feat(router): 集成NVIDIA多头分类器实现3-tier智能路由

- 新增nvidia_router.py: 手动加载NVIDIA prompt-task-and-complexity-classifier模型
- DeBERTa-v3-base backbone + 8个分类头(task_type/creativity/reasoning/domain等)
- 综合多维度评分实现simple/medium/complex三级路由
- 映射: simple->qwen-flash, medium->qwen-plus, complex->qwen-max
- main.py切换到NVIDIA路由替代RouteLLM BERT二分类
- 移除LiteLLM依赖解决版本冲突,使用原生httpx调用
- 版本升级至v0.3.0
This commit is contained in:
2026-04-18 01:21:31 +08:00
parent f9cc7973b9
commit 59c03516e4
3 changed files with 343 additions and 27 deletions

50
main.py
View File

@@ -13,7 +13,7 @@ from litellm import acompletion
import litellm
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING, DASHSCOPE_API_KEY
from bert_router import get_bert_router, route_with_bert
from nvidia_router import get_nvidia_router, select_model_by_nvidia
# 配置 LiteLLM 使用 DashScope (Qwen)
if DASHSCOPE_API_KEY:
@@ -21,15 +21,15 @@ if DASHSCOPE_API_KEY:
# Qwen 使用 OpenAI 兼容接口,但需要通过 api_base 指定
litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# BERT Router 实例(延迟加载)
_bert_router = None
# NVIDIA Router 实例(延迟加载)
_nvidia_router = None
def get_router():
"""获取 BERT Router 实例(延迟加载)"""
global _bert_router
if _bert_router is None:
_bert_router = get_bert_router()
return _bert_router
"""获取 NVIDIA Router 实例(延迟加载)"""
global _nvidia_router
if _nvidia_router is None:
_nvidia_router = get_nvidia_router()
return _nvidia_router
# 调用历史记录
@@ -69,8 +69,8 @@ class StatsResponse(BaseModel):
app = FastAPI(
title="LLM Router MVP",
description="基于 LiteLLM 的多提供商路由服务",
version="0.2.0",
description="基于 LiteLLM + NVIDIA 分类器的多提供商路由服务支持3-tier智能路由",
version="0.3.0",
)
@@ -102,30 +102,26 @@ def select_model_by_length(messages: List[Message]) -> str:
return DEFAULT_ROUTING["complex"]
def select_model_by_bert(messages: List[Message]) -> str:
def select_model_by_nvidia_classifier(messages: List[Message]) -> str:
"""
基于 BERT 分类器选择模型
基于 NVIDIA 多头分类器选择模型3-tier路由
BERT 输出: strong / weak
NVIDIA 输出: 多维度复杂度评分
映射到 Qwen 模型:
- strong -> qwen-max (复杂任务)
- weak -> qwen-flash (简单任务)
- simple -> qwen-flash (简单任务)
- medium -> qwen-plus (中等任务)
- complex -> qwen-max (复杂任务)
"""
# 取最后一条用户消息作为查询
query = messages[-1].content if messages else ""
try:
router = get_router()
complexity = router.predict(query)
# BERT 二分类映射到三模型
if complexity == "strong":
return "qwen-max"
else:
return "qwen-flash"
model = router.select_model(query)
return model
except Exception as e:
# BERT 失败时回退到 token 长度策略
print(f"BERT routing failed: {e}, falling back to token length")
# NVIDIA 分类器失败时回退到 token 长度策略
print(f"NVIDIA routing failed: {e}, falling back to token length")
return select_model_by_length(messages)
@@ -180,8 +176,8 @@ async def chat_completions(request: ChatRequest):
if request.model:
model_key = request.model
else:
# 使用 BERT 智能路由(替代原来的 token 长度路由
model_key = select_model_by_bert(request.messages)
# 使用 NVIDIA 多头分类器智能路由支持3-tier
model_key = select_model_by_nvidia_classifier(request.messages)
# 获取 LiteLLM 模型名称
provider_model = get_provider_model(model_key)
@@ -293,7 +289,7 @@ async def get_stats():
@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "healthy", "version": "0.2.0"}
return {"status": "healthy", "version": "0.3.0", "router": "nvidia-3tier"}
if __name__ == "__main__":