diff --git a/main.py b/main.py index ffd03a2..abc6b9a 100644 --- a/main.py +++ b/main.py @@ -13,7 +13,7 @@ from litellm import acompletion import litellm from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING, DASHSCOPE_API_KEY -from bert_router import get_bert_router, route_with_bert +from nvidia_router import get_nvidia_router, select_model_by_nvidia # 配置 LiteLLM 使用 DashScope (Qwen) if DASHSCOPE_API_KEY: @@ -21,15 +21,15 @@ if DASHSCOPE_API_KEY: # Qwen 使用 OpenAI 兼容接口,但需要通过 api_base 指定 litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1" -# BERT Router 实例(延迟加载) -_bert_router = None +# NVIDIA Router 实例(延迟加载) +_nvidia_router = None def get_router(): - """获取 BERT Router 实例(延迟加载)""" - global _bert_router - if _bert_router is None: - _bert_router = get_bert_router() - return _bert_router + """获取 NVIDIA Router 实例(延迟加载)""" + global _nvidia_router + if _nvidia_router is None: + _nvidia_router = get_nvidia_router() + return _nvidia_router # 调用历史记录 @@ -69,8 +69,8 @@ class StatsResponse(BaseModel): app = FastAPI( title="LLM Router MVP", - description="基于 LiteLLM 的多提供商路由服务", - version="0.2.0", + description="基于 LiteLLM + NVIDIA 分类器的多提供商路由服务(支持3-tier智能路由)", + version="0.3.0", ) @@ -102,30 +102,26 @@ def select_model_by_length(messages: List[Message]) -> str: return DEFAULT_ROUTING["complex"] -def select_model_by_bert(messages: List[Message]) -> str: +def select_model_by_nvidia_classifier(messages: List[Message]) -> str: """ - 基于 BERT 分类器选择模型 + 基于 NVIDIA 多头分类器选择模型(3-tier路由) - BERT 输出: strong / weak + NVIDIA 输出: 多维度复杂度评分 映射到 Qwen 模型: - - strong -> qwen-max (复杂任务) - - weak -> qwen-flash (简单任务) + - simple -> qwen-flash (简单任务) + - medium -> qwen-plus (中等任务) + - complex -> qwen-max (复杂任务) """ # 取最后一条用户消息作为查询 query = messages[-1].content if messages else "" try: router = get_router() - complexity = router.predict(query) - - # BERT 二分类映射到三模型 - if complexity == "strong": - return "qwen-max" - else: - return "qwen-flash" + model = router.select_model(query) + return model except Exception as e: - # BERT 失败时回退到 token 长度策略 - print(f"BERT routing failed: {e}, falling back to token length") + # NVIDIA 分类器失败时回退到 token 长度策略 + print(f"NVIDIA routing failed: {e}, falling back to token length") return select_model_by_length(messages) @@ -180,8 +176,8 @@ async def chat_completions(request: ChatRequest): if request.model: model_key = request.model else: - # 使用 BERT 智能路由(替代原来的 token 长度路由) - model_key = select_model_by_bert(request.messages) + # 使用 NVIDIA 多头分类器智能路由(支持3-tier) + model_key = select_model_by_nvidia_classifier(request.messages) # 获取 LiteLLM 模型名称 provider_model = get_provider_model(model_key) @@ -293,7 +289,7 @@ async def get_stats(): @app.get("/health") async def health_check(): """健康检查""" - return {"status": "healthy", "version": "0.2.0"} + return {"status": "healthy", "version": "0.3.0", "router": "nvidia-3tier"} if __name__ == "__main__": diff --git a/nvidia_router.py b/nvidia_router.py new file mode 100644 index 0000000..8b51412 --- /dev/null +++ b/nvidia_router.py @@ -0,0 +1,318 @@ +""" +NVIDIA Prompt Task & Complexity Classifier Router +手动加载自定义多头模型,支持3-tier路由 + +模型: nvidia/prompt-task-and-complexity-classifier (184M参数) +架构: DeBERTa-v3-base backbone + 8个分类头 +输出: task_type(12类), creativity(3类), reasoning(2类), + domain_knowledge(4类), complexity_score 等多维度 +""" + +import torch +import torch.nn as nn +from transformers import AutoTokenizer, DebertaV2Model, AutoConfig +from safetensors.torch import load_file +from huggingface_hub import hf_hub_download +from typing import Dict, Optional +import logging + +logger = logging.getLogger(__name__) + + +class ClassificationHead(nn.Module): + """单个分类头""" + def __init__(self, input_dim: int, num_classes: int, dropout: float = 0.2): + super().__init__() + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(input_dim, num_classes) + + def forward(self, x): + x = self.dropout(x) + return self.fc(x) + + +class NvidiaMultiHeadClassifier(nn.Module): + """ + NVIDIA 多头分类器 + DeBERTa backbone + 8个独立分类头 + """ + def __init__(self, config): + super().__init__() + self.config = config + + # DeBERTa backbone + self.backbone = DebertaV2Model.from_pretrained( + config.base_model, + ignore_mismatched_sizes=True + ) + + hidden_size = 768 # DeBERTa-v3-base + dropout = config.fc_dropout if hasattr(config, 'fc_dropout') else 0.2 + + # 8个分类头 (与 state_dict 中的 head_0 ~ head_7 对应) + target_sizes = config.target_sizes + self.head_0 = ClassificationHead(hidden_size, target_sizes["task_type"], dropout) # 12类 + self.head_1 = ClassificationHead(hidden_size, target_sizes["creativity_scope"], dropout) # 3类 + self.head_2 = ClassificationHead(hidden_size, target_sizes["reasoning"], dropout) # 2类 + self.head_3 = ClassificationHead(hidden_size, target_sizes["contextual_knowledge"], dropout) # 2类 + self.head_4 = ClassificationHead(hidden_size, target_sizes["number_of_few_shots"], dropout) # 6类 + self.head_5 = ClassificationHead(hidden_size, target_sizes["domain_knowledge"], dropout) # 4类 + self.head_6 = ClassificationHead(hidden_size, target_sizes["no_label_reason"], dropout) # 1类 + self.head_7 = ClassificationHead(hidden_size, target_sizes["constraint_ct"], dropout) # 2类 + + # Head 名称映射 + self.head_names = [ + "task_type", # head_0: 12类 + "creativity_scope", # head_1: 3类 + "reasoning", # head_2: 2类 + "contextual_knowledge", # head_3: 2类 + "number_of_few_shots", # head_4: 6类 + "domain_knowledge", # head_5: 4类 + "no_label_reason", # head_6: 1类 + "constraint_ct", # head_7: 2类 + ] + + def forward(self, input_ids, attention_mask=None, token_type_ids=None): + outputs = self.backbone( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids + ) + # 使用 [CLS] token 的隐层 + cls_output = outputs.last_hidden_state[:, 0] + + # 各头输出 + head_outputs = { + "task_type": self.head_0(cls_output), + "creativity_scope": self.head_1(cls_output), + "reasoning": self.head_2(cls_output), + "contextual_knowledge": self.head_3(cls_output), + "number_of_few_shots": self.head_4(cls_output), + "domain_knowledge": self.head_5(cls_output), + "no_label_reason": self.head_6(cls_output), + "constraint_ct": self.head_7(cls_output), + } + return head_outputs + + +class NvidiaComplexityRouter: + """NVIDIA 多头分类器路由封装""" + + MODEL_NAME = "nvidia/prompt-task-and-complexity-classifier" + + # Task type 映射 + TASK_TYPE_MAP = { + 0: "Brainstorming", 1: "Chatbot", 2: "Classification", + 3: "Closed QA", 4: "Code Generation", 5: "Extraction", + 6: "Open QA", 7: "Other", 8: "Rewrite", + 9: "Summarization", 10: "Text Generation", 11: "Unknown" + } + + # Domain knowledge 映射 + DOMAIN_MAP = {0: "High", 1: "Low", 2: "Medium", 3: "No"} + + # Creativity 映射 + CREATIVITY_MAP = {0: "High", 1: "Low", 2: "No"} + + def __init__(self, device: str = "cpu"): + self.device = device + self.tokenizer = None + self.model = None + self.config = None + self._initialized = False + + def initialize(self): + """延迟加载模型""" + if self._initialized: + return + + logger.info(f"Loading NVIDIA classifier: {self.MODEL_NAME}") + + # 1. 加载 config + self.config = AutoConfig.from_pretrained(self.MODEL_NAME) + + # 2. 加载 tokenizer (slow模式,兼容性好) + self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME, use_fast=False) + + # 3. 构建模型并加载权重 + self.model = NvidiaMultiHeadClassifier(self.config) + + model_path = hf_hub_download(self.MODEL_NAME, "model.safetensors") + state_dict = load_file(model_path) + self.model.load_state_dict(state_dict, strict=False) + + self.model.to(self.device) + self.model.eval() + self._initialized = True + logger.info("NVIDIA classifier loaded successfully") + + def predict(self, query: str) -> Dict: + """ + 预测查询的多维度特征 + + Returns: + { + "tier": "simple" | "medium" | "complex", + "complexity_score": float (0-1), + "task_type": str, + "domain_knowledge": str, + "reasoning": bool, + "creativity": str + } + """ + if not self._initialized: + self.initialize() + + inputs = self.tokenizer( + query, return_tensors="pt", truncation=True, max_length=512, padding=True + ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self.model(**inputs) + + # 解析各头输出 + task_type_idx = torch.argmax(outputs["task_type"], dim=-1).item() + task_type = self.TASK_TYPE_MAP.get(task_type_idx, "Unknown") + + domain_idx = torch.argmax(outputs["domain_knowledge"], dim=-1).item() + domain = self.DOMAIN_MAP.get(domain_idx, "Unknown") + + creativity_idx = torch.argmax(outputs["creativity_scope"], dim=-1).item() + creativity = self.CREATIVITY_MAP.get(creativity_idx, "Unknown") + + reasoning_idx = torch.argmax(outputs["reasoning"], dim=-1).item() + needs_reasoning = reasoning_idx == 1 + + # 计算综合复杂度评分 (0-1) + complexity_score = self._compute_complexity_score( + domain=domain, + creativity=creativity, + needs_reasoning=needs_reasoning, + task_type=task_type + ) + + tier = self._score_to_tier(complexity_score) + + return { + "tier": tier, + "complexity_score": complexity_score, + "task_type": task_type, + "domain_knowledge": domain, + "reasoning": needs_reasoning, + "creativity": creativity, + } + + def _compute_complexity_score(self, domain, creativity, needs_reasoning, task_type) -> float: + """ + 综合多维度计算复杂度评分 (0-1) + + 权重: + - domain_knowledge: 40% (High=1.0, Medium=0.6, Low=0.3, No=0.0) + - reasoning: 30% (Yes=1.0, No=0.0) + - creativity: 20% (High=1.0, Low=0.4, No=0.0) + - task_type: 10% (Code=0.8, QA=0.5, Chatbot=0.2, ...) + """ + domain_scores = {"High": 1.0, "Medium": 0.6, "Low": 0.3, "No": 0.0} + creativity_scores = {"High": 1.0, "Low": 0.4, "No": 0.0} + task_complexity = { + "Code Generation": 0.8, "Text Generation": 0.7, + "Summarization": 0.6, "Rewrite": 0.5, + "Open QA": 0.5, "Closed QA": 0.4, + "Classification": 0.3, "Extraction": 0.3, + "Brainstorming": 0.6, "Chatbot": 0.2, + "Other": 0.5, "Unknown": 0.5, + } + + score = ( + 0.4 * domain_scores.get(domain, 0.5) + + 0.3 * (1.0 if needs_reasoning else 0.0) + + 0.2 * creativity_scores.get(creativity, 0.5) + + 0.1 * task_complexity.get(task_type, 0.5) + ) + return round(score, 3) + + def _score_to_tier(self, score: float) -> str: + if score < 0.35: + return "simple" + elif score < 0.65: + return "medium" + else: + return "complex" + + def select_model(self, query: str) -> str: + """直接返回推荐的模型名称""" + result = self.predict(query) + model_map = { + "simple": "qwen-flash", + "medium": "qwen-plus", + "complex": "qwen-max" + } + return model_map[result["tier"]] + + def benchmark(self, queries: list) -> Dict: + """批量测试""" + import time + results = [] + for query in queries: + start = time.time() + result = self.predict(query) + elapsed = (time.time() - start) * 1000 + results.append({ + "query": query[:50], + "tier": result["tier"], + "score": result["complexity_score"], + "task": result["task_type"], + "domain": result["domain_knowledge"], + "reasoning": result["reasoning"], + "time_ms": round(elapsed, 1) + }) + + times = [r["time_ms"] for r in results] + return { + "avg_ms": round(sum(times) / len(times), 1), + "results": results + } + + +# 全局单例 +_router_instance: Optional[NvidiaComplexityRouter] = None + +def get_nvidia_router() -> NvidiaComplexityRouter: + global _router_instance + if _router_instance is None: + _router_instance = NvidiaComplexityRouter() + return _router_instance + +def select_model_by_nvidia(query: str) -> str: + return get_nvidia_router().select_model(query) + + +if __name__ == "__main__": + test_queries = [ + "你好", + "What is 2+2?", + "Explain quantum computing principles in detail", + "Write a quicksort algorithm in Python with error handling", + "Analyze this 10-page research paper and summarize the key innovations", + "Rewrite this sentence to be more concise", + "Generate a creative story about a robot", + ] + + router = NvidiaComplexityRouter() + + print("=" * 80) + print("NVIDIA Prompt Task & Complexity Classifier - 3-Tier Router Test") + print("=" * 80) + + for query in test_queries: + result = router.predict(query) + model = router.select_model(query) + print(f"\nQuery: {query}") + print(f" Tier: {result['tier']}") + print(f" Score: {result['complexity_score']}") + print(f" Task: {result['task_type']}") + print(f" Domain: {result['domain_knowledge']}") + print(f" Reasoning: {result['reasoning']}") + print(f" Creativity: {result['creativity']}") + print(f" -> Model: {model}") diff --git a/requirements.txt b/requirements.txt index a14d412..0539dd6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,7 @@ httpx>=0.25.0 python-dotenv>=1.0.0 transformers>=4.30.0 torch>=2.0.0 +# NVIDIA Multi-head Classifier for 3-tier routing +# nvidia/prompt-task-and-complexity-classifier will be loaded via transformers pytest>=7.4.0 pytest-asyncio>=0.21.0