297 lines
9.6 KiB
Python
297 lines
9.6 KiB
Python
from fastapi import APIRouter, Depends, HTTPException
|
|
from sqlalchemy.orm import Session
|
|
from typing import List, Optional
|
|
import httpx
|
|
import time
|
|
from datetime import datetime
|
|
|
|
from ..db import get_db
|
|
from ..id_generator import unique_short_id
|
|
from ..models import LLMModel
|
|
from ..schemas import (
|
|
LLMModelCreate, LLMModelUpdate, LLMModelOut,
|
|
LLMModelTestResponse, LLMPreviewRequest, LLMPreviewResponse
|
|
)
|
|
|
|
router = APIRouter(prefix="/llm", tags=["LLM Models"])
|
|
|
|
|
|
# ============ LLM Models CRUD ============
|
|
@router.get("")
|
|
def list_llm_models(
|
|
model_type: Optional[str] = None,
|
|
enabled: Optional[bool] = None,
|
|
page: int = 1,
|
|
limit: int = 50,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""获取LLM模型列表"""
|
|
query = db.query(LLMModel)
|
|
|
|
if model_type:
|
|
query = query.filter(LLMModel.type == model_type)
|
|
if enabled is not None:
|
|
query = query.filter(LLMModel.enabled == enabled)
|
|
|
|
total = query.count()
|
|
models = query.order_by(LLMModel.created_at.desc()) \
|
|
.offset((page-1)*limit).limit(limit).all()
|
|
|
|
return {"total": total, "page": page, "limit": limit, "list": models}
|
|
|
|
|
|
@router.get("/{id}", response_model=LLMModelOut)
|
|
def get_llm_model(id: str, db: Session = Depends(get_db)):
|
|
"""获取单个LLM模型详情"""
|
|
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
|
if not model:
|
|
raise HTTPException(status_code=404, detail="LLM Model not found")
|
|
return model
|
|
|
|
|
|
@router.post("", response_model=LLMModelOut)
|
|
def create_llm_model(data: LLMModelCreate, db: Session = Depends(get_db)):
|
|
"""创建LLM模型"""
|
|
llm_model = LLMModel(
|
|
id=unique_short_id("llm", db, LLMModel),
|
|
user_id=1, # 默认用户
|
|
name=data.name,
|
|
vendor=data.vendor,
|
|
type=data.type.value if hasattr(data.type, 'value') else data.type,
|
|
base_url=data.base_url,
|
|
api_key=data.api_key,
|
|
model_name=data.model_name,
|
|
temperature=data.temperature,
|
|
context_length=data.context_length,
|
|
enabled=data.enabled,
|
|
)
|
|
db.add(llm_model)
|
|
db.commit()
|
|
db.refresh(llm_model)
|
|
return llm_model
|
|
|
|
|
|
@router.put("/{id}", response_model=LLMModelOut)
|
|
def update_llm_model(id: str, data: LLMModelUpdate, db: Session = Depends(get_db)):
|
|
"""更新LLM模型"""
|
|
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
|
if not model:
|
|
raise HTTPException(status_code=404, detail="LLM Model not found")
|
|
|
|
update_data = data.model_dump(exclude_unset=True)
|
|
if "type" in update_data and update_data["type"] is not None and hasattr(update_data["type"], "value"):
|
|
update_data["type"] = update_data["type"].value
|
|
for field, value in update_data.items():
|
|
setattr(model, field, value)
|
|
|
|
model.updated_at = datetime.utcnow()
|
|
db.commit()
|
|
db.refresh(model)
|
|
return model
|
|
|
|
|
|
@router.delete("/{id}")
|
|
def delete_llm_model(id: str, db: Session = Depends(get_db)):
|
|
"""删除LLM模型"""
|
|
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
|
if not model:
|
|
raise HTTPException(status_code=404, detail="LLM Model not found")
|
|
db.delete(model)
|
|
db.commit()
|
|
return {"message": "Deleted successfully"}
|
|
|
|
|
|
@router.post("/{id}/test", response_model=LLMModelTestResponse)
|
|
def test_llm_model(id: str, db: Session = Depends(get_db)):
|
|
"""测试LLM模型连接"""
|
|
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
|
if not model:
|
|
raise HTTPException(status_code=404, detail="LLM Model not found")
|
|
|
|
start_time = time.time()
|
|
try:
|
|
# 构造测试请求
|
|
test_messages = [{"role": "user", "content": "Hello, please reply with 'OK'."}]
|
|
|
|
payload = {
|
|
"model": model.model_name or "gpt-3.5-turbo",
|
|
"messages": test_messages,
|
|
"max_tokens": 10,
|
|
"temperature": 0.1,
|
|
}
|
|
|
|
headers = {"Authorization": f"Bearer {model.api_key}"}
|
|
|
|
with httpx.Client(timeout=30.0) as client:
|
|
response = client.post(
|
|
f"{model.base_url}/chat/completions",
|
|
json=payload,
|
|
headers=headers
|
|
)
|
|
response.raise_for_status()
|
|
|
|
latency_ms = int((time.time() - start_time) * 1000)
|
|
result = response.json()
|
|
|
|
if result.get("choices"):
|
|
return LLMModelTestResponse(
|
|
success=True,
|
|
latency_ms=latency_ms,
|
|
message="Connection successful"
|
|
)
|
|
else:
|
|
return LLMModelTestResponse(
|
|
success=False,
|
|
latency_ms=latency_ms,
|
|
message="Unexpected response format"
|
|
)
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
return LLMModelTestResponse(
|
|
success=False,
|
|
message=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}"
|
|
)
|
|
except Exception as e:
|
|
return LLMModelTestResponse(
|
|
success=False,
|
|
message=str(e)[:200]
|
|
)
|
|
|
|
|
|
@router.post("/{id}/chat")
|
|
def chat_with_llm(
|
|
id: str,
|
|
message: str,
|
|
system_prompt: Optional[str] = None,
|
|
max_tokens: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""与LLM模型对话"""
|
|
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
|
if not model:
|
|
raise HTTPException(status_code=404, detail="LLM Model not found")
|
|
|
|
messages = []
|
|
if system_prompt:
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
messages.append({"role": "user", "content": message})
|
|
|
|
payload = {
|
|
"model": model.model_name or "gpt-3.5-turbo",
|
|
"messages": messages,
|
|
"max_tokens": max_tokens or 1000,
|
|
"temperature": temperature if temperature is not None else model.temperature or 0.7,
|
|
}
|
|
|
|
headers = {"Authorization": f"Bearer {model.api_key}"}
|
|
|
|
try:
|
|
with httpx.Client(timeout=60.0) as client:
|
|
response = client.post(
|
|
f"{model.base_url}/chat/completions",
|
|
json=payload,
|
|
headers=headers
|
|
)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
if choice := result.get("choices", [{}])[0]:
|
|
return {
|
|
"success": True,
|
|
"reply": choice.get("message", {}).get("content", ""),
|
|
"usage": result.get("usage", {})
|
|
}
|
|
return {"success": False, "reply": "", "error": "No response"}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/{id}/preview", response_model=LLMPreviewResponse)
|
|
def preview_llm_model(
|
|
id: str,
|
|
request: LLMPreviewRequest,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""预览模型输出,支持 text(chat) 与 embedding 两类模型。"""
|
|
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
|
if not model:
|
|
raise HTTPException(status_code=404, detail="LLM Model not found")
|
|
|
|
user_message = (request.message or "").strip()
|
|
if not user_message:
|
|
raise HTTPException(status_code=400, detail="Preview message cannot be empty")
|
|
|
|
model_id = model.model_name or "gpt-3.5-turbo"
|
|
headers = {"Authorization": f"Bearer {(request.api_key or model.api_key).strip()}"}
|
|
|
|
start_time = time.time()
|
|
endpoint = "/chat/completions"
|
|
payload = {}
|
|
|
|
if model.type == "embedding":
|
|
endpoint = "/embeddings"
|
|
payload = {
|
|
"model": model_id,
|
|
"input": user_message,
|
|
}
|
|
else:
|
|
messages = []
|
|
if request.system_prompt and request.system_prompt.strip():
|
|
messages.append({"role": "system", "content": request.system_prompt.strip()})
|
|
messages.append({"role": "user", "content": user_message})
|
|
payload = {
|
|
"model": model_id,
|
|
"messages": messages,
|
|
"max_tokens": request.max_tokens or 512,
|
|
"temperature": request.temperature if request.temperature is not None else (model.temperature or 0.7),
|
|
}
|
|
|
|
try:
|
|
with httpx.Client(timeout=60.0) as client:
|
|
response = client.post(
|
|
f"{model.base_url.rstrip('/')}{endpoint}",
|
|
json=payload,
|
|
headers=headers
|
|
)
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=502, detail=f"LLM request failed: {exc}") from exc
|
|
|
|
if response.status_code != 200:
|
|
detail = response.text
|
|
try:
|
|
detail_json = response.json()
|
|
detail = detail_json.get("error", {}).get("message") or detail_json.get("detail") or detail
|
|
except Exception:
|
|
pass
|
|
raise HTTPException(status_code=502, detail=f"LLM vendor error: {detail}")
|
|
|
|
result = response.json()
|
|
reply = ""
|
|
if model.type == "embedding":
|
|
data_list = result.get("data", [])
|
|
embedding = []
|
|
if data_list and isinstance(data_list, list):
|
|
embedding = data_list[0].get("embedding", []) or []
|
|
dims = len(embedding) if isinstance(embedding, list) else 0
|
|
preview_values = []
|
|
if isinstance(embedding, list):
|
|
preview_values = embedding[:8]
|
|
values_text = ", ".join(
|
|
[f"{float(v):.6f}" if isinstance(v, (float, int)) else str(v) for v in preview_values]
|
|
)
|
|
reply = f"Embedding generated successfully. dims={dims}. head=[{values_text}]"
|
|
else:
|
|
choices = result.get("choices", [])
|
|
if choices:
|
|
reply = choices[0].get("message", {}).get("content", "") or ""
|
|
|
|
return LLMPreviewResponse(
|
|
success=bool(reply),
|
|
reply=reply,
|
|
usage=result.get("usage"),
|
|
latency_ms=int((time.time() - start_time) * 1000),
|
|
error=None if reply else "No response content",
|
|
)
|