BaseInputTransport: push UserStartedSpeakingFrame/UserStoppedSpeakingFrame upstream
This commit is contained in:
@@ -82,6 +82,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Changed
|
||||
|
||||
- `UserStartedSpeakingFrame` and `UserStoppedSpeakingFrame` are also pushed
|
||||
upstream.
|
||||
|
||||
- `ParallelPipeline` now waits for `CancelFrame` to finish in all branches
|
||||
before pushing it downstream.
|
||||
|
||||
|
||||
@@ -299,10 +299,10 @@ class BaseInputTransport(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, EmulateUserStartedSpeakingFrame):
|
||||
logger.debug("Emulating user started speaking")
|
||||
await self._handle_user_interruption(UserStartedSpeakingFrame(emulated=True))
|
||||
await self._handle_user_interruption(VADState.SPEAKING, emulated=True)
|
||||
elif isinstance(frame, EmulateUserStoppedSpeakingFrame):
|
||||
logger.debug("Emulating user stopped speaking")
|
||||
await self._handle_user_interruption(UserStoppedSpeakingFrame(emulated=True))
|
||||
await self._handle_user_interruption(VADState.QUIET, emulated=True)
|
||||
# All other system frames
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -342,12 +342,16 @@ class BaseInputTransport(FrameProcessor):
|
||||
await self._start_interruption()
|
||||
await self.push_frame(StartInterruptionFrame())
|
||||
|
||||
async def _handle_user_interruption(self, frame: Frame):
|
||||
async def _handle_user_interruption(self, vad_state: VADState, emulated: bool = False):
|
||||
"""Handle user interruption events based on speaking state."""
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
if vad_state == VADState.SPEAKING:
|
||||
logger.debug("User started speaking")
|
||||
self._user_speaking = True
|
||||
await self.push_frame(frame)
|
||||
|
||||
upstream_frame = UserStartedSpeakingFrame(emulated=emulated)
|
||||
downstream_frame = UserStartedSpeakingFrame(emulated=emulated)
|
||||
await self.push_frame(downstream_frame)
|
||||
await self.push_frame(upstream_frame, FrameDirection.UPSTREAM)
|
||||
|
||||
# Only push StartInterruptionFrame if:
|
||||
# 1. No interruption config is set, OR
|
||||
@@ -368,10 +372,15 @@ class BaseInputTransport(FrameProcessor):
|
||||
"User started speaking while bot is speaking with interruption config - "
|
||||
"deferring interruption to aggregator"
|
||||
)
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
elif vad_state == VADState.QUIET:
|
||||
logger.debug("User stopped speaking")
|
||||
self._user_speaking = False
|
||||
await self.push_frame(frame)
|
||||
|
||||
upstream_frame = UserStoppedSpeakingFrame(emulated=emulated)
|
||||
downstream_frame = UserStoppedSpeakingFrame(emulated=emulated)
|
||||
await self.push_frame(downstream_frame)
|
||||
await self.push_frame(upstream_frame, FrameDirection.UPSTREAM)
|
||||
|
||||
if self.interruptions_allowed:
|
||||
await self._stop_interruption()
|
||||
|
||||
@@ -420,7 +429,8 @@ class BaseInputTransport(FrameProcessor):
|
||||
and new_vad_state != VADState.STARTING
|
||||
and new_vad_state != VADState.STOPPING
|
||||
):
|
||||
frame = None
|
||||
interruption_state = None
|
||||
|
||||
# If the turn analyser is enabled, this will prevent:
|
||||
# - Creating the UserStoppedSpeakingFrame
|
||||
# - Creating the UserStartedSpeakingFrame multiple times
|
||||
@@ -431,14 +441,14 @@ class BaseInputTransport(FrameProcessor):
|
||||
if new_vad_state == VADState.SPEAKING:
|
||||
await self.push_frame(VADUserStartedSpeakingFrame())
|
||||
if can_create_user_frames:
|
||||
frame = UserStartedSpeakingFrame()
|
||||
interruption_state = VADState.SPEAKING
|
||||
elif new_vad_state == VADState.QUIET:
|
||||
await self.push_frame(VADUserStoppedSpeakingFrame())
|
||||
if can_create_user_frames:
|
||||
frame = UserStoppedSpeakingFrame()
|
||||
interruption_state = VADState.QUIET
|
||||
|
||||
if frame:
|
||||
await self._handle_user_interruption(frame)
|
||||
if interruption_state:
|
||||
await self._handle_user_interruption(interruption_state)
|
||||
|
||||
vad_state = new_vad_state
|
||||
return vad_state
|
||||
@@ -453,7 +463,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
async def _handle_end_of_turn_complete(self, state: EndOfTurnState):
|
||||
"""Handle completion of end-of-turn analysis."""
|
||||
if state == EndOfTurnState.COMPLETE:
|
||||
await self._handle_user_interruption(UserStoppedSpeakingFrame())
|
||||
await self._handle_user_interruption(VADState.QUIET)
|
||||
|
||||
async def _run_turn_analyzer(
|
||||
self, frame: InputAudioRawFrame, vad_state: VADState, previous_vad_state: VADState
|
||||
@@ -507,7 +517,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
vad_state = VADState.QUIET
|
||||
if self._params.turn_analyzer:
|
||||
self._params.turn_analyzer.clear()
|
||||
await self._handle_user_interruption(UserStoppedSpeakingFrame())
|
||||
await self._handle_user_interruption(VADState.QUIET)
|
||||
|
||||
async def _handle_prediction_result(self, result: MetricsData):
|
||||
"""Handle a prediction result event from the turn analyzer."""
|
||||
|
||||
Reference in New Issue
Block a user