feat(api): 使用 functools.lru_cache 实现数据缓存

- 使用 Python 标准库 functools.lru_cache 实现 LRU 缓存
- 添加 TTL 机制实现缓存过期(默认5分钟)
- 双缓存机制:LRU + TTL 结合
- 支持环境变量配置:CACHE_MAXSIZE(默认128)、CACHE_TTL_SECONDS(默认300)
- 新增缓存管理端点:
  - POST /api/v1/cache/clear - 清理缓存
  - GET /api/v1/cache/stats - 查看缓存统计(hits/misses/maxsize/currsize)
- /api/v1/ohlcv 支持 nocache 参数跳过缓存
- 批量接口自动使用缓存
- 响应中包含 cached 字段标识缓存状态
- 更新 API 版本到 1.1.0
This commit is contained in:
2026-05-07 23:23:06 +08:00
parent d703974c5b
commit b4a45e479f

View File

@@ -26,11 +26,10 @@ API 文档:
import os
import sys
import json
import hashlib
from pathlib import Path
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
from functools import wraps
from functools import lru_cache
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent
@@ -60,41 +59,6 @@ CORS(app) # 启用跨域支持
fetcher: Optional[UniversalDataFetcher] = None
ssh_config: Optional[Dict] = None
# 内存缓存
_cache: Dict[str, Any] = {}
_cache_ttl: int = int(os.getenv('CACHE_TTL_SECONDS', '300')) # 默认5分钟
def _get_cache_key(code: str, start: str, end: str) -> str:
"""生成缓存键"""
key = f"{code}:{start}:{end}"
return hashlib.md5(key.encode()).hexdigest()
def _get_cached(key: str) -> Optional[Any]:
"""获取缓存数据"""
if key not in _cache:
return None
data, timestamp = _cache[key]
if datetime.now().timestamp() - timestamp > _cache_ttl:
# 缓存过期
del _cache[key]
return None
return data
def _set_cache(key: str, data: Any):
"""设置缓存数据"""
_cache[key] = (data, datetime.now().timestamp())
def clear_cache():
"""清理所有缓存"""
_cache.clear()
print("✓ 缓存已清理")
def get_ssh_config() -> Optional[Dict]:
"""
@@ -134,6 +98,105 @@ def get_fetcher() -> UniversalDataFetcher:
return fetcher
# 缓存配置
CACHE_MAXSIZE = int(os.getenv('CACHE_MAXSIZE', '128')) # 默认缓存128条数据
CACHE_TTL_SECONDS = int(os.getenv('CACHE_TTL_SECONDS', '300')) # 默认5分钟过期
# 带时间戳的缓存条目
class TimedCacheEntry:
def __init__(self, data):
self.data = data
self.timestamp = datetime.now()
def is_expired(self) -> bool:
return (datetime.now() - self.timestamp).total_seconds() > CACHE_TTL_SECONDS
# 使用 lru_cache 包装数据获取函数
@lru_cache(maxsize=CACHE_MAXSIZE)
def _fetch_data_cached(code: str, start: str, end: str, retry: int) -> Optional[str]:
"""
获取数据的缓存版本
返回 JSON 序列化的字符串,因为 lru_cache 需要可哈希的参数
"""
fetcher = get_fetcher()
try:
with fetcher:
df = fetcher.fetch(code, start, end, retry=retry)
if df is None or len(df) == 0:
return None
result = dataframe_to_json(df)
result['code'] = code
result['asset_type'] = AssetTypeDetector.detect(code)
return json.dumps(result)
except Exception as e:
return json.dumps({"error": str(e)})
def fetch_data_with_ttl(code: str, start: str, end: str, retry: int = 3, nocache: bool = False) -> Optional[Dict]:
"""
获取数据支持TTL缓存
"""
if nocache:
# 直接获取,不使用缓存
result_json = _fetch_data_cached.__wrapped__(code, start, end, retry)
return json.loads(result_json) if result_json else None
# 使用缓存
cache_key = (code, start, end, retry)
# 检查内存中的TTL缓存
if hasattr(fetch_data_with_ttl, '_ttl_cache'):
if cache_key in fetch_data_with_ttl._ttl_cache:
entry = fetch_data_with_ttl._ttl_cache[cache_key]
if not entry.is_expired():
return entry.data
# 过期,删除
del fetch_data_with_ttl._ttl_cache[cache_key]
else:
fetch_data_with_ttl._ttl_cache = {}
# 从 lru_cache 获取
result_json = _fetch_data_cached(code, start, end, retry)
if result_json is None:
return None
result = json.loads(result_json)
# 存入TTL缓存
fetch_data_with_ttl._ttl_cache[cache_key] = TimedCacheEntry(result)
return result
def clear_data_cache():
"""清理缓存"""
_fetch_data_cached.cache_clear()
if hasattr(fetch_data_with_ttl, '_ttl_cache'):
fetch_data_with_ttl._ttl_cache.clear()
def get_cache_info() -> Dict:
"""获取缓存信息"""
info = _fetch_data_cached.cache_info()
ttl_size = len(fetch_data_with_ttl._ttl_cache) if hasattr(fetch_data_with_ttl, '_ttl_cache') else 0
return {
"lru_cache": {
"hits": info.hits,
"misses": info.misses,
"maxsize": info.maxsize,
"currsize": info.currsize,
},
"ttl_cache_size": ttl_size,
"ttl_seconds": CACHE_TTL_SECONDS,
}
def dataframe_to_json(df: pd.DataFrame) -> Dict:
"""将 DataFrame 转换为 JSON 可序列化的字典"""
if df is None or len(df) == 0:
@@ -186,13 +249,14 @@ def get_default_dates() -> tuple:
@app.route('/')
def index():
"""首页 - API 信息"""
cache_info = get_cache_info()
return jsonify({
"name": "Universal Data Fetcher API",
"version": "1.1.0",
"description": "统一数据获取服务支持A股、港股、美股、期货、加密货币",
"features": [
"自动资产类型识别",
"内存缓存默认5分钟TTL",
"LRU+TTL 双缓存机制",
"SSH隧道支持港美股",
"批量数据获取"
],
@@ -216,8 +280,9 @@ def index():
"加密货币 (BTC)",
],
"cache_config": {
"ttl_seconds": _cache_ttl,
"current_size": len(_cache)
"maxsize": CACHE_MAXSIZE,
"ttl_seconds": CACHE_TTL_SECONDS,
"stats": cache_info['lru_cache']
},
"ssh_status": "enabled" if ssh_config and ssh_config.get('enabled') else "disabled",
})
@@ -281,7 +346,7 @@ def detect_asset_type():
@app.route('/api/v1/ohlcv')
def get_ohlcv():
"""
获取单只标的的 OHLCV 数据(支持缓存)
获取单只标的的 OHLCV 数据(支持LRU+TTL缓存)
Query Parameters:
code: 标的代码 (required)
@@ -294,10 +359,7 @@ def get_ohlcv():
{
"code": "000300.SH",
"asset_type": "china_index",
"data": [
{"date": "2024-01-02", "open": 3500.0, "high": 3550.0, ...},
...
],
"data": [...],
"count": 58,
"date_range": {"start": "2024-01-02", "end": "2024-03-29"},
"cached": false
@@ -335,52 +397,31 @@ def get_ohlcv():
except ValueError:
retry_count = 3
# 检测资产类型
asset_type = AssetTypeDetector.detect(code)
# 使用缓存获取数据
result = fetch_data_with_ttl(code, start, end, retry_count, nocache)
# 检查缓存
cache_key = _get_cache_key(code, start, end)
if not nocache:
cached_data = _get_cached(cache_key)
if cached_data:
cached_data['cached'] = True
return jsonify(cached_data)
# 获取数据
fetcher = get_fetcher()
try:
with fetcher:
df = fetcher.fetch(code, start, end, retry=retry_count)
if df is None or len(df) == 0:
return jsonify({
"code": code,
"asset_type": asset_type,
"error": "No data available for the specified code and date range",
"start": start,
"end": end,
}), 404
result = dataframe_to_json(df)
result['code'] = code
result['asset_type'] = asset_type
result['cached'] = False
# 存入缓存
if not nocache:
_set_cache(cache_key, result.copy())
return jsonify(result)
except Exception as e:
if result is None:
return jsonify({
"code": code,
"asset_type": asset_type,
"error": str(e),
"asset_type": AssetTypeDetector.detect(code),
"error": "No data available for the specified code and date range",
"start": start,
"end": end,
}), 404
if "error" in result:
return jsonify({
"code": code,
"asset_type": AssetTypeDetector.detect(code),
"error": result["error"],
"start": start,
"end": end,
}), 500
# 添加缓存状态
result['cached'] = not nocache and _fetch_data_cached.cache_info().hits > 0
return jsonify(result)
@app.route('/api/v1/ohlcv/batch', methods=['POST'])
@@ -454,19 +495,17 @@ def batch_ohlcv():
with fetcher:
for code in codes:
try:
df = fetcher.fetch(code, start, end, retry=retry)
# 使用缓存获取数据
result = fetch_data_with_ttl(code, start, end, retry, nocache=False)
if df is not None and len(df) > 0:
result = dataframe_to_json(df)
result['code'] = code
result['asset_type'] = AssetTypeDetector.detect(code)
if result is not None and "error" not in result:
results[code] = result
success_count += 1
else:
results[code] = {
"code": code,
"asset_type": AssetTypeDetector.detect(code),
"error": "No data available",
"error": result.get("error", "No data available") if result else "No data available",
"data": [],
"count": 0,
}
@@ -497,48 +536,6 @@ def batch_ohlcv():
})
@app.route('/api/v1/cache/clear', methods=['POST'])
def clear_cache_endpoint():
"""
清理所有缓存数据
Returns:
{"message": "Cache cleared", "cache_size": 0}
"""
cache_size = len(_cache)
clear_cache()
return jsonify({
"message": "Cache cleared successfully",
"previous_size": cache_size,
"current_size": 0
})
@app.route('/api/v1/cache/stats')
def cache_stats():
"""
获取缓存统计信息
Returns:
{
"cache_size": 10,
"cache_ttl_seconds": 300,
"memory_estimate_mb": 5.2
}
"""
# 估算内存使用(粗略估计)
import sys
total_size = 0
for key, (data, _) in _cache.items():
total_size += sys.getsizeof(key) + sys.getsizeof(data)
return jsonify({
"cache_size": len(_cache),
"cache_ttl_seconds": _cache_ttl,
"memory_estimate_mb": round(total_size / 1024 / 1024, 2)
})
@app.route('/api/v1/supported-codes')
def get_supported_codes():
"""
@@ -587,6 +584,24 @@ def get_supported_codes():
})
@app.route('/api/v1/cache/clear', methods=['POST'])
def clear_cache_endpoint():
"""清理缓存"""
info_before = get_cache_info()
clear_data_cache()
return jsonify({
"message": "Cache cleared successfully",
"before": info_before,
"after": get_cache_info()
})
@app.route('/api/v1/cache/stats')
def cache_stats():
"""获取缓存统计信息"""
return jsonify(get_cache_info())
# ============================================================
# 错误处理
# ============================================================
@@ -641,8 +656,10 @@ if __name__ == '__main__':
print(f" Host: {args.host}")
print(f" Port: {args.port}")
print(f" Debug: {args.debug}")
print(f" Cache: maxsize={CACHE_MAXSIZE}, ttl={CACHE_TTL_SECONDS}s")
print(f"\n📖 API 文档: http://{args.host}:{args.port}/")
print(f" 健康检查: http://{args.host}:{args.port}/health")
print(f" 缓存统计: http://{args.host}:{args.port}/api/v1/cache/stats")
print(f"\n💡 示例请求:")
print(f" curl 'http://{args.host}:{args.port}/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31'")
print(f"\n")