refactor(framework): 框架只保留抽象接口,具体实现移至strategies/shared
- FactorBase/FactorRegistry/FactorCombiner: 因子抽象接口 - SignalGenerator: 信号生成抽象接口 - RiskControl/Position/CallbackHook: 风控抽象接口 - StrategyBase: 策略抽象基类 - Executor/Portfolio: 执行器抽象接口 - ConfigLoader: 配置加载器 - 删除framework/factors/momentum.py(具体实现)
This commit is contained in:
@@ -1,21 +1,52 @@
|
||||
"""
|
||||
量化策略通用框架
|
||||
框架统一入口(通用)
|
||||
|
||||
融合Freqtrade回调机制 + 模块化因子设计
|
||||
只导出抽象接口,具体实现在strategies/shared/
|
||||
"""
|
||||
|
||||
from .factors import FactorBase, FactorRegistry, FactorCombiner
|
||||
from .signals import SignalGenerator, TopNSelector, TrendFollower, ReversalTrader
|
||||
from .strategy import StrategyBase, RotationStrategy
|
||||
from .risk import RiskControl, StopLossControl, PositionLimitControl
|
||||
from .execution import Executor, BacktestExecutor, DryRunExecutor
|
||||
from .config import ConfigLoader
|
||||
# 因子层抽象
|
||||
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
|
||||
|
||||
# 信号层抽象
|
||||
from framework.signals import SignalGenerator
|
||||
|
||||
# 风控层抽象
|
||||
from framework.risk import Position, RiskControl, CallbackHook
|
||||
|
||||
# 策略层抽象
|
||||
from framework.strategy import StrategyBase
|
||||
|
||||
# 执行层抽象
|
||||
from framework.execution import Portfolio, Executor, BacktestExecutor, DryRunExecutor
|
||||
|
||||
# 配置层
|
||||
from framework.config import ConfigLoader, StrategyConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
'FactorBase', 'FactorRegistry', 'FactorCombiner',
|
||||
'SignalGenerator', 'TopNSelector', 'TrendFollower', 'ReversalTrader',
|
||||
'StrategyBase', 'RotationStrategy',
|
||||
'RiskControl', 'StopLossControl', 'PositionLimitControl',
|
||||
'Executor', 'BacktestExecutor', 'DryRunExecutor',
|
||||
# 因子层
|
||||
'FactorBase',
|
||||
'FactorRegistry',
|
||||
'FactorCombiner',
|
||||
|
||||
# 信号层
|
||||
'SignalGenerator',
|
||||
|
||||
# 风控层
|
||||
'Position',
|
||||
'RiskControl',
|
||||
'CallbackHook',
|
||||
|
||||
# 策略层
|
||||
'StrategyBase',
|
||||
|
||||
# 执行层
|
||||
'Portfolio',
|
||||
'Executor',
|
||||
'BacktestExecutor',
|
||||
'DryRunExecutor',
|
||||
|
||||
# 配置层
|
||||
'ConfigLoader',
|
||||
'StrategyConfig',
|
||||
]
|
||||
@@ -1,83 +1,103 @@
|
||||
"""
|
||||
配置层抽象设计
|
||||
配置层抽象接口(通用)
|
||||
|
||||
核心组件:
|
||||
- ConfigLoader: 配置加载器
|
||||
只提供配置加载和验证机制
|
||||
"""
|
||||
|
||||
import yaml
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyConfig:
|
||||
"""策略配置"""
|
||||
name: str
|
||||
version: int
|
||||
factors: list
|
||||
signal: dict
|
||||
callbacks: dict
|
||||
params: dict
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
"""
|
||||
配置加载器
|
||||
配置加载器(通用)
|
||||
|
||||
支持YAML配置文件加载和验证
|
||||
支持从YAML文件加载配置
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
"""
|
||||
初始化配置加载器
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
"""初始化配置加载器"""
|
||||
self._config: Dict[str, Any] = {}
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
"""
|
||||
self._config_path = Path(config_path)
|
||||
self._config = None
|
||||
if config_path:
|
||||
self.load(config_path)
|
||||
|
||||
def load(self) -> Dict:
|
||||
"""加载配置"""
|
||||
if not self._config_path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {self._config_path}")
|
||||
def load(self, config_path: str) -> Dict[str, Any]:
|
||||
"""从YAML文件加载配置"""
|
||||
path = Path(config_path)
|
||||
|
||||
with open(self._config_path, 'r', encoding='utf-8') as f:
|
||||
self._config = yaml.safe_load(f)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
self._config = yaml.safe_load(f) or {}
|
||||
|
||||
return self._config
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""验证配置"""
|
||||
if self._config is None:
|
||||
self.load()
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""获取配置项"""
|
||||
return self._config.get(key, default)
|
||||
|
||||
def get_section(self, section: str) -> Dict[str, Any]:
|
||||
"""获取配置区块"""
|
||||
return self._config.get(section, {})
|
||||
|
||||
def validate(self, required_keys: list) -> bool:
|
||||
"""验证必填配置项"""
|
||||
missing = [key for key in required_keys if key not in self._config]
|
||||
|
||||
# 必须字段
|
||||
required_fields = ['strategy', 'factors', 'signal']
|
||||
|
||||
for field in required_fields:
|
||||
if field not in self._config:
|
||||
raise ValueError(f"Missing required field: {field}")
|
||||
if missing:
|
||||
raise ValueError(f"Missing required config keys: {missing}")
|
||||
|
||||
return True
|
||||
|
||||
def get_strategy_config(self) -> StrategyConfig:
|
||||
"""获取策略配置"""
|
||||
if self._config is None:
|
||||
self.load()
|
||||
|
||||
return StrategyConfig(
|
||||
name=self._config['strategy']['name'],
|
||||
version=self._config['strategy'].get('version', 1),
|
||||
factors=self._config['factors'],
|
||||
signal=self._config['signal'],
|
||||
callbacks=self._config.get('callbacks', {}),
|
||||
params=self._config.get('params', {})
|
||||
)
|
||||
def __repr__(self) -> str:
|
||||
return f"ConfigLoader(keys={list(self._config.keys())})"
|
||||
|
||||
|
||||
class StrategyConfig:
|
||||
"""
|
||||
策略配置类(通用)
|
||||
|
||||
@staticmethod
|
||||
def from_yaml(yaml_str: str) -> Dict:
|
||||
"""从YAML字符串加载"""
|
||||
return yaml.safe_load(yaml_str)
|
||||
用于封装策略配置
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""初始化策略配置"""
|
||||
self._config = config
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""策略名称"""
|
||||
return self._config.get('strategy', {}).get('name', 'unknown')
|
||||
|
||||
@property
|
||||
def select_num(self) -> int:
|
||||
"""选中数量"""
|
||||
return self._config.get('signal', {}).get('select_num', 3)
|
||||
|
||||
@property
|
||||
def stoploss(self) -> float:
|
||||
"""止损阈值"""
|
||||
return self._config.get('risk', {}).get('stop_loss', -0.05)
|
||||
|
||||
def get_factor_config(self) -> list:
|
||||
"""获取因子配置"""
|
||||
return self._config.get('factors', [])
|
||||
|
||||
def get_signal_config(self) -> dict:
|
||||
"""获取信号配置"""
|
||||
return self._config.get('signal', {})
|
||||
|
||||
def get_risk_config(self) -> list:
|
||||
"""获取风控配置"""
|
||||
return self._config.get('risk', [])
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return self._config.copy()
|
||||
|
||||
|
||||
# 导出抽象接口
|
||||
__all__ = ['ConfigLoader', 'StrategyConfig']
|
||||
@@ -1,55 +1,119 @@
|
||||
"""
|
||||
执行层抽象设计
|
||||
执行层抽象接口(通用)
|
||||
|
||||
核心组件:
|
||||
- Executor: 执行器抽象基类
|
||||
- BacktestExecutor: 回测执行器
|
||||
- DryRunExecutor: 模拟盘执行器
|
||||
只提供抽象基类和Portfolio数据结构,具体执行器可扩展
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional, List
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
|
||||
from framework.risk import Position
|
||||
|
||||
|
||||
@dataclass
|
||||
class Portfolio:
|
||||
"""持仓组合"""
|
||||
positions: Dict[str, Any] # {code: Position}
|
||||
cash: float
|
||||
nav: float
|
||||
trades: List[Any]
|
||||
"""
|
||||
投资组合数据结构(通用)
|
||||
|
||||
def get_total_value(self) -> float:
|
||||
"""获取总价值"""
|
||||
position_value = sum(
|
||||
pos.quantity * pos.current_price
|
||||
for pos in self.positions.values()
|
||||
用于管理持仓集合
|
||||
"""
|
||||
|
||||
def __init__(self, initial_capital: float = 100000):
|
||||
"""初始化投资组合"""
|
||||
self.initial_capital = initial_capital
|
||||
self.cash = initial_capital
|
||||
self.positions: Dict[str, Position] = {}
|
||||
self.trades: List[Dict] = []
|
||||
self._net_value_history: List[float] = []
|
||||
|
||||
def add_position(self, code: str, price: float, quantity: float, time: datetime) -> None:
|
||||
"""添加持仓"""
|
||||
position = Position(
|
||||
code=code,
|
||||
entry_price=price,
|
||||
current_price=price,
|
||||
entry_time=time,
|
||||
quantity=quantity
|
||||
)
|
||||
return self.cash + position_value
|
||||
self.positions[code] = position
|
||||
self.cash -= price * quantity
|
||||
|
||||
self.trades.append({
|
||||
'action': 'BUY',
|
||||
'code': code,
|
||||
'price': price,
|
||||
'quantity': quantity,
|
||||
'time': time
|
||||
})
|
||||
|
||||
def get_position_codes(self) -> List[str]:
|
||||
"""获取持仓代码列表"""
|
||||
return list(self.positions.keys())
|
||||
def remove_position(self, code: str, price: float, time: datetime) -> float:
|
||||
"""移除持仓"""
|
||||
if code not in self.positions:
|
||||
return 0
|
||||
|
||||
position = self.positions[code]
|
||||
profit = (price - position.entry_price) * position.quantity
|
||||
self.cash += price * position.quantity
|
||||
|
||||
del self.positions[code]
|
||||
|
||||
self.trades.append({
|
||||
'action': 'SELL',
|
||||
'code': code,
|
||||
'price': price,
|
||||
'quantity': position.quantity,
|
||||
'time': time,
|
||||
'profit': profit
|
||||
})
|
||||
|
||||
return profit
|
||||
|
||||
def update_prices(self, prices: Dict[str, float]) -> None:
|
||||
"""更新持仓价格"""
|
||||
for code, price in prices.items():
|
||||
if code in self.positions:
|
||||
self.positions[code].current_price = price
|
||||
|
||||
def get_net_value(self) -> float:
|
||||
"""计算净值"""
|
||||
positions_value = sum(
|
||||
pos.current_price * pos.quantity for pos in self.positions.values()
|
||||
)
|
||||
return self.cash + positions_value
|
||||
|
||||
def record_net_value(self) -> None:
|
||||
"""记录当前净值"""
|
||||
self._net_value_history.append(self.get_net_value())
|
||||
|
||||
def get_net_value_series(self) -> pd.Series:
|
||||
"""获取净值序列"""
|
||||
return pd.Series(self._net_value_history)
|
||||
|
||||
def get_weight(self, code: str) -> float:
|
||||
"""计算持仓权重"""
|
||||
if code not in self.positions:
|
||||
return 0
|
||||
|
||||
position_value = self.positions[code].current_price * self.positions[code].quantity
|
||||
return position_value / self.get_net_value()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Portfolio(capital={self.cash:.2f}, positions={len(self.positions)})"
|
||||
|
||||
|
||||
class Executor(ABC):
|
||||
"""
|
||||
执行器抽象基类
|
||||
|
||||
支持不同执行模式:
|
||||
- backtest: 回测模式
|
||||
- dry_run: 模拟盘模式
|
||||
- live: 实盘模式(TODO)
|
||||
所有执行器必须实现execute方法
|
||||
"""
|
||||
|
||||
mode: str = "base"
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None):
|
||||
self._config = config or {}
|
||||
self._portfolio = None
|
||||
def __init__(self, **params):
|
||||
"""初始化执行器参数"""
|
||||
self._params = params
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> Portfolio:
|
||||
@@ -58,121 +122,71 @@ class Executor(ABC):
|
||||
|
||||
Args:
|
||||
signals: 信号DataFrame
|
||||
data: 价格数据
|
||||
data: OHLCV数据
|
||||
|
||||
Returns:
|
||||
持仓组合
|
||||
Portfolio对象
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_mode(self) -> str:
|
||||
"""获取执行模式"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def portfolio(self) -> Optional[Portfolio]:
|
||||
"""获取当前持仓"""
|
||||
return self._portfolio
|
||||
def __repr__(self) -> str:
|
||||
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
|
||||
return f"{self.__class__.__name__}({params_str})"
|
||||
|
||||
|
||||
class BacktestExecutor(Executor):
|
||||
"""
|
||||
回测执行器
|
||||
回测执行器(通用骨架)
|
||||
|
||||
执行回测逻辑:
|
||||
- 处理信号
|
||||
- 计算净值
|
||||
- 记录交易
|
||||
具体回测逻辑需要在strategies中定制实现
|
||||
"""
|
||||
|
||||
mode = "backtest"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_capital: float = 100000.0,
|
||||
trade_cost: float = 0.001
|
||||
):
|
||||
super().__init__()
|
||||
def __init__(self, initial_capital: float = 100000, trade_cost: float = 0.001):
|
||||
super().__init__(initial_capital=initial_capital, trade_cost=trade_cost)
|
||||
self.initial_capital = initial_capital
|
||||
self.trade_cost = trade_cost
|
||||
|
||||
def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> Portfolio:
|
||||
"""执行回测"""
|
||||
# 初始化持仓
|
||||
self._portfolio = Portfolio(
|
||||
positions={},
|
||||
cash=self.initial_capital,
|
||||
nav=1.0,
|
||||
trades=[]
|
||||
)
|
||||
"""
|
||||
执行回测(简化版本)
|
||||
|
||||
# 回测逻辑(简化版)
|
||||
result = pd.DataFrame(index=signals.index)
|
||||
result['nav'] = 1.0
|
||||
result['daily_return'] = 0.0
|
||||
完整回测逻辑需要定制实现
|
||||
"""
|
||||
portfolio = Portfolio(self.initial_capital)
|
||||
|
||||
# TODO: 完整回测逻辑迁移
|
||||
# 这里只提供骨架,具体逻辑需要定制实现
|
||||
# 包括:净值计算、交易成本扣除、基准对比等
|
||||
|
||||
return self._portfolio
|
||||
|
||||
def get_mode(self) -> str:
|
||||
return "backtest"
|
||||
return portfolio
|
||||
|
||||
|
||||
class DryRunExecutor(Executor):
|
||||
"""
|
||||
模拟盘执行器
|
||||
Dry-run执行器(通用)
|
||||
|
||||
执行模拟交易:
|
||||
- 模拟下单
|
||||
- 模拟成交
|
||||
- 模拟持仓更新
|
||||
用于模拟运行,不实际执行交易
|
||||
"""
|
||||
|
||||
mode = "dry_run"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_capital: float = 100000.0,
|
||||
simulated_exchange = None
|
||||
):
|
||||
super().__init__()
|
||||
self.initial_capital = initial_capital
|
||||
self.simulated_exchange = simulated_exchange
|
||||
def __init__(self, verbose: bool = True):
|
||||
super().__init__(verbose=verbose)
|
||||
self.verbose = verbose
|
||||
|
||||
def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> Portfolio:
|
||||
"""执行模拟盘"""
|
||||
# 初始化持仓
|
||||
self._portfolio = Portfolio(
|
||||
positions={},
|
||||
cash=self.initial_capital,
|
||||
nav=1.0,
|
||||
trades=[]
|
||||
)
|
||||
"""模拟执行"""
|
||||
portfolio = Portfolio(100000)
|
||||
|
||||
# 模拟执行逻辑
|
||||
# TODO: 模拟订单执行
|
||||
for date in signals.index:
|
||||
signal = signals.loc[date, 'signal']
|
||||
|
||||
if signal and self.verbose:
|
||||
print(f"[{date}] Signal: {signal}")
|
||||
|
||||
return self._portfolio
|
||||
|
||||
def get_mode(self) -> str:
|
||||
return "dry_run"
|
||||
|
||||
def simulate_order(self, code: str, direction: str, quantity: float, price: float):
|
||||
"""模拟下单"""
|
||||
# 记录模拟订单
|
||||
print(f"[DRY_RUN] {direction} {quantity} {code} @ {price}")
|
||||
|
||||
# 更新持仓
|
||||
if direction == 'BUY':
|
||||
# 模拟买入
|
||||
cost = quantity * price
|
||||
if cost <= self._portfolio.cash:
|
||||
self._portfolio.cash -= cost
|
||||
# TODO: 创建Position对象
|
||||
elif direction == 'SELL':
|
||||
# 模拟卖出
|
||||
if code in self._portfolio.positions:
|
||||
# TODO: 平仓逻辑
|
||||
pass
|
||||
return portfolio
|
||||
|
||||
|
||||
# 导出抽象接口
|
||||
__all__ = ['Portfolio', 'Executor', 'BacktestExecutor', 'DryRunExecutor']
|
||||
@@ -1,54 +1,28 @@
|
||||
"""
|
||||
因子层抽象设计
|
||||
因子层抽象接口(通用)
|
||||
|
||||
核心组件:
|
||||
- FactorBase: 因子抽象基类
|
||||
- FactorRegistry: 因子注册器
|
||||
- FactorCombiner: 因子组合器
|
||||
只提供抽象基类和注册机制,具体因子实现在strategies/shared/factors/
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Any, Type
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactorMeta:
|
||||
"""因子元信息"""
|
||||
name: str
|
||||
category: str # 'momentum', 'trend', 'reversal', 'volatility', 'fundamental'
|
||||
params: Dict[str, Any]
|
||||
description: str = ""
|
||||
|
||||
|
||||
class FactorBase(ABC):
|
||||
"""
|
||||
因子抽象基类
|
||||
|
||||
所有因子必须继承此基类,实现compute方法。
|
||||
支持参数配置、数据验证、元信息管理。
|
||||
所有因子必须实现compute方法
|
||||
"""
|
||||
|
||||
# 类属性(可被配置覆盖)
|
||||
name: str = "base"
|
||||
category: str = "unknown"
|
||||
|
||||
def __init__(self, **params):
|
||||
"""
|
||||
初始化因子
|
||||
|
||||
Args:
|
||||
**params: 因子参数(如n_days=25, period=14等)
|
||||
"""
|
||||
"""初始化因子参数"""
|
||||
self._params = params
|
||||
self._meta = FactorMeta(
|
||||
name=self.name,
|
||||
category=self.category,
|
||||
params=params,
|
||||
description=self.__doc__ or ""
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
@@ -56,139 +30,84 @@ class FactorBase(ABC):
|
||||
计算因子值
|
||||
|
||||
Args:
|
||||
data: 包含OHLCV数据的DataFrame
|
||||
data: OHLCV数据,必须包含'close'列
|
||||
|
||||
Returns:
|
||||
因子值序列(Series)
|
||||
因子值序列
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def params(self) -> Dict[str, Any]:
|
||||
"""获取因子参数"""
|
||||
return self._params
|
||||
|
||||
@property
|
||||
def meta(self) -> FactorMeta:
|
||||
"""获取因子元信息"""
|
||||
return self._meta
|
||||
|
||||
def validate_data(self, data: pd.DataFrame) -> bool:
|
||||
"""
|
||||
验证数据是否满足计算要求
|
||||
"""验证数据是否满足计算要求"""
|
||||
if 'close' not in data.columns:
|
||||
return False
|
||||
|
||||
Args:
|
||||
data: 数据DataFrame
|
||||
|
||||
Returns:
|
||||
是否满足要求
|
||||
"""
|
||||
# 默认验证:数据长度 >= 最小周期
|
||||
min_periods = self._params.get('min_periods', 20)
|
||||
return len(data) >= min_periods
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(name={self.name}, params={self._params})"
|
||||
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
|
||||
return f"{self.__class__.__name__}({params_str})"
|
||||
|
||||
|
||||
class FactorRegistry:
|
||||
"""
|
||||
因子注册器
|
||||
因子注册器(通用)
|
||||
|
||||
管理所有注册的因子,支持:
|
||||
- 注册因子类
|
||||
- 获取因子实例
|
||||
- 列出可用因子
|
||||
- 按类别筛选因子
|
||||
管理因子类的注册和获取
|
||||
"""
|
||||
|
||||
_factors: Dict[str, type] = {}
|
||||
_factors: Dict[str, Type[FactorBase]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, factor_class: type) -> None:
|
||||
"""
|
||||
注册因子类
|
||||
|
||||
Args:
|
||||
factor_class: 因子类(必须继承FactorBase)
|
||||
"""
|
||||
if not isinstance(factor_class, type) or not issubclass(factor_class, FactorBase):
|
||||
raise TypeError(f"factor_class must be a subclass of FactorBase")
|
||||
|
||||
# 创建临时实例获取名称
|
||||
def register(cls, factor_class: Type[FactorBase]) -> None:
|
||||
"""注册因子类"""
|
||||
temp_instance = factor_class()
|
||||
name = temp_instance.name
|
||||
|
||||
if name in cls._factors:
|
||||
print(f"因子已注册,覆盖: {name}")
|
||||
|
||||
cls._factors[name] = factor_class
|
||||
print(f"✓ 因子已注册: {name} ({factor_class.__name__})")
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str, **params) -> FactorBase:
|
||||
"""
|
||||
获取因子实例
|
||||
|
||||
Args:
|
||||
name: 因子名称
|
||||
**params: 因子参数
|
||||
|
||||
Returns:
|
||||
因子实例
|
||||
"""
|
||||
"""获取因子实例"""
|
||||
if name not in cls._factors:
|
||||
raise KeyError(f"Factor '{name}' not registered. Available: {cls.list()}")
|
||||
raise ValueError(f"因子未注册: {name}")
|
||||
|
||||
factor_class = cls._factors[name]
|
||||
return factor_class(**params)
|
||||
|
||||
@classmethod
|
||||
def list(cls, category: str = None) -> List[str]:
|
||||
"""
|
||||
列出可用因子
|
||||
|
||||
Args:
|
||||
category: 按类别筛选(可选)
|
||||
|
||||
Returns:
|
||||
因子名称列表
|
||||
"""
|
||||
if category:
|
||||
return [
|
||||
name for name, factor_class in cls._factors.items()
|
||||
if factor_class().category == category
|
||||
]
|
||||
def list_factors(cls) -> List[str]:
|
||||
"""列出所有已注册因子"""
|
||||
return list(cls._factors.keys())
|
||||
|
||||
@classmethod
|
||||
def list_by_category(cls) -> Dict[str, List[str]]:
|
||||
"""
|
||||
按类别列出因子
|
||||
|
||||
Returns:
|
||||
类别→因子列表字典
|
||||
"""
|
||||
result = {}
|
||||
for name, factor_class in cls._factors.items():
|
||||
cat = factor_class().category
|
||||
if cat not in result:
|
||||
result[cat] = []
|
||||
result[cat].append(name)
|
||||
return result
|
||||
def clear(cls) -> None:
|
||||
"""清空注册表"""
|
||||
cls._factors = {}
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
"""清空注册表(用于测试)"""
|
||||
cls._factors.clear()
|
||||
def get_category(cls, name: str) -> str:
|
||||
"""获取因子类别"""
|
||||
if name not in cls._factors:
|
||||
return "unknown"
|
||||
|
||||
temp_instance = cls._factors[name]()
|
||||
return temp_instance.category
|
||||
|
||||
|
||||
class FactorCombiner:
|
||||
"""
|
||||
因子组合器
|
||||
因子组合器(通用)
|
||||
|
||||
支持多因子加权组合,用于:
|
||||
- 多因子策略
|
||||
- 因子权重调整
|
||||
- 因子结果合并
|
||||
支持多因子加权组合
|
||||
"""
|
||||
|
||||
SUPPORTED_METHODS = ['weighted_sum', 'rank_average', 'zscore_sum', 'equal_weight']
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
factors: List[FactorBase],
|
||||
@@ -196,87 +115,77 @@ class FactorCombiner:
|
||||
method: str = 'weighted_sum'
|
||||
):
|
||||
"""
|
||||
初始化因子组合器
|
||||
初始化组合器
|
||||
|
||||
Args:
|
||||
factors: 因子实例列表
|
||||
weights: 权重列表(默认等权)
|
||||
method: 组合方法 ('weighted_sum', 'average', 'max', 'min')
|
||||
weights: 因子权重列表(可选)
|
||||
method: 组合方法(weighted_sum/rank_average/zscore_sum/equal_weight)
|
||||
"""
|
||||
if not factors:
|
||||
raise ValueError("factors list cannot be empty")
|
||||
|
||||
if method not in self.SUPPORTED_METHODS:
|
||||
raise ValueError(f"Unsupported method: {method}")
|
||||
|
||||
self._factors = factors
|
||||
self._weights = weights or [1.0 / len(factors)] * len(factors)
|
||||
|
||||
if weights is None:
|
||||
self._weights = [1.0 / len(factors)] * len(factors)
|
||||
else:
|
||||
if len(weights) != len(factors):
|
||||
raise ValueError("weights length must match factors length")
|
||||
self._weights = weights
|
||||
|
||||
self._method = method
|
||||
|
||||
# 验证权重
|
||||
if len(self._weights) != len(factors):
|
||||
raise ValueError(f"weights length ({len(self._weights)}) != factors length ({len(factors)})")
|
||||
|
||||
# 归一化权重
|
||||
total_weight = sum(self._weights)
|
||||
self._weights = [w / total_weight for w in self._weights]
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
计算所有因子并组合
|
||||
|
||||
Args:
|
||||
data: 输入数据
|
||||
|
||||
Returns:
|
||||
包含各因子值和组合因子值的DataFrame
|
||||
DataFrame包含各因子值和combined列
|
||||
"""
|
||||
result = pd.DataFrame(index=data.index)
|
||||
|
||||
# 计算各因子
|
||||
for i, factor in enumerate(self._factors):
|
||||
# 验证数据
|
||||
if not factor.validate_data(data):
|
||||
print(f"⚠ 因子 {factor.name} 数据验证失败,跳过")
|
||||
continue
|
||||
|
||||
# 计算因子值
|
||||
factor_values = factor.compute(data)
|
||||
result[factor.name] = factor_values
|
||||
|
||||
# 加权因子值
|
||||
result[f"{factor.name}_weighted"] = factor_values * self._weights[i]
|
||||
col_name = f"{factor.name}"
|
||||
result[col_name] = factor_values
|
||||
|
||||
# 组合因子值
|
||||
weighted_cols = [f"{f.name}_weighted" for f in self._factors if f.name in result.columns]
|
||||
|
||||
if self._method == 'weighted_sum':
|
||||
result['combined'] = result[weighted_cols].sum(axis=1)
|
||||
elif self._method == 'average':
|
||||
factor_cols = [f.name for f in self._factors if f.name in result.columns]
|
||||
weighted_cols = [f.name for f in self._factors]
|
||||
result['combined'] = result[weighted_cols].apply(
|
||||
lambda row: sum(row[col] * self._weights[i] for i, col in enumerate(weighted_cols) if pd.notna(row[col])),
|
||||
axis=1
|
||||
)
|
||||
|
||||
elif self._method == 'equal_weight':
|
||||
factor_cols = [f.name for f in self._factors]
|
||||
result['combined'] = result[factor_cols].mean(axis=1)
|
||||
elif self._method == 'max':
|
||||
factor_cols = [f.name for f in self._factors if f.name in result.columns]
|
||||
result['combined'] = result[factor_cols].max(axis=1)
|
||||
elif self._method == 'min':
|
||||
factor_cols = [f.name for f in self._factors if f.name in result.columns]
|
||||
result['combined'] = result[factor_cols].min(axis=1)
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {self._method}")
|
||||
|
||||
elif self._method == 'rank_average':
|
||||
factor_cols = [f.name for f in self._factors]
|
||||
ranks = result[factor_cols].rank(axis=1)
|
||||
result['combined'] = ranks.mean(axis=1)
|
||||
|
||||
elif self._method == 'zscore_sum':
|
||||
factor_cols = [f.name for f in self._factors]
|
||||
zscores = result[factor_cols].apply(lambda x: (x - x.mean()) / x.std())
|
||||
result['combined'] = zscores.sum(axis=1)
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def factors(self) -> List[FactorBase]:
|
||||
"""获取因子列表"""
|
||||
return self._factors
|
||||
|
||||
@property
|
||||
def weights(self) -> List[float]:
|
||||
"""获取权重列表"""
|
||||
return self._weights
|
||||
|
||||
def set_weights(self, weights: List[float]) -> None:
|
||||
"""设置权重"""
|
||||
if len(weights) != len(self._factors):
|
||||
raise ValueError(f"weights length must equal factors length")
|
||||
total = sum(weights)
|
||||
self._weights = [w / total for w in weights]
|
||||
def get_factor_names(self) -> List[str]:
|
||||
"""获取因子名称列表"""
|
||||
return [f.name for f in self._factors]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
factor_names = [f.name for f in self._factors]
|
||||
return f"FactorCombiner(factors={factor_names}, weights={self._weights})"
|
||||
return f"FactorCombiner(factors={factor_names}, weights={self._weights}, method={self._method})"
|
||||
|
||||
|
||||
# 导出抽象接口
|
||||
__all__ = ['FactorBase', 'FactorRegistry', 'FactorCombiner']
|
||||
@@ -1,312 +0,0 @@
|
||||
"""
|
||||
动量因子实现
|
||||
|
||||
基于加权线性回归动量的因子
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
from framework.factors import FactorBase
|
||||
|
||||
|
||||
class MomentumFactor(FactorBase):
|
||||
"""
|
||||
动量因子
|
||||
|
||||
计算加权线性回归动量得分:
|
||||
得分 = 年化收益率 × R²
|
||||
|
||||
参数:
|
||||
- n_days: 动量窗口(默认25)
|
||||
- weighted: 是否加权(默认True)
|
||||
- crash_filter: 是否启用崩盘过滤(默认True)
|
||||
"""
|
||||
|
||||
name = "momentum"
|
||||
category = "momentum"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_days: int = 25,
|
||||
weighted: bool = True,
|
||||
crash_filter: bool = True
|
||||
):
|
||||
super().__init__(n_days=n_days, weighted=weighted, crash_filter=crash_filter)
|
||||
self.n_days = n_days
|
||||
self.weighted = weighted
|
||||
self.crash_filter = crash_filter
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算动量因子值"""
|
||||
if 'close' not in data.columns:
|
||||
raise ValueError("data must contain 'close' column")
|
||||
|
||||
prices = data['close']
|
||||
|
||||
if self.weighted:
|
||||
# 加权动量得分
|
||||
factor_values = prices.rolling(self.n_days).apply(
|
||||
lambda x: self._weighted_momentum_score(x.values),
|
||||
raw=False
|
||||
)
|
||||
else:
|
||||
# 简单动量
|
||||
factor_values = prices.pct_change(self.n_days)
|
||||
|
||||
# 应用崩盘过滤
|
||||
if self.crash_filter:
|
||||
factor_values = self._apply_crash_filter(prices, factor_values)
|
||||
|
||||
return factor_values
|
||||
|
||||
def _weighted_momentum_score(self, prices: np.ndarray) -> float:
|
||||
"""计算加权动量得分"""
|
||||
if len(prices) < 5:
|
||||
return 0.0
|
||||
|
||||
y = np.log(prices)
|
||||
x = np.arange(len(y))
|
||||
weights = np.linspace(1, 2, len(y))
|
||||
|
||||
# 加权线性回归
|
||||
slope, intercept = np.polyfit(x, y, 1, w=weights)
|
||||
annualized_returns = math.exp(slope * 250) - 1
|
||||
|
||||
# 加权R²
|
||||
y_pred = slope * x + intercept
|
||||
ss_res = np.sum(weights * (y - y_pred) ** 2)
|
||||
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
|
||||
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
|
||||
|
||||
return annualized_returns * r2
|
||||
|
||||
def _apply_crash_filter(
|
||||
self,
|
||||
prices: pd.Series,
|
||||
factor_values: pd.Series
|
||||
) -> pd.Series:
|
||||
"""崩盘过滤:连续3天跌>5%清零"""
|
||||
result = factor_values.copy()
|
||||
|
||||
for i in range(3, len(prices)):
|
||||
r1 = prices.iloc[i] / prices.iloc[i-1]
|
||||
r2 = prices.iloc[i-1] / prices.iloc[i-2]
|
||||
r3 = prices.iloc[i-2] / prices.iloc[i-3]
|
||||
|
||||
# 条件1:任一天跌>5%
|
||||
con1 = min(r1, r2, r3) < 0.95
|
||||
# 条件2:连续下跌且累计跌>5%
|
||||
con2 = (r1 < 1) and (r2 < 1) and (r3 < 1) and (prices.iloc[i] / prices.iloc[i-3] < 0.95)
|
||||
|
||||
if con1 or con2:
|
||||
result.iloc[i] = 0.0
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TrendFactor(FactorBase):
|
||||
"""
|
||||
趋势因子
|
||||
|
||||
计算趋势强度:
|
||||
- MA交叉偏离度
|
||||
- MACD趋势
|
||||
|
||||
参数:
|
||||
- method: 趋势方法('ma_cross', 'macd')
|
||||
- fast: 快线周期(默认5)
|
||||
- slow: 慢线周期(默认20)
|
||||
"""
|
||||
|
||||
name = "trend"
|
||||
category = "trend"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: str = 'ma_cross',
|
||||
fast: int = 5,
|
||||
slow: int = 20
|
||||
):
|
||||
super().__init__(method=method, fast=fast, slow=slow)
|
||||
self.method = method
|
||||
self.fast = fast
|
||||
self.slow = slow
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算趋势因子值"""
|
||||
if 'close' not in data.columns:
|
||||
raise ValueError("data must contain 'close' column")
|
||||
|
||||
prices = data['close']
|
||||
|
||||
if self.method == 'ma_cross':
|
||||
# MA交叉偏离度
|
||||
fast_ma = prices.rolling(self.fast).mean()
|
||||
slow_ma = prices.rolling(self.slow).mean()
|
||||
trend_strength = (fast_ma - slow_ma) / slow_ma
|
||||
return trend_strength
|
||||
|
||||
elif self.method == 'macd':
|
||||
# MACD趋势
|
||||
ema12 = prices.ewm(span=12).mean()
|
||||
ema26 = prices.ewm(span=26).mean()
|
||||
macd = ema12 - ema26
|
||||
signal = macd.ewm(span=9).mean()
|
||||
return macd - signal
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {self.method}")
|
||||
|
||||
|
||||
class ReversalFactor(FactorBase):
|
||||
"""
|
||||
反转因子
|
||||
|
||||
计算超买超卖信号:
|
||||
- RSI偏离度
|
||||
- KDJ
|
||||
|
||||
参数:
|
||||
- method: 反转方法('rsi', 'kdj')
|
||||
- period: 周期(默认14)
|
||||
- overbought: 超买阈值(默认70)
|
||||
- oversold: 超卖阈值(默认30)
|
||||
"""
|
||||
|
||||
name = "reversal"
|
||||
category = "reversal"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: str = 'rsi',
|
||||
period: int = 14,
|
||||
overbought: float = 70,
|
||||
oversold: float = 30
|
||||
):
|
||||
super().__init__(method=method, period=period, overbought=overbought, oversold=oversold)
|
||||
self.method = method
|
||||
self.period = period
|
||||
self.overbought = overbought
|
||||
self.oversold = oversold
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算反转因子值"""
|
||||
if 'close' not in data.columns:
|
||||
raise ValueError("data must contain 'close' column")
|
||||
|
||||
prices = data['close']
|
||||
|
||||
if self.method == 'rsi':
|
||||
# RSI反转信号
|
||||
rsi = self._compute_rsi(prices, self.period)
|
||||
|
||||
# 超买超卖偏离度
|
||||
# 超买 → 负值(反转向下信号)
|
||||
# 超卖 → 正值(反转向上信号)
|
||||
reversal_signal = pd.Series(index=prices.index, dtype=float)
|
||||
reversal_signal = np.where(
|
||||
rsi > self.overbought,
|
||||
-(rsi - self.overbought) / (100 - self.overbought), # 超买:负值
|
||||
np.where(
|
||||
rsi < self.oversold,
|
||||
(self.oversold - rsi) / self.oversold, # 超卖:正值
|
||||
0 # 正常区间:0
|
||||
)
|
||||
)
|
||||
return pd.Series(reversal_signal, index=prices.index)
|
||||
|
||||
elif self.method == 'kdj':
|
||||
# KDJ反转信号
|
||||
return self._compute_kdj(data)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {self.method}")
|
||||
|
||||
def _compute_rsi(self, prices: pd.Series, period: int) -> pd.Series:
|
||||
"""计算RSI"""
|
||||
delta = prices.diff()
|
||||
gain = delta.where(delta > 0, 0)
|
||||
loss = (-delta).where(delta < 0, 0)
|
||||
|
||||
avg_gain = gain.rolling(period).mean()
|
||||
avg_loss = loss.rolling(period).mean()
|
||||
|
||||
rs = avg_gain / avg_loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
return rsi
|
||||
|
||||
def _compute_kdj(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算KDJ反转信号"""
|
||||
low = data['low']
|
||||
high = data['high']
|
||||
close = data['close']
|
||||
|
||||
# 计算K、D、J
|
||||
low_min = low.rolling(self.period).min()
|
||||
high_max = high.rolling(self.period).max()
|
||||
|
||||
rsv = (close - low_min) / (high_max - low_min) * 100
|
||||
|
||||
k = rsv.ewm(alpha=1/3).mean()
|
||||
d = k.ewm(alpha=1/3).mean()
|
||||
j = 3 * k - 2 * d
|
||||
|
||||
# J值偏离度作为反转信号
|
||||
return j
|
||||
|
||||
|
||||
class VolatilityFactor(FactorBase):
|
||||
"""
|
||||
波动率因子
|
||||
|
||||
计算价格波动率:
|
||||
- ATR
|
||||
- 标准差
|
||||
|
||||
参数:
|
||||
- method: 波动率方法('atr', 'std')
|
||||
- period: 周期(默认20)
|
||||
"""
|
||||
|
||||
name = "volatility"
|
||||
category = "volatility"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: str = 'std',
|
||||
period: int = 20
|
||||
):
|
||||
super().__init__(method=method, period=period)
|
||||
self.method = method
|
||||
self.period = period
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算波动率因子值"""
|
||||
if self.method == 'std':
|
||||
# 标准差波动率
|
||||
return data['close'].rolling(self.period).std()
|
||||
|
||||
elif self.method == 'atr':
|
||||
# ATR波动率
|
||||
return self._compute_atr(data)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {self.method}")
|
||||
|
||||
def _compute_atr(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算ATR"""
|
||||
high = data['high']
|
||||
low = data['low']
|
||||
close = data['close']
|
||||
|
||||
prev_close = close.shift(1)
|
||||
tr = pd.concat([
|
||||
high - low,
|
||||
(high - prev_close).abs(),
|
||||
(low - prev_close).abs()
|
||||
], axis=1).max(axis=1)
|
||||
|
||||
return tr.rolling(self.period).mean()
|
||||
@@ -1,351 +1,188 @@
|
||||
"""
|
||||
回调钩子与风控层设计
|
||||
风控层抽象接口(通用)
|
||||
|
||||
核心组件:
|
||||
- RiskControl: 风控抽象基类
|
||||
- StopLossControl: 止损控制
|
||||
- PositionLimitControl: 仓位限制控制
|
||||
- CallbackHook: 回调钩子管理
|
||||
只提供抽象基类和回调机制,具体风控组件在strategies/shared/risk/
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional, List
|
||||
from typing import Dict, List, Any, Callable, Optional
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class Position:
|
||||
"""持仓信息"""
|
||||
"""
|
||||
持仓数据结构(通用)
|
||||
|
||||
用于表示单个持仓的状态
|
||||
"""
|
||||
code: str
|
||||
entry_price: float
|
||||
entry_date: datetime
|
||||
current_price: float
|
||||
current_date: datetime
|
||||
quantity: float
|
||||
weight: float
|
||||
entry_time: datetime
|
||||
quantity: float = 1.0
|
||||
weight: float = 1.0
|
||||
|
||||
@property
|
||||
def profit_ratio(self) -> float:
|
||||
"""盈亏比例"""
|
||||
"""计算盈亏比例"""
|
||||
return (self.current_price - self.entry_price) / self.entry_price
|
||||
|
||||
@property
|
||||
def holding_days(self) -> int:
|
||||
"""持仓天数"""
|
||||
return (self.current_date - self.entry_date).days
|
||||
def profit_amount(self) -> float:
|
||||
"""计算盈亏金额"""
|
||||
return (self.current_price - self.entry_price) * self.quantity
|
||||
|
||||
@property
|
||||
def is_profit(self) -> bool:
|
||||
"""是否盈利"""
|
||||
return self.profit_ratio > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class Trade:
|
||||
"""交易信息"""
|
||||
code: str
|
||||
direction: str # 'entry' or 'exit'
|
||||
price: float
|
||||
date: datetime
|
||||
quantity: float
|
||||
reason: str = ""
|
||||
def holding_days(self) -> int:
|
||||
"""计算持仓天数"""
|
||||
if self.entry_time is None:
|
||||
return 0
|
||||
return (datetime.now() - self.entry_time).days
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Position(code={self.code}, profit={self.profit_ratio:.2%}, days={self.holding_days})"
|
||||
|
||||
|
||||
class RiskControl(ABC):
|
||||
"""
|
||||
风控抽象基类
|
||||
风控组件抽象基类
|
||||
|
||||
所有风控组件必须继承此基类。
|
||||
所有风控组件必须实现check方法
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
def __init__(self, **params):
|
||||
"""初始化风控参数"""
|
||||
self._params = params
|
||||
|
||||
@abstractmethod
|
||||
def check(self, position: Optional[Position], **kwargs) -> bool:
|
||||
def check(self, position: Position, **kwargs) -> bool:
|
||||
"""
|
||||
风控检查
|
||||
检查风控条件
|
||||
|
||||
Args:
|
||||
position: 持仓信息(可选)
|
||||
**kwargs: 其他参数
|
||||
position: 持仓对象
|
||||
kwargs: 额外参数(如premium、history等)
|
||||
|
||||
Returns:
|
||||
是否通过检查
|
||||
True表示通过检查,False表示触发风控
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, position: Position) -> Optional[float]:
|
||||
def apply(self, position: Position) -> Any:
|
||||
"""
|
||||
应用风控
|
||||
应用风控规则(可选)
|
||||
|
||||
Args:
|
||||
position: 持仓信息
|
||||
position: 持仓对象
|
||||
|
||||
Returns:
|
||||
应用结果(如止损价格、仓位调整比例等)
|
||||
风控结果(如止损价格、建议仓位等)
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def params(self) -> Dict[str, Any]:
|
||||
return self._params
|
||||
|
||||
|
||||
class StopLossControl(RiskControl):
|
||||
"""
|
||||
止损控制
|
||||
|
||||
参数:
|
||||
- threshold: 止损阈值(默认-0.05)
|
||||
- trailing: 是否跟踪止损(默认False)
|
||||
- trailing_percent: 跟踪止损比例(默认0.03)
|
||||
"""
|
||||
|
||||
name = "stop_loss"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
threshold: float = -0.05,
|
||||
trailing: bool = False,
|
||||
trailing_percent: float = 0.03
|
||||
):
|
||||
super().__init__(
|
||||
threshold=threshold,
|
||||
trailing=trailing,
|
||||
trailing_percent=trailing_percent
|
||||
)
|
||||
self.threshold = threshold
|
||||
self.trailing = trailing
|
||||
self.trailing_percent = trailing_percent
|
||||
self._highest_price = {} # 跟踪最高价
|
||||
|
||||
def check(self, position: Optional[Position], **kwargs) -> bool:
|
||||
"""检查是否触发止损"""
|
||||
if position is None:
|
||||
return True
|
||||
|
||||
# 更新最高价(跟踪止损)
|
||||
if self.trailing:
|
||||
if position.code not in self._highest_price:
|
||||
self._highest_price[position.code] = position.entry_price
|
||||
self._highest_price[position.code] = max(
|
||||
self._highest_price[position.code],
|
||||
position.current_price
|
||||
)
|
||||
|
||||
# 检查止损
|
||||
if self.trailing:
|
||||
# 跟踪止损:从最高价回撤超过阈值
|
||||
highest = self._highest_price[position.code]
|
||||
drawdown = (position.current_price - highest) / highest
|
||||
return drawdown > -self.trailing_percent
|
||||
else:
|
||||
# 固定止损:从入场价亏损超过阈值
|
||||
return position.profit_ratio > self.threshold
|
||||
|
||||
def apply(self, position: Position) -> Optional[float]:
|
||||
"""返回止损价格"""
|
||||
if self.trailing:
|
||||
highest = self._highest_price.get(position.code, position.entry_price)
|
||||
return highest * (1 - self.trailing_percent)
|
||||
else:
|
||||
return position.entry_price * (1 + self.threshold)
|
||||
|
||||
|
||||
class PositionLimitControl(RiskControl):
|
||||
"""
|
||||
仓位限制控制
|
||||
|
||||
参数:
|
||||
- max_position: 单品种最大仓位(默认0.33)
|
||||
- max_total: 总仓位上限(默认1.0)
|
||||
"""
|
||||
|
||||
name = "position_limit"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_position: float = 0.33,
|
||||
max_total: float = 1.0
|
||||
):
|
||||
super().__init__(
|
||||
max_position=max_position,
|
||||
max_total=max_total
|
||||
)
|
||||
self.max_position = max_position
|
||||
self.max_total = max_total
|
||||
|
||||
def check(self, position: Optional[Position], **kwargs) -> bool:
|
||||
"""检查仓位是否超限"""
|
||||
if position is None:
|
||||
return True
|
||||
|
||||
# 检查单品种仓位
|
||||
if position.weight > self.max_position:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def apply(self, position: Position) -> Optional[float]:
|
||||
"""返回建议仓位"""
|
||||
return min(position.weight, self.max_position)
|
||||
|
||||
|
||||
class PremiumControl(RiskControl):
|
||||
"""
|
||||
溢价控制
|
||||
|
||||
参数:
|
||||
- threshold: 溢价阈值(默认0.10)
|
||||
- mode: 控制模式('filter'或'penalize')
|
||||
"""
|
||||
|
||||
name = "premium"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
threshold: float = 0.10,
|
||||
mode: str = 'filter'
|
||||
):
|
||||
super().__init__(
|
||||
threshold=threshold,
|
||||
mode=mode
|
||||
)
|
||||
self.threshold = threshold
|
||||
self.mode = mode
|
||||
|
||||
def check(self, position: Optional[Position], **kwargs) -> bool:
|
||||
"""检查溢价是否超限"""
|
||||
premium = kwargs.get('premium', 0)
|
||||
|
||||
if self.mode == 'filter':
|
||||
# 完全排除
|
||||
return premium <= self.threshold
|
||||
else:
|
||||
# 仅降权,允许通过
|
||||
return True
|
||||
|
||||
def apply(self, position: Position) -> Optional[float]:
|
||||
"""返回溢价惩罚系数"""
|
||||
if self.mode == 'penalize':
|
||||
return 0.5 # 降权50%
|
||||
return None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
|
||||
return f"{self.__class__.__name__}({params_str})"
|
||||
|
||||
|
||||
class CallbackHook:
|
||||
"""
|
||||
回调钩子管理
|
||||
回调钩子管理(通用)
|
||||
|
||||
支持策略生命周期回调:
|
||||
- before_entry: 入场前检查
|
||||
- after_entry: 入场后处理
|
||||
- before_exit: 出场前检查
|
||||
- after_exit: 出场后处理
|
||||
- dynamic_stoploss: 动态止损
|
||||
- custom_exit: 自定义出场
|
||||
支持在策略生命周期的关键节点注入自定义逻辑
|
||||
"""
|
||||
|
||||
SUPPORTED_HOOKS = [
|
||||
'before_entry', # 入场前检查
|
||||
'after_entry', # 入场后处理
|
||||
'before_exit', # 出场前检查
|
||||
'after_exit', # 出场后处理
|
||||
'dynamic_stoploss', # 动态止损计算
|
||||
'custom_exit' # 自定义出场条件
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self._hooks = {
|
||||
'before_entry': [],
|
||||
'after_entry': [],
|
||||
'before_exit': [],
|
||||
'after_exit': [],
|
||||
'dynamic_stoploss': [],
|
||||
'custom_exit': []
|
||||
"""初始化回调钩子"""
|
||||
self._hooks: Dict[str, List[Callable]] = {
|
||||
hook: [] for hook in self.SUPPORTED_HOOKS
|
||||
}
|
||||
|
||||
def register(self, hook_name: str, callback: callable) -> None:
|
||||
"""注册回调"""
|
||||
def register(self, hook_name: str, callback: Callable) -> None:
|
||||
"""注册回调函数"""
|
||||
if hook_name not in self._hooks:
|
||||
raise ValueError(f"Unknown hook: {hook_name}")
|
||||
raise ValueError(f"Unsupported hook: {hook_name}")
|
||||
|
||||
self._hooks[hook_name].append(callback)
|
||||
|
||||
def trigger(self, hook_name: str, *args, **kwargs) -> Any:
|
||||
"""触发回调"""
|
||||
"""
|
||||
触发回调
|
||||
|
||||
Args:
|
||||
hook_name: 钩子名称
|
||||
args: 位置参数
|
||||
kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
回调结果(根据钩子类型返回不同结果)
|
||||
"""
|
||||
if hook_name not in self._hooks:
|
||||
raise ValueError(f"Unknown hook: {hook_name}")
|
||||
raise ValueError(f"Unsupported hook: {hook_name}")
|
||||
|
||||
callbacks = self._hooks[hook_name]
|
||||
|
||||
if not callbacks:
|
||||
# 默认行为
|
||||
if hook_name == 'dynamic_stoploss':
|
||||
return kwargs.get('default_stoploss', -0.05)
|
||||
elif hook_name in ['before_entry', 'before_exit']:
|
||||
return True
|
||||
elif hook_name == 'custom_exit':
|
||||
return False
|
||||
return None
|
||||
|
||||
results = []
|
||||
for callback in self._hooks[hook_name]:
|
||||
try:
|
||||
result = callback(*args, **kwargs)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
print(f"⚠ Callback error: {e}")
|
||||
for callback in callbacks:
|
||||
result = callback(*args, **kwargs)
|
||||
results.append(result)
|
||||
|
||||
# before_entry和before_exit需要所有回调返回True
|
||||
# 根据钩子类型返回不同结果
|
||||
if hook_name in ['before_entry', 'before_exit']:
|
||||
# 所有回调返回True才允许
|
||||
return all(results)
|
||||
|
||||
# dynamic_stoploss返回最小的止损值
|
||||
if hook_name == 'dynamic_stoploss':
|
||||
return min(results) if results else -0.05
|
||||
# 返回最小止损值(最严格)
|
||||
return min(results)
|
||||
|
||||
# custom_exit返回是否有任一回调触发出场
|
||||
if hook_name == 'custom_exit':
|
||||
# 任一回调触发出场
|
||||
return any(results)
|
||||
|
||||
return results
|
||||
# 其他钩子返回最后一个结果
|
||||
return results[-1] if results else None
|
||||
|
||||
def clear(self, hook_name: str = None) -> None:
|
||||
def clear(self, hook_name: Optional[str] = None) -> None:
|
||||
"""清空回调"""
|
||||
if hook_name:
|
||||
self._hooks[hook_name] = []
|
||||
if hook_name in self._hooks:
|
||||
self._hooks[hook_name] = []
|
||||
else:
|
||||
for key in self._hooks:
|
||||
self._hooks[key] = []
|
||||
for hook in self._hooks:
|
||||
self._hooks[hook] = []
|
||||
|
||||
def list_hooks(self) -> List[str]:
|
||||
"""列出支持的钩子"""
|
||||
return self.SUPPORTED_HOOKS
|
||||
|
||||
def get_callbacks(self, hook_name: str) -> List[Callable]:
|
||||
"""获取钩子的所有回调"""
|
||||
return self._hooks.get(hook_name, [])
|
||||
|
||||
|
||||
# 便捷回调函数
|
||||
def premium_filter_callback(threshold: float = 0.10):
|
||||
"""溢价过滤回调"""
|
||||
def callback(code: str, price: float, **kwargs) -> bool:
|
||||
premium = kwargs.get('premium', 0)
|
||||
if premium > threshold:
|
||||
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
|
||||
return False
|
||||
return True
|
||||
return callback
|
||||
|
||||
|
||||
def crash_filter_callback(lookback: int = 3, crash_threshold: float = 0.05):
|
||||
"""崩盘过滤回调"""
|
||||
def callback(code: str, price: float, **kwargs) -> bool:
|
||||
history = kwargs.get('history', None)
|
||||
if history is None:
|
||||
return True
|
||||
|
||||
# 检查最近N天是否有崩盘
|
||||
recent = history.tail(lookback)
|
||||
if len(recent) < lookback:
|
||||
return True
|
||||
|
||||
returns = recent['close'].pct_change()
|
||||
min_return = returns.min()
|
||||
|
||||
if min_return < -crash_threshold:
|
||||
print(f"崩盘检测,拒绝入场: {code} (最大跌幅={min_return:.2%})")
|
||||
return False
|
||||
return True
|
||||
return callback
|
||||
|
||||
|
||||
def holding_time_stoploss_callback(
|
||||
day_5_stoploss: float = -0.05,
|
||||
day_10_stoploss: float = -0.03
|
||||
):
|
||||
"""持仓时间动态止损回调"""
|
||||
def callback(position: Position) -> float:
|
||||
if position.holding_days >= 10:
|
||||
return day_10_stoploss # 10天后收紧止损
|
||||
elif position.holding_days >= 5:
|
||||
return day_5_stoploss
|
||||
return -0.10 # 默认止损
|
||||
return callback
|
||||
# 导出抽象接口
|
||||
__all__ = ['Position', 'RiskControl', 'CallbackHook']
|
||||
@@ -1,52 +1,26 @@
|
||||
"""
|
||||
信号层抽象设计
|
||||
信号层抽象接口(通用)
|
||||
|
||||
核心组件:
|
||||
- SignalGenerator: 信号生成器抽象基类
|
||||
- TopNSelector: Top N选股器(轮动策略)
|
||||
- TrendFollower: 趋势跟随器(趋势策略)
|
||||
- ReversalTrader: 反转交易器(反转策略)
|
||||
只提供抽象基类,具体信号生成器在strategies/shared/signals/
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SignalMeta:
|
||||
"""信号元信息"""
|
||||
mode: str # 'top_n', 'trend', 'reversal'
|
||||
select_num: int
|
||||
description: str = ""
|
||||
from typing import List, Optional, Any
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class SignalGenerator(ABC):
|
||||
"""
|
||||
信号生成器抽象基类
|
||||
|
||||
所有信号生成器必须继承此基类,实现generate方法。
|
||||
支持不同策略类型的信号生成逻辑。
|
||||
所有信号生成器必须实现generate方法
|
||||
"""
|
||||
|
||||
# 类属性(可被配置覆盖)
|
||||
mode: str = "base"
|
||||
|
||||
def __init__(self, **params):
|
||||
"""
|
||||
初始化信号生成器
|
||||
|
||||
Args:
|
||||
**params: 信号参数
|
||||
"""
|
||||
"""初始化信号生成器参数"""
|
||||
self._params = params
|
||||
self._meta = SignalMeta(
|
||||
mode=self.mode,
|
||||
select_num=params.get('select_num', 1),
|
||||
description=self.__doc__ or ""
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
|
||||
@@ -54,300 +28,27 @@ class SignalGenerator(ABC):
|
||||
生成交易信号
|
||||
|
||||
Args:
|
||||
factor_data: 包含因子值的DataFrame
|
||||
factor_data: 因子数据DataFrame
|
||||
|
||||
Returns:
|
||||
包含信号列的DataFrame
|
||||
包含'signal'列的DataFrame
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def params(self) -> Dict[str, Any]:
|
||||
"""获取信号参数"""
|
||||
return self._params
|
||||
|
||||
@property
|
||||
def meta(self) -> SignalMeta:
|
||||
"""获取信号元信息"""
|
||||
return self._meta
|
||||
def validate_factor_data(self, factor_data: pd.DataFrame) -> bool:
|
||||
"""验证因子数据是否有效"""
|
||||
if factor_data.empty:
|
||||
return False
|
||||
|
||||
if 'signal' in factor_data.columns:
|
||||
print("Warning: factor_data already contains 'signal' column")
|
||||
|
||||
return True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(mode={self.mode}, params={self._params})"
|
||||
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
|
||||
return f"{self.__class__.__name__}({params_str})"
|
||||
|
||||
|
||||
class TopNSelector(SignalGenerator):
|
||||
"""
|
||||
Top N选股器
|
||||
|
||||
用于轮动策略:
|
||||
- 按因子值排序,选出Top N标的
|
||||
- 支持分组选股(先类内竞争,再跨类排序)
|
||||
|
||||
参数:
|
||||
- select_num: 选中数量(默认3)
|
||||
- group_by: 分组列名(可选,如'market')
|
||||
- top_per_group: 每组选中数量(默认1)
|
||||
- min_score: 最小得分阈值(可选)
|
||||
"""
|
||||
|
||||
mode = "top_n"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
select_num: int = 3,
|
||||
group_by: Optional[str] = None,
|
||||
top_per_group: int = 1,
|
||||
min_score: Optional[float] = None
|
||||
):
|
||||
super().__init__(
|
||||
select_num=select_num,
|
||||
group_by=group_by,
|
||||
top_per_group=top_per_group,
|
||||
min_score=min_score
|
||||
)
|
||||
self.select_num = select_num
|
||||
self.group_by = group_by
|
||||
self.top_per_group = top_per_group
|
||||
self.min_score = min_score
|
||||
|
||||
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""生成Top N选股信号"""
|
||||
result = pd.DataFrame(index=factor_data.index)
|
||||
|
||||
# 获取因子列(排除非因子列)
|
||||
factor_cols = self._get_factor_columns(factor_data)
|
||||
|
||||
if not factor_cols:
|
||||
print("⚠ 未找到因子列")
|
||||
result['signal'] = ''
|
||||
return result
|
||||
|
||||
# 每日选股
|
||||
signals = []
|
||||
for date in factor_data.index:
|
||||
row = factor_data.loc[date]
|
||||
|
||||
# 提取当日因子值
|
||||
scores = {}
|
||||
for col in factor_cols:
|
||||
score = row[col]
|
||||
if pd.notna(score):
|
||||
scores[col] = score
|
||||
|
||||
# 应用最小得分过滤
|
||||
if self.min_score:
|
||||
scores = {k: v for k, v in scores.items() if v >= self.min_score}
|
||||
|
||||
# 选股逻辑
|
||||
if self.group_by and 'group_info' in factor_data.columns:
|
||||
# 分组选股:先类内竞争,再跨类排序
|
||||
selected = self._grouped_selection(scores, factor_data.loc[date])
|
||||
else:
|
||||
# 全局Top N
|
||||
selected = self._global_top_n(scores)
|
||||
|
||||
# 信号格式:逗号分隔的代码列表
|
||||
signals.append(','.join(selected) if selected else '')
|
||||
|
||||
result['signal'] = signals
|
||||
result['signal_raw'] = signals # 原始信号(未shift)
|
||||
|
||||
# T+1执行:信号向后移位1天
|
||||
result['signal'] = result['signal'].shift(1)
|
||||
|
||||
return result
|
||||
|
||||
def _get_factor_columns(self, data: pd.DataFrame) -> List[str]:
|
||||
"""获取因子列名"""
|
||||
# 排除已知非因子列
|
||||
exclude_cols = ['signal', 'signal_raw', 'group_info', 'combined', 'open', 'high', 'low', 'close', 'volume']
|
||||
factor_cols = [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')]
|
||||
return factor_cols
|
||||
|
||||
def _global_top_n(self, scores: Dict[str, float]) -> List[str]:
|
||||
"""全局Top N选股"""
|
||||
if not scores:
|
||||
return []
|
||||
|
||||
# 按得分排序
|
||||
sorted_items = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 选Top N
|
||||
selected = [item[0] for item in sorted_items[:self.select_num]]
|
||||
return selected
|
||||
|
||||
def _grouped_selection(
|
||||
self,
|
||||
scores: Dict[str, float],
|
||||
row: pd.Series
|
||||
) -> List[str]:
|
||||
"""分组选股:先类内竞争,再跨类排序"""
|
||||
if 'group_info' not in row.index:
|
||||
return self._global_top_n(scores)
|
||||
|
||||
group_info = row['group_info']
|
||||
if pd.isna(group_info):
|
||||
return self._global_top_n(scores)
|
||||
|
||||
# 解析分组信息:{code: group}
|
||||
groups = group_info if isinstance(group_info, dict) else {}
|
||||
|
||||
# 类内竞争:每组选Top1
|
||||
group_champions = {}
|
||||
for code, score in scores.items():
|
||||
group = groups.get(code, 'default')
|
||||
if group not in group_champions or score > group_champions[group][1]:
|
||||
group_champions[group] = (code, score)
|
||||
|
||||
# 跨类排序:从冠军中选Top N
|
||||
champions_scores = {code: score for code, score in group_champions.values()}
|
||||
return self._global_top_n(champions_scores)
|
||||
|
||||
|
||||
class TrendFollower(SignalGenerator):
|
||||
"""
|
||||
趋势跟随器
|
||||
|
||||
用于趋势跟踪策略:
|
||||
- 趋势强度 > 入场阈值 → 入场信号
|
||||
- 趋势强度 < 出场阈值 → 出场信号
|
||||
|
||||
参数:
|
||||
- entry_threshold: 入场阈值(默认0.02)
|
||||
- exit_threshold: 出场阈值(默认-0.02)
|
||||
- select_num: 最大持仓数量(默认1)
|
||||
"""
|
||||
|
||||
mode = "trend"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entry_threshold: float = 0.02,
|
||||
exit_threshold: float = -0.02,
|
||||
select_num: int = 1
|
||||
):
|
||||
super().__init__(
|
||||
entry_threshold=entry_threshold,
|
||||
exit_threshold=exit_threshold,
|
||||
select_num=select_num
|
||||
)
|
||||
self.entry_threshold = entry_threshold
|
||||
self.exit_threshold = exit_threshold
|
||||
self.select_num = select_num
|
||||
|
||||
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""生成趋势跟随信号"""
|
||||
result = pd.DataFrame(index=factor_data.index)
|
||||
|
||||
factor_cols = self._get_factor_columns(factor_data)
|
||||
|
||||
for col in factor_cols:
|
||||
trend_strength = factor_data[col]
|
||||
|
||||
# 入场信号:趋势强度 > 阈值
|
||||
result[f'{col}_entry'] = trend_strength > self.entry_threshold
|
||||
|
||||
# 出场信号:趋势强度 < 阈值
|
||||
result[f'{col}_exit'] = trend_strength < self.exit_threshold
|
||||
|
||||
# 综合信号:入场强度最高的Top N
|
||||
signals = []
|
||||
for date in result.index:
|
||||
entry_signals = []
|
||||
for col in factor_cols:
|
||||
if result.loc[date, f'{col}_entry']:
|
||||
score = factor_data.loc[date, col]
|
||||
if pd.notna(score):
|
||||
entry_signals.append((col, score))
|
||||
|
||||
# 按强度排序,选Top N
|
||||
entry_signals.sort(key=lambda x: x[1], reverse=True)
|
||||
selected = [item[0] for item in entry_signals[:self.select_num]]
|
||||
signals.append(','.join(selected) if selected else '')
|
||||
|
||||
result['signal'] = signals
|
||||
result['signal'] = result['signal'].shift(1) # T+1执行
|
||||
|
||||
return result
|
||||
|
||||
def _get_factor_columns(self, data: pd.DataFrame) -> List[str]:
|
||||
"""获取因子列名"""
|
||||
exclude_cols = ['signal', 'signal_raw', 'combined', 'open', 'high', 'low', 'close', 'volume']
|
||||
return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')]
|
||||
|
||||
|
||||
class ReversalTrader(SignalGenerator):
|
||||
"""
|
||||
反转交易器
|
||||
|
||||
用于反转策略:
|
||||
- 超买区域(RSI>70) → 反转向下信号(卖出)
|
||||
- 超卖区域(RSI<30) → 反转向上信号(买入)
|
||||
|
||||
参数:
|
||||
- overbought: 超买阈值(默认70)
|
||||
- oversold: 超卖阈值(默认30)
|
||||
- reversal_threshold: 反转信号强度阈值(默认0.1)
|
||||
"""
|
||||
|
||||
mode = "reversal"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
overbought: float = 70,
|
||||
oversold: float = 30,
|
||||
reversal_threshold: float = 0.1
|
||||
):
|
||||
super().__init__(
|
||||
overbought=overbought,
|
||||
oversold=oversold,
|
||||
reversal_threshold=reversal_threshold
|
||||
)
|
||||
self.overbought = overbought
|
||||
self.oversold = oversold
|
||||
self.reversal_threshold = reversal_threshold
|
||||
|
||||
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""生成反转交易信号"""
|
||||
result = pd.DataFrame(index=factor_data.index)
|
||||
|
||||
factor_cols = self._get_factor_columns(factor_data)
|
||||
|
||||
for col in factor_cols:
|
||||
reversal_signal = factor_data[col]
|
||||
|
||||
# 买入信号:反转信号 > 阈值(正值,超卖反转)
|
||||
result[f'{col}_buy'] = reversal_signal > self.reversal_threshold
|
||||
|
||||
# 卖出信号:反转信号 < -阈值(负值,超买反转)
|
||||
result[f'{col}_sell'] = reversal_signal < -self.reversal_threshold
|
||||
|
||||
# 综合信号
|
||||
signals = []
|
||||
for date in result.index:
|
||||
buy_signals = []
|
||||
sell_signals = []
|
||||
|
||||
for col in factor_cols:
|
||||
if result.loc[date, f'{col}_buy']:
|
||||
buy_signals.append(col)
|
||||
if result.loc[date, f'{col}_sell']:
|
||||
sell_signals.append(col)
|
||||
|
||||
# 信号格式:'BUY:code1,code2' 或 'SELL:code1' 或 ''
|
||||
if buy_signals:
|
||||
signals.append(f"BUY:{','.join(buy_signals)}")
|
||||
elif sell_signals:
|
||||
signals.append(f"SELL:{','.join(sell_signals)}")
|
||||
else:
|
||||
signals.append('')
|
||||
|
||||
result['signal'] = signals
|
||||
result['signal'] = result['signal'].shift(1) # T+1执行
|
||||
|
||||
return result
|
||||
|
||||
def _get_factor_columns(self, data: pd.DataFrame) -> List[str]:
|
||||
"""获取因子列名"""
|
||||
exclude_cols = ['signal', 'signal_raw', 'combined', 'open', 'high', 'low', 'close', 'volume']
|
||||
return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')]
|
||||
# 导出抽象接口
|
||||
__all__ = ['SignalGenerator']
|
||||
@@ -1,63 +1,30 @@
|
||||
"""
|
||||
策略基类与配置层
|
||||
策略层抽象基类(通用)
|
||||
|
||||
核心组件:
|
||||
- StrategyBase: 策略抽象基类(含回调钩子)
|
||||
- ConfigLoader: 配置加载器
|
||||
只提供抽象接口,具体策略实现在strategies/
|
||||
"""
|
||||
|
||||
import yaml
|
||||
import pandas as pd
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional, List
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Any
|
||||
import pandas as pd
|
||||
|
||||
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
|
||||
from framework.factors.momentum import MomentumFactor
|
||||
from framework.signals import SignalGenerator, TopNSelector
|
||||
from framework.factors import FactorCombiner
|
||||
from framework.signals import SignalGenerator
|
||||
from framework.risk import CallbackHook, Position
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyConfig:
|
||||
"""策略配置"""
|
||||
name: str
|
||||
version: int
|
||||
factors: List[Dict]
|
||||
signal: Dict
|
||||
callbacks: Dict
|
||||
params: Dict
|
||||
|
||||
|
||||
class StrategyBase(ABC):
|
||||
"""
|
||||
策略抽象基类
|
||||
|
||||
融合Freqtrade回调机制 + 模块化因子设计
|
||||
|
||||
类属性(可被配置覆盖):
|
||||
- name: 策略名称
|
||||
- version: 接口版本
|
||||
- timeframe: K线周期
|
||||
- select_num: 选中数量
|
||||
- stoploss: 止损比例
|
||||
|
||||
回调钩子(可选实现):
|
||||
- before_entry: 入场前检查
|
||||
- after_entry: 入场后处理
|
||||
- before_exit: 出场前检查
|
||||
- after_exit: 出场后处理
|
||||
- dynamic_stoploss: 动态止损
|
||||
- custom_exit: 自定义出场条件
|
||||
所有策略必须实现init_factors和init_signal_generator方法
|
||||
"""
|
||||
|
||||
# 接口版本
|
||||
INTERFACE_VERSION = 1
|
||||
|
||||
# 类属性(可被配置覆盖)
|
||||
name: str = "base"
|
||||
timeframe: str = "1d"
|
||||
|
||||
# 类属性(可被配置覆盖)
|
||||
select_num: int = 3
|
||||
stoploss: float = -0.05
|
||||
|
||||
@@ -66,53 +33,50 @@ class StrategyBase(ABC):
|
||||
初始化策略
|
||||
|
||||
Args:
|
||||
config: 策略配置(可覆盖类属性)
|
||||
config: 配置字典(可选,用于覆盖类属性)
|
||||
"""
|
||||
# 配置覆盖类属性
|
||||
if config:
|
||||
self._apply_config(config)
|
||||
|
||||
# 初始化回调钩子
|
||||
self._callbacks = CallbackHook()
|
||||
self._register_default_callbacks()
|
||||
|
||||
# 初始化因子和信号生成器
|
||||
self._factors = None
|
||||
self._signal_gen = None
|
||||
self._factors = self.init_factors()
|
||||
self._signal_gen = self.init_signal_generator()
|
||||
|
||||
def _apply_config(self, config: Dict) -> None:
|
||||
"""应用配置"""
|
||||
params = config.get('params', {})
|
||||
|
||||
# 覆盖类属性
|
||||
for key, value in params.items():
|
||||
"""应用配置覆盖类属性"""
|
||||
for key, value in config.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
# 保存完整配置
|
||||
self._config = config
|
||||
|
||||
def _register_default_callbacks(self) -> None:
|
||||
"""注册默认回调"""
|
||||
# 注册入场前回调(溢价过滤)
|
||||
"""注册默认回调方法"""
|
||||
if hasattr(self, 'before_entry'):
|
||||
self._callbacks.register('before_entry', self.before_entry)
|
||||
|
||||
# 注册动态止损回调
|
||||
if hasattr(self, 'after_entry'):
|
||||
self._callbacks.register('after_entry', self.after_entry)
|
||||
|
||||
if hasattr(self, 'before_exit'):
|
||||
self._callbacks.register('before_exit', self.before_exit)
|
||||
|
||||
if hasattr(self, 'after_exit'):
|
||||
self._callbacks.register('after_exit', self.after_exit)
|
||||
|
||||
if hasattr(self, 'dynamic_stoploss'):
|
||||
self._callbacks.register('dynamic_stoploss', self.dynamic_stoploss)
|
||||
|
||||
# 注册自定义出场回调
|
||||
if hasattr(self, 'custom_exit'):
|
||||
self._callbacks.register('custom_exit', self.custom_exit)
|
||||
|
||||
@abstractmethod
|
||||
def init_factors(self) -> FactorCombiner:
|
||||
"""
|
||||
初始化因子组合
|
||||
初始化因子组合器
|
||||
|
||||
Returns:
|
||||
因子组合器
|
||||
FactorCombiner实例
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -122,7 +86,7 @@ class StrategyBase(ABC):
|
||||
初始化信号生成器
|
||||
|
||||
Returns:
|
||||
信号生成器
|
||||
SignalGenerator实例
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -131,85 +95,28 @@ class StrategyBase(ABC):
|
||||
运行策略
|
||||
|
||||
Args:
|
||||
data: 输入数据
|
||||
data: OHLCV数据
|
||||
|
||||
Returns:
|
||||
包含信号的DataFrame
|
||||
"""
|
||||
# 初始化因子和信号生成器
|
||||
if self._factors is None:
|
||||
self._factors = self.init_factors()
|
||||
|
||||
if self._signal_gen is None:
|
||||
self._signal_gen = self.init_signal_generator()
|
||||
|
||||
# 1. 计算因子
|
||||
factor_data = self._factors.compute(data)
|
||||
|
||||
# 2. 生成信号
|
||||
signals = self._signal_gen.generate(factor_data)
|
||||
|
||||
# 3. 应用回调钩子
|
||||
signals = self._apply_callbacks(signals, data)
|
||||
|
||||
return signals
|
||||
|
||||
def _apply_callbacks(self, signals: pd.DataFrame, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""应用回调钩子"""
|
||||
# 遍历每行信号
|
||||
for date in signals.index:
|
||||
signal = signals.loc[date, 'signal']
|
||||
|
||||
if not signal or pd.isna(signal):
|
||||
continue
|
||||
|
||||
# 解析信号(逗号分隔的代码)
|
||||
codes = signal.split(',')
|
||||
|
||||
# 应用入场前回调
|
||||
for code in codes:
|
||||
if code not in data.columns:
|
||||
continue
|
||||
|
||||
price = data.loc[date, code]
|
||||
premium = 0.0 # TODO: 从溢价数据获取
|
||||
|
||||
# 触发回调
|
||||
allowed = self._callbacks.trigger(
|
||||
'before_entry',
|
||||
code,
|
||||
price,
|
||||
premium=premium,
|
||||
history=data
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
# 移除被拒绝的代码
|
||||
codes.remove(code)
|
||||
|
||||
# 更新信号
|
||||
signals.loc[date, 'signal'] = ','.join(codes) if codes else ''
|
||||
|
||||
"""应用回调处理"""
|
||||
return signals
|
||||
|
||||
# ===== 可选回调方法 =====
|
||||
|
||||
# 可选回调方法(子类可覆盖)
|
||||
def before_entry(self, code: str, price: float, **kwargs) -> bool:
|
||||
"""
|
||||
入场前检查
|
||||
|
||||
Args:
|
||||
code: 标的代码
|
||||
price: 入场价格
|
||||
**kwargs: 其他参数(premium, history等)
|
||||
|
||||
Returns:
|
||||
是否允许入场
|
||||
"""
|
||||
# 默认:允许入场
|
||||
"""入场前检查"""
|
||||
return True
|
||||
|
||||
def after_entry(self, trade, **kwargs) -> None:
|
||||
def after_entry(self, code: str, price: float, **kwargs) -> None:
|
||||
"""入场后处理"""
|
||||
pass
|
||||
|
||||
@@ -217,141 +124,21 @@ class StrategyBase(ABC):
|
||||
"""出场前检查"""
|
||||
return True
|
||||
|
||||
def after_exit(self, trade, **kwargs) -> None:
|
||||
def after_exit(self, position: Position, **kwargs) -> None:
|
||||
"""出场后处理"""
|
||||
pass
|
||||
|
||||
def dynamic_stoploss(self, position: Position) -> float:
|
||||
"""
|
||||
动态止损
|
||||
|
||||
Args:
|
||||
position: 持仓信息
|
||||
|
||||
Returns:
|
||||
止损比例
|
||||
"""
|
||||
# 默认:返回固定止损
|
||||
"""动态止损"""
|
||||
return self.stoploss
|
||||
|
||||
def custom_exit(self, position: Position) -> bool:
|
||||
"""
|
||||
自定义出场条件
|
||||
|
||||
Args:
|
||||
position: 持仓信息
|
||||
|
||||
Returns:
|
||||
是否触发出场
|
||||
"""
|
||||
"""自定义出场条件"""
|
||||
return False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(name={self.name})"
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
"""
|
||||
配置加载器
|
||||
|
||||
支持YAML配置文件加载和验证
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
"""
|
||||
初始化配置加载器
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
"""
|
||||
self._config_path = Path(config_path)
|
||||
self._config = None
|
||||
|
||||
def load(self) -> Dict:
|
||||
"""加载配置"""
|
||||
if not self._config_path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {self._config_path}")
|
||||
|
||||
with open(self._config_path, 'r', encoding='utf-8') as f:
|
||||
self._config = yaml.safe_load(f)
|
||||
|
||||
return self._config
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""验证配置"""
|
||||
if self._config is None:
|
||||
self.load()
|
||||
|
||||
# 必须字段
|
||||
required_fields = ['strategy', 'factors', 'signal']
|
||||
|
||||
for field in required_fields:
|
||||
if field not in self._config:
|
||||
raise ValueError(f"Missing required field: {field}")
|
||||
|
||||
return True
|
||||
|
||||
def get_strategy_config(self) -> StrategyConfig:
|
||||
"""获取策略配置"""
|
||||
if self._config is None:
|
||||
self.load()
|
||||
|
||||
return StrategyConfig(
|
||||
name=self._config['strategy']['name'],
|
||||
version=self._config['strategy'].get('version', 1),
|
||||
factors=self._config['factors'],
|
||||
signal=self._config['signal'],
|
||||
callbacks=self._config.get('callbacks', {}),
|
||||
params=self._config.get('params', {})
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_yaml(yaml_str: str) -> Dict:
|
||||
"""从YAML字符串加载"""
|
||||
return yaml.safe_load(yaml_str)
|
||||
|
||||
|
||||
# 示例策略实现
|
||||
class RotationStrategy(StrategyBase):
|
||||
"""
|
||||
ETF轮动策略
|
||||
|
||||
基于动量因子 + Top N选股
|
||||
"""
|
||||
|
||||
name = "rotation"
|
||||
select_num = 3
|
||||
|
||||
def init_factors(self) -> FactorCombiner:
|
||||
"""初始化动量因子"""
|
||||
FactorRegistry.clear()
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
|
||||
return FactorCombiner([
|
||||
FactorRegistry.get('momentum', n_days=25, crash_filter=True)
|
||||
])
|
||||
|
||||
def init_signal_generator(self) -> SignalGenerator:
|
||||
"""初始化Top N选股器"""
|
||||
from framework.signals import TopNSelector
|
||||
|
||||
return TopNSelector(
|
||||
select_num=self.select_num,
|
||||
min_score=0.0
|
||||
)
|
||||
|
||||
def before_entry(self, code: str, price: float, **kwargs) -> bool:
|
||||
"""入场前:溢价过滤"""
|
||||
premium = kwargs.get('premium', 0)
|
||||
|
||||
# 溢价超过10%拒绝入场
|
||||
if premium > 0.10:
|
||||
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def dynamic_stoploss(self, position: Position) -> float:
|
||||
"""动态止损:根据持仓时间调整"""
|
||||
if position.holding_days >= 10:
|
||||
return -0.03 # 10天后收紧止损
|
||||
elif position.holding_days >= 5:
|
||||
return -0.05
|
||||
return -0.10
|
||||
# 导出抽象接口
|
||||
__all__ = ['StrategyBase']
|
||||
Reference in New Issue
Block a user