feat: integrate LiteLLM for multi-provider support

使用 LiteLLM 统一接口支持多 LLM 提供商:
- 支持 OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商
- 统一模型配置 (MODEL_CONFIG)
- 新增 /models 端点列出可用模型
- 统计增加提供商分布
- 简化代码,移除 OpenAI 客户端初始化
This commit is contained in:
2026-04-17 23:42:31 +08:00
parent 2380dd4617
commit 4259478a37
4 changed files with 117 additions and 69 deletions

View File

@@ -1,5 +1,14 @@
# OpenAI API Key
OPENAI_API_KEY=sk-your-api-key-here
OPENAI_API_KEY=sk-your-openai-key-here
# Anthropic API Key (Claude)
ANTHROPIC_API_KEY=sk-ant-your-anthropic-key-here
# Google API Key (Gemini)
GEMINI_API_KEY=your-gemini-key-here
# Ollama (本地模型,不需要 API Key)
# OLLAMA_HOST=http://localhost:11434
# 可选:自定义路由阈值
# ROUTE_SIMPLE_THRESHOLD=100

View File

@@ -8,31 +8,38 @@ from dotenv import load_dotenv
# 加载 .env 文件
load_dotenv()
# 模型配置
# 统一模型配置(支持多提供商)
# 格式: "统一模型名": {"provider": "litellm格式", "input_cost": x, "output_cost": y}
MODEL_CONFIG = {
"gpt-3.5-turbo": {
"input_cost_per_1k": 0.0005,
"output_cost_per_1k": 0.0015,
"max_tokens": 4096,
},
"gpt-4o-mini": {
"input_cost_per_1k": 0.00015,
"output_cost_per_1k": 0.0006,
"max_tokens": 128000,
},
"gpt-4o": {
"input_cost_per_1k": 0.005,
"output_cost_per_1k": 0.015,
"max_tokens": 128000,
},
# OpenAI
"gpt-3.5": {"provider": "gpt-3.5-turbo", "input_cost": 0.0005, "output_cost": 0.0015},
"gpt-4o-mini": {"provider": "gpt-4o-mini", "input_cost": 0.00015, "output_cost": 0.0006},
"gpt-4o": {"provider": "gpt-4o", "input_cost": 0.005, "output_cost": 0.015},
# Anthropic
"claude-3-haiku": {"provider": "claude-3-haiku-20240307", "input_cost": 0.00025, "output_cost": 0.00125},
"claude-3-sonnet": {"provider": "claude-3-sonnet-20240229", "input_cost": 0.003, "output_cost": 0.015},
"claude-3-opus": {"provider": "claude-3-opus-20240229", "input_cost": 0.015, "output_cost": 0.075},
# Gemini
"gemini-flash": {"provider": "gemini/gemini-1.5-flash", "input_cost": 0.000075, "output_cost": 0.0003},
"gemini-pro": {"provider": "gemini/gemini-1.5-pro", "input_cost": 0.00125, "output_cost": 0.005},
# 本地/开源
"llama3": {"provider": "ollama/llama3", "input_cost": 0, "output_cost": 0},
}
# 路由阈值
# 路由阈值token 数 -> 推荐模型)
ROUTING_THRESHOLDS = {
"simple": 100, # < 100 tokens -> gpt-3.5-turbo
"medium": 500, # < 500 tokens -> gpt-4o-mini
# >= 500 tokens -> gpt-4o
"simple": 100, # < 100 tokens
"medium": 500, # < 500 tokens
}
# API Key
# 默认模型选择策略
DEFAULT_ROUTING = {
"simple": "gpt-3.5", # 或 "claude-3-haiku", "gemini-flash"
"medium": "gpt-4o-mini", # 或 "claude-3-haiku"
"complex": "gpt-4o", # 或 "claude-3-sonnet", "gemini-pro"
}
# API Keyslitellm 自动读取环境变量)
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")

124
main.py
View File

@@ -1,17 +1,17 @@
"""
MVP版 LLM 路由服务
基于 token 长度的简单规则路由
基于 LiteLLM 的多提供商统一接口
支持: OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商
"""
import time
import tiktoken
from typing import List, Dict, Any, Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from openai import AsyncOpenAI
from pydantic import BaseModel
from litellm import acompletion
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, OPENAI_API_KEY
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING
# 调用历史记录
@@ -33,6 +33,7 @@ class ChatRequest(BaseModel):
class ChatResponse(BaseModel):
id: str
model: str
provider: str
content: str
usage: Dict[str, int]
cost_usd: float
@@ -44,29 +45,14 @@ class StatsResponse(BaseModel):
total_cost_usd: float
avg_latency_ms: float
model_distribution: Dict[str, int]
provider_distribution: Dict[str, int]
recent_calls: List[Dict[str, Any]]
# 初始化 OpenAI 客户端
client: Optional[AsyncOpenAI] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
global client
if not OPENAI_API_KEY:
raise RuntimeError("OPENAI_API_KEY environment variable is required")
client = AsyncOpenAI(api_key=OPENAI_API_KEY)
yield
client = None
app = FastAPI(
title="LLM Router MVP",
description="基于 token 长度的简单规则路由服务",
version="0.1.0",
lifespan=lifespan,
description="基于 LiteLLM 的多提供商路由服务",
version="0.2.0",
)
@@ -79,10 +65,10 @@ def estimate_tokens(messages: List[Message]) -> int:
total_tokens = 0
for msg in messages:
total_tokens += 4 # 每条消息的开销
total_tokens += 4
total_tokens += len(encoding.encode(msg.content))
total_tokens += len(encoding.encode(msg.role))
total_tokens += 2 # 回复的开销
total_tokens += 2
return total_tokens
@@ -91,25 +77,47 @@ def select_model_by_length(messages: List[Message]) -> str:
token_count = estimate_tokens(messages)
if token_count < ROUTING_THRESHOLDS["simple"]:
return "gpt-3.5-turbo"
return DEFAULT_ROUTING["simple"]
elif token_count < ROUTING_THRESHOLDS["medium"]:
return "gpt-4o-mini"
return DEFAULT_ROUTING["medium"]
else:
return "gpt-4o"
return DEFAULT_ROUTING["complex"]
def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
def get_provider_model(model_key: str) -> str:
"""获取 LiteLLM 格式的模型名称"""
config = MODEL_CONFIG.get(model_key)
if not config:
raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}")
return config["provider"]
def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float:
"""计算调用成本"""
config = MODEL_CONFIG.get(model, MODEL_CONFIG["gpt-4o"])
input_cost = (input_tokens / 1000) * config["input_cost_per_1k"]
output_cost = (output_tokens / 1000) * config["output_cost_per_1k"]
config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"])
input_cost = (input_tokens / 1000) * config["input_cost"]
output_cost = (output_tokens / 1000) * config["output_cost"]
return input_cost + output_cost
def log_call(model: str, cost: float, latency_ms: float, tokens: int):
def get_provider_from_model(model_name: str) -> str:
"""从模型名称推断提供商"""
if model_name.startswith("gpt"):
return "openai"
elif model_name.startswith("claude"):
return "anthropic"
elif model_name.startswith("gemini"):
return "google"
elif "/" in model_name:
return model_name.split("/")[0]
return "unknown"
def log_call(model: str, provider: str, cost: float, latency_ms: float, tokens: int):
"""记录调用历史"""
call_history.append({
"model": model,
"provider": provider,
"cost_usd": cost,
"latency_ms": latency_ms,
"tokens": tokens,
@@ -123,21 +131,22 @@ async def chat_completions(request: ChatRequest):
聊天完成接口
如果 request.model 未指定,则根据 token 长度自动路由
"""
if client is None:
raise HTTPException(status_code=500, detail="OpenAI client not initialized")
# 选择模型
if request.model:
model = request.model
model_key = request.model
else:
model = select_model_by_length(request.messages)
model_key = select_model_by_length(request.messages)
# 获取 LiteLLM 模型名称
provider_model = get_provider_model(model_key)
provider = get_provider_from_model(provider_model)
start_time = time.time()
try:
# OpenAI
response = await client.chat.completions.create(
model=model,
# 使LiteLLM 统一调用
response = await acompletion(
model=provider_model,
messages=[{"role": m.role, "content": m.content} for m in request.messages],
temperature=request.temperature,
max_tokens=request.max_tokens,
@@ -148,14 +157,15 @@ async def chat_completions(request: ChatRequest):
# 计算成本
input_tokens = response.usage.prompt_tokens
output_tokens = response.usage.completion_tokens
cost = calculate_cost(model, input_tokens, output_tokens)
cost = calculate_cost(model_key, input_tokens, output_tokens)
# 记录调用
log_call(model, cost, latency_ms, input_tokens + output_tokens)
log_call(model_key, provider, cost, latency_ms, input_tokens + output_tokens)
return ChatResponse(
id=response.id,
model=model,
model=model_key,
provider=provider,
content=response.choices[0].message.content,
usage={
"prompt_tokens": input_tokens,
@@ -167,7 +177,23 @@ async def chat_completions(request: ChatRequest):
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"OpenAI API error: {str(e)}")
raise HTTPException(status_code=500, detail=f"API error: {str(e)}")
@app.get("/models")
async def list_models():
"""列出支持的模型"""
return {
"models": [
{
"key": key,
"provider": config["provider"],
"input_cost_per_1k": config["input_cost"],
"output_cost_per_1k": config["output_cost"],
}
for key, config in MODEL_CONFIG.items()
]
}
@app.get("/stats", response_model=StatsResponse)
@@ -179,6 +205,7 @@ async def get_stats():
total_cost_usd=0.0,
avg_latency_ms=0.0,
model_distribution={},
provider_distribution={},
recent_calls=[],
)
@@ -188,14 +215,18 @@ async def get_stats():
# 模型分布
model_dist: Dict[str, int] = {}
provider_dist: Dict[str, int] = {}
for call in call_history:
model = call["model"]
provider = call["provider"]
model_dist[model] = model_dist.get(model, 0) + 1
provider_dist[provider] = provider_dist.get(provider, 0) + 1
# 最近 10 条记录
recent = [
{
"model": c["model"],
"provider": c["provider"],
"cost_usd": round(c["cost_usd"], 6),
"latency_ms": round(c["latency_ms"], 2),
"tokens": c["tokens"],
@@ -208,6 +239,7 @@ async def get_stats():
total_cost_usd=round(total_cost, 6),
avg_latency_ms=round(avg_latency, 2),
model_distribution=model_dist,
provider_distribution=provider_dist,
recent_calls=recent,
)
@@ -215,7 +247,7 @@ async def get_stats():
@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "healthy", "client_initialized": client is not None}
return {"status": "healthy", "version": "0.2.0"}
if __name__ == "__main__":

View File

@@ -1,7 +1,7 @@
fastapi>=0.104.0
uvicorn[standard]>=0.24.0
pydantic>=2.5.0
openai>=1.6.0
litellm>=1.0.0
tiktoken>=0.5.0
httpx>=0.25.0
python-dotenv>=1.0.0