Files
pipecat/src/pipecat/processors/transcript_processor.py
2025-04-01 10:21:21 -07:00

301 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import List, Optional
from loguru import logger
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
Frame,
StartInterruptionFrame,
TranscriptionFrame,
TranscriptionMessage,
TranscriptionUpdateFrame,
TTSTextFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.time import time_now_iso8601
class BaseTranscriptProcessor(FrameProcessor):
"""Base class for processing conversation transcripts.
Provides common functionality for handling transcript messages and updates.
"""
def __init__(self, **kwargs):
"""Initialize processor with empty message store."""
super().__init__(**kwargs)
self._processed_messages: List[TranscriptionMessage] = []
self._register_event_handler("on_transcript_update")
async def _emit_update(self, messages: List[TranscriptionMessage]):
"""Emit transcript updates for new messages.
Args:
messages: New messages to emit in update
"""
if messages:
self._processed_messages.extend(messages)
update_frame = TranscriptionUpdateFrame(messages=messages)
await self._call_event_handler("on_transcript_update", update_frame)
await self.push_frame(update_frame)
class UserTranscriptProcessor(BaseTranscriptProcessor):
"""Processes user transcription frames into timestamped conversation messages."""
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process TranscriptionFrames into user conversation messages.
Args:
frame: Input frame to process
direction: Frame processing direction
"""
await super().process_frame(frame, direction)
if isinstance(frame, TranscriptionFrame):
message = TranscriptionMessage(
role="user", content=frame.text, timestamp=frame.timestamp
)
await self._emit_update([message])
await self.push_frame(frame, direction)
class AssistantTranscriptProcessor(BaseTranscriptProcessor):
"""Processes assistant TTS text frames into timestamped conversation messages.
This processor aggregates TTS text frames into complete utterances and emits them as
transcript messages. Utterances are completed when:
- The bot stops speaking (BotStoppedSpeakingFrame)
- The bot is interrupted (StartInterruptionFrame)
- The pipeline ends (EndFrame)
Attributes:
_current_text_parts: List of text fragments being aggregated for current utterance
_aggregation_start_time: Timestamp when the current utterance began
"""
def __init__(self, **kwargs):
"""Initialize processor with aggregation state."""
super().__init__(**kwargs)
self._current_text_parts: List[str] = []
self._aggregation_start_time: Optional[str] = None
async def _emit_aggregated_text(self):
"""Aggregates and emits text fragments as a transcript message.
This method uses a heuristic to automatically detect whether text fragments
use pre-spacing (spaces at the beginning of fragments) or not, and applies
the appropriate joining strategy. It handles fragments from different TTS
services with different formatting patterns.
Examples:
Pre-spaced fragments (concatenated):
```
TTSTextFrame: ["Hello"]
TTSTextFrame: [" there"]
TTSTextFrame: ["!"]
TTSTextFrame: [" How"]
TTSTextFrame: ["'s"]
TTSTextFrame: [" it"]
TTSTextFrame: [" going"]
TTSTextFrame: ["?"]
```
Result: "Hello there! How's it going?"
Word-by-word fragments (joined with spaces):
```
TTSTextFrame: ["Hello"]
TTSTextFrame: ["there!"]
TTSTextFrame: ["How"]
TTSTextFrame: ["is"]
TTSTextFrame: ["it"]
TTSTextFrame: ["going?"]
```
Result: "Hello there! How is it going?"
"""
if self._current_text_parts and self._aggregation_start_time:
# Heuristic to detect pre-spaced fragments
uses_prespacing = False
if len(self._current_text_parts) > 1:
# Check if any fragment after the first one starts with whitespace
has_spaced_parts = any(
part and part[0].isspace() for part in self._current_text_parts[1:]
)
if has_spaced_parts:
uses_prespacing = True
# Apply appropriate joining method
if uses_prespacing:
# Pre-spaced fragments - just concatenate
content = "".join(self._current_text_parts)
else:
# Word-by-word fragments - join with spaces
content = " ".join(self._current_text_parts)
# Clean up any excessive whitespace
content = content.strip()
if content:
logger.trace(f"Emitting aggregated assistant message: {content}")
message = TranscriptionMessage(
role="assistant",
content=content,
timestamp=self._aggregation_start_time,
)
await self._emit_update([message])
else:
logger.trace("No content to emit after stripping whitespace")
# Reset aggregation state
self._current_text_parts = []
self._aggregation_start_time = None
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames into assistant conversation messages.
Handles different frame types:
- TTSTextFrame: Aggregates text for current utterance
- BotStoppedSpeakingFrame: Completes current utterance
- StartInterruptionFrame: Completes current utterance due to interruption
- EndFrame: Completes current utterance at pipeline end
- CancelFrame: Completes current utterance due to cancellation
Args:
frame: Input frame to process
direction: Frame processing direction
"""
await super().process_frame(frame, direction)
if isinstance(frame, (StartInterruptionFrame, CancelFrame)):
# Push frame first otherwise our emitted transcription update frame
# might get cleaned up.
await self.push_frame(frame, direction)
# Emit accumulated text with interruptions
await self._emit_aggregated_text()
elif isinstance(frame, TTSTextFrame):
# Start timestamp on first text part
if not self._aggregation_start_time:
self._aggregation_start_time = time_now_iso8601()
self._current_text_parts.append(frame.text)
# Push frame.
await self.push_frame(frame, direction)
elif isinstance(frame, (BotStoppedSpeakingFrame, EndFrame)):
# Emit accumulated text when bot finishes speaking or pipeline ends.
await self._emit_aggregated_text()
# Push frame.
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)
class TranscriptProcessor:
"""Factory for creating and managing transcript processors.
Provides unified access to user and assistant transcript processors
with shared event handling.
Example:
```python
transcript = TranscriptProcessor()
pipeline = Pipeline(
[
transport.input(),
stt,
transcript.user(), # User transcripts
context_aggregator.user(),
llm,
tts,
transport.output(),
transcript.assistant_tts(), # Assistant transcripts
context_aggregator.assistant(),
]
)
@transcript.event_handler("on_transcript_update")
async def handle_update(processor, frame):
print(f"New messages: {frame.messages}")
```
"""
def __init__(self):
"""Initialize factory."""
self._user_processor = None
self._assistant_processor = None
self._event_handlers = {}
def user(self, **kwargs) -> UserTranscriptProcessor:
"""Get the user transcript processor.
Args:
**kwargs: Arguments specific to UserTranscriptProcessor
"""
if self._user_processor is None:
self._user_processor = UserTranscriptProcessor(**kwargs)
# Apply any registered event handlers
for event_name, handler in self._event_handlers.items():
@self._user_processor.event_handler(event_name)
async def user_handler(processor, frame):
return await handler(processor, frame)
return self._user_processor
def assistant(self, **kwargs) -> AssistantTranscriptProcessor:
"""Get the assistant transcript processor.
Args:
**kwargs: Arguments specific to AssistantTranscriptProcessor
"""
if self._assistant_processor is None:
self._assistant_processor = AssistantTranscriptProcessor(**kwargs)
# Apply any registered event handlers
for event_name, handler in self._event_handlers.items():
@self._assistant_processor.event_handler(event_name)
async def assistant_handler(processor, frame):
return await handler(processor, frame)
return self._assistant_processor
def event_handler(self, event_name: str):
"""Register event handler for both processors.
Args:
event_name: Name of event to handle
Returns:
Decorator function that registers handler with both processors
"""
def decorator(handler):
self._event_handlers[event_name] = handler
# Apply handler to existing processors if they exist
if self._user_processor:
@self._user_processor.event_handler(event_name)
async def user_handler(processor, frame):
return await handler(processor, frame)
if self._assistant_processor:
@self._assistant_processor.event_handler(event_name)
async def assistant_handler(processor, frame):
return await handler(processor, frame)
return handler
return decorator