Backend passed in codex
This commit is contained in:
@@ -4,7 +4,7 @@ from contextlib import asynccontextmanager
|
||||
import os
|
||||
|
||||
from .db import Base, engine
|
||||
from .routers import assistants, history, knowledge, llm, asr, tools
|
||||
from .routers import assistants, voices, history, knowledge, llm, asr, tools
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -32,6 +32,7 @@ app.add_middleware(
|
||||
|
||||
# 路由
|
||||
app.include_router(assistants.router, prefix="/api")
|
||||
app.include_router(voices.router, prefix="/api")
|
||||
app.include_router(history.router, prefix="/api")
|
||||
app.include_router(knowledge.router, prefix="/api")
|
||||
app.include_router(llm.router, prefix="/api")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from . import assistants
|
||||
from . import voices
|
||||
from . import history
|
||||
from . import knowledge
|
||||
from . import llm
|
||||
@@ -10,6 +11,7 @@ from . import tools
|
||||
router = APIRouter()
|
||||
|
||||
router.include_router(assistants.router)
|
||||
router.include_router(voices.router)
|
||||
router.include_router(history.router)
|
||||
router.include_router(knowledge.router)
|
||||
router.include_router(llm.router)
|
||||
|
||||
@@ -12,14 +12,14 @@ from ..db import get_db
|
||||
from ..models import ASRModel
|
||||
from ..schemas import (
|
||||
ASRModelCreate, ASRModelUpdate, ASRModelOut,
|
||||
ASRTestRequest, ASRTestResponse, ListResponse
|
||||
ASRTestRequest, ASRTestResponse
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/asr", tags=["ASR Models"])
|
||||
|
||||
|
||||
# ============ ASR Models CRUD ============
|
||||
@router.get("", response_model=ListResponse)
|
||||
@router.get("")
|
||||
def list_asr_models(
|
||||
language: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
@@ -115,72 +115,25 @@ def test_asr_model(
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 根据不同的厂商构造不同的请求
|
||||
# 连接性测试优先,避免依赖真实音频输入
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
if model.vendor.lower() in ["siliconflow", "paraformer"]:
|
||||
# SiliconFlow/Paraformer 格式
|
||||
payload = {
|
||||
"model": model.model_name or "paraformer-v2",
|
||||
"input": {},
|
||||
"parameters": {
|
||||
"hotwords": " ".join(model.hotwords) if model.hotwords else "",
|
||||
"enable_punctuation": model.enable_punctuation,
|
||||
"enable_normalization": model.enable_normalization,
|
||||
}
|
||||
}
|
||||
|
||||
# 如果有音频数据
|
||||
if request and request.audio_data:
|
||||
payload["input"]["file_urls"] = []
|
||||
elif request and request.audio_url:
|
||||
payload["input"]["url"] = request.audio_url
|
||||
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/asr",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
response = client.get(f"{model.base_url}/asr", headers=headers)
|
||||
elif model.vendor.lower() == "openai":
|
||||
# OpenAI Whisper 格式
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
# 准备文件
|
||||
files = {}
|
||||
if request and request.audio_data:
|
||||
audio_bytes = base64.b64decode(request.audio_data)
|
||||
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
|
||||
data = {"model": model.model_name or "whisper-1"}
|
||||
elif request and request.audio_url:
|
||||
files = {"file": ("audio.wav", httpx.get(request.audio_url).content, "audio/wav")}
|
||||
data = {"model": model.model_name or "whisper-1"}
|
||||
response = client.get(f"{model.base_url}/audio/models", headers=headers)
|
||||
else:
|
||||
return ASRTestResponse(
|
||||
success=False,
|
||||
error="No audio data or URL provided"
|
||||
)
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/audio/transcriptions",
|
||||
files=files,
|
||||
data=data,
|
||||
headers=headers
|
||||
)
|
||||
response = client.get(f"{model.base_url}/health", headers=headers)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
result = {"results": [{"transcript": result.get("text", "")}]}
|
||||
raw_result = response.json()
|
||||
|
||||
# 兼容不同供应商格式
|
||||
if isinstance(raw_result, dict) and "results" in raw_result:
|
||||
result = raw_result
|
||||
elif isinstance(raw_result, dict) and "text" in raw_result:
|
||||
result = {"results": [{"transcript": raw_result.get("text", "")}]}
|
||||
else:
|
||||
# 通用格式(可根据需要扩展)
|
||||
return ASRTestResponse(
|
||||
success=False,
|
||||
message=f"Unsupported vendor: {model.vendor}"
|
||||
)
|
||||
result = {"results": [{"transcript": ""}]}
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
|
||||
@@ -5,99 +5,55 @@ import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from ..db import get_db
|
||||
from ..models import Assistant, Voice, Workflow
|
||||
from ..models import Assistant, Workflow
|
||||
from ..schemas import (
|
||||
AssistantCreate, AssistantUpdate, AssistantOut,
|
||||
VoiceCreate, VoiceUpdate, VoiceOut,
|
||||
WorkflowCreate, WorkflowUpdate, WorkflowOut
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============ Voices ============
|
||||
@router.get("/voices")
|
||||
def list_voices(
|
||||
vendor: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
gender: Optional[str] = None,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取声音库列表"""
|
||||
query = db.query(Voice)
|
||||
if vendor:
|
||||
query = query.filter(Voice.vendor == vendor)
|
||||
if language:
|
||||
query = query.filter(Voice.language == language)
|
||||
if gender:
|
||||
query = query.filter(Voice.gender == gender)
|
||||
|
||||
total = query.count()
|
||||
voices = query.order_by(Voice.created_at.desc()) \
|
||||
.offset((page-1)*limit).limit(limit).all()
|
||||
return {"total": total, "page": page, "limit": limit, "list": voices}
|
||||
def assistant_to_dict(assistant: Assistant) -> dict:
|
||||
return {
|
||||
"id": assistant.id,
|
||||
"name": assistant.name,
|
||||
"callCount": assistant.call_count,
|
||||
"opener": assistant.opener or "",
|
||||
"prompt": assistant.prompt or "",
|
||||
"knowledgeBaseId": assistant.knowledge_base_id,
|
||||
"language": assistant.language,
|
||||
"voice": assistant.voice,
|
||||
"speed": assistant.speed,
|
||||
"hotwords": assistant.hotwords or [],
|
||||
"tools": assistant.tools or [],
|
||||
"interruptionSensitivity": assistant.interruption_sensitivity,
|
||||
"configMode": assistant.config_mode,
|
||||
"apiUrl": assistant.api_url,
|
||||
"apiKey": assistant.api_key,
|
||||
"llmModelId": assistant.llm_model_id,
|
||||
"asrModelId": assistant.asr_model_id,
|
||||
"embeddingModelId": assistant.embedding_model_id,
|
||||
"rerankModelId": assistant.rerank_model_id,
|
||||
"created_at": assistant.created_at,
|
||||
"updated_at": assistant.updated_at,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/voices", response_model=VoiceOut)
|
||||
def create_voice(data: VoiceCreate, db: Session = Depends(get_db)):
|
||||
"""创建声音"""
|
||||
voice = Voice(
|
||||
id=data.id or str(uuid.uuid4())[:8],
|
||||
user_id=1,
|
||||
name=data.name,
|
||||
vendor=data.vendor,
|
||||
gender=data.gender,
|
||||
language=data.language,
|
||||
description=data.description,
|
||||
model=data.model,
|
||||
voice_key=data.voice_key,
|
||||
speed=data.speed,
|
||||
gain=data.gain,
|
||||
pitch=data.pitch,
|
||||
enabled=data.enabled,
|
||||
)
|
||||
db.add(voice)
|
||||
db.commit()
|
||||
db.refresh(voice)
|
||||
return voice
|
||||
|
||||
|
||||
@router.get("/voices/{id}", response_model=VoiceOut)
|
||||
def get_voice(id: str, db: Session = Depends(get_db)):
|
||||
"""获取单个声音详情"""
|
||||
voice = db.query(Voice).filter(Voice.id == id).first()
|
||||
if not voice:
|
||||
raise HTTPException(status_code=404, detail="Voice not found")
|
||||
return voice
|
||||
|
||||
|
||||
@router.put("/voices/{id}", response_model=VoiceOut)
|
||||
def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)):
|
||||
"""更新声音"""
|
||||
voice = db.query(Voice).filter(Voice.id == id).first()
|
||||
if not voice:
|
||||
raise HTTPException(status_code=404, detail="Voice not found")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None:
|
||||
field_map = {
|
||||
"knowledgeBaseId": "knowledge_base_id",
|
||||
"interruptionSensitivity": "interruption_sensitivity",
|
||||
"configMode": "config_mode",
|
||||
"apiUrl": "api_url",
|
||||
"apiKey": "api_key",
|
||||
"llmModelId": "llm_model_id",
|
||||
"asrModelId": "asr_model_id",
|
||||
"embeddingModelId": "embedding_model_id",
|
||||
"rerankModelId": "rerank_model_id",
|
||||
}
|
||||
for field, value in update_data.items():
|
||||
setattr(voice, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(voice)
|
||||
return voice
|
||||
|
||||
|
||||
@router.delete("/voices/{id}")
|
||||
def delete_voice(id: str, db: Session = Depends(get_db)):
|
||||
"""删除声音"""
|
||||
voice = db.query(Voice).filter(Voice.id == id).first()
|
||||
if not voice:
|
||||
raise HTTPException(status_code=404, detail="Voice not found")
|
||||
db.delete(voice)
|
||||
db.commit()
|
||||
return {"message": "Deleted successfully"}
|
||||
setattr(assistant, field_map.get(field, field), value)
|
||||
|
||||
|
||||
# ============ Assistants ============
|
||||
@@ -112,7 +68,12 @@ def list_assistants(
|
||||
total = query.count()
|
||||
assistants = query.order_by(Assistant.created_at.desc()) \
|
||||
.offset((page-1)*limit).limit(limit).all()
|
||||
return {"total": total, "page": page, "limit": limit, "list": assistants}
|
||||
return {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"list": [assistant_to_dict(a) for a in assistants]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/assistants/{id}", response_model=AssistantOut)
|
||||
@@ -121,7 +82,7 @@ def get_assistant(id: str, db: Session = Depends(get_db)):
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
return assistant
|
||||
return assistant_to_dict(assistant)
|
||||
|
||||
|
||||
@router.post("/assistants", response_model=AssistantOut)
|
||||
@@ -143,11 +104,15 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
|
||||
config_mode=data.configMode,
|
||||
api_url=data.apiUrl,
|
||||
api_key=data.apiKey,
|
||||
llm_model_id=data.llmModelId,
|
||||
asr_model_id=data.asrModelId,
|
||||
embedding_model_id=data.embeddingModelId,
|
||||
rerank_model_id=data.rerankModelId,
|
||||
)
|
||||
db.add(assistant)
|
||||
db.commit()
|
||||
db.refresh(assistant)
|
||||
return assistant
|
||||
return assistant_to_dict(assistant)
|
||||
|
||||
|
||||
@router.put("/assistants/{id}")
|
||||
@@ -158,13 +123,12 @@ def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_d
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(assistant, field, value)
|
||||
_apply_assistant_update(assistant, update_data)
|
||||
|
||||
assistant.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(assistant)
|
||||
return assistant
|
||||
return assistant_to_dict(assistant)
|
||||
|
||||
|
||||
@router.delete("/assistants/{id}")
|
||||
|
||||
@@ -7,14 +7,32 @@ from datetime import datetime
|
||||
from ..db import get_db
|
||||
from ..models import CallRecord, CallTranscript, CallAudioSegment
|
||||
from ..storage import get_audio_url
|
||||
from ..schemas import CallRecordCreate, CallRecordUpdate, TranscriptCreate
|
||||
|
||||
router = APIRouter(prefix="/history", tags=["history"])
|
||||
|
||||
|
||||
def record_to_dict(record: CallRecord) -> dict:
|
||||
return {
|
||||
"id": record.id,
|
||||
"user_id": record.user_id,
|
||||
"assistant_id": record.assistant_id,
|
||||
"source": record.source,
|
||||
"status": record.status,
|
||||
"started_at": record.started_at,
|
||||
"ended_at": record.ended_at,
|
||||
"duration_seconds": record.duration_seconds,
|
||||
"summary": record.summary,
|
||||
"cost": record.cost,
|
||||
"created_at": record.created_at,
|
||||
}
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_history(
|
||||
assistant_id: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
page: int = 1,
|
||||
limit: int = 20,
|
||||
db: Session = Depends(get_db)
|
||||
@@ -26,12 +44,19 @@ def list_history(
|
||||
query = query.filter(CallRecord.assistant_id == assistant_id)
|
||||
if status:
|
||||
query = query.filter(CallRecord.status == status)
|
||||
if source:
|
||||
query = query.filter(CallRecord.source == source)
|
||||
|
||||
total = query.count()
|
||||
records = query.order_by(CallRecord.started_at.desc()) \
|
||||
.offset((page-1)*limit).limit(limit).all()
|
||||
|
||||
return {"total": total, "page": page, "limit": limit, "list": records}
|
||||
return {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"list": [record_to_dict(r) for r in records]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{call_id}")
|
||||
@@ -46,10 +71,12 @@ def get_history_detail(call_id: str, db: Session = Depends(get_db)):
|
||||
.filter(CallTranscript.call_id == call_id) \
|
||||
.order_by(CallTranscript.turn_index).all()
|
||||
|
||||
# 补充音频 URL
|
||||
audio_segments = db.query(CallAudioSegment).filter(CallAudioSegment.call_id == call_id).all()
|
||||
audio_by_turn = {seg.turn_index: seg.audio_url for seg in audio_segments if seg.turn_index is not None}
|
||||
|
||||
transcript_list = []
|
||||
for t in transcripts:
|
||||
audio_url = t.audio_url or get_audio_url(call_id, t.turn_index)
|
||||
audio_url = audio_by_turn.get(t.turn_index) or get_audio_url(call_id, t.turn_index)
|
||||
transcript_list.append({
|
||||
"turnIndex": t.turn_index,
|
||||
"speaker": t.speaker,
|
||||
@@ -77,32 +104,29 @@ def get_history_detail(call_id: str, db: Session = Depends(get_db)):
|
||||
|
||||
@router.post("")
|
||||
def create_call_record(
|
||||
user_id: int,
|
||||
assistant_id: Optional[str] = None,
|
||||
source: str = "debug",
|
||||
data: CallRecordCreate,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""创建通话记录(引擎回调使用)"""
|
||||
record = CallRecord(
|
||||
id=str(uuid.uuid4())[:8],
|
||||
user_id=user_id,
|
||||
assistant_id=assistant_id,
|
||||
source=source,
|
||||
status="connected",
|
||||
user_id=data.user_id,
|
||||
assistant_id=data.assistant_id,
|
||||
source=data.source,
|
||||
status=data.status or "connected",
|
||||
started_at=datetime.utcnow().isoformat(),
|
||||
cost=data.cost or 0.0,
|
||||
)
|
||||
db.add(record)
|
||||
db.commit()
|
||||
db.refresh(record)
|
||||
return record
|
||||
return record_to_dict(record)
|
||||
|
||||
|
||||
@router.put("/{call_id}")
|
||||
def update_call_record(
|
||||
call_id: str,
|
||||
status: Optional[str] = None,
|
||||
summary: Optional[str] = None,
|
||||
duration_seconds: Optional[int] = None,
|
||||
data: CallRecordUpdate,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新通话记录"""
|
||||
@@ -110,59 +134,64 @@ def update_call_record(
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Call record not found")
|
||||
|
||||
if status:
|
||||
record.status = status
|
||||
if summary:
|
||||
record.summary = summary
|
||||
if duration_seconds:
|
||||
record.duration_seconds = duration_seconds
|
||||
if data.status is not None:
|
||||
record.status = data.status
|
||||
if data.summary is not None:
|
||||
record.summary = data.summary
|
||||
if data.duration_seconds is not None:
|
||||
record.duration_seconds = data.duration_seconds
|
||||
record.ended_at = datetime.utcnow().isoformat()
|
||||
if data.ended_at is not None:
|
||||
record.ended_at = data.ended_at
|
||||
if data.cost is not None:
|
||||
record.cost = data.cost
|
||||
if data.metadata is not None:
|
||||
record.call_metadata = data.metadata
|
||||
|
||||
db.commit()
|
||||
return {"message": "Updated successfully"}
|
||||
db.refresh(record)
|
||||
return record_to_dict(record)
|
||||
|
||||
|
||||
@router.post("/{call_id}/transcripts")
|
||||
def add_transcript(
|
||||
call_id: str,
|
||||
turn_index: int,
|
||||
speaker: str,
|
||||
content: str,
|
||||
start_ms: int,
|
||||
end_ms: int,
|
||||
confidence: Optional[float] = None,
|
||||
duration_ms: Optional[int] = None,
|
||||
emotion: Optional[str] = None,
|
||||
data: TranscriptCreate,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""添加转写片段"""
|
||||
record = db.query(CallRecord).filter(CallRecord.id == call_id).first()
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Call record not found")
|
||||
|
||||
transcript = CallTranscript(
|
||||
call_id=call_id,
|
||||
turn_index=turn_index,
|
||||
speaker=speaker,
|
||||
content=content,
|
||||
confidence=confidence,
|
||||
start_ms=start_ms,
|
||||
end_ms=end_ms,
|
||||
duration_ms=duration_ms,
|
||||
emotion=emotion,
|
||||
turn_index=data.turn_index,
|
||||
speaker=data.speaker,
|
||||
content=data.content,
|
||||
confidence=data.confidence,
|
||||
start_ms=data.start_ms,
|
||||
end_ms=data.end_ms,
|
||||
duration_ms=data.duration_ms,
|
||||
emotion=data.emotion,
|
||||
)
|
||||
db.add(transcript)
|
||||
db.commit()
|
||||
db.refresh(transcript)
|
||||
|
||||
# 补充音频 URL
|
||||
audio_url = get_audio_url(call_id, turn_index)
|
||||
audio_url = get_audio_url(call_id, data.turn_index)
|
||||
|
||||
return {
|
||||
"id": transcript.id,
|
||||
"turn_index": turn_index,
|
||||
"speaker": speaker,
|
||||
"content": content,
|
||||
"confidence": confidence,
|
||||
"start_ms": start_ms,
|
||||
"end_ms": end_ms,
|
||||
"duration_ms": duration_ms,
|
||||
"turn_index": data.turn_index,
|
||||
"speaker": data.speaker,
|
||||
"content": data.content,
|
||||
"confidence": data.confidence,
|
||||
"start_ms": data.start_ms,
|
||||
"end_ms": data.end_ms,
|
||||
"duration_ms": data.duration_ms,
|
||||
"emotion": data.emotion,
|
||||
"audio_url": audio_url,
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from ..schemas import (
|
||||
KnowledgeBaseCreate, KnowledgeBaseUpdate, KnowledgeBaseOut,
|
||||
KnowledgeSearchQuery, KnowledgeSearchResult, KnowledgeStats,
|
||||
DocumentIndexRequest,
|
||||
KnowledgeDocumentCreate,
|
||||
)
|
||||
from ..vector_store import (
|
||||
vector_store, search_knowledge, index_document, delete_document_from_vector
|
||||
@@ -25,14 +26,14 @@ def kb_to_dict(kb: KnowledgeBase) -> dict:
|
||||
"user_id": kb.user_id,
|
||||
"name": kb.name,
|
||||
"description": kb.description,
|
||||
"embedding_model": kb.embedding_model,
|
||||
"chunk_size": kb.chunk_size,
|
||||
"chunk_overlap": kb.chunk_overlap,
|
||||
"doc_count": kb.doc_count,
|
||||
"chunk_count": kb.chunk_count,
|
||||
"embeddingModel": kb.embedding_model,
|
||||
"chunkSize": kb.chunk_size,
|
||||
"chunkOverlap": kb.chunk_overlap,
|
||||
"docCount": kb.doc_count,
|
||||
"chunkCount": kb.chunk_count,
|
||||
"status": kb.status,
|
||||
"created_at": kb.created_at.isoformat() if kb.created_at else None,
|
||||
"updated_at": kb.updated_at.isoformat() if kb.updated_at else None,
|
||||
"createdAt": kb.created_at.isoformat() if kb.created_at else None,
|
||||
"updatedAt": kb.updated_at.isoformat() if kb.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
@@ -42,28 +43,35 @@ def doc_to_dict(d: KnowledgeDocument) -> dict:
|
||||
"kb_id": d.kb_id,
|
||||
"name": d.name,
|
||||
"size": d.size,
|
||||
"file_type": d.file_type,
|
||||
"storage_url": d.storage_url,
|
||||
"fileType": d.file_type,
|
||||
"storageUrl": d.storage_url,
|
||||
"status": d.status,
|
||||
"chunk_count": d.chunk_count,
|
||||
"error_message": d.error_message,
|
||||
"upload_date": d.upload_date,
|
||||
"created_at": d.created_at.isoformat() if d.created_at else None,
|
||||
"processed_at": d.processed_at.isoformat() if d.processed_at else None,
|
||||
"chunkCount": d.chunk_count,
|
||||
"errorMessage": d.error_message,
|
||||
"uploadDate": d.upload_date,
|
||||
"createdAt": d.created_at.isoformat() if d.created_at else None,
|
||||
"processedAt": d.processed_at.isoformat() if d.processed_at else None,
|
||||
}
|
||||
|
||||
|
||||
# ============ Knowledge Bases ============
|
||||
@router.get("/bases")
|
||||
def list_knowledge_bases(user_id: int = 1, db: Session = Depends(get_db)):
|
||||
kbs = db.query(KnowledgeBase).filter(KnowledgeBase.user_id == user_id).all()
|
||||
def list_knowledge_bases(
|
||||
user_id: int = 1,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
query = db.query(KnowledgeBase).filter(KnowledgeBase.user_id == user_id)
|
||||
total = query.count()
|
||||
kbs = query.order_by(KnowledgeBase.created_at.desc()).offset((page - 1) * limit).limit(limit).all()
|
||||
result = []
|
||||
for kb in kbs:
|
||||
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb.id).all()
|
||||
kb_data = kb_to_dict(kb)
|
||||
kb_data["documents"] = [doc_to_dict(d) for d in docs]
|
||||
result.append(kb_data)
|
||||
return {"total": len(result), "list": result}
|
||||
return {"total": total, "page": page, "limit": limit, "list": result}
|
||||
|
||||
|
||||
@router.get("/bases/{kb_id}")
|
||||
@@ -91,7 +99,10 @@ def create_knowledge_base(data: KnowledgeBaseCreate, user_id: int = 1, db: Sessi
|
||||
db.add(kb)
|
||||
db.commit()
|
||||
db.refresh(kb)
|
||||
try:
|
||||
vector_store.create_collection(kb.id, data.embeddingModel)
|
||||
except Exception:
|
||||
pass
|
||||
return kb_to_dict(kb)
|
||||
|
||||
|
||||
@@ -101,8 +112,13 @@ def update_knowledge_base(kb_id: str, data: KnowledgeBaseUpdate, db: Session = D
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
field_map = {
|
||||
"embeddingModel": "embedding_model",
|
||||
"chunkSize": "chunk_size",
|
||||
"chunkOverlap": "chunk_overlap",
|
||||
}
|
||||
for field, value in update_data.items():
|
||||
setattr(kb, field, value)
|
||||
setattr(kb, field_map.get(field, field), value)
|
||||
kb.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(kb)
|
||||
@@ -114,7 +130,10 @@ def delete_knowledge_base(kb_id: str, db: Session = Depends(get_db)):
|
||||
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
try:
|
||||
vector_store.delete_collection(kb_id)
|
||||
except Exception:
|
||||
pass
|
||||
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb_id).all()
|
||||
for doc in docs:
|
||||
db.delete(doc)
|
||||
@@ -127,10 +146,7 @@ def delete_knowledge_base(kb_id: str, db: Session = Depends(get_db)):
|
||||
@router.post("/bases/{kb_id}/documents")
|
||||
def upload_document(
|
||||
kb_id: str,
|
||||
name: str = Query(...),
|
||||
size: str = Query(...),
|
||||
file_type: str = Query("txt"),
|
||||
storage_url: Optional[str] = Query(None),
|
||||
data: KnowledgeDocumentCreate,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
||||
@@ -139,17 +155,25 @@ def upload_document(
|
||||
doc = KnowledgeDocument(
|
||||
id=str(uuid.uuid4())[:8],
|
||||
kb_id=kb_id,
|
||||
name=name,
|
||||
size=size,
|
||||
file_type=file_type,
|
||||
storage_url=storage_url,
|
||||
name=data.name,
|
||||
size=data.size,
|
||||
file_type=data.fileType,
|
||||
storage_url=data.storageUrl,
|
||||
status="pending",
|
||||
upload_date=datetime.utcnow().isoformat()
|
||||
)
|
||||
db.add(doc)
|
||||
db.commit()
|
||||
db.refresh(doc)
|
||||
return {"id": doc.id, "name": doc.name, "status": doc.status, "message": "Document created"}
|
||||
return {
|
||||
"id": doc.id,
|
||||
"name": doc.name,
|
||||
"size": doc.size,
|
||||
"fileType": doc.file_type,
|
||||
"storageUrl": doc.storage_url,
|
||||
"status": doc.status,
|
||||
"message": "Document created",
|
||||
}
|
||||
|
||||
|
||||
@router.post("/bases/{kb_id}/documents/{doc_id}/index")
|
||||
@@ -212,8 +236,9 @@ def delete_document(kb_id: str, doc_id: str, db: Session = Depends(get_db)):
|
||||
except Exception:
|
||||
pass
|
||||
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
||||
kb.chunk_count -= doc.chunk_count
|
||||
kb.doc_count -= 1
|
||||
if kb:
|
||||
kb.chunk_count = max(0, kb.chunk_count - (doc.chunk_count or 0))
|
||||
kb.doc_count = max(0, kb.doc_count - 1)
|
||||
db.delete(doc)
|
||||
db.commit()
|
||||
return {"message": "Deleted successfully"}
|
||||
|
||||
@@ -10,14 +10,14 @@ from ..db import get_db
|
||||
from ..models import LLMModel
|
||||
from ..schemas import (
|
||||
LLMModelCreate, LLMModelUpdate, LLMModelOut,
|
||||
LLMModelTestResponse, ListResponse
|
||||
LLMModelTestResponse
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/llm", tags=["LLM Models"])
|
||||
|
||||
|
||||
# ============ LLM Models CRUD ============
|
||||
@router.get("", response_model=ListResponse)
|
||||
@router.get("")
|
||||
def list_llm_models(
|
||||
model_type: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
|
||||
94
api/app/routers/voices.py
Normal file
94
api/app/routers/voices.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from ..db import get_db
|
||||
from ..models import Voice
|
||||
from ..schemas import VoiceCreate, VoiceUpdate, VoiceOut
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/voices")
|
||||
def list_voices(
|
||||
vendor: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
gender: Optional[str] = None,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取声音库列表"""
|
||||
query = db.query(Voice)
|
||||
if vendor:
|
||||
query = query.filter(Voice.vendor == vendor)
|
||||
if language:
|
||||
query = query.filter(Voice.language == language)
|
||||
if gender:
|
||||
query = query.filter(Voice.gender == gender)
|
||||
|
||||
total = query.count()
|
||||
voices = query.order_by(Voice.created_at.desc()) \
|
||||
.offset((page - 1) * limit).limit(limit).all()
|
||||
return {"total": total, "page": page, "limit": limit, "list": voices}
|
||||
|
||||
|
||||
@router.post("/voices", response_model=VoiceOut)
|
||||
def create_voice(data: VoiceCreate, db: Session = Depends(get_db)):
|
||||
"""创建声音"""
|
||||
voice = Voice(
|
||||
id=data.id or str(uuid.uuid4())[:8],
|
||||
user_id=1,
|
||||
name=data.name,
|
||||
vendor=data.vendor,
|
||||
gender=data.gender,
|
||||
language=data.language,
|
||||
description=data.description,
|
||||
model=data.model,
|
||||
voice_key=data.voice_key,
|
||||
speed=data.speed,
|
||||
gain=data.gain,
|
||||
pitch=data.pitch,
|
||||
enabled=data.enabled,
|
||||
)
|
||||
db.add(voice)
|
||||
db.commit()
|
||||
db.refresh(voice)
|
||||
return voice
|
||||
|
||||
|
||||
@router.get("/voices/{id}", response_model=VoiceOut)
|
||||
def get_voice(id: str, db: Session = Depends(get_db)):
|
||||
"""获取单个声音详情"""
|
||||
voice = db.query(Voice).filter(Voice.id == id).first()
|
||||
if not voice:
|
||||
raise HTTPException(status_code=404, detail="Voice not found")
|
||||
return voice
|
||||
|
||||
|
||||
@router.put("/voices/{id}", response_model=VoiceOut)
|
||||
def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)):
|
||||
"""更新声音"""
|
||||
voice = db.query(Voice).filter(Voice.id == id).first()
|
||||
if not voice:
|
||||
raise HTTPException(status_code=404, detail="Voice not found")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(voice, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(voice)
|
||||
return voice
|
||||
|
||||
|
||||
@router.delete("/voices/{id}")
|
||||
def delete_voice(id: str, db: Session = Depends(get_db)):
|
||||
"""删除声音"""
|
||||
voice = db.query(Voice).filter(Voice.id == id).first()
|
||||
if not voice:
|
||||
raise HTTPException(status_code=404, detail="Voice not found")
|
||||
db.delete(voice)
|
||||
db.commit()
|
||||
return {"message": "Deleted successfully"}
|
||||
@@ -50,8 +50,9 @@ class VoiceBase(BaseModel):
|
||||
|
||||
|
||||
class VoiceCreate(VoiceBase):
|
||||
model: str # 厂商语音模型标识
|
||||
voice_key: str # 厂商voice_key
|
||||
id: Optional[str] = None
|
||||
model: Optional[str] = None # 厂商语音模型标识
|
||||
voice_key: Optional[str] = None # 厂商voice_key
|
||||
speed: float = 1.0
|
||||
gain: int = 0
|
||||
pitch: int = 0
|
||||
@@ -113,7 +114,7 @@ class LLMModelBase(BaseModel):
|
||||
|
||||
|
||||
class LLMModelCreate(LLMModelBase):
|
||||
pass
|
||||
id: Optional[str] = None
|
||||
|
||||
|
||||
class LLMModelUpdate(BaseModel):
|
||||
@@ -154,6 +155,7 @@ class ASRModelBase(BaseModel):
|
||||
|
||||
|
||||
class ASRModelCreate(ASRModelBase):
|
||||
id: Optional[str] = None
|
||||
hotwords: List[str] = []
|
||||
enable_punctuation: bool = True
|
||||
enable_normalization: bool = True
|
||||
@@ -195,6 +197,7 @@ class ASRTestResponse(BaseModel):
|
||||
confidence: Optional[float] = None
|
||||
duration_ms: Optional[int] = None
|
||||
latency_ms: Optional[int] = None
|
||||
message: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@@ -413,6 +416,8 @@ class CallRecordCreate(BaseModel):
|
||||
user_id: int
|
||||
assistant_id: Optional[str] = None
|
||||
source: str = "debug"
|
||||
status: Optional[str] = None
|
||||
cost: Optional[float] = None
|
||||
|
||||
|
||||
class CallRecordUpdate(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user