clean(rotation): add simple rotation strategy and remove unused files
New: - rotation/simple_rotation.py: daily-iteration rotation strategy (584 lines) - rotation/config_loader.py: standalone config loader - rotation/config_simple.yaml: 11 assets, 7 groups - rotation/README_SIMPLE.md: usage guide - scripts/get_trading_calendar.py: trading calendar fetcher Removed: - rotation/example_usage.py, run_strategy.py (replaced by simple_rotation.py) - rotation/results/ output files (gitignored) - scripts/verify_*.py, calculate_returns_from_detail.py (one-off scripts) - scripts/README_TRADING_CALENDAR.md Backtest result (2020-01-10 ~ 2026-06-01): - Total return: 1237.6%, Annual: 52.66% - Max drawdown: -11.71%, Sharpe: 2.50
This commit is contained in:
457
rotation/config_loader.py
Normal file
457
rotation/config_loader.py
Normal file
@@ -0,0 +1,457 @@
|
||||
"""
|
||||
配置加载器 - rotation 目录专用(独立实现)
|
||||
|
||||
不依赖 framework_v2,完全独立的配置加载和验证机制。
|
||||
|
||||
使用示例:
|
||||
from config_loader import load_rotation_config
|
||||
|
||||
# 加载配置(自动处理环境变量)
|
||||
config = load_rotation_config('config_simple.yaml')
|
||||
|
||||
# 访问配置
|
||||
print(f"策略版本: {config.metadata.version}")
|
||||
print(f"动量窗口: {config.factor.n_days}")
|
||||
print(f"资产数量: {config.asset_pools.count()}")
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List, Any
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from enum import Enum
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 枚举类型
|
||||
# ============================================================
|
||||
|
||||
class FactorType(str, Enum):
|
||||
"""因子类型枚举"""
|
||||
MOMENTUM = "momentum"
|
||||
SLOPE_R2 = "slope_r2"
|
||||
WEIGHTED_MOMENTUM = "weighted_momentum"
|
||||
|
||||
|
||||
class PremiumMode(str, Enum):
|
||||
"""溢价控制模式"""
|
||||
FILTER = "filter" # 完全排除
|
||||
PENALIZE = "penalize" # 降权
|
||||
|
||||
|
||||
class ThresholdMode(str, Enum):
|
||||
"""阈值模式"""
|
||||
FIXED = "fixed" # 固定阈值
|
||||
DYNAMIC = "dynamic" # 动态阈值
|
||||
|
||||
|
||||
class DataSourceType(str, Enum):
|
||||
"""数据源类型"""
|
||||
FLASK_API = "flask_api"
|
||||
TUSHARE = "tushare"
|
||||
YFINANCE = "yfinance"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Schema 定义(独立实现)
|
||||
# ============================================================
|
||||
|
||||
class PremiumConfig(BaseModel):
|
||||
"""溢价控制配置"""
|
||||
enabled: bool = Field(default=True, description="是否启用")
|
||||
threshold: float = Field(default=0.10, ge=0, le=1.0, description="溢价阈值")
|
||||
|
||||
|
||||
class AssetConfig(BaseModel):
|
||||
"""标的配置"""
|
||||
name: str = Field(..., description="标的名称")
|
||||
group: str = Field(..., description="策略分组")
|
||||
|
||||
signal_source: str = Field(..., description="信号来源代码")
|
||||
trade_source: str = Field(..., description="交易来源代码")
|
||||
|
||||
description: Optional[str] = Field(None, description="标的描述")
|
||||
etf: Optional[str] = Field(None, description="ETF代码(兼容旧配置)")
|
||||
premium_control: Optional["PremiumConfig"] = Field(None, description="溢价控制配置")
|
||||
|
||||
@property
|
||||
def is_cross_market(self) -> bool:
|
||||
"""是否跨市场"""
|
||||
return self.signal_source != self.trade_source
|
||||
|
||||
|
||||
class AssetPool(BaseModel):
|
||||
"""资产池(扁平化设计)"""
|
||||
assets: Dict[str, AssetConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="所有标的"
|
||||
)
|
||||
|
||||
@property
|
||||
def by_group(self) -> Dict[str, Dict[str, AssetConfig]]:
|
||||
"""按策略分组"""
|
||||
groups = {}
|
||||
for code, asset in self.assets.items():
|
||||
group = asset.group
|
||||
if group not in groups:
|
||||
groups[group] = {}
|
||||
groups[group][code] = asset
|
||||
return groups
|
||||
|
||||
@property
|
||||
def groups(self) -> list:
|
||||
"""获取所有策略分组"""
|
||||
return list(self.by_group.keys())
|
||||
|
||||
def get_signal_codes(self, group: str = None) -> list:
|
||||
"""获取信号标的"""
|
||||
if group:
|
||||
return [
|
||||
asset.signal_source
|
||||
for asset in self.assets.values()
|
||||
if asset.group == group
|
||||
]
|
||||
return [asset.signal_source for asset in self.assets.values()]
|
||||
|
||||
def get_trade_codes(self, group: str = None) -> list:
|
||||
"""获取交易标的"""
|
||||
if group:
|
||||
return [
|
||||
asset.trade_source
|
||||
for asset in self.assets.values()
|
||||
if asset.group == group
|
||||
]
|
||||
return [asset.trade_source for asset in self.assets.values()]
|
||||
|
||||
def get_signal_to_trade_mapping(self, group: str = None) -> Dict[str, str]:
|
||||
"""获取信号→交易映射"""
|
||||
mapping = {}
|
||||
for asset in self.assets.values():
|
||||
if group and asset.group != group:
|
||||
continue
|
||||
mapping[asset.signal_source] = asset.trade_source
|
||||
return mapping
|
||||
|
||||
def count(self, group: str = None) -> int:
|
||||
"""获取标的数量"""
|
||||
if group:
|
||||
return len([a for a in self.assets.values() if a.group == group])
|
||||
return len(self.assets)
|
||||
|
||||
|
||||
class FactorConfig(BaseModel):
|
||||
"""因子配置"""
|
||||
type: FactorType = Field(default=FactorType.WEIGHTED_MOMENTUM)
|
||||
n_days: int = Field(default=25, ge=5, le=250)
|
||||
|
||||
|
||||
class DynamicThresholdConfig(BaseModel):
|
||||
"""动态阈值配置"""
|
||||
reference: str = Field(..., description="阈值参考标的")
|
||||
ratio: float = Field(default=1.0, ge=0, description="阈值倍数")
|
||||
fallback_enabled: bool = Field(default=True)
|
||||
fallback_value: float = Field(default=0.0)
|
||||
|
||||
|
||||
class ThresholdConfig(BaseModel):
|
||||
"""阈值配置"""
|
||||
mode: ThresholdMode = Field(default=ThresholdMode.DYNAMIC)
|
||||
fixed_value: float = Field(default=0.0)
|
||||
dynamic: Optional[DynamicThresholdConfig] = Field(None)
|
||||
|
||||
|
||||
class RotationConfig(BaseModel):
|
||||
"""轮动配置"""
|
||||
select_num: int = Field(default=3, ge=1, le=10)
|
||||
diversified: bool = Field(default=True)
|
||||
threshold: ThresholdConfig = Field(default_factory=ThresholdConfig)
|
||||
|
||||
|
||||
class RebalanceConfig(BaseModel):
|
||||
"""调仓配置"""
|
||||
min_hold_days: int = Field(default=1, ge=1)
|
||||
score_threshold: float = Field(default=0.0)
|
||||
trade_cost: float = Field(default=0.001, ge=0, le=0.01)
|
||||
|
||||
|
||||
class PremiumControlConfig(BaseModel):
|
||||
"""溢价控制配置"""
|
||||
enabled: bool = Field(default=False)
|
||||
default_threshold: float = Field(default=0.10)
|
||||
mode: PremiumMode = Field(default=PremiumMode.FILTER)
|
||||
penalty_factor: float = Field(default=0.5)
|
||||
market_overrides: Optional[Dict[str, Dict[str, Any]]] = Field(None)
|
||||
|
||||
|
||||
class DataSourceConfig(BaseModel):
|
||||
"""数据源配置"""
|
||||
type: DataSourceType = Field(...)
|
||||
enabled: bool = Field(default=True)
|
||||
timeout: int = Field(default=120, ge=10, le=600)
|
||||
url: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class DataConfig(BaseModel):
|
||||
"""数据配置"""
|
||||
sources: List[DataSourceConfig] = Field(...)
|
||||
|
||||
@field_validator('sources')
|
||||
@classmethod
|
||||
def check_at_least_one(cls, v):
|
||||
"""至少配置一个数据源"""
|
||||
if not v:
|
||||
raise ValueError("必须配置至少一个数据源")
|
||||
return v
|
||||
|
||||
|
||||
class BenchmarkConfig(BaseModel):
|
||||
"""基准配置"""
|
||||
code: str = Field(...)
|
||||
name: str = Field(...)
|
||||
|
||||
|
||||
class BacktestConfig(BaseModel):
|
||||
"""回测配置"""
|
||||
start_date: str = Field(...)
|
||||
end_date: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class MetadataConfig(BaseModel):
|
||||
"""配置元数据"""
|
||||
version: str = Field(default="1.0.0")
|
||||
strategy: str = Field(default="rotation")
|
||||
description: str = Field(default="")
|
||||
last_updated: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class RotationStrategyConfig(BaseModel):
|
||||
"""ETF轮动策略完整配置"""
|
||||
metadata: MetadataConfig = Field(default_factory=MetadataConfig)
|
||||
asset_pools: AssetPool = Field(...)
|
||||
benchmark: BenchmarkConfig = Field(...)
|
||||
backtest: BacktestConfig = Field(...)
|
||||
factor: FactorConfig = Field(default_factory=FactorConfig)
|
||||
rotation: RotationConfig = Field(default_factory=RotationConfig)
|
||||
rebalance: RebalanceConfig = Field(default_factory=RebalanceConfig)
|
||||
premium_control: PremiumControlConfig = Field(default_factory=PremiumControlConfig)
|
||||
data: DataConfig = Field(...)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 配置加载器(独立实现)
|
||||
# ============================================================
|
||||
|
||||
class ConfigLoader:
|
||||
"""配置加载器(独立实现)"""
|
||||
|
||||
def __init__(self, config_dir: str = None):
|
||||
"""初始化"""
|
||||
if config_dir is None:
|
||||
config_dir = Path(__file__).parent
|
||||
|
||||
self.config_dir = Path(config_dir)
|
||||
|
||||
def load(self, config_file: str) -> RotationStrategyConfig:
|
||||
"""加载配置文件"""
|
||||
# 1. 解析路径
|
||||
config_path = self._resolve_path(config_file)
|
||||
|
||||
# 2. 读取 YAML
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
|
||||
# 3. 环境变量替换
|
||||
config_dict = self._substitute_env_vars(config_dict)
|
||||
|
||||
# 4. Pydantic 验证
|
||||
config = RotationStrategyConfig(**config_dict)
|
||||
|
||||
return config
|
||||
|
||||
def _resolve_path(self, config_file: str) -> Path:
|
||||
"""解析配置文件路径"""
|
||||
path = Path(config_file)
|
||||
|
||||
if path.is_absolute():
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"配置文件不存在: {path}")
|
||||
return path
|
||||
|
||||
# 相对路径:先在配置目录查找
|
||||
config_path = self.config_dir / path
|
||||
if config_path.exists():
|
||||
return config_path
|
||||
|
||||
# 然后在当前工作目录查找
|
||||
cwd_path = Path.cwd() / path
|
||||
if cwd_path.exists():
|
||||
return cwd_path
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"配置文件未找到: {path}\n"
|
||||
f"搜索路径:\n"
|
||||
f" - {config_path}\n"
|
||||
f" - {cwd_path}"
|
||||
)
|
||||
|
||||
def _substitute_env_vars(self, config: Any) -> Any:
|
||||
"""替换环境变量"""
|
||||
if isinstance(config, dict):
|
||||
return {
|
||||
key: self._substitute_env_vars(value)
|
||||
for key, value in config.items()
|
||||
}
|
||||
elif isinstance(config, list):
|
||||
return [
|
||||
self._substitute_env_vars(item)
|
||||
for item in config
|
||||
]
|
||||
elif isinstance(config, str):
|
||||
# 匹配 ${VAR_NAME} 或 ${VAR_NAME:default}
|
||||
pattern = r'\$\{([^}]+)\}'
|
||||
|
||||
def replace_match(match):
|
||||
var_expr = match.group(1)
|
||||
|
||||
if ':' in var_expr:
|
||||
var_name, default = var_expr.split(':', 1)
|
||||
else:
|
||||
var_name = var_expr
|
||||
default = None
|
||||
|
||||
value = os.getenv(var_name)
|
||||
|
||||
if value is None:
|
||||
if default is not None:
|
||||
return default
|
||||
else:
|
||||
raise ValueError(
|
||||
f"环境变量未设置: {var_name}\n"
|
||||
f"请设置环境变量或在配置中使用默认值: ${{{var_name}:default_value}}"
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
return re.sub(pattern, replace_match, config)
|
||||
else:
|
||||
return config
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 快捷函数
|
||||
# ============================================================
|
||||
|
||||
def load_rotation_config(config_file: str = 'config_simple.yaml') -> RotationStrategyConfig:
|
||||
"""
|
||||
加载 rotation 配置
|
||||
|
||||
Args:
|
||||
config_file: 配置文件名
|
||||
|
||||
Returns:
|
||||
RotationStrategyConfig: 验证后的配置对象
|
||||
|
||||
环境变量:
|
||||
必须设置 FLASK_API_URL 或提供默认值
|
||||
"""
|
||||
loader = ConfigLoader()
|
||||
return loader.load(config_file)
|
||||
|
||||
|
||||
def get_config_info(config: RotationStrategyConfig) -> dict:
|
||||
"""获取配置摘要"""
|
||||
return {
|
||||
'version': config.metadata.version,
|
||||
'strategy': config.metadata.strategy,
|
||||
'description': config.metadata.description,
|
||||
'asset_count': config.asset_pools.count(),
|
||||
'groups': config.asset_pools.groups,
|
||||
'factor_type': config.factor.type.value,
|
||||
'n_days': config.factor.n_days,
|
||||
'select_num': config.rotation.select_num,
|
||||
'start_date': config.backtest.start_date,
|
||||
'end_date': config.backtest.end_date or '至今',
|
||||
'threshold_mode': config.rotation.threshold.mode.value,
|
||||
}
|
||||
|
||||
|
||||
def print_config_summary(config: RotationStrategyConfig) -> None:
|
||||
"""打印配置摘要"""
|
||||
info = get_config_info(config)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print(" ETF 轮动策略配置摘要")
|
||||
print("=" * 50)
|
||||
|
||||
print(f"\n版本: {info['version']}")
|
||||
print(f"策略: {info['strategy']}")
|
||||
print(f"描述: {info['description']}")
|
||||
|
||||
print(f"\n资产池:")
|
||||
print(f" 总数: {info['asset_count']}")
|
||||
print(f" 分组: {', '.join(info['groups'])}")
|
||||
|
||||
print(f"\n因子配置:")
|
||||
print(f" 类型: {info['factor_type']}")
|
||||
print(f" 窗口: {info['n_days']} 天")
|
||||
|
||||
print(f"\n轮动配置:")
|
||||
print(f" 选股数: {info['select_num']}")
|
||||
print(f" 阈值模式: {info['threshold_mode']}")
|
||||
|
||||
print(f"\n回测配置:")
|
||||
print(f" 起始: {info['start_date']}")
|
||||
print(f" 结束: {info['end_date']}")
|
||||
|
||||
print(f"\n各分组资产:")
|
||||
for group, assets in config.asset_pools.by_group.items():
|
||||
print(f"\n [{group}] ({len(assets)} 个):")
|
||||
for code, asset in assets.items():
|
||||
cross_market = "✓ 跨市场" if asset.is_cross_market else ""
|
||||
print(f" - {asset.name} ({code})")
|
||||
print(f" 信号: {asset.signal_source} → 交易: {asset.trade_source} {cross_market}")
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 测试
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""测试配置加载器"""
|
||||
# 设置环境变量(如果没有设置)
|
||||
if 'FLASK_API_URL' not in os.environ:
|
||||
print("提示: FLASK_API_URL 未设置,使用默认值")
|
||||
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
|
||||
|
||||
try:
|
||||
# 加载配置
|
||||
print("\n[测试] 加载 config_simple.yaml...")
|
||||
config = load_rotation_config()
|
||||
|
||||
# 打印摘要
|
||||
print_config_summary(config)
|
||||
|
||||
# 测试基本功能
|
||||
print("\n[验证] 基本功能测试:")
|
||||
print(f" ✓ 配置对象类型: {type(config).__name__}")
|
||||
print(f" ✓ 资产池对象类型: {type(config.asset_pools).__name__}")
|
||||
print(f" ✓ 资产数量: {config.asset_pools.count()}")
|
||||
print(f" ✓ 分组数量: {len(config.asset_pools.groups)}")
|
||||
|
||||
# 测试数据源
|
||||
print("\n[验证] 数据源配置:")
|
||||
for i, source in enumerate(config.data.sources, 1):
|
||||
print(f" {i}. {source.type.value}")
|
||||
print(f" URL: {source.url}")
|
||||
print(f" 启用: {source.enabled}")
|
||||
|
||||
print("\n✓ 所有测试通过!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 加载失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Reference in New Issue
Block a user