Add embedding preview

This commit is contained in:
Xin Wang
2026-02-10 10:18:19 +08:00
parent 436fb3c1e5
commit 1462488969
3 changed files with 105 additions and 22 deletions

View File

@@ -214,7 +214,7 @@ def preview_llm_model(
request: LLMPreviewRequest,
db: Session = Depends(get_db)
):
"""预览 LLM 输出,基于 OpenAI-compatible /chat/completions"""
"""预览模型输出,支持 text(chat) 与 embedding 两类模型"""
model = db.query(LLMModel).filter(LLMModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="LLM Model not found")
@@ -223,24 +223,35 @@ def preview_llm_model(
if not user_message:
raise HTTPException(status_code=400, detail="Preview message cannot be empty")
messages = []
if request.system_prompt and request.system_prompt.strip():
messages.append({"role": "system", "content": request.system_prompt.strip()})
messages.append({"role": "user", "content": user_message})
payload = {
"model": model.model_name or "gpt-3.5-turbo",
"messages": messages,
"max_tokens": request.max_tokens or 512,
"temperature": request.temperature if request.temperature is not None else (model.temperature or 0.7),
}
model_id = model.model_name or "gpt-3.5-turbo"
headers = {"Authorization": f"Bearer {(request.api_key or model.api_key).strip()}"}
start_time = time.time()
endpoint = "/chat/completions"
payload = {}
if model.type == "embedding":
endpoint = "/embeddings"
payload = {
"model": model_id,
"input": user_message,
}
else:
messages = []
if request.system_prompt and request.system_prompt.strip():
messages.append({"role": "system", "content": request.system_prompt.strip()})
messages.append({"role": "user", "content": user_message})
payload = {
"model": model_id,
"messages": messages,
"max_tokens": request.max_tokens or 512,
"temperature": request.temperature if request.temperature is not None else (model.temperature or 0.7),
}
try:
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{model.base_url.rstrip('/')}/chat/completions",
f"{model.base_url.rstrip('/')}{endpoint}",
json=payload,
headers=headers
)
@@ -258,9 +269,23 @@ def preview_llm_model(
result = response.json()
reply = ""
choices = result.get("choices", [])
if choices:
reply = choices[0].get("message", {}).get("content", "") or ""
if model.type == "embedding":
data_list = result.get("data", [])
embedding = []
if data_list and isinstance(data_list, list):
embedding = data_list[0].get("embedding", []) or []
dims = len(embedding) if isinstance(embedding, list) else 0
preview_values = []
if isinstance(embedding, list):
preview_values = embedding[:8]
values_text = ", ".join(
[f"{float(v):.6f}" if isinstance(v, (float, int)) else str(v) for v in preview_values]
)
reply = f"Embedding generated successfully. dims={dims}. head=[{values_text}]"
else:
choices = result.get("choices", [])
if choices:
reply = choices[0].get("message", {}).get("content", "") or ""
return LLMPreviewResponse(
success=bool(reply),

View File

@@ -300,3 +300,53 @@ class TestLLMModelAPI:
response = client.post(f"/api/llm/{model_id}/preview", json={"message": " "})
assert response.status_code == 400
def test_preview_embedding_model_success(self, client, monkeypatch):
"""Test embedding model preview endpoint returns embedding summary."""
from app.routers import llm as llm_router
embedding_model_data = {
"id": "preview-emb",
"name": "Preview Embedding",
"vendor": "OpenAI",
"type": "embedding",
"base_url": "https://api.openai.com/v1",
"api_key": "test-key",
"model_name": "text-embedding-3-small"
}
create_response = client.post("/api/llm", json=embedding_model_data)
model_id = create_response.json()["id"]
class DummyResponse:
status_code = 200
def json(self):
return {"data": [{"embedding": [0.1, 0.2, 0.3, 0.4]}], "usage": {"total_tokens": 7}}
@property
def text(self):
return '{"ok":true}'
class DummyClient:
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def post(self, url, json=None, headers=None):
assert url.endswith("/embeddings")
assert json["input"] == "hello embedding"
assert headers["Authorization"] == "Bearer test-key"
return DummyResponse()
monkeypatch.setattr(llm_router.httpx, "Client", DummyClient)
response = client.post(f"/api/llm/{model_id}/preview", json={"message": "hello embedding"})
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "dims=4" in data["reply"]