Compare commits
1 Commits
filipi/sma
...
mb/tts-spe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b8dbf9728b |
1
changelog/3465.changed.md
Normal file
1
changelog/3465.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added a parameter called `append_to_context` to the `TTSSpeakFrame`, which controls whether or not the `TTSSpeakFrame` should be added to the context. By default this value is False.
|
||||
@@ -341,6 +341,11 @@ class TextFrame(DataFrame):
|
||||
|
||||
Parameters:
|
||||
text: The text content.
|
||||
skip_tts: Whether this text should be skipped by the TTS service.
|
||||
includes_inter_frame_spaces: Whether any necessary inter-frame (leading/trailing) spaces are already
|
||||
included in the text.
|
||||
append_to_context: Whether this text should be appended to the LLM context.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
text: str
|
||||
@@ -918,9 +923,12 @@ class TTSSpeakFrame(DataFrame):
|
||||
|
||||
Parameters:
|
||||
text: The text to be spoken.
|
||||
append_to_context: Whether this text should be appended to the LLM context.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
text: str
|
||||
append_to_context: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -49,6 +49,7 @@ from pipecat.frames.frames import (
|
||||
StartFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
TranslationFrame,
|
||||
UserImageRawFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
@@ -639,7 +640,6 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self._started = 0
|
||||
self._function_calls_in_progress: Dict[str, Optional[FunctionCallInProgressFrame]] = {}
|
||||
self._function_calls_image_results: Dict[str, UserImageRawFrame] = {}
|
||||
self._context_updated_tasks: Set[asyncio.Task] = set()
|
||||
@@ -758,7 +758,6 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
async def _handle_interruptions(self, frame: InterruptionFrame):
|
||||
await self._trigger_assistant_turn_stopped()
|
||||
self._started = 0
|
||||
await self.reset()
|
||||
|
||||
async def _handle_function_calls_started(self, frame: FunctionCallsStartedFrame):
|
||||
@@ -904,15 +903,17 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
)
|
||||
|
||||
async def _handle_llm_start(self, _: LLMFullResponseStartFrame):
|
||||
self._started += 1
|
||||
await self._trigger_assistant_turn_started()
|
||||
|
||||
async def _handle_llm_end(self, _: LLMFullResponseEndFrame):
|
||||
self._started -= 1
|
||||
await self._trigger_assistant_turn_stopped()
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
if not self._started or not frame.append_to_context:
|
||||
# Skip TextFrame types not intended to build the assistant context
|
||||
if isinstance(frame, (TranscriptionFrame, TranslationFrame, InterimTranscriptionFrame)):
|
||||
return
|
||||
|
||||
if not frame.append_to_context:
|
||||
return
|
||||
|
||||
# Make sure we really have text (spaces count, too!)
|
||||
@@ -926,18 +927,12 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
)
|
||||
|
||||
async def _handle_thought_start(self, frame: LLMThoughtStartFrame):
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
await self._reset_thought_aggregation()
|
||||
self._thought_append_to_context = frame.append_to_context
|
||||
self._thought_llm = frame.llm
|
||||
self._thought_start_time = time_now_iso8601()
|
||||
|
||||
async def _handle_thought_text(self, frame: LLMThoughtTextFrame):
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
# Make sure we really have text (spaces count, too!)
|
||||
if len(frame.text) == 0:
|
||||
return
|
||||
@@ -949,9 +944,6 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
)
|
||||
|
||||
async def _handle_thought_end(self, frame: LLMThoughtEndFrame):
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
thought = concatenate_aggregated_text(self._thought_aggregation)
|
||||
|
||||
if self._thought_append_to_context:
|
||||
|
||||
@@ -207,6 +207,9 @@ class TTSService(AIService):
|
||||
|
||||
self._processing_text: bool = False
|
||||
|
||||
# Track append_to_context for the current TTS generation (used by WordTTSService subclasses)
|
||||
self._current_append_to_context: Optional[bool] = None
|
||||
|
||||
self._register_event_handler("on_connected")
|
||||
self._register_event_handler("on_disconnected")
|
||||
self._register_event_handler("on_connection_error")
|
||||
@@ -460,7 +463,10 @@ class TTSService(AIService):
|
||||
# Store if we were processing text or not so we can set it back.
|
||||
processing_text = self._processing_text
|
||||
# Assumption: text in TTSSpeakFrame does not include inter-frame spaces
|
||||
await self._push_tts_frames(AggregatedTextFrame(frame.text, AggregationType.SENTENCE))
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(frame.text, AggregationType.SENTENCE),
|
||||
append_to_context=frame.append_to_context,
|
||||
)
|
||||
# We pause processing incoming frames because we are sending data to
|
||||
# the TTS. We pause to avoid audio overlapping.
|
||||
await self._maybe_pause_frame_processing()
|
||||
@@ -571,7 +577,10 @@ class TTSService(AIService):
|
||||
)
|
||||
|
||||
async def _push_tts_frames(
|
||||
self, src_frame: AggregatedTextFrame, includes_inter_frame_spaces: Optional[bool] = False
|
||||
self,
|
||||
src_frame: AggregatedTextFrame,
|
||||
includes_inter_frame_spaces: Optional[bool] = False,
|
||||
append_to_context: Optional[bool] = None,
|
||||
):
|
||||
type = src_frame.aggregated_by
|
||||
text = src_frame.text
|
||||
@@ -623,6 +632,10 @@ class TTSService(AIService):
|
||||
if aggregation_type == type or aggregation_type == "*":
|
||||
transformed_text = await transform(transformed_text, type)
|
||||
|
||||
# Store append_to_context for use by WordTTSService subclasses
|
||||
if append_to_context is not None:
|
||||
self._current_append_to_context = append_to_context
|
||||
|
||||
# Apply any final text preparation (e.g., trailing space)
|
||||
prepared_text = self._prepare_text_for_tts(transformed_text)
|
||||
await self.process_generator(self.run_tts(prepared_text))
|
||||
@@ -639,7 +652,15 @@ class TTSService(AIService):
|
||||
# or transformations.
|
||||
frame = TTSTextFrame(text, aggregated_by=type)
|
||||
frame.includes_inter_frame_spaces = includes_inter_frame_spaces
|
||||
# If append_to_context was explicitly specified (e.g., from TTSSpeakFrame),
|
||||
# use that value; otherwise, use TTSTextFrame's default (True).
|
||||
if self._current_append_to_context is not None:
|
||||
frame.append_to_context = self._current_append_to_context
|
||||
await self.push_frame(frame)
|
||||
# Reset after pushing the frame to avoid affecting subsequent TTS operations.
|
||||
# Note: WordTTSService subclasses don't use _push_text_frames so this reset
|
||||
# does not affect that class.
|
||||
self._current_append_to_context = None
|
||||
|
||||
async def _stop_frame_handler(self):
|
||||
has_started = False
|
||||
@@ -690,6 +711,7 @@ class WordTTSService(TTSService):
|
||||
async def reset_word_timestamps(self):
|
||||
"""Reset word timestamp tracking."""
|
||||
self._initial_word_timestamp = -1
|
||||
self._current_append_to_context = None
|
||||
|
||||
async def add_word_timestamps(self, word_times: List[Tuple[str, float]]):
|
||||
"""Add word timestamps to the processing queue.
|
||||
@@ -783,6 +805,10 @@ class WordTTSService(TTSService):
|
||||
# we can rely on the default includes_inter_frame_spaces=False
|
||||
frame = TTSTextFrame(word, aggregated_by=AggregationType.WORD)
|
||||
frame.pts = self._initial_word_timestamp + timestamp
|
||||
# Apply the append_to_context setting from the original TTSSpeakFrame
|
||||
# if it was explicitly set; otherwise rely on TTSTextFrame's default (True).
|
||||
if self._current_append_to_context is not None:
|
||||
frame.append_to_context = self._current_append_to_context
|
||||
if frame:
|
||||
last_pts = frame.pts
|
||||
await self.push_frame(frame)
|
||||
|
||||
Reference in New Issue
Block a user