Compare commits

..

12 Commits

Author SHA1 Message Date
James Hush
858e305c40 Get the Daily session id 2025-05-12 09:39:25 +08:00
Mark Backman
20498fb47f Merge pull request #1790 from AngeloGiacco/angelo/fix-api-key
[elevenlabs tts ] fix api key
2025-05-10 19:16:27 -04:00
Angelo Giacco
b57dfb3b5d fix lint 2025-05-10 16:36:26 +01:00
Angelo Giacco
0355ed4aa1 move api key to ws header 2025-05-10 16:34:01 +01:00
Angelo Giacco
1e76cc7bdc fix: elevenlabs api key 2025-05-10 16:09:20 +01:00
Vanessa Pyne
18c0374126 Merge pull request #1785 from pipecat-ai/vp-small-filenmae-change
39-aws-nova-sonic.py -> 40-aws-nova-sonic.py
2025-05-09 12:19:09 -05:00
Aleix Conchillo Flaqué
7072fba7e7 Merge pull request #1780 from pipecat-ai/aleix/deprecate-google-generativeai
GoogleLLMService: deprecate google-generativeai
2025-05-09 09:18:30 -07:00
Aleix Conchillo Flaqué
3d702a5c39 minor examples cleanup 2025-05-09 09:16:10 -07:00
Aleix Conchillo Flaqué
f31efa42c9 GoogleLLMService: deprecate google-generativeai 2025-05-09 09:14:43 -07:00
vipyne
74b369ff20 39-aws-nova-sonic.py -> 40-aws-nova-sonic.py 2025-05-09 08:30:59 -05:00
kompfner
9643296e29 Merge pull request #1779 from pipecat-ai/pk/aws-nova-sonic-missing-params-export
Add missing `Params` export to AWS Nova Sonic module
2025-05-08 16:04:38 -04:00
Paul Kompfner
c83c5b5a34 Add missing Params export to AWS Nova Sonic module 2025-05-08 15:23:25 -04:00
15 changed files with 186 additions and 582 deletions

View File

@@ -5,6 +5,13 @@ All notable changes to **Pipecat** will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Changed
- `GoogleLLMService` has been updated to use `google-genai` instead of the
deprecated `google-generativeai`.
## [0.0.67] - 2025-05-07
### Added

View File

@@ -11,18 +11,17 @@ from pathlib import Path
from dotenv import load_dotenv
from loguru import logger
from openai import audio
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import Frame
from pipecat.observers.base_observer import BaseObserver, FramePushed
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.google.llm import GoogleLLMService, LLMSearchResponseFrame
from pipecat.services.llm_service import LLMService
from pipecat.transports.base_transport import TransportParams
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
@@ -33,7 +32,7 @@ load_dotenv(override=True)
# Function handlers for the LLM
search_tool = {"google_search_retrieval": {}}
search_tool = {"google_search": {}}
tools = [search_tool]
system_instruction = """
@@ -50,14 +49,22 @@ Start each interaction by asking the user about which place they would like to k
"""
class LLMSearchLoggerProcessor(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
class LLMSearchLoggerObserver(BaseObserver):
async def on_push_frame(self, data: FramePushed):
src = data.source
dst = data.destination
frame = data.frame
timestamp = data.timestamp
if not isinstance(src, LLMService) and not isinstance(dst, LLMService):
return
time_sec = timestamp / 1_000_000_000
arrow = ""
if isinstance(frame, LLMSearchResponseFrame):
print(f"LLMSearchLoggerProcessor: {frame}")
await self.push_frame(frame)
logger.debug(f"🧠 {arrow} {dst} LLM SEARCH RESPONSE FRAME: {frame} at {time_sec:.2f}s")
async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace):
@@ -84,7 +91,6 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
api_key=os.getenv("GOOGLE_API_KEY"),
system_instruction=system_instruction,
tools=tools,
model="gemini-1.5-flash-002",
)
context = OpenAILLMContext(
@@ -97,22 +103,23 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
)
context_aggregator = llm.create_context_aggregator(context)
llm_search_logger = LLMSearchLoggerProcessor()
pipeline = Pipeline(
[
transport.input(),
stt,
context_aggregator.user(),
llm,
llm_search_logger,
tts,
transport.output(),
context_aggregator.assistant(),
]
)
task = PipelineTask(pipeline, params=PipelineParams(allow_interruptions=True))
task = PipelineTask(
pipeline,
params=PipelineParams(allow_interruptions=True),
observers=[LLMSearchLoggerObserver()],
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):

View File

@@ -1,274 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import argparse
import os
from abc import ABC, abstractmethod
from dataclasses import field
from typing import List, Literal, Optional
import httpx
from agents import Agent, Runner
from dotenv import load_dotenv
from loguru import logger
from openai import AsyncStream, BaseModel
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import (
Frame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
LLMTextFrame,
LLMUpdateSettingsFrame,
VisionImageRawFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMUserAggregatorParams,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.ai_service import AIService
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.openai.base_llm import BaseOpenAILLMService
from pipecat.services.openai.llm import (
OpenAIAssistantContextAggregator,
OpenAIContextAggregatorPair,
OpenAILLMService,
OpenAIUserContextAggregator,
)
from pipecat.transports.base_transport import TransportParams
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
load_dotenv(override=True)
class LlmMessage(BaseModel):
# ...
role: Literal["system", "user", "assistant", "tool"]
content: Optional[str]
class AgentResponse(BaseModel):
content: str
msgs: list[LlmMessage] = field(default_factory=list)
class BackendBase(ABC):
@abstractmethod
async def get_resp(self, messages: list[LlmMessage], extra_params) -> AgentResponse:
raise NotImplementedError("The method get_resp is not implemented.")
class ChoiceDelta(BaseModel):
content: Optional[str] = None
"""The contents of the chunk message."""
class Choice(BaseModel):
delta: ChoiceDelta
"""The contents of the chunk message."""
index: int
"""The index of the choice in the list of choices."""
class CustomLLMService(BaseOpenAILLMService):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._client = Agent(
name="Assistant agent",
instructions="Respond with haikus.",
# tools=[get_weather],
)
def create_client(
self,
api_key=None,
base_url=None,
organization=None,
project=None,
default_headers=None,
**kwargs,
):
return Agent(
name="Assistant agent",
instructions="Respond with haikus.",
# tools=[get_weather],
)
def create_context_aggregator(
self,
context: OpenAILLMContext,
*,
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> OpenAIContextAggregatorPair:
"""Create an instance of OpenAIContextAggregatorPair.
from an
OpenAILLMContext. Constructor keyword arguments for both the user and
assistant aggregators can be provided.
Args:
context (OpenAILLMContext): The LLM context.
user_params (LLMUserAggregatorParams, optional): User aggregator parameters.
assistant_params (LLMAssistantAggregatorParams, optional): User aggregator parameters.
Returns:
OpenAIContextAggregatorPair: A pair of context aggregators, one for
the user and one for the assistant, encapsulated in an
OpenAIContextAggregatorPair.
"""
context.set_llm_adapter(self.get_llm_adapter())
user = OpenAIUserContextAggregator(context, params=user_params)
assistant = OpenAIAssistantContextAggregator(context, params=assistant_params)
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
async def _process_context(self, context: OpenAILLMContext):
functions_list = []
arguments_list = []
tool_id_list = []
func_idx = 0
function_name = ""
arguments = ""
tool_call_id = ""
await self.start_ttfb_metrics()
result = Runner.run_streamed(
# context=context,
starting_agent=self._client,
input=context.messages, # messages
# ---
# no func tool
# input="give me a 2 sentences about life",
)
logger.info(f"get_chat_completions: {result}")
if result is None:
logger.error("Runner.run_streamed returned None")
return
async for event in result.stream_events():
if event.type == "raw_response_event":
if event.data.type == "response.output_text.delta":
await self.push_frame(LLMTextFrame(event.data.delta))
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
context = None
if isinstance(frame, OpenAILLMContextFrame):
context: OpenAILLMContext = frame.context
elif isinstance(frame, LLMMessagesFrame):
context = OpenAILLMContext.from_messages(frame.messages)
else:
await self.push_frame(frame, direction)
if context:
try:
await self.push_frame(LLMFullResponseStartFrame())
await self.start_processing_metrics()
await self._process_context(context)
except httpx.TimeoutException:
await self._call_event_handler("on_completion_timeout")
finally:
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())
async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace):
logger.info(f"Starting bot")
transport = SmallWebRTCTransport(
webrtc_connection=webrtc_connection,
params=TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
)
llm = CustomLLMService(model="gpt-4.1", api_key=os.getenv("OPENAI_API_KEY"))
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = OpenAILLMContext(messages=messages)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt,
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
report_only_initial_ttfb=True,
),
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
# messages.append({"role": "system", "content": "Please introduce yourself to the user."})
# await task.queue_frames([context_aggregator.user().get_context_frame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
@transport.event_handler("on_client_closed")
async def on_client_closed(transport, client):
logger.info(f"Client closed connection")
await task.cancel()
runner = PipelineRunner(handle_sigint=False)
await runner.run(task)
if __name__ == "__main__":
from run import main
main()

View File

@@ -1,122 +0,0 @@
import asyncio
import logging
import os
from datetime import datetime
from agents import (
Agent,
FunctionTool,
HandoffOutputItem,
ItemHelpers,
MessageOutputItem,
RunContextWrapper,
Runner,
ToolCallItem,
ToolCallOutputItem,
function_tool,
set_default_openai_api,
set_default_openai_client,
set_tracing_disabled,
trace,
)
from httpx import get
@function_tool
async def get_weather(location: str) -> str:
"""Fetch the weather for today.
Args:
location: The location to fetch the weather for.
"""
return f"{location} is sunny"
system_prompt = """
you are a helpful assistant for a real estate brokerage AI assistant.
"""
bot = Agent(
name="Assistant agent",
instructions=system_prompt,
# tools=[get_weather],
)
async def main():
# res = await Runner.run(
# starting_agent=bot,
# input="What is the weather today?",
# )
# print(res)
result = Runner.run_streamed(
starting_agent=bot,
# ---
# with func tool
input="Tell a joke about pirates.",
# ---
# no func tool
# input="give me a 2 sentences about life",
)
final = []
async for event in result.stream_events():
# We'll ignore the raw responses event deltas
name = getattr(event, "name", None)
# print(f"Event: {event.type} - name {name}")
# print(event)
# continue
if event.type == "raw_response_event":
if event.data.type == "response.output_text.delta":
final += event.data.delta
print(f"raw resp: {event}")
# When the agent updates, print that
elif event.type == "agent_updated_stream_event":
print(f"Agent updated: {event.new_agent.name}")
continue
# When items are generated, print them
elif event.type == "run_item_stream_event":
if event.item.type == "tool_call_item":
print("-- Tool was called")
elif event.item.type == "tool_call_output_item":
print(f"-- Tool output: {event.item.output}")
elif event.item.type == "message_output_item":
print(f"-- Message output:\n {ItemHelpers.text_message_output(event.item)}")
else:
print(f"-- Unknown item type: {event.item.type}")
pass # Ignore other event types
else:
print(f"-- Unknown out item type: {event.item.type}")
print(f"----------------------")
print(f"FinalFinalFinal: {''.join(final)}")
if __name__ == "__main__":
asyncio.run(main())
# no func tool:
#
# Event: agent_updated_stream_event - name None
# Event: raw_response_event - name None
# ...
# Event: raw_response_event - name None
# Event: run_item_stream_event - name message_output_created
# with func tool:
#
# Event: agent_updated_stream_event - name None
# Event: raw_response_event - name None
# ...
# Event: raw_response_event - name None
# Event: run_item_stream_event - name tool_called
# Event: run_item_stream_event - name tool_output
# Event: raw_response_event - name None
# ...
# Event: raw_response_event - name None
# Event: run_item_stream_event - name message_output_created

View File

@@ -102,9 +102,9 @@ async def main():
llm = GoogleLLMService(
api_key=os.getenv("GOOGLE_API_KEY"),
model="gemini-1.5-flash-002",
system_instruction=system_instruction,
tools=tools,
model="gemini-1.5-flash",
)
context = OpenAILLMContext(
@@ -153,7 +153,6 @@ async def main():
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
logger.debug("First participant joined: {}", participant["id"])
await transport.capture_participant_transcription(participant["id"])
@transport.event_handler("on_participant_left")
async def on_participant_left(transport, participant, reason):

View File

@@ -8,6 +8,7 @@ import {
} from '@pipecat-ai/client-js';
import { useRTVIClient, useRTVIClientEvent } from '@pipecat-ai/client-react';
import './DebugDisplay.css';
import { DailyTransport } from '@pipecat-ai/daily-transport';
export function DebugDisplay() {
const debugLogRef = useRef<HTMLDivElement>(null);
@@ -52,6 +53,17 @@ export function DebugDisplay() {
)
);
// Log connection events
useRTVIClientEvent(
RTVIEvent.Connected,
useCallback(() => {
if (!client) return;
const dailyCallClient = (client.transport as DailyTransport)
.dailyCallClient;
console.log(`Session ID: ${dailyCallClient.meetingSessionSummary().id}`);
}, [client])
);
useRTVIClientEvent(
RTVIEvent.BotDisconnected,
useCallback(

View File

@@ -187,7 +187,7 @@ async def main():
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await transport.capture_participant_transcription(participant["id"])
print(f"Participant joined: {participant}")
@transport.event_handler("on_participant_left")
async def on_participant_left(transport, participant, reason):

View File

@@ -215,6 +215,7 @@ async def main():
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
print(f"Participant joined: {participant}")
await transport.capture_participant_transcription(participant["id"])
@transport.event_handler("on_participant_left")

View File

@@ -30,7 +30,7 @@ from loguru import logger
from pipecatcloud.agent import DailySessionArguments
from word_list import generate_game_words
from pipecat.audio.resamplers.soxr_resampler import SOXRAudioResampler
from pipecat.audio.utils import create_default_resampler
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
@@ -524,7 +524,7 @@ async def tts_audio_raw_frame_filter(frame: Frame):
# Create a resampler instance once
resampler = SOXRAudioResampler()
resampler = create_default_resampler()
async def tts_to_input_audio_transformer(frame: Frame):
@@ -689,8 +689,6 @@ Important guidelines:
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
logger.info("First participant joined: {}", participant["id"])
# Capture the participant's transcription
await transport.capture_participant_transcription(participant["id"])
# Kick off the conversation
await task.queue_frames([context_aggregator.user().get_context_frame()])
# Start the game timer

View File

@@ -54,7 +54,7 @@ fal = [ "fal-client~=0.5.9" ]
fireworks = []
fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
gladia = [ "websockets~=13.1" ]
google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4", "websockets~=13.1" ]
google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.14.0", "websockets~=13.1" ]
grok = []
groq = [ "groq~=0.23.0" ]
gstreamer = [ "pygobject~=3.50.0" ]

View File

@@ -1 +1 @@
from .aws import AWSNovaSonicLLMService
from .aws import AWSNovaSonicLLMService, Params

View File

@@ -334,7 +334,9 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
)
# Set max websocket message size to 16MB for large audio responses
self._websocket = await websockets.connect(url, max_size=16 * 1024 * 1024)
self._websocket = await websockets.connect(
url, max_size=16 * 1024 * 1024, extra_headers={"xi-api-key": self._api_key}
)
except Exception as e:
logger.error(f"{self} initialization error: {e}")
@@ -425,7 +427,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
if self._websocket:
if not self._context_id:
# First message for a new context - need a space to initialize
msg = {"text": " ", "context_id": str(uuid.uuid4()), "xi_api_key": self._api_key}
msg = {"text": " ", "context_id": str(uuid.uuid4())}
# Add voice settings only in first message for a context
if self._voice_settings:

View File

@@ -52,10 +52,16 @@ from pipecat.services.openai.llm import (
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
try:
import google.ai.generativelanguage as glm
import google.generativeai as gai
from google import genai
from google.api_core.exceptions import DeadlineExceeded
from google.generativeai.types import GenerationConfig
from google.genai.types import (
Blob,
Content,
FunctionCall,
FunctionResponse,
GenerateContentConfig,
Part,
)
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
@@ -65,9 +71,7 @@ except ModuleNotFoundError as e:
class GoogleUserContextAggregator(OpenAIUserContextAggregator):
async def push_aggregation(self):
if len(self._aggregation) > 0:
self._context.add_message(
glm.Content(role="user", parts=[glm.Part(text=self._aggregation)])
)
self._context.add_message(Content(role="user", parts=[Part(text=self._aggregation)]))
# Reset the aggregation. Reset it before pushing it down, otherwise
# if the tasks gets cancelled we won't be able to clear things up.
@@ -83,15 +87,15 @@ class GoogleUserContextAggregator(OpenAIUserContextAggregator):
class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
async def handle_aggregation(self, aggregation: str):
self._context.add_message(glm.Content(role="model", parts=[glm.Part(text=aggregation)]))
self._context.add_message(Content(role="model", parts=[Part(text=aggregation)]))
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
self._context.add_message(
glm.Content(
Content(
role="model",
parts=[
glm.Part(
function_call=glm.FunctionCall(
Part(
function_call=FunctionCall(
id=frame.tool_call_id, name=frame.function_name, args=frame.arguments
)
)
@@ -99,11 +103,11 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
)
)
self._context.add_message(
glm.Content(
Content(
role="user",
parts=[
glm.Part(
function_response=glm.FunctionResponse(
Part(
function_response=FunctionResponse(
id=frame.tool_call_id,
name=frame.function_name,
response={"response": "IN_PROGRESS"},
@@ -187,7 +191,7 @@ class GoogleLLMContext(OpenAILLMContext):
# Convert each message individually
converted_messages = []
for msg in messages:
if isinstance(msg, glm.Content):
if isinstance(msg, Content):
# Already in Gemini format
converted_messages.append(msg)
else:
@@ -202,7 +206,7 @@ class GoogleLLMContext(OpenAILLMContext):
def get_messages_for_logging(self):
msgs = []
for message in self.messages:
obj = glm.Content.to_dict(message)
obj = message.to_json_dict()
try:
if "parts" in obj:
for part in obj["parts"]:
@@ -221,10 +225,10 @@ class GoogleLLMContext(OpenAILLMContext):
parts = []
if text:
parts.append(glm.Part(text=text))
parts.append(glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue())))
parts.append(Part(text=text))
parts.append(Part(inline_data=Blob(mime_type="image/jpeg", data=buffer.getvalue())))
self.add_message(glm.Content(role="user", parts=parts))
self.add_message(Content(role="user", parts=parts))
def add_audio_frames_message(
self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
@@ -239,10 +243,10 @@ class GoogleLLMContext(OpenAILLMContext):
data = b"".join(frame.audio for frame in audio_frames)
# NOTE(aleix): According to the docs only text or inline_data should be needed.
# (see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference)
parts.append(glm.Part(text=text))
parts.append(Part(text=text))
parts.append(
glm.Part(
inline_data=glm.Blob(
Part(
inline_data=Blob(
mime_type="audio/wav",
data=(
bytes(
@@ -252,7 +256,7 @@ class GoogleLLMContext(OpenAILLMContext):
)
),
)
self.add_message(glm.Content(role="user", parts=parts))
self.add_message(Content(role="user", parts=parts))
# message = {"mime_type": "audio/mp3", "data": bytes(data + create_wav_header(sample_rate, num_channels, 16, len(data)))}
# self.add_message(message)
@@ -271,7 +275,7 @@ class GoogleLLMContext(OpenAILLMContext):
}
Returns:
glm.Content object with:
Content object with:
- role: "user" or "model" (converted from "assistant")
- parts: List[Part] containing text, inline_data, or function calls
Returns None for system messages.
@@ -288,8 +292,8 @@ class GoogleLLMContext(OpenAILLMContext):
if message.get("tool_calls"):
for tc in message["tool_calls"]:
parts.append(
glm.Part(
function_call=glm.FunctionCall(
Part(
function_call=FunctionCall(
name=tc["function"]["name"],
args=json.loads(tc["function"]["arguments"]),
)
@@ -298,30 +302,30 @@ class GoogleLLMContext(OpenAILLMContext):
elif role == "tool":
role = "model"
parts.append(
glm.Part(
function_response=glm.FunctionResponse(
Part(
function_response=FunctionResponse(
name="tool_call_result", # seems to work to hard-code the same name every time
response=json.loads(message["content"]),
)
)
)
elif isinstance(content, str):
parts.append(glm.Part(text=content))
parts.append(Part(text=content))
elif isinstance(content, list):
for c in content:
if c["type"] == "text":
parts.append(glm.Part(text=c["text"]))
parts.append(Part(text=c["text"]))
elif c["type"] == "image_url":
parts.append(
glm.Part(
inline_data=glm.Blob(
Part(
inline_data=Blob(
mime_type="image/jpeg",
data=base64.b64decode(c["image_url"]["url"].split(",")[1]),
)
)
)
message = glm.Content(role=role, parts=parts)
message = Content(role=role, parts=parts)
return message
def to_standard_messages(self, obj) -> list:
@@ -409,7 +413,7 @@ class GoogleLLMContext(OpenAILLMContext):
# Process each message, preserving Google-formatted messages and converting others
for message in self._messages:
if isinstance(message, glm.Content):
if isinstance(message, Content):
# Keep existing Google-formatted messages (e.g., function calls/responses)
converted_messages.append(message)
continue
@@ -433,9 +437,7 @@ class GoogleLLMContext(OpenAILLMContext):
# Add system message back as a user message if we only have function messages
if self.system_message and not has_regular_messages:
self._messages.append(
glm.Content(role="user", parts=[glm.Part(text=self.system_message)])
)
self._messages.append(Content(role="user", parts=[Part(text=self.system_message)]))
# Remove any empty messages
self._messages = [m for m in self._messages if m.parts]
@@ -463,7 +465,7 @@ class GoogleLLMService(LLMService):
self,
*,
api_key: str,
model: str = "gemini-2.0-flash-001",
model: str = "gemini-2.0-flash",
params: InputParams = InputParams(),
system_instruction: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
@@ -471,10 +473,10 @@ class GoogleLLMService(LLMService):
**kwargs,
):
super().__init__(**kwargs)
gai.configure(api_key=api_key)
self.set_model_name(model)
self._api_key = api_key
self._system_instruction = system_instruction
self._create_client()
self._create_client(api_key)
self._settings = {
"max_tokens": params.max_tokens,
"temperature": params.temperature,
@@ -488,10 +490,8 @@ class GoogleLLMService(LLMService):
def can_generate_metrics(self) -> bool:
return True
def _create_client(self):
self._client = gai.GenerativeModel(
self._model_name, system_instruction=self._system_instruction
)
def _create_client(self, api_key: str):
self._client = genai.Client(api_key=api_key)
async def _process_context(self, context: OpenAILLMContext):
await self.push_frame(LLMFullResponseStartFrame())
@@ -513,23 +513,7 @@ class GoogleLLMService(LLMService):
if context.system_message and self._system_instruction != context.system_message:
logger.debug(f"System instruction changed: {context.system_message}")
self._system_instruction = context.system_message
self._create_client()
# Filter out None values and create GenerationConfig
generation_params = {
k: v
for k, v in {
"temperature": self._settings["temperature"],
"top_p": self._settings["top_p"],
"top_k": self._settings["top_k"],
"max_output_tokens": self._settings["max_tokens"],
}.items()
if v is not None
}
generation_config = GenerationConfig(**generation_params) if generation_params else None
await self.start_ttfb_metrics()
tools = []
if context.tools:
tools = context.tools
@@ -538,112 +522,104 @@ class GoogleLLMService(LLMService):
tool_config = None
if self._tool_config:
tool_config = self._tool_config
response = await self._client.generate_content_async(
# Filter out None values and create GenerationContentConfig
generation_params = {
k: v
for k, v in {
"system_instruction": self._system_instruction,
"temperature": self._settings["temperature"],
"top_p": self._settings["top_p"],
"top_k": self._settings["top_k"],
"max_output_tokens": self._settings["max_tokens"],
"tools": tools,
"tool_config": tool_config,
}.items()
if v is not None
}
generation_config = (
GenerateContentConfig(**generation_params) if generation_params else None
)
await self.start_ttfb_metrics()
response = await self._client.aio.models.generate_content_stream(
model=self._model_name,
contents=messages,
tools=tools,
stream=True,
generation_config=generation_config,
tool_config=tool_config,
config=generation_config,
)
await self.stop_ttfb_metrics()
if response.usage_metadata:
# Use only the prompt token count from the response object
prompt_tokens = response.usage_metadata.prompt_token_count
total_tokens = prompt_tokens
async for chunk in response:
if chunk.usage_metadata:
# Use only the completion_tokens from the chunks. Prompt tokens are already counted and
# are repeated here.
completion_tokens += chunk.usage_metadata.candidates_token_count
total_tokens += chunk.usage_metadata.candidates_token_count
try:
for c in chunk.parts:
if c.text:
search_result += c.text
await self.push_frame(LLMTextFrame(c.text))
elif c.function_call:
logger.debug(f"Function call: {c.function_call}")
args = type(c.function_call).to_dict(c.function_call).get("args", {})
await self.call_function(
context=context,
tool_call_id=str(uuid.uuid4()),
function_name=c.function_call.name,
arguments=args,
)
# Handle grounding metadata
# It seems only the last chunk that we receive may contain this information
# If the response doesn't include groundingMetadata, this means the response wasn't grounded.
if chunk.candidates:
for candidate in chunk.candidates:
# logger.debug(f"candidate received: {candidate}")
# Extract grounding metadata
grounding_metadata = (
{
"rendered_content": getattr(
getattr(candidate, "grounding_metadata", None),
"search_entry_point",
None,
).rendered_content
if hasattr(
getattr(candidate, "grounding_metadata", None),
"search_entry_point",
)
else None,
"origins": [
{
"site_uri": getattr(grounding_chunk.web, "uri", None),
"site_title": getattr(
grounding_chunk.web, "title", None
),
"results": [
{
"text": getattr(
grounding_support.segment, "text", ""
),
"confidence": getattr(
grounding_support, "confidence_scores", None
),
}
for grounding_support in getattr(
getattr(candidate, "grounding_metadata", None),
"grounding_supports",
[],
)
if index
in getattr(
grounding_support, "grounding_chunk_indices", []
)
],
}
for index, grounding_chunk in enumerate(
getattr(
getattr(candidate, "grounding_metadata", None),
"grounding_chunks",
[],
)
)
],
}
if getattr(candidate, "grounding_metadata", None)
else None
)
except Exception as e:
# Google LLMs seem to flag safety issues a lot!
if chunk.candidates[0].finish_reason == 3:
logger.debug(
f"LLM refused to generate content for safety reasons - {messages}."
)
else:
logger.exception(f"{self} error: {e}")
prompt_tokens += chunk.usage_metadata.prompt_token_count or 0
completion_tokens += chunk.usage_metadata.candidates_token_count or 0
total_tokens += chunk.usage_metadata.total_token_count or 0
if not chunk.candidates:
continue
for candidate in chunk.candidates:
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if not part.thought and part.text:
search_result += part.text
await self.push_frame(LLMTextFrame(part.text))
elif part.function_call:
function_call = part.function_call
id = function_call.id or str(uuid.uuid4())
logger.debug(f"Function call: {function_call.name}:{id}")
await self.call_function(
context=context,
tool_call_id=id,
function_name=function_call.name,
arguments=function_call.args or {},
)
if (
candidate.grounding_metadata
and candidate.grounding_metadata.grounding_chunks
):
m = candidate.grounding_metadata
rendered_content = (
m.search_entry_point.rendered_content if m.search_entry_point else None
)
origins = [
{
"site_uri": grounding_chunk.web.uri
if grounding_chunk.web
else None,
"site_title": grounding_chunk.web.title
if grounding_chunk.web
else None,
"results": [
{
"text": grounding_support.segment.text
if grounding_support.segment
else "",
"confidence": grounding_support.confidence_scores,
}
for grounding_support in (
m.grounding_supports if m.grounding_supports else []
)
if grounding_support.grounding_chunk_indices
and index in grounding_support.grounding_chunk_indices
],
}
for index, grounding_chunk in enumerate(
m.grounding_chunks if m.grounding_chunks else []
)
]
grounding_metadata = {
"rendered_content": rendered_content,
"origins": origins,
}
except DeadlineExceeded:
await self._call_event_handler("on_completion_timeout")
except Exception as e:
logger.exception(f"{self} exception: {e}")
finally:
if grounding_metadata is not None and isinstance(grounding_metadata, dict):
if grounding_metadata and isinstance(grounding_metadata, dict):
llm_search_frame = LLMSearchResponseFrame(
search_result=search_result,
origins=grounding_metadata["origins"],

View File

@@ -8,8 +8,6 @@ import json
import unittest
from typing import Any
import google.ai.generativelanguage as glm
from pipecat.frames.frames import (
EmulateUserStartedSpeakingFrame,
EmulateUserStoppedSpeakingFrame,
@@ -758,13 +756,13 @@ class TestGoogleUserContextAggregator(
AGGREGATOR_CLASS = GoogleUserContextAggregator
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
obj = glm.Content.to_dict(context.messages[index])
obj = context.messages[index].to_json_dict()
assert obj["parts"][0]["text"] == content
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
):
obj = glm.Content.to_dict(context.messages[index])
obj = context.messages[index].to_json_dict()
assert obj["parts"][0]["text"] == content
@@ -776,17 +774,17 @@ class TestGoogleAssistantContextAggregator(
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
obj = glm.Content.to_dict(context.messages[index])
obj = context.messages[index].to_json_dict()
assert obj["parts"][0]["text"] == content
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
):
obj = glm.Content.to_dict(context.messages[index])
obj = context.messages[index].to_json_dict()
assert obj["parts"][0]["text"] == content
def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any):
obj = glm.Content.to_dict(context.messages[index])
obj = context.messages[index].to_json_dict()
assert obj["parts"][0]["function_response"]["response"]["value"] == json.dumps(content)