From fef393dcacf4e1c94d7a819cf3ea35722fbc8a67 Mon Sep 17 00:00:00 2001 From: Kwindla Hultman Kramer Date: Sun, 29 Sep 2024 14:11:17 -0700 Subject: [PATCH] assistant aggregator switch for space padding or not --- src/pipecat/processors/aggregators/llm_response.py | 10 ++++++++-- src/pipecat/services/anthropic.py | 12 ++++++++---- src/pipecat/services/openai.py | 12 ++++++++---- src/pipecat/services/together.py | 12 ++++++++---- 4 files changed, 32 insertions(+), 14 deletions(-) diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index a3cd63cbd..4ea38b89f 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -39,6 +39,7 @@ class LLMResponseAggregator(FrameProcessor): accumulator_frame: Type[TextFrame], interim_accumulator_frame: Type[TextFrame] | None = None, handle_interruptions: bool = False, + expect_stripped_words: bool = True, # if True, need to add spaces between words ): super().__init__() @@ -49,6 +50,7 @@ class LLMResponseAggregator(FrameProcessor): self._accumulator_frame = accumulator_frame self._interim_accumulator_frame = interim_accumulator_frame self._handle_interruptions = handle_interruptions + self._expect_stripped_words = expect_stripped_words # Reset our accumulator state. self._reset() @@ -110,7 +112,10 @@ class LLMResponseAggregator(FrameProcessor): await self.push_frame(frame, direction) elif isinstance(frame, self._accumulator_frame): if self._aggregating: - self._aggregation += frame.text if self._aggregation else frame.text + if self._expect_stripped_words: + self._aggregation += f" {frame.text}" if self._aggregation else frame.text + else: + self._aggregation += frame.text if self._aggregation else frame.text # We have recevied a complete sentence, so if we have seen the # end frame and we were still aggregating, it means we should # send the aggregation. @@ -289,7 +294,7 @@ class LLMContextAggregator(LLMResponseAggregator): class LLMAssistantContextAggregator(LLMContextAggregator): - def __init__(self, context: OpenAILLMContext): + def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True): super().__init__( messages=[], context=context, @@ -298,6 +303,7 @@ class LLMAssistantContextAggregator(LLMContextAggregator): end_frame=LLMFullResponseEndFrame, accumulator_frame=TextFrame, handle_interruptions=True, + expect_stripped_words=expect_stripped_words, ) diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py index 6a535ef15..86e1e3726 100644 --- a/src/pipecat/services/anthropic.py +++ b/src/pipecat/services/anthropic.py @@ -110,9 +110,13 @@ class AnthropicLLMService(LLMService): return self._enable_prompt_caching_beta @staticmethod - def create_context_aggregator(context: OpenAILLMContext) -> AnthropicContextAggregatorPair: + def create_context_aggregator( + context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True + ) -> AnthropicContextAggregatorPair: user = AnthropicUserContextAggregator(context) - assistant = AnthropicAssistantContextAggregator(user) + assistant = AnthropicAssistantContextAggregator( + user, expect_stripped_words=assistant_expect_stripped_words + ) return AnthropicContextAggregatorPair(_user=user, _assistant=assistant) async def set_enable_prompt_caching_beta(self, enable_prompt_caching_beta: bool): @@ -541,8 +545,8 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator): class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): - def __init__(self, user_context_aggregator: AnthropicUserContextAggregator): - super().__init__(context=user_context_aggregator._context) + def __init__(self, user_context_aggregator: AnthropicUserContextAggregator, **kwargs): + super().__init__(context=user_context_aggregator._context, **kwargs) self._user_context_aggregator = user_context_aggregator self._function_call_in_progress = None self._function_call_result = None diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index 99d2d7497..c17916f2d 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -336,9 +336,13 @@ class OpenAILLMService(BaseOpenAILLMService): super().__init__(model=model, params=params, **kwargs) @staticmethod - def create_context_aggregator(context: OpenAILLMContext) -> OpenAIContextAggregatorPair: + def create_context_aggregator( + context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True + ) -> OpenAIContextAggregatorPair: user = OpenAIUserContextAggregator(context) - assistant = OpenAIAssistantContextAggregator(user) + assistant = OpenAIAssistantContextAggregator( + user, expect_stripped_words=assistant_expect_stripped_words + ) return OpenAIContextAggregatorPair(_user=user, _assistant=assistant) @@ -458,8 +462,8 @@ class OpenAIUserContextAggregator(LLMUserContextAggregator): class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): - def __init__(self, user_context_aggregator: OpenAIUserContextAggregator): - super().__init__(context=user_context_aggregator._context) + def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwargs): + super().__init__(context=user_context_aggregator._context, **kwargs) self._user_context_aggregator = user_context_aggregator self._function_call_in_progress = None self._function_call_result = None diff --git a/src/pipecat/services/together.py b/src/pipecat/services/together.py index 935f625ad..3f4d97964 100644 --- a/src/pipecat/services/together.py +++ b/src/pipecat/services/together.py @@ -95,9 +95,13 @@ class TogetherLLMService(LLMService): return True @staticmethod - def create_context_aggregator(context: OpenAILLMContext) -> TogetherContextAggregatorPair: + def create_context_aggregator( + context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True + ) -> TogetherContextAggregatorPair: user = TogetherUserContextAggregator(context) - assistant = TogetherAssistantContextAggregator(user) + assistant = TogetherAssistantContextAggregator( + user, expect_stripped_words=assistant_expect_stripped_words + ) return TogetherContextAggregatorPair(_user=user, _assistant=assistant) async def set_frequency_penalty(self, frequency_penalty: float): @@ -331,8 +335,8 @@ class TogetherUserContextAggregator(LLMUserContextAggregator): class TogetherAssistantContextAggregator(LLMAssistantContextAggregator): - def __init__(self, user_context_aggregator: TogetherUserContextAggregator): - super().__init__(context=user_context_aggregator._context) + def __init__(self, user_context_aggregator: TogetherUserContextAggregator, **kwargs): + super().__init__(context=user_context_aggregator._context, **kwargs) self._user_context_aggregator = user_context_aggregator self._function_call_in_progress = None self._function_call_result = None