refactor(archive): move unused modules to archive/
Archive legacy framework and utility modules that are no longer referenced by the active core (datasource/ and rotation/): - framework/ -> archive/framework/ - framework_v2/ -> archive/framework_v2/ - strategies/ -> archive/strategies/ - config/ -> archive/config/ - visualization/ -> archive/visualization/ - scripts/ -> archive/scripts/ - tests/ -> archive/tests/ - run_rotation.py, run_us_rotation.py -> archive/single_files/ - compare_*.py, test_api_dates.py -> archive/single_files/
This commit is contained in:
260
archive/framework_v2/core/strategy.py
Normal file
260
archive/framework_v2/core/strategy.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""
|
||||
策略抽象基类
|
||||
|
||||
所有策略必须继承此类并实现必要方法
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional, Any
|
||||
import pandas as pd
|
||||
|
||||
from framework_v2.config.schemas import StrategyConfig
|
||||
|
||||
|
||||
class StrategyBase(ABC):
|
||||
"""
|
||||
策略抽象基类(V2 增强版)
|
||||
|
||||
定义策略的标准生命周期:
|
||||
1. 初始化配置(使用 Pydantic Schema)
|
||||
2. 获取数据(支持多数据源)
|
||||
3. 计算因子(使用框架因子库)
|
||||
4. 生成信号(策略特定逻辑)
|
||||
5. 执行回测(框架通用执行器)
|
||||
|
||||
子类必须实现:
|
||||
- get_codes(): 获取标的列表
|
||||
- compute_factors(): 计算因子
|
||||
- generate_signals(): 生成信号
|
||||
- manage_positions(): 仓位管理
|
||||
"""
|
||||
|
||||
INTERFACE_VERSION = 2 # V2 版本
|
||||
|
||||
def __init__(self, config: StrategyConfig):
|
||||
"""
|
||||
初始化策略
|
||||
|
||||
Args:
|
||||
config: 策略配置(Pydantic Schema)
|
||||
"""
|
||||
self.config = config
|
||||
self.name = config.metadata.strategy
|
||||
|
||||
# 组件将在子类中初始化
|
||||
self._data_fetcher = None
|
||||
self._executor = None
|
||||
|
||||
@abstractmethod
|
||||
def get_codes(self) -> list:
|
||||
"""
|
||||
获取标的列表(策略必须实现)
|
||||
|
||||
Returns:
|
||||
标的代码列表,如 ['399006.SZ', 'NDX', 'GC=F']
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_factors(self, data: Dict[str, pd.DataFrame]) -> Dict[str, pd.Series]:
|
||||
"""
|
||||
计算因子(策略必须实现)
|
||||
|
||||
Args:
|
||||
data: 数据字典 {code: DataFrame}
|
||||
|
||||
Returns:
|
||||
因子字典 {code: Series}
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate_signals(self, factors: Dict[str, pd.Series]) -> pd.DataFrame:
|
||||
"""
|
||||
生成信号(策略必须实现)
|
||||
|
||||
Args:
|
||||
factors: 因子字典 {code: Series}
|
||||
|
||||
Returns:
|
||||
信号 DataFrame(包含 'signal' 列,1=买入,0=空仓)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def manage_positions(self, signals: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
仓位管理(策略必须实现)
|
||||
|
||||
Args:
|
||||
signals: 信号 DataFrame
|
||||
|
||||
Returns:
|
||||
仓位 DataFrame(包含 'weight' 列,权重和为 1)
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_data(self) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
获取数据(框架实现,策略可覆盖)
|
||||
|
||||
Returns:
|
||||
数据字典 {code: DataFrame}
|
||||
"""
|
||||
if self._data_fetcher is None:
|
||||
self._data_fetcher = self._create_data_fetcher()
|
||||
|
||||
codes = self.get_codes()
|
||||
|
||||
# 处理 end_date 为 None 的情况(使用今天)
|
||||
from datetime import date
|
||||
start = self.config.backtest.start_date
|
||||
end = self.config.backtest.end_date
|
||||
if end is None:
|
||||
end = date.today().strftime('%Y-%m-%d')
|
||||
|
||||
# 批量获取数据(fetch_indices 返回 {code: DataFrame})
|
||||
try:
|
||||
data = self._data_fetcher.fetch_indices(
|
||||
codes=codes,
|
||||
start=start,
|
||||
end=end
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
print(f" 错误: 数据获取失败 - {e}")
|
||||
return {}
|
||||
|
||||
def compute_factors(self, data: Dict[str, Any]) -> pd.DataFrame:
|
||||
"""
|
||||
计算因子(可选覆盖)
|
||||
|
||||
Args:
|
||||
data: 数据字典
|
||||
|
||||
Returns:
|
||||
因子 DataFrame(日期 × 标的)
|
||||
"""
|
||||
if self._factor is None:
|
||||
self._factor = self.init_factors()
|
||||
|
||||
# 默认实现:遍历标的计算因子
|
||||
factor_values = {}
|
||||
for code in data.get('valid_codes', []):
|
||||
if code in data.get('index_data', {}):
|
||||
factor_values[code] = self._factor.compute(data['index_data'][code])
|
||||
|
||||
return pd.DataFrame(factor_values)
|
||||
|
||||
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
生成信号
|
||||
|
||||
Args:
|
||||
factor_df: 因子 DataFrame
|
||||
|
||||
Returns:
|
||||
信号 DataFrame(包含 'signal' 列)
|
||||
"""
|
||||
if self._signal_generator is None:
|
||||
self._signal_generator = self.init_signal_generator()
|
||||
|
||||
return self._signal_generator.generate(factor_df)
|
||||
|
||||
def run(self, data: Optional[Dict[str, pd.DataFrame]] = None,
|
||||
export_detail: bool = False, detail_path: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
运行完整回测流程(框架标准流程)
|
||||
|
||||
Args:
|
||||
data: 可选,如不提供则自动获取
|
||||
export_detail: 是否导出逐日明细(默认 False)
|
||||
detail_path: 明细 JSON 文件路径(export_detail=True 时必需)
|
||||
|
||||
Returns:
|
||||
回测结果字典,包含:
|
||||
- equity_curve: 净值曲线
|
||||
- trades: 交易记录
|
||||
- metrics: 绩效指标
|
||||
"""
|
||||
# 1. 获取数据并保存
|
||||
if data is None:
|
||||
print("[1/5] 获取数据...")
|
||||
data = self.get_data()
|
||||
self._data = data # 保存数据供导出使用
|
||||
print(f" 获取 {len(data)} 个标的")
|
||||
else:
|
||||
self._data = data
|
||||
|
||||
# 2. 计算因子
|
||||
print("[2/5] 计算因子...")
|
||||
factors = self.compute_factors(data)
|
||||
print(f" 计算 {len(factors)} 个因子")
|
||||
|
||||
# 3. 生成信号
|
||||
print("[3/5] 生成信号...")
|
||||
signals = self.generate_signals(factors)
|
||||
print(f" 生成 {signals.shape[0]} 个信号")
|
||||
|
||||
# 4. 仓位管理
|
||||
print("[4/5] 仓位管理...")
|
||||
positions = self.manage_positions(signals)
|
||||
# positions 可能是信号矩阵或权重矩阵,计算平均仓位
|
||||
if hasattr(positions, 'sum'):
|
||||
avg_position = positions.sum(axis=1).mean() if hasattr(positions.sum(axis=1), 'mean') else 0
|
||||
print(f" 平均仓位: {avg_position:.2%}")
|
||||
else:
|
||||
print(f" 仓位管理完成")
|
||||
|
||||
# 5. 执行回测
|
||||
print("[5/5] 执行回测...")
|
||||
result = self._execute_backtest(positions, data)
|
||||
print(f" 回测完成")
|
||||
|
||||
# 6. 可选:导出逐日明细
|
||||
if export_detail:
|
||||
if not detail_path:
|
||||
raise ValueError("export_detail=True 时需要指定 detail_path")
|
||||
|
||||
print("\n[额外] 导出逐日明细...")
|
||||
self._export_backtest_detail(
|
||||
factors=factors,
|
||||
signals=signals,
|
||||
positions=positions,
|
||||
result=result,
|
||||
output_path=detail_path
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _execute_backtest(self, signals: pd.DataFrame, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
执行回测(子类可覆盖)
|
||||
|
||||
Args:
|
||||
signals: 信号 DataFrame
|
||||
data: 数据字典
|
||||
|
||||
Returns:
|
||||
回测结果
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement _execute_backtest()")
|
||||
|
||||
def _create_data_fetcher(self):
|
||||
"""
|
||||
创建数据获取器(框架实现)
|
||||
|
||||
Returns:
|
||||
DataFetcher 实例
|
||||
"""
|
||||
from framework_v2.shared.data import FlaskAPIFetcher
|
||||
|
||||
# 使用配置中的第一个启用的数据源
|
||||
for source_config in self.config.data.sources:
|
||||
if source_config.enabled and source_config.type.value == 'flask_api':
|
||||
return FlaskAPIFetcher(
|
||||
base_url=source_config.url,
|
||||
timeout=source_config.timeout
|
||||
)
|
||||
|
||||
raise ValueError("未找到可用的数据源配置")
|
||||
Reference in New Issue
Block a user