feat: integrate RouteLLM BERT router for intelligent query classification
- 添加 transformers 和 torch 依赖 - 创建 bert_router.py 封装 RouteLLM BERT 分类器 - 新增 select_model_by_bert() 函数替代 token 长度路由 - BERT 输出映射: strong->qwen-max, weak->qwen-flash - 保留 token 长度路由作为 fallback
This commit is contained in:
161
bert_router.py
Normal file
161
bert_router.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
"""
|
||||||
|
RouteLLM BERT Router 封装
|
||||||
|
基于预训练的 BERT 分类器进行查询复杂度预测
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||||
|
|
||||||
|
|
||||||
|
class BERTRouter:
|
||||||
|
"""
|
||||||
|
RouteLLM BERT 路由器
|
||||||
|
|
||||||
|
模型信息:
|
||||||
|
- 基础模型: BERT-base-uncased
|
||||||
|
- 参数量: ~110M
|
||||||
|
- 输入长度: 512 tokens
|
||||||
|
- 输出: 二分类 (0=弱模型, 1=强模型)
|
||||||
|
- 预期延迟: 1-5ms (CPU)
|
||||||
|
|
||||||
|
使用方法:
|
||||||
|
router = BERTRouter()
|
||||||
|
result = router.predict("你的查询文本")
|
||||||
|
# result: "strong" 或 "weak"
|
||||||
|
"""
|
||||||
|
|
||||||
|
MODEL_NAME = "lm-sys/routellm-bert"
|
||||||
|
|
||||||
|
def __init__(self, device: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
初始化 BERT Router
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: 运行设备 ('cpu', 'cuda', 或 None自动选择)
|
||||||
|
"""
|
||||||
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
def _load_model(self):
|
||||||
|
"""加载模型和tokenizer"""
|
||||||
|
try:
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME)
|
||||||
|
self.model = AutoModelForSequenceClassification.from_pretrained(self.MODEL_NAME)
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load BERT router model: {e}")
|
||||||
|
|
||||||
|
def predict(self, query: str) -> str:
|
||||||
|
"""
|
||||||
|
预测查询复杂度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 用户查询文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"strong": 复杂任务,应使用强模型
|
||||||
|
"weak": 简单任务,应使用弱模型
|
||||||
|
"""
|
||||||
|
# 编码输入
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
query,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=True,
|
||||||
|
max_length=512,
|
||||||
|
padding=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 移动到设备
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
# 推理
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
probs = torch.softmax(outputs.logits, dim=-1)
|
||||||
|
prediction = torch.argmax(probs, dim=-1).item()
|
||||||
|
|
||||||
|
# 0 = 弱模型, 1 = 强模型
|
||||||
|
return "strong" if prediction == 1 else "weak"
|
||||||
|
|
||||||
|
def predict_with_confidence(self, query: str) -> tuple:
|
||||||
|
"""
|
||||||
|
预测并返回置信度
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(prediction, confidence): ("strong"/"weak", 置信度分数)
|
||||||
|
"""
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
query,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=True,
|
||||||
|
max_length=512,
|
||||||
|
padding=True
|
||||||
|
)
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
probs = torch.softmax(outputs.logits, dim=-1)
|
||||||
|
prediction = torch.argmax(probs, dim=-1).item()
|
||||||
|
confidence = probs[0][prediction].item()
|
||||||
|
|
||||||
|
result = "strong" if prediction == 1 else "weak"
|
||||||
|
return result, confidence
|
||||||
|
|
||||||
|
def benchmark(self, query: str, n_runs: int = 10) -> dict:
|
||||||
|
"""
|
||||||
|
基准测试推理延迟
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 测试查询
|
||||||
|
n_runs: 运行次数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
延迟统计信息
|
||||||
|
"""
|
||||||
|
latencies = []
|
||||||
|
|
||||||
|
# 预热
|
||||||
|
for _ in range(3):
|
||||||
|
self.predict(query)
|
||||||
|
|
||||||
|
# 正式测试
|
||||||
|
for _ in range(n_runs):
|
||||||
|
start = time.time()
|
||||||
|
self.predict(query)
|
||||||
|
latencies.append((time.time() - start) * 1000)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"avg_ms": sum(latencies) / len(latencies),
|
||||||
|
"min_ms": min(latencies),
|
||||||
|
"max_ms": max(latencies),
|
||||||
|
"device": self.device,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局路由器实例(延迟加载)
|
||||||
|
_bert_router: Optional[BERTRouter] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_bert_router() -> BERTRouter:
|
||||||
|
"""获取全局 BERT Router 实例"""
|
||||||
|
global _bert_router
|
||||||
|
if _bert_router is None:
|
||||||
|
_bert_router = BERTRouter()
|
||||||
|
return _bert_router
|
||||||
|
|
||||||
|
|
||||||
|
def route_with_bert(query: str) -> str:
|
||||||
|
"""
|
||||||
|
使用 BERT 进行路由决策的便捷函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 用户查询
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"strong" 或 "weak"
|
||||||
|
"""
|
||||||
|
router = get_bert_router()
|
||||||
|
return router.predict(query)
|
||||||
43
main.py
43
main.py
@@ -13,6 +13,7 @@ from litellm import acompletion
|
|||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING, DASHSCOPE_API_KEY
|
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING, DASHSCOPE_API_KEY
|
||||||
|
from bert_router import get_bert_router, route_with_bert
|
||||||
|
|
||||||
# 配置 LiteLLM 使用 DashScope (Qwen)
|
# 配置 LiteLLM 使用 DashScope (Qwen)
|
||||||
if DASHSCOPE_API_KEY:
|
if DASHSCOPE_API_KEY:
|
||||||
@@ -20,6 +21,16 @@ if DASHSCOPE_API_KEY:
|
|||||||
# Qwen 使用 OpenAI 兼容接口,但需要通过 api_base 指定
|
# Qwen 使用 OpenAI 兼容接口,但需要通过 api_base 指定
|
||||||
litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
|
|
||||||
|
# BERT Router 实例(延迟加载)
|
||||||
|
_bert_router = None
|
||||||
|
|
||||||
|
def get_router():
|
||||||
|
"""获取 BERT Router 实例(延迟加载)"""
|
||||||
|
global _bert_router
|
||||||
|
if _bert_router is None:
|
||||||
|
_bert_router = get_bert_router()
|
||||||
|
return _bert_router
|
||||||
|
|
||||||
|
|
||||||
# 调用历史记录
|
# 调用历史记录
|
||||||
call_history: List[Dict[str, Any]] = []
|
call_history: List[Dict[str, Any]] = []
|
||||||
@@ -80,7 +91,7 @@ def estimate_tokens(messages: List[Message]) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def select_model_by_length(messages: List[Message]) -> str:
|
def select_model_by_length(messages: List[Message]) -> str:
|
||||||
"""基于 token 长度选择模型"""
|
"""基于 token 长度选择模型(备用策略)"""
|
||||||
token_count = estimate_tokens(messages)
|
token_count = estimate_tokens(messages)
|
||||||
|
|
||||||
if token_count < ROUTING_THRESHOLDS["simple"]:
|
if token_count < ROUTING_THRESHOLDS["simple"]:
|
||||||
@@ -91,6 +102,33 @@ def select_model_by_length(messages: List[Message]) -> str:
|
|||||||
return DEFAULT_ROUTING["complex"]
|
return DEFAULT_ROUTING["complex"]
|
||||||
|
|
||||||
|
|
||||||
|
def select_model_by_bert(messages: List[Message]) -> str:
|
||||||
|
"""
|
||||||
|
基于 BERT 分类器选择模型
|
||||||
|
|
||||||
|
BERT 输出: strong / weak
|
||||||
|
映射到 Qwen 模型:
|
||||||
|
- strong -> qwen-max (复杂任务)
|
||||||
|
- weak -> qwen-flash (简单任务)
|
||||||
|
"""
|
||||||
|
# 取最后一条用户消息作为查询
|
||||||
|
query = messages[-1].content if messages else ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
router = get_router()
|
||||||
|
complexity = router.predict(query)
|
||||||
|
|
||||||
|
# BERT 二分类映射到三模型
|
||||||
|
if complexity == "strong":
|
||||||
|
return "qwen-max"
|
||||||
|
else:
|
||||||
|
return "qwen-flash"
|
||||||
|
except Exception as e:
|
||||||
|
# BERT 失败时回退到 token 长度策略
|
||||||
|
print(f"BERT routing failed: {e}, falling back to token length")
|
||||||
|
return select_model_by_length(messages)
|
||||||
|
|
||||||
|
|
||||||
def get_provider_model(model_key: str) -> str:
|
def get_provider_model(model_key: str) -> str:
|
||||||
"""获取 LiteLLM 格式的模型名称"""
|
"""获取 LiteLLM 格式的模型名称"""
|
||||||
config = MODEL_CONFIG.get(model_key)
|
config = MODEL_CONFIG.get(model_key)
|
||||||
@@ -142,7 +180,8 @@ async def chat_completions(request: ChatRequest):
|
|||||||
if request.model:
|
if request.model:
|
||||||
model_key = request.model
|
model_key = request.model
|
||||||
else:
|
else:
|
||||||
model_key = select_model_by_length(request.messages)
|
# 使用 BERT 智能路由(替代原来的 token 长度路由)
|
||||||
|
model_key = select_model_by_bert(request.messages)
|
||||||
|
|
||||||
# 获取 LiteLLM 模型名称
|
# 获取 LiteLLM 模型名称
|
||||||
provider_model = get_provider_model(model_key)
|
provider_model = get_provider_model(model_key)
|
||||||
|
|||||||
@@ -5,5 +5,7 @@ litellm>=1.0.0
|
|||||||
tiktoken>=0.5.0
|
tiktoken>=0.5.0
|
||||||
httpx>=0.25.0
|
httpx>=0.25.0
|
||||||
python-dotenv>=1.0.0
|
python-dotenv>=1.0.0
|
||||||
|
transformers>=4.30.0
|
||||||
|
torch>=2.0.0
|
||||||
pytest>=7.4.0
|
pytest>=7.4.0
|
||||||
pytest-asyncio>=0.21.0
|
pytest-asyncio>=0.21.0
|
||||||
|
|||||||
Reference in New Issue
Block a user