diff --git a/CHANGELOG.md b/CHANGELOG.md index 74f5aa18c..f0ba80455 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- It is now possible to register synchronous event handlers. By default, all + event handlers are executed in a separate task. However, in some cases we want + to guarantee order of execution, for example, executing something before + disconnecting a transport. + + ```python + self._register_event_handler("on_event_name", sync=True) + ``` + - Added support for global location in `GoogleVertexLLMService`. The service now supports both regional locations (e.g., "us-east4") and the "global" location for Vertex AI endpoints. When using "global" location, the service will use diff --git a/src/pipecat/utils/base_object.py b/src/pipecat/utils/base_object.py index bf34515a6..b40911fec 100644 --- a/src/pipecat/utils/base_object.py +++ b/src/pipecat/utils/base_object.py @@ -14,13 +14,33 @@ and async cleanup for all Pipecat components. import asyncio import inspect from abc import ABC -from typing import Optional +from dataclasses import dataclass +from typing import Any, Dict, List, Optional from loguru import logger from pipecat.utils.utils import obj_count, obj_id +@dataclass +class EventHandler: + """Data class to store event handlers information. + + This data class stores the event name, a list of handlers to run for this + event, and whether these handlers will be executed in a task. + + Attributes: + name (str): The name of the event handler. + handlers (List[Any]): A list of functions to be called when this event is triggered. + is_sync (bool): Indicates whether the functions are executed in a task. + + """ + + name: str + handlers: List[Any] + is_sync: bool + + class BaseObject(ABC): """Abstract base class providing common functionality for Pipecat objects. @@ -41,7 +61,7 @@ class BaseObject(ABC): self._name = name or f"{self.__class__.__name__}#{obj_count(self)}" # Registered event handlers. - self._event_handlers: dict = {} + self._event_handlers: Dict[str, EventHandler] = {} # Set of tasks being executed. When a task finishes running it gets # automatically removed from the set. When we cleanup we wait for all @@ -103,18 +123,21 @@ class BaseObject(ABC): Can be sync or async. """ if event_name in self._event_handlers: - self._event_handlers[event_name].append(handler) + self._event_handlers[event_name].handlers.append(handler) else: logger.warning(f"Event handler {event_name} not registered") - def _register_event_handler(self, event_name: str): + def _register_event_handler(self, event_name: str, sync: bool = False): """Register an event handler type. Args: event_name: The name of the event type to register. + sync: Whether this event handler will be executed in a task. """ if event_name not in self._event_handlers: - self._event_handlers[event_name] = [] + self._event_handlers[event_name] = EventHandler( + name=event_name, handlers=[], is_sync=sync + ) else: logger.warning(f"Event handler {event_name} not registered") @@ -126,36 +149,40 @@ class BaseObject(ABC): *args: Positional arguments to pass to event handlers. **kwargs: Keyword arguments to pass to event handlers. """ - # If we haven't registered an event handler, we don't need to do - # anything. - if not self._event_handlers.get(event_name): + if event_name not in self._event_handlers: return - # Create the task. - task = asyncio.create_task(self._run_task(event_name, *args, **kwargs)) + event_handler = self._event_handlers[event_name] - # Add it to our list of event tasks. - self._event_tasks.add((event_name, task)) + if event_handler.is_sync: + # Just run the handler. + await self._run_handler(event_handler, *args, **kwargs) + else: + # Create the task. + task = asyncio.create_task(self._run_handler(event_handler, *args, **kwargs)) - # Remove the task from the event tasks list when the task completes. - task.add_done_callback(self._event_task_finished) + # Add it to our list of event tasks. + self._event_tasks.add((event_name, task)) - async def _run_task(self, event_name: str, *args, **kwargs): + # Remove the task from the event tasks list when the task completes. + task.add_done_callback(self._event_task_finished) + + async def _run_handler(self, event_handler: EventHandler, *args, **kwargs): """Execute all handlers for an event. Args: - event_name: The name of the event being handled. + event_handler: The event handler to run. *args: Positional arguments to pass to handlers. **kwargs: Keyword arguments to pass to handlers. """ try: - for handler in self._event_handlers[event_name]: + for handler in event_handler.handlers: if inspect.iscoroutinefunction(handler): await handler(self, *args, **kwargs) else: handler(self, *args, **kwargs) except Exception as e: - logger.exception(f"Exception in event handler {event_name}: {e}") + logger.exception(f"Exception in event handler {event_handler.name}: {e}") def _event_task_finished(self, task: asyncio.Task): """Clean up completed event handler tasks.