Files
cta/user_data/strategies/SimpleRSIStrategy_Fixed.py
2025-10-11 19:05:42 +08:00

332 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# strategies/SimpleRSIStrategy_Fixed.py
import talib.abstract as ta
import pandas as pd
import numpy as np
from freqtrade.strategy import IStrategy
from freqtrade.persistence import Trade
from typing import Dict, List, Optional
from functools import reduce
import logging
logger = logging.getLogger(__name__)
class SimpleRSIStrategyFixed(IStrategy):
"""
修复版简化 RSI 策略:结合 EMA 交叉 + RSI 超卖 + 上升趋势
修复内容:
1. 添加缺失的导入语句
2. 修复除零错误
3. 优化性能
4. 改进代码可读性
5. 添加数据验证
"""
INTERFACE_VERSION = 3
# Can this strategy go short?
can_short: bool = False
# Minimal ROI designed for the strategy.
# 基于x.py策略使用更保守的ROI设置
# minimal_roi = {
# "30": 0.3, # 20% 目标收益
# # ""
# }
# Optimal stoploss designed for the strategy.
# 基于x.py策略的卖出条件跌破MA60且跌幅超过10%
stoploss = -0.05
# 超参数定义
minimal_roi = {
"0": 100
}
timeframe = '1d'
# 策略参数
rsi_period_short = 6
rsi_period_long = 12
ema_period_short = 10
ema_period_long = 20
rsi_oversold_threshold = 20
rsi_oversold_tolerance = 1.05
ema_trend_threshold = 0.01
ema_breakout_threshold = 0.02
ema_separation_threshold = 0.02
def populate_indicators(self, dataframe: pd.DataFrame, metadata: Dict) -> pd.DataFrame:
"""
添加技术指标到数据框
"""
try:
# RSI 指标 (6, 12) 及位移
dataframe['rsi_6'] = ta.RSI(dataframe['close'], timeperiod=6)
dataframe['rsi_12'] = ta.RSI(dataframe['close'], timeperiod=12)
for i in range(1, 6):
dataframe[f'rsi_6_shift{i}'] = dataframe['rsi_6'].shift(i)
dataframe[f'rsi_12_shift{i}'] = dataframe['rsi_12'].shift(i)
# EMA 指标 (10, 20, 30) 及位移
dataframe['ema_10'] = ta.EMA(dataframe['close'], timeperiod=10)
dataframe['ema_20'] = ta.EMA(dataframe['close'], timeperiod=20)
dataframe['ema_30'] = ta.EMA(dataframe['close'], timeperiod=30)
for i in range(1, 6):
dataframe[f'ema_10_shift{i}'] = dataframe['ema_10'].shift(i)
dataframe[f'ema_20_shift{i}'] = dataframe['ema_20'].shift(i)
dataframe[f'ema_30_shift{i}'] = dataframe['ema_30'].shift(i)
# 方便条件判断的差值与比率(带除零保护)
dataframe['ema_10_20_diff'] = dataframe['ema_10'] - dataframe['ema_20']
dataframe['ema_10_20_ratio'] = np.where(
dataframe['ema_20'] != 0,
dataframe['ema_10_20_diff'] / dataframe['ema_20'],
0
)
for i in range(1, 6):
dataframe[f'ema_10_20_diff_shift{i}'] = dataframe['ema_10_20_diff'].shift(i)
dataframe[f'ema_10_20_ratio_shift{i}'] = np.where(
dataframe[f'ema_20_shift{i}'] != 0,
dataframe[f'ema_10_20_diff_shift{i}'] / dataframe[f'ema_20_shift{i}'],
0
)
dataframe.to_csv("rsi_eth2.csv", index=False, encoding='utf-8-sig')
logger.debug(f"Indicators populated successfully for {metadata.get('pair', 'unknown')}")
return dataframe
except Exception as e:
logger.error(f"Error populating indicators: {e}")
return dataframe
def populate_entry_trend(self, dataframe: pd.DataFrame, metadata: Dict) -> pd.DataFrame:
"""
生成买入信号
"""
try:
# 初始化信号列
dataframe.loc[:, 'enter_long'] = 0
# 数据验证
required_columns = [
'ema_10', 'ema_20', 'ema_30',
'rsi_6', 'rsi_12',
'ema_10_shift1', 'ema_10_shift2', 'ema_10_shift3', 'ema_10_shift4', 'ema_10_shift5',
'ema_20_shift1', 'ema_20_shift2', 'ema_20_shift3', 'ema_20_shift4', 'ema_20_shift5',
'rsi_6_shift1', 'rsi_6_shift2', 'rsi_6_shift3',
'rsi_12_shift1', 'rsi_12_shift2', 'rsi_12_shift3',
'ema_10_20_ratio', 'ema_10_20_ratio_shift5',
'ema_10_20_diff', 'ema_10_20_diff_shift1', 'ema_10_20_diff_shift2', 'ema_10_20_diff_shift3', 'ema_10_20_diff_shift4', 'ema_10_20_diff_shift5',
'close'
]
if not all(col in dataframe.columns for col in required_columns):
logger.warning(f"Missing required columns: {required_columns}")
return dataframe
# 按给定逻辑实现买入条件
ema_10 = dataframe['ema_10']
ema_20 = dataframe['ema_20']
ema_10_s1 = dataframe['ema_10_shift1']
ema_10_s2 = dataframe['ema_10_shift2']
ema_10_s3 = dataframe['ema_10_shift3']
ema_10_s4 = dataframe['ema_10_shift4']
ema_10_s5 = dataframe['ema_10_shift5']
ema_20_s1 = dataframe['ema_20_shift1']
ema_20_s2 = dataframe['ema_20_shift2']
ema_20_s3 = dataframe['ema_20_shift3']
ema_20_s4 = dataframe['ema_20_shift4']
ema_20_s5 = dataframe['ema_20_shift5']
rsi_6 = dataframe['rsi_6']
rsi_12 = dataframe['rsi_12']
rsi_6_s1 = dataframe['rsi_6_shift1']
rsi_6_s2 = dataframe['rsi_6_shift2']
rsi_6_s3 = dataframe['rsi_6_shift3']
rsi_12_s1 = dataframe['rsi_12_shift1']
rsi_12_s2 = dataframe['rsi_12_shift2']
rsi_12_s3 = dataframe['rsi_12_shift3']
close = dataframe['close']
# 比率与安全除法
ratio_now = np.where(ema_20 != 0, (ema_10 - ema_20) / ema_20, 0)
ratio_s5 = np.where(ema_20_s5 != 0, (ema_10_s5 - ema_20_s5) / ema_20_s5, 0)
ratio_pos_and_small = (ratio_s5 != 0) & ((ratio_now / ratio_s5) > 0) & ((ratio_now / ratio_s5) < 0.2)
cond_1 = (
(
(ema_10 > ema_20)
& (ema_10_s1 > ema_20_s1)
& (ema_10_s1 > ema_10_s2)
& (ema_10 > ema_10_s1)
& (ratio_now > 0.01)
)
|
(
(ema_10 > 1.02 * ema_10_s1)
& (ema_10_s1 > ema_10_s2)
& (ema_10 > ema_20)
)
|
(
ratio_pos_and_small
& ((ema_20 - ema_10) < (ema_20_s1 - ema_10_s1))
& ((ema_20 - ema_10) < (ema_20_s2 - ema_10_s2))
& ((ema_20 - ema_10) < (ema_20_s3 - ema_10_s3))
& ((ema_20 - ema_10) < (ema_20_s4 - ema_10_s4))
& ((ema_20 - ema_10) < (ema_20_s5 - ema_10_s5))
& (ema_20_s5 > ema_10_s5)
& (ema_10 < close)
)
)
cond_rsi = (
(rsi_6 > 1.5 * rsi_6_s1)
& (rsi_6 > 0.95 * rsi_12)
& (rsi_6_s1 < 20 * 1.05)
& (rsi_6_s2 < 20 * 1.05)
& (rsi_6_s3 < 20 * 1.05)
& (rsi_12_s1 < 25 * 1.05)
& (rsi_12_s2 < 25 * 1.05)
& (rsi_12_s3 < 25 * 1.05)
)
buy_condition = cond_1 | cond_rsi
dataframe.loc[buy_condition, 'enter_long'] = 1
# 记录信号统计
signal_count = buy_condition.sum()
if signal_count > 0:
logger.info(f"Generated {signal_count} buy signals for {metadata.get('pair', 'unknown')}")
return dataframe
except Exception as e:
logger.error(f"Error in populate_entry_trend: {e}")
return dataframe
def populate_exit_trend(self, dataframe: pd.DataFrame, metadata: Dict) -> pd.DataFrame:
"""
生成卖出信号
"""
try:
dataframe.loc[:, 'exit_long'] = 0
# 数据验证
required_columns = [
'ema_10', 'ema_20',
'ema_10_shift1', 'ema_10_shift2', 'ema_10_shift3', 'ema_10_shift4', 'ema_10_shift5',
'ema_20_shift1', 'ema_20_shift2', 'ema_20_shift3', 'ema_20_shift4', 'ema_20_shift5',
'ema_10_20_ratio', 'ema_10_20_ratio_shift4',
]
if not all(col in dataframe.columns for col in required_columns):
logger.warning(f"Missing required columns: {required_columns}")
return dataframe
ema_10 = dataframe['ema_10']
ema_20 = dataframe['ema_20']
ema_10_s1 = dataframe['ema_10_shift1']
ema_10_s2 = dataframe['ema_10_shift2']
ema_10_s3 = dataframe['ema_10_shift3']
ema_10_s4 = dataframe['ema_10_shift4']
ema_10_s5 = dataframe['ema_10_shift5']
ema_20_s1 = dataframe['ema_20_shift1']
ema_20_s2 = dataframe['ema_20_shift2']
ema_20_s3 = dataframe['ema_20_shift3']
ema_20_s4 = dataframe['ema_20_shift4']
ema_20_s5 = dataframe['ema_20_shift5']
ratio_now = np.where(ema_20 != 0, (ema_10 - ema_20) / ema_20, 0)
ratio_s4 = np.where(ema_20_s4 != 0, (ema_10_s4 - ema_20_s4) / ema_20_s4, 0)
ratio_pos_and_lt03 = (ratio_s4 != 0) & ((ratio_now / ratio_s4) > 0) & ((ratio_now / ratio_s4) < 0.3)
cond_block_1 = (
(
(ema_10 < ema_20)
& (ema_10_s1 < ema_20_s1)
& (ema_10_s2 < ema_20_s2)
& (ema_10_s5 > ema_20_s5)
)
|
(
(ema_10 < 0.995 * ema_10_s1)
& (ema_10_s1 < ema_10_s2)
& (ema_10 < ema_20)
)
)
cond_block_2 = (
(ratio_now < 0.02)
& ratio_pos_and_lt03
& (ema_10_s4 > ema_20_s4)
& (ema_10 < ema_10_s1)
& (ema_10 < ema_10_s2)
& (ema_10 < ema_10_s3)
& (ema_10 < ema_10_s4)
)
widening_now_vs_history = (
((ema_20 - ema_10) > (ema_20_s1 - ema_10_s1))
& ((ema_20 - ema_10) > (ema_20_s2 - ema_10_s2))
& ((ema_20 - ema_10) > (ema_20_s3 - ema_10_s3))
& ((ema_20 - ema_10) > (ema_20_s4 - ema_10_s4))
& ((ema_20 - ema_10) > (ema_20_s5 - ema_10_s5))
)
sell_condition = (
(
cond_block_1
& (ema_10 < ema_10_s1)
& (ema_10 < ema_10_s2)
& (ema_10 < ema_10_s3)
& (ema_10 < ema_10_s4)
& (ema_10 < ema_10_s5)
& widening_now_vs_history
)
|
(
cond_block_2
& ((ema_20 - ema_10) > (ema_20_s1 - ema_10_s1))
& ((ema_20 - ema_10) > (ema_20_s2 - ema_10_s2))
& ((ema_20 - ema_10) > (ema_20_s3 - ema_10_s3))
& ((ema_20 - ema_10) > (ema_20_s4 - ema_10_s4))
)
)
dataframe.loc[sell_condition, 'exit_long'] = 1
# 记录信号统计
signal_count = sell_condition.sum()
if signal_count > 0:
logger.info(f"Generated {signal_count} sell signals for {metadata.get('pair', 'unknown')}")
return dataframe
except Exception as e:
logger.error(f"Error in populate_exit_trend: {e}")
return dataframe
# def custom_stoploss(self, pair: str, trade: Trade, current_time, current_rate, current_profit, **kwargs) -> float:
# """
# 自定义止损逻辑
# """
# try:
# # 动态止损根据RSI调整止损点
# dataframe, _ = self.dp.get_analyzed_dataframe(pair, self.timeframe)
# if len(dataframe) > 0:
# last_candle = dataframe.iloc[-1]
# rsi_short = last_candle.get('rsi_6', 50)
# # RSI越高止损越宽松
# if rsi_short > 70:
# return -0.15 # 更宽松的止损
# elif rsi_short < 30:
# return -0.05 # 更严格的止损
# else:
# return self.stoploss # 默认止损
# else:
# return self.stoploss
# except Exception as e:
# logger.error(f"Error in custom_stoploss: {e}")
# return self.stoploss