from __future__ import annotations import re from loguru import logger from pipecat.frames.frames import ( BotStartedSpeakingFrame, BotStoppedSpeakingFrame, Frame, InterimTranscriptionFrame, TranscriptionFrame, ) from pipecat.turns.types import ProcessFrameResult from pipecat.turns.user_start.base_user_turn_start_strategy import BaseUserTurnStartStrategy _COUNTABLE_TEXT_RE = re.compile(r"[\w\u4e00-\u9fff]", re.UNICODE) class InterruptionGateUserTurnStartStrategy(BaseUserTurnStartStrategy): """Starts user turns only after likely intentional speech. When the assistant is speaking, short background speech should not barge in unless it is a common answer to a yes/no style question. When the assistant is not speaking, any non-empty transcript can start a normal user turn. """ def __init__( self, *, min_chars_when_bot_speaking: int, allowed_short_replies: list[str], use_interim: bool = True, **kwargs, ): super().__init__(**kwargs) self._min_chars_when_bot_speaking = min_chars_when_bot_speaking self._allowed_short_replies = { self._normalize_text(reply) for reply in allowed_short_replies if reply.strip() } self._use_interim = use_interim self._bot_speaking = False async def reset(self): await super().reset() async def process_frame(self, frame: Frame) -> ProcessFrameResult: if isinstance(frame, BotStartedSpeakingFrame): self._bot_speaking = True return ProcessFrameResult.CONTINUE if isinstance(frame, BotStoppedSpeakingFrame): self._bot_speaking = False return ProcessFrameResult.CONTINUE if isinstance(frame, InterimTranscriptionFrame) and self._use_interim: return await self._handle_transcription(frame.text, interim=True) if isinstance(frame, TranscriptionFrame): return await self._handle_transcription(frame.text, interim=False) return ProcessFrameResult.CONTINUE async def _handle_transcription(self, text: str, *, interim: bool) -> ProcessFrameResult: normalized = self._normalize_text(text) if not normalized: return ProcessFrameResult.CONTINUE if not self._bot_speaking: await self.trigger_user_turn_started() return ProcessFrameResult.STOP should_interrupt = self._should_interrupt(normalized) logger.debug( f"{self} interruption_gate text={text!r} normalized={normalized!r} " f"should_interrupt={should_interrupt} interim={interim}" ) if should_interrupt: await self.trigger_user_turn_started() return ProcessFrameResult.STOP await self.trigger_reset_aggregation() return ProcessFrameResult.CONTINUE def _should_interrupt(self, normalized: str) -> bool: return ( normalized in self._allowed_short_replies or len(normalized) >= self._min_chars_when_bot_speaking ) @staticmethod def _normalize_text(text: str) -> str: return "".join(_COUNTABLE_TEXT_RE.findall(text.lower()))