348 lines
13 KiB
Python
348 lines
13 KiB
Python
# strategies/SimpleRSIStrategy_Fixed.py
|
||
|
||
import talib.abstract as ta
|
||
import pandas as pd
|
||
import numpy as np
|
||
from freqtrade.strategy import IStrategy
|
||
from freqtrade.persistence import Trade
|
||
from typing import Dict, List, Optional
|
||
from functools import reduce
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class SimpleRSIStrategyFixed(IStrategy):
|
||
"""
|
||
修复版简化 RSI 策略:结合 EMA 交叉 + RSI 超卖 + 上升趋势
|
||
|
||
修复内容:
|
||
1. 添加缺失的导入语句
|
||
2. 修复除零错误
|
||
3. 优化性能
|
||
4. 改进代码可读性
|
||
5. 添加数据验证
|
||
"""
|
||
INTERFACE_VERSION = 3
|
||
|
||
# Can this strategy go short?
|
||
can_short: bool = False
|
||
|
||
# Minimal ROI designed for the strategy.
|
||
# 基于x.py策略,使用更保守的ROI设置
|
||
# minimal_roi = {
|
||
# "30": 0.3, # 20% 目标收益
|
||
# # ""
|
||
# }
|
||
|
||
# Optimal stoploss designed for the strategy.
|
||
# 基于x.py策略的卖出条件:跌破MA60且跌幅超过10%
|
||
stoploss = -0.05
|
||
|
||
# 超参数定义
|
||
minimal_roi = {
|
||
"0": 100
|
||
}
|
||
|
||
timeframe = '1h'
|
||
|
||
# 策略参数
|
||
rsi_period_short = 6
|
||
rsi_period_long = 12
|
||
ema_period_short = 10
|
||
ema_period_long = 20
|
||
rsi_oversold_threshold = 20
|
||
rsi_oversold_tolerance = 1.05
|
||
ema_trend_threshold = 0.01
|
||
ema_breakout_threshold = 0.02
|
||
ema_separation_threshold = 0.02
|
||
|
||
def populate_indicators(self, dataframe: pd.DataFrame, metadata: Dict) -> pd.DataFrame:
|
||
"""
|
||
添加技术指标到数据框
|
||
"""
|
||
df = dataframe.copy()
|
||
|
||
# 将date设置为索引用于resample
|
||
df['date'] = df['date'].dt.tz_convert('Asia/Shanghai')
|
||
df = df.set_index('date')
|
||
df_resample = df.resample('24h', label='right').agg({
|
||
'open': 'first',
|
||
'high': 'max',
|
||
'low': 'min',
|
||
'close': 'last',
|
||
'volume': 'sum'
|
||
})
|
||
|
||
# RSI 指标 (6, 12) 及位移
|
||
df_resample['rsi_6'] = ta.RSI(df_resample['close'], timeperiod=6)
|
||
df_resample['rsi_12'] = ta.RSI(df_resample['close'], timeperiod=12)
|
||
for i in range(1, 6):
|
||
df_resample[f'rsi_6_shift{i}'] = df_resample['rsi_6'].shift(i)
|
||
df_resample[f'rsi_12_shift{i}'] = df_resample['rsi_12'].shift(i)
|
||
|
||
# EMA 指标 (10, 20, 30) 及位移
|
||
df_resample['ema_10'] = ta.EMA(df_resample['close'], timeperiod=10)
|
||
df_resample['ema_20'] = ta.EMA(df_resample['close'], timeperiod=20)
|
||
df_resample['ema_30'] = ta.EMA(df_resample['close'], timeperiod=30)
|
||
for i in range(1, 6):
|
||
df_resample[f'ema_10_shift{i}'] = df_resample['ema_10'].shift(i)
|
||
df_resample[f'ema_20_shift{i}'] = df_resample['ema_20'].shift(i)
|
||
df_resample[f'ema_30_shift{i}'] = df_resample['ema_30'].shift(i)
|
||
|
||
# 方便条件判断的差值与比率(带除零保护)
|
||
df_resample['ema_10_20_diff'] = df_resample['ema_10'] - df_resample['ema_20']
|
||
df_resample['ema_10_20_ratio'] = np.where(
|
||
df_resample['ema_20'] != 0,
|
||
df_resample['ema_10_20_diff'] / df_resample['ema_20'],
|
||
0
|
||
)
|
||
for i in range(1, 6):
|
||
df_resample[f'ema_10_20_diff_shift{i}'] = df_resample['ema_10_20_diff'].shift(i)
|
||
df_resample[f'ema_10_20_ratio_shift{i}'] = np.where(
|
||
df_resample[f'ema_20_shift{i}'] != 0,
|
||
df_resample[f'ema_10_20_diff_shift{i}'] / df_resample[f'ema_20_shift{i}'],
|
||
0
|
||
)
|
||
|
||
df_resample = df_resample.drop(columns=['open', 'high', 'low', 'close', 'volume'])
|
||
df_resample.index = df_resample.index.tz_convert('UTC')
|
||
dataframe = dataframe.join(df_resample,
|
||
on='date',
|
||
how='left').ffill()
|
||
dataframe.to_csv("rsi_eth2.csv", index=False, encoding='utf-8-sig')
|
||
logger.debug(f"Indicators populated successfully for {metadata.get('pair', 'unknown')}")
|
||
return dataframe
|
||
|
||
|
||
|
||
def populate_entry_trend(self, dataframe: pd.DataFrame, metadata: Dict) -> pd.DataFrame:
|
||
"""
|
||
生成买入信号
|
||
"""
|
||
try:
|
||
# 初始化信号列
|
||
dataframe.loc[:, 'enter_long'] = 0
|
||
|
||
# 数据验证
|
||
required_columns = [
|
||
'ema_10', 'ema_20', 'ema_30',
|
||
'rsi_6', 'rsi_12',
|
||
'ema_10_shift1', 'ema_10_shift2', 'ema_10_shift3', 'ema_10_shift4', 'ema_10_shift5',
|
||
'ema_20_shift1', 'ema_20_shift2', 'ema_20_shift3', 'ema_20_shift4', 'ema_20_shift5',
|
||
'rsi_6_shift1', 'rsi_6_shift2', 'rsi_6_shift3',
|
||
'rsi_12_shift1', 'rsi_12_shift2', 'rsi_12_shift3',
|
||
'ema_10_20_ratio', 'ema_10_20_ratio_shift5',
|
||
'ema_10_20_diff', 'ema_10_20_diff_shift1', 'ema_10_20_diff_shift2', 'ema_10_20_diff_shift3', 'ema_10_20_diff_shift4', 'ema_10_20_diff_shift5',
|
||
'close'
|
||
]
|
||
if not all(col in dataframe.columns for col in required_columns):
|
||
logger.warning(f"Missing required columns: {required_columns}")
|
||
return dataframe
|
||
|
||
# 按给定逻辑实现买入条件
|
||
ema_10 = dataframe['ema_10']
|
||
ema_20 = dataframe['ema_20']
|
||
ema_10_s1 = dataframe['ema_10_shift1']
|
||
ema_10_s2 = dataframe['ema_10_shift2']
|
||
ema_10_s3 = dataframe['ema_10_shift3']
|
||
ema_10_s4 = dataframe['ema_10_shift4']
|
||
ema_10_s5 = dataframe['ema_10_shift5']
|
||
ema_20_s1 = dataframe['ema_20_shift1']
|
||
ema_20_s2 = dataframe['ema_20_shift2']
|
||
ema_20_s3 = dataframe['ema_20_shift3']
|
||
ema_20_s4 = dataframe['ema_20_shift4']
|
||
ema_20_s5 = dataframe['ema_20_shift5']
|
||
rsi_6 = dataframe['rsi_6']
|
||
rsi_12 = dataframe['rsi_12']
|
||
rsi_6_s1 = dataframe['rsi_6_shift1']
|
||
rsi_6_s2 = dataframe['rsi_6_shift2']
|
||
rsi_6_s3 = dataframe['rsi_6_shift3']
|
||
rsi_12_s1 = dataframe['rsi_12_shift1']
|
||
rsi_12_s2 = dataframe['rsi_12_shift2']
|
||
rsi_12_s3 = dataframe['rsi_12_shift3']
|
||
close = dataframe['close']
|
||
|
||
# 比率与安全除法
|
||
ratio_now = np.where(ema_20 != 0, (ema_10 - ema_20) / ema_20, 0)
|
||
ratio_s5 = np.where(ema_20_s5 != 0, (ema_10_s5 - ema_20_s5) / ema_20_s5, 0)
|
||
ratio_pos_and_small = (ratio_s5 != 0) & ((ratio_now / ratio_s5) > 0) & ((ratio_now / ratio_s5) < 0.2)
|
||
|
||
cond_1 = (
|
||
(
|
||
(ema_10 > ema_20)
|
||
& (ema_10_s1 > ema_20_s1)
|
||
& (ema_10_s1 > ema_10_s2)
|
||
& (ema_10 > ema_10_s1)
|
||
& (ratio_now > 0.01)
|
||
)
|
||
|
|
||
(
|
||
(ema_10 > 1.02 * ema_10_s1)
|
||
& (ema_10_s1 > ema_10_s2)
|
||
& (ema_10 > ema_20)
|
||
)
|
||
|
|
||
(
|
||
ratio_pos_and_small
|
||
& ((ema_20 - ema_10) < (ema_20_s1 - ema_10_s1))
|
||
& ((ema_20 - ema_10) < (ema_20_s2 - ema_10_s2))
|
||
& ((ema_20 - ema_10) < (ema_20_s3 - ema_10_s3))
|
||
& ((ema_20 - ema_10) < (ema_20_s4 - ema_10_s4))
|
||
& ((ema_20 - ema_10) < (ema_20_s5 - ema_10_s5))
|
||
& (ema_20_s5 > ema_10_s5)
|
||
& (ema_10 < close)
|
||
)
|
||
)
|
||
|
||
cond_rsi = (
|
||
(rsi_6 > 1.5 * rsi_6_s1)
|
||
& (rsi_6 > 0.95 * rsi_12)
|
||
& (rsi_6_s1 < 20 * 1.05)
|
||
& (rsi_6_s2 < 20 * 1.05)
|
||
& (rsi_6_s3 < 20 * 1.05)
|
||
& (rsi_12_s1 < 25 * 1.05)
|
||
& (rsi_12_s2 < 25 * 1.05)
|
||
& (rsi_12_s3 < 25 * 1.05)
|
||
)
|
||
|
||
buy_condition = cond_1 | cond_rsi
|
||
|
||
dataframe.loc[buy_condition, 'enter_long'] = 1
|
||
|
||
# 记录信号统计
|
||
signal_count = buy_condition.sum()
|
||
if signal_count > 0:
|
||
logger.info(f"Generated {signal_count} buy signals for {metadata.get('pair', 'unknown')}")
|
||
|
||
return dataframe
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in populate_entry_trend: {e}")
|
||
return dataframe
|
||
|
||
def populate_exit_trend(self, dataframe: pd.DataFrame, metadata: Dict) -> pd.DataFrame:
|
||
"""
|
||
生成卖出信号
|
||
"""
|
||
try:
|
||
dataframe.loc[:, 'exit_long'] = 0
|
||
|
||
# 数据验证
|
||
required_columns = [
|
||
'ema_10', 'ema_20',
|
||
'ema_10_shift1', 'ema_10_shift2', 'ema_10_shift3', 'ema_10_shift4', 'ema_10_shift5',
|
||
'ema_20_shift1', 'ema_20_shift2', 'ema_20_shift3', 'ema_20_shift4', 'ema_20_shift5',
|
||
'ema_10_20_ratio', 'ema_10_20_ratio_shift4',
|
||
]
|
||
if not all(col in dataframe.columns for col in required_columns):
|
||
logger.warning(f"Missing required columns: {required_columns}")
|
||
return dataframe
|
||
|
||
ema_10 = dataframe['ema_10']
|
||
ema_20 = dataframe['ema_20']
|
||
ema_10_s1 = dataframe['ema_10_shift1']
|
||
ema_10_s2 = dataframe['ema_10_shift2']
|
||
ema_10_s3 = dataframe['ema_10_shift3']
|
||
ema_10_s4 = dataframe['ema_10_shift4']
|
||
ema_10_s5 = dataframe['ema_10_shift5']
|
||
ema_20_s1 = dataframe['ema_20_shift1']
|
||
ema_20_s2 = dataframe['ema_20_shift2']
|
||
ema_20_s3 = dataframe['ema_20_shift3']
|
||
ema_20_s4 = dataframe['ema_20_shift4']
|
||
ema_20_s5 = dataframe['ema_20_shift5']
|
||
|
||
ratio_now = np.where(ema_20 != 0, (ema_10 - ema_20) / ema_20, 0)
|
||
ratio_s4 = np.where(ema_20_s4 != 0, (ema_10_s4 - ema_20_s4) / ema_20_s4, 0)
|
||
ratio_pos_and_lt03 = (ratio_s4 != 0) & ((ratio_now / ratio_s4) > 0) & ((ratio_now / ratio_s4) < 0.3)
|
||
|
||
cond_block_1 = (
|
||
(
|
||
(ema_10 < ema_20)
|
||
& (ema_10_s1 < ema_20_s1)
|
||
& (ema_10_s2 < ema_20_s2)
|
||
& (ema_10_s5 > ema_20_s5)
|
||
)
|
||
|
|
||
(
|
||
(ema_10 < 0.995 * ema_10_s1)
|
||
& (ema_10_s1 < ema_10_s2)
|
||
& (ema_10 < ema_20)
|
||
)
|
||
)
|
||
|
||
cond_block_2 = (
|
||
(ratio_now < 0.02)
|
||
& ratio_pos_and_lt03
|
||
& (ema_10_s4 > ema_20_s4)
|
||
& (ema_10 < ema_10_s1)
|
||
& (ema_10 < ema_10_s2)
|
||
& (ema_10 < ema_10_s3)
|
||
& (ema_10 < ema_10_s4)
|
||
)
|
||
|
||
widening_now_vs_history = (
|
||
((ema_20 - ema_10) > (ema_20_s1 - ema_10_s1))
|
||
& ((ema_20 - ema_10) > (ema_20_s2 - ema_10_s2))
|
||
& ((ema_20 - ema_10) > (ema_20_s3 - ema_10_s3))
|
||
& ((ema_20 - ema_10) > (ema_20_s4 - ema_10_s4))
|
||
& ((ema_20 - ema_10) > (ema_20_s5 - ema_10_s5))
|
||
)
|
||
|
||
sell_condition = (
|
||
(
|
||
cond_block_1
|
||
& (ema_10 < ema_10_s1)
|
||
& (ema_10 < ema_10_s2)
|
||
& (ema_10 < ema_10_s3)
|
||
& (ema_10 < ema_10_s4)
|
||
& (ema_10 < ema_10_s5)
|
||
& widening_now_vs_history
|
||
)
|
||
|
|
||
(
|
||
cond_block_2
|
||
& ((ema_20 - ema_10) > (ema_20_s1 - ema_10_s1))
|
||
& ((ema_20 - ema_10) > (ema_20_s2 - ema_10_s2))
|
||
& ((ema_20 - ema_10) > (ema_20_s3 - ema_10_s3))
|
||
& ((ema_20 - ema_10) > (ema_20_s4 - ema_10_s4))
|
||
)
|
||
)
|
||
|
||
dataframe.loc[sell_condition, 'exit_long'] = 1
|
||
|
||
# 记录信号统计
|
||
signal_count = sell_condition.sum()
|
||
if signal_count > 0:
|
||
logger.info(f"Generated {signal_count} sell signals for {metadata.get('pair', 'unknown')}")
|
||
|
||
return dataframe
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in populate_exit_trend: {e}")
|
||
return dataframe
|
||
|
||
# def custom_stoploss(self, pair: str, trade: Trade, current_time, current_rate, current_profit, **kwargs) -> float:
|
||
# """
|
||
# 自定义止损逻辑
|
||
# """
|
||
# try:
|
||
# # 动态止损:根据RSI调整止损点
|
||
# dataframe, _ = self.dp.get_analyzed_dataframe(pair, self.timeframe)
|
||
# if len(dataframe) > 0:
|
||
# last_candle = dataframe.iloc[-1]
|
||
# rsi_short = last_candle.get('rsi_6', 50)
|
||
|
||
# # RSI越高,止损越宽松
|
||
# if rsi_short > 70:
|
||
# return -0.15 # 更宽松的止损
|
||
# elif rsi_short < 30:
|
||
# return -0.05 # 更严格的止损
|
||
# else:
|
||
# return self.stoploss # 默认止损
|
||
# else:
|
||
# return self.stoploss
|
||
# except Exception as e:
|
||
# logger.error(f"Error in custom_stoploss: {e}")
|
||
# return self.stoploss
|