""" Flask 数据服务 API ================== 提供 RESTful API 接口,支持获取各类资产的 K 线数据 特性: - 分层架构:各资产类型独立实现 - LRU + TTL 双缓存机制 - SSH隧道支持(港美股) - ETF净值获取(计算溢价率) 运行: python datasource/flask_server.py API 文档: GET / - 服务信息 GET /health - 健康检查 GET /api/v1/asset-type - 检测资产类型 GET /api/v1/ohlcv - 获取K线数据 POST /api/v1/ohlcv/batch - 批量获取K线数据 GET /api/v1/etf/nav - 获取ETF净值 POST /api/v1/cache/clear - 清理缓存 GET /api/v1/cache/stats - 缓存统计 """ import os import sys import json from pathlib import Path from datetime import datetime, timedelta from typing import Optional, Dict, Any, List, Tuple from functools import lru_cache # 添加项目根目录到路径 project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) from dotenv import load_dotenv load_dotenv() from flask import Flask, request, jsonify from flask_cors import CORS import pandas as pd from datasource.universal_fetcher import UniversalDataFetcher from datasource.asset_type_detector import AssetTypeDetector, AssetType # ============================================================ # Flask 应用配置 # ============================================================ app = Flask(__name__) CORS(app) # 启用跨域支持 # 全局数据获取器实例 fetcher: Optional[UniversalDataFetcher] = None ssh_config: Optional[Dict] = None # 缓存配置 CACHE_MAXSIZE = int(os.getenv('CACHE_MAXSIZE', '128')) CACHE_TTL_SECONDS = int(os.getenv('CACHE_TTL_SECONDS', '7200')) # 默认2小时 class TimedCacheEntry: """带时间戳的缓存条目""" def __init__(self, data: Any): self.data = data self.timestamp = datetime.now() def is_expired(self) -> bool: return (datetime.now() - self.timestamp).total_seconds() > CACHE_TTL_SECONDS # TTL缓存存储 _ttl_cache: Dict[Tuple, TimedCacheEntry] = {} # ============================================================ # 初始化 # ============================================================ def get_ssh_config() -> Optional[Dict]: """从环境变量获取 SSH 配置""" enabled = os.getenv('SSH_ENABLED', 'false').lower() == 'true' if not enabled: return None return { "enabled": True, "host": os.getenv('SSH_HOST', ''), "port": int(os.getenv('SSH_PORT', '22')), "username": os.getenv('SSH_USERNAME', ''), "key_path": os.getenv('SSH_KEY_PATH', 'config/hk_ecs.pem'), "local_port": int(os.getenv('SSH_LOCAL_PORT', '1080')), } def get_fetcher() -> UniversalDataFetcher: """获取或创建数据获取器实例""" global fetcher, ssh_config if fetcher is None: ssh_config = get_ssh_config() fetcher = UniversalDataFetcher(ssh_config=ssh_config) return fetcher # ============================================================ # 缓存机制 # ============================================================ @lru_cache(maxsize=CACHE_MAXSIZE) def _fetch_data_cached(code: str, start: str, end: str) -> Optional[str]: """ 获取数据的缓存版本 返回 JSON 序列化的字符串 """ f = get_fetcher() try: with f: df = f.fetch(code, start, end) if df is None or len(df) == 0: return None result = dataframe_to_json(df) result['code'] = code result['asset_type'] = AssetTypeDetector.detect(code).value return json.dumps(result) except Exception as e: return json.dumps({"error": str(e)}) def fetch_data_with_ttl( code: str, start: str, end: str, nocache: bool = False ) -> Tuple[Optional[Dict], bool]: """ 获取数据,支持 TTL 缓存 Args: code: 标的代码 start: 开始日期 end: 结束日期 nocache: 是否跳过缓存 Returns: (data, is_cached): 数据和是否命中缓存 """ cache_key = (code, start, end) # 跳过缓存 if nocache: _fetch_data_cached.cache_clear() result_json = _fetch_data_cached(code, start, end) return (json.loads(result_json) if result_json else None, False) # 检查 TTL 缓存 global _ttl_cache if cache_key in _ttl_cache: entry = _ttl_cache[cache_key] if not entry.is_expired(): return entry.data, True # 过期,删除 del _ttl_cache[cache_key] # 从 LRU 缓存获取 result_json = _fetch_data_cached(code, start, end) if result_json is None: return None, False result = json.loads(result_json) # 存入 TTL 缓存 _ttl_cache[cache_key] = TimedCacheEntry(result) return result, False def clear_cache(): """清理所有缓存""" global _ttl_cache _fetch_data_cached.cache_clear() _ttl_cache.clear() def get_cache_info() -> Dict: """获取缓存统计信息""" info = _fetch_data_cached.cache_info() return { "lru_cache": { "hits": info.hits, "misses": info.misses, "maxsize": info.maxsize, "currsize": info.currsize, }, "ttl_cache_size": len(_ttl_cache), "ttl_seconds": CACHE_TTL_SECONDS, } # ============================================================ # DataFrame 转换 # ============================================================ def dataframe_to_json(df: pd.DataFrame) -> Dict: """将 DataFrame 转换为 JSON 可序列化的字典""" if df is None or len(df) == 0: return {"data": [], "count": 0} # 重置索引 df_reset = df.reset_index() # 处理日期列 date_columns = ['date', 'Date', 'index', 'trade_date', 'datetime'] for col in date_columns: if col in df_reset.columns: try: df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime('%Y-%m-%d') if col != 'date': df_reset = df_reset.rename(columns={col: 'date'}) break except Exception: pass # 转换为字典列表 records = df_reset.to_dict(orient='records') return { "data": records, "count": len(records), "columns": list(df_reset.columns), "date_range": { "start": df.index.min().strftime('%Y-%m-%d') if len(df) > 0 else None, "end": df.index.max().strftime('%Y-%m-%d') if len(df) > 0 else None, } } def validate_date(date_str: str) -> bool: """验证日期格式""" try: datetime.strptime(date_str, '%Y-%m-%d') return True except ValueError: return False def get_default_dates() -> Tuple[str, str]: """获取默认日期范围(最近3个月)""" end = datetime.now() start = end - timedelta(days=90) return start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d') # ============================================================ # API 路由 # ============================================================ @app.route('/') def index(): """首页 - API 信息""" return jsonify({ "name": "Universal Data Fetcher API", "version": "2.0.0", "description": "统一数据获取服务(分层架构)", "architecture": "Unified entry + Asset-specific methods", "features": [ "分层架构(各资产类型独立实现)", "LRU + TTL 双缓存机制", "SSH隧道支持(港美股)", "ETF净值获取(计算溢价率)", ], "endpoints": { "info": "/", "health": "/health", "asset_type": "/api/v1/asset-type?code={code}", "ohlcv": "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}", "ohlcv_nocache": "/api/v1/ohlcv?code={code}&nocache=true", "batch": "POST /api/v1/ohlcv/batch", "etf_nav": "/api/v1/etf/nav?code={code}", "cache_clear": "POST /api/v1/cache/clear", "cache_stats": "/api/v1/cache/stats", }, "supported_assets": { "china_index": ["000300.SH", "399006.SZ", "H30269.CSI"], "china_etf": ["159915.SZ", "513100.SH", "518880.SH"], "hk_index": ["HSI", "HSTECH.HK"], "us_index": ["NDX", "SPX", "N225", "GDAXI"], "futures": ["AU.SHF", "CU.SHF", "CL.NYM"], "crypto": ["BTC", "ETH"], }, "cache_config": get_cache_info(), "ssh_status": "enabled" if ssh_config and ssh_config.get('enabled') else "disabled", }) @app.route('/health') def health(): """健康检查""" return jsonify({ "status": "healthy", "timestamp": datetime.now().isoformat(), "ssh_configured": ssh_config is not None and ssh_config.get('enabled', False), }) @app.route('/api/v1/asset-type') def detect_asset_type(): """检测资产类型""" code = request.args.get('code', '').strip() if not code: return jsonify({ "error": "Missing required parameter: code", "example": "/api/v1/asset-type?code=000300.SH" }), 400 asset_type = AssetTypeDetector.detect(code) description = AssetTypeDetector.get_description(asset_type) return jsonify({ "code": code, "asset_type": asset_type.value, "description": description, }) @app.route('/api/v1/ohlcv') def get_ohlcv(): """ 获取单只标的的 OHLCV 数据 Query Parameters: code: 标的代码 (required) start: 开始日期 YYYY-MM-DD (optional, 默认90天前) end: 结束日期 YYYY-MM-DD (optional, 默认今天) nocache: 是否跳过缓存 (optional, 默认false) """ code = request.args.get('code', '').strip() start = request.args.get('start', '').strip() end = request.args.get('end', '').strip() nocache = request.args.get('nocache', 'false').lower() == 'true' # 参数验证 if not code: return jsonify({ "error": "Missing required parameter: code", "example": "/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31" }), 400 # 设置默认日期 if not start or not end: start, end = get_default_dates() # 日期格式验证 if not validate_date(start) or not validate_date(end): return jsonify({ "error": "Invalid date format. Use YYYY-MM-DD", "start": start, "end": end, }), 400 # 使用缓存获取数据 result, is_cached = fetch_data_with_ttl(code, start, end, nocache) if result is None: return jsonify({ "code": code, "asset_type": AssetTypeDetector.detect(code).value, "error": "No data available", "start": start, "end": end, }), 404 if "error" in result: return jsonify({ "code": code, "asset_type": AssetTypeDetector.detect(code).value, "error": result["error"], }), 500 result['cached'] = is_cached return jsonify(result) @app.route('/api/v1/ohlcv/batch', methods=['POST']) def batch_ohlcv(): """批量获取多只标的的 OHLCV 数据""" data = request.get_json() if not data: return jsonify({ "error": "Missing request body", "example": { "codes": ["000300.SH", "NDX"], "start": "2024-01-01", "end": "2024-03-31", } }), 400 codes = data.get('codes', []) start = data.get('start', '').strip() end = data.get('end', '').strip() if not codes or not isinstance(codes, list): return jsonify({ "error": "Missing or invalid parameter: codes (must be a list)" }), 400 if not start or not end: start, end = get_default_dates() if not validate_date(start) or not validate_date(end): return jsonify({ "error": "Invalid date format. Use YYYY-MM-DD", }), 400 # 获取数据 f = get_fetcher() results = {} success_count = 0 failed_count = 0 try: with f: for code in codes: result, _ = fetch_data_with_ttl(code, start, end) if result is not None and "error" not in result: results[code] = result success_count += 1 else: results[code] = { "code": code, "asset_type": AssetTypeDetector.detect(code).value, "error": result.get("error", "No data") if result else "No data", "data": [], "count": 0, } failed_count += 1 except Exception as e: return jsonify({"error": f"Batch fetch failed: {str(e)}"}), 500 return jsonify({ "results": results, "success_count": success_count, "failed_count": failed_count, "total": len(codes), "start": start, "end": end, }) @app.route('/api/v1/etf/nav') def get_etf_nav(): """获取ETF净值数据(用于计算溢价率)""" code = request.args.get('code', '').strip() start = request.args.get('start', '').strip() end = request.args.get('end', '').strip() if not code: return jsonify({ "error": "Missing required parameter: code", "example": "/api/v1/etf/nav?code=513100.SH" }), 400 if not start or not end: start, end = get_default_dates() # 检查是否为ETF asset_type = AssetTypeDetector.detect(code) if asset_type != AssetType.CHINA_ETF: return jsonify({ "error": f"Not an ETF: {code} (type: {asset_type.value})", "hint": "Only A股ETF (codes starting with 51/52/15/16) supported", }), 400 # 获取净值和溢价率 f = get_fetcher() try: with f: price_df, nav_df, premium_series = f.fetch_etf_with_nav(code, start, end) result = { "code": code, "price": dataframe_to_json(price_df) if price_df else {"data": [], "count": 0}, "nav": dataframe_to_json(nav_df) if nav_df else {"data": [], "count": 0}, } # 添加历史溢价率序列 if premium_series is not None and len(premium_series) > 0: # 转换为日期-溢价率列表 premium_data = [ {"date": date.strftime('%Y-%m-%d'), "premium": round(premium, 6)} for date, premium in premium_series.items() ] result['premium_series'] = premium_data # 最新溢价率 latest_premium = premium_series.iloc[-1] latest_date = premium_series.index[-1].strftime('%Y-%m-%d') result['latest_premium'] = round(latest_premium, 6) result['premium_date'] = latest_date # 溢价率统计 result['premium_stats'] = { "mean": round(premium_series.mean(), 6), "std": round(premium_series.std(), 6), "min": round(premium_series.min(), 6), "max": round(premium_series.max(), 6), "median": round(premium_series.median(), 6), } return jsonify(result) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route('/api/v1/cache/clear', methods=['POST']) def clear_cache_endpoint(): """清理缓存""" info_before = get_cache_info() clear_cache() return jsonify({ "message": "Cache cleared successfully", "before": info_before, "after": get_cache_info() }) @app.route('/api/v1/cache/stats') def cache_stats(): """获取缓存统计""" return jsonify(get_cache_info()) # ============================================================ # 错误处理 # ============================================================ @app.errorhandler(404) def not_found(error): return jsonify({ "error": "Endpoint not found", "available_endpoints": [ "/", "/health", "/api/v1/asset-type", "/api/v1/ohlcv", "/api/v1/ohlcv/batch", "/api/v1/etf/nav", "/api/v1/cache/clear", "/api/v1/cache/stats", ] }), 404 @app.errorhandler(500) def internal_error(error): return jsonify({ "error": "Internal server error", "message": str(error) }), 500 # ============================================================ # 启动服务 # ============================================================ if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Universal Data Fetcher API Server') parser.add_argument('--host', default='0.0.0.0', help='Host to bind') parser.add_argument('--port', type=int, default=5000, help='Port to bind') parser.add_argument('--debug', action='store_true', help='Enable debug mode') args = parser.parse_args() # 预加载 SSH 配置 ssh_config = get_ssh_config() if ssh_config and ssh_config.get('enabled'): print(f"✓ SSH 隧道已配置: {ssh_config['host']}:{ssh_config['port']}") else: print("✗ SSH 隧道未启用(仅支持A股数据)") print(f"\n🚀 Universal Data Fetcher API Server v2.0") print(f" Host: {args.host}") print(f" Port: {args.port}") print(f" Cache: LRU({CACHE_MAXSIZE}) + TTL({CACHE_TTL_SECONDS}s)") print(f"\n📖 API: http://{args.host}:{args.port}/") print(f" 健康检查: http://{args.host}:{args.port}/health") app.run(host=args.host, port=args.port, debug=args.debug)