Files
etf/rotation/experiments/experiment_select_num.py
aszerW ac022020c7 refactor: 整理rotation目录结构
将分析/测试/实验脚本从核心目录移出:
- enrich_etf_data.py → scripts/
- oil_tracking.py → analysis/
- tracking_error_full.py → analysis/
- tracking_error_validation.py → analysis/
- test_start_year_analysis.py → experiments/
- experiment_select_num.py → experiments/

rotation/ 目录现在只保留核心策略代码:
- simple_rotation.py (策略主逻辑)
- config_loader.py (配置加载)
- config_simple.yaml (配置文件)
- daily_scheduler.py (调度器)
2026-06-21 13:38:15 +08:00

194 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
select_num A/B 实验:对比 Top-1 / Top-2 / Top-3 的表现
用法:
python rotation/experiment_select_num.py
"""
import os
import sys
import yaml
import json
import tempfile
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from rotation.simple_rotation import SimpleRotationStrategy
def run_with_select_num(config_path: str, select_num: int, output_dir: Path) -> dict:
"""运行一次策略,覆盖 select_num"""
print(f"\n{'='*60}")
print(f" 实验: select_num = {select_num}")
print(f"{'='*60}\n")
# 读取原始配置,修改 select_num写入临时文件
with open(config_path, 'r', encoding='utf-8') as f:
cfg = yaml.safe_load(f)
cfg['rotation']['select_num'] = select_num
tmp_path = output_dir / f'config_select_{select_num}.yaml'
with open(tmp_path, 'w', encoding='utf-8') as f:
yaml.dump(cfg, f, default_flow_style=False, allow_unicode=True)
strategy = SimpleRotationStrategy(config_path=str(tmp_path))
result = strategy.run()
if result:
# 导出到子目录
sub_dir = output_dir / f'select_{select_num}'
sub_dir.mkdir(parents=True, exist_ok=True)
strategy.export_results(output_dir=str(sub_dir))
return result.get('metrics', {})
return {}
def print_comparison(all_metrics: dict):
"""打印对比表格"""
print(f"\n\n{'='*80}")
print(f" select_num 实验对比结果")
print(f"{'='*80}\n")
header = f"{'指标':<16}"
for n in sorted(all_metrics.keys()):
header += f"{'Top-'+str(n):>12}"
print(header)
print("-" * (16 + 12 * len(all_metrics)))
rows = [
('累计收益', 'total_return', '{:.2%}'),
('年化收益', 'annual_return', '{:.2%}'),
('最大回撤', 'max_drawdown', '{:.2%}'),
('夏普比率', 'sharpe_ratio', '{:.2f}'),
('Calmar比率', 'calmar_ratio', '{:.2f}'),
('日胜率', 'win_rate', '{:.2%}'),
('交易日数', 'n_days', '{}'),
('调仓次数', 'rebalance_count', '{}'),
]
for label, key, fmt in rows:
row = f"{label:<16}"
for n in sorted(all_metrics.keys()):
val = all_metrics[n].get(key, 0)
row += f"{fmt.format(val):>12}"
print(row)
print(f"\n{'='*80}")
def plot_comparison(all_metrics: dict, output_dir: Path):
"""生成对比图表"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
fig.suptitle("select_num A/B Experiment", fontsize=14, fontweight="bold")
nums = sorted(all_metrics.keys())
colors = ['#E74C3C', '#3498DB', '#2ECC71']
# 1. 收益对比
ax = axes[0]
annuals = [all_metrics[n].get('annual_return', 0) for n in nums]
totals = [all_metrics[n].get('total_return', 0) for n in nums]
x = np.arange(len(nums))
w = 0.35
ax.bar(x - w/2, [a*100 for a in annuals], w, label='Annual %', color='#E74C3C', alpha=0.8)
ax.bar(x + w/2, [t*100 for t in totals], w, label='Total %', color='#3498DB', alpha=0.8)
ax.set_xticks(x)
ax.set_xticklabels([f'Top-{n}' for n in nums])
ax.set_ylabel('Return (%)')
ax.set_title('Returns')
ax.legend()
ax.grid(True, alpha=0.3)
# 2. 风险对比
ax = axes[1]
dds = [abs(all_metrics[n].get('max_drawdown', 0)) * 100 for n in nums]
ax.bar(x, dds, color='#E74C3C', alpha=0.7)
ax.set_xticks(x)
ax.set_xticklabels([f'Top-{n}' for n in nums])
ax.set_ylabel('Max Drawdown (%)')
ax.set_title('Risk')
ax.grid(True, alpha=0.3)
# 3. 夏普 & Calmar
ax = axes[2]
sharpes = [all_metrics[n].get('sharpe_ratio', 0) for n in nums]
calmars = [all_metrics[n].get('calmar_ratio', 0) for n in nums]
ax.bar(x - w/2, sharpes, w, label='Sharpe', color='#2ECC71', alpha=0.8)
ax.bar(x + w/2, calmars, w, label='Calmar', color='#F39C12', alpha=0.8)
ax.set_xticks(x)
ax.set_xticklabels([f'Top-{n}' for n in nums])
ax.set_ylabel('Ratio')
ax.set_title('Risk-Adjusted')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
chart_path = output_dir / 'select_num_comparison.png'
plt.savefig(str(chart_path), dpi=150, bbox_inches="tight")
plt.close()
print(f"\n + Chart: {chart_path}")
def plot_nav_comparison(output_dir: Path):
"""加载三组 NAV 画在同一张图上"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(14, 6))
colors = {'1': '#E74C3C', '2': '#3498DB', '3': '#2ECC71'}
for n in [1, 2, 3]:
nav_path = output_dir / f'select_{n}' / 'simple_rotation_nav.csv'
if nav_path.exists():
df = pd.read_csv(nav_path, parse_dates=['date'])
ax.plot(df['date'], df['nav'], label=f'Top-{n}', linewidth=1.5, color=colors[str(n)])
ax.set_title("NAV Curve Comparison (select_num)", fontsize=14, fontweight="bold")
ax.set_ylabel("NAV")
ax.set_yscale("log")
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
nav_chart = output_dir / 'select_num_nav_comparison.png'
plt.savefig(str(nav_chart), dpi=150, bbox_inches="tight")
plt.close()
print(f" + NAV Chart: {nav_chart}")
if __name__ == "__main__":
if 'FLASK_API_URL' not in os.environ:
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
config_path = str(Path(__file__).parent / 'config_simple.yaml')
output_dir = PROJECT_ROOT / 'results' / 'experiment_select_num'
output_dir.mkdir(parents=True, exist_ok=True)
all_metrics = {}
for n in [1, 2, 3]:
metrics = run_with_select_num(config_path, n, output_dir)
if metrics:
all_metrics[n] = metrics
if all_metrics:
print_comparison(all_metrics)
plot_comparison(all_metrics, output_dir)
plot_nav_comparison(output_dir)
# 保存原始指标
metrics_path = output_dir / 'experiment_metrics.json'
with open(metrics_path, 'w', encoding='utf-8') as f:
json.dump({str(k): v for k, v in all_metrics.items()}, f, ensure_ascii=False, indent=2)
print(f" + Metrics: {metrics_path}")