feat: Flask统一数据服务迁移(分层架构)

架构设计:
- 对外统一接口 fetch():自动识别资产类型并路由
- 对内分层实现:各资产类型独立方法,职责单一

新增文件:
- datasource/universal_fetcher.py: 统一数据获取器
  - _fetch_china_index: A股指数(Tushare)
  - _fetch_china_etf: A股ETF(含净值)
  - _fetch_us_index: 美股指数(YFinance+SSH)
  - _fetch_hk_index: 港股指数(YFinance+SSH)
  - _fetch_futures: 期货(Tushare/YFinance)
  - fetch_etf_with_nav: ETF价格+净值(计算溢价率)

- datasource/asset_type_detector.py: 资产类型检测器
  - AssetType枚举:9种资产类型
  - detect(): 自动识别资产类型
  - group_by_type(): 批量分组

- datasource/flask_server.py: Flask API服务
  - LRU + TTL 双缓存机制
  - 8个API端点:ohlcv、etf/nav、batch、cache等

更新:
- datasource/__init__.py: 导出新模块

验证:
- 模块导入成功
- 资产类型检测正确
- A股数据获取正常(沪深300: 5条)
This commit is contained in:
2026-05-12 21:33:19 +08:00
parent c63158c99d
commit 4e3aac5e0e
5 changed files with 1144 additions and 0 deletions

588
datasource/flask_server.py Normal file
View File

@@ -0,0 +1,588 @@
"""
Flask 数据服务 API
==================
提供 RESTful API 接口,支持获取各类资产的 K 线数据
特性:
- 分层架构:各资产类型独立实现
- LRU + TTL 双缓存机制
- SSH隧道支持港美股
- ETF净值获取计算溢价率
运行:
python datasource/flask_server.py
API 文档:
GET / - 服务信息
GET /health - 健康检查
GET /api/v1/asset-type - 检测资产类型
GET /api/v1/ohlcv - 获取K线数据
POST /api/v1/ohlcv/batch - 批量获取K线数据
GET /api/v1/etf/nav - 获取ETF净值
POST /api/v1/cache/clear - 清理缓存
GET /api/v1/cache/stats - 缓存统计
"""
import os
import sys
import json
from pathlib import Path
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List, Tuple
from functools import lru_cache
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from dotenv import load_dotenv
load_dotenv()
from flask import Flask, request, jsonify
from flask_cors import CORS
import pandas as pd
from datasource.universal_fetcher import UniversalDataFetcher
from datasource.asset_type_detector import AssetTypeDetector, AssetType
# ============================================================
# Flask 应用配置
# ============================================================
app = Flask(__name__)
CORS(app) # 启用跨域支持
# 全局数据获取器实例
fetcher: Optional[UniversalDataFetcher] = None
ssh_config: Optional[Dict] = None
# 缓存配置
CACHE_MAXSIZE = int(os.getenv('CACHE_MAXSIZE', '128'))
CACHE_TTL_SECONDS = int(os.getenv('CACHE_TTL_SECONDS', '7200')) # 默认2小时
class TimedCacheEntry:
"""带时间戳的缓存条目"""
def __init__(self, data: Any):
self.data = data
self.timestamp = datetime.now()
def is_expired(self) -> bool:
return (datetime.now() - self.timestamp).total_seconds() > CACHE_TTL_SECONDS
# TTL缓存存储
_ttl_cache: Dict[Tuple, TimedCacheEntry] = {}
# ============================================================
# 初始化
# ============================================================
def get_ssh_config() -> Optional[Dict]:
"""从环境变量获取 SSH 配置"""
enabled = os.getenv('SSH_ENABLED', 'false').lower() == 'true'
if not enabled:
return None
return {
"enabled": True,
"host": os.getenv('SSH_HOST', ''),
"port": int(os.getenv('SSH_PORT', '22')),
"username": os.getenv('SSH_USERNAME', ''),
"key_path": os.getenv('SSH_KEY_PATH', 'hk_ecs.pem'),
"local_port": int(os.getenv('SSH_LOCAL_PORT', '1080')),
}
def get_fetcher() -> UniversalDataFetcher:
"""获取或创建数据获取器实例"""
global fetcher, ssh_config
if fetcher is None:
ssh_config = get_ssh_config()
fetcher = UniversalDataFetcher(ssh_config=ssh_config)
return fetcher
# ============================================================
# 缓存机制
# ============================================================
@lru_cache(maxsize=CACHE_MAXSIZE)
def _fetch_data_cached(code: str, start: str, end: str) -> Optional[str]:
"""
获取数据的缓存版本
返回 JSON 序列化的字符串
"""
f = get_fetcher()
try:
with f:
df = f.fetch(code, start, end)
if df is None or len(df) == 0:
return None
result = dataframe_to_json(df)
result['code'] = code
result['asset_type'] = AssetTypeDetector.detect(code).value
return json.dumps(result)
except Exception as e:
return json.dumps({"error": str(e)})
def fetch_data_with_ttl(
code: str,
start: str,
end: str,
nocache: bool = False
) -> Tuple[Optional[Dict], bool]:
"""
获取数据,支持 TTL 缓存
Args:
code: 标的代码
start: 开始日期
end: 结束日期
nocache: 是否跳过缓存
Returns:
(data, is_cached): 数据和是否命中缓存
"""
cache_key = (code, start, end)
# 跳过缓存
if nocache:
_fetch_data_cached.cache_clear()
result_json = _fetch_data_cached(code, start, end)
return (json.loads(result_json) if result_json else None, False)
# 检查 TTL 缓存
global _ttl_cache
if cache_key in _ttl_cache:
entry = _ttl_cache[cache_key]
if not entry.is_expired():
return entry.data, True
# 过期,删除
del _ttl_cache[cache_key]
# 从 LRU 缓存获取
result_json = _fetch_data_cached(code, start, end)
if result_json is None:
return None, False
result = json.loads(result_json)
# 存入 TTL 缓存
_ttl_cache[cache_key] = TimedCacheEntry(result)
return result, False
def clear_cache():
"""清理所有缓存"""
global _ttl_cache
_fetch_data_cached.cache_clear()
_ttl_cache.clear()
def get_cache_info() -> Dict:
"""获取缓存统计信息"""
info = _fetch_data_cached.cache_info()
return {
"lru_cache": {
"hits": info.hits,
"misses": info.misses,
"maxsize": info.maxsize,
"currsize": info.currsize,
},
"ttl_cache_size": len(_ttl_cache),
"ttl_seconds": CACHE_TTL_SECONDS,
}
# ============================================================
# DataFrame 转换
# ============================================================
def dataframe_to_json(df: pd.DataFrame) -> Dict:
"""将 DataFrame 转换为 JSON 可序列化的字典"""
if df is None or len(df) == 0:
return {"data": [], "count": 0}
# 重置索引
df_reset = df.reset_index()
# 处理日期列
date_columns = ['date', 'Date', 'index', 'trade_date', 'datetime']
for col in date_columns:
if col in df_reset.columns:
try:
df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime('%Y-%m-%d')
if col != 'date':
df_reset = df_reset.rename(columns={col: 'date'})
break
except Exception:
pass
# 转换为字典列表
records = df_reset.to_dict(orient='records')
return {
"data": records,
"count": len(records),
"columns": list(df_reset.columns),
"date_range": {
"start": df.index.min().strftime('%Y-%m-%d') if len(df) > 0 else None,
"end": df.index.max().strftime('%Y-%m-%d') if len(df) > 0 else None,
}
}
def validate_date(date_str: str) -> bool:
"""验证日期格式"""
try:
datetime.strptime(date_str, '%Y-%m-%d')
return True
except ValueError:
return False
def get_default_dates() -> Tuple[str, str]:
"""获取默认日期范围最近3个月"""
end = datetime.now()
start = end - timedelta(days=90)
return start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d')
# ============================================================
# API 路由
# ============================================================
@app.route('/')
def index():
"""首页 - API 信息"""
return jsonify({
"name": "Universal Data Fetcher API",
"version": "2.0.0",
"description": "统一数据获取服务(分层架构)",
"architecture": "Unified entry + Asset-specific methods",
"features": [
"分层架构(各资产类型独立实现)",
"LRU + TTL 双缓存机制",
"SSH隧道支持港美股",
"ETF净值获取计算溢价率",
],
"endpoints": {
"info": "/",
"health": "/health",
"asset_type": "/api/v1/asset-type?code={code}",
"ohlcv": "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}",
"ohlcv_nocache": "/api/v1/ohlcv?code={code}&nocache=true",
"batch": "POST /api/v1/ohlcv/batch",
"etf_nav": "/api/v1/etf/nav?code={code}",
"cache_clear": "POST /api/v1/cache/clear",
"cache_stats": "/api/v1/cache/stats",
},
"supported_assets": {
"china_index": ["000300.SH", "399006.SZ", "H30269.CSI"],
"china_etf": ["159915.SZ", "513100.SH", "518880.SH"],
"hk_index": ["HSI", "HSTECH.HK"],
"us_index": ["NDX", "SPX", "N225", "GDAXI"],
"futures": ["AU.SHF", "CU.SHF", "CL.NYM"],
"crypto": ["BTC", "ETH"],
},
"cache_config": get_cache_info(),
"ssh_status": "enabled" if ssh_config and ssh_config.get('enabled') else "disabled",
})
@app.route('/health')
def health():
"""健康检查"""
return jsonify({
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"ssh_configured": ssh_config is not None and ssh_config.get('enabled', False),
})
@app.route('/api/v1/asset-type')
def detect_asset_type():
"""检测资产类型"""
code = request.args.get('code', '').strip()
if not code:
return jsonify({
"error": "Missing required parameter: code",
"example": "/api/v1/asset-type?code=000300.SH"
}), 400
asset_type = AssetTypeDetector.detect(code)
description = AssetTypeDetector.get_description(asset_type)
return jsonify({
"code": code,
"asset_type": asset_type.value,
"description": description,
})
@app.route('/api/v1/ohlcv')
def get_ohlcv():
"""
获取单只标的的 OHLCV 数据
Query Parameters:
code: 标的代码 (required)
start: 开始日期 YYYY-MM-DD (optional, 默认90天前)
end: 结束日期 YYYY-MM-DD (optional, 默认今天)
nocache: 是否跳过缓存 (optional, 默认false)
"""
code = request.args.get('code', '').strip()
start = request.args.get('start', '').strip()
end = request.args.get('end', '').strip()
nocache = request.args.get('nocache', 'false').lower() == 'true'
# 参数验证
if not code:
return jsonify({
"error": "Missing required parameter: code",
"example": "/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31"
}), 400
# 设置默认日期
if not start or not end:
start, end = get_default_dates()
# 日期格式验证
if not validate_date(start) or not validate_date(end):
return jsonify({
"error": "Invalid date format. Use YYYY-MM-DD",
"start": start,
"end": end,
}), 400
# 使用缓存获取数据
result, is_cached = fetch_data_with_ttl(code, start, end, nocache)
if result is None:
return jsonify({
"code": code,
"asset_type": AssetTypeDetector.detect(code).value,
"error": "No data available",
"start": start,
"end": end,
}), 404
if "error" in result:
return jsonify({
"code": code,
"asset_type": AssetTypeDetector.detect(code).value,
"error": result["error"],
}), 500
result['cached'] = is_cached
return jsonify(result)
@app.route('/api/v1/ohlcv/batch', methods=['POST'])
def batch_ohlcv():
"""批量获取多只标的的 OHLCV 数据"""
data = request.get_json()
if not data:
return jsonify({
"error": "Missing request body",
"example": {
"codes": ["000300.SH", "NDX"],
"start": "2024-01-01",
"end": "2024-03-31",
}
}), 400
codes = data.get('codes', [])
start = data.get('start', '').strip()
end = data.get('end', '').strip()
if not codes or not isinstance(codes, list):
return jsonify({
"error": "Missing or invalid parameter: codes (must be a list)"
}), 400
if not start or not end:
start, end = get_default_dates()
if not validate_date(start) or not validate_date(end):
return jsonify({
"error": "Invalid date format. Use YYYY-MM-DD",
}), 400
# 获取数据
f = get_fetcher()
results = {}
success_count = 0
failed_count = 0
try:
with f:
for code in codes:
result, _ = fetch_data_with_ttl(code, start, end)
if result is not None and "error" not in result:
results[code] = result
success_count += 1
else:
results[code] = {
"code": code,
"asset_type": AssetTypeDetector.detect(code).value,
"error": result.get("error", "No data") if result else "No data",
"data": [],
"count": 0,
}
failed_count += 1
except Exception as e:
return jsonify({"error": f"Batch fetch failed: {str(e)}"}), 500
return jsonify({
"results": results,
"success_count": success_count,
"failed_count": failed_count,
"total": len(codes),
"start": start,
"end": end,
})
@app.route('/api/v1/etf/nav')
def get_etf_nav():
"""获取ETF净值数据用于计算溢价率"""
code = request.args.get('code', '').strip()
start = request.args.get('start', '').strip()
end = request.args.get('end', '').strip()
if not code:
return jsonify({
"error": "Missing required parameter: code",
"example": "/api/v1/etf/nav?code=513100.SH"
}), 400
if not start or not end:
start, end = get_default_dates()
# 检查是否为ETF
asset_type = AssetTypeDetector.detect(code)
if asset_type != AssetType.CHINA_ETF:
return jsonify({
"error": f"Not an ETF: {code} (type: {asset_type.value})",
"hint": "Only A股ETF (codes starting with 51/52/15/16) supported",
}), 400
# 获取净值
f = get_fetcher()
try:
with f:
price_df, nav_df = f.fetch_etf_with_nav(code, start, end)
result = {
"code": code,
"price": dataframe_to_json(price_df) if price_df else {"data": [], "count": 0},
"nav": dataframe_to_json(nav_df) if nav_df else {"data": [], "count": 0},
}
# 计算最新溢价率
if nav_df is not None and len(nav_df) > 0 and price_df is not None and len(price_df) > 0:
latest_nav = nav_df['nav'].iloc[-1]
latest_price = price_df['close'].iloc[-1]
if latest_nav > 0:
premium = (latest_price - latest_nav) / latest_nav
result['premium_rate'] = premium
result['premium_date'] = nav_df.index[-1].strftime('%Y-%m-%d')
return jsonify(result)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/api/v1/cache/clear', methods=['POST'])
def clear_cache_endpoint():
"""清理缓存"""
info_before = get_cache_info()
clear_cache()
return jsonify({
"message": "Cache cleared successfully",
"before": info_before,
"after": get_cache_info()
})
@app.route('/api/v1/cache/stats')
def cache_stats():
"""获取缓存统计"""
return jsonify(get_cache_info())
# ============================================================
# 错误处理
# ============================================================
@app.errorhandler(404)
def not_found(error):
return jsonify({
"error": "Endpoint not found",
"available_endpoints": [
"/", "/health",
"/api/v1/asset-type",
"/api/v1/ohlcv",
"/api/v1/ohlcv/batch",
"/api/v1/etf/nav",
"/api/v1/cache/clear",
"/api/v1/cache/stats",
]
}), 404
@app.errorhandler(500)
def internal_error(error):
return jsonify({
"error": "Internal server error",
"message": str(error)
}), 500
# ============================================================
# 启动服务
# ============================================================
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Universal Data Fetcher API Server')
parser.add_argument('--host', default='0.0.0.0', help='Host to bind')
parser.add_argument('--port', type=int, default=5000, help='Port to bind')
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
args = parser.parse_args()
# 预加载 SSH 配置
ssh_config = get_ssh_config()
if ssh_config and ssh_config.get('enabled'):
print(f"✓ SSH 隧道已配置: {ssh_config['host']}:{ssh_config['port']}")
else:
print("✗ SSH 隧道未启用仅支持A股数据")
print(f"\n🚀 Universal Data Fetcher API Server v2.0")
print(f" Host: {args.host}")
print(f" Port: {args.port}")
print(f" Cache: LRU({CACHE_MAXSIZE}) + TTL({CACHE_TTL_SECONDS}s)")
print(f"\n📖 API: http://{args.host}:{args.port}/")
print(f" 健康检查: http://{args.host}:{args.port}/health")
app.run(host=args.host, port=args.port, debug=args.debug)