Merge pull request #4211 from pipecat-ai/pk/openai-responses-websocket-service-refactor

Introduce WebsocketLLMService and refactor OpenAIResponsesLLMService …
This commit is contained in:
kompfner
2026-03-31 13:02:45 -04:00
committed by GitHub
5 changed files with 295 additions and 155 deletions

View File

@@ -84,6 +84,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
llm.register_function("get_current_weather", fetch_weather_from_api)
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
@llm.event_handler("on_connection_error")
async def on_connection_error(service, error):
logger.error(f"LLM connection error: {error}")
@llm.event_handler("on_function_calls_started")
async def on_function_calls_started(service, function_calls):
# Avoid appending this filler message to the LLM context — it would

View File

@@ -102,9 +102,12 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
await task.queue_frames([LLMRunFrame()])
await asyncio.sleep(10)
logger.info("Updating OpenAI LLM settings: temperature=0.1")
logger.info("Updating OpenAI LLM settings: temperature=1")
await task.queue_frame(
LLMUpdateSettingsFrame(delta=OpenAIResponsesLLMService.Settings(temperature=0.1))
# Known OpenAI Python issue (as of 2026-03-31): setting temperature
# to non-integer value results in failure.
# https://github.com/openai/openai-python/issues/2919
LLMUpdateSettingsFrame(delta=OpenAIResponsesLLMService.Settings(temperature=1))
)
@transport.event_handler("on_client_disconnected")

View File

@@ -8,6 +8,7 @@
import asyncio
import inspect
import json
import warnings
from dataclasses import dataclass
from typing import (
@@ -23,6 +24,8 @@ from typing import (
)
from loguru import logger
from websockets.exceptions import ConnectionClosed
from websockets.protocol import State
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.direct_function import DirectFunction, DirectFunctionWrapper
@@ -30,6 +33,7 @@ from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
FunctionCallCancelFrame,
FunctionCallFromLLM,
@@ -59,6 +63,7 @@ from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_service import AIService
from pipecat.services.settings import LLMSettings
from pipecat.services.websocket_service import WebsocketService
from pipecat.turns.user_turn_completion_mixin import UserTurnCompletionLLMServiceMixin
from pipecat.utils.context.llm_context_summarization import (
DEFAULT_SUMMARIZATION_TIMEOUT,
@@ -872,3 +877,190 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
def _function_call_task_finished(self, task: asyncio.Task):
if task in self._function_call_tasks:
del self._function_call_tasks[task]
# ---------------------------------------------------------------------------
# WebSocket LLM service base
# ---------------------------------------------------------------------------
class WebsocketReconnectedError(Exception):
"""Raised by ``_ws_send``/``_ws_recv`` after a transparent reconnection.
Signals that the WebSocket connection was lost and automatically
re-established. The current inference should be restarted — any
connection-local state on the server (e.g. cached responses) is gone.
"""
pass
class WebsocketLLMService(LLMService, WebsocketService):
"""Base class for websocket-based LLM services.
Each LLM inference is a discrete request/response exchange: send one
request, receive events inline until a terminal event, then wait for
the next frame to trigger an inference. This contrasts with
``WebsocketTTSService`` / ``WebsocketSTTService`` which stream data
continuously via a background receive loop
(``_receive_task_handler``). This class does **not** start a
background receive loop.
Provides connection lifecycle management (connect on start, disconnect
on stop/cancel), automatic reconnection with exponential backoff, and
three helpers for running each inference:
1. ``_ensure_connected()`` — verify the websocket is alive, reconnect
with exponential backoff if not.
2. ``_ws_send(message)`` — send the inference request as JSON.
3. ``_ws_recv()`` — receive and parse response events one at a time
until the caller sees a terminal event.
``_ws_send`` and ``_ws_recv`` catch ``ConnectionClosed`` transparently,
auto-reconnect via ``_try_reconnect``, and raise
``WebsocketReconnectedError`` so callers know the inference must be
restarted. If reconnection fails, the original ``ConnectionClosed``
propagates.
Subclasses must implement:
``_connect_websocket()``: Establish the websocket connection.
``_disconnect_websocket()``: Close the websocket and clean up.
Event handlers:
on_connection_error: Called when a websocket connection error occurs.
Example::
@llm.event_handler("on_connection_error")
async def on_connection_error(llm: LLMService, error: str):
logger.error(f"LLM connection error: {error}")
"""
def __init__(self, *, reconnect_on_error: bool = True, **kwargs):
"""Initialize the Websocket LLM service.
Args:
reconnect_on_error: Whether to automatically reconnect on websocket errors.
**kwargs: Additional arguments passed to parent classes.
"""
LLMService.__init__(self, **kwargs)
WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs)
self._register_event_handler("on_connection_error")
# -- lifecycle ------------------------------------------------------------
async def _connect(self):
"""Connect: reset flags and establish the websocket."""
await super()._connect()
await self._connect_websocket()
async def _disconnect(self):
"""Disconnect: set flags and close the websocket."""
await super()._disconnect()
await self._disconnect_websocket()
async def start(self, frame: StartFrame):
"""Start the service and establish WebSocket connection.
Args:
frame: The start frame triggering service initialization.
"""
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the service and close WebSocket connection.
Args:
frame: The end frame triggering service shutdown.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the service and close WebSocket connection.
Args:
frame: The cancel frame triggering service cancellation.
"""
await super().cancel(frame)
await self._disconnect()
# -- per-inference helpers ------------------------------------------------
async def _ws_send(self, message: dict):
"""Send a JSON message over the websocket.
Guards against sends during intentional disconnect. If the send
fails with ``ConnectionClosed``, attempts to reconnect and raises
``WebsocketReconnectedError`` on success so the caller can restart
the inference. If reconnection fails, the original
``ConnectionClosed`` propagates.
Args:
message: The message dict to serialize and send.
"""
if self._disconnecting or not self._websocket:
return
try:
await self._websocket.send(json.dumps(message))
except ConnectionClosed:
if self._disconnecting:
return
success = await self._try_reconnect(report_error=self._report_error)
if success:
raise WebsocketReconnectedError()
raise
async def _ws_recv(self) -> dict:
"""Receive and parse a JSON message from the websocket.
If the receive fails with ``ConnectionClosed``, attempts to
reconnect and raises ``WebsocketReconnectedError`` on success.
If reconnection fails, the original ``ConnectionClosed``
propagates.
Returns:
The parsed JSON message as a dict.
"""
try:
raw = await self._websocket.recv()
return json.loads(raw)
except ConnectionClosed:
if self._disconnecting:
raise
success = await self._try_reconnect(report_error=self._report_error)
if success:
raise WebsocketReconnectedError()
raise
async def _ensure_connected(self):
"""Ensure the websocket is connected, reconnecting if needed.
Uses ``_try_reconnect`` with exponential backoff.
Raises:
ConnectionError: If the connection could not be established.
"""
if self._websocket and self._websocket.state is not State.CLOSED:
return
success = await self._try_reconnect(report_error=self._report_error)
if not success:
raise ConnectionError(f"{self} failed to establish WebSocket connection")
# -- WebsocketService interface -------------------------------------------
async def _receive_messages(self):
"""Not used — messages are received inline during each inference.
This satisfies the ``WebsocketService`` abstract method but is never
called because ``_receive_task_handler`` is never started.
"""
raise NotImplementedError(
"WebsocketLLMService receives messages inline during inference, "
"not via a continuous background loop"
)
async def _report_error(self, error: ErrorFrame):
await self._call_event_handler("on_connection_error", error.error)
await self.push_error_frame(error)

View File

@@ -33,18 +33,20 @@ from pipecat.adapters.services.open_ai_responses_adapter import (
OpenAIResponsesLLMInvocationParams,
)
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
LLMContextFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
StartFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
from pipecat.services.llm_service import (
FunctionCallFromLLM,
LLMService,
WebsocketLLMService,
WebsocketReconnectedError,
)
from pipecat.services.settings import NOT_GIVEN as _NOT_GIVEN
from pipecat.services.settings import LLMSettings, _NotGiven
from pipecat.utils.tracing.service_decorators import traced_llm
@@ -338,7 +340,7 @@ class _BaseOpenAIResponsesLLMService(LLMService):
# ---------------------------------------------------------------------------
class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService):
class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService, WebsocketLLMService):
"""OpenAI Responses API LLM service using WebSocket transport.
Maintains a persistent WebSocket connection to ``wss://api.openai.com/v1/responses``
@@ -384,8 +386,6 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService):
super().__init__(**kwargs)
self._ws_url = ws_url
self._websocket = None
self._disconnecting = False
# State for previous_response_id optimization
self._previous_response_id: Optional[str] = None
@@ -398,40 +398,10 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService):
self._cancel_pending_response: bool = False
self._needs_drain: bool = False
# -- lifecycle ------------------------------------------------------------
# -- WebsocketLLMService interface ----------------------------------------
async def start(self, frame: StartFrame):
"""Start the service and establish WebSocket connection.
Args:
frame: The start frame triggering service initialization.
"""
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the service and close WebSocket connection.
Args:
frame: The end frame triggering service shutdown.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the service and close WebSocket connection.
Args:
frame: The cancel frame triggering service cancellation.
"""
await super().cancel(frame)
await self._disconnect()
# -- connection management ------------------------------------------------
async def _connect(self):
async def _connect_websocket(self):
"""Establish the WebSocket connection."""
self._disconnecting = False
try:
if self._websocket:
return
@@ -442,13 +412,12 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService):
},
)
except Exception as e:
await self.push_error(error_msg=f"Error connecting to WebSocket: {e}", exception=e)
self._websocket = None
await self.push_error(error_msg=f"Error connecting to WebSocket: {e}", exception=e)
async def _disconnect(self):
async def _disconnect_websocket(self):
"""Close the WebSocket connection and clear state."""
try:
self._disconnecting = True
await self.stop_all_metrics()
if self._websocket:
await self._websocket.close()
@@ -458,33 +427,6 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService):
self._websocket = None
self._clear_previous_response_state()
self._clear_cancellation_state()
self._disconnecting = False
async def _reconnect(self):
"""Reconnect to the WebSocket, clearing previous_response_id state."""
await self._disconnect()
await self._connect()
async def _ensure_connected(self):
"""Ensure a WebSocket connection is available, reconnecting if needed.
Raises:
_RetryableError: If the connection could not be established.
"""
if self._websocket is None:
await self._connect()
if self._websocket is None:
raise _RetryableError("Failed to establish WebSocket connection")
async def _ws_send(self, message: dict):
"""Send a JSON message over the WebSocket.
Args:
message: The message dict to serialize and send.
"""
if self._disconnecting or not self._websocket:
return
await self._websocket.send(json.dumps(message))
# -- previous_response_id optimization ------------------------------------
@@ -676,8 +618,14 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService):
This method reads and discards events until a terminal event
(``response.completed``, ``response.failed``, or
``response.incomplete``) arrives, ensuring the connection is clean.
Falls back to reconnecting if draining takes too long.
If draining times out or the connection drops, clears cancellation
state and returns — ``_ensure_connected`` will handle reconnection
before the next inference.
"""
if not self._websocket:
self._clear_cancellation_state()
return
logger.debug(f"{self}: Draining cancelled response events")
try:
while True:
@@ -711,9 +659,9 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService):
)
self._clear_cancellation_state()
return
except (asyncio.TimeoutError, ConnectionClosed) as e:
logger.warning(f"{self}: Error draining cancelled response: {e} — reconnecting")
await self._reconnect()
except (asyncio.TimeoutError, WebsocketReconnectedError, ConnectionClosed) as e:
logger.warning(f"{self}: Error draining cancelled response: {e}")
self._clear_cancellation_state()
# -- frame processing -----------------------------------------------------
@@ -765,7 +713,7 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService):
self._needs_drain = True
raise
except Exception as e:
await self.push_error(error_msg=f"Error during completion: {e}", exception=e)
await self.push_error(error_msg=f"Error during inference: {e}", exception=e)
finally:
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())
@@ -776,6 +724,12 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService):
async def _process_context(self, context: LLMContext):
"""Run inference over WebSocket with retry and previous_response_id.
Tries once with the ``previous_response_id`` optimization. On a
retriable error (cache miss, connection limit, connection drop),
clears state and retries once with the full context. Transport-level
``ConnectionClosed`` errors are handled transparently by
``_ws_send``/``_ws_recv`` (auto-reconnect → ``WebsocketReconnectedError``).
Args:
context: The LLM context containing conversation history.
"""
@@ -796,60 +750,61 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService):
full_input = invocation_params["input"]
max_attempts = 2
for attempt in range(max_attempts):
def build_params(*, apply_optimization: bool) -> dict:
params = self._build_response_params(invocation_params)
# WebSocket mode does not use the "stream" parameter
# WebSocket mode does not use the "stream" parameter.
params.pop("stream", None)
# Apply previous_response_id optimization (skipped after a retry)
if attempt == 0:
if apply_optimization:
params = self._apply_previous_response_optimization(params, full_input)
return params
try:
await self._ensure_connected()
await self.start_ttfb_metrics()
await self._ws_send({"type": "response.create", **params})
await self._receive_response_events(context, full_input)
return # Success
except _PreviousResponseNotFoundError:
logger.warning(
f"{self}: previous_response_not_found — "
f"retrying with full context ({len(full_input)} items)"
)
self._clear_previous_response_state()
await self.stop_ttfb_metrics()
if attempt >= max_attempts - 1:
await self.push_error(
error_msg="previous_response_not_found: retry also failed"
)
return
except _ConnectionLimitReachedError:
logger.warning(
f"{self}: WebSocket connection limit reached — "
f"reconnecting and retrying with full context ({len(full_input)} items)"
)
self._clear_previous_response_state()
await self.stop_ttfb_metrics()
await self._reconnect()
if attempt >= max_attempts - 1:
await self.push_error(error_msg="WebSocket connection limit: retry also failed")
return
except ConnectionClosed as e:
logger.warning(
f"{self}: WebSocket connection closed during inference: {e}"
f"reconnecting and retrying with full context ({len(full_input)} items)"
)
self._clear_previous_response_state()
self._websocket = None
await self.stop_ttfb_metrics()
await self._reconnect()
if attempt >= max_attempts - 1:
await self.push_error(
error_msg=f"WebSocket connection closed: retry also failed: {e}",
exception=e,
)
return
async def send_and_receive(params: dict):
await self._ensure_connected()
await self.start_ttfb_metrics()
await self._ws_send({"type": "response.create", **params})
await self._receive_response_events(context, full_input)
async def cleanup():
self._clear_previous_response_state()
await self.stop_ttfb_metrics()
# -- first attempt (with previous_response_id optimization) -----------
try:
await send_and_receive(build_params(apply_optimization=True))
return # Success
except _PreviousResponseNotFoundError:
logger.warning(
f"{self}: previous_response_not_found — "
f"retrying with full context ({len(full_input)} items)"
)
await cleanup()
except _ConnectionLimitReachedError:
logger.warning(
f"{self}: WebSocket connection limit reached — "
f"reconnecting and retrying with full context ({len(full_input)} items)"
)
await cleanup()
await self._try_reconnect(report_error=self._report_error)
except WebsocketReconnectedError:
# ConnectionClosed was handled by the base class — connection is
# fresh, so any connection-local server state is gone.
logger.warning(
f"{self}: Connection lost and recovered — "
f"retrying with full context ({len(full_input)} items)"
)
await cleanup()
except Exception:
await cleanup()
raise
# -- retry with full context (no optimization) ------------------------
try:
await send_and_receive(build_params(apply_optimization=False))
except Exception:
await cleanup()
raise
async def _receive_response_events(self, context: LLMContext, full_input: list):
"""Receive and process WebSocket events until the response completes.
@@ -861,14 +816,14 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService):
Raises:
_PreviousResponseNotFoundError: Server couldn't find previous response.
_ConnectionLimitReachedError: 60-minute connection limit reached.
ConnectionClosed: WebSocket connection was closed unexpectedly.
WebsocketReconnectedError: Connection was lost and auto-recovered.
ConnectionClosed: Connection was lost and could not be recovered.
"""
function_calls: Dict[str, Dict[str, str]] = {}
current_arguments: Dict[str, str] = {}
while True:
raw = await self._websocket.recv()
event = json.loads(raw)
event = await self._ws_recv()
event_type = event.get("type")
if event_type == "response.created":
@@ -1020,7 +975,7 @@ class OpenAIResponsesHttpLLMService(_BaseOpenAIResponsesLLMService):
await self._call_event_handler("on_completion_timeout")
await self.push_error(error_msg="LLM completion timeout", exception=e)
except Exception as e:
await self.push_error(error_msg=f"Error during completion: {e}", exception=e)
await self.push_error(error_msg=f"Error during inference: {e}", exception=e)
finally:
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())

View File

@@ -628,29 +628,20 @@ class TestDrainCancelledResponse:
assert len(cancel_calls) == 1
@pytest.mark.asyncio
async def test_drain_timeout_triggers_reconnect(self):
"""If draining takes too long, should fall back to reconnecting."""
async def test_drain_timeout_clears_state(self):
"""If draining times out, should clear cancellation state."""
service = _make_service()
service._needs_drain = True
service.stop_all_metrics = AsyncMock()
service.push_error = AsyncMock()
mock_ws = AsyncMock()
# recv() never returns a terminal event — times out
mock_ws.recv = AsyncMock(side_effect=asyncio.TimeoutError)
mock_ws.close = AsyncMock()
service._websocket = mock_ws
with patch(
"pipecat.services.openai.responses.llm.websocket_connect",
new_callable=AsyncMock,
return_value=AsyncMock(),
):
await service._drain_cancelled_response()
await service._drain_cancelled_response()
assert not service._needs_drain
# Should have reconnected (old ws closed)
mock_ws.close.assert_called_once()
assert not service._cancel_pending_response
# ---------------------------------------------------------------------------
@@ -688,7 +679,9 @@ class TestConnectionLifecycle:
new_callable=AsyncMock,
return_value=AsyncMock(),
):
await service._reconnect()
# _disconnect + _connect is the lifecycle equivalent of the old _reconnect
await service._disconnect()
await service._connect()
assert service._previous_response_id is None
mock_ws.close.assert_called_once()
@@ -723,17 +716,10 @@ class TestConnectionLifecycle:
@pytest.mark.asyncio
async def test_ensure_connected_raises_on_failure(self):
from pipecat.services.openai.responses.llm import _RetryableError
service = _make_service()
service._websocket = None
service.push_error = AsyncMock()
# Mock _try_reconnect to simulate exhausted retries
service._try_reconnect = AsyncMock(return_value=False)
# Mock connect to fail
with patch(
"pipecat.services.openai.responses.llm.websocket_connect",
new_callable=AsyncMock,
side_effect=Exception("Connection refused"),
):
with pytest.raises(_RetryableError):
await service._ensure_connected()
with pytest.raises(ConnectionError):
await service._ensure_connected()