feat(stats): 完善调用记录详情并持久化到JSONL文件
- log_call保存完整request/routing/llm三层数据(含NVIDIA分类原始输出) - 新增/stats/raw接口返回原始调用记录(支持分页) - /stats摘要新增tier_distribution、task_type_distribution、avg_routing_ms - 调用历史持久化到data/call_history.jsonl,重启自动恢复 - data/目录加入.gitignore
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -10,6 +10,9 @@ __pycache__/
|
|||||||
.env
|
.env
|
||||||
.venv
|
.venv
|
||||||
|
|
||||||
|
# Data (call history logs)
|
||||||
|
data/
|
||||||
|
|
||||||
# IDE
|
# IDE
|
||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
|
|||||||
259
main.py
259
main.py
@@ -4,8 +4,11 @@ MVP版 LLM 路由服务
|
|||||||
支持: OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商
|
支持: OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
|
import json
|
||||||
|
import os
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -32,9 +35,29 @@ def get_router():
|
|||||||
return _nvidia_router
|
return _nvidia_router
|
||||||
|
|
||||||
|
|
||||||
# 调用历史记录
|
# 调用历史 - JSON 文件持久化
|
||||||
|
CALL_LOG_DIR = Path(__file__).parent / "data"
|
||||||
|
CALL_LOG_DIR.mkdir(exist_ok=True)
|
||||||
|
CALL_LOG_FILE = CALL_LOG_DIR / "call_history.jsonl"
|
||||||
|
|
||||||
|
# 内存缓存(启动时从文件加载)
|
||||||
call_history: List[Dict[str, Any]] = []
|
call_history: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
def _load_history():
|
||||||
|
"""启动时从 JSONL 文件加载历史记录"""
|
||||||
|
if CALL_LOG_FILE.exists():
|
||||||
|
with open(CALL_LOG_FILE, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
try:
|
||||||
|
call_history.append(json.loads(line))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
print(f"Loaded {len(call_history)} historical records from {CALL_LOG_FILE}")
|
||||||
|
|
||||||
|
_load_history()
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
@@ -58,15 +81,6 @@ class ChatResponse(BaseModel):
|
|||||||
latency_ms: float
|
latency_ms: float
|
||||||
|
|
||||||
|
|
||||||
class StatsResponse(BaseModel):
|
|
||||||
total_calls: int
|
|
||||||
total_cost_usd: float
|
|
||||||
avg_latency_ms: float
|
|
||||||
model_distribution: Dict[str, int]
|
|
||||||
provider_distribution: Dict[str, int]
|
|
||||||
recent_calls: List[Dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="LLM Router MVP",
|
title="LLM Router MVP",
|
||||||
description="基于 LiteLLM + NVIDIA 分类器的多提供商路由服务(支持3-tier智能路由)",
|
description="基于 LiteLLM + NVIDIA 分类器的多提供商路由服务(支持3-tier智能路由)",
|
||||||
@@ -102,27 +116,46 @@ def select_model_by_length(messages: List[Message]) -> str:
|
|||||||
return DEFAULT_ROUTING["complex"]
|
return DEFAULT_ROUTING["complex"]
|
||||||
|
|
||||||
|
|
||||||
def select_model_by_nvidia_classifier(messages: List[Message]) -> str:
|
def select_model_by_nvidia_classifier(messages: List[Message]) -> tuple:
|
||||||
"""
|
"""
|
||||||
基于 NVIDIA 多头分类器选择模型(3-tier路由)
|
基于 NVIDIA 多头分类器选择模型(3-tier路由)
|
||||||
|
|
||||||
NVIDIA 输出: 多维度复杂度评分
|
Returns:
|
||||||
映射到 Qwen 模型:
|
(model_key, routing_detail) - 模型名称 + 路由分类细节
|
||||||
- simple -> qwen-flash (简单任务)
|
|
||||||
- medium -> qwen-plus (中等任务)
|
|
||||||
- complex -> qwen-max (复杂任务)
|
|
||||||
"""
|
"""
|
||||||
# 取最后一条用户消息作为查询
|
|
||||||
query = messages[-1].content if messages else ""
|
query = messages[-1].content if messages else ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
router = get_router()
|
router = get_router()
|
||||||
model = router.select_model(query)
|
start = time.time()
|
||||||
return model
|
result = router.predict(query)
|
||||||
|
routing_ms = (time.time() - start) * 1000
|
||||||
|
|
||||||
|
model_map = {"simple": "qwen-flash", "medium": "qwen-plus", "complex": "qwen-max"}
|
||||||
|
model_key = model_map[result["tier"]]
|
||||||
|
|
||||||
|
routing_detail = {
|
||||||
|
"method": "nvidia_classifier",
|
||||||
|
"query": query,
|
||||||
|
"routing_latency_ms": round(routing_ms, 2),
|
||||||
|
"tier": result["tier"],
|
||||||
|
"complexity_score": result["complexity_score"],
|
||||||
|
"task_type": result["task_type"],
|
||||||
|
"domain_knowledge": result["domain_knowledge"],
|
||||||
|
"reasoning": result["reasoning"],
|
||||||
|
"creativity": result["creativity"],
|
||||||
|
}
|
||||||
|
return model_key, routing_detail
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# NVIDIA 分类器失败时回退到 token 长度策略
|
|
||||||
print(f"NVIDIA routing failed: {e}, falling back to token length")
|
print(f"NVIDIA routing failed: {e}, falling back to token length")
|
||||||
return select_model_by_length(messages)
|
model_key = select_model_by_length(messages)
|
||||||
|
routing_detail = {
|
||||||
|
"method": "fallback_token_length",
|
||||||
|
"query": query,
|
||||||
|
"routing_latency_ms": 0,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
return model_key, routing_detail
|
||||||
|
|
||||||
|
|
||||||
def get_provider_model(model_key: str) -> str:
|
def get_provider_model(model_key: str) -> str:
|
||||||
@@ -154,61 +187,114 @@ def get_provider_from_model(model_name: str) -> str:
|
|||||||
return "unknown"
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
def log_call(model: str, provider: str, cost: float, latency_ms: float, tokens: int):
|
def log_call(
|
||||||
"""记录调用历史"""
|
model: str,
|
||||||
call_history.append({
|
provider: str,
|
||||||
"model": model,
|
cost: float,
|
||||||
"provider": provider,
|
latency_ms: float,
|
||||||
"cost_usd": cost,
|
input_tokens: int,
|
||||||
"latency_ms": latency_ms,
|
output_tokens: int,
|
||||||
"tokens": tokens,
|
messages: List[Dict[str, str]],
|
||||||
|
response_content: str,
|
||||||
|
response_id: str,
|
||||||
|
routing_detail: Optional[Dict[str, Any]],
|
||||||
|
request_params: Dict[str, Any],
|
||||||
|
):
|
||||||
|
"""记录完整调用历史(含路由细节 + LLM 原始数据,供后续调优)"""
|
||||||
|
record = {
|
||||||
"timestamp": time.time(),
|
"timestamp": time.time(),
|
||||||
})
|
# 请求
|
||||||
|
"request": {
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": request_params.get("temperature"),
|
||||||
|
"max_tokens": request_params.get("max_tokens"),
|
||||||
|
"user_specified_model": request_params.get("user_specified_model"),
|
||||||
|
},
|
||||||
|
# 路由决策
|
||||||
|
"routing": routing_detail,
|
||||||
|
# LLM 调用
|
||||||
|
"llm": {
|
||||||
|
"model": model,
|
||||||
|
"provider": provider,
|
||||||
|
"response_id": response_id,
|
||||||
|
"response_content": response_content,
|
||||||
|
"input_tokens": input_tokens,
|
||||||
|
"output_tokens": output_tokens,
|
||||||
|
"total_tokens": input_tokens + output_tokens,
|
||||||
|
"cost_usd": cost,
|
||||||
|
"llm_latency_ms": round(latency_ms, 2),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
call_history.append(record)
|
||||||
|
|
||||||
|
# 追加写入 JSONL 文件
|
||||||
|
with open(CALL_LOG_FILE, "a", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions", response_model=ChatResponse)
|
@app.post("/v1/chat/completions", response_model=ChatResponse)
|
||||||
async def chat_completions(request: ChatRequest):
|
async def chat_completions(request: ChatRequest):
|
||||||
"""
|
"""
|
||||||
聊天完成接口
|
聊天完成接口
|
||||||
如果 request.model 未指定,则根据 token 长度自动路由
|
如果 request.model 未指定,则使用 NVIDIA 分类器智能路由
|
||||||
"""
|
"""
|
||||||
# 选择模型
|
routing_detail = None
|
||||||
|
|
||||||
if request.model:
|
if request.model:
|
||||||
model_key = request.model
|
model_key = request.model
|
||||||
|
routing_detail = {
|
||||||
|
"method": "user_specified",
|
||||||
|
"query": request.messages[-1].content if request.messages else "",
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
# 使用 NVIDIA 多头分类器智能路由(支持3-tier)
|
model_key, routing_detail = select_model_by_nvidia_classifier(request.messages)
|
||||||
model_key = select_model_by_nvidia_classifier(request.messages)
|
|
||||||
|
|
||||||
# 获取 LiteLLM 模型名称
|
# 获取 LiteLLM 模型名称
|
||||||
provider_model = get_provider_model(model_key)
|
provider_model = get_provider_model(model_key)
|
||||||
provider = get_provider_from_model(provider_model)
|
provider = get_provider_from_model(provider_model)
|
||||||
|
messages_raw = [{"role": m.role, "content": m.content} for m in request.messages]
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用 LiteLLM 统一调用
|
|
||||||
response = await acompletion(
|
response = await acompletion(
|
||||||
model=provider_model,
|
model=provider_model,
|
||||||
messages=[{"role": m.role, "content": m.content} for m in request.messages],
|
messages=messages_raw,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
latency_ms = (time.time() - start_time) * 1000
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
# 计算成本
|
|
||||||
input_tokens = response.usage.prompt_tokens
|
input_tokens = response.usage.prompt_tokens
|
||||||
output_tokens = response.usage.completion_tokens
|
output_tokens = response.usage.completion_tokens
|
||||||
cost = calculate_cost(model_key, input_tokens, output_tokens)
|
cost = calculate_cost(model_key, input_tokens, output_tokens)
|
||||||
|
response_content = response.choices[0].message.content
|
||||||
|
|
||||||
# 记录调用
|
# 记录完整调用数据
|
||||||
log_call(model_key, provider, cost, latency_ms, input_tokens + output_tokens)
|
log_call(
|
||||||
|
model=model_key,
|
||||||
|
provider=provider,
|
||||||
|
cost=cost,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
messages=messages_raw,
|
||||||
|
response_content=response_content,
|
||||||
|
response_id=response.id,
|
||||||
|
routing_detail=routing_detail,
|
||||||
|
request_params={
|
||||||
|
"temperature": request.temperature,
|
||||||
|
"max_tokens": request.max_tokens,
|
||||||
|
"user_specified_model": request.model,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return ChatResponse(
|
return ChatResponse(
|
||||||
id=response.id,
|
id=response.id,
|
||||||
model=model_key,
|
model=model_key,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
content=response.choices[0].message.content,
|
content=response_content,
|
||||||
usage={
|
usage={
|
||||||
"prompt_tokens": input_tokens,
|
"prompt_tokens": input_tokens,
|
||||||
"completion_tokens": output_tokens,
|
"completion_tokens": output_tokens,
|
||||||
@@ -238,52 +324,73 @@ async def list_models():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/stats", response_model=StatsResponse)
|
@app.get("/stats")
|
||||||
async def get_stats():
|
async def get_stats():
|
||||||
"""获取调用统计"""
|
"""获取调用统计摘要"""
|
||||||
if not call_history:
|
if not call_history:
|
||||||
return StatsResponse(
|
return {
|
||||||
total_calls=0,
|
"total_calls": 0,
|
||||||
total_cost_usd=0.0,
|
"total_cost_usd": 0.0,
|
||||||
avg_latency_ms=0.0,
|
"avg_latency_ms": 0.0,
|
||||||
model_distribution={},
|
"model_distribution": {},
|
||||||
provider_distribution={},
|
"tier_distribution": {},
|
||||||
recent_calls=[],
|
"task_type_distribution": {},
|
||||||
)
|
}
|
||||||
|
|
||||||
total_calls = len(call_history)
|
total_calls = len(call_history)
|
||||||
total_cost = sum(c["cost_usd"] for c in call_history)
|
total_cost = sum(c["llm"]["cost_usd"] for c in call_history)
|
||||||
avg_latency = sum(c["latency_ms"] for c in call_history) / total_calls
|
avg_latency = sum(c["llm"]["llm_latency_ms"] for c in call_history) / total_calls
|
||||||
|
|
||||||
# 模型分布
|
|
||||||
model_dist: Dict[str, int] = {}
|
model_dist: Dict[str, int] = {}
|
||||||
provider_dist: Dict[str, int] = {}
|
tier_dist: Dict[str, int] = {}
|
||||||
|
task_dist: Dict[str, int] = {}
|
||||||
|
|
||||||
for call in call_history:
|
for call in call_history:
|
||||||
model = call["model"]
|
model = call["llm"]["model"]
|
||||||
provider = call["provider"]
|
|
||||||
model_dist[model] = model_dist.get(model, 0) + 1
|
model_dist[model] = model_dist.get(model, 0) + 1
|
||||||
provider_dist[provider] = provider_dist.get(provider, 0) + 1
|
|
||||||
|
routing = call.get("routing") or {}
|
||||||
|
if routing.get("tier"):
|
||||||
|
tier = routing["tier"]
|
||||||
|
tier_dist[tier] = tier_dist.get(tier, 0) + 1
|
||||||
|
if routing.get("task_type"):
|
||||||
|
task = routing["task_type"]
|
||||||
|
task_dist[task] = task_dist.get(task, 0) + 1
|
||||||
|
|
||||||
# 最近 10 条记录
|
return {
|
||||||
recent = [
|
"total_calls": total_calls,
|
||||||
{
|
"total_cost_usd": round(total_cost, 6),
|
||||||
"model": c["model"],
|
"avg_latency_ms": round(avg_latency, 2),
|
||||||
"provider": c["provider"],
|
"avg_routing_ms": round(
|
||||||
"cost_usd": round(c["cost_usd"], 6),
|
sum(c.get("routing", {}).get("routing_latency_ms", 0) for c in call_history) / total_calls, 2
|
||||||
"latency_ms": round(c["latency_ms"], 2),
|
),
|
||||||
"tokens": c["tokens"],
|
"model_distribution": model_dist,
|
||||||
}
|
"tier_distribution": tier_dist,
|
||||||
for c in call_history[-10:]
|
"task_type_distribution": task_dist,
|
||||||
]
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/stats/raw")
|
||||||
|
async def get_stats_raw(limit: int = 50, offset: int = 0):
|
||||||
|
"""
|
||||||
|
获取原始调用记录(含路由分类细节 + LLM 完整数据)
|
||||||
|
用于后续调优和分析
|
||||||
|
|
||||||
return StatsResponse(
|
参数:
|
||||||
total_calls=total_calls,
|
- limit: 返回条数(默认50)
|
||||||
total_cost_usd=round(total_cost, 6),
|
- offset: 偏移量(默认0,从最新开始)
|
||||||
avg_latency_ms=round(avg_latency, 2),
|
"""
|
||||||
model_distribution=model_dist,
|
total = len(call_history)
|
||||||
provider_distribution=provider_dist,
|
# 倒序返回(最新在前)
|
||||||
recent_calls=recent,
|
records = list(reversed(call_history))
|
||||||
)
|
page = records[offset:offset + limit]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
"records": page,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
|
|||||||
Reference in New Issue
Block a user