feat: 启用 Apple Silicon MPS 加速 + 兼容 transformers 5.x + 本地运行配置

nvidia_router.py 变更:
- device 默认值从 'cpu' 改为 'auto',自动检测 MPS/CUDA/CPU
- AutoConfig 替换为 DebertaV2Config + 手动解析 config.json
  (nvidia/prompt-task-and-complexity-classifier 的 config.json 无 model_type,
   transformers 5.x 的 AutoConfig 会直接报错)
- MPS 设备自动转换 float16,修复 MPS 矩阵乘法数据类型冲突崩溃
  (MPS NDArrayMatrixMultiplication 要求 dst/accumulator 同类型)
- 日志增加设备和精度信息输出

docker-compose.yml 变更:
- 端口映射改为 402:8000 (本地开发端口)
- volume 从 named volume 改为 ./data 本地目录映射
- API Key 改回环境变量引用 (密钥存 .env 文件,已在 .gitignore 中)

测试环境: Mac Mini M4 Pro / 64GB / macOS 15.3.1
运行方式: .venv/bin/python -m uvicorn main:app --host 0.0.0.0 --port 402
测试结果:
- MPS + FP16 分类器正常工作,稳态路由延迟 ~53ms
- NVIDIA 3-tier 路由决策正确 (simple/medium/complex)
- OpenAI 兼容 API 正常响应,DashScope Qwen 模型调用正常
This commit is contained in:
2026-04-19 00:17:38 +08:00
parent 4c439d2d7e
commit 2afe976a31
2 changed files with 31 additions and 7 deletions

View File

@@ -3,11 +3,11 @@ services:
build: . build: .
container_name: llm-compass container_name: llm-compass
ports: ports:
- "8000:8000" - "402:8000"
environment: environment:
- DASHSCOPE_API_KEY=${DASHSCOPE_API_KEY} - DASHSCOPE_API_KEY=${DASHSCOPE_API_KEY}
volumes: volumes:
- compass-data:/app/data - ./data:/app/data
restart: unless-stopped restart: unless-stopped
healthcheck: healthcheck:
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"] test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]

View File

@@ -10,11 +10,12 @@ NVIDIA Prompt Task & Complexity Classifier Router
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoTokenizer, DebertaV2Model, AutoConfig from transformers import AutoTokenizer, DebertaV2Model, DebertaV2Config
from safetensors.torch import load_file from safetensors.torch import load_file
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from typing import Dict, Optional from typing import Dict, Optional
import logging import logging
import json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -115,7 +116,16 @@ class NvidiaComplexityRouter:
# Creativity 映射 # Creativity 映射
CREATIVITY_MAP = {0: "High", 1: "Low", 2: "No"} CREATIVITY_MAP = {0: "High", 1: "Low", 2: "No"}
def __init__(self, device: str = "cpu"): def __init__(self, device: str = "auto"):
if device == "auto":
if torch.backends.mps.is_available():
device = "mps"
logger.info("MPS (Metal GPU) detected, using MPS acceleration")
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
logger.info("No GPU detected, using CPU")
self.device = device self.device = device
self.tokenizer = None self.tokenizer = None
self.model = None self.model = None
@@ -129,8 +139,18 @@ class NvidiaComplexityRouter:
logger.info(f"Loading NVIDIA classifier: {self.MODEL_NAME}") logger.info(f"Loading NVIDIA classifier: {self.MODEL_NAME}")
# 1. 加载 config # 1. 手动加载自定义 config.json该模型无 model_typeAutoConfig 不兼容)
self.config = AutoConfig.from_pretrained(self.MODEL_NAME) config_path = hf_hub_download(self.MODEL_NAME, "config.json")
with open(config_path, "r") as f:
custom_config = json.load(f)
# 构建 backbone 的 DeBERTa config从 base_model 加载)
base_model = custom_config.get("base_model", "microsoft/DeBERTa-v3-base")
self.config = DebertaV2Config.from_pretrained(base_model)
# 保存自定义分类头参数
self.config.target_sizes = custom_config["target_sizes"]
self.config.fc_dropout = custom_config.get("fc_dropout", 0.2)
self.config.base_model = base_model
# 2. 加载 tokenizer (slow模式兼容性好) # 2. 加载 tokenizer (slow模式兼容性好)
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME, use_fast=False) self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME, use_fast=False)
@@ -143,9 +163,13 @@ class NvidiaComplexityRouter:
self.model.load_state_dict(state_dict, strict=False) self.model.load_state_dict(state_dict, strict=False)
self.model.to(self.device) self.model.to(self.device)
# MPS 需要 float16 以避免矩阵乘法数据类型冲突
if self.device == "mps":
self.model.half()
self.model.eval() self.model.eval()
self._initialized = True self._initialized = True
logger.info("NVIDIA classifier loaded successfully") dtype = "float16" if self.device == "mps" else "float32"
logger.info(f"NVIDIA classifier loaded successfully on {self.device} ({dtype})")
def predict(self, query: str) -> Dict: def predict(self, query: str) -> Dict:
""" """