Add should_interrupt + broadcast user events
This commit is contained in:
@@ -204,6 +204,7 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
model: str = "solaria-1",
|
||||
params: Optional[GladiaInputParams] = None,
|
||||
max_buffer_size: int = 1024 * 1024 * 20, # 20MB default buffer
|
||||
should_interrupt: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Gladia STT service.
|
||||
@@ -222,6 +223,8 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
model: Model to use for transcription. Defaults to "solaria-1".
|
||||
params: Additional configuration parameters for Gladia service.
|
||||
max_buffer_size: Maximum size of audio buffer in bytes. Defaults to 20MB.
|
||||
should_interrupt: Determine whether the bot should be interrupted when
|
||||
Gladia VAD detects user speech. Defaults to True.
|
||||
**kwargs: Additional arguments passed to the STTService parent class.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
@@ -270,6 +273,7 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
|
||||
# VAD state tracking
|
||||
self._is_speaking = False
|
||||
self._should_interrupt = should_interrupt
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name} [{self._session_id}]"
|
||||
@@ -515,25 +519,28 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
async def _on_speech_started(self):
|
||||
"""Handle speech start event from Gladia.
|
||||
|
||||
Triggers interruption and emits UserStartedSpeakingFrame when VAD is enabled.
|
||||
Broadcasts UserStartedSpeakingFrame and optionally triggers interruption
|
||||
when VAD is enabled.
|
||||
"""
|
||||
if not self._params.enable_vad or self._is_speaking:
|
||||
return
|
||||
|
||||
logger.debug(f"{self} User started speaking")
|
||||
self._is_speaking = True
|
||||
# Push interruption first to stop the bot, then notify about user speaking
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.push_frame(UserStartedSpeakingFrame())
|
||||
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
if self._should_interrupt:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
|
||||
async def _on_speech_ended(self):
|
||||
"""Handle speech end event from Gladia.
|
||||
|
||||
Emits UserStoppedSpeakingFrame when VAD is enabled.
|
||||
Broadcasts UserStoppedSpeakingFrame when VAD is enabled.
|
||||
"""
|
||||
if not self._params.enable_vad or not self._is_speaking:
|
||||
return
|
||||
self._is_speaking = False
|
||||
await self.push_frame(UserStoppedSpeakingFrame())
|
||||
await self.broadcast_frame(UserStoppedSpeakingFrame)
|
||||
logger.debug(f"{self} User stopped speaking")
|
||||
|
||||
async def _send_audio(self, audio: bytes):
|
||||
|
||||
Reference in New Issue
Block a user