Files
etf/datasource/models.py
aszerW 226a27361f feat(pydantic): 集成 Pydantic 模型到 Flask API 层
1. models.py:
   - 添加 dataframe_to_ohlcv_response() 转换函数
   - 支持 DataFrame → OHLCVResponse 自动转换
   - 自动处理 nav、premium、attrs 等业务数据

2. flask_server.py:
   - 使用 Pydantic 模型构建响应(替代手动 Dict)
   - 错误响应使用 ErrorResponse 模型
   - 代码减少 20+ 行,类型安全提升

3. flask_api_source.py:
   - 使用 validate_ohlcv_response() 验证 API 响应
   - 类型安全访问 nav、premium、info 等字段
   - ETF 数据解析更可靠

测试通过:
 DataFrame → Pydantic 转换正常
 ETF 净值和溢价率正确处理
 线上 API 响应验证成功
 FlaskAPIDataSource 集成正常
2026-05-24 01:13:33 +08:00

537 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Pydantic 数据模型
==================
定义 Flask API 的请求和响应数据结构
用途:
- 类型安全的 API 响应结构
- 调用方数据验证
- IDE 自动补全支持
- 为 Phase 2 (Flask-Pydantic 集成) 做准备
使用示例:
from datasource.models import OHLCVResponse
# 验证 API 响应
response = requests.get(url).json()
validated = OHLCVResponse.model_validate(response)
# 访问字段(有类型提示)
print(validated.code)
print(validated.latest_premium)
"""
from pydantic import BaseModel, Field, field_validator
from typing import Optional, List, Dict, Any
from enum import Enum
# ============================================================
# 枚举类型
# ============================================================
class AssetTypeEnum(str, Enum):
"""资产类型枚举"""
CHINA_INDEX = "china_index"
CHINA_ETF = "china_etf"
CHINA_STOCK = "china_stock"
US_INDEX = "us_index"
US_STOCK = "us_stock"
HK_INDEX = "hk_index"
HK_STOCK = "hk_stock"
FUTURES = "futures"
CRYPTO = "crypto"
class AdjTypeEnum(str, Enum):
"""复权类型枚举"""
RAW = "raw"
QFQ = "qfq"
HFQ = "hfq"
class TimeframeEnum(str, Enum):
"""K线周期枚举加密货币"""
D1 = "1d"
H1 = "1h"
H4 = "4h"
M15 = "15m"
M1 = "1m"
# ============================================================
# 请求模型Phase 2 使用)
# ============================================================
class OHLCVRequest(BaseModel):
"""
OHLCV 数据获取请求
使用示例:
request = OHLCVRequest(
code="META",
start="2024-01-01",
end="2024-03-31",
adj=AdjTypeEnum.RAW
)
"""
code: str = Field(
...,
min_length=1,
max_length=50,
description="标的代码(如 META, 513100.SH, BTC"
)
start: Optional[str] = Field(
None,
pattern=r'^\d{4}-\d{2}-\d{2}$',
description="开始日期YYYY-MM-DD默认90天前"
)
end: Optional[str] = Field(
None,
pattern=r'^\d{4}-\d{2}-\d{2}$',
description="结束日期YYYY-MM-DD默认今天"
)
asset_type: Optional[AssetTypeEnum] = Field(
None,
description="资产类型(可选,覆盖自动检测)"
)
adj: AdjTypeEnum = Field(
AdjTypeEnum.RAW,
description="复权类型raw/qfq/hfq"
)
timeframe: Optional[TimeframeEnum] = Field(
TimeframeEnum.D1,
description="K线周期仅加密货币需要"
)
nocache: bool = Field(
False,
description="是否跳过缓存"
)
@field_validator('start', 'end')
@classmethod
def validate_date_format(cls, v: Optional[str]) -> Optional[str]:
"""验证日期格式"""
if v is None:
return v
# 验证日期合法性
from datetime import datetime
try:
datetime.strptime(v, '%Y-%m-%d')
except ValueError:
raise ValueError(f'Invalid date format: {v}. Use YYYY-MM-DD')
return v
class AssetTypeRequest(BaseModel):
"""资产类型检测请求"""
code: str = Field(
...,
min_length=1,
max_length=50,
description="标的代码"
)
# ============================================================
# 响应模型
# ============================================================
class DataRange(BaseModel):
"""日期范围"""
start: Optional[str] = None
end: Optional[str] = None
class RequestedRange(BaseModel):
"""用户请求的日期范围"""
start: str
end: str
class AvailableRange(BaseModel):
"""可用数据范围"""
start: Optional[str] = None
end: Optional[str] = None
class TypeOverride(BaseModel):
"""类型覆盖信息"""
detected: str
specified: str
hint: str
class OHLCVRecord(BaseModel):
"""
OHLCV 数据记录
字段会根据资产类型动态变化:
- 股票/指数/ETF: date, open, high, low, close, volume
- 加密货币: date, open, high, low, close, volume
- 其他可能包含: amount, change, pct_chg 等
"""
date: str
# 其他字段使用 Dict 接受,因为不同资产类型字段不同
class Config:
extra = "allow" # 允许额外字段
class NavData(BaseModel):
"""ETF 净值数据"""
data: List[Dict[str, Any]]
count: int
class PremiumStats(BaseModel):
"""溢价率统计"""
mean: float
std: float
min: float
max: float
median: float
class PremiumSeriesItem(BaseModel):
"""溢价率序列项"""
date: str
premium: float
class OHLCVResponse(BaseModel):
"""
OHLCV 数据获取响应
字段说明:
- code: 标的代码
- asset_type: 资产类型
- adj: 复权类型
- count: 数据条数
- data: OHLCV 数据列表
- cached: 是否命中缓存
- nav: ETF 净值数据(仅 china_etf
- latest_premium: 最新溢价率(仅 china_etf
- premium_series: 溢价率序列(仅 china_etf
- premium_date: 最新溢价率日期(仅 china_etf
- premium_stats: 溢价率统计(仅 china_etf
- attrs: 完整元数据
- info: 标的信息(股票/指数)
- columns: 数据列名
- date_range: 数据日期范围
- requested_range: 用户请求范围
- available_range: 可用数据范围
- cache_strategy: 缓存策略说明
- timeframe: K线周期加密货币
- type_override: 类型覆盖信息
使用示例:
response = requests.get(url).json()
validated = OHLCVResponse.model_validate(response)
# 有类型提示
print(validated.code)
print(validated.latest_premium)
print(validated.nav.count if validated.nav else 0)
"""
code: str = Field(..., description="标的代码")
asset_type: str = Field(..., description="资产类型")
adj: str = Field(..., description="复权类型")
count: int = Field(..., description="数据条数")
data: List[Dict[str, Any]] = Field(..., description="OHLCV 数据列表")
cached: Optional[bool] = Field(None, description="是否命中缓存")
# ETF 相关(仅 china_etf
nav: Optional[NavData] = Field(None, description="ETF 净值数据")
latest_premium: Optional[float] = Field(None, description="最新溢价率(%")
premium_series: Optional[List[PremiumSeriesItem]] = Field(None, description="溢价率序列")
premium_date: Optional[str] = Field(None, description="最新溢价率日期")
premium_stats: Optional[PremiumStats] = Field(None, description="溢价率统计")
# 元数据
attrs: Optional[Dict[str, Any]] = Field(None, description="完整元数据")
info: Optional[Dict[str, Any]] = Field(None, description="标的信息")
# 辅助信息
columns: Optional[List[str]] = Field(None, description="数据列名")
date_range: Optional[DataRange] = Field(None, description="数据日期范围")
requested_range: Optional[RequestedRange] = Field(None, description="用户请求范围")
available_range: Optional[AvailableRange] = Field(None, description="可用数据范围")
cache_strategy: Optional[str] = Field(None, description="缓存策略说明")
timeframe: Optional[str] = Field(None, description="K线周期加密货币")
type_override: Optional[TypeOverride] = Field(None, description="类型覆盖信息")
class Config:
extra = "allow" # 允许额外字段,保持向后兼容
class AssetTypeResponse(BaseModel):
"""资产类型检测响应"""
code: str = Field(..., description="标的代码")
asset_type: str = Field(..., description="资产类型")
description: str = Field(..., description="资产类型描述")
class CacheStats(BaseModel):
"""缓存统计信息"""
lru_cache: Dict[str, Any] = Field(..., description="LRU 缓存统计")
ttl_cache_size: int = Field(..., description="TTL 缓存大小")
ttl_seconds: int = Field(..., description="TTL 秒数")
default_start_date: str = Field(..., description="默认数据起点")
cache_strategy: str = Field(..., description="缓存策略")
class ErrorResponse(BaseModel):
"""错误响应"""
error: str = Field(..., description="错误信息")
code: Optional[str] = Field(None, description="标的代码")
asset_type: Optional[str] = Field(None, description="资产类型")
adj: Optional[str] = Field(None, description="复权类型")
detected_type: Optional[str] = Field(None, description="自动检测的类型")
example: Optional[str] = Field(None, description="示例请求")
hint: Optional[str] = Field(None, description="提示信息")
valid_adj: Optional[List[str]] = Field(None, description="有效的 adj 值")
valid_types: Optional[List[str]] = Field(None, description="有效的资产类型")
valid_timeframes: Optional[List[str]] = Field(None, description="有效的 timeframe 值")
class Config:
extra = "allow" # 允许额外字段
# ============================================================
# 模型工具函数
# ============================================================
def validate_ohlcv_response(data: Dict[str, Any]) -> OHLCVResponse:
"""
验证 OHLCV API 响应
Args:
data: API 返回的 JSON 数据
Returns:
验证后的响应对象
Raises:
ValidationError: 数据不符合模型定义
使用示例:
response = requests.get(url).json()
validated = validate_ohlcv_response(response)
# 安全访问字段
if validated.nav:
print(f"净值数据条数: {validated.nav.count}")
"""
return OHLCVResponse.model_validate(data)
def dataframe_to_records(df, date_format: str = '%Y-%m-%d') -> List[Dict[str, Any]]:
"""
将 DataFrame 转换为 OHLCVRecord 兼容的字典列表
Args:
df: pandas DataFrame
date_format: 日期格式字符串
Returns:
字典列表(可直接用于 OHLCVResponse.data
"""
if df is None or len(df) == 0:
return []
df_reset = df.reset_index()
# 处理日期列
date_columns = ['date', 'Date', 'index', 'trade_date', 'datetime']
for col in date_columns:
if col in df_reset.columns:
try:
import pandas as pd
df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime(date_format)
if col != 'date':
df_reset = df_reset.rename(columns={col: 'date'})
break
except Exception:
pass
return df_reset.to_dict(orient='records')
# ============================================================
# DataFrame → Pydantic Model 转换函数
# ============================================================
def dataframe_to_ohlcv_response(
df: Any, # pd.DataFrame
code: str,
asset_type: str,
adj: str = 'raw',
cached: bool = False,
nav_df: Optional[Any] = None, # Optional[pd.DataFrame]
premium_series: Optional[Any] = None, # Optional[pd.Series]
info: Optional[Dict[str, Any]] = None,
attrs: Optional[Dict[str, Any]] = None,
date_format: Optional[str] = None,
**kwargs
) -> 'OHLCVResponse':
"""
将 DataFrame 转换为 OHLCVResponse 模型
用途:
- Flask API: 统一响应结构
- 本地调用: 获得类型安全的响应对象
Args:
df: 主数据 DataFrameOHLCV
code: 标的代码
asset_type: 资产类型
adj: 复权类型
cached: 是否命中缓存
nav_df: ETF 净值 DataFrame可选
premium_series: 溢价率 Series可选
info: 标的信息(可选)
attrs: 完整元数据(可选)
date_format: 日期格式(可选,默认根据 asset_type 自动选择)
**kwargs: 其他字段columns, date_range, timeframe 等)
Returns:
OHLCVResponse 模型实例
使用示例:
# Flask API
df = fetcher.fetch("META", start, end)
response = dataframe_to_ohlcv_response(
df, code="META", asset_type="us_stock", adj="raw"
)
return response.model_dump(mode='json')
# 本地调用
df = fetcher.fetch("513100.SH", start, end)
nav_df = df.attrs.get('nav')
premium = df.attrs.get('premium')
response = dataframe_to_ohlcv_response(
df,
code="513100.SH",
asset_type="china_etf",
nav_df=nav_df,
premium_series=premium
)
print(response.nav.count) # IDE 有自动补全
"""
import pandas as pd
# 自动选择日期格式
if date_format is None:
date_format = '%Y-%m-%d %H:%M:%S' if asset_type == 'crypto' else '%Y-%m-%d'
# 转换主数据
data = dataframe_to_records(df, date_format) if df is not None else []
# 构建响应数据
response_data = {
"code": code,
"asset_type": asset_type,
"adj": adj,
"count": len(data),
"data": data,
"cached": cached,
}
# 添加 info优先使用传入的其次从 df.attrs 获取)
if info is not None:
response_data['info'] = info
elif hasattr(df, 'attrs') and df.attrs and 'info' in df.attrs:
response_data['info'] = df.attrs['info']
# 添加 nav如果有
if nav_df is not None and isinstance(nav_df, pd.DataFrame):
nav_records = dataframe_to_records(nav_df, date_format)
response_data['nav'] = {
"data": nav_records,
"count": len(nav_records)
}
# 添加 premium如果有
if premium_series is not None and isinstance(premium_series, pd.Series) and len(premium_series) > 0:
# 最新溢价率
latest_premium = float(premium_series.iloc[-1])
response_data['latest_premium'] = round(latest_premium, 6)
response_data['premium_date'] = premium_series.index[-1].strftime(date_format)
# 溢价率序列
premium_list = [
{"date": date.strftime(date_format), "premium": round(float(premium), 6)}
for date, premium in premium_series.items()
]
response_data['premium_series'] = premium_list
# 溢价率统计
response_data['premium_stats'] = {
"mean": round(float(premium_series.mean()), 6),
"std": round(float(premium_series.std()), 6),
"min": round(float(premium_series.min()), 6),
"max": round(float(premium_series.max()), 6),
"median": round(float(premium_series.median()), 6),
}
# 添加 attrs如果有
if attrs is not None:
# 过滤内部缓存元数据
public_attrs = {k: v for k, v in attrs.items() if not k.startswith('_cache_')}
# 转换 DataFrame/Series 为可序列化格式
attrs_serializable = {}
for key, value in public_attrs.items():
if isinstance(value, pd.DataFrame):
attrs_serializable[key] = {
'data': dataframe_to_records(value, date_format),
'count': len(value)
}
elif isinstance(value, pd.Series):
series_copy = value.copy()
series_copy.index = series_copy.index.strftime(date_format)
attrs_serializable[key] = {
'type': 'series',
'data': series_copy.to_dict(),
'name': value.name
}
else:
attrs_serializable[key] = value
if attrs_serializable:
response_data['attrs'] = attrs_serializable
elif hasattr(df, 'attrs') and df.attrs:
# 从 df.attrs 提取
public_attrs = {k: v for k, v in df.attrs.items() if not k.startswith('_cache_')}
if public_attrs:
attrs_serializable = {}
for key, value in public_attrs.items():
if isinstance(value, pd.DataFrame):
attrs_serializable[key] = {
'data': dataframe_to_records(value, date_format),
'count': len(value)
}
elif isinstance(value, pd.Series):
series_copy = value.copy()
series_copy.index = series_copy.index.strftime(date_format)
attrs_serializable[key] = {
'type': 'series',
'data': series_copy.to_dict(),
'name': value.name
}
else:
attrs_serializable[key] = value
if attrs_serializable:
response_data['attrs'] = attrs_serializable
# 添加其他辅助信息
response_data.update(kwargs)
# 验证并返回模型
return OHLCVResponse.model_validate(response_data)