diff --git a/chart.py b/chart.py new file mode 100644 index 0000000..80a498e --- /dev/null +++ b/chart.py @@ -0,0 +1,203 @@ +import pandas as pd +from db_config import DatabaseManager, DatabaseConfig +from loguru import logger +import talib as ta +from lightweight_charts import Chart + + +def add_ema(df, chart, period: int = 50): + ema = ta.EMA(df["close"], timeperiod=period) + name = f"EMA {period}" + df[name] = ema + PERIOD_COLORS = { + 10: "#E42C2A", + 20: "#E49D30", + 30: "#3CFF0A", + 60: "#3275E4", + 200: "#B06CE3", + } + line = chart.create_line(name, color=PERIOD_COLORS[period], width=2) + line.set(df[["time", name]]) + + +def add_cci(df, chart, period: int = 14, height: float = 0.1, position: str = "bottom"): + cci = ta.CCI(df["high"], df["low"], df["close"], timeperiod=period) + df["cci"] = cci + cci_chart = chart.create_subchart( + position=position, width=1, height=height, sync=True + ) + cci_chart.legend(visible=True) + cci_chart.time_scale(visible=False) + cci_line = cci_chart.create_line(name="cci", color="#FF5722", width=2) + cci_line.set(df[["time", "cci"]]) + + +def add_macd(df, chart, height: float = 0.2, position: str = "bottom"): + macd, signal, hist = ta.MACD( + df["close"], fastperiod=12, slowperiod=26, signalperiod=9 + ) + df["macd"] = macd + df["signal"] = signal + df["histogram"] = hist + macd_chart = chart.create_subchart( + position=position, width=1, height=height, sync=True + ) + macd_chart.legend(visible=True) + macd_chart.time_scale(visible=False) + + histogram = macd_chart.create_histogram(name="histogram") + hist_data = df[["time", "histogram"]].copy() + hist_data["color"] = hist_data["histogram"].apply( + lambda x: "#26a69a" if x < 0 else "#ef5350" # 绿色 : 红色 + ) + histogram.set(hist_data) + + macd_line = macd_chart.create_line(name="macd", color="#2962FF", width=2) + macd_line.set(df[["time", "macd"]]) + signal_line = macd_chart.create_line(name="signal", color="#FF6D00", width=2) + signal_line.set(df[["time", "signal"]]) + + +def TD(dataframe: pd.DataFrame): + close = dataframe["close"].to_list() + td = [0, 0, 0, 0] + up = 0 + down = 0 + for i in range(4, len(close)): + if close[i] > close[i - 4]: + up += 1 + down = 0 + td.append(up) + else: + down -= 1 + up = 0 + td.append(down) + return td + + +def add_TD(df, chart): + df["TD"] = TD(df) + td_line = chart.create_line( + name="TD", + color="rgba(0, 0, 0, 0)", # 透明色,不在图表上显示线条 + width=0, + price_line=False, + price_label=False, + price_scale_id="td_scale", + ) + td_line.precision(0) + td_line.set(df[["time", "TD"]]) + TDs = df[["time", "TD"]].to_dict(orient="records") + markers = [] + for item in TDs: + if item["TD"] == 9: + markers.append( + { + "time": item["time"].strftime("%Y-%m-%d"), + "position": "below", + "shape": "arrow_up", + "color": "#4CAF50", + "text": f"{item['TD']}", + } + ) + elif item["TD"] == -9: + markers.append( + { + "time": item["time"].strftime("%Y-%m-%d"), + "position": "above", + "shape": "arrow_down", + "color": "#F44336", + "text": f"{item['TD']}", + } + ) + + chart.marker_list(markers) + + +def get_kline(code: str) -> list: + """ + 获取所有指数代码 + :return: + """ + env = "online" + env = "daily" + db_config = DatabaseConfig(env=env) + logger.info(f"数据库连接: {db_config.connection_string}") + + db_manager = DatabaseManager(db_config) + sql = f"SELECT date as time, open, high, low, close, volume FROM public.index_kline where code='{code}' order by date;" + res = db_manager.execute_query(sql) + data_list = [dict(item) for item in res] + df = pd.DataFrame(data_list) + df["time"] = pd.to_datetime(df["time"]) + num_cols = ["open", "high", "low", "close", "volume"] + for col in num_cols: + if col in df.columns: + df[col] = pd.to_numeric(df[col], errors="coerce").astype(float) + return df + + +def resample_data(df: pd.DataFrame, timeframe: str) -> pd.DataFrame: + """ + 对日线数据进行重采样 + :param df: 原始日线数据(time 作为索引) + :param timeframe: 目标周期 ('1D', '1W', '1M', '3M', '1Y') + :return: 重采样后的 DataFrame + """ + # 映射周期到 pandas resample 规则 + timeframe_map = { + "1D": "D", # 日线 + "1W": "W", # 周线 + "1M": "M", # 月线 + "3M": "3M", # 季线 + "1Y": "Y", # 年线 + } + + if timeframe not in timeframe_map: + return df + df = df.set_index("time") + + rule = timeframe_map[timeframe] + + # 对 OHLCV 数据进行重采样 + resampled = ( + df.resample(rule) + .agg( + { + "open": "first", # 开盘价取第一个 + "high": "max", # 最高价取最大值 + "low": "min", # 最低价取最小值 + "close": "last", # 收盘价取最后一个 + "volume": "sum", # 成交量求和 + } + ) + .dropna() + ) + + # 重置索引,将 time 变回列 + resampled = resampled.reset_index() + return resampled + + +if __name__ == "__main__": + code = "399986" + timeframe = "1W" + df = get_kline(code=code) + df = resample_data(df, timeframe) + chart = Chart(toolbox=True, inner_height=0.7, maximize=True) + + chart.topbar.textbox("symbol", code) + chart.topbar.textbox("timeframe", timeframe) + # chart.time_scale(visible=False) # 将主图的时间轴隐藏 + chart.legend(visible=True) + + chart.set(df) + add_ema(df, chart, period=10) + add_ema(df, chart, period=20) + add_ema(df, chart, period=30) + add_ema(df, chart, period=60) + add_macd(df, chart) + add_cci(df, chart, period=14) + add_TD(df, chart) + + chart.show(block=True)