feat(pydantic): 集成 Pydantic 模型到 Flask API 层

1. models.py:
   - 添加 dataframe_to_ohlcv_response() 转换函数
   - 支持 DataFrame → OHLCVResponse 自动转换
   - 自动处理 nav、premium、attrs 等业务数据

2. flask_server.py:
   - 使用 Pydantic 模型构建响应(替代手动 Dict)
   - 错误响应使用 ErrorResponse 模型
   - 代码减少 20+ 行,类型安全提升

3. flask_api_source.py:
   - 使用 validate_ohlcv_response() 验证 API 响应
   - 类型安全访问 nav、premium、info 等字段
   - ETF 数据解析更可靠

测试通过:
 DataFrame → Pydantic 转换正常
 ETF 净值和溢价率正确处理
 线上 API 响应验证成功
 FlaskAPIDataSource 集成正常
This commit is contained in:
2026-05-24 01:13:33 +08:00
parent 72df18a28b
commit 226a27361f
3 changed files with 265 additions and 99 deletions

View File

@@ -328,12 +328,13 @@ def validate_ohlcv_response(data: Dict[str, Any]) -> OHLCVResponse:
return OHLCVResponse.model_validate(data)
def dataframe_to_records(df) -> List[Dict[str, Any]]:
def dataframe_to_records(df, date_format: str = '%Y-%m-%d') -> List[Dict[str, Any]]:
"""
将 DataFrame 转换为 OHLCVRecord 兼容的字典列表
Args:
df: pandas DataFrame
date_format: 日期格式字符串
Returns:
字典列表(可直接用于 OHLCVResponse.data
@@ -349,7 +350,7 @@ def dataframe_to_records(df) -> List[Dict[str, Any]]:
if col in df_reset.columns:
try:
import pandas as pd
df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime('%Y-%m-%d')
df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime(date_format)
if col != 'date':
df_reset = df_reset.rename(columns={col: 'date'})
break
@@ -357,3 +358,179 @@ def dataframe_to_records(df) -> List[Dict[str, Any]]:
pass
return df_reset.to_dict(orient='records')
# ============================================================
# DataFrame → Pydantic Model 转换函数
# ============================================================
def dataframe_to_ohlcv_response(
df: Any, # pd.DataFrame
code: str,
asset_type: str,
adj: str = 'raw',
cached: bool = False,
nav_df: Optional[Any] = None, # Optional[pd.DataFrame]
premium_series: Optional[Any] = None, # Optional[pd.Series]
info: Optional[Dict[str, Any]] = None,
attrs: Optional[Dict[str, Any]] = None,
date_format: Optional[str] = None,
**kwargs
) -> 'OHLCVResponse':
"""
将 DataFrame 转换为 OHLCVResponse 模型
用途:
- Flask API: 统一响应结构
- 本地调用: 获得类型安全的响应对象
Args:
df: 主数据 DataFrameOHLCV
code: 标的代码
asset_type: 资产类型
adj: 复权类型
cached: 是否命中缓存
nav_df: ETF 净值 DataFrame可选
premium_series: 溢价率 Series可选
info: 标的信息(可选)
attrs: 完整元数据(可选)
date_format: 日期格式(可选,默认根据 asset_type 自动选择)
**kwargs: 其他字段columns, date_range, timeframe 等)
Returns:
OHLCVResponse 模型实例
使用示例:
# Flask API
df = fetcher.fetch("META", start, end)
response = dataframe_to_ohlcv_response(
df, code="META", asset_type="us_stock", adj="raw"
)
return response.model_dump(mode='json')
# 本地调用
df = fetcher.fetch("513100.SH", start, end)
nav_df = df.attrs.get('nav')
premium = df.attrs.get('premium')
response = dataframe_to_ohlcv_response(
df,
code="513100.SH",
asset_type="china_etf",
nav_df=nav_df,
premium_series=premium
)
print(response.nav.count) # IDE 有自动补全
"""
import pandas as pd
# 自动选择日期格式
if date_format is None:
date_format = '%Y-%m-%d %H:%M:%S' if asset_type == 'crypto' else '%Y-%m-%d'
# 转换主数据
data = dataframe_to_records(df, date_format) if df is not None else []
# 构建响应数据
response_data = {
"code": code,
"asset_type": asset_type,
"adj": adj,
"count": len(data),
"data": data,
"cached": cached,
}
# 添加 info优先使用传入的其次从 df.attrs 获取)
if info is not None:
response_data['info'] = info
elif hasattr(df, 'attrs') and df.attrs and 'info' in df.attrs:
response_data['info'] = df.attrs['info']
# 添加 nav如果有
if nav_df is not None and isinstance(nav_df, pd.DataFrame):
nav_records = dataframe_to_records(nav_df, date_format)
response_data['nav'] = {
"data": nav_records,
"count": len(nav_records)
}
# 添加 premium如果有
if premium_series is not None and isinstance(premium_series, pd.Series) and len(premium_series) > 0:
# 最新溢价率
latest_premium = float(premium_series.iloc[-1])
response_data['latest_premium'] = round(latest_premium, 6)
response_data['premium_date'] = premium_series.index[-1].strftime(date_format)
# 溢价率序列
premium_list = [
{"date": date.strftime(date_format), "premium": round(float(premium), 6)}
for date, premium in premium_series.items()
]
response_data['premium_series'] = premium_list
# 溢价率统计
response_data['premium_stats'] = {
"mean": round(float(premium_series.mean()), 6),
"std": round(float(premium_series.std()), 6),
"min": round(float(premium_series.min()), 6),
"max": round(float(premium_series.max()), 6),
"median": round(float(premium_series.median()), 6),
}
# 添加 attrs如果有
if attrs is not None:
# 过滤内部缓存元数据
public_attrs = {k: v for k, v in attrs.items() if not k.startswith('_cache_')}
# 转换 DataFrame/Series 为可序列化格式
attrs_serializable = {}
for key, value in public_attrs.items():
if isinstance(value, pd.DataFrame):
attrs_serializable[key] = {
'data': dataframe_to_records(value, date_format),
'count': len(value)
}
elif isinstance(value, pd.Series):
series_copy = value.copy()
series_copy.index = series_copy.index.strftime(date_format)
attrs_serializable[key] = {
'type': 'series',
'data': series_copy.to_dict(),
'name': value.name
}
else:
attrs_serializable[key] = value
if attrs_serializable:
response_data['attrs'] = attrs_serializable
elif hasattr(df, 'attrs') and df.attrs:
# 从 df.attrs 提取
public_attrs = {k: v for k, v in df.attrs.items() if not k.startswith('_cache_')}
if public_attrs:
attrs_serializable = {}
for key, value in public_attrs.items():
if isinstance(value, pd.DataFrame):
attrs_serializable[key] = {
'data': dataframe_to_records(value, date_format),
'count': len(value)
}
elif isinstance(value, pd.Series):
series_copy = value.copy()
series_copy.index = series_copy.index.strftime(date_format)
attrs_serializable[key] = {
'type': 'series',
'data': series_copy.to_dict(),
'name': value.name
}
else:
attrs_serializable[key] = value
if attrs_serializable:
response_data['attrs'] = attrs_serializable
# 添加其他辅助信息
response_data.update(kwargs)
# 验证并返回模型
return OHLCVResponse.model_validate(response_data)