From 40116f436f479d759ff80db5f5c4deab24c6d915 Mon Sep 17 00:00:00 2001 From: aszerW Date: Sun, 24 May 2026 10:38:34 +0800 Subject: [PATCH] =?UTF-8?q?feat(framework=5Fv2):=20=E6=B7=BB=E5=8A=A0=20Fl?= =?UTF-8?q?askAPIFetcher=20=E6=95=B0=E6=8D=AE=E8=8E=B7=E5=8F=96=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## 核心功能 - FlaskAPIFetcher: 继承 DataFetcher 抽象基类 - fetch_indices(): 获取指数 OHLCV 数据 - fetch_etf(): 获取 ETF 数据(自动附加净值+溢价率) - get_trading_calendar(): 获取交易日历 - get_benchmark(): 获取基准数据 ## 技术实现 - 委托调用 FlaskAPIDataSource(HTTP API) - 自动重试 3 次,超时 120 秒 - Pydantic Schema 验证响应 - 进度显示(批量获取) - 无需本地 SSH 隧道配置 ## 测试验证 - 5/5 测试通过(健康检查、指数、ETF、日历、基准) - 成功获取线上数据(000300.SH, 510300.SH) - ETF 自动附加净值(3695 条)和溢价率 ## 架构设计 - shared/data/flask_api_fetcher.py - 实现(262 行) - tests/test_flask_api_fetcher.py - 测试(199 行) - 依赖倒置原则(策略依赖抽象接口) --- framework_v2/shared/data/__init__.py | 2 + framework_v2/shared/data/flask_api_fetcher.py | 261 ++++++++++++++++++ framework_v2/tests/test_flask_api_fetcher.py | 198 +++++++++++++ 3 files changed, 461 insertions(+) create mode 100644 framework_v2/shared/data/flask_api_fetcher.py create mode 100644 framework_v2/tests/test_flask_api_fetcher.py diff --git a/framework_v2/shared/data/__init__.py b/framework_v2/shared/data/__init__.py index 4625831..a68a2bd 100644 --- a/framework_v2/shared/data/__init__.py +++ b/framework_v2/shared/data/__init__.py @@ -9,6 +9,7 @@ from framework_v2.shared.data.schemas import ( AlignedReturnsSchema, AlignmentValidationResult, ) +from framework_v2.shared.data.flask_api_fetcher import FlaskAPIFetcher __all__ = [ 'CrossMarketAligner', @@ -16,4 +17,5 @@ __all__ = [ 'AlignedFactorSchema', 'AlignedReturnsSchema', 'AlignmentValidationResult', + 'FlaskAPIFetcher', ] diff --git a/framework_v2/shared/data/flask_api_fetcher.py b/framework_v2/shared/data/flask_api_fetcher.py new file mode 100644 index 0000000..edc1ce2 --- /dev/null +++ b/framework_v2/shared/data/flask_api_fetcher.py @@ -0,0 +1,261 @@ +""" +Flask API 数据获取器(framework_v2 实现) + +继承 DataFetcher 抽象基类,使用 FlaskAPIDataSource 获取线上数据 +支持指数、ETF 数据获取 +""" + +import pandas as pd +from typing import Dict, List, Optional +from pathlib import Path +import sys + +# 添加项目根目录到路径 +project_root = Path(__file__).parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +from framework_v2.core.data import DataFetcher +from datasource.flask_api_source import FlaskAPIDataSource + + +class FlaskAPIFetcher(DataFetcher): + """ + Flask API 数据获取器 + + 通过 HTTP API 获取线上数据(指数、ETF) + 无需本地 SSH 隧道配置 + + 用法: + fetcher = FlaskAPIFetcher(base_url="https://k3s.tokenpluse.xyz") + data = fetcher.fetch_indices(["000300.SH"], "2024-01-01", "2024-12-31") + """ + + name = "flask_api" + + def __init__( + self, + base_url: str = None, + timeout: int = 120, + retries: int = 3 + ): + """ + 初始化 + + Args: + base_url: API 服务地址(默认从环境变量读取) + timeout: 请求超时时间(秒) + retries: 重试次数 + """ + super().__init__(base_url=base_url, timeout=timeout, retries=retries) + + # 创建底层数据源 + self._source = FlaskAPIDataSource( + base_url=base_url, + timeout=timeout, + retries=retries + ) + + def fetch_indices( + self, + codes: List[str], + start: str, + end: str + ) -> Dict[str, pd.DataFrame]: + """ + 获取指数 OHLCV 数据 + + Args: + codes: 指数代码列表(如 ["000300.SH", "000905.SH"]) + start: 开始日期 (YYYY-MM-DD) + end: 结束日期 (YYYY-MM-DD) + + Returns: + {code: DataFrame} 字典,DataFrame 包含 OHLCV 列 + + 示例: + >>> fetcher = FlaskAPIFetcher() + >>> data = fetcher.fetch_indices( + ... ["000300.SH", "000905.SH"], + ... "2024-01-01", + ... "2024-12-31" + ... ) + >>> print(data["000300.SH"].head()) + """ + print(f"\n[FlaskAPI] 获取 {len(codes)} 只指数数据...") + + results = {} + for i, code in enumerate(codes, 1): + print(f" [{i}/{len(codes)}] {code}...") + + df = self._source.fetch( + code=code, + start_date=start, + end_date=end, + adj='raw' # 指数通常用原始价格 + ) + + if df is not None: + results[code] = df + print(f" ✓ {len(df)} 条数据") + else: + print(f" ✗ 获取失败") + + success = len(results) + print(f"\n[FlaskAPI] 指数数据获取完成: {success}/{len(codes)} 成功") + + return results + + def fetch_etf( + self, + codes: List[str], + start: str, + end: str + ) -> Dict[str, pd.DataFrame]: + """ + 获取 ETF 数据(价格 + 净值) + + Args: + codes: ETF 代码列表(如 ["510300.SH", "159919.SZ"]) + start: 开始日期 (YYYY-MM-DD) + end: 结束日期 (YYYY-MM-DD) + + Returns: + {code: DataFrame} 字典 + DataFrame 包含 OHLCV 列 + df.attrs['nav'] 包含净值数据 + df.attrs['premium_series'] 包含溢价率序列 + + 示例: + >>> fetcher = FlaskAPIFetcher() + >>> data = fetcher.fetch_etf( + ... ["510300.SH", "159919.SZ"], + ... "2024-01-01", + ... "2024-12-31" + ... ) + >>> # 访问净值 + >>> nav = data["510300.SH"].attrs.get('nav') + """ + print(f"\n[FlaskAPI] 获取 {len(codes)} 只 ETF 数据...") + + results = {} + for i, code in enumerate(codes, 1): + print(f" [{i}/{len(codes)}] {code}...") + + df = self._source.fetch( + code=code, + start_date=start, + end_date=end, + adj='raw', + asset_type='china_etf' # 强制指定 ETF 类型 + ) + + if df is not None: + results[code] = df + + # 显示附加信息 + nav_count = len(df.attrs.get('nav', pd.DataFrame())) + premium = df.attrs.get('latest_premium', 'N/A') + + print(f" ✓ {len(df)} 条价格, {nav_count} 条净值, 溢价率: {premium}%") + else: + print(f" ✗ 获取失败") + + success = len(results) + print(f"\n[FlaskAPI] ETF 数据获取完成: {success}/{len(codes)} 成功") + + return results + + def get_trading_calendar(self, market: str = 'A') -> pd.Index: + """ + 获取交易日历 + + 注意:Flask API 暂不直接提供交易日历 + 这里使用 pandas 的 BDay 生成近似日历 + + TODO: 后续可通过 API 端点获取准确日历 + + Args: + market: 市场代码('A', 'US', 'HK' 等) + + Returns: + 交易日历 Index + """ + # 临时实现:使用 pandas 生成工作日日历 + # 实际应该从 API 获取准确的交易日历 + + if market == 'A': + # A股:中国工作日(简化实现) + start = pd.Timestamp('2020-01-01') + end = pd.Timestamp('2025-12-31') + calendar = pd.bdate_range(start=start, end=end) + + # 移除中国主要节假日(简化版) + # 实际应该从 API 或数据库获取准确日历 + holidays = [ + # 春节(示例,不完整) + '2024-02-10', '2024-02-11', '2024-02-12', '2024-02-13', '2024-02-14', + '2024-02-15', '2024-02-16', '2024-02-17', + # 国庆(示例,不完整) + '2024-10-01', '2024-10-02', '2024-10-03', '2024-10-04', + '2024-10-05', '2024-10-06', '2024-10-07', + ] + calendar = calendar[~calendar.isin(pd.to_datetime(holidays))] + + return calendar + + elif market == 'US': + # 美股:美国工作日 + start = pd.Timestamp('2020-01-01') + end = pd.Timestamp('2025-12-31') + return pd.bdate_range(start=start, end=end) + + elif market == 'HK': + # 港股:香港工作日 + start = pd.Timestamp('2020-01-01') + end = pd.Timestamp('2025-12-31') + return pd.bdate_range(start=start, end=end) + + else: + raise ValueError(f"不支持的市场: {market}") + + def get_benchmark( + self, + code: str = "000300.SH", + start: str = "2020-01-01", + end: str = "2025-12-31" + ) -> pd.Series: + """ + 获取基准数据 + + Args: + code: 基准代码(默认沪深 300) + start: 开始日期 + end: 结束日期 + + Returns: + 基准收盘价 Series + """ + df = self._source.fetch( + code=code, + start_date=start, + end_date=end, + adj='raw' + ) + + if df is None: + raise ValueError(f"基准数据获取失败: {code}") + + return df['close'] + + def get_health(self) -> Dict: + """ + 检查 API 服务健康状态 + + Returns: + 健康状态字典 + """ + return self._source.get_health() + + def __repr__(self) -> str: + return f"FlaskAPIFetcher(base_url={self._source.base_url})" diff --git a/framework_v2/tests/test_flask_api_fetcher.py b/framework_v2/tests/test_flask_api_fetcher.py new file mode 100644 index 0000000..aa6dabd --- /dev/null +++ b/framework_v2/tests/test_flask_api_fetcher.py @@ -0,0 +1,198 @@ +""" +测试 FlaskAPIFetcher + +验证: +1. 获取指数数据 +2. 获取 ETF 数据 +3. 获取交易日历 +4. 健康检查 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到路径 +project_root = Path(__file__).parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +from framework_v2.shared.data import FlaskAPIFetcher + + +def test_health_check(): + """测试 1: 健康检查""" + print("\n" + "=" * 60) + print(" 测试 1: 健康检查") + print("=" * 60) + + fetcher = FlaskAPIFetcher() + health = fetcher.get_health() + + print(f"\n健康状态: {health}") + + assert health.get('available'), "API 服务不可用" + print("\n✓ 测试通过") + + +def test_fetch_indices(): + """测试 2: 获取指数数据""" + print("\n" + "=" * 60) + print(" 测试 2: 获取指数数据") + print("=" * 60) + + fetcher = FlaskAPIFetcher() + + # 获取沪深 300 + 中证 500 + codes = ["000300.SH", "000905.SH"] + data = fetcher.fetch_indices( + codes=codes, + start="2024-01-01", + end="2024-03-31" + ) + + # 验证 + assert len(data) == 2, f"应该返回 2 只指数,实际 {len(data)}" + + for code, df in data.items(): + print(f"\n{code}:") + print(f" 数据量: {len(df)} 条") + print(f" 列: {list(df.columns)}") + print(f" 日期范围: {df.index[0]} ~ {df.index[-1]}") + + assert len(df) > 0, f"{code} 数据为空" + assert 'close' in df.columns, f"{code} 缺少 close 列" + assert 'volume' in df.columns, f"{code} 缺少 volume 列" + + print("\n✓ 测试通过") + + +def test_fetch_etf(): + """测试 3: 获取 ETF 数据""" + print("\n" + "=" * 60) + print(" 测试 3: 获取 ETF 数据") + print("=" * 60) + + fetcher = FlaskAPIFetcher() + + # 获取沪深 300 ETF + codes = ["510300.SH"] + data = fetcher.fetch_etf( + codes=codes, + start="2024-01-01", + end="2024-03-31" + ) + + # 验证 + assert len(data) == 1, f"应该返回 1 只 ETF,实际 {len(data)}" + + code = "510300.SH" + df = data[code] + + print(f"\n{code}:") + print(f" 价格数据: {len(df)} 条") + print(f" 列: {list(df.columns)}") + + # 验证附加信息 + nav = df.attrs.get('nav') + if nav is not None: + print(f" 净值数据: {len(nav)} 条") + + premium = df.attrs.get('latest_premium') + if premium is not None: + print(f" 最新溢价率: {premium:.2f}%") + + assert len(df) > 0, f"{code} 数据为空" + assert 'close' in df.columns, f"{code} 缺少 close 列" + + print("\n✓ 测试通过") + + +def test_trading_calendar(): + """测试 4: 获取交易日历""" + print("\n" + "=" * 60) + print(" 测试 4: 获取交易日历") + print("=" * 60) + + fetcher = FlaskAPIFetcher() + + # A股日历 + calendar_a = fetcher.get_trading_calendar(market='A') + print(f"\nA股交易日历:") + print(f" 总天数: {len(calendar_a)}") + print(f" 日期范围: {calendar_a[0]} ~ {calendar_a[-1]}") + print(f" 前 5 天: {calendar_a[:5].tolist()}") + + assert len(calendar_a) > 0, "A股日历为空" + + # 美股日历 + calendar_us = fetcher.get_trading_calendar(market='US') + print(f"\n美股交易日历:") + print(f" 总天数: {len(calendar_us)}") + + assert len(calendar_us) > 0, "美股日历为空" + + print("\n✓ 测试通过") + + +def test_benchmark(): + """测试 5: 获取基准数据""" + print("\n" + "=" * 60) + print(" 测试 5: 获取基准数据") + print("=" * 60) + + fetcher = FlaskAPIFetcher() + + benchmark = fetcher.get_benchmark( + code="000300.SH", + start="2024-01-01", + end="2024-03-31" + ) + + print(f"\n沪深 300 基准:") + print(f" 数据量: {len(benchmark)} 条") + print(f" 日期范围: {benchmark.index[0]} ~ {benchmark.index[-1]}") + print(f" 价格范围: {benchmark.min():.2f} ~ {benchmark.max():.2f}") + + assert len(benchmark) > 0, "基准数据为空" + assert isinstance(benchmark, pd.Series), "基准数据应该是 Series" + + print("\n✓ 测试通过") + + +if __name__ == "__main__": + import pandas as pd + + print("\n" + "=" * 60) + print(" FlaskAPIFetcher 测试") + print("=" * 60) + + tests = [ + ("健康检查", test_health_check), + ("指数数据", test_fetch_indices), + ("ETF 数据", test_fetch_etf), + ("交易日历", test_trading_calendar), + ("基准数据", test_benchmark), + ] + + passed = 0 + failed = 0 + + for name, test_func in tests: + try: + test_func() + passed += 1 + except Exception as e: + print(f"\n✗ 测试失败: {name}") + print(f" 错误: {e}") + import traceback + traceback.print_exc() + failed += 1 + + print("\n" + "=" * 60) + print(" 测试总结") + print("=" * 60) + print(f" ✓ 通过 - {passed}") + if failed > 0: + print(f" ✗ 失败 - {failed}") + print(f"\n总计: {passed}/{passed + failed} 通过") + print("=" * 60 + "\n")