gemini audio-in with no transcription

This commit is contained in:
Kwindla Hultman Kramer
2024-11-08 08:28:54 -08:00
parent 91ac40307e
commit ee53535f41
3 changed files with 387 additions and 59 deletions

View File

@@ -0,0 +1,276 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import aiohttp
import asyncio
import os
import sys
import google.ai.generativelanguage as glm
from dataclasses import dataclass
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.google import GoogleLLMService
from pipecat.processors.frame_processor import FrameProcessor
from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.frames.frames import (
LLMFullResponseStartFrame,
LLMFullResponseEndFrame,
InputAudioRawFrame,
Frame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
marker = "|----|"
system_message = f"""
You are a helpful LLM in a WebRTC call. Your goals are to be helpful and brief in your responses.
You are expert at transcribing audio to text. You will receive a mixture of audio and text input. When
asked to transcribe what the user said, output an exact, word-for-word transcription.
Your output will be converted to audio so don't include special characters in your answers.
Each time you answer, you should respond in three parts.
1. Transcribe exactly what the user said.
2. Output the separator field '{marker}'.
3. Respond to the user's input in a succinct, helpful, creative way using only simple text and punctuation.
Example:
User: How many ounces are in a pound?
You: How many ounces are in a pound?
{marker}
There are 16 ounces in a pound.
"""
@dataclass
class MagicDemoTranscriptionFrame(Frame):
text: str
class UserAudioCollector(FrameProcessor):
def __init__(self, context, user_context_aggregator):
super().__init__()
self._context = context
self._user_context_aggregator = user_context_aggregator
self._audio_frames = []
self._start_secs = (
0.2 # this should match VAD_START_SECS but we'll just hardcode it for now
)
self._user_speaking = False
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
if isinstance(frame, TranscriptionFrame):
# We could gracefully handle both audio input and text/transcription input ...
# but let's leave that as an exercise to the reader. :-)
return
if isinstance(frame, UserStartedSpeakingFrame):
self._user_speaking = True
elif isinstance(frame, UserStoppedSpeakingFrame):
self._user_speaking = False
self._context.add_audio_frames_message(audio_frames=self._audio_frames)
await self._user_context_aggregator.push_frame(
self._user_context_aggregator.get_context_frame()
)
elif isinstance(frame, InputAudioRawFrame):
if self._user_speaking:
self._audio_frames.append(frame)
else:
# Append the audio frame to our buffer. Treat the buffer as a ring buffer, dropping the oldest
# frames as necessary. Assume all audio frames have the same duration.
self._audio_frames.append(frame)
frame_duration = len(frame.audio) / 16 * frame.num_channels / frame.sample_rate
buffer_duration = frame_duration * len(self._audio_frames)
while buffer_duration > self._start_secs:
self._audio_frames.pop(0)
buffer_duration -= frame_duration
await self.push_frame(frame, direction)
class TranscriptExtractor(FrameProcessor):
def __init__(self, context):
super().__init__()
self._context = context
self._accumulator = ""
self._processing_llm_response = False
self._accumulating_transcript = False
def reset(self):
self._accumulator = ""
self._processing_llm_response = False
self._accumulating_transcript = False
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
if isinstance(frame, LLMFullResponseStartFrame):
self._processing_llm_response = True
self._accumulating_transcript = True
elif isinstance(frame, TextFrame) and self._processing_llm_response:
if self._accumulating_transcript:
text = frame.text
split_index = text.find(marker)
if split_index < 0:
self._accumulator += frame.text
# do not push this frame
return
else:
self._accumulating_transcript = False
self._accumulator += text[:split_index]
frame.text = text[split_index + len(marker) :]
await self.push_frame(frame)
return
elif isinstance(frame, LLMFullResponseEndFrame):
await self.push_frame(MagicDemoTranscriptionFrame(text=self._accumulator.strip()))
self.reset()
await self.push_frame(frame, direction)
class TanscriptionContextFixup(FrameProcessor):
def __init__(self, context):
super().__init__()
self._context = context
self._transcript = "THIS IS A TRANSCRIPT"
def swap_user_audio(self):
if not self._transcript:
return
message = self._context.messages[-2]
last_part = message.parts[-1]
if (
message.role == "user"
and last_part.inline_data
and last_part.inline_data.mime_type == "audio/wav"
):
self._context.messages[-2] = glm.Content(
role="user", parts=[glm.Part(text=self._transcript)]
)
def add_transcript_back_to_inference_output(self):
if not self._transcript:
return
message = self._context.messages[-1]
last_part = message.parts[-1]
if message.role == "model" and last_part.text:
self._context.messages[-1].parts[-1].text += f"\n\n{marker}\n{self._transcript}\n"
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
if isinstance(frame, MagicDemoTranscriptionFrame):
self._transcript = frame.text
elif isinstance(frame, LLMFullResponseEndFrame) or isinstance(
frame, StartInterruptionFrame
):
self.swap_user_audio()
self.add_transcript_back_to_inference_output()
self._transcript = ""
await self.push_frame(frame, direction)
async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
# No transcription at all. just audio input to Gemini!
# transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
vad_audio_passthrough=True,
),
)
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
)
llm = GoogleLLMService(model="gemini-1.5-flash-latest", api_key=os.getenv("GOOGLE_API_KEY"))
messages = [
{
"role": "system",
"content": system_message,
},
{
"role": "user",
"content": "Start by saying hello.",
},
]
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
audio_collector = UserAudioCollector(context, context_aggregator.user())
pull_transcript_out_of_llm_output = TranscriptExtractor(context)
fixup_context_messages = TanscriptionContextFixup(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
audio_collector,
context_aggregator.user(), # User responses
llm, # LLM
pull_transcript_out_of_llm_output,
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
fixup_context_messages,
]
)
task = PipelineTask(
pipeline,
PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
),
)
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
await task.queue_frames([context_aggregator.user().get_context_frame()])
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -15,6 +15,7 @@ from loguru import logger
from PIL import Image
from pipecat.frames.frames import (
AudioRawFrame,
Frame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
@@ -174,6 +175,10 @@ class OpenAILLMContext:
content.append({"type": "text", "text": text})
self.add_message({"role": "user", "content": content})
def add_audio_frames_message(self, *, audio_frames: list[AudioRawFrame], text: str = None):
# todo: implement for OpenAI models and others
pass
async def call_function(
self,
f: Callable[
@@ -213,6 +218,29 @@ class OpenAILLMContext:
await f(function_name, tool_call_id, arguments, llm, self, function_call_result_callback)
def create_wav_header(self, sample_rate, num_channels, bits_per_sample, data_size):
# RIFF chunk descriptor
header = bytearray()
header.extend(b"RIFF") # ChunkID
header.extend((data_size + 36).to_bytes(4, "little")) # ChunkSize: total size - 8
header.extend(b"WAVE") # Format
# "fmt " sub-chunk
header.extend(b"fmt ") # Subchunk1ID
header.extend((16).to_bytes(4, "little")) # Subchunk1Size (16 for PCM)
header.extend((1).to_bytes(2, "little")) # AudioFormat (1 for PCM)
header.extend(num_channels.to_bytes(2, "little")) # NumChannels
header.extend(sample_rate.to_bytes(4, "little")) # SampleRate
# Calculate byte rate and block align
byte_rate = sample_rate * num_channels * (bits_per_sample // 8)
block_align = num_channels * (bits_per_sample // 8)
header.extend(byte_rate.to_bytes(4, "little")) # ByteRate
header.extend(block_align.to_bytes(2, "little")) # BlockAlign
header.extend(bits_per_sample.to_bytes(2, "little")) # BitsPerSample
# "data" sub-chunk
header.extend(b"data") # Subchunk2ID
header.extend(data_size.to_bytes(4, "little")) # Subchunk2Size
return header
@dataclass
class OpenAILLMContextFrame(Frame):

View File

@@ -16,6 +16,7 @@ from PIL import Image
from pydantic import BaseModel, Field
from pipecat.frames.frames import (
AudioRawFrame,
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
@@ -184,11 +185,53 @@ class GoogleLLMContext(OpenAILLMContext):
msgs.append(obj)
return msgs
def add_image_frame_message(
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
):
buffer = io.BytesIO()
Image.frombytes(format, size, image).save(buffer, format="JPEG")
parts = []
if text:
parts.append(glm.Part(text=text))
parts.append(
glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue())),
)
self.add_message(glm.Content(role="user", parts=parts))
def add_audio_frames_message(self, *, audio_frames: list[AudioRawFrame], text: str = None):
if not audio_frames:
return
sample_rate = audio_frames[0].sample_rate
num_channels = audio_frames[0].num_channels
parts = []
data = b"".join(frame.audio for frame in audio_frames)
if text:
parts.append(glm.Part(text=text))
parts.append(
glm.Part(
inline_data=glm.Blob(
mime_type="audio/wav",
data=(
bytes(
self.create_wav_header(sample_rate, num_channels, 16, len(data)) + data
)
),
)
),
)
self.add_message(glm.Content(role="user", parts=parts))
# message = {"mime_type": "audio/mp3", "data": bytes(data + create_wav_header(sample_rate, num_channels, 16, len(data)))}
# self.add_message(message)
def from_standard_message(self, message):
role = message["role"]
content = message.get("content", [])
if role == "system":
role = "user"
self.system_message = content
return None
elif role == "assistant":
role = "model"
@@ -232,20 +275,6 @@ class GoogleLLMContext(OpenAILLMContext):
message = glm.Content(role=role, parts=parts)
return message
def add_image_frame_message(
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
):
buffer = io.BytesIO()
Image.frombytes(format, size, image).save(buffer, format="JPEG")
parts = []
if text:
parts.append(glm.Part(text=text))
parts.append(
glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue())),
)
self.add_message(glm.Content(role="user", parts=parts))
def to_standard_messages(self, obj) -> list:
msg = {"role": obj.role, "content": []}
if msg["role"] == "model":
@@ -289,9 +318,20 @@ class GoogleLLMContext(OpenAILLMContext):
return [msg]
def _restructure_from_openai_messages(self):
self.system_message = None
# first, map across self._messages calling self.from_standard_message(m) to modify messages in place
try:
self._messages[:] = [self.from_standard_message(m) for m in self._messages]
self._messages[:] = [
msg
for msg in (self.from_standard_message(m) for m in self._messages)
if msg is not None
]
# We might have been given a messages list with only a system message. If so, let's put that back in
# the messages list as a user message.
if self.system_message and not self._messages:
self.add_message(
glm.Content(role="user", parts=[glm.Part(text=self.system_message)])
)
except Exception as e:
logger.error(f"Error mapping messages: {e}")
# iterate over messages and remove any messages that have an empty content list
@@ -319,11 +359,14 @@ class GoogleLLMService(LLMService):
api_key: str,
model: str = "gemini-1.5-flash-latest",
params: InputParams = InputParams(),
system_instruction: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
gai.configure(api_key=api_key)
self._create_client(model)
self.set_model_name(model)
self._system_instruction = system_instruction
self._create_client()
self._settings = {
"max_tokens": params.max_tokens,
"temperature": params.temperature,
@@ -335,34 +378,10 @@ class GoogleLLMService(LLMService):
def can_generate_metrics(self) -> bool:
return True
def _create_client(self, model: str):
self.set_model_name(model)
self._client = gai.GenerativeModel(model)
def _get_messages_from_openai_context(self, context: OpenAILLMContext) -> List[glm.Content]:
openai_messages = context.get_messages()
google_messages = []
for message in openai_messages:
role = message["role"]
content = message["content"]
if role == "system":
role = "user"
elif role == "assistant":
role = "model"
parts = [glm.Part(text=content)]
if "mime_type" in message:
parts.append(
glm.Part(
inline_data=glm.Blob(
mime_type=message["mime_type"], data=message["data"].getvalue()
)
)
)
google_messages.append({"role": role, "parts": parts})
return google_messages
def _create_client(self):
self._client = gai.GenerativeModel(
self._model_name, system_instruction=self._system_instruction
)
async def _async_generator_wrapper(self, sync_generator):
for item in sync_generator:
@@ -374,10 +393,11 @@ class GoogleLLMService(LLMService):
try:
logger.debug(f"Generating chat: {context.get_messages_for_logging()}")
# todo: move this into the new context code structure, convert from openai context one time
# todo: add system instructions
# messages = self._get_messages_from_openai_context(context)
messages = context.messages
if self._system_instruction != context.system_message:
logger.debug(f"System instruction changed: {context.system_message}")
self._system_instruction = context.system_message
self._create_client()
# Filter out None values and create GenerationConfig
generation_params = {
@@ -394,24 +414,21 @@ class GoogleLLMService(LLMService):
generation_config = GenerationConfig(**generation_params) if generation_params else None
await self.start_ttfb_metrics()
tools = context.tools if context.tools else []
response = self._client.generate_content(
contents=messages, tools=tools, stream=True, generation_config=generation_config
)
tokens = LLMTokenUsage(
prompt_tokens=response.usage_metadata.prompt_token_count,
completion_tokens=response.usage_metadata.candidates_token_count,
total_tokens=response.usage_metadata.total_token_count,
)
await self.start_llm_usage_metrics(tokens)
await self.stop_ttfb_metrics()
prompt_tokens = response.usage_metadata.prompt_token_count
completion_tokens = response.usage_metadata.candidates_token_count
total_tokens = response.usage_metadata.total_token_count
async for chunk in self._async_generator_wrapper(response):
# todo: usage
if chunk.usage_metadata:
prompt_tokens += response.usage_metadata.prompt_token_count
completion_tokens += response.usage_metadata.candidates_token_count
total_tokens += response.usage_metadata.total_token_count
try:
for c in chunk.parts:
if c.text:
@@ -436,6 +453,13 @@ class GoogleLLMService(LLMService):
except Exception as e:
logger.exception(f"{self} exception: {e}")
finally:
await self.start_llm_usage_metrics(
LLMTokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
)
await self.push_frame(LLMFullResponseEndFrame())
async def process_frame(self, frame: Frame, direction: FrameDirection):