From e97de43de26b1743c4b3da4a65d3c76527d35684 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Mon, 14 Apr 2025 13:08:13 -0700 Subject: [PATCH 1/2] add LLMUserAggregatorParams and LLMAssistantAggregatorParams --- CHANGELOG.md | 12 ++- .../22d-natural-conversation-gemini-audio.py | 7 +- .../processors/aggregators/llm_response.py | 75 ++++++++++++++++--- src/pipecat/services/anthropic/llm.py | 27 +++---- .../services/gemini_multimodal_live/gemini.py | 27 ++++--- src/pipecat/services/google/llm.py | 38 +++++----- src/pipecat/services/grok/llm.py | 23 +++--- src/pipecat/services/llm_service.py | 10 ++- src/pipecat/services/openai/llm.py | 20 +++-- .../services/openai_realtime_beta/openai.py | 42 +++++------ tests/test_context_aggregators.py | 57 ++++++++++---- tests/test_langchain.py | 5 +- 12 files changed, 218 insertions(+), 125 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f8eb48217..80b1369d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `DeepgramTTSService` accepts `base_url` argument again, allowing you to connect to an on-prem service. +- Added `LLMUserAggregatorParams` and `LLMAssistantAggregatorParams` which allow + you to control aggregator settings. You can now pass these arguments when + creating aggregator pairs with `create_context_aggregator()`. + - It is now possible to disable `SoundfileMixer` when created. You can then use `MixerEnableFrame` to dynamically enable it when necessary. @@ -38,6 +42,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `DeepgramSTTService` parameter `url` is now deprecated, use `base_url` instead. +### Removed + +- Parameters `user_kwargs` and `assistant_kwargs` when creating a context + aggregator pair using `create_context_aggregator()` have been removed. Use + `user_params` and `assistant_params` instead. + ### Fixed - Fixed a `TavusVideoService` issue that was causing audio choppiness. @@ -45,7 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed an issue in `SmallWebRTCTransport` where an error was thrown if the client did not create a video transceiver. -- Fixed an issue where LLM input parameters were not working and applied correctly in `GoogleVertexLLMService`, causing +- Fixed an issue where LLM input parameters were not working and applied correctly in `GoogleVertexLLMService`, causing unexpected behavior during inference. ## [0.0.63] - 2025-04-11 diff --git a/examples/foundational/22d-natural-conversation-gemini-audio.py b/examples/foundational/22d-natural-conversation-gemini-audio.py index a0035e11c..40bc33e48 100644 --- a/examples/foundational/22d-natural-conversation-gemini-audio.py +++ b/examples/foundational/22d-natural-conversation-gemini-audio.py @@ -33,7 +33,10 @@ from pipecat.pipeline.parallel_pipeline import ParallelPipeline from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask -from pipecat.processors.aggregators.llm_response import LLMAssistantResponseAggregator +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantAggregatorParams, + LLMAssistantResponseAggregator, +) from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, @@ -478,7 +481,7 @@ class LLMAggregatorBuffer(LLMAssistantResponseAggregator): """Buffers the output of the transcription LLM. Used by the bot output gate.""" def __init__(self, **kwargs): - super().__init__(expect_stripped_words=False) + super().__init__(params=LLMAssistantAggregatorParams(expect_stripped_words=False)) self._transcription = "" async def process_frame(self, frame: Frame, direction: FrameDirection): diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index dccceea1f..1f7a33ed1 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -6,6 +6,7 @@ import asyncio from abc import abstractmethod +from dataclasses import dataclass from typing import Dict, List, Literal, Set from loguru import logger @@ -46,6 +47,16 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.utils.time import time_now_iso8601 +@dataclass +class LLMUserAggregatorParams: + aggregation_timeout: float = 1.0 + + +@dataclass +class LLMAssistantAggregatorParams: + expect_stripped_words: bool = True + + class LLMFullResponseAggregator(FrameProcessor): """This is an LLM aggregator that aggregates a full LLM completion. It aggregates LLM text frames (tokens) received between @@ -230,11 +241,23 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): def __init__( self, context: OpenAILLMContext, - aggregation_timeout: float = 1.0, + *, + params: LLMUserAggregatorParams = LLMUserAggregatorParams(), **kwargs, ): super().__init__(context=context, role="user", **kwargs) - self._aggregation_timeout = aggregation_timeout + self._params = params + if "aggregation_timeout" in kwargs: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "Parameter 'aggregation_timeout' is deprecated, use 'params' instead.", + DeprecationWarning, + ) + + self._params.aggregation_timeout = kwargs["aggregation_timeout"] self._seen_interim_results = False self._user_speaking = False @@ -357,7 +380,9 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): async def _aggregation_task_handler(self): while True: try: - await asyncio.wait_for(self._aggregation_event.wait(), self._aggregation_timeout) + await asyncio.wait_for( + self._aggregation_event.wait(), self._params.aggregation_timeout + ) await self._maybe_push_bot_interruption() except asyncio.TimeoutError: if not self._user_speaking: @@ -394,9 +419,27 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator): """ - def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True, **kwargs): + def __init__( + self, + context: OpenAILLMContext, + *, + params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), + **kwargs, + ): super().__init__(context=context, role="assistant", **kwargs) - self._expect_stripped_words = expect_stripped_words + self._params = params + + if "expect_stripped_words" in kwargs: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "Parameter 'expect_stripped_words' is deprecated, use 'params' instead.", + DeprecationWarning, + ) + + self._params.expect_stripped_words = kwargs["expect_stripped_words"] self._started = 0 self._function_calls_in_progress: Dict[str, FunctionCallInProgressFrame] = {} @@ -558,7 +601,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator): if not self._started: return - if self._expect_stripped_words: + if self._params.expect_stripped_words: self._aggregation += f" {frame.text}" if self._aggregation else frame.text else: self._aggregation += frame.text @@ -572,8 +615,14 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator): class LLMUserResponseAggregator(LLMUserContextAggregator): - def __init__(self, messages: List[dict] = [], **kwargs): - super().__init__(context=OpenAILLMContext(messages), **kwargs) + def __init__( + self, + messages: List[dict] = [], + *, + params: LLMUserAggregatorParams = LLMUserAggregatorParams(), + **kwargs, + ): + super().__init__(context=OpenAILLMContext(messages), params=params, **kwargs) async def push_aggregation(self): if len(self._aggregation) > 0: @@ -588,8 +637,14 @@ class LLMUserResponseAggregator(LLMUserContextAggregator): class LLMAssistantResponseAggregator(LLMAssistantContextAggregator): - def __init__(self, messages: List[dict] = [], **kwargs): - super().__init__(context=OpenAILLMContext(messages), **kwargs) + def __init__( + self, + messages: List[dict] = [], + *, + params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), + **kwargs, + ): + super().__init__(context=OpenAILLMContext(messages), params=params, **kwargs) async def push_aggregation(self): if len(self._aggregation) > 0: diff --git a/src/pipecat/services/anthropic/llm.py b/src/pipecat/services/anthropic/llm.py index 9e75e198b..277e29f83 100644 --- a/src/pipecat/services/anthropic/llm.py +++ b/src/pipecat/services/anthropic/llm.py @@ -11,7 +11,7 @@ import io import json import re from dataclasses import dataclass -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, List, Optional, Union import httpx from loguru import logger @@ -35,7 +35,9 @@ from pipecat.frames.frames import ( ) from pipecat.metrics.metrics import LLMTokenUsage from pipecat.processors.aggregators.llm_response import ( + LLMAssistantAggregatorParams, LLMAssistantContextAggregator, + LLMUserAggregatorParams, LLMUserContextAggregator, ) from pipecat.processors.aggregators.openai_llm_context import ( @@ -49,10 +51,7 @@ try: from anthropic import NOT_GIVEN, AsyncAnthropic, NotGiven except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error( - "In order to use Anthropic, you need to `pip install pipecat-ai[anthropic]`. " - + "Also, set `ANTHROPIC_API_KEY` environment variable." - ) + logger.error("In order to use Anthropic, you need to `pip install pipecat-ai[anthropic]`.") raise Exception(f"Missing module: {e}") @@ -120,8 +119,8 @@ class AnthropicLLMService(LLMService): self, context: OpenAILLMContext, *, - user_kwargs: Mapping[str, Any] = {}, - assistant_kwargs: Mapping[str, Any] = {}, + user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), + assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), ) -> AnthropicContextAggregatorPair: """Create an instance of AnthropicContextAggregatorPair from an OpenAILLMContext. Constructor keyword arguments for both the user and @@ -129,12 +128,10 @@ class AnthropicLLMService(LLMService): Args: context (OpenAILLMContext): The LLM context. - user_kwargs (Mapping[str, Any], optional): Additional keyword - arguments for the user context aggregator constructor. Defaults - to an empty mapping. - assistant_kwargs (Mapping[str, Any], optional): Additional keyword - arguments for the assistant context aggregator - constructor. Defaults to an empty mapping. + user_params (LLMUserAggregatorParams, optional): User aggregator + parameters. + assistant_params (LLMAssistantAggregatorParams, optional): User + aggregator parameters. Returns: AnthropicContextAggregatorPair: A pair of context aggregators, one @@ -146,8 +143,8 @@ class AnthropicLLMService(LLMService): if isinstance(context, OpenAILLMContext): context = AnthropicLLMContext.from_openai_context(context) - user = AnthropicUserContextAggregator(context, **user_kwargs) - assistant = AnthropicAssistantContextAggregator(context, **assistant_kwargs) + user = AnthropicUserContextAggregator(context, params=user_params) + assistant = AnthropicAssistantContextAggregator(context, params=assistant_params) return AnthropicContextAggregatorPair(_user=user, _assistant=assistant) async def _process_context(self, context: OpenAILLMContext): diff --git a/src/pipecat/services/gemini_multimodal_live/gemini.py b/src/pipecat/services/gemini_multimodal_live/gemini.py index 1051d2f60..f89e97bdb 100644 --- a/src/pipecat/services/gemini_multimodal_live/gemini.py +++ b/src/pipecat/services/gemini_multimodal_live/gemini.py @@ -45,6 +45,10 @@ from pipecat.frames.frames import ( UserStoppedSpeakingFrame, ) from pipecat.metrics.metrics import LLMTokenUsage +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantAggregatorParams, + LLMUserAggregatorParams, +) from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, @@ -871,8 +875,8 @@ class GeminiMultimodalLiveLLMService(LLMService): self, context: OpenAILLMContext, *, - user_kwargs: Mapping[str, Any] = {}, - assistant_kwargs: Mapping[str, Any] = {}, + user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), + assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), ) -> GeminiMultimodalLiveContextAggregatorPair: """Create an instance of GeminiMultimodalLiveContextAggregatorPair from an OpenAILLMContext. Constructor keyword arguments for both the user and @@ -880,12 +884,10 @@ class GeminiMultimodalLiveLLMService(LLMService): Args: context (OpenAILLMContext): The LLM context. - user_kwargs (Mapping[str, Any], optional): Additional keyword - arguments for the user context aggregator constructor. Defaults - to an empty mapping. - assistant_kwargs (Mapping[str, Any], optional): Additional keyword - arguments for the assistant context aggregator - constructor. Defaults to an empty mapping. + user_params (LLMUserAggregatorParams, optional): User aggregator + parameters. + assistant_params (LLMAssistantAggregatorParams, optional): User + aggregator parameters. Returns: GeminiMultimodalLiveContextAggregatorPair: A pair of context @@ -896,11 +898,8 @@ class GeminiMultimodalLiveLLMService(LLMService): context.set_llm_adapter(self.get_llm_adapter()) GeminiMultimodalLiveContext.upgrade(context) - user = GeminiMultimodalLiveUserContextAggregator(context, **user_kwargs) + user = GeminiMultimodalLiveUserContextAggregator(context, params=user_params) - default_assistant_kwargs = {"expect_stripped_words": True} - default_assistant_kwargs.update(assistant_kwargs) - assistant = GeminiMultimodalLiveAssistantContextAggregator( - context, **default_assistant_kwargs - ) + assistant_params.expect_stripped_words = True + assistant = GeminiMultimodalLiveAssistantContextAggregator(context, params=assistant_params) return GeminiMultimodalLiveContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/services/google/llm.py b/src/pipecat/services/google/llm.py index a9dd0cb3a..bf9714817 100644 --- a/src/pipecat/services/google/llm.py +++ b/src/pipecat/services/google/llm.py @@ -9,21 +9,14 @@ import io import json import os import uuid - -from google.api_core.exceptions import DeadlineExceeded - -from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter - -# Suppress gRPC fork warnings -os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false" - from dataclasses import dataclass -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, List, Optional from loguru import logger from PIL import Image from pydantic import BaseModel, Field +from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter from pipecat.frames.frames import ( AudioRawFrame, Frame, @@ -39,6 +32,10 @@ from pipecat.frames.frames import ( VisionImageRawFrame, ) from pipecat.metrics.metrics import LLMTokenUsage +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantAggregatorParams, + LLMUserAggregatorParams, +) from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, @@ -51,11 +48,14 @@ from pipecat.services.openai.llm import ( OpenAIUserContextAggregator, ) +# Suppress gRPC fork warnings +os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false" + try: import google.ai.generativelanguage as glm import google.generativeai as gai + from google.api_core.exceptions import DeadlineExceeded from google.generativeai.types import GenerationConfig - except ModuleNotFoundError as e: logger.error(f"Exception: {e}") logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.") @@ -686,8 +686,8 @@ class GoogleLLMService(LLMService): self, context: OpenAILLMContext, *, - user_kwargs: Mapping[str, Any] = {}, - assistant_kwargs: Mapping[str, Any] = {}, + user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), + assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), ) -> GoogleContextAggregatorPair: """Create an instance of GoogleContextAggregatorPair from an OpenAILLMContext. Constructor keyword arguments for both the user and @@ -695,12 +695,10 @@ class GoogleLLMService(LLMService): Args: context (OpenAILLMContext): The LLM context. - user_kwargs (Mapping[str, Any], optional): Additional keyword - arguments for the user context aggregator constructor. Defaults - to an empty mapping. - assistant_kwargs (Mapping[str, Any], optional): Additional keyword - arguments for the assistant context aggregator - constructor. Defaults to an empty mapping. + user_params (LLMUserAggregatorParams, optional): User aggregator + parameters. + assistant_params (LLMAssistantAggregatorParams, optional): User + aggregator parameters. Returns: GoogleContextAggregatorPair: A pair of context aggregators, one for @@ -712,6 +710,6 @@ class GoogleLLMService(LLMService): if isinstance(context, OpenAILLMContext): context = GoogleLLMContext.upgrade_to_google(context) - user = GoogleUserContextAggregator(context, **user_kwargs) - assistant = GoogleAssistantContextAggregator(context, **assistant_kwargs) + user = GoogleUserContextAggregator(context, params=user_params) + assistant = GoogleAssistantContextAggregator(context, params=assistant_params) return GoogleContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/services/grok/llm.py b/src/pipecat/services/grok/llm.py index 57517eb3e..90c8df14f 100644 --- a/src/pipecat/services/grok/llm.py +++ b/src/pipecat/services/grok/llm.py @@ -5,11 +5,14 @@ # from dataclasses import dataclass -from typing import Any, Mapping from loguru import logger from pipecat.metrics.metrics import LLMTokenUsage +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantAggregatorParams, + LLMUserAggregatorParams, +) from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.services.openai.llm import ( OpenAIAssistantContextAggregator, @@ -124,8 +127,8 @@ class GrokLLMService(OpenAILLMService): self, context: OpenAILLMContext, *, - user_kwargs: Mapping[str, Any] = {}, - assistant_kwargs: Mapping[str, Any] = {}, + user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), + assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), ) -> GrokContextAggregatorPair: """Create an instance of GrokContextAggregatorPair from an OpenAILLMContext. Constructor keyword arguments for both the user and @@ -133,12 +136,10 @@ class GrokLLMService(OpenAILLMService): Args: context (OpenAILLMContext): The LLM context. - user_kwargs (Mapping[str, Any], optional): Additional keyword - arguments for the user context aggregator constructor. Defaults - to an empty mapping. - assistant_kwargs (Mapping[str, Any], optional): Additional keyword - arguments for the assistant context aggregator - constructor. Defaults to an empty mapping. + user_params (LLMUserAggregatorParams, optional): User aggregator + parameters. + assistant_params (LLMAssistantAggregatorParams, optional): User + aggregator parameters. Returns: GrokContextAggregatorPair: A pair of context aggregators, one for @@ -148,6 +149,6 @@ class GrokLLMService(OpenAILLMService): """ context.set_llm_adapter(self.get_llm_adapter()) - user = OpenAIUserContextAggregator(context, **user_kwargs) - assistant = OpenAIAssistantContextAggregator(context, **assistant_kwargs) + user = OpenAIUserContextAggregator(context, params=user_params) + assistant = OpenAIAssistantContextAggregator(context, params=assistant_params) return GrokContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/services/llm_service.py b/src/pipecat/services/llm_service.py index 7f7238b47..6ac841f25 100644 --- a/src/pipecat/services/llm_service.py +++ b/src/pipecat/services/llm_service.py @@ -6,7 +6,7 @@ import asyncio from dataclasses import dataclass -from typing import Any, Mapping, Optional, Set, Tuple, Type +from typing import Any, Optional, Set, Tuple, Type from loguru import logger @@ -20,6 +20,10 @@ from pipecat.frames.frames import ( StartInterruptionFrame, UserImageRequestFrame, ) +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantAggregatorParams, + LLMUserAggregatorParams, +) from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_service import AIService @@ -55,8 +59,8 @@ class LLMService(AIService): self, context: OpenAILLMContext, *, - user_kwargs: Mapping[str, Any] = {}, - assistant_kwargs: Mapping[str, Any] = {}, + user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), + assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), ) -> Any: pass diff --git a/src/pipecat/services/openai/llm.py b/src/pipecat/services/openai/llm.py index 0b634c01b..07b564fe1 100644 --- a/src/pipecat/services/openai/llm.py +++ b/src/pipecat/services/openai/llm.py @@ -6,7 +6,7 @@ import json from dataclasses import dataclass -from typing import Any, Mapping +from typing import Any from pipecat.frames.frames import ( FunctionCallCancelFrame, @@ -15,7 +15,9 @@ from pipecat.frames.frames import ( UserImageRawFrame, ) from pipecat.processors.aggregators.llm_response import ( + LLMAssistantAggregatorParams, LLMAssistantContextAggregator, + LLMUserAggregatorParams, LLMUserContextAggregator, ) from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext @@ -48,8 +50,8 @@ class OpenAILLMService(BaseOpenAILLMService): self, context: OpenAILLMContext, *, - user_kwargs: Mapping[str, Any] = {}, - assistant_kwargs: Mapping[str, Any] = {}, + user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), + assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), ) -> OpenAIContextAggregatorPair: """Create an instance of OpenAIContextAggregatorPair from an OpenAILLMContext. Constructor keyword arguments for both the user and @@ -57,12 +59,8 @@ class OpenAILLMService(BaseOpenAILLMService): Args: context (OpenAILLMContext): The LLM context. - user_kwargs (Mapping[str, Any], optional): Additional keyword - arguments for the user context aggregator constructor. Defaults - to an empty mapping. - assistant_kwargs (Mapping[str, Any], optional): Additional keyword - arguments for the assistant context aggregator - constructor. Defaults to an empty mapping. + user_params (LLMUserAggregatorParams, optional): User aggregator parameters. + assistant_params (LLMAssistantAggregatorParams, optional): User aggregator parameters. Returns: OpenAIContextAggregatorPair: A pair of context aggregators, one for @@ -71,8 +69,8 @@ class OpenAILLMService(BaseOpenAILLMService): """ context.set_llm_adapter(self.get_llm_adapter()) - user = OpenAIUserContextAggregator(context, **user_kwargs) - assistant = OpenAIAssistantContextAggregator(context, **assistant_kwargs) + user = OpenAIUserContextAggregator(context, params=user_params) + assistant = OpenAIAssistantContextAggregator(context, params=assistant_params) return OpenAIContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/services/openai_realtime_beta/openai.py b/src/pipecat/services/openai_realtime_beta/openai.py index 6da1a3109..94d848b77 100644 --- a/src/pipecat/services/openai_realtime_beta/openai.py +++ b/src/pipecat/services/openai_realtime_beta/openai.py @@ -8,19 +8,9 @@ import base64 import json import time from dataclasses import dataclass -from typing import Any, Mapping from loguru import logger -try: - import websockets -except ModuleNotFoundError as e: - logger.error(f"Exception: {e}") - logger.error( - "In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable." - ) - raise Exception(f"Missing module: {e}") - from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter from pipecat.frames.frames import ( BotStoppedSpeakingFrame, @@ -48,6 +38,10 @@ from pipecat.frames.frames import ( UserStoppedSpeakingFrame, ) from pipecat.metrics.metrics import LLMTokenUsage +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantAggregatorParams, + LLMUserAggregatorParams, +) from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, @@ -65,6 +59,13 @@ from .context import ( ) from .frames import RealtimeFunctionCallResultFrame, RealtimeMessagesUpdateFrame +try: + import websockets +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use OpenAI, you need to `pip install pipecat-ai[openai]`.") + raise Exception(f"Missing module: {e}") + @dataclass class CurrentAudioResponse: @@ -650,8 +651,8 @@ class OpenAIRealtimeBetaLLMService(LLMService): self, context: OpenAILLMContext, *, - user_kwargs: Mapping[str, Any] = {}, - assistant_kwargs: Mapping[str, Any] = {}, + user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), + assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), ) -> OpenAIContextAggregatorPair: """Create an instance of OpenAIContextAggregatorPair from an OpenAILLMContext. Constructor keyword arguments for both the user and @@ -659,12 +660,10 @@ class OpenAIRealtimeBetaLLMService(LLMService): Args: context (OpenAILLMContext): The LLM context. - user_kwargs (Mapping[str, Any], optional): Additional keyword - arguments for the user context aggregator constructor. Defaults - to an empty mapping. - assistant_kwargs (Mapping[str, Any], optional): Additional keyword - arguments for the assistant context aggregator - constructor. Defaults to an empty mapping. + user_params (LLMUserAggregatorParams, optional): User aggregator + parameters. + assistant_params (LLMAssistantAggregatorParams, optional): User + aggregator parameters. Returns: OpenAIContextAggregatorPair: A pair of context aggregators, one for @@ -675,9 +674,8 @@ class OpenAIRealtimeBetaLLMService(LLMService): context.set_llm_adapter(self.get_llm_adapter()) OpenAIRealtimeLLMContext.upgrade_to_realtime(context) - user = OpenAIRealtimeUserContextAggregator(context, **user_kwargs) + user = OpenAIRealtimeUserContextAggregator(context, params=user_params) - default_assistant_kwargs = {"expect_stripped_words": False} - default_assistant_kwargs.update(assistant_kwargs) - assistant = OpenAIRealtimeAssistantContextAggregator(context, **default_assistant_kwargs) + assistant_params.expect_stripped_words = False + assistant = OpenAIRealtimeAssistantContextAggregator(context, params=assistant_params) return OpenAIContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/tests/test_context_aggregators.py b/tests/test_context_aggregators.py index 05734a64e..dfe210e07 100644 --- a/tests/test_context_aggregators.py +++ b/tests/test_context_aggregators.py @@ -26,7 +26,11 @@ from pipecat.frames.frames import ( UserStartedSpeakingFrame, UserStoppedSpeakingFrame, ) -from pipecat.processors.aggregators.llm_response import LLMUserContextAggregator +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantAggregatorParams, + LLMUserAggregatorParams, + LLMUserContextAggregator, +) from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, @@ -163,7 +167,9 @@ class BaseTestUserContextAggregator: assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) + aggregator = self.AGGREGATOR_CLASS( + context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) + ) frames_to_send = [ UserStartedSpeakingFrame(), InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), @@ -189,7 +195,9 @@ class BaseTestUserContextAggregator: assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) + aggregator = self.AGGREGATOR_CLASS( + context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) + ) frames_to_send = [ UserStartedSpeakingFrame(), InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), @@ -216,7 +224,9 @@ class BaseTestUserContextAggregator: assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) + aggregator = self.AGGREGATOR_CLASS( + context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) + ) frames_to_send = [ UserStartedSpeakingFrame(), UserStoppedSpeakingFrame(), @@ -240,7 +250,9 @@ class BaseTestUserContextAggregator: assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) + aggregator = self.AGGREGATOR_CLASS( + context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) + ) frames_to_send = [ UserStartedSpeakingFrame(), UserStoppedSpeakingFrame(), @@ -265,7 +277,9 @@ class BaseTestUserContextAggregator: assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) + aggregator = self.AGGREGATOR_CLASS( + context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) + ) frames_to_send = [ UserStartedSpeakingFrame(), TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), @@ -293,7 +307,9 @@ class BaseTestUserContextAggregator: assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) + aggregator = self.AGGREGATOR_CLASS( + context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) + ) frames_to_send = [ UserStartedSpeakingFrame(), UserStoppedSpeakingFrame(), @@ -318,7 +334,9 @@ class BaseTestUserContextAggregator: assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) + aggregator = self.AGGREGATOR_CLASS( + context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) + ) frames_to_send = [ UserStartedSpeakingFrame(), InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""), @@ -346,7 +364,9 @@ class BaseTestUserContextAggregator: assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) + aggregator = self.AGGREGATOR_CLASS( + context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) + ) frames_to_send = [ TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""), SleepFrame(sleep=AGGREGATION_SLEEP), @@ -366,7 +386,9 @@ class BaseTestUserContextAggregator: assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) + aggregator = self.AGGREGATOR_CLASS( + context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) + ) frames_to_send = [ InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""), SleepFrame(), @@ -389,8 +411,7 @@ class BaseTestUserContextAggregator: context = self.CONTEXT_CLASS() aggregator = self.AGGREGATOR_CLASS( - context, - aggregation_timeout=AGGREGATION_TIMEOUT, + context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) ) frames_to_send = [ UserStartedSpeakingFrame(), @@ -469,7 +490,9 @@ class BaseTestAssistantContextAggreagator: assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False) + aggregator = self.AGGREGATOR_CLASS( + context, params=LLMAssistantAggregatorParams(expect_stripped_words=False) + ) frames_to_send = [ LLMFullResponseStartFrame(), TextFrame(text="Hello "), @@ -513,7 +536,9 @@ class BaseTestAssistantContextAggreagator: assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False) + aggregator = self.AGGREGATOR_CLASS( + context, params=LLMAssistantAggregatorParams(expect_stripped_words=False) + ) frames_to_send = [ LLMFullResponseStartFrame(), TextFrame(text="Hello "), @@ -538,7 +563,9 @@ class BaseTestAssistantContextAggreagator: assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False) + aggregator = self.AGGREGATOR_CLASS( + context, params=LLMAssistantAggregatorParams(expect_stripped_words=False) + ) frames_to_send = [ LLMFullResponseStartFrame(), TextFrame(text="Hello "), diff --git a/tests/test_langchain.py b/tests/test_langchain.py index 6534f6bb0..3d907e084 100644 --- a/tests/test_langchain.py +++ b/tests/test_langchain.py @@ -20,6 +20,7 @@ from pipecat.frames.frames import ( ) from pipecat.pipeline.pipeline import Pipeline from pipecat.processors.aggregators.llm_response import ( + LLMAssistantAggregatorParams, LLMAssistantResponseAggregator, LLMUserResponseAggregator, ) @@ -63,7 +64,9 @@ class TestLangchain(unittest.IsolatedAsyncioTestCase): self.mock_proc = self.MockProcessor("token_collector") tma_in = LLMUserResponseAggregator(messages) - tma_out = LLMAssistantResponseAggregator(messages, expect_stripped_words=False) + tma_out = LLMAssistantResponseAggregator( + messages, params=LLMAssistantAggregatorParams(expect_stripped_words=False) + ) pipeline = Pipeline([tma_in, proc, self.mock_proc, tma_out]) From f385cc04608a099d3e7131340202db0f1c8271f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Mon, 14 Apr 2025 13:08:27 -0700 Subject: [PATCH 2/2] pyproject: add websockets as google dependency --- pyproject.toml | 2 +- src/pipecat/services/gemini_multimodal_live/gemini.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8b7c8546a..cb0cd3520 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ fal = [ "fal-client~=0.5.9" ] fireworks = [] fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ] gladia = [ "websockets~=13.1" ] -google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4" ] +google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4", "websockets~=13.1" ] grok = [] groq = [ "groq~=0.20.0" ] gstreamer = [ "pygobject~=3.50.0" ] diff --git a/src/pipecat/services/gemini_multimodal_live/gemini.py b/src/pipecat/services/gemini_multimodal_live/gemini.py index f89e97bdb..d953e5065 100644 --- a/src/pipecat/services/gemini_multimodal_live/gemini.py +++ b/src/pipecat/services/gemini_multimodal_live/gemini.py @@ -10,9 +10,8 @@ import json import time from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, List, Optional, Union -import websockets from loguru import logger from pydantic import BaseModel, Field @@ -65,6 +64,13 @@ from pipecat.utils.time import time_now_iso8601 from . import events from .audio_transcriber import AudioTranscriber +try: + import websockets +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.") + raise Exception(f"Missing module: {e}") + def language_to_gemini_language(language: Language) -> Optional[str]: """Maps a Language enum value to a Gemini Live supported language code.