""" Flask 数据服务 API ================== 提供 RESTful API 接口,支持获取各类资产的 K 线数据 支持的资产类型: - A股指数 (000300.SH, 399006.SZ) - A股ETF (510300.SH, 159915.SZ) - A股股票 (600000.SH, 000001.SZ) - 港股指数 (HSI, HSTECH.HK) - 美股指数 (NDX, SPX, DJI) - 美股股票 (AAPL, MSFT, GOOGL) - 期货合约 (AU.SHF, CU.SHF) - 加密货币 (BTC, ETH) 运行: python core/datasource/flask_server.py API 文档: GET /health - 健康检查 GET /api/v1/asset-type?code=000300.SH - 检测资产类型 GET /api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31 - 获取K线数据 POST /api/v1/ohlcv/batch - 批量获取K线数据 """ import os import sys import json import hashlib from pathlib import Path from datetime import datetime, timedelta from typing import Optional, Dict, Any, List from functools import wraps # 添加项目根目录到路径 project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) from dotenv import load_dotenv load_dotenv() from flask import Flask, request, jsonify from flask_cors import CORS import pandas as pd from core.datasource.universal_fetcher import ( UniversalDataFetcher, AssetTypeDetector, ) # ============================================================ # Flask 应用配置 # ============================================================ app = Flask(__name__) CORS(app) # 启用跨域支持 # 全局数据获取器实例 fetcher: Optional[UniversalDataFetcher] = None ssh_config: Optional[Dict] = None # 内存缓存 _cache: Dict[str, Any] = {} _cache_ttl: int = int(os.getenv('CACHE_TTL_SECONDS', '300')) # 默认5分钟 def _get_cache_key(code: str, start: str, end: str) -> str: """生成缓存键""" key = f"{code}:{start}:{end}" return hashlib.md5(key.encode()).hexdigest() def _get_cached(key: str) -> Optional[Any]: """获取缓存数据""" if key not in _cache: return None data, timestamp = _cache[key] if datetime.now().timestamp() - timestamp > _cache_ttl: # 缓存过期 del _cache[key] return None return data def _set_cache(key: str, data: Any): """设置缓存数据""" _cache[key] = (data, datetime.now().timestamp()) def clear_cache(): """清理所有缓存""" _cache.clear() print("✓ 缓存已清理") def get_ssh_config() -> Optional[Dict]: """ 从环境变量获取 SSH 配置 环境变量: SSH_ENABLED: 是否启用 SSH 隧道 (true/false) SSH_HOST: SSH 服务器地址 SSH_PORT: SSH 端口 (默认 22) SSH_USERNAME: SSH 用户名 SSH_KEY_PATH: SSH 私钥路径 SSH_LOCAL_PORT: 本地 SOCKS5 端口 (默认 1080) """ enabled = os.getenv('SSH_ENABLED', 'false').lower() == 'true' if not enabled: return None return { "enabled": True, "host": os.getenv('SSH_HOST', ''), "port": int(os.getenv('SSH_PORT', '22')), "username": os.getenv('SSH_USERNAME', ''), "key_path": os.getenv('SSH_KEY_PATH', 'hk_ecs.pem'), "local_port": int(os.getenv('SSH_LOCAL_PORT', '1080')), } def get_fetcher() -> UniversalDataFetcher: """获取或创建数据获取器实例""" global fetcher, ssh_config if fetcher is None: ssh_config = get_ssh_config() fetcher = UniversalDataFetcher(ssh_config=ssh_config) return fetcher def dataframe_to_json(df: pd.DataFrame) -> Dict: """将 DataFrame 转换为 JSON 可序列化的字典""" if df is None or len(df) == 0: return {"data": [], "count": 0} # 重置索引,将日期转为列 df_reset = df.reset_index() # 处理日期列 if 'date' in df_reset.columns: df_reset['date'] = df_reset['date'].dt.strftime('%Y-%m-%d') elif 'index' in df_reset.columns: df_reset['index'] = pd.to_datetime(df_reset['index']).dt.strftime('%Y-%m-%d') df_reset = df_reset.rename(columns={'index': 'date'}) # 转换为字典列表 records = df_reset.to_dict(orient='records') return { "data": records, "count": len(records), "columns": list(df_reset.columns), "date_range": { "start": df.index.min().strftime('%Y-%m-%d'), "end": df.index.max().strftime('%Y-%m-%d'), } if len(df) > 0 else None } def validate_date(date_str: str) -> bool: """验证日期格式""" try: datetime.strptime(date_str, '%Y-%m-%d') return True except ValueError: return False def get_default_dates() -> tuple: """获取默认日期范围(最近3个月)""" end = datetime.now() start = end - timedelta(days=90) return start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d') # ============================================================ # API 路由 # ============================================================ @app.route('/') def index(): """首页 - API 信息""" return jsonify({ "name": "Universal Data Fetcher API", "version": "1.1.0", "description": "统一数据获取服务,支持A股、港股、美股、期货、加密货币", "features": [ "自动资产类型识别", "内存缓存(默认5分钟TTL)", "SSH隧道支持(港美股)", "批量数据获取" ], "endpoints": { "health": "/health", "asset_type": "/api/v1/asset-type?code={code}", "ohlcv": "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}", "ohlcv_nocache": "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}&nocache=true", "batch": "POST /api/v1/ohlcv/batch", "cache_clear": "POST /api/v1/cache/clear", "cache_stats": "/api/v1/cache/stats", }, "supported_assets": [ "A股指数 (000300.SH)", "A股ETF (510300.SH)", "A股股票 (600000.SH)", "港股指数 (HSI)", "美股指数 (NDX)", "美股股票 (AAPL)", "期货 (AU.SHF)", "加密货币 (BTC)", ], "cache_config": { "ttl_seconds": _cache_ttl, "current_size": len(_cache) }, "ssh_status": "enabled" if ssh_config and ssh_config.get('enabled') else "disabled", }) @app.route('/health') def health(): """健康检查""" return jsonify({ "status": "healthy", "timestamp": datetime.now().isoformat(), "ssh_configured": ssh_config is not None and ssh_config.get('enabled', False), }) @app.route('/api/v1/asset-type') def detect_asset_type(): """ 检测资产类型 Query Parameters: code: 标的代码 (required) Returns: { "code": "000300.SH", "asset_type": "china_index", "description": "A股指数" } """ code = request.args.get('code', '').strip() if not code: return jsonify({ "error": "Missing required parameter: code", "example": "/api/v1/asset-type?code=000300.SH" }), 400 asset_type = AssetTypeDetector.detect(code) # 资产类型描述映射 descriptions = { 'china_index': 'A股指数', 'china_etf': 'A股ETF', 'china_stock': 'A股股票', 'hk_index': '港股指数', 'hk_stock': '港股股票', 'us_index': '美股指数', 'us_stock': '美股股票', 'futures': '期货合约', 'crypto': '加密货币', } return jsonify({ "code": code, "asset_type": asset_type, "description": descriptions.get(asset_type, '未知类型'), }) @app.route('/api/v1/ohlcv') def get_ohlcv(): """ 获取单只标的的 OHLCV 数据(支持缓存) Query Parameters: code: 标的代码 (required) start: 开始日期,格式 YYYY-MM-DD (optional, 默认90天前) end: 结束日期,格式 YYYY-MM-DD (optional, 默认今天) retry: 重试次数 (optional, 默认3) nocache: 是否跳过缓存 (optional, 默认false) Returns: { "code": "000300.SH", "asset_type": "china_index", "data": [ {"date": "2024-01-02", "open": 3500.0, "high": 3550.0, ...}, ... ], "count": 58, "date_range": {"start": "2024-01-02", "end": "2024-03-29"}, "cached": false } """ code = request.args.get('code', '').strip() start = request.args.get('start', '').strip() end = request.args.get('end', '').strip() retry = request.args.get('retry', '3') nocache = request.args.get('nocache', 'false').lower() == 'true' # 参数验证 if not code: return jsonify({ "error": "Missing required parameter: code", "example": "/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31" }), 400 # 设置默认日期 if not start or not end: default_start, default_end = get_default_dates() start = start or default_start end = end or default_end # 日期格式验证 if not validate_date(start) or not validate_date(end): return jsonify({ "error": "Invalid date format. Use YYYY-MM-DD", "start": start, "end": end, }), 400 try: retry_count = int(retry) except ValueError: retry_count = 3 # 检测资产类型 asset_type = AssetTypeDetector.detect(code) # 检查缓存 cache_key = _get_cache_key(code, start, end) if not nocache: cached_data = _get_cached(cache_key) if cached_data: cached_data['cached'] = True return jsonify(cached_data) # 获取数据 fetcher = get_fetcher() try: with fetcher: df = fetcher.fetch(code, start, end, retry=retry_count) if df is None or len(df) == 0: return jsonify({ "code": code, "asset_type": asset_type, "error": "No data available for the specified code and date range", "start": start, "end": end, }), 404 result = dataframe_to_json(df) result['code'] = code result['asset_type'] = asset_type result['cached'] = False # 存入缓存 if not nocache: _set_cache(cache_key, result.copy()) return jsonify(result) except Exception as e: return jsonify({ "code": code, "asset_type": asset_type, "error": str(e), "start": start, "end": end, }), 500 @app.route('/api/v1/ohlcv/batch', methods=['POST']) def batch_ohlcv(): """ 批量获取多只标的的 OHLCV 数据 Request Body: { "codes": ["000300.SH", "NDX", "HSI"], "start": "2024-01-01", "end": "2024-03-31", "retry": 3 } Returns: { "results": { "000300.SH": {"data": [...], "count": 58, ...}, "NDX": {"data": [...], "count": 61, ...}, "HSI": {"error": "No data available", ...} }, "success_count": 2, "failed_count": 1, "total": 3 } """ data = request.get_json() if not data: return jsonify({ "error": "Missing request body", "example": { "codes": ["000300.SH", "NDX"], "start": "2024-01-01", "end": "2024-03-31", } }), 400 codes = data.get('codes', []) start = data.get('start', '').strip() end = data.get('end', '').strip() retry = data.get('retry', 3) if not codes or not isinstance(codes, list): return jsonify({ "error": "Missing or invalid parameter: codes (must be a list)" }), 400 # 设置默认日期 if not start or not end: default_start, default_end = get_default_dates() start = start or default_start end = end or default_end # 日期格式验证 if not validate_date(start) or not validate_date(end): return jsonify({ "error": "Invalid date format. Use YYYY-MM-DD", "start": start, "end": end, }), 400 # 获取数据 fetcher = get_fetcher() results = {} success_count = 0 failed_count = 0 try: with fetcher: for code in codes: try: df = fetcher.fetch(code, start, end, retry=retry) if df is not None and len(df) > 0: result = dataframe_to_json(df) result['code'] = code result['asset_type'] = AssetTypeDetector.detect(code) results[code] = result success_count += 1 else: results[code] = { "code": code, "asset_type": AssetTypeDetector.detect(code), "error": "No data available", "data": [], "count": 0, } failed_count += 1 except Exception as e: results[code] = { "code": code, "asset_type": AssetTypeDetector.detect(code), "error": str(e), "data": [], "count": 0, } failed_count += 1 except Exception as e: return jsonify({ "error": f"Batch fetch failed: {str(e)}" }), 500 return jsonify({ "results": results, "success_count": success_count, "failed_count": failed_count, "total": len(codes), "start": start, "end": end, }) @app.route('/api/v1/cache/clear', methods=['POST']) def clear_cache_endpoint(): """ 清理所有缓存数据 Returns: {"message": "Cache cleared", "cache_size": 0} """ cache_size = len(_cache) clear_cache() return jsonify({ "message": "Cache cleared successfully", "previous_size": cache_size, "current_size": 0 }) @app.route('/api/v1/cache/stats') def cache_stats(): """ 获取缓存统计信息 Returns: { "cache_size": 10, "cache_ttl_seconds": 300, "memory_estimate_mb": 5.2 } """ # 估算内存使用(粗略估计) import sys total_size = 0 for key, (data, _) in _cache.items(): total_size += sys.getsizeof(key) + sys.getsizeof(data) return jsonify({ "cache_size": len(_cache), "cache_ttl_seconds": _cache_ttl, "memory_estimate_mb": round(total_size / 1024 / 1024, 2) }) @app.route('/api/v1/supported-codes') def get_supported_codes(): """ 获取支持的代码示例 Returns: { "china_index": ["000300.SH", "399006.SZ", ...], "china_etf": ["510300.SH", "159915.SZ", ...], ... } """ return jsonify({ "china_index": { "description": "A股指数", "examples": ["000300.SH", "399006.SZ", "000016.SH", "H30269.CSI"], }, "china_etf": { "description": "A股ETF", "examples": ["510300.SH", "159915.SZ", "510500.SH", "513100.SH"], }, "china_stock": { "description": "A股股票", "examples": ["600000.SH", "000001.SZ"], }, "hk_index": { "description": "港股指数", "examples": ["HSI", "HSTECH.HK"], }, "us_index": { "description": "美股指数", "examples": ["NDX", "SPX", "DJI", "N225", "GDAXI"], }, "us_stock": { "description": "美股股票", "examples": ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"], }, "futures": { "description": "期货合约", "examples": ["AU.SHF", "CU.SHF"], }, "crypto": { "description": "加密货币", "examples": ["BTC", "ETH"], }, }) # ============================================================ # 错误处理 # ============================================================ @app.errorhandler(404) def not_found(error): return jsonify({ "error": "Endpoint not found", "available_endpoints": [ "/", "/health", "/api/v1/asset-type?code={code}", "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}", "/api/v1/ohlcv/batch", "/api/v1/cache/clear", "/api/v1/cache/stats", "/api/v1/supported-codes", ] }), 404 @app.errorhandler(500) def internal_error(error): return jsonify({ "error": "Internal server error", "message": str(error) }), 500 # ============================================================ # 启动服务 # ============================================================ if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Universal Data Fetcher API Server') parser.add_argument('--host', default='0.0.0.0', help='Host to bind (default: 0.0.0.0)') parser.add_argument('--port', type=int, default=5000, help='Port to bind (default: 5000)') parser.add_argument('--debug', action='store_true', help='Enable debug mode') args = parser.parse_args() # 预加载 SSH 配置 ssh_config = get_ssh_config() if ssh_config and ssh_config.get('enabled'): print(f"✓ SSH 隧道已配置: {ssh_config['host']}:{ssh_config['port']}") else: print("✗ SSH 隧道未启用(仅支持A股数据)") print(f"\n🚀 启动 Universal Data Fetcher API Server") print(f" Host: {args.host}") print(f" Port: {args.port}") print(f" Debug: {args.debug}") print(f"\n📖 API 文档: http://{args.host}:{args.port}/") print(f" 健康检查: http://{args.host}:{args.port}/health") print(f"\n💡 示例请求:") print(f" curl 'http://{args.host}:{args.port}/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31'") print(f"\n") app.run(host=args.host, port=args.port, debug=args.debug)