Compare commits

...

3 Commits

Author SHA1 Message Date
Kwindla Hultman Kramer
90c64c3df6 wip together function calling and vision improvements 2024-09-01 14:15:46 -07:00
Kwindla Hultman Kramer
c2bc64361a first working llama-vision 2024-09-01 14:10:50 -07:00
Aleix Conchillo Flaqué
9bbb824248 introduce FrameProcessor single async push frame task
Pipecat has a pipeline-based architecture. The pipeline consists of frame
processors linked to each other. The elements travelling across the pipeline are
called frames.

To have a deterministic behavior the frames travelling through the pipeline are
awlays ordered, except system frames which are out-of-band frames.

To achieve ordering each frame processor only outputs frames from a single
task. This single task is internally created in each FrameProcessor and the
developer doesn't need to know much about it, except conceptually.

Having a single output task avoids problems such as audio overlapping.
2024-09-01 11:44:00 -07:00
19 changed files with 202 additions and 390 deletions

View File

@@ -70,7 +70,7 @@ async def main():
async def user_idle_callback(user_idle: UserIdleProcessor):
messages.append(
{"role": "system", "content": "Ask the user if they are still there and try to prompt for some input, but be short."})
await user_idle.queue_frame(LLMMessagesFrame(messages))
await user_idle.push_frame(LLMMessagesFrame(messages))
user_idle = UserIdleProcessor(callback=user_idle_callback, timeout=5.0)

View File

@@ -94,6 +94,8 @@ class UserImageRawFrame(ImageRawFrame):
"""
user_id: str
context: Any = None
description: str | None = None
def __str__(self):
return f"{self.name}(user: {self.user_id}, size: {self.size}, format: {self.format})"
@@ -423,7 +425,7 @@ class TTSStoppedFrame(ControlFrame):
class UserImageRequestFrame(ControlFrame):
"""A frame user to request an image from the given user."""
user_id: str
context: Optional[Any] = None
context: Any = None
def __str__(self):
return f"{self.name}, user: {self.user_id}"

View File

@@ -4,7 +4,6 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
import sys
from typing import List
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame, OpenAILLMContext

View File

@@ -1,64 +0,0 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
from pipecat.frames.frames import EndFrame, Frame, StartInterruptionFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class AsyncFrameProcessor(FrameProcessor):
def __init__(
self,
*,
name: str | None = None,
loop: asyncio.AbstractEventLoop | None = None,
**kwargs):
super().__init__(name=name, loop=loop, **kwargs)
self._create_push_task()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, StartInterruptionFrame):
await self._handle_interruptions(frame)
async def queue_frame(
self,
frame: Frame,
direction: FrameDirection = FrameDirection.DOWNSTREAM):
await self._push_queue.put((frame, direction))
async def cleanup(self):
self._push_frame_task.cancel()
await self._push_frame_task
async def _handle_interruptions(self, frame: Frame):
# Cancel the task. This will stop pushing frames downstream.
self._push_frame_task.cancel()
await self._push_frame_task
# Push an out-of-band frame (i.e. not using the ordered push
# frame task).
await self.push_frame(frame)
# Create a new queue and task.
self._create_push_task()
def _create_push_task(self):
self._push_queue = asyncio.Queue()
self._push_frame_task = self.get_event_loop().create_task(self._push_frame_task_handler())
async def _push_frame_task_handler(self):
running = True
while running:
try:
(frame, direction) = await self._push_queue.get()
await self.push_frame(frame, direction)
running = not isinstance(frame, EndFrame)
self._push_queue.task_done()
except asyncio.CancelledError:
break

View File

@@ -10,12 +10,14 @@ import time
from enum import Enum
from pipecat.frames.frames import (
EndFrame,
ErrorFrame,
Frame,
MetricsFrame,
StartFrame,
StartInterruptionFrame,
UserStoppedSpeakingFrame)
StopInterruptionFrame,
SystemFrame)
from pipecat.utils.utils import obj_count, obj_id
from loguru import logger
@@ -105,6 +107,13 @@ class FrameProcessor:
# Metrics
self._metrics = FrameProcessorMetrics(name=self.name)
# Every processor in Pipecat should only output frames from a single
# task. This avoid problems like audio overlapping. System frames are
# the exception to this rule.
#
# This create this task.
self.__create_push_task()
@property
def interruptions_allowed(self):
return self._allow_interruptions
@@ -184,14 +193,41 @@ class FrameProcessor:
self._enable_usage_metrics = frame.enable_usage_metrics
self._report_only_initial_ttfb = frame.report_only_initial_ttfb
elif isinstance(frame, StartInterruptionFrame):
await self._start_interruption()
await self.stop_all_metrics()
elif isinstance(frame, UserStoppedSpeakingFrame):
elif isinstance(frame, StopInterruptionFrame):
self._should_report_ttfb = True
async def push_error(self, error: ErrorFrame):
await self.push_frame(error, FrameDirection.UPSTREAM)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
if isinstance(frame, SystemFrame):
await self.__internal_push_frame(frame, direction)
else:
await self.__push_queue.put((frame, direction))
#
# Handle interruptions
#
async def _start_interruption(self):
# Cancel the task. This will stop pushing frames downstream.
self.__push_frame_task.cancel()
await self.__push_frame_task
# Create a new queue and task.
self.__create_push_task()
async def _stop_interruption(self):
# Nothing to do right now.
pass
def __create_push_task(self):
self.__push_queue = asyncio.Queue()
self.__push_frame_task = self.get_event_loop().create_task(self.__push_frame_task_handler())
async def __internal_push_frame(self, frame: Frame, direction: FrameDirection):
try:
if direction == FrameDirection.DOWNSTREAM and self._next:
logger.trace(f"Pushing {frame} from {self} to {self._next}")
@@ -202,5 +238,16 @@ class FrameProcessor:
except Exception as e:
logger.exception(f"Uncaught exception in {self}: {e}")
async def __push_frame_task_handler(self):
running = True
while running:
try:
(frame, direction) = await self.__push_queue.get()
await self.__internal_push_frame(frame, direction)
running = not isinstance(frame, EndFrame)
self.__push_queue.task_done()
except asyncio.CancelledError:
break
def __str__(self):
return self.name

View File

@@ -286,9 +286,6 @@ class RTVIProcessor(FrameProcessor):
self._registered_actions: Dict[str, RTVIAction] = {}
self._registered_services: Dict[str, RTVIService] = {}
self._push_frame_task = self.get_event_loop().create_task(self._push_frame_task_handler())
self._push_queue = asyncio.Queue()
self._message_task = self.get_event_loop().create_task(self._message_task_handler())
self._message_queue = asyncio.Queue()
@@ -335,12 +332,6 @@ class RTVIProcessor(FrameProcessor):
message = RTVILLMFunctionCallStartMessage(data=fn)
await self._push_transport_message(message, exclude_none=False)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
if isinstance(frame, SystemFrame):
await super().push_frame(frame, direction)
else:
await self._internal_push_frame(frame, direction)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
@@ -394,30 +385,10 @@ class RTVIProcessor(FrameProcessor):
# processing EndFrames.
self._message_task.cancel()
await self._message_task
await self._push_frame_task
async def _cancel(self, frame: CancelFrame):
self._message_task.cancel()
await self._message_task
self._push_frame_task.cancel()
await self._push_frame_task
async def _internal_push_frame(
self,
frame: Frame | None,
direction: FrameDirection | None = FrameDirection.DOWNSTREAM):
await self._push_queue.put((frame, direction))
async def _push_frame_task_handler(self):
running = True
while running:
try:
(frame, direction) = await self._push_queue.get()
await super().push_frame(frame, direction)
self._push_queue.task_done()
running = not isinstance(frame, EndFrame)
except asyncio.CancelledError:
break
async def _push_transport_message(self, model: BaseModel, exclude_none: bool = True):
frame = TransportMessageFrame(

View File

@@ -62,10 +62,6 @@ class GStreamerPipelineSource(FrameProcessor):
bus.add_signal_watch()
bus.connect("message", self._on_gstreamer_message)
# Create push frame task. This is the task that will push frames in
# order. We also guarantee that all frames are pushed in the same task.
self._create_push_task()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
@@ -80,60 +76,28 @@ class GStreamerPipelineSource(FrameProcessor):
elif isinstance(frame, StartFrame):
# Push StartFrame before start(), because we want StartFrame to be
# processed by every processor before any other frame is processed.
await self._internal_push_frame(frame, direction)
await self.push_frame(frame, direction)
await self._start(frame)
elif isinstance(frame, EndFrame):
# Push EndFrame before stop(), because stop() waits on the task to
# finish and the task finishes when EndFrame is processed.
await self._internal_push_frame(frame, direction)
await self.push_frame(frame, direction)
await self._stop(frame)
# Other frames
else:
await self._internal_push_frame(frame, direction)
await self.push_frame(frame, direction)
async def _start(self, frame: StartFrame):
self._player.set_state(Gst.State.PLAYING)
async def _stop(self, frame: EndFrame):
self._player.set_state(Gst.State.NULL)
# Wait for the push frame task to finish. It will finish when the
# EndFrame is actually processed.
await self._push_frame_task
async def _cancel(self, frame: CancelFrame):
self._player.set_state(Gst.State.NULL)
# Cancel all the tasks and wait for them to finish.
self._push_frame_task.cancel()
await self._push_frame_task
#
# Push frames task
#
def _create_push_task(self):
loop = self.get_event_loop()
self._push_queue = asyncio.Queue()
self._push_frame_task = loop.create_task(self._push_frame_task_handler())
async def _internal_push_frame(
self,
frame: Frame | None,
direction: FrameDirection | None = FrameDirection.DOWNSTREAM):
await self._push_queue.put((frame, direction))
async def _push_frame_task_handler(self):
running = True
while running:
try:
(frame, direction) = await self._push_queue.get()
await self.push_frame(frame, direction)
running = not isinstance(frame, EndFrame)
self._push_queue.task_done()
except asyncio.CancelledError:
break
#
# GStreaner
# GStreamer
#
def _on_gstreamer_message(self, bus: Gst.Bus, message: Gst.Message):
@@ -221,7 +185,7 @@ class GStreamerPipelineSource(FrameProcessor):
frame = AudioRawFrame(audio=info.data,
sample_rate=self._out_params.audio_sample_rate,
num_channels=self._out_params.audio_channels)
asyncio.run_coroutine_threadsafe(self._internal_push_frame(frame), self.get_event_loop())
asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())
buffer.unmap(info)
return Gst.FlowReturn.OK
@@ -232,6 +196,6 @@ class GStreamerPipelineSource(FrameProcessor):
image=info.data,
size=(self._out_params.video_width, self._out_params.video_height),
format="RGB")
asyncio.run_coroutine_threadsafe(self._internal_push_frame(frame), self.get_event_loop())
asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())
buffer.unmap(info)
return Gst.FlowReturn.OK

View File

@@ -8,19 +8,14 @@ import asyncio
from typing import Awaitable, Callable, List
from pipecat.frames.frames import Frame, SystemFrame
from pipecat.processors.async_frame_processor import AsyncFrameProcessor
from pipecat.processors.frame_processor import FrameDirection
from pipecat.frames.frames import Frame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class IdleFrameProcessor(AsyncFrameProcessor):
class IdleFrameProcessor(FrameProcessor):
"""This class waits to receive any frame or list of desired frames within a
given timeout. If the timeout is reached before receiving any of those
frames the provided callback will be called.
The callback can then be used to push frames downstream by using
`queue_frame()` (or `push_frame()` for system frames).
"""
def __init__(
@@ -41,10 +36,7 @@ class IdleFrameProcessor(AsyncFrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
else:
await self.queue_frame(frame, direction)
await self.push_frame(frame, direction)
# If we are not waiting for any specific frame set the event, otherwise
# check if we have received one of the desired frames.
@@ -55,7 +47,6 @@ class IdleFrameProcessor(AsyncFrameProcessor):
if isinstance(frame, t):
self._idle_event.set()
# If we are not waiting for any specific frame set the event, otherwise
async def cleanup(self):
self._idle_task.cancel()
await self._idle_task

View File

@@ -11,21 +11,16 @@ from typing import Awaitable, Callable
from pipecat.frames.frames import (
BotSpeakingFrame,
Frame,
SystemFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame)
from pipecat.processors.async_frame_processor import AsyncFrameProcessor
from pipecat.processors.frame_processor import FrameDirection
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class UserIdleProcessor(AsyncFrameProcessor):
class UserIdleProcessor(FrameProcessor):
"""This class is useful to check if the user is interacting with the bot
within a given timeout. If the timeout is reached before any interaction
occurred the provided callback will be called.
The callback can then be used to push frames downstream by using
`queue_frame()` (or `push_frame()` for system frames).
"""
def __init__(
@@ -46,10 +41,7 @@ class UserIdleProcessor(AsyncFrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
else:
await self.queue_frame(frame, direction)
await self.push_frame(frame, direction)
# We shouldn't call the idle callback if the user or the bot are speaking.
if isinstance(frame, UserStartedSpeakingFrame):

View File

@@ -30,9 +30,9 @@ from pipecat.frames.frames import (
TTSVoiceUpdateFrame,
TextFrame,
UserImageRequestFrame,
VisionImageRawFrame
VisionImageRawFrame,
UserImageRawFrame
)
from pipecat.processors.async_frame_processor import AsyncFrameProcessor
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.transcriptions.language import Language
from pipecat.utils.audio import calculate_audio_volume
@@ -64,7 +64,7 @@ class AIService(FrameProcessor):
elif isinstance(frame, EndFrame):
await self.stop(frame)
async def process_generator(self, generator: AsyncGenerator[Frame, None]):
async def process_generator(self, generator: AsyncGenerator[Frame | None, None]):
async for f in generator:
if f:
if isinstance(f, ErrorFrame):
@@ -73,30 +73,6 @@ class AIService(FrameProcessor):
await self.push_frame(f)
class AsyncAIService(AsyncFrameProcessor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def start(self, frame: StartFrame):
pass
async def stop(self, frame: EndFrame):
pass
async def cancel(self, frame: CancelFrame):
pass
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, StartFrame):
await self.start(frame)
elif isinstance(frame, CancelFrame):
await self.cancel(frame)
elif isinstance(frame, EndFrame):
await self.stop(frame)
class LLMService(AIService):
"""This class is a no-op but serves as a base class for LLM services."""
@@ -439,13 +415,14 @@ class VisionService(AIService):
self._describe_text = None
@abstractmethod
async def run_vision(self, frame: VisionImageRawFrame) -> AsyncGenerator[Frame, None]:
async def run_vision(self, frame: VisionImageRawFrame |
UserImageRawFrame) -> AsyncGenerator[Frame, None]:
pass
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, VisionImageRawFrame):
if isinstance(frame, VisionImageRawFrame) or isinstance(frame, UserImageRawFrame):
await self.start_processing_metrics()
await self.process_generator(self.run_vision(frame))
await self.stop_processing_metrics()

View File

@@ -18,13 +18,11 @@ from pipecat.frames.frames import (
ErrorFrame,
Frame,
StartFrame,
SystemFrame,
TTSStartedFrame,
TTSStoppedFrame,
TranscriptionFrame,
URLImageRawFrame)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import AsyncAIService, TTSService, ImageGenService
from pipecat.services.ai_services import STTService, TTSService, ImageGenService
from pipecat.services.openai import BaseOpenAILLMService
from pipecat.utils.time import time_now_iso8601
@@ -118,7 +116,7 @@ class AzureTTSService(TTSService):
logger.error(f"{self} error: {cancellation_details.error_details}")
class AzureSTTService(AsyncAIService):
class AzureSTTService(STTService):
def __init__(
self,
*,
@@ -141,15 +139,11 @@ class AzureSTTService(AsyncAIService):
speech_config=speech_config, audio_config=audio_config)
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
elif isinstance(frame, AudioRawFrame):
self._audio_stream.write(frame.audio)
else:
await self._push_queue.put((frame, direction))
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
await self.start_processing_metrics()
self._audio_stream.write(audio)
await self.stop_processing_metrics()
yield None
async def start(self, frame: StartFrame):
await super().start(frame)
@@ -168,7 +162,7 @@ class AzureSTTService(AsyncAIService):
def _on_handle_recognized(self, event):
if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0:
frame = TranscriptionFrame(event.result.text, "", time_now_iso8601())
asyncio.run_coroutine_threadsafe(self.queue_frame(frame), self.get_event_loop())
asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())
class AzureImageGenServiceREST(ImageGenService):

View File

@@ -161,8 +161,8 @@ class DeepgramSTTService(STTService):
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
await self.start_processing_metrics()
await self._connection.send(audio)
yield None
await self.stop_processing_metrics()
yield None
async def _connect(self):
if await self._connection.start(self._live_options):

View File

@@ -7,20 +7,17 @@
import base64
import json
from typing import Optional
from typing import AsyncGenerator, Optional
from pydantic.main import BaseModel
from pipecat.frames.frames import (
AudioRawFrame,
CancelFrame,
EndFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
SystemFrame,
TranscriptionFrame)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import AsyncAIService
from pipecat.services.ai_services import STTService
from pipecat.utils.time import time_now_iso8601
from loguru import logger
@@ -35,7 +32,7 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}")
class GladiaSTTService(AsyncAIService):
class GladiaSTTService(STTService):
class InputParams(BaseModel):
sample_rate: Optional[int] = 16000
language: Optional[str] = "english"
@@ -57,16 +54,6 @@ class GladiaSTTService(AsyncAIService):
self._params = params
self._confidence = confidence
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
elif isinstance(frame, AudioRawFrame):
await self._send_audio(frame)
else:
await self.queue_frame(frame, direction)
async def start(self, frame: StartFrame):
await super().start(frame)
self._websocket = await websockets.connect(self._url)
@@ -81,6 +68,12 @@ class GladiaSTTService(AsyncAIService):
await super().cancel(frame)
await self._websocket.close()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
await self.start_processing_metrics()
await self._send_audio(audio)
await self.stop_processing_metrics()
yield None
async def _setup_gladia(self):
configuration = {
"x_gladia_key": self._api_key,
@@ -92,9 +85,9 @@ class GladiaSTTService(AsyncAIService):
await self._websocket.send(json.dumps(configuration))
async def _send_audio(self, frame: AudioRawFrame):
async def _send_audio(self, audio: bytes):
message = {
'frames': base64.b64encode(frame.audio).decode("utf-8")
'frames': base64.b64encode(audio).decode("utf-8")
}
await self._websocket.send(json.dumps(message))
@@ -113,6 +106,6 @@ class GladiaSTTService(AsyncAIService):
transcript = utterance["transcription"]
if confidence >= self._confidence:
if type == "final":
await self.queue_frame(TranscriptionFrame(transcript, "", time_now_iso8601()))
await self.push_frame(TranscriptionFrame(transcript, "", time_now_iso8601()))
else:
await self.queue_frame(InterimTranscriptionFrame(transcript, "", time_now_iso8601()))
await self.push_frame(InterimTranscriptionFrame(transcript, "", time_now_iso8601()))

View File

@@ -10,7 +10,13 @@ from PIL import Image
from typing import AsyncGenerator
from pipecat.frames.frames import ErrorFrame, Frame, TextFrame, VisionImageRawFrame
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TextFrame,
ImageRawFrame,
VisionImageRawFrame,
UserImageRawFrame)
from pipecat.services.ai_services import VisionService
from loguru import logger
@@ -48,7 +54,7 @@ class MoondreamService(VisionService):
self,
*,
model="vikhyatk/moondream2",
revision="2024-04-02",
revision="2024-08-26",
use_cpu=False
):
super().__init__()
@@ -70,23 +76,30 @@ class MoondreamService(VisionService):
logger.debug("Loaded Moondream model")
async def run_vision(self, frame: VisionImageRawFrame) -> AsyncGenerator[Frame, None]:
async def run_vision(self, frame: VisionImageRawFrame |
UserImageRawFrame) -> AsyncGenerator[Frame, None]:
if not self._model:
logger.error(f"{self} error: Moondream model not available")
yield ErrorFrame("Moondream model not available")
return
question = getattr(frame, "context", None) or getattr(frame, "text", None)
logger.debug(f"Analyzing image: {frame}")
def get_image_description(frame: VisionImageRawFrame):
def get_image_description(frame: ImageRawFrame):
image = Image.frombytes(frame.format, frame.size, frame.image)
image_embeds = self._model.encode_image(image)
description = self._model.answer_question(
image_embeds=image_embeds,
question=frame.text,
question=question,
tokenizer=self._tokenizer)
return description
description = await asyncio.to_thread(get_image_description, frame)
yield TextFrame(text=description)
if isinstance(frame, VisionImageRawFrame):
yield TextFrame(text=description)
elif isinstance(frame, UserImageRawFrame):
frame.description = description
yield frame

View File

@@ -18,8 +18,6 @@ from pipecat.frames.frames import (
Frame,
LLMModelUpdateFrame,
TextFrame,
VisionImageRawFrame,
UserImageRequestFrame,
UserImageRawFrame,
LLMMessagesFrame,
LLMFullResponseStartFrame,
@@ -100,8 +98,12 @@ class TogetherLLMService(LLMService):
stream=True,
)
# Function calling
got_first_chunk = False
# Function calling. We should be able to prompt Llama 3.1 to always return either plain
# text or a function call. However, occasionally we see a function call after plain text.
# Try to account for that.
most_recent_chunk_was_function_call_start_char = False # function call start char is '<'
accumulating_function_call = False
function_call_accumulator = ""
@@ -131,10 +133,24 @@ class TogetherLLMService(LLMService):
if accumulating_function_call:
function_call_accumulator += chunk.choices[0].delta.content
else:
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
text = chunk.choices[0].delta.content
if most_recent_chunk_was_function_call_start_char:
most_recent_chunk_was_function_call_start_char = False
if text == "function":
accumulating_function_call = True
function_call_accumulator = "<function"
else:
await self.push_frame("<" + TextFrame(chunk.choices[0].delta.content))
elif text == '<':
most_recent_chunk_was_function_call_start_char = True
else:
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
if chunk.choices[0].finish_reason == 'eos' and accumulating_function_call:
await self._extract_function_call(context, function_call_accumulator)
if chunk.choices[0].finish_reason == 'eos':
if accumulating_function_call:
await self._extract_function_call(context, function_call_accumulator)
elif most_recent_chunk_was_function_call_start_char:
await self.push_frame(TextFrame("<"))
except CancelledError as e:
# todo: implement token counting estimates for use when the user interrupts a long generation
@@ -164,13 +180,26 @@ class TogetherLLMService(LLMService):
await self._process_context(context)
async def _extract_function_call(self, context, function_call_accumulator):
# logger.debug(f"Extracting function call: {function_call_accumulator}")
context.add_message({"role": "assistant", "content": function_call_accumulator})
function_regex = r"<function=(\w+)>(.*?)</function>"
# Function format regex. Llama 3.1 sometimes adds an extra " or space just before the
# </function> tag. This regexp just ignores the extra characters if they are there. (That's
# the [\s"]? part of the regex.) Occasionally the </function> close tag is also missing.
function_regex = r'<function=(\w+)>(.*?)<\/function>|<function=(\w+)>(.*)'
match = re.search(function_regex, function_call_accumulator)
if match:
function_name, args_string = match.groups()
function_name = ""
args_string = ""
if match.group(1): # Case with closing tag
function_name = match.group(1)
args_string = match.group(2)
else: # Case without closing tag
function_name = match.group(3)
args_string = match.group(4)
try:
args_string = re.sub(r'[\s"]+$', '', args_string)
arguments = json.loads(args_string)
await self.call_function(context=context,
tool_call_id=str(uuid.uuid4()),
@@ -181,7 +210,8 @@ class TogetherLLMService(LLMService):
# We get here if the LLM returns a function call with invalid JSON arguments. This could happen
# because of LLM non-determinism, or maybe more often because of user error in the prompt.
# Should we do anything more than log a warning?
logger.debug(f"Error parsing function arguments: {error}")
logger.debug(
f"Error parsing function arguments: {error} - {function_call_accumulator}")
class TogetherLLMContext(OpenAILLMContext):
@@ -219,9 +249,17 @@ class TogetherUserContextAggregator(LLMUserContextAggregator):
if isinstance(context, OpenAILLMContext):
self._context = TogetherLLMContext.from_openai_context(context)
def get_messages_frame(self):
return OpenAILLMContextFrame(self._context)
async def push_messages_frame(self):
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
await self.push_frame(self.get_messages_frame())
def append_image_description_tool_message(self, description):
self._context.add_message({
"role": "tool",
"content": json.dumps({"image_description": description})
})
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
@@ -230,20 +268,10 @@ class TogetherUserContextAggregator(LLMUserContextAggregator):
# to talk through (tagging @aleix). At some point we might need to refactor these
# context aggregators.
try:
if isinstance(frame, UserImageRequestFrame):
# The LLM sends a UserImageRequestFrame upstream. Cache any context provided with
# that frame so we can use it when we assemble the image message in the assistant
# context aggregator.
if (frame.context):
if isinstance(frame.context, str):
self._context._user_image_request_context[frame.user_id] = frame.context
else:
logger.error(
f"Unexpected UserImageRequestFrame context type: {type(frame.context)}")
del self._context._user_image_request_context[frame.user_id]
else:
if frame.user_id in self._context._user_image_request_context:
del self._context._user_image_request_context[frame.user_id]
if isinstance(frame, UserImageRawFrame):
if frame.description:
self.append_image_description_tool_message(frame.description)
await self.push_messages_frame()
except Exception as e:
logger.error(f"Error processing frame: {e}")

View File

@@ -37,10 +37,6 @@ class BaseInputTransport(FrameProcessor):
self._executor = ThreadPoolExecutor(max_workers=5)
# Create push frame task. This is the task that will push frames in
# order. We also guarantee that all frames are pushed in the same task.
self._create_push_task()
async def start(self, frame: StartFrame):
# Create audio input queue and task if needed.
if self._params.audio_in_enabled or self._params.vad_enabled:
@@ -53,10 +49,6 @@ class BaseInputTransport(FrameProcessor):
self._audio_task.cancel()
await self._audio_task
# Wait for the push frame task to finish. It will finish when the
# EndFrame is actually processed.
await self._push_frame_task
async def cancel(self, frame: CancelFrame):
# Cancel all the tasks and wait for them to finish.
@@ -64,9 +56,6 @@ class BaseInputTransport(FrameProcessor):
self._audio_task.cancel()
await self._audio_task
self._push_frame_task.cancel()
await self._push_frame_task
def vad_analyzer(self) -> VADAnalyzer | None:
return self._params.vad_analyzer
@@ -86,11 +75,8 @@ class BaseInputTransport(FrameProcessor):
await self.cancel(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, BotInterruptionFrame):
await self._handle_interruptions(frame, False)
elif isinstance(frame, StartInterruptionFrame):
logger.debug("Bot interruption")
await self._start_interruption()
elif isinstance(frame, StopInterruptionFrame):
await self._stop_interruption()
# All other system frames
elif isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
@@ -98,12 +84,12 @@ class BaseInputTransport(FrameProcessor):
elif isinstance(frame, StartFrame):
# Push StartFrame before start(), because we want StartFrame to be
# processed by every processor before any other frame is processed.
await self._internal_push_frame(frame, direction)
await self.push_frame(frame, direction)
await self.start(frame)
elif isinstance(frame, EndFrame):
# Push EndFrame before stop(), because stop() waits on the task to
# finish and the task finishes when EndFrame is processed.
await self._internal_push_frame(frame, direction)
await self.push_frame(frame, direction)
await self.stop(frame)
elif isinstance(frame, VADParamsUpdateFrame):
vad_analyzer = self.vad_analyzer()
@@ -111,73 +97,28 @@ class BaseInputTransport(FrameProcessor):
vad_analyzer.set_params(frame.params)
# Other frames
else:
await self._internal_push_frame(frame, direction)
#
# Push frames task
#
def _create_push_task(self):
loop = self.get_event_loop()
self._push_queue = asyncio.Queue()
self._push_frame_task = loop.create_task(self._push_frame_task_handler())
async def _internal_push_frame(
self,
frame: Frame | None,
direction: FrameDirection | None = FrameDirection.DOWNSTREAM):
await self._push_queue.put((frame, direction))
async def _push_frame_task_handler(self):
running = True
while running:
try:
(frame, direction) = await self._push_queue.get()
await self.push_frame(frame, direction)
running = not isinstance(frame, EndFrame)
self._push_queue.task_done()
except asyncio.CancelledError:
break
await self.push_frame(frame, direction)
#
# Handle interruptions
#
async def _start_interruption(self):
if not self.interruptions_allowed:
return
# Cancel the task. This will stop pushing frames downstream.
self._push_frame_task.cancel()
await self._push_frame_task
# Push an out-of-band frame (i.e. not using the ordered push
# frame task) to stop everything, specially at the output
# transport.
await self.push_frame(StartInterruptionFrame())
# Create a new queue and task.
self._create_push_task()
async def _stop_interruption(self):
if not self.interruptions_allowed:
return
await self.push_frame(StopInterruptionFrame())
async def _handle_interruptions(self, frame: Frame, push_frame: bool):
async def _handle_interruptions(self, frame: Frame):
if self.interruptions_allowed:
# Make sure we notify about interruptions quickly out-of-band
if isinstance(frame, BotInterruptionFrame):
logger.debug("Bot interruption")
await self._start_interruption()
elif isinstance(frame, UserStartedSpeakingFrame):
# Make sure we notify about interruptions quickly out-of-band.
if isinstance(frame, UserStartedSpeakingFrame):
logger.debug("User started speaking")
await self._start_interruption()
# Push an out-of-band frame (i.e. not using the ordered push
# frame task) to stop everything, specially at the output
# transport.
await self.push_frame(StartInterruptionFrame())
elif isinstance(frame, UserStoppedSpeakingFrame):
logger.debug("User stopped speaking")
await self._stop_interruption()
await self.push_frame(StopInterruptionFrame())
if push_frame:
await self._internal_push_frame(frame)
await self.push_frame(frame)
#
# Audio input
@@ -201,7 +142,7 @@ class BaseInputTransport(FrameProcessor):
frame = UserStoppedSpeakingFrame()
if frame:
await self._handle_interruptions(frame, True)
await self._handle_interruptions(frame)
vad_state = new_vad_state
return vad_state
@@ -222,7 +163,7 @@ class BaseInputTransport(FrameProcessor):
# Push audio downstream if passthrough.
if audio_passthrough:
await self._internal_push_frame(frame)
await self.push_frame(frame)
self._audio_in_queue.task_done()
except asyncio.CancelledError:

View File

@@ -66,10 +66,6 @@ class BaseOutputTransport(FrameProcessor):
# generating frames upstream while, for example, the audio is playing.
self._create_sink_task()
# Create push frame task. This is the task that will push frames in
# order. We also guarantee that all frames are pushed in the same task.
self._create_push_task()
async def start(self, frame: StartFrame):
# Create camera output queue and task if needed.
if self._params.camera_out_enabled:
@@ -91,9 +87,8 @@ class BaseOutputTransport(FrameProcessor):
self._audio_out_task.cancel()
await self._audio_out_task
# Wait for the push frame and sink tasks to finish. They will finish when
# the EndFrame is actually processed.
await self._push_frame_task
# Wait for the sink task to finish. They will finish when the EndFrame
# is actually processed.
await self._sink_task
async def cancel(self, frame: CancelFrame):
@@ -103,9 +98,6 @@ class BaseOutputTransport(FrameProcessor):
self._camera_out_task.cancel()
await self._camera_out_task
self._push_frame_task.cancel()
await self._push_frame_task
self._sink_task.cancel()
await self._sink_task
@@ -170,10 +162,6 @@ class BaseOutputTransport(FrameProcessor):
self._sink_task.cancel()
await self._sink_task
self._create_sink_task()
# Stop push task.
self._push_frame_task.cancel()
await self._push_frame_task
self._create_push_task()
# Let's send a bot stopped speaking if we have to.
if self._bot_speaking:
await self._bot_stopped_speaking()
@@ -213,7 +201,7 @@ class BaseOutputTransport(FrameProcessor):
frame = await self._sink_queue.get()
if isinstance(frame, AudioRawFrame):
await self.write_raw_audio_frames(frame.audio)
await self._internal_push_frame(frame)
await self.push_frame(frame)
await self.push_frame(BotSpeakingFrame(), FrameDirection.UPSTREAM)
elif isinstance(frame, ImageRawFrame):
await self._set_camera_image(frame)
@@ -223,12 +211,12 @@ class BaseOutputTransport(FrameProcessor):
await self.send_message(frame)
elif isinstance(frame, TTSStartedFrame):
await self._bot_started_speaking()
await self._internal_push_frame(frame)
await self.push_frame(frame)
elif isinstance(frame, TTSStoppedFrame):
await self._bot_stopped_speaking()
await self._internal_push_frame(frame)
await self.push_frame(frame)
else:
await self._internal_push_frame(frame)
await self.push_frame(frame)
running = not isinstance(frame, EndFrame)
@@ -241,38 +229,12 @@ class BaseOutputTransport(FrameProcessor):
async def _bot_started_speaking(self):
logger.debug("Bot started speaking")
self._bot_speaking = True
await self._internal_push_frame(BotStartedSpeakingFrame(), FrameDirection.UPSTREAM)
await self.push_frame(BotStartedSpeakingFrame(), FrameDirection.UPSTREAM)
async def _bot_stopped_speaking(self):
logger.debug("Bot stopped speaking")
self._bot_speaking = False
await self._internal_push_frame(BotStoppedSpeakingFrame(), FrameDirection.UPSTREAM)
#
# Push frames task
#
def _create_push_task(self):
loop = self.get_event_loop()
self._push_queue = asyncio.Queue()
self._push_frame_task = loop.create_task(self._push_frame_task_handler())
async def _internal_push_frame(
self,
frame: Frame | None,
direction: FrameDirection | None = FrameDirection.DOWNSTREAM):
await self._push_queue.put((frame, direction))
async def _push_frame_task_handler(self):
running = True
while running:
try:
(frame, direction) = await self._push_queue.get()
await self.push_frame(frame, direction)
running = not isinstance(frame, EndFrame)
self._push_queue.task_done()
except asyncio.CancelledError:
break
await self.push_frame(BotStoppedSpeakingFrame(), FrameDirection.UPSTREAM)
#
# Camera out
@@ -356,7 +318,7 @@ class BaseOutputTransport(FrameProcessor):
try:
frame = await self._audio_out_queue.get()
await self.write_raw_audio_frames(frame.audio)
await self._internal_push_frame(frame)
await self.push_frame(frame)
await self.push_frame(BotSpeakingFrame(), FrameDirection.UPSTREAM)
except asyncio.CancelledError:
break

View File

@@ -98,9 +98,9 @@ class WebsocketServerInputTransport(BaseInputTransport):
continue
if isinstance(frame, AudioRawFrame):
await self.push_audio_frame(frame)
await self.queue_audio_frame(frame)
else:
await self._internal_push_frame(frame)
await self.push_frame(frame)
# Notify disconnection
await self._callbacks.on_client_disconnected(websocket)

View File

@@ -612,18 +612,18 @@ class DailyInputTransport(BaseInputTransport):
await super().process_frame(frame, direction)
if isinstance(frame, UserImageRequestFrame):
self.request_participant_image(frame.user_id)
self.request_participant_image(frame.user_id, frame.context)
#
# Frames
#
async def push_transcription_frame(self, frame: TranscriptionFrame | InterimTranscriptionFrame):
await self._internal_push_frame(frame)
await self.push_frame(frame)
async def push_app_message(self, message: Any, sender: str):
frame = DailyTransportMessageFrame(message=message, participant_id=sender)
await self._internal_push_frame(frame)
await self.push_frame(frame)
#
# Audio in
@@ -662,9 +662,10 @@ class DailyInputTransport(BaseInputTransport):
color_format
)
def request_participant_image(self, participant_id: str):
def request_participant_image(self, participant_id: str, context: Any = None):
if participant_id in self._video_renderers:
self._video_renderers[participant_id]["render_next_frame"] = True
truthy = context if context else True
self._video_renderers[participant_id]["render_next_frame"] = truthy
async def _on_participant_video_frame(self, participant_id: str, buffer, size, format):
render_frame = False
@@ -677,16 +678,17 @@ class DailyInputTransport(BaseInputTransport):
next_time = prev_time + 1 / framerate
render_frame = (curr_time - next_time) < 0.1
elif self._video_renderers[participant_id]["render_next_frame"]:
render_frame = self._video_renderers[participant_id]["render_next_frame"]
self._video_renderers[participant_id]["render_next_frame"] = False
render_frame = True
if render_frame:
frame = UserImageRawFrame(
user_id=participant_id,
image=buffer,
size=size,
format=format)
await self._internal_push_frame(frame)
format=format,
context=None if render_frame is True else render_frame)
await self.push_frame(frame)
self._video_renderers[participant_id]["timestamp"] = curr_time