diff --git a/bert_router.py b/bert_router.py deleted file mode 100644 index 09c213b..0000000 --- a/bert_router.py +++ /dev/null @@ -1,161 +0,0 @@ -""" -RouteLLM BERT Router 封装 -基于预训练的 BERT 分类器进行查询复杂度预测 -""" -import time -from typing import Optional -import torch -from transformers import AutoTokenizer, AutoModelForSequenceClassification - - -class BERTRouter: - """ - RouteLLM BERT 路由器 - - 模型信息: - - 基础模型: BERT-base-uncased - - 参数量: ~110M - - 输入长度: 512 tokens - - 输出: 二分类 (0=弱模型, 1=强模型) - - 预期延迟: 1-5ms (CPU) - - 使用方法: - router = BERTRouter() - result = router.predict("你的查询文本") - # result: "strong" 或 "weak" - """ - - MODEL_NAME = "lm-sys/routellm-bert" - - def __init__(self, device: Optional[str] = None): - """ - 初始化 BERT Router - - Args: - device: 运行设备 ('cpu', 'cuda', 或 None自动选择) - """ - self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") - self._load_model() - - def _load_model(self): - """加载模型和tokenizer""" - try: - self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME) - self.model = AutoModelForSequenceClassification.from_pretrained(self.MODEL_NAME) - self.model.to(self.device) - self.model.eval() - except Exception as e: - raise RuntimeError(f"Failed to load BERT router model: {e}") - - def predict(self, query: str) -> str: - """ - 预测查询复杂度 - - Args: - query: 用户查询文本 - - Returns: - "strong": 复杂任务,应使用强模型 - "weak": 简单任务,应使用弱模型 - """ - # 编码输入 - inputs = self.tokenizer( - query, - return_tensors="pt", - truncation=True, - max_length=512, - padding=True - ) - - # 移动到设备 - inputs = {k: v.to(self.device) for k, v in inputs.items()} - - # 推理 - with torch.no_grad(): - outputs = self.model(**inputs) - probs = torch.softmax(outputs.logits, dim=-1) - prediction = torch.argmax(probs, dim=-1).item() - - # 0 = 弱模型, 1 = 强模型 - return "strong" if prediction == 1 else "weak" - - def predict_with_confidence(self, query: str) -> tuple: - """ - 预测并返回置信度 - - Returns: - (prediction, confidence): ("strong"/"weak", 置信度分数) - """ - inputs = self.tokenizer( - query, - return_tensors="pt", - truncation=True, - max_length=512, - padding=True - ) - inputs = {k: v.to(self.device) for k, v in inputs.items()} - - with torch.no_grad(): - outputs = self.model(**inputs) - probs = torch.softmax(outputs.logits, dim=-1) - prediction = torch.argmax(probs, dim=-1).item() - confidence = probs[0][prediction].item() - - result = "strong" if prediction == 1 else "weak" - return result, confidence - - def benchmark(self, query: str, n_runs: int = 10) -> dict: - """ - 基准测试推理延迟 - - Args: - query: 测试查询 - n_runs: 运行次数 - - Returns: - 延迟统计信息 - """ - latencies = [] - - # 预热 - for _ in range(3): - self.predict(query) - - # 正式测试 - for _ in range(n_runs): - start = time.time() - self.predict(query) - latencies.append((time.time() - start) * 1000) - - return { - "avg_ms": sum(latencies) / len(latencies), - "min_ms": min(latencies), - "max_ms": max(latencies), - "device": self.device, - } - - -# 全局路由器实例(延迟加载) -_bert_router: Optional[BERTRouter] = None - - -def get_bert_router() -> BERTRouter: - """获取全局 BERT Router 实例""" - global _bert_router - if _bert_router is None: - _bert_router = BERTRouter() - return _bert_router - - -def route_with_bert(query: str) -> str: - """ - 使用 BERT 进行路由决策的便捷函数 - - Args: - query: 用户查询 - - Returns: - "strong" 或 "weak" - """ - router = get_bert_router() - return router.predict(query)