Files
etf/datasource/models.py
aszerW 72df18a28b feat(models): 添加 Pydantic 数据模型(Phase 1)
- 定义请求模型:OHLCVRequest, AssetTypeRequest
- 定义响应模型:OHLCVResponse, AssetTypeResponse, ErrorResponse
- 定义枚举类型:AssetTypeEnum, AdjTypeEnum, TimeframeEnum
- 提供类型安全的 API 响应验证
- 支持 IDE 自动补全和类型检查
- 为 Phase 2 (Flask-Pydantic 集成) 做准备

测试通过:
 请求参数自动验证(日期格式、adj 值)
 响应数据验证(美股 META、ETF 513100.SH、BTC)
 序列化/反序列化正常
 类型安全检查(缺失字段、类型错误)
2026-05-24 00:42:22 +08:00

360 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Pydantic 数据模型
==================
定义 Flask API 的请求和响应数据结构
用途:
- 类型安全的 API 响应结构
- 调用方数据验证
- IDE 自动补全支持
- 为 Phase 2 (Flask-Pydantic 集成) 做准备
使用示例:
from datasource.models import OHLCVResponse
# 验证 API 响应
response = requests.get(url).json()
validated = OHLCVResponse.model_validate(response)
# 访问字段(有类型提示)
print(validated.code)
print(validated.latest_premium)
"""
from pydantic import BaseModel, Field, field_validator
from typing import Optional, List, Dict, Any
from enum import Enum
# ============================================================
# 枚举类型
# ============================================================
class AssetTypeEnum(str, Enum):
"""资产类型枚举"""
CHINA_INDEX = "china_index"
CHINA_ETF = "china_etf"
CHINA_STOCK = "china_stock"
US_INDEX = "us_index"
US_STOCK = "us_stock"
HK_INDEX = "hk_index"
HK_STOCK = "hk_stock"
FUTURES = "futures"
CRYPTO = "crypto"
class AdjTypeEnum(str, Enum):
"""复权类型枚举"""
RAW = "raw"
QFQ = "qfq"
HFQ = "hfq"
class TimeframeEnum(str, Enum):
"""K线周期枚举加密货币"""
D1 = "1d"
H1 = "1h"
H4 = "4h"
M15 = "15m"
M1 = "1m"
# ============================================================
# 请求模型Phase 2 使用)
# ============================================================
class OHLCVRequest(BaseModel):
"""
OHLCV 数据获取请求
使用示例:
request = OHLCVRequest(
code="META",
start="2024-01-01",
end="2024-03-31",
adj=AdjTypeEnum.RAW
)
"""
code: str = Field(
...,
min_length=1,
max_length=50,
description="标的代码(如 META, 513100.SH, BTC"
)
start: Optional[str] = Field(
None,
pattern=r'^\d{4}-\d{2}-\d{2}$',
description="开始日期YYYY-MM-DD默认90天前"
)
end: Optional[str] = Field(
None,
pattern=r'^\d{4}-\d{2}-\d{2}$',
description="结束日期YYYY-MM-DD默认今天"
)
asset_type: Optional[AssetTypeEnum] = Field(
None,
description="资产类型(可选,覆盖自动检测)"
)
adj: AdjTypeEnum = Field(
AdjTypeEnum.RAW,
description="复权类型raw/qfq/hfq"
)
timeframe: Optional[TimeframeEnum] = Field(
TimeframeEnum.D1,
description="K线周期仅加密货币需要"
)
nocache: bool = Field(
False,
description="是否跳过缓存"
)
@field_validator('start', 'end')
@classmethod
def validate_date_format(cls, v: Optional[str]) -> Optional[str]:
"""验证日期格式"""
if v is None:
return v
# 验证日期合法性
from datetime import datetime
try:
datetime.strptime(v, '%Y-%m-%d')
except ValueError:
raise ValueError(f'Invalid date format: {v}. Use YYYY-MM-DD')
return v
class AssetTypeRequest(BaseModel):
"""资产类型检测请求"""
code: str = Field(
...,
min_length=1,
max_length=50,
description="标的代码"
)
# ============================================================
# 响应模型
# ============================================================
class DataRange(BaseModel):
"""日期范围"""
start: Optional[str] = None
end: Optional[str] = None
class RequestedRange(BaseModel):
"""用户请求的日期范围"""
start: str
end: str
class AvailableRange(BaseModel):
"""可用数据范围"""
start: Optional[str] = None
end: Optional[str] = None
class TypeOverride(BaseModel):
"""类型覆盖信息"""
detected: str
specified: str
hint: str
class OHLCVRecord(BaseModel):
"""
OHLCV 数据记录
字段会根据资产类型动态变化:
- 股票/指数/ETF: date, open, high, low, close, volume
- 加密货币: date, open, high, low, close, volume
- 其他可能包含: amount, change, pct_chg 等
"""
date: str
# 其他字段使用 Dict 接受,因为不同资产类型字段不同
class Config:
extra = "allow" # 允许额外字段
class NavData(BaseModel):
"""ETF 净值数据"""
data: List[Dict[str, Any]]
count: int
class PremiumStats(BaseModel):
"""溢价率统计"""
mean: float
std: float
min: float
max: float
median: float
class PremiumSeriesItem(BaseModel):
"""溢价率序列项"""
date: str
premium: float
class OHLCVResponse(BaseModel):
"""
OHLCV 数据获取响应
字段说明:
- code: 标的代码
- asset_type: 资产类型
- adj: 复权类型
- count: 数据条数
- data: OHLCV 数据列表
- cached: 是否命中缓存
- nav: ETF 净值数据(仅 china_etf
- latest_premium: 最新溢价率(仅 china_etf
- premium_series: 溢价率序列(仅 china_etf
- premium_date: 最新溢价率日期(仅 china_etf
- premium_stats: 溢价率统计(仅 china_etf
- attrs: 完整元数据
- info: 标的信息(股票/指数)
- columns: 数据列名
- date_range: 数据日期范围
- requested_range: 用户请求范围
- available_range: 可用数据范围
- cache_strategy: 缓存策略说明
- timeframe: K线周期加密货币
- type_override: 类型覆盖信息
使用示例:
response = requests.get(url).json()
validated = OHLCVResponse.model_validate(response)
# 有类型提示
print(validated.code)
print(validated.latest_premium)
print(validated.nav.count if validated.nav else 0)
"""
code: str = Field(..., description="标的代码")
asset_type: str = Field(..., description="资产类型")
adj: str = Field(..., description="复权类型")
count: int = Field(..., description="数据条数")
data: List[Dict[str, Any]] = Field(..., description="OHLCV 数据列表")
cached: Optional[bool] = Field(None, description="是否命中缓存")
# ETF 相关(仅 china_etf
nav: Optional[NavData] = Field(None, description="ETF 净值数据")
latest_premium: Optional[float] = Field(None, description="最新溢价率(%")
premium_series: Optional[List[PremiumSeriesItem]] = Field(None, description="溢价率序列")
premium_date: Optional[str] = Field(None, description="最新溢价率日期")
premium_stats: Optional[PremiumStats] = Field(None, description="溢价率统计")
# 元数据
attrs: Optional[Dict[str, Any]] = Field(None, description="完整元数据")
info: Optional[Dict[str, Any]] = Field(None, description="标的信息")
# 辅助信息
columns: Optional[List[str]] = Field(None, description="数据列名")
date_range: Optional[DataRange] = Field(None, description="数据日期范围")
requested_range: Optional[RequestedRange] = Field(None, description="用户请求范围")
available_range: Optional[AvailableRange] = Field(None, description="可用数据范围")
cache_strategy: Optional[str] = Field(None, description="缓存策略说明")
timeframe: Optional[str] = Field(None, description="K线周期加密货币")
type_override: Optional[TypeOverride] = Field(None, description="类型覆盖信息")
class Config:
extra = "allow" # 允许额外字段,保持向后兼容
class AssetTypeResponse(BaseModel):
"""资产类型检测响应"""
code: str = Field(..., description="标的代码")
asset_type: str = Field(..., description="资产类型")
description: str = Field(..., description="资产类型描述")
class CacheStats(BaseModel):
"""缓存统计信息"""
lru_cache: Dict[str, Any] = Field(..., description="LRU 缓存统计")
ttl_cache_size: int = Field(..., description="TTL 缓存大小")
ttl_seconds: int = Field(..., description="TTL 秒数")
default_start_date: str = Field(..., description="默认数据起点")
cache_strategy: str = Field(..., description="缓存策略")
class ErrorResponse(BaseModel):
"""错误响应"""
error: str = Field(..., description="错误信息")
code: Optional[str] = Field(None, description="标的代码")
asset_type: Optional[str] = Field(None, description="资产类型")
adj: Optional[str] = Field(None, description="复权类型")
detected_type: Optional[str] = Field(None, description="自动检测的类型")
example: Optional[str] = Field(None, description="示例请求")
hint: Optional[str] = Field(None, description="提示信息")
valid_adj: Optional[List[str]] = Field(None, description="有效的 adj 值")
valid_types: Optional[List[str]] = Field(None, description="有效的资产类型")
valid_timeframes: Optional[List[str]] = Field(None, description="有效的 timeframe 值")
class Config:
extra = "allow" # 允许额外字段
# ============================================================
# 模型工具函数
# ============================================================
def validate_ohlcv_response(data: Dict[str, Any]) -> OHLCVResponse:
"""
验证 OHLCV API 响应
Args:
data: API 返回的 JSON 数据
Returns:
验证后的响应对象
Raises:
ValidationError: 数据不符合模型定义
使用示例:
response = requests.get(url).json()
validated = validate_ohlcv_response(response)
# 安全访问字段
if validated.nav:
print(f"净值数据条数: {validated.nav.count}")
"""
return OHLCVResponse.model_validate(data)
def dataframe_to_records(df) -> List[Dict[str, Any]]:
"""
将 DataFrame 转换为 OHLCVRecord 兼容的字典列表
Args:
df: pandas DataFrame
Returns:
字典列表(可直接用于 OHLCVResponse.data
"""
if df is None or len(df) == 0:
return []
df_reset = df.reset_index()
# 处理日期列
date_columns = ['date', 'Date', 'index', 'trade_date', 'datetime']
for col in date_columns:
if col in df_reset.columns:
try:
import pandas as pd
df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime('%Y-%m-%d')
if col != 'date':
df_reset = df_reset.rename(columns={col: 'date'})
break
except Exception:
pass
return df_reset.to_dict(orient='records')