Compare commits
7 Commits
4c439d2d7e
...
00083627f7
| Author | SHA1 | Date | |
|---|---|---|---|
| 00083627f7 | |||
| d7155e98c3 | |||
| 508118cc50 | |||
| 943fc9dcc0 | |||
| a6a471c5c4 | |||
| 72345871c6 | |||
| 2afe976a31 |
281
README.md
Normal file
281
README.md
Normal file
@@ -0,0 +1,281 @@
|
||||
# LLM Compass
|
||||
|
||||
智能 LLM 路由服务,基于 NVIDIA 多头分类器和 Apple Silicon MPS 加速,为查询自动选择最优模型,兼顾质量与成本。
|
||||
|
||||
---
|
||||
|
||||
## 项目背景
|
||||
|
||||
在大规模使用 LLM 的场景中,不同复杂度的查询适合不同规格的模型:
|
||||
- 简单问候用 `qwen-flash` 即可,成本低、延迟小
|
||||
- 代码生成需要 `qwen-plus` 保证质量
|
||||
- 复杂分析任务才值得调用 `qwen-max`
|
||||
|
||||
手动选择模型效率低下,而全部使用最强模型又浪费成本。**LLM Compass** 的目标是自动为每个查询选择"刚刚好"的模型。
|
||||
|
||||
灵感来源于 [tx402.ai](https://tx402.ai) 的三层路由架构,本项目采用开源 NVIDIA 多头分类器实现了类似能力。
|
||||
|
||||
---
|
||||
|
||||
## 核心方法
|
||||
|
||||
### NVIDIA 多头分类器
|
||||
|
||||
采用 [nvidia/prompt-task-and-complexity-classifier](https://huggingface.co/nvidia/prompt-task-and-complexity-classifier)(184M 参数,DeBERTa-v3-base 架构):
|
||||
|
||||
```
|
||||
用户查询 → DeBERTa Backbone → 8个分类头 → 综合评分 → 3-tier路由
|
||||
↓
|
||||
task_type (12类)
|
||||
creativity (3类)
|
||||
reasoning (2类)
|
||||
domain_knowledge (4类)
|
||||
contextual_knowledge
|
||||
number_of_few_shots
|
||||
no_label_reason
|
||||
constraint_ct
|
||||
```
|
||||
|
||||
### 复杂度评分公式
|
||||
|
||||
```python
|
||||
score = (
|
||||
0.4 * domain_knowledge + # High=1.0, Medium=0.6, Low=0.3, No=0.0
|
||||
0.3 * reasoning + # Yes=1.0, No=0.0
|
||||
0.2 * creativity + # High=1.0, Low=0.4, No=0.0
|
||||
0.1 * task_type # Code=0.8, QA=0.5, Chatbot=0.2, ...
|
||||
)
|
||||
|
||||
# 3-tier 路由
|
||||
score < 0.35 → simple → qwen-flash
|
||||
0.35 ≤ score < 0.65 → medium → qwen-plus
|
||||
score ≥ 0.65 → complex → qwen-max
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 技术架构
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────────────────┐
|
||||
│ LLM Compass │
|
||||
├──────────────────────────────────────────────────────────┤
|
||||
│ API Layer: FastAPI (OpenAI 兼容) │
|
||||
│ ├─ POST /v1/chat/completions (流式/非流式) │
|
||||
│ ├─ GET /v1/models │
|
||||
│ ├─ GET /stats │
|
||||
│ └─ GET /docs (Swagger UI) │
|
||||
├──────────────────────────────────────────────────────────┤
|
||||
│ Routing Layer: NVIDIA Multi-Head Classifier (184M) │
|
||||
│ ├─ 8 维度查询分析 │
|
||||
│ ├─ 综合复杂度评分 │
|
||||
│ └─ 3-tier 智能路由 │
|
||||
├──────────────────────────────────────────────────────────┤
|
||||
│ LLM Backend: LiteLLM (多提供商统一接口) │
|
||||
│ ├─ DashScope (Qwen) │
|
||||
│ ├─ OpenAI (GPT) │
|
||||
│ ├─ Anthropic (Claude) │
|
||||
│ └─ Google (Gemini) │
|
||||
└──────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Apple Silicon 优化 (M4 Pro)
|
||||
|
||||
- **MPS 加速**: 使用 Metal Performance Shaders GPU 后端
|
||||
- **FP16 推理**: 半精度浮点,避免 MPS 矩阵乘法类型冲突
|
||||
- **统一内存**: M4 Pro 64GB 统一内存,模型加载零拷贝
|
||||
|
||||
---
|
||||
|
||||
## 实现效果
|
||||
|
||||
### 路由准确性
|
||||
|
||||
| 查询示例 | Tier | Score | 路由模型 |
|
||||
|---------|------|-------|---------|
|
||||
| "你好" | simple | 0.17 | qwen-flash |
|
||||
| "1+1等于几" | simple | 0.17 | qwen-flash |
|
||||
| "Write quicksort in Python" | medium | 0.45 | qwen-plus |
|
||||
| "分析深度学习的注意力机制原理" | medium | 0.47 | qwen-plus |
|
||||
| "请详细分析量子计算对密码学的影响" | complex | 0.72 | qwen-max |
|
||||
|
||||
### 成本优化
|
||||
|
||||
根据实际调用统计:
|
||||
- **16 次调用总成本**: $0.011
|
||||
- **模型分布**: qwen-flash 87.5% (14次), qwen-plus 12.5% (2次)
|
||||
- **任务类型**: 主要为 Open QA (13次)
|
||||
- **复杂度分布**: simple 73%, medium 13%
|
||||
|
||||
---
|
||||
|
||||
## 路由延迟
|
||||
|
||||
| 环境 | 首次加载 | 稳态延迟 | 备注 |
|
||||
|------|---------|---------|------|
|
||||
| **M4 Pro MPS + FP16** | ~2s (模型加载) | **~60-90ms** | 当前生产环境 |
|
||||
| x86 CPU | ~3s | ~100-150ms | Docker 容器 |
|
||||
| NVIDIA 官方报告 | - | 5-15ms | 数据中心 CPU |
|
||||
|
||||
**说明**:
|
||||
- 首次加载包含模型下载和 MPS kernel 编译,后续请求无需重新加载
|
||||
- 稳态延迟约 60-90ms,其中分类器推理 ~53ms,其余为 FastAPI 开销
|
||||
- 对于 LLM 调用本身(通常 2-10s),路由开销占比 < 2%
|
||||
|
||||
---
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 前置要求
|
||||
|
||||
- Python 3.12+
|
||||
- macOS (Apple Silicon 推荐,支持 MPS 加速)
|
||||
- DashScope API Key(阿里云 Qwen)
|
||||
|
||||
### 安装
|
||||
|
||||
```bash
|
||||
# 1. 克隆项目
|
||||
git clone <repo-url>
|
||||
cd llm-compass
|
||||
|
||||
# 2. 创建虚拟环境
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
|
||||
# 3. 安装依赖
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 4. 配置 API Key
|
||||
cp .env.example .env
|
||||
# 编辑 .env 填入 DASHSCOPE_API_KEY
|
||||
```
|
||||
|
||||
### 启动服务
|
||||
|
||||
```bash
|
||||
./start.sh # 默认端口 8402
|
||||
./start.sh 9000 # 自定义端口
|
||||
```
|
||||
|
||||
### 测试
|
||||
|
||||
```bash
|
||||
# 健康检查
|
||||
curl http://localhost:8402/health
|
||||
|
||||
# API 测试 (自动路由)
|
||||
curl http://localhost:8402/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"messages":[{"role":"user","content":"你好"}]}'
|
||||
|
||||
# Swagger UI
|
||||
# 访问 http://localhost:8402/docs
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## API 使用
|
||||
|
||||
### OpenAI 兼容接口
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8402/v1",
|
||||
api_key="not-needed" # 可选
|
||||
)
|
||||
|
||||
# 自动路由(推荐)
|
||||
response = client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": "解释量子计算"}]
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
print(response.routing) # 路由详情
|
||||
|
||||
# 指定模型
|
||||
response = client.chat.completions.create(
|
||||
model="qwen-plus",
|
||||
messages=[{"role": "user", "content": "写一个排序算法"}]
|
||||
)
|
||||
```
|
||||
|
||||
### 响应格式
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-xxx",
|
||||
"model": "qwen-flash",
|
||||
"choices": [{"message": {"content": "..."}}],
|
||||
"usage": {"prompt_tokens": 13, "completion_tokens": 25, "total_tokens": 38},
|
||||
"routing": {
|
||||
"method": "nvidia_classifier",
|
||||
"tier": "simple",
|
||||
"complexity_score": 0.17,
|
||||
"task_type": "Open QA",
|
||||
"domain_knowledge": "Low",
|
||||
"reasoning": false,
|
||||
"creativity": "No",
|
||||
"routing_latency_ms": 63.27
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 技术栈
|
||||
|
||||
- **Web 框架**: FastAPI + Uvicorn
|
||||
- **路由模型**: NVIDIA Multi-Head Classifier (DeBERTa-v3-base)
|
||||
- **LLM 调用**: LiteLLM (多提供商统一接口)
|
||||
- **GPU 加速**: PyTorch MPS (Metal Performance Shaders)
|
||||
- **Token 计算**: tiktoken
|
||||
|
||||
---
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
llm-compass/
|
||||
├── main.py # FastAPI 主服务
|
||||
├── nvidia_router.py # NVIDIA 分类器实现
|
||||
├── config.py # 模型配置和路由阈值
|
||||
├── start.sh # 启动脚本
|
||||
├── .env # 环境变量(不提交到 Git)
|
||||
├── .env.example # 环境变量模板
|
||||
├── requirements.txt # Python 依赖
|
||||
├── Dockerfile # Docker 构建文件
|
||||
├── docker-compose.yml # Docker Compose 配置
|
||||
├── data/ # 调用历史日志(自动创建)
|
||||
└── docs/ # 技术文档
|
||||
└── llm-router-open-source-research.md
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 已知限制
|
||||
|
||||
1. **路由维度**: 当前仅支持 3 个 Qwen 模型,可扩展至 40+ 模型
|
||||
2. **在线学习**: 缺少多臂老虎机等在线学习机制
|
||||
3. **语义缓存**: 未实现查询缓存优化
|
||||
4. **Docker MPS**: macOS Docker 容器无法使用 Metal GPU,需原生运行
|
||||
|
||||
---
|
||||
|
||||
## 后续计划
|
||||
|
||||
- [ ] Layer 2: 多臂老虎机在线学习 (Thompson Sampling)
|
||||
- [ ] Layer 3: 语义缓存 + 批量优化
|
||||
- [ ] 扩展模型池至 40+ 模型
|
||||
- [ ] 基于业务数据微调 NVIDIA 分类器
|
||||
|
||||
---
|
||||
|
||||
## 许可证
|
||||
|
||||
Apache 2.0
|
||||
|
||||
---
|
||||
|
||||
**LLM Compass** - 让每个查询都找到最优的模型。
|
||||
@@ -3,11 +3,11 @@ services:
|
||||
build: .
|
||||
container_name: llm-compass
|
||||
ports:
|
||||
- "8000:8000"
|
||||
- "402:8000"
|
||||
environment:
|
||||
- DASHSCOPE_API_KEY=${DASHSCOPE_API_KEY}
|
||||
volumes:
|
||||
- compass-data:/app/data
|
||||
- ./data:/app/data
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||
|
||||
73
main.py
73
main.py
@@ -58,23 +58,33 @@ _load_history()
|
||||
|
||||
|
||||
# ── OpenAI 兼容请求/响应模型 ────────────────────────────────
|
||||
from pydantic import BaseModel, Field
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
role: str = Field(..., description="角色:system, user, assistant", example="user")
|
||||
content: Optional[str] = Field(None, description="消息内容", example="你好,介绍一下你自己")
|
||||
name: Optional[str] = Field(None, description="可选的名称")
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: Optional[str] = None
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = 0.7
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
top_p: Optional[float] = 1.0
|
||||
n: Optional[int] = 1
|
||||
stop: Optional[Any] = None
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
user: Optional[str] = None
|
||||
model: Optional[str] = Field(
|
||||
None,
|
||||
description="模型名称(留空时自动使用 NVIDIA 分类器智能路由)",
|
||||
example="qwen-plus",
|
||||
json_schema_extra={"examples": ["", "qwen-flash", "qwen-plus", "qwen-max"]}
|
||||
)
|
||||
messages: List[ChatMessage] = Field(
|
||||
...,
|
||||
description="对话消息列表",
|
||||
example=[{"role": "user", "content": "你好,介绍一下你自己"}]
|
||||
)
|
||||
temperature: Optional[float] = Field(0.7, ge=0, le=2, description="随机性 (0-2)")
|
||||
max_tokens: Optional[int] = Field(None, description="最大生成 token 数(留空时使用模型默认值)", example=2048)
|
||||
stream: Optional[bool] = Field(False, description="是否使用流式输出")
|
||||
top_p: Optional[float] = Field(1.0, ge=0, le=1, description="核采样参数")
|
||||
n: Optional[int] = Field(1, ge=1, le=10, description="生成回复数量")
|
||||
stop: Optional[Any] = Field(None, description="停止词")
|
||||
presence_penalty: Optional[float] = Field(0.0, ge=-2, le=2, description="存在惩罚")
|
||||
frequency_penalty: Optional[float] = Field(0.0, ge=-2, le=2, description="频率惩罚")
|
||||
user: Optional[str] = Field(None, description="用户标识")
|
||||
|
||||
|
||||
# ── FastAPI App ──────────────────────────────────────────────
|
||||
@@ -281,13 +291,18 @@ async def chat_completions(request: ChatCompletionRequest):
|
||||
|
||||
# 3. 非流式响应
|
||||
start_time = time.time()
|
||||
|
||||
# 构建请求参数(过滤掉 None 和 0 的 max_tokens)
|
||||
completion_kwargs = {
|
||||
"model": provider_model,
|
||||
"messages": messages_raw,
|
||||
"temperature": request.temperature,
|
||||
}
|
||||
if request.max_tokens and request.max_tokens > 0:
|
||||
completion_kwargs["max_tokens"] = request.max_tokens
|
||||
|
||||
try:
|
||||
response = await acompletion(
|
||||
model=provider_model,
|
||||
messages=messages_raw,
|
||||
temperature=request.temperature,
|
||||
max_tokens=request.max_tokens,
|
||||
)
|
||||
response = await acompletion(**completion_kwargs)
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
input_tokens = response.usage.prompt_tokens
|
||||
@@ -328,13 +343,17 @@ async def _stream_response(
|
||||
output_tokens = 0
|
||||
|
||||
try:
|
||||
response = await acompletion(
|
||||
model=provider_model,
|
||||
messages=messages_raw,
|
||||
temperature=request.temperature,
|
||||
max_tokens=request.max_tokens,
|
||||
stream=True,
|
||||
)
|
||||
# 构建请求参数(过滤掉 None 和 0 的 max_tokens)
|
||||
completion_kwargs = {
|
||||
"model": provider_model,
|
||||
"messages": messages_raw,
|
||||
"temperature": request.temperature,
|
||||
"stream": True,
|
||||
}
|
||||
if request.max_tokens and request.max_tokens > 0:
|
||||
completion_kwargs["max_tokens"] = request.max_tokens
|
||||
|
||||
response = await acompletion(**completion_kwargs)
|
||||
|
||||
async for chunk in response:
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
@@ -10,11 +10,12 @@ NVIDIA Prompt Task & Complexity Classifier Router
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoTokenizer, DebertaV2Model, AutoConfig
|
||||
from transformers import AutoTokenizer, DebertaV2Model, DebertaV2Config
|
||||
from safetensors.torch import load_file
|
||||
from huggingface_hub import hf_hub_download
|
||||
from typing import Dict, Optional
|
||||
import logging
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -115,7 +116,16 @@ class NvidiaComplexityRouter:
|
||||
# Creativity 映射
|
||||
CREATIVITY_MAP = {0: "High", 1: "Low", 2: "No"}
|
||||
|
||||
def __init__(self, device: str = "cpu"):
|
||||
def __init__(self, device: str = "auto"):
|
||||
if device == "auto":
|
||||
if torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
logger.info("MPS (Metal GPU) detected, using MPS acceleration")
|
||||
elif torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
logger.info("No GPU detected, using CPU")
|
||||
self.device = device
|
||||
self.tokenizer = None
|
||||
self.model = None
|
||||
@@ -129,8 +139,18 @@ class NvidiaComplexityRouter:
|
||||
|
||||
logger.info(f"Loading NVIDIA classifier: {self.MODEL_NAME}")
|
||||
|
||||
# 1. 加载 config
|
||||
self.config = AutoConfig.from_pretrained(self.MODEL_NAME)
|
||||
# 1. 手动加载自定义 config.json(该模型无 model_type,AutoConfig 不兼容)
|
||||
config_path = hf_hub_download(self.MODEL_NAME, "config.json")
|
||||
with open(config_path, "r") as f:
|
||||
custom_config = json.load(f)
|
||||
|
||||
# 构建 backbone 的 DeBERTa config(从 base_model 加载)
|
||||
base_model = custom_config.get("base_model", "microsoft/DeBERTa-v3-base")
|
||||
self.config = DebertaV2Config.from_pretrained(base_model)
|
||||
# 保存自定义分类头参数
|
||||
self.config.target_sizes = custom_config["target_sizes"]
|
||||
self.config.fc_dropout = custom_config.get("fc_dropout", 0.2)
|
||||
self.config.base_model = base_model
|
||||
|
||||
# 2. 加载 tokenizer (slow模式,兼容性好)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME, use_fast=False)
|
||||
@@ -143,9 +163,13 @@ class NvidiaComplexityRouter:
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
self.model.to(self.device)
|
||||
# MPS 需要 float16 以避免矩阵乘法数据类型冲突
|
||||
if self.device == "mps":
|
||||
self.model.half()
|
||||
self.model.eval()
|
||||
self._initialized = True
|
||||
logger.info("NVIDIA classifier loaded successfully")
|
||||
dtype = "float16" if self.device == "mps" else "float32"
|
||||
logger.info(f"NVIDIA classifier loaded successfully on {self.device} ({dtype})")
|
||||
|
||||
def predict(self, query: str) -> Dict:
|
||||
"""
|
||||
|
||||
48
start.sh
Executable file
48
start.sh
Executable file
@@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
# LLM Compass 启动脚本
|
||||
# 用法: ./start.sh [端口号]
|
||||
|
||||
set -e
|
||||
|
||||
# 默认端口
|
||||
PORT=${1:-8402}
|
||||
|
||||
# 获取脚本所在目录
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# 检查虚拟环境
|
||||
if [ ! -d ".venv" ]; then
|
||||
echo "❌ 虚拟环境不存在,请先运行:"
|
||||
echo " python3 -m venv .venv"
|
||||
echo " .venv/bin/pip install -r requirements.txt"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 检查 .env 文件
|
||||
if [ ! -f ".env" ]; then
|
||||
echo "❌ .env 文件不存在,请创建并配置 API Key"
|
||||
echo " cp .env.example .env"
|
||||
echo " 编辑 .env 填入 DASHSCOPE_API_KEY"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 加载环境变量
|
||||
export DASHSCOPE_API_KEY=$(grep DASHSCOPE_API_KEY .env | cut -d= -f2)
|
||||
|
||||
if [ -z "$DASHSCOPE_API_KEY" ]; then
|
||||
echo "❌ DASHSCOPE_API_KEY 未设置,请检查 .env 文件"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "🚀 启动 LLM Compass 服务..."
|
||||
echo "📍 地址: http://localhost:${PORT}"
|
||||
echo "📖 API 文档: http://localhost:${PORT}/docs"
|
||||
echo "🔧 路由方式: NVIDIA MPS 加速 (M4 Pro GPU)"
|
||||
echo ""
|
||||
|
||||
# 启动服务
|
||||
exec .venv/bin/python -m uvicorn main:app \
|
||||
--host 0.0.0.0 \
|
||||
--port "$PORT" \
|
||||
--log-level info
|
||||
Reference in New Issue
Block a user