feat(strategy): rotation策略支持Flask API数据获取

- 新增 flask_api_source.py: Flask API远程数据源模块
- 修改 strategy.py: get_data() 支持通过Flask API获取数据

使用方式:
strategy.get_data(use_flask_api=True)  # 通过部署服务获取
strategy.get_data(use_flask_api=False) # 本地HybridDataSource

配置项:
flask_api_url: 可在config.yaml中指定API地址
This commit is contained in:
2026-05-13 23:49:26 +08:00
parent 416f708d53
commit 0a9795febb
2 changed files with 418 additions and 3 deletions

View File

@@ -0,0 +1,293 @@
"""
Flask API 数据源
通过部署后的 Flask API 服务获取 OHLCV 数据
支持远程调用,无需本地 SSH 隧道
"""
import os
import json
import requests
import pandas as pd
from typing import Optional, Dict, List
from datetime import datetime
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
class FlaskAPIDataSource:
"""
Flask API 数据源
通过 HTTP API 获取数据,无需本地配置 SSH 隧道
适用于远程调用或生产环境
用法:
source = FlaskAPIDataSource(base_url="https://k3s.tokenpluse.xyz")
df = source.fetch("000300.SH", "2024-01-01", "2024-12-31")
"""
def __init__(
self,
base_url: str = None,
api_path: str = "/api/v1/ohlcv",
timeout: int = 120,
retries: int = 3
):
"""
初始化
Args:
base_url: API 服务基础地址,默认从环境变量读取
api_path: API 路径
timeout: 请求超时时间(秒)
retries: 重试次数
"""
self.base_url = base_url or os.getenv(
'FLASK_API_URL',
'https://k3s.tokenpluse.xyz'
)
self.api_path = api_path
self.timeout = timeout
self.retries = retries
# 确保 base_url 不以 / 结尾
self.base_url = self.base_url.rstrip('/')
def fetch(
self,
code: str,
start_date: str,
end_date: str,
asset_type: str = None,
timeframe: str = '1d'
) -> Optional[pd.DataFrame]:
"""
获取单只标的 OHLCV 数据
Args:
code: 标的代码
start_date: 开始日期 YYYY-MM-DD
end_date: 结束日期 YYYY-MM-DD
asset_type: 资产类型(可选,用于覆盖自动检测)
timeframe: K线周期加密货币需要
Returns:
DataFrame with columns: date, open, high, low, close, volume
"""
# 构建请求 URL
url = f"{self.base_url}{self.api_path}"
# 构建请求参数
params = {
'code': code,
'start': start_date,
'end': end_date,
}
# 加密货币需要 timeframe 参数
if asset_type == 'crypto' or code.upper() in ['BTC', 'ETH']:
params['timeframe'] = timeframe
# 可选:强制指定 asset_type
if asset_type:
params['asset_type'] = asset_type
for attempt in range(self.retries):
try:
response = requests.get(
url,
params=params,
timeout=self.timeout
)
if response.status_code != 200:
if attempt < self.retries - 1:
continue
print(f"✗ API请求失败: {response.status_code} - {response.text[:100]}")
return None
data = response.json()
# 检查错误
if 'error' in data:
print(f"✗ API返回错误: {data['error']}")
return None
# 解析数据
records = data.get('data', [])
if not records:
print(f"{code}: 无数据返回")
return None
# 转换为 DataFrame
df = pd.DataFrame(records)
# 处理日期列
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date')
# 确保列名标准化
df = df[['open', 'high', 'low', 'close', 'volume']]
# 缓存 info 信息(如果有)
if 'info' in data:
df.attrs['info'] = data['info']
print(f"{code}: {len(df)} 条数据 ({start_date} ~ {end_date})")
return df
except requests.exceptions.Timeout:
if attempt < self.retries - 1:
print(f"{code}: 请求超时,重试 {attempt + 2}/{self.retries}")
continue
print(f"{code}: 请求超时")
return None
except requests.exceptions.RequestException as e:
if attempt < self.retries - 1:
continue
print(f"{code}: 请求异常 - {e}")
return None
except json.JSONDecodeError as e:
print(f"{code}: JSON解析失败 - {e}")
return None
return None
def fetch_batch(
self,
codes: List[str],
start_date: str,
end_date: str,
asset_types: Dict[str, str] = None
) -> Dict[str, Optional[pd.DataFrame]]:
"""
批量获取多只标的数据
Args:
codes: 标的代码列表
start_date: 开始日期
end_date: 结束日期
asset_types: 资产类型映射 {code: asset_type}
Returns:
{code: DataFrame}
"""
results = {}
asset_types = asset_types or {}
print(f"从 Flask API 获取 {len(codes)} 只标的...")
for i, code in enumerate(codes, 1):
asset_type = asset_types.get(code)
df = self.fetch(code, start_date, end_date, asset_type)
results[code] = df
# 显示进度
if i % 5 == 0 or i == len(codes):
success = sum(1 for v in results.values() if v is not None)
print(f" 进度: {i}/{len(codes)} (成功: {success})")
return results
def fetch_etf_nav(
self,
code: str,
start_date: str,
end_date: str
) -> Optional[pd.DataFrame]:
"""
获取 ETF 净值数据
Args:
code: ETF代码
start_date: 开始日期
end_date: 结束日期
Returns:
DataFrame with nav column
"""
url = f"{self.base_url}/api/v1/etf/nav"
params = {
'code': code,
'start': start_date,
'end': end_date
}
try:
response = requests.get(url, params=params, timeout=self.timeout)
if response.status_code != 200:
return None
data = response.json()
if 'error' in data or 'nav' not in data:
return None
nav_data = data.get('nav', {})
records = nav_data.get('data', [])
if not records:
return None
df = pd.DataFrame(records)
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date')
return df
except Exception as e:
print(f"{code} 净值获取失败: {e}")
return None
def get_health(self) -> Dict:
"""获取服务健康状态"""
# 先尝试 ohlcv 端点检查服务是否可用
url = f"{self.base_url}{self.api_path}"
params = {'code': '000300.SH', 'start': '2024-01-01', 'end': '2024-01-05'}
try:
response = requests.get(url, params=params, timeout=self.timeout)
if response.status_code == 200:
data = response.json()
return {
'status': 'healthy',
'ssh_configured': True,
'available': True
}
else:
return {'status': 'error', 'available': False}
except Exception as e:
return {'status': 'error', 'message': str(e), 'available': False}
def get_service_info(self) -> Dict:
"""获取服务信息"""
url = f"{self.base_url}/"
try:
response = requests.get(url, timeout=10)
return response.json()
except Exception as e:
return {"error": str(e)}
# 全局实例
_flask_api_source: Optional[FlaskAPIDataSource] = None
def get_flask_api_source(base_url: str = None) -> FlaskAPIDataSource:
"""获取 Flask API 数据源实例"""
global _flask_api_source
if _flask_api_source is None:
_flask_api_source = FlaskAPIDataSource(base_url=base_url)
return _flask_api_source

View File

@@ -105,8 +105,14 @@ class RotationStrategy(StrategyBase):
group_mapping[code] = cfg.get('market', 'default') group_mapping[code] = cfg.get('market', 'default')
return group_mapping return group_mapping
def get_data(self) -> dict: def get_data(self, use_flask_api: bool = True) -> dict:
"""获取数据(使用新数据源模块)""" """
获取数据
Args:
use_flask_api: 是否使用 Flask API 服务获取数据(默认 True
False 则使用本地 HybridDataSource
"""
code_list_config = self.config.get('code_list', {}) code_list_config = self.config.get('code_list', {})
benchmark_config = self.config.get('benchmark', {}) benchmark_config = self.config.get('benchmark', {})
benchmark_code = benchmark_config.get('code', '000300.SH') benchmark_code = benchmark_config.get('code', '000300.SH')
@@ -114,7 +120,123 @@ class RotationStrategy(StrategyBase):
if not code_list_config: if not code_list_config:
raise ValueError("配置中未找到 code_list") raise ValueError("配置中未找到 code_list")
# 使用新数据源模块 # 获取 Flask API 地址
flask_api_url = self.config.get('flask_api_url')
if use_flask_api:
# 使用 Flask API 服务获取数据(远程调用)
return self._get_data_from_flask_api(
code_list_config,
benchmark_code,
flask_api_url
)
else:
# 使用本地 HybridDataSource需要本地 SSH 隧道)
return self._get_data_from_local(
code_list_config,
benchmark_code
)
def _get_data_from_flask_api(
self,
code_list_config: dict,
benchmark_code: str,
flask_api_url: str = None
) -> dict:
"""通过 Flask API 服务获取数据"""
from datasource.flask_api_source import FlaskAPIDataSource
# 初始化 Flask API 数据源
api_source = FlaskAPIDataSource(base_url=flask_api_url)
# 检查服务状态
health = api_source.get_health()
if health.get('status') != 'healthy':
print(f"⚠ Flask API 服务状态: {health}")
else:
print(f"✓ Flask API 服务正常 (SSH: {health.get('ssh_configured', False)})")
# 获取指数代码列表
index_codes = list(code_list_config.keys())
# 获取 ETF 代码映射
etf_code_map = {}
etf_codes = []
for index_code, cfg in code_list_config.items():
if isinstance(cfg, dict) and cfg.get('etf'):
etf_code_map[index_code] = cfg['etf']
etf_codes.append(cfg['etf'])
# 获取指数 OHLCV 数据
print(f"\n获取指数数据 ({len(index_codes)} 只)...")
index_ohlcv_data = api_source.fetch_batch(
index_codes,
self.start_date,
self.end_date
)
# 过滤有效代码
valid_codes = [code for code, df in index_ohlcv_data.items() if df is not None and len(df) > 0]
print(f"有效指数: {len(valid_codes)}")
# 获取 ETF 价格数据
print(f"\n获取 ETF 数据 ({len(etf_codes)} 只)...")
etf_ohlcv_data = api_source.fetch_batch(
etf_codes,
self.start_date,
self.end_date
)
# 转换为宽格式 DataFrame
etf_data = None
if etf_ohlcv_data:
etf_close_dict = {}
for etf_code, df in etf_ohlcv_data.items():
if df is not None and 'close' in df.columns:
etf_close_dict[etf_code] = df['close']
if etf_close_dict:
etf_data = pd.DataFrame(etf_close_dict)
# 获取基准数据
print(f"\n获取基准数据 ({benchmark_code})...")
benchmark_ohlcv = api_source.fetch(benchmark_code, self.start_date, self.end_date)
benchmark_data = None
if benchmark_ohlcv is not None:
benchmark_data = benchmark_ohlcv['close']
# 构建指数收盘价宽格式 DataFrame用于因子计算
index_close_dict = {}
for code in valid_codes:
df = index_ohlcv_data.get(code)
if df is not None and 'close' in df.columns:
index_close_dict[code] = df['close']
index_close = pd.DataFrame(index_close_dict) if index_close_dict else None
# 获取 ETF 净值数据(用于溢价率计算)
print(f"\n获取 ETF 净值数据...")
etf_nav_data = {}
for etf_code in etf_codes:
nav_df = api_source.fetch_etf_nav(etf_code, self.start_date, self.end_date)
if nav_df is not None:
etf_nav_data[etf_code] = nav_df
print(f"有效净值: {len(etf_nav_data)}")
return {
'index_data': index_ohlcv_data, # 原始 OHLCV 数据 {code: DataFrame}
'index_close': index_close, # 对齐后的收盘价(宽格式)
'etf_data': etf_data, # ETF 收盘价(宽格式)
'etf_nav_data': etf_nav_data, # ETF 净值数据 {code: DataFrame}
'benchmark_data': benchmark_data, # 基准收盘价 Series
'valid_codes': valid_codes, # 有效指数代码列表
'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射
}
def _get_data_from_local(
self,
code_list_config: dict,
benchmark_code: str
) -> dict:
"""使用本地 HybridDataSource 获取数据"""
from datasource import HybridDataSource from datasource import HybridDataSource
ssh_config = self.config.get('ssh_tunnel', {}) ssh_config = self.config.get('ssh_tunnel', {})