BaseObject: allow synchronous event handlers
This commit is contained in:
@@ -9,6 +9,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- It is now possible to register synchronous event handlers. By default, all
|
||||
event handlers are executed in a separate task. However, in some cases we want
|
||||
to guarantee order of execution, for example, executing something before
|
||||
disconnecting a transport.
|
||||
|
||||
```python
|
||||
self._register_event_handler("on_event_name", sync=True)
|
||||
```
|
||||
|
||||
- Added support for global location in `GoogleVertexLLMService`. The service now
|
||||
supports both regional locations (e.g., "us-east4") and the "global" location
|
||||
for Vertex AI endpoints. When using "global" location, the service will use
|
||||
|
||||
@@ -14,13 +14,33 @@ and async cleanup for all Pipecat components.
|
||||
import asyncio
|
||||
import inspect
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.utils.utils import obj_count, obj_id
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventHandler:
|
||||
"""Data class to store event handlers information.
|
||||
|
||||
This data class stores the event name, a list of handlers to run for this
|
||||
event, and whether these handlers will be executed in a task.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the event handler.
|
||||
handlers (List[Any]): A list of functions to be called when this event is triggered.
|
||||
is_sync (bool): Indicates whether the functions are executed in a task.
|
||||
|
||||
"""
|
||||
|
||||
name: str
|
||||
handlers: List[Any]
|
||||
is_sync: bool
|
||||
|
||||
|
||||
class BaseObject(ABC):
|
||||
"""Abstract base class providing common functionality for Pipecat objects.
|
||||
|
||||
@@ -41,7 +61,7 @@ class BaseObject(ABC):
|
||||
self._name = name or f"{self.__class__.__name__}#{obj_count(self)}"
|
||||
|
||||
# Registered event handlers.
|
||||
self._event_handlers: dict = {}
|
||||
self._event_handlers: Dict[str, EventHandler] = {}
|
||||
|
||||
# Set of tasks being executed. When a task finishes running it gets
|
||||
# automatically removed from the set. When we cleanup we wait for all
|
||||
@@ -103,18 +123,21 @@ class BaseObject(ABC):
|
||||
Can be sync or async.
|
||||
"""
|
||||
if event_name in self._event_handlers:
|
||||
self._event_handlers[event_name].append(handler)
|
||||
self._event_handlers[event_name].handlers.append(handler)
|
||||
else:
|
||||
logger.warning(f"Event handler {event_name} not registered")
|
||||
|
||||
def _register_event_handler(self, event_name: str):
|
||||
def _register_event_handler(self, event_name: str, sync: bool = False):
|
||||
"""Register an event handler type.
|
||||
|
||||
Args:
|
||||
event_name: The name of the event type to register.
|
||||
sync: Whether this event handler will be executed in a task.
|
||||
"""
|
||||
if event_name not in self._event_handlers:
|
||||
self._event_handlers[event_name] = []
|
||||
self._event_handlers[event_name] = EventHandler(
|
||||
name=event_name, handlers=[], is_sync=sync
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Event handler {event_name} not registered")
|
||||
|
||||
@@ -126,36 +149,40 @@ class BaseObject(ABC):
|
||||
*args: Positional arguments to pass to event handlers.
|
||||
**kwargs: Keyword arguments to pass to event handlers.
|
||||
"""
|
||||
# If we haven't registered an event handler, we don't need to do
|
||||
# anything.
|
||||
if not self._event_handlers.get(event_name):
|
||||
if event_name not in self._event_handlers:
|
||||
return
|
||||
|
||||
# Create the task.
|
||||
task = asyncio.create_task(self._run_task(event_name, *args, **kwargs))
|
||||
event_handler = self._event_handlers[event_name]
|
||||
|
||||
# Add it to our list of event tasks.
|
||||
self._event_tasks.add((event_name, task))
|
||||
if event_handler.is_sync:
|
||||
# Just run the handler.
|
||||
await self._run_handler(event_handler, *args, **kwargs)
|
||||
else:
|
||||
# Create the task.
|
||||
task = asyncio.create_task(self._run_handler(event_handler, *args, **kwargs))
|
||||
|
||||
# Remove the task from the event tasks list when the task completes.
|
||||
task.add_done_callback(self._event_task_finished)
|
||||
# Add it to our list of event tasks.
|
||||
self._event_tasks.add((event_name, task))
|
||||
|
||||
async def _run_task(self, event_name: str, *args, **kwargs):
|
||||
# Remove the task from the event tasks list when the task completes.
|
||||
task.add_done_callback(self._event_task_finished)
|
||||
|
||||
async def _run_handler(self, event_handler: EventHandler, *args, **kwargs):
|
||||
"""Execute all handlers for an event.
|
||||
|
||||
Args:
|
||||
event_name: The name of the event being handled.
|
||||
event_handler: The event handler to run.
|
||||
*args: Positional arguments to pass to handlers.
|
||||
**kwargs: Keyword arguments to pass to handlers.
|
||||
"""
|
||||
try:
|
||||
for handler in self._event_handlers[event_name]:
|
||||
for handler in event_handler.handlers:
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
await handler(self, *args, **kwargs)
|
||||
else:
|
||||
handler(self, *args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception in event handler {event_name}: {e}")
|
||||
logger.exception(f"Exception in event handler {event_handler.name}: {e}")
|
||||
|
||||
def _event_task_finished(self, task: asyncio.Task):
|
||||
"""Clean up completed event handler tasks.
|
||||
|
||||
Reference in New Issue
Block a user