Merge pull request #1376 from pipecat-ai/aleix/event-handlers-as-tasks

event handlers are now executed in separate tasks
This commit is contained in:
Aleix Conchillo Flaqué
2025-03-18 12:10:34 -07:00
committed by GitHub
10 changed files with 150 additions and 17 deletions

View File

@@ -99,6 +99,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- All event handlers are now executed in separate tasks in order to prevent
blocking the pipeline. It is possible that event handlers take some time to
execute in which case the pipeline would be blocked waiting for the event
handler to complete.
- Updated `TranscriptProcessor` to support text output from
`OpenAIRealtimeBetaLLMService`.

View File

@@ -40,12 +40,18 @@ class PipelineRunner(BaseObject):
task.set_event_loop(self._loop)
await task.run()
del self._tasks[task.name]
# Cleanup base object.
await self.cleanup()
# If we are cancelling through a signal, make sure we wait for it so
# everything gets cleaned up nicely.
if self._sig_task:
await self._sig_task
if self._force_gc:
self._gc_collect()
logger.debug(f"Runner {self} finished running {task}")
async def stop_when_done(self):

View File

@@ -354,6 +354,10 @@ class PipelineTask(BaseTask):
self._pipeline_end_event.clear()
async def _cleanup(self, cleanup_pipeline: bool):
# Cleanup base object.
await self.cleanup()
# Cleanup pipeline processors.
await self._source.cleanup()
if cleanup_pipeline:
await self._pipeline.cleanup()

View File

@@ -164,6 +164,7 @@ class FrameProcessor(BaseObject):
await self._task_manager.wait_for_task(task, timeout)
async def cleanup(self):
await super().cleanup()
await self.__cancel_input_task()
await self.__cancel_push_task()

View File

@@ -102,11 +102,13 @@ class FastAPIWebsocketClient:
class FastAPIWebsocketInputTransport(BaseInputTransport):
def __init__(
self,
transport: BaseTransport,
client: FastAPIWebsocketClient,
params: FastAPIWebsocketParams,
**kwargs,
):
super().__init__(params, **kwargs)
self._transport = transport
self._client = client
self._params = params
self._receive_task = None
@@ -139,6 +141,10 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
await self._stop_tasks()
await self._client.disconnect()
async def cleanup(self):
await super().cleanup()
await self._transport.cleanup()
async def _receive_messages(self):
try:
async for message in self._client.receive():
@@ -165,11 +171,14 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
class FastAPIWebsocketOutputTransport(BaseOutputTransport):
def __init__(
self,
transport: BaseTransport,
client: FastAPIWebsocketClient,
params: FastAPIWebsocketParams,
**kwargs,
):
super().__init__(params, **kwargs)
self._transport = transport
self._client = client
self._params = params
@@ -194,6 +203,10 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
await super().cancel(frame)
await self._client.disconnect()
async def cleanup(self):
await super().cleanup()
await self._transport.cleanup()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
@@ -266,6 +279,7 @@ class FastAPIWebsocketTransport(BaseTransport):
output_name: Optional[str] = None,
):
super().__init__(input_name=input_name, output_name=output_name)
self._params = params
self._callbacks = FastAPIWebsocketCallbacks(
@@ -278,10 +292,10 @@ class FastAPIWebsocketTransport(BaseTransport):
self._client = FastAPIWebsocketClient(websocket, is_binary, self._callbacks)
self._input = FastAPIWebsocketInputTransport(
self._client, self._params, name=self._input_name
self, self._client, self._params, name=self._input_name
)
self._output = FastAPIWebsocketOutputTransport(
self._client, self._params, name=self._output_name
self, self._client, self._params, name=self._output_name
)
# Register supported handlers. The user will only be able to register

View File

@@ -118,9 +118,15 @@ class WebsocketClientSession:
class WebsocketClientInputTransport(BaseInputTransport):
def __init__(self, session: WebsocketClientSession, params: WebsocketClientParams):
def __init__(
self,
transport: BaseTransport,
session: WebsocketClientSession,
params: WebsocketClientParams,
):
super().__init__(params)
self._transport = transport
self._session = session
self._params = params
@@ -138,6 +144,10 @@ class WebsocketClientInputTransport(BaseInputTransport):
await super().cancel(frame)
await self._session.disconnect()
async def cleanup(self):
await super().cleanup()
await self._transport.cleanup()
async def on_message(self, websocket, message):
frame = await self._params.serializer.deserialize(message)
if not frame:
@@ -149,9 +159,15 @@ class WebsocketClientInputTransport(BaseInputTransport):
class WebsocketClientOutputTransport(BaseOutputTransport):
def __init__(self, session: WebsocketClientSession, params: WebsocketClientParams):
def __init__(
self,
transport: BaseTransport,
session: WebsocketClientSession,
params: WebsocketClientParams,
):
super().__init__(params)
self._transport = transport
self._session = session
self._params = params
@@ -178,6 +194,10 @@ class WebsocketClientOutputTransport(BaseOutputTransport):
await super().cancel(frame)
await self._session.disconnect()
async def cleanup(self):
await super().cleanup()
await self._transport.cleanup()
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
await self._write_frame(frame)
@@ -250,12 +270,12 @@ class WebsocketClientTransport(BaseTransport):
def input(self) -> WebsocketClientInputTransport:
if not self._input:
self._input = WebsocketClientInputTransport(self._session, self._params)
self._input = WebsocketClientInputTransport(self, self._session, self._params)
return self._input
def output(self) -> WebsocketClientOutputTransport:
if not self._output:
self._output = WebsocketClientOutputTransport(self._session, self._params)
self._output = WebsocketClientOutputTransport(self, self._session, self._params)
return self._output
async def _on_connected(self, websocket):

View File

@@ -55,6 +55,7 @@ class WebsocketServerCallbacks(BaseModel):
class WebsocketServerInputTransport(BaseInputTransport):
def __init__(
self,
transport: BaseTransport,
host: str,
port: int,
params: WebsocketServerParams,
@@ -63,6 +64,7 @@ class WebsocketServerInputTransport(BaseInputTransport):
):
super().__init__(params, **kwargs)
self._transport = transport
self._host = host
self._port = port
self._params = params
@@ -102,6 +104,10 @@ class WebsocketServerInputTransport(BaseInputTransport):
await self.cancel_task(self._server_task)
self._server_task = None
async def cleanup(self):
await super().cleanup()
await self._transport.cleanup()
async def _server_task_handler(self):
logger.info(f"Starting websocket server on {self._host}:{self._port}")
async with websockets.serve(self._client_handler, self._host, self._port) as server:
@@ -163,9 +169,10 @@ class WebsocketServerInputTransport(BaseInputTransport):
class WebsocketServerOutputTransport(BaseOutputTransport):
def __init__(self, params: WebsocketServerParams, **kwargs):
def __init__(self, transport: BaseTransport, params: WebsocketServerParams, **kwargs):
super().__init__(params, **kwargs)
self._transport = transport
self._params = params
self._websocket: Optional[websockets.WebSocketServerProtocol] = None
@@ -189,6 +196,10 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
await self._params.serializer.setup(frame)
self._send_interval = (self._audio_chunk_size / self.sample_rate) / 2
async def cleanup(self):
await super().cleanup()
await self._transport.cleanup()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
@@ -283,13 +294,15 @@ class WebsocketServerTransport(BaseTransport):
def input(self) -> WebsocketServerInputTransport:
if not self._input:
self._input = WebsocketServerInputTransport(
self._host, self._port, self._params, self._callbacks, name=self._input_name
self, self._host, self._port, self._params, self._callbacks, name=self._input_name
)
return self._input
def output(self) -> WebsocketServerOutputTransport:
if not self._output:
self._output = WebsocketServerOutputTransport(self._params, name=self._output_name)
self._output = WebsocketServerOutputTransport(
self, self._params, name=self._output_name
)
return self._output
async def _on_client_connected(self, websocket):

View File

@@ -811,9 +811,16 @@ class DailyInputTransport(BaseInputTransport):
params: Configuration parameters.
"""
def __init__(self, client: DailyTransportClient, params: DailyParams, **kwargs):
def __init__(
self,
transport: BaseTransport,
client: DailyTransportClient,
params: DailyParams,
**kwargs,
):
super().__init__(params, **kwargs)
self._transport = transport
self._client = client
self._params = params
@@ -881,6 +888,7 @@ class DailyInputTransport(BaseInputTransport):
async def cleanup(self):
await super().cleanup()
await self._client.cleanup()
await self._transport.cleanup()
#
# FrameProcessor
@@ -971,9 +979,12 @@ class DailyOutputTransport(BaseOutputTransport):
params: Configuration parameters.
"""
def __init__(self, client: DailyTransportClient, params: DailyParams, **kwargs):
def __init__(
self, transport: BaseTransport, client: DailyTransportClient, params: DailyParams, **kwargs
):
super().__init__(params, **kwargs)
self._transport = transport
self._client = client
# Whether we have seen a StartFrame already.
@@ -1008,6 +1019,7 @@ class DailyOutputTransport(BaseOutputTransport):
async def cleanup(self):
await super().cleanup()
await self._client.cleanup()
await self._transport.cleanup()
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
await self._client.send_message(frame)
@@ -1109,12 +1121,16 @@ class DailyTransport(BaseTransport):
def input(self) -> DailyInputTransport:
if not self._input:
self._input = DailyInputTransport(self._client, self._params, name=self._input_name)
self._input = DailyInputTransport(
self, self._client, self._params, name=self._input_name
)
return self._input
def output(self) -> DailyOutputTransport:
if not self._output:
self._output = DailyOutputTransport(self._client, self._params, name=self._output_name)
self._output = DailyOutputTransport(
self, self._client, self._params, name=self._output_name
)
return self._output
#

View File

@@ -345,9 +345,17 @@ class LiveKitTransportClient:
class LiveKitInputTransport(BaseInputTransport):
def __init__(self, client: LiveKitTransportClient, params: LiveKitParams, **kwargs):
def __init__(
self,
transport: BaseTransport,
client: LiveKitTransportClient,
params: LiveKitParams,
**kwargs,
):
super().__init__(params, **kwargs)
self._transport = transport
self._client = client
self._audio_in_task = None
self._vad_analyzer: Optional[VADAnalyzer] = params.vad_analyzer
self._resampler = create_default_resampler()
@@ -377,6 +385,10 @@ class LiveKitInputTransport(BaseInputTransport):
if self._audio_in_task and (self._params.audio_in_enabled or self._params.vad_enabled):
await self.cancel_task(self._audio_in_task)
async def cleanup(self):
await super().cleanup()
await self._transport.cleanup()
async def push_app_message(self, message: Any, sender: str):
frame = LiveKitTransportMessageUrgentFrame(message=message, participant_id=sender)
await self.push_frame(frame)
@@ -414,8 +426,15 @@ class LiveKitInputTransport(BaseInputTransport):
class LiveKitOutputTransport(BaseOutputTransport):
def __init__(self, client: LiveKitTransportClient, params: LiveKitParams, **kwargs):
def __init__(
self,
transport: BaseTransport,
client: LiveKitTransportClient,
params: LiveKitParams,
**kwargs,
):
super().__init__(params, **kwargs)
self._transport = transport
self._client = client
async def start(self, frame: StartFrame):
@@ -433,6 +452,10 @@ class LiveKitOutputTransport(BaseOutputTransport):
await super().cancel(frame)
await self._client.disconnect()
async def cleanup(self):
await super().cleanup()
await self._transport.cleanup()
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
if isinstance(frame, (LiveKitTransportMessageFrame, LiveKitTransportMessageUrgentFrame)):
await self._client.send_data(frame.message.encode(), frame.participant_id)
@@ -499,13 +522,15 @@ class LiveKitTransport(BaseTransport):
def input(self) -> LiveKitInputTransport:
if not self._input:
self._input = LiveKitInputTransport(self._client, self._params, name=self._input_name)
self._input = LiveKitInputTransport(
self, self._client, self._params, name=self._input_name
)
return self._input
def output(self) -> LiveKitOutputTransport:
if not self._output:
self._output = LiveKitOutputTransport(
self._client, self._params, name=self._output_name
self, self._client, self._params, name=self._output_name
)
return self._output

View File

@@ -4,6 +4,7 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import inspect
from abc import ABC
from typing import Optional
@@ -17,8 +18,15 @@ class BaseObject(ABC):
def __init__(self, *, name: Optional[str] = None):
self._id: int = obj_id()
self._name = name or f"{self.__class__.__name__}#{obj_count(self)}"
# Registered event handlers.
self._event_handlers: dict = {}
# Set of tasks being executed. When a task finishes running it gets
# automatically removed from the set. When we cleanup we wait for all
# event tasks still being executed.
self._event_tasks = set()
@property
def id(self) -> int:
return self._id
@@ -27,6 +35,12 @@ class BaseObject(ABC):
def name(self) -> str:
return self._name
async def cleanup(self):
if self._event_tasks:
event_names, tasks = zip(*self._event_tasks)
logger.debug(f"{self} wating on event handlers to finish {list(event_names)}...")
await asyncio.wait(tasks)
def event_handler(self, event_name: str):
def decorator(handler):
self.add_event_handler(event_name, handler)
@@ -45,6 +59,16 @@ class BaseObject(ABC):
self._event_handlers[event_name] = []
async def _call_event_handler(self, event_name: str, *args, **kwargs):
# Create the task.
task = asyncio.create_task(self._run_task(event_name, *args, **kwargs))
# Add it to our list of event tasks.
self._event_tasks.add((event_name, task))
# Remove the task from the event tasks list when the task completes.
task.add_done_callback(self._event_task_finished)
async def _run_task(self, event_name: str, *args, **kwargs):
try:
for handler in self._event_handlers[event_name]:
if inspect.iscoroutinefunction(handler):
@@ -54,5 +78,10 @@ class BaseObject(ABC):
except Exception as e:
logger.exception(f"Exception in event handler {event_name}: {e}")
def _event_task_finished(self, task: asyncio.Task):
tuple_to_remove = next((t for t in self._event_tasks if t[1] == task), None)
if tuple_to_remove:
self._event_tasks.discard(tuple_to_remove)
def __str__(self):
return self.name