Files
llm-compass/main.py
aszerW 1705426eef feat(api): 重写为OpenAI兼容API并支持流式SSE返回
- 请求/响应完全对齐OpenAI Chat Completions API格式
- 支持 stream=true SSE流式返回 (data: {...}\n\n + [DONE])
- 新增 /v1/models 接口 (OpenAI格式 object:list)
- 非流式响应扩展 routing 字段暴露路由决策细节
- OpenAI Python SDK可直接对接 (base_url=http://localhost:8000/v1)
- 版本升级至v0.4.0
2026-04-18 08:56:12 +08:00

475 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LLM Router 服务 - OpenAI 兼容 API
基于 LiteLLM + NVIDIA 多头分类器的智能路由服务
支持 OpenAI Chat Completions API 格式(含流式返回)
"""
import time
import json
import uuid
import tiktoken
from typing import List, Dict, Any, Optional
from pathlib import Path
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from litellm import acompletion
import litellm
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING, DASHSCOPE_API_KEY
from nvidia_router import get_nvidia_router, select_model_by_nvidia
# 配置 LiteLLM 使用 DashScope (Qwen)
if DASHSCOPE_API_KEY:
litellm.api_key = DASHSCOPE_API_KEY
litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# NVIDIA Router 实例(延迟加载)
_nvidia_router = None
def get_router():
"""获取 NVIDIA Router 实例(延迟加载)"""
global _nvidia_router
if _nvidia_router is None:
_nvidia_router = get_nvidia_router()
return _nvidia_router
# ── 调用历史 - JSONL 持久化 ────────────────────────────────
CALL_LOG_DIR = Path(__file__).parent / "data"
CALL_LOG_DIR.mkdir(exist_ok=True)
CALL_LOG_FILE = CALL_LOG_DIR / "call_history.jsonl"
call_history: List[Dict[str, Any]] = []
def _load_history():
if CALL_LOG_FILE.exists():
with open(CALL_LOG_FILE, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
try:
call_history.append(json.loads(line))
except json.JSONDecodeError:
continue
print(f"Loaded {len(call_history)} historical records from {CALL_LOG_FILE}")
_load_history()
# ── OpenAI 兼容请求/响应模型 ────────────────────────────────
class ChatMessage(BaseModel):
role: str
content: Optional[str] = None
name: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: Optional[str] = None
messages: List[ChatMessage]
temperature: Optional[float] = 0.7
max_tokens: Optional[int] = None
stream: Optional[bool] = False
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stop: Optional[Any] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
# ── FastAPI App ──────────────────────────────────────────────
app = FastAPI(
title="LLM Router",
description="OpenAI 兼容的 LLM 智能路由服务NVIDIA 3-tier 分类 + LiteLLM 多提供商)",
version="0.4.0",
)
# ── 辅助函数 ─────────────────────────────────────────────────
def estimate_tokens(messages: List[ChatMessage]) -> int:
try:
encoding = tiktoken.encoding_for_model("gpt-4")
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
total = 0
for msg in messages:
total += 4
if msg.content:
total += len(encoding.encode(msg.content))
total += len(encoding.encode(msg.role))
total += 2
return total
def select_model_by_length(messages: List[ChatMessage]) -> str:
token_count = estimate_tokens(messages)
if token_count < ROUTING_THRESHOLDS["simple"]:
return DEFAULT_ROUTING["simple"]
elif token_count < ROUTING_THRESHOLDS["medium"]:
return DEFAULT_ROUTING["medium"]
else:
return DEFAULT_ROUTING["complex"]
def select_model_by_nvidia_classifier(messages: List[ChatMessage]) -> tuple:
"""
基于 NVIDIA 多头分类器选择模型
Returns: (model_key, routing_detail)
"""
query = ""
for msg in reversed(messages):
if msg.role == "user" and msg.content:
query = msg.content
break
try:
router = get_router()
start = time.time()
result = router.predict(query)
routing_ms = (time.time() - start) * 1000
model_map = {"simple": "qwen-flash", "medium": "qwen-plus", "complex": "qwen-max"}
model_key = model_map[result["tier"]]
routing_detail = {
"method": "nvidia_classifier",
"query": query,
"routing_latency_ms": round(routing_ms, 2),
"tier": result["tier"],
"complexity_score": result["complexity_score"],
"task_type": result["task_type"],
"domain_knowledge": result["domain_knowledge"],
"reasoning": result["reasoning"],
"creativity": result["creativity"],
}
return model_key, routing_detail
except Exception as e:
print(f"NVIDIA routing failed: {e}, falling back to token length")
model_key = select_model_by_length(messages)
routing_detail = {
"method": "fallback_token_length",
"query": query,
"routing_latency_ms": 0,
"error": str(e),
}
return model_key, routing_detail
def get_provider_model(model_key: str) -> str:
config = MODEL_CONFIG.get(model_key)
if not config:
raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}")
return config["provider"]
def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float:
config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"])
return (input_tokens / 1000) * config["input_cost"] + (output_tokens / 1000) * config["output_cost"]
def log_call(
model: str,
cost: float,
latency_ms: float,
input_tokens: int,
output_tokens: int,
messages_raw: List[Dict],
response_content: str,
response_id: str,
routing_detail: Optional[Dict],
request_params: Dict,
stream: bool = False,
):
record = {
"timestamp": time.time(),
"request": {
"messages": messages_raw,
"temperature": request_params.get("temperature"),
"max_tokens": request_params.get("max_tokens"),
"stream": stream,
"user_specified_model": request_params.get("user_specified_model"),
},
"routing": routing_detail,
"llm": {
"model": model,
"response_id": response_id,
"response_content": response_content,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
"cost_usd": cost,
"llm_latency_ms": round(latency_ms, 2),
},
}
call_history.append(record)
with open(CALL_LOG_FILE, "a", encoding="utf-8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
def build_openai_response(
response_id: str,
model: str,
content: str,
input_tokens: int,
output_tokens: int,
routing_detail: Optional[Dict] = None,
) -> Dict:
"""构建 OpenAI 格式的非流式响应"""
resp = {
"id": response_id,
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}],
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
},
}
# 路由细节作为扩展字段
if routing_detail:
resp["routing"] = routing_detail
return resp
# ── 核心 API: /v1/chat/completions ──────────────────────────
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
"""
OpenAI 兼容的 Chat Completions API
- model 为空时自动使用 NVIDIA 分类器路由
- stream=true 时返回 SSE 流式响应
"""
# 1. 路由决策
routing_detail = None
if request.model:
model_key = request.model
routing_detail = {
"method": "user_specified",
"query": next((m.content for m in reversed(request.messages) if m.role == "user" and m.content), ""),
}
else:
model_key, routing_detail = select_model_by_nvidia_classifier(request.messages)
provider_model = get_provider_model(model_key)
messages_raw = [{"role": m.role, "content": m.content} for m in request.messages]
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
# 2. 流式响应
if request.stream:
return StreamingResponse(
_stream_response(
provider_model=provider_model,
model_key=model_key,
messages_raw=messages_raw,
request=request,
response_id=response_id,
routing_detail=routing_detail,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# 3. 非流式响应
start_time = time.time()
try:
response = await acompletion(
model=provider_model,
messages=messages_raw,
temperature=request.temperature,
max_tokens=request.max_tokens,
)
latency_ms = (time.time() - start_time) * 1000
input_tokens = response.usage.prompt_tokens
output_tokens = response.usage.completion_tokens
cost = calculate_cost(model_key, input_tokens, output_tokens)
content = response.choices[0].message.content
log_call(
model=model_key, cost=cost, latency_ms=latency_ms,
input_tokens=input_tokens, output_tokens=output_tokens,
messages_raw=messages_raw, response_content=content,
response_id=response_id, routing_detail=routing_detail,
request_params={
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"user_specified_model": request.model,
},
)
return build_openai_response(response_id, model_key, content, input_tokens, output_tokens, routing_detail)
except Exception as e:
raise HTTPException(status_code=500, detail=f"API error: {str(e)}")
async def _stream_response(
provider_model: str,
model_key: str,
messages_raw: List[Dict],
request: ChatCompletionRequest,
response_id: str,
routing_detail: Optional[Dict],
):
"""生成 SSE 流式响应"""
start_time = time.time()
collected_content = ""
input_tokens = 0
output_tokens = 0
try:
response = await acompletion(
model=provider_model,
messages=messages_raw,
temperature=request.temperature,
max_tokens=request.max_tokens,
stream=True,
)
async for chunk in response:
delta = chunk.choices[0].delta
# 收集内容用于日志
if delta.content:
collected_content += delta.content
output_tokens += 1 # 近似计数
# 构建 SSE 数据
chunk_data = {
"id": response_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_key,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": None,
}],
}
if delta.content:
chunk_data["choices"][0]["delta"] = {"content": delta.content}
elif delta.role:
chunk_data["choices"][0]["delta"] = {"role": delta.role}
if chunk.choices[0].finish_reason:
chunk_data["choices"][0]["finish_reason"] = chunk.choices[0].finish_reason
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
# 发送 [DONE]
yield "data: [DONE]\n\n"
# 记录日志
latency_ms = (time.time() - start_time) * 1000
# 流式模式下用 tiktoken 近似计算 input_tokens
try:
encoding = tiktoken.get_encoding("cl100k_base")
input_tokens = sum(len(encoding.encode(m.get("content", "") or "")) for m in messages_raw) + len(messages_raw) * 4
except Exception:
input_tokens = 0
cost = calculate_cost(model_key, input_tokens, output_tokens)
log_call(
model=model_key, cost=cost, latency_ms=latency_ms,
input_tokens=input_tokens, output_tokens=output_tokens,
messages_raw=messages_raw, response_content=collected_content,
response_id=response_id, routing_detail=routing_detail,
request_params={
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"user_specified_model": request.model,
},
stream=True,
)
except Exception as e:
error_data = {"error": {"message": str(e), "type": "api_error"}}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
# ── OpenAI 兼容: /v1/models ─────────────────────────────────
@app.get("/v1/models")
async def list_models():
"""OpenAI 兼容的模型列表接口"""
data = []
for key, config in MODEL_CONFIG.items():
data.append({
"id": key,
"object": "model",
"created": 1700000000,
"owned_by": config["provider"].split("/")[0] if "/" in config["provider"] else "unknown",
})
return {"object": "list", "data": data}
# ── 管理接口 ─────────────────────────────────────────────────
@app.get("/stats")
async def get_stats():
"""获取调用统计摘要"""
if not call_history:
return {
"total_calls": 0, "total_cost_usd": 0.0,
"avg_latency_ms": 0.0, "model_distribution": {},
"tier_distribution": {}, "task_type_distribution": {},
}
total_calls = len(call_history)
total_cost = sum(c["llm"]["cost_usd"] for c in call_history)
avg_latency = sum(c["llm"]["llm_latency_ms"] for c in call_history) / total_calls
model_dist: Dict[str, int] = {}
tier_dist: Dict[str, int] = {}
task_dist: Dict[str, int] = {}
for call in call_history:
model = call["llm"]["model"]
model_dist[model] = model_dist.get(model, 0) + 1
routing = call.get("routing") or {}
if routing.get("tier"):
tier_dist[routing["tier"]] = tier_dist.get(routing["tier"], 0) + 1
if routing.get("task_type"):
task_dist[routing["task_type"]] = task_dist.get(routing["task_type"], 0) + 1
return {
"total_calls": total_calls,
"total_cost_usd": round(total_cost, 6),
"avg_latency_ms": round(avg_latency, 2),
"avg_routing_ms": round(
sum(c.get("routing", {}).get("routing_latency_ms", 0) for c in call_history) / total_calls, 2
),
"model_distribution": model_dist,
"tier_distribution": tier_dist,
"task_type_distribution": task_dist,
}
@app.get("/stats/raw")
async def get_stats_raw(limit: int = 50, offset: int = 0):
"""获取原始调用记录"""
total = len(call_history)
records = list(reversed(call_history))
return {"total": total, "limit": limit, "offset": offset, "records": records[offset:offset + limit]}
@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "healthy", "version": "0.4.0", "router": "nvidia-3tier"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)