diff --git a/CHANGELOG.md b/CHANGELOG.md index 77c66e937..d8696008a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed an issue with asynchronous STT services (Deepgram and Azure) that could + cause static audio issues and interruptions to not work properly when dealing + with multiple LLMs sentences. + - Fixed an issue that could mix new LLM responses with previous ones when handling interruptions. diff --git a/src/pipecat/services/azure.py b/src/pipecat/services/azure.py index 85ac49ae3..e8e2acd34 100644 --- a/src/pipecat/services/azure.py +++ b/src/pipecat/services/azure.py @@ -12,7 +12,18 @@ import time from PIL import Image from typing import AsyncGenerator -from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, ErrorFrame, Frame, StartFrame, SystemFrame, TranscriptionFrame, URLImageRawFrame +from pipecat.frames.frames import ( + AudioRawFrame, + CancelFrame, + EndFrame, + ErrorFrame, + Frame, + StartFrame, + StartInterruptionFrame, + StopInterruptionFrame, + SystemFrame, + TranscriptionFrame, + URLImageRawFrame) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_services import AIService, TTSService, ImageGenService from pipecat.services.openai import BaseOpenAILLMService @@ -34,7 +45,7 @@ try: except ModuleNotFoundError as e: logger.error(f"Exception: {e}") logger.error( - "In order to use Azure TTS, you need to `pip install pipecat-ai[azure]`. Also, set `AZURE_SPEECH_API_KEY` and `AZURE_SPEECH_REGION` environment variables.") + "In order to use Azure, you need to `pip install pipecat-ai[azure]`. Also, set `AZURE_SPEECH_API_KEY` and `AZURE_SPEECH_REGION` environment variables.") raise Exception(f"Missing module: {e}") @@ -123,12 +134,18 @@ class AzureSTTService(AIService): speech_config=speech_config, audio_config=audio_config) self._speech_recognizer.recognized.connect(self._on_handle_recognized) + # This event will be used to ignore out-of-band transcriptions while we + # are itnerrupted. + self._is_interrupted_event = asyncio.Event() + self._create_push_task() async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) - if isinstance(frame, SystemFrame): + if isinstance(frame, StartInterruptionFrame) or isinstance(frame, StopInterruptionFrame): + await self._handle_interruptions(frame) + elif isinstance(frame, SystemFrame): await self.push_frame(frame, direction) elif isinstance(frame, AudioRawFrame): self._audio_stream.write(frame.audio) @@ -148,6 +165,23 @@ class AzureSTTService(AIService): self._push_frame_task.cancel() await self._push_frame_task + async def _handle_interruptions(self, frame: Frame): + if isinstance(frame, StartInterruptionFrame): + # Indicate we are interrupted, we should ignore any out-of-band + # transcriptions. + self._is_interrupted_event.set() + # Cancel the task. This will stop pushing frames downstream. + self._push_frame_task.cancel() + await self._push_frame_task + # Push an out-of-band frame (i.e. not using the ordered push + # frame task). + await self.push_frame(frame) + # Create a new queue and task. + self._create_push_task() + elif isinstance(frame, StopInterruptionFrame): + # We should now be able to receive transcriptions again. + self._is_interrupted_event.clear() + def _create_push_task(self): self._push_queue = asyncio.Queue() self._push_frame_task = self.get_event_loop().create_task(self._push_frame_task_handler()) @@ -163,6 +197,9 @@ class AzureSTTService(AIService): break def _on_handle_recognized(self, event): + if self._is_interrupted_event.is_set(): + return + if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0: direction = FrameDirection.DOWNSTREAM frame = TranscriptionFrame(event.result.text, "", int(time.time_ns() / 1000000))