refactor(archive): move unused modules to archive/

Archive legacy framework and utility modules that are no longer
referenced by the active core (datasource/ and rotation/):

- framework/ -> archive/framework/
- framework_v2/ -> archive/framework_v2/
- strategies/ -> archive/strategies/
- config/ -> archive/config/
- visualization/ -> archive/visualization/
- scripts/ -> archive/scripts/
- tests/ -> archive/tests/
- run_rotation.py, run_us_rotation.py -> archive/single_files/
- compare_*.py, test_api_dates.py -> archive/single_files/
This commit is contained in:
2026-06-03 23:41:46 +08:00
parent d700bc1dfd
commit c905230a40
98 changed files with 0 additions and 714 deletions

View File

@@ -0,0 +1,60 @@
"""
框架统一入口(通用)
只导出抽象接口具体实现在strategies/shared/
"""
# 因子层抽象
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
# 信号层抽象
from framework.signals import SignalGenerator
# 风控层抽象
from framework.risk import Position, RiskControl, CallbackHook
# 策略层抽象
from framework.strategy import StrategyBase
# 执行层抽象
from framework.execution import Portfolio, Executor, BacktestExecutor, DryRunExecutor
# 配置层
from framework.config import ConfigLoader, StrategyConfig
# 数据层抽象
from framework.data import OHLCVData, DataSource, DataCache
__all__ = [
# 因子层
'FactorBase',
'FactorRegistry',
'FactorCombiner',
# 信号层
'SignalGenerator',
# 风控层
'Position',
'RiskControl',
'CallbackHook',
# 策略层
'StrategyBase',
# 执行层
'Portfolio',
'Executor',
'BacktestExecutor',
'DryRunExecutor',
# 配置层
'ConfigLoader',
'StrategyConfig',
# 数据层
'OHLCVData',
'DataSource',
'DataCache',
]

View File

@@ -0,0 +1,103 @@
"""
配置层抽象接口(通用)
只提供配置加载和验证机制
"""
import yaml
from typing import Dict, Any, Optional
from pathlib import Path
class ConfigLoader:
"""
配置加载器(通用)
支持从YAML文件加载配置
"""
def __init__(self, config_path: Optional[str] = None):
"""初始化配置加载器"""
self._config: Dict[str, Any] = {}
if config_path:
self.load(config_path)
def load(self, config_path: str) -> Dict[str, Any]:
"""从YAML文件加载配置"""
path = Path(config_path)
if not path.exists():
raise FileNotFoundError(f"Config file not found: {config_path}")
with open(path, 'r', encoding='utf-8') as f:
self._config = yaml.safe_load(f) or {}
return self._config
def get(self, key: str, default: Any = None) -> Any:
"""获取配置项"""
return self._config.get(key, default)
def get_section(self, section: str) -> Dict[str, Any]:
"""获取配置区块"""
return self._config.get(section, {})
def validate(self, required_keys: list) -> bool:
"""验证必填配置项"""
missing = [key for key in required_keys if key not in self._config]
if missing:
raise ValueError(f"Missing required config keys: {missing}")
return True
def __repr__(self) -> str:
return f"ConfigLoader(keys={list(self._config.keys())})"
class StrategyConfig:
"""
策略配置类(通用)
用于封装策略配置
"""
def __init__(self, config: Dict[str, Any]):
"""初始化策略配置"""
self._config = config
@property
def name(self) -> str:
"""策略名称"""
return self._config.get('strategy', {}).get('name', 'unknown')
@property
def select_num(self) -> int:
"""选中数量"""
return self._config.get('signal', {}).get('select_num', 3)
@property
def stoploss(self) -> float:
"""止损阈值"""
return self._config.get('risk', {}).get('stop_loss', -0.05)
def get_factor_config(self) -> list:
"""获取因子配置"""
return self._config.get('factors', [])
def get_signal_config(self) -> dict:
"""获取信号配置"""
return self._config.get('signal', {})
def get_risk_config(self) -> list:
"""获取风控配置"""
return self._config.get('risk', [])
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return self._config.copy()
# 导出抽象接口
__all__ = ['ConfigLoader', 'StrategyConfig']

View File

@@ -0,0 +1,126 @@
"""
数据层抽象接口(通用)
只提供数据获取抽象接口具体实现在strategies/shared/data/
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any
import pandas as pd
from dataclasses import dataclass
from datetime import datetime
@dataclass
class OHLCVData:
"""
OHLCV数据结构通用
标准化的K线数据格式
"""
code: str
name: str = ""
start_date: datetime = None
end_date: datetime = None
# OHLCV数据DataFrame
data: pd.DataFrame = None
@property
def length(self) -> int:
"""数据长度"""
return len(self.data) if self.data is not None else 0
def validate(self) -> bool:
"""验证数据完整性"""
if self.data is None or self.data.empty:
return False
required_cols = ['close']
return all(col in self.data.columns for col in required_cols)
def __repr__(self) -> str:
return f"OHLCVData(code={self.code}, name={self.name}, length={self.length})"
class DataSource(ABC):
"""
数据源抽象接口
所有数据源必须实现fetch方法
"""
name: str = "base"
def __init__(self, **params):
"""初始化数据源参数"""
self._params = params
@abstractmethod
def fetch(self, code: str, start: str, end: str) -> OHLCVData:
"""
获取单个标的的OHLCV数据
Args:
code: 标的代码
start: 开始日期 (YYYY-MM-DD)
end: 结束日期 (YYYY-MM-DD)
Returns:
OHLCVData对象
"""
pass
@abstractmethod
def fetch_batch(self, codes: List[str], start: str, end: str) -> Dict[str, OHLCVData]:
"""
批量获取多个标的的OHLCV数据
Args:
codes: 标的代码列表
start: 开始日期
end: 结束日期
Returns:
{code: OHLCVData}字典
"""
pass
def get_supported_codes(self) -> List[str]:
"""获取支持的数据源代码列表"""
return []
def __repr__(self) -> str:
return f"{self.__class__.__name__}(name={self.name})"
class DataCache(ABC):
"""
数据缓存抽象接口(通用)
支持本地缓存管理
"""
@abstractmethod
def get(self, code: str, start: str, end: str) -> Optional[OHLCVData]:
"""从缓存获取数据"""
pass
@abstractmethod
def set(self, code: str, data: OHLCVData) -> None:
"""写入缓存"""
pass
@abstractmethod
def is_fresh(self, code: str, max_age_days: int = 1) -> bool:
"""检查缓存是否新鲜"""
pass
@abstractmethod
def clear(self, code: Optional[str] = None) -> None:
"""清空缓存"""
pass
# 导出抽象接口
__all__ = ['OHLCVData', 'DataSource', 'DataCache']

View File

@@ -0,0 +1,500 @@
"""
执行层抽象接口(通用)
只提供抽象基类和Portfolio数据结构具体执行器可扩展
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
import pandas as pd
import numpy as np
from datetime import datetime
from framework.risk import Position
class Portfolio:
"""
投资组合数据结构(通用)
用于管理持仓集合
"""
def __init__(self, initial_capital: float = 100000):
"""初始化投资组合"""
self.initial_capital = initial_capital
self.cash = initial_capital
self.positions: Dict[str, Position] = {}
self.trades: List[Dict] = []
self._net_value_history: List[float] = []
def add_position(self, code: str, price: float, quantity: float, time: datetime) -> None:
"""添加持仓"""
position = Position(
code=code,
entry_price=price,
current_price=price,
entry_time=time,
quantity=quantity
)
self.positions[code] = position
self.cash -= price * quantity
self.trades.append({
'action': 'BUY',
'code': code,
'price': price,
'quantity': quantity,
'time': time
})
def remove_position(self, code: str, price: float, time: datetime) -> float:
"""移除持仓"""
if code not in self.positions:
return 0
position = self.positions[code]
profit = (price - position.entry_price) * position.quantity
self.cash += price * position.quantity
del self.positions[code]
self.trades.append({
'action': 'SELL',
'code': code,
'price': price,
'quantity': position.quantity,
'time': time,
'profit': profit
})
return profit
def update_prices(self, prices: Dict[str, float]) -> None:
"""更新持仓价格"""
for code, price in prices.items():
if code in self.positions:
self.positions[code].current_price = price
def get_net_value(self) -> float:
"""计算净值"""
positions_value = sum(
pos.current_price * pos.quantity for pos in self.positions.values()
)
return self.cash + positions_value
def record_net_value(self) -> None:
"""记录当前净值"""
self._net_value_history.append(self.get_net_value())
def get_net_value_series(self) -> pd.Series:
"""获取净值序列"""
return pd.Series(self._net_value_history)
def get_weight(self, code: str) -> float:
"""计算持仓权重"""
if code not in self.positions:
return 0
position_value = self.positions[code].current_price * self.positions[code].quantity
return position_value / self.get_net_value()
def __repr__(self) -> str:
return f"Portfolio(capital={self.cash:.2f}, positions={len(self.positions)})"
class Executor(ABC):
"""
执行器抽象基类
所有执行器必须实现execute方法
"""
mode: str = "base"
def __init__(self, **params):
"""初始化执行器参数"""
self._params = params
@abstractmethod
def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> Portfolio:
"""
执行信号
Args:
signals: 信号DataFrame
data: OHLCV数据
Returns:
Portfolio对象
"""
pass
def __repr__(self) -> str:
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
return f"{self.__class__.__name__}({params_str})"
class BacktestExecutor(Executor):
"""
完整回测执行器(通用)
支持:
- 日收益率计算
- 交易成本扣除
- 净值计算(起点归一化)
- 基准对比
- 持仓跟踪
"""
mode = "backtest"
def __init__(
self,
initial_capital: float = 100000,
trade_cost: float = 0.001,
select_num: int = 1,
benchmark_data: Optional[pd.DataFrame] = None
):
super().__init__(
initial_capital=initial_capital,
trade_cost=trade_cost,
select_num=select_num,
benchmark_data=benchmark_data
)
self.initial_capital = initial_capital
self.trade_cost = trade_cost
self.select_num = select_num
self.benchmark_data = benchmark_data
def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> Portfolio:
"""
执行完整回测
Args:
signals: 信号DataFrame包含signal或信号列
data: OHLCV数据和日收益率数据
Returns:
Portfolio对象含净值序列、交易记录
"""
portfolio = Portfolio(self.initial_capital)
# 支持中英文列名
signal_col = 'signal' if 'signal' in signals.columns else '信号'
# 删除空信号行
signals = signals.dropna(subset=[signal_col])
signals = signals[signals[signal_col] != '']
if signals.empty:
return portfolio
# 计算策略日收益率
result = self._calculate_daily_returns(signals, data, signal_col)
# 扣除交易成本(同时记录调仓事件)
result, rebalance_events = self._apply_trade_cost_with_events(result, signals, signal_col)
# 计算净值(起点归一化)
result = self._calculate_net_value(result)
# 计算基准净值
result = self._calculate_benchmark(result)
# 记录净值历史
for date in result.index:
portfolio.record_net_value()
# 存储回测结果
portfolio.backtest_result = result
portfolio.rebalance_events = rebalance_events # 新增:调仓事件记录
# 补充调仓事件的净值信息
if not rebalance_events.empty:
rebalance_events = self._enrich_rebalance_events(rebalance_events, result)
portfolio.rebalance_events = rebalance_events
return portfolio
def _calculate_daily_returns(self, signals: pd.DataFrame, data: pd.DataFrame, signal_col: str = 'signal') -> pd.DataFrame:
"""计算策略日收益率"""
result = signals.copy()
# 日收益率列名格式日收益率_{code} 或 日收益率_{code}
return_cols = [col for col in data.columns if col.startswith('日收益率_')]
if self.select_num == 1:
# 单标的策略
def calc_return(row):
signal = row[signal_col]
if not signal or pd.isna(signal):
return 0.0
return data.loc[row.name, f'日收益率_{signal}'] if f'日收益率_{signal}' in data.columns else 0.0
result['策略日收益率'] = result.apply(calc_return, axis=1)
else:
# 多标的策略(等权组合)
# 按实际持仓数量等权分配选出2只时每只50%选出1只时100%
def calc_multi_return(row):
codes = [c for c in row[signal_col].split(',') if c]
if not codes:
return 0.0
returns = []
for c in codes:
ret = data.loc[row.name, f'日收益率_{c}'] if f'日收益率_{c}' in data.columns else None
if ret is not None and pd.notna(ret):
returns.append(ret)
return np.mean(returns) if returns else 0.0
result['策略日收益率'] = result.apply(calc_multi_return, axis=1)
return result
def _apply_trade_cost(self, result: pd.DataFrame, signals: pd.DataFrame, signal_col: str = 'signal') -> pd.DataFrame:
"""扣除交易成本"""
if self.trade_cost <= 0:
return result
prev_signal = signals[signal_col].shift(1)
if self.select_num == 1:
# 单标的策略:调仓时扣除固定成本
changed = (signals[signal_col] != prev_signal) & prev_signal.notna()
result.loc[changed, '策略日收益率'] -= self.trade_cost
else:
# 多标的策略:按换手率比例扣除成本
turnover_list = []
for curr, prev in zip(signals[signal_col], prev_signal):
if pd.isna(prev) or curr == prev:
turnover_list.append(0.0)
else:
old = set(prev.split(','))
new = set(curr.split(','))
swapped = len(old - new)
turnover = swapped / len(old) if old else 0.0
turnover_list.append(turnover)
result['换手率'] = turnover_list
result['策略日收益率'] -= result['换手率'] * self.trade_cost
return result
def _apply_trade_cost_with_events(self, result: pd.DataFrame, signals: pd.DataFrame, signal_col: str = 'signal') -> tuple:
"""
扣除交易成本并记录调仓事件
Returns:
(result, rebalance_events): 回测结果DataFrame和调仓事件DataFrame
"""
prev_signal = signals[signal_col].shift(1)
# 记录调仓事件
rebalance_events = []
last_rebalance_date = None
# 先计算累积收益率(用于计算调仓前后的净值)
cum_return_before_cost = result['策略日收益率'].copy()
if self.select_num == 1:
# 单标的策略
for i, (date, curr, prev) in enumerate(zip(signals.index, signals[signal_col], prev_signal)):
# 检查是否调仓
is_rebalance = False
turnover = 0.0
added = []
removed = []
if pd.notna(prev) and curr != prev:
is_rebalance = True
turnover = 1.0 if prev else 0.0
added = [curr] if curr else []
removed = [prev] if prev else []
# 扣除成本
result.loc[date, '策略日收益率'] -= self.trade_cost
# 记录调仓事件
if is_rebalance:
# 计算持仓天数
holding_days = 0
if last_rebalance_date is not None:
holding_days = (date - last_rebalance_date).days
event = {
'日期': date,
'调仓前持仓': prev if pd.notna(prev) else '',
'调仓后持仓': curr,
'调入标的': ','.join(added) if added else '',
'调出标的': ','.join(removed) if removed else '',
'换手率': turnover,
'调仓成本': self.trade_cost * turnover,
'持仓天数': holding_days,
'当日收益': result.loc[date, '策略日收益率'] + self.trade_cost * turnover, # 原始收益(扣除成本前)
}
rebalance_events.append(event)
last_rebalance_date = date
else:
# 多标的策略
turnover_list = []
for i, (date, curr, prev) in enumerate(zip(signals.index, signals[signal_col], prev_signal)):
# 检查是否调仓
is_rebalance = False
turnover = 0.0
added = []
removed = []
if pd.notna(prev) and curr != prev:
old = set(prev.split(',')) if prev else set()
new = set(curr.split(',')) if curr else set()
added = list(new - old)
removed = list(old - new)
swapped = len(removed)
turnover = swapped / len(old) if old else 0.0
is_rebalance = len(added) > 0 or len(removed) > 0
turnover_list.append(turnover)
# 扣除成本
result.loc[date, '策略日收益率'] -= turnover * self.trade_cost
else:
turnover_list.append(0.0)
# 记录调仓事件
if is_rebalance:
# 计算持仓天数
holding_days = 0
if last_rebalance_date is not None:
holding_days = (date - last_rebalance_date).days
event = {
'日期': date,
'调仓前持仓': prev if pd.notna(prev) else '',
'调仓后持仓': curr,
'调入标的': ','.join(added) if added else '',
'调出标的': ','.join(removed) if removed else '',
'换手率': turnover,
'调仓成本': self.trade_cost * turnover,
'持仓天数': holding_days,
'当日收益': result.loc[date, '策略日收益率'] + turnover * self.trade_cost, # 原始收益(扣除成本前)
}
rebalance_events.append(event)
last_rebalance_date = date
result['换手率'] = turnover_list
# 转换为DataFrame
rebalance_df = pd.DataFrame(rebalance_events) if rebalance_events else pd.DataFrame()
if not rebalance_df.empty:
rebalance_df['日期'] = pd.to_datetime(rebalance_df['日期'])
rebalance_df = rebalance_df.set_index('日期')
return result, rebalance_df
def _enrich_rebalance_events(self, rebalance_df: pd.DataFrame, result: pd.DataFrame) -> pd.DataFrame:
"""
补充调仓事件的净值信息
Args:
rebalance_df: 调仓事件DataFrame
result: 回测结果DataFrame含净值序列
Returns:
补充净值信息后的调仓事件DataFrame
"""
# 计算调仓前后净值变化
nav_before_list = []
nav_after_list = []
nav_change_list = []
for date in rebalance_df.index:
# 获取调仓日的净值
if date in result.index:
# 调仓前净值:前一天收盘净值
prev_date_idx = result.index.get_loc(date) - 1
if prev_date_idx >= 0:
nav_before = result['策略净值'].iloc[prev_date_idx]
else:
nav_before = 1.0
# 调仓后净值:当天收盘净值
nav_after = result.loc[date, '策略净值']
# 净值变化
nav_change = (nav_after / nav_before - 1) * 100
else:
nav_before = None
nav_after = None
nav_change = None
nav_before_list.append(nav_before)
nav_after_list.append(nav_after)
nav_change_list.append(nav_change)
# 添加净值信息列
rebalance_df['调仓前净值'] = nav_before_list
rebalance_df['调仓后净值'] = nav_after_list
rebalance_df['净值变化%'] = nav_change_list
return rebalance_df
def _calculate_net_value(self, result: pd.DataFrame) -> pd.DataFrame:
"""计算净值(起点归一化)"""
result['策略净值'] = (1 + result['策略日收益率']).cumprod()
# 归一化确保净值起点为1.0
result['策略净值'] = result['策略净值'] / result['策略净值'].iloc[0]
return result
def _calculate_benchmark(self, result: pd.DataFrame) -> pd.DataFrame:
"""计算基准净值"""
if self.benchmark_data is None:
return result
# 获取基准收益率
if isinstance(self.benchmark_data, pd.DataFrame):
if 'close' in self.benchmark_data.columns:
bench_close = self.benchmark_data['close']
else:
bench_close = self.benchmark_data.iloc[:, 0]
else:
bench_close = self.benchmark_data
bench_ret = bench_close.pct_change().dropna()
common_dates = result.index.intersection(bench_ret.index)
bench_ret = bench_ret.loc[common_dates]
result['基准日收益率'] = bench_ret.reindex(result.index, fill_value=0)
result['基准净值'] = (1 + result['基准日收益率']).cumprod()
result['基准净值'] = result['基准净值'] / result['基准净值'].iloc[0]
return result
class DryRunExecutor(Executor):
"""
Dry-run执行器通用
用于模拟运行,不实际执行交易
"""
mode = "dry_run"
def __init__(self, verbose: bool = True):
super().__init__(verbose=verbose)
self.verbose = verbose
def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> Portfolio:
"""模拟执行"""
portfolio = Portfolio(100000)
for date in signals.index:
signal = signals.loc[date, 'signal']
if signal and self.verbose:
print(f"[{date}] Signal: {signal}")
return portfolio
# 导出抽象接口
__all__ = ['Portfolio', 'Executor', 'BacktestExecutor', 'DryRunExecutor']

View File

@@ -0,0 +1,191 @@
"""
因子层抽象接口(通用)
只提供抽象基类和注册机制具体因子实现在strategies/shared/factors/
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any, Type
import pandas as pd
import numpy as np
class FactorBase(ABC):
"""
因子抽象基类
所有因子必须实现compute方法
"""
name: str = "base"
category: str = "unknown"
def __init__(self, **params):
"""初始化因子参数"""
self._params = params
@abstractmethod
def compute(self, data: pd.DataFrame) -> pd.Series:
"""
计算因子值
Args:
data: OHLCV数据必须包含'close'
Returns:
因子值序列
"""
pass
def validate_data(self, data: pd.DataFrame) -> bool:
"""验证数据是否满足计算要求"""
if 'close' not in data.columns:
return False
min_periods = self._params.get('min_periods', 20)
return len(data) >= min_periods
def __repr__(self) -> str:
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
return f"{self.__class__.__name__}({params_str})"
class FactorRegistry:
"""
因子注册器(通用)
管理因子类的注册和获取
"""
_factors: Dict[str, Type[FactorBase]] = {}
@classmethod
def register(cls, factor_class: Type[FactorBase]) -> None:
"""注册因子类"""
temp_instance = factor_class()
name = temp_instance.name
if name in cls._factors:
print(f"因子已注册,覆盖: {name}")
cls._factors[name] = factor_class
@classmethod
def get(cls, name: str, **params) -> FactorBase:
"""获取因子实例"""
if name not in cls._factors:
raise ValueError(f"因子未注册: {name}")
factor_class = cls._factors[name]
return factor_class(**params)
@classmethod
def list_factors(cls) -> List[str]:
"""列出所有已注册因子"""
return list(cls._factors.keys())
@classmethod
def clear(cls) -> None:
"""清空注册表"""
cls._factors = {}
@classmethod
def get_category(cls, name: str) -> str:
"""获取因子类别"""
if name not in cls._factors:
return "unknown"
temp_instance = cls._factors[name]()
return temp_instance.category
class FactorCombiner:
"""
因子组合器(通用)
支持多因子加权组合
"""
SUPPORTED_METHODS = ['weighted_sum', 'rank_average', 'zscore_sum', 'equal_weight']
def __init__(
self,
factors: List[FactorBase],
weights: Optional[List[float]] = None,
method: str = 'weighted_sum'
):
"""
初始化组合器
Args:
factors: 因子实例列表
weights: 因子权重列表(可选)
method: 组合方法weighted_sum/rank_average/zscore_sum/equal_weight
"""
if not factors:
raise ValueError("factors list cannot be empty")
if method not in self.SUPPORTED_METHODS:
raise ValueError(f"Unsupported method: {method}")
self._factors = factors
if weights is None:
self._weights = [1.0 / len(factors)] * len(factors)
else:
if len(weights) != len(factors):
raise ValueError("weights length must match factors length")
self._weights = weights
self._method = method
def compute(self, data: pd.DataFrame) -> pd.DataFrame:
"""
计算所有因子并组合
Returns:
DataFrame包含各因子值和combined列
"""
result = pd.DataFrame(index=data.index)
# 计算各因子
for i, factor in enumerate(self._factors):
factor_values = factor.compute(data)
col_name = f"{factor.name}"
result[col_name] = factor_values
# 组合因子值
if self._method == 'weighted_sum':
weighted_cols = [f.name for f in self._factors]
result['combined'] = result[weighted_cols].apply(
lambda row: sum(row[col] * self._weights[i] for i, col in enumerate(weighted_cols) if pd.notna(row[col])),
axis=1
)
elif self._method == 'equal_weight':
factor_cols = [f.name for f in self._factors]
result['combined'] = result[factor_cols].mean(axis=1)
elif self._method == 'rank_average':
factor_cols = [f.name for f in self._factors]
ranks = result[factor_cols].rank(axis=1)
result['combined'] = ranks.mean(axis=1)
elif self._method == 'zscore_sum':
factor_cols = [f.name for f in self._factors]
zscores = result[factor_cols].apply(lambda x: (x - x.mean()) / x.std())
result['combined'] = zscores.sum(axis=1)
return result
def get_factor_names(self) -> List[str]:
"""获取因子名称列表"""
return [f.name for f in self._factors]
def __repr__(self) -> str:
factor_names = [f.name for f in self._factors]
return f"FactorCombiner(factors={factor_names}, weights={self._weights}, method={self._method})"
# 导出抽象接口
__all__ = ['FactorBase', 'FactorRegistry', 'FactorCombiner']

View File

@@ -0,0 +1,188 @@
"""
风控层抽象接口(通用)
只提供抽象基类和回调机制具体风控组件在strategies/shared/risk/
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Callable, Optional
from dataclasses import dataclass
from datetime import datetime
@dataclass
class Position:
"""
持仓数据结构(通用)
用于表示单个持仓的状态
"""
code: str
entry_price: float
current_price: float
entry_time: datetime
quantity: float = 1.0
weight: float = 1.0
@property
def profit_ratio(self) -> float:
"""计算盈亏比例"""
return (self.current_price - self.entry_price) / self.entry_price
@property
def profit_amount(self) -> float:
"""计算盈亏金额"""
return (self.current_price - self.entry_price) * self.quantity
@property
def holding_days(self) -> int:
"""计算持仓天数"""
if self.entry_time is None:
return 0
return (datetime.now() - self.entry_time).days
def __repr__(self) -> str:
return f"Position(code={self.code}, profit={self.profit_ratio:.2%}, days={self.holding_days})"
class RiskControl(ABC):
"""
风控组件抽象基类
所有风控组件必须实现check方法
"""
name: str = "base"
def __init__(self, **params):
"""初始化风控参数"""
self._params = params
@abstractmethod
def check(self, position: Position, **kwargs) -> bool:
"""
检查风控条件
Args:
position: 持仓对象
kwargs: 额外参数如premium、history等
Returns:
True表示通过检查False表示触发风控
"""
pass
def apply(self, position: Position) -> Any:
"""
应用风控规则(可选)
Args:
position: 持仓对象
Returns:
风控结果(如止损价格、建议仓位等)
"""
return None
def __repr__(self) -> str:
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
return f"{self.__class__.__name__}({params_str})"
class CallbackHook:
"""
回调钩子管理(通用)
支持在策略生命周期的关键节点注入自定义逻辑
"""
SUPPORTED_HOOKS = [
'before_entry', # 入场前检查
'after_entry', # 入场后处理
'before_exit', # 出场前检查
'after_exit', # 出场后处理
'dynamic_stoploss', # 动态止损计算
'custom_exit' # 自定义出场条件
]
def __init__(self):
"""初始化回调钩子"""
self._hooks: Dict[str, List[Callable]] = {
hook: [] for hook in self.SUPPORTED_HOOKS
}
def register(self, hook_name: str, callback: Callable) -> None:
"""注册回调函数"""
if hook_name not in self._hooks:
raise ValueError(f"Unsupported hook: {hook_name}")
self._hooks[hook_name].append(callback)
def trigger(self, hook_name: str, *args, **kwargs) -> Any:
"""
触发回调
Args:
hook_name: 钩子名称
args: 位置参数
kwargs: 关键字参数
Returns:
回调结果(根据钩子类型返回不同结果)
"""
if hook_name not in self._hooks:
raise ValueError(f"Unsupported hook: {hook_name}")
callbacks = self._hooks[hook_name]
if not callbacks:
# 默认行为
if hook_name == 'dynamic_stoploss':
return kwargs.get('default_stoploss', -0.05)
elif hook_name in ['before_entry', 'before_exit']:
return True
elif hook_name == 'custom_exit':
return False
return None
results = []
for callback in callbacks:
result = callback(*args, **kwargs)
results.append(result)
# 根据钩子类型返回不同结果
if hook_name in ['before_entry', 'before_exit']:
# 所有回调返回True才允许
return all(results)
if hook_name == 'dynamic_stoploss':
# 返回最小止损值(最严格)
return min(results)
if hook_name == 'custom_exit':
# 任一回调触发出场
return any(results)
# 其他钩子返回最后一个结果
return results[-1] if results else None
def clear(self, hook_name: Optional[str] = None) -> None:
"""清空回调"""
if hook_name:
if hook_name in self._hooks:
self._hooks[hook_name] = []
else:
for hook in self._hooks:
self._hooks[hook] = []
def list_hooks(self) -> List[str]:
"""列出支持的钩子"""
return self.SUPPORTED_HOOKS
def get_callbacks(self, hook_name: str) -> List[Callable]:
"""获取钩子的所有回调"""
return self._hooks.get(hook_name, [])
# 导出抽象接口
__all__ = ['Position', 'RiskControl', 'CallbackHook']

View File

@@ -0,0 +1,54 @@
"""
信号层抽象接口(通用)
只提供抽象基类具体信号生成器在strategies/shared/signals/
"""
from abc import ABC, abstractmethod
from typing import List, Optional, Any
import pandas as pd
class SignalGenerator(ABC):
"""
信号生成器抽象基类
所有信号生成器必须实现generate方法
"""
mode: str = "base"
def __init__(self, **params):
"""初始化信号生成器参数"""
self._params = params
@abstractmethod
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
"""
生成交易信号
Args:
factor_data: 因子数据DataFrame
Returns:
包含'signal'列的DataFrame
"""
pass
def validate_factor_data(self, factor_data: pd.DataFrame) -> bool:
"""验证因子数据是否有效"""
if factor_data.empty:
return False
if 'signal' in factor_data.columns:
print("Warning: factor_data already contains 'signal' column")
return True
def __repr__(self) -> str:
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
return f"{self.__class__.__name__}({params_str})"
# 导出抽象接口
__all__ = ['SignalGenerator']

View File

@@ -0,0 +1,144 @@
"""
策略层抽象基类(通用)
只提供抽象接口具体策略实现在strategies/
"""
from abc import ABC, abstractmethod
from typing import Dict, Optional, Any
import pandas as pd
from framework.factors import FactorCombiner
from framework.signals import SignalGenerator
from framework.risk import CallbackHook, Position
class StrategyBase(ABC):
"""
策略抽象基类
所有策略必须实现init_factors和init_signal_generator方法
"""
INTERFACE_VERSION = 1
name: str = "base"
timeframe: str = "1d"
# 类属性(可被配置覆盖)
select_num: int = 3
stoploss: float = -0.05
def __init__(self, config: Optional[Dict] = None):
"""
初始化策略
Args:
config: 配置字典(可选,用于覆盖类属性)
"""
if config:
self._apply_config(config)
self._callbacks = CallbackHook()
self._register_default_callbacks()
self._factors = self.init_factors()
self._signal_gen = self.init_signal_generator()
def _apply_config(self, config: Dict) -> None:
"""应用配置覆盖类属性"""
for key, value in config.items():
if hasattr(self, key):
setattr(self, key, value)
def _register_default_callbacks(self) -> None:
"""注册默认回调方法"""
if hasattr(self, 'before_entry'):
self._callbacks.register('before_entry', self.before_entry)
if hasattr(self, 'after_entry'):
self._callbacks.register('after_entry', self.after_entry)
if hasattr(self, 'before_exit'):
self._callbacks.register('before_exit', self.before_exit)
if hasattr(self, 'after_exit'):
self._callbacks.register('after_exit', self.after_exit)
if hasattr(self, 'dynamic_stoploss'):
self._callbacks.register('dynamic_stoploss', self.dynamic_stoploss)
if hasattr(self, 'custom_exit'):
self._callbacks.register('custom_exit', self.custom_exit)
@abstractmethod
def init_factors(self) -> FactorCombiner:
"""
初始化因子组合器
Returns:
FactorCombiner实例
"""
pass
@abstractmethod
def init_signal_generator(self) -> SignalGenerator:
"""
初始化信号生成器
Returns:
SignalGenerator实例
"""
pass
def run(self, data: pd.DataFrame) -> pd.DataFrame:
"""
运行策略
Args:
data: OHLCV数据
Returns:
包含信号的DataFrame
"""
factor_data = self._factors.compute(data)
signals = self._signal_gen.generate(factor_data)
signals = self._apply_callbacks(signals, data)
return signals
def _apply_callbacks(self, signals: pd.DataFrame, data: pd.DataFrame) -> pd.DataFrame:
"""应用回调处理"""
return signals
# 可选回调方法(子类可覆盖)
def before_entry(self, code: str, price: float, **kwargs) -> bool:
"""入场前检查"""
return True
def after_entry(self, code: str, price: float, **kwargs) -> None:
"""入场后处理"""
pass
def before_exit(self, position: Position, **kwargs) -> bool:
"""出场前检查"""
return True
def after_exit(self, position: Position, **kwargs) -> None:
"""出场后处理"""
pass
def dynamic_stoploss(self, position: Position) -> float:
"""动态止损"""
return self.stoploss
def custom_exit(self, position: Position) -> bool:
"""自定义出场条件"""
return False
def __repr__(self) -> str:
return f"{self.__class__.__name__}(name={self.name})"
# 导出抽象接口
__all__ = ['StrategyBase']

View File

@@ -0,0 +1,174 @@
"""
执行层测试
测试Portfolio、Executor抽象接口
"""
import pandas as pd
import numpy as np
import pytest
from datetime import datetime
from framework.execution import Portfolio, Executor, BacktestExecutor, DryRunExecutor
from framework.risk import Position
class TestPortfolio:
"""测试Portfolio"""
def test_portfolio_init(self):
"""测试初始化"""
portfolio = Portfolio(initial_capital=100000)
assert portfolio.initial_capital == 100000
assert portfolio.cash == 100000
assert len(portfolio.positions) == 0
def test_add_position(self):
"""测试添加持仓"""
portfolio = Portfolio(initial_capital=100000)
portfolio.add_position(
code='AAPL',
price=100.0,
quantity=10,
time=datetime.now()
)
assert len(portfolio.positions) == 1
assert 'AAPL' in portfolio.positions
assert portfolio.cash == 99000 # 100000 - 100*10
def test_remove_position(self):
"""测试移除持仓"""
portfolio = Portfolio(initial_capital=100000)
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
profit = portfolio.remove_position('AAPL', 110.0, datetime.now())
assert len(portfolio.positions) == 0
assert profit == 100.0 # (110-100)*10
assert portfolio.cash == 100100 # 99000 + 110*10
def test_update_prices(self):
"""测试更新价格"""
portfolio = Portfolio(initial_capital=100000)
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
portfolio.add_position('MSFT', 200.0, 5, datetime.now())
portfolio.update_prices({'AAPL': 110.0, 'MSFT': 220.0})
assert portfolio.positions['AAPL'].current_price == 110.0
assert portfolio.positions['MSFT'].current_price == 220.0
def test_get_net_value(self):
"""测试净值计算"""
portfolio = Portfolio(initial_capital=100000)
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
portfolio.update_prices({'AAPL': 110.0})
net_value = portfolio.get_net_value()
expected = 99000 + 110 * 10 # cash + position_value
assert net_value == expected
def test_get_weight(self):
"""测试权重计算"""
portfolio = Portfolio(initial_capital=100000)
portfolio.add_position('AAPL', 100.0, 100, datetime.now())
portfolio.update_prices({'AAPL': 110.0})
weight = portfolio.get_weight('AAPL')
# 持仓价值=110*100=11000净值=90000+11000=101000
expected_weight = 11000 / 101000
assert abs(weight - expected_weight) < 0.01
def test_record_net_value(self):
"""测试净值记录"""
portfolio = Portfolio(initial_capital=100000)
portfolio.record_net_value()
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
portfolio.record_net_value()
series = portfolio.get_net_value_series()
assert len(series) == 2
def test_portfolio_repr(self):
"""测试字符串表示"""
portfolio = Portfolio(initial_capital=100000)
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
repr_str = repr(portfolio)
assert 'Portfolio' in repr_str
assert 'positions=1' in repr_str
class TestBacktestExecutor:
"""测试回测执行器"""
def test_backtest_executor_init(self):
"""测试初始化"""
executor = BacktestExecutor(initial_capital=100000, trade_cost=0.001)
assert executor.initial_capital == 100000
assert executor.trade_cost == 0.001
assert executor.mode == "backtest"
def test_backtest_execute(self):
"""测试回测执行"""
executor = BacktestExecutor(initial_capital=100000)
dates = pd.date_range('2020-01-01', periods=50)
signals = pd.DataFrame({
'signal': ['AAPL'] * 50
}, index=dates)
data = pd.DataFrame({
'close': np.random.randn(50).cumsum() + 100
}, index=dates)
portfolio = executor.execute(signals, data)
assert isinstance(portfolio, Portfolio)
def test_backtest_executor_repr(self):
"""测试字符串表示"""
executor = BacktestExecutor(initial_capital=100000, trade_cost=0.001)
repr_str = repr(executor)
assert 'BacktestExecutor' in repr_str
class TestDryRunExecutor:
"""测试DryRun执行器"""
def test_dry_run_executor_init(self):
"""测试初始化"""
executor = DryRunExecutor(verbose=True)
assert executor.verbose == True
assert executor.mode == "dry_run"
def test_dry_run_execute(self):
"""测试模拟执行"""
executor = DryRunExecutor(verbose=False)
dates = pd.date_range('2020-01-01', periods=50)
signals = pd.DataFrame({
'signal': ['AAPL,MSFT'] * 50
}, index=dates)
data = pd.DataFrame({
'close': np.random.randn(50).cumsum() + 100
}, index=dates)
portfolio = executor.execute(signals, data)
assert isinstance(portfolio, Portfolio)
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -0,0 +1,288 @@
"""
因子层测试
测试FactorBase、FactorRegistry、FactorCombiner抽象接口
"""
import pandas as pd
import numpy as np
import pytest
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
from strategies.shared.factors.momentum import MomentumFactor, TrendFactor, ReversalFactor, VolatilityFactor
class TestFactorBase:
"""测试FactorBase抽象基类"""
def test_factor_meta(self):
"""测试因子元信息"""
factor = MomentumFactor(n_days=25)
assert factor.name == "momentum"
assert factor.category == "momentum"
def test_factor_repr(self):
"""测试因子字符串表示"""
factor = MomentumFactor(n_days=30)
repr_str = repr(factor)
assert "MomentumFactor" in repr_str
def test_validate_data(self):
"""测试数据验证"""
factor = MomentumFactor(n_days=25)
# 数据充足
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
})
assert factor.validate_data(data) == True
# 数据不足
short_data = pd.DataFrame({
'close': np.random.randn(10).cumsum() + 100
})
assert factor.validate_data(short_data) == False
class TestFactorRegistry:
"""测试因子注册器"""
def setup_method(self):
"""每个测试前清空注册表"""
FactorRegistry.clear()
def test_register_factor(self):
"""测试因子注册"""
FactorRegistry.register(MomentumFactor)
assert 'momentum' in FactorRegistry.list_factors()
def test_get_factor(self):
"""测试获取因子实例"""
FactorRegistry.register(MomentumFactor)
factor = FactorRegistry.get('momentum', n_days=30)
assert isinstance(factor, MomentumFactor)
assert factor.n_days == 30
def test_get_unknown_factor(self):
"""测试获取未注册因子"""
with pytest.raises(ValueError):
FactorRegistry.get('unknown_factor')
def test_get_category(self):
"""测试获取因子类别"""
FactorRegistry.register(MomentumFactor)
category = FactorRegistry.get_category('momentum')
assert category == 'momentum'
class TestFactorCombiner:
"""测试因子组合器"""
def setup_method(self):
"""每个测试前清空注册表"""
FactorRegistry.clear()
def test_combiner_init(self):
"""测试组合器初始化"""
factors = [
MomentumFactor(n_days=25),
TrendFactor(method='ma_cross')
]
combiner = FactorCombiner(factors, weights=[0.7, 0.3])
assert len(combiner.get_factor_names()) == 2
def test_combiner_equal_weights(self):
"""测试等权组合"""
factors = [
MomentumFactor(n_days=25),
TrendFactor()
]
combiner = FactorCombiner(factors) # 默认等权
# 权重应该归一化
assert sum(combiner._weights) == 1.0
def test_combiner_compute(self):
"""测试因子组合计算"""
# 生成测试数据
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100,
'high': np.random.randn(100).cumsum() + 105,
'low': np.random.randn(100).cumsum() + 95
}, index=dates)
factors = [
MomentumFactor(n_days=20),
TrendFactor(fast=5, slow=10)
]
combiner = FactorCombiner(factors, weights=[0.6, 0.4])
result = combiner.compute(data)
# 检查结果列
assert 'momentum' in result.columns
assert 'trend' in result.columns
assert 'combined' in result.columns
def test_combiner_method_rank_average(self):
"""测试rank_average组合方法"""
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
}, index=dates)
factors = [
MomentumFactor(n_days=20),
TrendFactor()
]
combiner = FactorCombiner(factors, method='rank_average')
result = combiner.compute(data)
# combined应该是排名平均值
assert 'combined' in result.columns
class TestMomentumFactor:
"""测试动量因子"""
def test_momentum_compute(self):
"""测试动量因子计算"""
dates = pd.date_range('2020-01-01', periods=100)
# 生成上升趋势数据
prices = 100 + np.arange(100) * 0.5
data = pd.DataFrame({'close': prices}, index=dates)
factor = MomentumFactor(n_days=25, weighted=True)
values = factor.compute(data)
# 上升趋势应该有正的动量得分
assert values.iloc[-1] > 0
def test_crash_filter(self):
"""测试崩盘过滤"""
dates = pd.date_range('2020-01-01', periods=100)
# 生成正常数据,然后在末尾添加崩盘
prices = 100 + np.random.randn(100).cumsum()
prices[-3:] = prices[-4] * np.array([0.96, 0.93, 0.90]) # 连续大跌
data = pd.DataFrame({'close': prices}, index=dates)
factor = MomentumFactor(n_days=25, crash_filter=True)
values = factor.compute(data)
# 崩盘后动量得分应该被清零
assert values.iloc[-1] == 0.0
def test_simple_momentum(self):
"""测试简单动量(无加权,无崩盘过滤)"""
dates = pd.date_range('2020-01-01', periods=100)
prices = 100 + np.random.randn(100).cumsum()
data = pd.DataFrame({'close': prices}, index=dates)
factor = MomentumFactor(n_days=25, weighted=False, crash_filter=False)
values = factor.compute(data)
# 简单动量应该是N日涨幅
expected = data['close'].pct_change(25)
# 验证长度一致
assert len(values) == len(expected)
class TestTrendFactor:
"""测试趋势因子"""
def test_ma_cross(self):
"""测试MA交叉趋势"""
dates = pd.date_range('2020-01-01', periods=100)
# 生成上升趋势
prices = 100 + np.arange(100) * 0.5
data = pd.DataFrame({'close': prices}, index=dates)
factor = TrendFactor(method='ma_cross', fast=5, slow=20)
values = factor.compute(data)
# 上升趋势应该有正的趋势强度
assert values.iloc[-1] > 0
def test_macd(self):
"""测试MACD趋势"""
dates = pd.date_range('2020-01-01', periods=100)
prices = 100 + np.random.randn(100).cumsum()
data = pd.DataFrame({'close': prices}, index=dates)
factor = TrendFactor(method='macd')
values = factor.compute(data)
# 检查计算结果
assert len(values) == len(data)
class TestReversalFactor:
"""测试反转因子"""
def test_rsi_reversal(self):
"""测试RSI反转信号"""
dates = pd.date_range('2020-01-01', periods=100)
# 生成超买数据(持续上涨)
prices = 100 + np.arange(100) * 1.0
data = pd.DataFrame({'close': prices}, index=dates)
factor = ReversalFactor(method='rsi', period=14, overbought=70)
values = factor.compute(data)
# RSI超过70应该产生负值反转向下信号
assert values.iloc[-1] < 0
def test_rsi_oversold(self):
"""测试RSI超卖信号"""
dates = pd.date_range('2020-01-01', periods=100)
# 生成超卖数据(持续下跌)
prices = 100 - np.arange(100) * 1.0
data = pd.DataFrame({'close': prices}, index=dates)
factor = ReversalFactor(method='rsi', period=14, oversold=30)
values = factor.compute(data)
# RSI低于30应该产生正值反转向上信号
assert values.iloc[-1] > 0
class TestVolatilityFactor:
"""测试波动率因子"""
def test_std_volatility(self):
"""测试标准差波动率"""
dates = pd.date_range('2020-01-01', periods=100)
prices = 100 + np.random.randn(100).cumsum()
data = pd.DataFrame({'close': prices}, index=dates)
factor = VolatilityFactor(method='std', period=20)
values = factor.compute(data)
assert len(values) == len(data)
def test_atr_volatility(self):
"""测试ATR波动率"""
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100,
'high': np.random.randn(100).cumsum() + 105,
'low': np.random.randn(100).cumsum() + 95
}, index=dates)
factor = VolatilityFactor(method='atr', period=20)
values = factor.compute(data)
assert len(values) == len(data)
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -0,0 +1,166 @@
"""
集成测试
测试框架与定制组件的集成
"""
import pandas as pd
import numpy as np
import pytest
from framework.factors import FactorRegistry, FactorCombiner
from framework.signals import SignalGenerator
from framework.risk import CallbackHook, Position
from framework.strategy import StrategyBase
from framework.execution import Portfolio, BacktestExecutor
from strategies.shared.factors.momentum import MomentumFactor, VolatilityFactor
from strategies.shared.signals.selectors import TopNSelector
from strategies.shared.risk.controls import StopLossControl, premium_filter_callback
from strategies.rotation.strategy import RotationStrategy
class TestFactorIntegration:
"""测试因子集成"""
def setup_method(self):
"""每个测试前清空注册表"""
FactorRegistry.clear()
def test_register_and_use_custom_factor(self):
"""测试注册并使用定制因子"""
FactorRegistry.register(MomentumFactor)
factor = FactorRegistry.get('momentum', n_days=25, crash_filter=True)
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
}, index=dates)
values = factor.compute(data)
assert len(values) == len(data)
def test_combiner_with_custom_factors(self):
"""测试组合器使用定制因子"""
FactorRegistry.register(MomentumFactor)
FactorRegistry.register(VolatilityFactor)
momentum = FactorRegistry.get('momentum', n_days=25)
volatility = FactorRegistry.get('volatility', method='std', period=20)
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100,
'high': np.random.randn(100).cumsum() + 105,
'low': np.random.randn(100).cumsum() + 95
}, index=dates)
combiner = FactorCombiner([momentum, volatility], weights=[0.7, 0.3])
result = combiner.compute(data)
assert 'momentum' in result.columns
assert 'volatility' in result.columns
assert 'combined' in result.columns
class TestSignalIntegration:
"""测试信号生成器集成"""
def test_custom_signal_generator_with_factors(self):
"""测试定制信号生成器与因子集成"""
dates = pd.date_range('2020-01-01', periods=50)
factor_data = pd.DataFrame({
'momentum_A': np.random.randn(50),
'momentum_B': np.random.randn(50),
'momentum_C': np.random.randn(50),
}, index=dates)
selector = TopNSelector(select_num=2, min_score=0.0)
result = selector.generate(factor_data)
assert 'signal' in result.columns
class TestRiskIntegration:
"""测试风控组件集成"""
def test_custom_risk_control(self):
"""测试定制风控组件"""
control = StopLossControl(threshold=-0.05, trailing=True)
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=pd.Timestamp.now()
)
# 首次检查,设置最高价
assert control.check(position) == True
# 价格下跌,触发跟踪止损
position.current_price = 100.0
assert control.check(position) == False
def test_callback_hook_with_custom_callback(self):
"""测试回调钩子与定制回调集成"""
hook = CallbackHook()
# 注册定制回调
callback = premium_filter_callback(threshold=0.10)
hook.register('before_entry', callback)
# 正常溢价通过
result = hook.trigger('before_entry', 'AAPL', 100.0, premium=0.05)
assert result == True
# 高溢价拒绝
result = hook.trigger('before_entry', 'AAPL', 100.0, premium=0.15)
assert result == False
class TestStrategyIntegration:
"""测试策略集成"""
def setup_method(self):
"""每个测试前清空注册表"""
FactorRegistry.clear()
def test_rotation_strategy_full_flow(self):
"""测试轮动策略完整流程"""
strategy = RotationStrategy()
# 生成测试数据
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
}, index=dates)
# 运行策略
result = strategy.run(data)
assert 'signal' in result.columns
def test_strategy_with_backtest_executor(self):
"""测试策略与回测执行器集成"""
FactorRegistry.clear()
strategy = RotationStrategy()
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
}, index=dates)
signals = strategy.run(data)
executor = BacktestExecutor(initial_capital=100000, trade_cost=0.001)
portfolio = executor.execute(signals, data)
assert isinstance(portfolio, Portfolio)
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -0,0 +1,325 @@
"""
风控层测试
测试RiskControl、CallbackHook抽象接口
"""
import pandas as pd
import numpy as np
import pytest
from datetime import datetime
from framework.risk import RiskControl, CallbackHook, Position
from strategies.shared.risk.controls import StopLossControl, PositionLimitControl, PremiumControl
class TestPosition:
"""测试Position数据结构"""
def test_position_creation(self):
"""测试持仓创建"""
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now(),
quantity=10
)
assert position.code == 'AAPL'
assert position.entry_price == 100.0
assert position.current_price == 105.0
def test_profit_ratio(self):
"""测试盈亏比例计算"""
position = Position(
code='AAPL',
entry_price=100.0,
current_price=110.0,
entry_time=datetime.now()
)
assert position.profit_ratio == 0.10
def test_profit_amount(self):
"""测试盈亏金额计算"""
position = Position(
code='AAPL',
entry_price=100.0,
current_price=110.0,
entry_time=datetime.now(),
quantity=10
)
assert position.profit_amount == 100.0
def test_position_repr(self):
"""测试持仓字符串表示"""
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now()
)
repr_str = repr(position)
assert 'AAPL' in repr_str
assert '5.00%' in repr_str
class TestStopLossControl:
"""测试止损控制"""
def test_stop_loss_init(self):
"""测试初始化"""
control = StopLossControl(threshold=-0.05)
assert control.threshold == -0.05
assert control.name == "stop_loss"
def test_stop_loss_check(self):
"""测试止损检查"""
control = StopLossControl(threshold=-0.05)
# 盈利持仓应该通过
profit_position = Position(
code='AAPL',
entry_price=100.0,
current_price=110.0,
entry_time=datetime.now()
)
assert control.check(profit_position) == True
# 亏损持仓应该触发止损
loss_position = Position(
code='AAPL',
entry_price=100.0,
current_price=90.0, # 亏损10%
entry_time=datetime.now()
)
# 亏损超过阈值应该触发
assert control.check(loss_position) == False
def test_trailing_stop_loss(self):
"""测试跟踪止损"""
control = StopLossControl(threshold=-0.05, trailing=True, trailing_percent=0.03)
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now()
)
# 初始最高价=入场价100当前价105
# 更新后最高价=105
control.check(position)
# 从最高价回撤不超过阈值应该通过
position.current_price = 103.0 # 回撤约2%
assert control.check(position) == True
# 回撤超过阈值应该触发
position.current_price = 100.0 # 回撤约5%
assert control.check(position) == False
class TestPositionLimitControl:
"""测试仓位限制控制"""
def test_position_limit_init(self):
"""测试初始化"""
control = PositionLimitControl(max_position=0.33)
assert control.max_position == 0.33
assert control.name == "position_limit"
def test_position_limit_check(self):
"""测试仓位限制检查"""
control = PositionLimitControl(max_position=0.33)
# 正常仓位应该通过
normal_position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now(),
weight=0.20
)
assert control.check(normal_position) == True
# 超限仓位应该触发
over_position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now(),
weight=0.50
)
assert control.check(over_position) == False
class TestPremiumControl:
"""测试溢价控制"""
def test_premium_control_init(self):
"""测试初始化"""
control = PremiumControl(threshold=0.10)
assert control.threshold == 0.10
assert control.name == "premium"
def test_premium_filter_mode(self):
"""测试溢价过滤模式"""
control = PremiumControl(threshold=0.10, mode='filter')
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now()
)
# 正常溢价应该通过
assert control.check(position, premium=0.05) == True
# 高溢价应该被过滤
assert control.check(position, premium=0.15) == False
class TestCallbackHook:
"""测试回调钩子"""
def test_callback_hook_init(self):
"""测试初始化"""
hook = CallbackHook()
assert len(hook.list_hooks()) == 6
def test_register_callback(self):
"""测试注册回调"""
hook = CallbackHook()
def my_callback(code, price, **kwargs):
return True
hook.register('before_entry', my_callback)
callbacks = hook.get_callbacks('before_entry')
assert len(callbacks) == 1
def test_trigger_before_entry(self):
"""测试触发入场前回调"""
hook = CallbackHook()
def always_pass(code, price, **kwargs):
return True
def always_block(code, price, **kwargs):
return False
hook.register('before_entry', always_pass)
hook.register('before_entry', always_block)
# before_entry要求所有回调返回True才允许
result = hook.trigger('before_entry', 'AAPL', 100.0)
assert result == False
def test_trigger_dynamic_stoploss(self):
"""测试触发动态止损回调"""
hook = CallbackHook()
def stoploss_5p(position):
return -0.05
def stoploss_3p(position):
return -0.03
hook.register('dynamic_stoploss', stoploss_5p)
hook.register('dynamic_stoploss', stoploss_3p)
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now()
)
# dynamic_stoploss返回最小止损值最严格
result = hook.trigger('dynamic_stoploss', position)
assert result == -0.05
def test_trigger_custom_exit(self):
"""测试触发自定义出场回调"""
hook = CallbackHook()
def exit_on_loss(position):
return position.profit_ratio < -0.05
def exit_on_profit(position):
return position.profit_ratio > 0.20
hook.register('custom_exit', exit_on_loss)
hook.register('custom_exit', exit_on_profit)
# custom_exit任一回调触发即可
profit_position = Position(
code='AAPL',
entry_price=100.0,
current_price=125.0, # 盈利25%
entry_time=datetime.now()
)
result = hook.trigger('custom_exit', profit_position)
assert result == True
# 未触发
normal_position = Position(
code='AAPL',
entry_price=100.0,
current_price=110.0,
entry_time=datetime.now()
)
result = hook.trigger('custom_exit', normal_position)
assert result == False
def test_clear_hooks(self):
"""测试清空回调"""
hook = CallbackHook()
def callback(code, price, **kwargs):
return True
hook.register('before_entry', callback)
# 清空特定钩子
hook.clear('before_entry')
assert len(hook.get_callbacks('before_entry')) == 0
# 清空所有钩子
hook.register('before_entry', callback)
hook.register('after_entry', callback)
hook.clear()
assert len(hook.get_callbacks('before_entry')) == 0
assert len(hook.get_callbacks('after_entry')) == 0
def test_default_behavior(self):
"""测试默认行为"""
hook = CallbackHook() # 无注册回调
# before_entry默认允许
result = hook.trigger('before_entry', 'AAPL', 100.0)
assert result == True
# dynamic_stoploss默认值
result = hook.trigger('dynamic_stoploss', None, default_stoploss=-0.05)
assert result == -0.05
# custom_exit默认不出场
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now()
)
result = hook.trigger('custom_exit', position)
assert result == False
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -0,0 +1,163 @@
"""
信号层测试
测试SignalGenerator抽象接口
"""
import pandas as pd
import numpy as np
import pytest
from framework.signals import SignalGenerator
from strategies.shared.signals.selectors import TopNSelector, TrendFollower, ReversalTrader
class TestTopNSelector:
"""测试TopNSelector"""
def test_top_n_selector_init(self):
"""测试初始化"""
selector = TopNSelector(select_num=3)
assert selector.select_num == 3
assert selector.mode == "top_n"
def test_top_n_selection(self):
"""测试Top N选股"""
dates = pd.date_range('2020-01-01', periods=50)
# 生成因子数据
data = pd.DataFrame({
'factor_A': np.random.randn(50),
'factor_B': np.random.randn(50),
'factor_C': np.random.randn(50),
}, index=dates)
selector = TopNSelector(select_num=2)
result = selector.generate(data)
# 检查结果列
assert 'signal' in result.columns
assert 'signal_raw' in result.columns
# 检查T+1移位signal比signal_raw滞后1天
assert result['signal'].iloc[0] == '' or pd.isna(result['signal'].iloc[0])
def test_min_score_filter(self):
"""测试最小得分过滤"""
dates = pd.date_range('2020-01-01', periods=50)
# 生成因子数据,部分为负值
data = pd.DataFrame({
'factor_A': [0.1] * 50,
'factor_B': [-0.1] * 50, # 负分
'factor_C': [0.2] * 50,
}, index=dates)
selector = TopNSelector(select_num=2, min_score=0.0)
result = selector.generate(data)
# 负分因子应该被过滤
signals = result['signal_raw'].dropna().unique()
for sig in signals:
if sig:
codes = sig.split(',')
assert 'factor_B' not in codes
def test_grouped_selection(self):
"""测试分组选股"""
dates = pd.date_range('2020-01-01', periods=50)
# 生成因子数据和分组信息
data = pd.DataFrame({
'factor_A': [0.1] * 50,
'factor_B': [0.2] * 50,
'factor_C': [0.15] * 50,
}, index=dates)
# 分组信息(模拟)
# 注实际使用需要在数据中包含group_info列
selector = TopNSelector(select_num=2, group_by='market')
result = selector.generate(data)
assert 'signal' in result.columns
class TestTrendFollower:
"""测试趋势跟随器"""
def test_trend_follower_init(self):
"""测试初始化"""
follower = TrendFollower(entry_threshold=0.02, exit_threshold=-0.02)
assert follower.entry_threshold == 0.02
assert follower.exit_threshold == -0.02
assert follower.mode == "trend"
def test_trend_signal_generation(self):
"""测试趋势信号生成"""
dates = pd.date_range('2020-01-01', periods=50)
# 生成趋势因子数据
data = pd.DataFrame({
'trend_A': [0.05] * 50, # 强趋势
'trend_B': [-0.05] * 50, # 弱趋势
}, index=dates)
follower = TrendFollower(entry_threshold=0.02)
result = follower.generate(data)
# 检查信号列
assert 'signal' in result.columns
class TestReversalTrader:
"""测试反转交易器"""
def test_reversal_trader_init(self):
"""测试初始化"""
trader = ReversalTrader(overbought=70, oversold=30)
assert trader.overbought == 70
assert trader.oversold == 30
assert trader.mode == "reversal"
def test_reversal_signal_generation(self):
"""测试反转信号生成"""
dates = pd.date_range('2020-01-01', periods=50)
# 生成反转因子数据
data = pd.DataFrame({
'reversal_A': [0.15] * 50, # 超卖反转
'reversal_B': [-0.15] * 50, # 超买反转
}, index=dates)
trader = ReversalTrader(reversal_threshold=0.1)
result = trader.generate(data)
# 检查信号列
assert 'signal' in result.columns
class TestSignalGeneratorBase:
"""测试SignalGenerator抽象基类"""
def test_validate_factor_data(self):
"""测试数据验证"""
selector = TopNSelector(select_num=3)
# 空数据应该返回False
empty_data = pd.DataFrame()
assert selector.validate_factor_data(empty_data) == False
# 有效数据应该返回True
valid_data = pd.DataFrame({'factor_A': [1, 2, 3]})
assert selector.validate_factor_data(valid_data) == True
def test_repr(self):
"""测试字符串表示"""
selector = TopNSelector(select_num=3, min_score=0.5)
repr_str = repr(selector)
assert "TopNSelector" in repr_str
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -0,0 +1,194 @@
"""
策略层测试
测试StrategyBase抽象接口
"""
import pandas as pd
import numpy as np
import pytest
from datetime import datetime
from framework.strategy import StrategyBase
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
from framework.signals import SignalGenerator
from framework.risk import Position
class TestStrategyBase:
"""测试StrategyBase抽象基类"""
def test_strategy_config_override(self):
"""测试配置覆盖类属性"""
from strategies.rotation.strategy import RotationStrategy
strategy = RotationStrategy(config={'select_num': 5, 'stoploss': -0.03})
assert strategy.select_num == 5
assert strategy.stoploss == -0.03
def test_strategy_default_values(self):
"""测试默认值"""
from strategies.rotation.strategy import RotationStrategy
strategy = RotationStrategy()
assert strategy.select_num == 3
assert strategy.stoploss == -0.05
def test_strategy_repr(self):
"""测试字符串表示"""
from strategies.rotation.strategy import RotationStrategy
strategy = RotationStrategy()
repr_str = repr(strategy)
assert 'RotationStrategy' in repr_str
assert 'rotation' in repr_str
def test_strategy_interface_version(self):
"""测试接口版本"""
from strategies.rotation.strategy import RotationStrategy
strategy = RotationStrategy()
assert strategy.INTERFACE_VERSION == 1
class TestRotationStrategy:
"""测试轮动策略"""
def test_rotation_strategy_init(self):
"""测试初始化"""
from strategies.rotation.strategy import RotationStrategy
FactorRegistry.clear()
strategy = RotationStrategy()
# 检查因子初始化
assert strategy._factors is not None
assert strategy._signal_gen is not None
def test_rotation_strategy_run(self):
"""测试策略运行"""
from strategies.rotation.strategy import RotationStrategy
FactorRegistry.clear()
strategy = RotationStrategy()
# 生成测试数据
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
}, index=dates)
result = strategy.run(data)
# 检查结果
assert 'signal' in result.columns
def test_dynamic_stoploss(self):
"""测试动态止损"""
from strategies.rotation.strategy import RotationStrategy
FactorRegistry.clear()
strategy = RotationStrategy()
# 测试不同持仓时间
position_5days = Position(
code='AAPL',
entry_price=100.0,
current_price=95.0,
entry_time=datetime.now() - pd.Timedelta(days=5)
)
# 5天持仓止损阈值应该为-0.05
stoploss = strategy.dynamic_stoploss(position_5days)
assert stoploss == -0.05
def test_before_entry_premium_filter(self):
"""测试入场前溢价过滤"""
from strategies.rotation.strategy import RotationStrategy
FactorRegistry.clear()
strategy = RotationStrategy()
# 正常溢价应该通过
result = strategy.before_entry('AAPL', 100.0, premium=0.05)
assert result == True
# 高溢价应该被拒绝
result = strategy.before_entry('AAPL', 100.0, premium=0.15)
assert result == False
def test_custom_exit(self):
"""测试自定义出场"""
from strategies.rotation.strategy import RotationStrategy
FactorRegistry.clear()
strategy = RotationStrategy()
# 正常盈亏不触发
normal_position = Position(
code='AAPL',
entry_price=100.0,
current_price=95.0,
entry_time=datetime.now()
)
result = strategy.custom_exit(normal_position)
assert result == False
# 大亏损触发出场
loss_position = Position(
code='AAPL',
entry_price=100.0,
current_price=85.0,
entry_time=datetime.now()
)
result = strategy.custom_exit(loss_position)
assert result == True
class TestStrategyCallbacks:
"""测试策略回调机制"""
def test_callback_registration(self):
"""测试回调自动注册"""
from strategies.rotation.strategy import RotationStrategy
FactorRegistry.clear()
strategy = RotationStrategy()
# 检查回调是否注册
callbacks = strategy._callbacks.get_callbacks('before_entry')
assert len(callbacks) > 0
callbacks = strategy._callbacks.get_callbacks('dynamic_stoploss')
assert len(callbacks) > 0
def test_callback_trigger_in_run(self):
"""测试回调在策略运行中触发"""
from strategies.rotation.strategy import RotationStrategy
FactorRegistry.clear()
strategy = RotationStrategy()
# 添加自定义回调
call_count = {'count': 0}
def counting_callback(code, price, **kwargs):
call_count['count'] += 1
return True
strategy._callbacks.register('before_entry', counting_callback)
# 运行策略
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
}, index=dates)
strategy.run(data)
if __name__ == '__main__':
pytest.main([__file__, '-v'])