460 lines
16 KiB
Python
460 lines
16 KiB
Python
import pandas as pd
|
||
|
||
from loguru import logger
|
||
import talib as ta
|
||
import numpy as np
|
||
from lightweight_charts import Chart # lightweight-charts==2.1
|
||
|
||
|
||
import random
|
||
from datetime import datetime
|
||
|
||
|
||
def TD(dataframe: pd.DataFrame):
|
||
close = dataframe["close"].to_list()
|
||
td = [0, 0, 0, 0]
|
||
up = 0
|
||
down = 0
|
||
for i in range(4, len(close)):
|
||
if close[i] > close[i - 4]:
|
||
up += 1
|
||
down = 0
|
||
td.append(up)
|
||
else:
|
||
down -= 1
|
||
up = 0
|
||
td.append(down)
|
||
return td
|
||
|
||
|
||
def resample_data(df: pd.DataFrame, timeframe: str) -> pd.DataFrame:
|
||
"""
|
||
对日线数据进行重采样
|
||
:param df: 原始日线数据(time 作为索引)
|
||
:param timeframe: 目标周期 ('1D', '1W', '1M', '3M', '1Y')
|
||
:return: 重采样后的 DataFrame
|
||
"""
|
||
# 映射周期到 pandas resample 规则
|
||
timeframe_map = {
|
||
"1D": "D", # 日线
|
||
"1W": "W", # 周线
|
||
"1M": "M", # 月线
|
||
"3M": "3M", # 季线
|
||
"1Y": "Y", # 年线
|
||
}
|
||
|
||
if timeframe not in timeframe_map:
|
||
return df
|
||
df = df.set_index("time")
|
||
|
||
rule = timeframe_map[timeframe]
|
||
|
||
# 对 OHLCV 数据进行重采样
|
||
resampled = (
|
||
df.resample(rule)
|
||
.agg(
|
||
{
|
||
"open": "first", # 开盘价取第一个
|
||
"high": "max", # 最高价取最大值
|
||
"low": "min", # 最低价取最小值
|
||
"close": "last", # 收盘价取最后一个
|
||
"volume": "sum", # 成交量求和
|
||
}
|
||
)
|
||
.dropna()
|
||
)
|
||
|
||
# 重置索引,将 time 变回列
|
||
resampled = resampled.reset_index()
|
||
return resampled
|
||
|
||
|
||
def POC(df, bins: int = 50):
|
||
"""
|
||
计算价格分布的成交量峰值位置(POC,Point Of Control)
|
||
|
||
参数:
|
||
df: 包含 'low', 'high', 'volume' 三列的 DataFrame(每行代表一根 K 线)
|
||
bins: 将价格区间划分为多少个小区间(价格桶),默认为 50
|
||
|
||
返回:
|
||
poc: 代表成交量最大的价格区间的中点价格
|
||
"""
|
||
|
||
# 当前数据的最低价和最高价,用于构建价格区间
|
||
low, high = df["low"].min(), df["high"].max()
|
||
# 将整个价格区间等分为 bins 个小区间,需要 bins+1 个边界点
|
||
price_ranges = np.linspace(low, high, bins + 1)
|
||
|
||
# 初始化每个价格区间累积的成交量数组,长度为 bins(区间个数)
|
||
volume_per_price = np.zeros(bins)
|
||
|
||
# 遍历每一根 K 线,将该 K 线的成交量按覆盖的价格区间平均分配
|
||
for i in range(len(df)):
|
||
high = df["high"].iloc[i]
|
||
low = df["low"].iloc[i]
|
||
volume = df["volume"].iloc[i]
|
||
|
||
# 找到当前 K 线覆盖的价格边界(注意使用闭区间判断)
|
||
# price_ranges 表示边界点,若 price_ranges[j] 在 [low, high] 范围内,则第 j 个边界被覆盖
|
||
price_coverage = (price_ranges >= low) & (price_ranges <= high)
|
||
covered_bins = np.where(price_coverage)[0]
|
||
|
||
if len(covered_bins) > 0:
|
||
# 如果覆盖了多个边界点,则这些边界之间形成了若干完整的区间
|
||
# 将该 K 线的成交量平均分配到这些被覆盖的区间上
|
||
# 注意:covered_bins 里是边界索引,最后一个边界索引对应的区间索引需小于 bins
|
||
volume_per_bin = volume / len(covered_bins)
|
||
for bin_idx in covered_bins:
|
||
if bin_idx < bins:
|
||
volume_per_price[bin_idx] += volume_per_bin
|
||
|
||
# 找到成交量最大的区间索引
|
||
idx = np.argmax(volume_per_price)
|
||
# 以该区间的两个边界的中点作为 POC 值
|
||
poc = (price_ranges[idx] + price_ranges[idx + 1]) / 2
|
||
return poc
|
||
|
||
|
||
def get_fixed_color_based_on_period(num: int):
|
||
# 使用周期值作为随机种子,确保相同周期生成相同颜色
|
||
random.seed(num)
|
||
|
||
# 生成随机的RGB值
|
||
r = random.randint(0, 255)
|
||
g = random.randint(0, 255)
|
||
b = random.randint(0, 255)
|
||
|
||
# 将RGB值转换为十六进制颜色代码
|
||
color_code = "#{:02x}{:02x}{:02x}".format(r, g, b)
|
||
|
||
# 重置随机种子,避免影响其他随机操作
|
||
random.seed(None)
|
||
|
||
return color_code
|
||
|
||
|
||
class QuantChart:
|
||
|
||
def __init__(self):
|
||
self.horizontal_lines = {}
|
||
self.legend_data = {}
|
||
self.visible_latest_close = None
|
||
|
||
def update_legend(self, chart, key, text):
|
||
self.legend_data[key] = text
|
||
sorted_dict = dict(sorted(self.legend_data.items()))
|
||
full_text = ", ".join(sorted_dict.values())
|
||
chart.legend(visible=True, text=full_text)
|
||
|
||
def add_ema(self, df, chart, period: int = 50):
|
||
name = f"EMA_{period}"
|
||
df[name] = ta.EMA(df["close"], timeperiod=period)
|
||
color = get_fixed_color_based_on_period(num=period)
|
||
line = chart.create_line(
|
||
name, color=color, width=2, price_label=False, price_line=False
|
||
)
|
||
line.set(df[["time", name]])
|
||
|
||
def add_cci(
|
||
self, df, chart, period: int = 14, height: float = 0.1, position: str = "bottom"
|
||
):
|
||
cci = ta.CCI(df["high"], df["low"], df["close"], timeperiod=period)
|
||
df[f"CCI_{period}"] = cci
|
||
cci_chart = chart.create_subchart(
|
||
position=position, width=1, height=height, sync=True
|
||
)
|
||
cci_chart.layout(font_family="Times New Roman")
|
||
cci_chart.legend(
|
||
visible=True, font_size=14, color="#FFFFFF", font_family="Times New Roman"
|
||
)
|
||
cci_chart.time_scale(visible=False)
|
||
cci_line = cci_chart.create_line(name=f"CCI_{period}", color="#FF0000", width=2)
|
||
cci_line.set(df[["time", f"CCI_{period}"]])
|
||
df = df[["time"]].copy()
|
||
df["h"] = 100
|
||
df["l"] = -100
|
||
cci_line = cci_chart.create_line(
|
||
name="h",
|
||
color="#D4C21C",
|
||
width=1,
|
||
style="dashed",
|
||
price_label=False,
|
||
price_line=False,
|
||
)
|
||
cci_line.set(df[["time", "h"]])
|
||
cci_line = cci_chart.create_line(
|
||
name="l",
|
||
color="#D4C21C",
|
||
width=1,
|
||
style="dashed",
|
||
price_label=False,
|
||
price_line=False,
|
||
)
|
||
cci_line.set(df[["time", "l"]])
|
||
|
||
def add_macd(
|
||
self,
|
||
df,
|
||
chart,
|
||
fastperiod: int = 12,
|
||
slowperiod: int = 26,
|
||
signalperiod: int = 9,
|
||
height: float = 0.1,
|
||
position: str = "bottom",
|
||
):
|
||
macd, signal, hist = ta.MACD(
|
||
df["close"],
|
||
fastperiod=fastperiod,
|
||
slowperiod=slowperiod,
|
||
signalperiod=signalperiod,
|
||
)
|
||
df["DIF"] = macd
|
||
df["DEA"] = signal
|
||
macd_name = f"MACD_{fastperiod}_{slowperiod}_{signalperiod}"
|
||
df[macd_name] = hist * 2
|
||
macd_chart = chart.create_subchart(
|
||
position=position, width=1, height=height, sync=True
|
||
)
|
||
macd_chart.layout(font_family="Times New Roman")
|
||
macd_chart.legend(
|
||
visible=True, font_size=14, color="#FFFFFF", font_family="Times New Roman"
|
||
)
|
||
macd_chart.time_scale(visible=False)
|
||
|
||
histogram = macd_chart.create_histogram(name=macd_name)
|
||
hist_data = df[["time", macd_name]].copy()
|
||
hist_data['prev_value'] = hist_data[macd_name].shift(1)
|
||
# 设置颜色逻辑
|
||
def get_histogram_color(row):
|
||
current, prev = row[macd_name], row['prev_value']
|
||
is_hollow = (current >= 0 and current < prev) or (current < 0 and current > prev)
|
||
if current >= 0:
|
||
return "rgba(255, 0, 0, 0.3)" if is_hollow else "#ff0000"
|
||
else:
|
||
return "rgba(0, 255, 0, 0.3)" if is_hollow else "#00FF00"
|
||
hist_data["color"] = hist_data.apply(get_histogram_color, axis=1)
|
||
hist_data = hist_data.drop('prev_value', axis=1)
|
||
|
||
histogram.set(hist_data)
|
||
|
||
macd_line = macd_chart.create_line(
|
||
name="DIF", color="#2962FF", width=2, price_label=False, price_line=False
|
||
)
|
||
macd_line.set(df[["time", "DIF"]])
|
||
signal_line = macd_chart.create_line(
|
||
name="DEA", color="#FF0000", width=2, price_label=False, price_line=False
|
||
)
|
||
signal_line.set(df[["time", "DEA"]])
|
||
|
||
def add_TD(self, df, chart):
|
||
df["TD"] = TD(df)
|
||
td_line = chart.create_line(
|
||
name="TD",
|
||
color="rgba(0, 0, 0, 0)", # 透明色,不在图表上显示线条
|
||
width=0,
|
||
price_line=False,
|
||
price_label=False,
|
||
price_scale_id="td_scale",
|
||
)
|
||
td_line.precision(0)
|
||
td_line.set(df[["time", "TD"]])
|
||
TDs = df[["time", "TD"]].to_dict(orient="records")
|
||
markers = []
|
||
for item in TDs:
|
||
if item["TD"] in [9, 13]:
|
||
markers.append(
|
||
{
|
||
"time": item["time"].strftime("%Y-%m-%d %H:%M:%S"),
|
||
"position": "above",
|
||
"shape": "arrow_down",
|
||
"color": "#00FF00",
|
||
"text": f"{item['TD']}",
|
||
}
|
||
)
|
||
elif item["TD"] in [-9, -13]:
|
||
markers.append(
|
||
{
|
||
"time": item["time"].strftime("%Y-%m-%d %H:%M:%S"),
|
||
"position": "below",
|
||
"shape": "arrow_up",
|
||
"color": "#FF0000",
|
||
"text": f"{item['TD']}",
|
||
}
|
||
)
|
||
|
||
chart.marker_list(markers)
|
||
|
||
def add_buy_sell_signal_markers(self, df, chart):
|
||
|
||
if "buy" not in df.columns or "sell" not in df.columns:
|
||
return
|
||
|
||
markers = []
|
||
signals = df[["time", "buy", "sell"]].to_dict(orient="records")
|
||
for item in signals:
|
||
if item["buy"] == 1:
|
||
markers.append(
|
||
{
|
||
"time": item["time"].strftime("%Y-%m-%d"),
|
||
"position": "below",
|
||
"shape": "arrow_up",
|
||
"color": "#00FF00",
|
||
"text": f"B",
|
||
}
|
||
)
|
||
elif item["sell"] == 1:
|
||
markers.append(
|
||
{
|
||
"time": item["time"].strftime("%Y-%m-%d"),
|
||
"position": "above",
|
||
"shape": "arrow_down",
|
||
"color": "#FF0000",
|
||
"text": f"S",
|
||
}
|
||
)
|
||
|
||
chart.marker_list(markers)
|
||
|
||
def on_range_change_poc(self, chart, bars_before, bars_after):
|
||
df = chart.candle_data
|
||
if df is None or df.empty:
|
||
return
|
||
|
||
total_bars = len(df)
|
||
# 计算可见范围的起始和结束索引
|
||
# if bars_after < 0:
|
||
# start_idx = max(0, int(bars_before))
|
||
# end_idx = total_bars
|
||
# elif bars_before < 0:
|
||
# start_idx = 0
|
||
# end_idx = max(0, int(total_bars - bars_after))
|
||
# else:
|
||
# start_idx = int(bars_before)
|
||
# end_idx = max(0, int(total_bars - bars_after))
|
||
|
||
start_idx = max(0, int(bars_before))
|
||
end_idx = max(0, int(total_bars - max(0, int(bars_after))))
|
||
|
||
df_range = df.iloc[start_idx:end_idx]
|
||
# logger.info(
|
||
# f"Calculating POC for bars {start_idx} to {end_idx}, total_bars={total_bars}, bars_before={bars_before}, bars_after={bars_after}"
|
||
# )
|
||
# logger.info(f"df_range_len: {len(df_range)}")
|
||
poc = POC(df_range)
|
||
poc_line_name = "POC"
|
||
if poc_line_name in self.horizontal_lines:
|
||
self.horizontal_lines[poc_line_name].delete()
|
||
# 添加新的 POC 水平线
|
||
self.visible_latest_close = df_range["close"].iloc[-1]
|
||
profit = (self.visible_latest_close - poc) / poc * 100
|
||
poc_line = chart.horizontal_line(
|
||
price=poc, color="#FF0000", width=4, style="solid", text=f"{poc:.2f}"
|
||
)
|
||
legend_text = f"POC: {poc:.2f}, poc_range_profit%: {profit:.1f}%, visible_tf_cnt: {len(df_range)}"
|
||
self.update_legend(chart=chart, key=poc_line_name, text=legend_text)
|
||
|
||
self.horizontal_lines[poc_line_name] = poc_line
|
||
|
||
def check_df(self, df: pd.DataFrame):
|
||
# basic type check
|
||
if not isinstance(df, pd.DataFrame):
|
||
raise TypeError("df must be a pandas DataFrame")
|
||
|
||
required_cols = ["time", "open", "high", "low", "close", "volume"]
|
||
missing = [c for c in required_cols if c not in df.columns]
|
||
if missing:
|
||
raise ValueError(
|
||
f"Missing required columns: {missing}. Required: {required_cols}"
|
||
)
|
||
|
||
# time column must be datetime.datetime or pd.Timestamp (or a datetime64 dtype)
|
||
time_series = df["time"]
|
||
if not (
|
||
pd.api.types.is_datetime64_any_dtype(time_series)
|
||
or time_series.apply(
|
||
lambda x: isinstance(x, (pd.Timestamp, datetime))
|
||
).all()
|
||
):
|
||
raise TypeError(
|
||
"Column 'time' must contain datetime values (python datetime.datetime or pandas.Timestamp) "
|
||
"or be a datetime64 dtype."
|
||
)
|
||
|
||
# other columns must be numeric (int or float)
|
||
for col in ["open", "high", "low", "close", "volume"]:
|
||
if not pd.api.types.is_numeric_dtype(df[col]):
|
||
raise TypeError(f"Column '{col}' must be numeric (int or float)")
|
||
|
||
def setup_crosshair_tracking(self, chart):
|
||
chart.run_script(
|
||
f"""
|
||
{chart.id}.chart.subscribeCrosshairMove((param) => {{
|
||
if (!param.point) return;
|
||
const price = {chart.id}.series.coordinateToPrice(param.point.y);
|
||
if (price !== null) {{
|
||
window.callbackFunction(`crosshair_price_~_${{price}}`);
|
||
}}
|
||
}})
|
||
"""
|
||
)
|
||
|
||
# 注册回调处理
|
||
def on_crosshair_price(price_str):
|
||
# 实时获取鼠标位置y轴的价格
|
||
price = float(price_str)
|
||
if self.visible_latest_close is not None:
|
||
profit = (self.visible_latest_close - price) / price * 100
|
||
self.update_legend(
|
||
chart=chart,
|
||
key="Crosshair Price Profit%",
|
||
text=f"crosshair_price_profit%: {profit:.2f}%",
|
||
)
|
||
|
||
chart.win.handlers["crosshair_price"] = lambda p: on_crosshair_price(p)
|
||
|
||
def plot_chart(
|
||
self,
|
||
df,
|
||
symbol: str,
|
||
name: str,
|
||
timeframe: str,
|
||
init_visible_num_bars: int = 90,
|
||
):
|
||
# 校验数据是否满足
|
||
self.check_df(df)
|
||
|
||
chart = Chart(toolbox=True, inner_height=0.8, maximize=True)
|
||
chart.layout(font_family="Times New Roman")
|
||
chart.topbar.textbox("symbol", symbol)
|
||
chart.topbar.textbox("name", name)
|
||
chart.topbar.textbox("timeframe", timeframe)
|
||
chart.legend(visible=True, font_size=14, color="#FFFFFF")
|
||
chart.set(df)
|
||
|
||
# 设置刚进入chart时的可见k线数量范围
|
||
end_time = df["time"].iloc[-1]
|
||
start_time = df["time"].iloc[-init_visible_num_bars]
|
||
chart.set_visible_range(start_time, end_time)
|
||
|
||
# 设置每次放缩k线范围时的回调函数计算实时计算poc
|
||
chart.events.range_change += self.on_range_change_poc
|
||
|
||
self.setup_crosshair_tracking(chart)
|
||
|
||
# 添加技术指标
|
||
# add_ema(df, chart, period=10)
|
||
# add_ema(df, chart, period=20)
|
||
self.add_ema(df, chart, period=30)
|
||
# add_ema(df, chart, period=60)
|
||
self.add_macd(df, chart)
|
||
self.add_cci(df, chart, period=14)
|
||
self.add_TD(df, chart)
|
||
self.add_buy_sell_signal_markers(df, chart)
|
||
|
||
chart.show(block=True)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
...
|