From 1e273e36709358336318974c7ddcb5a125d9da3e Mon Sep 17 00:00:00 2001 From: aszerW Date: Sat, 18 Apr 2026 01:58:33 +0800 Subject: [PATCH] =?UTF-8?q?feat(stats):=20=E5=AE=8C=E5=96=84=E8=B0=83?= =?UTF-8?q?=E7=94=A8=E8=AE=B0=E5=BD=95=E8=AF=A6=E6=83=85=E5=B9=B6=E6=8C=81?= =?UTF-8?q?=E4=B9=85=E5=8C=96=E5=88=B0JSONL=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - log_call保存完整request/routing/llm三层数据(含NVIDIA分类原始输出) - 新增/stats/raw接口返回原始调用记录(支持分页) - /stats摘要新增tier_distribution、task_type_distribution、avg_routing_ms - 调用历史持久化到data/call_history.jsonl,重启自动恢复 - data/目录加入.gitignore --- .gitignore | 3 + main.py | 259 +++++++++++++++++++++++++++++++++++++---------------- 2 files changed, 186 insertions(+), 76 deletions(-) diff --git a/.gitignore b/.gitignore index e8b4a3d..cc4750b 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,9 @@ __pycache__/ .env .venv +# Data (call history logs) +data/ + # IDE .vscode/ .idea/ diff --git a/main.py b/main.py index abc6b9a..01cb844 100644 --- a/main.py +++ b/main.py @@ -4,8 +4,11 @@ MVP版 LLM 路由服务 支持: OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商 """ import time +import json +import os import tiktoken from typing import List, Dict, Any, Optional +from pathlib import Path from fastapi import FastAPI, HTTPException from pydantic import BaseModel @@ -32,9 +35,29 @@ def get_router(): return _nvidia_router -# 调用历史记录 +# 调用历史 - JSON 文件持久化 +CALL_LOG_DIR = Path(__file__).parent / "data" +CALL_LOG_DIR.mkdir(exist_ok=True) +CALL_LOG_FILE = CALL_LOG_DIR / "call_history.jsonl" + +# 内存缓存(启动时从文件加载) call_history: List[Dict[str, Any]] = [] +def _load_history(): + """启动时从 JSONL 文件加载历史记录""" + if CALL_LOG_FILE.exists(): + with open(CALL_LOG_FILE, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + try: + call_history.append(json.loads(line)) + except json.JSONDecodeError: + continue + print(f"Loaded {len(call_history)} historical records from {CALL_LOG_FILE}") + +_load_history() + class Message(BaseModel): role: str @@ -58,15 +81,6 @@ class ChatResponse(BaseModel): latency_ms: float -class StatsResponse(BaseModel): - total_calls: int - total_cost_usd: float - avg_latency_ms: float - model_distribution: Dict[str, int] - provider_distribution: Dict[str, int] - recent_calls: List[Dict[str, Any]] - - app = FastAPI( title="LLM Router MVP", description="基于 LiteLLM + NVIDIA 分类器的多提供商路由服务(支持3-tier智能路由)", @@ -102,27 +116,46 @@ def select_model_by_length(messages: List[Message]) -> str: return DEFAULT_ROUTING["complex"] -def select_model_by_nvidia_classifier(messages: List[Message]) -> str: +def select_model_by_nvidia_classifier(messages: List[Message]) -> tuple: """ 基于 NVIDIA 多头分类器选择模型(3-tier路由) - NVIDIA 输出: 多维度复杂度评分 - 映射到 Qwen 模型: - - simple -> qwen-flash (简单任务) - - medium -> qwen-plus (中等任务) - - complex -> qwen-max (复杂任务) + Returns: + (model_key, routing_detail) - 模型名称 + 路由分类细节 """ - # 取最后一条用户消息作为查询 query = messages[-1].content if messages else "" try: router = get_router() - model = router.select_model(query) - return model + start = time.time() + result = router.predict(query) + routing_ms = (time.time() - start) * 1000 + + model_map = {"simple": "qwen-flash", "medium": "qwen-plus", "complex": "qwen-max"} + model_key = model_map[result["tier"]] + + routing_detail = { + "method": "nvidia_classifier", + "query": query, + "routing_latency_ms": round(routing_ms, 2), + "tier": result["tier"], + "complexity_score": result["complexity_score"], + "task_type": result["task_type"], + "domain_knowledge": result["domain_knowledge"], + "reasoning": result["reasoning"], + "creativity": result["creativity"], + } + return model_key, routing_detail except Exception as e: - # NVIDIA 分类器失败时回退到 token 长度策略 print(f"NVIDIA routing failed: {e}, falling back to token length") - return select_model_by_length(messages) + model_key = select_model_by_length(messages) + routing_detail = { + "method": "fallback_token_length", + "query": query, + "routing_latency_ms": 0, + "error": str(e), + } + return model_key, routing_detail def get_provider_model(model_key: str) -> str: @@ -154,61 +187,114 @@ def get_provider_from_model(model_name: str) -> str: return "unknown" -def log_call(model: str, provider: str, cost: float, latency_ms: float, tokens: int): - """记录调用历史""" - call_history.append({ - "model": model, - "provider": provider, - "cost_usd": cost, - "latency_ms": latency_ms, - "tokens": tokens, +def log_call( + model: str, + provider: str, + cost: float, + latency_ms: float, + input_tokens: int, + output_tokens: int, + messages: List[Dict[str, str]], + response_content: str, + response_id: str, + routing_detail: Optional[Dict[str, Any]], + request_params: Dict[str, Any], +): + """记录完整调用历史(含路由细节 + LLM 原始数据,供后续调优)""" + record = { "timestamp": time.time(), - }) + # 请求 + "request": { + "messages": messages, + "temperature": request_params.get("temperature"), + "max_tokens": request_params.get("max_tokens"), + "user_specified_model": request_params.get("user_specified_model"), + }, + # 路由决策 + "routing": routing_detail, + # LLM 调用 + "llm": { + "model": model, + "provider": provider, + "response_id": response_id, + "response_content": response_content, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + "cost_usd": cost, + "llm_latency_ms": round(latency_ms, 2), + }, + } + call_history.append(record) + + # 追加写入 JSONL 文件 + with open(CALL_LOG_FILE, "a", encoding="utf-8") as f: + f.write(json.dumps(record, ensure_ascii=False) + "\n") @app.post("/v1/chat/completions", response_model=ChatResponse) async def chat_completions(request: ChatRequest): """ 聊天完成接口 - 如果 request.model 未指定,则根据 token 长度自动路由 + 如果 request.model 未指定,则使用 NVIDIA 分类器智能路由 """ - # 选择模型 + routing_detail = None + if request.model: model_key = request.model + routing_detail = { + "method": "user_specified", + "query": request.messages[-1].content if request.messages else "", + } else: - # 使用 NVIDIA 多头分类器智能路由(支持3-tier) - model_key = select_model_by_nvidia_classifier(request.messages) + model_key, routing_detail = select_model_by_nvidia_classifier(request.messages) # 获取 LiteLLM 模型名称 provider_model = get_provider_model(model_key) provider = get_provider_from_model(provider_model) + messages_raw = [{"role": m.role, "content": m.content} for m in request.messages] start_time = time.time() try: - # 使用 LiteLLM 统一调用 response = await acompletion( model=provider_model, - messages=[{"role": m.role, "content": m.content} for m in request.messages], + messages=messages_raw, temperature=request.temperature, max_tokens=request.max_tokens, ) latency_ms = (time.time() - start_time) * 1000 - # 计算成本 input_tokens = response.usage.prompt_tokens output_tokens = response.usage.completion_tokens cost = calculate_cost(model_key, input_tokens, output_tokens) + response_content = response.choices[0].message.content - # 记录调用 - log_call(model_key, provider, cost, latency_ms, input_tokens + output_tokens) + # 记录完整调用数据 + log_call( + model=model_key, + provider=provider, + cost=cost, + latency_ms=latency_ms, + input_tokens=input_tokens, + output_tokens=output_tokens, + messages=messages_raw, + response_content=response_content, + response_id=response.id, + routing_detail=routing_detail, + request_params={ + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "user_specified_model": request.model, + }, + ) return ChatResponse( id=response.id, model=model_key, provider=provider, - content=response.choices[0].message.content, + content=response_content, usage={ "prompt_tokens": input_tokens, "completion_tokens": output_tokens, @@ -238,52 +324,73 @@ async def list_models(): } -@app.get("/stats", response_model=StatsResponse) +@app.get("/stats") async def get_stats(): - """获取调用统计""" + """获取调用统计摘要""" if not call_history: - return StatsResponse( - total_calls=0, - total_cost_usd=0.0, - avg_latency_ms=0.0, - model_distribution={}, - provider_distribution={}, - recent_calls=[], - ) + return { + "total_calls": 0, + "total_cost_usd": 0.0, + "avg_latency_ms": 0.0, + "model_distribution": {}, + "tier_distribution": {}, + "task_type_distribution": {}, + } total_calls = len(call_history) - total_cost = sum(c["cost_usd"] for c in call_history) - avg_latency = sum(c["latency_ms"] for c in call_history) / total_calls + total_cost = sum(c["llm"]["cost_usd"] for c in call_history) + avg_latency = sum(c["llm"]["llm_latency_ms"] for c in call_history) / total_calls - # 模型分布 model_dist: Dict[str, int] = {} - provider_dist: Dict[str, int] = {} + tier_dist: Dict[str, int] = {} + task_dist: Dict[str, int] = {} + for call in call_history: - model = call["model"] - provider = call["provider"] + model = call["llm"]["model"] model_dist[model] = model_dist.get(model, 0) + 1 - provider_dist[provider] = provider_dist.get(provider, 0) + 1 + + routing = call.get("routing") or {} + if routing.get("tier"): + tier = routing["tier"] + tier_dist[tier] = tier_dist.get(tier, 0) + 1 + if routing.get("task_type"): + task = routing["task_type"] + task_dist[task] = task_dist.get(task, 0) + 1 - # 最近 10 条记录 - recent = [ - { - "model": c["model"], - "provider": c["provider"], - "cost_usd": round(c["cost_usd"], 6), - "latency_ms": round(c["latency_ms"], 2), - "tokens": c["tokens"], - } - for c in call_history[-10:] - ] + return { + "total_calls": total_calls, + "total_cost_usd": round(total_cost, 6), + "avg_latency_ms": round(avg_latency, 2), + "avg_routing_ms": round( + sum(c.get("routing", {}).get("routing_latency_ms", 0) for c in call_history) / total_calls, 2 + ), + "model_distribution": model_dist, + "tier_distribution": tier_dist, + "task_type_distribution": task_dist, + } + + +@app.get("/stats/raw") +async def get_stats_raw(limit: int = 50, offset: int = 0): + """ + 获取原始调用记录(含路由分类细节 + LLM 完整数据) + 用于后续调优和分析 - return StatsResponse( - total_calls=total_calls, - total_cost_usd=round(total_cost, 6), - avg_latency_ms=round(avg_latency, 2), - model_distribution=model_dist, - provider_distribution=provider_dist, - recent_calls=recent, - ) + 参数: + - limit: 返回条数(默认50) + - offset: 偏移量(默认0,从最新开始) + """ + total = len(call_history) + # 倒序返回(最新在前) + records = list(reversed(call_history)) + page = records[offset:offset + limit] + + return { + "total": total, + "limit": limit, + "offset": offset, + "records": page, + } @app.get("/health")