""" MVP版 LLM 路由服务 基于 LiteLLM 的多提供商统一接口 支持: OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商 """ import time import json import os import tiktoken from typing import List, Dict, Any, Optional from pathlib import Path from fastapi import FastAPI, HTTPException from pydantic import BaseModel from litellm import acompletion import litellm from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING, DASHSCOPE_API_KEY from nvidia_router import get_nvidia_router, select_model_by_nvidia # 配置 LiteLLM 使用 DashScope (Qwen) if DASHSCOPE_API_KEY: litellm.api_key = DASHSCOPE_API_KEY # Qwen 使用 OpenAI 兼容接口,但需要通过 api_base 指定 litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1" # NVIDIA Router 实例(延迟加载) _nvidia_router = None def get_router(): """获取 NVIDIA Router 实例(延迟加载)""" global _nvidia_router if _nvidia_router is None: _nvidia_router = get_nvidia_router() return _nvidia_router # 调用历史 - JSON 文件持久化 CALL_LOG_DIR = Path(__file__).parent / "data" CALL_LOG_DIR.mkdir(exist_ok=True) CALL_LOG_FILE = CALL_LOG_DIR / "call_history.jsonl" # 内存缓存(启动时从文件加载) call_history: List[Dict[str, Any]] = [] def _load_history(): """启动时从 JSONL 文件加载历史记录""" if CALL_LOG_FILE.exists(): with open(CALL_LOG_FILE, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: try: call_history.append(json.loads(line)) except json.JSONDecodeError: continue print(f"Loaded {len(call_history)} historical records from {CALL_LOG_FILE}") _load_history() class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[Message] model: Optional[str] = None # 可选,如果指定则跳过路由 temperature: float = 0.7 max_tokens: Optional[int] = None class ChatResponse(BaseModel): id: str model: str provider: str content: str usage: Dict[str, int] cost_usd: float latency_ms: float app = FastAPI( title="LLM Router MVP", description="基于 LiteLLM + NVIDIA 分类器的多提供商路由服务(支持3-tier智能路由)", version="0.3.0", ) def estimate_tokens(messages: List[Message]) -> int: """估算 token 数量""" try: encoding = tiktoken.encoding_for_model("gpt-4") except KeyError: encoding = tiktoken.get_encoding("cl100k_base") total_tokens = 0 for msg in messages: total_tokens += 4 total_tokens += len(encoding.encode(msg.content)) total_tokens += len(encoding.encode(msg.role)) total_tokens += 2 return total_tokens def select_model_by_length(messages: List[Message]) -> str: """基于 token 长度选择模型(备用策略)""" token_count = estimate_tokens(messages) if token_count < ROUTING_THRESHOLDS["simple"]: return DEFAULT_ROUTING["simple"] elif token_count < ROUTING_THRESHOLDS["medium"]: return DEFAULT_ROUTING["medium"] else: return DEFAULT_ROUTING["complex"] def select_model_by_nvidia_classifier(messages: List[Message]) -> tuple: """ 基于 NVIDIA 多头分类器选择模型(3-tier路由) Returns: (model_key, routing_detail) - 模型名称 + 路由分类细节 """ query = messages[-1].content if messages else "" try: router = get_router() start = time.time() result = router.predict(query) routing_ms = (time.time() - start) * 1000 model_map = {"simple": "qwen-flash", "medium": "qwen-plus", "complex": "qwen-max"} model_key = model_map[result["tier"]] routing_detail = { "method": "nvidia_classifier", "query": query, "routing_latency_ms": round(routing_ms, 2), "tier": result["tier"], "complexity_score": result["complexity_score"], "task_type": result["task_type"], "domain_knowledge": result["domain_knowledge"], "reasoning": result["reasoning"], "creativity": result["creativity"], } return model_key, routing_detail except Exception as e: print(f"NVIDIA routing failed: {e}, falling back to token length") model_key = select_model_by_length(messages) routing_detail = { "method": "fallback_token_length", "query": query, "routing_latency_ms": 0, "error": str(e), } return model_key, routing_detail def get_provider_model(model_key: str) -> str: """获取 LiteLLM 格式的模型名称""" config = MODEL_CONFIG.get(model_key) if not config: raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}") return config["provider"] def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float: """计算调用成本""" config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"]) input_cost = (input_tokens / 1000) * config["input_cost"] output_cost = (output_tokens / 1000) * config["output_cost"] return input_cost + output_cost def get_provider_from_model(model_name: str) -> str: """从模型名称推断提供商""" if model_name.startswith("gpt"): return "openai" elif model_name.startswith("claude"): return "anthropic" elif model_name.startswith("gemini"): return "google" elif "/" in model_name: return model_name.split("/")[0] return "unknown" def log_call( model: str, provider: str, cost: float, latency_ms: float, input_tokens: int, output_tokens: int, messages: List[Dict[str, str]], response_content: str, response_id: str, routing_detail: Optional[Dict[str, Any]], request_params: Dict[str, Any], ): """记录完整调用历史(含路由细节 + LLM 原始数据,供后续调优)""" record = { "timestamp": time.time(), # 请求 "request": { "messages": messages, "temperature": request_params.get("temperature"), "max_tokens": request_params.get("max_tokens"), "user_specified_model": request_params.get("user_specified_model"), }, # 路由决策 "routing": routing_detail, # LLM 调用 "llm": { "model": model, "provider": provider, "response_id": response_id, "response_content": response_content, "input_tokens": input_tokens, "output_tokens": output_tokens, "total_tokens": input_tokens + output_tokens, "cost_usd": cost, "llm_latency_ms": round(latency_ms, 2), }, } call_history.append(record) # 追加写入 JSONL 文件 with open(CALL_LOG_FILE, "a", encoding="utf-8") as f: f.write(json.dumps(record, ensure_ascii=False) + "\n") @app.post("/v1/chat/completions", response_model=ChatResponse) async def chat_completions(request: ChatRequest): """ 聊天完成接口 如果 request.model 未指定,则使用 NVIDIA 分类器智能路由 """ routing_detail = None if request.model: model_key = request.model routing_detail = { "method": "user_specified", "query": request.messages[-1].content if request.messages else "", } else: model_key, routing_detail = select_model_by_nvidia_classifier(request.messages) # 获取 LiteLLM 模型名称 provider_model = get_provider_model(model_key) provider = get_provider_from_model(provider_model) messages_raw = [{"role": m.role, "content": m.content} for m in request.messages] start_time = time.time() try: response = await acompletion( model=provider_model, messages=messages_raw, temperature=request.temperature, max_tokens=request.max_tokens, ) latency_ms = (time.time() - start_time) * 1000 input_tokens = response.usage.prompt_tokens output_tokens = response.usage.completion_tokens cost = calculate_cost(model_key, input_tokens, output_tokens) response_content = response.choices[0].message.content # 记录完整调用数据 log_call( model=model_key, provider=provider, cost=cost, latency_ms=latency_ms, input_tokens=input_tokens, output_tokens=output_tokens, messages=messages_raw, response_content=response_content, response_id=response.id, routing_detail=routing_detail, request_params={ "temperature": request.temperature, "max_tokens": request.max_tokens, "user_specified_model": request.model, }, ) return ChatResponse( id=response.id, model=model_key, provider=provider, content=response_content, usage={ "prompt_tokens": input_tokens, "completion_tokens": output_tokens, "total_tokens": input_tokens + output_tokens, }, cost_usd=cost, latency_ms=round(latency_ms, 2), ) except Exception as e: raise HTTPException(status_code=500, detail=f"API error: {str(e)}") @app.get("/models") async def list_models(): """列出支持的模型""" return { "models": [ { "key": key, "provider": config["provider"], "input_cost_per_1k": config["input_cost"], "output_cost_per_1k": config["output_cost"], } for key, config in MODEL_CONFIG.items() ] } @app.get("/stats") async def get_stats(): """获取调用统计摘要""" if not call_history: return { "total_calls": 0, "total_cost_usd": 0.0, "avg_latency_ms": 0.0, "model_distribution": {}, "tier_distribution": {}, "task_type_distribution": {}, } total_calls = len(call_history) total_cost = sum(c["llm"]["cost_usd"] for c in call_history) avg_latency = sum(c["llm"]["llm_latency_ms"] for c in call_history) / total_calls model_dist: Dict[str, int] = {} tier_dist: Dict[str, int] = {} task_dist: Dict[str, int] = {} for call in call_history: model = call["llm"]["model"] model_dist[model] = model_dist.get(model, 0) + 1 routing = call.get("routing") or {} if routing.get("tier"): tier = routing["tier"] tier_dist[tier] = tier_dist.get(tier, 0) + 1 if routing.get("task_type"): task = routing["task_type"] task_dist[task] = task_dist.get(task, 0) + 1 return { "total_calls": total_calls, "total_cost_usd": round(total_cost, 6), "avg_latency_ms": round(avg_latency, 2), "avg_routing_ms": round( sum(c.get("routing", {}).get("routing_latency_ms", 0) for c in call_history) / total_calls, 2 ), "model_distribution": model_dist, "tier_distribution": tier_dist, "task_type_distribution": task_dist, } @app.get("/stats/raw") async def get_stats_raw(limit: int = 50, offset: int = 0): """ 获取原始调用记录(含路由分类细节 + LLM 完整数据) 用于后续调优和分析 参数: - limit: 返回条数(默认50) - offset: 偏移量(默认0,从最新开始) """ total = len(call_history) # 倒序返回(最新在前) records = list(reversed(call_history)) page = records[offset:offset + limit] return { "total": total, "limit": limit, "offset": offset, "records": page, } @app.get("/health") async def health_check(): """健康检查""" return {"status": "healthy", "version": "0.3.0", "router": "nvidia-3tier"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)