92 lines
3.1 KiB
Python
92 lines
3.1 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
|
|
from loguru import logger
|
|
from pipecat.frames.frames import (
|
|
BotStartedSpeakingFrame,
|
|
BotStoppedSpeakingFrame,
|
|
Frame,
|
|
InterimTranscriptionFrame,
|
|
TranscriptionFrame,
|
|
)
|
|
from pipecat.turns.types import ProcessFrameResult
|
|
from pipecat.turns.user_start.base_user_turn_start_strategy import BaseUserTurnStartStrategy
|
|
|
|
|
|
_COUNTABLE_TEXT_RE = re.compile(r"[\w\u4e00-\u9fff]", re.UNICODE)
|
|
|
|
|
|
class InterruptionGateUserTurnStartStrategy(BaseUserTurnStartStrategy):
|
|
"""Starts user turns only after likely intentional speech.
|
|
|
|
When the assistant is speaking, short background speech should not barge in
|
|
unless it is a common answer to a yes/no style question. When the assistant
|
|
is not speaking, any non-empty transcript can start a normal user turn.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
min_chars_when_bot_speaking: int,
|
|
allowed_short_replies: list[str],
|
|
use_interim: bool = True,
|
|
**kwargs,
|
|
):
|
|
super().__init__(**kwargs)
|
|
self._min_chars_when_bot_speaking = min_chars_when_bot_speaking
|
|
self._allowed_short_replies = {
|
|
self._normalize_text(reply) for reply in allowed_short_replies if reply.strip()
|
|
}
|
|
self._use_interim = use_interim
|
|
self._bot_speaking = False
|
|
|
|
async def reset(self):
|
|
await super().reset()
|
|
|
|
async def process_frame(self, frame: Frame) -> ProcessFrameResult:
|
|
if isinstance(frame, BotStartedSpeakingFrame):
|
|
self._bot_speaking = True
|
|
return ProcessFrameResult.CONTINUE
|
|
if isinstance(frame, BotStoppedSpeakingFrame):
|
|
self._bot_speaking = False
|
|
return ProcessFrameResult.CONTINUE
|
|
if isinstance(frame, InterimTranscriptionFrame) and self._use_interim:
|
|
return await self._handle_transcription(frame.text, interim=True)
|
|
if isinstance(frame, TranscriptionFrame):
|
|
return await self._handle_transcription(frame.text, interim=False)
|
|
|
|
return ProcessFrameResult.CONTINUE
|
|
|
|
async def _handle_transcription(self, text: str, *, interim: bool) -> ProcessFrameResult:
|
|
normalized = self._normalize_text(text)
|
|
if not normalized:
|
|
return ProcessFrameResult.CONTINUE
|
|
|
|
if not self._bot_speaking:
|
|
await self.trigger_user_turn_started()
|
|
return ProcessFrameResult.STOP
|
|
|
|
should_interrupt = self._should_interrupt(normalized)
|
|
logger.debug(
|
|
f"{self} interruption_gate text={text!r} normalized={normalized!r} "
|
|
f"should_interrupt={should_interrupt} interim={interim}"
|
|
)
|
|
|
|
if should_interrupt:
|
|
await self.trigger_user_turn_started()
|
|
return ProcessFrameResult.STOP
|
|
|
|
await self.trigger_reset_aggregation()
|
|
return ProcessFrameResult.CONTINUE
|
|
|
|
def _should_interrupt(self, normalized: str) -> bool:
|
|
return (
|
|
normalized in self._allowed_short_replies
|
|
or len(normalized) >= self._min_chars_when_bot_speaking
|
|
)
|
|
|
|
@staticmethod
|
|
def _normalize_text(text: str) -> str:
|
|
return "".join(_COUNTABLE_TEXT_RE.findall(text.lower()))
|