feat(strategy): rotation策略支持Flask API数据获取
- 新增 flask_api_source.py: Flask API远程数据源模块 - 修改 strategy.py: get_data() 支持通过Flask API获取数据 使用方式: strategy.get_data(use_flask_api=True) # 通过部署服务获取 strategy.get_data(use_flask_api=False) # 本地HybridDataSource 配置项: flask_api_url: 可在config.yaml中指定API地址
This commit is contained in:
293
datasource/flask_api_source.py
Normal file
293
datasource/flask_api_source.py
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
"""
|
||||||
|
Flask API 数据源
|
||||||
|
|
||||||
|
通过部署后的 Flask API 服务获取 OHLCV 数据
|
||||||
|
支持远程调用,无需本地 SSH 隧道
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
import pandas as pd
|
||||||
|
from typing import Optional, Dict, List
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class FlaskAPIDataSource:
|
||||||
|
"""
|
||||||
|
Flask API 数据源
|
||||||
|
|
||||||
|
通过 HTTP API 获取数据,无需本地配置 SSH 隧道
|
||||||
|
适用于远程调用或生产环境
|
||||||
|
|
||||||
|
用法:
|
||||||
|
source = FlaskAPIDataSource(base_url="https://k3s.tokenpluse.xyz")
|
||||||
|
df = source.fetch("000300.SH", "2024-01-01", "2024-12-31")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str = None,
|
||||||
|
api_path: str = "/api/v1/ohlcv",
|
||||||
|
timeout: int = 120,
|
||||||
|
retries: int = 3
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: API 服务基础地址,默认从环境变量读取
|
||||||
|
api_path: API 路径
|
||||||
|
timeout: 请求超时时间(秒)
|
||||||
|
retries: 重试次数
|
||||||
|
"""
|
||||||
|
self.base_url = base_url or os.getenv(
|
||||||
|
'FLASK_API_URL',
|
||||||
|
'https://k3s.tokenpluse.xyz'
|
||||||
|
)
|
||||||
|
self.api_path = api_path
|
||||||
|
self.timeout = timeout
|
||||||
|
self.retries = retries
|
||||||
|
|
||||||
|
# 确保 base_url 不以 / 结尾
|
||||||
|
self.base_url = self.base_url.rstrip('/')
|
||||||
|
|
||||||
|
def fetch(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
asset_type: str = None,
|
||||||
|
timeframe: str = '1d'
|
||||||
|
) -> Optional[pd.DataFrame]:
|
||||||
|
"""
|
||||||
|
获取单只标的 OHLCV 数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: 标的代码
|
||||||
|
start_date: 开始日期 YYYY-MM-DD
|
||||||
|
end_date: 结束日期 YYYY-MM-DD
|
||||||
|
asset_type: 资产类型(可选,用于覆盖自动检测)
|
||||||
|
timeframe: K线周期(加密货币需要)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with columns: date, open, high, low, close, volume
|
||||||
|
"""
|
||||||
|
# 构建请求 URL
|
||||||
|
url = f"{self.base_url}{self.api_path}"
|
||||||
|
|
||||||
|
# 构建请求参数
|
||||||
|
params = {
|
||||||
|
'code': code,
|
||||||
|
'start': start_date,
|
||||||
|
'end': end_date,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 加密货币需要 timeframe 参数
|
||||||
|
if asset_type == 'crypto' or code.upper() in ['BTC', 'ETH']:
|
||||||
|
params['timeframe'] = timeframe
|
||||||
|
|
||||||
|
# 可选:强制指定 asset_type
|
||||||
|
if asset_type:
|
||||||
|
params['asset_type'] = asset_type
|
||||||
|
|
||||||
|
for attempt in range(self.retries):
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
url,
|
||||||
|
params=params,
|
||||||
|
timeout=self.timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
if attempt < self.retries - 1:
|
||||||
|
continue
|
||||||
|
print(f"✗ API请求失败: {response.status_code} - {response.text[:100]}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# 检查错误
|
||||||
|
if 'error' in data:
|
||||||
|
print(f"✗ API返回错误: {data['error']}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 解析数据
|
||||||
|
records = data.get('data', [])
|
||||||
|
if not records:
|
||||||
|
print(f"⚠ {code}: 无数据返回")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 转换为 DataFrame
|
||||||
|
df = pd.DataFrame(records)
|
||||||
|
|
||||||
|
# 处理日期列
|
||||||
|
if 'date' in df.columns:
|
||||||
|
df['date'] = pd.to_datetime(df['date'])
|
||||||
|
df = df.set_index('date')
|
||||||
|
|
||||||
|
# 确保列名标准化
|
||||||
|
df = df[['open', 'high', 'low', 'close', 'volume']]
|
||||||
|
|
||||||
|
# 缓存 info 信息(如果有)
|
||||||
|
if 'info' in data:
|
||||||
|
df.attrs['info'] = data['info']
|
||||||
|
|
||||||
|
print(f"✓ {code}: {len(df)} 条数据 ({start_date} ~ {end_date})")
|
||||||
|
return df
|
||||||
|
|
||||||
|
except requests.exceptions.Timeout:
|
||||||
|
if attempt < self.retries - 1:
|
||||||
|
print(f"⚠ {code}: 请求超时,重试 {attempt + 2}/{self.retries}")
|
||||||
|
continue
|
||||||
|
print(f"✗ {code}: 请求超时")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
if attempt < self.retries - 1:
|
||||||
|
continue
|
||||||
|
print(f"✗ {code}: 请求异常 - {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"✗ {code}: JSON解析失败 - {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def fetch_batch(
|
||||||
|
self,
|
||||||
|
codes: List[str],
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
asset_types: Dict[str, str] = None
|
||||||
|
) -> Dict[str, Optional[pd.DataFrame]]:
|
||||||
|
"""
|
||||||
|
批量获取多只标的数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
codes: 标的代码列表
|
||||||
|
start_date: 开始日期
|
||||||
|
end_date: 结束日期
|
||||||
|
asset_types: 资产类型映射 {code: asset_type}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{code: DataFrame}
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
asset_types = asset_types or {}
|
||||||
|
|
||||||
|
print(f"从 Flask API 获取 {len(codes)} 只标的...")
|
||||||
|
|
||||||
|
for i, code in enumerate(codes, 1):
|
||||||
|
asset_type = asset_types.get(code)
|
||||||
|
df = self.fetch(code, start_date, end_date, asset_type)
|
||||||
|
results[code] = df
|
||||||
|
|
||||||
|
# 显示进度
|
||||||
|
if i % 5 == 0 or i == len(codes):
|
||||||
|
success = sum(1 for v in results.values() if v is not None)
|
||||||
|
print(f" 进度: {i}/{len(codes)} (成功: {success})")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def fetch_etf_nav(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str
|
||||||
|
) -> Optional[pd.DataFrame]:
|
||||||
|
"""
|
||||||
|
获取 ETF 净值数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: ETF代码
|
||||||
|
start_date: 开始日期
|
||||||
|
end_date: 结束日期
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with nav column
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/api/v1/etf/nav"
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'code': code,
|
||||||
|
'start': start_date,
|
||||||
|
'end': end_date
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(url, params=params, timeout=self.timeout)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if 'error' in data or 'nav' not in data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
nav_data = data.get('nav', {})
|
||||||
|
records = nav_data.get('data', [])
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
return None
|
||||||
|
|
||||||
|
df = pd.DataFrame(records)
|
||||||
|
if 'date' in df.columns:
|
||||||
|
df['date'] = pd.to_datetime(df['date'])
|
||||||
|
df = df.set_index('date')
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ {code} 净值获取失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_health(self) -> Dict:
|
||||||
|
"""获取服务健康状态"""
|
||||||
|
# 先尝试 ohlcv 端点检查服务是否可用
|
||||||
|
url = f"{self.base_url}{self.api_path}"
|
||||||
|
params = {'code': '000300.SH', 'start': '2024-01-01', 'end': '2024-01-05'}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(url, params=params, timeout=self.timeout)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
return {
|
||||||
|
'status': 'healthy',
|
||||||
|
'ssh_configured': True,
|
||||||
|
'available': True
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {'status': 'error', 'available': False}
|
||||||
|
except Exception as e:
|
||||||
|
return {'status': 'error', 'message': str(e), 'available': False}
|
||||||
|
|
||||||
|
def get_service_info(self) -> Dict:
|
||||||
|
"""获取服务信息"""
|
||||||
|
url = f"{self.base_url}/"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(url, timeout=10)
|
||||||
|
return response.json()
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局实例
|
||||||
|
_flask_api_source: Optional[FlaskAPIDataSource] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_flask_api_source(base_url: str = None) -> FlaskAPIDataSource:
|
||||||
|
"""获取 Flask API 数据源实例"""
|
||||||
|
global _flask_api_source
|
||||||
|
|
||||||
|
if _flask_api_source is None:
|
||||||
|
_flask_api_source = FlaskAPIDataSource(base_url=base_url)
|
||||||
|
|
||||||
|
return _flask_api_source
|
||||||
@@ -105,8 +105,14 @@ class RotationStrategy(StrategyBase):
|
|||||||
group_mapping[code] = cfg.get('market', 'default')
|
group_mapping[code] = cfg.get('market', 'default')
|
||||||
return group_mapping
|
return group_mapping
|
||||||
|
|
||||||
def get_data(self) -> dict:
|
def get_data(self, use_flask_api: bool = True) -> dict:
|
||||||
"""获取数据(使用新数据源模块)"""
|
"""
|
||||||
|
获取数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_flask_api: 是否使用 Flask API 服务获取数据(默认 True)
|
||||||
|
False 则使用本地 HybridDataSource
|
||||||
|
"""
|
||||||
code_list_config = self.config.get('code_list', {})
|
code_list_config = self.config.get('code_list', {})
|
||||||
benchmark_config = self.config.get('benchmark', {})
|
benchmark_config = self.config.get('benchmark', {})
|
||||||
benchmark_code = benchmark_config.get('code', '000300.SH')
|
benchmark_code = benchmark_config.get('code', '000300.SH')
|
||||||
@@ -114,7 +120,123 @@ class RotationStrategy(StrategyBase):
|
|||||||
if not code_list_config:
|
if not code_list_config:
|
||||||
raise ValueError("配置中未找到 code_list")
|
raise ValueError("配置中未找到 code_list")
|
||||||
|
|
||||||
# 使用新数据源模块
|
# 获取 Flask API 地址
|
||||||
|
flask_api_url = self.config.get('flask_api_url')
|
||||||
|
|
||||||
|
if use_flask_api:
|
||||||
|
# 使用 Flask API 服务获取数据(远程调用)
|
||||||
|
return self._get_data_from_flask_api(
|
||||||
|
code_list_config,
|
||||||
|
benchmark_code,
|
||||||
|
flask_api_url
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 使用本地 HybridDataSource(需要本地 SSH 隧道)
|
||||||
|
return self._get_data_from_local(
|
||||||
|
code_list_config,
|
||||||
|
benchmark_code
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_data_from_flask_api(
|
||||||
|
self,
|
||||||
|
code_list_config: dict,
|
||||||
|
benchmark_code: str,
|
||||||
|
flask_api_url: str = None
|
||||||
|
) -> dict:
|
||||||
|
"""通过 Flask API 服务获取数据"""
|
||||||
|
from datasource.flask_api_source import FlaskAPIDataSource
|
||||||
|
|
||||||
|
# 初始化 Flask API 数据源
|
||||||
|
api_source = FlaskAPIDataSource(base_url=flask_api_url)
|
||||||
|
|
||||||
|
# 检查服务状态
|
||||||
|
health = api_source.get_health()
|
||||||
|
if health.get('status') != 'healthy':
|
||||||
|
print(f"⚠ Flask API 服务状态: {health}")
|
||||||
|
else:
|
||||||
|
print(f"✓ Flask API 服务正常 (SSH: {health.get('ssh_configured', False)})")
|
||||||
|
|
||||||
|
# 获取指数代码列表
|
||||||
|
index_codes = list(code_list_config.keys())
|
||||||
|
|
||||||
|
# 获取 ETF 代码映射
|
||||||
|
etf_code_map = {}
|
||||||
|
etf_codes = []
|
||||||
|
for index_code, cfg in code_list_config.items():
|
||||||
|
if isinstance(cfg, dict) and cfg.get('etf'):
|
||||||
|
etf_code_map[index_code] = cfg['etf']
|
||||||
|
etf_codes.append(cfg['etf'])
|
||||||
|
|
||||||
|
# 获取指数 OHLCV 数据
|
||||||
|
print(f"\n获取指数数据 ({len(index_codes)} 只)...")
|
||||||
|
index_ohlcv_data = api_source.fetch_batch(
|
||||||
|
index_codes,
|
||||||
|
self.start_date,
|
||||||
|
self.end_date
|
||||||
|
)
|
||||||
|
|
||||||
|
# 过滤有效代码
|
||||||
|
valid_codes = [code for code, df in index_ohlcv_data.items() if df is not None and len(df) > 0]
|
||||||
|
print(f"有效指数: {len(valid_codes)} 只")
|
||||||
|
|
||||||
|
# 获取 ETF 价格数据
|
||||||
|
print(f"\n获取 ETF 数据 ({len(etf_codes)} 只)...")
|
||||||
|
etf_ohlcv_data = api_source.fetch_batch(
|
||||||
|
etf_codes,
|
||||||
|
self.start_date,
|
||||||
|
self.end_date
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转换为宽格式 DataFrame
|
||||||
|
etf_data = None
|
||||||
|
if etf_ohlcv_data:
|
||||||
|
etf_close_dict = {}
|
||||||
|
for etf_code, df in etf_ohlcv_data.items():
|
||||||
|
if df is not None and 'close' in df.columns:
|
||||||
|
etf_close_dict[etf_code] = df['close']
|
||||||
|
if etf_close_dict:
|
||||||
|
etf_data = pd.DataFrame(etf_close_dict)
|
||||||
|
|
||||||
|
# 获取基准数据
|
||||||
|
print(f"\n获取基准数据 ({benchmark_code})...")
|
||||||
|
benchmark_ohlcv = api_source.fetch(benchmark_code, self.start_date, self.end_date)
|
||||||
|
benchmark_data = None
|
||||||
|
if benchmark_ohlcv is not None:
|
||||||
|
benchmark_data = benchmark_ohlcv['close']
|
||||||
|
|
||||||
|
# 构建指数收盘价宽格式 DataFrame(用于因子计算)
|
||||||
|
index_close_dict = {}
|
||||||
|
for code in valid_codes:
|
||||||
|
df = index_ohlcv_data.get(code)
|
||||||
|
if df is not None and 'close' in df.columns:
|
||||||
|
index_close_dict[code] = df['close']
|
||||||
|
index_close = pd.DataFrame(index_close_dict) if index_close_dict else None
|
||||||
|
|
||||||
|
# 获取 ETF 净值数据(用于溢价率计算)
|
||||||
|
print(f"\n获取 ETF 净值数据...")
|
||||||
|
etf_nav_data = {}
|
||||||
|
for etf_code in etf_codes:
|
||||||
|
nav_df = api_source.fetch_etf_nav(etf_code, self.start_date, self.end_date)
|
||||||
|
if nav_df is not None:
|
||||||
|
etf_nav_data[etf_code] = nav_df
|
||||||
|
print(f"有效净值: {len(etf_nav_data)} 只")
|
||||||
|
|
||||||
|
return {
|
||||||
|
'index_data': index_ohlcv_data, # 原始 OHLCV 数据 {code: DataFrame}
|
||||||
|
'index_close': index_close, # 对齐后的收盘价(宽格式)
|
||||||
|
'etf_data': etf_data, # ETF 收盘价(宽格式)
|
||||||
|
'etf_nav_data': etf_nav_data, # ETF 净值数据 {code: DataFrame}
|
||||||
|
'benchmark_data': benchmark_data, # 基准收盘价 Series
|
||||||
|
'valid_codes': valid_codes, # 有效指数代码列表
|
||||||
|
'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_data_from_local(
|
||||||
|
self,
|
||||||
|
code_list_config: dict,
|
||||||
|
benchmark_code: str
|
||||||
|
) -> dict:
|
||||||
|
"""使用本地 HybridDataSource 获取数据"""
|
||||||
from datasource import HybridDataSource
|
from datasource import HybridDataSource
|
||||||
|
|
||||||
ssh_config = self.config.get('ssh_tunnel', {})
|
ssh_config = self.config.get('ssh_tunnel', {})
|
||||||
|
|||||||
Reference in New Issue
Block a user