Files
llm-compass/main.py
aszerW 508118cc50 fix: 修复 max_tokens 为 0 或 None 时响应内容被截断的问题
问题: Swagger UI 测试时 max_tokens 默认值为 0,导致 DashScope API
      返回的响应内容只有 1 个 token(被截断)

修复:
- 非流式和流式响应中,当 max_tokens 为 None 或 ≤0 时不传给后端 API
- 让 DashScope 使用自己的默认 max_tokens 值(通常 2048/4096)
- 使用 completion_kwargs 字典动态构建请求参数

效果:
- Swagger UI 中 max_tokens 留空或设为 0 都能返回完整响应
- 需要限制输出时可手动设置合理的 max_tokens 值
2026-04-19 00:58:51 +08:00

494 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LLM Compass - 智能LLM路由服务 (OpenAI 兼容 API)
基于 LiteLLM + NVIDIA 多头分类器的智能路由服务
支持 OpenAI Chat Completions API 格式(含流式返回)
"""
import time
import json
import uuid
import tiktoken
from typing import List, Dict, Any, Optional
from pathlib import Path
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from litellm import acompletion
import litellm
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING, DASHSCOPE_API_KEY
from nvidia_router import get_nvidia_router, select_model_by_nvidia
# 配置 LiteLLM 使用 DashScope (Qwen)
if DASHSCOPE_API_KEY:
litellm.api_key = DASHSCOPE_API_KEY
litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# NVIDIA Router 实例(延迟加载)
_nvidia_router = None
def get_router():
"""获取 NVIDIA Router 实例(延迟加载)"""
global _nvidia_router
if _nvidia_router is None:
_nvidia_router = get_nvidia_router()
return _nvidia_router
# ── 调用历史 - JSONL 持久化 ────────────────────────────────
CALL_LOG_DIR = Path(__file__).parent / "data"
CALL_LOG_DIR.mkdir(exist_ok=True)
CALL_LOG_FILE = CALL_LOG_DIR / "call_history.jsonl"
call_history: List[Dict[str, Any]] = []
def _load_history():
if CALL_LOG_FILE.exists():
with open(CALL_LOG_FILE, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
try:
call_history.append(json.loads(line))
except json.JSONDecodeError:
continue
print(f"Loaded {len(call_history)} historical records from {CALL_LOG_FILE}")
_load_history()
# ── OpenAI 兼容请求/响应模型 ────────────────────────────────
from pydantic import BaseModel, Field
class ChatMessage(BaseModel):
role: str = Field(..., description="角色system, user, assistant", example="user")
content: Optional[str] = Field(None, description="消息内容", example="你好,介绍一下你自己")
name: Optional[str] = Field(None, description="可选的名称")
class ChatCompletionRequest(BaseModel):
model: Optional[str] = Field(
None,
description="模型名称(留空时自动使用 NVIDIA 分类器智能路由)",
example="qwen-plus",
json_schema_extra={"examples": ["", "qwen-flash", "qwen-plus", "qwen-max"]}
)
messages: List[ChatMessage] = Field(
...,
description="对话消息列表",
example=[{"role": "user", "content": "你好,介绍一下你自己"}]
)
temperature: Optional[float] = Field(0.7, ge=0, le=2, description="随机性 (0-2)")
max_tokens: Optional[int] = Field(None, ge=1, description="最大生成 token 数")
stream: Optional[bool] = Field(False, description="是否使用流式输出")
top_p: Optional[float] = Field(1.0, ge=0, le=1, description="核采样参数")
n: Optional[int] = Field(1, ge=1, le=10, description="生成回复数量")
stop: Optional[Any] = Field(None, description="停止词")
presence_penalty: Optional[float] = Field(0.0, ge=-2, le=2, description="存在惩罚")
frequency_penalty: Optional[float] = Field(0.0, ge=-2, le=2, description="频率惩罚")
user: Optional[str] = Field(None, description="用户标识")
# ── FastAPI App ──────────────────────────────────────────────
app = FastAPI(
title="LLM Compass",
description="智能LLM路由服务为请求指引最优模型兼顾质量与成本NVIDIA 3-tier 分类 + LiteLLM 多提供商)",
version="0.4.0",
)
# ── 辅助函数 ─────────────────────────────────────────────────
def estimate_tokens(messages: List[ChatMessage]) -> int:
try:
encoding = tiktoken.encoding_for_model("gpt-4")
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
total = 0
for msg in messages:
total += 4
if msg.content:
total += len(encoding.encode(msg.content))
total += len(encoding.encode(msg.role))
total += 2
return total
def select_model_by_length(messages: List[ChatMessage]) -> str:
token_count = estimate_tokens(messages)
if token_count < ROUTING_THRESHOLDS["simple"]:
return DEFAULT_ROUTING["simple"]
elif token_count < ROUTING_THRESHOLDS["medium"]:
return DEFAULT_ROUTING["medium"]
else:
return DEFAULT_ROUTING["complex"]
def select_model_by_nvidia_classifier(messages: List[ChatMessage]) -> tuple:
"""
基于 NVIDIA 多头分类器选择模型
Returns: (model_key, routing_detail)
"""
query = ""
for msg in reversed(messages):
if msg.role == "user" and msg.content:
query = msg.content
break
try:
router = get_router()
start = time.time()
result = router.predict(query)
routing_ms = (time.time() - start) * 1000
model_map = {"simple": "qwen-flash", "medium": "qwen-plus", "complex": "qwen-max"}
model_key = model_map[result["tier"]]
routing_detail = {
"method": "nvidia_classifier",
"query": query,
"routing_latency_ms": round(routing_ms, 2),
"tier": result["tier"],
"complexity_score": result["complexity_score"],
"task_type": result["task_type"],
"domain_knowledge": result["domain_knowledge"],
"reasoning": result["reasoning"],
"creativity": result["creativity"],
}
return model_key, routing_detail
except Exception as e:
print(f"NVIDIA routing failed: {e}, falling back to token length")
model_key = select_model_by_length(messages)
routing_detail = {
"method": "fallback_token_length",
"query": query,
"routing_latency_ms": 0,
"error": str(e),
}
return model_key, routing_detail
def get_provider_model(model_key: str) -> str:
config = MODEL_CONFIG.get(model_key)
if not config:
raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}")
return config["provider"]
def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float:
config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"])
return (input_tokens / 1000) * config["input_cost"] + (output_tokens / 1000) * config["output_cost"]
def log_call(
model: str,
cost: float,
latency_ms: float,
input_tokens: int,
output_tokens: int,
messages_raw: List[Dict],
response_content: str,
response_id: str,
routing_detail: Optional[Dict],
request_params: Dict,
stream: bool = False,
):
record = {
"timestamp": time.time(),
"request": {
"messages": messages_raw,
"temperature": request_params.get("temperature"),
"max_tokens": request_params.get("max_tokens"),
"stream": stream,
"user_specified_model": request_params.get("user_specified_model"),
},
"routing": routing_detail,
"llm": {
"model": model,
"response_id": response_id,
"response_content": response_content,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
"cost_usd": cost,
"llm_latency_ms": round(latency_ms, 2),
},
}
call_history.append(record)
with open(CALL_LOG_FILE, "a", encoding="utf-8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
def build_openai_response(
response_id: str,
model: str,
content: str,
input_tokens: int,
output_tokens: int,
routing_detail: Optional[Dict] = None,
) -> Dict:
"""构建 OpenAI 格式的非流式响应"""
resp = {
"id": response_id,
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}],
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
},
}
# 路由细节作为扩展字段
if routing_detail:
resp["routing"] = routing_detail
return resp
# ── 核心 API: /v1/chat/completions ──────────────────────────
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
"""
OpenAI 兼容的 Chat Completions API
- model 为空时自动使用 NVIDIA 分类器路由
- stream=true 时返回 SSE 流式响应
"""
# 1. 路由决策
routing_detail = None
if request.model:
model_key = request.model
routing_detail = {
"method": "user_specified",
"query": next((m.content for m in reversed(request.messages) if m.role == "user" and m.content), ""),
}
else:
model_key, routing_detail = select_model_by_nvidia_classifier(request.messages)
provider_model = get_provider_model(model_key)
messages_raw = [{"role": m.role, "content": m.content} for m in request.messages]
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
# 2. 流式响应
if request.stream:
return StreamingResponse(
_stream_response(
provider_model=provider_model,
model_key=model_key,
messages_raw=messages_raw,
request=request,
response_id=response_id,
routing_detail=routing_detail,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# 3. 非流式响应
start_time = time.time()
# 构建请求参数(过滤掉 None 和 0 的 max_tokens
completion_kwargs = {
"model": provider_model,
"messages": messages_raw,
"temperature": request.temperature,
}
if request.max_tokens and request.max_tokens > 0:
completion_kwargs["max_tokens"] = request.max_tokens
try:
response = await acompletion(**completion_kwargs)
latency_ms = (time.time() - start_time) * 1000
input_tokens = response.usage.prompt_tokens
output_tokens = response.usage.completion_tokens
cost = calculate_cost(model_key, input_tokens, output_tokens)
content = response.choices[0].message.content
log_call(
model=model_key, cost=cost, latency_ms=latency_ms,
input_tokens=input_tokens, output_tokens=output_tokens,
messages_raw=messages_raw, response_content=content,
response_id=response_id, routing_detail=routing_detail,
request_params={
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"user_specified_model": request.model,
},
)
return build_openai_response(response_id, model_key, content, input_tokens, output_tokens, routing_detail)
except Exception as e:
raise HTTPException(status_code=500, detail=f"API error: {str(e)}")
async def _stream_response(
provider_model: str,
model_key: str,
messages_raw: List[Dict],
request: ChatCompletionRequest,
response_id: str,
routing_detail: Optional[Dict],
):
"""生成 SSE 流式响应"""
start_time = time.time()
collected_content = ""
input_tokens = 0
output_tokens = 0
try:
# 构建请求参数(过滤掉 None 和 0 的 max_tokens
completion_kwargs = {
"model": provider_model,
"messages": messages_raw,
"temperature": request.temperature,
"stream": True,
}
if request.max_tokens and request.max_tokens > 0:
completion_kwargs["max_tokens"] = request.max_tokens
response = await acompletion(**completion_kwargs)
async for chunk in response:
delta = chunk.choices[0].delta
# 收集内容用于日志
if delta.content:
collected_content += delta.content
output_tokens += 1 # 近似计数
# 构建 SSE 数据
chunk_data = {
"id": response_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_key,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": None,
}],
}
if delta.content:
chunk_data["choices"][0]["delta"] = {"content": delta.content}
elif delta.role:
chunk_data["choices"][0]["delta"] = {"role": delta.role}
if chunk.choices[0].finish_reason:
chunk_data["choices"][0]["finish_reason"] = chunk.choices[0].finish_reason
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
# 发送 [DONE]
yield "data: [DONE]\n\n"
# 记录日志
latency_ms = (time.time() - start_time) * 1000
# 流式模式下用 tiktoken 近似计算 input_tokens
try:
encoding = tiktoken.get_encoding("cl100k_base")
input_tokens = sum(len(encoding.encode(m.get("content", "") or "")) for m in messages_raw) + len(messages_raw) * 4
except Exception:
input_tokens = 0
cost = calculate_cost(model_key, input_tokens, output_tokens)
log_call(
model=model_key, cost=cost, latency_ms=latency_ms,
input_tokens=input_tokens, output_tokens=output_tokens,
messages_raw=messages_raw, response_content=collected_content,
response_id=response_id, routing_detail=routing_detail,
request_params={
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"user_specified_model": request.model,
},
stream=True,
)
except Exception as e:
error_data = {"error": {"message": str(e), "type": "api_error"}}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
# ── OpenAI 兼容: /v1/models ─────────────────────────────────
@app.get("/v1/models")
async def list_models():
"""OpenAI 兼容的模型列表接口"""
data = []
for key, config in MODEL_CONFIG.items():
data.append({
"id": key,
"object": "model",
"created": 1700000000,
"owned_by": config["provider"].split("/")[0] if "/" in config["provider"] else "unknown",
})
return {"object": "list", "data": data}
# ── 管理接口 ─────────────────────────────────────────────────
@app.get("/stats")
async def get_stats():
"""获取调用统计摘要"""
if not call_history:
return {
"total_calls": 0, "total_cost_usd": 0.0,
"avg_latency_ms": 0.0, "model_distribution": {},
"tier_distribution": {}, "task_type_distribution": {},
}
total_calls = len(call_history)
total_cost = sum(c["llm"]["cost_usd"] for c in call_history)
avg_latency = sum(c["llm"]["llm_latency_ms"] for c in call_history) / total_calls
model_dist: Dict[str, int] = {}
tier_dist: Dict[str, int] = {}
task_dist: Dict[str, int] = {}
for call in call_history:
model = call["llm"]["model"]
model_dist[model] = model_dist.get(model, 0) + 1
routing = call.get("routing") or {}
if routing.get("tier"):
tier_dist[routing["tier"]] = tier_dist.get(routing["tier"], 0) + 1
if routing.get("task_type"):
task_dist[routing["task_type"]] = task_dist.get(routing["task_type"], 0) + 1
return {
"total_calls": total_calls,
"total_cost_usd": round(total_cost, 6),
"avg_latency_ms": round(avg_latency, 2),
"avg_routing_ms": round(
sum(c.get("routing", {}).get("routing_latency_ms", 0) for c in call_history) / total_calls, 2
),
"model_distribution": model_dist,
"tier_distribution": tier_dist,
"task_type_distribution": task_dist,
}
@app.get("/stats/raw")
async def get_stats_raw(limit: int = 50, offset: int = 0):
"""获取原始调用记录"""
total = len(call_history)
records = list(reversed(call_history))
return {"total": total, "limit": limit, "offset": offset, "records": records[offset:offset + limit]}
@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "healthy", "version": "0.4.0", "router": "llm-compass"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)