Getting started on interruptible transport pipeline runner
This commit is contained in:
@@ -139,7 +139,6 @@ class LLMFullResponseAggregator(FrameProcessor):
|
||||
yield frame
|
||||
|
||||
|
||||
|
||||
class StatelessTextTransformer(FrameProcessor):
|
||||
def __init__(self, transform_fn):
|
||||
self.transform_fn = transform_fn
|
||||
@@ -159,7 +158,11 @@ class ParallelPipeline(FrameProcessor):
|
||||
self.sources = [asyncio.Queue() for _ in pipeline_definitions]
|
||||
self.sink: asyncio.Queue[QueueFrame] = asyncio.Queue()
|
||||
self.pipelines: list[Pipeline] = [
|
||||
Pipeline(source, self.sink, pipeline_definition)
|
||||
Pipeline(
|
||||
pipeline_definition,
|
||||
source,
|
||||
self.sink,
|
||||
)
|
||||
for source, pipeline_definition in zip(self.sources, pipeline_definitions)
|
||||
]
|
||||
|
||||
|
||||
@@ -12,20 +12,32 @@ instantiate and run a pipeline with the Transport's sink and source queues.
|
||||
"""
|
||||
|
||||
class Pipeline:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source: asyncio.Queue,
|
||||
sink: asyncio.Queue[QueueFrame],
|
||||
processors: List[FrameProcessor],
|
||||
source: asyncio.Queue | None = None,
|
||||
sink: asyncio.Queue[QueueFrame] | None = None,
|
||||
):
|
||||
self.source: asyncio.Queue[QueueFrame] = source
|
||||
self.sink: asyncio.Queue[QueueFrame] = sink
|
||||
self.processors = processors
|
||||
self.source: asyncio.Queue[QueueFrame] | None = source
|
||||
self.sink: asyncio.Queue[QueueFrame] | None = sink
|
||||
|
||||
def set_source(self, source: asyncio.Queue[QueueFrame]):
|
||||
self.source = source
|
||||
|
||||
def set_sink(self, sink: asyncio.Queue[QueueFrame]):
|
||||
self.sink = sink
|
||||
|
||||
async def get_next_source_frame(self) -> AsyncGenerator[QueueFrame, None]:
|
||||
if self.source is None:
|
||||
raise ValueError("Source queue not set")
|
||||
yield await self.source.get()
|
||||
|
||||
async def run_pipeline(self):
|
||||
if self.source is None or self.sink is None:
|
||||
raise ValueError("Source or sink queue not set")
|
||||
|
||||
try:
|
||||
while True:
|
||||
frame_generators = [self.get_next_source_frame()]
|
||||
|
||||
@@ -21,6 +21,7 @@ from dailyai.pipeline.frames import (
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame
|
||||
)
|
||||
from dailyai.pipeline.pipeline import Pipeline
|
||||
|
||||
torch.set_num_threads(1)
|
||||
|
||||
@@ -166,6 +167,21 @@ class BaseTransportService():
|
||||
if self._vad_enabled:
|
||||
self._vad_thread.join()
|
||||
|
||||
async def run_pipeline(self, pipeline:Pipeline, allow_interruptions=True):
|
||||
pipeline.set_sink(self.send_queue)
|
||||
if not allow_interruptions:
|
||||
pipeline.set_source(self.receive_queue)
|
||||
await pipeline.run_pipeline()
|
||||
else:
|
||||
source_queue = asyncio.Queue()
|
||||
pipeline.set_source(source_queue)
|
||||
pipeline.set_sink(self.send_queue)
|
||||
pipeline_task = asyncio.create_task(pipeline.run_pipeline())
|
||||
|
||||
async for frame in self.get_receive_frames():
|
||||
await source_queue.put(frame)
|
||||
|
||||
|
||||
|
||||
def _post_run(self):
|
||||
# Note that this function must be idempotent! It can be called multiple times
|
||||
|
||||
@@ -87,9 +87,14 @@ class TestDailyFrameAggregators(unittest.IsolatedAsyncioTestCase):
|
||||
source = asyncio.Queue()
|
||||
sink = asyncio.Queue()
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
ParallelPipeline(
|
||||
[[pipe1_annotation], [sentence_aggregator, pipe2_annotation]]
|
||||
),
|
||||
add_dots,
|
||||
],
|
||||
source,
|
||||
sink,
|
||||
[ParallelPipeline([[pipe1_annotation], [sentence_aggregator, pipe2_annotation]]), add_dots],
|
||||
)
|
||||
|
||||
frames = [
|
||||
|
||||
@@ -14,7 +14,7 @@ class TestDailyPipeline(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
outgoing_queue = asyncio.Queue()
|
||||
incoming_queue = asyncio.Queue()
|
||||
pipeline = Pipeline(incoming_queue, outgoing_queue, [aggregator])
|
||||
pipeline = Pipeline([aggregator], incoming_queue, outgoing_queue)
|
||||
|
||||
await incoming_queue.put(TextQueueFrame("Hello, "))
|
||||
await incoming_queue.put(TextQueueFrame("world."))
|
||||
@@ -33,7 +33,9 @@ class TestDailyPipeline(unittest.IsolatedAsyncioTestCase):
|
||||
outgoing_queue = asyncio.Queue()
|
||||
incoming_queue = asyncio.Queue()
|
||||
pipeline = Pipeline(
|
||||
incoming_queue, outgoing_queue, [add_space, sentence_aggregator, to_upper]
|
||||
[add_space, sentence_aggregator, to_upper],
|
||||
incoming_queue,
|
||||
outgoing_queue
|
||||
)
|
||||
|
||||
sentence = "Hello, world. It's me, a pipeline."
|
||||
|
||||
Reference in New Issue
Block a user