diff --git a/changelog/3589.fixed.md b/changelog/3589.fixed.md new file mode 100644 index 000000000..fda03ac70 --- /dev/null +++ b/changelog/3589.fixed.md @@ -0,0 +1 @@ +- Fixed OpenAI LLM stream not being closed on cancellation/exception, which could leak sockets. diff --git a/src/pipecat/services/openai/base_llm.py b/src/pipecat/services/openai/base_llm.py index ef6cfbbe9..54e514508 100644 --- a/src/pipecat/services/openai/base_llm.py +++ b/src/pipecat/services/openai/base_llm.py @@ -362,74 +362,77 @@ class BaseOpenAILLMService(LLMService): else self._stream_chat_completions_universal_context(context) ) - async for chunk in chunk_stream: - if chunk.usage: - cached_tokens = ( - chunk.usage.prompt_tokens_details.cached_tokens - if chunk.usage.prompt_tokens_details - else None - ) - reasoning_tokens = ( - chunk.usage.completion_tokens_details.reasoning_tokens - if chunk.usage.completion_tokens_details - else None - ) - tokens = LLMTokenUsage( - prompt_tokens=chunk.usage.prompt_tokens, - completion_tokens=chunk.usage.completion_tokens, - total_tokens=chunk.usage.total_tokens, - cache_read_input_tokens=cached_tokens, - reasoning_tokens=reasoning_tokens, - ) - await self.start_llm_usage_metrics(tokens) + # Use context manager to ensure stream is closed on cancellation/exception. + # Without this, CancelledError during iteration leaves the underlying socket open. + async with chunk_stream: + async for chunk in chunk_stream: + if chunk.usage: + cached_tokens = ( + chunk.usage.prompt_tokens_details.cached_tokens + if chunk.usage.prompt_tokens_details + else None + ) + reasoning_tokens = ( + chunk.usage.completion_tokens_details.reasoning_tokens + if chunk.usage.completion_tokens_details + else None + ) + tokens = LLMTokenUsage( + prompt_tokens=chunk.usage.prompt_tokens, + completion_tokens=chunk.usage.completion_tokens, + total_tokens=chunk.usage.total_tokens, + cache_read_input_tokens=cached_tokens, + reasoning_tokens=reasoning_tokens, + ) + await self.start_llm_usage_metrics(tokens) - if chunk.model and self.get_full_model_name() != chunk.model: - self.set_full_model_name(chunk.model) + if chunk.model and self.get_full_model_name() != chunk.model: + self.set_full_model_name(chunk.model) - if chunk.choices is None or len(chunk.choices) == 0: - continue + if chunk.choices is None or len(chunk.choices) == 0: + continue - await self.stop_ttfb_metrics() + await self.stop_ttfb_metrics() - if not chunk.choices[0].delta: - continue + if not chunk.choices[0].delta: + continue - if chunk.choices[0].delta.tool_calls: - # We're streaming the LLM response to enable the fastest response times. - # For text, we just yield each chunk as we receive it and count on consumers - # to do whatever coalescing they need (eg. to pass full sentences to TTS) - # - # If the LLM is a function call, we'll do some coalescing here. - # If the response contains a function name, we'll yield a frame to tell consumers - # that they can start preparing to call the function with that name. - # We accumulate all the arguments for the rest of the streamed response, then when - # the response is done, we package up all the arguments and the function name and - # yield a frame containing the function name and the arguments. + if chunk.choices[0].delta.tool_calls: + # We're streaming the LLM response to enable the fastest response times. + # For text, we just yield each chunk as we receive it and count on consumers + # to do whatever coalescing they need (eg. to pass full sentences to TTS) + # + # If the LLM is a function call, we'll do some coalescing here. + # If the response contains a function name, we'll yield a frame to tell consumers + # that they can start preparing to call the function with that name. + # We accumulate all the arguments for the rest of the streamed response, then when + # the response is done, we package up all the arguments and the function name and + # yield a frame containing the function name and the arguments. - tool_call = chunk.choices[0].delta.tool_calls[0] - if tool_call.index != func_idx: - functions_list.append(function_name) - arguments_list.append(arguments) - tool_id_list.append(tool_call_id) - function_name = "" - arguments = "" - tool_call_id = "" - func_idx += 1 - if tool_call.function and tool_call.function.name: - function_name += tool_call.function.name - tool_call_id = tool_call.id - if tool_call.function and tool_call.function.arguments: - # Keep iterating through the response to collect all the argument fragments - arguments += tool_call.function.arguments - elif chunk.choices[0].delta.content: - await self._push_llm_text(chunk.choices[0].delta.content) + tool_call = chunk.choices[0].delta.tool_calls[0] + if tool_call.index != func_idx: + functions_list.append(function_name) + arguments_list.append(arguments) + tool_id_list.append(tool_call_id) + function_name = "" + arguments = "" + tool_call_id = "" + func_idx += 1 + if tool_call.function and tool_call.function.name: + function_name += tool_call.function.name + tool_call_id = tool_call.id + if tool_call.function and tool_call.function.arguments: + # Keep iterating through the response to collect all the argument fragments + arguments += tool_call.function.arguments + elif chunk.choices[0].delta.content: + await self._push_llm_text(chunk.choices[0].delta.content) - # When gpt-4o-audio / gpt-4o-mini-audio is used for llm or stt+llm - # we need to get LLMTextFrame for the transcript - elif hasattr(chunk.choices[0].delta, "audio") and chunk.choices[0].delta.audio.get( - "transcript" - ): - await self.push_frame(LLMTextFrame(chunk.choices[0].delta.audio["transcript"])) + # When gpt-4o-audio / gpt-4o-mini-audio is used for llm or stt+llm + # we need to get LLMTextFrame for the transcript + elif hasattr(chunk.choices[0].delta, "audio") and chunk.choices[0].delta.audio.get( + "transcript" + ): + await self.push_frame(LLMTextFrame(chunk.choices[0].delta.audio["transcript"])) # if we got a function name and arguments, check to see if it's a function with # a registered handler. If so, run the registered callback, save the result to diff --git a/tests/test_openai_llm_timeout.py b/tests/test_openai_llm_timeout.py index 37e4523a9..4ba459a29 100644 --- a/tests/test_openai_llm_timeout.py +++ b/tests/test_openai_llm_timeout.py @@ -127,6 +127,72 @@ async def test_openai_llm_timeout_still_pushes_end_frame(): service.stop_processing_metrics.assert_called_once() +@pytest.mark.asyncio +async def test_openai_llm_stream_closed_on_cancellation(): + """Test that the stream is closed when CancelledError occurs during iteration. + + This prevents socket leaks when the pipeline is interrupted (e.g., user interruption). + See issue #3589. + """ + import asyncio + + with patch.object(OpenAILLMService, "create_client"): + service = OpenAILLMService(model="gpt-4") + service._client = AsyncMock() + + # Track if close was called + stream_closed = False + + class MockAsyncStream: + """Mock AsyncStream that tracks close() calls and raises CancelledError.""" + + def __init__(self): + self.iteration_count = 0 + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + nonlocal stream_closed + stream_closed = True + return False + + def __aiter__(self): + return self + + async def __anext__(self): + self.iteration_count += 1 + if self.iteration_count > 1: + # Simulate cancellation during iteration + raise asyncio.CancelledError() + # Return a minimal chunk for first iteration + mock_chunk = AsyncMock() + mock_chunk.usage = None + mock_chunk.model = None + mock_chunk.choices = [] + return mock_chunk + + mock_stream = MockAsyncStream() + + # Mock the stream creation methods + service._stream_chat_completions_specific_context = AsyncMock(return_value=mock_stream) + service._stream_chat_completions_universal_context = AsyncMock(return_value=mock_stream) + service.start_ttfb_metrics = AsyncMock() + service.stop_ttfb_metrics = AsyncMock() + service.start_llm_usage_metrics = AsyncMock() + + context = LLMContext( + messages=[{"role": "user", "content": "Hello"}], + ) + + # Process context should raise CancelledError but stream should still be closed + with pytest.raises(asyncio.CancelledError): + await service._process_context(context) + + # Verify stream was closed despite the cancellation + assert stream_closed, "Stream should be closed even when CancelledError occurs" + + @pytest.mark.asyncio async def test_openai_llm_emits_error_frame_on_exception(): """Test that OpenAI LLM service emits ErrorFrame when a general exception occurs.