From 3cd7d882fbdbe8d4bb2bf75e96eb6f91f09ecc0a Mon Sep 17 00:00:00 2001 From: Kwindla Hultman Kramer Date: Wed, 25 Mar 2026 18:56:00 -0700 Subject: [PATCH] Fix bundled Gemini Live transcription ordering --- .../services/google/gemini_live/llm.py | 25 ++- tests/test_google_gemini_live_llm.py | 190 ++++++++++++++++++ 2 files changed, 206 insertions(+), 9 deletions(-) create mode 100644 tests/test_google_gemini_live_llm.py diff --git a/src/pipecat/services/google/gemini_live/llm.py b/src/pipecat/services/google/gemini_live/llm.py index 27827d5c8..f870419be 100644 --- a/src/pipecat/services/google/gemini_live/llm.py +++ b/src/pipecat/services/google/gemini_live/llm.py @@ -1361,9 +1361,9 @@ class GeminiLiveLLMService(LLMService): self._check_and_reset_failure_counter() # server_content fields are NOT mutually exclusive — - # Gemini 3.x can bundle e.g. model_turn and - # output_transcription on the same message — so check - # each field independently. + # Gemini 3.x can bundle multiple content fields and + # turn_complete on the same message, so process the + # content-bearing fields before closing the turn. sc = message.server_content if sc and sc.interrupted: # NOTE: while the service triggers interruptions in @@ -1380,18 +1380,25 @@ class GeminiLiveLLMService(LLMService): await self.broadcast_interruption() if sc and sc.model_turn: await self._handle_msg_model_turn(message) + if sc and sc.input_transcription: + await self._handle_msg_input_transcription(message) + if sc and sc.output_transcription: + await self._handle_msg_output_transcription(message) + if ( + sc + and sc.grounding_metadata + and not sc.model_turn + and not sc.output_transcription + ): + # model_turn/output_transcription already defer + # bundled grounding metadata to turn_complete. + await self._handle_msg_grounding_metadata(message) if sc and sc.turn_complete: if not message.usage_metadata: logger.warning("Received turn_complete without usage_metadata") await self._handle_msg_turn_complete(message) if message.usage_metadata: await self._handle_msg_usage_metadata(message) - if sc and sc.input_transcription: - await self._handle_msg_input_transcription(message) - if sc and sc.output_transcription: - await self._handle_msg_output_transcription(message) - if sc and sc.grounding_metadata: - await self._handle_msg_grounding_metadata(message) if message.tool_call: await self._handle_msg_tool_call(message) if message.session_resumption_update: diff --git a/tests/test_google_gemini_live_llm.py b/tests/test_google_gemini_live_llm.py new file mode 100644 index 000000000..1d94322f4 --- /dev/null +++ b/tests/test_google_gemini_live_llm.py @@ -0,0 +1,190 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Unit tests for Gemini Live message dispatch ordering.""" + +import warnings +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest + +try: + from pipecat.services.google.gemini_live.llm import GeminiLiveLLMService + + google_available = True +except Exception: + google_available = False + + +class _SingleMessageTurn: + def __init__(self, message): + self._message = message + self._sent = False + + def __aiter__(self): + return self + + async def __anext__(self): + if self._sent: + raise StopAsyncIteration + self._sent = True + return self._message + + +class _SingleMessageSession: + def __init__(self, message): + self._message = message + self._received = False + + def receive(self): + if self._received: + raise RuntimeError("stop test session") + self._received = True + return _SingleMessageTurn(self._message) + + +class _MockConnect: + def __init__(self, session): + self._session = session + + async def __aenter__(self): + return self._session + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return False + + +def _make_service(): + with patch.object(GeminiLiveLLMService, "create_client"): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + return GeminiLiveLLMService(api_key="test-key", model="test-model") + + +def _make_message( + *, + model_turn=None, + turn_complete=False, + output_transcription=None, + grounding_metadata=None, + usage_metadata=None, +): + server_content = SimpleNamespace( + interrupted=False, + model_turn=model_turn, + turn_complete=turn_complete, + input_transcription=None, + output_transcription=output_transcription, + grounding_metadata=grounding_metadata, + ) + return SimpleNamespace( + server_content=server_content, + usage_metadata=usage_metadata, + tool_call=None, + session_resumption_update=None, + ) + + +async def _run_single_message(service, message): + session = _SingleMessageSession(message) + service._client = SimpleNamespace( + aio=SimpleNamespace(live=SimpleNamespace(connect=lambda **kwargs: _MockConnect(session))) + ) + service._disconnecting = True + await service._connection_task_handler(config=SimpleNamespace()) + + +def _async_recorder(order, name): + async def _record(message): + order.append(name) + + return _record + + +@pytest.mark.asyncio +@pytest.mark.skipif(not google_available, reason="Google dependencies not installed") +async def test_output_transcription_is_processed_before_turn_complete(): + service = _make_service() + order = [] + + service._handle_msg_model_turn = AsyncMock(side_effect=_async_recorder(order, "model_turn")) + service._handle_msg_output_transcription = AsyncMock( + side_effect=_async_recorder(order, "output_transcription") + ) + service._handle_msg_turn_complete = AsyncMock( + side_effect=_async_recorder(order, "turn_complete") + ) + service._handle_msg_usage_metadata = AsyncMock( + side_effect=_async_recorder(order, "usage_metadata") + ) + + message = _make_message( + model_turn=object(), + output_transcription=object(), + turn_complete=True, + usage_metadata=object(), + ) + await _run_single_message(service, message) + + assert order == ["model_turn", "output_transcription", "turn_complete", "usage_metadata"] + + +@pytest.mark.asyncio +@pytest.mark.skipif(not google_available, reason="Google dependencies not installed") +async def test_bundled_grounding_metadata_is_not_emitted_immediately(): + service = _make_service() + order = [] + + service._handle_msg_output_transcription = AsyncMock( + side_effect=_async_recorder(order, "output_transcription") + ) + service._handle_msg_grounding_metadata = AsyncMock( + side_effect=_async_recorder(order, "grounding_metadata") + ) + service._handle_msg_turn_complete = AsyncMock( + side_effect=_async_recorder(order, "turn_complete") + ) + service._handle_msg_usage_metadata = AsyncMock( + side_effect=_async_recorder(order, "usage_metadata") + ) + + message = _make_message( + output_transcription=object(), + grounding_metadata=object(), + turn_complete=True, + usage_metadata=object(), + ) + await _run_single_message(service, message) + + assert order == ["output_transcription", "turn_complete", "usage_metadata"] + service._handle_msg_grounding_metadata.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.skipif(not google_available, reason="Google dependencies not installed") +async def test_standalone_grounding_metadata_is_still_emitted(): + service = _make_service() + order = [] + + service._handle_msg_grounding_metadata = AsyncMock( + side_effect=_async_recorder(order, "grounding_metadata") + ) + service._handle_msg_turn_complete = AsyncMock( + side_effect=_async_recorder(order, "turn_complete") + ) + service._handle_msg_usage_metadata = AsyncMock( + side_effect=_async_recorder(order, "usage_metadata") + ) + + message = _make_message( + grounding_metadata=object(), + turn_complete=True, + usage_metadata=object(), + ) + await _run_single_message(service, message) + + assert order == ["grounding_metadata", "turn_complete", "usage_metadata"]