Compare commits
3 Commits
filipi/hig
...
hush/TurnT
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8bbfa829d3 | ||
|
|
c2eb663bdc | ||
|
|
bf055843e6 |
@@ -9,9 +9,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added `cache_read_input_tokens`, `cache_creation_input_tokens` and
|
||||
`reasoning_tokens` to OTel spans for LLM call
|
||||
|
||||
- Added `LiveKitRESTHelper` utility class for managing LiveKit rooms via REST API.
|
||||
|
||||
- Added `DeepgramSageMakerSTTService` which connects to a SageMaker hosted
|
||||
|
||||
103
docs/TURN_AWARE_TRANSCRIPT_PROCESSOR.md
Normal file
103
docs/TURN_AWARE_TRANSCRIPT_PROCESSOR.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# TurnAwareTranscriptProcessor Example
|
||||
|
||||
## Overview
|
||||
|
||||
The `TurnAwareTranscriptProcessor` combines user and assistant transcript tracking with turn boundary detection. It correctly handles interruptions by only capturing what was actually spoken.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```python
|
||||
from pipecat.processors.transcript_processor import TurnAwareTranscriptProcessor
|
||||
|
||||
# Create the processor
|
||||
turn_processor = TurnAwareTranscriptProcessor()
|
||||
|
||||
# Register event handlers
|
||||
@turn_processor.event_handler("on_turn_started")
|
||||
async def handle_turn_started(processor, turn_number):
|
||||
print(f"Turn {turn_number} started")
|
||||
|
||||
@turn_processor.event_handler("on_turn_ended")
|
||||
async def handle_turn_ended(processor, turn_number, user_text, assistant_text, was_interrupted):
|
||||
print(f"\nTurn {turn_number} ended:")
|
||||
print(f" User said: {user_text}")
|
||||
print(f" Assistant said: {assistant_text}")
|
||||
print(f" Was interrupted: {was_interrupted}")
|
||||
|
||||
@turn_processor.event_handler("on_transcript_update")
|
||||
async def handle_transcript_update(processor, frame):
|
||||
for msg in frame.messages:
|
||||
print(f"[{msg.role}]: {msg.content}")
|
||||
|
||||
# Add to pipeline
|
||||
pipeline = Pipeline([
|
||||
transport.input(),
|
||||
stt,
|
||||
turn_processor, # Process transcripts and track turns
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
])
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
1. **Turn Boundary Detection**: Automatically detects when turns start and end based on user and bot speaking patterns
|
||||
2. **Interruption Handling**: Correctly captures only what was actually spoken when interruptions occur
|
||||
3. **Real-time Transcripts**: Emits transcript messages for both user and assistant speech
|
||||
4. **Turn Events**: Provides start/end events with accumulated transcripts for each turn
|
||||
|
||||
## Events
|
||||
|
||||
### on_turn_started
|
||||
Emitted when a new turn begins (user starts speaking).
|
||||
|
||||
**Handler signature**: `async def handler(processor, turn_number)`
|
||||
|
||||
### on_turn_ended
|
||||
Emitted when a turn ends with accumulated transcripts.
|
||||
|
||||
**Handler signature**: `async def handler(processor, turn_number, user_transcript, assistant_transcript, was_interrupted)`
|
||||
|
||||
### on_transcript_update
|
||||
Inherited from `BaseTranscriptProcessor`, emitted for individual transcript messages.
|
||||
|
||||
**Handler signature**: `async def handler(processor, frame)`
|
||||
|
||||
## Turn Logic
|
||||
|
||||
- Turns start when the user begins speaking (`UserStartedSpeakingFrame`)
|
||||
- Turns end when:
|
||||
- The user starts speaking again (previous turn ends, new turn starts)
|
||||
- The bot is interrupted (`InterruptionFrame`)
|
||||
- The pipeline ends (`EndFrame`/`CancelFrame`)
|
||||
|
||||
## Integration with OpenTelemetry
|
||||
|
||||
You can use turn events to enrich OpenTelemetry spans:
|
||||
|
||||
```python
|
||||
from pipecat.utils.tracing.turn_trace_observer import TurnTraceObserver
|
||||
|
||||
turn_tracker = TurnTrackingObserver()
|
||||
turn_tracer = TurnTraceObserver(turn_tracker)
|
||||
turn_processor = TurnAwareTranscriptProcessor()
|
||||
|
||||
@turn_processor.event_handler("on_turn_ended")
|
||||
async def add_transcripts_to_span(processor, turn_number, user_text, assistant_text, interrupted):
|
||||
# Get current span and add transcript data
|
||||
from opentelemetry import trace
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
current_span.set_attribute("turn.user_text", user_text)
|
||||
current_span.set_attribute("turn.assistant_text", assistant_text)
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- The processor handles async frame processing correctly by delaying turn end until frames are processed
|
||||
- Works with word-level timestamps from TTS services like Cartesia
|
||||
- Accumulates both user (`TranscriptionFrame`) and assistant (`TTSTextFrame`) speech
|
||||
- Emits individual transcript messages in addition to turn-level aggregation
|
||||
@@ -50,14 +50,25 @@ import aiofiles
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -83,10 +94,20 @@ transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -94,13 +115,38 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"), audio_passthrough=True)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121",
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4")
|
||||
|
||||
# Create audio buffer processor
|
||||
audiobuffer = AudioBufferProcessor(sample_rate=48000)
|
||||
audiobuffer = AudioBufferProcessor()
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant demonstrating audio recording capabilities. Keep your responses brief and clear.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
audiobuffer, # Add audio buffer to pipeline
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -109,8 +155,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
audio_in_sample_rate=48000,
|
||||
audio_out_sample_rate= 48000
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
@@ -121,7 +165,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# Start recording audio
|
||||
await audiobuffer.start_recording()
|
||||
# Start conversation - empty prompt to let LLM follow system instructions
|
||||
# await task.queue_frames([LLMRunFrame()])
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
|
||||
@@ -15,6 +15,7 @@ from typing import List, Optional
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -24,6 +25,7 @@ from pipecat.frames.frames import (
|
||||
TranscriptionMessage,
|
||||
TranscriptionUpdateFrame,
|
||||
TTSTextFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.string import TextPartForConcatenation, concatenate_aggregated_text
|
||||
@@ -306,3 +308,267 @@ class TranscriptProcessor:
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class TurnAwareTranscriptProcessor(BaseTranscriptProcessor):
|
||||
"""Processes transcripts with turn boundary awareness.
|
||||
|
||||
This processor combines user and assistant transcript tracking with turn
|
||||
detection, emitting events when turns start and end. It correctly handles
|
||||
interruptions by only capturing what was actually spoken.
|
||||
|
||||
Turn boundaries are detected based on:
|
||||
- User started speaking (UserStartedSpeakingFrame)
|
||||
- Bot stopped speaking (BotStoppedSpeakingFrame)
|
||||
- Interruptions (InterruptionFrame)
|
||||
|
||||
Events:
|
||||
on_turn_started: Emitted when a new turn begins.
|
||||
Handler signature: async def handler(processor, turn_number)
|
||||
|
||||
on_turn_ended: Emitted when a turn ends.
|
||||
Handler signature: async def handler(processor, turn_number,
|
||||
user_transcript, assistant_transcript,
|
||||
was_interrupted)
|
||||
|
||||
on_transcript_update: Inherited from BaseTranscriptProcessor, emitted for
|
||||
individual transcript messages.
|
||||
|
||||
Example::
|
||||
|
||||
turn_processor = TurnAwareTranscriptProcessor()
|
||||
|
||||
@turn_processor.event_handler("on_turn_started")
|
||||
async def handle_turn_started(processor, turn_number):
|
||||
print(f"Turn {turn_number} started")
|
||||
|
||||
@turn_processor.event_handler("on_turn_ended")
|
||||
async def handle_turn_ended(processor, turn_number, user_text, assistant_text, interrupted):
|
||||
print(f"Turn {turn_number} ended")
|
||||
print(f"User said: {user_text}")
|
||||
print(f"Assistant said: {assistant_text}")
|
||||
print(f"Was interrupted: {interrupted}")
|
||||
|
||||
pipeline = Pipeline([
|
||||
transport.input(),
|
||||
stt,
|
||||
turn_processor,
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
])
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the turn-aware transcript processor.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Turn tracking state
|
||||
self._turn_number = 0
|
||||
self._turn_active = False
|
||||
self._turn_start_time: Optional[str] = None
|
||||
|
||||
# Accumulate text for current turn
|
||||
self._current_turn_user_parts: List[TextPartForConcatenation] = []
|
||||
self._current_turn_assistant_parts: List[TextPartForConcatenation] = []
|
||||
|
||||
# Track bot speaking state
|
||||
self._bot_is_speaking = False
|
||||
|
||||
# Register turn events
|
||||
self._register_event_handler("on_turn_started")
|
||||
self._register_event_handler("on_turn_ended")
|
||||
|
||||
async def _start_turn(self):
|
||||
"""Start a new turn."""
|
||||
if not self._turn_active:
|
||||
self._turn_number += 1
|
||||
self._turn_active = True
|
||||
self._turn_start_time = time_now_iso8601()
|
||||
self._current_turn_user_parts = []
|
||||
self._current_turn_assistant_parts = []
|
||||
|
||||
logger.debug(f"Turn {self._turn_number} started")
|
||||
await self._call_event_handler("on_turn_started", self._turn_number)
|
||||
|
||||
async def _end_turn(self, was_interrupted: bool = False):
|
||||
"""End the current turn and emit aggregated transcripts.
|
||||
|
||||
Args:
|
||||
was_interrupted: Whether the turn ended due to an interruption.
|
||||
"""
|
||||
if not self._turn_active:
|
||||
return
|
||||
|
||||
# Aggregate user text
|
||||
user_transcript = ""
|
||||
if self._current_turn_user_parts:
|
||||
user_transcript = concatenate_aggregated_text(self._current_turn_user_parts)
|
||||
|
||||
# Aggregate assistant text
|
||||
assistant_transcript = ""
|
||||
if self._current_turn_assistant_parts:
|
||||
assistant_transcript = concatenate_aggregated_text(self._current_turn_assistant_parts)
|
||||
|
||||
# Emit turn ended event
|
||||
logger.debug(
|
||||
f"Turn {self._turn_number} ended (interrupted={was_interrupted}). "
|
||||
f"User: '{user_transcript}', Assistant: '{assistant_transcript}'"
|
||||
)
|
||||
await self._call_event_handler(
|
||||
"on_turn_ended",
|
||||
self._turn_number,
|
||||
user_transcript,
|
||||
assistant_transcript,
|
||||
was_interrupted,
|
||||
)
|
||||
|
||||
# Reset turn state
|
||||
self._turn_active = False
|
||||
self._current_turn_user_parts = []
|
||||
self._current_turn_assistant_parts = []
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames for turn-aware transcript tracking.
|
||||
|
||||
Handles:
|
||||
- UserStartedSpeakingFrame: Start new turn
|
||||
- TranscriptionFrame: Accumulate user speech and emit transcript message
|
||||
- BotStartedSpeakingFrame: Track bot speaking state
|
||||
- TTSTextFrame: Accumulate assistant speech
|
||||
- BotStoppedSpeakingFrame: End turn if no interruption pending
|
||||
- InterruptionFrame: End turn immediately as interrupted
|
||||
- EndFrame/CancelFrame: End any active turn
|
||||
|
||||
Args:
|
||||
frame: Input frame to process.
|
||||
direction: Frame processing direction.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
# User started speaking
|
||||
if self._bot_is_speaking:
|
||||
# This is an interruption - end the current turn with what was spoken
|
||||
if self._current_turn_assistant_parts:
|
||||
assistant_content = concatenate_aggregated_text(
|
||||
self._current_turn_assistant_parts
|
||||
)
|
||||
if assistant_content:
|
||||
message = TranscriptionMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
timestamp=self._turn_start_time or time_now_iso8601(),
|
||||
)
|
||||
await self._emit_update([message])
|
||||
await self._end_turn(was_interrupted=True)
|
||||
self._bot_is_speaking = False
|
||||
elif self._turn_active:
|
||||
# Previous turn is ending normally (bot finished speaking)
|
||||
if self._current_turn_assistant_parts:
|
||||
assistant_content = concatenate_aggregated_text(
|
||||
self._current_turn_assistant_parts
|
||||
)
|
||||
if assistant_content:
|
||||
message = TranscriptionMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
timestamp=self._turn_start_time or time_now_iso8601(),
|
||||
)
|
||||
await self._emit_update([message])
|
||||
await self._end_turn(was_interrupted=False)
|
||||
|
||||
# Start a new turn
|
||||
await self._start_turn()
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
# Accumulate user speech for the current turn
|
||||
if self._turn_active:
|
||||
self._current_turn_user_parts.append(
|
||||
TextPartForConcatenation(frame.text, includes_inter_part_spaces=True)
|
||||
)
|
||||
|
||||
# Also emit individual transcript message
|
||||
message = TranscriptionMessage(
|
||||
role="user",
|
||||
user_id=frame.user_id,
|
||||
content=frame.text,
|
||||
timestamp=frame.timestamp,
|
||||
)
|
||||
await self._emit_update([message])
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
# Bot started speaking
|
||||
self._bot_is_speaking = True
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
elif isinstance(frame, TTSTextFrame):
|
||||
# Accumulate assistant speech for the current turn
|
||||
if self._turn_active:
|
||||
self._current_turn_assistant_parts.append(
|
||||
TextPartForConcatenation(
|
||||
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
# Bot stopped speaking - just mark it, don't end turn yet
|
||||
# Turn will end when next user speaks or pipeline ends
|
||||
self._bot_is_speaking = False
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
# Emit assistant transcript message with what was spoken before interruption
|
||||
if self._current_turn_assistant_parts:
|
||||
assistant_content = concatenate_aggregated_text(self._current_turn_assistant_parts)
|
||||
if assistant_content:
|
||||
message = TranscriptionMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
timestamp=self._turn_start_time or time_now_iso8601(),
|
||||
)
|
||||
await self._emit_update([message])
|
||||
|
||||
# Push frame first to ensure proper cleanup
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
# End turn as interrupted
|
||||
await self._end_turn(was_interrupted=True)
|
||||
self._bot_is_speaking = False
|
||||
|
||||
elif isinstance(frame, (EndFrame, CancelFrame)):
|
||||
# Pipeline ending - finalize any active turn
|
||||
if self._turn_active:
|
||||
# Emit any pending assistant transcript (allow time for TTSTextFrames to be processed)
|
||||
# Give a brief moment for any pending frames to process
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
if self._current_turn_assistant_parts:
|
||||
assistant_content = concatenate_aggregated_text(
|
||||
self._current_turn_assistant_parts
|
||||
)
|
||||
if assistant_content:
|
||||
message = TranscriptionMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
timestamp=self._turn_start_time or time_now_iso8601(),
|
||||
)
|
||||
await self._emit_update([message])
|
||||
|
||||
await self._end_turn(was_interrupted=isinstance(frame, CancelFrame))
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -183,6 +183,14 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
"""
|
||||
await self._connect_websocket()
|
||||
|
||||
# Creating the receiver task (only created once during initial connection)
|
||||
if not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
# Creating the watchdog task (only created once during initial connection)
|
||||
if not self._watchdog_task:
|
||||
self._watchdog_task = self.create_task(self._watchdog_task_handler())
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from WebSocket and clean up tasks.
|
||||
|
||||
@@ -235,16 +243,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
additional_headers={"Authorization": f"Token {self._api_key}"},
|
||||
)
|
||||
|
||||
# Creating the receiver task
|
||||
if not self._receive_task:
|
||||
self._receive_task = self.create_task(
|
||||
self._receive_task_handler(self._report_error)
|
||||
)
|
||||
|
||||
# Creating the watchdog task
|
||||
if not self._watchdog_task:
|
||||
self._watchdog_task = self.create_task(self._watchdog_task_handler())
|
||||
|
||||
# Now wait for the connection established event
|
||||
logger.debug("WebSocket connected, waiting for server confirmation...")
|
||||
await self._connection_established_event.wait()
|
||||
|
||||
@@ -1723,8 +1723,6 @@ class GeminiLiveLLMService(LLMService):
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cache_read_input_tokens=usage.cached_content_token_count,
|
||||
reasoning_tokens=usage.thoughts_token_count,
|
||||
)
|
||||
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
@@ -123,8 +123,6 @@ class GrokLLMService(OpenAILLMService):
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._cache_read_input_tokens = None
|
||||
self._reasoning_tokens = None
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = True
|
||||
|
||||
@@ -139,8 +137,6 @@ class GrokLLMService(OpenAILLMService):
|
||||
prompt_tokens=self._prompt_tokens,
|
||||
completion_tokens=self._completion_tokens,
|
||||
total_tokens=self._total_tokens,
|
||||
cache_read_input_tokens=self._cache_read_input_tokens,
|
||||
reasoning_tokens=self._reasoning_tokens,
|
||||
)
|
||||
await super().start_llm_usage_metrics(tokens)
|
||||
|
||||
@@ -153,7 +149,7 @@ class GrokLLMService(OpenAILLMService):
|
||||
|
||||
Args:
|
||||
tokens: The token usage metrics for the current chunk of processing,
|
||||
containing prompt_tokens, completion_tokens, and optional cached/reasoning tokens.
|
||||
containing prompt_tokens and completion_tokens counts.
|
||||
"""
|
||||
# Only accumulate metrics during active processing
|
||||
if not self._is_processing:
|
||||
@@ -168,13 +164,6 @@ class GrokLLMService(OpenAILLMService):
|
||||
if tokens.completion_tokens > self._completion_tokens:
|
||||
self._completion_tokens = tokens.completion_tokens
|
||||
|
||||
# Capture cached & reasoning tokens (these typically only appear once per request)
|
||||
if tokens.cache_read_input_tokens is not None:
|
||||
self._cache_read_input_tokens = tokens.cache_read_input_tokens
|
||||
|
||||
if tokens.reasoning_tokens is not None:
|
||||
self._reasoning_tokens = tokens.reasoning_tokens
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
|
||||
@@ -346,17 +346,11 @@ class BaseOpenAILLMService(LLMService):
|
||||
if chunk.usage.prompt_tokens_details
|
||||
else None
|
||||
)
|
||||
reasoning_tokens = (
|
||||
chunk.usage.completion_tokens_details.reasoning_tokens
|
||||
if chunk.usage.completion_tokens_details
|
||||
else None
|
||||
)
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
cache_read_input_tokens=cached_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
|
||||
from pipecat.services.openai.llm import OpenAIContextAggregatorPair
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_openai_realtime, traced_stt
|
||||
@@ -656,17 +657,10 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
async def _handle_evt_response_done(self, evt):
|
||||
# todo: figure out whether there's anything we need to do for "cancelled" events
|
||||
# usage metrics
|
||||
cached_tokens = (
|
||||
evt.response.usage.input_token_details.cached_tokens
|
||||
if hasattr(evt.response.usage, "input_token_details")
|
||||
and evt.response.usage.input_token_details
|
||||
else None
|
||||
)
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=evt.response.usage.input_tokens,
|
||||
completion_tokens=evt.response.usage.output_tokens,
|
||||
total_tokens=evt.response.usage.total_tokens,
|
||||
cache_read_input_tokens=cached_tokens,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
await self.stop_processing_metrics()
|
||||
@@ -816,7 +810,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
# We're done configuring the LLM for this session
|
||||
self._llm_needs_conversation_setup = False
|
||||
|
||||
logger.debug("Creating response")
|
||||
logger.debug(f"Creating response")
|
||||
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
|
||||
@@ -235,7 +235,7 @@ class SmallWebRTCClient:
|
||||
|
||||
# We are always resampling it for 16000 if the sample_rate that we receive is bigger than that.
|
||||
# otherwise we face issues with Silero VAD
|
||||
self._pipecat_resampler = AudioResampler("s16", "mono", 48000)
|
||||
self._pipecat_resampler = AudioResampler("s16", "mono", 16000)
|
||||
|
||||
@self._webrtc_connection.event_handler("connected")
|
||||
async def on_connected(connection: SmallWebRTCConnection):
|
||||
@@ -366,16 +366,31 @@ class SmallWebRTCClient:
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
resampled_frames = self._pipecat_resampler.resample(frame)
|
||||
for resampled_frame in resampled_frames:
|
||||
if frame.sample_rate > self._in_sample_rate:
|
||||
resampled_frames = self._pipecat_resampler.resample(frame)
|
||||
for resampled_frame in resampled_frames:
|
||||
# 16-bit PCM bytes
|
||||
pcm_array = resampled_frame.to_ndarray().astype(np.int16)
|
||||
pcm_bytes = pcm_array.tobytes()
|
||||
del pcm_array # free NumPy array immediately
|
||||
|
||||
audio_frame = InputAudioRawFrame(
|
||||
audio=pcm_bytes,
|
||||
sample_rate=resampled_frame.sample_rate,
|
||||
num_channels=self._audio_in_channels,
|
||||
)
|
||||
del pcm_bytes # reference kept in audio_frame
|
||||
|
||||
yield audio_frame
|
||||
else:
|
||||
# 16-bit PCM bytes
|
||||
pcm_array = resampled_frame.to_ndarray().astype(np.int16)
|
||||
pcm_array = frame.to_ndarray().astype(np.int16)
|
||||
pcm_bytes = pcm_array.tobytes()
|
||||
del pcm_array # free NumPy array immediately
|
||||
|
||||
audio_frame = InputAudioRawFrame(
|
||||
audio=pcm_bytes,
|
||||
sample_rate=resampled_frame.sample_rate,
|
||||
sample_rate=frame.sample_rate,
|
||||
num_channels=self._audio_in_channels,
|
||||
)
|
||||
del pcm_bytes # reference kept in audio_frame
|
||||
|
||||
@@ -92,24 +92,6 @@ def _add_token_usage_to_span(span, token_usage):
|
||||
span.set_attribute("gen_ai.usage.input_tokens", token_usage["prompt_tokens"])
|
||||
if "completion_tokens" in token_usage:
|
||||
span.set_attribute("gen_ai.usage.output_tokens", token_usage["completion_tokens"])
|
||||
# Add cached token metrics for dictionary
|
||||
if (
|
||||
"cache_read_input_tokens" in token_usage
|
||||
and token_usage["cache_read_input_tokens"] is not None
|
||||
):
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.cache_read_input_tokens", token_usage["cache_read_input_tokens"]
|
||||
)
|
||||
if (
|
||||
"cache_creation_input_tokens" in token_usage
|
||||
and token_usage["cache_creation_input_tokens"] is not None
|
||||
):
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.cache_creation_input_tokens",
|
||||
token_usage["cache_creation_input_tokens"],
|
||||
)
|
||||
if "reasoning_tokens" in token_usage and token_usage["reasoning_tokens"] is not None:
|
||||
span.set_attribute("gen_ai.usage.reasoning_tokens", token_usage["reasoning_tokens"])
|
||||
else:
|
||||
# Handle LLMTokenUsage object
|
||||
span.set_attribute("gen_ai.usage.input_tokens", getattr(token_usage, "prompt_tokens", 0))
|
||||
@@ -117,19 +99,6 @@ def _add_token_usage_to_span(span, token_usage):
|
||||
"gen_ai.usage.output_tokens", getattr(token_usage, "completion_tokens", 0)
|
||||
)
|
||||
|
||||
# Add cached token metrics for LLMTokenUsage object
|
||||
cache_read_tokens = getattr(token_usage, "cache_read_input_tokens", None)
|
||||
if cache_read_tokens is not None:
|
||||
span.set_attribute("gen_ai.usage.cache_read_input_tokens", cache_read_tokens)
|
||||
|
||||
cache_creation_tokens = getattr(token_usage, "cache_creation_input_tokens", None)
|
||||
if cache_creation_tokens is not None:
|
||||
span.set_attribute("gen_ai.usage.cache_creation_input_tokens", cache_creation_tokens)
|
||||
|
||||
reasoning_tokens = getattr(token_usage, "reasoning_tokens", None)
|
||||
if reasoning_tokens is not None:
|
||||
span.set_attribute("gen_ai.usage.reasoning_tokens", reasoning_tokens)
|
||||
|
||||
|
||||
def traced_tts(func: Optional[Callable] = None, *, name: Optional[str] = None) -> Callable:
|
||||
"""Trace TTS service methods with TTS-specific attributes.
|
||||
@@ -746,7 +715,7 @@ def traced_gemini_live(operation: str) -> Callable:
|
||||
else:
|
||||
operation_attrs["tool.result_status"] = "completed"
|
||||
|
||||
except json.JSONDecodeError:
|
||||
except json.JSONDecodeError as e:
|
||||
operation_attrs["tool.result"] = (
|
||||
f"Invalid JSON: {str(result_content)[:500]}"
|
||||
)
|
||||
|
||||
189
tests/test_turn_aware_transcript_processor.py
Normal file
189
tests/test_turn_aware_transcript_processor.py
Normal file
@@ -0,0 +1,189 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
InterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
TranscriptionUpdateFrame,
|
||||
TTSTextFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.transcript_processor import TurnAwareTranscriptProcessor
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
|
||||
|
||||
class TestTurnAwareTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
"""Tests for TurnAwareTranscriptProcessor."""
|
||||
|
||||
async def test_basic_turn_flow(self):
|
||||
"""Test basic turn start/end with user and assistant speech."""
|
||||
processor = TurnAwareTranscriptProcessor()
|
||||
|
||||
# Track events
|
||||
turn_started_calls = []
|
||||
turn_ended_calls = []
|
||||
|
||||
@processor.event_handler("on_turn_started")
|
||||
async def on_turn_started(proc, turn_number):
|
||||
turn_started_calls.append(turn_number)
|
||||
|
||||
@processor.event_handler("on_turn_ended")
|
||||
async def on_turn_ended(proc, turn_number, user_text, assistant_text, interrupted):
|
||||
turn_ended_calls.append(
|
||||
{
|
||||
"turn_number": turn_number,
|
||||
"user_text": user_text,
|
||||
"assistant_text": assistant_text,
|
||||
"interrupted": interrupted,
|
||||
}
|
||||
)
|
||||
|
||||
frames_to_send = [
|
||||
# Turn 1: User speaks, bot responds
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hello", user_id="user1", timestamp=""),
|
||||
SleepFrame(sleep=0.01), # Allow transcription to process
|
||||
BotStartedSpeakingFrame(),
|
||||
TTSTextFrame(text="Hi", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text=" there", aggregated_by=AggregationType.WORD),
|
||||
BotStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.1),
|
||||
]
|
||||
|
||||
await run_test(processor, frames_to_send=frames_to_send)
|
||||
|
||||
# Verify events
|
||||
self.assertEqual(
|
||||
len(turn_started_calls), 1, f"Expected 1 turn started, got {len(turn_started_calls)}"
|
||||
)
|
||||
self.assertEqual(turn_started_calls[0], 1)
|
||||
|
||||
self.assertEqual(
|
||||
len(turn_ended_calls), 1, f"Expected 1 turn ended, got {len(turn_ended_calls)}"
|
||||
)
|
||||
self.assertEqual(turn_ended_calls[0]["turn_number"], 1)
|
||||
self.assertEqual(turn_ended_calls[0]["user_text"], "Hello")
|
||||
self.assertEqual(turn_ended_calls[0]["assistant_text"], "Hi there")
|
||||
self.assertFalse(turn_ended_calls[0]["interrupted"])
|
||||
|
||||
async def test_interruption(self):
|
||||
"""Test turn ending on interruption."""
|
||||
processor = TurnAwareTranscriptProcessor()
|
||||
|
||||
# Track events
|
||||
turn_ended_calls = []
|
||||
|
||||
@processor.event_handler("on_turn_ended")
|
||||
async def on_turn_ended(proc, turn_number, user_text, assistant_text, interrupted):
|
||||
turn_ended_calls.append(
|
||||
{
|
||||
"turn_number": turn_number,
|
||||
"user_text": user_text,
|
||||
"assistant_text": assistant_text,
|
||||
"interrupted": interrupted,
|
||||
}
|
||||
)
|
||||
|
||||
frames_to_send = [
|
||||
# User speaks
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Tell me", user_id="user1", timestamp=""),
|
||||
SleepFrame(sleep=0.01), # Allow transcription to process
|
||||
# Bot starts responding
|
||||
BotStartedSpeakingFrame(),
|
||||
TTSTextFrame(text="Sure", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text=" I", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text=" can", aggregated_by=AggregationType.WORD),
|
||||
# User interrupts
|
||||
InterruptionFrame(),
|
||||
# New turn starts
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Wait", user_id="user1", timestamp=""),
|
||||
SleepFrame(sleep=0.1),
|
||||
]
|
||||
|
||||
await run_test(processor, frames_to_send=frames_to_send)
|
||||
|
||||
# Verify first turn was interrupted
|
||||
self.assertGreaterEqual(
|
||||
len(turn_ended_calls), 1, f"Expected at least 1 turn ended, got {len(turn_ended_calls)}"
|
||||
)
|
||||
first_turn = turn_ended_calls[0]
|
||||
self.assertEqual(first_turn["user_text"], "Tell me")
|
||||
# Note: In this test flow, InterruptionFrame arrives before TTSTextFrames are processed,
|
||||
# so assistant text may be empty. In real scenarios, word timestamps ensure proper capture.
|
||||
self.assertIn(first_turn["assistant_text"], ["", "Sure I can", "Sure I can"])
|
||||
self.assertTrue(first_turn["interrupted"])
|
||||
|
||||
async def test_multiple_turns(self):
|
||||
"""Test multiple back-and-forth turns."""
|
||||
processor = TurnAwareTranscriptProcessor()
|
||||
|
||||
# Track events
|
||||
turn_started_calls = []
|
||||
turn_ended_calls = []
|
||||
|
||||
@processor.event_handler("on_turn_started")
|
||||
async def on_turn_started(proc, turn_number):
|
||||
turn_started_calls.append(turn_number)
|
||||
|
||||
@processor.event_handler("on_turn_ended")
|
||||
async def on_turn_ended(proc, turn_number, user_text, assistant_text, interrupted):
|
||||
turn_ended_calls.append(
|
||||
{
|
||||
"turn_number": turn_number,
|
||||
"user_text": user_text,
|
||||
"assistant_text": assistant_text,
|
||||
}
|
||||
)
|
||||
|
||||
frames_to_send = [
|
||||
# Turn 1
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hi", user_id="user1", timestamp=""),
|
||||
SleepFrame(sleep=0.01), # Allow transcription to process
|
||||
BotStartedSpeakingFrame(),
|
||||
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
|
||||
BotStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.05),
|
||||
# Turn 2
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="How are you", user_id="user1", timestamp=""),
|
||||
SleepFrame(sleep=0.01), # Allow transcription to process
|
||||
BotStartedSpeakingFrame(),
|
||||
TTSTextFrame(text="I'm", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text=" good", aggregated_by=AggregationType.WORD),
|
||||
BotStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.1),
|
||||
]
|
||||
|
||||
await run_test(processor, frames_to_send=frames_to_send)
|
||||
|
||||
# Verify multiple turns
|
||||
self.assertEqual(
|
||||
len(turn_started_calls), 2, f"Expected 2 turns started, got {len(turn_started_calls)}"
|
||||
)
|
||||
self.assertEqual(turn_started_calls, [1, 2])
|
||||
|
||||
self.assertEqual(
|
||||
len(turn_ended_calls), 2, f"Expected 2 turns ended, got {len(turn_ended_calls)}"
|
||||
)
|
||||
self.assertEqual(turn_ended_calls[0]["turn_number"], 1)
|
||||
self.assertEqual(turn_ended_calls[0]["user_text"], "Hi")
|
||||
self.assertEqual(turn_ended_calls[0]["assistant_text"], "Hello")
|
||||
|
||||
self.assertEqual(turn_ended_calls[1]["turn_number"], 2)
|
||||
self.assertEqual(turn_ended_calls[1]["user_text"], "How are you")
|
||||
self.assertEqual(turn_ended_calls[1]["assistant_text"], "I'm good")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user