gemini: create transcribe tasks only once
This commit is contained in:
@@ -8,7 +8,7 @@ import asyncio
|
||||
import io
|
||||
import wave
|
||||
from abc import abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -69,7 +69,7 @@ class AIService(FrameProcessor):
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
pass
|
||||
|
||||
async def _update_settings(self, settings: Dict[str, Any]):
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
from pipecat.services.openai_realtime_beta.events import (
|
||||
SessionProperties,
|
||||
)
|
||||
@@ -267,7 +267,7 @@ class TTSService(AIService):
|
||||
await self.cancel_task(self._stop_frame_task)
|
||||
self._stop_frame_task = None
|
||||
|
||||
async def _update_settings(self, settings: Dict[str, Any]):
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
for key, value in settings.items():
|
||||
if key in self._settings:
|
||||
logger.info(f"Updating TTS setting {key} to: [{value}]")
|
||||
@@ -468,7 +468,7 @@ class STTService(AIService):
|
||||
"""Returns transcript as a string"""
|
||||
pass
|
||||
|
||||
async def _update_settings(self, settings: Dict[str, Any]):
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
logger.info(f"Updating STT settings: {self._settings}")
|
||||
for key, value in settings.items():
|
||||
if key in self._settings:
|
||||
|
||||
@@ -51,6 +51,7 @@ from pipecat.services.openai import (
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.utils import wait_for_task
|
||||
|
||||
from . import events
|
||||
from .audio_transcriber import AudioTranscriber
|
||||
@@ -182,9 +183,13 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
|
||||
self._audio_input_paused = start_audio_paused
|
||||
self._video_input_paused = start_video_paused
|
||||
self._context = None
|
||||
self._websocket = None
|
||||
self._receive_task = None
|
||||
self._context = None
|
||||
self._transcribe_audio_task = None
|
||||
self._transcribe_model_audio_task = None
|
||||
self._transcribe_audio_queue = asyncio.Queue()
|
||||
self._transcribe_model_audio_queue = asyncio.Queue()
|
||||
|
||||
self._disconnecting = False
|
||||
self._api_session_ready = False
|
||||
@@ -275,7 +280,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
)
|
||||
await self.send_client_event(evt)
|
||||
if self._transcribe_user_audio and self._context:
|
||||
self.create_task(self._handle_transcribe_user_audio(audio, self._context))
|
||||
await self._transcribe_audio_queue.put(audio)
|
||||
|
||||
async def _handle_transcribe_user_audio(self, audio, context):
|
||||
text = await self._transcribe_audio(audio, context)
|
||||
@@ -392,6 +397,10 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
logger.info(f"Connecting to {uri}")
|
||||
self._websocket = await websockets.connect(uri=uri)
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
self._transcribe_audio_task = self.create_task(self._transcribe_audio_handler())
|
||||
self._transcribe_model_audio_task = self.create_task(
|
||||
self._transcribe_model_audio_handler()
|
||||
)
|
||||
config = events.Config.model_validate(
|
||||
{
|
||||
"setup": {
|
||||
@@ -443,6 +452,12 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task, timeout=1.0)
|
||||
self._receive_task = None
|
||||
if self._transcribe_audio_task:
|
||||
await self.cancel_task(self._transcribe_audio_task)
|
||||
self._transcribe_audio_task = None
|
||||
if self._transcribe_model_audio_task:
|
||||
await self.cancel_task(self._transcribe_model_audio_task)
|
||||
self._transcribe_model_audio_task = None
|
||||
self._disconnecting = False
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error disconnecting: {e}")
|
||||
@@ -469,33 +484,35 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
#
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
try:
|
||||
async for message in self._websocket:
|
||||
evt = events.parse_server_event(message)
|
||||
# logger.debug(f"Received event: {message[:500]}")
|
||||
# logger.debug(f"Received event: {evt}")
|
||||
async for message in self._websocket:
|
||||
evt = events.parse_server_event(message)
|
||||
# logger.debug(f"Received event: {message[:500]}")
|
||||
# logger.debug(f"Received event: {evt}")
|
||||
|
||||
if evt.setupComplete:
|
||||
await self._handle_evt_setup_complete(evt)
|
||||
elif evt.serverContent and evt.serverContent.modelTurn:
|
||||
await self._handle_evt_model_turn(evt)
|
||||
elif evt.serverContent and evt.serverContent.turnComplete:
|
||||
await self._handle_evt_turn_complete(evt)
|
||||
elif evt.toolCall:
|
||||
await self._handle_evt_tool_call(evt)
|
||||
if evt.setupComplete:
|
||||
await self._handle_evt_setup_complete(evt)
|
||||
elif evt.serverContent and evt.serverContent.modelTurn:
|
||||
await self._handle_evt_model_turn(evt)
|
||||
elif evt.serverContent and evt.serverContent.turnComplete:
|
||||
await self._handle_evt_turn_complete(evt)
|
||||
elif evt.toolCall:
|
||||
await self._handle_evt_tool_call(evt)
|
||||
elif False: # !!! todo: error events?
|
||||
await self._handle_evt_error(evt)
|
||||
# errors are fatal, so exit the receive loop
|
||||
return
|
||||
else:
|
||||
pass
|
||||
|
||||
elif False: # !!! todo: error events?
|
||||
await self._handle_evt_error(evt)
|
||||
# errors are fatal, so exit the receive loop
|
||||
return
|
||||
async def _transcribe_audio_handler(self):
|
||||
while True:
|
||||
audio = await self._transcribe_audio_queue.get()
|
||||
await self._handle_transcribe_user_audio(audio, self._context)
|
||||
|
||||
else:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("websocket receive task cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
async def _transcribe_model_audio_handler(self):
|
||||
while True:
|
||||
audio = await self._transcribe_model_audio_queue.get()
|
||||
await self._handle_transcribe_model_audio(audio, self._context)
|
||||
|
||||
#
|
||||
#
|
||||
@@ -676,7 +693,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
self._bot_text_buffer = ""
|
||||
|
||||
if audio and self._transcribe_model_audio and self._context:
|
||||
self.create_task(self._handle_transcribe_model_audio(audio, self._context))
|
||||
await self._transcribe_model_audio.put(audio)
|
||||
elif text:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user