From 1705426eef7a129d47d2112470b81754876d7865 Mon Sep 17 00:00:00 2001 From: aszerW Date: Sat, 18 Apr 2026 08:56:12 +0800 Subject: [PATCH] =?UTF-8?q?feat(api):=20=E9=87=8D=E5=86=99=E4=B8=BAOpenAI?= =?UTF-8?q?=E5=85=BC=E5=AE=B9API=E5=B9=B6=E6=94=AF=E6=8C=81=E6=B5=81?= =?UTF-8?q?=E5=BC=8FSSE=E8=BF=94=E5=9B=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 请求/响应完全对齐OpenAI Chat Completions API格式 - 支持 stream=true SSE流式返回 (data: {...}\n\n + [DONE]) - 新增 /v1/models 接口 (OpenAI格式 object:list) - 非流式响应扩展 routing 字段暴露路由决策细节 - OpenAI Python SDK可直接对接 (base_url=http://localhost:8000/v1) - 版本升级至v0.4.0 --- main.py | 380 +++++++++++++++++++++++++++++++++----------------------- 1 file changed, 225 insertions(+), 155 deletions(-) diff --git a/main.py b/main.py index 01cb844..89fa8a1 100644 --- a/main.py +++ b/main.py @@ -1,17 +1,18 @@ """ -MVP版 LLM 路由服务 -基于 LiteLLM 的多提供商统一接口 -支持: OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商 +LLM Router 服务 - OpenAI 兼容 API +基于 LiteLLM + NVIDIA 多头分类器的智能路由服务 +支持 OpenAI Chat Completions API 格式(含流式返回) """ import time import json -import os +import uuid import tiktoken from typing import List, Dict, Any, Optional from pathlib import Path from fastapi import FastAPI, HTTPException -from pydantic import BaseModel +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, Field from litellm import acompletion import litellm @@ -21,7 +22,6 @@ from nvidia_router import get_nvidia_router, select_model_by_nvidia # 配置 LiteLLM 使用 DashScope (Qwen) if DASHSCOPE_API_KEY: litellm.api_key = DASHSCOPE_API_KEY - # Qwen 使用 OpenAI 兼容接口,但需要通过 api_base 指定 litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1" # NVIDIA Router 实例(延迟加载) @@ -35,16 +35,14 @@ def get_router(): return _nvidia_router -# 调用历史 - JSON 文件持久化 +# ── 调用历史 - JSONL 持久化 ──────────────────────────────── CALL_LOG_DIR = Path(__file__).parent / "data" CALL_LOG_DIR.mkdir(exist_ok=True) CALL_LOG_FILE = CALL_LOG_DIR / "call_history.jsonl" -# 内存缓存(启动时从文件加载) call_history: List[Dict[str, Any]] = [] def _load_history(): - """启动时从 JSONL 文件加载历史记录""" if CALL_LOG_FILE.exists(): with open(CALL_LOG_FILE, "r", encoding="utf-8") as f: for line in f: @@ -59,55 +57,52 @@ def _load_history(): _load_history() -class Message(BaseModel): +# ── OpenAI 兼容请求/响应模型 ──────────────────────────────── +class ChatMessage(BaseModel): role: str - content: str + content: Optional[str] = None + name: Optional[str] = None - -class ChatRequest(BaseModel): - messages: List[Message] - model: Optional[str] = None # 可选,如果指定则跳过路由 - temperature: float = 0.7 +class ChatCompletionRequest(BaseModel): + model: Optional[str] = None + messages: List[ChatMessage] + temperature: Optional[float] = 0.7 max_tokens: Optional[int] = None + stream: Optional[bool] = False + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + stop: Optional[Any] = None + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None -class ChatResponse(BaseModel): - id: str - model: str - provider: str - content: str - usage: Dict[str, int] - cost_usd: float - latency_ms: float - - +# ── FastAPI App ────────────────────────────────────────────── app = FastAPI( - title="LLM Router MVP", - description="基于 LiteLLM + NVIDIA 分类器的多提供商路由服务(支持3-tier智能路由)", - version="0.3.0", + title="LLM Router", + description="OpenAI 兼容的 LLM 智能路由服务(NVIDIA 3-tier 分类 + LiteLLM 多提供商)", + version="0.4.0", ) -def estimate_tokens(messages: List[Message]) -> int: - """估算 token 数量""" +# ── 辅助函数 ───────────────────────────────────────────────── +def estimate_tokens(messages: List[ChatMessage]) -> int: try: encoding = tiktoken.encoding_for_model("gpt-4") except KeyError: encoding = tiktoken.get_encoding("cl100k_base") - - total_tokens = 0 + total = 0 for msg in messages: - total_tokens += 4 - total_tokens += len(encoding.encode(msg.content)) - total_tokens += len(encoding.encode(msg.role)) - total_tokens += 2 - return total_tokens + total += 4 + if msg.content: + total += len(encoding.encode(msg.content)) + total += len(encoding.encode(msg.role)) + total += 2 + return total -def select_model_by_length(messages: List[Message]) -> str: - """基于 token 长度选择模型(备用策略)""" +def select_model_by_length(messages: List[ChatMessage]) -> str: token_count = estimate_tokens(messages) - if token_count < ROUTING_THRESHOLDS["simple"]: return DEFAULT_ROUTING["simple"] elif token_count < ROUTING_THRESHOLDS["medium"]: @@ -116,14 +111,16 @@ def select_model_by_length(messages: List[Message]) -> str: return DEFAULT_ROUTING["complex"] -def select_model_by_nvidia_classifier(messages: List[Message]) -> tuple: +def select_model_by_nvidia_classifier(messages: List[ChatMessage]) -> tuple: """ - 基于 NVIDIA 多头分类器选择模型(3-tier路由) - - Returns: - (model_key, routing_detail) - 模型名称 + 路由分类细节 + 基于 NVIDIA 多头分类器选择模型 + Returns: (model_key, routing_detail) """ - query = messages[-1].content if messages else "" + query = "" + for msg in reversed(messages): + if msg.role == "user" and msg.content: + query = msg.content + break try: router = get_router() @@ -159,7 +156,6 @@ def select_model_by_nvidia_classifier(messages: List[Message]) -> tuple: def get_provider_model(model_key: str) -> str: - """获取 LiteLLM 格式的模型名称""" config = MODEL_CONFIG.get(model_key) if not config: raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}") @@ -167,55 +163,35 @@ def get_provider_model(model_key: str) -> str: def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float: - """计算调用成本""" config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"]) - input_cost = (input_tokens / 1000) * config["input_cost"] - output_cost = (output_tokens / 1000) * config["output_cost"] - return input_cost + output_cost - - -def get_provider_from_model(model_name: str) -> str: - """从模型名称推断提供商""" - if model_name.startswith("gpt"): - return "openai" - elif model_name.startswith("claude"): - return "anthropic" - elif model_name.startswith("gemini"): - return "google" - elif "/" in model_name: - return model_name.split("/")[0] - return "unknown" + return (input_tokens / 1000) * config["input_cost"] + (output_tokens / 1000) * config["output_cost"] def log_call( model: str, - provider: str, cost: float, latency_ms: float, input_tokens: int, output_tokens: int, - messages: List[Dict[str, str]], + messages_raw: List[Dict], response_content: str, response_id: str, - routing_detail: Optional[Dict[str, Any]], - request_params: Dict[str, Any], + routing_detail: Optional[Dict], + request_params: Dict, + stream: bool = False, ): - """记录完整调用历史(含路由细节 + LLM 原始数据,供后续调优)""" record = { "timestamp": time.time(), - # 请求 "request": { - "messages": messages, + "messages": messages_raw, "temperature": request_params.get("temperature"), "max_tokens": request_params.get("max_tokens"), + "stream": stream, "user_specified_model": request_params.get("user_specified_model"), }, - # 路由决策 "routing": routing_detail, - # LLM 调用 "llm": { "model": model, - "provider": provider, "response_id": response_id, "response_content": response_content, "input_tokens": input_tokens, @@ -226,36 +202,85 @@ def log_call( }, } call_history.append(record) - - # 追加写入 JSONL 文件 with open(CALL_LOG_FILE, "a", encoding="utf-8") as f: f.write(json.dumps(record, ensure_ascii=False) + "\n") -@app.post("/v1/chat/completions", response_model=ChatResponse) -async def chat_completions(request: ChatRequest): +def build_openai_response( + response_id: str, + model: str, + content: str, + input_tokens: int, + output_tokens: int, + routing_detail: Optional[Dict] = None, +) -> Dict: + """构建 OpenAI 格式的非流式响应""" + resp = { + "id": response_id, + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + }], + "usage": { + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + }, + } + # 路由细节作为扩展字段 + if routing_detail: + resp["routing"] = routing_detail + return resp + + +# ── 核心 API: /v1/chat/completions ────────────────────────── +@app.post("/v1/chat/completions") +async def chat_completions(request: ChatCompletionRequest): """ - 聊天完成接口 - 如果 request.model 未指定,则使用 NVIDIA 分类器智能路由 + OpenAI 兼容的 Chat Completions API + - model 为空时自动使用 NVIDIA 分类器路由 + - stream=true 时返回 SSE 流式响应 """ + # 1. 路由决策 routing_detail = None - if request.model: model_key = request.model routing_detail = { "method": "user_specified", - "query": request.messages[-1].content if request.messages else "", + "query": next((m.content for m in reversed(request.messages) if m.role == "user" and m.content), ""), } else: model_key, routing_detail = select_model_by_nvidia_classifier(request.messages) - # 获取 LiteLLM 模型名称 provider_model = get_provider_model(model_key) - provider = get_provider_from_model(provider_model) messages_raw = [{"role": m.role, "content": m.content} for m in request.messages] + response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + # 2. 流式响应 + if request.stream: + return StreamingResponse( + _stream_response( + provider_model=provider_model, + model_key=model_key, + messages_raw=messages_raw, + request=request, + response_id=response_id, + routing_detail=routing_detail, + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + # 3. 非流式响应 start_time = time.time() - try: response = await acompletion( model=provider_model, @@ -263,26 +288,18 @@ async def chat_completions(request: ChatRequest): temperature=request.temperature, max_tokens=request.max_tokens, ) - latency_ms = (time.time() - start_time) * 1000 input_tokens = response.usage.prompt_tokens output_tokens = response.usage.completion_tokens cost = calculate_cost(model_key, input_tokens, output_tokens) - response_content = response.choices[0].message.content + content = response.choices[0].message.content - # 记录完整调用数据 log_call( - model=model_key, - provider=provider, - cost=cost, - latency_ms=latency_ms, - input_tokens=input_tokens, - output_tokens=output_tokens, - messages=messages_raw, - response_content=response_content, - response_id=response.id, - routing_detail=routing_detail, + model=model_key, cost=cost, latency_ms=latency_ms, + input_tokens=input_tokens, output_tokens=output_tokens, + messages_raw=messages_raw, response_content=content, + response_id=response_id, routing_detail=routing_detail, request_params={ "temperature": request.temperature, "max_tokens": request.max_tokens, @@ -290,51 +307,122 @@ async def chat_completions(request: ChatRequest): }, ) - return ChatResponse( - id=response.id, - model=model_key, - provider=provider, - content=response_content, - usage={ - "prompt_tokens": input_tokens, - "completion_tokens": output_tokens, - "total_tokens": input_tokens + output_tokens, - }, - cost_usd=cost, - latency_ms=round(latency_ms, 2), - ) - + return build_openai_response(response_id, model_key, content, input_tokens, output_tokens, routing_detail) + except Exception as e: raise HTTPException(status_code=500, detail=f"API error: {str(e)}") -@app.get("/models") -async def list_models(): - """列出支持的模型""" - return { - "models": [ - { - "key": key, - "provider": config["provider"], - "input_cost_per_1k": config["input_cost"], - "output_cost_per_1k": config["output_cost"], +async def _stream_response( + provider_model: str, + model_key: str, + messages_raw: List[Dict], + request: ChatCompletionRequest, + response_id: str, + routing_detail: Optional[Dict], +): + """生成 SSE 流式响应""" + start_time = time.time() + collected_content = "" + input_tokens = 0 + output_tokens = 0 + + try: + response = await acompletion( + model=provider_model, + messages=messages_raw, + temperature=request.temperature, + max_tokens=request.max_tokens, + stream=True, + ) + + async for chunk in response: + delta = chunk.choices[0].delta + + # 收集内容用于日志 + if delta.content: + collected_content += delta.content + output_tokens += 1 # 近似计数 + + # 构建 SSE 数据 + chunk_data = { + "id": response_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_key, + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": None, + }], } - for key, config in MODEL_CONFIG.items() - ] - } + + if delta.content: + chunk_data["choices"][0]["delta"] = {"content": delta.content} + elif delta.role: + chunk_data["choices"][0]["delta"] = {"role": delta.role} + + if chunk.choices[0].finish_reason: + chunk_data["choices"][0]["finish_reason"] = chunk.choices[0].finish_reason + + yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n" + + # 发送 [DONE] + yield "data: [DONE]\n\n" + + # 记录日志 + latency_ms = (time.time() - start_time) * 1000 + # 流式模式下用 tiktoken 近似计算 input_tokens + try: + encoding = tiktoken.get_encoding("cl100k_base") + input_tokens = sum(len(encoding.encode(m.get("content", "") or "")) for m in messages_raw) + len(messages_raw) * 4 + except Exception: + input_tokens = 0 + + cost = calculate_cost(model_key, input_tokens, output_tokens) + log_call( + model=model_key, cost=cost, latency_ms=latency_ms, + input_tokens=input_tokens, output_tokens=output_tokens, + messages_raw=messages_raw, response_content=collected_content, + response_id=response_id, routing_detail=routing_detail, + request_params={ + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "user_specified_model": request.model, + }, + stream=True, + ) + + except Exception as e: + error_data = {"error": {"message": str(e), "type": "api_error"}} + yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" +# ── OpenAI 兼容: /v1/models ───────────────────────────────── +@app.get("/v1/models") +async def list_models(): + """OpenAI 兼容的模型列表接口""" + data = [] + for key, config in MODEL_CONFIG.items(): + data.append({ + "id": key, + "object": "model", + "created": 1700000000, + "owned_by": config["provider"].split("/")[0] if "/" in config["provider"] else "unknown", + }) + return {"object": "list", "data": data} + + +# ── 管理接口 ───────────────────────────────────────────────── @app.get("/stats") async def get_stats(): """获取调用统计摘要""" if not call_history: return { - "total_calls": 0, - "total_cost_usd": 0.0, - "avg_latency_ms": 0.0, - "model_distribution": {}, - "tier_distribution": {}, - "task_type_distribution": {}, + "total_calls": 0, "total_cost_usd": 0.0, + "avg_latency_ms": 0.0, "model_distribution": {}, + "tier_distribution": {}, "task_type_distribution": {}, } total_calls = len(call_history) @@ -348,14 +436,11 @@ async def get_stats(): for call in call_history: model = call["llm"]["model"] model_dist[model] = model_dist.get(model, 0) + 1 - routing = call.get("routing") or {} if routing.get("tier"): - tier = routing["tier"] - tier_dist[tier] = tier_dist.get(tier, 0) + 1 + tier_dist[routing["tier"]] = tier_dist.get(routing["tier"], 0) + 1 if routing.get("task_type"): - task = routing["task_type"] - task_dist[task] = task_dist.get(task, 0) + 1 + task_dist[routing["task_type"]] = task_dist.get(routing["task_type"], 0) + 1 return { "total_calls": total_calls, @@ -372,31 +457,16 @@ async def get_stats(): @app.get("/stats/raw") async def get_stats_raw(limit: int = 50, offset: int = 0): - """ - 获取原始调用记录(含路由分类细节 + LLM 完整数据) - 用于后续调优和分析 - - 参数: - - limit: 返回条数(默认50) - - offset: 偏移量(默认0,从最新开始) - """ + """获取原始调用记录""" total = len(call_history) - # 倒序返回(最新在前) records = list(reversed(call_history)) - page = records[offset:offset + limit] - - return { - "total": total, - "limit": limit, - "offset": offset, - "records": page, - } + return {"total": total, "limit": limit, "offset": offset, "records": records[offset:offset + limit]} @app.get("/health") async def health_check(): """健康检查""" - return {"status": "healthy", "version": "0.3.0", "router": "nvidia-3tier"} + return {"status": "healthy", "version": "0.4.0", "router": "nvidia-3tier"} if __name__ == "__main__":