assistant aggregator switch for space padding or not

This commit is contained in:
Kwindla Hultman Kramer
2024-09-29 14:11:17 -07:00
parent ed607d5c4b
commit fef393dcac
4 changed files with 32 additions and 14 deletions

View File

@@ -39,6 +39,7 @@ class LLMResponseAggregator(FrameProcessor):
accumulator_frame: Type[TextFrame],
interim_accumulator_frame: Type[TextFrame] | None = None,
handle_interruptions: bool = False,
expect_stripped_words: bool = True, # if True, need to add spaces between words
):
super().__init__()
@@ -49,6 +50,7 @@ class LLMResponseAggregator(FrameProcessor):
self._accumulator_frame = accumulator_frame
self._interim_accumulator_frame = interim_accumulator_frame
self._handle_interruptions = handle_interruptions
self._expect_stripped_words = expect_stripped_words
# Reset our accumulator state.
self._reset()
@@ -110,7 +112,10 @@ class LLMResponseAggregator(FrameProcessor):
await self.push_frame(frame, direction)
elif isinstance(frame, self._accumulator_frame):
if self._aggregating:
self._aggregation += frame.text if self._aggregation else frame.text
if self._expect_stripped_words:
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
else:
self._aggregation += frame.text if self._aggregation else frame.text
# We have recevied a complete sentence, so if we have seen the
# end frame and we were still aggregating, it means we should
# send the aggregation.
@@ -289,7 +294,7 @@ class LLMContextAggregator(LLMResponseAggregator):
class LLMAssistantContextAggregator(LLMContextAggregator):
def __init__(self, context: OpenAILLMContext):
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True):
super().__init__(
messages=[],
context=context,
@@ -298,6 +303,7 @@ class LLMAssistantContextAggregator(LLMContextAggregator):
end_frame=LLMFullResponseEndFrame,
accumulator_frame=TextFrame,
handle_interruptions=True,
expect_stripped_words=expect_stripped_words,
)

View File

@@ -110,9 +110,13 @@ class AnthropicLLMService(LLMService):
return self._enable_prompt_caching_beta
@staticmethod
def create_context_aggregator(context: OpenAILLMContext) -> AnthropicContextAggregatorPair:
def create_context_aggregator(
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
) -> AnthropicContextAggregatorPair:
user = AnthropicUserContextAggregator(context)
assistant = AnthropicAssistantContextAggregator(user)
assistant = AnthropicAssistantContextAggregator(
user, expect_stripped_words=assistant_expect_stripped_words
)
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)
async def set_enable_prompt_caching_beta(self, enable_prompt_caching_beta: bool):
@@ -541,8 +545,8 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator):
class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
def __init__(self, user_context_aggregator: AnthropicUserContextAggregator):
super().__init__(context=user_context_aggregator._context)
def __init__(self, user_context_aggregator: AnthropicUserContextAggregator, **kwargs):
super().__init__(context=user_context_aggregator._context, **kwargs)
self._user_context_aggregator = user_context_aggregator
self._function_call_in_progress = None
self._function_call_result = None

View File

@@ -336,9 +336,13 @@ class OpenAILLMService(BaseOpenAILLMService):
super().__init__(model=model, params=params, **kwargs)
@staticmethod
def create_context_aggregator(context: OpenAILLMContext) -> OpenAIContextAggregatorPair:
def create_context_aggregator(
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
) -> OpenAIContextAggregatorPair:
user = OpenAIUserContextAggregator(context)
assistant = OpenAIAssistantContextAggregator(user)
assistant = OpenAIAssistantContextAggregator(
user, expect_stripped_words=assistant_expect_stripped_words
)
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
@@ -458,8 +462,8 @@ class OpenAIUserContextAggregator(LLMUserContextAggregator):
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator):
super().__init__(context=user_context_aggregator._context)
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwargs):
super().__init__(context=user_context_aggregator._context, **kwargs)
self._user_context_aggregator = user_context_aggregator
self._function_call_in_progress = None
self._function_call_result = None

View File

@@ -95,9 +95,13 @@ class TogetherLLMService(LLMService):
return True
@staticmethod
def create_context_aggregator(context: OpenAILLMContext) -> TogetherContextAggregatorPair:
def create_context_aggregator(
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
) -> TogetherContextAggregatorPair:
user = TogetherUserContextAggregator(context)
assistant = TogetherAssistantContextAggregator(user)
assistant = TogetherAssistantContextAggregator(
user, expect_stripped_words=assistant_expect_stripped_words
)
return TogetherContextAggregatorPair(_user=user, _assistant=assistant)
async def set_frequency_penalty(self, frequency_penalty: float):
@@ -331,8 +335,8 @@ class TogetherUserContextAggregator(LLMUserContextAggregator):
class TogetherAssistantContextAggregator(LLMAssistantContextAggregator):
def __init__(self, user_context_aggregator: TogetherUserContextAggregator):
super().__init__(context=user_context_aggregator._context)
def __init__(self, user_context_aggregator: TogetherUserContextAggregator, **kwargs):
super().__init__(context=user_context_aggregator._context, **kwargs)
self._user_context_aggregator = user_context_aggregator
self._function_call_in_progress = None
self._function_call_result = None