diff --git a/framework/__init__.py b/framework/__init__.py index 9add992..5ab2962 100644 --- a/framework/__init__.py +++ b/framework/__init__.py @@ -1,21 +1,52 @@ """ -量化策略通用框架 +框架统一入口(通用) -融合Freqtrade回调机制 + 模块化因子设计 +只导出抽象接口,具体实现在strategies/shared/ """ -from .factors import FactorBase, FactorRegistry, FactorCombiner -from .signals import SignalGenerator, TopNSelector, TrendFollower, ReversalTrader -from .strategy import StrategyBase, RotationStrategy -from .risk import RiskControl, StopLossControl, PositionLimitControl -from .execution import Executor, BacktestExecutor, DryRunExecutor -from .config import ConfigLoader +# 因子层抽象 +from framework.factors import FactorBase, FactorRegistry, FactorCombiner + +# 信号层抽象 +from framework.signals import SignalGenerator + +# 风控层抽象 +from framework.risk import Position, RiskControl, CallbackHook + +# 策略层抽象 +from framework.strategy import StrategyBase + +# 执行层抽象 +from framework.execution import Portfolio, Executor, BacktestExecutor, DryRunExecutor + +# 配置层 +from framework.config import ConfigLoader, StrategyConfig + __all__ = [ - 'FactorBase', 'FactorRegistry', 'FactorCombiner', - 'SignalGenerator', 'TopNSelector', 'TrendFollower', 'ReversalTrader', - 'StrategyBase', 'RotationStrategy', - 'RiskControl', 'StopLossControl', 'PositionLimitControl', - 'Executor', 'BacktestExecutor', 'DryRunExecutor', + # 因子层 + 'FactorBase', + 'FactorRegistry', + 'FactorCombiner', + + # 信号层 + 'SignalGenerator', + + # 风控层 + 'Position', + 'RiskControl', + 'CallbackHook', + + # 策略层 + 'StrategyBase', + + # 执行层 + 'Portfolio', + 'Executor', + 'BacktestExecutor', + 'DryRunExecutor', + + # 配置层 'ConfigLoader', + 'StrategyConfig', ] \ No newline at end of file diff --git a/framework/config/__init__.py b/framework/config/__init__.py index 7246012..8683056 100644 --- a/framework/config/__init__.py +++ b/framework/config/__init__.py @@ -1,83 +1,103 @@ """ -配置层抽象设计 +配置层抽象接口(通用) -核心组件: -- ConfigLoader: 配置加载器 +只提供配置加载和验证机制 """ import yaml from typing import Dict, Any, Optional from pathlib import Path -from dataclasses import dataclass - - -@dataclass -class StrategyConfig: - """策略配置""" - name: str - version: int - factors: list - signal: dict - callbacks: dict - params: dict class ConfigLoader: """ - 配置加载器 + 配置加载器(通用) - 支持YAML配置文件加载和验证 + 支持从YAML文件加载配置 """ - def __init__(self, config_path: str): - """ - 初始化配置加载器 + def __init__(self, config_path: Optional[str] = None): + """初始化配置加载器""" + self._config: Dict[str, Any] = {} - Args: - config_path: 配置文件路径 - """ - self._config_path = Path(config_path) - self._config = None + if config_path: + self.load(config_path) - def load(self) -> Dict: - """加载配置""" - if not self._config_path.exists(): - raise FileNotFoundError(f"Config file not found: {self._config_path}") + def load(self, config_path: str) -> Dict[str, Any]: + """从YAML文件加载配置""" + path = Path(config_path) - with open(self._config_path, 'r', encoding='utf-8') as f: - self._config = yaml.safe_load(f) + if not path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + + with open(path, 'r', encoding='utf-8') as f: + self._config = yaml.safe_load(f) or {} return self._config - def validate(self) -> bool: - """验证配置""" - if self._config is None: - self.load() + def get(self, key: str, default: Any = None) -> Any: + """获取配置项""" + return self._config.get(key, default) + + def get_section(self, section: str) -> Dict[str, Any]: + """获取配置区块""" + return self._config.get(section, {}) + + def validate(self, required_keys: list) -> bool: + """验证必填配置项""" + missing = [key for key in required_keys if key not in self._config] - # 必须字段 - required_fields = ['strategy', 'factors', 'signal'] - - for field in required_fields: - if field not in self._config: - raise ValueError(f"Missing required field: {field}") + if missing: + raise ValueError(f"Missing required config keys: {missing}") return True - def get_strategy_config(self) -> StrategyConfig: - """获取策略配置""" - if self._config is None: - self.load() - - return StrategyConfig( - name=self._config['strategy']['name'], - version=self._config['strategy'].get('version', 1), - factors=self._config['factors'], - signal=self._config['signal'], - callbacks=self._config.get('callbacks', {}), - params=self._config.get('params', {}) - ) + def __repr__(self) -> str: + return f"ConfigLoader(keys={list(self._config.keys())})" + + +class StrategyConfig: + """ + 策略配置类(通用) - @staticmethod - def from_yaml(yaml_str: str) -> Dict: - """从YAML字符串加载""" - return yaml.safe_load(yaml_str) \ No newline at end of file + 用于封装策略配置 + """ + + def __init__(self, config: Dict[str, Any]): + """初始化策略配置""" + self._config = config + + @property + def name(self) -> str: + """策略名称""" + return self._config.get('strategy', {}).get('name', 'unknown') + + @property + def select_num(self) -> int: + """选中数量""" + return self._config.get('signal', {}).get('select_num', 3) + + @property + def stoploss(self) -> float: + """止损阈值""" + return self._config.get('risk', {}).get('stop_loss', -0.05) + + def get_factor_config(self) -> list: + """获取因子配置""" + return self._config.get('factors', []) + + def get_signal_config(self) -> dict: + """获取信号配置""" + return self._config.get('signal', {}) + + def get_risk_config(self) -> list: + """获取风控配置""" + return self._config.get('risk', []) + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return self._config.copy() + + +# 导出抽象接口 +__all__ = ['ConfigLoader', 'StrategyConfig'] \ No newline at end of file diff --git a/framework/execution/__init__.py b/framework/execution/__init__.py index ea8f200..7238b88 100644 --- a/framework/execution/__init__.py +++ b/framework/execution/__init__.py @@ -1,55 +1,119 @@ """ -执行层抽象设计 +执行层抽象接口(通用) -核心组件: -- Executor: 执行器抽象基类 -- BacktestExecutor: 回测执行器 -- DryRunExecutor: 模拟盘执行器 +只提供抽象基类和Portfolio数据结构,具体执行器可扩展 """ -import pandas as pd from abc import ABC, abstractmethod -from typing import Dict, Any, Optional, List -from dataclasses import dataclass +from typing import Dict, List, Optional +import pandas as pd from datetime import datetime +from framework.risk import Position + -@dataclass class Portfolio: - """持仓组合""" - positions: Dict[str, Any] # {code: Position} - cash: float - nav: float - trades: List[Any] + """ + 投资组合数据结构(通用) - def get_total_value(self) -> float: - """获取总价值""" - position_value = sum( - pos.quantity * pos.current_price - for pos in self.positions.values() + 用于管理持仓集合 + """ + + def __init__(self, initial_capital: float = 100000): + """初始化投资组合""" + self.initial_capital = initial_capital + self.cash = initial_capital + self.positions: Dict[str, Position] = {} + self.trades: List[Dict] = [] + self._net_value_history: List[float] = [] + + def add_position(self, code: str, price: float, quantity: float, time: datetime) -> None: + """添加持仓""" + position = Position( + code=code, + entry_price=price, + current_price=price, + entry_time=time, + quantity=quantity ) - return self.cash + position_value + self.positions[code] = position + self.cash -= price * quantity + + self.trades.append({ + 'action': 'BUY', + 'code': code, + 'price': price, + 'quantity': quantity, + 'time': time + }) - def get_position_codes(self) -> List[str]: - """获取持仓代码列表""" - return list(self.positions.keys()) + def remove_position(self, code: str, price: float, time: datetime) -> float: + """移除持仓""" + if code not in self.positions: + return 0 + + position = self.positions[code] + profit = (price - position.entry_price) * position.quantity + self.cash += price * position.quantity + + del self.positions[code] + + self.trades.append({ + 'action': 'SELL', + 'code': code, + 'price': price, + 'quantity': position.quantity, + 'time': time, + 'profit': profit + }) + + return profit + + def update_prices(self, prices: Dict[str, float]) -> None: + """更新持仓价格""" + for code, price in prices.items(): + if code in self.positions: + self.positions[code].current_price = price + + def get_net_value(self) -> float: + """计算净值""" + positions_value = sum( + pos.current_price * pos.quantity for pos in self.positions.values() + ) + return self.cash + positions_value + + def record_net_value(self) -> None: + """记录当前净值""" + self._net_value_history.append(self.get_net_value()) + + def get_net_value_series(self) -> pd.Series: + """获取净值序列""" + return pd.Series(self._net_value_history) + + def get_weight(self, code: str) -> float: + """计算持仓权重""" + if code not in self.positions: + return 0 + + position_value = self.positions[code].current_price * self.positions[code].quantity + return position_value / self.get_net_value() + + def __repr__(self) -> str: + return f"Portfolio(capital={self.cash:.2f}, positions={len(self.positions)})" class Executor(ABC): """ 执行器抽象基类 - 支持不同执行模式: - - backtest: 回测模式 - - dry_run: 模拟盘模式 - - live: 实盘模式(TODO) + 所有执行器必须实现execute方法 """ mode: str = "base" - def __init__(self, config: Optional[Dict] = None): - self._config = config or {} - self._portfolio = None + def __init__(self, **params): + """初始化执行器参数""" + self._params = params @abstractmethod def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> Portfolio: @@ -58,121 +122,71 @@ class Executor(ABC): Args: signals: 信号DataFrame - data: 价格数据 + data: OHLCV数据 Returns: - 持仓组合 + Portfolio对象 """ pass - @abstractmethod - def get_mode(self) -> str: - """获取执行模式""" - pass - - @property - def portfolio(self) -> Optional[Portfolio]: - """获取当前持仓""" - return self._portfolio + def __repr__(self) -> str: + params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()]) + return f"{self.__class__.__name__}({params_str})" class BacktestExecutor(Executor): """ - 回测执行器 + 回测执行器(通用骨架) - 执行回测逻辑: - - 处理信号 - - 计算净值 - - 记录交易 + 具体回测逻辑需要在strategies中定制实现 """ mode = "backtest" - def __init__( - self, - initial_capital: float = 100000.0, - trade_cost: float = 0.001 - ): - super().__init__() + def __init__(self, initial_capital: float = 100000, trade_cost: float = 0.001): + super().__init__(initial_capital=initial_capital, trade_cost=trade_cost) self.initial_capital = initial_capital self.trade_cost = trade_cost def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> Portfolio: - """执行回测""" - # 初始化持仓 - self._portfolio = Portfolio( - positions={}, - cash=self.initial_capital, - nav=1.0, - trades=[] - ) + """ + 执行回测(简化版本) - # 回测逻辑(简化版) - result = pd.DataFrame(index=signals.index) - result['nav'] = 1.0 - result['daily_return'] = 0.0 + 完整回测逻辑需要定制实现 + """ + portfolio = Portfolio(self.initial_capital) - # TODO: 完整回测逻辑迁移 + # 这里只提供骨架,具体逻辑需要定制实现 + # 包括:净值计算、交易成本扣除、基准对比等 - return self._portfolio - - def get_mode(self) -> str: - return "backtest" + return portfolio class DryRunExecutor(Executor): """ - 模拟盘执行器 + Dry-run执行器(通用) - 执行模拟交易: - - 模拟下单 - - 模拟成交 - - 模拟持仓更新 + 用于模拟运行,不实际执行交易 """ mode = "dry_run" - def __init__( - self, - initial_capital: float = 100000.0, - simulated_exchange = None - ): - super().__init__() - self.initial_capital = initial_capital - self.simulated_exchange = simulated_exchange + def __init__(self, verbose: bool = True): + super().__init__(verbose=verbose) + self.verbose = verbose def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> Portfolio: - """执行模拟盘""" - # 初始化持仓 - self._portfolio = Portfolio( - positions={}, - cash=self.initial_capital, - nav=1.0, - trades=[] - ) + """模拟执行""" + portfolio = Portfolio(100000) - # 模拟执行逻辑 - # TODO: 模拟订单执行 + for date in signals.index: + signal = signals.loc[date, 'signal'] + + if signal and self.verbose: + print(f"[{date}] Signal: {signal}") - return self._portfolio - - def get_mode(self) -> str: - return "dry_run" - - def simulate_order(self, code: str, direction: str, quantity: float, price: float): - """模拟下单""" - # 记录模拟订单 - print(f"[DRY_RUN] {direction} {quantity} {code} @ {price}") - - # 更新持仓 - if direction == 'BUY': - # 模拟买入 - cost = quantity * price - if cost <= self._portfolio.cash: - self._portfolio.cash -= cost - # TODO: 创建Position对象 - elif direction == 'SELL': - # 模拟卖出 - if code in self._portfolio.positions: - # TODO: 平仓逻辑 - pass \ No newline at end of file + return portfolio + + +# 导出抽象接口 +__all__ = ['Portfolio', 'Executor', 'BacktestExecutor', 'DryRunExecutor'] \ No newline at end of file diff --git a/framework/factors/__init__.py b/framework/factors/__init__.py index ec48e82..e5c9d30 100644 --- a/framework/factors/__init__.py +++ b/framework/factors/__init__.py @@ -1,54 +1,28 @@ """ -因子层抽象设计 +因子层抽象接口(通用) -核心组件: -- FactorBase: 因子抽象基类 -- FactorRegistry: 因子注册器 -- FactorCombiner: 因子组合器 +只提供抽象基类和注册机制,具体因子实现在strategies/shared/factors/ """ +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Any, Type import pandas as pd import numpy as np -from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Any -from dataclasses import dataclass - - -@dataclass -class FactorMeta: - """因子元信息""" - name: str - category: str # 'momentum', 'trend', 'reversal', 'volatility', 'fundamental' - params: Dict[str, Any] - description: str = "" class FactorBase(ABC): """ 因子抽象基类 - 所有因子必须继承此基类,实现compute方法。 - 支持参数配置、数据验证、元信息管理。 + 所有因子必须实现compute方法 """ - # 类属性(可被配置覆盖) name: str = "base" category: str = "unknown" def __init__(self, **params): - """ - 初始化因子 - - Args: - **params: 因子参数(如n_days=25, period=14等) - """ + """初始化因子参数""" self._params = params - self._meta = FactorMeta( - name=self.name, - category=self.category, - params=params, - description=self.__doc__ or "" - ) @abstractmethod def compute(self, data: pd.DataFrame) -> pd.Series: @@ -56,139 +30,84 @@ class FactorBase(ABC): 计算因子值 Args: - data: 包含OHLCV数据的DataFrame + data: OHLCV数据,必须包含'close'列 Returns: - 因子值序列(Series) + 因子值序列 """ pass - @property - def params(self) -> Dict[str, Any]: - """获取因子参数""" - return self._params - - @property - def meta(self) -> FactorMeta: - """获取因子元信息""" - return self._meta - def validate_data(self, data: pd.DataFrame) -> bool: - """ - 验证数据是否满足计算要求 + """验证数据是否满足计算要求""" + if 'close' not in data.columns: + return False - Args: - data: 数据DataFrame - - Returns: - 是否满足要求 - """ - # 默认验证:数据长度 >= 最小周期 min_periods = self._params.get('min_periods', 20) return len(data) >= min_periods def __repr__(self) -> str: - return f"{self.__class__.__name__}(name={self.name}, params={self._params})" + params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()]) + return f"{self.__class__.__name__}({params_str})" class FactorRegistry: """ - 因子注册器 + 因子注册器(通用) - 管理所有注册的因子,支持: - - 注册因子类 - - 获取因子实例 - - 列出可用因子 - - 按类别筛选因子 + 管理因子类的注册和获取 """ - _factors: Dict[str, type] = {} + _factors: Dict[str, Type[FactorBase]] = {} @classmethod - def register(cls, factor_class: type) -> None: - """ - 注册因子类 - - Args: - factor_class: 因子类(必须继承FactorBase) - """ - if not isinstance(factor_class, type) or not issubclass(factor_class, FactorBase): - raise TypeError(f"factor_class must be a subclass of FactorBase") - - # 创建临时实例获取名称 + def register(cls, factor_class: Type[FactorBase]) -> None: + """注册因子类""" temp_instance = factor_class() name = temp_instance.name + + if name in cls._factors: + print(f"因子已注册,覆盖: {name}") + cls._factors[name] = factor_class - print(f"✓ 因子已注册: {name} ({factor_class.__name__})") @classmethod def get(cls, name: str, **params) -> FactorBase: - """ - 获取因子实例 - - Args: - name: 因子名称 - **params: 因子参数 - - Returns: - 因子实例 - """ + """获取因子实例""" if name not in cls._factors: - raise KeyError(f"Factor '{name}' not registered. Available: {cls.list()}") + raise ValueError(f"因子未注册: {name}") factor_class = cls._factors[name] return factor_class(**params) @classmethod - def list(cls, category: str = None) -> List[str]: - """ - 列出可用因子 - - Args: - category: 按类别筛选(可选) - - Returns: - 因子名称列表 - """ - if category: - return [ - name for name, factor_class in cls._factors.items() - if factor_class().category == category - ] + def list_factors(cls) -> List[str]: + """列出所有已注册因子""" return list(cls._factors.keys()) @classmethod - def list_by_category(cls) -> Dict[str, List[str]]: - """ - 按类别列出因子 - - Returns: - 类别→因子列表字典 - """ - result = {} - for name, factor_class in cls._factors.items(): - cat = factor_class().category - if cat not in result: - result[cat] = [] - result[cat].append(name) - return result + def clear(cls) -> None: + """清空注册表""" + cls._factors = {} @classmethod - def clear(cls) -> None: - """清空注册表(用于测试)""" - cls._factors.clear() + def get_category(cls, name: str) -> str: + """获取因子类别""" + if name not in cls._factors: + return "unknown" + + temp_instance = cls._factors[name]() + return temp_instance.category class FactorCombiner: """ - 因子组合器 + 因子组合器(通用) - 支持多因子加权组合,用于: - - 多因子策略 - - 因子权重调整 - - 因子结果合并 + 支持多因子加权组合 """ + SUPPORTED_METHODS = ['weighted_sum', 'rank_average', 'zscore_sum', 'equal_weight'] + def __init__( self, factors: List[FactorBase], @@ -196,87 +115,77 @@ class FactorCombiner: method: str = 'weighted_sum' ): """ - 初始化因子组合器 + 初始化组合器 Args: factors: 因子实例列表 - weights: 权重列表(默认等权) - method: 组合方法 ('weighted_sum', 'average', 'max', 'min') + weights: 因子权重列表(可选) + method: 组合方法(weighted_sum/rank_average/zscore_sum/equal_weight) """ + if not factors: + raise ValueError("factors list cannot be empty") + + if method not in self.SUPPORTED_METHODS: + raise ValueError(f"Unsupported method: {method}") + self._factors = factors - self._weights = weights or [1.0 / len(factors)] * len(factors) + + if weights is None: + self._weights = [1.0 / len(factors)] * len(factors) + else: + if len(weights) != len(factors): + raise ValueError("weights length must match factors length") + self._weights = weights + self._method = method - - # 验证权重 - if len(self._weights) != len(factors): - raise ValueError(f"weights length ({len(self._weights)}) != factors length ({len(factors)})") - - # 归一化权重 - total_weight = sum(self._weights) - self._weights = [w / total_weight for w in self._weights] def compute(self, data: pd.DataFrame) -> pd.DataFrame: """ 计算所有因子并组合 - Args: - data: 输入数据 - Returns: - 包含各因子值和组合因子值的DataFrame + DataFrame包含各因子值和combined列 """ result = pd.DataFrame(index=data.index) # 计算各因子 for i, factor in enumerate(self._factors): - # 验证数据 - if not factor.validate_data(data): - print(f"⚠ 因子 {factor.name} 数据验证失败,跳过") - continue - - # 计算因子值 factor_values = factor.compute(data) - result[factor.name] = factor_values - - # 加权因子值 - result[f"{factor.name}_weighted"] = factor_values * self._weights[i] + col_name = f"{factor.name}" + result[col_name] = factor_values # 组合因子值 - weighted_cols = [f"{f.name}_weighted" for f in self._factors if f.name in result.columns] - if self._method == 'weighted_sum': - result['combined'] = result[weighted_cols].sum(axis=1) - elif self._method == 'average': - factor_cols = [f.name for f in self._factors if f.name in result.columns] + weighted_cols = [f.name for f in self._factors] + result['combined'] = result[weighted_cols].apply( + lambda row: sum(row[col] * self._weights[i] for i, col in enumerate(weighted_cols) if pd.notna(row[col])), + axis=1 + ) + + elif self._method == 'equal_weight': + factor_cols = [f.name for f in self._factors] result['combined'] = result[factor_cols].mean(axis=1) - elif self._method == 'max': - factor_cols = [f.name for f in self._factors if f.name in result.columns] - result['combined'] = result[factor_cols].max(axis=1) - elif self._method == 'min': - factor_cols = [f.name for f in self._factors if f.name in result.columns] - result['combined'] = result[factor_cols].min(axis=1) - else: - raise ValueError(f"Unknown method: {self._method}") + + elif self._method == 'rank_average': + factor_cols = [f.name for f in self._factors] + ranks = result[factor_cols].rank(axis=1) + result['combined'] = ranks.mean(axis=1) + + elif self._method == 'zscore_sum': + factor_cols = [f.name for f in self._factors] + zscores = result[factor_cols].apply(lambda x: (x - x.mean()) / x.std()) + result['combined'] = zscores.sum(axis=1) return result - @property - def factors(self) -> List[FactorBase]: - """获取因子列表""" - return self._factors - - @property - def weights(self) -> List[float]: - """获取权重列表""" - return self._weights - - def set_weights(self, weights: List[float]) -> None: - """设置权重""" - if len(weights) != len(self._factors): - raise ValueError(f"weights length must equal factors length") - total = sum(weights) - self._weights = [w / total for w in weights] + def get_factor_names(self) -> List[str]: + """获取因子名称列表""" + return [f.name for f in self._factors] def __repr__(self) -> str: factor_names = [f.name for f in self._factors] - return f"FactorCombiner(factors={factor_names}, weights={self._weights})" \ No newline at end of file + return f"FactorCombiner(factors={factor_names}, weights={self._weights}, method={self._method})" + + +# 导出抽象接口 +__all__ = ['FactorBase', 'FactorRegistry', 'FactorCombiner'] \ No newline at end of file diff --git a/framework/factors/momentum.py b/framework/factors/momentum.py deleted file mode 100644 index 99139c0..0000000 --- a/framework/factors/momentum.py +++ /dev/null @@ -1,312 +0,0 @@ -""" -动量因子实现 - -基于加权线性回归动量的因子 -""" - -import pandas as pd -import numpy as np -import math -from typing import Optional - -from framework.factors import FactorBase - - -class MomentumFactor(FactorBase): - """ - 动量因子 - - 计算加权线性回归动量得分: - 得分 = 年化收益率 × R² - - 参数: - - n_days: 动量窗口(默认25) - - weighted: 是否加权(默认True) - - crash_filter: 是否启用崩盘过滤(默认True) - """ - - name = "momentum" - category = "momentum" - - def __init__( - self, - n_days: int = 25, - weighted: bool = True, - crash_filter: bool = True - ): - super().__init__(n_days=n_days, weighted=weighted, crash_filter=crash_filter) - self.n_days = n_days - self.weighted = weighted - self.crash_filter = crash_filter - - def compute(self, data: pd.DataFrame) -> pd.Series: - """计算动量因子值""" - if 'close' not in data.columns: - raise ValueError("data must contain 'close' column") - - prices = data['close'] - - if self.weighted: - # 加权动量得分 - factor_values = prices.rolling(self.n_days).apply( - lambda x: self._weighted_momentum_score(x.values), - raw=False - ) - else: - # 简单动量 - factor_values = prices.pct_change(self.n_days) - - # 应用崩盘过滤 - if self.crash_filter: - factor_values = self._apply_crash_filter(prices, factor_values) - - return factor_values - - def _weighted_momentum_score(self, prices: np.ndarray) -> float: - """计算加权动量得分""" - if len(prices) < 5: - return 0.0 - - y = np.log(prices) - x = np.arange(len(y)) - weights = np.linspace(1, 2, len(y)) - - # 加权线性回归 - slope, intercept = np.polyfit(x, y, 1, w=weights) - annualized_returns = math.exp(slope * 250) - 1 - - # 加权R² - y_pred = slope * x + intercept - ss_res = np.sum(weights * (y - y_pred) ** 2) - ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2) - r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0 - - return annualized_returns * r2 - - def _apply_crash_filter( - self, - prices: pd.Series, - factor_values: pd.Series - ) -> pd.Series: - """崩盘过滤:连续3天跌>5%清零""" - result = factor_values.copy() - - for i in range(3, len(prices)): - r1 = prices.iloc[i] / prices.iloc[i-1] - r2 = prices.iloc[i-1] / prices.iloc[i-2] - r3 = prices.iloc[i-2] / prices.iloc[i-3] - - # 条件1:任一天跌>5% - con1 = min(r1, r2, r3) < 0.95 - # 条件2:连续下跌且累计跌>5% - con2 = (r1 < 1) and (r2 < 1) and (r3 < 1) and (prices.iloc[i] / prices.iloc[i-3] < 0.95) - - if con1 or con2: - result.iloc[i] = 0.0 - - return result - - -class TrendFactor(FactorBase): - """ - 趋势因子 - - 计算趋势强度: - - MA交叉偏离度 - - MACD趋势 - - 参数: - - method: 趋势方法('ma_cross', 'macd') - - fast: 快线周期(默认5) - - slow: 慢线周期(默认20) - """ - - name = "trend" - category = "trend" - - def __init__( - self, - method: str = 'ma_cross', - fast: int = 5, - slow: int = 20 - ): - super().__init__(method=method, fast=fast, slow=slow) - self.method = method - self.fast = fast - self.slow = slow - - def compute(self, data: pd.DataFrame) -> pd.Series: - """计算趋势因子值""" - if 'close' not in data.columns: - raise ValueError("data must contain 'close' column") - - prices = data['close'] - - if self.method == 'ma_cross': - # MA交叉偏离度 - fast_ma = prices.rolling(self.fast).mean() - slow_ma = prices.rolling(self.slow).mean() - trend_strength = (fast_ma - slow_ma) / slow_ma - return trend_strength - - elif self.method == 'macd': - # MACD趋势 - ema12 = prices.ewm(span=12).mean() - ema26 = prices.ewm(span=26).mean() - macd = ema12 - ema26 - signal = macd.ewm(span=9).mean() - return macd - signal - - else: - raise ValueError(f"Unknown method: {self.method}") - - -class ReversalFactor(FactorBase): - """ - 反转因子 - - 计算超买超卖信号: - - RSI偏离度 - - KDJ - - 参数: - - method: 反转方法('rsi', 'kdj') - - period: 周期(默认14) - - overbought: 超买阈值(默认70) - - oversold: 超卖阈值(默认30) - """ - - name = "reversal" - category = "reversal" - - def __init__( - self, - method: str = 'rsi', - period: int = 14, - overbought: float = 70, - oversold: float = 30 - ): - super().__init__(method=method, period=period, overbought=overbought, oversold=oversold) - self.method = method - self.period = period - self.overbought = overbought - self.oversold = oversold - - def compute(self, data: pd.DataFrame) -> pd.Series: - """计算反转因子值""" - if 'close' not in data.columns: - raise ValueError("data must contain 'close' column") - - prices = data['close'] - - if self.method == 'rsi': - # RSI反转信号 - rsi = self._compute_rsi(prices, self.period) - - # 超买超卖偏离度 - # 超买 → 负值(反转向下信号) - # 超卖 → 正值(反转向上信号) - reversal_signal = pd.Series(index=prices.index, dtype=float) - reversal_signal = np.where( - rsi > self.overbought, - -(rsi - self.overbought) / (100 - self.overbought), # 超买:负值 - np.where( - rsi < self.oversold, - (self.oversold - rsi) / self.oversold, # 超卖:正值 - 0 # 正常区间:0 - ) - ) - return pd.Series(reversal_signal, index=prices.index) - - elif self.method == 'kdj': - # KDJ反转信号 - return self._compute_kdj(data) - - else: - raise ValueError(f"Unknown method: {self.method}") - - def _compute_rsi(self, prices: pd.Series, period: int) -> pd.Series: - """计算RSI""" - delta = prices.diff() - gain = delta.where(delta > 0, 0) - loss = (-delta).where(delta < 0, 0) - - avg_gain = gain.rolling(period).mean() - avg_loss = loss.rolling(period).mean() - - rs = avg_gain / avg_loss - rsi = 100 - (100 / (1 + rs)) - return rsi - - def _compute_kdj(self, data: pd.DataFrame) -> pd.Series: - """计算KDJ反转信号""" - low = data['low'] - high = data['high'] - close = data['close'] - - # 计算K、D、J - low_min = low.rolling(self.period).min() - high_max = high.rolling(self.period).max() - - rsv = (close - low_min) / (high_max - low_min) * 100 - - k = rsv.ewm(alpha=1/3).mean() - d = k.ewm(alpha=1/3).mean() - j = 3 * k - 2 * d - - # J值偏离度作为反转信号 - return j - - -class VolatilityFactor(FactorBase): - """ - 波动率因子 - - 计算价格波动率: - - ATR - - 标准差 - - 参数: - - method: 波动率方法('atr', 'std') - - period: 周期(默认20) - """ - - name = "volatility" - category = "volatility" - - def __init__( - self, - method: str = 'std', - period: int = 20 - ): - super().__init__(method=method, period=period) - self.method = method - self.period = period - - def compute(self, data: pd.DataFrame) -> pd.Series: - """计算波动率因子值""" - if self.method == 'std': - # 标准差波动率 - return data['close'].rolling(self.period).std() - - elif self.method == 'atr': - # ATR波动率 - return self._compute_atr(data) - - else: - raise ValueError(f"Unknown method: {self.method}") - - def _compute_atr(self, data: pd.DataFrame) -> pd.Series: - """计算ATR""" - high = data['high'] - low = data['low'] - close = data['close'] - - prev_close = close.shift(1) - tr = pd.concat([ - high - low, - (high - prev_close).abs(), - (low - prev_close).abs() - ], axis=1).max(axis=1) - - return tr.rolling(self.period).mean() \ No newline at end of file diff --git a/framework/risk/__init__.py b/framework/risk/__init__.py index 1596769..a2a2c8c 100644 --- a/framework/risk/__init__.py +++ b/framework/risk/__init__.py @@ -1,351 +1,188 @@ """ -回调钩子与风控层设计 +风控层抽象接口(通用) -核心组件: -- RiskControl: 风控抽象基类 -- StopLossControl: 止损控制 -- PositionLimitControl: 仓位限制控制 -- CallbackHook: 回调钩子管理 +只提供抽象基类和回调机制,具体风控组件在strategies/shared/risk/ """ -import pandas as pd from abc import ABC, abstractmethod -from typing import Dict, Any, Optional, List +from typing import Dict, List, Any, Callable, Optional from dataclasses import dataclass from datetime import datetime @dataclass class Position: - """持仓信息""" + """ + 持仓数据结构(通用) + + 用于表示单个持仓的状态 + """ code: str entry_price: float - entry_date: datetime current_price: float - current_date: datetime - quantity: float - weight: float + entry_time: datetime + quantity: float = 1.0 + weight: float = 1.0 @property def profit_ratio(self) -> float: - """盈亏比例""" + """计算盈亏比例""" return (self.current_price - self.entry_price) / self.entry_price @property - def holding_days(self) -> int: - """持仓天数""" - return (self.current_date - self.entry_date).days + def profit_amount(self) -> float: + """计算盈亏金额""" + return (self.current_price - self.entry_price) * self.quantity @property - def is_profit(self) -> bool: - """是否盈利""" - return self.profit_ratio > 0 - - -@dataclass -class Trade: - """交易信息""" - code: str - direction: str # 'entry' or 'exit' - price: float - date: datetime - quantity: float - reason: str = "" + def holding_days(self) -> int: + """计算持仓天数""" + if self.entry_time is None: + return 0 + return (datetime.now() - self.entry_time).days + + def __repr__(self) -> str: + return f"Position(code={self.code}, profit={self.profit_ratio:.2%}, days={self.holding_days})" class RiskControl(ABC): """ - 风控抽象基类 + 风控组件抽象基类 - 所有风控组件必须继承此基类。 + 所有风控组件必须实现check方法 """ name: str = "base" def __init__(self, **params): + """初始化风控参数""" self._params = params @abstractmethod - def check(self, position: Optional[Position], **kwargs) -> bool: + def check(self, position: Position, **kwargs) -> bool: """ - 风控检查 + 检查风控条件 Args: - position: 持仓信息(可选) - **kwargs: 其他参数 + position: 持仓对象 + kwargs: 额外参数(如premium、history等) Returns: - 是否通过检查 + True表示通过检查,False表示触发风控 """ pass - @abstractmethod - def apply(self, position: Position) -> Optional[float]: + def apply(self, position: Position) -> Any: """ - 应用风控 + 应用风控规则(可选) Args: - position: 持仓信息 + position: 持仓对象 Returns: - 应用结果(如止损价格、仓位调整比例等) + 风控结果(如止损价格、建议仓位等) """ - pass - - @property - def params(self) -> Dict[str, Any]: - return self._params - - -class StopLossControl(RiskControl): - """ - 止损控制 - - 参数: - - threshold: 止损阈值(默认-0.05) - - trailing: 是否跟踪止损(默认False) - - trailing_percent: 跟踪止损比例(默认0.03) - """ - - name = "stop_loss" - - def __init__( - self, - threshold: float = -0.05, - trailing: bool = False, - trailing_percent: float = 0.03 - ): - super().__init__( - threshold=threshold, - trailing=trailing, - trailing_percent=trailing_percent - ) - self.threshold = threshold - self.trailing = trailing - self.trailing_percent = trailing_percent - self._highest_price = {} # 跟踪最高价 - - def check(self, position: Optional[Position], **kwargs) -> bool: - """检查是否触发止损""" - if position is None: - return True - - # 更新最高价(跟踪止损) - if self.trailing: - if position.code not in self._highest_price: - self._highest_price[position.code] = position.entry_price - self._highest_price[position.code] = max( - self._highest_price[position.code], - position.current_price - ) - - # 检查止损 - if self.trailing: - # 跟踪止损:从最高价回撤超过阈值 - highest = self._highest_price[position.code] - drawdown = (position.current_price - highest) / highest - return drawdown > -self.trailing_percent - else: - # 固定止损:从入场价亏损超过阈值 - return position.profit_ratio > self.threshold - - def apply(self, position: Position) -> Optional[float]: - """返回止损价格""" - if self.trailing: - highest = self._highest_price.get(position.code, position.entry_price) - return highest * (1 - self.trailing_percent) - else: - return position.entry_price * (1 + self.threshold) - - -class PositionLimitControl(RiskControl): - """ - 仓位限制控制 - - 参数: - - max_position: 单品种最大仓位(默认0.33) - - max_total: 总仓位上限(默认1.0) - """ - - name = "position_limit" - - def __init__( - self, - max_position: float = 0.33, - max_total: float = 1.0 - ): - super().__init__( - max_position=max_position, - max_total=max_total - ) - self.max_position = max_position - self.max_total = max_total - - def check(self, position: Optional[Position], **kwargs) -> bool: - """检查仓位是否超限""" - if position is None: - return True - - # 检查单品种仓位 - if position.weight > self.max_position: - return False - - return True - - def apply(self, position: Position) -> Optional[float]: - """返回建议仓位""" - return min(position.weight, self.max_position) - - -class PremiumControl(RiskControl): - """ - 溢价控制 - - 参数: - - threshold: 溢价阈值(默认0.10) - - mode: 控制模式('filter'或'penalize') - """ - - name = "premium" - - def __init__( - self, - threshold: float = 0.10, - mode: str = 'filter' - ): - super().__init__( - threshold=threshold, - mode=mode - ) - self.threshold = threshold - self.mode = mode - - def check(self, position: Optional[Position], **kwargs) -> bool: - """检查溢价是否超限""" - premium = kwargs.get('premium', 0) - - if self.mode == 'filter': - # 完全排除 - return premium <= self.threshold - else: - # 仅降权,允许通过 - return True - - def apply(self, position: Position) -> Optional[float]: - """返回溢价惩罚系数""" - if self.mode == 'penalize': - return 0.5 # 降权50% return None + + def __repr__(self) -> str: + params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()]) + return f"{self.__class__.__name__}({params_str})" class CallbackHook: """ - 回调钩子管理 + 回调钩子管理(通用) - 支持策略生命周期回调: - - before_entry: 入场前检查 - - after_entry: 入场后处理 - - before_exit: 出场前检查 - - after_exit: 出场后处理 - - dynamic_stoploss: 动态止损 - - custom_exit: 自定义出场 + 支持在策略生命周期的关键节点注入自定义逻辑 """ + SUPPORTED_HOOKS = [ + 'before_entry', # 入场前检查 + 'after_entry', # 入场后处理 + 'before_exit', # 出场前检查 + 'after_exit', # 出场后处理 + 'dynamic_stoploss', # 动态止损计算 + 'custom_exit' # 自定义出场条件 + ] + def __init__(self): - self._hooks = { - 'before_entry': [], - 'after_entry': [], - 'before_exit': [], - 'after_exit': [], - 'dynamic_stoploss': [], - 'custom_exit': [] + """初始化回调钩子""" + self._hooks: Dict[str, List[Callable]] = { + hook: [] for hook in self.SUPPORTED_HOOKS } - def register(self, hook_name: str, callback: callable) -> None: - """注册回调""" + def register(self, hook_name: str, callback: Callable) -> None: + """注册回调函数""" if hook_name not in self._hooks: - raise ValueError(f"Unknown hook: {hook_name}") + raise ValueError(f"Unsupported hook: {hook_name}") + self._hooks[hook_name].append(callback) def trigger(self, hook_name: str, *args, **kwargs) -> Any: - """触发回调""" + """ + 触发回调 + + Args: + hook_name: 钩子名称 + args: 位置参数 + kwargs: 关键字参数 + + Returns: + 回调结果(根据钩子类型返回不同结果) + """ if hook_name not in self._hooks: - raise ValueError(f"Unknown hook: {hook_name}") + raise ValueError(f"Unsupported hook: {hook_name}") + + callbacks = self._hooks[hook_name] + + if not callbacks: + # 默认行为 + if hook_name == 'dynamic_stoploss': + return kwargs.get('default_stoploss', -0.05) + elif hook_name in ['before_entry', 'before_exit']: + return True + elif hook_name == 'custom_exit': + return False + return None results = [] - for callback in self._hooks[hook_name]: - try: - result = callback(*args, **kwargs) - results.append(result) - except Exception as e: - print(f"⚠ Callback error: {e}") + for callback in callbacks: + result = callback(*args, **kwargs) + results.append(result) - # before_entry和before_exit需要所有回调返回True + # 根据钩子类型返回不同结果 if hook_name in ['before_entry', 'before_exit']: + # 所有回调返回True才允许 return all(results) - # dynamic_stoploss返回最小的止损值 if hook_name == 'dynamic_stoploss': - return min(results) if results else -0.05 + # 返回最小止损值(最严格) + return min(results) - # custom_exit返回是否有任一回调触发出场 if hook_name == 'custom_exit': + # 任一回调触发出场 return any(results) - return results + # 其他钩子返回最后一个结果 + return results[-1] if results else None - def clear(self, hook_name: str = None) -> None: + def clear(self, hook_name: Optional[str] = None) -> None: """清空回调""" if hook_name: - self._hooks[hook_name] = [] + if hook_name in self._hooks: + self._hooks[hook_name] = [] else: - for key in self._hooks: - self._hooks[key] = [] + for hook in self._hooks: + self._hooks[hook] = [] + + def list_hooks(self) -> List[str]: + """列出支持的钩子""" + return self.SUPPORTED_HOOKS + + def get_callbacks(self, hook_name: str) -> List[Callable]: + """获取钩子的所有回调""" + return self._hooks.get(hook_name, []) -# 便捷回调函数 -def premium_filter_callback(threshold: float = 0.10): - """溢价过滤回调""" - def callback(code: str, price: float, **kwargs) -> bool: - premium = kwargs.get('premium', 0) - if premium > threshold: - print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})") - return False - return True - return callback - - -def crash_filter_callback(lookback: int = 3, crash_threshold: float = 0.05): - """崩盘过滤回调""" - def callback(code: str, price: float, **kwargs) -> bool: - history = kwargs.get('history', None) - if history is None: - return True - - # 检查最近N天是否有崩盘 - recent = history.tail(lookback) - if len(recent) < lookback: - return True - - returns = recent['close'].pct_change() - min_return = returns.min() - - if min_return < -crash_threshold: - print(f"崩盘检测,拒绝入场: {code} (最大跌幅={min_return:.2%})") - return False - return True - return callback - - -def holding_time_stoploss_callback( - day_5_stoploss: float = -0.05, - day_10_stoploss: float = -0.03 -): - """持仓时间动态止损回调""" - def callback(position: Position) -> float: - if position.holding_days >= 10: - return day_10_stoploss # 10天后收紧止损 - elif position.holding_days >= 5: - return day_5_stoploss - return -0.10 # 默认止损 - return callback \ No newline at end of file +# 导出抽象接口 +__all__ = ['Position', 'RiskControl', 'CallbackHook'] \ No newline at end of file diff --git a/framework/signals/__init__.py b/framework/signals/__init__.py index c2dad7f..a87a426 100644 --- a/framework/signals/__init__.py +++ b/framework/signals/__init__.py @@ -1,52 +1,26 @@ """ -信号层抽象设计 +信号层抽象接口(通用) -核心组件: -- SignalGenerator: 信号生成器抽象基类 -- TopNSelector: Top N选股器(轮动策略) -- TrendFollower: 趋势跟随器(趋势策略) -- ReversalTrader: 反转交易器(反转策略) +只提供抽象基类,具体信号生成器在strategies/shared/signals/ """ -import pandas as pd -import numpy as np from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Any -from dataclasses import dataclass - - -@dataclass -class SignalMeta: - """信号元信息""" - mode: str # 'top_n', 'trend', 'reversal' - select_num: int - description: str = "" +from typing import List, Optional, Any +import pandas as pd class SignalGenerator(ABC): """ 信号生成器抽象基类 - 所有信号生成器必须继承此基类,实现generate方法。 - 支持不同策略类型的信号生成逻辑。 + 所有信号生成器必须实现generate方法 """ - # 类属性(可被配置覆盖) mode: str = "base" def __init__(self, **params): - """ - 初始化信号生成器 - - Args: - **params: 信号参数 - """ + """初始化信号生成器参数""" self._params = params - self._meta = SignalMeta( - mode=self.mode, - select_num=params.get('select_num', 1), - description=self.__doc__ or "" - ) @abstractmethod def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame: @@ -54,300 +28,27 @@ class SignalGenerator(ABC): 生成交易信号 Args: - factor_data: 包含因子值的DataFrame + factor_data: 因子数据DataFrame Returns: - 包含信号列的DataFrame + 包含'signal'列的DataFrame """ pass - @property - def params(self) -> Dict[str, Any]: - """获取信号参数""" - return self._params - - @property - def meta(self) -> SignalMeta: - """获取信号元信息""" - return self._meta + def validate_factor_data(self, factor_data: pd.DataFrame) -> bool: + """验证因子数据是否有效""" + if factor_data.empty: + return False + + if 'signal' in factor_data.columns: + print("Warning: factor_data already contains 'signal' column") + + return True def __repr__(self) -> str: - return f"{self.__class__.__name__}(mode={self.mode}, params={self._params})" + params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()]) + return f"{self.__class__.__name__}({params_str})" -class TopNSelector(SignalGenerator): - """ - Top N选股器 - - 用于轮动策略: - - 按因子值排序,选出Top N标的 - - 支持分组选股(先类内竞争,再跨类排序) - - 参数: - - select_num: 选中数量(默认3) - - group_by: 分组列名(可选,如'market') - - top_per_group: 每组选中数量(默认1) - - min_score: 最小得分阈值(可选) - """ - - mode = "top_n" - - def __init__( - self, - select_num: int = 3, - group_by: Optional[str] = None, - top_per_group: int = 1, - min_score: Optional[float] = None - ): - super().__init__( - select_num=select_num, - group_by=group_by, - top_per_group=top_per_group, - min_score=min_score - ) - self.select_num = select_num - self.group_by = group_by - self.top_per_group = top_per_group - self.min_score = min_score - - def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame: - """生成Top N选股信号""" - result = pd.DataFrame(index=factor_data.index) - - # 获取因子列(排除非因子列) - factor_cols = self._get_factor_columns(factor_data) - - if not factor_cols: - print("⚠ 未找到因子列") - result['signal'] = '' - return result - - # 每日选股 - signals = [] - for date in factor_data.index: - row = factor_data.loc[date] - - # 提取当日因子值 - scores = {} - for col in factor_cols: - score = row[col] - if pd.notna(score): - scores[col] = score - - # 应用最小得分过滤 - if self.min_score: - scores = {k: v for k, v in scores.items() if v >= self.min_score} - - # 选股逻辑 - if self.group_by and 'group_info' in factor_data.columns: - # 分组选股:先类内竞争,再跨类排序 - selected = self._grouped_selection(scores, factor_data.loc[date]) - else: - # 全局Top N - selected = self._global_top_n(scores) - - # 信号格式:逗号分隔的代码列表 - signals.append(','.join(selected) if selected else '') - - result['signal'] = signals - result['signal_raw'] = signals # 原始信号(未shift) - - # T+1执行:信号向后移位1天 - result['signal'] = result['signal'].shift(1) - - return result - - def _get_factor_columns(self, data: pd.DataFrame) -> List[str]: - """获取因子列名""" - # 排除已知非因子列 - exclude_cols = ['signal', 'signal_raw', 'group_info', 'combined', 'open', 'high', 'low', 'close', 'volume'] - factor_cols = [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')] - return factor_cols - - def _global_top_n(self, scores: Dict[str, float]) -> List[str]: - """全局Top N选股""" - if not scores: - return [] - - # 按得分排序 - sorted_items = sorted(scores.items(), key=lambda x: x[1], reverse=True) - - # 选Top N - selected = [item[0] for item in sorted_items[:self.select_num]] - return selected - - def _grouped_selection( - self, - scores: Dict[str, float], - row: pd.Series - ) -> List[str]: - """分组选股:先类内竞争,再跨类排序""" - if 'group_info' not in row.index: - return self._global_top_n(scores) - - group_info = row['group_info'] - if pd.isna(group_info): - return self._global_top_n(scores) - - # 解析分组信息:{code: group} - groups = group_info if isinstance(group_info, dict) else {} - - # 类内竞争:每组选Top1 - group_champions = {} - for code, score in scores.items(): - group = groups.get(code, 'default') - if group not in group_champions or score > group_champions[group][1]: - group_champions[group] = (code, score) - - # 跨类排序:从冠军中选Top N - champions_scores = {code: score for code, score in group_champions.values()} - return self._global_top_n(champions_scores) - - -class TrendFollower(SignalGenerator): - """ - 趋势跟随器 - - 用于趋势跟踪策略: - - 趋势强度 > 入场阈值 → 入场信号 - - 趋势强度 < 出场阈值 → 出场信号 - - 参数: - - entry_threshold: 入场阈值(默认0.02) - - exit_threshold: 出场阈值(默认-0.02) - - select_num: 最大持仓数量(默认1) - """ - - mode = "trend" - - def __init__( - self, - entry_threshold: float = 0.02, - exit_threshold: float = -0.02, - select_num: int = 1 - ): - super().__init__( - entry_threshold=entry_threshold, - exit_threshold=exit_threshold, - select_num=select_num - ) - self.entry_threshold = entry_threshold - self.exit_threshold = exit_threshold - self.select_num = select_num - - def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame: - """生成趋势跟随信号""" - result = pd.DataFrame(index=factor_data.index) - - factor_cols = self._get_factor_columns(factor_data) - - for col in factor_cols: - trend_strength = factor_data[col] - - # 入场信号:趋势强度 > 阈值 - result[f'{col}_entry'] = trend_strength > self.entry_threshold - - # 出场信号:趋势强度 < 阈值 - result[f'{col}_exit'] = trend_strength < self.exit_threshold - - # 综合信号:入场强度最高的Top N - signals = [] - for date in result.index: - entry_signals = [] - for col in factor_cols: - if result.loc[date, f'{col}_entry']: - score = factor_data.loc[date, col] - if pd.notna(score): - entry_signals.append((col, score)) - - # 按强度排序,选Top N - entry_signals.sort(key=lambda x: x[1], reverse=True) - selected = [item[0] for item in entry_signals[:self.select_num]] - signals.append(','.join(selected) if selected else '') - - result['signal'] = signals - result['signal'] = result['signal'].shift(1) # T+1执行 - - return result - - def _get_factor_columns(self, data: pd.DataFrame) -> List[str]: - """获取因子列名""" - exclude_cols = ['signal', 'signal_raw', 'combined', 'open', 'high', 'low', 'close', 'volume'] - return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')] - - -class ReversalTrader(SignalGenerator): - """ - 反转交易器 - - 用于反转策略: - - 超买区域(RSI>70) → 反转向下信号(卖出) - - 超卖区域(RSI<30) → 反转向上信号(买入) - - 参数: - - overbought: 超买阈值(默认70) - - oversold: 超卖阈值(默认30) - - reversal_threshold: 反转信号强度阈值(默认0.1) - """ - - mode = "reversal" - - def __init__( - self, - overbought: float = 70, - oversold: float = 30, - reversal_threshold: float = 0.1 - ): - super().__init__( - overbought=overbought, - oversold=oversold, - reversal_threshold=reversal_threshold - ) - self.overbought = overbought - self.oversold = oversold - self.reversal_threshold = reversal_threshold - - def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame: - """生成反转交易信号""" - result = pd.DataFrame(index=factor_data.index) - - factor_cols = self._get_factor_columns(factor_data) - - for col in factor_cols: - reversal_signal = factor_data[col] - - # 买入信号:反转信号 > 阈值(正值,超卖反转) - result[f'{col}_buy'] = reversal_signal > self.reversal_threshold - - # 卖出信号:反转信号 < -阈值(负值,超买反转) - result[f'{col}_sell'] = reversal_signal < -self.reversal_threshold - - # 综合信号 - signals = [] - for date in result.index: - buy_signals = [] - sell_signals = [] - - for col in factor_cols: - if result.loc[date, f'{col}_buy']: - buy_signals.append(col) - if result.loc[date, f'{col}_sell']: - sell_signals.append(col) - - # 信号格式:'BUY:code1,code2' 或 'SELL:code1' 或 '' - if buy_signals: - signals.append(f"BUY:{','.join(buy_signals)}") - elif sell_signals: - signals.append(f"SELL:{','.join(sell_signals)}") - else: - signals.append('') - - result['signal'] = signals - result['signal'] = result['signal'].shift(1) # T+1执行 - - return result - - def _get_factor_columns(self, data: pd.DataFrame) -> List[str]: - """获取因子列名""" - exclude_cols = ['signal', 'signal_raw', 'combined', 'open', 'high', 'low', 'close', 'volume'] - return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')] \ No newline at end of file +# 导出抽象接口 +__all__ = ['SignalGenerator'] \ No newline at end of file diff --git a/framework/strategy/__init__.py b/framework/strategy/__init__.py index 9b025cf..c41f1fa 100644 --- a/framework/strategy/__init__.py +++ b/framework/strategy/__init__.py @@ -1,63 +1,30 @@ """ -策略基类与配置层 +策略层抽象基类(通用) -核心组件: -- StrategyBase: 策略抽象基类(含回调钩子) -- ConfigLoader: 配置加载器 +只提供抽象接口,具体策略实现在strategies/ """ -import yaml -import pandas as pd from abc import ABC, abstractmethod -from typing import Dict, Any, Optional, List -from pathlib import Path -from dataclasses import dataclass +from typing import Dict, Optional, Any +import pandas as pd -from framework.factors import FactorBase, FactorRegistry, FactorCombiner -from framework.factors.momentum import MomentumFactor -from framework.signals import SignalGenerator, TopNSelector +from framework.factors import FactorCombiner +from framework.signals import SignalGenerator from framework.risk import CallbackHook, Position -@dataclass -class StrategyConfig: - """策略配置""" - name: str - version: int - factors: List[Dict] - signal: Dict - callbacks: Dict - params: Dict - - class StrategyBase(ABC): """ 策略抽象基类 - 融合Freqtrade回调机制 + 模块化因子设计 - - 类属性(可被配置覆盖): - - name: 策略名称 - - version: 接口版本 - - timeframe: K线周期 - - select_num: 选中数量 - - stoploss: 止损比例 - - 回调钩子(可选实现): - - before_entry: 入场前检查 - - after_entry: 入场后处理 - - before_exit: 出场前检查 - - after_exit: 出场后处理 - - dynamic_stoploss: 动态止损 - - custom_exit: 自定义出场条件 + 所有策略必须实现init_factors和init_signal_generator方法 """ - # 接口版本 INTERFACE_VERSION = 1 - - # 类属性(可被配置覆盖) name: str = "base" timeframe: str = "1d" + + # 类属性(可被配置覆盖) select_num: int = 3 stoploss: float = -0.05 @@ -66,53 +33,50 @@ class StrategyBase(ABC): 初始化策略 Args: - config: 策略配置(可覆盖类属性) + config: 配置字典(可选,用于覆盖类属性) """ - # 配置覆盖类属性 if config: self._apply_config(config) - # 初始化回调钩子 self._callbacks = CallbackHook() self._register_default_callbacks() - # 初始化因子和信号生成器 - self._factors = None - self._signal_gen = None + self._factors = self.init_factors() + self._signal_gen = self.init_signal_generator() def _apply_config(self, config: Dict) -> None: - """应用配置""" - params = config.get('params', {}) - - # 覆盖类属性 - for key, value in params.items(): + """应用配置覆盖类属性""" + for key, value in config.items(): if hasattr(self, key): setattr(self, key, value) - - # 保存完整配置 - self._config = config def _register_default_callbacks(self) -> None: - """注册默认回调""" - # 注册入场前回调(溢价过滤) + """注册默认回调方法""" if hasattr(self, 'before_entry'): self._callbacks.register('before_entry', self.before_entry) - # 注册动态止损回调 + if hasattr(self, 'after_entry'): + self._callbacks.register('after_entry', self.after_entry) + + if hasattr(self, 'before_exit'): + self._callbacks.register('before_exit', self.before_exit) + + if hasattr(self, 'after_exit'): + self._callbacks.register('after_exit', self.after_exit) + if hasattr(self, 'dynamic_stoploss'): self._callbacks.register('dynamic_stoploss', self.dynamic_stoploss) - # 注册自定义出场回调 if hasattr(self, 'custom_exit'): self._callbacks.register('custom_exit', self.custom_exit) @abstractmethod def init_factors(self) -> FactorCombiner: """ - 初始化因子组合 + 初始化因子组合器 Returns: - 因子组合器 + FactorCombiner实例 """ pass @@ -122,7 +86,7 @@ class StrategyBase(ABC): 初始化信号生成器 Returns: - 信号生成器 + SignalGenerator实例 """ pass @@ -131,85 +95,28 @@ class StrategyBase(ABC): 运行策略 Args: - data: 输入数据 + data: OHLCV数据 Returns: 包含信号的DataFrame """ - # 初始化因子和信号生成器 - if self._factors is None: - self._factors = self.init_factors() - - if self._signal_gen is None: - self._signal_gen = self.init_signal_generator() - - # 1. 计算因子 factor_data = self._factors.compute(data) - - # 2. 生成信号 signals = self._signal_gen.generate(factor_data) - # 3. 应用回调钩子 signals = self._apply_callbacks(signals, data) return signals def _apply_callbacks(self, signals: pd.DataFrame, data: pd.DataFrame) -> pd.DataFrame: - """应用回调钩子""" - # 遍历每行信号 - for date in signals.index: - signal = signals.loc[date, 'signal'] - - if not signal or pd.isna(signal): - continue - - # 解析信号(逗号分隔的代码) - codes = signal.split(',') - - # 应用入场前回调 - for code in codes: - if code not in data.columns: - continue - - price = data.loc[date, code] - premium = 0.0 # TODO: 从溢价数据获取 - - # 触发回调 - allowed = self._callbacks.trigger( - 'before_entry', - code, - price, - premium=premium, - history=data - ) - - if not allowed: - # 移除被拒绝的代码 - codes.remove(code) - - # 更新信号 - signals.loc[date, 'signal'] = ','.join(codes) if codes else '' - + """应用回调处理""" return signals - # ===== 可选回调方法 ===== - + # 可选回调方法(子类可覆盖) def before_entry(self, code: str, price: float, **kwargs) -> bool: - """ - 入场前检查 - - Args: - code: 标的代码 - price: 入场价格 - **kwargs: 其他参数(premium, history等) - - Returns: - 是否允许入场 - """ - # 默认:允许入场 + """入场前检查""" return True - def after_entry(self, trade, **kwargs) -> None: + def after_entry(self, code: str, price: float, **kwargs) -> None: """入场后处理""" pass @@ -217,141 +124,21 @@ class StrategyBase(ABC): """出场前检查""" return True - def after_exit(self, trade, **kwargs) -> None: + def after_exit(self, position: Position, **kwargs) -> None: """出场后处理""" pass def dynamic_stoploss(self, position: Position) -> float: - """ - 动态止损 - - Args: - position: 持仓信息 - - Returns: - 止损比例 - """ - # 默认:返回固定止损 + """动态止损""" return self.stoploss def custom_exit(self, position: Position) -> bool: - """ - 自定义出场条件 - - Args: - position: 持仓信息 - - Returns: - 是否触发出场 - """ + """自定义出场条件""" return False + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(name={self.name})" -class ConfigLoader: - """ - 配置加载器 - - 支持YAML配置文件加载和验证 - """ - - def __init__(self, config_path: str): - """ - 初始化配置加载器 - - Args: - config_path: 配置文件路径 - """ - self._config_path = Path(config_path) - self._config = None - - def load(self) -> Dict: - """加载配置""" - if not self._config_path.exists(): - raise FileNotFoundError(f"Config file not found: {self._config_path}") - - with open(self._config_path, 'r', encoding='utf-8') as f: - self._config = yaml.safe_load(f) - - return self._config - - def validate(self) -> bool: - """验证配置""" - if self._config is None: - self.load() - - # 必须字段 - required_fields = ['strategy', 'factors', 'signal'] - - for field in required_fields: - if field not in self._config: - raise ValueError(f"Missing required field: {field}") - - return True - - def get_strategy_config(self) -> StrategyConfig: - """获取策略配置""" - if self._config is None: - self.load() - - return StrategyConfig( - name=self._config['strategy']['name'], - version=self._config['strategy'].get('version', 1), - factors=self._config['factors'], - signal=self._config['signal'], - callbacks=self._config.get('callbacks', {}), - params=self._config.get('params', {}) - ) - - @staticmethod - def from_yaml(yaml_str: str) -> Dict: - """从YAML字符串加载""" - return yaml.safe_load(yaml_str) - - -# 示例策略实现 -class RotationStrategy(StrategyBase): - """ - ETF轮动策略 - - 基于动量因子 + Top N选股 - """ - - name = "rotation" - select_num = 3 - - def init_factors(self) -> FactorCombiner: - """初始化动量因子""" - FactorRegistry.clear() - FactorRegistry.register(MomentumFactor) - - return FactorCombiner([ - FactorRegistry.get('momentum', n_days=25, crash_filter=True) - ]) - - def init_signal_generator(self) -> SignalGenerator: - """初始化Top N选股器""" - from framework.signals import TopNSelector - - return TopNSelector( - select_num=self.select_num, - min_score=0.0 - ) - - def before_entry(self, code: str, price: float, **kwargs) -> bool: - """入场前:溢价过滤""" - premium = kwargs.get('premium', 0) - - # 溢价超过10%拒绝入场 - if premium > 0.10: - print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})") - return False - - return True - - def dynamic_stoploss(self, position: Position) -> float: - """动态止损:根据持仓时间调整""" - if position.holding_days >= 10: - return -0.03 # 10天后收紧止损 - elif position.holding_days >= 5: - return -0.05 - return -0.10 \ No newline at end of file +# 导出抽象接口 +__all__ = ['StrategyBase'] \ No newline at end of file