Files
etf/chart.py

448 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
from loguru import logger
import talib as ta
import numpy as np
from lightweight_charts import Chart # lightweight-charts==2.1
import random
from datetime import datetime
def TD(dataframe: pd.DataFrame):
close = dataframe["close"].to_list()
td = [0, 0, 0, 0]
up = 0
down = 0
for i in range(4, len(close)):
if close[i] > close[i - 4]:
up += 1
down = 0
td.append(up)
else:
down -= 1
up = 0
td.append(down)
return td
def resample_data(df: pd.DataFrame, timeframe: str) -> pd.DataFrame:
"""
对日线数据进行重采样
:param df: 原始日线数据(time 作为索引)
:param timeframe: 目标周期 ('1D', '1W', '1M', '3M', '1Y')
:return: 重采样后的 DataFrame
"""
# 映射周期到 pandas resample 规则
timeframe_map = {
"1D": "D", # 日线
"1W": "W", # 周线
"1M": "M", # 月线
"3M": "3M", # 季线
"1Y": "Y", # 年线
}
if timeframe not in timeframe_map:
return df
df = df.set_index("time")
rule = timeframe_map[timeframe]
# 对 OHLCV 数据进行重采样
resampled = (
df.resample(rule)
.agg(
{
"open": "first", # 开盘价取第一个
"high": "max", # 最高价取最大值
"low": "min", # 最低价取最小值
"close": "last", # 收盘价取最后一个
"volume": "sum", # 成交量求和
}
)
.dropna()
)
# 重置索引,将 time 变回列
resampled = resampled.reset_index()
return resampled
def POC(df, bins: int = 50):
"""
计算价格分布的成交量峰值位置POCPoint Of Control
参数:
df: 包含 'low', 'high', 'volume' 三列的 DataFrame每行代表一根 K 线)
bins: 将价格区间划分为多少个小区间(价格桶),默认为 50
返回:
poc: 代表成交量最大的价格区间的中点价格
"""
# 当前数据的最低价和最高价,用于构建价格区间
low, high = df["low"].min(), df["high"].max()
# 将整个价格区间等分为 bins 个小区间,需要 bins+1 个边界点
price_ranges = np.linspace(low, high, bins + 1)
# 初始化每个价格区间累积的成交量数组,长度为 bins区间个数
volume_per_price = np.zeros(bins)
# 遍历每一根 K 线,将该 K 线的成交量按覆盖的价格区间平均分配
for i in range(len(df)):
high = df["high"].iloc[i]
low = df["low"].iloc[i]
volume = df["volume"].iloc[i]
# 找到当前 K 线覆盖的价格边界(注意使用闭区间判断)
# price_ranges 表示边界点,若 price_ranges[j] 在 [low, high] 范围内,则第 j 个边界被覆盖
price_coverage = (price_ranges >= low) & (price_ranges <= high)
covered_bins = np.where(price_coverage)[0]
if len(covered_bins) > 0:
# 如果覆盖了多个边界点,则这些边界之间形成了若干完整的区间
# 将该 K 线的成交量平均分配到这些被覆盖的区间上
# 注意covered_bins 里是边界索引,最后一个边界索引对应的区间索引需小于 bins
volume_per_bin = volume / len(covered_bins)
for bin_idx in covered_bins:
if bin_idx < bins:
volume_per_price[bin_idx] += volume_per_bin
# 找到成交量最大的区间索引
idx = np.argmax(volume_per_price)
# 以该区间的两个边界的中点作为 POC 值
poc = (price_ranges[idx] + price_ranges[idx + 1]) / 2
return poc
def get_fixed_color_based_on_period(num: int):
# 使用周期值作为随机种子,确保相同周期生成相同颜色
random.seed(num)
# 生成随机的RGB值
r = random.randint(0, 255)
g = random.randint(0, 255)
b = random.randint(0, 255)
# 将RGB值转换为十六进制颜色代码
color_code = "#{:02x}{:02x}{:02x}".format(r, g, b)
# 重置随机种子,避免影响其他随机操作
random.seed(None)
return color_code
class QuantChart:
def __init__(self):
self.horizontal_lines = {}
self.legend_data = {}
self.visible_latest_close = None
def update_legend(self, chart, key, text):
self.legend_data[key] = text
sorted_dict = dict(sorted(self.legend_data.items()))
full_text = ", ".join(sorted_dict.values())
chart.legend(visible=True, text=full_text)
def add_ema(self, df, chart, period: int = 50):
name = f"EMA_{period}"
df[name] = ta.EMA(df["close"], timeperiod=period)
color = get_fixed_color_based_on_period(num=period)
line = chart.create_line(
name, color=color, width=2, price_label=False, price_line=False
)
line.set(df[["time", name]])
def add_cci(
self, df, chart, period: int = 14, height: float = 0.1, position: str = "bottom"
):
cci = ta.CCI(df["high"], df["low"], df["close"], timeperiod=period)
df[f"CCI_{period}"] = cci
cci_chart = chart.create_subchart(
position=position, width=1, height=height, sync=True
)
cci_chart.layout(font_family="Times New Roman")
cci_chart.legend(
visible=True, font_size=14, color="#FFFFFF", font_family="Times New Roman"
)
cci_chart.time_scale(visible=False)
cci_line = cci_chart.create_line(name=f"CCI_{period}", color="#FF0000", width=2)
cci_line.set(df[["time", f"CCI_{period}"]])
df = df[["time"]].copy()
df["h"] = 100
df["l"] = -100
cci_line = cci_chart.create_line(
name="h",
color="#D4C21C",
width=1,
style="dashed",
price_label=False,
price_line=False,
)
cci_line.set(df[["time", "h"]])
cci_line = cci_chart.create_line(
name="l",
color="#D4C21C",
width=1,
style="dashed",
price_label=False,
price_line=False,
)
cci_line.set(df[["time", "l"]])
def add_macd(
self,
df,
chart,
fastperiod: int = 12,
slowperiod: int = 26,
signalperiod: int = 9,
height: float = 0.1,
position: str = "bottom",
):
macd, signal, hist = ta.MACD(
df["close"],
fastperiod=fastperiod,
slowperiod=slowperiod,
signalperiod=signalperiod,
)
df["DIF"] = macd
df["DEA"] = signal
macd_name = f"MACD_{fastperiod}_{slowperiod}_{signalperiod}"
df[macd_name] = hist * 2
macd_chart = chart.create_subchart(
position=position, width=1, height=height, sync=True
)
macd_chart.layout(font_family="Times New Roman")
macd_chart.legend(
visible=True, font_size=14, color="#FFFFFF", font_family="Times New Roman"
)
macd_chart.time_scale(visible=False)
histogram = macd_chart.create_histogram(name=macd_name)
hist_data = df[["time", macd_name]].copy()
hist_data["color"] = hist_data[macd_name].apply(
lambda x: "#00FF00" if x < 0 else "#ff0000" # 绿色 : 红色
)
histogram.set(hist_data)
macd_line = macd_chart.create_line(
name="DIF", color="#2962FF", width=2, price_label=False, price_line=False
)
macd_line.set(df[["time", "DIF"]])
signal_line = macd_chart.create_line(
name="DEA", color="#FF0000", width=2, price_label=False, price_line=False
)
signal_line.set(df[["time", "DEA"]])
def add_TD(self, df, chart):
df["TD"] = TD(df)
td_line = chart.create_line(
name="TD",
color="rgba(0, 0, 0, 0)", # 透明色,不在图表上显示线条
width=0,
price_line=False,
price_label=False,
price_scale_id="td_scale",
)
td_line.precision(0)
td_line.set(df[["time", "TD"]])
TDs = df[["time", "TD"]].to_dict(orient="records")
markers = []
for item in TDs:
if item["TD"] in [9, 13]:
markers.append(
{
"time": item["time"].strftime("%Y-%m-%d %H:%M:%S"),
"position": "above",
"shape": "arrow_down",
"color": "#00FF00",
"text": f"{item['TD']}",
}
)
elif item["TD"] in [-9, -13]:
markers.append(
{
"time": item["time"].strftime("%Y-%m-%d %H:%M:%S"),
"position": "below",
"shape": "arrow_up",
"color": "#FF0000",
"text": f"{item['TD']}",
}
)
chart.marker_list(markers)
def add_buy_sell_signal_markers(self, df, chart):
if "buy" not in df.columns or "sell" not in df.columns:
return
markers = []
signals = df[["time", "buy", "sell"]].to_dict(orient="records")
for item in signals:
if item["buy"] == 1:
markers.append(
{
"time": item["time"].strftime("%Y-%m-%d"),
"position": "below",
"shape": "arrow_up",
"color": "#00FF00",
"text": f"B",
}
)
elif item["sell"] == 1:
markers.append(
{
"time": item["time"].strftime("%Y-%m-%d"),
"position": "above",
"shape": "arrow_down",
"color": "#FF0000",
"text": f"S",
}
)
chart.marker_list(markers)
def on_range_change_poc(self, chart, bars_before, bars_after):
df = chart.candle_data
if df is None or df.empty:
return
total_bars = len(df)
# TODO: k线拉到最早会报错
if bars_after < 0:
start_idx = max(0, int(bars_before))
end_idx = total_bars
else:
start_idx = int(bars_before)
end_idx = max(0, int(total_bars - bars_after))
df_range = df.iloc[start_idx:end_idx]
# logger.info(
# f"Calculating POC for bars {start_idx} to {end_idx}, total_bars={total_bars}, bars_before={bars_before}, bars_after={bars_after}"
# )
# logger.info(f"df_range_len: {len(df_range)}")
poc = POC(df_range)
poc_line_name = "POC"
if poc_line_name in self.horizontal_lines:
self.horizontal_lines[poc_line_name].delete()
# 添加新的 POC 水平线
self.visible_latest_close = df_range["close"].iloc[-1]
profit = (self.visible_latest_close - poc) / poc * 100
poc_line = chart.horizontal_line(
price=poc, color="#FF0000", width=4, style="solid", text=f"{poc:.2f}"
)
# chart.legend(
# visible=True,
# text=f"POC: {poc:.2f}, poc_range_profit%: {profit:.1f}%, visible_tf_cnt: {len(df_range)}",
# )
legend_text = f"POC: {poc:.2f}, poc_range_profit%: {profit:.1f}%, visible_tf_cnt: {len(df_range)}"
self.update_legend(chart=chart, key=poc_line_name, text=legend_text)
self.horizontal_lines[poc_line_name] = poc_line
def check_df(self, df: pd.DataFrame):
# basic type check
if not isinstance(df, pd.DataFrame):
raise TypeError("df must be a pandas DataFrame")
required_cols = ["time", "open", "high", "low", "close", "volume"]
missing = [c for c in required_cols if c not in df.columns]
if missing:
raise ValueError(
f"Missing required columns: {missing}. Required: {required_cols}"
)
# time column must be datetime.datetime or pd.Timestamp (or a datetime64 dtype)
time_series = df["time"]
if not (
pd.api.types.is_datetime64_any_dtype(time_series)
or time_series.apply(
lambda x: isinstance(x, (pd.Timestamp, datetime))
).all()
):
raise TypeError(
"Column 'time' must contain datetime values (python datetime.datetime or pandas.Timestamp) "
"or be a datetime64 dtype."
)
# other columns must be numeric (int or float)
for col in ["open", "high", "low", "close", "volume"]:
if not pd.api.types.is_numeric_dtype(df[col]):
raise TypeError(f"Column '{col}' must be numeric (int or float)")
def setup_crosshair_tracking(self, chart):
chart.run_script(
f"""
{chart.id}.chart.subscribeCrosshairMove((param) => {{
if (!param.point) return;
const price = {chart.id}.series.coordinateToPrice(param.point.y);
if (price !== null) {{
window.callbackFunction(`crosshair_price_~_${{price}}`);
}}
}})
"""
)
# 注册回调处理
def on_crosshair_price(price_str):
# 实时获取鼠标位置y轴的价格
price = float(price_str)
if self.visible_latest_close is not None:
profit = (self.visible_latest_close - price) / price * 100
self.update_legend(
chart=chart,
key="Crosshair Price Profit%",
text=f"crosshair_price_profit%: {profit:.2f}%",
)
chart.win.handlers["crosshair_price"] = lambda p: on_crosshair_price(p)
def plot_chart(
self,
df,
symbol: str,
name: str,
timeframe: str,
init_visible_num_bars: int = 90,
):
# 校验数据是否满足
self.check_df(df)
chart = Chart(toolbox=True, inner_height=0.8, maximize=True)
chart.layout(font_family="Times New Roman")
chart.topbar.textbox("symbol", symbol)
chart.topbar.textbox("name", name)
chart.topbar.textbox("timeframe", timeframe)
chart.legend(visible=True, font_size=14, color="#FFFFFF")
chart.set(df)
# 设置刚进入chart时的可见k线数量范围
end_time = df["time"].iloc[-1]
start_time = df["time"].iloc[-init_visible_num_bars]
chart.set_visible_range(start_time, end_time)
# 设置每次放缩k线范围时的回调函数计算实时计算poc
chart.events.range_change += self.on_range_change_poc
self.setup_crosshair_tracking(chart)
# 添加技术指标
# add_ema(df, chart, period=10)
# add_ema(df, chart, period=20)
self.add_ema(df, chart, period=30)
# add_ema(df, chart, period=60)
self.add_macd(df, chart)
self.add_cci(df, chart, period=14)
self.add_TD(df, chart)
self.add_buy_sell_signal_markers(df, chart)
chart.show(block=True)
if __name__ == "__main__":
...