Files
etf/framework_v2/shared/data/flask_api_fetcher.py
aszerW 94b9ef165b feat(v2): 增强框架核心功能与ETF复权修复
- 修复 end_date=None 导致 Flask API 返回错误时间范围的 bug
  * strategy.py: 自动使用今天日期作为 end_date
  * 验证:回测区间从 77 天恢复到 1539 天

- ETF 收益计算从原始价格改为后复权价格
  * flask_api_fetcher.py: adj='raw' → adj='hfq'
  * 自动处理 ETF 份额拆分事件,确保收益率准确

- V2 简单版添加 A 股交易日过滤
  * simple.py: 获取 SSE 交易日历,过滤非交易日
  * 验证:1999 天 → 1539 天(与 V1 一致)

- 配置严格对齐 V1 config.yaml
  * config_simple.yaml: start_date 从 2020-01-01 改为 2020-01-10
  * group 字段值严格映射 V1 的 market 字段

关键验证:
- V2 简单版回测:1539 天,981.95% 收益(未计入交易成本)
- V2 正式版回测:1539 天,135.63% 收益(已计入交易成本)
- V1 旧版框架:1539 天,103.29% 收益(基准)
2026-05-24 22:53:45 +08:00

258 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Flask API 数据获取器framework_v2 实现)
继承 DataFetcher 抽象基类,使用 FlaskAPIDataSource 获取线上数据
支持指数、ETF 数据获取
"""
import pandas as pd
from typing import Dict, List, Optional
from pathlib import Path
import sys
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from framework_v2.core.data import DataFetcher
from datasource.flask_api_source import FlaskAPIDataSource
class FlaskAPIFetcher(DataFetcher):
"""
Flask API 数据获取器
通过 HTTP API 获取线上数据指数、ETF
无需本地 SSH 隧道配置
用法:
fetcher = FlaskAPIFetcher(base_url="https://k3s.tokenpluse.xyz")
data = fetcher.fetch_indices(["000300.SH"], "2024-01-01", "2024-12-31")
"""
name = "flask_api"
def __init__(
self,
base_url: str = None,
timeout: int = 120,
retries: int = 3
):
"""
初始化
Args:
base_url: API 服务地址(默认从环境变量读取)
timeout: 请求超时时间(秒)
retries: 重试次数
"""
super().__init__(base_url=base_url, timeout=timeout, retries=retries)
# 创建底层数据源
self._source = FlaskAPIDataSource(
base_url=base_url,
timeout=timeout,
retries=retries
)
def fetch_indices(
self,
codes: List[str],
start: str,
end: str
) -> Dict[str, pd.DataFrame]:
"""
获取指数 OHLCV 数据
Args:
codes: 指数代码列表(如 ["000300.SH", "000905.SH"]
start: 开始日期 (YYYY-MM-DD)
end: 结束日期 (YYYY-MM-DD)
Returns:
{code: DataFrame} 字典DataFrame 包含 OHLCV 列
示例:
>>> fetcher = FlaskAPIFetcher()
>>> data = fetcher.fetch_indices(
... ["000300.SH", "000905.SH"],
... "2024-01-01",
... "2024-12-31"
... )
>>> print(data["000300.SH"].head())
"""
print(f"\n[FlaskAPI] 获取 {len(codes)} 只指数数据...")
results = {}
for i, code in enumerate(codes, 1):
print(f" [{i}/{len(codes)}] {code}...")
df = self._source.fetch(
code=code,
start_date=start,
end_date=end,
adj='raw' # 指数通常用原始价格
)
if df is not None:
results[code] = df
print(f"{len(df)} 条数据")
else:
print(f" ✗ 获取失败")
success = len(results)
print(f"\n[FlaskAPI] 指数数据获取完成: {success}/{len(codes)} 成功")
return results
def fetch_etf(
self,
codes: List[str],
start: str,
end: str
) -> Dict[str, pd.DataFrame]:
"""
获取 ETF 数据(价格 + 净值)
Args:
codes: ETF 代码列表(如 ["510300.SH", "159919.SZ"]
start: 开始日期 (YYYY-MM-DD)
end: 结束日期 (YYYY-MM-DD)
Returns:
{code: DataFrame} 字典
DataFrame 包含 OHLCV 列
df.attrs['nav'] 包含净值数据
df.attrs['premium_series'] 包含溢价率序列
示例:
>>> fetcher = FlaskAPIFetcher()
>>> data = fetcher.fetch_etf(
... ["510300.SH", "159919.SZ"],
... "2024-01-01",
... "2024-12-31"
... )
>>> # 访问净值
>>> nav = data["510300.SH"].attrs.get('nav')
"""
print(f"\n[FlaskAPI] 获取 {len(codes)} 只 ETF 数据...")
results = {}
for i, code in enumerate(codes, 1):
print(f" [{i}/{len(codes)}] {code}...")
df = self._source.fetch(
code=code,
start_date=start,
end_date=end,
adj='hfq', # ETF 收益计算必须使用后复权价格(处理份额拆分)
asset_type='china_etf' # 强制指定 ETF 类型
)
if df is not None:
results[code] = df
# 显示附加信息
nav_count = len(df.attrs.get('nav', pd.DataFrame()))
premium = df.attrs.get('latest_premium', 'N/A')
print(f"{len(df)} 条价格, {nav_count} 条净值, 溢价率: {premium}%")
else:
print(f" ✗ 获取失败")
success = len(results)
print(f"\n[FlaskAPI] ETF 数据获取完成: {success}/{len(codes)} 成功")
return results
def get_trading_calendar(
self,
market: str = 'A',
start: str = None,
end: str = None
) -> pd.Index:
"""
获取交易日历(通过 API
Args:
market: 市场代码
- 'A''china': A股上交所/深交所)
- 'US''us': 美股NYSE
- 'HK''hk': 港股HKEX
start: 开始日期 YYYY-MM-DD默认 2020-01-01
end: 结束日期 YYYY-MM-DD默认 2025-12-31
Returns:
交易日历 DatetimeIndex
示例:
>>> fetcher = FlaskAPIFetcher()
>>> # 获取 A 股 2024 年交易日历
>>> calendar = fetcher.get_trading_calendar('A', '2024-01-01', '2024-12-31')
>>> # 获取美股交易日历
>>> calendar = fetcher.get_trading_calendar('US', '2024-01-01', '2024-12-31')
"""
# 默认日期范围
if start is None:
start = '2020-01-01'
if end is None:
end = '2025-12-31'
# 调用 API 获取准确日历
calendar = self._source.get_trading_calendar(
market=market,
start_date=start,
end_date=end
)
if calendar is None:
# API 失败,抛出异常(不应静默降级)
raise ValueError(
f"交易日历获取失败: market={market}, {start} ~ {end}"
f"请检查 API 服务是否可用。"
)
return calendar
def get_benchmark(
self,
code: str = "000300.SH",
start: str = "2020-01-01",
end: str = "2025-12-31"
) -> pd.Series:
"""
获取基准数据
Args:
code: 基准代码(默认沪深 300
start: 开始日期
end: 结束日期
Returns:
基准收盘价 Series
"""
df = self._source.fetch(
code=code,
start_date=start,
end_date=end,
adj='raw'
)
if df is None:
raise ValueError(f"基准数据获取失败: {code}")
return df['close']
def get_health(self) -> Dict:
"""
检查 API 服务健康状态
Returns:
健康状态字典
"""
return self._source.get_health()
def __repr__(self) -> str:
return f"FlaskAPIFetcher(base_url={self._source.base_url})"