- 添加内存缓存,默认TTL 5分钟(可通过 CACHE_TTL_SECONDS 环境变量配置) - 新增缓存相关端点: - POST /api/v1/cache/clear - 清理缓存 - GET /api/v1/cache/stats - 缓存统计信息 - /api/v1/ohlcv 支持 nocache 参数跳过缓存 - 响应中返回 cached 字段标识是否命中缓存 - 更新 API 文档和版本号到 1.1.0 - 删除不需要的 build-flask-and-push.sh 和 docker-compose.flask.yml
651 lines
18 KiB
Python
651 lines
18 KiB
Python
"""
|
||
Flask 数据服务 API
|
||
==================
|
||
提供 RESTful API 接口,支持获取各类资产的 K 线数据
|
||
|
||
支持的资产类型:
|
||
- A股指数 (000300.SH, 399006.SZ)
|
||
- A股ETF (510300.SH, 159915.SZ)
|
||
- A股股票 (600000.SH, 000001.SZ)
|
||
- 港股指数 (HSI, HSTECH.HK)
|
||
- 美股指数 (NDX, SPX, DJI)
|
||
- 美股股票 (AAPL, MSFT, GOOGL)
|
||
- 期货合约 (AU.SHF, CU.SHF)
|
||
- 加密货币 (BTC, ETH)
|
||
|
||
运行:
|
||
python core/datasource/flask_server.py
|
||
|
||
API 文档:
|
||
GET /health - 健康检查
|
||
GET /api/v1/asset-type?code=000300.SH - 检测资产类型
|
||
GET /api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31 - 获取K线数据
|
||
POST /api/v1/ohlcv/batch - 批量获取K线数据
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import hashlib
|
||
from pathlib import Path
|
||
from datetime import datetime, timedelta
|
||
from typing import Optional, Dict, Any, List
|
||
from functools import wraps
|
||
|
||
# 添加项目根目录到路径
|
||
project_root = Path(__file__).parent.parent.parent
|
||
sys.path.insert(0, str(project_root))
|
||
|
||
from dotenv import load_dotenv
|
||
load_dotenv()
|
||
|
||
from flask import Flask, request, jsonify
|
||
from flask_cors import CORS
|
||
import pandas as pd
|
||
|
||
from core.datasource.universal_fetcher import (
|
||
UniversalDataFetcher,
|
||
AssetTypeDetector,
|
||
)
|
||
|
||
|
||
# ============================================================
|
||
# Flask 应用配置
|
||
# ============================================================
|
||
|
||
app = Flask(__name__)
|
||
CORS(app) # 启用跨域支持
|
||
|
||
# 全局数据获取器实例
|
||
fetcher: Optional[UniversalDataFetcher] = None
|
||
ssh_config: Optional[Dict] = None
|
||
|
||
# 内存缓存
|
||
_cache: Dict[str, Any] = {}
|
||
_cache_ttl: int = int(os.getenv('CACHE_TTL_SECONDS', '300')) # 默认5分钟
|
||
|
||
|
||
def _get_cache_key(code: str, start: str, end: str) -> str:
|
||
"""生成缓存键"""
|
||
key = f"{code}:{start}:{end}"
|
||
return hashlib.md5(key.encode()).hexdigest()
|
||
|
||
|
||
def _get_cached(key: str) -> Optional[Any]:
|
||
"""获取缓存数据"""
|
||
if key not in _cache:
|
||
return None
|
||
|
||
data, timestamp = _cache[key]
|
||
if datetime.now().timestamp() - timestamp > _cache_ttl:
|
||
# 缓存过期
|
||
del _cache[key]
|
||
return None
|
||
|
||
return data
|
||
|
||
|
||
def _set_cache(key: str, data: Any):
|
||
"""设置缓存数据"""
|
||
_cache[key] = (data, datetime.now().timestamp())
|
||
|
||
|
||
def clear_cache():
|
||
"""清理所有缓存"""
|
||
_cache.clear()
|
||
print("✓ 缓存已清理")
|
||
|
||
|
||
def get_ssh_config() -> Optional[Dict]:
|
||
"""
|
||
从环境变量获取 SSH 配置
|
||
|
||
环境变量:
|
||
SSH_ENABLED: 是否启用 SSH 隧道 (true/false)
|
||
SSH_HOST: SSH 服务器地址
|
||
SSH_PORT: SSH 端口 (默认 22)
|
||
SSH_USERNAME: SSH 用户名
|
||
SSH_KEY_PATH: SSH 私钥路径
|
||
SSH_LOCAL_PORT: 本地 SOCKS5 端口 (默认 1080)
|
||
"""
|
||
enabled = os.getenv('SSH_ENABLED', 'false').lower() == 'true'
|
||
|
||
if not enabled:
|
||
return None
|
||
|
||
return {
|
||
"enabled": True,
|
||
"host": os.getenv('SSH_HOST', ''),
|
||
"port": int(os.getenv('SSH_PORT', '22')),
|
||
"username": os.getenv('SSH_USERNAME', ''),
|
||
"key_path": os.getenv('SSH_KEY_PATH', 'hk_ecs.pem'),
|
||
"local_port": int(os.getenv('SSH_LOCAL_PORT', '1080')),
|
||
}
|
||
|
||
|
||
def get_fetcher() -> UniversalDataFetcher:
|
||
"""获取或创建数据获取器实例"""
|
||
global fetcher, ssh_config
|
||
|
||
if fetcher is None:
|
||
ssh_config = get_ssh_config()
|
||
fetcher = UniversalDataFetcher(ssh_config=ssh_config)
|
||
|
||
return fetcher
|
||
|
||
|
||
def dataframe_to_json(df: pd.DataFrame) -> Dict:
|
||
"""将 DataFrame 转换为 JSON 可序列化的字典"""
|
||
if df is None or len(df) == 0:
|
||
return {"data": [], "count": 0}
|
||
|
||
# 重置索引,将日期转为列
|
||
df_reset = df.reset_index()
|
||
|
||
# 处理日期列
|
||
if 'date' in df_reset.columns:
|
||
df_reset['date'] = df_reset['date'].dt.strftime('%Y-%m-%d')
|
||
elif 'index' in df_reset.columns:
|
||
df_reset['index'] = pd.to_datetime(df_reset['index']).dt.strftime('%Y-%m-%d')
|
||
df_reset = df_reset.rename(columns={'index': 'date'})
|
||
|
||
# 转换为字典列表
|
||
records = df_reset.to_dict(orient='records')
|
||
|
||
return {
|
||
"data": records,
|
||
"count": len(records),
|
||
"columns": list(df_reset.columns),
|
||
"date_range": {
|
||
"start": df.index.min().strftime('%Y-%m-%d'),
|
||
"end": df.index.max().strftime('%Y-%m-%d'),
|
||
} if len(df) > 0 else None
|
||
}
|
||
|
||
|
||
def validate_date(date_str: str) -> bool:
|
||
"""验证日期格式"""
|
||
try:
|
||
datetime.strptime(date_str, '%Y-%m-%d')
|
||
return True
|
||
except ValueError:
|
||
return False
|
||
|
||
|
||
def get_default_dates() -> tuple:
|
||
"""获取默认日期范围(最近3个月)"""
|
||
end = datetime.now()
|
||
start = end - timedelta(days=90)
|
||
return start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d')
|
||
|
||
|
||
# ============================================================
|
||
# API 路由
|
||
# ============================================================
|
||
|
||
@app.route('/')
|
||
def index():
|
||
"""首页 - API 信息"""
|
||
return jsonify({
|
||
"name": "Universal Data Fetcher API",
|
||
"version": "1.1.0",
|
||
"description": "统一数据获取服务,支持A股、港股、美股、期货、加密货币",
|
||
"features": [
|
||
"自动资产类型识别",
|
||
"内存缓存(默认5分钟TTL)",
|
||
"SSH隧道支持(港美股)",
|
||
"批量数据获取"
|
||
],
|
||
"endpoints": {
|
||
"health": "/health",
|
||
"asset_type": "/api/v1/asset-type?code={code}",
|
||
"ohlcv": "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}",
|
||
"ohlcv_nocache": "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}&nocache=true",
|
||
"batch": "POST /api/v1/ohlcv/batch",
|
||
"cache_clear": "POST /api/v1/cache/clear",
|
||
"cache_stats": "/api/v1/cache/stats",
|
||
},
|
||
"supported_assets": [
|
||
"A股指数 (000300.SH)",
|
||
"A股ETF (510300.SH)",
|
||
"A股股票 (600000.SH)",
|
||
"港股指数 (HSI)",
|
||
"美股指数 (NDX)",
|
||
"美股股票 (AAPL)",
|
||
"期货 (AU.SHF)",
|
||
"加密货币 (BTC)",
|
||
],
|
||
"cache_config": {
|
||
"ttl_seconds": _cache_ttl,
|
||
"current_size": len(_cache)
|
||
},
|
||
"ssh_status": "enabled" if ssh_config and ssh_config.get('enabled') else "disabled",
|
||
})
|
||
|
||
|
||
@app.route('/health')
|
||
def health():
|
||
"""健康检查"""
|
||
return jsonify({
|
||
"status": "healthy",
|
||
"timestamp": datetime.now().isoformat(),
|
||
"ssh_configured": ssh_config is not None and ssh_config.get('enabled', False),
|
||
})
|
||
|
||
|
||
@app.route('/api/v1/asset-type')
|
||
def detect_asset_type():
|
||
"""
|
||
检测资产类型
|
||
|
||
Query Parameters:
|
||
code: 标的代码 (required)
|
||
|
||
Returns:
|
||
{
|
||
"code": "000300.SH",
|
||
"asset_type": "china_index",
|
||
"description": "A股指数"
|
||
}
|
||
"""
|
||
code = request.args.get('code', '').strip()
|
||
|
||
if not code:
|
||
return jsonify({
|
||
"error": "Missing required parameter: code",
|
||
"example": "/api/v1/asset-type?code=000300.SH"
|
||
}), 400
|
||
|
||
asset_type = AssetTypeDetector.detect(code)
|
||
|
||
# 资产类型描述映射
|
||
descriptions = {
|
||
'china_index': 'A股指数',
|
||
'china_etf': 'A股ETF',
|
||
'china_stock': 'A股股票',
|
||
'hk_index': '港股指数',
|
||
'hk_stock': '港股股票',
|
||
'us_index': '美股指数',
|
||
'us_stock': '美股股票',
|
||
'futures': '期货合约',
|
||
'crypto': '加密货币',
|
||
}
|
||
|
||
return jsonify({
|
||
"code": code,
|
||
"asset_type": asset_type,
|
||
"description": descriptions.get(asset_type, '未知类型'),
|
||
})
|
||
|
||
|
||
@app.route('/api/v1/ohlcv')
|
||
def get_ohlcv():
|
||
"""
|
||
获取单只标的的 OHLCV 数据(支持缓存)
|
||
|
||
Query Parameters:
|
||
code: 标的代码 (required)
|
||
start: 开始日期,格式 YYYY-MM-DD (optional, 默认90天前)
|
||
end: 结束日期,格式 YYYY-MM-DD (optional, 默认今天)
|
||
retry: 重试次数 (optional, 默认3)
|
||
nocache: 是否跳过缓存 (optional, 默认false)
|
||
|
||
Returns:
|
||
{
|
||
"code": "000300.SH",
|
||
"asset_type": "china_index",
|
||
"data": [
|
||
{"date": "2024-01-02", "open": 3500.0, "high": 3550.0, ...},
|
||
...
|
||
],
|
||
"count": 58,
|
||
"date_range": {"start": "2024-01-02", "end": "2024-03-29"},
|
||
"cached": false
|
||
}
|
||
"""
|
||
code = request.args.get('code', '').strip()
|
||
start = request.args.get('start', '').strip()
|
||
end = request.args.get('end', '').strip()
|
||
retry = request.args.get('retry', '3')
|
||
nocache = request.args.get('nocache', 'false').lower() == 'true'
|
||
|
||
# 参数验证
|
||
if not code:
|
||
return jsonify({
|
||
"error": "Missing required parameter: code",
|
||
"example": "/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31"
|
||
}), 400
|
||
|
||
# 设置默认日期
|
||
if not start or not end:
|
||
default_start, default_end = get_default_dates()
|
||
start = start or default_start
|
||
end = end or default_end
|
||
|
||
# 日期格式验证
|
||
if not validate_date(start) or not validate_date(end):
|
||
return jsonify({
|
||
"error": "Invalid date format. Use YYYY-MM-DD",
|
||
"start": start,
|
||
"end": end,
|
||
}), 400
|
||
|
||
try:
|
||
retry_count = int(retry)
|
||
except ValueError:
|
||
retry_count = 3
|
||
|
||
# 检测资产类型
|
||
asset_type = AssetTypeDetector.detect(code)
|
||
|
||
# 检查缓存
|
||
cache_key = _get_cache_key(code, start, end)
|
||
if not nocache:
|
||
cached_data = _get_cached(cache_key)
|
||
if cached_data:
|
||
cached_data['cached'] = True
|
||
return jsonify(cached_data)
|
||
|
||
# 获取数据
|
||
fetcher = get_fetcher()
|
||
|
||
try:
|
||
with fetcher:
|
||
df = fetcher.fetch(code, start, end, retry=retry_count)
|
||
|
||
if df is None or len(df) == 0:
|
||
return jsonify({
|
||
"code": code,
|
||
"asset_type": asset_type,
|
||
"error": "No data available for the specified code and date range",
|
||
"start": start,
|
||
"end": end,
|
||
}), 404
|
||
|
||
result = dataframe_to_json(df)
|
||
result['code'] = code
|
||
result['asset_type'] = asset_type
|
||
result['cached'] = False
|
||
|
||
# 存入缓存
|
||
if not nocache:
|
||
_set_cache(cache_key, result.copy())
|
||
|
||
return jsonify(result)
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
"code": code,
|
||
"asset_type": asset_type,
|
||
"error": str(e),
|
||
"start": start,
|
||
"end": end,
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/v1/ohlcv/batch', methods=['POST'])
|
||
def batch_ohlcv():
|
||
"""
|
||
批量获取多只标的的 OHLCV 数据
|
||
|
||
Request Body:
|
||
{
|
||
"codes": ["000300.SH", "NDX", "HSI"],
|
||
"start": "2024-01-01",
|
||
"end": "2024-03-31",
|
||
"retry": 3
|
||
}
|
||
|
||
Returns:
|
||
{
|
||
"results": {
|
||
"000300.SH": {"data": [...], "count": 58, ...},
|
||
"NDX": {"data": [...], "count": 61, ...},
|
||
"HSI": {"error": "No data available", ...}
|
||
},
|
||
"success_count": 2,
|
||
"failed_count": 1,
|
||
"total": 3
|
||
}
|
||
"""
|
||
data = request.get_json()
|
||
|
||
if not data:
|
||
return jsonify({
|
||
"error": "Missing request body",
|
||
"example": {
|
||
"codes": ["000300.SH", "NDX"],
|
||
"start": "2024-01-01",
|
||
"end": "2024-03-31",
|
||
}
|
||
}), 400
|
||
|
||
codes = data.get('codes', [])
|
||
start = data.get('start', '').strip()
|
||
end = data.get('end', '').strip()
|
||
retry = data.get('retry', 3)
|
||
|
||
if not codes or not isinstance(codes, list):
|
||
return jsonify({
|
||
"error": "Missing or invalid parameter: codes (must be a list)"
|
||
}), 400
|
||
|
||
# 设置默认日期
|
||
if not start or not end:
|
||
default_start, default_end = get_default_dates()
|
||
start = start or default_start
|
||
end = end or default_end
|
||
|
||
# 日期格式验证
|
||
if not validate_date(start) or not validate_date(end):
|
||
return jsonify({
|
||
"error": "Invalid date format. Use YYYY-MM-DD",
|
||
"start": start,
|
||
"end": end,
|
||
}), 400
|
||
|
||
# 获取数据
|
||
fetcher = get_fetcher()
|
||
results = {}
|
||
success_count = 0
|
||
failed_count = 0
|
||
|
||
try:
|
||
with fetcher:
|
||
for code in codes:
|
||
try:
|
||
df = fetcher.fetch(code, start, end, retry=retry)
|
||
|
||
if df is not None and len(df) > 0:
|
||
result = dataframe_to_json(df)
|
||
result['code'] = code
|
||
result['asset_type'] = AssetTypeDetector.detect(code)
|
||
results[code] = result
|
||
success_count += 1
|
||
else:
|
||
results[code] = {
|
||
"code": code,
|
||
"asset_type": AssetTypeDetector.detect(code),
|
||
"error": "No data available",
|
||
"data": [],
|
||
"count": 0,
|
||
}
|
||
failed_count += 1
|
||
|
||
except Exception as e:
|
||
results[code] = {
|
||
"code": code,
|
||
"asset_type": AssetTypeDetector.detect(code),
|
||
"error": str(e),
|
||
"data": [],
|
||
"count": 0,
|
||
}
|
||
failed_count += 1
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
"error": f"Batch fetch failed: {str(e)}"
|
||
}), 500
|
||
|
||
return jsonify({
|
||
"results": results,
|
||
"success_count": success_count,
|
||
"failed_count": failed_count,
|
||
"total": len(codes),
|
||
"start": start,
|
||
"end": end,
|
||
})
|
||
|
||
|
||
@app.route('/api/v1/cache/clear', methods=['POST'])
|
||
def clear_cache_endpoint():
|
||
"""
|
||
清理所有缓存数据
|
||
|
||
Returns:
|
||
{"message": "Cache cleared", "cache_size": 0}
|
||
"""
|
||
cache_size = len(_cache)
|
||
clear_cache()
|
||
return jsonify({
|
||
"message": "Cache cleared successfully",
|
||
"previous_size": cache_size,
|
||
"current_size": 0
|
||
})
|
||
|
||
|
||
@app.route('/api/v1/cache/stats')
|
||
def cache_stats():
|
||
"""
|
||
获取缓存统计信息
|
||
|
||
Returns:
|
||
{
|
||
"cache_size": 10,
|
||
"cache_ttl_seconds": 300,
|
||
"memory_estimate_mb": 5.2
|
||
}
|
||
"""
|
||
# 估算内存使用(粗略估计)
|
||
import sys
|
||
total_size = 0
|
||
for key, (data, _) in _cache.items():
|
||
total_size += sys.getsizeof(key) + sys.getsizeof(data)
|
||
|
||
return jsonify({
|
||
"cache_size": len(_cache),
|
||
"cache_ttl_seconds": _cache_ttl,
|
||
"memory_estimate_mb": round(total_size / 1024 / 1024, 2)
|
||
})
|
||
|
||
|
||
@app.route('/api/v1/supported-codes')
|
||
def get_supported_codes():
|
||
"""
|
||
获取支持的代码示例
|
||
|
||
Returns:
|
||
{
|
||
"china_index": ["000300.SH", "399006.SZ", ...],
|
||
"china_etf": ["510300.SH", "159915.SZ", ...],
|
||
...
|
||
}
|
||
"""
|
||
return jsonify({
|
||
"china_index": {
|
||
"description": "A股指数",
|
||
"examples": ["000300.SH", "399006.SZ", "000016.SH", "H30269.CSI"],
|
||
},
|
||
"china_etf": {
|
||
"description": "A股ETF",
|
||
"examples": ["510300.SH", "159915.SZ", "510500.SH", "513100.SH"],
|
||
},
|
||
"china_stock": {
|
||
"description": "A股股票",
|
||
"examples": ["600000.SH", "000001.SZ"],
|
||
},
|
||
"hk_index": {
|
||
"description": "港股指数",
|
||
"examples": ["HSI", "HSTECH.HK"],
|
||
},
|
||
"us_index": {
|
||
"description": "美股指数",
|
||
"examples": ["NDX", "SPX", "DJI", "N225", "GDAXI"],
|
||
},
|
||
"us_stock": {
|
||
"description": "美股股票",
|
||
"examples": ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"],
|
||
},
|
||
"futures": {
|
||
"description": "期货合约",
|
||
"examples": ["AU.SHF", "CU.SHF"],
|
||
},
|
||
"crypto": {
|
||
"description": "加密货币",
|
||
"examples": ["BTC", "ETH"],
|
||
},
|
||
})
|
||
|
||
|
||
# ============================================================
|
||
# 错误处理
|
||
# ============================================================
|
||
|
||
@app.errorhandler(404)
|
||
def not_found(error):
|
||
return jsonify({
|
||
"error": "Endpoint not found",
|
||
"available_endpoints": [
|
||
"/",
|
||
"/health",
|
||
"/api/v1/asset-type?code={code}",
|
||
"/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}",
|
||
"/api/v1/ohlcv/batch",
|
||
"/api/v1/cache/clear",
|
||
"/api/v1/cache/stats",
|
||
"/api/v1/supported-codes",
|
||
]
|
||
}), 404
|
||
|
||
|
||
@app.errorhandler(500)
|
||
def internal_error(error):
|
||
return jsonify({
|
||
"error": "Internal server error",
|
||
"message": str(error)
|
||
}), 500
|
||
|
||
|
||
# ============================================================
|
||
# 启动服务
|
||
# ============================================================
|
||
|
||
if __name__ == '__main__':
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='Universal Data Fetcher API Server')
|
||
parser.add_argument('--host', default='0.0.0.0', help='Host to bind (default: 0.0.0.0)')
|
||
parser.add_argument('--port', type=int, default=5000, help='Port to bind (default: 5000)')
|
||
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 预加载 SSH 配置
|
||
ssh_config = get_ssh_config()
|
||
if ssh_config and ssh_config.get('enabled'):
|
||
print(f"✓ SSH 隧道已配置: {ssh_config['host']}:{ssh_config['port']}")
|
||
else:
|
||
print("✗ SSH 隧道未启用(仅支持A股数据)")
|
||
|
||
print(f"\n🚀 启动 Universal Data Fetcher API Server")
|
||
print(f" Host: {args.host}")
|
||
print(f" Port: {args.port}")
|
||
print(f" Debug: {args.debug}")
|
||
print(f"\n📖 API 文档: http://{args.host}:{args.port}/")
|
||
print(f" 健康检查: http://{args.host}:{args.port}/health")
|
||
print(f"\n💡 示例请求:")
|
||
print(f" curl 'http://{args.host}:{args.port}/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31'")
|
||
print(f"\n")
|
||
|
||
app.run(host=args.host, port=args.port, debug=args.debug)
|