Compare commits

..

12 Commits

Author SHA1 Message Date
James Hush
858e305c40 Get the Daily session id 2025-05-12 09:39:25 +08:00
Mark Backman
20498fb47f Merge pull request #1790 from AngeloGiacco/angelo/fix-api-key
[elevenlabs tts ] fix api key
2025-05-10 19:16:27 -04:00
Angelo Giacco
b57dfb3b5d fix lint 2025-05-10 16:36:26 +01:00
Angelo Giacco
0355ed4aa1 move api key to ws header 2025-05-10 16:34:01 +01:00
Angelo Giacco
1e76cc7bdc fix: elevenlabs api key 2025-05-10 16:09:20 +01:00
Vanessa Pyne
18c0374126 Merge pull request #1785 from pipecat-ai/vp-small-filenmae-change
39-aws-nova-sonic.py -> 40-aws-nova-sonic.py
2025-05-09 12:19:09 -05:00
Aleix Conchillo Flaqué
7072fba7e7 Merge pull request #1780 from pipecat-ai/aleix/deprecate-google-generativeai
GoogleLLMService: deprecate google-generativeai
2025-05-09 09:18:30 -07:00
Aleix Conchillo Flaqué
3d702a5c39 minor examples cleanup 2025-05-09 09:16:10 -07:00
Aleix Conchillo Flaqué
f31efa42c9 GoogleLLMService: deprecate google-generativeai 2025-05-09 09:14:43 -07:00
vipyne
74b369ff20 39-aws-nova-sonic.py -> 40-aws-nova-sonic.py 2025-05-09 08:30:59 -05:00
kompfner
9643296e29 Merge pull request #1779 from pipecat-ai/pk/aws-nova-sonic-missing-params-export
Add missing `Params` export to AWS Nova Sonic module
2025-05-08 16:04:38 -04:00
Paul Kompfner
c83c5b5a34 Add missing Params export to AWS Nova Sonic module 2025-05-08 15:23:25 -04:00
16 changed files with 282 additions and 347 deletions

View File

@@ -5,6 +5,13 @@ All notable changes to **Pipecat** will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Changed
- `GoogleLLMService` has been updated to use `google-genai` instead of the
deprecated `google-generativeai`.
## [0.0.67] - 2025-05-07
### Added

View File

@@ -11,18 +11,17 @@ from pathlib import Path
from dotenv import load_dotenv
from loguru import logger
from openai import audio
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import Frame
from pipecat.observers.base_observer import BaseObserver, FramePushed
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.google.llm import GoogleLLMService, LLMSearchResponseFrame
from pipecat.services.llm_service import LLMService
from pipecat.transports.base_transport import TransportParams
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
@@ -33,7 +32,7 @@ load_dotenv(override=True)
# Function handlers for the LLM
search_tool = {"google_search_retrieval": {}}
search_tool = {"google_search": {}}
tools = [search_tool]
system_instruction = """
@@ -50,14 +49,22 @@ Start each interaction by asking the user about which place they would like to k
"""
class LLMSearchLoggerProcessor(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
class LLMSearchLoggerObserver(BaseObserver):
async def on_push_frame(self, data: FramePushed):
src = data.source
dst = data.destination
frame = data.frame
timestamp = data.timestamp
if not isinstance(src, LLMService) and not isinstance(dst, LLMService):
return
time_sec = timestamp / 1_000_000_000
arrow = ""
if isinstance(frame, LLMSearchResponseFrame):
print(f"LLMSearchLoggerProcessor: {frame}")
await self.push_frame(frame)
logger.debug(f"🧠 {arrow} {dst} LLM SEARCH RESPONSE FRAME: {frame} at {time_sec:.2f}s")
async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace):
@@ -84,7 +91,6 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
api_key=os.getenv("GOOGLE_API_KEY"),
system_instruction=system_instruction,
tools=tools,
model="gemini-1.5-flash-002",
)
context = OpenAILLMContext(
@@ -97,22 +103,23 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
)
context_aggregator = llm.create_context_aggregator(context)
llm_search_logger = LLMSearchLoggerProcessor()
pipeline = Pipeline(
[
transport.input(),
stt,
context_aggregator.user(),
llm,
llm_search_logger,
tts,
transport.output(),
context_aggregator.assistant(),
]
)
task = PipelineTask(pipeline, params=PipelineParams(allow_interruptions=True))
task = PipelineTask(
pipeline,
params=PipelineParams(allow_interruptions=True),
observers=[LLMSearchLoggerObserver()],
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):

View File

@@ -102,9 +102,9 @@ async def main():
llm = GoogleLLMService(
api_key=os.getenv("GOOGLE_API_KEY"),
model="gemini-1.5-flash-002",
system_instruction=system_instruction,
tools=tools,
model="gemini-1.5-flash",
)
context = OpenAILLMContext(
@@ -153,7 +153,6 @@ async def main():
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
logger.debug("First participant joined: {}", participant["id"])
await transport.capture_participant_transcription(participant["id"])
@transport.event_handler("on_participant_left")
async def on_participant_left(transport, participant, reason):

View File

@@ -8,6 +8,7 @@ import {
} from '@pipecat-ai/client-js';
import { useRTVIClient, useRTVIClientEvent } from '@pipecat-ai/client-react';
import './DebugDisplay.css';
import { DailyTransport } from '@pipecat-ai/daily-transport';
export function DebugDisplay() {
const debugLogRef = useRef<HTMLDivElement>(null);
@@ -52,6 +53,17 @@ export function DebugDisplay() {
)
);
// Log connection events
useRTVIClientEvent(
RTVIEvent.Connected,
useCallback(() => {
if (!client) return;
const dailyCallClient = (client.transport as DailyTransport)
.dailyCallClient;
console.log(`Session ID: ${dailyCallClient.meetingSessionSummary().id}`);
}, [client])
);
useRTVIClientEvent(
RTVIEvent.BotDisconnected,
useCallback(

View File

@@ -187,7 +187,7 @@ async def main():
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await transport.capture_participant_transcription(participant["id"])
print(f"Participant joined: {participant}")
@transport.event_handler("on_participant_left")
async def on_participant_left(transport, participant, reason):

View File

@@ -215,6 +215,7 @@ async def main():
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
print(f"Participant joined: {participant}")
await transport.capture_participant_transcription(participant["id"])
@transport.event_handler("on_participant_left")

View File

@@ -30,7 +30,7 @@ from loguru import logger
from pipecatcloud.agent import DailySessionArguments
from word_list import generate_game_words
from pipecat.audio.resamplers.soxr_resampler import SOXRAudioResampler
from pipecat.audio.utils import create_default_resampler
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
@@ -524,7 +524,7 @@ async def tts_audio_raw_frame_filter(frame: Frame):
# Create a resampler instance once
resampler = SOXRAudioResampler()
resampler = create_default_resampler()
async def tts_to_input_audio_transformer(frame: Frame):
@@ -689,8 +689,6 @@ Important guidelines:
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
logger.info("First participant joined: {}", participant["id"])
# Capture the participant's transcription
await transport.capture_participant_transcription(participant["id"])
# Kick off the conversation
await task.queue_frames([context_aggregator.user().get_context_frame()])
# Start the game timer

View File

@@ -54,7 +54,7 @@ fal = [ "fal-client~=0.5.9" ]
fireworks = []
fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
gladia = [ "websockets~=13.1" ]
google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4", "websockets~=13.1" ]
google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.14.0", "websockets~=13.1" ]
grok = []
groq = [ "groq~=0.23.0" ]
gstreamer = [ "pygobject~=3.50.0" ]

View File

@@ -77,8 +77,8 @@ class Frame:
@dataclass
class SystemFrame(Frame):
"""A frame that takes higher priority than other frames. System frames are
handled in order and are not affected by user interruptions.
"""System frames are frames that are not internally queued by any of the
frame processors and should be processed immediately.
"""
@@ -87,9 +87,8 @@ class SystemFrame(Frame):
@dataclass
class DataFrame(Frame):
"""A frame that is processed in order and usually contains data such as LLM
context, text, audio or images. Data frames are cancelled by user
interruptions.
"""Data frames are frames that will be processed in order and usually
contain data such as LLM context, text, audio or images.
"""
@@ -98,9 +97,9 @@ class DataFrame(Frame):
@dataclass
class ControlFrame(Frame):
"""A frame that, as data frames, is processed in order and usually contains
control information such as update settings or to end the pipeline after
everything is flushed. Control frames are cancelled by user interruptions.
"""Control frames are frames that, similar to data frames, will be processed
in order and usually contain control information such as frames to update
settings or to end the pipeline.
"""
@@ -691,7 +690,7 @@ class FunctionCallResultFrame(SystemFrame):
@dataclass
class STTMuteFrame(SystemFrame):
"""A frame to mute/unmute the STT service."""
"""System frame to mute/unmute the STT service."""
mute: bool
@@ -797,7 +796,7 @@ class EndFrame(ControlFrame):
should be shut down. If the transport receives this frame, it will stop
sending frames to its output channel(s) and close all its threads. Note,
that this is a control frame, which means it will received in the order it
was sent.
was sent (unline system frames).
"""

View File

@@ -5,7 +5,6 @@
#
import asyncio
from dataclasses import dataclass
from enum import Enum
from typing import Awaitable, Callable, Coroutine, Optional
@@ -33,51 +32,6 @@ class FrameDirection(Enum):
UPSTREAM = 2
@dataclass
class FrameProcessorQueueItem:
frame: Frame
direction: FrameDirection
callback: Optional[Callable[["FrameProcessor", Frame, FrameDirection], Awaitable[None]]]
class FrameProcessorQueue:
def __init__(self):
self._queue = asyncio.Queue()
self._urgent_queue = asyncio.Queue()
self._event = asyncio.Event()
async def put(self, item: FrameProcessorQueueItem):
if isinstance(item.frame, SystemFrame):
await self._urgent_queue.put(item)
else:
await self._queue.put(item)
self._event.set()
async def get(self) -> FrameProcessorQueueItem:
# Wait for an item in any of the queues.
await self._event.wait()
if self._urgent_queue.empty():
item = await self._queue.get()
self._queue.task_done()
else:
item = await self._urgent_queue.get()
self._urgent_queue.task_done()
# Clear the event only if all queues are empty.
if self._queue.empty() and self._urgent_queue.empty():
self._event.clear()
return item
def clear(self):
self._queue = asyncio.Queue()
# Clear the event only if all queues are empty.
if self._queue.empty() and self._urgent_queue.empty():
self._event.clear()
class FrameProcessor(BaseObject):
def __init__(
self,
@@ -115,21 +69,18 @@ class FrameProcessor(BaseObject):
self._metrics = metrics or FrameProcessorMetrics()
self._metrics.set_processor_name(self.name)
# Processors receive frames on a streaming queue which are then
# processed by a streaming task. This guarantees that all frames are
# processed in the same task. By default, the streaming queue is
# processed immediately but it may block if `pause_processing_frames()`
# Processors have an input queue. The input queue will be processed
# immediately (default) or it will block if `pause_processing_frames()`
# is called. To resume processing frames we need to call
# `resume_processing_frames()` which will wake up the event.
self.__should_block_frames = False
self.__streaming_event = asyncio.Event()
self.__streaming_queue = FrameProcessorQueue()
self.__streaming_frame_task: Optional[asyncio.Task] = None
self.__input_event = asyncio.Event()
self.__input_frame_task: Optional[asyncio.Task] = None
self.__process_queue = asyncio.Queue()
self.__process_task: Optional[asyncio.Task] = None
self.__process_urgent_queue = asyncio.Queue()
self.__process_urgent_task: Optional[asyncio.Task] = None
# Every processor in Pipecat should only output frames from a single
# task. This avoid problems like audio overlapping. System frames are the
# exception to this rule. This create this task.
self.__push_frame_task: Optional[asyncio.Task] = None
@property
def id(self) -> int:
@@ -219,8 +170,7 @@ class FrameProcessor(BaseObject):
async def cleanup(self):
await super().cleanup()
await self.__cancel_input_task()
await self.__cancel_process_task()
await self.__cancel_process_urgent_task()
await self.__cancel_push_task()
def link(self, processor: "FrameProcessor"):
self._next = processor
@@ -265,7 +215,7 @@ class FrameProcessor(BaseObject):
await self.process_frame(frame, direction)
else:
# We queue everything else.
await self.__streaming_queue.put(FrameProcessorQueueItem(frame, direction, callback))
await self.__input_queue.put((frame, direction, callback))
async def pause_processing_frames(self):
logger.trace(f"{self}: pausing frame processing")
@@ -273,7 +223,7 @@ class FrameProcessor(BaseObject):
async def resume_processing_frames(self):
logger.trace(f"{self}: resuming frame processing")
self.__streaming_event.set()
self.__input_event.set()
async def process_frame(self, frame: Frame, direction: FrameDirection):
if isinstance(frame, StartFrame):
@@ -300,6 +250,47 @@ class FrameProcessor(BaseObject):
if not self._check_ready(frame):
return
if isinstance(frame, SystemFrame):
await self.__internal_push_frame(frame, direction)
else:
await self.__push_queue.put((frame, direction))
async def __start(self, frame: StartFrame):
self.__create_input_task()
self.__create_push_task()
async def __cancel(self, frame: CancelFrame):
self._cancelling = True
await self.__cancel_input_task()
await self.__cancel_push_task()
#
# Handle interruptions
#
async def _start_interruption(self):
try:
# Cancel the push frame task. This will stop pushing frames downstream.
await self.__cancel_push_task()
# Cancel the input task. This will stop processing queued frames.
await self.__cancel_input_task()
except Exception as e:
logger.exception(f"Uncaught exception in {self}: {e}")
await self.push_error(ErrorFrame(str(e)))
raise
# Create a new input queue and task.
self.__create_input_task()
# Create a new output queue and task.
self.__create_push_task()
async def _stop_interruption(self):
# Nothing to do right now.
pass
async def __internal_push_frame(self, frame: Frame, direction: FrameDirection):
try:
timestamp = self._clock.get_time() if self._clock else 0
if direction == FrameDirection.DOWNSTREAM and self._next:
@@ -332,49 +323,6 @@ class FrameProcessor(BaseObject):
await self.push_error(ErrorFrame(str(e)))
raise
async def __start(self, frame: StartFrame):
self.__create_process_task()
self.__create_process_urgent_task()
self.__create_input_task()
async def __cancel(self, frame: CancelFrame):
self._cancelling = True
await self.__cancel_input_task()
await self.__cancel_process_task()
await self.__cancel_process_urgent_task()
#
# Handle interruptions
#
async def _start_interruption(self):
try:
# Cancel the streaming task.
await self.__cancel_input_task()
# Cancel the task processing frames. We do not cancel the task that
# is processing urgent frames.
await self.__cancel_process_task()
# If there's an interruption we should not block frames anymore.
self.__should_block_frames = False
# Clear the streaming queue, since we don't want to process its
# frame anymore (except system and urgent frames).
self.__streaming_queue.clear()
except Exception as e:
logger.exception(f"Uncaught exception in {self}: {e}")
await self.push_error(ErrorFrame(str(e)))
raise
# Create a new tasks.
self.__create_process_task()
self.__create_input_task()
async def _stop_interruption(self):
# Nothing to do right now.
pass
def _check_ready(self, frame: Frame):
# If we are trying to push a frame but we still have no clock, it means
# we didn't process a StartFrame.
@@ -386,60 +334,49 @@ class FrameProcessor(BaseObject):
return True
def __create_input_task(self):
if not self.__streaming_frame_task:
self.__streaming_frame_task = self.create_task(self.__streaming_frame_task_handler())
if not self.__input_frame_task:
self.__should_block_frames = False
self.__input_event.clear()
self.__input_queue = asyncio.Queue()
self.__input_frame_task = self.create_task(self.__input_frame_task_handler())
async def __cancel_input_task(self):
if self.__streaming_frame_task:
await self.cancel_task(self.__streaming_frame_task)
self.__streaming_frame_task = None
if self.__input_frame_task:
await self.cancel_task(self.__input_frame_task)
self.__input_frame_task = None
def __create_process_task(self):
if not self.__process_task:
self.__process_queue = asyncio.Queue()
self.__process_task = self.create_task(
self.__process_task_handler(self.__process_queue)
)
async def __cancel_process_task(self):
if self.__process_task:
await self.cancel_task(self.__process_task)
self.__process_task = None
def __create_process_urgent_task(self):
if not self.__process_urgent_task:
self.__process_urgent_task = self.create_task(
self.__process_task_handler(self.__process_urgent_queue)
)
async def __cancel_process_urgent_task(self):
if self.__process_urgent_task:
await self.cancel_task(self.__process_urgent_task)
self.__process_urgent_task = None
async def __streaming_frame_task_handler(self):
async def __input_frame_task_handler(self):
while True:
if self.__should_block_frames:
logger.trace(f"{self}: frame processing paused")
await self.__streaming_event.wait()
self.__streaming_event.clear()
await self.__input_event.wait()
self.__input_event.clear()
self.__should_block_frames = False
logger.trace(f"{self}: frame processing resumed")
item = await self.__streaming_queue.get()
if isinstance(item.frame, SystemFrame):
await self.__process_urgent_queue.put(item)
else:
await self.__process_queue.put(item)
async def __process_task_handler(self, queue: asyncio.Queue):
while True:
item = await queue.get()
(frame, direction, callback) = await self.__input_queue.get()
# Process the frame.
await self.process_frame(item.frame, item.direction)
await self.process_frame(frame, direction)
# If this frame has an associated callback, call it now.
if item.callback:
await item.callback(self, item.frame, item.direction)
if callback:
await callback(self, frame, direction)
self.__input_queue.task_done()
def __create_push_task(self):
if not self.__push_frame_task:
self.__push_queue = asyncio.Queue()
self.__push_frame_task = self.create_task(self.__push_frame_task_handler())
async def __cancel_push_task(self):
if self.__push_frame_task:
await self.cancel_task(self.__push_frame_task)
self.__push_frame_task = None
async def __push_frame_task_handler(self):
while True:
(frame, direction) = await self.__push_queue.get()
await self.__internal_push_frame(frame, direction)
self.__push_queue.task_done()

View File

@@ -1 +1 @@
from .aws import AWSNovaSonicLLMService
from .aws import AWSNovaSonicLLMService, Params

View File

@@ -334,7 +334,9 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
)
# Set max websocket message size to 16MB for large audio responses
self._websocket = await websockets.connect(url, max_size=16 * 1024 * 1024)
self._websocket = await websockets.connect(
url, max_size=16 * 1024 * 1024, extra_headers={"xi-api-key": self._api_key}
)
except Exception as e:
logger.error(f"{self} initialization error: {e}")
@@ -425,7 +427,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
if self._websocket:
if not self._context_id:
# First message for a new context - need a space to initialize
msg = {"text": " ", "context_id": str(uuid.uuid4()), "xi_api_key": self._api_key}
msg = {"text": " ", "context_id": str(uuid.uuid4())}
# Add voice settings only in first message for a context
if self._voice_settings:

View File

@@ -52,10 +52,16 @@ from pipecat.services.openai.llm import (
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
try:
import google.ai.generativelanguage as glm
import google.generativeai as gai
from google import genai
from google.api_core.exceptions import DeadlineExceeded
from google.generativeai.types import GenerationConfig
from google.genai.types import (
Blob,
Content,
FunctionCall,
FunctionResponse,
GenerateContentConfig,
Part,
)
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
@@ -65,9 +71,7 @@ except ModuleNotFoundError as e:
class GoogleUserContextAggregator(OpenAIUserContextAggregator):
async def push_aggregation(self):
if len(self._aggregation) > 0:
self._context.add_message(
glm.Content(role="user", parts=[glm.Part(text=self._aggregation)])
)
self._context.add_message(Content(role="user", parts=[Part(text=self._aggregation)]))
# Reset the aggregation. Reset it before pushing it down, otherwise
# if the tasks gets cancelled we won't be able to clear things up.
@@ -83,15 +87,15 @@ class GoogleUserContextAggregator(OpenAIUserContextAggregator):
class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
async def handle_aggregation(self, aggregation: str):
self._context.add_message(glm.Content(role="model", parts=[glm.Part(text=aggregation)]))
self._context.add_message(Content(role="model", parts=[Part(text=aggregation)]))
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
self._context.add_message(
glm.Content(
Content(
role="model",
parts=[
glm.Part(
function_call=glm.FunctionCall(
Part(
function_call=FunctionCall(
id=frame.tool_call_id, name=frame.function_name, args=frame.arguments
)
)
@@ -99,11 +103,11 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
)
)
self._context.add_message(
glm.Content(
Content(
role="user",
parts=[
glm.Part(
function_response=glm.FunctionResponse(
Part(
function_response=FunctionResponse(
id=frame.tool_call_id,
name=frame.function_name,
response={"response": "IN_PROGRESS"},
@@ -187,7 +191,7 @@ class GoogleLLMContext(OpenAILLMContext):
# Convert each message individually
converted_messages = []
for msg in messages:
if isinstance(msg, glm.Content):
if isinstance(msg, Content):
# Already in Gemini format
converted_messages.append(msg)
else:
@@ -202,7 +206,7 @@ class GoogleLLMContext(OpenAILLMContext):
def get_messages_for_logging(self):
msgs = []
for message in self.messages:
obj = glm.Content.to_dict(message)
obj = message.to_json_dict()
try:
if "parts" in obj:
for part in obj["parts"]:
@@ -221,10 +225,10 @@ class GoogleLLMContext(OpenAILLMContext):
parts = []
if text:
parts.append(glm.Part(text=text))
parts.append(glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue())))
parts.append(Part(text=text))
parts.append(Part(inline_data=Blob(mime_type="image/jpeg", data=buffer.getvalue())))
self.add_message(glm.Content(role="user", parts=parts))
self.add_message(Content(role="user", parts=parts))
def add_audio_frames_message(
self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
@@ -239,10 +243,10 @@ class GoogleLLMContext(OpenAILLMContext):
data = b"".join(frame.audio for frame in audio_frames)
# NOTE(aleix): According to the docs only text or inline_data should be needed.
# (see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference)
parts.append(glm.Part(text=text))
parts.append(Part(text=text))
parts.append(
glm.Part(
inline_data=glm.Blob(
Part(
inline_data=Blob(
mime_type="audio/wav",
data=(
bytes(
@@ -252,7 +256,7 @@ class GoogleLLMContext(OpenAILLMContext):
)
),
)
self.add_message(glm.Content(role="user", parts=parts))
self.add_message(Content(role="user", parts=parts))
# message = {"mime_type": "audio/mp3", "data": bytes(data + create_wav_header(sample_rate, num_channels, 16, len(data)))}
# self.add_message(message)
@@ -271,7 +275,7 @@ class GoogleLLMContext(OpenAILLMContext):
}
Returns:
glm.Content object with:
Content object with:
- role: "user" or "model" (converted from "assistant")
- parts: List[Part] containing text, inline_data, or function calls
Returns None for system messages.
@@ -288,8 +292,8 @@ class GoogleLLMContext(OpenAILLMContext):
if message.get("tool_calls"):
for tc in message["tool_calls"]:
parts.append(
glm.Part(
function_call=glm.FunctionCall(
Part(
function_call=FunctionCall(
name=tc["function"]["name"],
args=json.loads(tc["function"]["arguments"]),
)
@@ -298,30 +302,30 @@ class GoogleLLMContext(OpenAILLMContext):
elif role == "tool":
role = "model"
parts.append(
glm.Part(
function_response=glm.FunctionResponse(
Part(
function_response=FunctionResponse(
name="tool_call_result", # seems to work to hard-code the same name every time
response=json.loads(message["content"]),
)
)
)
elif isinstance(content, str):
parts.append(glm.Part(text=content))
parts.append(Part(text=content))
elif isinstance(content, list):
for c in content:
if c["type"] == "text":
parts.append(glm.Part(text=c["text"]))
parts.append(Part(text=c["text"]))
elif c["type"] == "image_url":
parts.append(
glm.Part(
inline_data=glm.Blob(
Part(
inline_data=Blob(
mime_type="image/jpeg",
data=base64.b64decode(c["image_url"]["url"].split(",")[1]),
)
)
)
message = glm.Content(role=role, parts=parts)
message = Content(role=role, parts=parts)
return message
def to_standard_messages(self, obj) -> list:
@@ -409,7 +413,7 @@ class GoogleLLMContext(OpenAILLMContext):
# Process each message, preserving Google-formatted messages and converting others
for message in self._messages:
if isinstance(message, glm.Content):
if isinstance(message, Content):
# Keep existing Google-formatted messages (e.g., function calls/responses)
converted_messages.append(message)
continue
@@ -433,9 +437,7 @@ class GoogleLLMContext(OpenAILLMContext):
# Add system message back as a user message if we only have function messages
if self.system_message and not has_regular_messages:
self._messages.append(
glm.Content(role="user", parts=[glm.Part(text=self.system_message)])
)
self._messages.append(Content(role="user", parts=[Part(text=self.system_message)]))
# Remove any empty messages
self._messages = [m for m in self._messages if m.parts]
@@ -463,7 +465,7 @@ class GoogleLLMService(LLMService):
self,
*,
api_key: str,
model: str = "gemini-2.0-flash-001",
model: str = "gemini-2.0-flash",
params: InputParams = InputParams(),
system_instruction: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
@@ -471,10 +473,10 @@ class GoogleLLMService(LLMService):
**kwargs,
):
super().__init__(**kwargs)
gai.configure(api_key=api_key)
self.set_model_name(model)
self._api_key = api_key
self._system_instruction = system_instruction
self._create_client()
self._create_client(api_key)
self._settings = {
"max_tokens": params.max_tokens,
"temperature": params.temperature,
@@ -488,10 +490,8 @@ class GoogleLLMService(LLMService):
def can_generate_metrics(self) -> bool:
return True
def _create_client(self):
self._client = gai.GenerativeModel(
self._model_name, system_instruction=self._system_instruction
)
def _create_client(self, api_key: str):
self._client = genai.Client(api_key=api_key)
async def _process_context(self, context: OpenAILLMContext):
await self.push_frame(LLMFullResponseStartFrame())
@@ -513,23 +513,7 @@ class GoogleLLMService(LLMService):
if context.system_message and self._system_instruction != context.system_message:
logger.debug(f"System instruction changed: {context.system_message}")
self._system_instruction = context.system_message
self._create_client()
# Filter out None values and create GenerationConfig
generation_params = {
k: v
for k, v in {
"temperature": self._settings["temperature"],
"top_p": self._settings["top_p"],
"top_k": self._settings["top_k"],
"max_output_tokens": self._settings["max_tokens"],
}.items()
if v is not None
}
generation_config = GenerationConfig(**generation_params) if generation_params else None
await self.start_ttfb_metrics()
tools = []
if context.tools:
tools = context.tools
@@ -538,112 +522,104 @@ class GoogleLLMService(LLMService):
tool_config = None
if self._tool_config:
tool_config = self._tool_config
response = await self._client.generate_content_async(
# Filter out None values and create GenerationContentConfig
generation_params = {
k: v
for k, v in {
"system_instruction": self._system_instruction,
"temperature": self._settings["temperature"],
"top_p": self._settings["top_p"],
"top_k": self._settings["top_k"],
"max_output_tokens": self._settings["max_tokens"],
"tools": tools,
"tool_config": tool_config,
}.items()
if v is not None
}
generation_config = (
GenerateContentConfig(**generation_params) if generation_params else None
)
await self.start_ttfb_metrics()
response = await self._client.aio.models.generate_content_stream(
model=self._model_name,
contents=messages,
tools=tools,
stream=True,
generation_config=generation_config,
tool_config=tool_config,
config=generation_config,
)
await self.stop_ttfb_metrics()
if response.usage_metadata:
# Use only the prompt token count from the response object
prompt_tokens = response.usage_metadata.prompt_token_count
total_tokens = prompt_tokens
async for chunk in response:
if chunk.usage_metadata:
# Use only the completion_tokens from the chunks. Prompt tokens are already counted and
# are repeated here.
completion_tokens += chunk.usage_metadata.candidates_token_count
total_tokens += chunk.usage_metadata.candidates_token_count
try:
for c in chunk.parts:
if c.text:
search_result += c.text
await self.push_frame(LLMTextFrame(c.text))
elif c.function_call:
logger.debug(f"Function call: {c.function_call}")
args = type(c.function_call).to_dict(c.function_call).get("args", {})
await self.call_function(
context=context,
tool_call_id=str(uuid.uuid4()),
function_name=c.function_call.name,
arguments=args,
)
# Handle grounding metadata
# It seems only the last chunk that we receive may contain this information
# If the response doesn't include groundingMetadata, this means the response wasn't grounded.
if chunk.candidates:
for candidate in chunk.candidates:
# logger.debug(f"candidate received: {candidate}")
# Extract grounding metadata
grounding_metadata = (
{
"rendered_content": getattr(
getattr(candidate, "grounding_metadata", None),
"search_entry_point",
None,
).rendered_content
if hasattr(
getattr(candidate, "grounding_metadata", None),
"search_entry_point",
)
else None,
"origins": [
{
"site_uri": getattr(grounding_chunk.web, "uri", None),
"site_title": getattr(
grounding_chunk.web, "title", None
),
"results": [
{
"text": getattr(
grounding_support.segment, "text", ""
),
"confidence": getattr(
grounding_support, "confidence_scores", None
),
}
for grounding_support in getattr(
getattr(candidate, "grounding_metadata", None),
"grounding_supports",
[],
)
if index
in getattr(
grounding_support, "grounding_chunk_indices", []
)
],
}
for index, grounding_chunk in enumerate(
getattr(
getattr(candidate, "grounding_metadata", None),
"grounding_chunks",
[],
)
)
],
}
if getattr(candidate, "grounding_metadata", None)
else None
)
except Exception as e:
# Google LLMs seem to flag safety issues a lot!
if chunk.candidates[0].finish_reason == 3:
logger.debug(
f"LLM refused to generate content for safety reasons - {messages}."
)
else:
logger.exception(f"{self} error: {e}")
prompt_tokens += chunk.usage_metadata.prompt_token_count or 0
completion_tokens += chunk.usage_metadata.candidates_token_count or 0
total_tokens += chunk.usage_metadata.total_token_count or 0
if not chunk.candidates:
continue
for candidate in chunk.candidates:
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if not part.thought and part.text:
search_result += part.text
await self.push_frame(LLMTextFrame(part.text))
elif part.function_call:
function_call = part.function_call
id = function_call.id or str(uuid.uuid4())
logger.debug(f"Function call: {function_call.name}:{id}")
await self.call_function(
context=context,
tool_call_id=id,
function_name=function_call.name,
arguments=function_call.args or {},
)
if (
candidate.grounding_metadata
and candidate.grounding_metadata.grounding_chunks
):
m = candidate.grounding_metadata
rendered_content = (
m.search_entry_point.rendered_content if m.search_entry_point else None
)
origins = [
{
"site_uri": grounding_chunk.web.uri
if grounding_chunk.web
else None,
"site_title": grounding_chunk.web.title
if grounding_chunk.web
else None,
"results": [
{
"text": grounding_support.segment.text
if grounding_support.segment
else "",
"confidence": grounding_support.confidence_scores,
}
for grounding_support in (
m.grounding_supports if m.grounding_supports else []
)
if grounding_support.grounding_chunk_indices
and index in grounding_support.grounding_chunk_indices
],
}
for index, grounding_chunk in enumerate(
m.grounding_chunks if m.grounding_chunks else []
)
]
grounding_metadata = {
"rendered_content": rendered_content,
"origins": origins,
}
except DeadlineExceeded:
await self._call_event_handler("on_completion_timeout")
except Exception as e:
logger.exception(f"{self} exception: {e}")
finally:
if grounding_metadata is not None and isinstance(grounding_metadata, dict):
if grounding_metadata and isinstance(grounding_metadata, dict):
llm_search_frame = LLMSearchResponseFrame(
search_result=search_result,
origins=grounding_metadata["origins"],

View File

@@ -8,8 +8,6 @@ import json
import unittest
from typing import Any
import google.ai.generativelanguage as glm
from pipecat.frames.frames import (
EmulateUserStartedSpeakingFrame,
EmulateUserStoppedSpeakingFrame,
@@ -758,13 +756,13 @@ class TestGoogleUserContextAggregator(
AGGREGATOR_CLASS = GoogleUserContextAggregator
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
obj = glm.Content.to_dict(context.messages[index])
obj = context.messages[index].to_json_dict()
assert obj["parts"][0]["text"] == content
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
):
obj = glm.Content.to_dict(context.messages[index])
obj = context.messages[index].to_json_dict()
assert obj["parts"][0]["text"] == content
@@ -776,17 +774,17 @@ class TestGoogleAssistantContextAggregator(
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
obj = glm.Content.to_dict(context.messages[index])
obj = context.messages[index].to_json_dict()
assert obj["parts"][0]["text"] == content
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
):
obj = glm.Content.to_dict(context.messages[index])
obj = context.messages[index].to_json_dict()
assert obj["parts"][0]["text"] == content
def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any):
obj = glm.Content.to_dict(context.messages[index])
obj = context.messages[index].to_json_dict()
assert obj["parts"][0]["function_response"]["response"]["value"] == json.dumps(content)

View File

@@ -9,7 +9,6 @@ import unittest
from pipecat.frames.frames import (
EndFrame,
Frame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
@@ -58,8 +57,8 @@ class TestFrameFilter(unittest.IsolatedAsyncioTestCase):
async def test_system_frame(self):
filter = FrameFilter(types=())
frames_to_send = [StartInterruptionFrame()]
expected_down_frames = [StartInterruptionFrame]
frames_to_send = [UserStartedSpeakingFrame()]
expected_down_frames = [UserStartedSpeakingFrame]
await run_test(
filter,
frames_to_send=frames_to_send,