[WIP] AWS Nova Sonic service - added LLMFullResponseStartFrame, LLMTextFrame, and LLMFullResponseEndFrame
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from aws_sdk_bedrock_runtime.client import (
|
||||
@@ -25,6 +26,9 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
@@ -37,6 +41,36 @@ from pipecat.services.llm_service import LLMService
|
||||
class Role(Enum):
|
||||
SYSTEM = "SYSTEM"
|
||||
USER = "USER"
|
||||
ASSISTANT = "ASSISTANT"
|
||||
TOOL = "TOOL"
|
||||
|
||||
|
||||
class ContentType(Enum):
|
||||
AUDIO = "AUDIO"
|
||||
TEXT = "TEXT"
|
||||
TOOL = "TOOL"
|
||||
|
||||
|
||||
class TextStage(Enum):
|
||||
FINAL = "FINAL" # what has been said
|
||||
SPECULATIVE = "SPECULATIVE" # what's planned to be said
|
||||
|
||||
|
||||
@dataclass
|
||||
class CurrentContent:
|
||||
type: ContentType
|
||||
role: Role
|
||||
text_stage: TextStage # None if not text
|
||||
text_content: str # starts as None, then fills in if text
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"CurrentContent(\n"
|
||||
f" type={self.type.name},\n"
|
||||
f" role={self.role.name},\n"
|
||||
f" text_stage={self.text_stage.name if self.text_stage else 'None'}\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
class AWSNovaSonicService(LLMService):
|
||||
@@ -65,7 +99,8 @@ class AWSNovaSonicService(LLMService):
|
||||
self._receive_task = None
|
||||
self._prompt_name = str(uuid.uuid4())
|
||||
self._input_audio_content_name = str(uuid.uuid4())
|
||||
self._audio_response_ongoing = False
|
||||
self._content_being_received = None # TODO: clean this up on error or when finished
|
||||
self._assistant_is_responding = False
|
||||
|
||||
#
|
||||
# standard AIService frame handling
|
||||
@@ -314,7 +349,7 @@ class AWSNovaSonicService(LLMService):
|
||||
if "event" in json_data:
|
||||
event_json = json_data["event"]
|
||||
if "completionStart" in event_json:
|
||||
# Handle the LLM response starting
|
||||
# Handle the LLM completion starting
|
||||
await self._handle_completion_start_event(event_json)
|
||||
elif "contentStart" in event_json:
|
||||
# Handle a piece of content starting
|
||||
@@ -329,7 +364,7 @@ class AWSNovaSonicService(LLMService):
|
||||
# Handle a piece of content ending
|
||||
await self._handle_content_end_event(event_json)
|
||||
elif "completionStart" in event_json:
|
||||
# Handle the LLM response ending
|
||||
# Handle the LLM completion ending
|
||||
await self._handle_completion_end_event(event_json)
|
||||
|
||||
except Exception as e:
|
||||
@@ -347,24 +382,35 @@ class AWSNovaSonicService(LLMService):
|
||||
if "additionalModelFields" in content_start:
|
||||
additional_model_fields = json.loads(content_start["additionalModelFields"])
|
||||
generation_stage = additional_model_fields.get("generationStage")
|
||||
# print(
|
||||
# f"[pk] content start. type: {type}, role: {role}, generation_stage: {generation_stage}"
|
||||
# )
|
||||
|
||||
# Bookkeeping: track current content being received
|
||||
content = CurrentContent(
|
||||
type=ContentType(type),
|
||||
role=Role(role),
|
||||
text_stage=TextStage(generation_stage) if generation_stage else None,
|
||||
text_content=None
|
||||
)
|
||||
self._content_being_received = content
|
||||
|
||||
if content.role == Role.ASSISTANT:
|
||||
if content.type == ContentType.AUDIO:
|
||||
# Report that *equivalent* of TTS (this is a speech-to-speech model) started
|
||||
# print("[pk] TTS started")
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
|
||||
print(f"[pk] content start: {self._content_being_received}")
|
||||
|
||||
async def _handle_text_output_event(self, event_json):
|
||||
text_content = event_json["textOutput"]["content"]
|
||||
# print(f"[pk] text output. content: {text_content}")
|
||||
print(f"[pk] text output. content: {text_content}")
|
||||
|
||||
# Bookkeeping: augment the current content being received with text
|
||||
content = self._content_being_received
|
||||
content.text_content = text_content
|
||||
|
||||
async def _handle_audio_output_event(self, event_json):
|
||||
audio_content = event_json["audioOutput"]["content"]
|
||||
print(f"[pk] audio output. content: {len(audio_content)}")
|
||||
|
||||
# Report that *equivalent* of TTS (this is a speech-to-speech model) started
|
||||
if not self._audio_response_ongoing:
|
||||
self._audio_response_ongoing = True
|
||||
# print("[pk] starting TTS")
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
|
||||
# print(f"[pk] audio output. content: {len(audio_content)}")
|
||||
# Push audio frame
|
||||
audio = base64.b64decode(audio_content)
|
||||
# TODO: make sample rate + channels (used in multiple places) consts
|
||||
@@ -377,15 +423,49 @@ class AWSNovaSonicService(LLMService):
|
||||
|
||||
async def _handle_content_end_event(self, event_json):
|
||||
content_end = event_json["contentEnd"]
|
||||
type = content_end["type"]
|
||||
stop_reason = content_end["stopReason"]
|
||||
# print(f"[pk] content end. type: {type}, stop_reason: {stop_reason}")
|
||||
# print(
|
||||
# f"[pk] content end: {self._content_being_received}.\n"
|
||||
# f" stop_reason: {stop_reason}"
|
||||
# )
|
||||
|
||||
# Report that *equivalent* of TTS (this is a speech-to-speech model) stopped
|
||||
if type == "AUDIO" and self._audio_response_ongoing:
|
||||
print("[pk] stopping TTS")
|
||||
self._audio_response_ongoing = False
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
# Bookkeeping: clear current content being received
|
||||
content = self._content_being_received
|
||||
self._content_being_received = None
|
||||
|
||||
if content and content.role == Role.ASSISTANT:
|
||||
if content.type == ContentType.AUDIO:
|
||||
# We got to the end of a chunk of the assistant's audio.
|
||||
# Report that *equivalent* of TTS (this is a speech-to-speech model) stopped.
|
||||
# print("[pk] TTS stopped")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
elif content.type == ContentType.TEXT:
|
||||
# Ignore non-final text, and the "interrupted" message (which isn't meaningful text)
|
||||
if content.text_stage == TextStage.FINAL and stop_reason != "INTERRUPTED":
|
||||
# TODO: the way we're tracking the start and stop of the assistant response here
|
||||
# is rather busted, and results in way too many "responses" being put into the
|
||||
# context (every final text content block is treated as its own response).
|
||||
# We *should* only record that an assistant response has ended when:
|
||||
# - the assistant truly finished its turn (stop_reason is END_TURN)
|
||||
# - when this is the next text content block after an INTERRUPTED has occurred
|
||||
# BUT it seems like there's a bug where, if there are multiple assistant text
|
||||
# content blocks, the *first* one gets marked END_TURN rather than the last.
|
||||
print("[pk] LLM full response started")
|
||||
self._assistant_is_responding = True
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
|
||||
if self._assistant_is_responding:
|
||||
# Add text to the ongoing reported assistant response
|
||||
print(f"[pk] LLM text: {content.text_content}")
|
||||
await self.push_frame(LLMTextFrame(content.text_content))
|
||||
|
||||
# Report that the assistant has finished their response.
|
||||
# TODO: kinda busted. see TODO comment above.
|
||||
print("[pk] LLM full response ended")
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
self._assistant_is_responding = False
|
||||
|
||||
self._content_being_received = False
|
||||
|
||||
async def _handle_completion_end_event(self, event_json):
|
||||
# print("[pk] completion end")
|
||||
|
||||
Reference in New Issue
Block a user