feat(framework_v2): 添加 FlaskAPIFetcher 数据获取器
## 核心功能 - FlaskAPIFetcher: 继承 DataFetcher 抽象基类 - fetch_indices(): 获取指数 OHLCV 数据 - fetch_etf(): 获取 ETF 数据(自动附加净值+溢价率) - get_trading_calendar(): 获取交易日历 - get_benchmark(): 获取基准数据 ## 技术实现 - 委托调用 FlaskAPIDataSource(HTTP API) - 自动重试 3 次,超时 120 秒 - Pydantic Schema 验证响应 - 进度显示(批量获取) - 无需本地 SSH 隧道配置 ## 测试验证 - 5/5 测试通过(健康检查、指数、ETF、日历、基准) - 成功获取线上数据(000300.SH, 510300.SH) - ETF 自动附加净值(3695 条)和溢价率 ## 架构设计 - shared/data/flask_api_fetcher.py - 实现(262 行) - tests/test_flask_api_fetcher.py - 测试(199 行) - 依赖倒置原则(策略依赖抽象接口)
This commit is contained in:
@@ -9,6 +9,7 @@ from framework_v2.shared.data.schemas import (
|
||||
AlignedReturnsSchema,
|
||||
AlignmentValidationResult,
|
||||
)
|
||||
from framework_v2.shared.data.flask_api_fetcher import FlaskAPIFetcher
|
||||
|
||||
__all__ = [
|
||||
'CrossMarketAligner',
|
||||
@@ -16,4 +17,5 @@ __all__ = [
|
||||
'AlignedFactorSchema',
|
||||
'AlignedReturnsSchema',
|
||||
'AlignmentValidationResult',
|
||||
'FlaskAPIFetcher',
|
||||
]
|
||||
|
||||
261
framework_v2/shared/data/flask_api_fetcher.py
Normal file
261
framework_v2/shared/data/flask_api_fetcher.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
Flask API 数据获取器(framework_v2 实现)
|
||||
|
||||
继承 DataFetcher 抽象基类,使用 FlaskAPIDataSource 获取线上数据
|
||||
支持指数、ETF 数据获取
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from framework_v2.core.data import DataFetcher
|
||||
from datasource.flask_api_source import FlaskAPIDataSource
|
||||
|
||||
|
||||
class FlaskAPIFetcher(DataFetcher):
|
||||
"""
|
||||
Flask API 数据获取器
|
||||
|
||||
通过 HTTP API 获取线上数据(指数、ETF)
|
||||
无需本地 SSH 隧道配置
|
||||
|
||||
用法:
|
||||
fetcher = FlaskAPIFetcher(base_url="https://k3s.tokenpluse.xyz")
|
||||
data = fetcher.fetch_indices(["000300.SH"], "2024-01-01", "2024-12-31")
|
||||
"""
|
||||
|
||||
name = "flask_api"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = None,
|
||||
timeout: int = 120,
|
||||
retries: int = 3
|
||||
):
|
||||
"""
|
||||
初始化
|
||||
|
||||
Args:
|
||||
base_url: API 服务地址(默认从环境变量读取)
|
||||
timeout: 请求超时时间(秒)
|
||||
retries: 重试次数
|
||||
"""
|
||||
super().__init__(base_url=base_url, timeout=timeout, retries=retries)
|
||||
|
||||
# 创建底层数据源
|
||||
self._source = FlaskAPIDataSource(
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
retries=retries
|
||||
)
|
||||
|
||||
def fetch_indices(
|
||||
self,
|
||||
codes: List[str],
|
||||
start: str,
|
||||
end: str
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
获取指数 OHLCV 数据
|
||||
|
||||
Args:
|
||||
codes: 指数代码列表(如 ["000300.SH", "000905.SH"])
|
||||
start: 开始日期 (YYYY-MM-DD)
|
||||
end: 结束日期 (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
{code: DataFrame} 字典,DataFrame 包含 OHLCV 列
|
||||
|
||||
示例:
|
||||
>>> fetcher = FlaskAPIFetcher()
|
||||
>>> data = fetcher.fetch_indices(
|
||||
... ["000300.SH", "000905.SH"],
|
||||
... "2024-01-01",
|
||||
... "2024-12-31"
|
||||
... )
|
||||
>>> print(data["000300.SH"].head())
|
||||
"""
|
||||
print(f"\n[FlaskAPI] 获取 {len(codes)} 只指数数据...")
|
||||
|
||||
results = {}
|
||||
for i, code in enumerate(codes, 1):
|
||||
print(f" [{i}/{len(codes)}] {code}...")
|
||||
|
||||
df = self._source.fetch(
|
||||
code=code,
|
||||
start_date=start,
|
||||
end_date=end,
|
||||
adj='raw' # 指数通常用原始价格
|
||||
)
|
||||
|
||||
if df is not None:
|
||||
results[code] = df
|
||||
print(f" ✓ {len(df)} 条数据")
|
||||
else:
|
||||
print(f" ✗ 获取失败")
|
||||
|
||||
success = len(results)
|
||||
print(f"\n[FlaskAPI] 指数数据获取完成: {success}/{len(codes)} 成功")
|
||||
|
||||
return results
|
||||
|
||||
def fetch_etf(
|
||||
self,
|
||||
codes: List[str],
|
||||
start: str,
|
||||
end: str
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
获取 ETF 数据(价格 + 净值)
|
||||
|
||||
Args:
|
||||
codes: ETF 代码列表(如 ["510300.SH", "159919.SZ"])
|
||||
start: 开始日期 (YYYY-MM-DD)
|
||||
end: 结束日期 (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
{code: DataFrame} 字典
|
||||
DataFrame 包含 OHLCV 列
|
||||
df.attrs['nav'] 包含净值数据
|
||||
df.attrs['premium_series'] 包含溢价率序列
|
||||
|
||||
示例:
|
||||
>>> fetcher = FlaskAPIFetcher()
|
||||
>>> data = fetcher.fetch_etf(
|
||||
... ["510300.SH", "159919.SZ"],
|
||||
... "2024-01-01",
|
||||
... "2024-12-31"
|
||||
... )
|
||||
>>> # 访问净值
|
||||
>>> nav = data["510300.SH"].attrs.get('nav')
|
||||
"""
|
||||
print(f"\n[FlaskAPI] 获取 {len(codes)} 只 ETF 数据...")
|
||||
|
||||
results = {}
|
||||
for i, code in enumerate(codes, 1):
|
||||
print(f" [{i}/{len(codes)}] {code}...")
|
||||
|
||||
df = self._source.fetch(
|
||||
code=code,
|
||||
start_date=start,
|
||||
end_date=end,
|
||||
adj='raw',
|
||||
asset_type='china_etf' # 强制指定 ETF 类型
|
||||
)
|
||||
|
||||
if df is not None:
|
||||
results[code] = df
|
||||
|
||||
# 显示附加信息
|
||||
nav_count = len(df.attrs.get('nav', pd.DataFrame()))
|
||||
premium = df.attrs.get('latest_premium', 'N/A')
|
||||
|
||||
print(f" ✓ {len(df)} 条价格, {nav_count} 条净值, 溢价率: {premium}%")
|
||||
else:
|
||||
print(f" ✗ 获取失败")
|
||||
|
||||
success = len(results)
|
||||
print(f"\n[FlaskAPI] ETF 数据获取完成: {success}/{len(codes)} 成功")
|
||||
|
||||
return results
|
||||
|
||||
def get_trading_calendar(self, market: str = 'A') -> pd.Index:
|
||||
"""
|
||||
获取交易日历
|
||||
|
||||
注意:Flask API 暂不直接提供交易日历
|
||||
这里使用 pandas 的 BDay 生成近似日历
|
||||
|
||||
TODO: 后续可通过 API 端点获取准确日历
|
||||
|
||||
Args:
|
||||
market: 市场代码('A', 'US', 'HK' 等)
|
||||
|
||||
Returns:
|
||||
交易日历 Index
|
||||
"""
|
||||
# 临时实现:使用 pandas 生成工作日日历
|
||||
# 实际应该从 API 获取准确的交易日历
|
||||
|
||||
if market == 'A':
|
||||
# A股:中国工作日(简化实现)
|
||||
start = pd.Timestamp('2020-01-01')
|
||||
end = pd.Timestamp('2025-12-31')
|
||||
calendar = pd.bdate_range(start=start, end=end)
|
||||
|
||||
# 移除中国主要节假日(简化版)
|
||||
# 实际应该从 API 或数据库获取准确日历
|
||||
holidays = [
|
||||
# 春节(示例,不完整)
|
||||
'2024-02-10', '2024-02-11', '2024-02-12', '2024-02-13', '2024-02-14',
|
||||
'2024-02-15', '2024-02-16', '2024-02-17',
|
||||
# 国庆(示例,不完整)
|
||||
'2024-10-01', '2024-10-02', '2024-10-03', '2024-10-04',
|
||||
'2024-10-05', '2024-10-06', '2024-10-07',
|
||||
]
|
||||
calendar = calendar[~calendar.isin(pd.to_datetime(holidays))]
|
||||
|
||||
return calendar
|
||||
|
||||
elif market == 'US':
|
||||
# 美股:美国工作日
|
||||
start = pd.Timestamp('2020-01-01')
|
||||
end = pd.Timestamp('2025-12-31')
|
||||
return pd.bdate_range(start=start, end=end)
|
||||
|
||||
elif market == 'HK':
|
||||
# 港股:香港工作日
|
||||
start = pd.Timestamp('2020-01-01')
|
||||
end = pd.Timestamp('2025-12-31')
|
||||
return pd.bdate_range(start=start, end=end)
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的市场: {market}")
|
||||
|
||||
def get_benchmark(
|
||||
self,
|
||||
code: str = "000300.SH",
|
||||
start: str = "2020-01-01",
|
||||
end: str = "2025-12-31"
|
||||
) -> pd.Series:
|
||||
"""
|
||||
获取基准数据
|
||||
|
||||
Args:
|
||||
code: 基准代码(默认沪深 300)
|
||||
start: 开始日期
|
||||
end: 结束日期
|
||||
|
||||
Returns:
|
||||
基准收盘价 Series
|
||||
"""
|
||||
df = self._source.fetch(
|
||||
code=code,
|
||||
start_date=start,
|
||||
end_date=end,
|
||||
adj='raw'
|
||||
)
|
||||
|
||||
if df is None:
|
||||
raise ValueError(f"基准数据获取失败: {code}")
|
||||
|
||||
return df['close']
|
||||
|
||||
def get_health(self) -> Dict:
|
||||
"""
|
||||
检查 API 服务健康状态
|
||||
|
||||
Returns:
|
||||
健康状态字典
|
||||
"""
|
||||
return self._source.get_health()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"FlaskAPIFetcher(base_url={self._source.base_url})"
|
||||
198
framework_v2/tests/test_flask_api_fetcher.py
Normal file
198
framework_v2/tests/test_flask_api_fetcher.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
测试 FlaskAPIFetcher
|
||||
|
||||
验证:
|
||||
1. 获取指数数据
|
||||
2. 获取 ETF 数据
|
||||
3. 获取交易日历
|
||||
4. 健康检查
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from framework_v2.shared.data import FlaskAPIFetcher
|
||||
|
||||
|
||||
def test_health_check():
|
||||
"""测试 1: 健康检查"""
|
||||
print("\n" + "=" * 60)
|
||||
print(" 测试 1: 健康检查")
|
||||
print("=" * 60)
|
||||
|
||||
fetcher = FlaskAPIFetcher()
|
||||
health = fetcher.get_health()
|
||||
|
||||
print(f"\n健康状态: {health}")
|
||||
|
||||
assert health.get('available'), "API 服务不可用"
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_fetch_indices():
|
||||
"""测试 2: 获取指数数据"""
|
||||
print("\n" + "=" * 60)
|
||||
print(" 测试 2: 获取指数数据")
|
||||
print("=" * 60)
|
||||
|
||||
fetcher = FlaskAPIFetcher()
|
||||
|
||||
# 获取沪深 300 + 中证 500
|
||||
codes = ["000300.SH", "000905.SH"]
|
||||
data = fetcher.fetch_indices(
|
||||
codes=codes,
|
||||
start="2024-01-01",
|
||||
end="2024-03-31"
|
||||
)
|
||||
|
||||
# 验证
|
||||
assert len(data) == 2, f"应该返回 2 只指数,实际 {len(data)}"
|
||||
|
||||
for code, df in data.items():
|
||||
print(f"\n{code}:")
|
||||
print(f" 数据量: {len(df)} 条")
|
||||
print(f" 列: {list(df.columns)}")
|
||||
print(f" 日期范围: {df.index[0]} ~ {df.index[-1]}")
|
||||
|
||||
assert len(df) > 0, f"{code} 数据为空"
|
||||
assert 'close' in df.columns, f"{code} 缺少 close 列"
|
||||
assert 'volume' in df.columns, f"{code} 缺少 volume 列"
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_fetch_etf():
|
||||
"""测试 3: 获取 ETF 数据"""
|
||||
print("\n" + "=" * 60)
|
||||
print(" 测试 3: 获取 ETF 数据")
|
||||
print("=" * 60)
|
||||
|
||||
fetcher = FlaskAPIFetcher()
|
||||
|
||||
# 获取沪深 300 ETF
|
||||
codes = ["510300.SH"]
|
||||
data = fetcher.fetch_etf(
|
||||
codes=codes,
|
||||
start="2024-01-01",
|
||||
end="2024-03-31"
|
||||
)
|
||||
|
||||
# 验证
|
||||
assert len(data) == 1, f"应该返回 1 只 ETF,实际 {len(data)}"
|
||||
|
||||
code = "510300.SH"
|
||||
df = data[code]
|
||||
|
||||
print(f"\n{code}:")
|
||||
print(f" 价格数据: {len(df)} 条")
|
||||
print(f" 列: {list(df.columns)}")
|
||||
|
||||
# 验证附加信息
|
||||
nav = df.attrs.get('nav')
|
||||
if nav is not None:
|
||||
print(f" 净值数据: {len(nav)} 条")
|
||||
|
||||
premium = df.attrs.get('latest_premium')
|
||||
if premium is not None:
|
||||
print(f" 最新溢价率: {premium:.2f}%")
|
||||
|
||||
assert len(df) > 0, f"{code} 数据为空"
|
||||
assert 'close' in df.columns, f"{code} 缺少 close 列"
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_trading_calendar():
|
||||
"""测试 4: 获取交易日历"""
|
||||
print("\n" + "=" * 60)
|
||||
print(" 测试 4: 获取交易日历")
|
||||
print("=" * 60)
|
||||
|
||||
fetcher = FlaskAPIFetcher()
|
||||
|
||||
# A股日历
|
||||
calendar_a = fetcher.get_trading_calendar(market='A')
|
||||
print(f"\nA股交易日历:")
|
||||
print(f" 总天数: {len(calendar_a)}")
|
||||
print(f" 日期范围: {calendar_a[0]} ~ {calendar_a[-1]}")
|
||||
print(f" 前 5 天: {calendar_a[:5].tolist()}")
|
||||
|
||||
assert len(calendar_a) > 0, "A股日历为空"
|
||||
|
||||
# 美股日历
|
||||
calendar_us = fetcher.get_trading_calendar(market='US')
|
||||
print(f"\n美股交易日历:")
|
||||
print(f" 总天数: {len(calendar_us)}")
|
||||
|
||||
assert len(calendar_us) > 0, "美股日历为空"
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_benchmark():
|
||||
"""测试 5: 获取基准数据"""
|
||||
print("\n" + "=" * 60)
|
||||
print(" 测试 5: 获取基准数据")
|
||||
print("=" * 60)
|
||||
|
||||
fetcher = FlaskAPIFetcher()
|
||||
|
||||
benchmark = fetcher.get_benchmark(
|
||||
code="000300.SH",
|
||||
start="2024-01-01",
|
||||
end="2024-03-31"
|
||||
)
|
||||
|
||||
print(f"\n沪深 300 基准:")
|
||||
print(f" 数据量: {len(benchmark)} 条")
|
||||
print(f" 日期范围: {benchmark.index[0]} ~ {benchmark.index[-1]}")
|
||||
print(f" 价格范围: {benchmark.min():.2f} ~ {benchmark.max():.2f}")
|
||||
|
||||
assert len(benchmark) > 0, "基准数据为空"
|
||||
assert isinstance(benchmark, pd.Series), "基准数据应该是 Series"
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pandas as pd
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(" FlaskAPIFetcher 测试")
|
||||
print("=" * 60)
|
||||
|
||||
tests = [
|
||||
("健康检查", test_health_check),
|
||||
("指数数据", test_fetch_indices),
|
||||
("ETF 数据", test_fetch_etf),
|
||||
("交易日历", test_trading_calendar),
|
||||
("基准数据", test_benchmark),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for name, test_func in tests:
|
||||
try:
|
||||
test_func()
|
||||
passed += 1
|
||||
except Exception as e:
|
||||
print(f"\n✗ 测试失败: {name}")
|
||||
print(f" 错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
failed += 1
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(" 测试总结")
|
||||
print("=" * 60)
|
||||
print(f" ✓ 通过 - {passed}")
|
||||
if failed > 0:
|
||||
print(f" ✗ 失败 - {failed}")
|
||||
print(f"\n总计: {passed}/{passed + failed} 通过")
|
||||
print("=" * 60 + "\n")
|
||||
Reference in New Issue
Block a user