Files
llm-compass/main.py
aszerW 4a8de8925e feat: implement MVP LLM router service
实现基于 token 长度的简单规则路由服务:
- FastAPI 基础服务 (/v1/chat/completions)
- 根据 token 长度自动选择模型 (gpt-3.5/gpt-4o-mini/gpt-4o)
- 成本追踪和统计 (/stats)
- 健康检查端点 (/health)
- 总计 224 行代码
2026-04-17 23:33:43 +08:00

224 lines
6.0 KiB
Python

"""
MVP版 LLM 路由服务
基于 token 长度的简单规则路由
"""
import time
import tiktoken
from typing import List, Dict, Any, Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from openai import AsyncOpenAI
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, OPENAI_API_KEY
# 调用历史记录
call_history: List[Dict[str, Any]] = []
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
model: Optional[str] = None # 可选,如果指定则跳过路由
temperature: float = 0.7
max_tokens: Optional[int] = None
class ChatResponse(BaseModel):
id: str
model: str
content: str
usage: Dict[str, int]
cost_usd: float
latency_ms: float
class StatsResponse(BaseModel):
total_calls: int
total_cost_usd: float
avg_latency_ms: float
model_distribution: Dict[str, int]
recent_calls: List[Dict[str, Any]]
# 初始化 OpenAI 客户端
client: Optional[AsyncOpenAI] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
global client
if not OPENAI_API_KEY:
raise RuntimeError("OPENAI_API_KEY environment variable is required")
client = AsyncOpenAI(api_key=OPENAI_API_KEY)
yield
client = None
app = FastAPI(
title="LLM Router MVP",
description="基于 token 长度的简单规则路由服务",
version="0.1.0",
lifespan=lifespan,
)
def estimate_tokens(messages: List[Message]) -> int:
"""估算 token 数量"""
try:
encoding = tiktoken.encoding_for_model("gpt-4")
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
total_tokens = 0
for msg in messages:
total_tokens += 4 # 每条消息的开销
total_tokens += len(encoding.encode(msg.content))
total_tokens += len(encoding.encode(msg.role))
total_tokens += 2 # 回复的开销
return total_tokens
def select_model_by_length(messages: List[Message]) -> str:
"""基于 token 长度选择模型"""
token_count = estimate_tokens(messages)
if token_count < ROUTING_THRESHOLDS["simple"]:
return "gpt-3.5-turbo"
elif token_count < ROUTING_THRESHOLDS["medium"]:
return "gpt-4o-mini"
else:
return "gpt-4o"
def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
"""计算调用成本"""
config = MODEL_CONFIG.get(model, MODEL_CONFIG["gpt-4o"])
input_cost = (input_tokens / 1000) * config["input_cost_per_1k"]
output_cost = (output_tokens / 1000) * config["output_cost_per_1k"]
return input_cost + output_cost
def log_call(model: str, cost: float, latency_ms: float, tokens: int):
"""记录调用历史"""
call_history.append({
"model": model,
"cost_usd": cost,
"latency_ms": latency_ms,
"tokens": tokens,
"timestamp": time.time(),
})
@app.post("/v1/chat/completions", response_model=ChatResponse)
async def chat_completions(request: ChatRequest):
"""
聊天完成接口
如果 request.model 未指定,则根据 token 长度自动路由
"""
if client is None:
raise HTTPException(status_code=500, detail="OpenAI client not initialized")
# 选择模型
if request.model:
model = request.model
else:
model = select_model_by_length(request.messages)
start_time = time.time()
try:
# 调用 OpenAI
response = await client.chat.completions.create(
model=model,
messages=[{"role": m.role, "content": m.content} for m in request.messages],
temperature=request.temperature,
max_tokens=request.max_tokens,
)
latency_ms = (time.time() - start_time) * 1000
# 计算成本
input_tokens = response.usage.prompt_tokens
output_tokens = response.usage.completion_tokens
cost = calculate_cost(model, input_tokens, output_tokens)
# 记录调用
log_call(model, cost, latency_ms, input_tokens + output_tokens)
return ChatResponse(
id=response.id,
model=model,
content=response.choices[0].message.content,
usage={
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
},
cost_usd=cost,
latency_ms=round(latency_ms, 2),
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"OpenAI API error: {str(e)}")
@app.get("/stats", response_model=StatsResponse)
async def get_stats():
"""获取调用统计"""
if not call_history:
return StatsResponse(
total_calls=0,
total_cost_usd=0.0,
avg_latency_ms=0.0,
model_distribution={},
recent_calls=[],
)
total_calls = len(call_history)
total_cost = sum(c["cost_usd"] for c in call_history)
avg_latency = sum(c["latency_ms"] for c in call_history) / total_calls
# 模型分布
model_dist: Dict[str, int] = {}
for call in call_history:
model = call["model"]
model_dist[model] = model_dist.get(model, 0) + 1
# 最近 10 条记录
recent = [
{
"model": c["model"],
"cost_usd": round(c["cost_usd"], 6),
"latency_ms": round(c["latency_ms"], 2),
"tokens": c["tokens"],
}
for c in call_history[-10:]
]
return StatsResponse(
total_calls=total_calls,
total_cost_usd=round(total_cost, 6),
avg_latency_ms=round(avg_latency, 2),
model_distribution=model_dist,
recent_calls=recent,
)
@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "healthy", "client_initialized": client is not None}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)