""" LLM Compass - 智能LLM路由服务 (OpenAI 兼容 API) 基于 LiteLLM + NVIDIA 多头分类器的智能路由服务 支持 OpenAI Chat Completions API 格式(含流式返回) """ import time import json import uuid import tiktoken from typing import List, Dict, Any, Optional from pathlib import Path from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field from litellm import acompletion import litellm from config import MODEL_CONFIG, ROUTING_THRESHOLDS, DEFAULT_ROUTING, DASHSCOPE_API_KEY from nvidia_router import get_nvidia_router, select_model_by_nvidia # 配置 LiteLLM 使用 DashScope (Qwen) if DASHSCOPE_API_KEY: litellm.api_key = DASHSCOPE_API_KEY litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1" # NVIDIA Router 实例(延迟加载) _nvidia_router = None def get_router(): """获取 NVIDIA Router 实例(延迟加载)""" global _nvidia_router if _nvidia_router is None: _nvidia_router = get_nvidia_router() return _nvidia_router # ── 调用历史 - JSONL 持久化 ──────────────────────────────── CALL_LOG_DIR = Path(__file__).parent / "data" CALL_LOG_DIR.mkdir(exist_ok=True) CALL_LOG_FILE = CALL_LOG_DIR / "call_history.jsonl" call_history: List[Dict[str, Any]] = [] def _load_history(): if CALL_LOG_FILE.exists(): with open(CALL_LOG_FILE, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: try: call_history.append(json.loads(line)) except json.JSONDecodeError: continue print(f"Loaded {len(call_history)} historical records from {CALL_LOG_FILE}") _load_history() # ── OpenAI 兼容请求/响应模型 ──────────────────────────────── from pydantic import BaseModel, Field class ChatMessage(BaseModel): role: str = Field(..., description="角色:system, user, assistant", example="user") content: Optional[str] = Field(None, description="消息内容", example="你好,介绍一下你自己") name: Optional[str] = Field(None, description="可选的名称") class ChatCompletionRequest(BaseModel): model: Optional[str] = Field( None, description="模型名称(留空时自动使用 NVIDIA 分类器智能路由)", example="qwen-plus", json_schema_extra={"examples": ["", "qwen-flash", "qwen-plus", "qwen-max"]} ) messages: List[ChatMessage] = Field( ..., description="对话消息列表", example=[{"role": "user", "content": "你好,介绍一下你自己"}] ) temperature: Optional[float] = Field(0.7, ge=0, le=2, description="随机性 (0-2)") max_tokens: Optional[int] = Field(None, description="最大生成 token 数(留空时使用模型默认值)", example=2048) stream: Optional[bool] = Field(False, description="是否使用流式输出") top_p: Optional[float] = Field(1.0, ge=0, le=1, description="核采样参数") n: Optional[int] = Field(1, ge=1, le=10, description="生成回复数量") stop: Optional[Any] = Field(None, description="停止词") presence_penalty: Optional[float] = Field(0.0, ge=-2, le=2, description="存在惩罚") frequency_penalty: Optional[float] = Field(0.0, ge=-2, le=2, description="频率惩罚") user: Optional[str] = Field(None, description="用户标识") # ── FastAPI App ────────────────────────────────────────────── app = FastAPI( title="LLM Compass", description="智能LLM路由服务,为请求指引最优模型,兼顾质量与成本(NVIDIA 3-tier 分类 + LiteLLM 多提供商)", version="0.4.0", ) # ── 辅助函数 ───────────────────────────────────────────────── def estimate_tokens(messages: List[ChatMessage]) -> int: try: encoding = tiktoken.encoding_for_model("gpt-4") except KeyError: encoding = tiktoken.get_encoding("cl100k_base") total = 0 for msg in messages: total += 4 if msg.content: total += len(encoding.encode(msg.content)) total += len(encoding.encode(msg.role)) total += 2 return total def select_model_by_length(messages: List[ChatMessage]) -> str: token_count = estimate_tokens(messages) if token_count < ROUTING_THRESHOLDS["simple"]: return DEFAULT_ROUTING["simple"] elif token_count < ROUTING_THRESHOLDS["medium"]: return DEFAULT_ROUTING["medium"] else: return DEFAULT_ROUTING["complex"] def select_model_by_nvidia_classifier(messages: List[ChatMessage]) -> tuple: """ 基于 NVIDIA 多头分类器选择模型 Returns: (model_key, routing_detail) """ query = "" for msg in reversed(messages): if msg.role == "user" and msg.content: query = msg.content break try: router = get_router() start = time.time() result = router.predict(query) routing_ms = (time.time() - start) * 1000 model_map = {"simple": "qwen-flash", "medium": "qwen-plus", "complex": "qwen-max"} model_key = model_map[result["tier"]] routing_detail = { "method": "nvidia_classifier", "query": query, "routing_latency_ms": round(routing_ms, 2), "tier": result["tier"], "complexity_score": result["complexity_score"], "task_type": result["task_type"], "domain_knowledge": result["domain_knowledge"], "reasoning": result["reasoning"], "creativity": result["creativity"], } return model_key, routing_detail except Exception as e: print(f"NVIDIA routing failed: {e}, falling back to token length") model_key = select_model_by_length(messages) routing_detail = { "method": "fallback_token_length", "query": query, "routing_latency_ms": 0, "error": str(e), } return model_key, routing_detail def get_provider_model(model_key: str) -> str: config = MODEL_CONFIG.get(model_key) if not config: raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}") return config["provider"] def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float: config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"]) return (input_tokens / 1000) * config["input_cost"] + (output_tokens / 1000) * config["output_cost"] def log_call( model: str, cost: float, latency_ms: float, input_tokens: int, output_tokens: int, messages_raw: List[Dict], response_content: str, response_id: str, routing_detail: Optional[Dict], request_params: Dict, stream: bool = False, ): record = { "timestamp": time.time(), "request": { "messages": messages_raw, "temperature": request_params.get("temperature"), "max_tokens": request_params.get("max_tokens"), "stream": stream, "user_specified_model": request_params.get("user_specified_model"), }, "routing": routing_detail, "llm": { "model": model, "response_id": response_id, "response_content": response_content, "input_tokens": input_tokens, "output_tokens": output_tokens, "total_tokens": input_tokens + output_tokens, "cost_usd": cost, "llm_latency_ms": round(latency_ms, 2), }, } call_history.append(record) with open(CALL_LOG_FILE, "a", encoding="utf-8") as f: f.write(json.dumps(record, ensure_ascii=False) + "\n") def build_openai_response( response_id: str, model: str, content: str, input_tokens: int, output_tokens: int, routing_detail: Optional[Dict] = None, ) -> Dict: """构建 OpenAI 格式的非流式响应""" resp = { "id": response_id, "object": "chat.completion", "created": int(time.time()), "model": model, "choices": [{ "index": 0, "message": {"role": "assistant", "content": content}, "finish_reason": "stop", }], "usage": { "prompt_tokens": input_tokens, "completion_tokens": output_tokens, "total_tokens": input_tokens + output_tokens, }, } # 路由细节作为扩展字段 if routing_detail: resp["routing"] = routing_detail return resp # ── 核心 API: /v1/chat/completions ────────────────────────── @app.post("/v1/chat/completions") async def chat_completions(request: ChatCompletionRequest): """ OpenAI 兼容的 Chat Completions API - model 为空时自动使用 NVIDIA 分类器路由 - stream=true 时返回 SSE 流式响应 """ # 1. 路由决策 routing_detail = None if request.model: model_key = request.model routing_detail = { "method": "user_specified", "query": next((m.content for m in reversed(request.messages) if m.role == "user" and m.content), ""), } else: model_key, routing_detail = select_model_by_nvidia_classifier(request.messages) provider_model = get_provider_model(model_key) messages_raw = [{"role": m.role, "content": m.content} for m in request.messages] response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" # 2. 流式响应 if request.stream: return StreamingResponse( _stream_response( provider_model=provider_model, model_key=model_key, messages_raw=messages_raw, request=request, response_id=response_id, routing_detail=routing_detail, ), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) # 3. 非流式响应 start_time = time.time() # 构建请求参数(过滤掉 None 和 0 的 max_tokens) completion_kwargs = { "model": provider_model, "messages": messages_raw, "temperature": request.temperature, } if request.max_tokens and request.max_tokens > 0: completion_kwargs["max_tokens"] = request.max_tokens try: response = await acompletion(**completion_kwargs) latency_ms = (time.time() - start_time) * 1000 input_tokens = response.usage.prompt_tokens output_tokens = response.usage.completion_tokens cost = calculate_cost(model_key, input_tokens, output_tokens) content = response.choices[0].message.content log_call( model=model_key, cost=cost, latency_ms=latency_ms, input_tokens=input_tokens, output_tokens=output_tokens, messages_raw=messages_raw, response_content=content, response_id=response_id, routing_detail=routing_detail, request_params={ "temperature": request.temperature, "max_tokens": request.max_tokens, "user_specified_model": request.model, }, ) return build_openai_response(response_id, model_key, content, input_tokens, output_tokens, routing_detail) except Exception as e: raise HTTPException(status_code=500, detail=f"API error: {str(e)}") async def _stream_response( provider_model: str, model_key: str, messages_raw: List[Dict], request: ChatCompletionRequest, response_id: str, routing_detail: Optional[Dict], ): """生成 SSE 流式响应""" start_time = time.time() collected_content = "" input_tokens = 0 output_tokens = 0 try: # 构建请求参数(过滤掉 None 和 0 的 max_tokens) completion_kwargs = { "model": provider_model, "messages": messages_raw, "temperature": request.temperature, "stream": True, } if request.max_tokens and request.max_tokens > 0: completion_kwargs["max_tokens"] = request.max_tokens response = await acompletion(**completion_kwargs) async for chunk in response: delta = chunk.choices[0].delta # 收集内容用于日志 if delta.content: collected_content += delta.content output_tokens += 1 # 近似计数 # 构建 SSE 数据 chunk_data = { "id": response_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": model_key, "choices": [{ "index": 0, "delta": {}, "finish_reason": None, }], } if delta.content: chunk_data["choices"][0]["delta"] = {"content": delta.content} elif delta.role: chunk_data["choices"][0]["delta"] = {"role": delta.role} if chunk.choices[0].finish_reason: chunk_data["choices"][0]["finish_reason"] = chunk.choices[0].finish_reason yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n" # 发送 [DONE] yield "data: [DONE]\n\n" # 记录日志 latency_ms = (time.time() - start_time) * 1000 # 流式模式下用 tiktoken 近似计算 input_tokens try: encoding = tiktoken.get_encoding("cl100k_base") input_tokens = sum(len(encoding.encode(m.get("content", "") or "")) for m in messages_raw) + len(messages_raw) * 4 except Exception: input_tokens = 0 cost = calculate_cost(model_key, input_tokens, output_tokens) log_call( model=model_key, cost=cost, latency_ms=latency_ms, input_tokens=input_tokens, output_tokens=output_tokens, messages_raw=messages_raw, response_content=collected_content, response_id=response_id, routing_detail=routing_detail, request_params={ "temperature": request.temperature, "max_tokens": request.max_tokens, "user_specified_model": request.model, }, stream=True, ) except Exception as e: error_data = {"error": {"message": str(e), "type": "api_error"}} yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" # ── OpenAI 兼容: /v1/models ───────────────────────────────── @app.get("/v1/models") async def list_models(): """OpenAI 兼容的模型列表接口""" data = [] for key, config in MODEL_CONFIG.items(): data.append({ "id": key, "object": "model", "created": 1700000000, "owned_by": config["provider"].split("/")[0] if "/" in config["provider"] else "unknown", }) return {"object": "list", "data": data} # ── 管理接口 ───────────────────────────────────────────────── @app.get("/stats") async def get_stats(): """获取调用统计摘要""" if not call_history: return { "total_calls": 0, "total_cost_usd": 0.0, "avg_latency_ms": 0.0, "model_distribution": {}, "tier_distribution": {}, "task_type_distribution": {}, } total_calls = len(call_history) total_cost = sum(c["llm"]["cost_usd"] for c in call_history) avg_latency = sum(c["llm"]["llm_latency_ms"] for c in call_history) / total_calls model_dist: Dict[str, int] = {} tier_dist: Dict[str, int] = {} task_dist: Dict[str, int] = {} for call in call_history: model = call["llm"]["model"] model_dist[model] = model_dist.get(model, 0) + 1 routing = call.get("routing") or {} if routing.get("tier"): tier_dist[routing["tier"]] = tier_dist.get(routing["tier"], 0) + 1 if routing.get("task_type"): task_dist[routing["task_type"]] = task_dist.get(routing["task_type"], 0) + 1 return { "total_calls": total_calls, "total_cost_usd": round(total_cost, 6), "avg_latency_ms": round(avg_latency, 2), "avg_routing_ms": round( sum(c.get("routing", {}).get("routing_latency_ms", 0) for c in call_history) / total_calls, 2 ), "model_distribution": model_dist, "tier_distribution": tier_dist, "task_type_distribution": task_dist, } @app.get("/stats/raw") async def get_stats_raw(limit: int = 50, offset: int = 0): """获取原始调用记录""" total = len(call_history) records = list(reversed(call_history)) return {"total": total, "limit": limit, "offset": offset, "records": records[offset:offset + limit]} @app.get("/health") async def health_check(): """健康检查""" return {"status": "healthy", "version": "0.4.0", "router": "llm-compass"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)