Support conversations with Gemini 3 Pro Image (model "gemini-3-pro-image-preview").
Prior to this change, after the model generated an image the conversation would not be able to progress. It would stall out because we were never storing the image in context, so the model would never realize it already did the work of generating an image. We didn't run into issues with Gemini 2.5 Flash Image, because that model always followed up an image with a text message.
This commit is contained in:
3
changelog/3224.fixed.2.md
Normal file
3
changelog/3224.fixed.2.md
Normal file
@@ -0,0 +1,3 @@
|
||||
- Better support conversation history with Gemini 2.5 Flash Image (model
|
||||
"gemini-2.5-flash-image"). Prior to this fix, the model had no memory of
|
||||
previous images it had generated, so it wouldn't be able to iterate on them.
|
||||
3
changelog/3224.fixed.md
Normal file
3
changelog/3224.fixed.md
Normal file
@@ -0,0 +1,3 @@
|
||||
- Support conversations with Gemini 3 Pro Image (model
|
||||
"gemini-3-pro-image-preview"). Prior to this fix, after the model generated
|
||||
an image the conversation would not be able to progress.
|
||||
@@ -89,6 +89,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
llm = GoogleLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
model="gemini-2.5-flash-image",
|
||||
# model="gemini-3-pro-image-preview", # A more powerful model, but slower
|
||||
)
|
||||
|
||||
messages = [
|
||||
|
||||
@@ -556,7 +556,10 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
if (
|
||||
hasattr(part, "inline_data")
|
||||
and part.inline_data
|
||||
and part.inline_data.data == bookmark_inline_data.data
|
||||
# Comparing length should be good enough for matching inline data,
|
||||
# especially since we're already matching thought signatures in
|
||||
# strict message order. Comparing actual data is expensive.
|
||||
and len(part.inline_data.data) == len(bookmark_inline_data.data)
|
||||
):
|
||||
logger.trace(f"Thought signature inline data match")
|
||||
return True
|
||||
|
||||
@@ -1466,6 +1466,20 @@ class UserImageRawFrame(InputImageRawFrame):
|
||||
return f"{self.name}(pts: {pts}, user: {self.user_id}, source: {self.transport_source}, size: {self.size}, format: {self.format}, text: {self.text}, append_to_context: {self.append_to_context})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssistantImageRawFrame(OutputImageRawFrame):
|
||||
"""Frame containing image generated by the assistant.
|
||||
|
||||
An image generated by the assistant. Gets appended to the LLM context.
|
||||
|
||||
Parameters:
|
||||
original_jpeg: The already-JPEG-encoded image bytes, which may be
|
||||
appended directly to the LLM context without further encoding.
|
||||
"""
|
||||
|
||||
original_jpeg: Optional[bytes] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputDTMFFrame(DTMFFrame, SystemFrame):
|
||||
"""DTMF keypress input frame from transport."""
|
||||
|
||||
@@ -157,9 +157,15 @@ class LLMContext:
|
||||
"""
|
||||
|
||||
def encode_image():
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
if format == "JPEG":
|
||||
# Already JPEG-encoded
|
||||
bytes = image
|
||||
else:
|
||||
# Encode to JPEG
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
bytes = buffer.getvalue()
|
||||
encoded_image = base64.b64encode(bytes).decode("utf-8")
|
||||
return encoded_image
|
||||
|
||||
encoded_image = await asyncio.to_thread(encode_image)
|
||||
@@ -334,18 +340,26 @@ class LLMContext:
|
||||
self._tool_choice = tool_choice
|
||||
|
||||
async def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: Optional[str] = None
|
||||
self,
|
||||
*,
|
||||
format: str,
|
||||
size: tuple[int, int],
|
||||
image: bytes,
|
||||
text: Optional[str] = None,
|
||||
role: str = "user",
|
||||
):
|
||||
"""Add a message containing an image frame.
|
||||
|
||||
Args:
|
||||
format: Image format (e.g., 'RGB', 'RGBA').
|
||||
format: Image format (e.g., 'RGB', 'RGBA', or, if already
|
||||
JPEG-encoded, "JPEG").
|
||||
size: Image dimensions as (width, height) tuple.
|
||||
image: Raw image bytes.
|
||||
text: Optional text to include with the image.
|
||||
role: The role of this message (defaults to "user").
|
||||
"""
|
||||
message = await LLMContext.create_image_message(
|
||||
format=format, size=size, image=image, text=text
|
||||
role=role, format=format, size=size, image=image, text=text
|
||||
)
|
||||
self.add_message(message)
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from pipecat.audio.interruptions.base_interruption_strategy import BaseInterrupt
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import (
|
||||
AssistantImageRawFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -663,6 +664,8 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
await self._handle_function_call_cancel(frame)
|
||||
elif isinstance(frame, UserImageRawFrame):
|
||||
await self._handle_user_image_frame(frame)
|
||||
elif isinstance(frame, AssistantImageRawFrame):
|
||||
await self._handle_assistant_image_frame(frame)
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
await self.push_aggregation()
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -827,6 +830,24 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
await self.push_aggregation()
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_assistant_image_frame(self, frame: AssistantImageRawFrame):
|
||||
logger.debug(f"{self} Appending AssistantImageRawFrame to LLM context (size: {frame.size})")
|
||||
|
||||
if frame.original_jpeg:
|
||||
await self._context.add_image_frame_message(
|
||||
format="JPEG",
|
||||
size=frame.size, # Technically doesn't matter, since already encoded
|
||||
image=frame.original_jpeg,
|
||||
role="assistant",
|
||||
)
|
||||
else:
|
||||
await self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
async def _handle_llm_start(self, _: LLMFullResponseStartFrame):
|
||||
self._started += 1
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter, GeminiLLMInvocationParams
|
||||
from pipecat.frames.frames import (
|
||||
AssistantImageRawFrame,
|
||||
AudioRawFrame,
|
||||
Frame,
|
||||
FunctionCallCancelFrame,
|
||||
@@ -43,7 +44,7 @@ from pipecat.frames.frames import (
|
||||
UserImageRawFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
@@ -992,11 +993,25 @@ class GoogleLLMService(LLMService):
|
||||
)
|
||||
)
|
||||
elif part.inline_data and part.inline_data.data:
|
||||
# Here we assume that inline_data is an image.
|
||||
image = Image.open(io.BytesIO(part.inline_data.data))
|
||||
frame = OutputImageRawFrame(
|
||||
image=image.tobytes(), size=image.size, format="RGB"
|
||||
# NOTE: Gemini 3 Pro Image seems to always give
|
||||
# JPEGs. It expects us to send back the
|
||||
# original JPEG data in the context, along with
|
||||
# the corresponding thought signature. JPEG
|
||||
# happens to be the format our universal
|
||||
# context uses for images, so we can just pass
|
||||
# it through as-is.
|
||||
await self.push_frame(
|
||||
AssistantImageRawFrame(
|
||||
image=image.tobytes(),
|
||||
size=image.size,
|
||||
format="RGB",
|
||||
original_jpeg=part.inline_data.data
|
||||
if part.inline_data.mime_type == "image/jpeg"
|
||||
else None,
|
||||
)
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Handle Gemini thought signatures.
|
||||
#
|
||||
@@ -1022,13 +1037,12 @@ class GoogleLLMService(LLMService):
|
||||
if part.function_call:
|
||||
bookmark["function_call"] = function_call_id
|
||||
elif part.inline_data and part.inline_data.data:
|
||||
# NOTE: missing feature: we don't store
|
||||
# inline_data messages (like generated
|
||||
# images) in context today, so this thought
|
||||
# signature is not fully supported yet.
|
||||
# (A conversation with
|
||||
# "gemini-3-pro-image-preview" doesn't work
|
||||
# today due to the missing context.)
|
||||
# With Gemini 3 Pro (where sending the
|
||||
# thought signature is required for images)
|
||||
# this is the JPEG-encoded image data that
|
||||
# we sent to be written to the context
|
||||
# as-is, so it is usable as a bookmark (it
|
||||
# will match the context data).
|
||||
bookmark["inline_data"] = part.inline_data
|
||||
elif part.text is not None:
|
||||
# Account for Gemini 3 Pro trailing
|
||||
|
||||
@@ -23,6 +23,7 @@ from pipecat.audio.dtmf.utils import load_dtmf_audio
|
||||
from pipecat.audio.mixers.base_audio_mixer import BaseAudioMixer
|
||||
from pipecat.audio.utils import create_stream_resampler, is_silence
|
||||
from pipecat.frames.frames import (
|
||||
AssistantImageRawFrame,
|
||||
BotSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
@@ -335,6 +336,10 @@ class BaseOutputTransport(FrameProcessor):
|
||||
await sender.handle_audio_frame(frame)
|
||||
elif isinstance(frame, (OutputImageRawFrame, SpriteFrame)):
|
||||
await sender.handle_image_frame(frame)
|
||||
if isinstance(frame, AssistantImageRawFrame):
|
||||
# This will push it further, to be handled by the assistant
|
||||
# aggregator, say
|
||||
await sender.handle_sync_frame(frame)
|
||||
elif isinstance(frame, MixerControlFrame):
|
||||
await sender.handle_mixer_control_frame(frame)
|
||||
elif frame.pts:
|
||||
@@ -753,7 +758,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
await self._handle_frame(frame)
|
||||
|
||||
# If we are not able to write to the transport we shouldn't
|
||||
# pushb downstream.
|
||||
# push downstream.
|
||||
push_downstream = True
|
||||
|
||||
# Try to send audio to the transport.
|
||||
|
||||
Reference in New Issue
Block a user