1. models.py: - 添加 dataframe_to_ohlcv_response() 转换函数 - 支持 DataFrame → OHLCVResponse 自动转换 - 自动处理 nav、premium、attrs 等业务数据 2. flask_server.py: - 使用 Pydantic 模型构建响应(替代手动 Dict) - 错误响应使用 ErrorResponse 模型 - 代码减少 20+ 行,类型安全提升 3. flask_api_source.py: - 使用 validate_ohlcv_response() 验证 API 响应 - 类型安全访问 nav、premium、info 等字段 - ETF 数据解析更可靠 测试通过: ✅ DataFrame → Pydantic 转换正常 ✅ ETF 净值和溢价率正确处理 ✅ 线上 API 响应验证成功 ✅ FlaskAPIDataSource 集成正常
537 lines
17 KiB
Python
537 lines
17 KiB
Python
"""
|
||
Pydantic 数据模型
|
||
==================
|
||
定义 Flask API 的请求和响应数据结构
|
||
|
||
用途:
|
||
- 类型安全的 API 响应结构
|
||
- 调用方数据验证
|
||
- IDE 自动补全支持
|
||
- 为 Phase 2 (Flask-Pydantic 集成) 做准备
|
||
|
||
使用示例:
|
||
from datasource.models import OHLCVResponse
|
||
|
||
# 验证 API 响应
|
||
response = requests.get(url).json()
|
||
validated = OHLCVResponse.model_validate(response)
|
||
|
||
# 访问字段(有类型提示)
|
||
print(validated.code)
|
||
print(validated.latest_premium)
|
||
"""
|
||
|
||
from pydantic import BaseModel, Field, field_validator
|
||
from typing import Optional, List, Dict, Any
|
||
from enum import Enum
|
||
|
||
|
||
# ============================================================
|
||
# 枚举类型
|
||
# ============================================================
|
||
|
||
class AssetTypeEnum(str, Enum):
|
||
"""资产类型枚举"""
|
||
CHINA_INDEX = "china_index"
|
||
CHINA_ETF = "china_etf"
|
||
CHINA_STOCK = "china_stock"
|
||
US_INDEX = "us_index"
|
||
US_STOCK = "us_stock"
|
||
HK_INDEX = "hk_index"
|
||
HK_STOCK = "hk_stock"
|
||
FUTURES = "futures"
|
||
CRYPTO = "crypto"
|
||
|
||
|
||
class AdjTypeEnum(str, Enum):
|
||
"""复权类型枚举"""
|
||
RAW = "raw"
|
||
QFQ = "qfq"
|
||
HFQ = "hfq"
|
||
|
||
|
||
class TimeframeEnum(str, Enum):
|
||
"""K线周期枚举(加密货币)"""
|
||
D1 = "1d"
|
||
H1 = "1h"
|
||
H4 = "4h"
|
||
M15 = "15m"
|
||
M1 = "1m"
|
||
|
||
|
||
# ============================================================
|
||
# 请求模型(Phase 2 使用)
|
||
# ============================================================
|
||
|
||
class OHLCVRequest(BaseModel):
|
||
"""
|
||
OHLCV 数据获取请求
|
||
|
||
使用示例:
|
||
request = OHLCVRequest(
|
||
code="META",
|
||
start="2024-01-01",
|
||
end="2024-03-31",
|
||
adj=AdjTypeEnum.RAW
|
||
)
|
||
"""
|
||
code: str = Field(
|
||
...,
|
||
min_length=1,
|
||
max_length=50,
|
||
description="标的代码(如 META, 513100.SH, BTC)"
|
||
)
|
||
start: Optional[str] = Field(
|
||
None,
|
||
pattern=r'^\d{4}-\d{2}-\d{2}$',
|
||
description="开始日期(YYYY-MM-DD),默认90天前"
|
||
)
|
||
end: Optional[str] = Field(
|
||
None,
|
||
pattern=r'^\d{4}-\d{2}-\d{2}$',
|
||
description="结束日期(YYYY-MM-DD),默认今天"
|
||
)
|
||
asset_type: Optional[AssetTypeEnum] = Field(
|
||
None,
|
||
description="资产类型(可选,覆盖自动检测)"
|
||
)
|
||
adj: AdjTypeEnum = Field(
|
||
AdjTypeEnum.RAW,
|
||
description="复权类型(raw/qfq/hfq)"
|
||
)
|
||
timeframe: Optional[TimeframeEnum] = Field(
|
||
TimeframeEnum.D1,
|
||
description="K线周期(仅加密货币需要)"
|
||
)
|
||
nocache: bool = Field(
|
||
False,
|
||
description="是否跳过缓存"
|
||
)
|
||
|
||
@field_validator('start', 'end')
|
||
@classmethod
|
||
def validate_date_format(cls, v: Optional[str]) -> Optional[str]:
|
||
"""验证日期格式"""
|
||
if v is None:
|
||
return v
|
||
|
||
# 验证日期合法性
|
||
from datetime import datetime
|
||
try:
|
||
datetime.strptime(v, '%Y-%m-%d')
|
||
except ValueError:
|
||
raise ValueError(f'Invalid date format: {v}. Use YYYY-MM-DD')
|
||
|
||
return v
|
||
|
||
|
||
class AssetTypeRequest(BaseModel):
|
||
"""资产类型检测请求"""
|
||
code: str = Field(
|
||
...,
|
||
min_length=1,
|
||
max_length=50,
|
||
description="标的代码"
|
||
)
|
||
|
||
|
||
# ============================================================
|
||
# 响应模型
|
||
# ============================================================
|
||
|
||
class DataRange(BaseModel):
|
||
"""日期范围"""
|
||
start: Optional[str] = None
|
||
end: Optional[str] = None
|
||
|
||
|
||
class RequestedRange(BaseModel):
|
||
"""用户请求的日期范围"""
|
||
start: str
|
||
end: str
|
||
|
||
|
||
class AvailableRange(BaseModel):
|
||
"""可用数据范围"""
|
||
start: Optional[str] = None
|
||
end: Optional[str] = None
|
||
|
||
|
||
class TypeOverride(BaseModel):
|
||
"""类型覆盖信息"""
|
||
detected: str
|
||
specified: str
|
||
hint: str
|
||
|
||
|
||
class OHLCVRecord(BaseModel):
|
||
"""
|
||
OHLCV 数据记录
|
||
|
||
字段会根据资产类型动态变化:
|
||
- 股票/指数/ETF: date, open, high, low, close, volume
|
||
- 加密货币: date, open, high, low, close, volume
|
||
- 其他可能包含: amount, change, pct_chg 等
|
||
"""
|
||
date: str
|
||
# 其他字段使用 Dict 接受,因为不同资产类型字段不同
|
||
|
||
class Config:
|
||
extra = "allow" # 允许额外字段
|
||
|
||
|
||
class NavData(BaseModel):
|
||
"""ETF 净值数据"""
|
||
data: List[Dict[str, Any]]
|
||
count: int
|
||
|
||
|
||
class PremiumStats(BaseModel):
|
||
"""溢价率统计"""
|
||
mean: float
|
||
std: float
|
||
min: float
|
||
max: float
|
||
median: float
|
||
|
||
|
||
class PremiumSeriesItem(BaseModel):
|
||
"""溢价率序列项"""
|
||
date: str
|
||
premium: float
|
||
|
||
|
||
class OHLCVResponse(BaseModel):
|
||
"""
|
||
OHLCV 数据获取响应
|
||
|
||
字段说明:
|
||
- code: 标的代码
|
||
- asset_type: 资产类型
|
||
- adj: 复权类型
|
||
- count: 数据条数
|
||
- data: OHLCV 数据列表
|
||
- cached: 是否命中缓存
|
||
- nav: ETF 净值数据(仅 china_etf)
|
||
- latest_premium: 最新溢价率(仅 china_etf)
|
||
- premium_series: 溢价率序列(仅 china_etf)
|
||
- premium_date: 最新溢价率日期(仅 china_etf)
|
||
- premium_stats: 溢价率统计(仅 china_etf)
|
||
- attrs: 完整元数据
|
||
- info: 标的信息(股票/指数)
|
||
- columns: 数据列名
|
||
- date_range: 数据日期范围
|
||
- requested_range: 用户请求范围
|
||
- available_range: 可用数据范围
|
||
- cache_strategy: 缓存策略说明
|
||
- timeframe: K线周期(加密货币)
|
||
- type_override: 类型覆盖信息
|
||
|
||
使用示例:
|
||
response = requests.get(url).json()
|
||
validated = OHLCVResponse.model_validate(response)
|
||
|
||
# 有类型提示
|
||
print(validated.code)
|
||
print(validated.latest_premium)
|
||
print(validated.nav.count if validated.nav else 0)
|
||
"""
|
||
code: str = Field(..., description="标的代码")
|
||
asset_type: str = Field(..., description="资产类型")
|
||
adj: str = Field(..., description="复权类型")
|
||
count: int = Field(..., description="数据条数")
|
||
data: List[Dict[str, Any]] = Field(..., description="OHLCV 数据列表")
|
||
cached: Optional[bool] = Field(None, description="是否命中缓存")
|
||
|
||
# ETF 相关(仅 china_etf)
|
||
nav: Optional[NavData] = Field(None, description="ETF 净值数据")
|
||
latest_premium: Optional[float] = Field(None, description="最新溢价率(%)")
|
||
premium_series: Optional[List[PremiumSeriesItem]] = Field(None, description="溢价率序列")
|
||
premium_date: Optional[str] = Field(None, description="最新溢价率日期")
|
||
premium_stats: Optional[PremiumStats] = Field(None, description="溢价率统计")
|
||
|
||
# 元数据
|
||
attrs: Optional[Dict[str, Any]] = Field(None, description="完整元数据")
|
||
info: Optional[Dict[str, Any]] = Field(None, description="标的信息")
|
||
|
||
# 辅助信息
|
||
columns: Optional[List[str]] = Field(None, description="数据列名")
|
||
date_range: Optional[DataRange] = Field(None, description="数据日期范围")
|
||
requested_range: Optional[RequestedRange] = Field(None, description="用户请求范围")
|
||
available_range: Optional[AvailableRange] = Field(None, description="可用数据范围")
|
||
cache_strategy: Optional[str] = Field(None, description="缓存策略说明")
|
||
timeframe: Optional[str] = Field(None, description="K线周期(加密货币)")
|
||
type_override: Optional[TypeOverride] = Field(None, description="类型覆盖信息")
|
||
|
||
class Config:
|
||
extra = "allow" # 允许额外字段,保持向后兼容
|
||
|
||
|
||
class AssetTypeResponse(BaseModel):
|
||
"""资产类型检测响应"""
|
||
code: str = Field(..., description="标的代码")
|
||
asset_type: str = Field(..., description="资产类型")
|
||
description: str = Field(..., description="资产类型描述")
|
||
|
||
|
||
class CacheStats(BaseModel):
|
||
"""缓存统计信息"""
|
||
lru_cache: Dict[str, Any] = Field(..., description="LRU 缓存统计")
|
||
ttl_cache_size: int = Field(..., description="TTL 缓存大小")
|
||
ttl_seconds: int = Field(..., description="TTL 秒数")
|
||
default_start_date: str = Field(..., description="默认数据起点")
|
||
cache_strategy: str = Field(..., description="缓存策略")
|
||
|
||
|
||
class ErrorResponse(BaseModel):
|
||
"""错误响应"""
|
||
error: str = Field(..., description="错误信息")
|
||
code: Optional[str] = Field(None, description="标的代码")
|
||
asset_type: Optional[str] = Field(None, description="资产类型")
|
||
adj: Optional[str] = Field(None, description="复权类型")
|
||
detected_type: Optional[str] = Field(None, description="自动检测的类型")
|
||
example: Optional[str] = Field(None, description="示例请求")
|
||
hint: Optional[str] = Field(None, description="提示信息")
|
||
valid_adj: Optional[List[str]] = Field(None, description="有效的 adj 值")
|
||
valid_types: Optional[List[str]] = Field(None, description="有效的资产类型")
|
||
valid_timeframes: Optional[List[str]] = Field(None, description="有效的 timeframe 值")
|
||
|
||
class Config:
|
||
extra = "allow" # 允许额外字段
|
||
|
||
|
||
# ============================================================
|
||
# 模型工具函数
|
||
# ============================================================
|
||
|
||
def validate_ohlcv_response(data: Dict[str, Any]) -> OHLCVResponse:
|
||
"""
|
||
验证 OHLCV API 响应
|
||
|
||
Args:
|
||
data: API 返回的 JSON 数据
|
||
|
||
Returns:
|
||
验证后的响应对象
|
||
|
||
Raises:
|
||
ValidationError: 数据不符合模型定义
|
||
|
||
使用示例:
|
||
response = requests.get(url).json()
|
||
validated = validate_ohlcv_response(response)
|
||
|
||
# 安全访问字段
|
||
if validated.nav:
|
||
print(f"净值数据条数: {validated.nav.count}")
|
||
"""
|
||
return OHLCVResponse.model_validate(data)
|
||
|
||
|
||
def dataframe_to_records(df, date_format: str = '%Y-%m-%d') -> List[Dict[str, Any]]:
|
||
"""
|
||
将 DataFrame 转换为 OHLCVRecord 兼容的字典列表
|
||
|
||
Args:
|
||
df: pandas DataFrame
|
||
date_format: 日期格式字符串
|
||
|
||
Returns:
|
||
字典列表(可直接用于 OHLCVResponse.data)
|
||
"""
|
||
if df is None or len(df) == 0:
|
||
return []
|
||
|
||
df_reset = df.reset_index()
|
||
|
||
# 处理日期列
|
||
date_columns = ['date', 'Date', 'index', 'trade_date', 'datetime']
|
||
for col in date_columns:
|
||
if col in df_reset.columns:
|
||
try:
|
||
import pandas as pd
|
||
df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime(date_format)
|
||
if col != 'date':
|
||
df_reset = df_reset.rename(columns={col: 'date'})
|
||
break
|
||
except Exception:
|
||
pass
|
||
|
||
return df_reset.to_dict(orient='records')
|
||
|
||
|
||
# ============================================================
|
||
# DataFrame → Pydantic Model 转换函数
|
||
# ============================================================
|
||
|
||
def dataframe_to_ohlcv_response(
|
||
df: Any, # pd.DataFrame
|
||
code: str,
|
||
asset_type: str,
|
||
adj: str = 'raw',
|
||
cached: bool = False,
|
||
nav_df: Optional[Any] = None, # Optional[pd.DataFrame]
|
||
premium_series: Optional[Any] = None, # Optional[pd.Series]
|
||
info: Optional[Dict[str, Any]] = None,
|
||
attrs: Optional[Dict[str, Any]] = None,
|
||
date_format: Optional[str] = None,
|
||
**kwargs
|
||
) -> 'OHLCVResponse':
|
||
"""
|
||
将 DataFrame 转换为 OHLCVResponse 模型
|
||
|
||
用途:
|
||
- Flask API: 统一响应结构
|
||
- 本地调用: 获得类型安全的响应对象
|
||
|
||
Args:
|
||
df: 主数据 DataFrame(OHLCV)
|
||
code: 标的代码
|
||
asset_type: 资产类型
|
||
adj: 复权类型
|
||
cached: 是否命中缓存
|
||
nav_df: ETF 净值 DataFrame(可选)
|
||
premium_series: 溢价率 Series(可选)
|
||
info: 标的信息(可选)
|
||
attrs: 完整元数据(可选)
|
||
date_format: 日期格式(可选,默认根据 asset_type 自动选择)
|
||
**kwargs: 其他字段(columns, date_range, timeframe 等)
|
||
|
||
Returns:
|
||
OHLCVResponse 模型实例
|
||
|
||
使用示例:
|
||
# Flask API
|
||
df = fetcher.fetch("META", start, end)
|
||
response = dataframe_to_ohlcv_response(
|
||
df, code="META", asset_type="us_stock", adj="raw"
|
||
)
|
||
return response.model_dump(mode='json')
|
||
|
||
# 本地调用
|
||
df = fetcher.fetch("513100.SH", start, end)
|
||
nav_df = df.attrs.get('nav')
|
||
premium = df.attrs.get('premium')
|
||
|
||
response = dataframe_to_ohlcv_response(
|
||
df,
|
||
code="513100.SH",
|
||
asset_type="china_etf",
|
||
nav_df=nav_df,
|
||
premium_series=premium
|
||
)
|
||
print(response.nav.count) # IDE 有自动补全
|
||
"""
|
||
import pandas as pd
|
||
|
||
# 自动选择日期格式
|
||
if date_format is None:
|
||
date_format = '%Y-%m-%d %H:%M:%S' if asset_type == 'crypto' else '%Y-%m-%d'
|
||
|
||
# 转换主数据
|
||
data = dataframe_to_records(df, date_format) if df is not None else []
|
||
|
||
# 构建响应数据
|
||
response_data = {
|
||
"code": code,
|
||
"asset_type": asset_type,
|
||
"adj": adj,
|
||
"count": len(data),
|
||
"data": data,
|
||
"cached": cached,
|
||
}
|
||
|
||
# 添加 info(优先使用传入的,其次从 df.attrs 获取)
|
||
if info is not None:
|
||
response_data['info'] = info
|
||
elif hasattr(df, 'attrs') and df.attrs and 'info' in df.attrs:
|
||
response_data['info'] = df.attrs['info']
|
||
|
||
# 添加 nav(如果有)
|
||
if nav_df is not None and isinstance(nav_df, pd.DataFrame):
|
||
nav_records = dataframe_to_records(nav_df, date_format)
|
||
response_data['nav'] = {
|
||
"data": nav_records,
|
||
"count": len(nav_records)
|
||
}
|
||
|
||
# 添加 premium(如果有)
|
||
if premium_series is not None and isinstance(premium_series, pd.Series) and len(premium_series) > 0:
|
||
# 最新溢价率
|
||
latest_premium = float(premium_series.iloc[-1])
|
||
response_data['latest_premium'] = round(latest_premium, 6)
|
||
response_data['premium_date'] = premium_series.index[-1].strftime(date_format)
|
||
|
||
# 溢价率序列
|
||
premium_list = [
|
||
{"date": date.strftime(date_format), "premium": round(float(premium), 6)}
|
||
for date, premium in premium_series.items()
|
||
]
|
||
response_data['premium_series'] = premium_list
|
||
|
||
# 溢价率统计
|
||
response_data['premium_stats'] = {
|
||
"mean": round(float(premium_series.mean()), 6),
|
||
"std": round(float(premium_series.std()), 6),
|
||
"min": round(float(premium_series.min()), 6),
|
||
"max": round(float(premium_series.max()), 6),
|
||
"median": round(float(premium_series.median()), 6),
|
||
}
|
||
|
||
# 添加 attrs(如果有)
|
||
if attrs is not None:
|
||
# 过滤内部缓存元数据
|
||
public_attrs = {k: v for k, v in attrs.items() if not k.startswith('_cache_')}
|
||
|
||
# 转换 DataFrame/Series 为可序列化格式
|
||
attrs_serializable = {}
|
||
for key, value in public_attrs.items():
|
||
if isinstance(value, pd.DataFrame):
|
||
attrs_serializable[key] = {
|
||
'data': dataframe_to_records(value, date_format),
|
||
'count': len(value)
|
||
}
|
||
elif isinstance(value, pd.Series):
|
||
series_copy = value.copy()
|
||
series_copy.index = series_copy.index.strftime(date_format)
|
||
attrs_serializable[key] = {
|
||
'type': 'series',
|
||
'data': series_copy.to_dict(),
|
||
'name': value.name
|
||
}
|
||
else:
|
||
attrs_serializable[key] = value
|
||
|
||
if attrs_serializable:
|
||
response_data['attrs'] = attrs_serializable
|
||
elif hasattr(df, 'attrs') and df.attrs:
|
||
# 从 df.attrs 提取
|
||
public_attrs = {k: v for k, v in df.attrs.items() if not k.startswith('_cache_')}
|
||
if public_attrs:
|
||
attrs_serializable = {}
|
||
for key, value in public_attrs.items():
|
||
if isinstance(value, pd.DataFrame):
|
||
attrs_serializable[key] = {
|
||
'data': dataframe_to_records(value, date_format),
|
||
'count': len(value)
|
||
}
|
||
elif isinstance(value, pd.Series):
|
||
series_copy = value.copy()
|
||
series_copy.index = series_copy.index.strftime(date_format)
|
||
attrs_serializable[key] = {
|
||
'type': 'series',
|
||
'data': series_copy.to_dict(),
|
||
'name': value.name
|
||
}
|
||
else:
|
||
attrs_serializable[key] = value
|
||
|
||
if attrs_serializable:
|
||
response_data['attrs'] = attrs_serializable
|
||
|
||
# 添加其他辅助信息
|
||
response_data.update(kwargs)
|
||
|
||
# 验证并返回模型
|
||
return OHLCVResponse.model_validate(response_data)
|