Compare commits
2 Commits
hush/realt
...
khk/togeth
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b36a466798 | ||
|
|
58f3965cdc |
@@ -93,6 +93,8 @@ class UserImageRawFrame(ImageRawFrame):
|
||||
|
||||
"""
|
||||
user_id: str
|
||||
context: Any = None
|
||||
description: str | None = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(user: {self.user_id}, size: {self.size}, format: {self.format})"
|
||||
@@ -420,7 +422,7 @@ class TTSStoppedFrame(ControlFrame):
|
||||
class UserImageRequestFrame(ControlFrame):
|
||||
"""A frame user to request an image from the given user."""
|
||||
user_id: str
|
||||
context: Optional[Any] = None
|
||||
context: Any = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}, user: {self.user_id}"
|
||||
|
||||
@@ -26,7 +26,8 @@ from pipecat.frames.frames import (
|
||||
TTSVoiceUpdateFrame,
|
||||
TextFrame,
|
||||
UserImageRequestFrame,
|
||||
VisionImageRawFrame
|
||||
VisionImageRawFrame,
|
||||
UserImageRawFrame
|
||||
)
|
||||
from pipecat.processors.async_frame_processor import AsyncFrameProcessor
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
@@ -399,13 +400,14 @@ class VisionService(AIService):
|
||||
self._describe_text = None
|
||||
|
||||
@abstractmethod
|
||||
async def run_vision(self, frame: VisionImageRawFrame) -> AsyncGenerator[Frame, None]:
|
||||
async def run_vision(self, frame: VisionImageRawFrame |
|
||||
UserImageRawFrame) -> AsyncGenerator[Frame, None]:
|
||||
pass
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, VisionImageRawFrame):
|
||||
if isinstance(frame, VisionImageRawFrame) or isinstance(frame, UserImageRawFrame):
|
||||
await self.start_processing_metrics()
|
||||
await self.process_generator(self.run_vision(frame))
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
@@ -10,7 +10,13 @@ from PIL import Image
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, TextFrame, VisionImageRawFrame
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TextFrame,
|
||||
ImageRawFrame,
|
||||
VisionImageRawFrame,
|
||||
UserImageRawFrame)
|
||||
from pipecat.services.ai_services import VisionService
|
||||
|
||||
from loguru import logger
|
||||
@@ -48,7 +54,7 @@ class MoondreamService(VisionService):
|
||||
self,
|
||||
*,
|
||||
model="vikhyatk/moondream2",
|
||||
revision="2024-04-02",
|
||||
revision="2024-08-26",
|
||||
use_cpu=False
|
||||
):
|
||||
super().__init__()
|
||||
@@ -70,23 +76,30 @@ class MoondreamService(VisionService):
|
||||
|
||||
logger.debug("Loaded Moondream model")
|
||||
|
||||
async def run_vision(self, frame: VisionImageRawFrame) -> AsyncGenerator[Frame, None]:
|
||||
async def run_vision(self, frame: VisionImageRawFrame |
|
||||
UserImageRawFrame) -> AsyncGenerator[Frame, None]:
|
||||
if not self._model:
|
||||
logger.error(f"{self} error: Moondream model not available")
|
||||
yield ErrorFrame("Moondream model not available")
|
||||
return
|
||||
|
||||
question = getattr(frame, "context", None) or getattr(frame, "text", None)
|
||||
|
||||
logger.debug(f"Analyzing image: {frame}")
|
||||
|
||||
def get_image_description(frame: VisionImageRawFrame):
|
||||
def get_image_description(frame: ImageRawFrame):
|
||||
image = Image.frombytes(frame.format, frame.size, frame.image)
|
||||
image_embeds = self._model.encode_image(image)
|
||||
description = self._model.answer_question(
|
||||
image_embeds=image_embeds,
|
||||
question=frame.text,
|
||||
question=question,
|
||||
tokenizer=self._tokenizer)
|
||||
return description
|
||||
|
||||
description = await asyncio.to_thread(get_image_description, frame)
|
||||
|
||||
yield TextFrame(text=description)
|
||||
if isinstance(frame, VisionImageRawFrame):
|
||||
yield TextFrame(text=description)
|
||||
elif isinstance(frame, UserImageRawFrame):
|
||||
frame.description = description
|
||||
yield frame
|
||||
|
||||
@@ -18,8 +18,6 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMModelUpdateFrame,
|
||||
TextFrame,
|
||||
VisionImageRawFrame,
|
||||
UserImageRequestFrame,
|
||||
UserImageRawFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
@@ -100,8 +98,12 @@ class TogetherLLMService(LLMService):
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Function calling
|
||||
got_first_chunk = False
|
||||
|
||||
# Function calling. We should be able to prompt Llama 3.1 to always return either plain
|
||||
# text or a function call. However, occasionally we see a function call after plain text.
|
||||
# Try to account for that.
|
||||
most_recent_chunk_was_function_call_start_char = False # function call start char is '<'
|
||||
accumulating_function_call = False
|
||||
function_call_accumulator = ""
|
||||
|
||||
@@ -131,10 +133,24 @@ class TogetherLLMService(LLMService):
|
||||
if accumulating_function_call:
|
||||
function_call_accumulator += chunk.choices[0].delta.content
|
||||
else:
|
||||
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
|
||||
text = chunk.choices[0].delta.content
|
||||
if most_recent_chunk_was_function_call_start_char:
|
||||
most_recent_chunk_was_function_call_start_char = False
|
||||
if text == "function":
|
||||
accumulating_function_call = True
|
||||
function_call_accumulator = "<function"
|
||||
else:
|
||||
await self.push_frame("<" + TextFrame(chunk.choices[0].delta.content))
|
||||
elif text == '<':
|
||||
most_recent_chunk_was_function_call_start_char = True
|
||||
else:
|
||||
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
|
||||
|
||||
if chunk.choices[0].finish_reason == 'eos' and accumulating_function_call:
|
||||
await self._extract_function_call(context, function_call_accumulator)
|
||||
if chunk.choices[0].finish_reason == 'eos':
|
||||
if accumulating_function_call:
|
||||
await self._extract_function_call(context, function_call_accumulator)
|
||||
elif most_recent_chunk_was_function_call_start_char:
|
||||
await self.push_frame(TextFrame("<"))
|
||||
|
||||
except CancelledError as e:
|
||||
# todo: implement token counting estimates for use when the user interrupts a long generation
|
||||
@@ -164,13 +180,26 @@ class TogetherLLMService(LLMService):
|
||||
await self._process_context(context)
|
||||
|
||||
async def _extract_function_call(self, context, function_call_accumulator):
|
||||
# logger.debug(f"Extracting function call: {function_call_accumulator}")
|
||||
context.add_message({"role": "assistant", "content": function_call_accumulator})
|
||||
|
||||
function_regex = r"<function=(\w+)>(.*?)</function>"
|
||||
# Function format regex. Llama 3.1 sometimes adds an extra " or space just before the
|
||||
# </function> tag. This regexp just ignores the extra characters if they are there. (That's
|
||||
# the [\s"]? part of the regex.) Occasionally the </function> close tag is also missing.
|
||||
function_regex = r'<function=(\w+)>(.*?)<\/function>|<function=(\w+)>(.*)'
|
||||
match = re.search(function_regex, function_call_accumulator)
|
||||
if match:
|
||||
function_name, args_string = match.groups()
|
||||
function_name = ""
|
||||
args_string = ""
|
||||
if match.group(1): # Case with closing tag
|
||||
function_name = match.group(1)
|
||||
args_string = match.group(2)
|
||||
else: # Case without closing tag
|
||||
function_name = match.group(3)
|
||||
args_string = match.group(4)
|
||||
|
||||
try:
|
||||
args_string = re.sub(r'[\s"]+$', '', args_string)
|
||||
arguments = json.loads(args_string)
|
||||
await self.call_function(context=context,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
@@ -181,7 +210,8 @@ class TogetherLLMService(LLMService):
|
||||
# We get here if the LLM returns a function call with invalid JSON arguments. This could happen
|
||||
# because of LLM non-determinism, or maybe more often because of user error in the prompt.
|
||||
# Should we do anything more than log a warning?
|
||||
logger.debug(f"Error parsing function arguments: {error}")
|
||||
logger.debug(
|
||||
f"Error parsing function arguments: {error} - {function_call_accumulator}")
|
||||
|
||||
|
||||
class TogetherLLMContext(OpenAILLMContext):
|
||||
@@ -219,9 +249,17 @@ class TogetherUserContextAggregator(LLMUserContextAggregator):
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
self._context = TogetherLLMContext.from_openai_context(context)
|
||||
|
||||
def get_messages_frame(self):
|
||||
return OpenAILLMContextFrame(self._context)
|
||||
|
||||
async def push_messages_frame(self):
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
await self.push_frame(self.get_messages_frame())
|
||||
|
||||
def append_image_description_tool_message(self, description):
|
||||
self._context.add_message({
|
||||
"role": "tool",
|
||||
"content": json.dumps({"image_description": description})
|
||||
})
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -230,20 +268,10 @@ class TogetherUserContextAggregator(LLMUserContextAggregator):
|
||||
# to talk through (tagging @aleix). At some point we might need to refactor these
|
||||
# context aggregators.
|
||||
try:
|
||||
if isinstance(frame, UserImageRequestFrame):
|
||||
# The LLM sends a UserImageRequestFrame upstream. Cache any context provided with
|
||||
# that frame so we can use it when we assemble the image message in the assistant
|
||||
# context aggregator.
|
||||
if (frame.context):
|
||||
if isinstance(frame.context, str):
|
||||
self._context._user_image_request_context[frame.user_id] = frame.context
|
||||
else:
|
||||
logger.error(
|
||||
f"Unexpected UserImageRequestFrame context type: {type(frame.context)}")
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
else:
|
||||
if frame.user_id in self._context._user_image_request_context:
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
if isinstance(frame, UserImageRawFrame):
|
||||
if frame.description:
|
||||
self.append_image_description_tool_message(frame.description)
|
||||
await self.push_messages_frame()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
|
||||
@@ -611,7 +611,7 @@ class DailyInputTransport(BaseInputTransport):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserImageRequestFrame):
|
||||
self.request_participant_image(frame.user_id)
|
||||
self.request_participant_image(frame.user_id, frame.context)
|
||||
|
||||
#
|
||||
# Frames
|
||||
@@ -661,9 +661,10 @@ class DailyInputTransport(BaseInputTransport):
|
||||
color_format
|
||||
)
|
||||
|
||||
def request_participant_image(self, participant_id: str):
|
||||
def request_participant_image(self, participant_id: str, context: Any = None):
|
||||
if participant_id in self._video_renderers:
|
||||
self._video_renderers[participant_id]["render_next_frame"] = True
|
||||
truthy = context if context else True
|
||||
self._video_renderers[participant_id]["render_next_frame"] = truthy
|
||||
|
||||
async def _on_participant_video_frame(self, participant_id: str, buffer, size, format):
|
||||
render_frame = False
|
||||
@@ -676,15 +677,16 @@ class DailyInputTransport(BaseInputTransport):
|
||||
next_time = prev_time + 1 / framerate
|
||||
render_frame = (curr_time - next_time) < 0.1
|
||||
elif self._video_renderers[participant_id]["render_next_frame"]:
|
||||
render_frame = self._video_renderers[participant_id]["render_next_frame"]
|
||||
self._video_renderers[participant_id]["render_next_frame"] = False
|
||||
render_frame = True
|
||||
|
||||
if render_frame:
|
||||
frame = UserImageRawFrame(
|
||||
user_id=participant_id,
|
||||
image=buffer,
|
||||
size=size,
|
||||
format=format)
|
||||
format=format,
|
||||
context=None if render_frame is True else render_frame)
|
||||
await self._internal_push_frame(frame)
|
||||
|
||||
self._video_renderers[participant_id]["timestamp"] = curr_time
|
||||
|
||||
Reference in New Issue
Block a user