BaseInputTransport: push UserStartedSpeakingFrame/UserStoppedSpeakingFrame upstream

This commit is contained in:
Aleix Conchillo Flaqué
2025-09-03 13:51:20 -07:00
parent 09d6ec1098
commit 5a4c6b9618
2 changed files with 27 additions and 14 deletions

View File

@@ -82,6 +82,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- `UserStartedSpeakingFrame` and `UserStoppedSpeakingFrame` are also pushed
upstream.
- `ParallelPipeline` now waits for `CancelFrame` to finish in all branches
before pushing it downstream.

View File

@@ -299,10 +299,10 @@ class BaseInputTransport(FrameProcessor):
await self.push_frame(frame, direction)
elif isinstance(frame, EmulateUserStartedSpeakingFrame):
logger.debug("Emulating user started speaking")
await self._handle_user_interruption(UserStartedSpeakingFrame(emulated=True))
await self._handle_user_interruption(VADState.SPEAKING, emulated=True)
elif isinstance(frame, EmulateUserStoppedSpeakingFrame):
logger.debug("Emulating user stopped speaking")
await self._handle_user_interruption(UserStoppedSpeakingFrame(emulated=True))
await self._handle_user_interruption(VADState.QUIET, emulated=True)
# All other system frames
elif isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
@@ -342,12 +342,16 @@ class BaseInputTransport(FrameProcessor):
await self._start_interruption()
await self.push_frame(StartInterruptionFrame())
async def _handle_user_interruption(self, frame: Frame):
async def _handle_user_interruption(self, vad_state: VADState, emulated: bool = False):
"""Handle user interruption events based on speaking state."""
if isinstance(frame, UserStartedSpeakingFrame):
if vad_state == VADState.SPEAKING:
logger.debug("User started speaking")
self._user_speaking = True
await self.push_frame(frame)
upstream_frame = UserStartedSpeakingFrame(emulated=emulated)
downstream_frame = UserStartedSpeakingFrame(emulated=emulated)
await self.push_frame(downstream_frame)
await self.push_frame(upstream_frame, FrameDirection.UPSTREAM)
# Only push StartInterruptionFrame if:
# 1. No interruption config is set, OR
@@ -368,10 +372,15 @@ class BaseInputTransport(FrameProcessor):
"User started speaking while bot is speaking with interruption config - "
"deferring interruption to aggregator"
)
elif isinstance(frame, UserStoppedSpeakingFrame):
elif vad_state == VADState.QUIET:
logger.debug("User stopped speaking")
self._user_speaking = False
await self.push_frame(frame)
upstream_frame = UserStoppedSpeakingFrame(emulated=emulated)
downstream_frame = UserStoppedSpeakingFrame(emulated=emulated)
await self.push_frame(downstream_frame)
await self.push_frame(upstream_frame, FrameDirection.UPSTREAM)
if self.interruptions_allowed:
await self._stop_interruption()
@@ -420,7 +429,8 @@ class BaseInputTransport(FrameProcessor):
and new_vad_state != VADState.STARTING
and new_vad_state != VADState.STOPPING
):
frame = None
interruption_state = None
# If the turn analyser is enabled, this will prevent:
# - Creating the UserStoppedSpeakingFrame
# - Creating the UserStartedSpeakingFrame multiple times
@@ -431,14 +441,14 @@ class BaseInputTransport(FrameProcessor):
if new_vad_state == VADState.SPEAKING:
await self.push_frame(VADUserStartedSpeakingFrame())
if can_create_user_frames:
frame = UserStartedSpeakingFrame()
interruption_state = VADState.SPEAKING
elif new_vad_state == VADState.QUIET:
await self.push_frame(VADUserStoppedSpeakingFrame())
if can_create_user_frames:
frame = UserStoppedSpeakingFrame()
interruption_state = VADState.QUIET
if frame:
await self._handle_user_interruption(frame)
if interruption_state:
await self._handle_user_interruption(interruption_state)
vad_state = new_vad_state
return vad_state
@@ -453,7 +463,7 @@ class BaseInputTransport(FrameProcessor):
async def _handle_end_of_turn_complete(self, state: EndOfTurnState):
"""Handle completion of end-of-turn analysis."""
if state == EndOfTurnState.COMPLETE:
await self._handle_user_interruption(UserStoppedSpeakingFrame())
await self._handle_user_interruption(VADState.QUIET)
async def _run_turn_analyzer(
self, frame: InputAudioRawFrame, vad_state: VADState, previous_vad_state: VADState
@@ -507,7 +517,7 @@ class BaseInputTransport(FrameProcessor):
vad_state = VADState.QUIET
if self._params.turn_analyzer:
self._params.turn_analyzer.clear()
await self._handle_user_interruption(UserStoppedSpeakingFrame())
await self._handle_user_interruption(VADState.QUIET)
async def _handle_prediction_result(self, result: MetricsData):
"""Handle a prediction result event from the turn analyzer."""