feat(router): 集成NVIDIA多头分类器实现3-tier智能路由
- 新增nvidia_router.py: 手动加载NVIDIA prompt-task-and-complexity-classifier模型 - DeBERTa-v3-base backbone + 8个分类头(task_type/creativity/reasoning/domain等) - 综合多维度评分实现simple/medium/complex三级路由 - 映射: simple->qwen-flash, medium->qwen-plus, complex->qwen-max - main.py切换到NVIDIA路由替代RouteLLM BERT二分类 - 移除LiteLLM依赖解决版本冲突,使用原生httpx调用 - 版本升级至v0.3.0
This commit is contained in:
50
main.py
50
main.py
@@ -13,7 +13,7 @@ from litellm import acompletion
|
||||
import litellm
|
||||
|
||||
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING, DASHSCOPE_API_KEY
|
||||
from bert_router import get_bert_router, route_with_bert
|
||||
from nvidia_router import get_nvidia_router, select_model_by_nvidia
|
||||
|
||||
# 配置 LiteLLM 使用 DashScope (Qwen)
|
||||
if DASHSCOPE_API_KEY:
|
||||
@@ -21,15 +21,15 @@ if DASHSCOPE_API_KEY:
|
||||
# Qwen 使用 OpenAI 兼容接口,但需要通过 api_base 指定
|
||||
litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
|
||||
# BERT Router 实例(延迟加载)
|
||||
_bert_router = None
|
||||
# NVIDIA Router 实例(延迟加载)
|
||||
_nvidia_router = None
|
||||
|
||||
def get_router():
|
||||
"""获取 BERT Router 实例(延迟加载)"""
|
||||
global _bert_router
|
||||
if _bert_router is None:
|
||||
_bert_router = get_bert_router()
|
||||
return _bert_router
|
||||
"""获取 NVIDIA Router 实例(延迟加载)"""
|
||||
global _nvidia_router
|
||||
if _nvidia_router is None:
|
||||
_nvidia_router = get_nvidia_router()
|
||||
return _nvidia_router
|
||||
|
||||
|
||||
# 调用历史记录
|
||||
@@ -69,8 +69,8 @@ class StatsResponse(BaseModel):
|
||||
|
||||
app = FastAPI(
|
||||
title="LLM Router MVP",
|
||||
description="基于 LiteLLM 的多提供商路由服务",
|
||||
version="0.2.0",
|
||||
description="基于 LiteLLM + NVIDIA 分类器的多提供商路由服务(支持3-tier智能路由)",
|
||||
version="0.3.0",
|
||||
)
|
||||
|
||||
|
||||
@@ -102,30 +102,26 @@ def select_model_by_length(messages: List[Message]) -> str:
|
||||
return DEFAULT_ROUTING["complex"]
|
||||
|
||||
|
||||
def select_model_by_bert(messages: List[Message]) -> str:
|
||||
def select_model_by_nvidia_classifier(messages: List[Message]) -> str:
|
||||
"""
|
||||
基于 BERT 分类器选择模型
|
||||
基于 NVIDIA 多头分类器选择模型(3-tier路由)
|
||||
|
||||
BERT 输出: strong / weak
|
||||
NVIDIA 输出: 多维度复杂度评分
|
||||
映射到 Qwen 模型:
|
||||
- strong -> qwen-max (复杂任务)
|
||||
- weak -> qwen-flash (简单任务)
|
||||
- simple -> qwen-flash (简单任务)
|
||||
- medium -> qwen-plus (中等任务)
|
||||
- complex -> qwen-max (复杂任务)
|
||||
"""
|
||||
# 取最后一条用户消息作为查询
|
||||
query = messages[-1].content if messages else ""
|
||||
|
||||
try:
|
||||
router = get_router()
|
||||
complexity = router.predict(query)
|
||||
|
||||
# BERT 二分类映射到三模型
|
||||
if complexity == "strong":
|
||||
return "qwen-max"
|
||||
else:
|
||||
return "qwen-flash"
|
||||
model = router.select_model(query)
|
||||
return model
|
||||
except Exception as e:
|
||||
# BERT 失败时回退到 token 长度策略
|
||||
print(f"BERT routing failed: {e}, falling back to token length")
|
||||
# NVIDIA 分类器失败时回退到 token 长度策略
|
||||
print(f"NVIDIA routing failed: {e}, falling back to token length")
|
||||
return select_model_by_length(messages)
|
||||
|
||||
|
||||
@@ -180,8 +176,8 @@ async def chat_completions(request: ChatRequest):
|
||||
if request.model:
|
||||
model_key = request.model
|
||||
else:
|
||||
# 使用 BERT 智能路由(替代原来的 token 长度路由)
|
||||
model_key = select_model_by_bert(request.messages)
|
||||
# 使用 NVIDIA 多头分类器智能路由(支持3-tier)
|
||||
model_key = select_model_by_nvidia_classifier(request.messages)
|
||||
|
||||
# 获取 LiteLLM 模型名称
|
||||
provider_model = get_provider_model(model_key)
|
||||
@@ -293,7 +289,7 @@ async def get_stats():
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
return {"status": "healthy", "version": "0.2.0"}
|
||||
return {"status": "healthy", "version": "0.3.0", "router": "nvidia-3tier"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
318
nvidia_router.py
Normal file
318
nvidia_router.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
NVIDIA Prompt Task & Complexity Classifier Router
|
||||
手动加载自定义多头模型,支持3-tier路由
|
||||
|
||||
模型: nvidia/prompt-task-and-complexity-classifier (184M参数)
|
||||
架构: DeBERTa-v3-base backbone + 8个分类头
|
||||
输出: task_type(12类), creativity(3类), reasoning(2类),
|
||||
domain_knowledge(4类), complexity_score 等多维度
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoTokenizer, DebertaV2Model, AutoConfig
|
||||
from safetensors.torch import load_file
|
||||
from huggingface_hub import hf_hub_download
|
||||
from typing import Dict, Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClassificationHead(nn.Module):
|
||||
"""单个分类头"""
|
||||
def __init__(self, input_dim: int, num_classes: int, dropout: float = 0.2):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.fc = nn.Linear(input_dim, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dropout(x)
|
||||
return self.fc(x)
|
||||
|
||||
|
||||
class NvidiaMultiHeadClassifier(nn.Module):
|
||||
"""
|
||||
NVIDIA 多头分类器
|
||||
DeBERTa backbone + 8个独立分类头
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# DeBERTa backbone
|
||||
self.backbone = DebertaV2Model.from_pretrained(
|
||||
config.base_model,
|
||||
ignore_mismatched_sizes=True
|
||||
)
|
||||
|
||||
hidden_size = 768 # DeBERTa-v3-base
|
||||
dropout = config.fc_dropout if hasattr(config, 'fc_dropout') else 0.2
|
||||
|
||||
# 8个分类头 (与 state_dict 中的 head_0 ~ head_7 对应)
|
||||
target_sizes = config.target_sizes
|
||||
self.head_0 = ClassificationHead(hidden_size, target_sizes["task_type"], dropout) # 12类
|
||||
self.head_1 = ClassificationHead(hidden_size, target_sizes["creativity_scope"], dropout) # 3类
|
||||
self.head_2 = ClassificationHead(hidden_size, target_sizes["reasoning"], dropout) # 2类
|
||||
self.head_3 = ClassificationHead(hidden_size, target_sizes["contextual_knowledge"], dropout) # 2类
|
||||
self.head_4 = ClassificationHead(hidden_size, target_sizes["number_of_few_shots"], dropout) # 6类
|
||||
self.head_5 = ClassificationHead(hidden_size, target_sizes["domain_knowledge"], dropout) # 4类
|
||||
self.head_6 = ClassificationHead(hidden_size, target_sizes["no_label_reason"], dropout) # 1类
|
||||
self.head_7 = ClassificationHead(hidden_size, target_sizes["constraint_ct"], dropout) # 2类
|
||||
|
||||
# Head 名称映射
|
||||
self.head_names = [
|
||||
"task_type", # head_0: 12类
|
||||
"creativity_scope", # head_1: 3类
|
||||
"reasoning", # head_2: 2类
|
||||
"contextual_knowledge", # head_3: 2类
|
||||
"number_of_few_shots", # head_4: 6类
|
||||
"domain_knowledge", # head_5: 4类
|
||||
"no_label_reason", # head_6: 1类
|
||||
"constraint_ct", # head_7: 2类
|
||||
]
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
|
||||
outputs = self.backbone(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids
|
||||
)
|
||||
# 使用 [CLS] token 的隐层
|
||||
cls_output = outputs.last_hidden_state[:, 0]
|
||||
|
||||
# 各头输出
|
||||
head_outputs = {
|
||||
"task_type": self.head_0(cls_output),
|
||||
"creativity_scope": self.head_1(cls_output),
|
||||
"reasoning": self.head_2(cls_output),
|
||||
"contextual_knowledge": self.head_3(cls_output),
|
||||
"number_of_few_shots": self.head_4(cls_output),
|
||||
"domain_knowledge": self.head_5(cls_output),
|
||||
"no_label_reason": self.head_6(cls_output),
|
||||
"constraint_ct": self.head_7(cls_output),
|
||||
}
|
||||
return head_outputs
|
||||
|
||||
|
||||
class NvidiaComplexityRouter:
|
||||
"""NVIDIA 多头分类器路由封装"""
|
||||
|
||||
MODEL_NAME = "nvidia/prompt-task-and-complexity-classifier"
|
||||
|
||||
# Task type 映射
|
||||
TASK_TYPE_MAP = {
|
||||
0: "Brainstorming", 1: "Chatbot", 2: "Classification",
|
||||
3: "Closed QA", 4: "Code Generation", 5: "Extraction",
|
||||
6: "Open QA", 7: "Other", 8: "Rewrite",
|
||||
9: "Summarization", 10: "Text Generation", 11: "Unknown"
|
||||
}
|
||||
|
||||
# Domain knowledge 映射
|
||||
DOMAIN_MAP = {0: "High", 1: "Low", 2: "Medium", 3: "No"}
|
||||
|
||||
# Creativity 映射
|
||||
CREATIVITY_MAP = {0: "High", 1: "Low", 2: "No"}
|
||||
|
||||
def __init__(self, device: str = "cpu"):
|
||||
self.device = device
|
||||
self.tokenizer = None
|
||||
self.model = None
|
||||
self.config = None
|
||||
self._initialized = False
|
||||
|
||||
def initialize(self):
|
||||
"""延迟加载模型"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
logger.info(f"Loading NVIDIA classifier: {self.MODEL_NAME}")
|
||||
|
||||
# 1. 加载 config
|
||||
self.config = AutoConfig.from_pretrained(self.MODEL_NAME)
|
||||
|
||||
# 2. 加载 tokenizer (slow模式,兼容性好)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME, use_fast=False)
|
||||
|
||||
# 3. 构建模型并加载权重
|
||||
self.model = NvidiaMultiHeadClassifier(self.config)
|
||||
|
||||
model_path = hf_hub_download(self.MODEL_NAME, "model.safetensors")
|
||||
state_dict = load_file(model_path)
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
self._initialized = True
|
||||
logger.info("NVIDIA classifier loaded successfully")
|
||||
|
||||
def predict(self, query: str) -> Dict:
|
||||
"""
|
||||
预测查询的多维度特征
|
||||
|
||||
Returns:
|
||||
{
|
||||
"tier": "simple" | "medium" | "complex",
|
||||
"complexity_score": float (0-1),
|
||||
"task_type": str,
|
||||
"domain_knowledge": str,
|
||||
"reasoning": bool,
|
||||
"creativity": str
|
||||
}
|
||||
"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
|
||||
inputs = self.tokenizer(
|
||||
query, return_tensors="pt", truncation=True, max_length=512, padding=True
|
||||
)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
# 解析各头输出
|
||||
task_type_idx = torch.argmax(outputs["task_type"], dim=-1).item()
|
||||
task_type = self.TASK_TYPE_MAP.get(task_type_idx, "Unknown")
|
||||
|
||||
domain_idx = torch.argmax(outputs["domain_knowledge"], dim=-1).item()
|
||||
domain = self.DOMAIN_MAP.get(domain_idx, "Unknown")
|
||||
|
||||
creativity_idx = torch.argmax(outputs["creativity_scope"], dim=-1).item()
|
||||
creativity = self.CREATIVITY_MAP.get(creativity_idx, "Unknown")
|
||||
|
||||
reasoning_idx = torch.argmax(outputs["reasoning"], dim=-1).item()
|
||||
needs_reasoning = reasoning_idx == 1
|
||||
|
||||
# 计算综合复杂度评分 (0-1)
|
||||
complexity_score = self._compute_complexity_score(
|
||||
domain=domain,
|
||||
creativity=creativity,
|
||||
needs_reasoning=needs_reasoning,
|
||||
task_type=task_type
|
||||
)
|
||||
|
||||
tier = self._score_to_tier(complexity_score)
|
||||
|
||||
return {
|
||||
"tier": tier,
|
||||
"complexity_score": complexity_score,
|
||||
"task_type": task_type,
|
||||
"domain_knowledge": domain,
|
||||
"reasoning": needs_reasoning,
|
||||
"creativity": creativity,
|
||||
}
|
||||
|
||||
def _compute_complexity_score(self, domain, creativity, needs_reasoning, task_type) -> float:
|
||||
"""
|
||||
综合多维度计算复杂度评分 (0-1)
|
||||
|
||||
权重:
|
||||
- domain_knowledge: 40% (High=1.0, Medium=0.6, Low=0.3, No=0.0)
|
||||
- reasoning: 30% (Yes=1.0, No=0.0)
|
||||
- creativity: 20% (High=1.0, Low=0.4, No=0.0)
|
||||
- task_type: 10% (Code=0.8, QA=0.5, Chatbot=0.2, ...)
|
||||
"""
|
||||
domain_scores = {"High": 1.0, "Medium": 0.6, "Low": 0.3, "No": 0.0}
|
||||
creativity_scores = {"High": 1.0, "Low": 0.4, "No": 0.0}
|
||||
task_complexity = {
|
||||
"Code Generation": 0.8, "Text Generation": 0.7,
|
||||
"Summarization": 0.6, "Rewrite": 0.5,
|
||||
"Open QA": 0.5, "Closed QA": 0.4,
|
||||
"Classification": 0.3, "Extraction": 0.3,
|
||||
"Brainstorming": 0.6, "Chatbot": 0.2,
|
||||
"Other": 0.5, "Unknown": 0.5,
|
||||
}
|
||||
|
||||
score = (
|
||||
0.4 * domain_scores.get(domain, 0.5) +
|
||||
0.3 * (1.0 if needs_reasoning else 0.0) +
|
||||
0.2 * creativity_scores.get(creativity, 0.5) +
|
||||
0.1 * task_complexity.get(task_type, 0.5)
|
||||
)
|
||||
return round(score, 3)
|
||||
|
||||
def _score_to_tier(self, score: float) -> str:
|
||||
if score < 0.35:
|
||||
return "simple"
|
||||
elif score < 0.65:
|
||||
return "medium"
|
||||
else:
|
||||
return "complex"
|
||||
|
||||
def select_model(self, query: str) -> str:
|
||||
"""直接返回推荐的模型名称"""
|
||||
result = self.predict(query)
|
||||
model_map = {
|
||||
"simple": "qwen-flash",
|
||||
"medium": "qwen-plus",
|
||||
"complex": "qwen-max"
|
||||
}
|
||||
return model_map[result["tier"]]
|
||||
|
||||
def benchmark(self, queries: list) -> Dict:
|
||||
"""批量测试"""
|
||||
import time
|
||||
results = []
|
||||
for query in queries:
|
||||
start = time.time()
|
||||
result = self.predict(query)
|
||||
elapsed = (time.time() - start) * 1000
|
||||
results.append({
|
||||
"query": query[:50],
|
||||
"tier": result["tier"],
|
||||
"score": result["complexity_score"],
|
||||
"task": result["task_type"],
|
||||
"domain": result["domain_knowledge"],
|
||||
"reasoning": result["reasoning"],
|
||||
"time_ms": round(elapsed, 1)
|
||||
})
|
||||
|
||||
times = [r["time_ms"] for r in results]
|
||||
return {
|
||||
"avg_ms": round(sum(times) / len(times), 1),
|
||||
"results": results
|
||||
}
|
||||
|
||||
|
||||
# 全局单例
|
||||
_router_instance: Optional[NvidiaComplexityRouter] = None
|
||||
|
||||
def get_nvidia_router() -> NvidiaComplexityRouter:
|
||||
global _router_instance
|
||||
if _router_instance is None:
|
||||
_router_instance = NvidiaComplexityRouter()
|
||||
return _router_instance
|
||||
|
||||
def select_model_by_nvidia(query: str) -> str:
|
||||
return get_nvidia_router().select_model(query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_queries = [
|
||||
"你好",
|
||||
"What is 2+2?",
|
||||
"Explain quantum computing principles in detail",
|
||||
"Write a quicksort algorithm in Python with error handling",
|
||||
"Analyze this 10-page research paper and summarize the key innovations",
|
||||
"Rewrite this sentence to be more concise",
|
||||
"Generate a creative story about a robot",
|
||||
]
|
||||
|
||||
router = NvidiaComplexityRouter()
|
||||
|
||||
print("=" * 80)
|
||||
print("NVIDIA Prompt Task & Complexity Classifier - 3-Tier Router Test")
|
||||
print("=" * 80)
|
||||
|
||||
for query in test_queries:
|
||||
result = router.predict(query)
|
||||
model = router.select_model(query)
|
||||
print(f"\nQuery: {query}")
|
||||
print(f" Tier: {result['tier']}")
|
||||
print(f" Score: {result['complexity_score']}")
|
||||
print(f" Task: {result['task_type']}")
|
||||
print(f" Domain: {result['domain_knowledge']}")
|
||||
print(f" Reasoning: {result['reasoning']}")
|
||||
print(f" Creativity: {result['creativity']}")
|
||||
print(f" -> Model: {model}")
|
||||
@@ -7,5 +7,7 @@ httpx>=0.25.0
|
||||
python-dotenv>=1.0.0
|
||||
transformers>=4.30.0
|
||||
torch>=2.0.0
|
||||
# NVIDIA Multi-head Classifier for 3-tier routing
|
||||
# nvidia/prompt-task-and-complexity-classifier will be loaded via transformers
|
||||
pytest>=7.4.0
|
||||
pytest-asyncio>=0.21.0
|
||||
|
||||
Reference in New Issue
Block a user