refactor(archive): move unused modules to archive/
Archive legacy framework and utility modules that are no longer referenced by the active core (datasource/ and rotation/): - framework/ -> archive/framework/ - framework_v2/ -> archive/framework_v2/ - strategies/ -> archive/strategies/ - config/ -> archive/config/ - visualization/ -> archive/visualization/ - scripts/ -> archive/scripts/ - tests/ -> archive/tests/ - run_rotation.py, run_us_rotation.py -> archive/single_files/ - compare_*.py, test_api_dates.py -> archive/single_files/
This commit is contained in:
60
archive/framework/__init__.py
Normal file
60
archive/framework/__init__.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
框架统一入口(通用)
|
||||
|
||||
只导出抽象接口,具体实现在strategies/shared/
|
||||
"""
|
||||
|
||||
# 因子层抽象
|
||||
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
|
||||
|
||||
# 信号层抽象
|
||||
from framework.signals import SignalGenerator
|
||||
|
||||
# 风控层抽象
|
||||
from framework.risk import Position, RiskControl, CallbackHook
|
||||
|
||||
# 策略层抽象
|
||||
from framework.strategy import StrategyBase
|
||||
|
||||
# 执行层抽象
|
||||
from framework.execution import Portfolio, Executor, BacktestExecutor, DryRunExecutor
|
||||
|
||||
# 配置层
|
||||
from framework.config import ConfigLoader, StrategyConfig
|
||||
|
||||
# 数据层抽象
|
||||
from framework.data import OHLCVData, DataSource, DataCache
|
||||
|
||||
|
||||
__all__ = [
|
||||
# 因子层
|
||||
'FactorBase',
|
||||
'FactorRegistry',
|
||||
'FactorCombiner',
|
||||
|
||||
# 信号层
|
||||
'SignalGenerator',
|
||||
|
||||
# 风控层
|
||||
'Position',
|
||||
'RiskControl',
|
||||
'CallbackHook',
|
||||
|
||||
# 策略层
|
||||
'StrategyBase',
|
||||
|
||||
# 执行层
|
||||
'Portfolio',
|
||||
'Executor',
|
||||
'BacktestExecutor',
|
||||
'DryRunExecutor',
|
||||
|
||||
# 配置层
|
||||
'ConfigLoader',
|
||||
'StrategyConfig',
|
||||
|
||||
# 数据层
|
||||
'OHLCVData',
|
||||
'DataSource',
|
||||
'DataCache',
|
||||
]
|
||||
103
archive/framework/config/__init__.py
Normal file
103
archive/framework/config/__init__.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
配置层抽象接口(通用)
|
||||
|
||||
只提供配置加载和验证机制
|
||||
"""
|
||||
|
||||
import yaml
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
"""
|
||||
配置加载器(通用)
|
||||
|
||||
支持从YAML文件加载配置
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
"""初始化配置加载器"""
|
||||
self._config: Dict[str, Any] = {}
|
||||
|
||||
if config_path:
|
||||
self.load(config_path)
|
||||
|
||||
def load(self, config_path: str) -> Dict[str, Any]:
|
||||
"""从YAML文件加载配置"""
|
||||
path = Path(config_path)
|
||||
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
self._config = yaml.safe_load(f) or {}
|
||||
|
||||
return self._config
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""获取配置项"""
|
||||
return self._config.get(key, default)
|
||||
|
||||
def get_section(self, section: str) -> Dict[str, Any]:
|
||||
"""获取配置区块"""
|
||||
return self._config.get(section, {})
|
||||
|
||||
def validate(self, required_keys: list) -> bool:
|
||||
"""验证必填配置项"""
|
||||
missing = [key for key in required_keys if key not in self._config]
|
||||
|
||||
if missing:
|
||||
raise ValueError(f"Missing required config keys: {missing}")
|
||||
|
||||
return True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ConfigLoader(keys={list(self._config.keys())})"
|
||||
|
||||
|
||||
class StrategyConfig:
|
||||
"""
|
||||
策略配置类(通用)
|
||||
|
||||
用于封装策略配置
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""初始化策略配置"""
|
||||
self._config = config
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""策略名称"""
|
||||
return self._config.get('strategy', {}).get('name', 'unknown')
|
||||
|
||||
@property
|
||||
def select_num(self) -> int:
|
||||
"""选中数量"""
|
||||
return self._config.get('signal', {}).get('select_num', 3)
|
||||
|
||||
@property
|
||||
def stoploss(self) -> float:
|
||||
"""止损阈值"""
|
||||
return self._config.get('risk', {}).get('stop_loss', -0.05)
|
||||
|
||||
def get_factor_config(self) -> list:
|
||||
"""获取因子配置"""
|
||||
return self._config.get('factors', [])
|
||||
|
||||
def get_signal_config(self) -> dict:
|
||||
"""获取信号配置"""
|
||||
return self._config.get('signal', {})
|
||||
|
||||
def get_risk_config(self) -> list:
|
||||
"""获取风控配置"""
|
||||
return self._config.get('risk', [])
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return self._config.copy()
|
||||
|
||||
|
||||
# 导出抽象接口
|
||||
__all__ = ['ConfigLoader', 'StrategyConfig']
|
||||
126
archive/framework/data/__init__.py
Normal file
126
archive/framework/data/__init__.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
数据层抽象接口(通用)
|
||||
|
||||
只提供数据获取抽象接口,具体实现在strategies/shared/data/
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Any
|
||||
import pandas as pd
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class OHLCVData:
|
||||
"""
|
||||
OHLCV数据结构(通用)
|
||||
|
||||
标准化的K线数据格式
|
||||
"""
|
||||
code: str
|
||||
name: str = ""
|
||||
start_date: datetime = None
|
||||
end_date: datetime = None
|
||||
|
||||
# OHLCV数据DataFrame
|
||||
data: pd.DataFrame = None
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
"""数据长度"""
|
||||
return len(self.data) if self.data is not None else 0
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""验证数据完整性"""
|
||||
if self.data is None or self.data.empty:
|
||||
return False
|
||||
|
||||
required_cols = ['close']
|
||||
return all(col in self.data.columns for col in required_cols)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"OHLCVData(code={self.code}, name={self.name}, length={self.length})"
|
||||
|
||||
|
||||
class DataSource(ABC):
|
||||
"""
|
||||
数据源抽象接口
|
||||
|
||||
所有数据源必须实现fetch方法
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
def __init__(self, **params):
|
||||
"""初始化数据源参数"""
|
||||
self._params = params
|
||||
|
||||
@abstractmethod
|
||||
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
|
||||
"""
|
||||
获取单个标的的OHLCV数据
|
||||
|
||||
Args:
|
||||
code: 标的代码
|
||||
start: 开始日期 (YYYY-MM-DD)
|
||||
end: 结束日期 (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
OHLCVData对象
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
|
||||
"""
|
||||
批量获取多个标的的OHLCV数据
|
||||
|
||||
Args:
|
||||
codes: 标的代码列表
|
||||
start: 开始日期
|
||||
end: 结束日期
|
||||
|
||||
Returns:
|
||||
{code: OHLCVData}字典
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_supported_codes(self) -> List[str]:
|
||||
"""获取支持的数据源代码列表"""
|
||||
return []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(name={self.name})"
|
||||
|
||||
|
||||
class DataCache(ABC):
|
||||
"""
|
||||
数据缓存抽象接口(通用)
|
||||
|
||||
支持本地缓存管理
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, code: str, start: str, end: str) -> Optional[OHLCVData]:
|
||||
"""从缓存获取数据"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, code: str, data: OHLCVData) -> None:
|
||||
"""写入缓存"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_fresh(self, code: str, max_age_days: int = 1) -> bool:
|
||||
"""检查缓存是否新鲜"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, code: Optional[str] = None) -> None:
|
||||
"""清空缓存"""
|
||||
pass
|
||||
|
||||
|
||||
# 导出抽象接口
|
||||
__all__ = ['OHLCVData', 'DataSource', 'DataCache']
|
||||
BIN
archive/framework/data/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
archive/framework/data/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
500
archive/framework/execution/__init__.py
Normal file
500
archive/framework/execution/__init__.py
Normal file
@@ -0,0 +1,500 @@
|
||||
"""
|
||||
执行层抽象接口(通用)
|
||||
|
||||
只提供抽象基类和Portfolio数据结构,具体执行器可扩展
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
|
||||
from framework.risk import Position
|
||||
|
||||
|
||||
class Portfolio:
|
||||
"""
|
||||
投资组合数据结构(通用)
|
||||
|
||||
用于管理持仓集合
|
||||
"""
|
||||
|
||||
def __init__(self, initial_capital: float = 100000):
|
||||
"""初始化投资组合"""
|
||||
self.initial_capital = initial_capital
|
||||
self.cash = initial_capital
|
||||
self.positions: Dict[str, Position] = {}
|
||||
self.trades: List[Dict] = []
|
||||
self._net_value_history: List[float] = []
|
||||
|
||||
def add_position(self, code: str, price: float, quantity: float, time: datetime) -> None:
|
||||
"""添加持仓"""
|
||||
position = Position(
|
||||
code=code,
|
||||
entry_price=price,
|
||||
current_price=price,
|
||||
entry_time=time,
|
||||
quantity=quantity
|
||||
)
|
||||
self.positions[code] = position
|
||||
self.cash -= price * quantity
|
||||
|
||||
self.trades.append({
|
||||
'action': 'BUY',
|
||||
'code': code,
|
||||
'price': price,
|
||||
'quantity': quantity,
|
||||
'time': time
|
||||
})
|
||||
|
||||
def remove_position(self, code: str, price: float, time: datetime) -> float:
|
||||
"""移除持仓"""
|
||||
if code not in self.positions:
|
||||
return 0
|
||||
|
||||
position = self.positions[code]
|
||||
profit = (price - position.entry_price) * position.quantity
|
||||
self.cash += price * position.quantity
|
||||
|
||||
del self.positions[code]
|
||||
|
||||
self.trades.append({
|
||||
'action': 'SELL',
|
||||
'code': code,
|
||||
'price': price,
|
||||
'quantity': position.quantity,
|
||||
'time': time,
|
||||
'profit': profit
|
||||
})
|
||||
|
||||
return profit
|
||||
|
||||
def update_prices(self, prices: Dict[str, float]) -> None:
|
||||
"""更新持仓价格"""
|
||||
for code, price in prices.items():
|
||||
if code in self.positions:
|
||||
self.positions[code].current_price = price
|
||||
|
||||
def get_net_value(self) -> float:
|
||||
"""计算净值"""
|
||||
positions_value = sum(
|
||||
pos.current_price * pos.quantity for pos in self.positions.values()
|
||||
)
|
||||
return self.cash + positions_value
|
||||
|
||||
def record_net_value(self) -> None:
|
||||
"""记录当前净值"""
|
||||
self._net_value_history.append(self.get_net_value())
|
||||
|
||||
def get_net_value_series(self) -> pd.Series:
|
||||
"""获取净值序列"""
|
||||
return pd.Series(self._net_value_history)
|
||||
|
||||
def get_weight(self, code: str) -> float:
|
||||
"""计算持仓权重"""
|
||||
if code not in self.positions:
|
||||
return 0
|
||||
|
||||
position_value = self.positions[code].current_price * self.positions[code].quantity
|
||||
return position_value / self.get_net_value()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Portfolio(capital={self.cash:.2f}, positions={len(self.positions)})"
|
||||
|
||||
|
||||
class Executor(ABC):
|
||||
"""
|
||||
执行器抽象基类
|
||||
|
||||
所有执行器必须实现execute方法
|
||||
"""
|
||||
|
||||
mode: str = "base"
|
||||
|
||||
def __init__(self, **params):
|
||||
"""初始化执行器参数"""
|
||||
self._params = params
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> Portfolio:
|
||||
"""
|
||||
执行信号
|
||||
|
||||
Args:
|
||||
signals: 信号DataFrame
|
||||
data: OHLCV数据
|
||||
|
||||
Returns:
|
||||
Portfolio对象
|
||||
"""
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
|
||||
return f"{self.__class__.__name__}({params_str})"
|
||||
|
||||
|
||||
class BacktestExecutor(Executor):
|
||||
"""
|
||||
完整回测执行器(通用)
|
||||
|
||||
支持:
|
||||
- 日收益率计算
|
||||
- 交易成本扣除
|
||||
- 净值计算(起点归一化)
|
||||
- 基准对比
|
||||
- 持仓跟踪
|
||||
"""
|
||||
|
||||
mode = "backtest"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_capital: float = 100000,
|
||||
trade_cost: float = 0.001,
|
||||
select_num: int = 1,
|
||||
benchmark_data: Optional[pd.DataFrame] = None
|
||||
):
|
||||
super().__init__(
|
||||
initial_capital=initial_capital,
|
||||
trade_cost=trade_cost,
|
||||
select_num=select_num,
|
||||
benchmark_data=benchmark_data
|
||||
)
|
||||
self.initial_capital = initial_capital
|
||||
self.trade_cost = trade_cost
|
||||
self.select_num = select_num
|
||||
self.benchmark_data = benchmark_data
|
||||
|
||||
def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> Portfolio:
|
||||
"""
|
||||
执行完整回测
|
||||
|
||||
Args:
|
||||
signals: 信号DataFrame,包含signal或信号列
|
||||
data: OHLCV数据和日收益率数据
|
||||
|
||||
Returns:
|
||||
Portfolio对象(含净值序列、交易记录)
|
||||
"""
|
||||
portfolio = Portfolio(self.initial_capital)
|
||||
|
||||
# 支持中英文列名
|
||||
signal_col = 'signal' if 'signal' in signals.columns else '信号'
|
||||
|
||||
# 删除空信号行
|
||||
signals = signals.dropna(subset=[signal_col])
|
||||
signals = signals[signals[signal_col] != '']
|
||||
|
||||
if signals.empty:
|
||||
return portfolio
|
||||
|
||||
# 计算策略日收益率
|
||||
result = self._calculate_daily_returns(signals, data, signal_col)
|
||||
|
||||
# 扣除交易成本(同时记录调仓事件)
|
||||
result, rebalance_events = self._apply_trade_cost_with_events(result, signals, signal_col)
|
||||
|
||||
# 计算净值(起点归一化)
|
||||
result = self._calculate_net_value(result)
|
||||
|
||||
# 计算基准净值
|
||||
result = self._calculate_benchmark(result)
|
||||
|
||||
# 记录净值历史
|
||||
for date in result.index:
|
||||
portfolio.record_net_value()
|
||||
|
||||
# 存储回测结果
|
||||
portfolio.backtest_result = result
|
||||
portfolio.rebalance_events = rebalance_events # 新增:调仓事件记录
|
||||
|
||||
# 补充调仓事件的净值信息
|
||||
if not rebalance_events.empty:
|
||||
rebalance_events = self._enrich_rebalance_events(rebalance_events, result)
|
||||
portfolio.rebalance_events = rebalance_events
|
||||
|
||||
return portfolio
|
||||
|
||||
def _calculate_daily_returns(self, signals: pd.DataFrame, data: pd.DataFrame, signal_col: str = 'signal') -> pd.DataFrame:
|
||||
"""计算策略日收益率"""
|
||||
result = signals.copy()
|
||||
|
||||
# 日收益率列名格式:日收益率_{code} 或 日收益率_{code}
|
||||
return_cols = [col for col in data.columns if col.startswith('日收益率_')]
|
||||
|
||||
if self.select_num == 1:
|
||||
# 单标的策略
|
||||
def calc_return(row):
|
||||
signal = row[signal_col]
|
||||
if not signal or pd.isna(signal):
|
||||
return 0.0
|
||||
return data.loc[row.name, f'日收益率_{signal}'] if f'日收益率_{signal}' in data.columns else 0.0
|
||||
|
||||
result['策略日收益率'] = result.apply(calc_return, axis=1)
|
||||
else:
|
||||
# 多标的策略(等权组合)
|
||||
# 按实际持仓数量等权分配:选出2只时每只50%,选出1只时100%
|
||||
def calc_multi_return(row):
|
||||
codes = [c for c in row[signal_col].split(',') if c]
|
||||
if not codes:
|
||||
return 0.0
|
||||
returns = []
|
||||
for c in codes:
|
||||
ret = data.loc[row.name, f'日收益率_{c}'] if f'日收益率_{c}' in data.columns else None
|
||||
if ret is not None and pd.notna(ret):
|
||||
returns.append(ret)
|
||||
return np.mean(returns) if returns else 0.0
|
||||
|
||||
result['策略日收益率'] = result.apply(calc_multi_return, axis=1)
|
||||
|
||||
return result
|
||||
|
||||
def _apply_trade_cost(self, result: pd.DataFrame, signals: pd.DataFrame, signal_col: str = 'signal') -> pd.DataFrame:
|
||||
"""扣除交易成本"""
|
||||
if self.trade_cost <= 0:
|
||||
return result
|
||||
|
||||
prev_signal = signals[signal_col].shift(1)
|
||||
|
||||
if self.select_num == 1:
|
||||
# 单标的策略:调仓时扣除固定成本
|
||||
changed = (signals[signal_col] != prev_signal) & prev_signal.notna()
|
||||
result.loc[changed, '策略日收益率'] -= self.trade_cost
|
||||
else:
|
||||
# 多标的策略:按换手率比例扣除成本
|
||||
turnover_list = []
|
||||
for curr, prev in zip(signals[signal_col], prev_signal):
|
||||
if pd.isna(prev) or curr == prev:
|
||||
turnover_list.append(0.0)
|
||||
else:
|
||||
old = set(prev.split(','))
|
||||
new = set(curr.split(','))
|
||||
swapped = len(old - new)
|
||||
turnover = swapped / len(old) if old else 0.0
|
||||
turnover_list.append(turnover)
|
||||
|
||||
result['换手率'] = turnover_list
|
||||
result['策略日收益率'] -= result['换手率'] * self.trade_cost
|
||||
|
||||
return result
|
||||
|
||||
def _apply_trade_cost_with_events(self, result: pd.DataFrame, signals: pd.DataFrame, signal_col: str = 'signal') -> tuple:
|
||||
"""
|
||||
扣除交易成本并记录调仓事件
|
||||
|
||||
Returns:
|
||||
(result, rebalance_events): 回测结果DataFrame和调仓事件DataFrame
|
||||
"""
|
||||
prev_signal = signals[signal_col].shift(1)
|
||||
|
||||
# 记录调仓事件
|
||||
rebalance_events = []
|
||||
last_rebalance_date = None
|
||||
|
||||
# 先计算累积收益率(用于计算调仓前后的净值)
|
||||
cum_return_before_cost = result['策略日收益率'].copy()
|
||||
|
||||
if self.select_num == 1:
|
||||
# 单标的策略
|
||||
for i, (date, curr, prev) in enumerate(zip(signals.index, signals[signal_col], prev_signal)):
|
||||
# 检查是否调仓
|
||||
is_rebalance = False
|
||||
turnover = 0.0
|
||||
added = []
|
||||
removed = []
|
||||
|
||||
if pd.notna(prev) and curr != prev:
|
||||
is_rebalance = True
|
||||
turnover = 1.0 if prev else 0.0
|
||||
added = [curr] if curr else []
|
||||
removed = [prev] if prev else []
|
||||
# 扣除成本
|
||||
result.loc[date, '策略日收益率'] -= self.trade_cost
|
||||
|
||||
# 记录调仓事件
|
||||
if is_rebalance:
|
||||
# 计算持仓天数
|
||||
holding_days = 0
|
||||
if last_rebalance_date is not None:
|
||||
holding_days = (date - last_rebalance_date).days
|
||||
|
||||
event = {
|
||||
'日期': date,
|
||||
'调仓前持仓': prev if pd.notna(prev) else '',
|
||||
'调仓后持仓': curr,
|
||||
'调入标的': ','.join(added) if added else '',
|
||||
'调出标的': ','.join(removed) if removed else '',
|
||||
'换手率': turnover,
|
||||
'调仓成本': self.trade_cost * turnover,
|
||||
'持仓天数': holding_days,
|
||||
'当日收益': result.loc[date, '策略日收益率'] + self.trade_cost * turnover, # 原始收益(扣除成本前)
|
||||
}
|
||||
rebalance_events.append(event)
|
||||
last_rebalance_date = date
|
||||
|
||||
else:
|
||||
# 多标的策略
|
||||
turnover_list = []
|
||||
for i, (date, curr, prev) in enumerate(zip(signals.index, signals[signal_col], prev_signal)):
|
||||
# 检查是否调仓
|
||||
is_rebalance = False
|
||||
turnover = 0.0
|
||||
added = []
|
||||
removed = []
|
||||
|
||||
if pd.notna(prev) and curr != prev:
|
||||
old = set(prev.split(',')) if prev else set()
|
||||
new = set(curr.split(',')) if curr else set()
|
||||
added = list(new - old)
|
||||
removed = list(old - new)
|
||||
swapped = len(removed)
|
||||
turnover = swapped / len(old) if old else 0.0
|
||||
is_rebalance = len(added) > 0 or len(removed) > 0
|
||||
turnover_list.append(turnover)
|
||||
# 扣除成本
|
||||
result.loc[date, '策略日收益率'] -= turnover * self.trade_cost
|
||||
else:
|
||||
turnover_list.append(0.0)
|
||||
|
||||
# 记录调仓事件
|
||||
if is_rebalance:
|
||||
# 计算持仓天数
|
||||
holding_days = 0
|
||||
if last_rebalance_date is not None:
|
||||
holding_days = (date - last_rebalance_date).days
|
||||
|
||||
event = {
|
||||
'日期': date,
|
||||
'调仓前持仓': prev if pd.notna(prev) else '',
|
||||
'调仓后持仓': curr,
|
||||
'调入标的': ','.join(added) if added else '',
|
||||
'调出标的': ','.join(removed) if removed else '',
|
||||
'换手率': turnover,
|
||||
'调仓成本': self.trade_cost * turnover,
|
||||
'持仓天数': holding_days,
|
||||
'当日收益': result.loc[date, '策略日收益率'] + turnover * self.trade_cost, # 原始收益(扣除成本前)
|
||||
}
|
||||
rebalance_events.append(event)
|
||||
last_rebalance_date = date
|
||||
|
||||
result['换手率'] = turnover_list
|
||||
|
||||
# 转换为DataFrame
|
||||
rebalance_df = pd.DataFrame(rebalance_events) if rebalance_events else pd.DataFrame()
|
||||
if not rebalance_df.empty:
|
||||
rebalance_df['日期'] = pd.to_datetime(rebalance_df['日期'])
|
||||
rebalance_df = rebalance_df.set_index('日期')
|
||||
|
||||
return result, rebalance_df
|
||||
|
||||
def _enrich_rebalance_events(self, rebalance_df: pd.DataFrame, result: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
补充调仓事件的净值信息
|
||||
|
||||
Args:
|
||||
rebalance_df: 调仓事件DataFrame
|
||||
result: 回测结果DataFrame(含净值序列)
|
||||
|
||||
Returns:
|
||||
补充净值信息后的调仓事件DataFrame
|
||||
"""
|
||||
# 计算调仓前后净值变化
|
||||
nav_before_list = []
|
||||
nav_after_list = []
|
||||
nav_change_list = []
|
||||
|
||||
for date in rebalance_df.index:
|
||||
# 获取调仓日的净值
|
||||
if date in result.index:
|
||||
# 调仓前净值:前一天收盘净值
|
||||
prev_date_idx = result.index.get_loc(date) - 1
|
||||
if prev_date_idx >= 0:
|
||||
nav_before = result['策略净值'].iloc[prev_date_idx]
|
||||
else:
|
||||
nav_before = 1.0
|
||||
|
||||
# 调仓后净值:当天收盘净值
|
||||
nav_after = result.loc[date, '策略净值']
|
||||
|
||||
# 净值变化
|
||||
nav_change = (nav_after / nav_before - 1) * 100
|
||||
else:
|
||||
nav_before = None
|
||||
nav_after = None
|
||||
nav_change = None
|
||||
|
||||
nav_before_list.append(nav_before)
|
||||
nav_after_list.append(nav_after)
|
||||
nav_change_list.append(nav_change)
|
||||
|
||||
# 添加净值信息列
|
||||
rebalance_df['调仓前净值'] = nav_before_list
|
||||
rebalance_df['调仓后净值'] = nav_after_list
|
||||
rebalance_df['净值变化%'] = nav_change_list
|
||||
|
||||
return rebalance_df
|
||||
|
||||
def _calculate_net_value(self, result: pd.DataFrame) -> pd.DataFrame:
|
||||
"""计算净值(起点归一化)"""
|
||||
result['策略净值'] = (1 + result['策略日收益率']).cumprod()
|
||||
|
||||
# 归一化:确保净值起点为1.0
|
||||
result['策略净值'] = result['策略净值'] / result['策略净值'].iloc[0]
|
||||
|
||||
return result
|
||||
|
||||
def _calculate_benchmark(self, result: pd.DataFrame) -> pd.DataFrame:
|
||||
"""计算基准净值"""
|
||||
if self.benchmark_data is None:
|
||||
return result
|
||||
|
||||
# 获取基准收益率
|
||||
if isinstance(self.benchmark_data, pd.DataFrame):
|
||||
if 'close' in self.benchmark_data.columns:
|
||||
bench_close = self.benchmark_data['close']
|
||||
else:
|
||||
bench_close = self.benchmark_data.iloc[:, 0]
|
||||
else:
|
||||
bench_close = self.benchmark_data
|
||||
|
||||
bench_ret = bench_close.pct_change().dropna()
|
||||
common_dates = result.index.intersection(bench_ret.index)
|
||||
bench_ret = bench_ret.loc[common_dates]
|
||||
|
||||
result['基准日收益率'] = bench_ret.reindex(result.index, fill_value=0)
|
||||
result['基准净值'] = (1 + result['基准日收益率']).cumprod()
|
||||
result['基准净值'] = result['基准净值'] / result['基准净值'].iloc[0]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class DryRunExecutor(Executor):
|
||||
"""
|
||||
Dry-run执行器(通用)
|
||||
|
||||
用于模拟运行,不实际执行交易
|
||||
"""
|
||||
|
||||
mode = "dry_run"
|
||||
|
||||
def __init__(self, verbose: bool = True):
|
||||
super().__init__(verbose=verbose)
|
||||
self.verbose = verbose
|
||||
|
||||
def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> Portfolio:
|
||||
"""模拟执行"""
|
||||
portfolio = Portfolio(100000)
|
||||
|
||||
for date in signals.index:
|
||||
signal = signals.loc[date, 'signal']
|
||||
|
||||
if signal and self.verbose:
|
||||
print(f"[{date}] Signal: {signal}")
|
||||
|
||||
return portfolio
|
||||
|
||||
|
||||
# 导出抽象接口
|
||||
__all__ = ['Portfolio', 'Executor', 'BacktestExecutor', 'DryRunExecutor']
|
||||
191
archive/framework/factors/__init__.py
Normal file
191
archive/framework/factors/__init__.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
因子层抽象接口(通用)
|
||||
|
||||
只提供抽象基类和注册机制,具体因子实现在strategies/shared/factors/
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Any, Type
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
class FactorBase(ABC):
|
||||
"""
|
||||
因子抽象基类
|
||||
|
||||
所有因子必须实现compute方法
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
category: str = "unknown"
|
||||
|
||||
def __init__(self, **params):
|
||||
"""初始化因子参数"""
|
||||
self._params = params
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""
|
||||
计算因子值
|
||||
|
||||
Args:
|
||||
data: OHLCV数据,必须包含'close'列
|
||||
|
||||
Returns:
|
||||
因子值序列
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_data(self, data: pd.DataFrame) -> bool:
|
||||
"""验证数据是否满足计算要求"""
|
||||
if 'close' not in data.columns:
|
||||
return False
|
||||
|
||||
min_periods = self._params.get('min_periods', 20)
|
||||
return len(data) >= min_periods
|
||||
|
||||
def __repr__(self) -> str:
|
||||
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
|
||||
return f"{self.__class__.__name__}({params_str})"
|
||||
|
||||
|
||||
class FactorRegistry:
|
||||
"""
|
||||
因子注册器(通用)
|
||||
|
||||
管理因子类的注册和获取
|
||||
"""
|
||||
|
||||
_factors: Dict[str, Type[FactorBase]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, factor_class: Type[FactorBase]) -> None:
|
||||
"""注册因子类"""
|
||||
temp_instance = factor_class()
|
||||
name = temp_instance.name
|
||||
|
||||
if name in cls._factors:
|
||||
print(f"因子已注册,覆盖: {name}")
|
||||
|
||||
cls._factors[name] = factor_class
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str, **params) -> FactorBase:
|
||||
"""获取因子实例"""
|
||||
if name not in cls._factors:
|
||||
raise ValueError(f"因子未注册: {name}")
|
||||
|
||||
factor_class = cls._factors[name]
|
||||
return factor_class(**params)
|
||||
|
||||
@classmethod
|
||||
def list_factors(cls) -> List[str]:
|
||||
"""列出所有已注册因子"""
|
||||
return list(cls._factors.keys())
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
"""清空注册表"""
|
||||
cls._factors = {}
|
||||
|
||||
@classmethod
|
||||
def get_category(cls, name: str) -> str:
|
||||
"""获取因子类别"""
|
||||
if name not in cls._factors:
|
||||
return "unknown"
|
||||
|
||||
temp_instance = cls._factors[name]()
|
||||
return temp_instance.category
|
||||
|
||||
|
||||
class FactorCombiner:
|
||||
"""
|
||||
因子组合器(通用)
|
||||
|
||||
支持多因子加权组合
|
||||
"""
|
||||
|
||||
SUPPORTED_METHODS = ['weighted_sum', 'rank_average', 'zscore_sum', 'equal_weight']
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
factors: List[FactorBase],
|
||||
weights: Optional[List[float]] = None,
|
||||
method: str = 'weighted_sum'
|
||||
):
|
||||
"""
|
||||
初始化组合器
|
||||
|
||||
Args:
|
||||
factors: 因子实例列表
|
||||
weights: 因子权重列表(可选)
|
||||
method: 组合方法(weighted_sum/rank_average/zscore_sum/equal_weight)
|
||||
"""
|
||||
if not factors:
|
||||
raise ValueError("factors list cannot be empty")
|
||||
|
||||
if method not in self.SUPPORTED_METHODS:
|
||||
raise ValueError(f"Unsupported method: {method}")
|
||||
|
||||
self._factors = factors
|
||||
|
||||
if weights is None:
|
||||
self._weights = [1.0 / len(factors)] * len(factors)
|
||||
else:
|
||||
if len(weights) != len(factors):
|
||||
raise ValueError("weights length must match factors length")
|
||||
self._weights = weights
|
||||
|
||||
self._method = method
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
计算所有因子并组合
|
||||
|
||||
Returns:
|
||||
DataFrame包含各因子值和combined列
|
||||
"""
|
||||
result = pd.DataFrame(index=data.index)
|
||||
|
||||
# 计算各因子
|
||||
for i, factor in enumerate(self._factors):
|
||||
factor_values = factor.compute(data)
|
||||
col_name = f"{factor.name}"
|
||||
result[col_name] = factor_values
|
||||
|
||||
# 组合因子值
|
||||
if self._method == 'weighted_sum':
|
||||
weighted_cols = [f.name for f in self._factors]
|
||||
result['combined'] = result[weighted_cols].apply(
|
||||
lambda row: sum(row[col] * self._weights[i] for i, col in enumerate(weighted_cols) if pd.notna(row[col])),
|
||||
axis=1
|
||||
)
|
||||
|
||||
elif self._method == 'equal_weight':
|
||||
factor_cols = [f.name for f in self._factors]
|
||||
result['combined'] = result[factor_cols].mean(axis=1)
|
||||
|
||||
elif self._method == 'rank_average':
|
||||
factor_cols = [f.name for f in self._factors]
|
||||
ranks = result[factor_cols].rank(axis=1)
|
||||
result['combined'] = ranks.mean(axis=1)
|
||||
|
||||
elif self._method == 'zscore_sum':
|
||||
factor_cols = [f.name for f in self._factors]
|
||||
zscores = result[factor_cols].apply(lambda x: (x - x.mean()) / x.std())
|
||||
result['combined'] = zscores.sum(axis=1)
|
||||
|
||||
return result
|
||||
|
||||
def get_factor_names(self) -> List[str]:
|
||||
"""获取因子名称列表"""
|
||||
return [f.name for f in self._factors]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
factor_names = [f.name for f in self._factors]
|
||||
return f"FactorCombiner(factors={factor_names}, weights={self._weights}, method={self._method})"
|
||||
|
||||
|
||||
# 导出抽象接口
|
||||
__all__ = ['FactorBase', 'FactorRegistry', 'FactorCombiner']
|
||||
188
archive/framework/risk/__init__.py
Normal file
188
archive/framework/risk/__init__.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
风控层抽象接口(通用)
|
||||
|
||||
只提供抽象基类和回调机制,具体风控组件在strategies/shared/risk/
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Any, Callable, Optional
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class Position:
|
||||
"""
|
||||
持仓数据结构(通用)
|
||||
|
||||
用于表示单个持仓的状态
|
||||
"""
|
||||
code: str
|
||||
entry_price: float
|
||||
current_price: float
|
||||
entry_time: datetime
|
||||
quantity: float = 1.0
|
||||
weight: float = 1.0
|
||||
|
||||
@property
|
||||
def profit_ratio(self) -> float:
|
||||
"""计算盈亏比例"""
|
||||
return (self.current_price - self.entry_price) / self.entry_price
|
||||
|
||||
@property
|
||||
def profit_amount(self) -> float:
|
||||
"""计算盈亏金额"""
|
||||
return (self.current_price - self.entry_price) * self.quantity
|
||||
|
||||
@property
|
||||
def holding_days(self) -> int:
|
||||
"""计算持仓天数"""
|
||||
if self.entry_time is None:
|
||||
return 0
|
||||
return (datetime.now() - self.entry_time).days
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Position(code={self.code}, profit={self.profit_ratio:.2%}, days={self.holding_days})"
|
||||
|
||||
|
||||
class RiskControl(ABC):
|
||||
"""
|
||||
风控组件抽象基类
|
||||
|
||||
所有风控组件必须实现check方法
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
def __init__(self, **params):
|
||||
"""初始化风控参数"""
|
||||
self._params = params
|
||||
|
||||
@abstractmethod
|
||||
def check(self, position: Position, **kwargs) -> bool:
|
||||
"""
|
||||
检查风控条件
|
||||
|
||||
Args:
|
||||
position: 持仓对象
|
||||
kwargs: 额外参数(如premium、history等)
|
||||
|
||||
Returns:
|
||||
True表示通过检查,False表示触发风控
|
||||
"""
|
||||
pass
|
||||
|
||||
def apply(self, position: Position) -> Any:
|
||||
"""
|
||||
应用风控规则(可选)
|
||||
|
||||
Args:
|
||||
position: 持仓对象
|
||||
|
||||
Returns:
|
||||
风控结果(如止损价格、建议仓位等)
|
||||
"""
|
||||
return None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
|
||||
return f"{self.__class__.__name__}({params_str})"
|
||||
|
||||
|
||||
class CallbackHook:
|
||||
"""
|
||||
回调钩子管理(通用)
|
||||
|
||||
支持在策略生命周期的关键节点注入自定义逻辑
|
||||
"""
|
||||
|
||||
SUPPORTED_HOOKS = [
|
||||
'before_entry', # 入场前检查
|
||||
'after_entry', # 入场后处理
|
||||
'before_exit', # 出场前检查
|
||||
'after_exit', # 出场后处理
|
||||
'dynamic_stoploss', # 动态止损计算
|
||||
'custom_exit' # 自定义出场条件
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
"""初始化回调钩子"""
|
||||
self._hooks: Dict[str, List[Callable]] = {
|
||||
hook: [] for hook in self.SUPPORTED_HOOKS
|
||||
}
|
||||
|
||||
def register(self, hook_name: str, callback: Callable) -> None:
|
||||
"""注册回调函数"""
|
||||
if hook_name not in self._hooks:
|
||||
raise ValueError(f"Unsupported hook: {hook_name}")
|
||||
|
||||
self._hooks[hook_name].append(callback)
|
||||
|
||||
def trigger(self, hook_name: str, *args, **kwargs) -> Any:
|
||||
"""
|
||||
触发回调
|
||||
|
||||
Args:
|
||||
hook_name: 钩子名称
|
||||
args: 位置参数
|
||||
kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
回调结果(根据钩子类型返回不同结果)
|
||||
"""
|
||||
if hook_name not in self._hooks:
|
||||
raise ValueError(f"Unsupported hook: {hook_name}")
|
||||
|
||||
callbacks = self._hooks[hook_name]
|
||||
|
||||
if not callbacks:
|
||||
# 默认行为
|
||||
if hook_name == 'dynamic_stoploss':
|
||||
return kwargs.get('default_stoploss', -0.05)
|
||||
elif hook_name in ['before_entry', 'before_exit']:
|
||||
return True
|
||||
elif hook_name == 'custom_exit':
|
||||
return False
|
||||
return None
|
||||
|
||||
results = []
|
||||
for callback in callbacks:
|
||||
result = callback(*args, **kwargs)
|
||||
results.append(result)
|
||||
|
||||
# 根据钩子类型返回不同结果
|
||||
if hook_name in ['before_entry', 'before_exit']:
|
||||
# 所有回调返回True才允许
|
||||
return all(results)
|
||||
|
||||
if hook_name == 'dynamic_stoploss':
|
||||
# 返回最小止损值(最严格)
|
||||
return min(results)
|
||||
|
||||
if hook_name == 'custom_exit':
|
||||
# 任一回调触发出场
|
||||
return any(results)
|
||||
|
||||
# 其他钩子返回最后一个结果
|
||||
return results[-1] if results else None
|
||||
|
||||
def clear(self, hook_name: Optional[str] = None) -> None:
|
||||
"""清空回调"""
|
||||
if hook_name:
|
||||
if hook_name in self._hooks:
|
||||
self._hooks[hook_name] = []
|
||||
else:
|
||||
for hook in self._hooks:
|
||||
self._hooks[hook] = []
|
||||
|
||||
def list_hooks(self) -> List[str]:
|
||||
"""列出支持的钩子"""
|
||||
return self.SUPPORTED_HOOKS
|
||||
|
||||
def get_callbacks(self, hook_name: str) -> List[Callable]:
|
||||
"""获取钩子的所有回调"""
|
||||
return self._hooks.get(hook_name, [])
|
||||
|
||||
|
||||
# 导出抽象接口
|
||||
__all__ = ['Position', 'RiskControl', 'CallbackHook']
|
||||
54
archive/framework/signals/__init__.py
Normal file
54
archive/framework/signals/__init__.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
信号层抽象接口(通用)
|
||||
|
||||
只提供抽象基类,具体信号生成器在strategies/shared/signals/
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Any
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class SignalGenerator(ABC):
|
||||
"""
|
||||
信号生成器抽象基类
|
||||
|
||||
所有信号生成器必须实现generate方法
|
||||
"""
|
||||
|
||||
mode: str = "base"
|
||||
|
||||
def __init__(self, **params):
|
||||
"""初始化信号生成器参数"""
|
||||
self._params = params
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
生成交易信号
|
||||
|
||||
Args:
|
||||
factor_data: 因子数据DataFrame
|
||||
|
||||
Returns:
|
||||
包含'signal'列的DataFrame
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_factor_data(self, factor_data: pd.DataFrame) -> bool:
|
||||
"""验证因子数据是否有效"""
|
||||
if factor_data.empty:
|
||||
return False
|
||||
|
||||
if 'signal' in factor_data.columns:
|
||||
print("Warning: factor_data already contains 'signal' column")
|
||||
|
||||
return True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
|
||||
return f"{self.__class__.__name__}({params_str})"
|
||||
|
||||
|
||||
# 导出抽象接口
|
||||
__all__ = ['SignalGenerator']
|
||||
144
archive/framework/strategy/__init__.py
Normal file
144
archive/framework/strategy/__init__.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
策略层抽象基类(通用)
|
||||
|
||||
只提供抽象接口,具体策略实现在strategies/
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional, Any
|
||||
import pandas as pd
|
||||
|
||||
from framework.factors import FactorCombiner
|
||||
from framework.signals import SignalGenerator
|
||||
from framework.risk import CallbackHook, Position
|
||||
|
||||
|
||||
class StrategyBase(ABC):
|
||||
"""
|
||||
策略抽象基类
|
||||
|
||||
所有策略必须实现init_factors和init_signal_generator方法
|
||||
"""
|
||||
|
||||
INTERFACE_VERSION = 1
|
||||
name: str = "base"
|
||||
timeframe: str = "1d"
|
||||
|
||||
# 类属性(可被配置覆盖)
|
||||
select_num: int = 3
|
||||
stoploss: float = -0.05
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None):
|
||||
"""
|
||||
初始化策略
|
||||
|
||||
Args:
|
||||
config: 配置字典(可选,用于覆盖类属性)
|
||||
"""
|
||||
if config:
|
||||
self._apply_config(config)
|
||||
|
||||
self._callbacks = CallbackHook()
|
||||
self._register_default_callbacks()
|
||||
|
||||
self._factors = self.init_factors()
|
||||
self._signal_gen = self.init_signal_generator()
|
||||
|
||||
def _apply_config(self, config: Dict) -> None:
|
||||
"""应用配置覆盖类属性"""
|
||||
for key, value in config.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
def _register_default_callbacks(self) -> None:
|
||||
"""注册默认回调方法"""
|
||||
if hasattr(self, 'before_entry'):
|
||||
self._callbacks.register('before_entry', self.before_entry)
|
||||
|
||||
if hasattr(self, 'after_entry'):
|
||||
self._callbacks.register('after_entry', self.after_entry)
|
||||
|
||||
if hasattr(self, 'before_exit'):
|
||||
self._callbacks.register('before_exit', self.before_exit)
|
||||
|
||||
if hasattr(self, 'after_exit'):
|
||||
self._callbacks.register('after_exit', self.after_exit)
|
||||
|
||||
if hasattr(self, 'dynamic_stoploss'):
|
||||
self._callbacks.register('dynamic_stoploss', self.dynamic_stoploss)
|
||||
|
||||
if hasattr(self, 'custom_exit'):
|
||||
self._callbacks.register('custom_exit', self.custom_exit)
|
||||
|
||||
@abstractmethod
|
||||
def init_factors(self) -> FactorCombiner:
|
||||
"""
|
||||
初始化因子组合器
|
||||
|
||||
Returns:
|
||||
FactorCombiner实例
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def init_signal_generator(self) -> SignalGenerator:
|
||||
"""
|
||||
初始化信号生成器
|
||||
|
||||
Returns:
|
||||
SignalGenerator实例
|
||||
"""
|
||||
pass
|
||||
|
||||
def run(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
运行策略
|
||||
|
||||
Args:
|
||||
data: OHLCV数据
|
||||
|
||||
Returns:
|
||||
包含信号的DataFrame
|
||||
"""
|
||||
factor_data = self._factors.compute(data)
|
||||
signals = self._signal_gen.generate(factor_data)
|
||||
|
||||
signals = self._apply_callbacks(signals, data)
|
||||
|
||||
return signals
|
||||
|
||||
def _apply_callbacks(self, signals: pd.DataFrame, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""应用回调处理"""
|
||||
return signals
|
||||
|
||||
# 可选回调方法(子类可覆盖)
|
||||
def before_entry(self, code: str, price: float, **kwargs) -> bool:
|
||||
"""入场前检查"""
|
||||
return True
|
||||
|
||||
def after_entry(self, code: str, price: float, **kwargs) -> None:
|
||||
"""入场后处理"""
|
||||
pass
|
||||
|
||||
def before_exit(self, position: Position, **kwargs) -> bool:
|
||||
"""出场前检查"""
|
||||
return True
|
||||
|
||||
def after_exit(self, position: Position, **kwargs) -> None:
|
||||
"""出场后处理"""
|
||||
pass
|
||||
|
||||
def dynamic_stoploss(self, position: Position) -> float:
|
||||
"""动态止损"""
|
||||
return self.stoploss
|
||||
|
||||
def custom_exit(self, position: Position) -> bool:
|
||||
"""自定义出场条件"""
|
||||
return False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(name={self.name})"
|
||||
|
||||
|
||||
# 导出抽象接口
|
||||
__all__ = ['StrategyBase']
|
||||
174
archive/framework/tests/test_execution.py
Normal file
174
archive/framework/tests/test_execution.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
执行层测试
|
||||
|
||||
测试Portfolio、Executor抽象接口
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from framework.execution import Portfolio, Executor, BacktestExecutor, DryRunExecutor
|
||||
from framework.risk import Position
|
||||
|
||||
|
||||
class TestPortfolio:
|
||||
"""测试Portfolio"""
|
||||
|
||||
def test_portfolio_init(self):
|
||||
"""测试初始化"""
|
||||
portfolio = Portfolio(initial_capital=100000)
|
||||
|
||||
assert portfolio.initial_capital == 100000
|
||||
assert portfolio.cash == 100000
|
||||
assert len(portfolio.positions) == 0
|
||||
|
||||
def test_add_position(self):
|
||||
"""测试添加持仓"""
|
||||
portfolio = Portfolio(initial_capital=100000)
|
||||
|
||||
portfolio.add_position(
|
||||
code='AAPL',
|
||||
price=100.0,
|
||||
quantity=10,
|
||||
time=datetime.now()
|
||||
)
|
||||
|
||||
assert len(portfolio.positions) == 1
|
||||
assert 'AAPL' in portfolio.positions
|
||||
assert portfolio.cash == 99000 # 100000 - 100*10
|
||||
|
||||
def test_remove_position(self):
|
||||
"""测试移除持仓"""
|
||||
portfolio = Portfolio(initial_capital=100000)
|
||||
|
||||
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
|
||||
profit = portfolio.remove_position('AAPL', 110.0, datetime.now())
|
||||
|
||||
assert len(portfolio.positions) == 0
|
||||
assert profit == 100.0 # (110-100)*10
|
||||
assert portfolio.cash == 100100 # 99000 + 110*10
|
||||
|
||||
def test_update_prices(self):
|
||||
"""测试更新价格"""
|
||||
portfolio = Portfolio(initial_capital=100000)
|
||||
|
||||
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
|
||||
portfolio.add_position('MSFT', 200.0, 5, datetime.now())
|
||||
|
||||
portfolio.update_prices({'AAPL': 110.0, 'MSFT': 220.0})
|
||||
|
||||
assert portfolio.positions['AAPL'].current_price == 110.0
|
||||
assert portfolio.positions['MSFT'].current_price == 220.0
|
||||
|
||||
def test_get_net_value(self):
|
||||
"""测试净值计算"""
|
||||
portfolio = Portfolio(initial_capital=100000)
|
||||
|
||||
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
|
||||
portfolio.update_prices({'AAPL': 110.0})
|
||||
|
||||
net_value = portfolio.get_net_value()
|
||||
expected = 99000 + 110 * 10 # cash + position_value
|
||||
assert net_value == expected
|
||||
|
||||
def test_get_weight(self):
|
||||
"""测试权重计算"""
|
||||
portfolio = Portfolio(initial_capital=100000)
|
||||
|
||||
portfolio.add_position('AAPL', 100.0, 100, datetime.now())
|
||||
portfolio.update_prices({'AAPL': 110.0})
|
||||
|
||||
weight = portfolio.get_weight('AAPL')
|
||||
# 持仓价值=110*100=11000,净值=90000+11000=101000
|
||||
expected_weight = 11000 / 101000
|
||||
assert abs(weight - expected_weight) < 0.01
|
||||
|
||||
def test_record_net_value(self):
|
||||
"""测试净值记录"""
|
||||
portfolio = Portfolio(initial_capital=100000)
|
||||
|
||||
portfolio.record_net_value()
|
||||
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
|
||||
portfolio.record_net_value()
|
||||
|
||||
series = portfolio.get_net_value_series()
|
||||
assert len(series) == 2
|
||||
|
||||
def test_portfolio_repr(self):
|
||||
"""测试字符串表示"""
|
||||
portfolio = Portfolio(initial_capital=100000)
|
||||
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
|
||||
|
||||
repr_str = repr(portfolio)
|
||||
assert 'Portfolio' in repr_str
|
||||
assert 'positions=1' in repr_str
|
||||
|
||||
|
||||
class TestBacktestExecutor:
|
||||
"""测试回测执行器"""
|
||||
|
||||
def test_backtest_executor_init(self):
|
||||
"""测试初始化"""
|
||||
executor = BacktestExecutor(initial_capital=100000, trade_cost=0.001)
|
||||
|
||||
assert executor.initial_capital == 100000
|
||||
assert executor.trade_cost == 0.001
|
||||
assert executor.mode == "backtest"
|
||||
|
||||
def test_backtest_execute(self):
|
||||
"""测试回测执行"""
|
||||
executor = BacktestExecutor(initial_capital=100000)
|
||||
|
||||
dates = pd.date_range('2020-01-01', periods=50)
|
||||
signals = pd.DataFrame({
|
||||
'signal': ['AAPL'] * 50
|
||||
}, index=dates)
|
||||
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(50).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
portfolio = executor.execute(signals, data)
|
||||
|
||||
assert isinstance(portfolio, Portfolio)
|
||||
|
||||
def test_backtest_executor_repr(self):
|
||||
"""测试字符串表示"""
|
||||
executor = BacktestExecutor(initial_capital=100000, trade_cost=0.001)
|
||||
repr_str = repr(executor)
|
||||
|
||||
assert 'BacktestExecutor' in repr_str
|
||||
|
||||
|
||||
class TestDryRunExecutor:
|
||||
"""测试DryRun执行器"""
|
||||
|
||||
def test_dry_run_executor_init(self):
|
||||
"""测试初始化"""
|
||||
executor = DryRunExecutor(verbose=True)
|
||||
|
||||
assert executor.verbose == True
|
||||
assert executor.mode == "dry_run"
|
||||
|
||||
def test_dry_run_execute(self):
|
||||
"""测试模拟执行"""
|
||||
executor = DryRunExecutor(verbose=False)
|
||||
|
||||
dates = pd.date_range('2020-01-01', periods=50)
|
||||
signals = pd.DataFrame({
|
||||
'signal': ['AAPL,MSFT'] * 50
|
||||
}, index=dates)
|
||||
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(50).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
portfolio = executor.execute(signals, data)
|
||||
|
||||
assert isinstance(portfolio, Portfolio)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
288
archive/framework/tests/test_factors.py
Normal file
288
archive/framework/tests/test_factors.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
因子层测试
|
||||
|
||||
测试FactorBase、FactorRegistry、FactorCombiner抽象接口
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
|
||||
from strategies.shared.factors.momentum import MomentumFactor, TrendFactor, ReversalFactor, VolatilityFactor
|
||||
|
||||
|
||||
class TestFactorBase:
|
||||
"""测试FactorBase抽象基类"""
|
||||
|
||||
def test_factor_meta(self):
|
||||
"""测试因子元信息"""
|
||||
factor = MomentumFactor(n_days=25)
|
||||
assert factor.name == "momentum"
|
||||
assert factor.category == "momentum"
|
||||
|
||||
def test_factor_repr(self):
|
||||
"""测试因子字符串表示"""
|
||||
factor = MomentumFactor(n_days=30)
|
||||
repr_str = repr(factor)
|
||||
assert "MomentumFactor" in repr_str
|
||||
|
||||
def test_validate_data(self):
|
||||
"""测试数据验证"""
|
||||
factor = MomentumFactor(n_days=25)
|
||||
|
||||
# 数据充足
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
})
|
||||
assert factor.validate_data(data) == True
|
||||
|
||||
# 数据不足
|
||||
short_data = pd.DataFrame({
|
||||
'close': np.random.randn(10).cumsum() + 100
|
||||
})
|
||||
assert factor.validate_data(short_data) == False
|
||||
|
||||
|
||||
class TestFactorRegistry:
|
||||
"""测试因子注册器"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试前清空注册表"""
|
||||
FactorRegistry.clear()
|
||||
|
||||
def test_register_factor(self):
|
||||
"""测试因子注册"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
assert 'momentum' in FactorRegistry.list_factors()
|
||||
|
||||
def test_get_factor(self):
|
||||
"""测试获取因子实例"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
factor = FactorRegistry.get('momentum', n_days=30)
|
||||
assert isinstance(factor, MomentumFactor)
|
||||
assert factor.n_days == 30
|
||||
|
||||
def test_get_unknown_factor(self):
|
||||
"""测试获取未注册因子"""
|
||||
with pytest.raises(ValueError):
|
||||
FactorRegistry.get('unknown_factor')
|
||||
|
||||
def test_get_category(self):
|
||||
"""测试获取因子类别"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
category = FactorRegistry.get_category('momentum')
|
||||
assert category == 'momentum'
|
||||
|
||||
|
||||
class TestFactorCombiner:
|
||||
"""测试因子组合器"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试前清空注册表"""
|
||||
FactorRegistry.clear()
|
||||
|
||||
def test_combiner_init(self):
|
||||
"""测试组合器初始化"""
|
||||
factors = [
|
||||
MomentumFactor(n_days=25),
|
||||
TrendFactor(method='ma_cross')
|
||||
]
|
||||
combiner = FactorCombiner(factors, weights=[0.7, 0.3])
|
||||
|
||||
assert len(combiner.get_factor_names()) == 2
|
||||
|
||||
def test_combiner_equal_weights(self):
|
||||
"""测试等权组合"""
|
||||
factors = [
|
||||
MomentumFactor(n_days=25),
|
||||
TrendFactor()
|
||||
]
|
||||
combiner = FactorCombiner(factors) # 默认等权
|
||||
|
||||
# 权重应该归一化
|
||||
assert sum(combiner._weights) == 1.0
|
||||
|
||||
def test_combiner_compute(self):
|
||||
"""测试因子组合计算"""
|
||||
# 生成测试数据
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100,
|
||||
'high': np.random.randn(100).cumsum() + 105,
|
||||
'low': np.random.randn(100).cumsum() + 95
|
||||
}, index=dates)
|
||||
|
||||
factors = [
|
||||
MomentumFactor(n_days=20),
|
||||
TrendFactor(fast=5, slow=10)
|
||||
]
|
||||
combiner = FactorCombiner(factors, weights=[0.6, 0.4])
|
||||
|
||||
result = combiner.compute(data)
|
||||
|
||||
# 检查结果列
|
||||
assert 'momentum' in result.columns
|
||||
assert 'trend' in result.columns
|
||||
assert 'combined' in result.columns
|
||||
|
||||
def test_combiner_method_rank_average(self):
|
||||
"""测试rank_average组合方法"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
factors = [
|
||||
MomentumFactor(n_days=20),
|
||||
TrendFactor()
|
||||
]
|
||||
combiner = FactorCombiner(factors, method='rank_average')
|
||||
|
||||
result = combiner.compute(data)
|
||||
|
||||
# combined应该是排名平均值
|
||||
assert 'combined' in result.columns
|
||||
|
||||
|
||||
class TestMomentumFactor:
|
||||
"""测试动量因子"""
|
||||
|
||||
def test_momentum_compute(self):
|
||||
"""测试动量因子计算"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成上升趋势数据
|
||||
prices = 100 + np.arange(100) * 0.5
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = MomentumFactor(n_days=25, weighted=True)
|
||||
values = factor.compute(data)
|
||||
|
||||
# 上升趋势应该有正的动量得分
|
||||
assert values.iloc[-1] > 0
|
||||
|
||||
def test_crash_filter(self):
|
||||
"""测试崩盘过滤"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成正常数据,然后在末尾添加崩盘
|
||||
prices = 100 + np.random.randn(100).cumsum()
|
||||
prices[-3:] = prices[-4] * np.array([0.96, 0.93, 0.90]) # 连续大跌
|
||||
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = MomentumFactor(n_days=25, crash_filter=True)
|
||||
values = factor.compute(data)
|
||||
|
||||
# 崩盘后动量得分应该被清零
|
||||
assert values.iloc[-1] == 0.0
|
||||
|
||||
def test_simple_momentum(self):
|
||||
"""测试简单动量(无加权,无崩盘过滤)"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
prices = 100 + np.random.randn(100).cumsum()
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = MomentumFactor(n_days=25, weighted=False, crash_filter=False)
|
||||
values = factor.compute(data)
|
||||
|
||||
# 简单动量应该是N日涨幅
|
||||
expected = data['close'].pct_change(25)
|
||||
# 验证长度一致
|
||||
assert len(values) == len(expected)
|
||||
|
||||
|
||||
class TestTrendFactor:
|
||||
"""测试趋势因子"""
|
||||
|
||||
def test_ma_cross(self):
|
||||
"""测试MA交叉趋势"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成上升趋势
|
||||
prices = 100 + np.arange(100) * 0.5
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = TrendFactor(method='ma_cross', fast=5, slow=20)
|
||||
values = factor.compute(data)
|
||||
|
||||
# 上升趋势应该有正的趋势强度
|
||||
assert values.iloc[-1] > 0
|
||||
|
||||
def test_macd(self):
|
||||
"""测试MACD趋势"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
prices = 100 + np.random.randn(100).cumsum()
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = TrendFactor(method='macd')
|
||||
values = factor.compute(data)
|
||||
|
||||
# 检查计算结果
|
||||
assert len(values) == len(data)
|
||||
|
||||
|
||||
class TestReversalFactor:
|
||||
"""测试反转因子"""
|
||||
|
||||
def test_rsi_reversal(self):
|
||||
"""测试RSI反转信号"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成超买数据(持续上涨)
|
||||
prices = 100 + np.arange(100) * 1.0
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = ReversalFactor(method='rsi', period=14, overbought=70)
|
||||
values = factor.compute(data)
|
||||
|
||||
# RSI超过70应该产生负值(反转向下信号)
|
||||
assert values.iloc[-1] < 0
|
||||
|
||||
def test_rsi_oversold(self):
|
||||
"""测试RSI超卖信号"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成超卖数据(持续下跌)
|
||||
prices = 100 - np.arange(100) * 1.0
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = ReversalFactor(method='rsi', period=14, oversold=30)
|
||||
values = factor.compute(data)
|
||||
|
||||
# RSI低于30应该产生正值(反转向上信号)
|
||||
assert values.iloc[-1] > 0
|
||||
|
||||
|
||||
class TestVolatilityFactor:
|
||||
"""测试波动率因子"""
|
||||
|
||||
def test_std_volatility(self):
|
||||
"""测试标准差波动率"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
prices = 100 + np.random.randn(100).cumsum()
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = VolatilityFactor(method='std', period=20)
|
||||
values = factor.compute(data)
|
||||
|
||||
assert len(values) == len(data)
|
||||
|
||||
def test_atr_volatility(self):
|
||||
"""测试ATR波动率"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100,
|
||||
'high': np.random.randn(100).cumsum() + 105,
|
||||
'low': np.random.randn(100).cumsum() + 95
|
||||
}, index=dates)
|
||||
|
||||
factor = VolatilityFactor(method='atr', period=20)
|
||||
values = factor.compute(data)
|
||||
|
||||
assert len(values) == len(data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
166
archive/framework/tests/test_integration.py
Normal file
166
archive/framework/tests/test_integration.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
集成测试
|
||||
|
||||
测试框架与定制组件的集成
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from framework.factors import FactorRegistry, FactorCombiner
|
||||
from framework.signals import SignalGenerator
|
||||
from framework.risk import CallbackHook, Position
|
||||
from framework.strategy import StrategyBase
|
||||
from framework.execution import Portfolio, BacktestExecutor
|
||||
|
||||
from strategies.shared.factors.momentum import MomentumFactor, VolatilityFactor
|
||||
from strategies.shared.signals.selectors import TopNSelector
|
||||
from strategies.shared.risk.controls import StopLossControl, premium_filter_callback
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
|
||||
class TestFactorIntegration:
|
||||
"""测试因子集成"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试前清空注册表"""
|
||||
FactorRegistry.clear()
|
||||
|
||||
def test_register_and_use_custom_factor(self):
|
||||
"""测试注册并使用定制因子"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
|
||||
factor = FactorRegistry.get('momentum', n_days=25, crash_filter=True)
|
||||
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
values = factor.compute(data)
|
||||
|
||||
assert len(values) == len(data)
|
||||
|
||||
def test_combiner_with_custom_factors(self):
|
||||
"""测试组合器使用定制因子"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
FactorRegistry.register(VolatilityFactor)
|
||||
|
||||
momentum = FactorRegistry.get('momentum', n_days=25)
|
||||
volatility = FactorRegistry.get('volatility', method='std', period=20)
|
||||
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100,
|
||||
'high': np.random.randn(100).cumsum() + 105,
|
||||
'low': np.random.randn(100).cumsum() + 95
|
||||
}, index=dates)
|
||||
|
||||
combiner = FactorCombiner([momentum, volatility], weights=[0.7, 0.3])
|
||||
result = combiner.compute(data)
|
||||
|
||||
assert 'momentum' in result.columns
|
||||
assert 'volatility' in result.columns
|
||||
assert 'combined' in result.columns
|
||||
|
||||
|
||||
class TestSignalIntegration:
|
||||
"""测试信号生成器集成"""
|
||||
|
||||
def test_custom_signal_generator_with_factors(self):
|
||||
"""测试定制信号生成器与因子集成"""
|
||||
dates = pd.date_range('2020-01-01', periods=50)
|
||||
|
||||
factor_data = pd.DataFrame({
|
||||
'momentum_A': np.random.randn(50),
|
||||
'momentum_B': np.random.randn(50),
|
||||
'momentum_C': np.random.randn(50),
|
||||
}, index=dates)
|
||||
|
||||
selector = TopNSelector(select_num=2, min_score=0.0)
|
||||
result = selector.generate(factor_data)
|
||||
|
||||
assert 'signal' in result.columns
|
||||
|
||||
|
||||
class TestRiskIntegration:
|
||||
"""测试风控组件集成"""
|
||||
|
||||
def test_custom_risk_control(self):
|
||||
"""测试定制风控组件"""
|
||||
control = StopLossControl(threshold=-0.05, trailing=True)
|
||||
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=pd.Timestamp.now()
|
||||
)
|
||||
|
||||
# 首次检查,设置最高价
|
||||
assert control.check(position) == True
|
||||
|
||||
# 价格下跌,触发跟踪止损
|
||||
position.current_price = 100.0
|
||||
assert control.check(position) == False
|
||||
|
||||
def test_callback_hook_with_custom_callback(self):
|
||||
"""测试回调钩子与定制回调集成"""
|
||||
hook = CallbackHook()
|
||||
|
||||
# 注册定制回调
|
||||
callback = premium_filter_callback(threshold=0.10)
|
||||
hook.register('before_entry', callback)
|
||||
|
||||
# 正常溢价通过
|
||||
result = hook.trigger('before_entry', 'AAPL', 100.0, premium=0.05)
|
||||
assert result == True
|
||||
|
||||
# 高溢价拒绝
|
||||
result = hook.trigger('before_entry', 'AAPL', 100.0, premium=0.15)
|
||||
assert result == False
|
||||
|
||||
|
||||
class TestStrategyIntegration:
|
||||
"""测试策略集成"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试前清空注册表"""
|
||||
FactorRegistry.clear()
|
||||
|
||||
def test_rotation_strategy_full_flow(self):
|
||||
"""测试轮动策略完整流程"""
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 生成测试数据
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
# 运行策略
|
||||
result = strategy.run(data)
|
||||
|
||||
assert 'signal' in result.columns
|
||||
|
||||
def test_strategy_with_backtest_executor(self):
|
||||
"""测试策略与回测执行器集成"""
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
signals = strategy.run(data)
|
||||
|
||||
executor = BacktestExecutor(initial_capital=100000, trade_cost=0.001)
|
||||
portfolio = executor.execute(signals, data)
|
||||
|
||||
assert isinstance(portfolio, Portfolio)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
325
archive/framework/tests/test_risk.py
Normal file
325
archive/framework/tests/test_risk.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""
|
||||
风控层测试
|
||||
|
||||
测试RiskControl、CallbackHook抽象接口
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from framework.risk import RiskControl, CallbackHook, Position
|
||||
from strategies.shared.risk.controls import StopLossControl, PositionLimitControl, PremiumControl
|
||||
|
||||
|
||||
class TestPosition:
|
||||
"""测试Position数据结构"""
|
||||
|
||||
def test_position_creation(self):
|
||||
"""测试持仓创建"""
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now(),
|
||||
quantity=10
|
||||
)
|
||||
|
||||
assert position.code == 'AAPL'
|
||||
assert position.entry_price == 100.0
|
||||
assert position.current_price == 105.0
|
||||
|
||||
def test_profit_ratio(self):
|
||||
"""测试盈亏比例计算"""
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=110.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
|
||||
assert position.profit_ratio == 0.10
|
||||
|
||||
def test_profit_amount(self):
|
||||
"""测试盈亏金额计算"""
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=110.0,
|
||||
entry_time=datetime.now(),
|
||||
quantity=10
|
||||
)
|
||||
|
||||
assert position.profit_amount == 100.0
|
||||
|
||||
def test_position_repr(self):
|
||||
"""测试持仓字符串表示"""
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
|
||||
repr_str = repr(position)
|
||||
assert 'AAPL' in repr_str
|
||||
assert '5.00%' in repr_str
|
||||
|
||||
|
||||
class TestStopLossControl:
|
||||
"""测试止损控制"""
|
||||
|
||||
def test_stop_loss_init(self):
|
||||
"""测试初始化"""
|
||||
control = StopLossControl(threshold=-0.05)
|
||||
assert control.threshold == -0.05
|
||||
assert control.name == "stop_loss"
|
||||
|
||||
def test_stop_loss_check(self):
|
||||
"""测试止损检查"""
|
||||
control = StopLossControl(threshold=-0.05)
|
||||
|
||||
# 盈利持仓应该通过
|
||||
profit_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=110.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
assert control.check(profit_position) == True
|
||||
|
||||
# 亏损持仓应该触发止损
|
||||
loss_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=90.0, # 亏损10%
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
# 亏损超过阈值应该触发
|
||||
assert control.check(loss_position) == False
|
||||
|
||||
def test_trailing_stop_loss(self):
|
||||
"""测试跟踪止损"""
|
||||
control = StopLossControl(threshold=-0.05, trailing=True, trailing_percent=0.03)
|
||||
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
|
||||
# 初始最高价=入场价100,当前价105
|
||||
# 更新后最高价=105
|
||||
control.check(position)
|
||||
|
||||
# 从最高价回撤不超过阈值应该通过
|
||||
position.current_price = 103.0 # 回撤约2%
|
||||
assert control.check(position) == True
|
||||
|
||||
# 回撤超过阈值应该触发
|
||||
position.current_price = 100.0 # 回撤约5%
|
||||
assert control.check(position) == False
|
||||
|
||||
|
||||
class TestPositionLimitControl:
|
||||
"""测试仓位限制控制"""
|
||||
|
||||
def test_position_limit_init(self):
|
||||
"""测试初始化"""
|
||||
control = PositionLimitControl(max_position=0.33)
|
||||
assert control.max_position == 0.33
|
||||
assert control.name == "position_limit"
|
||||
|
||||
def test_position_limit_check(self):
|
||||
"""测试仓位限制检查"""
|
||||
control = PositionLimitControl(max_position=0.33)
|
||||
|
||||
# 正常仓位应该通过
|
||||
normal_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now(),
|
||||
weight=0.20
|
||||
)
|
||||
assert control.check(normal_position) == True
|
||||
|
||||
# 超限仓位应该触发
|
||||
over_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now(),
|
||||
weight=0.50
|
||||
)
|
||||
assert control.check(over_position) == False
|
||||
|
||||
|
||||
class TestPremiumControl:
|
||||
"""测试溢价控制"""
|
||||
|
||||
def test_premium_control_init(self):
|
||||
"""测试初始化"""
|
||||
control = PremiumControl(threshold=0.10)
|
||||
assert control.threshold == 0.10
|
||||
assert control.name == "premium"
|
||||
|
||||
def test_premium_filter_mode(self):
|
||||
"""测试溢价过滤模式"""
|
||||
control = PremiumControl(threshold=0.10, mode='filter')
|
||||
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
|
||||
# 正常溢价应该通过
|
||||
assert control.check(position, premium=0.05) == True
|
||||
|
||||
# 高溢价应该被过滤
|
||||
assert control.check(position, premium=0.15) == False
|
||||
|
||||
|
||||
class TestCallbackHook:
|
||||
"""测试回调钩子"""
|
||||
|
||||
def test_callback_hook_init(self):
|
||||
"""测试初始化"""
|
||||
hook = CallbackHook()
|
||||
assert len(hook.list_hooks()) == 6
|
||||
|
||||
def test_register_callback(self):
|
||||
"""测试注册回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
def my_callback(code, price, **kwargs):
|
||||
return True
|
||||
|
||||
hook.register('before_entry', my_callback)
|
||||
|
||||
callbacks = hook.get_callbacks('before_entry')
|
||||
assert len(callbacks) == 1
|
||||
|
||||
def test_trigger_before_entry(self):
|
||||
"""测试触发入场前回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
def always_pass(code, price, **kwargs):
|
||||
return True
|
||||
|
||||
def always_block(code, price, **kwargs):
|
||||
return False
|
||||
|
||||
hook.register('before_entry', always_pass)
|
||||
hook.register('before_entry', always_block)
|
||||
|
||||
# before_entry要求所有回调返回True才允许
|
||||
result = hook.trigger('before_entry', 'AAPL', 100.0)
|
||||
assert result == False
|
||||
|
||||
def test_trigger_dynamic_stoploss(self):
|
||||
"""测试触发动态止损回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
def stoploss_5p(position):
|
||||
return -0.05
|
||||
|
||||
def stoploss_3p(position):
|
||||
return -0.03
|
||||
|
||||
hook.register('dynamic_stoploss', stoploss_5p)
|
||||
hook.register('dynamic_stoploss', stoploss_3p)
|
||||
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
|
||||
# dynamic_stoploss返回最小止损值(最严格)
|
||||
result = hook.trigger('dynamic_stoploss', position)
|
||||
assert result == -0.05
|
||||
|
||||
def test_trigger_custom_exit(self):
|
||||
"""测试触发自定义出场回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
def exit_on_loss(position):
|
||||
return position.profit_ratio < -0.05
|
||||
|
||||
def exit_on_profit(position):
|
||||
return position.profit_ratio > 0.20
|
||||
|
||||
hook.register('custom_exit', exit_on_loss)
|
||||
hook.register('custom_exit', exit_on_profit)
|
||||
|
||||
# custom_exit任一回调触发即可
|
||||
profit_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=125.0, # 盈利25%
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
result = hook.trigger('custom_exit', profit_position)
|
||||
assert result == True
|
||||
|
||||
# 未触发
|
||||
normal_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=110.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
result = hook.trigger('custom_exit', normal_position)
|
||||
assert result == False
|
||||
|
||||
def test_clear_hooks(self):
|
||||
"""测试清空回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
def callback(code, price, **kwargs):
|
||||
return True
|
||||
|
||||
hook.register('before_entry', callback)
|
||||
|
||||
# 清空特定钩子
|
||||
hook.clear('before_entry')
|
||||
assert len(hook.get_callbacks('before_entry')) == 0
|
||||
|
||||
# 清空所有钩子
|
||||
hook.register('before_entry', callback)
|
||||
hook.register('after_entry', callback)
|
||||
hook.clear()
|
||||
assert len(hook.get_callbacks('before_entry')) == 0
|
||||
assert len(hook.get_callbacks('after_entry')) == 0
|
||||
|
||||
def test_default_behavior(self):
|
||||
"""测试默认行为"""
|
||||
hook = CallbackHook() # 无注册回调
|
||||
|
||||
# before_entry默认允许
|
||||
result = hook.trigger('before_entry', 'AAPL', 100.0)
|
||||
assert result == True
|
||||
|
||||
# dynamic_stoploss默认值
|
||||
result = hook.trigger('dynamic_stoploss', None, default_stoploss=-0.05)
|
||||
assert result == -0.05
|
||||
|
||||
# custom_exit默认不出场
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
result = hook.trigger('custom_exit', position)
|
||||
assert result == False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
163
archive/framework/tests/test_signals.py
Normal file
163
archive/framework/tests/test_signals.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
信号层测试
|
||||
|
||||
测试SignalGenerator抽象接口
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from framework.signals import SignalGenerator
|
||||
from strategies.shared.signals.selectors import TopNSelector, TrendFollower, ReversalTrader
|
||||
|
||||
|
||||
class TestTopNSelector:
|
||||
"""测试TopNSelector"""
|
||||
|
||||
def test_top_n_selector_init(self):
|
||||
"""测试初始化"""
|
||||
selector = TopNSelector(select_num=3)
|
||||
assert selector.select_num == 3
|
||||
assert selector.mode == "top_n"
|
||||
|
||||
def test_top_n_selection(self):
|
||||
"""测试Top N选股"""
|
||||
dates = pd.date_range('2020-01-01', periods=50)
|
||||
|
||||
# 生成因子数据
|
||||
data = pd.DataFrame({
|
||||
'factor_A': np.random.randn(50),
|
||||
'factor_B': np.random.randn(50),
|
||||
'factor_C': np.random.randn(50),
|
||||
}, index=dates)
|
||||
|
||||
selector = TopNSelector(select_num=2)
|
||||
result = selector.generate(data)
|
||||
|
||||
# 检查结果列
|
||||
assert 'signal' in result.columns
|
||||
assert 'signal_raw' in result.columns
|
||||
|
||||
# 检查T+1移位(signal比signal_raw滞后1天)
|
||||
assert result['signal'].iloc[0] == '' or pd.isna(result['signal'].iloc[0])
|
||||
|
||||
def test_min_score_filter(self):
|
||||
"""测试最小得分过滤"""
|
||||
dates = pd.date_range('2020-01-01', periods=50)
|
||||
|
||||
# 生成因子数据,部分为负值
|
||||
data = pd.DataFrame({
|
||||
'factor_A': [0.1] * 50,
|
||||
'factor_B': [-0.1] * 50, # 负分
|
||||
'factor_C': [0.2] * 50,
|
||||
}, index=dates)
|
||||
|
||||
selector = TopNSelector(select_num=2, min_score=0.0)
|
||||
result = selector.generate(data)
|
||||
|
||||
# 负分因子应该被过滤
|
||||
signals = result['signal_raw'].dropna().unique()
|
||||
for sig in signals:
|
||||
if sig:
|
||||
codes = sig.split(',')
|
||||
assert 'factor_B' not in codes
|
||||
|
||||
def test_grouped_selection(self):
|
||||
"""测试分组选股"""
|
||||
dates = pd.date_range('2020-01-01', periods=50)
|
||||
|
||||
# 生成因子数据和分组信息
|
||||
data = pd.DataFrame({
|
||||
'factor_A': [0.1] * 50,
|
||||
'factor_B': [0.2] * 50,
|
||||
'factor_C': [0.15] * 50,
|
||||
}, index=dates)
|
||||
|
||||
# 分组信息(模拟)
|
||||
# 注:实际使用需要在数据中包含group_info列
|
||||
|
||||
selector = TopNSelector(select_num=2, group_by='market')
|
||||
result = selector.generate(data)
|
||||
|
||||
assert 'signal' in result.columns
|
||||
|
||||
|
||||
class TestTrendFollower:
|
||||
"""测试趋势跟随器"""
|
||||
|
||||
def test_trend_follower_init(self):
|
||||
"""测试初始化"""
|
||||
follower = TrendFollower(entry_threshold=0.02, exit_threshold=-0.02)
|
||||
assert follower.entry_threshold == 0.02
|
||||
assert follower.exit_threshold == -0.02
|
||||
assert follower.mode == "trend"
|
||||
|
||||
def test_trend_signal_generation(self):
|
||||
"""测试趋势信号生成"""
|
||||
dates = pd.date_range('2020-01-01', periods=50)
|
||||
|
||||
# 生成趋势因子数据
|
||||
data = pd.DataFrame({
|
||||
'trend_A': [0.05] * 50, # 强趋势
|
||||
'trend_B': [-0.05] * 50, # 弱趋势
|
||||
}, index=dates)
|
||||
|
||||
follower = TrendFollower(entry_threshold=0.02)
|
||||
result = follower.generate(data)
|
||||
|
||||
# 检查信号列
|
||||
assert 'signal' in result.columns
|
||||
|
||||
|
||||
class TestReversalTrader:
|
||||
"""测试反转交易器"""
|
||||
|
||||
def test_reversal_trader_init(self):
|
||||
"""测试初始化"""
|
||||
trader = ReversalTrader(overbought=70, oversold=30)
|
||||
assert trader.overbought == 70
|
||||
assert trader.oversold == 30
|
||||
assert trader.mode == "reversal"
|
||||
|
||||
def test_reversal_signal_generation(self):
|
||||
"""测试反转信号生成"""
|
||||
dates = pd.date_range('2020-01-01', periods=50)
|
||||
|
||||
# 生成反转因子数据
|
||||
data = pd.DataFrame({
|
||||
'reversal_A': [0.15] * 50, # 超卖反转
|
||||
'reversal_B': [-0.15] * 50, # 超买反转
|
||||
}, index=dates)
|
||||
|
||||
trader = ReversalTrader(reversal_threshold=0.1)
|
||||
result = trader.generate(data)
|
||||
|
||||
# 检查信号列
|
||||
assert 'signal' in result.columns
|
||||
|
||||
|
||||
class TestSignalGeneratorBase:
|
||||
"""测试SignalGenerator抽象基类"""
|
||||
|
||||
def test_validate_factor_data(self):
|
||||
"""测试数据验证"""
|
||||
selector = TopNSelector(select_num=3)
|
||||
|
||||
# 空数据应该返回False
|
||||
empty_data = pd.DataFrame()
|
||||
assert selector.validate_factor_data(empty_data) == False
|
||||
|
||||
# 有效数据应该返回True
|
||||
valid_data = pd.DataFrame({'factor_A': [1, 2, 3]})
|
||||
assert selector.validate_factor_data(valid_data) == True
|
||||
|
||||
def test_repr(self):
|
||||
"""测试字符串表示"""
|
||||
selector = TopNSelector(select_num=3, min_score=0.5)
|
||||
repr_str = repr(selector)
|
||||
assert "TopNSelector" in repr_str
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
194
archive/framework/tests/test_strategy.py
Normal file
194
archive/framework/tests/test_strategy.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
策略层测试
|
||||
|
||||
测试StrategyBase抽象接口
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from framework.strategy import StrategyBase
|
||||
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
|
||||
from framework.signals import SignalGenerator
|
||||
from framework.risk import Position
|
||||
|
||||
|
||||
class TestStrategyBase:
|
||||
"""测试StrategyBase抽象基类"""
|
||||
|
||||
def test_strategy_config_override(self):
|
||||
"""测试配置覆盖类属性"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
strategy = RotationStrategy(config={'select_num': 5, 'stoploss': -0.03})
|
||||
|
||||
assert strategy.select_num == 5
|
||||
assert strategy.stoploss == -0.03
|
||||
|
||||
def test_strategy_default_values(self):
|
||||
"""测试默认值"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
strategy = RotationStrategy()
|
||||
|
||||
assert strategy.select_num == 3
|
||||
assert strategy.stoploss == -0.05
|
||||
|
||||
def test_strategy_repr(self):
|
||||
"""测试字符串表示"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
strategy = RotationStrategy()
|
||||
repr_str = repr(strategy)
|
||||
|
||||
assert 'RotationStrategy' in repr_str
|
||||
assert 'rotation' in repr_str
|
||||
|
||||
def test_strategy_interface_version(self):
|
||||
"""测试接口版本"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
strategy = RotationStrategy()
|
||||
assert strategy.INTERFACE_VERSION == 1
|
||||
|
||||
|
||||
class TestRotationStrategy:
|
||||
"""测试轮动策略"""
|
||||
|
||||
def test_rotation_strategy_init(self):
|
||||
"""测试初始化"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 检查因子初始化
|
||||
assert strategy._factors is not None
|
||||
assert strategy._signal_gen is not None
|
||||
|
||||
def test_rotation_strategy_run(self):
|
||||
"""测试策略运行"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 生成测试数据
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
result = strategy.run(data)
|
||||
|
||||
# 检查结果
|
||||
assert 'signal' in result.columns
|
||||
|
||||
def test_dynamic_stoploss(self):
|
||||
"""测试动态止损"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 测试不同持仓时间
|
||||
position_5days = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=95.0,
|
||||
entry_time=datetime.now() - pd.Timedelta(days=5)
|
||||
)
|
||||
|
||||
# 5天持仓止损阈值应该为-0.05
|
||||
stoploss = strategy.dynamic_stoploss(position_5days)
|
||||
assert stoploss == -0.05
|
||||
|
||||
def test_before_entry_premium_filter(self):
|
||||
"""测试入场前溢价过滤"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 正常溢价应该通过
|
||||
result = strategy.before_entry('AAPL', 100.0, premium=0.05)
|
||||
assert result == True
|
||||
|
||||
# 高溢价应该被拒绝
|
||||
result = strategy.before_entry('AAPL', 100.0, premium=0.15)
|
||||
assert result == False
|
||||
|
||||
def test_custom_exit(self):
|
||||
"""测试自定义出场"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 正常盈亏不触发
|
||||
normal_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=95.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
result = strategy.custom_exit(normal_position)
|
||||
assert result == False
|
||||
|
||||
# 大亏损触发出场
|
||||
loss_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=85.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
result = strategy.custom_exit(loss_position)
|
||||
assert result == True
|
||||
|
||||
|
||||
class TestStrategyCallbacks:
|
||||
"""测试策略回调机制"""
|
||||
|
||||
def test_callback_registration(self):
|
||||
"""测试回调自动注册"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 检查回调是否注册
|
||||
callbacks = strategy._callbacks.get_callbacks('before_entry')
|
||||
assert len(callbacks) > 0
|
||||
|
||||
callbacks = strategy._callbacks.get_callbacks('dynamic_stoploss')
|
||||
assert len(callbacks) > 0
|
||||
|
||||
def test_callback_trigger_in_run(self):
|
||||
"""测试回调在策略运行中触发"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 添加自定义回调
|
||||
call_count = {'count': 0}
|
||||
|
||||
def counting_callback(code, price, **kwargs):
|
||||
call_count['count'] += 1
|
||||
return True
|
||||
|
||||
strategy._callbacks.register('before_entry', counting_callback)
|
||||
|
||||
# 运行策略
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
strategy.run(data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
Reference in New Issue
Block a user