Files
etf/rotation/experiment_select_num.py
aszerW a47af0f0eb docs(experiment): add select_num A/B/C comparison report (005)
- Experiment: select_num = 1, 2, 3 comparison
- Period: 2020-01-10 ~ 2026-06-02 (1546 trading days)
- Key findings:
  - Top-1: highest return (600%), highest drawdown (-25.5%)
  - Top-3: best risk-adjusted return (Calmar 1.73, Sharpe 1.35)
  - Top-2: balanced middle ground (Calmar 1.69)
- Add rotation/experiment_select_num.py experiment script
- Save report to docs/experiments/005_select_num_comparison.md
2026-06-02 01:32:43 +08:00

194 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
select_num A/B 实验:对比 Top-1 / Top-2 / Top-3 的表现
用法:
python rotation/experiment_select_num.py
"""
import os
import sys
import yaml
import json
import tempfile
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from rotation.simple_rotation import SimpleRotationStrategy
def run_with_select_num(config_path: str, select_num: int, output_dir: Path) -> dict:
"""运行一次策略,覆盖 select_num"""
print(f"\n{'='*60}")
print(f" 实验: select_num = {select_num}")
print(f"{'='*60}\n")
# 读取原始配置,修改 select_num写入临时文件
with open(config_path, 'r', encoding='utf-8') as f:
cfg = yaml.safe_load(f)
cfg['rotation']['select_num'] = select_num
tmp_path = output_dir / f'config_select_{select_num}.yaml'
with open(tmp_path, 'w', encoding='utf-8') as f:
yaml.dump(cfg, f, default_flow_style=False, allow_unicode=True)
strategy = SimpleRotationStrategy(config_path=str(tmp_path))
result = strategy.run()
if result:
# 导出到子目录
sub_dir = output_dir / f'select_{select_num}'
sub_dir.mkdir(parents=True, exist_ok=True)
strategy.export_results(output_dir=str(sub_dir))
return result.get('metrics', {})
return {}
def print_comparison(all_metrics: dict):
"""打印对比表格"""
print(f"\n\n{'='*80}")
print(f" select_num 实验对比结果")
print(f"{'='*80}\n")
header = f"{'指标':<16}"
for n in sorted(all_metrics.keys()):
header += f"{'Top-'+str(n):>12}"
print(header)
print("-" * (16 + 12 * len(all_metrics)))
rows = [
('累计收益', 'total_return', '{:.2%}'),
('年化收益', 'annual_return', '{:.2%}'),
('最大回撤', 'max_drawdown', '{:.2%}'),
('夏普比率', 'sharpe_ratio', '{:.2f}'),
('Calmar比率', 'calmar_ratio', '{:.2f}'),
('日胜率', 'win_rate', '{:.2%}'),
('交易日数', 'n_days', '{}'),
('调仓次数', 'rebalance_count', '{}'),
]
for label, key, fmt in rows:
row = f"{label:<16}"
for n in sorted(all_metrics.keys()):
val = all_metrics[n].get(key, 0)
row += f"{fmt.format(val):>12}"
print(row)
print(f"\n{'='*80}")
def plot_comparison(all_metrics: dict, output_dir: Path):
"""生成对比图表"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
fig.suptitle("select_num A/B Experiment", fontsize=14, fontweight="bold")
nums = sorted(all_metrics.keys())
colors = ['#E74C3C', '#3498DB', '#2ECC71']
# 1. 收益对比
ax = axes[0]
annuals = [all_metrics[n].get('annual_return', 0) for n in nums]
totals = [all_metrics[n].get('total_return', 0) for n in nums]
x = np.arange(len(nums))
w = 0.35
ax.bar(x - w/2, [a*100 for a in annuals], w, label='Annual %', color='#E74C3C', alpha=0.8)
ax.bar(x + w/2, [t*100 for t in totals], w, label='Total %', color='#3498DB', alpha=0.8)
ax.set_xticks(x)
ax.set_xticklabels([f'Top-{n}' for n in nums])
ax.set_ylabel('Return (%)')
ax.set_title('Returns')
ax.legend()
ax.grid(True, alpha=0.3)
# 2. 风险对比
ax = axes[1]
dds = [abs(all_metrics[n].get('max_drawdown', 0)) * 100 for n in nums]
ax.bar(x, dds, color='#E74C3C', alpha=0.7)
ax.set_xticks(x)
ax.set_xticklabels([f'Top-{n}' for n in nums])
ax.set_ylabel('Max Drawdown (%)')
ax.set_title('Risk')
ax.grid(True, alpha=0.3)
# 3. 夏普 & Calmar
ax = axes[2]
sharpes = [all_metrics[n].get('sharpe_ratio', 0) for n in nums]
calmars = [all_metrics[n].get('calmar_ratio', 0) for n in nums]
ax.bar(x - w/2, sharpes, w, label='Sharpe', color='#2ECC71', alpha=0.8)
ax.bar(x + w/2, calmars, w, label='Calmar', color='#F39C12', alpha=0.8)
ax.set_xticks(x)
ax.set_xticklabels([f'Top-{n}' for n in nums])
ax.set_ylabel('Ratio')
ax.set_title('Risk-Adjusted')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
chart_path = output_dir / 'select_num_comparison.png'
plt.savefig(str(chart_path), dpi=150, bbox_inches="tight")
plt.close()
print(f"\n + Chart: {chart_path}")
def plot_nav_comparison(output_dir: Path):
"""加载三组 NAV 画在同一张图上"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(14, 6))
colors = {'1': '#E74C3C', '2': '#3498DB', '3': '#2ECC71'}
for n in [1, 2, 3]:
nav_path = output_dir / f'select_{n}' / 'simple_rotation_nav.csv'
if nav_path.exists():
df = pd.read_csv(nav_path, parse_dates=['date'])
ax.plot(df['date'], df['nav'], label=f'Top-{n}', linewidth=1.5, color=colors[str(n)])
ax.set_title("NAV Curve Comparison (select_num)", fontsize=14, fontweight="bold")
ax.set_ylabel("NAV")
ax.set_yscale("log")
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
nav_chart = output_dir / 'select_num_nav_comparison.png'
plt.savefig(str(nav_chart), dpi=150, bbox_inches="tight")
plt.close()
print(f" + NAV Chart: {nav_chart}")
if __name__ == "__main__":
if 'FLASK_API_URL' not in os.environ:
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
config_path = str(Path(__file__).parent / 'config_simple.yaml')
output_dir = PROJECT_ROOT / 'results' / 'experiment_select_num'
output_dir.mkdir(parents=True, exist_ok=True)
all_metrics = {}
for n in [1, 2, 3]:
metrics = run_with_select_num(config_path, n, output_dir)
if metrics:
all_metrics[n] = metrics
if all_metrics:
print_comparison(all_metrics)
plot_comparison(all_metrics, output_dir)
plot_nav_comparison(output_dir)
# 保存原始指标
metrics_path = output_dir / 'experiment_metrics.json'
with open(metrics_path, 'w', encoding='utf-8') as f:
json.dump({str(k): v for k, v in all_metrics.items()}, f, ensure_ascii=False, indent=2)
print(f" + Metrics: {metrics_path}")