feat: Flask统一数据服务迁移(分层架构)
架构设计: - 对外统一接口 fetch():自动识别资产类型并路由 - 对内分层实现:各资产类型独立方法,职责单一 新增文件: - datasource/universal_fetcher.py: 统一数据获取器 - _fetch_china_index: A股指数(Tushare) - _fetch_china_etf: A股ETF(含净值) - _fetch_us_index: 美股指数(YFinance+SSH) - _fetch_hk_index: 港股指数(YFinance+SSH) - _fetch_futures: 期货(Tushare/YFinance) - fetch_etf_with_nav: ETF价格+净值(计算溢价率) - datasource/asset_type_detector.py: 资产类型检测器 - AssetType枚举:9种资产类型 - detect(): 自动识别资产类型 - group_by_type(): 批量分组 - datasource/flask_server.py: Flask API服务 - LRU + TTL 双缓存机制 - 8个API端点:ohlcv、etf/nav、batch、cache等 更新: - datasource/__init__.py: 导出新模块 验证: - 模块导入成功 - 资产类型检测正确 - A股数据获取正常(沪深300: 5条)
This commit is contained in:
@@ -4,16 +4,31 @@
|
||||
核心数据获取能力:
|
||||
- A股数据:Tushare(指数、ETF、期货)
|
||||
- 境外数据:YFinance(港股、美股)通过SSH隧道
|
||||
|
||||
架构设计:
|
||||
- 分层架构:对外统一接口,对内各资产类型独立实现
|
||||
- Flask API:LRU + TTL 双缓存机制
|
||||
|
||||
用法:
|
||||
from datasource import UniversalDataFetcher, AssetType
|
||||
|
||||
fetcher = UniversalDataFetcher()
|
||||
df = fetcher.fetch("000300.SH", "2024-01-01", "2024-12-31")
|
||||
"""
|
||||
|
||||
from .ssh_tunnel import SSHTunnelManager
|
||||
from .tushare_source import TushareSource
|
||||
from .yfinance_source import YFinanceSource
|
||||
from .hybrid_source import HybridDataSource
|
||||
from .asset_type_detector import AssetTypeDetector, AssetType
|
||||
from .universal_fetcher import UniversalDataFetcher
|
||||
|
||||
__all__ = [
|
||||
'SSHTunnelManager',
|
||||
'TushareSource',
|
||||
'YFinanceSource',
|
||||
'HybridDataSource',
|
||||
'AssetTypeDetector',
|
||||
'AssetType',
|
||||
'UniversalDataFetcher',
|
||||
]
|
||||
219
datasource/asset_type_detector.py
Normal file
219
datasource/asset_type_detector.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
资产类型检测器
|
||||
|
||||
根据代码格式自动识别资产类型,支持:
|
||||
- A股指数/ETF/股票
|
||||
- 港股指数/股票
|
||||
- 美股指数/股票
|
||||
- 期货合约
|
||||
- 加密货币
|
||||
|
||||
用法:
|
||||
from datasource.asset_type_detector import AssetTypeDetector, AssetType
|
||||
|
||||
# 检测资产类型
|
||||
asset_type = AssetTypeDetector.detect("000300.SH") # AssetType.CHINA_INDEX
|
||||
|
||||
# 获取描述
|
||||
desc = AssetTypeDetector.get_description(asset_type) # "A股指数"
|
||||
|
||||
# 批量分组
|
||||
grouped = AssetTypeDetector.group_by_type(["000300.SH", "NDX", "AU.SHF"])
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class AssetType(Enum):
|
||||
"""资产类型枚举"""
|
||||
CHINA_INDEX = "china_index" # A股指数
|
||||
CHINA_ETF = "china_etf" # A股ETF
|
||||
CHINA_STOCK = "china_stock" # A股股票
|
||||
HK_INDEX = "hk_index" # 港股指数
|
||||
HK_STOCK = "hk_stock" # 港股股票
|
||||
US_INDEX = "us_index" # 美股指数
|
||||
US_STOCK = "us_stock" # 美股股票
|
||||
FUTURES = "futures" # 期货合约
|
||||
CRYPTO = "crypto" # 加密货币
|
||||
UNKNOWN = "unknown" # 未知类型
|
||||
|
||||
|
||||
class AssetTypeDetector:
|
||||
"""
|
||||
资产类型检测器
|
||||
|
||||
根据代码格式自动识别资产类型
|
||||
"""
|
||||
|
||||
# A股后缀
|
||||
CHINA_SUFFIXES = ('.SH', '.SZ', '.SS', '.CSI')
|
||||
|
||||
# 中国期货后缀
|
||||
CHINA_FUTURES_SUFFIXES = ('.SHF', '.DCE', '.CZC', '.INE', '.GFEX')
|
||||
|
||||
# 境外期货后缀
|
||||
FOREIGN_FUTURES_SUFFIXES = ('.NYM', '.ICE', '.CME', '.CBT')
|
||||
|
||||
# 港股后缀
|
||||
HK_SUFFIXES = ('.HK',)
|
||||
|
||||
# 加密货币代码
|
||||
CRYPTO_CODES = {'BTC', 'ETH', 'SOL', 'BNB', 'XRP', 'ADA', 'DOGE'}
|
||||
|
||||
# 特殊指数代码(无后缀)
|
||||
SPECIAL_INDEX_CODES = {
|
||||
'HSI': AssetType.HK_INDEX, # 恒生指数
|
||||
'NDX': AssetType.US_INDEX, # 纳斯达克100
|
||||
'SPX': AssetType.US_INDEX, # 标普500
|
||||
'DJI': AssetType.US_INDEX, # 道琼斯
|
||||
'N225': AssetType.US_INDEX, # 日经225(日本)
|
||||
'GDAXI': AssetType.US_INDEX, # 德国DAX(欧洲)
|
||||
'HSCCI': AssetType.HK_INDEX, # 恒生国企指数
|
||||
'HSCEI': AssetType.HK_INDEX, # 恒生中企指数
|
||||
}
|
||||
|
||||
# YFinance映射(用于判断是否为指数)
|
||||
YFINANCE_INDEX_MAP = {
|
||||
"HSTECH.HK": "3033.HK",
|
||||
"HSI": "^HSI",
|
||||
"NDX": "^NDX",
|
||||
"SPX": "^GSPC",
|
||||
"DJI": "^DJI",
|
||||
"N225": "^N225",
|
||||
"GDAXI": "^GDAXI",
|
||||
"CL.NYM": "CL=F",
|
||||
}
|
||||
|
||||
# 类型描述映射
|
||||
DESCRIPTIONS = {
|
||||
AssetType.CHINA_INDEX: "A股指数",
|
||||
AssetType.CHINA_ETF: "A股ETF",
|
||||
AssetType.CHINA_STOCK: "A股股票",
|
||||
AssetType.HK_INDEX: "港股指数",
|
||||
AssetType.HK_STOCK: "港股股票",
|
||||
AssetType.US_INDEX: "美股指数",
|
||||
AssetType.US_STOCK: "美股股票",
|
||||
AssetType.FUTURES: "期货合约",
|
||||
AssetType.CRYPTO: "加密货币",
|
||||
AssetType.UNKNOWN: "未知类型",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def detect(cls, code: str) -> AssetType:
|
||||
"""
|
||||
检测资产类型
|
||||
|
||||
Args:
|
||||
code: 标的代码
|
||||
|
||||
Returns:
|
||||
AssetType 枚举值
|
||||
"""
|
||||
code = code.strip().upper()
|
||||
|
||||
# 1. 加密货币(优先判断)
|
||||
if code in cls.CRYPTO_CODES:
|
||||
return AssetType.CRYPTO
|
||||
|
||||
# 2. 特殊指数代码(无后缀)
|
||||
if code in cls.SPECIAL_INDEX_CODES:
|
||||
return cls.SPECIAL_INDEX_CODES[code]
|
||||
|
||||
# 3. YFinance映射中的代码(指数)
|
||||
if code in cls.YFINANCE_INDEX_MAP:
|
||||
yf_code = cls.YFINANCE_INDEX_MAP[code]
|
||||
if yf_code.startswith('^'):
|
||||
return AssetType.US_INDEX
|
||||
elif yf_code.endswith('.HK'):
|
||||
return AssetType.HK_INDEX
|
||||
|
||||
# 4. 期货后缀
|
||||
if any(code.endswith(suffix) for suffix in cls.CHINA_FUTURES_SUFFIXES):
|
||||
return AssetType.FUTURES
|
||||
if any(code.endswith(suffix) for suffix in cls.FOREIGN_FUTURES_SUFFIXES):
|
||||
return AssetType.FUTURES
|
||||
|
||||
# 5. 港股后缀
|
||||
if code.endswith('.HK'):
|
||||
# 4位数字.HK通常是指数
|
||||
code_body = code.split('.')[0]
|
||||
if code_body.isdigit() and len(code_body) == 4:
|
||||
return AssetType.HK_INDEX
|
||||
return AssetType.HK_STOCK
|
||||
|
||||
# 6. A股后缀
|
||||
if any(code.endswith(suffix) for suffix in cls.CHINA_SUFFIXES):
|
||||
return cls._classify_china_asset(code)
|
||||
|
||||
# 7. 默认:美股股票
|
||||
return AssetType.US_STOCK
|
||||
|
||||
@classmethod
|
||||
def _classify_china_asset(cls, code: str) -> AssetType:
|
||||
"""
|
||||
细分A股资产类型
|
||||
|
||||
规则:
|
||||
- .CSI 后缀:中证指数
|
||||
- 指数代码段: 000, 001, 002, 399, 930, 931, 932
|
||||
- ETF代码段: 51, 52, 56, 58, 15, 16
|
||||
- 股票: 其他
|
||||
|
||||
Args:
|
||||
code: A股代码(含后缀)
|
||||
|
||||
Returns:
|
||||
AssetType
|
||||
"""
|
||||
# .CSI 后缀直接判定为指数
|
||||
if code.endswith('.CSI'):
|
||||
return AssetType.CHINA_INDEX
|
||||
|
||||
# 提取代码主体
|
||||
code_body = code.split('.')[0]
|
||||
|
||||
# 检查是否为6位数字
|
||||
if not code_body.isdigit() or len(code_body) != 6:
|
||||
return AssetType.CHINA_STOCK
|
||||
|
||||
# 特殊情况:000001 是平安银行(股票)
|
||||
if code_body == '000001':
|
||||
return AssetType.CHINA_STOCK
|
||||
|
||||
# ETF代码段
|
||||
etf_prefixes = ('51', '52', '56', '58', '15', '16')
|
||||
if code_body.startswith(etf_prefixes):
|
||||
return AssetType.CHINA_ETF
|
||||
|
||||
# 指数代码段
|
||||
index_prefixes = ('000', '001', '002', '399', '930', '931', '932')
|
||||
if code_body.startswith(index_prefixes):
|
||||
return AssetType.CHINA_INDEX
|
||||
|
||||
# 默认为股票
|
||||
return AssetType.CHINA_STOCK
|
||||
|
||||
@classmethod
|
||||
def get_description(cls, asset_type: AssetType) -> str:
|
||||
"""获取资产类型描述"""
|
||||
return cls.DESCRIPTIONS.get(asset_type, "未知类型")
|
||||
|
||||
@classmethod
|
||||
def group_by_type(cls, codes: List[str]) -> Dict[AssetType, List[str]]:
|
||||
"""
|
||||
按资产类型分组
|
||||
|
||||
Args:
|
||||
codes: 代码列表
|
||||
|
||||
Returns:
|
||||
{AssetType: [codes]}
|
||||
"""
|
||||
grouped = {}
|
||||
for code in codes:
|
||||
asset_type = cls.detect(code)
|
||||
if asset_type not in grouped:
|
||||
grouped[asset_type] = []
|
||||
grouped[asset_type].append(code)
|
||||
return grouped
|
||||
588
datasource/flask_server.py
Normal file
588
datasource/flask_server.py
Normal file
@@ -0,0 +1,588 @@
|
||||
"""
|
||||
Flask 数据服务 API
|
||||
==================
|
||||
提供 RESTful API 接口,支持获取各类资产的 K 线数据
|
||||
|
||||
特性:
|
||||
- 分层架构:各资产类型独立实现
|
||||
- LRU + TTL 双缓存机制
|
||||
- SSH隧道支持(港美股)
|
||||
- ETF净值获取(计算溢价率)
|
||||
|
||||
运行:
|
||||
python datasource/flask_server.py
|
||||
|
||||
API 文档:
|
||||
GET / - 服务信息
|
||||
GET /health - 健康检查
|
||||
GET /api/v1/asset-type - 检测资产类型
|
||||
GET /api/v1/ohlcv - 获取K线数据
|
||||
POST /api/v1/ohlcv/batch - 批量获取K线数据
|
||||
GET /api/v1/etf/nav - 获取ETF净值
|
||||
POST /api/v1/cache/clear - 清理缓存
|
||||
GET /api/v1/cache/stats - 缓存统计
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from functools import lru_cache
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from flask import Flask, request, jsonify
|
||||
from flask_cors import CORS
|
||||
import pandas as pd
|
||||
|
||||
from datasource.universal_fetcher import UniversalDataFetcher
|
||||
from datasource.asset_type_detector import AssetTypeDetector, AssetType
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Flask 应用配置
|
||||
# ============================================================
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app) # 启用跨域支持
|
||||
|
||||
# 全局数据获取器实例
|
||||
fetcher: Optional[UniversalDataFetcher] = None
|
||||
ssh_config: Optional[Dict] = None
|
||||
|
||||
# 缓存配置
|
||||
CACHE_MAXSIZE = int(os.getenv('CACHE_MAXSIZE', '128'))
|
||||
CACHE_TTL_SECONDS = int(os.getenv('CACHE_TTL_SECONDS', '7200')) # 默认2小时
|
||||
|
||||
|
||||
class TimedCacheEntry:
|
||||
"""带时间戳的缓存条目"""
|
||||
def __init__(self, data: Any):
|
||||
self.data = data
|
||||
self.timestamp = datetime.now()
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
return (datetime.now() - self.timestamp).total_seconds() > CACHE_TTL_SECONDS
|
||||
|
||||
|
||||
# TTL缓存存储
|
||||
_ttl_cache: Dict[Tuple, TimedCacheEntry] = {}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 初始化
|
||||
# ============================================================
|
||||
|
||||
def get_ssh_config() -> Optional[Dict]:
|
||||
"""从环境变量获取 SSH 配置"""
|
||||
enabled = os.getenv('SSH_ENABLED', 'false').lower() == 'true'
|
||||
|
||||
if not enabled:
|
||||
return None
|
||||
|
||||
return {
|
||||
"enabled": True,
|
||||
"host": os.getenv('SSH_HOST', ''),
|
||||
"port": int(os.getenv('SSH_PORT', '22')),
|
||||
"username": os.getenv('SSH_USERNAME', ''),
|
||||
"key_path": os.getenv('SSH_KEY_PATH', 'hk_ecs.pem'),
|
||||
"local_port": int(os.getenv('SSH_LOCAL_PORT', '1080')),
|
||||
}
|
||||
|
||||
|
||||
def get_fetcher() -> UniversalDataFetcher:
|
||||
"""获取或创建数据获取器实例"""
|
||||
global fetcher, ssh_config
|
||||
|
||||
if fetcher is None:
|
||||
ssh_config = get_ssh_config()
|
||||
fetcher = UniversalDataFetcher(ssh_config=ssh_config)
|
||||
|
||||
return fetcher
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 缓存机制
|
||||
# ============================================================
|
||||
|
||||
@lru_cache(maxsize=CACHE_MAXSIZE)
|
||||
def _fetch_data_cached(code: str, start: str, end: str) -> Optional[str]:
|
||||
"""
|
||||
获取数据的缓存版本
|
||||
返回 JSON 序列化的字符串
|
||||
"""
|
||||
f = get_fetcher()
|
||||
|
||||
try:
|
||||
with f:
|
||||
df = f.fetch(code, start, end)
|
||||
|
||||
if df is None or len(df) == 0:
|
||||
return None
|
||||
|
||||
result = dataframe_to_json(df)
|
||||
result['code'] = code
|
||||
result['asset_type'] = AssetTypeDetector.detect(code).value
|
||||
|
||||
return json.dumps(result)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
def fetch_data_with_ttl(
|
||||
code: str,
|
||||
start: str,
|
||||
end: str,
|
||||
nocache: bool = False
|
||||
) -> Tuple[Optional[Dict], bool]:
|
||||
"""
|
||||
获取数据,支持 TTL 缓存
|
||||
|
||||
Args:
|
||||
code: 标的代码
|
||||
start: 开始日期
|
||||
end: 结束日期
|
||||
nocache: 是否跳过缓存
|
||||
|
||||
Returns:
|
||||
(data, is_cached): 数据和是否命中缓存
|
||||
"""
|
||||
cache_key = (code, start, end)
|
||||
|
||||
# 跳过缓存
|
||||
if nocache:
|
||||
_fetch_data_cached.cache_clear()
|
||||
result_json = _fetch_data_cached(code, start, end)
|
||||
return (json.loads(result_json) if result_json else None, False)
|
||||
|
||||
# 检查 TTL 缓存
|
||||
global _ttl_cache
|
||||
if cache_key in _ttl_cache:
|
||||
entry = _ttl_cache[cache_key]
|
||||
if not entry.is_expired():
|
||||
return entry.data, True
|
||||
# 过期,删除
|
||||
del _ttl_cache[cache_key]
|
||||
|
||||
# 从 LRU 缓存获取
|
||||
result_json = _fetch_data_cached(code, start, end)
|
||||
|
||||
if result_json is None:
|
||||
return None, False
|
||||
|
||||
result = json.loads(result_json)
|
||||
|
||||
# 存入 TTL 缓存
|
||||
_ttl_cache[cache_key] = TimedCacheEntry(result)
|
||||
|
||||
return result, False
|
||||
|
||||
|
||||
def clear_cache():
|
||||
"""清理所有缓存"""
|
||||
global _ttl_cache
|
||||
_fetch_data_cached.cache_clear()
|
||||
_ttl_cache.clear()
|
||||
|
||||
|
||||
def get_cache_info() -> Dict:
|
||||
"""获取缓存统计信息"""
|
||||
info = _fetch_data_cached.cache_info()
|
||||
return {
|
||||
"lru_cache": {
|
||||
"hits": info.hits,
|
||||
"misses": info.misses,
|
||||
"maxsize": info.maxsize,
|
||||
"currsize": info.currsize,
|
||||
},
|
||||
"ttl_cache_size": len(_ttl_cache),
|
||||
"ttl_seconds": CACHE_TTL_SECONDS,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# DataFrame 转换
|
||||
# ============================================================
|
||||
|
||||
def dataframe_to_json(df: pd.DataFrame) -> Dict:
|
||||
"""将 DataFrame 转换为 JSON 可序列化的字典"""
|
||||
if df is None or len(df) == 0:
|
||||
return {"data": [], "count": 0}
|
||||
|
||||
# 重置索引
|
||||
df_reset = df.reset_index()
|
||||
|
||||
# 处理日期列
|
||||
date_columns = ['date', 'Date', 'index', 'trade_date', 'datetime']
|
||||
for col in date_columns:
|
||||
if col in df_reset.columns:
|
||||
try:
|
||||
df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime('%Y-%m-%d')
|
||||
if col != 'date':
|
||||
df_reset = df_reset.rename(columns={col: 'date'})
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 转换为字典列表
|
||||
records = df_reset.to_dict(orient='records')
|
||||
|
||||
return {
|
||||
"data": records,
|
||||
"count": len(records),
|
||||
"columns": list(df_reset.columns),
|
||||
"date_range": {
|
||||
"start": df.index.min().strftime('%Y-%m-%d') if len(df) > 0 else None,
|
||||
"end": df.index.max().strftime('%Y-%m-%d') if len(df) > 0 else None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def validate_date(date_str: str) -> bool:
|
||||
"""验证日期格式"""
|
||||
try:
|
||||
datetime.strptime(date_str, '%Y-%m-%d')
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def get_default_dates() -> Tuple[str, str]:
|
||||
"""获取默认日期范围(最近3个月)"""
|
||||
end = datetime.now()
|
||||
start = end - timedelta(days=90)
|
||||
return start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d')
|
||||
|
||||
|
||||
# ============================================================
|
||||
# API 路由
|
||||
# ============================================================
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
"""首页 - API 信息"""
|
||||
return jsonify({
|
||||
"name": "Universal Data Fetcher API",
|
||||
"version": "2.0.0",
|
||||
"description": "统一数据获取服务(分层架构)",
|
||||
"architecture": "Unified entry + Asset-specific methods",
|
||||
"features": [
|
||||
"分层架构(各资产类型独立实现)",
|
||||
"LRU + TTL 双缓存机制",
|
||||
"SSH隧道支持(港美股)",
|
||||
"ETF净值获取(计算溢价率)",
|
||||
],
|
||||
"endpoints": {
|
||||
"info": "/",
|
||||
"health": "/health",
|
||||
"asset_type": "/api/v1/asset-type?code={code}",
|
||||
"ohlcv": "/api/v1/ohlcv?code={code}&start={YYYY-MM-DD}&end={YYYY-MM-DD}",
|
||||
"ohlcv_nocache": "/api/v1/ohlcv?code={code}&nocache=true",
|
||||
"batch": "POST /api/v1/ohlcv/batch",
|
||||
"etf_nav": "/api/v1/etf/nav?code={code}",
|
||||
"cache_clear": "POST /api/v1/cache/clear",
|
||||
"cache_stats": "/api/v1/cache/stats",
|
||||
},
|
||||
"supported_assets": {
|
||||
"china_index": ["000300.SH", "399006.SZ", "H30269.CSI"],
|
||||
"china_etf": ["159915.SZ", "513100.SH", "518880.SH"],
|
||||
"hk_index": ["HSI", "HSTECH.HK"],
|
||||
"us_index": ["NDX", "SPX", "N225", "GDAXI"],
|
||||
"futures": ["AU.SHF", "CU.SHF", "CL.NYM"],
|
||||
"crypto": ["BTC", "ETH"],
|
||||
},
|
||||
"cache_config": get_cache_info(),
|
||||
"ssh_status": "enabled" if ssh_config and ssh_config.get('enabled') else "disabled",
|
||||
})
|
||||
|
||||
|
||||
@app.route('/health')
|
||||
def health():
|
||||
"""健康检查"""
|
||||
return jsonify({
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"ssh_configured": ssh_config is not None and ssh_config.get('enabled', False),
|
||||
})
|
||||
|
||||
|
||||
@app.route('/api/v1/asset-type')
|
||||
def detect_asset_type():
|
||||
"""检测资产类型"""
|
||||
code = request.args.get('code', '').strip()
|
||||
|
||||
if not code:
|
||||
return jsonify({
|
||||
"error": "Missing required parameter: code",
|
||||
"example": "/api/v1/asset-type?code=000300.SH"
|
||||
}), 400
|
||||
|
||||
asset_type = AssetTypeDetector.detect(code)
|
||||
description = AssetTypeDetector.get_description(asset_type)
|
||||
|
||||
return jsonify({
|
||||
"code": code,
|
||||
"asset_type": asset_type.value,
|
||||
"description": description,
|
||||
})
|
||||
|
||||
|
||||
@app.route('/api/v1/ohlcv')
|
||||
def get_ohlcv():
|
||||
"""
|
||||
获取单只标的的 OHLCV 数据
|
||||
|
||||
Query Parameters:
|
||||
code: 标的代码 (required)
|
||||
start: 开始日期 YYYY-MM-DD (optional, 默认90天前)
|
||||
end: 结束日期 YYYY-MM-DD (optional, 默认今天)
|
||||
nocache: 是否跳过缓存 (optional, 默认false)
|
||||
"""
|
||||
code = request.args.get('code', '').strip()
|
||||
start = request.args.get('start', '').strip()
|
||||
end = request.args.get('end', '').strip()
|
||||
nocache = request.args.get('nocache', 'false').lower() == 'true'
|
||||
|
||||
# 参数验证
|
||||
if not code:
|
||||
return jsonify({
|
||||
"error": "Missing required parameter: code",
|
||||
"example": "/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31"
|
||||
}), 400
|
||||
|
||||
# 设置默认日期
|
||||
if not start or not end:
|
||||
start, end = get_default_dates()
|
||||
|
||||
# 日期格式验证
|
||||
if not validate_date(start) or not validate_date(end):
|
||||
return jsonify({
|
||||
"error": "Invalid date format. Use YYYY-MM-DD",
|
||||
"start": start,
|
||||
"end": end,
|
||||
}), 400
|
||||
|
||||
# 使用缓存获取数据
|
||||
result, is_cached = fetch_data_with_ttl(code, start, end, nocache)
|
||||
|
||||
if result is None:
|
||||
return jsonify({
|
||||
"code": code,
|
||||
"asset_type": AssetTypeDetector.detect(code).value,
|
||||
"error": "No data available",
|
||||
"start": start,
|
||||
"end": end,
|
||||
}), 404
|
||||
|
||||
if "error" in result:
|
||||
return jsonify({
|
||||
"code": code,
|
||||
"asset_type": AssetTypeDetector.detect(code).value,
|
||||
"error": result["error"],
|
||||
}), 500
|
||||
|
||||
result['cached'] = is_cached
|
||||
return jsonify(result)
|
||||
|
||||
|
||||
@app.route('/api/v1/ohlcv/batch', methods=['POST'])
|
||||
def batch_ohlcv():
|
||||
"""批量获取多只标的的 OHLCV 数据"""
|
||||
data = request.get_json()
|
||||
|
||||
if not data:
|
||||
return jsonify({
|
||||
"error": "Missing request body",
|
||||
"example": {
|
||||
"codes": ["000300.SH", "NDX"],
|
||||
"start": "2024-01-01",
|
||||
"end": "2024-03-31",
|
||||
}
|
||||
}), 400
|
||||
|
||||
codes = data.get('codes', [])
|
||||
start = data.get('start', '').strip()
|
||||
end = data.get('end', '').strip()
|
||||
|
||||
if not codes or not isinstance(codes, list):
|
||||
return jsonify({
|
||||
"error": "Missing or invalid parameter: codes (must be a list)"
|
||||
}), 400
|
||||
|
||||
if not start or not end:
|
||||
start, end = get_default_dates()
|
||||
|
||||
if not validate_date(start) or not validate_date(end):
|
||||
return jsonify({
|
||||
"error": "Invalid date format. Use YYYY-MM-DD",
|
||||
}), 400
|
||||
|
||||
# 获取数据
|
||||
f = get_fetcher()
|
||||
results = {}
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
|
||||
try:
|
||||
with f:
|
||||
for code in codes:
|
||||
result, _ = fetch_data_with_ttl(code, start, end)
|
||||
|
||||
if result is not None and "error" not in result:
|
||||
results[code] = result
|
||||
success_count += 1
|
||||
else:
|
||||
results[code] = {
|
||||
"code": code,
|
||||
"asset_type": AssetTypeDetector.detect(code).value,
|
||||
"error": result.get("error", "No data") if result else "No data",
|
||||
"data": [],
|
||||
"count": 0,
|
||||
}
|
||||
failed_count += 1
|
||||
except Exception as e:
|
||||
return jsonify({"error": f"Batch fetch failed: {str(e)}"}), 500
|
||||
|
||||
return jsonify({
|
||||
"results": results,
|
||||
"success_count": success_count,
|
||||
"failed_count": failed_count,
|
||||
"total": len(codes),
|
||||
"start": start,
|
||||
"end": end,
|
||||
})
|
||||
|
||||
|
||||
@app.route('/api/v1/etf/nav')
|
||||
def get_etf_nav():
|
||||
"""获取ETF净值数据(用于计算溢价率)"""
|
||||
code = request.args.get('code', '').strip()
|
||||
start = request.args.get('start', '').strip()
|
||||
end = request.args.get('end', '').strip()
|
||||
|
||||
if not code:
|
||||
return jsonify({
|
||||
"error": "Missing required parameter: code",
|
||||
"example": "/api/v1/etf/nav?code=513100.SH"
|
||||
}), 400
|
||||
|
||||
if not start or not end:
|
||||
start, end = get_default_dates()
|
||||
|
||||
# 检查是否为ETF
|
||||
asset_type = AssetTypeDetector.detect(code)
|
||||
if asset_type != AssetType.CHINA_ETF:
|
||||
return jsonify({
|
||||
"error": f"Not an ETF: {code} (type: {asset_type.value})",
|
||||
"hint": "Only A股ETF (codes starting with 51/52/15/16) supported",
|
||||
}), 400
|
||||
|
||||
# 获取净值
|
||||
f = get_fetcher()
|
||||
try:
|
||||
with f:
|
||||
price_df, nav_df = f.fetch_etf_with_nav(code, start, end)
|
||||
|
||||
result = {
|
||||
"code": code,
|
||||
"price": dataframe_to_json(price_df) if price_df else {"data": [], "count": 0},
|
||||
"nav": dataframe_to_json(nav_df) if nav_df else {"data": [], "count": 0},
|
||||
}
|
||||
|
||||
# 计算最新溢价率
|
||||
if nav_df is not None and len(nav_df) > 0 and price_df is not None and len(price_df) > 0:
|
||||
latest_nav = nav_df['nav'].iloc[-1]
|
||||
latest_price = price_df['close'].iloc[-1]
|
||||
if latest_nav > 0:
|
||||
premium = (latest_price - latest_nav) / latest_nav
|
||||
result['premium_rate'] = premium
|
||||
result['premium_date'] = nav_df.index[-1].strftime('%Y-%m-%d')
|
||||
|
||||
return jsonify(result)
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@app.route('/api/v1/cache/clear', methods=['POST'])
|
||||
def clear_cache_endpoint():
|
||||
"""清理缓存"""
|
||||
info_before = get_cache_info()
|
||||
clear_cache()
|
||||
return jsonify({
|
||||
"message": "Cache cleared successfully",
|
||||
"before": info_before,
|
||||
"after": get_cache_info()
|
||||
})
|
||||
|
||||
|
||||
@app.route('/api/v1/cache/stats')
|
||||
def cache_stats():
|
||||
"""获取缓存统计"""
|
||||
return jsonify(get_cache_info())
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 错误处理
|
||||
# ============================================================
|
||||
|
||||
@app.errorhandler(404)
|
||||
def not_found(error):
|
||||
return jsonify({
|
||||
"error": "Endpoint not found",
|
||||
"available_endpoints": [
|
||||
"/", "/health",
|
||||
"/api/v1/asset-type",
|
||||
"/api/v1/ohlcv",
|
||||
"/api/v1/ohlcv/batch",
|
||||
"/api/v1/etf/nav",
|
||||
"/api/v1/cache/clear",
|
||||
"/api/v1/cache/stats",
|
||||
]
|
||||
}), 404
|
||||
|
||||
|
||||
@app.errorhandler(500)
|
||||
def internal_error(error):
|
||||
return jsonify({
|
||||
"error": "Internal server error",
|
||||
"message": str(error)
|
||||
}), 500
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 启动服务
|
||||
# ============================================================
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Universal Data Fetcher API Server')
|
||||
parser.add_argument('--host', default='0.0.0.0', help='Host to bind')
|
||||
parser.add_argument('--port', type=int, default=5000, help='Port to bind')
|
||||
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 预加载 SSH 配置
|
||||
ssh_config = get_ssh_config()
|
||||
if ssh_config and ssh_config.get('enabled'):
|
||||
print(f"✓ SSH 隧道已配置: {ssh_config['host']}:{ssh_config['port']}")
|
||||
else:
|
||||
print("✗ SSH 隧道未启用(仅支持A股数据)")
|
||||
|
||||
print(f"\n🚀 Universal Data Fetcher API Server v2.0")
|
||||
print(f" Host: {args.host}")
|
||||
print(f" Port: {args.port}")
|
||||
print(f" Cache: LRU({CACHE_MAXSIZE}) + TTL({CACHE_TTL_SECONDS}s)")
|
||||
print(f"\n📖 API: http://{args.host}:{args.port}/")
|
||||
print(f" 健康检查: http://{args.host}:{args.port}/health")
|
||||
|
||||
app.run(host=args.host, port=args.port, debug=args.debug)
|
||||
322
datasource/universal_fetcher.py
Normal file
322
datasource/universal_fetcher.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
统一数据获取器
|
||||
|
||||
分层架构:对外统一接口,对内按资产类型独立实现
|
||||
支持:A股指数/ETF、港股指数、美股指数、期货、加密货币
|
||||
|
||||
用法:
|
||||
from datasource import UniversalDataFetcher
|
||||
|
||||
fetcher = UniversalDataFetcher()
|
||||
|
||||
# 单标的获取(自动识别类型)
|
||||
df = fetcher.fetch("000300.SH", "2024-01-01", "2024-12-31")
|
||||
|
||||
# ETF获取(含净值)
|
||||
price_df, nav_df = fetcher.fetch_etf_with_nav("513100.SH", "2024-01-01", "2024-12-31")
|
||||
|
||||
# 批量获取
|
||||
results = fetcher.fetch_batch(["000300.SH", "NDX", "N225"], "2024-01-01", "2024-12-31")
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Optional, Dict, List, Tuple
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
|
||||
from .tushare_source import TushareSource
|
||||
from .yfinance_source import YFinanceSource
|
||||
from .ssh_tunnel import SSHTunnelManager
|
||||
from .asset_type_detector import AssetTypeDetector, AssetType
|
||||
|
||||
|
||||
class UniversalDataFetcher:
|
||||
"""
|
||||
统一数据获取器
|
||||
|
||||
分层架构:
|
||||
- 对外:统一 fetch() 接口,自动路由
|
||||
- 对内:各资产类型独立方法,职责单一
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ssh_config: Optional[Dict] = None,
|
||||
use_cache: bool = True,
|
||||
cache_dir: str = "data/etf_cache/daily"
|
||||
):
|
||||
"""
|
||||
初始化
|
||||
|
||||
Args:
|
||||
ssh_config: SSH隧道配置(用于港美股)
|
||||
use_cache: 是否使用本地缓存
|
||||
cache_dir: 缓存目录
|
||||
"""
|
||||
self.ssh_config = ssh_config or {}
|
||||
self.use_cache = use_cache
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
# 数据源实例
|
||||
self._tushare = TushareSource()
|
||||
self._yfinance = YFinanceSource()
|
||||
|
||||
# SSH隧道(延迟初始化)
|
||||
self._tunnel: Optional[SSHTunnelManager] = None
|
||||
self._tunnel_started = False
|
||||
|
||||
def __enter__(self):
|
||||
"""上下文管理器入口"""
|
||||
self._start_tunnel()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""上下文管理器退出"""
|
||||
self._stop_tunnel()
|
||||
|
||||
# ============================================================
|
||||
# SSH隧道管理
|
||||
# ============================================================
|
||||
|
||||
def _start_tunnel(self) -> bool:
|
||||
"""启动SSH隧道"""
|
||||
if self._tunnel_started:
|
||||
return True
|
||||
|
||||
if self.ssh_config.get('enabled'):
|
||||
self._tunnel = SSHTunnelManager(self.ssh_config)
|
||||
if self._tunnel.start():
|
||||
self._tunnel_started = True
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def _stop_tunnel(self):
|
||||
"""停止SSH隧道"""
|
||||
if self._tunnel:
|
||||
self._tunnel.stop()
|
||||
self._tunnel = None
|
||||
self._tunnel_started = False
|
||||
|
||||
# ============================================================
|
||||
# 统一入口(自动路由)
|
||||
# ============================================================
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
code: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
retry: int = 3
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
统一数据获取入口
|
||||
|
||||
自动识别资产类型并路由到对应方法
|
||||
|
||||
Args:
|
||||
code: 标的代码
|
||||
start_date: 开始日期 'YYYY-MM-DD'
|
||||
end_date: 结束日期 'YYYY-MM-DD'
|
||||
retry: 重试次数
|
||||
|
||||
Returns:
|
||||
DataFrame with columns: date, open, high, low, close, volume
|
||||
"""
|
||||
asset_type = AssetTypeDetector.detect(code)
|
||||
|
||||
for attempt in range(retry):
|
||||
try:
|
||||
# 路由到具体方法
|
||||
if asset_type == AssetType.CHINA_INDEX:
|
||||
return self._fetch_china_index(code, start_date, end_date)
|
||||
elif asset_type == AssetType.CHINA_ETF:
|
||||
return self._fetch_china_etf(code, start_date, end_date)
|
||||
elif asset_type == AssetType.US_INDEX:
|
||||
return self._fetch_us_index(code, start_date, end_date)
|
||||
elif asset_type == AssetType.HK_INDEX:
|
||||
return self._fetch_hk_index(code, start_date, end_date)
|
||||
elif asset_type == AssetType.FUTURES:
|
||||
return self._fetch_futures(code, start_date, end_date)
|
||||
elif asset_type == AssetType.CRYPTO:
|
||||
return self._fetch_crypto(code, start_date, end_date)
|
||||
else:
|
||||
print(f"⚠️ 未知资产类型: {code} -> {asset_type}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
if attempt < retry - 1:
|
||||
time.sleep(2)
|
||||
else:
|
||||
print(f"✗ 获取 {code} 失败 (尝试 {attempt+1}/{retry}): {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
# ============================================================
|
||||
# 分层实现:各资产类型独立方法
|
||||
# ============================================================
|
||||
|
||||
def _fetch_china_index(
|
||||
self,
|
||||
code: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取A股指数
|
||||
|
||||
特点:Tushare API,无需SSH隧道
|
||||
"""
|
||||
return self._tushare.fetch_index(code, start_date, end_date)
|
||||
|
||||
def _fetch_china_etf(
|
||||
self,
|
||||
code: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取A股ETF价格
|
||||
|
||||
特点:Tushare fund_daily接口
|
||||
"""
|
||||
return self._tushare.fetch_etf(code, start_date, end_date)
|
||||
|
||||
def fetch_etf_with_nav(
|
||||
self,
|
||||
code: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame]]:
|
||||
"""
|
||||
获取ETF价格 + 净值
|
||||
|
||||
用于计算溢价率
|
||||
|
||||
Args:
|
||||
code: ETF代码
|
||||
|
||||
Returns:
|
||||
(price_df, nav_df)
|
||||
"""
|
||||
price_df = self._tushare.fetch_etf(code, start_date, end_date)
|
||||
nav_df = self._tushare.fetch_etf_nav(code, start_date, end_date)
|
||||
return price_df, nav_df
|
||||
|
||||
def _fetch_us_index(
|
||||
self,
|
||||
code: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取美股指数
|
||||
|
||||
特点:YFinance,需要SSH隧道,指数代码转换
|
||||
"""
|
||||
self._start_tunnel()
|
||||
return self._yfinance.fetch(code, start_date, end_date)
|
||||
|
||||
def _fetch_hk_index(
|
||||
self,
|
||||
code: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取港股指数
|
||||
|
||||
特点:YFinance,需要SSH隧道
|
||||
"""
|
||||
self._start_tunnel()
|
||||
return self._yfinance.fetch(code, start_date, end_date)
|
||||
|
||||
def _fetch_futures(
|
||||
self,
|
||||
code: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取期货数据
|
||||
|
||||
特点:
|
||||
- 中国期货(.SHF/.DCE/.CZC): Tushare
|
||||
- NYMEX(.NYM): YFinance
|
||||
"""
|
||||
if code.endswith('.NYM'):
|
||||
# NYMEX期货走YFinance
|
||||
self._start_tunnel()
|
||||
return self._yfinance.fetch(code, start_date, end_date)
|
||||
else:
|
||||
# 中国期货走Tushare
|
||||
return self._tushare.fetch_futures(code, start_date, end_date)
|
||||
|
||||
def _fetch_crypto(
|
||||
self,
|
||||
code: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取加密货币
|
||||
|
||||
特点:CCXT,不支持SOCKS5代理
|
||||
|
||||
TODO: 实现加密货币获取
|
||||
"""
|
||||
print(f"⚠️ 加密货币数据获取尚未实现: {code}")
|
||||
return None
|
||||
|
||||
# ============================================================
|
||||
# 批量获取
|
||||
# ============================================================
|
||||
|
||||
def fetch_batch(
|
||||
self,
|
||||
codes: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
retry: int = 3
|
||||
) -> Dict[str, Optional[pd.DataFrame]]:
|
||||
"""
|
||||
批量获取多只标的数据
|
||||
|
||||
Args:
|
||||
codes: 代码列表
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
|
||||
Returns:
|
||||
{code: DataFrame}
|
||||
"""
|
||||
results = {}
|
||||
|
||||
# 按资产类型分组
|
||||
grouped = AssetTypeDetector.group_by_type(codes)
|
||||
|
||||
print(f"开始获取 {len(codes)} 只标的...")
|
||||
for asset_type, code_list in grouped.items():
|
||||
print(f" {asset_type.value}: {len(code_list)} 只")
|
||||
|
||||
# 启动隧道(港美股需要)
|
||||
self._start_tunnel()
|
||||
|
||||
for code in codes:
|
||||
results[code] = self.fetch(code, start_date, end_date, retry)
|
||||
|
||||
return results
|
||||
|
||||
# ============================================================
|
||||
# 辅助方法
|
||||
# ============================================================
|
||||
|
||||
def get_asset_type(self, code: str) -> AssetType:
|
||||
"""获取资产类型"""
|
||||
return AssetTypeDetector.detect(code)
|
||||
|
||||
def is_supported(self, code: str) -> bool:
|
||||
"""判断是否支持该代码"""
|
||||
return AssetTypeDetector.detect(code) != AssetType.UNKNOWN
|
||||
Reference in New Issue
Block a user