diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index 2b0485fdd..bee15ceab 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -8,7 +8,7 @@ import asyncio import io import wave from abc import abstractmethod -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple +from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Tuple from loguru import logger @@ -69,7 +69,7 @@ class AIService(FrameProcessor): async def cancel(self, frame: CancelFrame): pass - async def _update_settings(self, settings: Dict[str, Any]): + async def _update_settings(self, settings: Mapping[str, Any]): from pipecat.services.openai_realtime_beta.events import ( SessionProperties, ) @@ -267,7 +267,7 @@ class TTSService(AIService): await self.cancel_task(self._stop_frame_task) self._stop_frame_task = None - async def _update_settings(self, settings: Dict[str, Any]): + async def _update_settings(self, settings: Mapping[str, Any]): for key, value in settings.items(): if key in self._settings: logger.info(f"Updating TTS setting {key} to: [{value}]") @@ -468,7 +468,7 @@ class STTService(AIService): """Returns transcript as a string""" pass - async def _update_settings(self, settings: Dict[str, Any]): + async def _update_settings(self, settings: Mapping[str, Any]): logger.info(f"Updating STT settings: {self._settings}") for key, value in settings.items(): if key in self._settings: diff --git a/src/pipecat/services/gemini_multimodal_live/gemini.py b/src/pipecat/services/gemini_multimodal_live/gemini.py index ef509e4ae..c2fc103b7 100644 --- a/src/pipecat/services/gemini_multimodal_live/gemini.py +++ b/src/pipecat/services/gemini_multimodal_live/gemini.py @@ -51,6 +51,7 @@ from pipecat.services.openai import ( OpenAIUserContextAggregator, ) from pipecat.utils.time import time_now_iso8601 +from pipecat.utils.utils import wait_for_task from . import events from .audio_transcriber import AudioTranscriber @@ -182,9 +183,13 @@ class GeminiMultimodalLiveLLMService(LLMService): self._audio_input_paused = start_audio_paused self._video_input_paused = start_video_paused + self._context = None self._websocket = None self._receive_task = None - self._context = None + self._transcribe_audio_task = None + self._transcribe_model_audio_task = None + self._transcribe_audio_queue = asyncio.Queue() + self._transcribe_model_audio_queue = asyncio.Queue() self._disconnecting = False self._api_session_ready = False @@ -275,7 +280,7 @@ class GeminiMultimodalLiveLLMService(LLMService): ) await self.send_client_event(evt) if self._transcribe_user_audio and self._context: - self.create_task(self._handle_transcribe_user_audio(audio, self._context)) + await self._transcribe_audio_queue.put(audio) async def _handle_transcribe_user_audio(self, audio, context): text = await self._transcribe_audio(audio, context) @@ -392,6 +397,10 @@ class GeminiMultimodalLiveLLMService(LLMService): logger.info(f"Connecting to {uri}") self._websocket = await websockets.connect(uri=uri) self._receive_task = self.create_task(self._receive_task_handler()) + self._transcribe_audio_task = self.create_task(self._transcribe_audio_handler()) + self._transcribe_model_audio_task = self.create_task( + self._transcribe_model_audio_handler() + ) config = events.Config.model_validate( { "setup": { @@ -443,6 +452,12 @@ class GeminiMultimodalLiveLLMService(LLMService): if self._receive_task: await self.cancel_task(self._receive_task, timeout=1.0) self._receive_task = None + if self._transcribe_audio_task: + await self.cancel_task(self._transcribe_audio_task) + self._transcribe_audio_task = None + if self._transcribe_model_audio_task: + await self.cancel_task(self._transcribe_model_audio_task) + self._transcribe_model_audio_task = None self._disconnecting = False except Exception as e: logger.error(f"{self} error disconnecting: {e}") @@ -469,33 +484,35 @@ class GeminiMultimodalLiveLLMService(LLMService): # async def _receive_task_handler(self): - try: - async for message in self._websocket: - evt = events.parse_server_event(message) - # logger.debug(f"Received event: {message[:500]}") - # logger.debug(f"Received event: {evt}") + async for message in self._websocket: + evt = events.parse_server_event(message) + # logger.debug(f"Received event: {message[:500]}") + # logger.debug(f"Received event: {evt}") - if evt.setupComplete: - await self._handle_evt_setup_complete(evt) - elif evt.serverContent and evt.serverContent.modelTurn: - await self._handle_evt_model_turn(evt) - elif evt.serverContent and evt.serverContent.turnComplete: - await self._handle_evt_turn_complete(evt) - elif evt.toolCall: - await self._handle_evt_tool_call(evt) + if evt.setupComplete: + await self._handle_evt_setup_complete(evt) + elif evt.serverContent and evt.serverContent.modelTurn: + await self._handle_evt_model_turn(evt) + elif evt.serverContent and evt.serverContent.turnComplete: + await self._handle_evt_turn_complete(evt) + elif evt.toolCall: + await self._handle_evt_tool_call(evt) + elif False: # !!! todo: error events? + await self._handle_evt_error(evt) + # errors are fatal, so exit the receive loop + return + else: + pass - elif False: # !!! todo: error events? - await self._handle_evt_error(evt) - # errors are fatal, so exit the receive loop - return + async def _transcribe_audio_handler(self): + while True: + audio = await self._transcribe_audio_queue.get() + await self._handle_transcribe_user_audio(audio, self._context) - else: - pass - except asyncio.CancelledError: - logger.debug("websocket receive task cancelled") - raise - except Exception as e: - logger.error(f"{self} exception: {e}") + async def _transcribe_model_audio_handler(self): + while True: + audio = await self._transcribe_model_audio_queue.get() + await self._handle_transcribe_model_audio(audio, self._context) # # @@ -676,7 +693,7 @@ class GeminiMultimodalLiveLLMService(LLMService): self._bot_text_buffer = "" if audio and self._transcribe_model_audio and self._context: - self.create_task(self._handle_transcribe_model_audio(audio, self._context)) + await self._transcribe_model_audio.put(audio) elif text: await self.push_frame(LLMFullResponseEndFrame())