diff --git a/src/pipecat/services/aws_nova_sonic/aws.py b/src/pipecat/services/aws_nova_sonic/aws.py index 2e875f96f..a2437b9dd 100644 --- a/src/pipecat/services/aws_nova_sonic/aws.py +++ b/src/pipecat/services/aws_nova_sonic/aws.py @@ -1,6 +1,7 @@ import base64 import json import uuid +from dataclasses import dataclass from enum import Enum from aws_sdk_bedrock_runtime.client import ( @@ -25,6 +26,9 @@ from pipecat.frames.frames import ( EndFrame, Frame, InputAudioRawFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + LLMTextFrame, StartFrame, TTSAudioRawFrame, TTSStartedFrame, @@ -37,6 +41,36 @@ from pipecat.services.llm_service import LLMService class Role(Enum): SYSTEM = "SYSTEM" USER = "USER" + ASSISTANT = "ASSISTANT" + TOOL = "TOOL" + + +class ContentType(Enum): + AUDIO = "AUDIO" + TEXT = "TEXT" + TOOL = "TOOL" + + +class TextStage(Enum): + FINAL = "FINAL" # what has been said + SPECULATIVE = "SPECULATIVE" # what's planned to be said + + +@dataclass +class CurrentContent: + type: ContentType + role: Role + text_stage: TextStage # None if not text + text_content: str # starts as None, then fills in if text + + def __str__(self): + return ( + f"CurrentContent(\n" + f" type={self.type.name},\n" + f" role={self.role.name},\n" + f" text_stage={self.text_stage.name if self.text_stage else 'None'}\n" + f")" + ) class AWSNovaSonicService(LLMService): @@ -65,7 +99,8 @@ class AWSNovaSonicService(LLMService): self._receive_task = None self._prompt_name = str(uuid.uuid4()) self._input_audio_content_name = str(uuid.uuid4()) - self._audio_response_ongoing = False + self._content_being_received = None # TODO: clean this up on error or when finished + self._assistant_is_responding = False # # standard AIService frame handling @@ -314,7 +349,7 @@ class AWSNovaSonicService(LLMService): if "event" in json_data: event_json = json_data["event"] if "completionStart" in event_json: - # Handle the LLM response starting + # Handle the LLM completion starting await self._handle_completion_start_event(event_json) elif "contentStart" in event_json: # Handle a piece of content starting @@ -329,7 +364,7 @@ class AWSNovaSonicService(LLMService): # Handle a piece of content ending await self._handle_content_end_event(event_json) elif "completionStart" in event_json: - # Handle the LLM response ending + # Handle the LLM completion ending await self._handle_completion_end_event(event_json) except Exception as e: @@ -347,24 +382,35 @@ class AWSNovaSonicService(LLMService): if "additionalModelFields" in content_start: additional_model_fields = json.loads(content_start["additionalModelFields"]) generation_stage = additional_model_fields.get("generationStage") - # print( - # f"[pk] content start. type: {type}, role: {role}, generation_stage: {generation_stage}" - # ) + + # Bookkeeping: track current content being received + content = CurrentContent( + type=ContentType(type), + role=Role(role), + text_stage=TextStage(generation_stage) if generation_stage else None, + text_content=None + ) + self._content_being_received = content + + if content.role == Role.ASSISTANT: + if content.type == ContentType.AUDIO: + # Report that *equivalent* of TTS (this is a speech-to-speech model) started + # print("[pk] TTS started") + await self.push_frame(TTSStartedFrame()) + + print(f"[pk] content start: {self._content_being_received}") async def _handle_text_output_event(self, event_json): text_content = event_json["textOutput"]["content"] - # print(f"[pk] text output. content: {text_content}") + print(f"[pk] text output. content: {text_content}") + + # Bookkeeping: augment the current content being received with text + content = self._content_being_received + content.text_content = text_content async def _handle_audio_output_event(self, event_json): audio_content = event_json["audioOutput"]["content"] - print(f"[pk] audio output. content: {len(audio_content)}") - - # Report that *equivalent* of TTS (this is a speech-to-speech model) started - if not self._audio_response_ongoing: - self._audio_response_ongoing = True - # print("[pk] starting TTS") - await self.push_frame(TTSStartedFrame()) - + # print(f"[pk] audio output. content: {len(audio_content)}") # Push audio frame audio = base64.b64decode(audio_content) # TODO: make sample rate + channels (used in multiple places) consts @@ -377,15 +423,49 @@ class AWSNovaSonicService(LLMService): async def _handle_content_end_event(self, event_json): content_end = event_json["contentEnd"] - type = content_end["type"] stop_reason = content_end["stopReason"] - # print(f"[pk] content end. type: {type}, stop_reason: {stop_reason}") + # print( + # f"[pk] content end: {self._content_being_received}.\n" + # f" stop_reason: {stop_reason}" + # ) - # Report that *equivalent* of TTS (this is a speech-to-speech model) stopped - if type == "AUDIO" and self._audio_response_ongoing: - print("[pk] stopping TTS") - self._audio_response_ongoing = False - await self.push_frame(TTSStoppedFrame()) + # Bookkeeping: clear current content being received + content = self._content_being_received + self._content_being_received = None + + if content and content.role == Role.ASSISTANT: + if content.type == ContentType.AUDIO: + # We got to the end of a chunk of the assistant's audio. + # Report that *equivalent* of TTS (this is a speech-to-speech model) stopped. + # print("[pk] TTS stopped") + await self.push_frame(TTSStoppedFrame()) + elif content.type == ContentType.TEXT: + # Ignore non-final text, and the "interrupted" message (which isn't meaningful text) + if content.text_stage == TextStage.FINAL and stop_reason != "INTERRUPTED": + # TODO: the way we're tracking the start and stop of the assistant response here + # is rather busted, and results in way too many "responses" being put into the + # context (every final text content block is treated as its own response). + # We *should* only record that an assistant response has ended when: + # - the assistant truly finished its turn (stop_reason is END_TURN) + # - when this is the next text content block after an INTERRUPTED has occurred + # BUT it seems like there's a bug where, if there are multiple assistant text + # content blocks, the *first* one gets marked END_TURN rather than the last. + print("[pk] LLM full response started") + self._assistant_is_responding = True + await self.push_frame(LLMFullResponseStartFrame()) + + if self._assistant_is_responding: + # Add text to the ongoing reported assistant response + print(f"[pk] LLM text: {content.text_content}") + await self.push_frame(LLMTextFrame(content.text_content)) + + # Report that the assistant has finished their response. + # TODO: kinda busted. see TODO comment above. + print("[pk] LLM full response ended") + await self.push_frame(LLMFullResponseEndFrame()) + self._assistant_is_responding = False + + self._content_being_received = False async def _handle_completion_end_event(self, event_json): # print("[pk] completion end")