Compare commits
5 Commits
hush/usage
...
aleix/edga
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eef3f320b1 | ||
|
|
dfdd536b20 | ||
|
|
ba6e9ed9ad | ||
|
|
2dd56ba992 | ||
|
|
c989c9c16d |
2
.github/workflows/tests.yaml
vendored
2
.github/workflows/tests.yaml
vendored
@@ -49,4 +49,4 @@ jobs:
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
pytest --ignore-glob="*to_be_updated*" --ignore-glob=*pipeline_source* src tests
|
||||
pytest
|
||||
|
||||
@@ -75,6 +75,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed LLM response aggregators to support more uses cases such as delayed
|
||||
transcriptions.
|
||||
|
||||
- Fixed an issue that could cause the bot to stop talking if there was a user
|
||||
interruption before getting any audio from the TTS service.
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ Available options include:
|
||||
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), WebSocket, Local | `pip install "pipecat-ai[daily]"` |
|
||||
| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) | `pip install "pipecat-ai[tavus,simli]"` |
|
||||
| Vision & Image | [Moondream](https://docs.pipecat.ai/server/services/vision/moondream), [fal](https://docs.pipecat.ai/server/services/image-generation/fal) | `pip install "pipecat-ai[moondream]"` |
|
||||
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) | `pip install "pipecat-ai[silero]"` |
|
||||
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter) | `pip install "pipecat-ai[silero]"` |
|
||||
| Analytics & Metrics | [Canonical AI](https://docs.pipecat.ai/server/services/analytics/canonical), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) | `pip install "pipecat-ai[canonical]"` |
|
||||
|
||||
📚 [View full services documentation →](https://docs.pipecat.ai/server/services/supported-services)
|
||||
|
||||
@@ -82,7 +82,10 @@ whisper = [ "faster-whisper~=1.1.0" ]
|
||||
where = ["src"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--verbose --disable-warnings"
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["src"]
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
|
||||
[tool.setuptools_scm]
|
||||
local_scheme = "no-local-version"
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
from typing import List, Type
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
@@ -40,6 +41,7 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
interim_accumulator_frame: Type[TextFrame] | None = None,
|
||||
handle_interruptions: bool = False,
|
||||
expect_stripped_words: bool = True, # if True, need to add spaces between words
|
||||
interrupt_double_accumulator: bool = True, # if True, interrupt if two or more accumulators are received
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -51,8 +53,8 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
self._interim_accumulator_frame = interim_accumulator_frame
|
||||
self._handle_interruptions = handle_interruptions
|
||||
self._expect_stripped_words = expect_stripped_words
|
||||
self._interrupt_double_accumulator = interrupt_double_accumulator
|
||||
|
||||
# Reset our accumulator state.
|
||||
self._reset()
|
||||
|
||||
@property
|
||||
@@ -69,21 +71,20 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
|
||||
# Use cases implemented:
|
||||
#
|
||||
# S: Start, E: End, T: Transcription, I: Interim, X: Text
|
||||
# S: Start, E: End, T: Transcription, I: Interim
|
||||
#
|
||||
# S E -> None
|
||||
# S T E -> X
|
||||
# S I T E -> X
|
||||
# S I E T -> X
|
||||
# S I E I T -> X
|
||||
# S E T -> X
|
||||
# S E I T -> X
|
||||
#
|
||||
# The following case would not be supported:
|
||||
#
|
||||
# S I E T1 I T2 -> X
|
||||
#
|
||||
# and T2 would be dropped.
|
||||
# S E -> None -> User started speaking but no transcription.
|
||||
# S T E -> T -> Transcription between user started and stopped speaking.
|
||||
# S E T -> T -> Transcription after user stopped speaking.
|
||||
# S I T E -> T -> Transcription between user started and stopped speaking (with interims).
|
||||
# S I E T -> T -> Transcription after user stopped speaking (with interims).
|
||||
# S I E I T -> T -> Transcription after user stopped speaking (with interims).
|
||||
# S E I T -> T -> Transcription after user stopped speaking (with interims).
|
||||
# S T1 I E S T2 E -> "T1 T2" -> Merge two transcriptions if we got a first interim.
|
||||
# S I E T1 I T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription.
|
||||
# S T1 E T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription.
|
||||
# S E T1 B T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription.
|
||||
# S E T1 T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription.
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -91,11 +92,9 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
send_aggregation = False
|
||||
|
||||
if isinstance(frame, self._start_frame):
|
||||
self._aggregation = ""
|
||||
self._aggregating = True
|
||||
self._seen_start_frame = True
|
||||
self._seen_end_frame = False
|
||||
self._seen_interim_results = False
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, self._end_frame):
|
||||
self._seen_end_frame = True
|
||||
@@ -109,23 +108,36 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
# Send the aggregation if we are not aggregating anymore (i.e. no
|
||||
# more interim results received).
|
||||
send_aggregation = not self._aggregating
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, self._accumulator_frame):
|
||||
if self._aggregating:
|
||||
if self._expect_stripped_words:
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
else:
|
||||
self._aggregation += frame.text
|
||||
# We have recevied a complete sentence, so if we have seen the
|
||||
# end frame and we were still aggregating, it means we should
|
||||
# send the aggregation.
|
||||
send_aggregation = self._seen_end_frame
|
||||
if (
|
||||
self._interrupt_double_accumulator
|
||||
and self._sent_aggregation_after_last_interruption
|
||||
):
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
self._sent_aggregation_after_last_interruption = False
|
||||
|
||||
if self._expect_stripped_words:
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
else:
|
||||
self._aggregation += frame.text
|
||||
|
||||
# If we haven't seen the start frame but we got an accumulator frame
|
||||
# it means two things: it was develiver before the end frame or it
|
||||
# was delivered late. In both cases so we want to send the
|
||||
# aggregation.
|
||||
send_aggregation = not self._seen_start_frame
|
||||
|
||||
# We just got our final result, so let's reset interim results.
|
||||
self._seen_interim_results = False
|
||||
elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame):
|
||||
if (
|
||||
self._interrupt_double_accumulator
|
||||
and self._sent_aggregation_after_last_interruption
|
||||
):
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
self._sent_aggregation_after_last_interruption = False
|
||||
self._seen_interim_results = True
|
||||
elif self._handle_interruptions and isinstance(frame, StartInterruptionFrame):
|
||||
elif isinstance(frame, StartInterruptionFrame) and self._handle_interruptions:
|
||||
await self._push_aggregation()
|
||||
# Reset anyways
|
||||
self._reset()
|
||||
@@ -142,6 +154,9 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
if send_aggregation:
|
||||
await self._push_aggregation()
|
||||
|
||||
if isinstance(frame, self._end_frame):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
self._messages.append({"role": self._role, "content": self._aggregation})
|
||||
@@ -150,6 +165,8 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
self._aggregation = ""
|
||||
|
||||
self._sent_aggregation_after_last_interruption = True
|
||||
|
||||
frame = LLMMessagesFrame(self._messages)
|
||||
await self.push_frame(frame)
|
||||
|
||||
@@ -172,22 +189,11 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
self._seen_start_frame = False
|
||||
self._seen_end_frame = False
|
||||
self._seen_interim_results = False
|
||||
|
||||
|
||||
class LLMAssistantResponseAggregator(LLMResponseAggregator):
|
||||
def __init__(self, messages: List[dict] = []):
|
||||
super().__init__(
|
||||
messages=messages,
|
||||
role="assistant",
|
||||
start_frame=LLMFullResponseStartFrame,
|
||||
end_frame=LLMFullResponseEndFrame,
|
||||
accumulator_frame=TextFrame,
|
||||
handle_interruptions=True,
|
||||
)
|
||||
self._sent_aggregation_after_last_interruption = False
|
||||
|
||||
|
||||
class LLMUserResponseAggregator(LLMResponseAggregator):
|
||||
def __init__(self, messages: List[dict] = []):
|
||||
def __init__(self, messages: List[dict] = [], **kwargs):
|
||||
super().__init__(
|
||||
messages=messages,
|
||||
role="user",
|
||||
@@ -195,61 +201,21 @@ class LLMUserResponseAggregator(LLMResponseAggregator):
|
||||
end_frame=UserStoppedSpeakingFrame,
|
||||
accumulator_frame=TranscriptionFrame,
|
||||
interim_accumulator_frame=InterimTranscriptionFrame,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class LLMFullResponseAggregator(FrameProcessor):
|
||||
"""This class aggregates Text frames until it receives a
|
||||
LLMFullResponseEndFrame, then emits the concatenated text as
|
||||
a single text frame.
|
||||
|
||||
given the following frames:
|
||||
|
||||
TextFrame("Hello,")
|
||||
TextFrame(" world.")
|
||||
TextFrame(" I am")
|
||||
TextFrame(" an LLM.")
|
||||
LLMFullResponseEndFrame()]
|
||||
|
||||
this processor will yield nothing for the first 4 frames, then
|
||||
|
||||
TextFrame("Hello, world. I am an LLM.")
|
||||
LLMFullResponseEndFrame()
|
||||
|
||||
when passed the last frame.
|
||||
|
||||
>>> async def print_frames(aggregator, frame):
|
||||
... async for frame in aggregator.process_frame(frame):
|
||||
... if isinstance(frame, TextFrame):
|
||||
... print(frame.text)
|
||||
... else:
|
||||
... print(frame.__class__.__name__)
|
||||
|
||||
>>> aggregator = LLMFullResponseAggregator()
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame("Hello,")))
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame(" world.")))
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame(" I am")))
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM.")))
|
||||
>>> asyncio.run(print_frames(aggregator, LLMFullResponseEndFrame()))
|
||||
Hello, world. I am an LLM.
|
||||
LLMFullResponseEndFrame
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._aggregation = ""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
self._aggregation += frame.text
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self.push_frame(TextFrame(self._aggregation))
|
||||
await self.push_frame(frame)
|
||||
self._aggregation = ""
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
class LLMAssistantResponseAggregator(LLMResponseAggregator):
|
||||
def __init__(self, messages: List[dict] = [], **kwargs):
|
||||
super().__init__(
|
||||
messages=messages,
|
||||
role="assistant",
|
||||
start_frame=LLMFullResponseStartFrame,
|
||||
end_frame=LLMFullResponseEndFrame,
|
||||
accumulator_frame=TextFrame,
|
||||
handle_interruptions=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class LLMContextAggregator(LLMResponseAggregator):
|
||||
@@ -286,15 +252,14 @@ class LLMContextAggregator(LLMResponseAggregator):
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
self._aggregation = ""
|
||||
|
||||
self._sent_aggregation_after_last_interruption = True
|
||||
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Reset our accumulator state.
|
||||
self._reset()
|
||||
|
||||
|
||||
class LLMAssistantContextAggregator(LLMContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True):
|
||||
def __init__(self, context: OpenAILLMContext, **kwargs):
|
||||
super().__init__(
|
||||
messages=[],
|
||||
context=context,
|
||||
@@ -303,12 +268,12 @@ class LLMAssistantContextAggregator(LLMContextAggregator):
|
||||
end_frame=LLMFullResponseEndFrame,
|
||||
accumulator_frame=TextFrame,
|
||||
handle_interruptions=True,
|
||||
expect_stripped_words=expect_stripped_words,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class LLMUserContextAggregator(LLMContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext):
|
||||
def __init__(self, context: OpenAILLMContext, **kwargs):
|
||||
super().__init__(
|
||||
messages=[],
|
||||
context=context,
|
||||
@@ -317,4 +282,69 @@ class LLMUserContextAggregator(LLMContextAggregator):
|
||||
end_frame=UserStoppedSpeakingFrame,
|
||||
accumulator_frame=TranscriptionFrame,
|
||||
interim_accumulator_frame=InterimTranscriptionFrame,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class LLMFullResponseAggregator(FrameProcessor):
|
||||
"""This class aggregates Text frames between LLMFullResponseStartFrame and
|
||||
LLMFullResponseEndFrame, then emits the concatenated text as a single text
|
||||
frame.
|
||||
|
||||
given the following frames:
|
||||
|
||||
LLMFullResponseStartFrame()
|
||||
TextFrame("Hello,")
|
||||
TextFrame(" world.")
|
||||
TextFrame(" I am")
|
||||
TextFrame(" an LLM.")
|
||||
LLMFullResponseEndFrame()
|
||||
|
||||
this processor will push,
|
||||
|
||||
LLMFullResponseStartFrame()
|
||||
TextFrame("Hello, world. I am an LLM.")
|
||||
LLMFullResponseEndFrame()
|
||||
|
||||
when passed the last frame.
|
||||
|
||||
>>> async def print_frames(aggregator, frame):
|
||||
... async for frame in aggregator.process_frame(frame):
|
||||
... if isinstance(frame, TextFrame):
|
||||
... print(frame.text)
|
||||
... else:
|
||||
... print(frame.__class__.__name__)
|
||||
|
||||
>>> aggregator = LLMFullResponseAggregator()
|
||||
>>> asyncio.run(print_frames(aggregator, LLMFullResponseStartFrame()))
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame("Hello,")))
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame(" world.")))
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame(" I am")))
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM.")))
|
||||
>>> asyncio.run(print_frames(aggregator, LLMFullResponseEndFrame()))
|
||||
LLMFullResponseStartFrame
|
||||
Hello, world. I am an LLM.
|
||||
LLMFullResponseEndFrame
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._aggregation = ""
|
||||
self._seen_start_frame = False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, LLMFullResponseStartFrame):
|
||||
self._seen_start_frame = True
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
self._seen_start_frame = False
|
||||
await self.push_frame(TextFrame(self._aggregation))
|
||||
await self.push_frame(frame)
|
||||
self._aggregation = ""
|
||||
elif isinstance(frame, TextFrame) and self._seen_start_frame:
|
||||
self._aggregation += frame.text
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -311,8 +311,15 @@ class FrameProcessor:
|
||||
self.__push_frame_task = self.get_event_loop().create_task(self.__push_frame_task_handler())
|
||||
|
||||
async def __cancel_push_task(self):
|
||||
self.__push_frame_task.cancel()
|
||||
await self.__push_frame_task
|
||||
try:
|
||||
self.__push_frame_task.cancel()
|
||||
await self.__push_frame_task
|
||||
except asyncio.CancelledError:
|
||||
# TODO(aleix: Investigate why this is really needed. So far, this is
|
||||
# necessary because of how pytest works. If a task is cancelled,
|
||||
# pytest will know the task has been cancelled even if
|
||||
# `asyncio.CancelledError` is handled internally in the task.
|
||||
pass
|
||||
|
||||
async def __push_frame_task_handler(self):
|
||||
running = True
|
||||
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/processors/__init__.py
Normal file
0
tests/processors/__init__.py
Normal file
0
tests/processors/aggregators/__init__.py
Normal file
0
tests/processors/aggregators/__init__.py
Normal file
370
tests/processors/aggregators/test_llm_response.py
Normal file
370
tests/processors/aggregators/test_llm_response.py
Normal file
@@ -0,0 +1,370 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantContextAggregator,
|
||||
LLMFullResponseAggregator,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from tests.utils import run_test
|
||||
|
||||
|
||||
class TestLLMUserContextAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
# S E ->
|
||||
async def test_s_e(self):
|
||||
"""S E case"""
|
||||
context_aggregator = LLMUserContextAggregator(
|
||||
OpenAILLMContext(messages=[{"role": "", "content": ""}])
|
||||
)
|
||||
frames_to_send = [
|
||||
StartInterruptionFrame(),
|
||||
UserStartedSpeakingFrame(),
|
||||
StopInterruptionFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_returned_frames = [
|
||||
StartInterruptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
StopInterruptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
]
|
||||
await run_test(context_aggregator, frames_to_send, expected_returned_frames)
|
||||
|
||||
# S T E -> T
|
||||
async def test_s_t_e(self):
|
||||
"""S T E case"""
|
||||
context_aggregator = LLMUserContextAggregator(
|
||||
OpenAILLMContext(messages=[{"role": "", "content": ""}])
|
||||
)
|
||||
frames_to_send = [
|
||||
StartInterruptionFrame(),
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame("Hello", "", ""),
|
||||
StopInterruptionFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_returned_frames = [
|
||||
StartInterruptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
StopInterruptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
]
|
||||
await run_test(context_aggregator, frames_to_send, expected_returned_frames)
|
||||
|
||||
# S I T E -> T
|
||||
async def test_s_i_t_e(self):
|
||||
"""S I T E case"""
|
||||
context_aggregator = LLMUserContextAggregator(
|
||||
OpenAILLMContext(messages=[{"role": "", "content": ""}])
|
||||
)
|
||||
frames_to_send = [
|
||||
StartInterruptionFrame(),
|
||||
UserStartedSpeakingFrame(),
|
||||
InterimTranscriptionFrame("This", "", ""),
|
||||
TranscriptionFrame("This is a test", "", ""),
|
||||
StopInterruptionFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_returned_frames = [
|
||||
StartInterruptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
StopInterruptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
]
|
||||
await run_test(context_aggregator, frames_to_send, expected_returned_frames)
|
||||
|
||||
# S I E T -> T
|
||||
async def test_s_i_e_t(self):
|
||||
"""S I E T case"""
|
||||
context_aggregator = LLMUserContextAggregator(
|
||||
OpenAILLMContext(messages=[{"role": "", "content": ""}])
|
||||
)
|
||||
frames_to_send = [
|
||||
StartInterruptionFrame(),
|
||||
UserStartedSpeakingFrame(),
|
||||
InterimTranscriptionFrame("This", "", ""),
|
||||
StopInterruptionFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
TranscriptionFrame("This is a test", "", ""),
|
||||
]
|
||||
expected_returned_frames = [
|
||||
StartInterruptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
StopInterruptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
]
|
||||
await run_test(context_aggregator, frames_to_send, expected_returned_frames)
|
||||
|
||||
# S I E I T -> T
|
||||
async def test_s_i_e_i_t(self):
|
||||
"""S I E I T case"""
|
||||
context_aggregator = LLMUserContextAggregator(
|
||||
OpenAILLMContext(messages=[{"role": "", "content": ""}])
|
||||
)
|
||||
frames_to_send = [
|
||||
StartInterruptionFrame(),
|
||||
UserStartedSpeakingFrame(),
|
||||
InterimTranscriptionFrame("This", "", ""),
|
||||
StopInterruptionFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
InterimTranscriptionFrame("This is", "", ""),
|
||||
TranscriptionFrame("This is a test", "", ""),
|
||||
]
|
||||
expected_returned_frames = [
|
||||
StartInterruptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
StopInterruptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
]
|
||||
await run_test(context_aggregator, frames_to_send, expected_returned_frames)
|
||||
|
||||
# S E T -> T
|
||||
async def test_s_e_t(self):
|
||||
"""S E case"""
|
||||
context_aggregator = LLMUserContextAggregator(
|
||||
OpenAILLMContext(messages=[{"role": "", "content": ""}])
|
||||
)
|
||||
frames_to_send = [
|
||||
StartInterruptionFrame(),
|
||||
UserStartedSpeakingFrame(),
|
||||
StopInterruptionFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
TranscriptionFrame("This is a test", "", ""),
|
||||
]
|
||||
expected_returned_frames = [
|
||||
StartInterruptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
StopInterruptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
]
|
||||
await run_test(context_aggregator, frames_to_send, expected_returned_frames)
|
||||
|
||||
# S E I T -> T
|
||||
async def test_s_e_i_t(self):
|
||||
"""S E I T case"""
|
||||
context_aggregator = LLMUserContextAggregator(
|
||||
OpenAILLMContext(messages=[{"role": "", "content": ""}])
|
||||
)
|
||||
frames_to_send = [
|
||||
StartInterruptionFrame(),
|
||||
UserStartedSpeakingFrame(),
|
||||
StopInterruptionFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
InterimTranscriptionFrame("This", "", ""),
|
||||
TranscriptionFrame("This is a test", "", ""),
|
||||
]
|
||||
expected_returned_frames = [
|
||||
StartInterruptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
StopInterruptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
]
|
||||
await run_test(context_aggregator, frames_to_send, expected_returned_frames)
|
||||
|
||||
# S T1 I E S T2 E -> "T1 T2"
|
||||
async def test_s_t1_i_e_s_t2_e(self):
|
||||
"""S T1 I E S T2 E case"""
|
||||
context_aggregator = LLMUserContextAggregator(
|
||||
OpenAILLMContext(messages=[{"role": "", "content": ""}])
|
||||
)
|
||||
frames_to_send = [
|
||||
StartInterruptionFrame(),
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame("T1", "", ""),
|
||||
InterimTranscriptionFrame("", "", ""),
|
||||
StopInterruptionFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
StartInterruptionFrame(),
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame("T2", "", ""),
|
||||
StopInterruptionFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_returned_frames = [
|
||||
StartInterruptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
StopInterruptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
StartInterruptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
StopInterruptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
]
|
||||
(received_down, _) = await run_test(
|
||||
context_aggregator, frames_to_send, expected_returned_frames
|
||||
)
|
||||
assert received_down[-1].context.messages[-1]["content"] == "T1 T2"
|
||||
|
||||
# S I E T1 I T2 -> T1 Interruption T2
|
||||
async def test_s_i_e_t1_i_t2(self):
|
||||
"""S I E T1 I T2 case"""
|
||||
context_aggregator = LLMUserContextAggregator(
|
||||
OpenAILLMContext(messages=[{"role": "", "content": ""}])
|
||||
)
|
||||
frames_to_send = [
|
||||
StartInterruptionFrame(),
|
||||
UserStartedSpeakingFrame(),
|
||||
InterimTranscriptionFrame("", "", ""),
|
||||
StopInterruptionFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
TranscriptionFrame("T1", "", ""),
|
||||
InterimTranscriptionFrame("", "", ""),
|
||||
TranscriptionFrame("T2", "", ""),
|
||||
]
|
||||
expected_down_frames = [
|
||||
StartInterruptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
StopInterruptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
OpenAILLMContextFrame,
|
||||
]
|
||||
expected_up_frames = [
|
||||
BotInterruptionFrame,
|
||||
]
|
||||
(received_down, _) = await run_test(
|
||||
context_aggregator, frames_to_send, expected_down_frames, expected_up_frames
|
||||
)
|
||||
assert received_down[-1].context.messages[-2]["content"] == "T1"
|
||||
assert received_down[-1].context.messages[-1]["content"] == "T2"
|
||||
|
||||
# S T1 E T2 -> T1 Interruption T2
|
||||
async def test_s_t1_e_t2(self):
|
||||
"""S T1 E T2 case"""
|
||||
context_aggregator = LLMUserContextAggregator(
|
||||
OpenAILLMContext(messages=[{"role": "", "content": ""}])
|
||||
)
|
||||
frames_to_send = [
|
||||
StartInterruptionFrame(),
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame("T1", "", ""),
|
||||
StopInterruptionFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
TranscriptionFrame("T2", "", ""),
|
||||
]
|
||||
expected_down_frames = [
|
||||
StartInterruptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
StopInterruptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
OpenAILLMContextFrame,
|
||||
]
|
||||
expected_up_frames = [
|
||||
BotInterruptionFrame,
|
||||
]
|
||||
(received_down, _) = await run_test(
|
||||
context_aggregator, frames_to_send, expected_down_frames, expected_up_frames
|
||||
)
|
||||
assert received_down[-1].context.messages[-2]["content"] == "T1"
|
||||
assert received_down[-1].context.messages[-1]["content"] == "T2"
|
||||
|
||||
# S E T1 T2 -> T1 Interruption T2
|
||||
async def test_s_e_t1_t2(self):
|
||||
"""S E T1 T2 case"""
|
||||
context_aggregator = LLMUserContextAggregator(
|
||||
OpenAILLMContext(messages=[{"role": "", "content": ""}])
|
||||
)
|
||||
frames_to_send = [
|
||||
StartInterruptionFrame(),
|
||||
UserStartedSpeakingFrame(),
|
||||
StopInterruptionFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
TranscriptionFrame("T1", "", ""),
|
||||
TranscriptionFrame("T2", "", ""),
|
||||
]
|
||||
expected_down_frames = [
|
||||
StartInterruptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
StopInterruptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
OpenAILLMContextFrame,
|
||||
]
|
||||
expected_up_frames = [
|
||||
BotInterruptionFrame,
|
||||
]
|
||||
(received_down, _) = await run_test(
|
||||
context_aggregator, frames_to_send, expected_down_frames, expected_up_frames
|
||||
)
|
||||
assert received_down[-1].context.messages[-2]["content"] == "T1"
|
||||
assert received_down[-1].context.messages[-1]["content"] == "T2"
|
||||
|
||||
|
||||
class TestLLMAssistantContextAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
# S T E -> T
|
||||
async def test_s_t_e(self):
|
||||
"""S T E case"""
|
||||
context_aggregator = LLMAssistantContextAggregator(
|
||||
OpenAILLMContext(messages=[{"role": "", "content": ""}])
|
||||
)
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame("Hello this is Pipecat speaking!"),
|
||||
TextFrame("How are you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_returned_frames = [
|
||||
LLMFullResponseStartFrame,
|
||||
OpenAILLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
]
|
||||
(received_down, _) = await run_test(
|
||||
context_aggregator, frames_to_send, expected_returned_frames
|
||||
)
|
||||
assert (
|
||||
received_down[-2].context.messages[-1]["content"]
|
||||
== "Hello this is Pipecat speaking! How are you?"
|
||||
)
|
||||
|
||||
|
||||
class TestLLMFullResponseAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
# S T E -> T
|
||||
async def test_s_t_e(self):
|
||||
"""S T E case"""
|
||||
response_aggregator = LLMFullResponseAggregator()
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame("Hello "),
|
||||
TextFrame("this "),
|
||||
TextFrame("is "),
|
||||
TextFrame("Pipecat!"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_returned_frames = [
|
||||
LLMFullResponseStartFrame,
|
||||
TextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
]
|
||||
(received_down, _) = await run_test(
|
||||
response_aggregator, frames_to_send, expected_returned_frames
|
||||
)
|
||||
assert received_down[-2].text == "Hello this is Pipecat!"
|
||||
0
tests/processors/frameworks/__init__.py
Normal file
0
tests/processors/frameworks/__init__.py
Normal file
0
tests/services/__init__.py
Normal file
0
tests/services/__init__.py
Normal file
96
tests/utils.py
Normal file
96
tests/utils.py
Normal file
@@ -0,0 +1,96 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
from pipecat.clocks.system_clock import SystemClock
|
||||
from pipecat.frames.frames import (
|
||||
ControlFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndTestFrame(ControlFrame):
|
||||
pass
|
||||
|
||||
|
||||
class QueuedFrameProcessor(FrameProcessor):
|
||||
def __init__(self, queue: asyncio.Queue, ignore_start: bool = True):
|
||||
super().__init__()
|
||||
self._queue = queue
|
||||
self._ignore_start = ignore_start
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if self._ignore_start and isinstance(frame, StartFrame):
|
||||
return
|
||||
await self._queue.put(frame)
|
||||
|
||||
|
||||
async def run_test(
|
||||
processor: FrameProcessor,
|
||||
frames_to_send: List[Frame],
|
||||
expected_down_frames: List[type],
|
||||
expected_up_frames: List[type] = [],
|
||||
) -> Tuple[List[Frame], List[Frame]]:
|
||||
received_up = asyncio.Queue()
|
||||
received_down = asyncio.Queue()
|
||||
up_processor = QueuedFrameProcessor(received_up)
|
||||
down_processor = QueuedFrameProcessor(received_down)
|
||||
|
||||
up_processor.link(processor)
|
||||
processor.link(down_processor)
|
||||
|
||||
await processor.queue_frame(StartFrame(clock=SystemClock()))
|
||||
|
||||
for frame in frames_to_send:
|
||||
await processor.process_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
await processor.queue_frame(EndTestFrame())
|
||||
await processor.queue_frame(EndTestFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
#
|
||||
# Down frames
|
||||
#
|
||||
received_down_frames: List[Frame] = []
|
||||
running = True
|
||||
while running:
|
||||
frame = await received_down.get()
|
||||
running = not isinstance(frame, EndTestFrame)
|
||||
if running:
|
||||
received_down_frames.append(frame)
|
||||
|
||||
print("received DOWN frames =", received_down_frames)
|
||||
|
||||
assert len(received_down_frames) == len(expected_down_frames)
|
||||
|
||||
for real, expected in zip(received_down_frames, expected_down_frames):
|
||||
assert isinstance(real, expected)
|
||||
|
||||
#
|
||||
# Up frames
|
||||
#
|
||||
received_up_frames: List[Frame] = []
|
||||
running = True
|
||||
while running:
|
||||
frame = await received_up.get()
|
||||
running = not isinstance(frame, EndTestFrame)
|
||||
if running:
|
||||
received_up_frames.append(frame)
|
||||
|
||||
print("received UP frames =", received_up_frames)
|
||||
|
||||
assert len(received_up_frames) == len(expected_up_frames)
|
||||
|
||||
for real, expected in zip(received_up_frames, expected_up_frames):
|
||||
assert isinstance(real, expected)
|
||||
|
||||
return (received_down_frames, received_up_frames)
|
||||
Reference in New Issue
Block a user