Add WebsocketLLMService as a base class for WebSocket-based LLM services, parallel to WebsocketTTSService/WebsocketSTTService but codifying a transactional request-response model rather than a continuous background receive loop. WebsocketLLMService provides: - Connection lifecycle (start/stop/cancel → connect/disconnect) - _ws_send/_ws_recv with transparent ConnectionClosed handling (auto-reconnect via exponential backoff → WebsocketReconnectedError) - _ensure_connected with retry via _try_reconnect OpenAIResponsesLLMService now inherits from WebsocketLLMService, removing duplicated connection management code (_connect, _disconnect, _reconnect, _ensure_connected, _ws_send, start, stop, cancel) and simplifying _process_context from a loop with attempt tracking to a flat try/except with a single retry.
726 lines
26 KiB
Python
726 lines
26 KiB
Python
#
|
|
# Copyright (c) 2024-2026, Daily
|
|
#
|
|
# SPDX-License-Identifier: BSD 2-Clause License
|
|
#
|
|
|
|
"""Tests for the WebSocket variant of OpenAIResponsesLLMService."""
|
|
|
|
import asyncio
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from pipecat.processors.aggregators.llm_context import LLMContext
|
|
from pipecat.processors.frame_processor import FrameDirection
|
|
from pipecat.services.openai.responses.llm import OpenAIResponsesLLMService
|
|
|
|
|
|
def _make_service(**kwargs):
|
|
"""Create a service with the client mocked out."""
|
|
with patch.object(OpenAIResponsesLLMService, "_create_client"):
|
|
service = OpenAIResponsesLLMService(
|
|
api_key="test-key",
|
|
**kwargs,
|
|
)
|
|
service._client = AsyncMock()
|
|
return service
|
|
|
|
|
|
def _ws_events(*events):
|
|
"""Build a mock WebSocket that yields the given events from recv()."""
|
|
ws = AsyncMock()
|
|
# .recv() returns each event in order, then raises StopAsyncIteration
|
|
ws.recv = AsyncMock(side_effect=[json.dumps(e) for e in events])
|
|
ws.send = AsyncMock()
|
|
ws.close = AsyncMock()
|
|
ws.close_code = None
|
|
return ws
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Hash determinism
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestHashInputItems:
|
|
def test_same_input_same_hash(self):
|
|
items = [{"role": "user", "content": "hello"}]
|
|
h1 = OpenAIResponsesLLMService._hash_input_items(items)
|
|
h2 = OpenAIResponsesLLMService._hash_input_items(items)
|
|
assert h1 == h2
|
|
|
|
def test_different_input_different_hash(self):
|
|
h1 = OpenAIResponsesLLMService._hash_input_items([{"role": "user", "content": "hello"}])
|
|
h2 = OpenAIResponsesLLMService._hash_input_items([{"role": "user", "content": "world"}])
|
|
assert h1 != h2
|
|
|
|
def test_order_independent_keys(self):
|
|
"""Keys within a dict should not affect hash (sort_keys=True)."""
|
|
h1 = OpenAIResponsesLLMService._hash_input_items([{"a": 1, "b": 2}])
|
|
h2 = OpenAIResponsesLLMService._hash_input_items([{"b": 2, "a": 1}])
|
|
assert h1 == h2
|
|
|
|
|
|
class TestStartsWithResponseOutput:
|
|
def test_text_message_matches_by_role(self):
|
|
response_output = [
|
|
{
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"content": [{"type": "output_text", "text": "Hello!"}],
|
|
}
|
|
]
|
|
# Adapter produces a different format, but same role
|
|
items = [{"role": "assistant", "content": "Hello!"}, {"role": "user", "content": "hi"}]
|
|
assert OpenAIResponsesLLMService._starts_with_response_output(items, response_output)
|
|
|
|
def test_function_call_matches_by_call_id(self):
|
|
response_output = [
|
|
{
|
|
"type": "function_call",
|
|
"id": "fc_1",
|
|
"call_id": "call_1",
|
|
"name": "get_weather",
|
|
"arguments": '{"location": "SF"}',
|
|
}
|
|
]
|
|
# Adapter format (no "id" field)
|
|
items = [
|
|
{
|
|
"type": "function_call",
|
|
"call_id": "call_1",
|
|
"name": "get_weather",
|
|
"arguments": "{}",
|
|
},
|
|
{"type": "function_call_output", "call_id": "call_1", "output": "sunny"},
|
|
]
|
|
assert OpenAIResponsesLLMService._starts_with_response_output(items, response_output)
|
|
|
|
def test_mixed_output(self):
|
|
response_output = [
|
|
{
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"content": [{"type": "output_text", "text": "Let me check."}],
|
|
},
|
|
{
|
|
"type": "function_call",
|
|
"id": "fc_1",
|
|
"call_id": "call_1",
|
|
"name": "get_weather",
|
|
"arguments": "{}",
|
|
},
|
|
]
|
|
items = [
|
|
{"role": "assistant", "content": "Let me check."},
|
|
{
|
|
"type": "function_call",
|
|
"call_id": "call_1",
|
|
"name": "get_weather",
|
|
"arguments": "{}",
|
|
},
|
|
{"type": "function_call_output", "call_id": "call_1", "output": "sunny"},
|
|
]
|
|
assert OpenAIResponsesLLMService._starts_with_response_output(items, response_output)
|
|
|
|
def test_role_mismatch(self):
|
|
response_output = [{"type": "message", "role": "assistant", "content": []}]
|
|
items = [{"role": "user", "content": "hi"}]
|
|
assert not OpenAIResponsesLLMService._starts_with_response_output(items, response_output)
|
|
|
|
def test_text_content_mismatch(self):
|
|
response_output = [
|
|
{
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"content": [{"type": "output_text", "text": "Hello!"}],
|
|
}
|
|
]
|
|
items = [{"role": "assistant", "content": "Something completely different"}]
|
|
assert not OpenAIResponsesLLMService._starts_with_response_output(items, response_output)
|
|
|
|
def test_call_id_mismatch(self):
|
|
response_output = [{"type": "function_call", "call_id": "call_1", "name": "f"}]
|
|
items = [{"type": "function_call", "call_id": "call_999", "name": "f"}]
|
|
assert not OpenAIResponsesLLMService._starts_with_response_output(items, response_output)
|
|
|
|
def test_too_few_items(self):
|
|
response_output = [
|
|
{"type": "message", "role": "assistant", "content": []},
|
|
{"type": "function_call", "call_id": "call_1", "name": "f"},
|
|
]
|
|
items = [{"role": "assistant", "content": "hi"}]
|
|
assert not OpenAIResponsesLLMService._starts_with_response_output(items, response_output)
|
|
|
|
def test_empty_output_always_matches(self):
|
|
assert OpenAIResponsesLLMService._starts_with_response_output([], [])
|
|
assert OpenAIResponsesLLMService._starts_with_response_output([{"role": "user"}], [])
|
|
|
|
def test_unknown_output_type_rejects(self):
|
|
response_output = [{"type": "unknown_thing", "data": "something"}]
|
|
items = [{"role": "assistant", "content": "hi"}]
|
|
assert not OpenAIResponsesLLMService._starts_with_response_output(items, response_output)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# previous_response_id optimization
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPreviousResponseOptimization:
|
|
def test_no_previous_state_sends_full_input(self):
|
|
service = _make_service()
|
|
full_input = [{"role": "user", "content": "hi"}]
|
|
params = {"input": full_input, "model": "gpt-4.1"}
|
|
|
|
result = service._apply_previous_response_optimization(params, full_input)
|
|
|
|
assert result["input"] == full_input
|
|
assert "previous_response_id" not in result
|
|
|
|
def test_matching_prefix_sends_incremental(self):
|
|
service = _make_service()
|
|
# Simulate: sent [user_msg], got assistant reply "hello"
|
|
prev_input = [{"role": "user", "content": "hi"}]
|
|
prev_output = [
|
|
{
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"content": [{"type": "output_text", "text": "hello"}],
|
|
}
|
|
]
|
|
service._store_previous_response_state("resp_123", prev_input, prev_output)
|
|
|
|
# Next call: adapter produces full context including assistant reply + new user msg
|
|
full_input = [
|
|
{"role": "user", "content": "hi"},
|
|
{"role": "assistant", "content": "hello"},
|
|
{"role": "user", "content": "how are you?"},
|
|
]
|
|
params = {"input": list(full_input), "model": "gpt-4.1"}
|
|
|
|
result = service._apply_previous_response_optimization(params, full_input)
|
|
|
|
assert result["previous_response_id"] == "resp_123"
|
|
# Only the new user message should be sent
|
|
assert result["input"] == [{"role": "user", "content": "how are you?"}]
|
|
|
|
def test_mismatched_prefix_sends_full(self):
|
|
service = _make_service()
|
|
prev_input = [{"role": "user", "content": "hi"}]
|
|
service._store_previous_response_state("resp_123", prev_input, [])
|
|
|
|
# Different first message
|
|
full_input = [
|
|
{"role": "user", "content": "different"},
|
|
{"role": "assistant", "content": "hello"},
|
|
]
|
|
params = {"input": list(full_input), "model": "gpt-4.1"}
|
|
|
|
result = service._apply_previous_response_optimization(params, full_input)
|
|
|
|
assert "previous_response_id" not in result
|
|
assert result["input"] == full_input
|
|
|
|
def test_same_length_sends_full(self):
|
|
"""When new input is same length as previous, no optimization."""
|
|
service = _make_service()
|
|
prev_input = [{"role": "user", "content": "hi"}]
|
|
service._store_previous_response_state("resp_123", prev_input, [])
|
|
|
|
full_input = [{"role": "user", "content": "hi"}]
|
|
params = {"input": list(full_input), "model": "gpt-4.1"}
|
|
|
|
result = service._apply_previous_response_optimization(params, full_input)
|
|
|
|
assert "previous_response_id" not in result
|
|
|
|
def test_output_mismatch_sends_full_context(self):
|
|
"""When prefix matches but output doesn't, fall back to full context."""
|
|
service = _make_service()
|
|
prev_input = [{"role": "user", "content": "hi"}]
|
|
prev_output = [
|
|
{
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"content": [{"type": "output_text", "text": "hello"}],
|
|
}
|
|
]
|
|
service._store_previous_response_state("resp_123", prev_input, prev_output)
|
|
|
|
# Aggregator stored the output differently (e.g. different role)
|
|
full_input = [
|
|
{"role": "user", "content": "hi"},
|
|
{"role": "developer", "content": "something unexpected"},
|
|
{"role": "user", "content": "how are you?"},
|
|
]
|
|
params = {"input": list(full_input), "model": "gpt-4.1"}
|
|
|
|
result = service._apply_previous_response_optimization(params, full_input)
|
|
|
|
assert "previous_response_id" not in result
|
|
assert result["input"] == full_input
|
|
|
|
def test_clear_state(self):
|
|
service = _make_service()
|
|
service._store_previous_response_state("resp_123", [{"role": "user", "content": "hi"}], [])
|
|
service._clear_previous_response_state()
|
|
|
|
assert service._previous_response_id is None
|
|
assert service._previous_input_hash is None
|
|
assert service._previous_input_length is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _receive_response_events — text streaming
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestReceiveResponseEventsText:
|
|
@pytest.mark.asyncio
|
|
async def test_text_deltas_pushed(self):
|
|
service = _make_service()
|
|
service._push_llm_text = AsyncMock()
|
|
service.stop_ttfb_metrics = AsyncMock()
|
|
service.start_llm_usage_metrics = AsyncMock()
|
|
|
|
ws = _ws_events(
|
|
{"type": "response.output_text.delta", "delta": "Hello"},
|
|
{"type": "response.output_text.delta", "delta": " world"},
|
|
{
|
|
"type": "response.completed",
|
|
"response": {
|
|
"id": "resp_1",
|
|
"model": "gpt-4.1",
|
|
"usage": {
|
|
"input_tokens": 10,
|
|
"output_tokens": 5,
|
|
"total_tokens": 15,
|
|
"input_tokens_details": {"cached_tokens": 0},
|
|
"output_tokens_details": {"reasoning_tokens": 0},
|
|
},
|
|
},
|
|
},
|
|
)
|
|
service._websocket = ws
|
|
|
|
context = MagicMock(spec=LLMContext)
|
|
full_input = [{"role": "user", "content": "hi"}]
|
|
await service._receive_response_events(context, full_input)
|
|
|
|
assert service._push_llm_text.call_count == 2
|
|
service._push_llm_text.assert_any_await("Hello")
|
|
service._push_llm_text.assert_any_await(" world")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_response_completed_stores_state(self):
|
|
service = _make_service()
|
|
service._push_llm_text = AsyncMock()
|
|
service.stop_ttfb_metrics = AsyncMock()
|
|
service.start_llm_usage_metrics = AsyncMock()
|
|
|
|
ws = _ws_events(
|
|
{
|
|
"type": "response.completed",
|
|
"response": {
|
|
"id": "resp_42",
|
|
"model": "gpt-4.1",
|
|
"output": [
|
|
{
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"content": [{"type": "output_text", "text": "Hello!"}],
|
|
}
|
|
],
|
|
"usage": {
|
|
"input_tokens": 10,
|
|
"output_tokens": 5,
|
|
"total_tokens": 15,
|
|
"input_tokens_details": {"cached_tokens": 2},
|
|
"output_tokens_details": {"reasoning_tokens": 1},
|
|
},
|
|
},
|
|
},
|
|
)
|
|
service._websocket = ws
|
|
|
|
context = MagicMock(spec=LLMContext)
|
|
full_input = [{"role": "user", "content": "hi"}]
|
|
await service._receive_response_events(context, full_input)
|
|
|
|
assert service._previous_response_id == "resp_42"
|
|
assert service._previous_input_length == 1
|
|
assert service._previous_input_hash is not None
|
|
assert len(service._previous_response_output) == 1
|
|
assert service.start_llm_usage_metrics.called
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_token_usage_metrics(self):
|
|
service = _make_service()
|
|
service._push_llm_text = AsyncMock()
|
|
service.stop_ttfb_metrics = AsyncMock()
|
|
service.start_llm_usage_metrics = AsyncMock()
|
|
|
|
ws = _ws_events(
|
|
{
|
|
"type": "response.completed",
|
|
"response": {
|
|
"id": "resp_1",
|
|
"model": "gpt-4.1",
|
|
"usage": {
|
|
"input_tokens": 100,
|
|
"output_tokens": 50,
|
|
"total_tokens": 150,
|
|
"input_tokens_details": {"cached_tokens": 20},
|
|
"output_tokens_details": {"reasoning_tokens": 10},
|
|
},
|
|
},
|
|
},
|
|
)
|
|
service._websocket = ws
|
|
|
|
context = MagicMock(spec=LLMContext)
|
|
await service._receive_response_events(context, [])
|
|
|
|
tokens = service.start_llm_usage_metrics.call_args[0][0]
|
|
assert tokens.prompt_tokens == 100
|
|
assert tokens.completion_tokens == 50
|
|
assert tokens.total_tokens == 150
|
|
assert tokens.cache_read_input_tokens == 20
|
|
assert tokens.reasoning_tokens == 10
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _receive_response_events — function calls
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestReceiveResponseEventsFunctionCalls:
|
|
@pytest.mark.asyncio
|
|
async def test_function_call_sequence(self):
|
|
service = _make_service()
|
|
service._push_llm_text = AsyncMock()
|
|
service.stop_ttfb_metrics = AsyncMock()
|
|
service.start_llm_usage_metrics = AsyncMock()
|
|
service.run_function_calls = AsyncMock()
|
|
|
|
ws = _ws_events(
|
|
{
|
|
"type": "response.output_item.added",
|
|
"item": {
|
|
"type": "function_call",
|
|
"id": "fc_1",
|
|
"name": "get_weather",
|
|
"call_id": "call_1",
|
|
},
|
|
},
|
|
{
|
|
"type": "response.function_call_arguments.delta",
|
|
"item_id": "fc_1",
|
|
"delta": '{"loc',
|
|
},
|
|
{
|
|
"type": "response.function_call_arguments.delta",
|
|
"item_id": "fc_1",
|
|
"delta": 'ation": "SF"}',
|
|
},
|
|
{
|
|
"type": "response.function_call_arguments.done",
|
|
"item_id": "fc_1",
|
|
"arguments": '{"location": "SF"}',
|
|
},
|
|
{
|
|
"type": "response.output_item.done",
|
|
"item": {
|
|
"type": "function_call",
|
|
"id": "fc_1",
|
|
"name": "get_weather",
|
|
"call_id": "call_1",
|
|
"arguments": '{"location": "SF"}',
|
|
},
|
|
},
|
|
{
|
|
"type": "response.completed",
|
|
"response": {"id": "resp_1", "model": "gpt-4.1", "usage": None},
|
|
},
|
|
)
|
|
service._websocket = ws
|
|
|
|
context = MagicMock(spec=LLMContext)
|
|
await service._receive_response_events(context, [])
|
|
|
|
service.run_function_calls.assert_called_once()
|
|
fc_list = service.run_function_calls.call_args[0][0]
|
|
assert len(fc_list) == 1
|
|
assert fc_list[0].function_name == "get_weather"
|
|
assert fc_list[0].tool_call_id == "call_1"
|
|
assert fc_list[0].arguments == {"location": "SF"}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _receive_response_events — errors
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestReceiveResponseEventsErrors:
|
|
@pytest.mark.asyncio
|
|
async def test_response_failed_pushes_error(self):
|
|
service = _make_service()
|
|
service.stop_ttfb_metrics = AsyncMock()
|
|
service.start_llm_usage_metrics = AsyncMock()
|
|
service.push_error = AsyncMock()
|
|
|
|
ws = _ws_events(
|
|
{
|
|
"type": "response.failed",
|
|
"response": {
|
|
"id": "resp_1",
|
|
"status_details": {
|
|
"error": {"message": "Content filter triggered"},
|
|
},
|
|
},
|
|
},
|
|
)
|
|
service._websocket = ws
|
|
|
|
context = MagicMock(spec=LLMContext)
|
|
await service._receive_response_events(context, [])
|
|
|
|
service.push_error.assert_called_once()
|
|
assert "Content filter triggered" in service.push_error.call_args.kwargs["error_msg"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_response_incomplete_pushes_error(self):
|
|
service = _make_service()
|
|
service.stop_ttfb_metrics = AsyncMock()
|
|
service.start_llm_usage_metrics = AsyncMock()
|
|
service.push_error = AsyncMock()
|
|
|
|
ws = _ws_events(
|
|
{
|
|
"type": "response.incomplete",
|
|
"response": {"id": "resp_1", "status_details": None},
|
|
},
|
|
)
|
|
service._websocket = ws
|
|
|
|
context = MagicMock(spec=LLMContext)
|
|
await service._receive_response_events(context, [])
|
|
|
|
service.push_error.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_previous_response_not_found_raises(self):
|
|
from pipecat.services.openai.responses.llm import _PreviousResponseNotFoundError
|
|
|
|
service = _make_service()
|
|
service.stop_ttfb_metrics = AsyncMock()
|
|
|
|
ws = _ws_events(
|
|
{
|
|
"type": "error",
|
|
"error": {
|
|
"code": "previous_response_not_found",
|
|
"message": "Previous response with id 'resp_abc' not found.",
|
|
},
|
|
},
|
|
)
|
|
service._websocket = ws
|
|
|
|
context = MagicMock(spec=LLMContext)
|
|
with pytest.raises(_PreviousResponseNotFoundError):
|
|
await service._receive_response_events(context, [])
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connection_limit_reached_raises(self):
|
|
from pipecat.services.openai.responses.llm import _ConnectionLimitReachedError
|
|
|
|
service = _make_service()
|
|
service.stop_ttfb_metrics = AsyncMock()
|
|
|
|
ws = _ws_events(
|
|
{
|
|
"type": "error",
|
|
"error": {
|
|
"code": "websocket_connection_limit_reached",
|
|
"message": "Connection limit reached.",
|
|
},
|
|
},
|
|
)
|
|
service._websocket = ws
|
|
|
|
context = MagicMock(spec=LLMContext)
|
|
with pytest.raises(_ConnectionLimitReachedError):
|
|
await service._receive_response_events(context, [])
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generic_error_pushes_error(self):
|
|
service = _make_service()
|
|
service.stop_ttfb_metrics = AsyncMock()
|
|
service.start_llm_usage_metrics = AsyncMock()
|
|
service.push_error = AsyncMock()
|
|
|
|
ws = _ws_events(
|
|
{
|
|
"type": "error",
|
|
"error": {
|
|
"code": "server_error",
|
|
"message": "Internal server error",
|
|
},
|
|
},
|
|
)
|
|
service._websocket = ws
|
|
|
|
context = MagicMock(spec=LLMContext)
|
|
await service._receive_response_events(context, [])
|
|
|
|
service.push_error.assert_called_once()
|
|
assert "Internal server error" in service.push_error.call_args.kwargs["error_msg"]
|
|
|
|
|
|
class TestDrainCancelledResponse:
|
|
@pytest.mark.asyncio
|
|
async def test_drain_discards_events_until_terminal(self):
|
|
"""Draining should discard events until a terminal event arrives."""
|
|
service = _make_service()
|
|
service._needs_drain = True
|
|
|
|
ws = _ws_events(
|
|
{"type": "response.output_text.delta", "delta": "stale"},
|
|
{"type": "response.output_text.delta", "delta": "also stale"},
|
|
{"type": "response.completed", "response": {"id": "resp_old"}},
|
|
)
|
|
service._websocket = ws
|
|
|
|
await service._drain_cancelled_response()
|
|
|
|
assert not service._needs_drain
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_drain_handles_pending_cancel(self):
|
|
"""If cancelled before response.created, drain should send cancel
|
|
once it sees the response.created, then continue draining."""
|
|
service = _make_service()
|
|
service._needs_drain = True
|
|
service._cancel_pending_response = True
|
|
|
|
mock_ws = AsyncMock()
|
|
mock_ws.recv = AsyncMock(
|
|
side_effect=[
|
|
json.dumps({"type": "response.created", "response": {"id": "resp_late"}}),
|
|
json.dumps({"type": "response.output_text.delta", "delta": "stale"}),
|
|
json.dumps({"type": "response.failed", "response": {"id": "resp_late"}}),
|
|
]
|
|
)
|
|
mock_ws.send = AsyncMock()
|
|
service._websocket = mock_ws
|
|
|
|
await service._drain_cancelled_response()
|
|
|
|
assert not service._needs_drain
|
|
assert not service._cancel_pending_response
|
|
# Should have sent response.cancel
|
|
cancel_calls = [
|
|
call for call in mock_ws.send.call_args_list if "response.cancel" in call.args[0]
|
|
]
|
|
assert len(cancel_calls) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_drain_timeout_clears_state(self):
|
|
"""If draining times out, should clear cancellation state."""
|
|
service = _make_service()
|
|
service._needs_drain = True
|
|
|
|
mock_ws = AsyncMock()
|
|
# recv() never returns a terminal event — times out
|
|
mock_ws.recv = AsyncMock(side_effect=asyncio.TimeoutError)
|
|
service._websocket = mock_ws
|
|
|
|
await service._drain_cancelled_response()
|
|
|
|
assert not service._needs_drain
|
|
assert not service._cancel_pending_response
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Connection lifecycle
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestConnectionLifecycle:
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect_clears_previous_response_state(self):
|
|
service = _make_service()
|
|
service._store_previous_response_state("resp_1", [{"role": "user", "content": "hi"}], [])
|
|
service.stop_all_metrics = AsyncMock()
|
|
|
|
await service._disconnect()
|
|
|
|
assert service._previous_response_id is None
|
|
assert service._previous_input_hash is None
|
|
assert service._previous_input_length is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reconnect_clears_state_and_reconnects(self):
|
|
service = _make_service()
|
|
service._store_previous_response_state("resp_1", [{"role": "user", "content": "hi"}], [])
|
|
service.stop_all_metrics = AsyncMock()
|
|
service.push_error = AsyncMock()
|
|
|
|
# Mock connect to set a websocket
|
|
mock_ws = AsyncMock()
|
|
mock_ws.close = AsyncMock()
|
|
service._websocket = mock_ws
|
|
|
|
with patch(
|
|
"pipecat.services.openai.responses.llm.websocket_connect",
|
|
new_callable=AsyncMock,
|
|
return_value=AsyncMock(),
|
|
):
|
|
# _disconnect + _connect is the lifecycle equivalent of the old _reconnect
|
|
await service._disconnect()
|
|
await service._connect()
|
|
|
|
assert service._previous_response_id is None
|
|
mock_ws.close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancellation_preserves_connection_and_sets_drain(self):
|
|
"""When process_frame is cancelled (e.g. interruption), the WebSocket
|
|
connection should be preserved and _needs_drain set."""
|
|
service = _make_service()
|
|
service.stop_processing_metrics = AsyncMock()
|
|
service.push_frame = AsyncMock()
|
|
|
|
mock_ws = AsyncMock()
|
|
mock_ws.recv = AsyncMock(side_effect=asyncio.CancelledError)
|
|
mock_ws.send = AsyncMock()
|
|
service._websocket = mock_ws
|
|
|
|
context = MagicMock(spec=LLMContext)
|
|
context.tools = None
|
|
context.tool_choice = None
|
|
context.messages = [{"role": "user", "content": "hi"}]
|
|
|
|
from pipecat.frames.frames import LLMContextFrame
|
|
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await service.process_frame(LLMContextFrame(context=context), FrameDirection.DOWNSTREAM)
|
|
|
|
# Connection should be preserved, not closed
|
|
assert service._websocket is mock_ws
|
|
# Should be flagged for draining before next inference
|
|
assert service._needs_drain
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ensure_connected_raises_on_failure(self):
|
|
service = _make_service()
|
|
service._websocket = None
|
|
# Mock _try_reconnect to simulate exhausted retries
|
|
service._try_reconnect = AsyncMock(return_value=False)
|
|
|
|
with pytest.raises(ConnectionError):
|
|
await service._ensure_connected()
|