Files
etf/framework_v2/shared/data/flask_api_fetcher.py
aszerW 2ff48e8d56 refactor(flask_api_fetcher): 暴露adj参数,增强接口透明度和灵活性
改进:
- fetch_indices()添加adj参数,默认'raw',可自定义
- fetch_etf()添加adj参数,默认'hfq',可自定义
- 改进日志输出,显示实际使用的adj参数
- 保持向后兼容,默认值保持原有行为

优势:
- 透明性:调用者清楚知道使用的复权方式
- 灵活性:可按需获取raw/qfq/hfq数据
- 一致性:两个方法接口统一
- 向后兼容:不影响现有代码
2026-05-26 19:54:41 +08:00

270 lines
8.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Flask API 数据获取器framework_v2 实现)
继承 DataFetcher 抽象基类,使用 FlaskAPIDataSource 获取线上数据
支持指数、ETF 数据获取
"""
import pandas as pd
from typing import Dict, List, Optional
from pathlib import Path
import sys
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from framework_v2.core.data import DataFetcher
from datasource.flask_api_source import FlaskAPIDataSource
class FlaskAPIFetcher(DataFetcher):
"""
Flask API 数据获取器
通过 HTTP API 获取线上数据指数、ETF
无需本地 SSH 隧道配置
用法:
fetcher = FlaskAPIFetcher(base_url="https://k3s.tokenpluse.xyz")
data = fetcher.fetch_indices(["000300.SH"], "2024-01-01", "2024-12-31")
"""
name = "flask_api"
def __init__(
self,
base_url: str = None,
timeout: int = 120,
retries: int = 3
):
"""
初始化
Args:
base_url: API 服务地址(默认从环境变量读取)
timeout: 请求超时时间(秒)
retries: 重试次数
"""
super().__init__(base_url=base_url, timeout=timeout, retries=retries)
# 创建底层数据源
self._source = FlaskAPIDataSource(
base_url=base_url,
timeout=timeout,
retries=retries
)
def fetch_indices(
self,
codes: List[str],
start: str,
end: str,
adj: str = 'raw'
) -> Dict[str, pd.DataFrame]:
"""
获取指数 OHLCV 数据
Args:
codes: 指数代码列表(如 ["000300.SH", "000905.SH"]
start: 开始日期 (YYYY-MM-DD)
end: 结束日期 (YYYY-MM-DD)
adj: 复权类型,默认 'raw'(指数通常用原始价格)
Returns:
{code: DataFrame} 字典DataFrame 包含 OHLCV 列
示例:
>>> fetcher = FlaskAPIFetcher()
>>> data = fetcher.fetch_indices(
... ["000300.SH", "000905.SH"],
... "2024-01-01",
... "2024-12-31"
... )
>>> print(data["000300.SH"].head())
"""
print(f"\n[FlaskAPI] 获取 {len(codes)} 只指数数据adj='{adj}'...")
results = {}
for i, code in enumerate(codes, 1):
print(f" [{i}/{len(codes)}] {code}...")
df = self._source.fetch(
code=code,
start_date=start,
end_date=end,
adj=adj # 使用传入的 adj 参数
)
if df is not None:
results[code] = df
print(f"{len(df)} 条数据")
else:
print(f" ✗ 获取失败")
success = len(results)
print(f"\n[FlaskAPI] 指数数据获取完成: {success}/{len(codes)} 成功")
return results
def fetch_etf(
self,
codes: List[str],
start: str,
end: str,
adj: str = 'hfq'
) -> Dict[str, pd.DataFrame]:
"""
获取 ETF 数据(价格 + 净值)
Args:
codes: ETF 代码列表(如 ["510300.SH", "159919.SZ"]
start: 开始日期 (YYYY-MM-DD)
end: 结束日期 (YYYY-MM-DD)
adj: 复权类型,默认 'hfq'ETF 收益计算推荐后复权)
Returns:
{code: DataFrame} 字典
DataFrame 包含 OHLCV 列
df.attrs['nav'] 包含净值数据
df.attrs['premium_series'] 包含溢价率序列
示例:
>>> fetcher = FlaskAPIFetcher()
>>> # 默认使用 hfq后复权
>>> data = fetcher.fetch_etf(
... ["510300.SH", "159919.SZ"],
... "2024-01-01",
... "2024-12-31"
... )
>>> # 或者显式指定 raw原始价格用于计算溢价率
>>> data_raw = fetcher.fetch_etf(
... ["510300.SH"],
... "2024-01-01",
... "2024-12-31",
... adj='raw'
... )
>>> # 访问净值
>>> nav = data["510300.SH"].attrs.get('nav')
"""
print(f"\n[FlaskAPI] 获取 {len(codes)} 只 ETF 数据adj='{adj}'...")
results = {}
for i, code in enumerate(codes, 1):
print(f" [{i}/{len(codes)}] {code}...")
df = self._source.fetch(
code=code,
start_date=start,
end_date=end,
adj=adj, # 使用传入的 adj 参数
asset_type='china_etf' # 强制指定 ETF 类型
)
if df is not None:
results[code] = df
# 显示附加信息
nav_count = len(df.attrs.get('nav', pd.DataFrame()))
premium = df.attrs.get('latest_premium', 'N/A')
print(f"{len(df)} 条价格, {nav_count} 条净值, 溢价率: {premium}%")
else:
print(f" ✗ 获取失败")
success = len(results)
print(f"\n[FlaskAPI] ETF 数据获取完成: {success}/{len(codes)} 成功")
return results
def get_trading_calendar(
self,
market: str = 'A',
start: str = None,
end: str = None
) -> pd.Index:
"""
获取交易日历(通过 API
Args:
market: 市场代码
- 'A''china': A股上交所/深交所)
- 'US''us': 美股NYSE
- 'HK''hk': 港股HKEX
start: 开始日期 YYYY-MM-DD默认 2020-01-01
end: 结束日期 YYYY-MM-DD默认 2025-12-31
Returns:
交易日历 DatetimeIndex
示例:
>>> fetcher = FlaskAPIFetcher()
>>> # 获取 A 股 2024 年交易日历
>>> calendar = fetcher.get_trading_calendar('A', '2024-01-01', '2024-12-31')
>>> # 获取美股交易日历
>>> calendar = fetcher.get_trading_calendar('US', '2024-01-01', '2024-12-31')
"""
# 默认日期范围
if start is None:
start = '2020-01-01'
if end is None:
end = '2025-12-31'
# 调用 API 获取准确日历
calendar = self._source.get_trading_calendar(
market=market,
start_date=start,
end_date=end
)
if calendar is None:
# API 失败,抛出异常(不应静默降级)
raise ValueError(
f"交易日历获取失败: market={market}, {start} ~ {end}"
f"请检查 API 服务是否可用。"
)
return calendar
def get_benchmark(
self,
code: str = "000300.SH",
start: str = "2020-01-01",
end: str = "2025-12-31"
) -> pd.Series:
"""
获取基准数据
Args:
code: 基准代码(默认沪深 300
start: 开始日期
end: 结束日期
Returns:
基准收盘价 Series
"""
df = self._source.fetch(
code=code,
start_date=start,
end_date=end,
adj='raw'
)
if df is None:
raise ValueError(f"基准数据获取失败: {code}")
return df['close']
def get_health(self) -> Dict:
"""
检查 API 服务健康状态
Returns:
健康状态字典
"""
return self._source.get_health()
def __repr__(self) -> str:
return f"FlaskAPIFetcher(base_url={self._source.base_url})"