from typing import List, Union from PIL import Image, ImageDraw import torch from diffusers.modular_pipelines import ( PipelineState, ModularPipelineBlocks, InputParam, ComponentSpec, OutputParam, ) from transformers import AutoProcessor, AutoModelForCausalLM class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): @property def expected_components(self): return [ ComponentSpec( name="image_annotator", type_hint=AutoModelForCausalLM, repo="microsoft/Florence-2-large", ), ComponentSpec( name="image_annotator_processor", type_hint=AutoProcessor, repo="microsoft/Florence-2-large", ), ] @property def inputs(self) -> List[InputParam]: return [ InputParam( "image", Image, required=True, description="Image(s) to annotate", ), InputParam( "annotation_task_prompt", Union[str, List[str]], required=True, description="""Annotation Task to perform on the image. """, ), ] @property def intermediates_outputs(self) -> List[OutputParam]: return [ OutputParam( "mask", type_hint=torch.Tensor, description="Depth Map(s) of input Image(s)", ), ] def annotate_image(self, image, prompt): inputs = self.image_annotator_processor( text=prompt, images=image, return_tensors="pt" ) generated_ids = self.annotator.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3, ) annotations = self.image_annotator_processor.batch_decode( generated_ids, skip_special_tokens=False )[0] annotations = self.image_annotator_processor.post_process_generation( annotations, task=prompt, image_size=(image.height, image.width) ) return annotations def prepare_mask(self, images, annotations): masks = [] for image, annotation in zip(images, annotations): mask_image = Image.new("L", image.size, 0) draw = ImageDraw.Draw(mask_image) draw.polygon(annotation["polygon"], fill="white") masks.append(mask_image) return masks @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) images = block_state.image annotation_task_prompt = block_state.annotation_task_prompt if not isinstance(annotation_task_prompt, list): annotation_task_prompt = [annotation_task_prompt] if len(images) != len(annotation_task_prompt): raise ValueError("Number of images and annotation prompts must match") annotations = self.annotate_image(images, annotation_task_prompt) block_state.mask = self.prepare_mask(images, annotations) self.set_block_state(block_state)