diff --git a/examples/realtime/realtime-aws-nova-sonic.py b/examples/realtime/realtime-aws-nova-sonic.py index 5e6a2037c..0e8e67eaf 100644 --- a/examples/realtime/realtime-aws-nova-sonic.py +++ b/examples/realtime/realtime-aws-nova-sonic.py @@ -30,6 +30,7 @@ from pipecat.processors.aggregators.llm_response_universal import ( from pipecat.runner.types import RunnerArguments from pipecat.runner.utils import create_transport from pipecat.services.aws.nova_sonic.llm import AWSNovaSonicLLMService +from pipecat.services.aws.nova_sonic.session_continuation import SessionContinuationParams from pipecat.services.llm_service import FunctionCallParams from pipecat.transports.base_transport import BaseTransport, TransportParams from pipecat.transports.daily.transport import DailyParams @@ -132,6 +133,16 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): voice="tiffany", system_instruction=system_instruction, ), + # Session continuation is enabled by default, allowing seamless + # conversations longer than the AWS ~8-minute session limit. + # The service rotates sessions in the background with no + # user-perceptible interruption. You can tune the threshold or + # disable it with: session_continuation=SessionContinuationParams(enabled=False) + session_continuation=SessionContinuationParams( + # When to start preparing the next session (default: 360 = 6 min). + # Lower this (e.g. 20) to see a handoff happen quickly during testing. + transition_threshold_seconds=360, + ), # you could choose to pass tools here rather than via context # tools=tools ) diff --git a/src/pipecat/services/aws/nova_sonic/llm.py b/src/pipecat/services/aws/nova_sonic/llm.py index 1d0ed7b8f..e16f288e1 100644 --- a/src/pipecat/services/aws/nova_sonic/llm.py +++ b/src/pipecat/services/aws/nova_sonic/llm.py @@ -12,6 +12,7 @@ bidirectional audio streaming, text generation, and function calling capabilitie import asyncio import base64 +import concurrent.futures import json import time import uuid @@ -50,6 +51,10 @@ from pipecat.frames.frames import ( ) from pipecat.processors.aggregators.llm_context import LLMContext from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.aws.nova_sonic.session_continuation import ( + SessionContinuationHelper, + SessionContinuationParams, +) from pipecat.services.llm_service import LLMService from pipecat.services.settings import NOT_GIVEN, LLMSettings, _NotGiven, assert_given from pipecat.utils.time import time_now_iso8601 @@ -257,6 +262,7 @@ class AWSNovaSonicLLMService(LLMService): settings: Settings | None = None, system_instruction: str | None = None, tools: ToolsSchema | None = None, + session_continuation: SessionContinuationParams | None = None, **kwargs, ): """Initializes the AWS Nova Sonic LLM service. @@ -300,6 +306,9 @@ class AWSNovaSonicLLMService(LLMService): .. deprecated:: 0.0.105 Use ``settings=AWSNovaSonicLLMService.Settings(system_instruction=...)`` instead. tools: Available tools/functions for the model to use. + session_continuation: Configuration for automatic session continuation. + When enabled (the default), sessions are seamlessly rotated before + the AWS time limit (~8 minutes) with no user-perceptible interruption. **kwargs: Additional arguments passed to the parent LLMService. """ # 1. Initialize default_settings with hardcoded defaults @@ -406,6 +415,18 @@ class AWSNovaSonicLLMService(LLMService): self._user_text_buffer = "" self._completed_tool_calls = set() self._audio_input_started = False + + # Session continuation helper. The service itself implements the + # NovaSonicSessionSender protocol (see methods below) so the helper can + # target either the current or next session without coupling to the + # service's internal config. + sc_params = session_continuation or SessionContinuationParams() + self._sc = SessionContinuationHelper( + sc_params, + sender=self, + create_task=lambda coro: self.create_task(coro), + cancel_task=lambda task, timeout: self.cancel_task(task, timeout=timeout), + ) self._pending_speculative_text: str | None = None file_path = files("pipecat.services.aws.nova_sonic").joinpath("ready.wav") @@ -533,6 +554,16 @@ class AWSNovaSonicLLMService(LLMService): if self._triggering_assistant_response: return + # Session continuation: let the helper buffer audio during the transition window + self._sc.on_audio_input(frame.audio) + + # Stop sending audio to the old stream once a handoff is in progress. + # Audio is still being buffered above and will be replayed to the new + # session. Matches reference: old session state set to CLOSING stops + # audio routing before close events are sent. + if self._sc.handoff_in_progress: + return + await self._send_user_audio_event(frame.audio) async def _handle_interruption_frame(self): @@ -555,7 +586,7 @@ class AWSNovaSonicLLMService(LLMService): self._input_audio_content_name = str(uuid.uuid4()) # Create the client - self._client = self._create_client() + self._client = self.create_client() # Start the bidirectional stream self._stream = await self._client.invoke_model_with_bidirectional_stream( @@ -645,14 +676,22 @@ class AWSNovaSonicLLMService(LLMService): text=last_user_message.text, role=last_user_message.role, interactive=True ) - # Start receiving events - self._receive_task = self.create_task(self._receive_task_handler()) + # Start receiving events (bound to the current stream) + self._receive_task = self.create_task(self._receive_task_handler(stream=self._stream)) # Record finished connecting time (must be done before sending assistant response trigger) self._connected_time = time.time() logger.info("Finished connecting") + # Notify session continuation helper of connection and start monitoring + self._sc.set_connected(self._connected_time) + # Seed the helper's history with initial context messages (these wouldn't be + # captured via real-time FINAL text events since they pre-date the session) + for message in llm_connection_params["messages"]: + self._sc.seed_history(message.role.value, message.text) + self._sc.start_monitor() + # If we need to, send assistant response trigger (depends on self._connected_time) if self._triggering_assistant_response: await self._send_assistant_response_trigger() @@ -707,11 +746,17 @@ class AWSNovaSonicLLMService(LLMService): self._audio_input_started = False self._pending_speculative_text = None + # Stop session continuation monitor and notify of disconnect + await self._sc.stop_monitor() + await self._sc.cleanup_next_session() + self._sc.set_disconnected() + logger.info("Finished disconnecting") except Exception as e: await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e) - def _create_client(self) -> BedrockRuntimeClient: + def create_client(self) -> BedrockRuntimeClient: + """Create a new Bedrock runtime client (NovaSonicSessionSender protocol).""" config = Config( endpoint_uri=f"https://bedrock-runtime.{self._region}.amazonaws.com", region=self._region, @@ -723,6 +768,11 @@ class AWSNovaSonicLLMService(LLMService): ) return BedrockRuntimeClient(config=config) + @property + def audio_config(self) -> AudioConfig: + """Return the audio configuration (NovaSonicSessionSender protocol).""" + return self._audio_config + def _is_first_generation_sonic_model(self) -> bool: # Nova Sonic (the older model) is identified by "amazon.nova-sonic-v1:0" return self._settings.model == "amazon.nova-sonic-v1:0" @@ -739,98 +789,26 @@ class AWSNovaSonicLLMService(LLMService): # # LLM communication: input events (pipecat -> LLM) # + # These methods operate on the current session. They're thin wrappers over + # the NovaSonicSessionSender protocol methods (which accept an explicit + # stream/prompt_name), reusing the same Nova Sonic wire-format serialization + # for both the current session and next-session setup during a handoff. + # async def _send_session_start_event(self): - turn_detection_config = ( - f""", - "turnDetectionConfiguration": {{ - "endpointingSensitivity": "{self._settings.endpointing_sensitivity}" - }}""" - if self._settings.endpointing_sensitivity - else "" - ) - - session_start = f""" - {{ - "event": {{ - "sessionStart": {{ - "inferenceConfiguration": {{ - "maxTokens": {self._settings.max_tokens}, - "topP": {self._settings.top_p}, - "temperature": {self._settings.temperature} - }}{turn_detection_config} - }} - }} - }} - """ - await self._send_client_event(session_start) + await self._send_client_event(self.build_session_start_json()) async def _send_prompt_start_event(self, tools: list[Any]): if not self._prompt_name: return - - tools_config = ( - f""", - "toolUseOutputConfiguration": {{ - "mediaType": "application/json" - }}, - "toolConfiguration": {{ - "tools": {json.dumps(tools)} - }} - """ - if tools - else "" - ) - - prompt_start = f''' - {{ - "event": {{ - "promptStart": {{ - "promptName": "{self._prompt_name}", - "textOutputConfiguration": {{ - "mediaType": "text/plain" - }}, - "audioOutputConfiguration": {{ - "mediaType": "audio/lpcm", - "sampleRateHertz": {self._audio_config.output_sample_rate}, - "sampleSizeBits": {self._audio_config.output_sample_size}, - "channelCount": {self._audio_config.output_channel_count}, - "voiceId": "{self._settings.voice}", - "encoding": "base64", - "audioType": "SPEECH" - }}{tools_config} - }} - }} - }} - ''' - await self._send_client_event(prompt_start) + await self.send_prompt_start(tools, self._prompt_name, self._stream) async def _send_audio_input_start_event(self): if not self._prompt_name: return - - audio_content_start = f''' - {{ - "event": {{ - "contentStart": {{ - "promptName": "{self._prompt_name}", - "contentName": "{self._input_audio_content_name}", - "type": "AUDIO", - "interactive": true, - "role": "USER", - "audioInputConfiguration": {{ - "mediaType": "audio/lpcm", - "sampleRateHertz": {self._audio_config.input_sample_rate}, - "sampleSizeBits": {self._audio_config.input_sample_size}, - "channelCount": {self._audio_config.input_channel_count}, - "audioType": "SPEECH", - "encoding": "base64" - }} - }} - }} - }} - ''' - await self._send_client_event(audio_content_start) + await self.send_audio_input_start( + self._prompt_name, self._input_audio_content_name, self._stream + ) self._audio_input_started = True async def _send_text_event(self, text: str, role: Role, interactive: bool = False): @@ -845,70 +823,14 @@ class AWSNovaSonicLLMService(LLMService): """ if not self._stream or not self._prompt_name or not text: return - - content_name = str(uuid.uuid4()) - - text_content_start = f''' - {{ - "event": {{ - "contentStart": {{ - "promptName": "{self._prompt_name}", - "contentName": "{content_name}", - "type": "TEXT", - "interactive": {json.dumps(interactive)}, - "role": "{role.value}", - "textInputConfiguration": {{ - "mediaType": "text/plain" - }} - }} - }} - }} - ''' - await self._send_client_event(text_content_start) - - escaped_text = json.dumps(text) # includes quotes - text_input = f''' - {{ - "event": {{ - "textInput": {{ - "promptName": "{self._prompt_name}", - "contentName": "{content_name}", - "content": {escaped_text} - }} - }} - }} - ''' - await self._send_client_event(text_input) - - text_content_end = f''' - {{ - "event": {{ - "contentEnd": {{ - "promptName": "{self._prompt_name}", - "contentName": "{content_name}" - }} - }} - }} - ''' - await self._send_client_event(text_content_end) + await self.send_text(text, role.value, self._prompt_name, self._stream, interactive) async def _send_user_audio_event(self, audio: bytes): if not self._stream or not self._audio_input_started: return - - blob = base64.b64encode(audio) - audio_event = f''' - {{ - "event": {{ - "audioInput": {{ - "promptName": "{self._prompt_name}", - "contentName": "{self._input_audio_content_name}", - "content": "{blob.decode("utf-8")}" - }} - }} - }} - ''' - await self._send_client_event(audio_event) + await self.send_audio( + audio, self._prompt_name, self._input_audio_content_name, self._stream + ) async def _send_session_end_events(self): if not self._stream or not self._prompt_name: @@ -998,6 +920,248 @@ class AWSNovaSonicLLMService(LLMService): ) await self._stream.input_stream.send(event) + # + # NovaSonicSessionSender protocol implementation + # + # These methods expose the Nova Sonic wire protocol to the session + # continuation helper. Each accepts an explicit ``stream`` / ``prompt_name`` + # so the helper can target either the current session or a pre-created + # next session during a handoff. + # + + def build_session_start_json(self) -> str: + """Build the ``sessionStart`` event JSON. + + Shared between the current and next session setup. + """ + turn_detection_config = ( + f""", + "turnDetectionConfiguration": {{ + "endpointingSensitivity": "{self._settings.endpointing_sensitivity}" + }}""" + if self._settings.endpointing_sensitivity + else "" + ) + return f""" + {{ + "event": {{ + "sessionStart": {{ + "inferenceConfiguration": {{ + "maxTokens": {self._settings.max_tokens}, + "topP": {self._settings.top_p}, + "temperature": {self._settings.temperature} + }}{turn_detection_config} + }} + }} + }} + """ + + async def open_stream(self, client): + """Open a bidirectional stream on the given client.""" + return await client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self._settings.model) + ) + + async def send_event(self, event_json: str, stream): + """Send a raw event JSON to the given stream.""" + if not stream: + return + event = InvokeModelWithBidirectionalStreamInputChunk( + value=BidirectionalInputPayloadPart(bytes_=event_json.encode("utf-8")) + ) + await stream.input_stream.send(event) + + async def send_text( + self, + text: str, + role: str, + prompt_name: str, + stream, + interactive: bool, + ): + """Send a text content block (contentStart/textInput/contentEnd) to the given stream.""" + if not text or not stream or not prompt_name: + return + content_name = str(uuid.uuid4()) + escaped_text = json.dumps(text) + + content_start = f''' + {{ + "event": {{ + "contentStart": {{ + "promptName": "{prompt_name}", + "contentName": "{content_name}", + "type": "TEXT", + "interactive": {json.dumps(interactive)}, + "role": "{role}", + "textInputConfiguration": {{ + "mediaType": "text/plain" + }} + }} + }} + }} + ''' + await self.send_event(content_start, stream) + + text_input = f''' + {{ + "event": {{ + "textInput": {{ + "promptName": "{prompt_name}", + "contentName": "{content_name}", + "content": {escaped_text} + }} + }} + }} + ''' + await self.send_event(text_input, stream) + + content_end = f''' + {{ + "event": {{ + "contentEnd": {{ + "promptName": "{prompt_name}", + "contentName": "{content_name}" + }} + }} + }} + ''' + await self.send_event(content_end, stream) + + async def send_audio_input_start(self, prompt_name: str, content_name: str, stream): + """Send an audio input ``contentStart`` to the given stream.""" + event_json = f''' + {{ + "event": {{ + "contentStart": {{ + "promptName": "{prompt_name}", + "contentName": "{content_name}", + "type": "AUDIO", + "interactive": true, + "role": "USER", + "audioInputConfiguration": {{ + "mediaType": "audio/lpcm", + "sampleRateHertz": {self._audio_config.input_sample_rate}, + "sampleSizeBits": {self._audio_config.input_sample_size}, + "channelCount": {self._audio_config.input_channel_count}, + "audioType": "SPEECH", + "encoding": "base64" + }} + }} + }} + }} + ''' + await self.send_event(event_json, stream) + + async def send_audio(self, audio: bytes, prompt_name: str, content_name: str, stream): + """Send an ``audioInput`` event to the given stream.""" + blob = base64.b64encode(audio) + event_json = f''' + {{ + "event": {{ + "audioInput": {{ + "promptName": "{prompt_name}", + "contentName": "{content_name}", + "content": "{blob.decode("utf-8")}" + }} + }} + }} + ''' + await self.send_event(event_json, stream) + + async def send_prompt_start(self, tools: list, prompt_name: str, stream): + """Send a ``promptStart`` event to the given stream.""" + tools_config = ( + f""", + "toolUseOutputConfiguration": {{ + "mediaType": "application/json" + }}, + "toolConfiguration": {{ + "tools": {json.dumps(tools)} + }} + """ + if tools + else "" + ) + event_json = f''' + {{ + "event": {{ + "promptStart": {{ + "promptName": "{prompt_name}", + "textOutputConfiguration": {{ + "mediaType": "text/plain" + }}, + "audioOutputConfiguration": {{ + "mediaType": "audio/lpcm", + "sampleRateHertz": {self._audio_config.output_sample_rate}, + "sampleSizeBits": {self._audio_config.output_sample_size}, + "channelCount": {self._audio_config.output_channel_count}, + "voiceId": "{self._settings.voice}", + "encoding": "base64", + "audioType": "SPEECH" + }}{tools_config} + }} + }} + }} + ''' + await self.send_event(event_json, stream) + + def get_setup_params(self): + """Return ``(system_instruction, tools)`` for the next session setup.""" + if not self._context: + return None, [] + adapter: AWSNovaSonicLLMAdapter = self.get_llm_adapter() + llm_params = adapter.get_llm_invocation_params( + self._context, system_instruction=self._settings.system_instruction + ) + tools = ( + llm_params["tools"] if llm_params["tools"] else adapter.from_standard_tools(self._tools) + ) + return llm_params["system_instruction"], tools + + async def _run_sc_handoff(self): + """Swap the current session with the pre-created next one.""" + # Snapshot the old session's resources before the helper swaps them out + old_client = self._client + old_stream = self._stream + old_receive_task = self._receive_task + old_prompt_name = self._prompt_name + old_input_audio_content_name = self._input_audio_content_name + + next_session = await self._sc.execute_handoff() + if not next_session: + return + + # Swap in the new session's stream and names. The helper already sent + # sessionStart, promptStart, system instruction, conversation history, + # audioInputStart, and buffered audio to the new stream. + self._client = next_session.client + self._stream = next_session.stream + self._prompt_name = next_session.prompt_name + self._input_audio_content_name = next_session.input_audio_content_name + self._connected_time = time.time() + self._audio_input_started = True + + # Start the main receive loop on the new stream (bound to that stream) + self._receive_task = self.create_task(self._receive_task_handler(stream=self._stream)) + + # Update the helper's connected time so the threshold timer restarts + self._sc.set_connected(self._connected_time) + + logger.info("Session continuation: swap complete, closing old session in background") + + # Close the old session in the background — do not block the pipeline + self.create_task( + self._sc.close_old_session( + old_client, + old_stream, + old_receive_task, + old_prompt_name, + old_input_audio_content_name, + ), + name="sc_close_old_session", + ) + # # LLM communication: output events (LLM -> pipecat) # @@ -1013,11 +1177,28 @@ class AWSNovaSonicLLMService(LLMService): # Each piece of content is wrapped by "contentStart" and "contentEnd" events. The content is # delivered sequentially: one piece of content will end before another starts. # The overall completion is wrapped by "completionStart" and "completionEnd" events. - async def _receive_task_handler(self): + async def _receive_task_handler(self, stream=None): + # Bind to the specific stream given at creation time. + # Do NOT re-read ``self._stream`` in the loop — during a session + # continuation handoff, ``self._stream`` gets swapped to a new session, + # and reading from the wrong stream here would cause two receive loops + # to compete on the same stream (yielding "Invalid input request" from + # the AWS event stream layer). + if stream is None: + stream = self._stream try: - while self._stream and not self._disconnecting: - output = await self._stream.await_output() - result = await output[1].receive() + while stream and not self._disconnecting: + try: + output = await stream.await_output() + result = await output[1].receive() + except concurrent.futures.InvalidStateError: + break + + # After a session continuation handoff, this receive task + # is stale — stop processing events so close_old_session + # can drain the stream without interference. + if stream is not self._stream: + return if result.value and result.value.bytes_: response_data = result.value.bytes_.decode("utf-8") @@ -1048,8 +1229,12 @@ class AWSNovaSonicLLMService(LLMService): await self._handle_completion_end_event(event_json) except Exception as e: if self._disconnecting: - # Errors are kind of expected while disconnecting, so just - # ignore them and do nothing + return + # If this receive task is for a stale (old) stream that was replaced + # by a session continuation handoff, don't reset the conversation — + # the new session is already active on self._stream. + if stream is not self._stream: + logger.debug(f"Session continuation: old receive task error (expected): {e}") return await self.push_error(error_msg=f"Error processing responses: {e}", exception=e) if self._wants_connection: @@ -1085,6 +1270,16 @@ class AWSNovaSonicLLMService(LLMService): self._assistant_is_responding = True await self._report_user_transcription_ended() # Consider user turn over await self._report_assistant_response_started() + elif content.type == ContentType.AUDIO: + # Session continuation: AUDIO contentStart from assistant is the + # trigger to start buffering user audio and creating the next session + # (if we're past the threshold). + await self._sc.on_assistant_audio_started() + elif content.role == Role.USER: + # Session continuation: USER contentStart during a forced transition + # (no assistant response yet) should complete the handoff immediately. + if self._sc.on_user_content_started(): + self.create_task(self._run_sc_handoff(), name="sc_handoff") async def _handle_text_output_event(self, event_json): if not self._content_being_received: # should never happen @@ -1097,6 +1292,12 @@ class AWSNovaSonicLLMService(LLMService): # Assumption: only one text content per content block content.text_content = text_content + # Session continuation: track speculative/final text counts for completion signal + self._sc.on_text_output( + content.role.value, + content.text_stage.value if content.text_stage else None, + ) + async def _handle_audio_output_event(self, event_json): if not self._content_being_received: # should never happen return @@ -1160,12 +1361,21 @@ class AWSNovaSonicLLMService(LLMService): if stop_reason != "INTERRUPTED": if content.text_stage == TextStage.SPECULATIVE: await self._report_llm_text(content.text_content) - elif self._assistant_is_responding: - # TEXT INTERRUPTED with no audio means the user interrupted - # before audio started. End the response here since no AUDIO - # contentEnd will arrive. - self._assistant_is_responding = False - await self._report_assistant_response_ended() + # Session continuation: ASSISTANT FINAL text — add to history + # and check for completion signal (speculative/final counts match) + if content.text_stage == TextStage.FINAL: + if self._sc.on_content_end_assistant_final_text(content.text_content): + self.create_task(self._run_sc_handoff(), name="sc_handoff") + else: + if self._assistant_is_responding: + # TEXT INTERRUPTED before audio started means no AUDIO + # contentEnd will arrive — end the response here. + self._assistant_is_responding = False + await self._report_assistant_response_ended() + # Session continuation: TEXT INTERRUPTED is a completion + # signal regardless of audio state (reference lines 650-654) + if self._sc.on_content_end_text_interrupted(): + self.create_task(self._run_sc_handoff(), name="sc_handoff") elif content.type == ContentType.AUDIO: # Emit deferred TTSTextFrame after all audio chunks have been sent await self._report_tts_text() @@ -1179,9 +1389,13 @@ class AWSNovaSonicLLMService(LLMService): if content.text_stage == TextStage.FINAL: # User transcription text added await self._report_user_transcription_text_added(content.text_content) + # Session continuation: add to real-time history + self._sc.on_content_end_user_final_text(content.text_content) async def _handle_completion_end_event(self, _): - pass + # Session continuation: completionEnd is a fallback completion signal + if self._sc.on_completion_end(): + self.create_task(self._run_sc_handoff(), name="sc_handoff") # # assistant response reporting diff --git a/src/pipecat/services/aws/nova_sonic/session_continuation.py b/src/pipecat/services/aws/nova_sonic/session_continuation.py new file mode 100644 index 000000000..045a4723d --- /dev/null +++ b/src/pipecat/services/aws/nova_sonic/session_continuation.py @@ -0,0 +1,705 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Session continuation helper for AWS Nova Sonic. + +Nova Sonic sessions have an AWS-imposed time limit (~8 minutes). This module +provides transparent session continuation that rotates sessions in the background +before the limit is reached, preserving conversation context with no +user-perceptible interruption. + +Implementation follows the AWS reference architecture: +https://github.com/aws-samples/amazon-nova-samples/tree/main/speech-to-speech/amazon-nova-2-sonic/repeatable-patterns/session-continuation/console-python +""" + +import asyncio +import time +from collections import deque +from collections.abc import Callable, Coroutine +from dataclasses import dataclass +from typing import Any, Protocol, runtime_checkable + +from loguru import logger +from pydantic import BaseModel, Field + +_MAX_HISTORY_MESSAGE_BYTES = 50 * 1024 # 50 KB per message +_MAX_HISTORY_TOTAL_BYTES = 200 * 1024 # 200 KB total history + + +@runtime_checkable +class NovaSonicSessionSender(Protocol): + """Protocol for sending events to a Nova Sonic session stream. + + The LLM service implements this to expose the Nova Sonic wire protocol to + the session continuation helper without coupling the helper to service + internals (audio config, voice, model, etc.). Each method targets an + explicit ``stream`` / ``prompt_name`` so the same implementation can write + to either the current session or the next (pre-created) session. + """ + + async def open_stream(self, client: Any) -> Any: + """Open a bidirectional stream on the given client.""" + ... + + async def send_event(self, event_json: str, stream: Any) -> None: + """Send a raw event JSON string to the given stream.""" + ... + + def build_session_start_json(self) -> str: + """Build the ``sessionStart`` event JSON string.""" + ... + + async def send_prompt_start(self, tools: list, prompt_name: str, stream: Any) -> None: + """Send a ``promptStart`` event to the given stream.""" + ... + + async def send_text( + self, text: str, role: str, prompt_name: str, stream: Any, interactive: bool + ) -> None: + """Send a text content block (contentStart/textInput/contentEnd) to the given stream.""" + ... + + async def send_audio_input_start( + self, prompt_name: str, content_name: str, stream: Any + ) -> None: + """Send an audio input ``contentStart`` to the given stream.""" + ... + + async def send_audio( + self, audio: bytes, prompt_name: str, content_name: str, stream: Any + ) -> None: + """Send an ``audioInput`` event to the given stream.""" + ... + + def create_client(self) -> Any: + """Create a new Bedrock runtime client.""" + ... + + @property + def audio_config(self) -> Any: + """Return the audio configuration (``AudioConfig`` instance).""" + ... + + def get_setup_params(self) -> "tuple[str | None, list]": + """Return ``(system_instruction, tools)`` for the next session setup.""" + ... + + +class SessionContinuationParams(BaseModel): + """Configuration for automatic session continuation. + + Nova Sonic sessions have an AWS-imposed time limit (~8 minutes). When enabled, + session continuation proactively creates a new session in the background before + the limit is reached, buffers user audio during the transition, and seamlessly + hands off — preserving conversation context with no user-perceptible gap. + + Parameters: + enabled: Whether automatic session continuation is enabled. + transition_threshold_seconds: How many seconds into a session to begin + monitoring for a transition opportunity. The transition will occur + when the assistant next starts speaking after this threshold. + audio_buffer_duration_seconds: Duration of the rolling audio buffer + (in seconds) that captures user audio during the transition window. + This audio is replayed into the new session so no user input is lost. + audio_start_timeout_seconds: Maximum time to wait for the assistant to + start speaking after the threshold is reached. If no assistant audio + arrives within this window, the transition is forced. Set to 0 to + disable the timeout (wait indefinitely). + """ + + enabled: bool = Field(default=True) + transition_threshold_seconds: float = Field(default=360.0) + audio_buffer_duration_seconds: float = Field(default=3.0) + audio_start_timeout_seconds: float = Field(default=80.0) + + +@dataclass +class _NextSession: + """Holds pre-created resources for the next session during a transition.""" + + client: Any # BedrockRuntimeClient + stream: Any # DuplexEventStream + prompt_name: str + input_audio_content_name: str + + +class SessionContinuationHelper: + """Manages proactive session rotation for Nova Sonic. + + This helper encapsulates all session continuation state and logic, providing + a clean API for the LLM service to integrate with. It delegates stream I/O + back to the LLM service via callbacks. + + The LLM service hooks into this helper at key points: + - ``on_audio_input(audio)``: called for each user audio frame + - ``on_assistant_audio_started()``: called on AUDIO contentStart from assistant + - ``on_assistant_text_output(role, text, stage)``: called on textOutput events + - ``on_content_end(role, content_type, stop_reason, text_content, text_stage)``: + called on contentEnd events + - ``on_completion_end()``: called on completionEnd events + - ``on_user_content_started()``: called on USER contentStart events + """ + + def __init__( + self, + params: SessionContinuationParams, + *, + sender: NovaSonicSessionSender, + create_task: Callable[[Coroutine], asyncio.Task], + cancel_task: Callable[[asyncio.Task, float], Coroutine[Any, Any, None]], + ): + """Initialize the session continuation helper. + + Args: + params: Configuration for session continuation behavior. + sender: Object implementing the ``NovaSonicSessionSender`` protocol. + The LLM service provides this to expose Nova Sonic wire I/O + without coupling the helper to service internals. Audio + configuration is read from ``sender.audio_config``. + create_task: Callable to spawn a task managed by the service's task + manager (typically ``self.create_task`` from the LLM service). + cancel_task: Callable to cancel a task (typically + ``self.cancel_task`` from the LLM service). + """ + self._params = params + self._sender = sender + self._create_task = create_task + self._cancel_task = cancel_task + + # Audio buffer — sized from the sender's audio config + ac = sender.audio_config + self._audio_buffer: deque[bytes] = deque() + self._audio_buffer_max_bytes: int = int( + params.audio_buffer_duration_seconds + * ac.input_sample_rate + * (ac.input_sample_size / 8) + * ac.input_channel_count + ) + + # Transition state + self._next_session: _NextSession | None = None + self._is_buffering = False + self._waiting_for_audio_start = False + self._waiting_for_completion = False + self._handoff_in_progress = False + self._audio_start_wait_time: float | None = None + self._next_session_created_time: float | None = None + self._monitor_task: asyncio.Task | None = None + self._connected_time: float | None = None + self._wants_connection = False + + # Session-level text counters — always incremented, never gated. + # Matches reference: counts live on SessionInfo from session start. + self._speculative_text_count = 0 + self._final_text_count = 0 + self._barge_in_detected = False + + # Conversation history — tracked in real-time from FINAL text events. + # TODO: Integrate with pipecat's LLMContext once the pipeline supports + # synchronous context reads or a flush mechanism. + self._conversation_history: list[dict[str, str]] = [] + + # --- Public API for the LLM service --- + + @property + def is_buffering(self) -> bool: + """Whether user audio is currently being buffered for the transition.""" + return self._is_buffering + + @property + def next_session(self) -> _NextSession | None: + """The pre-created next session, if any.""" + return self._next_session + + @property + def handoff_in_progress(self) -> bool: + """Whether a handoff is currently in progress.""" + return self._handoff_in_progress + + def set_connected(self, connected_time: float): + """Called when the current session finishes connecting. + + Resets session-level counters. In the reference these live on + SessionInfo and are zero-initialized per session. + """ + self._connected_time = connected_time + self._wants_connection = True + self._speculative_text_count = 0 + self._final_text_count = 0 + self._barge_in_detected = False + + def set_disconnected(self): + """Called when the current session disconnects.""" + self._wants_connection = False + self._connected_time = None + + def seed_history(self, role: str, text: str): + """Seed conversation history with initial context messages.""" + if text: + self._conversation_history.append({"role": role, "text": text}) + + def start_monitor(self): + """Start the session duration monitor.""" + if not self._params.enabled or self._monitor_task: + return + self._monitor_task = self._create_task(self._monitor_loop()) + + async def stop_monitor(self): + """Stop the session duration monitor.""" + if self._monitor_task: + await self._cancel_task(self._monitor_task, 1.0) + self._monitor_task = None + + def on_audio_input(self, audio: bytes): + """Called for each user audio frame. Buffers audio during transition.""" + if self._is_buffering: + self._audio_buffer.append(audio) + total = sum(len(c) for c in self._audio_buffer) + while total > self._audio_buffer_max_bytes and self._audio_buffer: + removed = self._audio_buffer.popleft() + total -= len(removed) + + async def on_assistant_audio_started(self): + """Called when assistant AUDIO contentStart is detected. + + Starts buffering and creates the next session if we're past the threshold. + Returns True if session continuation was triggered. + """ + if not self._waiting_for_audio_start or self._handoff_in_progress: + return False + + self._waiting_for_audio_start = False + self._audio_start_wait_time = None + self._is_buffering = True + self._waiting_for_completion = True + + logger.info( + "Session continuation: assistant audio started, " + "buffering user audio and creating next session" + ) + + if not self._next_session: + try: + await self._prepare_next_session() + except Exception as e: + logger.error(f"Session continuation: failed to prepare next session: {e}") + self._is_buffering = False + self._waiting_for_completion = False + return False + + return True + + def on_text_output(self, role: str, stage: str | None): + """Called on textOutput events. Always tracks speculative/final counts. + + Matches reference: counts are session-level, always incremented for + ASSISTANT text regardless of transition state. The completion check + (in on_content_end_assistant_final_text) gates on _waiting_for_completion. + """ + if role != "ASSISTANT": + return + + if stage == "SPECULATIVE": + self._speculative_text_count += 1 + logger.debug(f"Session continuation: SPECULATIVE text #{self._speculative_text_count}") + elif stage == "FINAL": + self._final_text_count += 1 + logger.debug( + f"Session continuation: FINAL text #{self._final_text_count} " + f"(speculative={self._speculative_text_count})" + ) + + def on_content_end_assistant_final_text(self, text: str | None): + """Called on contentEnd for ASSISTANT FINAL TEXT (non-interrupted). + + Adds text to history and checks for completion signal. + Returns True if handoff should be triggered. + """ + if text: + self._conversation_history.append({"role": "ASSISTANT", "text": text}) + + # Check completion signal after adding to history + if ( + self._waiting_for_completion + and self._speculative_text_count > 0 + and self._final_text_count > 0 + and self._final_text_count >= self._speculative_text_count + and not self._handoff_in_progress + ): + logger.info( + f"Session continuation: completion signal — text pairs matched " + f"(final={self._final_text_count} >= speculative={self._speculative_text_count})" + ) + self._waiting_for_completion = False + return True + return False + + def on_content_end_text_interrupted(self) -> bool: + """Called on contentEnd for TEXT with stopReason=INTERRUPTED. + + Marks barge-in detected. If we're waiting for completion, triggers + handoff immediately (matches reference lines 650-654). + Returns True if handoff should be triggered. + """ + self._barge_in_detected = True + if self._waiting_for_completion and not self._handoff_in_progress: + logger.info("Session continuation: completion signal — TEXT INTERRUPTED (barge-in)") + self._waiting_for_completion = False + return True + return False + + def on_content_end_user_final_text(self, text: str | None): + """Called on contentEnd for USER FINAL TEXT. Adds to history. + + Also handles barge-in count reconciliation: when the user speaks after + a barge-in, remaining FINAL texts for the interrupted response will + never arrive. Force final = speculative so the counts match. + Matches reference lines 602-609. + """ + if text: + self._conversation_history.append({"role": "USER", "text": text}) + + if self._barge_in_detected and self._speculative_text_count > self._final_text_count: + logger.info( + f"Session continuation: user spoke after barge-in — " + f"setting final={self._speculative_text_count} (was {self._final_text_count})" + ) + self._final_text_count = self._speculative_text_count + + def on_user_content_started(self) -> bool: + """Called on USER contentStart during transition. + + Marks barge-in during transition (matches reference lines 527-534). + Returns True if handoff should be triggered (forced transition, no + assistant response yet — matches reference lines 579-583). + """ + if self._waiting_for_completion and self._next_session: + self._barge_in_detected = True + + if ( + self._waiting_for_completion + and not self._handoff_in_progress + and self._next_session + and self._final_text_count == 0 + ): + logger.info( + "Session continuation: user spoke during forced transition " + "(no assistant response yet) — completing handoff immediately" + ) + self._waiting_for_completion = False + return True + return False + + def on_completion_end(self) -> bool: + """Called on completionEnd. Fallback completion signal. + + Returns True if handoff should be triggered. + """ + if self._waiting_for_completion and not self._handoff_in_progress: + logger.info("Session continuation: completion signal — completionEnd (fallback)") + self._waiting_for_completion = False + return True + return False + + async def execute_handoff(self) -> _NextSession | None: + """Execute the session handoff. + + Sends conversation history + audioInputStart + buffered audio to the next + session. Returns (old_client, old_stream, old_receive_task, old_prompt_name) + for the caller to swap and clean up, or None if handoff couldn't proceed. + """ + if self._handoff_in_progress: + return None + self._handoff_in_progress = True + + try: + ns = self._next_session + if not ns: + logger.warning("Session continuation: no next session available for handoff") + return None + + logger.info("Session continuation: executing handoff") + + # Build trimmed history: walk backwards to prioritize recent + # messages, truncate individual messages, and cap total size. + prepared: list[dict[str, str]] = [] + total_bytes = 0 + for msg in reversed(self._conversation_history): + text = msg["text"] + encoded = text.encode("utf-8") + if len(encoded) > _MAX_HISTORY_MESSAGE_BYTES: + encoded = encoded[:_MAX_HISTORY_MESSAGE_BYTES] + text = encoded.decode("utf-8", errors="ignore") + encoded = text.encode("utf-8") + msg_bytes = len(encoded) + if total_bytes + msg_bytes > _MAX_HISTORY_TOTAL_BYTES: + logger.debug( + f"Session continuation: dropping older history to fit " + f"{_MAX_HISTORY_TOTAL_BYTES} byte limit " + f"(total_bytes={total_bytes}, msg_bytes={msg_bytes})" + ) + break + total_bytes += msg_bytes + prepared.append({"role": msg["role"], "text": text}) + prepared.reverse() + + # Ensure history starts with a USER message + while prepared and prepared[0]["role"] != "USER": + dropped = prepared.pop(0) + logger.debug( + f"Session continuation: dropping leading {dropped['role']} message from history" + ) + + # Send conversation history + if prepared: + logger.info( + f"Session continuation: sending {len(prepared)} history " + f"messages ({total_bytes} bytes) to new session" + ) + for msg in prepared: + logger.debug( + f"Session continuation: history [{msg['role']}]: " + f"{msg['text'][:80]}{'...' if len(msg['text']) > 80 else ''}" + ) + await self._sender.send_text( + msg["text"], msg["role"], ns.prompt_name, ns.stream, False + ) + + # Send audioInputStart + await self._sender.send_audio_input_start( + ns.prompt_name, ns.input_audio_content_name, ns.stream + ) + + # Send buffered audio + buffer_chunks = list(self._audio_buffer) + ac = self._sender.audio_config + bytes_per_second = ( + ac.input_sample_rate * (ac.input_sample_size / 8) * ac.input_channel_count + ) + buffer_duration = sum(len(c) for c in buffer_chunks) / bytes_per_second + logger.info( + f"Session continuation: sending {len(buffer_chunks)} buffered audio chunks " + f"(~{buffer_duration:.1f}s) to new session" + ) + for chunk in buffer_chunks: + await self._sender.send_audio( + chunk, ns.prompt_name, ns.input_audio_content_name, ns.stream + ) + + # Return the next session info for the caller to promote + logger.info("Session continuation: promoting new session") + result = ns + self._next_session = None + self._is_buffering = False + self._audio_buffer.clear() + + return result + + except Exception as e: + logger.error(f"Session continuation: handoff error: {e}") + await self.cleanup_next_session() + return None + finally: + self._handoff_in_progress = False + self._waiting_for_audio_start = False + self._waiting_for_completion = False + self._audio_start_wait_time = None + self._next_session_created_time = None + # Note: speculative/final counts and barge_in_detected are NOT + # reset here — they are session-level and get reset in + # set_connected() when the new session starts. + + async def close_old_session( + self, client, stream, receive_task, prompt_name, input_audio_content_name=None + ): + """Close the old session's resources in the background. + + Audio input to the old stream is already stopped (handoff_in_progress + gate in _handle_input_audio_frame). Sends contentEnd (audio) → + promptEnd → sessionEnd → closes stream → cancels receive task. + The receive task is cancelled last as a safety net to avoid leaks; + by that point the stream is closed so the CRT future should already + be resolved. + """ + try: + if stream and prompt_name: + try: + import json + + if input_audio_content_name: + audio_content_end_json = json.dumps( + { + "event": { + "contentEnd": { + "promptName": prompt_name, + "contentName": input_audio_content_name, + } + } + } + ) + await self._sender.send_event(audio_content_end_json, stream) + + prompt_end_json = json.dumps( + {"event": {"promptEnd": {"promptName": prompt_name}}} + ) + session_end_json = json.dumps({"event": {"sessionEnd": {}}}) + await self._sender.send_event(prompt_end_json, stream) + await self._sender.send_event(session_end_json, stream) + except Exception: + pass + + if stream: + try: + await asyncio.wait_for(stream.input_stream.close(), timeout=2.0) + except (TimeoutError, Exception): + pass + + # Wait for the receive task to exit naturally (the stream is + # closed, so it will hit an error or the stale-stream check). + # Do NOT cancel — that cancels the underlying CRT future and + # races with native set_result() callbacks, producing an + # InvalidStateError traceback we can't catch from Python. + if receive_task: + try: + await asyncio.wait_for(asyncio.shield(receive_task), timeout=5.0) + except (TimeoutError, Exception): + pass + + logger.debug("Session continuation: old session closed") + except Exception as e: + logger.warning(f"Session continuation: error closing old session: {e}") + + async def cleanup_next_session(self): + """Clean up the pre-created next session if it exists.""" + ns = self._next_session + if not ns: + return + + if ns.stream: + try: + await ns.stream.close() + except Exception: + pass + + self._next_session = None + self._is_buffering = False + self._audio_buffer.clear() + self._next_session_created_time = None + + # --- Internal methods --- + + async def _monitor_loop(self): + """Periodically check session age and manage next session lifecycle.""" + try: + while self._wants_connection: + await asyncio.sleep(1) + + if not self._connected_time or self._handoff_in_progress: + continue + + session_age = time.time() - self._connected_time + threshold = self._params.transition_threshold_seconds + + # Threshold reached — start waiting for assistant audio + if ( + session_age >= threshold + and not self._waiting_for_audio_start + and not self._next_session + and not self._waiting_for_completion + ): + logger.info( + f"Session continuation: session age {session_age:.0f}s >= " + f"threshold {threshold:.0f}s, waiting for assistant audio" + ) + self._waiting_for_audio_start = True + self._audio_start_wait_time = time.time() + + # Audio start timeout — force transition + audio_start_timeout = self._params.audio_start_timeout_seconds + if ( + self._waiting_for_audio_start + and self._audio_start_wait_time + and audio_start_timeout > 0 + and (time.time() - self._audio_start_wait_time) > audio_start_timeout + ): + logger.info( + f"Session continuation: TIMEOUT — no assistant audio after " + f"{audio_start_timeout:.0f}s, forcing transition" + ) + self._waiting_for_audio_start = False + self._audio_start_wait_time = None + self._is_buffering = True + self._waiting_for_completion = False + try: + if not self._next_session: + await self._prepare_next_session() + self._create_task(self.execute_handoff()) + except Exception as e: + logger.error(f"Session continuation: forced transition failed: {e}") + self._is_buffering = False + + # Dead session detection — recreate if idle too long + next_session_timeout = 30.0 + if ( + self._next_session + and self._next_session_created_time + and not self._handoff_in_progress + and (time.time() - self._next_session_created_time) > next_session_timeout + ): + logger.warning( + f"Session continuation: next session idle for " + f">{next_session_timeout:.0f}s, recreating" + ) + await self.cleanup_next_session() + try: + await self._prepare_next_session() + except Exception as e: + logger.error(f"Session continuation: failed to recreate next session: {e}") + + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"Session continuation monitor error: {e}") + + async def _prepare_next_session(self): + """Create a new session and send setup (sessionStart + promptStart + system instruction). + + Conversation history and audio are deferred to handoff time. + """ + import uuid + + prompt_name = str(uuid.uuid4()) + input_audio_content_name = str(uuid.uuid4()) + + client = self._sender.create_client() + stream = await self._sender.open_stream(client) + + self._next_session = _NextSession( + client=client, + stream=stream, + prompt_name=prompt_name, + input_audio_content_name=input_audio_content_name, + ) + self._next_session_created_time = time.time() + + ns = self._next_session + + # Send sessionStart + await self._sender.send_event(self._sender.build_session_start_json(), ns.stream) + + # Get setup params: (system_instruction, tools) + system_instruction, tools = self._sender.get_setup_params() + + # Send promptStart with tools + await self._sender.send_prompt_start(tools, ns.prompt_name, ns.stream) + + # Send system instruction + if system_instruction: + await self._sender.send_text( + system_instruction, "SYSTEM", ns.prompt_name, ns.stream, False + ) + + logger.debug(f"Session continuation: next session prepared (prompt={prompt_name[:8]}...)")