From b4a45e479f1ef7e3071a52fbb50e5aeb558254d1 Mon Sep 17 00:00:00 2001 From: aszerW Date: Thu, 7 May 2026 23:23:06 +0800 Subject: [PATCH] =?UTF-8?q?feat(api):=20=E4=BD=BF=E7=94=A8=20functools.lru?= =?UTF-8?q?=5Fcache=20=E5=AE=9E=E7=8E=B0=E6=95=B0=E6=8D=AE=E7=BC=93?= =?UTF-8?q?=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 使用 Python 标准库 functools.lru_cache 实现 LRU 缓存 - 添加 TTL 机制实现缓存过期(默认5分钟) - 双缓存机制:LRU + TTL 结合 - 支持环境变量配置:CACHE_MAXSIZE(默认128)、CACHE_TTL_SECONDS(默认300) - 新增缓存管理端点: - POST /api/v1/cache/clear - 清理缓存 - GET /api/v1/cache/stats - 查看缓存统计(hits/misses/maxsize/currsize) - /api/v1/ohlcv 支持 nocache 参数跳过缓存 - 批量接口自动使用缓存 - 响应中包含 cached 字段标识缓存状态 - 更新 API 版本到 1.1.0 --- core/datasource/flask_server.py | 283 +++++++++++++++++--------------- 1 file changed, 150 insertions(+), 133 deletions(-) diff --git a/core/datasource/flask_server.py b/core/datasource/flask_server.py index e1dceff..c2b3b12 100644 --- a/core/datasource/flask_server.py +++ b/core/datasource/flask_server.py @@ -26,11 +26,10 @@ API 文档: import os import sys import json -import hashlib from pathlib import Path from datetime import datetime, timedelta from typing import Optional, Dict, Any, List -from functools import wraps +from functools import lru_cache # 添加项目根目录到路径 project_root = Path(__file__).parent.parent.parent @@ -60,41 +59,6 @@ CORS(app) # 启用跨域支持 fetcher: Optional[UniversalDataFetcher] = None ssh_config: Optional[Dict] = None -# 内存缓存 -_cache: Dict[str, Any] = {} -_cache_ttl: int = int(os.getenv('CACHE_TTL_SECONDS', '300')) # 默认5分钟 - - -def _get_cache_key(code: str, start: str, end: str) -> str: - """生成缓存键""" - key = f"{code}:{start}:{end}" - return hashlib.md5(key.encode()).hexdigest() - - -def _get_cached(key: str) -> Optional[Any]: - """获取缓存数据""" - if key not in _cache: - return None - - data, timestamp = _cache[key] - if datetime.now().timestamp() - timestamp > _cache_ttl: - # 缓存过期 - del _cache[key] - return None - - return data - - -def _set_cache(key: str, data: Any): - """设置缓存数据""" - _cache[key] = (data, datetime.now().timestamp()) - - -def clear_cache(): - """清理所有缓存""" - _cache.clear() - print("✓ 缓存已清理") - def get_ssh_config() -> Optional[Dict]: """ @@ -134,6 +98,105 @@ def get_fetcher() -> UniversalDataFetcher: return fetcher +# 缓存配置 +CACHE_MAXSIZE = int(os.getenv('CACHE_MAXSIZE', '128')) # 默认缓存128条数据 +CACHE_TTL_SECONDS = int(os.getenv('CACHE_TTL_SECONDS', '300')) # 默认5分钟过期 + +# 带时间戳的缓存条目 +class TimedCacheEntry: + def __init__(self, data): + self.data = data + self.timestamp = datetime.now() + + def is_expired(self) -> bool: + return (datetime.now() - self.timestamp).total_seconds() > CACHE_TTL_SECONDS + + +# 使用 lru_cache 包装数据获取函数 +@lru_cache(maxsize=CACHE_MAXSIZE) +def _fetch_data_cached(code: str, start: str, end: str, retry: int) -> Optional[str]: + """ + 获取数据的缓存版本 + 返回 JSON 序列化的字符串,因为 lru_cache 需要可哈希的参数 + """ + fetcher = get_fetcher() + + try: + with fetcher: + df = fetcher.fetch(code, start, end, retry=retry) + + if df is None or len(df) == 0: + return None + + result = dataframe_to_json(df) + result['code'] = code + result['asset_type'] = AssetTypeDetector.detect(code) + + return json.dumps(result) + except Exception as e: + return json.dumps({"error": str(e)}) + + +def fetch_data_with_ttl(code: str, start: str, end: str, retry: int = 3, nocache: bool = False) -> Optional[Dict]: + """ + 获取数据,支持TTL缓存 + """ + if nocache: + # 直接获取,不使用缓存 + result_json = _fetch_data_cached.__wrapped__(code, start, end, retry) + return json.loads(result_json) if result_json else None + + # 使用缓存 + cache_key = (code, start, end, retry) + + # 检查内存中的TTL缓存 + if hasattr(fetch_data_with_ttl, '_ttl_cache'): + if cache_key in fetch_data_with_ttl._ttl_cache: + entry = fetch_data_with_ttl._ttl_cache[cache_key] + if not entry.is_expired(): + return entry.data + # 过期,删除 + del fetch_data_with_ttl._ttl_cache[cache_key] + else: + fetch_data_with_ttl._ttl_cache = {} + + # 从 lru_cache 获取 + result_json = _fetch_data_cached(code, start, end, retry) + + if result_json is None: + return None + + result = json.loads(result_json) + + # 存入TTL缓存 + fetch_data_with_ttl._ttl_cache[cache_key] = TimedCacheEntry(result) + + return result + + +def clear_data_cache(): + """清理缓存""" + _fetch_data_cached.cache_clear() + if hasattr(fetch_data_with_ttl, '_ttl_cache'): + fetch_data_with_ttl._ttl_cache.clear() + + +def get_cache_info() -> Dict: + """获取缓存信息""" + info = _fetch_data_cached.cache_info() + ttl_size = len(fetch_data_with_ttl._ttl_cache) if hasattr(fetch_data_with_ttl, '_ttl_cache') else 0 + return { + "lru_cache": { + "hits": info.hits, + "misses": info.misses, + "maxsize": info.maxsize, + "currsize": info.currsize, + }, + "ttl_cache_size": ttl_size, + "ttl_seconds": CACHE_TTL_SECONDS, + } + + def dataframe_to_json(df: pd.DataFrame) -> Dict: """将 DataFrame 转换为 JSON 可序列化的字典""" if df is None or len(df) == 0: @@ -186,13 +249,14 @@ def get_default_dates() -> tuple: @app.route('/') def index(): """首页 - API 信息""" + cache_info = get_cache_info() return jsonify({ "name": "Universal Data Fetcher API", "version": "1.1.0", "description": "统一数据获取服务,支持A股、港股、美股、期货、加密货币", "features": [ "自动资产类型识别", - "内存缓存(默认5分钟TTL)", + "LRU+TTL 双缓存机制", "SSH隧道支持(港美股)", "批量数据获取" ], @@ -216,8 +280,9 @@ def index(): "加密货币 (BTC)", ], "cache_config": { - "ttl_seconds": _cache_ttl, - "current_size": len(_cache) + "maxsize": CACHE_MAXSIZE, + "ttl_seconds": CACHE_TTL_SECONDS, + "stats": cache_info['lru_cache'] }, "ssh_status": "enabled" if ssh_config and ssh_config.get('enabled') else "disabled", }) @@ -281,7 +346,7 @@ def detect_asset_type(): @app.route('/api/v1/ohlcv') def get_ohlcv(): """ - 获取单只标的的 OHLCV 数据(支持缓存) + 获取单只标的的 OHLCV 数据(支持LRU+TTL缓存) Query Parameters: code: 标的代码 (required) @@ -294,10 +359,7 @@ def get_ohlcv(): { "code": "000300.SH", "asset_type": "china_index", - "data": [ - {"date": "2024-01-02", "open": 3500.0, "high": 3550.0, ...}, - ... - ], + "data": [...], "count": 58, "date_range": {"start": "2024-01-02", "end": "2024-03-29"}, "cached": false @@ -335,52 +397,31 @@ def get_ohlcv(): except ValueError: retry_count = 3 - # 检测资产类型 - asset_type = AssetTypeDetector.detect(code) + # 使用缓存获取数据 + result = fetch_data_with_ttl(code, start, end, retry_count, nocache) - # 检查缓存 - cache_key = _get_cache_key(code, start, end) - if not nocache: - cached_data = _get_cached(cache_key) - if cached_data: - cached_data['cached'] = True - return jsonify(cached_data) - - # 获取数据 - fetcher = get_fetcher() - - try: - with fetcher: - df = fetcher.fetch(code, start, end, retry=retry_count) - - if df is None or len(df) == 0: - return jsonify({ - "code": code, - "asset_type": asset_type, - "error": "No data available for the specified code and date range", - "start": start, - "end": end, - }), 404 - - result = dataframe_to_json(df) - result['code'] = code - result['asset_type'] = asset_type - result['cached'] = False - - # 存入缓存 - if not nocache: - _set_cache(cache_key, result.copy()) - - return jsonify(result) - - except Exception as e: + if result is None: return jsonify({ "code": code, - "asset_type": asset_type, - "error": str(e), + "asset_type": AssetTypeDetector.detect(code), + "error": "No data available for the specified code and date range", + "start": start, + "end": end, + }), 404 + + if "error" in result: + return jsonify({ + "code": code, + "asset_type": AssetTypeDetector.detect(code), + "error": result["error"], "start": start, "end": end, }), 500 + + # 添加缓存状态 + result['cached'] = not nocache and _fetch_data_cached.cache_info().hits > 0 + + return jsonify(result) @app.route('/api/v1/ohlcv/batch', methods=['POST']) @@ -454,19 +495,17 @@ def batch_ohlcv(): with fetcher: for code in codes: try: - df = fetcher.fetch(code, start, end, retry=retry) + # 使用缓存获取数据 + result = fetch_data_with_ttl(code, start, end, retry, nocache=False) - if df is not None and len(df) > 0: - result = dataframe_to_json(df) - result['code'] = code - result['asset_type'] = AssetTypeDetector.detect(code) + if result is not None and "error" not in result: results[code] = result success_count += 1 else: results[code] = { "code": code, "asset_type": AssetTypeDetector.detect(code), - "error": "No data available", + "error": result.get("error", "No data available") if result else "No data available", "data": [], "count": 0, } @@ -497,48 +536,6 @@ def batch_ohlcv(): }) -@app.route('/api/v1/cache/clear', methods=['POST']) -def clear_cache_endpoint(): - """ - 清理所有缓存数据 - - Returns: - {"message": "Cache cleared", "cache_size": 0} - """ - cache_size = len(_cache) - clear_cache() - return jsonify({ - "message": "Cache cleared successfully", - "previous_size": cache_size, - "current_size": 0 - }) - - -@app.route('/api/v1/cache/stats') -def cache_stats(): - """ - 获取缓存统计信息 - - Returns: - { - "cache_size": 10, - "cache_ttl_seconds": 300, - "memory_estimate_mb": 5.2 - } - """ - # 估算内存使用(粗略估计) - import sys - total_size = 0 - for key, (data, _) in _cache.items(): - total_size += sys.getsizeof(key) + sys.getsizeof(data) - - return jsonify({ - "cache_size": len(_cache), - "cache_ttl_seconds": _cache_ttl, - "memory_estimate_mb": round(total_size / 1024 / 1024, 2) - }) - - @app.route('/api/v1/supported-codes') def get_supported_codes(): """ @@ -587,6 +584,24 @@ def get_supported_codes(): }) +@app.route('/api/v1/cache/clear', methods=['POST']) +def clear_cache_endpoint(): + """清理缓存""" + info_before = get_cache_info() + clear_data_cache() + return jsonify({ + "message": "Cache cleared successfully", + "before": info_before, + "after": get_cache_info() + }) + + +@app.route('/api/v1/cache/stats') +def cache_stats(): + """获取缓存统计信息""" + return jsonify(get_cache_info()) + + # ============================================================ # 错误处理 # ============================================================ @@ -641,8 +656,10 @@ if __name__ == '__main__': print(f" Host: {args.host}") print(f" Port: {args.port}") print(f" Debug: {args.debug}") + print(f" Cache: maxsize={CACHE_MAXSIZE}, ttl={CACHE_TTL_SECONDS}s") print(f"\n📖 API 文档: http://{args.host}:{args.port}/") print(f" 健康检查: http://{args.host}:{args.port}/health") + print(f" 缓存统计: http://{args.host}:{args.port}/api/v1/cache/stats") print(f"\n💡 示例请求:") print(f" curl 'http://{args.host}:{args.port}/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31'") print(f"\n")