diff --git a/changelog/3629.fixed.md b/changelog/3629.fixed.md new file mode 100644 index 000000000..12792f47d --- /dev/null +++ b/changelog/3629.fixed.md @@ -0,0 +1 @@ +- Fixed `StopAsyncIteration` exceptions in `parse_telephony_websocket()` when WebSocket connections close before sending expected messages. diff --git a/src/pipecat/runner/utils.py b/src/pipecat/runner/utils.py index f54453d93..d0bb44a88 100644 --- a/src/pipecat/runner/utils.py +++ b/src/pipecat/runner/utils.py @@ -96,6 +96,9 @@ def _detect_transport_type_from_message(message_data: dict) -> str: async def parse_telephony_websocket(websocket: WebSocket): """Parse telephony WebSocket messages and return transport type and call data. + Args: + websocket: FastAPI WebSocket connection from telephony provider. + Returns: tuple: (transport_type: str, call_data: dict) @@ -136,6 +139,9 @@ async def parse_telephony_websocket(websocket: WebSocket): "to": str, } + Raises: + ValueError: If WebSocket closes before sending any messages. + Example usage:: transport_type, call_data = await parse_telephony_websocket(websocket) @@ -143,25 +149,31 @@ async def parse_telephony_websocket(websocket: WebSocket): user_id = call_data["body"]["user_id"] """ # Read first two messages - start_data = websocket.iter_text() + message_stream = websocket.iter_text() + first_message = {} + second_message = {} try: - # First message - first_message_raw = await start_data.__anext__() + # First message - required + first_message_raw = await message_stream.__anext__() logger.trace(f"First message: {first_message_raw}") - try: - first_message = json.loads(first_message_raw) - except json.JSONDecodeError: - first_message = {} + first_message = json.loads(first_message_raw) if first_message_raw else {} + except json.JSONDecodeError: + pass + except StopAsyncIteration: + raise ValueError("WebSocket closed before receiving telephony handshake messages") - # Second message - second_message_raw = await start_data.__anext__() + try: + # Second message - optional, some providers may only send one + second_message_raw = await message_stream.__anext__() logger.trace(f"Second message: {second_message_raw}") - try: - second_message = json.loads(second_message_raw) - except json.JSONDecodeError: - second_message = {} + second_message = json.loads(second_message_raw) if second_message_raw else {} + except json.JSONDecodeError: + pass + except StopAsyncIteration: + logger.warning("Only received one WebSocket message, expected two") + try: # Try auto-detection on both messages detected_type_first = _detect_transport_type_from_message(first_message) detected_type_second = _detect_transport_type_from_message(second_message) diff --git a/tests/test_runner_utils.py b/tests/test_runner_utils.py new file mode 100644 index 000000000..18f156cbb --- /dev/null +++ b/tests/test_runner_utils.py @@ -0,0 +1,153 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import json +import unittest +from unittest.mock import MagicMock + +from pipecat.runner.utils import parse_telephony_websocket + + +class MockAsyncIterator: + """Mock async iterator for WebSocket messages.""" + + def __init__(self, messages): + self.messages = messages + self.index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index >= len(self.messages): + raise StopAsyncIteration + message = self.messages[self.index] + self.index += 1 + return message + + +class TestParseTelephonyWebSocket(unittest.IsolatedAsyncioTestCase): + async def test_no_messages_raises_value_error(self): + """Test that no messages raises ValueError.""" + mock_websocket = MagicMock() + mock_websocket.iter_text.return_value = MockAsyncIterator([]) + + with self.assertRaises(ValueError) as context: + await parse_telephony_websocket(mock_websocket) + + self.assertIn("WebSocket closed before receiving", str(context.exception)) + + async def test_one_message_logs_warning_and_continues(self): + """Test that one message logs warning but continues processing.""" + twilio_message = json.dumps( + { + "event": "start", + "start": { + "streamSid": "MZ123", + "callSid": "CA123", + "customParameters": {"user_id": "test_user"}, + }, + } + ) + + mock_websocket = MagicMock() + mock_websocket.iter_text.return_value = MockAsyncIterator([twilio_message]) + + transport_type, call_data = await parse_telephony_websocket(mock_websocket) + + self.assertEqual(transport_type, "twilio") + self.assertEqual(call_data["stream_id"], "MZ123") + self.assertEqual(call_data["call_id"], "CA123") + + async def test_two_messages_normal_operation(self): + """Test normal operation with two messages.""" + first_message = json.dumps({"event": "connected"}) + twilio_message = json.dumps( + { + "event": "start", + "start": { + "streamSid": "MZ456", + "callSid": "CA456", + "customParameters": {}, + }, + } + ) + + mock_websocket = MagicMock() + mock_websocket.iter_text.return_value = MockAsyncIterator([first_message, twilio_message]) + + transport_type, call_data = await parse_telephony_websocket(mock_websocket) + + self.assertEqual(transport_type, "twilio") + self.assertEqual(call_data["stream_id"], "MZ456") + self.assertEqual(call_data["call_id"], "CA456") + + async def test_telnyx_detection(self): + """Test Telnyx provider detection.""" + telnyx_message = json.dumps( + { + "stream_id": "stream_123", + "start": { + "call_control_id": "cc_123", + "media_format": {"encoding": "PCMU"}, + "from": "+15551234567", + "to": "+15559876543", + }, + } + ) + + mock_websocket = MagicMock() + mock_websocket.iter_text.return_value = MockAsyncIterator([telnyx_message]) + + transport_type, call_data = await parse_telephony_websocket(mock_websocket) + + self.assertEqual(transport_type, "telnyx") + self.assertEqual(call_data["stream_id"], "stream_123") + self.assertEqual(call_data["call_control_id"], "cc_123") + + async def test_plivo_detection(self): + """Test Plivo provider detection.""" + plivo_message = json.dumps( + {"start": {"streamId": "stream_plivo_123", "callId": "call_plivo_123"}} + ) + + mock_websocket = MagicMock() + mock_websocket.iter_text.return_value = MockAsyncIterator([plivo_message]) + + transport_type, call_data = await parse_telephony_websocket(mock_websocket) + + self.assertEqual(transport_type, "plivo") + self.assertEqual(call_data["stream_id"], "stream_plivo_123") + self.assertEqual(call_data["call_id"], "call_plivo_123") + + async def test_exotel_detection(self): + """Test Exotel provider detection.""" + exotel_message = json.dumps( + { + "event": "start", + "start": { + "stream_sid": "stream_exo_123", + "call_sid": "call_exo_123", + "account_sid": "acc_123", + "from": "+15551111111", + "to": "+15552222222", + }, + } + ) + + mock_websocket = MagicMock() + mock_websocket.iter_text.return_value = MockAsyncIterator([exotel_message]) + + transport_type, call_data = await parse_telephony_websocket(mock_websocket) + + self.assertEqual(transport_type, "exotel") + self.assertEqual(call_data["stream_id"], "stream_exo_123") + self.assertEqual(call_data["call_id"], "call_exo_123") + self.assertEqual(call_data["account_sid"], "acc_123") + + +if __name__ == "__main__": + unittest.main()