diff --git a/README.md b/README.md index 9efff83..165b869 100644 --- a/README.md +++ b/README.md @@ -154,6 +154,9 @@ before declaring the turn done) is configurable per environment: "stop_secs": 0.6, "min_volume": 0.6 }, + "interruption_min_chars": 3, + "interruption_use_interim": true, + "interruption_short_replies": ["是的", "行", "可以"], "user_speech_timeout_sec": 1.0 } ``` @@ -162,6 +165,13 @@ before declaring the turn done) is configurable per environment: controls the Silero VAD. `stop_secs` is the duration of silence required before VAD reports the user stopped speaking; raise it if VAD is cutting users off mid-clause, lower it for snappier turn-taking. +- `interruption_min_chars`, `interruption_use_interim`, and + `interruption_short_replies` configure the custom turn-start gate used + while the assistant is speaking. Short replies in the allowlist (for + example, `是的`, `行`, `可以`) can barge in immediately; other text must + contain at least `interruption_min_chars` countable characters after + punctuation and spaces are removed. This keeps common yes/no answers while + filtering brief background speech. - `user_speech_timeout_sec` is the additional grace window (used by `SpeechTimeoutUserTurnStopStrategy`) during which the user may resume speaking before the aggregator finalizes the turn. The timer is diff --git a/config.json b/config.json index c14ab4e..57e3a93 100644 --- a/config.json +++ b/config.json @@ -21,6 +21,27 @@ "stop_secs": 0.4, "min_volume": 0.6 }, + "interruption_min_chars": 3, + "interruption_use_interim": true, + "interruption_short_replies": [ + "是", + "是的", + "对", + "对的", + "嗯", + "好", + "好的", + "行", + "可以", + "没问题", + "不是", + "不", + "不行", + "不用", + "不要", + "没有", + "否" + ], "user_speech_timeout_sec": 0.8 }, "agent": { diff --git a/config/xfyun.json b/config/xfyun.json index 2c37c8c..e0cc11b 100644 --- a/config/xfyun.json +++ b/config/xfyun.json @@ -19,6 +19,27 @@ "stop_secs": 0.6, "min_volume": 0.6 }, + "interruption_min_chars": 3, + "interruption_use_interim": true, + "interruption_short_replies": [ + "是", + "是的", + "对", + "对的", + "嗯", + "好", + "好的", + "行", + "可以", + "没问题", + "不是", + "不", + "不行", + "不用", + "不要", + "没有", + "否" + ], "user_speech_timeout_sec": 1.0 }, "agent": { diff --git a/engine/config.py b/engine/config.py index 26b7f23..962180b 100644 --- a/engine/config.py +++ b/engine/config.py @@ -61,6 +61,33 @@ class TurnConfig: vad: VADConfig = field(default_factory=VADConfig) user_speech_timeout_sec: float = 1.0 + interruption_min_chars: int = 3 + interruption_use_interim: bool = True + interruption_short_replies: list[str] = field( + default_factory=lambda: [ + "是", + "是的", + "对", + "对的", + "嗯", + "好", + "好的", + "行", + "可以", + "没问题", + "不是", + "不", + "不行", + "不用", + "不要", + "没有", + "否", + "no", + "yes", + "ok", + "okay", + ] + ) @dataclass(frozen=True) @@ -165,6 +192,18 @@ def config_from_dict(data: dict) -> EngineConfig: user_speech_timeout_sec=float( turn.get("user_speech_timeout_sec", TurnConfig().user_speech_timeout_sec) ), + interruption_min_chars=int( + turn.get("interruption_min_chars", TurnConfig().interruption_min_chars) + ), + interruption_use_interim=bool( + turn.get("interruption_use_interim", TurnConfig().interruption_use_interim) + ), + interruption_short_replies=list( + turn.get( + "interruption_short_replies", + TurnConfig().interruption_short_replies, + ) + ), ), agent=AgentConfig(**agent), services=ServicesConfig( diff --git a/engine/pipeline.py b/engine/pipeline.py index 517e54b..1dcb6b7 100644 --- a/engine/pipeline.py +++ b/engine/pipeline.py @@ -36,6 +36,7 @@ from .services import create_llm_service, create_stt_service, create_tts_service from .text_input import ProductTextInputProcessor from .text_stream import ProductTextStreamProcessor from .transcript_stream import ProductTranscriptStreamProcessor +from .turn_start import InterruptionGateUserTurnStartStrategy async def run_voice_pipeline(websocket, config: EngineConfig) -> None: @@ -104,6 +105,13 @@ async def run_pipeline_with_serializer( # (re-armed every time the user resumes speaking) before declaring the # turn finished — which is what we actually want for streaming ASRs. user_turn_strategies = UserTurnStrategies( + start=[ + InterruptionGateUserTurnStartStrategy( + min_chars_when_bot_speaking=config.turn.interruption_min_chars, + allowed_short_replies=config.turn.interruption_short_replies, + use_interim=config.turn.interruption_use_interim, + ), + ], stop=[ SpeechTimeoutUserTurnStopStrategy( user_speech_timeout=config.turn.user_speech_timeout_sec, diff --git a/engine/turn_start.py b/engine/turn_start.py new file mode 100644 index 0000000..929ac75 --- /dev/null +++ b/engine/turn_start.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import re + +from loguru import logger +from pipecat.frames.frames import ( + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, + Frame, + InterimTranscriptionFrame, + TranscriptionFrame, +) +from pipecat.turns.types import ProcessFrameResult +from pipecat.turns.user_start.base_user_turn_start_strategy import BaseUserTurnStartStrategy + + +_COUNTABLE_TEXT_RE = re.compile(r"[\w\u4e00-\u9fff]", re.UNICODE) + + +class InterruptionGateUserTurnStartStrategy(BaseUserTurnStartStrategy): + """Starts user turns only after likely intentional speech. + + When the assistant is speaking, short background speech should not barge in + unless it is a common answer to a yes/no style question. When the assistant + is not speaking, any non-empty transcript can start a normal user turn. + """ + + def __init__( + self, + *, + min_chars_when_bot_speaking: int, + allowed_short_replies: list[str], + use_interim: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self._min_chars_when_bot_speaking = min_chars_when_bot_speaking + self._allowed_short_replies = { + self._normalize_text(reply) for reply in allowed_short_replies if reply.strip() + } + self._use_interim = use_interim + self._bot_speaking = False + + async def reset(self): + await super().reset() + + async def process_frame(self, frame: Frame) -> ProcessFrameResult: + if isinstance(frame, BotStartedSpeakingFrame): + self._bot_speaking = True + return ProcessFrameResult.CONTINUE + if isinstance(frame, BotStoppedSpeakingFrame): + self._bot_speaking = False + return ProcessFrameResult.CONTINUE + if isinstance(frame, InterimTranscriptionFrame) and self._use_interim: + return await self._handle_transcription(frame.text, interim=True) + if isinstance(frame, TranscriptionFrame): + return await self._handle_transcription(frame.text, interim=False) + + return ProcessFrameResult.CONTINUE + + async def _handle_transcription(self, text: str, *, interim: bool) -> ProcessFrameResult: + normalized = self._normalize_text(text) + if not normalized: + return ProcessFrameResult.CONTINUE + + if not self._bot_speaking: + await self.trigger_user_turn_started() + return ProcessFrameResult.STOP + + should_interrupt = self._should_interrupt(normalized) + logger.debug( + f"{self} interruption_gate text={text!r} normalized={normalized!r} " + f"should_interrupt={should_interrupt} interim={interim}" + ) + + if should_interrupt: + await self.trigger_user_turn_started() + return ProcessFrameResult.STOP + + await self.trigger_reset_aggregation() + return ProcessFrameResult.CONTINUE + + def _should_interrupt(self, normalized: str) -> bool: + return ( + normalized in self._allowed_short_replies + or len(normalized) >= self._min_chars_when_bot_speaking + ) + + @staticmethod + def _normalize_text(text: str) -> str: + return "".join(_COUNTABLE_TEXT_RE.findall(text.lower()))