feat(data-source): 支持指数-ETF双轨数据获取及因子计算
- 新增使用Tushare获取A股ETF价格及净值数据的私有方法 - fetch_all方法支持接收完整代码配置,区分指数与ETF及市场类别 - 指数数据和ETF数据分别下载,ETF净值数据用于溢价率计算 - 采用A股交易日为主交易日历,非A股数据前向填充对齐 - 调整因子计算,支持指数价格计算因子,ETF价格计算收益率 - run_rotation脚本和RotationStrategy引擎适配指数-ETF配置格式 - 代码结构优化,增强多市场及加密货币处理能力
This commit is contained in:
@@ -209,6 +209,116 @@ class HybridDataSource:
|
|||||||
if value is not None:
|
if value is not None:
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
|
|
||||||
|
def _fetch_etf(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||||
|
"""使用 Tushare 获取A股ETF数据(fund_daily接口)"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 临时清除代理环境变量
|
||||||
|
original_proxy = {}
|
||||||
|
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
|
||||||
|
original_proxy[key] = os.environ.pop(key, None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tushare as ts
|
||||||
|
|
||||||
|
pro = ts.pro_api(self._get_tushare_token())
|
||||||
|
|
||||||
|
# 转换代码格式 (510300.SH -> 510300.SH)
|
||||||
|
ts_code = code.replace(".SS", ".SH")
|
||||||
|
|
||||||
|
# 获取ETF日线数据
|
||||||
|
df = pro.fund_daily(
|
||||||
|
ts_code=ts_code,
|
||||||
|
start_date=start_date.replace("-", ""),
|
||||||
|
end_date=end_date.replace("-", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if df is None or len(df) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 标准化列名
|
||||||
|
df = df.rename(columns={
|
||||||
|
"trade_date": "date",
|
||||||
|
"open": "open",
|
||||||
|
"high": "high",
|
||||||
|
"low": "low",
|
||||||
|
"close": "close",
|
||||||
|
"vol": "volume",
|
||||||
|
})
|
||||||
|
|
||||||
|
# 转换日期格式
|
||||||
|
df["date"] = pd.to_datetime(df["date"])
|
||||||
|
df = df.set_index("date")
|
||||||
|
df = df.sort_index()
|
||||||
|
|
||||||
|
# 添加代码列
|
||||||
|
df["code"] = code
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Tushare 下载ETF {code} 失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 恢复代理环境变量
|
||||||
|
for key, value in original_proxy.items():
|
||||||
|
if value is not None:
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
|
def _fetch_etf_nav(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||||
|
"""使用 Tushare 获取ETF净值数据(fund_nav接口)"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 临时清除代理环境变量
|
||||||
|
original_proxy = {}
|
||||||
|
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
|
||||||
|
original_proxy[key] = os.environ.pop(key, None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tushare as ts
|
||||||
|
|
||||||
|
pro = ts.pro_api(self._get_tushare_token())
|
||||||
|
|
||||||
|
# 转换代码格式
|
||||||
|
ts_code = code.replace(".SS", ".SH")
|
||||||
|
|
||||||
|
# 获取ETF净值数据
|
||||||
|
df = pro.fund_nav(
|
||||||
|
ts_code=ts_code,
|
||||||
|
start_date=start_date.replace("-", ""),
|
||||||
|
end_date=end_date.replace("-", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if df is None or len(df) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 标准化列名
|
||||||
|
df = df.rename(columns={
|
||||||
|
"nav_date": "date",
|
||||||
|
"unit_nav": "nav",
|
||||||
|
})
|
||||||
|
|
||||||
|
# 转换日期格式
|
||||||
|
df["date"] = pd.to_datetime(df["date"])
|
||||||
|
df = df.set_index("date")
|
||||||
|
df = df.sort_index()
|
||||||
|
|
||||||
|
# 添加代码列
|
||||||
|
df["code"] = code
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Tushare 下载ETF净值 {code} 失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 恢复代理环境变量
|
||||||
|
for key, value in original_proxy.items():
|
||||||
|
if value is not None:
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
def _fetch_yfinance(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
def _fetch_yfinance(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||||
"""使用 YFinance 获取数据"""
|
"""使用 YFinance 获取数据"""
|
||||||
import time
|
import time
|
||||||
@@ -303,39 +413,50 @@ class HybridDataSource:
|
|||||||
|
|
||||||
def fetch_all(
|
def fetch_all(
|
||||||
self,
|
self,
|
||||||
code_list, # list[代码] 或 dict{代码: 名称}
|
code_config: dict, # {代码: {name, etf, market}}
|
||||||
benchmark_code: str,
|
benchmark_code: str,
|
||||||
start_date: str,
|
start_date: str,
|
||||||
end_date: str,
|
end_date: str,
|
||||||
) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], list]:
|
) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], Optional[pd.DataFrame], Optional[pd.DataFrame], list]:
|
||||||
"""
|
"""
|
||||||
批量获取数据
|
批量获取数据(支持指数-ETF映射)
|
||||||
注意:由于 Tushare(中国A股) 和 YFinance(美股/加密货币) 的交易日历不同,
|
|
||||||
这里返回的是长格式数据,由调用方分别处理各市场的数据
|
Args:
|
||||||
|
code_config: 配置字典,格式为 {index_code: {name, etf, market}}
|
||||||
|
benchmark_code: 基准指数代码
|
||||||
|
start_date: 开始日期
|
||||||
|
end_date: 结束日期
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(etf_data, benchmark_data, valid_codes)
|
(index_data, etf_data, etf_nav_data, benchmark_data, valid_codes)
|
||||||
etf_data: DataFrame with columns [code, close, source], index=date
|
- index_data: 指数数据(用于因子计算)
|
||||||
|
- etf_data: ETF价格数据(用于收益计算)
|
||||||
|
- etf_nav_data: ETF净值数据(用于溢价率计算)
|
||||||
|
- benchmark_data: 基准数据
|
||||||
|
- valid_codes: 有效代码列表
|
||||||
"""
|
"""
|
||||||
all_data = []
|
index_data_list = []
|
||||||
|
etf_data_list = []
|
||||||
valid_codes = []
|
valid_codes = []
|
||||||
|
|
||||||
# 兼容列表和字典格式
|
# 提取指数代码和ETF代码
|
||||||
if isinstance(code_list, dict):
|
index_codes = list(code_config.keys())
|
||||||
codes = list(code_list.keys())
|
etf_codes = {}
|
||||||
code_name_map = code_list
|
for idx_code, cfg in code_config.items():
|
||||||
else:
|
if cfg.get('etf'):
|
||||||
codes = code_list
|
etf_codes[idx_code] = cfg['etf']
|
||||||
code_name_map = {c: c for c in codes}
|
|
||||||
|
|
||||||
print(f"开始下载 {len(codes)} 只标的的数据...")
|
print(f"开始下载 {len(index_codes)} 只标的的数据...")
|
||||||
china_codes = [c for c in codes if self._is_china_index(c)]
|
print(f" 指数代码: {len(index_codes)} 只")
|
||||||
global_codes = [c for c in codes if not self._is_china_index(c)]
|
print(f" ETF映射: {len(etf_codes)} 只")
|
||||||
|
|
||||||
|
china_codes = [c for c in index_codes if self._is_china_index(c)]
|
||||||
|
global_codes = [c for c in index_codes if not self._is_china_index(c)]
|
||||||
print(f" 中国A股指数: {len(china_codes)} 只")
|
print(f" 中国A股指数: {len(china_codes)} 只")
|
||||||
print(f" 港股/美股/加密货币: {len(global_codes)} 只")
|
print(f" 港股/美股/加密货币: {len(global_codes)} 只")
|
||||||
|
|
||||||
# 检查是否需要启动 socks2http 代理(用于加密货币)
|
# 检查是否需要启动 socks2http 代理(用于加密货币)
|
||||||
crypto_codes = [c for c in codes if self._is_crypto(c)]
|
crypto_codes = [c for c in index_codes if self._is_crypto(c)]
|
||||||
http_proxy = None
|
http_proxy = None
|
||||||
socks2http_proc = None
|
socks2http_proc = None
|
||||||
|
|
||||||
@@ -358,8 +479,9 @@ class HybridDataSource:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" ✗ 启动代理失败: {e}")
|
print(f" ✗ 启动代理失败: {e}")
|
||||||
|
|
||||||
# 分别下载数据
|
# 下载指数数据
|
||||||
for code in codes:
|
print("\n [1/2] 下载指数数据(用于因子计算)...")
|
||||||
|
for code in index_codes:
|
||||||
if self._is_china_index(code):
|
if self._is_china_index(code):
|
||||||
source = "Tushare"
|
source = "Tushare"
|
||||||
elif self._is_crypto(code):
|
elif self._is_crypto(code):
|
||||||
@@ -367,7 +489,7 @@ class HybridDataSource:
|
|||||||
else:
|
else:
|
||||||
source = "YFinance"
|
source = "YFinance"
|
||||||
|
|
||||||
name = code_name_map.get(code, code)
|
name = code_config[code].get('name', code)
|
||||||
print(f" 下载 {code} ({name}) - {source}...", end=" ")
|
print(f" 下载 {code} ({name}) - {source}...", end=" ")
|
||||||
|
|
||||||
# 加密货币使用 HTTP 代理
|
# 加密货币使用 HTTP 代理
|
||||||
@@ -378,95 +500,163 @@ class HybridDataSource:
|
|||||||
# 标准化数据格式
|
# 标准化数据格式
|
||||||
data = data.copy()
|
data = data.copy()
|
||||||
data['source'] = source
|
data['source'] = source
|
||||||
|
data['code'] = code # 确保code列正确
|
||||||
# 确保索引是日期格式且无时区,只保留日期部分(去掉时间)
|
# 确保索引是日期格式且无时区,只保留日期部分(去掉时间)
|
||||||
data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize()
|
data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize()
|
||||||
all_data.append(data[['code', 'close', 'source']])
|
index_data_list.append(data[['code', 'close', 'source']])
|
||||||
valid_codes.append(code)
|
valid_codes.append(code)
|
||||||
print(f"✓ {len(data)} 条")
|
print(f"✓ {len(data)} 条")
|
||||||
else:
|
else:
|
||||||
print("✗ 无数据")
|
print("✗ 无数据")
|
||||||
|
|
||||||
|
# 下载ETF数据(价格+净值,用于溢价率计算)
|
||||||
|
etf_nav_data_list = [] # ETF净值数据
|
||||||
|
|
||||||
|
if etf_codes:
|
||||||
|
print("\n [2/2] 下载ETF数据(价格+净值,用于溢价率计算)...")
|
||||||
|
|
||||||
|
for idx_code, etf_code in etf_codes.items():
|
||||||
|
name = code_config[idx_code].get('name', idx_code)
|
||||||
|
market = code_config[idx_code].get('market', 'A')
|
||||||
|
|
||||||
|
# 加密货币跳过ETF下载
|
||||||
|
if market == 'CRYPTO':
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f" 下载 ETF {etf_code} (对应指数 {idx_code})...", end=" ")
|
||||||
|
|
||||||
|
# 获取ETF价格数据
|
||||||
|
price_data = self._fetch_etf(etf_code, start_date, end_date)
|
||||||
|
# 获取ETF净值数据
|
||||||
|
nav_data = self._fetch_etf_nav(etf_code, start_date, end_date)
|
||||||
|
|
||||||
|
if price_data is not None and len(price_data) > 0:
|
||||||
|
# 使用指数代码作为列名,保持与指数数据一致
|
||||||
|
price_data = price_data.copy()
|
||||||
|
price_data['source'] = 'Tushare-ETF'
|
||||||
|
price_data['code'] = idx_code
|
||||||
|
price_data.index = pd.to_datetime(price_data.index, utc=True).tz_localize(None).normalize()
|
||||||
|
etf_data_list.append(price_data[['code', 'close', 'source']])
|
||||||
|
|
||||||
|
# 处理净值数据
|
||||||
|
if nav_data is not None and len(nav_data) > 0:
|
||||||
|
nav_data = nav_data.copy()
|
||||||
|
nav_data['code'] = idx_code
|
||||||
|
nav_data.index = pd.to_datetime(nav_data.index, utc=True).tz_localize(None).normalize()
|
||||||
|
etf_nav_data_list.append(nav_data[['code', 'nav']])
|
||||||
|
print(f"✓ 价格{len(price_data)}条 净值{len(nav_data)}条")
|
||||||
|
else:
|
||||||
|
print(f"✓ 价格{len(price_data)}条 (无净值数据)")
|
||||||
|
else:
|
||||||
|
print(f"✗ 无数据")
|
||||||
|
|
||||||
# 关闭 socks2http 代理
|
# 关闭 socks2http 代理
|
||||||
if socks2http_proc:
|
if socks2http_proc:
|
||||||
socks2http_proc.terminate()
|
socks2http_proc.terminate()
|
||||||
socks2http_proc.wait()
|
socks2http_proc.wait()
|
||||||
print(f"\n socks2http 代理已关闭")
|
print(f"\n socks2http 代理已关闭")
|
||||||
|
|
||||||
if not all_data:
|
if not index_data_list:
|
||||||
return None, None, []
|
return None, None, None, None, []
|
||||||
|
|
||||||
# 检查数据源类型
|
# 处理指数数据
|
||||||
sources = set(d['source'].iloc[0] for d in all_data)
|
print(f"\n整理指数数据(用于因子计算)...")
|
||||||
|
index_df = pd.concat(index_data_list, ignore_index=False)
|
||||||
if len(sources) == 1:
|
index_df = index_df.reset_index()
|
||||||
# 单一数据源:转换为宽格式(向后兼容)
|
if 'index' in index_df.columns:
|
||||||
all_df = pd.concat(all_data, ignore_index=False)
|
index_df = index_df.rename(columns={'index': 'date'})
|
||||||
all_df = all_df.reset_index()
|
index_df['date'] = pd.to_datetime(index_df['date']).dt.normalize()
|
||||||
all_df['date'] = pd.to_datetime(all_df['date'], utc=True).dt.tz_localize(None)
|
|
||||||
etf_data = all_df.pivot_table(
|
|
||||||
index='date',
|
|
||||||
columns='code',
|
|
||||||
values='close',
|
|
||||||
aggfunc='first'
|
|
||||||
)
|
|
||||||
print(f"\n数据整理完成 (单一数据源 {list(sources)[0]}):")
|
|
||||||
print(f" 时间范围: {etf_data.index[0]} ~ {etf_data.index[-1]}")
|
|
||||||
print(f" 交易日数: {len(etf_data)}")
|
|
||||||
print(f" 有效标的: {len(etf_data.columns)} 只")
|
|
||||||
else:
|
|
||||||
# 多数据源:以主市场(Tushare/A股)为基准,其他市场数据前向填充
|
|
||||||
print(f"\n数据整理完成 (多数据源 - 以A股交易日为基准):")
|
|
||||||
|
|
||||||
# 合并所有数据(索引已经是标准化后的日期)
|
|
||||||
all_df = pd.concat(all_data, ignore_index=False)
|
|
||||||
all_df = all_df.reset_index()
|
|
||||||
# 重命名索引列为 date
|
|
||||||
if 'index' in all_df.columns:
|
|
||||||
all_df = all_df.rename(columns={'index': 'date'})
|
|
||||||
# 确保 date 列是日期格式(不含时间)
|
|
||||||
all_df['date'] = pd.to_datetime(all_df['date']).dt.normalize()
|
|
||||||
|
|
||||||
# 透视为宽格式
|
# 透视为宽格式
|
||||||
etf_data = all_df.pivot_table(
|
index_data = index_df.pivot_table(
|
||||||
index='date',
|
index='date',
|
||||||
columns='code',
|
columns='code',
|
||||||
values='close',
|
values='close',
|
||||||
aggfunc='first'
|
aggfunc='first'
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取主市场(Tushare)的交易日历
|
# 以A股交易日为基准,对齐所有数据
|
||||||
tushare_codes = [c for c in valid_codes if self._is_china_index(c)]
|
tushare_codes = [c for c in valid_codes if self._is_china_index(c)]
|
||||||
if tushare_codes:
|
if tushare_codes:
|
||||||
# 使用第一个A股代码的日期作为主市场交易日
|
primary_dates = index_data[tushare_codes[0]].dropna().index
|
||||||
primary_dates = etf_data[tushare_codes[0]].dropna().index
|
|
||||||
print(f" 主市场交易日: {len(primary_dates)} 天")
|
print(f" 主市场交易日: {len(primary_dates)} 天")
|
||||||
|
|
||||||
# 重新索引到主市场交易日,使用前向填充
|
# 重新索引到主市场交易日
|
||||||
|
index_data = index_data.reindex(primary_dates)
|
||||||
|
|
||||||
|
# 对非A股指数进行前向填充
|
||||||
|
non_a_codes = [c for c in valid_codes if not self._is_china_index(c)]
|
||||||
|
for code in non_a_codes:
|
||||||
|
if code in index_data.columns:
|
||||||
|
index_data[code] = index_data[code].ffill().bfill()
|
||||||
|
|
||||||
|
print(f" 非A股标的: {len(non_a_codes)} 只 (已前向填充)")
|
||||||
|
|
||||||
|
print(f" 时间范围: {index_data.index[0]} ~ {index_data.index[-1]}")
|
||||||
|
print(f" 交易日数: {len(index_data)}")
|
||||||
|
|
||||||
|
# 处理ETF数据
|
||||||
|
if etf_data_list:
|
||||||
|
print(f"\n整理ETF数据(用于收益计算)...")
|
||||||
|
etf_df = pd.concat(etf_data_list, ignore_index=False)
|
||||||
|
etf_df = etf_df.reset_index()
|
||||||
|
if 'index' in etf_df.columns:
|
||||||
|
etf_df = etf_df.rename(columns={'index': 'date'})
|
||||||
|
etf_df['date'] = pd.to_datetime(etf_df['date']).dt.normalize()
|
||||||
|
|
||||||
|
# 透视为宽格式
|
||||||
|
etf_data = etf_df.pivot_table(
|
||||||
|
index='date',
|
||||||
|
columns='code',
|
||||||
|
values='close',
|
||||||
|
aggfunc='first'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 对齐到主市场交易日
|
||||||
|
if tushare_codes:
|
||||||
etf_data = etf_data.reindex(primary_dates)
|
etf_data = etf_data.reindex(primary_dates)
|
||||||
|
|
||||||
# 对每个非主市场代码进行前向填充
|
print(f" ETF价格数据: {len(etf_data.columns)} 只")
|
||||||
yfinance_codes = [c for c in valid_codes if not self._is_china_index(c)]
|
else:
|
||||||
for code in yfinance_codes:
|
# 如果没有ETF数据,使用指数数据代替
|
||||||
if code in etf_data.columns:
|
etf_data = index_data.copy()
|
||||||
# 前向填充:用最近的有效价格填充休市日的数据
|
print(f"\n无ETF映射,使用指数数据代替")
|
||||||
etf_data[code] = etf_data[code].ffill()
|
|
||||||
# 对于开头的NaN,用后向填充
|
|
||||||
etf_data[code] = etf_data[code].bfill()
|
|
||||||
|
|
||||||
print(f" 非主市场标的: {len(yfinance_codes)} 只 (已前向填充)")
|
# 处理ETF净值数据
|
||||||
|
etf_nav_data = None
|
||||||
|
if etf_nav_data_list:
|
||||||
|
print(f"\n整理ETF净值数据(用于溢价率计算)...")
|
||||||
|
nav_df = pd.concat(etf_nav_data_list, ignore_index=False)
|
||||||
|
nav_df = nav_df.reset_index()
|
||||||
|
if 'index' in nav_df.columns:
|
||||||
|
nav_df = nav_df.rename(columns={'index': 'date'})
|
||||||
|
nav_df['date'] = pd.to_datetime(nav_df['date']).dt.normalize()
|
||||||
|
|
||||||
print(f" 时间范围: {etf_data.index[0]} ~ {etf_data.index[-1]}")
|
# 透视为宽格式
|
||||||
print(f" 交易日数: {len(etf_data)}")
|
etf_nav_data = nav_df.pivot_table(
|
||||||
print(f" 有效标的: {len(etf_data.columns)} 只")
|
index='date',
|
||||||
|
columns='code',
|
||||||
|
values='nav',
|
||||||
|
aggfunc='first'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 对齐到主市场交易日,并前向填充缺失值(净值数据通常T+1更新)
|
||||||
|
if tushare_codes:
|
||||||
|
etf_nav_data = etf_nav_data.reindex(primary_dates)
|
||||||
|
etf_nav_data = etf_nav_data.ffill() # 前向填充缺失的净值数据
|
||||||
|
|
||||||
|
print(f" ETF净值数据: {len(etf_nav_data.columns)} 只")
|
||||||
|
|
||||||
# 获取基准数据
|
# 获取基准数据
|
||||||
benchmark_data = self.fetch_single(benchmark_code, start_date, end_date)
|
benchmark_data = self.fetch_single(benchmark_code, start_date, end_date)
|
||||||
if benchmark_data is not None:
|
if benchmark_data is not None:
|
||||||
# 标准化日期索引(无时区,只保留日期部分)
|
|
||||||
benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize()
|
benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize()
|
||||||
print(f" ✓ 基准 {benchmark_code}: {len(benchmark_data)} 条")
|
# 对齐到主市场交易日
|
||||||
|
if tushare_codes:
|
||||||
|
benchmark_data = benchmark_data.reindex(primary_dates)
|
||||||
|
print(f"\n✓ 基准 {benchmark_code}: {len(benchmark_data)} 条")
|
||||||
|
|
||||||
return etf_data, benchmark_data, valid_codes
|
return index_data, etf_data, etf_nav_data, benchmark_data, valid_codes
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
"""上下文管理器入口"""
|
"""上下文管理器入口"""
|
||||||
|
|||||||
@@ -80,76 +80,36 @@ def calculate_daily_return(price_series: pd.Series) -> pd.Series:
|
|||||||
|
|
||||||
|
|
||||||
def compute_factors(
|
def compute_factors(
|
||||||
etf_data: pd.DataFrame,
|
index_data: pd.DataFrame,
|
||||||
code_list: list,
|
code_list: list,
|
||||||
n: int = 25,
|
n: int = 25,
|
||||||
factor_type: str = "slope_r2",
|
factor_type: str = "slope_r2",
|
||||||
|
etf_data: pd.DataFrame = None,
|
||||||
|
code_config: dict = None,
|
||||||
) -> tuple[pd.DataFrame, list]:
|
) -> tuple[pd.DataFrame, list]:
|
||||||
"""
|
"""
|
||||||
计算所有指数的因子和日收益率
|
计算所有指数的因子和日收益率(支持指数-ETF双轨数据)
|
||||||
支持长格式数据(混合数据源:Tushare + YFinance)
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
etf_data: DataFrame, 长格式数据,包含 [code, close, source] 列
|
index_data: 指数价格数据(宽格式,用于因子计算)
|
||||||
code_list: 指数代码列表
|
code_list: 指数代码列表
|
||||||
n: 动量/趋势窗口
|
n: 动量/趋势窗口
|
||||||
factor_type: 'momentum' 或 'slope_r2'
|
factor_type: 'momentum' 或 'slope_r2'
|
||||||
|
etf_data: ETF价格数据(宽格式,用于收益计算)
|
||||||
|
code_config: 代码配置字典 {code: {name, etf, market}},用于判断是否为加密货币
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (result_df, valid_codes)
|
tuple: (result_df, valid_codes)
|
||||||
|
- result_df: 包含因子得分和日收益率的DataFrame
|
||||||
|
- valid_codes: 有效代码列表
|
||||||
"""
|
"""
|
||||||
# 检查数据格式
|
code_config = code_config or {}
|
||||||
if 'code' in etf_data.columns:
|
|
||||||
# 长格式数据 - 按 code 分别计算因子(旧逻辑,保留兼容)
|
|
||||||
all_factors = []
|
|
||||||
valid_codes = []
|
|
||||||
|
|
||||||
for code in code_list:
|
# 如果没有提供ETF数据,创建一个空的DataFrame
|
||||||
code_data = etf_data[etf_data['code'] == code].copy()
|
if etf_data is None:
|
||||||
if len(code_data) == 0:
|
etf_data = pd.DataFrame()
|
||||||
print(f" ⚠ 跳过 {code}: 不在数据中")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查缺失值
|
result = index_data.copy()
|
||||||
null_pct = code_data['close'].isnull().sum() / len(code_data)
|
|
||||||
if null_pct > 0.2:
|
|
||||||
print(f" ⚠ 剔除 {code}: 缺失率 {null_pct:.1%} 过高")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 按日期排序
|
|
||||||
code_data = code_data.sort_index()
|
|
||||||
|
|
||||||
# 计算日收益率和因子
|
|
||||||
code_data[f"日收益率_{code}"] = calculate_daily_return(code_data['close'])
|
|
||||||
|
|
||||||
if factor_type == "momentum":
|
|
||||||
code_data[f"得分_{code}"] = calculate_momentum(code_data['close'], n)
|
|
||||||
elif factor_type == "slope_r2":
|
|
||||||
code_data[f"得分_{code}"] = calculate_slope_r2(code_data['close'], n)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"不支持的因子类型: {factor_type}")
|
|
||||||
|
|
||||||
# 保留需要的列
|
|
||||||
code_data = code_data[[f"日收益率_{code}", f"得分_{code}"]]
|
|
||||||
all_factors.append(code_data)
|
|
||||||
valid_codes.append(code)
|
|
||||||
|
|
||||||
if not all_factors:
|
|
||||||
raise ValueError("没有有效的指数数据")
|
|
||||||
|
|
||||||
# 合并所有因子的数据(按日期内连接 - 只保留所有指数都有数据的日期)
|
|
||||||
result = all_factors[0]
|
|
||||||
for df in all_factors[1:]:
|
|
||||||
result = result.join(df, how='inner')
|
|
||||||
|
|
||||||
# 删除所有得分都是 NaN 的行(即窗口期内的数据)
|
|
||||||
score_cols = [f"得分_{code}" for code in valid_codes]
|
|
||||||
# 只删除完全无法比较的行(所有得分都是NaN)
|
|
||||||
result = result.dropna(subset=score_cols, how='all')
|
|
||||||
|
|
||||||
else:
|
|
||||||
# 宽格式数据(向后兼容)
|
|
||||||
result = etf_data.copy()
|
|
||||||
|
|
||||||
# 过滤掉缺失值过多的指数
|
# 过滤掉缺失值过多的指数
|
||||||
total_rows = len(result)
|
total_rows = len(result)
|
||||||
@@ -165,10 +125,9 @@ def compute_factors(
|
|||||||
else:
|
else:
|
||||||
valid_codes.append(code)
|
valid_codes.append(code)
|
||||||
|
|
||||||
# 对有效指数计算因子
|
# 对有效指数计算因子和收益率
|
||||||
for code in valid_codes:
|
for code in valid_codes:
|
||||||
result[f"日收益率_{code}"] = calculate_daily_return(result[code])
|
# 因子基于指数价格计算
|
||||||
|
|
||||||
if factor_type == "momentum":
|
if factor_type == "momentum":
|
||||||
result[f"得分_{code}"] = calculate_momentum(result[code], n)
|
result[f"得分_{code}"] = calculate_momentum(result[code], n)
|
||||||
elif factor_type == "slope_r2":
|
elif factor_type == "slope_r2":
|
||||||
@@ -176,6 +135,9 @@ def compute_factors(
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"不支持的因子类型: {factor_type}")
|
raise ValueError(f"不支持的因子类型: {factor_type}")
|
||||||
|
|
||||||
|
# 日收益率基于指数价格计算(回测使用指数价格)
|
||||||
|
result[f"日收益率_{code}"] = calculate_daily_return(result[code])
|
||||||
|
|
||||||
# 按得分列做 dropna
|
# 按得分列做 dropna
|
||||||
score_cols = [f"得分_{code}" for code in valid_codes]
|
score_cols = [f"得分_{code}" for code in valid_codes]
|
||||||
result = result.dropna(subset=score_cols)
|
result = result.dropna(subset=score_cols)
|
||||||
@@ -185,5 +147,7 @@ def compute_factors(
|
|||||||
print(f" 窗口天数: {n}")
|
print(f" 窗口天数: {n}")
|
||||||
print(f" 有效指数: {len(valid_codes)}/{len(code_list)}")
|
print(f" 有效指数: {len(valid_codes)}/{len(code_list)}")
|
||||||
print(f" 有效数据: {len(result)} 行")
|
print(f" 有效数据: {len(result)} 行")
|
||||||
|
if etf_data is not index_data:
|
||||||
|
print(f" 使用ETF数据计算收益: ✓")
|
||||||
|
|
||||||
return result, valid_codes
|
return result, valid_codes
|
||||||
|
|||||||
@@ -59,22 +59,38 @@ def main():
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
config['end_date'] = datetime.now().strftime('%Y-%m-%d')
|
config['end_date'] = datetime.now().strftime('%Y-%m-%d')
|
||||||
|
|
||||||
# 从配置中读取 code_list 和 code_name_map
|
# 从配置中读取 code_list(新的配置格式:{代码: {name, etf, market}})
|
||||||
# code_list 现在是一个字典 {代码: 名称}
|
|
||||||
code_list_config = config.get('code_list', {})
|
code_list_config = config.get('code_list', {})
|
||||||
|
|
||||||
|
# 提取代码列表和名称映射
|
||||||
if isinstance(code_list_config, dict):
|
if isinstance(code_list_config, dict):
|
||||||
code_list = list(code_list_config.keys())
|
code_list = list(code_list_config.keys())
|
||||||
code_name_map = code_list_config
|
# 构建 code_name_map: {代码: 名称}
|
||||||
|
code_name_map = {}
|
||||||
|
for code, cfg in code_list_config.items():
|
||||||
|
if isinstance(cfg, dict):
|
||||||
|
code_name_map[code] = cfg.get('name', code)
|
||||||
|
else:
|
||||||
|
# 兼容旧格式
|
||||||
|
code_name_map[code] = cfg
|
||||||
else:
|
else:
|
||||||
# 兼容旧格式(列表)
|
# 兼容旧格式(列表)
|
||||||
code_list = code_list_config
|
code_list = code_list_config
|
||||||
code_name_map = DEFAULT_CODE_NAME_MAP
|
code_name_map = DEFAULT_CODE_NAME_MAP
|
||||||
|
code_list_config = {}
|
||||||
|
|
||||||
benchmark_config = config.get('benchmark', {})
|
benchmark_config = config.get('benchmark', {})
|
||||||
benchmark_name = benchmark_config.get('name', DEFAULT_BENCHMARK_NAME)
|
benchmark_name = benchmark_config.get('name', DEFAULT_BENCHMARK_NAME)
|
||||||
|
|
||||||
print(f"\n配置文件: {args.config}")
|
print(f"\n配置文件: {args.config}")
|
||||||
print(f"候选标的: {len(code_list)} 只")
|
print(f"候选标的: {len(code_list)} 只")
|
||||||
|
|
||||||
|
# 统计ETF映射情况
|
||||||
|
etf_count = sum(1 for cfg in code_list_config.values() if isinstance(cfg, dict) and cfg.get('etf'))
|
||||||
|
crypto_count = sum(1 for cfg in code_list_config.values() if isinstance(cfg, dict) and cfg.get('market') == 'CRYPTO')
|
||||||
|
print(f" - ETF映射: {etf_count} 只")
|
||||||
|
print(f" - 直接交易: {crypto_count} 只(加密货币)")
|
||||||
|
|
||||||
print(f"回测区间: {config['start_date']} ~ {config['end_date']}")
|
print(f"回测区间: {config['start_date']} ~ {config['end_date']}")
|
||||||
print(f"因子类型: {config['factor_type']}")
|
print(f"因子类型: {config['factor_type']}")
|
||||||
print(f"窗口天数: {config['n_days']}")
|
print(f"窗口天数: {config['n_days']}")
|
||||||
@@ -82,8 +98,8 @@ def main():
|
|||||||
print(f"调仓周期: {config['rebalance_days']} 天")
|
print(f"调仓周期: {config['rebalance_days']} 天")
|
||||||
print(f"交易成本: {config['trade_cost']:.2%}")
|
print(f"交易成本: {config['trade_cost']:.2%}")
|
||||||
|
|
||||||
# 更新 config 中的 code_list 为列表格式
|
# 保持 config 中的 code_list 为完整配置格式(用于引擎内部解析)
|
||||||
config['code_list'] = code_list
|
# 不需要修改 config['code_list'],引擎会直接使用原始配置
|
||||||
|
|
||||||
# 创建策略实例
|
# 创建策略实例
|
||||||
strategy = RotationStrategy(config)
|
strategy = RotationStrategy(config)
|
||||||
@@ -119,6 +135,10 @@ def main():
|
|||||||
benchmark_name=benchmark_name,
|
benchmark_name=benchmark_name,
|
||||||
save_path=args.save_path,
|
save_path=args.save_path,
|
||||||
select_num=config["select_num"],
|
select_num=config["select_num"],
|
||||||
|
code_config=code_list_config, # 传入完整配置以显示ETF映射
|
||||||
|
index_data=strategy.index_data, # 传入指数数据
|
||||||
|
etf_price_data=strategy.etf_data, # 传入ETF价格数据
|
||||||
|
etf_nav_data_raw=strategy.etf_nav_data, # 传入ETF净值数据
|
||||||
)
|
)
|
||||||
|
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|||||||
@@ -34,31 +34,40 @@ class RotationStrategy(BacktestStrategy):
|
|||||||
self.backtest_result = None
|
self.backtest_result = None
|
||||||
|
|
||||||
def fetch_data(self) -> pd.DataFrame:
|
def fetch_data(self) -> pd.DataFrame:
|
||||||
"""获取数据"""
|
"""获取数据(支持指数-ETF双轨数据)"""
|
||||||
from config.settings import DEFAULT_BENCHMARK_CODE
|
from config.settings import DEFAULT_BENCHMARK_CODE
|
||||||
|
|
||||||
# 从配置中读取基准代码,或使用默认值
|
# 从配置中读取基准代码,或使用默认值
|
||||||
benchmark_code = self.config.get("benchmark", {}).get("code", DEFAULT_BENCHMARK_CODE)
|
benchmark_code = self.config.get("benchmark", {}).get("code", DEFAULT_BENCHMARK_CODE)
|
||||||
|
|
||||||
# 使用上下文管理器管理 SSH 隧道(如果是 YFinance 数据源)
|
# 获取代码配置(包含 name, etf, market)
|
||||||
|
code_config = self.config.get("code_list", {})
|
||||||
|
|
||||||
|
# 使用上下文管理器管理 SSH 隧道
|
||||||
with self.data_source:
|
with self.data_source:
|
||||||
etf_data, benchmark_data, valid_codes = self.data_source.fetch_all(
|
index_data, etf_data, etf_nav_data, benchmark_data, valid_codes = self.data_source.fetch_all(
|
||||||
self.config["code_list"],
|
code_config,
|
||||||
benchmark_code,
|
benchmark_code,
|
||||||
self.config["start_date"],
|
self.config["start_date"],
|
||||||
self.config["end_date"],
|
self.config["end_date"],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.etf_data = etf_data
|
# 存储数据和配置
|
||||||
|
self.index_data = index_data # 指数数据(用于因子计算)
|
||||||
|
self.etf_data = etf_data # ETF价格数据(用于收益计算)
|
||||||
|
self.etf_nav_data = etf_nav_data # ETF净值数据(用于溢价率计算)
|
||||||
self.benchmark_data = benchmark_data
|
self.benchmark_data = benchmark_data
|
||||||
self.valid_codes = valid_codes
|
self.valid_codes = valid_codes
|
||||||
|
self.code_config = code_config # 代码配置(用于判断市场类型)
|
||||||
|
|
||||||
# 计算因子
|
# 计算因子(传入两套数据:指数数据用于因子,ETF数据用于收益)
|
||||||
factor_data, valid_codes = compute_factors(
|
factor_data, valid_codes = compute_factors(
|
||||||
etf_data,
|
index_data,
|
||||||
valid_codes,
|
valid_codes,
|
||||||
n=self.config["n_days"],
|
n=self.config["n_days"],
|
||||||
factor_type=self.config["factor_type"],
|
factor_type=self.config["factor_type"],
|
||||||
|
etf_data=etf_data, # 传入ETF数据用于收益计算
|
||||||
|
code_config=code_config, # 传入配置以判断加密货币
|
||||||
)
|
)
|
||||||
|
|
||||||
self.data = factor_data
|
self.data = factor_data
|
||||||
|
|||||||
Reference in New Issue
Block a user