feat(api): 使用 functools.lru_cache 实现数据缓存
- 使用 Python 标准库 functools.lru_cache 实现 LRU 缓存 - 添加 TTL 机制实现缓存过期(默认5分钟) - 双缓存机制:LRU + TTL 结合 - 支持环境变量配置:CACHE_MAXSIZE(默认128)、CACHE_TTL_SECONDS(默认300) - 新增缓存管理端点: - POST /api/v1/cache/clear - 清理缓存 - GET /api/v1/cache/stats - 查看缓存统计(hits/misses/maxsize/currsize) - /api/v1/ohlcv 支持 nocache 参数跳过缓存 - 批量接口自动使用缓存 - 响应中包含 cached 字段标识缓存状态 - 更新 API 版本到 1.1.0
This commit is contained in:
@@ -26,11 +26,10 @@ API 文档:
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
import hashlib
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Optional, Dict, Any, List
|
from typing import Optional, Dict, Any, List
|
||||||
from functools import wraps
|
from functools import lru_cache
|
||||||
|
|
||||||
# 添加项目根目录到路径
|
# 添加项目根目录到路径
|
||||||
project_root = Path(__file__).parent.parent.parent
|
project_root = Path(__file__).parent.parent.parent
|
||||||
@@ -60,41 +59,6 @@ CORS(app) # 启用跨域支持
|
|||||||
fetcher: Optional[UniversalDataFetcher] = None
|
fetcher: Optional[UniversalDataFetcher] = None
|
||||||
ssh_config: Optional[Dict] = None
|
ssh_config: Optional[Dict] = None
|
||||||
|
|
||||||
# 内存缓存
|
|
||||||
_cache: Dict[str, Any] = {}
|
|
||||||
_cache_ttl: int = int(os.getenv('CACHE_TTL_SECONDS', '300')) # 默认5分钟
|
|
||||||
|
|
||||||
|
|
||||||
def _get_cache_key(code: str, start: str, end: str) -> str:
|
|
||||||
"""生成缓存键"""
|
|
||||||
key = f"{code}:{start}:{end}"
|
|
||||||
return hashlib.md5(key.encode()).hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_cached(key: str) -> Optional[Any]:
|
|
||||||
"""获取缓存数据"""
|
|
||||||
if key not in _cache:
|
|
||||||
return None
|
|
||||||
|
|
||||||
data, timestamp = _cache[key]
|
|
||||||
if datetime.now().timestamp() - timestamp > _cache_ttl:
|
|
||||||
# 缓存过期
|
|
||||||
del _cache[key]
|
|
||||||
return None
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def _set_cache(key: str, data: Any):
|
|
||||||
"""设置缓存数据"""
|
|
||||||
_cache[key] = (data, datetime.now().timestamp())
|
|
||||||
|
|
||||||
|
|
||||||
def clear_cache():
|
|
||||||
"""清理所有缓存"""
|
|
||||||
_cache.clear()
|
|
||||||
print("✓ 缓存已清理")
|
|
||||||
|
|
||||||
|
|
||||||
def get_ssh_config() -> Optional[Dict]:
|
def get_ssh_config() -> Optional[Dict]:
|
||||||
"""
|
"""
|
||||||
@@ -134,6 +98,105 @@ def get_fetcher() -> UniversalDataFetcher:
|
|||||||
return fetcher
|
return fetcher
|
||||||
|
|
||||||
|
|
||||||
|
# 缓存配置
|
||||||
|
CACHE_MAXSIZE = int(os.getenv('CACHE_MAXSIZE', '128')) # 默认缓存128条数据
|
||||||
|
CACHE_TTL_SECONDS = int(os.getenv('CACHE_TTL_SECONDS', '300')) # 默认5分钟过期
|
||||||
|
|
||||||
|
# 带时间戳的缓存条目
|
||||||
|
class TimedCacheEntry:
|
||||||
|
def __init__(self, data):
|
||||||
|
self.data = data
|
||||||
|
self.timestamp = datetime.now()
|
||||||
|
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
return (datetime.now() - self.timestamp).total_seconds() > CACHE_TTL_SECONDS
|
||||||
|
|
||||||
|
|
||||||
|
# 使用 lru_cache 包装数据获取函数
|
||||||
|
@lru_cache(maxsize=CACHE_MAXSIZE)
|
||||||
|
def _fetch_data_cached(code: str, start: str, end: str, retry: int) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
获取数据的缓存版本
|
||||||
|
返回 JSON 序列化的字符串,因为 lru_cache 需要可哈希的参数
|
||||||
|
"""
|
||||||
|
fetcher = get_fetcher()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with fetcher:
|
||||||
|
df = fetcher.fetch(code, start, end, retry=retry)
|
||||||
|
|
||||||
|
if df is None or len(df) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = dataframe_to_json(df)
|
||||||
|
result['code'] = code
|
||||||
|
result['asset_type'] = AssetTypeDetector.detect(code)
|
||||||
|
|
||||||
|
return json.dumps(result)
|
||||||
|
except Exception as e:
|
||||||
|
return json.dumps({"error": str(e)})
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_data_with_ttl(code: str, start: str, end: str, retry: int = 3, nocache: bool = False) -> Optional[Dict]:
|
||||||
|
"""
|
||||||
|
获取数据,支持TTL缓存
|
||||||
|
"""
|
||||||
|
if nocache:
|
||||||
|
# 直接获取,不使用缓存
|
||||||
|
result_json = _fetch_data_cached.__wrapped__(code, start, end, retry)
|
||||||
|
return json.loads(result_json) if result_json else None
|
||||||
|
|
||||||
|
# 使用缓存
|
||||||
|
cache_key = (code, start, end, retry)
|
||||||
|
|
||||||
|
# 检查内存中的TTL缓存
|
||||||
|
if hasattr(fetch_data_with_ttl, '_ttl_cache'):
|
||||||
|
if cache_key in fetch_data_with_ttl._ttl_cache:
|
||||||
|
entry = fetch_data_with_ttl._ttl_cache[cache_key]
|
||||||
|
if not entry.is_expired():
|
||||||
|
return entry.data
|
||||||
|
# 过期,删除
|
||||||
|
del fetch_data_with_ttl._ttl_cache[cache_key]
|
||||||
|
else:
|
||||||
|
fetch_data_with_ttl._ttl_cache = {}
|
||||||
|
|
||||||
|
# 从 lru_cache 获取
|
||||||
|
result_json = _fetch_data_cached(code, start, end, retry)
|
||||||
|
|
||||||
|
if result_json is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = json.loads(result_json)
|
||||||
|
|
||||||
|
# 存入TTL缓存
|
||||||
|
fetch_data_with_ttl._ttl_cache[cache_key] = TimedCacheEntry(result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def clear_data_cache():
|
||||||
|
"""清理缓存"""
|
||||||
|
_fetch_data_cached.cache_clear()
|
||||||
|
if hasattr(fetch_data_with_ttl, '_ttl_cache'):
|
||||||
|
fetch_data_with_ttl._ttl_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_info() -> Dict:
|
||||||
|
"""获取缓存信息"""
|
||||||
|
info = _fetch_data_cached.cache_info()
|
||||||
|
ttl_size = len(fetch_data_with_ttl._ttl_cache) if hasattr(fetch_data_with_ttl, '_ttl_cache') else 0
|
||||||
|
return {
|
||||||
|
"lru_cache": {
|
||||||
|
"hits": info.hits,
|
||||||
|
"misses": info.misses,
|
||||||
|
"maxsize": info.maxsize,
|
||||||
|
"currsize": info.currsize,
|
||||||
|
},
|
||||||
|
"ttl_cache_size": ttl_size,
|
||||||
|
"ttl_seconds": CACHE_TTL_SECONDS,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def dataframe_to_json(df: pd.DataFrame) -> Dict:
|
def dataframe_to_json(df: pd.DataFrame) -> Dict:
|
||||||
"""将 DataFrame 转换为 JSON 可序列化的字典"""
|
"""将 DataFrame 转换为 JSON 可序列化的字典"""
|
||||||
if df is None or len(df) == 0:
|
if df is None or len(df) == 0:
|
||||||
@@ -186,13 +249,14 @@ def get_default_dates() -> tuple:
|
|||||||
@app.route('/')
|
@app.route('/')
|
||||||
def index():
|
def index():
|
||||||
"""首页 - API 信息"""
|
"""首页 - API 信息"""
|
||||||
|
cache_info = get_cache_info()
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"name": "Universal Data Fetcher API",
|
"name": "Universal Data Fetcher API",
|
||||||
"version": "1.1.0",
|
"version": "1.1.0",
|
||||||
"description": "统一数据获取服务,支持A股、港股、美股、期货、加密货币",
|
"description": "统一数据获取服务,支持A股、港股、美股、期货、加密货币",
|
||||||
"features": [
|
"features": [
|
||||||
"自动资产类型识别",
|
"自动资产类型识别",
|
||||||
"内存缓存(默认5分钟TTL)",
|
"LRU+TTL 双缓存机制",
|
||||||
"SSH隧道支持(港美股)",
|
"SSH隧道支持(港美股)",
|
||||||
"批量数据获取"
|
"批量数据获取"
|
||||||
],
|
],
|
||||||
@@ -216,8 +280,9 @@ def index():
|
|||||||
"加密货币 (BTC)",
|
"加密货币 (BTC)",
|
||||||
],
|
],
|
||||||
"cache_config": {
|
"cache_config": {
|
||||||
"ttl_seconds": _cache_ttl,
|
"maxsize": CACHE_MAXSIZE,
|
||||||
"current_size": len(_cache)
|
"ttl_seconds": CACHE_TTL_SECONDS,
|
||||||
|
"stats": cache_info['lru_cache']
|
||||||
},
|
},
|
||||||
"ssh_status": "enabled" if ssh_config and ssh_config.get('enabled') else "disabled",
|
"ssh_status": "enabled" if ssh_config and ssh_config.get('enabled') else "disabled",
|
||||||
})
|
})
|
||||||
@@ -281,7 +346,7 @@ def detect_asset_type():
|
|||||||
@app.route('/api/v1/ohlcv')
|
@app.route('/api/v1/ohlcv')
|
||||||
def get_ohlcv():
|
def get_ohlcv():
|
||||||
"""
|
"""
|
||||||
获取单只标的的 OHLCV 数据(支持缓存)
|
获取单只标的的 OHLCV 数据(支持LRU+TTL缓存)
|
||||||
|
|
||||||
Query Parameters:
|
Query Parameters:
|
||||||
code: 标的代码 (required)
|
code: 标的代码 (required)
|
||||||
@@ -294,10 +359,7 @@ def get_ohlcv():
|
|||||||
{
|
{
|
||||||
"code": "000300.SH",
|
"code": "000300.SH",
|
||||||
"asset_type": "china_index",
|
"asset_type": "china_index",
|
||||||
"data": [
|
"data": [...],
|
||||||
{"date": "2024-01-02", "open": 3500.0, "high": 3550.0, ...},
|
|
||||||
...
|
|
||||||
],
|
|
||||||
"count": 58,
|
"count": 58,
|
||||||
"date_range": {"start": "2024-01-02", "end": "2024-03-29"},
|
"date_range": {"start": "2024-01-02", "end": "2024-03-29"},
|
||||||
"cached": false
|
"cached": false
|
||||||
@@ -335,53 +397,32 @@ def get_ohlcv():
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
retry_count = 3
|
retry_count = 3
|
||||||
|
|
||||||
# 检测资产类型
|
# 使用缓存获取数据
|
||||||
asset_type = AssetTypeDetector.detect(code)
|
result = fetch_data_with_ttl(code, start, end, retry_count, nocache)
|
||||||
|
|
||||||
# 检查缓存
|
if result is None:
|
||||||
cache_key = _get_cache_key(code, start, end)
|
|
||||||
if not nocache:
|
|
||||||
cached_data = _get_cached(cache_key)
|
|
||||||
if cached_data:
|
|
||||||
cached_data['cached'] = True
|
|
||||||
return jsonify(cached_data)
|
|
||||||
|
|
||||||
# 获取数据
|
|
||||||
fetcher = get_fetcher()
|
|
||||||
|
|
||||||
try:
|
|
||||||
with fetcher:
|
|
||||||
df = fetcher.fetch(code, start, end, retry=retry_count)
|
|
||||||
|
|
||||||
if df is None or len(df) == 0:
|
|
||||||
return jsonify({
|
|
||||||
"code": code,
|
|
||||||
"asset_type": asset_type,
|
|
||||||
"error": "No data available for the specified code and date range",
|
|
||||||
"start": start,
|
|
||||||
"end": end,
|
|
||||||
}), 404
|
|
||||||
|
|
||||||
result = dataframe_to_json(df)
|
|
||||||
result['code'] = code
|
|
||||||
result['asset_type'] = asset_type
|
|
||||||
result['cached'] = False
|
|
||||||
|
|
||||||
# 存入缓存
|
|
||||||
if not nocache:
|
|
||||||
_set_cache(cache_key, result.copy())
|
|
||||||
|
|
||||||
return jsonify(result)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"code": code,
|
"code": code,
|
||||||
"asset_type": asset_type,
|
"asset_type": AssetTypeDetector.detect(code),
|
||||||
"error": str(e),
|
"error": "No data available for the specified code and date range",
|
||||||
|
"start": start,
|
||||||
|
"end": end,
|
||||||
|
}), 404
|
||||||
|
|
||||||
|
if "error" in result:
|
||||||
|
return jsonify({
|
||||||
|
"code": code,
|
||||||
|
"asset_type": AssetTypeDetector.detect(code),
|
||||||
|
"error": result["error"],
|
||||||
"start": start,
|
"start": start,
|
||||||
"end": end,
|
"end": end,
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
|
# 添加缓存状态
|
||||||
|
result['cached'] = not nocache and _fetch_data_cached.cache_info().hits > 0
|
||||||
|
|
||||||
|
return jsonify(result)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/v1/ohlcv/batch', methods=['POST'])
|
@app.route('/api/v1/ohlcv/batch', methods=['POST'])
|
||||||
def batch_ohlcv():
|
def batch_ohlcv():
|
||||||
@@ -454,19 +495,17 @@ def batch_ohlcv():
|
|||||||
with fetcher:
|
with fetcher:
|
||||||
for code in codes:
|
for code in codes:
|
||||||
try:
|
try:
|
||||||
df = fetcher.fetch(code, start, end, retry=retry)
|
# 使用缓存获取数据
|
||||||
|
result = fetch_data_with_ttl(code, start, end, retry, nocache=False)
|
||||||
|
|
||||||
if df is not None and len(df) > 0:
|
if result is not None and "error" not in result:
|
||||||
result = dataframe_to_json(df)
|
|
||||||
result['code'] = code
|
|
||||||
result['asset_type'] = AssetTypeDetector.detect(code)
|
|
||||||
results[code] = result
|
results[code] = result
|
||||||
success_count += 1
|
success_count += 1
|
||||||
else:
|
else:
|
||||||
results[code] = {
|
results[code] = {
|
||||||
"code": code,
|
"code": code,
|
||||||
"asset_type": AssetTypeDetector.detect(code),
|
"asset_type": AssetTypeDetector.detect(code),
|
||||||
"error": "No data available",
|
"error": result.get("error", "No data available") if result else "No data available",
|
||||||
"data": [],
|
"data": [],
|
||||||
"count": 0,
|
"count": 0,
|
||||||
}
|
}
|
||||||
@@ -497,48 +536,6 @@ def batch_ohlcv():
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/v1/cache/clear', methods=['POST'])
|
|
||||||
def clear_cache_endpoint():
|
|
||||||
"""
|
|
||||||
清理所有缓存数据
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
{"message": "Cache cleared", "cache_size": 0}
|
|
||||||
"""
|
|
||||||
cache_size = len(_cache)
|
|
||||||
clear_cache()
|
|
||||||
return jsonify({
|
|
||||||
"message": "Cache cleared successfully",
|
|
||||||
"previous_size": cache_size,
|
|
||||||
"current_size": 0
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/v1/cache/stats')
|
|
||||||
def cache_stats():
|
|
||||||
"""
|
|
||||||
获取缓存统计信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
{
|
|
||||||
"cache_size": 10,
|
|
||||||
"cache_ttl_seconds": 300,
|
|
||||||
"memory_estimate_mb": 5.2
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
# 估算内存使用(粗略估计)
|
|
||||||
import sys
|
|
||||||
total_size = 0
|
|
||||||
for key, (data, _) in _cache.items():
|
|
||||||
total_size += sys.getsizeof(key) + sys.getsizeof(data)
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
"cache_size": len(_cache),
|
|
||||||
"cache_ttl_seconds": _cache_ttl,
|
|
||||||
"memory_estimate_mb": round(total_size / 1024 / 1024, 2)
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/v1/supported-codes')
|
@app.route('/api/v1/supported-codes')
|
||||||
def get_supported_codes():
|
def get_supported_codes():
|
||||||
"""
|
"""
|
||||||
@@ -587,6 +584,24 @@ def get_supported_codes():
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/api/v1/cache/clear', methods=['POST'])
|
||||||
|
def clear_cache_endpoint():
|
||||||
|
"""清理缓存"""
|
||||||
|
info_before = get_cache_info()
|
||||||
|
clear_data_cache()
|
||||||
|
return jsonify({
|
||||||
|
"message": "Cache cleared successfully",
|
||||||
|
"before": info_before,
|
||||||
|
"after": get_cache_info()
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/api/v1/cache/stats')
|
||||||
|
def cache_stats():
|
||||||
|
"""获取缓存统计信息"""
|
||||||
|
return jsonify(get_cache_info())
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# 错误处理
|
# 错误处理
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -641,8 +656,10 @@ if __name__ == '__main__':
|
|||||||
print(f" Host: {args.host}")
|
print(f" Host: {args.host}")
|
||||||
print(f" Port: {args.port}")
|
print(f" Port: {args.port}")
|
||||||
print(f" Debug: {args.debug}")
|
print(f" Debug: {args.debug}")
|
||||||
|
print(f" Cache: maxsize={CACHE_MAXSIZE}, ttl={CACHE_TTL_SECONDS}s")
|
||||||
print(f"\n📖 API 文档: http://{args.host}:{args.port}/")
|
print(f"\n📖 API 文档: http://{args.host}:{args.port}/")
|
||||||
print(f" 健康检查: http://{args.host}:{args.port}/health")
|
print(f" 健康检查: http://{args.host}:{args.port}/health")
|
||||||
|
print(f" 缓存统计: http://{args.host}:{args.port}/api/v1/cache/stats")
|
||||||
print(f"\n💡 示例请求:")
|
print(f"\n💡 示例请求:")
|
||||||
print(f" curl 'http://{args.host}:{args.port}/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31'")
|
print(f" curl 'http://{args.host}:{args.port}/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31'")
|
||||||
print(f"\n")
|
print(f"\n")
|
||||||
|
|||||||
Reference in New Issue
Block a user