[WIP] AWS Nova Sonic service - added LLMFullResponseStartFrame, LLMTextFrame, and LLMFullResponseEndFrame

This commit is contained in:
Paul Kompfner
2025-04-25 15:12:37 -04:00
parent e40aa4f99a
commit de294caed9

View File

@@ -1,6 +1,7 @@
import base64
import json
import uuid
from dataclasses import dataclass
from enum import Enum
from aws_sdk_bedrock_runtime.client import (
@@ -25,6 +26,9 @@ from pipecat.frames.frames import (
EndFrame,
Frame,
InputAudioRawFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMTextFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
@@ -37,6 +41,36 @@ from pipecat.services.llm_service import LLMService
class Role(Enum):
SYSTEM = "SYSTEM"
USER = "USER"
ASSISTANT = "ASSISTANT"
TOOL = "TOOL"
class ContentType(Enum):
AUDIO = "AUDIO"
TEXT = "TEXT"
TOOL = "TOOL"
class TextStage(Enum):
FINAL = "FINAL" # what has been said
SPECULATIVE = "SPECULATIVE" # what's planned to be said
@dataclass
class CurrentContent:
type: ContentType
role: Role
text_stage: TextStage # None if not text
text_content: str # starts as None, then fills in if text
def __str__(self):
return (
f"CurrentContent(\n"
f" type={self.type.name},\n"
f" role={self.role.name},\n"
f" text_stage={self.text_stage.name if self.text_stage else 'None'}\n"
f")"
)
class AWSNovaSonicService(LLMService):
@@ -65,7 +99,8 @@ class AWSNovaSonicService(LLMService):
self._receive_task = None
self._prompt_name = str(uuid.uuid4())
self._input_audio_content_name = str(uuid.uuid4())
self._audio_response_ongoing = False
self._content_being_received = None # TODO: clean this up on error or when finished
self._assistant_is_responding = False
#
# standard AIService frame handling
@@ -314,7 +349,7 @@ class AWSNovaSonicService(LLMService):
if "event" in json_data:
event_json = json_data["event"]
if "completionStart" in event_json:
# Handle the LLM response starting
# Handle the LLM completion starting
await self._handle_completion_start_event(event_json)
elif "contentStart" in event_json:
# Handle a piece of content starting
@@ -329,7 +364,7 @@ class AWSNovaSonicService(LLMService):
# Handle a piece of content ending
await self._handle_content_end_event(event_json)
elif "completionStart" in event_json:
# Handle the LLM response ending
# Handle the LLM completion ending
await self._handle_completion_end_event(event_json)
except Exception as e:
@@ -347,24 +382,35 @@ class AWSNovaSonicService(LLMService):
if "additionalModelFields" in content_start:
additional_model_fields = json.loads(content_start["additionalModelFields"])
generation_stage = additional_model_fields.get("generationStage")
# print(
# f"[pk] content start. type: {type}, role: {role}, generation_stage: {generation_stage}"
# )
# Bookkeeping: track current content being received
content = CurrentContent(
type=ContentType(type),
role=Role(role),
text_stage=TextStage(generation_stage) if generation_stage else None,
text_content=None
)
self._content_being_received = content
if content.role == Role.ASSISTANT:
if content.type == ContentType.AUDIO:
# Report that *equivalent* of TTS (this is a speech-to-speech model) started
# print("[pk] TTS started")
await self.push_frame(TTSStartedFrame())
print(f"[pk] content start: {self._content_being_received}")
async def _handle_text_output_event(self, event_json):
text_content = event_json["textOutput"]["content"]
# print(f"[pk] text output. content: {text_content}")
print(f"[pk] text output. content: {text_content}")
# Bookkeeping: augment the current content being received with text
content = self._content_being_received
content.text_content = text_content
async def _handle_audio_output_event(self, event_json):
audio_content = event_json["audioOutput"]["content"]
print(f"[pk] audio output. content: {len(audio_content)}")
# Report that *equivalent* of TTS (this is a speech-to-speech model) started
if not self._audio_response_ongoing:
self._audio_response_ongoing = True
# print("[pk] starting TTS")
await self.push_frame(TTSStartedFrame())
# print(f"[pk] audio output. content: {len(audio_content)}")
# Push audio frame
audio = base64.b64decode(audio_content)
# TODO: make sample rate + channels (used in multiple places) consts
@@ -377,15 +423,49 @@ class AWSNovaSonicService(LLMService):
async def _handle_content_end_event(self, event_json):
content_end = event_json["contentEnd"]
type = content_end["type"]
stop_reason = content_end["stopReason"]
# print(f"[pk] content end. type: {type}, stop_reason: {stop_reason}")
# print(
# f"[pk] content end: {self._content_being_received}.\n"
# f" stop_reason: {stop_reason}"
# )
# Report that *equivalent* of TTS (this is a speech-to-speech model) stopped
if type == "AUDIO" and self._audio_response_ongoing:
print("[pk] stopping TTS")
self._audio_response_ongoing = False
await self.push_frame(TTSStoppedFrame())
# Bookkeeping: clear current content being received
content = self._content_being_received
self._content_being_received = None
if content and content.role == Role.ASSISTANT:
if content.type == ContentType.AUDIO:
# We got to the end of a chunk of the assistant's audio.
# Report that *equivalent* of TTS (this is a speech-to-speech model) stopped.
# print("[pk] TTS stopped")
await self.push_frame(TTSStoppedFrame())
elif content.type == ContentType.TEXT:
# Ignore non-final text, and the "interrupted" message (which isn't meaningful text)
if content.text_stage == TextStage.FINAL and stop_reason != "INTERRUPTED":
# TODO: the way we're tracking the start and stop of the assistant response here
# is rather busted, and results in way too many "responses" being put into the
# context (every final text content block is treated as its own response).
# We *should* only record that an assistant response has ended when:
# - the assistant truly finished its turn (stop_reason is END_TURN)
# - when this is the next text content block after an INTERRUPTED has occurred
# BUT it seems like there's a bug where, if there are multiple assistant text
# content blocks, the *first* one gets marked END_TURN rather than the last.
print("[pk] LLM full response started")
self._assistant_is_responding = True
await self.push_frame(LLMFullResponseStartFrame())
if self._assistant_is_responding:
# Add text to the ongoing reported assistant response
print(f"[pk] LLM text: {content.text_content}")
await self.push_frame(LLMTextFrame(content.text_content))
# Report that the assistant has finished their response.
# TODO: kinda busted. see TODO comment above.
print("[pk] LLM full response ended")
await self.push_frame(LLMFullResponseEndFrame())
self._assistant_is_responding = False
self._content_being_received = False
async def _handle_completion_end_event(self, event_json):
# print("[pk] completion end")