diff --git a/api/app/routers/llm.py b/api/app/routers/llm.py index 0658de4..8454f54 100644 --- a/api/app/routers/llm.py +++ b/api/app/routers/llm.py @@ -214,7 +214,7 @@ def preview_llm_model( request: LLMPreviewRequest, db: Session = Depends(get_db) ): - """预览 LLM 输出,基于 OpenAI-compatible /chat/completions。""" + """预览模型输出,支持 text(chat) 与 embedding 两类模型。""" model = db.query(LLMModel).filter(LLMModel.id == id).first() if not model: raise HTTPException(status_code=404, detail="LLM Model not found") @@ -223,24 +223,35 @@ def preview_llm_model( if not user_message: raise HTTPException(status_code=400, detail="Preview message cannot be empty") - messages = [] - if request.system_prompt and request.system_prompt.strip(): - messages.append({"role": "system", "content": request.system_prompt.strip()}) - messages.append({"role": "user", "content": user_message}) - - payload = { - "model": model.model_name or "gpt-3.5-turbo", - "messages": messages, - "max_tokens": request.max_tokens or 512, - "temperature": request.temperature if request.temperature is not None else (model.temperature or 0.7), - } + model_id = model.model_name or "gpt-3.5-turbo" headers = {"Authorization": f"Bearer {(request.api_key or model.api_key).strip()}"} start_time = time.time() + endpoint = "/chat/completions" + payload = {} + + if model.type == "embedding": + endpoint = "/embeddings" + payload = { + "model": model_id, + "input": user_message, + } + else: + messages = [] + if request.system_prompt and request.system_prompt.strip(): + messages.append({"role": "system", "content": request.system_prompt.strip()}) + messages.append({"role": "user", "content": user_message}) + payload = { + "model": model_id, + "messages": messages, + "max_tokens": request.max_tokens or 512, + "temperature": request.temperature if request.temperature is not None else (model.temperature or 0.7), + } + try: with httpx.Client(timeout=60.0) as client: response = client.post( - f"{model.base_url.rstrip('/')}/chat/completions", + f"{model.base_url.rstrip('/')}{endpoint}", json=payload, headers=headers ) @@ -258,9 +269,23 @@ def preview_llm_model( result = response.json() reply = "" - choices = result.get("choices", []) - if choices: - reply = choices[0].get("message", {}).get("content", "") or "" + if model.type == "embedding": + data_list = result.get("data", []) + embedding = [] + if data_list and isinstance(data_list, list): + embedding = data_list[0].get("embedding", []) or [] + dims = len(embedding) if isinstance(embedding, list) else 0 + preview_values = [] + if isinstance(embedding, list): + preview_values = embedding[:8] + values_text = ", ".join( + [f"{float(v):.6f}" if isinstance(v, (float, int)) else str(v) for v in preview_values] + ) + reply = f"Embedding generated successfully. dims={dims}. head=[{values_text}]" + else: + choices = result.get("choices", []) + if choices: + reply = choices[0].get("message", {}).get("content", "") or "" return LLMPreviewResponse( success=bool(reply), diff --git a/api/tests/test_llm.py b/api/tests/test_llm.py index 0108b86..2b72fea 100644 --- a/api/tests/test_llm.py +++ b/api/tests/test_llm.py @@ -300,3 +300,53 @@ class TestLLMModelAPI: response = client.post(f"/api/llm/{model_id}/preview", json={"message": " "}) assert response.status_code == 400 + + def test_preview_embedding_model_success(self, client, monkeypatch): + """Test embedding model preview endpoint returns embedding summary.""" + from app.routers import llm as llm_router + + embedding_model_data = { + "id": "preview-emb", + "name": "Preview Embedding", + "vendor": "OpenAI", + "type": "embedding", + "base_url": "https://api.openai.com/v1", + "api_key": "test-key", + "model_name": "text-embedding-3-small" + } + create_response = client.post("/api/llm", json=embedding_model_data) + model_id = create_response.json()["id"] + + class DummyResponse: + status_code = 200 + + def json(self): + return {"data": [{"embedding": [0.1, 0.2, 0.3, 0.4]}], "usage": {"total_tokens": 7}} + + @property + def text(self): + return '{"ok":true}' + + class DummyClient: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def post(self, url, json=None, headers=None): + assert url.endswith("/embeddings") + assert json["input"] == "hello embedding" + assert headers["Authorization"] == "Bearer test-key" + return DummyResponse() + + monkeypatch.setattr(llm_router.httpx, "Client", DummyClient) + + response = client.post(f"/api/llm/{model_id}/preview", json={"message": "hello embedding"}) + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert "dims=4" in data["reply"] diff --git a/web/pages/LLMLibrary.tsx b/web/pages/LLMLibrary.tsx index e6252eb..89e24d6 100644 --- a/web/pages/LLMLibrary.tsx +++ b/web/pages/LLMLibrary.tsx @@ -144,7 +144,13 @@ export const LLMLibraryPage: React.FC = () => { {model.baseUrl} {maskApiKey(model.apiKey)} - setPreviewingModel(model)} disabled={model.type !== 'text'} title={model.type !== 'text' ? '仅 text 模型可预览' : '预览模型'}> + setPreviewingModel(model)} + disabled={model.type === 'rerank'} + title={model.type === 'rerank' ? '暂不支持 rerank 预览' : (model.type === 'embedding' ? '预览 embedding 向量' : '预览模型')} + > setEditingModel(model)}> @@ -358,6 +364,7 @@ const LLMPreviewModal: React.FC<{ onClose: () => void; model: LLMModel | null; }> = ({ isOpen, onClose, model }) => { + const isEmbeddingModel = model?.type === 'embedding'; const [systemPrompt, setSystemPrompt] = useState('You are a concise helpful assistant.'); const [message, setMessage] = useState('Hello, please introduce yourself in one sentence.'); const [temperature, setTemperature] = useState(0.7); @@ -419,28 +426,29 @@ const LLMPreviewModal: React.FC<{ value={systemPrompt} onChange={(e) => setSystemPrompt(e.target.value)} className="flex min-h-[70px] w-full rounded-md border-0 bg-white/5 px-3 py-2 text-sm shadow-sm placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-primary/50 text-white" - placeholder="可选系统提示词" + placeholder={isEmbeddingModel ? 'embedding 预览无需 system prompt(可留空)' : '可选系统提示词'} + disabled={isEmbeddingModel} /> - User Message + {isEmbeddingModel ? 'Input Text' : 'User Message'} setMessage(e.target.value)} className="flex min-h-[90px] w-full rounded-md border-0 bg-white/5 px-3 py-2 text-sm shadow-sm placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-primary/50 text-white" - placeholder="输入用户消息" + placeholder={isEmbeddingModel ? '输入需要生成向量的文本' : '输入用户消息'} /> Temperature - setTemperature(parseFloat(e.target.value || '0'))} /> + setTemperature(parseFloat(e.target.value || '0'))} disabled={isEmbeddingModel} /> Max Tokens - setMaxTokens(parseInt(e.target.value || '1', 10))} /> + setMaxTokens(parseInt(e.target.value || '1', 10))} disabled={isEmbeddingModel} />