From 15523d9ec2cce7ee45b63c96576ea8ae9cc932f0 Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Wed, 11 Feb 2026 09:50:46 +0800 Subject: [PATCH] Add tts/text output schema --- api/app/db.py | 28 ++++++- api/app/main.py | 3 +- api/app/models.py | 1 + api/app/routers/assistants.py | 10 ++- api/app/schemas.py | 2 + api/tests/conftest.py | 1 + api/tests/test_assistants.py | 17 ++++- engine/core/duplex_pipeline.py | 114 +++++++++++++++++++--------- engine/docs/ws_v1_schema.md | 8 ++ engine/tests/test_tool_call_flow.py | 30 ++++++++ web/pages/Assistants.tsx | 51 ++++++++++--- web/services/backendApi.ts | 3 + web/types.ts | 1 + 13 files changed, 219 insertions(+), 50 deletions(-) diff --git a/api/app/db.py b/api/app/db.py index 3b04bc8..22a25af 100644 --- a/api/app/db.py +++ b/api/app/db.py @@ -1,4 +1,4 @@ -from sqlalchemy import create_engine +from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker, DeclarativeBase import os @@ -14,6 +14,32 @@ class Base(DeclarativeBase): pass +def ensure_schema_compatibility() -> None: + """Best-effort lightweight migrations for SQLite deployments.""" + if engine.dialect.name != "sqlite": + return + + with engine.begin() as conn: + columns = { + row[1] + for row in conn.execute(text("PRAGMA table_info(assistants)")) + } + if "voice_output_enabled" not in columns: + conn.execute( + text( + "ALTER TABLE assistants " + "ADD COLUMN voice_output_enabled BOOLEAN DEFAULT 1" + ) + ) + conn.execute( + text( + "UPDATE assistants " + "SET voice_output_enabled = 1 " + "WHERE voice_output_enabled IS NULL" + ) + ) + + def get_db(): db = SessionLocal() try: diff --git a/api/app/main.py b/api/app/main.py index a193ff9..055d53c 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -3,7 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager import os -from .db import Base, engine +from .db import Base, engine, ensure_schema_compatibility from .routers import assistants, voices, workflows, history, knowledge, llm, asr, tools @@ -11,6 +11,7 @@ from .routers import assistants, voices, workflows, history, knowledge, llm, asr async def lifespan(app: FastAPI): # 启动时创建表 Base.metadata.create_all(bind=engine) + ensure_schema_compatibility() yield diff --git a/api/app/models.py b/api/app/models.py index c69bd78..cc253aa 100644 --- a/api/app/models.py +++ b/api/app/models.py @@ -112,6 +112,7 @@ class Assistant(Base): prompt: Mapped[str] = mapped_column(Text, default="") knowledge_base_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) language: Mapped[str] = mapped_column(String(16), default="zh") + voice_output_enabled: Mapped[bool] = mapped_column(default=True) voice: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) speed: Mapped[float] = mapped_column(Float, default=1.0) hotwords: Mapped[dict] = mapped_column(JSON, default=list) diff --git a/api/app/routers/assistants.py b/api/app/routers/assistants.py index 0e5227e..2368cee 100644 --- a/api/app/routers/assistants.py +++ b/api/app/routers/assistants.py @@ -21,6 +21,7 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict: metadata = { "systemPrompt": assistant.prompt or "", "greeting": assistant.opener or "", + "output": {"mode": "audio" if assistant.voice_output_enabled else "text"}, "services": {}, } warnings = [] @@ -49,11 +50,14 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict: else: warnings.append(f"ASR model not found: {assistant.asr_model_id}") - if assistant.voice: + if not assistant.voice_output_enabled: + metadata["services"]["tts"] = {"enabled": False} + elif assistant.voice: voice = db.query(Voice).filter(Voice.id == assistant.voice).first() if voice: tts_provider = "siliconflow" if _is_siliconflow_vendor(voice.vendor) else "edge" metadata["services"]["tts"] = { + "enabled": True, "provider": tts_provider, "model": voice.model, "apiKey": voice.api_key if tts_provider == "siliconflow" else None, @@ -63,6 +67,7 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict: else: # Keep assistant.voice as direct voice identifier fallback metadata["services"]["tts"] = { + "enabled": True, "voice": assistant.voice, "speed": assistant.speed or 1.0, } @@ -98,6 +103,7 @@ def assistant_to_dict(assistant: Assistant) -> dict: "prompt": assistant.prompt or "", "knowledgeBaseId": assistant.knowledge_base_id, "language": assistant.language, + "voiceOutputEnabled": assistant.voice_output_enabled, "voice": assistant.voice, "speed": assistant.speed, "hotwords": assistant.hotwords or [], @@ -120,6 +126,7 @@ def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None: "knowledgeBaseId": "knowledge_base_id", "interruptionSensitivity": "interruption_sensitivity", "configMode": "config_mode", + "voiceOutputEnabled": "voice_output_enabled", "apiUrl": "api_url", "apiKey": "api_key", "llmModelId": "llm_model_id", @@ -180,6 +187,7 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)): prompt=data.prompt, knowledge_base_id=data.knowledgeBaseId, language=data.language, + voice_output_enabled=data.voiceOutputEnabled, voice=data.voice, speed=data.speed, hotwords=data.hotwords, diff --git a/api/app/schemas.py b/api/app/schemas.py index 457476b..6b7645e 100644 --- a/api/app/schemas.py +++ b/api/app/schemas.py @@ -268,6 +268,7 @@ class AssistantBase(BaseModel): prompt: str = "" knowledgeBaseId: Optional[str] = None language: str = "zh" + voiceOutputEnabled: bool = True voice: Optional[str] = None speed: float = 1.0 hotwords: List[str] = [] @@ -293,6 +294,7 @@ class AssistantUpdate(BaseModel): prompt: Optional[str] = None knowledgeBaseId: Optional[str] = None language: Optional[str] = None + voiceOutputEnabled: Optional[bool] = None voice: Optional[str] = None speed: Optional[float] = None hotwords: Optional[List[str]] = None diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 015cef7..9435444 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -85,6 +85,7 @@ def sample_assistant_data(): "opener": "Hello, welcome!", "prompt": "You are a helpful assistant.", "language": "zh", + "voiceOutputEnabled": True, "speed": 1.0, "hotwords": ["test", "hello"], "tools": [], diff --git a/api/tests/test_assistants.py b/api/tests/test_assistants.py index 106ff10..a0140d1 100644 --- a/api/tests/test_assistants.py +++ b/api/tests/test_assistants.py @@ -23,6 +23,7 @@ class TestAssistantAPI: assert data["opener"] == sample_assistant_data["opener"] assert data["prompt"] == sample_assistant_data["prompt"] assert data["language"] == sample_assistant_data["language"] + assert data["voiceOutputEnabled"] is True assert "id" in data assert data["callCount"] == 0 @@ -61,7 +62,8 @@ class TestAssistantAPI: update_data = { "name": "Updated Assistant", "prompt": "You are an updated assistant.", - "speed": 1.5 + "speed": 1.5, + "voiceOutputEnabled": False, } response = client.put(f"/api/assistants/{assistant_id}", json=update_data) assert response.status_code == 200 @@ -69,6 +71,7 @@ class TestAssistantAPI: assert data["name"] == "Updated Assistant" assert data["prompt"] == "You are an updated assistant." assert data["speed"] == 1.5 + assert data["voiceOutputEnabled"] is False def test_delete_assistant(self, client, sample_assistant_data): """Test deleting an assistant""" @@ -210,3 +213,15 @@ class TestAssistantAPI: assert metadata["services"]["llm"]["model"] == sample_llm_model_data["model_name"] assert metadata["services"]["asr"]["model"] == sample_asr_model_data["model_name"] assert metadata["services"]["tts"]["voice"] == sample_voice_data["voice_key"] + + def test_runtime_config_text_mode_when_voice_output_disabled(self, client, sample_assistant_data): + sample_assistant_data["voiceOutputEnabled"] = False + assistant_resp = client.post("/api/assistants", json=sample_assistant_data) + assert assistant_resp.status_code == 200 + assistant_id = assistant_resp.json()["id"] + + runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config") + assert runtime_resp.status_code == 200 + metadata = runtime_resp.json()["sessionStartMetadata"] + assert metadata["output"]["mode"] == "text" + assert metadata["services"]["tts"]["enabled"] is False diff --git a/engine/core/duplex_pipeline.py b/engine/core/duplex_pipeline.py index 2fa2948..8527144 100644 --- a/engine/core/duplex_pipeline.py +++ b/engine/core/duplex_pipeline.py @@ -211,6 +211,7 @@ class DuplexPipeline: self._runtime_llm: Dict[str, Any] = {} self._runtime_asr: Dict[str, Any] = {} self._runtime_tts: Dict[str, Any] = {} + self._runtime_output: Dict[str, Any] = {} self._runtime_system_prompt: Optional[str] = None self._runtime_greeting: Optional[str] = None self._runtime_knowledge: Dict[str, Any] = {} @@ -257,6 +258,9 @@ class DuplexPipeline: self._runtime_asr = services["asr"] if isinstance(services.get("tts"), dict): self._runtime_tts = services["tts"] + output = metadata.get("output") or {} + if isinstance(output, dict): + self._runtime_output = output knowledge_base_id = metadata.get("knowledgeBaseId") if knowledge_base_id is not None: @@ -283,6 +287,31 @@ class DuplexPipeline: if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"): self.llm_service.set_tool_schemas(self._resolved_tool_schemas()) + @staticmethod + def _coerce_bool(value: Any) -> Optional[bool]: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"1", "true", "yes", "on", "enabled"}: + return True + if normalized in {"0", "false", "no", "off", "disabled"}: + return False + return None + + def _tts_output_enabled(self) -> bool: + enabled = self._coerce_bool(self._runtime_tts.get("enabled")) + if enabled is not None: + return enabled + + output_mode = str(self._runtime_output.get("mode") or "").strip().lower() + if output_mode in {"text", "text_only", "text-only"}: + return False + + return True + async def start(self) -> None: """Start the pipeline and connect services.""" try: @@ -311,38 +340,44 @@ class DuplexPipeline: await self.llm_service.connect() - # Connect TTS service - if not self.tts_service: - tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower() - tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key - tts_voice = self._runtime_tts.get("voice") or settings.tts_voice - tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model - tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed) + tts_output_enabled = self._tts_output_enabled() - if tts_provider == "siliconflow" and tts_api_key: - self.tts_service = SiliconFlowTTSService( - api_key=tts_api_key, - voice=tts_voice, - model=tts_model, - sample_rate=settings.sample_rate, - speed=tts_speed - ) - logger.info("Using SiliconFlow TTS service") - else: - self.tts_service = EdgeTTSService( - voice=tts_voice, + # Connect TTS service only when audio output is enabled. + if tts_output_enabled: + if not self.tts_service: + tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower() + tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key + tts_voice = self._runtime_tts.get("voice") or settings.tts_voice + tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model + tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed) + + if tts_provider == "siliconflow" and tts_api_key: + self.tts_service = SiliconFlowTTSService( + api_key=tts_api_key, + voice=tts_voice, + model=tts_model, + sample_rate=settings.sample_rate, + speed=tts_speed + ) + logger.info("Using SiliconFlow TTS service") + else: + self.tts_service = EdgeTTSService( + voice=tts_voice, + sample_rate=settings.sample_rate + ) + logger.info("Using Edge TTS service") + + try: + await self.tts_service.connect() + except Exception as e: + logger.warning(f"TTS backend unavailable ({e}); falling back to MockTTS") + self.tts_service = MockTTSService( sample_rate=settings.sample_rate ) - logger.info("Using Edge TTS service") - - try: - await self.tts_service.connect() - except Exception as e: - logger.warning(f"TTS backend unavailable ({e}); falling back to MockTTS") - self.tts_service = MockTTSService( - sample_rate=settings.sample_rate - ) - await self.tts_service.connect() + await self.tts_service.connect() + else: + self.tts_service = None + logger.info("TTS output disabled by runtime metadata") # Connect ASR service if not self.asr_service: @@ -375,7 +410,7 @@ class DuplexPipeline: self._outbound_task = asyncio.create_task(self._outbound_loop()) # Speak greeting if configured - if self.conversation.greeting: + if self.conversation.greeting and tts_output_enabled: await self._speak(self.conversation.greeting) except Exception as e: @@ -932,7 +967,7 @@ class DuplexPipeline: pending_punctuation = sentence continue - if not self._interrupt_event.is_set(): + if self._tts_output_enabled() and not self._interrupt_event.is_set(): if not first_audio_sent: await self._send_event( { @@ -952,7 +987,12 @@ class DuplexPipeline: ) remaining_text = f"{pending_punctuation}{sentence_buffer}".strip() - if remaining_text and has_spoken_content(remaining_text) and not self._interrupt_event.is_set(): + if ( + self._tts_output_enabled() + and remaining_text + and has_spoken_content(remaining_text) + and not self._interrupt_event.is_set() + ): if not first_audio_sent: await self._send_event( { @@ -1066,7 +1106,10 @@ class DuplexPipeline: fade_in_ms: Fade-in duration for sentence start chunks fade_out_ms: Fade-out duration for sentence end chunks """ - if not text.strip() or self._interrupt_event.is_set(): + if not self._tts_output_enabled(): + return + + if not text.strip() or self._interrupt_event.is_set() or not self.tts_service: return logger.info(f"[TTS] split sentence: {text!r}") @@ -1153,7 +1196,10 @@ class DuplexPipeline: Args: text: Text to speak """ - if not text.strip(): + if not self._tts_output_enabled(): + return + + if not text.strip() or not self.tts_service: return try: diff --git a/engine/docs/ws_v1_schema.md b/engine/docs/ws_v1_schema.md index 0d21432..88ac755 100644 --- a/engine/docs/ws_v1_schema.md +++ b/engine/docs/ws_v1_schema.md @@ -53,6 +53,9 @@ Rules: }, "metadata": { "client": "web-debug", + "output": { + "mode": "audio" + }, "systemPrompt": "You are concise.", "greeting": "Hi, how can I help?", "services": { @@ -70,6 +73,7 @@ Rules: "minAudioMs": 300 }, "tts": { + "enabled": true, "provider": "siliconflow", "model": "FunAudioLLM/CosyVoice2-0.5B", "apiKey": "sf-...", @@ -83,6 +87,10 @@ Rules: `metadata.services` is optional. If omitted, server defaults to environment configuration. +Text-only mode: +- Set `metadata.output.mode = "text"` OR `metadata.services.tts.enabled = false`. +- In this mode server still sends `assistant.response.delta/final`, but will not emit audio frames or `output.audio.start/end`. + ### `input.text` ```json diff --git a/engine/tests/test_tool_call_flow.py b/engine/tests/test_tool_call_flow.py index d21e73a..e5f241b 100644 --- a/engine/tests/test_tool_call_flow.py +++ b/engine/tests/test_tool_call_flow.py @@ -125,6 +125,36 @@ async def test_turn_without_tool_keeps_streaming(monkeypatch): assert "assistant.tool_call" not in event_types +@pytest.mark.asyncio +@pytest.mark.parametrize( + "metadata", + [ + {"output": {"mode": "text"}}, + {"services": {"tts": {"enabled": False}}}, + ], +) +async def test_text_output_mode_skips_audio_events(monkeypatch, metadata): + pipeline, events = _build_pipeline( + monkeypatch, + [ + [ + LLMStreamEvent(type="text_delta", text="hello "), + LLMStreamEvent(type="text_delta", text="world."), + LLMStreamEvent(type="done"), + ] + ], + ) + pipeline.apply_runtime_overrides(metadata) + + await pipeline._handle_turn("hi") + + event_types = [e.get("type") for e in events] + assert "assistant.response.delta" in event_types + assert "assistant.response.final" in event_types + assert "output.audio.start" not in event_types + assert "output.audio.end" not in event_types + + @pytest.mark.asyncio async def test_turn_with_tool_call_then_results(monkeypatch): pipeline, events = _build_pipeline( diff --git a/web/pages/Assistants.tsx b/web/pages/Assistants.tsx index 87646a7..58b09ab 100644 --- a/web/pages/Assistants.tsx +++ b/web/pages/Assistants.tsx @@ -118,6 +118,7 @@ export const AssistantsPage: React.FC = () => { prompt: '', knowledgeBaseId: '', language: 'zh', + voiceOutputEnabled: true, voice: voices[0]?.id || '', speed: 1, hotwords: [], @@ -531,6 +532,7 @@ export const AssistantsPage: React.FC = () => { placeholder="设定小助手的人设、语气、行为规范以及业务逻辑..." /> + )} @@ -624,15 +626,32 @@ export const AssistantsPage: React.FC = () => {

+
+ + +

关闭后将进入纯文本输出模式,不会产生语音音频。

+
+
setTextTtsEnabled(e.target.checked)} - className="accent-primary" - /> - TTS - + + TTS: {textTtsEnabled ? 'ON' : 'OFF'} +

Audio 3A

diff --git a/web/services/backendApi.ts b/web/services/backendApi.ts index 72c5720..5ea47c3 100644 --- a/web/services/backendApi.ts +++ b/web/services/backendApi.ts @@ -33,6 +33,7 @@ const mapAssistant = (raw: AnyRecord): Assistant => ({ prompt: readField(raw, ['prompt'], ''), knowledgeBaseId: readField(raw, ['knowledgeBaseId', 'knowledge_base_id'], ''), language: readField(raw, ['language'], 'zh') as 'zh' | 'en', + voiceOutputEnabled: Boolean(readField(raw, ['voiceOutputEnabled', 'voice_output_enabled'], true)), voice: readField(raw, ['voice'], ''), speed: Number(readField(raw, ['speed'], 1)), hotwords: readField(raw, ['hotwords'], []), @@ -210,6 +211,7 @@ export const createAssistant = async (data: Partial): Promise): Pro prompt: data.prompt, knowledgeBaseId: data.knowledgeBaseId, language: data.language, + voiceOutputEnabled: data.voiceOutputEnabled, voice: data.voice, speed: data.speed, hotwords: data.hotwords, diff --git a/web/types.ts b/web/types.ts index 9862e73..697f15d 100644 --- a/web/types.ts +++ b/web/types.ts @@ -7,6 +7,7 @@ export interface Assistant { prompt: string; knowledgeBaseId: string; language: 'zh' | 'en'; + voiceOutputEnabled?: boolean; voice: string; // This will now store the ID of the voice from Voice Library speed: number; hotwords: string[];