""" Flask 数据服务 API ================== 提供 RESTful API 接口,支持获取各类资产的 K 线数据 支持的资产类型: - A股指数 (000300.SH, 399006.SZ) - A股ETF (510300.SH, 159915.SZ) - A股股票 (600000.SH, 000001.SZ) - 港股指数 (HSI, HSTECH.HK) - 美股指数 (NDX, SPX, DJI) - 美股股票 (AAPL, MSFT, GOOGL) - 期货合约 (AU.SHF, CU.SHF) - 加密货币 (BTC, ETH) 运行: python core/datasource/flask_server.py API 文档: GET /health - 健康检查 GET /api/v1/asset-type?code=000300.SH - 检测资产类型 GET /api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31 - 获取K线数据 POST /api/v1/ohlcv/batch - 批量获取K线数据 """ import os import sys import json from pathlib import Path from datetime import datetime, timedelta from typing import Optional, Dict, Any, List from functools import lru_cache # 添加项目根目录到路径 project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) from dotenv import load_dotenv load_dotenv() from flask import Flask, request, jsonify from flask_cors import CORS import pandas as pd from core.datasource.universal_fetcher import ( UniversalDataFetcher, AssetTypeDetector, ) # ============================================================ # Flask 应用配置 # ============================================================ app = Flask(__name__) CORS(app) # 启用跨域支持 # 全局数据获取器实例 fetcher: Optional[UniversalDataFetcher] = None ssh_config: Optional[Dict] = None def get_ssh_config() -> Optional[Dict]: """ 从环境变量获取 SSH 配置 环境变量: SSH_ENABLED: 是否启用 SSH 隧道 (true/false) SSH_HOST: SSH 服务器地址 SSH_PORT: SSH 端口 (默认 22) SSH_USERNAME: SSH 用户名 SSH_KEY_PATH: SSH 私钥路径 SSH_LOCAL_PORT: 本地 SOCKS5 端口 (默认 1080) """ enabled = os.getenv('SSH_ENABLED', 'false').lower() == 'true' if not enabled: return None return { "enabled": True, "host": os.getenv('SSH_HOST', ''), "port": int(os.getenv('SSH_PORT', '22')), "username": os.getenv('SSH_USERNAME', ''), "key_path": os.getenv('SSH_KEY_PATH', 'hk_ecs.pem'), "local_port": int(os.getenv('SSH_LOCAL_PORT', '1080')), } def get_fetcher() -> UniversalDataFetcher: """获取或创建数据获取器实例""" global fetcher, ssh_config if fetcher is None: ssh_config = get_ssh_config() fetcher = UniversalDataFetcher(ssh_config=ssh_config) return fetcher # 缓存配置 CACHE_MAXSIZE = int(os.getenv('CACHE_MAXSIZE', '128')) # 默认缓存128条数据 CACHE_TTL_SECONDS = int(os.getenv('CACHE_TTL_SECONDS', '7200')) # 默认5分钟过期 # 带时间戳的缓存条目 class TimedCacheEntry: def __init__(self, data): self.data = data self.timestamp = datetime.now() def is_expired(self) -> bool: return (datetime.now() - self.timestamp).total_seconds() > CACHE_TTL_SECONDS # 使用 lru_cache 包装数据获取函数 @lru_cache(maxsize=CACHE_MAXSIZE) def _fetch_data_cached(code: str, start: str, end: str, retry: int) -> Optional[str]: """ 获取数据的缓存版本 返回 JSON 序列化的字符串,因为 lru_cache 需要可哈希的参数 """ fetcher = get_fetcher() try: with fetcher: df = fetcher.fetch(code, start, end, retry=retry) if df is None or len(df) == 0: return None result = dataframe_to_json(df) result['code'] = code result['asset_type'] = AssetTypeDetector.detect(code) return json.dumps(result) except Exception as e: return json.dumps({"error": str(e)}) def fetch_data_with_ttl(code: str, start: str, end: str, retry: int = 3, nocache: bool = False) -> Optional[Dict]: """ 获取数据,支持TTL缓存 """ if nocache: # 直接获取,不使用缓存 result_json = _fetch_data_cached.__wrapped__(code, start, end, retry) return json.loads(result_json) if result_json else None # 使用缓存 cache_key = (code, start, end, retry) # 检查内存中的TTL缓存 if hasattr(fetch_data_with_ttl, '_ttl_cache'): if cache_key in fetch_data_with_ttl._ttl_cache: entry = fetch_data_with_ttl._ttl_cache[cache_key] if not entry.is_expired(): return entry.data # 过期,删除 del fetch_data_with_ttl._ttl_cache[cache_key] else: fetch_data_with_ttl._ttl_cache = {} # 从 lru_cache 获取 result_json = _fetch_data_cached(code, start, end, retry) if result_json is None: return None result = json.loads(result_json) # 存入TTL缓存 fetch_data_with_ttl._ttl_cache[cache_key] = TimedCacheEntry(result) return result def clear_data_cache(): """清理缓存""" _fetch_data_cached.cache_clear() if hasattr(fetch_data_with_ttl, '_ttl_cache'): fetch_data_with_ttl._ttl_cache.clear() def get_cache_info() -> Dict: """获取缓存信息""" info = _fetch_data_cached.cache_info() ttl_size = len(fetch_data_with_ttl._ttl_cache) if hasattr(fetch_data_with_ttl, '_ttl_cache') else 0 return { "lru_cache": { "hits": info.hits, "misses": info.misses, "maxsize": info.maxsize, "currsize": info.currsize, }, "ttl_cache_size": ttl_size, "ttl_seconds": CACHE_TTL_SECONDS, } def dataframe_to_json(df: pd.DataFrame) -> Dict: """将 DataFrame 转换为 JSON 可序列化的字典""" if df is None or len(df) == 0: return {"data": [], "count": 0} # 重置索引,将日期转为列 df_reset = df.reset_index() # 处理日期列(支持多种可能的列名) date_columns = ['date', 'Date', 'index', 'trade_date', 'datetime'] for col in date_columns: if col in df_reset.columns: try: df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime('%Y-%m-%d') # 统一重命名为 'date' if col != 'date': df_reset = df_reset.rename(columns={col: 'date'}) break except Exception: pass # 确保所有数据都是 JSON 可序列化的 for col in df_reset.columns: if df_reset[col].dtype == 'object': # 转换可能的 Timestamp 对象 try: df_reset[col] = df_reset[col].apply( lambda x: x.strftime('%Y-%m-%d') if hasattr(x, 'strftime') else x ) except Exception: pass # 转换为字典列表 records = df_reset.to_dict(orient='records') return { "data": records, "count": len(records), "columns": list(df_reset.columns), "date_range": { "start": df.index.min().strftime('%Y-%m-%d') if hasattr(df.index.min(), 'strftime') else str(df.index.min()), "end": df.index.max().strftime('%Y-%m-%d') if hasattr(df.index.max(), 'strftime') else str(df.index.max()), } if len(df) > 0 else None } def validate_date(date_str: str) -> bool: """验证日期格式""" try: datetime.strptime(date_str, '%Y-%m-%d') return True except ValueError: return False def get_default_dates() -> tuple: """获取默认日期范围(最近3个月)""" end = datetime.now() start = end - timedelta(days=90) return start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d') # ============================================================ # API 路由 # ============================================================ @app.route('/') def index(): """首页 - API 信息""" cache_info = get_cache_info() return jsonify({ "name": "Universal Data Fetcher API", "version": "1.1.0", "description": "统一数据获取服务,支持A股、港股、美股、期货、加密货币", "features": [ "自动资产类型识别", "LRU+TTL 双缓存机制", "SSH隧道支持(港美股)", "批量数据获取" ], "endpoints": { "health": "/health", "asset_type": "/api/v1/asset-type?code={code}", "ohlcv": "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}", "ohlcv_nocache": "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}&nocache=true", "batch": "POST /api/v1/ohlcv/batch", "cache_clear": "POST /api/v1/cache/clear", "cache_stats": "/api/v1/cache/stats", }, "supported_assets": [ "A股指数 (000300.SH)", "A股ETF (510300.SH)", "A股股票 (600000.SH)", "港股指数 (HSI)", "美股指数 (NDX)", "美股股票 (AAPL)", "期货 (AU.SHF)", "加密货币 (BTC)", ], "cache_config": { "maxsize": CACHE_MAXSIZE, "ttl_seconds": CACHE_TTL_SECONDS, "stats": cache_info['lru_cache'] }, "ssh_status": "enabled" if ssh_config and ssh_config.get('enabled') else "disabled", }) @app.route('/health') def health(): """健康检查""" return jsonify({ "status": "healthy", "timestamp": datetime.now().isoformat(), "ssh_configured": ssh_config is not None and ssh_config.get('enabled', False), }) @app.route('/api/v1/asset-type') def detect_asset_type(): """ 检测资产类型 Query Parameters: code: 标的代码 (required) Returns: { "code": "000300.SH", "asset_type": "china_index", "description": "A股指数" } """ code = request.args.get('code', '').strip() if not code: return jsonify({ "error": "Missing required parameter: code", "example": "/api/v1/asset-type?code=000300.SH" }), 400 asset_type = AssetTypeDetector.detect(code) # 资产类型描述映射 descriptions = { 'china_index': 'A股指数', 'china_etf': 'A股ETF', 'china_stock': 'A股股票', 'hk_index': '港股指数', 'hk_stock': '港股股票', 'us_index': '美股指数', 'us_stock': '美股股票', 'futures': '期货合约', 'crypto': '加密货币', } return jsonify({ "code": code, "asset_type": asset_type, "description": descriptions.get(asset_type, '未知类型'), }) @app.route('/api/v1/ohlcv') def get_ohlcv(): """ 获取单只标的的 OHLCV 数据(支持LRU+TTL缓存) Query Parameters: code: 标的代码 (required) start: 开始日期,格式 YYYY-MM-DD (optional, 默认90天前) end: 结束日期,格式 YYYY-MM-DD (optional, 默认今天) retry: 重试次数 (optional, 默认3) nocache: 是否跳过缓存 (optional, 默认false) Returns: { "code": "000300.SH", "asset_type": "china_index", "data": [...], "count": 58, "date_range": {"start": "2024-01-02", "end": "2024-03-29"}, "cached": false } """ code = request.args.get('code', '').strip() start = request.args.get('start', '').strip() end = request.args.get('end', '').strip() retry = request.args.get('retry', '3') nocache = request.args.get('nocache', 'false').lower() == 'true' # 参数验证 if not code: return jsonify({ "error": "Missing required parameter: code", "example": "/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31" }), 400 # 设置默认日期 if not start or not end: default_start, default_end = get_default_dates() start = start or default_start end = end or default_end # 日期格式验证 if not validate_date(start) or not validate_date(end): return jsonify({ "error": "Invalid date format. Use YYYY-MM-DD", "start": start, "end": end, }), 400 try: retry_count = int(retry) except ValueError: retry_count = 3 # 使用缓存获取数据 result = fetch_data_with_ttl(code, start, end, retry_count, nocache) if result is None: return jsonify({ "code": code, "asset_type": AssetTypeDetector.detect(code), "error": "No data available for the specified code and date range", "start": start, "end": end, }), 404 if "error" in result: return jsonify({ "code": code, "asset_type": AssetTypeDetector.detect(code), "error": result["error"], "start": start, "end": end, }), 500 # 添加缓存状态 result['cached'] = not nocache and _fetch_data_cached.cache_info().hits > 0 return jsonify(result) @app.route('/api/v1/ohlcv/batch', methods=['POST']) def batch_ohlcv(): """ 批量获取多只标的的 OHLCV 数据 Request Body: { "codes": ["000300.SH", "NDX", "HSI"], "start": "2024-01-01", "end": "2024-03-31", "retry": 3 } Returns: { "results": { "000300.SH": {"data": [...], "count": 58, ...}, "NDX": {"data": [...], "count": 61, ...}, "HSI": {"error": "No data available", ...} }, "success_count": 2, "failed_count": 1, "total": 3 } """ data = request.get_json() if not data: return jsonify({ "error": "Missing request body", "example": { "codes": ["000300.SH", "NDX"], "start": "2024-01-01", "end": "2024-03-31", } }), 400 codes = data.get('codes', []) start = data.get('start', '').strip() end = data.get('end', '').strip() retry = data.get('retry', 3) if not codes or not isinstance(codes, list): return jsonify({ "error": "Missing or invalid parameter: codes (must be a list)" }), 400 # 设置默认日期 if not start or not end: default_start, default_end = get_default_dates() start = start or default_start end = end or default_end # 日期格式验证 if not validate_date(start) or not validate_date(end): return jsonify({ "error": "Invalid date format. Use YYYY-MM-DD", "start": start, "end": end, }), 400 # 获取数据 fetcher = get_fetcher() results = {} success_count = 0 failed_count = 0 try: with fetcher: for code in codes: try: # 使用缓存获取数据 result = fetch_data_with_ttl(code, start, end, retry, nocache=False) if result is not None and "error" not in result: results[code] = result success_count += 1 else: results[code] = { "code": code, "asset_type": AssetTypeDetector.detect(code), "error": result.get("error", "No data available") if result else "No data available", "data": [], "count": 0, } failed_count += 1 except Exception as e: results[code] = { "code": code, "asset_type": AssetTypeDetector.detect(code), "error": str(e), "data": [], "count": 0, } failed_count += 1 except Exception as e: return jsonify({ "error": f"Batch fetch failed: {str(e)}" }), 500 return jsonify({ "results": results, "success_count": success_count, "failed_count": failed_count, "total": len(codes), "start": start, "end": end, }) @app.route('/api/v1/supported-codes') def get_supported_codes(): """ 获取支持的代码示例 Returns: { "china_index": ["000300.SH", "399006.SZ", ...], "china_etf": ["510300.SH", "159915.SZ", ...], ... } """ return jsonify({ "china_index": { "description": "A股指数", "examples": ["000300.SH", "399006.SZ", "000016.SH", "H30269.CSI"], }, "china_etf": { "description": "A股ETF", "examples": ["510300.SH", "159915.SZ", "510500.SH", "513100.SH"], }, "china_stock": { "description": "A股股票", "examples": ["600000.SH", "000001.SZ"], }, "hk_index": { "description": "港股指数", "examples": ["HSI", "HSTECH.HK"], }, "us_index": { "description": "美股指数", "examples": ["NDX", "SPX", "DJI", "N225", "GDAXI"], }, "us_stock": { "description": "美股股票", "examples": ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"], }, "futures": { "description": "期货合约", "examples": ["AU.SHF", "CU.SHF"], }, "crypto": { "description": "加密货币", "examples": ["BTC", "ETH"], }, }) @app.route('/api/v1/cache/clear', methods=['POST']) def clear_cache_endpoint(): """清理缓存""" info_before = get_cache_info() clear_data_cache() return jsonify({ "message": "Cache cleared successfully", "before": info_before, "after": get_cache_info() }) @app.route('/api/v1/cache/stats') def cache_stats(): """获取缓存统计信息""" return jsonify(get_cache_info()) # ============================================================ # 错误处理 # ============================================================ @app.errorhandler(404) def not_found(error): return jsonify({ "error": "Endpoint not found", "available_endpoints": [ "/", "/health", "/api/v1/asset-type?code={code}", "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}", "/api/v1/ohlcv/batch", "/api/v1/cache/clear", "/api/v1/cache/stats", "/api/v1/supported-codes", ] }), 404 @app.errorhandler(500) def internal_error(error): return jsonify({ "error": "Internal server error", "message": str(error) }), 500 # ============================================================ # 启动服务 # ============================================================ if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Universal Data Fetcher API Server') parser.add_argument('--host', default='0.0.0.0', help='Host to bind (default: 0.0.0.0)') parser.add_argument('--port', type=int, default=5000, help='Port to bind (default: 5000)') parser.add_argument('--debug', action='store_true', help='Enable debug mode') args = parser.parse_args() # 预加载 SSH 配置 ssh_config = get_ssh_config() if ssh_config and ssh_config.get('enabled'): print(f"✓ SSH 隧道已配置: {ssh_config['host']}:{ssh_config['port']}") else: print("✗ SSH 隧道未启用(仅支持A股数据)") print(f"\n🚀 启动 Universal Data Fetcher API Server") print(f" Host: {args.host}") print(f" Port: {args.port}") print(f" Debug: {args.debug}") print(f" Cache: maxsize={CACHE_MAXSIZE}, ttl={CACHE_TTL_SECONDS}s") print(f"\n📖 API 文档: http://{args.host}:{args.port}/") print(f" 健康检查: http://{args.host}:{args.port}/health") print(f" 缓存统计: http://{args.host}:{args.port}/api/v1/cache/stats") print(f"\n💡 示例请求:") print(f" curl 'http://{args.host}:{args.port}/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31'") print(f"\n") app.run(host=args.host, port=args.port, debug=args.debug)