Add TranscriptionProcessor
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
@@ -14,12 +15,13 @@ from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.frames.frames import Frame, LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.anthropic import AnthropicLLMService
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.anthropic import AnthropicLLMContext, AnthropicLLMService
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
@@ -29,6 +31,28 @@ logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class TestAnthropicLLMService(AnthropicLLMService):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if isinstance(frame, LLMMessagesFrame):
|
||||
logger.info("Original OpenAI format messages:")
|
||||
logger.info(frame.messages)
|
||||
|
||||
# Convert to Anthropic format
|
||||
context = AnthropicLLMContext.from_messages(frame.messages)
|
||||
logger.info("Converted to Anthropic format:")
|
||||
logger.info(context.messages)
|
||||
|
||||
# Convert back to OpenAI format
|
||||
openai_messages = []
|
||||
for msg in context.messages:
|
||||
converted = context.to_standard_messages(msg)
|
||||
openai_messages.extend(converted)
|
||||
logger.info("Converted back to OpenAI format:")
|
||||
logger.info(openai_messages)
|
||||
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
@@ -50,18 +74,24 @@ async def main():
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = AnthropicLLMService(
|
||||
llm = TestAnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-opus-20240229"
|
||||
)
|
||||
|
||||
# todo: think more about how to handle system prompts in a more general way. OpenAI,
|
||||
# Google, and Anthropic all have slightly different approaches to providing a system
|
||||
# prompt.
|
||||
# Test messages including various formats
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative, helpful, and brief way. Say hello.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Hello! How can I help you today?"},
|
||||
{"type": "text", "text": "I'm ready to assist."},
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Hi there!"},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
|
||||
128
examples/foundational/28a-transcription-update-openai.py
Normal file
128
examples/foundational/28a-transcription-update-openai.py
Normal file
@@ -0,0 +1,128 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMMessagesFrame, TranscriptionMessage, TranscriptionUpdateFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class TranscriptHandler:
|
||||
"""Simple handler to demonstrate transcript processing."""
|
||||
|
||||
def __init__(self):
|
||||
self.messages: List[TranscriptionMessage] = []
|
||||
|
||||
async def on_transcript_update(
|
||||
self, processor: TranscriptProcessor, frame: TranscriptionUpdateFrame
|
||||
):
|
||||
"""Handle new transcript messages."""
|
||||
self.messages.extend(frame.messages)
|
||||
|
||||
# Log the new messages
|
||||
logger.info("New transcript messages:")
|
||||
for msg in frame.messages:
|
||||
logger.info(f"{msg.role}: {msg.content}")
|
||||
|
||||
# Log the full transcript
|
||||
logger.info("Full transcript:")
|
||||
for msg in self.messages:
|
||||
logger.info(f"{msg.role}: {msg.content}")
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative, helpful, and brief way. Say hello.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
# Create transcript processor and handler
|
||||
transcript_processor = TranscriptProcessor()
|
||||
transcript_handler = TranscriptHandler()
|
||||
|
||||
# Register event handler for transcript updates
|
||||
@transcript_processor.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(processor, frame):
|
||||
await transcript_handler.on_transcript_update(processor, frame)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
transcript_processor, # Process transcripts
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMMessagesFrame(messages)])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
128
examples/foundational/28b-transcription-update-anthropic.py
Normal file
128
examples/foundational/28b-transcription-update-anthropic.py
Normal file
@@ -0,0 +1,128 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMMessagesFrame, TranscriptionMessage, TranscriptionUpdateFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.services.anthropic import AnthropicLLMService
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class TranscriptHandler:
|
||||
"""Simple handler to demonstrate transcript processing."""
|
||||
|
||||
def __init__(self):
|
||||
self.messages: List[TranscriptionMessage] = []
|
||||
|
||||
async def on_transcript_update(
|
||||
self, processor: TranscriptProcessor, frame: TranscriptionUpdateFrame
|
||||
):
|
||||
"""Handle new transcript messages."""
|
||||
self.messages.extend(frame.messages)
|
||||
|
||||
# Log the new messages
|
||||
logger.info("New transcript messages:")
|
||||
for msg in frame.messages:
|
||||
logger.info(f"{msg.role}: {msg.content}")
|
||||
|
||||
# Log the full transcript
|
||||
logger.info("Full transcript:")
|
||||
for msg in self.messages:
|
||||
logger.info(f"{msg.role}: {msg.content}")
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = AnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-5-sonnet-20241022"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative, helpful, and brief way.",
|
||||
},
|
||||
{"role": "user", "content": "Say hello."},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
# Create transcript processor and handler
|
||||
transcript_processor = TranscriptProcessor()
|
||||
transcript_handler = TranscriptHandler()
|
||||
|
||||
# Register event handler for transcript updates
|
||||
@transcript_processor.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(processor, frame):
|
||||
await transcript_handler.on_transcript_update(processor, frame)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
transcript_processor, # Process transcripts
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMMessagesFrame(messages)])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
138
examples/foundational/28c-transcription-update-gemini.py
Normal file
138
examples/foundational/28c-transcription-update-gemini.py
Normal file
@@ -0,0 +1,138 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import TranscriptionMessage, TranscriptionUpdateFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.google import GoogleLLMService
|
||||
from pipecat.services.openai import OpenAILLMContext
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class TranscriptHandler:
|
||||
"""Simple handler to demonstrate transcript processing."""
|
||||
|
||||
def __init__(self):
|
||||
self.messages: List[TranscriptionMessage] = []
|
||||
|
||||
async def on_transcript_update(
|
||||
self, processor: TranscriptProcessor, frame: TranscriptionUpdateFrame
|
||||
):
|
||||
"""Handle new transcript messages."""
|
||||
self.messages.extend(frame.messages)
|
||||
|
||||
# Log the new messages
|
||||
logger.info("New transcript messages:")
|
||||
for msg in frame.messages:
|
||||
logger.info(f"{msg.role}: {msg.content}")
|
||||
|
||||
# Log the full transcript
|
||||
logger.info("Full transcript:")
|
||||
for msg in self.messages:
|
||||
logger.info(f"{msg.role}: {msg.content}")
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = GoogleLLMService(
|
||||
model="models/gemini-2.0-flash-exp",
|
||||
# model="gemini-exp-1114",
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative, helpful, and brief way.",
|
||||
},
|
||||
{"role": "user", "content": "Say hello."},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
# Create transcript processor and handler
|
||||
transcript_processor = TranscriptProcessor()
|
||||
transcript_handler = TranscriptHandler()
|
||||
|
||||
# Register event handler for transcript updates
|
||||
@transcript_processor.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(processor, frame):
|
||||
await transcript_handler.on_transcript_update(processor, frame)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
transcript_processor,
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -5,7 +5,7 @@
|
||||
#
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Mapping, Optional, Tuple
|
||||
from typing import Any, List, Literal, Mapping, Optional, Tuple, TypeAlias
|
||||
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.clocks.base_clock import BaseClock
|
||||
@@ -195,7 +195,8 @@ class TranscriptionFrame(TextFrame):
|
||||
@dataclass
|
||||
class InterimTranscriptionFrame(TextFrame):
|
||||
"""A text frame with interim transcription-specific data. Will be placed in
|
||||
the transport's receive queue when a participant speaks."""
|
||||
the transport's receive queue when a participant speaks.
|
||||
"""
|
||||
|
||||
text: str
|
||||
user_id: str
|
||||
@@ -206,6 +207,34 @@ class InterimTranscriptionFrame(TextFrame):
|
||||
return f"{self.name}(user: {self.user_id}, text: [{self.text}], language: {self.language}, timestamp: {self.timestamp})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionMessage:
|
||||
"""A message in a conversation transcript containing the role and content.
|
||||
|
||||
Messages are in standard format with roles normalized to user/assistant.
|
||||
"""
|
||||
|
||||
role: Literal["user", "assistant"]
|
||||
content: str
|
||||
timestamp: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionUpdateFrame(DataFrame):
|
||||
"""A frame containing new messages added to the conversation transcript.
|
||||
|
||||
This frame is emitted when new messages are added to the conversation history,
|
||||
containing only the newly added messages rather than the full transcript.
|
||||
Messages have normalized roles (user/assistant) regardless of the LLM service used.
|
||||
"""
|
||||
|
||||
messages: List[TranscriptionMessage]
|
||||
|
||||
def __str__(self):
|
||||
pts = format_pts(self.pts)
|
||||
return f"{self.name}(pts: {pts}, messages: {len(self.messages)})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMMessagesFrame(DataFrame):
|
||||
"""A frame containing a list of LLM messages. Used to signal that an LLM
|
||||
@@ -546,7 +575,8 @@ class EndFrame(ControlFrame):
|
||||
@dataclass
|
||||
class LLMFullResponseStartFrame(ControlFrame):
|
||||
"""Used to indicate the beginning of an LLM response. Following by one or
|
||||
more TextFrame and a final LLMFullResponseEndFrame."""
|
||||
more TextFrame and a final LLMFullResponseEndFrame.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
150
src/pipecat/processors/transcript_processor.py
Normal file
150
src/pipecat/processors/transcript_processor.py
Normal file
@@ -0,0 +1,150 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TranscriptionMessage,
|
||||
TranscriptionUpdateFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class TranscriptProcessor(FrameProcessor):
|
||||
"""Processes LLM context frames to generate conversation transcripts.
|
||||
|
||||
This processor monitors OpenAILLMContextFrame frames and extracts conversation
|
||||
content, filtering out system messages and function calls. When new messages
|
||||
are detected, it emits a TranscriptionUpdateFrame containing only the new
|
||||
messages.
|
||||
|
||||
Each LLM context (OpenAI, Anthropic, Google) provides conversion to the standard format:
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hi, how are you?"}]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Great! And you?"}]
|
||||
}
|
||||
]
|
||||
|
||||
Events:
|
||||
on_transcript_update: Emitted when new transcript messages are available.
|
||||
Args: TranscriptionUpdateFrame containing new messages.
|
||||
|
||||
Example:
|
||||
```python
|
||||
transcript_processor = TranscriptProcessor()
|
||||
|
||||
@transcript_processor.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(processor, frame):
|
||||
for msg in frame.messages:
|
||||
print(f"{msg.role}: {msg.content}")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the transcript processor.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments passed to FrameProcessor
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._processed_messages: List[TranscriptionMessage] = []
|
||||
self._register_event_handler("on_transcript_update")
|
||||
|
||||
def _extract_messages(self, messages: List[dict]) -> List[TranscriptionMessage]:
|
||||
"""Extract conversation messages from standard format.
|
||||
|
||||
Args:
|
||||
messages: List of messages in standard format with structured content
|
||||
|
||||
Returns:
|
||||
List[TranscriptionMessage]: Normalized conversation messages
|
||||
"""
|
||||
result = []
|
||||
for msg in messages:
|
||||
# Only process user and assistant messages
|
||||
if msg["role"] not in ("user", "assistant"):
|
||||
continue
|
||||
|
||||
content = msg.get("content", [])
|
||||
if isinstance(content, list):
|
||||
# Extract text from structured content
|
||||
text_parts = []
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
text_parts.append(part["text"])
|
||||
|
||||
if text_parts:
|
||||
result.append(
|
||||
TranscriptionMessage(role=msg["role"], content=" ".join(text_parts))
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _find_new_messages(self, current: List[TranscriptionMessage]) -> List[TranscriptionMessage]:
|
||||
"""Find messages in current that aren't in self._processed_messages.
|
||||
|
||||
Args:
|
||||
current: List of current messages
|
||||
|
||||
Returns:
|
||||
List[TranscriptionMessage]: New messages not yet processed
|
||||
"""
|
||||
if not self._processed_messages:
|
||||
return current
|
||||
|
||||
processed_len = len(self._processed_messages)
|
||||
if len(current) <= processed_len:
|
||||
return []
|
||||
|
||||
return current[processed_len:]
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process incoming frames, watching for OpenAILLMContextFrame.
|
||||
|
||||
Args:
|
||||
frame: The frame to process
|
||||
direction: Frame processing direction
|
||||
|
||||
Raises:
|
||||
ErrorFrame: If message processing fails
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
try:
|
||||
# Convert context messages to standard format
|
||||
standard_messages = []
|
||||
for msg in frame.context.messages:
|
||||
converted = frame.context.to_standard_messages(msg)
|
||||
standard_messages.extend(converted)
|
||||
|
||||
# Extract and process messages
|
||||
current_messages = self._extract_messages(standard_messages)
|
||||
new_messages = self._find_new_messages(current_messages)
|
||||
|
||||
if new_messages:
|
||||
# Update state and notify listeners
|
||||
self._processed_messages.extend(new_messages)
|
||||
update_frame = TranscriptionUpdateFrame(messages=new_messages)
|
||||
await self._call_event_handler("on_transcript_update", update_frame)
|
||||
await self.push_frame(update_frame)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing transcript in {self}: {e}")
|
||||
await self.push_error(ErrorFrame(str(e)))
|
||||
|
||||
# Always push the original frame downstream
|
||||
await self.push_frame(frame, direction)
|
||||
Reference in New Issue
Block a user