vision(moondream): process VisionImageRawFrame
This commit is contained in:
@@ -167,6 +167,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated `MoondreamService` to process `VisionImageRawFrame`.
|
||||
|
||||
- `VisionService` expects `VisionImageRawFrame` in order to analyze images.
|
||||
|
||||
- `DailyTransport` triggers `on_error` event if transcription can't be started
|
||||
or stopped.
|
||||
|
||||
|
||||
@@ -11,15 +11,12 @@ for image analysis and description generation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, TextFrame
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, TextFrame, VisionImageRawFrame
|
||||
from pipecat.services.vision_service import VisionService
|
||||
|
||||
try:
|
||||
@@ -92,16 +89,16 @@ class MoondreamService(VisionService):
|
||||
trust_remote_code=True,
|
||||
revision=revision,
|
||||
device_map={"": device},
|
||||
torch_dtype=dtype,
|
||||
dtype=dtype,
|
||||
).eval()
|
||||
|
||||
logger.debug("Loaded Moondream model")
|
||||
|
||||
async def run_vision(self, context: LLMContext) -> AsyncGenerator[Frame, None]:
|
||||
async def run_vision(self, frame: VisionImageRawFrame) -> AsyncGenerator[Frame, None]:
|
||||
"""Analyze an image and generate a description.
|
||||
|
||||
Args:
|
||||
context: The context to process, containing image data.
|
||||
frame: The vision image frame to process.
|
||||
|
||||
Yields:
|
||||
Frame: TextFrame containing the generated image description, or ErrorFrame
|
||||
@@ -112,45 +109,14 @@ class MoondreamService(VisionService):
|
||||
yield ErrorFrame("Moondream model not available")
|
||||
return
|
||||
|
||||
image_bytes = None
|
||||
text = None
|
||||
try:
|
||||
messages = context.get_messages()
|
||||
last_message = messages[-1]
|
||||
last_message_content = last_message.get("content")
|
||||
logger.debug(f"Analyzing image (bytes length: {len(frame.image)})")
|
||||
|
||||
for item in last_message_content:
|
||||
if isinstance(item, dict):
|
||||
if (
|
||||
"image_url" in item
|
||||
and isinstance(item["image_url"], dict)
|
||||
and item["image_url"].get("url")
|
||||
):
|
||||
image_bytes = base64.b64decode(item["image_url"]["url"].split(",")[1])
|
||||
elif "text" in item and isinstance(item["text"], str):
|
||||
text = item["text"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during image extraction: {e}")
|
||||
yield ErrorFrame("Failed to extract image from context")
|
||||
return
|
||||
|
||||
if not image_bytes:
|
||||
logger.error("No image found in context")
|
||||
yield ErrorFrame("No image found in context")
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Analyzing image (bytes length: {len(image_bytes) if image_bytes else 'None'})"
|
||||
)
|
||||
|
||||
def get_image_description(bytes: bytes, text: Optional[str]) -> str:
|
||||
image_buffer = BytesIO(bytes)
|
||||
image = Image.open(image_buffer)
|
||||
def get_image_description(image_bytes: bytes, text: Optional[str]) -> str:
|
||||
image = Image.frombytes(frame.format, frame.size, image_bytes)
|
||||
image_embeds = self._model.encode_image(image)
|
||||
description = self._model.query(image_embeds, text)["answer"]
|
||||
return description
|
||||
|
||||
description = await asyncio.to_thread(get_image_description, image_bytes, text)
|
||||
description = await asyncio.to_thread(get_image_description, frame.image, frame.text)
|
||||
|
||||
yield TextFrame(text=description)
|
||||
|
||||
@@ -14,8 +14,7 @@ visual content.
|
||||
from abc import abstractmethod
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pipecat.frames.frames import Frame, LLMContextFrame
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.frames.frames import Frame, VisionImageRawFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_service import AIService
|
||||
|
||||
@@ -38,15 +37,15 @@ class VisionService(AIService):
|
||||
self._describe_text = None
|
||||
|
||||
@abstractmethod
|
||||
async def run_vision(self, context: LLMContext) -> AsyncGenerator[Frame, None]:
|
||||
"""Process the latest image in the context and generate results.
|
||||
async def run_vision(self, frame: VisionImageRawFrame) -> AsyncGenerator[Frame, None]:
|
||||
"""Process the given vision image and generate results.
|
||||
|
||||
This method must be implemented by subclasses to provide actual computer
|
||||
vision functionality such as image description, object detection, or
|
||||
visual question answering.
|
||||
|
||||
Args:
|
||||
context: The context to process, containing image data.
|
||||
frame: The vision image frame to process.
|
||||
|
||||
Yields:
|
||||
Frame: Frames containing the vision analysis results, typically TextFrame
|
||||
@@ -66,9 +65,9 @@ class VisionService(AIService):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
if isinstance(frame, VisionImageRawFrame):
|
||||
await self.start_processing_metrics()
|
||||
await self.process_generator(self.run_vision(frame.context))
|
||||
await self.process_generator(self.run_vision(frame))
|
||||
await self.stop_processing_metrics()
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
Reference in New Issue
Block a user