Merge pull request #1376 from pipecat-ai/aleix/event-handlers-as-tasks
event handlers are now executed in separate tasks
This commit is contained in:
@@ -99,6 +99,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Changed
|
||||
|
||||
- All event handlers are now executed in separate tasks in order to prevent
|
||||
blocking the pipeline. It is possible that event handlers take some time to
|
||||
execute in which case the pipeline would be blocked waiting for the event
|
||||
handler to complete.
|
||||
|
||||
- Updated `TranscriptProcessor` to support text output from
|
||||
`OpenAIRealtimeBetaLLMService`.
|
||||
|
||||
|
||||
@@ -40,12 +40,18 @@ class PipelineRunner(BaseObject):
|
||||
task.set_event_loop(self._loop)
|
||||
await task.run()
|
||||
del self._tasks[task.name]
|
||||
|
||||
# Cleanup base object.
|
||||
await self.cleanup()
|
||||
|
||||
# If we are cancelling through a signal, make sure we wait for it so
|
||||
# everything gets cleaned up nicely.
|
||||
if self._sig_task:
|
||||
await self._sig_task
|
||||
|
||||
if self._force_gc:
|
||||
self._gc_collect()
|
||||
|
||||
logger.debug(f"Runner {self} finished running {task}")
|
||||
|
||||
async def stop_when_done(self):
|
||||
|
||||
@@ -354,6 +354,10 @@ class PipelineTask(BaseTask):
|
||||
self._pipeline_end_event.clear()
|
||||
|
||||
async def _cleanup(self, cleanup_pipeline: bool):
|
||||
# Cleanup base object.
|
||||
await self.cleanup()
|
||||
|
||||
# Cleanup pipeline processors.
|
||||
await self._source.cleanup()
|
||||
if cleanup_pipeline:
|
||||
await self._pipeline.cleanup()
|
||||
|
||||
@@ -164,6 +164,7 @@ class FrameProcessor(BaseObject):
|
||||
await self._task_manager.wait_for_task(task, timeout)
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
await self.__cancel_input_task()
|
||||
await self.__cancel_push_task()
|
||||
|
||||
|
||||
@@ -102,11 +102,13 @@ class FastAPIWebsocketClient:
|
||||
class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
def __init__(
|
||||
self,
|
||||
transport: BaseTransport,
|
||||
client: FastAPIWebsocketClient,
|
||||
params: FastAPIWebsocketParams,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(params, **kwargs)
|
||||
self._transport = transport
|
||||
self._client = client
|
||||
self._params = params
|
||||
self._receive_task = None
|
||||
@@ -139,6 +141,10 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
await self._stop_tasks()
|
||||
await self._client.disconnect()
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
await self._transport.cleanup()
|
||||
|
||||
async def _receive_messages(self):
|
||||
try:
|
||||
async for message in self._client.receive():
|
||||
@@ -165,11 +171,14 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
def __init__(
|
||||
self,
|
||||
transport: BaseTransport,
|
||||
client: FastAPIWebsocketClient,
|
||||
params: FastAPIWebsocketParams,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(params, **kwargs)
|
||||
|
||||
self._transport = transport
|
||||
self._client = client
|
||||
self._params = params
|
||||
|
||||
@@ -194,6 +203,10 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
await super().cancel(frame)
|
||||
await self._client.disconnect()
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
await self._transport.cleanup()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -266,6 +279,7 @@ class FastAPIWebsocketTransport(BaseTransport):
|
||||
output_name: Optional[str] = None,
|
||||
):
|
||||
super().__init__(input_name=input_name, output_name=output_name)
|
||||
|
||||
self._params = params
|
||||
|
||||
self._callbacks = FastAPIWebsocketCallbacks(
|
||||
@@ -278,10 +292,10 @@ class FastAPIWebsocketTransport(BaseTransport):
|
||||
self._client = FastAPIWebsocketClient(websocket, is_binary, self._callbacks)
|
||||
|
||||
self._input = FastAPIWebsocketInputTransport(
|
||||
self._client, self._params, name=self._input_name
|
||||
self, self._client, self._params, name=self._input_name
|
||||
)
|
||||
self._output = FastAPIWebsocketOutputTransport(
|
||||
self._client, self._params, name=self._output_name
|
||||
self, self._client, self._params, name=self._output_name
|
||||
)
|
||||
|
||||
# Register supported handlers. The user will only be able to register
|
||||
|
||||
@@ -118,9 +118,15 @@ class WebsocketClientSession:
|
||||
|
||||
|
||||
class WebsocketClientInputTransport(BaseInputTransport):
|
||||
def __init__(self, session: WebsocketClientSession, params: WebsocketClientParams):
|
||||
def __init__(
|
||||
self,
|
||||
transport: BaseTransport,
|
||||
session: WebsocketClientSession,
|
||||
params: WebsocketClientParams,
|
||||
):
|
||||
super().__init__(params)
|
||||
|
||||
self._transport = transport
|
||||
self._session = session
|
||||
self._params = params
|
||||
|
||||
@@ -138,6 +144,10 @@ class WebsocketClientInputTransport(BaseInputTransport):
|
||||
await super().cancel(frame)
|
||||
await self._session.disconnect()
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
await self._transport.cleanup()
|
||||
|
||||
async def on_message(self, websocket, message):
|
||||
frame = await self._params.serializer.deserialize(message)
|
||||
if not frame:
|
||||
@@ -149,9 +159,15 @@ class WebsocketClientInputTransport(BaseInputTransport):
|
||||
|
||||
|
||||
class WebsocketClientOutputTransport(BaseOutputTransport):
|
||||
def __init__(self, session: WebsocketClientSession, params: WebsocketClientParams):
|
||||
def __init__(
|
||||
self,
|
||||
transport: BaseTransport,
|
||||
session: WebsocketClientSession,
|
||||
params: WebsocketClientParams,
|
||||
):
|
||||
super().__init__(params)
|
||||
|
||||
self._transport = transport
|
||||
self._session = session
|
||||
self._params = params
|
||||
|
||||
@@ -178,6 +194,10 @@ class WebsocketClientOutputTransport(BaseOutputTransport):
|
||||
await super().cancel(frame)
|
||||
await self._session.disconnect()
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
await self._transport.cleanup()
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
|
||||
await self._write_frame(frame)
|
||||
|
||||
@@ -250,12 +270,12 @@ class WebsocketClientTransport(BaseTransport):
|
||||
|
||||
def input(self) -> WebsocketClientInputTransport:
|
||||
if not self._input:
|
||||
self._input = WebsocketClientInputTransport(self._session, self._params)
|
||||
self._input = WebsocketClientInputTransport(self, self._session, self._params)
|
||||
return self._input
|
||||
|
||||
def output(self) -> WebsocketClientOutputTransport:
|
||||
if not self._output:
|
||||
self._output = WebsocketClientOutputTransport(self._session, self._params)
|
||||
self._output = WebsocketClientOutputTransport(self, self._session, self._params)
|
||||
return self._output
|
||||
|
||||
async def _on_connected(self, websocket):
|
||||
|
||||
@@ -55,6 +55,7 @@ class WebsocketServerCallbacks(BaseModel):
|
||||
class WebsocketServerInputTransport(BaseInputTransport):
|
||||
def __init__(
|
||||
self,
|
||||
transport: BaseTransport,
|
||||
host: str,
|
||||
port: int,
|
||||
params: WebsocketServerParams,
|
||||
@@ -63,6 +64,7 @@ class WebsocketServerInputTransport(BaseInputTransport):
|
||||
):
|
||||
super().__init__(params, **kwargs)
|
||||
|
||||
self._transport = transport
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._params = params
|
||||
@@ -102,6 +104,10 @@ class WebsocketServerInputTransport(BaseInputTransport):
|
||||
await self.cancel_task(self._server_task)
|
||||
self._server_task = None
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
await self._transport.cleanup()
|
||||
|
||||
async def _server_task_handler(self):
|
||||
logger.info(f"Starting websocket server on {self._host}:{self._port}")
|
||||
async with websockets.serve(self._client_handler, self._host, self._port) as server:
|
||||
@@ -163,9 +169,10 @@ class WebsocketServerInputTransport(BaseInputTransport):
|
||||
|
||||
|
||||
class WebsocketServerOutputTransport(BaseOutputTransport):
|
||||
def __init__(self, params: WebsocketServerParams, **kwargs):
|
||||
def __init__(self, transport: BaseTransport, params: WebsocketServerParams, **kwargs):
|
||||
super().__init__(params, **kwargs)
|
||||
|
||||
self._transport = transport
|
||||
self._params = params
|
||||
|
||||
self._websocket: Optional[websockets.WebSocketServerProtocol] = None
|
||||
@@ -189,6 +196,10 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
|
||||
await self._params.serializer.setup(frame)
|
||||
self._send_interval = (self._audio_chunk_size / self.sample_rate) / 2
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
await self._transport.cleanup()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -283,13 +294,15 @@ class WebsocketServerTransport(BaseTransport):
|
||||
def input(self) -> WebsocketServerInputTransport:
|
||||
if not self._input:
|
||||
self._input = WebsocketServerInputTransport(
|
||||
self._host, self._port, self._params, self._callbacks, name=self._input_name
|
||||
self, self._host, self._port, self._params, self._callbacks, name=self._input_name
|
||||
)
|
||||
return self._input
|
||||
|
||||
def output(self) -> WebsocketServerOutputTransport:
|
||||
if not self._output:
|
||||
self._output = WebsocketServerOutputTransport(self._params, name=self._output_name)
|
||||
self._output = WebsocketServerOutputTransport(
|
||||
self, self._params, name=self._output_name
|
||||
)
|
||||
return self._output
|
||||
|
||||
async def _on_client_connected(self, websocket):
|
||||
|
||||
@@ -811,9 +811,16 @@ class DailyInputTransport(BaseInputTransport):
|
||||
params: Configuration parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, client: DailyTransportClient, params: DailyParams, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
transport: BaseTransport,
|
||||
client: DailyTransportClient,
|
||||
params: DailyParams,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(params, **kwargs)
|
||||
|
||||
self._transport = transport
|
||||
self._client = client
|
||||
self._params = params
|
||||
|
||||
@@ -881,6 +888,7 @@ class DailyInputTransport(BaseInputTransport):
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
await self._client.cleanup()
|
||||
await self._transport.cleanup()
|
||||
|
||||
#
|
||||
# FrameProcessor
|
||||
@@ -971,9 +979,12 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
params: Configuration parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, client: DailyTransportClient, params: DailyParams, **kwargs):
|
||||
def __init__(
|
||||
self, transport: BaseTransport, client: DailyTransportClient, params: DailyParams, **kwargs
|
||||
):
|
||||
super().__init__(params, **kwargs)
|
||||
|
||||
self._transport = transport
|
||||
self._client = client
|
||||
|
||||
# Whether we have seen a StartFrame already.
|
||||
@@ -1008,6 +1019,7 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
await self._client.cleanup()
|
||||
await self._transport.cleanup()
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
|
||||
await self._client.send_message(frame)
|
||||
@@ -1109,12 +1121,16 @@ class DailyTransport(BaseTransport):
|
||||
|
||||
def input(self) -> DailyInputTransport:
|
||||
if not self._input:
|
||||
self._input = DailyInputTransport(self._client, self._params, name=self._input_name)
|
||||
self._input = DailyInputTransport(
|
||||
self, self._client, self._params, name=self._input_name
|
||||
)
|
||||
return self._input
|
||||
|
||||
def output(self) -> DailyOutputTransport:
|
||||
if not self._output:
|
||||
self._output = DailyOutputTransport(self._client, self._params, name=self._output_name)
|
||||
self._output = DailyOutputTransport(
|
||||
self, self._client, self._params, name=self._output_name
|
||||
)
|
||||
return self._output
|
||||
|
||||
#
|
||||
|
||||
@@ -345,9 +345,17 @@ class LiveKitTransportClient:
|
||||
|
||||
|
||||
class LiveKitInputTransport(BaseInputTransport):
|
||||
def __init__(self, client: LiveKitTransportClient, params: LiveKitParams, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
transport: BaseTransport,
|
||||
client: LiveKitTransportClient,
|
||||
params: LiveKitParams,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(params, **kwargs)
|
||||
self._transport = transport
|
||||
self._client = client
|
||||
|
||||
self._audio_in_task = None
|
||||
self._vad_analyzer: Optional[VADAnalyzer] = params.vad_analyzer
|
||||
self._resampler = create_default_resampler()
|
||||
@@ -377,6 +385,10 @@ class LiveKitInputTransport(BaseInputTransport):
|
||||
if self._audio_in_task and (self._params.audio_in_enabled or self._params.vad_enabled):
|
||||
await self.cancel_task(self._audio_in_task)
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
await self._transport.cleanup()
|
||||
|
||||
async def push_app_message(self, message: Any, sender: str):
|
||||
frame = LiveKitTransportMessageUrgentFrame(message=message, participant_id=sender)
|
||||
await self.push_frame(frame)
|
||||
@@ -414,8 +426,15 @@ class LiveKitInputTransport(BaseInputTransport):
|
||||
|
||||
|
||||
class LiveKitOutputTransport(BaseOutputTransport):
|
||||
def __init__(self, client: LiveKitTransportClient, params: LiveKitParams, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
transport: BaseTransport,
|
||||
client: LiveKitTransportClient,
|
||||
params: LiveKitParams,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(params, **kwargs)
|
||||
self._transport = transport
|
||||
self._client = client
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
@@ -433,6 +452,10 @@ class LiveKitOutputTransport(BaseOutputTransport):
|
||||
await super().cancel(frame)
|
||||
await self._client.disconnect()
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
await self._transport.cleanup()
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
|
||||
if isinstance(frame, (LiveKitTransportMessageFrame, LiveKitTransportMessageUrgentFrame)):
|
||||
await self._client.send_data(frame.message.encode(), frame.participant_id)
|
||||
@@ -499,13 +522,15 @@ class LiveKitTransport(BaseTransport):
|
||||
|
||||
def input(self) -> LiveKitInputTransport:
|
||||
if not self._input:
|
||||
self._input = LiveKitInputTransport(self._client, self._params, name=self._input_name)
|
||||
self._input = LiveKitInputTransport(
|
||||
self, self._client, self._params, name=self._input_name
|
||||
)
|
||||
return self._input
|
||||
|
||||
def output(self) -> LiveKitOutputTransport:
|
||||
if not self._output:
|
||||
self._output = LiveKitOutputTransport(
|
||||
self._client, self._params, name=self._output_name
|
||||
self, self._client, self._params, name=self._output_name
|
||||
)
|
||||
return self._output
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
@@ -17,8 +18,15 @@ class BaseObject(ABC):
|
||||
def __init__(self, *, name: Optional[str] = None):
|
||||
self._id: int = obj_id()
|
||||
self._name = name or f"{self.__class__.__name__}#{obj_count(self)}"
|
||||
|
||||
# Registered event handlers.
|
||||
self._event_handlers: dict = {}
|
||||
|
||||
# Set of tasks being executed. When a task finishes running it gets
|
||||
# automatically removed from the set. When we cleanup we wait for all
|
||||
# event tasks still being executed.
|
||||
self._event_tasks = set()
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
@@ -27,6 +35,12 @@ class BaseObject(ABC):
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
async def cleanup(self):
|
||||
if self._event_tasks:
|
||||
event_names, tasks = zip(*self._event_tasks)
|
||||
logger.debug(f"{self} wating on event handlers to finish {list(event_names)}...")
|
||||
await asyncio.wait(tasks)
|
||||
|
||||
def event_handler(self, event_name: str):
|
||||
def decorator(handler):
|
||||
self.add_event_handler(event_name, handler)
|
||||
@@ -45,6 +59,16 @@ class BaseObject(ABC):
|
||||
self._event_handlers[event_name] = []
|
||||
|
||||
async def _call_event_handler(self, event_name: str, *args, **kwargs):
|
||||
# Create the task.
|
||||
task = asyncio.create_task(self._run_task(event_name, *args, **kwargs))
|
||||
|
||||
# Add it to our list of event tasks.
|
||||
self._event_tasks.add((event_name, task))
|
||||
|
||||
# Remove the task from the event tasks list when the task completes.
|
||||
task.add_done_callback(self._event_task_finished)
|
||||
|
||||
async def _run_task(self, event_name: str, *args, **kwargs):
|
||||
try:
|
||||
for handler in self._event_handlers[event_name]:
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
@@ -54,5 +78,10 @@ class BaseObject(ABC):
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception in event handler {event_name}: {e}")
|
||||
|
||||
def _event_task_finished(self, task: asyncio.Task):
|
||||
tuple_to_remove = next((t for t in self._event_tasks if t[1] == task), None)
|
||||
if tuple_to_remove:
|
||||
self._event_tasks.discard(tuple_to_remove)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
Reference in New Issue
Block a user