Files
etf/datasource/flask_server.py
aszerW 4e3aac5e0e feat: Flask统一数据服务迁移(分层架构)
架构设计:
- 对外统一接口 fetch():自动识别资产类型并路由
- 对内分层实现:各资产类型独立方法,职责单一

新增文件:
- datasource/universal_fetcher.py: 统一数据获取器
  - _fetch_china_index: A股指数(Tushare)
  - _fetch_china_etf: A股ETF(含净值)
  - _fetch_us_index: 美股指数(YFinance+SSH)
  - _fetch_hk_index: 港股指数(YFinance+SSH)
  - _fetch_futures: 期货(Tushare/YFinance)
  - fetch_etf_with_nav: ETF价格+净值(计算溢价率)

- datasource/asset_type_detector.py: 资产类型检测器
  - AssetType枚举:9种资产类型
  - detect(): 自动识别资产类型
  - group_by_type(): 批量分组

- datasource/flask_server.py: Flask API服务
  - LRU + TTL 双缓存机制
  - 8个API端点:ohlcv、etf/nav、batch、cache等

更新:
- datasource/__init__.py: 导出新模块

验证:
- 模块导入成功
- 资产类型检测正确
- A股数据获取正常(沪深300: 5条)
2026-05-12 21:33:19 +08:00

588 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Flask 数据服务 API
==================
提供 RESTful API 接口,支持获取各类资产的 K 线数据
特性:
- 分层架构:各资产类型独立实现
- LRU + TTL 双缓存机制
- SSH隧道支持港美股
- ETF净值获取计算溢价率
运行:
python datasource/flask_server.py
API 文档:
GET / - 服务信息
GET /health - 健康检查
GET /api/v1/asset-type - 检测资产类型
GET /api/v1/ohlcv - 获取K线数据
POST /api/v1/ohlcv/batch - 批量获取K线数据
GET /api/v1/etf/nav - 获取ETF净值
POST /api/v1/cache/clear - 清理缓存
GET /api/v1/cache/stats - 缓存统计
"""
import os
import sys
import json
from pathlib import Path
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List, Tuple
from functools import lru_cache
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from dotenv import load_dotenv
load_dotenv()
from flask import Flask, request, jsonify
from flask_cors import CORS
import pandas as pd
from datasource.universal_fetcher import UniversalDataFetcher
from datasource.asset_type_detector import AssetTypeDetector, AssetType
# ============================================================
# Flask 应用配置
# ============================================================
app = Flask(__name__)
CORS(app) # 启用跨域支持
# 全局数据获取器实例
fetcher: Optional[UniversalDataFetcher] = None
ssh_config: Optional[Dict] = None
# 缓存配置
CACHE_MAXSIZE = int(os.getenv('CACHE_MAXSIZE', '128'))
CACHE_TTL_SECONDS = int(os.getenv('CACHE_TTL_SECONDS', '7200')) # 默认2小时
class TimedCacheEntry:
"""带时间戳的缓存条目"""
def __init__(self, data: Any):
self.data = data
self.timestamp = datetime.now()
def is_expired(self) -> bool:
return (datetime.now() - self.timestamp).total_seconds() > CACHE_TTL_SECONDS
# TTL缓存存储
_ttl_cache: Dict[Tuple, TimedCacheEntry] = {}
# ============================================================
# 初始化
# ============================================================
def get_ssh_config() -> Optional[Dict]:
"""从环境变量获取 SSH 配置"""
enabled = os.getenv('SSH_ENABLED', 'false').lower() == 'true'
if not enabled:
return None
return {
"enabled": True,
"host": os.getenv('SSH_HOST', ''),
"port": int(os.getenv('SSH_PORT', '22')),
"username": os.getenv('SSH_USERNAME', ''),
"key_path": os.getenv('SSH_KEY_PATH', 'hk_ecs.pem'),
"local_port": int(os.getenv('SSH_LOCAL_PORT', '1080')),
}
def get_fetcher() -> UniversalDataFetcher:
"""获取或创建数据获取器实例"""
global fetcher, ssh_config
if fetcher is None:
ssh_config = get_ssh_config()
fetcher = UniversalDataFetcher(ssh_config=ssh_config)
return fetcher
# ============================================================
# 缓存机制
# ============================================================
@lru_cache(maxsize=CACHE_MAXSIZE)
def _fetch_data_cached(code: str, start: str, end: str) -> Optional[str]:
"""
获取数据的缓存版本
返回 JSON 序列化的字符串
"""
f = get_fetcher()
try:
with f:
df = f.fetch(code, start, end)
if df is None or len(df) == 0:
return None
result = dataframe_to_json(df)
result['code'] = code
result['asset_type'] = AssetTypeDetector.detect(code).value
return json.dumps(result)
except Exception as e:
return json.dumps({"error": str(e)})
def fetch_data_with_ttl(
code: str,
start: str,
end: str,
nocache: bool = False
) -> Tuple[Optional[Dict], bool]:
"""
获取数据,支持 TTL 缓存
Args:
code: 标的代码
start: 开始日期
end: 结束日期
nocache: 是否跳过缓存
Returns:
(data, is_cached): 数据和是否命中缓存
"""
cache_key = (code, start, end)
# 跳过缓存
if nocache:
_fetch_data_cached.cache_clear()
result_json = _fetch_data_cached(code, start, end)
return (json.loads(result_json) if result_json else None, False)
# 检查 TTL 缓存
global _ttl_cache
if cache_key in _ttl_cache:
entry = _ttl_cache[cache_key]
if not entry.is_expired():
return entry.data, True
# 过期,删除
del _ttl_cache[cache_key]
# 从 LRU 缓存获取
result_json = _fetch_data_cached(code, start, end)
if result_json is None:
return None, False
result = json.loads(result_json)
# 存入 TTL 缓存
_ttl_cache[cache_key] = TimedCacheEntry(result)
return result, False
def clear_cache():
"""清理所有缓存"""
global _ttl_cache
_fetch_data_cached.cache_clear()
_ttl_cache.clear()
def get_cache_info() -> Dict:
"""获取缓存统计信息"""
info = _fetch_data_cached.cache_info()
return {
"lru_cache": {
"hits": info.hits,
"misses": info.misses,
"maxsize": info.maxsize,
"currsize": info.currsize,
},
"ttl_cache_size": len(_ttl_cache),
"ttl_seconds": CACHE_TTL_SECONDS,
}
# ============================================================
# DataFrame 转换
# ============================================================
def dataframe_to_json(df: pd.DataFrame) -> Dict:
"""将 DataFrame 转换为 JSON 可序列化的字典"""
if df is None or len(df) == 0:
return {"data": [], "count": 0}
# 重置索引
df_reset = df.reset_index()
# 处理日期列
date_columns = ['date', 'Date', 'index', 'trade_date', 'datetime']
for col in date_columns:
if col in df_reset.columns:
try:
df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime('%Y-%m-%d')
if col != 'date':
df_reset = df_reset.rename(columns={col: 'date'})
break
except Exception:
pass
# 转换为字典列表
records = df_reset.to_dict(orient='records')
return {
"data": records,
"count": len(records),
"columns": list(df_reset.columns),
"date_range": {
"start": df.index.min().strftime('%Y-%m-%d') if len(df) > 0 else None,
"end": df.index.max().strftime('%Y-%m-%d') if len(df) > 0 else None,
}
}
def validate_date(date_str: str) -> bool:
"""验证日期格式"""
try:
datetime.strptime(date_str, '%Y-%m-%d')
return True
except ValueError:
return False
def get_default_dates() -> Tuple[str, str]:
"""获取默认日期范围最近3个月"""
end = datetime.now()
start = end - timedelta(days=90)
return start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d')
# ============================================================
# API 路由
# ============================================================
@app.route('/')
def index():
"""首页 - API 信息"""
return jsonify({
"name": "Universal Data Fetcher API",
"version": "2.0.0",
"description": "统一数据获取服务(分层架构)",
"architecture": "Unified entry + Asset-specific methods",
"features": [
"分层架构(各资产类型独立实现)",
"LRU + TTL 双缓存机制",
"SSH隧道支持港美股",
"ETF净值获取计算溢价率",
],
"endpoints": {
"info": "/",
"health": "/health",
"asset_type": "/api/v1/asset-type?code={code}",
"ohlcv": "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}",
"ohlcv_nocache": "/api/v1/ohlcv?code={code}&nocache=true",
"batch": "POST /api/v1/ohlcv/batch",
"etf_nav": "/api/v1/etf/nav?code={code}",
"cache_clear": "POST /api/v1/cache/clear",
"cache_stats": "/api/v1/cache/stats",
},
"supported_assets": {
"china_index": ["000300.SH", "399006.SZ", "H30269.CSI"],
"china_etf": ["159915.SZ", "513100.SH", "518880.SH"],
"hk_index": ["HSI", "HSTECH.HK"],
"us_index": ["NDX", "SPX", "N225", "GDAXI"],
"futures": ["AU.SHF", "CU.SHF", "CL.NYM"],
"crypto": ["BTC", "ETH"],
},
"cache_config": get_cache_info(),
"ssh_status": "enabled" if ssh_config and ssh_config.get('enabled') else "disabled",
})
@app.route('/health')
def health():
"""健康检查"""
return jsonify({
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"ssh_configured": ssh_config is not None and ssh_config.get('enabled', False),
})
@app.route('/api/v1/asset-type')
def detect_asset_type():
"""检测资产类型"""
code = request.args.get('code', '').strip()
if not code:
return jsonify({
"error": "Missing required parameter: code",
"example": "/api/v1/asset-type?code=000300.SH"
}), 400
asset_type = AssetTypeDetector.detect(code)
description = AssetTypeDetector.get_description(asset_type)
return jsonify({
"code": code,
"asset_type": asset_type.value,
"description": description,
})
@app.route('/api/v1/ohlcv')
def get_ohlcv():
"""
获取单只标的的 OHLCV 数据
Query Parameters:
code: 标的代码 (required)
start: 开始日期 YYYY-MM-DD (optional, 默认90天前)
end: 结束日期 YYYY-MM-DD (optional, 默认今天)
nocache: 是否跳过缓存 (optional, 默认false)
"""
code = request.args.get('code', '').strip()
start = request.args.get('start', '').strip()
end = request.args.get('end', '').strip()
nocache = request.args.get('nocache', 'false').lower() == 'true'
# 参数验证
if not code:
return jsonify({
"error": "Missing required parameter: code",
"example": "/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31"
}), 400
# 设置默认日期
if not start or not end:
start, end = get_default_dates()
# 日期格式验证
if not validate_date(start) or not validate_date(end):
return jsonify({
"error": "Invalid date format. Use YYYY-MM-DD",
"start": start,
"end": end,
}), 400
# 使用缓存获取数据
result, is_cached = fetch_data_with_ttl(code, start, end, nocache)
if result is None:
return jsonify({
"code": code,
"asset_type": AssetTypeDetector.detect(code).value,
"error": "No data available",
"start": start,
"end": end,
}), 404
if "error" in result:
return jsonify({
"code": code,
"asset_type": AssetTypeDetector.detect(code).value,
"error": result["error"],
}), 500
result['cached'] = is_cached
return jsonify(result)
@app.route('/api/v1/ohlcv/batch', methods=['POST'])
def batch_ohlcv():
"""批量获取多只标的的 OHLCV 数据"""
data = request.get_json()
if not data:
return jsonify({
"error": "Missing request body",
"example": {
"codes": ["000300.SH", "NDX"],
"start": "2024-01-01",
"end": "2024-03-31",
}
}), 400
codes = data.get('codes', [])
start = data.get('start', '').strip()
end = data.get('end', '').strip()
if not codes or not isinstance(codes, list):
return jsonify({
"error": "Missing or invalid parameter: codes (must be a list)"
}), 400
if not start or not end:
start, end = get_default_dates()
if not validate_date(start) or not validate_date(end):
return jsonify({
"error": "Invalid date format. Use YYYY-MM-DD",
}), 400
# 获取数据
f = get_fetcher()
results = {}
success_count = 0
failed_count = 0
try:
with f:
for code in codes:
result, _ = fetch_data_with_ttl(code, start, end)
if result is not None and "error" not in result:
results[code] = result
success_count += 1
else:
results[code] = {
"code": code,
"asset_type": AssetTypeDetector.detect(code).value,
"error": result.get("error", "No data") if result else "No data",
"data": [],
"count": 0,
}
failed_count += 1
except Exception as e:
return jsonify({"error": f"Batch fetch failed: {str(e)}"}), 500
return jsonify({
"results": results,
"success_count": success_count,
"failed_count": failed_count,
"total": len(codes),
"start": start,
"end": end,
})
@app.route('/api/v1/etf/nav')
def get_etf_nav():
"""获取ETF净值数据用于计算溢价率"""
code = request.args.get('code', '').strip()
start = request.args.get('start', '').strip()
end = request.args.get('end', '').strip()
if not code:
return jsonify({
"error": "Missing required parameter: code",
"example": "/api/v1/etf/nav?code=513100.SH"
}), 400
if not start or not end:
start, end = get_default_dates()
# 检查是否为ETF
asset_type = AssetTypeDetector.detect(code)
if asset_type != AssetType.CHINA_ETF:
return jsonify({
"error": f"Not an ETF: {code} (type: {asset_type.value})",
"hint": "Only A股ETF (codes starting with 51/52/15/16) supported",
}), 400
# 获取净值
f = get_fetcher()
try:
with f:
price_df, nav_df = f.fetch_etf_with_nav(code, start, end)
result = {
"code": code,
"price": dataframe_to_json(price_df) if price_df else {"data": [], "count": 0},
"nav": dataframe_to_json(nav_df) if nav_df else {"data": [], "count": 0},
}
# 计算最新溢价率
if nav_df is not None and len(nav_df) > 0 and price_df is not None and len(price_df) > 0:
latest_nav = nav_df['nav'].iloc[-1]
latest_price = price_df['close'].iloc[-1]
if latest_nav > 0:
premium = (latest_price - latest_nav) / latest_nav
result['premium_rate'] = premium
result['premium_date'] = nav_df.index[-1].strftime('%Y-%m-%d')
return jsonify(result)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/api/v1/cache/clear', methods=['POST'])
def clear_cache_endpoint():
"""清理缓存"""
info_before = get_cache_info()
clear_cache()
return jsonify({
"message": "Cache cleared successfully",
"before": info_before,
"after": get_cache_info()
})
@app.route('/api/v1/cache/stats')
def cache_stats():
"""获取缓存统计"""
return jsonify(get_cache_info())
# ============================================================
# 错误处理
# ============================================================
@app.errorhandler(404)
def not_found(error):
return jsonify({
"error": "Endpoint not found",
"available_endpoints": [
"/", "/health",
"/api/v1/asset-type",
"/api/v1/ohlcv",
"/api/v1/ohlcv/batch",
"/api/v1/etf/nav",
"/api/v1/cache/clear",
"/api/v1/cache/stats",
]
}), 404
@app.errorhandler(500)
def internal_error(error):
return jsonify({
"error": "Internal server error",
"message": str(error)
}), 500
# ============================================================
# 启动服务
# ============================================================
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Universal Data Fetcher API Server')
parser.add_argument('--host', default='0.0.0.0', help='Host to bind')
parser.add_argument('--port', type=int, default=5000, help='Port to bind')
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
args = parser.parse_args()
# 预加载 SSH 配置
ssh_config = get_ssh_config()
if ssh_config and ssh_config.get('enabled'):
print(f"✓ SSH 隧道已配置: {ssh_config['host']}:{ssh_config['port']}")
else:
print("✗ SSH 隧道未启用仅支持A股数据")
print(f"\n🚀 Universal Data Fetcher API Server v2.0")
print(f" Host: {args.host}")
print(f" Port: {args.port}")
print(f" Cache: LRU({CACHE_MAXSIZE}) + TTL({CACHE_TTL_SECONDS}s)")
print(f"\n📖 API: http://{args.host}:{args.port}/")
print(f" 健康检查: http://{args.host}:{args.port}/health")
app.run(host=args.host, port=args.port, debug=args.debug)