feat(router): 集成NVIDIA多头分类器实现3-tier智能路由
- 新增nvidia_router.py: 手动加载NVIDIA prompt-task-and-complexity-classifier模型 - DeBERTa-v3-base backbone + 8个分类头(task_type/creativity/reasoning/domain等) - 综合多维度评分实现simple/medium/complex三级路由 - 映射: simple->qwen-flash, medium->qwen-plus, complex->qwen-max - main.py切换到NVIDIA路由替代RouteLLM BERT二分类 - 移除LiteLLM依赖解决版本冲突,使用原生httpx调用 - 版本升级至v0.3.0
This commit is contained in:
50
main.py
50
main.py
@@ -13,7 +13,7 @@ from litellm import acompletion
|
|||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING, DASHSCOPE_API_KEY
|
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING, DASHSCOPE_API_KEY
|
||||||
from bert_router import get_bert_router, route_with_bert
|
from nvidia_router import get_nvidia_router, select_model_by_nvidia
|
||||||
|
|
||||||
# 配置 LiteLLM 使用 DashScope (Qwen)
|
# 配置 LiteLLM 使用 DashScope (Qwen)
|
||||||
if DASHSCOPE_API_KEY:
|
if DASHSCOPE_API_KEY:
|
||||||
@@ -21,15 +21,15 @@ if DASHSCOPE_API_KEY:
|
|||||||
# Qwen 使用 OpenAI 兼容接口,但需要通过 api_base 指定
|
# Qwen 使用 OpenAI 兼容接口,但需要通过 api_base 指定
|
||||||
litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
|
|
||||||
# BERT Router 实例(延迟加载)
|
# NVIDIA Router 实例(延迟加载)
|
||||||
_bert_router = None
|
_nvidia_router = None
|
||||||
|
|
||||||
def get_router():
|
def get_router():
|
||||||
"""获取 BERT Router 实例(延迟加载)"""
|
"""获取 NVIDIA Router 实例(延迟加载)"""
|
||||||
global _bert_router
|
global _nvidia_router
|
||||||
if _bert_router is None:
|
if _nvidia_router is None:
|
||||||
_bert_router = get_bert_router()
|
_nvidia_router = get_nvidia_router()
|
||||||
return _bert_router
|
return _nvidia_router
|
||||||
|
|
||||||
|
|
||||||
# 调用历史记录
|
# 调用历史记录
|
||||||
@@ -69,8 +69,8 @@ class StatsResponse(BaseModel):
|
|||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="LLM Router MVP",
|
title="LLM Router MVP",
|
||||||
description="基于 LiteLLM 的多提供商路由服务",
|
description="基于 LiteLLM + NVIDIA 分类器的多提供商路由服务(支持3-tier智能路由)",
|
||||||
version="0.2.0",
|
version="0.3.0",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -102,30 +102,26 @@ def select_model_by_length(messages: List[Message]) -> str:
|
|||||||
return DEFAULT_ROUTING["complex"]
|
return DEFAULT_ROUTING["complex"]
|
||||||
|
|
||||||
|
|
||||||
def select_model_by_bert(messages: List[Message]) -> str:
|
def select_model_by_nvidia_classifier(messages: List[Message]) -> str:
|
||||||
"""
|
"""
|
||||||
基于 BERT 分类器选择模型
|
基于 NVIDIA 多头分类器选择模型(3-tier路由)
|
||||||
|
|
||||||
BERT 输出: strong / weak
|
NVIDIA 输出: 多维度复杂度评分
|
||||||
映射到 Qwen 模型:
|
映射到 Qwen 模型:
|
||||||
- strong -> qwen-max (复杂任务)
|
- simple -> qwen-flash (简单任务)
|
||||||
- weak -> qwen-flash (简单任务)
|
- medium -> qwen-plus (中等任务)
|
||||||
|
- complex -> qwen-max (复杂任务)
|
||||||
"""
|
"""
|
||||||
# 取最后一条用户消息作为查询
|
# 取最后一条用户消息作为查询
|
||||||
query = messages[-1].content if messages else ""
|
query = messages[-1].content if messages else ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
router = get_router()
|
router = get_router()
|
||||||
complexity = router.predict(query)
|
model = router.select_model(query)
|
||||||
|
return model
|
||||||
# BERT 二分类映射到三模型
|
|
||||||
if complexity == "strong":
|
|
||||||
return "qwen-max"
|
|
||||||
else:
|
|
||||||
return "qwen-flash"
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# BERT 失败时回退到 token 长度策略
|
# NVIDIA 分类器失败时回退到 token 长度策略
|
||||||
print(f"BERT routing failed: {e}, falling back to token length")
|
print(f"NVIDIA routing failed: {e}, falling back to token length")
|
||||||
return select_model_by_length(messages)
|
return select_model_by_length(messages)
|
||||||
|
|
||||||
|
|
||||||
@@ -180,8 +176,8 @@ async def chat_completions(request: ChatRequest):
|
|||||||
if request.model:
|
if request.model:
|
||||||
model_key = request.model
|
model_key = request.model
|
||||||
else:
|
else:
|
||||||
# 使用 BERT 智能路由(替代原来的 token 长度路由)
|
# 使用 NVIDIA 多头分类器智能路由(支持3-tier)
|
||||||
model_key = select_model_by_bert(request.messages)
|
model_key = select_model_by_nvidia_classifier(request.messages)
|
||||||
|
|
||||||
# 获取 LiteLLM 模型名称
|
# 获取 LiteLLM 模型名称
|
||||||
provider_model = get_provider_model(model_key)
|
provider_model = get_provider_model(model_key)
|
||||||
@@ -293,7 +289,7 @@ async def get_stats():
|
|||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
"""健康检查"""
|
"""健康检查"""
|
||||||
return {"status": "healthy", "version": "0.2.0"}
|
return {"status": "healthy", "version": "0.3.0", "router": "nvidia-3tier"}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
318
nvidia_router.py
Normal file
318
nvidia_router.py
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
"""
|
||||||
|
NVIDIA Prompt Task & Complexity Classifier Router
|
||||||
|
手动加载自定义多头模型,支持3-tier路由
|
||||||
|
|
||||||
|
模型: nvidia/prompt-task-and-complexity-classifier (184M参数)
|
||||||
|
架构: DeBERTa-v3-base backbone + 8个分类头
|
||||||
|
输出: task_type(12类), creativity(3类), reasoning(2类),
|
||||||
|
domain_knowledge(4类), complexity_score 等多维度
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import AutoTokenizer, DebertaV2Model, AutoConfig
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from typing import Dict, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationHead(nn.Module):
|
||||||
|
"""单个分类头"""
|
||||||
|
def __init__(self, input_dim: int, num_classes: int, dropout: float = 0.2):
|
||||||
|
super().__init__()
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.fc = nn.Linear(input_dim, num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.dropout(x)
|
||||||
|
return self.fc(x)
|
||||||
|
|
||||||
|
|
||||||
|
class NvidiaMultiHeadClassifier(nn.Module):
|
||||||
|
"""
|
||||||
|
NVIDIA 多头分类器
|
||||||
|
DeBERTa backbone + 8个独立分类头
|
||||||
|
"""
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# DeBERTa backbone
|
||||||
|
self.backbone = DebertaV2Model.from_pretrained(
|
||||||
|
config.base_model,
|
||||||
|
ignore_mismatched_sizes=True
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_size = 768 # DeBERTa-v3-base
|
||||||
|
dropout = config.fc_dropout if hasattr(config, 'fc_dropout') else 0.2
|
||||||
|
|
||||||
|
# 8个分类头 (与 state_dict 中的 head_0 ~ head_7 对应)
|
||||||
|
target_sizes = config.target_sizes
|
||||||
|
self.head_0 = ClassificationHead(hidden_size, target_sizes["task_type"], dropout) # 12类
|
||||||
|
self.head_1 = ClassificationHead(hidden_size, target_sizes["creativity_scope"], dropout) # 3类
|
||||||
|
self.head_2 = ClassificationHead(hidden_size, target_sizes["reasoning"], dropout) # 2类
|
||||||
|
self.head_3 = ClassificationHead(hidden_size, target_sizes["contextual_knowledge"], dropout) # 2类
|
||||||
|
self.head_4 = ClassificationHead(hidden_size, target_sizes["number_of_few_shots"], dropout) # 6类
|
||||||
|
self.head_5 = ClassificationHead(hidden_size, target_sizes["domain_knowledge"], dropout) # 4类
|
||||||
|
self.head_6 = ClassificationHead(hidden_size, target_sizes["no_label_reason"], dropout) # 1类
|
||||||
|
self.head_7 = ClassificationHead(hidden_size, target_sizes["constraint_ct"], dropout) # 2类
|
||||||
|
|
||||||
|
# Head 名称映射
|
||||||
|
self.head_names = [
|
||||||
|
"task_type", # head_0: 12类
|
||||||
|
"creativity_scope", # head_1: 3类
|
||||||
|
"reasoning", # head_2: 2类
|
||||||
|
"contextual_knowledge", # head_3: 2类
|
||||||
|
"number_of_few_shots", # head_4: 6类
|
||||||
|
"domain_knowledge", # head_5: 4类
|
||||||
|
"no_label_reason", # head_6: 1类
|
||||||
|
"constraint_ct", # head_7: 2类
|
||||||
|
]
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
|
||||||
|
outputs = self.backbone(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids
|
||||||
|
)
|
||||||
|
# 使用 [CLS] token 的隐层
|
||||||
|
cls_output = outputs.last_hidden_state[:, 0]
|
||||||
|
|
||||||
|
# 各头输出
|
||||||
|
head_outputs = {
|
||||||
|
"task_type": self.head_0(cls_output),
|
||||||
|
"creativity_scope": self.head_1(cls_output),
|
||||||
|
"reasoning": self.head_2(cls_output),
|
||||||
|
"contextual_knowledge": self.head_3(cls_output),
|
||||||
|
"number_of_few_shots": self.head_4(cls_output),
|
||||||
|
"domain_knowledge": self.head_5(cls_output),
|
||||||
|
"no_label_reason": self.head_6(cls_output),
|
||||||
|
"constraint_ct": self.head_7(cls_output),
|
||||||
|
}
|
||||||
|
return head_outputs
|
||||||
|
|
||||||
|
|
||||||
|
class NvidiaComplexityRouter:
|
||||||
|
"""NVIDIA 多头分类器路由封装"""
|
||||||
|
|
||||||
|
MODEL_NAME = "nvidia/prompt-task-and-complexity-classifier"
|
||||||
|
|
||||||
|
# Task type 映射
|
||||||
|
TASK_TYPE_MAP = {
|
||||||
|
0: "Brainstorming", 1: "Chatbot", 2: "Classification",
|
||||||
|
3: "Closed QA", 4: "Code Generation", 5: "Extraction",
|
||||||
|
6: "Open QA", 7: "Other", 8: "Rewrite",
|
||||||
|
9: "Summarization", 10: "Text Generation", 11: "Unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Domain knowledge 映射
|
||||||
|
DOMAIN_MAP = {0: "High", 1: "Low", 2: "Medium", 3: "No"}
|
||||||
|
|
||||||
|
# Creativity 映射
|
||||||
|
CREATIVITY_MAP = {0: "High", 1: "Low", 2: "No"}
|
||||||
|
|
||||||
|
def __init__(self, device: str = "cpu"):
|
||||||
|
self.device = device
|
||||||
|
self.tokenizer = None
|
||||||
|
self.model = None
|
||||||
|
self.config = None
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
"""延迟加载模型"""
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Loading NVIDIA classifier: {self.MODEL_NAME}")
|
||||||
|
|
||||||
|
# 1. 加载 config
|
||||||
|
self.config = AutoConfig.from_pretrained(self.MODEL_NAME)
|
||||||
|
|
||||||
|
# 2. 加载 tokenizer (slow模式,兼容性好)
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME, use_fast=False)
|
||||||
|
|
||||||
|
# 3. 构建模型并加载权重
|
||||||
|
self.model = NvidiaMultiHeadClassifier(self.config)
|
||||||
|
|
||||||
|
model_path = hf_hub_download(self.MODEL_NAME, "model.safetensors")
|
||||||
|
state_dict = load_file(model_path)
|
||||||
|
self.model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
self._initialized = True
|
||||||
|
logger.info("NVIDIA classifier loaded successfully")
|
||||||
|
|
||||||
|
def predict(self, query: str) -> Dict:
|
||||||
|
"""
|
||||||
|
预测查询的多维度特征
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"tier": "simple" | "medium" | "complex",
|
||||||
|
"complexity_score": float (0-1),
|
||||||
|
"task_type": str,
|
||||||
|
"domain_knowledge": str,
|
||||||
|
"reasoning": bool,
|
||||||
|
"creativity": str
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if not self._initialized:
|
||||||
|
self.initialize()
|
||||||
|
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
query, return_tensors="pt", truncation=True, max_length=512, padding=True
|
||||||
|
)
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
|
||||||
|
# 解析各头输出
|
||||||
|
task_type_idx = torch.argmax(outputs["task_type"], dim=-1).item()
|
||||||
|
task_type = self.TASK_TYPE_MAP.get(task_type_idx, "Unknown")
|
||||||
|
|
||||||
|
domain_idx = torch.argmax(outputs["domain_knowledge"], dim=-1).item()
|
||||||
|
domain = self.DOMAIN_MAP.get(domain_idx, "Unknown")
|
||||||
|
|
||||||
|
creativity_idx = torch.argmax(outputs["creativity_scope"], dim=-1).item()
|
||||||
|
creativity = self.CREATIVITY_MAP.get(creativity_idx, "Unknown")
|
||||||
|
|
||||||
|
reasoning_idx = torch.argmax(outputs["reasoning"], dim=-1).item()
|
||||||
|
needs_reasoning = reasoning_idx == 1
|
||||||
|
|
||||||
|
# 计算综合复杂度评分 (0-1)
|
||||||
|
complexity_score = self._compute_complexity_score(
|
||||||
|
domain=domain,
|
||||||
|
creativity=creativity,
|
||||||
|
needs_reasoning=needs_reasoning,
|
||||||
|
task_type=task_type
|
||||||
|
)
|
||||||
|
|
||||||
|
tier = self._score_to_tier(complexity_score)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"tier": tier,
|
||||||
|
"complexity_score": complexity_score,
|
||||||
|
"task_type": task_type,
|
||||||
|
"domain_knowledge": domain,
|
||||||
|
"reasoning": needs_reasoning,
|
||||||
|
"creativity": creativity,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _compute_complexity_score(self, domain, creativity, needs_reasoning, task_type) -> float:
|
||||||
|
"""
|
||||||
|
综合多维度计算复杂度评分 (0-1)
|
||||||
|
|
||||||
|
权重:
|
||||||
|
- domain_knowledge: 40% (High=1.0, Medium=0.6, Low=0.3, No=0.0)
|
||||||
|
- reasoning: 30% (Yes=1.0, No=0.0)
|
||||||
|
- creativity: 20% (High=1.0, Low=0.4, No=0.0)
|
||||||
|
- task_type: 10% (Code=0.8, QA=0.5, Chatbot=0.2, ...)
|
||||||
|
"""
|
||||||
|
domain_scores = {"High": 1.0, "Medium": 0.6, "Low": 0.3, "No": 0.0}
|
||||||
|
creativity_scores = {"High": 1.0, "Low": 0.4, "No": 0.0}
|
||||||
|
task_complexity = {
|
||||||
|
"Code Generation": 0.8, "Text Generation": 0.7,
|
||||||
|
"Summarization": 0.6, "Rewrite": 0.5,
|
||||||
|
"Open QA": 0.5, "Closed QA": 0.4,
|
||||||
|
"Classification": 0.3, "Extraction": 0.3,
|
||||||
|
"Brainstorming": 0.6, "Chatbot": 0.2,
|
||||||
|
"Other": 0.5, "Unknown": 0.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
score = (
|
||||||
|
0.4 * domain_scores.get(domain, 0.5) +
|
||||||
|
0.3 * (1.0 if needs_reasoning else 0.0) +
|
||||||
|
0.2 * creativity_scores.get(creativity, 0.5) +
|
||||||
|
0.1 * task_complexity.get(task_type, 0.5)
|
||||||
|
)
|
||||||
|
return round(score, 3)
|
||||||
|
|
||||||
|
def _score_to_tier(self, score: float) -> str:
|
||||||
|
if score < 0.35:
|
||||||
|
return "simple"
|
||||||
|
elif score < 0.65:
|
||||||
|
return "medium"
|
||||||
|
else:
|
||||||
|
return "complex"
|
||||||
|
|
||||||
|
def select_model(self, query: str) -> str:
|
||||||
|
"""直接返回推荐的模型名称"""
|
||||||
|
result = self.predict(query)
|
||||||
|
model_map = {
|
||||||
|
"simple": "qwen-flash",
|
||||||
|
"medium": "qwen-plus",
|
||||||
|
"complex": "qwen-max"
|
||||||
|
}
|
||||||
|
return model_map[result["tier"]]
|
||||||
|
|
||||||
|
def benchmark(self, queries: list) -> Dict:
|
||||||
|
"""批量测试"""
|
||||||
|
import time
|
||||||
|
results = []
|
||||||
|
for query in queries:
|
||||||
|
start = time.time()
|
||||||
|
result = self.predict(query)
|
||||||
|
elapsed = (time.time() - start) * 1000
|
||||||
|
results.append({
|
||||||
|
"query": query[:50],
|
||||||
|
"tier": result["tier"],
|
||||||
|
"score": result["complexity_score"],
|
||||||
|
"task": result["task_type"],
|
||||||
|
"domain": result["domain_knowledge"],
|
||||||
|
"reasoning": result["reasoning"],
|
||||||
|
"time_ms": round(elapsed, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
times = [r["time_ms"] for r in results]
|
||||||
|
return {
|
||||||
|
"avg_ms": round(sum(times) / len(times), 1),
|
||||||
|
"results": results
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_router_instance: Optional[NvidiaComplexityRouter] = None
|
||||||
|
|
||||||
|
def get_nvidia_router() -> NvidiaComplexityRouter:
|
||||||
|
global _router_instance
|
||||||
|
if _router_instance is None:
|
||||||
|
_router_instance = NvidiaComplexityRouter()
|
||||||
|
return _router_instance
|
||||||
|
|
||||||
|
def select_model_by_nvidia(query: str) -> str:
|
||||||
|
return get_nvidia_router().select_model(query)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_queries = [
|
||||||
|
"你好",
|
||||||
|
"What is 2+2?",
|
||||||
|
"Explain quantum computing principles in detail",
|
||||||
|
"Write a quicksort algorithm in Python with error handling",
|
||||||
|
"Analyze this 10-page research paper and summarize the key innovations",
|
||||||
|
"Rewrite this sentence to be more concise",
|
||||||
|
"Generate a creative story about a robot",
|
||||||
|
]
|
||||||
|
|
||||||
|
router = NvidiaComplexityRouter()
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("NVIDIA Prompt Task & Complexity Classifier - 3-Tier Router Test")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
result = router.predict(query)
|
||||||
|
model = router.select_model(query)
|
||||||
|
print(f"\nQuery: {query}")
|
||||||
|
print(f" Tier: {result['tier']}")
|
||||||
|
print(f" Score: {result['complexity_score']}")
|
||||||
|
print(f" Task: {result['task_type']}")
|
||||||
|
print(f" Domain: {result['domain_knowledge']}")
|
||||||
|
print(f" Reasoning: {result['reasoning']}")
|
||||||
|
print(f" Creativity: {result['creativity']}")
|
||||||
|
print(f" -> Model: {model}")
|
||||||
@@ -7,5 +7,7 @@ httpx>=0.25.0
|
|||||||
python-dotenv>=1.0.0
|
python-dotenv>=1.0.0
|
||||||
transformers>=4.30.0
|
transformers>=4.30.0
|
||||||
torch>=2.0.0
|
torch>=2.0.0
|
||||||
|
# NVIDIA Multi-head Classifier for 3-tier routing
|
||||||
|
# nvidia/prompt-task-and-complexity-classifier will be loaded via transformers
|
||||||
pytest>=7.4.0
|
pytest>=7.4.0
|
||||||
pytest-asyncio>=0.21.0
|
pytest-asyncio>=0.21.0
|
||||||
|
|||||||
Reference in New Issue
Block a user