feat: integrate LiteLLM for multi-provider support
使用 LiteLLM 统一接口支持多 LLM 提供商: - 支持 OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商 - 统一模型配置 (MODEL_CONFIG) - 新增 /models 端点列出可用模型 - 统计增加提供商分布 - 简化代码,移除 OpenAI 客户端初始化
This commit is contained in:
11
.env.example
11
.env.example
@@ -1,5 +1,14 @@
|
|||||||
# OpenAI API Key
|
# OpenAI API Key
|
||||||
OPENAI_API_KEY=sk-your-api-key-here
|
OPENAI_API_KEY=sk-your-openai-key-here
|
||||||
|
|
||||||
|
# Anthropic API Key (Claude)
|
||||||
|
ANTHROPIC_API_KEY=sk-ant-your-anthropic-key-here
|
||||||
|
|
||||||
|
# Google API Key (Gemini)
|
||||||
|
GEMINI_API_KEY=your-gemini-key-here
|
||||||
|
|
||||||
|
# Ollama (本地模型,不需要 API Key)
|
||||||
|
# OLLAMA_HOST=http://localhost:11434
|
||||||
|
|
||||||
# 可选:自定义路由阈值
|
# 可选:自定义路由阈值
|
||||||
# ROUTE_SIMPLE_THRESHOLD=100
|
# ROUTE_SIMPLE_THRESHOLD=100
|
||||||
|
|||||||
49
config.py
49
config.py
@@ -8,31 +8,38 @@ from dotenv import load_dotenv
|
|||||||
# 加载 .env 文件
|
# 加载 .env 文件
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# 模型配置
|
# 统一模型配置(支持多提供商)
|
||||||
|
# 格式: "统一模型名": {"provider": "litellm格式", "input_cost": x, "output_cost": y}
|
||||||
MODEL_CONFIG = {
|
MODEL_CONFIG = {
|
||||||
"gpt-3.5-turbo": {
|
# OpenAI
|
||||||
"input_cost_per_1k": 0.0005,
|
"gpt-3.5": {"provider": "gpt-3.5-turbo", "input_cost": 0.0005, "output_cost": 0.0015},
|
||||||
"output_cost_per_1k": 0.0015,
|
"gpt-4o-mini": {"provider": "gpt-4o-mini", "input_cost": 0.00015, "output_cost": 0.0006},
|
||||||
"max_tokens": 4096,
|
"gpt-4o": {"provider": "gpt-4o", "input_cost": 0.005, "output_cost": 0.015},
|
||||||
},
|
# Anthropic
|
||||||
"gpt-4o-mini": {
|
"claude-3-haiku": {"provider": "claude-3-haiku-20240307", "input_cost": 0.00025, "output_cost": 0.00125},
|
||||||
"input_cost_per_1k": 0.00015,
|
"claude-3-sonnet": {"provider": "claude-3-sonnet-20240229", "input_cost": 0.003, "output_cost": 0.015},
|
||||||
"output_cost_per_1k": 0.0006,
|
"claude-3-opus": {"provider": "claude-3-opus-20240229", "input_cost": 0.015, "output_cost": 0.075},
|
||||||
"max_tokens": 128000,
|
# Gemini
|
||||||
},
|
"gemini-flash": {"provider": "gemini/gemini-1.5-flash", "input_cost": 0.000075, "output_cost": 0.0003},
|
||||||
"gpt-4o": {
|
"gemini-pro": {"provider": "gemini/gemini-1.5-pro", "input_cost": 0.00125, "output_cost": 0.005},
|
||||||
"input_cost_per_1k": 0.005,
|
# 本地/开源
|
||||||
"output_cost_per_1k": 0.015,
|
"llama3": {"provider": "ollama/llama3", "input_cost": 0, "output_cost": 0},
|
||||||
"max_tokens": 128000,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# 路由阈值
|
# 路由阈值(token 数 -> 推荐模型)
|
||||||
ROUTING_THRESHOLDS = {
|
ROUTING_THRESHOLDS = {
|
||||||
"simple": 100, # < 100 tokens -> gpt-3.5-turbo
|
"simple": 100, # < 100 tokens
|
||||||
"medium": 500, # < 500 tokens -> gpt-4o-mini
|
"medium": 500, # < 500 tokens
|
||||||
# >= 500 tokens -> gpt-4o
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# API Key
|
# 默认模型选择策略
|
||||||
|
DEFAULT_ROUTING = {
|
||||||
|
"simple": "gpt-3.5", # 或 "claude-3-haiku", "gemini-flash"
|
||||||
|
"medium": "gpt-4o-mini", # 或 "claude-3-haiku"
|
||||||
|
"complex": "gpt-4o", # 或 "claude-3-sonnet", "gemini-pro"
|
||||||
|
}
|
||||||
|
|
||||||
|
# API Keys(litellm 自动读取环境变量)
|
||||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||||
|
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
||||||
|
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
|
||||||
|
|||||||
124
main.py
124
main.py
@@ -1,17 +1,17 @@
|
|||||||
"""
|
"""
|
||||||
MVP版 LLM 路由服务
|
MVP版 LLM 路由服务
|
||||||
基于 token 长度的简单规则路由
|
基于 LiteLLM 的多提供商统一接口
|
||||||
|
支持: OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel
|
||||||
from openai import AsyncOpenAI
|
from litellm import acompletion
|
||||||
|
|
||||||
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, OPENAI_API_KEY
|
from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING
|
||||||
|
|
||||||
|
|
||||||
# 调用历史记录
|
# 调用历史记录
|
||||||
@@ -33,6 +33,7 @@ class ChatRequest(BaseModel):
|
|||||||
class ChatResponse(BaseModel):
|
class ChatResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
model: str
|
model: str
|
||||||
|
provider: str
|
||||||
content: str
|
content: str
|
||||||
usage: Dict[str, int]
|
usage: Dict[str, int]
|
||||||
cost_usd: float
|
cost_usd: float
|
||||||
@@ -44,29 +45,14 @@ class StatsResponse(BaseModel):
|
|||||||
total_cost_usd: float
|
total_cost_usd: float
|
||||||
avg_latency_ms: float
|
avg_latency_ms: float
|
||||||
model_distribution: Dict[str, int]
|
model_distribution: Dict[str, int]
|
||||||
|
provider_distribution: Dict[str, int]
|
||||||
recent_calls: List[Dict[str, Any]]
|
recent_calls: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
# 初始化 OpenAI 客户端
|
|
||||||
client: Optional[AsyncOpenAI] = None
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
"""应用生命周期管理"""
|
|
||||||
global client
|
|
||||||
if not OPENAI_API_KEY:
|
|
||||||
raise RuntimeError("OPENAI_API_KEY environment variable is required")
|
|
||||||
client = AsyncOpenAI(api_key=OPENAI_API_KEY)
|
|
||||||
yield
|
|
||||||
client = None
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="LLM Router MVP",
|
title="LLM Router MVP",
|
||||||
description="基于 token 长度的简单规则路由服务",
|
description="基于 LiteLLM 的多提供商路由服务",
|
||||||
version="0.1.0",
|
version="0.2.0",
|
||||||
lifespan=lifespan,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -79,10 +65,10 @@ def estimate_tokens(messages: List[Message]) -> int:
|
|||||||
|
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
total_tokens += 4 # 每条消息的开销
|
total_tokens += 4
|
||||||
total_tokens += len(encoding.encode(msg.content))
|
total_tokens += len(encoding.encode(msg.content))
|
||||||
total_tokens += len(encoding.encode(msg.role))
|
total_tokens += len(encoding.encode(msg.role))
|
||||||
total_tokens += 2 # 回复的开销
|
total_tokens += 2
|
||||||
return total_tokens
|
return total_tokens
|
||||||
|
|
||||||
|
|
||||||
@@ -91,25 +77,47 @@ def select_model_by_length(messages: List[Message]) -> str:
|
|||||||
token_count = estimate_tokens(messages)
|
token_count = estimate_tokens(messages)
|
||||||
|
|
||||||
if token_count < ROUTING_THRESHOLDS["simple"]:
|
if token_count < ROUTING_THRESHOLDS["simple"]:
|
||||||
return "gpt-3.5-turbo"
|
return DEFAULT_ROUTING["simple"]
|
||||||
elif token_count < ROUTING_THRESHOLDS["medium"]:
|
elif token_count < ROUTING_THRESHOLDS["medium"]:
|
||||||
return "gpt-4o-mini"
|
return DEFAULT_ROUTING["medium"]
|
||||||
else:
|
else:
|
||||||
return "gpt-4o"
|
return DEFAULT_ROUTING["complex"]
|
||||||
|
|
||||||
|
|
||||||
def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
|
def get_provider_model(model_key: str) -> str:
|
||||||
|
"""获取 LiteLLM 格式的模型名称"""
|
||||||
|
config = MODEL_CONFIG.get(model_key)
|
||||||
|
if not config:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}")
|
||||||
|
return config["provider"]
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float:
|
||||||
"""计算调用成本"""
|
"""计算调用成本"""
|
||||||
config = MODEL_CONFIG.get(model, MODEL_CONFIG["gpt-4o"])
|
config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"])
|
||||||
input_cost = (input_tokens / 1000) * config["input_cost_per_1k"]
|
input_cost = (input_tokens / 1000) * config["input_cost"]
|
||||||
output_cost = (output_tokens / 1000) * config["output_cost_per_1k"]
|
output_cost = (output_tokens / 1000) * config["output_cost"]
|
||||||
return input_cost + output_cost
|
return input_cost + output_cost
|
||||||
|
|
||||||
|
|
||||||
def log_call(model: str, cost: float, latency_ms: float, tokens: int):
|
def get_provider_from_model(model_name: str) -> str:
|
||||||
|
"""从模型名称推断提供商"""
|
||||||
|
if model_name.startswith("gpt"):
|
||||||
|
return "openai"
|
||||||
|
elif model_name.startswith("claude"):
|
||||||
|
return "anthropic"
|
||||||
|
elif model_name.startswith("gemini"):
|
||||||
|
return "google"
|
||||||
|
elif "/" in model_name:
|
||||||
|
return model_name.split("/")[0]
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def log_call(model: str, provider: str, cost: float, latency_ms: float, tokens: int):
|
||||||
"""记录调用历史"""
|
"""记录调用历史"""
|
||||||
call_history.append({
|
call_history.append({
|
||||||
"model": model,
|
"model": model,
|
||||||
|
"provider": provider,
|
||||||
"cost_usd": cost,
|
"cost_usd": cost,
|
||||||
"latency_ms": latency_ms,
|
"latency_ms": latency_ms,
|
||||||
"tokens": tokens,
|
"tokens": tokens,
|
||||||
@@ -123,21 +131,22 @@ async def chat_completions(request: ChatRequest):
|
|||||||
聊天完成接口
|
聊天完成接口
|
||||||
如果 request.model 未指定,则根据 token 长度自动路由
|
如果 request.model 未指定,则根据 token 长度自动路由
|
||||||
"""
|
"""
|
||||||
if client is None:
|
|
||||||
raise HTTPException(status_code=500, detail="OpenAI client not initialized")
|
|
||||||
|
|
||||||
# 选择模型
|
# 选择模型
|
||||||
if request.model:
|
if request.model:
|
||||||
model = request.model
|
model_key = request.model
|
||||||
else:
|
else:
|
||||||
model = select_model_by_length(request.messages)
|
model_key = select_model_by_length(request.messages)
|
||||||
|
|
||||||
|
# 获取 LiteLLM 模型名称
|
||||||
|
provider_model = get_provider_model(model_key)
|
||||||
|
provider = get_provider_from_model(provider_model)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 调用 OpenAI
|
# 使用 LiteLLM 统一调用
|
||||||
response = await client.chat.completions.create(
|
response = await acompletion(
|
||||||
model=model,
|
model=provider_model,
|
||||||
messages=[{"role": m.role, "content": m.content} for m in request.messages],
|
messages=[{"role": m.role, "content": m.content} for m in request.messages],
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
@@ -148,14 +157,15 @@ async def chat_completions(request: ChatRequest):
|
|||||||
# 计算成本
|
# 计算成本
|
||||||
input_tokens = response.usage.prompt_tokens
|
input_tokens = response.usage.prompt_tokens
|
||||||
output_tokens = response.usage.completion_tokens
|
output_tokens = response.usage.completion_tokens
|
||||||
cost = calculate_cost(model, input_tokens, output_tokens)
|
cost = calculate_cost(model_key, input_tokens, output_tokens)
|
||||||
|
|
||||||
# 记录调用
|
# 记录调用
|
||||||
log_call(model, cost, latency_ms, input_tokens + output_tokens)
|
log_call(model_key, provider, cost, latency_ms, input_tokens + output_tokens)
|
||||||
|
|
||||||
return ChatResponse(
|
return ChatResponse(
|
||||||
id=response.id,
|
id=response.id,
|
||||||
model=model,
|
model=model_key,
|
||||||
|
provider=provider,
|
||||||
content=response.choices[0].message.content,
|
content=response.choices[0].message.content,
|
||||||
usage={
|
usage={
|
||||||
"prompt_tokens": input_tokens,
|
"prompt_tokens": input_tokens,
|
||||||
@@ -167,7 +177,23 @@ async def chat_completions(request: ChatRequest):
|
|||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"OpenAI API error: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"API error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/models")
|
||||||
|
async def list_models():
|
||||||
|
"""列出支持的模型"""
|
||||||
|
return {
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"key": key,
|
||||||
|
"provider": config["provider"],
|
||||||
|
"input_cost_per_1k": config["input_cost"],
|
||||||
|
"output_cost_per_1k": config["output_cost"],
|
||||||
|
}
|
||||||
|
for key, config in MODEL_CONFIG.items()
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/stats", response_model=StatsResponse)
|
@app.get("/stats", response_model=StatsResponse)
|
||||||
@@ -179,6 +205,7 @@ async def get_stats():
|
|||||||
total_cost_usd=0.0,
|
total_cost_usd=0.0,
|
||||||
avg_latency_ms=0.0,
|
avg_latency_ms=0.0,
|
||||||
model_distribution={},
|
model_distribution={},
|
||||||
|
provider_distribution={},
|
||||||
recent_calls=[],
|
recent_calls=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -188,14 +215,18 @@ async def get_stats():
|
|||||||
|
|
||||||
# 模型分布
|
# 模型分布
|
||||||
model_dist: Dict[str, int] = {}
|
model_dist: Dict[str, int] = {}
|
||||||
|
provider_dist: Dict[str, int] = {}
|
||||||
for call in call_history:
|
for call in call_history:
|
||||||
model = call["model"]
|
model = call["model"]
|
||||||
|
provider = call["provider"]
|
||||||
model_dist[model] = model_dist.get(model, 0) + 1
|
model_dist[model] = model_dist.get(model, 0) + 1
|
||||||
|
provider_dist[provider] = provider_dist.get(provider, 0) + 1
|
||||||
|
|
||||||
# 最近 10 条记录
|
# 最近 10 条记录
|
||||||
recent = [
|
recent = [
|
||||||
{
|
{
|
||||||
"model": c["model"],
|
"model": c["model"],
|
||||||
|
"provider": c["provider"],
|
||||||
"cost_usd": round(c["cost_usd"], 6),
|
"cost_usd": round(c["cost_usd"], 6),
|
||||||
"latency_ms": round(c["latency_ms"], 2),
|
"latency_ms": round(c["latency_ms"], 2),
|
||||||
"tokens": c["tokens"],
|
"tokens": c["tokens"],
|
||||||
@@ -208,6 +239,7 @@ async def get_stats():
|
|||||||
total_cost_usd=round(total_cost, 6),
|
total_cost_usd=round(total_cost, 6),
|
||||||
avg_latency_ms=round(avg_latency, 2),
|
avg_latency_ms=round(avg_latency, 2),
|
||||||
model_distribution=model_dist,
|
model_distribution=model_dist,
|
||||||
|
provider_distribution=provider_dist,
|
||||||
recent_calls=recent,
|
recent_calls=recent,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -215,7 +247,7 @@ async def get_stats():
|
|||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
"""健康检查"""
|
"""健康检查"""
|
||||||
return {"status": "healthy", "client_initialized": client is not None}
|
return {"status": "healthy", "version": "0.2.0"}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
fastapi>=0.104.0
|
fastapi>=0.104.0
|
||||||
uvicorn[standard]>=0.24.0
|
uvicorn[standard]>=0.24.0
|
||||||
pydantic>=2.5.0
|
pydantic>=2.5.0
|
||||||
openai>=1.6.0
|
litellm>=1.0.0
|
||||||
tiktoken>=0.5.0
|
tiktoken>=0.5.0
|
||||||
httpx>=0.25.0
|
httpx>=0.25.0
|
||||||
python-dotenv>=1.0.0
|
python-dotenv>=1.0.0
|
||||||
|
|||||||
Reference in New Issue
Block a user