Merge pull request #330 from pipecat-ai/aleix/stop-and-cancel-are-different
EndFrame tries to end gracefully CancelFrame cancels tasks
This commit is contained in:
@@ -51,7 +51,7 @@ class ImageSyncAggregator(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if not isinstance(frame, SystemFrame):
|
||||
if not isinstance(frame, SystemFrame) and direction == FrameDirection.DOWNSTREAM:
|
||||
await self.push_frame(ImageRawFrame(image=self._speaking_image_bytes, size=(1024, 1024), format=self._speaking_image_format))
|
||||
await self.push_frame(frame)
|
||||
await self.push_frame(ImageRawFrame(image=self._waiting_image_bytes, size=(1024, 1024), format=self._waiting_image_format))
|
||||
|
||||
@@ -12,6 +12,8 @@ from pydantic import PrivateAttr, BaseModel, ValidationError
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
@@ -343,32 +345,64 @@ class RTVIProcessor(FrameProcessor):
|
||||
self._ctor_args = ctor_args
|
||||
|
||||
async def update_config(self, config: RTVIConfig):
|
||||
await self._handle_config_update(config)
|
||||
if self._pipeline:
|
||||
await self._handle_config_update(config)
|
||||
self._config = config
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, SystemFrame):
|
||||
# Specific system frames
|
||||
if isinstance(frame, CancelFrame):
|
||||
await self._cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
# All other system frames
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
# Control frames
|
||||
elif isinstance(frame, StartFrame):
|
||||
await self._start(frame)
|
||||
await self._internal_push_frame(frame, direction)
|
||||
elif isinstance(frame, EndFrame):
|
||||
# Push EndFrame before stop(), because stop() waits on the task to
|
||||
# finish and the task finishes when EndFrame is processed.
|
||||
await self._internal_push_frame(frame, direction)
|
||||
await self._stop(frame)
|
||||
# Other frames
|
||||
else:
|
||||
await self._frame_queue.put((frame, direction))
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
try:
|
||||
await self._handle_pipeline_setup(frame, self._config)
|
||||
except Exception as e:
|
||||
await self._send_error(f"unable to setup RTVI pipeline: {e}")
|
||||
await self._internal_push_frame(frame, direction)
|
||||
|
||||
async def cleanup(self):
|
||||
if self._pipeline:
|
||||
await self._pipeline.cleanup()
|
||||
|
||||
async def _start(self, frame: StartFrame):
|
||||
try:
|
||||
await self._handle_pipeline_setup(frame, self._config)
|
||||
except Exception as e:
|
||||
await self._send_error(f"unable to setup RTVI pipeline: {e}")
|
||||
|
||||
async def _stop(self, frame: EndFrame):
|
||||
await self._frame_handler_task
|
||||
|
||||
async def _cancel(self, frame: CancelFrame):
|
||||
self._frame_handler_task.cancel()
|
||||
await self._frame_handler_task
|
||||
|
||||
async def _internal_push_frame(
|
||||
self,
|
||||
frame: Frame | None,
|
||||
direction: FrameDirection | None = FrameDirection.DOWNSTREAM):
|
||||
await self._frame_queue.put((frame, direction))
|
||||
|
||||
async def _frame_handler(self):
|
||||
while True:
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
(frame, direction) = await self._frame_queue.get()
|
||||
await self._handle_frame(frame, direction)
|
||||
self._frame_queue.task_done()
|
||||
running = not isinstance(frame, EndFrame)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
|
||||
@@ -283,14 +283,17 @@ class STTService(AIService):
|
||||
await self.stop_processing_metrics()
|
||||
(self._content, self._wave) = self._new_wave()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
self._wave.close()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
self._wave.close()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Processes a frame of audio data, either buffering or transcribing it."""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame):
|
||||
self._wave.close()
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
if isinstance(frame, AudioRawFrame):
|
||||
# In this service we accumulate audio internally and at the end we
|
||||
# push a TextFrame. We don't really want to push audio frames down.
|
||||
await self._append_audio(frame)
|
||||
|
||||
@@ -147,13 +147,16 @@ class AzureSTTService(AsyncAIService):
|
||||
await self._push_queue.put((frame, direction))
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._speech_recognizer.start_continuous_recognition_async()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
self._speech_recognizer.stop_continuous_recognition_async()
|
||||
self._audio_stream.close()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
self._speech_recognizer.stop_continuous_recognition_async()
|
||||
self._audio_stream.close()
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import AsyncGenerator
|
||||
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
Frame,
|
||||
AudioRawFrame,
|
||||
StartInterruptionFrame,
|
||||
@@ -98,6 +99,10 @@ class CartesiaTTSService(TTSService):
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
try:
|
||||
self._websocket = await websockets.connect(
|
||||
@@ -111,6 +116,8 @@ class CartesiaTTSService(TTSService):
|
||||
|
||||
async def _disconnect(self):
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
|
||||
if self._context_appending_task:
|
||||
self._context_appending_task.cancel()
|
||||
await self._context_appending_task
|
||||
@@ -120,13 +127,12 @@ class CartesiaTTSService(TTSService):
|
||||
await self._receive_task
|
||||
self._receive_task = None
|
||||
if self._websocket:
|
||||
ws = self._websocket
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
await ws.close()
|
||||
|
||||
self._context_id = None
|
||||
self._context_id_start_timestamp = None
|
||||
self._timestamped_words_buffer = []
|
||||
await self.stop_all_metrics()
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error closing websocket: {e}")
|
||||
|
||||
@@ -142,13 +148,13 @@ class CartesiaTTSService(TTSService):
|
||||
try:
|
||||
async for message in self._websocket:
|
||||
msg = json.loads(message)
|
||||
# logger.debug(f"Received message: {msg['type']} {msg['context_id']}")
|
||||
if not msg or msg["context_id"] != self._context_id:
|
||||
continue
|
||||
if msg["type"] == "done":
|
||||
await self.stop_ttfb_metrics()
|
||||
# unset _context_id but not the _context_id_start_timestamp because we are likely still
|
||||
# playing out audio and need the timestamp to set send context frames
|
||||
# Unset _context_id but not the _context_id_start_timestamp
|
||||
# because we are likely still playing out audio and need the
|
||||
# timestamp to set send context frames.
|
||||
self._context_id = None
|
||||
self._timestamped_words_buffer.append(("LLMFullResponseEndFrame", 0))
|
||||
elif msg["type"] == "timestamps":
|
||||
@@ -166,6 +172,8 @@ class CartesiaTTSService(TTSService):
|
||||
num_channels=1
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
|
||||
@@ -176,15 +184,17 @@ class CartesiaTTSService(TTSService):
|
||||
if not self._context_id_start_timestamp:
|
||||
continue
|
||||
elapsed_seconds = time.time() - self._context_id_start_timestamp
|
||||
# pop all words from self._timestamped_words_buffer that are older than the
|
||||
# elapsed time and print a message about them to the console
|
||||
# Pop all words from self._timestamped_words_buffer that are
|
||||
# older than the elapsed time and print a message about them to
|
||||
# the console.
|
||||
while self._timestamped_words_buffer and self._timestamped_words_buffer[0][1] <= elapsed_seconds:
|
||||
word, timestamp = self._timestamped_words_buffer.pop(0)
|
||||
if word == "LLMFullResponseEndFrame" and timestamp == 0:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
continue
|
||||
# print(f"Word '{word}' with timestamp {timestamp:.2f}s has been spoken.")
|
||||
await self.push_frame(TextFrame(word))
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
|
||||
@@ -212,7 +222,6 @@ class CartesiaTTSService(TTSService):
|
||||
"language": self._language,
|
||||
"add_timestamps": True,
|
||||
}
|
||||
# logger.debug(f"SENDING MESSAGE {json.dumps(msg)}")
|
||||
try:
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
except Exception as e:
|
||||
|
||||
@@ -136,15 +136,18 @@ class DeepgramSTTService(AsyncAIService):
|
||||
await self.queue_frame(frame, direction)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
if await self._connection.start(self._live_options):
|
||||
logger.debug(f"{self}: Connected to Deepgram")
|
||||
else:
|
||||
logger.error(f"{self}: Unable to connect to Deepgram")
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._connection.finish()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._connection.finish()
|
||||
|
||||
async def _on_message(self, *args, **kwargs):
|
||||
|
||||
@@ -68,14 +68,17 @@ class GladiaSTTService(AsyncAIService):
|
||||
await self.queue_frame(frame, direction)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._websocket = await websockets.connect(self._url)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
await self._setup_gladia()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._websocket.close()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._websocket.close()
|
||||
|
||||
async def _setup_gladia(self):
|
||||
|
||||
@@ -46,12 +46,26 @@ class BaseInputTransport(FrameProcessor):
|
||||
self._audio_in_queue = asyncio.Queue()
|
||||
self._audio_task = self.get_event_loop().create_task(self._audio_task_handler())
|
||||
|
||||
async def stop(self):
|
||||
# Wait for the task to finish.
|
||||
async def stop(self, frame: EndFrame):
|
||||
# Cancel and wait for the audio input task to finish.
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
self._audio_task.cancel()
|
||||
await self._audio_task
|
||||
|
||||
# Wait for the push frame task to finish. It will finish when the
|
||||
# EndFrame is actually processed.
|
||||
await self._push_frame_task
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
# Cancel all the tasks and wait for them to finish.
|
||||
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
self._audio_task.cancel()
|
||||
await self._audio_task
|
||||
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
def vad_analyzer(self) -> VADAnalyzer | None:
|
||||
return self._params.vad_analyzer
|
||||
|
||||
@@ -63,17 +77,12 @@ class BaseInputTransport(FrameProcessor):
|
||||
# Frame processor
|
||||
#
|
||||
|
||||
async def cleanup(self):
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Specific system frames
|
||||
if isinstance(frame, CancelFrame):
|
||||
await self.stop()
|
||||
# We don't queue a CancelFrame since we want to stop ASAP.
|
||||
await self.cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, BotInterruptionFrame):
|
||||
await self._handle_interruptions(frame, False)
|
||||
@@ -89,8 +98,10 @@ class BaseInputTransport(FrameProcessor):
|
||||
await self.start(frame)
|
||||
await self._internal_push_frame(frame, direction)
|
||||
elif isinstance(frame, EndFrame):
|
||||
# Push EndFrame before stop(), because stop() waits on the task to
|
||||
# finish and the task finishes when EndFrame is processed.
|
||||
await self._internal_push_frame(frame, direction)
|
||||
await self.stop()
|
||||
await self.stop(frame)
|
||||
# Other frames
|
||||
else:
|
||||
await self._internal_push_frame(frame, direction)
|
||||
@@ -111,10 +122,12 @@ class BaseInputTransport(FrameProcessor):
|
||||
await self._push_queue.put((frame, direction))
|
||||
|
||||
async def _push_frame_task_handler(self):
|
||||
while True:
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
(frame, direction) = await self._push_queue.get()
|
||||
await self.push_frame(frame, direction)
|
||||
running = not isinstance(frame, EndFrame)
|
||||
self._push_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
@@ -64,18 +64,34 @@ class BaseOutputTransport(FrameProcessor):
|
||||
self._create_push_task()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
# Create media threads queues.
|
||||
# Create camera output queue and task if needed.
|
||||
if self._params.camera_out_enabled:
|
||||
self._camera_out_queue = asyncio.Queue()
|
||||
self._camera_out_task = self.get_event_loop().create_task(self._camera_out_task_handler())
|
||||
|
||||
async def stop(self):
|
||||
# Wait on the threads to finish.
|
||||
async def stop(self, frame: EndFrame):
|
||||
# Cancel and wait for the camera output task to finish.
|
||||
if self._params.camera_out_enabled:
|
||||
self._camera_out_task.cancel()
|
||||
await self._camera_out_task
|
||||
|
||||
self._stopped_event.set()
|
||||
# Wait for the push frame and sink tasks to finish. They will finish when
|
||||
# the EndFrame is actually processed.
|
||||
await self._push_frame_task
|
||||
await self._sink_task
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
# Cancel all the tasks and wait for them to finish.
|
||||
|
||||
if self._params.camera_out_enabled:
|
||||
self._camera_out_task.cancel()
|
||||
await self._camera_out_task
|
||||
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
self._sink_task.cancel()
|
||||
await self._sink_task
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame):
|
||||
pass
|
||||
@@ -93,48 +109,38 @@ class BaseOutputTransport(FrameProcessor):
|
||||
# Frame processor
|
||||
#
|
||||
|
||||
async def cleanup(self):
|
||||
if self._sink_task:
|
||||
self._sink_task.cancel()
|
||||
await self._sink_task
|
||||
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
#
|
||||
# Out-of-band frames like (CancelFrame or StartInterruptionFrame) are
|
||||
# pushed immediately. Other frames require order so they are put in the
|
||||
# sink queue.
|
||||
# System frames (like StartInterruptionFrame) are pushed
|
||||
# immediately. Other frames require order so they are put in the sink
|
||||
# queue.
|
||||
#
|
||||
if isinstance(frame, StartFrame):
|
||||
await self.start(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
# EndFrame is managed in the sink queue handler.
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self.stop()
|
||||
if isinstance(frame, CancelFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
await self.cancel(frame)
|
||||
elif isinstance(frame, StartInterruptionFrame) or isinstance(frame, StopInterruptionFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
await self._handle_interruptions(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, MetricsFrame):
|
||||
await self.send_metrics(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
await self.send_metrics(frame)
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
# Control frames.
|
||||
elif isinstance(frame, StartFrame):
|
||||
await self._sink_queue.put(frame)
|
||||
await self.start(frame)
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self._sink_queue.put(frame)
|
||||
await self.stop(frame)
|
||||
# Other frames.
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
await self._handle_audio(frame)
|
||||
else:
|
||||
await self._sink_queue.put(frame)
|
||||
|
||||
# If we are finishing, wait here until we have stopped, otherwise we might
|
||||
# close things too early upstream. We need this event because we don't
|
||||
# know when the internal threads will finish.
|
||||
if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame):
|
||||
await self._stopped_event.wait()
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame):
|
||||
if not self.interruptions_allowed:
|
||||
return
|
||||
@@ -164,7 +170,9 @@ class BaseOutputTransport(FrameProcessor):
|
||||
async def _sink_task_handler(self):
|
||||
# Audio accumlation buffer
|
||||
buffer = bytearray()
|
||||
while True:
|
||||
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
frame = await self._sink_queue.get()
|
||||
if isinstance(frame, AudioRawFrame) and self._params.audio_out_enabled:
|
||||
@@ -185,8 +193,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
else:
|
||||
await self._internal_push_frame(frame)
|
||||
|
||||
if isinstance(frame, EndFrame):
|
||||
await self.stop()
|
||||
running = not isinstance(frame, EndFrame)
|
||||
|
||||
self._sink_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
@@ -210,10 +217,12 @@ class BaseOutputTransport(FrameProcessor):
|
||||
await self._push_queue.put((frame, direction))
|
||||
|
||||
async def _push_frame_task_handler(self):
|
||||
while True:
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
(frame, direction) = await self._push_queue.get()
|
||||
await self.push_frame(frame, direction)
|
||||
running = not isinstance(frame, EndFrame)
|
||||
self._push_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
@@ -12,7 +12,7 @@ import wave
|
||||
from typing import Awaitable, Callable
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.frames.frames import AudioRawFrame, StartFrame
|
||||
from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, StartFrame
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
@@ -57,14 +57,19 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
self._callbacks = callbacks
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await self._callbacks.on_client_connected(self._websocket)
|
||||
await super().start(frame)
|
||||
await self._callbacks.on_client_connected(self._websocket)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_messages())
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
if self._websocket.client_state != WebSocketState.DISCONNECTED:
|
||||
await self._websocket.close()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
if self._websocket.client_state != WebSocketState.DISCONNECTED:
|
||||
await self._websocket.close()
|
||||
await super().stop()
|
||||
|
||||
async def _receive_messages(self):
|
||||
async for message in self._websocket.iter_text():
|
||||
|
||||
@@ -11,7 +11,7 @@ import wave
|
||||
from typing import Awaitable, Callable
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.frames.frames import AudioRawFrame, StartFrame
|
||||
from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, StartFrame
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
from pipecat.serializers.protobuf import ProtobufFrameSerializer
|
||||
@@ -64,10 +64,15 @@ class WebsocketServerInputTransport(BaseInputTransport):
|
||||
self._server_task = self.get_event_loop().create_task(self._server_task_handler())
|
||||
await super().start(frame)
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
self._stop_server_event.set()
|
||||
await self._server_task
|
||||
await super().stop()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
self._server_task.cancel()
|
||||
await self._server_task
|
||||
|
||||
async def _server_task_handler(self):
|
||||
logger.info(f"Starting websocket server on {self._host}:{self._port}")
|
||||
|
||||
@@ -23,6 +23,8 @@ from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
ImageRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
@@ -125,11 +127,15 @@ class DailyCallbacks(BaseModel):
|
||||
|
||||
def completion_callback(future):
|
||||
def _callback(*args):
|
||||
if not future.cancelled():
|
||||
if len(args) > 1:
|
||||
future.get_loop().call_soon_threadsafe(future.set_result, args)
|
||||
else:
|
||||
future.get_loop().call_soon_threadsafe(future.set_result, *args)
|
||||
def set_result(future, *args):
|
||||
try:
|
||||
if len(args) > 1:
|
||||
future.set_result(args)
|
||||
else:
|
||||
future.set_result(*args)
|
||||
except asyncio.InvalidStateError:
|
||||
pass
|
||||
future.get_loop().call_soon_threadsafe(set_result, future, *args)
|
||||
return _callback
|
||||
|
||||
|
||||
@@ -541,9 +547,19 @@ class DailyInputTransport(BaseInputTransport):
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
self._audio_in_task = self.get_event_loop().create_task(self._audio_in_task_handler())
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self, frame: EndFrame):
|
||||
# Parent stop.
|
||||
await super().stop()
|
||||
await super().stop(frame)
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
# Stop audio thread.
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
self._audio_in_task.cancel()
|
||||
await self._audio_in_task
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
# Parent stop.
|
||||
await super().cancel(frame)
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
# Stop audio thread.
|
||||
@@ -658,9 +674,15 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
# Join the room.
|
||||
await self._client.join()
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self, frame: EndFrame):
|
||||
# Parent stop.
|
||||
await super().stop()
|
||||
await super().stop(frame)
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
# Parent stop.
|
||||
await super().cancel(frame)
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user