- 新增nvidia_router.py: 手动加载NVIDIA prompt-task-and-complexity-classifier模型 - DeBERTa-v3-base backbone + 8个分类头(task_type/creativity/reasoning/domain等) - 综合多维度评分实现simple/medium/complex三级路由 - 映射: simple->qwen-flash, medium->qwen-plus, complex->qwen-max - main.py切换到NVIDIA路由替代RouteLLM BERT二分类 - 移除LiteLLM依赖解决版本冲突,使用原生httpx调用 - 版本升级至v0.3.0
298 lines
8.7 KiB
Python
298 lines
8.7 KiB
Python
"""
|
||
MVP版 LLM 路由服务
|
||
基于 LiteLLM 的多提供商统一接口
|
||
支持: OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商
|
||
"""
|
||
import time
|
||
import tiktoken
|
||
from typing import List, Dict, Any, Optional
|
||
|
||
from fastapi import FastAPI, HTTPException
|
||
from pydantic import BaseModel
|
||
from litellm import acompletion
|
||
import litellm
|
||
|
||
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING, DASHSCOPE_API_KEY
|
||
from nvidia_router import get_nvidia_router, select_model_by_nvidia
|
||
|
||
# 配置 LiteLLM 使用 DashScope (Qwen)
|
||
if DASHSCOPE_API_KEY:
|
||
litellm.api_key = DASHSCOPE_API_KEY
|
||
# Qwen 使用 OpenAI 兼容接口,但需要通过 api_base 指定
|
||
litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
|
||
# NVIDIA Router 实例(延迟加载)
|
||
_nvidia_router = None
|
||
|
||
def get_router():
|
||
"""获取 NVIDIA Router 实例(延迟加载)"""
|
||
global _nvidia_router
|
||
if _nvidia_router is None:
|
||
_nvidia_router = get_nvidia_router()
|
||
return _nvidia_router
|
||
|
||
|
||
# 调用历史记录
|
||
call_history: List[Dict[str, Any]] = []
|
||
|
||
|
||
class Message(BaseModel):
|
||
role: str
|
||
content: str
|
||
|
||
|
||
class ChatRequest(BaseModel):
|
||
messages: List[Message]
|
||
model: Optional[str] = None # 可选,如果指定则跳过路由
|
||
temperature: float = 0.7
|
||
max_tokens: Optional[int] = None
|
||
|
||
|
||
class ChatResponse(BaseModel):
|
||
id: str
|
||
model: str
|
||
provider: str
|
||
content: str
|
||
usage: Dict[str, int]
|
||
cost_usd: float
|
||
latency_ms: float
|
||
|
||
|
||
class StatsResponse(BaseModel):
|
||
total_calls: int
|
||
total_cost_usd: float
|
||
avg_latency_ms: float
|
||
model_distribution: Dict[str, int]
|
||
provider_distribution: Dict[str, int]
|
||
recent_calls: List[Dict[str, Any]]
|
||
|
||
|
||
app = FastAPI(
|
||
title="LLM Router MVP",
|
||
description="基于 LiteLLM + NVIDIA 分类器的多提供商路由服务(支持3-tier智能路由)",
|
||
version="0.3.0",
|
||
)
|
||
|
||
|
||
def estimate_tokens(messages: List[Message]) -> int:
|
||
"""估算 token 数量"""
|
||
try:
|
||
encoding = tiktoken.encoding_for_model("gpt-4")
|
||
except KeyError:
|
||
encoding = tiktoken.get_encoding("cl100k_base")
|
||
|
||
total_tokens = 0
|
||
for msg in messages:
|
||
total_tokens += 4
|
||
total_tokens += len(encoding.encode(msg.content))
|
||
total_tokens += len(encoding.encode(msg.role))
|
||
total_tokens += 2
|
||
return total_tokens
|
||
|
||
|
||
def select_model_by_length(messages: List[Message]) -> str:
|
||
"""基于 token 长度选择模型(备用策略)"""
|
||
token_count = estimate_tokens(messages)
|
||
|
||
if token_count < ROUTING_THRESHOLDS["simple"]:
|
||
return DEFAULT_ROUTING["simple"]
|
||
elif token_count < ROUTING_THRESHOLDS["medium"]:
|
||
return DEFAULT_ROUTING["medium"]
|
||
else:
|
||
return DEFAULT_ROUTING["complex"]
|
||
|
||
|
||
def select_model_by_nvidia_classifier(messages: List[Message]) -> str:
|
||
"""
|
||
基于 NVIDIA 多头分类器选择模型(3-tier路由)
|
||
|
||
NVIDIA 输出: 多维度复杂度评分
|
||
映射到 Qwen 模型:
|
||
- simple -> qwen-flash (简单任务)
|
||
- medium -> qwen-plus (中等任务)
|
||
- complex -> qwen-max (复杂任务)
|
||
"""
|
||
# 取最后一条用户消息作为查询
|
||
query = messages[-1].content if messages else ""
|
||
|
||
try:
|
||
router = get_router()
|
||
model = router.select_model(query)
|
||
return model
|
||
except Exception as e:
|
||
# NVIDIA 分类器失败时回退到 token 长度策略
|
||
print(f"NVIDIA routing failed: {e}, falling back to token length")
|
||
return select_model_by_length(messages)
|
||
|
||
|
||
def get_provider_model(model_key: str) -> str:
|
||
"""获取 LiteLLM 格式的模型名称"""
|
||
config = MODEL_CONFIG.get(model_key)
|
||
if not config:
|
||
raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}")
|
||
return config["provider"]
|
||
|
||
|
||
def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float:
|
||
"""计算调用成本"""
|
||
config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"])
|
||
input_cost = (input_tokens / 1000) * config["input_cost"]
|
||
output_cost = (output_tokens / 1000) * config["output_cost"]
|
||
return input_cost + output_cost
|
||
|
||
|
||
def get_provider_from_model(model_name: str) -> str:
|
||
"""从模型名称推断提供商"""
|
||
if model_name.startswith("gpt"):
|
||
return "openai"
|
||
elif model_name.startswith("claude"):
|
||
return "anthropic"
|
||
elif model_name.startswith("gemini"):
|
||
return "google"
|
||
elif "/" in model_name:
|
||
return model_name.split("/")[0]
|
||
return "unknown"
|
||
|
||
|
||
def log_call(model: str, provider: str, cost: float, latency_ms: float, tokens: int):
|
||
"""记录调用历史"""
|
||
call_history.append({
|
||
"model": model,
|
||
"provider": provider,
|
||
"cost_usd": cost,
|
||
"latency_ms": latency_ms,
|
||
"tokens": tokens,
|
||
"timestamp": time.time(),
|
||
})
|
||
|
||
|
||
@app.post("/v1/chat/completions", response_model=ChatResponse)
|
||
async def chat_completions(request: ChatRequest):
|
||
"""
|
||
聊天完成接口
|
||
如果 request.model 未指定,则根据 token 长度自动路由
|
||
"""
|
||
# 选择模型
|
||
if request.model:
|
||
model_key = request.model
|
||
else:
|
||
# 使用 NVIDIA 多头分类器智能路由(支持3-tier)
|
||
model_key = select_model_by_nvidia_classifier(request.messages)
|
||
|
||
# 获取 LiteLLM 模型名称
|
||
provider_model = get_provider_model(model_key)
|
||
provider = get_provider_from_model(provider_model)
|
||
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# 使用 LiteLLM 统一调用
|
||
response = await acompletion(
|
||
model=provider_model,
|
||
messages=[{"role": m.role, "content": m.content} for m in request.messages],
|
||
temperature=request.temperature,
|
||
max_tokens=request.max_tokens,
|
||
)
|
||
|
||
latency_ms = (time.time() - start_time) * 1000
|
||
|
||
# 计算成本
|
||
input_tokens = response.usage.prompt_tokens
|
||
output_tokens = response.usage.completion_tokens
|
||
cost = calculate_cost(model_key, input_tokens, output_tokens)
|
||
|
||
# 记录调用
|
||
log_call(model_key, provider, cost, latency_ms, input_tokens + output_tokens)
|
||
|
||
return ChatResponse(
|
||
id=response.id,
|
||
model=model_key,
|
||
provider=provider,
|
||
content=response.choices[0].message.content,
|
||
usage={
|
||
"prompt_tokens": input_tokens,
|
||
"completion_tokens": output_tokens,
|
||
"total_tokens": input_tokens + output_tokens,
|
||
},
|
||
cost_usd=cost,
|
||
latency_ms=round(latency_ms, 2),
|
||
)
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"API error: {str(e)}")
|
||
|
||
|
||
@app.get("/models")
|
||
async def list_models():
|
||
"""列出支持的模型"""
|
||
return {
|
||
"models": [
|
||
{
|
||
"key": key,
|
||
"provider": config["provider"],
|
||
"input_cost_per_1k": config["input_cost"],
|
||
"output_cost_per_1k": config["output_cost"],
|
||
}
|
||
for key, config in MODEL_CONFIG.items()
|
||
]
|
||
}
|
||
|
||
|
||
@app.get("/stats", response_model=StatsResponse)
|
||
async def get_stats():
|
||
"""获取调用统计"""
|
||
if not call_history:
|
||
return StatsResponse(
|
||
total_calls=0,
|
||
total_cost_usd=0.0,
|
||
avg_latency_ms=0.0,
|
||
model_distribution={},
|
||
provider_distribution={},
|
||
recent_calls=[],
|
||
)
|
||
|
||
total_calls = len(call_history)
|
||
total_cost = sum(c["cost_usd"] for c in call_history)
|
||
avg_latency = sum(c["latency_ms"] for c in call_history) / total_calls
|
||
|
||
# 模型分布
|
||
model_dist: Dict[str, int] = {}
|
||
provider_dist: Dict[str, int] = {}
|
||
for call in call_history:
|
||
model = call["model"]
|
||
provider = call["provider"]
|
||
model_dist[model] = model_dist.get(model, 0) + 1
|
||
provider_dist[provider] = provider_dist.get(provider, 0) + 1
|
||
|
||
# 最近 10 条记录
|
||
recent = [
|
||
{
|
||
"model": c["model"],
|
||
"provider": c["provider"],
|
||
"cost_usd": round(c["cost_usd"], 6),
|
||
"latency_ms": round(c["latency_ms"], 2),
|
||
"tokens": c["tokens"],
|
||
}
|
||
for c in call_history[-10:]
|
||
]
|
||
|
||
return StatsResponse(
|
||
total_calls=total_calls,
|
||
total_cost_usd=round(total_cost, 6),
|
||
avg_latency_ms=round(avg_latency, 2),
|
||
model_distribution=model_dist,
|
||
provider_distribution=provider_dist,
|
||
recent_calls=recent,
|
||
)
|
||
|
||
|
||
@app.get("/health")
|
||
async def health_check():
|
||
"""健康检查"""
|
||
return {"status": "healthy", "version": "0.3.0", "router": "nvidia-3tier"}
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=8000)
|