Merge pull request #1588 from pipecat-ai/aleix/llm-aggregator-params
LLM aggregator params
This commit is contained in:
12
CHANGELOG.md
12
CHANGELOG.md
@@ -12,6 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- `DeepgramTTSService` accepts `base_url` argument again, allowing you to
|
||||
connect to an on-prem service.
|
||||
|
||||
- Added `LLMUserAggregatorParams` and `LLMAssistantAggregatorParams` which allow
|
||||
you to control aggregator settings. You can now pass these arguments when
|
||||
creating aggregator pairs with `create_context_aggregator()`.
|
||||
|
||||
- It is now possible to disable `SoundfileMixer` when created. You can then use
|
||||
`MixerEnableFrame` to dynamically enable it when necessary.
|
||||
|
||||
@@ -38,6 +42,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- `DeepgramSTTService` parameter `url` is now deprecated, use `base_url`
|
||||
instead.
|
||||
|
||||
### Removed
|
||||
|
||||
- Parameters `user_kwargs` and `assistant_kwargs` when creating a context
|
||||
aggregator pair using `create_context_aggregator()` have been removed. Use
|
||||
`user_params` and `assistant_params` instead.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a `TavusVideoService` issue that was causing audio choppiness.
|
||||
@@ -45,7 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Fixed an issue in `SmallWebRTCTransport` where an error was thrown if the
|
||||
client did not create a video transceiver.
|
||||
|
||||
- Fixed an issue where LLM input parameters were not working and applied correctly in `GoogleVertexLLMService`, causing
|
||||
- Fixed an issue where LLM input parameters were not working and applied correctly in `GoogleVertexLLMService`, causing
|
||||
unexpected behavior during inference.
|
||||
|
||||
## [0.0.63] - 2025-04-11
|
||||
|
||||
@@ -33,7 +33,10 @@ from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantResponseAggregator
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMAssistantResponseAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
@@ -478,7 +481,7 @@ class LLMAggregatorBuffer(LLMAssistantResponseAggregator):
|
||||
"""Buffers the output of the transcription LLM. Used by the bot output gate."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(expect_stripped_words=False)
|
||||
super().__init__(params=LLMAssistantAggregatorParams(expect_stripped_words=False))
|
||||
self._transcription = ""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
|
||||
@@ -54,7 +54,7 @@ fal = [ "fal-client~=0.5.9" ]
|
||||
fireworks = []
|
||||
fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
|
||||
gladia = [ "websockets~=13.1" ]
|
||||
google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4" ]
|
||||
google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4", "websockets~=13.1" ]
|
||||
grok = []
|
||||
groq = [ "groq~=0.20.0" ]
|
||||
gstreamer = [ "pygobject~=3.50.0" ]
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import asyncio
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Literal, Set
|
||||
|
||||
from loguru import logger
|
||||
@@ -46,6 +47,16 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMUserAggregatorParams:
|
||||
aggregation_timeout: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMAssistantAggregatorParams:
|
||||
expect_stripped_words: bool = True
|
||||
|
||||
|
||||
class LLMFullResponseAggregator(FrameProcessor):
|
||||
"""This is an LLM aggregator that aggregates a full LLM completion. It
|
||||
aggregates LLM text frames (tokens) received between
|
||||
@@ -230,11 +241,23 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
def __init__(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
aggregation_timeout: float = 1.0,
|
||||
*,
|
||||
params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(context=context, role="user", **kwargs)
|
||||
self._aggregation_timeout = aggregation_timeout
|
||||
self._params = params
|
||||
if "aggregation_timeout" in kwargs:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'aggregation_timeout' is deprecated, use 'params' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self._params.aggregation_timeout = kwargs["aggregation_timeout"]
|
||||
|
||||
self._seen_interim_results = False
|
||||
self._user_speaking = False
|
||||
@@ -357,7 +380,9 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
async def _aggregation_task_handler(self):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(self._aggregation_event.wait(), self._aggregation_timeout)
|
||||
await asyncio.wait_for(
|
||||
self._aggregation_event.wait(), self._params.aggregation_timeout
|
||||
)
|
||||
await self._maybe_push_bot_interruption()
|
||||
except asyncio.TimeoutError:
|
||||
if not self._user_speaking:
|
||||
@@ -394,9 +419,27 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(context=context, role="assistant", **kwargs)
|
||||
self._expect_stripped_words = expect_stripped_words
|
||||
self._params = params
|
||||
|
||||
if "expect_stripped_words" in kwargs:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'expect_stripped_words' is deprecated, use 'params' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self._params.expect_stripped_words = kwargs["expect_stripped_words"]
|
||||
|
||||
self._started = 0
|
||||
self._function_calls_in_progress: Dict[str, FunctionCallInProgressFrame] = {}
|
||||
@@ -558,7 +601,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
if self._expect_stripped_words:
|
||||
if self._params.expect_stripped_words:
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
else:
|
||||
self._aggregation += frame.text
|
||||
@@ -572,8 +615,14 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
|
||||
|
||||
class LLMUserResponseAggregator(LLMUserContextAggregator):
|
||||
def __init__(self, messages: List[dict] = [], **kwargs):
|
||||
super().__init__(context=OpenAILLMContext(messages), **kwargs)
|
||||
def __init__(
|
||||
self,
|
||||
messages: List[dict] = [],
|
||||
*,
|
||||
params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(context=OpenAILLMContext(messages), params=params, **kwargs)
|
||||
|
||||
async def push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
@@ -588,8 +637,14 @@ class LLMUserResponseAggregator(LLMUserContextAggregator):
|
||||
|
||||
|
||||
class LLMAssistantResponseAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, messages: List[dict] = [], **kwargs):
|
||||
super().__init__(context=OpenAILLMContext(messages), **kwargs)
|
||||
def __init__(
|
||||
self,
|
||||
messages: List[dict] = [],
|
||||
*,
|
||||
params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(context=OpenAILLMContext(messages), params=params, **kwargs)
|
||||
|
||||
async def push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
|
||||
@@ -11,7 +11,7 @@ import io
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
@@ -35,7 +35,9 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
@@ -49,10 +51,7 @@ try:
|
||||
from anthropic import NOT_GIVEN, AsyncAnthropic, NotGiven
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Anthropic, you need to `pip install pipecat-ai[anthropic]`. "
|
||||
+ "Also, set `ANTHROPIC_API_KEY` environment variable."
|
||||
)
|
||||
logger.error("In order to use Anthropic, you need to `pip install pipecat-ai[anthropic]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@@ -120,8 +119,8 @@ class AnthropicLLMService(LLMService):
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_kwargs: Mapping[str, Any] = {},
|
||||
assistant_kwargs: Mapping[str, Any] = {},
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> AnthropicContextAggregatorPair:
|
||||
"""Create an instance of AnthropicContextAggregatorPair from an
|
||||
OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||
@@ -129,12 +128,10 @@ class AnthropicLLMService(LLMService):
|
||||
|
||||
Args:
|
||||
context (OpenAILLMContext): The LLM context.
|
||||
user_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the user context aggregator constructor. Defaults
|
||||
to an empty mapping.
|
||||
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the assistant context aggregator
|
||||
constructor. Defaults to an empty mapping.
|
||||
user_params (LLMUserAggregatorParams, optional): User aggregator
|
||||
parameters.
|
||||
assistant_params (LLMAssistantAggregatorParams, optional): User
|
||||
aggregator parameters.
|
||||
|
||||
Returns:
|
||||
AnthropicContextAggregatorPair: A pair of context aggregators, one
|
||||
@@ -146,8 +143,8 @@ class AnthropicLLMService(LLMService):
|
||||
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
context = AnthropicLLMContext.from_openai_context(context)
|
||||
user = AnthropicUserContextAggregator(context, **user_kwargs)
|
||||
assistant = AnthropicAssistantContextAggregator(context, **assistant_kwargs)
|
||||
user = AnthropicUserContextAggregator(context, params=user_params)
|
||||
assistant = AnthropicAssistantContextAggregator(context, params=assistant_params)
|
||||
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
|
||||
@@ -10,9 +10,8 @@ import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import websockets
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -45,6 +44,10 @@ from pipecat.frames.frames import (
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
@@ -61,6 +64,13 @@ from pipecat.utils.time import time_now_iso8601
|
||||
from . import events
|
||||
from .audio_transcriber import AudioTranscriber
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def language_to_gemini_language(language: Language) -> Optional[str]:
|
||||
"""Maps a Language enum value to a Gemini Live supported language code.
|
||||
@@ -871,8 +881,8 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_kwargs: Mapping[str, Any] = {},
|
||||
assistant_kwargs: Mapping[str, Any] = {},
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> GeminiMultimodalLiveContextAggregatorPair:
|
||||
"""Create an instance of GeminiMultimodalLiveContextAggregatorPair from
|
||||
an OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||
@@ -880,12 +890,10 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
|
||||
Args:
|
||||
context (OpenAILLMContext): The LLM context.
|
||||
user_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the user context aggregator constructor. Defaults
|
||||
to an empty mapping.
|
||||
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the assistant context aggregator
|
||||
constructor. Defaults to an empty mapping.
|
||||
user_params (LLMUserAggregatorParams, optional): User aggregator
|
||||
parameters.
|
||||
assistant_params (LLMAssistantAggregatorParams, optional): User
|
||||
aggregator parameters.
|
||||
|
||||
Returns:
|
||||
GeminiMultimodalLiveContextAggregatorPair: A pair of context
|
||||
@@ -896,11 +904,8 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
|
||||
GeminiMultimodalLiveContext.upgrade(context)
|
||||
user = GeminiMultimodalLiveUserContextAggregator(context, **user_kwargs)
|
||||
user = GeminiMultimodalLiveUserContextAggregator(context, params=user_params)
|
||||
|
||||
default_assistant_kwargs = {"expect_stripped_words": True}
|
||||
default_assistant_kwargs.update(assistant_kwargs)
|
||||
assistant = GeminiMultimodalLiveAssistantContextAggregator(
|
||||
context, **default_assistant_kwargs
|
||||
)
|
||||
assistant_params.expect_stripped_words = True
|
||||
assistant = GeminiMultimodalLiveAssistantContextAggregator(context, params=assistant_params)
|
||||
return GeminiMultimodalLiveContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
@@ -9,21 +9,14 @@ import io
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from google.api_core.exceptions import DeadlineExceeded
|
||||
|
||||
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
|
||||
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
Frame,
|
||||
@@ -39,6 +32,10 @@ from pipecat.frames.frames import (
|
||||
VisionImageRawFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
@@ -51,11 +48,14 @@ from pipecat.services.openai.llm import (
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
|
||||
try:
|
||||
import google.ai.generativelanguage as glm
|
||||
import google.generativeai as gai
|
||||
from google.api_core.exceptions import DeadlineExceeded
|
||||
from google.generativeai.types import GenerationConfig
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
|
||||
@@ -686,8 +686,8 @@ class GoogleLLMService(LLMService):
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_kwargs: Mapping[str, Any] = {},
|
||||
assistant_kwargs: Mapping[str, Any] = {},
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> GoogleContextAggregatorPair:
|
||||
"""Create an instance of GoogleContextAggregatorPair from an
|
||||
OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||
@@ -695,12 +695,10 @@ class GoogleLLMService(LLMService):
|
||||
|
||||
Args:
|
||||
context (OpenAILLMContext): The LLM context.
|
||||
user_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the user context aggregator constructor. Defaults
|
||||
to an empty mapping.
|
||||
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the assistant context aggregator
|
||||
constructor. Defaults to an empty mapping.
|
||||
user_params (LLMUserAggregatorParams, optional): User aggregator
|
||||
parameters.
|
||||
assistant_params (LLMAssistantAggregatorParams, optional): User
|
||||
aggregator parameters.
|
||||
|
||||
Returns:
|
||||
GoogleContextAggregatorPair: A pair of context aggregators, one for
|
||||
@@ -712,6 +710,6 @@ class GoogleLLMService(LLMService):
|
||||
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
context = GoogleLLMContext.upgrade_to_google(context)
|
||||
user = GoogleUserContextAggregator(context, **user_kwargs)
|
||||
assistant = GoogleAssistantContextAggregator(context, **assistant_kwargs)
|
||||
user = GoogleUserContextAggregator(context, params=user_params)
|
||||
assistant = GoogleAssistantContextAggregator(context, params=assistant_params)
|
||||
return GoogleContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
@@ -5,11 +5,14 @@
|
||||
#
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Mapping
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
@@ -124,8 +127,8 @@ class GrokLLMService(OpenAILLMService):
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_kwargs: Mapping[str, Any] = {},
|
||||
assistant_kwargs: Mapping[str, Any] = {},
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> GrokContextAggregatorPair:
|
||||
"""Create an instance of GrokContextAggregatorPair from an
|
||||
OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||
@@ -133,12 +136,10 @@ class GrokLLMService(OpenAILLMService):
|
||||
|
||||
Args:
|
||||
context (OpenAILLMContext): The LLM context.
|
||||
user_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the user context aggregator constructor. Defaults
|
||||
to an empty mapping.
|
||||
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the assistant context aggregator
|
||||
constructor. Defaults to an empty mapping.
|
||||
user_params (LLMUserAggregatorParams, optional): User aggregator
|
||||
parameters.
|
||||
assistant_params (LLMAssistantAggregatorParams, optional): User
|
||||
aggregator parameters.
|
||||
|
||||
Returns:
|
||||
GrokContextAggregatorPair: A pair of context aggregators, one for
|
||||
@@ -148,6 +149,6 @@ class GrokLLMService(OpenAILLMService):
|
||||
"""
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
|
||||
user = OpenAIUserContextAggregator(context, **user_kwargs)
|
||||
assistant = OpenAIAssistantContextAggregator(context, **assistant_kwargs)
|
||||
user = OpenAIUserContextAggregator(context, params=user_params)
|
||||
assistant = OpenAIAssistantContextAggregator(context, params=assistant_params)
|
||||
return GrokContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Mapping, Optional, Set, Tuple, Type
|
||||
from typing import Any, Optional, Set, Tuple, Type
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -20,6 +20,10 @@ from pipecat.frames.frames import (
|
||||
StartInterruptionFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_service import AIService
|
||||
@@ -55,8 +59,8 @@ class LLMService(AIService):
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_kwargs: Mapping[str, Any] = {},
|
||||
assistant_kwargs: Mapping[str, Any] = {},
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Mapping
|
||||
from typing import Any
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
FunctionCallCancelFrame,
|
||||
@@ -15,7 +15,9 @@ from pipecat.frames.frames import (
|
||||
UserImageRawFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
@@ -48,8 +50,8 @@ class OpenAILLMService(BaseOpenAILLMService):
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_kwargs: Mapping[str, Any] = {},
|
||||
assistant_kwargs: Mapping[str, Any] = {},
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> OpenAIContextAggregatorPair:
|
||||
"""Create an instance of OpenAIContextAggregatorPair from an
|
||||
OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||
@@ -57,12 +59,8 @@ class OpenAILLMService(BaseOpenAILLMService):
|
||||
|
||||
Args:
|
||||
context (OpenAILLMContext): The LLM context.
|
||||
user_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the user context aggregator constructor. Defaults
|
||||
to an empty mapping.
|
||||
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the assistant context aggregator
|
||||
constructor. Defaults to an empty mapping.
|
||||
user_params (LLMUserAggregatorParams, optional): User aggregator parameters.
|
||||
assistant_params (LLMAssistantAggregatorParams, optional): User aggregator parameters.
|
||||
|
||||
Returns:
|
||||
OpenAIContextAggregatorPair: A pair of context aggregators, one for
|
||||
@@ -71,8 +69,8 @@ class OpenAILLMService(BaseOpenAILLMService):
|
||||
|
||||
"""
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
user = OpenAIUserContextAggregator(context, **user_kwargs)
|
||||
assistant = OpenAIAssistantContextAggregator(context, **assistant_kwargs)
|
||||
user = OpenAIUserContextAggregator(context, params=user_params)
|
||||
assistant = OpenAIAssistantContextAggregator(context, params=assistant_params)
|
||||
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
|
||||
|
||||
@@ -8,19 +8,9 @@ import base64
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Mapping
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
@@ -48,6 +38,10 @@ from pipecat.frames.frames import (
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
@@ -65,6 +59,13 @@ from .context import (
|
||||
)
|
||||
from .frames import RealtimeFunctionCallResultFrame, RealtimeMessagesUpdateFrame
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use OpenAI, you need to `pip install pipecat-ai[openai]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CurrentAudioResponse:
|
||||
@@ -650,8 +651,8 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_kwargs: Mapping[str, Any] = {},
|
||||
assistant_kwargs: Mapping[str, Any] = {},
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> OpenAIContextAggregatorPair:
|
||||
"""Create an instance of OpenAIContextAggregatorPair from an
|
||||
OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||
@@ -659,12 +660,10 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
|
||||
Args:
|
||||
context (OpenAILLMContext): The LLM context.
|
||||
user_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the user context aggregator constructor. Defaults
|
||||
to an empty mapping.
|
||||
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the assistant context aggregator
|
||||
constructor. Defaults to an empty mapping.
|
||||
user_params (LLMUserAggregatorParams, optional): User aggregator
|
||||
parameters.
|
||||
assistant_params (LLMAssistantAggregatorParams, optional): User
|
||||
aggregator parameters.
|
||||
|
||||
Returns:
|
||||
OpenAIContextAggregatorPair: A pair of context aggregators, one for
|
||||
@@ -675,9 +674,8 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
|
||||
OpenAIRealtimeLLMContext.upgrade_to_realtime(context)
|
||||
user = OpenAIRealtimeUserContextAggregator(context, **user_kwargs)
|
||||
user = OpenAIRealtimeUserContextAggregator(context, params=user_params)
|
||||
|
||||
default_assistant_kwargs = {"expect_stripped_words": False}
|
||||
default_assistant_kwargs.update(assistant_kwargs)
|
||||
assistant = OpenAIRealtimeAssistantContextAggregator(context, **default_assistant_kwargs)
|
||||
assistant_params.expect_stripped_words = False
|
||||
assistant = OpenAIRealtimeAssistantContextAggregator(context, params=assistant_params)
|
||||
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
@@ -26,7 +26,11 @@ from pipecat.frames.frames import (
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import LLMUserContextAggregator
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
@@ -163,7 +167,9 @@ class BaseTestUserContextAggregator:
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""),
|
||||
@@ -189,7 +195,9 @@ class BaseTestUserContextAggregator:
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""),
|
||||
@@ -216,7 +224,9 @@ class BaseTestUserContextAggregator:
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
@@ -240,7 +250,9 @@ class BaseTestUserContextAggregator:
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
@@ -265,7 +277,9 @@ class BaseTestUserContextAggregator:
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
|
||||
@@ -293,7 +307,9 @@ class BaseTestUserContextAggregator:
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
@@ -318,7 +334,9 @@ class BaseTestUserContextAggregator:
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""),
|
||||
@@ -346,7 +364,9 @@ class BaseTestUserContextAggregator:
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
frames_to_send = [
|
||||
TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""),
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
@@ -366,7 +386,9 @@ class BaseTestUserContextAggregator:
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
frames_to_send = [
|
||||
InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""),
|
||||
SleepFrame(),
|
||||
@@ -389,8 +411,7 @@ class BaseTestUserContextAggregator:
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context,
|
||||
aggregation_timeout=AGGREGATION_TIMEOUT,
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
@@ -469,7 +490,9 @@ class BaseTestAssistantContextAggreagator:
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False)
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMAssistantAggregatorParams(expect_stripped_words=False)
|
||||
)
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="Hello "),
|
||||
@@ -513,7 +536,9 @@ class BaseTestAssistantContextAggreagator:
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False)
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMAssistantAggregatorParams(expect_stripped_words=False)
|
||||
)
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="Hello "),
|
||||
@@ -538,7 +563,9 @@ class BaseTestAssistantContextAggreagator:
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False)
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMAssistantAggregatorParams(expect_stripped_words=False)
|
||||
)
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="Hello "),
|
||||
|
||||
@@ -20,6 +20,7 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMAssistantResponseAggregator,
|
||||
LLMUserResponseAggregator,
|
||||
)
|
||||
@@ -63,7 +64,9 @@ class TestLangchain(unittest.IsolatedAsyncioTestCase):
|
||||
self.mock_proc = self.MockProcessor("token_collector")
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages, expect_stripped_words=False)
|
||||
tma_out = LLMAssistantResponseAggregator(
|
||||
messages, params=LLMAssistantAggregatorParams(expect_stripped_words=False)
|
||||
)
|
||||
|
||||
pipeline = Pipeline([tma_in, proc, self.mock_proc, tma_out])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user