Add TranscriptionProcessor

This commit is contained in:
Mark Backman
2024-12-13 15:38:59 -05:00
parent fb9f72d38b
commit 55879bf365
6 changed files with 613 additions and 9 deletions

View File

@@ -7,6 +7,7 @@
import asyncio
import os
import sys
from typing import List
import aiohttp
from dotenv import load_dotenv
@@ -14,12 +15,13 @@ from loguru import logger
from runner import configure
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMMessagesFrame
from pipecat.frames.frames import Frame, LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.anthropic import AnthropicLLMService
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.anthropic import AnthropicLLMContext, AnthropicLLMService
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.transports.services.daily import DailyParams, DailyTransport
@@ -29,6 +31,28 @@ logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
class TestAnthropicLLMService(AnthropicLLMService):
async def process_frame(self, frame: Frame, direction: FrameDirection):
if isinstance(frame, LLMMessagesFrame):
logger.info("Original OpenAI format messages:")
logger.info(frame.messages)
# Convert to Anthropic format
context = AnthropicLLMContext.from_messages(frame.messages)
logger.info("Converted to Anthropic format:")
logger.info(context.messages)
# Convert back to OpenAI format
openai_messages = []
for msg in context.messages:
converted = context.to_standard_messages(msg)
openai_messages.extend(converted)
logger.info("Converted back to OpenAI format:")
logger.info(openai_messages)
await super().process_frame(frame, direction)
async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
@@ -50,18 +74,24 @@ async def main():
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
)
llm = AnthropicLLMService(
llm = TestAnthropicLLMService(
api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-opus-20240229"
)
# todo: think more about how to handle system prompts in a more general way. OpenAI,
# Google, and Anthropic all have slightly different approaches to providing a system
# prompt.
# Test messages including various formats
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative, helpful, and brief way. Say hello.",
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Hello! How can I help you today?"},
{"type": "text", "text": "I'm ready to assist."},
],
},
{"role": "user", "content": "Hi there!"},
]
context = OpenAILLMContext(messages)

View File

@@ -0,0 +1,128 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
import sys
from typing import List
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMMessagesFrame, TranscriptionMessage, TranscriptionUpdateFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.transcript_processor import TranscriptProcessor
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.openai import OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
class TranscriptHandler:
"""Simple handler to demonstrate transcript processing."""
def __init__(self):
self.messages: List[TranscriptionMessage] = []
async def on_transcript_update(
self, processor: TranscriptProcessor, frame: TranscriptionUpdateFrame
):
"""Handle new transcript messages."""
self.messages.extend(frame.messages)
# Log the new messages
logger.info("New transcript messages:")
for msg in frame.messages:
logger.info(f"{msg.role}: {msg.content}")
# Log the full transcript
logger.info("Full transcript:")
for msg in self.messages:
logger.info(f"{msg.role}: {msg.content}")
async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
)
llm = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY"),
model="gpt-4o",
)
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative, helpful, and brief way. Say hello.",
},
]
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
# Create transcript processor and handler
transcript_processor = TranscriptProcessor()
transcript_handler = TranscriptHandler()
# Register event handler for transcript updates
@transcript_processor.event_handler("on_transcript_update")
async def on_transcript_update(processor, frame):
await transcript_handler.on_transcript_update(processor, frame)
pipeline = Pipeline(
[
transport.input(), # Transport user input
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
transcript_processor, # Process transcripts
]
)
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
await task.queue_frames([LLMMessagesFrame(messages)])
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,128 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
import sys
from typing import List
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMMessagesFrame, TranscriptionMessage, TranscriptionUpdateFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.transcript_processor import TranscriptProcessor
from pipecat.services.anthropic import AnthropicLLMService
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.transports.services.daily import DailyParams, DailyTransport
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
class TranscriptHandler:
"""Simple handler to demonstrate transcript processing."""
def __init__(self):
self.messages: List[TranscriptionMessage] = []
async def on_transcript_update(
self, processor: TranscriptProcessor, frame: TranscriptionUpdateFrame
):
"""Handle new transcript messages."""
self.messages.extend(frame.messages)
# Log the new messages
logger.info("New transcript messages:")
for msg in frame.messages:
logger.info(f"{msg.role}: {msg.content}")
# Log the full transcript
logger.info("Full transcript:")
for msg in self.messages:
logger.info(f"{msg.role}: {msg.content}")
async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
)
llm = AnthropicLLMService(
api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-5-sonnet-20241022"
)
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative, helpful, and brief way.",
},
{"role": "user", "content": "Say hello."},
]
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
# Create transcript processor and handler
transcript_processor = TranscriptProcessor()
transcript_handler = TranscriptHandler()
# Register event handler for transcript updates
@transcript_processor.event_handler("on_transcript_update")
async def on_transcript_update(processor, frame):
await transcript_handler.on_transcript_update(processor, frame)
pipeline = Pipeline(
[
transport.input(), # Transport user input
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
transcript_processor, # Process transcripts
]
)
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
await task.queue_frames([LLMMessagesFrame(messages)])
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,138 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
import sys
from typing import List
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import TranscriptionMessage, TranscriptionUpdateFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.transcript_processor import TranscriptProcessor
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.google import GoogleLLMService
from pipecat.services.openai import OpenAILLMContext
from pipecat.transports.services.daily import DailyParams, DailyTransport
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
class TranscriptHandler:
"""Simple handler to demonstrate transcript processing."""
def __init__(self):
self.messages: List[TranscriptionMessage] = []
async def on_transcript_update(
self, processor: TranscriptProcessor, frame: TranscriptionUpdateFrame
):
"""Handle new transcript messages."""
self.messages.extend(frame.messages)
# Log the new messages
logger.info("New transcript messages:")
for msg in frame.messages:
logger.info(f"{msg.role}: {msg.content}")
# Log the full transcript
logger.info("Full transcript:")
for msg in self.messages:
logger.info(f"{msg.role}: {msg.content}")
async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
)
llm = GoogleLLMService(
model="models/gemini-2.0-flash-exp",
# model="gemini-exp-1114",
api_key=os.getenv("GOOGLE_API_KEY"),
)
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative, helpful, and brief way.",
},
{"role": "user", "content": "Say hello."},
]
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
# Create transcript processor and handler
transcript_processor = TranscriptProcessor()
transcript_handler = TranscriptHandler()
# Register event handler for transcript updates
@transcript_processor.event_handler("on_transcript_update")
async def on_transcript_update(processor, frame):
await transcript_handler.on_transcript_update(processor, frame)
pipeline = Pipeline(
[
transport.input(),
context_aggregator.user(),
llm,
tts,
transport.output(),
context_aggregator.assistant(),
transcript_processor,
]
)
task = PipelineTask(
pipeline,
PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
),
)
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
await task.queue_frames([context_aggregator.user().get_context_frame()])
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -5,7 +5,7 @@
#
from dataclasses import dataclass, field
from typing import Any, List, Mapping, Optional, Tuple
from typing import Any, List, Literal, Mapping, Optional, Tuple, TypeAlias
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.clocks.base_clock import BaseClock
@@ -195,7 +195,8 @@ class TranscriptionFrame(TextFrame):
@dataclass
class InterimTranscriptionFrame(TextFrame):
"""A text frame with interim transcription-specific data. Will be placed in
the transport's receive queue when a participant speaks."""
the transport's receive queue when a participant speaks.
"""
text: str
user_id: str
@@ -206,6 +207,34 @@ class InterimTranscriptionFrame(TextFrame):
return f"{self.name}(user: {self.user_id}, text: [{self.text}], language: {self.language}, timestamp: {self.timestamp})"
@dataclass
class TranscriptionMessage:
"""A message in a conversation transcript containing the role and content.
Messages are in standard format with roles normalized to user/assistant.
"""
role: Literal["user", "assistant"]
content: str
timestamp: str | None = None
@dataclass
class TranscriptionUpdateFrame(DataFrame):
"""A frame containing new messages added to the conversation transcript.
This frame is emitted when new messages are added to the conversation history,
containing only the newly added messages rather than the full transcript.
Messages have normalized roles (user/assistant) regardless of the LLM service used.
"""
messages: List[TranscriptionMessage]
def __str__(self):
pts = format_pts(self.pts)
return f"{self.name}(pts: {pts}, messages: {len(self.messages)})"
@dataclass
class LLMMessagesFrame(DataFrame):
"""A frame containing a list of LLM messages. Used to signal that an LLM
@@ -546,7 +575,8 @@ class EndFrame(ControlFrame):
@dataclass
class LLMFullResponseStartFrame(ControlFrame):
"""Used to indicate the beginning of an LLM response. Following by one or
more TextFrame and a final LLMFullResponseEndFrame."""
more TextFrame and a final LLMFullResponseEndFrame.
"""
pass

View File

@@ -0,0 +1,150 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import List
from loguru import logger
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TranscriptionMessage,
TranscriptionUpdateFrame,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class TranscriptProcessor(FrameProcessor):
"""Processes LLM context frames to generate conversation transcripts.
This processor monitors OpenAILLMContextFrame frames and extracts conversation
content, filtering out system messages and function calls. When new messages
are detected, it emits a TranscriptionUpdateFrame containing only the new
messages.
Each LLM context (OpenAI, Anthropic, Google) provides conversion to the standard format:
[
{
"role": "user",
"content": [{"type": "text", "text": "Hi, how are you?"}]
},
{
"role": "assistant",
"content": [{"type": "text", "text": "Great! And you?"}]
}
]
Events:
on_transcript_update: Emitted when new transcript messages are available.
Args: TranscriptionUpdateFrame containing new messages.
Example:
```python
transcript_processor = TranscriptProcessor()
@transcript_processor.event_handler("on_transcript_update")
async def on_transcript_update(processor, frame):
for msg in frame.messages:
print(f"{msg.role}: {msg.content}")
```
"""
def __init__(self, **kwargs):
"""Initialize the transcript processor.
Args:
**kwargs: Additional arguments passed to FrameProcessor
"""
super().__init__(**kwargs)
self._processed_messages: List[TranscriptionMessage] = []
self._register_event_handler("on_transcript_update")
def _extract_messages(self, messages: List[dict]) -> List[TranscriptionMessage]:
"""Extract conversation messages from standard format.
Args:
messages: List of messages in standard format with structured content
Returns:
List[TranscriptionMessage]: Normalized conversation messages
"""
result = []
for msg in messages:
# Only process user and assistant messages
if msg["role"] not in ("user", "assistant"):
continue
content = msg.get("content", [])
if isinstance(content, list):
# Extract text from structured content
text_parts = []
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
text_parts.append(part["text"])
if text_parts:
result.append(
TranscriptionMessage(role=msg["role"], content=" ".join(text_parts))
)
return result
def _find_new_messages(self, current: List[TranscriptionMessage]) -> List[TranscriptionMessage]:
"""Find messages in current that aren't in self._processed_messages.
Args:
current: List of current messages
Returns:
List[TranscriptionMessage]: New messages not yet processed
"""
if not self._processed_messages:
return current
processed_len = len(self._processed_messages)
if len(current) <= processed_len:
return []
return current[processed_len:]
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process incoming frames, watching for OpenAILLMContextFrame.
Args:
frame: The frame to process
direction: Frame processing direction
Raises:
ErrorFrame: If message processing fails
"""
await super().process_frame(frame, direction)
if isinstance(frame, OpenAILLMContextFrame):
try:
# Convert context messages to standard format
standard_messages = []
for msg in frame.context.messages:
converted = frame.context.to_standard_messages(msg)
standard_messages.extend(converted)
# Extract and process messages
current_messages = self._extract_messages(standard_messages)
new_messages = self._find_new_messages(current_messages)
if new_messages:
# Update state and notify listeners
self._processed_messages.extend(new_messages)
update_frame = TranscriptionUpdateFrame(messages=new_messages)
await self._call_event_handler("on_transcript_update", update_frame)
await self.push_frame(update_frame)
except Exception as e:
logger.error(f"Error processing transcript in {self}: {e}")
await self.push_error(ErrorFrame(str(e)))
# Always push the original frame downstream
await self.push_frame(frame, direction)