From 18e7626b9f4a692191dee6d96eb72250d6445e2a Mon Sep 17 00:00:00 2001 From: Moishe Lettvin Date: Mon, 4 Mar 2024 07:51:22 -0500 Subject: [PATCH] Getting started on interruptible transport pipeline runner --- src/dailyai/pipeline/aggregators.py | 7 +++++-- src/dailyai/pipeline/pipeline.py | 20 +++++++++++++++---- .../services/base_transport_service.py | 16 +++++++++++++++ src/dailyai/tests/test_aggregators.py | 7 ++++++- src/dailyai/tests/test_pipeline.py | 6 ++++-- 5 files changed, 47 insertions(+), 9 deletions(-) diff --git a/src/dailyai/pipeline/aggregators.py b/src/dailyai/pipeline/aggregators.py index 17c7775ff..9b244ef00 100644 --- a/src/dailyai/pipeline/aggregators.py +++ b/src/dailyai/pipeline/aggregators.py @@ -139,7 +139,6 @@ class LLMFullResponseAggregator(FrameProcessor): yield frame - class StatelessTextTransformer(FrameProcessor): def __init__(self, transform_fn): self.transform_fn = transform_fn @@ -159,7 +158,11 @@ class ParallelPipeline(FrameProcessor): self.sources = [asyncio.Queue() for _ in pipeline_definitions] self.sink: asyncio.Queue[QueueFrame] = asyncio.Queue() self.pipelines: list[Pipeline] = [ - Pipeline(source, self.sink, pipeline_definition) + Pipeline( + pipeline_definition, + source, + self.sink, + ) for source, pipeline_definition in zip(self.sources, pipeline_definitions) ] diff --git a/src/dailyai/pipeline/pipeline.py b/src/dailyai/pipeline/pipeline.py index f6618a85d..0ef2ab72a 100644 --- a/src/dailyai/pipeline/pipeline.py +++ b/src/dailyai/pipeline/pipeline.py @@ -12,20 +12,32 @@ instantiate and run a pipeline with the Transport's sink and source queues. """ class Pipeline: + def __init__( self, - source: asyncio.Queue, - sink: asyncio.Queue[QueueFrame], processors: List[FrameProcessor], + source: asyncio.Queue | None = None, + sink: asyncio.Queue[QueueFrame] | None = None, ): - self.source: asyncio.Queue[QueueFrame] = source - self.sink: asyncio.Queue[QueueFrame] = sink self.processors = processors + self.source: asyncio.Queue[QueueFrame] | None = source + self.sink: asyncio.Queue[QueueFrame] | None = sink + + def set_source(self, source: asyncio.Queue[QueueFrame]): + self.source = source + + def set_sink(self, sink: asyncio.Queue[QueueFrame]): + self.sink = sink async def get_next_source_frame(self) -> AsyncGenerator[QueueFrame, None]: + if self.source is None: + raise ValueError("Source queue not set") yield await self.source.get() async def run_pipeline(self): + if self.source is None or self.sink is None: + raise ValueError("Source or sink queue not set") + try: while True: frame_generators = [self.get_next_source_frame()] diff --git a/src/dailyai/services/base_transport_service.py b/src/dailyai/services/base_transport_service.py index 005d0c8f5..490489b87 100644 --- a/src/dailyai/services/base_transport_service.py +++ b/src/dailyai/services/base_transport_service.py @@ -21,6 +21,7 @@ from dailyai.pipeline.frames import ( UserStartedSpeakingFrame, UserStoppedSpeakingFrame ) +from dailyai.pipeline.pipeline import Pipeline torch.set_num_threads(1) @@ -166,6 +167,21 @@ class BaseTransportService(): if self._vad_enabled: self._vad_thread.join() + async def run_pipeline(self, pipeline:Pipeline, allow_interruptions=True): + pipeline.set_sink(self.send_queue) + if not allow_interruptions: + pipeline.set_source(self.receive_queue) + await pipeline.run_pipeline() + else: + source_queue = asyncio.Queue() + pipeline.set_source(source_queue) + pipeline.set_sink(self.send_queue) + pipeline_task = asyncio.create_task(pipeline.run_pipeline()) + + async for frame in self.get_receive_frames(): + await source_queue.put(frame) + + def _post_run(self): # Note that this function must be idempotent! It can be called multiple times diff --git a/src/dailyai/tests/test_aggregators.py b/src/dailyai/tests/test_aggregators.py index 2dfe574e4..5a2cfb07d 100644 --- a/src/dailyai/tests/test_aggregators.py +++ b/src/dailyai/tests/test_aggregators.py @@ -87,9 +87,14 @@ class TestDailyFrameAggregators(unittest.IsolatedAsyncioTestCase): source = asyncio.Queue() sink = asyncio.Queue() pipeline = Pipeline( + [ + ParallelPipeline( + [[pipe1_annotation], [sentence_aggregator, pipe2_annotation]] + ), + add_dots, + ], source, sink, - [ParallelPipeline([[pipe1_annotation], [sentence_aggregator, pipe2_annotation]]), add_dots], ) frames = [ diff --git a/src/dailyai/tests/test_pipeline.py b/src/dailyai/tests/test_pipeline.py index 94840eb6e..2e1d4289c 100644 --- a/src/dailyai/tests/test_pipeline.py +++ b/src/dailyai/tests/test_pipeline.py @@ -14,7 +14,7 @@ class TestDailyPipeline(unittest.IsolatedAsyncioTestCase): outgoing_queue = asyncio.Queue() incoming_queue = asyncio.Queue() - pipeline = Pipeline(incoming_queue, outgoing_queue, [aggregator]) + pipeline = Pipeline([aggregator], incoming_queue, outgoing_queue) await incoming_queue.put(TextQueueFrame("Hello, ")) await incoming_queue.put(TextQueueFrame("world.")) @@ -33,7 +33,9 @@ class TestDailyPipeline(unittest.IsolatedAsyncioTestCase): outgoing_queue = asyncio.Queue() incoming_queue = asyncio.Queue() pipeline = Pipeline( - incoming_queue, outgoing_queue, [add_space, sentence_aggregator, to_upper] + [add_space, sentence_aggregator, to_upper], + incoming_queue, + outgoing_queue ) sentence = "Hello, world. It's me, a pipeline."