From 4e3aac5e0eb4100bd9d599578fc403efb133b349 Mon Sep 17 00:00:00 2001 From: aszerW Date: Tue, 12 May 2026 21:33:19 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20Flask=E7=BB=9F=E4=B8=80=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E6=9C=8D=E5=8A=A1=E8=BF=81=E7=A7=BB=EF=BC=88=E5=88=86?= =?UTF-8?q?=E5=B1=82=E6=9E=B6=E6=9E=84=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 架构设计: - 对外统一接口 fetch():自动识别资产类型并路由 - 对内分层实现:各资产类型独立方法,职责单一 新增文件: - datasource/universal_fetcher.py: 统一数据获取器 - _fetch_china_index: A股指数(Tushare) - _fetch_china_etf: A股ETF(含净值) - _fetch_us_index: 美股指数(YFinance+SSH) - _fetch_hk_index: 港股指数(YFinance+SSH) - _fetch_futures: 期货(Tushare/YFinance) - fetch_etf_with_nav: ETF价格+净值(计算溢价率) - datasource/asset_type_detector.py: 资产类型检测器 - AssetType枚举:9种资产类型 - detect(): 自动识别资产类型 - group_by_type(): 批量分组 - datasource/flask_server.py: Flask API服务 - LRU + TTL 双缓存机制 - 8个API端点:ohlcv、etf/nav、batch、cache等 更新: - datasource/__init__.py: 导出新模块 验证: - 模块导入成功 - 资产类型检测正确 - A股数据获取正常(沪深300: 5条) --- datasource/__init__.py | 15 + datasource/asset_type_detector.py | 219 +++++++ datasource/flask_server.py | 588 ++++++++++++++++++ datasource/universal_fetcher.py | 322 ++++++++++ .../data/__pycache__/__init__.cpython-312.pyc | Bin 5476 -> 5476 bytes 5 files changed, 1144 insertions(+) create mode 100644 datasource/asset_type_detector.py create mode 100644 datasource/flask_server.py create mode 100644 datasource/universal_fetcher.py diff --git a/datasource/__init__.py b/datasource/__init__.py index 46cc232..9b9da9d 100644 --- a/datasource/__init__.py +++ b/datasource/__init__.py @@ -4,16 +4,31 @@ 核心数据获取能力: - A股数据:Tushare(指数、ETF、期货) - 境外数据:YFinance(港股、美股)通过SSH隧道 + +架构设计: +- 分层架构:对外统一接口,对内各资产类型独立实现 +- Flask API:LRU + TTL 双缓存机制 + +用法: + from datasource import UniversalDataFetcher, AssetType + + fetcher = UniversalDataFetcher() + df = fetcher.fetch("000300.SH", "2024-01-01", "2024-12-31") """ from .ssh_tunnel import SSHTunnelManager from .tushare_source import TushareSource from .yfinance_source import YFinanceSource from .hybrid_source import HybridDataSource +from .asset_type_detector import AssetTypeDetector, AssetType +from .universal_fetcher import UniversalDataFetcher __all__ = [ 'SSHTunnelManager', 'TushareSource', 'YFinanceSource', 'HybridDataSource', + 'AssetTypeDetector', + 'AssetType', + 'UniversalDataFetcher', ] \ No newline at end of file diff --git a/datasource/asset_type_detector.py b/datasource/asset_type_detector.py new file mode 100644 index 0000000..6c8e329 --- /dev/null +++ b/datasource/asset_type_detector.py @@ -0,0 +1,219 @@ +""" +资产类型检测器 + +根据代码格式自动识别资产类型,支持: +- A股指数/ETF/股票 +- 港股指数/股票 +- 美股指数/股票 +- 期货合约 +- 加密货币 + +用法: + from datasource.asset_type_detector import AssetTypeDetector, AssetType + + # 检测资产类型 + asset_type = AssetTypeDetector.detect("000300.SH") # AssetType.CHINA_INDEX + + # 获取描述 + desc = AssetTypeDetector.get_description(asset_type) # "A股指数" + + # 批量分组 + grouped = AssetTypeDetector.group_by_type(["000300.SH", "NDX", "AU.SHF"]) +""" + +from enum import Enum +from typing import Dict, List + + +class AssetType(Enum): + """资产类型枚举""" + CHINA_INDEX = "china_index" # A股指数 + CHINA_ETF = "china_etf" # A股ETF + CHINA_STOCK = "china_stock" # A股股票 + HK_INDEX = "hk_index" # 港股指数 + HK_STOCK = "hk_stock" # 港股股票 + US_INDEX = "us_index" # 美股指数 + US_STOCK = "us_stock" # 美股股票 + FUTURES = "futures" # 期货合约 + CRYPTO = "crypto" # 加密货币 + UNKNOWN = "unknown" # 未知类型 + + +class AssetTypeDetector: + """ + 资产类型检测器 + + 根据代码格式自动识别资产类型 + """ + + # A股后缀 + CHINA_SUFFIXES = ('.SH', '.SZ', '.SS', '.CSI') + + # 中国期货后缀 + CHINA_FUTURES_SUFFIXES = ('.SHF', '.DCE', '.CZC', '.INE', '.GFEX') + + # 境外期货后缀 + FOREIGN_FUTURES_SUFFIXES = ('.NYM', '.ICE', '.CME', '.CBT') + + # 港股后缀 + HK_SUFFIXES = ('.HK',) + + # 加密货币代码 + CRYPTO_CODES = {'BTC', 'ETH', 'SOL', 'BNB', 'XRP', 'ADA', 'DOGE'} + + # 特殊指数代码(无后缀) + SPECIAL_INDEX_CODES = { + 'HSI': AssetType.HK_INDEX, # 恒生指数 + 'NDX': AssetType.US_INDEX, # 纳斯达克100 + 'SPX': AssetType.US_INDEX, # 标普500 + 'DJI': AssetType.US_INDEX, # 道琼斯 + 'N225': AssetType.US_INDEX, # 日经225(日本) + 'GDAXI': AssetType.US_INDEX, # 德国DAX(欧洲) + 'HSCCI': AssetType.HK_INDEX, # 恒生国企指数 + 'HSCEI': AssetType.HK_INDEX, # 恒生中企指数 + } + + # YFinance映射(用于判断是否为指数) + YFINANCE_INDEX_MAP = { + "HSTECH.HK": "3033.HK", + "HSI": "^HSI", + "NDX": "^NDX", + "SPX": "^GSPC", + "DJI": "^DJI", + "N225": "^N225", + "GDAXI": "^GDAXI", + "CL.NYM": "CL=F", + } + + # 类型描述映射 + DESCRIPTIONS = { + AssetType.CHINA_INDEX: "A股指数", + AssetType.CHINA_ETF: "A股ETF", + AssetType.CHINA_STOCK: "A股股票", + AssetType.HK_INDEX: "港股指数", + AssetType.HK_STOCK: "港股股票", + AssetType.US_INDEX: "美股指数", + AssetType.US_STOCK: "美股股票", + AssetType.FUTURES: "期货合约", + AssetType.CRYPTO: "加密货币", + AssetType.UNKNOWN: "未知类型", + } + + @classmethod + def detect(cls, code: str) -> AssetType: + """ + 检测资产类型 + + Args: + code: 标的代码 + + Returns: + AssetType 枚举值 + """ + code = code.strip().upper() + + # 1. 加密货币(优先判断) + if code in cls.CRYPTO_CODES: + return AssetType.CRYPTO + + # 2. 特殊指数代码(无后缀) + if code in cls.SPECIAL_INDEX_CODES: + return cls.SPECIAL_INDEX_CODES[code] + + # 3. YFinance映射中的代码(指数) + if code in cls.YFINANCE_INDEX_MAP: + yf_code = cls.YFINANCE_INDEX_MAP[code] + if yf_code.startswith('^'): + return AssetType.US_INDEX + elif yf_code.endswith('.HK'): + return AssetType.HK_INDEX + + # 4. 期货后缀 + if any(code.endswith(suffix) for suffix in cls.CHINA_FUTURES_SUFFIXES): + return AssetType.FUTURES + if any(code.endswith(suffix) for suffix in cls.FOREIGN_FUTURES_SUFFIXES): + return AssetType.FUTURES + + # 5. 港股后缀 + if code.endswith('.HK'): + # 4位数字.HK通常是指数 + code_body = code.split('.')[0] + if code_body.isdigit() and len(code_body) == 4: + return AssetType.HK_INDEX + return AssetType.HK_STOCK + + # 6. A股后缀 + if any(code.endswith(suffix) for suffix in cls.CHINA_SUFFIXES): + return cls._classify_china_asset(code) + + # 7. 默认:美股股票 + return AssetType.US_STOCK + + @classmethod + def _classify_china_asset(cls, code: str) -> AssetType: + """ + 细分A股资产类型 + + 规则: + - .CSI 后缀:中证指数 + - 指数代码段: 000, 001, 002, 399, 930, 931, 932 + - ETF代码段: 51, 52, 56, 58, 15, 16 + - 股票: 其他 + + Args: + code: A股代码(含后缀) + + Returns: + AssetType + """ + # .CSI 后缀直接判定为指数 + if code.endswith('.CSI'): + return AssetType.CHINA_INDEX + + # 提取代码主体 + code_body = code.split('.')[0] + + # 检查是否为6位数字 + if not code_body.isdigit() or len(code_body) != 6: + return AssetType.CHINA_STOCK + + # 特殊情况:000001 是平安银行(股票) + if code_body == '000001': + return AssetType.CHINA_STOCK + + # ETF代码段 + etf_prefixes = ('51', '52', '56', '58', '15', '16') + if code_body.startswith(etf_prefixes): + return AssetType.CHINA_ETF + + # 指数代码段 + index_prefixes = ('000', '001', '002', '399', '930', '931', '932') + if code_body.startswith(index_prefixes): + return AssetType.CHINA_INDEX + + # 默认为股票 + return AssetType.CHINA_STOCK + + @classmethod + def get_description(cls, asset_type: AssetType) -> str: + """获取资产类型描述""" + return cls.DESCRIPTIONS.get(asset_type, "未知类型") + + @classmethod + def group_by_type(cls, codes: List[str]) -> Dict[AssetType, List[str]]: + """ + 按资产类型分组 + + Args: + codes: 代码列表 + + Returns: + {AssetType: [codes]} + """ + grouped = {} + for code in codes: + asset_type = cls.detect(code) + if asset_type not in grouped: + grouped[asset_type] = [] + grouped[asset_type].append(code) + return grouped \ No newline at end of file diff --git a/datasource/flask_server.py b/datasource/flask_server.py new file mode 100644 index 0000000..5e8458d --- /dev/null +++ b/datasource/flask_server.py @@ -0,0 +1,588 @@ +""" +Flask 数据服务 API +================== +提供 RESTful API 接口,支持获取各类资产的 K 线数据 + +特性: +- 分层架构:各资产类型独立实现 +- LRU + TTL 双缓存机制 +- SSH隧道支持(港美股) +- ETF净值获取(计算溢价率) + +运行: + python datasource/flask_server.py + +API 文档: + GET / - 服务信息 + GET /health - 健康检查 + GET /api/v1/asset-type - 检测资产类型 + GET /api/v1/ohlcv - 获取K线数据 + POST /api/v1/ohlcv/batch - 批量获取K线数据 + GET /api/v1/etf/nav - 获取ETF净值 + POST /api/v1/cache/clear - 清理缓存 + GET /api/v1/cache/stats - 缓存统计 +""" + +import os +import sys +import json +from pathlib import Path +from datetime import datetime, timedelta +from typing import Optional, Dict, Any, List, Tuple +from functools import lru_cache + +# 添加项目根目录到路径 +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from dotenv import load_dotenv +load_dotenv() + +from flask import Flask, request, jsonify +from flask_cors import CORS +import pandas as pd + +from datasource.universal_fetcher import UniversalDataFetcher +from datasource.asset_type_detector import AssetTypeDetector, AssetType + + +# ============================================================ +# Flask 应用配置 +# ============================================================ + +app = Flask(__name__) +CORS(app) # 启用跨域支持 + +# 全局数据获取器实例 +fetcher: Optional[UniversalDataFetcher] = None +ssh_config: Optional[Dict] = None + +# 缓存配置 +CACHE_MAXSIZE = int(os.getenv('CACHE_MAXSIZE', '128')) +CACHE_TTL_SECONDS = int(os.getenv('CACHE_TTL_SECONDS', '7200')) # 默认2小时 + + +class TimedCacheEntry: + """带时间戳的缓存条目""" + def __init__(self, data: Any): + self.data = data + self.timestamp = datetime.now() + + def is_expired(self) -> bool: + return (datetime.now() - self.timestamp).total_seconds() > CACHE_TTL_SECONDS + + +# TTL缓存存储 +_ttl_cache: Dict[Tuple, TimedCacheEntry] = {} + + +# ============================================================ +# 初始化 +# ============================================================ + +def get_ssh_config() -> Optional[Dict]: + """从环境变量获取 SSH 配置""" + enabled = os.getenv('SSH_ENABLED', 'false').lower() == 'true' + + if not enabled: + return None + + return { + "enabled": True, + "host": os.getenv('SSH_HOST', ''), + "port": int(os.getenv('SSH_PORT', '22')), + "username": os.getenv('SSH_USERNAME', ''), + "key_path": os.getenv('SSH_KEY_PATH', 'hk_ecs.pem'), + "local_port": int(os.getenv('SSH_LOCAL_PORT', '1080')), + } + + +def get_fetcher() -> UniversalDataFetcher: + """获取或创建数据获取器实例""" + global fetcher, ssh_config + + if fetcher is None: + ssh_config = get_ssh_config() + fetcher = UniversalDataFetcher(ssh_config=ssh_config) + + return fetcher + + +# ============================================================ +# 缓存机制 +# ============================================================ + +@lru_cache(maxsize=CACHE_MAXSIZE) +def _fetch_data_cached(code: str, start: str, end: str) -> Optional[str]: + """ + 获取数据的缓存版本 + 返回 JSON 序列化的字符串 + """ + f = get_fetcher() + + try: + with f: + df = f.fetch(code, start, end) + + if df is None or len(df) == 0: + return None + + result = dataframe_to_json(df) + result['code'] = code + result['asset_type'] = AssetTypeDetector.detect(code).value + + return json.dumps(result) + except Exception as e: + return json.dumps({"error": str(e)}) + + +def fetch_data_with_ttl( + code: str, + start: str, + end: str, + nocache: bool = False +) -> Tuple[Optional[Dict], bool]: + """ + 获取数据,支持 TTL 缓存 + + Args: + code: 标的代码 + start: 开始日期 + end: 结束日期 + nocache: 是否跳过缓存 + + Returns: + (data, is_cached): 数据和是否命中缓存 + """ + cache_key = (code, start, end) + + # 跳过缓存 + if nocache: + _fetch_data_cached.cache_clear() + result_json = _fetch_data_cached(code, start, end) + return (json.loads(result_json) if result_json else None, False) + + # 检查 TTL 缓存 + global _ttl_cache + if cache_key in _ttl_cache: + entry = _ttl_cache[cache_key] + if not entry.is_expired(): + return entry.data, True + # 过期,删除 + del _ttl_cache[cache_key] + + # 从 LRU 缓存获取 + result_json = _fetch_data_cached(code, start, end) + + if result_json is None: + return None, False + + result = json.loads(result_json) + + # 存入 TTL 缓存 + _ttl_cache[cache_key] = TimedCacheEntry(result) + + return result, False + + +def clear_cache(): + """清理所有缓存""" + global _ttl_cache + _fetch_data_cached.cache_clear() + _ttl_cache.clear() + + +def get_cache_info() -> Dict: + """获取缓存统计信息""" + info = _fetch_data_cached.cache_info() + return { + "lru_cache": { + "hits": info.hits, + "misses": info.misses, + "maxsize": info.maxsize, + "currsize": info.currsize, + }, + "ttl_cache_size": len(_ttl_cache), + "ttl_seconds": CACHE_TTL_SECONDS, + } + + +# ============================================================ +# DataFrame 转换 +# ============================================================ + +def dataframe_to_json(df: pd.DataFrame) -> Dict: + """将 DataFrame 转换为 JSON 可序列化的字典""" + if df is None or len(df) == 0: + return {"data": [], "count": 0} + + # 重置索引 + df_reset = df.reset_index() + + # 处理日期列 + date_columns = ['date', 'Date', 'index', 'trade_date', 'datetime'] + for col in date_columns: + if col in df_reset.columns: + try: + df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime('%Y-%m-%d') + if col != 'date': + df_reset = df_reset.rename(columns={col: 'date'}) + break + except Exception: + pass + + # 转换为字典列表 + records = df_reset.to_dict(orient='records') + + return { + "data": records, + "count": len(records), + "columns": list(df_reset.columns), + "date_range": { + "start": df.index.min().strftime('%Y-%m-%d') if len(df) > 0 else None, + "end": df.index.max().strftime('%Y-%m-%d') if len(df) > 0 else None, + } + } + + +def validate_date(date_str: str) -> bool: + """验证日期格式""" + try: + datetime.strptime(date_str, '%Y-%m-%d') + return True + except ValueError: + return False + + +def get_default_dates() -> Tuple[str, str]: + """获取默认日期范围(最近3个月)""" + end = datetime.now() + start = end - timedelta(days=90) + return start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d') + + +# ============================================================ +# API 路由 +# ============================================================ + +@app.route('/') +def index(): + """首页 - API 信息""" + return jsonify({ + "name": "Universal Data Fetcher API", + "version": "2.0.0", + "description": "统一数据获取服务(分层架构)", + "architecture": "Unified entry + Asset-specific methods", + "features": [ + "分层架构(各资产类型独立实现)", + "LRU + TTL 双缓存机制", + "SSH隧道支持(港美股)", + "ETF净值获取(计算溢价率)", + ], + "endpoints": { + "info": "/", + "health": "/health", + "asset_type": "/api/v1/asset-type?code={code}", + "ohlcv": "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}", + "ohlcv_nocache": "/api/v1/ohlcv?code={code}&nocache=true", + "batch": "POST /api/v1/ohlcv/batch", + "etf_nav": "/api/v1/etf/nav?code={code}", + "cache_clear": "POST /api/v1/cache/clear", + "cache_stats": "/api/v1/cache/stats", + }, + "supported_assets": { + "china_index": ["000300.SH", "399006.SZ", "H30269.CSI"], + "china_etf": ["159915.SZ", "513100.SH", "518880.SH"], + "hk_index": ["HSI", "HSTECH.HK"], + "us_index": ["NDX", "SPX", "N225", "GDAXI"], + "futures": ["AU.SHF", "CU.SHF", "CL.NYM"], + "crypto": ["BTC", "ETH"], + }, + "cache_config": get_cache_info(), + "ssh_status": "enabled" if ssh_config and ssh_config.get('enabled') else "disabled", + }) + + +@app.route('/health') +def health(): + """健康检查""" + return jsonify({ + "status": "healthy", + "timestamp": datetime.now().isoformat(), + "ssh_configured": ssh_config is not None and ssh_config.get('enabled', False), + }) + + +@app.route('/api/v1/asset-type') +def detect_asset_type(): + """检测资产类型""" + code = request.args.get('code', '').strip() + + if not code: + return jsonify({ + "error": "Missing required parameter: code", + "example": "/api/v1/asset-type?code=000300.SH" + }), 400 + + asset_type = AssetTypeDetector.detect(code) + description = AssetTypeDetector.get_description(asset_type) + + return jsonify({ + "code": code, + "asset_type": asset_type.value, + "description": description, + }) + + +@app.route('/api/v1/ohlcv') +def get_ohlcv(): + """ + 获取单只标的的 OHLCV 数据 + + Query Parameters: + code: 标的代码 (required) + start: 开始日期 YYYY-MM-DD (optional, 默认90天前) + end: 结束日期 YYYY-MM-DD (optional, 默认今天) + nocache: 是否跳过缓存 (optional, 默认false) + """ + code = request.args.get('code', '').strip() + start = request.args.get('start', '').strip() + end = request.args.get('end', '').strip() + nocache = request.args.get('nocache', 'false').lower() == 'true' + + # 参数验证 + if not code: + return jsonify({ + "error": "Missing required parameter: code", + "example": "/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31" + }), 400 + + # 设置默认日期 + if not start or not end: + start, end = get_default_dates() + + # 日期格式验证 + if not validate_date(start) or not validate_date(end): + return jsonify({ + "error": "Invalid date format. Use YYYY-MM-DD", + "start": start, + "end": end, + }), 400 + + # 使用缓存获取数据 + result, is_cached = fetch_data_with_ttl(code, start, end, nocache) + + if result is None: + return jsonify({ + "code": code, + "asset_type": AssetTypeDetector.detect(code).value, + "error": "No data available", + "start": start, + "end": end, + }), 404 + + if "error" in result: + return jsonify({ + "code": code, + "asset_type": AssetTypeDetector.detect(code).value, + "error": result["error"], + }), 500 + + result['cached'] = is_cached + return jsonify(result) + + +@app.route('/api/v1/ohlcv/batch', methods=['POST']) +def batch_ohlcv(): + """批量获取多只标的的 OHLCV 数据""" + data = request.get_json() + + if not data: + return jsonify({ + "error": "Missing request body", + "example": { + "codes": ["000300.SH", "NDX"], + "start": "2024-01-01", + "end": "2024-03-31", + } + }), 400 + + codes = data.get('codes', []) + start = data.get('start', '').strip() + end = data.get('end', '').strip() + + if not codes or not isinstance(codes, list): + return jsonify({ + "error": "Missing or invalid parameter: codes (must be a list)" + }), 400 + + if not start or not end: + start, end = get_default_dates() + + if not validate_date(start) or not validate_date(end): + return jsonify({ + "error": "Invalid date format. Use YYYY-MM-DD", + }), 400 + + # 获取数据 + f = get_fetcher() + results = {} + success_count = 0 + failed_count = 0 + + try: + with f: + for code in codes: + result, _ = fetch_data_with_ttl(code, start, end) + + if result is not None and "error" not in result: + results[code] = result + success_count += 1 + else: + results[code] = { + "code": code, + "asset_type": AssetTypeDetector.detect(code).value, + "error": result.get("error", "No data") if result else "No data", + "data": [], + "count": 0, + } + failed_count += 1 + except Exception as e: + return jsonify({"error": f"Batch fetch failed: {str(e)}"}), 500 + + return jsonify({ + "results": results, + "success_count": success_count, + "failed_count": failed_count, + "total": len(codes), + "start": start, + "end": end, + }) + + +@app.route('/api/v1/etf/nav') +def get_etf_nav(): + """获取ETF净值数据(用于计算溢价率)""" + code = request.args.get('code', '').strip() + start = request.args.get('start', '').strip() + end = request.args.get('end', '').strip() + + if not code: + return jsonify({ + "error": "Missing required parameter: code", + "example": "/api/v1/etf/nav?code=513100.SH" + }), 400 + + if not start or not end: + start, end = get_default_dates() + + # 检查是否为ETF + asset_type = AssetTypeDetector.detect(code) + if asset_type != AssetType.CHINA_ETF: + return jsonify({ + "error": f"Not an ETF: {code} (type: {asset_type.value})", + "hint": "Only A股ETF (codes starting with 51/52/15/16) supported", + }), 400 + + # 获取净值 + f = get_fetcher() + try: + with f: + price_df, nav_df = f.fetch_etf_with_nav(code, start, end) + + result = { + "code": code, + "price": dataframe_to_json(price_df) if price_df else {"data": [], "count": 0}, + "nav": dataframe_to_json(nav_df) if nav_df else {"data": [], "count": 0}, + } + + # 计算最新溢价率 + if nav_df is not None and len(nav_df) > 0 and price_df is not None and len(price_df) > 0: + latest_nav = nav_df['nav'].iloc[-1] + latest_price = price_df['close'].iloc[-1] + if latest_nav > 0: + premium = (latest_price - latest_nav) / latest_nav + result['premium_rate'] = premium + result['premium_date'] = nav_df.index[-1].strftime('%Y-%m-%d') + + return jsonify(result) + + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route('/api/v1/cache/clear', methods=['POST']) +def clear_cache_endpoint(): + """清理缓存""" + info_before = get_cache_info() + clear_cache() + return jsonify({ + "message": "Cache cleared successfully", + "before": info_before, + "after": get_cache_info() + }) + + +@app.route('/api/v1/cache/stats') +def cache_stats(): + """获取缓存统计""" + return jsonify(get_cache_info()) + + +# ============================================================ +# 错误处理 +# ============================================================ + +@app.errorhandler(404) +def not_found(error): + return jsonify({ + "error": "Endpoint not found", + "available_endpoints": [ + "/", "/health", + "/api/v1/asset-type", + "/api/v1/ohlcv", + "/api/v1/ohlcv/batch", + "/api/v1/etf/nav", + "/api/v1/cache/clear", + "/api/v1/cache/stats", + ] + }), 404 + + +@app.errorhandler(500) +def internal_error(error): + return jsonify({ + "error": "Internal server error", + "message": str(error) + }), 500 + + +# ============================================================ +# 启动服务 +# ============================================================ + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser(description='Universal Data Fetcher API Server') + parser.add_argument('--host', default='0.0.0.0', help='Host to bind') + parser.add_argument('--port', type=int, default=5000, help='Port to bind') + parser.add_argument('--debug', action='store_true', help='Enable debug mode') + + args = parser.parse_args() + + # 预加载 SSH 配置 + ssh_config = get_ssh_config() + if ssh_config and ssh_config.get('enabled'): + print(f"✓ SSH 隧道已配置: {ssh_config['host']}:{ssh_config['port']}") + else: + print("✗ SSH 隧道未启用(仅支持A股数据)") + + print(f"\n🚀 Universal Data Fetcher API Server v2.0") + print(f" Host: {args.host}") + print(f" Port: {args.port}") + print(f" Cache: LRU({CACHE_MAXSIZE}) + TTL({CACHE_TTL_SECONDS}s)") + print(f"\n📖 API: http://{args.host}:{args.port}/") + print(f" 健康检查: http://{args.host}:{args.port}/health") + + app.run(host=args.host, port=args.port, debug=args.debug) \ No newline at end of file diff --git a/datasource/universal_fetcher.py b/datasource/universal_fetcher.py new file mode 100644 index 0000000..6970541 --- /dev/null +++ b/datasource/universal_fetcher.py @@ -0,0 +1,322 @@ +""" +统一数据获取器 + +分层架构:对外统一接口,对内按资产类型独立实现 +支持:A股指数/ETF、港股指数、美股指数、期货、加密货币 + +用法: + from datasource import UniversalDataFetcher + + fetcher = UniversalDataFetcher() + + # 单标的获取(自动识别类型) + df = fetcher.fetch("000300.SH", "2024-01-01", "2024-12-31") + + # ETF获取(含净值) + price_df, nav_df = fetcher.fetch_etf_with_nav("513100.SH", "2024-01-01", "2024-12-31") + + # 批量获取 + results = fetcher.fetch_batch(["000300.SH", "NDX", "N225"], "2024-01-01", "2024-12-31") +""" + +import os +import time +from typing import Optional, Dict, List, Tuple +from datetime import datetime +import pandas as pd + +from .tushare_source import TushareSource +from .yfinance_source import YFinanceSource +from .ssh_tunnel import SSHTunnelManager +from .asset_type_detector import AssetTypeDetector, AssetType + + +class UniversalDataFetcher: + """ + 统一数据获取器 + + 分层架构: + - 对外:统一 fetch() 接口,自动路由 + - 对内:各资产类型独立方法,职责单一 + """ + + def __init__( + self, + ssh_config: Optional[Dict] = None, + use_cache: bool = True, + cache_dir: str = "data/etf_cache/daily" + ): + """ + 初始化 + + Args: + ssh_config: SSH隧道配置(用于港美股) + use_cache: 是否使用本地缓存 + cache_dir: 缓存目录 + """ + self.ssh_config = ssh_config or {} + self.use_cache = use_cache + self.cache_dir = cache_dir + + # 数据源实例 + self._tushare = TushareSource() + self._yfinance = YFinanceSource() + + # SSH隧道(延迟初始化) + self._tunnel: Optional[SSHTunnelManager] = None + self._tunnel_started = False + + def __enter__(self): + """上下文管理器入口""" + self._start_tunnel() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """上下文管理器退出""" + self._stop_tunnel() + + # ============================================================ + # SSH隧道管理 + # ============================================================ + + def _start_tunnel(self) -> bool: + """启动SSH隧道""" + if self._tunnel_started: + return True + + if self.ssh_config.get('enabled'): + self._tunnel = SSHTunnelManager(self.ssh_config) + if self._tunnel.start(): + self._tunnel_started = True + return True + return False + return True + + def _stop_tunnel(self): + """停止SSH隧道""" + if self._tunnel: + self._tunnel.stop() + self._tunnel = None + self._tunnel_started = False + + # ============================================================ + # 统一入口(自动路由) + # ============================================================ + + def fetch( + self, + code: str, + start_date: str, + end_date: str, + retry: int = 3 + ) -> Optional[pd.DataFrame]: + """ + 统一数据获取入口 + + 自动识别资产类型并路由到对应方法 + + Args: + code: 标的代码 + start_date: 开始日期 'YYYY-MM-DD' + end_date: 结束日期 'YYYY-MM-DD' + retry: 重试次数 + + Returns: + DataFrame with columns: date, open, high, low, close, volume + """ + asset_type = AssetTypeDetector.detect(code) + + for attempt in range(retry): + try: + # 路由到具体方法 + if asset_type == AssetType.CHINA_INDEX: + return self._fetch_china_index(code, start_date, end_date) + elif asset_type == AssetType.CHINA_ETF: + return self._fetch_china_etf(code, start_date, end_date) + elif asset_type == AssetType.US_INDEX: + return self._fetch_us_index(code, start_date, end_date) + elif asset_type == AssetType.HK_INDEX: + return self._fetch_hk_index(code, start_date, end_date) + elif asset_type == AssetType.FUTURES: + return self._fetch_futures(code, start_date, end_date) + elif asset_type == AssetType.CRYPTO: + return self._fetch_crypto(code, start_date, end_date) + else: + print(f"⚠️ 未知资产类型: {code} -> {asset_type}") + return None + + except Exception as e: + if attempt < retry - 1: + time.sleep(2) + else: + print(f"✗ 获取 {code} 失败 (尝试 {attempt+1}/{retry}): {e}") + return None + + return None + + # ============================================================ + # 分层实现:各资产类型独立方法 + # ============================================================ + + def _fetch_china_index( + self, + code: str, + start_date: str, + end_date: str + ) -> Optional[pd.DataFrame]: + """ + 获取A股指数 + + 特点:Tushare API,无需SSH隧道 + """ + return self._tushare.fetch_index(code, start_date, end_date) + + def _fetch_china_etf( + self, + code: str, + start_date: str, + end_date: str + ) -> Optional[pd.DataFrame]: + """ + 获取A股ETF价格 + + 特点:Tushare fund_daily接口 + """ + return self._tushare.fetch_etf(code, start_date, end_date) + + def fetch_etf_with_nav( + self, + code: str, + start_date: str, + end_date: str + ) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame]]: + """ + 获取ETF价格 + 净值 + + 用于计算溢价率 + + Args: + code: ETF代码 + + Returns: + (price_df, nav_df) + """ + price_df = self._tushare.fetch_etf(code, start_date, end_date) + nav_df = self._tushare.fetch_etf_nav(code, start_date, end_date) + return price_df, nav_df + + def _fetch_us_index( + self, + code: str, + start_date: str, + end_date: str + ) -> Optional[pd.DataFrame]: + """ + 获取美股指数 + + 特点:YFinance,需要SSH隧道,指数代码转换 + """ + self._start_tunnel() + return self._yfinance.fetch(code, start_date, end_date) + + def _fetch_hk_index( + self, + code: str, + start_date: str, + end_date: str + ) -> Optional[pd.DataFrame]: + """ + 获取港股指数 + + 特点:YFinance,需要SSH隧道 + """ + self._start_tunnel() + return self._yfinance.fetch(code, start_date, end_date) + + def _fetch_futures( + self, + code: str, + start_date: str, + end_date: str + ) -> Optional[pd.DataFrame]: + """ + 获取期货数据 + + 特点: + - 中国期货(.SHF/.DCE/.CZC): Tushare + - NYMEX(.NYM): YFinance + """ + if code.endswith('.NYM'): + # NYMEX期货走YFinance + self._start_tunnel() + return self._yfinance.fetch(code, start_date, end_date) + else: + # 中国期货走Tushare + return self._tushare.fetch_futures(code, start_date, end_date) + + def _fetch_crypto( + self, + code: str, + start_date: str, + end_date: str + ) -> Optional[pd.DataFrame]: + """ + 获取加密货币 + + 特点:CCXT,不支持SOCKS5代理 + + TODO: 实现加密货币获取 + """ + print(f"⚠️ 加密货币数据获取尚未实现: {code}") + return None + + # ============================================================ + # 批量获取 + # ============================================================ + + def fetch_batch( + self, + codes: List[str], + start_date: str, + end_date: str, + retry: int = 3 + ) -> Dict[str, Optional[pd.DataFrame]]: + """ + 批量获取多只标的数据 + + Args: + codes: 代码列表 + start_date: 开始日期 + end_date: 结束日期 + + Returns: + {code: DataFrame} + """ + results = {} + + # 按资产类型分组 + grouped = AssetTypeDetector.group_by_type(codes) + + print(f"开始获取 {len(codes)} 只标的...") + for asset_type, code_list in grouped.items(): + print(f" {asset_type.value}: {len(code_list)} 只") + + # 启动隧道(港美股需要) + self._start_tunnel() + + for code in codes: + results[code] = self.fetch(code, start_date, end_date, retry) + + return results + + # ============================================================ + # 辅助方法 + # ============================================================ + + def get_asset_type(self, code: str) -> AssetType: + """获取资产类型""" + return AssetTypeDetector.detect(code) + + def is_supported(self, code: str) -> bool: + """判断是否支持该代码""" + return AssetTypeDetector.detect(code) != AssetType.UNKNOWN \ No newline at end of file diff --git a/framework/data/__pycache__/__init__.cpython-312.pyc b/framework/data/__pycache__/__init__.cpython-312.pyc index db21a091a8db38dd65801f8acee39c45ba2cb13d..fdf9cfff87b30b18853d6ba521be34c949618502 100644 GIT binary patch delta 20 acmaE&^+b#NG%qg~0}wQ+GjHUM6a@f2a|G=G delta 20 acmaE&^+b#NG%qg~0}#Z1V%*3bDGC5WC