feat: integrate LiteLLM for multi-provider support

使用 LiteLLM 统一接口支持多 LLM 提供商:
- 支持 OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商
- 统一模型配置 (MODEL_CONFIG)
- 新增 /models 端点列出可用模型
- 统计增加提供商分布
- 简化代码,移除 OpenAI 客户端初始化
This commit is contained in:
2026-04-17 23:42:31 +08:00
parent 2380dd4617
commit 4259478a37
4 changed files with 117 additions and 69 deletions

View File

@@ -1,5 +1,14 @@
# OpenAI API Key # OpenAI API Key
OPENAI_API_KEY=sk-your-api-key-here OPENAI_API_KEY=sk-your-openai-key-here
# Anthropic API Key (Claude)
ANTHROPIC_API_KEY=sk-ant-your-anthropic-key-here
# Google API Key (Gemini)
GEMINI_API_KEY=your-gemini-key-here
# Ollama (本地模型,不需要 API Key)
# OLLAMA_HOST=http://localhost:11434
# 可选:自定义路由阈值 # 可选:自定义路由阈值
# ROUTE_SIMPLE_THRESHOLD=100 # ROUTE_SIMPLE_THRESHOLD=100

View File

@@ -8,31 +8,38 @@ from dotenv import load_dotenv
# 加载 .env 文件 # 加载 .env 文件
load_dotenv() load_dotenv()
# 模型配置 # 统一模型配置(支持多提供商)
# 格式: "统一模型名": {"provider": "litellm格式", "input_cost": x, "output_cost": y}
MODEL_CONFIG = { MODEL_CONFIG = {
"gpt-3.5-turbo": { # OpenAI
"input_cost_per_1k": 0.0005, "gpt-3.5": {"provider": "gpt-3.5-turbo", "input_cost": 0.0005, "output_cost": 0.0015},
"output_cost_per_1k": 0.0015, "gpt-4o-mini": {"provider": "gpt-4o-mini", "input_cost": 0.00015, "output_cost": 0.0006},
"max_tokens": 4096, "gpt-4o": {"provider": "gpt-4o", "input_cost": 0.005, "output_cost": 0.015},
}, # Anthropic
"gpt-4o-mini": { "claude-3-haiku": {"provider": "claude-3-haiku-20240307", "input_cost": 0.00025, "output_cost": 0.00125},
"input_cost_per_1k": 0.00015, "claude-3-sonnet": {"provider": "claude-3-sonnet-20240229", "input_cost": 0.003, "output_cost": 0.015},
"output_cost_per_1k": 0.0006, "claude-3-opus": {"provider": "claude-3-opus-20240229", "input_cost": 0.015, "output_cost": 0.075},
"max_tokens": 128000, # Gemini
}, "gemini-flash": {"provider": "gemini/gemini-1.5-flash", "input_cost": 0.000075, "output_cost": 0.0003},
"gpt-4o": { "gemini-pro": {"provider": "gemini/gemini-1.5-pro", "input_cost": 0.00125, "output_cost": 0.005},
"input_cost_per_1k": 0.005, # 本地/开源
"output_cost_per_1k": 0.015, "llama3": {"provider": "ollama/llama3", "input_cost": 0, "output_cost": 0},
"max_tokens": 128000,
},
} }
# 路由阈值 # 路由阈值token 数 -> 推荐模型)
ROUTING_THRESHOLDS = { ROUTING_THRESHOLDS = {
"simple": 100, # < 100 tokens -> gpt-3.5-turbo "simple": 100, # < 100 tokens
"medium": 500, # < 500 tokens -> gpt-4o-mini "medium": 500, # < 500 tokens
# >= 500 tokens -> gpt-4o
} }
# API Key # 默认模型选择策略
DEFAULT_ROUTING = {
"simple": "gpt-3.5", # 或 "claude-3-haiku", "gemini-flash"
"medium": "gpt-4o-mini", # 或 "claude-3-haiku"
"complex": "gpt-4o", # 或 "claude-3-sonnet", "gemini-pro"
}
# API Keyslitellm 自动读取环境变量)
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")

124
main.py
View File

@@ -1,17 +1,17 @@
""" """
MVP版 LLM 路由服务 MVP版 LLM 路由服务
基于 token 长度的简单规则路由 基于 LiteLLM 的多提供商统一接口
支持: OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商
""" """
import time import time
import tiktoken import tiktoken
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel
from openai import AsyncOpenAI from litellm import acompletion
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, OPENAI_API_KEY from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING
# 调用历史记录 # 调用历史记录
@@ -33,6 +33,7 @@ class ChatRequest(BaseModel):
class ChatResponse(BaseModel): class ChatResponse(BaseModel):
id: str id: str
model: str model: str
provider: str
content: str content: str
usage: Dict[str, int] usage: Dict[str, int]
cost_usd: float cost_usd: float
@@ -44,29 +45,14 @@ class StatsResponse(BaseModel):
total_cost_usd: float total_cost_usd: float
avg_latency_ms: float avg_latency_ms: float
model_distribution: Dict[str, int] model_distribution: Dict[str, int]
provider_distribution: Dict[str, int]
recent_calls: List[Dict[str, Any]] recent_calls: List[Dict[str, Any]]
# 初始化 OpenAI 客户端
client: Optional[AsyncOpenAI] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
global client
if not OPENAI_API_KEY:
raise RuntimeError("OPENAI_API_KEY environment variable is required")
client = AsyncOpenAI(api_key=OPENAI_API_KEY)
yield
client = None
app = FastAPI( app = FastAPI(
title="LLM Router MVP", title="LLM Router MVP",
description="基于 token 长度的简单规则路由服务", description="基于 LiteLLM 的多提供商路由服务",
version="0.1.0", version="0.2.0",
lifespan=lifespan,
) )
@@ -79,10 +65,10 @@ def estimate_tokens(messages: List[Message]) -> int:
total_tokens = 0 total_tokens = 0
for msg in messages: for msg in messages:
total_tokens += 4 # 每条消息的开销 total_tokens += 4
total_tokens += len(encoding.encode(msg.content)) total_tokens += len(encoding.encode(msg.content))
total_tokens += len(encoding.encode(msg.role)) total_tokens += len(encoding.encode(msg.role))
total_tokens += 2 # 回复的开销 total_tokens += 2
return total_tokens return total_tokens
@@ -91,25 +77,47 @@ def select_model_by_length(messages: List[Message]) -> str:
token_count = estimate_tokens(messages) token_count = estimate_tokens(messages)
if token_count < ROUTING_THRESHOLDS["simple"]: if token_count < ROUTING_THRESHOLDS["simple"]:
return "gpt-3.5-turbo" return DEFAULT_ROUTING["simple"]
elif token_count < ROUTING_THRESHOLDS["medium"]: elif token_count < ROUTING_THRESHOLDS["medium"]:
return "gpt-4o-mini" return DEFAULT_ROUTING["medium"]
else: else:
return "gpt-4o" return DEFAULT_ROUTING["complex"]
def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> float: def get_provider_model(model_key: str) -> str:
"""获取 LiteLLM 格式的模型名称"""
config = MODEL_CONFIG.get(model_key)
if not config:
raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}")
return config["provider"]
def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float:
"""计算调用成本""" """计算调用成本"""
config = MODEL_CONFIG.get(model, MODEL_CONFIG["gpt-4o"]) config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"])
input_cost = (input_tokens / 1000) * config["input_cost_per_1k"] input_cost = (input_tokens / 1000) * config["input_cost"]
output_cost = (output_tokens / 1000) * config["output_cost_per_1k"] output_cost = (output_tokens / 1000) * config["output_cost"]
return input_cost + output_cost return input_cost + output_cost
def log_call(model: str, cost: float, latency_ms: float, tokens: int): def get_provider_from_model(model_name: str) -> str:
"""从模型名称推断提供商"""
if model_name.startswith("gpt"):
return "openai"
elif model_name.startswith("claude"):
return "anthropic"
elif model_name.startswith("gemini"):
return "google"
elif "/" in model_name:
return model_name.split("/")[0]
return "unknown"
def log_call(model: str, provider: str, cost: float, latency_ms: float, tokens: int):
"""记录调用历史""" """记录调用历史"""
call_history.append({ call_history.append({
"model": model, "model": model,
"provider": provider,
"cost_usd": cost, "cost_usd": cost,
"latency_ms": latency_ms, "latency_ms": latency_ms,
"tokens": tokens, "tokens": tokens,
@@ -123,21 +131,22 @@ async def chat_completions(request: ChatRequest):
聊天完成接口 聊天完成接口
如果 request.model 未指定,则根据 token 长度自动路由 如果 request.model 未指定,则根据 token 长度自动路由
""" """
if client is None:
raise HTTPException(status_code=500, detail="OpenAI client not initialized")
# 选择模型 # 选择模型
if request.model: if request.model:
model = request.model model_key = request.model
else: else:
model = select_model_by_length(request.messages) model_key = select_model_by_length(request.messages)
# 获取 LiteLLM 模型名称
provider_model = get_provider_model(model_key)
provider = get_provider_from_model(provider_model)
start_time = time.time() start_time = time.time()
try: try:
# OpenAI # 使LiteLLM 统一调用
response = await client.chat.completions.create( response = await acompletion(
model=model, model=provider_model,
messages=[{"role": m.role, "content": m.content} for m in request.messages], messages=[{"role": m.role, "content": m.content} for m in request.messages],
temperature=request.temperature, temperature=request.temperature,
max_tokens=request.max_tokens, max_tokens=request.max_tokens,
@@ -148,14 +157,15 @@ async def chat_completions(request: ChatRequest):
# 计算成本 # 计算成本
input_tokens = response.usage.prompt_tokens input_tokens = response.usage.prompt_tokens
output_tokens = response.usage.completion_tokens output_tokens = response.usage.completion_tokens
cost = calculate_cost(model, input_tokens, output_tokens) cost = calculate_cost(model_key, input_tokens, output_tokens)
# 记录调用 # 记录调用
log_call(model, cost, latency_ms, input_tokens + output_tokens) log_call(model_key, provider, cost, latency_ms, input_tokens + output_tokens)
return ChatResponse( return ChatResponse(
id=response.id, id=response.id,
model=model, model=model_key,
provider=provider,
content=response.choices[0].message.content, content=response.choices[0].message.content,
usage={ usage={
"prompt_tokens": input_tokens, "prompt_tokens": input_tokens,
@@ -167,7 +177,23 @@ async def chat_completions(request: ChatRequest):
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"OpenAI API error: {str(e)}") raise HTTPException(status_code=500, detail=f"API error: {str(e)}")
@app.get("/models")
async def list_models():
"""列出支持的模型"""
return {
"models": [
{
"key": key,
"provider": config["provider"],
"input_cost_per_1k": config["input_cost"],
"output_cost_per_1k": config["output_cost"],
}
for key, config in MODEL_CONFIG.items()
]
}
@app.get("/stats", response_model=StatsResponse) @app.get("/stats", response_model=StatsResponse)
@@ -179,6 +205,7 @@ async def get_stats():
total_cost_usd=0.0, total_cost_usd=0.0,
avg_latency_ms=0.0, avg_latency_ms=0.0,
model_distribution={}, model_distribution={},
provider_distribution={},
recent_calls=[], recent_calls=[],
) )
@@ -188,14 +215,18 @@ async def get_stats():
# 模型分布 # 模型分布
model_dist: Dict[str, int] = {} model_dist: Dict[str, int] = {}
provider_dist: Dict[str, int] = {}
for call in call_history: for call in call_history:
model = call["model"] model = call["model"]
provider = call["provider"]
model_dist[model] = model_dist.get(model, 0) + 1 model_dist[model] = model_dist.get(model, 0) + 1
provider_dist[provider] = provider_dist.get(provider, 0) + 1
# 最近 10 条记录 # 最近 10 条记录
recent = [ recent = [
{ {
"model": c["model"], "model": c["model"],
"provider": c["provider"],
"cost_usd": round(c["cost_usd"], 6), "cost_usd": round(c["cost_usd"], 6),
"latency_ms": round(c["latency_ms"], 2), "latency_ms": round(c["latency_ms"], 2),
"tokens": c["tokens"], "tokens": c["tokens"],
@@ -208,6 +239,7 @@ async def get_stats():
total_cost_usd=round(total_cost, 6), total_cost_usd=round(total_cost, 6),
avg_latency_ms=round(avg_latency, 2), avg_latency_ms=round(avg_latency, 2),
model_distribution=model_dist, model_distribution=model_dist,
provider_distribution=provider_dist,
recent_calls=recent, recent_calls=recent,
) )
@@ -215,7 +247,7 @@ async def get_stats():
@app.get("/health") @app.get("/health")
async def health_check(): async def health_check():
"""健康检查""" """健康检查"""
return {"status": "healthy", "client_initialized": client is not None} return {"status": "healthy", "version": "0.2.0"}
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,7 +1,7 @@
fastapi>=0.104.0 fastapi>=0.104.0
uvicorn[standard]>=0.24.0 uvicorn[standard]>=0.24.0
pydantic>=2.5.0 pydantic>=2.5.0
openai>=1.6.0 litellm>=1.0.0
tiktoken>=0.5.0 tiktoken>=0.5.0
httpx>=0.25.0 httpx>=0.25.0
python-dotenv>=1.0.0 python-dotenv>=1.0.0