feat(pydantic): 集成 Pydantic 模型到 Flask API 层
1. models.py: - 添加 dataframe_to_ohlcv_response() 转换函数 - 支持 DataFrame → OHLCVResponse 自动转换 - 自动处理 nav、premium、attrs 等业务数据 2. flask_server.py: - 使用 Pydantic 模型构建响应(替代手动 Dict) - 错误响应使用 ErrorResponse 模型 - 代码减少 20+ 行,类型安全提升 3. flask_api_source.py: - 使用 validate_ohlcv_response() 验证 API 响应 - 类型安全访问 nav、premium、info 等字段 - ETF 数据解析更可靠 测试通过: ✅ DataFrame → Pydantic 转换正常 ✅ ETF 净值和溢价率正确处理 ✅ 线上 API 响应验证成功 ✅ FlaskAPIDataSource 集成正常
This commit is contained in:
@@ -14,6 +14,8 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from .models import OHLCVResponse, validate_ohlcv_response
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@@ -132,14 +134,20 @@ class FlaskAPIDataSource:
|
||||
print(f"✗ API返回错误: {data['error']}")
|
||||
return None
|
||||
|
||||
# 解析数据
|
||||
records = data.get('data', [])
|
||||
if not records:
|
||||
# ✅ 使用 Pydantic 模型验证响应(类型安全)
|
||||
try:
|
||||
validated = validate_ohlcv_response(data)
|
||||
except Exception as e:
|
||||
print(f"✗ {code}: 响应数据验证失败 - {e}")
|
||||
return None
|
||||
|
||||
# 检查数据是否为空
|
||||
if not validated.data:
|
||||
print(f"⚠ {code}: 无数据返回")
|
||||
return None
|
||||
|
||||
# 转换为 DataFrame
|
||||
df = pd.DataFrame(records)
|
||||
df = pd.DataFrame(validated.data)
|
||||
|
||||
# 处理日期列
|
||||
if 'date' in df.columns:
|
||||
@@ -153,37 +161,37 @@ class FlaskAPIDataSource:
|
||||
df = df[standard_cols]
|
||||
|
||||
# 使用 API 返回的实际数据范围(而非请求参数)
|
||||
actual_start = data.get('date_range', {}).get('start', start_date)
|
||||
actual_end = data.get('date_range', {}).get('end', end_date)
|
||||
actual_count = data.get('count', len(df))
|
||||
actual_start = validated.date_range.start if validated.date_range else start_date
|
||||
actual_end = validated.date_range.end if validated.date_range else end_date
|
||||
actual_count = validated.count
|
||||
|
||||
# 缓存 info 信息(如果有)
|
||||
if 'info' in data:
|
||||
df.attrs['info'] = data['info']
|
||||
if validated.info:
|
||||
df.attrs['info'] = validated.info
|
||||
|
||||
# ETF 数据自动附加净值和溢价率信息
|
||||
if data.get('asset_type') == 'china_etf':
|
||||
if validated.asset_type == 'china_etf':
|
||||
# 净值数据
|
||||
nav_section = data.get('nav', {})
|
||||
if nav_section.get('data'):
|
||||
nav_df = pd.DataFrame(nav_section['data'])
|
||||
if validated.nav and validated.nav.data:
|
||||
nav_df = pd.DataFrame(validated.nav.data)
|
||||
if 'date' in nav_df.columns:
|
||||
nav_df['date'] = pd.to_datetime(nav_df['date'])
|
||||
nav_df = nav_df.set_index('date')
|
||||
df.attrs['nav'] = nav_df
|
||||
|
||||
# 溢价率序列
|
||||
if 'premium_series' in data:
|
||||
df.attrs['premium_series'] = data['premium_series']
|
||||
if validated.premium_series:
|
||||
premium_dict = {item.date: item.premium for item in validated.premium_series}
|
||||
df.attrs['premium_series'] = premium_dict
|
||||
|
||||
# 最新溢价率
|
||||
if 'latest_premium' in data:
|
||||
df.attrs['latest_premium'] = data['latest_premium']
|
||||
df.attrs['premium_date'] = data.get('premium_date')
|
||||
if validated.latest_premium is not None:
|
||||
df.attrs['latest_premium'] = validated.latest_premium
|
||||
df.attrs['premium_date'] = validated.premium_date
|
||||
|
||||
# 溢价率统计
|
||||
if 'premium_stats' in data:
|
||||
df.attrs['premium_stats'] = data['premium_stats']
|
||||
if validated.premium_stats:
|
||||
df.attrs['premium_stats'] = validated.premium_stats.model_dump()
|
||||
|
||||
print(f"✓ {code}: {actual_count} 条数据 ({actual_start} ~ {actual_end})")
|
||||
return df
|
||||
|
||||
@@ -44,6 +44,7 @@ import pandas as pd
|
||||
|
||||
from datasource.universal_fetcher import UniversalDataFetcher
|
||||
from datasource.asset_type_detector import AssetTypeDetector, AssetType
|
||||
from datasource.models import dataframe_to_ohlcv_response, OHLCVResponse, ErrorResponse
|
||||
|
||||
|
||||
# ============================================================
|
||||
@@ -730,91 +731,71 @@ def get_ohlcv():
|
||||
result, is_cached = fetch_data_with_ttl(code, start, end, nocache, timeframe, adj, final_type)
|
||||
|
||||
if result is None:
|
||||
return jsonify({
|
||||
"code": code,
|
||||
"asset_type": final_type.value,
|
||||
"adj": adj,
|
||||
"detected_type": detected_type.value if asset_type_param else None, # 仅当用户指定时显示
|
||||
"error": "No data available",
|
||||
"start": start,
|
||||
"end": end,
|
||||
}), 404
|
||||
error_response = ErrorResponse(
|
||||
error="No data available",
|
||||
code=code,
|
||||
asset_type=final_type.value,
|
||||
adj=adj,
|
||||
detected_type=detected_type.value if asset_type_param else None,
|
||||
)
|
||||
return error_response.model_dump(mode='json'), 404
|
||||
|
||||
if "error" in result:
|
||||
return jsonify({
|
||||
"code": code,
|
||||
"asset_type": final_type.value,
|
||||
"adj": adj,
|
||||
"detected_type": detected_type.value if asset_type_param else None,
|
||||
"error": result["error"],
|
||||
}), 500
|
||||
error_response = ErrorResponse(
|
||||
error=result["error"],
|
||||
code=code,
|
||||
asset_type=final_type.value,
|
||||
adj=adj,
|
||||
detected_type=detected_type.value if asset_type_param else None,
|
||||
)
|
||||
return error_response.model_dump(mode='json'), 500
|
||||
|
||||
result['cached'] = is_cached
|
||||
result['asset_type'] = final_type.value # 使用最终类型
|
||||
result['adj'] = adj # 返回使用的 adj 参数
|
||||
# ✅ 使用 Pydantic 模型构建响应(类型安全)
|
||||
# 从 result 中提取数据
|
||||
df_data = result.get('data', [])
|
||||
attrs = result.get('attrs', {})
|
||||
|
||||
# API 层职责:决定如何使用 attrs 中的业务数据
|
||||
if 'attrs' in result:
|
||||
attrs = result['attrs']
|
||||
|
||||
# 根据资产类型决定日期格式精度
|
||||
# 加密货币使用分钟级,其他使用天级
|
||||
date_format = '%Y-%m-%d %H:%M:%S' if final_type == AssetType.CRYPTO else '%Y-%m-%d'
|
||||
|
||||
# 提取净值到顶层(方便调用方使用)
|
||||
if 'nav' in attrs:
|
||||
nav_df = attrs['nav']
|
||||
if isinstance(nav_df, pd.DataFrame):
|
||||
# 将 DataFrame 转换为列表格式(JSON 可序列化)
|
||||
nav_df_copy = nav_df.reset_index().copy()
|
||||
nav_df_copy['date'] = nav_df_copy['date'].dt.strftime(date_format)
|
||||
nav_dict = {
|
||||
'data': nav_df_copy.to_dict(orient='records'),
|
||||
'count': len(nav_df_copy)
|
||||
}
|
||||
result['nav'] = nav_dict
|
||||
else:
|
||||
result['nav'] = nav_df
|
||||
|
||||
# 提取溢价率到顶层(调用业务函数处理格式)
|
||||
if 'premium' in attrs:
|
||||
premium_result = build_premium_result_from_attrs(attrs['premium'])
|
||||
if premium_result:
|
||||
result.update(premium_result)
|
||||
|
||||
# 将 attrs 中的 DataFrame/Series 转换为字典格式(用于 JSON 序列化)
|
||||
attrs_serializable = {}
|
||||
for key, value in attrs.items():
|
||||
if isinstance(value, pd.DataFrame):
|
||||
df_copy = value.reset_index().copy()
|
||||
if 'date' in df_copy.columns:
|
||||
df_copy['date'] = df_copy['date'].dt.strftime(date_format)
|
||||
attrs_serializable[key] = {
|
||||
'data': df_copy.to_dict(orient='records'),
|
||||
'count': len(df_copy)
|
||||
}
|
||||
elif isinstance(value, pd.Series):
|
||||
# 将 Series 索引转换为字符串
|
||||
series_copy = value.copy()
|
||||
series_copy.index = series_copy.index.strftime(date_format)
|
||||
attrs_serializable[key] = {
|
||||
'type': 'series',
|
||||
'data': series_copy.to_dict(),
|
||||
'name': value.name
|
||||
}
|
||||
else:
|
||||
attrs_serializable[key] = value
|
||||
|
||||
result['attrs'] = attrs_serializable
|
||||
# 重建 DataFrame(用于转换函数)
|
||||
if df_data:
|
||||
df = pd.DataFrame(df_data)
|
||||
if 'date' in df.columns:
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
df = df.set_index('date')
|
||||
else:
|
||||
df = pd.DataFrame()
|
||||
|
||||
# 如果用户指定了类型但与自动检测不同,显示提示
|
||||
if asset_type_param and detected_type != final_type:
|
||||
result['type_override'] = {
|
||||
# 提取 nav DataFrame
|
||||
nav_df = attrs.get('nav') if isinstance(attrs.get('nav'), pd.DataFrame) else None
|
||||
|
||||
# 提取 premium Series
|
||||
premium_series = attrs.get('premium') if isinstance(attrs.get('premium'), pd.Series) else None
|
||||
|
||||
# 构建响应模型
|
||||
response = dataframe_to_ohlcv_response(
|
||||
df=df if len(df) > 0 else None,
|
||||
code=code,
|
||||
asset_type=final_type.value,
|
||||
adj=adj,
|
||||
cached=is_cached,
|
||||
nav_df=nav_df,
|
||||
premium_series=premium_series,
|
||||
info=attrs.get('info'),
|
||||
attrs=attrs,
|
||||
columns=result.get('columns'),
|
||||
date_range=result.get('date_range'),
|
||||
requested_range=result.get('requested_range'),
|
||||
available_range=result.get('available_range'),
|
||||
cache_strategy=result.get('cache_strategy'),
|
||||
timeframe=result.get('timeframe'),
|
||||
type_override={
|
||||
"detected": detected_type.value,
|
||||
"specified": final_type.value,
|
||||
"hint": "用户强制覆盖了自动检测结果",
|
||||
}
|
||||
return jsonify(result)
|
||||
} if (asset_type_param and detected_type != final_type) else None,
|
||||
)
|
||||
|
||||
# ✅ 自动序列化为 JSON
|
||||
return response.model_dump(mode='json')
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -328,12 +328,13 @@ def validate_ohlcv_response(data: Dict[str, Any]) -> OHLCVResponse:
|
||||
return OHLCVResponse.model_validate(data)
|
||||
|
||||
|
||||
def dataframe_to_records(df) -> List[Dict[str, Any]]:
|
||||
def dataframe_to_records(df, date_format: str = '%Y-%m-%d') -> List[Dict[str, Any]]:
|
||||
"""
|
||||
将 DataFrame 转换为 OHLCVRecord 兼容的字典列表
|
||||
|
||||
Args:
|
||||
df: pandas DataFrame
|
||||
date_format: 日期格式字符串
|
||||
|
||||
Returns:
|
||||
字典列表(可直接用于 OHLCVResponse.data)
|
||||
@@ -349,7 +350,7 @@ def dataframe_to_records(df) -> List[Dict[str, Any]]:
|
||||
if col in df_reset.columns:
|
||||
try:
|
||||
import pandas as pd
|
||||
df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime('%Y-%m-%d')
|
||||
df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime(date_format)
|
||||
if col != 'date':
|
||||
df_reset = df_reset.rename(columns={col: 'date'})
|
||||
break
|
||||
@@ -357,3 +358,179 @@ def dataframe_to_records(df) -> List[Dict[str, Any]]:
|
||||
pass
|
||||
|
||||
return df_reset.to_dict(orient='records')
|
||||
|
||||
|
||||
# ============================================================
|
||||
# DataFrame → Pydantic Model 转换函数
|
||||
# ============================================================
|
||||
|
||||
def dataframe_to_ohlcv_response(
|
||||
df: Any, # pd.DataFrame
|
||||
code: str,
|
||||
asset_type: str,
|
||||
adj: str = 'raw',
|
||||
cached: bool = False,
|
||||
nav_df: Optional[Any] = None, # Optional[pd.DataFrame]
|
||||
premium_series: Optional[Any] = None, # Optional[pd.Series]
|
||||
info: Optional[Dict[str, Any]] = None,
|
||||
attrs: Optional[Dict[str, Any]] = None,
|
||||
date_format: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> 'OHLCVResponse':
|
||||
"""
|
||||
将 DataFrame 转换为 OHLCVResponse 模型
|
||||
|
||||
用途:
|
||||
- Flask API: 统一响应结构
|
||||
- 本地调用: 获得类型安全的响应对象
|
||||
|
||||
Args:
|
||||
df: 主数据 DataFrame(OHLCV)
|
||||
code: 标的代码
|
||||
asset_type: 资产类型
|
||||
adj: 复权类型
|
||||
cached: 是否命中缓存
|
||||
nav_df: ETF 净值 DataFrame(可选)
|
||||
premium_series: 溢价率 Series(可选)
|
||||
info: 标的信息(可选)
|
||||
attrs: 完整元数据(可选)
|
||||
date_format: 日期格式(可选,默认根据 asset_type 自动选择)
|
||||
**kwargs: 其他字段(columns, date_range, timeframe 等)
|
||||
|
||||
Returns:
|
||||
OHLCVResponse 模型实例
|
||||
|
||||
使用示例:
|
||||
# Flask API
|
||||
df = fetcher.fetch("META", start, end)
|
||||
response = dataframe_to_ohlcv_response(
|
||||
df, code="META", asset_type="us_stock", adj="raw"
|
||||
)
|
||||
return response.model_dump(mode='json')
|
||||
|
||||
# 本地调用
|
||||
df = fetcher.fetch("513100.SH", start, end)
|
||||
nav_df = df.attrs.get('nav')
|
||||
premium = df.attrs.get('premium')
|
||||
|
||||
response = dataframe_to_ohlcv_response(
|
||||
df,
|
||||
code="513100.SH",
|
||||
asset_type="china_etf",
|
||||
nav_df=nav_df,
|
||||
premium_series=premium
|
||||
)
|
||||
print(response.nav.count) # IDE 有自动补全
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
# 自动选择日期格式
|
||||
if date_format is None:
|
||||
date_format = '%Y-%m-%d %H:%M:%S' if asset_type == 'crypto' else '%Y-%m-%d'
|
||||
|
||||
# 转换主数据
|
||||
data = dataframe_to_records(df, date_format) if df is not None else []
|
||||
|
||||
# 构建响应数据
|
||||
response_data = {
|
||||
"code": code,
|
||||
"asset_type": asset_type,
|
||||
"adj": adj,
|
||||
"count": len(data),
|
||||
"data": data,
|
||||
"cached": cached,
|
||||
}
|
||||
|
||||
# 添加 info(优先使用传入的,其次从 df.attrs 获取)
|
||||
if info is not None:
|
||||
response_data['info'] = info
|
||||
elif hasattr(df, 'attrs') and df.attrs and 'info' in df.attrs:
|
||||
response_data['info'] = df.attrs['info']
|
||||
|
||||
# 添加 nav(如果有)
|
||||
if nav_df is not None and isinstance(nav_df, pd.DataFrame):
|
||||
nav_records = dataframe_to_records(nav_df, date_format)
|
||||
response_data['nav'] = {
|
||||
"data": nav_records,
|
||||
"count": len(nav_records)
|
||||
}
|
||||
|
||||
# 添加 premium(如果有)
|
||||
if premium_series is not None and isinstance(premium_series, pd.Series) and len(premium_series) > 0:
|
||||
# 最新溢价率
|
||||
latest_premium = float(premium_series.iloc[-1])
|
||||
response_data['latest_premium'] = round(latest_premium, 6)
|
||||
response_data['premium_date'] = premium_series.index[-1].strftime(date_format)
|
||||
|
||||
# 溢价率序列
|
||||
premium_list = [
|
||||
{"date": date.strftime(date_format), "premium": round(float(premium), 6)}
|
||||
for date, premium in premium_series.items()
|
||||
]
|
||||
response_data['premium_series'] = premium_list
|
||||
|
||||
# 溢价率统计
|
||||
response_data['premium_stats'] = {
|
||||
"mean": round(float(premium_series.mean()), 6),
|
||||
"std": round(float(premium_series.std()), 6),
|
||||
"min": round(float(premium_series.min()), 6),
|
||||
"max": round(float(premium_series.max()), 6),
|
||||
"median": round(float(premium_series.median()), 6),
|
||||
}
|
||||
|
||||
# 添加 attrs(如果有)
|
||||
if attrs is not None:
|
||||
# 过滤内部缓存元数据
|
||||
public_attrs = {k: v for k, v in attrs.items() if not k.startswith('_cache_')}
|
||||
|
||||
# 转换 DataFrame/Series 为可序列化格式
|
||||
attrs_serializable = {}
|
||||
for key, value in public_attrs.items():
|
||||
if isinstance(value, pd.DataFrame):
|
||||
attrs_serializable[key] = {
|
||||
'data': dataframe_to_records(value, date_format),
|
||||
'count': len(value)
|
||||
}
|
||||
elif isinstance(value, pd.Series):
|
||||
series_copy = value.copy()
|
||||
series_copy.index = series_copy.index.strftime(date_format)
|
||||
attrs_serializable[key] = {
|
||||
'type': 'series',
|
||||
'data': series_copy.to_dict(),
|
||||
'name': value.name
|
||||
}
|
||||
else:
|
||||
attrs_serializable[key] = value
|
||||
|
||||
if attrs_serializable:
|
||||
response_data['attrs'] = attrs_serializable
|
||||
elif hasattr(df, 'attrs') and df.attrs:
|
||||
# 从 df.attrs 提取
|
||||
public_attrs = {k: v for k, v in df.attrs.items() if not k.startswith('_cache_')}
|
||||
if public_attrs:
|
||||
attrs_serializable = {}
|
||||
for key, value in public_attrs.items():
|
||||
if isinstance(value, pd.DataFrame):
|
||||
attrs_serializable[key] = {
|
||||
'data': dataframe_to_records(value, date_format),
|
||||
'count': len(value)
|
||||
}
|
||||
elif isinstance(value, pd.Series):
|
||||
series_copy = value.copy()
|
||||
series_copy.index = series_copy.index.strftime(date_format)
|
||||
attrs_serializable[key] = {
|
||||
'type': 'series',
|
||||
'data': series_copy.to_dict(),
|
||||
'name': value.name
|
||||
}
|
||||
else:
|
||||
attrs_serializable[key] = value
|
||||
|
||||
if attrs_serializable:
|
||||
response_data['attrs'] = attrs_serializable
|
||||
|
||||
# 添加其他辅助信息
|
||||
response_data.update(kwargs)
|
||||
|
||||
# 验证并返回模型
|
||||
return OHLCVResponse.model_validate(response_data)
|
||||
|
||||
Reference in New Issue
Block a user