feat(framework_v2): 添加 FlaskAPIFetcher 数据获取器

## 核心功能
- FlaskAPIFetcher: 继承 DataFetcher 抽象基类
- fetch_indices(): 获取指数 OHLCV 数据
- fetch_etf(): 获取 ETF 数据(自动附加净值+溢价率)
- get_trading_calendar(): 获取交易日历
- get_benchmark(): 获取基准数据

## 技术实现
- 委托调用 FlaskAPIDataSource(HTTP API)
- 自动重试 3 次,超时 120 秒
- Pydantic Schema 验证响应
- 进度显示(批量获取)
- 无需本地 SSH 隧道配置

## 测试验证
- 5/5 测试通过(健康检查、指数、ETF、日历、基准)
- 成功获取线上数据(000300.SH, 510300.SH)
- ETF 自动附加净值(3695 条)和溢价率

## 架构设计
- shared/data/flask_api_fetcher.py - 实现(262 行)
- tests/test_flask_api_fetcher.py - 测试(199 行)
- 依赖倒置原则(策略依赖抽象接口)
This commit is contained in:
2026-05-24 10:38:34 +08:00
parent 5f08e508ac
commit 40116f436f
3 changed files with 461 additions and 0 deletions

View File

@@ -0,0 +1,198 @@
"""
测试 FlaskAPIFetcher
验证:
1. 获取指数数据
2. 获取 ETF 数据
3. 获取交易日历
4. 健康检查
"""
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from framework_v2.shared.data import FlaskAPIFetcher
def test_health_check():
"""测试 1: 健康检查"""
print("\n" + "=" * 60)
print(" 测试 1: 健康检查")
print("=" * 60)
fetcher = FlaskAPIFetcher()
health = fetcher.get_health()
print(f"\n健康状态: {health}")
assert health.get('available'), "API 服务不可用"
print("\n✓ 测试通过")
def test_fetch_indices():
"""测试 2: 获取指数数据"""
print("\n" + "=" * 60)
print(" 测试 2: 获取指数数据")
print("=" * 60)
fetcher = FlaskAPIFetcher()
# 获取沪深 300 + 中证 500
codes = ["000300.SH", "000905.SH"]
data = fetcher.fetch_indices(
codes=codes,
start="2024-01-01",
end="2024-03-31"
)
# 验证
assert len(data) == 2, f"应该返回 2 只指数,实际 {len(data)}"
for code, df in data.items():
print(f"\n{code}:")
print(f" 数据量: {len(df)}")
print(f" 列: {list(df.columns)}")
print(f" 日期范围: {df.index[0]} ~ {df.index[-1]}")
assert len(df) > 0, f"{code} 数据为空"
assert 'close' in df.columns, f"{code} 缺少 close 列"
assert 'volume' in df.columns, f"{code} 缺少 volume 列"
print("\n✓ 测试通过")
def test_fetch_etf():
"""测试 3: 获取 ETF 数据"""
print("\n" + "=" * 60)
print(" 测试 3: 获取 ETF 数据")
print("=" * 60)
fetcher = FlaskAPIFetcher()
# 获取沪深 300 ETF
codes = ["510300.SH"]
data = fetcher.fetch_etf(
codes=codes,
start="2024-01-01",
end="2024-03-31"
)
# 验证
assert len(data) == 1, f"应该返回 1 只 ETF实际 {len(data)}"
code = "510300.SH"
df = data[code]
print(f"\n{code}:")
print(f" 价格数据: {len(df)}")
print(f" 列: {list(df.columns)}")
# 验证附加信息
nav = df.attrs.get('nav')
if nav is not None:
print(f" 净值数据: {len(nav)}")
premium = df.attrs.get('latest_premium')
if premium is not None:
print(f" 最新溢价率: {premium:.2f}%")
assert len(df) > 0, f"{code} 数据为空"
assert 'close' in df.columns, f"{code} 缺少 close 列"
print("\n✓ 测试通过")
def test_trading_calendar():
"""测试 4: 获取交易日历"""
print("\n" + "=" * 60)
print(" 测试 4: 获取交易日历")
print("=" * 60)
fetcher = FlaskAPIFetcher()
# A股日历
calendar_a = fetcher.get_trading_calendar(market='A')
print(f"\nA股交易日历:")
print(f" 总天数: {len(calendar_a)}")
print(f" 日期范围: {calendar_a[0]} ~ {calendar_a[-1]}")
print(f" 前 5 天: {calendar_a[:5].tolist()}")
assert len(calendar_a) > 0, "A股日历为空"
# 美股日历
calendar_us = fetcher.get_trading_calendar(market='US')
print(f"\n美股交易日历:")
print(f" 总天数: {len(calendar_us)}")
assert len(calendar_us) > 0, "美股日历为空"
print("\n✓ 测试通过")
def test_benchmark():
"""测试 5: 获取基准数据"""
print("\n" + "=" * 60)
print(" 测试 5: 获取基准数据")
print("=" * 60)
fetcher = FlaskAPIFetcher()
benchmark = fetcher.get_benchmark(
code="000300.SH",
start="2024-01-01",
end="2024-03-31"
)
print(f"\n沪深 300 基准:")
print(f" 数据量: {len(benchmark)}")
print(f" 日期范围: {benchmark.index[0]} ~ {benchmark.index[-1]}")
print(f" 价格范围: {benchmark.min():.2f} ~ {benchmark.max():.2f}")
assert len(benchmark) > 0, "基准数据为空"
assert isinstance(benchmark, pd.Series), "基准数据应该是 Series"
print("\n✓ 测试通过")
if __name__ == "__main__":
import pandas as pd
print("\n" + "=" * 60)
print(" FlaskAPIFetcher 测试")
print("=" * 60)
tests = [
("健康检查", test_health_check),
("指数数据", test_fetch_indices),
("ETF 数据", test_fetch_etf),
("交易日历", test_trading_calendar),
("基准数据", test_benchmark),
]
passed = 0
failed = 0
for name, test_func in tests:
try:
test_func()
passed += 1
except Exception as e:
print(f"\n✗ 测试失败: {name}")
print(f" 错误: {e}")
import traceback
traceback.print_exc()
failed += 1
print("\n" + "=" * 60)
print(" 测试总结")
print("=" * 60)
print(f" ✓ 通过 - {passed}")
if failed > 0:
print(f" ✗ 失败 - {failed}")
print(f"\n总计: {passed}/{passed + failed} 通过")
print("=" * 60 + "\n")