Add tts/text output schema
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker, DeclarativeBase
|
||||
import os
|
||||
|
||||
@@ -14,6 +14,32 @@ class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def ensure_schema_compatibility() -> None:
|
||||
"""Best-effort lightweight migrations for SQLite deployments."""
|
||||
if engine.dialect.name != "sqlite":
|
||||
return
|
||||
|
||||
with engine.begin() as conn:
|
||||
columns = {
|
||||
row[1]
|
||||
for row in conn.execute(text("PRAGMA table_info(assistants)"))
|
||||
}
|
||||
if "voice_output_enabled" not in columns:
|
||||
conn.execute(
|
||||
text(
|
||||
"ALTER TABLE assistants "
|
||||
"ADD COLUMN voice_output_enabled BOOLEAN DEFAULT 1"
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"UPDATE assistants "
|
||||
"SET voice_output_enabled = 1 "
|
||||
"WHERE voice_output_enabled IS NULL"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
|
||||
@@ -3,7 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
import os
|
||||
|
||||
from .db import Base, engine
|
||||
from .db import Base, engine, ensure_schema_compatibility
|
||||
from .routers import assistants, voices, workflows, history, knowledge, llm, asr, tools
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from .routers import assistants, voices, workflows, history, knowledge, llm, asr
|
||||
async def lifespan(app: FastAPI):
|
||||
# 启动时创建表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
ensure_schema_compatibility()
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@@ -112,6 +112,7 @@ class Assistant(Base):
|
||||
prompt: Mapped[str] = mapped_column(Text, default="")
|
||||
knowledge_base_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
language: Mapped[str] = mapped_column(String(16), default="zh")
|
||||
voice_output_enabled: Mapped[bool] = mapped_column(default=True)
|
||||
voice: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
speed: Mapped[float] = mapped_column(Float, default=1.0)
|
||||
hotwords: Mapped[dict] = mapped_column(JSON, default=list)
|
||||
|
||||
@@ -21,6 +21,7 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
|
||||
metadata = {
|
||||
"systemPrompt": assistant.prompt or "",
|
||||
"greeting": assistant.opener or "",
|
||||
"output": {"mode": "audio" if assistant.voice_output_enabled else "text"},
|
||||
"services": {},
|
||||
}
|
||||
warnings = []
|
||||
@@ -49,11 +50,14 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
|
||||
else:
|
||||
warnings.append(f"ASR model not found: {assistant.asr_model_id}")
|
||||
|
||||
if assistant.voice:
|
||||
if not assistant.voice_output_enabled:
|
||||
metadata["services"]["tts"] = {"enabled": False}
|
||||
elif assistant.voice:
|
||||
voice = db.query(Voice).filter(Voice.id == assistant.voice).first()
|
||||
if voice:
|
||||
tts_provider = "siliconflow" if _is_siliconflow_vendor(voice.vendor) else "edge"
|
||||
metadata["services"]["tts"] = {
|
||||
"enabled": True,
|
||||
"provider": tts_provider,
|
||||
"model": voice.model,
|
||||
"apiKey": voice.api_key if tts_provider == "siliconflow" else None,
|
||||
@@ -63,6 +67,7 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
|
||||
else:
|
||||
# Keep assistant.voice as direct voice identifier fallback
|
||||
metadata["services"]["tts"] = {
|
||||
"enabled": True,
|
||||
"voice": assistant.voice,
|
||||
"speed": assistant.speed or 1.0,
|
||||
}
|
||||
@@ -98,6 +103,7 @@ def assistant_to_dict(assistant: Assistant) -> dict:
|
||||
"prompt": assistant.prompt or "",
|
||||
"knowledgeBaseId": assistant.knowledge_base_id,
|
||||
"language": assistant.language,
|
||||
"voiceOutputEnabled": assistant.voice_output_enabled,
|
||||
"voice": assistant.voice,
|
||||
"speed": assistant.speed,
|
||||
"hotwords": assistant.hotwords or [],
|
||||
@@ -120,6 +126,7 @@ def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None:
|
||||
"knowledgeBaseId": "knowledge_base_id",
|
||||
"interruptionSensitivity": "interruption_sensitivity",
|
||||
"configMode": "config_mode",
|
||||
"voiceOutputEnabled": "voice_output_enabled",
|
||||
"apiUrl": "api_url",
|
||||
"apiKey": "api_key",
|
||||
"llmModelId": "llm_model_id",
|
||||
@@ -180,6 +187,7 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
|
||||
prompt=data.prompt,
|
||||
knowledge_base_id=data.knowledgeBaseId,
|
||||
language=data.language,
|
||||
voice_output_enabled=data.voiceOutputEnabled,
|
||||
voice=data.voice,
|
||||
speed=data.speed,
|
||||
hotwords=data.hotwords,
|
||||
|
||||
@@ -268,6 +268,7 @@ class AssistantBase(BaseModel):
|
||||
prompt: str = ""
|
||||
knowledgeBaseId: Optional[str] = None
|
||||
language: str = "zh"
|
||||
voiceOutputEnabled: bool = True
|
||||
voice: Optional[str] = None
|
||||
speed: float = 1.0
|
||||
hotwords: List[str] = []
|
||||
@@ -293,6 +294,7 @@ class AssistantUpdate(BaseModel):
|
||||
prompt: Optional[str] = None
|
||||
knowledgeBaseId: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
voiceOutputEnabled: Optional[bool] = None
|
||||
voice: Optional[str] = None
|
||||
speed: Optional[float] = None
|
||||
hotwords: Optional[List[str]] = None
|
||||
|
||||
@@ -85,6 +85,7 @@ def sample_assistant_data():
|
||||
"opener": "Hello, welcome!",
|
||||
"prompt": "You are a helpful assistant.",
|
||||
"language": "zh",
|
||||
"voiceOutputEnabled": True,
|
||||
"speed": 1.0,
|
||||
"hotwords": ["test", "hello"],
|
||||
"tools": [],
|
||||
|
||||
@@ -23,6 +23,7 @@ class TestAssistantAPI:
|
||||
assert data["opener"] == sample_assistant_data["opener"]
|
||||
assert data["prompt"] == sample_assistant_data["prompt"]
|
||||
assert data["language"] == sample_assistant_data["language"]
|
||||
assert data["voiceOutputEnabled"] is True
|
||||
assert "id" in data
|
||||
assert data["callCount"] == 0
|
||||
|
||||
@@ -61,7 +62,8 @@ class TestAssistantAPI:
|
||||
update_data = {
|
||||
"name": "Updated Assistant",
|
||||
"prompt": "You are an updated assistant.",
|
||||
"speed": 1.5
|
||||
"speed": 1.5,
|
||||
"voiceOutputEnabled": False,
|
||||
}
|
||||
response = client.put(f"/api/assistants/{assistant_id}", json=update_data)
|
||||
assert response.status_code == 200
|
||||
@@ -69,6 +71,7 @@ class TestAssistantAPI:
|
||||
assert data["name"] == "Updated Assistant"
|
||||
assert data["prompt"] == "You are an updated assistant."
|
||||
assert data["speed"] == 1.5
|
||||
assert data["voiceOutputEnabled"] is False
|
||||
|
||||
def test_delete_assistant(self, client, sample_assistant_data):
|
||||
"""Test deleting an assistant"""
|
||||
@@ -210,3 +213,15 @@ class TestAssistantAPI:
|
||||
assert metadata["services"]["llm"]["model"] == sample_llm_model_data["model_name"]
|
||||
assert metadata["services"]["asr"]["model"] == sample_asr_model_data["model_name"]
|
||||
assert metadata["services"]["tts"]["voice"] == sample_voice_data["voice_key"]
|
||||
|
||||
def test_runtime_config_text_mode_when_voice_output_disabled(self, client, sample_assistant_data):
|
||||
sample_assistant_data["voiceOutputEnabled"] = False
|
||||
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assert assistant_resp.status_code == 200
|
||||
assistant_id = assistant_resp.json()["id"]
|
||||
|
||||
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
|
||||
assert runtime_resp.status_code == 200
|
||||
metadata = runtime_resp.json()["sessionStartMetadata"]
|
||||
assert metadata["output"]["mode"] == "text"
|
||||
assert metadata["services"]["tts"]["enabled"] is False
|
||||
|
||||
Reference in New Issue
Block a user