| from typing import List, Union |
| from PIL import Image, ImageDraw |
| import torch |
|
|
| from diffusers.modular_pipelines import ( |
| PipelineState, |
| ModularPipelineBlocks, |
| InputParam, |
| ComponentSpec, |
| OutputParam, |
| ) |
| from transformers import AutoProcessor, AutoModelForCausalLM |
|
|
|
|
| class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): |
| @property |
| def expected_components(self): |
| return [ |
| ComponentSpec( |
| name="image_annotator", |
| type_hint=AutoModelForCausalLM, |
| repo="microsoft/Florence-2-large", |
| ), |
| ComponentSpec( |
| name="image_annotator_processor", |
| type_hint=AutoProcessor, |
| repo="microsoft/Florence-2-large", |
| ), |
| ] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam( |
| "image", |
| Image, |
| required=True, |
| description="Image(s) to annotate", |
| ), |
| InputParam( |
| "annotation_task_prompt", |
| Union[str, List[str]], |
| required=True, |
| description="""Annotation Task to perform on the image. |
| """, |
| ), |
| ] |
|
|
| @property |
| def intermediates_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam( |
| "mask", |
| type_hint=torch.Tensor, |
| description="Depth Map(s) of input Image(s)", |
| ), |
| ] |
|
|
| def annotate_image(self, image, prompt): |
| inputs = self.image_annotator_processor( |
| text=prompt, images=image, return_tensors="pt" |
| ) |
| generated_ids = self.annotator.generate( |
| input_ids=inputs["input_ids"], |
| pixel_values=inputs["pixel_values"], |
| max_new_tokens=1024, |
| early_stopping=False, |
| do_sample=False, |
| num_beams=3, |
| ) |
| annotations = self.image_annotator_processor.batch_decode( |
| generated_ids, skip_special_tokens=False |
| )[0] |
| annotations = self.image_annotator_processor.post_process_generation( |
| annotations, task=prompt, image_size=(image.height, image.width) |
| ) |
|
|
| return annotations |
|
|
| def prepare_mask(self, images, annotations): |
| masks = [] |
| for image, annotation in zip(images, annotations): |
| mask_image = Image.new("L", image.size, 0) |
| draw = ImageDraw.Draw(mask_image) |
| draw.polygon(annotation["polygon"], fill="white") |
| masks.append(mask_image) |
|
|
| return masks |
|
|
| @torch.no_grad() |
| def __call__(self, pipeline, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
|
|
| images = block_state.image |
| annotation_task_prompt = block_state.annotation_task_prompt |
|
|
| if not isinstance(annotation_task_prompt, list): |
| annotation_task_prompt = [annotation_task_prompt] |
|
|
| if len(images) != len(annotation_task_prompt): |
| raise ValueError("Number of images and annotation prompts must match") |
|
|
| annotations = self.annotate_image(images, annotation_task_prompt) |
| block_state.mask = self.prepare_mask(images, annotations) |
|
|
| self.set_block_state(block_state) |
|
|