""" MVP版 LLM 路由服务 基于 LiteLLM 的多提供商统一接口 支持: OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商 """ import time import tiktoken from typing import List, Dict, Any, Optional from fastapi import FastAPI, HTTPException from pydantic import BaseModel from litellm import acompletion import litellm from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING, DASHSCOPE_API_KEY # 配置 LiteLLM 使用 DashScope (Qwen) if DASHSCOPE_API_KEY: litellm.api_key = DASHSCOPE_API_KEY # Qwen 使用 OpenAI 兼容接口,但需要通过 api_base 指定 litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1" # 调用历史记录 call_history: List[Dict[str, Any]] = [] class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[Message] model: Optional[str] = None # 可选,如果指定则跳过路由 temperature: float = 0.7 max_tokens: Optional[int] = None class ChatResponse(BaseModel): id: str model: str provider: str content: str usage: Dict[str, int] cost_usd: float latency_ms: float class StatsResponse(BaseModel): total_calls: int total_cost_usd: float avg_latency_ms: float model_distribution: Dict[str, int] provider_distribution: Dict[str, int] recent_calls: List[Dict[str, Any]] app = FastAPI( title="LLM Router MVP", description="基于 LiteLLM 的多提供商路由服务", version="0.2.0", ) def estimate_tokens(messages: List[Message]) -> int: """估算 token 数量""" try: encoding = tiktoken.encoding_for_model("gpt-4") except KeyError: encoding = tiktoken.get_encoding("cl100k_base") total_tokens = 0 for msg in messages: total_tokens += 4 total_tokens += len(encoding.encode(msg.content)) total_tokens += len(encoding.encode(msg.role)) total_tokens += 2 return total_tokens def select_model_by_length(messages: List[Message]) -> str: """基于 token 长度选择模型""" token_count = estimate_tokens(messages) if token_count < ROUTING_THRESHOLDS["simple"]: return DEFAULT_ROUTING["simple"] elif token_count < ROUTING_THRESHOLDS["medium"]: return DEFAULT_ROUTING["medium"] else: return DEFAULT_ROUTING["complex"] def get_provider_model(model_key: str) -> str: """获取 LiteLLM 格式的模型名称""" config = MODEL_CONFIG.get(model_key) if not config: raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}") return config["provider"] def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float: """计算调用成本""" config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"]) input_cost = (input_tokens / 1000) * config["input_cost"] output_cost = (output_tokens / 1000) * config["output_cost"] return input_cost + output_cost def get_provider_from_model(model_name: str) -> str: """从模型名称推断提供商""" if model_name.startswith("gpt"): return "openai" elif model_name.startswith("claude"): return "anthropic" elif model_name.startswith("gemini"): return "google" elif "/" in model_name: return model_name.split("/")[0] return "unknown" def log_call(model: str, provider: str, cost: float, latency_ms: float, tokens: int): """记录调用历史""" call_history.append({ "model": model, "provider": provider, "cost_usd": cost, "latency_ms": latency_ms, "tokens": tokens, "timestamp": time.time(), }) @app.post("/v1/chat/completions", response_model=ChatResponse) async def chat_completions(request: ChatRequest): """ 聊天完成接口 如果 request.model 未指定,则根据 token 长度自动路由 """ # 选择模型 if request.model: model_key = request.model else: model_key = select_model_by_length(request.messages) # 获取 LiteLLM 模型名称 provider_model = get_provider_model(model_key) provider = get_provider_from_model(provider_model) start_time = time.time() try: # 使用 LiteLLM 统一调用 response = await acompletion( model=provider_model, messages=[{"role": m.role, "content": m.content} for m in request.messages], temperature=request.temperature, max_tokens=request.max_tokens, ) latency_ms = (time.time() - start_time) * 1000 # 计算成本 input_tokens = response.usage.prompt_tokens output_tokens = response.usage.completion_tokens cost = calculate_cost(model_key, input_tokens, output_tokens) # 记录调用 log_call(model_key, provider, cost, latency_ms, input_tokens + output_tokens) return ChatResponse( id=response.id, model=model_key, provider=provider, content=response.choices[0].message.content, usage={ "prompt_tokens": input_tokens, "completion_tokens": output_tokens, "total_tokens": input_tokens + output_tokens, }, cost_usd=cost, latency_ms=round(latency_ms, 2), ) except Exception as e: raise HTTPException(status_code=500, detail=f"API error: {str(e)}") @app.get("/models") async def list_models(): """列出支持的模型""" return { "models": [ { "key": key, "provider": config["provider"], "input_cost_per_1k": config["input_cost"], "output_cost_per_1k": config["output_cost"], } for key, config in MODEL_CONFIG.items() ] } @app.get("/stats", response_model=StatsResponse) async def get_stats(): """获取调用统计""" if not call_history: return StatsResponse( total_calls=0, total_cost_usd=0.0, avg_latency_ms=0.0, model_distribution={}, provider_distribution={}, recent_calls=[], ) total_calls = len(call_history) total_cost = sum(c["cost_usd"] for c in call_history) avg_latency = sum(c["latency_ms"] for c in call_history) / total_calls # 模型分布 model_dist: Dict[str, int] = {} provider_dist: Dict[str, int] = {} for call in call_history: model = call["model"] provider = call["provider"] model_dist[model] = model_dist.get(model, 0) + 1 provider_dist[provider] = provider_dist.get(provider, 0) + 1 # 最近 10 条记录 recent = [ { "model": c["model"], "provider": c["provider"], "cost_usd": round(c["cost_usd"], 6), "latency_ms": round(c["latency_ms"], 2), "tokens": c["tokens"], } for c in call_history[-10:] ] return StatsResponse( total_calls=total_calls, total_cost_usd=round(total_cost, 6), avg_latency_ms=round(avg_latency, 2), model_distribution=model_dist, provider_distribution=provider_dist, recent_calls=recent, ) @app.get("/health") async def health_check(): """健康检查""" return {"status": "healthy", "version": "0.2.0"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)