|
from typing import List, Union |
|
from PIL import Image, ImageDraw |
|
import torch |
|
|
|
from diffusers.modular_pipelines import ( |
|
PipelineState, |
|
ModularPipelineBlocks, |
|
InputParam, |
|
ComponentSpec, |
|
OutputParam, |
|
) |
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
|
|
|
|
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): |
|
@property |
|
def expected_components(self): |
|
return [ |
|
ComponentSpec( |
|
name="image_annotator", |
|
type_hint=AutoModelForCausalLM, |
|
repo="microsoft/Florence-2-large", |
|
), |
|
ComponentSpec( |
|
name="image_annotator_processor", |
|
type_hint=AutoProcessor, |
|
repo="microsoft/Florence-2-large", |
|
), |
|
] |
|
|
|
@property |
|
def inputs(self) -> List[InputParam]: |
|
return [ |
|
InputParam( |
|
"image", |
|
Image, |
|
required=True, |
|
description="Image(s) to annotate", |
|
), |
|
InputParam( |
|
"annotation_task_prompt", |
|
Union[str, List[str]], |
|
required=True, |
|
description="""Annotation Task to perform on the image. |
|
""", |
|
), |
|
] |
|
|
|
@property |
|
def intermediates_outputs(self) -> List[OutputParam]: |
|
return [ |
|
OutputParam( |
|
"mask", |
|
type_hint=torch.Tensor, |
|
description="Depth Map(s) of input Image(s)", |
|
), |
|
] |
|
|
|
def annotate_image(self, image, prompt): |
|
inputs = self.image_annotator_processor( |
|
text=prompt, images=image, return_tensors="pt" |
|
) |
|
generated_ids = self.annotator.generate( |
|
input_ids=inputs["input_ids"], |
|
pixel_values=inputs["pixel_values"], |
|
max_new_tokens=1024, |
|
early_stopping=False, |
|
do_sample=False, |
|
num_beams=3, |
|
) |
|
annotations = self.image_annotator_processor.batch_decode( |
|
generated_ids, skip_special_tokens=False |
|
)[0] |
|
annotations = self.image_annotator_processor.post_process_generation( |
|
annotations, task=prompt, image_size=(image.height, image.width) |
|
) |
|
|
|
return annotations |
|
|
|
def prepare_mask(self, images, annotations): |
|
masks = [] |
|
for image, annotation in zip(images, annotations): |
|
mask_image = Image.new("L", image.size, 0) |
|
draw = ImageDraw.Draw(mask_image) |
|
draw.polygon(annotation["polygon"], fill="white") |
|
masks.append(mask_image) |
|
|
|
return masks |
|
|
|
@torch.no_grad() |
|
def __call__(self, pipeline, state: PipelineState) -> PipelineState: |
|
block_state = self.get_block_state(state) |
|
|
|
images = block_state.image |
|
annotation_task_prompt = block_state.annotation_task_prompt |
|
|
|
if not isinstance(annotation_task_prompt, list): |
|
annotation_task_prompt = [annotation_task_prompt] |
|
|
|
if len(images) != len(annotation_task_prompt): |
|
raise ValueError("Number of images and annotation prompts must match") |
|
|
|
annotations = self.annotate_image(images, annotation_task_prompt) |
|
block_state.mask = self.prepare_mask(images, annotations) |
|
|
|
self.set_block_state(block_state) |
|
|