services(azure): cancel tasks when interrupted and ignore incoming transcriptions

This commit is contained in:
Aleix Conchillo Flaqué
2024-06-25 10:54:56 -07:00
parent 64198313c6
commit 38aee7d8f2
2 changed files with 44 additions and 3 deletions

View File

@@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Fixed an issue with asynchronous STT services (Deepgram and Azure) that could
cause static audio issues and interruptions to not work properly when dealing
with multiple LLMs sentences.
- Fixed an issue that could mix new LLM responses with previous ones when
handling interruptions.

View File

@@ -12,7 +12,18 @@ import time
from PIL import Image
from typing import AsyncGenerator
from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, ErrorFrame, Frame, StartFrame, SystemFrame, TranscriptionFrame, URLImageRawFrame
from pipecat.frames.frames import (
AudioRawFrame,
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
StartFrame,
StartInterruptionFrame,
StopInterruptionFrame,
SystemFrame,
TranscriptionFrame,
URLImageRawFrame)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import AIService, TTSService, ImageGenService
from pipecat.services.openai import BaseOpenAILLMService
@@ -34,7 +45,7 @@ try:
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use Azure TTS, you need to `pip install pipecat-ai[azure]`. Also, set `AZURE_SPEECH_API_KEY` and `AZURE_SPEECH_REGION` environment variables.")
"In order to use Azure, you need to `pip install pipecat-ai[azure]`. Also, set `AZURE_SPEECH_API_KEY` and `AZURE_SPEECH_REGION` environment variables.")
raise Exception(f"Missing module: {e}")
@@ -123,12 +134,18 @@ class AzureSTTService(AIService):
speech_config=speech_config, audio_config=audio_config)
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
# This event will be used to ignore out-of-band transcriptions while we
# are itnerrupted.
self._is_interrupted_event = asyncio.Event()
self._create_push_task()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, SystemFrame):
if isinstance(frame, StartInterruptionFrame) or isinstance(frame, StopInterruptionFrame):
await self._handle_interruptions(frame)
elif isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
elif isinstance(frame, AudioRawFrame):
self._audio_stream.write(frame.audio)
@@ -148,6 +165,23 @@ class AzureSTTService(AIService):
self._push_frame_task.cancel()
await self._push_frame_task
async def _handle_interruptions(self, frame: Frame):
if isinstance(frame, StartInterruptionFrame):
# Indicate we are interrupted, we should ignore any out-of-band
# transcriptions.
self._is_interrupted_event.set()
# Cancel the task. This will stop pushing frames downstream.
self._push_frame_task.cancel()
await self._push_frame_task
# Push an out-of-band frame (i.e. not using the ordered push
# frame task).
await self.push_frame(frame)
# Create a new queue and task.
self._create_push_task()
elif isinstance(frame, StopInterruptionFrame):
# We should now be able to receive transcriptions again.
self._is_interrupted_event.clear()
def _create_push_task(self):
self._push_queue = asyncio.Queue()
self._push_frame_task = self.get_event_loop().create_task(self._push_frame_task_handler())
@@ -163,6 +197,9 @@ class AzureSTTService(AIService):
break
def _on_handle_recognized(self, event):
if self._is_interrupted_event.is_set():
return
if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0:
direction = FrameDirection.DOWNSTREAM
frame = TranscriptionFrame(event.result.text, "", int(time.time_ns() / 1000000))