feat: integrate RouteLLM BERT router for intelligent query classification

- 添加 transformers 和 torch 依赖
- 创建 bert_router.py 封装 RouteLLM BERT 分类器
- 新增 select_model_by_bert() 函数替代 token 长度路由
- BERT 输出映射: strong->qwen-max, weak->qwen-flash
- 保留 token 长度路由作为 fallback
This commit is contained in:
2026-04-18 00:12:51 +08:00
parent 88842457ea
commit f9cc7973b9
3 changed files with 204 additions and 2 deletions

43
main.py
View File

@@ -13,6 +13,7 @@ from litellm import acompletion
import litellm
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING, DASHSCOPE_API_KEY
from bert_router import get_bert_router, route_with_bert
# 配置 LiteLLM 使用 DashScope (Qwen)
if DASHSCOPE_API_KEY:
@@ -20,6 +21,16 @@ if DASHSCOPE_API_KEY:
# Qwen 使用 OpenAI 兼容接口,但需要通过 api_base 指定
litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# BERT Router 实例(延迟加载)
_bert_router = None
def get_router():
"""获取 BERT Router 实例(延迟加载)"""
global _bert_router
if _bert_router is None:
_bert_router = get_bert_router()
return _bert_router
# 调用历史记录
call_history: List[Dict[str, Any]] = []
@@ -80,7 +91,7 @@ def estimate_tokens(messages: List[Message]) -> int:
def select_model_by_length(messages: List[Message]) -> str:
"""基于 token 长度选择模型"""
"""基于 token 长度选择模型(备用策略)"""
token_count = estimate_tokens(messages)
if token_count < ROUTING_THRESHOLDS["simple"]:
@@ -91,6 +102,33 @@ def select_model_by_length(messages: List[Message]) -> str:
return DEFAULT_ROUTING["complex"]
def select_model_by_bert(messages: List[Message]) -> str:
"""
基于 BERT 分类器选择模型
BERT 输出: strong / weak
映射到 Qwen 模型:
- strong -> qwen-max (复杂任务)
- weak -> qwen-flash (简单任务)
"""
# 取最后一条用户消息作为查询
query = messages[-1].content if messages else ""
try:
router = get_router()
complexity = router.predict(query)
# BERT 二分类映射到三模型
if complexity == "strong":
return "qwen-max"
else:
return "qwen-flash"
except Exception as e:
# BERT 失败时回退到 token 长度策略
print(f"BERT routing failed: {e}, falling back to token length")
return select_model_by_length(messages)
def get_provider_model(model_key: str) -> str:
"""获取 LiteLLM 格式的模型名称"""
config = MODEL_CONFIG.get(model_key)
@@ -142,7 +180,8 @@ async def chat_completions(request: ChatRequest):
if request.model:
model_key = request.model
else:
model_key = select_model_by_length(request.messages)
# 使用 BERT 智能路由(替代原来的 token 长度路由)
model_key = select_model_by_bert(request.messages)
# 获取 LiteLLM 模型名称
provider_model = get_provider_model(model_key)