Files
pipecat/tests/test_transcript_processor.py
mattie ruth backman fe9aa3383e Adding support for new bot-output RTVI Message:
1. TTSTextFrames now include metadata about whether the text was spoken
   or not along with a type string to describe what the text represents:
   ex. "sentence", "word", "custom aggregation"
2. Expanded how aggregators work so that the aggregate method returns
   aggregated text along with the type of aggregation used to create it
3. Deprecated the RTVI bot-transcription event in lieu of...
4. Introduced support for a new bot-output event. This event is meant
   to be the one stop shop for communicating what the bot actually "says".
   It is based off TTSTextFrames to communicate both sentence by sentence
   (or whatever aggregation is used) as well as word by word. In addition,
   it will include LLMTextFrames, aggregated by sentence when tts is
   turned off (i.e. skip_tts is true).

Resolves pipecat-ai/pipecat-client-web#158
2025-11-14 13:51:45 -05:00

487 lines
18 KiB
Python

#
# Copyright (c) 2024-2025 Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import unittest
from datetime import datetime, timezone
from typing import List, Tuple, cast
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
InterruptionFrame,
TranscriptionFrame,
TranscriptionMessage,
TranscriptionUpdateFrame,
TTSTextFrame,
)
from pipecat.processors.transcript_processor import (
AssistantTranscriptProcessor,
UserTranscriptProcessor,
)
from pipecat.tests.utils import SleepFrame, run_test
class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
"""Tests for UserTranscriptProcessor"""
async def test_basic_transcription(self):
"""Test basic transcription frame processing"""
# Create processor
processor = UserTranscriptProcessor()
# Create test timestamp
timestamp = datetime.now(timezone.utc).isoformat()
# Create frames to send
frames_to_send = [
TranscriptionFrame(text="Hello, world!", user_id="test_user", timestamp=timestamp)
]
# Expected frames downstream - note the order:
# 1. TranscriptionUpdateFrame (processor emits the update first)
# 2. TranscriptionFrame (original frame is passed through)
expected_down_frames = [TranscriptionUpdateFrame, TranscriptionFrame]
# Run test
received_frames, _ = await run_test(
processor,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
# Verify the content of the TranscriptionUpdateFrame
update_frame = cast(
TranscriptionUpdateFrame, received_frames[0]
) # Note: now checking first frame
self.assertIsInstance(update_frame, TranscriptionUpdateFrame)
self.assertEqual(len(update_frame.messages), 1)
message = update_frame.messages[0]
self.assertEqual(message.role, "user")
self.assertEqual(message.content, "Hello, world!")
self.assertEqual(message.user_id, "test_user")
self.assertEqual(message.timestamp, timestamp)
async def test_event_handler(self):
"""Test that event handlers are called with transcript updates"""
# Create processor
processor = UserTranscriptProcessor()
# Track received updates
received_updates: List[TranscriptionMessage] = []
# Register event handler
@processor.event_handler("on_transcript_update")
async def handle_update(proc, frame: TranscriptionUpdateFrame):
received_updates.extend(frame.messages)
# Create test data
timestamp = datetime.now(timezone.utc).isoformat()
frames_to_send = [
TranscriptionFrame(text="First message", user_id="test_user", timestamp=timestamp),
TranscriptionFrame(text="Second message", user_id="test_user", timestamp=timestamp),
]
expected_down_frames = [
TranscriptionUpdateFrame,
TranscriptionFrame, # First message
TranscriptionUpdateFrame,
TranscriptionFrame, # Second message
]
# Run test
await run_test(
processor,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
# Verify event handler received updates
self.assertEqual(len(received_updates), 2)
# Check first message
self.assertEqual(received_updates[0].role, "user")
self.assertEqual(received_updates[0].content, "First message")
self.assertEqual(received_updates[0].timestamp, timestamp)
# Check second message
self.assertEqual(received_updates[1].role, "user")
self.assertEqual(received_updates[1].content, "Second message")
self.assertEqual(received_updates[1].timestamp, timestamp)
async def test_text_aggregation(self):
"""Test that TTSTextFrames are properly aggregated into a single message"""
# Create processor
processor = AssistantTranscriptProcessor()
# Track received updates
received_updates: List[TranscriptionUpdateFrame] = []
@processor.event_handler("on_transcript_update")
async def handle_update(proc, frame: TranscriptionUpdateFrame):
received_updates.append(frame)
# Create test frames simulating bot speaking multiple text chunks
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(), # Wait for StartedSpeaking to process
TTSTextFrame(text="Hello", aggregated_by="word"),
TTSTextFrame(text="world!", aggregated_by="word"),
TTSTextFrame(text="How", aggregated_by="word"),
TTSTextFrame(text="are", aggregated_by="word"),
TTSTextFrame(text="you?", aggregated_by="word"),
SleepFrame(), # Wait for text frames to queue
BotStoppedSpeakingFrame(),
]
# Expected order:
# 1. BotStartedSpeakingFrame (system frame, immediate)
# 2. All queued TTSTextFrames
# 3. BotStoppedSpeakingFrame (system frame, immediate)
# 4. TranscriptionUpdateFrame (after aggregation)
expected_down_frames = [
BotStartedSpeakingFrame,
TTSTextFrame,
TTSTextFrame,
TTSTextFrame,
TTSTextFrame,
TTSTextFrame,
TranscriptionUpdateFrame,
BotStoppedSpeakingFrame,
]
# Run test
received_frames, _ = await run_test(
processor,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
# Verify update was received
self.assertEqual(len(received_updates), 1)
# Get the update frame
update_frame = received_updates[0]
# Should have one aggregated message
self.assertEqual(len(update_frame.messages), 1)
message = update_frame.messages[0]
self.assertEqual(message.role, "assistant")
self.assertEqual(message.content, "Hello world! How are you?")
# Verify timestamp exists
self.assertIsNotNone(message.timestamp)
# All frames should be passed through in order, with update at end
downstream_update = cast(TranscriptionUpdateFrame, received_frames[-2])
self.assertEqual(downstream_update.messages[0].content, "Hello world! How are you?")
async def test_empty_text_handling(self):
"""Test that empty messages are not emitted"""
processor = AssistantTranscriptProcessor()
received_updates: List[TranscriptionUpdateFrame] = []
@processor.event_handler("on_transcript_update")
async def handle_update(proc, frame: TranscriptionUpdateFrame):
received_updates.append(frame)
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text="", aggregated_by="word"), # Empty text
TTSTextFrame(text=" ", aggregated_by="word"), # Just whitespace
TTSTextFrame(text="\n", aggregated_by="word"), # Just newline
BotStoppedSpeakingFrame(),
# Pipeline ends here; run_test will automatically send EndFrame
]
# From our earlier tests, we know BotStoppedSpeakingFrame comes before TTSTextFrames
expected_down_frames = [
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
TTSTextFrame, # empty
TTSTextFrame, # whitespace
TTSTextFrame, # newline
# No TranscriptionUpdateFrame since content is empty after stripping
]
await run_test(
processor,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.assertEqual(len(received_updates), 0, "No updates should be emitted for empty content")
async def test_interruption_handling(self):
"""Test that messages are properly captured when bot is interrupted"""
processor = AssistantTranscriptProcessor()
# Track received updates
received_updates: List[TranscriptionUpdateFrame] = []
@processor.event_handler("on_transcript_update")
async def handle_update(proc, frame: TranscriptionUpdateFrame):
received_updates.append(frame)
# Simulate bot being interrupted mid-sentence
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text="Hello", aggregated_by="word"),
TTSTextFrame(text="world!", aggregated_by="word"),
SleepFrame(),
InterruptionFrame(), # User interrupts here
SleepFrame(),
BotStartedSpeakingFrame(),
TTSTextFrame(text="New", aggregated_by="word"),
TTSTextFrame(text="response", aggregated_by="word"),
SleepFrame(),
BotStoppedSpeakingFrame(),
]
# Actual order of frames:
expected_down_frames = [
BotStartedSpeakingFrame,
TTSTextFrame, # "Hello"
TTSTextFrame, # "world!"
InterruptionFrame,
TranscriptionUpdateFrame, # First message (emitted due to interruption)
BotStartedSpeakingFrame,
TTSTextFrame, # "New"
TTSTextFrame, # "response"
TranscriptionUpdateFrame, # Second message
BotStoppedSpeakingFrame,
]
# Run test
received_frames, _ = await run_test(
processor,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
# Should have received two updates
self.assertEqual(len(received_updates), 2)
# First update should be interrupted message
first_message = received_updates[0].messages[0]
self.assertEqual(first_message.role, "assistant")
self.assertEqual(first_message.content, "Hello world!")
self.assertIsNotNone(first_message.timestamp)
# Second update should be new response
second_message = received_updates[1].messages[0]
self.assertEqual(second_message.role, "assistant")
self.assertEqual(second_message.content, "New response")
self.assertIsNotNone(second_message.timestamp)
# Verify timestamps are different
self.assertNotEqual(first_message.timestamp, second_message.timestamp)
async def test_end_frame_handling(self):
"""Test that final messages are captured when pipeline ends normally"""
processor = AssistantTranscriptProcessor()
received_updates: List[TranscriptionUpdateFrame] = []
@processor.event_handler("on_transcript_update")
async def handle_update(proc, frame: TranscriptionUpdateFrame):
received_updates.append(frame)
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text="Hello", aggregated_by="word"),
TTSTextFrame(text="world", aggregated_by="word"),
# Pipeline ends here; run_test will automatically send EndFrame
]
expected_down_frames = [
BotStartedSpeakingFrame,
TTSTextFrame,
TTSTextFrame,
TranscriptionUpdateFrame, # Final message emitted due to EndFrame
]
# Run test - EndFrame will be sent automatically
received_frames, _ = await run_test(
processor,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.assertEqual(len(received_updates), 1)
message = received_updates[0].messages[0]
self.assertEqual(message.role, "assistant")
self.assertEqual(message.content, "Hello world")
async def test_cancel_frame_handling(self):
"""Test that messages are properly captured when pipeline is cancelled"""
processor = AssistantTranscriptProcessor()
# Track updates with timestamps to verify order
received_updates: List[Tuple[str, float]] = []
@processor.event_handler("on_transcript_update")
async def handle_update(proc, frame: TranscriptionUpdateFrame):
# Record message content and time received
received_updates.append((frame.messages[0].content, asyncio.get_event_loop().time()))
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text="Hello", aggregated_by="word"),
TTSTextFrame(text="world", aggregated_by="word"),
SleepFrame(), # Ensure messages are processed
CancelFrame(),
]
# We don't need to verify frame order, just that CancelFrame triggers message emission
expected_down_frames = [
BotStartedSpeakingFrame,
TTSTextFrame,
TTSTextFrame,
CancelFrame,
]
await run_test(
processor,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
send_end_frame=False,
)
# Verify that we received an update
self.assertEqual(len(received_updates), 1, "Should receive one update before cancellation")
content, _ = received_updates[0]
self.assertEqual(content, "Hello world")
async def test_transcript_processor_factory(self):
"""Test that factory properly manages processors and event handlers"""
from pipecat.processors.transcript_processor import TranscriptProcessor
factory = TranscriptProcessor()
received_updates: List[TranscriptionMessage] = []
# Register handler with factory
@factory.event_handler("on_transcript_update")
async def handle_update(proc, frame: TranscriptionUpdateFrame):
received_updates.extend(frame.messages)
# Get processors and verify they're reused
user_proc1 = factory.user()
user_proc2 = factory.user()
self.assertIs(user_proc1, user_proc2, "User processor should be reused")
asst_proc1 = factory.assistant()
asst_proc2 = factory.assistant()
self.assertIs(asst_proc1, asst_proc2, "Assistant processor should be reused")
# Test user processor
timestamp = datetime.now(timezone.utc).isoformat()
frames_to_send = [
TranscriptionFrame(text="User message", user_id="user1", timestamp=timestamp)
]
await run_test(
user_proc1,
frames_to_send=frames_to_send,
expected_down_frames=[TranscriptionUpdateFrame, TranscriptionFrame],
)
# Test assistant processor
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text="Assistant", aggregated_by="word"),
TTSTextFrame(text="message", aggregated_by="word"),
BotStoppedSpeakingFrame(),
]
# The actual order we see in the output:
await run_test(
asst_proc1,
frames_to_send=frames_to_send,
expected_down_frames=[
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
TTSTextFrame,
TTSTextFrame,
TranscriptionUpdateFrame,
],
)
# Verify both processors triggered the same handler
self.assertEqual(len(received_updates), 2)
self.assertEqual(received_updates[0].role, "user")
self.assertEqual(received_updates[0].content, "User message")
self.assertEqual(received_updates[1].role, "assistant")
self.assertEqual(received_updates[1].content, "Assistant message")
async def test_text_fragments_with_spaces(self):
"""Test aggregating text fragments with various spacing patterns"""
processor = AssistantTranscriptProcessor()
# Track received updates
received_updates = []
@processor.event_handler("on_transcript_update")
async def handle_update(proc, frame: TranscriptionUpdateFrame):
received_updates.append(frame)
# Test the specific pattern shared
def make_tts_text_frame(text: str) -> TTSTextFrame:
frame = TTSTextFrame(text=text, aggregated_by="word")
frame.includes_inter_frame_spaces = True
return frame
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
make_tts_text_frame("Hello"),
make_tts_text_frame(" there"),
make_tts_text_frame("!"),
make_tts_text_frame(" How"),
make_tts_text_frame("'s"),
make_tts_text_frame(" it"),
make_tts_text_frame(" going"),
make_tts_text_frame("?"),
BotStoppedSpeakingFrame(),
]
expected_down_frames = [
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
TTSTextFrame,
TTSTextFrame,
TTSTextFrame,
TTSTextFrame,
TTSTextFrame,
TTSTextFrame,
TTSTextFrame,
TTSTextFrame,
TranscriptionUpdateFrame,
]
# Run test
received_frames, _ = await run_test(
processor,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
# Verify result
self.assertEqual(len(received_updates), 1)
message = received_updates[0].messages[0]
self.assertEqual(message.role, "assistant")
# Should be properly joined without extra spaces
self.assertEqual(message.content, "Hello there! How's it going?")