diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..bc4de5c --- /dev/null +++ b/.env.example @@ -0,0 +1,16 @@ +# ETF策略项目 - 环境变量配置模板 +# 复制此文件为 .env 并填入真实值 + +# ==================== Tushare API ==================== +TUSHARE_TOKEN=your_tushare_token_here + +# ==================== 钉钉机器人 ==================== +DINGTALK_WEBHOOK=https://oapi.dingtalk.com/robot/send?access_token=xxx +DINGTALK_SECRET=SECxxx + +# ==================== PostgreSQL数据库 ==================== +DB_HOST=192.168.0.115 +DB_PORT=5432 +DB_NAME=etf_db +DB_USER=admin +DB_PASS=admin diff --git a/.gitignore b/.gitignore index 4d5f47e..d0a9ac9 100644 --- a/.gitignore +++ b/.gitignore @@ -162,9 +162,11 @@ temp/ # API keys and secrets .env +.env.local config.ini secrets.json api_keys.txt +config_local.py # Database files *.db @@ -187,4 +189,12 @@ data_cache/ *.jpeg *.gif *.svg -*.csv + + +# Report files (keep examples) +report*.csv +report*.html +report*.png +!example_*.csv +!example_*.html +!example_*.png diff --git a/ETF轮动策略方案.md b/ETF轮动策略方案.md new file mode 100644 index 0000000..09294a7 --- /dev/null +++ b/ETF轮动策略方案.md @@ -0,0 +1,426 @@ +下面是一份基于你提供的四篇文档整理出的、可直接指导项目实施的**ETF轮动策略技术方案文档**。内容覆盖:策略原理、数据与因子、回测实现、CAGR纠错、QMT实盘迁移、常见坑与工程规范。 + +--- + +# 一、项目目标与整体架构 + +**项目目标**: +构建一套可从回测顺滑迁移到实盘的 ETF 轮动策略系统,实现: + +- 风格/资产/行业轮动逻辑的灵活配置 +- 日频轮动(支持未来扩展到周频) +- 有完整绩效评估(含修正后的 CAGR) +- 可在 QMT 上稳定运行的实盘版本(含仓位隔离、委托跟踪等) + +**整体架构分层**: + +1. **数据层**:行情与指数数据获取(AkShare、本地QMT历史数据)。 +2. **因子层**:动量/趋势得分计算(N日涨幅、斜率、斜率×R²)。 +3. **信号层**:基于得分的强弱排序与轮动信号生成。 +4. **回测层**:净值计算、quantstats 绩效评估(含 CAGR 修正)。 +5. **实盘层(QMT)**:数据驱动、下单逻辑、委托跟踪、仓位隔离、异常处理。 +6. **运维与风控**:日志、持仓文件、参数管理、错误监控。 + +--- + +# 二、策略原理与业务设计 + +## 2.1 轮动类型与本项目定位 + +文档中提到三类轮动:[1] + +- 资产轮动:股、债、商品、现金之间切换; +- 行业/板块轮动:不同行业 ETF 之间切换; +- 风格轮动:大小盘、价值/成长等风格之间切换。 + +**本项目主线**:以**风格轮动**为核心案例(大盘/小盘 + 价值/成长),同时方案设计要做到: + +- 候选池是可配置的(可替换为行业ETF、跨市场ETF组合); +- 动量计算与信号生成模块,对“标的池”保持抽象,不绑定具体标的。 + +## 2.2 候选池设计 + +基础文档中的风格轮动候选池为四只 ETF:[1][2] + +- 沪深300ETF:510300(大盘) +- 中证500ETF:510500(小盘) +- 红利ETF:510880(价值) +- 创业板ETF:159915(成长) + +**方案要求**: + +- 使用配置文件或策略参数维护 `CODE_LIST`: + ```python + CODE_LIST = ['510300', '510500', '510880', '159915'] + ``` +- 候选池可轻松扩展(文档中有将候选池扩展到纳指ETF、黄金ETF等的例子[2],可抽象为多资产轮动)。 + +**注意点**: + +- 候选池选择带有“主观视角”和后视镜效应,属于策略建模假设,应在项目文档中明确:回测收益不能简单外推到未来[1]。 +- 对于项目落地,可在候选池上设计多个版本(保守/均衡/激进),以便 AB 测试。 + +--- + +# 三、数据与因子计算方案 + +## 3.1 Python 环境与依赖 + +文档中使用的环境与库:[1][2] + +- Python 3.x(文中示例 3.13) +- 关键库:`numpy`, `pandas`, `akshare`, `matplotlib`, `scikit-learn`, `quantstats` + +方案建议: + +- 统一到 3.10+ 的稳定版本; +- 通过 `requirements.txt` 管理: + ```text + numpy + pandas + akshare + matplotlib + scikit-learn + quantstats + ``` + +## 3.2 行情数据获取(离线回测) + +示例代码结构如下:[2] + +```python +code_list = ['510300', '510500', '510880', '159915'] +start_date = '20150101' +end_date = '20250828' + +df_list = [] +for code in code_list: + df = ak.fund_etf_hist_em( + symbol=code, + period='daily', + start_date=start_date, + end_date=end_date, + adjust='hfq' # 后复权 + ) + df.insert(0, 'code', code) + df_list.append(df) + time.sleep(3) + +all_df = pd.concat(df_list, ignore_index=True) +data = all_df.pivot(index='日期', columns='code', values='收盘') +data.index = pd.to_datetime(data.index) +data = data.sort_index() +``` + +**项目要求**: + +1. 抽象出 `DataFetcher` 模块,支持: + - AkShare 拉取 ETF 历史行情; + - 可切换为本地 CSV/HDF5(防止外网依赖导致不稳定)。 +2. 指数基准数据: + - 必须同时获取沪深300指数(000300.SH)等基准,用于绩效评估和交易日推算(QMT 实盘中利用指数序列推前一交易日,避免停牌问题[3])。 + +## 3.3 日收益率与动量(基础版) + +基础动量使用“前 N 日涨幅”:[2] + +```python +N = 10 # 动量窗口 + +for code in code_list: + data[f'日收益率_{code}'] = data[code] / data[code].shift(1) - 1 + data[f'涨幅_{code}'] = data[code] / data[code].shift(N+1) - 1.0 + +data = data.dropna() +``` + +**工程要求**: + +- N 作为可配置参数; +- 函数化封装,方便后续切换为斜率或斜率×R²得分。 + +## 3.4 趋势得分(改进版:斜率×R²) + +文档中最终表现最好的是“斜率×决定系数 R²”的趋势得分[2][4]: + +```python +def calculate_score(srs, N=25): + if srs.shape[0] < N: + return np.nan + x = np.arange(1, N+1) + y = srs.values / srs.values[0] # 归一化 + lr = LinearRegression().fit(x.reshape(-1, 1), y) + slope = lr.coef_[0] # 斜率 + r_squared = lr.score(x.reshape(-1, 1), y) # R² + score = 10000 * slope * r_squared + return score + +N = 25 # 斜率计算长度 + +for code in code_list: + data[f'日收益率_{code}'] = data[code] / data[code].shift(1) - 1 + data[f'得分_{code}'] = data[code].rolling(N).apply(lambda x: calculate_score(x, N)) + +data = data.dropna() +``` + +**项目选型**: + +- 建议**默认采用斜率×R²得分**作为主策略动量因子: + - 斜率:趋势方向与强度; + - R²:趋势拟合优度,过滤剧烈波动导致的“假斜率”; + - 实证上,从十年 7 倍提升到 19 倍收益[4]。 +- 保留“简单 N 日涨幅”的实现作为对照策略,以支持对比研究与回测验证。 + +--- + +# 四、信号生成与回测实现 + +## 4.1 信号生成逻辑 + +核心思想:**每天选得分最高的 ETF,第二天按该信号进行持仓**。 + +文档中的实现要点:[2] + +```python +# 取出每日得分最高的证券 +data['信号'] = data[^f'得分_{v}' for v in code_list].idxmax(axis=1) +# 今日涨幅由昨日持仓产生,为避免未来函数: +data['信号'] = data['信号'].shift(1) +data = data.dropna() + +data['轮动策略日收益率'] = data.apply( + lambda x: x[f'日收益率_{x["信号"]}'], axis=1 +) +``` + +**关键点(项目必须遵守)**: + +- **信号使用 T-1 日数据**,在 T 日交易,避免“用收盘价算信号再按收盘价交易”的未来函数[5]。 +- 第一日的交易日收益不记入策略收益(当日仅建仓)。 + +## 4.2 策略净值计算 + +策略净值曲线:以 1 为初始净值,日收益连乘:[2] + +```python +data['轮动策略净值'] = (1 + data['轮动策略日收益率']).cumprod() +``` + +同时可计算各 ETF 的净值用于可视化对比。 + +## 4.3 回测指标与性能结论 + +文档给出的改进后策略表现(2015-02-10 至 2025-08-28):[4] + +| 指标 | 基准 沪深300 | 策略(斜率×R²) | +|---------------|--------------|-----------------| +| Cumulative Return | 516.53% | 1,906.09% | +| CAGR % | 12.64% | 21.67%* | +| Sharpe | 0.92 | 1.33 | +| Max Drawdown | -28.57% | -30.31% | +| Time in Market| 98% | 99% | + +> 注:文中后续通过修正 Quantstats CAGR Bug,得到的**正确年化在 32%+ 区间**[6],见下一节。 + +**工程含义**: + +- 策略在不加费用、以理想成交价(收盘价)回测下,显著跑赢基准; +- 最大回撤与基准相近,但收益远超,收益风险比良好; +- 实盘需考虑滑点与交易成本,实际收益会低于回测。 + +--- + +# 五、CAGR 计算错误及修正方案 + +## 5.1 问题来源 + +文档详细分析了 quantstats 中 CAGR 计算的历史 Bug[6][7]: + +- quantstats 在 `cagr` 函数中使用: + - 年数 = (自然日天数 / periods) + - `periods` 默认值 = 252; +- 实际含义变成:**用自然日天数除以“交易日数 252”**,导致年数被拉长,年化收益被严重压低; +- 更糟糕的是,quantstats 在最终 metrics 中调用 `cagr` 时**没有传入 periods 参数**,导致用户在外部修改 `periods_per_year` 也不会影响 CAGR 结果[7]。 + +因此同一净值序列,QMT 与 quantstats 年化结果严重不一致:原版量化报告年化 21.69%,而 QMT 估算合理年化约 33%[6]。 + +## 5.2 修正方法(项目必须实现) + +文档给出两种修复方式,任选一种即可,也可两种都做[7]: + +**方式一:修改 stats.py 中 cagr 默认值** + +```python +# 原 +def cagr(returns, rf=0.0, compounded=True, periods=252): + +# 改 +def cagr(returns, rf=0.0, compounded=True, periods=365): +``` + +**方式二:修改 reports.py 中 metrics 调用** + +```python +# 原 +metrics["CAGR%%"] = _stats.cagr(df, rf, compounded) * pct + +# 改 +metrics["CAGR%%"] = _stats.cagr(df, rf, compounded, 365) * pct +``` + +修改后: + +- 累计收益保持不变(仍为约 1906.09%); +- 年化收益由 ~21.67% 修正到 ~32.86%,与 QMT 结果(33.46%)接近,差异来自自然日 vs 交易日、首尾日期含不含、净值精度等细节[8]。 + +**项目要求**: + +- 在性能评估模块中集成“CAGR 自算函数”,不要完全依赖 quantstats 黑箱; +- 保留 QMT 计算口径的对比(通过自然日和交易日两种方式各计算一份年化),并输出到报告中,以减少跨平台困惑。 + +--- + +# 六、QMT 实盘迁移技术方案 + +文档中专门有两篇文章讨论轮动策略迁移到 QMT 实盘及踩坑经历[3][5][9],项目需要按“回测逻辑 → 实盘逻辑”两层设计。 + +## 6.1 回测与实盘的关键差异 + +主要有五个维度的差异[5][9]: + +1. 数据驱动方式(历史批量 vs 实时推送); +2. 下单模式(假设全部成交 vs 部分成交、排队、撤单); +3. 委托跟踪(回测通常忽略,实盘需要轮询订单状态); +4. 仓位隔离(回测中账户只服务单一策略,实盘同一账户多策略并行); +5. 异常处理(网络、交易所异常、盘中停牌等)。 + +项目中应为 QMT 实盘实现一个单独的“执行层模块”,逻辑大体如下: + +- 定时任务在 `STRATEGY_TRADETIME`(如 14:58)触发; +- 拉取前一交易日日线收盘价,计算信号; +- 获取本策略持仓(见后面“仓位隔离”); +- 对比目标持仓与当前持仓,生成买卖委托; +- 下单后轮询状态,`ORDER_TIMEOUT` 超时则自动撤单重下。 + +## 6.2 参数配置规范(QMT) + +文档给出了比较完整的参数说明[5]: + +- `ACCOUNT_ID`:股票账户资金账号(实盘必须为真实账号)。 +- `ACCOUNT_TYPE`:固定 `"STOCK"`(ETF 走股票通道)。 +- `ACCOUNT_MODE`: + - `"MONEY"`:固定金额,如 `ACCOUNT_MONEY = 10000`; + - `"RATIO"`:占总资产比例,如 `ACCOUNT_RATIO = 0.3`。 +- `ACCOUNT_MONEY`:策略金额,建议 ≥ 能买得起候选池中任意一手 ETF,实践上建议不低于 \$1000 对应的人民币。 +- `ACCOUNT_RATIO`:0–1 之间的小数。 +- `STRATEGY_TRADETIME`:日内交易时间点(字符串),如 `"14:58"`。 +- `ORDER_TIMEOUT`:订单超时(秒),超过时间未完全成交则撤单重下。 +- `STRATEGY_PATH`:策略文件根路径,程序会在该路径下新建 `STRATEGY_NAME` 目录,用于: + - 交易日志; + - 本策略专用持仓文件。 +- `STRATEGY_NAME`:策略名称,**实盘启用后不得随意修改**,否则会导致无法识别历史持仓文件,重新启动会把旧单子视为“外来持仓”[5]。 +- `CODE_LIST`:实盘候选池,与回测保持一致。 +- `N_DAYS`:斜率/得分窗口,默认 25。 +- `SELECT_NUM`:一次轮动选中 ETF 数量(1 表示全仓单一品种,>1 表示多品种分仓)。 + +## 6.3 仓位隔离与本地文件 + +文档强调 QMT 自身只能区分“委托/成交归属哪个策略”,无法直接区分“当前总持仓中哪一部分是哪个策略开的”[9]。如果多个策略共用一个股票账户: + +- 调用 `get_trade_detail_data` 得到的是**全账户持仓总和**; +- 若策略逻辑是 “if code not in selected_list then sell”,则会把其他策略和手工单的仓位一并卖掉。 + +**项目必须实现**: + +- 为每个策略在 `STRATEGY_PATH/STRATEGY_NAME` 下维护一份**策略专属持仓文件**(如 JSON/CSV): + - 记录:证券代码、持仓数量、建仓时间、成本价等; + - 实盘每次执行前,先读该文件,再结合实时查询来对账; + - 下单成功后更新持仓文件; +- 禁止通过“全账户持仓”直接驱动策略卖出逻辑。 + +缺少本地读写权限时,策略无法创建这些文件,文档建议直接更换完整权限的 QMT 版本[9]。 + +## 6.4 常见实盘错误与防御 + +文档中列举了多类实盘报错及原因[3][9]: + +1. **没有下载 ETF 历史数据**:回测前必须在 QMT 中先下载策略相关 ETF 历史数据,否则出现 `IndexError: single positional indexer is out-of-bounds`。 +2. **没有下载指数数据**:基准指数(默认沪深300)同样需要下载,否则: + - 回测绩效计算缺少 benchmark; + - `get_previous_nth_trading_date` 使用指数作为“完整交易日序列”的参照时会出错。 +3. **候选池增加新标的后报错**: + - 新加入的 ETF 没有历史数据; + - 或数据缺口导致 rolling 计算不满 N 天。 +4. **global 变量作用域错误**: + - 在函数中修改 `CODE_LIST` 未声明 `global CODE_LIST`,导致只改了局部变量,轮动逻辑仍基于旧池子; + - 文档明确通过 `global CODE_LIST` 解决[9]。 +5. **无读写权限**: + - 无法写入策略持仓或日志,建议换 QMT 安装路径或版本。 + +项目中应: + +- 在策略启动时加入一套“自检流程”: + - 检查历史数据完整性(ETF + 基准指数); + - 核验文件系统读写权限; + - 打印当前 `CODE_LIST` 和 `N_DAYS` 配置; +- 对常见异常设置明确的错误类型与日志输出。 + +--- + +# 七、项目落地建议与版本规划 + +## 7.1 最小可用版本(MVP) + +包含: + +1. AkShare 数据获取 + 本地缓存; +2. 简单 N 日涨幅轮动 + 斜率×R²轮动两个版本; +3. 回测模块(含 quantstats + 自算 CAGR 修正); +4. 策略报告(净值曲线、相对基准、关键指标表)。 + +用途:内部论证与参数调优。 + +## 7.2 QMT 实盘版本(V1) + +在 MVP 基础上增加: + +1. QMT 适配:数据获取、下单接口封装; +2. 日线定时任务(按 `STRATEGY_TRADETIME` 触发); +3. 仓位隔离实现(本地持仓文件); +4. 完整日志与错误处理(包括数据缺失、下单失败、超时重试)。 + +## 7.3 强化版本(V2+) + +未来可增加: + +- 交易成本与滑点模拟; +- 多频率(周维度)轮动; +- 按波动率调节仓位(风险平价思想); +- 多策略组合框架(资产轮动 + 行业轮动 + 风格轮动)。 + +--- + +# 八、结语(给团队的执行要点) + +1. **策略层**:落地时默认优先实现“斜率×R²趋势得分”的 ETF 轮动,并保留 N 日涨幅版用作对照。 +2. **评估层**:不能直接信任 quantstats 原始 CAGR 输出来做决策,必须按文档方式修正或自算年化。 +3. **工程层**:QMT 实盘时,仓位隔离、本地持仓文件、读写权限是“生死线”,没做好会直接导致策略乱卖持仓。 +4. **产品层**:对外宣传时,明确回测假设(收盘价成交、无费用、主观候选池),避免“十年10倍/19倍”被误解为可复制的实盘收益。 +5. **团队协作**:建议先按本方案拆分为“数据&回测”、“QMT接入”、“风控&监控”三个子任务并行推进。 + +如果你愿意,我可以在下一步帮你把上述方案拆解成更细的任务列表和代码模块设计草图(按你现在团队的技术栈来适配)。 + +--- + +**References** + +[1] 一、说在前面的话 / 轮动类型与候选池说明. 手把手教你构建与改进ETF轮动策略(十年19倍,附源码).pdf +[2] 轮动策略构建:数据获取与N日涨幅动量计算. 手把手教你构建与改进ETF轮动策略(十年19倍,附源码).pdf +[3] ETF轮动策略实盘QMT报错及数据下载问题解析. 完了,ETF轮动策略实盘当中,到底会遇到多少挫折_.pdf +[4] 改进策略2:斜率×R²趋势得分与绩效对比. 手把手教你构建与改进ETF轮动策略(十年19倍,附源码).pdf +[5] ETF轮动策略迁移到实盘后,幸好守住了十年10倍+(实盘参数与未来函数讨论). ETF轮动策略迁移到实盘后,幸好守住了十年10倍+.pdf +[6] 关于量化策略中CAGR计算公式及年数 n 确定方法讨论. 其实,上期的ETF轮动策略中,隐藏着一个错误.pdf +[7] quantstats 计算CAGR时periods默认值错误与cagr函数未传参Bug说明. 其实,上期的ETF轮动策略中,隐藏着一个错误.pdf +[8] ETF轮动策略累计收益与年化收益计算及其误差分析. 其实,上期的ETF轮动策略中,隐藏着一个错误.pdf +[9] ETF轮动策略实盘中成份股代码设置与global关键字问题、仓位隔离说明. 完了,ETF轮动策略实盘当中,到底会遇到多少挫折_.pdf \ No newline at end of file diff --git a/chart.py b/chart.py deleted file mode 100644 index d8c1102..0000000 --- a/chart.py +++ /dev/null @@ -1,463 +0,0 @@ -import pandas as pd - -from loguru import logger -import talib as ta -import numpy as np -from lightweight_charts import Chart # lightweight-charts==2.1 - - -import random -from datetime import datetime - - -def TD(dataframe: pd.DataFrame): - close = dataframe["close"].to_list() - td = [0, 0, 0, 0] - up = 0 - down = 0 - for i in range(4, len(close)): - if close[i] > close[i - 4]: - up += 1 - down = 0 - td.append(up) - else: - down -= 1 - up = 0 - td.append(down) - return td - - -def resample_data(df: pd.DataFrame, timeframe: str) -> pd.DataFrame: - """ - 对日线数据进行重采样 - :param df: 原始日线数据(time 作为索引) - :param timeframe: 目标周期 ('1D', '1W', '1M', '3M', '1Y') - :return: 重采样后的 DataFrame - """ - # 映射周期到 pandas resample 规则 - timeframe_map = { - "1D": "D", # 日线 - "1W": "W", # 周线 - "1M": "M", # 月线 - "3M": "3M", # 季线 - "1Y": "Y", # 年线 - } - - if timeframe not in timeframe_map: - return df - df = df.set_index("time") - - rule = timeframe_map[timeframe] - - # 对 OHLCV 数据进行重采样 - resampled = ( - df.resample(rule) - .agg( - { - "open": "first", # 开盘价取第一个 - "high": "max", # 最高价取最大值 - "low": "min", # 最低价取最小值 - "close": "last", # 收盘价取最后一个 - "volume": "sum", # 成交量求和 - } - ) - .dropna() - ) - - # 重置索引,将 time 变回列 - resampled = resampled.reset_index() - return resampled - - -def POC(df, bins: int = 50): - """ - 计算价格分布的成交量峰值位置(POC,Point Of Control) - - 参数: - df: 包含 'low', 'high', 'volume' 三列的 DataFrame(每行代表一根 K 线) - bins: 将价格区间划分为多少个小区间(价格桶),默认为 50 - - 返回: - poc: 代表成交量最大的价格区间的中点价格 - """ - - # 当前数据的最低价和最高价,用于构建价格区间 - low, high = df["low"].min(), df["high"].max() - # 将整个价格区间等分为 bins 个小区间,需要 bins+1 个边界点 - price_ranges = np.linspace(low, high, bins + 1) - - # 初始化每个价格区间累积的成交量数组,长度为 bins(区间个数) - volume_per_price = np.zeros(bins) - - # 遍历每一根 K 线,将该 K 线的成交量按覆盖的价格区间平均分配 - for i in range(len(df)): - high = df["high"].iloc[i] - low = df["low"].iloc[i] - volume = df["volume"].iloc[i] - - # 找到当前 K 线覆盖的价格边界(注意使用闭区间判断) - # price_ranges 表示边界点,若 price_ranges[j] 在 [low, high] 范围内,则第 j 个边界被覆盖 - price_coverage = (price_ranges >= low) & (price_ranges <= high) - covered_bins = np.where(price_coverage)[0] - - if len(covered_bins) > 0: - # 如果覆盖了多个边界点,则这些边界之间形成了若干完整的区间 - # 将该 K 线的成交量平均分配到这些被覆盖的区间上 - # 注意:covered_bins 里是边界索引,最后一个边界索引对应的区间索引需小于 bins - volume_per_bin = volume / len(covered_bins) - for bin_idx in covered_bins: - if bin_idx < bins: - volume_per_price[bin_idx] += volume_per_bin - - # 找到成交量最大的区间索引 - idx = np.argmax(volume_per_price) - # 以该区间的两个边界的中点作为 POC 值 - poc = (price_ranges[idx] + price_ranges[idx + 1]) / 2 - return poc - - -def get_fixed_color_based_on_period(num: int): - # 使用周期值作为随机种子,确保相同周期生成相同颜色 - random.seed(num) - - # 生成随机的RGB值 - r = random.randint(0, 255) - g = random.randint(0, 255) - b = random.randint(0, 255) - - # 将RGB值转换为十六进制颜色代码 - color_code = "#{:02x}{:02x}{:02x}".format(r, g, b) - - # 重置随机种子,避免影响其他随机操作 - random.seed(None) - - return color_code - - -class QuantChart: - - def __init__(self): - self.horizontal_lines = {} - self.legend_data = {} - self.visible_latest_close = None - - def update_legend(self, chart, key, text): - self.legend_data[key] = text - sorted_dict = dict(sorted(self.legend_data.items())) - full_text = ", ".join(sorted_dict.values()) - chart.legend(visible=True, text=full_text) - - def add_ema(self, df, chart, period: int = 50): - name = f"EMA_{period}" - df[name] = ta.EMA(df["close"], timeperiod=period) - color = get_fixed_color_based_on_period(num=period) - line = chart.create_line( - name, color=color, width=2, price_label=False, price_line=False - ) - line.set(df[["time", name]]) - - def add_cci( - self, df, chart, period: int = 14, height: float = 0.1, position: str = "bottom" - ): - cci = ta.CCI(df["high"], df["low"], df["close"], timeperiod=period) - df[f"CCI_{period}"] = cci - cci_chart = chart.create_subchart( - position=position, width=1, height=height, sync=True - ) - cci_chart.layout(font_family="Times New Roman") - cci_chart.legend( - visible=True, font_size=14, color="#FFFFFF", font_family="Times New Roman" - ) - cci_chart.time_scale(visible=False) - cci_line = cci_chart.create_line(name=f"CCI_{period}", color="#FF0000", width=2) - cci_line.set(df[["time", f"CCI_{period}"]]) - df = df[["time"]].copy() - df["h"] = 100 - df["l"] = -100 - cci_line = cci_chart.create_line( - name="h", - color="#D4C21C", - width=1, - style="dashed", - price_label=False, - price_line=False, - ) - cci_line.set(df[["time", "h"]]) - cci_line = cci_chart.create_line( - name="l", - color="#D4C21C", - width=1, - style="dashed", - price_label=False, - price_line=False, - ) - cci_line.set(df[["time", "l"]]) - - def add_macd( - self, - df, - chart, - fastperiod: int = 12, - slowperiod: int = 26, - signalperiod: int = 9, - height: float = 0.1, - position: str = "bottom", - ): - macd, signal, hist = ta.MACD( - df["close"], - fastperiod=fastperiod, - slowperiod=slowperiod, - signalperiod=signalperiod, - ) - df["DIF"] = macd - df["DEA"] = signal - macd_name = f"MACD_{fastperiod}_{slowperiod}_{signalperiod}" - df[macd_name] = hist * 2 - macd_chart = chart.create_subchart( - position=position, width=1, height=height, sync=True - ) - macd_chart.layout(font_family="Times New Roman") - macd_chart.legend( - visible=True, font_size=14, color="#FFFFFF", font_family="Times New Roman" - ) - macd_chart.time_scale(visible=False) - - histogram = macd_chart.create_histogram(name=macd_name) - hist_data = df[["time", macd_name]].copy() - hist_data["prev_value"] = hist_data[macd_name].shift(1) - - # 设置颜色逻辑 - def get_histogram_color(row): - current, prev = row[macd_name], row["prev_value"] - is_hollow = (current >= 0 and current < prev) or ( - current < 0 and current > prev - ) - if current >= 0: - return "rgba(255, 0, 0, 0.5)" if is_hollow else "#ff0000" - else: - return "rgba(0, 255, 0, 0.5)" if is_hollow else "#00FF00" - - hist_data["color"] = hist_data.apply(get_histogram_color, axis=1) - hist_data = hist_data.drop("prev_value", axis=1) - - histogram.set(hist_data) - - macd_line = macd_chart.create_line( - name="DIF", color="#2962FF", width=2, price_label=False, price_line=False - ) - macd_line.set(df[["time", "DIF"]]) - signal_line = macd_chart.create_line( - name="DEA", color="#FF0000", width=2, price_label=False, price_line=False - ) - signal_line.set(df[["time", "DEA"]]) - - def add_TD(self, df, chart): - df["TD"] = TD(df) - td_line = chart.create_line( - name="TD", - color="rgba(0, 0, 0, 0)", # 透明色,不在图表上显示线条 - width=0, - price_line=False, - price_label=False, - price_scale_id="td_scale", - ) - td_line.precision(0) - td_line.set(df[["time", "TD"]]) - TDs = df[["time", "TD"]].to_dict(orient="records") - markers = [] - for item in TDs: - if item["TD"] in [9, 13]: - markers.append( - { - "time": item["time"].strftime("%Y-%m-%d %H:%M:%S"), - "position": "above", - "shape": "arrow_down", - "color": "#00FF00", - "text": f"{item['TD']}", - } - ) - elif item["TD"] in [-9, -13]: - markers.append( - { - "time": item["time"].strftime("%Y-%m-%d %H:%M:%S"), - "position": "below", - "shape": "arrow_up", - "color": "#FF0000", - "text": f"{item['TD']}", - } - ) - - chart.marker_list(markers) - - def add_buy_sell_signal_markers(self, df, chart): - - if "buy" not in df.columns or "sell" not in df.columns: - return - - markers = [] - signals = df[["time", "buy", "sell"]].to_dict(orient="records") - for item in signals: - if item["buy"] == 1: - markers.append( - { - "time": item["time"].strftime("%Y-%m-%d"), - "position": "below", - "shape": "arrow_up", - "color": "#00FF00", - "text": f"B", - } - ) - elif item["sell"] == 1: - markers.append( - { - "time": item["time"].strftime("%Y-%m-%d"), - "position": "above", - "shape": "arrow_down", - "color": "#FF0000", - "text": f"S", - } - ) - - chart.marker_list(markers) - - def on_range_change_poc(self, chart, bars_before, bars_after): - df = chart.candle_data - if df is None or df.empty: - return - - total_bars = len(df) - # 计算可见范围的起始和结束索引 - # if bars_after < 0: - # start_idx = max(0, int(bars_before)) - # end_idx = total_bars - # elif bars_before < 0: - # start_idx = 0 - # end_idx = max(0, int(total_bars - bars_after)) - # else: - # start_idx = int(bars_before) - # end_idx = max(0, int(total_bars - bars_after)) - - start_idx = max(0, int(bars_before)) - end_idx = max(0, int(total_bars - max(0, int(bars_after)))) - - df_range = df.iloc[start_idx:end_idx] - # logger.info( - # f"Calculating POC for bars {start_idx} to {end_idx}, total_bars={total_bars}, bars_before={bars_before}, bars_after={bars_after}" - # ) - # logger.info(f"df_range_len: {len(df_range)}") - poc = POC(df_range) - poc_line_name = "POC" - if poc_line_name in self.horizontal_lines: - self.horizontal_lines[poc_line_name].delete() - # 添加新的 POC 水平线 - self.visible_latest_close = df_range["close"].iloc[-1] - profit = (self.visible_latest_close - poc) / poc * 100 - poc_line = chart.horizontal_line( - price=poc, color="#FF0000", width=4, style="solid", text=f"{poc:.2f}" - ) - legend_text = f"POC: {poc:.2f}, poc_range_profit%: {profit:.1f}%, visible_tf_cnt: {len(df_range)}" - self.update_legend(chart=chart, key=poc_line_name, text=legend_text) - - self.horizontal_lines[poc_line_name] = poc_line - - def check_df(self, df: pd.DataFrame): - # basic type check - if not isinstance(df, pd.DataFrame): - raise TypeError("df must be a pandas DataFrame") - - required_cols = ["time", "open", "high", "low", "close", "volume"] - missing = [c for c in required_cols if c not in df.columns] - if missing: - raise ValueError( - f"Missing required columns: {missing}. Required: {required_cols}" - ) - - # time column must be datetime.datetime or pd.Timestamp (or a datetime64 dtype) - time_series = df["time"] - if not ( - pd.api.types.is_datetime64_any_dtype(time_series) - or time_series.apply( - lambda x: isinstance(x, (pd.Timestamp, datetime)) - ).all() - ): - raise TypeError( - "Column 'time' must contain datetime values (python datetime.datetime or pandas.Timestamp) " - "or be a datetime64 dtype." - ) - - # other columns must be numeric (int or float) - for col in ["open", "high", "low", "close", "volume"]: - if not pd.api.types.is_numeric_dtype(df[col]): - raise TypeError(f"Column '{col}' must be numeric (int or float)") - - def setup_crosshair_tracking(self, chart): - chart.run_script( - f""" - {chart.id}.chart.subscribeCrosshairMove((param) => {{ - if (!param.point) return; - const price = {chart.id}.series.coordinateToPrice(param.point.y); - if (price !== null) {{ - window.callbackFunction(`crosshair_price_~_${{price}}`); - }} - }}) - """ - ) - - # 注册回调处理 - def on_crosshair_price(price_str): - # 实时获取鼠标位置y轴的价格 - price = float(price_str) - if self.visible_latest_close is not None: - profit = (self.visible_latest_close - price) / price * 100 - self.update_legend( - chart=chart, - key="Crosshair Price Profit%", - text=f"crosshair_price_profit%: {profit:.2f}%", - ) - - chart.win.handlers["crosshair_price"] = lambda p: on_crosshair_price(p) - - def plot_chart( - self, - df, - symbol: str, - name: str, - timeframe: str, - init_visible_num_bars: int = 90, - ): - # 校验数据是否满足 - self.check_df(df) - - chart = Chart(toolbox=True, inner_height=0.8, maximize=True) - chart.layout(font_family="Times New Roman") - chart.topbar.textbox("symbol", symbol) - chart.topbar.textbox("name", name) - chart.topbar.textbox("timeframe", timeframe) - chart.legend(visible=True, font_size=14, color="#FFFFFF") - chart.set(df) - - # 设置刚进入chart时的可见k线数量范围 - end_time = df["time"].iloc[-1] - start_time = df["time"].iloc[-init_visible_num_bars] - chart.set_visible_range(start_time, end_time) - - # 设置每次放缩k线范围时的回调函数计算实时计算poc - chart.events.range_change += self.on_range_change_poc - - self.setup_crosshair_tracking(chart) - - # 添加技术指标 - # add_ema(df, chart, period=10) - # add_ema(df, chart, period=20) - self.add_ema(df, chart, period=30) - # add_ema(df, chart, period=60) - self.add_macd(df, chart) - self.add_cci(df, chart, period=14) - self.add_TD(df, chart) - self.add_buy_sell_signal_markers(df, chart) - - chart.show(block=True) - - -if __name__ == "__main__": - ... diff --git a/chart_crypto.py b/chart_crypto.py deleted file mode 100644 index 9489d63..0000000 --- a/chart_crypto.py +++ /dev/null @@ -1,13 +0,0 @@ -import pandas as pd -from chart import QuantChart - -if __name__ == "__main__": - symbol = "ETH_USDT" - timeframe = "1d" - data_path = f"/Users/aszer/Documents/vscode/cta/user_data/data/okx/{symbol}-{timeframe}.feather" - df = pd.read_feather(data_path) - - df.rename(columns={"date": "time"}, inplace=True) - print(df.head()) - quant_chart = QuantChart() - quant_chart.plot_chart(df, symbol=symbol, name=symbol, timeframe=timeframe, init_visible_num_bars=180) diff --git a/chart_index.py b/chart_index.py deleted file mode 100644 index 938a6be..0000000 --- a/chart_index.py +++ /dev/null @@ -1,45 +0,0 @@ -import pandas as pd -from loguru import logger -from chart import resample_data, QuantChart -from db_config import DatabaseManager, DatabaseConfig - - -def get_kline(code: str) -> list: - """ - 获取所有指数代码 - :return: - """ - env = "online" - env = "daily" - db_config = DatabaseConfig(env=env) - logger.info(f"数据库连接: {db_config.connection_string}") - - db_manager = DatabaseManager(db_config) - sql = f"SELECT date as time, open, high, low, close, volume FROM public.index_kline where code='{code}' order by date;" - res = db_manager.execute_query(sql) - data_list = [dict(item) for item in res] - df = pd.DataFrame(data_list) - df["time"] = pd.to_datetime(df["time"]) - num_cols = ["open", "high", "low", "close", "volume"] - for col in num_cols: - if col in df.columns: - df[col] = pd.to_numeric(df[col], errors="coerce").astype(float) - return df - - -if __name__ == "__main__": - symbol = "399986" - timeframe = "1D" - - df = pd.read_csv( - "/Users/aszer/Documents/vscode/etf/data/index_all_stock.csv", - encoding="utf-8-sig", - ) - name = df.loc[df["代码"] == symbol, "名称"].values[0] - - df = get_kline(code=symbol) - df = resample_data(df, timeframe) - # df['buy'] = df['time'].apply(lambda x: 1 if pd.to_datetime(x).day % 2 == 1 else 0) - # df['sell'] = df['time'].apply(lambda x: 1 if pd.to_datetime(x).day % 2 == 0 else 0) - quant_chart = QuantChart() - quant_chart.plot_chart(df=df, symbol=symbol, name=name, timeframe=timeframe) diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/config/settings.py b/config/settings.py new file mode 100644 index 0000000..3659044 --- /dev/null +++ b/config/settings.py @@ -0,0 +1,100 @@ +""" +ETF策略项目 - 通用配置 + +敏感信息通过环境变量读取,非敏感配置直接定义 +""" + +import os +from pathlib import Path + +# 加载 .env 文件 +try: + from dotenv import load_dotenv + load_dotenv() +except ImportError: + pass # python-dotenv 未安装时跳过 + +# 项目根目录 +PROJECT_ROOT = Path(__file__).parent.parent + +# 数据目录 +DATA_DIR = PROJECT_ROOT / "data" +DATA_CACHE_DIR = PROJECT_ROOT / "data_cache" + +# 确保目录存在 +DATA_CACHE_DIR.mkdir(exist_ok=True) + + +# ==================== API配置 ==================== +def get_tushare_token() -> str: + """从环境变量获取Tushare Token""" + token = os.getenv("TUSHARE_TOKEN") + if not token: + raise ValueError("请设置环境变量 TUSHARE_TOKEN") + return token + + +# ==================== 钉钉配置 ==================== +def get_dingtalk_config() -> dict: + """从环境变量获取钉钉配置""" + return { + "webhook": os.getenv("DINGTALK_WEBHOOK", ""), + "secret": os.getenv("DINGTALK_SECRET", ""), + } + + +# ==================== 数据库配置 ==================== +def get_db_config() -> dict: + """从环境变量获取数据库配置""" + return { + "host": os.getenv("DB_HOST", "192.168.0.115"), + "port": int(os.getenv("DB_PORT", "5432")), + "database": os.getenv("DB_NAME", "etf_db"), + "username": os.getenv("DB_USER", "admin"), + "password": os.getenv("DB_PASS", "admin"), + } + + +# ==================== 代码映射 ==================== +CODE_NAME_MAP = { + # 宽基 + "000300.SH": "沪深300", + "000905.SH": "中证500", + "000852.SH": "中证1000", + "399006.SZ": "创业板指", + "000015.SH": "上证红利", + # 金融 + "399986.SZ": "中证银行", + "399975.SZ": "证券公司", + "000934.SH": "中证金融", + # 消费 + "000932.SH": "中证消费", + "399997.SZ": "中证白酒", + # 医药 + "000933.SH": "中证医药", + "399989.SZ": "中证医疗", + # 科技 + "000935.SH": "中证信息", + "399971.SZ": "中证传媒", + # 新能源 + "399808.SZ": "中证新能源", + "399976.SZ": "新能源车", + # 周期 + "399395.SZ": "国证有色", + "399440.SZ": "中证钢铁", + "399998.SZ": "中证煤炭", + "399813.SZ": "细分化工", + "000937.SH": "中证能源", + "000938.SH": "中证材料", + # 其他 + "399967.SZ": "中证军工", + "399393.SZ": "国证地产", + "000827.SH": "中证环保", + "399995.SZ": "中证基建", + "000949.SH": "中证农业", + "399702.SZ": "中证国债指数", +} + +# 基准指数 +BENCHMARK_CODE = "000300.SH" +BENCHMARK_NAME = "沪深300指数" diff --git a/config/strategies/cci.yaml b/config/strategies/cci.yaml new file mode 100644 index 0000000..70fa9b4 --- /dev/null +++ b/config/strategies/cci.yaml @@ -0,0 +1,27 @@ +# CCI技术指标筛选配置 + +# ==================== 数据源配置 ==================== +# 数据来源: "postgresql" 或 "akshare" +data_source: "postgresql" + +# ==================== 筛选参数 ==================== +# CCI指标周期 +day_period: 14 +week_period: 14 + +# 筛选阈值(低于该值视为超卖信号) +threshold: -100 + +# 数据获取天数(用于计算CCI) +lookback_days: 100 + +# ==================== 标的池 ==================== +# 指数代码列表文件路径(CSV格式,需包含"指数代码"和"指数名称"列) +index_fund_info_file: "index_fund_info.csv" + +# ==================== 定时任务 ==================== +# 运行时间(24小时制) +schedule_time: "19:00" + +# 是否跳过周末 +skip_weekend: true diff --git a/config/strategies/rotation.yaml b/config/strategies/rotation.yaml new file mode 100644 index 0000000..2dc7286 --- /dev/null +++ b/config/strategies/rotation.yaml @@ -0,0 +1,56 @@ +# ETF轮动策略配置 + +# ==================== 候选池配置 ==================== +# A股全行业指数代码列表(Tushare格式:XXXXXX.SH / XXXXXX.SZ) +code_list: + # 宽基指数 + - "000300.SH" # 沪深300(大盘蓝筹) + - "000905.SH" # 中证500(中盘成长) + - "000852.SH" # 中证1000(小盘) + - "399006.SZ" # 创业板指(创业板龙头) + - "000015.SH" # 上证红利(高股息价值) + # 金融 + - "399986.SZ" # 中证银行 + # 消费 + - "399997.SZ" # 中证白酒 + # 医药健康 + - "399989.SZ" # 中证医疗 + # 科技信息 + - "000935.SH" # 中证信息技术 + # 新能源 + - "399976.SZ" # 新能源汽车 + # 周期资源 + - "399395.SZ" # 国证有色金属 + - "399998.SZ" # 中证煤炭 + - "399813.SZ" # 细分化工 + - "000937.SH" # 中证能源 + # 其他行业 + - "399967.SZ" # 中证军工 + - "000949.SH" # 中证农业 + - "399702.SZ" # 中证国债指数 + +# ==================== 回测参数 ==================== +start_date: "2018-01-01" +end_date: "2025-03-17" + +# ==================== 因子参数 ==================== +# 动量/趋势窗口期(天数) +n_days: 25 +# 因子类型:'momentum'(N日涨幅)或 'slope_r2'(斜率×R²) +factor_type: "slope_r2" + +# ==================== 轮动参数 ==================== +# 每次轮动选中的ETF数量(1=全仓单一品种) +select_num: 5 + +# ==================== 调仓控制 ==================== +# 最低调仓周期(交易日):持仓至少持有 N 天后才允许换仓 +rebalance_days: 1 +# 调仓得分阈值:新组合总得分需超过当前组合 X% 才触发调仓 +rebalance_threshold: 0.0 +# 单次换仓成本(双边,含佣金+滑点) +trade_cost: 0.001 + +# ==================== 数据缓存 ==================== +# 是否使用本地缓存(True=优先从本地读取) +use_cache: true diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/core/common/__init__.py b/core/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/db_config.py b/core/common/db.py similarity index 61% rename from db_config.py rename to core/common/db.py index 456316f..a742673 100644 --- a/db_config.py +++ b/core/common/db.py @@ -7,49 +7,27 @@ from psycopg2.extras import RealDictCursor from sqlalchemy import create_engine import pandas as pd from loguru import logger -import os from typing import Optional - -class DatabaseConfig: - """数据库配置类""" - - def __init__(self): - self.host = "192.168.0.115" - self.port = 5432 - self.database = "etf_db" - self.username = "admin" - self.password = "admin" - - @property - def connection_string(self) -> str: - """获取连接字符串""" - return f"postgresql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}" - - @property - def psycopg2_params(self) -> dict: - """获取psycopg2连接参数""" - return { - "host": self.host, - "port": self.port, - "database": self.database, - "user": self.username, - "password": self.password, - } +from config.settings import get_db_config class DatabaseManager: """数据库管理类""" - def __init__(self, config: DatabaseConfig = None): - self.config = config or DatabaseConfig() + def __init__(self, config: dict = None): + self.config = config or get_db_config() self.engine = None def get_engine(self): """获取SQLAlchemy引擎""" if self.engine is None: + conn_str = ( + f"postgresql://{self.config['username']}:{self.config['password']}" + f"@{self.config['host']}:{self.config['port']}/{self.config['database']}" + ) self.engine = create_engine( - self.config.connection_string, + conn_str, pool_pre_ping=True, pool_recycle=300, echo=False, @@ -58,7 +36,13 @@ class DatabaseManager: def get_connection(self): """获取psycopg2连接""" - return psycopg2.connect(**self.config.psycopg2_params) + return psycopg2.connect( + host=self.config["host"], + port=self.config["port"], + database=self.config["database"], + user=self.config["username"], + password=self.config["password"], + ) def test_connection(self) -> bool: """测试数据库连接""" @@ -73,18 +57,17 @@ class DatabaseManager: logger.error(f"数据库连接测试失败: {e}") return False - def create_table_if_not_exists(self, table_name: str, create_sql: str) -> bool: - """创建表(如果不存在)""" + def execute_query(self, query: str, params: tuple = None) -> Optional[list]: + """执行查询并返回结果""" try: with self.get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(create_sql) - conn.commit() - logger.info(f"表 {table_name} 创建成功或已存在") - return True + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute(query, params) + result = cursor.fetchall() + return [dict(row) for row in result] except Exception as e: - logger.error(f"创建表 {table_name} 失败: {e}") - return False + logger.error(f"执行查询失败: {e}") + return None def insert_dataframe( self, df: pd.DataFrame, table_name: str, if_exists: str = "append" @@ -106,18 +89,6 @@ class DatabaseManager: logger.error(f"插入数据到表 {table_name} 失败: {e}") return False - def execute_query(self, query: str, params: tuple = None) -> Optional[list]: - """执行查询并返回结果""" - try: - with self.get_connection() as conn: - with conn.cursor(cursor_factory=RealDictCursor) as cursor: - cursor.execute(query, params) - result = cursor.fetchall() - return result - except Exception as e: - logger.error(f"执行查询失败: {e}") - return None - def close(self): """关闭连接""" if self.engine: diff --git a/core/common/notify.py b/core/common/notify.py new file mode 100644 index 0000000..17ef112 --- /dev/null +++ b/core/common/notify.py @@ -0,0 +1,210 @@ +""" +通知模块 - 支持钉钉、日志等多种通知方式 +""" + +import requests +import time +import hmac +import hashlib +import base64 +import urllib.parse +from loguru import logger +from typing import Optional + +from config.settings import get_dingtalk_config + + +class DingTalkBot: + """钉钉机器人类""" + + def __init__(self, webhook: str = None, secret: str = None): + """ + 初始化钉钉机器人 + + Args: + webhook: 钉钉自定义机器人webhook地址 + secret: 加签密钥(可选) + """ + config = get_dingtalk_config() + self.webhook = webhook or config.get("webhook", "") + self.secret = secret or config.get("secret", "") + + if not self.webhook: + logger.warning("钉钉webhook未配置,消息将不会被发送") + + def _gen_signed_url(self) -> str: + """生成带签名的URL""" + if not self.secret: + return self.webhook + + timestamp = str(round(time.time() * 1000)) + secret_enc = self.secret.encode("utf-8") + string_to_sign = f"{timestamp}\n{self.secret}" + string_to_sign_enc = string_to_sign.encode("utf-8") + hmac_code = hmac.new( + secret_enc, string_to_sign_enc, digestmod=hashlib.sha256 + ).digest() + sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) + return f"{self.webhook}×tamp={timestamp}&sign={sign}" + + def send_text( + self, content: str, at_mobiles: list = None, is_at_all: bool = False + ) -> bool: + """ + 发送文本消息 + + Args: + content: 消息内容 + at_mobiles: 需要@的手机号列表 + is_at_all: 是否@所有人 + + Returns: + bool: 是否发送成功 + """ + if not self.webhook: + logger.warning(f"[钉钉消息未发送] {content[:100]}...") + return False + + at_mobiles = at_mobiles or [] + data = { + "msgtype": "text", + "text": {"content": content}, + "at": {"atMobiles": at_mobiles, "isAtAll": is_at_all}, + } + + url = self._gen_signed_url() + + try: + response = requests.post(url, json=data, timeout=5) + response.raise_for_status() + result = response.json() + if result.get("errcode", -1) != 0: + logger.error(f"钉钉消息发送失败: {result}") + return False + logger.info("钉钉消息发送成功") + return True + except Exception as e: + logger.error(f"钉钉消息发送异常: {e}") + return False + + def send_markdown( + self, + title: str, + text: str, + at_mobiles: list = None, + is_at_all: bool = False, + ) -> bool: + """ + 发送markdown消息 + + Args: + title: 消息标题 + text: markdown格式的消息内容 + at_mobiles: 需要@的手机号列表 + is_at_all: 是否@所有人 + + Returns: + bool: 是否发送成功 + """ + if not self.webhook: + logger.warning(f"[钉钉Markdown未发送] {title}") + return False + + at_mobiles = at_mobiles or [] + data = { + "msgtype": "markdown", + "markdown": {"title": title, "text": text}, + "at": {"atMobiles": at_mobiles, "isAtAll": is_at_all}, + } + + url = self._gen_signed_url() + + try: + response = requests.post(url, json=data, timeout=5) + response.raise_for_status() + result = response.json() + if result.get("errcode", -1) != 0: + logger.error(f"钉钉markdown消息发送失败: {result}") + return False + logger.info("钉钉markdown消息发送成功") + return True + except Exception as e: + logger.error(f"钉钉markdown消息发送异常: {e}") + return False + + +class NotificationManager: + """通知管理器 - 统一管理多种通知渠道""" + + def __init__(self): + self.dingtalk = DingTalkBot() + + def notify(self, message: str, title: str = "系统通知", use_markdown: bool = False): + """ + 发送通知(优先使用钉钉,失败则记录日志) + + Args: + message: 消息内容 + title: 消息标题(markdown模式使用) + use_markdown: 是否使用markdown格式 + """ + if use_markdown: + success = self.dingtalk.send_markdown(title, message) + else: + success = self.dingtalk.send_text(message) + + if not success: + # 钉钉发送失败,记录到日志 + logger.info(f"[通知] {title}: {message}") + + def notify_error(self, error_msg: str): + """发送错误通知""" + markdown = f"""## 错误告警 + +**时间**: {time.strftime('%Y-%m-%d %H:%M:%S')} + +**错误信息**: +``` +{error_msg} +``` +""" + self.notify(markdown, title="系统错误", use_markdown=True) + + def notify_signal(self, signals: list, signal_type: str = "CCI超卖"): + """ + 发送交易信号通知 + + Args: + signals: 信号列表,每项为dict包含code, name等指标 + signal_type: 信号类型名称 + """ + if not signals: + logger.info(f"[{signal_type}] 无信号") + return + + # 构建markdown表格 + if signals: + headers = signals[0].keys() + header_line = " | ".join(headers) + separator = " | ".join(["---"] * len(headers)) + + rows = [] + for s in signals: + row = " | ".join(str(v) for v in s.values()) + rows.append(row) + + table = f"{header_line}\n{separator}\n" + "\n".join(rows) + else: + table = "无" + + markdown = f"""## {signal_type}信号 + +**时间**: {time.strftime('%Y-%m-%d %H:%M:%S')} + +**筛选结果**: + +{table} + +共 {len(signals)} 个标的符合筛选条件。 +""" + self.notify(markdown, title=f"{signal_type}信号", use_markdown=True) diff --git a/core/common/utils.py b/core/common/utils.py new file mode 100644 index 0000000..b742927 --- /dev/null +++ b/core/common/utils.py @@ -0,0 +1,190 @@ +""" +通用工具函数 +""" + +import pandas as pd +import numpy as np +from datetime import datetime, timedelta +from typing import Optional + + +def format_date(date_str: str, output_format: str = "%Y-%m-%d") -> str: + """ + 统一日期格式 + + Args: + date_str: 输入日期字符串(支持 YYYY-MM-DD 或 YYYYMMDD) + output_format: 输出格式 + + Returns: + str: 格式化后的日期字符串 + """ + # 尝试解析多种格式 + for fmt in ["%Y-%m-%d", "%Y%m%d", "%Y/%m/%d"]: + try: + dt = datetime.strptime(date_str, fmt) + return dt.strftime(output_format) + except ValueError: + continue + raise ValueError(f"无法解析日期格式: {date_str}") + + +def get_date_range( + start_date: Optional[str] = None, + end_date: Optional[str] = None, + lookback_days: int = 365, +) -> tuple[str, str]: + """ + 获取日期范围 + + Args: + start_date: 开始日期,None则根据lookback_days计算 + end_date: 结束日期,None则使用今天 + lookback_days: 回溯天数 + + Returns: + tuple: (start_date, end_date) 格式为 YYYY-MM-DD + """ + if end_date is None: + end = datetime.now() + else: + end = datetime.strptime(format_date(end_date), "%Y-%m-%d") + + if start_date is None: + start = end - timedelta(days=lookback_days) + else: + start = datetime.strptime(format_date(start_date), "%Y-%m-%d") + + return start.strftime("%Y-%m-%d"), end.strftime("%Y-%m-%d") + + +def calculate_cagr( + nav_series: pd.Series, + method: str = "natural_days", +) -> float: + """ + 计算年化收益率(CAGR) + + Args: + nav_series: 净值序列(index=日期) + method: 'natural_days' 或 'trading_days' + + Returns: + float: CAGR值 + """ + total_return = nav_series.iloc[-1] / nav_series.iloc[0] + + if method == "natural_days": + days = (nav_series.index[-1] - nav_series.index[0]).days + years = days / 365.0 + elif method == "trading_days": + years = len(nav_series) / 252.0 + else: + raise ValueError(f"不支持的CAGR计算方式: {method}") + + if years <= 0: + return 0.0 + + return total_return ** (1 / years) - 1 + + +def calculate_max_drawdown(nav_series: pd.Series) -> tuple[float, datetime, datetime]: + """ + 计算最大回撤 + + Returns: + tuple: (最大回撤比例, 回撤起始日, 回撤结束日) + """ + cummax = nav_series.cummax() + drawdown = (nav_series - cummax) / cummax + + max_dd = drawdown.min() + end_idx = drawdown.idxmin() + start_idx = nav_series[:end_idx].idxmax() + + return max_dd, start_idx, end_idx + + +def calculate_sharpe( + returns: pd.Series, + rf: float = 0.0, + periods: int = 252, +) -> float: + """ + 计算年化夏普比率 + + Args: + returns: 日收益率序列 + rf: 无风险利率(年化) + periods: 年化系数 + + Returns: + float: 夏普比率 + """ + excess_returns = returns - rf / periods + if excess_returns.std() == 0: + return 0.0 + return excess_returns.mean() / excess_returns.std() * np.sqrt(periods) + + +def resample_data( + df: pd.DataFrame, + timeframe: str, + time_col: str = "time", +) -> pd.DataFrame: + """ + 对数据进行重采样 + + Args: + df: 原始数据 + timeframe: 目标周期 ('1D', '1W', '1M', '1Y') + time_col: 时间列名 + + Returns: + DataFrame: 重采样后的数据 + """ + timeframe_map = { + "1D": "D", + "1W": "W", + "1M": "M", + "3M": "3M", + "1Y": "Y", + } + + if timeframe not in timeframe_map: + return df + + df = df.copy() + if time_col in df.columns: + df[time_col] = pd.to_datetime(df[time_col]) + df.set_index(time_col, inplace=True) + + rule = timeframe_map[timeframe] + + resampled = ( + df.resample(rule) + .agg( + { + "open": "first", + "high": "max", + "low": "min", + "close": "last", + "volume": "sum", + } + ) + .dropna() + ) + + return resampled.reset_index() + + +def safe_divide(a: float, b: float, default: float = 0.0) -> float: + """安全除法,避免除以0""" + return a / b if b != 0 else default + + +def truncate_string(s: str, max_length: int = 50, suffix: str = "...") -> str: + """截断字符串""" + if len(s) <= max_length: + return s + return s[: max_length - len(suffix)] + suffix diff --git a/core/factors/__init__.py b/core/factors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/core/factors/momentum.py b/core/factors/momentum.py new file mode 100644 index 0000000..9d86893 --- /dev/null +++ b/core/factors/momentum.py @@ -0,0 +1,137 @@ +""" +动量因子计算模块 + +支持两种动量因子: +1. N日涨幅(简单动量) +2. 斜率×R²趋势得分(改进版) +""" + +import numpy as np +import pandas as pd +from sklearn.linear_model import LinearRegression + + +def calculate_momentum(price_series: pd.Series, n: int) -> pd.Series: + """ + 计算 N 日涨幅(简单动量) + + Args: + price_series: 价格序列 + n: 动量窗口天数 + + Returns: + Series: N日涨幅 + """ + return price_series / price_series.shift(n + 1) - 1.0 + + +def _slope_r2_score(srs: pd.Series, n: int = 25) -> float: + """ + 单次计算斜率×R²趋势得分 + + Args: + srs: 价格窗口序列(长度为 n) + n: 窗口长度 + + Returns: + float: 斜率 × R² × 10000 + """ + if srs.shape[0] < n: + return np.nan + + x = np.arange(1, n + 1).reshape(-1, 1) + y = srs.values / srs.values[0] # 归一化 + + lr = LinearRegression().fit(x, y) + slope = lr.coef_[0] + r_squared = lr.score(x, y) + score = 10000 * slope * r_squared + + return score + + +def calculate_slope_r2(price_series: pd.Series, n: int = 25) -> pd.Series: + """ + 计算斜率×R²趋势得分序列 + + Args: + price_series: 价格序列 + n: 滚动窗口天数 + + Returns: + Series: 趋势得分序列 + """ + return price_series.rolling(n).apply( + lambda x: _slope_r2_score(x, n), raw=False + ) + + +def calculate_daily_return(price_series: pd.Series) -> pd.Series: + """ + 计算日收益率 + + Args: + price_series: 价格序列 + + Returns: + Series: 日收益率 + """ + return price_series / price_series.shift(1) - 1 + + +def compute_factors( + etf_data: pd.DataFrame, + code_list: list, + n: int = 25, + factor_type: str = "slope_r2", +) -> tuple[pd.DataFrame, list]: + """ + 计算所有指数的因子和日收益率 + + Args: + etf_data: DataFrame, 宽表格式的收盘价 + code_list: 指数代码列表 + n: 动量/趋势窗口 + factor_type: 'momentum' 或 'slope_r2' + + Returns: + tuple: (result_df, valid_codes) + """ + result = etf_data.copy() + + # 过滤掉缺失值过多的指数 + total_rows = len(result) + valid_codes = [] + for code in code_list: + if code not in result.columns: + print(f" ⚠ 跳过 {code}: 不在数据中") + continue + null_pct = result[code].isnull().sum() / total_rows + if null_pct > 0.2: + print(f" ⚠ 剔除 {code}: 缺失率 {null_pct:.1%} 过高") + result = result.drop(columns=[code]) + else: + valid_codes.append(code) + + # 对有效指数计算因子 + for code in valid_codes: + result[f"日收益率_{code}"] = calculate_daily_return(result[code]) + + if factor_type == "momentum": + result[f"得分_{code}"] = calculate_momentum(result[code], n) + elif factor_type == "slope_r2": + result[f"得分_{code}"] = calculate_slope_r2(result[code], n) + else: + raise ValueError(f"不支持的因子类型: {factor_type}") + + # 按得分列做 dropna + score_cols = [f"得分_{code}" for code in valid_codes] + result = result.dropna(subset=score_cols) + + print("\n因子计算完成:") + print(f" 因子类型: {factor_type}") + print(f" 窗口天数: {n}") + print(f" 有效指数: {len(valid_codes)}/{len(code_list)}") + print(f" 有效数据: {len(result)} 行") + + return result, valid_codes diff --git a/core/factors/technical.py b/core/factors/technical.py new file mode 100644 index 0000000..fdc48c1 --- /dev/null +++ b/core/factors/technical.py @@ -0,0 +1,207 @@ +""" +技术指标计算模块 + +包含CCI、EMA、MACD等常用技术指标 +""" + +import pandas as pd +import numpy as np +import talib as ta + + +def calculate_cci( + df: pd.DataFrame, + period: int = 14, + high_col: str = "high", + low_col: str = "low", + close_col: str = "close", +) -> pd.Series: + """ + 计算CCI指标(商品通道指数) + + Args: + df: DataFrame with OHLC data + period: CCI周期 + high_col: 最高价列名 + low_col: 最低价列名 + close_col: 收盘价列名 + + Returns: + Series: CCI值 + """ + return ta.CCI( + high=df[high_col], + low=df[low_col], + close=df[close_col], + timeperiod=period, + ) + + +def calculate_ema( + price_series: pd.Series, + period: int = 20, +) -> pd.Series: + """ + 计算指数移动平均线 + + Args: + price_series: 价格序列 + period: EMA周期 + + Returns: + Series: EMA值 + """ + return ta.EMA(price_series, timeperiod=period) + + +def calculate_macd( + price_series: pd.Series, + fastperiod: int = 12, + slowperiod: int = 26, + signalperiod: int = 9, +) -> tuple[pd.Series, pd.Series, pd.Series]: + """ + 计算MACD指标 + + Args: + price_series: 价格序列 + fastperiod: 快线周期 + slowperiod: 慢线周期 + signalperiod: 信号线周期 + + Returns: + tuple: (macd, signal, hist) + """ + macd, signal, hist = ta.MACD( + price_series, + fastperiod=fastperiod, + slowperiod=slowperiod, + signalperiod=signalperiod, + ) + return macd, signal, hist + + +def calculate_td_sequence(close_series: pd.Series) -> pd.Series: + """ + 计算TD序列(Tom DeMark Sequential) + + Args: + close_series: 收盘价序列 + + Returns: + Series: TD序列值(正数为上涨计数,负数为下跌计数) + """ + close = close_series.to_list() + td = [0, 0, 0, 0] + up = 0 + down = 0 + + for i in range(4, len(close)): + if close[i] > close[i - 4]: + up += 1 + down = 0 + td.append(up) + else: + down -= 1 + up = 0 + td.append(down) + + return pd.Series(td, index=close_series.index) + + +def resample_to_weekly(df: pd.DataFrame) -> pd.DataFrame: + """ + 将日线数据重采样为周线数据 + + Args: + df: DataFrame with columns: date, open, high, low, close, volume + + Returns: + DataFrame: 周线数据 + """ + df = df.copy() + if "date" in df.columns: + df["date"] = pd.to_datetime(df["date"]) + df.set_index("date", inplace=True) + + weekly = pd.DataFrame( + { + "code": df["code"].resample("W").first() if "code" in df.columns else None, + "open": df["open"].resample("W").first(), + "high": df["high"].resample("W").max(), + "low": df["low"].resample("W").min(), + "close": df["close"].resample("W").last(), + "volume": df["volume"].resample("W").sum(), + } + ) + + return weekly.dropna() + + +class TechnicalScreener: + """技术指标筛选器基类""" + + def __init__(self, name: str): + self.name = name + + def screen(self, df: pd.DataFrame) -> bool: + """ + 判断数据是否符合筛选条件 + + Args: + df: DataFrame with OHLCV data + + Returns: + bool: 是否符合条件 + """ + raise NotImplementedError + + +class CCIScreener(TechnicalScreener): + """CCI超卖筛选器""" + + def __init__( + self, + day_period: int = 14, + week_period: int = 14, + threshold: float = -100, + use_weekly: bool = True, + ): + super().__init__("CCI超卖筛选") + self.day_period = day_period + self.week_period = week_period + self.threshold = threshold + self.use_weekly = use_weekly + + def screen(self, df: pd.DataFrame) -> dict: + """ + 筛选CCI超卖信号 + + Returns: + dict: { + 'triggered': bool, # 是否触发信号 + 'day_cci': float, # 日线CCI值 + 'week_cci': float, # 周线CCI值(如启用) + } + """ + # 计算日线CCI + day_cci = calculate_cci(df, period=self.day_period).iloc[-1] + + result = { + "triggered": day_cci < self.threshold, + "day_cci": day_cci, + "week_cci": None, + } + + # 计算周线CCI(如果启用) + if self.use_weekly: + weekly_df = resample_to_weekly(df) + if len(weekly_df) >= self.week_period: + week_cci = calculate_cci(weekly_df, period=self.week_period).iloc[-1] + result["week_cci"] = week_cci + # 日线或周线任一超卖即触发 + result["triggered"] = ( + day_cci < self.threshold or week_cci < self.threshold + ) + + return result diff --git a/dingtalk.py b/dingtalk.py deleted file mode 100644 index 06521d1..0000000 --- a/dingtalk.py +++ /dev/null @@ -1,118 +0,0 @@ -import requests -import time -import hmac -import hashlib -import base64 -import urllib.parse -from loguru import logger - - -class DingTalkBot: - """ - 钉钉机器人类,通过webhook和可选的加签token向群聊发送消息提醒 - """ - - def __init__(self, webhook: str, secret: str = None): - """ - :param webhook: 钉钉自定义机器人webhook地址 - :param secret: 加签密钥(可选) - """ - self.webhook = webhook - self.secret = secret - - def _gen_signed_url(self): - """ - 如果有加签token,根据钉钉接口算法拼接签名到url - """ - if not self.secret: - return self.webhook - timestamp = str(round(time.time() * 1000)) - secret_enc = self.secret.encode("utf-8") - string_to_sign = f"{timestamp}\n{self.secret}" - string_to_sign_enc = string_to_sign.encode("utf-8") - hmac_code = hmac.new( - secret_enc, string_to_sign_enc, digestmod=hashlib.sha256 - ).digest() - sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) - url = f"{self.webhook}×tamp={timestamp}&sign={sign}" - return url - - def send_text(self, content: str, at_mobiles=None, is_at_all=False): - """ - 发送文本消息 - - :param content: 消息内容 - :param at_mobiles: 需要@的手机号组成的列表,可选 - :param is_at_all: 是否@所有人,默认False - """ - at_mobiles = at_mobiles or [] - data = { - "msgtype": "text", - "text": {"content": content}, - "at": {"atMobiles": at_mobiles, "isAtAll": is_at_all}, - } - - url = self._gen_signed_url() if self.secret else self.webhook - - try: - response = requests.post(url, json=data, timeout=5) - response.raise_for_status() - result = response.json() - if result.get("errcode", -1) != 0: - logger.error(f"钉钉消息发送失败: {result}") - else: - logger.info("钉钉消息发送成功") - except Exception as e: - logger.error(f"钉钉消息发送异常: {e}") - - def send_markdown(self, title: str, text: str, at_mobiles=None, is_at_all=False): - """ - 发送markdown消息 - - :param title: 消息标题 - :param text: markdown格式的消息内容 - :param at_mobiles: 需要@的手机号组成的列表,可选 - :param is_at_all: 是否@所有人,默认False - """ - at_mobiles = at_mobiles or [] - data = { - "msgtype": "markdown", - "markdown": {"title": title, "text": text}, - "at": {"atMobiles": at_mobiles, "isAtAll": is_at_all}, - } - - url = self._gen_signed_url() if self.secret else self.webhook - - try: - response = requests.post(url, json=data, timeout=5) - response.raise_for_status() - result = response.json() - if result.get("errcode", -1) != 0: - logger.error(f"钉钉markdown消息发送失败: {result}") - else: - logger.info("钉钉markdown消息发送成功") - except Exception as e: - logger.error(f"钉钉markdown消息发送异常: {e}") - - -if __name__ == "__main__": - - webhook = "https://oapi.dingtalk.com/robot/send?access_token=fb70c1561d8beba94b4f11568f4bb15e3ae07ccbdc8ac19676434a9d1cd17546" # 填写你的webhook - secret = "SEC1ae7cd2f1a6f9da3611af37da3e7d954c1e8533fc073c6c8cc5e5af3b6e5926b" # 填写你的加签token(如果有),否则留空 - - # CTA 群机器人 - # webhook = "https://oapi.dingtalk.com/robot/send?access_token=87c7abfcdd69b699c32da4e4f5981cd2ca6b0445474fc6ffb36f2ed0f6262fbb" - # secret = "SECf3d6b43f2f8a87ab91feffd052e71ec314fbf57a1842e483fe07af3c0a0e5aa6" - dingtalk = DingTalkBot(webhook, secret) - dingtalk.send_text("测试消息") - - # 测试markdown消息 - markdown_content = """ -## 系统通知 -- **状态**: 正常运行 -- **CPU使用率**: 65% -- **内存使用率**: 78% - -> 详细信息请查看监控面板 - """ - dingtalk.send_markdown("系统状态报告", markdown_content) diff --git a/index_downloader.py b/index_downloader.py deleted file mode 100644 index c77160b..0000000 --- a/index_downloader.py +++ /dev/null @@ -1,90 +0,0 @@ -import os -import re -import akshare as ak -import pandas as pd -from loguru import logger - - -# index_hist_df = ak.index_zh_a_hist( -# symbol="000001", # 指数代码,如上证指数 -# period="daily", # K线周期: daily(日K) -# start_date="19700101", # 开始日期 -# end_date="22220101", # 结束日期 -# ) - - -def get_all_stock_index(): - index_choice = ["沪深重要指数", "上证系列指数", "深证系列指数", "中证系列指数"] - index_df_list = [] - for source in index_choice: - logger.info(f"正在获取 {source}...") - index_df = ak.stock_zh_index_spot_em(symbol=source) - index_df["symbol"] = source - index_df_list.append(index_df) - logger.info(f"{source}: {index_df.shape[0]}") - df = pd.concat(index_df_list) - return df - - -def get_index_fund_info(): - # 读取指数数据和基金数据 - index_df = pd.read_csv( - "/Users/aszer/Documents/vscode/etf/data/index_all_stock.csv", - encoding="utf-8-sig", - ) - fund_df = pd.read_csv( - "/Users/aszer/Documents/vscode/etf/data/fund_info.csv", encoding="utf-8-sig" - ) - - # 构建指数名称集合去重,加快后续匹配 - index_name_set = set(index_df["名称"].astype(str).unique()) - - # 对每个基金名称,查找其是否包含某一指数名称,允许多对多匹配 - records = [] - for fund_idx, fund_row in fund_df.iterrows(): - fund_name = str(fund_row["基金名称"]) - fund_code = str(fund_row["基金代码"]) - funf_fee = fund_row["手续费"] - matched_index_list = [ - idx_name - for idx_name in index_name_set - if idx_name.lower() in fund_name.lower() - ] - for idx_name in matched_index_list: - # 找到指数的相关代码 - index_row = index_df[index_df["名称"] == idx_name] - index_code = None - if not index_row.empty: - index_code = str(index_row.iloc[0]["代码"]) - records.append( - { - "指数代码": index_code, - "指数名称": idx_name, - "基金代码": fund_code, - "基金名称": fund_name, - "手续费": funf_fee, - } - ) - index_fund_df = pd.DataFrame(records) - index_fund_df = index_fund_df.sort_values("指数代码").reset_index(drop=True) - return index_fund_df - - -if __name__ == "__main__": - df = get_all_stock_index() - # df.to_csv("index_all_stock.csv", index=False, encoding="utf-8-sig") - - # res = ak.fund_etf_spot_em() - # print(res) - - # df = get_index_fund_info() - # df.to_csv("index_fund_info.csv", index=False, encoding="utf-8-sig") - - # import akshare as ak - # import pandas as pd - - # for symbol in ["NVDA", "AAPL", "MSFT", "AMZN", "TSLA", "META", "GOOGL"]: - # stock_us_daily_df = ak.stock_us_daily(symbol=symbol, adjust="qfq") - # stock_us_daily_df.to_csv( - # f"{symbol}_stock_us_daily.csv", index=False, encoding="utf-8-sig" - # ) diff --git a/scripts/run_cci_screener.py b/scripts/run_cci_screener.py new file mode 100755 index 0000000..4922884 --- /dev/null +++ b/scripts/run_cci_screener.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +""" +CCI技术指标筛选器入口 + +用法: + python scripts/run_cci_screener.py + python scripts/run_cci_screener.py --config config/strategies/cci.yaml + python scripts/run_cci_screener.py --schedule # 定时模式 +""" + +import sys +import time +import yaml +import argparse +import schedule +from pathlib import Path +from datetime import datetime + +# 添加项目根目录到路径 +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from strategies.screener.cci import CCIScreener + + +def load_config(config_path: str) -> dict: + """加载配置文件""" + with open(config_path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) + + +def run_screening(config: dict): + """执行一次筛选""" + print(f"\n[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] 开始CCI筛选...") + + screener = CCIScreener(config) + signals = screener.run_screening() + + return signals + + +def main(): + parser = argparse.ArgumentParser(description="CCI技术指标筛选") + parser.add_argument( + "--config", + type=str, + default="config/strategies/cci.yaml", + help="配置文件路径", + ) + parser.add_argument( + "--schedule", + action="store_true", + help="启用定时模式", + ) + args = parser.parse_args() + + # 加载配置 + config = load_config(args.config) + print("=" * 60) + print(" CCI技术指标筛选器") + print("=" * 60) + print(f"\n配置文件: {args.config}") + print(f"日线周期: {config.get('day_period', 14)}") + print(f"周线周期: {config.get('week_period', 14)}") + print(f"筛选阈值: {config.get('threshold', -100)}") + print(f"数据源: {config.get('data_source', 'postgresql')}") + + if args.schedule: + # 定时模式 + schedule_time = config.get("schedule_time", "19:00") + print(f"\n定时模式已启用,每天 {schedule_time} 执行") + print("按 Ctrl+C 停止\n") + + schedule.every().day.at(schedule_time).do(run_screening, config) + + while True: + schedule.run_pending() + time.sleep(1) + else: + # 单次执行 + run_screening(config) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_rotation.py b/scripts/run_rotation.py new file mode 100755 index 0000000..c502e3d --- /dev/null +++ b/scripts/run_rotation.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +""" +ETF轮动策略回测入口 + +用法: + python scripts/run_rotation.py + python scripts/run_rotation.py --config config/strategies/rotation.yaml +""" + +import sys +import time +import yaml +import argparse +from pathlib import Path + +# 添加项目根目录到路径 +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from strategies.rotation.engine import RotationStrategy +from strategies.rotation.portfolio import track_positions, save_trades +from strategies.rotation.report import generate_performance_report +from config.settings import CODE_NAME_MAP, BENCHMARK_NAME + + +def load_config(config_path: str) -> dict: + """加载配置文件""" + with open(config_path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) + + +def main(): + parser = argparse.ArgumentParser(description="ETF轮动策略回测") + parser.add_argument( + "--config", + type=str, + default="config/strategies/rotation.yaml", + help="配置文件路径", + ) + parser.add_argument( + "--save-path", + type=str, + default="report", + help="报告保存路径前缀", + ) + args = parser.parse_args() + + start_time = time.time() + + print("=" * 60) + print(" ETF轮动策略 回测系统") + print("=" * 60) + + # 加载配置 + config = load_config(args.config) + print(f"\n配置文件: {args.config}") + print(f"候选标的: {len(config['code_list'])} 只") + print(f"回测区间: {config['start_date']} ~ {config['end_date']}") + print(f"因子类型: {config['factor_type']}") + print(f"窗口天数: {config['n_days']}") + print(f"选中数量: {config['select_num']}") + print(f"调仓周期: {config['rebalance_days']} 天") + print(f"交易成本: {config['trade_cost']:.2%}") + + # 创建策略实例 + strategy = RotationStrategy(config) + + # 运行回测 + print("\n" + "=" * 60) + print("开始回测...") + print("=" * 60) + + backtest_result = strategy.run() + + # 持仓跟踪 + print("\n" + "=" * 60) + print("持仓跟踪...") + print("=" * 60) + + trades_df, summary_df = track_positions( + backtest_result, + code_name_map=CODE_NAME_MAP, + select_num=config["select_num"], + ) + save_trades(trades_df, summary_df, save_path=args.save_path) + + # 生成绩效报告 + print("\n" + "=" * 60) + print("生成绩效报告...") + print("=" * 60) + + metrics = generate_performance_report( + backtest_result, + strategy.valid_codes, + code_name_map=CODE_NAME_MAP, + benchmark_name=BENCHMARK_NAME, + save_path=args.save_path, + select_num=config["select_num"], + ) + + elapsed = time.time() - start_time + print(f"\n总耗时: {elapsed:.1f}秒") + + return metrics + + +if __name__ == "__main__": + main() diff --git a/signal_calc.py b/signal_calc.py deleted file mode 100644 index ef2b449..0000000 --- a/signal_calc.py +++ /dev/null @@ -1,140 +0,0 @@ -import pandas as pd -from db_config import DatabaseManager, DatabaseConfig -from loguru import logger -from datetime import datetime -import akshare as ak -from index_downloader import get_all_stock_index -import schedule -import time -import traceback -from dingtalk import DingTalkBot -import talib as ta -from tabulate import tabulate - - -db_config = DatabaseConfig() -logger.info(f"数据库连接: {db_config.connection_string}") - -# 如果只是测试连接 -db_manager = DatabaseManager(db_config) - - -def get_all_index_code() -> list: - """ - 获取所有指数代码 - :return: - """ - sql = "SELECT distinct code FROM public.index_kline;" - res = db_manager.execute_query(sql) - code_list = [dict(item)["code"] for item in res] - return code_list - - -def get_index_recent_date(code: str, limit: int = None) -> pd.DataFrame: - """ - 获取指数的近期数据 - :param code: - :return: - """ - limit_clause = f" LIMIT {limit}" if limit else "" - sql = f"SELECT date, code, open, high, low, close, volume FROM public.index_kline WHERE code = '{code}' order by date desc {limit_clause};" - raw_data_list = db_manager.execute_query(sql) - data_list = [dict(item) for item in raw_data_list] - - for i, data in enumerate(data_list): - data_list[i]["date"] = data["date"].strftime("%Y-%m-%d") - data_list[i]["volume"] = int(data["volume"]) - data_list[i]["close"] = float(data["close"]) - data_list[i]["open"] = float(data["open"]) - data_list[i]["high"] = float(data["high"]) - data_list[i]["low"] = float(data["low"]) - df = pd.DataFrame(data_list) - return df - - -def main_calc_process(): - if datetime.today().weekday() >= 5: - logger.info(f"非交易日") - return - webhook = "https://oapi.dingtalk.com/robot/send?access_token=fb70c1561d8beba94b4f11568f4bb15e3ae07ccbdc8ac19676434a9d1cd17546" # 填写你的webhook - secret = "SEC1ae7cd2f1a6f9da3611af37da3e7d954c1e8533fc073c6c8cc5e5af3b6e5926b" # 填写你的加签token(如果有),否则留空 - dingtalk = DingTalkBot(webhook, secret) - index_fund_df = pd.read_csv("index_fund_info.csv", encoding="utf-8-sig") - code_df = index_fund_df.drop_duplicates(subset=["指数代码"]) - code_list = code_df.to_dict(orient="records") - signal_list = [] - for i, code_info in enumerate(code_list): - code = code_info["指数代码"] - df = get_index_recent_date(code, 100) - if len(df) < 100: - continue - # 将 'date' 列转换为 datetime 类型,并设置为索引 - df["date"] = pd.to_datetime(df["date"]) - # 判断最新日期是否为今天,如果不是则跳过 - today_str = datetime.now().strftime("%Y-%m-%d") - if df["date"].max().strftime("%Y-%m-%d") != today_str: - continue - df = df.sort_values("date") - df.set_index("date", inplace=True) - - # 按周重采样(以每周最后一天为sample),open为第一个、close为最后一个、high/low为最大/最小、volume为总和、code取第一个即可 - df_weekly = pd.DataFrame( - { - "code": df["code"].resample("W").first(), - "open": df["open"].resample("W").first(), - "high": df["high"].resample("W").max(), - "low": df["low"].resample("W").min(), - "close": df["close"].resample("W").last(), - "volume": df["volume"].resample("W").sum(), - } - ) - - # 计算CCI指标(以典型的20周期为例,如果有更具体周期可以调整) - df_weekly["cci"] = ta.CCI( - high=df_weekly["high"], - low=df_weekly["low"], - close=df_weekly["close"], - timeperiod=14, - ) - df_weekly = df_weekly.tail(1) - week_cci = df_weekly["cci"].values[0] - - df["cci"] = ta.CCI( - high=df["high"], - low=df["low"], - close=df["close"], - timeperiod=14, - ) - cci = df["cci"].tail(1).values[0] - logger.info(f"{i}/{len(code_list)}: {code} week_cci: {week_cci} day_cci: {cci}") - if cci < -100 or week_cci < -100: - signal_list.append( - { - "code": code, - "name": code_info["指数名称"], - "天cci": cci, - "周cci": week_cci, - } - ) - # break - signal_df = pd.DataFrame(signal_list) - # dingtalk.send_markdown( - # f"CCI信号", signal_df.to_markdown(tablefmt="simple", index=False) - # ) - if len(signal_list) > 0: - dingtalk.send_text( - tabulate(signal_df, tablefmt="plain", headers="keys", showindex=False) - ) - else: - logger.info("无信号") - - -if __name__ == "__main__": - # main_calc_process() - ... - # main() - # logger.info(datetime.now()) - # schedule.every().day.at("19:00").do(main) - # while True: - # schedule.run_pending() - # time.sleep(1) diff --git a/strategies/__init__.py b/strategies/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/strategies/base.py b/strategies/base.py new file mode 100644 index 0000000..62814fc --- /dev/null +++ b/strategies/base.py @@ -0,0 +1,41 @@ +""" +策略基类定义 +""" + +from abc import ABC, abstractmethod +from typing import Any + + +class Strategy(ABC): + """策略抽象基类""" + + def __init__(self, name: str, config: dict = None): + self.name = name + self.config = config or {} + + @abstractmethod + def run(self, **kwargs) -> Any: + """执行策略""" + pass + + @abstractmethod + def get_signals(self, **kwargs) -> Any: + """获取当前信号""" + pass + + +class BacktestStrategy(Strategy): + """回测策略基类""" + + def __init__(self, name: str, config: dict = None): + super().__init__(name, config) + self.results = None + + @abstractmethod + def run_backtest(self, **kwargs) -> dict: + """执行回测,返回绩效指标""" + pass + + def get_results(self) -> dict: + """获取回测结果""" + return self.results diff --git a/strategies/rotation/__init__.py b/strategies/rotation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/strategies/rotation/engine.py b/strategies/rotation/engine.py new file mode 100644 index 0000000..4c64221 --- /dev/null +++ b/strategies/rotation/engine.py @@ -0,0 +1,249 @@ +""" +ETF轮动策略引擎 + +整合信号生成和回测逻辑 +""" + +import pandas as pd +import numpy as np +from typing import Optional + +from strategies.base import BacktestStrategy +from core.data.tushare_source import TushareDataSource +from core.factors.momentum import compute_factors, calculate_daily_return + + +class RotationStrategy(BacktestStrategy): + """ETF轮动策略""" + + def __init__(self, config: dict): + super().__init__("ETF轮动策略", config) + self.data_source = TushareDataSource(use_cache=config.get("use_cache", True)) + self.data = None + self.signals = None + self.backtest_result = None + + def fetch_data(self) -> pd.DataFrame: + """获取数据""" + from config.settings import BENCHMARK_CODE + + etf_data, benchmark_data, valid_codes = self.data_source.fetch_all( + self.config["code_list"], + BENCHMARK_CODE, + self.config["start_date"], + self.config["end_date"], + ) + + self.etf_data = etf_data + self.benchmark_data = benchmark_data + self.valid_codes = valid_codes + + # 计算因子 + factor_data, valid_codes = compute_factors( + etf_data, + valid_codes, + n=self.config["n_days"], + factor_type=self.config["factor_type"], + ) + + self.data = factor_data + self.valid_codes = valid_codes + return factor_data + + def generate_signals(self) -> pd.DataFrame: + """生成轮动信号""" + if self.data is None: + self.fetch_data() + + result = self.data.copy() + score_cols = [f"得分_{code}" for code in self.valid_codes] + select_num = self.config["select_num"] + rebalance_days = self.config["rebalance_days"] + rebalance_threshold = self.config["rebalance_threshold"] + + # Step 1: 每日目标组合 + if select_num == 1: + daily_target = ( + result[score_cols] + .idxmax(axis=1) + .str.replace("得分_", "", regex=False) + ) + else: + def top_n_codes(row): + scores = pd.to_numeric(row[score_cols], errors="coerce") + top = scores.nlargest(select_num).index.tolist() + return ",".join([c.replace("得分_", "") for c in top]) + daily_target = result.apply(top_n_codes, axis=1) + + # Step 2: 逐日生成信号(调仓周期控制) + held_signals = [] + current_held = None + last_rebalance_idx = 0 + + for i in range(len(result)): + target = daily_target.iloc[i] + + if current_held is None: + current_held = target + last_rebalance_idx = i + held_signals.append(current_held) + continue + + days_since = i - last_rebalance_idx + if days_since >= rebalance_days: + should = self._check_rebalance( + result.iloc[i], current_held, target, + select_num, rebalance_threshold + ) + if should: + current_held = target + last_rebalance_idx = i + + held_signals.append(current_held) + + result["信号_raw"] = held_signals + result["信号"] = result["信号_raw"].shift(1) + result = result.drop(columns=["信号_raw"]) + result = result.dropna(subset=["信号"]) + + self.signals = result + self._print_signal_stats(result, select_num, rebalance_days, rebalance_threshold) + return result + + def _check_rebalance(self, row, current_held, target, select_num, threshold): + """检查是否应该调仓""" + if select_num == 1: + if target == current_held: + return False + new_score = float(row[f"得分_{target}"]) + old_score = float(row[f"得分_{current_held}"]) + if old_score > 0: + return (new_score / old_score - 1) >= threshold + return new_score > 0 + else: + new_codes = target.split(",") + old_codes = current_held.split(",") + if set(new_codes) == set(old_codes): + return False + new_total = sum(float(row[f"得分_{c}"]) for c in new_codes) + old_total = sum(float(row[f"得分_{c}"]) for c in old_codes) + if old_total > 0: + return (new_total / old_total - 1) >= threshold + return new_total > 0 + + def _print_signal_stats(self, result, select_num, rebalance_days, rebalance_threshold): + """打印信号统计""" + total_days = len(result) + + if select_num == 1: + rebalance_count = (result["信号"] != result["信号"].shift(1)).sum() - 1 + else: + prev = None + rebalance_count = 0 + for s in result["信号"]: + if prev is not None and s != prev: + if set(s.split(",")) != set(prev.split(",")): + rebalance_count += 1 + prev = s + + rebalance_count = max(rebalance_count, 0) + avg_hold = total_days / max(rebalance_count, 1) + years = total_days / 252 + annual_rebalances = rebalance_count / max(years, 0.1) + + print(f"\n信号生成完成:") + print(f" 调仓周期: {rebalance_days} 天 | 阈值: {rebalance_threshold:.1%}") + print(f" 交易天数: {total_days}") + print(f" 调仓次数: {rebalance_count} | 平均持仓: {avg_hold:.1f} 天 | 年均调仓: {annual_rebalances:.1f} 次") + + if select_num == 1: + signal_counts = result["信号"].value_counts() + print(f" 品种持仓分布 (前10):") + for code, count in signal_counts.head(10).items(): + pct = count / total_days * 100 + print(f" {code}: {count}天 ({pct:.1f}%)") + + def run_backtest(self) -> pd.DataFrame: + """执行回测""" + if self.signals is None: + self.generate_signals() + + result = self.signals.copy() + select_num = self.config["select_num"] + trade_cost = self.config["trade_cost"] + + # 计算策略日收益率 + if select_num == 1: + def calc_return(row): + return row[f"日收益率_{row['信号']}"] + result["轮动策略日收益率"] = result.apply(calc_return, axis=1) + else: + def calc_multi_return(row): + codes = row["信号"].split(",") + returns = [row[f"日收益率_{c}"] for c in codes] + return np.mean(returns) + result["轮动策略日收益率"] = result.apply(calc_multi_return, axis=1) + + # 扣除交易成本 + if trade_cost > 0: + prev_signal = result["信号"].shift(1) + + if select_num == 1: + changed = (result["信号"] != prev_signal) & prev_signal.notna() + result.loc[changed, "轮动策略日收益率"] -= trade_cost + else: + turnover_list = [] + for curr, prev in zip(result["信号"], prev_signal): + if pd.isna(prev) or curr == prev: + turnover_list.append(0.0) + else: + old = set(prev.split(",")) + new = set(curr.split(",")) + swapped = len(old - new) + turnover_list.append(swapped / len(old)) + result["换手率"] = turnover_list + result["轮动策略日收益率"] -= result["换手率"] * trade_cost + + # 计算净值 + result["轮动策略净值"] = (1 + result["轮动策略日收益率"]).cumprod() + + # 各ETF单独净值 + for code in self.valid_codes: + first_price = result[code].iloc[0] + result[f"净值_{code}"] = result[code] / first_price + + # 基准净值 + bench_ret = self.benchmark_data.pct_change().dropna() + common_dates = result.index.intersection(bench_ret.index) + bench_ret = bench_ret.loc[common_dates] + + result["基准日收益率"] = bench_ret.reindex(result.index, fill_value=0) + result["基准净值"] = (1 + result["基准日收益率"]).cumprod() + + self.backtest_result = result + + # 打印摘要 + total_days = len(result) + strategy_total_return = result["轮动策略净值"].iloc[-1] - 1 + benchmark_total_return = result["基准净值"].iloc[-1] - 1 + + print(f"\n回测完成:") + print(f" 回测区间: {result.index.min().date()} ~ {result.index.max().date()}") + print(f" 交易天数: {total_days}") + print(f" 策略累计收益: {strategy_total_return:.2%}") + print(f" 基准累计收益: {benchmark_total_return:.2%}") + + return result + + def run(self) -> dict: + """运行完整流程""" + self.fetch_data() + self.generate_signals() + self.run_backtest() + return self.backtest_result + + def get_signals(self) -> pd.DataFrame: + """获取当前信号""" + if self.signals is None: + self.generate_signals() + return self.signals diff --git a/strategies/rotation/portfolio.py b/strategies/rotation/portfolio.py new file mode 100644 index 0000000..dbd6e32 --- /dev/null +++ b/strategies/rotation/portfolio.py @@ -0,0 +1,233 @@ +""" +ETF轮动策略 - 持仓跟踪模块 +""" + +import pandas as pd +from typing import Optional + + +def track_positions( + backtest_result: pd.DataFrame, + code_name_map: dict = None, + select_num: int = 1, +) -> tuple[pd.DataFrame, pd.DataFrame]: + """ + 从回测结果中提取每笔持仓记录 + + Args: + backtest_result: 回测结果(含 '信号' 列) + code_name_map: 代码→名称映射 + select_num: 每次选中的品种数量 + + Returns: + tuple: (trades_df, summary_df) + """ + code_name_map = code_name_map or {} + data = backtest_result.copy() + dates = data.index.tolist() + signals = data["信号"].tolist() + trades = [] + + if select_num == 1: + # 单品种轮动 + current_code = signals[0] + entry_date = dates[0] + entry_price = data.loc[entry_date, current_code] + entry_nav = data.loc[entry_date, "轮动策略净值"] + + for i in range(1, len(dates)): + today_code = signals[i] + + if today_code != current_code: + exit_date = dates[i - 1] + exit_price = data.loc[exit_date, current_code] + exit_nav = data.loc[exit_date, "轮动策略净值"] + holding_days = (i - 1) - dates.index(entry_date) + 1 + trade_return = exit_price / entry_price - 1 if entry_price != 0 else 0 + nav_contrib = exit_nav - entry_nav + + trades.append({ + "序号": len(trades) + 1, + "品种代码": current_code, + "品种名称": code_name_map.get(current_code, current_code), + "进场日期": entry_date, + "出场日期": exit_date, + "持仓天数": holding_days, + "仓位占比": "100%", + "进场价格": round(entry_price, 2), + "出场价格": round(exit_price, 2), + "持仓收益": trade_return, + "进场净值": round(entry_nav, 4), + "出场净值": round(exit_nav, 4), + "净值贡献": round(nav_contrib, 4), + }) + + current_code = today_code + entry_date = dates[i] + entry_price = data.loc[entry_date, current_code] + entry_nav = data.loc[entry_date, "轮动策略净值"] + + # 最后一笔 + exit_date = dates[-1] + exit_price = data.loc[exit_date, current_code] + exit_nav = data.loc[exit_date, "轮动策略净值"] + holding_days = len(dates) - dates.index(entry_date) + trade_return = exit_price / entry_price - 1 if entry_price != 0 else 0 + nav_contrib = exit_nav - entry_nav + + trades.append({ + "序号": len(trades) + 1, + "品种代码": current_code, + "品种名称": code_name_map.get(current_code, current_code), + "进场日期": entry_date, + "出场日期": exit_date, + "持仓天数": holding_days, + "仓位占比": "100%", + "进场价格": round(entry_price, 2), + "出场价格": round(exit_price, 2), + "持仓收益": trade_return, + "进场净值": round(entry_nav, 4), + "出场净值": round(exit_nav, 4), + "净值贡献": round(nav_contrib, 4), + }) + + else: + # 多品种等权轮动 + current_signal = signals[0] + entry_date = dates[0] + codes = current_signal.split(",") + weight = 1.0 / len(codes) + entry_prices = {c: data.loc[entry_date, c] for c in codes} + entry_nav = data.loc[entry_date, "轮动策略净值"] + + for i in range(1, len(dates)): + today_signal = signals[i] + + if today_signal != current_signal: + exit_date = dates[i - 1] + exit_nav = data.loc[exit_date, "轮动策略净值"] + holding_days = (i - 1) - dates.index(entry_date) + 1 + + for c in codes: + exit_price = data.loc[exit_date, c] + ep = entry_prices[c] + trade_return = exit_price / ep - 1 if ep != 0 else 0 + + trades.append({ + "序号": len(trades) + 1, + "品种代码": c, + "品种名称": code_name_map.get(c, c), + "进场日期": entry_date, + "出场日期": exit_date, + "持仓天数": holding_days, + "仓位占比": f"{weight:.0%}", + "进场价格": round(ep, 2), + "出场价格": round(exit_price, 2), + "持仓收益": trade_return, + "进场净值": round(entry_nav, 4), + "出场净值": round(exit_nav, 4), + "净值贡献": round((exit_nav - entry_nav) * weight, 4), + }) + + current_signal = today_signal + entry_date = dates[i] + codes = current_signal.split(",") + weight = 1.0 / len(codes) + entry_prices = {c: data.loc[entry_date, c] for c in codes} + entry_nav = data.loc[entry_date, "轮动策略净值"] + + # 最后一笔 + exit_date = dates[-1] + exit_nav = data.loc[exit_date, "轮动策略净值"] + holding_days = len(dates) - dates.index(entry_date) + for c in codes: + exit_price = data.loc[exit_date, c] + ep = entry_prices[c] + trade_return = exit_price / ep - 1 if ep != 0 else 0 + trades.append({ + "序号": len(trades) + 1, + "品种代码": c, + "品种名称": code_name_map.get(c, c), + "进场日期": entry_date, + "出场日期": exit_date, + "持仓天数": holding_days, + "仓位占比": f"{weight:.0%}", + "进场价格": round(ep, 2), + "出场价格": round(exit_price, 2), + "持仓收益": trade_return, + "进场净值": round(entry_nav, 4), + "出场净值": round(exit_nav, 4), + "净值贡献": round((exit_nav - entry_nav) * weight, 4), + }) + + trades_df = pd.DataFrame(trades) + summary = _summarize_by_code(trades_df, code_name_map) + return trades_df, summary + + +def _summarize_by_code(trades_df: pd.DataFrame, code_name_map: dict) -> pd.DataFrame: + """按品种汇总持仓统计""" + if trades_df.empty: + return pd.DataFrame() + + groups = trades_df.groupby("品种代码") + rows = [] + + for code, grp in groups: + total_trades = len(grp) + total_days = grp["持仓天数"].sum() + avg_days = grp["持仓天数"].mean() + win_trades = (grp["持仓收益"] > 0).sum() + win_rate = win_trades / total_trades if total_trades > 0 else 0 + avg_return = grp["持仓收益"].mean() + total_return = (1 + grp["持仓收益"]).prod() - 1 + max_return = grp["持仓收益"].max() + min_return = grp["持仓收益"].min() + + rows.append({ + "品种代码": code, + "品种名称": code_name_map.get(code, code), + "调仓次数": total_trades, + "总持仓天数": total_days, + "平均持仓天数": round(avg_days, 1), + "胜率": win_rate, + "平均收益": avg_return, + "累计收益": total_return, + "最大单次收益": max_return, + "最大单次亏损": min_return, + }) + + summary = pd.DataFrame(rows) + summary = summary.sort_values("总持仓天数", ascending=False).reset_index(drop=True) + return summary + + +def save_trades( + trades_df: pd.DataFrame, + summary_df: pd.DataFrame, + save_path: str = "report", +) -> None: + """保存调仓明细和汇总到CSV""" + import os + os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else ".", exist_ok=True) + + trades_out = trades_df.copy() + trades_out["持仓收益"] = trades_out["持仓收益"].apply(lambda x: f"{x:.2%}") + trades_out["进场日期"] = trades_out["进场日期"].apply( + lambda x: x.strftime("%Y-%m-%d") if hasattr(x, "strftime") else str(x)[:10] + ) + trades_out["出场日期"] = trades_out["出场日期"].apply( + lambda x: x.strftime("%Y-%m-%d") if hasattr(x, "strftime") else str(x)[:10] + ) + + trades_path = f"{save_path}_trades.csv" + trades_out.to_csv(trades_path, index=False, encoding="utf-8-sig") + print(f"\n调仓明细已保存: {trades_path}") + + summary_out = summary_df.copy() + for col in ["胜率", "平均收益", "累计收益", "最大单次收益", "最大单次亏损"]: + summary_out[col] = summary_out[col].apply(lambda x: f"{x:.2%}") + + summary_path = f"{save_path}_summary.csv" + summary_out.to_csv(summary_path, index=False, encoding="utf-8-sig") + print(f"品种汇总已保存: {summary_path}") diff --git a/strategies/rotation/report.py b/strategies/rotation/report.py new file mode 100644 index 0000000..7e04ad5 --- /dev/null +++ b/strategies/rotation/report.py @@ -0,0 +1,175 @@ +""" +ETF轮动策略 - 绩效报告模块 +""" + +import numpy as np +import pandas as pd +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from typing import Optional + +from core.common.utils import calculate_cagr, calculate_max_drawdown, calculate_sharpe + + +def generate_performance_report( + backtest_result: pd.DataFrame, + code_list: list, + code_name_map: dict = None, + benchmark_name: str = "沪深300指数", + save_path: str = "report", + select_num: int = 1, +) -> dict: + """ + 生成完整的绩效报告 + + Args: + backtest_result: 回测结果 + code_list: ETF代码列表 + code_name_map: 代码到名称映射 + benchmark_name: 基准名称 + save_path: 报告保存路径前缀 + select_num: 选中数量 + + Returns: + dict: 绩效指标字典 + """ + import os + os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else ".", exist_ok=True) + + code_name_map = code_name_map or {} + strategy_nav = backtest_result["轮动策略净值"] + strategy_ret = backtest_result["轮动策略日收益率"] + benchmark_nav = backtest_result["基准净值"] + benchmark_ret = backtest_result["基准日收益率"] + + # 计算绩效指标 + s_cagr_nat = calculate_cagr(strategy_nav, "natural_days") + s_cagr_trd = calculate_cagr(strategy_nav, "trading_days") + s_total_return = strategy_nav.iloc[-1] - 1 + s_sharpe = calculate_sharpe(strategy_ret) + s_max_dd, s_dd_start, s_dd_end = calculate_max_drawdown(strategy_nav) + s_win_rate = (strategy_ret > 0).sum() / len(strategy_ret) + s_calmar = s_cagr_nat / abs(s_max_dd) if s_max_dd != 0 else np.inf + + b_cagr_nat = calculate_cagr(benchmark_nav, "natural_days") + b_cagr_trd = calculate_cagr(benchmark_nav, "trading_days") + b_total_return = benchmark_nav.iloc[-1] - 1 + b_sharpe = calculate_sharpe(benchmark_ret) + b_max_dd, _, _ = calculate_max_drawdown(benchmark_nav) + + # 打印绩效表格 + print("\n" + "=" * 70) + print(" 绩效评估报告") + print("=" * 70) + print(f" 回测区间: {strategy_nav.index.min().date()} ~ {strategy_nav.index.max().date()}") + print(f" 交易天数: {len(strategy_nav)}") + print("-" * 70) + print(f' {"指标":<25} {"轮动策略":>15} {"基准(" + benchmark_name + ")":>18}') + print("-" * 70) + print(f' {"累计收益":<25} {s_total_return:>14.2%} {b_total_return:>17.2%}') + print(f' {"CAGR(自然日口径)":<25} {s_cagr_nat:>14.2%} {b_cagr_nat:>17.2%}') + print(f' {"CAGR(交易日口径)":<25} {s_cagr_trd:>14.2%} {b_cagr_trd:>17.2%}') + print(f' {"年化夏普比率":<25} {s_sharpe:>14.2f} {b_sharpe:>17.2f}') + print(f' {"最大回撤":<25} {s_max_dd:>14.2%} {b_max_dd:>17.2%}') + print(f' {"Calmar比率":<23} {s_calmar:>14.2f} {"--":>17}') + print(f' {"日胜率":<25} {s_win_rate:>14.2%} {"--":>17}') + print(f' {"最大回撤区间":<22} {str(s_dd_start.date()):>10} ~ {str(s_dd_end.date())}') + print("=" * 70) + + # 绘制图表 + _plot_report_chart( + backtest_result, code_list, code_name_map, + benchmark_name, save_path, select_num + ) + + # 返回指标字典 + return { + "累计收益": s_total_return, + "CAGR_自然日": s_cagr_nat, + "CAGR_交易日": s_cagr_trd, + "夏普比率": s_sharpe, + "最大回撤": s_max_dd, + "Calmar比率": s_calmar, + "日胜率": s_win_rate, + "基准累计收益": b_total_return, + "基准CAGR_自然日": b_cagr_nat, + "基准夏普比率": b_sharpe, + "基准最大回撤": b_max_dd, + } + + +def _plot_report_chart( + backtest_result: pd.DataFrame, + code_list: list, + code_name_map: dict, + benchmark_name: str, + save_path: str, + select_num: int, +): + """绘制报告图表""" + plt.rcParams["font.sans-serif"] = ["Arial Unicode MS", "SimHei", "DejaVu Sans"] + plt.rcParams["axes.unicode_minus"] = False + + strategy_nav = backtest_result["轮动策略净值"] + benchmark_nav = backtest_result["基准净值"] + + fig, axes = plt.subplots(3, 1, figsize=(14, 12)) + + # 面板1: 净值曲线 + ax1 = axes[0] + ax1.plot(strategy_nav.index, strategy_nav.values, + label="轮动策略", linewidth=2, color="#E74C3C") + ax1.plot(benchmark_nav.index, benchmark_nav.values, + label=benchmark_name, linewidth=1.5, color="#3498DB", alpha=0.8) + + chart_colors = plt.cm.tab20.colors + show_legend_n = min(len(code_list), 10) + for i, code in enumerate(code_list): + name = code_name_map.get(code, code) + lbl = name if i < show_legend_n else None + ax1.plot(backtest_result.index, backtest_result[f"净值_{code}"].values, + label=lbl, linewidth=0.8, alpha=0.4, + color=chart_colors[i % len(chart_colors)]) + + ax1.set_title("ETF轮动策略 - 净值曲线", fontsize=16, fontweight="bold") + ax1.set_ylabel("净值") + ax1.legend(loc="upper left", fontsize=8, ncol=2) + ax1.grid(True, alpha=0.3) + ax1.set_yscale("log") + + # 面板2: 回撤曲线 + ax2 = axes[1] + cummax = strategy_nav.cummax() + drawdown = (strategy_nav - cummax) / cummax + ax2.fill_between(drawdown.index, drawdown.values, 0, alpha=0.5, color="#E74C3C") + ax2.set_title("策略回撤", fontsize=12) + ax2.set_ylabel("回撤") + ax2.grid(True, alpha=0.3) + + # 面板3: 持仓分布 + ax3 = axes[2] + signal_series = backtest_result["信号"] + for i, code in enumerate(code_list): + name = code_name_map.get(code, code) + if select_num > 1: + mask = signal_series.str.contains(code, regex=False, na=False) + else: + mask = signal_series == code + if mask.any(): + ax3.fill_between(signal_series.index, i, i + 0.8, + where=mask, alpha=0.7, + color=chart_colors[i % len(chart_colors)], + label=name) + + ylabels = [code_name_map.get(c, c) for c in code_list] + ax3.set_title("每日持仓分布", fontsize=12) + ax3.set_yticks(range(len(ylabels))) + ax3.set_yticklabels(ylabels, fontsize=7) + ax3.grid(True, alpha=0.3) + + plt.tight_layout() + chart_path = f"{save_path}_chart.png" + plt.savefig(chart_path, dpi=150, bbox_inches="tight") + plt.close() + print(f"\n报告图表已保存: {chart_path}") diff --git a/strategies/screener/__init__.py b/strategies/screener/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/strategies/screener/base.py b/strategies/screener/base.py new file mode 100644 index 0000000..bd1bd91 --- /dev/null +++ b/strategies/screener/base.py @@ -0,0 +1,68 @@ +""" +标的筛选器基类 + +用于基于技术指标筛选符合条件的标的 +""" + +from abc import ABC, abstractmethod +from typing import Any +import pandas as pd + + +class Screener(ABC): + """筛选器抽象基类""" + + def __init__(self, name: str, config: dict = None): + self.name = name + self.config = config or {} + + @abstractmethod + def screen(self, data: Any) -> dict: + """ + 执行筛选 + + Args: + data: 输入数据(DataFrame或其他格式) + + Returns: + dict: 筛选结果,必须包含 'triggered' 键表示是否触发 + """ + pass + + @abstractmethod + def screen_batch(self, data_dict: dict) -> list: + """ + 批量筛选多个标的 + + Args: + data_dict: {code: data} 格式的字典 + + Returns: + list: 符合条件的标的列表 + """ + pass + + +class DataFrameScreener(Screener): + """基于DataFrame的筛选器基类""" + + def __init__(self, name: str, config: dict = None): + super().__init__(name, config) + + def validate_data(self, df: pd.DataFrame) -> bool: + """验证数据格式""" + required_cols = ["open", "high", "low", "close", "volume"] + return all(col in df.columns for col in required_cols) + + def screen_batch(self, data_dict: dict) -> list: + """批量筛选""" + results = [] + for code, data in data_dict.items(): + if isinstance(data, pd.DataFrame) and self.validate_data(data): + result = self.screen(data) + if result.get("triggered", False): + results.append({ + "code": code, + **result + }) + return results diff --git a/strategies/screener/cci.py b/strategies/screener/cci.py new file mode 100644 index 0000000..56536d1 --- /dev/null +++ b/strategies/screener/cci.py @@ -0,0 +1,186 @@ +""" +CCI技术指标筛选器 + +基于商品通道指数(CCI)筛选超卖标的 +""" + +import pandas as pd +from datetime import datetime + +from .base import DataFrameScreener +from core.factors.technical import calculate_cci, resample_to_weekly +from core.common.db import DatabaseManager +from core.common.notify import NotificationManager + + +class CCIScreener(DataFrameScreener): + """CCI超卖筛选器""" + + def __init__(self, config: dict = None): + super().__init__("CCI超卖筛选", config) + self.day_period = config.get("day_period", 14) + self.week_period = config.get("week_period", 14) + self.threshold = config.get("threshold", -100) + self.use_weekly = config.get("use_weekly", True) + self.db_manager = DatabaseManager() + self.notifier = NotificationManager() + + def screen(self, df: pd.DataFrame) -> dict: + """ + 对单只标的进行CCI筛选 + + Args: + df: DataFrame with OHLCV data + + Returns: + dict: { + 'triggered': bool, + 'day_cci': float, + 'week_cci': float or None, + 'current_price': float, + } + """ + if not self.validate_data(df): + return {"triggered": False, "error": "数据格式错误"} + + # 计算日线CCI + day_cci = calculate_cci(df, period=self.day_period).iloc[-1] + current_price = df["close"].iloc[-1] + + result = { + "triggered": day_cci < self.threshold, + "day_cci": round(day_cci, 2), + "week_cci": None, + "current_price": round(current_price, 2), + } + + # 计算周线CCI + if self.use_weekly: + weekly_df = resample_to_weekly(df) + if len(weekly_df) >= self.week_period: + week_cci = calculate_cci(weekly_df, period=self.week_period).iloc[-1] + result["week_cci"] = round(week_cci, 2) + # 日线或周线任一超卖即触发 + result["triggered"] = ( + day_cci < self.threshold or week_cci < self.threshold + ) + + return result + + def get_data_from_db(self, code: str, limit: int = 100) -> pd.DataFrame: + """从数据库获取数据""" + sql = f""" + SELECT date, open, high, low, close, volume + FROM public.index_kline + WHERE code = '{code}' + ORDER BY date DESC + LIMIT {limit} + """ + result = self.db_manager.execute_query(sql) + + if not result: + return pd.DataFrame() + + df = pd.DataFrame(result) + df["date"] = pd.to_datetime(df["date"]) + + # 转换数值类型 + for col in ["open", "high", "low", "close", "volume"]: + df[col] = pd.to_numeric(df[col], errors="coerce") + + df = df.sort_values("date").reset_index(drop=True) + return df + + def run_screening(self, code_list: list = None) -> list: + """ + 执行批量筛选 + + Args: + code_list: 标的代码列表,None则从配置文件读取 + + Returns: + list: 符合条件的标的列表 + """ + if code_list is None: + # 从CSV文件读取指数列表 + import os + csv_path = self.config.get("index_fund_info_file", "index_fund_info.csv") + if os.path.exists(csv_path): + df = pd.read_csv(csv_path, encoding="utf-8-sig") + code_list = df.drop_duplicates(subset=["指数代码"]).to_dict("records") + else: + raise ValueError(f"找不到标的列表文件: {csv_path}") + + signals = [] + today_str = datetime.now().strftime("%Y-%m-%d") + + print(f"开始CCI筛选,共 {len(code_list)} 个标的...") + + for i, code_info in enumerate(code_list): + if isinstance(code_info, dict): + code = code_info.get("指数代码") + name = code_info.get("指数名称", code) + else: + code = code_info + name = code + + try: + df = self.get_data_from_db(code, limit=self.config.get("lookback_days", 100)) + + if len(df) < self.day_period: + continue + + # 检查最新日期 + if df["date"].max().strftime("%Y-%m-%d") != today_str: + continue + + result = self.screen(df) + + if result["triggered"]: + signals.append({ + "code": code, + "name": name, + "day_cci": result["day_cci"], + "week_cci": result["week_cci"], + "price": result["current_price"], + }) + print(f" ✓ {code} ({name}): 日线CCI={result['day_cci']:.2f}") + + except Exception as e: + print(f" ✗ {code}: {e}") + continue + + print(f"\n筛选完成,{len(signals)} 个标的符合CCI超卖条件") + + # 发送通知 + if signals: + self.notifier.notify_signal(signals, signal_type="CCI超卖") + + return signals + + def run_daily(self): + """每日定时运行""" + from datetime import datetime + + # 检查是否为交易日 + if datetime.today().weekday() >= 5 and self.config.get("skip_weekend", True): + print("非交易日,跳过") + return + + self.run_screening() + + +def create_cci_screener_from_config(config_path: str = None) -> CCIScreener: + """从配置文件创建CCI筛选器""" + import yaml + import os + + if config_path is None: + config_path = os.path.join( + os.path.dirname(__file__), "..", "..", "config", "strategies", "cci.yaml" + ) + + with open(config_path, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + + return CCIScreener(config) diff --git a/test.py b/test.py new file mode 100644 index 0000000..5ad905e --- /dev/null +++ b/test.py @@ -0,0 +1,100 @@ +import pandas as pd +import numpy as np +import vectorbt as vbt +from numba import njit +import vectorbt as vbt +import pandas as pd +import numpy as np + +import pandas as pd +from loguru import logger +from chart import resample_data, QuantChart +from db_config import DatabaseManager, DatabaseConfig +from vectorbt.base.reshape_fns import to_2d_array + + +def get_kline(code: str) -> list: + """ + 获取所有指数代码 + :return: + """ + db_config = DatabaseConfig() + logger.info(f"数据库连接: {db_config.connection_string}") + + db_manager = DatabaseManager(db_config) + sql = f"SELECT date as time, open, high, low, close, volume FROM public.index_kline where code='{code}' order by date;" + res = db_manager.execute_query(sql) + data_list = [dict(item) for item in res] + df = pd.DataFrame(data_list) + df["time"] = pd.to_datetime(df["time"]) + num_cols = ["open", "high", "low", "close", "volume"] + for col in num_cols: + if col in df.columns: + df[col] = pd.to_numeric(df[col], errors="coerce").astype(float) + return df + +symbol = "399998" +timeframe = "1D" + +df = get_kline(code=symbol) +df = resample_data(df, timeframe) +df.rename(columns={'time': 'date'}, inplace=True) +print(df.head()) +if 'date' in df.columns: + df = df.set_index('date') +price = df['close'] + +# 2. 计算90天滚动波动率(年化) +returns = price.pct_change() +volatility_90d = returns.rolling(window=90, min_periods=90).std() * np.sqrt(365) + +# 3. 计算波动率倒数作为权重 +inv_vol = 1 / volatility_90d +# 标准化权重(可选,使其更易解释) +inv_vol_normalized = inv_vol / inv_vol.rolling(window=252).mean() + +# 4. 创建每周重新平衡的信号 +# 获取每周最后一个交易日 +weekly_rebalance = pd.Series(False, index=price.index) +weekly_last_days = price.resample('W').last().index +for date in weekly_last_days: + # 找到最接近的交易日 + idx = price.index.get_indexer([date], method='ffill')[0] + if idx >= 0: + weekly_rebalance.iloc[idx] = True + +# 5. 定义订单函数 +@njit +def order_func_nb(c, inv_vol_arr, rebalance_arr): + # 获取当前的波动率倒数权重 + inv_vol_now = vbt.nb.flex_select_auto_nb(inv_vol_arr, c.i, c.col, False) + rebalance_now = vbt.nb.flex_select_auto_nb(rebalance_arr, c.i, c.col, False) + + # 只在重新平衡日调整仓位 + if not rebalance_now or np.isnan(inv_vol_now): + return vbt.nb.order_nothing_nb() + + # 目标仓位 = 总价值 * 波动率倒数权重 + # 这里使用 TargetPercent 类型,权重越高仓位越大 + target_percent = min(inv_vol_now, 1.0) # 限制最大100%仓位 + + return vbt.nb.order_nb( + size=target_percent, + size_type=vbt.SizeType.TargetPercent, + direction=vbt.Direction.LongOnly + ) + +# 6. 运行回测 +pf = vbt.Portfolio.from_order_func( + price, + order_func_nb, + to_2d_array(inv_vol_normalized), + to_2d_array(weekly_rebalance), + init_cash=100, + freq='1D' +) + +# 7. 查看结果 +print(pf.stats()) +print(f"\n波动率倒数策略 vs 买入持有:") +print(f"总收益率: {pf.total_return():.2%}") \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/update_data.py b/update_data.py deleted file mode 100644 index 2aab544..0000000 --- a/update_data.py +++ /dev/null @@ -1,89 +0,0 @@ -import pandas as pd -from db_config import DatabaseManager, DatabaseConfig -from loguru import logger -from datetime import datetime -import akshare as ak - -import schedule -import time -import traceback -from dingtalk import DingTalkBot -import os -from signal_calc import main_calc_process -import requests -from retry import retry - -webhook = "https://oapi.dingtalk.com/robot/send?access_token=21de667159edadd33172c6ec414a2addf9c6359189350ffd36819d2a20e8a0f4" -secret = "SEC43a0fa0b29717f98637a119b92a0bd5f7b2b6da671bdd2bd1279ed8323454d5e" -dingtalk = DingTalkBot(webhook, secret) - - - -def get_latest_index_kline_date(): - - df = ak.stock_zh_index_spot_sina() - column_mapping = { - "date": "date", - "代码": "code", - "今开": "open", - "最高": "high", - "最低": "low", - "最新价": "close", - "成交量": "volume", - } - df["date"] = None - df = df.rename(columns=column_mapping) - df['code'] = df['code'].str.extract(r'(\d+)')[0] - df = df[column_mapping.values()] - df = df.drop_duplicates(subset=["code"]) - cur_date = datetime.now().strftime("%Y%m%d") - df["date"] = cur_date - df.dropna(how="any", inplace=True) - # df.to_csv(f"aaaa.csv", index=False, encoding="utf-8-sig") - return df - - -def main(): - if datetime.today().weekday() >= 5: - logger.info(f"非交易日") - return - - try: - db_config = DatabaseConfig() - logger.info(f"数据库连接: {db_config.connection_string}") - - # 如果只是测试连接 - db_manager = DatabaseManager(db_config) - if db_manager.test_connection(): - logger.info("✅ 数据库连接测试成功") - else: - logger.error("❌ 数据库连接测试失败") - raise Exception("数据库连接测试失败") - - df = get_latest_index_kline_date() - logger.info(df.head()) - latest_date = df["date"].values[0] - res = db_manager.execute_query( - f"SELECT date, code, open, high, low, close, volume FROM public.index_kline where code='000001' and date='{latest_date}' order by date desc limit 1;" - ) - # print(dict(res)) - logger.info(len(res)) - if len(res) == 0: - res = db_manager.insert_dataframe(df, "index_kline") - logger.info(res) - main_calc_process() - except Exception as e: - error_message = f"{e}\n{traceback.format_exc()}" - logger.error(error_message) - dingtalk.send_text(f"A股指数抓取失败: \n{error_message}") - - -if __name__ == "__main__": - # main() - logger.info(datetime.now()) - PULL_SCHEDULE: str = os.getenv("PULL_SCHEDULE", "16:00") - logger.info(f"PULL_SCHEDULE: {PULL_SCHEDULE}") - schedule.every().day.at(PULL_SCHEDULE).do(main) - while True: - schedule.run_pending() - time.sleep(1) diff --git a/visualization/__init__.py b/visualization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/visualization/charts/__init__.py b/visualization/charts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/visualization/charts/indicators.py b/visualization/charts/indicators.py new file mode 100644 index 0000000..6b5d259 --- /dev/null +++ b/visualization/charts/indicators.py @@ -0,0 +1,238 @@ +""" +技术指标绘制组件 +""" + +import pandas as pd +import numpy as np +import talib as ta +import random +from lightweight_charts import Chart + + +def get_fixed_color(num: int) -> str: + """根据数字生成固定颜色""" + random.seed(num) + r = random.randint(0, 255) + g = random.randint(0, 255) + b = random.randint(0, 255) + color = "#{:02x}{:02x}{:02x}".format(r, g, b) + random.seed(None) + return color + + +def add_ema( + chart: Chart, + df: pd.DataFrame, + period: int = 20, + color: str = None, + price_label: bool = False, +): + """添加EMA指标线""" + name = f"EMA_{period}" + df[name] = ta.EMA(df["close"], timeperiod=period) + + line_color = color or get_fixed_color(period) + line = chart.create_line( + name, color=line_color, width=2, + price_label=price_label, price_line=False + ) + line.set(df[["time", name]]) + return line + + +def add_cci( + chart: Chart, + df: pd.DataFrame, + period: int = 14, + height: float = 0.15, + position: str = "bottom", +): + """添加CCI副图""" + cci = ta.CCI(df["high"], df["low"], df["close"], timeperiod=period) + df[f"CCI_{period}"] = cci + + # 创建副图 + cci_chart = chart.create_subchart( + position=position, width=1, height=height, sync=True + ) + cci_chart.layout(font_family="Times New Roman") + cci_chart.legend(visible=True, font_size=14, color="#FFFFFF") + cci_chart.time_scale(visible=False) + + # CCI线 + cci_line = cci_chart.create_line( + name=f"CCI_{period}", color="#FF0000", width=2 + ) + cci_line.set(df[["time", f"CCI_{period}"]]) + + # 水平参考线 + for level, label in [(100, "+100"), (-100, "-100")]: + df[f"cci_{label}"] = level + ref_line = cci_chart.create_line( + name=label, color="#D4C21C", width=1, + style="dashed", price_label=False, price_line=False + ) + ref_line.set(df[["time", f"cci_{label}"]]) + + return cci_chart + + +def add_macd( + chart: Chart, + df: pd.DataFrame, + fastperiod: int = 12, + slowperiod: int = 26, + signalperiod: int = 9, + height: float = 0.15, + position: str = "bottom", +): + """添加MACD副图""" + macd, signal, hist = ta.MACD( + df["close"], + fastperiod=fastperiod, + slowperiod=slowperiod, + signalperiod=signalperiod, + ) + + df["DIF"] = macd + df["DEA"] = signal + macd_name = f"MACD_{fastperiod}_{slowperiod}_{signalperiod}" + df[macd_name] = hist * 2 + + # 创建副图 + macd_chart = chart.create_subchart( + position=position, width=1, height=height, sync=True + ) + macd_chart.layout(font_family="Times New Roman") + macd_chart.legend(visible=True, font_size=14, color="#FFFFFF") + macd_chart.time_scale(visible=False) + + # 柱状图 + histogram = macd_chart.create_histogram(name=macd_name) + hist_data = df[["time", macd_name]].copy() + hist_data["prev_value"] = hist_data[macd_name].shift(1) + + def get_color(row): + current, prev = row[macd_name], row["prev_value"] + is_hollow = (current >= 0 and current < prev) or (current < 0 and current > prev) + if current >= 0: + return "rgba(255, 0, 0, 0.5)" if is_hollow else "#ff0000" + else: + return "rgba(0, 255, 0, 0.5)" if is_hollow else "#00FF00" + + hist_data["color"] = hist_data.apply(get_color, axis=1) + hist_data = hist_data.drop("prev_value", axis=1) + histogram.set(hist_data) + + # DIF线 + dif_line = macd_chart.create_line( + name="DIF", color="#2962FF", width=2, price_label=False, price_line=False + ) + dif_line.set(df[["time", "DIF"]]) + + # DEA线 + dea_line = macd_chart.create_line( + name="DEA", color="#FF0000", width=2, price_label=False, price_line=False + ) + dea_line.set(df[["time", "DEA"]]) + + return macd_chart + + +def add_td_sequence(chart: Chart, df: pd.DataFrame): + """添加TD序列标记""" + close = df["close"].to_list() + td = [0, 0, 0, 0] + up = 0 + down = 0 + + for i in range(4, len(close)): + if close[i] > close[i - 4]: + up += 1 + down = 0 + td.append(up) + else: + down -= 1 + up = 0 + td.append(down) + + df["TD"] = td + + # 添加标记 + markers = [] + for _, row in df.iterrows(): + td_val = row["TD"] + if td_val in [9, 13]: + markers.append({ + "time": row["time"].strftime("%Y-%m-%d %H:%M:%S"), + "position": "above", + "shape": "arrow_down", + "color": "#00FF00", + "text": str(td_val), + }) + elif td_val in [-9, -13]: + markers.append({ + "time": row["time"].strftime("%Y-%m-%d %H:%M:%S"), + "position": "below", + "shape": "arrow_up", + "color": "#FF0000", + "text": str(abs(td_val)), + }) + + chart.marker_list(markers) + + +def add_buy_sell_signals(chart: Chart, df: pd.DataFrame): + """添加买卖信号标记""" + if "buy" not in df.columns and "sell" not in df.columns: + return + + markers = [] + for _, row in df.iterrows(): + if row.get("buy") == 1: + markers.append({ + "time": row["time"].strftime("%Y-%m-%d"), + "position": "below", + "shape": "arrow_up", + "color": "#00FF00", + "text": "B", + }) + elif row.get("sell") == 1: + markers.append({ + "time": row["time"].strftime("%Y-%m-%d"), + "position": "above", + "shape": "arrow_down", + "color": "#FF0000", + "text": "S", + }) + + chart.marker_list(markers) + + +class IndicatorOverlay: + """指标叠加器""" + + def __init__(self, chart: Chart): + self.chart = chart + + def add_default_indicators(self, df: pd.DataFrame): + """添加默认指标组合""" + # 短期EMA + for period in [3, 5, 8, 10, 12, 15]: + add_ema(self.chart, df, period=period, color=5) + + # 长期EMA + for period in [30, 35, 40, 45, 50, 60]: + add_ema(self.chart, df, period=period, color=10) + + # 年线 + add_ema(self.chart, df, period=260) + + # MACD + add_macd(self.chart, df, fastperiod=30, slowperiod=90) + + # CCI + add_cci(self.chart, df, period=120) + + # TD序列 + add_td_sequence(self.chart, df) diff --git a/visualization/charts/kline.py b/visualization/charts/kline.py new file mode 100644 index 0000000..f819203 --- /dev/null +++ b/visualization/charts/kline.py @@ -0,0 +1,97 @@ +""" +K线图表组件 + +基于lightweight-charts的K线图表 +""" + +import pandas as pd +from lightweight_charts import Chart +from datetime import datetime + + +class KlineChart: + """K线图表类""" + + def __init__(self, title: str = "K线图", toolbox: bool = True): + self.title = title + self.toolbox = toolbox + self.chart = None + self.subcharts = {} + + def create(self, maximize: bool = True) -> Chart: + """创建图表实例""" + self.chart = Chart(toolbox=self.toolbox, inner_height=0.8, maximize=maximize) + self.chart.layout(font_family="Times New Roman") + self.chart.legend(visible=True, font_size=14, color="#FFFFFF") + return self.chart + + def set_data(self, df: pd.DataFrame, time_col: str = "time"): + """设置K线数据""" + if self.chart is None: + self.create() + + # 验证数据 + required_cols = ["open", "high", "low", "close", "volume"] + missing = [c for c in required_cols if c not in df.columns] + if missing: + raise ValueError(f"缺少必要列: {missing}") + + # 确保时间列格式正确 + df = df.copy() + if time_col in df.columns: + df[time_col] = pd.to_datetime(df[time_col]) + + self.chart.set(df) + + def set_visible_range(self, start_time: datetime, end_time: datetime): + """设置可见范围""" + if self.chart: + self.chart.set_visible_range(start_time, end_time) + + def add_topbar_text(self, key: str, text: str): + """添加顶部栏文本""" + if self.chart: + self.chart.topbar.textbox(key, text) + + def show(self, block: bool = True): + """显示图表""" + if self.chart: + self.chart.show(block=block) + + +def create_kline_chart( + df: pd.DataFrame, + symbol: str, + name: str, + timeframe: str, + init_visible_bars: int = 90, +) -> Chart: + """ + 快速创建K线图表 + + Args: + df: DataFrame with OHLCV data + symbol: 标的代码 + name: 标的名称 + timeframe: 时间周期 + init_visible_bars: 初始可见K线数量 + + Returns: + Chart实例 + """ + chart = KlineChart() + chart.create(maximize=True) + + chart.add_topbar_text("symbol", symbol) + chart.add_topbar_text("name", name) + chart.add_topbar_text("timeframe", timeframe) + + chart.set_data(df) + + # 设置初始可见范围 + if len(df) > init_visible_bars: + end_time = df["time"].iloc[-1] + start_time = df["time"].iloc[-init_visible_bars] + chart.set_visible_range(start_time, end_time) + + return chart.chart