Files
etf/strategies/rotation/strategy.py
aszerW 0a9795febb feat(strategy): rotation策略支持Flask API数据获取
- 新增 flask_api_source.py: Flask API远程数据源模块
- 修改 strategy.py: get_data() 支持通过Flask API获取数据

使用方式:
strategy.get_data(use_flask_api=True)  # 通过部署服务获取
strategy.get_data(use_flask_api=False) # 本地HybridDataSource

配置项:
flask_api_url: 可在config.yaml中指定API地址
2026-05-13 23:49:26 +08:00

457 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
轮动策略完整实现
整合数据获取、因子计算、信号生成、回测执行
"""
import pandas as pd
import yaml
from datetime import datetime
from pathlib import Path
# 加载环境变量
from dotenv import load_dotenv
load_dotenv()
from framework.factors import FactorRegistry, FactorCombiner
from framework.signals import SignalGenerator
from framework.execution import BacktestExecutor
from framework.risk import CallbackHook, Position
from framework.strategy import StrategyBase
# 导入定制组件
from strategies.shared.factors.momentum import MomentumFactor
from strategies.shared.signals.selectors import TopNSelector
class RotationStrategy(StrategyBase):
"""
ETF轮动策略完整实现
基于动量因子 + Top N选股 + 分散化
使用方式:
from strategies.rotation.strategy import RotationStrategy
strategy = RotationStrategy.from_yaml('strategies/rotation/config.yaml')
result = strategy.run_backtest()
"""
name = "rotation"
select_num = 3
stoploss = -0.05
n_days = 25
rebalance_days = 1
rebalance_threshold = 0.0
trade_cost = 0.001
def __init__(self, config: dict = None):
"""初始化策略"""
# 应用配置
if config:
self._apply_config(config)
self.config = config
else:
self.config = {}
# 初始化因子
FactorRegistry.clear()
FactorRegistry.register(MomentumFactor)
self._factor = FactorRegistry.get(
'momentum',
n_days=self.n_days,
crash_filter=True
)
# 构建分组映射(分散化选股)
self._group_mapping = self._build_group_mapping()
# 初始化信号生成器
self._selector = TopNSelector(
select_num=self.select_num,
group_mapping=self._group_mapping,
min_score=0.0,
rebalance_days=self.rebalance_days,
rebalance_threshold=self.rebalance_threshold
)
@classmethod
def from_yaml(cls, config_path: str) -> 'RotationStrategy':
"""从YAML配置创建策略实例"""
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# 设置结束日期
if not config.get('end_date'):
config['end_date'] = datetime.now().strftime('%Y-%m-%d')
return cls(config)
def _apply_config(self, config: dict) -> None:
"""应用配置参数"""
self.select_num = config.get('select_num', self.select_num)
self.n_days = config.get('n_days', self.n_days)
self.rebalance_days = config.get('rebalance_days', self.rebalance_days)
self.rebalance_threshold = config.get('rebalance_threshold', self.rebalance_threshold)
self.trade_cost = config.get('trade_cost', self.trade_cost)
self.start_date = config.get('start_date', '2019-01-01')
self.end_date = config.get('end_date', datetime.now().strftime('%Y-%m-%d'))
def _build_group_mapping(self) -> dict:
"""构建分组映射(分散化选股)"""
group_mapping = {}
code_list_config = self.config.get('code_list', {})
for code, cfg in code_list_config.items():
if isinstance(cfg, dict):
group_mapping[code] = cfg.get('market', 'default')
return group_mapping
def get_data(self, use_flask_api: bool = True) -> dict:
"""
获取数据
Args:
use_flask_api: 是否使用 Flask API 服务获取数据(默认 True
False 则使用本地 HybridDataSource
"""
code_list_config = self.config.get('code_list', {})
benchmark_config = self.config.get('benchmark', {})
benchmark_code = benchmark_config.get('code', '000300.SH')
if not code_list_config:
raise ValueError("配置中未找到 code_list")
# 获取 Flask API 地址
flask_api_url = self.config.get('flask_api_url')
if use_flask_api:
# 使用 Flask API 服务获取数据(远程调用)
return self._get_data_from_flask_api(
code_list_config,
benchmark_code,
flask_api_url
)
else:
# 使用本地 HybridDataSource需要本地 SSH 隧道)
return self._get_data_from_local(
code_list_config,
benchmark_code
)
def _get_data_from_flask_api(
self,
code_list_config: dict,
benchmark_code: str,
flask_api_url: str = None
) -> dict:
"""通过 Flask API 服务获取数据"""
from datasource.flask_api_source import FlaskAPIDataSource
# 初始化 Flask API 数据源
api_source = FlaskAPIDataSource(base_url=flask_api_url)
# 检查服务状态
health = api_source.get_health()
if health.get('status') != 'healthy':
print(f"⚠ Flask API 服务状态: {health}")
else:
print(f"✓ Flask API 服务正常 (SSH: {health.get('ssh_configured', False)})")
# 获取指数代码列表
index_codes = list(code_list_config.keys())
# 获取 ETF 代码映射
etf_code_map = {}
etf_codes = []
for index_code, cfg in code_list_config.items():
if isinstance(cfg, dict) and cfg.get('etf'):
etf_code_map[index_code] = cfg['etf']
etf_codes.append(cfg['etf'])
# 获取指数 OHLCV 数据
print(f"\n获取指数数据 ({len(index_codes)} 只)...")
index_ohlcv_data = api_source.fetch_batch(
index_codes,
self.start_date,
self.end_date
)
# 过滤有效代码
valid_codes = [code for code, df in index_ohlcv_data.items() if df is not None and len(df) > 0]
print(f"有效指数: {len(valid_codes)}")
# 获取 ETF 价格数据
print(f"\n获取 ETF 数据 ({len(etf_codes)} 只)...")
etf_ohlcv_data = api_source.fetch_batch(
etf_codes,
self.start_date,
self.end_date
)
# 转换为宽格式 DataFrame
etf_data = None
if etf_ohlcv_data:
etf_close_dict = {}
for etf_code, df in etf_ohlcv_data.items():
if df is not None and 'close' in df.columns:
etf_close_dict[etf_code] = df['close']
if etf_close_dict:
etf_data = pd.DataFrame(etf_close_dict)
# 获取基准数据
print(f"\n获取基准数据 ({benchmark_code})...")
benchmark_ohlcv = api_source.fetch(benchmark_code, self.start_date, self.end_date)
benchmark_data = None
if benchmark_ohlcv is not None:
benchmark_data = benchmark_ohlcv['close']
# 构建指数收盘价宽格式 DataFrame用于因子计算
index_close_dict = {}
for code in valid_codes:
df = index_ohlcv_data.get(code)
if df is not None and 'close' in df.columns:
index_close_dict[code] = df['close']
index_close = pd.DataFrame(index_close_dict) if index_close_dict else None
# 获取 ETF 净值数据(用于溢价率计算)
print(f"\n获取 ETF 净值数据...")
etf_nav_data = {}
for etf_code in etf_codes:
nav_df = api_source.fetch_etf_nav(etf_code, self.start_date, self.end_date)
if nav_df is not None:
etf_nav_data[etf_code] = nav_df
print(f"有效净值: {len(etf_nav_data)}")
return {
'index_data': index_ohlcv_data, # 原始 OHLCV 数据 {code: DataFrame}
'index_close': index_close, # 对齐后的收盘价(宽格式)
'etf_data': etf_data, # ETF 收盘价(宽格式)
'etf_nav_data': etf_nav_data, # ETF 净值数据 {code: DataFrame}
'benchmark_data': benchmark_data, # 基准收盘价 Series
'valid_codes': valid_codes, # 有效指数代码列表
'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射
}
def _get_data_from_local(
self,
code_list_config: dict,
benchmark_code: str
) -> dict:
"""使用本地 HybridDataSource 获取数据"""
from datasource import HybridDataSource
ssh_config = self.config.get('ssh_tunnel', {})
data_source = HybridDataSource(
ssh_config=ssh_config,
use_cache=self.config.get('use_cache', True)
)
# 调用 fetch_all
index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_code_map = \
data_source.fetch_all(
code_config=code_list_config,
benchmark_code=benchmark_code,
start_date=self.start_date,
end_date=self.end_date
)
return {
'index_data': index_ohlcv_data, # 原始OHLCV数据
'index_close': index_data, # 对齐后的收盘价(宽格式)
'etf_data': etf_data,
'etf_nav_data': etf_nav_data,
'benchmark_data': benchmark_data,
'valid_codes': valid_codes,
'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射
}
def compute_factors(self, data: dict) -> pd.DataFrame:
"""计算因子值匹配原引擎先计算因子再对齐到A股交易日历"""
index_data = data['index_data']
valid_codes = data['valid_codes']
# 获取A股交易日历作为基准使用已有的对齐后数据索引
index_close = data.get('index_close')
if index_close is not None:
a_share_dates = index_close.index
else:
for code in valid_codes:
if code.endswith('.SH') or code.endswith('.SZ') or code.endswith('.CSI'):
a_share_dates = index_data[code].index
break
else:
a_share_dates = index_data[valid_codes[0]].index
factor_values = {}
final_valid_codes = []
for code in valid_codes:
df = index_data[code].copy()
# 原引擎剔除逻辑如果有OHLCV列整行dropna()后再检查长度
# 这会剔除国债等只有close数据的标的open/high/low全空
ohlcv_cols = ['open', 'high', 'low', 'close', 'volume']
has_ohlcv = all(col in df.columns for col in ['open', 'high', 'low', 'close'])
if has_ohlcv:
# 原引擎逻辑整行dropna()后检查数据是否足够
df_clean = df[ohlcv_cols].dropna()
if len(df_clean) < self.n_days + 1:
print(f" ⚠ 剔除 {code}: OHLCV数据不足 ({len(df_clean)} < {self.n_days + 1})")
continue
close_series = df_clean['close']
else:
# 只有close列的情况
if 'close' in df.columns:
close_series = df['close'].dropna()
else:
close_series = df.dropna()
if len(close_series) < self.n_days + 1:
print(f" ⚠ 剔除 {code}: close数据不足 ({len(close_series)} < {self.n_days + 1})")
continue
# 原引擎逻辑:先在原始交易日历上计算因子
# rolling窗口使用的是原始交易日数据不包含ffill填充的重复值
close_df = pd.DataFrame({'close': close_series})
factor_series = self._factor.compute(close_df)
# 然后对齐因子序列到A股交易日历匹配原引擎逻辑
factor_aligned = factor_series.reindex(a_share_dates, method='ffill')
factor_values[code] = factor_aligned
final_valid_codes.append(code)
factor_df = pd.DataFrame(factor_values)
# 过滤缺失率过高的标的
total_rows = len(factor_df)
for code in final_valid_codes:
if code in factor_df.columns:
null_pct = factor_df[code].isnull().sum() / total_rows
if null_pct > 0.5:
print(f" ⚠ 剔除 {code}: 缺失率 {null_pct:.1%} 过高")
factor_df = factor_df.drop(columns=[code])
# 更新有效代码列表
data['valid_codes'] = [c for c in final_valid_codes if c in factor_df.columns]
return factor_df
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
"""生成信号"""
return self._selector.generate(factor_df)
def run_backtest(self, data: dict = None, save_path: str = None) -> dict:
"""
完整回测流程
Args:
data: 可选,如不提供则自动获取
save_path: 报告保存路径
Returns:
回测结果字典
"""
print("\n" + "=" * 60)
print(" ETF轮动策略 回测系统")
print("=" * 60)
# 1. 获取数据
if data is None:
data = self.get_data()
valid_codes = data['valid_codes']
index_data = data['index_data']
print(f"\n候选标的: {len(valid_codes)}")
print(f"回测区间: {self.start_date} ~ {self.end_date}")
# 2. 计算因子
print("\n计算因子...")
factor_df = self.compute_factors(data)
print(f" 因子类型: momentum (weighted)\n 窗口天数: {self.n_days}\n 计算完成: {len(factor_df.columns)}")
# 3. 生成信号
print("\n生成信号...")
signals = self.generate_signals(factor_df)
print(f" 选股数量: {self.select_num}\n 分组选股: {len(set(self._group_mapping.values()))} 个大类\n 信号日期: {len(signals)}")
# 4. 执行回测
print("\n执行回测...")
# 获取ETF数据和代码映射
etf_data = data.get('etf_data')
etf_code_map = data.get('etf_code_map', {}) # {指数代码: ETF代码}
# 计算日收益率使用ETF价格数据匹配原引擎逻辑
if etf_data is not None and not etf_data.empty:
# 使用ETF价格计算收益列名保持指数代码格式
returns_data = {}
for idx_code in valid_codes:
etf_code = etf_code_map.get(idx_code, idx_code)
if etf_code in etf_data.columns:
returns_data[f'日收益率_{idx_code}'] = etf_data[etf_code].pct_change()
returns_df = pd.DataFrame(returns_data)
else:
# 回退到指数收盘价数据
index_close = data.get('index_close')
if index_close is not None and not index_close.empty:
returns_df = index_close.pct_change()
returns_df.columns = [f'日收益率_{col}' for col in returns_df.columns]
else:
returns_data = {}
for code in valid_codes:
if code in index_data:
df = index_data[code]
returns_data[f'日收益率_{code}'] = df['close'].pct_change()
returns_df = pd.DataFrame(returns_data)
if valid_codes:
first_code = valid_codes[0]
returns_df.index = index_data[first_code].index
# 确保信号和收益率数据日期对齐
common_dates = signals.index.intersection(returns_df.index)
signals = signals.loc[common_dates]
returns_df = returns_df.loc[common_dates]
print(f" 对齐后日期: {len(common_dates)}")
executor = BacktestExecutor(
initial_capital=100000,
trade_cost=self.trade_cost,
select_num=self.select_num
)
portfolio = executor.execute(signals, returns_df)
# 5. 输出结果
if hasattr(portfolio, 'backtest_result'):
result = portfolio.backtest_result
final_nav = result['策略净值'].iloc[-1]
total_return = (final_nav - 1) * 100
print("\n回测结果:")
print(f" 最终净值: {final_nav:.4f}\n 累计收益: {total_return:.2f}%")
# 保存报告
if save_path:
result[['策略净值']].to_csv(f"{save_path}_nav.csv")
signals.to_csv(f"{save_path}_signals.csv")
print(f" 报告保存: {save_path}_*.csv")
return {
'signals': signals,
'result': result,
'portfolio': portfolio,
'total_return': total_return
}
return {'signals': signals, 'result': None}
# 保留抽象方法实现
def init_factors(self) -> FactorCombiner:
return FactorCombiner([self._factor])
def init_signal_generator(self) -> SignalGenerator:
return self._selector