Fix bundled Gemini Live transcription ordering
This commit is contained in:
@@ -1361,9 +1361,9 @@ class GeminiLiveLLMService(LLMService):
|
||||
self._check_and_reset_failure_counter()
|
||||
|
||||
# server_content fields are NOT mutually exclusive —
|
||||
# Gemini 3.x can bundle e.g. model_turn and
|
||||
# output_transcription on the same message — so check
|
||||
# each field independently.
|
||||
# Gemini 3.x can bundle multiple content fields and
|
||||
# turn_complete on the same message, so process the
|
||||
# content-bearing fields before closing the turn.
|
||||
sc = message.server_content
|
||||
if sc and sc.interrupted:
|
||||
# NOTE: while the service triggers interruptions in
|
||||
@@ -1380,18 +1380,25 @@ class GeminiLiveLLMService(LLMService):
|
||||
await self.broadcast_interruption()
|
||||
if sc and sc.model_turn:
|
||||
await self._handle_msg_model_turn(message)
|
||||
if sc and sc.input_transcription:
|
||||
await self._handle_msg_input_transcription(message)
|
||||
if sc and sc.output_transcription:
|
||||
await self._handle_msg_output_transcription(message)
|
||||
if (
|
||||
sc
|
||||
and sc.grounding_metadata
|
||||
and not sc.model_turn
|
||||
and not sc.output_transcription
|
||||
):
|
||||
# model_turn/output_transcription already defer
|
||||
# bundled grounding metadata to turn_complete.
|
||||
await self._handle_msg_grounding_metadata(message)
|
||||
if sc and sc.turn_complete:
|
||||
if not message.usage_metadata:
|
||||
logger.warning("Received turn_complete without usage_metadata")
|
||||
await self._handle_msg_turn_complete(message)
|
||||
if message.usage_metadata:
|
||||
await self._handle_msg_usage_metadata(message)
|
||||
if sc and sc.input_transcription:
|
||||
await self._handle_msg_input_transcription(message)
|
||||
if sc and sc.output_transcription:
|
||||
await self._handle_msg_output_transcription(message)
|
||||
if sc and sc.grounding_metadata:
|
||||
await self._handle_msg_grounding_metadata(message)
|
||||
if message.tool_call:
|
||||
await self._handle_msg_tool_call(message)
|
||||
if message.session_resumption_update:
|
||||
|
||||
190
tests/test_google_gemini_live_llm.py
Normal file
190
tests/test_google_gemini_live_llm.py
Normal file
@@ -0,0 +1,190 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Unit tests for Gemini Live message dispatch ordering."""
|
||||
|
||||
import warnings
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from pipecat.services.google.gemini_live.llm import GeminiLiveLLMService
|
||||
|
||||
google_available = True
|
||||
except Exception:
|
||||
google_available = False
|
||||
|
||||
|
||||
class _SingleMessageTurn:
|
||||
def __init__(self, message):
|
||||
self._message = message
|
||||
self._sent = False
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self._sent:
|
||||
raise StopAsyncIteration
|
||||
self._sent = True
|
||||
return self._message
|
||||
|
||||
|
||||
class _SingleMessageSession:
|
||||
def __init__(self, message):
|
||||
self._message = message
|
||||
self._received = False
|
||||
|
||||
def receive(self):
|
||||
if self._received:
|
||||
raise RuntimeError("stop test session")
|
||||
self._received = True
|
||||
return _SingleMessageTurn(self._message)
|
||||
|
||||
|
||||
class _MockConnect:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
async def __aenter__(self):
|
||||
return self._session
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
|
||||
def _make_service():
|
||||
with patch.object(GeminiLiveLLMService, "create_client"):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
return GeminiLiveLLMService(api_key="test-key", model="test-model")
|
||||
|
||||
|
||||
def _make_message(
|
||||
*,
|
||||
model_turn=None,
|
||||
turn_complete=False,
|
||||
output_transcription=None,
|
||||
grounding_metadata=None,
|
||||
usage_metadata=None,
|
||||
):
|
||||
server_content = SimpleNamespace(
|
||||
interrupted=False,
|
||||
model_turn=model_turn,
|
||||
turn_complete=turn_complete,
|
||||
input_transcription=None,
|
||||
output_transcription=output_transcription,
|
||||
grounding_metadata=grounding_metadata,
|
||||
)
|
||||
return SimpleNamespace(
|
||||
server_content=server_content,
|
||||
usage_metadata=usage_metadata,
|
||||
tool_call=None,
|
||||
session_resumption_update=None,
|
||||
)
|
||||
|
||||
|
||||
async def _run_single_message(service, message):
|
||||
session = _SingleMessageSession(message)
|
||||
service._client = SimpleNamespace(
|
||||
aio=SimpleNamespace(live=SimpleNamespace(connect=lambda **kwargs: _MockConnect(session)))
|
||||
)
|
||||
service._disconnecting = True
|
||||
await service._connection_task_handler(config=SimpleNamespace())
|
||||
|
||||
|
||||
def _async_recorder(order, name):
|
||||
async def _record(message):
|
||||
order.append(name)
|
||||
|
||||
return _record
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not google_available, reason="Google dependencies not installed")
|
||||
async def test_output_transcription_is_processed_before_turn_complete():
|
||||
service = _make_service()
|
||||
order = []
|
||||
|
||||
service._handle_msg_model_turn = AsyncMock(side_effect=_async_recorder(order, "model_turn"))
|
||||
service._handle_msg_output_transcription = AsyncMock(
|
||||
side_effect=_async_recorder(order, "output_transcription")
|
||||
)
|
||||
service._handle_msg_turn_complete = AsyncMock(
|
||||
side_effect=_async_recorder(order, "turn_complete")
|
||||
)
|
||||
service._handle_msg_usage_metadata = AsyncMock(
|
||||
side_effect=_async_recorder(order, "usage_metadata")
|
||||
)
|
||||
|
||||
message = _make_message(
|
||||
model_turn=object(),
|
||||
output_transcription=object(),
|
||||
turn_complete=True,
|
||||
usage_metadata=object(),
|
||||
)
|
||||
await _run_single_message(service, message)
|
||||
|
||||
assert order == ["model_turn", "output_transcription", "turn_complete", "usage_metadata"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not google_available, reason="Google dependencies not installed")
|
||||
async def test_bundled_grounding_metadata_is_not_emitted_immediately():
|
||||
service = _make_service()
|
||||
order = []
|
||||
|
||||
service._handle_msg_output_transcription = AsyncMock(
|
||||
side_effect=_async_recorder(order, "output_transcription")
|
||||
)
|
||||
service._handle_msg_grounding_metadata = AsyncMock(
|
||||
side_effect=_async_recorder(order, "grounding_metadata")
|
||||
)
|
||||
service._handle_msg_turn_complete = AsyncMock(
|
||||
side_effect=_async_recorder(order, "turn_complete")
|
||||
)
|
||||
service._handle_msg_usage_metadata = AsyncMock(
|
||||
side_effect=_async_recorder(order, "usage_metadata")
|
||||
)
|
||||
|
||||
message = _make_message(
|
||||
output_transcription=object(),
|
||||
grounding_metadata=object(),
|
||||
turn_complete=True,
|
||||
usage_metadata=object(),
|
||||
)
|
||||
await _run_single_message(service, message)
|
||||
|
||||
assert order == ["output_transcription", "turn_complete", "usage_metadata"]
|
||||
service._handle_msg_grounding_metadata.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not google_available, reason="Google dependencies not installed")
|
||||
async def test_standalone_grounding_metadata_is_still_emitted():
|
||||
service = _make_service()
|
||||
order = []
|
||||
|
||||
service._handle_msg_grounding_metadata = AsyncMock(
|
||||
side_effect=_async_recorder(order, "grounding_metadata")
|
||||
)
|
||||
service._handle_msg_turn_complete = AsyncMock(
|
||||
side_effect=_async_recorder(order, "turn_complete")
|
||||
)
|
||||
service._handle_msg_usage_metadata = AsyncMock(
|
||||
side_effect=_async_recorder(order, "usage_metadata")
|
||||
)
|
||||
|
||||
message = _make_message(
|
||||
grounding_metadata=object(),
|
||||
turn_complete=True,
|
||||
usage_metadata=object(),
|
||||
)
|
||||
await _run_single_message(service, message)
|
||||
|
||||
assert order == ["grounding_metadata", "turn_complete", "usage_metadata"]
|
||||
Reference in New Issue
Block a user