feat(data-source): 支持指数-ETF双轨数据获取及因子计算

- 新增使用Tushare获取A股ETF价格及净值数据的私有方法
- fetch_all方法支持接收完整代码配置,区分指数与ETF及市场类别
- 指数数据和ETF数据分别下载,ETF净值数据用于溢价率计算
- 采用A股交易日为主交易日历,非A股数据前向填充对齐
- 调整因子计算,支持指数价格计算因子,ETF价格计算收益率
- run_rotation脚本和RotationStrategy引擎适配指数-ETF配置格式
- 代码结构优化,增强多市场及加密货币处理能力
This commit is contained in:
2026-03-25 22:01:44 +08:00
parent e6898a851c
commit ec749314bc
4 changed files with 366 additions and 183 deletions

View File

@@ -209,6 +209,116 @@ class HybridDataSource:
if value is not None: if value is not None:
os.environ[key] = value os.environ[key] = value
def _fetch_etf(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""使用 Tushare 获取A股ETF数据fund_daily接口"""
import os
# 临时清除代理环境变量
original_proxy = {}
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
original_proxy[key] = os.environ.pop(key, None)
try:
import tushare as ts
pro = ts.pro_api(self._get_tushare_token())
# 转换代码格式 (510300.SH -> 510300.SH)
ts_code = code.replace(".SS", ".SH")
# 获取ETF日线数据
df = pro.fund_daily(
ts_code=ts_code,
start_date=start_date.replace("-", ""),
end_date=end_date.replace("-", "")
)
if df is None or len(df) == 0:
return None
# 标准化列名
df = df.rename(columns={
"trade_date": "date",
"open": "open",
"high": "high",
"low": "low",
"close": "close",
"vol": "volume",
})
# 转换日期格式
df["date"] = pd.to_datetime(df["date"])
df = df.set_index("date")
df = df.sort_index()
# 添加代码列
df["code"] = code
return df
except Exception as e:
print(f"Tushare 下载ETF {code} 失败: {e}")
return None
finally:
# 恢复代理环境变量
for key, value in original_proxy.items():
if value is not None:
os.environ[key] = value
def _fetch_etf_nav(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""使用 Tushare 获取ETF净值数据fund_nav接口"""
import os
# 临时清除代理环境变量
original_proxy = {}
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
original_proxy[key] = os.environ.pop(key, None)
try:
import tushare as ts
pro = ts.pro_api(self._get_tushare_token())
# 转换代码格式
ts_code = code.replace(".SS", ".SH")
# 获取ETF净值数据
df = pro.fund_nav(
ts_code=ts_code,
start_date=start_date.replace("-", ""),
end_date=end_date.replace("-", "")
)
if df is None or len(df) == 0:
return None
# 标准化列名
df = df.rename(columns={
"nav_date": "date",
"unit_nav": "nav",
})
# 转换日期格式
df["date"] = pd.to_datetime(df["date"])
df = df.set_index("date")
df = df.sort_index()
# 添加代码列
df["code"] = code
return df
except Exception as e:
print(f"Tushare 下载ETF净值 {code} 失败: {e}")
return None
finally:
# 恢复代理环境变量
for key, value in original_proxy.items():
if value is not None:
os.environ[key] = value
def _fetch_yfinance(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: def _fetch_yfinance(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""使用 YFinance 获取数据""" """使用 YFinance 获取数据"""
import time import time
@@ -303,39 +413,50 @@ class HybridDataSource:
def fetch_all( def fetch_all(
self, self,
code_list, # list[代码] 或 dict{代码: 名称} code_config: dict, # {代码: {name, etf, market}}
benchmark_code: str, benchmark_code: str,
start_date: str, start_date: str,
end_date: str, end_date: str,
) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], list]: ) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], Optional[pd.DataFrame], Optional[pd.DataFrame], list]:
""" """
批量获取数据 批量获取数据(支持指数-ETF映射
注意:由于 Tushare(中国A股) 和 YFinance(美股/加密货币) 的交易日历不同,
这里返回的是长格式数据,由调用方分别处理各市场的数据 Args:
code_config: 配置字典,格式为 {index_code: {name, etf, market}}
benchmark_code: 基准指数代码
start_date: 开始日期
end_date: 结束日期
Returns: Returns:
(etf_data, benchmark_data, valid_codes) (index_data, etf_data, etf_nav_data, benchmark_data, valid_codes)
etf_data: DataFrame with columns [code, close, source], index=date - index_data: 指数数据(用于因子计算)
- etf_data: ETF价格数据用于收益计算
- etf_nav_data: ETF净值数据用于溢价率计算
- benchmark_data: 基准数据
- valid_codes: 有效代码列表
""" """
all_data = [] index_data_list = []
etf_data_list = []
valid_codes = [] valid_codes = []
# 兼容列表和字典格式 # 提取指数代码和ETF代码
if isinstance(code_list, dict): index_codes = list(code_config.keys())
codes = list(code_list.keys()) etf_codes = {}
code_name_map = code_list for idx_code, cfg in code_config.items():
else: if cfg.get('etf'):
codes = code_list etf_codes[idx_code] = cfg['etf']
code_name_map = {c: c for c in codes}
print(f"开始下载 {len(index_codes)} 只标的的数据...")
print(f"开始下载 {len(codes)}标的的数据...") print(f" 指数代码: {len(index_codes)}")
china_codes = [c for c in codes if self._is_china_index(c)] print(f" ETF映射: {len(etf_codes)}")
global_codes = [c for c in codes if not self._is_china_index(c)]
china_codes = [c for c in index_codes if self._is_china_index(c)]
global_codes = [c for c in index_codes if not self._is_china_index(c)]
print(f" 中国A股指数: {len(china_codes)}") print(f" 中国A股指数: {len(china_codes)}")
print(f" 港股/美股/加密货币: {len(global_codes)}") print(f" 港股/美股/加密货币: {len(global_codes)}")
# 检查是否需要启动 socks2http 代理(用于加密货币) # 检查是否需要启动 socks2http 代理(用于加密货币)
crypto_codes = [c for c in codes if self._is_crypto(c)] crypto_codes = [c for c in index_codes if self._is_crypto(c)]
http_proxy = None http_proxy = None
socks2http_proc = None socks2http_proc = None
@@ -358,8 +479,9 @@ class HybridDataSource:
except Exception as e: except Exception as e:
print(f" ✗ 启动代理失败: {e}") print(f" ✗ 启动代理失败: {e}")
# 分别下载数据 # 下载指数数据
for code in codes: print("\n [1/2] 下载指数数据(用于因子计算)...")
for code in index_codes:
if self._is_china_index(code): if self._is_china_index(code):
source = "Tushare" source = "Tushare"
elif self._is_crypto(code): elif self._is_crypto(code):
@@ -367,8 +489,8 @@ class HybridDataSource:
else: else:
source = "YFinance" source = "YFinance"
name = code_name_map.get(code, code) name = code_config[code].get('name', code)
print(f" 下载 {code} ({name}) - {source}...", end=" ") print(f" 下载 {code} ({name}) - {source}...", end=" ")
# 加密货币使用 HTTP 代理 # 加密货币使用 HTTP 代理
proxy = http_proxy if self._is_crypto(code) else None proxy = http_proxy if self._is_crypto(code) else None
@@ -378,95 +500,163 @@ class HybridDataSource:
# 标准化数据格式 # 标准化数据格式
data = data.copy() data = data.copy()
data['source'] = source data['source'] = source
data['code'] = code # 确保code列正确
# 确保索引是日期格式且无时区,只保留日期部分(去掉时间) # 确保索引是日期格式且无时区,只保留日期部分(去掉时间)
data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize() data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize()
all_data.append(data[['code', 'close', 'source']]) index_data_list.append(data[['code', 'close', 'source']])
valid_codes.append(code) valid_codes.append(code)
print(f"{len(data)}") print(f"{len(data)}")
else: else:
print("✗ 无数据") print("✗ 无数据")
# 下载ETF数据价格+净值,用于溢价率计算)
etf_nav_data_list = [] # ETF净值数据
if etf_codes:
print("\n [2/2] 下载ETF数据价格+净值,用于溢价率计算)...")
for idx_code, etf_code in etf_codes.items():
name = code_config[idx_code].get('name', idx_code)
market = code_config[idx_code].get('market', 'A')
# 加密货币跳过ETF下载
if market == 'CRYPTO':
continue
print(f" 下载 ETF {etf_code} (对应指数 {idx_code})...", end=" ")
# 获取ETF价格数据
price_data = self._fetch_etf(etf_code, start_date, end_date)
# 获取ETF净值数据
nav_data = self._fetch_etf_nav(etf_code, start_date, end_date)
if price_data is not None and len(price_data) > 0:
# 使用指数代码作为列名,保持与指数数据一致
price_data = price_data.copy()
price_data['source'] = 'Tushare-ETF'
price_data['code'] = idx_code
price_data.index = pd.to_datetime(price_data.index, utc=True).tz_localize(None).normalize()
etf_data_list.append(price_data[['code', 'close', 'source']])
# 处理净值数据
if nav_data is not None and len(nav_data) > 0:
nav_data = nav_data.copy()
nav_data['code'] = idx_code
nav_data.index = pd.to_datetime(nav_data.index, utc=True).tz_localize(None).normalize()
etf_nav_data_list.append(nav_data[['code', 'nav']])
print(f"✓ 价格{len(price_data)}条 净值{len(nav_data)}")
else:
print(f"✓ 价格{len(price_data)}条 (无净值数据)")
else:
print(f"✗ 无数据")
# 关闭 socks2http 代理 # 关闭 socks2http 代理
if socks2http_proc: if socks2http_proc:
socks2http_proc.terminate() socks2http_proc.terminate()
socks2http_proc.wait() socks2http_proc.wait()
print(f"\n socks2http 代理已关闭") print(f"\n socks2http 代理已关闭")
if not all_data: if not index_data_list:
return None, None, [] return None, None, None, None, []
# 检查数据源类型
sources = set(d['source'].iloc[0] for d in all_data)
if len(sources) == 1:
# 单一数据源:转换为宽格式(向后兼容)
all_df = pd.concat(all_data, ignore_index=False)
all_df = all_df.reset_index()
all_df['date'] = pd.to_datetime(all_df['date'], utc=True).dt.tz_localize(None)
etf_data = all_df.pivot_table(
index='date',
columns='code',
values='close',
aggfunc='first'
)
print(f"\n数据整理完成 (单一数据源 {list(sources)[0]}):")
print(f" 时间范围: {etf_data.index[0]} ~ {etf_data.index[-1]}")
print(f" 交易日数: {len(etf_data)}")
print(f" 有效标的: {len(etf_data.columns)}")
else:
# 多数据源以主市场Tushare/A股为基准其他市场数据前向填充
print(f"\n数据整理完成 (多数据源 - 以A股交易日为基准):")
# 合并所有数据(索引已经是标准化后的日期)
all_df = pd.concat(all_data, ignore_index=False)
all_df = all_df.reset_index()
# 重命名索引列为 date
if 'index' in all_df.columns:
all_df = all_df.rename(columns={'index': 'date'})
# 确保 date 列是日期格式(不含时间)
all_df['date'] = pd.to_datetime(all_df['date']).dt.normalize()
# 处理指数数据
print(f"\n整理指数数据(用于因子计算)...")
index_df = pd.concat(index_data_list, ignore_index=False)
index_df = index_df.reset_index()
if 'index' in index_df.columns:
index_df = index_df.rename(columns={'index': 'date'})
index_df['date'] = pd.to_datetime(index_df['date']).dt.normalize()
# 透视为宽格式
index_data = index_df.pivot_table(
index='date',
columns='code',
values='close',
aggfunc='first'
)
# 以A股交易日为基准对齐所有数据
tushare_codes = [c for c in valid_codes if self._is_china_index(c)]
if tushare_codes:
primary_dates = index_data[tushare_codes[0]].dropna().index
print(f" 主市场交易日: {len(primary_dates)}")
# 重新索引到主市场交易日
index_data = index_data.reindex(primary_dates)
# 对非A股指数进行前向填充
non_a_codes = [c for c in valid_codes if not self._is_china_index(c)]
for code in non_a_codes:
if code in index_data.columns:
index_data[code] = index_data[code].ffill().bfill()
print(f" 非A股标的: {len(non_a_codes)} 只 (已前向填充)")
print(f" 时间范围: {index_data.index[0]} ~ {index_data.index[-1]}")
print(f" 交易日数: {len(index_data)}")
# 处理ETF数据
if etf_data_list:
print(f"\n整理ETF数据用于收益计算...")
etf_df = pd.concat(etf_data_list, ignore_index=False)
etf_df = etf_df.reset_index()
if 'index' in etf_df.columns:
etf_df = etf_df.rename(columns={'index': 'date'})
etf_df['date'] = pd.to_datetime(etf_df['date']).dt.normalize()
# 透视为宽格式 # 透视为宽格式
etf_data = all_df.pivot_table( etf_data = etf_df.pivot_table(
index='date', index='date',
columns='code', columns='code',
values='close', values='close',
aggfunc='first' aggfunc='first'
) )
# 获取主市场Tushare交易日 # 对齐到主市场交易日
tushare_codes = [c for c in valid_codes if self._is_china_index(c)]
if tushare_codes: if tushare_codes:
# 使用第一个A股代码的日期作为主市场交易日
primary_dates = etf_data[tushare_codes[0]].dropna().index
print(f" 主市场交易日: {len(primary_dates)}")
# 重新索引到主市场交易日,使用前向填充
etf_data = etf_data.reindex(primary_dates) etf_data = etf_data.reindex(primary_dates)
# 对每个非主市场代码进行前向填充 print(f" ETF价格数据: {len(etf_data.columns)}")
yfinance_codes = [c for c in valid_codes if not self._is_china_index(c)] else:
for code in yfinance_codes: # 如果没有ETF数据使用指数数据代替
if code in etf_data.columns: etf_data = index_data.copy()
# 前向填充:用最近的有效价格填充休市日的数据 print(f"\n无ETF映射使用指数数据代替")
etf_data[code] = etf_data[code].ffill()
# 对于开头的NaN用后向填充 # 处理ETF净值数据
etf_data[code] = etf_data[code].bfill() etf_nav_data = None
if etf_nav_data_list:
print(f" 非主市场标的: {len(yfinance_codes)} 只 (已前向填充)") print(f"\n整理ETF净值数据用于溢价率计算...")
nav_df = pd.concat(etf_nav_data_list, ignore_index=False)
print(f" 时间范围: {etf_data.index[0]} ~ {etf_data.index[-1]}") nav_df = nav_df.reset_index()
print(f" 交易日数: {len(etf_data)}") if 'index' in nav_df.columns:
print(f" 有效标的: {len(etf_data.columns)}") nav_df = nav_df.rename(columns={'index': 'date'})
nav_df['date'] = pd.to_datetime(nav_df['date']).dt.normalize()
# 透视为宽格式
etf_nav_data = nav_df.pivot_table(
index='date',
columns='code',
values='nav',
aggfunc='first'
)
# 对齐到主市场交易日并前向填充缺失值净值数据通常T+1更新
if tushare_codes:
etf_nav_data = etf_nav_data.reindex(primary_dates)
etf_nav_data = etf_nav_data.ffill() # 前向填充缺失的净值数据
print(f" ETF净值数据: {len(etf_nav_data.columns)}")
# 获取基准数据 # 获取基准数据
benchmark_data = self.fetch_single(benchmark_code, start_date, end_date) benchmark_data = self.fetch_single(benchmark_code, start_date, end_date)
if benchmark_data is not None: if benchmark_data is not None:
# 标准化日期索引(无时区,只保留日期部分)
benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize() benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize()
print(f" ✓ 基准 {benchmark_code}: {len(benchmark_data)}") # 对齐到主市场交易日
if tushare_codes:
benchmark_data = benchmark_data.reindex(primary_dates)
print(f"\n✓ 基准 {benchmark_code}: {len(benchmark_data)}")
return etf_data, benchmark_data, valid_codes return index_data, etf_data, etf_nav_data, benchmark_data, valid_codes
def __enter__(self): def __enter__(self):
"""上下文管理器入口""" """上下文管理器入口"""

View File

@@ -80,110 +80,74 @@ def calculate_daily_return(price_series: pd.Series) -> pd.Series:
def compute_factors( def compute_factors(
etf_data: pd.DataFrame, index_data: pd.DataFrame,
code_list: list, code_list: list,
n: int = 25, n: int = 25,
factor_type: str = "slope_r2", factor_type: str = "slope_r2",
etf_data: pd.DataFrame = None,
code_config: dict = None,
) -> tuple[pd.DataFrame, list]: ) -> tuple[pd.DataFrame, list]:
""" """
计算所有指数的因子和日收益率 计算所有指数的因子和日收益率(支持指数-ETF双轨数据
支持长格式数据混合数据源Tushare + YFinance
Args: Args:
etf_data: DataFrame, 长格式数据,包含 [code, close, source] 列 index_data: 指数价格数据(宽格式,用于因子计算)
code_list: 指数代码列表 code_list: 指数代码列表
n: 动量/趋势窗口 n: 动量/趋势窗口
factor_type: 'momentum''slope_r2' factor_type: 'momentum''slope_r2'
etf_data: ETF价格数据宽格式用于收益计算
code_config: 代码配置字典 {code: {name, etf, market}},用于判断是否为加密货币
Returns: Returns:
tuple: (result_df, valid_codes) tuple: (result_df, valid_codes)
- result_df: 包含因子得分和日收益率的DataFrame
- valid_codes: 有效代码列表
""" """
# 检查数据格式 code_config = code_config or {}
if 'code' in etf_data.columns:
# 长格式数据 - 按 code 分别计算因子(旧逻辑,保留兼容) # 如果没有提供ETF数据创建一个空的DataFrame
all_factors = [] if etf_data is None:
valid_codes = [] etf_data = pd.DataFrame()
for code in code_list: result = index_data.copy()
code_data = etf_data[etf_data['code'] == code].copy()
if len(code_data) == 0: # 过滤掉缺失值过多的指数
print(f" ⚠ 跳过 {code}: 不在数据中") total_rows = len(result)
continue valid_codes = []
for code in code_list:
# 检查缺失值 if code not in result.columns:
null_pct = code_data['close'].isnull().sum() / len(code_data) print(f" ⚠ 跳过 {code}: 不在数据中")
if null_pct > 0.2: continue
print(f" ⚠ 剔除 {code}: 缺失率 {null_pct:.1%} 过高") null_pct = result[code].isnull().sum() / total_rows
continue if null_pct > 0.2:
print(f" ⚠ 剔除 {code}: 缺失率 {null_pct:.1%} 过高")
# 按日期排序 result = result.drop(columns=[code])
code_data = code_data.sort_index() else:
# 计算日收益率和因子
code_data[f"日收益率_{code}"] = calculate_daily_return(code_data['close'])
if factor_type == "momentum":
code_data[f"得分_{code}"] = calculate_momentum(code_data['close'], n)
elif factor_type == "slope_r2":
code_data[f"得分_{code}"] = calculate_slope_r2(code_data['close'], n)
else:
raise ValueError(f"不支持的因子类型: {factor_type}")
# 保留需要的列
code_data = code_data[[f"日收益率_{code}", f"得分_{code}"]]
all_factors.append(code_data)
valid_codes.append(code) valid_codes.append(code)
if not all_factors: # 对有效指数计算因子和收益率
raise ValueError("没有有效的指数数据") for code in valid_codes:
# 因子基于指数价格计算
if factor_type == "momentum":
result[f"得分_{code}"] = calculate_momentum(result[code], n)
elif factor_type == "slope_r2":
result[f"得分_{code}"] = calculate_slope_r2(result[code], n)
else:
raise ValueError(f"不支持的因子类型: {factor_type}")
# 日收益率基于指数价格计算(回测使用指数价格)
result[f"日收益率_{code}"] = calculate_daily_return(result[code])
# 合并所有因子的数据(按日期内连接 - 只保留所有指数都有数据的日期) # 按得分列做 dropna
result = all_factors[0] score_cols = [f"得分_{code}" for code in valid_codes]
for df in all_factors[1:]: result = result.dropna(subset=score_cols)
result = result.join(df, how='inner')
# 删除所有得分都是 NaN 的行(即窗口期内的数据)
score_cols = [f"得分_{code}" for code in valid_codes]
# 只删除完全无法比较的行所有得分都是NaN
result = result.dropna(subset=score_cols, how='all')
else:
# 宽格式数据(向后兼容)
result = etf_data.copy()
# 过滤掉缺失值过多的指数
total_rows = len(result)
valid_codes = []
for code in code_list:
if code not in result.columns:
print(f" ⚠ 跳过 {code}: 不在数据中")
continue
null_pct = result[code].isnull().sum() / total_rows
if null_pct > 0.2:
print(f" ⚠ 剔除 {code}: 缺失率 {null_pct:.1%} 过高")
result = result.drop(columns=[code])
else:
valid_codes.append(code)
# 对有效指数计算因子
for code in valid_codes:
result[f"日收益率_{code}"] = calculate_daily_return(result[code])
if factor_type == "momentum":
result[f"得分_{code}"] = calculate_momentum(result[code], n)
elif factor_type == "slope_r2":
result[f"得分_{code}"] = calculate_slope_r2(result[code], n)
else:
raise ValueError(f"不支持的因子类型: {factor_type}")
# 按得分列做 dropna
score_cols = [f"得分_{code}" for code in valid_codes]
result = result.dropna(subset=score_cols)
print("\n因子计算完成:") print("\n因子计算完成:")
print(f" 因子类型: {factor_type}") print(f" 因子类型: {factor_type}")
print(f" 窗口天数: {n}") print(f" 窗口天数: {n}")
print(f" 有效指数: {len(valid_codes)}/{len(code_list)}") print(f" 有效指数: {len(valid_codes)}/{len(code_list)}")
print(f" 有效数据: {len(result)}") print(f" 有效数据: {len(result)}")
if etf_data is not index_data:
print(f" 使用ETF数据计算收益: ✓")
return result, valid_codes return result, valid_codes

View File

@@ -59,22 +59,38 @@ def main():
from datetime import datetime from datetime import datetime
config['end_date'] = datetime.now().strftime('%Y-%m-%d') config['end_date'] = datetime.now().strftime('%Y-%m-%d')
# 从配置中读取 code_list 和 code_name_map # 从配置中读取 code_list(新的配置格式:{代码: {name, etf, market}}
# code_list 现在是一个字典 {代码: 名称}
code_list_config = config.get('code_list', {}) code_list_config = config.get('code_list', {})
# 提取代码列表和名称映射
if isinstance(code_list_config, dict): if isinstance(code_list_config, dict):
code_list = list(code_list_config.keys()) code_list = list(code_list_config.keys())
code_name_map = code_list_config # 构建 code_name_map: {代码: 名称}
code_name_map = {}
for code, cfg in code_list_config.items():
if isinstance(cfg, dict):
code_name_map[code] = cfg.get('name', code)
else:
# 兼容旧格式
code_name_map[code] = cfg
else: else:
# 兼容旧格式(列表) # 兼容旧格式(列表)
code_list = code_list_config code_list = code_list_config
code_name_map = DEFAULT_CODE_NAME_MAP code_name_map = DEFAULT_CODE_NAME_MAP
code_list_config = {}
benchmark_config = config.get('benchmark', {}) benchmark_config = config.get('benchmark', {})
benchmark_name = benchmark_config.get('name', DEFAULT_BENCHMARK_NAME) benchmark_name = benchmark_config.get('name', DEFAULT_BENCHMARK_NAME)
print(f"\n配置文件: {args.config}") print(f"\n配置文件: {args.config}")
print(f"候选标的: {len(code_list)}") print(f"候选标的: {len(code_list)}")
# 统计ETF映射情况
etf_count = sum(1 for cfg in code_list_config.values() if isinstance(cfg, dict) and cfg.get('etf'))
crypto_count = sum(1 for cfg in code_list_config.values() if isinstance(cfg, dict) and cfg.get('market') == 'CRYPTO')
print(f" - ETF映射: {etf_count}")
print(f" - 直接交易: {crypto_count} 只(加密货币)")
print(f"回测区间: {config['start_date']} ~ {config['end_date']}") print(f"回测区间: {config['start_date']} ~ {config['end_date']}")
print(f"因子类型: {config['factor_type']}") print(f"因子类型: {config['factor_type']}")
print(f"窗口天数: {config['n_days']}") print(f"窗口天数: {config['n_days']}")
@@ -82,8 +98,8 @@ def main():
print(f"调仓周期: {config['rebalance_days']}") print(f"调仓周期: {config['rebalance_days']}")
print(f"交易成本: {config['trade_cost']:.2%}") print(f"交易成本: {config['trade_cost']:.2%}")
# 更新 config 中的 code_list 为列表格式 # 保持 config 中的 code_list 为完整配置格式(用于引擎内部解析)
config['code_list'] = code_list # 不需要修改 config['code_list'],引擎会直接使用原始配置
# 创建策略实例 # 创建策略实例
strategy = RotationStrategy(config) strategy = RotationStrategy(config)
@@ -119,6 +135,10 @@ def main():
benchmark_name=benchmark_name, benchmark_name=benchmark_name,
save_path=args.save_path, save_path=args.save_path,
select_num=config["select_num"], select_num=config["select_num"],
code_config=code_list_config, # 传入完整配置以显示ETF映射
index_data=strategy.index_data, # 传入指数数据
etf_price_data=strategy.etf_data, # 传入ETF价格数据
etf_nav_data_raw=strategy.etf_nav_data, # 传入ETF净值数据
) )
elapsed = time.time() - start_time elapsed = time.time() - start_time

View File

@@ -34,31 +34,40 @@ class RotationStrategy(BacktestStrategy):
self.backtest_result = None self.backtest_result = None
def fetch_data(self) -> pd.DataFrame: def fetch_data(self) -> pd.DataFrame:
"""获取数据""" """获取数据(支持指数-ETF双轨数据"""
from config.settings import DEFAULT_BENCHMARK_CODE from config.settings import DEFAULT_BENCHMARK_CODE
# 从配置中读取基准代码,或使用默认值 # 从配置中读取基准代码,或使用默认值
benchmark_code = self.config.get("benchmark", {}).get("code", DEFAULT_BENCHMARK_CODE) benchmark_code = self.config.get("benchmark", {}).get("code", DEFAULT_BENCHMARK_CODE)
# 获取代码配置(包含 name, etf, market
code_config = self.config.get("code_list", {})
# 使用上下文管理器管理 SSH 隧道(如果是 YFinance 数据源) # 使用上下文管理器管理 SSH 隧道
with self.data_source: with self.data_source:
etf_data, benchmark_data, valid_codes = self.data_source.fetch_all( index_data, etf_data, etf_nav_data, benchmark_data, valid_codes = self.data_source.fetch_all(
self.config["code_list"], code_config,
benchmark_code, benchmark_code,
self.config["start_date"], self.config["start_date"],
self.config["end_date"], self.config["end_date"],
) )
self.etf_data = etf_data # 存储数据和配置
self.index_data = index_data # 指数数据(用于因子计算)
self.etf_data = etf_data # ETF价格数据用于收益计算
self.etf_nav_data = etf_nav_data # ETF净值数据用于溢价率计算
self.benchmark_data = benchmark_data self.benchmark_data = benchmark_data
self.valid_codes = valid_codes self.valid_codes = valid_codes
self.code_config = code_config # 代码配置(用于判断市场类型)
# 计算因子 # 计算因子传入两套数据指数数据用于因子ETF数据用于收益
factor_data, valid_codes = compute_factors( factor_data, valid_codes = compute_factors(
etf_data, index_data,
valid_codes, valid_codes,
n=self.config["n_days"], n=self.config["n_days"],
factor_type=self.config["factor_type"], factor_type=self.config["factor_type"],
etf_data=etf_data, # 传入ETF数据用于收益计算
code_config=code_config, # 传入配置以判断加密货币
) )
self.data = factor_data self.data = factor_data