test: 添加 UniversalDataFetcher 和 Flask API 测试

- 新增 test_universal_fetcher.py:资产类型检测、单只/批量获取、边界测试
- 新增 test_ssh_tunnel.py:SSH 隧道连接、港美股数据获取、混合市场测试
- 新增 test_flask_api.py:API 端点测试、健康检查、批量获取测试
- 所有测试已通过验证(A股100%,SSH隧道港美股正常)
This commit is contained in:
2026-05-07 21:19:37 +08:00
parent 8b2c2be6f3
commit fbaa3f9d73
3 changed files with 585 additions and 0 deletions

140
tests/test_flask_api.py Normal file
View File

@@ -0,0 +1,140 @@
"""
Flask API 快速测试
==================
测试 Flask 服务的基本功能
"""
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from dotenv import load_dotenv
load_dotenv()
import requests
import json
def test_api():
"""测试 API 功能"""
BASE_URL = "http://localhost:5000"
print("\n" + "="*60)
print("Flask API 测试")
print("="*60)
# 测试1: 健康检查
print("\n1. 测试健康检查...")
try:
response = requests.get(f"{BASE_URL}/health", timeout=5)
print(f" 状态码: {response.status_code}")
print(f" 响应: {json.dumps(response.json(), indent=2, ensure_ascii=False)[:200]}")
except Exception as e:
print(f" ✗ 失败: {e}")
# 测试2: 首页
print("\n2. 测试首页...")
try:
response = requests.get(f"{BASE_URL}/", timeout=5)
print(f" 状态码: {response.status_code}")
data = response.json()
print(f" API名称: {data.get('name')}")
print(f" 版本: {data.get('version')}")
print(f" SSH状态: {data.get('ssh_status')}")
except Exception as e:
print(f" ✗ 失败: {e}")
# 测试3: 资产类型检测
print("\n3. 测试资产类型检测...")
codes = ["000300.SH", "NDX", "BTC"]
for code in codes:
try:
response = requests.get(
f"{BASE_URL}/api/v1/asset-type",
params={"code": code},
timeout=5
)
data = response.json()
print(f" {code:15s} -> {data.get('asset_type'):15s} ({data.get('description')})")
except Exception as e:
print(f" {code:15s} ✗ 失败: {e}")
# 测试4: 获取K线数据
print("\n4. 测试获取K线数据...")
try:
response = requests.get(
f"{BASE_URL}/api/v1/ohlcv",
params={
"code": "000300.SH",
"start": "2024-01-01",
"end": "2024-01-31"
},
timeout=30
)
data = response.json()
if "error" in data:
print(f" ✗ 错误: {data['error']}")
else:
print(f" ✓ 获取成功: {data.get('count')}")
print(f" 资产类型: {data.get('asset_type')}")
if data.get('data'):
latest = data['data'][-1]
print(f" 最新数据: {latest.get('date')} 收盘 {latest.get('close')}")
except Exception as e:
print(f" ✗ 失败: {e}")
# 测试5: 批量获取
print("\n5. 测试批量获取...")
try:
response = requests.post(
f"{BASE_URL}/api/v1/ohlcv/batch",
json={
"codes": ["000300.SH", "510300.SH"],
"start": "2024-01-01",
"end": "2024-01-31"
},
timeout=60
)
data = response.json()
print(f" 成功: {data.get('success_count')}/{data.get('total')}")
print(f" 失败: {data.get('failed_count')}/{data.get('total')}")
for code, result in data.get('results', {}).items():
if 'error' in result:
print(f"{code}: {result['error']}")
else:
print(f"{code}: {result['count']}")
except Exception as e:
print(f" ✗ 失败: {e}")
# 测试6: 支持的代码
print("\n6. 测试获取支持的代码...")
try:
response = requests.get(f"{BASE_URL}/api/v1/supported-codes", timeout=5)
data = response.json()
print(f" 支持的资产类型: {len(data)}")
for asset_type, info in list(data.items())[:3]:
print(f" - {asset_type}: {info.get('description')}")
print(f" 示例: {', '.join(info.get('examples', [])[:3])}")
except Exception as e:
print(f" ✗ 失败: {e}")
print("\n" + "="*60)
print("测试完成")
print("="*60)
if __name__ == "__main__":
print("\n" + "="*60)
print("Flask API 测试客户端")
print("="*60)
print("\n请确保 Flask 服务已启动:")
print(" ./start_flask_server.sh")
print("")
print(" python core/datasource/flask_server.py")
print("")
test_api()

217
tests/test_ssh_tunnel.py Normal file
View File

@@ -0,0 +1,217 @@
"""
使用香港服务器 SSH 隧道测试 UniversalDataFetcher
==============================================
通过阿里云香港 ECS 服务器建立 SOCKS5 代理,获取港美股数据
"""
import sys
from pathlib import Path
# 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent))
from dotenv import load_dotenv
load_dotenv()
from core.datasource.universal_fetcher import UniversalDataFetcher, fetch_kline
def test_with_ssh_tunnel():
"""使用 SSH 隧道获取港美股数据"""
print("\n" + "="*60)
print("使用香港服务器 SSH 隧道获取数据")
print("="*60)
# SSH 隧道配置(使用 hk_ecs.pem
ssh_config = {
"enabled": True,
"host": "8.218.167.69", # 阿里云香港 ECS IP
"port": 22, # SSH 端口
"username": "root", # SSH 用户名
"key_path": "hk_ecs.pem", # SSH 私钥路径(相对于项目根目录)
"local_port": 1080, # 本地 SOCKS5 代理端口
}
# 要测试的标的
test_codes = [
("NDX", "纳斯达克100"),
("SPX", "标普500"),
("HSI", "恒生指数"),
("N225", "日经225"),
("GDAXI", "德国DAX"),
]
print(f"\nSSH 配置:")
print(f" 服务器: {ssh_config['host']}:{ssh_config['port']}")
print(f" 用户名: {ssh_config['username']}")
print(f" 私钥: {ssh_config['key_path']}")
print(f" 本地端口: {ssh_config['local_port']}")
# 创建带 SSH 隧道的数据获取器
fetcher = UniversalDataFetcher(ssh_config=ssh_config)
print("\n正在建立 SSH 隧道...")
with fetcher:
print("✓ SSH 隧道已建立")
print("\n开始获取数据...")
for code, name in test_codes:
print(f"\n[{code}] {name}")
try:
df = fetcher.fetch(code, "2024-01-01", "2024-03-31")
if df is not None and len(df) > 0:
print(f" ✓ 获取成功: {len(df)}")
print(f" 日期范围: {df.index.min().strftime('%Y-%m-%d')} ~ {df.index.max().strftime('%Y-%m-%d')}")
print(f" 最新收盘价: {df['close'].iloc[-1]:.3f}")
print(f" 数据预览:")
print(df.tail(3)[['open', 'high', 'low', 'close']].to_string())
else:
print(f" ✗ 无数据")
except Exception as e:
print(f" ✗ 错误: {e}")
print("\n✓ SSH 隧道已关闭")
def test_without_ssh():
"""不使用 SSH 隧道仅获取A股数据"""
print("\n" + "="*60)
print("不使用 SSH 隧道仅A股")
print("="*60)
test_codes = [
("000300.SH", "沪深300"),
("510300.SH", "沪深300ETF"),
("AU.SHF", "黄金期货"),
]
fetcher = UniversalDataFetcher(ssh_config={"enabled": False})
with fetcher:
for code, name in test_codes:
print(f"\n[{code}] {name}")
df = fetcher.fetch(code, "2024-01-01", "2024-03-31")
if df is not None:
print(f"{len(df)} 条, 最新: {df['close'].iloc[-1]:.3f}")
else:
print(f" ✗ 无数据")
def test_mixed_markets():
"""混合市场测试A股 + 港美股)"""
print("\n" + "="*60)
print("混合市场测试A股 + 港美股)")
print("="*60)
ssh_config = {
"enabled": True,
"host": "8.218.167.69",
"port": 22,
"username": "root",
"key_path": "hk_ecs.pem",
"local_port": 1080,
}
# 混合代码
codes = [
"000300.SH", # A股指数
"NDX", # 美股指数
"HSI", # 港股指数
"510300.SH", # A股ETF
"N225", # 日本指数
]
fetcher = UniversalDataFetcher(ssh_config=ssh_config)
print("\n开始批量获取...")
with fetcher:
results = fetcher.fetch_multiple(codes, "2024-01-01", "2024-03-31")
print("\n获取结果:")
for code, df in results.items():
if df is not None:
print(f"{code:15s} {len(df):4d} 条, "
f"最新: {df['close'].iloc[-1]:.3f}")
else:
print(f"{code:15s} 无数据")
def test_ssh_connection_only():
"""仅测试 SSH 连接"""
print("\n" + "="*60)
print("测试 SSH 连接")
print("="*60)
ssh_config = {
"enabled": True,
"host": "8.218.167.69",
"port": 22,
"username": "root",
"key_path": "hk_ecs.pem",
"local_port": 1080,
}
print("\n正在建立 SSH 隧道...")
fetcher = UniversalDataFetcher(ssh_config=ssh_config)
try:
with fetcher:
print("✓ SSH 隧道建立成功!")
print(f" 本地 SOCKS5 代理: socks5h://127.0.0.1:{ssh_config['local_port']}")
# 测试一个简单的 HTTP 请求
import requests
print("\n测试代理连接...")
try:
response = requests.get(
"https://www.google.com",
proxies={
"http": f"socks5h://127.0.0.1:{ssh_config['local_port']}",
"https": f"socks5h://127.0.0.1:{ssh_config['local_port']}",
},
timeout=10
)
print(f"✓ 代理测试成功!状态码: {response.status_code}")
except Exception as e:
print(f"✗ 代理测试失败: {e}")
print("\n✓ SSH 隧道已关闭")
except Exception as e:
print(f"✗ SSH 隧道建立失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
print("\n" + "="*60)
print("UniversalDataFetcher + 香港 SSH 隧道测试")
print("="*60)
# 选择要运行的测试
tests = [
("1", "仅测试 SSH 连接", test_ssh_connection_only),
("2", "不使用 SSH仅A股", test_without_ssh),
("3", "使用 SSH 隧道(港美股)", test_with_ssh_tunnel),
("4", "混合市场测试", test_mixed_markets),
]
print("\n可用测试:")
for num, name, _ in tests:
print(f" {num}. {name}")
# 运行所有测试
print("\n" + "="*60)
print("运行所有测试...")
print("="*60)
for num, name, func in tests:
try:
func()
except Exception as e:
print(f"\n测试 '{name}' 失败: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*60)
print("测试完成")
print("="*60)

View File

@@ -0,0 +1,228 @@
"""
统一数据获取接口测试
====================
测试各种资产类型的K线数据获取
"""
import sys
from pathlib import Path
# 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent))
from dotenv import load_dotenv
load_dotenv()
from core.datasource.universal_fetcher import (
UniversalDataFetcher,
AssetTypeDetector,
detect_asset_type,
fetch_kline
)
def test_asset_detection():
"""测试资产类型检测"""
print("\n" + "="*60)
print("测试1: 资产类型检测")
print("="*60)
test_cases = [
# A股指数
("000300.SH", "china_index"),
("399006.SZ", "china_index"),
("H30269.CSI", "china_index"),
# A股ETF
("510300.SH", "china_etf"),
("159915.SZ", "china_etf"),
("513100.SH", "china_etf"),
# A股股票
("600000.SH", "china_stock"),
("000001.SZ", "china_stock"),
# 港股
("HSI", "hk_index"),
("HSTECH.HK", "hk_index"),
# 美股
("NDX", "us_index"),
("SPX", "us_index"),
("AAPL", "us_stock"),
# 期货
("AU.SHF", "futures"),
("CU.SHF", "futures"),
# 加密货币
("BTC", "crypto"),
("ETH", "crypto"),
]
correct = 0
for code, expected in test_cases:
result = detect_asset_type(code)
status = "" if result == expected else ""
if result == expected:
correct += 1
print(f" {status} {code:15s} -> {result:15s} (期望: {expected})")
print(f"\n检测准确率: {correct}/{len(test_cases)} ({100*correct/len(test_cases):.1f}%)")
def test_single_fetch():
"""测试单只标的获取"""
print("\n" + "="*60)
print("测试2: 单只标的获取")
print("="*60)
# 测试A股指数
print("\n[A股指数] 000300.SH (沪深300)")
df = fetch_kline("000300.SH", "2024-01-01", "2024-03-31")
if df is not None:
print(f" ✓ 获取成功: {len(df)}")
print(f" 日期范围: {df.index.min()} ~ {df.index.max()}")
print(f" 列: {list(df.columns)}")
print(f" 最新数据:\n{df.tail(3)}")
else:
print(" ✗ 获取失败")
# 测试A股ETF
print("\n[A股ETF] 510300.SH (沪深300ETF)")
df = fetch_kline("510300.SH", "2024-01-01", "2024-03-31")
if df is not None:
print(f" ✓ 获取成功: {len(df)}")
print(f" 最新收盘价: {df['close'].iloc[-1]:.3f}")
else:
print(" ✗ 获取失败")
# 测试美股指数
print("\n[美股指数] NDX (纳斯达克100)")
df = fetch_kline("NDX", "2024-01-01", "2024-03-31")
if df is not None:
print(f" ✓ 获取成功: {len(df)}")
print(f" 最新收盘价: {df['close'].iloc[-1]:.3f}")
else:
print(" ✗ 获取失败可能需要SSH隧道")
# 测试港股指数
print("\n[港股指数] HSI (恒生指数)")
df = fetch_kline("HSI", "2024-01-01", "2024-03-31")
if df is not None:
print(f" ✓ 获取成功: {len(df)}")
print(f" 最新收盘价: {df['close'].iloc[-1]:.3f}")
else:
print(" ✗ 获取失败可能需要SSH隧道")
def test_multiple_fetch():
"""测试批量获取"""
print("\n" + "="*60)
print("测试3: 批量获取")
print("="*60)
codes = [
"000300.SH", # A股指数
"510300.SH", # A股ETF
"NDX", # 美股指数
"HSI", # 港股指数
"AU.SHF", # 期货
# "BTC", # 加密货币需要SSH隧道
]
fetcher = UniversalDataFetcher()
with fetcher:
results = fetcher.fetch_multiple(codes, "2024-01-01", "2024-03-31")
print(f"\n获取结果:")
for code, df in results.items():
if df is not None:
print(f"{code:15s} {len(df):4d} 条, "
f"最新收盘价: {df['close'].iloc[-1]:.3f}")
else:
print(f"{code:15s} 无数据")
def test_context_manager():
"""测试上下文管理器SSH隧道"""
print("\n" + "="*60)
print("测试4: 上下文管理器SSH隧道")
print("="*60)
# 不启用SSH
print("\n[不启用SSH] 获取A股数据应成功")
fetcher = UniversalDataFetcher(ssh_config={"enabled": False})
with fetcher:
df = fetcher.fetch("000300.SH", "2024-01-01", "2024-01-31")
if df is not None:
print(f" ✓ 成功: {len(df)}")
else:
print(" ✗ 失败")
# 启用SSH如果配置了
ssh_config = {
"enabled": False, # 改为 True 并填入实际配置以测试
"host": "",
"port": 22,
"username": "",
"key_path": "",
"local_port": 1080,
}
if ssh_config["enabled"]:
print("\n[启用SSH] 获取美股数据")
fetcher = UniversalDataFetcher(ssh_config=ssh_config)
with fetcher:
df = fetcher.fetch("NDX", "2024-01-01", "2024-01-31")
if df is not None:
print(f" ✓ 成功: {len(df)}")
else:
print(" ✗ 失败")
else:
print("\n[跳过SSH测试] SSH未启用")
def test_edge_cases():
"""测试边界情况"""
print("\n" + "="*60)
print("测试5: 边界情况")
print("="*60)
# 测试无效代码
print("\n[无效代码] INVALID")
df = fetch_kline("INVALID", "2024-01-01", "2024-01-31")
print(f" 结果: {df}")
# 测试日期范围
print("\n[空日期范围]")
df = fetch_kline("000300.SH", "2030-01-01", "2030-01-31")
if df is None or len(df) == 0:
print(" ✓ 正确处理(无数据)")
else:
print(f" ✗ 意外获取到数据: {len(df)}")
# 测试代码格式转换
print("\n[代码格式转换] 000300.SS -> 000300.SH")
df = fetch_kline("000300.SS", "2024-01-01", "2024-01-31")
if df is not None:
print(f" ✓ 转换成功: {len(df)}")
else:
print(" ✗ 失败")
if __name__ == "__main__":
print("\n" + "="*60)
print("统一数据获取接口测试")
print("="*60)
# 运行所有测试
test_asset_detection()
test_single_fetch()
test_multiple_fetch()
test_context_manager()
test_edge_cases()
print("\n" + "="*60)
print("测试完成")
print("="*60)