From 17a1f305729e24c720b292bc2f51f81d1b338494 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 28 Feb 2025 18:02:11 -0800 Subject: [PATCH] LLMService: add user/assistant args to create_context_aggregator() --- CHANGELOG.md | 4 +++ src/pipecat/services/anthropic.py | 32 +++++++++++++---- .../services/gemini_multimodal_live/gemini.py | 34 +++++++++++++++--- src/pipecat/services/google/google.py | 32 +++++++++++++---- src/pipecat/services/grok.py | 32 +++++++++++++---- src/pipecat/services/openai.py | 32 +++++++++++++---- .../services/openai_realtime_beta/openai.py | 36 +++++++++++++++---- 7 files changed, 168 insertions(+), 34 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 61a95b3c2..4e6ae5c21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Allow passing user (`user_kwargs`) and assistant (`assistant_kwargs`) context + aggregator parameters when using `create_context_aggregator()`. The values are + passed as a mapping that will then be converted to arguments. + - Added `speed` as an `InputParam` for both `ElevenLabsTTSService` and `ElevenLabsHttpTTSService`. diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py index f10ee7ccb..bae780e62 100644 --- a/src/pipecat/services/anthropic.py +++ b/src/pipecat/services/anthropic.py @@ -11,7 +11,7 @@ import io import json import re from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Union import httpx from loguru import logger @@ -125,14 +125,34 @@ class AnthropicLLMService(LLMService): @staticmethod def create_context_aggregator( - context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True + context: OpenAILLMContext, + *, + user_kwargs: Mapping[str, Any] = {}, + assistant_kwargs: Mapping[str, Any] = {}, ) -> AnthropicContextAggregatorPair: + """Create an instance of AnthropicContextAggregatorPair from an + OpenAILLMContext. Constructor keyword arguments for both the user and + assistant aggregators can be provided. + + Args: + context (OpenAILLMContext): The LLM context. + user_kwargs (Mapping[str, Any], optional): Additional keyword + arguments for the user context aggregator constructor. Defaults + to an empty mapping. + assistant_kwargs (Mapping[str, Any], optional): Additional keyword + arguments for the assistant context aggregator + constructor. Defaults to an empty mapping. + + Returns: + AnthropicContextAggregatorPair: A pair of context aggregators, one + for the user and one for the assistant, encapsulated in an + AnthropicContextAggregatorPair. + + """ if isinstance(context, OpenAILLMContext): context = AnthropicLLMContext.from_openai_context(context) - user = AnthropicUserContextAggregator(context) - assistant = AnthropicAssistantContextAggregator( - context, expect_stripped_words=assistant_expect_stripped_words - ) + user = AnthropicUserContextAggregator(context, **user_kwargs) + assistant = AnthropicAssistantContextAggregator(context, **assistant_kwargs) return AnthropicContextAggregatorPair(_user=user, _assistant=assistant) async def _process_context(self, context: OpenAILLMContext): diff --git a/src/pipecat/services/gemini_multimodal_live/gemini.py b/src/pipecat/services/gemini_multimodal_live/gemini.py index 6e7a1c0fa..934117c52 100644 --- a/src/pipecat/services/gemini_multimodal_live/gemini.py +++ b/src/pipecat/services/gemini_multimodal_live/gemini.py @@ -9,7 +9,7 @@ import base64 import json from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Mapping, Optional import websockets from loguru import logger @@ -701,11 +701,37 @@ class GeminiMultimodalLiveLLMService(LLMService): await self.push_frame(TTSStoppedFrame()) def create_context_aggregator( - self, context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = False + self, + context: OpenAILLMContext, + *, + user_kwargs: Mapping[str, Any] = {}, + assistant_kwargs: Mapping[str, Any] = {}, ) -> GeminiMultimodalLiveContextAggregatorPair: + """Create an instance of GeminiMultimodalLiveContextAggregatorPair from + an OpenAILLMContext. Constructor keyword arguments for both the user and + assistant aggregators can be provided. + + Args: + context (OpenAILLMContext): The LLM context. + user_kwargs (Mapping[str, Any], optional): Additional keyword + arguments for the user context aggregator constructor. Defaults + to an empty mapping. + assistant_kwargs (Mapping[str, Any], optional): Additional keyword + arguments for the assistant context aggregator + constructor. Defaults to an empty mapping. + + Returns: + GeminiMultimodalLiveContextAggregatorPair: A pair of context + aggregators, one for the user and one for the assistant, + encapsulated in an GeminiMultimodalLiveContextAggregatorPair. + + """ GeminiMultimodalLiveContext.upgrade(context) - user = GeminiMultimodalLiveUserContextAggregator(context) + user = GeminiMultimodalLiveUserContextAggregator(context, **user_kwargs) + + default_assistant_kwargs = {"expect_stripped_words": False} + default_assistant_kwargs.update(assistant_kwargs) assistant = GeminiMultimodalLiveAssistantContextAggregator( - context, expect_stripped_words=assistant_expect_stripped_words + context, **default_assistant_kwargs ) return GeminiMultimodalLiveContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/services/google/google.py b/src/pipecat/services/google/google.py index 50d0aaf83..cbbc73b47 100644 --- a/src/pipecat/services/google/google.py +++ b/src/pipecat/services/google/google.py @@ -19,7 +19,7 @@ from openai.types.chat import ChatCompletionChunk os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false" from dataclasses import dataclass -from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Union +from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Union from loguru import logger from PIL import Image @@ -1182,14 +1182,34 @@ class GoogleLLMService(LLMService): @staticmethod def create_context_aggregator( - context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True + context: OpenAILLMContext, + *, + user_kwargs: Mapping[str, Any] = {}, + assistant_kwargs: Mapping[str, Any] = {}, ) -> GoogleContextAggregatorPair: + """Create an instance of GoogleContextAggregatorPair from an + OpenAILLMContext. Constructor keyword arguments for both the user and + assistant aggregators can be provided. + + Args: + context (OpenAILLMContext): The LLM context. + user_kwargs (Mapping[str, Any], optional): Additional keyword + arguments for the user context aggregator constructor. Defaults + to an empty mapping. + assistant_kwargs (Mapping[str, Any], optional): Additional keyword + arguments for the assistant context aggregator + constructor. Defaults to an empty mapping. + + Returns: + GoogleContextAggregatorPair: A pair of context aggregators, one for + the user and one for the assistant, encapsulated in an + GoogleContextAggregatorPair. + + """ if isinstance(context, OpenAILLMContext): context = GoogleLLMContext.upgrade_to_google(context) - user = GoogleUserContextAggregator(context) - assistant = GoogleAssistantContextAggregator( - context, expect_stripped_words=assistant_expect_stripped_words - ) + user = GoogleUserContextAggregator(context, **user_kwargs) + assistant = GoogleAssistantContextAggregator(context, **assistant_kwargs) return GoogleContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/services/grok.py b/src/pipecat/services/grok.py index 3a4a5fb5e..1f1661cf4 100644 --- a/src/pipecat/services/grok.py +++ b/src/pipecat/services/grok.py @@ -7,7 +7,7 @@ import json from dataclasses import dataclass -from typing import Optional +from typing import Any, Mapping, Optional from loguru import logger @@ -208,10 +208,30 @@ class GrokLLMService(OpenAILLMService): @staticmethod def create_context_aggregator( - context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True + context: OpenAILLMContext, + *, + user_kwargs: Mapping[str, Any] = {}, + assistant_kwargs: Mapping[str, Any] = {}, ) -> GrokContextAggregatorPair: - user = OpenAIUserContextAggregator(context) - assistant = GrokAssistantContextAggregator( - context, expect_stripped_words=assistant_expect_stripped_words - ) + """Create an instance of GrokContextAggregatorPair from an + OpenAILLMContext. Constructor keyword arguments for both the user and + assistant aggregators can be provided. + + Args: + context (OpenAILLMContext): The LLM context. + user_kwargs (Mapping[str, Any], optional): Additional keyword + arguments for the user context aggregator constructor. Defaults + to an empty mapping. + assistant_kwargs (Mapping[str, Any], optional): Additional keyword + arguments for the assistant context aggregator + constructor. Defaults to an empty mapping. + + Returns: + GrokContextAggregatorPair: A pair of context aggregators, one for + the user and one for the assistant, encapsulated in an + GrokContextAggregatorPair. + + """ + user = OpenAIUserContextAggregator(context, **user_kwargs) + assistant = GrokAssistantContextAggregator(context, **assistant_kwargs) return GrokContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index 71da3f5e2..425882d6f 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -8,7 +8,7 @@ import base64 import io import json from dataclasses import dataclass -from typing import Any, AsyncGenerator, Dict, List, Literal, Optional +from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional import aiohttp import httpx @@ -345,12 +345,32 @@ class OpenAILLMService(BaseOpenAILLMService): @staticmethod def create_context_aggregator( - context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True + context: OpenAILLMContext, + *, + user_kwargs: Mapping[str, Any] = {}, + assistant_kwargs: Mapping[str, Any] = {}, ) -> OpenAIContextAggregatorPair: - user = OpenAIUserContextAggregator(context) - assistant = OpenAIAssistantContextAggregator( - context, expect_stripped_words=assistant_expect_stripped_words - ) + """Create an instance of OpenAIContextAggregatorPair from an + OpenAILLMContext. Constructor keyword arguments for both the user and + assistant aggregators can be provided. + + Args: + context (OpenAILLMContext): The LLM context. + user_kwargs (Mapping[str, Any], optional): Additional keyword + arguments for the user context aggregator constructor. Defaults + to an empty mapping. + assistant_kwargs (Mapping[str, Any], optional): Additional keyword + arguments for the assistant context aggregator + constructor. Defaults to an empty mapping. + + Returns: + OpenAIContextAggregatorPair: A pair of context aggregators, one for + the user and one for the assistant, encapsulated in an + OpenAIContextAggregatorPair. + + """ + user = OpenAIUserContextAggregator(context, **user_kwargs) + assistant = OpenAIAssistantContextAggregator(context, **assistant_kwargs) return OpenAIContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/services/openai_realtime_beta/openai.py b/src/pipecat/services/openai_realtime_beta/openai.py index ff6f24c66..44ce45dd7 100644 --- a/src/pipecat/services/openai_realtime_beta/openai.py +++ b/src/pipecat/services/openai_realtime_beta/openai.py @@ -4,11 +4,11 @@ # SPDX-License-Identifier: BSD 2-Clause License # -import asyncio import base64 import json import time from dataclasses import dataclass +from typing import Any, Mapping from loguru import logger @@ -571,11 +571,35 @@ class OpenAIRealtimeBetaLLMService(LLMService): await self.send_client_event(events.InputAudioBufferAppendEvent(audio=payload)) def create_context_aggregator( - self, context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = False + self, + context: OpenAILLMContext, + *, + user_kwargs: Mapping[str, Any] = {}, + assistant_kwargs: Mapping[str, Any] = {}, ) -> OpenAIContextAggregatorPair: + """Create an instance of OpenAIContextAggregatorPair from an + OpenAILLMContext. Constructor keyword arguments for both the user and + assistant aggregators can be provided. + + Args: + context (OpenAILLMContext): The LLM context. + user_kwargs (Mapping[str, Any], optional): Additional keyword + arguments for the user context aggregator constructor. Defaults + to an empty mapping. + assistant_kwargs (Mapping[str, Any], optional): Additional keyword + arguments for the assistant context aggregator + constructor. Defaults to an empty mapping. + + Returns: + OpenAIContextAggregatorPair: A pair of context aggregators, one for + the user and one for the assistant, encapsulated in an + OpenAIContextAggregatorPair. + + """ OpenAIRealtimeLLMContext.upgrade_to_realtime(context) - user = OpenAIRealtimeUserContextAggregator(context) - assistant = OpenAIRealtimeAssistantContextAggregator( - context, expect_stripped_words=assistant_expect_stripped_words - ) + user = OpenAIRealtimeUserContextAggregator(context, **user_kwargs) + + default_assistant_kwargs = {"expect_stripped_words": False} + default_assistant_kwargs.update(assistant_kwargs) + assistant = OpenAIRealtimeAssistantContextAggregator(context, **default_assistant_kwargs) return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)