Merge pull request #1371 from pipecat-ai/mb/openai-realtime-transcription
Add TranscriptProcessor support for OpenAIRealtimeBetaLLMService
This commit is contained in:
10
CHANGELOG.md
10
CHANGELOG.md
@@ -90,6 +90,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated `TranscriptProcessor` to support text output from
|
||||
`OpenAIRealtimeBetaLLMService`.
|
||||
|
||||
- `OpenAIRealtimeBetaLLMService` and `GeminiMultimodalLiveLLMService` now push
|
||||
a `TTSTextFrame`.
|
||||
|
||||
- Updated the default mode for `CartesiaTTSService` and
|
||||
`CartesiaHttpTTSService` to `sonic-2`.
|
||||
|
||||
@@ -120,6 +126,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- Added a Pipecat Cloud deployment example to the `examples` directory.
|
||||
|
||||
- Removed foundational examples 28b and 28c as the TranscriptProcessor no
|
||||
longer has an LLM depedency. Renamed foundational example 28a to
|
||||
`28-transcript-processor.py`.
|
||||
|
||||
## [0.0.58] - 2025-02-26
|
||||
|
||||
### Added
|
||||
|
||||
@@ -147,8 +147,8 @@ Remember, your responses should be short. Just one or two sentences, usually."""
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(),
|
||||
llm, # LLM
|
||||
context_aggregator.assistant(),
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import TranscriptionMessage, TranscriptionUpdateFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.services.anthropic import AnthropicLLMService
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.deepgram import DeepgramSTTService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class TranscriptHandler:
|
||||
"""Handles real-time transcript processing and output.
|
||||
|
||||
Maintains a list of conversation messages and outputs them either to a log
|
||||
or to a file as they are received. Each message includes its timestamp and role.
|
||||
|
||||
Attributes:
|
||||
messages: List of all processed transcript messages
|
||||
output_file: Optional path to file where transcript is saved. If None, outputs to log only.
|
||||
"""
|
||||
|
||||
def __init__(self, output_file: Optional[str] = None):
|
||||
"""Initialize handler with optional file output.
|
||||
|
||||
Args:
|
||||
output_file: Path to output file. If None, outputs to log only.
|
||||
"""
|
||||
self.messages: List[TranscriptionMessage] = []
|
||||
self.output_file: Optional[str] = output_file
|
||||
logger.debug(
|
||||
f"TranscriptHandler initialized {'with output_file=' + output_file if output_file else 'with log output only'}"
|
||||
)
|
||||
|
||||
async def save_message(self, message: TranscriptionMessage):
|
||||
"""Save a single transcript message.
|
||||
|
||||
Outputs the message to the log and optionally to a file.
|
||||
|
||||
Args:
|
||||
message: The message to save
|
||||
"""
|
||||
timestamp = f"[{message.timestamp}] " if message.timestamp else ""
|
||||
line = f"{timestamp}{message.role}: {message.content}"
|
||||
|
||||
# Always log the message
|
||||
logger.info(f"Transcript: {line}")
|
||||
|
||||
# Optionally write to file
|
||||
if self.output_file:
|
||||
try:
|
||||
with open(self.output_file, "a", encoding="utf-8") as f:
|
||||
f.write(line + "\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving transcript message to file: {e}")
|
||||
|
||||
async def on_transcript_update(
|
||||
self, processor: TranscriptProcessor, frame: TranscriptionUpdateFrame
|
||||
):
|
||||
"""Handle new transcript messages.
|
||||
|
||||
Args:
|
||||
processor: The TranscriptProcessor that emitted the update
|
||||
frame: TranscriptionUpdateFrame containing new messages
|
||||
"""
|
||||
logger.debug(f"Received transcript update with {len(frame.messages)} new messages")
|
||||
|
||||
for msg in frame.messages:
|
||||
self.messages.append(msg)
|
||||
await self.save_message(msg)
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
None,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
vad_audio_passthrough=True,
|
||||
),
|
||||
)
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = AnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-5-sonnet-20241022"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative, helpful, and brief way.",
|
||||
},
|
||||
{"role": "user", "content": "Say hello."},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
# Create transcript processor and handler
|
||||
transcript = TranscriptProcessor()
|
||||
transcript_handler = TranscriptHandler() # Output to log only
|
||||
# transcript_handler = TranscriptHandler(output_file="transcript.txt") # Output to file and log
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
transcript.user(), # User transcripts
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
transcript.assistant(), # Assistant transcripts
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, params=PipelineParams(allow_interruptions=True))
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
# Register event handler for transcript updates
|
||||
@transcript.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(processor, frame):
|
||||
await transcript_handler.on_transcript_update(processor, frame)
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
# Stop the pipeline immediately when the participant leaves
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,210 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import TranscriptionMessage, TranscriptionUpdateFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.deepgram import DeepgramSTTService
|
||||
from pipecat.services.google import GoogleLLMService
|
||||
from pipecat.services.openai import OpenAILLMContext
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class TranscriptHandler:
|
||||
"""Handles real-time transcript processing and output.
|
||||
|
||||
Maintains a list of conversation messages and outputs them either to a log
|
||||
or to a file as they are received. Each message includes its timestamp and role.
|
||||
|
||||
Attributes:
|
||||
messages: List of all processed transcript messages
|
||||
output_file: Optional path to file where transcript is saved. If None, outputs to log only.
|
||||
"""
|
||||
|
||||
def __init__(self, output_file: Optional[str] = None, output_db: Optional[str] = None):
|
||||
"""Initialize handler with optional file or database output.
|
||||
|
||||
Args:
|
||||
output_file: Path to output file. If None, outputs to log only.
|
||||
"""
|
||||
self.messages: List[TranscriptionMessage] = []
|
||||
self.output_file: Optional[str] = output_file
|
||||
self.output_db: Optional[str] = output_db
|
||||
|
||||
if self.output_db:
|
||||
self.con = sqlite3.connect("example.db")
|
||||
self.db = self.con.cursor()
|
||||
|
||||
table = self.db.execute("SELECT name FROM sqlite_master WHERE name='messages'")
|
||||
if not (table.fetchone()):
|
||||
self.db.execute(
|
||||
"CREATE TABLE messages(role TEXT, content TEXT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP )"
|
||||
)
|
||||
logger.debug(
|
||||
f"TranscriptHandler initialized; output file: {output_file}, output DB: {output_db}"
|
||||
)
|
||||
|
||||
async def save_message(self, message: TranscriptionMessage):
|
||||
"""Save a single transcript message.
|
||||
|
||||
Outputs the message to the log and optionally to a SQLite database or file.
|
||||
|
||||
Args:
|
||||
message: The message to save
|
||||
"""
|
||||
timestamp = f"[{message.timestamp}] " if message.timestamp else ""
|
||||
line = f"{timestamp}{message.role}: {message.content}"
|
||||
|
||||
# Always log the message
|
||||
logger.info(f"Transcript: {line}")
|
||||
|
||||
# Optionally write to file
|
||||
if self.output_file:
|
||||
try:
|
||||
with open(self.output_file, "a", encoding="utf-8") as f:
|
||||
f.write(line + "\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving transcript message to file: {e}")
|
||||
|
||||
# and/or to a SQLite database
|
||||
if self.output_db:
|
||||
self.db.execute(
|
||||
"INSERT INTO messages VALUES (?, ?, ?)",
|
||||
(message.role, message.content, message.timestamp),
|
||||
)
|
||||
self.con.commit()
|
||||
|
||||
async def on_transcript_update(
|
||||
self, processor: TranscriptProcessor, frame: TranscriptionUpdateFrame
|
||||
):
|
||||
"""Handle new transcript messages.
|
||||
|
||||
Args:
|
||||
processor: The TranscriptProcessor that emitted the update
|
||||
frame: TranscriptionUpdateFrame containing new messages
|
||||
"""
|
||||
logger.debug(f"Received transcript update with {len(frame.messages)} new messages")
|
||||
|
||||
for msg in frame.messages:
|
||||
self.messages.append(msg)
|
||||
await self.save_message(msg)
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
None,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
vad_audio_passthrough=True,
|
||||
),
|
||||
)
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = GoogleLLMService(
|
||||
model="models/gemini-2.0-flash-exp",
|
||||
# model="gemini-exp-1114",
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative, helpful, and brief way.",
|
||||
},
|
||||
{"role": "user", "content": "Say hello."},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
# Create transcript processor and handler
|
||||
transcript = TranscriptProcessor()
|
||||
# Select a TranscriptHandler output method
|
||||
# Uncomment out only one of the following lines:
|
||||
transcript_handler = TranscriptHandler() # Output to log only
|
||||
# transcript_handler = TranscriptHandler(output_file="transcript.txt") # Output to file and log
|
||||
# transcript_handler = TranscriptHandler(output_db="example.db") # Output to SQLite DB and log
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
transcript.user(), # User transcripts
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
transcript.assistant(), # Assistant transcripts
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
# Register event handler for transcript updates
|
||||
@transcript.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(processor, frame):
|
||||
await transcript_handler.on_transcript_update(processor, frame)
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
# Stop the pipeline immediately when the participant leaves
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -90,11 +90,62 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
self._aggregation_start_time: Optional[str] = None
|
||||
|
||||
async def _emit_aggregated_text(self):
|
||||
"""Emit aggregated text as a transcript message."""
|
||||
"""Aggregates and emits text fragments as a transcript message.
|
||||
|
||||
This method uses a heuristic to automatically detect whether text fragments
|
||||
use pre-spacing (spaces at the beginning of fragments) or not, and applies
|
||||
the appropriate joining strategy. It handles fragments from different TTS
|
||||
services with different formatting patterns.
|
||||
|
||||
Examples:
|
||||
Pre-spaced fragments (concatenated):
|
||||
```
|
||||
TTSTextFrame: ["Hello"]
|
||||
TTSTextFrame: [" there"]
|
||||
TTSTextFrame: ["!"]
|
||||
TTSTextFrame: [" How"]
|
||||
TTSTextFrame: ["'s"]
|
||||
TTSTextFrame: [" it"]
|
||||
TTSTextFrame: [" going"]
|
||||
TTSTextFrame: ["?"]
|
||||
```
|
||||
Result: "Hello there! How's it going?"
|
||||
|
||||
Word-by-word fragments (joined with spaces):
|
||||
```
|
||||
TTSTextFrame: ["Hello"]
|
||||
TTSTextFrame: ["there!"]
|
||||
TTSTextFrame: ["How"]
|
||||
TTSTextFrame: ["is"]
|
||||
TTSTextFrame: ["it"]
|
||||
TTSTextFrame: ["going?"]
|
||||
```
|
||||
Result: "Hello there! How is it going?"
|
||||
"""
|
||||
if self._current_text_parts and self._aggregation_start_time:
|
||||
content = " ".join(self._current_text_parts).strip()
|
||||
# Heuristic to detect pre-spaced fragments
|
||||
uses_prespacing = False
|
||||
if len(self._current_text_parts) > 1:
|
||||
# Check if any fragment after the first one starts with whitespace
|
||||
has_spaced_parts = any(
|
||||
part and part[0].isspace() for part in self._current_text_parts[1:]
|
||||
)
|
||||
if has_spaced_parts:
|
||||
uses_prespacing = True
|
||||
|
||||
# Apply appropriate joining method
|
||||
if uses_prespacing:
|
||||
# Pre-spaced fragments - just concatenate
|
||||
content = "".join(self._current_text_parts)
|
||||
else:
|
||||
# Word-by-word fragments - join with spaces
|
||||
content = " ".join(self._current_text_parts)
|
||||
|
||||
# Clean up any excessive whitespace
|
||||
content = content.strip()
|
||||
|
||||
if content:
|
||||
logger.debug(f"Emitting aggregated assistant message: {content}")
|
||||
logger.trace(f"Emitting aggregated assistant message: {content}")
|
||||
message = TranscriptionMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
@@ -102,7 +153,7 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
)
|
||||
await self._emit_update([message])
|
||||
else:
|
||||
logger.debug("No content to emit after stripping whitespace")
|
||||
logger.trace("No content to emit after stripping whitespace")
|
||||
|
||||
# Reset aggregation state
|
||||
self._current_text_parts = []
|
||||
|
||||
@@ -38,6 +38,7 @@ from pipecat.frames.frames import (
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TTSTextFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -312,6 +313,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
# context.add_message({"role": "assistant", "content": [{"type": "text", "text": text}]})
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.push_frame(LLMTextFrame(text=text))
|
||||
await self.push_frame(TTSTextFrame(text=text))
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
async def _transcribe_audio(self, audio, context):
|
||||
|
||||
@@ -43,6 +43,7 @@ from pipecat.frames.frames import (
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TTSTextFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -471,6 +472,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
async def _handle_evt_audio_transcript_delta(self, evt):
|
||||
if evt.delta:
|
||||
await self.push_frame(LLMTextFrame(evt.delta))
|
||||
await self.push_frame(TTSTextFrame(evt.delta))
|
||||
|
||||
async def _handle_evt_speech_started(self, evt):
|
||||
await self._truncate_current_audio_response()
|
||||
|
||||
@@ -235,8 +235,7 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.1),
|
||||
TTSTextFrame(text="Hello"),
|
||||
TTSTextFrame(text="world"),
|
||||
TTSTextFrame(text="!"),
|
||||
TTSTextFrame(text="world!"),
|
||||
SleepFrame(sleep=0.1),
|
||||
StartInterruptionFrame(), # User interrupts here
|
||||
BotStartedSpeakingFrame(),
|
||||
@@ -251,8 +250,7 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
expected_down_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
TTSTextFrame, # "Hello"
|
||||
TTSTextFrame, # "world"
|
||||
TTSTextFrame, # "!"
|
||||
TTSTextFrame, # "world!"
|
||||
TranscriptionUpdateFrame, # First message (emitted due to interruption)
|
||||
StartInterruptionFrame, # Interruption frame comes after the update
|
||||
BotStartedSpeakingFrame,
|
||||
@@ -275,7 +273,7 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
# First update should be interrupted message
|
||||
first_message = received_updates[0].messages[0]
|
||||
self.assertEqual(first_message.role, "assistant")
|
||||
self.assertEqual(first_message.content, "Hello world !")
|
||||
self.assertEqual(first_message.content, "Hello world!")
|
||||
self.assertIsNotNone(first_message.timestamp)
|
||||
|
||||
# Second update should be new response
|
||||
@@ -426,3 +424,57 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(received_updates[0].content, "User message")
|
||||
self.assertEqual(received_updates[1].role, "assistant")
|
||||
self.assertEqual(received_updates[1].content, "Assistant message")
|
||||
|
||||
async def test_text_fragments_with_spaces(self):
|
||||
"""Test aggregating text fragments with various spacing patterns"""
|
||||
processor = AssistantTranscriptProcessor()
|
||||
|
||||
# Track received updates
|
||||
received_updates = []
|
||||
|
||||
@processor.event_handler("on_transcript_update")
|
||||
async def handle_update(proc, frame: TranscriptionUpdateFrame):
|
||||
received_updates.append(frame)
|
||||
|
||||
# Test the specific pattern shared
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.1),
|
||||
TTSTextFrame(text="Hello"),
|
||||
TTSTextFrame(text=" there"),
|
||||
TTSTextFrame(text="!"),
|
||||
TTSTextFrame(text=" How"),
|
||||
TTSTextFrame(text="'s"),
|
||||
TTSTextFrame(text=" it"),
|
||||
TTSTextFrame(text=" going"),
|
||||
TTSTextFrame(text="?"),
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TranscriptionUpdateFrame,
|
||||
]
|
||||
|
||||
# Run test
|
||||
received_frames, _ = await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
# Verify result
|
||||
self.assertEqual(len(received_updates), 1)
|
||||
message = received_updates[0].messages[0]
|
||||
self.assertEqual(message.role, "assistant")
|
||||
# Should be properly joined without extra spaces
|
||||
self.assertEqual(message.content, "Hello there! How's it going?")
|
||||
|
||||
Reference in New Issue
Block a user