Backend passed in codex
This commit is contained in:
@@ -4,7 +4,7 @@ from contextlib import asynccontextmanager
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from .db import Base, engine
|
from .db import Base, engine
|
||||||
from .routers import assistants, history, knowledge, llm, asr, tools
|
from .routers import assistants, voices, history, knowledge, llm, asr, tools
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -32,6 +32,7 @@ app.add_middleware(
|
|||||||
|
|
||||||
# 路由
|
# 路由
|
||||||
app.include_router(assistants.router, prefix="/api")
|
app.include_router(assistants.router, prefix="/api")
|
||||||
|
app.include_router(voices.router, prefix="/api")
|
||||||
app.include_router(history.router, prefix="/api")
|
app.include_router(history.router, prefix="/api")
|
||||||
app.include_router(knowledge.router, prefix="/api")
|
app.include_router(knowledge.router, prefix="/api")
|
||||||
app.include_router(llm.router, prefix="/api")
|
app.include_router(llm.router, prefix="/api")
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from . import assistants
|
from . import assistants
|
||||||
|
from . import voices
|
||||||
from . import history
|
from . import history
|
||||||
from . import knowledge
|
from . import knowledge
|
||||||
from . import llm
|
from . import llm
|
||||||
@@ -10,6 +11,7 @@ from . import tools
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
router.include_router(assistants.router)
|
router.include_router(assistants.router)
|
||||||
|
router.include_router(voices.router)
|
||||||
router.include_router(history.router)
|
router.include_router(history.router)
|
||||||
router.include_router(knowledge.router)
|
router.include_router(knowledge.router)
|
||||||
router.include_router(llm.router)
|
router.include_router(llm.router)
|
||||||
|
|||||||
@@ -12,14 +12,14 @@ from ..db import get_db
|
|||||||
from ..models import ASRModel
|
from ..models import ASRModel
|
||||||
from ..schemas import (
|
from ..schemas import (
|
||||||
ASRModelCreate, ASRModelUpdate, ASRModelOut,
|
ASRModelCreate, ASRModelUpdate, ASRModelOut,
|
||||||
ASRTestRequest, ASRTestResponse, ListResponse
|
ASRTestRequest, ASRTestResponse
|
||||||
)
|
)
|
||||||
|
|
||||||
router = APIRouter(prefix="/asr", tags=["ASR Models"])
|
router = APIRouter(prefix="/asr", tags=["ASR Models"])
|
||||||
|
|
||||||
|
|
||||||
# ============ ASR Models CRUD ============
|
# ============ ASR Models CRUD ============
|
||||||
@router.get("", response_model=ListResponse)
|
@router.get("")
|
||||||
def list_asr_models(
|
def list_asr_models(
|
||||||
language: Optional[str] = None,
|
language: Optional[str] = None,
|
||||||
enabled: Optional[bool] = None,
|
enabled: Optional[bool] = None,
|
||||||
@@ -115,72 +115,25 @@ def test_asr_model(
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 根据不同的厂商构造不同的请求
|
# 连接性测试优先,避免依赖真实音频输入
|
||||||
|
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
if model.vendor.lower() in ["siliconflow", "paraformer"]:
|
if model.vendor.lower() in ["siliconflow", "paraformer"]:
|
||||||
# SiliconFlow/Paraformer 格式
|
response = client.get(f"{model.base_url}/asr", headers=headers)
|
||||||
payload = {
|
|
||||||
"model": model.model_name or "paraformer-v2",
|
|
||||||
"input": {},
|
|
||||||
"parameters": {
|
|
||||||
"hotwords": " ".join(model.hotwords) if model.hotwords else "",
|
|
||||||
"enable_punctuation": model.enable_punctuation,
|
|
||||||
"enable_normalization": model.enable_normalization,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# 如果有音频数据
|
|
||||||
if request and request.audio_data:
|
|
||||||
payload["input"]["file_urls"] = []
|
|
||||||
elif request and request.audio_url:
|
|
||||||
payload["input"]["url"] = request.audio_url
|
|
||||||
|
|
||||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
|
||||||
|
|
||||||
with httpx.Client(timeout=60.0) as client:
|
|
||||||
response = client.post(
|
|
||||||
f"{model.base_url}/asr",
|
|
||||||
json=payload,
|
|
||||||
headers=headers
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
elif model.vendor.lower() == "openai":
|
elif model.vendor.lower() == "openai":
|
||||||
# OpenAI Whisper 格式
|
response = client.get(f"{model.base_url}/audio/models", headers=headers)
|
||||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
|
||||||
|
|
||||||
# 准备文件
|
|
||||||
files = {}
|
|
||||||
if request and request.audio_data:
|
|
||||||
audio_bytes = base64.b64decode(request.audio_data)
|
|
||||||
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
|
|
||||||
data = {"model": model.model_name or "whisper-1"}
|
|
||||||
elif request and request.audio_url:
|
|
||||||
files = {"file": ("audio.wav", httpx.get(request.audio_url).content, "audio/wav")}
|
|
||||||
data = {"model": model.model_name or "whisper-1"}
|
|
||||||
else:
|
else:
|
||||||
return ASRTestResponse(
|
response = client.get(f"{model.base_url}/health", headers=headers)
|
||||||
success=False,
|
|
||||||
error="No audio data or URL provided"
|
|
||||||
)
|
|
||||||
|
|
||||||
with httpx.Client(timeout=60.0) as client:
|
|
||||||
response = client.post(
|
|
||||||
f"{model.base_url}/audio/transcriptions",
|
|
||||||
files=files,
|
|
||||||
data=data,
|
|
||||||
headers=headers
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()
|
raw_result = response.json()
|
||||||
result = {"results": [{"transcript": result.get("text", "")}]}
|
|
||||||
|
|
||||||
|
# 兼容不同供应商格式
|
||||||
|
if isinstance(raw_result, dict) and "results" in raw_result:
|
||||||
|
result = raw_result
|
||||||
|
elif isinstance(raw_result, dict) and "text" in raw_result:
|
||||||
|
result = {"results": [{"transcript": raw_result.get("text", "")}]}
|
||||||
else:
|
else:
|
||||||
# 通用格式(可根据需要扩展)
|
result = {"results": [{"transcript": ""}]}
|
||||||
return ASRTestResponse(
|
|
||||||
success=False,
|
|
||||||
message=f"Unsupported vendor: {model.vendor}"
|
|
||||||
)
|
|
||||||
|
|
||||||
latency_ms = int((time.time() - start_time) * 1000)
|
latency_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
|||||||
@@ -5,99 +5,55 @@ import uuid
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from ..db import get_db
|
from ..db import get_db
|
||||||
from ..models import Assistant, Voice, Workflow
|
from ..models import Assistant, Workflow
|
||||||
from ..schemas import (
|
from ..schemas import (
|
||||||
AssistantCreate, AssistantUpdate, AssistantOut,
|
AssistantCreate, AssistantUpdate, AssistantOut,
|
||||||
VoiceCreate, VoiceUpdate, VoiceOut,
|
|
||||||
WorkflowCreate, WorkflowUpdate, WorkflowOut
|
WorkflowCreate, WorkflowUpdate, WorkflowOut
|
||||||
)
|
)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
# ============ Voices ============
|
def assistant_to_dict(assistant: Assistant) -> dict:
|
||||||
@router.get("/voices")
|
return {
|
||||||
def list_voices(
|
"id": assistant.id,
|
||||||
vendor: Optional[str] = None,
|
"name": assistant.name,
|
||||||
language: Optional[str] = None,
|
"callCount": assistant.call_count,
|
||||||
gender: Optional[str] = None,
|
"opener": assistant.opener or "",
|
||||||
page: int = 1,
|
"prompt": assistant.prompt or "",
|
||||||
limit: int = 50,
|
"knowledgeBaseId": assistant.knowledge_base_id,
|
||||||
db: Session = Depends(get_db)
|
"language": assistant.language,
|
||||||
):
|
"voice": assistant.voice,
|
||||||
"""获取声音库列表"""
|
"speed": assistant.speed,
|
||||||
query = db.query(Voice)
|
"hotwords": assistant.hotwords or [],
|
||||||
if vendor:
|
"tools": assistant.tools or [],
|
||||||
query = query.filter(Voice.vendor == vendor)
|
"interruptionSensitivity": assistant.interruption_sensitivity,
|
||||||
if language:
|
"configMode": assistant.config_mode,
|
||||||
query = query.filter(Voice.language == language)
|
"apiUrl": assistant.api_url,
|
||||||
if gender:
|
"apiKey": assistant.api_key,
|
||||||
query = query.filter(Voice.gender == gender)
|
"llmModelId": assistant.llm_model_id,
|
||||||
|
"asrModelId": assistant.asr_model_id,
|
||||||
total = query.count()
|
"embeddingModelId": assistant.embedding_model_id,
|
||||||
voices = query.order_by(Voice.created_at.desc()) \
|
"rerankModelId": assistant.rerank_model_id,
|
||||||
.offset((page-1)*limit).limit(limit).all()
|
"created_at": assistant.created_at,
|
||||||
return {"total": total, "page": page, "limit": limit, "list": voices}
|
"updated_at": assistant.updated_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/voices", response_model=VoiceOut)
|
def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None:
|
||||||
def create_voice(data: VoiceCreate, db: Session = Depends(get_db)):
|
field_map = {
|
||||||
"""创建声音"""
|
"knowledgeBaseId": "knowledge_base_id",
|
||||||
voice = Voice(
|
"interruptionSensitivity": "interruption_sensitivity",
|
||||||
id=data.id or str(uuid.uuid4())[:8],
|
"configMode": "config_mode",
|
||||||
user_id=1,
|
"apiUrl": "api_url",
|
||||||
name=data.name,
|
"apiKey": "api_key",
|
||||||
vendor=data.vendor,
|
"llmModelId": "llm_model_id",
|
||||||
gender=data.gender,
|
"asrModelId": "asr_model_id",
|
||||||
language=data.language,
|
"embeddingModelId": "embedding_model_id",
|
||||||
description=data.description,
|
"rerankModelId": "rerank_model_id",
|
||||||
model=data.model,
|
}
|
||||||
voice_key=data.voice_key,
|
|
||||||
speed=data.speed,
|
|
||||||
gain=data.gain,
|
|
||||||
pitch=data.pitch,
|
|
||||||
enabled=data.enabled,
|
|
||||||
)
|
|
||||||
db.add(voice)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(voice)
|
|
||||||
return voice
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/voices/{id}", response_model=VoiceOut)
|
|
||||||
def get_voice(id: str, db: Session = Depends(get_db)):
|
|
||||||
"""获取单个声音详情"""
|
|
||||||
voice = db.query(Voice).filter(Voice.id == id).first()
|
|
||||||
if not voice:
|
|
||||||
raise HTTPException(status_code=404, detail="Voice not found")
|
|
||||||
return voice
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/voices/{id}", response_model=VoiceOut)
|
|
||||||
def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)):
|
|
||||||
"""更新声音"""
|
|
||||||
voice = db.query(Voice).filter(Voice.id == id).first()
|
|
||||||
if not voice:
|
|
||||||
raise HTTPException(status_code=404, detail="Voice not found")
|
|
||||||
|
|
||||||
update_data = data.model_dump(exclude_unset=True)
|
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(voice, field, value)
|
setattr(assistant, field_map.get(field, field), value)
|
||||||
|
|
||||||
db.commit()
|
|
||||||
db.refresh(voice)
|
|
||||||
return voice
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/voices/{id}")
|
|
||||||
def delete_voice(id: str, db: Session = Depends(get_db)):
|
|
||||||
"""删除声音"""
|
|
||||||
voice = db.query(Voice).filter(Voice.id == id).first()
|
|
||||||
if not voice:
|
|
||||||
raise HTTPException(status_code=404, detail="Voice not found")
|
|
||||||
db.delete(voice)
|
|
||||||
db.commit()
|
|
||||||
return {"message": "Deleted successfully"}
|
|
||||||
|
|
||||||
|
|
||||||
# ============ Assistants ============
|
# ============ Assistants ============
|
||||||
@@ -112,7 +68,12 @@ def list_assistants(
|
|||||||
total = query.count()
|
total = query.count()
|
||||||
assistants = query.order_by(Assistant.created_at.desc()) \
|
assistants = query.order_by(Assistant.created_at.desc()) \
|
||||||
.offset((page-1)*limit).limit(limit).all()
|
.offset((page-1)*limit).limit(limit).all()
|
||||||
return {"total": total, "page": page, "limit": limit, "list": assistants}
|
return {
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"limit": limit,
|
||||||
|
"list": [assistant_to_dict(a) for a in assistants]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/assistants/{id}", response_model=AssistantOut)
|
@router.get("/assistants/{id}", response_model=AssistantOut)
|
||||||
@@ -121,7 +82,7 @@ def get_assistant(id: str, db: Session = Depends(get_db)):
|
|||||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||||
if not assistant:
|
if not assistant:
|
||||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||||
return assistant
|
return assistant_to_dict(assistant)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/assistants", response_model=AssistantOut)
|
@router.post("/assistants", response_model=AssistantOut)
|
||||||
@@ -143,11 +104,15 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
|
|||||||
config_mode=data.configMode,
|
config_mode=data.configMode,
|
||||||
api_url=data.apiUrl,
|
api_url=data.apiUrl,
|
||||||
api_key=data.apiKey,
|
api_key=data.apiKey,
|
||||||
|
llm_model_id=data.llmModelId,
|
||||||
|
asr_model_id=data.asrModelId,
|
||||||
|
embedding_model_id=data.embeddingModelId,
|
||||||
|
rerank_model_id=data.rerankModelId,
|
||||||
)
|
)
|
||||||
db.add(assistant)
|
db.add(assistant)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(assistant)
|
db.refresh(assistant)
|
||||||
return assistant
|
return assistant_to_dict(assistant)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/assistants/{id}")
|
@router.put("/assistants/{id}")
|
||||||
@@ -158,13 +123,12 @@ def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_d
|
|||||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||||
|
|
||||||
update_data = data.model_dump(exclude_unset=True)
|
update_data = data.model_dump(exclude_unset=True)
|
||||||
for field, value in update_data.items():
|
_apply_assistant_update(assistant, update_data)
|
||||||
setattr(assistant, field, value)
|
|
||||||
|
|
||||||
assistant.updated_at = datetime.utcnow()
|
assistant.updated_at = datetime.utcnow()
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(assistant)
|
db.refresh(assistant)
|
||||||
return assistant
|
return assistant_to_dict(assistant)
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/assistants/{id}")
|
@router.delete("/assistants/{id}")
|
||||||
|
|||||||
@@ -7,14 +7,32 @@ from datetime import datetime
|
|||||||
from ..db import get_db
|
from ..db import get_db
|
||||||
from ..models import CallRecord, CallTranscript, CallAudioSegment
|
from ..models import CallRecord, CallTranscript, CallAudioSegment
|
||||||
from ..storage import get_audio_url
|
from ..storage import get_audio_url
|
||||||
|
from ..schemas import CallRecordCreate, CallRecordUpdate, TranscriptCreate
|
||||||
|
|
||||||
router = APIRouter(prefix="/history", tags=["history"])
|
router = APIRouter(prefix="/history", tags=["history"])
|
||||||
|
|
||||||
|
|
||||||
|
def record_to_dict(record: CallRecord) -> dict:
|
||||||
|
return {
|
||||||
|
"id": record.id,
|
||||||
|
"user_id": record.user_id,
|
||||||
|
"assistant_id": record.assistant_id,
|
||||||
|
"source": record.source,
|
||||||
|
"status": record.status,
|
||||||
|
"started_at": record.started_at,
|
||||||
|
"ended_at": record.ended_at,
|
||||||
|
"duration_seconds": record.duration_seconds,
|
||||||
|
"summary": record.summary,
|
||||||
|
"cost": record.cost,
|
||||||
|
"created_at": record.created_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get("")
|
@router.get("")
|
||||||
def list_history(
|
def list_history(
|
||||||
assistant_id: Optional[str] = None,
|
assistant_id: Optional[str] = None,
|
||||||
status: Optional[str] = None,
|
status: Optional[str] = None,
|
||||||
|
source: Optional[str] = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
limit: int = 20,
|
limit: int = 20,
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
@@ -26,12 +44,19 @@ def list_history(
|
|||||||
query = query.filter(CallRecord.assistant_id == assistant_id)
|
query = query.filter(CallRecord.assistant_id == assistant_id)
|
||||||
if status:
|
if status:
|
||||||
query = query.filter(CallRecord.status == status)
|
query = query.filter(CallRecord.status == status)
|
||||||
|
if source:
|
||||||
|
query = query.filter(CallRecord.source == source)
|
||||||
|
|
||||||
total = query.count()
|
total = query.count()
|
||||||
records = query.order_by(CallRecord.started_at.desc()) \
|
records = query.order_by(CallRecord.started_at.desc()) \
|
||||||
.offset((page-1)*limit).limit(limit).all()
|
.offset((page-1)*limit).limit(limit).all()
|
||||||
|
|
||||||
return {"total": total, "page": page, "limit": limit, "list": records}
|
return {
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"limit": limit,
|
||||||
|
"list": [record_to_dict(r) for r in records]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{call_id}")
|
@router.get("/{call_id}")
|
||||||
@@ -46,10 +71,12 @@ def get_history_detail(call_id: str, db: Session = Depends(get_db)):
|
|||||||
.filter(CallTranscript.call_id == call_id) \
|
.filter(CallTranscript.call_id == call_id) \
|
||||||
.order_by(CallTranscript.turn_index).all()
|
.order_by(CallTranscript.turn_index).all()
|
||||||
|
|
||||||
# 补充音频 URL
|
audio_segments = db.query(CallAudioSegment).filter(CallAudioSegment.call_id == call_id).all()
|
||||||
|
audio_by_turn = {seg.turn_index: seg.audio_url for seg in audio_segments if seg.turn_index is not None}
|
||||||
|
|
||||||
transcript_list = []
|
transcript_list = []
|
||||||
for t in transcripts:
|
for t in transcripts:
|
||||||
audio_url = t.audio_url or get_audio_url(call_id, t.turn_index)
|
audio_url = audio_by_turn.get(t.turn_index) or get_audio_url(call_id, t.turn_index)
|
||||||
transcript_list.append({
|
transcript_list.append({
|
||||||
"turnIndex": t.turn_index,
|
"turnIndex": t.turn_index,
|
||||||
"speaker": t.speaker,
|
"speaker": t.speaker,
|
||||||
@@ -77,32 +104,29 @@ def get_history_detail(call_id: str, db: Session = Depends(get_db)):
|
|||||||
|
|
||||||
@router.post("")
|
@router.post("")
|
||||||
def create_call_record(
|
def create_call_record(
|
||||||
user_id: int,
|
data: CallRecordCreate,
|
||||||
assistant_id: Optional[str] = None,
|
|
||||||
source: str = "debug",
|
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""创建通话记录(引擎回调使用)"""
|
"""创建通话记录(引擎回调使用)"""
|
||||||
record = CallRecord(
|
record = CallRecord(
|
||||||
id=str(uuid.uuid4())[:8],
|
id=str(uuid.uuid4())[:8],
|
||||||
user_id=user_id,
|
user_id=data.user_id,
|
||||||
assistant_id=assistant_id,
|
assistant_id=data.assistant_id,
|
||||||
source=source,
|
source=data.source,
|
||||||
status="connected",
|
status=data.status or "connected",
|
||||||
started_at=datetime.utcnow().isoformat(),
|
started_at=datetime.utcnow().isoformat(),
|
||||||
|
cost=data.cost or 0.0,
|
||||||
)
|
)
|
||||||
db.add(record)
|
db.add(record)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(record)
|
db.refresh(record)
|
||||||
return record
|
return record_to_dict(record)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{call_id}")
|
@router.put("/{call_id}")
|
||||||
def update_call_record(
|
def update_call_record(
|
||||||
call_id: str,
|
call_id: str,
|
||||||
status: Optional[str] = None,
|
data: CallRecordUpdate,
|
||||||
summary: Optional[str] = None,
|
|
||||||
duration_seconds: Optional[int] = None,
|
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""更新通话记录"""
|
"""更新通话记录"""
|
||||||
@@ -110,59 +134,64 @@ def update_call_record(
|
|||||||
if not record:
|
if not record:
|
||||||
raise HTTPException(status_code=404, detail="Call record not found")
|
raise HTTPException(status_code=404, detail="Call record not found")
|
||||||
|
|
||||||
if status:
|
if data.status is not None:
|
||||||
record.status = status
|
record.status = data.status
|
||||||
if summary:
|
if data.summary is not None:
|
||||||
record.summary = summary
|
record.summary = data.summary
|
||||||
if duration_seconds:
|
if data.duration_seconds is not None:
|
||||||
record.duration_seconds = duration_seconds
|
record.duration_seconds = data.duration_seconds
|
||||||
record.ended_at = datetime.utcnow().isoformat()
|
record.ended_at = datetime.utcnow().isoformat()
|
||||||
|
if data.ended_at is not None:
|
||||||
|
record.ended_at = data.ended_at
|
||||||
|
if data.cost is not None:
|
||||||
|
record.cost = data.cost
|
||||||
|
if data.metadata is not None:
|
||||||
|
record.call_metadata = data.metadata
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return {"message": "Updated successfully"}
|
db.refresh(record)
|
||||||
|
return record_to_dict(record)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{call_id}/transcripts")
|
@router.post("/{call_id}/transcripts")
|
||||||
def add_transcript(
|
def add_transcript(
|
||||||
call_id: str,
|
call_id: str,
|
||||||
turn_index: int,
|
data: TranscriptCreate,
|
||||||
speaker: str,
|
|
||||||
content: str,
|
|
||||||
start_ms: int,
|
|
||||||
end_ms: int,
|
|
||||||
confidence: Optional[float] = None,
|
|
||||||
duration_ms: Optional[int] = None,
|
|
||||||
emotion: Optional[str] = None,
|
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""添加转写片段"""
|
"""添加转写片段"""
|
||||||
|
record = db.query(CallRecord).filter(CallRecord.id == call_id).first()
|
||||||
|
if not record:
|
||||||
|
raise HTTPException(status_code=404, detail="Call record not found")
|
||||||
|
|
||||||
transcript = CallTranscript(
|
transcript = CallTranscript(
|
||||||
call_id=call_id,
|
call_id=call_id,
|
||||||
turn_index=turn_index,
|
turn_index=data.turn_index,
|
||||||
speaker=speaker,
|
speaker=data.speaker,
|
||||||
content=content,
|
content=data.content,
|
||||||
confidence=confidence,
|
confidence=data.confidence,
|
||||||
start_ms=start_ms,
|
start_ms=data.start_ms,
|
||||||
end_ms=end_ms,
|
end_ms=data.end_ms,
|
||||||
duration_ms=duration_ms,
|
duration_ms=data.duration_ms,
|
||||||
emotion=emotion,
|
emotion=data.emotion,
|
||||||
)
|
)
|
||||||
db.add(transcript)
|
db.add(transcript)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(transcript)
|
db.refresh(transcript)
|
||||||
|
|
||||||
# 补充音频 URL
|
# 补充音频 URL
|
||||||
audio_url = get_audio_url(call_id, turn_index)
|
audio_url = get_audio_url(call_id, data.turn_index)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": transcript.id,
|
"id": transcript.id,
|
||||||
"turn_index": turn_index,
|
"turn_index": data.turn_index,
|
||||||
"speaker": speaker,
|
"speaker": data.speaker,
|
||||||
"content": content,
|
"content": data.content,
|
||||||
"confidence": confidence,
|
"confidence": data.confidence,
|
||||||
"start_ms": start_ms,
|
"start_ms": data.start_ms,
|
||||||
"end_ms": end_ms,
|
"end_ms": data.end_ms,
|
||||||
"duration_ms": duration_ms,
|
"duration_ms": data.duration_ms,
|
||||||
|
"emotion": data.emotion,
|
||||||
"audio_url": audio_url,
|
"audio_url": audio_url,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from ..schemas import (
|
|||||||
KnowledgeBaseCreate, KnowledgeBaseUpdate, KnowledgeBaseOut,
|
KnowledgeBaseCreate, KnowledgeBaseUpdate, KnowledgeBaseOut,
|
||||||
KnowledgeSearchQuery, KnowledgeSearchResult, KnowledgeStats,
|
KnowledgeSearchQuery, KnowledgeSearchResult, KnowledgeStats,
|
||||||
DocumentIndexRequest,
|
DocumentIndexRequest,
|
||||||
|
KnowledgeDocumentCreate,
|
||||||
)
|
)
|
||||||
from ..vector_store import (
|
from ..vector_store import (
|
||||||
vector_store, search_knowledge, index_document, delete_document_from_vector
|
vector_store, search_knowledge, index_document, delete_document_from_vector
|
||||||
@@ -25,14 +26,14 @@ def kb_to_dict(kb: KnowledgeBase) -> dict:
|
|||||||
"user_id": kb.user_id,
|
"user_id": kb.user_id,
|
||||||
"name": kb.name,
|
"name": kb.name,
|
||||||
"description": kb.description,
|
"description": kb.description,
|
||||||
"embedding_model": kb.embedding_model,
|
"embeddingModel": kb.embedding_model,
|
||||||
"chunk_size": kb.chunk_size,
|
"chunkSize": kb.chunk_size,
|
||||||
"chunk_overlap": kb.chunk_overlap,
|
"chunkOverlap": kb.chunk_overlap,
|
||||||
"doc_count": kb.doc_count,
|
"docCount": kb.doc_count,
|
||||||
"chunk_count": kb.chunk_count,
|
"chunkCount": kb.chunk_count,
|
||||||
"status": kb.status,
|
"status": kb.status,
|
||||||
"created_at": kb.created_at.isoformat() if kb.created_at else None,
|
"createdAt": kb.created_at.isoformat() if kb.created_at else None,
|
||||||
"updated_at": kb.updated_at.isoformat() if kb.updated_at else None,
|
"updatedAt": kb.updated_at.isoformat() if kb.updated_at else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -42,28 +43,35 @@ def doc_to_dict(d: KnowledgeDocument) -> dict:
|
|||||||
"kb_id": d.kb_id,
|
"kb_id": d.kb_id,
|
||||||
"name": d.name,
|
"name": d.name,
|
||||||
"size": d.size,
|
"size": d.size,
|
||||||
"file_type": d.file_type,
|
"fileType": d.file_type,
|
||||||
"storage_url": d.storage_url,
|
"storageUrl": d.storage_url,
|
||||||
"status": d.status,
|
"status": d.status,
|
||||||
"chunk_count": d.chunk_count,
|
"chunkCount": d.chunk_count,
|
||||||
"error_message": d.error_message,
|
"errorMessage": d.error_message,
|
||||||
"upload_date": d.upload_date,
|
"uploadDate": d.upload_date,
|
||||||
"created_at": d.created_at.isoformat() if d.created_at else None,
|
"createdAt": d.created_at.isoformat() if d.created_at else None,
|
||||||
"processed_at": d.processed_at.isoformat() if d.processed_at else None,
|
"processedAt": d.processed_at.isoformat() if d.processed_at else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# ============ Knowledge Bases ============
|
# ============ Knowledge Bases ============
|
||||||
@router.get("/bases")
|
@router.get("/bases")
|
||||||
def list_knowledge_bases(user_id: int = 1, db: Session = Depends(get_db)):
|
def list_knowledge_bases(
|
||||||
kbs = db.query(KnowledgeBase).filter(KnowledgeBase.user_id == user_id).all()
|
user_id: int = 1,
|
||||||
|
page: int = 1,
|
||||||
|
limit: int = 50,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
query = db.query(KnowledgeBase).filter(KnowledgeBase.user_id == user_id)
|
||||||
|
total = query.count()
|
||||||
|
kbs = query.order_by(KnowledgeBase.created_at.desc()).offset((page - 1) * limit).limit(limit).all()
|
||||||
result = []
|
result = []
|
||||||
for kb in kbs:
|
for kb in kbs:
|
||||||
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb.id).all()
|
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb.id).all()
|
||||||
kb_data = kb_to_dict(kb)
|
kb_data = kb_to_dict(kb)
|
||||||
kb_data["documents"] = [doc_to_dict(d) for d in docs]
|
kb_data["documents"] = [doc_to_dict(d) for d in docs]
|
||||||
result.append(kb_data)
|
result.append(kb_data)
|
||||||
return {"total": len(result), "list": result}
|
return {"total": total, "page": page, "limit": limit, "list": result}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/bases/{kb_id}")
|
@router.get("/bases/{kb_id}")
|
||||||
@@ -91,7 +99,10 @@ def create_knowledge_base(data: KnowledgeBaseCreate, user_id: int = 1, db: Sessi
|
|||||||
db.add(kb)
|
db.add(kb)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(kb)
|
db.refresh(kb)
|
||||||
|
try:
|
||||||
vector_store.create_collection(kb.id, data.embeddingModel)
|
vector_store.create_collection(kb.id, data.embeddingModel)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
return kb_to_dict(kb)
|
return kb_to_dict(kb)
|
||||||
|
|
||||||
|
|
||||||
@@ -101,8 +112,13 @@ def update_knowledge_base(kb_id: str, data: KnowledgeBaseUpdate, db: Session = D
|
|||||||
if not kb:
|
if not kb:
|
||||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||||
update_data = data.model_dump(exclude_unset=True)
|
update_data = data.model_dump(exclude_unset=True)
|
||||||
|
field_map = {
|
||||||
|
"embeddingModel": "embedding_model",
|
||||||
|
"chunkSize": "chunk_size",
|
||||||
|
"chunkOverlap": "chunk_overlap",
|
||||||
|
}
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(kb, field, value)
|
setattr(kb, field_map.get(field, field), value)
|
||||||
kb.updated_at = datetime.utcnow()
|
kb.updated_at = datetime.utcnow()
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(kb)
|
db.refresh(kb)
|
||||||
@@ -114,7 +130,10 @@ def delete_knowledge_base(kb_id: str, db: Session = Depends(get_db)):
|
|||||||
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
||||||
if not kb:
|
if not kb:
|
||||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||||
|
try:
|
||||||
vector_store.delete_collection(kb_id)
|
vector_store.delete_collection(kb_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb_id).all()
|
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb_id).all()
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
db.delete(doc)
|
db.delete(doc)
|
||||||
@@ -127,10 +146,7 @@ def delete_knowledge_base(kb_id: str, db: Session = Depends(get_db)):
|
|||||||
@router.post("/bases/{kb_id}/documents")
|
@router.post("/bases/{kb_id}/documents")
|
||||||
def upload_document(
|
def upload_document(
|
||||||
kb_id: str,
|
kb_id: str,
|
||||||
name: str = Query(...),
|
data: KnowledgeDocumentCreate,
|
||||||
size: str = Query(...),
|
|
||||||
file_type: str = Query("txt"),
|
|
||||||
storage_url: Optional[str] = Query(None),
|
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
||||||
@@ -139,17 +155,25 @@ def upload_document(
|
|||||||
doc = KnowledgeDocument(
|
doc = KnowledgeDocument(
|
||||||
id=str(uuid.uuid4())[:8],
|
id=str(uuid.uuid4())[:8],
|
||||||
kb_id=kb_id,
|
kb_id=kb_id,
|
||||||
name=name,
|
name=data.name,
|
||||||
size=size,
|
size=data.size,
|
||||||
file_type=file_type,
|
file_type=data.fileType,
|
||||||
storage_url=storage_url,
|
storage_url=data.storageUrl,
|
||||||
status="pending",
|
status="pending",
|
||||||
upload_date=datetime.utcnow().isoformat()
|
upload_date=datetime.utcnow().isoformat()
|
||||||
)
|
)
|
||||||
db.add(doc)
|
db.add(doc)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(doc)
|
db.refresh(doc)
|
||||||
return {"id": doc.id, "name": doc.name, "status": doc.status, "message": "Document created"}
|
return {
|
||||||
|
"id": doc.id,
|
||||||
|
"name": doc.name,
|
||||||
|
"size": doc.size,
|
||||||
|
"fileType": doc.file_type,
|
||||||
|
"storageUrl": doc.storage_url,
|
||||||
|
"status": doc.status,
|
||||||
|
"message": "Document created",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/bases/{kb_id}/documents/{doc_id}/index")
|
@router.post("/bases/{kb_id}/documents/{doc_id}/index")
|
||||||
@@ -212,8 +236,9 @@ def delete_document(kb_id: str, doc_id: str, db: Session = Depends(get_db)):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
||||||
kb.chunk_count -= doc.chunk_count
|
if kb:
|
||||||
kb.doc_count -= 1
|
kb.chunk_count = max(0, kb.chunk_count - (doc.chunk_count or 0))
|
||||||
|
kb.doc_count = max(0, kb.doc_count - 1)
|
||||||
db.delete(doc)
|
db.delete(doc)
|
||||||
db.commit()
|
db.commit()
|
||||||
return {"message": "Deleted successfully"}
|
return {"message": "Deleted successfully"}
|
||||||
|
|||||||
@@ -10,14 +10,14 @@ from ..db import get_db
|
|||||||
from ..models import LLMModel
|
from ..models import LLMModel
|
||||||
from ..schemas import (
|
from ..schemas import (
|
||||||
LLMModelCreate, LLMModelUpdate, LLMModelOut,
|
LLMModelCreate, LLMModelUpdate, LLMModelOut,
|
||||||
LLMModelTestResponse, ListResponse
|
LLMModelTestResponse
|
||||||
)
|
)
|
||||||
|
|
||||||
router = APIRouter(prefix="/llm", tags=["LLM Models"])
|
router = APIRouter(prefix="/llm", tags=["LLM Models"])
|
||||||
|
|
||||||
|
|
||||||
# ============ LLM Models CRUD ============
|
# ============ LLM Models CRUD ============
|
||||||
@router.get("", response_model=ListResponse)
|
@router.get("")
|
||||||
def list_llm_models(
|
def list_llm_models(
|
||||||
model_type: Optional[str] = None,
|
model_type: Optional[str] = None,
|
||||||
enabled: Optional[bool] = None,
|
enabled: Optional[bool] = None,
|
||||||
|
|||||||
94
api/app/routers/voices.py
Normal file
94
api/app/routers/voices.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Optional
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from ..db import get_db
|
||||||
|
from ..models import Voice
|
||||||
|
from ..schemas import VoiceCreate, VoiceUpdate, VoiceOut
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/voices")
|
||||||
|
def list_voices(
|
||||||
|
vendor: Optional[str] = None,
|
||||||
|
language: Optional[str] = None,
|
||||||
|
gender: Optional[str] = None,
|
||||||
|
page: int = 1,
|
||||||
|
limit: int = 50,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""获取声音库列表"""
|
||||||
|
query = db.query(Voice)
|
||||||
|
if vendor:
|
||||||
|
query = query.filter(Voice.vendor == vendor)
|
||||||
|
if language:
|
||||||
|
query = query.filter(Voice.language == language)
|
||||||
|
if gender:
|
||||||
|
query = query.filter(Voice.gender == gender)
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
voices = query.order_by(Voice.created_at.desc()) \
|
||||||
|
.offset((page - 1) * limit).limit(limit).all()
|
||||||
|
return {"total": total, "page": page, "limit": limit, "list": voices}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/voices", response_model=VoiceOut)
|
||||||
|
def create_voice(data: VoiceCreate, db: Session = Depends(get_db)):
|
||||||
|
"""创建声音"""
|
||||||
|
voice = Voice(
|
||||||
|
id=data.id or str(uuid.uuid4())[:8],
|
||||||
|
user_id=1,
|
||||||
|
name=data.name,
|
||||||
|
vendor=data.vendor,
|
||||||
|
gender=data.gender,
|
||||||
|
language=data.language,
|
||||||
|
description=data.description,
|
||||||
|
model=data.model,
|
||||||
|
voice_key=data.voice_key,
|
||||||
|
speed=data.speed,
|
||||||
|
gain=data.gain,
|
||||||
|
pitch=data.pitch,
|
||||||
|
enabled=data.enabled,
|
||||||
|
)
|
||||||
|
db.add(voice)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(voice)
|
||||||
|
return voice
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/voices/{id}", response_model=VoiceOut)
|
||||||
|
def get_voice(id: str, db: Session = Depends(get_db)):
|
||||||
|
"""获取单个声音详情"""
|
||||||
|
voice = db.query(Voice).filter(Voice.id == id).first()
|
||||||
|
if not voice:
|
||||||
|
raise HTTPException(status_code=404, detail="Voice not found")
|
||||||
|
return voice
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/voices/{id}", response_model=VoiceOut)
|
||||||
|
def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)):
|
||||||
|
"""更新声音"""
|
||||||
|
voice = db.query(Voice).filter(Voice.id == id).first()
|
||||||
|
if not voice:
|
||||||
|
raise HTTPException(status_code=404, detail="Voice not found")
|
||||||
|
|
||||||
|
update_data = data.model_dump(exclude_unset=True)
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(voice, field, value)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(voice)
|
||||||
|
return voice
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/voices/{id}")
|
||||||
|
def delete_voice(id: str, db: Session = Depends(get_db)):
|
||||||
|
"""删除声音"""
|
||||||
|
voice = db.query(Voice).filter(Voice.id == id).first()
|
||||||
|
if not voice:
|
||||||
|
raise HTTPException(status_code=404, detail="Voice not found")
|
||||||
|
db.delete(voice)
|
||||||
|
db.commit()
|
||||||
|
return {"message": "Deleted successfully"}
|
||||||
@@ -50,8 +50,9 @@ class VoiceBase(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class VoiceCreate(VoiceBase):
|
class VoiceCreate(VoiceBase):
|
||||||
model: str # 厂商语音模型标识
|
id: Optional[str] = None
|
||||||
voice_key: str # 厂商voice_key
|
model: Optional[str] = None # 厂商语音模型标识
|
||||||
|
voice_key: Optional[str] = None # 厂商voice_key
|
||||||
speed: float = 1.0
|
speed: float = 1.0
|
||||||
gain: int = 0
|
gain: int = 0
|
||||||
pitch: int = 0
|
pitch: int = 0
|
||||||
@@ -113,7 +114,7 @@ class LLMModelBase(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class LLMModelCreate(LLMModelBase):
|
class LLMModelCreate(LLMModelBase):
|
||||||
pass
|
id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class LLMModelUpdate(BaseModel):
|
class LLMModelUpdate(BaseModel):
|
||||||
@@ -154,6 +155,7 @@ class ASRModelBase(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ASRModelCreate(ASRModelBase):
|
class ASRModelCreate(ASRModelBase):
|
||||||
|
id: Optional[str] = None
|
||||||
hotwords: List[str] = []
|
hotwords: List[str] = []
|
||||||
enable_punctuation: bool = True
|
enable_punctuation: bool = True
|
||||||
enable_normalization: bool = True
|
enable_normalization: bool = True
|
||||||
@@ -195,6 +197,7 @@ class ASRTestResponse(BaseModel):
|
|||||||
confidence: Optional[float] = None
|
confidence: Optional[float] = None
|
||||||
duration_ms: Optional[int] = None
|
duration_ms: Optional[int] = None
|
||||||
latency_ms: Optional[int] = None
|
latency_ms: Optional[int] = None
|
||||||
|
message: Optional[str] = None
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -413,6 +416,8 @@ class CallRecordCreate(BaseModel):
|
|||||||
user_id: int
|
user_id: int
|
||||||
assistant_id: Optional[str] = None
|
assistant_id: Optional[str] = None
|
||||||
source: str = "debug"
|
source: str = "debug"
|
||||||
|
status: Optional[str] = None
|
||||||
|
cost: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
class CallRecordUpdate(BaseModel):
|
class CallRecordUpdate(BaseModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user