dn6's picture
dn6 HF Staff
Upload folder using huggingface_hub
95e2523 verified
from typing import List, Union
from PIL import Image, ImageDraw
import torch
from diffusers.modular_pipelines import (
PipelineState,
ModularPipelineBlocks,
InputParam,
ComponentSpec,
OutputParam,
)
from transformers import AutoProcessor, AutoModelForCausalLM
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [
ComponentSpec(
name="image_annotator",
type_hint=AutoModelForCausalLM,
repo="microsoft/Florence-2-large",
),
ComponentSpec(
name="image_annotator_processor",
type_hint=AutoProcessor,
repo="microsoft/Florence-2-large",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
Image,
required=True,
description="Image(s) to annotate",
),
InputParam(
"annotation_task_prompt",
Union[str, List[str]],
required=True,
description="""Annotation Task to perform on the image.
""",
),
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"mask",
type_hint=torch.Tensor,
description="Depth Map(s) of input Image(s)",
),
]
def annotate_image(self, image, prompt):
inputs = self.image_annotator_processor(
text=prompt, images=image, return_tensors="pt"
)
generated_ids = self.annotator.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
annotations = self.image_annotator_processor.batch_decode(
generated_ids, skip_special_tokens=False
)[0]
annotations = self.image_annotator_processor.post_process_generation(
annotations, task=prompt, image_size=(image.height, image.width)
)
return annotations
def prepare_mask(self, images, annotations):
masks = []
for image, annotation in zip(images, annotations):
mask_image = Image.new("L", image.size, 0)
draw = ImageDraw.Draw(mask_image)
draw.polygon(annotation["polygon"], fill="white")
masks.append(mask_image)
return masks
@torch.no_grad()
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
images = block_state.image
annotation_task_prompt = block_state.annotation_task_prompt
if not isinstance(annotation_task_prompt, list):
annotation_task_prompt = [annotation_task_prompt]
if len(images) != len(annotation_task_prompt):
raise ValueError("Number of images and annotation prompts must match")
annotations = self.annotate_image(images, annotation_task_prompt)
block_state.mask = self.prepare_mask(images, annotations)
self.set_block_state(block_state)