Getting started on interruptible transport pipeline runner

This commit is contained in:
Moishe Lettvin
2024-03-04 07:51:22 -05:00
parent 763a50f8ec
commit 18e7626b9f
5 changed files with 47 additions and 9 deletions

View File

@@ -139,7 +139,6 @@ class LLMFullResponseAggregator(FrameProcessor):
yield frame
class StatelessTextTransformer(FrameProcessor):
def __init__(self, transform_fn):
self.transform_fn = transform_fn
@@ -159,7 +158,11 @@ class ParallelPipeline(FrameProcessor):
self.sources = [asyncio.Queue() for _ in pipeline_definitions]
self.sink: asyncio.Queue[QueueFrame] = asyncio.Queue()
self.pipelines: list[Pipeline] = [
Pipeline(source, self.sink, pipeline_definition)
Pipeline(
pipeline_definition,
source,
self.sink,
)
for source, pipeline_definition in zip(self.sources, pipeline_definitions)
]

View File

@@ -12,20 +12,32 @@ instantiate and run a pipeline with the Transport's sink and source queues.
"""
class Pipeline:
def __init__(
self,
source: asyncio.Queue,
sink: asyncio.Queue[QueueFrame],
processors: List[FrameProcessor],
source: asyncio.Queue | None = None,
sink: asyncio.Queue[QueueFrame] | None = None,
):
self.source: asyncio.Queue[QueueFrame] = source
self.sink: asyncio.Queue[QueueFrame] = sink
self.processors = processors
self.source: asyncio.Queue[QueueFrame] | None = source
self.sink: asyncio.Queue[QueueFrame] | None = sink
def set_source(self, source: asyncio.Queue[QueueFrame]):
self.source = source
def set_sink(self, sink: asyncio.Queue[QueueFrame]):
self.sink = sink
async def get_next_source_frame(self) -> AsyncGenerator[QueueFrame, None]:
if self.source is None:
raise ValueError("Source queue not set")
yield await self.source.get()
async def run_pipeline(self):
if self.source is None or self.sink is None:
raise ValueError("Source or sink queue not set")
try:
while True:
frame_generators = [self.get_next_source_frame()]

View File

@@ -21,6 +21,7 @@ from dailyai.pipeline.frames import (
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame
)
from dailyai.pipeline.pipeline import Pipeline
torch.set_num_threads(1)
@@ -166,6 +167,21 @@ class BaseTransportService():
if self._vad_enabled:
self._vad_thread.join()
async def run_pipeline(self, pipeline:Pipeline, allow_interruptions=True):
pipeline.set_sink(self.send_queue)
if not allow_interruptions:
pipeline.set_source(self.receive_queue)
await pipeline.run_pipeline()
else:
source_queue = asyncio.Queue()
pipeline.set_source(source_queue)
pipeline.set_sink(self.send_queue)
pipeline_task = asyncio.create_task(pipeline.run_pipeline())
async for frame in self.get_receive_frames():
await source_queue.put(frame)
def _post_run(self):
# Note that this function must be idempotent! It can be called multiple times

View File

@@ -87,9 +87,14 @@ class TestDailyFrameAggregators(unittest.IsolatedAsyncioTestCase):
source = asyncio.Queue()
sink = asyncio.Queue()
pipeline = Pipeline(
[
ParallelPipeline(
[[pipe1_annotation], [sentence_aggregator, pipe2_annotation]]
),
add_dots,
],
source,
sink,
[ParallelPipeline([[pipe1_annotation], [sentence_aggregator, pipe2_annotation]]), add_dots],
)
frames = [

View File

@@ -14,7 +14,7 @@ class TestDailyPipeline(unittest.IsolatedAsyncioTestCase):
outgoing_queue = asyncio.Queue()
incoming_queue = asyncio.Queue()
pipeline = Pipeline(incoming_queue, outgoing_queue, [aggregator])
pipeline = Pipeline([aggregator], incoming_queue, outgoing_queue)
await incoming_queue.put(TextQueueFrame("Hello, "))
await incoming_queue.put(TextQueueFrame("world."))
@@ -33,7 +33,9 @@ class TestDailyPipeline(unittest.IsolatedAsyncioTestCase):
outgoing_queue = asyncio.Queue()
incoming_queue = asyncio.Queue()
pipeline = Pipeline(
incoming_queue, outgoing_queue, [add_space, sentence_aggregator, to_upper]
[add_space, sentence_aggregator, to_upper],
incoming_queue,
outgoing_queue
)
sentence = "Hello, world. It's me, a pipeline."