diff --git a/rotation/simple_rotation.py b/rotation/simple_rotation.py index f7fd7aa..d495cad 100644 --- a/rotation/simple_rotation.py +++ b/rotation/simple_rotation.py @@ -1312,8 +1312,8 @@ class SimpleRotationStrategy: n_rows = len(positions_info) + len(exit_positions) + len(unselected_positions) signal_h = max(1.5, 0.5 + n_rows * 0.35) - fig = plt.figure(figsize=(16, 10 + signal_h + 1.2 + 8)) - gs = fig.add_gridspec(5, 1, height_ratios=[signal_h, 1.2, 3, 1, 1.2], hspace=0.35) + fig = plt.figure(figsize=(16, 10 + signal_h + 1.2 + 1.5 + 8)) + gs = fig.add_gridspec(6, 1, height_ratios=[signal_h, 1.2, 1.5, 3, 1, 1.2], hspace=0.35) # Panel 0: Full asset ranking table ax0 = fig.add_subplot(gs[0]) @@ -1394,51 +1394,123 @@ class SimpleRotationStrategy: ptbl[1, j].set_facecolor("#d4edda") ptbl[2, j].set_facecolor("#cce5ff") - # Panel 2: NAV curves + # Panel 2: Monthly returns heatmap (year × month matrix) ax2 = fig.add_subplot(gs[2]) - ax2.plot(strategy_nav.index, strategy_nav.values, + ax2.axis('off') + ax2.set_title('月度收益矩阵 (%)', fontsize=14, fontweight='bold', loc='left', pad=10) + # Resample to monthly NAV (last business day) + monthly_nav = strategy_nav.resample('ME').last().dropna() + monthly_ret = monthly_nav.pct_change().dropna() + # Build year-month matrix + monthly_ret_pct = (monthly_ret * 100).round(2) + year_month_df = pd.DataFrame({ + 'year': monthly_ret_pct.index.year, + 'month': monthly_ret_pct.index.month, + 'ret': monthly_ret_pct.values, + }) + years = sorted(year_month_df['year'].unique()) + month_labels = [f'{m}月' for m in range(1, 13)] + col_labels = month_labels + ['年度'] + # Build table data + table_data = [] + for yr in years: + row = [] + yr_total = 0.0 + yr_count = 0 + for m in range(1, 13): + mask = (year_month_df['year'] == yr) & (year_month_df['month'] == m) + vals = year_month_df.loc[mask, 'ret'].values + if len(vals) > 0: + v = float(vals[0]) + row.append(v) + yr_total = (1 + yr_total / 100) * (1 + v / 100) * 100 - 100 + yr_count += 1 + else: + row.append(None) + row.append(round(yr_total, 2) if yr_count > 0 else None) + table_data.append(row) + + if table_data: + cell_text = [] + for row in table_data: + cell_text.append([f'{v:+.1f}' if v is not None else '—' for v in row]) + tbl = ax2.table(cellText=cell_text, rowLabels=[str(y) for y in years], + colLabels=col_labels, loc='center', cellLoc='center', bbox=[0, 0, 1, 1]) + tbl.auto_set_font_size(False) + tbl.set_fontsize(9) + tbl.scale(1, 1.6) + # Style header (colLabels row) + for j in range(len(col_labels)): + tbl[0, j].set_facecolor('#2C3E50') + tbl[0, j].set_text_props(color='white', fontweight='bold') + # Color cells based on value (sqrt mapping for better visibility of small returns) + max_abs = max(abs(v) for row in table_data for v in row if v is not None) or 1 + for i, row in enumerate(table_data): + for j, v in enumerate(row): + if v is not None: + intensity = min((abs(v) / max_abs) ** 0.5, 1.0) + if v >= 0: + r, g, b = 1.0, 1.0 - intensity * 0.55, 1.0 - intensity * 0.55 + else: + r, g, b = 1.0 - intensity * 0.55, 1.0, 1.0 - intensity * 0.55 + tbl[i + 1, j].set_facecolor((r, g, b)) + else: + tbl[i + 1, j].set_facecolor('#f5f5f5') + # Annual column: stronger coloring + ann_v = row[-1] + if ann_v is not None: + intensity = min((abs(ann_v) / max_abs) ** 0.5, 1.0) + if ann_v >= 0: + r, g, b = 1.0, 1.0 - intensity * 0.7, 1.0 - intensity * 0.7 + else: + r, g, b = 1.0 - intensity * 0.7, 1.0, 1.0 - intensity * 0.7 + tbl[i + 1, len(col_labels) - 1].set_facecolor((r, g, b)) + + # Panel 3: NAV curves + ax3 = fig.add_subplot(gs[3]) + ax3.plot(strategy_nav.index, strategy_nav.values, label="轮动策略", linewidth=2, color="#E74C3C") if benchmark_nav is not None: - ax2.plot(benchmark_nav.index, benchmark_nav.values, + ax3.plot(benchmark_nav.index, benchmark_nav.values, label=self.benchmark_name, linewidth=1.5, color="#3498DB", alpha=0.8) colors = plt.cm.tab20.colors for i, code in enumerate(self.signal_codes): if code in asset_navs: cfg = code_config.get(code, {}) lbl = cfg.get('name', code) if i < 10 else None - ax2.plot(asset_navs[code].index, asset_navs[code].values, + ax3.plot(asset_navs[code].index, asset_navs[code].values, label=lbl, linewidth=0.8, alpha=0.4, color=colors[i % len(colors)]) - ax2.set_title("ETF轮动策略 - 净值曲线", fontsize=16, fontweight="bold") - ax2.set_ylabel("净值") - ax2.legend(loc="upper left", fontsize=8, ncol=2) - ax2.grid(True, alpha=0.3) - ax2.set_yscale("log") - - # Panel 3: Drawdown - ax3 = fig.add_subplot(gs[3]) - drawdown = (strategy_nav - s_peak) / s_peak - ax3.fill_between(drawdown.index, drawdown.values, 0, alpha=0.5, color="#E74C3C") - ax3.set_title("策略回撤", fontsize=12) - ax3.set_ylabel("回撤") + ax3.set_title("ETF轮动策略 - 净值曲线", fontsize=16, fontweight="bold") + ax3.set_ylabel("净值") + ax3.legend(loc="upper left", fontsize=8, ncol=2) ax3.grid(True, alpha=0.3) + ax3.set_yscale("log") - # Panel 4: Holdings distribution + # Panel 4: Drawdown ax4 = fig.add_subplot(gs[4]) + drawdown = (strategy_nav - s_peak) / s_peak + ax4.fill_between(drawdown.index, drawdown.values, 0, alpha=0.5, color="#E74C3C") + ax4.set_title("策略回撤", fontsize=12) + ax4.set_ylabel("回撤") + ax4.grid(True, alpha=0.3) + + # Panel 5: Holdings distribution + ax5 = fig.add_subplot(gs[5]) holdings_series = df['holdings'] for i, code in enumerate(self.signal_codes): cfg = code_config.get(code, {}) name = cfg.get('name', code) mask = holdings_series.apply(lambda h: code in h) if mask.any(): - ax4.fill_between(mask.index, i, i + 0.8, + ax5.fill_between(mask.index, i, i + 0.8, where=mask, alpha=0.7, color=colors[i % len(colors)], label=name) ylabels = [code_config.get(c, {}).get('name', c) for c in self.signal_codes] - ax4.set_title("每日持仓分布", fontsize=12) - ax4.set_yticks(range(len(ylabels))) - ax4.set_yticklabels(ylabels, fontsize=7) - ax4.grid(True, alpha=0.3) + ax5.set_title("每日持仓分布", fontsize=12) + ax5.set_yticks(range(len(ylabels))) + ax5.set_yticklabels(ylabels, fontsize=7) + ax5.grid(True, alpha=0.3) chart_path = output_dir / 'simple_rotation_report.png' plt.savefig(str(chart_path), dpi=150, bbox_inches="tight")