Files
etf/core/datasource/flask_server.py
aszerW a539a7d096 fix(api): 修复缓存状态判断逻辑
- fetch_data_with_ttl 返回 (data, is_cached) 元组
- 修复 cached 字段始终为 false 的问题
- 批量接口适配新的返回格式
2026-05-07 23:57:19 +08:00

695 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Flask 数据服务 API
==================
提供 RESTful API 接口,支持获取各类资产的 K 线数据
支持的资产类型:
- A股指数 (000300.SH, 399006.SZ)
- A股ETF (510300.SH, 159915.SZ)
- A股股票 (600000.SH, 000001.SZ)
- 港股指数 (HSI, HSTECH.HK)
- 美股指数 (NDX, SPX, DJI)
- 美股股票 (AAPL, MSFT, GOOGL)
- 期货合约 (AU.SHF, CU.SHF)
- 加密货币 (BTC, ETH)
运行:
python core/datasource/flask_server.py
API 文档:
GET /health - 健康检查
GET /api/v1/asset-type?code=000300.SH - 检测资产类型
GET /api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31 - 获取K线数据
POST /api/v1/ohlcv/batch - 批量获取K线数据
"""
import os
import sys
import json
from pathlib import Path
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
from functools import lru_cache
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from dotenv import load_dotenv
load_dotenv()
from flask import Flask, request, jsonify
from flask_cors import CORS
import pandas as pd
from core.datasource.universal_fetcher import (
UniversalDataFetcher,
AssetTypeDetector,
)
# ============================================================
# Flask 应用配置
# ============================================================
app = Flask(__name__)
CORS(app) # 启用跨域支持
# 全局数据获取器实例
fetcher: Optional[UniversalDataFetcher] = None
ssh_config: Optional[Dict] = None
def get_ssh_config() -> Optional[Dict]:
"""
从环境变量获取 SSH 配置
环境变量:
SSH_ENABLED: 是否启用 SSH 隧道 (true/false)
SSH_HOST: SSH 服务器地址
SSH_PORT: SSH 端口 (默认 22)
SSH_USERNAME: SSH 用户名
SSH_KEY_PATH: SSH 私钥路径
SSH_LOCAL_PORT: 本地 SOCKS5 端口 (默认 1080)
"""
enabled = os.getenv('SSH_ENABLED', 'false').lower() == 'true'
if not enabled:
return None
return {
"enabled": True,
"host": os.getenv('SSH_HOST', ''),
"port": int(os.getenv('SSH_PORT', '22')),
"username": os.getenv('SSH_USERNAME', ''),
"key_path": os.getenv('SSH_KEY_PATH', 'hk_ecs.pem'),
"local_port": int(os.getenv('SSH_LOCAL_PORT', '1080')),
}
def get_fetcher() -> UniversalDataFetcher:
"""获取或创建数据获取器实例"""
global fetcher, ssh_config
if fetcher is None:
ssh_config = get_ssh_config()
fetcher = UniversalDataFetcher(ssh_config=ssh_config)
return fetcher
# 缓存配置
CACHE_MAXSIZE = int(os.getenv('CACHE_MAXSIZE', '128')) # 默认缓存128条数据
CACHE_TTL_SECONDS = int(os.getenv('CACHE_TTL_SECONDS', '7200')) # 默认5分钟过期
# 带时间戳的缓存条目
class TimedCacheEntry:
def __init__(self, data):
self.data = data
self.timestamp = datetime.now()
def is_expired(self) -> bool:
return (datetime.now() - self.timestamp).total_seconds() > CACHE_TTL_SECONDS
# 使用 lru_cache 包装数据获取函数
@lru_cache(maxsize=CACHE_MAXSIZE)
def _fetch_data_cached(code: str, start: str, end: str, retry: int) -> Optional[str]:
"""
获取数据的缓存版本
返回 JSON 序列化的字符串,因为 lru_cache 需要可哈希的参数
"""
fetcher = get_fetcher()
try:
with fetcher:
df = fetcher.fetch(code, start, end, retry=retry)
if df is None or len(df) == 0:
return None
result = dataframe_to_json(df)
result['code'] = code
result['asset_type'] = AssetTypeDetector.detect(code)
return json.dumps(result)
except Exception as e:
return json.dumps({"error": str(e)})
def fetch_data_with_ttl(code: str, start: str, end: str, retry: int = 3, nocache: bool = False) -> tuple[Optional[Dict], bool]:
"""
获取数据支持TTL缓存
Returns:
(data, is_cached): 数据和是否命中缓存
"""
if nocache:
# 直接获取,不使用缓存
result_json = _fetch_data_cached.__wrapped__(code, start, end, retry)
return (json.loads(result_json) if result_json else None, False)
# 使用缓存
cache_key = (code, start, end, retry)
# 检查内存中的TTL缓存
if hasattr(fetch_data_with_ttl, '_ttl_cache'):
if cache_key in fetch_data_with_ttl._ttl_cache:
entry = fetch_data_with_ttl._ttl_cache[cache_key]
if not entry.is_expired():
return entry.data, True
# 过期,删除
del fetch_data_with_ttl._ttl_cache[cache_key]
else:
fetch_data_with_ttl._ttl_cache = {}
# 从 lru_cache 获取
result_json = _fetch_data_cached(code, start, end, retry)
if result_json is None:
return None, False
result = json.loads(result_json)
# 存入TTL缓存
fetch_data_with_ttl._ttl_cache[cache_key] = TimedCacheEntry(result)
return result, False
def clear_data_cache():
"""清理缓存"""
_fetch_data_cached.cache_clear()
if hasattr(fetch_data_with_ttl, '_ttl_cache'):
fetch_data_with_ttl._ttl_cache.clear()
def get_cache_info() -> Dict:
"""获取缓存信息"""
info = _fetch_data_cached.cache_info()
ttl_size = len(fetch_data_with_ttl._ttl_cache) if hasattr(fetch_data_with_ttl, '_ttl_cache') else 0
return {
"lru_cache": {
"hits": info.hits,
"misses": info.misses,
"maxsize": info.maxsize,
"currsize": info.currsize,
},
"ttl_cache_size": ttl_size,
"ttl_seconds": CACHE_TTL_SECONDS,
}
def dataframe_to_json(df: pd.DataFrame) -> Dict:
"""将 DataFrame 转换为 JSON 可序列化的字典"""
if df is None or len(df) == 0:
return {"data": [], "count": 0}
# 重置索引,将日期转为列
df_reset = df.reset_index()
# 处理日期列(支持多种可能的列名)
date_columns = ['date', 'Date', 'index', 'trade_date', 'datetime']
for col in date_columns:
if col in df_reset.columns:
try:
df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime('%Y-%m-%d')
# 统一重命名为 'date'
if col != 'date':
df_reset = df_reset.rename(columns={col: 'date'})
break
except Exception:
pass
# 确保所有数据都是 JSON 可序列化的
for col in df_reset.columns:
if df_reset[col].dtype == 'object':
# 转换可能的 Timestamp 对象
try:
df_reset[col] = df_reset[col].apply(
lambda x: x.strftime('%Y-%m-%d') if hasattr(x, 'strftime') else x
)
except Exception:
pass
# 转换为字典列表
records = df_reset.to_dict(orient='records')
# 构建返回结果
result = {
"data": records,
"count": len(records),
"columns": list(df_reset.columns),
"date_range": {
"start": df.index.min().strftime('%Y-%m-%d') if hasattr(df.index.min(), 'strftime') else str(df.index.min()),
"end": df.index.max().strftime('%Y-%m-%d') if hasattr(df.index.max(), 'strftime') else str(df.index.max()),
} if len(df) > 0 else None
}
# 添加股票信息(如果存在)
if hasattr(df, 'attrs') and df.attrs.get('info'):
result['info'] = df.attrs['info']
return result
def validate_date(date_str: str) -> bool:
"""验证日期格式"""
try:
datetime.strptime(date_str, '%Y-%m-%d')
return True
except ValueError:
return False
def get_default_dates() -> tuple:
"""获取默认日期范围最近3个月"""
end = datetime.now()
start = end - timedelta(days=90)
return start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d')
# ============================================================
# API 路由
# ============================================================
@app.route('/')
def index():
"""首页 - API 信息"""
cache_info = get_cache_info()
return jsonify({
"name": "Universal Data Fetcher API",
"version": "1.1.0",
"description": "统一数据获取服务支持A股、港股、美股、期货、加密货币",
"features": [
"自动资产类型识别",
"LRU+TTL 双缓存机制",
"SSH隧道支持港美股",
"批量数据获取"
],
"endpoints": {
"health": "/health",
"asset_type": "/api/v1/asset-type?code={code}",
"ohlcv": "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}",
"ohlcv_nocache": "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}&nocache=true",
"batch": "POST /api/v1/ohlcv/batch",
"cache_clear": "POST /api/v1/cache/clear",
"cache_stats": "/api/v1/cache/stats",
},
"supported_assets": [
"A股指数 (000300.SH)",
"A股ETF (510300.SH)",
"A股股票 (600000.SH)",
"港股指数 (HSI)",
"美股指数 (NDX)",
"美股股票 (AAPL)",
"期货 (AU.SHF)",
"加密货币 (BTC)",
],
"cache_config": {
"maxsize": CACHE_MAXSIZE,
"ttl_seconds": CACHE_TTL_SECONDS,
"stats": cache_info['lru_cache']
},
"ssh_status": "enabled" if ssh_config and ssh_config.get('enabled') else "disabled",
})
@app.route('/health')
def health():
"""健康检查"""
return jsonify({
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"ssh_configured": ssh_config is not None and ssh_config.get('enabled', False),
})
@app.route('/api/v1/asset-type')
def detect_asset_type():
"""
检测资产类型
Query Parameters:
code: 标的代码 (required)
Returns:
{
"code": "000300.SH",
"asset_type": "china_index",
"description": "A股指数"
}
"""
code = request.args.get('code', '').strip()
if not code:
return jsonify({
"error": "Missing required parameter: code",
"example": "/api/v1/asset-type?code=000300.SH"
}), 400
asset_type = AssetTypeDetector.detect(code)
# 资产类型描述映射
descriptions = {
'china_index': 'A股指数',
'china_etf': 'A股ETF',
'china_stock': 'A股股票',
'hk_index': '港股指数',
'hk_stock': '港股股票',
'us_index': '美股指数',
'us_stock': '美股股票',
'futures': '期货合约',
'crypto': '加密货币',
}
return jsonify({
"code": code,
"asset_type": asset_type,
"description": descriptions.get(asset_type, '未知类型'),
})
@app.route('/api/v1/ohlcv')
def get_ohlcv():
"""
获取单只标的的 OHLCV 数据支持LRU+TTL缓存
Query Parameters:
code: 标的代码 (required)
start: 开始日期,格式 YYYY-MM-DD (optional, 默认90天前)
end: 结束日期,格式 YYYY-MM-DD (optional, 默认今天)
retry: 重试次数 (optional, 默认3)
nocache: 是否跳过缓存 (optional, 默认false)
Returns:
{
"code": "000300.SH",
"asset_type": "china_index",
"data": [...],
"count": 58,
"date_range": {"start": "2024-01-02", "end": "2024-03-29"},
"cached": false
}
"""
code = request.args.get('code', '').strip()
start = request.args.get('start', '').strip()
end = request.args.get('end', '').strip()
retry = request.args.get('retry', '3')
nocache = request.args.get('nocache', 'false').lower() == 'true'
# 参数验证
if not code:
return jsonify({
"error": "Missing required parameter: code",
"example": "/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31"
}), 400
# 设置默认日期
if not start or not end:
default_start, default_end = get_default_dates()
start = start or default_start
end = end or default_end
# 日期格式验证
if not validate_date(start) or not validate_date(end):
return jsonify({
"error": "Invalid date format. Use YYYY-MM-DD",
"start": start,
"end": end,
}), 400
try:
retry_count = int(retry)
except ValueError:
retry_count = 3
# 使用缓存获取数据
result, is_cached = fetch_data_with_ttl(code, start, end, retry_count, nocache)
if result is None:
return jsonify({
"code": code,
"asset_type": AssetTypeDetector.detect(code),
"error": "No data available for the specified code and date range",
"start": start,
"end": end,
}), 404
if "error" in result:
return jsonify({
"code": code,
"asset_type": AssetTypeDetector.detect(code),
"error": result["error"],
"start": start,
"end": end,
}), 500
# 添加缓存状态
result['cached'] = is_cached
return jsonify(result)
@app.route('/api/v1/ohlcv/batch', methods=['POST'])
def batch_ohlcv():
"""
批量获取多只标的的 OHLCV 数据
Request Body:
{
"codes": ["000300.SH", "NDX", "HSI"],
"start": "2024-01-01",
"end": "2024-03-31",
"retry": 3
}
Returns:
{
"results": {
"000300.SH": {"data": [...], "count": 58, ...},
"NDX": {"data": [...], "count": 61, ...},
"HSI": {"error": "No data available", ...}
},
"success_count": 2,
"failed_count": 1,
"total": 3
}
"""
data = request.get_json()
if not data:
return jsonify({
"error": "Missing request body",
"example": {
"codes": ["000300.SH", "NDX"],
"start": "2024-01-01",
"end": "2024-03-31",
}
}), 400
codes = data.get('codes', [])
start = data.get('start', '').strip()
end = data.get('end', '').strip()
retry = data.get('retry', 3)
if not codes or not isinstance(codes, list):
return jsonify({
"error": "Missing or invalid parameter: codes (must be a list)"
}), 400
# 设置默认日期
if not start or not end:
default_start, default_end = get_default_dates()
start = start or default_start
end = end or default_end
# 日期格式验证
if not validate_date(start) or not validate_date(end):
return jsonify({
"error": "Invalid date format. Use YYYY-MM-DD",
"start": start,
"end": end,
}), 400
# 获取数据
fetcher = get_fetcher()
results = {}
success_count = 0
failed_count = 0
try:
with fetcher:
for code in codes:
try:
# 使用缓存获取数据
result, _ = fetch_data_with_ttl(code, start, end, retry, nocache=False)
if result is not None and "error" not in result:
results[code] = result
success_count += 1
else:
results[code] = {
"code": code,
"asset_type": AssetTypeDetector.detect(code),
"error": result.get("error", "No data available") if result else "No data available",
"data": [],
"count": 0,
}
failed_count += 1
except Exception as e:
results[code] = {
"code": code,
"asset_type": AssetTypeDetector.detect(code),
"error": str(e),
"data": [],
"count": 0,
}
failed_count += 1
except Exception as e:
return jsonify({
"error": f"Batch fetch failed: {str(e)}"
}), 500
return jsonify({
"results": results,
"success_count": success_count,
"failed_count": failed_count,
"total": len(codes),
"start": start,
"end": end,
})
@app.route('/api/v1/supported-codes')
def get_supported_codes():
"""
获取支持的代码示例
Returns:
{
"china_index": ["000300.SH", "399006.SZ", ...],
"china_etf": ["510300.SH", "159915.SZ", ...],
...
}
"""
return jsonify({
"china_index": {
"description": "A股指数",
"examples": ["000300.SH", "399006.SZ", "000016.SH", "H30269.CSI"],
},
"china_etf": {
"description": "A股ETF",
"examples": ["510300.SH", "159915.SZ", "510500.SH", "513100.SH"],
},
"china_stock": {
"description": "A股股票",
"examples": ["600000.SH", "000001.SZ"],
},
"hk_index": {
"description": "港股指数",
"examples": ["HSI", "HSTECH.HK"],
},
"us_index": {
"description": "美股指数",
"examples": ["NDX", "SPX", "DJI", "N225", "GDAXI"],
},
"us_stock": {
"description": "美股股票",
"examples": ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"],
},
"futures": {
"description": "期货合约",
"examples": ["AU.SHF", "CU.SHF"],
},
"crypto": {
"description": "加密货币",
"examples": ["BTC", "ETH"],
},
})
@app.route('/api/v1/cache/clear', methods=['POST'])
def clear_cache_endpoint():
"""清理缓存"""
info_before = get_cache_info()
clear_data_cache()
return jsonify({
"message": "Cache cleared successfully",
"before": info_before,
"after": get_cache_info()
})
@app.route('/api/v1/cache/stats')
def cache_stats():
"""获取缓存统计信息"""
return jsonify(get_cache_info())
# ============================================================
# 错误处理
# ============================================================
@app.errorhandler(404)
def not_found(error):
return jsonify({
"error": "Endpoint not found",
"available_endpoints": [
"/",
"/health",
"/api/v1/asset-type?code={code}",
"/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}",
"/api/v1/ohlcv/batch",
"/api/v1/cache/clear",
"/api/v1/cache/stats",
"/api/v1/supported-codes",
]
}), 404
@app.errorhandler(500)
def internal_error(error):
return jsonify({
"error": "Internal server error",
"message": str(error)
}), 500
# ============================================================
# 启动服务
# ============================================================
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Universal Data Fetcher API Server')
parser.add_argument('--host', default='0.0.0.0', help='Host to bind (default: 0.0.0.0)')
parser.add_argument('--port', type=int, default=5000, help='Port to bind (default: 5000)')
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
args = parser.parse_args()
# 预加载 SSH 配置
ssh_config = get_ssh_config()
if ssh_config and ssh_config.get('enabled'):
print(f"✓ SSH 隧道已配置: {ssh_config['host']}:{ssh_config['port']}")
else:
print("✗ SSH 隧道未启用仅支持A股数据")
print(f"\n🚀 启动 Universal Data Fetcher API Server")
print(f" Host: {args.host}")
print(f" Port: {args.port}")
print(f" Debug: {args.debug}")
print(f" Cache: maxsize={CACHE_MAXSIZE}, ttl={CACHE_TTL_SECONDS}s")
print(f"\n📖 API 文档: http://{args.host}:{args.port}/")
print(f" 健康检查: http://{args.host}:{args.port}/health")
print(f" 缓存统计: http://{args.host}:{args.port}/api/v1/cache/stats")
print(f"\n💡 示例请求:")
print(f" curl 'http://{args.host}:{args.port}/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31'")
print(f"\n")
app.run(host=args.host, port=args.port, debug=args.debug)