Merge pull request #1588 from pipecat-ai/aleix/llm-aggregator-params

LLM aggregator params
This commit is contained in:
Aleix Conchillo Flaqué
2025-04-16 15:25:21 -07:00
committed by GitHub
13 changed files with 227 additions and 128 deletions

View File

@@ -12,6 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `DeepgramTTSService` accepts `base_url` argument again, allowing you to
connect to an on-prem service.
- Added `LLMUserAggregatorParams` and `LLMAssistantAggregatorParams` which allow
you to control aggregator settings. You can now pass these arguments when
creating aggregator pairs with `create_context_aggregator()`.
- It is now possible to disable `SoundfileMixer` when created. You can then use
`MixerEnableFrame` to dynamically enable it when necessary.
@@ -38,6 +42,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `DeepgramSTTService` parameter `url` is now deprecated, use `base_url`
instead.
### Removed
- Parameters `user_kwargs` and `assistant_kwargs` when creating a context
aggregator pair using `create_context_aggregator()` have been removed. Use
`user_params` and `assistant_params` instead.
### Fixed
- Fixed a `TavusVideoService` issue that was causing audio choppiness.
@@ -45,7 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed an issue in `SmallWebRTCTransport` where an error was thrown if the
client did not create a video transceiver.
- Fixed an issue where LLM input parameters were not working and applied correctly in `GoogleVertexLLMService`, causing
- Fixed an issue where LLM input parameters were not working and applied correctly in `GoogleVertexLLMService`, causing
unexpected behavior during inference.
## [0.0.63] - 2025-04-11

View File

@@ -33,7 +33,10 @@ from pipecat.pipeline.parallel_pipeline import ParallelPipeline
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_response import LLMAssistantResponseAggregator
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMAssistantResponseAggregator,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
@@ -478,7 +481,7 @@ class LLMAggregatorBuffer(LLMAssistantResponseAggregator):
"""Buffers the output of the transcription LLM. Used by the bot output gate."""
def __init__(self, **kwargs):
super().__init__(expect_stripped_words=False)
super().__init__(params=LLMAssistantAggregatorParams(expect_stripped_words=False))
self._transcription = ""
async def process_frame(self, frame: Frame, direction: FrameDirection):

View File

@@ -54,7 +54,7 @@ fal = [ "fal-client~=0.5.9" ]
fireworks = []
fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
gladia = [ "websockets~=13.1" ]
google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4" ]
google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4", "websockets~=13.1" ]
grok = []
groq = [ "groq~=0.20.0" ]
gstreamer = [ "pygobject~=3.50.0" ]

View File

@@ -6,6 +6,7 @@
import asyncio
from abc import abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Literal, Set
from loguru import logger
@@ -46,6 +47,16 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.time import time_now_iso8601
@dataclass
class LLMUserAggregatorParams:
aggregation_timeout: float = 1.0
@dataclass
class LLMAssistantAggregatorParams:
expect_stripped_words: bool = True
class LLMFullResponseAggregator(FrameProcessor):
"""This is an LLM aggregator that aggregates a full LLM completion. It
aggregates LLM text frames (tokens) received between
@@ -230,11 +241,23 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
def __init__(
self,
context: OpenAILLMContext,
aggregation_timeout: float = 1.0,
*,
params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
**kwargs,
):
super().__init__(context=context, role="user", **kwargs)
self._aggregation_timeout = aggregation_timeout
self._params = params
if "aggregation_timeout" in kwargs:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"Parameter 'aggregation_timeout' is deprecated, use 'params' instead.",
DeprecationWarning,
)
self._params.aggregation_timeout = kwargs["aggregation_timeout"]
self._seen_interim_results = False
self._user_speaking = False
@@ -357,7 +380,9 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
async def _aggregation_task_handler(self):
while True:
try:
await asyncio.wait_for(self._aggregation_event.wait(), self._aggregation_timeout)
await asyncio.wait_for(
self._aggregation_event.wait(), self._params.aggregation_timeout
)
await self._maybe_push_bot_interruption()
except asyncio.TimeoutError:
if not self._user_speaking:
@@ -394,9 +419,27 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
"""
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True, **kwargs):
def __init__(
self,
context: OpenAILLMContext,
*,
params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
**kwargs,
):
super().__init__(context=context, role="assistant", **kwargs)
self._expect_stripped_words = expect_stripped_words
self._params = params
if "expect_stripped_words" in kwargs:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"Parameter 'expect_stripped_words' is deprecated, use 'params' instead.",
DeprecationWarning,
)
self._params.expect_stripped_words = kwargs["expect_stripped_words"]
self._started = 0
self._function_calls_in_progress: Dict[str, FunctionCallInProgressFrame] = {}
@@ -558,7 +601,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
if not self._started:
return
if self._expect_stripped_words:
if self._params.expect_stripped_words:
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
else:
self._aggregation += frame.text
@@ -572,8 +615,14 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
class LLMUserResponseAggregator(LLMUserContextAggregator):
def __init__(self, messages: List[dict] = [], **kwargs):
super().__init__(context=OpenAILLMContext(messages), **kwargs)
def __init__(
self,
messages: List[dict] = [],
*,
params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
**kwargs,
):
super().__init__(context=OpenAILLMContext(messages), params=params, **kwargs)
async def push_aggregation(self):
if len(self._aggregation) > 0:
@@ -588,8 +637,14 @@ class LLMUserResponseAggregator(LLMUserContextAggregator):
class LLMAssistantResponseAggregator(LLMAssistantContextAggregator):
def __init__(self, messages: List[dict] = [], **kwargs):
super().__init__(context=OpenAILLMContext(messages), **kwargs)
def __init__(
self,
messages: List[dict] = [],
*,
params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
**kwargs,
):
super().__init__(context=OpenAILLMContext(messages), params=params, **kwargs)
async def push_aggregation(self):
if len(self._aggregation) > 0:

View File

@@ -11,7 +11,7 @@ import io
import json
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Mapping, Optional, Union
from typing import Any, Dict, List, Optional, Union
import httpx
from loguru import logger
@@ -35,7 +35,9 @@ from pipecat.frames.frames import (
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMAssistantContextAggregator,
LLMUserAggregatorParams,
LLMUserContextAggregator,
)
from pipecat.processors.aggregators.openai_llm_context import (
@@ -49,10 +51,7 @@ try:
from anthropic import NOT_GIVEN, AsyncAnthropic, NotGiven
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use Anthropic, you need to `pip install pipecat-ai[anthropic]`. "
+ "Also, set `ANTHROPIC_API_KEY` environment variable."
)
logger.error("In order to use Anthropic, you need to `pip install pipecat-ai[anthropic]`.")
raise Exception(f"Missing module: {e}")
@@ -120,8 +119,8 @@ class AnthropicLLMService(LLMService):
self,
context: OpenAILLMContext,
*,
user_kwargs: Mapping[str, Any] = {},
assistant_kwargs: Mapping[str, Any] = {},
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> AnthropicContextAggregatorPair:
"""Create an instance of AnthropicContextAggregatorPair from an
OpenAILLMContext. Constructor keyword arguments for both the user and
@@ -129,12 +128,10 @@ class AnthropicLLMService(LLMService):
Args:
context (OpenAILLMContext): The LLM context.
user_kwargs (Mapping[str, Any], optional): Additional keyword
arguments for the user context aggregator constructor. Defaults
to an empty mapping.
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
arguments for the assistant context aggregator
constructor. Defaults to an empty mapping.
user_params (LLMUserAggregatorParams, optional): User aggregator
parameters.
assistant_params (LLMAssistantAggregatorParams, optional): User
aggregator parameters.
Returns:
AnthropicContextAggregatorPair: A pair of context aggregators, one
@@ -146,8 +143,8 @@ class AnthropicLLMService(LLMService):
if isinstance(context, OpenAILLMContext):
context = AnthropicLLMContext.from_openai_context(context)
user = AnthropicUserContextAggregator(context, **user_kwargs)
assistant = AnthropicAssistantContextAggregator(context, **assistant_kwargs)
user = AnthropicUserContextAggregator(context, params=user_params)
assistant = AnthropicAssistantContextAggregator(context, params=assistant_params)
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)
async def _process_context(self, context: OpenAILLMContext):

View File

@@ -10,9 +10,8 @@ import json
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Mapping, Optional, Union
from typing import Any, Dict, List, Optional, Union
import websockets
from loguru import logger
from pydantic import BaseModel, Field
@@ -45,6 +44,10 @@ from pipecat.frames.frames import (
UserStoppedSpeakingFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMUserAggregatorParams,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
@@ -61,6 +64,13 @@ from pipecat.utils.time import time_now_iso8601
from . import events
from .audio_transcriber import AudioTranscriber
try:
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
raise Exception(f"Missing module: {e}")
def language_to_gemini_language(language: Language) -> Optional[str]:
"""Maps a Language enum value to a Gemini Live supported language code.
@@ -871,8 +881,8 @@ class GeminiMultimodalLiveLLMService(LLMService):
self,
context: OpenAILLMContext,
*,
user_kwargs: Mapping[str, Any] = {},
assistant_kwargs: Mapping[str, Any] = {},
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> GeminiMultimodalLiveContextAggregatorPair:
"""Create an instance of GeminiMultimodalLiveContextAggregatorPair from
an OpenAILLMContext. Constructor keyword arguments for both the user and
@@ -880,12 +890,10 @@ class GeminiMultimodalLiveLLMService(LLMService):
Args:
context (OpenAILLMContext): The LLM context.
user_kwargs (Mapping[str, Any], optional): Additional keyword
arguments for the user context aggregator constructor. Defaults
to an empty mapping.
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
arguments for the assistant context aggregator
constructor. Defaults to an empty mapping.
user_params (LLMUserAggregatorParams, optional): User aggregator
parameters.
assistant_params (LLMAssistantAggregatorParams, optional): User
aggregator parameters.
Returns:
GeminiMultimodalLiveContextAggregatorPair: A pair of context
@@ -896,11 +904,8 @@ class GeminiMultimodalLiveLLMService(LLMService):
context.set_llm_adapter(self.get_llm_adapter())
GeminiMultimodalLiveContext.upgrade(context)
user = GeminiMultimodalLiveUserContextAggregator(context, **user_kwargs)
user = GeminiMultimodalLiveUserContextAggregator(context, params=user_params)
default_assistant_kwargs = {"expect_stripped_words": True}
default_assistant_kwargs.update(assistant_kwargs)
assistant = GeminiMultimodalLiveAssistantContextAggregator(
context, **default_assistant_kwargs
)
assistant_params.expect_stripped_words = True
assistant = GeminiMultimodalLiveAssistantContextAggregator(context, params=assistant_params)
return GeminiMultimodalLiveContextAggregatorPair(_user=user, _assistant=assistant)

View File

@@ -9,21 +9,14 @@ import io
import json
import os
import uuid
from google.api_core.exceptions import DeadlineExceeded
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
# Suppress gRPC fork warnings
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
from dataclasses import dataclass
from typing import Any, Dict, List, Mapping, Optional, Union
from typing import Any, Dict, List, Optional
from loguru import logger
from PIL import Image
from pydantic import BaseModel, Field
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
from pipecat.frames.frames import (
AudioRawFrame,
Frame,
@@ -39,6 +32,10 @@ from pipecat.frames.frames import (
VisionImageRawFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMUserAggregatorParams,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
@@ -51,11 +48,14 @@ from pipecat.services.openai.llm import (
OpenAIUserContextAggregator,
)
# Suppress gRPC fork warnings
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
try:
import google.ai.generativelanguage as glm
import google.generativeai as gai
from google.api_core.exceptions import DeadlineExceeded
from google.generativeai.types import GenerationConfig
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
@@ -686,8 +686,8 @@ class GoogleLLMService(LLMService):
self,
context: OpenAILLMContext,
*,
user_kwargs: Mapping[str, Any] = {},
assistant_kwargs: Mapping[str, Any] = {},
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> GoogleContextAggregatorPair:
"""Create an instance of GoogleContextAggregatorPair from an
OpenAILLMContext. Constructor keyword arguments for both the user and
@@ -695,12 +695,10 @@ class GoogleLLMService(LLMService):
Args:
context (OpenAILLMContext): The LLM context.
user_kwargs (Mapping[str, Any], optional): Additional keyword
arguments for the user context aggregator constructor. Defaults
to an empty mapping.
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
arguments for the assistant context aggregator
constructor. Defaults to an empty mapping.
user_params (LLMUserAggregatorParams, optional): User aggregator
parameters.
assistant_params (LLMAssistantAggregatorParams, optional): User
aggregator parameters.
Returns:
GoogleContextAggregatorPair: A pair of context aggregators, one for
@@ -712,6 +710,6 @@ class GoogleLLMService(LLMService):
if isinstance(context, OpenAILLMContext):
context = GoogleLLMContext.upgrade_to_google(context)
user = GoogleUserContextAggregator(context, **user_kwargs)
assistant = GoogleAssistantContextAggregator(context, **assistant_kwargs)
user = GoogleUserContextAggregator(context, params=user_params)
assistant = GoogleAssistantContextAggregator(context, params=assistant_params)
return GoogleContextAggregatorPair(_user=user, _assistant=assistant)

View File

@@ -5,11 +5,14 @@
#
from dataclasses import dataclass
from typing import Any, Mapping
from loguru import logger
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMUserAggregatorParams,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.openai.llm import (
OpenAIAssistantContextAggregator,
@@ -124,8 +127,8 @@ class GrokLLMService(OpenAILLMService):
self,
context: OpenAILLMContext,
*,
user_kwargs: Mapping[str, Any] = {},
assistant_kwargs: Mapping[str, Any] = {},
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> GrokContextAggregatorPair:
"""Create an instance of GrokContextAggregatorPair from an
OpenAILLMContext. Constructor keyword arguments for both the user and
@@ -133,12 +136,10 @@ class GrokLLMService(OpenAILLMService):
Args:
context (OpenAILLMContext): The LLM context.
user_kwargs (Mapping[str, Any], optional): Additional keyword
arguments for the user context aggregator constructor. Defaults
to an empty mapping.
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
arguments for the assistant context aggregator
constructor. Defaults to an empty mapping.
user_params (LLMUserAggregatorParams, optional): User aggregator
parameters.
assistant_params (LLMAssistantAggregatorParams, optional): User
aggregator parameters.
Returns:
GrokContextAggregatorPair: A pair of context aggregators, one for
@@ -148,6 +149,6 @@ class GrokLLMService(OpenAILLMService):
"""
context.set_llm_adapter(self.get_llm_adapter())
user = OpenAIUserContextAggregator(context, **user_kwargs)
assistant = OpenAIAssistantContextAggregator(context, **assistant_kwargs)
user = OpenAIUserContextAggregator(context, params=user_params)
assistant = OpenAIAssistantContextAggregator(context, params=assistant_params)
return GrokContextAggregatorPair(_user=user, _assistant=assistant)

View File

@@ -6,7 +6,7 @@
import asyncio
from dataclasses import dataclass
from typing import Any, Mapping, Optional, Set, Tuple, Type
from typing import Any, Optional, Set, Tuple, Type
from loguru import logger
@@ -20,6 +20,10 @@ from pipecat.frames.frames import (
StartInterruptionFrame,
UserImageRequestFrame,
)
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMUserAggregatorParams,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_service import AIService
@@ -55,8 +59,8 @@ class LLMService(AIService):
self,
context: OpenAILLMContext,
*,
user_kwargs: Mapping[str, Any] = {},
assistant_kwargs: Mapping[str, Any] = {},
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> Any:
pass

View File

@@ -6,7 +6,7 @@
import json
from dataclasses import dataclass
from typing import Any, Mapping
from typing import Any
from pipecat.frames.frames import (
FunctionCallCancelFrame,
@@ -15,7 +15,9 @@ from pipecat.frames.frames import (
UserImageRawFrame,
)
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMAssistantContextAggregator,
LLMUserAggregatorParams,
LLMUserContextAggregator,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
@@ -48,8 +50,8 @@ class OpenAILLMService(BaseOpenAILLMService):
self,
context: OpenAILLMContext,
*,
user_kwargs: Mapping[str, Any] = {},
assistant_kwargs: Mapping[str, Any] = {},
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> OpenAIContextAggregatorPair:
"""Create an instance of OpenAIContextAggregatorPair from an
OpenAILLMContext. Constructor keyword arguments for both the user and
@@ -57,12 +59,8 @@ class OpenAILLMService(BaseOpenAILLMService):
Args:
context (OpenAILLMContext): The LLM context.
user_kwargs (Mapping[str, Any], optional): Additional keyword
arguments for the user context aggregator constructor. Defaults
to an empty mapping.
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
arguments for the assistant context aggregator
constructor. Defaults to an empty mapping.
user_params (LLMUserAggregatorParams, optional): User aggregator parameters.
assistant_params (LLMAssistantAggregatorParams, optional): User aggregator parameters.
Returns:
OpenAIContextAggregatorPair: A pair of context aggregators, one for
@@ -71,8 +69,8 @@ class OpenAILLMService(BaseOpenAILLMService):
"""
context.set_llm_adapter(self.get_llm_adapter())
user = OpenAIUserContextAggregator(context, **user_kwargs)
assistant = OpenAIAssistantContextAggregator(context, **assistant_kwargs)
user = OpenAIUserContextAggregator(context, params=user_params)
assistant = OpenAIAssistantContextAggregator(context, params=assistant_params)
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)

View File

@@ -8,19 +8,9 @@ import base64
import json
import time
from dataclasses import dataclass
from typing import Any, Mapping
from loguru import logger
try:
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable."
)
raise Exception(f"Missing module: {e}")
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
@@ -48,6 +38,10 @@ from pipecat.frames.frames import (
UserStoppedSpeakingFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMUserAggregatorParams,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
@@ -65,6 +59,13 @@ from .context import (
)
from .frames import RealtimeFunctionCallResultFrame, RealtimeMessagesUpdateFrame
try:
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use OpenAI, you need to `pip install pipecat-ai[openai]`.")
raise Exception(f"Missing module: {e}")
@dataclass
class CurrentAudioResponse:
@@ -650,8 +651,8 @@ class OpenAIRealtimeBetaLLMService(LLMService):
self,
context: OpenAILLMContext,
*,
user_kwargs: Mapping[str, Any] = {},
assistant_kwargs: Mapping[str, Any] = {},
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> OpenAIContextAggregatorPair:
"""Create an instance of OpenAIContextAggregatorPair from an
OpenAILLMContext. Constructor keyword arguments for both the user and
@@ -659,12 +660,10 @@ class OpenAIRealtimeBetaLLMService(LLMService):
Args:
context (OpenAILLMContext): The LLM context.
user_kwargs (Mapping[str, Any], optional): Additional keyword
arguments for the user context aggregator constructor. Defaults
to an empty mapping.
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
arguments for the assistant context aggregator
constructor. Defaults to an empty mapping.
user_params (LLMUserAggregatorParams, optional): User aggregator
parameters.
assistant_params (LLMAssistantAggregatorParams, optional): User
aggregator parameters.
Returns:
OpenAIContextAggregatorPair: A pair of context aggregators, one for
@@ -675,9 +674,8 @@ class OpenAIRealtimeBetaLLMService(LLMService):
context.set_llm_adapter(self.get_llm_adapter())
OpenAIRealtimeLLMContext.upgrade_to_realtime(context)
user = OpenAIRealtimeUserContextAggregator(context, **user_kwargs)
user = OpenAIRealtimeUserContextAggregator(context, params=user_params)
default_assistant_kwargs = {"expect_stripped_words": False}
default_assistant_kwargs.update(assistant_kwargs)
assistant = OpenAIRealtimeAssistantContextAggregator(context, **default_assistant_kwargs)
assistant_params.expect_stripped_words = False
assistant = OpenAIRealtimeAssistantContextAggregator(context, params=assistant_params)
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)

View File

@@ -26,7 +26,11 @@ from pipecat.frames.frames import (
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.processors.aggregators.llm_response import LLMUserContextAggregator
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMUserAggregatorParams,
LLMUserContextAggregator,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
@@ -163,7 +167,9 @@ class BaseTestUserContextAggregator:
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
)
frames_to_send = [
UserStartedSpeakingFrame(),
InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""),
@@ -189,7 +195,9 @@ class BaseTestUserContextAggregator:
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
)
frames_to_send = [
UserStartedSpeakingFrame(),
InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""),
@@ -216,7 +224,9 @@ class BaseTestUserContextAggregator:
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
)
frames_to_send = [
UserStartedSpeakingFrame(),
UserStoppedSpeakingFrame(),
@@ -240,7 +250,9 @@ class BaseTestUserContextAggregator:
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
)
frames_to_send = [
UserStartedSpeakingFrame(),
UserStoppedSpeakingFrame(),
@@ -265,7 +277,9 @@ class BaseTestUserContextAggregator:
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
)
frames_to_send = [
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
@@ -293,7 +307,9 @@ class BaseTestUserContextAggregator:
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
)
frames_to_send = [
UserStartedSpeakingFrame(),
UserStoppedSpeakingFrame(),
@@ -318,7 +334,9 @@ class BaseTestUserContextAggregator:
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
)
frames_to_send = [
UserStartedSpeakingFrame(),
InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""),
@@ -346,7 +364,9 @@ class BaseTestUserContextAggregator:
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
)
frames_to_send = [
TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""),
SleepFrame(sleep=AGGREGATION_SLEEP),
@@ -366,7 +386,9 @@ class BaseTestUserContextAggregator:
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
)
frames_to_send = [
InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""),
SleepFrame(),
@@ -389,8 +411,7 @@ class BaseTestUserContextAggregator:
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(
context,
aggregation_timeout=AGGREGATION_TIMEOUT,
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
)
frames_to_send = [
UserStartedSpeakingFrame(),
@@ -469,7 +490,9 @@ class BaseTestAssistantContextAggreagator:
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False)
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMAssistantAggregatorParams(expect_stripped_words=False)
)
frames_to_send = [
LLMFullResponseStartFrame(),
TextFrame(text="Hello "),
@@ -513,7 +536,9 @@ class BaseTestAssistantContextAggreagator:
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False)
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMAssistantAggregatorParams(expect_stripped_words=False)
)
frames_to_send = [
LLMFullResponseStartFrame(),
TextFrame(text="Hello "),
@@ -538,7 +563,9 @@ class BaseTestAssistantContextAggreagator:
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False)
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMAssistantAggregatorParams(expect_stripped_words=False)
)
frames_to_send = [
LLMFullResponseStartFrame(),
TextFrame(text="Hello "),

View File

@@ -20,6 +20,7 @@ from pipecat.frames.frames import (
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMAssistantResponseAggregator,
LLMUserResponseAggregator,
)
@@ -63,7 +64,9 @@ class TestLangchain(unittest.IsolatedAsyncioTestCase):
self.mock_proc = self.MockProcessor("token_collector")
tma_in = LLMUserResponseAggregator(messages)
tma_out = LLMAssistantResponseAggregator(messages, expect_stripped_words=False)
tma_out = LLMAssistantResponseAggregator(
messages, params=LLMAssistantAggregatorParams(expect_stripped_words=False)
)
pipeline = Pipeline([tma_in, proc, self.mock_proc, tma_out])