refactor(archive): move unused modules to archive/
Archive legacy framework and utility modules that are no longer referenced by the active core (datasource/ and rotation/): - framework/ -> archive/framework/ - framework_v2/ -> archive/framework_v2/ - strategies/ -> archive/strategies/ - config/ -> archive/config/ - visualization/ -> archive/visualization/ - scripts/ -> archive/scripts/ - tests/ -> archive/tests/ - run_rotation.py, run_us_rotation.py -> archive/single_files/ - compare_*.py, test_api_dates.py -> archive/single_files/
This commit is contained in:
9
archive/strategies/__init__.py
Normal file
9
archive/strategies/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
strategies模块入口
|
||||
|
||||
包含所有策略实现
|
||||
"""
|
||||
|
||||
from strategies.rotation import RotationStrategy
|
||||
|
||||
__all__ = ['RotationStrategy']
|
||||
41
archive/strategies/base.py
Normal file
41
archive/strategies/base.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
策略基类定义
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Strategy(ABC):
|
||||
"""策略抽象基类"""
|
||||
|
||||
def __init__(self, name: str, config: dict = None):
|
||||
self.name = name
|
||||
self.config = config or {}
|
||||
|
||||
@abstractmethod
|
||||
def run(self, **kwargs) -> Any:
|
||||
"""执行策略"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_signals(self, **kwargs) -> Any:
|
||||
"""获取当前信号"""
|
||||
pass
|
||||
|
||||
|
||||
class BacktestStrategy(Strategy):
|
||||
"""回测策略基类"""
|
||||
|
||||
def __init__(self, name: str, config: dict = None):
|
||||
super().__init__(name, config)
|
||||
self.results = None
|
||||
|
||||
@abstractmethod
|
||||
def run_backtest(self, **kwargs) -> dict:
|
||||
"""执行回测,返回绩效指标"""
|
||||
pass
|
||||
|
||||
def get_results(self) -> dict:
|
||||
"""获取回测结果"""
|
||||
return self.results
|
||||
7
archive/strategies/rotation/__init__.py
Normal file
7
archive/strategies/rotation/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
轮动策略模块入口
|
||||
"""
|
||||
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
__all__ = ['RotationStrategy']
|
||||
182
archive/strategies/rotation/config.yaml
Normal file
182
archive/strategies/rotation/config.yaml
Normal file
@@ -0,0 +1,182 @@
|
||||
# ETF轮动策略配置
|
||||
|
||||
# ==================== 候选池配置 ====================
|
||||
# 指数-ETF映射配置
|
||||
# index: 指数代码(用于计算因子信号)
|
||||
# etf: ETF代码(用于实际交易和收益计算),null表示直接交易指数/加密货币
|
||||
code_list:
|
||||
# 中国A股指数
|
||||
"399006.SZ":
|
||||
name: "创业板指"
|
||||
etf: "159915.SZ"
|
||||
market: "A"
|
||||
"H30269.CSI":
|
||||
name: "中证红利低波"
|
||||
etf: "512890.SH"
|
||||
market: "A"
|
||||
|
||||
|
||||
# 全球市场
|
||||
"NDX":
|
||||
name: "纳指100"
|
||||
etf: "513100.SH"
|
||||
market: "US"
|
||||
"N225":
|
||||
name: "日经225"
|
||||
etf: "513520.SH"
|
||||
market: "JP"
|
||||
"GDAXI":
|
||||
name: "德国DAX"
|
||||
etf: "513030.SH"
|
||||
market: "EU"
|
||||
"HSI":
|
||||
name: "恒生指数"
|
||||
etf: "159920.SZ"
|
||||
market: "HK"
|
||||
"HSTECH.HK":
|
||||
name: "恒生科技"
|
||||
etf: "513130.SH"
|
||||
market: "HK"
|
||||
|
||||
# 商品 & 固收
|
||||
# 使用 COMEX/WTI 期货替代上期所主力合约(数据更长)
|
||||
"GC=F": # COMEX黄金期货(2000年至今)
|
||||
name: "黄金"
|
||||
etf: "518880.SH" # 国内黄金ETF
|
||||
market: "COMMODITY"
|
||||
"CL=F": # WTI原油期货(2000年至今)
|
||||
name: "原油"
|
||||
etf: "160723.SZ" # 国内原油ETF
|
||||
market: "COMMODITY"
|
||||
# 使用 COMEX 铜期货替代上期所主力合约(数据更长)
|
||||
"HG=F": # COMEX铜期货(2000年至今)
|
||||
name: "有色金属"
|
||||
etf: "159980.SZ" # 国内有色金属ETF
|
||||
market: "COMMODITY"
|
||||
|
||||
# 防御类资产:短债指数
|
||||
# 931862.CSI = 中证0-9个月国债指数(短债指数)
|
||||
# 数据范围:2007-12-31开始,约19年数据
|
||||
# 久期:极短(<1年),波动极小,熊市防御效果最佳
|
||||
#
|
||||
# 【收益归因实证分析结论】
|
||||
# 分析方法:将持有债券期间的收益分解为两部分
|
||||
# 1. 标的收益:债券本身的价格上涨(年化约3%)
|
||||
# 2. 决策收益:持有债券期间避免持有股票带来的损失
|
||||
#
|
||||
# 实证结果(2002-2026回测区间):
|
||||
# - 标的收益占比:约17%(债券本身增长贡献)
|
||||
# - 决策收益占比:约83%(避险决策贡献)
|
||||
#
|
||||
# 核心结论:
|
||||
# 短债指数在轮动策略中的价值主要来自"正确避险决策",
|
||||
# 而非债券本身的价格增长。这验证了动量轮动策略的
|
||||
# 核心逻辑:通过仓位选择规避下跌风险,而非依赖标的本身收益。
|
||||
#
|
||||
# 收益对比(策略净值):
|
||||
# - 使用931862.CSI(短债):净值264.54,收益26354%
|
||||
# - 使用000012.SH(综合国债):净值216.30,收益21530%
|
||||
# - 短债防御效果更好,收益高18.2%
|
||||
#
|
||||
# 注意:无对应ETF可交易,直接使用指数数据计算动量和收益
|
||||
"931862.CSI":
|
||||
name: "短债指数"
|
||||
etf: null
|
||||
market: "BOND"
|
||||
|
||||
# 000012.SH(上证国债指数)配置已注释,原因:
|
||||
# 1. 000012.SH是综合国债指数(包含短债到长债),无对应ETF
|
||||
# 2. 之前错误映射到511520.SH(政金债ETF),指数-ETF不匹配
|
||||
# 3. 收益低于短债指数(216.30 vs 264.54)
|
||||
# "000012.SH":
|
||||
# name: "上证国债指数"
|
||||
# etf: "511520.SH"
|
||||
# market: "BOND"
|
||||
|
||||
# 主市场配置
|
||||
primary_market:
|
||||
source: "Tushare"
|
||||
code: "000300.SH"
|
||||
|
||||
# 基准指数配置
|
||||
benchmark:
|
||||
code: "000300.SH"
|
||||
name: "沪深300"
|
||||
|
||||
# ==================== 回测参数 ====================
|
||||
start_date: "2020-01-01"
|
||||
|
||||
# ==================== 因子参数 ====================
|
||||
# 动量/趋势窗口期(天数)
|
||||
n_days: 25
|
||||
# 因子类型:'momentum', 'slope_r2', 'weighted_momentum'
|
||||
factor_type: "weighted_momentum"
|
||||
|
||||
# 动态周期参数 (匹配 JoinQuant 策略)
|
||||
auto_day: false
|
||||
min_days: 20
|
||||
max_days: 60
|
||||
|
||||
# ==================== 轮动参数 ====================
|
||||
select_num: 3
|
||||
# 强制分散化:每个大类只选 Top 1
|
||||
diversified: true
|
||||
# 动量最低阈值:标的动量得分需>=此值才考虑入选(年化收益率*R²)
|
||||
# 设置为0表示过滤负动量标的,更高阈值虽能改善回撤但可能错过正动量机会
|
||||
min_score: 0.0
|
||||
|
||||
# V3: 动态阈值配置(替代固定 min_score: 0.0)
|
||||
# 使用短债动量作为动态 min_score:标的动量 < 短债动量 → 不持有
|
||||
bond_threshold:
|
||||
enabled: true # true=V3动态阈值, false=退化为V2固定阈值
|
||||
bond_code: "931862.CSI" # 阈值参考标的(短债指数)
|
||||
ratio: 1.0 # 阈值 = 短债动量 × ratio
|
||||
fill_bond: true # 选出不足select_num只时,用短债填充空余仓位
|
||||
|
||||
# ==================== 调仓控制 ====================
|
||||
# 最低调仓周期(交易日):持仓至少持有 N 天后才允许换仓
|
||||
rebalance_days: 1
|
||||
# 调仓得分阈值:新组合总得分需超过当前组合 X% 才触发调仓
|
||||
rebalance_threshold: 0.0
|
||||
# 单次换仓成本(双边,含佣金+滑点)
|
||||
trade_cost: 0.001
|
||||
|
||||
# ==================== 溢价控制配置 ====================
|
||||
# 跨境ETF溢价过滤机制(防止高溢价买入)
|
||||
premium_control:
|
||||
enabled: true
|
||||
default_threshold: 0.10 # 默认溢价阈值 10%
|
||||
mode: "filter" # "filter"(完全排除) 或 "penalize"(降权)
|
||||
penalty_factor: 0.5 # 降权模式下的惩罚系数
|
||||
|
||||
# 按市场类型覆盖配置
|
||||
market_overrides:
|
||||
A: # A股 ETF
|
||||
enabled: false # 不启用(溢价通常 < 0.5%)
|
||||
HK: # 港股 ETF
|
||||
enabled: true
|
||||
threshold: 0.10 # 阈值 10%
|
||||
US: # 美股 ETF
|
||||
enabled: true
|
||||
threshold: 0.10 # 阈值 10%
|
||||
COMMODITY: # 商品 ETF
|
||||
enabled: false
|
||||
|
||||
# ==================== 数据缓存 ====================
|
||||
# 是否使用本地缓存(True=优先从本地读取)
|
||||
use_cache: true
|
||||
|
||||
# ==================== 数据源配置 ====================
|
||||
# Flask API 服务配置(优先使用远程 API 获取数据)
|
||||
flask_api:
|
||||
enabled: true # 是否启用 Flask API
|
||||
url: "https://k3s.tokenpluse.xyz" # Flask API 服务地址
|
||||
|
||||
# SSH 隧道配置(用于网络受限环境,通过境外服务器访问 yfinance)
|
||||
ssh_tunnel:
|
||||
enabled: true # 是否启用 SSH 隧道
|
||||
host: "8.218.167.69" # SSH 服务器地址(阿里云香港 ECS IP)
|
||||
port: 22 # SSH 端口
|
||||
username: "root" # SSH 用户名
|
||||
key_path: "hk_ecs.pem" # SSH 私钥路径(相对于项目根目录)
|
||||
local_port: 1080 # 本地 SOCKS5 代理端口
|
||||
563
archive/strategies/rotation/strategy.py
Normal file
563
archive/strategies/rotation/strategy.py
Normal file
@@ -0,0 +1,563 @@
|
||||
"""
|
||||
轮动策略完整实现
|
||||
|
||||
整合数据获取、因子计算、信号生成、回测执行
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# 加载环境变量
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from framework.factors import FactorRegistry, FactorCombiner
|
||||
from framework.signals import SignalGenerator
|
||||
from framework.execution import BacktestExecutor
|
||||
from framework.risk import CallbackHook, Position
|
||||
from framework.strategy import StrategyBase
|
||||
|
||||
# 导入定制组件
|
||||
from strategies.shared.factors.momentum import MomentumFactor
|
||||
from strategies.shared.signals.selectors import TopNSelector
|
||||
|
||||
|
||||
class RotationStrategy(StrategyBase):
|
||||
"""
|
||||
ETF轮动策略(完整实现)
|
||||
|
||||
基于动量因子 + Top N选股 + 分散化
|
||||
|
||||
使用方式:
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
strategy = RotationStrategy.from_yaml('strategies/rotation/config.yaml')
|
||||
result = strategy.run_backtest()
|
||||
"""
|
||||
|
||||
name = "rotation"
|
||||
select_num = 3
|
||||
stoploss = -0.05
|
||||
n_days = 25
|
||||
rebalance_days = 1
|
||||
rebalance_threshold = 0.0
|
||||
trade_cost = 0.001
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
"""初始化策略"""
|
||||
# 应用配置
|
||||
if config:
|
||||
self._apply_config(config)
|
||||
self.config = config
|
||||
else:
|
||||
self.config = {}
|
||||
|
||||
# 初始化因子
|
||||
FactorRegistry.clear()
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
self._factor = FactorRegistry.get(
|
||||
'momentum',
|
||||
n_days=self.n_days,
|
||||
crash_filter=True
|
||||
)
|
||||
|
||||
# 构建分组映射(分散化选股)
|
||||
self._group_mapping = self._build_group_mapping()
|
||||
|
||||
# 初始化信号生成器
|
||||
self._selector = TopNSelector(
|
||||
select_num=self.select_num,
|
||||
group_mapping=self._group_mapping,
|
||||
min_score=self.min_score, # 从配置读取,支持动态调整阈值
|
||||
rebalance_days=self.rebalance_days,
|
||||
rebalance_threshold=self.rebalance_threshold,
|
||||
bond_threshold_config=self.config.get('bond_threshold', {}) # V3动态阈值配置
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, config_path: str) -> 'RotationStrategy':
|
||||
"""从YAML配置创建策略实例"""
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# 设置结束日期
|
||||
if not config.get('end_date'):
|
||||
config['end_date'] = datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
return cls(config)
|
||||
|
||||
def _apply_config(self, config: dict) -> None:
|
||||
"""应用配置参数"""
|
||||
self.select_num = config.get('select_num', self.select_num)
|
||||
self.n_days = config.get('n_days', self.n_days)
|
||||
self.rebalance_days = config.get('rebalance_days', self.rebalance_days)
|
||||
self.rebalance_threshold = config.get('rebalance_threshold', self.rebalance_threshold)
|
||||
self.trade_cost = config.get('trade_cost', self.trade_cost)
|
||||
self.min_score = config.get('min_score', 0.0) # 动量最低阈值,默认过滤负动量
|
||||
self.start_date = config.get('start_date', '2019-01-01')
|
||||
self.end_date = config.get('end_date', datetime.now().strftime('%Y-%m-%d'))
|
||||
|
||||
def _build_group_mapping(self) -> dict:
|
||||
"""构建分组映射(分散化选股)"""
|
||||
group_mapping = {}
|
||||
code_list_config = self.config.get('code_list', {})
|
||||
for code, cfg in code_list_config.items():
|
||||
if isinstance(cfg, dict):
|
||||
group_mapping[code] = cfg.get('market', 'default')
|
||||
return group_mapping
|
||||
|
||||
def get_data(self, use_flask_api: bool = True) -> dict:
|
||||
"""
|
||||
获取数据
|
||||
|
||||
Args:
|
||||
use_flask_api: 是否使用 Flask API 服务获取数据(默认 True)
|
||||
False 则使用本地 UniversalDataFetcher
|
||||
"""
|
||||
code_list_config = self.config.get('code_list', {})
|
||||
benchmark_config = self.config.get('benchmark', {})
|
||||
benchmark_code = benchmark_config.get('code', '000300.SH')
|
||||
|
||||
if not code_list_config:
|
||||
raise ValueError("配置中未找到 code_list")
|
||||
|
||||
# 获取 Flask API 地址
|
||||
flask_api_config = self.config.get('flask_api', {})
|
||||
flask_api_url = flask_api_config.get('url') if flask_api_config.get('enabled') else None
|
||||
|
||||
if use_flask_api:
|
||||
# 使用 Flask API 服务获取数据(远程调用)
|
||||
return self._get_data_from_flask_api(
|
||||
code_list_config,
|
||||
benchmark_code,
|
||||
flask_api_url
|
||||
)
|
||||
else:
|
||||
# 使用本地 HybridDataSource(需要本地 SSH 隧道)
|
||||
return self._get_data_from_local(
|
||||
code_list_config,
|
||||
benchmark_code
|
||||
)
|
||||
|
||||
def _get_data_from_flask_api(
|
||||
self,
|
||||
code_list_config: dict,
|
||||
benchmark_code: str,
|
||||
flask_api_url: str = None
|
||||
) -> dict:
|
||||
"""通过 Flask API 服务获取数据"""
|
||||
from datasource.flask_api_source import FlaskAPIDataSource
|
||||
|
||||
# 初始化 Flask API 数据源
|
||||
api_source = FlaskAPIDataSource(base_url=flask_api_url)
|
||||
|
||||
# 检查服务状态
|
||||
health = api_source.get_health()
|
||||
if health.get('status') != 'healthy':
|
||||
print(f"⚠ Flask API 服务状态: {health}")
|
||||
else:
|
||||
print(f"✓ Flask API 服务正常 (SSH: {health.get('ssh_configured', False)})")
|
||||
|
||||
# 打印回测时间区间说明
|
||||
print(f"\n回测配置区间: {self.start_date} ~ {self.end_date}")
|
||||
print("注: 各标的实际数据范围可能因上市时间/数据源限制而不同")
|
||||
|
||||
# 获取指数代码列表
|
||||
index_codes = list(code_list_config.keys())
|
||||
|
||||
# 获取 ETF 代码映射
|
||||
etf_code_map = {}
|
||||
etf_codes = []
|
||||
for index_code, cfg in code_list_config.items():
|
||||
if isinstance(cfg, dict) and cfg.get('etf'):
|
||||
etf_code_map[index_code] = cfg['etf']
|
||||
etf_codes.append(cfg['etf'])
|
||||
|
||||
# 获取指数 OHLCV 数据
|
||||
print(f"\n获取指数数据 ({len(index_codes)} 只)...")
|
||||
index_ohlcv_data = api_source.fetch_batch(
|
||||
index_codes,
|
||||
self.start_date,
|
||||
self.end_date
|
||||
)
|
||||
|
||||
# 过滤有效代码
|
||||
valid_codes = [code for code, df in index_ohlcv_data.items() if df is not None and len(df) > 0]
|
||||
print(f"有效指数: {len(valid_codes)} 只")
|
||||
|
||||
# 获取 ETF 价格数据(同时获取净值和溢价率)
|
||||
print(f"\n获取 ETF 数据 ({len(etf_codes)} 只)...")
|
||||
etf_ohlcv_data = api_source.fetch_batch(
|
||||
etf_codes,
|
||||
self.start_date,
|
||||
self.end_date
|
||||
)
|
||||
|
||||
# 转换为宽格式 DataFrame,并提取净值/溢价率数据
|
||||
etf_data = None
|
||||
etf_nav_data = {}
|
||||
etf_premium_data = {}
|
||||
|
||||
if etf_ohlcv_data:
|
||||
etf_close_dict = {}
|
||||
for etf_code, df in etf_ohlcv_data.items():
|
||||
if df is not None and 'close' in df.columns:
|
||||
etf_close_dict[etf_code] = df['close']
|
||||
|
||||
# 从 DataFrame.attrs 中提取净值和溢价率数据
|
||||
# Flask API 已自动附加这些数据
|
||||
if 'nav' in df.attrs:
|
||||
etf_nav_data[etf_code] = df.attrs['nav']
|
||||
if 'premium_series' in df.attrs:
|
||||
etf_premium_data[etf_code] = {
|
||||
'series': df.attrs['premium_series'],
|
||||
'latest': df.attrs.get('latest_premium'),
|
||||
'date': df.attrs.get('premium_date'),
|
||||
'stats': df.attrs.get('premium_stats'),
|
||||
}
|
||||
if etf_close_dict:
|
||||
etf_data = pd.DataFrame(etf_close_dict)
|
||||
|
||||
print(f"有效净值: {len(etf_nav_data)} 只")
|
||||
print(f"有效溢价率: {len(etf_premium_data)} 只")
|
||||
|
||||
# 获取基准数据
|
||||
print(f"\n获取基准数据 ({benchmark_code})...")
|
||||
benchmark_ohlcv = api_source.fetch(benchmark_code, self.start_date, self.end_date)
|
||||
benchmark_data = None
|
||||
if benchmark_ohlcv is not None:
|
||||
benchmark_data = benchmark_ohlcv['close']
|
||||
|
||||
# 构建指数收盘价宽格式 DataFrame(用于因子计算)
|
||||
index_close_dict = {}
|
||||
for code in valid_codes:
|
||||
df = index_ohlcv_data.get(code)
|
||||
if df is not None and 'close' in df.columns:
|
||||
index_close_dict[code] = df['close']
|
||||
index_close = pd.DataFrame(index_close_dict) if index_close_dict else None
|
||||
|
||||
# 获取 A 股 SSE 官方交易日历
|
||||
from datasource.tushare_source import TushareSource
|
||||
tushare = TushareSource()
|
||||
a_share_dates = tushare.fetch_trade_cal(self.start_date, self.end_date)
|
||||
print(f"A股交易日历: {len(a_share_dates)} 天")
|
||||
|
||||
return {
|
||||
'index_data': index_ohlcv_data, # 原始 OHLCV 数据 {code: DataFrame}
|
||||
'index_close': index_close, # 对齐后的收盘价(宽格式)
|
||||
'etf_data': etf_data, # ETF 收盘价(宽格式)
|
||||
'etf_nav_data': etf_nav_data, # ETF 净值数据 {code: DataFrame}
|
||||
'etf_premium_data': etf_premium_data, # ETF 溢价率数据 {code: dict}
|
||||
'benchmark_data': benchmark_data, # 基准收盘价 Series
|
||||
'valid_codes': valid_codes, # 有效指数代码列表
|
||||
'etf_code_map': etf_code_map, # {指数代码: ETF代码} 映射
|
||||
'a_share_dates': a_share_dates # A股SSE交易日历
|
||||
}
|
||||
|
||||
def _get_data_from_local(
|
||||
self,
|
||||
code_list_config: dict,
|
||||
benchmark_code: str
|
||||
) -> dict:
|
||||
"""使用本地 UniversalDataFetcher 获取数据"""
|
||||
from datasource import UniversalDataFetcher
|
||||
from datasource.tushare_source import TushareSource
|
||||
|
||||
ssh_config = self.config.get('ssh_tunnel', {})
|
||||
|
||||
fetcher = UniversalDataFetcher(
|
||||
ssh_config=ssh_config,
|
||||
use_cache=self.config.get('use_cache', True)
|
||||
)
|
||||
|
||||
index_codes = list(code_list_config.keys())
|
||||
etf_code_map = {idx_code: cfg['etf'] for idx_code, cfg in code_list_config.items() if cfg.get('etf')}
|
||||
|
||||
# 获取指数数据
|
||||
index_ohlcv_data = {}
|
||||
valid_codes = []
|
||||
|
||||
with fetcher: # 使用上下文管理器自动管理 SSH 隧道
|
||||
for code in index_codes:
|
||||
data = fetcher.fetch(code, self.start_date, self.end_date)
|
||||
if data is not None and len(data) > 0:
|
||||
index_ohlcv_data[code] = data
|
||||
valid_codes.append(code)
|
||||
print(f"✓ {code}: {len(data)} 条")
|
||||
else:
|
||||
print(f"✗ {code}: 无数据")
|
||||
|
||||
# 构建宽格式收盘价
|
||||
index_close = None
|
||||
if index_ohlcv_data:
|
||||
close_list = []
|
||||
for code, df in index_ohlcv_data.items():
|
||||
close_df = df[['close']].copy()
|
||||
close_df.columns = [code]
|
||||
close_list.append(close_df)
|
||||
index_close = pd.concat(close_list, axis=1)
|
||||
|
||||
# 获取 ETF 数据
|
||||
etf_data = None
|
||||
etf_nav_data = None
|
||||
|
||||
tushare = TushareSource()
|
||||
|
||||
if etf_code_map:
|
||||
etf_price_list = []
|
||||
etf_nav_list = []
|
||||
|
||||
for idx_code, etf_code in etf_code_map.items():
|
||||
# ETF 价格
|
||||
etf_df = tushare.fetch_etf(etf_code, self.start_date, self.end_date)
|
||||
if etf_df is not None and len(etf_df) > 0:
|
||||
etf_df = etf_df[['close']].copy()
|
||||
etf_df.columns = [etf_code]
|
||||
etf_price_list.append(etf_df)
|
||||
|
||||
# ETF 净值
|
||||
nav_df = tushare.fetch_etf_nav(etf_code, self.start_date, self.end_date)
|
||||
if nav_df is not None and len(nav_df) > 0:
|
||||
nav_df = nav_df[['nav']].copy()
|
||||
nav_df.columns = [etf_code]
|
||||
etf_nav_list.append(nav_df)
|
||||
|
||||
if etf_price_list:
|
||||
etf_data = pd.concat(etf_price_list, axis=1)
|
||||
if etf_nav_list:
|
||||
etf_nav_data = pd.concat(etf_nav_list, axis=1)
|
||||
|
||||
# 基准数据
|
||||
benchmark_data = tushare.fetch_index(benchmark_code, self.start_date, self.end_date)
|
||||
|
||||
# A股交易日历
|
||||
a_share_dates = tushare.fetch_trade_cal(self.start_date, self.end_date)
|
||||
print(f"A股交易日历: {len(a_share_dates)} 天")
|
||||
|
||||
return {
|
||||
'index_data': index_ohlcv_data, # 原始OHLCV数据
|
||||
'index_close': index_close, # 对齐后的收盘价(宽格式)
|
||||
'etf_data': etf_data,
|
||||
'etf_nav_data': etf_nav_data,
|
||||
'benchmark_data': benchmark_data,
|
||||
'valid_codes': valid_codes,
|
||||
'etf_code_map': etf_code_map, # {指数代码: ETF代码} 映射
|
||||
'a_share_dates': a_share_dates # A股SSE交易日历
|
||||
}
|
||||
|
||||
def compute_factors(self, data: dict) -> pd.DataFrame:
|
||||
"""计算因子值(匹配原引擎:先计算因子再对齐到A股交易日历)
|
||||
|
||||
注意:不剔除数据不足的标的,保留所有标的以暴露策略问题
|
||||
"""
|
||||
index_data = data['index_data']
|
||||
valid_codes = data['valid_codes']
|
||||
|
||||
# 获取 A 股 SSE 官方交易日历(优先使用已获取的)
|
||||
a_share_dates = data.get('a_share_dates')
|
||||
if a_share_dates is None or len(a_share_dates) == 0:
|
||||
# 回退:使用已有的对齐后数据索引
|
||||
index_close = data.get('index_close')
|
||||
if index_close is not None:
|
||||
a_share_dates = index_close.index
|
||||
else:
|
||||
for code in valid_codes:
|
||||
if code.endswith('.SH') or code.endswith('.SZ') or code.endswith('.CSI'):
|
||||
a_share_dates = index_data[code].index
|
||||
break
|
||||
else:
|
||||
a_share_dates = index_data[valid_codes[0]].index
|
||||
|
||||
factor_values = {}
|
||||
final_valid_codes = []
|
||||
|
||||
for code in valid_codes:
|
||||
df = index_data[code].copy()
|
||||
|
||||
# 检查是否有有效的OHLCV数据(列存在且不全为None)
|
||||
ohlcv_cols = ['open', 'high', 'low', 'close', 'volume']
|
||||
required_cols = ['open', 'high', 'low', 'close']
|
||||
|
||||
# 检查列是否存在
|
||||
cols_exist = all(col in df.columns for col in required_cols)
|
||||
|
||||
# 检查数据是否有效(不全为None/NaN)
|
||||
if cols_exist:
|
||||
cols_have_data = all(df[col].notna().any() for col in required_cols)
|
||||
else:
|
||||
cols_have_data = False
|
||||
|
||||
if cols_exist and cols_have_data:
|
||||
# 有完整有效的OHLCV数据,整行dropna()后提取close
|
||||
df_clean = df[ohlcv_cols].dropna()
|
||||
close_series = df_clean['close'] if len(df_clean) > 0 else pd.Series(dtype=float)
|
||||
elif 'close' in df.columns and df['close'].notna().any():
|
||||
# 只有close列有效数据(如债券指数)
|
||||
close_series = df['close'].dropna()
|
||||
else:
|
||||
# 无有效数据
|
||||
close_series = pd.Series(dtype=float)
|
||||
|
||||
# 检查数据长度并警告,但不剔除
|
||||
if len(close_series) < self.n_days + 1:
|
||||
print(f" ⚠ {code}: 数据不足 ({len(close_series)} < {self.n_days + 1}),保留但因子值可能为NaN")
|
||||
|
||||
# 原引擎逻辑:先在原始交易日历上计算因子
|
||||
# rolling窗口使用的是原始交易日数据,不包含ffill填充的重复值
|
||||
if len(close_series) > 0:
|
||||
close_df = pd.DataFrame({'close': close_series})
|
||||
factor_series = self._factor.compute(close_df)
|
||||
|
||||
# 然后对齐因子序列到A股交易日历(匹配原引擎逻辑)
|
||||
factor_aligned = factor_series.reindex(a_share_dates, method='ffill')
|
||||
else:
|
||||
# 没有数据,创建空的因子序列
|
||||
factor_aligned = pd.Series(index=a_share_dates, dtype=float)
|
||||
|
||||
factor_values[code] = factor_aligned
|
||||
final_valid_codes.append(code)
|
||||
|
||||
factor_df = pd.DataFrame(factor_values)
|
||||
|
||||
# 检查缺失率并警告,但不剔除(保留所有标的以暴露策略问题)
|
||||
total_rows = len(factor_df)
|
||||
for code in final_valid_codes:
|
||||
if code in factor_df.columns:
|
||||
null_pct = factor_df[code].isnull().sum() / total_rows
|
||||
if null_pct > 0.5:
|
||||
print(f" ⚠ {code}: 缺失率 {null_pct:.1%} 较高,保留但信号生成时可能跳过")
|
||||
|
||||
# 不更新有效代码列表,保留所有原始代码
|
||||
data['valid_codes'] = final_valid_codes
|
||||
|
||||
return factor_df
|
||||
|
||||
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""生成信号"""
|
||||
return self._selector.generate(factor_df)
|
||||
|
||||
def run_backtest(self, data: dict = None, save_path: str = None) -> dict:
|
||||
"""
|
||||
完整回测流程
|
||||
|
||||
Args:
|
||||
data: 可选,如不提供则自动获取
|
||||
save_path: 报告保存路径
|
||||
|
||||
Returns:
|
||||
回测结果字典
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print(" ETF轮动策略 回测系统")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 获取数据
|
||||
if data is None:
|
||||
data = self.get_data()
|
||||
|
||||
valid_codes = data['valid_codes']
|
||||
index_data = data['index_data']
|
||||
|
||||
print(f"\n候选标的: {len(valid_codes)} 只")
|
||||
print(f"回测区间: {self.start_date} ~ {self.end_date}")
|
||||
|
||||
# 2. 计算因子
|
||||
print("\n计算因子...")
|
||||
factor_df = self.compute_factors(data)
|
||||
print(f" 因子类型: momentum (weighted)\n 窗口天数: {self.n_days}\n 计算完成: {len(factor_df.columns)} 只")
|
||||
|
||||
# 3. 生成信号
|
||||
print("\n生成信号...")
|
||||
signals = self.generate_signals(factor_df)
|
||||
print(f" 选股数量: {self.select_num}\n 分组选股: {len(set(self._group_mapping.values()))} 个大类\n 信号日期: {len(signals)} 天")
|
||||
|
||||
# 4. 执行回测
|
||||
print("\n执行回测...")
|
||||
|
||||
# 获取 A 股 SSE 官方交易日历(优先使用已获取的)
|
||||
a_share_dates = data.get('a_share_dates')
|
||||
if a_share_dates is None or len(a_share_dates) == 0:
|
||||
a_share_dates = signals.index
|
||||
|
||||
# 将信号对齐到 A 股日历
|
||||
if a_share_dates is not signals.index:
|
||||
signals = signals.reindex(a_share_dates, method='ffill').dropna(subset=[signals.columns[0]])
|
||||
print(f" 信号对齐到A股日历: {len(signals)} 天")
|
||||
|
||||
# 计算日收益率:先在原始交易日历计算,再对齐到A股日历
|
||||
# 关键:与因子计算逻辑一致,避免交易日不对齐导致收益率NaN
|
||||
returns_data = {}
|
||||
for code in valid_codes:
|
||||
if code in index_data:
|
||||
df = index_data[code]
|
||||
# 提取原始收盘价序列
|
||||
if 'close' in df.columns:
|
||||
close_series = df['close'].dropna()
|
||||
# 修复:先ffill价格对齐到A股日历,再计算收益率
|
||||
# 原因:若先pct_change再ffill,休市日会复制前一天的非零收益率
|
||||
# 正确做法:休市日价格不变 → 收益率应为0%
|
||||
close_aligned = close_series.reindex(a_share_dates, method='ffill')
|
||||
returns_aligned = close_aligned.pct_change(fill_method=None)
|
||||
returns_data[f'日收益率_{code}'] = returns_aligned
|
||||
|
||||
returns_df = pd.DataFrame(returns_data)
|
||||
|
||||
# 确保信号和收益率数据日期对齐
|
||||
common_dates = signals.index.intersection(returns_df.index)
|
||||
signals = signals.loc[common_dates]
|
||||
returns_df = returns_df.loc[common_dates]
|
||||
|
||||
print(f" 对齐后日期: {len(common_dates)} 天")
|
||||
|
||||
executor = BacktestExecutor(
|
||||
initial_capital=100000,
|
||||
trade_cost=self.trade_cost,
|
||||
select_num=self.select_num
|
||||
)
|
||||
|
||||
portfolio = executor.execute(signals, returns_df)
|
||||
|
||||
# 5. 输出结果
|
||||
if hasattr(portfolio, 'backtest_result'):
|
||||
result = portfolio.backtest_result
|
||||
final_nav = result['策略净值'].iloc[-1]
|
||||
total_return = (final_nav - 1) * 100
|
||||
|
||||
print("\n回测结果:")
|
||||
print(f" 最终净值: {final_nav:.4f}\n 累计收益: {total_return:.2f}%")
|
||||
|
||||
# 获取调仓事件
|
||||
rebalance_events = getattr(portfolio, 'rebalance_events', pd.DataFrame())
|
||||
if not rebalance_events.empty:
|
||||
print(f" 调仓次数: {len(rebalance_events)} 次")
|
||||
|
||||
# 保存报告
|
||||
if save_path:
|
||||
result[['策略净值']].to_csv(f"{save_path}_nav.csv")
|
||||
signals.to_csv(f"{save_path}_signals.csv")
|
||||
|
||||
# 保存调仓事件记录
|
||||
if not rebalance_events.empty:
|
||||
rebalance_events.to_csv(f"{save_path}_rebalances.csv")
|
||||
print(f" 报告保存: {save_path}_*.csv (含调仓记录)")
|
||||
else:
|
||||
print(f" 报告保存: {save_path}_*.csv")
|
||||
|
||||
return {
|
||||
'signals': signals,
|
||||
'result': result,
|
||||
'portfolio': portfolio,
|
||||
'total_return': total_return,
|
||||
'rebalance_events': rebalance_events
|
||||
}
|
||||
|
||||
return {'signals': signals, 'result': None}
|
||||
|
||||
# 保留抽象方法实现
|
||||
def init_factors(self) -> FactorCombiner:
|
||||
return FactorCombiner([self._factor])
|
||||
|
||||
def init_signal_generator(self) -> SignalGenerator:
|
||||
return self._selector
|
||||
54
archive/strategies/shared/__init__.py
Normal file
54
archive/strategies/shared/__init__.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
定制组件统一入口
|
||||
|
||||
所有定制因子、信号生成器、风控组件都在这里导出
|
||||
"""
|
||||
|
||||
# 定制因子
|
||||
from strategies.shared.factors.momentum import (
|
||||
MomentumFactor,
|
||||
TrendFactor,
|
||||
ReversalFactor,
|
||||
VolatilityFactor
|
||||
)
|
||||
|
||||
# 定制信号生成器
|
||||
from strategies.shared.signals.selectors import (
|
||||
TopNSelector,
|
||||
TrendFollower,
|
||||
ReversalTrader
|
||||
)
|
||||
|
||||
# 定制风控组件
|
||||
from strategies.shared.risk.controls import (
|
||||
StopLossControl,
|
||||
PositionLimitControl,
|
||||
PremiumControl,
|
||||
premium_filter_callback,
|
||||
crash_filter_callback,
|
||||
holding_time_stoploss_callback
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
# 因子
|
||||
'MomentumFactor',
|
||||
'TrendFactor',
|
||||
'ReversalFactor',
|
||||
'VolatilityFactor',
|
||||
|
||||
# 信号生成器
|
||||
'TopNSelector',
|
||||
'TrendFollower',
|
||||
'ReversalTrader',
|
||||
|
||||
# 风控组件
|
||||
'StopLossControl',
|
||||
'PositionLimitControl',
|
||||
'PremiumControl',
|
||||
|
||||
# 回调函数
|
||||
'premium_filter_callback',
|
||||
'crash_filter_callback',
|
||||
'holding_time_stoploss_callback',
|
||||
]
|
||||
18
archive/strategies/shared/data/__init__.py
Normal file
18
archive/strategies/shared/data/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
定制数据源统一入口
|
||||
"""
|
||||
|
||||
from strategies.shared.data.sources import (
|
||||
LocalFileCache,
|
||||
HybridDataSourceAdapter,
|
||||
TushareDataSource,
|
||||
YFinanceDataSource
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'LocalFileCache',
|
||||
'HybridDataSourceAdapter',
|
||||
'TushareDataSource',
|
||||
'YFinanceDataSource'
|
||||
]
|
||||
223
archive/strategies/shared/data/sources.py
Normal file
223
archive/strategies/shared/data/sources.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
定制数据源实现
|
||||
|
||||
具体数据源适配器继承framework.data.DataSource
|
||||
"""
|
||||
|
||||
from framework.data import DataSource, OHLCVData, DataCache
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
class LocalFileCache(DataCache):
|
||||
"""
|
||||
本地文件缓存(定制实现)
|
||||
|
||||
支持版本控制和新鲜性检查
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: str = "data/etf_cache/daily"):
|
||||
"""初始化缓存"""
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
def get(self, code: str, start: str, end: str) -> Optional[OHLCVData]:
|
||||
"""从缓存获取数据"""
|
||||
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
|
||||
|
||||
if not os.path.exists(cache_file):
|
||||
return None
|
||||
|
||||
try:
|
||||
df = pd.read_csv(cache_file, index_col=0, parse_dates=True)
|
||||
|
||||
# 过滤日期范围
|
||||
df = df.loc[start:end]
|
||||
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
return OHLCVData(
|
||||
code=code,
|
||||
data=df,
|
||||
start_date=df.index.min(),
|
||||
end_date=df.index.max()
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"缓存读取失败: {code} - {e}")
|
||||
return None
|
||||
|
||||
def set(self, code: str, data: OHLCVData) -> None:
|
||||
"""写入缓存"""
|
||||
if data.data is None:
|
||||
return
|
||||
|
||||
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
|
||||
|
||||
# 如果已存在,追加新数据
|
||||
if os.path.exists(cache_file):
|
||||
existing = pd.read_csv(cache_file, index_col=0, parse_dates=True)
|
||||
combined = pd.concat([existing, data.data]).drop_duplicates()
|
||||
combined.to_csv(cache_file)
|
||||
else:
|
||||
data.data.to_csv(cache_file)
|
||||
|
||||
def is_fresh(self, code: str, max_age_days: int = 1) -> bool:
|
||||
"""检查缓存是否新鲜"""
|
||||
meta_file = os.path.join(self.cache_dir, f"{code}.meta.json")
|
||||
|
||||
if not os.path.exists(meta_file):
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(meta_file, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
last_update = datetime.fromisoformat(meta.get('last_update', ''))
|
||||
age = (datetime.now() - last_update).days
|
||||
|
||||
return age <= max_age_days
|
||||
except:
|
||||
return False
|
||||
|
||||
def clear(self, code: Optional[str] = None) -> None:
|
||||
"""清空缓存"""
|
||||
if code:
|
||||
cache_file = os.path.join(self.cache_dir, f"{code}.csv")
|
||||
meta_file = os.path.join(self.cache_dir, f"{code}.meta.json")
|
||||
|
||||
if os.path.exists(cache_file):
|
||||
os.remove(cache_file)
|
||||
if os.path.exists(meta_file):
|
||||
os.remove(meta_file)
|
||||
else:
|
||||
# 清空整个缓存目录
|
||||
for f in os.listdir(self.cache_dir):
|
||||
os.remove(os.path.join(self.cache_dir, f))
|
||||
|
||||
|
||||
class HybridDataSourceAdapter(DataSource):
|
||||
"""
|
||||
混合数据源适配器(定制实现)
|
||||
|
||||
封装现有的HybridDataSource,适配到框架DataSource接口
|
||||
"""
|
||||
|
||||
name = "hybrid"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_cache: bool = True,
|
||||
cache_dir: str = "data/etf_cache/daily",
|
||||
ssh_config: Optional[Dict] = None
|
||||
):
|
||||
"""初始化混合数据源"""
|
||||
super().__init__(use_cache=use_cache, cache_dir=cache_dir, ssh_config=ssh_config)
|
||||
self.use_cache = use_cache
|
||||
self.cache = LocalFileCache(cache_dir) if use_cache else None
|
||||
self.ssh_config = ssh_config or {}
|
||||
|
||||
# 内部使用现有的HybridDataSource
|
||||
self._hybrid_source = None
|
||||
|
||||
def _init_hybrid_source(self):
|
||||
"""延迟初始化HybridDataSource"""
|
||||
if self._hybrid_source is None:
|
||||
from core.datasource.hybrid_source import HybridDataSource
|
||||
self._hybrid_source = HybridDataSource(
|
||||
ssh_config=self.ssh_config,
|
||||
use_cache=self.use_cache
|
||||
)
|
||||
|
||||
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
|
||||
"""获取单个标的数据"""
|
||||
# 先检查缓存
|
||||
if self.cache and self.cache.is_fresh(code):
|
||||
cached = self.cache.get(code, start, end)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
# 从数据源获取
|
||||
self._init_hybrid_source()
|
||||
|
||||
# 这里需要根据代码类型判断使用哪个数据源
|
||||
# 简化实现:直接调用现有HybridDataSource
|
||||
# TODO: 完整实现需要适配现有数据源
|
||||
|
||||
return OHLCVData(code=code, data=pd.DataFrame())
|
||||
|
||||
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
|
||||
"""批量获取数据"""
|
||||
result = {}
|
||||
for code in codes:
|
||||
result[code] = self.fetch(code, start, end)
|
||||
return result
|
||||
|
||||
def get_supported_codes(self) -> List[str]:
|
||||
"""获取支持的代码列表"""
|
||||
# 从fund_basic.csv读取
|
||||
basic_file = "data/etf_cache/fund_basic.csv"
|
||||
if os.path.exists(basic_file):
|
||||
df = pd.read_csv(basic_file)
|
||||
return df['code'].tolist()
|
||||
return []
|
||||
|
||||
|
||||
class TushareDataSource(DataSource):
|
||||
"""
|
||||
Tushare数据源(定制实现)
|
||||
|
||||
用于获取A股数据
|
||||
"""
|
||||
|
||||
name = "tushare"
|
||||
|
||||
def __init__(self, token: Optional[str] = None):
|
||||
"""初始化Tushare数据源"""
|
||||
super().__init__(token=token)
|
||||
self.token = token
|
||||
|
||||
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
|
||||
"""获取A股指数数据"""
|
||||
# TODO: 实现Tushare数据获取
|
||||
return OHLCVData(code=code, data=pd.DataFrame())
|
||||
|
||||
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
|
||||
"""批量获取"""
|
||||
return {code: self.fetch(code, start, end) for code in codes}
|
||||
|
||||
|
||||
class YFinanceDataSource(DataSource):
|
||||
"""
|
||||
YFinance数据源(定制实现)
|
||||
|
||||
用于获取港股/美股/加密货币数据
|
||||
"""
|
||||
|
||||
name = "yfinance"
|
||||
|
||||
def __init__(self, use_ssh_tunnel: bool = False, ssh_config: Optional[Dict] = None):
|
||||
"""初始化YFinance数据源"""
|
||||
super().__init__(use_ssh_tunnel=use_ssh_tunnel, ssh_config=ssh_config)
|
||||
self.use_ssh_tunnel = use_ssh_tunnel
|
||||
self.ssh_config = ssh_config or {}
|
||||
|
||||
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
|
||||
"""获取境外数据"""
|
||||
# TODO: 实现YFinance数据获取(含SSH隧道)
|
||||
return OHLCVData(code=code, data=pd.DataFrame())
|
||||
|
||||
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
|
||||
"""批量获取"""
|
||||
return {code: self.fetch(code, start, end) for code in codes}
|
||||
|
||||
|
||||
# 导出定制数据源
|
||||
__all__ = [
|
||||
'LocalFileCache',
|
||||
'HybridDataSourceAdapter',
|
||||
'TushareDataSource',
|
||||
'YFinanceDataSource'
|
||||
]
|
||||
250
archive/strategies/shared/factors/momentum.py
Normal file
250
archive/strategies/shared/factors/momentum.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
定制因子实现
|
||||
|
||||
这些因子继承framework.core.factors.FactorBase
|
||||
"""
|
||||
|
||||
from framework.factors import FactorBase, FactorRegistry
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
|
||||
class MomentumFactor(FactorBase):
|
||||
"""
|
||||
动量因子(定制实现)
|
||||
|
||||
计算加权线性回归动量得分:
|
||||
得分 = 年化收益率 × R²
|
||||
|
||||
参数:
|
||||
- n_days: 动量窗口(默认25)
|
||||
- weighted: 是否加权(默认True)
|
||||
- crash_filter: 是否启用崩盘过滤(默认True)
|
||||
"""
|
||||
|
||||
name = "momentum"
|
||||
category = "momentum"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_days: int = 25,
|
||||
weighted: bool = True,
|
||||
crash_filter: bool = True
|
||||
):
|
||||
super().__init__(n_days=n_days, weighted=weighted, crash_filter=crash_filter)
|
||||
self.n_days = n_days
|
||||
self.weighted = weighted
|
||||
self.crash_filter = crash_filter
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算动量因子值"""
|
||||
if 'close' not in data.columns:
|
||||
raise ValueError("data must contain 'close' column")
|
||||
|
||||
prices = data['close']
|
||||
|
||||
if self.weighted:
|
||||
factor_values = prices.rolling(self.n_days).apply(
|
||||
lambda x: self._weighted_momentum_score(x.values),
|
||||
raw=False
|
||||
)
|
||||
else:
|
||||
factor_values = prices.pct_change(self.n_days)
|
||||
|
||||
if self.crash_filter:
|
||||
factor_values = self._apply_crash_filter(prices, factor_values)
|
||||
|
||||
return factor_values
|
||||
|
||||
def _weighted_momentum_score(self, prices: np.ndarray) -> float:
|
||||
"""计算加权动量得分"""
|
||||
if len(prices) < 5:
|
||||
return 0.0
|
||||
|
||||
# 价格下界 clip,防止 log(0) 或 log(负数)
|
||||
prices = np.clip(prices, 0.01, None)
|
||||
y = np.log(prices)
|
||||
|
||||
# 异常值检测
|
||||
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
|
||||
return 0.0
|
||||
|
||||
x = np.arange(len(y))
|
||||
weights = np.linspace(1, 2, len(y))
|
||||
|
||||
slope, intercept = np.polyfit(x, y, 1, w=weights)
|
||||
annualized_returns = math.exp(slope * 250) - 1
|
||||
|
||||
y_pred = slope * x + intercept
|
||||
ss_res = np.sum(weights * (y - y_pred) ** 2)
|
||||
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
|
||||
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
|
||||
|
||||
return annualized_returns * r2
|
||||
|
||||
def _apply_crash_filter(self, prices: pd.Series, factor_values: pd.Series) -> pd.Series:
|
||||
"""崩盘过滤:连续3天跌>5%清零"""
|
||||
result = factor_values.copy()
|
||||
|
||||
for i in range(3, len(prices)):
|
||||
r1 = prices.iloc[i] / prices.iloc[i-1]
|
||||
r2 = prices.iloc[i-1] / prices.iloc[i-2]
|
||||
r3 = prices.iloc[i-2] / prices.iloc[i-3]
|
||||
|
||||
con1 = min(r1, r2, r3) < 0.95
|
||||
con2 = (r1 < 1) and (r2 < 1) and (r3 < 1) and (prices.iloc[i] / prices.iloc[i-3] < 0.95)
|
||||
|
||||
if con1 or con2:
|
||||
result.iloc[i] = 0.0
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TrendFactor(FactorBase):
|
||||
"""趋势因子(定制实现)"""
|
||||
|
||||
name = "trend"
|
||||
category = "trend"
|
||||
|
||||
def __init__(self, method: str = 'ma_cross', fast: int = 5, slow: int = 20):
|
||||
super().__init__(method=method, fast=fast, slow=slow)
|
||||
self.method = method
|
||||
self.fast = fast
|
||||
self.slow = slow
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算趋势因子值"""
|
||||
if 'close' not in data.columns:
|
||||
raise ValueError("data must contain 'close' column")
|
||||
|
||||
prices = data['close']
|
||||
|
||||
if self.method == 'ma_cross':
|
||||
fast_ma = prices.rolling(self.fast).mean()
|
||||
slow_ma = prices.rolling(self.slow).mean()
|
||||
return (fast_ma - slow_ma) / slow_ma
|
||||
|
||||
elif self.method == 'macd':
|
||||
ema12 = prices.ewm(span=12).mean()
|
||||
ema26 = prices.ewm(span=26).mean()
|
||||
macd = ema12 - ema26
|
||||
signal = macd.ewm(span=9).mean()
|
||||
return macd - signal
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {self.method}")
|
||||
|
||||
|
||||
class ReversalFactor(FactorBase):
|
||||
"""反转因子(定制实现)"""
|
||||
|
||||
name = "reversal"
|
||||
category = "reversal"
|
||||
|
||||
def __init__(self, method: str = 'rsi', period: int = 14, overbought: float = 70, oversold: float = 30):
|
||||
super().__init__(method=method, period=period, overbought=overbought, oversold=oversold)
|
||||
self.method = method
|
||||
self.period = period
|
||||
self.overbought = overbought
|
||||
self.oversold = oversold
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算反转因子值"""
|
||||
if 'close' not in data.columns:
|
||||
raise ValueError("data must contain 'close' column")
|
||||
|
||||
prices = data['close']
|
||||
|
||||
if self.method == 'rsi':
|
||||
rsi = self._compute_rsi(prices, self.period)
|
||||
reversal_signal = np.where(
|
||||
rsi > self.overbought,
|
||||
-(rsi - self.overbought) / (100 - self.overbought),
|
||||
np.where(
|
||||
rsi < self.oversold,
|
||||
(self.oversold - rsi) / self.oversold,
|
||||
0
|
||||
)
|
||||
)
|
||||
return pd.Series(reversal_signal, index=prices.index)
|
||||
|
||||
elif self.method == 'kdj':
|
||||
return self._compute_kdj(data)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {self.method}")
|
||||
|
||||
def _compute_rsi(self, prices: pd.Series, period: int) -> pd.Series:
|
||||
"""计算RSI"""
|
||||
delta = prices.diff()
|
||||
gain = delta.where(delta > 0, 0)
|
||||
loss = (-delta).where(delta < 0, 0)
|
||||
|
||||
avg_gain = gain.rolling(period).mean()
|
||||
avg_loss = loss.rolling(period).mean()
|
||||
|
||||
rs = avg_gain / avg_loss
|
||||
return 100 - (100 / (1 + rs))
|
||||
|
||||
def _compute_kdj(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算KDJ反转信号"""
|
||||
low = data['low']
|
||||
high = data['high']
|
||||
close = data['close']
|
||||
|
||||
low_min = low.rolling(self.period).min()
|
||||
high_max = high.rolling(self.period).max()
|
||||
|
||||
rsv = (close - low_min) / (high_max - low_min) * 100
|
||||
|
||||
k = rsv.ewm(alpha=1/3).mean()
|
||||
d = k.ewm(alpha=1/3).mean()
|
||||
j = 3 * k - 2 * d
|
||||
|
||||
return j
|
||||
|
||||
|
||||
class VolatilityFactor(FactorBase):
|
||||
"""波动率因子(定制实现)"""
|
||||
|
||||
name = "volatility"
|
||||
category = "volatility"
|
||||
|
||||
def __init__(self, method: str = 'std', period: int = 20):
|
||||
super().__init__(method=method, period=period)
|
||||
self.method = method
|
||||
self.period = period
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算波动率因子值"""
|
||||
if self.method == 'std':
|
||||
return data['close'].rolling(self.period).std()
|
||||
|
||||
elif self.method == 'atr':
|
||||
return self._compute_atr(data)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {self.method}")
|
||||
|
||||
def _compute_atr(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算ATR"""
|
||||
high = data['high']
|
||||
low = data['low']
|
||||
close = data['close']
|
||||
|
||||
prev_close = close.shift(1)
|
||||
tr = pd.concat([
|
||||
high - low,
|
||||
(high - prev_close).abs(),
|
||||
(low - prev_close).abs()
|
||||
], axis=1).max(axis=1)
|
||||
|
||||
return tr.rolling(self.period).mean()
|
||||
|
||||
|
||||
# 注册因子
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
FactorRegistry.register(TrendFactor)
|
||||
FactorRegistry.register(ReversalFactor)
|
||||
FactorRegistry.register(VolatilityFactor)
|
||||
143
archive/strategies/shared/risk/controls.py
Normal file
143
archive/strategies/shared/risk/controls.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
定制风控组件实现
|
||||
|
||||
这些风控组件继承framework.core.risk.RiskControl
|
||||
"""
|
||||
|
||||
from framework.risk import RiskControl, Position, CallbackHook
|
||||
|
||||
|
||||
class StopLossControl(RiskControl):
|
||||
"""止损控制(定制实现)"""
|
||||
|
||||
name = "stop_loss"
|
||||
|
||||
def __init__(self, threshold: float = -0.05, trailing: bool = False, trailing_percent: float = 0.03):
|
||||
super().__init__(threshold=threshold, trailing=trailing, trailing_percent=trailing_percent)
|
||||
self.threshold = threshold
|
||||
self.trailing = trailing
|
||||
self.trailing_percent = trailing_percent
|
||||
self._highest_price = {}
|
||||
|
||||
def check(self, position: Position, **kwargs) -> bool:
|
||||
"""检查是否触发止损"""
|
||||
if position is None:
|
||||
return True
|
||||
|
||||
if self.trailing:
|
||||
if position.code not in self._highest_price:
|
||||
self._highest_price[position.code] = position.entry_price
|
||||
self._highest_price[position.code] = max(
|
||||
self._highest_price[position.code],
|
||||
position.current_price
|
||||
)
|
||||
|
||||
if self.trailing:
|
||||
highest = self._highest_price[position.code]
|
||||
drawdown = (position.current_price - highest) / highest
|
||||
return drawdown > -self.trailing_percent
|
||||
else:
|
||||
return position.profit_ratio > self.threshold
|
||||
|
||||
def apply(self, position: Position):
|
||||
"""返回止损价格"""
|
||||
if self.trailing:
|
||||
highest = self._highest_price.get(position.code, position.entry_price)
|
||||
return highest * (1 - self.trailing_percent)
|
||||
else:
|
||||
return position.entry_price * (1 + self.threshold)
|
||||
|
||||
|
||||
class PositionLimitControl(RiskControl):
|
||||
"""仓位限制控制(定制实现)"""
|
||||
|
||||
name = "position_limit"
|
||||
|
||||
def __init__(self, max_position: float = 0.33, max_total: float = 1.0):
|
||||
super().__init__(max_position=max_position, max_total=max_total)
|
||||
self.max_position = max_position
|
||||
self.max_total = max_total
|
||||
|
||||
def check(self, position: Position, **kwargs) -> bool:
|
||||
"""检查仓位是否超限"""
|
||||
if position is None:
|
||||
return True
|
||||
|
||||
if position.weight > self.max_position:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def apply(self, position: Position):
|
||||
"""返回建议仓位"""
|
||||
return min(position.weight, self.max_position)
|
||||
|
||||
|
||||
class PremiumControl(RiskControl):
|
||||
"""溢价控制(定制实现)"""
|
||||
|
||||
name = "premium"
|
||||
|
||||
def __init__(self, threshold: float = 0.10, mode: str = 'filter'):
|
||||
super().__init__(threshold=threshold, mode=mode)
|
||||
self.threshold = threshold
|
||||
self.mode = mode
|
||||
|
||||
def check(self, position: Position, **kwargs) -> bool:
|
||||
"""检查溢价是否超限"""
|
||||
premium = kwargs.get('premium', 0)
|
||||
|
||||
if self.mode == 'filter':
|
||||
return premium <= self.threshold
|
||||
else:
|
||||
return True
|
||||
|
||||
def apply(self, position: Position):
|
||||
"""返回溢价惩罚系数"""
|
||||
if self.mode == 'penalize':
|
||||
return 0.5
|
||||
return None
|
||||
|
||||
|
||||
# 定制回调函数
|
||||
def premium_filter_callback(threshold: float = 0.10):
|
||||
"""溢价过滤回调(定制实现)"""
|
||||
def callback(code: str, price: float, **kwargs) -> bool:
|
||||
premium = kwargs.get('premium', 0)
|
||||
if premium > threshold:
|
||||
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
|
||||
return False
|
||||
return True
|
||||
return callback
|
||||
|
||||
|
||||
def crash_filter_callback(lookback: int = 3, crash_threshold: float = 0.05):
|
||||
"""崩盘过滤回调(定制实现)"""
|
||||
def callback(code: str, price: float, **kwargs) -> bool:
|
||||
history = kwargs.get('history', None)
|
||||
if history is None:
|
||||
return True
|
||||
|
||||
recent = history.tail(lookback)
|
||||
if len(recent) < lookback:
|
||||
return True
|
||||
|
||||
returns = recent['close'].pct_change()
|
||||
min_return = returns.min()
|
||||
|
||||
if min_return < -crash_threshold:
|
||||
print(f"崩盘检测,拒绝入场: {code} (最大跌幅={min_return:.2%})")
|
||||
return False
|
||||
return True
|
||||
return callback
|
||||
|
||||
|
||||
def holding_time_stoploss_callback(day_5_stoploss: float = -0.05, day_10_stoploss: float = -0.03):
|
||||
"""持仓时间动态止损回调(定制实现)"""
|
||||
def callback(position: Position) -> float:
|
||||
if position.holding_days >= 10:
|
||||
return day_10_stoploss
|
||||
elif position.holding_days >= 5:
|
||||
return day_5_stoploss
|
||||
return -0.10
|
||||
return callback
|
||||
366
archive/strategies/shared/signals/selectors.py
Normal file
366
archive/strategies/shared/signals/selectors.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
定制信号生成器实现
|
||||
|
||||
这些信号生成器继承framework.core.signals.SignalGenerator
|
||||
"""
|
||||
|
||||
from framework.signals import SignalGenerator
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
|
||||
class TopNSelector(SignalGenerator):
|
||||
"""
|
||||
Top N选股器(定制实现)
|
||||
|
||||
用于轮动策略:
|
||||
- 按因子值排序,选出Top N标的
|
||||
- 支持分组选股(先类内竞争,再跨类排序)
|
||||
- 支持调仓阈值检查(新组合得分需超过当前组合一定比例才调仓)
|
||||
- V3: 支持动态阈值(短债动量作为过滤阈值)
|
||||
|
||||
参数:
|
||||
- select_num: 选中数量(默认3)
|
||||
- group_by: 分组键名(可选,如'market')
|
||||
- group_mapping: 分组映射字典(可选,{code: group})
|
||||
- top_per_group: 每组选中数量(默认1)
|
||||
- min_score: 最小得分阈值(可选,如0表示过滤负分)
|
||||
- rebalance_threshold: 调仓阈值(可选,新组合得分需超过当前组合X%才调仓)
|
||||
- rebalance_days: 最低调仓周期(可选,持仓至少N天才能调仓)
|
||||
- bond_threshold_config: V3动态阈值配置
|
||||
"""
|
||||
|
||||
mode = "top_n"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
select_num: int = 3,
|
||||
group_by: Optional[str] = None,
|
||||
group_mapping: Optional[Dict[str, str]] = None,
|
||||
top_per_group: int = 1,
|
||||
min_score: Optional[float] = None,
|
||||
rebalance_threshold: float = 0.0,
|
||||
rebalance_days: int = 1,
|
||||
bond_threshold_config: Optional[Dict] = None
|
||||
):
|
||||
super().__init__(
|
||||
select_num=select_num,
|
||||
group_by=group_by,
|
||||
group_mapping=group_mapping,
|
||||
top_per_group=top_per_group,
|
||||
min_score=min_score,
|
||||
rebalance_threshold=rebalance_threshold,
|
||||
rebalance_days=rebalance_days
|
||||
)
|
||||
self.select_num = select_num
|
||||
self.group_by = group_by
|
||||
self.group_mapping = group_mapping or {}
|
||||
self.top_per_group = top_per_group
|
||||
self.min_score = min_score
|
||||
self.rebalance_threshold = rebalance_threshold
|
||||
self.rebalance_days = rebalance_days
|
||||
self.bond_threshold_config = bond_threshold_config or {}
|
||||
|
||||
def _get_dynamic_threshold(self, scores: Dict[str, float]) -> float:
|
||||
"""获取动态阈值:短债动量 × ratio,无数据时退化为 min_score
|
||||
|
||||
V3动态阈值逻辑:
|
||||
- 若bond_threshold.enabled=true,阈值 = 短债动量 × ratio
|
||||
- 若短债无数据或动量<0,退化为固定min_score
|
||||
- 若enabled=false,退化为固定min_score
|
||||
"""
|
||||
cfg = self.bond_threshold_config
|
||||
if not cfg.get('enabled', False):
|
||||
return self.min_score if self.min_score is not None else 0.0
|
||||
|
||||
bond_code = cfg.get('bond_code', '931862.CSI')
|
||||
ratio = cfg.get('ratio', 1.0)
|
||||
|
||||
bond_score = scores.get(bond_code, None)
|
||||
if bond_score is None or bond_score < 0:
|
||||
return self.min_score if self.min_score is not None else 0.0
|
||||
|
||||
return bond_score * ratio
|
||||
|
||||
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""生成Top N选股信号(支持调仓周期控制)"""
|
||||
result = pd.DataFrame(index=factor_data.index)
|
||||
|
||||
factor_cols = self._get_factor_columns(factor_data)
|
||||
|
||||
if not factor_cols:
|
||||
result['signal'] = ''
|
||||
return result
|
||||
|
||||
# Step 1: 每日目标组合(不考虑调仓周期)
|
||||
daily_target = []
|
||||
for date in factor_data.index:
|
||||
row = factor_data.loc[date]
|
||||
|
||||
# 提取得分
|
||||
scores = {}
|
||||
for col in factor_cols:
|
||||
score = row[col]
|
||||
if pd.notna(score):
|
||||
scores[col] = score
|
||||
|
||||
# V3: 过滤前检查bond是否有因子数据(用于填充守卫)
|
||||
cfg = self.bond_threshold_config
|
||||
bond_code = cfg.get('bond_code', '931862.CSI') if cfg.get('enabled') else None
|
||||
bond_has_data = bond_code in scores # scores此时是过滤前的完整字典
|
||||
|
||||
# V3: 动态阈值过滤(替代固定 min_score)
|
||||
threshold = self._get_dynamic_threshold(scores)
|
||||
scores = {k: v for k, v in scores.items() if v >= threshold}
|
||||
|
||||
# 分组选股或全局选股
|
||||
if self.group_mapping:
|
||||
selected = self._grouped_selection(scores, bond_has_data)
|
||||
else:
|
||||
selected = self._global_top_n(scores)
|
||||
|
||||
daily_target.append(','.join(selected) if selected else '')
|
||||
|
||||
# Step 2: 逐日生成信号(调仓周期控制)
|
||||
signals = self._apply_rebalance_control(daily_target, factor_data)
|
||||
|
||||
result['signal_raw'] = daily_target # 每日目标组合
|
||||
result['signal'] = signals
|
||||
|
||||
# T+1执行:信号向后移位1天
|
||||
result['signal'] = result['signal'].shift(1)
|
||||
|
||||
return result
|
||||
|
||||
def _get_factor_columns(self, data: pd.DataFrame) -> List[str]:
|
||||
"""获取因子列名"""
|
||||
exclude_cols = ['signal', 'signal_raw', 'group_info', 'combined', 'open', 'high', 'low', 'close', 'volume']
|
||||
return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')]
|
||||
|
||||
def _global_top_n(self, scores: Dict[str, float]) -> List[str]:
|
||||
"""全局Top N选股"""
|
||||
if not scores:
|
||||
return []
|
||||
|
||||
sorted_items = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
||||
return [item[0] for item in sorted_items[:self.select_num]]
|
||||
|
||||
def _apply_rebalance_control(self, daily_target: List[str], factor_data: pd.DataFrame) -> List[str]:
|
||||
"""应用调仓周期控制"""
|
||||
signals = []
|
||||
current_held = None
|
||||
last_rebalance_idx = 0
|
||||
|
||||
for i, target in enumerate(daily_target):
|
||||
# 初始持仓为空,等待第一个有效信号
|
||||
if current_held is None:
|
||||
if not target:
|
||||
signals.append('')
|
||||
continue
|
||||
current_held = target
|
||||
last_rebalance_idx = i
|
||||
signals.append(current_held)
|
||||
continue
|
||||
|
||||
# 检查调仓周期
|
||||
days_since = i - last_rebalance_idx
|
||||
if days_since < self.rebalance_days:
|
||||
# 未达到最低调仓周期,保持当前持仓
|
||||
signals.append(current_held)
|
||||
continue
|
||||
|
||||
# 检查是否应该调仓
|
||||
if target: # 目标信号有效
|
||||
should = self._check_rebalance(
|
||||
factor_data.iloc[i],
|
||||
current_held,
|
||||
target,
|
||||
self._get_factor_columns(factor_data)
|
||||
)
|
||||
if should:
|
||||
current_held = target
|
||||
last_rebalance_idx = i
|
||||
else:
|
||||
# V3: target为空时保持当前持仓不变(与独立脚本行为一致)
|
||||
# 在有bond数据的时期target不会为空(会被bond填充)
|
||||
# target为空仅发生在2002-2007无bond数据期
|
||||
# 保持旧持仓比突然清仓更平滑
|
||||
pass
|
||||
|
||||
signals.append(current_held)
|
||||
|
||||
return signals
|
||||
|
||||
def _check_rebalance(
|
||||
self,
|
||||
row: pd.Series,
|
||||
current_held: str,
|
||||
target: str,
|
||||
factor_cols: List[str]
|
||||
) -> bool:
|
||||
"""检查是否应该调仓(得分阈值检查)"""
|
||||
# 提取当前持仓和目标持仓的代码
|
||||
old_codes = [c for c in current_held.split(',') if c]
|
||||
new_codes = [c for c in target.split(',') if c]
|
||||
|
||||
if not new_codes or not old_codes:
|
||||
return True
|
||||
|
||||
if set(new_codes) == set(old_codes):
|
||||
return False # 组合完全相同,不调仓
|
||||
|
||||
# 计算新旧组合的总得分
|
||||
# 修复: 按实际份数计算得分(支持重复代码如"931862.CSI,931862.CSI")
|
||||
old_total = sum(float(row.get(c, 0)) for c in old_codes)
|
||||
new_total = sum(float(row.get(c, 0)) for c in new_codes)
|
||||
|
||||
# 新组合得分需超过当前组合一定比例才调仓
|
||||
# 即使 threshold=0,也要确保 new_total >= old_total
|
||||
if old_total > 0:
|
||||
return (new_total / old_total - 1) >= self.rebalance_threshold
|
||||
|
||||
return new_total > 0
|
||||
|
||||
def _grouped_selection(self, scores: Dict[str, float], bond_has_data: bool = True) -> List[str]:
|
||||
"""V3分组选股:BOND不参与竞争,空余仓位填充短债
|
||||
|
||||
V3逻辑:
|
||||
1. BOND大类标的不参与冠军竞争(它是阈值,不是候选)
|
||||
2. 选出不足 select_num 只时,用短债填充
|
||||
3. bond无数据时(2002-2007)不填充
|
||||
4. V2退化:若bond_threshold.enabled=false,BOND正常参与竞争
|
||||
"""
|
||||
if not scores:
|
||||
return []
|
||||
|
||||
cfg = self.bond_threshold_config
|
||||
bond_code = cfg.get('bond_code', '931862.CSI') if cfg.get('enabled') else None
|
||||
|
||||
# 建立 group -> (code, score) 映射
|
||||
# V3: 排除 BOND 大类(它不参与竞争)
|
||||
group_champions = {}
|
||||
for code, score in scores.items():
|
||||
group = self.group_mapping.get(code, 'default')
|
||||
|
||||
# V3: BOND大类不参与竞争
|
||||
if cfg.get('enabled') and group == 'BOND':
|
||||
continue
|
||||
|
||||
if group not in group_champions or score > group_champions[group][1]:
|
||||
group_champions[group] = (code, score)
|
||||
|
||||
# 跨类排序取 Top N
|
||||
sorted_champions = sorted(group_champions.values(), key=lambda x: x[1], reverse=True)
|
||||
selected = [code for code, score in sorted_champions[:self.select_num]]
|
||||
|
||||
# V3: 空余仓位填充短债
|
||||
# 短债填充是防御机制,但需要有数据才能填充
|
||||
# bond有数据(含负值)→ 填充 ✓(防御机制不受动量影响)
|
||||
# bond无数据(NaN)→ 不填充 ✓(2002-2007正常退化)
|
||||
if cfg.get('fill_bond', False) and bond_code and bond_has_data:
|
||||
n_bond_slots = self.select_num - len(selected)
|
||||
if n_bond_slots > 0:
|
||||
for _ in range(n_bond_slots):
|
||||
selected.append(bond_code)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
class TrendFollower(SignalGenerator):
|
||||
"""趋势跟随器(定制实现)"""
|
||||
|
||||
mode = "trend"
|
||||
|
||||
def __init__(self, entry_threshold: float = 0.02, exit_threshold: float = -0.02, select_num: int = 1):
|
||||
super().__init__(entry_threshold=entry_threshold, exit_threshold=exit_threshold, select_num=select_num)
|
||||
self.entry_threshold = entry_threshold
|
||||
self.exit_threshold = exit_threshold
|
||||
self.select_num = select_num
|
||||
|
||||
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""生成趋势跟随信号"""
|
||||
result = pd.DataFrame(index=factor_data.index)
|
||||
|
||||
factor_cols = self._get_factor_columns(factor_data)
|
||||
|
||||
for col in factor_cols:
|
||||
trend_strength = factor_data[col]
|
||||
|
||||
result[f'{col}_entry'] = trend_strength > self.entry_threshold
|
||||
result[f'{col}_exit'] = trend_strength < self.exit_threshold
|
||||
|
||||
signals = []
|
||||
for date in result.index:
|
||||
entry_signals = []
|
||||
for col in factor_cols:
|
||||
if result.loc[date, f'{col}_entry']:
|
||||
score = factor_data.loc[date, col]
|
||||
if pd.notna(score):
|
||||
entry_signals.append((col, score))
|
||||
|
||||
entry_signals.sort(key=lambda x: x[1], reverse=True)
|
||||
selected = [item[0] for item in entry_signals[:self.select_num]]
|
||||
signals.append(','.join(selected) if selected else '')
|
||||
|
||||
result['signal'] = signals
|
||||
result['signal'] = result['signal'].shift(1)
|
||||
|
||||
return result
|
||||
|
||||
def _get_factor_columns(self, data: pd.DataFrame) -> List[str]:
|
||||
"""获取因子列名"""
|
||||
exclude_cols = ['signal', 'signal_raw', 'combined', 'open', 'high', 'low', 'close', 'volume']
|
||||
return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')]
|
||||
|
||||
|
||||
class ReversalTrader(SignalGenerator):
|
||||
"""反转交易器(定制实现)"""
|
||||
|
||||
mode = "reversal"
|
||||
|
||||
def __init__(self, overbought: float = 70, oversold: float = 30, reversal_threshold: float = 0.1):
|
||||
super().__init__(overbought=overbought, oversold=oversold, reversal_threshold=reversal_threshold)
|
||||
self.overbought = overbought
|
||||
self.oversold = oversold
|
||||
self.reversal_threshold = reversal_threshold
|
||||
|
||||
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""生成反转交易信号"""
|
||||
result = pd.DataFrame(index=factor_data.index)
|
||||
|
||||
factor_cols = self._get_factor_columns(factor_data)
|
||||
|
||||
for col in factor_cols:
|
||||
reversal_signal = factor_data[col]
|
||||
|
||||
result[f'{col}_buy'] = reversal_signal > self.reversal_threshold
|
||||
result[f'{col}_sell'] = reversal_signal < -self.reversal_threshold
|
||||
|
||||
signals = []
|
||||
for date in result.index:
|
||||
buy_signals = []
|
||||
sell_signals = []
|
||||
|
||||
for col in factor_cols:
|
||||
if result.loc[date, f'{col}_buy']:
|
||||
buy_signals.append(col)
|
||||
if result.loc[date, f'{col}_sell']:
|
||||
sell_signals.append(col)
|
||||
|
||||
if buy_signals:
|
||||
signals.append(f"BUY:{','.join(buy_signals)}")
|
||||
elif sell_signals:
|
||||
signals.append(f"SELL:{','.join(sell_signals)}")
|
||||
else:
|
||||
signals.append('')
|
||||
|
||||
result['signal'] = signals
|
||||
result['signal'] = result['signal'].shift(1)
|
||||
|
||||
return result
|
||||
|
||||
def _get_factor_columns(self, data: pd.DataFrame) -> List[str]:
|
||||
"""获取因子列名"""
|
||||
exclude_cols = ['signal', 'signal_raw', 'combined', 'open', 'high', 'low', 'close', 'volume']
|
||||
return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')]
|
||||
7
archive/strategies/us_rotation/__init__.py
Normal file
7
archive/strategies/us_rotation/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
美股轮动策略模块
|
||||
"""
|
||||
|
||||
from .strategy import USRotationStrategy
|
||||
|
||||
__all__ = ['USRotationStrategy']
|
||||
194
archive/strategies/us_rotation/config.yaml
Normal file
194
archive/strategies/us_rotation/config.yaml
Normal file
@@ -0,0 +1,194 @@
|
||||
# 美股轮动策略配置
|
||||
|
||||
# ==================== 候选池配置 ====================
|
||||
code_list:
|
||||
# 科技巨头
|
||||
"AAPL":
|
||||
name: "Apple"
|
||||
sector: "Technology"
|
||||
"ADBE":
|
||||
name: "Adobe"
|
||||
sector: "Technology"
|
||||
"AMD":
|
||||
name: "AMD"
|
||||
sector: "Technology"
|
||||
"AMZN":
|
||||
name: "Amazon"
|
||||
sector: "Technology"
|
||||
"ASML":
|
||||
name: "ASML"
|
||||
sector: "Technology"
|
||||
"AVGO":
|
||||
name: "Broadcom"
|
||||
sector: "Technology"
|
||||
"CRM":
|
||||
name: "Salesforce"
|
||||
sector: "Technology"
|
||||
"CRWD":
|
||||
name: "CrowdStrike"
|
||||
sector: "Technology"
|
||||
"CSCO":
|
||||
name: "Cisco"
|
||||
sector: "Technology"
|
||||
"FICO":
|
||||
name: "FICO"
|
||||
sector: "Technology"
|
||||
"GOOGL":
|
||||
name: "Google"
|
||||
sector: "Technology"
|
||||
"INTC":
|
||||
name: "Intel"
|
||||
sector: "Technology"
|
||||
"KLAC":
|
||||
name: "KLA"
|
||||
sector: "Technology"
|
||||
"LRCX":
|
||||
name: "Lam Research"
|
||||
sector: "Technology"
|
||||
"META":
|
||||
name: "Meta"
|
||||
sector: "Technology"
|
||||
"MSFT":
|
||||
name: "Microsoft"
|
||||
sector: "Technology"
|
||||
"MU":
|
||||
name: "Micron"
|
||||
sector: "Technology"
|
||||
"NET":
|
||||
name: "Cloudflare"
|
||||
sector: "Technology"
|
||||
"NFLX":
|
||||
name: "Netflix"
|
||||
sector: "Technology"
|
||||
"NVDA":
|
||||
name: "NVIDIA"
|
||||
sector: "Technology"
|
||||
"ORCL":
|
||||
name: "Oracle"
|
||||
sector: "Technology"
|
||||
"PANW":
|
||||
name: "Palo Alto"
|
||||
sector: "Technology"
|
||||
"PLTR":
|
||||
name: "Palantir"
|
||||
sector: "Technology"
|
||||
"QCOM":
|
||||
name: "Qualcomm"
|
||||
sector: "Technology"
|
||||
"SNOW":
|
||||
name: "Snowflake"
|
||||
sector: "Technology"
|
||||
"TSLA":
|
||||
name: "Tesla"
|
||||
sector: "Technology"
|
||||
"TSM":
|
||||
name: "TSMC"
|
||||
sector: "Technology"
|
||||
|
||||
# 金融
|
||||
"AXP":
|
||||
name: "American Express"
|
||||
sector: "Financial"
|
||||
"BAC":
|
||||
name: "Bank of America"
|
||||
sector: "Financial"
|
||||
"C":
|
||||
name: "Citigroup"
|
||||
sector: "Financial"
|
||||
"GS":
|
||||
name: "Goldman Sachs"
|
||||
sector: "Financial"
|
||||
"JPM":
|
||||
name: "JPMorgan"
|
||||
sector: "Financial"
|
||||
"MA":
|
||||
name: "Mastercard"
|
||||
sector: "Financial"
|
||||
"MS":
|
||||
name: "Morgan Stanley"
|
||||
sector: "Financial"
|
||||
|
||||
# 消费零售
|
||||
"COST":
|
||||
name: "Costco"
|
||||
sector: "Consumer"
|
||||
"LULU":
|
||||
name: "Lululemon"
|
||||
sector: "Consumer"
|
||||
"PDD":
|
||||
name: "PDD Holdings"
|
||||
sector: "Consumer"
|
||||
"SHOP":
|
||||
name: "Shopify"
|
||||
sector: "Consumer"
|
||||
|
||||
# 医药健康
|
||||
"LLY":
|
||||
name: "Eli Lilly"
|
||||
sector: "Healthcare"
|
||||
"NVO":
|
||||
name: "Novo Nordisk"
|
||||
sector: "Healthcare"
|
||||
|
||||
# 其他
|
||||
"CAT":
|
||||
name: "Caterpillar"
|
||||
sector: "Industrial"
|
||||
"COIN":
|
||||
name: "Coinbase"
|
||||
sector: "Crypto"
|
||||
"CRCL":
|
||||
name: "Circle"
|
||||
sector: "Crypto"
|
||||
"FUTU":
|
||||
name: "Futu"
|
||||
sector: "Financial"
|
||||
"HOOD":
|
||||
name: "Robinhood"
|
||||
sector: "Financial"
|
||||
"SAP":
|
||||
name: "SAP"
|
||||
sector: "Technology"
|
||||
"SCCO":
|
||||
name: "Southern Copper"
|
||||
sector: "Materials"
|
||||
|
||||
# ==================== 基准配置 ====================
|
||||
benchmark:
|
||||
code: "NDX"
|
||||
name: "纳斯达克100"
|
||||
|
||||
# ==================== 回测参数 ====================
|
||||
start_date: "2016-01-01"
|
||||
|
||||
# ==================== 因子参数 ====================
|
||||
# 动量窗口(天数)
|
||||
n_days: 250
|
||||
# 因子类型
|
||||
factor_type: "momentum"
|
||||
|
||||
# ==================== 轮动参数 ====================
|
||||
# 不分组,直接选 Top N
|
||||
diversified: false
|
||||
select_num: 5
|
||||
|
||||
# ==================== 调仓控制 ====================
|
||||
# 每日调仓
|
||||
rebalance_days: 1
|
||||
# 调仓阈值:新组合得分超过当前组合 X% 才触发调仓
|
||||
rebalance_threshold: 0.0
|
||||
# 交易成本(双边)
|
||||
trade_cost: 0.001
|
||||
|
||||
# ==================== 数据缓存 ====================
|
||||
use_cache: true
|
||||
|
||||
# ==================== 数据源配置 ====================
|
||||
# SSH 隧道配置(用于 yfinance)
|
||||
ssh_tunnel:
|
||||
enabled: true
|
||||
host: "8.218.167.69"
|
||||
port: 22
|
||||
username: "root"
|
||||
key_path: "hk_ecs.pem"
|
||||
local_port: 1080
|
||||
354
archive/strategies/us_rotation/strategy.py
Normal file
354
archive/strategies/us_rotation/strategy.py
Normal file
@@ -0,0 +1,354 @@
|
||||
"""
|
||||
美股轮动策略
|
||||
|
||||
纯美股轮动策略,使用动量因子选股
|
||||
特点:
|
||||
- 全部使用 yfinance 数据源
|
||||
- 美股交易日历
|
||||
- 不分组,直接选 Top 5
|
||||
- 基准为纳指 NDX
|
||||
"""
|
||||
|
||||
import sys
|
||||
import yaml
|
||||
import time
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
# 添加项目根目录
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from datasource.yfinance_source import YFinanceSource
|
||||
from datasource.ssh_tunnel import SSHTunnelManager
|
||||
from strategies.shared.factors.momentum import MomentumFactor
|
||||
from strategies.shared.signals.selectors import TopNSelector
|
||||
from framework.execution import BacktestExecutor
|
||||
|
||||
|
||||
class USRotationStrategy:
|
||||
"""美股轮动策略"""
|
||||
|
||||
def __init__(self, config_path: str = None, config: dict = None):
|
||||
"""
|
||||
初始化策略
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
config: 配置字典(可选)
|
||||
"""
|
||||
if config_path:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
self.config = yaml.safe_load(f)
|
||||
elif config:
|
||||
self.config = config
|
||||
else:
|
||||
raise ValueError("需要提供 config_path 或 config")
|
||||
|
||||
# 结束日期默认今天
|
||||
if not self.config.get('end_date'):
|
||||
self.config['end_date'] = datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
# 应用配置
|
||||
self._apply_config()
|
||||
|
||||
# 初始化因子
|
||||
self._factor = MomentumFactor(
|
||||
n_days=self.n_days,
|
||||
weighted=True,
|
||||
crash_filter=True
|
||||
)
|
||||
|
||||
# 初始化选择器(不分组,直接选 Top N)
|
||||
self._selector = TopNSelector(
|
||||
select_num=self.select_num,
|
||||
rebalance_days=self.rebalance_days,
|
||||
rebalance_threshold=self.rebalance_threshold
|
||||
)
|
||||
|
||||
# 数据源(延迟初始化)
|
||||
self._yfinance: Optional[YFinanceSource] = None
|
||||
self._tunnel: Optional[SSHTunnelManager] = None
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, config_path: str) -> 'USRotationStrategy':
|
||||
"""从 YAML 文件创建策略实例"""
|
||||
return cls(config_path=config_path)
|
||||
|
||||
def _apply_config(self):
|
||||
"""应用配置参数"""
|
||||
self.select_num = self.config.get('select_num', 5)
|
||||
self.n_days = self.config.get('n_days', 250)
|
||||
self.rebalance_days = self.config.get('rebalance_days', 1)
|
||||
self.rebalance_threshold = self.config.get('rebalance_threshold', 0.0)
|
||||
self.trade_cost = self.config.get('trade_cost', 0.001)
|
||||
self.start_date = self.config.get('start_date', '2016-01-01')
|
||||
self.end_date = self.config['end_date']
|
||||
self.use_cache = self.config.get('use_cache', True)
|
||||
|
||||
def _start_tunnel(self) -> bool:
|
||||
"""启动 SSH 隧道"""
|
||||
ssh_config = self.config.get('ssh_tunnel', {})
|
||||
if not ssh_config.get('enabled', False):
|
||||
return True
|
||||
|
||||
self._tunnel = SSHTunnelManager(ssh_config)
|
||||
return self._tunnel.start()
|
||||
|
||||
def _stop_tunnel(self):
|
||||
"""停止 SSH 隧道"""
|
||||
if self._tunnel:
|
||||
self._tunnel.stop()
|
||||
self._tunnel = None
|
||||
|
||||
def _init_yfinance(self):
|
||||
"""初始化 YFinance 数据源"""
|
||||
if self._yfinance is None:
|
||||
self._yfinance = YFinanceSource(use_ssh_tunnel=True)
|
||||
|
||||
def fetch_data(self) -> Dict:
|
||||
"""获取数据(全部使用 yfinance)"""
|
||||
print("\n" + "=" * 60)
|
||||
print("获取美股数据")
|
||||
print("=" * 60)
|
||||
|
||||
code_list_config = self.config.get('code_list', {})
|
||||
benchmark_config = self.config.get('benchmark', {})
|
||||
benchmark_code = benchmark_config.get('code', 'NDX')
|
||||
|
||||
if not code_list_config:
|
||||
raise ValueError("配置中未找到 code_list")
|
||||
|
||||
codes = list(code_list_config.keys())
|
||||
print(f"标的池: {len(codes)} 只股票")
|
||||
print(f"基准: {benchmark_code}")
|
||||
print(f"时间范围: {self.start_date} ~ {self.end_date}")
|
||||
|
||||
# 启动 SSH 隓道
|
||||
print("\n启动 SSH 隧道...")
|
||||
if not self._start_tunnel():
|
||||
raise RuntimeError("SSH 隧道启动失败")
|
||||
|
||||
self._init_yfinance()
|
||||
|
||||
# 获取数据
|
||||
all_data: Dict[str, pd.DataFrame] = {}
|
||||
valid_codes: List[str] = []
|
||||
|
||||
print("\n获取股票数据...")
|
||||
for i, code in enumerate(codes):
|
||||
print(f" [{i+1}/{len(codes)}] {code}...", end=" ")
|
||||
|
||||
try:
|
||||
df = self._yfinance.fetch(code, self.start_date, self.end_date)
|
||||
time.sleep(0.5) # 避免限流
|
||||
|
||||
if df is not None and len(df) >= self.n_days:
|
||||
all_data[code] = df
|
||||
valid_codes.append(code)
|
||||
print(f"✓ {len(df)} 条")
|
||||
else:
|
||||
print(f"✗ 数据不足")
|
||||
except Exception as e:
|
||||
print(f"✗ 失败: {e}")
|
||||
|
||||
# 获取基准数据
|
||||
print(f"\n获取基准 {benchmark_code}...", end=" ")
|
||||
try:
|
||||
benchmark_df = self._yfinance.fetch(benchmark_code, self.start_date, self.end_date)
|
||||
if benchmark_df is not None and len(benchmark_df) > 0:
|
||||
print(f"✓ {len(benchmark_df)} 条")
|
||||
else:
|
||||
print(f"✗ 基准数据获取失败")
|
||||
benchmark_df = None
|
||||
except Exception as e:
|
||||
print(f"✗ 失败: {e}")
|
||||
benchmark_df = None
|
||||
|
||||
# 停止隧道
|
||||
self._stop_tunnel()
|
||||
|
||||
print(f"\n数据获取完成: {len(valid_codes)}/{len(codes)} 只有效")
|
||||
|
||||
return {
|
||||
'stock_data': all_data,
|
||||
'valid_codes': valid_codes,
|
||||
'benchmark': benchmark_df,
|
||||
'benchmark_code': benchmark_code
|
||||
}
|
||||
|
||||
def compute_factors(self, data: Dict) -> pd.DataFrame:
|
||||
"""计算动量因子"""
|
||||
print("\n" + "=" * 60)
|
||||
print("计算动量因子")
|
||||
print("=" * 60)
|
||||
|
||||
stock_data = data['stock_data']
|
||||
valid_codes = data['valid_codes']
|
||||
|
||||
factor_values: Dict[str, pd.Series] = {}
|
||||
|
||||
for code in valid_codes:
|
||||
df = stock_data[code]
|
||||
|
||||
if 'close' not in df.columns:
|
||||
continue
|
||||
|
||||
# 数据长度检查
|
||||
if len(df) < self.n_days:
|
||||
print(f" ⚠ {code}: 数据不足 {self.n_days} 天,跳过")
|
||||
continue
|
||||
|
||||
# MomentumFactor.compute 需要DataFrame
|
||||
factor_series = self._factor.compute(df)
|
||||
|
||||
if factor_series is not None and len(factor_series) > 0:
|
||||
factor_values[code] = factor_series
|
||||
|
||||
# 合成 DataFrame
|
||||
factor_df = pd.DataFrame(factor_values)
|
||||
|
||||
print(f"\n因子计算完成: {len(factor_df.columns)} 只标的")
|
||||
print(f" 窗口: {self.n_days} 天")
|
||||
if len(factor_df) > 0:
|
||||
print(f" 日期范围: {factor_df.index.min()} ~ {factor_df.index.max()}")
|
||||
|
||||
return factor_df
|
||||
|
||||
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""生成轮动信号(不分组,直接选 Top 5)"""
|
||||
print("\n" + "=" * 60)
|
||||
print("生成轮动信号")
|
||||
print("=" * 60)
|
||||
|
||||
# 不分组,直接对因子排序选 Top 5
|
||||
# TopNSelector.generate 会自动处理调仓周期和T+1
|
||||
signals_df = self._selector.generate(factor_df)
|
||||
|
||||
print(f"\n信号生成完成:")
|
||||
print(f" 选股数量: {self.select_num}")
|
||||
if 'signal' in signals_df.columns:
|
||||
valid_signals = signals_df[signals_df['signal'] != '']
|
||||
print(f" 有效信号天数: {len(valid_signals)}")
|
||||
|
||||
return signals_df
|
||||
|
||||
def run_backtest(self, data: Dict = None, save_path: str = None) -> Dict:
|
||||
"""运行回测"""
|
||||
print("\n" + "=" * 60)
|
||||
print("美股动量轮动策略 回测")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 获取数据
|
||||
if data is None:
|
||||
data = self.fetch_data()
|
||||
|
||||
valid_codes = data['valid_codes']
|
||||
|
||||
# 2. 计算因子
|
||||
factor_df = self.compute_factors(data)
|
||||
|
||||
# 3. 生成信号
|
||||
signals = self.generate_signals(factor_df)
|
||||
|
||||
# 4. 构建收益数据(BacktestExecutor期望列名格式:日收益率_{code})
|
||||
print("\n构建收益数据...")
|
||||
stock_data = data['stock_data']
|
||||
|
||||
# 计算日收益率,列名格式为 '日收益率_{code}'
|
||||
returns_data: Dict[str, pd.Series] = {}
|
||||
for code in valid_codes:
|
||||
if code in stock_data:
|
||||
df = stock_data[code]
|
||||
if 'close' in df.columns:
|
||||
returns_data[f'日收益率_{code}'] = df['close'].pct_change()
|
||||
|
||||
returns_df = pd.DataFrame(returns_data)
|
||||
|
||||
# 对齐日期
|
||||
common_dates = signals.index.intersection(returns_df.index)
|
||||
signals = signals.loc[common_dates]
|
||||
returns_df = returns_df.loc[common_dates]
|
||||
|
||||
# 5. 执行回测
|
||||
print("\n执行回测...")
|
||||
executor = BacktestExecutor(
|
||||
initial_capital=100,
|
||||
trade_cost=self.trade_cost,
|
||||
select_num=self.select_num
|
||||
)
|
||||
|
||||
portfolio = executor.execute(signals, returns_df)
|
||||
|
||||
# 6. 计算基准收益
|
||||
benchmark_df = data['benchmark']
|
||||
benchmark_code = data['benchmark_code']
|
||||
|
||||
if benchmark_df is not None and 'close' in benchmark_df.columns:
|
||||
benchmark_returns = benchmark_df['close'].pct_change()
|
||||
# 对齐日期
|
||||
benchmark_returns = benchmark_returns.loc[common_dates]
|
||||
|
||||
# 7. 输出结果
|
||||
if hasattr(portfolio, 'backtest_result') and portfolio.backtest_result is not None:
|
||||
result = portfolio.backtest_result
|
||||
|
||||
# 策略净值(DataFrame列)
|
||||
if '策略净值' in result.columns:
|
||||
strategy_nav = result['策略净值'].values
|
||||
final_nav = strategy_nav[-1] if len(strategy_nav) > 0 else 100
|
||||
total_return = (final_nav - 1) * 100 # 净值归一化起点为1
|
||||
else:
|
||||
final_nav = 100
|
||||
total_return = 0
|
||||
|
||||
# 基准收益
|
||||
if benchmark_df is not None and 'close' in benchmark_df.columns:
|
||||
benchmark_start = benchmark_df['close'].iloc[0]
|
||||
benchmark_end = benchmark_df['close'].iloc[-1]
|
||||
benchmark_return = (benchmark_end / benchmark_start - 1) * 100
|
||||
else:
|
||||
benchmark_return = 0
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("回测结果")
|
||||
print("=" * 60)
|
||||
print(f"策略最终净值: {final_nav:.2f}")
|
||||
print(f"策略总收益: {total_return:.2f}%")
|
||||
print(f"基准 ({benchmark_code}) 收益: {benchmark_return:.2f}%")
|
||||
print(f"超额收益: {total_return - benchmark_return:.2f}%")
|
||||
print(f"交易成本: {self.trade_cost * 100:.1f}%")
|
||||
|
||||
# 保存结果
|
||||
if save_path:
|
||||
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 保存净值曲线
|
||||
if '策略净值' in result.columns:
|
||||
nav_df = pd.DataFrame({
|
||||
'date': result.index,
|
||||
'strategy_nav': result['策略净值'].values
|
||||
})
|
||||
if benchmark_df is not None and 'close' in benchmark_df.columns:
|
||||
# 重建基准净值
|
||||
benchmark_nav = (benchmark_df['close'].pct_change() + 1).cumprod()
|
||||
nav_df['benchmark_nav'] = benchmark_nav.reindex(result.index, method='ffill').values
|
||||
nav_df.to_csv(f"{save_path}_nav.csv", index=False)
|
||||
|
||||
# 保存信号
|
||||
signals.to_csv(f"{save_path}_signals.csv")
|
||||
|
||||
print(f"\n报告保存: {save_path}_*.csv")
|
||||
|
||||
return {
|
||||
'final_nav': final_nav,
|
||||
'total_return': total_return,
|
||||
'benchmark_return': benchmark_return,
|
||||
'excess_return': total_return - benchmark_return,
|
||||
'signals': signals,
|
||||
'result': result
|
||||
}
|
||||
|
||||
return {'signals': signals}
|
||||
Reference in New Issue
Block a user