diff --git a/datasource/models.py b/datasource/models.py new file mode 100644 index 0000000..5f89e86 --- /dev/null +++ b/datasource/models.py @@ -0,0 +1,359 @@ +""" +Pydantic 数据模型 +================== +定义 Flask API 的请求和响应数据结构 + +用途: +- 类型安全的 API 响应结构 +- 调用方数据验证 +- IDE 自动补全支持 +- 为 Phase 2 (Flask-Pydantic 集成) 做准备 + +使用示例: + from datasource.models import OHLCVResponse + + # 验证 API 响应 + response = requests.get(url).json() + validated = OHLCVResponse.model_validate(response) + + # 访问字段(有类型提示) + print(validated.code) + print(validated.latest_premium) +""" + +from pydantic import BaseModel, Field, field_validator +from typing import Optional, List, Dict, Any +from enum import Enum + + +# ============================================================ +# 枚举类型 +# ============================================================ + +class AssetTypeEnum(str, Enum): + """资产类型枚举""" + CHINA_INDEX = "china_index" + CHINA_ETF = "china_etf" + CHINA_STOCK = "china_stock" + US_INDEX = "us_index" + US_STOCK = "us_stock" + HK_INDEX = "hk_index" + HK_STOCK = "hk_stock" + FUTURES = "futures" + CRYPTO = "crypto" + + +class AdjTypeEnum(str, Enum): + """复权类型枚举""" + RAW = "raw" + QFQ = "qfq" + HFQ = "hfq" + + +class TimeframeEnum(str, Enum): + """K线周期枚举(加密货币)""" + D1 = "1d" + H1 = "1h" + H4 = "4h" + M15 = "15m" + M1 = "1m" + + +# ============================================================ +# 请求模型(Phase 2 使用) +# ============================================================ + +class OHLCVRequest(BaseModel): + """ + OHLCV 数据获取请求 + + 使用示例: + request = OHLCVRequest( + code="META", + start="2024-01-01", + end="2024-03-31", + adj=AdjTypeEnum.RAW + ) + """ + code: str = Field( + ..., + min_length=1, + max_length=50, + description="标的代码(如 META, 513100.SH, BTC)" + ) + start: Optional[str] = Field( + None, + pattern=r'^\d{4}-\d{2}-\d{2}$', + description="开始日期(YYYY-MM-DD),默认90天前" + ) + end: Optional[str] = Field( + None, + pattern=r'^\d{4}-\d{2}-\d{2}$', + description="结束日期(YYYY-MM-DD),默认今天" + ) + asset_type: Optional[AssetTypeEnum] = Field( + None, + description="资产类型(可选,覆盖自动检测)" + ) + adj: AdjTypeEnum = Field( + AdjTypeEnum.RAW, + description="复权类型(raw/qfq/hfq)" + ) + timeframe: Optional[TimeframeEnum] = Field( + TimeframeEnum.D1, + description="K线周期(仅加密货币需要)" + ) + nocache: bool = Field( + False, + description="是否跳过缓存" + ) + + @field_validator('start', 'end') + @classmethod + def validate_date_format(cls, v: Optional[str]) -> Optional[str]: + """验证日期格式""" + if v is None: + return v + + # 验证日期合法性 + from datetime import datetime + try: + datetime.strptime(v, '%Y-%m-%d') + except ValueError: + raise ValueError(f'Invalid date format: {v}. Use YYYY-MM-DD') + + return v + + +class AssetTypeRequest(BaseModel): + """资产类型检测请求""" + code: str = Field( + ..., + min_length=1, + max_length=50, + description="标的代码" + ) + + +# ============================================================ +# 响应模型 +# ============================================================ + +class DataRange(BaseModel): + """日期范围""" + start: Optional[str] = None + end: Optional[str] = None + + +class RequestedRange(BaseModel): + """用户请求的日期范围""" + start: str + end: str + + +class AvailableRange(BaseModel): + """可用数据范围""" + start: Optional[str] = None + end: Optional[str] = None + + +class TypeOverride(BaseModel): + """类型覆盖信息""" + detected: str + specified: str + hint: str + + +class OHLCVRecord(BaseModel): + """ + OHLCV 数据记录 + + 字段会根据资产类型动态变化: + - 股票/指数/ETF: date, open, high, low, close, volume + - 加密货币: date, open, high, low, close, volume + - 其他可能包含: amount, change, pct_chg 等 + """ + date: str + # 其他字段使用 Dict 接受,因为不同资产类型字段不同 + + class Config: + extra = "allow" # 允许额外字段 + + +class NavData(BaseModel): + """ETF 净值数据""" + data: List[Dict[str, Any]] + count: int + + +class PremiumStats(BaseModel): + """溢价率统计""" + mean: float + std: float + min: float + max: float + median: float + + +class PremiumSeriesItem(BaseModel): + """溢价率序列项""" + date: str + premium: float + + +class OHLCVResponse(BaseModel): + """ + OHLCV 数据获取响应 + + 字段说明: + - code: 标的代码 + - asset_type: 资产类型 + - adj: 复权类型 + - count: 数据条数 + - data: OHLCV 数据列表 + - cached: 是否命中缓存 + - nav: ETF 净值数据(仅 china_etf) + - latest_premium: 最新溢价率(仅 china_etf) + - premium_series: 溢价率序列(仅 china_etf) + - premium_date: 最新溢价率日期(仅 china_etf) + - premium_stats: 溢价率统计(仅 china_etf) + - attrs: 完整元数据 + - info: 标的信息(股票/指数) + - columns: 数据列名 + - date_range: 数据日期范围 + - requested_range: 用户请求范围 + - available_range: 可用数据范围 + - cache_strategy: 缓存策略说明 + - timeframe: K线周期(加密货币) + - type_override: 类型覆盖信息 + + 使用示例: + response = requests.get(url).json() + validated = OHLCVResponse.model_validate(response) + + # 有类型提示 + print(validated.code) + print(validated.latest_premium) + print(validated.nav.count if validated.nav else 0) + """ + code: str = Field(..., description="标的代码") + asset_type: str = Field(..., description="资产类型") + adj: str = Field(..., description="复权类型") + count: int = Field(..., description="数据条数") + data: List[Dict[str, Any]] = Field(..., description="OHLCV 数据列表") + cached: Optional[bool] = Field(None, description="是否命中缓存") + + # ETF 相关(仅 china_etf) + nav: Optional[NavData] = Field(None, description="ETF 净值数据") + latest_premium: Optional[float] = Field(None, description="最新溢价率(%)") + premium_series: Optional[List[PremiumSeriesItem]] = Field(None, description="溢价率序列") + premium_date: Optional[str] = Field(None, description="最新溢价率日期") + premium_stats: Optional[PremiumStats] = Field(None, description="溢价率统计") + + # 元数据 + attrs: Optional[Dict[str, Any]] = Field(None, description="完整元数据") + info: Optional[Dict[str, Any]] = Field(None, description="标的信息") + + # 辅助信息 + columns: Optional[List[str]] = Field(None, description="数据列名") + date_range: Optional[DataRange] = Field(None, description="数据日期范围") + requested_range: Optional[RequestedRange] = Field(None, description="用户请求范围") + available_range: Optional[AvailableRange] = Field(None, description="可用数据范围") + cache_strategy: Optional[str] = Field(None, description="缓存策略说明") + timeframe: Optional[str] = Field(None, description="K线周期(加密货币)") + type_override: Optional[TypeOverride] = Field(None, description="类型覆盖信息") + + class Config: + extra = "allow" # 允许额外字段,保持向后兼容 + + +class AssetTypeResponse(BaseModel): + """资产类型检测响应""" + code: str = Field(..., description="标的代码") + asset_type: str = Field(..., description="资产类型") + description: str = Field(..., description="资产类型描述") + + +class CacheStats(BaseModel): + """缓存统计信息""" + lru_cache: Dict[str, Any] = Field(..., description="LRU 缓存统计") + ttl_cache_size: int = Field(..., description="TTL 缓存大小") + ttl_seconds: int = Field(..., description="TTL 秒数") + default_start_date: str = Field(..., description="默认数据起点") + cache_strategy: str = Field(..., description="缓存策略") + + +class ErrorResponse(BaseModel): + """错误响应""" + error: str = Field(..., description="错误信息") + code: Optional[str] = Field(None, description="标的代码") + asset_type: Optional[str] = Field(None, description="资产类型") + adj: Optional[str] = Field(None, description="复权类型") + detected_type: Optional[str] = Field(None, description="自动检测的类型") + example: Optional[str] = Field(None, description="示例请求") + hint: Optional[str] = Field(None, description="提示信息") + valid_adj: Optional[List[str]] = Field(None, description="有效的 adj 值") + valid_types: Optional[List[str]] = Field(None, description="有效的资产类型") + valid_timeframes: Optional[List[str]] = Field(None, description="有效的 timeframe 值") + + class Config: + extra = "allow" # 允许额外字段 + + +# ============================================================ +# 模型工具函数 +# ============================================================ + +def validate_ohlcv_response(data: Dict[str, Any]) -> OHLCVResponse: + """ + 验证 OHLCV API 响应 + + Args: + data: API 返回的 JSON 数据 + + Returns: + 验证后的响应对象 + + Raises: + ValidationError: 数据不符合模型定义 + + 使用示例: + response = requests.get(url).json() + validated = validate_ohlcv_response(response) + + # 安全访问字段 + if validated.nav: + print(f"净值数据条数: {validated.nav.count}") + """ + return OHLCVResponse.model_validate(data) + + +def dataframe_to_records(df) -> List[Dict[str, Any]]: + """ + 将 DataFrame 转换为 OHLCVRecord 兼容的字典列表 + + Args: + df: pandas DataFrame + + Returns: + 字典列表(可直接用于 OHLCVResponse.data) + """ + if df is None or len(df) == 0: + return [] + + df_reset = df.reset_index() + + # 处理日期列 + date_columns = ['date', 'Date', 'index', 'trade_date', 'datetime'] + for col in date_columns: + if col in df_reset.columns: + try: + import pandas as pd + df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime('%Y-%m-%d') + if col != 'date': + df_reset = df_reset.rename(columns={col: 'date'}) + break + except Exception: + pass + + return df_reset.to_dict(orient='records')