Merge pull request #3640 from lukepayyapilli/fix/openai-stream-close

fix: close stream on cancellation to prevent socket leaks
This commit is contained in:
Mark Backman
2026-02-05 18:00:06 -05:00
committed by GitHub
3 changed files with 131 additions and 61 deletions

1
changelog/3589.fixed.md Normal file
View File

@@ -0,0 +1 @@
- Fixed OpenAI LLM stream not being closed on cancellation/exception, which could leak sockets.

View File

@@ -362,74 +362,77 @@ class BaseOpenAILLMService(LLMService):
else self._stream_chat_completions_universal_context(context)
)
async for chunk in chunk_stream:
if chunk.usage:
cached_tokens = (
chunk.usage.prompt_tokens_details.cached_tokens
if chunk.usage.prompt_tokens_details
else None
)
reasoning_tokens = (
chunk.usage.completion_tokens_details.reasoning_tokens
if chunk.usage.completion_tokens_details
else None
)
tokens = LLMTokenUsage(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
cache_read_input_tokens=cached_tokens,
reasoning_tokens=reasoning_tokens,
)
await self.start_llm_usage_metrics(tokens)
# Use context manager to ensure stream is closed on cancellation/exception.
# Without this, CancelledError during iteration leaves the underlying socket open.
async with chunk_stream:
async for chunk in chunk_stream:
if chunk.usage:
cached_tokens = (
chunk.usage.prompt_tokens_details.cached_tokens
if chunk.usage.prompt_tokens_details
else None
)
reasoning_tokens = (
chunk.usage.completion_tokens_details.reasoning_tokens
if chunk.usage.completion_tokens_details
else None
)
tokens = LLMTokenUsage(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
cache_read_input_tokens=cached_tokens,
reasoning_tokens=reasoning_tokens,
)
await self.start_llm_usage_metrics(tokens)
if chunk.model and self.get_full_model_name() != chunk.model:
self.set_full_model_name(chunk.model)
if chunk.model and self.get_full_model_name() != chunk.model:
self.set_full_model_name(chunk.model)
if chunk.choices is None or len(chunk.choices) == 0:
continue
if chunk.choices is None or len(chunk.choices) == 0:
continue
await self.stop_ttfb_metrics()
await self.stop_ttfb_metrics()
if not chunk.choices[0].delta:
continue
if not chunk.choices[0].delta:
continue
if chunk.choices[0].delta.tool_calls:
# We're streaming the LLM response to enable the fastest response times.
# For text, we just yield each chunk as we receive it and count on consumers
# to do whatever coalescing they need (eg. to pass full sentences to TTS)
#
# If the LLM is a function call, we'll do some coalescing here.
# If the response contains a function name, we'll yield a frame to tell consumers
# that they can start preparing to call the function with that name.
# We accumulate all the arguments for the rest of the streamed response, then when
# the response is done, we package up all the arguments and the function name and
# yield a frame containing the function name and the arguments.
if chunk.choices[0].delta.tool_calls:
# We're streaming the LLM response to enable the fastest response times.
# For text, we just yield each chunk as we receive it and count on consumers
# to do whatever coalescing they need (eg. to pass full sentences to TTS)
#
# If the LLM is a function call, we'll do some coalescing here.
# If the response contains a function name, we'll yield a frame to tell consumers
# that they can start preparing to call the function with that name.
# We accumulate all the arguments for the rest of the streamed response, then when
# the response is done, we package up all the arguments and the function name and
# yield a frame containing the function name and the arguments.
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.index != func_idx:
functions_list.append(function_name)
arguments_list.append(arguments)
tool_id_list.append(tool_call_id)
function_name = ""
arguments = ""
tool_call_id = ""
func_idx += 1
if tool_call.function and tool_call.function.name:
function_name += tool_call.function.name
tool_call_id = tool_call.id
if tool_call.function and tool_call.function.arguments:
# Keep iterating through the response to collect all the argument fragments
arguments += tool_call.function.arguments
elif chunk.choices[0].delta.content:
await self._push_llm_text(chunk.choices[0].delta.content)
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.index != func_idx:
functions_list.append(function_name)
arguments_list.append(arguments)
tool_id_list.append(tool_call_id)
function_name = ""
arguments = ""
tool_call_id = ""
func_idx += 1
if tool_call.function and tool_call.function.name:
function_name += tool_call.function.name
tool_call_id = tool_call.id
if tool_call.function and tool_call.function.arguments:
# Keep iterating through the response to collect all the argument fragments
arguments += tool_call.function.arguments
elif chunk.choices[0].delta.content:
await self._push_llm_text(chunk.choices[0].delta.content)
# When gpt-4o-audio / gpt-4o-mini-audio is used for llm or stt+llm
# we need to get LLMTextFrame for the transcript
elif hasattr(chunk.choices[0].delta, "audio") and chunk.choices[0].delta.audio.get(
"transcript"
):
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.audio["transcript"]))
# When gpt-4o-audio / gpt-4o-mini-audio is used for llm or stt+llm
# we need to get LLMTextFrame for the transcript
elif hasattr(chunk.choices[0].delta, "audio") and chunk.choices[0].delta.audio.get(
"transcript"
):
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.audio["transcript"]))
# if we got a function name and arguments, check to see if it's a function with
# a registered handler. If so, run the registered callback, save the result to

View File

@@ -127,6 +127,72 @@ async def test_openai_llm_timeout_still_pushes_end_frame():
service.stop_processing_metrics.assert_called_once()
@pytest.mark.asyncio
async def test_openai_llm_stream_closed_on_cancellation():
"""Test that the stream is closed when CancelledError occurs during iteration.
This prevents socket leaks when the pipeline is interrupted (e.g., user interruption).
See issue #3589.
"""
import asyncio
with patch.object(OpenAILLMService, "create_client"):
service = OpenAILLMService(model="gpt-4")
service._client = AsyncMock()
# Track if close was called
stream_closed = False
class MockAsyncStream:
"""Mock AsyncStream that tracks close() calls and raises CancelledError."""
def __init__(self):
self.iteration_count = 0
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
nonlocal stream_closed
stream_closed = True
return False
def __aiter__(self):
return self
async def __anext__(self):
self.iteration_count += 1
if self.iteration_count > 1:
# Simulate cancellation during iteration
raise asyncio.CancelledError()
# Return a minimal chunk for first iteration
mock_chunk = AsyncMock()
mock_chunk.usage = None
mock_chunk.model = None
mock_chunk.choices = []
return mock_chunk
mock_stream = MockAsyncStream()
# Mock the stream creation methods
service._stream_chat_completions_specific_context = AsyncMock(return_value=mock_stream)
service._stream_chat_completions_universal_context = AsyncMock(return_value=mock_stream)
service.start_ttfb_metrics = AsyncMock()
service.stop_ttfb_metrics = AsyncMock()
service.start_llm_usage_metrics = AsyncMock()
context = LLMContext(
messages=[{"role": "user", "content": "Hello"}],
)
# Process context should raise CancelledError but stream should still be closed
with pytest.raises(asyncio.CancelledError):
await service._process_context(context)
# Verify stream was closed despite the cancellation
assert stream_closed, "Stream should be closed even when CancelledError occurs"
@pytest.mark.asyncio
async def test_openai_llm_emits_error_frame_on_exception():
"""Test that OpenAI LLM service emits ErrorFrame when a general exception occurs.