diff --git a/examples/foundational/14-function-calling-openai-responses.py b/examples/foundational/14-function-calling-openai-responses.py index afdb92aa8..171e7b36e 100644 --- a/examples/foundational/14-function-calling-openai-responses.py +++ b/examples/foundational/14-function-calling-openai-responses.py @@ -84,6 +84,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): llm.register_function("get_current_weather", fetch_weather_from_api) llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation) + @llm.event_handler("on_connection_error") + async def on_connection_error(service, error): + logger.error(f"LLM connection error: {error}") + @llm.event_handler("on_function_calls_started") async def on_function_calls_started(service, function_calls): # Avoid appending this filler message to the LLM context — it would diff --git a/examples/foundational/55zi-update-settings-openai-responses-llm.py b/examples/foundational/55zi-update-settings-openai-responses-llm.py index 3ab5bacd0..8c214b639 100644 --- a/examples/foundational/55zi-update-settings-openai-responses-llm.py +++ b/examples/foundational/55zi-update-settings-openai-responses-llm.py @@ -102,9 +102,12 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): await task.queue_frames([LLMRunFrame()]) await asyncio.sleep(10) - logger.info("Updating OpenAI LLM settings: temperature=0.1") + logger.info("Updating OpenAI LLM settings: temperature=1") await task.queue_frame( - LLMUpdateSettingsFrame(delta=OpenAIResponsesLLMService.Settings(temperature=0.1)) + # Known OpenAI Python issue (as of 2026-03-31): setting temperature + # to non-integer value results in failure. + # https://github.com/openai/openai-python/issues/2919 + LLMUpdateSettingsFrame(delta=OpenAIResponsesLLMService.Settings(temperature=1)) ) @transport.event_handler("on_client_disconnected") diff --git a/src/pipecat/services/llm_service.py b/src/pipecat/services/llm_service.py index e7aaa7687..f2e247de9 100644 --- a/src/pipecat/services/llm_service.py +++ b/src/pipecat/services/llm_service.py @@ -8,6 +8,7 @@ import asyncio import inspect +import json import warnings from dataclasses import dataclass from typing import ( @@ -23,6 +24,8 @@ from typing import ( ) from loguru import logger +from websockets.exceptions import ConnectionClosed +from websockets.protocol import State from pipecat.adapters.base_llm_adapter import BaseLLMAdapter from pipecat.adapters.schemas.direct_function import DirectFunction, DirectFunctionWrapper @@ -30,6 +33,7 @@ from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter from pipecat.frames.frames import ( CancelFrame, EndFrame, + ErrorFrame, Frame, FunctionCallCancelFrame, FunctionCallFromLLM, @@ -59,6 +63,7 @@ from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_service import AIService from pipecat.services.settings import LLMSettings +from pipecat.services.websocket_service import WebsocketService from pipecat.turns.user_turn_completion_mixin import UserTurnCompletionLLMServiceMixin from pipecat.utils.context.llm_context_summarization import ( DEFAULT_SUMMARIZATION_TIMEOUT, @@ -872,3 +877,190 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): def _function_call_task_finished(self, task: asyncio.Task): if task in self._function_call_tasks: del self._function_call_tasks[task] + + +# --------------------------------------------------------------------------- +# WebSocket LLM service base +# --------------------------------------------------------------------------- + + +class WebsocketReconnectedError(Exception): + """Raised by ``_ws_send``/``_ws_recv`` after a transparent reconnection. + + Signals that the WebSocket connection was lost and automatically + re-established. The current inference should be restarted — any + connection-local state on the server (e.g. cached responses) is gone. + """ + + pass + + +class WebsocketLLMService(LLMService, WebsocketService): + """Base class for websocket-based LLM services. + + Each LLM inference is a discrete request/response exchange: send one + request, receive events inline until a terminal event, then wait for + the next frame to trigger an inference. This contrasts with + ``WebsocketTTSService`` / ``WebsocketSTTService`` which stream data + continuously via a background receive loop + (``_receive_task_handler``). This class does **not** start a + background receive loop. + + Provides connection lifecycle management (connect on start, disconnect + on stop/cancel), automatic reconnection with exponential backoff, and + three helpers for running each inference: + + 1. ``_ensure_connected()`` — verify the websocket is alive, reconnect + with exponential backoff if not. + 2. ``_ws_send(message)`` — send the inference request as JSON. + 3. ``_ws_recv()`` — receive and parse response events one at a time + until the caller sees a terminal event. + + ``_ws_send`` and ``_ws_recv`` catch ``ConnectionClosed`` transparently, + auto-reconnect via ``_try_reconnect``, and raise + ``WebsocketReconnectedError`` so callers know the inference must be + restarted. If reconnection fails, the original ``ConnectionClosed`` + propagates. + + Subclasses must implement: + ``_connect_websocket()``: Establish the websocket connection. + ``_disconnect_websocket()``: Close the websocket and clean up. + + Event handlers: + on_connection_error: Called when a websocket connection error occurs. + + Example:: + + @llm.event_handler("on_connection_error") + async def on_connection_error(llm: LLMService, error: str): + logger.error(f"LLM connection error: {error}") + """ + + def __init__(self, *, reconnect_on_error: bool = True, **kwargs): + """Initialize the Websocket LLM service. + + Args: + reconnect_on_error: Whether to automatically reconnect on websocket errors. + **kwargs: Additional arguments passed to parent classes. + """ + LLMService.__init__(self, **kwargs) + WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs) + self._register_event_handler("on_connection_error") + + # -- lifecycle ------------------------------------------------------------ + + async def _connect(self): + """Connect: reset flags and establish the websocket.""" + await super()._connect() + await self._connect_websocket() + + async def _disconnect(self): + """Disconnect: set flags and close the websocket.""" + await super()._disconnect() + await self._disconnect_websocket() + + async def start(self, frame: StartFrame): + """Start the service and establish WebSocket connection. + + Args: + frame: The start frame triggering service initialization. + """ + await super().start(frame) + await self._connect() + + async def stop(self, frame: EndFrame): + """Stop the service and close WebSocket connection. + + Args: + frame: The end frame triggering service shutdown. + """ + await super().stop(frame) + await self._disconnect() + + async def cancel(self, frame: CancelFrame): + """Cancel the service and close WebSocket connection. + + Args: + frame: The cancel frame triggering service cancellation. + """ + await super().cancel(frame) + await self._disconnect() + + # -- per-inference helpers ------------------------------------------------ + + async def _ws_send(self, message: dict): + """Send a JSON message over the websocket. + + Guards against sends during intentional disconnect. If the send + fails with ``ConnectionClosed``, attempts to reconnect and raises + ``WebsocketReconnectedError`` on success so the caller can restart + the inference. If reconnection fails, the original + ``ConnectionClosed`` propagates. + + Args: + message: The message dict to serialize and send. + """ + if self._disconnecting or not self._websocket: + return + try: + await self._websocket.send(json.dumps(message)) + except ConnectionClosed: + if self._disconnecting: + return + success = await self._try_reconnect(report_error=self._report_error) + if success: + raise WebsocketReconnectedError() + raise + + async def _ws_recv(self) -> dict: + """Receive and parse a JSON message from the websocket. + + If the receive fails with ``ConnectionClosed``, attempts to + reconnect and raises ``WebsocketReconnectedError`` on success. + If reconnection fails, the original ``ConnectionClosed`` + propagates. + + Returns: + The parsed JSON message as a dict. + """ + try: + raw = await self._websocket.recv() + return json.loads(raw) + except ConnectionClosed: + if self._disconnecting: + raise + success = await self._try_reconnect(report_error=self._report_error) + if success: + raise WebsocketReconnectedError() + raise + + async def _ensure_connected(self): + """Ensure the websocket is connected, reconnecting if needed. + + Uses ``_try_reconnect`` with exponential backoff. + + Raises: + ConnectionError: If the connection could not be established. + """ + if self._websocket and self._websocket.state is not State.CLOSED: + return + success = await self._try_reconnect(report_error=self._report_error) + if not success: + raise ConnectionError(f"{self} failed to establish WebSocket connection") + + # -- WebsocketService interface ------------------------------------------- + + async def _receive_messages(self): + """Not used — messages are received inline during each inference. + + This satisfies the ``WebsocketService`` abstract method but is never + called because ``_receive_task_handler`` is never started. + """ + raise NotImplementedError( + "WebsocketLLMService receives messages inline during inference, " + "not via a continuous background loop" + ) + + async def _report_error(self, error: ErrorFrame): + await self._call_event_handler("on_connection_error", error.error) + await self.push_error_frame(error) diff --git a/src/pipecat/services/openai/responses/llm.py b/src/pipecat/services/openai/responses/llm.py index 02e7d2fc8..fce6b46d8 100644 --- a/src/pipecat/services/openai/responses/llm.py +++ b/src/pipecat/services/openai/responses/llm.py @@ -33,18 +33,20 @@ from pipecat.adapters.services.open_ai_responses_adapter import ( OpenAIResponsesLLMInvocationParams, ) from pipecat.frames.frames import ( - CancelFrame, - EndFrame, Frame, LLMContextFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, - StartFrame, ) from pipecat.metrics.metrics import LLMTokenUsage from pipecat.processors.aggregators.llm_context import LLMContext from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.llm_service import FunctionCallFromLLM, LLMService +from pipecat.services.llm_service import ( + FunctionCallFromLLM, + LLMService, + WebsocketLLMService, + WebsocketReconnectedError, +) from pipecat.services.settings import NOT_GIVEN as _NOT_GIVEN from pipecat.services.settings import LLMSettings, _NotGiven from pipecat.utils.tracing.service_decorators import traced_llm @@ -338,7 +340,7 @@ class _BaseOpenAIResponsesLLMService(LLMService): # --------------------------------------------------------------------------- -class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService): +class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService, WebsocketLLMService): """OpenAI Responses API LLM service using WebSocket transport. Maintains a persistent WebSocket connection to ``wss://api.openai.com/v1/responses`` @@ -384,8 +386,6 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService): super().__init__(**kwargs) self._ws_url = ws_url - self._websocket = None - self._disconnecting = False # State for previous_response_id optimization self._previous_response_id: Optional[str] = None @@ -398,40 +398,10 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService): self._cancel_pending_response: bool = False self._needs_drain: bool = False - # -- lifecycle ------------------------------------------------------------ + # -- WebsocketLLMService interface ---------------------------------------- - async def start(self, frame: StartFrame): - """Start the service and establish WebSocket connection. - - Args: - frame: The start frame triggering service initialization. - """ - await super().start(frame) - await self._connect() - - async def stop(self, frame: EndFrame): - """Stop the service and close WebSocket connection. - - Args: - frame: The end frame triggering service shutdown. - """ - await super().stop(frame) - await self._disconnect() - - async def cancel(self, frame: CancelFrame): - """Cancel the service and close WebSocket connection. - - Args: - frame: The cancel frame triggering service cancellation. - """ - await super().cancel(frame) - await self._disconnect() - - # -- connection management ------------------------------------------------ - - async def _connect(self): + async def _connect_websocket(self): """Establish the WebSocket connection.""" - self._disconnecting = False try: if self._websocket: return @@ -442,13 +412,12 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService): }, ) except Exception as e: - await self.push_error(error_msg=f"Error connecting to WebSocket: {e}", exception=e) self._websocket = None + await self.push_error(error_msg=f"Error connecting to WebSocket: {e}", exception=e) - async def _disconnect(self): + async def _disconnect_websocket(self): """Close the WebSocket connection and clear state.""" try: - self._disconnecting = True await self.stop_all_metrics() if self._websocket: await self._websocket.close() @@ -458,33 +427,6 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService): self._websocket = None self._clear_previous_response_state() self._clear_cancellation_state() - self._disconnecting = False - - async def _reconnect(self): - """Reconnect to the WebSocket, clearing previous_response_id state.""" - await self._disconnect() - await self._connect() - - async def _ensure_connected(self): - """Ensure a WebSocket connection is available, reconnecting if needed. - - Raises: - _RetryableError: If the connection could not be established. - """ - if self._websocket is None: - await self._connect() - if self._websocket is None: - raise _RetryableError("Failed to establish WebSocket connection") - - async def _ws_send(self, message: dict): - """Send a JSON message over the WebSocket. - - Args: - message: The message dict to serialize and send. - """ - if self._disconnecting or not self._websocket: - return - await self._websocket.send(json.dumps(message)) # -- previous_response_id optimization ------------------------------------ @@ -676,8 +618,14 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService): This method reads and discards events until a terminal event (``response.completed``, ``response.failed``, or ``response.incomplete``) arrives, ensuring the connection is clean. - Falls back to reconnecting if draining takes too long. + If draining times out or the connection drops, clears cancellation + state and returns — ``_ensure_connected`` will handle reconnection + before the next inference. """ + if not self._websocket: + self._clear_cancellation_state() + return + logger.debug(f"{self}: Draining cancelled response events") try: while True: @@ -711,9 +659,9 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService): ) self._clear_cancellation_state() return - except (asyncio.TimeoutError, ConnectionClosed) as e: - logger.warning(f"{self}: Error draining cancelled response: {e} — reconnecting") - await self._reconnect() + except (asyncio.TimeoutError, WebsocketReconnectedError, ConnectionClosed) as e: + logger.warning(f"{self}: Error draining cancelled response: {e}") + self._clear_cancellation_state() # -- frame processing ----------------------------------------------------- @@ -765,7 +713,7 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService): self._needs_drain = True raise except Exception as e: - await self.push_error(error_msg=f"Error during completion: {e}", exception=e) + await self.push_error(error_msg=f"Error during inference: {e}", exception=e) finally: await self.stop_processing_metrics() await self.push_frame(LLMFullResponseEndFrame()) @@ -776,6 +724,12 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService): async def _process_context(self, context: LLMContext): """Run inference over WebSocket with retry and previous_response_id. + Tries once with the ``previous_response_id`` optimization. On a + retriable error (cache miss, connection limit, connection drop), + clears state and retries once with the full context. Transport-level + ``ConnectionClosed`` errors are handled transparently by + ``_ws_send``/``_ws_recv`` (auto-reconnect → ``WebsocketReconnectedError``). + Args: context: The LLM context containing conversation history. """ @@ -796,60 +750,61 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService): full_input = invocation_params["input"] - max_attempts = 2 - for attempt in range(max_attempts): + def build_params(*, apply_optimization: bool) -> dict: params = self._build_response_params(invocation_params) - # WebSocket mode does not use the "stream" parameter + # WebSocket mode does not use the "stream" parameter. params.pop("stream", None) - - # Apply previous_response_id optimization (skipped after a retry) - if attempt == 0: + if apply_optimization: params = self._apply_previous_response_optimization(params, full_input) + return params - try: - await self._ensure_connected() - await self.start_ttfb_metrics() - await self._ws_send({"type": "response.create", **params}) - await self._receive_response_events(context, full_input) - return # Success - except _PreviousResponseNotFoundError: - logger.warning( - f"{self}: previous_response_not_found — " - f"retrying with full context ({len(full_input)} items)" - ) - self._clear_previous_response_state() - await self.stop_ttfb_metrics() - if attempt >= max_attempts - 1: - await self.push_error( - error_msg="previous_response_not_found: retry also failed" - ) - return - except _ConnectionLimitReachedError: - logger.warning( - f"{self}: WebSocket connection limit reached — " - f"reconnecting and retrying with full context ({len(full_input)} items)" - ) - self._clear_previous_response_state() - await self.stop_ttfb_metrics() - await self._reconnect() - if attempt >= max_attempts - 1: - await self.push_error(error_msg="WebSocket connection limit: retry also failed") - return - except ConnectionClosed as e: - logger.warning( - f"{self}: WebSocket connection closed during inference: {e} — " - f"reconnecting and retrying with full context ({len(full_input)} items)" - ) - self._clear_previous_response_state() - self._websocket = None - await self.stop_ttfb_metrics() - await self._reconnect() - if attempt >= max_attempts - 1: - await self.push_error( - error_msg=f"WebSocket connection closed: retry also failed: {e}", - exception=e, - ) - return + async def send_and_receive(params: dict): + await self._ensure_connected() + await self.start_ttfb_metrics() + await self._ws_send({"type": "response.create", **params}) + await self._receive_response_events(context, full_input) + + async def cleanup(): + self._clear_previous_response_state() + await self.stop_ttfb_metrics() + + # -- first attempt (with previous_response_id optimization) ----------- + + try: + await send_and_receive(build_params(apply_optimization=True)) + return # Success + except _PreviousResponseNotFoundError: + logger.warning( + f"{self}: previous_response_not_found — " + f"retrying with full context ({len(full_input)} items)" + ) + await cleanup() + except _ConnectionLimitReachedError: + logger.warning( + f"{self}: WebSocket connection limit reached — " + f"reconnecting and retrying with full context ({len(full_input)} items)" + ) + await cleanup() + await self._try_reconnect(report_error=self._report_error) + except WebsocketReconnectedError: + # ConnectionClosed was handled by the base class — connection is + # fresh, so any connection-local server state is gone. + logger.warning( + f"{self}: Connection lost and recovered — " + f"retrying with full context ({len(full_input)} items)" + ) + await cleanup() + except Exception: + await cleanup() + raise + + # -- retry with full context (no optimization) ------------------------ + + try: + await send_and_receive(build_params(apply_optimization=False)) + except Exception: + await cleanup() + raise async def _receive_response_events(self, context: LLMContext, full_input: list): """Receive and process WebSocket events until the response completes. @@ -861,14 +816,14 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService): Raises: _PreviousResponseNotFoundError: Server couldn't find previous response. _ConnectionLimitReachedError: 60-minute connection limit reached. - ConnectionClosed: WebSocket connection was closed unexpectedly. + WebsocketReconnectedError: Connection was lost and auto-recovered. + ConnectionClosed: Connection was lost and could not be recovered. """ function_calls: Dict[str, Dict[str, str]] = {} current_arguments: Dict[str, str] = {} while True: - raw = await self._websocket.recv() - event = json.loads(raw) + event = await self._ws_recv() event_type = event.get("type") if event_type == "response.created": @@ -1020,7 +975,7 @@ class OpenAIResponsesHttpLLMService(_BaseOpenAIResponsesLLMService): await self._call_event_handler("on_completion_timeout") await self.push_error(error_msg="LLM completion timeout", exception=e) except Exception as e: - await self.push_error(error_msg=f"Error during completion: {e}", exception=e) + await self.push_error(error_msg=f"Error during inference: {e}", exception=e) finally: await self.stop_processing_metrics() await self.push_frame(LLMFullResponseEndFrame()) diff --git a/tests/test_openai_responses_websocket.py b/tests/test_openai_responses_websocket.py index 8623da0ba..fb3e41cc0 100644 --- a/tests/test_openai_responses_websocket.py +++ b/tests/test_openai_responses_websocket.py @@ -628,29 +628,20 @@ class TestDrainCancelledResponse: assert len(cancel_calls) == 1 @pytest.mark.asyncio - async def test_drain_timeout_triggers_reconnect(self): - """If draining takes too long, should fall back to reconnecting.""" + async def test_drain_timeout_clears_state(self): + """If draining times out, should clear cancellation state.""" service = _make_service() service._needs_drain = True - service.stop_all_metrics = AsyncMock() - service.push_error = AsyncMock() mock_ws = AsyncMock() # recv() never returns a terminal event — times out mock_ws.recv = AsyncMock(side_effect=asyncio.TimeoutError) - mock_ws.close = AsyncMock() service._websocket = mock_ws - with patch( - "pipecat.services.openai.responses.llm.websocket_connect", - new_callable=AsyncMock, - return_value=AsyncMock(), - ): - await service._drain_cancelled_response() + await service._drain_cancelled_response() assert not service._needs_drain - # Should have reconnected (old ws closed) - mock_ws.close.assert_called_once() + assert not service._cancel_pending_response # --------------------------------------------------------------------------- @@ -688,7 +679,9 @@ class TestConnectionLifecycle: new_callable=AsyncMock, return_value=AsyncMock(), ): - await service._reconnect() + # _disconnect + _connect is the lifecycle equivalent of the old _reconnect + await service._disconnect() + await service._connect() assert service._previous_response_id is None mock_ws.close.assert_called_once() @@ -723,17 +716,10 @@ class TestConnectionLifecycle: @pytest.mark.asyncio async def test_ensure_connected_raises_on_failure(self): - from pipecat.services.openai.responses.llm import _RetryableError - service = _make_service() service._websocket = None - service.push_error = AsyncMock() + # Mock _try_reconnect to simulate exhausted retries + service._try_reconnect = AsyncMock(return_value=False) - # Mock connect to fail - with patch( - "pipecat.services.openai.responses.llm.websocket_connect", - new_callable=AsyncMock, - side_effect=Exception("Connection refused"), - ): - with pytest.raises(_RetryableError): - await service._ensure_connected() + with pytest.raises(ConnectionError): + await service._ensure_connected()