Add gated turn start strategy
This commit is contained in:
10
README.md
10
README.md
@@ -154,6 +154,9 @@ before declaring the turn done) is configurable per environment:
|
||||
"stop_secs": 0.6,
|
||||
"min_volume": 0.6
|
||||
},
|
||||
"interruption_min_chars": 3,
|
||||
"interruption_use_interim": true,
|
||||
"interruption_short_replies": ["是的", "行", "可以"],
|
||||
"user_speech_timeout_sec": 1.0
|
||||
}
|
||||
```
|
||||
@@ -162,6 +165,13 @@ before declaring the turn done) is configurable per environment:
|
||||
controls the Silero VAD. `stop_secs` is the duration of silence required
|
||||
before VAD reports the user stopped speaking; raise it if VAD is
|
||||
cutting users off mid-clause, lower it for snappier turn-taking.
|
||||
- `interruption_min_chars`, `interruption_use_interim`, and
|
||||
`interruption_short_replies` configure the custom turn-start gate used
|
||||
while the assistant is speaking. Short replies in the allowlist (for
|
||||
example, `是的`, `行`, `可以`) can barge in immediately; other text must
|
||||
contain at least `interruption_min_chars` countable characters after
|
||||
punctuation and spaces are removed. This keeps common yes/no answers while
|
||||
filtering brief background speech.
|
||||
- `user_speech_timeout_sec` is the additional grace window (used by
|
||||
`SpeechTimeoutUserTurnStopStrategy`) during which the user may resume
|
||||
speaking before the aggregator finalizes the turn. The timer is
|
||||
|
||||
21
config.json
21
config.json
@@ -21,6 +21,27 @@
|
||||
"stop_secs": 0.4,
|
||||
"min_volume": 0.6
|
||||
},
|
||||
"interruption_min_chars": 3,
|
||||
"interruption_use_interim": true,
|
||||
"interruption_short_replies": [
|
||||
"是",
|
||||
"是的",
|
||||
"对",
|
||||
"对的",
|
||||
"嗯",
|
||||
"好",
|
||||
"好的",
|
||||
"行",
|
||||
"可以",
|
||||
"没问题",
|
||||
"不是",
|
||||
"不",
|
||||
"不行",
|
||||
"不用",
|
||||
"不要",
|
||||
"没有",
|
||||
"否"
|
||||
],
|
||||
"user_speech_timeout_sec": 0.8
|
||||
},
|
||||
"agent": {
|
||||
|
||||
@@ -19,6 +19,27 @@
|
||||
"stop_secs": 0.6,
|
||||
"min_volume": 0.6
|
||||
},
|
||||
"interruption_min_chars": 3,
|
||||
"interruption_use_interim": true,
|
||||
"interruption_short_replies": [
|
||||
"是",
|
||||
"是的",
|
||||
"对",
|
||||
"对的",
|
||||
"嗯",
|
||||
"好",
|
||||
"好的",
|
||||
"行",
|
||||
"可以",
|
||||
"没问题",
|
||||
"不是",
|
||||
"不",
|
||||
"不行",
|
||||
"不用",
|
||||
"不要",
|
||||
"没有",
|
||||
"否"
|
||||
],
|
||||
"user_speech_timeout_sec": 1.0
|
||||
},
|
||||
"agent": {
|
||||
|
||||
@@ -61,6 +61,33 @@ class TurnConfig:
|
||||
|
||||
vad: VADConfig = field(default_factory=VADConfig)
|
||||
user_speech_timeout_sec: float = 1.0
|
||||
interruption_min_chars: int = 3
|
||||
interruption_use_interim: bool = True
|
||||
interruption_short_replies: list[str] = field(
|
||||
default_factory=lambda: [
|
||||
"是",
|
||||
"是的",
|
||||
"对",
|
||||
"对的",
|
||||
"嗯",
|
||||
"好",
|
||||
"好的",
|
||||
"行",
|
||||
"可以",
|
||||
"没问题",
|
||||
"不是",
|
||||
"不",
|
||||
"不行",
|
||||
"不用",
|
||||
"不要",
|
||||
"没有",
|
||||
"否",
|
||||
"no",
|
||||
"yes",
|
||||
"ok",
|
||||
"okay",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -165,6 +192,18 @@ def config_from_dict(data: dict) -> EngineConfig:
|
||||
user_speech_timeout_sec=float(
|
||||
turn.get("user_speech_timeout_sec", TurnConfig().user_speech_timeout_sec)
|
||||
),
|
||||
interruption_min_chars=int(
|
||||
turn.get("interruption_min_chars", TurnConfig().interruption_min_chars)
|
||||
),
|
||||
interruption_use_interim=bool(
|
||||
turn.get("interruption_use_interim", TurnConfig().interruption_use_interim)
|
||||
),
|
||||
interruption_short_replies=list(
|
||||
turn.get(
|
||||
"interruption_short_replies",
|
||||
TurnConfig().interruption_short_replies,
|
||||
)
|
||||
),
|
||||
),
|
||||
agent=AgentConfig(**agent),
|
||||
services=ServicesConfig(
|
||||
|
||||
@@ -36,6 +36,7 @@ from .services import create_llm_service, create_stt_service, create_tts_service
|
||||
from .text_input import ProductTextInputProcessor
|
||||
from .text_stream import ProductTextStreamProcessor
|
||||
from .transcript_stream import ProductTranscriptStreamProcessor
|
||||
from .turn_start import InterruptionGateUserTurnStartStrategy
|
||||
|
||||
|
||||
async def run_voice_pipeline(websocket, config: EngineConfig) -> None:
|
||||
@@ -104,6 +105,13 @@ async def run_pipeline_with_serializer(
|
||||
# (re-armed every time the user resumes speaking) before declaring the
|
||||
# turn finished — which is what we actually want for streaming ASRs.
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[
|
||||
InterruptionGateUserTurnStartStrategy(
|
||||
min_chars_when_bot_speaking=config.turn.interruption_min_chars,
|
||||
allowed_short_replies=config.turn.interruption_short_replies,
|
||||
use_interim=config.turn.interruption_use_interim,
|
||||
),
|
||||
],
|
||||
stop=[
|
||||
SpeechTimeoutUserTurnStopStrategy(
|
||||
user_speech_timeout=config.turn.user_speech_timeout_sec,
|
||||
|
||||
91
engine/turn_start.py
Normal file
91
engine/turn_start.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.turns.types import ProcessFrameResult
|
||||
from pipecat.turns.user_start.base_user_turn_start_strategy import BaseUserTurnStartStrategy
|
||||
|
||||
|
||||
_COUNTABLE_TEXT_RE = re.compile(r"[\w\u4e00-\u9fff]", re.UNICODE)
|
||||
|
||||
|
||||
class InterruptionGateUserTurnStartStrategy(BaseUserTurnStartStrategy):
|
||||
"""Starts user turns only after likely intentional speech.
|
||||
|
||||
When the assistant is speaking, short background speech should not barge in
|
||||
unless it is a common answer to a yes/no style question. When the assistant
|
||||
is not speaking, any non-empty transcript can start a normal user turn.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
min_chars_when_bot_speaking: int,
|
||||
allowed_short_replies: list[str],
|
||||
use_interim: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._min_chars_when_bot_speaking = min_chars_when_bot_speaking
|
||||
self._allowed_short_replies = {
|
||||
self._normalize_text(reply) for reply in allowed_short_replies if reply.strip()
|
||||
}
|
||||
self._use_interim = use_interim
|
||||
self._bot_speaking = False
|
||||
|
||||
async def reset(self):
|
||||
await super().reset()
|
||||
|
||||
async def process_frame(self, frame: Frame) -> ProcessFrameResult:
|
||||
if isinstance(frame, BotStartedSpeakingFrame):
|
||||
self._bot_speaking = True
|
||||
return ProcessFrameResult.CONTINUE
|
||||
if isinstance(frame, BotStoppedSpeakingFrame):
|
||||
self._bot_speaking = False
|
||||
return ProcessFrameResult.CONTINUE
|
||||
if isinstance(frame, InterimTranscriptionFrame) and self._use_interim:
|
||||
return await self._handle_transcription(frame.text, interim=True)
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
return await self._handle_transcription(frame.text, interim=False)
|
||||
|
||||
return ProcessFrameResult.CONTINUE
|
||||
|
||||
async def _handle_transcription(self, text: str, *, interim: bool) -> ProcessFrameResult:
|
||||
normalized = self._normalize_text(text)
|
||||
if not normalized:
|
||||
return ProcessFrameResult.CONTINUE
|
||||
|
||||
if not self._bot_speaking:
|
||||
await self.trigger_user_turn_started()
|
||||
return ProcessFrameResult.STOP
|
||||
|
||||
should_interrupt = self._should_interrupt(normalized)
|
||||
logger.debug(
|
||||
f"{self} interruption_gate text={text!r} normalized={normalized!r} "
|
||||
f"should_interrupt={should_interrupt} interim={interim}"
|
||||
)
|
||||
|
||||
if should_interrupt:
|
||||
await self.trigger_user_turn_started()
|
||||
return ProcessFrameResult.STOP
|
||||
|
||||
await self.trigger_reset_aggregation()
|
||||
return ProcessFrameResult.CONTINUE
|
||||
|
||||
def _should_interrupt(self, normalized: str) -> bool:
|
||||
return (
|
||||
normalized in self._allowed_short_replies
|
||||
or len(normalized) >= self._min_chars_when_bot_speaking
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_text(text: str) -> str:
|
||||
return "".join(_COUNTABLE_TEXT_RE.findall(text.lower()))
|
||||
Reference in New Issue
Block a user