Files
etf/framework_v2/shared/data/flask_api_fetcher.py
aszerW 40116f436f feat(framework_v2): 添加 FlaskAPIFetcher 数据获取器
## 核心功能
- FlaskAPIFetcher: 继承 DataFetcher 抽象基类
- fetch_indices(): 获取指数 OHLCV 数据
- fetch_etf(): 获取 ETF 数据(自动附加净值+溢价率)
- get_trading_calendar(): 获取交易日历
- get_benchmark(): 获取基准数据

## 技术实现
- 委托调用 FlaskAPIDataSource(HTTP API)
- 自动重试 3 次,超时 120 秒
- Pydantic Schema 验证响应
- 进度显示(批量获取)
- 无需本地 SSH 隧道配置

## 测试验证
- 5/5 测试通过(健康检查、指数、ETF、日历、基准)
- 成功获取线上数据(000300.SH, 510300.SH)
- ETF 自动附加净值(3695 条)和溢价率

## 架构设计
- shared/data/flask_api_fetcher.py - 实现(262 行)
- tests/test_flask_api_fetcher.py - 测试(199 行)
- 依赖倒置原则(策略依赖抽象接口)
2026-05-24 10:38:34 +08:00

262 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Flask API 数据获取器framework_v2 实现)
继承 DataFetcher 抽象基类,使用 FlaskAPIDataSource 获取线上数据
支持指数、ETF 数据获取
"""
import pandas as pd
from typing import Dict, List, Optional
from pathlib import Path
import sys
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from framework_v2.core.data import DataFetcher
from datasource.flask_api_source import FlaskAPIDataSource
class FlaskAPIFetcher(DataFetcher):
"""
Flask API 数据获取器
通过 HTTP API 获取线上数据指数、ETF
无需本地 SSH 隧道配置
用法:
fetcher = FlaskAPIFetcher(base_url="https://k3s.tokenpluse.xyz")
data = fetcher.fetch_indices(["000300.SH"], "2024-01-01", "2024-12-31")
"""
name = "flask_api"
def __init__(
self,
base_url: str = None,
timeout: int = 120,
retries: int = 3
):
"""
初始化
Args:
base_url: API 服务地址(默认从环境变量读取)
timeout: 请求超时时间(秒)
retries: 重试次数
"""
super().__init__(base_url=base_url, timeout=timeout, retries=retries)
# 创建底层数据源
self._source = FlaskAPIDataSource(
base_url=base_url,
timeout=timeout,
retries=retries
)
def fetch_indices(
self,
codes: List[str],
start: str,
end: str
) -> Dict[str, pd.DataFrame]:
"""
获取指数 OHLCV 数据
Args:
codes: 指数代码列表(如 ["000300.SH", "000905.SH"]
start: 开始日期 (YYYY-MM-DD)
end: 结束日期 (YYYY-MM-DD)
Returns:
{code: DataFrame} 字典DataFrame 包含 OHLCV 列
示例:
>>> fetcher = FlaskAPIFetcher()
>>> data = fetcher.fetch_indices(
... ["000300.SH", "000905.SH"],
... "2024-01-01",
... "2024-12-31"
... )
>>> print(data["000300.SH"].head())
"""
print(f"\n[FlaskAPI] 获取 {len(codes)} 只指数数据...")
results = {}
for i, code in enumerate(codes, 1):
print(f" [{i}/{len(codes)}] {code}...")
df = self._source.fetch(
code=code,
start_date=start,
end_date=end,
adj='raw' # 指数通常用原始价格
)
if df is not None:
results[code] = df
print(f"{len(df)} 条数据")
else:
print(f" ✗ 获取失败")
success = len(results)
print(f"\n[FlaskAPI] 指数数据获取完成: {success}/{len(codes)} 成功")
return results
def fetch_etf(
self,
codes: List[str],
start: str,
end: str
) -> Dict[str, pd.DataFrame]:
"""
获取 ETF 数据(价格 + 净值)
Args:
codes: ETF 代码列表(如 ["510300.SH", "159919.SZ"]
start: 开始日期 (YYYY-MM-DD)
end: 结束日期 (YYYY-MM-DD)
Returns:
{code: DataFrame} 字典
DataFrame 包含 OHLCV 列
df.attrs['nav'] 包含净值数据
df.attrs['premium_series'] 包含溢价率序列
示例:
>>> fetcher = FlaskAPIFetcher()
>>> data = fetcher.fetch_etf(
... ["510300.SH", "159919.SZ"],
... "2024-01-01",
... "2024-12-31"
... )
>>> # 访问净值
>>> nav = data["510300.SH"].attrs.get('nav')
"""
print(f"\n[FlaskAPI] 获取 {len(codes)} 只 ETF 数据...")
results = {}
for i, code in enumerate(codes, 1):
print(f" [{i}/{len(codes)}] {code}...")
df = self._source.fetch(
code=code,
start_date=start,
end_date=end,
adj='raw',
asset_type='china_etf' # 强制指定 ETF 类型
)
if df is not None:
results[code] = df
# 显示附加信息
nav_count = len(df.attrs.get('nav', pd.DataFrame()))
premium = df.attrs.get('latest_premium', 'N/A')
print(f"{len(df)} 条价格, {nav_count} 条净值, 溢价率: {premium}%")
else:
print(f" ✗ 获取失败")
success = len(results)
print(f"\n[FlaskAPI] ETF 数据获取完成: {success}/{len(codes)} 成功")
return results
def get_trading_calendar(self, market: str = 'A') -> pd.Index:
"""
获取交易日历
注意Flask API 暂不直接提供交易日历
这里使用 pandas 的 BDay 生成近似日历
TODO: 后续可通过 API 端点获取准确日历
Args:
market: 市场代码('A', 'US', 'HK' 等)
Returns:
交易日历 Index
"""
# 临时实现:使用 pandas 生成工作日日历
# 实际应该从 API 获取准确的交易日历
if market == 'A':
# A股中国工作日简化实现
start = pd.Timestamp('2020-01-01')
end = pd.Timestamp('2025-12-31')
calendar = pd.bdate_range(start=start, end=end)
# 移除中国主要节假日(简化版)
# 实际应该从 API 或数据库获取准确日历
holidays = [
# 春节(示例,不完整)
'2024-02-10', '2024-02-11', '2024-02-12', '2024-02-13', '2024-02-14',
'2024-02-15', '2024-02-16', '2024-02-17',
# 国庆(示例,不完整)
'2024-10-01', '2024-10-02', '2024-10-03', '2024-10-04',
'2024-10-05', '2024-10-06', '2024-10-07',
]
calendar = calendar[~calendar.isin(pd.to_datetime(holidays))]
return calendar
elif market == 'US':
# 美股:美国工作日
start = pd.Timestamp('2020-01-01')
end = pd.Timestamp('2025-12-31')
return pd.bdate_range(start=start, end=end)
elif market == 'HK':
# 港股:香港工作日
start = pd.Timestamp('2020-01-01')
end = pd.Timestamp('2025-12-31')
return pd.bdate_range(start=start, end=end)
else:
raise ValueError(f"不支持的市场: {market}")
def get_benchmark(
self,
code: str = "000300.SH",
start: str = "2020-01-01",
end: str = "2025-12-31"
) -> pd.Series:
"""
获取基准数据
Args:
code: 基准代码(默认沪深 300
start: 开始日期
end: 结束日期
Returns:
基准收盘价 Series
"""
df = self._source.fetch(
code=code,
start_date=start,
end_date=end,
adj='raw'
)
if df is None:
raise ValueError(f"基准数据获取失败: {code}")
return df['close']
def get_health(self) -> Dict:
"""
检查 API 服务健康状态
Returns:
健康状态字典
"""
return self._source.get_health()
def __repr__(self) -> str:
return f"FlaskAPIFetcher(base_url={self._source.base_url})"