Compare commits

...

5 Commits

Author SHA1 Message Date
Aleix Conchillo Flaqué
eef3f320b1 update README with KoalaFilter 2024-12-23 18:15:30 -08:00
Aleix Conchillo Flaqué
dfdd536b20 improved unit tests / add a run_test function to test processors 2024-12-23 18:15:30 -08:00
Aleix Conchillo Flaqué
ba6e9ed9ad processors(frame_processors): add a try/except when cancelling tasks
This seems necessary because of how pytest works. If a task is cancelled, pytest
will know the task has been cancelled even if # `asyncio.CancelledError` is
handled internally in the task.
2024-12-23 18:15:25 -08:00
Aleix Conchillo Flaqué
2dd56ba992 processors(llm_response): unify new use cases into base class 2024-12-23 18:15:25 -08:00
edgar_git
c989c9c16d LLM user frame processor with tests 2024-12-22 15:18:33 -08:00
21 changed files with 612 additions and 103 deletions

View File

@@ -49,4 +49,4 @@ jobs:
- name: Test with pytest
run: |
source .venv/bin/activate
pytest --ignore-glob="*to_be_updated*" --ignore-glob=*pipeline_source* src tests
pytest

View File

@@ -75,6 +75,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Fixed LLM response aggregators to support more uses cases such as delayed
transcriptions.
- Fixed an issue that could cause the bot to stop talking if there was a user
interruption before getting any audio from the TTS service.

View File

@@ -64,7 +64,7 @@ Available options include:
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), WebSocket, Local | `pip install "pipecat-ai[daily]"` |
| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) | `pip install "pipecat-ai[tavus,simli]"` |
| Vision & Image | [Moondream](https://docs.pipecat.ai/server/services/vision/moondream), [fal](https://docs.pipecat.ai/server/services/image-generation/fal) | `pip install "pipecat-ai[moondream]"` |
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) | `pip install "pipecat-ai[silero]"` |
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter) | `pip install "pipecat-ai[silero]"` |
| Analytics & Metrics | [Canonical AI](https://docs.pipecat.ai/server/services/analytics/canonical), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) | `pip install "pipecat-ai[canonical]"` |
📚 [View full services documentation →](https://docs.pipecat.ai/server/services/supported-services)

View File

@@ -82,7 +82,10 @@ whisper = [ "faster-whisper~=1.1.0" ]
where = ["src"]
[tool.pytest.ini_options]
addopts = "--verbose --disable-warnings"
testpaths = ["tests"]
pythonpath = ["src"]
asyncio_default_fixture_loop_scope = "function"
[tool.setuptools_scm]
local_scheme = "no-local-version"

View File

@@ -7,6 +7,7 @@
from typing import List, Type
from pipecat.frames.frames import (
BotInterruptionFrame,
Frame,
InterimTranscriptionFrame,
LLMFullResponseEndFrame,
@@ -40,6 +41,7 @@ class LLMResponseAggregator(FrameProcessor):
interim_accumulator_frame: Type[TextFrame] | None = None,
handle_interruptions: bool = False,
expect_stripped_words: bool = True, # if True, need to add spaces between words
interrupt_double_accumulator: bool = True, # if True, interrupt if two or more accumulators are received
):
super().__init__()
@@ -51,8 +53,8 @@ class LLMResponseAggregator(FrameProcessor):
self._interim_accumulator_frame = interim_accumulator_frame
self._handle_interruptions = handle_interruptions
self._expect_stripped_words = expect_stripped_words
self._interrupt_double_accumulator = interrupt_double_accumulator
# Reset our accumulator state.
self._reset()
@property
@@ -69,21 +71,20 @@ class LLMResponseAggregator(FrameProcessor):
# Use cases implemented:
#
# S: Start, E: End, T: Transcription, I: Interim, X: Text
# S: Start, E: End, T: Transcription, I: Interim
#
# S E -> None
# S T E -> X
# S I T E -> X
# S I E T -> X
# S I E I T -> X
# S E T -> X
# S E I T -> X
#
# The following case would not be supported:
#
# S I E T1 I T2 -> X
#
# and T2 would be dropped.
# S E -> None -> User started speaking but no transcription.
# S T E -> T -> Transcription between user started and stopped speaking.
# S E T -> T -> Transcription after user stopped speaking.
# S I T E -> T -> Transcription between user started and stopped speaking (with interims).
# S I E T -> T -> Transcription after user stopped speaking (with interims).
# S I E I T -> T -> Transcription after user stopped speaking (with interims).
# S E I T -> T -> Transcription after user stopped speaking (with interims).
# S T1 I E S T2 E -> "T1 T2" -> Merge two transcriptions if we got a first interim.
# S I E T1 I T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription.
# S T1 E T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription.
# S E T1 B T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription.
# S E T1 T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription.
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
@@ -91,11 +92,9 @@ class LLMResponseAggregator(FrameProcessor):
send_aggregation = False
if isinstance(frame, self._start_frame):
self._aggregation = ""
self._aggregating = True
self._seen_start_frame = True
self._seen_end_frame = False
self._seen_interim_results = False
await self.push_frame(frame, direction)
elif isinstance(frame, self._end_frame):
self._seen_end_frame = True
@@ -109,23 +108,36 @@ class LLMResponseAggregator(FrameProcessor):
# Send the aggregation if we are not aggregating anymore (i.e. no
# more interim results received).
send_aggregation = not self._aggregating
await self.push_frame(frame, direction)
elif isinstance(frame, self._accumulator_frame):
if self._aggregating:
if self._expect_stripped_words:
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
else:
self._aggregation += frame.text
# We have recevied a complete sentence, so if we have seen the
# end frame and we were still aggregating, it means we should
# send the aggregation.
send_aggregation = self._seen_end_frame
if (
self._interrupt_double_accumulator
and self._sent_aggregation_after_last_interruption
):
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
self._sent_aggregation_after_last_interruption = False
if self._expect_stripped_words:
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
else:
self._aggregation += frame.text
# If we haven't seen the start frame but we got an accumulator frame
# it means two things: it was develiver before the end frame or it
# was delivered late. In both cases so we want to send the
# aggregation.
send_aggregation = not self._seen_start_frame
# We just got our final result, so let's reset interim results.
self._seen_interim_results = False
elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame):
if (
self._interrupt_double_accumulator
and self._sent_aggregation_after_last_interruption
):
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
self._sent_aggregation_after_last_interruption = False
self._seen_interim_results = True
elif self._handle_interruptions and isinstance(frame, StartInterruptionFrame):
elif isinstance(frame, StartInterruptionFrame) and self._handle_interruptions:
await self._push_aggregation()
# Reset anyways
self._reset()
@@ -142,6 +154,9 @@ class LLMResponseAggregator(FrameProcessor):
if send_aggregation:
await self._push_aggregation()
if isinstance(frame, self._end_frame):
await self.push_frame(frame, direction)
async def _push_aggregation(self):
if len(self._aggregation) > 0:
self._messages.append({"role": self._role, "content": self._aggregation})
@@ -150,6 +165,8 @@ class LLMResponseAggregator(FrameProcessor):
# if the tasks gets cancelled we won't be able to clear things up.
self._aggregation = ""
self._sent_aggregation_after_last_interruption = True
frame = LLMMessagesFrame(self._messages)
await self.push_frame(frame)
@@ -172,22 +189,11 @@ class LLMResponseAggregator(FrameProcessor):
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False
class LLMAssistantResponseAggregator(LLMResponseAggregator):
def __init__(self, messages: List[dict] = []):
super().__init__(
messages=messages,
role="assistant",
start_frame=LLMFullResponseStartFrame,
end_frame=LLMFullResponseEndFrame,
accumulator_frame=TextFrame,
handle_interruptions=True,
)
self._sent_aggregation_after_last_interruption = False
class LLMUserResponseAggregator(LLMResponseAggregator):
def __init__(self, messages: List[dict] = []):
def __init__(self, messages: List[dict] = [], **kwargs):
super().__init__(
messages=messages,
role="user",
@@ -195,61 +201,21 @@ class LLMUserResponseAggregator(LLMResponseAggregator):
end_frame=UserStoppedSpeakingFrame,
accumulator_frame=TranscriptionFrame,
interim_accumulator_frame=InterimTranscriptionFrame,
**kwargs,
)
class LLMFullResponseAggregator(FrameProcessor):
"""This class aggregates Text frames until it receives a
LLMFullResponseEndFrame, then emits the concatenated text as
a single text frame.
given the following frames:
TextFrame("Hello,")
TextFrame(" world.")
TextFrame(" I am")
TextFrame(" an LLM.")
LLMFullResponseEndFrame()]
this processor will yield nothing for the first 4 frames, then
TextFrame("Hello, world. I am an LLM.")
LLMFullResponseEndFrame()
when passed the last frame.
>>> async def print_frames(aggregator, frame):
... async for frame in aggregator.process_frame(frame):
... if isinstance(frame, TextFrame):
... print(frame.text)
... else:
... print(frame.__class__.__name__)
>>> aggregator = LLMFullResponseAggregator()
>>> asyncio.run(print_frames(aggregator, TextFrame("Hello,")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" world.")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" I am")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM.")))
>>> asyncio.run(print_frames(aggregator, LLMFullResponseEndFrame()))
Hello, world. I am an LLM.
LLMFullResponseEndFrame
"""
def __init__(self):
super().__init__()
self._aggregation = ""
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TextFrame):
self._aggregation += frame.text
elif isinstance(frame, LLMFullResponseEndFrame):
await self.push_frame(TextFrame(self._aggregation))
await self.push_frame(frame)
self._aggregation = ""
else:
await self.push_frame(frame, direction)
class LLMAssistantResponseAggregator(LLMResponseAggregator):
def __init__(self, messages: List[dict] = [], **kwargs):
super().__init__(
messages=messages,
role="assistant",
start_frame=LLMFullResponseStartFrame,
end_frame=LLMFullResponseEndFrame,
accumulator_frame=TextFrame,
handle_interruptions=True,
**kwargs,
)
class LLMContextAggregator(LLMResponseAggregator):
@@ -286,15 +252,14 @@ class LLMContextAggregator(LLMResponseAggregator):
# if the tasks gets cancelled we won't be able to clear things up.
self._aggregation = ""
self._sent_aggregation_after_last_interruption = True
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
# Reset our accumulator state.
self._reset()
class LLMAssistantContextAggregator(LLMContextAggregator):
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True):
def __init__(self, context: OpenAILLMContext, **kwargs):
super().__init__(
messages=[],
context=context,
@@ -303,12 +268,12 @@ class LLMAssistantContextAggregator(LLMContextAggregator):
end_frame=LLMFullResponseEndFrame,
accumulator_frame=TextFrame,
handle_interruptions=True,
expect_stripped_words=expect_stripped_words,
**kwargs,
)
class LLMUserContextAggregator(LLMContextAggregator):
def __init__(self, context: OpenAILLMContext):
def __init__(self, context: OpenAILLMContext, **kwargs):
super().__init__(
messages=[],
context=context,
@@ -317,4 +282,69 @@ class LLMUserContextAggregator(LLMContextAggregator):
end_frame=UserStoppedSpeakingFrame,
accumulator_frame=TranscriptionFrame,
interim_accumulator_frame=InterimTranscriptionFrame,
**kwargs,
)
class LLMFullResponseAggregator(FrameProcessor):
"""This class aggregates Text frames between LLMFullResponseStartFrame and
LLMFullResponseEndFrame, then emits the concatenated text as a single text
frame.
given the following frames:
LLMFullResponseStartFrame()
TextFrame("Hello,")
TextFrame(" world.")
TextFrame(" I am")
TextFrame(" an LLM.")
LLMFullResponseEndFrame()
this processor will push,
LLMFullResponseStartFrame()
TextFrame("Hello, world. I am an LLM.")
LLMFullResponseEndFrame()
when passed the last frame.
>>> async def print_frames(aggregator, frame):
... async for frame in aggregator.process_frame(frame):
... if isinstance(frame, TextFrame):
... print(frame.text)
... else:
... print(frame.__class__.__name__)
>>> aggregator = LLMFullResponseAggregator()
>>> asyncio.run(print_frames(aggregator, LLMFullResponseStartFrame()))
>>> asyncio.run(print_frames(aggregator, TextFrame("Hello,")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" world.")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" I am")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM.")))
>>> asyncio.run(print_frames(aggregator, LLMFullResponseEndFrame()))
LLMFullResponseStartFrame
Hello, world. I am an LLM.
LLMFullResponseEndFrame
"""
def __init__(self):
super().__init__()
self._aggregation = ""
self._seen_start_frame = False
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, LLMFullResponseStartFrame):
self._seen_start_frame = True
await self.push_frame(frame, direction)
elif isinstance(frame, LLMFullResponseEndFrame):
self._seen_start_frame = False
await self.push_frame(TextFrame(self._aggregation))
await self.push_frame(frame)
self._aggregation = ""
elif isinstance(frame, TextFrame) and self._seen_start_frame:
self._aggregation += frame.text
else:
await self.push_frame(frame, direction)

View File

@@ -311,8 +311,15 @@ class FrameProcessor:
self.__push_frame_task = self.get_event_loop().create_task(self.__push_frame_task_handler())
async def __cancel_push_task(self):
self.__push_frame_task.cancel()
await self.__push_frame_task
try:
self.__push_frame_task.cancel()
await self.__push_frame_task
except asyncio.CancelledError:
# TODO(aleix: Investigate why this is really needed. So far, this is
# necessary because of how pytest works. If a task is cancelled,
# pytest will know the task has been cancelled even if
# `asyncio.CancelledError` is handled internally in the task.
pass
async def __push_frame_task_handler(self):
running = True

0
tests/__init__.py Normal file
View File

View File

View File

View File

@@ -0,0 +1,370 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import unittest
from pipecat.frames.frames import (
BotInterruptionFrame,
InterimTranscriptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
StartInterruptionFrame,
StopInterruptionFrame,
TextFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.processors.aggregators.llm_response import (
LLMAssistantContextAggregator,
LLMFullResponseAggregator,
LLMUserContextAggregator,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from tests.utils import run_test
class TestLLMUserContextAggregator(unittest.IsolatedAsyncioTestCase):
# S E ->
async def test_s_e(self):
"""S E case"""
context_aggregator = LLMUserContextAggregator(
OpenAILLMContext(messages=[{"role": "", "content": ""}])
)
frames_to_send = [
StartInterruptionFrame(),
UserStartedSpeakingFrame(),
StopInterruptionFrame(),
UserStoppedSpeakingFrame(),
]
expected_returned_frames = [
StartInterruptionFrame,
UserStartedSpeakingFrame,
StopInterruptionFrame,
UserStoppedSpeakingFrame,
]
await run_test(context_aggregator, frames_to_send, expected_returned_frames)
# S T E -> T
async def test_s_t_e(self):
"""S T E case"""
context_aggregator = LLMUserContextAggregator(
OpenAILLMContext(messages=[{"role": "", "content": ""}])
)
frames_to_send = [
StartInterruptionFrame(),
UserStartedSpeakingFrame(),
TranscriptionFrame("Hello", "", ""),
StopInterruptionFrame(),
UserStoppedSpeakingFrame(),
]
expected_returned_frames = [
StartInterruptionFrame,
UserStartedSpeakingFrame,
StopInterruptionFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
]
await run_test(context_aggregator, frames_to_send, expected_returned_frames)
# S I T E -> T
async def test_s_i_t_e(self):
"""S I T E case"""
context_aggregator = LLMUserContextAggregator(
OpenAILLMContext(messages=[{"role": "", "content": ""}])
)
frames_to_send = [
StartInterruptionFrame(),
UserStartedSpeakingFrame(),
InterimTranscriptionFrame("This", "", ""),
TranscriptionFrame("This is a test", "", ""),
StopInterruptionFrame(),
UserStoppedSpeakingFrame(),
]
expected_returned_frames = [
StartInterruptionFrame,
UserStartedSpeakingFrame,
StopInterruptionFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
]
await run_test(context_aggregator, frames_to_send, expected_returned_frames)
# S I E T -> T
async def test_s_i_e_t(self):
"""S I E T case"""
context_aggregator = LLMUserContextAggregator(
OpenAILLMContext(messages=[{"role": "", "content": ""}])
)
frames_to_send = [
StartInterruptionFrame(),
UserStartedSpeakingFrame(),
InterimTranscriptionFrame("This", "", ""),
StopInterruptionFrame(),
UserStoppedSpeakingFrame(),
TranscriptionFrame("This is a test", "", ""),
]
expected_returned_frames = [
StartInterruptionFrame,
UserStartedSpeakingFrame,
StopInterruptionFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
]
await run_test(context_aggregator, frames_to_send, expected_returned_frames)
# S I E I T -> T
async def test_s_i_e_i_t(self):
"""S I E I T case"""
context_aggregator = LLMUserContextAggregator(
OpenAILLMContext(messages=[{"role": "", "content": ""}])
)
frames_to_send = [
StartInterruptionFrame(),
UserStartedSpeakingFrame(),
InterimTranscriptionFrame("This", "", ""),
StopInterruptionFrame(),
UserStoppedSpeakingFrame(),
InterimTranscriptionFrame("This is", "", ""),
TranscriptionFrame("This is a test", "", ""),
]
expected_returned_frames = [
StartInterruptionFrame,
UserStartedSpeakingFrame,
StopInterruptionFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
]
await run_test(context_aggregator, frames_to_send, expected_returned_frames)
# S E T -> T
async def test_s_e_t(self):
"""S E case"""
context_aggregator = LLMUserContextAggregator(
OpenAILLMContext(messages=[{"role": "", "content": ""}])
)
frames_to_send = [
StartInterruptionFrame(),
UserStartedSpeakingFrame(),
StopInterruptionFrame(),
UserStoppedSpeakingFrame(),
TranscriptionFrame("This is a test", "", ""),
]
expected_returned_frames = [
StartInterruptionFrame,
UserStartedSpeakingFrame,
StopInterruptionFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
]
await run_test(context_aggregator, frames_to_send, expected_returned_frames)
# S E I T -> T
async def test_s_e_i_t(self):
"""S E I T case"""
context_aggregator = LLMUserContextAggregator(
OpenAILLMContext(messages=[{"role": "", "content": ""}])
)
frames_to_send = [
StartInterruptionFrame(),
UserStartedSpeakingFrame(),
StopInterruptionFrame(),
UserStoppedSpeakingFrame(),
InterimTranscriptionFrame("This", "", ""),
TranscriptionFrame("This is a test", "", ""),
]
expected_returned_frames = [
StartInterruptionFrame,
UserStartedSpeakingFrame,
StopInterruptionFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
]
await run_test(context_aggregator, frames_to_send, expected_returned_frames)
# S T1 I E S T2 E -> "T1 T2"
async def test_s_t1_i_e_s_t2_e(self):
"""S T1 I E S T2 E case"""
context_aggregator = LLMUserContextAggregator(
OpenAILLMContext(messages=[{"role": "", "content": ""}])
)
frames_to_send = [
StartInterruptionFrame(),
UserStartedSpeakingFrame(),
TranscriptionFrame("T1", "", ""),
InterimTranscriptionFrame("", "", ""),
StopInterruptionFrame(),
UserStoppedSpeakingFrame(),
StartInterruptionFrame(),
UserStartedSpeakingFrame(),
TranscriptionFrame("T2", "", ""),
StopInterruptionFrame(),
UserStoppedSpeakingFrame(),
]
expected_returned_frames = [
StartInterruptionFrame,
UserStartedSpeakingFrame,
StopInterruptionFrame,
UserStoppedSpeakingFrame,
StartInterruptionFrame,
UserStartedSpeakingFrame,
StopInterruptionFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
]
(received_down, _) = await run_test(
context_aggregator, frames_to_send, expected_returned_frames
)
assert received_down[-1].context.messages[-1]["content"] == "T1 T2"
# S I E T1 I T2 -> T1 Interruption T2
async def test_s_i_e_t1_i_t2(self):
"""S I E T1 I T2 case"""
context_aggregator = LLMUserContextAggregator(
OpenAILLMContext(messages=[{"role": "", "content": ""}])
)
frames_to_send = [
StartInterruptionFrame(),
UserStartedSpeakingFrame(),
InterimTranscriptionFrame("", "", ""),
StopInterruptionFrame(),
UserStoppedSpeakingFrame(),
TranscriptionFrame("T1", "", ""),
InterimTranscriptionFrame("", "", ""),
TranscriptionFrame("T2", "", ""),
]
expected_down_frames = [
StartInterruptionFrame,
UserStartedSpeakingFrame,
StopInterruptionFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
OpenAILLMContextFrame,
]
expected_up_frames = [
BotInterruptionFrame,
]
(received_down, _) = await run_test(
context_aggregator, frames_to_send, expected_down_frames, expected_up_frames
)
assert received_down[-1].context.messages[-2]["content"] == "T1"
assert received_down[-1].context.messages[-1]["content"] == "T2"
# S T1 E T2 -> T1 Interruption T2
async def test_s_t1_e_t2(self):
"""S T1 E T2 case"""
context_aggregator = LLMUserContextAggregator(
OpenAILLMContext(messages=[{"role": "", "content": ""}])
)
frames_to_send = [
StartInterruptionFrame(),
UserStartedSpeakingFrame(),
TranscriptionFrame("T1", "", ""),
StopInterruptionFrame(),
UserStoppedSpeakingFrame(),
TranscriptionFrame("T2", "", ""),
]
expected_down_frames = [
StartInterruptionFrame,
UserStartedSpeakingFrame,
StopInterruptionFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
OpenAILLMContextFrame,
]
expected_up_frames = [
BotInterruptionFrame,
]
(received_down, _) = await run_test(
context_aggregator, frames_to_send, expected_down_frames, expected_up_frames
)
assert received_down[-1].context.messages[-2]["content"] == "T1"
assert received_down[-1].context.messages[-1]["content"] == "T2"
# S E T1 T2 -> T1 Interruption T2
async def test_s_e_t1_t2(self):
"""S E T1 T2 case"""
context_aggregator = LLMUserContextAggregator(
OpenAILLMContext(messages=[{"role": "", "content": ""}])
)
frames_to_send = [
StartInterruptionFrame(),
UserStartedSpeakingFrame(),
StopInterruptionFrame(),
UserStoppedSpeakingFrame(),
TranscriptionFrame("T1", "", ""),
TranscriptionFrame("T2", "", ""),
]
expected_down_frames = [
StartInterruptionFrame,
UserStartedSpeakingFrame,
StopInterruptionFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
OpenAILLMContextFrame,
]
expected_up_frames = [
BotInterruptionFrame,
]
(received_down, _) = await run_test(
context_aggregator, frames_to_send, expected_down_frames, expected_up_frames
)
assert received_down[-1].context.messages[-2]["content"] == "T1"
assert received_down[-1].context.messages[-1]["content"] == "T2"
class TestLLMAssistantContextAggregator(unittest.IsolatedAsyncioTestCase):
# S T E -> T
async def test_s_t_e(self):
"""S T E case"""
context_aggregator = LLMAssistantContextAggregator(
OpenAILLMContext(messages=[{"role": "", "content": ""}])
)
frames_to_send = [
LLMFullResponseStartFrame(),
TextFrame("Hello this is Pipecat speaking!"),
TextFrame("How are you?"),
LLMFullResponseEndFrame(),
]
expected_returned_frames = [
LLMFullResponseStartFrame,
OpenAILLMContextFrame,
LLMFullResponseEndFrame,
]
(received_down, _) = await run_test(
context_aggregator, frames_to_send, expected_returned_frames
)
assert (
received_down[-2].context.messages[-1]["content"]
== "Hello this is Pipecat speaking! How are you?"
)
class TestLLMFullResponseAggregator(unittest.IsolatedAsyncioTestCase):
# S T E -> T
async def test_s_t_e(self):
"""S T E case"""
response_aggregator = LLMFullResponseAggregator()
frames_to_send = [
LLMFullResponseStartFrame(),
TextFrame("Hello "),
TextFrame("this "),
TextFrame("is "),
TextFrame("Pipecat!"),
LLMFullResponseEndFrame(),
]
expected_returned_frames = [
LLMFullResponseStartFrame,
TextFrame,
LLMFullResponseEndFrame,
]
(received_down, _) = await run_test(
response_aggregator, frames_to_send, expected_returned_frames
)
assert received_down[-2].text == "Hello this is Pipecat!"

View File

View File

96
tests/utils.py Normal file
View File

@@ -0,0 +1,96 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
from dataclasses import dataclass
from typing import List, Tuple
from pipecat.clocks.system_clock import SystemClock
from pipecat.frames.frames import (
ControlFrame,
Frame,
StartFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
@dataclass
class EndTestFrame(ControlFrame):
pass
class QueuedFrameProcessor(FrameProcessor):
def __init__(self, queue: asyncio.Queue, ignore_start: bool = True):
super().__init__()
self._queue = queue
self._ignore_start = ignore_start
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if self._ignore_start and isinstance(frame, StartFrame):
return
await self._queue.put(frame)
async def run_test(
processor: FrameProcessor,
frames_to_send: List[Frame],
expected_down_frames: List[type],
expected_up_frames: List[type] = [],
) -> Tuple[List[Frame], List[Frame]]:
received_up = asyncio.Queue()
received_down = asyncio.Queue()
up_processor = QueuedFrameProcessor(received_up)
down_processor = QueuedFrameProcessor(received_down)
up_processor.link(processor)
processor.link(down_processor)
await processor.queue_frame(StartFrame(clock=SystemClock()))
for frame in frames_to_send:
await processor.process_frame(frame, FrameDirection.DOWNSTREAM)
await processor.queue_frame(EndTestFrame())
await processor.queue_frame(EndTestFrame(), FrameDirection.UPSTREAM)
#
# Down frames
#
received_down_frames: List[Frame] = []
running = True
while running:
frame = await received_down.get()
running = not isinstance(frame, EndTestFrame)
if running:
received_down_frames.append(frame)
print("received DOWN frames =", received_down_frames)
assert len(received_down_frames) == len(expected_down_frames)
for real, expected in zip(received_down_frames, expected_down_frames):
assert isinstance(real, expected)
#
# Up frames
#
received_up_frames: List[Frame] = []
running = True
while running:
frame = await received_up.get()
running = not isinstance(frame, EndTestFrame)
if running:
received_up_frames.append(frame)
print("received UP frames =", received_up_frames)
assert len(received_up_frames) == len(expected_up_frames)
for real, expected in zip(received_up_frames, expected_up_frames):
assert isinstance(real, expected)
return (received_down_frames, received_up_frames)