Merge pull request #1046 from pipecat-ai/mb/transcript-tts
Modified `TranscriptProcessor` to use `TTSTextFrame`s
This commit is contained in:
12
CHANGELOG.md
12
CHANGELOG.md
@@ -12,6 +12,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- It is now possible to specify the period of the `PipelineTask` heartbeat
|
||||
frames with `heartbeats_period_secs`.
|
||||
|
||||
### Changed
|
||||
|
||||
- Modified `TranscriptProcessor` to use TTS text frames for more accurate assistant
|
||||
transcripts. Assistant messages are now aggregated based on bot speaking boundaries
|
||||
rather than LLM context, providing better handling of interruptions and partial
|
||||
utterances.
|
||||
|
||||
- Updated foundational examples `28a-transcription-processor-openai.py`,
|
||||
`28b-transcript-processor-anthropic.py`, and
|
||||
`28c-transcription-processor-gemini.py` to use the updated
|
||||
`TranscriptProcessor`.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a type error when using `voice_settings` in `ElevenLabsHttpTTSService`.
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
@@ -15,7 +15,11 @@ from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import EndFrame, TranscriptionMessage, TranscriptionUpdateFrame
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
TranscriptionMessage,
|
||||
TranscriptionUpdateFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -33,13 +37,49 @@ logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class TranscriptHandler:
|
||||
"""Simple handler to demonstrate transcript processing.
|
||||
"""Handles real-time transcript processing and output.
|
||||
|
||||
Maintains a list of conversation messages and logs them with timestamps.
|
||||
Maintains a list of conversation messages and outputs them either to a log
|
||||
or to a file as they are received. Each message includes its timestamp and role.
|
||||
|
||||
Attributes:
|
||||
messages: List of all processed transcript messages
|
||||
output_file: Optional path to file where transcript is saved. If None, outputs to log only.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, output_file: Optional[str] = None):
|
||||
"""Initialize handler with optional file output.
|
||||
|
||||
Args:
|
||||
output_file: Path to output file. If None, outputs to log only.
|
||||
"""
|
||||
self.messages: List[TranscriptionMessage] = []
|
||||
self.output_file: Optional[str] = output_file
|
||||
logger.debug(
|
||||
f"TranscriptHandler initialized {'with output_file=' + output_file if output_file else 'with log output only'}"
|
||||
)
|
||||
|
||||
async def save_message(self, message: TranscriptionMessage):
|
||||
"""Save a single transcript message.
|
||||
|
||||
Outputs the message to the log and optionally to a file.
|
||||
|
||||
Args:
|
||||
message: The message to save
|
||||
"""
|
||||
timestamp = f"[{message.timestamp}] " if message.timestamp else ""
|
||||
line = f"{timestamp}{message.role}: {message.content}"
|
||||
|
||||
# Always log the message
|
||||
logger.info(f"Transcript: {line}")
|
||||
|
||||
# Optionally write to file
|
||||
if self.output_file:
|
||||
try:
|
||||
with open(self.output_file, "a", encoding="utf-8") as f:
|
||||
f.write(line + "\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving transcript message to file: {e}")
|
||||
|
||||
async def on_transcript_update(
|
||||
self, processor: TranscriptProcessor, frame: TranscriptionUpdateFrame
|
||||
@@ -50,13 +90,11 @@ class TranscriptHandler:
|
||||
processor: The TranscriptProcessor that emitted the update
|
||||
frame: TranscriptionUpdateFrame containing new messages
|
||||
"""
|
||||
self.messages.extend(frame.messages)
|
||||
logger.debug(f"Received transcript update with {len(frame.messages)} new messages")
|
||||
|
||||
# Log the new messages
|
||||
logger.info("New transcript messages:")
|
||||
for msg in frame.messages:
|
||||
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
|
||||
logger.info(f"{timestamp}{msg.role}: {msg.content}")
|
||||
self.messages.append(msg)
|
||||
await self.save_message(msg)
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -99,7 +137,8 @@ async def main():
|
||||
|
||||
# Create transcript processor and handler
|
||||
transcript = TranscriptProcessor()
|
||||
transcript_handler = TranscriptHandler()
|
||||
transcript_handler = TranscriptHandler() # Output to log only
|
||||
# transcript_handler = TranscriptHandler(output_file="transcript.txt") # Output to file and log
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
@@ -110,8 +149,8 @@ async def main():
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
transcript.assistant(), # Assistant transcripts
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
@@ -130,7 +169,8 @@ async def main():
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
await task.queue_frame(EndFrame())
|
||||
# Stop the pipeline immediately when the participant leaves
|
||||
await task.queue_frame(CancelFrame())
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
@@ -15,7 +15,11 @@ from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import EndFrame, TranscriptionMessage, TranscriptionUpdateFrame
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
TranscriptionMessage,
|
||||
TranscriptionUpdateFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -33,13 +37,49 @@ logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class TranscriptHandler:
|
||||
"""Simple handler to demonstrate transcript processing.
|
||||
"""Handles real-time transcript processing and output.
|
||||
|
||||
Maintains a list of conversation messages and logs them with timestamps.
|
||||
Maintains a list of conversation messages and outputs them either to a log
|
||||
or to a file as they are received. Each message includes its timestamp and role.
|
||||
|
||||
Attributes:
|
||||
messages: List of all processed transcript messages
|
||||
output_file: Optional path to file where transcript is saved. If None, outputs to log only.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, output_file: Optional[str] = None):
|
||||
"""Initialize handler with optional file output.
|
||||
|
||||
Args:
|
||||
output_file: Path to output file. If None, outputs to log only.
|
||||
"""
|
||||
self.messages: List[TranscriptionMessage] = []
|
||||
self.output_file: Optional[str] = output_file
|
||||
logger.debug(
|
||||
f"TranscriptHandler initialized {'with output_file=' + output_file if output_file else 'with log output only'}"
|
||||
)
|
||||
|
||||
async def save_message(self, message: TranscriptionMessage):
|
||||
"""Save a single transcript message.
|
||||
|
||||
Outputs the message to the log and optionally to a file.
|
||||
|
||||
Args:
|
||||
message: The message to save
|
||||
"""
|
||||
timestamp = f"[{message.timestamp}] " if message.timestamp else ""
|
||||
line = f"{timestamp}{message.role}: {message.content}"
|
||||
|
||||
# Always log the message
|
||||
logger.info(f"Transcript: {line}")
|
||||
|
||||
# Optionally write to file
|
||||
if self.output_file:
|
||||
try:
|
||||
with open(self.output_file, "a", encoding="utf-8") as f:
|
||||
f.write(line + "\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving transcript message to file: {e}")
|
||||
|
||||
async def on_transcript_update(
|
||||
self, processor: TranscriptProcessor, frame: TranscriptionUpdateFrame
|
||||
@@ -50,13 +90,11 @@ class TranscriptHandler:
|
||||
processor: The TranscriptProcessor that emitted the update
|
||||
frame: TranscriptionUpdateFrame containing new messages
|
||||
"""
|
||||
self.messages.extend(frame.messages)
|
||||
logger.debug(f"Received transcript update with {len(frame.messages)} new messages")
|
||||
|
||||
# Log the new messages
|
||||
logger.info("New transcript messages:")
|
||||
for msg in frame.messages:
|
||||
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
|
||||
logger.info(f"{timestamp}{msg.role}: {msg.content}")
|
||||
self.messages.append(msg)
|
||||
await self.save_message(msg)
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -99,7 +137,8 @@ async def main():
|
||||
|
||||
# Create transcript processor and handler
|
||||
transcript = TranscriptProcessor()
|
||||
transcript_handler = TranscriptHandler()
|
||||
transcript_handler = TranscriptHandler() # Output to log only
|
||||
# transcript_handler = TranscriptHandler(output_file="transcript.txt") # Output to file and log
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
@@ -110,8 +149,8 @@ async def main():
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
transcript.assistant(), # Assistant transcripts
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
@@ -130,7 +169,8 @@ async def main():
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
await task.queue_frame(EndFrame())
|
||||
# Stop the pipeline immediately when the participant leaves
|
||||
await task.queue_frame(CancelFrame())
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
@@ -15,7 +15,11 @@ from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import EndFrame, TranscriptionMessage, TranscriptionUpdateFrame
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
TranscriptionMessage,
|
||||
TranscriptionUpdateFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -34,13 +38,49 @@ logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class TranscriptHandler:
|
||||
"""Simple handler to demonstrate transcript processing.
|
||||
"""Handles real-time transcript processing and output.
|
||||
|
||||
Maintains a list of conversation messages and logs them with timestamps.
|
||||
Maintains a list of conversation messages and outputs them either to a log
|
||||
or to a file as they are received. Each message includes its timestamp and role.
|
||||
|
||||
Attributes:
|
||||
messages: List of all processed transcript messages
|
||||
output_file: Optional path to file where transcript is saved. If None, outputs to log only.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, output_file: Optional[str] = None):
|
||||
"""Initialize handler with optional file output.
|
||||
|
||||
Args:
|
||||
output_file: Path to output file. If None, outputs to log only.
|
||||
"""
|
||||
self.messages: List[TranscriptionMessage] = []
|
||||
self.output_file: Optional[str] = output_file
|
||||
logger.debug(
|
||||
f"TranscriptHandler initialized {'with output_file=' + output_file if output_file else 'with log output only'}"
|
||||
)
|
||||
|
||||
async def save_message(self, message: TranscriptionMessage):
|
||||
"""Save a single transcript message.
|
||||
|
||||
Outputs the message to the log and optionally to a file.
|
||||
|
||||
Args:
|
||||
message: The message to save
|
||||
"""
|
||||
timestamp = f"[{message.timestamp}] " if message.timestamp else ""
|
||||
line = f"{timestamp}{message.role}: {message.content}"
|
||||
|
||||
# Always log the message
|
||||
logger.info(f"Transcript: {line}")
|
||||
|
||||
# Optionally write to file
|
||||
if self.output_file:
|
||||
try:
|
||||
with open(self.output_file, "a", encoding="utf-8") as f:
|
||||
f.write(line + "\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving transcript message to file: {e}")
|
||||
|
||||
async def on_transcript_update(
|
||||
self, processor: TranscriptProcessor, frame: TranscriptionUpdateFrame
|
||||
@@ -51,13 +91,11 @@ class TranscriptHandler:
|
||||
processor: The TranscriptProcessor that emitted the update
|
||||
frame: TranscriptionUpdateFrame containing new messages
|
||||
"""
|
||||
self.messages.extend(frame.messages)
|
||||
logger.debug(f"Received transcript update with {len(frame.messages)} new messages")
|
||||
|
||||
# Log the new messages
|
||||
logger.info("New transcript messages:")
|
||||
for msg in frame.messages:
|
||||
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
|
||||
logger.info(f"{timestamp}{msg.role}: {msg.content}")
|
||||
self.messages.append(msg)
|
||||
await self.save_message(msg)
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -102,7 +140,8 @@ async def main():
|
||||
|
||||
# Create transcript processor and handler
|
||||
transcript = TranscriptProcessor()
|
||||
transcript_handler = TranscriptHandler()
|
||||
transcript_handler = TranscriptHandler() # Output to log only
|
||||
# transcript_handler = TranscriptHandler(output_file="transcript.txt") # Output to file and log
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
@@ -113,8 +152,8 @@ async def main():
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
transcript.assistant(), # Assistant transcripts
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
@@ -140,7 +179,8 @@ async def main():
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
await task.queue_frame(EndFrame())
|
||||
# Stop the pipeline immediately when the participant leaves
|
||||
await task.queue_frame(CancelFrame())
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
|
||||
@@ -4,17 +4,23 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
StartInterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
TranscriptionMessage,
|
||||
TranscriptionUpdateFrame,
|
||||
TTSTextFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
class BaseTranscriptProcessor(FrameProcessor):
|
||||
@@ -64,89 +70,74 @@ class UserTranscriptProcessor(BaseTranscriptProcessor):
|
||||
|
||||
|
||||
class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
"""Processes assistant LLM context frames into timestamped conversation messages."""
|
||||
"""Processes assistant TTS text frames into timestamped conversation messages.
|
||||
|
||||
This processor aggregates TTS text frames into complete utterances and emits them as
|
||||
transcript messages. Utterances are completed when:
|
||||
- The bot stops speaking (BotStoppedSpeakingFrame)
|
||||
- The bot is interrupted (StartInterruptionFrame)
|
||||
- The pipeline ends (EndFrame)
|
||||
|
||||
Attributes:
|
||||
_current_text_parts: List of text fragments being aggregated for current utterance
|
||||
_aggregation_start_time: Timestamp when the current utterance began
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize processor with empty message stores."""
|
||||
"""Initialize processor with aggregation state."""
|
||||
super().__init__(**kwargs)
|
||||
self._pending_assistant_messages: List[TranscriptionMessage] = []
|
||||
self._current_text_parts: List[str] = []
|
||||
self._aggregation_start_time: Optional[str] | None = None
|
||||
|
||||
def _extract_messages(self, messages: List[dict]) -> List[TranscriptionMessage]:
|
||||
"""Extract assistant messages from the OpenAI standard message format.
|
||||
async def _emit_aggregated_text(self):
|
||||
"""Emit aggregated text as a transcript message."""
|
||||
if self._current_text_parts and self._aggregation_start_time:
|
||||
content = " ".join(self._current_text_parts).strip()
|
||||
if content:
|
||||
logger.debug(f"Emitting aggregated assistant message: {content}")
|
||||
message = TranscriptionMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
timestamp=self._aggregation_start_time,
|
||||
)
|
||||
await self._emit_update([message])
|
||||
else:
|
||||
logger.debug("No content to emit after stripping whitespace")
|
||||
|
||||
Args:
|
||||
messages: List of messages in OpenAI format, which can be either:
|
||||
- Simple format: {"role": "user", "content": "Hello"}
|
||||
- Content list: {"role": "user", "content": [{"type": "text", "text": "Hello"}]}
|
||||
|
||||
Returns:
|
||||
List[TranscriptionMessage]: Normalized conversation messages
|
||||
"""
|
||||
result = []
|
||||
for msg in messages:
|
||||
if msg["role"] != "assistant":
|
||||
continue
|
||||
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
if content:
|
||||
result.append(TranscriptionMessage(role="assistant", content=content))
|
||||
elif isinstance(content, list):
|
||||
text_parts = []
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
text_parts.append(part["text"])
|
||||
|
||||
if text_parts:
|
||||
result.append(
|
||||
TranscriptionMessage(role="assistant", content=" ".join(text_parts))
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _find_new_messages(self, current: List[TranscriptionMessage]) -> List[TranscriptionMessage]:
|
||||
"""Find unprocessed messages from current list.
|
||||
|
||||
Args:
|
||||
current: List of current messages
|
||||
|
||||
Returns:
|
||||
List[TranscriptionMessage]: New messages not yet processed
|
||||
"""
|
||||
if not self._processed_messages:
|
||||
return current
|
||||
|
||||
processed_len = len(self._processed_messages)
|
||||
if len(current) <= processed_len:
|
||||
return []
|
||||
|
||||
return current[processed_len:]
|
||||
# Reset aggregation state
|
||||
self._current_text_parts = []
|
||||
self._aggregation_start_time = None
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames into assistant conversation messages.
|
||||
|
||||
Handles different frame types:
|
||||
- TTSTextFrame: Aggregates text for current utterance
|
||||
- BotStoppedSpeakingFrame: Completes current utterance
|
||||
- StartInterruptionFrame: Completes current utterance due to interruption
|
||||
- EndFrame: Completes current utterance at pipeline end
|
||||
- CancelFrame: Completes current utterance due to cancellation
|
||||
|
||||
Args:
|
||||
frame: Input frame to process
|
||||
direction: Frame processing direction
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
standard_messages = []
|
||||
for msg in frame.context.messages:
|
||||
converted = frame.context.to_standard_messages(msg)
|
||||
standard_messages.extend(converted)
|
||||
if isinstance(frame, TTSTextFrame):
|
||||
# Start timestamp on first text part
|
||||
if not self._aggregation_start_time:
|
||||
self._aggregation_start_time = time_now_iso8601()
|
||||
|
||||
current_messages = self._extract_messages(standard_messages)
|
||||
new_messages = self._find_new_messages(current_messages)
|
||||
self._pending_assistant_messages.extend(new_messages)
|
||||
self._current_text_parts.append(frame.text)
|
||||
|
||||
elif isinstance(frame, OpenAILLMContextAssistantTimestampFrame):
|
||||
if self._pending_assistant_messages:
|
||||
for msg in self._pending_assistant_messages:
|
||||
msg.timestamp = frame.timestamp
|
||||
await self._emit_update(self._pending_assistant_messages)
|
||||
self._pending_assistant_messages = []
|
||||
elif isinstance(frame, (BotStoppedSpeakingFrame, StartInterruptionFrame, CancelFrame)):
|
||||
# Emit accumulated text when bot finishes speaking or is interrupted
|
||||
await self._emit_aggregated_text()
|
||||
|
||||
elif isinstance(frame, EndFrame):
|
||||
# Emit any remaining text when pipeline ends
|
||||
await self._emit_aggregated_text()
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -170,8 +161,8 @@ class TranscriptProcessor:
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
transcript.assistant_tts(), # Assistant transcripts
|
||||
context_aggregator.assistant(),
|
||||
transcript.assistant(), # Assistant transcripts
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user