Files
pipecat/examples/foundational/22d-natural-conversation-gemini-audio.py
2024-11-11 21:04:50 -08:00

356 lines
13 KiB
Python

#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import aiohttp
import asyncio
import os
import sys
import time
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMMessagesFrame, TextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.services.deepgram import DeepgramSTTService
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
)
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.google import GoogleLLMService, GoogleLLMContext
from pipecat.sync.event_notifier import EventNotifier
from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.processors.frame_processor import FrameProcessor, FrameDirection
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
InputAudioRawFrame,
StartFrame,
StartInterruptionFrame,
StopInterruptionFrame,
SystemFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.sync.base_notifier import BaseNotifier
from pipecat.processors.filters.function_filter import FunctionFilter
from pipecat.processors.user_idle_processor import UserIdleProcessor
from runner import configure
from loguru import logger
from dotenv import load_dotenv
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
classifier_statement = """You are an audio language classifier model. You are receiving audio from a user in a WebRTC call. Your job is to decide whether the user has finished speaking or not.
Categorize the input you receive as either:
1. a complete thought, statement, or question, or
2. an incomplete thought, statement, or question
Output 'YES' if the input is likely to be a completed thought, statement, or question.
Output 'NO' if the input indicates that the user is still speaking and does not yet expect a response yet.
If you are unsure, output 'YES'.
"""
conversational_system_message = """You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.
Please be very concise in your responses. Unless you are explicitly asked to do otherwise, give me the shortest complete answer possible without unnecessary elaboration. Generally you should answer with a single sentence.
"""
class StatementJudgeAudioContextAccumulator(FrameProcessor):
def __init__(self, *, notifier: BaseNotifier, **kwargs):
super().__init__(**kwargs)
self._notifier = notifier
self._audio_frames = []
self._audio_frames = []
self._start_secs = 0.2 # this should match VAD start_secs (hardcoding for now)
self._user_speaking = False
async def reset(self):
self._audio_frames = []
self._user_speaking = False
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
# ignore context frame
if isinstance(frame, OpenAILLMContextFrame):
return
if isinstance(frame, TranscriptionFrame):
# We could gracefully handle both audio input and text/transcription input ...
# but let's leave that as an exercise to the reader. :-)
return
if isinstance(frame, UserStartedSpeakingFrame):
self._user_speaking = True
elif isinstance(frame, UserStoppedSpeakingFrame):
self._user_speaking = False
context = GoogleLLMContext()
context.set_messages([{"role": "system", "content": classifier_statement}])
context.add_audio_frames_message(audio_frames=self._audio_frames)
await self.push_frame(OpenAILLMContextFrame(context=context))
elif isinstance(frame, InputAudioRawFrame):
if self._user_speaking:
self._audio_frames.append(frame)
else:
# Append the audio frame to our buffer. Treat the buffer as a ring buffer, dropping the oldest
# frames as necessary. Assume all audio frames have the same duration.
self._audio_frames.append(frame)
frame_duration = len(frame.audio) / 16 * frame.num_channels / frame.sample_rate
buffer_duration = frame_duration * len(self._audio_frames)
while buffer_duration > self._start_secs:
self._audio_frames.pop(0)
buffer_duration -= frame_duration
await self.push_frame(frame, direction)
class CompletenessCheck(FrameProcessor):
def __init__(
self, notifier: BaseNotifier, audio_accumulator: StatementJudgeAudioContextAccumulator
):
super().__init__()
self._notifier = notifier
self._audio_accumulator = audio_accumulator
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TextFrame) and frame.text.startswith("YES"):
logger.debug("Completeness check YES")
await self.push_frame(UserStoppedSpeakingFrame())
await self._audio_accumulator.reset()
await self._notifier.notify()
elif isinstance(frame, TextFrame):
if frame.text.strip():
logger.debug(f"Completeness check NO - '{frame.text}'")
class OutputGate(FrameProcessor):
def __init__(self, notifier: BaseNotifier, **kwargs):
super().__init__(**kwargs)
self._gate_open = False
self._frames_buffer = []
self._notifier = notifier
def close_gate(self):
self._gate_open = False
def open_gate(self):
self._gate_open = True
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
# We must not block system frames.
if isinstance(frame, SystemFrame):
if isinstance(frame, StartFrame):
await self._start()
if isinstance(frame, (EndFrame, CancelFrame)):
await self._stop()
if isinstance(frame, StartInterruptionFrame):
self._frames_buffer = []
self.close_gate()
await self.push_frame(frame, direction)
return
# Ignore frames that are not following the direction of this gate.
if direction != FrameDirection.DOWNSTREAM:
await self.push_frame(frame, direction)
return
if self._gate_open:
await self.push_frame(frame, direction)
return
self._frames_buffer.append((frame, direction))
async def _start(self):
self._frames_buffer = []
self._gate_task = self.get_event_loop().create_task(self._gate_task_handler())
async def _stop(self):
self._gate_task.cancel()
await self._gate_task
async def _gate_task_handler(self):
while True:
try:
await self._notifier.wait()
self.open_gate()
for frame, direction in self._frames_buffer:
await self.push_frame(frame, direction)
self._frames_buffer = []
except asyncio.CancelledError:
break
async def main():
async with aiohttp.ClientSession() as session:
(room_url, _) = await configure(session)
transport = DailyTransport(
room_url,
None,
"Respond bot",
DailyParams(
audio_out_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
vad_audio_passthrough=True,
),
)
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
)
# This is the LLM that will be used to detect if the user has finished a
# statement. This doesn't really need to be an LLM, we could use NLP
# libraries for that, but we have the machinery to use an LLM, so we might as well!
statement_llm = GoogleLLMService(
model="gemini-1.5-flash-latest", api_key=os.getenv("GOOGLE_API_KEY")
)
# This is the regular LLM.
llm = GoogleLLMService(model="gemini-1.5-flash-latest", api_key=os.getenv("GOOGLE_API_KEY"))
messages = [
{
"role": "system",
"content": conversational_system_message,
},
]
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
# We have instructed the LLM to return 'YES' if it thinks the user
# completed a sentence. So, if it's 'YES' we will return true in this
# predicate which will wake up the notifier.
async def wake_check_filter(frame):
return frame.text == "YES"
# This is a notifier that we use to synchronize the two LLMs.
notifier = EventNotifier()
# This turns the LLM context into an inference request to classify the user's speech
# as complete or incomplete.
statement_judge_context_filter = StatementJudgeAudioContextAccumulator(notifier=notifier)
# This sends a UserStoppedSpeakingFrame and triggers the notifier event
completeness_check = CompletenessCheck(
notifier=notifier, audio_accumulator=statement_judge_context_filter
)
# # Notify if the user hasn't said anything.
async def user_idle_notifier(frame):
await notifier.notify()
# Sometimes the LLM will fail detecting if a user has completed a
# sentence, this will wake up the notifier if that happens.
user_idle = UserIdleProcessor(callback=user_idle_notifier, timeout=5.0)
bot_output_gate = OutputGate(notifier=notifier)
async def block_user_stopped_speaking(frame):
return not isinstance(frame, UserStoppedSpeakingFrame)
async def pass_only_llm_trigger_frames(frame):
return (
isinstance(frame, OpenAILLMContextFrame)
or isinstance(frame, LLMMessagesFrame)
or isinstance(frame, StartInterruptionFrame)
or isinstance(frame, StopInterruptionFrame)
)
pipeline = Pipeline(
[
transport.input(),
ParallelPipeline(
[
# Pass everything except UserStoppedSpeaking to the elements after
# this ParallelPipeline
FunctionFilter(filter=block_user_stopped_speaking),
],
[
statement_judge_context_filter,
statement_llm,
completeness_check,
],
[
stt,
context_aggregator.user(),
# Block everything except OpenAILLMContextFrame and LLMMessagesFrame
FunctionFilter(filter=pass_only_llm_trigger_frames),
llm,
bot_output_gate, # Buffer all llm/tts output until notified.
],
),
tts,
user_idle,
transport.output(),
context_aggregator.assistant(),
]
)
task = PipelineTask(
pipeline,
PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
),
)
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
await task.queue_frames([context_aggregator.user().get_context_frame()])
@transport.event_handler("on_app_message")
async def on_app_message(transport, message, sender):
logger.debug(f"Received app message: {message} - {sender}")
if "message" not in message:
return
await task.queue_frames(
[
UserStartedSpeakingFrame(),
TranscriptionFrame(
user_id=sender, timestamp=time.time(), text=message["message"]
),
UserStoppedSpeakingFrame(),
]
)
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())