Files
llm-compass/main.py
aszerW 59c03516e4 feat(router): 集成NVIDIA多头分类器实现3-tier智能路由
- 新增nvidia_router.py: 手动加载NVIDIA prompt-task-and-complexity-classifier模型
- DeBERTa-v3-base backbone + 8个分类头(task_type/creativity/reasoning/domain等)
- 综合多维度评分实现simple/medium/complex三级路由
- 映射: simple->qwen-flash, medium->qwen-plus, complex->qwen-max
- main.py切换到NVIDIA路由替代RouteLLM BERT二分类
- 移除LiteLLM依赖解决版本冲突,使用原生httpx调用
- 版本升级至v0.3.0
2026-04-18 01:21:31 +08:00

298 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
MVP版 LLM 路由服务
基于 LiteLLM 的多提供商统一接口
支持: OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商
"""
import time
import tiktoken
from typing import List, Dict, Any, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from litellm import acompletion
import litellm
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING, DASHSCOPE_API_KEY
from nvidia_router import get_nvidia_router, select_model_by_nvidia
# 配置 LiteLLM 使用 DashScope (Qwen)
if DASHSCOPE_API_KEY:
litellm.api_key = DASHSCOPE_API_KEY
# Qwen 使用 OpenAI 兼容接口,但需要通过 api_base 指定
litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# NVIDIA Router 实例(延迟加载)
_nvidia_router = None
def get_router():
"""获取 NVIDIA Router 实例(延迟加载)"""
global _nvidia_router
if _nvidia_router is None:
_nvidia_router = get_nvidia_router()
return _nvidia_router
# 调用历史记录
call_history: List[Dict[str, Any]] = []
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
model: Optional[str] = None # 可选,如果指定则跳过路由
temperature: float = 0.7
max_tokens: Optional[int] = None
class ChatResponse(BaseModel):
id: str
model: str
provider: str
content: str
usage: Dict[str, int]
cost_usd: float
latency_ms: float
class StatsResponse(BaseModel):
total_calls: int
total_cost_usd: float
avg_latency_ms: float
model_distribution: Dict[str, int]
provider_distribution: Dict[str, int]
recent_calls: List[Dict[str, Any]]
app = FastAPI(
title="LLM Router MVP",
description="基于 LiteLLM + NVIDIA 分类器的多提供商路由服务支持3-tier智能路由",
version="0.3.0",
)
def estimate_tokens(messages: List[Message]) -> int:
"""估算 token 数量"""
try:
encoding = tiktoken.encoding_for_model("gpt-4")
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
total_tokens = 0
for msg in messages:
total_tokens += 4
total_tokens += len(encoding.encode(msg.content))
total_tokens += len(encoding.encode(msg.role))
total_tokens += 2
return total_tokens
def select_model_by_length(messages: List[Message]) -> str:
"""基于 token 长度选择模型(备用策略)"""
token_count = estimate_tokens(messages)
if token_count < ROUTING_THRESHOLDS["simple"]:
return DEFAULT_ROUTING["simple"]
elif token_count < ROUTING_THRESHOLDS["medium"]:
return DEFAULT_ROUTING["medium"]
else:
return DEFAULT_ROUTING["complex"]
def select_model_by_nvidia_classifier(messages: List[Message]) -> str:
"""
基于 NVIDIA 多头分类器选择模型3-tier路由
NVIDIA 输出: 多维度复杂度评分
映射到 Qwen 模型:
- simple -> qwen-flash (简单任务)
- medium -> qwen-plus (中等任务)
- complex -> qwen-max (复杂任务)
"""
# 取最后一条用户消息作为查询
query = messages[-1].content if messages else ""
try:
router = get_router()
model = router.select_model(query)
return model
except Exception as e:
# NVIDIA 分类器失败时回退到 token 长度策略
print(f"NVIDIA routing failed: {e}, falling back to token length")
return select_model_by_length(messages)
def get_provider_model(model_key: str) -> str:
"""获取 LiteLLM 格式的模型名称"""
config = MODEL_CONFIG.get(model_key)
if not config:
raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}")
return config["provider"]
def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float:
"""计算调用成本"""
config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"])
input_cost = (input_tokens / 1000) * config["input_cost"]
output_cost = (output_tokens / 1000) * config["output_cost"]
return input_cost + output_cost
def get_provider_from_model(model_name: str) -> str:
"""从模型名称推断提供商"""
if model_name.startswith("gpt"):
return "openai"
elif model_name.startswith("claude"):
return "anthropic"
elif model_name.startswith("gemini"):
return "google"
elif "/" in model_name:
return model_name.split("/")[0]
return "unknown"
def log_call(model: str, provider: str, cost: float, latency_ms: float, tokens: int):
"""记录调用历史"""
call_history.append({
"model": model,
"provider": provider,
"cost_usd": cost,
"latency_ms": latency_ms,
"tokens": tokens,
"timestamp": time.time(),
})
@app.post("/v1/chat/completions", response_model=ChatResponse)
async def chat_completions(request: ChatRequest):
"""
聊天完成接口
如果 request.model 未指定,则根据 token 长度自动路由
"""
# 选择模型
if request.model:
model_key = request.model
else:
# 使用 NVIDIA 多头分类器智能路由支持3-tier
model_key = select_model_by_nvidia_classifier(request.messages)
# 获取 LiteLLM 模型名称
provider_model = get_provider_model(model_key)
provider = get_provider_from_model(provider_model)
start_time = time.time()
try:
# 使用 LiteLLM 统一调用
response = await acompletion(
model=provider_model,
messages=[{"role": m.role, "content": m.content} for m in request.messages],
temperature=request.temperature,
max_tokens=request.max_tokens,
)
latency_ms = (time.time() - start_time) * 1000
# 计算成本
input_tokens = response.usage.prompt_tokens
output_tokens = response.usage.completion_tokens
cost = calculate_cost(model_key, input_tokens, output_tokens)
# 记录调用
log_call(model_key, provider, cost, latency_ms, input_tokens + output_tokens)
return ChatResponse(
id=response.id,
model=model_key,
provider=provider,
content=response.choices[0].message.content,
usage={
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
},
cost_usd=cost,
latency_ms=round(latency_ms, 2),
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"API error: {str(e)}")
@app.get("/models")
async def list_models():
"""列出支持的模型"""
return {
"models": [
{
"key": key,
"provider": config["provider"],
"input_cost_per_1k": config["input_cost"],
"output_cost_per_1k": config["output_cost"],
}
for key, config in MODEL_CONFIG.items()
]
}
@app.get("/stats", response_model=StatsResponse)
async def get_stats():
"""获取调用统计"""
if not call_history:
return StatsResponse(
total_calls=0,
total_cost_usd=0.0,
avg_latency_ms=0.0,
model_distribution={},
provider_distribution={},
recent_calls=[],
)
total_calls = len(call_history)
total_cost = sum(c["cost_usd"] for c in call_history)
avg_latency = sum(c["latency_ms"] for c in call_history) / total_calls
# 模型分布
model_dist: Dict[str, int] = {}
provider_dist: Dict[str, int] = {}
for call in call_history:
model = call["model"]
provider = call["provider"]
model_dist[model] = model_dist.get(model, 0) + 1
provider_dist[provider] = provider_dist.get(provider, 0) + 1
# 最近 10 条记录
recent = [
{
"model": c["model"],
"provider": c["provider"],
"cost_usd": round(c["cost_usd"], 6),
"latency_ms": round(c["latency_ms"], 2),
"tokens": c["tokens"],
}
for c in call_history[-10:]
]
return StatsResponse(
total_calls=total_calls,
total_cost_usd=round(total_cost, 6),
avg_latency_ms=round(avg_latency, 2),
model_distribution=model_dist,
provider_distribution=provider_dist,
recent_calls=recent,
)
@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "healthy", "version": "0.3.0", "router": "nvidia-3tier"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)