reimplement LLM response aggregators

This commit is contained in:
Aleix Conchillo Flaqué
2025-02-11 22:16:10 -08:00
parent 8bdd7ed0ed
commit e1f2bbceb3
9 changed files with 275 additions and 251 deletions

View File

@@ -38,10 +38,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
OpenAI-compatible interface. Also, added foundational example
`14n-function-calling-perplexity.py`.
- Added `DailyTransport.update_remote_participants()`. This allows you to
update remote participant's settings, like their permissions or which of
their devices are enabled. Requires that the local participant have
participant admin permission.
- Added `DailyTransport.update_remote_participants()`. This allows you to update
remote participant's settings, like their permissions or which of their
devices are enabled. Requires that the local participant have participant
admin permission.
### Changed
@@ -91,6 +91,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Fixed multiple issue where user transcriptions where not being handled
properly. It was possible for short utterances to not trigger VAD which would
cause user transcriptions to be ignored. It was also possible for one or more
transcriptions to be generated after VAD in which case they would also be
ignored.
- Fixed an issue that was causing `BotStoppedSpeakingFrame` to be generated too
late. This could then cause issues unblocking `STTMuteFilter` later than
desired.
@@ -283,7 +289,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `enable_recording` and `geo` parameters to `DailyRoomProperties`.
- Added `RecordingsBucketConfig` to `DailyRoomProperties` to upload recordings to a custom AWS bucket.
- Added `RecordingsBucketConfig` to `DailyRoomProperties` to upload recordings
to a custom AWS bucket.
### Changed

View File

@@ -497,7 +497,7 @@ class UserAggregatorBuffer(LLMResponseAggregator):
if isinstance(frame, UserStartedSpeakingFrame):
self._transcription = ""
async def _push_aggregation(self):
async def push_aggregation(self):
if self._aggregation:
self._transcription = self._aggregation
self._aggregation = ""

View File

@@ -4,9 +4,13 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import List, Optional, Type
import asyncio
from abc import abstractmethod
from typing import List
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
InterimTranscriptionFrame,
LLMFullResponseEndFrame,
@@ -15,6 +19,7 @@ from pipecat.frames.frames import (
LLMMessagesFrame,
LLMMessagesUpdateFrame,
LLMSetToolsFrame,
StartFrame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
@@ -28,121 +33,81 @@ from pipecat.processors.aggregators.openai_llm_context import (
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class LLMResponseAggregator(FrameProcessor):
class BaseLLMResponseAggregator(FrameProcessor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@property
@abstractmethod
def messages(self) -> List[dict]:
pass
@property
@abstractmethod
def role(self) -> str:
pass
@abstractmethod
def add_messages(self, messages):
pass
@abstractmethod
def set_messages(self, messages):
pass
@abstractmethod
def set_tools(self, tools):
pass
@abstractmethod
def reset(self):
pass
@abstractmethod
async def push_aggregation(self):
pass
class LLMResponseAggregator(BaseLLMResponseAggregator):
def __init__(
self,
*,
messages: List[dict],
role: str,
start_frame,
end_frame,
accumulator_frame: Type[TextFrame],
interim_accumulator_frame: Optional[Type[TextFrame]] = None,
handle_interruptions: bool = False,
expect_stripped_words: bool = True, # if True, need to add spaces between words
role: str = "user",
**kwargs,
):
super().__init__()
super().__init__(**kwargs)
self._messages = messages
self._role = role
self._start_frame = start_frame
self._end_frame = end_frame
self._accumulator_frame = accumulator_frame
self._interim_accumulator_frame = interim_accumulator_frame
self._handle_interruptions = handle_interruptions
self._expect_stripped_words = expect_stripped_words
# Reset our accumulator state.
self._reset()
self._aggregation = ""
self.reset()
@property
def messages(self):
def messages(self) -> List[dict]:
return self._messages
@property
def role(self):
def role(self) -> str:
return self._role
#
# Frame processor
#
def add_messages(self, messages):
self._messages.extend(messages)
# Use cases implemented:
#
# S: Start, E: End, T: Transcription, I: Interim, X: Text
#
# S E -> None
# S T E -> X
# S I T E -> X
# S I E T -> X
# S I E I T -> X
# S E T -> X
# S E I T -> X
#
# The following case would not be supported:
#
# S I E T1 I T2 -> X
#
# and T2 would be dropped.
def set_messages(self, messages):
self.reset()
self._messages.clear()
self._messages.extend(messages)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
def set_tools(self, tools):
pass
send_aggregation = False
def reset(self):
self._aggregation = ""
if isinstance(frame, self._start_frame):
self._aggregation = ""
self._aggregating = True
self._seen_start_frame = True
self._seen_end_frame = False
self._seen_interim_results = False
await self.push_frame(frame, direction)
elif isinstance(frame, self._end_frame):
self._seen_end_frame = True
self._seen_start_frame = False
# We might have received the end frame but we might still be
# aggregating (i.e. we have seen interim results but not the final
# text).
self._aggregating = self._seen_interim_results or len(self._aggregation) == 0
# Send the aggregation if we are not aggregating anymore (i.e. no
# more interim results received).
send_aggregation = not self._aggregating
await self.push_frame(frame, direction)
elif isinstance(frame, self._accumulator_frame):
if self._aggregating:
if self._expect_stripped_words:
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
else:
self._aggregation += frame.text
# We have recevied a complete sentence, so if we have seen the
# end frame and we were still aggregating, it means we should
# send the aggregation.
send_aggregation = self._seen_end_frame
# We just got our final result, so let's reset interim results.
self._seen_interim_results = False
elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame):
self._seen_interim_results = True
elif self._handle_interruptions and isinstance(frame, StartInterruptionFrame):
await self._push_aggregation()
# Reset anyways
self._reset()
await self.push_frame(frame, direction)
elif isinstance(frame, LLMMessagesAppendFrame):
self._add_messages(frame.messages)
elif isinstance(frame, LLMMessagesUpdateFrame):
self._set_messages(frame.messages)
elif isinstance(frame, LLMSetToolsFrame):
self._set_tools(frame.tools)
else:
await self.push_frame(frame, direction)
if send_aggregation:
await self._push_aggregation()
async def _push_aggregation(self):
async def push_aggregation(self):
if len(self._aggregation) > 0:
self._messages.append({"role": self._role, "content": self._aggregation})
@@ -153,109 +118,22 @@ class LLMResponseAggregator(FrameProcessor):
frame = LLMMessagesFrame(self._messages)
await self.push_frame(frame)
# TODO-CB: Types
def _add_messages(self, messages):
self._messages.extend(messages)
def _set_messages(self, messages):
self._reset()
self._messages.clear()
self._messages.extend(messages)
def _set_tools(self, tools):
# noop in the base class
pass
def _reset(self):
self._aggregation = ""
self._aggregating = False
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False
class LLMAssistantResponseAggregator(LLMResponseAggregator):
def __init__(self, messages: List[dict] = []):
super().__init__(
messages=messages,
role="assistant",
start_frame=LLMFullResponseStartFrame,
end_frame=LLMFullResponseEndFrame,
accumulator_frame=TextFrame,
handle_interruptions=True,
)
class LLMUserResponseAggregator(LLMResponseAggregator):
def __init__(self, messages: List[dict] = []):
super().__init__(
messages=messages,
role="user",
start_frame=UserStartedSpeakingFrame,
end_frame=UserStoppedSpeakingFrame,
accumulator_frame=TranscriptionFrame,
interim_accumulator_frame=InterimTranscriptionFrame,
)
class LLMFullResponseAggregator(FrameProcessor):
"""This class aggregates Text frames until it receives a
LLMFullResponseEndFrame, then emits the concatenated text as
a single text frame.
given the following frames:
TextFrame("Hello,")
TextFrame(" world.")
TextFrame(" I am")
TextFrame(" an LLM.")
LLMFullResponseEndFrame()]
this processor will yield nothing for the first 4 frames, then
TextFrame("Hello, world. I am an LLM.")
LLMFullResponseEndFrame()
when passed the last frame.
>>> async def print_frames(aggregator, frame):
... async for frame in aggregator.process_frame(frame):
... if isinstance(frame, TextFrame):
... print(frame.text)
... else:
... print(frame.__class__.__name__)
>>> aggregator = LLMFullResponseAggregator()
>>> asyncio.run(print_frames(aggregator, TextFrame("Hello,")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" world.")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" I am")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM.")))
>>> asyncio.run(print_frames(aggregator, LLMFullResponseEndFrame()))
Hello, world. I am an LLM.
LLMFullResponseEndFrame
"""
def __init__(self):
super().__init__()
self._aggregation = ""
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TextFrame):
self._aggregation += frame.text
elif isinstance(frame, LLMFullResponseEndFrame):
await self.push_frame(TextFrame(self._aggregation))
await self.push_frame(frame)
self._aggregation = ""
else:
await self.push_frame(frame, direction)
class LLMContextAggregator(LLMResponseAggregator):
def __init__(self, *, context: OpenAILLMContext, **kwargs):
class LLMContextResponseAggregator(BaseLLMResponseAggregator):
def __init__(self, *, context: OpenAILLMContext, role: str, **kwargs):
super().__init__(**kwargs)
self._context = context
self._role = role
self._aggregation = ""
@property
def messages(self) -> List[dict]:
return self._context.get_messages()
@property
def role(self) -> str:
return self._role
@property
def context(self):
@@ -268,19 +146,18 @@ class LLMContextAggregator(LLMResponseAggregator):
frame = self.get_context_frame()
await self.push_frame(frame)
# TODO-CB: Types
def _add_messages(self, messages):
def add_messages(self, messages):
self._context.add_messages(messages)
def _set_messages(self, messages):
def set_messages(self, messages):
self._context.set_messages(messages)
def _set_tools(self, tools: List):
def set_tools(self, tools: List):
self._context.set_tools(tools)
async def _push_aggregation(self):
async def push_aggregation(self):
if len(self._aggregation) > 0:
self._context.add_message({"role": self._role, "content": self._aggregation})
self._context.add_message({"role": self.role, "content": self._aggregation})
# Reset the aggregation. Reset it before pushing it down, otherwise
# if the tasks gets cancelled we won't be able to clear things up.
@@ -290,31 +167,171 @@ class LLMContextAggregator(LLMResponseAggregator):
await self.push_frame(frame)
# Reset our accumulator state.
self._reset()
self.reset()
class LLMAssistantContextAggregator(LLMContextAggregator):
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True):
super().__init__(
messages=[],
context=context,
role="assistant",
start_frame=LLMFullResponseStartFrame,
end_frame=LLMFullResponseEndFrame,
accumulator_frame=TextFrame,
handle_interruptions=True,
expect_stripped_words=expect_stripped_words,
)
class LLMUserContextAggregator(LLMContextResponseAggregator):
def __init__(self, context: OpenAILLMContext, aggregation_timeout: float = 1.0, **kwargs):
super().__init__(context=context, role="user", **kwargs)
self._aggregation_timeout = aggregation_timeout
self._seen_interim_results = False
self._user_speaking = False
self._aggregation_event = asyncio.Event()
self._aggregation_task = None
self.reset()
def reset(self):
super().reset()
self._seen_interim_results = False
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, StartFrame):
await self._start(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, EndFrame):
await self._stop(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, CancelFrame):
await self._cancel(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, UserStartedSpeakingFrame):
await self._handle_user_started_speaking(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, UserStoppedSpeakingFrame):
await self._handle_user_stopped_speaking(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, TranscriptionFrame):
await self._handle_transcription(frame)
elif isinstance(frame, InterimTranscriptionFrame):
await self._handle_interim_transcription(frame)
elif isinstance(frame, LLMMessagesAppendFrame):
self.add_messages(frame.messages)
elif isinstance(frame, LLMMessagesUpdateFrame):
self.set_messages(frame.messages)
elif isinstance(frame, LLMSetToolsFrame):
self.set_tools(frame.tools)
else:
await self.push_frame(frame, direction)
async def _start(self, frame: StartFrame):
self._aggregation_task = self.create_task(self._aggregation_task_handler())
async def _stop(self, frame: EndFrame):
if self._aggregation_task:
await self.cancel_task(self._aggregation_task)
self._aggregation_task = None
async def _cancel(self, frame: CancelFrame):
if self._aggregation_task:
await self.cancel_task(self._aggregation_task)
self._aggregation_task = None
async def _handle_user_started_speaking(self, _: UserStartedSpeakingFrame):
self._user_speaking = True
async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame):
self._user_speaking = False
if not self._seen_interim_results:
await self.push_aggregation()
async def _handle_transcription(self, frame: TranscriptionFrame):
self._aggregation += frame.text
# We just got our final result, so let's reset interim results.
self._seen_interim_results = False
# Wakeup our task.
self._aggregation_event.set()
async def _handle_interim_transcription(self, _: InterimTranscriptionFrame):
self._seen_interim_results = True
async def _aggregation_task_handler(self):
while True:
await self._aggregation_event.wait()
await asyncio.sleep(self._aggregation_timeout)
if not self._user_speaking:
await self.push_aggregation()
self._aggregation_event.clear()
class LLMUserContextAggregator(LLMContextAggregator):
def __init__(self, context: OpenAILLMContext):
super().__init__(
messages=[],
context=context,
role="user",
start_frame=UserStartedSpeakingFrame,
end_frame=UserStoppedSpeakingFrame,
accumulator_frame=TranscriptionFrame,
interim_accumulator_frame=InterimTranscriptionFrame,
)
class LLMAssistantContextAggregator(LLMContextResponseAggregator):
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True, **kwargs):
super().__init__(context=context, role="assistant", **kwargs)
self._expect_stripped_words = expect_stripped_words
self.reset()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, StartInterruptionFrame):
await self.push_aggregation()
# Reset anyways
self.reset()
await self.push_frame(frame, direction)
elif isinstance(frame, LLMFullResponseStartFrame):
await self._handle_llm_start(frame)
elif isinstance(frame, LLMFullResponseEndFrame):
await self._handle_llm_end(frame)
elif isinstance(frame, TextFrame):
await self._handle_text(frame)
else:
await self.push_frame(frame, direction)
async def _handle_llm_start(self, _: LLMFullResponseStartFrame):
self._started = True
async def _handle_llm_end(self, _: LLMFullResponseEndFrame):
self._started = False
await self.push_aggregation()
async def _handle_text(self, frame: TextFrame):
if not self._started:
return
if self._expect_stripped_words:
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
else:
self._aggregation += frame.text
class LLMUserResponseAggregator(LLMUserContextAggregator):
def __init__(self, messages: List[dict] = [], **kwargs):
super().__init__(context=OpenAILLMContext(messages), **kwargs)
async def push_aggregation(self):
if len(self._aggregation) > 0:
self._context.add_message({"role": self.role, "content": self._aggregation})
# Reset the aggregation. Reset it before pushing it down, otherwise
# if the tasks gets cancelled we won't be able to clear things up.
self._aggregation = ""
frame = LLMMessagesFrame(self._context.messages)
await self.push_frame(frame)
# Reset our accumulator state.
self.reset()
class LLMAssistantResponseAggregator(LLMAssistantContextAggregator):
def __init__(self, messages: List[dict], **kwargs):
super().__init__(context=OpenAILLMContext(messages), **kwargs)
async def push_aggregation(self):
if len(self._aggregation) > 0:
self._context.add_message({"role": self.role, "content": self._aggregation})
# Reset the aggregation. Reset it before pushing it down, otherwise
# if the tasks gets cancelled we won't be able to clear things up.
self._aggregation = ""
frame = LLMMessagesFrame(self._context.messages)
await self.push_frame(frame)
# Reset our accumulator state.
self.reset()

View File

@@ -725,7 +725,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
):
self._function_call_in_progress = None
self._function_call_result = frame
await self._push_aggregation()
await self.push_aggregation()
else:
logger.warning(
"FunctionCallResultFrame tool_call_id != InProgressFrame tool_call_id"
@@ -734,9 +734,9 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
self._function_call_result = None
elif isinstance(frame, AnthropicImageMessageFrame):
self._pending_image_frame_message = frame
await self._push_aggregation()
await self.push_aggregation()
async def _push_aggregation(self):
async def push_aggregation(self):
if not (
self._aggregation or self._function_call_result or self._pending_image_frame_message
):
@@ -746,7 +746,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
properties: Optional[FunctionCallResultProperties] = None
aggregation = self._aggregation
self._reset()
self.reset()
try:
if self._function_call_result:

View File

@@ -115,10 +115,10 @@ class GeminiMultimodalLiveUserContextAggregator(OpenAIUserContextAggregator):
class GeminiMultimodalLiveAssistantContextAggregator(OpenAIAssistantContextAggregator):
async def _push_aggregation(self):
async def push_aggregation(self):
# We don't want to store any images in the context. Revisit this later when the API evolves.
self._pending_image_frame_message = None
await super()._push_aggregation()
await super().push_aggregation()
@dataclass

View File

@@ -537,7 +537,7 @@ def language_to_google_stt_language(language: Language) -> Optional[str]:
class GoogleUserContextAggregator(OpenAIUserContextAggregator):
async def _push_aggregation(self):
async def push_aggregation(self):
if len(self._aggregation) > 0:
self._context.add_message(
glm.Content(role="user", parts=[glm.Part(text=self._aggregation)])
@@ -552,11 +552,11 @@ class GoogleUserContextAggregator(OpenAIUserContextAggregator):
await self.push_frame(frame)
# Reset our accumulator state.
self._reset()
self.reset()
class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
async def _push_aggregation(self):
async def push_aggregation(self):
if not (
self._aggregation or self._function_call_result or self._pending_image_frame_message
):
@@ -566,7 +566,7 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
properties: Optional[FunctionCallResultProperties] = None
aggregation = self._aggregation
self._reset()
self.reset()
try:
if self._function_call_result:

View File

@@ -27,7 +27,7 @@ from pipecat.services.openai import (
class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator):
"""Custom assistant context aggregator for Grok that handles empty content requirement."""
async def _push_aggregation(self):
async def push_aggregation(self):
if not (
self._aggregation or self._function_call_result or self._pending_image_frame_message
):
@@ -37,7 +37,7 @@ class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator):
properties: Optional[FunctionCallResultProperties] = None
aggregation = self._aggregation
self._reset()
self.reset()
try:
if self._function_call_result:

View File

@@ -614,7 +614,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
del self._function_calls_in_progress[frame.tool_call_id]
self._function_call_result = frame
# TODO-CB: Kwin wants us to refactor this out of here but I REFUSE
await self._push_aggregation()
await self.push_aggregation()
else:
logger.warning(
"FunctionCallResultFrame tool_call_id does not match any function call in progress"
@@ -622,9 +622,9 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
self._function_call_result = None
elif isinstance(frame, OpenAIImageMessageFrame):
self._pending_image_frame_message = frame
await self._push_aggregation()
await self.push_aggregation()
async def _push_aggregation(self):
async def push_aggregation(self):
if not (
self._aggregation or self._function_call_result or self._pending_image_frame_message
):
@@ -634,7 +634,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
properties: Optional[FunctionCallResultProperties] = None
aggregation = self._aggregation
self._reset()
self.reset()
try:
if self._function_call_result:

View File

@@ -166,7 +166,7 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
if isinstance(frame, LLMSetToolsFrame):
await self.push_frame(frame, direction)
async def _push_aggregation(self):
async def push_aggregation(self):
# for the moment, ignore all user input coming into the pipeline.
# todo: think about whether/how to fix this to allow for text input from
# upstream (transport/transcription, or other sources)
@@ -174,7 +174,7 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator):
async def _push_aggregation(self):
async def push_aggregation(self):
# the only thing we implement here is function calling. in all other cases, messages
# are added to the context when we receive openai realtime api events
if not self._function_call_result:
@@ -182,7 +182,7 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator)
properties: Optional[FunctionCallResultProperties] = None
self._reset()
self.reset()
try:
run_llm = True
frame = self._function_call_result