feat(api): 重写为OpenAI兼容API并支持流式SSE返回
- 请求/响应完全对齐OpenAI Chat Completions API格式
- 支持 stream=true SSE流式返回 (data: {...}\n\n + [DONE])
- 新增 /v1/models 接口 (OpenAI格式 object:list)
- 非流式响应扩展 routing 字段暴露路由决策细节
- OpenAI Python SDK可直接对接 (base_url=http://localhost:8000/v1)
- 版本升级至v0.4.0
This commit is contained in:
382
main.py
382
main.py
@@ -1,17 +1,18 @@
|
|||||||
"""
|
"""
|
||||||
MVP版 LLM 路由服务
|
LLM Router 服务 - OpenAI 兼容 API
|
||||||
基于 LiteLLM 的多提供商统一接口
|
基于 LiteLLM + NVIDIA 多头分类器的智能路由服务
|
||||||
支持: OpenAI, Anthropic, Gemini, Ollama 等 100+ 提供商
|
支持 OpenAI Chat Completions API 格式(含流式返回)
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import os
|
import uuid
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from litellm import acompletion
|
from litellm import acompletion
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
@@ -21,7 +22,6 @@ from nvidia_router import get_nvidia_router, select_model_by_nvidia
|
|||||||
# 配置 LiteLLM 使用 DashScope (Qwen)
|
# 配置 LiteLLM 使用 DashScope (Qwen)
|
||||||
if DASHSCOPE_API_KEY:
|
if DASHSCOPE_API_KEY:
|
||||||
litellm.api_key = DASHSCOPE_API_KEY
|
litellm.api_key = DASHSCOPE_API_KEY
|
||||||
# Qwen 使用 OpenAI 兼容接口,但需要通过 api_base 指定
|
|
||||||
litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
litellm.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
|
|
||||||
# NVIDIA Router 实例(延迟加载)
|
# NVIDIA Router 实例(延迟加载)
|
||||||
@@ -35,16 +35,14 @@ def get_router():
|
|||||||
return _nvidia_router
|
return _nvidia_router
|
||||||
|
|
||||||
|
|
||||||
# 调用历史 - JSON 文件持久化
|
# ── 调用历史 - JSONL 持久化 ────────────────────────────────
|
||||||
CALL_LOG_DIR = Path(__file__).parent / "data"
|
CALL_LOG_DIR = Path(__file__).parent / "data"
|
||||||
CALL_LOG_DIR.mkdir(exist_ok=True)
|
CALL_LOG_DIR.mkdir(exist_ok=True)
|
||||||
CALL_LOG_FILE = CALL_LOG_DIR / "call_history.jsonl"
|
CALL_LOG_FILE = CALL_LOG_DIR / "call_history.jsonl"
|
||||||
|
|
||||||
# 内存缓存(启动时从文件加载)
|
|
||||||
call_history: List[Dict[str, Any]] = []
|
call_history: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
def _load_history():
|
def _load_history():
|
||||||
"""启动时从 JSONL 文件加载历史记录"""
|
|
||||||
if CALL_LOG_FILE.exists():
|
if CALL_LOG_FILE.exists():
|
||||||
with open(CALL_LOG_FILE, "r", encoding="utf-8") as f:
|
with open(CALL_LOG_FILE, "r", encoding="utf-8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
@@ -59,55 +57,52 @@ def _load_history():
|
|||||||
_load_history()
|
_load_history()
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
# ── OpenAI 兼容请求/响应模型 ────────────────────────────────
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: Optional[str] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
class ChatCompletionRequest(BaseModel):
|
||||||
class ChatRequest(BaseModel):
|
model: Optional[str] = None
|
||||||
messages: List[Message]
|
messages: List[ChatMessage]
|
||||||
model: Optional[str] = None # 可选,如果指定则跳过路由
|
temperature: Optional[float] = 0.7
|
||||||
temperature: float = 0.7
|
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
top_p: Optional[float] = 1.0
|
||||||
|
n: Optional[int] = 1
|
||||||
|
stop: Optional[Any] = None
|
||||||
|
presence_penalty: Optional[float] = 0.0
|
||||||
|
frequency_penalty: Optional[float] = 0.0
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatResponse(BaseModel):
|
# ── FastAPI App ──────────────────────────────────────────────
|
||||||
id: str
|
|
||||||
model: str
|
|
||||||
provider: str
|
|
||||||
content: str
|
|
||||||
usage: Dict[str, int]
|
|
||||||
cost_usd: float
|
|
||||||
latency_ms: float
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="LLM Router MVP",
|
title="LLM Router",
|
||||||
description="基于 LiteLLM + NVIDIA 分类器的多提供商路由服务(支持3-tier智能路由)",
|
description="OpenAI 兼容的 LLM 智能路由服务(NVIDIA 3-tier 分类 + LiteLLM 多提供商)",
|
||||||
version="0.3.0",
|
version="0.4.0",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def estimate_tokens(messages: List[Message]) -> int:
|
# ── 辅助函数 ─────────────────────────────────────────────────
|
||||||
"""估算 token 数量"""
|
def estimate_tokens(messages: List[ChatMessage]) -> int:
|
||||||
try:
|
try:
|
||||||
encoding = tiktoken.encoding_for_model("gpt-4")
|
encoding = tiktoken.encoding_for_model("gpt-4")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
total = 0
|
||||||
total_tokens = 0
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
total_tokens += 4
|
total += 4
|
||||||
total_tokens += len(encoding.encode(msg.content))
|
if msg.content:
|
||||||
total_tokens += len(encoding.encode(msg.role))
|
total += len(encoding.encode(msg.content))
|
||||||
total_tokens += 2
|
total += len(encoding.encode(msg.role))
|
||||||
return total_tokens
|
total += 2
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
def select_model_by_length(messages: List[Message]) -> str:
|
def select_model_by_length(messages: List[ChatMessage]) -> str:
|
||||||
"""基于 token 长度选择模型(备用策略)"""
|
|
||||||
token_count = estimate_tokens(messages)
|
token_count = estimate_tokens(messages)
|
||||||
|
|
||||||
if token_count < ROUTING_THRESHOLDS["simple"]:
|
if token_count < ROUTING_THRESHOLDS["simple"]:
|
||||||
return DEFAULT_ROUTING["simple"]
|
return DEFAULT_ROUTING["simple"]
|
||||||
elif token_count < ROUTING_THRESHOLDS["medium"]:
|
elif token_count < ROUTING_THRESHOLDS["medium"]:
|
||||||
@@ -116,14 +111,16 @@ def select_model_by_length(messages: List[Message]) -> str:
|
|||||||
return DEFAULT_ROUTING["complex"]
|
return DEFAULT_ROUTING["complex"]
|
||||||
|
|
||||||
|
|
||||||
def select_model_by_nvidia_classifier(messages: List[Message]) -> tuple:
|
def select_model_by_nvidia_classifier(messages: List[ChatMessage]) -> tuple:
|
||||||
"""
|
"""
|
||||||
基于 NVIDIA 多头分类器选择模型(3-tier路由)
|
基于 NVIDIA 多头分类器选择模型
|
||||||
|
Returns: (model_key, routing_detail)
|
||||||
Returns:
|
|
||||||
(model_key, routing_detail) - 模型名称 + 路由分类细节
|
|
||||||
"""
|
"""
|
||||||
query = messages[-1].content if messages else ""
|
query = ""
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if msg.role == "user" and msg.content:
|
||||||
|
query = msg.content
|
||||||
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
router = get_router()
|
router = get_router()
|
||||||
@@ -159,7 +156,6 @@ def select_model_by_nvidia_classifier(messages: List[Message]) -> tuple:
|
|||||||
|
|
||||||
|
|
||||||
def get_provider_model(model_key: str) -> str:
|
def get_provider_model(model_key: str) -> str:
|
||||||
"""获取 LiteLLM 格式的模型名称"""
|
|
||||||
config = MODEL_CONFIG.get(model_key)
|
config = MODEL_CONFIG.get(model_key)
|
||||||
if not config:
|
if not config:
|
||||||
raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}")
|
raise HTTPException(status_code=400, detail=f"Unknown model: {model_key}")
|
||||||
@@ -167,55 +163,35 @@ def get_provider_model(model_key: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float:
|
def calculate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float:
|
||||||
"""计算调用成本"""
|
|
||||||
config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"])
|
config = MODEL_CONFIG.get(model_key, MODEL_CONFIG["gpt-4o"])
|
||||||
input_cost = (input_tokens / 1000) * config["input_cost"]
|
return (input_tokens / 1000) * config["input_cost"] + (output_tokens / 1000) * config["output_cost"]
|
||||||
output_cost = (output_tokens / 1000) * config["output_cost"]
|
|
||||||
return input_cost + output_cost
|
|
||||||
|
|
||||||
|
|
||||||
def get_provider_from_model(model_name: str) -> str:
|
|
||||||
"""从模型名称推断提供商"""
|
|
||||||
if model_name.startswith("gpt"):
|
|
||||||
return "openai"
|
|
||||||
elif model_name.startswith("claude"):
|
|
||||||
return "anthropic"
|
|
||||||
elif model_name.startswith("gemini"):
|
|
||||||
return "google"
|
|
||||||
elif "/" in model_name:
|
|
||||||
return model_name.split("/")[0]
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
def log_call(
|
def log_call(
|
||||||
model: str,
|
model: str,
|
||||||
provider: str,
|
|
||||||
cost: float,
|
cost: float,
|
||||||
latency_ms: float,
|
latency_ms: float,
|
||||||
input_tokens: int,
|
input_tokens: int,
|
||||||
output_tokens: int,
|
output_tokens: int,
|
||||||
messages: List[Dict[str, str]],
|
messages_raw: List[Dict],
|
||||||
response_content: str,
|
response_content: str,
|
||||||
response_id: str,
|
response_id: str,
|
||||||
routing_detail: Optional[Dict[str, Any]],
|
routing_detail: Optional[Dict],
|
||||||
request_params: Dict[str, Any],
|
request_params: Dict,
|
||||||
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
"""记录完整调用历史(含路由细节 + LLM 原始数据,供后续调优)"""
|
|
||||||
record = {
|
record = {
|
||||||
"timestamp": time.time(),
|
"timestamp": time.time(),
|
||||||
# 请求
|
|
||||||
"request": {
|
"request": {
|
||||||
"messages": messages,
|
"messages": messages_raw,
|
||||||
"temperature": request_params.get("temperature"),
|
"temperature": request_params.get("temperature"),
|
||||||
"max_tokens": request_params.get("max_tokens"),
|
"max_tokens": request_params.get("max_tokens"),
|
||||||
|
"stream": stream,
|
||||||
"user_specified_model": request_params.get("user_specified_model"),
|
"user_specified_model": request_params.get("user_specified_model"),
|
||||||
},
|
},
|
||||||
# 路由决策
|
|
||||||
"routing": routing_detail,
|
"routing": routing_detail,
|
||||||
# LLM 调用
|
|
||||||
"llm": {
|
"llm": {
|
||||||
"model": model,
|
"model": model,
|
||||||
"provider": provider,
|
|
||||||
"response_id": response_id,
|
"response_id": response_id,
|
||||||
"response_content": response_content,
|
"response_content": response_content,
|
||||||
"input_tokens": input_tokens,
|
"input_tokens": input_tokens,
|
||||||
@@ -226,36 +202,85 @@ def log_call(
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
call_history.append(record)
|
call_history.append(record)
|
||||||
|
|
||||||
# 追加写入 JSONL 文件
|
|
||||||
with open(CALL_LOG_FILE, "a", encoding="utf-8") as f:
|
with open(CALL_LOG_FILE, "a", encoding="utf-8") as f:
|
||||||
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions", response_model=ChatResponse)
|
def build_openai_response(
|
||||||
async def chat_completions(request: ChatRequest):
|
response_id: str,
|
||||||
"""
|
model: str,
|
||||||
聊天完成接口
|
content: str,
|
||||||
如果 request.model 未指定,则使用 NVIDIA 分类器智能路由
|
input_tokens: int,
|
||||||
"""
|
output_tokens: int,
|
||||||
routing_detail = None
|
routing_detail: Optional[Dict] = None,
|
||||||
|
) -> Dict:
|
||||||
|
"""构建 OpenAI 格式的非流式响应"""
|
||||||
|
resp = {
|
||||||
|
"id": response_id,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": model,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": content},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": input_tokens,
|
||||||
|
"completion_tokens": output_tokens,
|
||||||
|
"total_tokens": input_tokens + output_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
# 路由细节作为扩展字段
|
||||||
|
if routing_detail:
|
||||||
|
resp["routing"] = routing_detail
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
# ── 核心 API: /v1/chat/completions ──────────────────────────
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def chat_completions(request: ChatCompletionRequest):
|
||||||
|
"""
|
||||||
|
OpenAI 兼容的 Chat Completions API
|
||||||
|
- model 为空时自动使用 NVIDIA 分类器路由
|
||||||
|
- stream=true 时返回 SSE 流式响应
|
||||||
|
"""
|
||||||
|
# 1. 路由决策
|
||||||
|
routing_detail = None
|
||||||
if request.model:
|
if request.model:
|
||||||
model_key = request.model
|
model_key = request.model
|
||||||
routing_detail = {
|
routing_detail = {
|
||||||
"method": "user_specified",
|
"method": "user_specified",
|
||||||
"query": request.messages[-1].content if request.messages else "",
|
"query": next((m.content for m in reversed(request.messages) if m.role == "user" and m.content), ""),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
model_key, routing_detail = select_model_by_nvidia_classifier(request.messages)
|
model_key, routing_detail = select_model_by_nvidia_classifier(request.messages)
|
||||||
|
|
||||||
# 获取 LiteLLM 模型名称
|
|
||||||
provider_model = get_provider_model(model_key)
|
provider_model = get_provider_model(model_key)
|
||||||
provider = get_provider_from_model(provider_model)
|
|
||||||
messages_raw = [{"role": m.role, "content": m.content} for m in request.messages]
|
messages_raw = [{"role": m.role, "content": m.content} for m in request.messages]
|
||||||
|
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||||
|
|
||||||
|
# 2. 流式响应
|
||||||
|
if request.stream:
|
||||||
|
return StreamingResponse(
|
||||||
|
_stream_response(
|
||||||
|
provider_model=provider_model,
|
||||||
|
model_key=model_key,
|
||||||
|
messages_raw=messages_raw,
|
||||||
|
request=request,
|
||||||
|
response_id=response_id,
|
||||||
|
routing_detail=routing_detail,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 非流式响应
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await acompletion(
|
response = await acompletion(
|
||||||
model=provider_model,
|
model=provider_model,
|
||||||
@@ -263,26 +288,18 @@ async def chat_completions(request: ChatRequest):
|
|||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
latency_ms = (time.time() - start_time) * 1000
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
input_tokens = response.usage.prompt_tokens
|
input_tokens = response.usage.prompt_tokens
|
||||||
output_tokens = response.usage.completion_tokens
|
output_tokens = response.usage.completion_tokens
|
||||||
cost = calculate_cost(model_key, input_tokens, output_tokens)
|
cost = calculate_cost(model_key, input_tokens, output_tokens)
|
||||||
response_content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
# 记录完整调用数据
|
|
||||||
log_call(
|
log_call(
|
||||||
model=model_key,
|
model=model_key, cost=cost, latency_ms=latency_ms,
|
||||||
provider=provider,
|
input_tokens=input_tokens, output_tokens=output_tokens,
|
||||||
cost=cost,
|
messages_raw=messages_raw, response_content=content,
|
||||||
latency_ms=latency_ms,
|
response_id=response_id, routing_detail=routing_detail,
|
||||||
input_tokens=input_tokens,
|
|
||||||
output_tokens=output_tokens,
|
|
||||||
messages=messages_raw,
|
|
||||||
response_content=response_content,
|
|
||||||
response_id=response.id,
|
|
||||||
routing_detail=routing_detail,
|
|
||||||
request_params={
|
request_params={
|
||||||
"temperature": request.temperature,
|
"temperature": request.temperature,
|
||||||
"max_tokens": request.max_tokens,
|
"max_tokens": request.max_tokens,
|
||||||
@@ -290,51 +307,122 @@ async def chat_completions(request: ChatRequest):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatResponse(
|
return build_openai_response(response_id, model_key, content, input_tokens, output_tokens, routing_detail)
|
||||||
id=response.id,
|
|
||||||
model=model_key,
|
|
||||||
provider=provider,
|
|
||||||
content=response_content,
|
|
||||||
usage={
|
|
||||||
"prompt_tokens": input_tokens,
|
|
||||||
"completion_tokens": output_tokens,
|
|
||||||
"total_tokens": input_tokens + output_tokens,
|
|
||||||
},
|
|
||||||
cost_usd=cost,
|
|
||||||
latency_ms=round(latency_ms, 2),
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"API error: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"API error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/models")
|
async def _stream_response(
|
||||||
|
provider_model: str,
|
||||||
|
model_key: str,
|
||||||
|
messages_raw: List[Dict],
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
response_id: str,
|
||||||
|
routing_detail: Optional[Dict],
|
||||||
|
):
|
||||||
|
"""生成 SSE 流式响应"""
|
||||||
|
start_time = time.time()
|
||||||
|
collected_content = ""
|
||||||
|
input_tokens = 0
|
||||||
|
output_tokens = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await acompletion(
|
||||||
|
model=provider_model,
|
||||||
|
messages=messages_raw,
|
||||||
|
temperature=request.temperature,
|
||||||
|
max_tokens=request.max_tokens,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in response:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
|
||||||
|
# 收集内容用于日志
|
||||||
|
if delta.content:
|
||||||
|
collected_content += delta.content
|
||||||
|
output_tokens += 1 # 近似计数
|
||||||
|
|
||||||
|
# 构建 SSE 数据
|
||||||
|
chunk_data = {
|
||||||
|
"id": response_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": model_key,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {},
|
||||||
|
"finish_reason": None,
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
|
||||||
|
if delta.content:
|
||||||
|
chunk_data["choices"][0]["delta"] = {"content": delta.content}
|
||||||
|
elif delta.role:
|
||||||
|
chunk_data["choices"][0]["delta"] = {"role": delta.role}
|
||||||
|
|
||||||
|
if chunk.choices[0].finish_reason:
|
||||||
|
chunk_data["choices"][0]["finish_reason"] = chunk.choices[0].finish_reason
|
||||||
|
|
||||||
|
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 发送 [DONE]
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
# 记录日志
|
||||||
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
# 流式模式下用 tiktoken 近似计算 input_tokens
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
input_tokens = sum(len(encoding.encode(m.get("content", "") or "")) for m in messages_raw) + len(messages_raw) * 4
|
||||||
|
except Exception:
|
||||||
|
input_tokens = 0
|
||||||
|
|
||||||
|
cost = calculate_cost(model_key, input_tokens, output_tokens)
|
||||||
|
log_call(
|
||||||
|
model=model_key, cost=cost, latency_ms=latency_ms,
|
||||||
|
input_tokens=input_tokens, output_tokens=output_tokens,
|
||||||
|
messages_raw=messages_raw, response_content=collected_content,
|
||||||
|
response_id=response_id, routing_detail=routing_detail,
|
||||||
|
request_params={
|
||||||
|
"temperature": request.temperature,
|
||||||
|
"max_tokens": request.max_tokens,
|
||||||
|
"user_specified_model": request.model,
|
||||||
|
},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_data = {"error": {"message": str(e), "type": "api_error"}}
|
||||||
|
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
# ── OpenAI 兼容: /v1/models ─────────────────────────────────
|
||||||
|
@app.get("/v1/models")
|
||||||
async def list_models():
|
async def list_models():
|
||||||
"""列出支持的模型"""
|
"""OpenAI 兼容的模型列表接口"""
|
||||||
return {
|
data = []
|
||||||
"models": [
|
for key, config in MODEL_CONFIG.items():
|
||||||
{
|
data.append({
|
||||||
"key": key,
|
"id": key,
|
||||||
"provider": config["provider"],
|
"object": "model",
|
||||||
"input_cost_per_1k": config["input_cost"],
|
"created": 1700000000,
|
||||||
"output_cost_per_1k": config["output_cost"],
|
"owned_by": config["provider"].split("/")[0] if "/" in config["provider"] else "unknown",
|
||||||
}
|
})
|
||||||
for key, config in MODEL_CONFIG.items()
|
return {"object": "list", "data": data}
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
# ── 管理接口 ─────────────────────────────────────────────────
|
||||||
@app.get("/stats")
|
@app.get("/stats")
|
||||||
async def get_stats():
|
async def get_stats():
|
||||||
"""获取调用统计摘要"""
|
"""获取调用统计摘要"""
|
||||||
if not call_history:
|
if not call_history:
|
||||||
return {
|
return {
|
||||||
"total_calls": 0,
|
"total_calls": 0, "total_cost_usd": 0.0,
|
||||||
"total_cost_usd": 0.0,
|
"avg_latency_ms": 0.0, "model_distribution": {},
|
||||||
"avg_latency_ms": 0.0,
|
"tier_distribution": {}, "task_type_distribution": {},
|
||||||
"model_distribution": {},
|
|
||||||
"tier_distribution": {},
|
|
||||||
"task_type_distribution": {},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
total_calls = len(call_history)
|
total_calls = len(call_history)
|
||||||
@@ -348,14 +436,11 @@ async def get_stats():
|
|||||||
for call in call_history:
|
for call in call_history:
|
||||||
model = call["llm"]["model"]
|
model = call["llm"]["model"]
|
||||||
model_dist[model] = model_dist.get(model, 0) + 1
|
model_dist[model] = model_dist.get(model, 0) + 1
|
||||||
|
|
||||||
routing = call.get("routing") or {}
|
routing = call.get("routing") or {}
|
||||||
if routing.get("tier"):
|
if routing.get("tier"):
|
||||||
tier = routing["tier"]
|
tier_dist[routing["tier"]] = tier_dist.get(routing["tier"], 0) + 1
|
||||||
tier_dist[tier] = tier_dist.get(tier, 0) + 1
|
|
||||||
if routing.get("task_type"):
|
if routing.get("task_type"):
|
||||||
task = routing["task_type"]
|
task_dist[routing["task_type"]] = task_dist.get(routing["task_type"], 0) + 1
|
||||||
task_dist[task] = task_dist.get(task, 0) + 1
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"total_calls": total_calls,
|
"total_calls": total_calls,
|
||||||
@@ -372,31 +457,16 @@ async def get_stats():
|
|||||||
|
|
||||||
@app.get("/stats/raw")
|
@app.get("/stats/raw")
|
||||||
async def get_stats_raw(limit: int = 50, offset: int = 0):
|
async def get_stats_raw(limit: int = 50, offset: int = 0):
|
||||||
"""
|
"""获取原始调用记录"""
|
||||||
获取原始调用记录(含路由分类细节 + LLM 完整数据)
|
|
||||||
用于后续调优和分析
|
|
||||||
|
|
||||||
参数:
|
|
||||||
- limit: 返回条数(默认50)
|
|
||||||
- offset: 偏移量(默认0,从最新开始)
|
|
||||||
"""
|
|
||||||
total = len(call_history)
|
total = len(call_history)
|
||||||
# 倒序返回(最新在前)
|
|
||||||
records = list(reversed(call_history))
|
records = list(reversed(call_history))
|
||||||
page = records[offset:offset + limit]
|
return {"total": total, "limit": limit, "offset": offset, "records": records[offset:offset + limit]}
|
||||||
|
|
||||||
return {
|
|
||||||
"total": total,
|
|
||||||
"limit": limit,
|
|
||||||
"offset": offset,
|
|
||||||
"records": page,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
"""健康检查"""
|
"""健康检查"""
|
||||||
return {"status": "healthy", "version": "0.3.0", "router": "nvidia-3tier"}
|
return {"status": "healthy", "version": "0.4.0", "router": "nvidia-3tier"}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user