From 2afe976a31af3abd5b58ae811fbd7f873ce0daf9 Mon Sep 17 00:00:00 2001 From: aszerW Date: Sun, 19 Apr 2026 00:17:38 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=90=AF=E7=94=A8=20Apple=20Silicon=20?= =?UTF-8?q?MPS=20=E5=8A=A0=E9=80=9F=20+=20=E5=85=BC=E5=AE=B9=20transformer?= =?UTF-8?q?s=205.x=20+=20=E6=9C=AC=E5=9C=B0=E8=BF=90=E8=A1=8C=E9=85=8D?= =?UTF-8?q?=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit nvidia_router.py 变更: - device 默认值从 'cpu' 改为 'auto',自动检测 MPS/CUDA/CPU - AutoConfig 替换为 DebertaV2Config + 手动解析 config.json (nvidia/prompt-task-and-complexity-classifier 的 config.json 无 model_type, transformers 5.x 的 AutoConfig 会直接报错) - MPS 设备自动转换 float16,修复 MPS 矩阵乘法数据类型冲突崩溃 (MPS NDArrayMatrixMultiplication 要求 dst/accumulator 同类型) - 日志增加设备和精度信息输出 docker-compose.yml 变更: - 端口映射改为 402:8000 (本地开发端口) - volume 从 named volume 改为 ./data 本地目录映射 - API Key 改回环境变量引用 (密钥存 .env 文件,已在 .gitignore 中) 测试环境: Mac Mini M4 Pro / 64GB / macOS 15.3.1 运行方式: .venv/bin/python -m uvicorn main:app --host 0.0.0.0 --port 402 测试结果: - MPS + FP16 分类器正常工作,稳态路由延迟 ~53ms - NVIDIA 3-tier 路由决策正确 (simple/medium/complex) - OpenAI 兼容 API 正常响应,DashScope Qwen 模型调用正常 --- docker-compose.yml | 4 ++-- nvidia_router.py | 34 +++++++++++++++++++++++++++++----- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 9b4ab67..311313d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,11 +3,11 @@ services: build: . container_name: llm-compass ports: - - "8000:8000" + - "402:8000" environment: - DASHSCOPE_API_KEY=${DASHSCOPE_API_KEY} volumes: - - compass-data:/app/data + - ./data:/app/data restart: unless-stopped healthcheck: test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"] diff --git a/nvidia_router.py b/nvidia_router.py index 7548a84..382385d 100644 --- a/nvidia_router.py +++ b/nvidia_router.py @@ -10,11 +10,12 @@ NVIDIA Prompt Task & Complexity Classifier Router import torch import torch.nn as nn -from transformers import AutoTokenizer, DebertaV2Model, AutoConfig +from transformers import AutoTokenizer, DebertaV2Model, DebertaV2Config from safetensors.torch import load_file from huggingface_hub import hf_hub_download from typing import Dict, Optional import logging +import json logger = logging.getLogger(__name__) @@ -115,7 +116,16 @@ class NvidiaComplexityRouter: # Creativity 映射 CREATIVITY_MAP = {0: "High", 1: "Low", 2: "No"} - def __init__(self, device: str = "cpu"): + def __init__(self, device: str = "auto"): + if device == "auto": + if torch.backends.mps.is_available(): + device = "mps" + logger.info("MPS (Metal GPU) detected, using MPS acceleration") + elif torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + logger.info("No GPU detected, using CPU") self.device = device self.tokenizer = None self.model = None @@ -129,8 +139,18 @@ class NvidiaComplexityRouter: logger.info(f"Loading NVIDIA classifier: {self.MODEL_NAME}") - # 1. 加载 config - self.config = AutoConfig.from_pretrained(self.MODEL_NAME) + # 1. 手动加载自定义 config.json(该模型无 model_type,AutoConfig 不兼容) + config_path = hf_hub_download(self.MODEL_NAME, "config.json") + with open(config_path, "r") as f: + custom_config = json.load(f) + + # 构建 backbone 的 DeBERTa config(从 base_model 加载) + base_model = custom_config.get("base_model", "microsoft/DeBERTa-v3-base") + self.config = DebertaV2Config.from_pretrained(base_model) + # 保存自定义分类头参数 + self.config.target_sizes = custom_config["target_sizes"] + self.config.fc_dropout = custom_config.get("fc_dropout", 0.2) + self.config.base_model = base_model # 2. 加载 tokenizer (slow模式,兼容性好) self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME, use_fast=False) @@ -143,9 +163,13 @@ class NvidiaComplexityRouter: self.model.load_state_dict(state_dict, strict=False) self.model.to(self.device) + # MPS 需要 float16 以避免矩阵乘法数据类型冲突 + if self.device == "mps": + self.model.half() self.model.eval() self._initialized = True - logger.info("NVIDIA classifier loaded successfully") + dtype = "float16" if self.device == "mps" else "float32" + logger.info(f"NVIDIA classifier loaded successfully on {self.device} ({dtype})") def predict(self, query: str) -> Dict: """