gemini audio-in with no transcription
This commit is contained in:
276
examples/foundational/07p-interruptible-google-audio-in.py
Normal file
276
examples/foundational/07p-interruptible-google-audio-in.py
Normal file
@@ -0,0 +1,276 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import google.ai.generativelanguage as glm
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.google import GoogleLLMService
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.frames.frames import (
|
||||
LLMFullResponseStartFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
InputAudioRawFrame,
|
||||
Frame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
marker = "|----|"
|
||||
system_message = f"""
|
||||
You are a helpful LLM in a WebRTC call. Your goals are to be helpful and brief in your responses.
|
||||
|
||||
You are expert at transcribing audio to text. You will receive a mixture of audio and text input. When
|
||||
asked to transcribe what the user said, output an exact, word-for-word transcription.
|
||||
|
||||
Your output will be converted to audio so don't include special characters in your answers.
|
||||
|
||||
Each time you answer, you should respond in three parts.
|
||||
|
||||
1. Transcribe exactly what the user said.
|
||||
2. Output the separator field '{marker}'.
|
||||
3. Respond to the user's input in a succinct, helpful, creative way using only simple text and punctuation.
|
||||
|
||||
Example:
|
||||
|
||||
User: How many ounces are in a pound?
|
||||
|
||||
You: How many ounces are in a pound?
|
||||
{marker}
|
||||
There are 16 ounces in a pound.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MagicDemoTranscriptionFrame(Frame):
|
||||
text: str
|
||||
|
||||
|
||||
class UserAudioCollector(FrameProcessor):
|
||||
def __init__(self, context, user_context_aggregator):
|
||||
super().__init__()
|
||||
self._context = context
|
||||
self._user_context_aggregator = user_context_aggregator
|
||||
self._audio_frames = []
|
||||
self._start_secs = (
|
||||
0.2 # this should match VAD_START_SECS but we'll just hardcode it for now
|
||||
)
|
||||
self._user_speaking = False
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
# We could gracefully handle both audio input and text/transcription input ...
|
||||
# but let's leave that as an exercise to the reader. :-)
|
||||
return
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
self._user_speaking = True
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
self._user_speaking = False
|
||||
self._context.add_audio_frames_message(audio_frames=self._audio_frames)
|
||||
await self._user_context_aggregator.push_frame(
|
||||
self._user_context_aggregator.get_context_frame()
|
||||
)
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
if self._user_speaking:
|
||||
self._audio_frames.append(frame)
|
||||
else:
|
||||
# Append the audio frame to our buffer. Treat the buffer as a ring buffer, dropping the oldest
|
||||
# frames as necessary. Assume all audio frames have the same duration.
|
||||
self._audio_frames.append(frame)
|
||||
frame_duration = len(frame.audio) / 16 * frame.num_channels / frame.sample_rate
|
||||
buffer_duration = frame_duration * len(self._audio_frames)
|
||||
while buffer_duration > self._start_secs:
|
||||
self._audio_frames.pop(0)
|
||||
buffer_duration -= frame_duration
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class TranscriptExtractor(FrameProcessor):
|
||||
def __init__(self, context):
|
||||
super().__init__()
|
||||
self._context = context
|
||||
self._accumulator = ""
|
||||
self._processing_llm_response = False
|
||||
self._accumulating_transcript = False
|
||||
|
||||
def reset(self):
|
||||
self._accumulator = ""
|
||||
self._processing_llm_response = False
|
||||
self._accumulating_transcript = False
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, LLMFullResponseStartFrame):
|
||||
self._processing_llm_response = True
|
||||
self._accumulating_transcript = True
|
||||
elif isinstance(frame, TextFrame) and self._processing_llm_response:
|
||||
if self._accumulating_transcript:
|
||||
text = frame.text
|
||||
split_index = text.find(marker)
|
||||
if split_index < 0:
|
||||
self._accumulator += frame.text
|
||||
# do not push this frame
|
||||
return
|
||||
else:
|
||||
self._accumulating_transcript = False
|
||||
self._accumulator += text[:split_index]
|
||||
frame.text = text[split_index + len(marker) :]
|
||||
await self.push_frame(frame)
|
||||
return
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self.push_frame(MagicDemoTranscriptionFrame(text=self._accumulator.strip()))
|
||||
self.reset()
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class TanscriptionContextFixup(FrameProcessor):
|
||||
def __init__(self, context):
|
||||
super().__init__()
|
||||
self._context = context
|
||||
self._transcript = "THIS IS A TRANSCRIPT"
|
||||
|
||||
def swap_user_audio(self):
|
||||
if not self._transcript:
|
||||
return
|
||||
message = self._context.messages[-2]
|
||||
last_part = message.parts[-1]
|
||||
if (
|
||||
message.role == "user"
|
||||
and last_part.inline_data
|
||||
and last_part.inline_data.mime_type == "audio/wav"
|
||||
):
|
||||
self._context.messages[-2] = glm.Content(
|
||||
role="user", parts=[glm.Part(text=self._transcript)]
|
||||
)
|
||||
|
||||
def add_transcript_back_to_inference_output(self):
|
||||
if not self._transcript:
|
||||
return
|
||||
message = self._context.messages[-1]
|
||||
last_part = message.parts[-1]
|
||||
if message.role == "model" and last_part.text:
|
||||
self._context.messages[-1].parts[-1].text += f"\n\n{marker}\n{self._transcript}\n"
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, MagicDemoTranscriptionFrame):
|
||||
self._transcript = frame.text
|
||||
elif isinstance(frame, LLMFullResponseEndFrame) or isinstance(
|
||||
frame, StartInterruptionFrame
|
||||
):
|
||||
self.swap_user_audio()
|
||||
self.add_transcript_back_to_inference_output()
|
||||
self._transcript = ""
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
# No transcription at all. just audio input to Gemini!
|
||||
# transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
vad_audio_passthrough=True,
|
||||
),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = GoogleLLMService(model="gemini-1.5-flash-latest", api_key=os.getenv("GOOGLE_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_message,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Start by saying hello.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
audio_collector = UserAudioCollector(context, context_aggregator.user())
|
||||
pull_transcript_out_of_llm_output = TranscriptExtractor(context)
|
||||
fixup_context_messages = TanscriptionContextFixup(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
audio_collector,
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
pull_transcript_out_of_llm_output,
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
fixup_context_messages,
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -15,6 +15,7 @@ from loguru import logger
|
||||
from PIL import Image
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
Frame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
@@ -174,6 +175,10 @@ class OpenAILLMContext:
|
||||
content.append({"type": "text", "text": text})
|
||||
self.add_message({"role": "user", "content": content})
|
||||
|
||||
def add_audio_frames_message(self, *, audio_frames: list[AudioRawFrame], text: str = None):
|
||||
# todo: implement for OpenAI models and others
|
||||
pass
|
||||
|
||||
async def call_function(
|
||||
self,
|
||||
f: Callable[
|
||||
@@ -213,6 +218,29 @@ class OpenAILLMContext:
|
||||
|
||||
await f(function_name, tool_call_id, arguments, llm, self, function_call_result_callback)
|
||||
|
||||
def create_wav_header(self, sample_rate, num_channels, bits_per_sample, data_size):
|
||||
# RIFF chunk descriptor
|
||||
header = bytearray()
|
||||
header.extend(b"RIFF") # ChunkID
|
||||
header.extend((data_size + 36).to_bytes(4, "little")) # ChunkSize: total size - 8
|
||||
header.extend(b"WAVE") # Format
|
||||
# "fmt " sub-chunk
|
||||
header.extend(b"fmt ") # Subchunk1ID
|
||||
header.extend((16).to_bytes(4, "little")) # Subchunk1Size (16 for PCM)
|
||||
header.extend((1).to_bytes(2, "little")) # AudioFormat (1 for PCM)
|
||||
header.extend(num_channels.to_bytes(2, "little")) # NumChannels
|
||||
header.extend(sample_rate.to_bytes(4, "little")) # SampleRate
|
||||
# Calculate byte rate and block align
|
||||
byte_rate = sample_rate * num_channels * (bits_per_sample // 8)
|
||||
block_align = num_channels * (bits_per_sample // 8)
|
||||
header.extend(byte_rate.to_bytes(4, "little")) # ByteRate
|
||||
header.extend(block_align.to_bytes(2, "little")) # BlockAlign
|
||||
header.extend(bits_per_sample.to_bytes(2, "little")) # BitsPerSample
|
||||
# "data" sub-chunk
|
||||
header.extend(b"data") # Subchunk2ID
|
||||
header.extend(data_size.to_bytes(4, "little")) # Subchunk2Size
|
||||
return header
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAILLMContextFrame(Frame):
|
||||
|
||||
@@ -16,6 +16,7 @@ from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
@@ -184,11 +185,53 @@ class GoogleLLMContext(OpenAILLMContext):
|
||||
msgs.append(obj)
|
||||
return msgs
|
||||
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
|
||||
):
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
|
||||
parts = []
|
||||
if text:
|
||||
parts.append(glm.Part(text=text))
|
||||
parts.append(
|
||||
glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue())),
|
||||
)
|
||||
self.add_message(glm.Content(role="user", parts=parts))
|
||||
|
||||
def add_audio_frames_message(self, *, audio_frames: list[AudioRawFrame], text: str = None):
|
||||
if not audio_frames:
|
||||
return
|
||||
|
||||
sample_rate = audio_frames[0].sample_rate
|
||||
num_channels = audio_frames[0].num_channels
|
||||
|
||||
parts = []
|
||||
data = b"".join(frame.audio for frame in audio_frames)
|
||||
if text:
|
||||
parts.append(glm.Part(text=text))
|
||||
parts.append(
|
||||
glm.Part(
|
||||
inline_data=glm.Blob(
|
||||
mime_type="audio/wav",
|
||||
data=(
|
||||
bytes(
|
||||
self.create_wav_header(sample_rate, num_channels, 16, len(data)) + data
|
||||
)
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
self.add_message(glm.Content(role="user", parts=parts))
|
||||
# message = {"mime_type": "audio/mp3", "data": bytes(data + create_wav_header(sample_rate, num_channels, 16, len(data)))}
|
||||
# self.add_message(message)
|
||||
|
||||
def from_standard_message(self, message):
|
||||
role = message["role"]
|
||||
content = message.get("content", [])
|
||||
if role == "system":
|
||||
role = "user"
|
||||
self.system_message = content
|
||||
return None
|
||||
elif role == "assistant":
|
||||
role = "model"
|
||||
|
||||
@@ -232,20 +275,6 @@ class GoogleLLMContext(OpenAILLMContext):
|
||||
message = glm.Content(role=role, parts=parts)
|
||||
return message
|
||||
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
|
||||
):
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
|
||||
parts = []
|
||||
if text:
|
||||
parts.append(glm.Part(text=text))
|
||||
parts.append(
|
||||
glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue())),
|
||||
)
|
||||
self.add_message(glm.Content(role="user", parts=parts))
|
||||
|
||||
def to_standard_messages(self, obj) -> list:
|
||||
msg = {"role": obj.role, "content": []}
|
||||
if msg["role"] == "model":
|
||||
@@ -289,9 +318,20 @@ class GoogleLLMContext(OpenAILLMContext):
|
||||
return [msg]
|
||||
|
||||
def _restructure_from_openai_messages(self):
|
||||
self.system_message = None
|
||||
# first, map across self._messages calling self.from_standard_message(m) to modify messages in place
|
||||
try:
|
||||
self._messages[:] = [self.from_standard_message(m) for m in self._messages]
|
||||
self._messages[:] = [
|
||||
msg
|
||||
for msg in (self.from_standard_message(m) for m in self._messages)
|
||||
if msg is not None
|
||||
]
|
||||
# We might have been given a messages list with only a system message. If so, let's put that back in
|
||||
# the messages list as a user message.
|
||||
if self.system_message and not self._messages:
|
||||
self.add_message(
|
||||
glm.Content(role="user", parts=[glm.Part(text=self.system_message)])
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error mapping messages: {e}")
|
||||
# iterate over messages and remove any messages that have an empty content list
|
||||
@@ -319,11 +359,14 @@ class GoogleLLMService(LLMService):
|
||||
api_key: str,
|
||||
model: str = "gemini-1.5-flash-latest",
|
||||
params: InputParams = InputParams(),
|
||||
system_instruction: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
gai.configure(api_key=api_key)
|
||||
self._create_client(model)
|
||||
self.set_model_name(model)
|
||||
self._system_instruction = system_instruction
|
||||
self._create_client()
|
||||
self._settings = {
|
||||
"max_tokens": params.max_tokens,
|
||||
"temperature": params.temperature,
|
||||
@@ -335,34 +378,10 @@ class GoogleLLMService(LLMService):
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def _create_client(self, model: str):
|
||||
self.set_model_name(model)
|
||||
self._client = gai.GenerativeModel(model)
|
||||
|
||||
def _get_messages_from_openai_context(self, context: OpenAILLMContext) -> List[glm.Content]:
|
||||
openai_messages = context.get_messages()
|
||||
google_messages = []
|
||||
|
||||
for message in openai_messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
if role == "system":
|
||||
role = "user"
|
||||
elif role == "assistant":
|
||||
role = "model"
|
||||
|
||||
parts = [glm.Part(text=content)]
|
||||
if "mime_type" in message:
|
||||
parts.append(
|
||||
glm.Part(
|
||||
inline_data=glm.Blob(
|
||||
mime_type=message["mime_type"], data=message["data"].getvalue()
|
||||
)
|
||||
)
|
||||
)
|
||||
google_messages.append({"role": role, "parts": parts})
|
||||
|
||||
return google_messages
|
||||
def _create_client(self):
|
||||
self._client = gai.GenerativeModel(
|
||||
self._model_name, system_instruction=self._system_instruction
|
||||
)
|
||||
|
||||
async def _async_generator_wrapper(self, sync_generator):
|
||||
for item in sync_generator:
|
||||
@@ -374,10 +393,11 @@ class GoogleLLMService(LLMService):
|
||||
try:
|
||||
logger.debug(f"Generating chat: {context.get_messages_for_logging()}")
|
||||
|
||||
# todo: move this into the new context code structure, convert from openai context one time
|
||||
# todo: add system instructions
|
||||
# messages = self._get_messages_from_openai_context(context)
|
||||
messages = context.messages
|
||||
if self._system_instruction != context.system_message:
|
||||
logger.debug(f"System instruction changed: {context.system_message}")
|
||||
self._system_instruction = context.system_message
|
||||
self._create_client()
|
||||
|
||||
# Filter out None values and create GenerationConfig
|
||||
generation_params = {
|
||||
@@ -394,24 +414,21 @@ class GoogleLLMService(LLMService):
|
||||
generation_config = GenerationConfig(**generation_params) if generation_params else None
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
tools = context.tools if context.tools else []
|
||||
response = self._client.generate_content(
|
||||
contents=messages, tools=tools, stream=True, generation_config=generation_config
|
||||
)
|
||||
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=response.usage_metadata.prompt_token_count,
|
||||
completion_tokens=response.usage_metadata.candidates_token_count,
|
||||
total_tokens=response.usage_metadata.total_token_count,
|
||||
)
|
||||
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
prompt_tokens = response.usage_metadata.prompt_token_count
|
||||
completion_tokens = response.usage_metadata.candidates_token_count
|
||||
total_tokens = response.usage_metadata.total_token_count
|
||||
|
||||
async for chunk in self._async_generator_wrapper(response):
|
||||
# todo: usage
|
||||
if chunk.usage_metadata:
|
||||
prompt_tokens += response.usage_metadata.prompt_token_count
|
||||
completion_tokens += response.usage_metadata.candidates_token_count
|
||||
total_tokens += response.usage_metadata.total_token_count
|
||||
try:
|
||||
for c in chunk.parts:
|
||||
if c.text:
|
||||
@@ -436,6 +453,13 @@ class GoogleLLMService(LLMService):
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
finally:
|
||||
await self.start_llm_usage_metrics(
|
||||
LLMTokenUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
)
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
|
||||
Reference in New Issue
Block a user