Compare commits

..

2 Commits

Author SHA1 Message Date
062f500369 refactor(rotation): 统一与配置文件代码映射和基准指数使用方式
- 将默认代码映射字典和基准指数改为可被策略配置覆盖的形式
- 修改配置文件rotation.yaml中候选池配置从列表变为代码与名称的字典映射
- 在运行脚本中加载配置时支持字典格式的code_list和benchmark,兼容旧格式列表
- 更新回测策略引擎通过配置动态获取基准指数代码
- 打印输出和函数调用中统一使用从配置加载的代码映射和基准名称数据
2026-03-19 00:33:06 +08:00
9b154a1a25 feat(rotation): 增加最新调仓信号展示功能
- 配置中取消固定end_date,改为默认使用当前日期
- 添加打印最新调仓信号的功能,显示持仓明细及调出品种
- 在报告生成流程中调用最新调仓信号打印函数
- 图表展示中新增最新调仓信号表格,支持颜色区分调入、调出和维持
- 优化报告图表布局,调整画布高度适应信号表内容
- 删除无用test.py测试脚本及相关冗余代码
2026-03-19 00:22:25 +08:00
6 changed files with 318 additions and 136 deletions

View File

@@ -55,8 +55,8 @@ def get_db_config() -> dict:
} }
# ==================== 代码映射 ==================== # ==================== 代码映射(默认,可被策略配置覆盖)====================
CODE_NAME_MAP = { DEFAULT_CODE_NAME_MAP = {
# 宽基 # 宽基
"000300.SH": "沪深300", "000300.SH": "沪深300",
"000905.SH": "中证500", "000905.SH": "中证500",
@@ -95,6 +95,6 @@ CODE_NAME_MAP = {
"399702.SZ": "中证国债指数", "399702.SZ": "中证国债指数",
} }
# 基准指数 # 基准指数(默认,可被策略配置覆盖)
BENCHMARK_CODE = "000300.SH" DEFAULT_BENCHMARK_CODE = "000300.SH"
BENCHMARK_NAME = "沪深300指数" DEFAULT_BENCHMARK_NAME = "沪深300指数"

View File

@@ -1,37 +1,43 @@
# ETF轮动策略配置 # ETF轮动策略配置
# ==================== 候选池配置 ==================== # ==================== 候选池配置 ====================
# A股全行业指数代码列表Tushare格式XXXXXX.SH / XXXXXX.SZ # A股全行业指数配置Tushare格式XXXXXX.SH / XXXXXX.SZ
# 格式: {代码: 名称}
code_list: code_list:
# 宽基指数 # 宽基指数
- "000300.SH" # 沪深300大盘蓝筹 "000300.SH": "沪深300"
- "000905.SH" # 中证500中盘成长 "000905.SH": "中证500"
- "000852.SH" # 中证1000(小盘) "000852.SH": "中证1000"
- "399006.SZ" # 创业板指(创业板龙头) "399006.SZ": "创业板指"
- "000015.SH" # 上证红利(高股息价值) "000015.SH": "上证红利"
# 金融 # 金融
- "399986.SZ" # 中证银行 "399986.SZ": "中证银行"
# 消费 # 消费
- "399997.SZ" # 中证白酒 "399997.SZ": "中证白酒"
# 医药健康 # 医药健康
- "399989.SZ" # 中证医疗 "399989.SZ": "中证医疗"
# 科技信息 # 科技信息
- "000935.SH" # 中证信息技术 "000935.SH": "中证信息"
# 新能源 # 新能源
- "399976.SZ" # 新能源 "399976.SZ": "新能源车"
# 周期资源 # 周期资源
- "399395.SZ" # 国证有色金属 "399395.SZ": "国证有色"
- "399998.SZ" # 中证煤炭 "399998.SZ": "中证煤炭"
- "399813.SZ" # 细分化工 "399813.SZ": "细分化工"
- "000937.SH" # 中证能源 "000937.SH": "中证能源"
# 其他行业 # 其他行业
- "399967.SZ" # 中证军工 "399967.SZ": "中证军工"
- "000949.SH" # 中证农业 "000949.SH": "中证农业"
- "399702.SZ" # 中证国债指数 "399702.SZ": "国债指数"
# 基准指数配置
benchmark:
code: "000300.SH"
name: "沪深300指数"
# ==================== 回测参数 ==================== # ==================== 回测参数 ====================
start_date: "2018-01-01" start_date: "2018-01-01"
end_date: "2025-03-17" # end_date: "2025-03-17"
# ==================== 因子参数 ==================== # ==================== 因子参数 ====================
# 动量/趋势窗口期(天数) # 动量/趋势窗口期(天数)

View File

@@ -20,7 +20,7 @@ sys.path.insert(0, str(project_root))
from strategies.rotation.engine import RotationStrategy from strategies.rotation.engine import RotationStrategy
from strategies.rotation.portfolio import track_positions, save_trades from strategies.rotation.portfolio import track_positions, save_trades
from strategies.rotation.report import generate_performance_report from strategies.rotation.report import generate_performance_report
from config.settings import CODE_NAME_MAP, BENCHMARK_NAME from config.settings import DEFAULT_CODE_NAME_MAP, DEFAULT_BENCHMARK_NAME
def load_config(config_path: str) -> dict: def load_config(config_path: str) -> dict:
@@ -53,8 +53,28 @@ def main():
# 加载配置 # 加载配置
config = load_config(args.config) config = load_config(args.config)
# 如果未设置 end_date默认使用最新日期
if not config.get('end_date'):
from datetime import datetime
config['end_date'] = datetime.now().strftime('%Y-%m-%d')
# 从配置中读取 code_list 和 code_name_map
# code_list 现在是一个字典 {代码: 名称}
code_list_config = config.get('code_list', {})
if isinstance(code_list_config, dict):
code_list = list(code_list_config.keys())
code_name_map = code_list_config
else:
# 兼容旧格式(列表)
code_list = code_list_config
code_name_map = DEFAULT_CODE_NAME_MAP
benchmark_config = config.get('benchmark', {})
benchmark_name = benchmark_config.get('name', DEFAULT_BENCHMARK_NAME)
print(f"\n配置文件: {args.config}") print(f"\n配置文件: {args.config}")
print(f"候选标的: {len(config['code_list'])}") print(f"候选标的: {len(code_list)}")
print(f"回测区间: {config['start_date']} ~ {config['end_date']}") print(f"回测区间: {config['start_date']} ~ {config['end_date']}")
print(f"因子类型: {config['factor_type']}") print(f"因子类型: {config['factor_type']}")
print(f"窗口天数: {config['n_days']}") print(f"窗口天数: {config['n_days']}")
@@ -62,6 +82,9 @@ def main():
print(f"调仓周期: {config['rebalance_days']}") print(f"调仓周期: {config['rebalance_days']}")
print(f"交易成本: {config['trade_cost']:.2%}") print(f"交易成本: {config['trade_cost']:.2%}")
# 更新 config 中的 code_list 为列表格式
config['code_list'] = code_list
# 创建策略实例 # 创建策略实例
strategy = RotationStrategy(config) strategy = RotationStrategy(config)
@@ -79,7 +102,7 @@ def main():
trades_df, summary_df = track_positions( trades_df, summary_df = track_positions(
backtest_result, backtest_result,
code_name_map=CODE_NAME_MAP, code_name_map=code_name_map,
select_num=config["select_num"], select_num=config["select_num"],
) )
save_trades(trades_df, summary_df, save_path=args.save_path) save_trades(trades_df, summary_df, save_path=args.save_path)
@@ -92,8 +115,8 @@ def main():
metrics = generate_performance_report( metrics = generate_performance_report(
backtest_result, backtest_result,
strategy.valid_codes, strategy.valid_codes,
code_name_map=CODE_NAME_MAP, code_name_map=code_name_map,
benchmark_name=BENCHMARK_NAME, benchmark_name=benchmark_name,
save_path=args.save_path, save_path=args.save_path,
select_num=config["select_num"], select_num=config["select_num"],
) )

View File

@@ -25,11 +25,14 @@ class RotationStrategy(BacktestStrategy):
def fetch_data(self) -> pd.DataFrame: def fetch_data(self) -> pd.DataFrame:
"""获取数据""" """获取数据"""
from config.settings import BENCHMARK_CODE from config.settings import DEFAULT_BENCHMARK_CODE
# 从配置中读取基准代码,或使用默认值
benchmark_code = self.config.get("benchmark", {}).get("code", DEFAULT_BENCHMARK_CODE)
etf_data, benchmark_data, valid_codes = self.data_source.fetch_all( etf_data, benchmark_data, valid_codes = self.data_source.fetch_all(
self.config["code_list"], self.config["code_list"],
BENCHMARK_CODE, benchmark_code,
self.config["start_date"], self.config["start_date"],
self.config["end_date"], self.config["end_date"],
) )

View File

@@ -77,6 +77,9 @@ def generate_performance_report(
print(f' {"最大回撤区间":<22} {str(s_dd_start.date()):>10} ~ {str(s_dd_end.date())}') print(f' {"最大回撤区间":<22} {str(s_dd_start.date()):>10} ~ {str(s_dd_end.date())}')
print("=" * 70) print("=" * 70)
# 打印最新调仓信号
_print_latest_signal(backtest_result, code_list, code_name_map, select_num)
# 绘制图表 # 绘制图表
_plot_report_chart( _plot_report_chart(
backtest_result, code_list, code_name_map, backtest_result, code_list, code_name_map,
@@ -99,6 +102,178 @@ def generate_performance_report(
} }
def _print_latest_signal(backtest_result: pd.DataFrame, code_list: list, code_name_map: dict, select_num: int):
"""打印最新调仓信号"""
latest = _extract_latest_positions(backtest_result, code_list, code_name_map, select_num)
signal_date_str = latest["signal_date"].strftime("%Y-%m-%d")
print("\n")
print("=" * 100)
print(" 最新调仓信号 (下一交易日执行)")
print("=" * 100)
print(f" 数据截止: {signal_date_str}")
print()
# 表头
print(f' {"品种名称":<8} {"代码":>10} {"仓位":>6} {"得分":>8} {"进场日期":>12} {"进场价":>10} {"最新价":>10} {"操作":>6} {"持有天数":>8} {"盈亏":>10}')
print(" " + "-" * 115)
# 下期持仓(调入/维持)
for pos in latest["positions"]:
pnl_str = f'{pos["pnl"]:>+9.2%}' if pos["pnl"] is not None else ''
days_str = f'{pos["holding_days"]:>7}' if pos["holding_days"] is not None else ''
entry_str = f'{pos["entry_price"]:>10.2f}' if pos["entry_price"] is not None else ''
entry_date_str = pos["entry_date"].strftime("%Y-%m-%d") if pos.get("entry_date") else ''
score_str = f'{pos["score"]:>8.2f}' if pos["score"] is not None else ''
flag = '' if pos["action"] == "调入" else ' '
print(f' {pos["name"]:<8} {pos["code"]:>10} {pos["weight"]:>6.0%} {score_str} {entry_date_str:>12} {entry_str} {pos["current_price"]:>10.2f} {flag}{pos["action"]:>4} {days_str} {pnl_str}')
# 需调出的品种
if latest["exit_positions"]:
print()
print(" 需调出:")
for pos in latest["exit_positions"]:
pnl_str = f'{pos["pnl"]:>+9.2%}' if pos["pnl"] is not None else ''
days_str = f'{pos["holding_days"]:>7}' if pos["holding_days"] is not None else ''
entry_str = f'{pos["entry_price"]:>10.2f}' if pos["entry_price"] is not None else ''
entry_date_str = pos["entry_date"].strftime("%Y-%m-%d") if pos.get("entry_date") else ''
score_str = '' # 调出品种无得分
print(f' {pos["name"]:<8} {pos["code"]:>10} {pos["weight"]:>6.0%} {score_str} {entry_date_str:>12} {entry_str} {pos["current_price"]:>10.2f} ▼调出 {days_str} {pnl_str}')
print("=" * 120)
def _extract_latest_positions(backtest_result: pd.DataFrame, code_list: list, code_name_map: dict, select_num: int) -> dict:
"""提取最新持仓和下期调仓建议"""
last_date = backtest_result.index[-1]
last_row = backtest_result.iloc[-1]
# 当前持仓
current_signal = last_row["信号"]
if select_num == 1:
current_codes = [current_signal]
else:
current_codes = current_signal.split(",")
# 下期建议
score_cols = [f"得分_{code}" for code in code_list if f"得分_{code}" in backtest_result.columns]
scores = pd.to_numeric(last_row[score_cols], errors="coerce")
top_n = scores.nlargest(select_num)
next_codes = [c.replace("得分_", "") for c in top_n.index]
# 计算持仓信息
positions_info = []
weight = 1.0 / select_num
for code in next_codes:
name = code_name_map.get(code, code)
action = "维持" if code in current_codes else "调入"
# 获取当前价格和得分
current_price = last_row.get(code, 0)
score = scores.get(f"得分_{code}", None)
# 计算持仓信息(如果是维持的仓位)
entry_date = None
entry_price = None
holding_days = None
pnl = None
if action == "维持":
# 找到该标的最近一次连续持仓的起始日期
signal_series = backtest_result["信号"]
mask = signal_series == code if select_num == 1 else signal_series.str.contains(code, regex=False, na=False)
# 找到连续持仓段(从后往前找)
is_holding = mask.values
dates = backtest_result.index
# 从最后一天往前遍历,找到连续持仓的起始点
entry_date = None
for i in range(len(is_holding) - 1, -1, -1):
if is_holding[i]:
entry_date = dates[i]
else:
break
if entry_date is not None:
entry_price = backtest_result.loc[entry_date, code]
holding_days = (last_date - entry_date).days
if entry_price and entry_price != 0:
pnl = current_price / entry_price - 1
positions_info.append({
"code": code,
"name": name,
"weight": weight,
"score": score,
"action": action,
"current_price": current_price,
"entry_date": entry_date,
"entry_price": entry_price,
"holding_days": holding_days,
"pnl": pnl,
})
# 需调出的品种信息
exit_positions = []
for code in current_codes:
if code not in next_codes:
name = code_name_map.get(code, code)
current_price = last_row.get(code, 0)
# 计算调出品种的持仓信息(最近一次连续持仓)
signal_series = backtest_result["信号"]
mask = signal_series == code if select_num == 1 else signal_series.str.contains(code, regex=False, na=False)
# 找到连续持仓段(从后往前找)
is_holding = mask.values
dates = backtest_result.index
entry_price = None
holding_days = None
pnl = None
# 从最后一天往前遍历,找到连续持仓的起始点
entry_date = None
for i in range(len(is_holding) - 1, -1, -1):
if is_holding[i]:
entry_date = dates[i]
else:
break
if entry_date is not None:
entry_price = backtest_result.loc[entry_date, code]
holding_days = (last_date - entry_date).days
if entry_price and entry_price != 0:
pnl = current_price / entry_price - 1
exit_positions.append({
"code": code,
"name": name,
"weight": weight,
"score": None, # 调出品种无得分
"action": "调出",
"current_price": current_price,
"entry_date": entry_date,
"entry_price": entry_price,
"holding_days": holding_days,
"pnl": pnl,
})
return {
"signal_date": last_date,
"current_codes": current_codes,
"next_codes": next_codes,
"positions": positions_info,
"exit_positions": exit_positions,
}
def _plot_report_chart( def _plot_report_chart(
backtest_result: pd.DataFrame, backtest_result: pd.DataFrame,
code_list: list, code_list: list,
@@ -114,10 +289,86 @@ def _plot_report_chart(
strategy_nav = backtest_result["轮动策略净值"] strategy_nav = backtest_result["轮动策略净值"]
benchmark_nav = backtest_result["基准净值"] benchmark_nav = backtest_result["基准净值"]
fig, axes = plt.subplots(3, 1, figsize=(14, 12)) # 提取最新调仓信息
latest = _extract_latest_positions(backtest_result, code_list, code_name_map, select_num)
# 计算表格行数
n_table_rows = len(latest["positions"]) + len(latest["exit_positions"])
table_height = max(1.5, 0.5 + n_table_rows * 0.28)
fig = plt.figure(figsize=(14, 10 + table_height + 8))
gs = fig.add_gridspec(4, 1, height_ratios=[table_height, 3, 1, 1.2], hspace=0.35)
# 面板0: 最新调仓信号表
ax0 = fig.add_subplot(gs[0])
ax0.axis("off")
signal_date_str = latest["signal_date"].strftime("%Y-%m-%d")
ax0.set_title(f"最新调仓信号 (数据截止: {signal_date_str},下一交易日执行)", fontsize=14, fontweight="bold", loc="left", pad=15)
# 构建表格数据
table_data = []
col_labels = ["品种名称", "代码", "仓位", "得分", "进场日期", "进场价", "最新价", "操作", "持有天数", "盈亏"]
# 下期持仓(调入/维持)
for pos in latest["positions"]:
pnl_str = f'{pos["pnl"]:+.2%}' if pos["pnl"] is not None else ""
days_str = f'{pos["holding_days"]}' if pos["holding_days"] is not None else ""
entry_str = f'{pos["entry_price"]:.2f}' if pos["entry_price"] is not None else ""
entry_date_str = pos["entry_date"].strftime("%m-%d") if pos.get("entry_date") else ""
score_str = f'{pos["score"]:.2f}' if pos["score"] is not None else ""
table_data.append([
pos["name"], pos["code"], f'{pos["weight"]:.0%}',
score_str, entry_date_str, entry_str, f'{pos["current_price"]:.2f}',
pos["action"], days_str, pnl_str
])
# 需调出的品种
for pos in latest["exit_positions"]:
pnl_str = f'{pos["pnl"]:+.2%}' if pos["pnl"] is not None else ""
days_str = f'{pos["holding_days"]}' if pos["holding_days"] is not None else ""
entry_str = f'{pos["entry_price"]:.2f}' if pos["entry_price"] is not None else ""
entry_date_str = pos["entry_date"].strftime("%m-%d") if pos.get("entry_date") else ""
score_str = "" # 调出品种无得分
table_data.append([
pos["name"], pos["code"], f'{pos["weight"]:.0%}',
score_str, entry_date_str, entry_str, f'{pos["current_price"]:.2f}',
"调出", days_str, pnl_str
])
if table_data:
table = ax0.table(
cellText=table_data,
colLabels=col_labels,
loc="upper center",
cellLoc="center",
colWidths=[0.08, 0.08, 0.05, 0.06, 0.06, 0.07, 0.07, 0.05, 0.06, 0.07],
)
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 2.0)
# 表头深色
for j in range(len(col_labels)):
table[0, j].set_facecolor("#2C3E50")
table[0, j].set_text_props(color="white", fontweight="bold")
# 数据行按操作着色
for i in range(len(table_data)):
action = table_data[i][3]
if action == "调入":
color = "#d4edda" # 绿色
elif action == "调出":
color = "#f8d7da" # 红色
else:
color = "#fff3cd" # 黄色
for j in range(len(col_labels)):
table[i + 1, j].set_facecolor(color)
# 面板1: 净值曲线 # 面板1: 净值曲线
ax1 = axes[0] ax1 = fig.add_subplot(gs[1])
ax1.plot(strategy_nav.index, strategy_nav.values, ax1.plot(strategy_nav.index, strategy_nav.values,
label="轮动策略", linewidth=2, color="#E74C3C") label="轮动策略", linewidth=2, color="#E74C3C")
ax1.plot(benchmark_nav.index, benchmark_nav.values, ax1.plot(benchmark_nav.index, benchmark_nav.values,
@@ -139,7 +390,7 @@ def _plot_report_chart(
ax1.set_yscale("log") ax1.set_yscale("log")
# 面板2: 回撤曲线 # 面板2: 回撤曲线
ax2 = axes[1] ax2 = fig.add_subplot(gs[2])
cummax = strategy_nav.cummax() cummax = strategy_nav.cummax()
drawdown = (strategy_nav - cummax) / cummax drawdown = (strategy_nav - cummax) / cummax
ax2.fill_between(drawdown.index, drawdown.values, 0, alpha=0.5, color="#E74C3C") ax2.fill_between(drawdown.index, drawdown.values, 0, alpha=0.5, color="#E74C3C")
@@ -148,7 +399,7 @@ def _plot_report_chart(
ax2.grid(True, alpha=0.3) ax2.grid(True, alpha=0.3)
# 面板3: 持仓分布 # 面板3: 持仓分布
ax3 = axes[2] ax3 = fig.add_subplot(gs[3])
signal_series = backtest_result["信号"] signal_series = backtest_result["信号"]
for i, code in enumerate(code_list): for i, code in enumerate(code_list):
name = code_name_map.get(code, code) name = code_name_map.get(code, code)
@@ -168,7 +419,6 @@ def _plot_report_chart(
ax3.set_yticklabels(ylabels, fontsize=7) ax3.set_yticklabels(ylabels, fontsize=7)
ax3.grid(True, alpha=0.3) ax3.grid(True, alpha=0.3)
plt.tight_layout()
chart_path = f"{save_path}_chart.png" chart_path = f"{save_path}_chart.png"
plt.savefig(chart_path, dpi=150, bbox_inches="tight") plt.savefig(chart_path, dpi=150, bbox_inches="tight")
plt.close() plt.close()

100
test.py
View File

@@ -1,100 +0,0 @@
import pandas as pd
import numpy as np
import vectorbt as vbt
from numba import njit
import vectorbt as vbt
import pandas as pd
import numpy as np
import pandas as pd
from loguru import logger
from chart import resample_data, QuantChart
from db_config import DatabaseManager, DatabaseConfig
from vectorbt.base.reshape_fns import to_2d_array
def get_kline(code: str) -> list:
"""
获取所有指数代码
:return:
"""
db_config = DatabaseConfig()
logger.info(f"数据库连接: {db_config.connection_string}")
db_manager = DatabaseManager(db_config)
sql = f"SELECT date as time, open, high, low, close, volume FROM public.index_kline where code='{code}' order by date;"
res = db_manager.execute_query(sql)
data_list = [dict(item) for item in res]
df = pd.DataFrame(data_list)
df["time"] = pd.to_datetime(df["time"])
num_cols = ["open", "high", "low", "close", "volume"]
for col in num_cols:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors="coerce").astype(float)
return df
symbol = "399998"
timeframe = "1D"
df = get_kline(code=symbol)
df = resample_data(df, timeframe)
df.rename(columns={'time': 'date'}, inplace=True)
print(df.head())
if 'date' in df.columns:
df = df.set_index('date')
price = df['close']
# 2. 计算90天滚动波动率(年化)
returns = price.pct_change()
volatility_90d = returns.rolling(window=90, min_periods=90).std() * np.sqrt(365)
# 3. 计算波动率倒数作为权重
inv_vol = 1 / volatility_90d
# 标准化权重(可选,使其更易解释)
inv_vol_normalized = inv_vol / inv_vol.rolling(window=252).mean()
# 4. 创建每周重新平衡的信号
# 获取每周最后一个交易日
weekly_rebalance = pd.Series(False, index=price.index)
weekly_last_days = price.resample('W').last().index
for date in weekly_last_days:
# 找到最接近的交易日
idx = price.index.get_indexer([date], method='ffill')[0]
if idx >= 0:
weekly_rebalance.iloc[idx] = True
# 5. 定义订单函数
@njit
def order_func_nb(c, inv_vol_arr, rebalance_arr):
# 获取当前的波动率倒数权重
inv_vol_now = vbt.nb.flex_select_auto_nb(inv_vol_arr, c.i, c.col, False)
rebalance_now = vbt.nb.flex_select_auto_nb(rebalance_arr, c.i, c.col, False)
# 只在重新平衡日调整仓位
if not rebalance_now or np.isnan(inv_vol_now):
return vbt.nb.order_nothing_nb()
# 目标仓位 = 总价值 * 波动率倒数权重
# 这里使用 TargetPercent 类型,权重越高仓位越大
target_percent = min(inv_vol_now, 1.0) # 限制最大100%仓位
return vbt.nb.order_nb(
size=target_percent,
size_type=vbt.SizeType.TargetPercent,
direction=vbt.Direction.LongOnly
)
# 6. 运行回测
pf = vbt.Portfolio.from_order_func(
price,
order_func_nb,
to_2d_array(inv_vol_normalized),
to_2d_array(weekly_rebalance),
init_cash=100,
freq='1D'
)
# 7. 查看结果
print(pf.stats())
print(f"\n波动率倒数策略 vs 买入持有:")
print(f"总收益率: {pf.total_return():.2%}")