- 新增nvidia_router.py: 手动加载NVIDIA prompt-task-and-complexity-classifier模型 - DeBERTa-v3-base backbone + 8个分类头(task_type/creativity/reasoning/domain等) - 综合多维度评分实现simple/medium/complex三级路由 - 映射: simple->qwen-flash, medium->qwen-plus, complex->qwen-max - main.py切换到NVIDIA路由替代RouteLLM BERT二分类 - 移除LiteLLM依赖解决版本冲突,使用原生httpx调用 - 版本升级至v0.3.0
319 lines
11 KiB
Python
319 lines
11 KiB
Python
"""
|
||
NVIDIA Prompt Task & Complexity Classifier Router
|
||
手动加载自定义多头模型,支持3-tier路由
|
||
|
||
模型: nvidia/prompt-task-and-complexity-classifier (184M参数)
|
||
架构: DeBERTa-v3-base backbone + 8个分类头
|
||
输出: task_type(12类), creativity(3类), reasoning(2类),
|
||
domain_knowledge(4类), complexity_score 等多维度
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
from transformers import AutoTokenizer, DebertaV2Model, AutoConfig
|
||
from safetensors.torch import load_file
|
||
from huggingface_hub import hf_hub_download
|
||
from typing import Dict, Optional
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class ClassificationHead(nn.Module):
|
||
"""单个分类头"""
|
||
def __init__(self, input_dim: int, num_classes: int, dropout: float = 0.2):
|
||
super().__init__()
|
||
self.dropout = nn.Dropout(dropout)
|
||
self.fc = nn.Linear(input_dim, num_classes)
|
||
|
||
def forward(self, x):
|
||
x = self.dropout(x)
|
||
return self.fc(x)
|
||
|
||
|
||
class NvidiaMultiHeadClassifier(nn.Module):
|
||
"""
|
||
NVIDIA 多头分类器
|
||
DeBERTa backbone + 8个独立分类头
|
||
"""
|
||
def __init__(self, config):
|
||
super().__init__()
|
||
self.config = config
|
||
|
||
# DeBERTa backbone
|
||
self.backbone = DebertaV2Model.from_pretrained(
|
||
config.base_model,
|
||
ignore_mismatched_sizes=True
|
||
)
|
||
|
||
hidden_size = 768 # DeBERTa-v3-base
|
||
dropout = config.fc_dropout if hasattr(config, 'fc_dropout') else 0.2
|
||
|
||
# 8个分类头 (与 state_dict 中的 head_0 ~ head_7 对应)
|
||
target_sizes = config.target_sizes
|
||
self.head_0 = ClassificationHead(hidden_size, target_sizes["task_type"], dropout) # 12类
|
||
self.head_1 = ClassificationHead(hidden_size, target_sizes["creativity_scope"], dropout) # 3类
|
||
self.head_2 = ClassificationHead(hidden_size, target_sizes["reasoning"], dropout) # 2类
|
||
self.head_3 = ClassificationHead(hidden_size, target_sizes["contextual_knowledge"], dropout) # 2类
|
||
self.head_4 = ClassificationHead(hidden_size, target_sizes["number_of_few_shots"], dropout) # 6类
|
||
self.head_5 = ClassificationHead(hidden_size, target_sizes["domain_knowledge"], dropout) # 4类
|
||
self.head_6 = ClassificationHead(hidden_size, target_sizes["no_label_reason"], dropout) # 1类
|
||
self.head_7 = ClassificationHead(hidden_size, target_sizes["constraint_ct"], dropout) # 2类
|
||
|
||
# Head 名称映射
|
||
self.head_names = [
|
||
"task_type", # head_0: 12类
|
||
"creativity_scope", # head_1: 3类
|
||
"reasoning", # head_2: 2类
|
||
"contextual_knowledge", # head_3: 2类
|
||
"number_of_few_shots", # head_4: 6类
|
||
"domain_knowledge", # head_5: 4类
|
||
"no_label_reason", # head_6: 1类
|
||
"constraint_ct", # head_7: 2类
|
||
]
|
||
|
||
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
|
||
outputs = self.backbone(
|
||
input_ids=input_ids,
|
||
attention_mask=attention_mask,
|
||
token_type_ids=token_type_ids
|
||
)
|
||
# 使用 [CLS] token 的隐层
|
||
cls_output = outputs.last_hidden_state[:, 0]
|
||
|
||
# 各头输出
|
||
head_outputs = {
|
||
"task_type": self.head_0(cls_output),
|
||
"creativity_scope": self.head_1(cls_output),
|
||
"reasoning": self.head_2(cls_output),
|
||
"contextual_knowledge": self.head_3(cls_output),
|
||
"number_of_few_shots": self.head_4(cls_output),
|
||
"domain_knowledge": self.head_5(cls_output),
|
||
"no_label_reason": self.head_6(cls_output),
|
||
"constraint_ct": self.head_7(cls_output),
|
||
}
|
||
return head_outputs
|
||
|
||
|
||
class NvidiaComplexityRouter:
|
||
"""NVIDIA 多头分类器路由封装"""
|
||
|
||
MODEL_NAME = "nvidia/prompt-task-and-complexity-classifier"
|
||
|
||
# Task type 映射
|
||
TASK_TYPE_MAP = {
|
||
0: "Brainstorming", 1: "Chatbot", 2: "Classification",
|
||
3: "Closed QA", 4: "Code Generation", 5: "Extraction",
|
||
6: "Open QA", 7: "Other", 8: "Rewrite",
|
||
9: "Summarization", 10: "Text Generation", 11: "Unknown"
|
||
}
|
||
|
||
# Domain knowledge 映射
|
||
DOMAIN_MAP = {0: "High", 1: "Low", 2: "Medium", 3: "No"}
|
||
|
||
# Creativity 映射
|
||
CREATIVITY_MAP = {0: "High", 1: "Low", 2: "No"}
|
||
|
||
def __init__(self, device: str = "cpu"):
|
||
self.device = device
|
||
self.tokenizer = None
|
||
self.model = None
|
||
self.config = None
|
||
self._initialized = False
|
||
|
||
def initialize(self):
|
||
"""延迟加载模型"""
|
||
if self._initialized:
|
||
return
|
||
|
||
logger.info(f"Loading NVIDIA classifier: {self.MODEL_NAME}")
|
||
|
||
# 1. 加载 config
|
||
self.config = AutoConfig.from_pretrained(self.MODEL_NAME)
|
||
|
||
# 2. 加载 tokenizer (slow模式,兼容性好)
|
||
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME, use_fast=False)
|
||
|
||
# 3. 构建模型并加载权重
|
||
self.model = NvidiaMultiHeadClassifier(self.config)
|
||
|
||
model_path = hf_hub_download(self.MODEL_NAME, "model.safetensors")
|
||
state_dict = load_file(model_path)
|
||
self.model.load_state_dict(state_dict, strict=False)
|
||
|
||
self.model.to(self.device)
|
||
self.model.eval()
|
||
self._initialized = True
|
||
logger.info("NVIDIA classifier loaded successfully")
|
||
|
||
def predict(self, query: str) -> Dict:
|
||
"""
|
||
预测查询的多维度特征
|
||
|
||
Returns:
|
||
{
|
||
"tier": "simple" | "medium" | "complex",
|
||
"complexity_score": float (0-1),
|
||
"task_type": str,
|
||
"domain_knowledge": str,
|
||
"reasoning": bool,
|
||
"creativity": str
|
||
}
|
||
"""
|
||
if not self._initialized:
|
||
self.initialize()
|
||
|
||
inputs = self.tokenizer(
|
||
query, return_tensors="pt", truncation=True, max_length=512, padding=True
|
||
)
|
||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||
|
||
with torch.no_grad():
|
||
outputs = self.model(**inputs)
|
||
|
||
# 解析各头输出
|
||
task_type_idx = torch.argmax(outputs["task_type"], dim=-1).item()
|
||
task_type = self.TASK_TYPE_MAP.get(task_type_idx, "Unknown")
|
||
|
||
domain_idx = torch.argmax(outputs["domain_knowledge"], dim=-1).item()
|
||
domain = self.DOMAIN_MAP.get(domain_idx, "Unknown")
|
||
|
||
creativity_idx = torch.argmax(outputs["creativity_scope"], dim=-1).item()
|
||
creativity = self.CREATIVITY_MAP.get(creativity_idx, "Unknown")
|
||
|
||
reasoning_idx = torch.argmax(outputs["reasoning"], dim=-1).item()
|
||
needs_reasoning = reasoning_idx == 1
|
||
|
||
# 计算综合复杂度评分 (0-1)
|
||
complexity_score = self._compute_complexity_score(
|
||
domain=domain,
|
||
creativity=creativity,
|
||
needs_reasoning=needs_reasoning,
|
||
task_type=task_type
|
||
)
|
||
|
||
tier = self._score_to_tier(complexity_score)
|
||
|
||
return {
|
||
"tier": tier,
|
||
"complexity_score": complexity_score,
|
||
"task_type": task_type,
|
||
"domain_knowledge": domain,
|
||
"reasoning": needs_reasoning,
|
||
"creativity": creativity,
|
||
}
|
||
|
||
def _compute_complexity_score(self, domain, creativity, needs_reasoning, task_type) -> float:
|
||
"""
|
||
综合多维度计算复杂度评分 (0-1)
|
||
|
||
权重:
|
||
- domain_knowledge: 40% (High=1.0, Medium=0.6, Low=0.3, No=0.0)
|
||
- reasoning: 30% (Yes=1.0, No=0.0)
|
||
- creativity: 20% (High=1.0, Low=0.4, No=0.0)
|
||
- task_type: 10% (Code=0.8, QA=0.5, Chatbot=0.2, ...)
|
||
"""
|
||
domain_scores = {"High": 1.0, "Medium": 0.6, "Low": 0.3, "No": 0.0}
|
||
creativity_scores = {"High": 1.0, "Low": 0.4, "No": 0.0}
|
||
task_complexity = {
|
||
"Code Generation": 0.8, "Text Generation": 0.7,
|
||
"Summarization": 0.6, "Rewrite": 0.5,
|
||
"Open QA": 0.5, "Closed QA": 0.4,
|
||
"Classification": 0.3, "Extraction": 0.3,
|
||
"Brainstorming": 0.6, "Chatbot": 0.2,
|
||
"Other": 0.5, "Unknown": 0.5,
|
||
}
|
||
|
||
score = (
|
||
0.4 * domain_scores.get(domain, 0.5) +
|
||
0.3 * (1.0 if needs_reasoning else 0.0) +
|
||
0.2 * creativity_scores.get(creativity, 0.5) +
|
||
0.1 * task_complexity.get(task_type, 0.5)
|
||
)
|
||
return round(score, 3)
|
||
|
||
def _score_to_tier(self, score: float) -> str:
|
||
if score < 0.35:
|
||
return "simple"
|
||
elif score < 0.65:
|
||
return "medium"
|
||
else:
|
||
return "complex"
|
||
|
||
def select_model(self, query: str) -> str:
|
||
"""直接返回推荐的模型名称"""
|
||
result = self.predict(query)
|
||
model_map = {
|
||
"simple": "qwen-flash",
|
||
"medium": "qwen-plus",
|
||
"complex": "qwen-max"
|
||
}
|
||
return model_map[result["tier"]]
|
||
|
||
def benchmark(self, queries: list) -> Dict:
|
||
"""批量测试"""
|
||
import time
|
||
results = []
|
||
for query in queries:
|
||
start = time.time()
|
||
result = self.predict(query)
|
||
elapsed = (time.time() - start) * 1000
|
||
results.append({
|
||
"query": query[:50],
|
||
"tier": result["tier"],
|
||
"score": result["complexity_score"],
|
||
"task": result["task_type"],
|
||
"domain": result["domain_knowledge"],
|
||
"reasoning": result["reasoning"],
|
||
"time_ms": round(elapsed, 1)
|
||
})
|
||
|
||
times = [r["time_ms"] for r in results]
|
||
return {
|
||
"avg_ms": round(sum(times) / len(times), 1),
|
||
"results": results
|
||
}
|
||
|
||
|
||
# 全局单例
|
||
_router_instance: Optional[NvidiaComplexityRouter] = None
|
||
|
||
def get_nvidia_router() -> NvidiaComplexityRouter:
|
||
global _router_instance
|
||
if _router_instance is None:
|
||
_router_instance = NvidiaComplexityRouter()
|
||
return _router_instance
|
||
|
||
def select_model_by_nvidia(query: str) -> str:
|
||
return get_nvidia_router().select_model(query)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
test_queries = [
|
||
"你好",
|
||
"What is 2+2?",
|
||
"Explain quantum computing principles in detail",
|
||
"Write a quicksort algorithm in Python with error handling",
|
||
"Analyze this 10-page research paper and summarize the key innovations",
|
||
"Rewrite this sentence to be more concise",
|
||
"Generate a creative story about a robot",
|
||
]
|
||
|
||
router = NvidiaComplexityRouter()
|
||
|
||
print("=" * 80)
|
||
print("NVIDIA Prompt Task & Complexity Classifier - 3-Tier Router Test")
|
||
print("=" * 80)
|
||
|
||
for query in test_queries:
|
||
result = router.predict(query)
|
||
model = router.select_model(query)
|
||
print(f"\nQuery: {query}")
|
||
print(f" Tier: {result['tier']}")
|
||
print(f" Score: {result['complexity_score']}")
|
||
print(f" Task: {result['task_type']}")
|
||
print(f" Domain: {result['domain_knowledge']}")
|
||
print(f" Reasoning: {result['reasoning']}")
|
||
print(f" Creativity: {result['creativity']}")
|
||
print(f" -> Model: {model}")
|