""" MVP版 LLM 路由服务 基于 LiteLLM 的多提供商统一接口 支持: OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商 """ import time import tiktoken from typing import List, Dict, Any, Optional from fastapi import FastAPI, HTTPException from pydantic import BaseModel from litellm import acompletion from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING # 调用历史记录 call_history: List[Dict[str, Any]] = [] class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[Message] model: Optional[str] = None # 可选,如果指定则跳过路由 temperature: float = 0.7 max_tokens: Optional[int] = None class ChatResponse(BaseModel): id: str model: str provider: str content: str usage: Dict[str, int] cost_usd: float latency_ms: float class StatsResponse(BaseModel): total_calls: int total_cost_usd: float avg_latency_ms: float model_distribution: Dict[str, int] provider_distribution: Dict[str, int] recent_calls: List[Dict[str, Any]] app = FastAPI( title="LLM Router MVP", description="基于 LiteLLM 的多提供商路由服务", version="0.2.0", ) def estimate_tokens(messages: List[Message]) -> int: """估算 token 数量""" try: encoding = tiktoken.encoding_for_model("gpt-4") except KeyError: encoding = tiktoken.get_encoding("cl100k_base") total_tokens = 0 for msg in messages: total_tokens += 4 total_tokens += len(encoding.encode(msg.content)) total_tokens += len(encoding.encode(msg.role)) total_tokens += 2 return total_tokens def select_model_by_length(messages: List[Message]) -> str: """基于 token 长度选择模型""" token_count = estimate_tokens(messages) if token_count < ROUTING_THRESHOLDS["simple"]: return DEFAULT_ROUTING["simple"] elif token_count < ROUTING_THRESHOLDS["medium"]: return DEFAULT_ROUTING["medium"] else: return DEFAULT_ROUTING["complex"] def get_provider_model(model_key: str) -> str: """获取 LiteLLM 格式的模型名称""" config = MODEL_CONFIG.get(model_key) if not config: raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}") return config["provider"] def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float: """计算调用成本""" config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"]) input_cost = (input_tokens / 1000) * config["input_cost"] output_cost = (output_tokens / 1000) * config["output_cost"] return input_cost + output_cost def get_provider_from_model(model_name: str) -> str: """从模型名称推断提供商""" if model_name.startswith("gpt"): return "openai" elif model_name.startswith("claude"): return "anthropic" elif model_name.startswith("gemini"): return "google" elif "/" in model_name: return model_name.split("/")[0] return "unknown" def log_call(model: str, provider: str, cost: float, latency_ms: float, tokens: int): """记录调用历史""" call_history.append({ "model": model, "provider": provider, "cost_usd": cost, "latency_ms": latency_ms, "tokens": tokens, "timestamp": time.time(), }) @app.post("/v1/chat/completions", response_model=ChatResponse) async def chat_completions(request: ChatRequest): """ 聊天完成接口 如果 request.model 未指定,则根据 token 长度自动路由 """ # 选择模型 if request.model: model_key = request.model else: model_key = select_model_by_length(request.messages) # 获取 LiteLLM 模型名称 provider_model = get_provider_model(model_key) provider = get_provider_from_model(provider_model) start_time = time.time() try: # 使用 LiteLLM 统一调用 response = await acompletion( model=provider_model, messages=[{"role": m.role, "content": m.content} for m in request.messages], temperature=request.temperature, max_tokens=request.max_tokens, ) latency_ms = (time.time() - start_time) * 1000 # 计算成本 input_tokens = response.usage.prompt_tokens output_tokens = response.usage.completion_tokens cost = calculate_cost(model_key, input_tokens, output_tokens) # 记录调用 log_call(model_key, provider, cost, latency_ms, input_tokens + output_tokens) return ChatResponse( id=response.id, model=model_key, provider=provider, content=response.choices[0].message.content, usage={ "prompt_tokens": input_tokens, "completion_tokens": output_tokens, "total_tokens": input_tokens + output_tokens, }, cost_usd=cost, latency_ms=round(latency_ms, 2), ) except Exception as e: raise HTTPException(status_code=500, detail=f"API error: {str(e)}") @app.get("/models") async def list_models(): """列出支持的模型""" return { "models": [ { "key": key, "provider": config["provider"], "input_cost_per_1k": config["input_cost"], "output_cost_per_1k": config["output_cost"], } for key, config in MODEL_CONFIG.items() ] } @app.get("/stats", response_model=StatsResponse) async def get_stats(): """获取调用统计""" if not call_history: return StatsResponse( total_calls=0, total_cost_usd=0.0, avg_latency_ms=0.0, model_distribution={}, provider_distribution={}, recent_calls=[], ) total_calls = len(call_history) total_cost = sum(c["cost_usd"] for c in call_history) avg_latency = sum(c["latency_ms"] for c in call_history) / total_calls # 模型分布 model_dist: Dict[str, int] = {} provider_dist: Dict[str, int] = {} for call in call_history: model = call["model"] provider = call["provider"] model_dist[model] = model_dist.get(model, 0) + 1 provider_dist[provider] = provider_dist.get(provider, 0) + 1 # 最近 10 条记录 recent = [ { "model": c["model"], "provider": c["provider"], "cost_usd": round(c["cost_usd"], 6), "latency_ms": round(c["latency_ms"], 2), "tokens": c["tokens"], } for c in call_history[-10:] ] return StatsResponse( total_calls=total_calls, total_cost_usd=round(total_cost, 6), avg_latency_ms=round(avg_latency, 2), model_distribution=model_dist, provider_distribution=provider_dist, recent_calls=recent, ) @app.get("/health") async def health_check(): """健康检查""" return {"status": "healthy", "version": "0.2.0"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)