diff --git a/datasource/flask_api_source.py b/datasource/flask_api_source.py index a2368b2..2ca69d0 100644 --- a/datasource/flask_api_source.py +++ b/datasource/flask_api_source.py @@ -14,6 +14,8 @@ from datetime import datetime from pathlib import Path from dotenv import load_dotenv +from .models import OHLCVResponse, validate_ohlcv_response + load_dotenv() @@ -132,14 +134,20 @@ class FlaskAPIDataSource: print(f"✗ API返回错误: {data['error']}") return None - # 解析数据 - records = data.get('data', []) - if not records: + # ✅ 使用 Pydantic 模型验证响应(类型安全) + try: + validated = validate_ohlcv_response(data) + except Exception as e: + print(f"✗ {code}: 响应数据验证失败 - {e}") + return None + + # 检查数据是否为空 + if not validated.data: print(f"⚠ {code}: 无数据返回") return None # 转换为 DataFrame - df = pd.DataFrame(records) + df = pd.DataFrame(validated.data) # 处理日期列 if 'date' in df.columns: @@ -153,37 +161,37 @@ class FlaskAPIDataSource: df = df[standard_cols] # 使用 API 返回的实际数据范围(而非请求参数) - actual_start = data.get('date_range', {}).get('start', start_date) - actual_end = data.get('date_range', {}).get('end', end_date) - actual_count = data.get('count', len(df)) + actual_start = validated.date_range.start if validated.date_range else start_date + actual_end = validated.date_range.end if validated.date_range else end_date + actual_count = validated.count # 缓存 info 信息(如果有) - if 'info' in data: - df.attrs['info'] = data['info'] + if validated.info: + df.attrs['info'] = validated.info # ETF 数据自动附加净值和溢价率信息 - if data.get('asset_type') == 'china_etf': + if validated.asset_type == 'china_etf': # 净值数据 - nav_section = data.get('nav', {}) - if nav_section.get('data'): - nav_df = pd.DataFrame(nav_section['data']) + if validated.nav and validated.nav.data: + nav_df = pd.DataFrame(validated.nav.data) if 'date' in nav_df.columns: nav_df['date'] = pd.to_datetime(nav_df['date']) nav_df = nav_df.set_index('date') df.attrs['nav'] = nav_df # 溢价率序列 - if 'premium_series' in data: - df.attrs['premium_series'] = data['premium_series'] + if validated.premium_series: + premium_dict = {item.date: item.premium for item in validated.premium_series} + df.attrs['premium_series'] = premium_dict # 最新溢价率 - if 'latest_premium' in data: - df.attrs['latest_premium'] = data['latest_premium'] - df.attrs['premium_date'] = data.get('premium_date') + if validated.latest_premium is not None: + df.attrs['latest_premium'] = validated.latest_premium + df.attrs['premium_date'] = validated.premium_date # 溢价率统计 - if 'premium_stats' in data: - df.attrs['premium_stats'] = data['premium_stats'] + if validated.premium_stats: + df.attrs['premium_stats'] = validated.premium_stats.model_dump() print(f"✓ {code}: {actual_count} 条数据 ({actual_start} ~ {actual_end})") return df diff --git a/datasource/flask_server.py b/datasource/flask_server.py index 7335699..d7ee61b 100644 --- a/datasource/flask_server.py +++ b/datasource/flask_server.py @@ -44,6 +44,7 @@ import pandas as pd from datasource.universal_fetcher import UniversalDataFetcher from datasource.asset_type_detector import AssetTypeDetector, AssetType +from datasource.models import dataframe_to_ohlcv_response, OHLCVResponse, ErrorResponse # ============================================================ @@ -730,91 +731,71 @@ def get_ohlcv(): result, is_cached = fetch_data_with_ttl(code, start, end, nocache, timeframe, adj, final_type) if result is None: - return jsonify({ - "code": code, - "asset_type": final_type.value, - "adj": adj, - "detected_type": detected_type.value if asset_type_param else None, # 仅当用户指定时显示 - "error": "No data available", - "start": start, - "end": end, - }), 404 + error_response = ErrorResponse( + error="No data available", + code=code, + asset_type=final_type.value, + adj=adj, + detected_type=detected_type.value if asset_type_param else None, + ) + return error_response.model_dump(mode='json'), 404 if "error" in result: - return jsonify({ - "code": code, - "asset_type": final_type.value, - "adj": adj, - "detected_type": detected_type.value if asset_type_param else None, - "error": result["error"], - }), 500 + error_response = ErrorResponse( + error=result["error"], + code=code, + asset_type=final_type.value, + adj=adj, + detected_type=detected_type.value if asset_type_param else None, + ) + return error_response.model_dump(mode='json'), 500 - result['cached'] = is_cached - result['asset_type'] = final_type.value # 使用最终类型 - result['adj'] = adj # 返回使用的 adj 参数 + # ✅ 使用 Pydantic 模型构建响应(类型安全) + # 从 result 中提取数据 + df_data = result.get('data', []) + attrs = result.get('attrs', {}) - # API 层职责:决定如何使用 attrs 中的业务数据 - if 'attrs' in result: - attrs = result['attrs'] - - # 根据资产类型决定日期格式精度 - # 加密货币使用分钟级,其他使用天级 - date_format = '%Y-%m-%d %H:%M:%S' if final_type == AssetType.CRYPTO else '%Y-%m-%d' - - # 提取净值到顶层(方便调用方使用) - if 'nav' in attrs: - nav_df = attrs['nav'] - if isinstance(nav_df, pd.DataFrame): - # 将 DataFrame 转换为列表格式(JSON 可序列化) - nav_df_copy = nav_df.reset_index().copy() - nav_df_copy['date'] = nav_df_copy['date'].dt.strftime(date_format) - nav_dict = { - 'data': nav_df_copy.to_dict(orient='records'), - 'count': len(nav_df_copy) - } - result['nav'] = nav_dict - else: - result['nav'] = nav_df - - # 提取溢价率到顶层(调用业务函数处理格式) - if 'premium' in attrs: - premium_result = build_premium_result_from_attrs(attrs['premium']) - if premium_result: - result.update(premium_result) - - # 将 attrs 中的 DataFrame/Series 转换为字典格式(用于 JSON 序列化) - attrs_serializable = {} - for key, value in attrs.items(): - if isinstance(value, pd.DataFrame): - df_copy = value.reset_index().copy() - if 'date' in df_copy.columns: - df_copy['date'] = df_copy['date'].dt.strftime(date_format) - attrs_serializable[key] = { - 'data': df_copy.to_dict(orient='records'), - 'count': len(df_copy) - } - elif isinstance(value, pd.Series): - # 将 Series 索引转换为字符串 - series_copy = value.copy() - series_copy.index = series_copy.index.strftime(date_format) - attrs_serializable[key] = { - 'type': 'series', - 'data': series_copy.to_dict(), - 'name': value.name - } - else: - attrs_serializable[key] = value - - result['attrs'] = attrs_serializable + # 重建 DataFrame(用于转换函数) + if df_data: + df = pd.DataFrame(df_data) + if 'date' in df.columns: + df['date'] = pd.to_datetime(df['date']) + df = df.set_index('date') + else: + df = pd.DataFrame() - # 如果用户指定了类型但与自动检测不同,显示提示 - if asset_type_param and detected_type != final_type: - result['type_override'] = { + # 提取 nav DataFrame + nav_df = attrs.get('nav') if isinstance(attrs.get('nav'), pd.DataFrame) else None + + # 提取 premium Series + premium_series = attrs.get('premium') if isinstance(attrs.get('premium'), pd.Series) else None + + # 构建响应模型 + response = dataframe_to_ohlcv_response( + df=df if len(df) > 0 else None, + code=code, + asset_type=final_type.value, + adj=adj, + cached=is_cached, + nav_df=nav_df, + premium_series=premium_series, + info=attrs.get('info'), + attrs=attrs, + columns=result.get('columns'), + date_range=result.get('date_range'), + requested_range=result.get('requested_range'), + available_range=result.get('available_range'), + cache_strategy=result.get('cache_strategy'), + timeframe=result.get('timeframe'), + type_override={ "detected": detected_type.value, "specified": final_type.value, "hint": "用户强制覆盖了自动检测结果", - } - return jsonify(result) + } if (asset_type_param and detected_type != final_type) else None, + ) + + # ✅ 自动序列化为 JSON + return response.model_dump(mode='json') diff --git a/datasource/models.py b/datasource/models.py index 5f89e86..5c4e8d7 100644 --- a/datasource/models.py +++ b/datasource/models.py @@ -328,12 +328,13 @@ def validate_ohlcv_response(data: Dict[str, Any]) -> OHLCVResponse: return OHLCVResponse.model_validate(data) -def dataframe_to_records(df) -> List[Dict[str, Any]]: +def dataframe_to_records(df, date_format: str = '%Y-%m-%d') -> List[Dict[str, Any]]: """ 将 DataFrame 转换为 OHLCVRecord 兼容的字典列表 Args: df: pandas DataFrame + date_format: 日期格式字符串 Returns: 字典列表(可直接用于 OHLCVResponse.data) @@ -349,7 +350,7 @@ def dataframe_to_records(df) -> List[Dict[str, Any]]: if col in df_reset.columns: try: import pandas as pd - df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime('%Y-%m-%d') + df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime(date_format) if col != 'date': df_reset = df_reset.rename(columns={col: 'date'}) break @@ -357,3 +358,179 @@ def dataframe_to_records(df) -> List[Dict[str, Any]]: pass return df_reset.to_dict(orient='records') + + +# ============================================================ +# DataFrame → Pydantic Model 转换函数 +# ============================================================ + +def dataframe_to_ohlcv_response( + df: Any, # pd.DataFrame + code: str, + asset_type: str, + adj: str = 'raw', + cached: bool = False, + nav_df: Optional[Any] = None, # Optional[pd.DataFrame] + premium_series: Optional[Any] = None, # Optional[pd.Series] + info: Optional[Dict[str, Any]] = None, + attrs: Optional[Dict[str, Any]] = None, + date_format: Optional[str] = None, + **kwargs +) -> 'OHLCVResponse': + """ + 将 DataFrame 转换为 OHLCVResponse 模型 + + 用途: + - Flask API: 统一响应结构 + - 本地调用: 获得类型安全的响应对象 + + Args: + df: 主数据 DataFrame(OHLCV) + code: 标的代码 + asset_type: 资产类型 + adj: 复权类型 + cached: 是否命中缓存 + nav_df: ETF 净值 DataFrame(可选) + premium_series: 溢价率 Series(可选) + info: 标的信息(可选) + attrs: 完整元数据(可选) + date_format: 日期格式(可选,默认根据 asset_type 自动选择) + **kwargs: 其他字段(columns, date_range, timeframe 等) + + Returns: + OHLCVResponse 模型实例 + + 使用示例: + # Flask API + df = fetcher.fetch("META", start, end) + response = dataframe_to_ohlcv_response( + df, code="META", asset_type="us_stock", adj="raw" + ) + return response.model_dump(mode='json') + + # 本地调用 + df = fetcher.fetch("513100.SH", start, end) + nav_df = df.attrs.get('nav') + premium = df.attrs.get('premium') + + response = dataframe_to_ohlcv_response( + df, + code="513100.SH", + asset_type="china_etf", + nav_df=nav_df, + premium_series=premium + ) + print(response.nav.count) # IDE 有自动补全 + """ + import pandas as pd + + # 自动选择日期格式 + if date_format is None: + date_format = '%Y-%m-%d %H:%M:%S' if asset_type == 'crypto' else '%Y-%m-%d' + + # 转换主数据 + data = dataframe_to_records(df, date_format) if df is not None else [] + + # 构建响应数据 + response_data = { + "code": code, + "asset_type": asset_type, + "adj": adj, + "count": len(data), + "data": data, + "cached": cached, + } + + # 添加 info(优先使用传入的,其次从 df.attrs 获取) + if info is not None: + response_data['info'] = info + elif hasattr(df, 'attrs') and df.attrs and 'info' in df.attrs: + response_data['info'] = df.attrs['info'] + + # 添加 nav(如果有) + if nav_df is not None and isinstance(nav_df, pd.DataFrame): + nav_records = dataframe_to_records(nav_df, date_format) + response_data['nav'] = { + "data": nav_records, + "count": len(nav_records) + } + + # 添加 premium(如果有) + if premium_series is not None and isinstance(premium_series, pd.Series) and len(premium_series) > 0: + # 最新溢价率 + latest_premium = float(premium_series.iloc[-1]) + response_data['latest_premium'] = round(latest_premium, 6) + response_data['premium_date'] = premium_series.index[-1].strftime(date_format) + + # 溢价率序列 + premium_list = [ + {"date": date.strftime(date_format), "premium": round(float(premium), 6)} + for date, premium in premium_series.items() + ] + response_data['premium_series'] = premium_list + + # 溢价率统计 + response_data['premium_stats'] = { + "mean": round(float(premium_series.mean()), 6), + "std": round(float(premium_series.std()), 6), + "min": round(float(premium_series.min()), 6), + "max": round(float(premium_series.max()), 6), + "median": round(float(premium_series.median()), 6), + } + + # 添加 attrs(如果有) + if attrs is not None: + # 过滤内部缓存元数据 + public_attrs = {k: v for k, v in attrs.items() if not k.startswith('_cache_')} + + # 转换 DataFrame/Series 为可序列化格式 + attrs_serializable = {} + for key, value in public_attrs.items(): + if isinstance(value, pd.DataFrame): + attrs_serializable[key] = { + 'data': dataframe_to_records(value, date_format), + 'count': len(value) + } + elif isinstance(value, pd.Series): + series_copy = value.copy() + series_copy.index = series_copy.index.strftime(date_format) + attrs_serializable[key] = { + 'type': 'series', + 'data': series_copy.to_dict(), + 'name': value.name + } + else: + attrs_serializable[key] = value + + if attrs_serializable: + response_data['attrs'] = attrs_serializable + elif hasattr(df, 'attrs') and df.attrs: + # 从 df.attrs 提取 + public_attrs = {k: v for k, v in df.attrs.items() if not k.startswith('_cache_')} + if public_attrs: + attrs_serializable = {} + for key, value in public_attrs.items(): + if isinstance(value, pd.DataFrame): + attrs_serializable[key] = { + 'data': dataframe_to_records(value, date_format), + 'count': len(value) + } + elif isinstance(value, pd.Series): + series_copy = value.copy() + series_copy.index = series_copy.index.strftime(date_format) + attrs_serializable[key] = { + 'type': 'series', + 'data': series_copy.to_dict(), + 'name': value.name + } + else: + attrs_serializable[key] = value + + if attrs_serializable: + response_data['attrs'] = attrs_serializable + + # 添加其他辅助信息 + response_data.update(kwargs) + + # 验证并返回模型 + return OHLCVResponse.model_validate(response_data)