204 lines
5.9 KiB
Python
204 lines
5.9 KiB
Python
import pandas as pd
|
|
from db_config import DatabaseManager, DatabaseConfig
|
|
from loguru import logger
|
|
import talib as ta
|
|
from lightweight_charts import Chart
|
|
|
|
|
|
def add_ema(df, chart, period: int = 50):
|
|
ema = ta.EMA(df["close"], timeperiod=period)
|
|
name = f"EMA {period}"
|
|
df[name] = ema
|
|
PERIOD_COLORS = {
|
|
10: "#E42C2A",
|
|
20: "#E49D30",
|
|
30: "#3CFF0A",
|
|
60: "#3275E4",
|
|
200: "#B06CE3",
|
|
}
|
|
line = chart.create_line(name, color=PERIOD_COLORS[period], width=2)
|
|
line.set(df[["time", name]])
|
|
|
|
|
|
def add_cci(df, chart, period: int = 14, height: float = 0.1, position: str = "bottom"):
|
|
cci = ta.CCI(df["high"], df["low"], df["close"], timeperiod=period)
|
|
df["cci"] = cci
|
|
cci_chart = chart.create_subchart(
|
|
position=position, width=1, height=height, sync=True
|
|
)
|
|
cci_chart.legend(visible=True)
|
|
cci_chart.time_scale(visible=False)
|
|
cci_line = cci_chart.create_line(name="cci", color="#FF5722", width=2)
|
|
cci_line.set(df[["time", "cci"]])
|
|
|
|
|
|
def add_macd(df, chart, height: float = 0.2, position: str = "bottom"):
|
|
macd, signal, hist = ta.MACD(
|
|
df["close"], fastperiod=12, slowperiod=26, signalperiod=9
|
|
)
|
|
df["macd"] = macd
|
|
df["signal"] = signal
|
|
df["histogram"] = hist
|
|
macd_chart = chart.create_subchart(
|
|
position=position, width=1, height=height, sync=True
|
|
)
|
|
macd_chart.legend(visible=True)
|
|
macd_chart.time_scale(visible=False)
|
|
|
|
histogram = macd_chart.create_histogram(name="histogram")
|
|
hist_data = df[["time", "histogram"]].copy()
|
|
hist_data["color"] = hist_data["histogram"].apply(
|
|
lambda x: "#26a69a" if x < 0 else "#ef5350" # 绿色 : 红色
|
|
)
|
|
histogram.set(hist_data)
|
|
|
|
macd_line = macd_chart.create_line(name="macd", color="#2962FF", width=2)
|
|
macd_line.set(df[["time", "macd"]])
|
|
signal_line = macd_chart.create_line(name="signal", color="#FF6D00", width=2)
|
|
signal_line.set(df[["time", "signal"]])
|
|
|
|
|
|
def TD(dataframe: pd.DataFrame):
|
|
close = dataframe["close"].to_list()
|
|
td = [0, 0, 0, 0]
|
|
up = 0
|
|
down = 0
|
|
for i in range(4, len(close)):
|
|
if close[i] > close[i - 4]:
|
|
up += 1
|
|
down = 0
|
|
td.append(up)
|
|
else:
|
|
down -= 1
|
|
up = 0
|
|
td.append(down)
|
|
return td
|
|
|
|
|
|
def add_TD(df, chart):
|
|
df["TD"] = TD(df)
|
|
td_line = chart.create_line(
|
|
name="TD",
|
|
color="rgba(0, 0, 0, 0)", # 透明色,不在图表上显示线条
|
|
width=0,
|
|
price_line=False,
|
|
price_label=False,
|
|
price_scale_id="td_scale",
|
|
)
|
|
td_line.precision(0)
|
|
td_line.set(df[["time", "TD"]])
|
|
TDs = df[["time", "TD"]].to_dict(orient="records")
|
|
markers = []
|
|
for item in TDs:
|
|
if item["TD"] == 9:
|
|
markers.append(
|
|
{
|
|
"time": item["time"].strftime("%Y-%m-%d"),
|
|
"position": "below",
|
|
"shape": "arrow_up",
|
|
"color": "#4CAF50",
|
|
"text": f"{item['TD']}",
|
|
}
|
|
)
|
|
elif item["TD"] == -9:
|
|
markers.append(
|
|
{
|
|
"time": item["time"].strftime("%Y-%m-%d"),
|
|
"position": "above",
|
|
"shape": "arrow_down",
|
|
"color": "#F44336",
|
|
"text": f"{item['TD']}",
|
|
}
|
|
)
|
|
|
|
chart.marker_list(markers)
|
|
|
|
|
|
def get_kline(code: str) -> list:
|
|
"""
|
|
获取所有指数代码
|
|
:return:
|
|
"""
|
|
env = "online"
|
|
env = "daily"
|
|
db_config = DatabaseConfig(env=env)
|
|
logger.info(f"数据库连接: {db_config.connection_string}")
|
|
|
|
db_manager = DatabaseManager(db_config)
|
|
sql = f"SELECT date as time, open, high, low, close, volume FROM public.index_kline where code='{code}' order by date;"
|
|
res = db_manager.execute_query(sql)
|
|
data_list = [dict(item) for item in res]
|
|
df = pd.DataFrame(data_list)
|
|
df["time"] = pd.to_datetime(df["time"])
|
|
num_cols = ["open", "high", "low", "close", "volume"]
|
|
for col in num_cols:
|
|
if col in df.columns:
|
|
df[col] = pd.to_numeric(df[col], errors="coerce").astype(float)
|
|
return df
|
|
|
|
|
|
def resample_data(df: pd.DataFrame, timeframe: str) -> pd.DataFrame:
|
|
"""
|
|
对日线数据进行重采样
|
|
:param df: 原始日线数据(time 作为索引)
|
|
:param timeframe: 目标周期 ('1D', '1W', '1M', '3M', '1Y')
|
|
:return: 重采样后的 DataFrame
|
|
"""
|
|
# 映射周期到 pandas resample 规则
|
|
timeframe_map = {
|
|
"1D": "D", # 日线
|
|
"1W": "W", # 周线
|
|
"1M": "M", # 月线
|
|
"3M": "3M", # 季线
|
|
"1Y": "Y", # 年线
|
|
}
|
|
|
|
if timeframe not in timeframe_map:
|
|
return df
|
|
df = df.set_index("time")
|
|
|
|
rule = timeframe_map[timeframe]
|
|
|
|
# 对 OHLCV 数据进行重采样
|
|
resampled = (
|
|
df.resample(rule)
|
|
.agg(
|
|
{
|
|
"open": "first", # 开盘价取第一个
|
|
"high": "max", # 最高价取最大值
|
|
"low": "min", # 最低价取最小值
|
|
"close": "last", # 收盘价取最后一个
|
|
"volume": "sum", # 成交量求和
|
|
}
|
|
)
|
|
.dropna()
|
|
)
|
|
|
|
# 重置索引,将 time 变回列
|
|
resampled = resampled.reset_index()
|
|
return resampled
|
|
|
|
|
|
if __name__ == "__main__":
|
|
code = "399986"
|
|
timeframe = "1W"
|
|
df = get_kline(code=code)
|
|
df = resample_data(df, timeframe)
|
|
chart = Chart(toolbox=True, inner_height=0.7, maximize=True)
|
|
|
|
chart.topbar.textbox("symbol", code)
|
|
chart.topbar.textbox("timeframe", timeframe)
|
|
# chart.time_scale(visible=False) # 将主图的时间轴隐藏
|
|
chart.legend(visible=True)
|
|
|
|
chart.set(df)
|
|
add_ema(df, chart, period=10)
|
|
add_ema(df, chart, period=20)
|
|
add_ema(df, chart, period=30)
|
|
add_ema(df, chart, period=60)
|
|
add_macd(df, chart)
|
|
add_cci(df, chart, period=14)
|
|
add_TD(df, chart)
|
|
|
|
chart.show(block=True)
|