问题: max_tokens 设置了 ge=1 约束,导致 Swagger UI 自动生成默认值 1,
响应内容被严重截断
修复:
- 移除 ge=1 约束,允许 null 值
- example 改为 2048,符合常规使用场景
- 描述更新为'留空时使用模型默认值'
效果: Swagger UI 测试时 max_tokens 默认显示 2048,可返回完整响应
494 lines
18 KiB
Python
494 lines
18 KiB
Python
"""
|
||
LLM Compass - 智能LLM路由服务 (OpenAI 兼容 API)
|
||
基于 LiteLLM + NVIDIA 多头分类器的智能路由服务
|
||
支持 OpenAI Chat Completions API 格式(含流式返回)
|
||
"""
|
||
import time
|
||
import json
|
||
import uuid
|
||
import tiktoken
|
||
from typing import List, Dict, Any, Optional
|
||
from pathlib import Path
|
||
|
||
from fastapi import FastAPI, HTTPException
|
||
from fastapi.responses import StreamingResponse
|
||
from pydantic import BaseModel, Field
|
||
from litellm import acompletion
|
||
import litellm
|
||
|
||
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING, DASHSCOPE_API_KEY
|
||
from nvidia_router import get_nvidia_router, select_model_by_nvidia
|
||
|
||
# 配置 LiteLLM 使用 DashScope (Qwen)
|
||
if DASHSCOPE_API_KEY:
|
||
litellm.api_key = DASHSCOPE_API_KEY
|
||
litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
|
||
# NVIDIA Router 实例(延迟加载)
|
||
_nvidia_router = None
|
||
|
||
def get_router():
|
||
"""获取 NVIDIA Router 实例(延迟加载)"""
|
||
global _nvidia_router
|
||
if _nvidia_router is None:
|
||
_nvidia_router = get_nvidia_router()
|
||
return _nvidia_router
|
||
|
||
|
||
# ── 调用历史 - JSONL 持久化 ────────────────────────────────
|
||
CALL_LOG_DIR = Path(__file__).parent / "data"
|
||
CALL_LOG_DIR.mkdir(exist_ok=True)
|
||
CALL_LOG_FILE = CALL_LOG_DIR / "call_history.jsonl"
|
||
|
||
call_history: List[Dict[str, Any]] = []
|
||
|
||
def _load_history():
|
||
if CALL_LOG_FILE.exists():
|
||
with open(CALL_LOG_FILE, "r", encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if line:
|
||
try:
|
||
call_history.append(json.loads(line))
|
||
except json.JSONDecodeError:
|
||
continue
|
||
print(f"Loaded {len(call_history)} historical records from {CALL_LOG_FILE}")
|
||
|
||
_load_history()
|
||
|
||
|
||
# ── OpenAI 兼容请求/响应模型 ────────────────────────────────
|
||
from pydantic import BaseModel, Field
|
||
class ChatMessage(BaseModel):
|
||
role: str = Field(..., description="角色:system, user, assistant", example="user")
|
||
content: Optional[str] = Field(None, description="消息内容", example="你好,介绍一下你自己")
|
||
name: Optional[str] = Field(None, description="可选的名称")
|
||
|
||
class ChatCompletionRequest(BaseModel):
|
||
model: Optional[str] = Field(
|
||
None,
|
||
description="模型名称(留空时自动使用 NVIDIA 分类器智能路由)",
|
||
example="qwen-plus",
|
||
json_schema_extra={"examples": ["", "qwen-flash", "qwen-plus", "qwen-max"]}
|
||
)
|
||
messages: List[ChatMessage] = Field(
|
||
...,
|
||
description="对话消息列表",
|
||
example=[{"role": "user", "content": "你好,介绍一下你自己"}]
|
||
)
|
||
temperature: Optional[float] = Field(0.7, ge=0, le=2, description="随机性 (0-2)")
|
||
max_tokens: Optional[int] = Field(None, description="最大生成 token 数(留空时使用模型默认值)", example=2048)
|
||
stream: Optional[bool] = Field(False, description="是否使用流式输出")
|
||
top_p: Optional[float] = Field(1.0, ge=0, le=1, description="核采样参数")
|
||
n: Optional[int] = Field(1, ge=1, le=10, description="生成回复数量")
|
||
stop: Optional[Any] = Field(None, description="停止词")
|
||
presence_penalty: Optional[float] = Field(0.0, ge=-2, le=2, description="存在惩罚")
|
||
frequency_penalty: Optional[float] = Field(0.0, ge=-2, le=2, description="频率惩罚")
|
||
user: Optional[str] = Field(None, description="用户标识")
|
||
|
||
|
||
# ── FastAPI App ──────────────────────────────────────────────
|
||
app = FastAPI(
|
||
title="LLM Compass",
|
||
description="智能LLM路由服务,为请求指引最优模型,兼顾质量与成本(NVIDIA 3-tier 分类 + LiteLLM 多提供商)",
|
||
version="0.4.0",
|
||
)
|
||
|
||
|
||
# ── 辅助函数 ─────────────────────────────────────────────────
|
||
def estimate_tokens(messages: List[ChatMessage]) -> int:
|
||
try:
|
||
encoding = tiktoken.encoding_for_model("gpt-4")
|
||
except KeyError:
|
||
encoding = tiktoken.get_encoding("cl100k_base")
|
||
total = 0
|
||
for msg in messages:
|
||
total += 4
|
||
if msg.content:
|
||
total += len(encoding.encode(msg.content))
|
||
total += len(encoding.encode(msg.role))
|
||
total += 2
|
||
return total
|
||
|
||
|
||
def select_model_by_length(messages: List[ChatMessage]) -> str:
|
||
token_count = estimate_tokens(messages)
|
||
if token_count < ROUTING_THRESHOLDS["simple"]:
|
||
return DEFAULT_ROUTING["simple"]
|
||
elif token_count < ROUTING_THRESHOLDS["medium"]:
|
||
return DEFAULT_ROUTING["medium"]
|
||
else:
|
||
return DEFAULT_ROUTING["complex"]
|
||
|
||
|
||
def select_model_by_nvidia_classifier(messages: List[ChatMessage]) -> tuple:
|
||
"""
|
||
基于 NVIDIA 多头分类器选择模型
|
||
Returns: (model_key, routing_detail)
|
||
"""
|
||
query = ""
|
||
for msg in reversed(messages):
|
||
if msg.role == "user" and msg.content:
|
||
query = msg.content
|
||
break
|
||
|
||
try:
|
||
router = get_router()
|
||
start = time.time()
|
||
result = router.predict(query)
|
||
routing_ms = (time.time() - start) * 1000
|
||
|
||
model_map = {"simple": "qwen-flash", "medium": "qwen-plus", "complex": "qwen-max"}
|
||
model_key = model_map[result["tier"]]
|
||
|
||
routing_detail = {
|
||
"method": "nvidia_classifier",
|
||
"query": query,
|
||
"routing_latency_ms": round(routing_ms, 2),
|
||
"tier": result["tier"],
|
||
"complexity_score": result["complexity_score"],
|
||
"task_type": result["task_type"],
|
||
"domain_knowledge": result["domain_knowledge"],
|
||
"reasoning": result["reasoning"],
|
||
"creativity": result["creativity"],
|
||
}
|
||
return model_key, routing_detail
|
||
except Exception as e:
|
||
print(f"NVIDIA routing failed: {e}, falling back to token length")
|
||
model_key = select_model_by_length(messages)
|
||
routing_detail = {
|
||
"method": "fallback_token_length",
|
||
"query": query,
|
||
"routing_latency_ms": 0,
|
||
"error": str(e),
|
||
}
|
||
return model_key, routing_detail
|
||
|
||
|
||
def get_provider_model(model_key: str) -> str:
|
||
config = MODEL_CONFIG.get(model_key)
|
||
if not config:
|
||
raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}")
|
||
return config["provider"]
|
||
|
||
|
||
def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float:
|
||
config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"])
|
||
return (input_tokens / 1000) * config["input_cost"] + (output_tokens / 1000) * config["output_cost"]
|
||
|
||
|
||
def log_call(
|
||
model: str,
|
||
cost: float,
|
||
latency_ms: float,
|
||
input_tokens: int,
|
||
output_tokens: int,
|
||
messages_raw: List[Dict],
|
||
response_content: str,
|
||
response_id: str,
|
||
routing_detail: Optional[Dict],
|
||
request_params: Dict,
|
||
stream: bool = False,
|
||
):
|
||
record = {
|
||
"timestamp": time.time(),
|
||
"request": {
|
||
"messages": messages_raw,
|
||
"temperature": request_params.get("temperature"),
|
||
"max_tokens": request_params.get("max_tokens"),
|
||
"stream": stream,
|
||
"user_specified_model": request_params.get("user_specified_model"),
|
||
},
|
||
"routing": routing_detail,
|
||
"llm": {
|
||
"model": model,
|
||
"response_id": response_id,
|
||
"response_content": response_content,
|
||
"input_tokens": input_tokens,
|
||
"output_tokens": output_tokens,
|
||
"total_tokens": input_tokens + output_tokens,
|
||
"cost_usd": cost,
|
||
"llm_latency_ms": round(latency_ms, 2),
|
||
},
|
||
}
|
||
call_history.append(record)
|
||
with open(CALL_LOG_FILE, "a", encoding="utf-8") as f:
|
||
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||
|
||
|
||
def build_openai_response(
|
||
response_id: str,
|
||
model: str,
|
||
content: str,
|
||
input_tokens: int,
|
||
output_tokens: int,
|
||
routing_detail: Optional[Dict] = None,
|
||
) -> Dict:
|
||
"""构建 OpenAI 格式的非流式响应"""
|
||
resp = {
|
||
"id": response_id,
|
||
"object": "chat.completion",
|
||
"created": int(time.time()),
|
||
"model": model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"message": {"role": "assistant", "content": content},
|
||
"finish_reason": "stop",
|
||
}],
|
||
"usage": {
|
||
"prompt_tokens": input_tokens,
|
||
"completion_tokens": output_tokens,
|
||
"total_tokens": input_tokens + output_tokens,
|
||
},
|
||
}
|
||
# 路由细节作为扩展字段
|
||
if routing_detail:
|
||
resp["routing"] = routing_detail
|
||
return resp
|
||
|
||
|
||
# ── 核心 API: /v1/chat/completions ──────────────────────────
|
||
@app.post("/v1/chat/completions")
|
||
async def chat_completions(request: ChatCompletionRequest):
|
||
"""
|
||
OpenAI 兼容的 Chat Completions API
|
||
- model 为空时自动使用 NVIDIA 分类器路由
|
||
- stream=true 时返回 SSE 流式响应
|
||
"""
|
||
# 1. 路由决策
|
||
routing_detail = None
|
||
if request.model:
|
||
model_key = request.model
|
||
routing_detail = {
|
||
"method": "user_specified",
|
||
"query": next((m.content for m in reversed(request.messages) if m.role == "user" and m.content), ""),
|
||
}
|
||
else:
|
||
model_key, routing_detail = select_model_by_nvidia_classifier(request.messages)
|
||
|
||
provider_model = get_provider_model(model_key)
|
||
messages_raw = [{"role": m.role, "content": m.content} for m in request.messages]
|
||
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||
|
||
# 2. 流式响应
|
||
if request.stream:
|
||
return StreamingResponse(
|
||
_stream_response(
|
||
provider_model=provider_model,
|
||
model_key=model_key,
|
||
messages_raw=messages_raw,
|
||
request=request,
|
||
response_id=response_id,
|
||
routing_detail=routing_detail,
|
||
),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no",
|
||
},
|
||
)
|
||
|
||
# 3. 非流式响应
|
||
start_time = time.time()
|
||
|
||
# 构建请求参数(过滤掉 None 和 0 的 max_tokens)
|
||
completion_kwargs = {
|
||
"model": provider_model,
|
||
"messages": messages_raw,
|
||
"temperature": request.temperature,
|
||
}
|
||
if request.max_tokens and request.max_tokens > 0:
|
||
completion_kwargs["max_tokens"] = request.max_tokens
|
||
|
||
try:
|
||
response = await acompletion(**completion_kwargs)
|
||
latency_ms = (time.time() - start_time) * 1000
|
||
|
||
input_tokens = response.usage.prompt_tokens
|
||
output_tokens = response.usage.completion_tokens
|
||
cost = calculate_cost(model_key, input_tokens, output_tokens)
|
||
content = response.choices[0].message.content
|
||
|
||
log_call(
|
||
model=model_key, cost=cost, latency_ms=latency_ms,
|
||
input_tokens=input_tokens, output_tokens=output_tokens,
|
||
messages_raw=messages_raw, response_content=content,
|
||
response_id=response_id, routing_detail=routing_detail,
|
||
request_params={
|
||
"temperature": request.temperature,
|
||
"max_tokens": request.max_tokens,
|
||
"user_specified_model": request.model,
|
||
},
|
||
)
|
||
|
||
return build_openai_response(response_id, model_key, content, input_tokens, output_tokens, routing_detail)
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"API error: {str(e)}")
|
||
|
||
|
||
async def _stream_response(
|
||
provider_model: str,
|
||
model_key: str,
|
||
messages_raw: List[Dict],
|
||
request: ChatCompletionRequest,
|
||
response_id: str,
|
||
routing_detail: Optional[Dict],
|
||
):
|
||
"""生成 SSE 流式响应"""
|
||
start_time = time.time()
|
||
collected_content = ""
|
||
input_tokens = 0
|
||
output_tokens = 0
|
||
|
||
try:
|
||
# 构建请求参数(过滤掉 None 和 0 的 max_tokens)
|
||
completion_kwargs = {
|
||
"model": provider_model,
|
||
"messages": messages_raw,
|
||
"temperature": request.temperature,
|
||
"stream": True,
|
||
}
|
||
if request.max_tokens and request.max_tokens > 0:
|
||
completion_kwargs["max_tokens"] = request.max_tokens
|
||
|
||
response = await acompletion(**completion_kwargs)
|
||
|
||
async for chunk in response:
|
||
delta = chunk.choices[0].delta
|
||
|
||
# 收集内容用于日志
|
||
if delta.content:
|
||
collected_content += delta.content
|
||
output_tokens += 1 # 近似计数
|
||
|
||
# 构建 SSE 数据
|
||
chunk_data = {
|
||
"id": response_id,
|
||
"object": "chat.completion.chunk",
|
||
"created": int(time.time()),
|
||
"model": model_key,
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {},
|
||
"finish_reason": None,
|
||
}],
|
||
}
|
||
|
||
if delta.content:
|
||
chunk_data["choices"][0]["delta"] = {"content": delta.content}
|
||
elif delta.role:
|
||
chunk_data["choices"][0]["delta"] = {"role": delta.role}
|
||
|
||
if chunk.choices[0].finish_reason:
|
||
chunk_data["choices"][0]["finish_reason"] = chunk.choices[0].finish_reason
|
||
|
||
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
|
||
|
||
# 发送 [DONE]
|
||
yield "data: [DONE]\n\n"
|
||
|
||
# 记录日志
|
||
latency_ms = (time.time() - start_time) * 1000
|
||
# 流式模式下用 tiktoken 近似计算 input_tokens
|
||
try:
|
||
encoding = tiktoken.get_encoding("cl100k_base")
|
||
input_tokens = sum(len(encoding.encode(m.get("content", "") or "")) for m in messages_raw) + len(messages_raw) * 4
|
||
except Exception:
|
||
input_tokens = 0
|
||
|
||
cost = calculate_cost(model_key, input_tokens, output_tokens)
|
||
log_call(
|
||
model=model_key, cost=cost, latency_ms=latency_ms,
|
||
input_tokens=input_tokens, output_tokens=output_tokens,
|
||
messages_raw=messages_raw, response_content=collected_content,
|
||
response_id=response_id, routing_detail=routing_detail,
|
||
request_params={
|
||
"temperature": request.temperature,
|
||
"max_tokens": request.max_tokens,
|
||
"user_specified_model": request.model,
|
||
},
|
||
stream=True,
|
||
)
|
||
|
||
except Exception as e:
|
||
error_data = {"error": {"message": str(e), "type": "api_error"}}
|
||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||
yield "data: [DONE]\n\n"
|
||
|
||
|
||
# ── OpenAI 兼容: /v1/models ─────────────────────────────────
|
||
@app.get("/v1/models")
|
||
async def list_models():
|
||
"""OpenAI 兼容的模型列表接口"""
|
||
data = []
|
||
for key, config in MODEL_CONFIG.items():
|
||
data.append({
|
||
"id": key,
|
||
"object": "model",
|
||
"created": 1700000000,
|
||
"owned_by": config["provider"].split("/")[0] if "/" in config["provider"] else "unknown",
|
||
})
|
||
return {"object": "list", "data": data}
|
||
|
||
|
||
# ── 管理接口 ─────────────────────────────────────────────────
|
||
@app.get("/stats")
|
||
async def get_stats():
|
||
"""获取调用统计摘要"""
|
||
if not call_history:
|
||
return {
|
||
"total_calls": 0, "total_cost_usd": 0.0,
|
||
"avg_latency_ms": 0.0, "model_distribution": {},
|
||
"tier_distribution": {}, "task_type_distribution": {},
|
||
}
|
||
|
||
total_calls = len(call_history)
|
||
total_cost = sum(c["llm"]["cost_usd"] for c in call_history)
|
||
avg_latency = sum(c["llm"]["llm_latency_ms"] for c in call_history) / total_calls
|
||
|
||
model_dist: Dict[str, int] = {}
|
||
tier_dist: Dict[str, int] = {}
|
||
task_dist: Dict[str, int] = {}
|
||
|
||
for call in call_history:
|
||
model = call["llm"]["model"]
|
||
model_dist[model] = model_dist.get(model, 0) + 1
|
||
routing = call.get("routing") or {}
|
||
if routing.get("tier"):
|
||
tier_dist[routing["tier"]] = tier_dist.get(routing["tier"], 0) + 1
|
||
if routing.get("task_type"):
|
||
task_dist[routing["task_type"]] = task_dist.get(routing["task_type"], 0) + 1
|
||
|
||
return {
|
||
"total_calls": total_calls,
|
||
"total_cost_usd": round(total_cost, 6),
|
||
"avg_latency_ms": round(avg_latency, 2),
|
||
"avg_routing_ms": round(
|
||
sum(c.get("routing", {}).get("routing_latency_ms", 0) for c in call_history) / total_calls, 2
|
||
),
|
||
"model_distribution": model_dist,
|
||
"tier_distribution": tier_dist,
|
||
"task_type_distribution": task_dist,
|
||
}
|
||
|
||
|
||
@app.get("/stats/raw")
|
||
async def get_stats_raw(limit: int = 50, offset: int = 0):
|
||
"""获取原始调用记录"""
|
||
total = len(call_history)
|
||
records = list(reversed(call_history))
|
||
return {"total": total, "limit": limit, "offset": offset, "records": records[offset:offset + limit]}
|
||
|
||
|
||
@app.get("/health")
|
||
async def health_check():
|
||
"""健康检查"""
|
||
return {"status": "healthy", "version": "0.4.0", "router": "llm-compass"}
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=8000)
|