refactor: 移除RouteLLM BERT路由模块
已切换到NVIDIA多头分类器,不再需要bert_router.py
This commit is contained in:
161
bert_router.py
161
bert_router.py
@@ -1,161 +0,0 @@
|
|||||||
"""
|
|
||||||
RouteLLM BERT Router 封装
|
|
||||||
基于预训练的 BERT 分类器进行查询复杂度预测
|
|
||||||
"""
|
|
||||||
import time
|
|
||||||
from typing import Optional
|
|
||||||
import torch
|
|
||||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
||||||
|
|
||||||
|
|
||||||
class BERTRouter:
|
|
||||||
"""
|
|
||||||
RouteLLM BERT 路由器
|
|
||||||
|
|
||||||
模型信息:
|
|
||||||
- 基础模型: BERT-base-uncased
|
|
||||||
- 参数量: ~110M
|
|
||||||
- 输入长度: 512 tokens
|
|
||||||
- 输出: 二分类 (0=弱模型, 1=强模型)
|
|
||||||
- 预期延迟: 1-5ms (CPU)
|
|
||||||
|
|
||||||
使用方法:
|
|
||||||
router = BERTRouter()
|
|
||||||
result = router.predict("你的查询文本")
|
|
||||||
# result: "strong" 或 "weak"
|
|
||||||
"""
|
|
||||||
|
|
||||||
MODEL_NAME = "lm-sys/routellm-bert"
|
|
||||||
|
|
||||||
def __init__(self, device: Optional[str] = None):
|
|
||||||
"""
|
|
||||||
初始化 BERT Router
|
|
||||||
|
|
||||||
Args:
|
|
||||||
device: 运行设备 ('cpu', 'cuda', 或 None自动选择)
|
|
||||||
"""
|
|
||||||
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
self._load_model()
|
|
||||||
|
|
||||||
def _load_model(self):
|
|
||||||
"""加载模型和tokenizer"""
|
|
||||||
try:
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME)
|
|
||||||
self.model = AutoModelForSequenceClassification.from_pretrained(self.MODEL_NAME)
|
|
||||||
self.model.to(self.device)
|
|
||||||
self.model.eval()
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to load BERT router model: {e}")
|
|
||||||
|
|
||||||
def predict(self, query: str) -> str:
|
|
||||||
"""
|
|
||||||
预测查询复杂度
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 用户查询文本
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
"strong": 复杂任务,应使用强模型
|
|
||||||
"weak": 简单任务,应使用弱模型
|
|
||||||
"""
|
|
||||||
# 编码输入
|
|
||||||
inputs = self.tokenizer(
|
|
||||||
query,
|
|
||||||
return_tensors="pt",
|
|
||||||
truncation=True,
|
|
||||||
max_length=512,
|
|
||||||
padding=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 移动到设备
|
|
||||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
||||||
|
|
||||||
# 推理
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
probs = torch.softmax(outputs.logits, dim=-1)
|
|
||||||
prediction = torch.argmax(probs, dim=-1).item()
|
|
||||||
|
|
||||||
# 0 = 弱模型, 1 = 强模型
|
|
||||||
return "strong" if prediction == 1 else "weak"
|
|
||||||
|
|
||||||
def predict_with_confidence(self, query: str) -> tuple:
|
|
||||||
"""
|
|
||||||
预测并返回置信度
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(prediction, confidence): ("strong"/"weak", 置信度分数)
|
|
||||||
"""
|
|
||||||
inputs = self.tokenizer(
|
|
||||||
query,
|
|
||||||
return_tensors="pt",
|
|
||||||
truncation=True,
|
|
||||||
max_length=512,
|
|
||||||
padding=True
|
|
||||||
)
|
|
||||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
probs = torch.softmax(outputs.logits, dim=-1)
|
|
||||||
prediction = torch.argmax(probs, dim=-1).item()
|
|
||||||
confidence = probs[0][prediction].item()
|
|
||||||
|
|
||||||
result = "strong" if prediction == 1 else "weak"
|
|
||||||
return result, confidence
|
|
||||||
|
|
||||||
def benchmark(self, query: str, n_runs: int = 10) -> dict:
|
|
||||||
"""
|
|
||||||
基准测试推理延迟
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 测试查询
|
|
||||||
n_runs: 运行次数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
延迟统计信息
|
|
||||||
"""
|
|
||||||
latencies = []
|
|
||||||
|
|
||||||
# 预热
|
|
||||||
for _ in range(3):
|
|
||||||
self.predict(query)
|
|
||||||
|
|
||||||
# 正式测试
|
|
||||||
for _ in range(n_runs):
|
|
||||||
start = time.time()
|
|
||||||
self.predict(query)
|
|
||||||
latencies.append((time.time() - start) * 1000)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"avg_ms": sum(latencies) / len(latencies),
|
|
||||||
"min_ms": min(latencies),
|
|
||||||
"max_ms": max(latencies),
|
|
||||||
"device": self.device,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# 全局路由器实例(延迟加载)
|
|
||||||
_bert_router: Optional[BERTRouter] = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_bert_router() -> BERTRouter:
|
|
||||||
"""获取全局 BERT Router 实例"""
|
|
||||||
global _bert_router
|
|
||||||
if _bert_router is None:
|
|
||||||
_bert_router = BERTRouter()
|
|
||||||
return _bert_router
|
|
||||||
|
|
||||||
|
|
||||||
def route_with_bert(query: str) -> str:
|
|
||||||
"""
|
|
||||||
使用 BERT 进行路由决策的便捷函数
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 用户查询
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
"strong" 或 "weak"
|
|
||||||
"""
|
|
||||||
router = get_bert_router()
|
|
||||||
return router.predict(query)
|
|
||||||
Reference in New Issue
Block a user