Files
etf/chart.py
2025-10-18 16:33:50 +08:00

213 lines
6.2 KiB
Python

import pandas as pd
from db_config import DatabaseManager, DatabaseConfig
from loguru import logger
import talib as ta
from lightweight_charts import Chart
def add_ema(df, chart, period: int = 50):
ema = ta.EMA(df["close"], timeperiod=period)
name = f"EMA {period}"
df[name] = ema
PERIOD_COLORS = {
10: "#E42C2A",
20: "#E49D30",
30: "#3CFF0A",
60: "#3275E4",
200: "#B06CE3",
}
line = chart.create_line(name, color=PERIOD_COLORS[period], width=2)
line.set(df[["time", name]])
def add_cci(df, chart, period: int = 14, height: float = 0.1, position: str = "bottom"):
cci = ta.CCI(df["high"], df["low"], df["close"], timeperiod=period)
df["cci"] = cci
cci_chart = chart.create_subchart(
position=position, width=1, height=height, sync=True
)
cci_chart.legend(visible=True)
cci_chart.time_scale(visible=False)
cci_line = cci_chart.create_line(name="cci", color="#FF5722", width=2)
cci_line.set(df[["time", "cci"]])
def add_macd(df, chart, height: float = 0.2, position: str = "bottom"):
macd, signal, hist = ta.MACD(
df["close"], fastperiod=12, slowperiod=26, signalperiod=9
)
df["macd"] = macd
df["signal"] = signal
df["histogram"] = hist
macd_chart = chart.create_subchart(
position=position, width=1, height=height, sync=True
)
macd_chart.legend(visible=True)
macd_chart.time_scale(visible=False)
histogram = macd_chart.create_histogram(name="histogram")
hist_data = df[["time", "histogram"]].copy()
hist_data["color"] = hist_data["histogram"].apply(
lambda x: "#26a69a" if x < 0 else "#ef5350" # 绿色 : 红色
)
histogram.set(hist_data)
macd_line = macd_chart.create_line(name="macd", color="#2962FF", width=2)
macd_line.set(df[["time", "macd"]])
signal_line = macd_chart.create_line(name="signal", color="#FF6D00", width=2)
signal_line.set(df[["time", "signal"]])
def TD(dataframe: pd.DataFrame):
close = dataframe["close"].to_list()
td = [0, 0, 0, 0]
up = 0
down = 0
for i in range(4, len(close)):
if close[i] > close[i - 4]:
up += 1
down = 0
td.append(up)
else:
down -= 1
up = 0
td.append(down)
return td
def add_TD(df, chart):
df["TD"] = TD(df)
td_line = chart.create_line(
name="TD",
color="rgba(0, 0, 0, 0)", # 透明色,不在图表上显示线条
width=0,
price_line=False,
price_label=False,
price_scale_id="td_scale",
)
td_line.precision(0)
td_line.set(df[["time", "TD"]])
TDs = df[["time", "TD"]].to_dict(orient="records")
markers = []
for item in TDs:
if item["TD"] == 9:
markers.append(
{
"time": item["time"].strftime("%Y-%m-%d"),
"position": "below",
"shape": "arrow_up",
"color": "#4CAF50",
"text": f"{item['TD']}",
}
)
elif item["TD"] == -9:
markers.append(
{
"time": item["time"].strftime("%Y-%m-%d"),
"position": "above",
"shape": "arrow_down",
"color": "#F44336",
"text": f"{item['TD']}",
}
)
chart.marker_list(markers)
def get_kline(code: str) -> list:
"""
获取所有指数代码
:return:
"""
env = "online"
env = "daily"
db_config = DatabaseConfig(env=env)
logger.info(f"数据库连接: {db_config.connection_string}")
db_manager = DatabaseManager(db_config)
sql = f"SELECT date as time, open, high, low, close, volume FROM public.index_kline where code='{code}' order by date;"
res = db_manager.execute_query(sql)
data_list = [dict(item) for item in res]
df = pd.DataFrame(data_list)
df["time"] = pd.to_datetime(df["time"])
num_cols = ["open", "high", "low", "close", "volume"]
for col in num_cols:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors="coerce").astype(float)
return df
def resample_data(df: pd.DataFrame, timeframe: str) -> pd.DataFrame:
"""
对日线数据进行重采样
:param df: 原始日线数据(time 作为索引)
:param timeframe: 目标周期 ('1D', '1W', '1M', '3M', '1Y')
:return: 重采样后的 DataFrame
"""
# 映射周期到 pandas resample 规则
timeframe_map = {
"1D": "D", # 日线
"1W": "W", # 周线
"1M": "M", # 月线
"3M": "3M", # 季线
"1Y": "Y", # 年线
}
if timeframe not in timeframe_map:
return df
df = df.set_index("time")
rule = timeframe_map[timeframe]
# 对 OHLCV 数据进行重采样
resampled = (
df.resample(rule)
.agg(
{
"open": "first", # 开盘价取第一个
"high": "max", # 最高价取最大值
"low": "min", # 最低价取最小值
"close": "last", # 收盘价取最后一个
"volume": "sum", # 成交量求和
}
)
.dropna()
)
# 重置索引,将 time 变回列
resampled = resampled.reset_index()
return resampled
if __name__ == "__main__":
code = "399986"
timeframe = "1W"
df = pd.read_csv(
"/Users/aszer/Documents/vscode/etf/data/index_all_stock.csv",
encoding="utf-8-sig",
)
name = df.loc[df["代码"] == code, "名称"].values[0]
df = get_kline(code=code)
df = resample_data(df, timeframe)
chart = Chart(toolbox=True, inner_height=0.7, maximize=True)
chart.topbar.textbox("symbol", code)
chart.topbar.textbox("name", name)
chart.topbar.textbox("timeframe", timeframe)
# chart.time_scale(visible=False) # 将主图的时间轴隐藏
chart.legend(visible=True)
chart.set(df)
add_ema(df, chart, period=10)
add_ema(df, chart, period=20)
add_ema(df, chart, period=30)
add_ema(df, chart, period=60)
add_macd(df, chart)
add_cci(df, chart, period=14)
add_TD(df, chart)
chart.show(block=True)