feat(strategy): rotation策略支持Flask API数据获取

- 新增 flask_api_source.py: Flask API远程数据源模块
- 修改 strategy.py: get_data() 支持通过Flask API获取数据

使用方式:
strategy.get_data(use_flask_api=True)  # 通过部署服务获取
strategy.get_data(use_flask_api=False) # 本地HybridDataSource

配置项:
flask_api_url: 可在config.yaml中指定API地址
This commit is contained in:
2026-05-13 23:49:26 +08:00
parent 416f708d53
commit 0a9795febb
2 changed files with 418 additions and 3 deletions

View File

@@ -105,8 +105,14 @@ class RotationStrategy(StrategyBase):
group_mapping[code] = cfg.get('market', 'default')
return group_mapping
def get_data(self) -> dict:
"""获取数据(使用新数据源模块)"""
def get_data(self, use_flask_api: bool = True) -> dict:
"""
获取数据
Args:
use_flask_api: 是否使用 Flask API 服务获取数据(默认 True
False 则使用本地 HybridDataSource
"""
code_list_config = self.config.get('code_list', {})
benchmark_config = self.config.get('benchmark', {})
benchmark_code = benchmark_config.get('code', '000300.SH')
@@ -114,7 +120,123 @@ class RotationStrategy(StrategyBase):
if not code_list_config:
raise ValueError("配置中未找到 code_list")
# 使用新数据源模块
# 获取 Flask API 地址
flask_api_url = self.config.get('flask_api_url')
if use_flask_api:
# 使用 Flask API 服务获取数据(远程调用)
return self._get_data_from_flask_api(
code_list_config,
benchmark_code,
flask_api_url
)
else:
# 使用本地 HybridDataSource需要本地 SSH 隧道)
return self._get_data_from_local(
code_list_config,
benchmark_code
)
def _get_data_from_flask_api(
self,
code_list_config: dict,
benchmark_code: str,
flask_api_url: str = None
) -> dict:
"""通过 Flask API 服务获取数据"""
from datasource.flask_api_source import FlaskAPIDataSource
# 初始化 Flask API 数据源
api_source = FlaskAPIDataSource(base_url=flask_api_url)
# 检查服务状态
health = api_source.get_health()
if health.get('status') != 'healthy':
print(f"⚠ Flask API 服务状态: {health}")
else:
print(f"✓ Flask API 服务正常 (SSH: {health.get('ssh_configured', False)})")
# 获取指数代码列表
index_codes = list(code_list_config.keys())
# 获取 ETF 代码映射
etf_code_map = {}
etf_codes = []
for index_code, cfg in code_list_config.items():
if isinstance(cfg, dict) and cfg.get('etf'):
etf_code_map[index_code] = cfg['etf']
etf_codes.append(cfg['etf'])
# 获取指数 OHLCV 数据
print(f"\n获取指数数据 ({len(index_codes)} 只)...")
index_ohlcv_data = api_source.fetch_batch(
index_codes,
self.start_date,
self.end_date
)
# 过滤有效代码
valid_codes = [code for code, df in index_ohlcv_data.items() if df is not None and len(df) > 0]
print(f"有效指数: {len(valid_codes)}")
# 获取 ETF 价格数据
print(f"\n获取 ETF 数据 ({len(etf_codes)} 只)...")
etf_ohlcv_data = api_source.fetch_batch(
etf_codes,
self.start_date,
self.end_date
)
# 转换为宽格式 DataFrame
etf_data = None
if etf_ohlcv_data:
etf_close_dict = {}
for etf_code, df in etf_ohlcv_data.items():
if df is not None and 'close' in df.columns:
etf_close_dict[etf_code] = df['close']
if etf_close_dict:
etf_data = pd.DataFrame(etf_close_dict)
# 获取基准数据
print(f"\n获取基准数据 ({benchmark_code})...")
benchmark_ohlcv = api_source.fetch(benchmark_code, self.start_date, self.end_date)
benchmark_data = None
if benchmark_ohlcv is not None:
benchmark_data = benchmark_ohlcv['close']
# 构建指数收盘价宽格式 DataFrame用于因子计算
index_close_dict = {}
for code in valid_codes:
df = index_ohlcv_data.get(code)
if df is not None and 'close' in df.columns:
index_close_dict[code] = df['close']
index_close = pd.DataFrame(index_close_dict) if index_close_dict else None
# 获取 ETF 净值数据(用于溢价率计算)
print(f"\n获取 ETF 净值数据...")
etf_nav_data = {}
for etf_code in etf_codes:
nav_df = api_source.fetch_etf_nav(etf_code, self.start_date, self.end_date)
if nav_df is not None:
etf_nav_data[etf_code] = nav_df
print(f"有效净值: {len(etf_nav_data)}")
return {
'index_data': index_ohlcv_data, # 原始 OHLCV 数据 {code: DataFrame}
'index_close': index_close, # 对齐后的收盘价(宽格式)
'etf_data': etf_data, # ETF 收盘价(宽格式)
'etf_nav_data': etf_nav_data, # ETF 净值数据 {code: DataFrame}
'benchmark_data': benchmark_data, # 基准收盘价 Series
'valid_codes': valid_codes, # 有效指数代码列表
'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射
}
def _get_data_from_local(
self,
code_list_config: dict,
benchmark_code: str
) -> dict:
"""使用本地 HybridDataSource 获取数据"""
from datasource import HybridDataSource
ssh_config = self.config.get('ssh_tunnel', {})