refactor: 整理rotation目录结构

将分析/测试/实验脚本从核心目录移出:
- enrich_etf_data.py → scripts/
- oil_tracking.py → analysis/
- tracking_error_full.py → analysis/
- tracking_error_validation.py → analysis/
- test_start_year_analysis.py → experiments/
- experiment_select_num.py → experiments/

rotation/ 目录现在只保留核心策略代码:
- simple_rotation.py (策略主逻辑)
- config_loader.py (配置加载)
- config_simple.yaml (配置文件)
- daily_scheduler.py (调度器)
This commit is contained in:
2026-06-21 13:38:15 +08:00
parent 0da0306894
commit ac022020c7
6 changed files with 0 additions and 0 deletions

View File

@@ -0,0 +1,193 @@
#!/usr/bin/env python3
"""
select_num A/B 实验:对比 Top-1 / Top-2 / Top-3 的表现
用法:
python rotation/experiment_select_num.py
"""
import os
import sys
import yaml
import json
import tempfile
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from rotation.simple_rotation import SimpleRotationStrategy
def run_with_select_num(config_path: str, select_num: int, output_dir: Path) -> dict:
"""运行一次策略,覆盖 select_num"""
print(f"\n{'='*60}")
print(f" 实验: select_num = {select_num}")
print(f"{'='*60}\n")
# 读取原始配置,修改 select_num写入临时文件
with open(config_path, 'r', encoding='utf-8') as f:
cfg = yaml.safe_load(f)
cfg['rotation']['select_num'] = select_num
tmp_path = output_dir / f'config_select_{select_num}.yaml'
with open(tmp_path, 'w', encoding='utf-8') as f:
yaml.dump(cfg, f, default_flow_style=False, allow_unicode=True)
strategy = SimpleRotationStrategy(config_path=str(tmp_path))
result = strategy.run()
if result:
# 导出到子目录
sub_dir = output_dir / f'select_{select_num}'
sub_dir.mkdir(parents=True, exist_ok=True)
strategy.export_results(output_dir=str(sub_dir))
return result.get('metrics', {})
return {}
def print_comparison(all_metrics: dict):
"""打印对比表格"""
print(f"\n\n{'='*80}")
print(f" select_num 实验对比结果")
print(f"{'='*80}\n")
header = f"{'指标':<16}"
for n in sorted(all_metrics.keys()):
header += f"{'Top-'+str(n):>12}"
print(header)
print("-" * (16 + 12 * len(all_metrics)))
rows = [
('累计收益', 'total_return', '{:.2%}'),
('年化收益', 'annual_return', '{:.2%}'),
('最大回撤', 'max_drawdown', '{:.2%}'),
('夏普比率', 'sharpe_ratio', '{:.2f}'),
('Calmar比率', 'calmar_ratio', '{:.2f}'),
('日胜率', 'win_rate', '{:.2%}'),
('交易日数', 'n_days', '{}'),
('调仓次数', 'rebalance_count', '{}'),
]
for label, key, fmt in rows:
row = f"{label:<16}"
for n in sorted(all_metrics.keys()):
val = all_metrics[n].get(key, 0)
row += f"{fmt.format(val):>12}"
print(row)
print(f"\n{'='*80}")
def plot_comparison(all_metrics: dict, output_dir: Path):
"""生成对比图表"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
fig.suptitle("select_num A/B Experiment", fontsize=14, fontweight="bold")
nums = sorted(all_metrics.keys())
colors = ['#E74C3C', '#3498DB', '#2ECC71']
# 1. 收益对比
ax = axes[0]
annuals = [all_metrics[n].get('annual_return', 0) for n in nums]
totals = [all_metrics[n].get('total_return', 0) for n in nums]
x = np.arange(len(nums))
w = 0.35
ax.bar(x - w/2, [a*100 for a in annuals], w, label='Annual %', color='#E74C3C', alpha=0.8)
ax.bar(x + w/2, [t*100 for t in totals], w, label='Total %', color='#3498DB', alpha=0.8)
ax.set_xticks(x)
ax.set_xticklabels([f'Top-{n}' for n in nums])
ax.set_ylabel('Return (%)')
ax.set_title('Returns')
ax.legend()
ax.grid(True, alpha=0.3)
# 2. 风险对比
ax = axes[1]
dds = [abs(all_metrics[n].get('max_drawdown', 0)) * 100 for n in nums]
ax.bar(x, dds, color='#E74C3C', alpha=0.7)
ax.set_xticks(x)
ax.set_xticklabels([f'Top-{n}' for n in nums])
ax.set_ylabel('Max Drawdown (%)')
ax.set_title('Risk')
ax.grid(True, alpha=0.3)
# 3. 夏普 & Calmar
ax = axes[2]
sharpes = [all_metrics[n].get('sharpe_ratio', 0) for n in nums]
calmars = [all_metrics[n].get('calmar_ratio', 0) for n in nums]
ax.bar(x - w/2, sharpes, w, label='Sharpe', color='#2ECC71', alpha=0.8)
ax.bar(x + w/2, calmars, w, label='Calmar', color='#F39C12', alpha=0.8)
ax.set_xticks(x)
ax.set_xticklabels([f'Top-{n}' for n in nums])
ax.set_ylabel('Ratio')
ax.set_title('Risk-Adjusted')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
chart_path = output_dir / 'select_num_comparison.png'
plt.savefig(str(chart_path), dpi=150, bbox_inches="tight")
plt.close()
print(f"\n + Chart: {chart_path}")
def plot_nav_comparison(output_dir: Path):
"""加载三组 NAV 画在同一张图上"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(14, 6))
colors = {'1': '#E74C3C', '2': '#3498DB', '3': '#2ECC71'}
for n in [1, 2, 3]:
nav_path = output_dir / f'select_{n}' / 'simple_rotation_nav.csv'
if nav_path.exists():
df = pd.read_csv(nav_path, parse_dates=['date'])
ax.plot(df['date'], df['nav'], label=f'Top-{n}', linewidth=1.5, color=colors[str(n)])
ax.set_title("NAV Curve Comparison (select_num)", fontsize=14, fontweight="bold")
ax.set_ylabel("NAV")
ax.set_yscale("log")
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
nav_chart = output_dir / 'select_num_nav_comparison.png'
plt.savefig(str(nav_chart), dpi=150, bbox_inches="tight")
plt.close()
print(f" + NAV Chart: {nav_chart}")
if __name__ == "__main__":
if 'FLASK_API_URL' not in os.environ:
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
config_path = str(Path(__file__).parent / 'config_simple.yaml')
output_dir = PROJECT_ROOT / 'results' / 'experiment_select_num'
output_dir.mkdir(parents=True, exist_ok=True)
all_metrics = {}
for n in [1, 2, 3]:
metrics = run_with_select_num(config_path, n, output_dir)
if metrics:
all_metrics[n] = metrics
if all_metrics:
print_comparison(all_metrics)
plot_comparison(all_metrics, output_dir)
plot_nav_comparison(output_dir)
# 保存原始指标
metrics_path = output_dir / 'experiment_metrics.json'
with open(metrics_path, 'w', encoding='utf-8') as f:
json.dump({str(k): v for k, v in all_metrics.items()}, f, ensure_ascii=False, indent=2)
print(f" + Metrics: {metrics_path}")

View File

@@ -0,0 +1,112 @@
"""
Test different start years with select_num=1
"""
import os
import sys
import yaml
from pathlib import Path
from datetime import datetime
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from dotenv import load_dotenv
load_dotenv(PROJECT_ROOT / '.env')
from rotation.config_loader import load_rotation_config
from rotation.simple_rotation import SimpleRotationStrategy
def run_test(start_date: str, select_num: int) -> dict:
"""Run backtest with specified start date and select_num."""
config_path = PROJECT_ROOT / 'rotation' / 'config_simple.yaml'
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
config['backtest']['start_date'] = start_date
config['rotation']['select_num'] = select_num
temp_config_path = PROJECT_ROOT / 'rotation' / 'temp_config.yaml'
with open(temp_config_path, 'w') as f:
yaml.dump(config, f)
try:
strategy = SimpleRotationStrategy(str(temp_config_path))
result = strategy.run()
return result['metrics']
finally:
if temp_config_path.exists():
temp_config_path.unlink()
def main():
select_num = 1
years = [2020, 2021, 2022, 2023, 2024, 2025]
print(f"\n{'='*80}")
print(f"Testing select_num={select_num} with different start years")
print(f"{'='*80}")
results = []
for year in years:
start_date = f"{year}-01-01"
print(f"\nTesting start_date={start_date}...")
try:
metrics = run_test(start_date, select_num)
results.append({
'start_year': year,
'start_date': start_date,
'select_num': select_num,
'total_return': metrics.get('total_return', 0),
'annual_return': metrics.get('annual_return', 0),
'max_drawdown': metrics.get('max_drawdown', 0),
'sharpe_ratio': metrics.get('sharpe_ratio', 0),
'rebalance_count': metrics.get('rebalance_count', 0),
'win_rate': metrics.get('win_rate', 0),
})
print(f" Total Return: {metrics.get('total_return', 0)*100:.2f}%")
print(f" Annual Return: {metrics.get('annual_return', 0)*100:.2f}%")
print(f" Max Drawdown: {metrics.get('max_drawdown', 0)*100:.2f}%")
print(f" Sharpe Ratio: {metrics.get('sharpe_ratio', 0):.3f}")
print(f" Rebalance Count: {metrics.get('rebalance_count', 0)}")
except Exception as e:
print(f" Error: {e}")
results.append({
'start_year': year,
'start_date': start_date,
'select_num': select_num,
'error': str(e)
})
# Print summary table
print(f"\n{'='*80}")
print(f"SUMMARY TABLE (select_num={select_num})")
print(f"{'='*80}")
print(f"{'Start Year':<12} {'Total Return':<15} {'Annual Return':<15} {'Max Drawdown':<15} {'Sharpe':<10} {'Rebal':<8}")
print(f"{'-'*80}")
for r in results:
if 'error' in r:
print(f"{r['start_year']:<12} {'ERROR':<15}")
else:
print(f"{r['start_year']:<12} {r['total_return']*100:>13.2f}% {r['annual_return']*100:>13.2f}% {r['max_drawdown']*100:>13.2f}% {r['sharpe_ratio']:>9.3f} {r['rebalance_count']:>7}")
# Save results to YAML
output_path = PROJECT_ROOT / 'rotation' / 'results' / 'start_year_analysis.yaml'
output_path.parent.mkdir(exist_ok=True)
with open(output_path, 'w') as f:
yaml.dump({
'select_num': select_num,
'test_date': datetime.now().isoformat(),
'results': results
}, f, default_flow_style=False)
print(f"\nResults saved to: {output_path}")
if __name__ == '__main__':
main()