From ee53535f41485bc9399e21eaefe5f77cd3969f9a Mon Sep 17 00:00:00 2001 From: Kwindla Hultman Kramer Date: Fri, 8 Nov 2024 08:28:54 -0800 Subject: [PATCH] gemini audio-in with no transcription --- .../07p-interruptible-google-audio-in.py | 276 ++++++++++++++++++ .../aggregators/openai_llm_context.py | 28 ++ src/pipecat/services/google.py | 142 +++++---- 3 files changed, 387 insertions(+), 59 deletions(-) create mode 100644 examples/foundational/07p-interruptible-google-audio-in.py diff --git a/examples/foundational/07p-interruptible-google-audio-in.py b/examples/foundational/07p-interruptible-google-audio-in.py new file mode 100644 index 000000000..33bb30187 --- /dev/null +++ b/examples/foundational/07p-interruptible-google-audio-in.py @@ -0,0 +1,276 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import aiohttp +import asyncio +import os +import sys + +import google.ai.generativelanguage as glm + +from dataclasses import dataclass +from dotenv import load_dotenv +from loguru import logger +from runner import configure + +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.cartesia import CartesiaTTSService +from pipecat.services.google import GoogleLLMService +from pipecat.processors.frame_processor import FrameProcessor +from pipecat.transports.services.daily import DailyParams, DailyTransport +from pipecat.frames.frames import ( + LLMFullResponseStartFrame, + LLMFullResponseEndFrame, + InputAudioRawFrame, + Frame, + StartInterruptionFrame, + TextFrame, + TranscriptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, +) + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + +marker = "|----|" +system_message = f""" +You are a helpful LLM in a WebRTC call. Your goals are to be helpful and brief in your responses. + +You are expert at transcribing audio to text. You will receive a mixture of audio and text input. When +asked to transcribe what the user said, output an exact, word-for-word transcription. + +Your output will be converted to audio so don't include special characters in your answers. + +Each time you answer, you should respond in three parts. + +1. Transcribe exactly what the user said. +2. Output the separator field '{marker}'. +3. Respond to the user's input in a succinct, helpful, creative way using only simple text and punctuation. + +Example: + +User: How many ounces are in a pound? + +You: How many ounces are in a pound? +{marker} +There are 16 ounces in a pound. +""" + + +@dataclass +class MagicDemoTranscriptionFrame(Frame): + text: str + + +class UserAudioCollector(FrameProcessor): + def __init__(self, context, user_context_aggregator): + super().__init__() + self._context = context + self._user_context_aggregator = user_context_aggregator + self._audio_frames = [] + self._start_secs = ( + 0.2 # this should match VAD_START_SECS but we'll just hardcode it for now + ) + self._user_speaking = False + + async def process_frame(self, frame, direction): + await super().process_frame(frame, direction) + + if isinstance(frame, TranscriptionFrame): + # We could gracefully handle both audio input and text/transcription input ... + # but let's leave that as an exercise to the reader. :-) + return + if isinstance(frame, UserStartedSpeakingFrame): + self._user_speaking = True + elif isinstance(frame, UserStoppedSpeakingFrame): + self._user_speaking = False + self._context.add_audio_frames_message(audio_frames=self._audio_frames) + await self._user_context_aggregator.push_frame( + self._user_context_aggregator.get_context_frame() + ) + elif isinstance(frame, InputAudioRawFrame): + if self._user_speaking: + self._audio_frames.append(frame) + else: + # Append the audio frame to our buffer. Treat the buffer as a ring buffer, dropping the oldest + # frames as necessary. Assume all audio frames have the same duration. + self._audio_frames.append(frame) + frame_duration = len(frame.audio) / 16 * frame.num_channels / frame.sample_rate + buffer_duration = frame_duration * len(self._audio_frames) + while buffer_duration > self._start_secs: + self._audio_frames.pop(0) + buffer_duration -= frame_duration + + await self.push_frame(frame, direction) + + +class TranscriptExtractor(FrameProcessor): + def __init__(self, context): + super().__init__() + self._context = context + self._accumulator = "" + self._processing_llm_response = False + self._accumulating_transcript = False + + def reset(self): + self._accumulator = "" + self._processing_llm_response = False + self._accumulating_transcript = False + + async def process_frame(self, frame, direction): + await super().process_frame(frame, direction) + if isinstance(frame, LLMFullResponseStartFrame): + self._processing_llm_response = True + self._accumulating_transcript = True + elif isinstance(frame, TextFrame) and self._processing_llm_response: + if self._accumulating_transcript: + text = frame.text + split_index = text.find(marker) + if split_index < 0: + self._accumulator += frame.text + # do not push this frame + return + else: + self._accumulating_transcript = False + self._accumulator += text[:split_index] + frame.text = text[split_index + len(marker) :] + await self.push_frame(frame) + return + elif isinstance(frame, LLMFullResponseEndFrame): + await self.push_frame(MagicDemoTranscriptionFrame(text=self._accumulator.strip())) + self.reset() + + await self.push_frame(frame, direction) + + +class TanscriptionContextFixup(FrameProcessor): + def __init__(self, context): + super().__init__() + self._context = context + self._transcript = "THIS IS A TRANSCRIPT" + + def swap_user_audio(self): + if not self._transcript: + return + message = self._context.messages[-2] + last_part = message.parts[-1] + if ( + message.role == "user" + and last_part.inline_data + and last_part.inline_data.mime_type == "audio/wav" + ): + self._context.messages[-2] = glm.Content( + role="user", parts=[glm.Part(text=self._transcript)] + ) + + def add_transcript_back_to_inference_output(self): + if not self._transcript: + return + message = self._context.messages[-1] + last_part = message.parts[-1] + if message.role == "model" and last_part.text: + self._context.messages[-1].parts[-1].text += f"\n\n{marker}\n{self._transcript}\n" + + async def process_frame(self, frame, direction): + await super().process_frame(frame, direction) + + if isinstance(frame, MagicDemoTranscriptionFrame): + self._transcript = frame.text + elif isinstance(frame, LLMFullResponseEndFrame) or isinstance( + frame, StartInterruptionFrame + ): + self.swap_user_audio() + self.add_transcript_back_to_inference_output() + self._transcript = "" + + await self.push_frame(frame, direction) + + +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, token) = await configure(session) + + transport = DailyTransport( + room_url, + token, + "Respond bot", + DailyParams( + audio_out_enabled=True, + # No transcription at all. just audio input to Gemini! + # transcription_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + vad_audio_passthrough=True, + ), + ) + + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady + ) + + llm = GoogleLLMService(model="gemini-1.5-flash-latest", api_key=os.getenv("GOOGLE_API_KEY")) + + messages = [ + { + "role": "system", + "content": system_message, + }, + { + "role": "user", + "content": "Start by saying hello.", + }, + ] + + context = OpenAILLMContext(messages) + context_aggregator = llm.create_context_aggregator(context) + audio_collector = UserAudioCollector(context, context_aggregator.user()) + pull_transcript_out_of_llm_output = TranscriptExtractor(context) + fixup_context_messages = TanscriptionContextFixup(context) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + audio_collector, + context_aggregator.user(), # User responses + llm, # LLM + pull_transcript_out_of_llm_output, + tts, # TTS + transport.output(), # Transport bot output + context_aggregator.assistant(), # Assistant spoken responses + fixup_context_messages, + ] + ) + + task = PipelineTask( + pipeline, + PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + ), + ) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + await transport.capture_participant_transcription(participant["id"]) + # Kick off the conversation. + await task.queue_frames([context_aggregator.user().get_context_frame()]) + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/pipecat/processors/aggregators/openai_llm_context.py b/src/pipecat/processors/aggregators/openai_llm_context.py index d70f0e25b..5e4f44093 100644 --- a/src/pipecat/processors/aggregators/openai_llm_context.py +++ b/src/pipecat/processors/aggregators/openai_llm_context.py @@ -15,6 +15,7 @@ from loguru import logger from PIL import Image from pipecat.frames.frames import ( + AudioRawFrame, Frame, FunctionCallInProgressFrame, FunctionCallResultFrame, @@ -174,6 +175,10 @@ class OpenAILLMContext: content.append({"type": "text", "text": text}) self.add_message({"role": "user", "content": content}) + def add_audio_frames_message(self, *, audio_frames: list[AudioRawFrame], text: str = None): + # todo: implement for OpenAI models and others + pass + async def call_function( self, f: Callable[ @@ -213,6 +218,29 @@ class OpenAILLMContext: await f(function_name, tool_call_id, arguments, llm, self, function_call_result_callback) + def create_wav_header(self, sample_rate, num_channels, bits_per_sample, data_size): + # RIFF chunk descriptor + header = bytearray() + header.extend(b"RIFF") # ChunkID + header.extend((data_size + 36).to_bytes(4, "little")) # ChunkSize: total size - 8 + header.extend(b"WAVE") # Format + # "fmt " sub-chunk + header.extend(b"fmt ") # Subchunk1ID + header.extend((16).to_bytes(4, "little")) # Subchunk1Size (16 for PCM) + header.extend((1).to_bytes(2, "little")) # AudioFormat (1 for PCM) + header.extend(num_channels.to_bytes(2, "little")) # NumChannels + header.extend(sample_rate.to_bytes(4, "little")) # SampleRate + # Calculate byte rate and block align + byte_rate = sample_rate * num_channels * (bits_per_sample // 8) + block_align = num_channels * (bits_per_sample // 8) + header.extend(byte_rate.to_bytes(4, "little")) # ByteRate + header.extend(block_align.to_bytes(2, "little")) # BlockAlign + header.extend(bits_per_sample.to_bytes(2, "little")) # BitsPerSample + # "data" sub-chunk + header.extend(b"data") # Subchunk2ID + header.extend(data_size.to_bytes(4, "little")) # Subchunk2Size + return header + @dataclass class OpenAILLMContextFrame(Frame): diff --git a/src/pipecat/services/google.py b/src/pipecat/services/google.py index c1114e44e..69864262e 100644 --- a/src/pipecat/services/google.py +++ b/src/pipecat/services/google.py @@ -16,6 +16,7 @@ from PIL import Image from pydantic import BaseModel, Field from pipecat.frames.frames import ( + AudioRawFrame, ErrorFrame, Frame, LLMFullResponseEndFrame, @@ -184,11 +185,53 @@ class GoogleLLMContext(OpenAILLMContext): msgs.append(obj) return msgs + def add_image_frame_message( + self, *, format: str, size: tuple[int, int], image: bytes, text: str = None + ): + buffer = io.BytesIO() + Image.frombytes(format, size, image).save(buffer, format="JPEG") + + parts = [] + if text: + parts.append(glm.Part(text=text)) + parts.append( + glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue())), + ) + self.add_message(glm.Content(role="user", parts=parts)) + + def add_audio_frames_message(self, *, audio_frames: list[AudioRawFrame], text: str = None): + if not audio_frames: + return + + sample_rate = audio_frames[0].sample_rate + num_channels = audio_frames[0].num_channels + + parts = [] + data = b"".join(frame.audio for frame in audio_frames) + if text: + parts.append(glm.Part(text=text)) + parts.append( + glm.Part( + inline_data=glm.Blob( + mime_type="audio/wav", + data=( + bytes( + self.create_wav_header(sample_rate, num_channels, 16, len(data)) + data + ) + ), + ) + ), + ) + self.add_message(glm.Content(role="user", parts=parts)) + # message = {"mime_type": "audio/mp3", "data": bytes(data + create_wav_header(sample_rate, num_channels, 16, len(data)))} + # self.add_message(message) + def from_standard_message(self, message): role = message["role"] content = message.get("content", []) if role == "system": - role = "user" + self.system_message = content + return None elif role == "assistant": role = "model" @@ -232,20 +275,6 @@ class GoogleLLMContext(OpenAILLMContext): message = glm.Content(role=role, parts=parts) return message - def add_image_frame_message( - self, *, format: str, size: tuple[int, int], image: bytes, text: str = None - ): - buffer = io.BytesIO() - Image.frombytes(format, size, image).save(buffer, format="JPEG") - - parts = [] - if text: - parts.append(glm.Part(text=text)) - parts.append( - glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue())), - ) - self.add_message(glm.Content(role="user", parts=parts)) - def to_standard_messages(self, obj) -> list: msg = {"role": obj.role, "content": []} if msg["role"] == "model": @@ -289,9 +318,20 @@ class GoogleLLMContext(OpenAILLMContext): return [msg] def _restructure_from_openai_messages(self): + self.system_message = None # first, map across self._messages calling self.from_standard_message(m) to modify messages in place try: - self._messages[:] = [self.from_standard_message(m) for m in self._messages] + self._messages[:] = [ + msg + for msg in (self.from_standard_message(m) for m in self._messages) + if msg is not None + ] + # We might have been given a messages list with only a system message. If so, let's put that back in + # the messages list as a user message. + if self.system_message and not self._messages: + self.add_message( + glm.Content(role="user", parts=[glm.Part(text=self.system_message)]) + ) except Exception as e: logger.error(f"Error mapping messages: {e}") # iterate over messages and remove any messages that have an empty content list @@ -319,11 +359,14 @@ class GoogleLLMService(LLMService): api_key: str, model: str = "gemini-1.5-flash-latest", params: InputParams = InputParams(), + system_instruction: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) gai.configure(api_key=api_key) - self._create_client(model) + self.set_model_name(model) + self._system_instruction = system_instruction + self._create_client() self._settings = { "max_tokens": params.max_tokens, "temperature": params.temperature, @@ -335,34 +378,10 @@ class GoogleLLMService(LLMService): def can_generate_metrics(self) -> bool: return True - def _create_client(self, model: str): - self.set_model_name(model) - self._client = gai.GenerativeModel(model) - - def _get_messages_from_openai_context(self, context: OpenAILLMContext) -> List[glm.Content]: - openai_messages = context.get_messages() - google_messages = [] - - for message in openai_messages: - role = message["role"] - content = message["content"] - if role == "system": - role = "user" - elif role == "assistant": - role = "model" - - parts = [glm.Part(text=content)] - if "mime_type" in message: - parts.append( - glm.Part( - inline_data=glm.Blob( - mime_type=message["mime_type"], data=message["data"].getvalue() - ) - ) - ) - google_messages.append({"role": role, "parts": parts}) - - return google_messages + def _create_client(self): + self._client = gai.GenerativeModel( + self._model_name, system_instruction=self._system_instruction + ) async def _async_generator_wrapper(self, sync_generator): for item in sync_generator: @@ -374,10 +393,11 @@ class GoogleLLMService(LLMService): try: logger.debug(f"Generating chat: {context.get_messages_for_logging()}") - # todo: move this into the new context code structure, convert from openai context one time - # todo: add system instructions - # messages = self._get_messages_from_openai_context(context) messages = context.messages + if self._system_instruction != context.system_message: + logger.debug(f"System instruction changed: {context.system_message}") + self._system_instruction = context.system_message + self._create_client() # Filter out None values and create GenerationConfig generation_params = { @@ -394,24 +414,21 @@ class GoogleLLMService(LLMService): generation_config = GenerationConfig(**generation_params) if generation_params else None await self.start_ttfb_metrics() - tools = context.tools if context.tools else [] response = self._client.generate_content( contents=messages, tools=tools, stream=True, generation_config=generation_config ) - - tokens = LLMTokenUsage( - prompt_tokens=response.usage_metadata.prompt_token_count, - completion_tokens=response.usage_metadata.candidates_token_count, - total_tokens=response.usage_metadata.total_token_count, - ) - - await self.start_llm_usage_metrics(tokens) - await self.stop_ttfb_metrics() + prompt_tokens = response.usage_metadata.prompt_token_count + completion_tokens = response.usage_metadata.candidates_token_count + total_tokens = response.usage_metadata.total_token_count + async for chunk in self._async_generator_wrapper(response): - # todo: usage + if chunk.usage_metadata: + prompt_tokens += response.usage_metadata.prompt_token_count + completion_tokens += response.usage_metadata.candidates_token_count + total_tokens += response.usage_metadata.total_token_count try: for c in chunk.parts: if c.text: @@ -436,6 +453,13 @@ class GoogleLLMService(LLMService): except Exception as e: logger.exception(f"{self} exception: {e}") finally: + await self.start_llm_usage_metrics( + LLMTokenUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + ) await self.push_frame(LLMFullResponseEndFrame()) async def process_frame(self, frame: Frame, direction: FrameDirection):