Compare commits

...

2 Commits

Author SHA1 Message Date
Kwindla Hultman Kramer
b36a466798 wip together function calling and vision improvements 2024-08-29 16:25:49 -07:00
Kwindla Hultman Kramer
58f3965cdc first working llama-vision 2024-08-29 11:19:00 -07:00
5 changed files with 87 additions and 40 deletions

View File

@@ -93,6 +93,8 @@ class UserImageRawFrame(ImageRawFrame):
"""
user_id: str
context: Any = None
description: str | None = None
def __str__(self):
return f"{self.name}(user: {self.user_id}, size: {self.size}, format: {self.format})"
@@ -420,7 +422,7 @@ class TTSStoppedFrame(ControlFrame):
class UserImageRequestFrame(ControlFrame):
"""A frame user to request an image from the given user."""
user_id: str
context: Optional[Any] = None
context: Any = None
def __str__(self):
return f"{self.name}, user: {self.user_id}"

View File

@@ -26,7 +26,8 @@ from pipecat.frames.frames import (
TTSVoiceUpdateFrame,
TextFrame,
UserImageRequestFrame,
VisionImageRawFrame
VisionImageRawFrame,
UserImageRawFrame
)
from pipecat.processors.async_frame_processor import AsyncFrameProcessor
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
@@ -399,13 +400,14 @@ class VisionService(AIService):
self._describe_text = None
@abstractmethod
async def run_vision(self, frame: VisionImageRawFrame) -> AsyncGenerator[Frame, None]:
async def run_vision(self, frame: VisionImageRawFrame |
UserImageRawFrame) -> AsyncGenerator[Frame, None]:
pass
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, VisionImageRawFrame):
if isinstance(frame, VisionImageRawFrame) or isinstance(frame, UserImageRawFrame):
await self.start_processing_metrics()
await self.process_generator(self.run_vision(frame))
await self.stop_processing_metrics()

View File

@@ -10,7 +10,13 @@ from PIL import Image
from typing import AsyncGenerator
from pipecat.frames.frames import ErrorFrame, Frame, TextFrame, VisionImageRawFrame
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TextFrame,
ImageRawFrame,
VisionImageRawFrame,
UserImageRawFrame)
from pipecat.services.ai_services import VisionService
from loguru import logger
@@ -48,7 +54,7 @@ class MoondreamService(VisionService):
self,
*,
model="vikhyatk/moondream2",
revision="2024-04-02",
revision="2024-08-26",
use_cpu=False
):
super().__init__()
@@ -70,23 +76,30 @@ class MoondreamService(VisionService):
logger.debug("Loaded Moondream model")
async def run_vision(self, frame: VisionImageRawFrame) -> AsyncGenerator[Frame, None]:
async def run_vision(self, frame: VisionImageRawFrame |
UserImageRawFrame) -> AsyncGenerator[Frame, None]:
if not self._model:
logger.error(f"{self} error: Moondream model not available")
yield ErrorFrame("Moondream model not available")
return
question = getattr(frame, "context", None) or getattr(frame, "text", None)
logger.debug(f"Analyzing image: {frame}")
def get_image_description(frame: VisionImageRawFrame):
def get_image_description(frame: ImageRawFrame):
image = Image.frombytes(frame.format, frame.size, frame.image)
image_embeds = self._model.encode_image(image)
description = self._model.answer_question(
image_embeds=image_embeds,
question=frame.text,
question=question,
tokenizer=self._tokenizer)
return description
description = await asyncio.to_thread(get_image_description, frame)
yield TextFrame(text=description)
if isinstance(frame, VisionImageRawFrame):
yield TextFrame(text=description)
elif isinstance(frame, UserImageRawFrame):
frame.description = description
yield frame

View File

@@ -18,8 +18,6 @@ from pipecat.frames.frames import (
Frame,
LLMModelUpdateFrame,
TextFrame,
VisionImageRawFrame,
UserImageRequestFrame,
UserImageRawFrame,
LLMMessagesFrame,
LLMFullResponseStartFrame,
@@ -100,8 +98,12 @@ class TogetherLLMService(LLMService):
stream=True,
)
# Function calling
got_first_chunk = False
# Function calling. We should be able to prompt Llama 3.1 to always return either plain
# text or a function call. However, occasionally we see a function call after plain text.
# Try to account for that.
most_recent_chunk_was_function_call_start_char = False # function call start char is '<'
accumulating_function_call = False
function_call_accumulator = ""
@@ -131,10 +133,24 @@ class TogetherLLMService(LLMService):
if accumulating_function_call:
function_call_accumulator += chunk.choices[0].delta.content
else:
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
text = chunk.choices[0].delta.content
if most_recent_chunk_was_function_call_start_char:
most_recent_chunk_was_function_call_start_char = False
if text == "function":
accumulating_function_call = True
function_call_accumulator = "<function"
else:
await self.push_frame("<" + TextFrame(chunk.choices[0].delta.content))
elif text == '<':
most_recent_chunk_was_function_call_start_char = True
else:
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
if chunk.choices[0].finish_reason == 'eos' and accumulating_function_call:
await self._extract_function_call(context, function_call_accumulator)
if chunk.choices[0].finish_reason == 'eos':
if accumulating_function_call:
await self._extract_function_call(context, function_call_accumulator)
elif most_recent_chunk_was_function_call_start_char:
await self.push_frame(TextFrame("<"))
except CancelledError as e:
# todo: implement token counting estimates for use when the user interrupts a long generation
@@ -164,13 +180,26 @@ class TogetherLLMService(LLMService):
await self._process_context(context)
async def _extract_function_call(self, context, function_call_accumulator):
# logger.debug(f"Extracting function call: {function_call_accumulator}")
context.add_message({"role": "assistant", "content": function_call_accumulator})
function_regex = r"<function=(\w+)>(.*?)</function>"
# Function format regex. Llama 3.1 sometimes adds an extra " or space just before the
# </function> tag. This regexp just ignores the extra characters if they are there. (That's
# the [\s"]? part of the regex.) Occasionally the </function> close tag is also missing.
function_regex = r'<function=(\w+)>(.*?)<\/function>|<function=(\w+)>(.*)'
match = re.search(function_regex, function_call_accumulator)
if match:
function_name, args_string = match.groups()
function_name = ""
args_string = ""
if match.group(1): # Case with closing tag
function_name = match.group(1)
args_string = match.group(2)
else: # Case without closing tag
function_name = match.group(3)
args_string = match.group(4)
try:
args_string = re.sub(r'[\s"]+$', '', args_string)
arguments = json.loads(args_string)
await self.call_function(context=context,
tool_call_id=str(uuid.uuid4()),
@@ -181,7 +210,8 @@ class TogetherLLMService(LLMService):
# We get here if the LLM returns a function call with invalid JSON arguments. This could happen
# because of LLM non-determinism, or maybe more often because of user error in the prompt.
# Should we do anything more than log a warning?
logger.debug(f"Error parsing function arguments: {error}")
logger.debug(
f"Error parsing function arguments: {error} - {function_call_accumulator}")
class TogetherLLMContext(OpenAILLMContext):
@@ -219,9 +249,17 @@ class TogetherUserContextAggregator(LLMUserContextAggregator):
if isinstance(context, OpenAILLMContext):
self._context = TogetherLLMContext.from_openai_context(context)
def get_messages_frame(self):
return OpenAILLMContextFrame(self._context)
async def push_messages_frame(self):
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
await self.push_frame(self.get_messages_frame())
def append_image_description_tool_message(self, description):
self._context.add_message({
"role": "tool",
"content": json.dumps({"image_description": description})
})
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
@@ -230,20 +268,10 @@ class TogetherUserContextAggregator(LLMUserContextAggregator):
# to talk through (tagging @aleix). At some point we might need to refactor these
# context aggregators.
try:
if isinstance(frame, UserImageRequestFrame):
# The LLM sends a UserImageRequestFrame upstream. Cache any context provided with
# that frame so we can use it when we assemble the image message in the assistant
# context aggregator.
if (frame.context):
if isinstance(frame.context, str):
self._context._user_image_request_context[frame.user_id] = frame.context
else:
logger.error(
f"Unexpected UserImageRequestFrame context type: {type(frame.context)}")
del self._context._user_image_request_context[frame.user_id]
else:
if frame.user_id in self._context._user_image_request_context:
del self._context._user_image_request_context[frame.user_id]
if isinstance(frame, UserImageRawFrame):
if frame.description:
self.append_image_description_tool_message(frame.description)
await self.push_messages_frame()
except Exception as e:
logger.error(f"Error processing frame: {e}")

View File

@@ -611,7 +611,7 @@ class DailyInputTransport(BaseInputTransport):
await super().process_frame(frame, direction)
if isinstance(frame, UserImageRequestFrame):
self.request_participant_image(frame.user_id)
self.request_participant_image(frame.user_id, frame.context)
#
# Frames
@@ -661,9 +661,10 @@ class DailyInputTransport(BaseInputTransport):
color_format
)
def request_participant_image(self, participant_id: str):
def request_participant_image(self, participant_id: str, context: Any = None):
if participant_id in self._video_renderers:
self._video_renderers[participant_id]["render_next_frame"] = True
truthy = context if context else True
self._video_renderers[participant_id]["render_next_frame"] = truthy
async def _on_participant_video_frame(self, participant_id: str, buffer, size, format):
render_frame = False
@@ -676,15 +677,16 @@ class DailyInputTransport(BaseInputTransport):
next_time = prev_time + 1 / framerate
render_frame = (curr_time - next_time) < 0.1
elif self._video_renderers[participant_id]["render_next_frame"]:
render_frame = self._video_renderers[participant_id]["render_next_frame"]
self._video_renderers[participant_id]["render_next_frame"] = False
render_frame = True
if render_frame:
frame = UserImageRawFrame(
user_id=participant_id,
image=buffer,
size=size,
format=format)
format=format,
context=None if render_frame is True else render_frame)
await self._internal_push_frame(frame)
self._video_renderers[participant_id]["timestamp"] = curr_time