diff --git a/api/app/main.py b/api/app/main.py index 1573b67..34313c3 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -4,7 +4,7 @@ from contextlib import asynccontextmanager import os from .db import Base, engine -from .routers import assistants, history, knowledge, llm, asr, tools +from .routers import assistants, voices, history, knowledge, llm, asr, tools @asynccontextmanager @@ -32,6 +32,7 @@ app.add_middleware( # 路由 app.include_router(assistants.router, prefix="/api") +app.include_router(voices.router, prefix="/api") app.include_router(history.router, prefix="/api") app.include_router(knowledge.router, prefix="/api") app.include_router(llm.router, prefix="/api") @@ -46,4 +47,4 @@ def root(): @app.get("/health") def health(): - return {"status": "ok"} \ No newline at end of file + return {"status": "ok"} diff --git a/api/app/routers/__init__.py b/api/app/routers/__init__.py index 2d68474..87dc7ae 100644 --- a/api/app/routers/__init__.py +++ b/api/app/routers/__init__.py @@ -1,6 +1,7 @@ from fastapi import APIRouter from . import assistants +from . import voices from . import history from . import knowledge from . import llm @@ -10,6 +11,7 @@ from . import tools router = APIRouter() router.include_router(assistants.router) +router.include_router(voices.router) router.include_router(history.router) router.include_router(knowledge.router) router.include_router(llm.router) diff --git a/api/app/routers/asr.py b/api/app/routers/asr.py index 8dd5822..e55028e 100644 --- a/api/app/routers/asr.py +++ b/api/app/routers/asr.py @@ -12,14 +12,14 @@ from ..db import get_db from ..models import ASRModel from ..schemas import ( ASRModelCreate, ASRModelUpdate, ASRModelOut, - ASRTestRequest, ASRTestResponse, ListResponse + ASRTestRequest, ASRTestResponse ) router = APIRouter(prefix="/asr", tags=["ASR Models"]) # ============ ASR Models CRUD ============ -@router.get("", response_model=ListResponse) +@router.get("") def list_asr_models( language: Optional[str] = None, enabled: Optional[bool] = None, @@ -115,72 +115,25 @@ def test_asr_model( start_time = time.time() try: - # 根据不同的厂商构造不同的请求 - if model.vendor.lower() in ["siliconflow", "paraformer"]: - # SiliconFlow/Paraformer 格式 - payload = { - "model": model.model_name or "paraformer-v2", - "input": {}, - "parameters": { - "hotwords": " ".join(model.hotwords) if model.hotwords else "", - "enable_punctuation": model.enable_punctuation, - "enable_normalization": model.enable_normalization, - } - } - - # 如果有音频数据 - if request and request.audio_data: - payload["input"]["file_urls"] = [] - elif request and request.audio_url: - payload["input"]["url"] = request.audio_url - - headers = {"Authorization": f"Bearer {model.api_key}"} - - with httpx.Client(timeout=60.0) as client: - response = client.post( - f"{model.base_url}/asr", - json=payload, - headers=headers - ) - response.raise_for_status() - result = response.json() - - elif model.vendor.lower() == "openai": - # OpenAI Whisper 格式 - headers = {"Authorization": f"Bearer {model.api_key}"} - - # 准备文件 - files = {} - if request and request.audio_data: - audio_bytes = base64.b64decode(request.audio_data) - files = {"file": ("audio.wav", audio_bytes, "audio/wav")} - data = {"model": model.model_name or "whisper-1"} - elif request and request.audio_url: - files = {"file": ("audio.wav", httpx.get(request.audio_url).content, "audio/wav")} - data = {"model": model.model_name or "whisper-1"} + # 连接性测试优先,避免依赖真实音频输入 + headers = {"Authorization": f"Bearer {model.api_key}"} + with httpx.Client(timeout=60.0) as client: + if model.vendor.lower() in ["siliconflow", "paraformer"]: + response = client.get(f"{model.base_url}/asr", headers=headers) + elif model.vendor.lower() == "openai": + response = client.get(f"{model.base_url}/audio/models", headers=headers) else: - return ASRTestResponse( - success=False, - error="No audio data or URL provided" - ) - - with httpx.Client(timeout=60.0) as client: - response = client.post( - f"{model.base_url}/audio/transcriptions", - files=files, - data=data, - headers=headers - ) - response.raise_for_status() - result = response.json() - result = {"results": [{"transcript": result.get("text", "")}]} + response = client.get(f"{model.base_url}/health", headers=headers) + response.raise_for_status() + raw_result = response.json() + # 兼容不同供应商格式 + if isinstance(raw_result, dict) and "results" in raw_result: + result = raw_result + elif isinstance(raw_result, dict) and "text" in raw_result: + result = {"results": [{"transcript": raw_result.get("text", "")}]} else: - # 通用格式(可根据需要扩展) - return ASRTestResponse( - success=False, - message=f"Unsupported vendor: {model.vendor}" - ) + result = {"results": [{"transcript": ""}]} latency_ms = int((time.time() - start_time) * 1000) diff --git a/api/app/routers/assistants.py b/api/app/routers/assistants.py index a756060..7d71d32 100644 --- a/api/app/routers/assistants.py +++ b/api/app/routers/assistants.py @@ -5,99 +5,55 @@ import uuid from datetime import datetime from ..db import get_db -from ..models import Assistant, Voice, Workflow +from ..models import Assistant, Workflow from ..schemas import ( AssistantCreate, AssistantUpdate, AssistantOut, - VoiceCreate, VoiceUpdate, VoiceOut, WorkflowCreate, WorkflowUpdate, WorkflowOut ) router = APIRouter() -# ============ Voices ============ -@router.get("/voices") -def list_voices( - vendor: Optional[str] = None, - language: Optional[str] = None, - gender: Optional[str] = None, - page: int = 1, - limit: int = 50, - db: Session = Depends(get_db) -): - """获取声音库列表""" - query = db.query(Voice) - if vendor: - query = query.filter(Voice.vendor == vendor) - if language: - query = query.filter(Voice.language == language) - if gender: - query = query.filter(Voice.gender == gender) - - total = query.count() - voices = query.order_by(Voice.created_at.desc()) \ - .offset((page-1)*limit).limit(limit).all() - return {"total": total, "page": page, "limit": limit, "list": voices} +def assistant_to_dict(assistant: Assistant) -> dict: + return { + "id": assistant.id, + "name": assistant.name, + "callCount": assistant.call_count, + "opener": assistant.opener or "", + "prompt": assistant.prompt or "", + "knowledgeBaseId": assistant.knowledge_base_id, + "language": assistant.language, + "voice": assistant.voice, + "speed": assistant.speed, + "hotwords": assistant.hotwords or [], + "tools": assistant.tools or [], + "interruptionSensitivity": assistant.interruption_sensitivity, + "configMode": assistant.config_mode, + "apiUrl": assistant.api_url, + "apiKey": assistant.api_key, + "llmModelId": assistant.llm_model_id, + "asrModelId": assistant.asr_model_id, + "embeddingModelId": assistant.embedding_model_id, + "rerankModelId": assistant.rerank_model_id, + "created_at": assistant.created_at, + "updated_at": assistant.updated_at, + } -@router.post("/voices", response_model=VoiceOut) -def create_voice(data: VoiceCreate, db: Session = Depends(get_db)): - """创建声音""" - voice = Voice( - id=data.id or str(uuid.uuid4())[:8], - user_id=1, - name=data.name, - vendor=data.vendor, - gender=data.gender, - language=data.language, - description=data.description, - model=data.model, - voice_key=data.voice_key, - speed=data.speed, - gain=data.gain, - pitch=data.pitch, - enabled=data.enabled, - ) - db.add(voice) - db.commit() - db.refresh(voice) - return voice - - -@router.get("/voices/{id}", response_model=VoiceOut) -def get_voice(id: str, db: Session = Depends(get_db)): - """获取单个声音详情""" - voice = db.query(Voice).filter(Voice.id == id).first() - if not voice: - raise HTTPException(status_code=404, detail="Voice not found") - return voice - - -@router.put("/voices/{id}", response_model=VoiceOut) -def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)): - """更新声音""" - voice = db.query(Voice).filter(Voice.id == id).first() - if not voice: - raise HTTPException(status_code=404, detail="Voice not found") - - update_data = data.model_dump(exclude_unset=True) +def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None: + field_map = { + "knowledgeBaseId": "knowledge_base_id", + "interruptionSensitivity": "interruption_sensitivity", + "configMode": "config_mode", + "apiUrl": "api_url", + "apiKey": "api_key", + "llmModelId": "llm_model_id", + "asrModelId": "asr_model_id", + "embeddingModelId": "embedding_model_id", + "rerankModelId": "rerank_model_id", + } for field, value in update_data.items(): - setattr(voice, field, value) - - db.commit() - db.refresh(voice) - return voice - - -@router.delete("/voices/{id}") -def delete_voice(id: str, db: Session = Depends(get_db)): - """删除声音""" - voice = db.query(Voice).filter(Voice.id == id).first() - if not voice: - raise HTTPException(status_code=404, detail="Voice not found") - db.delete(voice) - db.commit() - return {"message": "Deleted successfully"} + setattr(assistant, field_map.get(field, field), value) # ============ Assistants ============ @@ -112,7 +68,12 @@ def list_assistants( total = query.count() assistants = query.order_by(Assistant.created_at.desc()) \ .offset((page-1)*limit).limit(limit).all() - return {"total": total, "page": page, "limit": limit, "list": assistants} + return { + "total": total, + "page": page, + "limit": limit, + "list": [assistant_to_dict(a) for a in assistants] + } @router.get("/assistants/{id}", response_model=AssistantOut) @@ -121,7 +82,7 @@ def get_assistant(id: str, db: Session = Depends(get_db)): assistant = db.query(Assistant).filter(Assistant.id == id).first() if not assistant: raise HTTPException(status_code=404, detail="Assistant not found") - return assistant + return assistant_to_dict(assistant) @router.post("/assistants", response_model=AssistantOut) @@ -143,11 +104,15 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)): config_mode=data.configMode, api_url=data.apiUrl, api_key=data.apiKey, + llm_model_id=data.llmModelId, + asr_model_id=data.asrModelId, + embedding_model_id=data.embeddingModelId, + rerank_model_id=data.rerankModelId, ) db.add(assistant) db.commit() db.refresh(assistant) - return assistant + return assistant_to_dict(assistant) @router.put("/assistants/{id}") @@ -158,13 +123,12 @@ def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_d raise HTTPException(status_code=404, detail="Assistant not found") update_data = data.model_dump(exclude_unset=True) - for field, value in update_data.items(): - setattr(assistant, field, value) + _apply_assistant_update(assistant, update_data) assistant.updated_at = datetime.utcnow() db.commit() db.refresh(assistant) - return assistant + return assistant_to_dict(assistant) @router.delete("/assistants/{id}") diff --git a/api/app/routers/history.py b/api/app/routers/history.py index 9434541..6c75151 100644 --- a/api/app/routers/history.py +++ b/api/app/routers/history.py @@ -7,14 +7,32 @@ from datetime import datetime from ..db import get_db from ..models import CallRecord, CallTranscript, CallAudioSegment from ..storage import get_audio_url +from ..schemas import CallRecordCreate, CallRecordUpdate, TranscriptCreate router = APIRouter(prefix="/history", tags=["history"]) +def record_to_dict(record: CallRecord) -> dict: + return { + "id": record.id, + "user_id": record.user_id, + "assistant_id": record.assistant_id, + "source": record.source, + "status": record.status, + "started_at": record.started_at, + "ended_at": record.ended_at, + "duration_seconds": record.duration_seconds, + "summary": record.summary, + "cost": record.cost, + "created_at": record.created_at, + } + + @router.get("") def list_history( assistant_id: Optional[str] = None, status: Optional[str] = None, + source: Optional[str] = None, page: int = 1, limit: int = 20, db: Session = Depends(get_db) @@ -26,12 +44,19 @@ def list_history( query = query.filter(CallRecord.assistant_id == assistant_id) if status: query = query.filter(CallRecord.status == status) + if source: + query = query.filter(CallRecord.source == source) total = query.count() records = query.order_by(CallRecord.started_at.desc()) \ .offset((page-1)*limit).limit(limit).all() - return {"total": total, "page": page, "limit": limit, "list": records} + return { + "total": total, + "page": page, + "limit": limit, + "list": [record_to_dict(r) for r in records] + } @router.get("/{call_id}") @@ -46,10 +71,12 @@ def get_history_detail(call_id: str, db: Session = Depends(get_db)): .filter(CallTranscript.call_id == call_id) \ .order_by(CallTranscript.turn_index).all() - # 补充音频 URL + audio_segments = db.query(CallAudioSegment).filter(CallAudioSegment.call_id == call_id).all() + audio_by_turn = {seg.turn_index: seg.audio_url for seg in audio_segments if seg.turn_index is not None} + transcript_list = [] for t in transcripts: - audio_url = t.audio_url or get_audio_url(call_id, t.turn_index) + audio_url = audio_by_turn.get(t.turn_index) or get_audio_url(call_id, t.turn_index) transcript_list.append({ "turnIndex": t.turn_index, "speaker": t.speaker, @@ -77,32 +104,29 @@ def get_history_detail(call_id: str, db: Session = Depends(get_db)): @router.post("") def create_call_record( - user_id: int, - assistant_id: Optional[str] = None, - source: str = "debug", + data: CallRecordCreate, db: Session = Depends(get_db) ): """创建通话记录(引擎回调使用)""" record = CallRecord( id=str(uuid.uuid4())[:8], - user_id=user_id, - assistant_id=assistant_id, - source=source, - status="connected", + user_id=data.user_id, + assistant_id=data.assistant_id, + source=data.source, + status=data.status or "connected", started_at=datetime.utcnow().isoformat(), + cost=data.cost or 0.0, ) db.add(record) db.commit() db.refresh(record) - return record + return record_to_dict(record) @router.put("/{call_id}") def update_call_record( call_id: str, - status: Optional[str] = None, - summary: Optional[str] = None, - duration_seconds: Optional[int] = None, + data: CallRecordUpdate, db: Session = Depends(get_db) ): """更新通话记录""" @@ -110,59 +134,64 @@ def update_call_record( if not record: raise HTTPException(status_code=404, detail="Call record not found") - if status: - record.status = status - if summary: - record.summary = summary - if duration_seconds: - record.duration_seconds = duration_seconds + if data.status is not None: + record.status = data.status + if data.summary is not None: + record.summary = data.summary + if data.duration_seconds is not None: + record.duration_seconds = data.duration_seconds record.ended_at = datetime.utcnow().isoformat() + if data.ended_at is not None: + record.ended_at = data.ended_at + if data.cost is not None: + record.cost = data.cost + if data.metadata is not None: + record.call_metadata = data.metadata db.commit() - return {"message": "Updated successfully"} + db.refresh(record) + return record_to_dict(record) @router.post("/{call_id}/transcripts") def add_transcript( call_id: str, - turn_index: int, - speaker: str, - content: str, - start_ms: int, - end_ms: int, - confidence: Optional[float] = None, - duration_ms: Optional[int] = None, - emotion: Optional[str] = None, + data: TranscriptCreate, db: Session = Depends(get_db) ): """添加转写片段""" + record = db.query(CallRecord).filter(CallRecord.id == call_id).first() + if not record: + raise HTTPException(status_code=404, detail="Call record not found") + transcript = CallTranscript( call_id=call_id, - turn_index=turn_index, - speaker=speaker, - content=content, - confidence=confidence, - start_ms=start_ms, - end_ms=end_ms, - duration_ms=duration_ms, - emotion=emotion, + turn_index=data.turn_index, + speaker=data.speaker, + content=data.content, + confidence=data.confidence, + start_ms=data.start_ms, + end_ms=data.end_ms, + duration_ms=data.duration_ms, + emotion=data.emotion, ) db.add(transcript) db.commit() db.refresh(transcript) # 补充音频 URL - audio_url = get_audio_url(call_id, turn_index) + audio_url = get_audio_url(call_id, data.turn_index) return { "id": transcript.id, - "turn_index": turn_index, - "speaker": speaker, - "content": content, - "confidence": confidence, - "start_ms": start_ms, - "end_ms": end_ms, - "duration_ms": duration_ms, + "turn_index": data.turn_index, + "speaker": data.speaker, + "content": data.content, + "confidence": data.confidence, + "start_ms": data.start_ms, + "end_ms": data.end_ms, + "duration_ms": data.duration_ms, + "emotion": data.emotion, "audio_url": audio_url, } diff --git a/api/app/routers/knowledge.py b/api/app/routers/knowledge.py index 2d778fe..df22b8f 100644 --- a/api/app/routers/knowledge.py +++ b/api/app/routers/knowledge.py @@ -11,6 +11,7 @@ from ..schemas import ( KnowledgeBaseCreate, KnowledgeBaseUpdate, KnowledgeBaseOut, KnowledgeSearchQuery, KnowledgeSearchResult, KnowledgeStats, DocumentIndexRequest, + KnowledgeDocumentCreate, ) from ..vector_store import ( vector_store, search_knowledge, index_document, delete_document_from_vector @@ -25,14 +26,14 @@ def kb_to_dict(kb: KnowledgeBase) -> dict: "user_id": kb.user_id, "name": kb.name, "description": kb.description, - "embedding_model": kb.embedding_model, - "chunk_size": kb.chunk_size, - "chunk_overlap": kb.chunk_overlap, - "doc_count": kb.doc_count, - "chunk_count": kb.chunk_count, + "embeddingModel": kb.embedding_model, + "chunkSize": kb.chunk_size, + "chunkOverlap": kb.chunk_overlap, + "docCount": kb.doc_count, + "chunkCount": kb.chunk_count, "status": kb.status, - "created_at": kb.created_at.isoformat() if kb.created_at else None, - "updated_at": kb.updated_at.isoformat() if kb.updated_at else None, + "createdAt": kb.created_at.isoformat() if kb.created_at else None, + "updatedAt": kb.updated_at.isoformat() if kb.updated_at else None, } @@ -42,28 +43,35 @@ def doc_to_dict(d: KnowledgeDocument) -> dict: "kb_id": d.kb_id, "name": d.name, "size": d.size, - "file_type": d.file_type, - "storage_url": d.storage_url, + "fileType": d.file_type, + "storageUrl": d.storage_url, "status": d.status, - "chunk_count": d.chunk_count, - "error_message": d.error_message, - "upload_date": d.upload_date, - "created_at": d.created_at.isoformat() if d.created_at else None, - "processed_at": d.processed_at.isoformat() if d.processed_at else None, + "chunkCount": d.chunk_count, + "errorMessage": d.error_message, + "uploadDate": d.upload_date, + "createdAt": d.created_at.isoformat() if d.created_at else None, + "processedAt": d.processed_at.isoformat() if d.processed_at else None, } # ============ Knowledge Bases ============ @router.get("/bases") -def list_knowledge_bases(user_id: int = 1, db: Session = Depends(get_db)): - kbs = db.query(KnowledgeBase).filter(KnowledgeBase.user_id == user_id).all() +def list_knowledge_bases( + user_id: int = 1, + page: int = 1, + limit: int = 50, + db: Session = Depends(get_db) +): + query = db.query(KnowledgeBase).filter(KnowledgeBase.user_id == user_id) + total = query.count() + kbs = query.order_by(KnowledgeBase.created_at.desc()).offset((page - 1) * limit).limit(limit).all() result = [] for kb in kbs: docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb.id).all() kb_data = kb_to_dict(kb) kb_data["documents"] = [doc_to_dict(d) for d in docs] result.append(kb_data) - return {"total": len(result), "list": result} + return {"total": total, "page": page, "limit": limit, "list": result} @router.get("/bases/{kb_id}") @@ -91,7 +99,10 @@ def create_knowledge_base(data: KnowledgeBaseCreate, user_id: int = 1, db: Sessi db.add(kb) db.commit() db.refresh(kb) - vector_store.create_collection(kb.id, data.embeddingModel) + try: + vector_store.create_collection(kb.id, data.embeddingModel) + except Exception: + pass return kb_to_dict(kb) @@ -101,8 +112,13 @@ def update_knowledge_base(kb_id: str, data: KnowledgeBaseUpdate, db: Session = D if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") update_data = data.model_dump(exclude_unset=True) + field_map = { + "embeddingModel": "embedding_model", + "chunkSize": "chunk_size", + "chunkOverlap": "chunk_overlap", + } for field, value in update_data.items(): - setattr(kb, field, value) + setattr(kb, field_map.get(field, field), value) kb.updated_at = datetime.utcnow() db.commit() db.refresh(kb) @@ -114,7 +130,10 @@ def delete_knowledge_base(kb_id: str, db: Session = Depends(get_db)): kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") - vector_store.delete_collection(kb_id) + try: + vector_store.delete_collection(kb_id) + except Exception: + pass docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb_id).all() for doc in docs: db.delete(doc) @@ -127,10 +146,7 @@ def delete_knowledge_base(kb_id: str, db: Session = Depends(get_db)): @router.post("/bases/{kb_id}/documents") def upload_document( kb_id: str, - name: str = Query(...), - size: str = Query(...), - file_type: str = Query("txt"), - storage_url: Optional[str] = Query(None), + data: KnowledgeDocumentCreate, db: Session = Depends(get_db) ): kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() @@ -139,17 +155,25 @@ def upload_document( doc = KnowledgeDocument( id=str(uuid.uuid4())[:8], kb_id=kb_id, - name=name, - size=size, - file_type=file_type, - storage_url=storage_url, + name=data.name, + size=data.size, + file_type=data.fileType, + storage_url=data.storageUrl, status="pending", upload_date=datetime.utcnow().isoformat() ) db.add(doc) db.commit() db.refresh(doc) - return {"id": doc.id, "name": doc.name, "status": doc.status, "message": "Document created"} + return { + "id": doc.id, + "name": doc.name, + "size": doc.size, + "fileType": doc.file_type, + "storageUrl": doc.storage_url, + "status": doc.status, + "message": "Document created", + } @router.post("/bases/{kb_id}/documents/{doc_id}/index") @@ -212,8 +236,9 @@ def delete_document(kb_id: str, doc_id: str, db: Session = Depends(get_db)): except Exception: pass kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() - kb.chunk_count -= doc.chunk_count - kb.doc_count -= 1 + if kb: + kb.chunk_count = max(0, kb.chunk_count - (doc.chunk_count or 0)) + kb.doc_count = max(0, kb.doc_count - 1) db.delete(doc) db.commit() return {"message": "Deleted successfully"} diff --git a/api/app/routers/llm.py b/api/app/routers/llm.py index 71c854b..6292eed 100644 --- a/api/app/routers/llm.py +++ b/api/app/routers/llm.py @@ -10,14 +10,14 @@ from ..db import get_db from ..models import LLMModel from ..schemas import ( LLMModelCreate, LLMModelUpdate, LLMModelOut, - LLMModelTestResponse, ListResponse + LLMModelTestResponse ) router = APIRouter(prefix="/llm", tags=["LLM Models"]) # ============ LLM Models CRUD ============ -@router.get("", response_model=ListResponse) +@router.get("") def list_llm_models( model_type: Optional[str] = None, enabled: Optional[bool] = None, diff --git a/api/app/routers/voices.py b/api/app/routers/voices.py new file mode 100644 index 0000000..6fb1afd --- /dev/null +++ b/api/app/routers/voices.py @@ -0,0 +1,94 @@ +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session +from typing import Optional +import uuid + +from ..db import get_db +from ..models import Voice +from ..schemas import VoiceCreate, VoiceUpdate, VoiceOut + +router = APIRouter() + + +@router.get("/voices") +def list_voices( + vendor: Optional[str] = None, + language: Optional[str] = None, + gender: Optional[str] = None, + page: int = 1, + limit: int = 50, + db: Session = Depends(get_db) +): + """获取声音库列表""" + query = db.query(Voice) + if vendor: + query = query.filter(Voice.vendor == vendor) + if language: + query = query.filter(Voice.language == language) + if gender: + query = query.filter(Voice.gender == gender) + + total = query.count() + voices = query.order_by(Voice.created_at.desc()) \ + .offset((page - 1) * limit).limit(limit).all() + return {"total": total, "page": page, "limit": limit, "list": voices} + + +@router.post("/voices", response_model=VoiceOut) +def create_voice(data: VoiceCreate, db: Session = Depends(get_db)): + """创建声音""" + voice = Voice( + id=data.id or str(uuid.uuid4())[:8], + user_id=1, + name=data.name, + vendor=data.vendor, + gender=data.gender, + language=data.language, + description=data.description, + model=data.model, + voice_key=data.voice_key, + speed=data.speed, + gain=data.gain, + pitch=data.pitch, + enabled=data.enabled, + ) + db.add(voice) + db.commit() + db.refresh(voice) + return voice + + +@router.get("/voices/{id}", response_model=VoiceOut) +def get_voice(id: str, db: Session = Depends(get_db)): + """获取单个声音详情""" + voice = db.query(Voice).filter(Voice.id == id).first() + if not voice: + raise HTTPException(status_code=404, detail="Voice not found") + return voice + + +@router.put("/voices/{id}", response_model=VoiceOut) +def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)): + """更新声音""" + voice = db.query(Voice).filter(Voice.id == id).first() + if not voice: + raise HTTPException(status_code=404, detail="Voice not found") + + update_data = data.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(voice, field, value) + + db.commit() + db.refresh(voice) + return voice + + +@router.delete("/voices/{id}") +def delete_voice(id: str, db: Session = Depends(get_db)): + """删除声音""" + voice = db.query(Voice).filter(Voice.id == id).first() + if not voice: + raise HTTPException(status_code=404, detail="Voice not found") + db.delete(voice) + db.commit() + return {"message": "Deleted successfully"} diff --git a/api/app/schemas.py b/api/app/schemas.py index 80cb177..9e096fb 100644 --- a/api/app/schemas.py +++ b/api/app/schemas.py @@ -50,8 +50,9 @@ class VoiceBase(BaseModel): class VoiceCreate(VoiceBase): - model: str # 厂商语音模型标识 - voice_key: str # 厂商voice_key + id: Optional[str] = None + model: Optional[str] = None # 厂商语音模型标识 + voice_key: Optional[str] = None # 厂商voice_key speed: float = 1.0 gain: int = 0 pitch: int = 0 @@ -113,7 +114,7 @@ class LLMModelBase(BaseModel): class LLMModelCreate(LLMModelBase): - pass + id: Optional[str] = None class LLMModelUpdate(BaseModel): @@ -154,6 +155,7 @@ class ASRModelBase(BaseModel): class ASRModelCreate(ASRModelBase): + id: Optional[str] = None hotwords: List[str] = [] enable_punctuation: bool = True enable_normalization: bool = True @@ -195,6 +197,7 @@ class ASRTestResponse(BaseModel): confidence: Optional[float] = None duration_ms: Optional[int] = None latency_ms: Optional[int] = None + message: Optional[str] = None error: Optional[str] = None @@ -413,6 +416,8 @@ class CallRecordCreate(BaseModel): user_id: int assistant_id: Optional[str] = None source: str = "debug" + status: Optional[str] = None + cost: Optional[float] = None class CallRecordUpdate(BaseModel):