Files
etf/datasource/flask_api_source.py
aszerW 226a27361f feat(pydantic): 集成 Pydantic 模型到 Flask API 层
1. models.py:
   - 添加 dataframe_to_ohlcv_response() 转换函数
   - 支持 DataFrame → OHLCVResponse 自动转换
   - 自动处理 nav、premium、attrs 等业务数据

2. flask_server.py:
   - 使用 Pydantic 模型构建响应(替代手动 Dict)
   - 错误响应使用 ErrorResponse 模型
   - 代码减少 20+ 行,类型安全提升

3. flask_api_source.py:
   - 使用 validate_ohlcv_response() 验证 API 响应
   - 类型安全访问 nav、premium、info 等字段
   - ETF 数据解析更可靠

测试通过:
 DataFrame → Pydantic 转换正常
 ETF 净值和溢价率正确处理
 线上 API 响应验证成功
 FlaskAPIDataSource 集成正常
2026-05-24 01:13:33 +08:00

395 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Flask API 数据源
通过部署后的 Flask API 服务获取 OHLCV 数据
支持远程调用,无需本地 SSH 隧道
"""
import os
import json
import requests
import pandas as pd
from typing import Optional, Dict, List
from datetime import datetime
from pathlib import Path
from dotenv import load_dotenv
from .models import OHLCVResponse, validate_ohlcv_response
load_dotenv()
class FlaskAPIDataSource:
"""
Flask API 数据源
通过 HTTP API 获取数据,无需本地配置 SSH 隧道
适用于远程调用或生产环境
用法:
source = FlaskAPIDataSource(base_url="https://k3s.tokenpluse.xyz")
df = source.fetch("000300.SH", "2024-01-01", "2024-12-31")
"""
def __init__(
self,
base_url: str = None,
api_path: str = "/api/v1/ohlcv",
timeout: int = 120,
retries: int = 3
):
"""
初始化
Args:
base_url: API 服务基础地址,默认从环境变量读取
api_path: API 路径
timeout: 请求超时时间(秒)
retries: 重试次数
"""
self.base_url = base_url or os.getenv(
'FLASK_API_URL',
'https://k3s.tokenpluse.xyz'
)
self.api_path = api_path
self.timeout = timeout
self.retries = retries
# 确保 base_url 不以 / 结尾
self.base_url = self.base_url.rstrip('/')
def fetch(
self,
code: str,
start_date: str,
end_date: str,
adj: str = 'raw',
asset_type: str = None,
timeframe: str = '1d'
) -> Optional[pd.DataFrame]:
"""
获取单只标的 OHLCV 数据(支持 adj 参数)
Args:
code: 标的代码
start_date: 开始日期 YYYY-MM-DD
end_date: 结束日期 YYYY-MM-DD
adj: 复权类型 'raw'(原始) / 'qfq'(前复权) / 'hfq'(后复权),默认 'raw'
asset_type: 资产类型(可选,用于覆盖自动检测)
timeframe: K线周期加密货币需要
Returns:
DataFrame with columns: date, open, high, low, close, volume
adj='hfq' 时 A股 ETF 会额外返回 adj_factor, close_hfq
示例:
# 原始价格
df = source.fetch("000300.SH", "2020-01-01", "2024-12-31")
# A股股票后复权
df = source.fetch("000001.SZ", "2020-01-01", "2024-12-31", adj='hfq')
"""
# 构建请求 URL
url = f"{self.base_url}{self.api_path}"
# 构建请求参数(包含 adj
params = {
'code': code,
'start': start_date,
'end': end_date,
'adj': adj, # 添加 adj 参数
}
# 加密货币需要 timeframe 参数
if asset_type == 'crypto' or code.upper() in ['BTC', 'ETH']:
params['timeframe'] = timeframe
# 可选:强制指定 asset_type
if asset_type:
params['asset_type'] = asset_type
for attempt in range(self.retries):
try:
response = requests.get(
url,
params=params,
timeout=self.timeout
)
if response.status_code != 200:
if attempt < self.retries - 1:
continue
print(f"✗ API请求失败: {response.status_code} - {response.text[:100]}")
return None
# 尝试解析 JSON支持 zstd 响应)
try:
data = response.json()
except (json.JSONDecodeError, requests.exceptions.JSONDecodeError):
# 如果 response.json() 失败,手动解析
data = json.loads(response.text)
# 检查错误
if 'error' in data:
print(f"✗ API返回错误: {data['error']}")
return None
# ✅ 使用 Pydantic 模型验证响应(类型安全)
try:
validated = validate_ohlcv_response(data)
except Exception as e:
print(f"{code}: 响应数据验证失败 - {e}")
return None
# 检查数据是否为空
if not validated.data:
print(f"{code}: 无数据返回")
return None
# 转换为 DataFrame
df = pd.DataFrame(validated.data)
# 处理日期列
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date')
# 确保列名标准化(保留 code 列如果存在)
standard_cols = ['open', 'high', 'low', 'close', 'volume']
if 'code' in df.columns:
standard_cols = ['code'] + standard_cols
df = df[standard_cols]
# 使用 API 返回的实际数据范围(而非请求参数)
actual_start = validated.date_range.start if validated.date_range else start_date
actual_end = validated.date_range.end if validated.date_range else end_date
actual_count = validated.count
# 缓存 info 信息(如果有)
if validated.info:
df.attrs['info'] = validated.info
# ETF 数据自动附加净值和溢价率信息
if validated.asset_type == 'china_etf':
# 净值数据
if validated.nav and validated.nav.data:
nav_df = pd.DataFrame(validated.nav.data)
if 'date' in nav_df.columns:
nav_df['date'] = pd.to_datetime(nav_df['date'])
nav_df = nav_df.set_index('date')
df.attrs['nav'] = nav_df
# 溢价率序列
if validated.premium_series:
premium_dict = {item.date: item.premium for item in validated.premium_series}
df.attrs['premium_series'] = premium_dict
# 最新溢价率
if validated.latest_premium is not None:
df.attrs['latest_premium'] = validated.latest_premium
df.attrs['premium_date'] = validated.premium_date
# 溢价率统计
if validated.premium_stats:
df.attrs['premium_stats'] = validated.premium_stats.model_dump()
print(f"{code}: {actual_count} 条数据 ({actual_start} ~ {actual_end})")
return df
except requests.exceptions.Timeout:
if attempt < self.retries - 1:
print(f"{code}: 请求超时,重试 {attempt + 2}/{self.retries}")
continue
print(f"{code}: 请求超时")
return None
except requests.exceptions.RequestException as e:
if attempt < self.retries - 1:
continue
print(f"{code}: 请求异常 - {e}")
return None
except json.JSONDecodeError as e:
print(f"{code}: JSON解析失败 - {e}")
return None
return None
def fetch_batch(
self,
codes: List[str],
start_date: str,
end_date: str,
asset_types: Dict[str, str] = None
) -> Dict[str, Optional[pd.DataFrame]]:
"""
批量获取多只标的数据
Args:
codes: 标的代码列表
start_date: 开始日期
end_date: 结束日期
asset_types: 资产类型映射 {code: asset_type}
Returns:
{code: DataFrame}
"""
results = {}
asset_types = asset_types or {}
print(f"从 Flask API 获取 {len(codes)} 只标的...")
for i, code in enumerate(codes, 1):
asset_type = asset_types.get(code)
df = self.fetch(code, start_date, end_date, asset_type)
results[code] = df
# 显示进度
if i % 5 == 0 or i == len(codes):
success = sum(1 for v in results.values() if v is not None)
print(f" 进度: {i}/{len(codes)} (成功: {success})")
return results
def fetch_etf_nav(
self,
code: str,
start_date: str,
end_date: str
) -> Optional[pd.DataFrame]:
"""
获取 ETF 净值数据
Args:
code: ETF代码
start_date: 开始日期
end_date: 结束日期
Returns:
DataFrame with nav column
"""
url = f"{self.base_url}/api/v1/etf/nav"
params = {
'code': code,
'start': start_date,
'end': end_date
}
try:
response = requests.get(url, params=params, timeout=self.timeout)
if response.status_code != 200:
return None
# 处理 zstd 响应
try:
data = response.json()
except (json.JSONDecodeError, requests.exceptions.JSONDecodeError):
data = json.loads(response.text)
if 'error' in data:
return None
# 解析净值数据
# Flask server 返回格式: {'nav': {'data': [...], 'count': N}, 'premium_series': [...]}
nav_section = data.get('nav', {})
records = nav_section.get('data', [])
if not records:
return None
df = pd.DataFrame(records)
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date')
# 添加溢价率信息(如果有)
if 'premium_series' in data:
df.attrs['premium_series'] = data['premium_series']
if 'latest_premium' in data:
df.attrs['latest_premium'] = data['latest_premium']
if 'premium_stats' in data:
df.attrs['premium_stats'] = data['premium_stats']
return df
except Exception as e:
print(f"{code} 净值获取失败: {e}")
return None
def fetch_with_adj(
self,
code: str,
start_date: str,
end_date: str,
adj: str = 'raw',
asset_type: str = None,
timeframe: str = '1d'
) -> Optional[pd.DataFrame]:
"""
获取 OHLCV 数据(支持复权参数)- 简化版
直接调用 fetch(adj=adj),无需重复实现。
Args:
code: 标的代码
start_date: 开始日期 YYYY-MM-DD
end_date: 结束日期 YYYY-MM-DD
adj: 复权参数raw/qfq/hfq默认 'raw'
asset_type: 资产类型(可选)
timeframe: K线周期加密货币需要
Returns:
DataFrame结构因 adj 参数略有不同
示例:
# A股股票后复权
df = source.fetch_with_adj("000001.SZ", "2020-01-01", "2024-12-31", adj='hfq')
"""
# 直接调用 fetch传递 adj 参数
return self.fetch(code, start_date, end_date, adj, asset_type, timeframe)
def get_health(self) -> Dict:
"""获取服务健康状态"""
# 先尝试 ohlcv 端点检查服务是否可用
url = f"{self.base_url}{self.api_path}"
params = {'code': '000300.SH', 'start': '2024-01-01', 'end': '2024-01-05'}
try:
response = requests.get(url, params=params, timeout=self.timeout)
if response.status_code == 200:
data = response.json()
return {
'status': 'healthy',
'ssh_configured': True,
'available': True
}
else:
return {'status': 'error', 'available': False}
except Exception as e:
return {'status': 'error', 'message': str(e), 'available': False}
def get_service_info(self) -> Dict:
"""获取服务信息"""
url = f"{self.base_url}/"
try:
response = requests.get(url, timeout=10)
return response.json()
except Exception as e:
return {"error": str(e)}
# 全局实例
_flask_api_source: Optional[FlaskAPIDataSource] = None
def get_flask_api_source(base_url: str = None) -> FlaskAPIDataSource:
"""获取 Flask API 数据源实例"""
global _flask_api_source
if _flask_api_source is None:
_flask_api_source = FlaskAPIDataSource(base_url=base_url)
return _flask_api_source