diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000..7b3a7ef --- /dev/null +++ b/core/__init__.py @@ -0,0 +1 @@ +# 核心模块 diff --git a/core/common/__init__.py b/core/common/__init__.py new file mode 100644 index 0000000..b69e7e4 --- /dev/null +++ b/core/common/__init__.py @@ -0,0 +1 @@ +# 公共模块 diff --git a/core/common/notify.py b/core/common/notify.py new file mode 100644 index 0000000..7ac9cd8 --- /dev/null +++ b/core/common/notify.py @@ -0,0 +1,606 @@ +""" +通知模块 - 支持钉钉、日志等多种通知方式 +""" + +import requests +import time +import hmac +import hashlib +import base64 +import urllib.parse +import os +from loguru import logger +from typing import Optional + +from core.common.oss_utils import upload_image_to_oss + + +def _get_dingtalk_config() -> dict: + """获取钉钉机器人配置(第一群)""" + return { + "webhook": os.getenv("DINGTALK_WEBHOOK", ""), + "secret": os.getenv("DINGTALK_SECRET", ""), + } + + +def _get_all_dingtalk_configs() -> list: + """获取所有钉钉机器人配置(支持多群) + 环境变量格式: DINGTALK_WEBHOOK_1 + DINGTALK_SECRET_1, ... + """ + configs = [] + i = 1 + while True: + webhook = os.getenv(f"DINGTALK_WEBHOOK_{i}", "") + secret = os.getenv(f"DINGTALK_SECRET_{i}", "") + if not webhook: + break + configs.append({"webhook": webhook, "secret": secret}) + i += 1 + return configs + + +class DingTalkBot: + """钉钉机器人类""" + + def __init__(self, webhook: str = None, secret: str = None): + """ + 初始化钉钉机器人 + + Args: + webhook: 钉钉自定义机器人webhook地址 + secret: 加签密钥(可选) + """ + config = _get_dingtalk_config() + self.webhook = webhook or config.get("webhook", "") + self.secret = secret or config.get("secret", "") + + if not self.webhook: + logger.warning("钉钉webhook未配置,消息将不会被发送") + + def _gen_signed_url(self) -> str: + """生成带签名的URL""" + if not self.secret: + return self.webhook + + timestamp = str(round(time.time() * 1000)) + secret_enc = self.secret.encode("utf-8") + string_to_sign = f"{timestamp}\n{self.secret}" + string_to_sign_enc = string_to_sign.encode("utf-8") + hmac_code = hmac.new( + secret_enc, string_to_sign_enc, digestmod=hashlib.sha256 + ).digest() + sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) + return f"{self.webhook}×tamp={timestamp}&sign={sign}" + + def send_text( + self, content: str, at_mobiles: list = None, is_at_all: bool = False + ) -> bool: + """ + 发送文本消息 + + Args: + content: 消息内容 + at_mobiles: 需要@的手机号列表 + is_at_all: 是否@所有人 + + Returns: + bool: 是否发送成功 + """ + if not self.webhook: + logger.warning(f"[钉钉消息未发送] {content[:100]}...") + return False + + at_mobiles = at_mobiles or [] + data = { + "msgtype": "text", + "text": {"content": content}, + "at": {"atMobiles": at_mobiles, "isAtAll": is_at_all}, + } + + url = self._gen_signed_url() + + try: + response = requests.post(url, json=data, timeout=5) + response.raise_for_status() + result = response.json() + if result.get("errcode", -1) != 0: + logger.error(f"钉钉消息发送失败: {result}") + return False + logger.info("钉钉消息发送成功") + return True + except Exception as e: + logger.error(f"钉钉消息发送异常: {e}") + return False + + def send_markdown( + self, + title: str, + text: str, + at_mobiles: list = None, + is_at_all: bool = False, + ) -> bool: + """ + 发送markdown消息 + + Args: + title: 消息标题 + text: markdown格式的消息内容 + at_mobiles: 需要@的手机号列表 + is_at_all: 是否@所有人 + + Returns: + bool: 是否发送成功 + """ + if not self.webhook: + logger.warning(f"[钉钉Markdown未发送] {title}") + return False + + at_mobiles = at_mobiles or [] + data = { + "msgtype": "markdown", + "markdown": {"title": title, "text": text}, + "at": {"atMobiles": at_mobiles, "isAtAll": is_at_all}, + } + + url = self._gen_signed_url() + + try: + response = requests.post(url, json=data, timeout=5) + response.raise_for_status() + result = response.json() + if result.get("errcode", -1) != 0: + logger.error(f"钉钉markdown消息发送失败: {result}") + return False + logger.info("钉钉markdown消息发送成功") + return True + except Exception as e: + logger.error(f"钉钉markdown消息发送异常: {e}") + return False + + def send_image(self, image_path: str, title: str = "图片", max_size_kb: int = 6) -> bool: + """ + 发送图片消息(自动压缩以适应钉钉20KB限制) + 注意:钉钉限制的是整个请求body大小,base64编码会增加约33%体积,所以图片本身需要更小 + + Args: + image_path: 图片文件路径(本地路径) + title: 消息标题 + max_size_kb: 最大图片大小(KB),默认6KB(base64后约8KB,加上其他字段约10KB) + + Returns: + bool: 是否发送成功 + """ + if not self.webhook: + logger.warning(f"[钉钉图片未发送] {title}") + return False + + if not os.path.exists(image_path): + logger.error(f"图片文件不存在: {image_path}") + return False + + try: + # 读取并压缩图片 + image_data = self._compress_image(image_path, max_size_kb) + + if not image_data: + logger.error(f"图片压缩失败: {image_path}") + return False + + # 转为base64 + image_base64 = base64.b64encode(image_data).decode("utf-8") + # 计算图片md5(钉钉需要) + image_md5 = hashlib.md5(image_data).hexdigest() + + data = { + "msgtype": "image", + "image": { + "base64": image_base64, + "md5": image_md5 + } + } + + url = self._gen_signed_url() + response = requests.post(url, json=data, timeout=10) + response.raise_for_status() + result = response.json() + + if result.get("errcode", -1) != 0: + logger.error(f"钉钉图片发送失败: {result}") + return False + + logger.info(f"钉钉图片发送成功: {image_path}") + return True + + except Exception as e: + logger.error(f"钉钉图片发送异常: {e}") + return False + + def _compress_image(self, image_path: str, max_size_kb: int) -> bytes: + """ + 压缩图片到指定大小以下 + + Args: + image_path: 图片路径 + max_size_kb: 最大大小(KB) + + Returns: + bytes: 压缩后的图片数据 + """ + from PIL import Image + import io + + max_size_bytes = max_size_kb * 1024 + + # 先尝试读取原图 + with open(image_path, "rb") as f: + image_data = f.read() + + # 如果已经小于限制,直接返回 + if len(image_data) <= max_size_bytes: + return image_data + + # 需要压缩,使用PIL重新保存 + img = Image.open(image_path) + + # 转换为RGB(去除alpha通道,减小大小) + if img.mode in ('RGBA', 'P'): + img = img.convert('RGB') + + # 逐步降低质量直到满足大小要求 + quality = 85 + min_quality = 30 + + while quality >= min_quality: + buffer = io.BytesIO() + img.save(buffer, format='JPEG', quality=quality, optimize=True) + compressed_data = buffer.getvalue() + + if len(compressed_data) <= max_size_bytes: + logger.info(f"图片压缩成功: {len(image_data)/1024:.1f}KB -> {len(compressed_data)/1024:.1f}KB (quality={quality})") + return compressed_data + + quality -= 10 + + # 如果质量降到最小还不行,尝试缩小尺寸 + logger.warning(f"降低质量无法满足要求,尝试缩小尺寸") + width, height = img.size + + while width > 200 and height > 200: + width = int(width * 0.8) + height = int(height * 0.8) + resized_img = img.resize((width, height), Image.Resampling.LANCZOS) + + buffer = io.BytesIO() + resized_img.save(buffer, format='JPEG', quality=min_quality, optimize=True) + compressed_data = buffer.getvalue() + + if len(compressed_data) <= max_size_bytes: + logger.info(f"图片压缩成功: {len(image_data)/1024:.1f}KB -> {len(compressed_data)/1024:.1f}KB ({width}x{height})") + return compressed_data + + logger.error(f"无法将图片压缩到 {max_size_kb}KB 以下") + return None + + def send_image_with_text(self, image_path: str, title: str = "图片", text: str = "") -> bool: + """ + 发送图文消息(markdown格式嵌入图片链接) + 注意:需要使用钉钉的媒体文件上传接口获取URL,这里使用简化的markdown图片语法 + + Args: + image_path: 图片文件路径 + title: 消息标题 + text: accompanying text + + Returns: + bool: 是否发送成功 + """ + if not self.webhook: + logger.warning(f"[钉钉图文未发送] {title}") + return False + + if not os.path.exists(image_path): + logger.error(f"图片文件不存在: {image_path}") + return False + + # 先尝试直接发送图片 + success = self.send_image(image_path, title) + + # 如果图片发送成功且有文字,再发送文字 + if success and text: + time.sleep(0.5) # 避免发送过快 + return self.send_text(f"{title}\n{text}") + + return success + + def send_file(self, file_path: str, title: str = "文件") -> bool: + """ + 发送文件(通过钉钉文件上传接口) + 注意:需要企业版钉钉机器人,个人版可能不支持 + + Args: + file_path: 文件路径 + title: 文件标题 + + Returns: + bool: 是否发送成功 + """ + if not self.webhook: + logger.warning(f"[钉钉文件未发送] {title}") + return False + + if not os.path.exists(file_path): + logger.error(f"文件不存在: {file_path}") + return False + + try: + # 获取文件大小 + file_size = os.path.getsize(file_path) + file_name = os.path.basename(file_path) + + # 钉钉文件大小限制:20MB + if file_size > 20 * 1024 * 1024: + logger.error(f"文件过大: {file_size/1024/1024:.1f}MB > 20MB") + return False + + # 读取文件并转为base64 + with open(file_path, "rb") as f: + file_data = f.read() + + file_base64 = base64.b64encode(file_data).decode("utf-8") + + # 构建消息 + data = { + "msgtype": "file", + "file": { + "base64": file_base64, + "name": file_name + } + } + + url = self._gen_signed_url() + response = requests.post(url, json=data, timeout=30) + response.raise_for_status() + result = response.json() + + if result.get("errcode", -1) != 0: + # 如果文件发送失败,尝试作为图片发送 + if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')): + logger.warning(f"文件发送失败,尝试作为图片发送: {result}") + return self.send_image(file_path, title) + logger.error(f"钉钉文件发送失败: {result}") + return False + + logger.info(f"钉钉文件发送成功: {file_name}") + return True + + except Exception as e: + logger.error(f"钉钉文件发送异常: {e}") + return False + + def send_local_image_as_link(self, image_path: str, title: str = "图片", text: str = "") -> bool: + """ + 发送本地图片(转换为base64嵌入markdown) + 注意:钉钉不支持直接显示base64图片,此方法会发送图片链接文本 + + Args: + image_path: 图片文件路径 + title: 消息标题 + text: 附加文本 + + Returns: + bool: 是否发送成功 + """ + if not self.webhook: + logger.warning(f"[钉钉图片链接未发送] {title}") + return False + + if not os.path.exists(image_path): + logger.error(f"图片文件不存在: {image_path}") + return False + + try: + # 读取图片 + with open(image_path, "rb") as f: + image_data = f.read() + + # 转为base64 + image_base64 = base64.b64encode(image_data).decode("utf-8") + + # 获取图片格式 + ext = os.path.splitext(image_path)[1].lower().replace('.', '') + if ext == 'jpg': + ext = 'jpeg' + if ext not in ['png', 'jpeg', 'gif', 'bmp']: + ext = 'png' + + # 构建data URL(钉钉可能不支持直接显示,但可以作为链接) + data_url = f"data:image/{ext};base64,{image_base64[:100]}..." + + # 发送markdown消息 + markdown = f"## {title}\n\n" + if text: + markdown += f"{text}\n\n" + markdown += f"**图片**: {os.path.basename(image_path)}\n\n" + markdown += f"大小: {len(image_data)/1024:.1f}KB\n" + + return self.send_markdown(title, markdown) + + except Exception as e: + logger.error(f"发送图片链接异常: {e}") + return False + + def send_image_via_oss( + self, + image_path: str, + title: str = "策略图表", + text: str = "", + expire_days: int = 7, + ) -> bool: + """ + 上传图片到 OSS 并通过 Markdown 发送到钉钉 + 这是发送图片的推荐方式,不受 20KB 限制 + + Args: + image_path: 本地图片路径 + title: 消息标题 + text: 附加文本 + expire_days: OSS 链接有效期(天) + + Returns: + bool: 是否发送成功 + """ + if not self.webhook: + logger.warning(f"[钉钉OSS图片未发送] {title}") + return False + + if not os.path.exists(image_path): + logger.error(f"图片文件不存在: {image_path}") + return False + + try: + # 上传图片到 OSS + image_url = upload_image_to_oss(image_path, expire_days) + + if not image_url: + logger.error("图片上传到 OSS 失败") + # 尝试直接发送压缩后的图片 + return self.send_image(image_path, title) + + # 构建 Markdown 消息 + markdown = f"## {title}\n\n" + + if text: + markdown += f"{text}\n\n" + + # 添加图片(钉钉 Markdown 支持图片语法) + markdown += f"![{title}]({image_url})\n\n" + + # 添加图片信息 + file_size = os.path.getsize(image_path) / 1024 + markdown += f"---\n" + markdown += f"**图片**: {os.path.basename(image_path)} ({file_size:.1f}KB)\n" + markdown += f"**有效期**: {expire_days}天\n" + + # 发送 Markdown 消息 + return self.send_markdown(title, markdown) + + except Exception as e: + logger.error(f"发送 OSS 图片异常: {e}") + # 失败时尝试直接发送 + return self.send_image(image_path, title) + + +def send_to_all_groups( + send_func_name: str, + **kwargs, +) -> bool: + """ + 向所有已配置的钉钉群发送消息 + + Args: + send_func_name: DingTalkBot 的发送方法名,如 'send_text', 'send_markdown', 'send_image_via_oss' + **kwargs: 传递给发送方法的参数 + + Returns: + bool: 是否全部发送成功 + """ + configs = _get_all_dingtalk_configs() + if not configs: + logger.warning("没有配置任何钉钉群,消息未发送") + return False + + all_success = True + for i, cfg in enumerate(configs, 1): + bot = DingTalkBot(webhook=cfg["webhook"], secret=cfg["secret"]) + method = getattr(bot, send_func_name, None) + if method is None: + logger.error(f"DingTalkBot 没有方法: {send_func_name}") + return False + try: + success = method(**kwargs) + if success: + logger.info(f"群{i} 发送成功") + else: + logger.error(f"群{i} 发送失败") + all_success = False + except Exception as e: + logger.error(f"群{i} 发送异常: {e}") + all_success = False + return all_success + + +class NotificationManager: + """通知管理器 - 统一管理多种通知渠道""" + + def __init__(self): + self.dingtalk = DingTalkBot() + + def notify(self, message: str, title: str = "系统通知", use_markdown: bool = False): + """ + 发送通知(优先使用钉钉,失败则记录日志) + + Args: + message: 消息内容 + title: 消息标题(markdown模式使用) + use_markdown: 是否使用markdown格式 + """ + if use_markdown: + success = self.dingtalk.send_markdown(title, message) + else: + success = self.dingtalk.send_text(message) + + if not success: + # 钉钉发送失败,记录到日志 + logger.info(f"[通知] {title}: {message}") + + def notify_error(self, error_msg: str): + """发送错误通知""" + markdown = f"""## 错误告警 + +**时间**: {time.strftime('%Y-%m-%d %H:%M:%S')} + +**错误信息**: +``` +{error_msg} +``` +""" + self.notify(markdown, title="系统错误", use_markdown=True) + + def notify_signal(self, signals: list, signal_type: str = "CCI超卖"): + """ + 发送交易信号通知 + + Args: + signals: 信号列表,每项为dict包含code, name等指标 + signal_type: 信号类型名称 + """ + if not signals: + logger.info(f"[{signal_type}] 无信号") + return + + # 构建markdown表格 + if signals: + headers = signals[0].keys() + header_line = " | ".join(headers) + separator = " | ".join(["---"] * len(headers)) + + rows = [] + for s in signals: + row = " | ".join(str(v) for v in s.values()) + rows.append(row) + + table = f"{header_line}\n{separator}\n" + "\n".join(rows) + else: + table = "无" + + markdown = f"""## {signal_type}信号 + +**时间**: {time.strftime('%Y-%m-%d %H:%M:%S')} + +**筛选结果**: + +{table} + +共 {len(signals)} 个标的符合筛选条件。 +""" + self.notify(markdown, title=f"{signal_type}信号", use_markdown=True) diff --git a/core/common/oss_utils.py b/core/common/oss_utils.py new file mode 100644 index 0000000..a30976a --- /dev/null +++ b/core/common/oss_utils.py @@ -0,0 +1,172 @@ +""" +阿里云 OSS 工具模块 +用于上传文件到 OSS 并生成访问链接 +""" + +import oss2 +import os +from datetime import datetime +from typing import Optional +from loguru import logger + + +class OSSUploader: + """OSS 文件上传器""" + + def __init__( + self, + access_key_id: str = None, + access_key_secret: str = None, + bucket_name: str = None, + endpoint: str = None, + ): + """ + 初始化 OSS 上传器 + + Args: + access_key_id: 阿里云 AccessKey ID + access_key_secret: 阿里云 AccessKey Secret + bucket_name: OSS Bucket 名称 + endpoint: OSS 区域 Endpoint + """ + # 从环境变量或参数获取配置 + self.access_key_id = access_key_id or os.getenv("OSS_ACCESS_KEY_ID") + self.access_key_secret = access_key_secret or os.getenv("OSS_ACCESS_KEY_SECRET") + self.bucket_name = bucket_name or os.getenv("OSS_BUCKET_NAME", "value-investing") + self.endpoint = endpoint or os.getenv("OSS_ENDPOINT", "https://oss-cn-wulanchabu.aliyuncs.com") + + self.bucket = None + self._init_bucket() + + def _init_bucket(self): + """初始化 OSS Bucket""" + if not all([self.access_key_id, self.access_key_secret, self.bucket_name, self.endpoint]): + logger.warning("OSS 配置不完整,无法初始化") + return + + try: + auth = oss2.Auth(self.access_key_id, self.access_key_secret) + self.bucket = oss2.Bucket(auth, self.endpoint, self.bucket_name) + logger.info(f"OSS Bucket 初始化成功: {self.bucket_name}") + except Exception as e: + logger.error(f"OSS Bucket 初始化失败: {e}") + self.bucket = None + + def upload_file( + self, + local_path: str, + oss_key: str = None, + expire_seconds: int = 3600 * 24 * 7, # 默认7天有效期 + ) -> Optional[str]: + """ + 上传文件到 OSS + + Args: + local_path: 本地文件路径 + oss_key: OSS 中的目标路径,如果不指定则自动生成 + expire_seconds: 预签名URL有效期(秒) + + Returns: + str: 可访问的 URL,失败返回 None + """ + if not self.bucket: + logger.error("OSS Bucket 未初始化") + return None + + if not os.path.exists(local_path): + logger.error(f"本地文件不存在: {local_path}") + return None + + try: + # 自动生成 OSS 路径 + if not oss_key: + file_name = os.path.basename(local_path) + date_str = datetime.now().strftime("%Y%m%d") + oss_key = f"etf-signals/{date_str}/{file_name}" + + # 上传文件 + self.bucket.put_object_from_file(oss_key, local_path) + logger.info(f"文件上传成功: {local_path} -> {oss_key}") + + # 生成预签名 URL + url = self.bucket.sign_url("GET", oss_key, expire_seconds) + return url + + except Exception as e: + logger.error(f"文件上传失败: {e}") + return None + + def upload_image( + self, + image_path: str, + expire_days: int = 7, + ) -> Optional[str]: + """ + 上传图片到 OSS(专门用于钉钉通知) + + Args: + image_path: 图片文件路径 + expire_days: URL 有效期(天) + + Returns: + str: 图片访问 URL + """ + if not os.path.exists(image_path): + logger.error(f"图片文件不存在: {image_path}") + return None + + # 生成带时间戳的 OSS 路径 + file_name = os.path.basename(image_path) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + oss_key = f"etf-charts/{timestamp}_{file_name}" + + return self.upload_file(image_path, oss_key, expire_days * 24 * 3600) + + def delete_file(self, oss_key: str) -> bool: + """ + 删除 OSS 文件 + + Args: + oss_key: OSS 文件路径 + + Returns: + bool: 是否删除成功 + """ + if not self.bucket: + logger.error("OSS Bucket 未初始化") + return False + + try: + self.bucket.delete_object(oss_key) + logger.info(f"文件删除成功: {oss_key}") + return True + except Exception as e: + logger.error(f"文件删除失败: {e}") + return False + + +# 全局单例 +_oss_uploader: Optional[OSSUploader] = None + + +def get_oss_uploader() -> OSSUploader: + """获取 OSS 上传器单例""" + global _oss_uploader + if _oss_uploader is None: + _oss_uploader = OSSUploader() + return _oss_uploader + + +def upload_image_to_oss(image_path: str, expire_days: int = 7) -> Optional[str]: + """ + 便捷函数:上传图片到 OSS + + Args: + image_path: 图片路径 + expire_days: URL 有效期(天) + + Returns: + str: 图片访问 URL + """ + uploader = get_oss_uploader() + return uploader.upload_image(image_path, expire_days) diff --git a/rotation/daily_scheduler.py b/rotation/daily_scheduler.py index 98d2d70..967bb29 100644 --- a/rotation/daily_scheduler.py +++ b/rotation/daily_scheduler.py @@ -30,10 +30,6 @@ from pathlib import Path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) -# 添加归档模块路径(使用原有的 notify 和 oss_utils) -archive_path = project_root / 'archive' / 'legacy_core' -sys.path.insert(0, str(archive_path)) - # 加载环境变量 from dotenv import load_dotenv load_dotenv() @@ -154,7 +150,7 @@ def run_strategy(config_path: str = "strategies/rotation/config.yaml") -> dict: return {"success": False, "error": str(e)} -def run_simple_rotation(config_path: str = None) -> dict: +def run_simple_rotation(config_path: str = None, no_detail: bool = False, no_report: bool = False) -> dict: """ 执行 simple_rotation.py 策略回测并生成报告 @@ -173,6 +169,10 @@ def run_simple_rotation(config_path: str = None) -> dict: ] if config_path: cmd.extend(["--config", config_path]) + if no_detail: + cmd.append("--no-detail") + if no_report: + cmd.append("--no-report") logger.info(f"执行命令: {' '.join(cmd)}") @@ -284,7 +284,9 @@ def setup_schedule(target_time: str = "15:30", def daily_task(config_path: str = "strategies/rotation/config.yaml", strategy: str = "all", - simple_config: str = None): + simple_config: str = None, + no_detail: bool = False, + no_report: bool = False): """ 每日任务主流程 @@ -305,7 +307,7 @@ def daily_task(config_path: str = "strategies/rotation/config.yaml", # 2. 执行 Simple Rotation 策略 if strategy in ("simple", "all"): - result = run_simple_rotation(simple_config) + result = run_simple_rotation(simple_config, no_detail=no_detail, no_report=no_report) if result["success"]: if result.get("chart_path"): send_report_to_dingtalk( @@ -384,6 +386,16 @@ def main(): action='store_true', help='非后台模式:执行一次后进入定时循环(测试用)' ) + parser.add_argument( + '--no-detail', + action='store_true', + help='跳过 detail JSON 导出(加速日常运行)' + ) + parser.add_argument( + '--no-report', + action='store_true', + help='跳过 report PNG 生成' + ) args = parser.parse_args() @@ -392,12 +404,12 @@ def main(): if args.now: # 立即执行一次并退出 - daily_task(args.config, args.strategy, args.simple_config) + daily_task(args.config, args.strategy, args.simple_config, args.no_detail, args.no_report) elif args.no_daemon: # 非后台模式:执行一次后进入定时循环 setup_schedule(args.time, args.config, args.strategy, args.simple_config) logger.info("执行一次测试...") - daily_task(args.config, args.strategy, args.simple_config) + daily_task(args.config, args.strategy, args.simple_config, args.no_detail, args.no_report) logger.info("测试完成,启动定时任务循环(Ctrl+C 停止)...") run_scheduler_loop() else: diff --git a/rotation/simple_rotation.py b/rotation/simple_rotation.py index 6a21feb..cf9f415 100644 --- a/rotation/simple_rotation.py +++ b/rotation/simple_rotation.py @@ -1013,8 +1013,14 @@ class SimpleRotationStrategy: # Export # ============================================================ - def export_results(self, output_dir: str = None): - """Export backtest results to CSV and JSON (V2-compatible detail format)""" + def export_results(self, output_dir: str = None, detail: bool = True): + """Export backtest results to CSV and JSON (V2-compatible detail format) + + Args: + output_dir: Output directory path. + detail: If True, export detail JSON (large file for backtest_viewer). + Set to False for daily runs to skip expensive detail generation. + """ if not self.daily_records: print(" x No results to export") return @@ -1036,81 +1042,82 @@ class SimpleRotationStrategy: df[['date', 'holdings', 'is_rebalance', 'added', 'removed']].to_csv(sig_path, index=False) print(f" + Signals: {sig_path}") - # Detail JSON (V2-compatible format) - detail_path = output_dir / 'simple_rotation_detail.json' - days_out = [] - # Track entry_info across days for asset detail reconstruction - tracked_entry: Dict[str, dict] = {} - prev_holdings = [] + # Detail JSON (V2-compatible format, optional — large file for backtest_viewer) + if detail: + detail_path = output_dir / 'simple_rotation_detail.json' + days_out = [] + # Track entry_info across days for asset detail reconstruction + tracked_entry: Dict[str, dict] = {} + prev_holdings = [] - # Build date index map for signal_date lookup (calendar T-1, not A-share prev trading day) - date_list = [pd.Timestamp(rec['date']) for rec in self.daily_records] - date_to_signal_date = {d: d - timedelta(days=1) for d in date_list} + # Build date index map for signal_date lookup (calendar T-1, not A-share prev trading day) + date_list = [pd.Timestamp(rec['date']) for rec in self.daily_records] + date_to_signal_date = {d: d - timedelta(days=1) for d in date_list} - for rec in self.daily_records: - date = pd.Timestamp(rec['date']) - signal_date = date_to_signal_date[date] # T-1 for signal - holdings = rec['holdings'] - added = set(holdings) - set(prev_holdings) - removed = set(prev_holdings) - set(holdings) + for rec in self.daily_records: + date = pd.Timestamp(rec['date']) + signal_date = date_to_signal_date[date] # T-1 for signal + holdings = rec['holdings'] + added = set(holdings) - set(prev_holdings) + removed = set(prev_holdings) - set(holdings) - # Update entry tracking (consistent with run() logic) - for code in added: - trade_code = self.signal_to_trade.get(code, code) - etf_prices = self._get_etf_prices(trade_code, date) - # Entry price = actual buy price at T's open - entry_etf = etf_prices['open'] if etf_prices else None - # Index close at T-1 (signal data used for decision) - idx_close = self._get_index_close(code, signal_date) - tracked_entry[code] = { - 'entry_date': date.strftime('%Y-%m-%d'), - 'entry_price_etf': entry_etf, - 'entry_price_idx': idx_close, - } - for code in removed: - tracked_entry.pop(code, None) + # Update entry tracking (consistent with run() logic) + for code in added: + trade_code = self.signal_to_trade.get(code, code) + etf_prices = self._get_etf_prices(trade_code, date) + # Entry price = actual buy price at T's open + entry_etf = etf_prices['open'] if etf_prices else None + # Index close at T-1 (signal data used for decision) + idx_close = self._get_index_close(code, signal_date) + tracked_entry[code] = { + 'entry_date': date.strftime('%Y-%m-%d'), + 'entry_price_etf': entry_etf, + 'entry_price_idx': idx_close, + } + for code in removed: + tracked_entry.pop(code, None) - # Build signals dict: {code: 1} for selected holdings - signals = {c: 1 for c in holdings} + # Build signals dict: {code: 1} for selected holdings + signals = {c: 1 for c in holdings} - # Build per-asset details - assets = self._build_day_assets(rec, date, tracked_entry) + # Build per-asset details + assets = self._build_day_assets(rec, date, tracked_entry) - days_out.append({ - 'date': rec['date'], - 'nav': rec['nav'], - 'daily_return': rec['daily_return'], - 'is_rebalance': rec['is_rebalance'], - 'signals': signals, - 'holdings': holdings, - 'added': rec['added'], - 'removed': rec['removed'], - 'assets': assets, - }) - prev_holdings = holdings + days_out.append({ + 'date': rec['date'], + 'nav': rec['nav'], + 'daily_return': rec['daily_return'], + 'is_rebalance': rec['is_rebalance'], + 'signals': signals, + 'holdings': holdings, + 'added': rec['added'], + 'removed': rec['removed'], + 'assets': assets, + }) + prev_holdings = holdings - detail = { - 'meta': { - 'mode': 'Simple: Daily Iteration', - 'start_date': self.config.backtest.start_date, - 'end_date': self.daily_records[-1]['date'] if self.daily_records else self.config.backtest.end_date or 'now', - 'total_days': len(self.daily_records), - 'select_num': self.select_num, - 'n_days': self.n_days, - 'trade_cost': self.trade_cost, - 'bond_threshold': { - 'enabled': self.use_dynamic_threshold, - 'bond_code': self.bond_code, - 'ratio': self.bond_ratio, + detail_data = { + 'meta': { + 'mode': 'Simple: Daily Iteration', + 'start_date': self.config.backtest.start_date, + 'end_date': self.daily_records[-1]['date'] if self.daily_records else self.config.backtest.end_date or 'now', + 'total_days': len(self.daily_records), + 'select_num': self.select_num, + 'n_days': self.n_days, + 'trade_cost': self.trade_cost, + 'bond_threshold': { + 'enabled': self.use_dynamic_threshold, + 'bond_code': self.bond_code, + 'ratio': self.bond_ratio, + }, + 'codes': self._build_meta_codes(), }, - 'codes': self._build_meta_codes(), - }, - 'days': days_out, - } - _sanitize_json(detail) - with open(detail_path, 'w', encoding='utf-8') as f: - json.dump(detail, f, ensure_ascii=False, indent=2) - print(f" + Detail: {detail_path} ({len(days_out)} days)") + 'days': days_out, + } + _sanitize_json(detail_data) + with open(detail_path, 'w', encoding='utf-8') as f: + json.dump(detail_data, f, ensure_ascii=False, indent=2) + print(f" + Detail: {detail_path} ({len(days_out)} days)") # Metrics JSON metrics = self._compute_metrics(sum(1 for r in self.daily_records if r['is_rebalance'])) @@ -1275,6 +1282,7 @@ class SimpleRotationStrategy: 'premium': premium, 'action': action, 'entry_date': entry_date, 'entry_price': entry_price, 'holding_days': holding_days, 'pnl': pnl, + 'exit_date': None, 'exit_price': None, }) # Build exit positions: ONLY when the last day is a rebalance day @@ -1301,6 +1309,7 @@ class SimpleRotationStrategy: exit_entry_price = None exit_holding_days = 0 exit_pnl = None + sell_price = None for rec in reversed(self.daily_records[:last_idx]): if code in rec['holdings']: exit_entry_date = pd.Timestamp(rec['date']) @@ -1327,6 +1336,7 @@ class SimpleRotationStrategy: 'premium': premium, 'action': '调出', 'entry_date': exit_entry_date, 'entry_price': exit_entry_price, 'holding_days': exit_holding_days, 'pnl': exit_pnl, + 'exit_date': last_date, 'exit_price': sell_price, }) # Build unselected (未入选) positions: all signal codes not held and not exited @@ -1351,6 +1361,7 @@ class SimpleRotationStrategy: 'premium': premium, 'action': '未入选', 'entry_date': None, 'entry_price': None, 'holding_days': 0, 'pnl': None, + 'exit_date': None, 'exit_price': None, }) # ==================== Plot ==================== @@ -1374,7 +1385,8 @@ class SimpleRotationStrategy: rank_map = {c: i + 1 for i, c in enumerate(ranked_all)} col_labels = ["排名", "标的名称", "市场", "指数代码", "ETF代码", "仓位", "得分", - "指数最新价", "ETF收盘价", "溢价率", "状态", "进场日期", "持有天数", "盈亏"] + "指数最新价", "ETF收盘价", "溢价率", "状态", + "进场日期", "持有天数", "盈亏", "退场日期", "退场价格"] table_data = [] row_actions = [] # track action for coloring @@ -1391,9 +1403,12 @@ class SimpleRotationStrategy: pnl_s = f"{p['pnl']:+.2%}" if p['pnl'] is not None else "—" weight_s = f"{p['weight']:.0%}" if p['weight'] > 0 else "—" market = code_config.get(p['code'], {}).get('market', '—') + exit_date_s = p['exit_date'].strftime('%Y-%m-%d') if p.get('exit_date') else "—" + exit_price_s = f"{p['exit_price']:.3f}" if p.get('exit_price') else "—" table_data.append([ rank, p['name'], market, p['code'], p['etf'], weight_s, - score_s, idx_s, etf_s, prem_s, p['action'], entry_date_s, days_s, pnl_s + score_s, idx_s, etf_s, prem_s, p['action'], + entry_date_s, days_s, pnl_s, exit_date_s, exit_price_s ]) row_actions.append(p['action']) @@ -1571,12 +1586,21 @@ class SimpleRotationStrategy: # ============================================================ if __name__ == "__main__": + import argparse if 'FLASK_API_URL' not in os.environ: os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz' + parser = argparse.ArgumentParser(description='Simple Rotation Strategy Backtest') + parser.add_argument('--no-detail', action='store_true', + help='Skip detail JSON export (faster, for daily runs)') + parser.add_argument('--no-report', action='store_true', + help='Skip report PNG generation') + args = parser.parse_args() + strategy = SimpleRotationStrategy() result = strategy.run() if result: - strategy.export_results() - strategy.generate_report() + strategy.export_results(detail=not args.no_detail) + if not args.no_report: + strategy.generate_report()