Compare commits
7 Commits
4c439d2d7e
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 00083627f7 | |||
| d7155e98c3 | |||
| 508118cc50 | |||
| 943fc9dcc0 | |||
| a6a471c5c4 | |||
| 72345871c6 | |||
| 2afe976a31 |
281
README.md
Normal file
281
README.md
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
# LLM Compass
|
||||||
|
|
||||||
|
智能 LLM 路由服务,基于 NVIDIA 多头分类器和 Apple Silicon MPS 加速,为查询自动选择最优模型,兼顾质量与成本。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 项目背景
|
||||||
|
|
||||||
|
在大规模使用 LLM 的场景中,不同复杂度的查询适合不同规格的模型:
|
||||||
|
- 简单问候用 `qwen-flash` 即可,成本低、延迟小
|
||||||
|
- 代码生成需要 `qwen-plus` 保证质量
|
||||||
|
- 复杂分析任务才值得调用 `qwen-max`
|
||||||
|
|
||||||
|
手动选择模型效率低下,而全部使用最强模型又浪费成本。**LLM Compass** 的目标是自动为每个查询选择"刚刚好"的模型。
|
||||||
|
|
||||||
|
灵感来源于 [tx402.ai](https://tx402.ai) 的三层路由架构,本项目采用开源 NVIDIA 多头分类器实现了类似能力。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 核心方法
|
||||||
|
|
||||||
|
### NVIDIA 多头分类器
|
||||||
|
|
||||||
|
采用 [nvidia/prompt-task-and-complexity-classifier](https://huggingface.co/nvidia/prompt-task-and-complexity-classifier)(184M 参数,DeBERTa-v3-base 架构):
|
||||||
|
|
||||||
|
```
|
||||||
|
用户查询 → DeBERTa Backbone → 8个分类头 → 综合评分 → 3-tier路由
|
||||||
|
↓
|
||||||
|
task_type (12类)
|
||||||
|
creativity (3类)
|
||||||
|
reasoning (2类)
|
||||||
|
domain_knowledge (4类)
|
||||||
|
contextual_knowledge
|
||||||
|
number_of_few_shots
|
||||||
|
no_label_reason
|
||||||
|
constraint_ct
|
||||||
|
```
|
||||||
|
|
||||||
|
### 复杂度评分公式
|
||||||
|
|
||||||
|
```python
|
||||||
|
score = (
|
||||||
|
0.4 * domain_knowledge + # High=1.0, Medium=0.6, Low=0.3, No=0.0
|
||||||
|
0.3 * reasoning + # Yes=1.0, No=0.0
|
||||||
|
0.2 * creativity + # High=1.0, Low=0.4, No=0.0
|
||||||
|
0.1 * task_type # Code=0.8, QA=0.5, Chatbot=0.2, ...
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3-tier 路由
|
||||||
|
score < 0.35 → simple → qwen-flash
|
||||||
|
0.35 ≤ score < 0.65 → medium → qwen-plus
|
||||||
|
score ≥ 0.65 → complex → qwen-max
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 技术架构
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────────────────────────────────────────────────┐
|
||||||
|
│ LLM Compass │
|
||||||
|
├──────────────────────────────────────────────────────────┤
|
||||||
|
│ API Layer: FastAPI (OpenAI 兼容) │
|
||||||
|
│ ├─ POST /v1/chat/completions (流式/非流式) │
|
||||||
|
│ ├─ GET /v1/models │
|
||||||
|
│ ├─ GET /stats │
|
||||||
|
│ └─ GET /docs (Swagger UI) │
|
||||||
|
├──────────────────────────────────────────────────────────┤
|
||||||
|
│ Routing Layer: NVIDIA Multi-Head Classifier (184M) │
|
||||||
|
│ ├─ 8 维度查询分析 │
|
||||||
|
│ ├─ 综合复杂度评分 │
|
||||||
|
│ └─ 3-tier 智能路由 │
|
||||||
|
├──────────────────────────────────────────────────────────┤
|
||||||
|
│ LLM Backend: LiteLLM (多提供商统一接口) │
|
||||||
|
│ ├─ DashScope (Qwen) │
|
||||||
|
│ ├─ OpenAI (GPT) │
|
||||||
|
│ ├─ Anthropic (Claude) │
|
||||||
|
│ └─ Google (Gemini) │
|
||||||
|
└──────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Apple Silicon 优化 (M4 Pro)
|
||||||
|
|
||||||
|
- **MPS 加速**: 使用 Metal Performance Shaders GPU 后端
|
||||||
|
- **FP16 推理**: 半精度浮点,避免 MPS 矩阵乘法类型冲突
|
||||||
|
- **统一内存**: M4 Pro 64GB 统一内存,模型加载零拷贝
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 实现效果
|
||||||
|
|
||||||
|
### 路由准确性
|
||||||
|
|
||||||
|
| 查询示例 | Tier | Score | 路由模型 |
|
||||||
|
|---------|------|-------|---------|
|
||||||
|
| "你好" | simple | 0.17 | qwen-flash |
|
||||||
|
| "1+1等于几" | simple | 0.17 | qwen-flash |
|
||||||
|
| "Write quicksort in Python" | medium | 0.45 | qwen-plus |
|
||||||
|
| "分析深度学习的注意力机制原理" | medium | 0.47 | qwen-plus |
|
||||||
|
| "请详细分析量子计算对密码学的影响" | complex | 0.72 | qwen-max |
|
||||||
|
|
||||||
|
### 成本优化
|
||||||
|
|
||||||
|
根据实际调用统计:
|
||||||
|
- **16 次调用总成本**: $0.011
|
||||||
|
- **模型分布**: qwen-flash 87.5% (14次), qwen-plus 12.5% (2次)
|
||||||
|
- **任务类型**: 主要为 Open QA (13次)
|
||||||
|
- **复杂度分布**: simple 73%, medium 13%
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 路由延迟
|
||||||
|
|
||||||
|
| 环境 | 首次加载 | 稳态延迟 | 备注 |
|
||||||
|
|------|---------|---------|------|
|
||||||
|
| **M4 Pro MPS + FP16** | ~2s (模型加载) | **~60-90ms** | 当前生产环境 |
|
||||||
|
| x86 CPU | ~3s | ~100-150ms | Docker 容器 |
|
||||||
|
| NVIDIA 官方报告 | - | 5-15ms | 数据中心 CPU |
|
||||||
|
|
||||||
|
**说明**:
|
||||||
|
- 首次加载包含模型下载和 MPS kernel 编译,后续请求无需重新加载
|
||||||
|
- 稳态延迟约 60-90ms,其中分类器推理 ~53ms,其余为 FastAPI 开销
|
||||||
|
- 对于 LLM 调用本身(通常 2-10s),路由开销占比 < 2%
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
### 前置要求
|
||||||
|
|
||||||
|
- Python 3.12+
|
||||||
|
- macOS (Apple Silicon 推荐,支持 MPS 加速)
|
||||||
|
- DashScope API Key(阿里云 Qwen)
|
||||||
|
|
||||||
|
### 安装
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. 克隆项目
|
||||||
|
git clone <repo-url>
|
||||||
|
cd llm-compass
|
||||||
|
|
||||||
|
# 2. 创建虚拟环境
|
||||||
|
python3 -m venv .venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# 3. 安装依赖
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# 4. 配置 API Key
|
||||||
|
cp .env.example .env
|
||||||
|
# 编辑 .env 填入 DASHSCOPE_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
### 启动服务
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./start.sh # 默认端口 8402
|
||||||
|
./start.sh 9000 # 自定义端口
|
||||||
|
```
|
||||||
|
|
||||||
|
### 测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 健康检查
|
||||||
|
curl http://localhost:8402/health
|
||||||
|
|
||||||
|
# API 测试 (自动路由)
|
||||||
|
curl http://localhost:8402/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"messages":[{"role":"user","content":"你好"}]}'
|
||||||
|
|
||||||
|
# Swagger UI
|
||||||
|
# 访问 http://localhost:8402/docs
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API 使用
|
||||||
|
|
||||||
|
### OpenAI 兼容接口
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
base_url="http://localhost:8402/v1",
|
||||||
|
api_key="not-needed" # 可选
|
||||||
|
)
|
||||||
|
|
||||||
|
# 自动路由(推荐)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
messages=[{"role": "user", "content": "解释量子计算"}]
|
||||||
|
)
|
||||||
|
print(response.choices[0].message.content)
|
||||||
|
print(response.routing) # 路由详情
|
||||||
|
|
||||||
|
# 指定模型
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="qwen-plus",
|
||||||
|
messages=[{"role": "user", "content": "写一个排序算法"}]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 响应格式
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-xxx",
|
||||||
|
"model": "qwen-flash",
|
||||||
|
"choices": [{"message": {"content": "..."}}],
|
||||||
|
"usage": {"prompt_tokens": 13, "completion_tokens": 25, "total_tokens": 38},
|
||||||
|
"routing": {
|
||||||
|
"method": "nvidia_classifier",
|
||||||
|
"tier": "simple",
|
||||||
|
"complexity_score": 0.17,
|
||||||
|
"task_type": "Open QA",
|
||||||
|
"domain_knowledge": "Low",
|
||||||
|
"reasoning": false,
|
||||||
|
"creativity": "No",
|
||||||
|
"routing_latency_ms": 63.27
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 技术栈
|
||||||
|
|
||||||
|
- **Web 框架**: FastAPI + Uvicorn
|
||||||
|
- **路由模型**: NVIDIA Multi-Head Classifier (DeBERTa-v3-base)
|
||||||
|
- **LLM 调用**: LiteLLM (多提供商统一接口)
|
||||||
|
- **GPU 加速**: PyTorch MPS (Metal Performance Shaders)
|
||||||
|
- **Token 计算**: tiktoken
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
llm-compass/
|
||||||
|
├── main.py # FastAPI 主服务
|
||||||
|
├── nvidia_router.py # NVIDIA 分类器实现
|
||||||
|
├── config.py # 模型配置和路由阈值
|
||||||
|
├── start.sh # 启动脚本
|
||||||
|
├── .env # 环境变量(不提交到 Git)
|
||||||
|
├── .env.example # 环境变量模板
|
||||||
|
├── requirements.txt # Python 依赖
|
||||||
|
├── Dockerfile # Docker 构建文件
|
||||||
|
├── docker-compose.yml # Docker Compose 配置
|
||||||
|
├── data/ # 调用历史日志(自动创建)
|
||||||
|
└── docs/ # 技术文档
|
||||||
|
└── llm-router-open-source-research.md
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 已知限制
|
||||||
|
|
||||||
|
1. **路由维度**: 当前仅支持 3 个 Qwen 模型,可扩展至 40+ 模型
|
||||||
|
2. **在线学习**: 缺少多臂老虎机等在线学习机制
|
||||||
|
3. **语义缓存**: 未实现查询缓存优化
|
||||||
|
4. **Docker MPS**: macOS Docker 容器无法使用 Metal GPU,需原生运行
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 后续计划
|
||||||
|
|
||||||
|
- [ ] Layer 2: 多臂老虎机在线学习 (Thompson Sampling)
|
||||||
|
- [ ] Layer 3: 语义缓存 + 批量优化
|
||||||
|
- [ ] 扩展模型池至 40+ 模型
|
||||||
|
- [ ] 基于业务数据微调 NVIDIA 分类器
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 许可证
|
||||||
|
|
||||||
|
Apache 2.0
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**LLM Compass** - 让每个查询都找到最优的模型。
|
||||||
@@ -3,11 +3,11 @@ services:
|
|||||||
build: .
|
build: .
|
||||||
container_name: llm-compass
|
container_name: llm-compass
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "402:8000"
|
||||||
environment:
|
environment:
|
||||||
- DASHSCOPE_API_KEY=${DASHSCOPE_API_KEY}
|
- DASHSCOPE_API_KEY=${DASHSCOPE_API_KEY}
|
||||||
volumes:
|
volumes:
|
||||||
- compass-data:/app/data
|
- ./data:/app/data
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||||
|
|||||||
73
main.py
73
main.py
@@ -58,23 +58,33 @@ _load_history()
|
|||||||
|
|
||||||
|
|
||||||
# ── OpenAI 兼容请求/响应模型 ────────────────────────────────
|
# ── OpenAI 兼容请求/响应模型 ────────────────────────────────
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: str
|
role: str = Field(..., description="角色:system, user, assistant", example="user")
|
||||||
content: Optional[str] = None
|
content: Optional[str] = Field(None, description="消息内容", example="你好,介绍一下你自己")
|
||||||
name: Optional[str] = None
|
name: Optional[str] = Field(None, description="可选的名称")
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: Optional[str] = None
|
model: Optional[str] = Field(
|
||||||
messages: List[ChatMessage]
|
None,
|
||||||
temperature: Optional[float] = 0.7
|
description="模型名称(留空时自动使用 NVIDIA 分类器智能路由)",
|
||||||
max_tokens: Optional[int] = None
|
example="qwen-plus",
|
||||||
stream: Optional[bool] = False
|
json_schema_extra={"examples": ["", "qwen-flash", "qwen-plus", "qwen-max"]}
|
||||||
top_p: Optional[float] = 1.0
|
)
|
||||||
n: Optional[int] = 1
|
messages: List[ChatMessage] = Field(
|
||||||
stop: Optional[Any] = None
|
...,
|
||||||
presence_penalty: Optional[float] = 0.0
|
description="对话消息列表",
|
||||||
frequency_penalty: Optional[float] = 0.0
|
example=[{"role": "user", "content": "你好,介绍一下你自己"}]
|
||||||
user: Optional[str] = None
|
)
|
||||||
|
temperature: Optional[float] = Field(0.7, ge=0, le=2, description="随机性 (0-2)")
|
||||||
|
max_tokens: Optional[int] = Field(None, description="最大生成 token 数(留空时使用模型默认值)", example=2048)
|
||||||
|
stream: Optional[bool] = Field(False, description="是否使用流式输出")
|
||||||
|
top_p: Optional[float] = Field(1.0, ge=0, le=1, description="核采样参数")
|
||||||
|
n: Optional[int] = Field(1, ge=1, le=10, description="生成回复数量")
|
||||||
|
stop: Optional[Any] = Field(None, description="停止词")
|
||||||
|
presence_penalty: Optional[float] = Field(0.0, ge=-2, le=2, description="存在惩罚")
|
||||||
|
frequency_penalty: Optional[float] = Field(0.0, ge=-2, le=2, description="频率惩罚")
|
||||||
|
user: Optional[str] = Field(None, description="用户标识")
|
||||||
|
|
||||||
|
|
||||||
# ── FastAPI App ──────────────────────────────────────────────
|
# ── FastAPI App ──────────────────────────────────────────────
|
||||||
@@ -281,13 +291,18 @@ async def chat_completions(request: ChatCompletionRequest):
|
|||||||
|
|
||||||
# 3. 非流式响应
|
# 3. 非流式响应
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 构建请求参数(过滤掉 None 和 0 的 max_tokens)
|
||||||
|
completion_kwargs = {
|
||||||
|
"model": provider_model,
|
||||||
|
"messages": messages_raw,
|
||||||
|
"temperature": request.temperature,
|
||||||
|
}
|
||||||
|
if request.max_tokens and request.max_tokens > 0:
|
||||||
|
completion_kwargs["max_tokens"] = request.max_tokens
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await acompletion(
|
response = await acompletion(**completion_kwargs)
|
||||||
model=provider_model,
|
|
||||||
messages=messages_raw,
|
|
||||||
temperature=request.temperature,
|
|
||||||
max_tokens=request.max_tokens,
|
|
||||||
)
|
|
||||||
latency_ms = (time.time() - start_time) * 1000
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
input_tokens = response.usage.prompt_tokens
|
input_tokens = response.usage.prompt_tokens
|
||||||
@@ -328,13 +343,17 @@ async def _stream_response(
|
|||||||
output_tokens = 0
|
output_tokens = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await acompletion(
|
# 构建请求参数(过滤掉 None 和 0 的 max_tokens)
|
||||||
model=provider_model,
|
completion_kwargs = {
|
||||||
messages=messages_raw,
|
"model": provider_model,
|
||||||
temperature=request.temperature,
|
"messages": messages_raw,
|
||||||
max_tokens=request.max_tokens,
|
"temperature": request.temperature,
|
||||||
stream=True,
|
"stream": True,
|
||||||
)
|
}
|
||||||
|
if request.max_tokens and request.max_tokens > 0:
|
||||||
|
completion_kwargs["max_tokens"] = request.max_tokens
|
||||||
|
|
||||||
|
response = await acompletion(**completion_kwargs)
|
||||||
|
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
delta = chunk.choices[0].delta
|
delta = chunk.choices[0].delta
|
||||||
|
|||||||
@@ -10,11 +10,12 @@ NVIDIA Prompt Task & Complexity Classifier Router
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import AutoTokenizer, DebertaV2Model, AutoConfig
|
from transformers import AutoTokenizer, DebertaV2Model, DebertaV2Config
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
import logging
|
import logging
|
||||||
|
import json
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -115,7 +116,16 @@ class NvidiaComplexityRouter:
|
|||||||
# Creativity 映射
|
# Creativity 映射
|
||||||
CREATIVITY_MAP = {0: "High", 1: "Low", 2: "No"}
|
CREATIVITY_MAP = {0: "High", 1: "Low", 2: "No"}
|
||||||
|
|
||||||
def __init__(self, device: str = "cpu"):
|
def __init__(self, device: str = "auto"):
|
||||||
|
if device == "auto":
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
|
logger.info("MPS (Metal GPU) detected, using MPS acceleration")
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
logger.info("No GPU detected, using CPU")
|
||||||
self.device = device
|
self.device = device
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.model = None
|
self.model = None
|
||||||
@@ -129,8 +139,18 @@ class NvidiaComplexityRouter:
|
|||||||
|
|
||||||
logger.info(f"Loading NVIDIA classifier: {self.MODEL_NAME}")
|
logger.info(f"Loading NVIDIA classifier: {self.MODEL_NAME}")
|
||||||
|
|
||||||
# 1. 加载 config
|
# 1. 手动加载自定义 config.json(该模型无 model_type,AutoConfig 不兼容)
|
||||||
self.config = AutoConfig.from_pretrained(self.MODEL_NAME)
|
config_path = hf_hub_download(self.MODEL_NAME, "config.json")
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
custom_config = json.load(f)
|
||||||
|
|
||||||
|
# 构建 backbone 的 DeBERTa config(从 base_model 加载)
|
||||||
|
base_model = custom_config.get("base_model", "microsoft/DeBERTa-v3-base")
|
||||||
|
self.config = DebertaV2Config.from_pretrained(base_model)
|
||||||
|
# 保存自定义分类头参数
|
||||||
|
self.config.target_sizes = custom_config["target_sizes"]
|
||||||
|
self.config.fc_dropout = custom_config.get("fc_dropout", 0.2)
|
||||||
|
self.config.base_model = base_model
|
||||||
|
|
||||||
# 2. 加载 tokenizer (slow模式,兼容性好)
|
# 2. 加载 tokenizer (slow模式,兼容性好)
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME, use_fast=False)
|
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME, use_fast=False)
|
||||||
@@ -143,9 +163,13 @@ class NvidiaComplexityRouter:
|
|||||||
self.model.load_state_dict(state_dict, strict=False)
|
self.model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
# MPS 需要 float16 以避免矩阵乘法数据类型冲突
|
||||||
|
if self.device == "mps":
|
||||||
|
self.model.half()
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
logger.info("NVIDIA classifier loaded successfully")
|
dtype = "float16" if self.device == "mps" else "float32"
|
||||||
|
logger.info(f"NVIDIA classifier loaded successfully on {self.device} ({dtype})")
|
||||||
|
|
||||||
def predict(self, query: str) -> Dict:
|
def predict(self, query: str) -> Dict:
|
||||||
"""
|
"""
|
||||||
|
|||||||
48
start.sh
Executable file
48
start.sh
Executable file
@@ -0,0 +1,48 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# LLM Compass 启动脚本
|
||||||
|
# 用法: ./start.sh [端口号]
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# 默认端口
|
||||||
|
PORT=${1:-8402}
|
||||||
|
|
||||||
|
# 获取脚本所在目录
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||||
|
cd "$SCRIPT_DIR"
|
||||||
|
|
||||||
|
# 检查虚拟环境
|
||||||
|
if [ ! -d ".venv" ]; then
|
||||||
|
echo "❌ 虚拟环境不存在,请先运行:"
|
||||||
|
echo " python3 -m venv .venv"
|
||||||
|
echo " .venv/bin/pip install -r requirements.txt"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 检查 .env 文件
|
||||||
|
if [ ! -f ".env" ]; then
|
||||||
|
echo "❌ .env 文件不存在,请创建并配置 API Key"
|
||||||
|
echo " cp .env.example .env"
|
||||||
|
echo " 编辑 .env 填入 DASHSCOPE_API_KEY"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 加载环境变量
|
||||||
|
export DASHSCOPE_API_KEY=$(grep DASHSCOPE_API_KEY .env | cut -d= -f2)
|
||||||
|
|
||||||
|
if [ -z "$DASHSCOPE_API_KEY" ]; then
|
||||||
|
echo "❌ DASHSCOPE_API_KEY 未设置,请检查 .env 文件"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "🚀 启动 LLM Compass 服务..."
|
||||||
|
echo "📍 地址: http://localhost:${PORT}"
|
||||||
|
echo "📖 API 文档: http://localhost:${PORT}/docs"
|
||||||
|
echo "🔧 路由方式: NVIDIA MPS 加速 (M4 Pro GPU)"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# 启动服务
|
||||||
|
exec .venv/bin/python -m uvicorn main:app \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port "$PORT" \
|
||||||
|
--log-level info
|
||||||
Reference in New Issue
Block a user