""" RouteLLM BERT Router 封装 基于预训练的 BERT 分类器进行查询复杂度预测 """ import time from typing import Optional import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification class BERTRouter: """ RouteLLM BERT 路由器 模型信息: - 基础模型: BERT-base-uncased - 参数量: ~110M - 输入长度: 512 tokens - 输出: 二分类 (0=弱模型, 1=强模型) - 预期延迟: 1-5ms (CPU) 使用方法: router = BERTRouter() result = router.predict("你的查询文本") # result: "strong" 或 "weak" """ MODEL_NAME = "lm-sys/routellm-bert" def __init__(self, device: Optional[str] = None): """ 初始化 BERT Router Args: device: 运行设备 ('cpu', 'cuda', 或 None自动选择) """ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self._load_model() def _load_model(self): """加载模型和tokenizer""" try: self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME) self.model = AutoModelForSequenceClassification.from_pretrained(self.MODEL_NAME) self.model.to(self.device) self.model.eval() except Exception as e: raise RuntimeError(f"Failed to load BERT router model: {e}") def predict(self, query: str) -> str: """ 预测查询复杂度 Args: query: 用户查询文本 Returns: "strong": 复杂任务,应使用强模型 "weak": 简单任务,应使用弱模型 """ # 编码输入 inputs = self.tokenizer( query, return_tensors="pt", truncation=True, max_length=512, padding=True ) # 移动到设备 inputs = {k: v.to(self.device) for k, v in inputs.items()} # 推理 with torch.no_grad(): outputs = self.model(**inputs) probs = torch.softmax(outputs.logits, dim=-1) prediction = torch.argmax(probs, dim=-1).item() # 0 = 弱模型, 1 = 强模型 return "strong" if prediction == 1 else "weak" def predict_with_confidence(self, query: str) -> tuple: """ 预测并返回置信度 Returns: (prediction, confidence): ("strong"/"weak", 置信度分数) """ inputs = self.tokenizer( query, return_tensors="pt", truncation=True, max_length=512, padding=True ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) probs = torch.softmax(outputs.logits, dim=-1) prediction = torch.argmax(probs, dim=-1).item() confidence = probs[0][prediction].item() result = "strong" if prediction == 1 else "weak" return result, confidence def benchmark(self, query: str, n_runs: int = 10) -> dict: """ 基准测试推理延迟 Args: query: 测试查询 n_runs: 运行次数 Returns: 延迟统计信息 """ latencies = [] # 预热 for _ in range(3): self.predict(query) # 正式测试 for _ in range(n_runs): start = time.time() self.predict(query) latencies.append((time.time() - start) * 1000) return { "avg_ms": sum(latencies) / len(latencies), "min_ms": min(latencies), "max_ms": max(latencies), "device": self.device, } # 全局路由器实例(延迟加载) _bert_router: Optional[BERTRouter] = None def get_bert_router() -> BERTRouter: """获取全局 BERT Router 实例""" global _bert_router if _bert_router is None: _bert_router = BERTRouter() return _bert_router def route_with_bert(query: str) -> str: """ 使用 BERT 进行路由决策的便捷函数 Args: query: 用户查询 Returns: "strong" 或 "weak" """ router = get_bert_router() return router.predict(query)