356 lines
13 KiB
Python
356 lines
13 KiB
Python
#
|
|
# Copyright (c) 2024, Daily
|
|
#
|
|
# SPDX-License-Identifier: BSD 2-Clause License
|
|
#
|
|
|
|
import aiohttp
|
|
import asyncio
|
|
import os
|
|
import sys
|
|
import time
|
|
|
|
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
|
from pipecat.frames.frames import LLMMessagesFrame, TextFrame
|
|
from pipecat.pipeline.pipeline import Pipeline
|
|
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
|
from pipecat.pipeline.runner import PipelineRunner
|
|
from pipecat.services.deepgram import DeepgramSTTService
|
|
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
|
from pipecat.processors.aggregators.openai_llm_context import (
|
|
OpenAILLMContext,
|
|
)
|
|
from pipecat.services.cartesia import CartesiaTTSService
|
|
from pipecat.services.google import GoogleLLMService, GoogleLLMContext
|
|
from pipecat.sync.event_notifier import EventNotifier
|
|
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
|
from pipecat.processors.frame_processor import FrameProcessor, FrameDirection
|
|
from pipecat.frames.frames import (
|
|
CancelFrame,
|
|
EndFrame,
|
|
Frame,
|
|
InputAudioRawFrame,
|
|
StartFrame,
|
|
StartInterruptionFrame,
|
|
StopInterruptionFrame,
|
|
SystemFrame,
|
|
TranscriptionFrame,
|
|
UserStartedSpeakingFrame,
|
|
UserStoppedSpeakingFrame,
|
|
)
|
|
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
|
from pipecat.sync.base_notifier import BaseNotifier
|
|
from pipecat.processors.filters.function_filter import FunctionFilter
|
|
from pipecat.processors.user_idle_processor import UserIdleProcessor
|
|
|
|
|
|
from runner import configure
|
|
|
|
from loguru import logger
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv(override=True)
|
|
|
|
logger.remove(0)
|
|
logger.add(sys.stderr, level="DEBUG")
|
|
|
|
|
|
classifier_statement = """You are an audio language classifier model. You are receiving audio from a user in a WebRTC call. Your job is to decide whether the user has finished speaking or not.
|
|
|
|
Categorize the input you receive as either:
|
|
|
|
1. a complete thought, statement, or question, or
|
|
2. an incomplete thought, statement, or question
|
|
|
|
Output 'YES' if the input is likely to be a completed thought, statement, or question.
|
|
|
|
Output 'NO' if the input indicates that the user is still speaking and does not yet expect a response yet.
|
|
|
|
If you are unsure, output 'YES'.
|
|
"""
|
|
|
|
conversational_system_message = """You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.
|
|
|
|
Please be very concise in your responses. Unless you are explicitly asked to do otherwise, give me the shortest complete answer possible without unnecessary elaboration. Generally you should answer with a single sentence.
|
|
"""
|
|
|
|
|
|
class StatementJudgeAudioContextAccumulator(FrameProcessor):
|
|
def __init__(self, *, notifier: BaseNotifier, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self._notifier = notifier
|
|
self._audio_frames = []
|
|
self._audio_frames = []
|
|
self._start_secs = 0.2 # this should match VAD start_secs (hardcoding for now)
|
|
self._user_speaking = False
|
|
|
|
async def reset(self):
|
|
self._audio_frames = []
|
|
self._user_speaking = False
|
|
|
|
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
|
await super().process_frame(frame, direction)
|
|
|
|
# ignore context frame
|
|
if isinstance(frame, OpenAILLMContextFrame):
|
|
return
|
|
|
|
if isinstance(frame, TranscriptionFrame):
|
|
# We could gracefully handle both audio input and text/transcription input ...
|
|
# but let's leave that as an exercise to the reader. :-)
|
|
return
|
|
if isinstance(frame, UserStartedSpeakingFrame):
|
|
self._user_speaking = True
|
|
elif isinstance(frame, UserStoppedSpeakingFrame):
|
|
self._user_speaking = False
|
|
context = GoogleLLMContext()
|
|
context.set_messages([{"role": "system", "content": classifier_statement}])
|
|
context.add_audio_frames_message(audio_frames=self._audio_frames)
|
|
await self.push_frame(OpenAILLMContextFrame(context=context))
|
|
elif isinstance(frame, InputAudioRawFrame):
|
|
if self._user_speaking:
|
|
self._audio_frames.append(frame)
|
|
else:
|
|
# Append the audio frame to our buffer. Treat the buffer as a ring buffer, dropping the oldest
|
|
# frames as necessary. Assume all audio frames have the same duration.
|
|
self._audio_frames.append(frame)
|
|
frame_duration = len(frame.audio) / 16 * frame.num_channels / frame.sample_rate
|
|
buffer_duration = frame_duration * len(self._audio_frames)
|
|
while buffer_duration > self._start_secs:
|
|
self._audio_frames.pop(0)
|
|
buffer_duration -= frame_duration
|
|
|
|
await self.push_frame(frame, direction)
|
|
|
|
|
|
class CompletenessCheck(FrameProcessor):
|
|
def __init__(
|
|
self, notifier: BaseNotifier, audio_accumulator: StatementJudgeAudioContextAccumulator
|
|
):
|
|
super().__init__()
|
|
self._notifier = notifier
|
|
self._audio_accumulator = audio_accumulator
|
|
|
|
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
|
await super().process_frame(frame, direction)
|
|
|
|
if isinstance(frame, TextFrame) and frame.text.startswith("YES"):
|
|
logger.debug("Completeness check YES")
|
|
await self.push_frame(UserStoppedSpeakingFrame())
|
|
await self._audio_accumulator.reset()
|
|
await self._notifier.notify()
|
|
elif isinstance(frame, TextFrame):
|
|
if frame.text.strip():
|
|
logger.debug(f"Completeness check NO - '{frame.text}'")
|
|
|
|
|
|
class OutputGate(FrameProcessor):
|
|
def __init__(self, notifier: BaseNotifier, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self._gate_open = False
|
|
self._frames_buffer = []
|
|
self._notifier = notifier
|
|
|
|
def close_gate(self):
|
|
self._gate_open = False
|
|
|
|
def open_gate(self):
|
|
self._gate_open = True
|
|
|
|
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
|
await super().process_frame(frame, direction)
|
|
|
|
# We must not block system frames.
|
|
if isinstance(frame, SystemFrame):
|
|
if isinstance(frame, StartFrame):
|
|
await self._start()
|
|
if isinstance(frame, (EndFrame, CancelFrame)):
|
|
await self._stop()
|
|
if isinstance(frame, StartInterruptionFrame):
|
|
self._frames_buffer = []
|
|
self.close_gate()
|
|
await self.push_frame(frame, direction)
|
|
return
|
|
|
|
# Ignore frames that are not following the direction of this gate.
|
|
if direction != FrameDirection.DOWNSTREAM:
|
|
await self.push_frame(frame, direction)
|
|
return
|
|
|
|
if self._gate_open:
|
|
await self.push_frame(frame, direction)
|
|
return
|
|
|
|
self._frames_buffer.append((frame, direction))
|
|
|
|
async def _start(self):
|
|
self._frames_buffer = []
|
|
self._gate_task = self.get_event_loop().create_task(self._gate_task_handler())
|
|
|
|
async def _stop(self):
|
|
self._gate_task.cancel()
|
|
await self._gate_task
|
|
|
|
async def _gate_task_handler(self):
|
|
while True:
|
|
try:
|
|
await self._notifier.wait()
|
|
self.open_gate()
|
|
for frame, direction in self._frames_buffer:
|
|
await self.push_frame(frame, direction)
|
|
self._frames_buffer = []
|
|
except asyncio.CancelledError:
|
|
break
|
|
|
|
|
|
async def main():
|
|
async with aiohttp.ClientSession() as session:
|
|
(room_url, _) = await configure(session)
|
|
|
|
transport = DailyTransport(
|
|
room_url,
|
|
None,
|
|
"Respond bot",
|
|
DailyParams(
|
|
audio_out_enabled=True,
|
|
vad_enabled=True,
|
|
vad_analyzer=SileroVADAnalyzer(),
|
|
vad_audio_passthrough=True,
|
|
),
|
|
)
|
|
|
|
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
|
|
|
tts = CartesiaTTSService(
|
|
api_key=os.getenv("CARTESIA_API_KEY"),
|
|
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
|
)
|
|
|
|
# This is the LLM that will be used to detect if the user has finished a
|
|
# statement. This doesn't really need to be an LLM, we could use NLP
|
|
# libraries for that, but we have the machinery to use an LLM, so we might as well!
|
|
statement_llm = GoogleLLMService(
|
|
model="gemini-1.5-flash-latest", api_key=os.getenv("GOOGLE_API_KEY")
|
|
)
|
|
|
|
# This is the regular LLM.
|
|
llm = GoogleLLMService(model="gemini-1.5-flash-latest", api_key=os.getenv("GOOGLE_API_KEY"))
|
|
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": conversational_system_message,
|
|
},
|
|
]
|
|
|
|
context = OpenAILLMContext(messages)
|
|
context_aggregator = llm.create_context_aggregator(context)
|
|
|
|
# We have instructed the LLM to return 'YES' if it thinks the user
|
|
# completed a sentence. So, if it's 'YES' we will return true in this
|
|
# predicate which will wake up the notifier.
|
|
async def wake_check_filter(frame):
|
|
return frame.text == "YES"
|
|
|
|
# This is a notifier that we use to synchronize the two LLMs.
|
|
notifier = EventNotifier()
|
|
|
|
# This turns the LLM context into an inference request to classify the user's speech
|
|
# as complete or incomplete.
|
|
statement_judge_context_filter = StatementJudgeAudioContextAccumulator(notifier=notifier)
|
|
|
|
# This sends a UserStoppedSpeakingFrame and triggers the notifier event
|
|
completeness_check = CompletenessCheck(
|
|
notifier=notifier, audio_accumulator=statement_judge_context_filter
|
|
)
|
|
|
|
# # Notify if the user hasn't said anything.
|
|
async def user_idle_notifier(frame):
|
|
await notifier.notify()
|
|
|
|
# Sometimes the LLM will fail detecting if a user has completed a
|
|
# sentence, this will wake up the notifier if that happens.
|
|
user_idle = UserIdleProcessor(callback=user_idle_notifier, timeout=5.0)
|
|
|
|
bot_output_gate = OutputGate(notifier=notifier)
|
|
|
|
async def block_user_stopped_speaking(frame):
|
|
return not isinstance(frame, UserStoppedSpeakingFrame)
|
|
|
|
async def pass_only_llm_trigger_frames(frame):
|
|
return (
|
|
isinstance(frame, OpenAILLMContextFrame)
|
|
or isinstance(frame, LLMMessagesFrame)
|
|
or isinstance(frame, StartInterruptionFrame)
|
|
or isinstance(frame, StopInterruptionFrame)
|
|
)
|
|
|
|
pipeline = Pipeline(
|
|
[
|
|
transport.input(),
|
|
ParallelPipeline(
|
|
[
|
|
# Pass everything except UserStoppedSpeaking to the elements after
|
|
# this ParallelPipeline
|
|
FunctionFilter(filter=block_user_stopped_speaking),
|
|
],
|
|
[
|
|
statement_judge_context_filter,
|
|
statement_llm,
|
|
completeness_check,
|
|
],
|
|
[
|
|
stt,
|
|
context_aggregator.user(),
|
|
# Block everything except OpenAILLMContextFrame and LLMMessagesFrame
|
|
FunctionFilter(filter=pass_only_llm_trigger_frames),
|
|
llm,
|
|
bot_output_gate, # Buffer all llm/tts output until notified.
|
|
],
|
|
),
|
|
tts,
|
|
user_idle,
|
|
transport.output(),
|
|
context_aggregator.assistant(),
|
|
]
|
|
)
|
|
|
|
task = PipelineTask(
|
|
pipeline,
|
|
PipelineParams(
|
|
allow_interruptions=True,
|
|
enable_metrics=True,
|
|
enable_usage_metrics=True,
|
|
),
|
|
)
|
|
|
|
@transport.event_handler("on_first_participant_joined")
|
|
async def on_first_participant_joined(transport, participant):
|
|
await transport.capture_participant_transcription(participant["id"])
|
|
# Kick off the conversation.
|
|
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
|
|
|
@transport.event_handler("on_app_message")
|
|
async def on_app_message(transport, message, sender):
|
|
logger.debug(f"Received app message: {message} - {sender}")
|
|
if "message" not in message:
|
|
return
|
|
|
|
await task.queue_frames(
|
|
[
|
|
UserStartedSpeakingFrame(),
|
|
TranscriptionFrame(
|
|
user_id=sender, timestamp=time.time(), text=message["message"]
|
|
),
|
|
UserStoppedSpeakingFrame(),
|
|
]
|
|
)
|
|
|
|
runner = PipelineRunner()
|
|
await runner.run(task)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|