- 移除 Tushare 交易日历依赖,A股/美股/港股统一使用 pandas_market_calendars - 简化 get_trading_calendar() 接口,移除 exchange 参数(沪深日历一致) - 删除冗余的 _get_china/us/hk_calendar() 独立函数,直接调用 mcal - 新增 Flask API 端点: /api/v1/trading-calendar, /api/v1/calendar/info - 代码减少 73 行 (-61%),逻辑更集中易维护 - 更新 API 文档描述,三个市场数据源统一
988 lines
33 KiB
Python
988 lines
33 KiB
Python
"""
|
||
Flask 数据服务 API
|
||
==================
|
||
提供 RESTful API 接口,支持获取各类资产的 K 线数据
|
||
|
||
特性:
|
||
- 分层架构:各资产类型独立实现
|
||
- LRU + TTL 双缓存机制
|
||
- SSH隧道支持(港美股)
|
||
- ETF净值获取(计算溢价率)
|
||
|
||
运行:
|
||
python datasource/flask_server.py
|
||
|
||
API 文档:
|
||
GET / - 服务信息
|
||
GET /health - 健康检查
|
||
GET /api/v1/asset-type - 检测资产类型
|
||
GET /api/v1/ohlcv - 获取K线数据
|
||
POST /api/v1/cache/clear - 清理缓存
|
||
GET /api/v1/cache/stats - 缓存统计
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import pickle
|
||
from pathlib import Path
|
||
from datetime import datetime, timedelta
|
||
from typing import Optional, Dict, Any, List, Tuple
|
||
from functools import lru_cache
|
||
|
||
# 添加项目根目录到路径
|
||
project_root = Path(__file__).parent.parent
|
||
sys.path.insert(0, str(project_root))
|
||
|
||
from dotenv import load_dotenv
|
||
load_dotenv()
|
||
|
||
from flask import Flask, request, jsonify
|
||
from flask_cors import CORS
|
||
from flask_compress import Compress
|
||
import pandas as pd
|
||
|
||
from datasource.universal_fetcher import UniversalDataFetcher
|
||
from datasource.asset_type_detector import AssetTypeDetector, AssetType
|
||
from datasource.models import dataframe_to_ohlcv_response, OHLCVResponse, ErrorResponse
|
||
|
||
|
||
# ============================================================
|
||
# Flask 应用配置
|
||
# ============================================================
|
||
|
||
app = Flask(__name__)
|
||
CORS(app) # 启用跨域支持
|
||
Compress(app) # 启用 gzip 压缩
|
||
|
||
# 全局数据获取器实例
|
||
fetcher: Optional[UniversalDataFetcher] = None
|
||
|
||
# 缓存配置
|
||
CACHE_MAXSIZE = int(os.getenv('CACHE_MAXSIZE', '128'))
|
||
CACHE_TTL_SECONDS = int(os.getenv('CACHE_TTL_SECONDS', '7200')) # 默认2小时
|
||
|
||
# 默认数据起点(下载全量数据时使用)
|
||
# 设置为1980年以支持最长历史数据(标普500/日经225等)
|
||
DEFAULT_START_DATE = os.getenv('DEFAULT_START_DATE', '1980-01-01')
|
||
|
||
|
||
class TimedCacheEntry:
|
||
"""带时间戳的缓存条目"""
|
||
def __init__(self, data: Any):
|
||
self.data = data
|
||
self.timestamp = datetime.now()
|
||
|
||
def is_expired(self) -> bool:
|
||
return (datetime.now() - self.timestamp).total_seconds() > CACHE_TTL_SECONDS
|
||
|
||
|
||
# TTL缓存存储
|
||
_ttl_cache: Dict[Tuple, TimedCacheEntry] = {}
|
||
|
||
|
||
# ============================================================
|
||
# 初始化
|
||
# ============================================================
|
||
|
||
def get_fetcher() -> UniversalDataFetcher:
|
||
"""获取或创建数据获取器实例(从环境变量读取 SSH 配置)"""
|
||
global fetcher
|
||
|
||
if fetcher is None:
|
||
fetcher = UniversalDataFetcher.from_env()
|
||
|
||
return fetcher
|
||
|
||
|
||
# ============================================================
|
||
# 缓存机制
|
||
# ============================================================
|
||
|
||
@lru_cache(maxsize=CACHE_MAXSIZE)
|
||
def _fetch_full_data_cached(code: str, today: str, adj: str = 'raw') -> Optional[bytes]:
|
||
"""
|
||
缓存全量数据(pickle 格式,保留完整 DataFrame 包括 attrs)
|
||
|
||
缓存策略:
|
||
- 日级别数据(股票/指数/ETF/期货): 从 DEFAULT_START_DATE 到 today
|
||
- 加密货币: 不缓存,每次实时下载
|
||
- 不同 adj 参数(raw/qfq/hfq)独立缓存
|
||
|
||
缓存Key: (code, today_date, adj)
|
||
- today: 实际的今天日期,用于每日更新缓存
|
||
- adj: 复权参数,不同复权类型独立缓存
|
||
|
||
Returns:
|
||
pickle 序列化的 DataFrame(包括 df.attrs)
|
||
"""
|
||
f = get_fetcher()
|
||
|
||
# 检查资产类型
|
||
asset_type = AssetTypeDetector.detect(code)
|
||
|
||
# 加密货币不缓存
|
||
if asset_type == AssetType.CRYPTO:
|
||
return None # 不缓存加密货币
|
||
|
||
# adj 参数资产类型兼容性校验由 f.fetch() 内部处理
|
||
# 如果不兼容会抛出 ValueError,被 except 捕获
|
||
|
||
try:
|
||
with f:
|
||
# 使用 fetch(adj=adj) 获取数据(支持复权)
|
||
df = f.fetch(code, DEFAULT_START_DATE, today, adj)
|
||
|
||
if df is None or len(df) == 0:
|
||
return None
|
||
|
||
# 保存额外元数据到 attrs(用于切片后重建 result)
|
||
df.attrs['_cache_code'] = code
|
||
df.attrs['_cache_asset_type'] = asset_type.value
|
||
df.attrs['_cache_adj'] = adj
|
||
|
||
# ✅ 一行代码序列化整个 DataFrame(包括 attrs)
|
||
return pickle.dumps(df)
|
||
|
||
except Exception as e:
|
||
return None
|
||
|
||
|
||
def _slice_data_from_cache(cached_bytes: bytes, start: str, end: str) -> Dict:
|
||
"""
|
||
从缓存的 pickle 数据中切片指定日期范围
|
||
|
||
Args:
|
||
cached_bytes: pickle 序列化的 DataFrame
|
||
start: 用户请求的开始日期
|
||
end: 用户请求的结束日期
|
||
|
||
Returns:
|
||
切片后的数据(JSON格式)
|
||
"""
|
||
# ✅ 一行代码反序列化(包括 attrs)
|
||
df = pickle.loads(cached_bytes)
|
||
|
||
# 从 attrs 获取缓存元数据
|
||
code = df.attrs.get('_cache_code', '')
|
||
asset_type = df.attrs.get('_cache_asset_type', '')
|
||
adj = df.attrs.get('_cache_adj', 'raw')
|
||
|
||
# 切片日期范围
|
||
start_dt = pd.to_datetime(start)
|
||
end_dt = pd.to_datetime(end)
|
||
|
||
# 确保索引已排序
|
||
df = df.sort_index()
|
||
|
||
# 切片(使用 loc 进行日期范围选择)
|
||
sliced_df = df.loc[start_dt:end_dt]
|
||
|
||
# 转换为 JSON 格式
|
||
result = dataframe_to_json(sliced_df)
|
||
result['code'] = code
|
||
result['asset_type'] = asset_type
|
||
result['adj'] = adj
|
||
result['requested_range'] = {'start': start, 'end': end}
|
||
result['available_range'] = {
|
||
'start': df.index.min().strftime('%Y-%m-%d') if len(df) > 0 else None,
|
||
'end': df.index.max().strftime('%Y-%m-%d') if len(df) > 0 else None,
|
||
}
|
||
|
||
# 缓存层职责:只保存和恢复原始 attrs,不关心业务含义
|
||
# attrs 中的 nav、premium 等业务数据由 API 层处理
|
||
if sliced_df.attrs:
|
||
# 过滤掉内部缓存元数据(_cache_*)
|
||
public_attrs = {k: v for k, v in sliced_df.attrs.items() if not k.startswith('_cache_')}
|
||
if public_attrs:
|
||
result['attrs'] = public_attrs
|
||
|
||
return result
|
||
|
||
|
||
def fetch_data_with_ttl(
|
||
code: str,
|
||
start: str,
|
||
end: str,
|
||
nocache: bool = False,
|
||
timeframe: str = '1d',
|
||
adj: str = 'raw',
|
||
asset_type: Optional[AssetType] = None # 新增:可选的资产类型参数
|
||
) -> Tuple[Optional[Dict], bool]:
|
||
"""
|
||
获取数据,支持 TTL 缓存(加密货币不缓存)
|
||
|
||
缓存策略:
|
||
- 日级别数据(股票/指数/ETF/期货): Key=(code, today, adj), 缓存全量数据,切片返回
|
||
- 加密货币: 每次实时下载,不缓存,必须指定 timeframe
|
||
- 不同 adj 参数独立缓存
|
||
|
||
Args:
|
||
code: 标的代码
|
||
start: 用户请求的开始日期
|
||
end: 用户请求的结束日期
|
||
nocache: 是否跳过缓存
|
||
timeframe: K线周期(仅加密货币需要)
|
||
adj: 复权参数(raw/qfq/hfq)
|
||
asset_type: 资产类型(可选,如果不提供则自动检测)
|
||
|
||
Returns:
|
||
(data, is_cached): 数据和是否命中缓存
|
||
"""
|
||
# 获取今天的实际日期(用于缓存Key)
|
||
today = datetime.now().strftime('%Y-%m-%d')
|
||
|
||
# 使用传入的 asset_type 或自动检测
|
||
if asset_type is None:
|
||
asset_type = AssetTypeDetector.detect(code)
|
||
|
||
# 加密货币:直接下载,不缓存,必须指定 timeframe
|
||
if asset_type == AssetType.CRYPTO:
|
||
f = get_fetcher()
|
||
try:
|
||
with f:
|
||
# 加密货币仅支持 adj='raw'
|
||
df = f.fetch(code, start, end, adj='raw', timeframe=timeframe)
|
||
if df is None or len(df) == 0:
|
||
return None, False
|
||
result = dataframe_to_json(df, asset_type.value)
|
||
result['code'] = code
|
||
result['asset_type'] = asset_type.value
|
||
result['adj'] = 'raw' # 加密货币无复权
|
||
result['cache_strategy'] = 'no_cache_crypto'
|
||
result['requested_range'] = {'start': start, 'end': end}
|
||
result['timeframe'] = timeframe
|
||
return result, False
|
||
except Exception as e:
|
||
return {'error': str(e), 'code': code, 'asset_type': asset_type.value}, False
|
||
|
||
# 日级别数据:使用缓存(缓存 Key 包含 adj)
|
||
# adj 参数资产类型兼容性校验在 _fetch_full_data_cached() 中执行
|
||
full_cache_key = (code, today, adj)
|
||
|
||
# 跳过缓存:清理缓存后重新下载
|
||
if nocache:
|
||
_fetch_full_data_cached.cache_clear()
|
||
global _ttl_cache
|
||
_ttl_cache.clear()
|
||
cached_bytes = _fetch_full_data_cached(code, today, adj)
|
||
if cached_bytes is None:
|
||
return None, False
|
||
return (_slice_data_from_cache(cached_bytes, start, end), False)
|
||
|
||
# 检查 TTL 缓存(全量数据缓存)
|
||
if full_cache_key in _ttl_cache:
|
||
entry = _ttl_cache[full_cache_key]
|
||
if not entry.is_expired():
|
||
# 从缓存切片
|
||
sliced_data = _slice_data_from_cache(entry.data, start, end)
|
||
return sliced_data, True
|
||
# 过期,删除
|
||
del _ttl_cache[full_cache_key]
|
||
|
||
# 从 LRU 缓存获取全量数据(pickle bytes)
|
||
cached_bytes = _fetch_full_data_cached(code, today, adj)
|
||
|
||
if cached_bytes is None:
|
||
return None, False
|
||
|
||
# 存入 TTL 缓存(存 pickle bytes)
|
||
_ttl_cache[full_cache_key] = TimedCacheEntry(cached_bytes)
|
||
|
||
# 从全量数据切片返回用户请求的范围
|
||
sliced_data = _slice_data_from_cache(cached_bytes, start, end)
|
||
|
||
return sliced_data, False
|
||
|
||
|
||
def clear_cache():
|
||
"""清理所有缓存"""
|
||
global _ttl_cache
|
||
_fetch_full_data_cached.cache_clear()
|
||
_ttl_cache.clear()
|
||
|
||
|
||
def get_cache_info() -> Dict:
|
||
"""获取缓存统计信息"""
|
||
info = _fetch_full_data_cached.cache_info()
|
||
return {
|
||
"lru_cache": {
|
||
"hits": info.hits,
|
||
"misses": info.misses,
|
||
"maxsize": info.maxsize,
|
||
"currsize": info.currsize,
|
||
},
|
||
"ttl_cache_size": len(_ttl_cache),
|
||
"ttl_seconds": CACHE_TTL_SECONDS,
|
||
"default_start_date": DEFAULT_START_DATE,
|
||
"cache_strategy": "full_data_by_code_and_today",
|
||
}
|
||
|
||
|
||
# ============================================================
|
||
# DataFrame 转换
|
||
# ============================================================
|
||
|
||
class JSONEncoder(json.JSONEncoder):
|
||
"""自定义 JSON 编码器,处理特殊类型"""
|
||
def default(self, obj):
|
||
# 处理 pandas Timestamp
|
||
if hasattr(obj, 'isoformat'):
|
||
return obj.isoformat()
|
||
# 处理 numpy 类型
|
||
if hasattr(obj, 'item'):
|
||
return obj.item()
|
||
# 处理 NaN/Infinity
|
||
if isinstance(obj, float):
|
||
if obj != obj: # NaN
|
||
return None
|
||
if obj == float('inf'):
|
||
return None
|
||
if obj == float('-inf'):
|
||
return None
|
||
return super().default(obj)
|
||
|
||
|
||
def dataframe_to_json(df: pd.DataFrame, asset_type: Optional[str] = None) -> Dict:
|
||
"""将 DataFrame 转换为 JSON 可序列化的字典
|
||
|
||
Args:
|
||
df: DataFrame 数据
|
||
asset_type: 资产类型,用于决定日期格式精度
|
||
- crypto: 使用分钟级格式 '%Y-%m-%d %H:%M:%S'
|
||
- 其他: 使用天级格式 '%Y-%m-%d'
|
||
|
||
如果 df.attrs 中有 info 字段,会放到最外层返回
|
||
"""
|
||
if df is None or len(df) == 0:
|
||
result = {"data": [], "count": 0}
|
||
# 即使空数据也返回 info(如果有)
|
||
if hasattr(df, 'attrs') and 'info' in df.attrs:
|
||
result['info'] = df.attrs['info']
|
||
return result
|
||
|
||
# 重置索引
|
||
df_reset = df.reset_index()
|
||
|
||
# 处理日期列 - 根据资产类型决定格式精度
|
||
date_columns = ['date', 'Date', 'index', 'trade_date', 'datetime']
|
||
|
||
# 加密货币使用分钟级格式,其他使用天级格式
|
||
date_format = '%Y-%m-%d %H:%M:%S' if asset_type == 'crypto' else '%Y-%m-%d'
|
||
|
||
for col in date_columns:
|
||
if col in df_reset.columns:
|
||
try:
|
||
df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime(date_format)
|
||
if col != 'date':
|
||
df_reset = df_reset.rename(columns={col: 'date'})
|
||
break
|
||
except Exception:
|
||
pass
|
||
|
||
# 处理特殊值(NaN, Infinity)
|
||
df_clean = df_reset.copy()
|
||
for col in df_clean.columns:
|
||
if df_clean[col].dtype in ['float64', 'float32']:
|
||
df_clean[col] = df_clean[col].replace([float('inf'), float('-inf')], None)
|
||
df_clean[col] = df_clean[col].where(df_clean[col].notna(), None)
|
||
|
||
# 转换为字典列表
|
||
records = df_clean.to_dict(orient='records')
|
||
|
||
# 构建返回结果
|
||
result = {
|
||
"data": records,
|
||
"count": len(records),
|
||
"columns": list(df_clean.columns),
|
||
"date_range": {
|
||
"start": df.index.min().strftime(date_format) if len(df) > 0 else None,
|
||
"end": df.index.max().strftime(date_format) if len(df) > 0 else None,
|
||
}
|
||
}
|
||
|
||
# 将 info 从 df.attrs 放到最外层
|
||
if hasattr(df, 'attrs') and 'info' in df.attrs:
|
||
result['info'] = df.attrs['info']
|
||
|
||
return result
|
||
|
||
|
||
def validate_date(date_str: str) -> bool:
|
||
"""验证日期格式"""
|
||
try:
|
||
datetime.strptime(date_str, '%Y-%m-%d')
|
||
return True
|
||
except ValueError:
|
||
return False
|
||
|
||
|
||
def get_default_dates() -> Tuple[str, str]:
|
||
"""获取默认日期范围(最近3个月)"""
|
||
end = datetime.now()
|
||
start = end - timedelta(days=90)
|
||
return start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d')
|
||
|
||
|
||
def build_premium_result(premium_series: pd.Series) -> Dict:
|
||
"""
|
||
构建溢价率返回结果
|
||
|
||
Args:
|
||
premium_series: 溢价率序列(索引为日期)
|
||
|
||
Returns:
|
||
包含 premium_series, latest_premium, premium_date, premium_stats 的字典
|
||
"""
|
||
if premium_series is None or len(premium_series) == 0:
|
||
return {}
|
||
|
||
# 根据索引是否包含时间部分决定日期格式
|
||
has_time = any(
|
||
t.hour != 0 or t.minute != 0 or t.second != 0
|
||
for t in premium_series.index
|
||
)
|
||
date_format = '%Y-%m-%d %H:%M:%S' if has_time else '%Y-%m-%d'
|
||
|
||
# 转换为日期-溢价率列表
|
||
premium_data = [
|
||
{"date": date.strftime(date_format), "premium": round(premium, 6)}
|
||
for date, premium in premium_series.items()
|
||
]
|
||
|
||
# 最新溢价率
|
||
latest_premium = premium_series.iloc[-1]
|
||
latest_date = premium_series.index[-1].strftime(date_format)
|
||
|
||
return {
|
||
"premium_series": premium_data,
|
||
"latest_premium": round(latest_premium, 6),
|
||
"premium_date": latest_date,
|
||
"premium_stats": {
|
||
"mean": round(premium_series.mean(), 6),
|
||
"std": round(premium_series.std(), 6),
|
||
"min": round(premium_series.min(), 6),
|
||
"max": round(premium_series.max(), 6),
|
||
"median": round(premium_series.median(), 6),
|
||
},
|
||
}
|
||
|
||
|
||
def build_premium_result_from_attrs(premium_data) -> Dict:
|
||
"""
|
||
从 attrs 格式构建溢价率返回结果(兼容 Series 对象和字典格式)
|
||
|
||
Args:
|
||
premium_data: pd.Series 对象或字典格式:
|
||
{
|
||
'type': 'series',
|
||
'data': {date_str: premium_value, ...},
|
||
'name': 'premium'
|
||
}
|
||
|
||
Returns:
|
||
包含 premium_series, latest_premium, premium_date, premium_stats 的字典
|
||
"""
|
||
# 处理 None
|
||
if premium_data is None:
|
||
return {}
|
||
|
||
# 如果是 pd.Series 对象(pickle 反序列化后)
|
||
if isinstance(premium_data, pd.Series):
|
||
premium_series = premium_data
|
||
# 如果是字典格式(旧 JSON 序列化格式)
|
||
elif isinstance(premium_data, dict):
|
||
if premium_data.get('type') != 'series':
|
||
return {}
|
||
premium_dict = premium_data.get('data', {})
|
||
if not premium_dict:
|
||
return {}
|
||
premium_series = pd.Series(premium_dict)
|
||
premium_series.index = pd.to_datetime(premium_series.index)
|
||
else:
|
||
return {}
|
||
|
||
premium_series.index.name = 'date'
|
||
|
||
# 根据索引是否包含时间部分决定日期格式
|
||
# 如果所有时间都是 00:00:00,使用天级格式;否则使用分钟级格式
|
||
has_time = any(
|
||
t.hour != 0 or t.minute != 0 or t.second != 0
|
||
for t in premium_series.index
|
||
)
|
||
date_format = '%Y-%m-%d %H:%M:%S' if has_time else '%Y-%m-%d'
|
||
|
||
# 转换为日期-溢价率列表
|
||
premium_list = [
|
||
{"date": date.strftime(date_format), "premium": round(float(premium), 6)}
|
||
for date, premium in premium_series.items()
|
||
]
|
||
|
||
# 最新溢价率
|
||
latest_premium = float(premium_series.iloc[-1])
|
||
latest_date = premium_series.index[-1].strftime(date_format)
|
||
|
||
return {
|
||
"premium_series": premium_list,
|
||
"latest_premium": round(latest_premium, 6),
|
||
"premium_date": latest_date,
|
||
"premium_stats": {
|
||
"mean": round(float(premium_series.mean()), 6),
|
||
"std": round(float(premium_series.std()), 6),
|
||
"min": round(float(premium_series.min()), 6),
|
||
"max": round(float(premium_series.max()), 6),
|
||
"median": round(float(premium_series.median()), 6),
|
||
},
|
||
}
|
||
|
||
|
||
# ============================================================
|
||
# API 路由
|
||
# ============================================================
|
||
|
||
@app.route('/')
|
||
def index():
|
||
"""首页 - API 信息"""
|
||
return jsonify({
|
||
"name": "Universal Data Fetcher API",
|
||
"version": "2.1.0",
|
||
"description": "统一数据获取服务(分层架构 + 交易日历)",
|
||
"architecture": "Unified entry + Asset-specific methods",
|
||
"features": [
|
||
"分层架构(各资产类型独立实现)",
|
||
"LRU + TTL 双缓存机制",
|
||
"SSH隧道支持(港美股)",
|
||
"ETF净值获取(计算溢价率)",
|
||
"多市场交易日历(A股/美股/港股)",
|
||
],
|
||
"endpoints": {
|
||
"info": "/",
|
||
"health": "/health",
|
||
"asset_type": "/api/v1/asset-type?code={code}",
|
||
"trading_calendar": "/api/v1/trading-calendar?market={A|US|HK}&start={YYYY-MM-DD}&end={YYYY-MM-DD}",
|
||
"calendar_info": "/api/v1/calendar/info",
|
||
"ohlcv": "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}&asset_type={type}",
|
||
"ohlcv_nocache": "/api/v1/ohlcv?code={code}&nocache=true",
|
||
"ohlcv_crypto": "/api/v1/ohlcv?code=BTC&timeframe=1d (加密货币必须指定 timeframe)",
|
||
"ohlcv_asset_type": "/api/v1/ohlcv?code={code}&asset_type=china_index (强制覆盖类型)",
|
||
"cache_clear": "POST /api/v1/cache/clear",
|
||
"cache_stats": "/api/v1/cache/stats",
|
||
},
|
||
"trading_calendar_markets": {
|
||
"A": "A股(pandas_market_calendars)",
|
||
"US": "美股(pandas_market_calendars)",
|
||
"HK": "港股(pandas_market_calendars)",
|
||
},
|
||
"crypto_timeframes": {
|
||
"1d": "日线",
|
||
"1h": "小时线",
|
||
"4h": "4小时线",
|
||
"15m": "15分钟线",
|
||
"1m": "分钟线",
|
||
},
|
||
"asset_types": {
|
||
"china_index": "中国指数 (000300.SH, 399006.SZ等)",
|
||
"china_etf": "中国ETF (159915.SZ, 513100.SH等)",
|
||
"us_index": "美股指数 (NDX, SPX, N225等)",
|
||
"hk_index": "港股指数 (HSI, HSTECH.HK等)",
|
||
"futures": "期货 (AU.SHF, CU.SHF等)",
|
||
"crypto": "加密货币 (BTC, ETH - 不缓存)",
|
||
},
|
||
"supported_assets": {
|
||
"china_index": ["000300.SH", "399006.SZ", "H30269.CSI"],
|
||
"china_etf": ["159915.SZ", "513100.SH", "518880.SH"],
|
||
"hk_index": ["HSI", "HSTECH.HK"],
|
||
"us_index": ["NDX", "SPX", "N225", "GDAXI"],
|
||
"futures": ["AU.SHF", "CU.SHF", "CL.NYM"],
|
||
"crypto": ["BTC", "ETH"],
|
||
},
|
||
"cache_config": get_cache_info(),
|
||
"ssh": get_fetcher().get_ssh_status(),
|
||
"calendar_info": get_fetcher().get_calendar_info(),
|
||
})
|
||
|
||
|
||
@app.route('/health')
|
||
def health():
|
||
"""健康检查"""
|
||
return jsonify({
|
||
"status": "healthy",
|
||
"timestamp": datetime.now().isoformat(),
|
||
"ssh": get_fetcher().get_ssh_status(),
|
||
})
|
||
|
||
|
||
@app.route('/api/v1/asset-type')
|
||
def detect_asset_type():
|
||
"""检测资产类型"""
|
||
code = request.args.get('code', '').strip()
|
||
|
||
if not code:
|
||
return jsonify({
|
||
"error": "Missing required parameter: code",
|
||
"example": "/api/v1/asset-type?code=000300.SH"
|
||
}), 400
|
||
|
||
asset_type = AssetTypeDetector.detect(code)
|
||
description = AssetTypeDetector.get_description(asset_type)
|
||
|
||
return jsonify({
|
||
"code": code,
|
||
"asset_type": asset_type.value,
|
||
"description": description,
|
||
})
|
||
|
||
|
||
@app.route('/api/v1/trading-calendar')
|
||
def get_trading_calendar():
|
||
"""
|
||
获取交易日历
|
||
|
||
Query Parameters:
|
||
market: 市场代码 (required)
|
||
- A: A股(上交所/深交所,交易日历一致)
|
||
- US: 美股(NYSE)
|
||
- HK: 港股(HKEX)
|
||
start: 开始日期 YYYY-MM-DD (required)
|
||
end: 结束日期 YYYY-MM-DD (required)
|
||
|
||
Returns:
|
||
JSON 包含 trading_dates 列表(日期字符串数组)
|
||
|
||
示例:
|
||
/api/v1/trading-calendar?market=A&start=2024-01-01&end=2024-12-31
|
||
/api/v1/trading-calendar?market=US&start=2024-01-01&end=2024-12-31
|
||
/api/v1/trading-calendar?market=HK&start=2024-01-01&end=2024-12-31
|
||
"""
|
||
market = request.args.get('market', '').strip()
|
||
start = request.args.get('start', '').strip()
|
||
end = request.args.get('end', '').strip()
|
||
|
||
# 参数验证
|
||
if not market:
|
||
return jsonify({
|
||
"error": "Missing required parameter: market",
|
||
"example": "/api/v1/trading-calendar?market=A&start=2024-01-01&end=2024-12-31",
|
||
"supported_markets": ["A", "US", "HK"],
|
||
}), 400
|
||
|
||
if not start or not end:
|
||
return jsonify({
|
||
"error": "Missing required parameters: start and end",
|
||
"example": "/api/v1/trading-calendar?market=A&start=2024-01-01&end=2024-12-31",
|
||
}), 400
|
||
|
||
# 日期格式验证
|
||
if not validate_date(start) or not validate_date(end):
|
||
return jsonify({
|
||
"error": "Invalid date format. Use YYYY-MM-DD",
|
||
"start": start,
|
||
"end": end,
|
||
}), 400
|
||
|
||
try:
|
||
# 获取交易日历
|
||
f = get_fetcher()
|
||
trading_dates = f.get_trading_calendar(market, start, end)
|
||
|
||
# 转换为日期字符串列表
|
||
dates_list = [d.strftime('%Y-%m-%d') for d in trading_dates]
|
||
|
||
# 获取默认交易所名称
|
||
exchange_map = {
|
||
'A': 'SSE',
|
||
'US': 'NYSE',
|
||
'HK': 'HKEX',
|
||
}
|
||
exchange = exchange_map.get(market.upper(), '')
|
||
|
||
return jsonify({
|
||
"market": market.upper(),
|
||
"exchange": exchange,
|
||
"start": start,
|
||
"end": end,
|
||
"trading_dates": dates_list,
|
||
"count": len(dates_list),
|
||
})
|
||
|
||
except ValueError as e:
|
||
return jsonify({
|
||
"error": str(e),
|
||
"supported_markets": ["A", "US", "HK"],
|
||
}), 400
|
||
except ImportError as e:
|
||
return jsonify({
|
||
"error": str(e),
|
||
"hint": "请安装 pandas_market_calendars: pip install pandas_market_calendars",
|
||
}), 500
|
||
except Exception as e:
|
||
return jsonify({
|
||
"error": f"Failed to fetch trading calendar: {str(e)}",
|
||
}), 500
|
||
|
||
|
||
|
||
@app.route('/api/v1/calendar/info')
|
||
def calendar_info():
|
||
"""获取交易日历支持信息"""
|
||
try:
|
||
f = get_fetcher()
|
||
info = f.get_calendar_info()
|
||
return jsonify(info)
|
||
except Exception as e:
|
||
return jsonify({
|
||
"error": f"Failed to get calendar info: {str(e)}",
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/v1/ohlcv')
|
||
def get_ohlcv():
|
||
"""
|
||
获取单只标的的 OHLCV 数据
|
||
|
||
Query Parameters:
|
||
code: 标的代码 (required)
|
||
start: 开始日期 YYYY-MM-DD (optional, 默认90天前)
|
||
end: 结束日期 YYYY-MM-DD (optional, 默认今天)
|
||
asset_type: 资产类型 (optional, 强制覆盖自动检测结果)
|
||
- china_index: 中国指数
|
||
- china_etf: 中国ETF
|
||
- china_stock: 中国股票
|
||
- us_index: 美股指数
|
||
- us_stock: 美股股票
|
||
- hk_index: 港股指数
|
||
- hk_stock: 港股股票
|
||
- futures: 期货
|
||
- crypto: 加密货币
|
||
注:指定后会覆盖自动检测,用于修复检测逻辑问题
|
||
adj: 复权参数 (optional, 默认raw)
|
||
- raw: 原始价格(所有资产类型)
|
||
- qfq: 前复权(A股股票/美股股票/港股股票)
|
||
- hfq: 后复权(A股股票/ETF/美股股票/港股股票)
|
||
注:不同资产类型支持的adj值不同,非法组合返回400错误
|
||
timeframe: K线周期 (optional, 仅加密货币需要)
|
||
- 1d: 日线(默认)
|
||
- 1h: 小时线
|
||
- 4h: 4小时线
|
||
- 15m: 15分钟线
|
||
- 1m: 分钟线
|
||
nocache: 是否跳过缓存 (optional, 默认false)
|
||
|
||
特殊说明:
|
||
- 中国ETF (china_etf) 始终返回净值和溢价率数据
|
||
- 净值和溢价率为客观数据,与 adj 参数无关
|
||
- 溢价率始终基于原始价格计算,不受复权影响
|
||
"""
|
||
code = request.args.get('code', '').strip()
|
||
start = request.args.get('start', '').strip()
|
||
end = request.args.get('end', '').strip()
|
||
asset_type_param = request.args.get('asset_type', '').strip().lower()
|
||
adj = request.args.get('adj', 'raw').strip().lower()
|
||
timeframe = request.args.get('timeframe', '1d').strip().lower()
|
||
nocache = request.args.get('nocache', 'false').lower() == 'true'
|
||
|
||
# 参数验证
|
||
if not code:
|
||
return jsonify({
|
||
"error": "Missing required parameter: code",
|
||
"example": "/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31",
|
||
"adj_hint": "可选 adj 参数获取复权数据(raw/qfq/hfq)",
|
||
}), 400
|
||
|
||
# adj 参数基础格式验证(详细的资产类型兼容性校验在 fetch() 中)
|
||
if adj not in ['raw', 'qfq', 'hfq']:
|
||
return jsonify({
|
||
"error": f"Invalid adj parameter: {adj}",
|
||
"valid_adj": ['raw', 'qfq', 'hfq'],
|
||
"hint": "adj 必须是 raw/qfq/hfq",
|
||
}), 400
|
||
|
||
# 设置默认日期
|
||
if not start or not end:
|
||
start, end = get_default_dates()
|
||
|
||
# 日期格式验证
|
||
if not validate_date(start) or not validate_date(end):
|
||
return jsonify({
|
||
"error": "Invalid date format. Use YYYY-MM-DD",
|
||
"start": start,
|
||
"end": end,
|
||
}), 400
|
||
|
||
# 自动检测资产类型
|
||
detected_type = AssetTypeDetector.detect(code)
|
||
|
||
# 最终使用的类型:优先使用用户指定的类型
|
||
final_type = detected_type
|
||
if asset_type_param:
|
||
try:
|
||
# 将字符串转换为 AssetType(强制覆盖自动检测结果)
|
||
final_type = AssetType(asset_type_param)
|
||
except ValueError:
|
||
return jsonify({
|
||
"error": f"Invalid asset_type: {asset_type_param}",
|
||
"valid_types": [t.value for t in AssetType],
|
||
}), 400
|
||
|
||
# adj 参数资产类型兼容性校验(委托给 fetch_data_with_ttl,内部会调用 UniversalDataFetcher.fetch)
|
||
# 如果 adj 不兼容,fetch() 会抛出 ValueError,由 try-except 处理
|
||
|
||
# 加密货币必须指定 timeframe(无论自动检测还是手动指定)
|
||
if final_type == AssetType.CRYPTO:
|
||
valid_timeframes = ['1d', '1h', '4h', '15m', '1m', 'daily', 'hourly']
|
||
if timeframe not in valid_timeframes:
|
||
return jsonify({
|
||
"error": f"Invalid timeframe for crypto: {timeframe}",
|
||
"valid_timeframes": valid_timeframes,
|
||
"hint": "加密货币必须指定 timeframe 参数",
|
||
}), 400
|
||
|
||
# 使用缓存获取数据(加密货币不缓存)
|
||
# 传递 final_type 避免重复检测
|
||
result, is_cached = fetch_data_with_ttl(code, start, end, nocache, timeframe, adj, final_type)
|
||
|
||
if result is None:
|
||
error_response = ErrorResponse(
|
||
error="No data available",
|
||
code=code,
|
||
asset_type=final_type.value,
|
||
adj=adj,
|
||
detected_type=detected_type.value if asset_type_param else None,
|
||
)
|
||
return error_response.model_dump(mode='json'), 404
|
||
|
||
if "error" in result:
|
||
error_response = ErrorResponse(
|
||
error=result["error"],
|
||
code=code,
|
||
asset_type=final_type.value,
|
||
adj=adj,
|
||
detected_type=detected_type.value if asset_type_param else None,
|
||
)
|
||
return error_response.model_dump(mode='json'), 500
|
||
|
||
# ✅ 使用 Pydantic 模型构建响应(类型安全)
|
||
# 从 result 中提取数据
|
||
df_data = result.get('data', [])
|
||
attrs = result.get('attrs', {})
|
||
|
||
# 重建 DataFrame(用于转换函数)
|
||
if df_data:
|
||
df = pd.DataFrame(df_data)
|
||
if 'date' in df.columns:
|
||
df['date'] = pd.to_datetime(df['date'])
|
||
df = df.set_index('date')
|
||
else:
|
||
df = pd.DataFrame()
|
||
|
||
# 提取 nav DataFrame
|
||
nav_df = attrs.get('nav') if isinstance(attrs.get('nav'), pd.DataFrame) else None
|
||
|
||
# 提取 premium Series
|
||
premium_series = attrs.get('premium') if isinstance(attrs.get('premium'), pd.Series) else None
|
||
|
||
# 构建响应模型
|
||
response = dataframe_to_ohlcv_response(
|
||
df=df if len(df) > 0 else None,
|
||
code=code,
|
||
asset_type=final_type.value,
|
||
adj=adj,
|
||
cached=is_cached,
|
||
nav_df=nav_df,
|
||
premium_series=premium_series,
|
||
info=attrs.get('info'),
|
||
attrs=attrs,
|
||
columns=result.get('columns'),
|
||
date_range=result.get('date_range'),
|
||
requested_range=result.get('requested_range'),
|
||
available_range=result.get('available_range'),
|
||
cache_strategy=result.get('cache_strategy'),
|
||
timeframe=result.get('timeframe'),
|
||
type_override={
|
||
"detected": detected_type.value,
|
||
"specified": final_type.value,
|
||
"hint": "用户强制覆盖了自动检测结果",
|
||
} if (asset_type_param and detected_type != final_type) else None,
|
||
)
|
||
|
||
# ✅ 自动序列化为 JSON
|
||
return response.model_dump(mode='json')
|
||
|
||
|
||
|
||
@app.route('/api/v1/cache/clear', methods=['POST'])
|
||
def clear_cache_endpoint():
|
||
"""清理缓存"""
|
||
info_before = get_cache_info()
|
||
clear_cache()
|
||
return jsonify({
|
||
"message": "Cache cleared successfully",
|
||
"before": info_before,
|
||
"after": get_cache_info()
|
||
})
|
||
|
||
|
||
@app.route('/api/v1/cache/stats')
|
||
def cache_stats():
|
||
"""获取缓存统计"""
|
||
return jsonify(get_cache_info())
|
||
|
||
|
||
# ============================================================
|
||
# 错误处理
|
||
# ============================================================
|
||
|
||
@app.errorhandler(404)
|
||
def not_found(error):
|
||
return jsonify({
|
||
"error": "Endpoint not found",
|
||
"available_endpoints": [
|
||
"/", "/health",
|
||
"/api/v1/asset-type",
|
||
"/api/v1/trading-calendar",
|
||
"/api/v1/calendar/info",
|
||
"/api/v1/ohlcv",
|
||
"/api/v1/cache/clear",
|
||
"/api/v1/cache/stats",
|
||
]
|
||
}), 404
|
||
|
||
|
||
@app.errorhandler(500)
|
||
def internal_error(error):
|
||
return jsonify({
|
||
"error": "Internal server error",
|
||
"message": str(error)
|
||
}), 500
|
||
|
||
|
||
# ============================================================
|
||
# 启动服务
|
||
# ============================================================
|
||
|
||
if __name__ == '__main__':
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='Universal Data Fetcher API Server')
|
||
parser.add_argument('--host', default='0.0.0.0', help='Host to bind')
|
||
parser.add_argument('--port', type=int, default=80, help='Port to bind')
|
||
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 预加载 fetcher 并显示 SSH 配置
|
||
f = get_fetcher()
|
||
ssh_status = f.get_ssh_status()
|
||
if ssh_status['status'] == 'enabled':
|
||
print(f"✓ SSH 隧道已配置: {ssh_status['host']}:{ssh_status['port']}")
|
||
else:
|
||
print("✗ SSH 隧道未启用(仅支持A股数据)")
|
||
|
||
print(f"\n🚀 Universal Data Fetcher API Server v2.0")
|
||
print(f" Host: {args.host}")
|
||
print(f" Port: {args.port}")
|
||
print(f" Cache: LRU({CACHE_MAXSIZE}) + TTL({CACHE_TTL_SECONDS}s)")
|
||
print(f"\n📖 API: http://{args.host}:{args.port}/")
|
||
print(f" 健康检查: http://{args.host}:{args.port}/health")
|
||
|
||
app.run(host=args.host, port=args.port, debug=args.debug) |