diff --git a/CHANGELOG.md b/CHANGELOG.md index ca5d55089..c413d16fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added a peer connection monitor to the `SmallWebRTCConnection` that + automatically disconnects if the connection fails to establish within + the timeout (1 minute by default). + - Added memory cleanup improvements to reduce memory peaks. - Added `on_before_process_frame`, `on_after_process_frame`, diff --git a/src/pipecat/transports/smallwebrtc/connection.py b/src/pipecat/transports/smallwebrtc/connection.py index decd8ca58..45b1cffb2 100644 --- a/src/pipecat/transports/smallwebrtc/connection.py +++ b/src/pipecat/transports/smallwebrtc/connection.py @@ -206,11 +206,16 @@ class SmallWebRTCConnection(BaseObject): for real-time audio/video communication. """ - def __init__(self, ice_servers: Optional[Union[List[str], List[IceServer]]] = None): + def __init__( + self, + ice_servers: Optional[Union[List[str], List[IceServer]]] = None, + connection_timeout_secs: int = 60, + ): """Initialize the WebRTC connection. Args: ice_servers: List of ICE servers as URLs or IceServer objects. + connection_timeout_secs: Timeout in seconds for connecting to the peer. Raises: TypeError: If ice_servers contains mixed types or unsupported types. @@ -231,6 +236,7 @@ class SmallWebRTCConnection(BaseObject): VIDEO_TRANSCEIVER_INDEX: self.video_input_track, SCREEN_VIDEO_TRANSCEIVER_INDEX: self.screen_video_input_track, } + self.connection_timeout_secs = connection_timeout_secs self._initialize() @@ -279,6 +285,7 @@ class SmallWebRTCConnection(BaseObject): self._last_received_time = None self._message_queue = [] self._pending_app_messages = [] + self._connecting_timeout_task = None def _setup_listeners(self): """Set up event listeners for the peer connection.""" @@ -499,6 +506,7 @@ class SmallWebRTCConnection(BaseObject): self._message_queue.clear() self._pending_app_messages.clear() self._track_map = {} + self._cancel_monitoring_connecting_state() def get_answer(self): """Get the SDP answer for the current connection. @@ -516,9 +524,45 @@ class SmallWebRTCConnection(BaseObject): "pc_id": self._pc_id, } + def _monitoring_connecting_state(self) -> None: + """Start monitoring the peer connection while it is in the *connecting* state. + + This method schedules a timeout task that will automatically close the + connection if it remains in the connecting state for more than the specified + timeout, default to 60 seconds. + """ + logger.debug("Monitoring connecting state") + + async def timeout_handler(): + # We will close the connection in case we have remained in the connecting state for over 1 minute + await asyncio.sleep(self.connection_timeout_secs) + logger.warning("Timeout establishing the connection to the remote peer. Closing.") + + await self._close() + + # Create and store the timeout task + self._connecting_timeout_task = asyncio.create_task(timeout_handler()) + + def _cancel_monitoring_connecting_state(self) -> None: + """Cancel the ongoing connecting-state timeout task, if any. + + This method should be called once the connection has either succeeded or + transitioned out of the connecting state. If the timeout task is still + pending, it will be canceled and the reference cleared. + """ + if self._connecting_timeout_task and not self._connecting_timeout_task.done(): + logger.debug("Cancelling the connecting timeout task") + self._connecting_timeout_task.cancel() + self._connecting_timeout_task = None + async def _handle_new_connection_state(self): """Handle changes in the peer connection state.""" state = self._pc.connectionState + if state == "connecting": + self._monitoring_connecting_state() + else: + self._cancel_monitoring_connecting_state() + if state == "connected" and not self._connect_invoked: # We are going to wait until the pipeline is ready before triggering the event return