refactor(framework): 框架只保留抽象接口,具体实现移至strategies/shared

- FactorBase/FactorRegistry/FactorCombiner: 因子抽象接口
- SignalGenerator: 信号生成抽象接口
- RiskControl/Position/CallbackHook: 风控抽象接口
- StrategyBase: 策略抽象基类
- Executor/Portfolio: 执行器抽象接口
- ConfigLoader: 配置加载器
- 删除framework/factors/momentum.py(具体实现)
This commit is contained in:
2026-05-11 23:09:01 +08:00
parent 9a8a0d7c72
commit 30ea2970bd
8 changed files with 503 additions and 1516 deletions

View File

@@ -1,21 +1,52 @@
"""
量化策略通用框架
框架统一入口(通用)
融合Freqtrade回调机制 + 模块化因子设计
只导出抽象接口具体实现在strategies/shared/
"""
from .factors import FactorBase, FactorRegistry, FactorCombiner
from .signals import SignalGenerator, TopNSelector, TrendFollower, ReversalTrader
from .strategy import StrategyBase, RotationStrategy
from .risk import RiskControl, StopLossControl, PositionLimitControl
from .execution import Executor, BacktestExecutor, DryRunExecutor
from .config import ConfigLoader
# 因子层抽象
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
# 信号层抽象
from framework.signals import SignalGenerator
# 风控层抽象
from framework.risk import Position, RiskControl, CallbackHook
# 策略层抽象
from framework.strategy import StrategyBase
# 执行层抽象
from framework.execution import Portfolio, Executor, BacktestExecutor, DryRunExecutor
# 配置层
from framework.config import ConfigLoader, StrategyConfig
__all__ = [
'FactorBase', 'FactorRegistry', 'FactorCombiner',
'SignalGenerator', 'TopNSelector', 'TrendFollower', 'ReversalTrader',
'StrategyBase', 'RotationStrategy',
'RiskControl', 'StopLossControl', 'PositionLimitControl',
'Executor', 'BacktestExecutor', 'DryRunExecutor',
# 因子层
'FactorBase',
'FactorRegistry',
'FactorCombiner',
# 信号层
'SignalGenerator',
# 风控层
'Position',
'RiskControl',
'CallbackHook',
# 策略层
'StrategyBase',
# 执行层
'Portfolio',
'Executor',
'BacktestExecutor',
'DryRunExecutor',
# 配置层
'ConfigLoader',
'StrategyConfig',
]

View File

@@ -1,83 +1,103 @@
"""
配置层抽象设计
配置层抽象接口(通用)
核心组件:
- ConfigLoader: 配置加载器
只提供配置加载和验证机制
"""
import yaml
from typing import Dict, Any, Optional
from pathlib import Path
from dataclasses import dataclass
@dataclass
class StrategyConfig:
"""策略配置"""
name: str
version: int
factors: list
signal: dict
callbacks: dict
params: dict
class ConfigLoader:
"""
配置加载器
配置加载器(通用)
支持YAML配置文件加载和验证
支持YAML文件加载配置
"""
def __init__(self, config_path: str):
"""
初始化配置加载器
def __init__(self, config_path: Optional[str] = None):
"""初始化配置加载器"""
self._config: Dict[str, Any] = {}
Args:
config_path: 配置文件路径
"""
self._config_path = Path(config_path)
self._config = None
if config_path:
self.load(config_path)
def load(self) -> Dict:
"""加载配置"""
if not self._config_path.exists():
raise FileNotFoundError(f"Config file not found: {self._config_path}")
def load(self, config_path: str) -> Dict[str, Any]:
"""从YAML文件加载配置"""
path = Path(config_path)
with open(self._config_path, 'r', encoding='utf-8') as f:
self._config = yaml.safe_load(f)
if not path.exists():
raise FileNotFoundError(f"Config file not found: {config_path}")
with open(path, 'r', encoding='utf-8') as f:
self._config = yaml.safe_load(f) or {}
return self._config
def validate(self) -> bool:
"""验证配置"""
if self._config is None:
self.load()
def get(self, key: str, default: Any = None) -> Any:
"""获取配置"""
return self._config.get(key, default)
def get_section(self, section: str) -> Dict[str, Any]:
"""获取配置区块"""
return self._config.get(section, {})
def validate(self, required_keys: list) -> bool:
"""验证必填配置项"""
missing = [key for key in required_keys if key not in self._config]
# 必须字段
required_fields = ['strategy', 'factors', 'signal']
for field in required_fields:
if field not in self._config:
raise ValueError(f"Missing required field: {field}")
if missing:
raise ValueError(f"Missing required config keys: {missing}")
return True
def get_strategy_config(self) -> StrategyConfig:
"""获取策略配置"""
if self._config is None:
self.load()
return StrategyConfig(
name=self._config['strategy']['name'],
version=self._config['strategy'].get('version', 1),
factors=self._config['factors'],
signal=self._config['signal'],
callbacks=self._config.get('callbacks', {}),
params=self._config.get('params', {})
)
def __repr__(self) -> str:
return f"ConfigLoader(keys={list(self._config.keys())})"
class StrategyConfig:
"""
策略配置类(通用)
@staticmethod
def from_yaml(yaml_str: str) -> Dict:
"""从YAML字符串加载"""
return yaml.safe_load(yaml_str)
用于封装策略配置
"""
def __init__(self, config: Dict[str, Any]):
"""初始化策略配置"""
self._config = config
@property
def name(self) -> str:
"""策略名称"""
return self._config.get('strategy', {}).get('name', 'unknown')
@property
def select_num(self) -> int:
"""选中数量"""
return self._config.get('signal', {}).get('select_num', 3)
@property
def stoploss(self) -> float:
"""止损阈值"""
return self._config.get('risk', {}).get('stop_loss', -0.05)
def get_factor_config(self) -> list:
"""获取因子配置"""
return self._config.get('factors', [])
def get_signal_config(self) -> dict:
"""获取信号配置"""
return self._config.get('signal', {})
def get_risk_config(self) -> list:
"""获取风控配置"""
return self._config.get('risk', [])
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return self._config.copy()
# 导出抽象接口
__all__ = ['ConfigLoader', 'StrategyConfig']

View File

@@ -1,55 +1,119 @@
"""
执行层抽象设计
执行层抽象接口(通用)
核心组件:
- Executor: 执行器抽象基类
- BacktestExecutor: 回测执行器
- DryRunExecutor: 模拟盘执行器
只提供抽象基类和Portfolio数据结构具体执行器可扩展
"""
import pandas as pd
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
from typing import Dict, List, Optional
import pandas as pd
from datetime import datetime
from framework.risk import Position
@dataclass
class Portfolio:
"""持仓组合"""
positions: Dict[str, Any] # {code: Position}
cash: float
nav: float
trades: List[Any]
"""
投资组合数据结构(通用)
def get_total_value(self) -> float:
"""获取总价值"""
position_value = sum(
pos.quantity * pos.current_price
for pos in self.positions.values()
用于管理持仓集合
"""
def __init__(self, initial_capital: float = 100000):
"""初始化投资组合"""
self.initial_capital = initial_capital
self.cash = initial_capital
self.positions: Dict[str, Position] = {}
self.trades: List[Dict] = []
self._net_value_history: List[float] = []
def add_position(self, code: str, price: float, quantity: float, time: datetime) -> None:
"""添加持仓"""
position = Position(
code=code,
entry_price=price,
current_price=price,
entry_time=time,
quantity=quantity
)
return self.cash + position_value
self.positions[code] = position
self.cash -= price * quantity
self.trades.append({
'action': 'BUY',
'code': code,
'price': price,
'quantity': quantity,
'time': time
})
def get_position_codes(self) -> List[str]:
"""获取持仓代码列表"""
return list(self.positions.keys())
def remove_position(self, code: str, price: float, time: datetime) -> float:
"""移除持仓"""
if code not in self.positions:
return 0
position = self.positions[code]
profit = (price - position.entry_price) * position.quantity
self.cash += price * position.quantity
del self.positions[code]
self.trades.append({
'action': 'SELL',
'code': code,
'price': price,
'quantity': position.quantity,
'time': time,
'profit': profit
})
return profit
def update_prices(self, prices: Dict[str, float]) -> None:
"""更新持仓价格"""
for code, price in prices.items():
if code in self.positions:
self.positions[code].current_price = price
def get_net_value(self) -> float:
"""计算净值"""
positions_value = sum(
pos.current_price * pos.quantity for pos in self.positions.values()
)
return self.cash + positions_value
def record_net_value(self) -> None:
"""记录当前净值"""
self._net_value_history.append(self.get_net_value())
def get_net_value_series(self) -> pd.Series:
"""获取净值序列"""
return pd.Series(self._net_value_history)
def get_weight(self, code: str) -> float:
"""计算持仓权重"""
if code not in self.positions:
return 0
position_value = self.positions[code].current_price * self.positions[code].quantity
return position_value / self.get_net_value()
def __repr__(self) -> str:
return f"Portfolio(capital={self.cash:.2f}, positions={len(self.positions)})"
class Executor(ABC):
"""
执行器抽象基类
支持不同执行模式:
- backtest: 回测模式
- dry_run: 模拟盘模式
- live: 实盘模式TODO
所有执行器必须实现execute方法
"""
mode: str = "base"
def __init__(self, config: Optional[Dict] = None):
self._config = config or {}
self._portfolio = None
def __init__(self, **params):
"""初始化执行器参数"""
self._params = params
@abstractmethod
def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> Portfolio:
@@ -58,121 +122,71 @@ class Executor(ABC):
Args:
signals: 信号DataFrame
data: 价格数据
data: OHLCV数据
Returns:
持仓组合
Portfolio对象
"""
pass
@abstractmethod
def get_mode(self) -> str:
"""获取执行模式"""
pass
@property
def portfolio(self) -> Optional[Portfolio]:
"""获取当前持仓"""
return self._portfolio
def __repr__(self) -> str:
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
return f"{self.__class__.__name__}({params_str})"
class BacktestExecutor(Executor):
"""
回测执行器
回测执行器(通用骨架)
执行回测逻辑
- 处理信号
- 计算净值
- 记录交易
具体回测逻辑需要在strategies中定制实现
"""
mode = "backtest"
def __init__(
self,
initial_capital: float = 100000.0,
trade_cost: float = 0.001
):
super().__init__()
def __init__(self, initial_capital: float = 100000, trade_cost: float = 0.001):
super().__init__(initial_capital=initial_capital, trade_cost=trade_cost)
self.initial_capital = initial_capital
self.trade_cost = trade_cost
def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> Portfolio:
"""执行回测"""
# 初始化持仓
self._portfolio = Portfolio(
positions={},
cash=self.initial_capital,
nav=1.0,
trades=[]
)
"""
执行回测(简化版本)
# 回测逻辑(简化版)
result = pd.DataFrame(index=signals.index)
result['nav'] = 1.0
result['daily_return'] = 0.0
完整回测逻辑需要定制实现
"""
portfolio = Portfolio(self.initial_capital)
# TODO: 完整回测逻辑迁移
# 这里只提供骨架,具体逻辑需要定制实现
# 包括:净值计算、交易成本扣除、基准对比等
return self._portfolio
def get_mode(self) -> str:
return "backtest"
return portfolio
class DryRunExecutor(Executor):
"""
模拟盘执行器
Dry-run执行器通用
执行模拟交易
- 模拟下单
- 模拟成交
- 模拟持仓更新
用于模拟运行,不实际执行交易
"""
mode = "dry_run"
def __init__(
self,
initial_capital: float = 100000.0,
simulated_exchange = None
):
super().__init__()
self.initial_capital = initial_capital
self.simulated_exchange = simulated_exchange
def __init__(self, verbose: bool = True):
super().__init__(verbose=verbose)
self.verbose = verbose
def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> Portfolio:
"""执行模拟盘"""
# 初始化持仓
self._portfolio = Portfolio(
positions={},
cash=self.initial_capital,
nav=1.0,
trades=[]
)
"""模拟执行"""
portfolio = Portfolio(100000)
# 模拟执行逻辑
# TODO: 模拟订单执行
for date in signals.index:
signal = signals.loc[date, 'signal']
if signal and self.verbose:
print(f"[{date}] Signal: {signal}")
return self._portfolio
def get_mode(self) -> str:
return "dry_run"
def simulate_order(self, code: str, direction: str, quantity: float, price: float):
"""模拟下单"""
# 记录模拟订单
print(f"[DRY_RUN] {direction} {quantity} {code} @ {price}")
# 更新持仓
if direction == 'BUY':
# 模拟买入
cost = quantity * price
if cost <= self._portfolio.cash:
self._portfolio.cash -= cost
# TODO: 创建Position对象
elif direction == 'SELL':
# 模拟卖出
if code in self._portfolio.positions:
# TODO: 平仓逻辑
pass
return portfolio
# 导出抽象接口
__all__ = ['Portfolio', 'Executor', 'BacktestExecutor', 'DryRunExecutor']

View File

@@ -1,54 +1,28 @@
"""
因子层抽象设计
因子层抽象接口(通用)
核心组件:
- FactorBase: 因子抽象基类
- FactorRegistry: 因子注册器
- FactorCombiner: 因子组合器
只提供抽象基类和注册机制具体因子实现在strategies/shared/factors/
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any, Type
import pandas as pd
import numpy as np
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
@dataclass
class FactorMeta:
"""因子元信息"""
name: str
category: str # 'momentum', 'trend', 'reversal', 'volatility', 'fundamental'
params: Dict[str, Any]
description: str = ""
class FactorBase(ABC):
"""
因子抽象基类
所有因子必须继承此基类,实现compute方法
支持参数配置、数据验证、元信息管理。
所有因子必须实现compute方法
"""
# 类属性(可被配置覆盖)
name: str = "base"
category: str = "unknown"
def __init__(self, **params):
"""
初始化因子
Args:
**params: 因子参数如n_days=25, period=14等
"""
"""初始化因子参数"""
self._params = params
self._meta = FactorMeta(
name=self.name,
category=self.category,
params=params,
description=self.__doc__ or ""
)
@abstractmethod
def compute(self, data: pd.DataFrame) -> pd.Series:
@@ -56,139 +30,84 @@ class FactorBase(ABC):
计算因子值
Args:
data: 包含OHLCV数据的DataFrame
data: OHLCV数据,必须包含'close'
Returns:
因子值序列Series
因子值序列
"""
pass
@property
def params(self) -> Dict[str, Any]:
"""获取因子参数"""
return self._params
@property
def meta(self) -> FactorMeta:
"""获取因子元信息"""
return self._meta
def validate_data(self, data: pd.DataFrame) -> bool:
"""
验证数据是否满足计算要求
"""验证数据是否满足计算要求"""
if 'close' not in data.columns:
return False
Args:
data: 数据DataFrame
Returns:
是否满足要求
"""
# 默认验证:数据长度 >= 最小周期
min_periods = self._params.get('min_periods', 20)
return len(data) >= min_periods
def __repr__(self) -> str:
return f"{self.__class__.__name__}(name={self.name}, params={self._params})"
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
return f"{self.__class__.__name__}({params_str})"
class FactorRegistry:
"""
因子注册器
因子注册器(通用)
管理所有注册的因子,支持:
- 注册因子类
- 获取因子实例
- 列出可用因子
- 按类别筛选因子
管理因子类的注册和获取
"""
_factors: Dict[str, type] = {}
_factors: Dict[str, Type[FactorBase]] = {}
@classmethod
def register(cls, factor_class: type) -> None:
"""
注册因子类
Args:
factor_class: 因子类必须继承FactorBase
"""
if not isinstance(factor_class, type) or not issubclass(factor_class, FactorBase):
raise TypeError(f"factor_class must be a subclass of FactorBase")
# 创建临时实例获取名称
def register(cls, factor_class: Type[FactorBase]) -> None:
"""注册因子类"""
temp_instance = factor_class()
name = temp_instance.name
if name in cls._factors:
print(f"因子已注册,覆盖: {name}")
cls._factors[name] = factor_class
print(f"✓ 因子已注册: {name} ({factor_class.__name__})")
@classmethod
def get(cls, name: str, **params) -> FactorBase:
"""
获取因子实例
Args:
name: 因子名称
**params: 因子参数
Returns:
因子实例
"""
"""获取因子实例"""
if name not in cls._factors:
raise KeyError(f"Factor '{name}' not registered. Available: {cls.list()}")
raise ValueError(f"因子未注册: {name}")
factor_class = cls._factors[name]
return factor_class(**params)
@classmethod
def list(cls, category: str = None) -> List[str]:
"""
列出可用因子
Args:
category: 按类别筛选(可选)
Returns:
因子名称列表
"""
if category:
return [
name for name, factor_class in cls._factors.items()
if factor_class().category == category
]
def list_factors(cls) -> List[str]:
"""列出所有已注册因子"""
return list(cls._factors.keys())
@classmethod
def list_by_category(cls) -> Dict[str, List[str]]:
"""
按类别列出因子
Returns:
类别→因子列表字典
"""
result = {}
for name, factor_class in cls._factors.items():
cat = factor_class().category
if cat not in result:
result[cat] = []
result[cat].append(name)
return result
def clear(cls) -> None:
"""清空注册表"""
cls._factors = {}
@classmethod
def clear(cls) -> None:
"""清空注册表(用于测试)"""
cls._factors.clear()
def get_category(cls, name: str) -> str:
"""获取因子类别"""
if name not in cls._factors:
return "unknown"
temp_instance = cls._factors[name]()
return temp_instance.category
class FactorCombiner:
"""
因子组合器
因子组合器(通用)
支持多因子加权组合,用于:
- 多因子策略
- 因子权重调整
- 因子结果合并
支持多因子加权组合
"""
SUPPORTED_METHODS = ['weighted_sum', 'rank_average', 'zscore_sum', 'equal_weight']
def __init__(
self,
factors: List[FactorBase],
@@ -196,87 +115,77 @@ class FactorCombiner:
method: str = 'weighted_sum'
):
"""
初始化因子组合器
初始化组合器
Args:
factors: 因子实例列表
weights: 权重列表(默认等权
method: 组合方法 ('weighted_sum', 'average', 'max', 'min')
weights: 因子权重列表(可选
method: 组合方法weighted_sum/rank_average/zscore_sum/equal_weight
"""
if not factors:
raise ValueError("factors list cannot be empty")
if method not in self.SUPPORTED_METHODS:
raise ValueError(f"Unsupported method: {method}")
self._factors = factors
self._weights = weights or [1.0 / len(factors)] * len(factors)
if weights is None:
self._weights = [1.0 / len(factors)] * len(factors)
else:
if len(weights) != len(factors):
raise ValueError("weights length must match factors length")
self._weights = weights
self._method = method
# 验证权重
if len(self._weights) != len(factors):
raise ValueError(f"weights length ({len(self._weights)}) != factors length ({len(factors)})")
# 归一化权重
total_weight = sum(self._weights)
self._weights = [w / total_weight for w in self._weights]
def compute(self, data: pd.DataFrame) -> pd.DataFrame:
"""
计算所有因子并组合
Args:
data: 输入数据
Returns:
包含各因子值和组合因子值的DataFrame
DataFrame包含各因子值和combined列
"""
result = pd.DataFrame(index=data.index)
# 计算各因子
for i, factor in enumerate(self._factors):
# 验证数据
if not factor.validate_data(data):
print(f"⚠ 因子 {factor.name} 数据验证失败,跳过")
continue
# 计算因子值
factor_values = factor.compute(data)
result[factor.name] = factor_values
# 加权因子值
result[f"{factor.name}_weighted"] = factor_values * self._weights[i]
col_name = f"{factor.name}"
result[col_name] = factor_values
# 组合因子值
weighted_cols = [f"{f.name}_weighted" for f in self._factors if f.name in result.columns]
if self._method == 'weighted_sum':
result['combined'] = result[weighted_cols].sum(axis=1)
elif self._method == 'average':
factor_cols = [f.name for f in self._factors if f.name in result.columns]
weighted_cols = [f.name for f in self._factors]
result['combined'] = result[weighted_cols].apply(
lambda row: sum(row[col] * self._weights[i] for i, col in enumerate(weighted_cols) if pd.notna(row[col])),
axis=1
)
elif self._method == 'equal_weight':
factor_cols = [f.name for f in self._factors]
result['combined'] = result[factor_cols].mean(axis=1)
elif self._method == 'max':
factor_cols = [f.name for f in self._factors if f.name in result.columns]
result['combined'] = result[factor_cols].max(axis=1)
elif self._method == 'min':
factor_cols = [f.name for f in self._factors if f.name in result.columns]
result['combined'] = result[factor_cols].min(axis=1)
else:
raise ValueError(f"Unknown method: {self._method}")
elif self._method == 'rank_average':
factor_cols = [f.name for f in self._factors]
ranks = result[factor_cols].rank(axis=1)
result['combined'] = ranks.mean(axis=1)
elif self._method == 'zscore_sum':
factor_cols = [f.name for f in self._factors]
zscores = result[factor_cols].apply(lambda x: (x - x.mean()) / x.std())
result['combined'] = zscores.sum(axis=1)
return result
@property
def factors(self) -> List[FactorBase]:
"""获取因子列表"""
return self._factors
@property
def weights(self) -> List[float]:
"""获取权重列表"""
return self._weights
def set_weights(self, weights: List[float]) -> None:
"""设置权重"""
if len(weights) != len(self._factors):
raise ValueError(f"weights length must equal factors length")
total = sum(weights)
self._weights = [w / total for w in weights]
def get_factor_names(self) -> List[str]:
"""获取因子名称列表"""
return [f.name for f in self._factors]
def __repr__(self) -> str:
factor_names = [f.name for f in self._factors]
return f"FactorCombiner(factors={factor_names}, weights={self._weights})"
return f"FactorCombiner(factors={factor_names}, weights={self._weights}, method={self._method})"
# 导出抽象接口
__all__ = ['FactorBase', 'FactorRegistry', 'FactorCombiner']

View File

@@ -1,312 +0,0 @@
"""
动量因子实现
基于加权线性回归动量的因子
"""
import pandas as pd
import numpy as np
import math
from typing import Optional
from framework.factors import FactorBase
class MomentumFactor(FactorBase):
"""
动量因子
计算加权线性回归动量得分:
得分 = 年化收益率 ×
参数:
- n_days: 动量窗口默认25
- weighted: 是否加权默认True
- crash_filter: 是否启用崩盘过滤默认True
"""
name = "momentum"
category = "momentum"
def __init__(
self,
n_days: int = 25,
weighted: bool = True,
crash_filter: bool = True
):
super().__init__(n_days=n_days, weighted=weighted, crash_filter=crash_filter)
self.n_days = n_days
self.weighted = weighted
self.crash_filter = crash_filter
def compute(self, data: pd.DataFrame) -> pd.Series:
"""计算动量因子值"""
if 'close' not in data.columns:
raise ValueError("data must contain 'close' column")
prices = data['close']
if self.weighted:
# 加权动量得分
factor_values = prices.rolling(self.n_days).apply(
lambda x: self._weighted_momentum_score(x.values),
raw=False
)
else:
# 简单动量
factor_values = prices.pct_change(self.n_days)
# 应用崩盘过滤
if self.crash_filter:
factor_values = self._apply_crash_filter(prices, factor_values)
return factor_values
def _weighted_momentum_score(self, prices: np.ndarray) -> float:
"""计算加权动量得分"""
if len(prices) < 5:
return 0.0
y = np.log(prices)
x = np.arange(len(y))
weights = np.linspace(1, 2, len(y))
# 加权线性回归
slope, intercept = np.polyfit(x, y, 1, w=weights)
annualized_returns = math.exp(slope * 250) - 1
# 加权R²
y_pred = slope * x + intercept
ss_res = np.sum(weights * (y - y_pred) ** 2)
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
return annualized_returns * r2
def _apply_crash_filter(
self,
prices: pd.Series,
factor_values: pd.Series
) -> pd.Series:
"""崩盘过滤连续3天跌>5%清零"""
result = factor_values.copy()
for i in range(3, len(prices)):
r1 = prices.iloc[i] / prices.iloc[i-1]
r2 = prices.iloc[i-1] / prices.iloc[i-2]
r3 = prices.iloc[i-2] / prices.iloc[i-3]
# 条件1任一天跌>5%
con1 = min(r1, r2, r3) < 0.95
# 条件2连续下跌且累计跌>5%
con2 = (r1 < 1) and (r2 < 1) and (r3 < 1) and (prices.iloc[i] / prices.iloc[i-3] < 0.95)
if con1 or con2:
result.iloc[i] = 0.0
return result
class TrendFactor(FactorBase):
"""
趋势因子
计算趋势强度:
- MA交叉偏离度
- MACD趋势
参数:
- method: 趋势方法('ma_cross', 'macd'
- fast: 快线周期默认5
- slow: 慢线周期默认20
"""
name = "trend"
category = "trend"
def __init__(
self,
method: str = 'ma_cross',
fast: int = 5,
slow: int = 20
):
super().__init__(method=method, fast=fast, slow=slow)
self.method = method
self.fast = fast
self.slow = slow
def compute(self, data: pd.DataFrame) -> pd.Series:
"""计算趋势因子值"""
if 'close' not in data.columns:
raise ValueError("data must contain 'close' column")
prices = data['close']
if self.method == 'ma_cross':
# MA交叉偏离度
fast_ma = prices.rolling(self.fast).mean()
slow_ma = prices.rolling(self.slow).mean()
trend_strength = (fast_ma - slow_ma) / slow_ma
return trend_strength
elif self.method == 'macd':
# MACD趋势
ema12 = prices.ewm(span=12).mean()
ema26 = prices.ewm(span=26).mean()
macd = ema12 - ema26
signal = macd.ewm(span=9).mean()
return macd - signal
else:
raise ValueError(f"Unknown method: {self.method}")
class ReversalFactor(FactorBase):
"""
反转因子
计算超买超卖信号:
- RSI偏离度
- KDJ
参数:
- method: 反转方法('rsi', 'kdj'
- period: 周期默认14
- overbought: 超买阈值默认70
- oversold: 超卖阈值默认30
"""
name = "reversal"
category = "reversal"
def __init__(
self,
method: str = 'rsi',
period: int = 14,
overbought: float = 70,
oversold: float = 30
):
super().__init__(method=method, period=period, overbought=overbought, oversold=oversold)
self.method = method
self.period = period
self.overbought = overbought
self.oversold = oversold
def compute(self, data: pd.DataFrame) -> pd.Series:
"""计算反转因子值"""
if 'close' not in data.columns:
raise ValueError("data must contain 'close' column")
prices = data['close']
if self.method == 'rsi':
# RSI反转信号
rsi = self._compute_rsi(prices, self.period)
# 超买超卖偏离度
# 超买 → 负值(反转向下信号)
# 超卖 → 正值(反转向上信号)
reversal_signal = pd.Series(index=prices.index, dtype=float)
reversal_signal = np.where(
rsi > self.overbought,
-(rsi - self.overbought) / (100 - self.overbought), # 超买:负值
np.where(
rsi < self.oversold,
(self.oversold - rsi) / self.oversold, # 超卖:正值
0 # 正常区间0
)
)
return pd.Series(reversal_signal, index=prices.index)
elif self.method == 'kdj':
# KDJ反转信号
return self._compute_kdj(data)
else:
raise ValueError(f"Unknown method: {self.method}")
def _compute_rsi(self, prices: pd.Series, period: int) -> pd.Series:
"""计算RSI"""
delta = prices.diff()
gain = delta.where(delta > 0, 0)
loss = (-delta).where(delta < 0, 0)
avg_gain = gain.rolling(period).mean()
avg_loss = loss.rolling(period).mean()
rs = avg_gain / avg_loss
rsi = 100 - (100 / (1 + rs))
return rsi
def _compute_kdj(self, data: pd.DataFrame) -> pd.Series:
"""计算KDJ反转信号"""
low = data['low']
high = data['high']
close = data['close']
# 计算K、D、J
low_min = low.rolling(self.period).min()
high_max = high.rolling(self.period).max()
rsv = (close - low_min) / (high_max - low_min) * 100
k = rsv.ewm(alpha=1/3).mean()
d = k.ewm(alpha=1/3).mean()
j = 3 * k - 2 * d
# J值偏离度作为反转信号
return j
class VolatilityFactor(FactorBase):
"""
波动率因子
计算价格波动率:
- ATR
- 标准差
参数:
- method: 波动率方法('atr', 'std'
- period: 周期默认20
"""
name = "volatility"
category = "volatility"
def __init__(
self,
method: str = 'std',
period: int = 20
):
super().__init__(method=method, period=period)
self.method = method
self.period = period
def compute(self, data: pd.DataFrame) -> pd.Series:
"""计算波动率因子值"""
if self.method == 'std':
# 标准差波动率
return data['close'].rolling(self.period).std()
elif self.method == 'atr':
# ATR波动率
return self._compute_atr(data)
else:
raise ValueError(f"Unknown method: {self.method}")
def _compute_atr(self, data: pd.DataFrame) -> pd.Series:
"""计算ATR"""
high = data['high']
low = data['low']
close = data['close']
prev_close = close.shift(1)
tr = pd.concat([
high - low,
(high - prev_close).abs(),
(low - prev_close).abs()
], axis=1).max(axis=1)
return tr.rolling(self.period).mean()

View File

@@ -1,351 +1,188 @@
"""
回调钩子与风控层设计
风控层抽象接口(通用)
核心组件:
- RiskControl: 风控抽象基类
- StopLossControl: 止损控制
- PositionLimitControl: 仓位限制控制
- CallbackHook: 回调钩子管理
只提供抽象基类和回调机制具体风控组件在strategies/shared/risk/
"""
import pandas as pd
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from typing import Dict, List, Any, Callable, Optional
from dataclasses import dataclass
from datetime import datetime
@dataclass
class Position:
"""持仓信息"""
"""
持仓数据结构(通用)
用于表示单个持仓的状态
"""
code: str
entry_price: float
entry_date: datetime
current_price: float
current_date: datetime
quantity: float
weight: float
entry_time: datetime
quantity: float = 1.0
weight: float = 1.0
@property
def profit_ratio(self) -> float:
"""盈亏比例"""
"""计算盈亏比例"""
return (self.current_price - self.entry_price) / self.entry_price
@property
def holding_days(self) -> int:
"""持仓天数"""
return (self.current_date - self.entry_date).days
def profit_amount(self) -> float:
"""计算盈亏金额"""
return (self.current_price - self.entry_price) * self.quantity
@property
def is_profit(self) -> bool:
"""是否盈利"""
return self.profit_ratio > 0
@dataclass
class Trade:
"""交易信息"""
code: str
direction: str # 'entry' or 'exit'
price: float
date: datetime
quantity: float
reason: str = ""
def holding_days(self) -> int:
"""计算持仓天数"""
if self.entry_time is None:
return 0
return (datetime.now() - self.entry_time).days
def __repr__(self) -> str:
return f"Position(code={self.code}, profit={self.profit_ratio:.2%}, days={self.holding_days})"
class RiskControl(ABC):
"""
风控抽象基类
风控组件抽象基类
所有风控组件必须继承此基类。
所有风控组件必须实现check方法
"""
name: str = "base"
def __init__(self, **params):
"""初始化风控参数"""
self._params = params
@abstractmethod
def check(self, position: Optional[Position], **kwargs) -> bool:
def check(self, position: Position, **kwargs) -> bool:
"""
风控检查
检查风控条件
Args:
position: 持仓信息(可选)
**kwargs: 其他参数
position: 持仓对象
kwargs: 额外参数如premium、history等
Returns:
是否通过检查
True表示通过检查False表示触发风控
"""
pass
@abstractmethod
def apply(self, position: Position) -> Optional[float]:
def apply(self, position: Position) -> Any:
"""
应用风控
应用风控规则(可选)
Args:
position: 持仓信息
position: 持仓对象
Returns:
应用结果(如止损价格、仓位调整比例等)
风控结果(如止损价格、建议仓位等)
"""
pass
@property
def params(self) -> Dict[str, Any]:
return self._params
class StopLossControl(RiskControl):
"""
止损控制
参数:
- threshold: 止损阈值(默认-0.05
- trailing: 是否跟踪止损默认False
- trailing_percent: 跟踪止损比例默认0.03
"""
name = "stop_loss"
def __init__(
self,
threshold: float = -0.05,
trailing: bool = False,
trailing_percent: float = 0.03
):
super().__init__(
threshold=threshold,
trailing=trailing,
trailing_percent=trailing_percent
)
self.threshold = threshold
self.trailing = trailing
self.trailing_percent = trailing_percent
self._highest_price = {} # 跟踪最高价
def check(self, position: Optional[Position], **kwargs) -> bool:
"""检查是否触发止损"""
if position is None:
return True
# 更新最高价(跟踪止损)
if self.trailing:
if position.code not in self._highest_price:
self._highest_price[position.code] = position.entry_price
self._highest_price[position.code] = max(
self._highest_price[position.code],
position.current_price
)
# 检查止损
if self.trailing:
# 跟踪止损:从最高价回撤超过阈值
highest = self._highest_price[position.code]
drawdown = (position.current_price - highest) / highest
return drawdown > -self.trailing_percent
else:
# 固定止损:从入场价亏损超过阈值
return position.profit_ratio > self.threshold
def apply(self, position: Position) -> Optional[float]:
"""返回止损价格"""
if self.trailing:
highest = self._highest_price.get(position.code, position.entry_price)
return highest * (1 - self.trailing_percent)
else:
return position.entry_price * (1 + self.threshold)
class PositionLimitControl(RiskControl):
"""
仓位限制控制
参数:
- max_position: 单品种最大仓位默认0.33
- max_total: 总仓位上限默认1.0
"""
name = "position_limit"
def __init__(
self,
max_position: float = 0.33,
max_total: float = 1.0
):
super().__init__(
max_position=max_position,
max_total=max_total
)
self.max_position = max_position
self.max_total = max_total
def check(self, position: Optional[Position], **kwargs) -> bool:
"""检查仓位是否超限"""
if position is None:
return True
# 检查单品种仓位
if position.weight > self.max_position:
return False
return True
def apply(self, position: Position) -> Optional[float]:
"""返回建议仓位"""
return min(position.weight, self.max_position)
class PremiumControl(RiskControl):
"""
溢价控制
参数:
- threshold: 溢价阈值默认0.10
- mode: 控制模式('filter''penalize'
"""
name = "premium"
def __init__(
self,
threshold: float = 0.10,
mode: str = 'filter'
):
super().__init__(
threshold=threshold,
mode=mode
)
self.threshold = threshold
self.mode = mode
def check(self, position: Optional[Position], **kwargs) -> bool:
"""检查溢价是否超限"""
premium = kwargs.get('premium', 0)
if self.mode == 'filter':
# 完全排除
return premium <= self.threshold
else:
# 仅降权,允许通过
return True
def apply(self, position: Position) -> Optional[float]:
"""返回溢价惩罚系数"""
if self.mode == 'penalize':
return 0.5 # 降权50%
return None
def __repr__(self) -> str:
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
return f"{self.__class__.__name__}({params_str})"
class CallbackHook:
"""
回调钩子管理
回调钩子管理(通用)
支持策略生命周期回调:
- before_entry: 入场前检查
- after_entry: 入场后处理
- before_exit: 出场前检查
- after_exit: 出场后处理
- dynamic_stoploss: 动态止损
- custom_exit: 自定义出场
支持策略生命周期的关键节点注入自定义逻辑
"""
SUPPORTED_HOOKS = [
'before_entry', # 入场前检查
'after_entry', # 入场后处理
'before_exit', # 出场前检查
'after_exit', # 出场后处理
'dynamic_stoploss', # 动态止损计算
'custom_exit' # 自定义出场条件
]
def __init__(self):
self._hooks = {
'before_entry': [],
'after_entry': [],
'before_exit': [],
'after_exit': [],
'dynamic_stoploss': [],
'custom_exit': []
"""初始化回调钩子"""
self._hooks: Dict[str, List[Callable]] = {
hook: [] for hook in self.SUPPORTED_HOOKS
}
def register(self, hook_name: str, callback: callable) -> None:
"""注册回调"""
def register(self, hook_name: str, callback: Callable) -> None:
"""注册回调函数"""
if hook_name not in self._hooks:
raise ValueError(f"Unknown hook: {hook_name}")
raise ValueError(f"Unsupported hook: {hook_name}")
self._hooks[hook_name].append(callback)
def trigger(self, hook_name: str, *args, **kwargs) -> Any:
"""触发回调"""
"""
触发回调
Args:
hook_name: 钩子名称
args: 位置参数
kwargs: 关键字参数
Returns:
回调结果(根据钩子类型返回不同结果)
"""
if hook_name not in self._hooks:
raise ValueError(f"Unknown hook: {hook_name}")
raise ValueError(f"Unsupported hook: {hook_name}")
callbacks = self._hooks[hook_name]
if not callbacks:
# 默认行为
if hook_name == 'dynamic_stoploss':
return kwargs.get('default_stoploss', -0.05)
elif hook_name in ['before_entry', 'before_exit']:
return True
elif hook_name == 'custom_exit':
return False
return None
results = []
for callback in self._hooks[hook_name]:
try:
result = callback(*args, **kwargs)
results.append(result)
except Exception as e:
print(f"⚠ Callback error: {e}")
for callback in callbacks:
result = callback(*args, **kwargs)
results.append(result)
# before_entry和before_exit需要所有回调返回True
# 根据钩子类型返回不同结果
if hook_name in ['before_entry', 'before_exit']:
# 所有回调返回True才允许
return all(results)
# dynamic_stoploss返回最小的止损值
if hook_name == 'dynamic_stoploss':
return min(results) if results else -0.05
# 返回最小止损值(最严格)
return min(results)
# custom_exit返回是否有任一回调触发出场
if hook_name == 'custom_exit':
# 任一回调触发出场
return any(results)
return results
# 其他钩子返回最后一个结果
return results[-1] if results else None
def clear(self, hook_name: str = None) -> None:
def clear(self, hook_name: Optional[str] = None) -> None:
"""清空回调"""
if hook_name:
self._hooks[hook_name] = []
if hook_name in self._hooks:
self._hooks[hook_name] = []
else:
for key in self._hooks:
self._hooks[key] = []
for hook in self._hooks:
self._hooks[hook] = []
def list_hooks(self) -> List[str]:
"""列出支持的钩子"""
return self.SUPPORTED_HOOKS
def get_callbacks(self, hook_name: str) -> List[Callable]:
"""获取钩子的所有回调"""
return self._hooks.get(hook_name, [])
# 便捷回调函数
def premium_filter_callback(threshold: float = 0.10):
"""溢价过滤回调"""
def callback(code: str, price: float, **kwargs) -> bool:
premium = kwargs.get('premium', 0)
if premium > threshold:
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
return False
return True
return callback
def crash_filter_callback(lookback: int = 3, crash_threshold: float = 0.05):
"""崩盘过滤回调"""
def callback(code: str, price: float, **kwargs) -> bool:
history = kwargs.get('history', None)
if history is None:
return True
# 检查最近N天是否有崩盘
recent = history.tail(lookback)
if len(recent) < lookback:
return True
returns = recent['close'].pct_change()
min_return = returns.min()
if min_return < -crash_threshold:
print(f"崩盘检测,拒绝入场: {code} (最大跌幅={min_return:.2%})")
return False
return True
return callback
def holding_time_stoploss_callback(
day_5_stoploss: float = -0.05,
day_10_stoploss: float = -0.03
):
"""持仓时间动态止损回调"""
def callback(position: Position) -> float:
if position.holding_days >= 10:
return day_10_stoploss # 10天后收紧止损
elif position.holding_days >= 5:
return day_5_stoploss
return -0.10 # 默认止损
return callback
# 导出抽象接口
__all__ = ['Position', 'RiskControl', 'CallbackHook']

View File

@@ -1,52 +1,26 @@
"""
信号层抽象设计
信号层抽象接口(通用)
核心组件:
- SignalGenerator: 信号生成器抽象基类
- TopNSelector: Top N选股器轮动策略
- TrendFollower: 趋势跟随器(趋势策略)
- ReversalTrader: 反转交易器(反转策略)
只提供抽象基类具体信号生成器在strategies/shared/signals/
"""
import pandas as pd
import numpy as np
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
@dataclass
class SignalMeta:
"""信号元信息"""
mode: str # 'top_n', 'trend', 'reversal'
select_num: int
description: str = ""
from typing import List, Optional, Any
import pandas as pd
class SignalGenerator(ABC):
"""
信号生成器抽象基类
所有信号生成器必须继承此基类,实现generate方法
支持不同策略类型的信号生成逻辑。
所有信号生成器必须实现generate方法
"""
# 类属性(可被配置覆盖)
mode: str = "base"
def __init__(self, **params):
"""
初始化信号生成器
Args:
**params: 信号参数
"""
"""初始化信号生成器参数"""
self._params = params
self._meta = SignalMeta(
mode=self.mode,
select_num=params.get('select_num', 1),
description=self.__doc__ or ""
)
@abstractmethod
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
@@ -54,300 +28,27 @@ class SignalGenerator(ABC):
生成交易信号
Args:
factor_data: 包含因子值的DataFrame
factor_data: 因子数据DataFrame
Returns:
包含信号列的DataFrame
包含'signal'列的DataFrame
"""
pass
@property
def params(self) -> Dict[str, Any]:
"""获取信号参数"""
return self._params
@property
def meta(self) -> SignalMeta:
"""获取信号元信息"""
return self._meta
def validate_factor_data(self, factor_data: pd.DataFrame) -> bool:
"""验证因子数据是否有效"""
if factor_data.empty:
return False
if 'signal' in factor_data.columns:
print("Warning: factor_data already contains 'signal' column")
return True
def __repr__(self) -> str:
return f"{self.__class__.__name__}(mode={self.mode}, params={self._params})"
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
return f"{self.__class__.__name__}({params_str})"
class TopNSelector(SignalGenerator):
"""
Top N选股器
用于轮动策略:
- 按因子值排序选出Top N标的
- 支持分组选股(先类内竞争,再跨类排序)
参数:
- select_num: 选中数量默认3
- group_by: 分组列名(可选,如'market'
- top_per_group: 每组选中数量默认1
- min_score: 最小得分阈值(可选)
"""
mode = "top_n"
def __init__(
self,
select_num: int = 3,
group_by: Optional[str] = None,
top_per_group: int = 1,
min_score: Optional[float] = None
):
super().__init__(
select_num=select_num,
group_by=group_by,
top_per_group=top_per_group,
min_score=min_score
)
self.select_num = select_num
self.group_by = group_by
self.top_per_group = top_per_group
self.min_score = min_score
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
"""生成Top N选股信号"""
result = pd.DataFrame(index=factor_data.index)
# 获取因子列(排除非因子列)
factor_cols = self._get_factor_columns(factor_data)
if not factor_cols:
print("⚠ 未找到因子列")
result['signal'] = ''
return result
# 每日选股
signals = []
for date in factor_data.index:
row = factor_data.loc[date]
# 提取当日因子值
scores = {}
for col in factor_cols:
score = row[col]
if pd.notna(score):
scores[col] = score
# 应用最小得分过滤
if self.min_score:
scores = {k: v for k, v in scores.items() if v >= self.min_score}
# 选股逻辑
if self.group_by and 'group_info' in factor_data.columns:
# 分组选股:先类内竞争,再跨类排序
selected = self._grouped_selection(scores, factor_data.loc[date])
else:
# 全局Top N
selected = self._global_top_n(scores)
# 信号格式:逗号分隔的代码列表
signals.append(','.join(selected) if selected else '')
result['signal'] = signals
result['signal_raw'] = signals # 原始信号未shift
# T+1执行信号向后移位1天
result['signal'] = result['signal'].shift(1)
return result
def _get_factor_columns(self, data: pd.DataFrame) -> List[str]:
"""获取因子列名"""
# 排除已知非因子列
exclude_cols = ['signal', 'signal_raw', 'group_info', 'combined', 'open', 'high', 'low', 'close', 'volume']
factor_cols = [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')]
return factor_cols
def _global_top_n(self, scores: Dict[str, float]) -> List[str]:
"""全局Top N选股"""
if not scores:
return []
# 按得分排序
sorted_items = sorted(scores.items(), key=lambda x: x[1], reverse=True)
# 选Top N
selected = [item[0] for item in sorted_items[:self.select_num]]
return selected
def _grouped_selection(
self,
scores: Dict[str, float],
row: pd.Series
) -> List[str]:
"""分组选股:先类内竞争,再跨类排序"""
if 'group_info' not in row.index:
return self._global_top_n(scores)
group_info = row['group_info']
if pd.isna(group_info):
return self._global_top_n(scores)
# 解析分组信息:{code: group}
groups = group_info if isinstance(group_info, dict) else {}
# 类内竞争每组选Top1
group_champions = {}
for code, score in scores.items():
group = groups.get(code, 'default')
if group not in group_champions or score > group_champions[group][1]:
group_champions[group] = (code, score)
# 跨类排序从冠军中选Top N
champions_scores = {code: score for code, score in group_champions.values()}
return self._global_top_n(champions_scores)
class TrendFollower(SignalGenerator):
"""
趋势跟随器
用于趋势跟踪策略:
- 趋势强度 > 入场阈值 → 入场信号
- 趋势强度 < 出场阈值 → 出场信号
参数:
- entry_threshold: 入场阈值默认0.02
- exit_threshold: 出场阈值(默认-0.02
- select_num: 最大持仓数量默认1
"""
mode = "trend"
def __init__(
self,
entry_threshold: float = 0.02,
exit_threshold: float = -0.02,
select_num: int = 1
):
super().__init__(
entry_threshold=entry_threshold,
exit_threshold=exit_threshold,
select_num=select_num
)
self.entry_threshold = entry_threshold
self.exit_threshold = exit_threshold
self.select_num = select_num
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
"""生成趋势跟随信号"""
result = pd.DataFrame(index=factor_data.index)
factor_cols = self._get_factor_columns(factor_data)
for col in factor_cols:
trend_strength = factor_data[col]
# 入场信号:趋势强度 > 阈值
result[f'{col}_entry'] = trend_strength > self.entry_threshold
# 出场信号:趋势强度 < 阈值
result[f'{col}_exit'] = trend_strength < self.exit_threshold
# 综合信号入场强度最高的Top N
signals = []
for date in result.index:
entry_signals = []
for col in factor_cols:
if result.loc[date, f'{col}_entry']:
score = factor_data.loc[date, col]
if pd.notna(score):
entry_signals.append((col, score))
# 按强度排序选Top N
entry_signals.sort(key=lambda x: x[1], reverse=True)
selected = [item[0] for item in entry_signals[:self.select_num]]
signals.append(','.join(selected) if selected else '')
result['signal'] = signals
result['signal'] = result['signal'].shift(1) # T+1执行
return result
def _get_factor_columns(self, data: pd.DataFrame) -> List[str]:
"""获取因子列名"""
exclude_cols = ['signal', 'signal_raw', 'combined', 'open', 'high', 'low', 'close', 'volume']
return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')]
class ReversalTrader(SignalGenerator):
"""
反转交易器
用于反转策略:
- 超买区域RSI>70 → 反转向下信号(卖出)
- 超卖区域RSI<30 → 反转向上信号(买入)
参数:
- overbought: 超买阈值默认70
- oversold: 超卖阈值默认30
- reversal_threshold: 反转信号强度阈值默认0.1
"""
mode = "reversal"
def __init__(
self,
overbought: float = 70,
oversold: float = 30,
reversal_threshold: float = 0.1
):
super().__init__(
overbought=overbought,
oversold=oversold,
reversal_threshold=reversal_threshold
)
self.overbought = overbought
self.oversold = oversold
self.reversal_threshold = reversal_threshold
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
"""生成反转交易信号"""
result = pd.DataFrame(index=factor_data.index)
factor_cols = self._get_factor_columns(factor_data)
for col in factor_cols:
reversal_signal = factor_data[col]
# 买入信号:反转信号 > 阈值(正值,超卖反转)
result[f'{col}_buy'] = reversal_signal > self.reversal_threshold
# 卖出信号:反转信号 < -阈值(负值,超买反转)
result[f'{col}_sell'] = reversal_signal < -self.reversal_threshold
# 综合信号
signals = []
for date in result.index:
buy_signals = []
sell_signals = []
for col in factor_cols:
if result.loc[date, f'{col}_buy']:
buy_signals.append(col)
if result.loc[date, f'{col}_sell']:
sell_signals.append(col)
# 信号格式:'BUY:code1,code2' 或 'SELL:code1' 或 ''
if buy_signals:
signals.append(f"BUY:{','.join(buy_signals)}")
elif sell_signals:
signals.append(f"SELL:{','.join(sell_signals)}")
else:
signals.append('')
result['signal'] = signals
result['signal'] = result['signal'].shift(1) # T+1执行
return result
def _get_factor_columns(self, data: pd.DataFrame) -> List[str]:
"""获取因子列名"""
exclude_cols = ['signal', 'signal_raw', 'combined', 'open', 'high', 'low', 'close', 'volume']
return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')]
# 导出抽象接口
__all__ = ['SignalGenerator']

View File

@@ -1,63 +1,30 @@
"""
策略基类与配置
策略层抽象基类(通用)
核心组件:
- StrategyBase: 策略抽象基类(含回调钩子)
- ConfigLoader: 配置加载器
只提供抽象接口具体策略实现在strategies/
"""
import yaml
import pandas as pd
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, Optional, Any
import pandas as pd
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
from framework.factors.momentum import MomentumFactor
from framework.signals import SignalGenerator, TopNSelector
from framework.factors import FactorCombiner
from framework.signals import SignalGenerator
from framework.risk import CallbackHook, Position
@dataclass
class StrategyConfig:
"""策略配置"""
name: str
version: int
factors: List[Dict]
signal: Dict
callbacks: Dict
params: Dict
class StrategyBase(ABC):
"""
策略抽象基类
融合Freqtrade回调机制 + 模块化因子设计
类属性(可被配置覆盖):
- name: 策略名称
- version: 接口版本
- timeframe: K线周期
- select_num: 选中数量
- stoploss: 止损比例
回调钩子(可选实现):
- before_entry: 入场前检查
- after_entry: 入场后处理
- before_exit: 出场前检查
- after_exit: 出场后处理
- dynamic_stoploss: 动态止损
- custom_exit: 自定义出场条件
所有策略必须实现init_factors和init_signal_generator方法
"""
# 接口版本
INTERFACE_VERSION = 1
# 类属性(可被配置覆盖)
name: str = "base"
timeframe: str = "1d"
# 类属性(可被配置覆盖)
select_num: int = 3
stoploss: float = -0.05
@@ -66,53 +33,50 @@ class StrategyBase(ABC):
初始化策略
Args:
config: 策略配置(可覆盖类属性)
config: 配置字典(可选,用于覆盖类属性)
"""
# 配置覆盖类属性
if config:
self._apply_config(config)
# 初始化回调钩子
self._callbacks = CallbackHook()
self._register_default_callbacks()
# 初始化因子和信号生成器
self._factors = None
self._signal_gen = None
self._factors = self.init_factors()
self._signal_gen = self.init_signal_generator()
def _apply_config(self, config: Dict) -> None:
"""应用配置"""
params = config.get('params', {})
# 覆盖类属性
for key, value in params.items():
"""应用配置覆盖类属性"""
for key, value in config.items():
if hasattr(self, key):
setattr(self, key, value)
# 保存完整配置
self._config = config
def _register_default_callbacks(self) -> None:
"""注册默认回调"""
# 注册入场前回调(溢价过滤)
"""注册默认回调方法"""
if hasattr(self, 'before_entry'):
self._callbacks.register('before_entry', self.before_entry)
# 注册动态止损回调
if hasattr(self, 'after_entry'):
self._callbacks.register('after_entry', self.after_entry)
if hasattr(self, 'before_exit'):
self._callbacks.register('before_exit', self.before_exit)
if hasattr(self, 'after_exit'):
self._callbacks.register('after_exit', self.after_exit)
if hasattr(self, 'dynamic_stoploss'):
self._callbacks.register('dynamic_stoploss', self.dynamic_stoploss)
# 注册自定义出场回调
if hasattr(self, 'custom_exit'):
self._callbacks.register('custom_exit', self.custom_exit)
@abstractmethod
def init_factors(self) -> FactorCombiner:
"""
初始化因子组合
初始化因子组合
Returns:
因子组合器
FactorCombiner实例
"""
pass
@@ -122,7 +86,7 @@ class StrategyBase(ABC):
初始化信号生成器
Returns:
信号生成器
SignalGenerator实例
"""
pass
@@ -131,85 +95,28 @@ class StrategyBase(ABC):
运行策略
Args:
data: 输入数据
data: OHLCV数据
Returns:
包含信号的DataFrame
"""
# 初始化因子和信号生成器
if self._factors is None:
self._factors = self.init_factors()
if self._signal_gen is None:
self._signal_gen = self.init_signal_generator()
# 1. 计算因子
factor_data = self._factors.compute(data)
# 2. 生成信号
signals = self._signal_gen.generate(factor_data)
# 3. 应用回调钩子
signals = self._apply_callbacks(signals, data)
return signals
def _apply_callbacks(self, signals: pd.DataFrame, data: pd.DataFrame) -> pd.DataFrame:
"""应用回调钩子"""
# 遍历每行信号
for date in signals.index:
signal = signals.loc[date, 'signal']
if not signal or pd.isna(signal):
continue
# 解析信号(逗号分隔的代码)
codes = signal.split(',')
# 应用入场前回调
for code in codes:
if code not in data.columns:
continue
price = data.loc[date, code]
premium = 0.0 # TODO: 从溢价数据获取
# 触发回调
allowed = self._callbacks.trigger(
'before_entry',
code,
price,
premium=premium,
history=data
)
if not allowed:
# 移除被拒绝的代码
codes.remove(code)
# 更新信号
signals.loc[date, 'signal'] = ','.join(codes) if codes else ''
"""应用回调处理"""
return signals
# ===== 可选回调方法 =====
# 可选回调方法(子类可覆盖)
def before_entry(self, code: str, price: float, **kwargs) -> bool:
"""
入场前检查
Args:
code: 标的代码
price: 入场价格
**kwargs: 其他参数premium, history等
Returns:
是否允许入场
"""
# 默认:允许入场
"""入场前检查"""
return True
def after_entry(self, trade, **kwargs) -> None:
def after_entry(self, code: str, price: float, **kwargs) -> None:
"""入场后处理"""
pass
@@ -217,141 +124,21 @@ class StrategyBase(ABC):
"""出场前检查"""
return True
def after_exit(self, trade, **kwargs) -> None:
def after_exit(self, position: Position, **kwargs) -> None:
"""出场后处理"""
pass
def dynamic_stoploss(self, position: Position) -> float:
"""
动态止损
Args:
position: 持仓信息
Returns:
止损比例
"""
# 默认:返回固定止损
"""动态止损"""
return self.stoploss
def custom_exit(self, position: Position) -> bool:
"""
自定义出场条件
Args:
position: 持仓信息
Returns:
是否触发出场
"""
"""自定义出场条件"""
return False
def __repr__(self) -> str:
return f"{self.__class__.__name__}(name={self.name})"
class ConfigLoader:
"""
配置加载器
支持YAML配置文件加载和验证
"""
def __init__(self, config_path: str):
"""
初始化配置加载器
Args:
config_path: 配置文件路径
"""
self._config_path = Path(config_path)
self._config = None
def load(self) -> Dict:
"""加载配置"""
if not self._config_path.exists():
raise FileNotFoundError(f"Config file not found: {self._config_path}")
with open(self._config_path, 'r', encoding='utf-8') as f:
self._config = yaml.safe_load(f)
return self._config
def validate(self) -> bool:
"""验证配置"""
if self._config is None:
self.load()
# 必须字段
required_fields = ['strategy', 'factors', 'signal']
for field in required_fields:
if field not in self._config:
raise ValueError(f"Missing required field: {field}")
return True
def get_strategy_config(self) -> StrategyConfig:
"""获取策略配置"""
if self._config is None:
self.load()
return StrategyConfig(
name=self._config['strategy']['name'],
version=self._config['strategy'].get('version', 1),
factors=self._config['factors'],
signal=self._config['signal'],
callbacks=self._config.get('callbacks', {}),
params=self._config.get('params', {})
)
@staticmethod
def from_yaml(yaml_str: str) -> Dict:
"""从YAML字符串加载"""
return yaml.safe_load(yaml_str)
# 示例策略实现
class RotationStrategy(StrategyBase):
"""
ETF轮动策略
基于动量因子 + Top N选股
"""
name = "rotation"
select_num = 3
def init_factors(self) -> FactorCombiner:
"""初始化动量因子"""
FactorRegistry.clear()
FactorRegistry.register(MomentumFactor)
return FactorCombiner([
FactorRegistry.get('momentum', n_days=25, crash_filter=True)
])
def init_signal_generator(self) -> SignalGenerator:
"""初始化Top N选股器"""
from framework.signals import TopNSelector
return TopNSelector(
select_num=self.select_num,
min_score=0.0
)
def before_entry(self, code: str, price: float, **kwargs) -> bool:
"""入场前:溢价过滤"""
premium = kwargs.get('premium', 0)
# 溢价超过10%拒绝入场
if premium > 0.10:
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
return False
return True
def dynamic_stoploss(self, position: Position) -> float:
"""动态止损:根据持仓时间调整"""
if position.holding_days >= 10:
return -0.03 # 10天后收紧止损
elif position.holding_days >= 5:
return -0.05
return -0.10
# 导出抽象接口
__all__ = ['StrategyBase']