Files
llm-compass/bert_router.py
aszerW f9cc7973b9 feat: integrate RouteLLM BERT router for intelligent query classification
- 添加 transformers 和 torch 依赖
- 创建 bert_router.py 封装 RouteLLM BERT 分类器
- 新增 select_model_by_bert() 函数替代 token 长度路由
- BERT 输出映射: strong->qwen-max, weak->qwen-flash
- 保留 token 长度路由作为 fallback
2026-04-18 00:12:51 +08:00

162 lines
4.4 KiB
Python

"""
RouteLLM BERT Router 封装
基于预训练的 BERT 分类器进行查询复杂度预测
"""
import time
from typing import Optional
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
class BERTRouter:
"""
RouteLLM BERT 路由器
模型信息:
- 基础模型: BERT-base-uncased
- 参数量: ~110M
- 输入长度: 512 tokens
- 输出: 二分类 (0=弱模型, 1=强模型)
- 预期延迟: 1-5ms (CPU)
使用方法:
router = BERTRouter()
result = router.predict("你的查询文本")
# result: "strong""weak"
"""
MODEL_NAME = "lm-sys/routellm-bert"
def __init__(self, device: Optional[str] = None):
"""
初始化 BERT Router
Args:
device: 运行设备 ('cpu', 'cuda', 或 None自动选择)
"""
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self._load_model()
def _load_model(self):
"""加载模型和tokenizer"""
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME)
self.model = AutoModelForSequenceClassification.from_pretrained(self.MODEL_NAME)
self.model.to(self.device)
self.model.eval()
except Exception as e:
raise RuntimeError(f"Failed to load BERT router model: {e}")
def predict(self, query: str) -> str:
"""
预测查询复杂度
Args:
query: 用户查询文本
Returns:
"strong": 复杂任务,应使用强模型
"weak": 简单任务,应使用弱模型
"""
# 编码输入
inputs = self.tokenizer(
query,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
# 移动到设备
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# 推理
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
prediction = torch.argmax(probs, dim=-1).item()
# 0 = 弱模型, 1 = 强模型
return "strong" if prediction == 1 else "weak"
def predict_with_confidence(self, query: str) -> tuple:
"""
预测并返回置信度
Returns:
(prediction, confidence): ("strong"/"weak", 置信度分数)
"""
inputs = self.tokenizer(
query,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
prediction = torch.argmax(probs, dim=-1).item()
confidence = probs[0][prediction].item()
result = "strong" if prediction == 1 else "weak"
return result, confidence
def benchmark(self, query: str, n_runs: int = 10) -> dict:
"""
基准测试推理延迟
Args:
query: 测试查询
n_runs: 运行次数
Returns:
延迟统计信息
"""
latencies = []
# 预热
for _ in range(3):
self.predict(query)
# 正式测试
for _ in range(n_runs):
start = time.time()
self.predict(query)
latencies.append((time.time() - start) * 1000)
return {
"avg_ms": sum(latencies) / len(latencies),
"min_ms": min(latencies),
"max_ms": max(latencies),
"device": self.device,
}
# 全局路由器实例(延迟加载)
_bert_router: Optional[BERTRouter] = None
def get_bert_router() -> BERTRouter:
"""获取全局 BERT Router 实例"""
global _bert_router
if _bert_router is None:
_bert_router = BERTRouter()
return _bert_router
def route_with_bert(query: str) -> str:
"""
使用 BERT 进行路由决策的便捷函数
Args:
query: 用户查询
Returns:
"strong""weak"
"""
router = get_bert_router()
return router.predict(query)