diff --git a/tests/test_flask_api.py b/tests/test_flask_api.py new file mode 100644 index 0000000..319b958 --- /dev/null +++ b/tests/test_flask_api.py @@ -0,0 +1,140 @@ +""" +Flask API 快速测试 +================== +测试 Flask 服务的基本功能 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到路径 +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from dotenv import load_dotenv +load_dotenv() + +import requests +import json + + +def test_api(): + """测试 API 功能""" + BASE_URL = "http://localhost:5000" + + print("\n" + "="*60) + print("Flask API 测试") + print("="*60) + + # 测试1: 健康检查 + print("\n1. 测试健康检查...") + try: + response = requests.get(f"{BASE_URL}/health", timeout=5) + print(f" 状态码: {response.status_code}") + print(f" 响应: {json.dumps(response.json(), indent=2, ensure_ascii=False)[:200]}") + except Exception as e: + print(f" ✗ 失败: {e}") + + # 测试2: 首页 + print("\n2. 测试首页...") + try: + response = requests.get(f"{BASE_URL}/", timeout=5) + print(f" 状态码: {response.status_code}") + data = response.json() + print(f" API名称: {data.get('name')}") + print(f" 版本: {data.get('version')}") + print(f" SSH状态: {data.get('ssh_status')}") + except Exception as e: + print(f" ✗ 失败: {e}") + + # 测试3: 资产类型检测 + print("\n3. 测试资产类型检测...") + codes = ["000300.SH", "NDX", "BTC"] + for code in codes: + try: + response = requests.get( + f"{BASE_URL}/api/v1/asset-type", + params={"code": code}, + timeout=5 + ) + data = response.json() + print(f" {code:15s} -> {data.get('asset_type'):15s} ({data.get('description')})") + except Exception as e: + print(f" {code:15s} ✗ 失败: {e}") + + # 测试4: 获取K线数据 + print("\n4. 测试获取K线数据...") + try: + response = requests.get( + f"{BASE_URL}/api/v1/ohlcv", + params={ + "code": "000300.SH", + "start": "2024-01-01", + "end": "2024-01-31" + }, + timeout=30 + ) + data = response.json() + if "error" in data: + print(f" ✗ 错误: {data['error']}") + else: + print(f" ✓ 获取成功: {data.get('count')} 条") + print(f" 资产类型: {data.get('asset_type')}") + if data.get('data'): + latest = data['data'][-1] + print(f" 最新数据: {latest.get('date')} 收盘 {latest.get('close')}") + except Exception as e: + print(f" ✗ 失败: {e}") + + # 测试5: 批量获取 + print("\n5. 测试批量获取...") + try: + response = requests.post( + f"{BASE_URL}/api/v1/ohlcv/batch", + json={ + "codes": ["000300.SH", "510300.SH"], + "start": "2024-01-01", + "end": "2024-01-31" + }, + timeout=60 + ) + data = response.json() + print(f" 成功: {data.get('success_count')}/{data.get('total')}") + print(f" 失败: {data.get('failed_count')}/{data.get('total')}") + + for code, result in data.get('results', {}).items(): + if 'error' in result: + print(f" ✗ {code}: {result['error']}") + else: + print(f" ✓ {code}: {result['count']} 条") + except Exception as e: + print(f" ✗ 失败: {e}") + + # 测试6: 支持的代码 + print("\n6. 测试获取支持的代码...") + try: + response = requests.get(f"{BASE_URL}/api/v1/supported-codes", timeout=5) + data = response.json() + print(f" 支持的资产类型: {len(data)} 种") + for asset_type, info in list(data.items())[:3]: + print(f" - {asset_type}: {info.get('description')}") + print(f" 示例: {', '.join(info.get('examples', [])[:3])}") + except Exception as e: + print(f" ✗ 失败: {e}") + + print("\n" + "="*60) + print("测试完成") + print("="*60) + + +if __name__ == "__main__": + print("\n" + "="*60) + print("Flask API 测试客户端") + print("="*60) + print("\n请确保 Flask 服务已启动:") + print(" ./start_flask_server.sh") + print("或") + print(" python core/datasource/flask_server.py") + print("") + + test_api() diff --git a/tests/test_ssh_tunnel.py b/tests/test_ssh_tunnel.py new file mode 100644 index 0000000..18d2b5f --- /dev/null +++ b/tests/test_ssh_tunnel.py @@ -0,0 +1,217 @@ +""" +使用香港服务器 SSH 隧道测试 UniversalDataFetcher +============================================== +通过阿里云香港 ECS 服务器建立 SOCKS5 代理,获取港美股数据 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到路径 +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from dotenv import load_dotenv +load_dotenv() + +from core.datasource.universal_fetcher import UniversalDataFetcher, fetch_kline + + +def test_with_ssh_tunnel(): + """使用 SSH 隧道获取港美股数据""" + print("\n" + "="*60) + print("使用香港服务器 SSH 隧道获取数据") + print("="*60) + + # SSH 隧道配置(使用 hk_ecs.pem) + ssh_config = { + "enabled": True, + "host": "8.218.167.69", # 阿里云香港 ECS IP + "port": 22, # SSH 端口 + "username": "root", # SSH 用户名 + "key_path": "hk_ecs.pem", # SSH 私钥路径(相对于项目根目录) + "local_port": 1080, # 本地 SOCKS5 代理端口 + } + + # 要测试的标的 + test_codes = [ + ("NDX", "纳斯达克100"), + ("SPX", "标普500"), + ("HSI", "恒生指数"), + ("N225", "日经225"), + ("GDAXI", "德国DAX"), + ] + + print(f"\nSSH 配置:") + print(f" 服务器: {ssh_config['host']}:{ssh_config['port']}") + print(f" 用户名: {ssh_config['username']}") + print(f" 私钥: {ssh_config['key_path']}") + print(f" 本地端口: {ssh_config['local_port']}") + + # 创建带 SSH 隧道的数据获取器 + fetcher = UniversalDataFetcher(ssh_config=ssh_config) + + print("\n正在建立 SSH 隧道...") + with fetcher: + print("✓ SSH 隧道已建立") + print("\n开始获取数据...") + + for code, name in test_codes: + print(f"\n[{code}] {name}") + try: + df = fetcher.fetch(code, "2024-01-01", "2024-03-31") + if df is not None and len(df) > 0: + print(f" ✓ 获取成功: {len(df)} 条") + print(f" 日期范围: {df.index.min().strftime('%Y-%m-%d')} ~ {df.index.max().strftime('%Y-%m-%d')}") + print(f" 最新收盘价: {df['close'].iloc[-1]:.3f}") + print(f" 数据预览:") + print(df.tail(3)[['open', 'high', 'low', 'close']].to_string()) + else: + print(f" ✗ 无数据") + except Exception as e: + print(f" ✗ 错误: {e}") + + print("\n✓ SSH 隧道已关闭") + + +def test_without_ssh(): + """不使用 SSH 隧道(仅获取A股数据)""" + print("\n" + "="*60) + print("不使用 SSH 隧道(仅A股)") + print("="*60) + + test_codes = [ + ("000300.SH", "沪深300"), + ("510300.SH", "沪深300ETF"), + ("AU.SHF", "黄金期货"), + ] + + fetcher = UniversalDataFetcher(ssh_config={"enabled": False}) + + with fetcher: + for code, name in test_codes: + print(f"\n[{code}] {name}") + df = fetcher.fetch(code, "2024-01-01", "2024-03-31") + if df is not None: + print(f" ✓ {len(df)} 条, 最新: {df['close'].iloc[-1]:.3f}") + else: + print(f" ✗ 无数据") + + +def test_mixed_markets(): + """混合市场测试(A股 + 港美股)""" + print("\n" + "="*60) + print("混合市场测试(A股 + 港美股)") + print("="*60) + + ssh_config = { + "enabled": True, + "host": "8.218.167.69", + "port": 22, + "username": "root", + "key_path": "hk_ecs.pem", + "local_port": 1080, + } + + # 混合代码 + codes = [ + "000300.SH", # A股指数 + "NDX", # 美股指数 + "HSI", # 港股指数 + "510300.SH", # A股ETF + "N225", # 日本指数 + ] + + fetcher = UniversalDataFetcher(ssh_config=ssh_config) + + print("\n开始批量获取...") + with fetcher: + results = fetcher.fetch_multiple(codes, "2024-01-01", "2024-03-31") + + print("\n获取结果:") + for code, df in results.items(): + if df is not None: + print(f" ✓ {code:15s} {len(df):4d} 条, " + f"最新: {df['close'].iloc[-1]:.3f}") + else: + print(f" ✗ {code:15s} 无数据") + + +def test_ssh_connection_only(): + """仅测试 SSH 连接""" + print("\n" + "="*60) + print("测试 SSH 连接") + print("="*60) + + ssh_config = { + "enabled": True, + "host": "8.218.167.69", + "port": 22, + "username": "root", + "key_path": "hk_ecs.pem", + "local_port": 1080, + } + + print("\n正在建立 SSH 隧道...") + fetcher = UniversalDataFetcher(ssh_config=ssh_config) + + try: + with fetcher: + print("✓ SSH 隧道建立成功!") + print(f" 本地 SOCKS5 代理: socks5h://127.0.0.1:{ssh_config['local_port']}") + + # 测试一个简单的 HTTP 请求 + import requests + print("\n测试代理连接...") + try: + response = requests.get( + "https://www.google.com", + proxies={ + "http": f"socks5h://127.0.0.1:{ssh_config['local_port']}", + "https": f"socks5h://127.0.0.1:{ssh_config['local_port']}", + }, + timeout=10 + ) + print(f"✓ 代理测试成功!状态码: {response.status_code}") + except Exception as e: + print(f"✗ 代理测试失败: {e}") + + print("\n✓ SSH 隧道已关闭") + except Exception as e: + print(f"✗ SSH 隧道建立失败: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + print("\n" + "="*60) + print("UniversalDataFetcher + 香港 SSH 隧道测试") + print("="*60) + + # 选择要运行的测试 + tests = [ + ("1", "仅测试 SSH 连接", test_ssh_connection_only), + ("2", "不使用 SSH(仅A股)", test_without_ssh), + ("3", "使用 SSH 隧道(港美股)", test_with_ssh_tunnel), + ("4", "混合市场测试", test_mixed_markets), + ] + + print("\n可用测试:") + for num, name, _ in tests: + print(f" {num}. {name}") + + # 运行所有测试 + print("\n" + "="*60) + print("运行所有测试...") + print("="*60) + + for num, name, func in tests: + try: + func() + except Exception as e: + print(f"\n测试 '{name}' 失败: {e}") + import traceback + traceback.print_exc() + + print("\n" + "="*60) + print("测试完成") + print("="*60) diff --git a/tests/test_universal_fetcher.py b/tests/test_universal_fetcher.py new file mode 100644 index 0000000..8d5b826 --- /dev/null +++ b/tests/test_universal_fetcher.py @@ -0,0 +1,228 @@ +""" +统一数据获取接口测试 +==================== +测试各种资产类型的K线数据获取 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到路径 +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from dotenv import load_dotenv +load_dotenv() + +from core.datasource.universal_fetcher import ( + UniversalDataFetcher, + AssetTypeDetector, + detect_asset_type, + fetch_kline +) + + +def test_asset_detection(): + """测试资产类型检测""" + print("\n" + "="*60) + print("测试1: 资产类型检测") + print("="*60) + + test_cases = [ + # A股指数 + ("000300.SH", "china_index"), + ("399006.SZ", "china_index"), + ("H30269.CSI", "china_index"), + + # A股ETF + ("510300.SH", "china_etf"), + ("159915.SZ", "china_etf"), + ("513100.SH", "china_etf"), + + # A股股票 + ("600000.SH", "china_stock"), + ("000001.SZ", "china_stock"), + + # 港股 + ("HSI", "hk_index"), + ("HSTECH.HK", "hk_index"), + + # 美股 + ("NDX", "us_index"), + ("SPX", "us_index"), + ("AAPL", "us_stock"), + + # 期货 + ("AU.SHF", "futures"), + ("CU.SHF", "futures"), + + # 加密货币 + ("BTC", "crypto"), + ("ETH", "crypto"), + ] + + correct = 0 + for code, expected in test_cases: + result = detect_asset_type(code) + status = "✓" if result == expected else "✗" + if result == expected: + correct += 1 + print(f" {status} {code:15s} -> {result:15s} (期望: {expected})") + + print(f"\n检测准确率: {correct}/{len(test_cases)} ({100*correct/len(test_cases):.1f}%)") + + +def test_single_fetch(): + """测试单只标的获取""" + print("\n" + "="*60) + print("测试2: 单只标的获取") + print("="*60) + + # 测试A股指数 + print("\n[A股指数] 000300.SH (沪深300)") + df = fetch_kline("000300.SH", "2024-01-01", "2024-03-31") + if df is not None: + print(f" ✓ 获取成功: {len(df)} 条") + print(f" 日期范围: {df.index.min()} ~ {df.index.max()}") + print(f" 列: {list(df.columns)}") + print(f" 最新数据:\n{df.tail(3)}") + else: + print(" ✗ 获取失败") + + # 测试A股ETF + print("\n[A股ETF] 510300.SH (沪深300ETF)") + df = fetch_kline("510300.SH", "2024-01-01", "2024-03-31") + if df is not None: + print(f" ✓ 获取成功: {len(df)} 条") + print(f" 最新收盘价: {df['close'].iloc[-1]:.3f}") + else: + print(" ✗ 获取失败") + + # 测试美股指数 + print("\n[美股指数] NDX (纳斯达克100)") + df = fetch_kline("NDX", "2024-01-01", "2024-03-31") + if df is not None: + print(f" ✓ 获取成功: {len(df)} 条") + print(f" 最新收盘价: {df['close'].iloc[-1]:.3f}") + else: + print(" ✗ 获取失败(可能需要SSH隧道)") + + # 测试港股指数 + print("\n[港股指数] HSI (恒生指数)") + df = fetch_kline("HSI", "2024-01-01", "2024-03-31") + if df is not None: + print(f" ✓ 获取成功: {len(df)} 条") + print(f" 最新收盘价: {df['close'].iloc[-1]:.3f}") + else: + print(" ✗ 获取失败(可能需要SSH隧道)") + + +def test_multiple_fetch(): + """测试批量获取""" + print("\n" + "="*60) + print("测试3: 批量获取") + print("="*60) + + codes = [ + "000300.SH", # A股指数 + "510300.SH", # A股ETF + "NDX", # 美股指数 + "HSI", # 港股指数 + "AU.SHF", # 期货 + # "BTC", # 加密货币(需要SSH隧道) + ] + + fetcher = UniversalDataFetcher() + with fetcher: + results = fetcher.fetch_multiple(codes, "2024-01-01", "2024-03-31") + + print(f"\n获取结果:") + for code, df in results.items(): + if df is not None: + print(f" ✓ {code:15s} {len(df):4d} 条, " + f"最新收盘价: {df['close'].iloc[-1]:.3f}") + else: + print(f" ✗ {code:15s} 无数据") + + +def test_context_manager(): + """测试上下文管理器(SSH隧道)""" + print("\n" + "="*60) + print("测试4: 上下文管理器(SSH隧道)") + print("="*60) + + # 不启用SSH + print("\n[不启用SSH] 获取A股数据(应成功)") + fetcher = UniversalDataFetcher(ssh_config={"enabled": False}) + with fetcher: + df = fetcher.fetch("000300.SH", "2024-01-01", "2024-01-31") + if df is not None: + print(f" ✓ 成功: {len(df)} 条") + else: + print(" ✗ 失败") + + # 启用SSH(如果配置了) + ssh_config = { + "enabled": False, # 改为 True 并填入实际配置以测试 + "host": "", + "port": 22, + "username": "", + "key_path": "", + "local_port": 1080, + } + + if ssh_config["enabled"]: + print("\n[启用SSH] 获取美股数据") + fetcher = UniversalDataFetcher(ssh_config=ssh_config) + with fetcher: + df = fetcher.fetch("NDX", "2024-01-01", "2024-01-31") + if df is not None: + print(f" ✓ 成功: {len(df)} 条") + else: + print(" ✗ 失败") + else: + print("\n[跳过SSH测试] SSH未启用") + + +def test_edge_cases(): + """测试边界情况""" + print("\n" + "="*60) + print("测试5: 边界情况") + print("="*60) + + # 测试无效代码 + print("\n[无效代码] INVALID") + df = fetch_kline("INVALID", "2024-01-01", "2024-01-31") + print(f" 结果: {df}") + + # 测试日期范围 + print("\n[空日期范围]") + df = fetch_kline("000300.SH", "2030-01-01", "2030-01-31") + if df is None or len(df) == 0: + print(" ✓ 正确处理(无数据)") + else: + print(f" ✗ 意外获取到数据: {len(df)} 条") + + # 测试代码格式转换 + print("\n[代码格式转换] 000300.SS -> 000300.SH") + df = fetch_kline("000300.SS", "2024-01-01", "2024-01-31") + if df is not None: + print(f" ✓ 转换成功: {len(df)} 条") + else: + print(" ✗ 失败") + + +if __name__ == "__main__": + print("\n" + "="*60) + print("统一数据获取接口测试") + print("="*60) + + # 运行所有测试 + test_asset_detection() + test_single_fetch() + test_multiple_fetch() + test_context_manager() + test_edge_cases() + + print("\n" + "="*60) + print("测试完成") + print("="*60)