gemini: create transcribe tasks only once

This commit is contained in:
Aleix Conchillo Flaqué
2025-01-26 15:07:14 -08:00
parent a3a6adbd17
commit 2a2928d96c
2 changed files with 48 additions and 31 deletions

View File

@@ -8,7 +8,7 @@ import asyncio
import io
import wave
from abc import abstractmethod
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Tuple
from loguru import logger
@@ -69,7 +69,7 @@ class AIService(FrameProcessor):
async def cancel(self, frame: CancelFrame):
pass
async def _update_settings(self, settings: Dict[str, Any]):
async def _update_settings(self, settings: Mapping[str, Any]):
from pipecat.services.openai_realtime_beta.events import (
SessionProperties,
)
@@ -267,7 +267,7 @@ class TTSService(AIService):
await self.cancel_task(self._stop_frame_task)
self._stop_frame_task = None
async def _update_settings(self, settings: Dict[str, Any]):
async def _update_settings(self, settings: Mapping[str, Any]):
for key, value in settings.items():
if key in self._settings:
logger.info(f"Updating TTS setting {key} to: [{value}]")
@@ -468,7 +468,7 @@ class STTService(AIService):
"""Returns transcript as a string"""
pass
async def _update_settings(self, settings: Dict[str, Any]):
async def _update_settings(self, settings: Mapping[str, Any]):
logger.info(f"Updating STT settings: {self._settings}")
for key, value in settings.items():
if key in self._settings:

View File

@@ -51,6 +51,7 @@ from pipecat.services.openai import (
OpenAIUserContextAggregator,
)
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.utils import wait_for_task
from . import events
from .audio_transcriber import AudioTranscriber
@@ -182,9 +183,13 @@ class GeminiMultimodalLiveLLMService(LLMService):
self._audio_input_paused = start_audio_paused
self._video_input_paused = start_video_paused
self._context = None
self._websocket = None
self._receive_task = None
self._context = None
self._transcribe_audio_task = None
self._transcribe_model_audio_task = None
self._transcribe_audio_queue = asyncio.Queue()
self._transcribe_model_audio_queue = asyncio.Queue()
self._disconnecting = False
self._api_session_ready = False
@@ -275,7 +280,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
)
await self.send_client_event(evt)
if self._transcribe_user_audio and self._context:
self.create_task(self._handle_transcribe_user_audio(audio, self._context))
await self._transcribe_audio_queue.put(audio)
async def _handle_transcribe_user_audio(self, audio, context):
text = await self._transcribe_audio(audio, context)
@@ -392,6 +397,10 @@ class GeminiMultimodalLiveLLMService(LLMService):
logger.info(f"Connecting to {uri}")
self._websocket = await websockets.connect(uri=uri)
self._receive_task = self.create_task(self._receive_task_handler())
self._transcribe_audio_task = self.create_task(self._transcribe_audio_handler())
self._transcribe_model_audio_task = self.create_task(
self._transcribe_model_audio_handler()
)
config = events.Config.model_validate(
{
"setup": {
@@ -443,6 +452,12 @@ class GeminiMultimodalLiveLLMService(LLMService):
if self._receive_task:
await self.cancel_task(self._receive_task, timeout=1.0)
self._receive_task = None
if self._transcribe_audio_task:
await self.cancel_task(self._transcribe_audio_task)
self._transcribe_audio_task = None
if self._transcribe_model_audio_task:
await self.cancel_task(self._transcribe_model_audio_task)
self._transcribe_model_audio_task = None
self._disconnecting = False
except Exception as e:
logger.error(f"{self} error disconnecting: {e}")
@@ -469,33 +484,35 @@ class GeminiMultimodalLiveLLMService(LLMService):
#
async def _receive_task_handler(self):
try:
async for message in self._websocket:
evt = events.parse_server_event(message)
# logger.debug(f"Received event: {message[:500]}")
# logger.debug(f"Received event: {evt}")
async for message in self._websocket:
evt = events.parse_server_event(message)
# logger.debug(f"Received event: {message[:500]}")
# logger.debug(f"Received event: {evt}")
if evt.setupComplete:
await self._handle_evt_setup_complete(evt)
elif evt.serverContent and evt.serverContent.modelTurn:
await self._handle_evt_model_turn(evt)
elif evt.serverContent and evt.serverContent.turnComplete:
await self._handle_evt_turn_complete(evt)
elif evt.toolCall:
await self._handle_evt_tool_call(evt)
if evt.setupComplete:
await self._handle_evt_setup_complete(evt)
elif evt.serverContent and evt.serverContent.modelTurn:
await self._handle_evt_model_turn(evt)
elif evt.serverContent and evt.serverContent.turnComplete:
await self._handle_evt_turn_complete(evt)
elif evt.toolCall:
await self._handle_evt_tool_call(evt)
elif False: # !!! todo: error events?
await self._handle_evt_error(evt)
# errors are fatal, so exit the receive loop
return
else:
pass
elif False: # !!! todo: error events?
await self._handle_evt_error(evt)
# errors are fatal, so exit the receive loop
return
async def _transcribe_audio_handler(self):
while True:
audio = await self._transcribe_audio_queue.get()
await self._handle_transcribe_user_audio(audio, self._context)
else:
pass
except asyncio.CancelledError:
logger.debug("websocket receive task cancelled")
raise
except Exception as e:
logger.error(f"{self} exception: {e}")
async def _transcribe_model_audio_handler(self):
while True:
audio = await self._transcribe_model_audio_queue.get()
await self._handle_transcribe_model_audio(audio, self._context)
#
#
@@ -676,7 +693,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
self._bot_text_buffer = ""
if audio and self._transcribe_model_audio and self._context:
self.create_task(self._handle_transcribe_model_audio(audio, self._context))
await self._transcribe_model_audio.put(audio)
elif text:
await self.push_frame(LLMFullResponseEndFrame())