Compare commits

...

2 Commits

Author SHA1 Message Date
James Hush
8375d299bc Revert 2025-11-28 13:53:12 +01:00
James Hush
98df964e68 fix: propagate skip_tts flag through LLM response frames
- Add skip_tts as an init parameter for TextFrame, LLMFullResponseStartFrame,
  and LLMFullResponseEndFrame instead of setting it post-init
- Update all LLM services to pass skip_tts when creating frames:
  - Anthropic, AWS (Bedrock, Nova Sonic, AgentCore), Google (Gemini, Gemini Live)
  - OpenAI (base, realtime), OpenAI Realtime Beta, SambaNova
- Add _get_skip_tts() helper method in LLMService base class
- Remove push_frame override that was setting skip_tts after frame creation
2025-11-28 13:40:09 +01:00
13 changed files with 80 additions and 89 deletions

View File

@@ -327,23 +327,19 @@ class TextFrame(DataFrame):
Parameters:
text: The text content.
skip_tts: Whether this text should skip TTS processing.
"""
text: str
skip_tts: bool = field(init=False)
skip_tts: bool = field(default=False, kw_only=True)
# Whether any necessary inter-frame (leading/trailing) spaces are already
# included in the text.
# NOTE: Ideally this would be available at init time with a default value,
# but that would impact how subclasses can be initialized (it would require
# mandatory fields of theirs to have defaults to preserve
# non-default-before-default argument order)
includes_inter_frame_spaces: bool = field(init=False)
# Whether this text frame should be appended to the LLM context.
append_to_context: bool = field(init=False)
def __post_init__(self):
super().__post_init__()
self.skip_tts = False
self.includes_inter_frame_spaces = False
self.append_to_context = True
@@ -1630,24 +1626,23 @@ class LLMFullResponseStartFrame(ControlFrame):
Used to indicate the beginning of an LLM response. Followed by one or
more TextFrames and a final LLMFullResponseEndFrame.
Parameters:
skip_tts: Whether LLM output should skip TTS processing.
"""
skip_tts: bool = field(init=False)
def __post_init__(self):
super().__post_init__()
self.skip_tts = False
skip_tts: bool = field(default=False, kw_only=True)
@dataclass
class LLMFullResponseEndFrame(ControlFrame):
"""Frame indicating the end of an LLM response."""
"""Frame indicating the end of an LLM response.
skip_tts: bool = field(init=False)
Parameters:
skip_tts: Whether LLM output should skip TTS processing.
"""
def __post_init__(self):
super().__post_init__()
self.skip_tts = False
skip_tts: bool = field(default=False, kw_only=True)
@dataclass

View File

@@ -327,7 +327,7 @@ class AnthropicLLMService(LLMService):
cache_read_input_tokens = 0
try:
await self.push_frame(LLMFullResponseStartFrame())
await self.push_frame(LLMFullResponseStartFrame(skip_tts=self._get_skip_tts()))
await self.start_processing_metrics()
params_from_context = self._get_llm_invocation_params(context)
@@ -373,7 +373,9 @@ class AnthropicLLMService(LLMService):
if event.type == "content_block_delta":
if hasattr(event.delta, "text"):
await self.push_frame(LLMTextFrame(event.delta.text))
await self.push_frame(
LLMTextFrame(event.delta.text, skip_tts=self._get_skip_tts())
)
completion_tokens_estimate += self._estimate_tokens(event.delta.text)
elif hasattr(event.delta, "partial_json") and tool_use_block:
json_accumulator += event.delta.partial_json
@@ -461,7 +463,7 @@ class AnthropicLLMService(LLMService):
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())
await self.push_frame(LLMFullResponseEndFrame(skip_tts=self._get_skip_tts()))
comp_tokens = (
completion_tokens
if not use_completion_tokens_estimate

View File

@@ -172,7 +172,7 @@ class AWSAgentCoreProcessor(FrameProcessor):
await asyncio.sleep(self._output_response_timeout)
if self._output_response_open:
self._output_response_open = False
await self.push_frame(LLMFullResponseEndFrame())
await self.push_frame(LLMFullResponseEndFrame(skip_tts=self._get_skip_tts()))
async def _push_text_frame(self, text: str):
"""Push a text frame, managing output response bookends."""
@@ -182,11 +182,11 @@ class AWSAgentCoreProcessor(FrameProcessor):
# Open output response if needed
if not self._output_response_open:
await self.push_frame(LLMFullResponseStartFrame())
await self.push_frame(LLMFullResponseStartFrame(skip_tts=self._get_skip_tts()))
self._output_response_open = True
# Push the text frame
await self.push_frame(LLMTextFrame(text))
await self.push_frame(LLMTextFrame(text, skip_tts=self._get_skip_tts()))
self._last_text_frame_time = asyncio.get_event_loop().time()
# Schedule closing the output response after timeout
@@ -253,6 +253,6 @@ class AWSAgentCoreProcessor(FrameProcessor):
if self._close_task and not self._close_task.done():
await self.cancel_task(self._close_task)
self._output_response_open = False
await self.push_frame(LLMFullResponseEndFrame())
await self.push_frame(LLMFullResponseEndFrame(skip_tts=self._get_skip_tts()))
else:
await self.push_frame(frame, direction)

View File

@@ -981,7 +981,7 @@ class AWSBedrockLLMService(LLMService):
using_noop_tool = False
try:
await self.push_frame(LLMFullResponseStartFrame())
await self.push_frame(LLMFullResponseStartFrame(skip_tts=self._get_skip_tts()))
await self.start_processing_metrics()
await self.start_ttfb_metrics()
@@ -1078,7 +1078,9 @@ class AWSBedrockLLMService(LLMService):
if "contentBlockDelta" in event:
delta = event["contentBlockDelta"]["delta"]
if "text" in delta:
await self.push_frame(LLMTextFrame(delta["text"]))
await self.push_frame(
LLMTextFrame(delta["text"], skip_tts=self._get_skip_tts())
)
completion_tokens_estimate += self._estimate_tokens(delta["text"])
elif "toolUse" in delta and "input" in delta["toolUse"]:
# Handle partial JSON for tool use
@@ -1139,7 +1141,7 @@ class AWSBedrockLLMService(LLMService):
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())
await self.push_frame(LLMFullResponseEndFrame(skip_tts=self._get_skip_tts()))
comp_tokens = (
completion_tokens
if not use_completion_tokens_estimate

View File

@@ -1016,7 +1016,7 @@ class AWSNovaSonicLLMService(LLMService):
logger.debug("Assistant response started")
# Report the start of the assistant response.
await self.push_frame(LLMFullResponseStartFrame())
await self.push_frame(LLMFullResponseStartFrame(skip_tts=self._get_skip_tts()))
# Report that equivalent of TTS (this is a speech-to-speech model) started
await self.push_frame(TTSStartedFrame())
@@ -1062,7 +1062,7 @@ class AWSNovaSonicLLMService(LLMService):
# We also need to re-push the LLMFullResponseStartFrame since the
# TTSTextFrame would be ignored otherwise (the interruption frame
# would have cleared the assistant aggregator state).
await self.push_frame(LLMFullResponseStartFrame())
await self.push_frame(LLMFullResponseStartFrame(skip_tts=self._get_skip_tts()))
frame = TTSTextFrame(
self._assistant_text_buffer, aggregated_by=AggregationType.SENTENCE
)
@@ -1071,7 +1071,7 @@ class AWSNovaSonicLLMService(LLMService):
self._may_need_repush_assistant_text = False
# Report the end of the assistant response.
await self.push_frame(LLMFullResponseEndFrame())
await self.push_frame(LLMFullResponseEndFrame(skip_tts=self._get_skip_tts()))
# Report that equivalent of TTS (this is a speech-to-speech model) stopped.
await self.push_frame(TTSStoppedFrame())

View File

@@ -1448,11 +1448,11 @@ class GeminiLiveLLMService(LLMService):
# Update bot responding state and send service start frame
# (AUDIO modality case)
await self._set_bot_is_responding(True)
await self.push_frame(LLMFullResponseStartFrame())
await self.push_frame(LLMFullResponseStartFrame(skip_tts=self._get_skip_tts()))
self._bot_text_buffer += text
self._search_result_buffer += text # Also accumulate for grounding
frame = LLMTextFrame(text=text)
frame = LLMTextFrame(text=text, skip_tts=self._get_skip_tts())
await self.push_frame(frame)
# Check for grounding metadata in server content
@@ -1491,7 +1491,7 @@ class GeminiLiveLLMService(LLMService):
if not self._bot_is_responding:
await self._set_bot_is_responding(True)
await self.push_frame(TTSStartedFrame())
await self.push_frame(LLMFullResponseStartFrame())
await self.push_frame(LLMFullResponseStartFrame(skip_tts=self._get_skip_tts()))
self._bot_audio_buffer.extend(audio)
frame = TTSAudioRawFrame(
@@ -1552,10 +1552,10 @@ class GeminiLiveLLMService(LLMService):
if not text:
# AUDIO modality case
await self.push_frame(TTSStoppedFrame())
await self.push_frame(LLMFullResponseEndFrame())
await self.push_frame(LLMFullResponseEndFrame(skip_tts=self._get_skip_tts()))
else:
# TEXT modality case
await self.push_frame(LLMFullResponseEndFrame())
await self.push_frame(LLMFullResponseEndFrame(skip_tts=self._get_skip_tts()))
@traced_stt
async def _handle_user_transcription(
@@ -1643,7 +1643,7 @@ class GeminiLiveLLMService(LLMService):
if not self._bot_is_responding:
await self._set_bot_is_responding(True)
await self.push_frame(TTSStartedFrame())
await self.push_frame(LLMFullResponseStartFrame())
await self.push_frame(LLMFullResponseStartFrame(skip_tts=self._get_skip_tts()))
frame = TTSTextFrame(text=text, aggregated_by=AggregationType.SENTENCE)
# Gemini Live text already includes any necessary inter-chunk spaces

View File

@@ -876,7 +876,7 @@ class GoogleLLMService(LLMService):
@traced_llm
async def _process_context(self, context: OpenAILLMContext | LLMContext):
await self.push_frame(LLMFullResponseStartFrame())
await self.push_frame(LLMFullResponseStartFrame(skip_tts=self._get_skip_tts()))
prompt_tokens = 0
completion_tokens = 0
@@ -920,7 +920,9 @@ class GoogleLLMService(LLMService):
for part in candidate.content.parts:
if not part.thought and part.text:
search_result += part.text
await self.push_frame(LLMTextFrame(part.text))
await self.push_frame(
LLMTextFrame(part.text, skip_tts=self._get_skip_tts())
)
elif part.function_call:
function_call = part.function_call
id = function_call.id or str(uuid.uuid4())
@@ -1002,7 +1004,7 @@ class GoogleLLMService(LLMService):
reasoning_tokens=reasoning_tokens,
)
)
await self.push_frame(LLMFullResponseEndFrame())
await self.push_frame(LLMFullResponseEndFrame(skip_tts=self._get_skip_tts()))
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process incoming frames and handle different frame types.

View File

@@ -136,7 +136,9 @@ class GoogleLLMOpenAIBetaService(OpenAILLMService):
# Keep iterating through the response to collect all the argument fragments
arguments += tool_call.function.arguments
elif chunk.choices[0].delta.content:
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
await self.push_frame(
LLMTextFrame(chunk.choices[0].delta.content, skip_tts=self._get_skip_tts())
)
# if we got a function name and arguments, check to see if it's a function with
# a registered handler. If so, run the registered callback, save the result to

View File

@@ -9,17 +9,7 @@
import asyncio
import inspect
from dataclasses import dataclass
from typing import (
Any,
Awaitable,
Callable,
Dict,
Mapping,
Optional,
Protocol,
Sequence,
Type,
)
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional, Protocol, Sequence, Type
from loguru import logger
@@ -285,17 +275,13 @@ class LLMService(AIService):
elif isinstance(frame, LLMConfigureOutputFrame):
self._skip_tts = frame.skip_tts
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Pushes a frame.
def _get_skip_tts(self) -> bool:
"""Get the current skip_tts configuration.
Args:
frame: The frame to push.
direction: The direction of frame pushing.
Returns:
The current skip_tts setting for frames generated by this LLM.
"""
if isinstance(frame, (LLMTextFrame, LLMFullResponseStartFrame, LLMFullResponseEndFrame)):
frame.skip_tts = self._skip_tts
await super().push_frame(frame, direction)
return self._skip_tts
async def _handle_interruptions(self, _: InterruptionFrame):
for function_name, entry in self._functions.items():

View File

@@ -13,13 +13,7 @@ from typing import Any, Dict, List, Mapping, Optional
import httpx
from loguru import logger
from openai import (
NOT_GIVEN,
APITimeoutError,
AsyncOpenAI,
AsyncStream,
DefaultAsyncHttpxClient,
)
from openai import NOT_GIVEN, APITimeoutError, AsyncOpenAI, AsyncStream, DefaultAsyncHttpxClient
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
from pydantic import BaseModel, Field
@@ -396,14 +390,20 @@ class BaseOpenAILLMService(LLMService):
# Keep iterating through the response to collect all the argument fragments
arguments += tool_call.function.arguments
elif chunk.choices[0].delta.content:
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
await self.push_frame(
LLMTextFrame(chunk.choices[0].delta.content, skip_tts=self._get_skip_tts())
)
# When gpt-4o-audio / gpt-4o-mini-audio is used for llm or stt+llm
# we need to get LLMTextFrame for the transcript
elif hasattr(chunk.choices[0].delta, "audio") and chunk.choices[0].delta.audio.get(
"transcript"
):
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.audio["transcript"]))
await self.push_frame(
LLMTextFrame(
chunk.choices[0].delta.audio["transcript"], skip_tts=self._get_skip_tts()
)
)
# if we got a function name and arguments, check to see if it's a function with
# a registered handler. If so, run the registered callback, save the result to
@@ -463,11 +463,11 @@ class BaseOpenAILLMService(LLMService):
if context:
try:
await self.push_frame(LLMFullResponseStartFrame())
await self.push_frame(LLMFullResponseStartFrame(skip_tts=self._get_skip_tts()))
await self.start_processing_metrics()
await self._process_context(context)
except httpx.TimeoutException:
await self._call_event_handler("on_completion_timeout")
finally:
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())
await self.push_frame(LLMFullResponseEndFrame(skip_tts=self._get_skip_tts()))

View File

@@ -15,9 +15,7 @@ from typing import Optional
from loguru import logger
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.adapters.services.open_ai_realtime_adapter import (
OpenAIRealtimeLLMAdapter,
)
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
from pipecat.frames.frames import (
AggregationType,
BotStoppedSpeakingFrame,
@@ -284,7 +282,7 @@ class OpenAIRealtimeLLMService(LLMService):
await self._truncate_current_audio_response()
await self.stop_all_metrics()
if self._current_assistant_response:
await self.push_frame(LLMFullResponseEndFrame())
await self.push_frame(LLMFullResponseEndFrame(skip_tts=self._get_skip_tts()))
# Only push TTSStoppedFrame if audio modality is enabled
if self._is_modality_enabled("audio"):
await self.push_frame(TTSStoppedFrame())
@@ -608,7 +606,7 @@ class OpenAIRealtimeLLMService(LLMService):
if evt.item.role == "assistant":
self._current_assistant_response = evt.item
await self.push_frame(LLMFullResponseStartFrame())
await self.push_frame(LLMFullResponseStartFrame(skip_tts=self._get_skip_tts()))
async def _handle_evt_conversation_item_done(self, evt):
"""Handle conversation.item.done event - item is fully completed."""
@@ -669,7 +667,7 @@ class OpenAIRealtimeLLMService(LLMService):
)
await self.start_llm_usage_metrics(tokens)
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())
await self.push_frame(LLMFullResponseEndFrame(skip_tts=self._get_skip_tts()))
self._current_assistant_response = None
# error handling
if evt.response.status == "failed":
@@ -683,7 +681,7 @@ class OpenAIRealtimeLLMService(LLMService):
# We receive text deltas (as opposed to audio transcript deltas) when
# the output modality is "text"
if evt.delta:
frame = LLMTextFrame(evt.delta)
frame = LLMTextFrame(evt.delta, skip_tts=self._get_skip_tts())
await self.push_frame(frame)
async def _handle_evt_audio_transcript_delta(self, evt):
@@ -817,7 +815,7 @@ class OpenAIRealtimeLLMService(LLMService):
logger.debug("Creating response")
await self.push_frame(LLMFullResponseStartFrame())
await self.push_frame(LLMFullResponseStartFrame(skip_tts=self._get_skip_tts()))
await self.start_processing_metrics()
await self.start_ttfb_metrics()
await self.send_client_event(

View File

@@ -265,7 +265,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
await self._truncate_current_audio_response()
await self.stop_all_metrics()
if self._current_assistant_response:
await self.push_frame(LLMFullResponseEndFrame())
await self.push_frame(LLMFullResponseEndFrame(skip_tts=self._get_skip_tts()))
# Only push TTSStoppedFrame if audio modality is enabled
if self._is_modality_enabled("audio"):
await self.push_frame(TTSStoppedFrame())
@@ -564,7 +564,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
self._user_and_response_message_tuple = (evt.item, {"done": False, "output": []})
elif evt.item.role == "assistant":
self._current_assistant_response = evt.item
await self.push_frame(LLMFullResponseStartFrame())
await self.push_frame(LLMFullResponseStartFrame(skip_tts=self._get_skip_tts()))
async def _handle_evt_input_audio_transcription_delta(self, evt):
if self._send_transcription_frames:
@@ -623,7 +623,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
)
await self.start_llm_usage_metrics(tokens)
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())
await self.push_frame(LLMFullResponseEndFrame(skip_tts=self._get_skip_tts()))
self._current_assistant_response = None
# error handling
if evt.response.status == "failed":
@@ -647,11 +647,11 @@ class OpenAIRealtimeBetaLLMService(LLMService):
async def _handle_evt_text_delta(self, evt):
if evt.delta:
await self.push_frame(LLMTextFrame(evt.delta))
await self.push_frame(LLMTextFrame(evt.delta, skip_tts=self._get_skip_tts()))
async def _handle_evt_audio_transcript_delta(self, evt):
if evt.delta:
await self.push_frame(LLMTextFrame(evt.delta))
await self.push_frame(LLMTextFrame(evt.delta, skip_tts=self._get_skip_tts()))
await self.push_frame(TTSTextFrame(evt.delta, aggregated_by=AggregationType.SENTENCE))
async def _handle_evt_speech_started(self, evt):
@@ -747,7 +747,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
logger.debug(f"Creating response: {self._context.get_messages_for_logging()}")
await self.push_frame(LLMFullResponseStartFrame())
await self.push_frame(LLMFullResponseStartFrame(skip_tts=self._get_skip_tts()))
await self.start_processing_metrics()
await self.start_ttfb_metrics()
await self.send_client_event(

View File

@@ -14,9 +14,7 @@ from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk
from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
from pipecat.frames.frames import (
LLMTextFrame,
)
from pipecat.frames.frames import LLMTextFrame
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
@@ -176,14 +174,20 @@ class SambaNovaLLMService(OpenAILLMService): # type: ignore
# Keep iterating through the response to collect all the argument fragments
arguments += tool_call.function.arguments
elif chunk.choices[0].delta.content:
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
await self.push_frame(
LLMTextFrame(chunk.choices[0].delta.content, skip_tts=self._get_skip_tts())
)
# When gpt-4o-audio / gpt-4o-mini-audio is used for llm or stt+llm
# we need to get LLMTextFrame for the transcript
elif hasattr(chunk.choices[0].delta, "audio") and chunk.choices[0].delta.audio.get(
"transcript"
):
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.audio["transcript"]))
await self.push_frame(
LLMTextFrame(
chunk.choices[0].delta.audio["transcript"], skip_tts=self._get_skip_tts()
)
)
# if we got a function name and arguments, check to see if it's a function with
# a registered handler. If so, run the registered callback, save the result to