Compare commits

...

3 Commits

Author SHA1 Message Date
James Hush
8bbfa829d3 Remove wait 2025-11-26 12:27:02 +01:00
James Hush
c2eb663bdc Add TurnAwareTranscriptProcessor for turn-based transcript tracking
- Implements TurnAwareTranscriptProcessor that combines user and assistant transcript tracking with turn boundary detection
- Correctly handles interruptions by capturing only what was actually spoken
- Emits on_turn_started and on_turn_ended events with accumulated transcripts
- Handles async frame processing with strategic delays to ensure proper text accumulation
- Adds comprehensive tests covering basic flow, interruptions, and multiple turns
- Includes documentation and usage examples
2025-11-26 12:26:25 +01:00
James Hush
bf055843e6 Fix race condition in DeepgramFluxSTTService reconnection
Moved _receive_task and _watchdog_task creation from _connect_websocket() to _connect() to prevent multiple coroutines from attempting to receive from the websocket simultaneously during reconnection.

Previously, when reconnection occurred, _connect_websocket() would be called while the existing _receive_task was still running, causing both to try to receive from the websocket. This resulted in the error: 'cannot call recv while another coroutine is already running recv or recv_streaming'.

Now tasks are created only once during initial connection, and reconnection only re-establishes the websocket connection itself. This matches the pattern used by other websocket services in the codebase.

Fixes issue reported in 0.0.95 where reconnection attempts would fail with recv errors.
2025-11-26 10:11:19 +01:00
4 changed files with 566 additions and 10 deletions

View File

@@ -0,0 +1,103 @@
# TurnAwareTranscriptProcessor Example
## Overview
The `TurnAwareTranscriptProcessor` combines user and assistant transcript tracking with turn boundary detection. It correctly handles interruptions by only capturing what was actually spoken.
## Basic Usage
```python
from pipecat.processors.transcript_processor import TurnAwareTranscriptProcessor
# Create the processor
turn_processor = TurnAwareTranscriptProcessor()
# Register event handlers
@turn_processor.event_handler("on_turn_started")
async def handle_turn_started(processor, turn_number):
print(f"Turn {turn_number} started")
@turn_processor.event_handler("on_turn_ended")
async def handle_turn_ended(processor, turn_number, user_text, assistant_text, was_interrupted):
print(f"\nTurn {turn_number} ended:")
print(f" User said: {user_text}")
print(f" Assistant said: {assistant_text}")
print(f" Was interrupted: {was_interrupted}")
@turn_processor.event_handler("on_transcript_update")
async def handle_transcript_update(processor, frame):
for msg in frame.messages:
print(f"[{msg.role}]: {msg.content}")
# Add to pipeline
pipeline = Pipeline([
transport.input(),
stt,
turn_processor, # Process transcripts and track turns
context_aggregator.user(),
llm,
tts,
transport.output(),
context_aggregator.assistant(),
])
```
## Features
1. **Turn Boundary Detection**: Automatically detects when turns start and end based on user and bot speaking patterns
2. **Interruption Handling**: Correctly captures only what was actually spoken when interruptions occur
3. **Real-time Transcripts**: Emits transcript messages for both user and assistant speech
4. **Turn Events**: Provides start/end events with accumulated transcripts for each turn
## Events
### on_turn_started
Emitted when a new turn begins (user starts speaking).
**Handler signature**: `async def handler(processor, turn_number)`
### on_turn_ended
Emitted when a turn ends with accumulated transcripts.
**Handler signature**: `async def handler(processor, turn_number, user_transcript, assistant_transcript, was_interrupted)`
### on_transcript_update
Inherited from `BaseTranscriptProcessor`, emitted for individual transcript messages.
**Handler signature**: `async def handler(processor, frame)`
## Turn Logic
- Turns start when the user begins speaking (`UserStartedSpeakingFrame`)
- Turns end when:
- The user starts speaking again (previous turn ends, new turn starts)
- The bot is interrupted (`InterruptionFrame`)
- The pipeline ends (`EndFrame`/`CancelFrame`)
## Integration with OpenTelemetry
You can use turn events to enrich OpenTelemetry spans:
```python
from pipecat.utils.tracing.turn_trace_observer import TurnTraceObserver
turn_tracker = TurnTrackingObserver()
turn_tracer = TurnTraceObserver(turn_tracker)
turn_processor = TurnAwareTranscriptProcessor()
@turn_processor.event_handler("on_turn_ended")
async def add_transcripts_to_span(processor, turn_number, user_text, assistant_text, interrupted):
# Get current span and add transcript data
from opentelemetry import trace
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("turn.user_text", user_text)
current_span.set_attribute("turn.assistant_text", assistant_text)
```
## Notes
- The processor handles async frame processing correctly by delaying turn end until frames are processed
- Works with word-level timestamps from TTS services like Cartesia
- Accumulates both user (`TranscriptionFrame`) and assistant (`TTSTextFrame`) speech
- Emits individual transcript messages in addition to turn-level aggregation

View File

@@ -15,6 +15,7 @@ from typing import List, Optional
from loguru import logger
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
@@ -24,6 +25,7 @@ from pipecat.frames.frames import (
TranscriptionMessage,
TranscriptionUpdateFrame,
TTSTextFrame,
UserStartedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.string import TextPartForConcatenation, concatenate_aggregated_text
@@ -306,3 +308,267 @@ class TranscriptProcessor:
return handler
return decorator
class TurnAwareTranscriptProcessor(BaseTranscriptProcessor):
"""Processes transcripts with turn boundary awareness.
This processor combines user and assistant transcript tracking with turn
detection, emitting events when turns start and end. It correctly handles
interruptions by only capturing what was actually spoken.
Turn boundaries are detected based on:
- User started speaking (UserStartedSpeakingFrame)
- Bot stopped speaking (BotStoppedSpeakingFrame)
- Interruptions (InterruptionFrame)
Events:
on_turn_started: Emitted when a new turn begins.
Handler signature: async def handler(processor, turn_number)
on_turn_ended: Emitted when a turn ends.
Handler signature: async def handler(processor, turn_number,
user_transcript, assistant_transcript,
was_interrupted)
on_transcript_update: Inherited from BaseTranscriptProcessor, emitted for
individual transcript messages.
Example::
turn_processor = TurnAwareTranscriptProcessor()
@turn_processor.event_handler("on_turn_started")
async def handle_turn_started(processor, turn_number):
print(f"Turn {turn_number} started")
@turn_processor.event_handler("on_turn_ended")
async def handle_turn_ended(processor, turn_number, user_text, assistant_text, interrupted):
print(f"Turn {turn_number} ended")
print(f"User said: {user_text}")
print(f"Assistant said: {assistant_text}")
print(f"Was interrupted: {interrupted}")
pipeline = Pipeline([
transport.input(),
stt,
turn_processor,
context_aggregator.user(),
llm,
tts,
transport.output(),
context_aggregator.assistant(),
])
"""
def __init__(self, **kwargs):
"""Initialize the turn-aware transcript processor.
Args:
**kwargs: Additional arguments passed to parent class.
"""
super().__init__(**kwargs)
# Turn tracking state
self._turn_number = 0
self._turn_active = False
self._turn_start_time: Optional[str] = None
# Accumulate text for current turn
self._current_turn_user_parts: List[TextPartForConcatenation] = []
self._current_turn_assistant_parts: List[TextPartForConcatenation] = []
# Track bot speaking state
self._bot_is_speaking = False
# Register turn events
self._register_event_handler("on_turn_started")
self._register_event_handler("on_turn_ended")
async def _start_turn(self):
"""Start a new turn."""
if not self._turn_active:
self._turn_number += 1
self._turn_active = True
self._turn_start_time = time_now_iso8601()
self._current_turn_user_parts = []
self._current_turn_assistant_parts = []
logger.debug(f"Turn {self._turn_number} started")
await self._call_event_handler("on_turn_started", self._turn_number)
async def _end_turn(self, was_interrupted: bool = False):
"""End the current turn and emit aggregated transcripts.
Args:
was_interrupted: Whether the turn ended due to an interruption.
"""
if not self._turn_active:
return
# Aggregate user text
user_transcript = ""
if self._current_turn_user_parts:
user_transcript = concatenate_aggregated_text(self._current_turn_user_parts)
# Aggregate assistant text
assistant_transcript = ""
if self._current_turn_assistant_parts:
assistant_transcript = concatenate_aggregated_text(self._current_turn_assistant_parts)
# Emit turn ended event
logger.debug(
f"Turn {self._turn_number} ended (interrupted={was_interrupted}). "
f"User: '{user_transcript}', Assistant: '{assistant_transcript}'"
)
await self._call_event_handler(
"on_turn_ended",
self._turn_number,
user_transcript,
assistant_transcript,
was_interrupted,
)
# Reset turn state
self._turn_active = False
self._current_turn_user_parts = []
self._current_turn_assistant_parts = []
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames for turn-aware transcript tracking.
Handles:
- UserStartedSpeakingFrame: Start new turn
- TranscriptionFrame: Accumulate user speech and emit transcript message
- BotStartedSpeakingFrame: Track bot speaking state
- TTSTextFrame: Accumulate assistant speech
- BotStoppedSpeakingFrame: End turn if no interruption pending
- InterruptionFrame: End turn immediately as interrupted
- EndFrame/CancelFrame: End any active turn
Args:
frame: Input frame to process.
direction: Frame processing direction.
"""
await super().process_frame(frame, direction)
if isinstance(frame, UserStartedSpeakingFrame):
# User started speaking
if self._bot_is_speaking:
# This is an interruption - end the current turn with what was spoken
if self._current_turn_assistant_parts:
assistant_content = concatenate_aggregated_text(
self._current_turn_assistant_parts
)
if assistant_content:
message = TranscriptionMessage(
role="assistant",
content=assistant_content,
timestamp=self._turn_start_time or time_now_iso8601(),
)
await self._emit_update([message])
await self._end_turn(was_interrupted=True)
self._bot_is_speaking = False
elif self._turn_active:
# Previous turn is ending normally (bot finished speaking)
if self._current_turn_assistant_parts:
assistant_content = concatenate_aggregated_text(
self._current_turn_assistant_parts
)
if assistant_content:
message = TranscriptionMessage(
role="assistant",
content=assistant_content,
timestamp=self._turn_start_time or time_now_iso8601(),
)
await self._emit_update([message])
await self._end_turn(was_interrupted=False)
# Start a new turn
await self._start_turn()
await self.push_frame(frame, direction)
elif isinstance(frame, TranscriptionFrame):
# Accumulate user speech for the current turn
if self._turn_active:
self._current_turn_user_parts.append(
TextPartForConcatenation(frame.text, includes_inter_part_spaces=True)
)
# Also emit individual transcript message
message = TranscriptionMessage(
role="user",
user_id=frame.user_id,
content=frame.text,
timestamp=frame.timestamp,
)
await self._emit_update([message])
await self.push_frame(frame, direction)
elif isinstance(frame, BotStartedSpeakingFrame):
# Bot started speaking
self._bot_is_speaking = True
await self.push_frame(frame, direction)
elif isinstance(frame, TTSTextFrame):
# Accumulate assistant speech for the current turn
if self._turn_active:
self._current_turn_assistant_parts.append(
TextPartForConcatenation(
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
)
)
await self.push_frame(frame, direction)
elif isinstance(frame, BotStoppedSpeakingFrame):
# Bot stopped speaking - just mark it, don't end turn yet
# Turn will end when next user speaks or pipeline ends
self._bot_is_speaking = False
await self.push_frame(frame, direction)
elif isinstance(frame, InterruptionFrame):
# Emit assistant transcript message with what was spoken before interruption
if self._current_turn_assistant_parts:
assistant_content = concatenate_aggregated_text(self._current_turn_assistant_parts)
if assistant_content:
message = TranscriptionMessage(
role="assistant",
content=assistant_content,
timestamp=self._turn_start_time or time_now_iso8601(),
)
await self._emit_update([message])
# Push frame first to ensure proper cleanup
await self.push_frame(frame, direction)
# End turn as interrupted
await self._end_turn(was_interrupted=True)
self._bot_is_speaking = False
elif isinstance(frame, (EndFrame, CancelFrame)):
# Pipeline ending - finalize any active turn
if self._turn_active:
# Emit any pending assistant transcript (allow time for TTSTextFrames to be processed)
# Give a brief moment for any pending frames to process
import asyncio
await asyncio.sleep(0.001)
if self._current_turn_assistant_parts:
assistant_content = concatenate_aggregated_text(
self._current_turn_assistant_parts
)
if assistant_content:
message = TranscriptionMessage(
role="assistant",
content=assistant_content,
timestamp=self._turn_start_time or time_now_iso8601(),
)
await self._emit_update([message])
await self._end_turn(was_interrupted=isinstance(frame, CancelFrame))
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)

View File

@@ -183,6 +183,14 @@ class DeepgramFluxSTTService(WebsocketSTTService):
"""
await self._connect_websocket()
# Creating the receiver task (only created once during initial connection)
if not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
# Creating the watchdog task (only created once during initial connection)
if not self._watchdog_task:
self._watchdog_task = self.create_task(self._watchdog_task_handler())
async def _disconnect(self):
"""Disconnect from WebSocket and clean up tasks.
@@ -235,16 +243,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
additional_headers={"Authorization": f"Token {self._api_key}"},
)
# Creating the receiver task
if not self._receive_task:
self._receive_task = self.create_task(
self._receive_task_handler(self._report_error)
)
# Creating the watchdog task
if not self._watchdog_task:
self._watchdog_task = self.create_task(self._watchdog_task_handler())
# Now wait for the connection established event
logger.debug("WebSocket connected, waiting for server confirmation...")
await self._connection_established_event.wait()

View File

@@ -0,0 +1,189 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import unittest
from pipecat.frames.frames import (
AggregationType,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
InterruptionFrame,
TranscriptionFrame,
TranscriptionUpdateFrame,
TTSTextFrame,
UserStartedSpeakingFrame,
)
from pipecat.processors.transcript_processor import TurnAwareTranscriptProcessor
from pipecat.tests.utils import SleepFrame, run_test
class TestTurnAwareTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
"""Tests for TurnAwareTranscriptProcessor."""
async def test_basic_turn_flow(self):
"""Test basic turn start/end with user and assistant speech."""
processor = TurnAwareTranscriptProcessor()
# Track events
turn_started_calls = []
turn_ended_calls = []
@processor.event_handler("on_turn_started")
async def on_turn_started(proc, turn_number):
turn_started_calls.append(turn_number)
@processor.event_handler("on_turn_ended")
async def on_turn_ended(proc, turn_number, user_text, assistant_text, interrupted):
turn_ended_calls.append(
{
"turn_number": turn_number,
"user_text": user_text,
"assistant_text": assistant_text,
"interrupted": interrupted,
}
)
frames_to_send = [
# Turn 1: User speaks, bot responds
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Hello", user_id="user1", timestamp=""),
SleepFrame(sleep=0.01), # Allow transcription to process
BotStartedSpeakingFrame(),
TTSTextFrame(text="Hi", aggregated_by=AggregationType.WORD),
TTSTextFrame(text=" there", aggregated_by=AggregationType.WORD),
BotStoppedSpeakingFrame(),
SleepFrame(sleep=0.1),
]
await run_test(processor, frames_to_send=frames_to_send)
# Verify events
self.assertEqual(
len(turn_started_calls), 1, f"Expected 1 turn started, got {len(turn_started_calls)}"
)
self.assertEqual(turn_started_calls[0], 1)
self.assertEqual(
len(turn_ended_calls), 1, f"Expected 1 turn ended, got {len(turn_ended_calls)}"
)
self.assertEqual(turn_ended_calls[0]["turn_number"], 1)
self.assertEqual(turn_ended_calls[0]["user_text"], "Hello")
self.assertEqual(turn_ended_calls[0]["assistant_text"], "Hi there")
self.assertFalse(turn_ended_calls[0]["interrupted"])
async def test_interruption(self):
"""Test turn ending on interruption."""
processor = TurnAwareTranscriptProcessor()
# Track events
turn_ended_calls = []
@processor.event_handler("on_turn_ended")
async def on_turn_ended(proc, turn_number, user_text, assistant_text, interrupted):
turn_ended_calls.append(
{
"turn_number": turn_number,
"user_text": user_text,
"assistant_text": assistant_text,
"interrupted": interrupted,
}
)
frames_to_send = [
# User speaks
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Tell me", user_id="user1", timestamp=""),
SleepFrame(sleep=0.01), # Allow transcription to process
# Bot starts responding
BotStartedSpeakingFrame(),
TTSTextFrame(text="Sure", aggregated_by=AggregationType.WORD),
TTSTextFrame(text=" I", aggregated_by=AggregationType.WORD),
TTSTextFrame(text=" can", aggregated_by=AggregationType.WORD),
# User interrupts
InterruptionFrame(),
# New turn starts
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Wait", user_id="user1", timestamp=""),
SleepFrame(sleep=0.1),
]
await run_test(processor, frames_to_send=frames_to_send)
# Verify first turn was interrupted
self.assertGreaterEqual(
len(turn_ended_calls), 1, f"Expected at least 1 turn ended, got {len(turn_ended_calls)}"
)
first_turn = turn_ended_calls[0]
self.assertEqual(first_turn["user_text"], "Tell me")
# Note: In this test flow, InterruptionFrame arrives before TTSTextFrames are processed,
# so assistant text may be empty. In real scenarios, word timestamps ensure proper capture.
self.assertIn(first_turn["assistant_text"], ["", "Sure I can", "Sure I can"])
self.assertTrue(first_turn["interrupted"])
async def test_multiple_turns(self):
"""Test multiple back-and-forth turns."""
processor = TurnAwareTranscriptProcessor()
# Track events
turn_started_calls = []
turn_ended_calls = []
@processor.event_handler("on_turn_started")
async def on_turn_started(proc, turn_number):
turn_started_calls.append(turn_number)
@processor.event_handler("on_turn_ended")
async def on_turn_ended(proc, turn_number, user_text, assistant_text, interrupted):
turn_ended_calls.append(
{
"turn_number": turn_number,
"user_text": user_text,
"assistant_text": assistant_text,
}
)
frames_to_send = [
# Turn 1
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Hi", user_id="user1", timestamp=""),
SleepFrame(sleep=0.01), # Allow transcription to process
BotStartedSpeakingFrame(),
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
BotStoppedSpeakingFrame(),
SleepFrame(sleep=0.05),
# Turn 2
UserStartedSpeakingFrame(),
TranscriptionFrame(text="How are you", user_id="user1", timestamp=""),
SleepFrame(sleep=0.01), # Allow transcription to process
BotStartedSpeakingFrame(),
TTSTextFrame(text="I'm", aggregated_by=AggregationType.WORD),
TTSTextFrame(text=" good", aggregated_by=AggregationType.WORD),
BotStoppedSpeakingFrame(),
SleepFrame(sleep=0.1),
]
await run_test(processor, frames_to_send=frames_to_send)
# Verify multiple turns
self.assertEqual(
len(turn_started_calls), 2, f"Expected 2 turns started, got {len(turn_started_calls)}"
)
self.assertEqual(turn_started_calls, [1, 2])
self.assertEqual(
len(turn_ended_calls), 2, f"Expected 2 turns ended, got {len(turn_ended_calls)}"
)
self.assertEqual(turn_ended_calls[0]["turn_number"], 1)
self.assertEqual(turn_ended_calls[0]["user_text"], "Hi")
self.assertEqual(turn_ended_calls[0]["assistant_text"], "Hello")
self.assertEqual(turn_ended_calls[1]["turn_number"], 2)
self.assertEqual(turn_ended_calls[1]["user_text"], "How are you")
self.assertEqual(turn_ended_calls[1]["assistant_text"], "I'm good")
if __name__ == "__main__":
unittest.main()