Files
etf/archive/framework_v2/tests/test_flask_api_fetcher.py
aszerW c905230a40 refactor(archive): move unused modules to archive/
Archive legacy framework and utility modules that are no longer
referenced by the active core (datasource/ and rotation/):

- framework/ -> archive/framework/
- framework_v2/ -> archive/framework_v2/
- strategies/ -> archive/strategies/
- config/ -> archive/config/
- visualization/ -> archive/visualization/
- scripts/ -> archive/scripts/
- tests/ -> archive/tests/
- run_rotation.py, run_us_rotation.py -> archive/single_files/
- compare_*.py, test_api_dates.py -> archive/single_files/
2026-06-03 23:41:46 +08:00

199 lines
5.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
测试 FlaskAPIFetcher
验证:
1. 获取指数数据
2. 获取 ETF 数据
3. 获取交易日历
4. 健康检查
"""
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from framework_v2.shared.data import FlaskAPIFetcher
def test_health_check():
"""测试 1: 健康检查"""
print("\n" + "=" * 60)
print(" 测试 1: 健康检查")
print("=" * 60)
fetcher = FlaskAPIFetcher()
health = fetcher.get_health()
print(f"\n健康状态: {health}")
assert health.get('available'), "API 服务不可用"
print("\n✓ 测试通过")
def test_fetch_indices():
"""测试 2: 获取指数数据"""
print("\n" + "=" * 60)
print(" 测试 2: 获取指数数据")
print("=" * 60)
fetcher = FlaskAPIFetcher()
# 获取沪深 300 + 中证 500
codes = ["000300.SH", "000905.SH"]
data = fetcher.fetch_indices(
codes=codes,
start="2024-01-01",
end="2024-03-31"
)
# 验证
assert len(data) == 2, f"应该返回 2 只指数,实际 {len(data)}"
for code, df in data.items():
print(f"\n{code}:")
print(f" 数据量: {len(df)}")
print(f" 列: {list(df.columns)}")
print(f" 日期范围: {df.index[0]} ~ {df.index[-1]}")
assert len(df) > 0, f"{code} 数据为空"
assert 'close' in df.columns, f"{code} 缺少 close 列"
assert 'volume' in df.columns, f"{code} 缺少 volume 列"
print("\n✓ 测试通过")
def test_fetch_etf():
"""测试 3: 获取 ETF 数据"""
print("\n" + "=" * 60)
print(" 测试 3: 获取 ETF 数据")
print("=" * 60)
fetcher = FlaskAPIFetcher()
# 获取沪深 300 ETF
codes = ["510300.SH"]
data = fetcher.fetch_etf(
codes=codes,
start="2024-01-01",
end="2024-03-31"
)
# 验证
assert len(data) == 1, f"应该返回 1 只 ETF实际 {len(data)}"
code = "510300.SH"
df = data[code]
print(f"\n{code}:")
print(f" 价格数据: {len(df)}")
print(f" 列: {list(df.columns)}")
# 验证附加信息
nav = df.attrs.get('nav')
if nav is not None:
print(f" 净值数据: {len(nav)}")
premium = df.attrs.get('latest_premium')
if premium is not None:
print(f" 最新溢价率: {premium:.2f}%")
assert len(df) > 0, f"{code} 数据为空"
assert 'close' in df.columns, f"{code} 缺少 close 列"
print("\n✓ 测试通过")
def test_trading_calendar():
"""测试 4: 获取交易日历"""
print("\n" + "=" * 60)
print(" 测试 4: 获取交易日历")
print("=" * 60)
fetcher = FlaskAPIFetcher()
# A股日历
calendar_a = fetcher.get_trading_calendar(market='A')
print(f"\nA股交易日历:")
print(f" 总天数: {len(calendar_a)}")
print(f" 日期范围: {calendar_a[0]} ~ {calendar_a[-1]}")
print(f" 前 5 天: {calendar_a[:5].tolist()}")
assert len(calendar_a) > 0, "A股日历为空"
# 美股日历
calendar_us = fetcher.get_trading_calendar(market='US')
print(f"\n美股交易日历:")
print(f" 总天数: {len(calendar_us)}")
assert len(calendar_us) > 0, "美股日历为空"
print("\n✓ 测试通过")
def test_benchmark():
"""测试 5: 获取基准数据"""
print("\n" + "=" * 60)
print(" 测试 5: 获取基准数据")
print("=" * 60)
fetcher = FlaskAPIFetcher()
benchmark = fetcher.get_benchmark(
code="000300.SH",
start="2024-01-01",
end="2024-03-31"
)
print(f"\n沪深 300 基准:")
print(f" 数据量: {len(benchmark)}")
print(f" 日期范围: {benchmark.index[0]} ~ {benchmark.index[-1]}")
print(f" 价格范围: {benchmark.min():.2f} ~ {benchmark.max():.2f}")
assert len(benchmark) > 0, "基准数据为空"
assert isinstance(benchmark, pd.Series), "基准数据应该是 Series"
print("\n✓ 测试通过")
if __name__ == "__main__":
import pandas as pd
print("\n" + "=" * 60)
print(" FlaskAPIFetcher 测试")
print("=" * 60)
tests = [
("健康检查", test_health_check),
("指数数据", test_fetch_indices),
("ETF 数据", test_fetch_etf),
("交易日历", test_trading_calendar),
("基准数据", test_benchmark),
]
passed = 0
failed = 0
for name, test_func in tests:
try:
test_func()
passed += 1
except Exception as e:
print(f"\n✗ 测试失败: {name}")
print(f" 错误: {e}")
import traceback
traceback.print_exc()
failed += 1
print("\n" + "=" * 60)
print(" 测试总结")
print("=" * 60)
print(f" ✓ 通过 - {passed}")
if failed > 0:
print(f" ✗ 失败 - {failed}")
print(f"\n总计: {passed}/{passed + failed} 通过")
print("=" * 60 + "\n")