reimplement LLM response aggregators
This commit is contained in:
17
CHANGELOG.md
17
CHANGELOG.md
@@ -38,10 +38,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
OpenAI-compatible interface. Also, added foundational example
|
||||
`14n-function-calling-perplexity.py`.
|
||||
|
||||
- Added `DailyTransport.update_remote_participants()`. This allows you to
|
||||
update remote participant's settings, like their permissions or which of
|
||||
their devices are enabled. Requires that the local participant have
|
||||
participant admin permission.
|
||||
- Added `DailyTransport.update_remote_participants()`. This allows you to update
|
||||
remote participant's settings, like their permissions or which of their
|
||||
devices are enabled. Requires that the local participant have participant
|
||||
admin permission.
|
||||
|
||||
### Changed
|
||||
|
||||
@@ -91,6 +91,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed multiple issue where user transcriptions where not being handled
|
||||
properly. It was possible for short utterances to not trigger VAD which would
|
||||
cause user transcriptions to be ignored. It was also possible for one or more
|
||||
transcriptions to be generated after VAD in which case they would also be
|
||||
ignored.
|
||||
|
||||
- Fixed an issue that was causing `BotStoppedSpeakingFrame` to be generated too
|
||||
late. This could then cause issues unblocking `STTMuteFilter` later than
|
||||
desired.
|
||||
@@ -283,7 +289,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- Added `enable_recording` and `geo` parameters to `DailyRoomProperties`.
|
||||
|
||||
- Added `RecordingsBucketConfig` to `DailyRoomProperties` to upload recordings to a custom AWS bucket.
|
||||
- Added `RecordingsBucketConfig` to `DailyRoomProperties` to upload recordings
|
||||
to a custom AWS bucket.
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
@@ -497,7 +497,7 @@ class UserAggregatorBuffer(LLMResponseAggregator):
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
self._transcription = ""
|
||||
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
if self._aggregation:
|
||||
self._transcription = self._aggregation
|
||||
self._aggregation = ""
|
||||
|
||||
@@ -4,9 +4,13 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import List, Optional, Type
|
||||
import asyncio
|
||||
from abc import abstractmethod
|
||||
from typing import List
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
@@ -15,6 +19,7 @@ from pipecat.frames.frames import (
|
||||
LLMMessagesFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
@@ -28,121 +33,81 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class LLMResponseAggregator(FrameProcessor):
|
||||
class BaseLLMResponseAggregator(FrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def messages(self) -> List[dict]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def role(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_messages(self, messages):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_messages(self, messages):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_tools(self, tools):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def push_aggregation(self):
|
||||
pass
|
||||
|
||||
|
||||
class LLMResponseAggregator(BaseLLMResponseAggregator):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
messages: List[dict],
|
||||
role: str,
|
||||
start_frame,
|
||||
end_frame,
|
||||
accumulator_frame: Type[TextFrame],
|
||||
interim_accumulator_frame: Optional[Type[TextFrame]] = None,
|
||||
handle_interruptions: bool = False,
|
||||
expect_stripped_words: bool = True, # if True, need to add spaces between words
|
||||
role: str = "user",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._messages = messages
|
||||
self._role = role
|
||||
self._start_frame = start_frame
|
||||
self._end_frame = end_frame
|
||||
self._accumulator_frame = accumulator_frame
|
||||
self._interim_accumulator_frame = interim_accumulator_frame
|
||||
self._handle_interruptions = handle_interruptions
|
||||
self._expect_stripped_words = expect_stripped_words
|
||||
|
||||
# Reset our accumulator state.
|
||||
self._reset()
|
||||
self._aggregation = ""
|
||||
|
||||
self.reset()
|
||||
|
||||
@property
|
||||
def messages(self):
|
||||
def messages(self) -> List[dict]:
|
||||
return self._messages
|
||||
|
||||
@property
|
||||
def role(self):
|
||||
def role(self) -> str:
|
||||
return self._role
|
||||
|
||||
#
|
||||
# Frame processor
|
||||
#
|
||||
def add_messages(self, messages):
|
||||
self._messages.extend(messages)
|
||||
|
||||
# Use cases implemented:
|
||||
#
|
||||
# S: Start, E: End, T: Transcription, I: Interim, X: Text
|
||||
#
|
||||
# S E -> None
|
||||
# S T E -> X
|
||||
# S I T E -> X
|
||||
# S I E T -> X
|
||||
# S I E I T -> X
|
||||
# S E T -> X
|
||||
# S E I T -> X
|
||||
#
|
||||
# The following case would not be supported:
|
||||
#
|
||||
# S I E T1 I T2 -> X
|
||||
#
|
||||
# and T2 would be dropped.
|
||||
def set_messages(self, messages):
|
||||
self.reset()
|
||||
self._messages.clear()
|
||||
self._messages.extend(messages)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
def set_tools(self, tools):
|
||||
pass
|
||||
|
||||
send_aggregation = False
|
||||
def reset(self):
|
||||
self._aggregation = ""
|
||||
|
||||
if isinstance(frame, self._start_frame):
|
||||
self._aggregation = ""
|
||||
self._aggregating = True
|
||||
self._seen_start_frame = True
|
||||
self._seen_end_frame = False
|
||||
self._seen_interim_results = False
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, self._end_frame):
|
||||
self._seen_end_frame = True
|
||||
self._seen_start_frame = False
|
||||
|
||||
# We might have received the end frame but we might still be
|
||||
# aggregating (i.e. we have seen interim results but not the final
|
||||
# text).
|
||||
self._aggregating = self._seen_interim_results or len(self._aggregation) == 0
|
||||
|
||||
# Send the aggregation if we are not aggregating anymore (i.e. no
|
||||
# more interim results received).
|
||||
send_aggregation = not self._aggregating
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, self._accumulator_frame):
|
||||
if self._aggregating:
|
||||
if self._expect_stripped_words:
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
else:
|
||||
self._aggregation += frame.text
|
||||
# We have recevied a complete sentence, so if we have seen the
|
||||
# end frame and we were still aggregating, it means we should
|
||||
# send the aggregation.
|
||||
send_aggregation = self._seen_end_frame
|
||||
|
||||
# We just got our final result, so let's reset interim results.
|
||||
self._seen_interim_results = False
|
||||
elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame):
|
||||
self._seen_interim_results = True
|
||||
elif self._handle_interruptions and isinstance(frame, StartInterruptionFrame):
|
||||
await self._push_aggregation()
|
||||
# Reset anyways
|
||||
self._reset()
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMMessagesAppendFrame):
|
||||
self._add_messages(frame.messages)
|
||||
elif isinstance(frame, LLMMessagesUpdateFrame):
|
||||
self._set_messages(frame.messages)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self._set_tools(frame.tools)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if send_aggregation:
|
||||
await self._push_aggregation()
|
||||
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
self._messages.append({"role": self._role, "content": self._aggregation})
|
||||
|
||||
@@ -153,109 +118,22 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
frame = LLMMessagesFrame(self._messages)
|
||||
await self.push_frame(frame)
|
||||
|
||||
# TODO-CB: Types
|
||||
def _add_messages(self, messages):
|
||||
self._messages.extend(messages)
|
||||
|
||||
def _set_messages(self, messages):
|
||||
self._reset()
|
||||
self._messages.clear()
|
||||
self._messages.extend(messages)
|
||||
|
||||
def _set_tools(self, tools):
|
||||
# noop in the base class
|
||||
pass
|
||||
|
||||
def _reset(self):
|
||||
self._aggregation = ""
|
||||
self._aggregating = False
|
||||
self._seen_start_frame = False
|
||||
self._seen_end_frame = False
|
||||
self._seen_interim_results = False
|
||||
|
||||
|
||||
class LLMAssistantResponseAggregator(LLMResponseAggregator):
|
||||
def __init__(self, messages: List[dict] = []):
|
||||
super().__init__(
|
||||
messages=messages,
|
||||
role="assistant",
|
||||
start_frame=LLMFullResponseStartFrame,
|
||||
end_frame=LLMFullResponseEndFrame,
|
||||
accumulator_frame=TextFrame,
|
||||
handle_interruptions=True,
|
||||
)
|
||||
|
||||
|
||||
class LLMUserResponseAggregator(LLMResponseAggregator):
|
||||
def __init__(self, messages: List[dict] = []):
|
||||
super().__init__(
|
||||
messages=messages,
|
||||
role="user",
|
||||
start_frame=UserStartedSpeakingFrame,
|
||||
end_frame=UserStoppedSpeakingFrame,
|
||||
accumulator_frame=TranscriptionFrame,
|
||||
interim_accumulator_frame=InterimTranscriptionFrame,
|
||||
)
|
||||
|
||||
|
||||
class LLMFullResponseAggregator(FrameProcessor):
|
||||
"""This class aggregates Text frames until it receives a
|
||||
LLMFullResponseEndFrame, then emits the concatenated text as
|
||||
a single text frame.
|
||||
|
||||
given the following frames:
|
||||
|
||||
TextFrame("Hello,")
|
||||
TextFrame(" world.")
|
||||
TextFrame(" I am")
|
||||
TextFrame(" an LLM.")
|
||||
LLMFullResponseEndFrame()]
|
||||
|
||||
this processor will yield nothing for the first 4 frames, then
|
||||
|
||||
TextFrame("Hello, world. I am an LLM.")
|
||||
LLMFullResponseEndFrame()
|
||||
|
||||
when passed the last frame.
|
||||
|
||||
>>> async def print_frames(aggregator, frame):
|
||||
... async for frame in aggregator.process_frame(frame):
|
||||
... if isinstance(frame, TextFrame):
|
||||
... print(frame.text)
|
||||
... else:
|
||||
... print(frame.__class__.__name__)
|
||||
|
||||
>>> aggregator = LLMFullResponseAggregator()
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame("Hello,")))
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame(" world.")))
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame(" I am")))
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM.")))
|
||||
>>> asyncio.run(print_frames(aggregator, LLMFullResponseEndFrame()))
|
||||
Hello, world. I am an LLM.
|
||||
LLMFullResponseEndFrame
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._aggregation = ""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
self._aggregation += frame.text
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self.push_frame(TextFrame(self._aggregation))
|
||||
await self.push_frame(frame)
|
||||
self._aggregation = ""
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class LLMContextAggregator(LLMResponseAggregator):
|
||||
def __init__(self, *, context: OpenAILLMContext, **kwargs):
|
||||
class LLMContextResponseAggregator(BaseLLMResponseAggregator):
|
||||
def __init__(self, *, context: OpenAILLMContext, role: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._context = context
|
||||
self._role = role
|
||||
|
||||
self._aggregation = ""
|
||||
|
||||
@property
|
||||
def messages(self) -> List[dict]:
|
||||
return self._context.get_messages()
|
||||
|
||||
@property
|
||||
def role(self) -> str:
|
||||
return self._role
|
||||
|
||||
@property
|
||||
def context(self):
|
||||
@@ -268,19 +146,18 @@ class LLMContextAggregator(LLMResponseAggregator):
|
||||
frame = self.get_context_frame()
|
||||
await self.push_frame(frame)
|
||||
|
||||
# TODO-CB: Types
|
||||
def _add_messages(self, messages):
|
||||
def add_messages(self, messages):
|
||||
self._context.add_messages(messages)
|
||||
|
||||
def _set_messages(self, messages):
|
||||
def set_messages(self, messages):
|
||||
self._context.set_messages(messages)
|
||||
|
||||
def _set_tools(self, tools: List):
|
||||
def set_tools(self, tools: List):
|
||||
self._context.set_tools(tools)
|
||||
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
self._context.add_message({"role": self._role, "content": self._aggregation})
|
||||
self._context.add_message({"role": self.role, "content": self._aggregation})
|
||||
|
||||
# Reset the aggregation. Reset it before pushing it down, otherwise
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
@@ -290,31 +167,171 @@ class LLMContextAggregator(LLMResponseAggregator):
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Reset our accumulator state.
|
||||
self._reset()
|
||||
self.reset()
|
||||
|
||||
|
||||
class LLMAssistantContextAggregator(LLMContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True):
|
||||
super().__init__(
|
||||
messages=[],
|
||||
context=context,
|
||||
role="assistant",
|
||||
start_frame=LLMFullResponseStartFrame,
|
||||
end_frame=LLMFullResponseEndFrame,
|
||||
accumulator_frame=TextFrame,
|
||||
handle_interruptions=True,
|
||||
expect_stripped_words=expect_stripped_words,
|
||||
)
|
||||
class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
def __init__(self, context: OpenAILLMContext, aggregation_timeout: float = 1.0, **kwargs):
|
||||
super().__init__(context=context, role="user", **kwargs)
|
||||
self._aggregation_timeout = aggregation_timeout
|
||||
|
||||
self._seen_interim_results = False
|
||||
self._user_speaking = False
|
||||
|
||||
self._aggregation_event = asyncio.Event()
|
||||
self._aggregation_task = None
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
self._seen_interim_results = False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
await self._start(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self._stop(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self._cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._handle_user_started_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
await self._handle_user_stopped_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
await self._handle_transcription(frame)
|
||||
elif isinstance(frame, InterimTranscriptionFrame):
|
||||
await self._handle_interim_transcription(frame)
|
||||
elif isinstance(frame, LLMMessagesAppendFrame):
|
||||
self.add_messages(frame.messages)
|
||||
elif isinstance(frame, LLMMessagesUpdateFrame):
|
||||
self.set_messages(frame.messages)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self.set_tools(frame.tools)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _start(self, frame: StartFrame):
|
||||
self._aggregation_task = self.create_task(self._aggregation_task_handler())
|
||||
|
||||
async def _stop(self, frame: EndFrame):
|
||||
if self._aggregation_task:
|
||||
await self.cancel_task(self._aggregation_task)
|
||||
self._aggregation_task = None
|
||||
|
||||
async def _cancel(self, frame: CancelFrame):
|
||||
if self._aggregation_task:
|
||||
await self.cancel_task(self._aggregation_task)
|
||||
self._aggregation_task = None
|
||||
|
||||
async def _handle_user_started_speaking(self, _: UserStartedSpeakingFrame):
|
||||
self._user_speaking = True
|
||||
|
||||
async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame):
|
||||
self._user_speaking = False
|
||||
if not self._seen_interim_results:
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame):
|
||||
self._aggregation += frame.text
|
||||
# We just got our final result, so let's reset interim results.
|
||||
self._seen_interim_results = False
|
||||
# Wakeup our task.
|
||||
self._aggregation_event.set()
|
||||
|
||||
async def _handle_interim_transcription(self, _: InterimTranscriptionFrame):
|
||||
self._seen_interim_results = True
|
||||
|
||||
async def _aggregation_task_handler(self):
|
||||
while True:
|
||||
await self._aggregation_event.wait()
|
||||
await asyncio.sleep(self._aggregation_timeout)
|
||||
if not self._user_speaking:
|
||||
await self.push_aggregation()
|
||||
self._aggregation_event.clear()
|
||||
|
||||
|
||||
class LLMUserContextAggregator(LLMContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext):
|
||||
super().__init__(
|
||||
messages=[],
|
||||
context=context,
|
||||
role="user",
|
||||
start_frame=UserStartedSpeakingFrame,
|
||||
end_frame=UserStoppedSpeakingFrame,
|
||||
accumulator_frame=TranscriptionFrame,
|
||||
interim_accumulator_frame=InterimTranscriptionFrame,
|
||||
)
|
||||
class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True, **kwargs):
|
||||
super().__init__(context=context, role="assistant", **kwargs)
|
||||
self._expect_stripped_words = expect_stripped_words
|
||||
|
||||
self.reset()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
await self.push_aggregation()
|
||||
# Reset anyways
|
||||
self.reset()
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
await self._handle_llm_start(frame)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self._handle_llm_end(frame)
|
||||
elif isinstance(frame, TextFrame):
|
||||
await self._handle_text(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _handle_llm_start(self, _: LLMFullResponseStartFrame):
|
||||
self._started = True
|
||||
|
||||
async def _handle_llm_end(self, _: LLMFullResponseEndFrame):
|
||||
self._started = False
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
if self._expect_stripped_words:
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
else:
|
||||
self._aggregation += frame.text
|
||||
|
||||
|
||||
class LLMUserResponseAggregator(LLMUserContextAggregator):
|
||||
def __init__(self, messages: List[dict] = [], **kwargs):
|
||||
super().__init__(context=OpenAILLMContext(messages), **kwargs)
|
||||
|
||||
async def push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
self._context.add_message({"role": self.role, "content": self._aggregation})
|
||||
|
||||
# Reset the aggregation. Reset it before pushing it down, otherwise
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
self._aggregation = ""
|
||||
|
||||
frame = LLMMessagesFrame(self._context.messages)
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Reset our accumulator state.
|
||||
self.reset()
|
||||
|
||||
|
||||
class LLMAssistantResponseAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, messages: List[dict], **kwargs):
|
||||
super().__init__(context=OpenAILLMContext(messages), **kwargs)
|
||||
|
||||
async def push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
self._context.add_message({"role": self.role, "content": self._aggregation})
|
||||
|
||||
# Reset the aggregation. Reset it before pushing it down, otherwise
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
self._aggregation = ""
|
||||
|
||||
frame = LLMMessagesFrame(self._context.messages)
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Reset our accumulator state.
|
||||
self.reset()
|
||||
|
||||
@@ -725,7 +725,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
):
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = frame
|
||||
await self._push_aggregation()
|
||||
await self.push_aggregation()
|
||||
else:
|
||||
logger.warning(
|
||||
"FunctionCallResultFrame tool_call_id != InProgressFrame tool_call_id"
|
||||
@@ -734,9 +734,9 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
self._function_call_result = None
|
||||
elif isinstance(frame, AnthropicImageMessageFrame):
|
||||
self._pending_image_frame_message = frame
|
||||
await self._push_aggregation()
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
if not (
|
||||
self._aggregation or self._function_call_result or self._pending_image_frame_message
|
||||
):
|
||||
@@ -746,7 +746,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
properties: Optional[FunctionCallResultProperties] = None
|
||||
|
||||
aggregation = self._aggregation
|
||||
self._reset()
|
||||
self.reset()
|
||||
|
||||
try:
|
||||
if self._function_call_result:
|
||||
|
||||
@@ -115,10 +115,10 @@ class GeminiMultimodalLiveUserContextAggregator(OpenAIUserContextAggregator):
|
||||
|
||||
|
||||
class GeminiMultimodalLiveAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
# We don't want to store any images in the context. Revisit this later when the API evolves.
|
||||
self._pending_image_frame_message = None
|
||||
await super()._push_aggregation()
|
||||
await super().push_aggregation()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -537,7 +537,7 @@ def language_to_google_stt_language(language: Language) -> Optional[str]:
|
||||
|
||||
|
||||
class GoogleUserContextAggregator(OpenAIUserContextAggregator):
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
self._context.add_message(
|
||||
glm.Content(role="user", parts=[glm.Part(text=self._aggregation)])
|
||||
@@ -552,11 +552,11 @@ class GoogleUserContextAggregator(OpenAIUserContextAggregator):
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Reset our accumulator state.
|
||||
self._reset()
|
||||
self.reset()
|
||||
|
||||
|
||||
class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
if not (
|
||||
self._aggregation or self._function_call_result or self._pending_image_frame_message
|
||||
):
|
||||
@@ -566,7 +566,7 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
properties: Optional[FunctionCallResultProperties] = None
|
||||
|
||||
aggregation = self._aggregation
|
||||
self._reset()
|
||||
self.reset()
|
||||
|
||||
try:
|
||||
if self._function_call_result:
|
||||
|
||||
@@ -27,7 +27,7 @@ from pipecat.services.openai import (
|
||||
class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
"""Custom assistant context aggregator for Grok that handles empty content requirement."""
|
||||
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
if not (
|
||||
self._aggregation or self._function_call_result or self._pending_image_frame_message
|
||||
):
|
||||
@@ -37,7 +37,7 @@ class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
properties: Optional[FunctionCallResultProperties] = None
|
||||
|
||||
aggregation = self._aggregation
|
||||
self._reset()
|
||||
self.reset()
|
||||
|
||||
try:
|
||||
if self._function_call_result:
|
||||
|
||||
@@ -614,7 +614,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
del self._function_calls_in_progress[frame.tool_call_id]
|
||||
self._function_call_result = frame
|
||||
# TODO-CB: Kwin wants us to refactor this out of here but I REFUSE
|
||||
await self._push_aggregation()
|
||||
await self.push_aggregation()
|
||||
else:
|
||||
logger.warning(
|
||||
"FunctionCallResultFrame tool_call_id does not match any function call in progress"
|
||||
@@ -622,9 +622,9 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
self._function_call_result = None
|
||||
elif isinstance(frame, OpenAIImageMessageFrame):
|
||||
self._pending_image_frame_message = frame
|
||||
await self._push_aggregation()
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
if not (
|
||||
self._aggregation or self._function_call_result or self._pending_image_frame_message
|
||||
):
|
||||
@@ -634,7 +634,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
properties: Optional[FunctionCallResultProperties] = None
|
||||
|
||||
aggregation = self._aggregation
|
||||
self._reset()
|
||||
self.reset()
|
||||
|
||||
try:
|
||||
if self._function_call_result:
|
||||
|
||||
@@ -166,7 +166,7 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
|
||||
if isinstance(frame, LLMSetToolsFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
# for the moment, ignore all user input coming into the pipeline.
|
||||
# todo: think about whether/how to fix this to allow for text input from
|
||||
# upstream (transport/transcription, or other sources)
|
||||
@@ -174,7 +174,7 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
|
||||
|
||||
|
||||
class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
# the only thing we implement here is function calling. in all other cases, messages
|
||||
# are added to the context when we receive openai realtime api events
|
||||
if not self._function_call_result:
|
||||
@@ -182,7 +182,7 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator)
|
||||
|
||||
properties: Optional[FunctionCallResultProperties] = None
|
||||
|
||||
self._reset()
|
||||
self.reset()
|
||||
try:
|
||||
run_llm = True
|
||||
frame = self._function_call_result
|
||||
|
||||
Reference in New Issue
Block a user