#!/usr/bin/env python3 """ select_num A/B 实验:对比 Top-1 / Top-2 / Top-3 的表现 用法: python rotation/experiment_select_num.py """ import os import sys import yaml import json import tempfile import numpy as np import pandas as pd from pathlib import Path from datetime import datetime PROJECT_ROOT = Path(__file__).parent.parent sys.path.insert(0, str(PROJECT_ROOT)) from rotation.simple_rotation import SimpleRotationStrategy def run_with_select_num(config_path: str, select_num: int, output_dir: Path) -> dict: """运行一次策略,覆盖 select_num""" print(f"\n{'='*60}") print(f" 实验: select_num = {select_num}") print(f"{'='*60}\n") # 读取原始配置,修改 select_num,写入临时文件 with open(config_path, 'r', encoding='utf-8') as f: cfg = yaml.safe_load(f) cfg['rotation']['select_num'] = select_num tmp_path = output_dir / f'config_select_{select_num}.yaml' with open(tmp_path, 'w', encoding='utf-8') as f: yaml.dump(cfg, f, default_flow_style=False, allow_unicode=True) strategy = SimpleRotationStrategy(config_path=str(tmp_path)) result = strategy.run() if result: # 导出到子目录 sub_dir = output_dir / f'select_{select_num}' sub_dir.mkdir(parents=True, exist_ok=True) strategy.export_results(output_dir=str(sub_dir)) return result.get('metrics', {}) return {} def print_comparison(all_metrics: dict): """打印对比表格""" print(f"\n\n{'='*80}") print(f" select_num 实验对比结果") print(f"{'='*80}\n") header = f"{'指标':<16}" for n in sorted(all_metrics.keys()): header += f"{'Top-'+str(n):>12}" print(header) print("-" * (16 + 12 * len(all_metrics))) rows = [ ('累计收益', 'total_return', '{:.2%}'), ('年化收益', 'annual_return', '{:.2%}'), ('最大回撤', 'max_drawdown', '{:.2%}'), ('夏普比率', 'sharpe_ratio', '{:.2f}'), ('Calmar比率', 'calmar_ratio', '{:.2f}'), ('日胜率', 'win_rate', '{:.2%}'), ('交易日数', 'n_days', '{}'), ('调仓次数', 'rebalance_count', '{}'), ] for label, key, fmt in rows: row = f"{label:<16}" for n in sorted(all_metrics.keys()): val = all_metrics[n].get(key, 0) row += f"{fmt.format(val):>12}" print(row) print(f"\n{'='*80}") def plot_comparison(all_metrics: dict, output_dir: Path): """生成对比图表""" import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt fig, axes = plt.subplots(1, 3, figsize=(16, 5)) fig.suptitle("select_num A/B Experiment", fontsize=14, fontweight="bold") nums = sorted(all_metrics.keys()) colors = ['#E74C3C', '#3498DB', '#2ECC71'] # 1. 收益对比 ax = axes[0] annuals = [all_metrics[n].get('annual_return', 0) for n in nums] totals = [all_metrics[n].get('total_return', 0) for n in nums] x = np.arange(len(nums)) w = 0.35 ax.bar(x - w/2, [a*100 for a in annuals], w, label='Annual %', color='#E74C3C', alpha=0.8) ax.bar(x + w/2, [t*100 for t in totals], w, label='Total %', color='#3498DB', alpha=0.8) ax.set_xticks(x) ax.set_xticklabels([f'Top-{n}' for n in nums]) ax.set_ylabel('Return (%)') ax.set_title('Returns') ax.legend() ax.grid(True, alpha=0.3) # 2. 风险对比 ax = axes[1] dds = [abs(all_metrics[n].get('max_drawdown', 0)) * 100 for n in nums] ax.bar(x, dds, color='#E74C3C', alpha=0.7) ax.set_xticks(x) ax.set_xticklabels([f'Top-{n}' for n in nums]) ax.set_ylabel('Max Drawdown (%)') ax.set_title('Risk') ax.grid(True, alpha=0.3) # 3. 夏普 & Calmar ax = axes[2] sharpes = [all_metrics[n].get('sharpe_ratio', 0) for n in nums] calmars = [all_metrics[n].get('calmar_ratio', 0) for n in nums] ax.bar(x - w/2, sharpes, w, label='Sharpe', color='#2ECC71', alpha=0.8) ax.bar(x + w/2, calmars, w, label='Calmar', color='#F39C12', alpha=0.8) ax.set_xticks(x) ax.set_xticklabels([f'Top-{n}' for n in nums]) ax.set_ylabel('Ratio') ax.set_title('Risk-Adjusted') ax.legend() ax.grid(True, alpha=0.3) plt.tight_layout() chart_path = output_dir / 'select_num_comparison.png' plt.savefig(str(chart_path), dpi=150, bbox_inches="tight") plt.close() print(f"\n + Chart: {chart_path}") def plot_nav_comparison(output_dir: Path): """加载三组 NAV 画在同一张图上""" import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt fig, ax = plt.subplots(figsize=(14, 6)) colors = {'1': '#E74C3C', '2': '#3498DB', '3': '#2ECC71'} for n in [1, 2, 3]: nav_path = output_dir / f'select_{n}' / 'simple_rotation_nav.csv' if nav_path.exists(): df = pd.read_csv(nav_path, parse_dates=['date']) ax.plot(df['date'], df['nav'], label=f'Top-{n}', linewidth=1.5, color=colors[str(n)]) ax.set_title("NAV Curve Comparison (select_num)", fontsize=14, fontweight="bold") ax.set_ylabel("NAV") ax.set_yscale("log") ax.legend(fontsize=11) ax.grid(True, alpha=0.3) plt.tight_layout() nav_chart = output_dir / 'select_num_nav_comparison.png' plt.savefig(str(nav_chart), dpi=150, bbox_inches="tight") plt.close() print(f" + NAV Chart: {nav_chart}") if __name__ == "__main__": if 'FLASK_API_URL' not in os.environ: os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz' config_path = str(Path(__file__).parent / 'config_simple.yaml') output_dir = PROJECT_ROOT / 'results' / 'experiment_select_num' output_dir.mkdir(parents=True, exist_ok=True) all_metrics = {} for n in [1, 2, 3]: metrics = run_with_select_num(config_path, n, output_dir) if metrics: all_metrics[n] = metrics if all_metrics: print_comparison(all_metrics) plot_comparison(all_metrics, output_dir) plot_nav_comparison(output_dir) # 保存原始指标 metrics_path = output_dir / 'experiment_metrics.json' with open(metrics_path, 'w', encoding='utf-8') as f: json.dump({str(k): v for k, v in all_metrics.items()}, f, ensure_ascii=False, indent=2) print(f" + Metrics: {metrics_path}")