diff --git a/.env.example b/.env.example index d8b0faa..af2df66 100644 --- a/.env.example +++ b/.env.example @@ -1,5 +1,14 @@ # OpenAI API Key -OPENAI_API_KEY=sk-your-api-key-here +OPENAI_API_KEY=sk-your-openai-key-here + +# Anthropic API Key (Claude) +ANTHROPIC_API_KEY=sk-ant-your-anthropic-key-here + +# Google API Key (Gemini) +GEMINI_API_KEY=your-gemini-key-here + +# Ollama (本地模型,不需要 API Key) +# OLLAMA_HOST=http://localhost:11434 # 可选:自定义路由阈值 # ROUTE_SIMPLE_THRESHOLD=100 diff --git a/config.py b/config.py index e79abff..6eb7f40 100644 --- a/config.py +++ b/config.py @@ -8,31 +8,38 @@ from dotenv import load_dotenv # 加载 .env 文件 load_dotenv() -# 模型配置 +# 统一模型配置(支持多提供商) +# 格式: "统一模型名": {"provider": "litellm格式", "input_cost": x, "output_cost": y} MODEL_CONFIG = { - "gpt-3.5-turbo": { - "input_cost_per_1k": 0.0005, - "output_cost_per_1k": 0.0015, - "max_tokens": 4096, - }, - "gpt-4o-mini": { - "input_cost_per_1k": 0.00015, - "output_cost_per_1k": 0.0006, - "max_tokens": 128000, - }, - "gpt-4o": { - "input_cost_per_1k": 0.005, - "output_cost_per_1k": 0.015, - "max_tokens": 128000, - }, + # OpenAI + "gpt-3.5": {"provider": "gpt-3.5-turbo", "input_cost": 0.0005, "output_cost": 0.0015}, + "gpt-4o-mini": {"provider": "gpt-4o-mini", "input_cost": 0.00015, "output_cost": 0.0006}, + "gpt-4o": {"provider": "gpt-4o", "input_cost": 0.005, "output_cost": 0.015}, + # Anthropic + "claude-3-haiku": {"provider": "claude-3-haiku-20240307", "input_cost": 0.00025, "output_cost": 0.00125}, + "claude-3-sonnet": {"provider": "claude-3-sonnet-20240229", "input_cost": 0.003, "output_cost": 0.015}, + "claude-3-opus": {"provider": "claude-3-opus-20240229", "input_cost": 0.015, "output_cost": 0.075}, + # Gemini + "gemini-flash": {"provider": "gemini/gemini-1.5-flash", "input_cost": 0.000075, "output_cost": 0.0003}, + "gemini-pro": {"provider": "gemini/gemini-1.5-pro", "input_cost": 0.00125, "output_cost": 0.005}, + # 本地/开源 + "llama3": {"provider": "ollama/llama3", "input_cost": 0, "output_cost": 0}, } -# 路由阈值 +# 路由阈值(token 数 -> 推荐模型) ROUTING_THRESHOLDS = { - "simple": 100, # < 100 tokens -> gpt-3.5-turbo - "medium": 500, # < 500 tokens -> gpt-4o-mini - # >= 500 tokens -> gpt-4o + "simple": 100, # < 100 tokens + "medium": 500, # < 500 tokens } -# API Key +# 默认模型选择策略 +DEFAULT_ROUTING = { + "simple": "gpt-3.5", # 或 "claude-3-haiku", "gemini-flash" + "medium": "gpt-4o-mini", # 或 "claude-3-haiku" + "complex": "gpt-4o", # 或 "claude-3-sonnet", "gemini-pro" +} + +# API Keys(litellm 自动读取环境变量) OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") +ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "") +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "") diff --git a/main.py b/main.py index 9f90382..dfa1e0f 100644 --- a/main.py +++ b/main.py @@ -1,17 +1,17 @@ """ MVP版 LLM 路由服务 -基于 token 长度的简单规则路由 +基于 LiteLLM 的多提供商统一接口 +支持: OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商 """ import time import tiktoken from typing import List, Dict, Any, Optional -from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException -from pydantic import BaseModel, Field -from openai import AsyncOpenAI +from pydantic import BaseModel +from litellm import acompletion -from config import MODEL_CONFIG, ROUTING_THRESHOLDS, OPENAI_API_KEY +from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING # 调用历史记录 @@ -33,6 +33,7 @@ class ChatRequest(BaseModel): class ChatResponse(BaseModel): id: str model: str + provider: str content: str usage: Dict[str, int] cost_usd: float @@ -44,29 +45,14 @@ class StatsResponse(BaseModel): total_cost_usd: float avg_latency_ms: float model_distribution: Dict[str, int] + provider_distribution: Dict[str, int] recent_calls: List[Dict[str, Any]] -# 初始化 OpenAI 客户端 -client: Optional[AsyncOpenAI] = None - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """应用生命周期管理""" - global client - if not OPENAI_API_KEY: - raise RuntimeError("OPENAI_API_KEY environment variable is required") - client = AsyncOpenAI(api_key=OPENAI_API_KEY) - yield - client = None - - app = FastAPI( title="LLM Router MVP", - description="基于 token 长度的简单规则路由服务", - version="0.1.0", - lifespan=lifespan, + description="基于 LiteLLM 的多提供商路由服务", + version="0.2.0", ) @@ -79,10 +65,10 @@ def estimate_tokens(messages: List[Message]) -> int: total_tokens = 0 for msg in messages: - total_tokens += 4 # 每条消息的开销 + total_tokens += 4 total_tokens += len(encoding.encode(msg.content)) total_tokens += len(encoding.encode(msg.role)) - total_tokens += 2 # 回复的开销 + total_tokens += 2 return total_tokens @@ -91,25 +77,47 @@ def select_model_by_length(messages: List[Message]) -> str: token_count = estimate_tokens(messages) if token_count < ROUTING_THRESHOLDS["simple"]: - return "gpt-3.5-turbo" + return DEFAULT_ROUTING["simple"] elif token_count < ROUTING_THRESHOLDS["medium"]: - return "gpt-4o-mini" + return DEFAULT_ROUTING["medium"] else: - return "gpt-4o" + return DEFAULT_ROUTING["complex"] -def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> float: +def get_provider_model(model_key: str) -> str: + """获取 LiteLLM 格式的模型名称""" + config = MODEL_CONFIG.get(model_key) + if not config: + raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}") + return config["provider"] + + +def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float: """计算调用成本""" - config = MODEL_CONFIG.get(model, MODEL_CONFIG["gpt-4o"]) - input_cost = (input_tokens / 1000) * config["input_cost_per_1k"] - output_cost = (output_tokens / 1000) * config["output_cost_per_1k"] + config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"]) + input_cost = (input_tokens / 1000) * config["input_cost"] + output_cost = (output_tokens / 1000) * config["output_cost"] return input_cost + output_cost -def log_call(model: str, cost: float, latency_ms: float, tokens: int): +def get_provider_from_model(model_name: str) -> str: + """从模型名称推断提供商""" + if model_name.startswith("gpt"): + return "openai" + elif model_name.startswith("claude"): + return "anthropic" + elif model_name.startswith("gemini"): + return "google" + elif "/" in model_name: + return model_name.split("/")[0] + return "unknown" + + +def log_call(model: str, provider: str, cost: float, latency_ms: float, tokens: int): """记录调用历史""" call_history.append({ "model": model, + "provider": provider, "cost_usd": cost, "latency_ms": latency_ms, "tokens": tokens, @@ -123,21 +131,22 @@ async def chat_completions(request: ChatRequest): 聊天完成接口 如果 request.model 未指定,则根据 token 长度自动路由 """ - if client is None: - raise HTTPException(status_code=500, detail="OpenAI client not initialized") - # 选择模型 if request.model: - model = request.model + model_key = request.model else: - model = select_model_by_length(request.messages) + model_key = select_model_by_length(request.messages) + + # 获取 LiteLLM 模型名称 + provider_model = get_provider_model(model_key) + provider = get_provider_from_model(provider_model) start_time = time.time() try: - # 调用 OpenAI - response = await client.chat.completions.create( - model=model, + # 使用 LiteLLM 统一调用 + response = await acompletion( + model=provider_model, messages=[{"role": m.role, "content": m.content} for m in request.messages], temperature=request.temperature, max_tokens=request.max_tokens, @@ -148,14 +157,15 @@ async def chat_completions(request: ChatRequest): # 计算成本 input_tokens = response.usage.prompt_tokens output_tokens = response.usage.completion_tokens - cost = calculate_cost(model, input_tokens, output_tokens) + cost = calculate_cost(model_key, input_tokens, output_tokens) # 记录调用 - log_call(model, cost, latency_ms, input_tokens + output_tokens) + log_call(model_key, provider, cost, latency_ms, input_tokens + output_tokens) return ChatResponse( id=response.id, - model=model, + model=model_key, + provider=provider, content=response.choices[0].message.content, usage={ "prompt_tokens": input_tokens, @@ -167,7 +177,23 @@ async def chat_completions(request: ChatRequest): ) except Exception as e: - raise HTTPException(status_code=500, detail=f"OpenAI API error: {str(e)}") + raise HTTPException(status_code=500, detail=f"API error: {str(e)}") + + +@app.get("/models") +async def list_models(): + """列出支持的模型""" + return { + "models": [ + { + "key": key, + "provider": config["provider"], + "input_cost_per_1k": config["input_cost"], + "output_cost_per_1k": config["output_cost"], + } + for key, config in MODEL_CONFIG.items() + ] + } @app.get("/stats", response_model=StatsResponse) @@ -179,6 +205,7 @@ async def get_stats(): total_cost_usd=0.0, avg_latency_ms=0.0, model_distribution={}, + provider_distribution={}, recent_calls=[], ) @@ -188,14 +215,18 @@ async def get_stats(): # 模型分布 model_dist: Dict[str, int] = {} + provider_dist: Dict[str, int] = {} for call in call_history: model = call["model"] + provider = call["provider"] model_dist[model] = model_dist.get(model, 0) + 1 + provider_dist[provider] = provider_dist.get(provider, 0) + 1 # 最近 10 条记录 recent = [ { "model": c["model"], + "provider": c["provider"], "cost_usd": round(c["cost_usd"], 6), "latency_ms": round(c["latency_ms"], 2), "tokens": c["tokens"], @@ -208,6 +239,7 @@ async def get_stats(): total_cost_usd=round(total_cost, 6), avg_latency_ms=round(avg_latency, 2), model_distribution=model_dist, + provider_distribution=provider_dist, recent_calls=recent, ) @@ -215,7 +247,7 @@ async def get_stats(): @app.get("/health") async def health_check(): """健康检查""" - return {"status": "healthy", "client_initialized": client is not None} + return {"status": "healthy", "version": "0.2.0"} if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index ccbd1d2..9925d56 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ fastapi>=0.104.0 uvicorn[standard]>=0.24.0 pydantic>=2.5.0 -openai>=1.6.0 +litellm>=1.0.0 tiktoken>=0.5.0 httpx>=0.25.0 python-dotenv>=1.0.0