改进: - fetch_indices()添加adj参数,默认'raw',可自定义 - fetch_etf()添加adj参数,默认'hfq',可自定义 - 改进日志输出,显示实际使用的adj参数 - 保持向后兼容,默认值保持原有行为 优势: - 透明性:调用者清楚知道使用的复权方式 - 灵活性:可按需获取raw/qfq/hfq数据 - 一致性:两个方法接口统一 - 向后兼容:不影响现有代码
270 lines
8.1 KiB
Python
270 lines
8.1 KiB
Python
"""
|
||
Flask API 数据获取器(framework_v2 实现)
|
||
|
||
继承 DataFetcher 抽象基类,使用 FlaskAPIDataSource 获取线上数据
|
||
支持指数、ETF 数据获取
|
||
"""
|
||
|
||
import pandas as pd
|
||
from typing import Dict, List, Optional
|
||
from pathlib import Path
|
||
import sys
|
||
|
||
# 添加项目根目录到路径
|
||
project_root = Path(__file__).parent.parent.parent.parent
|
||
if str(project_root) not in sys.path:
|
||
sys.path.insert(0, str(project_root))
|
||
|
||
from framework_v2.core.data import DataFetcher
|
||
from datasource.flask_api_source import FlaskAPIDataSource
|
||
|
||
|
||
class FlaskAPIFetcher(DataFetcher):
|
||
"""
|
||
Flask API 数据获取器
|
||
|
||
通过 HTTP API 获取线上数据(指数、ETF)
|
||
无需本地 SSH 隧道配置
|
||
|
||
用法:
|
||
fetcher = FlaskAPIFetcher(base_url="https://k3s.tokenpluse.xyz")
|
||
data = fetcher.fetch_indices(["000300.SH"], "2024-01-01", "2024-12-31")
|
||
"""
|
||
|
||
name = "flask_api"
|
||
|
||
def __init__(
|
||
self,
|
||
base_url: str = None,
|
||
timeout: int = 120,
|
||
retries: int = 3
|
||
):
|
||
"""
|
||
初始化
|
||
|
||
Args:
|
||
base_url: API 服务地址(默认从环境变量读取)
|
||
timeout: 请求超时时间(秒)
|
||
retries: 重试次数
|
||
"""
|
||
super().__init__(base_url=base_url, timeout=timeout, retries=retries)
|
||
|
||
# 创建底层数据源
|
||
self._source = FlaskAPIDataSource(
|
||
base_url=base_url,
|
||
timeout=timeout,
|
||
retries=retries
|
||
)
|
||
|
||
def fetch_indices(
|
||
self,
|
||
codes: List[str],
|
||
start: str,
|
||
end: str,
|
||
adj: str = 'raw'
|
||
) -> Dict[str, pd.DataFrame]:
|
||
"""
|
||
获取指数 OHLCV 数据
|
||
|
||
Args:
|
||
codes: 指数代码列表(如 ["000300.SH", "000905.SH"])
|
||
start: 开始日期 (YYYY-MM-DD)
|
||
end: 结束日期 (YYYY-MM-DD)
|
||
adj: 复权类型,默认 'raw'(指数通常用原始价格)
|
||
|
||
Returns:
|
||
{code: DataFrame} 字典,DataFrame 包含 OHLCV 列
|
||
|
||
示例:
|
||
>>> fetcher = FlaskAPIFetcher()
|
||
>>> data = fetcher.fetch_indices(
|
||
... ["000300.SH", "000905.SH"],
|
||
... "2024-01-01",
|
||
... "2024-12-31"
|
||
... )
|
||
>>> print(data["000300.SH"].head())
|
||
"""
|
||
print(f"\n[FlaskAPI] 获取 {len(codes)} 只指数数据(adj='{adj}')...")
|
||
|
||
results = {}
|
||
for i, code in enumerate(codes, 1):
|
||
print(f" [{i}/{len(codes)}] {code}...")
|
||
|
||
df = self._source.fetch(
|
||
code=code,
|
||
start_date=start,
|
||
end_date=end,
|
||
adj=adj # 使用传入的 adj 参数
|
||
)
|
||
|
||
if df is not None:
|
||
results[code] = df
|
||
print(f" ✓ {len(df)} 条数据")
|
||
else:
|
||
print(f" ✗ 获取失败")
|
||
|
||
success = len(results)
|
||
print(f"\n[FlaskAPI] 指数数据获取完成: {success}/{len(codes)} 成功")
|
||
|
||
return results
|
||
|
||
def fetch_etf(
|
||
self,
|
||
codes: List[str],
|
||
start: str,
|
||
end: str,
|
||
adj: str = 'hfq'
|
||
) -> Dict[str, pd.DataFrame]:
|
||
"""
|
||
获取 ETF 数据(价格 + 净值)
|
||
|
||
Args:
|
||
codes: ETF 代码列表(如 ["510300.SH", "159919.SZ"])
|
||
start: 开始日期 (YYYY-MM-DD)
|
||
end: 结束日期 (YYYY-MM-DD)
|
||
adj: 复权类型,默认 'hfq'(ETF 收益计算推荐后复权)
|
||
|
||
Returns:
|
||
{code: DataFrame} 字典
|
||
DataFrame 包含 OHLCV 列
|
||
df.attrs['nav'] 包含净值数据
|
||
df.attrs['premium_series'] 包含溢价率序列
|
||
|
||
示例:
|
||
>>> fetcher = FlaskAPIFetcher()
|
||
>>> # 默认使用 hfq(后复权)
|
||
>>> data = fetcher.fetch_etf(
|
||
... ["510300.SH", "159919.SZ"],
|
||
... "2024-01-01",
|
||
... "2024-12-31"
|
||
... )
|
||
>>> # 或者显式指定 raw(原始价格,用于计算溢价率)
|
||
>>> data_raw = fetcher.fetch_etf(
|
||
... ["510300.SH"],
|
||
... "2024-01-01",
|
||
... "2024-12-31",
|
||
... adj='raw'
|
||
... )
|
||
>>> # 访问净值
|
||
>>> nav = data["510300.SH"].attrs.get('nav')
|
||
"""
|
||
print(f"\n[FlaskAPI] 获取 {len(codes)} 只 ETF 数据(adj='{adj}')...")
|
||
|
||
results = {}
|
||
for i, code in enumerate(codes, 1):
|
||
print(f" [{i}/{len(codes)}] {code}...")
|
||
|
||
df = self._source.fetch(
|
||
code=code,
|
||
start_date=start,
|
||
end_date=end,
|
||
adj=adj, # 使用传入的 adj 参数
|
||
asset_type='china_etf' # 强制指定 ETF 类型
|
||
)
|
||
|
||
if df is not None:
|
||
results[code] = df
|
||
|
||
# 显示附加信息
|
||
nav_count = len(df.attrs.get('nav', pd.DataFrame()))
|
||
premium = df.attrs.get('latest_premium', 'N/A')
|
||
|
||
print(f" ✓ {len(df)} 条价格, {nav_count} 条净值, 溢价率: {premium}%")
|
||
else:
|
||
print(f" ✗ 获取失败")
|
||
|
||
success = len(results)
|
||
print(f"\n[FlaskAPI] ETF 数据获取完成: {success}/{len(codes)} 成功")
|
||
|
||
return results
|
||
|
||
def get_trading_calendar(
|
||
self,
|
||
market: str = 'A',
|
||
start: str = None,
|
||
end: str = None
|
||
) -> pd.Index:
|
||
"""
|
||
获取交易日历(通过 API)
|
||
|
||
Args:
|
||
market: 市场代码
|
||
- 'A' 或 'china': A股(上交所/深交所)
|
||
- 'US' 或 'us': 美股(NYSE)
|
||
- 'HK' 或 'hk': 港股(HKEX)
|
||
start: 开始日期 YYYY-MM-DD(默认 2020-01-01)
|
||
end: 结束日期 YYYY-MM-DD(默认 2025-12-31)
|
||
|
||
Returns:
|
||
交易日历 DatetimeIndex
|
||
|
||
示例:
|
||
>>> fetcher = FlaskAPIFetcher()
|
||
>>> # 获取 A 股 2024 年交易日历
|
||
>>> calendar = fetcher.get_trading_calendar('A', '2024-01-01', '2024-12-31')
|
||
>>> # 获取美股交易日历
|
||
>>> calendar = fetcher.get_trading_calendar('US', '2024-01-01', '2024-12-31')
|
||
"""
|
||
# 默认日期范围
|
||
if start is None:
|
||
start = '2020-01-01'
|
||
if end is None:
|
||
end = '2025-12-31'
|
||
|
||
# 调用 API 获取准确日历
|
||
calendar = self._source.get_trading_calendar(
|
||
market=market,
|
||
start_date=start,
|
||
end_date=end
|
||
)
|
||
|
||
if calendar is None:
|
||
# API 失败,抛出异常(不应静默降级)
|
||
raise ValueError(
|
||
f"交易日历获取失败: market={market}, {start} ~ {end}。"
|
||
f"请检查 API 服务是否可用。"
|
||
)
|
||
|
||
return calendar
|
||
|
||
def get_benchmark(
|
||
self,
|
||
code: str = "000300.SH",
|
||
start: str = "2020-01-01",
|
||
end: str = "2025-12-31"
|
||
) -> pd.Series:
|
||
"""
|
||
获取基准数据
|
||
|
||
Args:
|
||
code: 基准代码(默认沪深 300)
|
||
start: 开始日期
|
||
end: 结束日期
|
||
|
||
Returns:
|
||
基准收盘价 Series
|
||
"""
|
||
df = self._source.fetch(
|
||
code=code,
|
||
start_date=start,
|
||
end_date=end,
|
||
adj='raw'
|
||
)
|
||
|
||
if df is None:
|
||
raise ValueError(f"基准数据获取失败: {code}")
|
||
|
||
return df['close']
|
||
|
||
def get_health(self) -> Dict:
|
||
"""
|
||
检查 API 服务健康状态
|
||
|
||
Returns:
|
||
健康状态字典
|
||
"""
|
||
return self._source.get_health()
|
||
|
||
def __repr__(self) -> str:
|
||
return f"FlaskAPIFetcher(base_url={self._source.base_url})"
|