Merge pull request #603 from pipecat-ai/aleix/silero-vad-processor-fixes

vad: add support for interruption to SileroVAD processor
This commit is contained in:
Aleix Conchillo Flaqué
2024-10-17 10:48:39 -07:00
committed by GitHub
3 changed files with 142 additions and 2 deletions

View File

@@ -12,6 +12,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Renamed `OpenAILLMServiceRealtimeBeta` to `OpenAIRealtimeBetaLLMService` to
match other services.
### Fixed
- Fixed `SileroVAD` processor to support interruptions properly.
### Other
- Added `examples/foundational/07-interruptible-vad.py`. This is the same as
`07-interruptible.py` but using the `SileroVAD` processor instead of passing
the `VADAnalyzer` in the transport.
## [0.0.45] - 2024-10-16
### Changed

View File

@@ -0,0 +1,106 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import aiohttp
import os
import sys
from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_response import (
LLMAssistantResponseAggregator,
LLMUserResponseAggregator,
)
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.openai import OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.vad.silero import SileroVAD
from runner import configure
from loguru import logger
from dotenv import load_dotenv
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
transcription_enabled=True,
),
)
vad = SileroVAD()
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
tma_in = LLMUserResponseAggregator(messages)
tma_out = LLMAssistantResponseAggregator(messages)
pipeline = Pipeline(
[
transport.input(), # Transport user input
vad,
tma_in, # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
tma_out, # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
report_only_initial_ttfb=True,
),
)
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMMessagesFrame(messages)])
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -11,6 +11,8 @@ import numpy as np
from pipecat.frames.frames import (
AudioRawFrame,
Frame,
StartInterruptionFrame,
StopInterruptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
@@ -200,6 +202,27 @@ class SileroVAD(FrameProcessor):
else:
await self.push_frame(frame, direction)
#
# Handle interruptions
#
async def _handle_interruptions(self, frame: Frame):
if self.interruptions_allowed:
# Make sure we notify about interruptions quickly out-of-band.
if isinstance(frame, UserStartedSpeakingFrame):
logger.debug("User started speaking")
await self._start_interruption()
# Push an out-of-band frame (i.e. not using the ordered push
# frame task) to stop everything, specially at the output
# transport.
await self.push_frame(StartInterruptionFrame())
elif isinstance(frame, UserStoppedSpeakingFrame):
logger.debug("User stopped speaking")
await self._stop_interruption()
await self.push_frame(StopInterruptionFrame())
await self.push_frame(frame)
async def _analyze_audio(self, frame: AudioRawFrame):
# Check VAD and push event if necessary. We just care about changes
# from QUIET to SPEAKING and vice versa.
@@ -217,5 +240,6 @@ class SileroVAD(FrameProcessor):
new_frame = UserStoppedSpeakingFrame()
if new_frame:
await self.push_frame(new_frame)
self._processor_vad_state = new_vad_state
await self._handle_interruptions(new_frame)
self._processor_vad_state = new_vad_state