diff --git a/examples/foundational/06a-image-sync.py b/examples/foundational/06a-image-sync.py index 5a5264dec..1b8df9a28 100644 --- a/examples/foundational/06a-image-sync.py +++ b/examples/foundational/06a-image-sync.py @@ -51,7 +51,7 @@ class ImageSyncAggregator(FrameProcessor): async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) - if not isinstance(frame, SystemFrame): + if not isinstance(frame, SystemFrame) and direction == FrameDirection.DOWNSTREAM: await self.push_frame(ImageRawFrame(image=self._speaking_image_bytes, size=(1024, 1024), format=self._speaking_image_format)) await self.push_frame(frame) await self.push_frame(ImageRawFrame(image=self._waiting_image_bytes, size=(1024, 1024), format=self._waiting_image_format)) diff --git a/src/pipecat/processors/frameworks/rtvi.py b/src/pipecat/processors/frameworks/rtvi.py index 343f0a217..cd4e2e8db 100644 --- a/src/pipecat/processors/frameworks/rtvi.py +++ b/src/pipecat/processors/frameworks/rtvi.py @@ -12,6 +12,8 @@ from pydantic import PrivateAttr, BaseModel, ValidationError from pipecat.frames.frames import ( BotInterruptionFrame, + CancelFrame, + EndFrame, Frame, InterimTranscriptionFrame, LLMFullResponseEndFrame, @@ -343,32 +345,64 @@ class RTVIProcessor(FrameProcessor): self._ctor_args = ctor_args async def update_config(self, config: RTVIConfig): - await self._handle_config_update(config) + if self._pipeline: + await self._handle_config_update(config) + self._config = config async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) - if isinstance(frame, SystemFrame): + # Specific system frames + if isinstance(frame, CancelFrame): + await self._cancel(frame) await self.push_frame(frame, direction) + # All other system frames + elif isinstance(frame, SystemFrame): + await self.push_frame(frame, direction) + # Control frames + elif isinstance(frame, StartFrame): + await self._start(frame) + await self._internal_push_frame(frame, direction) + elif isinstance(frame, EndFrame): + # Push EndFrame before stop(), because stop() waits on the task to + # finish and the task finishes when EndFrame is processed. + await self._internal_push_frame(frame, direction) + await self._stop(frame) + # Other frames else: - await self._frame_queue.put((frame, direction)) - - if isinstance(frame, StartFrame): - try: - await self._handle_pipeline_setup(frame, self._config) - except Exception as e: - await self._send_error(f"unable to setup RTVI pipeline: {e}") + await self._internal_push_frame(frame, direction) async def cleanup(self): + if self._pipeline: + await self._pipeline.cleanup() + + async def _start(self, frame: StartFrame): + try: + await self._handle_pipeline_setup(frame, self._config) + except Exception as e: + await self._send_error(f"unable to setup RTVI pipeline: {e}") + + async def _stop(self, frame: EndFrame): + await self._frame_handler_task + + async def _cancel(self, frame: CancelFrame): self._frame_handler_task.cancel() await self._frame_handler_task + async def _internal_push_frame( + self, + frame: Frame | None, + direction: FrameDirection | None = FrameDirection.DOWNSTREAM): + await self._frame_queue.put((frame, direction)) + async def _frame_handler(self): - while True: + running = True + while running: try: (frame, direction) = await self._frame_queue.get() await self._handle_frame(frame, direction) self._frame_queue.task_done() + running = not isinstance(frame, EndFrame) except asyncio.CancelledError: break diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index bc00accdf..8dae7ed0b 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -283,14 +283,17 @@ class STTService(AIService): await self.stop_processing_metrics() (self._content, self._wave) = self._new_wave() + async def stop(self, frame: EndFrame): + self._wave.close() + + async def cancel(self, frame: CancelFrame): + self._wave.close() + async def process_frame(self, frame: Frame, direction: FrameDirection): """Processes a frame of audio data, either buffering or transcribing it.""" await super().process_frame(frame, direction) - if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame): - self._wave.close() - await self.push_frame(frame, direction) - elif isinstance(frame, AudioRawFrame): + if isinstance(frame, AudioRawFrame): # In this service we accumulate audio internally and at the end we # push a TextFrame. We don't really want to push audio frames down. await self._append_audio(frame) diff --git a/src/pipecat/services/azure.py b/src/pipecat/services/azure.py index 8991f154a..2d43d9a8c 100644 --- a/src/pipecat/services/azure.py +++ b/src/pipecat/services/azure.py @@ -147,13 +147,16 @@ class AzureSTTService(AsyncAIService): await self._push_queue.put((frame, direction)) async def start(self, frame: StartFrame): + await super().start(frame) self._speech_recognizer.start_continuous_recognition_async() async def stop(self, frame: EndFrame): + await super().stop(frame) self._speech_recognizer.stop_continuous_recognition_async() self._audio_stream.close() async def cancel(self, frame: CancelFrame): + await super().cancel(frame) self._speech_recognizer.stop_continuous_recognition_async() self._audio_stream.close() diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia.py index 86ae10e99..592215cd6 100644 --- a/src/pipecat/services/cartesia.py +++ b/src/pipecat/services/cartesia.py @@ -14,6 +14,7 @@ from typing import AsyncGenerator from pipecat.processors.frame_processor import FrameDirection from pipecat.frames.frames import ( + CancelFrame, Frame, AudioRawFrame, StartInterruptionFrame, @@ -98,6 +99,10 @@ class CartesiaTTSService(TTSService): await super().stop(frame) await self._disconnect() + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + await self._disconnect() + async def _connect(self): try: self._websocket = await websockets.connect( @@ -111,6 +116,8 @@ class CartesiaTTSService(TTSService): async def _disconnect(self): try: + await self.stop_all_metrics() + if self._context_appending_task: self._context_appending_task.cancel() await self._context_appending_task @@ -120,13 +127,12 @@ class CartesiaTTSService(TTSService): await self._receive_task self._receive_task = None if self._websocket: - ws = self._websocket + await self._websocket.close() self._websocket = None - await ws.close() + self._context_id = None self._context_id_start_timestamp = None self._timestamped_words_buffer = [] - await self.stop_all_metrics() except Exception as e: logger.exception(f"{self} error closing websocket: {e}") @@ -142,13 +148,13 @@ class CartesiaTTSService(TTSService): try: async for message in self._websocket: msg = json.loads(message) - # logger.debug(f"Received message: {msg['type']} {msg['context_id']}") if not msg or msg["context_id"] != self._context_id: continue if msg["type"] == "done": await self.stop_ttfb_metrics() - # unset _context_id but not the _context_id_start_timestamp because we are likely still - # playing out audio and need the timestamp to set send context frames + # Unset _context_id but not the _context_id_start_timestamp + # because we are likely still playing out audio and need the + # timestamp to set send context frames. self._context_id = None self._timestamped_words_buffer.append(("LLMFullResponseEndFrame", 0)) elif msg["type"] == "timestamps": @@ -166,6 +172,8 @@ class CartesiaTTSService(TTSService): num_channels=1 ) await self.push_frame(frame) + except asyncio.CancelledError: + pass except Exception as e: logger.exception(f"{self} exception: {e}") @@ -176,15 +184,17 @@ class CartesiaTTSService(TTSService): if not self._context_id_start_timestamp: continue elapsed_seconds = time.time() - self._context_id_start_timestamp - # pop all words from self._timestamped_words_buffer that are older than the - # elapsed time and print a message about them to the console + # Pop all words from self._timestamped_words_buffer that are + # older than the elapsed time and print a message about them to + # the console. while self._timestamped_words_buffer and self._timestamped_words_buffer[0][1] <= elapsed_seconds: word, timestamp = self._timestamped_words_buffer.pop(0) if word == "LLMFullResponseEndFrame" and timestamp == 0: await self.push_frame(LLMFullResponseEndFrame()) continue - # print(f"Word '{word}' with timestamp {timestamp:.2f}s has been spoken.") await self.push_frame(TextFrame(word)) + except asyncio.CancelledError: + pass except Exception as e: logger.exception(f"{self} exception: {e}") @@ -212,7 +222,6 @@ class CartesiaTTSService(TTSService): "language": self._language, "add_timestamps": True, } - # logger.debug(f"SENDING MESSAGE {json.dumps(msg)}") try: await self._websocket.send(json.dumps(msg)) except Exception as e: diff --git a/src/pipecat/services/deepgram.py b/src/pipecat/services/deepgram.py index f582664c3..1c20df45c 100644 --- a/src/pipecat/services/deepgram.py +++ b/src/pipecat/services/deepgram.py @@ -136,15 +136,18 @@ class DeepgramSTTService(AsyncAIService): await self.queue_frame(frame, direction) async def start(self, frame: StartFrame): + await super().start(frame) if await self._connection.start(self._live_options): logger.debug(f"{self}: Connected to Deepgram") else: logger.error(f"{self}: Unable to connect to Deepgram") async def stop(self, frame: EndFrame): + await super().stop(frame) await self._connection.finish() async def cancel(self, frame: CancelFrame): + await super().cancel(frame) await self._connection.finish() async def _on_message(self, *args, **kwargs): diff --git a/src/pipecat/services/gladia.py b/src/pipecat/services/gladia.py index 4043e1283..886300897 100644 --- a/src/pipecat/services/gladia.py +++ b/src/pipecat/services/gladia.py @@ -68,14 +68,17 @@ class GladiaSTTService(AsyncAIService): await self.queue_frame(frame, direction) async def start(self, frame: StartFrame): + await super().start(frame) self._websocket = await websockets.connect(self._url) self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) await self._setup_gladia() async def stop(self, frame: EndFrame): + await super().stop(frame) await self._websocket.close() async def cancel(self, frame: CancelFrame): + await super().cancel(frame) await self._websocket.close() async def _setup_gladia(self): diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index fd51334f5..5eb5da16c 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -46,12 +46,26 @@ class BaseInputTransport(FrameProcessor): self._audio_in_queue = asyncio.Queue() self._audio_task = self.get_event_loop().create_task(self._audio_task_handler()) - async def stop(self): - # Wait for the task to finish. + async def stop(self, frame: EndFrame): + # Cancel and wait for the audio input task to finish. if self._params.audio_in_enabled or self._params.vad_enabled: self._audio_task.cancel() await self._audio_task + # Wait for the push frame task to finish. It will finish when the + # EndFrame is actually processed. + await self._push_frame_task + + async def cancel(self, frame: CancelFrame): + # Cancel all the tasks and wait for them to finish. + + if self._params.audio_in_enabled or self._params.vad_enabled: + self._audio_task.cancel() + await self._audio_task + + self._push_frame_task.cancel() + await self._push_frame_task + def vad_analyzer(self) -> VADAnalyzer | None: return self._params.vad_analyzer @@ -63,17 +77,12 @@ class BaseInputTransport(FrameProcessor): # Frame processor # - async def cleanup(self): - self._push_frame_task.cancel() - await self._push_frame_task - async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) # Specific system frames if isinstance(frame, CancelFrame): - await self.stop() - # We don't queue a CancelFrame since we want to stop ASAP. + await self.cancel(frame) await self.push_frame(frame, direction) elif isinstance(frame, BotInterruptionFrame): await self._handle_interruptions(frame, False) @@ -89,8 +98,10 @@ class BaseInputTransport(FrameProcessor): await self.start(frame) await self._internal_push_frame(frame, direction) elif isinstance(frame, EndFrame): + # Push EndFrame before stop(), because stop() waits on the task to + # finish and the task finishes when EndFrame is processed. await self._internal_push_frame(frame, direction) - await self.stop() + await self.stop(frame) # Other frames else: await self._internal_push_frame(frame, direction) @@ -111,10 +122,12 @@ class BaseInputTransport(FrameProcessor): await self._push_queue.put((frame, direction)) async def _push_frame_task_handler(self): - while True: + running = True + while running: try: (frame, direction) = await self._push_queue.get() await self.push_frame(frame, direction) + running = not isinstance(frame, EndFrame) self._push_queue.task_done() except asyncio.CancelledError: break diff --git a/src/pipecat/transports/base_output.py b/src/pipecat/transports/base_output.py index 8a99d94a8..cc37b8934 100644 --- a/src/pipecat/transports/base_output.py +++ b/src/pipecat/transports/base_output.py @@ -64,18 +64,34 @@ class BaseOutputTransport(FrameProcessor): self._create_push_task() async def start(self, frame: StartFrame): - # Create media threads queues. + # Create camera output queue and task if needed. if self._params.camera_out_enabled: self._camera_out_queue = asyncio.Queue() self._camera_out_task = self.get_event_loop().create_task(self._camera_out_task_handler()) - async def stop(self): - # Wait on the threads to finish. + async def stop(self, frame: EndFrame): + # Cancel and wait for the camera output task to finish. if self._params.camera_out_enabled: self._camera_out_task.cancel() await self._camera_out_task - self._stopped_event.set() + # Wait for the push frame and sink tasks to finish. They will finish when + # the EndFrame is actually processed. + await self._push_frame_task + await self._sink_task + + async def cancel(self, frame: CancelFrame): + # Cancel all the tasks and wait for them to finish. + + if self._params.camera_out_enabled: + self._camera_out_task.cancel() + await self._camera_out_task + + self._push_frame_task.cancel() + await self._push_frame_task + + self._sink_task.cancel() + await self._sink_task async def send_message(self, frame: TransportMessageFrame): pass @@ -93,48 +109,38 @@ class BaseOutputTransport(FrameProcessor): # Frame processor # - async def cleanup(self): - if self._sink_task: - self._sink_task.cancel() - await self._sink_task - - self._push_frame_task.cancel() - await self._push_frame_task - async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) # - # Out-of-band frames like (CancelFrame or StartInterruptionFrame) are - # pushed immediately. Other frames require order so they are put in the - # sink queue. + # System frames (like StartInterruptionFrame) are pushed + # immediately. Other frames require order so they are put in the sink + # queue. # - if isinstance(frame, StartFrame): - await self.start(frame) - await self.push_frame(frame, direction) - # EndFrame is managed in the sink queue handler. - elif isinstance(frame, CancelFrame): - await self.stop() + if isinstance(frame, CancelFrame): await self.push_frame(frame, direction) + await self.cancel(frame) elif isinstance(frame, StartInterruptionFrame) or isinstance(frame, StopInterruptionFrame): + await self.push_frame(frame, direction) await self._handle_interruptions(frame) - await self.push_frame(frame, direction) elif isinstance(frame, MetricsFrame): - await self.send_metrics(frame) await self.push_frame(frame, direction) + await self.send_metrics(frame) elif isinstance(frame, SystemFrame): await self.push_frame(frame, direction) + # Control frames. + elif isinstance(frame, StartFrame): + await self._sink_queue.put(frame) + await self.start(frame) + elif isinstance(frame, EndFrame): + await self._sink_queue.put(frame) + await self.stop(frame) + # Other frames. elif isinstance(frame, AudioRawFrame): await self._handle_audio(frame) else: await self._sink_queue.put(frame) - # If we are finishing, wait here until we have stopped, otherwise we might - # close things too early upstream. We need this event because we don't - # know when the internal threads will finish. - if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame): - await self._stopped_event.wait() - async def _handle_interruptions(self, frame: Frame): if not self.interruptions_allowed: return @@ -164,7 +170,9 @@ class BaseOutputTransport(FrameProcessor): async def _sink_task_handler(self): # Audio accumlation buffer buffer = bytearray() - while True: + + running = True + while running: try: frame = await self._sink_queue.get() if isinstance(frame, AudioRawFrame) and self._params.audio_out_enabled: @@ -185,8 +193,7 @@ class BaseOutputTransport(FrameProcessor): else: await self._internal_push_frame(frame) - if isinstance(frame, EndFrame): - await self.stop() + running = not isinstance(frame, EndFrame) self._sink_queue.task_done() except asyncio.CancelledError: @@ -210,10 +217,12 @@ class BaseOutputTransport(FrameProcessor): await self._push_queue.put((frame, direction)) async def _push_frame_task_handler(self): - while True: + running = True + while running: try: (frame, direction) = await self._push_queue.get() await self.push_frame(frame, direction) + running = not isinstance(frame, EndFrame) self._push_queue.task_done() except asyncio.CancelledError: break diff --git a/src/pipecat/transports/network/fastapi_websocket.py b/src/pipecat/transports/network/fastapi_websocket.py index 8b9877d09..32857a696 100644 --- a/src/pipecat/transports/network/fastapi_websocket.py +++ b/src/pipecat/transports/network/fastapi_websocket.py @@ -12,7 +12,7 @@ import wave from typing import Awaitable, Callable from pydantic.main import BaseModel -from pipecat.frames.frames import AudioRawFrame, StartFrame +from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, StartFrame from pipecat.processors.frame_processor import FrameProcessor from pipecat.serializers.base_serializer import FrameSerializer from pipecat.transports.base_input import BaseInputTransport @@ -57,14 +57,19 @@ class FastAPIWebsocketInputTransport(BaseInputTransport): self._callbacks = callbacks async def start(self, frame: StartFrame): - await self._callbacks.on_client_connected(self._websocket) await super().start(frame) + await self._callbacks.on_client_connected(self._websocket) self._receive_task = self.get_event_loop().create_task(self._receive_messages()) - async def stop(self): + async def stop(self, frame: EndFrame): + await super().stop(frame) + if self._websocket.client_state != WebSocketState.DISCONNECTED: + await self._websocket.close() + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) if self._websocket.client_state != WebSocketState.DISCONNECTED: await self._websocket.close() - await super().stop() async def _receive_messages(self): async for message in self._websocket.iter_text(): diff --git a/src/pipecat/transports/network/websocket_server.py b/src/pipecat/transports/network/websocket_server.py index fe775ce45..e231c2d77 100644 --- a/src/pipecat/transports/network/websocket_server.py +++ b/src/pipecat/transports/network/websocket_server.py @@ -11,7 +11,7 @@ import wave from typing import Awaitable, Callable from pydantic.main import BaseModel -from pipecat.frames.frames import AudioRawFrame, StartFrame +from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, StartFrame from pipecat.processors.frame_processor import FrameProcessor from pipecat.serializers.base_serializer import FrameSerializer from pipecat.serializers.protobuf import ProtobufFrameSerializer @@ -64,10 +64,15 @@ class WebsocketServerInputTransport(BaseInputTransport): self._server_task = self.get_event_loop().create_task(self._server_task_handler()) await super().start(frame) - async def stop(self): + async def stop(self, frame: EndFrame): + await super().stop(frame) self._stop_server_event.set() await self._server_task - await super().stop() + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + self._server_task.cancel() + await self._server_task async def _server_task_handler(self): logger.info(f"Starting websocket server on {self._host}:{self._port}") diff --git a/src/pipecat/transports/services/daily.py b/src/pipecat/transports/services/daily.py index 386d5deb4..3047e4fb2 100644 --- a/src/pipecat/transports/services/daily.py +++ b/src/pipecat/transports/services/daily.py @@ -23,6 +23,8 @@ from pydantic.main import BaseModel from pipecat.frames.frames import ( AudioRawFrame, + CancelFrame, + EndFrame, Frame, ImageRawFrame, InterimTranscriptionFrame, @@ -125,11 +127,15 @@ class DailyCallbacks(BaseModel): def completion_callback(future): def _callback(*args): - if not future.cancelled(): - if len(args) > 1: - future.get_loop().call_soon_threadsafe(future.set_result, args) - else: - future.get_loop().call_soon_threadsafe(future.set_result, *args) + def set_result(future, *args): + try: + if len(args) > 1: + future.set_result(args) + else: + future.set_result(*args) + except asyncio.InvalidStateError: + pass + future.get_loop().call_soon_threadsafe(set_result, future, *args) return _callback @@ -541,9 +547,19 @@ class DailyInputTransport(BaseInputTransport): if self._params.audio_in_enabled or self._params.vad_enabled: self._audio_in_task = self.get_event_loop().create_task(self._audio_in_task_handler()) - async def stop(self): + async def stop(self, frame: EndFrame): # Parent stop. - await super().stop() + await super().stop(frame) + # Leave the room. + await self._client.leave() + # Stop audio thread. + if self._params.audio_in_enabled or self._params.vad_enabled: + self._audio_in_task.cancel() + await self._audio_in_task + + async def cancel(self, frame: CancelFrame): + # Parent stop. + await super().cancel(frame) # Leave the room. await self._client.leave() # Stop audio thread. @@ -658,9 +674,15 @@ class DailyOutputTransport(BaseOutputTransport): # Join the room. await self._client.join() - async def stop(self): + async def stop(self, frame: EndFrame): # Parent stop. - await super().stop() + await super().stop(frame) + # Leave the room. + await self._client.leave() + + async def cancel(self, frame: CancelFrame): + # Parent stop. + await super().cancel(frame) # Leave the room. await self._client.leave()