Add tts/text output schema
This commit is contained in:
@@ -211,6 +211,7 @@ class DuplexPipeline:
|
||||
self._runtime_llm: Dict[str, Any] = {}
|
||||
self._runtime_asr: Dict[str, Any] = {}
|
||||
self._runtime_tts: Dict[str, Any] = {}
|
||||
self._runtime_output: Dict[str, Any] = {}
|
||||
self._runtime_system_prompt: Optional[str] = None
|
||||
self._runtime_greeting: Optional[str] = None
|
||||
self._runtime_knowledge: Dict[str, Any] = {}
|
||||
@@ -257,6 +258,9 @@ class DuplexPipeline:
|
||||
self._runtime_asr = services["asr"]
|
||||
if isinstance(services.get("tts"), dict):
|
||||
self._runtime_tts = services["tts"]
|
||||
output = metadata.get("output") or {}
|
||||
if isinstance(output, dict):
|
||||
self._runtime_output = output
|
||||
|
||||
knowledge_base_id = metadata.get("knowledgeBaseId")
|
||||
if knowledge_base_id is not None:
|
||||
@@ -283,6 +287,31 @@ class DuplexPipeline:
|
||||
if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"):
|
||||
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
|
||||
|
||||
@staticmethod
|
||||
def _coerce_bool(value: Any) -> Optional[bool]:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, (int, float)):
|
||||
return bool(value)
|
||||
if isinstance(value, str):
|
||||
normalized = value.strip().lower()
|
||||
if normalized in {"1", "true", "yes", "on", "enabled"}:
|
||||
return True
|
||||
if normalized in {"0", "false", "no", "off", "disabled"}:
|
||||
return False
|
||||
return None
|
||||
|
||||
def _tts_output_enabled(self) -> bool:
|
||||
enabled = self._coerce_bool(self._runtime_tts.get("enabled"))
|
||||
if enabled is not None:
|
||||
return enabled
|
||||
|
||||
output_mode = str(self._runtime_output.get("mode") or "").strip().lower()
|
||||
if output_mode in {"text", "text_only", "text-only"}:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the pipeline and connect services."""
|
||||
try:
|
||||
@@ -311,38 +340,44 @@ class DuplexPipeline:
|
||||
|
||||
await self.llm_service.connect()
|
||||
|
||||
# Connect TTS service
|
||||
if not self.tts_service:
|
||||
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
||||
tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key
|
||||
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
|
||||
tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model
|
||||
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
|
||||
tts_output_enabled = self._tts_output_enabled()
|
||||
|
||||
if tts_provider == "siliconflow" and tts_api_key:
|
||||
self.tts_service = SiliconFlowTTSService(
|
||||
api_key=tts_api_key,
|
||||
voice=tts_voice,
|
||||
model=tts_model,
|
||||
sample_rate=settings.sample_rate,
|
||||
speed=tts_speed
|
||||
)
|
||||
logger.info("Using SiliconFlow TTS service")
|
||||
else:
|
||||
self.tts_service = EdgeTTSService(
|
||||
voice=tts_voice,
|
||||
# Connect TTS service only when audio output is enabled.
|
||||
if tts_output_enabled:
|
||||
if not self.tts_service:
|
||||
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
||||
tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key
|
||||
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
|
||||
tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model
|
||||
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
|
||||
|
||||
if tts_provider == "siliconflow" and tts_api_key:
|
||||
self.tts_service = SiliconFlowTTSService(
|
||||
api_key=tts_api_key,
|
||||
voice=tts_voice,
|
||||
model=tts_model,
|
||||
sample_rate=settings.sample_rate,
|
||||
speed=tts_speed
|
||||
)
|
||||
logger.info("Using SiliconFlow TTS service")
|
||||
else:
|
||||
self.tts_service = EdgeTTSService(
|
||||
voice=tts_voice,
|
||||
sample_rate=settings.sample_rate
|
||||
)
|
||||
logger.info("Using Edge TTS service")
|
||||
|
||||
try:
|
||||
await self.tts_service.connect()
|
||||
except Exception as e:
|
||||
logger.warning(f"TTS backend unavailable ({e}); falling back to MockTTS")
|
||||
self.tts_service = MockTTSService(
|
||||
sample_rate=settings.sample_rate
|
||||
)
|
||||
logger.info("Using Edge TTS service")
|
||||
|
||||
try:
|
||||
await self.tts_service.connect()
|
||||
except Exception as e:
|
||||
logger.warning(f"TTS backend unavailable ({e}); falling back to MockTTS")
|
||||
self.tts_service = MockTTSService(
|
||||
sample_rate=settings.sample_rate
|
||||
)
|
||||
await self.tts_service.connect()
|
||||
await self.tts_service.connect()
|
||||
else:
|
||||
self.tts_service = None
|
||||
logger.info("TTS output disabled by runtime metadata")
|
||||
|
||||
# Connect ASR service
|
||||
if not self.asr_service:
|
||||
@@ -375,7 +410,7 @@ class DuplexPipeline:
|
||||
self._outbound_task = asyncio.create_task(self._outbound_loop())
|
||||
|
||||
# Speak greeting if configured
|
||||
if self.conversation.greeting:
|
||||
if self.conversation.greeting and tts_output_enabled:
|
||||
await self._speak(self.conversation.greeting)
|
||||
|
||||
except Exception as e:
|
||||
@@ -932,7 +967,7 @@ class DuplexPipeline:
|
||||
pending_punctuation = sentence
|
||||
continue
|
||||
|
||||
if not self._interrupt_event.is_set():
|
||||
if self._tts_output_enabled() and not self._interrupt_event.is_set():
|
||||
if not first_audio_sent:
|
||||
await self._send_event(
|
||||
{
|
||||
@@ -952,7 +987,12 @@ class DuplexPipeline:
|
||||
)
|
||||
|
||||
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
|
||||
if remaining_text and has_spoken_content(remaining_text) and not self._interrupt_event.is_set():
|
||||
if (
|
||||
self._tts_output_enabled()
|
||||
and remaining_text
|
||||
and has_spoken_content(remaining_text)
|
||||
and not self._interrupt_event.is_set()
|
||||
):
|
||||
if not first_audio_sent:
|
||||
await self._send_event(
|
||||
{
|
||||
@@ -1066,7 +1106,10 @@ class DuplexPipeline:
|
||||
fade_in_ms: Fade-in duration for sentence start chunks
|
||||
fade_out_ms: Fade-out duration for sentence end chunks
|
||||
"""
|
||||
if not text.strip() or self._interrupt_event.is_set():
|
||||
if not self._tts_output_enabled():
|
||||
return
|
||||
|
||||
if not text.strip() or self._interrupt_event.is_set() or not self.tts_service:
|
||||
return
|
||||
|
||||
logger.info(f"[TTS] split sentence: {text!r}")
|
||||
@@ -1153,7 +1196,10 @@ class DuplexPipeline:
|
||||
Args:
|
||||
text: Text to speak
|
||||
"""
|
||||
if not text.strip():
|
||||
if not self._tts_output_enabled():
|
||||
return
|
||||
|
||||
if not text.strip() or not self.tts_service:
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
@@ -53,6 +53,9 @@ Rules:
|
||||
},
|
||||
"metadata": {
|
||||
"client": "web-debug",
|
||||
"output": {
|
||||
"mode": "audio"
|
||||
},
|
||||
"systemPrompt": "You are concise.",
|
||||
"greeting": "Hi, how can I help?",
|
||||
"services": {
|
||||
@@ -70,6 +73,7 @@ Rules:
|
||||
"minAudioMs": 300
|
||||
},
|
||||
"tts": {
|
||||
"enabled": true,
|
||||
"provider": "siliconflow",
|
||||
"model": "FunAudioLLM/CosyVoice2-0.5B",
|
||||
"apiKey": "sf-...",
|
||||
@@ -83,6 +87,10 @@ Rules:
|
||||
|
||||
`metadata.services` is optional. If omitted, server defaults to environment configuration.
|
||||
|
||||
Text-only mode:
|
||||
- Set `metadata.output.mode = "text"` OR `metadata.services.tts.enabled = false`.
|
||||
- In this mode server still sends `assistant.response.delta/final`, but will not emit audio frames or `output.audio.start/end`.
|
||||
|
||||
### `input.text`
|
||||
|
||||
```json
|
||||
|
||||
@@ -125,6 +125,36 @@ async def test_turn_without_tool_keeps_streaming(monkeypatch):
|
||||
assert "assistant.tool_call" not in event_types
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"metadata",
|
||||
[
|
||||
{"output": {"mode": "text"}},
|
||||
{"services": {"tts": {"enabled": False}}},
|
||||
],
|
||||
)
|
||||
async def test_text_output_mode_skips_audio_events(monkeypatch, metadata):
|
||||
pipeline, events = _build_pipeline(
|
||||
monkeypatch,
|
||||
[
|
||||
[
|
||||
LLMStreamEvent(type="text_delta", text="hello "),
|
||||
LLMStreamEvent(type="text_delta", text="world."),
|
||||
LLMStreamEvent(type="done"),
|
||||
]
|
||||
],
|
||||
)
|
||||
pipeline.apply_runtime_overrides(metadata)
|
||||
|
||||
await pipeline._handle_turn("hi")
|
||||
|
||||
event_types = [e.get("type") for e in events]
|
||||
assert "assistant.response.delta" in event_types
|
||||
assert "assistant.response.final" in event_types
|
||||
assert "output.audio.start" not in event_types
|
||||
assert "output.audio.end" not in event_types
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_turn_with_tool_call_then_results(monkeypatch):
|
||||
pipeline, events = _build_pipeline(
|
||||
|
||||
Reference in New Issue
Block a user