From 4a8de8925e553e1102338656e4841d93a8e9a351 Mon Sep 17 00:00:00 2001 From: aszerW Date: Fri, 17 Apr 2026 23:33:43 +0800 Subject: [PATCH] feat: implement MVP LLM router service MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 实现基于 token 长度的简单规则路由服务: - FastAPI 基础服务 (/v1/chat/completions) - 根据 token 长度自动选择模型 (gpt-3.5/gpt-4o-mini/gpt-4o) - 成本追踪和统计 (/stats) - 健康检查端点 (/health) - 总计 224 行代码 --- .gitignore | 21 +++++ config.py | 35 ++++++++ main.py | 223 +++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 8 ++ 4 files changed, 287 insertions(+) create mode 100644 .gitignore create mode 100644 config.py create mode 100644 main.py create mode 100644 requirements.txt diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e8b4a3d --- /dev/null +++ b/.gitignore @@ -0,0 +1,21 @@ +# Python +venv/ +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python + +# Environment +.env +.venv + +# IDE +.vscode/ +.idea/ +*.swp +*.swo + +# OS +.DS_Store +Thumbs.db diff --git a/config.py b/config.py new file mode 100644 index 0000000..3275d4e --- /dev/null +++ b/config.py @@ -0,0 +1,35 @@ +""" +简单配置管理 +""" +import os +from typing import Literal + + +# 模型配置 +MODEL_CONFIG = { + "gpt-3.5-turbo": { + "input_cost_per_1k": 0.0005, + "output_cost_per_1k": 0.0015, + "max_tokens": 4096, + }, + "gpt-4o-mini": { + "input_cost_per_1k": 0.00015, + "output_cost_per_1k": 0.0006, + "max_tokens": 128000, + }, + "gpt-4o": { + "input_cost_per_1k": 0.005, + "output_cost_per_1k": 0.015, + "max_tokens": 128000, + }, +} + +# 路由阈值 +ROUTING_THRESHOLDS = { + "simple": 100, # < 100 tokens -> gpt-3.5-turbo + "medium": 500, # < 500 tokens -> gpt-4o-mini + # >= 500 tokens -> gpt-4o +} + +# API Key +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") diff --git a/main.py b/main.py new file mode 100644 index 0000000..9f90382 --- /dev/null +++ b/main.py @@ -0,0 +1,223 @@ +""" +MVP版 LLM 路由服务 +基于 token 长度的简单规则路由 +""" +import time +import tiktoken +from typing import List, Dict, Any, Optional +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel, Field +from openai import AsyncOpenAI + +from config import MODEL_CONFIG, ROUTING_THRESHOLDS, OPENAI_API_KEY + + +# 调用历史记录 +call_history: List[Dict[str, Any]] = [] + + +class Message(BaseModel): + role: str + content: str + + +class ChatRequest(BaseModel): + messages: List[Message] + model: Optional[str] = None # 可选,如果指定则跳过路由 + temperature: float = 0.7 + max_tokens: Optional[int] = None + + +class ChatResponse(BaseModel): + id: str + model: str + content: str + usage: Dict[str, int] + cost_usd: float + latency_ms: float + + +class StatsResponse(BaseModel): + total_calls: int + total_cost_usd: float + avg_latency_ms: float + model_distribution: Dict[str, int] + recent_calls: List[Dict[str, Any]] + + +# 初始化 OpenAI 客户端 +client: Optional[AsyncOpenAI] = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用生命周期管理""" + global client + if not OPENAI_API_KEY: + raise RuntimeError("OPENAI_API_KEY environment variable is required") + client = AsyncOpenAI(api_key=OPENAI_API_KEY) + yield + client = None + + +app = FastAPI( + title="LLM Router MVP", + description="基于 token 长度的简单规则路由服务", + version="0.1.0", + lifespan=lifespan, +) + + +def estimate_tokens(messages: List[Message]) -> int: + """估算 token 数量""" + try: + encoding = tiktoken.encoding_for_model("gpt-4") + except KeyError: + encoding = tiktoken.get_encoding("cl100k_base") + + total_tokens = 0 + for msg in messages: + total_tokens += 4 # 每条消息的开销 + total_tokens += len(encoding.encode(msg.content)) + total_tokens += len(encoding.encode(msg.role)) + total_tokens += 2 # 回复的开销 + return total_tokens + + +def select_model_by_length(messages: List[Message]) -> str: + """基于 token 长度选择模型""" + token_count = estimate_tokens(messages) + + if token_count < ROUTING_THRESHOLDS["simple"]: + return "gpt-3.5-turbo" + elif token_count < ROUTING_THRESHOLDS["medium"]: + return "gpt-4o-mini" + else: + return "gpt-4o" + + +def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> float: + """计算调用成本""" + config = MODEL_CONFIG.get(model, MODEL_CONFIG["gpt-4o"]) + input_cost = (input_tokens / 1000) * config["input_cost_per_1k"] + output_cost = (output_tokens / 1000) * config["output_cost_per_1k"] + return input_cost + output_cost + + +def log_call(model: str, cost: float, latency_ms: float, tokens: int): + """记录调用历史""" + call_history.append({ + "model": model, + "cost_usd": cost, + "latency_ms": latency_ms, + "tokens": tokens, + "timestamp": time.time(), + }) + + +@app.post("/v1/chat/completions", response_model=ChatResponse) +async def chat_completions(request: ChatRequest): + """ + 聊天完成接口 + 如果 request.model 未指定,则根据 token 长度自动路由 + """ + if client is None: + raise HTTPException(status_code=500, detail="OpenAI client not initialized") + + # 选择模型 + if request.model: + model = request.model + else: + model = select_model_by_length(request.messages) + + start_time = time.time() + + try: + # 调用 OpenAI + response = await client.chat.completions.create( + model=model, + messages=[{"role": m.role, "content": m.content} for m in request.messages], + temperature=request.temperature, + max_tokens=request.max_tokens, + ) + + latency_ms = (time.time() - start_time) * 1000 + + # 计算成本 + input_tokens = response.usage.prompt_tokens + output_tokens = response.usage.completion_tokens + cost = calculate_cost(model, input_tokens, output_tokens) + + # 记录调用 + log_call(model, cost, latency_ms, input_tokens + output_tokens) + + return ChatResponse( + id=response.id, + model=model, + content=response.choices[0].message.content, + usage={ + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + }, + cost_usd=cost, + latency_ms=round(latency_ms, 2), + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=f"OpenAI API error: {str(e)}") + + +@app.get("/stats", response_model=StatsResponse) +async def get_stats(): + """获取调用统计""" + if not call_history: + return StatsResponse( + total_calls=0, + total_cost_usd=0.0, + avg_latency_ms=0.0, + model_distribution={}, + recent_calls=[], + ) + + total_calls = len(call_history) + total_cost = sum(c["cost_usd"] for c in call_history) + avg_latency = sum(c["latency_ms"] for c in call_history) / total_calls + + # 模型分布 + model_dist: Dict[str, int] = {} + for call in call_history: + model = call["model"] + model_dist[model] = model_dist.get(model, 0) + 1 + + # 最近 10 条记录 + recent = [ + { + "model": c["model"], + "cost_usd": round(c["cost_usd"], 6), + "latency_ms": round(c["latency_ms"], 2), + "tokens": c["tokens"], + } + for c in call_history[-10:] + ] + + return StatsResponse( + total_calls=total_calls, + total_cost_usd=round(total_cost, 6), + avg_latency_ms=round(avg_latency, 2), + model_distribution=model_dist, + recent_calls=recent, + ) + + +@app.get("/health") +async def health_check(): + """健康检查""" + return {"status": "healthy", "client_initialized": client is not None} + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c9d22f4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +fastapi>=0.104.0 +uvicorn[standard]>=0.24.0 +pydantic>=2.5.0 +openai>=1.6.0 +tiktoken>=0.5.0 +httpx>=0.25.0 +pytest>=7.4.0 +pytest-asyncio>=0.21.0