Archive legacy framework and utility modules that are no longer referenced by the active core (datasource/ and rotation/): - framework/ -> archive/framework/ - framework_v2/ -> archive/framework_v2/ - strategies/ -> archive/strategies/ - config/ -> archive/config/ - visualization/ -> archive/visualization/ - scripts/ -> archive/scripts/ - tests/ -> archive/tests/ - run_rotation.py, run_us_rotation.py -> archive/single_files/ - compare_*.py, test_api_dates.py -> archive/single_files/
199 lines
5.1 KiB
Python
199 lines
5.1 KiB
Python
"""
|
||
测试 FlaskAPIFetcher
|
||
|
||
验证:
|
||
1. 获取指数数据
|
||
2. 获取 ETF 数据
|
||
3. 获取交易日历
|
||
4. 健康检查
|
||
"""
|
||
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
# 添加项目根目录到路径
|
||
project_root = Path(__file__).parent.parent.parent
|
||
if str(project_root) not in sys.path:
|
||
sys.path.insert(0, str(project_root))
|
||
|
||
from framework_v2.shared.data import FlaskAPIFetcher
|
||
|
||
|
||
def test_health_check():
|
||
"""测试 1: 健康检查"""
|
||
print("\n" + "=" * 60)
|
||
print(" 测试 1: 健康检查")
|
||
print("=" * 60)
|
||
|
||
fetcher = FlaskAPIFetcher()
|
||
health = fetcher.get_health()
|
||
|
||
print(f"\n健康状态: {health}")
|
||
|
||
assert health.get('available'), "API 服务不可用"
|
||
print("\n✓ 测试通过")
|
||
|
||
|
||
def test_fetch_indices():
|
||
"""测试 2: 获取指数数据"""
|
||
print("\n" + "=" * 60)
|
||
print(" 测试 2: 获取指数数据")
|
||
print("=" * 60)
|
||
|
||
fetcher = FlaskAPIFetcher()
|
||
|
||
# 获取沪深 300 + 中证 500
|
||
codes = ["000300.SH", "000905.SH"]
|
||
data = fetcher.fetch_indices(
|
||
codes=codes,
|
||
start="2024-01-01",
|
||
end="2024-03-31"
|
||
)
|
||
|
||
# 验证
|
||
assert len(data) == 2, f"应该返回 2 只指数,实际 {len(data)}"
|
||
|
||
for code, df in data.items():
|
||
print(f"\n{code}:")
|
||
print(f" 数据量: {len(df)} 条")
|
||
print(f" 列: {list(df.columns)}")
|
||
print(f" 日期范围: {df.index[0]} ~ {df.index[-1]}")
|
||
|
||
assert len(df) > 0, f"{code} 数据为空"
|
||
assert 'close' in df.columns, f"{code} 缺少 close 列"
|
||
assert 'volume' in df.columns, f"{code} 缺少 volume 列"
|
||
|
||
print("\n✓ 测试通过")
|
||
|
||
|
||
def test_fetch_etf():
|
||
"""测试 3: 获取 ETF 数据"""
|
||
print("\n" + "=" * 60)
|
||
print(" 测试 3: 获取 ETF 数据")
|
||
print("=" * 60)
|
||
|
||
fetcher = FlaskAPIFetcher()
|
||
|
||
# 获取沪深 300 ETF
|
||
codes = ["510300.SH"]
|
||
data = fetcher.fetch_etf(
|
||
codes=codes,
|
||
start="2024-01-01",
|
||
end="2024-03-31"
|
||
)
|
||
|
||
# 验证
|
||
assert len(data) == 1, f"应该返回 1 只 ETF,实际 {len(data)}"
|
||
|
||
code = "510300.SH"
|
||
df = data[code]
|
||
|
||
print(f"\n{code}:")
|
||
print(f" 价格数据: {len(df)} 条")
|
||
print(f" 列: {list(df.columns)}")
|
||
|
||
# 验证附加信息
|
||
nav = df.attrs.get('nav')
|
||
if nav is not None:
|
||
print(f" 净值数据: {len(nav)} 条")
|
||
|
||
premium = df.attrs.get('latest_premium')
|
||
if premium is not None:
|
||
print(f" 最新溢价率: {premium:.2f}%")
|
||
|
||
assert len(df) > 0, f"{code} 数据为空"
|
||
assert 'close' in df.columns, f"{code} 缺少 close 列"
|
||
|
||
print("\n✓ 测试通过")
|
||
|
||
|
||
def test_trading_calendar():
|
||
"""测试 4: 获取交易日历"""
|
||
print("\n" + "=" * 60)
|
||
print(" 测试 4: 获取交易日历")
|
||
print("=" * 60)
|
||
|
||
fetcher = FlaskAPIFetcher()
|
||
|
||
# A股日历
|
||
calendar_a = fetcher.get_trading_calendar(market='A')
|
||
print(f"\nA股交易日历:")
|
||
print(f" 总天数: {len(calendar_a)}")
|
||
print(f" 日期范围: {calendar_a[0]} ~ {calendar_a[-1]}")
|
||
print(f" 前 5 天: {calendar_a[:5].tolist()}")
|
||
|
||
assert len(calendar_a) > 0, "A股日历为空"
|
||
|
||
# 美股日历
|
||
calendar_us = fetcher.get_trading_calendar(market='US')
|
||
print(f"\n美股交易日历:")
|
||
print(f" 总天数: {len(calendar_us)}")
|
||
|
||
assert len(calendar_us) > 0, "美股日历为空"
|
||
|
||
print("\n✓ 测试通过")
|
||
|
||
|
||
def test_benchmark():
|
||
"""测试 5: 获取基准数据"""
|
||
print("\n" + "=" * 60)
|
||
print(" 测试 5: 获取基准数据")
|
||
print("=" * 60)
|
||
|
||
fetcher = FlaskAPIFetcher()
|
||
|
||
benchmark = fetcher.get_benchmark(
|
||
code="000300.SH",
|
||
start="2024-01-01",
|
||
end="2024-03-31"
|
||
)
|
||
|
||
print(f"\n沪深 300 基准:")
|
||
print(f" 数据量: {len(benchmark)} 条")
|
||
print(f" 日期范围: {benchmark.index[0]} ~ {benchmark.index[-1]}")
|
||
print(f" 价格范围: {benchmark.min():.2f} ~ {benchmark.max():.2f}")
|
||
|
||
assert len(benchmark) > 0, "基准数据为空"
|
||
assert isinstance(benchmark, pd.Series), "基准数据应该是 Series"
|
||
|
||
print("\n✓ 测试通过")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import pandas as pd
|
||
|
||
print("\n" + "=" * 60)
|
||
print(" FlaskAPIFetcher 测试")
|
||
print("=" * 60)
|
||
|
||
tests = [
|
||
("健康检查", test_health_check),
|
||
("指数数据", test_fetch_indices),
|
||
("ETF 数据", test_fetch_etf),
|
||
("交易日历", test_trading_calendar),
|
||
("基准数据", test_benchmark),
|
||
]
|
||
|
||
passed = 0
|
||
failed = 0
|
||
|
||
for name, test_func in tests:
|
||
try:
|
||
test_func()
|
||
passed += 1
|
||
except Exception as e:
|
||
print(f"\n✗ 测试失败: {name}")
|
||
print(f" 错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
failed += 1
|
||
|
||
print("\n" + "=" * 60)
|
||
print(" 测试总结")
|
||
print("=" * 60)
|
||
print(f" ✓ 通过 - {passed}")
|
||
if failed > 0:
|
||
print(f" ✗ 失败 - {failed}")
|
||
print(f"\n总计: {passed}/{passed + failed} 通过")
|
||
print("=" * 60 + "\n")
|