From f9cc7973b93c9e7e0f98ff5130124687c8940fa0 Mon Sep 17 00:00:00 2001 From: aszerW Date: Sat, 18 Apr 2026 00:12:51 +0800 Subject: [PATCH] feat: integrate RouteLLM BERT router for intelligent query classification MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加 transformers 和 torch 依赖 - 创建 bert_router.py 封装 RouteLLM BERT 分类器 - 新增 select_model_by_bert() 函数替代 token 长度路由 - BERT 输出映射: strong->qwen-max, weak->qwen-flash - 保留 token 长度路由作为 fallback --- bert_router.py | 161 +++++++++++++++++++++++++++++++++++++++++++++++ main.py | 43 ++++++++++++- requirements.txt | 2 + 3 files changed, 204 insertions(+), 2 deletions(-) create mode 100644 bert_router.py diff --git a/bert_router.py b/bert_router.py new file mode 100644 index 0000000..09c213b --- /dev/null +++ b/bert_router.py @@ -0,0 +1,161 @@ +""" +RouteLLM BERT Router 封装 +基于预训练的 BERT 分类器进行查询复杂度预测 +""" +import time +from typing import Optional +import torch +from transformers import AutoTokenizer, AutoModelForSequenceClassification + + +class BERTRouter: + """ + RouteLLM BERT 路由器 + + 模型信息: + - 基础模型: BERT-base-uncased + - 参数量: ~110M + - 输入长度: 512 tokens + - 输出: 二分类 (0=弱模型, 1=强模型) + - 预期延迟: 1-5ms (CPU) + + 使用方法: + router = BERTRouter() + result = router.predict("你的查询文本") + # result: "strong" 或 "weak" + """ + + MODEL_NAME = "lm-sys/routellm-bert" + + def __init__(self, device: Optional[str] = None): + """ + 初始化 BERT Router + + Args: + device: 运行设备 ('cpu', 'cuda', 或 None自动选择) + """ + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self._load_model() + + def _load_model(self): + """加载模型和tokenizer""" + try: + self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME) + self.model = AutoModelForSequenceClassification.from_pretrained(self.MODEL_NAME) + self.model.to(self.device) + self.model.eval() + except Exception as e: + raise RuntimeError(f"Failed to load BERT router model: {e}") + + def predict(self, query: str) -> str: + """ + 预测查询复杂度 + + Args: + query: 用户查询文本 + + Returns: + "strong": 复杂任务,应使用强模型 + "weak": 简单任务,应使用弱模型 + """ + # 编码输入 + inputs = self.tokenizer( + query, + return_tensors="pt", + truncation=True, + max_length=512, + padding=True + ) + + # 移动到设备 + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + # 推理 + with torch.no_grad(): + outputs = self.model(**inputs) + probs = torch.softmax(outputs.logits, dim=-1) + prediction = torch.argmax(probs, dim=-1).item() + + # 0 = 弱模型, 1 = 强模型 + return "strong" if prediction == 1 else "weak" + + def predict_with_confidence(self, query: str) -> tuple: + """ + 预测并返回置信度 + + Returns: + (prediction, confidence): ("strong"/"weak", 置信度分数) + """ + inputs = self.tokenizer( + query, + return_tensors="pt", + truncation=True, + max_length=512, + padding=True + ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self.model(**inputs) + probs = torch.softmax(outputs.logits, dim=-1) + prediction = torch.argmax(probs, dim=-1).item() + confidence = probs[0][prediction].item() + + result = "strong" if prediction == 1 else "weak" + return result, confidence + + def benchmark(self, query: str, n_runs: int = 10) -> dict: + """ + 基准测试推理延迟 + + Args: + query: 测试查询 + n_runs: 运行次数 + + Returns: + 延迟统计信息 + """ + latencies = [] + + # 预热 + for _ in range(3): + self.predict(query) + + # 正式测试 + for _ in range(n_runs): + start = time.time() + self.predict(query) + latencies.append((time.time() - start) * 1000) + + return { + "avg_ms": sum(latencies) / len(latencies), + "min_ms": min(latencies), + "max_ms": max(latencies), + "device": self.device, + } + + +# 全局路由器实例(延迟加载) +_bert_router: Optional[BERTRouter] = None + + +def get_bert_router() -> BERTRouter: + """获取全局 BERT Router 实例""" + global _bert_router + if _bert_router is None: + _bert_router = BERTRouter() + return _bert_router + + +def route_with_bert(query: str) -> str: + """ + 使用 BERT 进行路由决策的便捷函数 + + Args: + query: 用户查询 + + Returns: + "strong" 或 "weak" + """ + router = get_bert_router() + return router.predict(query) diff --git a/main.py b/main.py index da73cd8..ffd03a2 100644 --- a/main.py +++ b/main.py @@ -13,6 +13,7 @@ from litellm import acompletion import litellm from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING, DASHSCOPE_API_KEY +from bert_router import get_bert_router, route_with_bert # 配置 LiteLLM 使用 DashScope (Qwen) if DASHSCOPE_API_KEY: @@ -20,6 +21,16 @@ if DASHSCOPE_API_KEY: # Qwen 使用 OpenAI 兼容接口,但需要通过 api_base 指定 litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1" +# BERT Router 实例(延迟加载) +_bert_router = None + +def get_router(): + """获取 BERT Router 实例(延迟加载)""" + global _bert_router + if _bert_router is None: + _bert_router = get_bert_router() + return _bert_router + # 调用历史记录 call_history: List[Dict[str, Any]] = [] @@ -80,7 +91,7 @@ def estimate_tokens(messages: List[Message]) -> int: def select_model_by_length(messages: List[Message]) -> str: - """基于 token 长度选择模型""" + """基于 token 长度选择模型(备用策略)""" token_count = estimate_tokens(messages) if token_count < ROUTING_THRESHOLDS["simple"]: @@ -91,6 +102,33 @@ def select_model_by_length(messages: List[Message]) -> str: return DEFAULT_ROUTING["complex"] +def select_model_by_bert(messages: List[Message]) -> str: + """ + 基于 BERT 分类器选择模型 + + BERT 输出: strong / weak + 映射到 Qwen 模型: + - strong -> qwen-max (复杂任务) + - weak -> qwen-flash (简单任务) + """ + # 取最后一条用户消息作为查询 + query = messages[-1].content if messages else "" + + try: + router = get_router() + complexity = router.predict(query) + + # BERT 二分类映射到三模型 + if complexity == "strong": + return "qwen-max" + else: + return "qwen-flash" + except Exception as e: + # BERT 失败时回退到 token 长度策略 + print(f"BERT routing failed: {e}, falling back to token length") + return select_model_by_length(messages) + + def get_provider_model(model_key: str) -> str: """获取 LiteLLM 格式的模型名称""" config = MODEL_CONFIG.get(model_key) @@ -142,7 +180,8 @@ async def chat_completions(request: ChatRequest): if request.model: model_key = request.model else: - model_key = select_model_by_length(request.messages) + # 使用 BERT 智能路由(替代原来的 token 长度路由) + model_key = select_model_by_bert(request.messages) # 获取 LiteLLM 模型名称 provider_model = get_provider_model(model_key) diff --git a/requirements.txt b/requirements.txt index 9925d56..a14d412 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,5 +5,7 @@ litellm>=1.0.0 tiktoken>=0.5.0 httpx>=0.25.0 python-dotenv>=1.0.0 +transformers>=4.30.0 +torch>=2.0.0 pytest>=7.4.0 pytest-asyncio>=0.21.0