feat(pydantic): 集成 Pydantic 模型到 Flask API 层

1. models.py:
   - 添加 dataframe_to_ohlcv_response() 转换函数
   - 支持 DataFrame → OHLCVResponse 自动转换
   - 自动处理 nav、premium、attrs 等业务数据

2. flask_server.py:
   - 使用 Pydantic 模型构建响应(替代手动 Dict)
   - 错误响应使用 ErrorResponse 模型
   - 代码减少 20+ 行,类型安全提升

3. flask_api_source.py:
   - 使用 validate_ohlcv_response() 验证 API 响应
   - 类型安全访问 nav、premium、info 等字段
   - ETF 数据解析更可靠

测试通过:
 DataFrame → Pydantic 转换正常
 ETF 净值和溢价率正确处理
 线上 API 响应验证成功
 FlaskAPIDataSource 集成正常
This commit is contained in:
2026-05-24 01:13:33 +08:00
parent 72df18a28b
commit 226a27361f
3 changed files with 265 additions and 99 deletions

View File

@@ -14,6 +14,8 @@ from datetime import datetime
from pathlib import Path
from dotenv import load_dotenv
from .models import OHLCVResponse, validate_ohlcv_response
load_dotenv()
@@ -132,14 +134,20 @@ class FlaskAPIDataSource:
print(f"✗ API返回错误: {data['error']}")
return None
# 解析数据
records = data.get('data', [])
if not records:
# ✅ 使用 Pydantic 模型验证响应(类型安全)
try:
validated = validate_ohlcv_response(data)
except Exception as e:
print(f"{code}: 响应数据验证失败 - {e}")
return None
# 检查数据是否为空
if not validated.data:
print(f"{code}: 无数据返回")
return None
# 转换为 DataFrame
df = pd.DataFrame(records)
df = pd.DataFrame(validated.data)
# 处理日期列
if 'date' in df.columns:
@@ -153,37 +161,37 @@ class FlaskAPIDataSource:
df = df[standard_cols]
# 使用 API 返回的实际数据范围(而非请求参数)
actual_start = data.get('date_range', {}).get('start', start_date)
actual_end = data.get('date_range', {}).get('end', end_date)
actual_count = data.get('count', len(df))
actual_start = validated.date_range.start if validated.date_range else start_date
actual_end = validated.date_range.end if validated.date_range else end_date
actual_count = validated.count
# 缓存 info 信息(如果有)
if 'info' in data:
df.attrs['info'] = data['info']
if validated.info:
df.attrs['info'] = validated.info
# ETF 数据自动附加净值和溢价率信息
if data.get('asset_type') == 'china_etf':
if validated.asset_type == 'china_etf':
# 净值数据
nav_section = data.get('nav', {})
if nav_section.get('data'):
nav_df = pd.DataFrame(nav_section['data'])
if validated.nav and validated.nav.data:
nav_df = pd.DataFrame(validated.nav.data)
if 'date' in nav_df.columns:
nav_df['date'] = pd.to_datetime(nav_df['date'])
nav_df = nav_df.set_index('date')
df.attrs['nav'] = nav_df
# 溢价率序列
if 'premium_series' in data:
df.attrs['premium_series'] = data['premium_series']
if validated.premium_series:
premium_dict = {item.date: item.premium for item in validated.premium_series}
df.attrs['premium_series'] = premium_dict
# 最新溢价率
if 'latest_premium' in data:
df.attrs['latest_premium'] = data['latest_premium']
df.attrs['premium_date'] = data.get('premium_date')
if validated.latest_premium is not None:
df.attrs['latest_premium'] = validated.latest_premium
df.attrs['premium_date'] = validated.premium_date
# 溢价率统计
if 'premium_stats' in data:
df.attrs['premium_stats'] = data['premium_stats']
if validated.premium_stats:
df.attrs['premium_stats'] = validated.premium_stats.model_dump()
print(f"{code}: {actual_count} 条数据 ({actual_start} ~ {actual_end})")
return df