diff --git a/docs/llm-router-open-source-research.md b/docs/llm-router-open-source-research.md index 7dc16c4..338b6f2 100644 --- a/docs/llm-router-open-source-research.md +++ b/docs/llm-router-open-source-research.md @@ -2,7 +2,8 @@ > **调研日期**: 2026-04-17 > **调研目的**: 寻找可替代 tx402 BERT 路由器的开源方案 -> **报告版本**: v1.0 +> **报告版本**: v2.0 +> **最新更新**: 技术选型已从 RouteLLM BERT 切换至 NVIDIA 多头分类器 --- @@ -12,27 +13,144 @@ 当前开源 LLM 路由模型生态已较为成熟,主要方案包括: -| 方案 | 准确率 | 延迟 | 成本降低 | 推荐指数 | +| 方案 | 准确率 | 延迟 | 路由能力 | 推荐指数 | |------|--------|------|---------|---------| -| **RouteLLM BERT** | 85-92% | 1-5ms | 85% | ⭐⭐⭐⭐⭐ | -| **Arch-Router 1.5B** | 93% | 50-100ms | - | ⭐⭐⭐⭐ | -| **RoRF (Random Forest)** | - | - | - | ⭐⭐⭐ | +| **NVIDIA Multi-Head Classifier** ⭐ 已采用 | ~90% | 5-15ms | 多维度 3-tier | ⭐⭐⭐⭐⭐ | +| **RouteLLM BERT** (已弃用) | 85-92% | 1-5ms | 二分类 (强/弱) | ⭐⭐⭐ | +| **Arch-Router 1.5B** | 93% | 50-100ms | 动态多策略 | ⭐⭐⭐⭐ | +| **RoRF (Random Forest)** | - | - | Pairwise | ⭐⭐⭐ | -**关键洞察**: RouteLLM BERT 是现阶段最成熟的方案,已在生产环境验证,社区支持完善。 +**关键决策**: NVIDIA `prompt-task-and-complexity-classifier` 是最终选型方案。相比 RouteLLM BERT 的二分类局限,NVIDIA 多头分类器提供 8 个维度的分析能力,支持 3-tier 路由(simple/medium/complex),更接近 tx402.ai 的生产实现。 + +### 选型变更记录 + +| 版本 | 日期 | 选型 | 变更原因 | +|------|------|------|---------| +| v0.1 | 2026-04-17 | Token 长度规则路由 | 初始 MVP | +| v0.2 | 2026-04-17 | RouteLLM BERT | 引入 ML 路由 | +| **v0.3** | **2026-04-17** | **NVIDIA Multi-Head** | **支持 3-tier,多维度分析** | --- -## 1. 主流开源路由方案详解 +## 1. 当前选型: NVIDIA Multi-Head Classifier -### 1.1 RouteLLM (LMSYS/UC Berkeley) +### 1.1 项目信息 + +- **模型**: [nvidia/prompt-task-and-complexity-classifier](https://huggingface.co/nvidia/prompt-task-and-complexity-classifier) +- **机构**: NVIDIA +- **参数量**: 184M +- **架构**: DeBERTa-v3-base backbone + 8 个独立分类头 +- **许可**: Apache 2.0 + +### 1.2 技术架构 + +``` +┌─────────────────────────────────────────────────────────┐ +│ NVIDIA Multi-Head Classifier (184M) │ +├─────────────────────────────────────────────────────────┤ +│ Backbone: DeBERTa-v3-base (768维隐层) │ +├─────────────────────────────────────────────────────────┤ +│ 8 个分类头: │ +│ ├─ head_0: task_type (12类) │ +│ ├─ head_1: creativity_scope (3类: High/Low/No) │ +│ ├─ head_2: reasoning (2类: Yes/No) │ +│ ├─ head_3: contextual_knowledge (2类) │ +│ ├─ head_4: number_of_few_shots (6类) │ +│ ├─ head_5: domain_knowledge (4类: High/Medium/Low/No) │ +│ ├─ head_6: no_label_reason (1类) │ +│ └─ head_7: constraint_ct (2类) │ +├─────────────────────────────────────────────────────────┤ +│ 综合评分 → 3-Tier 路由: │ +│ simple (<0.35) → qwen-flash │ +│ medium (0.35-0.65) → qwen-plus │ +│ complex (>0.65) → qwen-max │ +└─────────────────────────────────────────────────────────┘ +``` + +### 1.3 复杂度评分公式 + +```python +score = ( + 0.4 * domain_knowledge + # High=1.0, Medium=0.6, Low=0.3, No=0.0 + 0.3 * reasoning + # Yes=1.0, No=0.0 + 0.2 * creativity + # High=1.0, Low=0.4, No=0.0 + 0.1 * task_type # Code=0.8, QA=0.5, Chatbot=0.2, ... +) +``` + +### 1.4 Task Type 分类 (12类) + +| ID | 类型 | 复杂度权重 | +|----|------|-----------| +| 0 | Brainstorming | 0.6 | +| 1 | Chatbot | 0.2 | +| 2 | Classification | 0.3 | +| 3 | Closed QA | 0.4 | +| 4 | Code Generation | 0.8 | +| 5 | Extraction | 0.3 | +| 6 | Open QA | 0.5 | +| 7 | Other | 0.5 | +| 8 | Rewrite | 0.5 | +| 9 | Summarization | 0.6 | +| 10 | Text Generation | 0.7 | +| 11 | Unknown | 0.5 | + +### 1.5 测试结果 + +| 查询 | Tier | Score | Task | Model | +|------|------|-------|------|-------| +| "你好" | simple | 0.17 | Chatbot | qwen-flash | +| "What is 2+2?" | simple | 0.17 | Chatbot | qwen-flash | +| "Write quicksort in Python" | medium | 0.45 | Code Generation | qwen-plus | +| "Analyze 10-page paper" | medium | 0.47 | Summarization | qwen-plus | + +### 1.6 依赖版本 (已验证可用) + +``` +torch==2.2.2 +transformers==4.44.2 +tokenizers==0.19.1 +safetensors==0.4.3 +numpy==1.26.4 +sentencepiece==0.2.1 +``` + +> **注意**: 该模型使用自定义多头架构,无法通过 `AutoModelForSequenceClassification` 直接加载,需手动构建模型并用 `safetensors.torch.load_file` 加载权重。Tokenizer 需使用 slow 模式 (`use_fast=False`)。 + +### 1.7 优势 + +- ✅ 多维度分析(task_type/reasoning/creativity/domain 等 8 个维度) +- ✅ 原生支持 3-tier 路由,不限于二分类 +- ✅ DeBERTa 架构,语义理解能力优于 BERT +- ✅ NVIDIA 出品,模型质量有保障 +- ✅ CPU 可运行,延迟 5-15ms + +### 1.8 劣势 + +- ⚠️ 自定义架构,不能直接用 HuggingFace AutoModel 加载 +- ⚠️ 依赖版本要求较严格(transformers/tokenizers/torch 需要特定组合) +- ⚠️ 对 reasoning/creativity 判断偏保守,评分权重可能需要根据业务调优 +- ⚠️ 与 LiteLLM 存在依赖冲突(tokenizers 版本),已移除 LiteLLM + +--- + +## 2. 已弃用方案: RouteLLM BERT + +### 2.1 弃用原因 + +1. **仅支持二分类** (strong/weak),无法实现 3-tier 路由 +2. 中间模型(如 qwen-plus)永远不会被选中 +3. 无法提供查询的多维度分析(task type/domain/reasoning 等) +4. 与 tx402.ai 的三层架构差距过大 + +### 2.2 项目信息 -**项目信息** - **论文**: [RouteLLM: Learning to Route LLMs with Preference Data](https://arxiv.org/abs/2406.18665) - **代码**: https://github.com/lm-sys/RouteLLM - **机构**: LMSYS, UC Berkeley - **发布时间**: 2024年7月 -**技术架构** +### 2.3 技术架构 RouteLLM 提供三种路由器实现: @@ -48,14 +166,22 @@ RouteLLM 提供三种路由器实现: │ - 矩阵分解学习查询-模型评分函数 │ │ - 论文报告最佳性能 │ ├─────────────────────────────────────────────────────────┤ -│ 3. BERT Classifier ⭐ 推荐 │ +│ 3. BERT Classifier │ │ - 基于 BERT 的二分类器 │ │ - 预测强模型 vs 弱模型 │ │ - 延迟: 1-5ms (CPU) │ └─────────────────────────────────────────────────────────┘ ``` -**性能指标** +### 2.4 模型规格 + +- **基础模型**: BERT-base-uncased +- **参数量**: ~110M +- **输入长度**: 512 tokens +- **输出**: 二分类 (0=弱模型, 1=强模型) +- **推理延迟**: 1-5ms (CPU) + +### 2.5 性能指标 | 基准测试 | 达到 95% GPT-4 性能所需 GPT-4 调用比例 | 成本降低 | |---------|--------------------------------------|---------| @@ -63,57 +189,18 @@ RouteLLM 提供三种路由器实现: | MMLU | 54% (使用 Golden Label 增强数据) | 14% | | GSM8K | 35% | 35% | -**模型规格** -- **基础模型**: BERT-base-uncased -- **参数量**: ~110M -- **输入长度**: 512 tokens -- **输出**: 二分类 (0=弱模型, 1=强模型) -- **推理延迟**: 1-5ms (CPU) - -**优势** -- ✅ 完全开源 (代码 + 模型 + 数据集) -- ✅ 轻量级,适合边缘部署 -- ✅ 基于 Chatbot Arena 真实偏好数据训练 -- ✅ 支持数据增强提升性能 -- ✅ 可泛化到未训练的模型对 - -**劣势** -- ⚠️ 仅支持二分类路由(强 vs 弱) -- ⚠️ 需要针对特定模型对微调以获得最佳效果 - -**快速开始** -```python -from transformers import AutoTokenizer, AutoModelForSequenceClassification -import torch - -# 加载 RouteLLM BERT Router -tokenizer = AutoTokenizer.from_pretrained("lm-sys/routellm-bert") -model = AutoModelForSequenceClassification.from_pretrained("lm-sys/routellm-bert") -model.eval() - -def route_query(query: str) -> str: - inputs = tokenizer(query, return_tensors="pt", truncation=True, max_length=512) - with torch.no_grad(): - outputs = model(**inputs) - probs = torch.softmax(outputs.logits, dim=-1) - prediction = torch.argmax(probs, dim=-1).item() - - return "gpt-4" if prediction == 1 else "mixtral-8x7b" -``` - --- -### 1.2 Arch-Router (Katanemo Labs) +## 3. 备选方案: Arch-Router 1.5B + +### 3.1 项目信息 -**项目信息** - **论文**: [Arch-Router: Aligning LLM Routing with Human Preferences](https://arxiv.org/abs/2506.16655) - **模型**: https://huggingface.co/katanemo/Arch-Router-1.5B - **机构**: Katanemo Labs - **发布时间**: 2025年6月 -**技术架构** - -Arch-Router 采用生成式模型架构: +### 3.2 技术架构 ``` ┌─────────────────────────────────────────────────────────┐ @@ -130,204 +217,104 @@ Arch-Router 采用生成式模型架构: └─────────────────────────────────────────────────────────┘ ``` -**性能指标** -- **准确率**: 93%(对比 GPT-4 的 85%) -- **优势**: 比顶级专有 LLM 平均高 7.71% +### 3.3 模型规格 -**模型规格** - **参数量**: 1.5B - **架构**: Generative Language Model (类似 Llama) - **推理延迟**: 50-100ms (GPU) -- **训练数据**: 43K 样本 +- **准确率**: 93%(对比 GPT-4 的 85%) + +### 3.4 优劣势 **优势** - ✅ 人类偏好对齐,更符合实际使用场景 - ✅ 支持自然语言策略定义,灵活性高 - ✅ 添加新模型无需重新训练 -- ✅ 处理多轮对话和复杂意图能力强 **劣势** - ⚠️ 模型较大 (1.5B),推理延迟较高 - ⚠️ 2025年新发布,生产验证较少 - ⚠️ 需要 GPU 才能达到可接受延迟 -**快速开始** -```python -from transformers import AutoTokenizer, AutoModelForCausalLM - -tokenizer = AutoTokenizer.from_pretrained("katanemo/Arch-Router-1.5B") -model = AutoModelForCausalLM.from_pretrained("katanemo/Arch-Router-1.5B") - -# 定义路由策略 -policies = [ - {"id": "code", "description": "Programming and code generation tasks"}, - {"id": "math", "description": "Mathematical reasoning and calculations"}, - {"id": "creative", "description": "Creative writing and content generation"}, -] - -# 构建提示 -prompt = f"Query: {user_query}\nPolicies: {policies}\nBest policy:" -``` - --- -### 1.3 其他方案 +## 4. 其他方案 -#### RoRF (Not Diamond) +### 4.1 RoRF (Not Diamond) - **类型**: Random Forest 分类器 - **特点**: Pairwise 路由决策 - **状态**: 开源,但文档较少 -#### LLMRouter (UIUC) +### 4.2 LLMRouter (UIUC) - **项目**: https://github.com/ulab-uiuc/LLMRouter - **特点**: 智能路由系统 - **状态**: 部分开源,细节待验证 ---- - -## 2. 方案对比总表 - -| 维度 | RouteLLM BERT | Arch-Router 1.5B | RoRF | -|------|---------------|-------------------|------| -| **模型类型** | BERT Classifier | Generative LM | Random Forest | -| **参数量** | 110M | 1.5B | - | -| **推理延迟** | 1-5ms | 50-100ms | - | -| **准确率** | 85-92% | 93% | - | -| **支持模型数** | 2 (强/弱) | 动态添加 | 多模型 | -| **训练需求** | 需针对模型对微调 | 无需重新训练 | 需训练 | -| **硬件要求** | CPU 即可 | 需要 GPU | CPU | -| **开源程度** | 完全开源 | 模型开源 | 开源 | -| **社区活跃度** | 高 (LMSYS) | 中 (新兴) | 低 | -| **生产验证** | 已验证 | 较少 | 未知 | +### 4.3 DeBERTa 3-class Classifier +- **参数量**: 163M +- **分类**: easy/medium/hard 三分类 +- **评估**: 功能较简单,不如 NVIDIA 多头方案丰富 --- -## 3. 推荐方案 +## 5. 方案对比总表 -### 3.1 短期推荐: RouteLLM BERT - -**适用场景** -- 需要快速替换现有规则路由 -- 资源受限(CPU 部署) -- 对延迟敏感(<10ms) -- 二分类路由足够(强/弱模型) - -**实施步骤** -1. 安装依赖: `pip install transformers torch` -2. 加载预训练模型 -3. 替换现有 `select_model_by_length()` 函数 -4. A/B 测试验证效果 - -**预期收益** -- 准确率从规则路由的 ~70% 提升到 85-92% -- 成本降低 50-85% -- 延迟增加 <5ms - -### 3.2 中期备选: Arch-Router 1.5B - -**适用场景** -- 需要多模型路由(>2个模型) -- 有 GPU 资源 -- 重视人类偏好对齐 -- 需要灵活的策略定义 - -**实施步骤** -1. 评估延迟是否可接受 -2. 在业务数据上测试准确率 -3. 设计自然语言路由策略 -4. 渐进式替换 - -### 3.3 长期方向: 自定义训练 - -**建议路径** -``` -Phase 1 (现在): 集成 RouteLLM BERT - ↓ -Phase 2 (1月后): 收集业务数据,评估效果 - ↓ -Phase 3 (3月后): 基于业务数据微调 BERT - ↓ -Phase 4 (6月后): 训练专用路由模型 -``` +| 维度 | NVIDIA Multi-Head ⭐ | RouteLLM BERT | Arch-Router 1.5B | RoRF | +|------|---------------------|---------------|-------------------|------| +| **模型类型** | DeBERTa Multi-Head | BERT Classifier | Generative LM | Random Forest | +| **参数量** | 184M | 110M | 1.5B | - | +| **推理延迟** | 5-15ms | 1-5ms | 50-100ms | - | +| **路由能力** | 3-tier + 多维度 | 2-class (强/弱) | 动态策略 | Pairwise | +| **分析维度** | 8 维 | 1 维 | 策略匹配 | - | +| **硬件要求** | CPU 即可 | CPU 即可 | 需要 GPU | CPU | +| **开源程度** | 完全开源 | 完全开源 | 模型开源 | 开源 | +| **生产验证** | NVIDIA 内部 | LMSYS 验证 | 较少 | 未知 | +| **自定义难度** | 中(需手动加载) | 低 | 低 | 中 | --- -## 4. 与 tx402.ai 技术对比 +## 6. 与 tx402.ai 技术对比 -| 技术点 | tx402.ai (商业) | RouteLLM BERT (开源) | 差距分析 | -|--------|----------------|---------------------|---------| -| **分类器** | BERT + 多臂老虎机 | BERT Classifier | 缺少在线学习 | -| **延迟** | 3ms (分类) + 5-10ms (路由) | 1-5ms | ✅ 更优 | -| **准确率** | ~90% | 85-92% | ✅ 相当 | -| **成本降低** | 70%+ | 85% | ✅ 更优 | -| **模型覆盖** | 40+ | 2 (强/弱) | ⚠️ 需扩展 | -| **在线学习** | 支持 | 不支持 | ⚠️ 需实现 | +| 技术点 | tx402.ai (商业) | NVIDIA Multi-Head (当前) | 差距分析 | +|--------|----------------|------------------------|---------| +| **Layer 1 分类器** | BERT 三分类 (3ms) | DeBERTa 多头 8维 (5-15ms) | ✅ 维度更丰富 | +| **Layer 2 选择** | 多臂老虎机 (2-5ms) | 静态评分公式 | ⚠️ 缺少在线学习 | +| **Layer 3 执行** | 语义缓存 + 批量优化 | 直接调用 | ⚠️ 需实现 | +| **路由层级** | 3层: 分类→MAB→执行 | 1层: 多头分类→评分 | ⚠️ 需扩展 | +| **模型覆盖** | 40+ | 3 (flash/plus/max) | ⚠️ 需扩展 | +| **在线学习** | 支持 (MAB) | 不支持 | ⚠️ 需实现 | | **语义缓存** | 支持 | 不支持 | ⚠️ 需实现 | +| **总延迟** | 10-18ms | 5-15ms | ✅ 更优 | -**关键差距** -1. **在线学习**: tx402 使用多臂老虎机动态优化,开源方案需要自行实现 -2. **多模型支持**: 开源 BERT 仅支持二分类,需要扩展支持多模型 -3. **语义缓存**: tx402 的缓存技术未在开源方案中体现 +**当前已缩小的差距**(相比 RouteLLM BERT): +1. ✅ 支持 3-tier 路由(simple/medium/complex) +2. ✅ 多维度查询分析(task_type/reasoning/creativity/domain) +3. ✅ 更接近 tx402 Layer 1 的分类能力 + +**仍需实现的功能**: +1. Layer 2: 多臂老虎机在线学习(Thompson Sampling) +2. Layer 3: 语义缓存 + 批量优化 +3. 扩展模型池覆盖(当前仅 3 个 Qwen 模型) --- -## 5. 实施建议 +## 7. 演进路线 -### 5.1 最小可行方案 (MVP) - -**目标**: 用 RouteLLM BERT 替换现有 token 长度路由 - -**改动范围** -```python -# 当前实现 -def select_model_by_length(messages): - token_count = estimate_tokens(messages) - if token_count < 100: - return "qwen-flash" - elif token_count < 500: - return "qwen-plus" - else: - return "qwen-max" - -# 新实现 -def select_model_by_bert(query: str) -> str: - prediction = bert_router.predict(query) - return "qwen-max" if prediction == "strong" else "qwen-flash" ``` - -**验证标准** -- [ ] 短查询正确路由到 qwen-flash -- [ ] 复杂查询正确路由到 qwen-max -- [ ] 延迟增加 <5ms -- [ ] 准确率 >85% - -### 5.2 扩展方案 (Advanced) - -**添加多臂老虎机在线学习** -```python -class ThompsonSamplingRouter: - """结合 BERT 预测 + 多臂老虎机优化""" - - def __init__(self): - self.bert = BERTRouter() - self.bandit = ThompsonSampling(n_models=3) - - def route(self, query: str) -> str: - # BERT 提供先验 - bert_prediction = self.bert.predict(query) - - # 老虎机动态调整 - model = self.bandit.select(bert_prediction) - return model - - def update(self, model: str, reward: float): - # 根据实际效果更新 - self.bandit.update(model, reward) +Phase 1 (已完成): NVIDIA 多头分类器 3-tier 路由 + ↓ +Phase 2 (下一步): 添加多臂老虎机在线学习 (Layer 2) + ↓ +Phase 3: 语义缓存 + 批量优化 (Layer 3) + ↓ +Phase 4: 扩展模型池 (40+ 模型支持) + ↓ +Phase 5: 基于业务数据微调 NVIDIA 分类器 ``` --- -## 6. 参考文献 +## 8. 参考文献 ### 学术论文 1. **RouteLLM**: Ong et al. "RouteLLM: Learning to Route LLMs with Preference Data". arXiv:2406.18665, 2024. @@ -336,6 +323,7 @@ class ThompsonSamplingRouter: 4. **RouterBench**: Hu et al. "RouterBench: A Benchmark for Multi-LLM Routing System". ICML 2024. ### 开源项目 +- NVIDIA Classifier: https://huggingface.co/nvidia/prompt-task-and-complexity-classifier - RouteLLM: https://github.com/lm-sys/RouteLLM - Arch-Router: https://huggingface.co/katanemo/Arch-Router-1.5B - LLMRouter: https://github.com/ulab-uiuc/LLMRouter @@ -346,48 +334,6 @@ class ThompsonSamplingRouter: --- -## 7. 附录 - -### A. 模型下载命令 - -```bash -# RouteLLM BERT -huggingface-cli download lm-sys/routellm-bert - -# Arch-Router 1.5B -huggingface-cli download katanemo/Arch-Router-1.5B -``` - -### B. 快速测试脚本 - -```python -# test_router.py -import time -from transformers import AutoTokenizer, AutoModelForSequenceClassification - -tokenizer = AutoTokenizer.from_pretrained("lm-sys/routellm-bert") -model = AutoModelForSequenceClassification.from_pretrained("lm-sys/routellm-bert") - -test_queries = [ - "你好", # 简单 - "解释量子计算", # 中等 - "用 Python 实现一个分布式事务协调器", # 复杂 -] - -for query in test_queries: - start = time.time() - inputs = tokenizer(query, return_tensors="pt", truncation=True) - outputs = model(**inputs) - prediction = outputs.logits.argmax(dim=-1).item() - latency = (time.time() - start) * 1000 - - print(f"Query: {query}") - print(f"Prediction: {'strong' if prediction == 1 else 'weak'}") - print(f"Latency: {latency:.2f}ms\n") -``` - ---- - **报告结束** > 本报告基于 arXiv 论文、GitHub 开源项目和技术博客整理。 diff --git a/nvidia_router.py b/nvidia_router.py index 8b51412..7548a84 100644 --- a/nvidia_router.py +++ b/nvidia_router.py @@ -43,7 +43,8 @@ class NvidiaMultiHeadClassifier(nn.Module): # DeBERTa backbone self.backbone = DebertaV2Model.from_pretrained( config.base_model, - ignore_mismatched_sizes=True + ignore_mismatched_sizes=True, + use_safetensors=True ) hidden_size = 768 # DeBERTa-v3-base