fix: 数据源路由修复与因子计算改进
1. 修复期货路由逻辑:NYMEX期货(.NYM)走YFinance而非Tushare 2. 添加SSH隧道路径修复(原引擎) 3. 因子计算只使用close列(处理部分指数只有收盘价的情况) 4. 添加数据不足和缺失率剔除日志 收益对比: - 原引擎(剔除国债): 累计1804%, 调仓459次 - 新框架: 累计772%, 调仓1276次 差异原因待查: - 国债剔除逻辑不同 - 调仓频率差异
This commit is contained in:
@@ -34,8 +34,8 @@ class SSHTunnelManager:
|
||||
# 处理 key_path:如果是相对路径,转换为绝对路径
|
||||
key_path = config.get("key_path", "")
|
||||
if key_path and not os.path.isabs(key_path):
|
||||
# 相对于项目根目录
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
# 相对于项目根目录(需要跳5层:datasource->core->legacy_core->archive->etf)
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
key_path = str(project_root / key_path)
|
||||
self.key_path = key_path
|
||||
print(f"SSH 私钥路径: {self.key_path}")
|
||||
|
||||
@@ -32,8 +32,8 @@ class SSHTunnelManager:
|
||||
# 处理 key_path:如果是相对路径,转换为绝对路径
|
||||
key_path = config.get("key_path", "")
|
||||
if key_path and not os.path.isabs(key_path):
|
||||
# 相对于项目根目录
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
# 相对于项目根目录(需要跳5层:datasource->core->legacy_core->archive->etf)
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
key_path = str(project_root / key_path)
|
||||
self.key_path = key_path
|
||||
|
||||
|
||||
@@ -113,7 +113,8 @@ class HybridDataSource:
|
||||
Optional[pd.DataFrame], # etf_nav_data: ETF净值
|
||||
Optional[pd.DataFrame], # benchmark_data: 基准数据
|
||||
List[str], # valid_codes: 有效代码列表
|
||||
Dict[str, pd.DataFrame] # index_ohlcv_data: 原始OHLCV数据
|
||||
Dict[str, pd.DataFrame], # index_ohlcv_data: 原始OHLCV数据
|
||||
Dict[str, str] # etf_code_map: {指数代码: ETF代码} 映射
|
||||
]:
|
||||
"""
|
||||
批量获取数据
|
||||
@@ -125,7 +126,7 @@ class HybridDataSource:
|
||||
end_date: 结束日期
|
||||
|
||||
Returns:
|
||||
(index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data)
|
||||
(index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_code_map)
|
||||
"""
|
||||
if end_date is None:
|
||||
end_date = datetime.now().strftime('%Y-%m-%d')
|
||||
@@ -247,7 +248,7 @@ class HybridDataSource:
|
||||
benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize()
|
||||
print(f"\n✓ 基准 {benchmark_code}: {len(benchmark_data)} 条")
|
||||
|
||||
return index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data
|
||||
return index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_codes
|
||||
|
||||
def __enter__(self):
|
||||
self._start_tunnel()
|
||||
|
||||
@@ -226,8 +226,10 @@ class TushareSource:
|
||||
return code.endswith(".SH") or code.endswith(".SZ") or code.endswith(".SS") or code.endswith(".CSI")
|
||||
|
||||
def is_futures(self, code: str) -> bool:
|
||||
"""判断是否为期货"""
|
||||
return ".SHF" in code or ".NYM" in code or ".DCE" in code or ".CZC" in code
|
||||
"""判断是否为中国期货(仅支持上期所、大商所、郑商所)"""
|
||||
# 只支持中国交易所期货(.SHF上期所、.DCE大商所、.CZC郑商所)
|
||||
# NYMEX (.NYM) 和 ICE (.ICE) 走 YFinance
|
||||
return ".SHF" in code or ".DCE" in code or ".CZC" in code
|
||||
|
||||
def fetch(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
|
||||
@@ -125,7 +125,7 @@ class RotationStrategy(StrategyBase):
|
||||
)
|
||||
|
||||
# 调用 fetch_all
|
||||
index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data = \
|
||||
index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_code_map = \
|
||||
data_source.fetch_all(
|
||||
code_config=code_list_config,
|
||||
benchmark_code=benchmark_code,
|
||||
@@ -139,7 +139,8 @@ class RotationStrategy(StrategyBase):
|
||||
'etf_data': etf_data,
|
||||
'etf_nav_data': etf_nav_data,
|
||||
'benchmark_data': benchmark_data,
|
||||
'valid_codes': valid_codes
|
||||
'valid_codes': valid_codes,
|
||||
'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射
|
||||
}
|
||||
|
||||
def compute_factors(self, data: dict) -> pd.DataFrame:
|
||||
@@ -148,13 +149,42 @@ class RotationStrategy(StrategyBase):
|
||||
valid_codes = data['valid_codes']
|
||||
|
||||
factor_values = {}
|
||||
final_valid_codes = []
|
||||
|
||||
for code in valid_codes:
|
||||
df = index_data[code]
|
||||
if len(df) >= self.n_days:
|
||||
values = self._factor.compute(df)
|
||||
factor_values[code] = values
|
||||
# 只使用 close 列计算因子(匹配原引擎逻辑:部分指数只有收盘价)
|
||||
if 'close' in df.columns:
|
||||
close_series = df['close'].dropna()
|
||||
else:
|
||||
close_series = df.dropna()
|
||||
|
||||
# 原引擎剔除逻辑:close 数据需要至少 n_days + 1 条
|
||||
if len(close_series) < self.n_days + 1:
|
||||
print(f" ⚠ 剔除 {code}: 数据不足 ({len(close_series)} < {self.n_days + 1})")
|
||||
continue
|
||||
|
||||
# 只传入 close 列给因子计算器
|
||||
close_df = pd.DataFrame({'close': close_series})
|
||||
values = self._factor.compute(close_df)
|
||||
factor_values[code] = values
|
||||
final_valid_codes.append(code)
|
||||
|
||||
return pd.DataFrame(factor_values)
|
||||
factor_df = pd.DataFrame(factor_values)
|
||||
|
||||
# 过滤缺失率过高的标的
|
||||
total_rows = len(factor_df)
|
||||
for code in final_valid_codes:
|
||||
if code in factor_df.columns:
|
||||
null_pct = factor_df[code].isnull().sum() / total_rows
|
||||
if null_pct > 0.5:
|
||||
print(f" ⚠ 剔除 {code}: 缺失率 {null_pct:.1%} 过高")
|
||||
factor_df = factor_df.drop(columns=[code])
|
||||
|
||||
# 更新有效代码列表
|
||||
data['valid_codes'] = [c for c in final_valid_codes if c in factor_df.columns]
|
||||
|
||||
return factor_df
|
||||
|
||||
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""生成信号"""
|
||||
@@ -198,25 +228,35 @@ class RotationStrategy(StrategyBase):
|
||||
# 4. 执行回测
|
||||
print("\n执行回测...")
|
||||
|
||||
# 使用对齐后的指数收盘价数据获取日期基准
|
||||
index_close = data.get('index_close')
|
||||
# 获取ETF数据和代码映射
|
||||
etf_data = data.get('etf_data')
|
||||
etf_code_map = data.get('etf_code_map', {}) # {指数代码: ETF代码}
|
||||
|
||||
# 计算日收益率(使用对齐后的收盘价数据)
|
||||
if index_close is not None and not index_close.empty:
|
||||
returns_df = index_close.pct_change()
|
||||
returns_df.columns = [f'日收益率_{col}' for col in returns_df.columns]
|
||||
else:
|
||||
# 回退到原始数据
|
||||
# 计算日收益率(使用ETF价格数据,匹配原引擎逻辑)
|
||||
if etf_data is not None and not etf_data.empty:
|
||||
# 使用ETF价格计算收益,列名保持指数代码格式
|
||||
returns_data = {}
|
||||
for code in valid_codes:
|
||||
if code in index_data:
|
||||
df = index_data[code]
|
||||
returns_data[f'日收益率_{code}'] = df['close'].pct_change()
|
||||
for idx_code in valid_codes:
|
||||
etf_code = etf_code_map.get(idx_code, idx_code)
|
||||
if etf_code in etf_data.columns:
|
||||
returns_data[f'日收益率_{idx_code}'] = etf_data[etf_code].pct_change()
|
||||
returns_df = pd.DataFrame(returns_data)
|
||||
|
||||
if valid_codes:
|
||||
first_code = valid_codes[0]
|
||||
returns_df.index = index_data[first_code].index
|
||||
else:
|
||||
# 回退到指数收盘价数据
|
||||
index_close = data.get('index_close')
|
||||
if index_close is not None and not index_close.empty:
|
||||
returns_df = index_close.pct_change()
|
||||
returns_df.columns = [f'日收益率_{col}' for col in returns_df.columns]
|
||||
else:
|
||||
returns_data = {}
|
||||
for code in valid_codes:
|
||||
if code in index_data:
|
||||
df = index_data[code]
|
||||
returns_data[f'日收益率_{code}'] = df['close'].pct_change()
|
||||
returns_df = pd.DataFrame(returns_data)
|
||||
if valid_codes:
|
||||
first_code = valid_codes[0]
|
||||
returns_df.index = index_data[first_code].index
|
||||
|
||||
# 确保信号和收益率数据日期对齐
|
||||
common_dates = signals.index.intersection(returns_df.index)
|
||||
|
||||
@@ -164,10 +164,6 @@ class TopNSelector(SignalGenerator):
|
||||
factor_cols: List[str]
|
||||
) -> bool:
|
||||
"""检查是否应该调仓(得分阈值检查)"""
|
||||
if self.rebalance_threshold <= 0:
|
||||
# 无阈值,直接调仓
|
||||
return target != current_held
|
||||
|
||||
# 提取当前持仓和目标持仓的代码
|
||||
old_codes = [c for c in current_held.split(',') if c]
|
||||
new_codes = [c for c in target.split(',') if c]
|
||||
@@ -176,13 +172,14 @@ class TopNSelector(SignalGenerator):
|
||||
return True
|
||||
|
||||
if set(new_codes) == set(old_codes):
|
||||
return False
|
||||
return False # 组合完全相同,不调仓
|
||||
|
||||
# 计算新旧组合的总得分
|
||||
old_total = sum(float(row.get(col, 0)) for col in factor_cols if col in old_codes)
|
||||
new_total = sum(float(row.get(col, 0)) for col in factor_cols if col in new_codes)
|
||||
|
||||
# 新组合得分需超过当前组合一定比例才调仓
|
||||
# 即使 threshold=0,也要确保 new_total >= old_total
|
||||
if old_total > 0:
|
||||
return (new_total / old_total - 1) >= self.rebalance_threshold
|
||||
|
||||
|
||||
Reference in New Issue
Block a user