Compare commits
3 Commits
aleix/fram
...
khk/togeth
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
90c64c3df6 | ||
|
|
c2bc64361a | ||
|
|
9bbb824248 |
@@ -70,7 +70,7 @@ async def main():
|
||||
async def user_idle_callback(user_idle: UserIdleProcessor):
|
||||
messages.append(
|
||||
{"role": "system", "content": "Ask the user if they are still there and try to prompt for some input, but be short."})
|
||||
await user_idle.queue_frame(LLMMessagesFrame(messages))
|
||||
await user_idle.push_frame(LLMMessagesFrame(messages))
|
||||
|
||||
user_idle = UserIdleProcessor(callback=user_idle_callback, timeout=5.0)
|
||||
|
||||
|
||||
@@ -94,6 +94,8 @@ class UserImageRawFrame(ImageRawFrame):
|
||||
|
||||
"""
|
||||
user_id: str
|
||||
context: Any = None
|
||||
description: str | None = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(user: {self.user_id}, size: {self.size}, format: {self.format})"
|
||||
@@ -423,7 +425,7 @@ class TTSStoppedFrame(ControlFrame):
|
||||
class UserImageRequestFrame(ControlFrame):
|
||||
"""A frame user to request an image from the given user."""
|
||||
user_id: str
|
||||
context: Optional[Any] = None
|
||||
context: Any = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}, user: {self.user_id}"
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame, OpenAILLMContext
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
|
||||
from pipecat.frames.frames import EndFrame, Frame, StartInterruptionFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class AsyncFrameProcessor(FrameProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str | None = None,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
**kwargs):
|
||||
super().__init__(name=name, loop=loop, **kwargs)
|
||||
|
||||
self._create_push_task()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
await self._handle_interruptions(frame)
|
||||
|
||||
async def queue_frame(
|
||||
self,
|
||||
frame: Frame,
|
||||
direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
await self._push_queue.put((frame, direction))
|
||||
|
||||
async def cleanup(self):
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame):
|
||||
# Cancel the task. This will stop pushing frames downstream.
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
# Push an out-of-band frame (i.e. not using the ordered push
|
||||
# frame task).
|
||||
await self.push_frame(frame)
|
||||
# Create a new queue and task.
|
||||
self._create_push_task()
|
||||
|
||||
def _create_push_task(self):
|
||||
self._push_queue = asyncio.Queue()
|
||||
self._push_frame_task = self.get_event_loop().create_task(self._push_frame_task_handler())
|
||||
|
||||
async def _push_frame_task_handler(self):
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
(frame, direction) = await self._push_queue.get()
|
||||
await self.push_frame(frame, direction)
|
||||
running = not isinstance(frame, EndFrame)
|
||||
self._push_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
@@ -10,12 +10,14 @@ import time
|
||||
from enum import Enum
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
MetricsFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
UserStoppedSpeakingFrame)
|
||||
StopInterruptionFrame,
|
||||
SystemFrame)
|
||||
from pipecat.utils.utils import obj_count, obj_id
|
||||
|
||||
from loguru import logger
|
||||
@@ -105,6 +107,13 @@ class FrameProcessor:
|
||||
# Metrics
|
||||
self._metrics = FrameProcessorMetrics(name=self.name)
|
||||
|
||||
# Every processor in Pipecat should only output frames from a single
|
||||
# task. This avoid problems like audio overlapping. System frames are
|
||||
# the exception to this rule.
|
||||
#
|
||||
# This create this task.
|
||||
self.__create_push_task()
|
||||
|
||||
@property
|
||||
def interruptions_allowed(self):
|
||||
return self._allow_interruptions
|
||||
@@ -184,14 +193,41 @@ class FrameProcessor:
|
||||
self._enable_usage_metrics = frame.enable_usage_metrics
|
||||
self._report_only_initial_ttfb = frame.report_only_initial_ttfb
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
await self._start_interruption()
|
||||
await self.stop_all_metrics()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
elif isinstance(frame, StopInterruptionFrame):
|
||||
self._should_report_ttfb = True
|
||||
|
||||
async def push_error(self, error: ErrorFrame):
|
||||
await self.push_frame(error, FrameDirection.UPSTREAM)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
if isinstance(frame, SystemFrame):
|
||||
await self.__internal_push_frame(frame, direction)
|
||||
else:
|
||||
await self.__push_queue.put((frame, direction))
|
||||
|
||||
#
|
||||
# Handle interruptions
|
||||
#
|
||||
|
||||
async def _start_interruption(self):
|
||||
# Cancel the task. This will stop pushing frames downstream.
|
||||
self.__push_frame_task.cancel()
|
||||
await self.__push_frame_task
|
||||
|
||||
# Create a new queue and task.
|
||||
self.__create_push_task()
|
||||
|
||||
async def _stop_interruption(self):
|
||||
# Nothing to do right now.
|
||||
pass
|
||||
|
||||
def __create_push_task(self):
|
||||
self.__push_queue = asyncio.Queue()
|
||||
self.__push_frame_task = self.get_event_loop().create_task(self.__push_frame_task_handler())
|
||||
|
||||
async def __internal_push_frame(self, frame: Frame, direction: FrameDirection):
|
||||
try:
|
||||
if direction == FrameDirection.DOWNSTREAM and self._next:
|
||||
logger.trace(f"Pushing {frame} from {self} to {self._next}")
|
||||
@@ -202,5 +238,16 @@ class FrameProcessor:
|
||||
except Exception as e:
|
||||
logger.exception(f"Uncaught exception in {self}: {e}")
|
||||
|
||||
async def __push_frame_task_handler(self):
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
(frame, direction) = await self.__push_queue.get()
|
||||
await self.__internal_push_frame(frame, direction)
|
||||
running = not isinstance(frame, EndFrame)
|
||||
self.__push_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
@@ -286,9 +286,6 @@ class RTVIProcessor(FrameProcessor):
|
||||
self._registered_actions: Dict[str, RTVIAction] = {}
|
||||
self._registered_services: Dict[str, RTVIService] = {}
|
||||
|
||||
self._push_frame_task = self.get_event_loop().create_task(self._push_frame_task_handler())
|
||||
self._push_queue = asyncio.Queue()
|
||||
|
||||
self._message_task = self.get_event_loop().create_task(self._message_task_handler())
|
||||
self._message_queue = asyncio.Queue()
|
||||
|
||||
@@ -335,12 +332,6 @@ class RTVIProcessor(FrameProcessor):
|
||||
message = RTVILLMFunctionCallStartMessage(data=fn)
|
||||
await self._push_transport_message(message, exclude_none=False)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
if isinstance(frame, SystemFrame):
|
||||
await super().push_frame(frame, direction)
|
||||
else:
|
||||
await self._internal_push_frame(frame, direction)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -394,30 +385,10 @@ class RTVIProcessor(FrameProcessor):
|
||||
# processing EndFrames.
|
||||
self._message_task.cancel()
|
||||
await self._message_task
|
||||
await self._push_frame_task
|
||||
|
||||
async def _cancel(self, frame: CancelFrame):
|
||||
self._message_task.cancel()
|
||||
await self._message_task
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
async def _internal_push_frame(
|
||||
self,
|
||||
frame: Frame | None,
|
||||
direction: FrameDirection | None = FrameDirection.DOWNSTREAM):
|
||||
await self._push_queue.put((frame, direction))
|
||||
|
||||
async def _push_frame_task_handler(self):
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
(frame, direction) = await self._push_queue.get()
|
||||
await super().push_frame(frame, direction)
|
||||
self._push_queue.task_done()
|
||||
running = not isinstance(frame, EndFrame)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
async def _push_transport_message(self, model: BaseModel, exclude_none: bool = True):
|
||||
frame = TransportMessageFrame(
|
||||
|
||||
@@ -62,10 +62,6 @@ class GStreamerPipelineSource(FrameProcessor):
|
||||
bus.add_signal_watch()
|
||||
bus.connect("message", self._on_gstreamer_message)
|
||||
|
||||
# Create push frame task. This is the task that will push frames in
|
||||
# order. We also guarantee that all frames are pushed in the same task.
|
||||
self._create_push_task()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -80,60 +76,28 @@ class GStreamerPipelineSource(FrameProcessor):
|
||||
elif isinstance(frame, StartFrame):
|
||||
# Push StartFrame before start(), because we want StartFrame to be
|
||||
# processed by every processor before any other frame is processed.
|
||||
await self._internal_push_frame(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
await self._start(frame)
|
||||
elif isinstance(frame, EndFrame):
|
||||
# Push EndFrame before stop(), because stop() waits on the task to
|
||||
# finish and the task finishes when EndFrame is processed.
|
||||
await self._internal_push_frame(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
await self._stop(frame)
|
||||
# Other frames
|
||||
else:
|
||||
await self._internal_push_frame(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _start(self, frame: StartFrame):
|
||||
self._player.set_state(Gst.State.PLAYING)
|
||||
|
||||
async def _stop(self, frame: EndFrame):
|
||||
self._player.set_state(Gst.State.NULL)
|
||||
# Wait for the push frame task to finish. It will finish when the
|
||||
# EndFrame is actually processed.
|
||||
await self._push_frame_task
|
||||
|
||||
async def _cancel(self, frame: CancelFrame):
|
||||
self._player.set_state(Gst.State.NULL)
|
||||
# Cancel all the tasks and wait for them to finish.
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
#
|
||||
# Push frames task
|
||||
#
|
||||
|
||||
def _create_push_task(self):
|
||||
loop = self.get_event_loop()
|
||||
self._push_queue = asyncio.Queue()
|
||||
self._push_frame_task = loop.create_task(self._push_frame_task_handler())
|
||||
|
||||
async def _internal_push_frame(
|
||||
self,
|
||||
frame: Frame | None,
|
||||
direction: FrameDirection | None = FrameDirection.DOWNSTREAM):
|
||||
await self._push_queue.put((frame, direction))
|
||||
|
||||
async def _push_frame_task_handler(self):
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
(frame, direction) = await self._push_queue.get()
|
||||
await self.push_frame(frame, direction)
|
||||
running = not isinstance(frame, EndFrame)
|
||||
self._push_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
#
|
||||
# GStreaner
|
||||
# GStreamer
|
||||
#
|
||||
|
||||
def _on_gstreamer_message(self, bus: Gst.Bus, message: Gst.Message):
|
||||
@@ -221,7 +185,7 @@ class GStreamerPipelineSource(FrameProcessor):
|
||||
frame = AudioRawFrame(audio=info.data,
|
||||
sample_rate=self._out_params.audio_sample_rate,
|
||||
num_channels=self._out_params.audio_channels)
|
||||
asyncio.run_coroutine_threadsafe(self._internal_push_frame(frame), self.get_event_loop())
|
||||
asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())
|
||||
buffer.unmap(info)
|
||||
return Gst.FlowReturn.OK
|
||||
|
||||
@@ -232,6 +196,6 @@ class GStreamerPipelineSource(FrameProcessor):
|
||||
image=info.data,
|
||||
size=(self._out_params.video_width, self._out_params.video_height),
|
||||
format="RGB")
|
||||
asyncio.run_coroutine_threadsafe(self._internal_push_frame(frame), self.get_event_loop())
|
||||
asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())
|
||||
buffer.unmap(info)
|
||||
return Gst.FlowReturn.OK
|
||||
|
||||
@@ -8,19 +8,14 @@ import asyncio
|
||||
|
||||
from typing import Awaitable, Callable, List
|
||||
|
||||
from pipecat.frames.frames import Frame, SystemFrame
|
||||
from pipecat.processors.async_frame_processor import AsyncFrameProcessor
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class IdleFrameProcessor(AsyncFrameProcessor):
|
||||
class IdleFrameProcessor(FrameProcessor):
|
||||
"""This class waits to receive any frame or list of desired frames within a
|
||||
given timeout. If the timeout is reached before receiving any of those
|
||||
frames the provided callback will be called.
|
||||
|
||||
The callback can then be used to push frames downstream by using
|
||||
`queue_frame()` (or `push_frame()` for system frames).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -41,10 +36,7 @@ class IdleFrameProcessor(AsyncFrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.queue_frame(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
# If we are not waiting for any specific frame set the event, otherwise
|
||||
# check if we have received one of the desired frames.
|
||||
@@ -55,7 +47,6 @@ class IdleFrameProcessor(AsyncFrameProcessor):
|
||||
if isinstance(frame, t):
|
||||
self._idle_event.set()
|
||||
|
||||
# If we are not waiting for any specific frame set the event, otherwise
|
||||
async def cleanup(self):
|
||||
self._idle_task.cancel()
|
||||
await self._idle_task
|
||||
|
||||
@@ -11,21 +11,16 @@ from typing import Awaitable, Callable
|
||||
from pipecat.frames.frames import (
|
||||
BotSpeakingFrame,
|
||||
Frame,
|
||||
SystemFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame)
|
||||
from pipecat.processors.async_frame_processor import AsyncFrameProcessor
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class UserIdleProcessor(AsyncFrameProcessor):
|
||||
class UserIdleProcessor(FrameProcessor):
|
||||
"""This class is useful to check if the user is interacting with the bot
|
||||
within a given timeout. If the timeout is reached before any interaction
|
||||
occurred the provided callback will be called.
|
||||
|
||||
The callback can then be used to push frames downstream by using
|
||||
`queue_frame()` (or `push_frame()` for system frames).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -46,10 +41,7 @@ class UserIdleProcessor(AsyncFrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.queue_frame(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
# We shouldn't call the idle callback if the user or the bot are speaking.
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
|
||||
@@ -30,9 +30,9 @@ from pipecat.frames.frames import (
|
||||
TTSVoiceUpdateFrame,
|
||||
TextFrame,
|
||||
UserImageRequestFrame,
|
||||
VisionImageRawFrame
|
||||
VisionImageRawFrame,
|
||||
UserImageRawFrame
|
||||
)
|
||||
from pipecat.processors.async_frame_processor import AsyncFrameProcessor
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.audio import calculate_audio_volume
|
||||
@@ -64,7 +64,7 @@ class AIService(FrameProcessor):
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self.stop(frame)
|
||||
|
||||
async def process_generator(self, generator: AsyncGenerator[Frame, None]):
|
||||
async def process_generator(self, generator: AsyncGenerator[Frame | None, None]):
|
||||
async for f in generator:
|
||||
if f:
|
||||
if isinstance(f, ErrorFrame):
|
||||
@@ -73,30 +73,6 @@ class AIService(FrameProcessor):
|
||||
await self.push_frame(f)
|
||||
|
||||
|
||||
class AsyncAIService(AsyncFrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
pass
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
pass
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
pass
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
await self.start(frame)
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self.cancel(frame)
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self.stop(frame)
|
||||
|
||||
|
||||
class LLMService(AIService):
|
||||
"""This class is a no-op but serves as a base class for LLM services."""
|
||||
|
||||
@@ -439,13 +415,14 @@ class VisionService(AIService):
|
||||
self._describe_text = None
|
||||
|
||||
@abstractmethod
|
||||
async def run_vision(self, frame: VisionImageRawFrame) -> AsyncGenerator[Frame, None]:
|
||||
async def run_vision(self, frame: VisionImageRawFrame |
|
||||
UserImageRawFrame) -> AsyncGenerator[Frame, None]:
|
||||
pass
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, VisionImageRawFrame):
|
||||
if isinstance(frame, VisionImageRawFrame) or isinstance(frame, UserImageRawFrame):
|
||||
await self.start_processing_metrics()
|
||||
await self.process_generator(self.run_vision(frame))
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
@@ -18,13 +18,11 @@ from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
SystemFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TranscriptionFrame,
|
||||
URLImageRawFrame)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import AsyncAIService, TTSService, ImageGenService
|
||||
from pipecat.services.ai_services import STTService, TTSService, ImageGenService
|
||||
from pipecat.services.openai import BaseOpenAILLMService
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
@@ -118,7 +116,7 @@ class AzureTTSService(TTSService):
|
||||
logger.error(f"{self} error: {cancellation_details.error_details}")
|
||||
|
||||
|
||||
class AzureSTTService(AsyncAIService):
|
||||
class AzureSTTService(STTService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -141,15 +139,11 @@ class AzureSTTService(AsyncAIService):
|
||||
speech_config=speech_config, audio_config=audio_config)
|
||||
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
self._audio_stream.write(frame.audio)
|
||||
else:
|
||||
await self._push_queue.put((frame, direction))
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
await self.start_processing_metrics()
|
||||
self._audio_stream.write(audio)
|
||||
await self.stop_processing_metrics()
|
||||
yield None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
@@ -168,7 +162,7 @@ class AzureSTTService(AsyncAIService):
|
||||
def _on_handle_recognized(self, event):
|
||||
if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0:
|
||||
frame = TranscriptionFrame(event.result.text, "", time_now_iso8601())
|
||||
asyncio.run_coroutine_threadsafe(self.queue_frame(frame), self.get_event_loop())
|
||||
asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())
|
||||
|
||||
|
||||
class AzureImageGenServiceREST(ImageGenService):
|
||||
|
||||
@@ -161,8 +161,8 @@ class DeepgramSTTService(STTService):
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
await self.start_processing_metrics()
|
||||
await self._connection.send(audio)
|
||||
yield None
|
||||
await self.stop_processing_metrics()
|
||||
yield None
|
||||
|
||||
async def _connect(self):
|
||||
if await self._connection.start(self._live_options):
|
||||
|
||||
@@ -7,20 +7,17 @@
|
||||
import base64
|
||||
import json
|
||||
|
||||
from typing import Optional
|
||||
from typing import AsyncGenerator, Optional
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
SystemFrame,
|
||||
TranscriptionFrame)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import AsyncAIService
|
||||
from pipecat.services.ai_services import STTService
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from loguru import logger
|
||||
@@ -35,7 +32,7 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class GladiaSTTService(AsyncAIService):
|
||||
class GladiaSTTService(STTService):
|
||||
class InputParams(BaseModel):
|
||||
sample_rate: Optional[int] = 16000
|
||||
language: Optional[str] = "english"
|
||||
@@ -57,16 +54,6 @@ class GladiaSTTService(AsyncAIService):
|
||||
self._params = params
|
||||
self._confidence = confidence
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
await self._send_audio(frame)
|
||||
else:
|
||||
await self.queue_frame(frame, direction)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._websocket = await websockets.connect(self._url)
|
||||
@@ -81,6 +68,12 @@ class GladiaSTTService(AsyncAIService):
|
||||
await super().cancel(frame)
|
||||
await self._websocket.close()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
await self.start_processing_metrics()
|
||||
await self._send_audio(audio)
|
||||
await self.stop_processing_metrics()
|
||||
yield None
|
||||
|
||||
async def _setup_gladia(self):
|
||||
configuration = {
|
||||
"x_gladia_key": self._api_key,
|
||||
@@ -92,9 +85,9 @@ class GladiaSTTService(AsyncAIService):
|
||||
|
||||
await self._websocket.send(json.dumps(configuration))
|
||||
|
||||
async def _send_audio(self, frame: AudioRawFrame):
|
||||
async def _send_audio(self, audio: bytes):
|
||||
message = {
|
||||
'frames': base64.b64encode(frame.audio).decode("utf-8")
|
||||
'frames': base64.b64encode(audio).decode("utf-8")
|
||||
}
|
||||
await self._websocket.send(json.dumps(message))
|
||||
|
||||
@@ -113,6 +106,6 @@ class GladiaSTTService(AsyncAIService):
|
||||
transcript = utterance["transcription"]
|
||||
if confidence >= self._confidence:
|
||||
if type == "final":
|
||||
await self.queue_frame(TranscriptionFrame(transcript, "", time_now_iso8601()))
|
||||
await self.push_frame(TranscriptionFrame(transcript, "", time_now_iso8601()))
|
||||
else:
|
||||
await self.queue_frame(InterimTranscriptionFrame(transcript, "", time_now_iso8601()))
|
||||
await self.push_frame(InterimTranscriptionFrame(transcript, "", time_now_iso8601()))
|
||||
|
||||
@@ -10,7 +10,13 @@ from PIL import Image
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, TextFrame, VisionImageRawFrame
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TextFrame,
|
||||
ImageRawFrame,
|
||||
VisionImageRawFrame,
|
||||
UserImageRawFrame)
|
||||
from pipecat.services.ai_services import VisionService
|
||||
|
||||
from loguru import logger
|
||||
@@ -48,7 +54,7 @@ class MoondreamService(VisionService):
|
||||
self,
|
||||
*,
|
||||
model="vikhyatk/moondream2",
|
||||
revision="2024-04-02",
|
||||
revision="2024-08-26",
|
||||
use_cpu=False
|
||||
):
|
||||
super().__init__()
|
||||
@@ -70,23 +76,30 @@ class MoondreamService(VisionService):
|
||||
|
||||
logger.debug("Loaded Moondream model")
|
||||
|
||||
async def run_vision(self, frame: VisionImageRawFrame) -> AsyncGenerator[Frame, None]:
|
||||
async def run_vision(self, frame: VisionImageRawFrame |
|
||||
UserImageRawFrame) -> AsyncGenerator[Frame, None]:
|
||||
if not self._model:
|
||||
logger.error(f"{self} error: Moondream model not available")
|
||||
yield ErrorFrame("Moondream model not available")
|
||||
return
|
||||
|
||||
question = getattr(frame, "context", None) or getattr(frame, "text", None)
|
||||
|
||||
logger.debug(f"Analyzing image: {frame}")
|
||||
|
||||
def get_image_description(frame: VisionImageRawFrame):
|
||||
def get_image_description(frame: ImageRawFrame):
|
||||
image = Image.frombytes(frame.format, frame.size, frame.image)
|
||||
image_embeds = self._model.encode_image(image)
|
||||
description = self._model.answer_question(
|
||||
image_embeds=image_embeds,
|
||||
question=frame.text,
|
||||
question=question,
|
||||
tokenizer=self._tokenizer)
|
||||
return description
|
||||
|
||||
description = await asyncio.to_thread(get_image_description, frame)
|
||||
|
||||
yield TextFrame(text=description)
|
||||
if isinstance(frame, VisionImageRawFrame):
|
||||
yield TextFrame(text=description)
|
||||
elif isinstance(frame, UserImageRawFrame):
|
||||
frame.description = description
|
||||
yield frame
|
||||
|
||||
@@ -18,8 +18,6 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMModelUpdateFrame,
|
||||
TextFrame,
|
||||
VisionImageRawFrame,
|
||||
UserImageRequestFrame,
|
||||
UserImageRawFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
@@ -100,8 +98,12 @@ class TogetherLLMService(LLMService):
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Function calling
|
||||
got_first_chunk = False
|
||||
|
||||
# Function calling. We should be able to prompt Llama 3.1 to always return either plain
|
||||
# text or a function call. However, occasionally we see a function call after plain text.
|
||||
# Try to account for that.
|
||||
most_recent_chunk_was_function_call_start_char = False # function call start char is '<'
|
||||
accumulating_function_call = False
|
||||
function_call_accumulator = ""
|
||||
|
||||
@@ -131,10 +133,24 @@ class TogetherLLMService(LLMService):
|
||||
if accumulating_function_call:
|
||||
function_call_accumulator += chunk.choices[0].delta.content
|
||||
else:
|
||||
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
|
||||
text = chunk.choices[0].delta.content
|
||||
if most_recent_chunk_was_function_call_start_char:
|
||||
most_recent_chunk_was_function_call_start_char = False
|
||||
if text == "function":
|
||||
accumulating_function_call = True
|
||||
function_call_accumulator = "<function"
|
||||
else:
|
||||
await self.push_frame("<" + TextFrame(chunk.choices[0].delta.content))
|
||||
elif text == '<':
|
||||
most_recent_chunk_was_function_call_start_char = True
|
||||
else:
|
||||
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
|
||||
|
||||
if chunk.choices[0].finish_reason == 'eos' and accumulating_function_call:
|
||||
await self._extract_function_call(context, function_call_accumulator)
|
||||
if chunk.choices[0].finish_reason == 'eos':
|
||||
if accumulating_function_call:
|
||||
await self._extract_function_call(context, function_call_accumulator)
|
||||
elif most_recent_chunk_was_function_call_start_char:
|
||||
await self.push_frame(TextFrame("<"))
|
||||
|
||||
except CancelledError as e:
|
||||
# todo: implement token counting estimates for use when the user interrupts a long generation
|
||||
@@ -164,13 +180,26 @@ class TogetherLLMService(LLMService):
|
||||
await self._process_context(context)
|
||||
|
||||
async def _extract_function_call(self, context, function_call_accumulator):
|
||||
# logger.debug(f"Extracting function call: {function_call_accumulator}")
|
||||
context.add_message({"role": "assistant", "content": function_call_accumulator})
|
||||
|
||||
function_regex = r"<function=(\w+)>(.*?)</function>"
|
||||
# Function format regex. Llama 3.1 sometimes adds an extra " or space just before the
|
||||
# </function> tag. This regexp just ignores the extra characters if they are there. (That's
|
||||
# the [\s"]? part of the regex.) Occasionally the </function> close tag is also missing.
|
||||
function_regex = r'<function=(\w+)>(.*?)<\/function>|<function=(\w+)>(.*)'
|
||||
match = re.search(function_regex, function_call_accumulator)
|
||||
if match:
|
||||
function_name, args_string = match.groups()
|
||||
function_name = ""
|
||||
args_string = ""
|
||||
if match.group(1): # Case with closing tag
|
||||
function_name = match.group(1)
|
||||
args_string = match.group(2)
|
||||
else: # Case without closing tag
|
||||
function_name = match.group(3)
|
||||
args_string = match.group(4)
|
||||
|
||||
try:
|
||||
args_string = re.sub(r'[\s"]+$', '', args_string)
|
||||
arguments = json.loads(args_string)
|
||||
await self.call_function(context=context,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
@@ -181,7 +210,8 @@ class TogetherLLMService(LLMService):
|
||||
# We get here if the LLM returns a function call with invalid JSON arguments. This could happen
|
||||
# because of LLM non-determinism, or maybe more often because of user error in the prompt.
|
||||
# Should we do anything more than log a warning?
|
||||
logger.debug(f"Error parsing function arguments: {error}")
|
||||
logger.debug(
|
||||
f"Error parsing function arguments: {error} - {function_call_accumulator}")
|
||||
|
||||
|
||||
class TogetherLLMContext(OpenAILLMContext):
|
||||
@@ -219,9 +249,17 @@ class TogetherUserContextAggregator(LLMUserContextAggregator):
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
self._context = TogetherLLMContext.from_openai_context(context)
|
||||
|
||||
def get_messages_frame(self):
|
||||
return OpenAILLMContextFrame(self._context)
|
||||
|
||||
async def push_messages_frame(self):
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
await self.push_frame(self.get_messages_frame())
|
||||
|
||||
def append_image_description_tool_message(self, description):
|
||||
self._context.add_message({
|
||||
"role": "tool",
|
||||
"content": json.dumps({"image_description": description})
|
||||
})
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -230,20 +268,10 @@ class TogetherUserContextAggregator(LLMUserContextAggregator):
|
||||
# to talk through (tagging @aleix). At some point we might need to refactor these
|
||||
# context aggregators.
|
||||
try:
|
||||
if isinstance(frame, UserImageRequestFrame):
|
||||
# The LLM sends a UserImageRequestFrame upstream. Cache any context provided with
|
||||
# that frame so we can use it when we assemble the image message in the assistant
|
||||
# context aggregator.
|
||||
if (frame.context):
|
||||
if isinstance(frame.context, str):
|
||||
self._context._user_image_request_context[frame.user_id] = frame.context
|
||||
else:
|
||||
logger.error(
|
||||
f"Unexpected UserImageRequestFrame context type: {type(frame.context)}")
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
else:
|
||||
if frame.user_id in self._context._user_image_request_context:
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
if isinstance(frame, UserImageRawFrame):
|
||||
if frame.description:
|
||||
self.append_image_description_tool_message(frame.description)
|
||||
await self.push_messages_frame()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
|
||||
@@ -37,10 +37,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
|
||||
self._executor = ThreadPoolExecutor(max_workers=5)
|
||||
|
||||
# Create push frame task. This is the task that will push frames in
|
||||
# order. We also guarantee that all frames are pushed in the same task.
|
||||
self._create_push_task()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
# Create audio input queue and task if needed.
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
@@ -53,10 +49,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
self._audio_task.cancel()
|
||||
await self._audio_task
|
||||
|
||||
# Wait for the push frame task to finish. It will finish when the
|
||||
# EndFrame is actually processed.
|
||||
await self._push_frame_task
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
# Cancel all the tasks and wait for them to finish.
|
||||
|
||||
@@ -64,9 +56,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
self._audio_task.cancel()
|
||||
await self._audio_task
|
||||
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
def vad_analyzer(self) -> VADAnalyzer | None:
|
||||
return self._params.vad_analyzer
|
||||
|
||||
@@ -86,11 +75,8 @@ class BaseInputTransport(FrameProcessor):
|
||||
await self.cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, BotInterruptionFrame):
|
||||
await self._handle_interruptions(frame, False)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
logger.debug("Bot interruption")
|
||||
await self._start_interruption()
|
||||
elif isinstance(frame, StopInterruptionFrame):
|
||||
await self._stop_interruption()
|
||||
# All other system frames
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -98,12 +84,12 @@ class BaseInputTransport(FrameProcessor):
|
||||
elif isinstance(frame, StartFrame):
|
||||
# Push StartFrame before start(), because we want StartFrame to be
|
||||
# processed by every processor before any other frame is processed.
|
||||
await self._internal_push_frame(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
await self.start(frame)
|
||||
elif isinstance(frame, EndFrame):
|
||||
# Push EndFrame before stop(), because stop() waits on the task to
|
||||
# finish and the task finishes when EndFrame is processed.
|
||||
await self._internal_push_frame(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
await self.stop(frame)
|
||||
elif isinstance(frame, VADParamsUpdateFrame):
|
||||
vad_analyzer = self.vad_analyzer()
|
||||
@@ -111,73 +97,28 @@ class BaseInputTransport(FrameProcessor):
|
||||
vad_analyzer.set_params(frame.params)
|
||||
# Other frames
|
||||
else:
|
||||
await self._internal_push_frame(frame, direction)
|
||||
|
||||
#
|
||||
# Push frames task
|
||||
#
|
||||
|
||||
def _create_push_task(self):
|
||||
loop = self.get_event_loop()
|
||||
self._push_queue = asyncio.Queue()
|
||||
self._push_frame_task = loop.create_task(self._push_frame_task_handler())
|
||||
|
||||
async def _internal_push_frame(
|
||||
self,
|
||||
frame: Frame | None,
|
||||
direction: FrameDirection | None = FrameDirection.DOWNSTREAM):
|
||||
await self._push_queue.put((frame, direction))
|
||||
|
||||
async def _push_frame_task_handler(self):
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
(frame, direction) = await self._push_queue.get()
|
||||
await self.push_frame(frame, direction)
|
||||
running = not isinstance(frame, EndFrame)
|
||||
self._push_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
#
|
||||
# Handle interruptions
|
||||
#
|
||||
|
||||
async def _start_interruption(self):
|
||||
if not self.interruptions_allowed:
|
||||
return
|
||||
|
||||
# Cancel the task. This will stop pushing frames downstream.
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
# Push an out-of-band frame (i.e. not using the ordered push
|
||||
# frame task) to stop everything, specially at the output
|
||||
# transport.
|
||||
await self.push_frame(StartInterruptionFrame())
|
||||
# Create a new queue and task.
|
||||
self._create_push_task()
|
||||
|
||||
async def _stop_interruption(self):
|
||||
if not self.interruptions_allowed:
|
||||
return
|
||||
|
||||
await self.push_frame(StopInterruptionFrame())
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame, push_frame: bool):
|
||||
async def _handle_interruptions(self, frame: Frame):
|
||||
if self.interruptions_allowed:
|
||||
# Make sure we notify about interruptions quickly out-of-band
|
||||
if isinstance(frame, BotInterruptionFrame):
|
||||
logger.debug("Bot interruption")
|
||||
await self._start_interruption()
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
# Make sure we notify about interruptions quickly out-of-band.
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
logger.debug("User started speaking")
|
||||
await self._start_interruption()
|
||||
# Push an out-of-band frame (i.e. not using the ordered push
|
||||
# frame task) to stop everything, specially at the output
|
||||
# transport.
|
||||
await self.push_frame(StartInterruptionFrame())
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
logger.debug("User stopped speaking")
|
||||
await self._stop_interruption()
|
||||
await self.push_frame(StopInterruptionFrame())
|
||||
|
||||
if push_frame:
|
||||
await self._internal_push_frame(frame)
|
||||
await self.push_frame(frame)
|
||||
|
||||
#
|
||||
# Audio input
|
||||
@@ -201,7 +142,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
frame = UserStoppedSpeakingFrame()
|
||||
|
||||
if frame:
|
||||
await self._handle_interruptions(frame, True)
|
||||
await self._handle_interruptions(frame)
|
||||
|
||||
vad_state = new_vad_state
|
||||
return vad_state
|
||||
@@ -222,7 +163,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
|
||||
# Push audio downstream if passthrough.
|
||||
if audio_passthrough:
|
||||
await self._internal_push_frame(frame)
|
||||
await self.push_frame(frame)
|
||||
|
||||
self._audio_in_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
|
||||
@@ -66,10 +66,6 @@ class BaseOutputTransport(FrameProcessor):
|
||||
# generating frames upstream while, for example, the audio is playing.
|
||||
self._create_sink_task()
|
||||
|
||||
# Create push frame task. This is the task that will push frames in
|
||||
# order. We also guarantee that all frames are pushed in the same task.
|
||||
self._create_push_task()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
# Create camera output queue and task if needed.
|
||||
if self._params.camera_out_enabled:
|
||||
@@ -91,9 +87,8 @@ class BaseOutputTransport(FrameProcessor):
|
||||
self._audio_out_task.cancel()
|
||||
await self._audio_out_task
|
||||
|
||||
# Wait for the push frame and sink tasks to finish. They will finish when
|
||||
# the EndFrame is actually processed.
|
||||
await self._push_frame_task
|
||||
# Wait for the sink task to finish. They will finish when the EndFrame
|
||||
# is actually processed.
|
||||
await self._sink_task
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
@@ -103,9 +98,6 @@ class BaseOutputTransport(FrameProcessor):
|
||||
self._camera_out_task.cancel()
|
||||
await self._camera_out_task
|
||||
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
self._sink_task.cancel()
|
||||
await self._sink_task
|
||||
|
||||
@@ -170,10 +162,6 @@ class BaseOutputTransport(FrameProcessor):
|
||||
self._sink_task.cancel()
|
||||
await self._sink_task
|
||||
self._create_sink_task()
|
||||
# Stop push task.
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
self._create_push_task()
|
||||
# Let's send a bot stopped speaking if we have to.
|
||||
if self._bot_speaking:
|
||||
await self._bot_stopped_speaking()
|
||||
@@ -213,7 +201,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
frame = await self._sink_queue.get()
|
||||
if isinstance(frame, AudioRawFrame):
|
||||
await self.write_raw_audio_frames(frame.audio)
|
||||
await self._internal_push_frame(frame)
|
||||
await self.push_frame(frame)
|
||||
await self.push_frame(BotSpeakingFrame(), FrameDirection.UPSTREAM)
|
||||
elif isinstance(frame, ImageRawFrame):
|
||||
await self._set_camera_image(frame)
|
||||
@@ -223,12 +211,12 @@ class BaseOutputTransport(FrameProcessor):
|
||||
await self.send_message(frame)
|
||||
elif isinstance(frame, TTSStartedFrame):
|
||||
await self._bot_started_speaking()
|
||||
await self._internal_push_frame(frame)
|
||||
await self.push_frame(frame)
|
||||
elif isinstance(frame, TTSStoppedFrame):
|
||||
await self._bot_stopped_speaking()
|
||||
await self._internal_push_frame(frame)
|
||||
await self.push_frame(frame)
|
||||
else:
|
||||
await self._internal_push_frame(frame)
|
||||
await self.push_frame(frame)
|
||||
|
||||
running = not isinstance(frame, EndFrame)
|
||||
|
||||
@@ -241,38 +229,12 @@ class BaseOutputTransport(FrameProcessor):
|
||||
async def _bot_started_speaking(self):
|
||||
logger.debug("Bot started speaking")
|
||||
self._bot_speaking = True
|
||||
await self._internal_push_frame(BotStartedSpeakingFrame(), FrameDirection.UPSTREAM)
|
||||
await self.push_frame(BotStartedSpeakingFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
async def _bot_stopped_speaking(self):
|
||||
logger.debug("Bot stopped speaking")
|
||||
self._bot_speaking = False
|
||||
await self._internal_push_frame(BotStoppedSpeakingFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
#
|
||||
# Push frames task
|
||||
#
|
||||
|
||||
def _create_push_task(self):
|
||||
loop = self.get_event_loop()
|
||||
self._push_queue = asyncio.Queue()
|
||||
self._push_frame_task = loop.create_task(self._push_frame_task_handler())
|
||||
|
||||
async def _internal_push_frame(
|
||||
self,
|
||||
frame: Frame | None,
|
||||
direction: FrameDirection | None = FrameDirection.DOWNSTREAM):
|
||||
await self._push_queue.put((frame, direction))
|
||||
|
||||
async def _push_frame_task_handler(self):
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
(frame, direction) = await self._push_queue.get()
|
||||
await self.push_frame(frame, direction)
|
||||
running = not isinstance(frame, EndFrame)
|
||||
self._push_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
await self.push_frame(BotStoppedSpeakingFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
#
|
||||
# Camera out
|
||||
@@ -356,7 +318,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
try:
|
||||
frame = await self._audio_out_queue.get()
|
||||
await self.write_raw_audio_frames(frame.audio)
|
||||
await self._internal_push_frame(frame)
|
||||
await self.push_frame(frame)
|
||||
await self.push_frame(BotSpeakingFrame(), FrameDirection.UPSTREAM)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
@@ -98,9 +98,9 @@ class WebsocketServerInputTransport(BaseInputTransport):
|
||||
continue
|
||||
|
||||
if isinstance(frame, AudioRawFrame):
|
||||
await self.push_audio_frame(frame)
|
||||
await self.queue_audio_frame(frame)
|
||||
else:
|
||||
await self._internal_push_frame(frame)
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Notify disconnection
|
||||
await self._callbacks.on_client_disconnected(websocket)
|
||||
|
||||
@@ -612,18 +612,18 @@ class DailyInputTransport(BaseInputTransport):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserImageRequestFrame):
|
||||
self.request_participant_image(frame.user_id)
|
||||
self.request_participant_image(frame.user_id, frame.context)
|
||||
|
||||
#
|
||||
# Frames
|
||||
#
|
||||
|
||||
async def push_transcription_frame(self, frame: TranscriptionFrame | InterimTranscriptionFrame):
|
||||
await self._internal_push_frame(frame)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def push_app_message(self, message: Any, sender: str):
|
||||
frame = DailyTransportMessageFrame(message=message, participant_id=sender)
|
||||
await self._internal_push_frame(frame)
|
||||
await self.push_frame(frame)
|
||||
|
||||
#
|
||||
# Audio in
|
||||
@@ -662,9 +662,10 @@ class DailyInputTransport(BaseInputTransport):
|
||||
color_format
|
||||
)
|
||||
|
||||
def request_participant_image(self, participant_id: str):
|
||||
def request_participant_image(self, participant_id: str, context: Any = None):
|
||||
if participant_id in self._video_renderers:
|
||||
self._video_renderers[participant_id]["render_next_frame"] = True
|
||||
truthy = context if context else True
|
||||
self._video_renderers[participant_id]["render_next_frame"] = truthy
|
||||
|
||||
async def _on_participant_video_frame(self, participant_id: str, buffer, size, format):
|
||||
render_frame = False
|
||||
@@ -677,16 +678,17 @@ class DailyInputTransport(BaseInputTransport):
|
||||
next_time = prev_time + 1 / framerate
|
||||
render_frame = (curr_time - next_time) < 0.1
|
||||
elif self._video_renderers[participant_id]["render_next_frame"]:
|
||||
render_frame = self._video_renderers[participant_id]["render_next_frame"]
|
||||
self._video_renderers[participant_id]["render_next_frame"] = False
|
||||
render_frame = True
|
||||
|
||||
if render_frame:
|
||||
frame = UserImageRawFrame(
|
||||
user_id=participant_id,
|
||||
image=buffer,
|
||||
size=size,
|
||||
format=format)
|
||||
await self._internal_push_frame(frame)
|
||||
format=format,
|
||||
context=None if render_frame is True else render_frame)
|
||||
await self.push_frame(frame)
|
||||
|
||||
self._video_renderers[participant_id]["timestamp"] = curr_time
|
||||
|
||||
|
||||
Reference in New Issue
Block a user