Compare commits
17 Commits
88842457ea
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 00083627f7 | |||
| d7155e98c3 | |||
| 508118cc50 | |||
| 943fc9dcc0 | |||
| a6a471c5c4 | |||
| 72345871c6 | |||
| 2afe976a31 | |||
| 4c439d2d7e | |||
| 78bf3862ab | |||
| b33d3c026c | |||
| 1705426eef | |||
| 1e273e3670 | |||
| a247df34a5 | |||
| 5a322e93a0 | |||
| a370061a96 | |||
| 59c03516e4 | |||
| f9cc7973b9 |
11
.dockerignore
Normal file
11
.dockerignore
Normal file
@@ -0,0 +1,11 @@
|
||||
venv/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.git/
|
||||
.env
|
||||
data/
|
||||
docs/
|
||||
*.md
|
||||
.gitignore
|
||||
.env.example
|
||||
.pytest_cache/
|
||||
@@ -1,5 +1,5 @@
|
||||
# DashScope API Key (阿里云 Qwen)
|
||||
DASHSCOPE_API_KEY=sk-37e148fafdfb425f8cc1cfa4efcbc9e1
|
||||
DASHSCOPE_API_KEY=sk-your-dashscope-key-here
|
||||
|
||||
# OpenAI API Key
|
||||
OPENAI_API_KEY=sk-your-openai-key-here
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -10,6 +10,9 @@ __pycache__/
|
||||
.env
|
||||
.venv
|
||||
|
||||
# Data (call history logs)
|
||||
data/
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
47
Dockerfile
Normal file
47
Dockerfile
Normal file
@@ -0,0 +1,47 @@
|
||||
# ── Stage 1: 依赖安装 ──────────────────────────────────────
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
# 先拷贝依赖文件,利用 Docker 缓存
|
||||
COPY requirements.lock.txt .
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.lock.txt
|
||||
|
||||
# ── Stage 2: 运行时 ────────────────────────────────────────
|
||||
FROM python:3.12-slim
|
||||
|
||||
LABEL maintainer="LLM Compass"
|
||||
LABEL description="智能LLM路由服务,为请求指引最优模型"
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 安装运行时系统依赖(sentencepiece 等)
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends libgomp1 && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 从 builder 拷贝 Python 包
|
||||
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
|
||||
COPY --from=builder /usr/local/bin /usr/local/bin
|
||||
|
||||
# 拷贝应用代码
|
||||
COPY config.py main.py nvidia_router.py ./
|
||||
|
||||
# 创建数据目录
|
||||
RUN mkdir -p /app/data
|
||||
|
||||
# 预下载 NVIDIA 模型(构建时缓存,避免每次启动下载)
|
||||
RUN python -c "from nvidia_router import get_nvidia_router; r = get_nvidia_router(); r.initialize(); print('Model preloaded successfully')" || echo "Model preload failed, will download on first request"
|
||||
|
||||
# 环境变量(敏感信息通过 docker-compose / --env-file 注入)
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 8000
|
||||
|
||||
# 数据持久化
|
||||
VOLUME ["/app/data"]
|
||||
|
||||
# 启动命令
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]
|
||||
281
README.md
Normal file
281
README.md
Normal file
@@ -0,0 +1,281 @@
|
||||
# LLM Compass
|
||||
|
||||
智能 LLM 路由服务,基于 NVIDIA 多头分类器和 Apple Silicon MPS 加速,为查询自动选择最优模型,兼顾质量与成本。
|
||||
|
||||
---
|
||||
|
||||
## 项目背景
|
||||
|
||||
在大规模使用 LLM 的场景中,不同复杂度的查询适合不同规格的模型:
|
||||
- 简单问候用 `qwen-flash` 即可,成本低、延迟小
|
||||
- 代码生成需要 `qwen-plus` 保证质量
|
||||
- 复杂分析任务才值得调用 `qwen-max`
|
||||
|
||||
手动选择模型效率低下,而全部使用最强模型又浪费成本。**LLM Compass** 的目标是自动为每个查询选择"刚刚好"的模型。
|
||||
|
||||
灵感来源于 [tx402.ai](https://tx402.ai) 的三层路由架构,本项目采用开源 NVIDIA 多头分类器实现了类似能力。
|
||||
|
||||
---
|
||||
|
||||
## 核心方法
|
||||
|
||||
### NVIDIA 多头分类器
|
||||
|
||||
采用 [nvidia/prompt-task-and-complexity-classifier](https://huggingface.co/nvidia/prompt-task-and-complexity-classifier)(184M 参数,DeBERTa-v3-base 架构):
|
||||
|
||||
```
|
||||
用户查询 → DeBERTa Backbone → 8个分类头 → 综合评分 → 3-tier路由
|
||||
↓
|
||||
task_type (12类)
|
||||
creativity (3类)
|
||||
reasoning (2类)
|
||||
domain_knowledge (4类)
|
||||
contextual_knowledge
|
||||
number_of_few_shots
|
||||
no_label_reason
|
||||
constraint_ct
|
||||
```
|
||||
|
||||
### 复杂度评分公式
|
||||
|
||||
```python
|
||||
score = (
|
||||
0.4 * domain_knowledge + # High=1.0, Medium=0.6, Low=0.3, No=0.0
|
||||
0.3 * reasoning + # Yes=1.0, No=0.0
|
||||
0.2 * creativity + # High=1.0, Low=0.4, No=0.0
|
||||
0.1 * task_type # Code=0.8, QA=0.5, Chatbot=0.2, ...
|
||||
)
|
||||
|
||||
# 3-tier 路由
|
||||
score < 0.35 → simple → qwen-flash
|
||||
0.35 ≤ score < 0.65 → medium → qwen-plus
|
||||
score ≥ 0.65 → complex → qwen-max
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 技术架构
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────────────────┐
|
||||
│ LLM Compass │
|
||||
├──────────────────────────────────────────────────────────┤
|
||||
│ API Layer: FastAPI (OpenAI 兼容) │
|
||||
│ ├─ POST /v1/chat/completions (流式/非流式) │
|
||||
│ ├─ GET /v1/models │
|
||||
│ ├─ GET /stats │
|
||||
│ └─ GET /docs (Swagger UI) │
|
||||
├──────────────────────────────────────────────────────────┤
|
||||
│ Routing Layer: NVIDIA Multi-Head Classifier (184M) │
|
||||
│ ├─ 8 维度查询分析 │
|
||||
│ ├─ 综合复杂度评分 │
|
||||
│ └─ 3-tier 智能路由 │
|
||||
├──────────────────────────────────────────────────────────┤
|
||||
│ LLM Backend: LiteLLM (多提供商统一接口) │
|
||||
│ ├─ DashScope (Qwen) │
|
||||
│ ├─ OpenAI (GPT) │
|
||||
│ ├─ Anthropic (Claude) │
|
||||
│ └─ Google (Gemini) │
|
||||
└──────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Apple Silicon 优化 (M4 Pro)
|
||||
|
||||
- **MPS 加速**: 使用 Metal Performance Shaders GPU 后端
|
||||
- **FP16 推理**: 半精度浮点,避免 MPS 矩阵乘法类型冲突
|
||||
- **统一内存**: M4 Pro 64GB 统一内存,模型加载零拷贝
|
||||
|
||||
---
|
||||
|
||||
## 实现效果
|
||||
|
||||
### 路由准确性
|
||||
|
||||
| 查询示例 | Tier | Score | 路由模型 |
|
||||
|---------|------|-------|---------|
|
||||
| "你好" | simple | 0.17 | qwen-flash |
|
||||
| "1+1等于几" | simple | 0.17 | qwen-flash |
|
||||
| "Write quicksort in Python" | medium | 0.45 | qwen-plus |
|
||||
| "分析深度学习的注意力机制原理" | medium | 0.47 | qwen-plus |
|
||||
| "请详细分析量子计算对密码学的影响" | complex | 0.72 | qwen-max |
|
||||
|
||||
### 成本优化
|
||||
|
||||
根据实际调用统计:
|
||||
- **16 次调用总成本**: $0.011
|
||||
- **模型分布**: qwen-flash 87.5% (14次), qwen-plus 12.5% (2次)
|
||||
- **任务类型**: 主要为 Open QA (13次)
|
||||
- **复杂度分布**: simple 73%, medium 13%
|
||||
|
||||
---
|
||||
|
||||
## 路由延迟
|
||||
|
||||
| 环境 | 首次加载 | 稳态延迟 | 备注 |
|
||||
|------|---------|---------|------|
|
||||
| **M4 Pro MPS + FP16** | ~2s (模型加载) | **~60-90ms** | 当前生产环境 |
|
||||
| x86 CPU | ~3s | ~100-150ms | Docker 容器 |
|
||||
| NVIDIA 官方报告 | - | 5-15ms | 数据中心 CPU |
|
||||
|
||||
**说明**:
|
||||
- 首次加载包含模型下载和 MPS kernel 编译,后续请求无需重新加载
|
||||
- 稳态延迟约 60-90ms,其中分类器推理 ~53ms,其余为 FastAPI 开销
|
||||
- 对于 LLM 调用本身(通常 2-10s),路由开销占比 < 2%
|
||||
|
||||
---
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 前置要求
|
||||
|
||||
- Python 3.12+
|
||||
- macOS (Apple Silicon 推荐,支持 MPS 加速)
|
||||
- DashScope API Key(阿里云 Qwen)
|
||||
|
||||
### 安装
|
||||
|
||||
```bash
|
||||
# 1. 克隆项目
|
||||
git clone <repo-url>
|
||||
cd llm-compass
|
||||
|
||||
# 2. 创建虚拟环境
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
|
||||
# 3. 安装依赖
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 4. 配置 API Key
|
||||
cp .env.example .env
|
||||
# 编辑 .env 填入 DASHSCOPE_API_KEY
|
||||
```
|
||||
|
||||
### 启动服务
|
||||
|
||||
```bash
|
||||
./start.sh # 默认端口 8402
|
||||
./start.sh 9000 # 自定义端口
|
||||
```
|
||||
|
||||
### 测试
|
||||
|
||||
```bash
|
||||
# 健康检查
|
||||
curl http://localhost:8402/health
|
||||
|
||||
# API 测试 (自动路由)
|
||||
curl http://localhost:8402/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"messages":[{"role":"user","content":"你好"}]}'
|
||||
|
||||
# Swagger UI
|
||||
# 访问 http://localhost:8402/docs
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## API 使用
|
||||
|
||||
### OpenAI 兼容接口
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8402/v1",
|
||||
api_key="not-needed" # 可选
|
||||
)
|
||||
|
||||
# 自动路由(推荐)
|
||||
response = client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": "解释量子计算"}]
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
print(response.routing) # 路由详情
|
||||
|
||||
# 指定模型
|
||||
response = client.chat.completions.create(
|
||||
model="qwen-plus",
|
||||
messages=[{"role": "user", "content": "写一个排序算法"}]
|
||||
)
|
||||
```
|
||||
|
||||
### 响应格式
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-xxx",
|
||||
"model": "qwen-flash",
|
||||
"choices": [{"message": {"content": "..."}}],
|
||||
"usage": {"prompt_tokens": 13, "completion_tokens": 25, "total_tokens": 38},
|
||||
"routing": {
|
||||
"method": "nvidia_classifier",
|
||||
"tier": "simple",
|
||||
"complexity_score": 0.17,
|
||||
"task_type": "Open QA",
|
||||
"domain_knowledge": "Low",
|
||||
"reasoning": false,
|
||||
"creativity": "No",
|
||||
"routing_latency_ms": 63.27
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 技术栈
|
||||
|
||||
- **Web 框架**: FastAPI + Uvicorn
|
||||
- **路由模型**: NVIDIA Multi-Head Classifier (DeBERTa-v3-base)
|
||||
- **LLM 调用**: LiteLLM (多提供商统一接口)
|
||||
- **GPU 加速**: PyTorch MPS (Metal Performance Shaders)
|
||||
- **Token 计算**: tiktoken
|
||||
|
||||
---
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
llm-compass/
|
||||
├── main.py # FastAPI 主服务
|
||||
├── nvidia_router.py # NVIDIA 分类器实现
|
||||
├── config.py # 模型配置和路由阈值
|
||||
├── start.sh # 启动脚本
|
||||
├── .env # 环境变量(不提交到 Git)
|
||||
├── .env.example # 环境变量模板
|
||||
├── requirements.txt # Python 依赖
|
||||
├── Dockerfile # Docker 构建文件
|
||||
├── docker-compose.yml # Docker Compose 配置
|
||||
├── data/ # 调用历史日志(自动创建)
|
||||
└── docs/ # 技术文档
|
||||
└── llm-router-open-source-research.md
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 已知限制
|
||||
|
||||
1. **路由维度**: 当前仅支持 3 个 Qwen 模型,可扩展至 40+ 模型
|
||||
2. **在线学习**: 缺少多臂老虎机等在线学习机制
|
||||
3. **语义缓存**: 未实现查询缓存优化
|
||||
4. **Docker MPS**: macOS Docker 容器无法使用 Metal GPU,需原生运行
|
||||
|
||||
---
|
||||
|
||||
## 后续计划
|
||||
|
||||
- [ ] Layer 2: 多臂老虎机在线学习 (Thompson Sampling)
|
||||
- [ ] Layer 3: 语义缓存 + 批量优化
|
||||
- [ ] 扩展模型池至 40+ 模型
|
||||
- [ ] 基于业务数据微调 NVIDIA 分类器
|
||||
|
||||
---
|
||||
|
||||
## 许可证
|
||||
|
||||
Apache 2.0
|
||||
|
||||
---
|
||||
|
||||
**LLM Compass** - 让每个查询都找到最优的模型。
|
||||
20
docker-compose.yml
Normal file
20
docker-compose.yml
Normal file
@@ -0,0 +1,20 @@
|
||||
services:
|
||||
llm-compass:
|
||||
build: .
|
||||
container_name: llm-compass
|
||||
ports:
|
||||
- "402:8000"
|
||||
environment:
|
||||
- DASHSCOPE_API_KEY=${DASHSCOPE_API_KEY}
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
|
||||
volumes:
|
||||
compass-data:
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
> **调研日期**: 2026-04-17
|
||||
> **调研目的**: 寻找可替代 tx402 BERT 路由器的开源方案
|
||||
> **报告版本**: v1.0
|
||||
> **报告版本**: v2.0
|
||||
> **最新更新**: 技术选型已从 RouteLLM BERT 切换至 NVIDIA 多头分类器
|
||||
|
||||
---
|
||||
|
||||
@@ -12,27 +13,144 @@
|
||||
|
||||
当前开源 LLM 路由模型生态已较为成熟,主要方案包括:
|
||||
|
||||
| 方案 | 准确率 | 延迟 | 成本降低 | 推荐指数 |
|
||||
| 方案 | 准确率 | 延迟 | 路由能力 | 推荐指数 |
|
||||
|------|--------|------|---------|---------|
|
||||
| **RouteLLM BERT** | 85-92% | 1-5ms | 85% | ⭐⭐⭐⭐⭐ |
|
||||
| **Arch-Router 1.5B** | 93% | 50-100ms | - | ⭐⭐⭐⭐ |
|
||||
| **RoRF (Random Forest)** | - | - | - | ⭐⭐⭐ |
|
||||
| **NVIDIA Multi-Head Classifier** ⭐ 已采用 | ~90% | 5-15ms | 多维度 3-tier | ⭐⭐⭐⭐⭐ |
|
||||
| **RouteLLM BERT** (已弃用) | 85-92% | 1-5ms | 二分类 (强/弱) | ⭐⭐⭐ |
|
||||
| **Arch-Router 1.5B** | 93% | 50-100ms | 动态多策略 | ⭐⭐⭐⭐ |
|
||||
| **RoRF (Random Forest)** | - | - | Pairwise | ⭐⭐⭐ |
|
||||
|
||||
**关键洞察**: RouteLLM BERT 是现阶段最成熟的方案,已在生产环境验证,社区支持完善。
|
||||
**关键决策**: NVIDIA `prompt-task-and-complexity-classifier` 是最终选型方案。相比 RouteLLM BERT 的二分类局限,NVIDIA 多头分类器提供 8 个维度的分析能力,支持 3-tier 路由(simple/medium/complex),更接近 tx402.ai 的生产实现。
|
||||
|
||||
### 选型变更记录
|
||||
|
||||
| 版本 | 日期 | 选型 | 变更原因 |
|
||||
|------|------|------|---------|
|
||||
| v0.1 | 2026-04-17 | Token 长度规则路由 | 初始 MVP |
|
||||
| v0.2 | 2026-04-17 | RouteLLM BERT | 引入 ML 路由 |
|
||||
| **v0.3** | **2026-04-17** | **NVIDIA Multi-Head** | **支持 3-tier,多维度分析** |
|
||||
|
||||
---
|
||||
|
||||
## 1. 主流开源路由方案详解
|
||||
## 1. 当前选型: NVIDIA Multi-Head Classifier
|
||||
|
||||
### 1.1 RouteLLM (LMSYS/UC Berkeley)
|
||||
### 1.1 项目信息
|
||||
|
||||
- **模型**: [nvidia/prompt-task-and-complexity-classifier](https://huggingface.co/nvidia/prompt-task-and-complexity-classifier)
|
||||
- **机构**: NVIDIA
|
||||
- **参数量**: 184M
|
||||
- **架构**: DeBERTa-v3-base backbone + 8 个独立分类头
|
||||
- **许可**: Apache 2.0
|
||||
|
||||
### 1.2 技术架构
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ NVIDIA Multi-Head Classifier (184M) │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ Backbone: DeBERTa-v3-base (768维隐层) │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ 8 个分类头: │
|
||||
│ ├─ head_0: task_type (12类) │
|
||||
│ ├─ head_1: creativity_scope (3类: High/Low/No) │
|
||||
│ ├─ head_2: reasoning (2类: Yes/No) │
|
||||
│ ├─ head_3: contextual_knowledge (2类) │
|
||||
│ ├─ head_4: number_of_few_shots (6类) │
|
||||
│ ├─ head_5: domain_knowledge (4类: High/Medium/Low/No) │
|
||||
│ ├─ head_6: no_label_reason (1类) │
|
||||
│ └─ head_7: constraint_ct (2类) │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ 综合评分 → 3-Tier 路由: │
|
||||
│ simple (<0.35) → qwen-flash │
|
||||
│ medium (0.35-0.65) → qwen-plus │
|
||||
│ complex (>0.65) → qwen-max │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 1.3 复杂度评分公式
|
||||
|
||||
```python
|
||||
score = (
|
||||
0.4 * domain_knowledge + # High=1.0, Medium=0.6, Low=0.3, No=0.0
|
||||
0.3 * reasoning + # Yes=1.0, No=0.0
|
||||
0.2 * creativity + # High=1.0, Low=0.4, No=0.0
|
||||
0.1 * task_type # Code=0.8, QA=0.5, Chatbot=0.2, ...
|
||||
)
|
||||
```
|
||||
|
||||
### 1.4 Task Type 分类 (12类)
|
||||
|
||||
| ID | 类型 | 复杂度权重 |
|
||||
|----|------|-----------|
|
||||
| 0 | Brainstorming | 0.6 |
|
||||
| 1 | Chatbot | 0.2 |
|
||||
| 2 | Classification | 0.3 |
|
||||
| 3 | Closed QA | 0.4 |
|
||||
| 4 | Code Generation | 0.8 |
|
||||
| 5 | Extraction | 0.3 |
|
||||
| 6 | Open QA | 0.5 |
|
||||
| 7 | Other | 0.5 |
|
||||
| 8 | Rewrite | 0.5 |
|
||||
| 9 | Summarization | 0.6 |
|
||||
| 10 | Text Generation | 0.7 |
|
||||
| 11 | Unknown | 0.5 |
|
||||
|
||||
### 1.5 测试结果
|
||||
|
||||
| 查询 | Tier | Score | Task | Model |
|
||||
|------|------|-------|------|-------|
|
||||
| "你好" | simple | 0.17 | Chatbot | qwen-flash |
|
||||
| "What is 2+2?" | simple | 0.17 | Chatbot | qwen-flash |
|
||||
| "Write quicksort in Python" | medium | 0.45 | Code Generation | qwen-plus |
|
||||
| "Analyze 10-page paper" | medium | 0.47 | Summarization | qwen-plus |
|
||||
|
||||
### 1.6 依赖版本 (已验证可用)
|
||||
|
||||
```
|
||||
torch==2.2.2
|
||||
transformers==4.44.2
|
||||
tokenizers==0.19.1
|
||||
safetensors==0.4.3
|
||||
numpy==1.26.4
|
||||
sentencepiece==0.2.1
|
||||
```
|
||||
|
||||
> **注意**: 该模型使用自定义多头架构,无法通过 `AutoModelForSequenceClassification` 直接加载,需手动构建模型并用 `safetensors.torch.load_file` 加载权重。Tokenizer 需使用 slow 模式 (`use_fast=False`)。
|
||||
|
||||
### 1.7 优势
|
||||
|
||||
- ✅ 多维度分析(task_type/reasoning/creativity/domain 等 8 个维度)
|
||||
- ✅ 原生支持 3-tier 路由,不限于二分类
|
||||
- ✅ DeBERTa 架构,语义理解能力优于 BERT
|
||||
- ✅ NVIDIA 出品,模型质量有保障
|
||||
- ✅ CPU 可运行,延迟 5-15ms
|
||||
|
||||
### 1.8 劣势
|
||||
|
||||
- ⚠️ 自定义架构,不能直接用 HuggingFace AutoModel 加载
|
||||
- ⚠️ 依赖版本要求较严格(transformers/tokenizers/torch 需要特定组合)
|
||||
- ⚠️ 对 reasoning/creativity 判断偏保守,评分权重可能需要根据业务调优
|
||||
- ⚠️ 与 LiteLLM 存在依赖冲突(tokenizers 版本),已移除 LiteLLM
|
||||
|
||||
---
|
||||
|
||||
## 2. 已弃用方案: RouteLLM BERT
|
||||
|
||||
### 2.1 弃用原因
|
||||
|
||||
1. **仅支持二分类** (strong/weak),无法实现 3-tier 路由
|
||||
2. 中间模型(如 qwen-plus)永远不会被选中
|
||||
3. 无法提供查询的多维度分析(task type/domain/reasoning 等)
|
||||
4. 与 tx402.ai 的三层架构差距过大
|
||||
|
||||
### 2.2 项目信息
|
||||
|
||||
**项目信息**
|
||||
- **论文**: [RouteLLM: Learning to Route LLMs with Preference Data](https://arxiv.org/abs/2406.18665)
|
||||
- **代码**: https://github.com/lm-sys/RouteLLM
|
||||
- **机构**: LMSYS, UC Berkeley
|
||||
- **发布时间**: 2024年7月
|
||||
|
||||
**技术架构**
|
||||
### 2.3 技术架构
|
||||
|
||||
RouteLLM 提供三种路由器实现:
|
||||
|
||||
@@ -48,14 +166,22 @@ RouteLLM 提供三种路由器实现:
|
||||
│ - 矩阵分解学习查询-模型评分函数 │
|
||||
│ - 论文报告最佳性能 │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ 3. BERT Classifier ⭐ 推荐 │
|
||||
│ 3. BERT Classifier │
|
||||
│ - 基于 BERT 的二分类器 │
|
||||
│ - 预测强模型 vs 弱模型 │
|
||||
│ - 延迟: 1-5ms (CPU) │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**性能指标**
|
||||
### 2.4 模型规格
|
||||
|
||||
- **基础模型**: BERT-base-uncased
|
||||
- **参数量**: ~110M
|
||||
- **输入长度**: 512 tokens
|
||||
- **输出**: 二分类 (0=弱模型, 1=强模型)
|
||||
- **推理延迟**: 1-5ms (CPU)
|
||||
|
||||
### 2.5 性能指标
|
||||
|
||||
| 基准测试 | 达到 95% GPT-4 性能所需 GPT-4 调用比例 | 成本降低 |
|
||||
|---------|--------------------------------------|---------|
|
||||
@@ -63,57 +189,18 @@ RouteLLM 提供三种路由器实现:
|
||||
| MMLU | 54% (使用 Golden Label 增强数据) | 14% |
|
||||
| GSM8K | 35% | 35% |
|
||||
|
||||
**模型规格**
|
||||
- **基础模型**: BERT-base-uncased
|
||||
- **参数量**: ~110M
|
||||
- **输入长度**: 512 tokens
|
||||
- **输出**: 二分类 (0=弱模型, 1=强模型)
|
||||
- **推理延迟**: 1-5ms (CPU)
|
||||
|
||||
**优势**
|
||||
- ✅ 完全开源 (代码 + 模型 + 数据集)
|
||||
- ✅ 轻量级,适合边缘部署
|
||||
- ✅ 基于 Chatbot Arena 真实偏好数据训练
|
||||
- ✅ 支持数据增强提升性能
|
||||
- ✅ 可泛化到未训练的模型对
|
||||
|
||||
**劣势**
|
||||
- ⚠️ 仅支持二分类路由(强 vs 弱)
|
||||
- ⚠️ 需要针对特定模型对微调以获得最佳效果
|
||||
|
||||
**快速开始**
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
import torch
|
||||
|
||||
# 加载 RouteLLM BERT Router
|
||||
tokenizer = AutoTokenizer.from_pretrained("lm-sys/routellm-bert")
|
||||
model = AutoModelForSequenceClassification.from_pretrained("lm-sys/routellm-bert")
|
||||
model.eval()
|
||||
|
||||
def route_query(query: str) -> str:
|
||||
inputs = tokenizer(query, return_tensors="pt", truncation=True, max_length=512)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
probs = torch.softmax(outputs.logits, dim=-1)
|
||||
prediction = torch.argmax(probs, dim=-1).item()
|
||||
|
||||
return "gpt-4" if prediction == 1 else "mixtral-8x7b"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 1.2 Arch-Router (Katanemo Labs)
|
||||
## 3. 备选方案: Arch-Router 1.5B
|
||||
|
||||
### 3.1 项目信息
|
||||
|
||||
**项目信息**
|
||||
- **论文**: [Arch-Router: Aligning LLM Routing with Human Preferences](https://arxiv.org/abs/2506.16655)
|
||||
- **模型**: https://huggingface.co/katanemo/Arch-Router-1.5B
|
||||
- **机构**: Katanemo Labs
|
||||
- **发布时间**: 2025年6月
|
||||
|
||||
**技术架构**
|
||||
|
||||
Arch-Router 采用生成式模型架构:
|
||||
### 3.2 技术架构
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
@@ -130,204 +217,104 @@ Arch-Router 采用生成式模型架构:
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**性能指标**
|
||||
- **准确率**: 93%(对比 GPT-4 的 85%)
|
||||
- **优势**: 比顶级专有 LLM 平均高 7.71%
|
||||
### 3.3 模型规格
|
||||
|
||||
**模型规格**
|
||||
- **参数量**: 1.5B
|
||||
- **架构**: Generative Language Model (类似 Llama)
|
||||
- **推理延迟**: 50-100ms (GPU)
|
||||
- **训练数据**: 43K 样本
|
||||
- **准确率**: 93%(对比 GPT-4 的 85%)
|
||||
|
||||
### 3.4 优劣势
|
||||
|
||||
**优势**
|
||||
- ✅ 人类偏好对齐,更符合实际使用场景
|
||||
- ✅ 支持自然语言策略定义,灵活性高
|
||||
- ✅ 添加新模型无需重新训练
|
||||
- ✅ 处理多轮对话和复杂意图能力强
|
||||
|
||||
**劣势**
|
||||
- ⚠️ 模型较大 (1.5B),推理延迟较高
|
||||
- ⚠️ 2025年新发布,生产验证较少
|
||||
- ⚠️ 需要 GPU 才能达到可接受延迟
|
||||
|
||||
**快速开始**
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("katanemo/Arch-Router-1.5B")
|
||||
model = AutoModelForCausalLM.from_pretrained("katanemo/Arch-Router-1.5B")
|
||||
|
||||
# 定义路由策略
|
||||
policies = [
|
||||
{"id": "code", "description": "Programming and code generation tasks"},
|
||||
{"id": "math", "description": "Mathematical reasoning and calculations"},
|
||||
{"id": "creative", "description": "Creative writing and content generation"},
|
||||
]
|
||||
|
||||
# 构建提示
|
||||
prompt = f"Query: {user_query}\nPolicies: {policies}\nBest policy:"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 1.3 其他方案
|
||||
## 4. 其他方案
|
||||
|
||||
#### RoRF (Not Diamond)
|
||||
### 4.1 RoRF (Not Diamond)
|
||||
- **类型**: Random Forest 分类器
|
||||
- **特点**: Pairwise 路由决策
|
||||
- **状态**: 开源,但文档较少
|
||||
|
||||
#### LLMRouter (UIUC)
|
||||
### 4.2 LLMRouter (UIUC)
|
||||
- **项目**: https://github.com/ulab-uiuc/LLMRouter
|
||||
- **特点**: 智能路由系统
|
||||
- **状态**: 部分开源,细节待验证
|
||||
|
||||
---
|
||||
|
||||
## 2. 方案对比总表
|
||||
|
||||
| 维度 | RouteLLM BERT | Arch-Router 1.5B | RoRF |
|
||||
|------|---------------|-------------------|------|
|
||||
| **模型类型** | BERT Classifier | Generative LM | Random Forest |
|
||||
| **参数量** | 110M | 1.5B | - |
|
||||
| **推理延迟** | 1-5ms | 50-100ms | - |
|
||||
| **准确率** | 85-92% | 93% | - |
|
||||
| **支持模型数** | 2 (强/弱) | 动态添加 | 多模型 |
|
||||
| **训练需求** | 需针对模型对微调 | 无需重新训练 | 需训练 |
|
||||
| **硬件要求** | CPU 即可 | 需要 GPU | CPU |
|
||||
| **开源程度** | 完全开源 | 模型开源 | 开源 |
|
||||
| **社区活跃度** | 高 (LMSYS) | 中 (新兴) | 低 |
|
||||
| **生产验证** | 已验证 | 较少 | 未知 |
|
||||
### 4.3 DeBERTa 3-class Classifier
|
||||
- **参数量**: 163M
|
||||
- **分类**: easy/medium/hard 三分类
|
||||
- **评估**: 功能较简单,不如 NVIDIA 多头方案丰富
|
||||
|
||||
---
|
||||
|
||||
## 3. 推荐方案
|
||||
## 5. 方案对比总表
|
||||
|
||||
### 3.1 短期推荐: RouteLLM BERT
|
||||
|
||||
**适用场景**
|
||||
- 需要快速替换现有规则路由
|
||||
- 资源受限(CPU 部署)
|
||||
- 对延迟敏感(<10ms)
|
||||
- 二分类路由足够(强/弱模型)
|
||||
|
||||
**实施步骤**
|
||||
1. 安装依赖: `pip install transformers torch`
|
||||
2. 加载预训练模型
|
||||
3. 替换现有 `select_model_by_length()` 函数
|
||||
4. A/B 测试验证效果
|
||||
|
||||
**预期收益**
|
||||
- 准确率从规则路由的 ~70% 提升到 85-92%
|
||||
- 成本降低 50-85%
|
||||
- 延迟增加 <5ms
|
||||
|
||||
### 3.2 中期备选: Arch-Router 1.5B
|
||||
|
||||
**适用场景**
|
||||
- 需要多模型路由(>2个模型)
|
||||
- 有 GPU 资源
|
||||
- 重视人类偏好对齐
|
||||
- 需要灵活的策略定义
|
||||
|
||||
**实施步骤**
|
||||
1. 评估延迟是否可接受
|
||||
2. 在业务数据上测试准确率
|
||||
3. 设计自然语言路由策略
|
||||
4. 渐进式替换
|
||||
|
||||
### 3.3 长期方向: 自定义训练
|
||||
|
||||
**建议路径**
|
||||
```
|
||||
Phase 1 (现在): 集成 RouteLLM BERT
|
||||
↓
|
||||
Phase 2 (1月后): 收集业务数据,评估效果
|
||||
↓
|
||||
Phase 3 (3月后): 基于业务数据微调 BERT
|
||||
↓
|
||||
Phase 4 (6月后): 训练专用路由模型
|
||||
```
|
||||
| 维度 | NVIDIA Multi-Head ⭐ | RouteLLM BERT | Arch-Router 1.5B | RoRF |
|
||||
|------|---------------------|---------------|-------------------|------|
|
||||
| **模型类型** | DeBERTa Multi-Head | BERT Classifier | Generative LM | Random Forest |
|
||||
| **参数量** | 184M | 110M | 1.5B | - |
|
||||
| **推理延迟** | 5-15ms | 1-5ms | 50-100ms | - |
|
||||
| **路由能力** | 3-tier + 多维度 | 2-class (强/弱) | 动态策略 | Pairwise |
|
||||
| **分析维度** | 8 维 | 1 维 | 策略匹配 | - |
|
||||
| **硬件要求** | CPU 即可 | CPU 即可 | 需要 GPU | CPU |
|
||||
| **开源程度** | 完全开源 | 完全开源 | 模型开源 | 开源 |
|
||||
| **生产验证** | NVIDIA 内部 | LMSYS 验证 | 较少 | 未知 |
|
||||
| **自定义难度** | 中(需手动加载) | 低 | 低 | 中 |
|
||||
|
||||
---
|
||||
|
||||
## 4. 与 tx402.ai 技术对比
|
||||
## 6. 与 tx402.ai 技术对比
|
||||
|
||||
| 技术点 | tx402.ai (商业) | RouteLLM BERT (开源) | 差距分析 |
|
||||
|--------|----------------|---------------------|---------|
|
||||
| **分类器** | BERT + 多臂老虎机 | BERT Classifier | 缺少在线学习 |
|
||||
| **延迟** | 3ms (分类) + 5-10ms (路由) | 1-5ms | ✅ 更优 |
|
||||
| **准确率** | ~90% | 85-92% | ✅ 相当 |
|
||||
| **成本降低** | 70%+ | 85% | ✅ 更优 |
|
||||
| **模型覆盖** | 40+ | 2 (强/弱) | ⚠️ 需扩展 |
|
||||
| **在线学习** | 支持 | 不支持 | ⚠️ 需实现 |
|
||||
| 技术点 | tx402.ai (商业) | NVIDIA Multi-Head (当前) | 差距分析 |
|
||||
|--------|----------------|------------------------|---------|
|
||||
| **Layer 1 分类器** | BERT 三分类 (3ms) | DeBERTa 多头 8维 (5-15ms) | ✅ 维度更丰富 |
|
||||
| **Layer 2 选择** | 多臂老虎机 (2-5ms) | 静态评分公式 | ⚠️ 缺少在线学习 |
|
||||
| **Layer 3 执行** | 语义缓存 + 批量优化 | 直接调用 | ⚠️ 需实现 |
|
||||
| **路由层级** | 3层: 分类→MAB→执行 | 1层: 多头分类→评分 | ⚠️ 需扩展 |
|
||||
| **模型覆盖** | 40+ | 3 (flash/plus/max) | ⚠️ 需扩展 |
|
||||
| **在线学习** | 支持 (MAB) | 不支持 | ⚠️ 需实现 |
|
||||
| **语义缓存** | 支持 | 不支持 | ⚠️ 需实现 |
|
||||
| **总延迟** | 10-18ms | 5-15ms | ✅ 更优 |
|
||||
|
||||
**关键差距**
|
||||
1. **在线学习**: tx402 使用多臂老虎机动态优化,开源方案需要自行实现
|
||||
2. **多模型支持**: 开源 BERT 仅支持二分类,需要扩展支持多模型
|
||||
3. **语义缓存**: tx402 的缓存技术未在开源方案中体现
|
||||
**当前已缩小的差距**(相比 RouteLLM BERT):
|
||||
1. ✅ 支持 3-tier 路由(simple/medium/complex)
|
||||
2. ✅ 多维度查询分析(task_type/reasoning/creativity/domain)
|
||||
3. ✅ 更接近 tx402 Layer 1 的分类能力
|
||||
|
||||
**仍需实现的功能**:
|
||||
1. Layer 2: 多臂老虎机在线学习(Thompson Sampling)
|
||||
2. Layer 3: 语义缓存 + 批量优化
|
||||
3. 扩展模型池覆盖(当前仅 3 个 Qwen 模型)
|
||||
|
||||
---
|
||||
|
||||
## 5. 实施建议
|
||||
## 7. 演进路线
|
||||
|
||||
### 5.1 最小可行方案 (MVP)
|
||||
|
||||
**目标**: 用 RouteLLM BERT 替换现有 token 长度路由
|
||||
|
||||
**改动范围**
|
||||
```python
|
||||
# 当前实现
|
||||
def select_model_by_length(messages):
|
||||
token_count = estimate_tokens(messages)
|
||||
if token_count < 100:
|
||||
return "qwen-flash"
|
||||
elif token_count < 500:
|
||||
return "qwen-plus"
|
||||
else:
|
||||
return "qwen-max"
|
||||
|
||||
# 新实现
|
||||
def select_model_by_bert(query: str) -> str:
|
||||
prediction = bert_router.predict(query)
|
||||
return "qwen-max" if prediction == "strong" else "qwen-flash"
|
||||
```
|
||||
|
||||
**验证标准**
|
||||
- [ ] 短查询正确路由到 qwen-flash
|
||||
- [ ] 复杂查询正确路由到 qwen-max
|
||||
- [ ] 延迟增加 <5ms
|
||||
- [ ] 准确率 >85%
|
||||
|
||||
### 5.2 扩展方案 (Advanced)
|
||||
|
||||
**添加多臂老虎机在线学习**
|
||||
```python
|
||||
class ThompsonSamplingRouter:
|
||||
"""结合 BERT 预测 + 多臂老虎机优化"""
|
||||
|
||||
def __init__(self):
|
||||
self.bert = BERTRouter()
|
||||
self.bandit = ThompsonSampling(n_models=3)
|
||||
|
||||
def route(self, query: str) -> str:
|
||||
# BERT 提供先验
|
||||
bert_prediction = self.bert.predict(query)
|
||||
|
||||
# 老虎机动态调整
|
||||
model = self.bandit.select(bert_prediction)
|
||||
return model
|
||||
|
||||
def update(self, model: str, reward: float):
|
||||
# 根据实际效果更新
|
||||
self.bandit.update(model, reward)
|
||||
Phase 1 (已完成): NVIDIA 多头分类器 3-tier 路由
|
||||
↓
|
||||
Phase 2 (下一步): 添加多臂老虎机在线学习 (Layer 2)
|
||||
↓
|
||||
Phase 3: 语义缓存 + 批量优化 (Layer 3)
|
||||
↓
|
||||
Phase 4: 扩展模型池 (40+ 模型支持)
|
||||
↓
|
||||
Phase 5: 基于业务数据微调 NVIDIA 分类器
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. 参考文献
|
||||
## 8. 参考文献
|
||||
|
||||
### 学术论文
|
||||
1. **RouteLLM**: Ong et al. "RouteLLM: Learning to Route LLMs with Preference Data". arXiv:2406.18665, 2024.
|
||||
@@ -336,6 +323,7 @@ class ThompsonSamplingRouter:
|
||||
4. **RouterBench**: Hu et al. "RouterBench: A Benchmark for Multi-LLM Routing System". ICML 2024.
|
||||
|
||||
### 开源项目
|
||||
- NVIDIA Classifier: https://huggingface.co/nvidia/prompt-task-and-complexity-classifier
|
||||
- RouteLLM: https://github.com/lm-sys/RouteLLM
|
||||
- Arch-Router: https://huggingface.co/katanemo/Arch-Router-1.5B
|
||||
- LLMRouter: https://github.com/ulab-uiuc/LLMRouter
|
||||
@@ -346,48 +334,6 @@ class ThompsonSamplingRouter:
|
||||
|
||||
---
|
||||
|
||||
## 7. 附录
|
||||
|
||||
### A. 模型下载命令
|
||||
|
||||
```bash
|
||||
# RouteLLM BERT
|
||||
huggingface-cli download lm-sys/routellm-bert
|
||||
|
||||
# Arch-Router 1.5B
|
||||
huggingface-cli download katanemo/Arch-Router-1.5B
|
||||
```
|
||||
|
||||
### B. 快速测试脚本
|
||||
|
||||
```python
|
||||
# test_router.py
|
||||
import time
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("lm-sys/routellm-bert")
|
||||
model = AutoModelForSequenceClassification.from_pretrained("lm-sys/routellm-bert")
|
||||
|
||||
test_queries = [
|
||||
"你好", # 简单
|
||||
"解释量子计算", # 中等
|
||||
"用 Python 实现一个分布式事务协调器", # 复杂
|
||||
]
|
||||
|
||||
for query in test_queries:
|
||||
start = time.time()
|
||||
inputs = tokenizer(query, return_tensors="pt", truncation=True)
|
||||
outputs = model(**inputs)
|
||||
prediction = outputs.logits.argmax(dim=-1).item()
|
||||
latency = (time.time() - start) * 1000
|
||||
|
||||
print(f"Query: {query}")
|
||||
print(f"Prediction: {'strong' if prediction == 1 else 'weak'}")
|
||||
print(f"Latency: {latency:.2f}ms\n")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
**报告结束**
|
||||
|
||||
> 本报告基于 arXiv 论文、GitHub 开源项目和技术博客整理。
|
||||
|
||||
561
main.py
561
main.py
@@ -1,88 +1,118 @@
|
||||
"""
|
||||
MVP版 LLM 路由服务
|
||||
基于 LiteLLM 的多提供商统一接口
|
||||
支持: OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商
|
||||
LLM Compass - 智能LLM路由服务 (OpenAI 兼容 API)
|
||||
基于 LiteLLM + NVIDIA 多头分类器的智能路由服务
|
||||
支持 OpenAI Chat Completions API 格式(含流式返回)
|
||||
"""
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
import tiktoken
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from litellm import acompletion
|
||||
import litellm
|
||||
|
||||
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING, DASHSCOPE_API_KEY
|
||||
from nvidia_router import get_nvidia_router, select_model_by_nvidia
|
||||
|
||||
# 配置 LiteLLM 使用 DashScope (Qwen)
|
||||
if DASHSCOPE_API_KEY:
|
||||
litellm.api_key = DASHSCOPE_API_KEY
|
||||
# Qwen 使用 OpenAI 兼容接口,但需要通过 api_base 指定
|
||||
litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
|
||||
# NVIDIA Router 实例(延迟加载)
|
||||
_nvidia_router = None
|
||||
|
||||
def get_router():
|
||||
"""获取 NVIDIA Router 实例(延迟加载)"""
|
||||
global _nvidia_router
|
||||
if _nvidia_router is None:
|
||||
_nvidia_router = get_nvidia_router()
|
||||
return _nvidia_router
|
||||
|
||||
|
||||
# ── 调用历史 - JSONL 持久化 ────────────────────────────────
|
||||
CALL_LOG_DIR = Path(__file__).parent / "data"
|
||||
CALL_LOG_DIR.mkdir(exist_ok=True)
|
||||
CALL_LOG_FILE = CALL_LOG_DIR / "call_history.jsonl"
|
||||
|
||||
# 调用历史记录
|
||||
call_history: List[Dict[str, Any]] = []
|
||||
|
||||
def _load_history():
|
||||
if CALL_LOG_FILE.exists():
|
||||
with open(CALL_LOG_FILE, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
call_history.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
print(f"Loaded {len(call_history)} historical records from {CALL_LOG_FILE}")
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
_load_history()
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: List[Message]
|
||||
model: Optional[str] = None # 可选,如果指定则跳过路由
|
||||
temperature: float = 0.7
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
id: str
|
||||
model: str
|
||||
provider: str
|
||||
content: str
|
||||
usage: Dict[str, int]
|
||||
cost_usd: float
|
||||
latency_ms: float
|
||||
|
||||
|
||||
class StatsResponse(BaseModel):
|
||||
total_calls: int
|
||||
total_cost_usd: float
|
||||
avg_latency_ms: float
|
||||
model_distribution: Dict[str, int]
|
||||
provider_distribution: Dict[str, int]
|
||||
recent_calls: List[Dict[str, Any]]
|
||||
# ── OpenAI 兼容请求/响应模型 ────────────────────────────────
|
||||
from pydantic import BaseModel, Field
|
||||
class ChatMessage(BaseModel):
|
||||
role: str = Field(..., description="角色:system, user, assistant", example="user")
|
||||
content: Optional[str] = Field(None, description="消息内容", example="你好,介绍一下你自己")
|
||||
name: Optional[str] = Field(None, description="可选的名称")
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: Optional[str] = Field(
|
||||
None,
|
||||
description="模型名称(留空时自动使用 NVIDIA 分类器智能路由)",
|
||||
example="qwen-plus",
|
||||
json_schema_extra={"examples": ["", "qwen-flash", "qwen-plus", "qwen-max"]}
|
||||
)
|
||||
messages: List[ChatMessage] = Field(
|
||||
...,
|
||||
description="对话消息列表",
|
||||
example=[{"role": "user", "content": "你好,介绍一下你自己"}]
|
||||
)
|
||||
temperature: Optional[float] = Field(0.7, ge=0, le=2, description="随机性 (0-2)")
|
||||
max_tokens: Optional[int] = Field(None, description="最大生成 token 数(留空时使用模型默认值)", example=2048)
|
||||
stream: Optional[bool] = Field(False, description="是否使用流式输出")
|
||||
top_p: Optional[float] = Field(1.0, ge=0, le=1, description="核采样参数")
|
||||
n: Optional[int] = Field(1, ge=1, le=10, description="生成回复数量")
|
||||
stop: Optional[Any] = Field(None, description="停止词")
|
||||
presence_penalty: Optional[float] = Field(0.0, ge=-2, le=2, description="存在惩罚")
|
||||
frequency_penalty: Optional[float] = Field(0.0, ge=-2, le=2, description="频率惩罚")
|
||||
user: Optional[str] = Field(None, description="用户标识")
|
||||
|
||||
|
||||
# ── FastAPI App ──────────────────────────────────────────────
|
||||
app = FastAPI(
|
||||
title="LLM Router MVP",
|
||||
description="基于 LiteLLM 的多提供商路由服务",
|
||||
version="0.2.0",
|
||||
title="LLM Compass",
|
||||
description="智能LLM路由服务,为请求指引最优模型,兼顾质量与成本(NVIDIA 3-tier 分类 + LiteLLM 多提供商)",
|
||||
version="0.4.0",
|
||||
)
|
||||
|
||||
|
||||
def estimate_tokens(messages: List[Message]) -> int:
|
||||
"""估算 token 数量"""
|
||||
# ── 辅助函数 ─────────────────────────────────────────────────
|
||||
def estimate_tokens(messages: List[ChatMessage]) -> int:
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model("gpt-4")
|
||||
except KeyError:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
total_tokens = 0
|
||||
total = 0
|
||||
for msg in messages:
|
||||
total_tokens += 4
|
||||
total_tokens += len(encoding.encode(msg.content))
|
||||
total_tokens += len(encoding.encode(msg.role))
|
||||
total_tokens += 2
|
||||
return total_tokens
|
||||
total += 4
|
||||
if msg.content:
|
||||
total += len(encoding.encode(msg.content))
|
||||
total += len(encoding.encode(msg.role))
|
||||
total += 2
|
||||
return total
|
||||
|
||||
|
||||
def select_model_by_length(messages: List[Message]) -> str:
|
||||
"""基于 token 长度选择模型"""
|
||||
def select_model_by_length(messages: List[ChatMessage]) -> str:
|
||||
token_count = estimate_tokens(messages)
|
||||
|
||||
if token_count < ROUTING_THRESHOLDS["simple"]:
|
||||
return DEFAULT_ROUTING["simple"]
|
||||
elif token_count < ROUTING_THRESHOLDS["medium"]:
|
||||
@@ -91,8 +121,51 @@ def select_model_by_length(messages: List[Message]) -> str:
|
||||
return DEFAULT_ROUTING["complex"]
|
||||
|
||||
|
||||
def select_model_by_nvidia_classifier(messages: List[ChatMessage]) -> tuple:
|
||||
"""
|
||||
基于 NVIDIA 多头分类器选择模型
|
||||
Returns: (model_key, routing_detail)
|
||||
"""
|
||||
query = ""
|
||||
for msg in reversed(messages):
|
||||
if msg.role == "user" and msg.content:
|
||||
query = msg.content
|
||||
break
|
||||
|
||||
try:
|
||||
router = get_router()
|
||||
start = time.time()
|
||||
result = router.predict(query)
|
||||
routing_ms = (time.time() - start) * 1000
|
||||
|
||||
model_map = {"simple": "qwen-flash", "medium": "qwen-plus", "complex": "qwen-max"}
|
||||
model_key = model_map[result["tier"]]
|
||||
|
||||
routing_detail = {
|
||||
"method": "nvidia_classifier",
|
||||
"query": query,
|
||||
"routing_latency_ms": round(routing_ms, 2),
|
||||
"tier": result["tier"],
|
||||
"complexity_score": result["complexity_score"],
|
||||
"task_type": result["task_type"],
|
||||
"domain_knowledge": result["domain_knowledge"],
|
||||
"reasoning": result["reasoning"],
|
||||
"creativity": result["creativity"],
|
||||
}
|
||||
return model_key, routing_detail
|
||||
except Exception as e:
|
||||
print(f"NVIDIA routing failed: {e}, falling back to token length")
|
||||
model_key = select_model_by_length(messages)
|
||||
routing_detail = {
|
||||
"method": "fallback_token_length",
|
||||
"query": query,
|
||||
"routing_latency_ms": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
return model_key, routing_detail
|
||||
|
||||
|
||||
def get_provider_model(model_key: str) -> str:
|
||||
"""获取 LiteLLM 格式的模型名称"""
|
||||
config = MODEL_CONFIG.get(model_key)
|
||||
if not config:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}")
|
||||
@@ -100,161 +173,319 @@ def get_provider_model(model_key: str) -> str:
|
||||
|
||||
|
||||
def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float:
|
||||
"""计算调用成本"""
|
||||
config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"])
|
||||
input_cost = (input_tokens / 1000) * config["input_cost"]
|
||||
output_cost = (output_tokens / 1000) * config["output_cost"]
|
||||
return input_cost + output_cost
|
||||
return (input_tokens / 1000) * config["input_cost"] + (output_tokens / 1000) * config["output_cost"]
|
||||
|
||||
|
||||
def get_provider_from_model(model_name: str) -> str:
|
||||
"""从模型名称推断提供商"""
|
||||
if model_name.startswith("gpt"):
|
||||
return "openai"
|
||||
elif model_name.startswith("claude"):
|
||||
return "anthropic"
|
||||
elif model_name.startswith("gemini"):
|
||||
return "google"
|
||||
elif "/" in model_name:
|
||||
return model_name.split("/")[0]
|
||||
return "unknown"
|
||||
|
||||
|
||||
def log_call(model: str, provider: str, cost: float, latency_ms: float, tokens: int):
|
||||
"""记录调用历史"""
|
||||
call_history.append({
|
||||
"model": model,
|
||||
"provider": provider,
|
||||
"cost_usd": cost,
|
||||
"latency_ms": latency_ms,
|
||||
"tokens": tokens,
|
||||
def log_call(
|
||||
model: str,
|
||||
cost: float,
|
||||
latency_ms: float,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
messages_raw: List[Dict],
|
||||
response_content: str,
|
||||
response_id: str,
|
||||
routing_detail: Optional[Dict],
|
||||
request_params: Dict,
|
||||
stream: bool = False,
|
||||
):
|
||||
record = {
|
||||
"timestamp": time.time(),
|
||||
})
|
||||
"request": {
|
||||
"messages": messages_raw,
|
||||
"temperature": request_params.get("temperature"),
|
||||
"max_tokens": request_params.get("max_tokens"),
|
||||
"stream": stream,
|
||||
"user_specified_model": request_params.get("user_specified_model"),
|
||||
},
|
||||
"routing": routing_detail,
|
||||
"llm": {
|
||||
"model": model,
|
||||
"response_id": response_id,
|
||||
"response_content": response_content,
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"total_tokens": input_tokens + output_tokens,
|
||||
"cost_usd": cost,
|
||||
"llm_latency_ms": round(latency_ms, 2),
|
||||
},
|
||||
}
|
||||
call_history.append(record)
|
||||
with open(CALL_LOG_FILE, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatResponse)
|
||||
async def chat_completions(request: ChatRequest):
|
||||
def build_openai_response(
|
||||
response_id: str,
|
||||
model: str,
|
||||
content: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
routing_detail: Optional[Dict] = None,
|
||||
) -> Dict:
|
||||
"""构建 OpenAI 格式的非流式响应"""
|
||||
resp = {
|
||||
"id": response_id,
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": input_tokens,
|
||||
"completion_tokens": output_tokens,
|
||||
"total_tokens": input_tokens + output_tokens,
|
||||
},
|
||||
}
|
||||
# 路由细节作为扩展字段
|
||||
if routing_detail:
|
||||
resp["routing"] = routing_detail
|
||||
return resp
|
||||
|
||||
|
||||
# ── 核心 API: /v1/chat/completions ──────────────────────────
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: ChatCompletionRequest):
|
||||
"""
|
||||
聊天完成接口
|
||||
如果 request.model 未指定,则根据 token 长度自动路由
|
||||
OpenAI 兼容的 Chat Completions API
|
||||
- model 为空时自动使用 NVIDIA 分类器路由
|
||||
- stream=true 时返回 SSE 流式响应
|
||||
"""
|
||||
# 选择模型
|
||||
# 1. 路由决策
|
||||
routing_detail = None
|
||||
if request.model:
|
||||
model_key = request.model
|
||||
routing_detail = {
|
||||
"method": "user_specified",
|
||||
"query": next((m.content for m in reversed(request.messages) if m.role == "user" and m.content), ""),
|
||||
}
|
||||
else:
|
||||
model_key = select_model_by_length(request.messages)
|
||||
model_key, routing_detail = select_model_by_nvidia_classifier(request.messages)
|
||||
|
||||
# 获取 LiteLLM 模型名称
|
||||
provider_model = get_provider_model(model_key)
|
||||
provider = get_provider_from_model(provider_model)
|
||||
messages_raw = [{"role": m.role, "content": m.content} for m in request.messages]
|
||||
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||
|
||||
# 2. 流式响应
|
||||
if request.stream:
|
||||
return StreamingResponse(
|
||||
_stream_response(
|
||||
provider_model=provider_model,
|
||||
model_key=model_key,
|
||||
messages_raw=messages_raw,
|
||||
request=request,
|
||||
response_id=response_id,
|
||||
routing_detail=routing_detail,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
# 3. 非流式响应
|
||||
start_time = time.time()
|
||||
|
||||
# 构建请求参数(过滤掉 None 和 0 的 max_tokens)
|
||||
completion_kwargs = {
|
||||
"model": provider_model,
|
||||
"messages": messages_raw,
|
||||
"temperature": request.temperature,
|
||||
}
|
||||
if request.max_tokens and request.max_tokens > 0:
|
||||
completion_kwargs["max_tokens"] = request.max_tokens
|
||||
|
||||
try:
|
||||
# 使用 LiteLLM 统一调用
|
||||
response = await acompletion(
|
||||
model=provider_model,
|
||||
messages=[{"role": m.role, "content": m.content} for m in request.messages],
|
||||
temperature=request.temperature,
|
||||
max_tokens=request.max_tokens,
|
||||
)
|
||||
|
||||
response = await acompletion(**completion_kwargs)
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# 计算成本
|
||||
input_tokens = response.usage.prompt_tokens
|
||||
output_tokens = response.usage.completion_tokens
|
||||
cost = calculate_cost(model_key, input_tokens, output_tokens)
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# 记录调用
|
||||
log_call(model_key, provider, cost, latency_ms, input_tokens + output_tokens)
|
||||
|
||||
return ChatResponse(
|
||||
id=response.id,
|
||||
model=model_key,
|
||||
provider=provider,
|
||||
content=response.choices[0].message.content,
|
||||
usage={
|
||||
"prompt_tokens": input_tokens,
|
||||
"completion_tokens": output_tokens,
|
||||
"total_tokens": input_tokens + output_tokens,
|
||||
log_call(
|
||||
model=model_key, cost=cost, latency_ms=latency_ms,
|
||||
input_tokens=input_tokens, output_tokens=output_tokens,
|
||||
messages_raw=messages_raw, response_content=content,
|
||||
response_id=response_id, routing_detail=routing_detail,
|
||||
request_params={
|
||||
"temperature": request.temperature,
|
||||
"max_tokens": request.max_tokens,
|
||||
"user_specified_model": request.model,
|
||||
},
|
||||
cost_usd=cost,
|
||||
latency_ms=round(latency_ms, 2),
|
||||
)
|
||||
|
||||
return build_openai_response(response_id, model_key, content, input_tokens, output_tokens, routing_detail)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"API error: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
async def list_models():
|
||||
"""列出支持的模型"""
|
||||
return {
|
||||
"models": [
|
||||
{
|
||||
"key": key,
|
||||
"provider": config["provider"],
|
||||
"input_cost_per_1k": config["input_cost"],
|
||||
"output_cost_per_1k": config["output_cost"],
|
||||
async def _stream_response(
|
||||
provider_model: str,
|
||||
model_key: str,
|
||||
messages_raw: List[Dict],
|
||||
request: ChatCompletionRequest,
|
||||
response_id: str,
|
||||
routing_detail: Optional[Dict],
|
||||
):
|
||||
"""生成 SSE 流式响应"""
|
||||
start_time = time.time()
|
||||
collected_content = ""
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
try:
|
||||
# 构建请求参数(过滤掉 None 和 0 的 max_tokens)
|
||||
completion_kwargs = {
|
||||
"model": provider_model,
|
||||
"messages": messages_raw,
|
||||
"temperature": request.temperature,
|
||||
"stream": True,
|
||||
}
|
||||
if request.max_tokens and request.max_tokens > 0:
|
||||
completion_kwargs["max_tokens"] = request.max_tokens
|
||||
|
||||
response = await acompletion(**completion_kwargs)
|
||||
|
||||
async for chunk in response:
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# 收集内容用于日志
|
||||
if delta.content:
|
||||
collected_content += delta.content
|
||||
output_tokens += 1 # 近似计数
|
||||
|
||||
# 构建 SSE 数据
|
||||
chunk_data = {
|
||||
"id": response_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model_key,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": None,
|
||||
}],
|
||||
}
|
||||
for key, config in MODEL_CONFIG.items()
|
||||
]
|
||||
|
||||
if delta.content:
|
||||
chunk_data["choices"][0]["delta"] = {"content": delta.content}
|
||||
elif delta.role:
|
||||
chunk_data["choices"][0]["delta"] = {"role": delta.role}
|
||||
|
||||
if chunk.choices[0].finish_reason:
|
||||
chunk_data["choices"][0]["finish_reason"] = chunk.choices[0].finish_reason
|
||||
|
||||
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 发送 [DONE]
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
# 记录日志
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
# 流式模式下用 tiktoken 近似计算 input_tokens
|
||||
try:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
input_tokens = sum(len(encoding.encode(m.get("content", "") or "")) for m in messages_raw) + len(messages_raw) * 4
|
||||
except Exception:
|
||||
input_tokens = 0
|
||||
|
||||
cost = calculate_cost(model_key, input_tokens, output_tokens)
|
||||
log_call(
|
||||
model=model_key, cost=cost, latency_ms=latency_ms,
|
||||
input_tokens=input_tokens, output_tokens=output_tokens,
|
||||
messages_raw=messages_raw, response_content=collected_content,
|
||||
response_id=response_id, routing_detail=routing_detail,
|
||||
request_params={
|
||||
"temperature": request.temperature,
|
||||
"max_tokens": request.max_tokens,
|
||||
"user_specified_model": request.model,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_data = {"error": {"message": str(e), "type": "api_error"}}
|
||||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
# ── OpenAI 兼容: /v1/models ─────────────────────────────────
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
"""OpenAI 兼容的模型列表接口"""
|
||||
data = []
|
||||
for key, config in MODEL_CONFIG.items():
|
||||
data.append({
|
||||
"id": key,
|
||||
"object": "model",
|
||||
"created": 1700000000,
|
||||
"owned_by": config["provider"].split("/")[0] if "/" in config["provider"] else "unknown",
|
||||
})
|
||||
return {"object": "list", "data": data}
|
||||
|
||||
|
||||
# ── 管理接口 ─────────────────────────────────────────────────
|
||||
@app.get("/stats")
|
||||
async def get_stats():
|
||||
"""获取调用统计摘要"""
|
||||
if not call_history:
|
||||
return {
|
||||
"total_calls": 0, "total_cost_usd": 0.0,
|
||||
"avg_latency_ms": 0.0, "model_distribution": {},
|
||||
"tier_distribution": {}, "task_type_distribution": {},
|
||||
}
|
||||
|
||||
total_calls = len(call_history)
|
||||
total_cost = sum(c["llm"]["cost_usd"] for c in call_history)
|
||||
avg_latency = sum(c["llm"]["llm_latency_ms"] for c in call_history) / total_calls
|
||||
|
||||
model_dist: Dict[str, int] = {}
|
||||
tier_dist: Dict[str, int] = {}
|
||||
task_dist: Dict[str, int] = {}
|
||||
|
||||
for call in call_history:
|
||||
model = call["llm"]["model"]
|
||||
model_dist[model] = model_dist.get(model, 0) + 1
|
||||
routing = call.get("routing") or {}
|
||||
if routing.get("tier"):
|
||||
tier_dist[routing["tier"]] = tier_dist.get(routing["tier"], 0) + 1
|
||||
if routing.get("task_type"):
|
||||
task_dist[routing["task_type"]] = task_dist.get(routing["task_type"], 0) + 1
|
||||
|
||||
return {
|
||||
"total_calls": total_calls,
|
||||
"total_cost_usd": round(total_cost, 6),
|
||||
"avg_latency_ms": round(avg_latency, 2),
|
||||
"avg_routing_ms": round(
|
||||
sum(c.get("routing", {}).get("routing_latency_ms", 0) for c in call_history) / total_calls, 2
|
||||
),
|
||||
"model_distribution": model_dist,
|
||||
"tier_distribution": tier_dist,
|
||||
"task_type_distribution": task_dist,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/stats", response_model=StatsResponse)
|
||||
async def get_stats():
|
||||
"""获取调用统计"""
|
||||
if not call_history:
|
||||
return StatsResponse(
|
||||
total_calls=0,
|
||||
total_cost_usd=0.0,
|
||||
avg_latency_ms=0.0,
|
||||
model_distribution={},
|
||||
provider_distribution={},
|
||||
recent_calls=[],
|
||||
)
|
||||
|
||||
total_calls = len(call_history)
|
||||
total_cost = sum(c["cost_usd"] for c in call_history)
|
||||
avg_latency = sum(c["latency_ms"] for c in call_history) / total_calls
|
||||
|
||||
# 模型分布
|
||||
model_dist: Dict[str, int] = {}
|
||||
provider_dist: Dict[str, int] = {}
|
||||
for call in call_history:
|
||||
model = call["model"]
|
||||
provider = call["provider"]
|
||||
model_dist[model] = model_dist.get(model, 0) + 1
|
||||
provider_dist[provider] = provider_dist.get(provider, 0) + 1
|
||||
|
||||
# 最近 10 条记录
|
||||
recent = [
|
||||
{
|
||||
"model": c["model"],
|
||||
"provider": c["provider"],
|
||||
"cost_usd": round(c["cost_usd"], 6),
|
||||
"latency_ms": round(c["latency_ms"], 2),
|
||||
"tokens": c["tokens"],
|
||||
}
|
||||
for c in call_history[-10:]
|
||||
]
|
||||
|
||||
return StatsResponse(
|
||||
total_calls=total_calls,
|
||||
total_cost_usd=round(total_cost, 6),
|
||||
avg_latency_ms=round(avg_latency, 2),
|
||||
model_distribution=model_dist,
|
||||
provider_distribution=provider_dist,
|
||||
recent_calls=recent,
|
||||
)
|
||||
@app.get("/stats/raw")
|
||||
async def get_stats_raw(limit: int = 50, offset: int = 0):
|
||||
"""获取原始调用记录"""
|
||||
total = len(call_history)
|
||||
records = list(reversed(call_history))
|
||||
return {"total": total, "limit": limit, "offset": offset, "records": records[offset:offset + limit]}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
return {"status": "healthy", "version": "0.2.0"}
|
||||
return {"status": "healthy", "version": "0.4.0", "router": "llm-compass"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
343
nvidia_router.py
Normal file
343
nvidia_router.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
NVIDIA Prompt Task & Complexity Classifier Router
|
||||
手动加载自定义多头模型,支持3-tier路由
|
||||
|
||||
模型: nvidia/prompt-task-and-complexity-classifier (184M参数)
|
||||
架构: DeBERTa-v3-base backbone + 8个分类头
|
||||
输出: task_type(12类), creativity(3类), reasoning(2类),
|
||||
domain_knowledge(4类), complexity_score 等多维度
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoTokenizer, DebertaV2Model, DebertaV2Config
|
||||
from safetensors.torch import load_file
|
||||
from huggingface_hub import hf_hub_download
|
||||
from typing import Dict, Optional
|
||||
import logging
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClassificationHead(nn.Module):
|
||||
"""单个分类头"""
|
||||
def __init__(self, input_dim: int, num_classes: int, dropout: float = 0.2):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.fc = nn.Linear(input_dim, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dropout(x)
|
||||
return self.fc(x)
|
||||
|
||||
|
||||
class NvidiaMultiHeadClassifier(nn.Module):
|
||||
"""
|
||||
NVIDIA 多头分类器
|
||||
DeBERTa backbone + 8个独立分类头
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# DeBERTa backbone
|
||||
self.backbone = DebertaV2Model.from_pretrained(
|
||||
config.base_model,
|
||||
ignore_mismatched_sizes=True,
|
||||
use_safetensors=True
|
||||
)
|
||||
|
||||
hidden_size = 768 # DeBERTa-v3-base
|
||||
dropout = config.fc_dropout if hasattr(config, 'fc_dropout') else 0.2
|
||||
|
||||
# 8个分类头 (与 state_dict 中的 head_0 ~ head_7 对应)
|
||||
target_sizes = config.target_sizes
|
||||
self.head_0 = ClassificationHead(hidden_size, target_sizes["task_type"], dropout) # 12类
|
||||
self.head_1 = ClassificationHead(hidden_size, target_sizes["creativity_scope"], dropout) # 3类
|
||||
self.head_2 = ClassificationHead(hidden_size, target_sizes["reasoning"], dropout) # 2类
|
||||
self.head_3 = ClassificationHead(hidden_size, target_sizes["contextual_knowledge"], dropout) # 2类
|
||||
self.head_4 = ClassificationHead(hidden_size, target_sizes["number_of_few_shots"], dropout) # 6类
|
||||
self.head_5 = ClassificationHead(hidden_size, target_sizes["domain_knowledge"], dropout) # 4类
|
||||
self.head_6 = ClassificationHead(hidden_size, target_sizes["no_label_reason"], dropout) # 1类
|
||||
self.head_7 = ClassificationHead(hidden_size, target_sizes["constraint_ct"], dropout) # 2类
|
||||
|
||||
# Head 名称映射
|
||||
self.head_names = [
|
||||
"task_type", # head_0: 12类
|
||||
"creativity_scope", # head_1: 3类
|
||||
"reasoning", # head_2: 2类
|
||||
"contextual_knowledge", # head_3: 2类
|
||||
"number_of_few_shots", # head_4: 6类
|
||||
"domain_knowledge", # head_5: 4类
|
||||
"no_label_reason", # head_6: 1类
|
||||
"constraint_ct", # head_7: 2类
|
||||
]
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
|
||||
outputs = self.backbone(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids
|
||||
)
|
||||
# 使用 [CLS] token 的隐层
|
||||
cls_output = outputs.last_hidden_state[:, 0]
|
||||
|
||||
# 各头输出
|
||||
head_outputs = {
|
||||
"task_type": self.head_0(cls_output),
|
||||
"creativity_scope": self.head_1(cls_output),
|
||||
"reasoning": self.head_2(cls_output),
|
||||
"contextual_knowledge": self.head_3(cls_output),
|
||||
"number_of_few_shots": self.head_4(cls_output),
|
||||
"domain_knowledge": self.head_5(cls_output),
|
||||
"no_label_reason": self.head_6(cls_output),
|
||||
"constraint_ct": self.head_7(cls_output),
|
||||
}
|
||||
return head_outputs
|
||||
|
||||
|
||||
class NvidiaComplexityRouter:
|
||||
"""NVIDIA 多头分类器路由封装"""
|
||||
|
||||
MODEL_NAME = "nvidia/prompt-task-and-complexity-classifier"
|
||||
|
||||
# Task type 映射
|
||||
TASK_TYPE_MAP = {
|
||||
0: "Brainstorming", 1: "Chatbot", 2: "Classification",
|
||||
3: "Closed QA", 4: "Code Generation", 5: "Extraction",
|
||||
6: "Open QA", 7: "Other", 8: "Rewrite",
|
||||
9: "Summarization", 10: "Text Generation", 11: "Unknown"
|
||||
}
|
||||
|
||||
# Domain knowledge 映射
|
||||
DOMAIN_MAP = {0: "High", 1: "Low", 2: "Medium", 3: "No"}
|
||||
|
||||
# Creativity 映射
|
||||
CREATIVITY_MAP = {0: "High", 1: "Low", 2: "No"}
|
||||
|
||||
def __init__(self, device: str = "auto"):
|
||||
if device == "auto":
|
||||
if torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
logger.info("MPS (Metal GPU) detected, using MPS acceleration")
|
||||
elif torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
logger.info("No GPU detected, using CPU")
|
||||
self.device = device
|
||||
self.tokenizer = None
|
||||
self.model = None
|
||||
self.config = None
|
||||
self._initialized = False
|
||||
|
||||
def initialize(self):
|
||||
"""延迟加载模型"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
logger.info(f"Loading NVIDIA classifier: {self.MODEL_NAME}")
|
||||
|
||||
# 1. 手动加载自定义 config.json(该模型无 model_type,AutoConfig 不兼容)
|
||||
config_path = hf_hub_download(self.MODEL_NAME, "config.json")
|
||||
with open(config_path, "r") as f:
|
||||
custom_config = json.load(f)
|
||||
|
||||
# 构建 backbone 的 DeBERTa config(从 base_model 加载)
|
||||
base_model = custom_config.get("base_model", "microsoft/DeBERTa-v3-base")
|
||||
self.config = DebertaV2Config.from_pretrained(base_model)
|
||||
# 保存自定义分类头参数
|
||||
self.config.target_sizes = custom_config["target_sizes"]
|
||||
self.config.fc_dropout = custom_config.get("fc_dropout", 0.2)
|
||||
self.config.base_model = base_model
|
||||
|
||||
# 2. 加载 tokenizer (slow模式,兼容性好)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME, use_fast=False)
|
||||
|
||||
# 3. 构建模型并加载权重
|
||||
self.model = NvidiaMultiHeadClassifier(self.config)
|
||||
|
||||
model_path = hf_hub_download(self.MODEL_NAME, "model.safetensors")
|
||||
state_dict = load_file(model_path)
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
self.model.to(self.device)
|
||||
# MPS 需要 float16 以避免矩阵乘法数据类型冲突
|
||||
if self.device == "mps":
|
||||
self.model.half()
|
||||
self.model.eval()
|
||||
self._initialized = True
|
||||
dtype = "float16" if self.device == "mps" else "float32"
|
||||
logger.info(f"NVIDIA classifier loaded successfully on {self.device} ({dtype})")
|
||||
|
||||
def predict(self, query: str) -> Dict:
|
||||
"""
|
||||
预测查询的多维度特征
|
||||
|
||||
Returns:
|
||||
{
|
||||
"tier": "simple" | "medium" | "complex",
|
||||
"complexity_score": float (0-1),
|
||||
"task_type": str,
|
||||
"domain_knowledge": str,
|
||||
"reasoning": bool,
|
||||
"creativity": str
|
||||
}
|
||||
"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
|
||||
inputs = self.tokenizer(
|
||||
query, return_tensors="pt", truncation=True, max_length=512, padding=True
|
||||
)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
# 解析各头输出
|
||||
task_type_idx = torch.argmax(outputs["task_type"], dim=-1).item()
|
||||
task_type = self.TASK_TYPE_MAP.get(task_type_idx, "Unknown")
|
||||
|
||||
domain_idx = torch.argmax(outputs["domain_knowledge"], dim=-1).item()
|
||||
domain = self.DOMAIN_MAP.get(domain_idx, "Unknown")
|
||||
|
||||
creativity_idx = torch.argmax(outputs["creativity_scope"], dim=-1).item()
|
||||
creativity = self.CREATIVITY_MAP.get(creativity_idx, "Unknown")
|
||||
|
||||
reasoning_idx = torch.argmax(outputs["reasoning"], dim=-1).item()
|
||||
needs_reasoning = reasoning_idx == 1
|
||||
|
||||
# 计算综合复杂度评分 (0-1)
|
||||
complexity_score = self._compute_complexity_score(
|
||||
domain=domain,
|
||||
creativity=creativity,
|
||||
needs_reasoning=needs_reasoning,
|
||||
task_type=task_type
|
||||
)
|
||||
|
||||
tier = self._score_to_tier(complexity_score)
|
||||
|
||||
return {
|
||||
"tier": tier,
|
||||
"complexity_score": complexity_score,
|
||||
"task_type": task_type,
|
||||
"domain_knowledge": domain,
|
||||
"reasoning": needs_reasoning,
|
||||
"creativity": creativity,
|
||||
}
|
||||
|
||||
def _compute_complexity_score(self, domain, creativity, needs_reasoning, task_type) -> float:
|
||||
"""
|
||||
综合多维度计算复杂度评分 (0-1)
|
||||
|
||||
权重:
|
||||
- domain_knowledge: 40% (High=1.0, Medium=0.6, Low=0.3, No=0.0)
|
||||
- reasoning: 30% (Yes=1.0, No=0.0)
|
||||
- creativity: 20% (High=1.0, Low=0.4, No=0.0)
|
||||
- task_type: 10% (Code=0.8, QA=0.5, Chatbot=0.2, ...)
|
||||
"""
|
||||
domain_scores = {"High": 1.0, "Medium": 0.6, "Low": 0.3, "No": 0.0}
|
||||
creativity_scores = {"High": 1.0, "Low": 0.4, "No": 0.0}
|
||||
task_complexity = {
|
||||
"Code Generation": 0.8, "Text Generation": 0.7,
|
||||
"Summarization": 0.6, "Rewrite": 0.5,
|
||||
"Open QA": 0.5, "Closed QA": 0.4,
|
||||
"Classification": 0.3, "Extraction": 0.3,
|
||||
"Brainstorming": 0.6, "Chatbot": 0.2,
|
||||
"Other": 0.5, "Unknown": 0.5,
|
||||
}
|
||||
|
||||
score = (
|
||||
0.4 * domain_scores.get(domain, 0.5) +
|
||||
0.3 * (1.0 if needs_reasoning else 0.0) +
|
||||
0.2 * creativity_scores.get(creativity, 0.5) +
|
||||
0.1 * task_complexity.get(task_type, 0.5)
|
||||
)
|
||||
return round(score, 3)
|
||||
|
||||
def _score_to_tier(self, score: float) -> str:
|
||||
if score < 0.35:
|
||||
return "simple"
|
||||
elif score < 0.65:
|
||||
return "medium"
|
||||
else:
|
||||
return "complex"
|
||||
|
||||
def select_model(self, query: str) -> str:
|
||||
"""直接返回推荐的模型名称"""
|
||||
result = self.predict(query)
|
||||
model_map = {
|
||||
"simple": "qwen-flash",
|
||||
"medium": "qwen-plus",
|
||||
"complex": "qwen-max"
|
||||
}
|
||||
return model_map[result["tier"]]
|
||||
|
||||
def benchmark(self, queries: list) -> Dict:
|
||||
"""批量测试"""
|
||||
import time
|
||||
results = []
|
||||
for query in queries:
|
||||
start = time.time()
|
||||
result = self.predict(query)
|
||||
elapsed = (time.time() - start) * 1000
|
||||
results.append({
|
||||
"query": query[:50],
|
||||
"tier": result["tier"],
|
||||
"score": result["complexity_score"],
|
||||
"task": result["task_type"],
|
||||
"domain": result["domain_knowledge"],
|
||||
"reasoning": result["reasoning"],
|
||||
"time_ms": round(elapsed, 1)
|
||||
})
|
||||
|
||||
times = [r["time_ms"] for r in results]
|
||||
return {
|
||||
"avg_ms": round(sum(times) / len(times), 1),
|
||||
"results": results
|
||||
}
|
||||
|
||||
|
||||
# 全局单例
|
||||
_router_instance: Optional[NvidiaComplexityRouter] = None
|
||||
|
||||
def get_nvidia_router() -> NvidiaComplexityRouter:
|
||||
global _router_instance
|
||||
if _router_instance is None:
|
||||
_router_instance = NvidiaComplexityRouter()
|
||||
return _router_instance
|
||||
|
||||
def select_model_by_nvidia(query: str) -> str:
|
||||
return get_nvidia_router().select_model(query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_queries = [
|
||||
"你好",
|
||||
"What is 2+2?",
|
||||
"Explain quantum computing principles in detail",
|
||||
"Write a quicksort algorithm in Python with error handling",
|
||||
"Analyze this 10-page research paper and summarize the key innovations",
|
||||
"Rewrite this sentence to be more concise",
|
||||
"Generate a creative story about a robot",
|
||||
]
|
||||
|
||||
router = NvidiaComplexityRouter()
|
||||
|
||||
print("=" * 80)
|
||||
print("NVIDIA Prompt Task & Complexity Classifier - 3-Tier Router Test")
|
||||
print("=" * 80)
|
||||
|
||||
for query in test_queries:
|
||||
result = router.predict(query)
|
||||
model = router.select_model(query)
|
||||
print(f"\nQuery: {query}")
|
||||
print(f" Tier: {result['tier']}")
|
||||
print(f" Score: {result['complexity_score']}")
|
||||
print(f" Task: {result['task_type']}")
|
||||
print(f" Domain: {result['domain_knowledge']}")
|
||||
print(f" Reasoning: {result['reasoning']}")
|
||||
print(f" Creativity: {result['creativity']}")
|
||||
print(f" -> Model: {model}")
|
||||
17
requirements.lock.txt
Normal file
17
requirements.lock.txt
Normal file
@@ -0,0 +1,17 @@
|
||||
# LLM Compass - Docker 锁定依赖 (CPU)
|
||||
# 使用 CPU 版 PyTorch,大幅减小镜像体积
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
fastapi==0.136.0
|
||||
uvicorn[standard]==0.44.0
|
||||
pydantic==2.12.5
|
||||
litellm==1.83.9
|
||||
tiktoken==0.12.0
|
||||
httpx==0.28.1
|
||||
python-dotenv==1.0.1
|
||||
torch==2.2.2+cpu
|
||||
transformers==4.57.6
|
||||
tokenizers==0.22.2
|
||||
safetensors==0.4.3
|
||||
numpy==1.26.4
|
||||
sentencepiece==0.2.1
|
||||
huggingface_hub>=0.28.0
|
||||
@@ -5,5 +5,9 @@ litellm>=1.0.0
|
||||
tiktoken>=0.5.0
|
||||
httpx>=0.25.0
|
||||
python-dotenv>=1.0.0
|
||||
transformers>=4.30.0
|
||||
torch>=2.0.0
|
||||
# NVIDIA Multi-head Classifier for 3-tier routing
|
||||
# nvidia/prompt-task-and-complexity-classifier will be loaded via transformers
|
||||
pytest>=7.4.0
|
||||
pytest-asyncio>=0.21.0
|
||||
|
||||
48
start.sh
Executable file
48
start.sh
Executable file
@@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
# LLM Compass 启动脚本
|
||||
# 用法: ./start.sh [端口号]
|
||||
|
||||
set -e
|
||||
|
||||
# 默认端口
|
||||
PORT=${1:-8402}
|
||||
|
||||
# 获取脚本所在目录
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# 检查虚拟环境
|
||||
if [ ! -d ".venv" ]; then
|
||||
echo "❌ 虚拟环境不存在,请先运行:"
|
||||
echo " python3 -m venv .venv"
|
||||
echo " .venv/bin/pip install -r requirements.txt"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 检查 .env 文件
|
||||
if [ ! -f ".env" ]; then
|
||||
echo "❌ .env 文件不存在,请创建并配置 API Key"
|
||||
echo " cp .env.example .env"
|
||||
echo " 编辑 .env 填入 DASHSCOPE_API_KEY"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 加载环境变量
|
||||
export DASHSCOPE_API_KEY=$(grep DASHSCOPE_API_KEY .env | cut -d= -f2)
|
||||
|
||||
if [ -z "$DASHSCOPE_API_KEY" ]; then
|
||||
echo "❌ DASHSCOPE_API_KEY 未设置,请检查 .env 文件"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "🚀 启动 LLM Compass 服务..."
|
||||
echo "📍 地址: http://localhost:${PORT}"
|
||||
echo "📖 API 文档: http://localhost:${PORT}/docs"
|
||||
echo "🔧 路由方式: NVIDIA MPS 加速 (M4 Pro GPU)"
|
||||
echo ""
|
||||
|
||||
# 启动服务
|
||||
exec .venv/bin/python -m uvicorn main:app \
|
||||
--host 0.0.0.0 \
|
||||
--port "$PORT" \
|
||||
--log-level info
|
||||
Reference in New Issue
Block a user