feat: implement MVP LLM router service
实现基于 token 长度的简单规则路由服务: - FastAPI 基础服务 (/v1/chat/completions) - 根据 token 长度自动选择模型 (gpt-3.5/gpt-4o-mini/gpt-4o) - 成本追踪和统计 (/stats) - 健康检查端点 (/health) - 总计 224 行代码
This commit is contained in:
21
.gitignore
vendored
Normal file
21
.gitignore
vendored
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# Python
|
||||||
|
venv/
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
35
config.py
Normal file
35
config.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
简单配置管理
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
MODEL_CONFIG = {
|
||||||
|
"gpt-3.5-turbo": {
|
||||||
|
"input_cost_per_1k": 0.0005,
|
||||||
|
"output_cost_per_1k": 0.0015,
|
||||||
|
"max_tokens": 4096,
|
||||||
|
},
|
||||||
|
"gpt-4o-mini": {
|
||||||
|
"input_cost_per_1k": 0.00015,
|
||||||
|
"output_cost_per_1k": 0.0006,
|
||||||
|
"max_tokens": 128000,
|
||||||
|
},
|
||||||
|
"gpt-4o": {
|
||||||
|
"input_cost_per_1k": 0.005,
|
||||||
|
"output_cost_per_1k": 0.015,
|
||||||
|
"max_tokens": 128000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# 路由阈值
|
||||||
|
ROUTING_THRESHOLDS = {
|
||||||
|
"simple": 100, # < 100 tokens -> gpt-3.5-turbo
|
||||||
|
"medium": 500, # < 500 tokens -> gpt-4o-mini
|
||||||
|
# >= 500 tokens -> gpt-4o
|
||||||
|
}
|
||||||
|
|
||||||
|
# API Key
|
||||||
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||||
223
main.py
Normal file
223
main.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
"""
|
||||||
|
MVP版 LLM 路由服务
|
||||||
|
基于 token 长度的简单规则路由
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
import tiktoken
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, OPENAI_API_KEY
|
||||||
|
|
||||||
|
|
||||||
|
# 调用历史记录
|
||||||
|
call_history: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
messages: List[Message]
|
||||||
|
model: Optional[str] = None # 可选,如果指定则跳过路由
|
||||||
|
temperature: float = 0.7
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
model: str
|
||||||
|
content: str
|
||||||
|
usage: Dict[str, int]
|
||||||
|
cost_usd: float
|
||||||
|
latency_ms: float
|
||||||
|
|
||||||
|
|
||||||
|
class StatsResponse(BaseModel):
|
||||||
|
total_calls: int
|
||||||
|
total_cost_usd: float
|
||||||
|
avg_latency_ms: float
|
||||||
|
model_distribution: Dict[str, int]
|
||||||
|
recent_calls: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
# 初始化 OpenAI 客户端
|
||||||
|
client: Optional[AsyncOpenAI] = None
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""应用生命周期管理"""
|
||||||
|
global client
|
||||||
|
if not OPENAI_API_KEY:
|
||||||
|
raise RuntimeError("OPENAI_API_KEY environment variable is required")
|
||||||
|
client = AsyncOpenAI(api_key=OPENAI_API_KEY)
|
||||||
|
yield
|
||||||
|
client = None
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="LLM Router MVP",
|
||||||
|
description="基于 token 长度的简单规则路由服务",
|
||||||
|
version="0.1.0",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_tokens(messages: List[Message]) -> int:
|
||||||
|
"""估算 token 数量"""
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.encoding_for_model("gpt-4")
|
||||||
|
except KeyError:
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
total_tokens = 0
|
||||||
|
for msg in messages:
|
||||||
|
total_tokens += 4 # 每条消息的开销
|
||||||
|
total_tokens += len(encoding.encode(msg.content))
|
||||||
|
total_tokens += len(encoding.encode(msg.role))
|
||||||
|
total_tokens += 2 # 回复的开销
|
||||||
|
return total_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def select_model_by_length(messages: List[Message]) -> str:
|
||||||
|
"""基于 token 长度选择模型"""
|
||||||
|
token_count = estimate_tokens(messages)
|
||||||
|
|
||||||
|
if token_count < ROUTING_THRESHOLDS["simple"]:
|
||||||
|
return "gpt-3.5-turbo"
|
||||||
|
elif token_count < ROUTING_THRESHOLDS["medium"]:
|
||||||
|
return "gpt-4o-mini"
|
||||||
|
else:
|
||||||
|
return "gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
|
||||||
|
"""计算调用成本"""
|
||||||
|
config = MODEL_CONFIG.get(model, MODEL_CONFIG["gpt-4o"])
|
||||||
|
input_cost = (input_tokens / 1000) * config["input_cost_per_1k"]
|
||||||
|
output_cost = (output_tokens / 1000) * config["output_cost_per_1k"]
|
||||||
|
return input_cost + output_cost
|
||||||
|
|
||||||
|
|
||||||
|
def log_call(model: str, cost: float, latency_ms: float, tokens: int):
|
||||||
|
"""记录调用历史"""
|
||||||
|
call_history.append({
|
||||||
|
"model": model,
|
||||||
|
"cost_usd": cost,
|
||||||
|
"latency_ms": latency_ms,
|
||||||
|
"tokens": tokens,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions", response_model=ChatResponse)
|
||||||
|
async def chat_completions(request: ChatRequest):
|
||||||
|
"""
|
||||||
|
聊天完成接口
|
||||||
|
如果 request.model 未指定,则根据 token 长度自动路由
|
||||||
|
"""
|
||||||
|
if client is None:
|
||||||
|
raise HTTPException(status_code=500, detail="OpenAI client not initialized")
|
||||||
|
|
||||||
|
# 选择模型
|
||||||
|
if request.model:
|
||||||
|
model = request.model
|
||||||
|
else:
|
||||||
|
model = select_model_by_length(request.messages)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 调用 OpenAI
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[{"role": m.role, "content": m.content} for m in request.messages],
|
||||||
|
temperature=request.temperature,
|
||||||
|
max_tokens=request.max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
# 计算成本
|
||||||
|
input_tokens = response.usage.prompt_tokens
|
||||||
|
output_tokens = response.usage.completion_tokens
|
||||||
|
cost = calculate_cost(model, input_tokens, output_tokens)
|
||||||
|
|
||||||
|
# 记录调用
|
||||||
|
log_call(model, cost, latency_ms, input_tokens + output_tokens)
|
||||||
|
|
||||||
|
return ChatResponse(
|
||||||
|
id=response.id,
|
||||||
|
model=model,
|
||||||
|
content=response.choices[0].message.content,
|
||||||
|
usage={
|
||||||
|
"prompt_tokens": input_tokens,
|
||||||
|
"completion_tokens": output_tokens,
|
||||||
|
"total_tokens": input_tokens + output_tokens,
|
||||||
|
},
|
||||||
|
cost_usd=cost,
|
||||||
|
latency_ms=round(latency_ms, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"OpenAI API error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/stats", response_model=StatsResponse)
|
||||||
|
async def get_stats():
|
||||||
|
"""获取调用统计"""
|
||||||
|
if not call_history:
|
||||||
|
return StatsResponse(
|
||||||
|
total_calls=0,
|
||||||
|
total_cost_usd=0.0,
|
||||||
|
avg_latency_ms=0.0,
|
||||||
|
model_distribution={},
|
||||||
|
recent_calls=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
total_calls = len(call_history)
|
||||||
|
total_cost = sum(c["cost_usd"] for c in call_history)
|
||||||
|
avg_latency = sum(c["latency_ms"] for c in call_history) / total_calls
|
||||||
|
|
||||||
|
# 模型分布
|
||||||
|
model_dist: Dict[str, int] = {}
|
||||||
|
for call in call_history:
|
||||||
|
model = call["model"]
|
||||||
|
model_dist[model] = model_dist.get(model, 0) + 1
|
||||||
|
|
||||||
|
# 最近 10 条记录
|
||||||
|
recent = [
|
||||||
|
{
|
||||||
|
"model": c["model"],
|
||||||
|
"cost_usd": round(c["cost_usd"], 6),
|
||||||
|
"latency_ms": round(c["latency_ms"], 2),
|
||||||
|
"tokens": c["tokens"],
|
||||||
|
}
|
||||||
|
for c in call_history[-10:]
|
||||||
|
]
|
||||||
|
|
||||||
|
return StatsResponse(
|
||||||
|
total_calls=total_calls,
|
||||||
|
total_cost_usd=round(total_cost, 6),
|
||||||
|
avg_latency_ms=round(avg_latency, 2),
|
||||||
|
model_distribution=model_dist,
|
||||||
|
recent_calls=recent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""健康检查"""
|
||||||
|
return {"status": "healthy", "client_initialized": client is not None}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
8
requirements.txt
Normal file
8
requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
fastapi>=0.104.0
|
||||||
|
uvicorn[standard]>=0.24.0
|
||||||
|
pydantic>=2.5.0
|
||||||
|
openai>=1.6.0
|
||||||
|
tiktoken>=0.5.0
|
||||||
|
httpx>=0.25.0
|
||||||
|
pytest>=7.4.0
|
||||||
|
pytest-asyncio>=0.21.0
|
||||||
Reference in New Issue
Block a user