assistant aggregator switch for space padding or not
This commit is contained in:
@@ -39,6 +39,7 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
accumulator_frame: Type[TextFrame],
|
||||
interim_accumulator_frame: Type[TextFrame] | None = None,
|
||||
handle_interruptions: bool = False,
|
||||
expect_stripped_words: bool = True, # if True, need to add spaces between words
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -49,6 +50,7 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
self._accumulator_frame = accumulator_frame
|
||||
self._interim_accumulator_frame = interim_accumulator_frame
|
||||
self._handle_interruptions = handle_interruptions
|
||||
self._expect_stripped_words = expect_stripped_words
|
||||
|
||||
# Reset our accumulator state.
|
||||
self._reset()
|
||||
@@ -110,7 +112,10 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, self._accumulator_frame):
|
||||
if self._aggregating:
|
||||
self._aggregation += frame.text if self._aggregation else frame.text
|
||||
if self._expect_stripped_words:
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
else:
|
||||
self._aggregation += frame.text if self._aggregation else frame.text
|
||||
# We have recevied a complete sentence, so if we have seen the
|
||||
# end frame and we were still aggregating, it means we should
|
||||
# send the aggregation.
|
||||
@@ -289,7 +294,7 @@ class LLMContextAggregator(LLMResponseAggregator):
|
||||
|
||||
|
||||
class LLMAssistantContextAggregator(LLMContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext):
|
||||
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True):
|
||||
super().__init__(
|
||||
messages=[],
|
||||
context=context,
|
||||
@@ -298,6 +303,7 @@ class LLMAssistantContextAggregator(LLMContextAggregator):
|
||||
end_frame=LLMFullResponseEndFrame,
|
||||
accumulator_frame=TextFrame,
|
||||
handle_interruptions=True,
|
||||
expect_stripped_words=expect_stripped_words,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -110,9 +110,13 @@ class AnthropicLLMService(LLMService):
|
||||
return self._enable_prompt_caching_beta
|
||||
|
||||
@staticmethod
|
||||
def create_context_aggregator(context: OpenAILLMContext) -> AnthropicContextAggregatorPair:
|
||||
def create_context_aggregator(
|
||||
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
|
||||
) -> AnthropicContextAggregatorPair:
|
||||
user = AnthropicUserContextAggregator(context)
|
||||
assistant = AnthropicAssistantContextAggregator(user)
|
||||
assistant = AnthropicAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
async def set_enable_prompt_caching_beta(self, enable_prompt_caching_beta: bool):
|
||||
@@ -541,8 +545,8 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator):
|
||||
|
||||
|
||||
class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: AnthropicUserContextAggregator):
|
||||
super().__init__(context=user_context_aggregator._context)
|
||||
def __init__(self, user_context_aggregator: AnthropicUserContextAggregator, **kwargs):
|
||||
super().__init__(context=user_context_aggregator._context, **kwargs)
|
||||
self._user_context_aggregator = user_context_aggregator
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = None
|
||||
|
||||
@@ -336,9 +336,13 @@ class OpenAILLMService(BaseOpenAILLMService):
|
||||
super().__init__(model=model, params=params, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def create_context_aggregator(context: OpenAILLMContext) -> OpenAIContextAggregatorPair:
|
||||
def create_context_aggregator(
|
||||
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
|
||||
) -> OpenAIContextAggregatorPair:
|
||||
user = OpenAIUserContextAggregator(context)
|
||||
assistant = OpenAIAssistantContextAggregator(user)
|
||||
assistant = OpenAIAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
|
||||
@@ -458,8 +462,8 @@ class OpenAIUserContextAggregator(LLMUserContextAggregator):
|
||||
|
||||
|
||||
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator):
|
||||
super().__init__(context=user_context_aggregator._context)
|
||||
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwargs):
|
||||
super().__init__(context=user_context_aggregator._context, **kwargs)
|
||||
self._user_context_aggregator = user_context_aggregator
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = None
|
||||
|
||||
@@ -95,9 +95,13 @@ class TogetherLLMService(LLMService):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def create_context_aggregator(context: OpenAILLMContext) -> TogetherContextAggregatorPair:
|
||||
def create_context_aggregator(
|
||||
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
|
||||
) -> TogetherContextAggregatorPair:
|
||||
user = TogetherUserContextAggregator(context)
|
||||
assistant = TogetherAssistantContextAggregator(user)
|
||||
assistant = TogetherAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return TogetherContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
async def set_frequency_penalty(self, frequency_penalty: float):
|
||||
@@ -331,8 +335,8 @@ class TogetherUserContextAggregator(LLMUserContextAggregator):
|
||||
|
||||
|
||||
class TogetherAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: TogetherUserContextAggregator):
|
||||
super().__init__(context=user_context_aggregator._context)
|
||||
def __init__(self, user_context_aggregator: TogetherUserContextAggregator, **kwargs):
|
||||
super().__init__(context=user_context_aggregator._context, **kwargs)
|
||||
self._user_context_aggregator = user_context_aggregator
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = None
|
||||
|
||||
Reference in New Issue
Block a user