Compare commits

..

3 Commits

Author SHA1 Message Date
James Hush
8bbfa829d3 Remove wait 2025-11-26 12:27:02 +01:00
James Hush
c2eb663bdc Add TurnAwareTranscriptProcessor for turn-based transcript tracking
- Implements TurnAwareTranscriptProcessor that combines user and assistant transcript tracking with turn boundary detection
- Correctly handles interruptions by capturing only what was actually spoken
- Emits on_turn_started and on_turn_ended events with accumulated transcripts
- Handles async frame processing with strategic delays to ensure proper text accumulation
- Adds comprehensive tests covering basic flow, interruptions, and multiple turns
- Includes documentation and usage examples
2025-11-26 12:26:25 +01:00
James Hush
bf055843e6 Fix race condition in DeepgramFluxSTTService reconnection
Moved _receive_task and _watchdog_task creation from _connect_websocket() to _connect() to prevent multiple coroutines from attempting to receive from the websocket simultaneously during reconnection.

Previously, when reconnection occurred, _connect_websocket() would be called while the existing _receive_task was still running, causing both to try to receive from the websocket. This resulted in the error: 'cannot call recv while another coroutine is already running recv or recv_streaming'.

Now tasks are created only once during initial connection, and reconnection only re-establishes the websocket connection itself. This matches the pattern used by other websocket services in the codebase.

Fixes issue reported in 0.0.95 where reconnection attempts would fail with recv errors.
2025-11-26 10:11:19 +01:00
12 changed files with 638 additions and 82 deletions

View File

@@ -9,9 +9,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added `cache_read_input_tokens`, `cache_creation_input_tokens` and
`reasoning_tokens` to OTel spans for LLM call
- Added `LiveKitRESTHelper` utility class for managing LiveKit rooms via REST API.
- Added `DeepgramSageMakerSTTService` which connects to a SageMaker hosted

View File

@@ -0,0 +1,103 @@
# TurnAwareTranscriptProcessor Example
## Overview
The `TurnAwareTranscriptProcessor` combines user and assistant transcript tracking with turn boundary detection. It correctly handles interruptions by only capturing what was actually spoken.
## Basic Usage
```python
from pipecat.processors.transcript_processor import TurnAwareTranscriptProcessor
# Create the processor
turn_processor = TurnAwareTranscriptProcessor()
# Register event handlers
@turn_processor.event_handler("on_turn_started")
async def handle_turn_started(processor, turn_number):
print(f"Turn {turn_number} started")
@turn_processor.event_handler("on_turn_ended")
async def handle_turn_ended(processor, turn_number, user_text, assistant_text, was_interrupted):
print(f"\nTurn {turn_number} ended:")
print(f" User said: {user_text}")
print(f" Assistant said: {assistant_text}")
print(f" Was interrupted: {was_interrupted}")
@turn_processor.event_handler("on_transcript_update")
async def handle_transcript_update(processor, frame):
for msg in frame.messages:
print(f"[{msg.role}]: {msg.content}")
# Add to pipeline
pipeline = Pipeline([
transport.input(),
stt,
turn_processor, # Process transcripts and track turns
context_aggregator.user(),
llm,
tts,
transport.output(),
context_aggregator.assistant(),
])
```
## Features
1. **Turn Boundary Detection**: Automatically detects when turns start and end based on user and bot speaking patterns
2. **Interruption Handling**: Correctly captures only what was actually spoken when interruptions occur
3. **Real-time Transcripts**: Emits transcript messages for both user and assistant speech
4. **Turn Events**: Provides start/end events with accumulated transcripts for each turn
## Events
### on_turn_started
Emitted when a new turn begins (user starts speaking).
**Handler signature**: `async def handler(processor, turn_number)`
### on_turn_ended
Emitted when a turn ends with accumulated transcripts.
**Handler signature**: `async def handler(processor, turn_number, user_transcript, assistant_transcript, was_interrupted)`
### on_transcript_update
Inherited from `BaseTranscriptProcessor`, emitted for individual transcript messages.
**Handler signature**: `async def handler(processor, frame)`
## Turn Logic
- Turns start when the user begins speaking (`UserStartedSpeakingFrame`)
- Turns end when:
- The user starts speaking again (previous turn ends, new turn starts)
- The bot is interrupted (`InterruptionFrame`)
- The pipeline ends (`EndFrame`/`CancelFrame`)
## Integration with OpenTelemetry
You can use turn events to enrich OpenTelemetry spans:
```python
from pipecat.utils.tracing.turn_trace_observer import TurnTraceObserver
turn_tracker = TurnTrackingObserver()
turn_tracer = TurnTraceObserver(turn_tracker)
turn_processor = TurnAwareTranscriptProcessor()
@turn_processor.event_handler("on_turn_ended")
async def add_transcripts_to_span(processor, turn_number, user_text, assistant_text, interrupted):
# Get current span and add transcript data
from opentelemetry import trace
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("turn.user_text", user_text)
current_span.set_attribute("turn.assistant_text", assistant_text)
```
## Notes
- The processor handles async frame processing correctly by delaying turn end until frames are processed
- Works with word-level timestamps from TTS services like Cartesia
- Accumulates both user (`TranscriptionFrame`) and assistant (`TTSTextFrame`) speech
- Emits individual transcript messages in addition to turn-level aggregation

View File

@@ -50,14 +50,25 @@ import aiofiles
from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import LLMRunFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
@@ -83,10 +94,20 @@ transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
}
@@ -94,13 +115,38 @@ transport_params = {
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"), audio_passthrough=True)
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121",
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4")
# Create audio buffer processor
audiobuffer = AudioBufferProcessor(sample_rate=48000)
audiobuffer = AudioBufferProcessor()
messages = [
{
"role": "system",
"content": "You are a helpful assistant demonstrating audio recording capabilities. Keep your responses brief and clear.",
},
]
context = LLMContext(messages)
context_aggregator = LLMContextAggregatorPair(context)
pipeline = Pipeline(
[
transport.input(),
stt,
context_aggregator.user(),
llm,
tts,
transport.output(),
audiobuffer, # Add audio buffer to pipeline
context_aggregator.assistant(),
]
)
@@ -109,8 +155,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
audio_in_sample_rate=48000,
audio_out_sample_rate= 48000
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
@@ -121,7 +165,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
# Start recording audio
await audiobuffer.start_recording()
# Start conversation - empty prompt to let LLM follow system instructions
# await task.queue_frames([LLMRunFrame()])
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):

View File

@@ -15,6 +15,7 @@ from typing import List, Optional
from loguru import logger
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
@@ -24,6 +25,7 @@ from pipecat.frames.frames import (
TranscriptionMessage,
TranscriptionUpdateFrame,
TTSTextFrame,
UserStartedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.string import TextPartForConcatenation, concatenate_aggregated_text
@@ -306,3 +308,267 @@ class TranscriptProcessor:
return handler
return decorator
class TurnAwareTranscriptProcessor(BaseTranscriptProcessor):
"""Processes transcripts with turn boundary awareness.
This processor combines user and assistant transcript tracking with turn
detection, emitting events when turns start and end. It correctly handles
interruptions by only capturing what was actually spoken.
Turn boundaries are detected based on:
- User started speaking (UserStartedSpeakingFrame)
- Bot stopped speaking (BotStoppedSpeakingFrame)
- Interruptions (InterruptionFrame)
Events:
on_turn_started: Emitted when a new turn begins.
Handler signature: async def handler(processor, turn_number)
on_turn_ended: Emitted when a turn ends.
Handler signature: async def handler(processor, turn_number,
user_transcript, assistant_transcript,
was_interrupted)
on_transcript_update: Inherited from BaseTranscriptProcessor, emitted for
individual transcript messages.
Example::
turn_processor = TurnAwareTranscriptProcessor()
@turn_processor.event_handler("on_turn_started")
async def handle_turn_started(processor, turn_number):
print(f"Turn {turn_number} started")
@turn_processor.event_handler("on_turn_ended")
async def handle_turn_ended(processor, turn_number, user_text, assistant_text, interrupted):
print(f"Turn {turn_number} ended")
print(f"User said: {user_text}")
print(f"Assistant said: {assistant_text}")
print(f"Was interrupted: {interrupted}")
pipeline = Pipeline([
transport.input(),
stt,
turn_processor,
context_aggregator.user(),
llm,
tts,
transport.output(),
context_aggregator.assistant(),
])
"""
def __init__(self, **kwargs):
"""Initialize the turn-aware transcript processor.
Args:
**kwargs: Additional arguments passed to parent class.
"""
super().__init__(**kwargs)
# Turn tracking state
self._turn_number = 0
self._turn_active = False
self._turn_start_time: Optional[str] = None
# Accumulate text for current turn
self._current_turn_user_parts: List[TextPartForConcatenation] = []
self._current_turn_assistant_parts: List[TextPartForConcatenation] = []
# Track bot speaking state
self._bot_is_speaking = False
# Register turn events
self._register_event_handler("on_turn_started")
self._register_event_handler("on_turn_ended")
async def _start_turn(self):
"""Start a new turn."""
if not self._turn_active:
self._turn_number += 1
self._turn_active = True
self._turn_start_time = time_now_iso8601()
self._current_turn_user_parts = []
self._current_turn_assistant_parts = []
logger.debug(f"Turn {self._turn_number} started")
await self._call_event_handler("on_turn_started", self._turn_number)
async def _end_turn(self, was_interrupted: bool = False):
"""End the current turn and emit aggregated transcripts.
Args:
was_interrupted: Whether the turn ended due to an interruption.
"""
if not self._turn_active:
return
# Aggregate user text
user_transcript = ""
if self._current_turn_user_parts:
user_transcript = concatenate_aggregated_text(self._current_turn_user_parts)
# Aggregate assistant text
assistant_transcript = ""
if self._current_turn_assistant_parts:
assistant_transcript = concatenate_aggregated_text(self._current_turn_assistant_parts)
# Emit turn ended event
logger.debug(
f"Turn {self._turn_number} ended (interrupted={was_interrupted}). "
f"User: '{user_transcript}', Assistant: '{assistant_transcript}'"
)
await self._call_event_handler(
"on_turn_ended",
self._turn_number,
user_transcript,
assistant_transcript,
was_interrupted,
)
# Reset turn state
self._turn_active = False
self._current_turn_user_parts = []
self._current_turn_assistant_parts = []
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames for turn-aware transcript tracking.
Handles:
- UserStartedSpeakingFrame: Start new turn
- TranscriptionFrame: Accumulate user speech and emit transcript message
- BotStartedSpeakingFrame: Track bot speaking state
- TTSTextFrame: Accumulate assistant speech
- BotStoppedSpeakingFrame: End turn if no interruption pending
- InterruptionFrame: End turn immediately as interrupted
- EndFrame/CancelFrame: End any active turn
Args:
frame: Input frame to process.
direction: Frame processing direction.
"""
await super().process_frame(frame, direction)
if isinstance(frame, UserStartedSpeakingFrame):
# User started speaking
if self._bot_is_speaking:
# This is an interruption - end the current turn with what was spoken
if self._current_turn_assistant_parts:
assistant_content = concatenate_aggregated_text(
self._current_turn_assistant_parts
)
if assistant_content:
message = TranscriptionMessage(
role="assistant",
content=assistant_content,
timestamp=self._turn_start_time or time_now_iso8601(),
)
await self._emit_update([message])
await self._end_turn(was_interrupted=True)
self._bot_is_speaking = False
elif self._turn_active:
# Previous turn is ending normally (bot finished speaking)
if self._current_turn_assistant_parts:
assistant_content = concatenate_aggregated_text(
self._current_turn_assistant_parts
)
if assistant_content:
message = TranscriptionMessage(
role="assistant",
content=assistant_content,
timestamp=self._turn_start_time or time_now_iso8601(),
)
await self._emit_update([message])
await self._end_turn(was_interrupted=False)
# Start a new turn
await self._start_turn()
await self.push_frame(frame, direction)
elif isinstance(frame, TranscriptionFrame):
# Accumulate user speech for the current turn
if self._turn_active:
self._current_turn_user_parts.append(
TextPartForConcatenation(frame.text, includes_inter_part_spaces=True)
)
# Also emit individual transcript message
message = TranscriptionMessage(
role="user",
user_id=frame.user_id,
content=frame.text,
timestamp=frame.timestamp,
)
await self._emit_update([message])
await self.push_frame(frame, direction)
elif isinstance(frame, BotStartedSpeakingFrame):
# Bot started speaking
self._bot_is_speaking = True
await self.push_frame(frame, direction)
elif isinstance(frame, TTSTextFrame):
# Accumulate assistant speech for the current turn
if self._turn_active:
self._current_turn_assistant_parts.append(
TextPartForConcatenation(
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
)
)
await self.push_frame(frame, direction)
elif isinstance(frame, BotStoppedSpeakingFrame):
# Bot stopped speaking - just mark it, don't end turn yet
# Turn will end when next user speaks or pipeline ends
self._bot_is_speaking = False
await self.push_frame(frame, direction)
elif isinstance(frame, InterruptionFrame):
# Emit assistant transcript message with what was spoken before interruption
if self._current_turn_assistant_parts:
assistant_content = concatenate_aggregated_text(self._current_turn_assistant_parts)
if assistant_content:
message = TranscriptionMessage(
role="assistant",
content=assistant_content,
timestamp=self._turn_start_time or time_now_iso8601(),
)
await self._emit_update([message])
# Push frame first to ensure proper cleanup
await self.push_frame(frame, direction)
# End turn as interrupted
await self._end_turn(was_interrupted=True)
self._bot_is_speaking = False
elif isinstance(frame, (EndFrame, CancelFrame)):
# Pipeline ending - finalize any active turn
if self._turn_active:
# Emit any pending assistant transcript (allow time for TTSTextFrames to be processed)
# Give a brief moment for any pending frames to process
import asyncio
await asyncio.sleep(0.001)
if self._current_turn_assistant_parts:
assistant_content = concatenate_aggregated_text(
self._current_turn_assistant_parts
)
if assistant_content:
message = TranscriptionMessage(
role="assistant",
content=assistant_content,
timestamp=self._turn_start_time or time_now_iso8601(),
)
await self._emit_update([message])
await self._end_turn(was_interrupted=isinstance(frame, CancelFrame))
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)

View File

@@ -183,6 +183,14 @@ class DeepgramFluxSTTService(WebsocketSTTService):
"""
await self._connect_websocket()
# Creating the receiver task (only created once during initial connection)
if not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
# Creating the watchdog task (only created once during initial connection)
if not self._watchdog_task:
self._watchdog_task = self.create_task(self._watchdog_task_handler())
async def _disconnect(self):
"""Disconnect from WebSocket and clean up tasks.
@@ -235,16 +243,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
additional_headers={"Authorization": f"Token {self._api_key}"},
)
# Creating the receiver task
if not self._receive_task:
self._receive_task = self.create_task(
self._receive_task_handler(self._report_error)
)
# Creating the watchdog task
if not self._watchdog_task:
self._watchdog_task = self.create_task(self._watchdog_task_handler())
# Now wait for the connection established event
logger.debug("WebSocket connected, waiting for server confirmation...")
await self._connection_established_event.wait()

View File

@@ -1723,8 +1723,6 @@ class GeminiLiveLLMService(LLMService):
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cache_read_input_tokens=usage.cached_content_token_count,
reasoning_tokens=usage.thoughts_token_count,
)
await self.start_llm_usage_metrics(tokens)

View File

@@ -123,8 +123,6 @@ class GrokLLMService(OpenAILLMService):
self._prompt_tokens = 0
self._completion_tokens = 0
self._total_tokens = 0
self._cache_read_input_tokens = None
self._reasoning_tokens = None
self._has_reported_prompt_tokens = False
self._is_processing = True
@@ -139,8 +137,6 @@ class GrokLLMService(OpenAILLMService):
prompt_tokens=self._prompt_tokens,
completion_tokens=self._completion_tokens,
total_tokens=self._total_tokens,
cache_read_input_tokens=self._cache_read_input_tokens,
reasoning_tokens=self._reasoning_tokens,
)
await super().start_llm_usage_metrics(tokens)
@@ -153,7 +149,7 @@ class GrokLLMService(OpenAILLMService):
Args:
tokens: The token usage metrics for the current chunk of processing,
containing prompt_tokens, completion_tokens, and optional cached/reasoning tokens.
containing prompt_tokens and completion_tokens counts.
"""
# Only accumulate metrics during active processing
if not self._is_processing:
@@ -168,13 +164,6 @@ class GrokLLMService(OpenAILLMService):
if tokens.completion_tokens > self._completion_tokens:
self._completion_tokens = tokens.completion_tokens
# Capture cached & reasoning tokens (these typically only appear once per request)
if tokens.cache_read_input_tokens is not None:
self._cache_read_input_tokens = tokens.cache_read_input_tokens
if tokens.reasoning_tokens is not None:
self._reasoning_tokens = tokens.reasoning_tokens
def create_context_aggregator(
self,
context: OpenAILLMContext,

View File

@@ -346,17 +346,11 @@ class BaseOpenAILLMService(LLMService):
if chunk.usage.prompt_tokens_details
else None
)
reasoning_tokens = (
chunk.usage.completion_tokens_details.reasoning_tokens
if chunk.usage.completion_tokens_details
else None
)
tokens = LLMTokenUsage(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
cache_read_input_tokens=cached_tokens,
reasoning_tokens=reasoning_tokens,
)
await self.start_llm_usage_metrics(tokens)

View File

@@ -57,6 +57,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
from pipecat.services.openai.llm import OpenAIContextAggregatorPair
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_openai_realtime, traced_stt
@@ -656,17 +657,10 @@ class OpenAIRealtimeLLMService(LLMService):
async def _handle_evt_response_done(self, evt):
# todo: figure out whether there's anything we need to do for "cancelled" events
# usage metrics
cached_tokens = (
evt.response.usage.input_token_details.cached_tokens
if hasattr(evt.response.usage, "input_token_details")
and evt.response.usage.input_token_details
else None
)
tokens = LLMTokenUsage(
prompt_tokens=evt.response.usage.input_tokens,
completion_tokens=evt.response.usage.output_tokens,
total_tokens=evt.response.usage.total_tokens,
cache_read_input_tokens=cached_tokens,
)
await self.start_llm_usage_metrics(tokens)
await self.stop_processing_metrics()
@@ -816,7 +810,7 @@ class OpenAIRealtimeLLMService(LLMService):
# We're done configuring the LLM for this session
self._llm_needs_conversation_setup = False
logger.debug("Creating response")
logger.debug(f"Creating response")
await self.push_frame(LLMFullResponseStartFrame())
await self.start_processing_metrics()

View File

@@ -235,7 +235,7 @@ class SmallWebRTCClient:
# We are always resampling it for 16000 if the sample_rate that we receive is bigger than that.
# otherwise we face issues with Silero VAD
self._pipecat_resampler = AudioResampler("s16", "mono", 48000)
self._pipecat_resampler = AudioResampler("s16", "mono", 16000)
@self._webrtc_connection.event_handler("connected")
async def on_connected(connection: SmallWebRTCConnection):
@@ -366,16 +366,31 @@ class SmallWebRTCClient:
await asyncio.sleep(0.01)
continue
resampled_frames = self._pipecat_resampler.resample(frame)
for resampled_frame in resampled_frames:
if frame.sample_rate > self._in_sample_rate:
resampled_frames = self._pipecat_resampler.resample(frame)
for resampled_frame in resampled_frames:
# 16-bit PCM bytes
pcm_array = resampled_frame.to_ndarray().astype(np.int16)
pcm_bytes = pcm_array.tobytes()
del pcm_array # free NumPy array immediately
audio_frame = InputAudioRawFrame(
audio=pcm_bytes,
sample_rate=resampled_frame.sample_rate,
num_channels=self._audio_in_channels,
)
del pcm_bytes # reference kept in audio_frame
yield audio_frame
else:
# 16-bit PCM bytes
pcm_array = resampled_frame.to_ndarray().astype(np.int16)
pcm_array = frame.to_ndarray().astype(np.int16)
pcm_bytes = pcm_array.tobytes()
del pcm_array # free NumPy array immediately
audio_frame = InputAudioRawFrame(
audio=pcm_bytes,
sample_rate=resampled_frame.sample_rate,
sample_rate=frame.sample_rate,
num_channels=self._audio_in_channels,
)
del pcm_bytes # reference kept in audio_frame

View File

@@ -92,24 +92,6 @@ def _add_token_usage_to_span(span, token_usage):
span.set_attribute("gen_ai.usage.input_tokens", token_usage["prompt_tokens"])
if "completion_tokens" in token_usage:
span.set_attribute("gen_ai.usage.output_tokens", token_usage["completion_tokens"])
# Add cached token metrics for dictionary
if (
"cache_read_input_tokens" in token_usage
and token_usage["cache_read_input_tokens"] is not None
):
span.set_attribute(
"gen_ai.usage.cache_read_input_tokens", token_usage["cache_read_input_tokens"]
)
if (
"cache_creation_input_tokens" in token_usage
and token_usage["cache_creation_input_tokens"] is not None
):
span.set_attribute(
"gen_ai.usage.cache_creation_input_tokens",
token_usage["cache_creation_input_tokens"],
)
if "reasoning_tokens" in token_usage and token_usage["reasoning_tokens"] is not None:
span.set_attribute("gen_ai.usage.reasoning_tokens", token_usage["reasoning_tokens"])
else:
# Handle LLMTokenUsage object
span.set_attribute("gen_ai.usage.input_tokens", getattr(token_usage, "prompt_tokens", 0))
@@ -117,19 +99,6 @@ def _add_token_usage_to_span(span, token_usage):
"gen_ai.usage.output_tokens", getattr(token_usage, "completion_tokens", 0)
)
# Add cached token metrics for LLMTokenUsage object
cache_read_tokens = getattr(token_usage, "cache_read_input_tokens", None)
if cache_read_tokens is not None:
span.set_attribute("gen_ai.usage.cache_read_input_tokens", cache_read_tokens)
cache_creation_tokens = getattr(token_usage, "cache_creation_input_tokens", None)
if cache_creation_tokens is not None:
span.set_attribute("gen_ai.usage.cache_creation_input_tokens", cache_creation_tokens)
reasoning_tokens = getattr(token_usage, "reasoning_tokens", None)
if reasoning_tokens is not None:
span.set_attribute("gen_ai.usage.reasoning_tokens", reasoning_tokens)
def traced_tts(func: Optional[Callable] = None, *, name: Optional[str] = None) -> Callable:
"""Trace TTS service methods with TTS-specific attributes.
@@ -746,7 +715,7 @@ def traced_gemini_live(operation: str) -> Callable:
else:
operation_attrs["tool.result_status"] = "completed"
except json.JSONDecodeError:
except json.JSONDecodeError as e:
operation_attrs["tool.result"] = (
f"Invalid JSON: {str(result_content)[:500]}"
)

View File

@@ -0,0 +1,189 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import unittest
from pipecat.frames.frames import (
AggregationType,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
InterruptionFrame,
TranscriptionFrame,
TranscriptionUpdateFrame,
TTSTextFrame,
UserStartedSpeakingFrame,
)
from pipecat.processors.transcript_processor import TurnAwareTranscriptProcessor
from pipecat.tests.utils import SleepFrame, run_test
class TestTurnAwareTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
"""Tests for TurnAwareTranscriptProcessor."""
async def test_basic_turn_flow(self):
"""Test basic turn start/end with user and assistant speech."""
processor = TurnAwareTranscriptProcessor()
# Track events
turn_started_calls = []
turn_ended_calls = []
@processor.event_handler("on_turn_started")
async def on_turn_started(proc, turn_number):
turn_started_calls.append(turn_number)
@processor.event_handler("on_turn_ended")
async def on_turn_ended(proc, turn_number, user_text, assistant_text, interrupted):
turn_ended_calls.append(
{
"turn_number": turn_number,
"user_text": user_text,
"assistant_text": assistant_text,
"interrupted": interrupted,
}
)
frames_to_send = [
# Turn 1: User speaks, bot responds
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Hello", user_id="user1", timestamp=""),
SleepFrame(sleep=0.01), # Allow transcription to process
BotStartedSpeakingFrame(),
TTSTextFrame(text="Hi", aggregated_by=AggregationType.WORD),
TTSTextFrame(text=" there", aggregated_by=AggregationType.WORD),
BotStoppedSpeakingFrame(),
SleepFrame(sleep=0.1),
]
await run_test(processor, frames_to_send=frames_to_send)
# Verify events
self.assertEqual(
len(turn_started_calls), 1, f"Expected 1 turn started, got {len(turn_started_calls)}"
)
self.assertEqual(turn_started_calls[0], 1)
self.assertEqual(
len(turn_ended_calls), 1, f"Expected 1 turn ended, got {len(turn_ended_calls)}"
)
self.assertEqual(turn_ended_calls[0]["turn_number"], 1)
self.assertEqual(turn_ended_calls[0]["user_text"], "Hello")
self.assertEqual(turn_ended_calls[0]["assistant_text"], "Hi there")
self.assertFalse(turn_ended_calls[0]["interrupted"])
async def test_interruption(self):
"""Test turn ending on interruption."""
processor = TurnAwareTranscriptProcessor()
# Track events
turn_ended_calls = []
@processor.event_handler("on_turn_ended")
async def on_turn_ended(proc, turn_number, user_text, assistant_text, interrupted):
turn_ended_calls.append(
{
"turn_number": turn_number,
"user_text": user_text,
"assistant_text": assistant_text,
"interrupted": interrupted,
}
)
frames_to_send = [
# User speaks
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Tell me", user_id="user1", timestamp=""),
SleepFrame(sleep=0.01), # Allow transcription to process
# Bot starts responding
BotStartedSpeakingFrame(),
TTSTextFrame(text="Sure", aggregated_by=AggregationType.WORD),
TTSTextFrame(text=" I", aggregated_by=AggregationType.WORD),
TTSTextFrame(text=" can", aggregated_by=AggregationType.WORD),
# User interrupts
InterruptionFrame(),
# New turn starts
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Wait", user_id="user1", timestamp=""),
SleepFrame(sleep=0.1),
]
await run_test(processor, frames_to_send=frames_to_send)
# Verify first turn was interrupted
self.assertGreaterEqual(
len(turn_ended_calls), 1, f"Expected at least 1 turn ended, got {len(turn_ended_calls)}"
)
first_turn = turn_ended_calls[0]
self.assertEqual(first_turn["user_text"], "Tell me")
# Note: In this test flow, InterruptionFrame arrives before TTSTextFrames are processed,
# so assistant text may be empty. In real scenarios, word timestamps ensure proper capture.
self.assertIn(first_turn["assistant_text"], ["", "Sure I can", "Sure I can"])
self.assertTrue(first_turn["interrupted"])
async def test_multiple_turns(self):
"""Test multiple back-and-forth turns."""
processor = TurnAwareTranscriptProcessor()
# Track events
turn_started_calls = []
turn_ended_calls = []
@processor.event_handler("on_turn_started")
async def on_turn_started(proc, turn_number):
turn_started_calls.append(turn_number)
@processor.event_handler("on_turn_ended")
async def on_turn_ended(proc, turn_number, user_text, assistant_text, interrupted):
turn_ended_calls.append(
{
"turn_number": turn_number,
"user_text": user_text,
"assistant_text": assistant_text,
}
)
frames_to_send = [
# Turn 1
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Hi", user_id="user1", timestamp=""),
SleepFrame(sleep=0.01), # Allow transcription to process
BotStartedSpeakingFrame(),
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
BotStoppedSpeakingFrame(),
SleepFrame(sleep=0.05),
# Turn 2
UserStartedSpeakingFrame(),
TranscriptionFrame(text="How are you", user_id="user1", timestamp=""),
SleepFrame(sleep=0.01), # Allow transcription to process
BotStartedSpeakingFrame(),
TTSTextFrame(text="I'm", aggregated_by=AggregationType.WORD),
TTSTextFrame(text=" good", aggregated_by=AggregationType.WORD),
BotStoppedSpeakingFrame(),
SleepFrame(sleep=0.1),
]
await run_test(processor, frames_to_send=frames_to_send)
# Verify multiple turns
self.assertEqual(
len(turn_started_calls), 2, f"Expected 2 turns started, got {len(turn_started_calls)}"
)
self.assertEqual(turn_started_calls, [1, 2])
self.assertEqual(
len(turn_ended_calls), 2, f"Expected 2 turns ended, got {len(turn_ended_calls)}"
)
self.assertEqual(turn_ended_calls[0]["turn_number"], 1)
self.assertEqual(turn_ended_calls[0]["user_text"], "Hi")
self.assertEqual(turn_ended_calls[0]["assistant_text"], "Hello")
self.assertEqual(turn_ended_calls[1]["turn_number"], 2)
self.assertEqual(turn_ended_calls[1]["user_text"], "How are you")
self.assertEqual(turn_ended_calls[1]["assistant_text"], "I'm good")
if __name__ == "__main__":
unittest.main()