feat(models): 添加 Pydantic 数据模型(Phase 1)

- 定义请求模型:OHLCVRequest, AssetTypeRequest
- 定义响应模型:OHLCVResponse, AssetTypeResponse, ErrorResponse
- 定义枚举类型:AssetTypeEnum, AdjTypeEnum, TimeframeEnum
- 提供类型安全的 API 响应验证
- 支持 IDE 自动补全和类型检查
- 为 Phase 2 (Flask-Pydantic 集成) 做准备

测试通过:
 请求参数自动验证(日期格式、adj 值)
 响应数据验证(美股 META、ETF 513100.SH、BTC)
 序列化/反序列化正常
 类型安全检查(缺失字段、类型错误)
This commit is contained in:
2026-05-24 00:42:22 +08:00
parent 11a0a6502b
commit 72df18a28b

359
datasource/models.py Normal file
View File

@@ -0,0 +1,359 @@
"""
Pydantic 数据模型
==================
定义 Flask API 的请求和响应数据结构
用途:
- 类型安全的 API 响应结构
- 调用方数据验证
- IDE 自动补全支持
- 为 Phase 2 (Flask-Pydantic 集成) 做准备
使用示例:
from datasource.models import OHLCVResponse
# 验证 API 响应
response = requests.get(url).json()
validated = OHLCVResponse.model_validate(response)
# 访问字段(有类型提示)
print(validated.code)
print(validated.latest_premium)
"""
from pydantic import BaseModel, Field, field_validator
from typing import Optional, List, Dict, Any
from enum import Enum
# ============================================================
# 枚举类型
# ============================================================
class AssetTypeEnum(str, Enum):
"""资产类型枚举"""
CHINA_INDEX = "china_index"
CHINA_ETF = "china_etf"
CHINA_STOCK = "china_stock"
US_INDEX = "us_index"
US_STOCK = "us_stock"
HK_INDEX = "hk_index"
HK_STOCK = "hk_stock"
FUTURES = "futures"
CRYPTO = "crypto"
class AdjTypeEnum(str, Enum):
"""复权类型枚举"""
RAW = "raw"
QFQ = "qfq"
HFQ = "hfq"
class TimeframeEnum(str, Enum):
"""K线周期枚举加密货币"""
D1 = "1d"
H1 = "1h"
H4 = "4h"
M15 = "15m"
M1 = "1m"
# ============================================================
# 请求模型Phase 2 使用)
# ============================================================
class OHLCVRequest(BaseModel):
"""
OHLCV 数据获取请求
使用示例:
request = OHLCVRequest(
code="META",
start="2024-01-01",
end="2024-03-31",
adj=AdjTypeEnum.RAW
)
"""
code: str = Field(
...,
min_length=1,
max_length=50,
description="标的代码(如 META, 513100.SH, BTC"
)
start: Optional[str] = Field(
None,
pattern=r'^\d{4}-\d{2}-\d{2}$',
description="开始日期YYYY-MM-DD默认90天前"
)
end: Optional[str] = Field(
None,
pattern=r'^\d{4}-\d{2}-\d{2}$',
description="结束日期YYYY-MM-DD默认今天"
)
asset_type: Optional[AssetTypeEnum] = Field(
None,
description="资产类型(可选,覆盖自动检测)"
)
adj: AdjTypeEnum = Field(
AdjTypeEnum.RAW,
description="复权类型raw/qfq/hfq"
)
timeframe: Optional[TimeframeEnum] = Field(
TimeframeEnum.D1,
description="K线周期仅加密货币需要"
)
nocache: bool = Field(
False,
description="是否跳过缓存"
)
@field_validator('start', 'end')
@classmethod
def validate_date_format(cls, v: Optional[str]) -> Optional[str]:
"""验证日期格式"""
if v is None:
return v
# 验证日期合法性
from datetime import datetime
try:
datetime.strptime(v, '%Y-%m-%d')
except ValueError:
raise ValueError(f'Invalid date format: {v}. Use YYYY-MM-DD')
return v
class AssetTypeRequest(BaseModel):
"""资产类型检测请求"""
code: str = Field(
...,
min_length=1,
max_length=50,
description="标的代码"
)
# ============================================================
# 响应模型
# ============================================================
class DataRange(BaseModel):
"""日期范围"""
start: Optional[str] = None
end: Optional[str] = None
class RequestedRange(BaseModel):
"""用户请求的日期范围"""
start: str
end: str
class AvailableRange(BaseModel):
"""可用数据范围"""
start: Optional[str] = None
end: Optional[str] = None
class TypeOverride(BaseModel):
"""类型覆盖信息"""
detected: str
specified: str
hint: str
class OHLCVRecord(BaseModel):
"""
OHLCV 数据记录
字段会根据资产类型动态变化:
- 股票/指数/ETF: date, open, high, low, close, volume
- 加密货币: date, open, high, low, close, volume
- 其他可能包含: amount, change, pct_chg 等
"""
date: str
# 其他字段使用 Dict 接受,因为不同资产类型字段不同
class Config:
extra = "allow" # 允许额外字段
class NavData(BaseModel):
"""ETF 净值数据"""
data: List[Dict[str, Any]]
count: int
class PremiumStats(BaseModel):
"""溢价率统计"""
mean: float
std: float
min: float
max: float
median: float
class PremiumSeriesItem(BaseModel):
"""溢价率序列项"""
date: str
premium: float
class OHLCVResponse(BaseModel):
"""
OHLCV 数据获取响应
字段说明:
- code: 标的代码
- asset_type: 资产类型
- adj: 复权类型
- count: 数据条数
- data: OHLCV 数据列表
- cached: 是否命中缓存
- nav: ETF 净值数据(仅 china_etf
- latest_premium: 最新溢价率(仅 china_etf
- premium_series: 溢价率序列(仅 china_etf
- premium_date: 最新溢价率日期(仅 china_etf
- premium_stats: 溢价率统计(仅 china_etf
- attrs: 完整元数据
- info: 标的信息(股票/指数)
- columns: 数据列名
- date_range: 数据日期范围
- requested_range: 用户请求范围
- available_range: 可用数据范围
- cache_strategy: 缓存策略说明
- timeframe: K线周期加密货币
- type_override: 类型覆盖信息
使用示例:
response = requests.get(url).json()
validated = OHLCVResponse.model_validate(response)
# 有类型提示
print(validated.code)
print(validated.latest_premium)
print(validated.nav.count if validated.nav else 0)
"""
code: str = Field(..., description="标的代码")
asset_type: str = Field(..., description="资产类型")
adj: str = Field(..., description="复权类型")
count: int = Field(..., description="数据条数")
data: List[Dict[str, Any]] = Field(..., description="OHLCV 数据列表")
cached: Optional[bool] = Field(None, description="是否命中缓存")
# ETF 相关(仅 china_etf
nav: Optional[NavData] = Field(None, description="ETF 净值数据")
latest_premium: Optional[float] = Field(None, description="最新溢价率(%")
premium_series: Optional[List[PremiumSeriesItem]] = Field(None, description="溢价率序列")
premium_date: Optional[str] = Field(None, description="最新溢价率日期")
premium_stats: Optional[PremiumStats] = Field(None, description="溢价率统计")
# 元数据
attrs: Optional[Dict[str, Any]] = Field(None, description="完整元数据")
info: Optional[Dict[str, Any]] = Field(None, description="标的信息")
# 辅助信息
columns: Optional[List[str]] = Field(None, description="数据列名")
date_range: Optional[DataRange] = Field(None, description="数据日期范围")
requested_range: Optional[RequestedRange] = Field(None, description="用户请求范围")
available_range: Optional[AvailableRange] = Field(None, description="可用数据范围")
cache_strategy: Optional[str] = Field(None, description="缓存策略说明")
timeframe: Optional[str] = Field(None, description="K线周期加密货币")
type_override: Optional[TypeOverride] = Field(None, description="类型覆盖信息")
class Config:
extra = "allow" # 允许额外字段,保持向后兼容
class AssetTypeResponse(BaseModel):
"""资产类型检测响应"""
code: str = Field(..., description="标的代码")
asset_type: str = Field(..., description="资产类型")
description: str = Field(..., description="资产类型描述")
class CacheStats(BaseModel):
"""缓存统计信息"""
lru_cache: Dict[str, Any] = Field(..., description="LRU 缓存统计")
ttl_cache_size: int = Field(..., description="TTL 缓存大小")
ttl_seconds: int = Field(..., description="TTL 秒数")
default_start_date: str = Field(..., description="默认数据起点")
cache_strategy: str = Field(..., description="缓存策略")
class ErrorResponse(BaseModel):
"""错误响应"""
error: str = Field(..., description="错误信息")
code: Optional[str] = Field(None, description="标的代码")
asset_type: Optional[str] = Field(None, description="资产类型")
adj: Optional[str] = Field(None, description="复权类型")
detected_type: Optional[str] = Field(None, description="自动检测的类型")
example: Optional[str] = Field(None, description="示例请求")
hint: Optional[str] = Field(None, description="提示信息")
valid_adj: Optional[List[str]] = Field(None, description="有效的 adj 值")
valid_types: Optional[List[str]] = Field(None, description="有效的资产类型")
valid_timeframes: Optional[List[str]] = Field(None, description="有效的 timeframe 值")
class Config:
extra = "allow" # 允许额外字段
# ============================================================
# 模型工具函数
# ============================================================
def validate_ohlcv_response(data: Dict[str, Any]) -> OHLCVResponse:
"""
验证 OHLCV API 响应
Args:
data: API 返回的 JSON 数据
Returns:
验证后的响应对象
Raises:
ValidationError: 数据不符合模型定义
使用示例:
response = requests.get(url).json()
validated = validate_ohlcv_response(response)
# 安全访问字段
if validated.nav:
print(f"净值数据条数: {validated.nav.count}")
"""
return OHLCVResponse.model_validate(data)
def dataframe_to_records(df) -> List[Dict[str, Any]]:
"""
将 DataFrame 转换为 OHLCVRecord 兼容的字典列表
Args:
df: pandas DataFrame
Returns:
字典列表(可直接用于 OHLCVResponse.data
"""
if df is None or len(df) == 0:
return []
df_reset = df.reset_index()
# 处理日期列
date_columns = ['date', 'Date', 'index', 'trade_date', 'datetime']
for col in date_columns:
if col in df_reset.columns:
try:
import pandas as pd
df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime('%Y-%m-%d')
if col != 'date':
df_reset = df_reset.rename(columns={col: 'date'})
break
except Exception:
pass
return df_reset.to_dict(orient='records')