feat(api): 重写为OpenAI兼容API并支持流式SSE返回

- 请求/响应完全对齐OpenAI Chat Completions API格式
- 支持 stream=true SSE流式返回 (data: {...}\n\n + [DONE])
- 新增 /v1/models 接口 (OpenAI格式 object:list)
- 非流式响应扩展 routing 字段暴露路由决策细节
- OpenAI Python SDK可直接对接 (base_url=http://localhost:8000/v1)
- 版本升级至v0.4.0
This commit is contained in:
2026-04-18 08:56:12 +08:00
parent 1e273e3670
commit 1705426eef

380
main.py
View File

@@ -1,17 +1,18 @@
"""
MVP版 LLM 路由服务
基于 LiteLLM 的多提供商统一接口
支持: OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商
LLM Router 服务 - OpenAI 兼容 API
基于 LiteLLM + NVIDIA 多头分类器的智能路由服务
支持 OpenAI Chat Completions API 格式(含流式返回)
"""
import time
import json
import os
import uuid
import tiktoken
from typing import List, Dict, Any, Optional
from pathlib import Path
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from litellm import acompletion
import litellm
@@ -21,7 +22,6 @@ from nvidia_router import get_nvidia_router, select_model_by_nvidia
# 配置 LiteLLM 使用 DashScope (Qwen)
if DASHSCOPE_API_KEY:
litellm.api_key = DASHSCOPE_API_KEY
# Qwen 使用 OpenAI 兼容接口,但需要通过 api_base 指定
litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# NVIDIA Router 实例(延迟加载)
@@ -35,16 +35,14 @@ def get_router():
return _nvidia_router
# 调用历史 - JSON 文件持久化
# ── 调用历史 - JSONL 持久化 ────────────────────────────────
CALL_LOG_DIR = Path(__file__).parent / "data"
CALL_LOG_DIR.mkdir(exist_ok=True)
CALL_LOG_FILE = CALL_LOG_DIR / "call_history.jsonl"
# 内存缓存(启动时从文件加载)
call_history: List[Dict[str, Any]] = []
def _load_history():
"""启动时从 JSONL 文件加载历史记录"""
if CALL_LOG_FILE.exists():
with open(CALL_LOG_FILE, "r", encoding="utf-8") as f:
for line in f:
@@ -59,55 +57,52 @@ def _load_history():
_load_history()
class Message(BaseModel):
# ── OpenAI 兼容请求/响应模型 ────────────────────────────────
class ChatMessage(BaseModel):
role: str
content: str
content: Optional[str] = None
name: Optional[str] = None
class ChatRequest(BaseModel):
messages: List[Message]
model: Optional[str] = None # 可选,如果指定则跳过路由
temperature: float = 0.7
class ChatCompletionRequest(BaseModel):
model: Optional[str] = None
messages: List[ChatMessage]
temperature: Optional[float] = 0.7
max_tokens: Optional[int] = None
stream: Optional[bool] = False
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stop: Optional[Any] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
class ChatResponse(BaseModel):
id: str
model: str
provider: str
content: str
usage: Dict[str, int]
cost_usd: float
latency_ms: float
# ── FastAPI App ──────────────────────────────────────────────
app = FastAPI(
title="LLM Router MVP",
description="基于 LiteLLM + NVIDIA 分类器的多提供商路由服务支持3-tier智能路由",
version="0.3.0",
title="LLM Router",
description="OpenAI 兼容的 LLM 智能路由服务NVIDIA 3-tier 分类 + LiteLLM 多提供商",
version="0.4.0",
)
def estimate_tokens(messages: List[Message]) -> int:
"""估算 token 数量"""
# ── 辅助函数 ─────────────────────────────────────────────────
def estimate_tokens(messages: List[ChatMessage]) -> int:
try:
encoding = tiktoken.encoding_for_model("gpt-4")
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
total_tokens = 0
total = 0
for msg in messages:
total_tokens += 4
total_tokens += len(encoding.encode(msg.content))
total_tokens += len(encoding.encode(msg.role))
total_tokens += 2
return total_tokens
total += 4
if msg.content:
total += len(encoding.encode(msg.content))
total += len(encoding.encode(msg.role))
total += 2
return total
def select_model_by_length(messages: List[Message]) -> str:
"""基于 token 长度选择模型(备用策略)"""
def select_model_by_length(messages: List[ChatMessage]) -> str:
token_count = estimate_tokens(messages)
if token_count < ROUTING_THRESHOLDS["simple"]:
return DEFAULT_ROUTING["simple"]
elif token_count < ROUTING_THRESHOLDS["medium"]:
@@ -116,14 +111,16 @@ def select_model_by_length(messages: List[Message]) -> str:
return DEFAULT_ROUTING["complex"]
def select_model_by_nvidia_classifier(messages: List[Message]) -> tuple:
def select_model_by_nvidia_classifier(messages: List[ChatMessage]) -> tuple:
"""
基于 NVIDIA 多头分类器选择模型3-tier路由
Returns:
(model_key, routing_detail) - 模型名称 + 路由分类细节
基于 NVIDIA 多头分类器选择模型
Returns: (model_key, routing_detail)
"""
query = messages[-1].content if messages else ""
query = ""
for msg in reversed(messages):
if msg.role == "user" and msg.content:
query = msg.content
break
try:
router = get_router()
@@ -159,7 +156,6 @@ def select_model_by_nvidia_classifier(messages: List[Message]) -> tuple:
def get_provider_model(model_key: str) -> str:
"""获取 LiteLLM 格式的模型名称"""
config = MODEL_CONFIG.get(model_key)
if not config:
raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}")
@@ -167,55 +163,35 @@ def get_provider_model(model_key: str) -> str:
def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float:
"""计算调用成本"""
config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"])
input_cost = (input_tokens / 1000) * config["input_cost"]
output_cost = (output_tokens / 1000) * config["output_cost"]
return input_cost + output_cost
def get_provider_from_model(model_name: str) -> str:
"""从模型名称推断提供商"""
if model_name.startswith("gpt"):
return "openai"
elif model_name.startswith("claude"):
return "anthropic"
elif model_name.startswith("gemini"):
return "google"
elif "/" in model_name:
return model_name.split("/")[0]
return "unknown"
return (input_tokens / 1000) * config["input_cost"] + (output_tokens / 1000) * config["output_cost"]
def log_call(
model: str,
provider: str,
cost: float,
latency_ms: float,
input_tokens: int,
output_tokens: int,
messages: List[Dict[str, str]],
messages_raw: List[Dict],
response_content: str,
response_id: str,
routing_detail: Optional[Dict[str, Any]],
request_params: Dict[str, Any],
routing_detail: Optional[Dict],
request_params: Dict,
stream: bool = False,
):
"""记录完整调用历史(含路由细节 + LLM 原始数据,供后续调优)"""
record = {
"timestamp": time.time(),
# 请求
"request": {
"messages": messages,
"messages": messages_raw,
"temperature": request_params.get("temperature"),
"max_tokens": request_params.get("max_tokens"),
"stream": stream,
"user_specified_model": request_params.get("user_specified_model"),
},
# 路由决策
"routing": routing_detail,
# LLM 调用
"llm": {
"model": model,
"provider": provider,
"response_id": response_id,
"response_content": response_content,
"input_tokens": input_tokens,
@@ -226,36 +202,85 @@ def log_call(
},
}
call_history.append(record)
# 追加写入 JSONL 文件
with open(CALL_LOG_FILE, "a", encoding="utf-8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
@app.post("/v1/chat/completions", response_model=ChatResponse)
async def chat_completions(request: ChatRequest):
def build_openai_response(
response_id: str,
model: str,
content: str,
input_tokens: int,
output_tokens: int,
routing_detail: Optional[Dict] = None,
) -> Dict:
"""构建 OpenAI 格式的非流式响应"""
resp = {
"id": response_id,
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}],
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
},
}
# 路由细节作为扩展字段
if routing_detail:
resp["routing"] = routing_detail
return resp
# ── 核心 API: /v1/chat/completions ──────────────────────────
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
"""
聊天完成接口
如果 request.model 未指定,则使用 NVIDIA 分类器智能路由
OpenAI 兼容的 Chat Completions API
- model 为空时自动使用 NVIDIA 分类器路由
- stream=true 时返回 SSE 流式响应
"""
# 1. 路由决策
routing_detail = None
if request.model:
model_key = request.model
routing_detail = {
"method": "user_specified",
"query": request.messages[-1].content if request.messages else "",
"query": next((m.content for m in reversed(request.messages) if m.role == "user" and m.content), ""),
}
else:
model_key, routing_detail = select_model_by_nvidia_classifier(request.messages)
# 获取 LiteLLM 模型名称
provider_model = get_provider_model(model_key)
provider = get_provider_from_model(provider_model)
messages_raw = [{"role": m.role, "content": m.content} for m in request.messages]
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
# 2. 流式响应
if request.stream:
return StreamingResponse(
_stream_response(
provider_model=provider_model,
model_key=model_key,
messages_raw=messages_raw,
request=request,
response_id=response_id,
routing_detail=routing_detail,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# 3. 非流式响应
start_time = time.time()
try:
response = await acompletion(
model=provider_model,
@@ -263,26 +288,18 @@ async def chat_completions(request: ChatRequest):
temperature=request.temperature,
max_tokens=request.max_tokens,
)
latency_ms = (time.time() - start_time) * 1000
input_tokens = response.usage.prompt_tokens
output_tokens = response.usage.completion_tokens
cost = calculate_cost(model_key, input_tokens, output_tokens)
response_content = response.choices[0].message.content
content = response.choices[0].message.content
# 记录完整调用数据
log_call(
model=model_key,
provider=provider,
cost=cost,
latency_ms=latency_ms,
input_tokens=input_tokens,
output_tokens=output_tokens,
messages=messages_raw,
response_content=response_content,
response_id=response.id,
routing_detail=routing_detail,
model=model_key, cost=cost, latency_ms=latency_ms,
input_tokens=input_tokens, output_tokens=output_tokens,
messages_raw=messages_raw, response_content=content,
response_id=response_id, routing_detail=routing_detail,
request_params={
"temperature": request.temperature,
"max_tokens": request.max_tokens,
@@ -290,51 +307,122 @@ async def chat_completions(request: ChatRequest):
},
)
return ChatResponse(
id=response.id,
model=model_key,
provider=provider,
content=response_content,
usage={
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
},
cost_usd=cost,
latency_ms=round(latency_ms, 2),
)
return build_openai_response(response_id, model_key, content, input_tokens, output_tokens, routing_detail)
except Exception as e:
raise HTTPException(status_code=500, detail=f"API error: {str(e)}")
@app.get("/models")
async def list_models():
"""列出支持的模型"""
return {
"models": [
{
"key": key,
"provider": config["provider"],
"input_cost_per_1k": config["input_cost"],
"output_cost_per_1k": config["output_cost"],
async def _stream_response(
provider_model: str,
model_key: str,
messages_raw: List[Dict],
request: ChatCompletionRequest,
response_id: str,
routing_detail: Optional[Dict],
):
"""生成 SSE 流式响应"""
start_time = time.time()
collected_content = ""
input_tokens = 0
output_tokens = 0
try:
response = await acompletion(
model=provider_model,
messages=messages_raw,
temperature=request.temperature,
max_tokens=request.max_tokens,
stream=True,
)
async for chunk in response:
delta = chunk.choices[0].delta
# 收集内容用于日志
if delta.content:
collected_content += delta.content
output_tokens += 1 # 近似计数
# 构建 SSE 数据
chunk_data = {
"id": response_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_key,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": None,
}],
}
for key, config in MODEL_CONFIG.items()
]
}
if delta.content:
chunk_data["choices"][0]["delta"] = {"content": delta.content}
elif delta.role:
chunk_data["choices"][0]["delta"] = {"role": delta.role}
if chunk.choices[0].finish_reason:
chunk_data["choices"][0]["finish_reason"] = chunk.choices[0].finish_reason
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
# 发送 [DONE]
yield "data: [DONE]\n\n"
# 记录日志
latency_ms = (time.time() - start_time) * 1000
# 流式模式下用 tiktoken 近似计算 input_tokens
try:
encoding = tiktoken.get_encoding("cl100k_base")
input_tokens = sum(len(encoding.encode(m.get("content", "") or "")) for m in messages_raw) + len(messages_raw) * 4
except Exception:
input_tokens = 0
cost = calculate_cost(model_key, input_tokens, output_tokens)
log_call(
model=model_key, cost=cost, latency_ms=latency_ms,
input_tokens=input_tokens, output_tokens=output_tokens,
messages_raw=messages_raw, response_content=collected_content,
response_id=response_id, routing_detail=routing_detail,
request_params={
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"user_specified_model": request.model,
},
stream=True,
)
except Exception as e:
error_data = {"error": {"message": str(e), "type": "api_error"}}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
# ── OpenAI 兼容: /v1/models ─────────────────────────────────
@app.get("/v1/models")
async def list_models():
"""OpenAI 兼容的模型列表接口"""
data = []
for key, config in MODEL_CONFIG.items():
data.append({
"id": key,
"object": "model",
"created": 1700000000,
"owned_by": config["provider"].split("/")[0] if "/" in config["provider"] else "unknown",
})
return {"object": "list", "data": data}
# ── 管理接口 ─────────────────────────────────────────────────
@app.get("/stats")
async def get_stats():
"""获取调用统计摘要"""
if not call_history:
return {
"total_calls": 0,
"total_cost_usd": 0.0,
"avg_latency_ms": 0.0,
"model_distribution": {},
"tier_distribution": {},
"task_type_distribution": {},
"total_calls": 0, "total_cost_usd": 0.0,
"avg_latency_ms": 0.0, "model_distribution": {},
"tier_distribution": {}, "task_type_distribution": {},
}
total_calls = len(call_history)
@@ -348,14 +436,11 @@ async def get_stats():
for call in call_history:
model = call["llm"]["model"]
model_dist[model] = model_dist.get(model, 0) + 1
routing = call.get("routing") or {}
if routing.get("tier"):
tier = routing["tier"]
tier_dist[tier] = tier_dist.get(tier, 0) + 1
tier_dist[routing["tier"]] = tier_dist.get(routing["tier"], 0) + 1
if routing.get("task_type"):
task = routing["task_type"]
task_dist[task] = task_dist.get(task, 0) + 1
task_dist[routing["task_type"]] = task_dist.get(routing["task_type"], 0) + 1
return {
"total_calls": total_calls,
@@ -372,31 +457,16 @@ async def get_stats():
@app.get("/stats/raw")
async def get_stats_raw(limit: int = 50, offset: int = 0):
"""
获取原始调用记录(含路由分类细节 + LLM 完整数据)
用于后续调优和分析
参数:
- limit: 返回条数默认50
- offset: 偏移量默认0从最新开始
"""
"""获取原始调用记录"""
total = len(call_history)
# 倒序返回(最新在前)
records = list(reversed(call_history))
page = records[offset:offset + limit]
return {
"total": total,
"limit": limit,
"offset": offset,
"records": page,
}
return {"total": total, "limit": limit, "offset": offset, "records": records[offset:offset + limit]}
@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "healthy", "version": "0.3.0", "router": "nvidia-3tier"}
return {"status": "healthy", "version": "0.4.0", "router": "nvidia-3tier"}
if __name__ == "__main__":