Files
llm-compass/main.py
aszerW 4259478a37 feat: integrate LiteLLM for multi-provider support
使用 LiteLLM 统一接口支持多 LLM 提供商:
- 支持 OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商
- 统一模型配置 (MODEL_CONFIG)
- 新增 /models 端点列出可用模型
- 统计增加提供商分布
- 简化代码,移除 OpenAI 客户端初始化
2026-04-17 23:42:31 +08:00

256 lines
7.1 KiB
Python

"""
MVP版 LLM 路由服务
基于 LiteLLM 的多提供商统一接口
支持: OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商
"""
import time
import tiktoken
from typing import List, Dict, Any, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from litellm import acompletion
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING
# 调用历史记录
call_history: List[Dict[str, Any]] = []
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
model: Optional[str] = None # 可选,如果指定则跳过路由
temperature: float = 0.7
max_tokens: Optional[int] = None
class ChatResponse(BaseModel):
id: str
model: str
provider: str
content: str
usage: Dict[str, int]
cost_usd: float
latency_ms: float
class StatsResponse(BaseModel):
total_calls: int
total_cost_usd: float
avg_latency_ms: float
model_distribution: Dict[str, int]
provider_distribution: Dict[str, int]
recent_calls: List[Dict[str, Any]]
app = FastAPI(
title="LLM Router MVP",
description="基于 LiteLLM 的多提供商路由服务",
version="0.2.0",
)
def estimate_tokens(messages: List[Message]) -> int:
"""估算 token 数量"""
try:
encoding = tiktoken.encoding_for_model("gpt-4")
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
total_tokens = 0
for msg in messages:
total_tokens += 4
total_tokens += len(encoding.encode(msg.content))
total_tokens += len(encoding.encode(msg.role))
total_tokens += 2
return total_tokens
def select_model_by_length(messages: List[Message]) -> str:
"""基于 token 长度选择模型"""
token_count = estimate_tokens(messages)
if token_count < ROUTING_THRESHOLDS["simple"]:
return DEFAULT_ROUTING["simple"]
elif token_count < ROUTING_THRESHOLDS["medium"]:
return DEFAULT_ROUTING["medium"]
else:
return DEFAULT_ROUTING["complex"]
def get_provider_model(model_key: str) -> str:
"""获取 LiteLLM 格式的模型名称"""
config = MODEL_CONFIG.get(model_key)
if not config:
raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}")
return config["provider"]
def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float:
"""计算调用成本"""
config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"])
input_cost = (input_tokens / 1000) * config["input_cost"]
output_cost = (output_tokens / 1000) * config["output_cost"]
return input_cost + output_cost
def get_provider_from_model(model_name: str) -> str:
"""从模型名称推断提供商"""
if model_name.startswith("gpt"):
return "openai"
elif model_name.startswith("claude"):
return "anthropic"
elif model_name.startswith("gemini"):
return "google"
elif "/" in model_name:
return model_name.split("/")[0]
return "unknown"
def log_call(model: str, provider: str, cost: float, latency_ms: float, tokens: int):
"""记录调用历史"""
call_history.append({
"model": model,
"provider": provider,
"cost_usd": cost,
"latency_ms": latency_ms,
"tokens": tokens,
"timestamp": time.time(),
})
@app.post("/v1/chat/completions", response_model=ChatResponse)
async def chat_completions(request: ChatRequest):
"""
聊天完成接口
如果 request.model 未指定,则根据 token 长度自动路由
"""
# 选择模型
if request.model:
model_key = request.model
else:
model_key = select_model_by_length(request.messages)
# 获取 LiteLLM 模型名称
provider_model = get_provider_model(model_key)
provider = get_provider_from_model(provider_model)
start_time = time.time()
try:
# 使用 LiteLLM 统一调用
response = await acompletion(
model=provider_model,
messages=[{"role": m.role, "content": m.content} for m in request.messages],
temperature=request.temperature,
max_tokens=request.max_tokens,
)
latency_ms = (time.time() - start_time) * 1000
# 计算成本
input_tokens = response.usage.prompt_tokens
output_tokens = response.usage.completion_tokens
cost = calculate_cost(model_key, input_tokens, output_tokens)
# 记录调用
log_call(model_key, provider, cost, latency_ms, input_tokens + output_tokens)
return ChatResponse(
id=response.id,
model=model_key,
provider=provider,
content=response.choices[0].message.content,
usage={
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
},
cost_usd=cost,
latency_ms=round(latency_ms, 2),
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"API error: {str(e)}")
@app.get("/models")
async def list_models():
"""列出支持的模型"""
return {
"models": [
{
"key": key,
"provider": config["provider"],
"input_cost_per_1k": config["input_cost"],
"output_cost_per_1k": config["output_cost"],
}
for key, config in MODEL_CONFIG.items()
]
}
@app.get("/stats", response_model=StatsResponse)
async def get_stats():
"""获取调用统计"""
if not call_history:
return StatsResponse(
total_calls=0,
total_cost_usd=0.0,
avg_latency_ms=0.0,
model_distribution={},
provider_distribution={},
recent_calls=[],
)
total_calls = len(call_history)
total_cost = sum(c["cost_usd"] for c in call_history)
avg_latency = sum(c["latency_ms"] for c in call_history) / total_calls
# 模型分布
model_dist: Dict[str, int] = {}
provider_dist: Dict[str, int] = {}
for call in call_history:
model = call["model"]
provider = call["provider"]
model_dist[model] = model_dist.get(model, 0) + 1
provider_dist[provider] = provider_dist.get(provider, 0) + 1
# 最近 10 条记录
recent = [
{
"model": c["model"],
"provider": c["provider"],
"cost_usd": round(c["cost_usd"], 6),
"latency_ms": round(c["latency_ms"], 2),
"tokens": c["tokens"],
}
for c in call_history[-10:]
]
return StatsResponse(
total_calls=total_calls,
total_cost_usd=round(total_cost, 6),
avg_latency_ms=round(avg_latency, 2),
model_distribution=model_dist,
provider_distribution=provider_dist,
recent_calls=recent,
)
@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "healthy", "version": "0.2.0"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)