docs(experiment): add select_num A/B/C comparison report (005)
- Experiment: select_num = 1, 2, 3 comparison - Period: 2020-01-10 ~ 2026-06-02 (1546 trading days) - Key findings: - Top-1: highest return (600%), highest drawdown (-25.5%) - Top-3: best risk-adjusted return (Calmar 1.73, Sharpe 1.35) - Top-2: balanced middle ground (Calmar 1.69) - Add rotation/experiment_select_num.py experiment script - Save report to docs/experiments/005_select_num_comparison.md
This commit is contained in:
193
rotation/experiment_select_num.py
Normal file
193
rotation/experiment_select_num.py
Normal file
@@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
select_num A/B 实验:对比 Top-1 / Top-2 / Top-3 的表现
|
||||
|
||||
用法:
|
||||
python rotation/experiment_select_num.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import yaml
|
||||
import json
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from rotation.simple_rotation import SimpleRotationStrategy
|
||||
|
||||
|
||||
def run_with_select_num(config_path: str, select_num: int, output_dir: Path) -> dict:
|
||||
"""运行一次策略,覆盖 select_num"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f" 实验: select_num = {select_num}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# 读取原始配置,修改 select_num,写入临时文件
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
cfg['rotation']['select_num'] = select_num
|
||||
|
||||
tmp_path = output_dir / f'config_select_{select_num}.yaml'
|
||||
with open(tmp_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(cfg, f, default_flow_style=False, allow_unicode=True)
|
||||
|
||||
strategy = SimpleRotationStrategy(config_path=str(tmp_path))
|
||||
result = strategy.run()
|
||||
|
||||
if result:
|
||||
# 导出到子目录
|
||||
sub_dir = output_dir / f'select_{select_num}'
|
||||
sub_dir.mkdir(parents=True, exist_ok=True)
|
||||
strategy.export_results(output_dir=str(sub_dir))
|
||||
return result.get('metrics', {})
|
||||
return {}
|
||||
|
||||
|
||||
def print_comparison(all_metrics: dict):
|
||||
"""打印对比表格"""
|
||||
print(f"\n\n{'='*80}")
|
||||
print(f" select_num 实验对比结果")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
header = f"{'指标':<16}"
|
||||
for n in sorted(all_metrics.keys()):
|
||||
header += f"{'Top-'+str(n):>12}"
|
||||
print(header)
|
||||
print("-" * (16 + 12 * len(all_metrics)))
|
||||
|
||||
rows = [
|
||||
('累计收益', 'total_return', '{:.2%}'),
|
||||
('年化收益', 'annual_return', '{:.2%}'),
|
||||
('最大回撤', 'max_drawdown', '{:.2%}'),
|
||||
('夏普比率', 'sharpe_ratio', '{:.2f}'),
|
||||
('Calmar比率', 'calmar_ratio', '{:.2f}'),
|
||||
('日胜率', 'win_rate', '{:.2%}'),
|
||||
('交易日数', 'n_days', '{}'),
|
||||
('调仓次数', 'rebalance_count', '{}'),
|
||||
]
|
||||
|
||||
for label, key, fmt in rows:
|
||||
row = f"{label:<16}"
|
||||
for n in sorted(all_metrics.keys()):
|
||||
val = all_metrics[n].get(key, 0)
|
||||
row += f"{fmt.format(val):>12}"
|
||||
print(row)
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
|
||||
|
||||
def plot_comparison(all_metrics: dict, output_dir: Path):
|
||||
"""生成对比图表"""
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
|
||||
fig.suptitle("select_num A/B Experiment", fontsize=14, fontweight="bold")
|
||||
|
||||
nums = sorted(all_metrics.keys())
|
||||
colors = ['#E74C3C', '#3498DB', '#2ECC71']
|
||||
|
||||
# 1. 收益对比
|
||||
ax = axes[0]
|
||||
annuals = [all_metrics[n].get('annual_return', 0) for n in nums]
|
||||
totals = [all_metrics[n].get('total_return', 0) for n in nums]
|
||||
x = np.arange(len(nums))
|
||||
w = 0.35
|
||||
ax.bar(x - w/2, [a*100 for a in annuals], w, label='Annual %', color='#E74C3C', alpha=0.8)
|
||||
ax.bar(x + w/2, [t*100 for t in totals], w, label='Total %', color='#3498DB', alpha=0.8)
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels([f'Top-{n}' for n in nums])
|
||||
ax.set_ylabel('Return (%)')
|
||||
ax.set_title('Returns')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# 2. 风险对比
|
||||
ax = axes[1]
|
||||
dds = [abs(all_metrics[n].get('max_drawdown', 0)) * 100 for n in nums]
|
||||
ax.bar(x, dds, color='#E74C3C', alpha=0.7)
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels([f'Top-{n}' for n in nums])
|
||||
ax.set_ylabel('Max Drawdown (%)')
|
||||
ax.set_title('Risk')
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# 3. 夏普 & Calmar
|
||||
ax = axes[2]
|
||||
sharpes = [all_metrics[n].get('sharpe_ratio', 0) for n in nums]
|
||||
calmars = [all_metrics[n].get('calmar_ratio', 0) for n in nums]
|
||||
ax.bar(x - w/2, sharpes, w, label='Sharpe', color='#2ECC71', alpha=0.8)
|
||||
ax.bar(x + w/2, calmars, w, label='Calmar', color='#F39C12', alpha=0.8)
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels([f'Top-{n}' for n in nums])
|
||||
ax.set_ylabel('Ratio')
|
||||
ax.set_title('Risk-Adjusted')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
chart_path = output_dir / 'select_num_comparison.png'
|
||||
plt.savefig(str(chart_path), dpi=150, bbox_inches="tight")
|
||||
plt.close()
|
||||
print(f"\n + Chart: {chart_path}")
|
||||
|
||||
|
||||
def plot_nav_comparison(output_dir: Path):
|
||||
"""加载三组 NAV 画在同一张图上"""
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, ax = plt.subplots(figsize=(14, 6))
|
||||
colors = {'1': '#E74C3C', '2': '#3498DB', '3': '#2ECC71'}
|
||||
|
||||
for n in [1, 2, 3]:
|
||||
nav_path = output_dir / f'select_{n}' / 'simple_rotation_nav.csv'
|
||||
if nav_path.exists():
|
||||
df = pd.read_csv(nav_path, parse_dates=['date'])
|
||||
ax.plot(df['date'], df['nav'], label=f'Top-{n}', linewidth=1.5, color=colors[str(n)])
|
||||
|
||||
ax.set_title("NAV Curve Comparison (select_num)", fontsize=14, fontweight="bold")
|
||||
ax.set_ylabel("NAV")
|
||||
ax.set_yscale("log")
|
||||
ax.legend(fontsize=11)
|
||||
ax.grid(True, alpha=0.3)
|
||||
plt.tight_layout()
|
||||
|
||||
nav_chart = output_dir / 'select_num_nav_comparison.png'
|
||||
plt.savefig(str(nav_chart), dpi=150, bbox_inches="tight")
|
||||
plt.close()
|
||||
print(f" + NAV Chart: {nav_chart}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if 'FLASK_API_URL' not in os.environ:
|
||||
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
|
||||
|
||||
config_path = str(Path(__file__).parent / 'config_simple.yaml')
|
||||
output_dir = PROJECT_ROOT / 'results' / 'experiment_select_num'
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
all_metrics = {}
|
||||
for n in [1, 2, 3]:
|
||||
metrics = run_with_select_num(config_path, n, output_dir)
|
||||
if metrics:
|
||||
all_metrics[n] = metrics
|
||||
|
||||
if all_metrics:
|
||||
print_comparison(all_metrics)
|
||||
plot_comparison(all_metrics, output_dir)
|
||||
plot_nav_comparison(output_dir)
|
||||
|
||||
# 保存原始指标
|
||||
metrics_path = output_dir / 'experiment_metrics.json'
|
||||
with open(metrics_path, 'w', encoding='utf-8') as f:
|
||||
json.dump({str(k): v for k, v in all_metrics.items()}, f, ensure_ascii=False, indent=2)
|
||||
print(f" + Metrics: {metrics_path}")
|
||||
Reference in New Issue
Block a user