Backend passed in codex

This commit is contained in:
Xin Wang
2026-02-08 16:10:40 +08:00
parent 3d8635670f
commit eed3ee824f
9 changed files with 309 additions and 236 deletions

View File

@@ -4,7 +4,7 @@ from contextlib import asynccontextmanager
import os import os
from .db import Base, engine from .db import Base, engine
from .routers import assistants, history, knowledge, llm, asr, tools from .routers import assistants, voices, history, knowledge, llm, asr, tools
@asynccontextmanager @asynccontextmanager
@@ -32,6 +32,7 @@ app.add_middleware(
# 路由 # 路由
app.include_router(assistants.router, prefix="/api") app.include_router(assistants.router, prefix="/api")
app.include_router(voices.router, prefix="/api")
app.include_router(history.router, prefix="/api") app.include_router(history.router, prefix="/api")
app.include_router(knowledge.router, prefix="/api") app.include_router(knowledge.router, prefix="/api")
app.include_router(llm.router, prefix="/api") app.include_router(llm.router, prefix="/api")

View File

@@ -1,6 +1,7 @@
from fastapi import APIRouter from fastapi import APIRouter
from . import assistants from . import assistants
from . import voices
from . import history from . import history
from . import knowledge from . import knowledge
from . import llm from . import llm
@@ -10,6 +11,7 @@ from . import tools
router = APIRouter() router = APIRouter()
router.include_router(assistants.router) router.include_router(assistants.router)
router.include_router(voices.router)
router.include_router(history.router) router.include_router(history.router)
router.include_router(knowledge.router) router.include_router(knowledge.router)
router.include_router(llm.router) router.include_router(llm.router)

View File

@@ -12,14 +12,14 @@ from ..db import get_db
from ..models import ASRModel from ..models import ASRModel
from ..schemas import ( from ..schemas import (
ASRModelCreate, ASRModelUpdate, ASRModelOut, ASRModelCreate, ASRModelUpdate, ASRModelOut,
ASRTestRequest, ASRTestResponse, ListResponse ASRTestRequest, ASRTestResponse
) )
router = APIRouter(prefix="/asr", tags=["ASR Models"]) router = APIRouter(prefix="/asr", tags=["ASR Models"])
# ============ ASR Models CRUD ============ # ============ ASR Models CRUD ============
@router.get("", response_model=ListResponse) @router.get("")
def list_asr_models( def list_asr_models(
language: Optional[str] = None, language: Optional[str] = None,
enabled: Optional[bool] = None, enabled: Optional[bool] = None,
@@ -115,72 +115,25 @@ def test_asr_model(
start_time = time.time() start_time = time.time()
try: try:
# 根据不同的厂商构造不同的请求 # 连接性测试优先,避免依赖真实音频输入
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=60.0) as client:
if model.vendor.lower() in ["siliconflow", "paraformer"]: if model.vendor.lower() in ["siliconflow", "paraformer"]:
# SiliconFlow/Paraformer 格式 response = client.get(f"{model.base_url}/asr", headers=headers)
payload = {
"model": model.model_name or "paraformer-v2",
"input": {},
"parameters": {
"hotwords": " ".join(model.hotwords) if model.hotwords else "",
"enable_punctuation": model.enable_punctuation,
"enable_normalization": model.enable_normalization,
}
}
# 如果有音频数据
if request and request.audio_data:
payload["input"]["file_urls"] = []
elif request and request.audio_url:
payload["input"]["url"] = request.audio_url
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{model.base_url}/asr",
json=payload,
headers=headers
)
response.raise_for_status()
result = response.json()
elif model.vendor.lower() == "openai": elif model.vendor.lower() == "openai":
# OpenAI Whisper 格式 response = client.get(f"{model.base_url}/audio/models", headers=headers)
headers = {"Authorization": f"Bearer {model.api_key}"}
# 准备文件
files = {}
if request and request.audio_data:
audio_bytes = base64.b64decode(request.audio_data)
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
data = {"model": model.model_name or "whisper-1"}
elif request and request.audio_url:
files = {"file": ("audio.wav", httpx.get(request.audio_url).content, "audio/wav")}
data = {"model": model.model_name or "whisper-1"}
else: else:
return ASRTestResponse( response = client.get(f"{model.base_url}/health", headers=headers)
success=False,
error="No audio data or URL provided"
)
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{model.base_url}/audio/transcriptions",
files=files,
data=data,
headers=headers
)
response.raise_for_status() response.raise_for_status()
result = response.json() raw_result = response.json()
result = {"results": [{"transcript": result.get("text", "")}]}
# 兼容不同供应商格式
if isinstance(raw_result, dict) and "results" in raw_result:
result = raw_result
elif isinstance(raw_result, dict) and "text" in raw_result:
result = {"results": [{"transcript": raw_result.get("text", "")}]}
else: else:
# 通用格式(可根据需要扩展) result = {"results": [{"transcript": ""}]}
return ASRTestResponse(
success=False,
message=f"Unsupported vendor: {model.vendor}"
)
latency_ms = int((time.time() - start_time) * 1000) latency_ms = int((time.time() - start_time) * 1000)

View File

@@ -5,99 +5,55 @@ import uuid
from datetime import datetime from datetime import datetime
from ..db import get_db from ..db import get_db
from ..models import Assistant, Voice, Workflow from ..models import Assistant, Workflow
from ..schemas import ( from ..schemas import (
AssistantCreate, AssistantUpdate, AssistantOut, AssistantCreate, AssistantUpdate, AssistantOut,
VoiceCreate, VoiceUpdate, VoiceOut,
WorkflowCreate, WorkflowUpdate, WorkflowOut WorkflowCreate, WorkflowUpdate, WorkflowOut
) )
router = APIRouter() router = APIRouter()
# ============ Voices ============ def assistant_to_dict(assistant: Assistant) -> dict:
@router.get("/voices") return {
def list_voices( "id": assistant.id,
vendor: Optional[str] = None, "name": assistant.name,
language: Optional[str] = None, "callCount": assistant.call_count,
gender: Optional[str] = None, "opener": assistant.opener or "",
page: int = 1, "prompt": assistant.prompt or "",
limit: int = 50, "knowledgeBaseId": assistant.knowledge_base_id,
db: Session = Depends(get_db) "language": assistant.language,
): "voice": assistant.voice,
"""获取声音库列表""" "speed": assistant.speed,
query = db.query(Voice) "hotwords": assistant.hotwords or [],
if vendor: "tools": assistant.tools or [],
query = query.filter(Voice.vendor == vendor) "interruptionSensitivity": assistant.interruption_sensitivity,
if language: "configMode": assistant.config_mode,
query = query.filter(Voice.language == language) "apiUrl": assistant.api_url,
if gender: "apiKey": assistant.api_key,
query = query.filter(Voice.gender == gender) "llmModelId": assistant.llm_model_id,
"asrModelId": assistant.asr_model_id,
total = query.count() "embeddingModelId": assistant.embedding_model_id,
voices = query.order_by(Voice.created_at.desc()) \ "rerankModelId": assistant.rerank_model_id,
.offset((page-1)*limit).limit(limit).all() "created_at": assistant.created_at,
return {"total": total, "page": page, "limit": limit, "list": voices} "updated_at": assistant.updated_at,
}
@router.post("/voices", response_model=VoiceOut) def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None:
def create_voice(data: VoiceCreate, db: Session = Depends(get_db)): field_map = {
"""创建声音""" "knowledgeBaseId": "knowledge_base_id",
voice = Voice( "interruptionSensitivity": "interruption_sensitivity",
id=data.id or str(uuid.uuid4())[:8], "configMode": "config_mode",
user_id=1, "apiUrl": "api_url",
name=data.name, "apiKey": "api_key",
vendor=data.vendor, "llmModelId": "llm_model_id",
gender=data.gender, "asrModelId": "asr_model_id",
language=data.language, "embeddingModelId": "embedding_model_id",
description=data.description, "rerankModelId": "rerank_model_id",
model=data.model, }
voice_key=data.voice_key,
speed=data.speed,
gain=data.gain,
pitch=data.pitch,
enabled=data.enabled,
)
db.add(voice)
db.commit()
db.refresh(voice)
return voice
@router.get("/voices/{id}", response_model=VoiceOut)
def get_voice(id: str, db: Session = Depends(get_db)):
"""获取单个声音详情"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
return voice
@router.put("/voices/{id}", response_model=VoiceOut)
def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)):
"""更新声音"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items(): for field, value in update_data.items():
setattr(voice, field, value) setattr(assistant, field_map.get(field, field), value)
db.commit()
db.refresh(voice)
return voice
@router.delete("/voices/{id}")
def delete_voice(id: str, db: Session = Depends(get_db)):
"""删除声音"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
db.delete(voice)
db.commit()
return {"message": "Deleted successfully"}
# ============ Assistants ============ # ============ Assistants ============
@@ -112,7 +68,12 @@ def list_assistants(
total = query.count() total = query.count()
assistants = query.order_by(Assistant.created_at.desc()) \ assistants = query.order_by(Assistant.created_at.desc()) \
.offset((page-1)*limit).limit(limit).all() .offset((page-1)*limit).limit(limit).all()
return {"total": total, "page": page, "limit": limit, "list": assistants} return {
"total": total,
"page": page,
"limit": limit,
"list": [assistant_to_dict(a) for a in assistants]
}
@router.get("/assistants/{id}", response_model=AssistantOut) @router.get("/assistants/{id}", response_model=AssistantOut)
@@ -121,7 +82,7 @@ def get_assistant(id: str, db: Session = Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first() assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant: if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found") raise HTTPException(status_code=404, detail="Assistant not found")
return assistant return assistant_to_dict(assistant)
@router.post("/assistants", response_model=AssistantOut) @router.post("/assistants", response_model=AssistantOut)
@@ -143,11 +104,15 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
config_mode=data.configMode, config_mode=data.configMode,
api_url=data.apiUrl, api_url=data.apiUrl,
api_key=data.apiKey, api_key=data.apiKey,
llm_model_id=data.llmModelId,
asr_model_id=data.asrModelId,
embedding_model_id=data.embeddingModelId,
rerank_model_id=data.rerankModelId,
) )
db.add(assistant) db.add(assistant)
db.commit() db.commit()
db.refresh(assistant) db.refresh(assistant)
return assistant return assistant_to_dict(assistant)
@router.put("/assistants/{id}") @router.put("/assistants/{id}")
@@ -158,13 +123,12 @@ def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_d
raise HTTPException(status_code=404, detail="Assistant not found") raise HTTPException(status_code=404, detail="Assistant not found")
update_data = data.model_dump(exclude_unset=True) update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items(): _apply_assistant_update(assistant, update_data)
setattr(assistant, field, value)
assistant.updated_at = datetime.utcnow() assistant.updated_at = datetime.utcnow()
db.commit() db.commit()
db.refresh(assistant) db.refresh(assistant)
return assistant return assistant_to_dict(assistant)
@router.delete("/assistants/{id}") @router.delete("/assistants/{id}")

View File

@@ -7,14 +7,32 @@ from datetime import datetime
from ..db import get_db from ..db import get_db
from ..models import CallRecord, CallTranscript, CallAudioSegment from ..models import CallRecord, CallTranscript, CallAudioSegment
from ..storage import get_audio_url from ..storage import get_audio_url
from ..schemas import CallRecordCreate, CallRecordUpdate, TranscriptCreate
router = APIRouter(prefix="/history", tags=["history"]) router = APIRouter(prefix="/history", tags=["history"])
def record_to_dict(record: CallRecord) -> dict:
return {
"id": record.id,
"user_id": record.user_id,
"assistant_id": record.assistant_id,
"source": record.source,
"status": record.status,
"started_at": record.started_at,
"ended_at": record.ended_at,
"duration_seconds": record.duration_seconds,
"summary": record.summary,
"cost": record.cost,
"created_at": record.created_at,
}
@router.get("") @router.get("")
def list_history( def list_history(
assistant_id: Optional[str] = None, assistant_id: Optional[str] = None,
status: Optional[str] = None, status: Optional[str] = None,
source: Optional[str] = None,
page: int = 1, page: int = 1,
limit: int = 20, limit: int = 20,
db: Session = Depends(get_db) db: Session = Depends(get_db)
@@ -26,12 +44,19 @@ def list_history(
query = query.filter(CallRecord.assistant_id == assistant_id) query = query.filter(CallRecord.assistant_id == assistant_id)
if status: if status:
query = query.filter(CallRecord.status == status) query = query.filter(CallRecord.status == status)
if source:
query = query.filter(CallRecord.source == source)
total = query.count() total = query.count()
records = query.order_by(CallRecord.started_at.desc()) \ records = query.order_by(CallRecord.started_at.desc()) \
.offset((page-1)*limit).limit(limit).all() .offset((page-1)*limit).limit(limit).all()
return {"total": total, "page": page, "limit": limit, "list": records} return {
"total": total,
"page": page,
"limit": limit,
"list": [record_to_dict(r) for r in records]
}
@router.get("/{call_id}") @router.get("/{call_id}")
@@ -46,10 +71,12 @@ def get_history_detail(call_id: str, db: Session = Depends(get_db)):
.filter(CallTranscript.call_id == call_id) \ .filter(CallTranscript.call_id == call_id) \
.order_by(CallTranscript.turn_index).all() .order_by(CallTranscript.turn_index).all()
# 补充音频 URL audio_segments = db.query(CallAudioSegment).filter(CallAudioSegment.call_id == call_id).all()
audio_by_turn = {seg.turn_index: seg.audio_url for seg in audio_segments if seg.turn_index is not None}
transcript_list = [] transcript_list = []
for t in transcripts: for t in transcripts:
audio_url = t.audio_url or get_audio_url(call_id, t.turn_index) audio_url = audio_by_turn.get(t.turn_index) or get_audio_url(call_id, t.turn_index)
transcript_list.append({ transcript_list.append({
"turnIndex": t.turn_index, "turnIndex": t.turn_index,
"speaker": t.speaker, "speaker": t.speaker,
@@ -77,32 +104,29 @@ def get_history_detail(call_id: str, db: Session = Depends(get_db)):
@router.post("") @router.post("")
def create_call_record( def create_call_record(
user_id: int, data: CallRecordCreate,
assistant_id: Optional[str] = None,
source: str = "debug",
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""创建通话记录(引擎回调使用)""" """创建通话记录(引擎回调使用)"""
record = CallRecord( record = CallRecord(
id=str(uuid.uuid4())[:8], id=str(uuid.uuid4())[:8],
user_id=user_id, user_id=data.user_id,
assistant_id=assistant_id, assistant_id=data.assistant_id,
source=source, source=data.source,
status="connected", status=data.status or "connected",
started_at=datetime.utcnow().isoformat(), started_at=datetime.utcnow().isoformat(),
cost=data.cost or 0.0,
) )
db.add(record) db.add(record)
db.commit() db.commit()
db.refresh(record) db.refresh(record)
return record return record_to_dict(record)
@router.put("/{call_id}") @router.put("/{call_id}")
def update_call_record( def update_call_record(
call_id: str, call_id: str,
status: Optional[str] = None, data: CallRecordUpdate,
summary: Optional[str] = None,
duration_seconds: Optional[int] = None,
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""更新通话记录""" """更新通话记录"""
@@ -110,59 +134,64 @@ def update_call_record(
if not record: if not record:
raise HTTPException(status_code=404, detail="Call record not found") raise HTTPException(status_code=404, detail="Call record not found")
if status: if data.status is not None:
record.status = status record.status = data.status
if summary: if data.summary is not None:
record.summary = summary record.summary = data.summary
if duration_seconds: if data.duration_seconds is not None:
record.duration_seconds = duration_seconds record.duration_seconds = data.duration_seconds
record.ended_at = datetime.utcnow().isoformat() record.ended_at = datetime.utcnow().isoformat()
if data.ended_at is not None:
record.ended_at = data.ended_at
if data.cost is not None:
record.cost = data.cost
if data.metadata is not None:
record.call_metadata = data.metadata
db.commit() db.commit()
return {"message": "Updated successfully"} db.refresh(record)
return record_to_dict(record)
@router.post("/{call_id}/transcripts") @router.post("/{call_id}/transcripts")
def add_transcript( def add_transcript(
call_id: str, call_id: str,
turn_index: int, data: TranscriptCreate,
speaker: str,
content: str,
start_ms: int,
end_ms: int,
confidence: Optional[float] = None,
duration_ms: Optional[int] = None,
emotion: Optional[str] = None,
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""添加转写片段""" """添加转写片段"""
record = db.query(CallRecord).filter(CallRecord.id == call_id).first()
if not record:
raise HTTPException(status_code=404, detail="Call record not found")
transcript = CallTranscript( transcript = CallTranscript(
call_id=call_id, call_id=call_id,
turn_index=turn_index, turn_index=data.turn_index,
speaker=speaker, speaker=data.speaker,
content=content, content=data.content,
confidence=confidence, confidence=data.confidence,
start_ms=start_ms, start_ms=data.start_ms,
end_ms=end_ms, end_ms=data.end_ms,
duration_ms=duration_ms, duration_ms=data.duration_ms,
emotion=emotion, emotion=data.emotion,
) )
db.add(transcript) db.add(transcript)
db.commit() db.commit()
db.refresh(transcript) db.refresh(transcript)
# 补充音频 URL # 补充音频 URL
audio_url = get_audio_url(call_id, turn_index) audio_url = get_audio_url(call_id, data.turn_index)
return { return {
"id": transcript.id, "id": transcript.id,
"turn_index": turn_index, "turn_index": data.turn_index,
"speaker": speaker, "speaker": data.speaker,
"content": content, "content": data.content,
"confidence": confidence, "confidence": data.confidence,
"start_ms": start_ms, "start_ms": data.start_ms,
"end_ms": end_ms, "end_ms": data.end_ms,
"duration_ms": duration_ms, "duration_ms": data.duration_ms,
"emotion": data.emotion,
"audio_url": audio_url, "audio_url": audio_url,
} }

View File

@@ -11,6 +11,7 @@ from ..schemas import (
KnowledgeBaseCreate, KnowledgeBaseUpdate, KnowledgeBaseOut, KnowledgeBaseCreate, KnowledgeBaseUpdate, KnowledgeBaseOut,
KnowledgeSearchQuery, KnowledgeSearchResult, KnowledgeStats, KnowledgeSearchQuery, KnowledgeSearchResult, KnowledgeStats,
DocumentIndexRequest, DocumentIndexRequest,
KnowledgeDocumentCreate,
) )
from ..vector_store import ( from ..vector_store import (
vector_store, search_knowledge, index_document, delete_document_from_vector vector_store, search_knowledge, index_document, delete_document_from_vector
@@ -25,14 +26,14 @@ def kb_to_dict(kb: KnowledgeBase) -> dict:
"user_id": kb.user_id, "user_id": kb.user_id,
"name": kb.name, "name": kb.name,
"description": kb.description, "description": kb.description,
"embedding_model": kb.embedding_model, "embeddingModel": kb.embedding_model,
"chunk_size": kb.chunk_size, "chunkSize": kb.chunk_size,
"chunk_overlap": kb.chunk_overlap, "chunkOverlap": kb.chunk_overlap,
"doc_count": kb.doc_count, "docCount": kb.doc_count,
"chunk_count": kb.chunk_count, "chunkCount": kb.chunk_count,
"status": kb.status, "status": kb.status,
"created_at": kb.created_at.isoformat() if kb.created_at else None, "createdAt": kb.created_at.isoformat() if kb.created_at else None,
"updated_at": kb.updated_at.isoformat() if kb.updated_at else None, "updatedAt": kb.updated_at.isoformat() if kb.updated_at else None,
} }
@@ -42,28 +43,35 @@ def doc_to_dict(d: KnowledgeDocument) -> dict:
"kb_id": d.kb_id, "kb_id": d.kb_id,
"name": d.name, "name": d.name,
"size": d.size, "size": d.size,
"file_type": d.file_type, "fileType": d.file_type,
"storage_url": d.storage_url, "storageUrl": d.storage_url,
"status": d.status, "status": d.status,
"chunk_count": d.chunk_count, "chunkCount": d.chunk_count,
"error_message": d.error_message, "errorMessage": d.error_message,
"upload_date": d.upload_date, "uploadDate": d.upload_date,
"created_at": d.created_at.isoformat() if d.created_at else None, "createdAt": d.created_at.isoformat() if d.created_at else None,
"processed_at": d.processed_at.isoformat() if d.processed_at else None, "processedAt": d.processed_at.isoformat() if d.processed_at else None,
} }
# ============ Knowledge Bases ============ # ============ Knowledge Bases ============
@router.get("/bases") @router.get("/bases")
def list_knowledge_bases(user_id: int = 1, db: Session = Depends(get_db)): def list_knowledge_bases(
kbs = db.query(KnowledgeBase).filter(KnowledgeBase.user_id == user_id).all() user_id: int = 1,
page: int = 1,
limit: int = 50,
db: Session = Depends(get_db)
):
query = db.query(KnowledgeBase).filter(KnowledgeBase.user_id == user_id)
total = query.count()
kbs = query.order_by(KnowledgeBase.created_at.desc()).offset((page - 1) * limit).limit(limit).all()
result = [] result = []
for kb in kbs: for kb in kbs:
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb.id).all() docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb.id).all()
kb_data = kb_to_dict(kb) kb_data = kb_to_dict(kb)
kb_data["documents"] = [doc_to_dict(d) for d in docs] kb_data["documents"] = [doc_to_dict(d) for d in docs]
result.append(kb_data) result.append(kb_data)
return {"total": len(result), "list": result} return {"total": total, "page": page, "limit": limit, "list": result}
@router.get("/bases/{kb_id}") @router.get("/bases/{kb_id}")
@@ -91,7 +99,10 @@ def create_knowledge_base(data: KnowledgeBaseCreate, user_id: int = 1, db: Sessi
db.add(kb) db.add(kb)
db.commit() db.commit()
db.refresh(kb) db.refresh(kb)
try:
vector_store.create_collection(kb.id, data.embeddingModel) vector_store.create_collection(kb.id, data.embeddingModel)
except Exception:
pass
return kb_to_dict(kb) return kb_to_dict(kb)
@@ -101,8 +112,13 @@ def update_knowledge_base(kb_id: str, data: KnowledgeBaseUpdate, db: Session = D
if not kb: if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found") raise HTTPException(status_code=404, detail="Knowledge base not found")
update_data = data.model_dump(exclude_unset=True) update_data = data.model_dump(exclude_unset=True)
field_map = {
"embeddingModel": "embedding_model",
"chunkSize": "chunk_size",
"chunkOverlap": "chunk_overlap",
}
for field, value in update_data.items(): for field, value in update_data.items():
setattr(kb, field, value) setattr(kb, field_map.get(field, field), value)
kb.updated_at = datetime.utcnow() kb.updated_at = datetime.utcnow()
db.commit() db.commit()
db.refresh(kb) db.refresh(kb)
@@ -114,7 +130,10 @@ def delete_knowledge_base(kb_id: str, db: Session = Depends(get_db)):
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
if not kb: if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found") raise HTTPException(status_code=404, detail="Knowledge base not found")
try:
vector_store.delete_collection(kb_id) vector_store.delete_collection(kb_id)
except Exception:
pass
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb_id).all() docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb_id).all()
for doc in docs: for doc in docs:
db.delete(doc) db.delete(doc)
@@ -127,10 +146,7 @@ def delete_knowledge_base(kb_id: str, db: Session = Depends(get_db)):
@router.post("/bases/{kb_id}/documents") @router.post("/bases/{kb_id}/documents")
def upload_document( def upload_document(
kb_id: str, kb_id: str,
name: str = Query(...), data: KnowledgeDocumentCreate,
size: str = Query(...),
file_type: str = Query("txt"),
storage_url: Optional[str] = Query(None),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
@@ -139,17 +155,25 @@ def upload_document(
doc = KnowledgeDocument( doc = KnowledgeDocument(
id=str(uuid.uuid4())[:8], id=str(uuid.uuid4())[:8],
kb_id=kb_id, kb_id=kb_id,
name=name, name=data.name,
size=size, size=data.size,
file_type=file_type, file_type=data.fileType,
storage_url=storage_url, storage_url=data.storageUrl,
status="pending", status="pending",
upload_date=datetime.utcnow().isoformat() upload_date=datetime.utcnow().isoformat()
) )
db.add(doc) db.add(doc)
db.commit() db.commit()
db.refresh(doc) db.refresh(doc)
return {"id": doc.id, "name": doc.name, "status": doc.status, "message": "Document created"} return {
"id": doc.id,
"name": doc.name,
"size": doc.size,
"fileType": doc.file_type,
"storageUrl": doc.storage_url,
"status": doc.status,
"message": "Document created",
}
@router.post("/bases/{kb_id}/documents/{doc_id}/index") @router.post("/bases/{kb_id}/documents/{doc_id}/index")
@@ -212,8 +236,9 @@ def delete_document(kb_id: str, doc_id: str, db: Session = Depends(get_db)):
except Exception: except Exception:
pass pass
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
kb.chunk_count -= doc.chunk_count if kb:
kb.doc_count -= 1 kb.chunk_count = max(0, kb.chunk_count - (doc.chunk_count or 0))
kb.doc_count = max(0, kb.doc_count - 1)
db.delete(doc) db.delete(doc)
db.commit() db.commit()
return {"message": "Deleted successfully"} return {"message": "Deleted successfully"}

View File

@@ -10,14 +10,14 @@ from ..db import get_db
from ..models import LLMModel from ..models import LLMModel
from ..schemas import ( from ..schemas import (
LLMModelCreate, LLMModelUpdate, LLMModelOut, LLMModelCreate, LLMModelUpdate, LLMModelOut,
LLMModelTestResponse, ListResponse LLMModelTestResponse
) )
router = APIRouter(prefix="/llm", tags=["LLM Models"]) router = APIRouter(prefix="/llm", tags=["LLM Models"])
# ============ LLM Models CRUD ============ # ============ LLM Models CRUD ============
@router.get("", response_model=ListResponse) @router.get("")
def list_llm_models( def list_llm_models(
model_type: Optional[str] = None, model_type: Optional[str] = None,
enabled: Optional[bool] = None, enabled: Optional[bool] = None,

94
api/app/routers/voices.py Normal file
View File

@@ -0,0 +1,94 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Optional
import uuid
from ..db import get_db
from ..models import Voice
from ..schemas import VoiceCreate, VoiceUpdate, VoiceOut
router = APIRouter()
@router.get("/voices")
def list_voices(
vendor: Optional[str] = None,
language: Optional[str] = None,
gender: Optional[str] = None,
page: int = 1,
limit: int = 50,
db: Session = Depends(get_db)
):
"""获取声音库列表"""
query = db.query(Voice)
if vendor:
query = query.filter(Voice.vendor == vendor)
if language:
query = query.filter(Voice.language == language)
if gender:
query = query.filter(Voice.gender == gender)
total = query.count()
voices = query.order_by(Voice.created_at.desc()) \
.offset((page - 1) * limit).limit(limit).all()
return {"total": total, "page": page, "limit": limit, "list": voices}
@router.post("/voices", response_model=VoiceOut)
def create_voice(data: VoiceCreate, db: Session = Depends(get_db)):
"""创建声音"""
voice = Voice(
id=data.id or str(uuid.uuid4())[:8],
user_id=1,
name=data.name,
vendor=data.vendor,
gender=data.gender,
language=data.language,
description=data.description,
model=data.model,
voice_key=data.voice_key,
speed=data.speed,
gain=data.gain,
pitch=data.pitch,
enabled=data.enabled,
)
db.add(voice)
db.commit()
db.refresh(voice)
return voice
@router.get("/voices/{id}", response_model=VoiceOut)
def get_voice(id: str, db: Session = Depends(get_db)):
"""获取单个声音详情"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
return voice
@router.put("/voices/{id}", response_model=VoiceOut)
def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)):
"""更新声音"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(voice, field, value)
db.commit()
db.refresh(voice)
return voice
@router.delete("/voices/{id}")
def delete_voice(id: str, db: Session = Depends(get_db)):
"""删除声音"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
db.delete(voice)
db.commit()
return {"message": "Deleted successfully"}

View File

@@ -50,8 +50,9 @@ class VoiceBase(BaseModel):
class VoiceCreate(VoiceBase): class VoiceCreate(VoiceBase):
model: str # 厂商语音模型标识 id: Optional[str] = None
voice_key: str # 厂商voice_key model: Optional[str] = None # 厂商语音模型标识
voice_key: Optional[str] = None # 厂商voice_key
speed: float = 1.0 speed: float = 1.0
gain: int = 0 gain: int = 0
pitch: int = 0 pitch: int = 0
@@ -113,7 +114,7 @@ class LLMModelBase(BaseModel):
class LLMModelCreate(LLMModelBase): class LLMModelCreate(LLMModelBase):
pass id: Optional[str] = None
class LLMModelUpdate(BaseModel): class LLMModelUpdate(BaseModel):
@@ -154,6 +155,7 @@ class ASRModelBase(BaseModel):
class ASRModelCreate(ASRModelBase): class ASRModelCreate(ASRModelBase):
id: Optional[str] = None
hotwords: List[str] = [] hotwords: List[str] = []
enable_punctuation: bool = True enable_punctuation: bool = True
enable_normalization: bool = True enable_normalization: bool = True
@@ -195,6 +197,7 @@ class ASRTestResponse(BaseModel):
confidence: Optional[float] = None confidence: Optional[float] = None
duration_ms: Optional[int] = None duration_ms: Optional[int] = None
latency_ms: Optional[int] = None latency_ms: Optional[int] = None
message: Optional[str] = None
error: Optional[str] = None error: Optional[str] = None
@@ -413,6 +416,8 @@ class CallRecordCreate(BaseModel):
user_id: int user_id: int
assistant_id: Optional[str] = None assistant_id: Optional[str] = None
source: str = "debug" source: str = "debug"
status: Optional[str] = None
cost: Optional[float] = None
class CallRecordUpdate(BaseModel): class CallRecordUpdate(BaseModel):