import gradio as gr from transformers import AutoModel, AutoProcessor from PIL import Image, ImageDraw, ImageFont import logging from datasets import load_dataset logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class DFineDemo: def __init__(self): self.processor = AutoProcessor.from_pretrained("Laudando-Associates-LLC/d-fine", trust_remote_code=True) self.model_variants = { "D-FINE Nano": "Laudando-Associates-LLC/d-fine-nano", "D-FINE Small": "Laudando-Associates-LLC/d-fine-small", "D-FINE Medium": "Laudando-Associates-LLC/d-fine-medium", "D-FINE Large": "Laudando-Associates-LLC/d-fine-large", "D-FINE X-Large": "Laudando-Associates-LLC/d-fine-xlarge" } logger.info("Loading all D-FINE model variants into memory...") self.models = { name: AutoModel.from_pretrained(repo, trust_remote_code=True) for name, repo in self.model_variants.items() } dataset = load_dataset("Laudando-Associates-LLC/pucks", split="test") self.image_cache = { f"Test Image {i+1}": { "input": example["image"], "annotated": example["annotated_image"] } for i, example in enumerate(dataset) } self.image_labels = list(self.image_cache.keys()) def run_inference(self, input_image, model_name, threshold): # Find matching annotated image based on value in self.image_cache for label, pair in self.image_cache.items(): if pair["input"] == input_image: annotated = pair["annotated"] break else: annotated = input_image # fallback # Predict image = input_image.copy() inputs = self.processor(image) outputs = self.models[model_name](**inputs, conf_threshold=threshold) draw = ImageDraw.Draw(image) font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=24) for result in outputs: for box, score in zip(result["boxes"], result["scores"]): x1, y1, x2, y2 = box.tolist() draw.rectangle([x1, y1, x2, y2], outline="blue", width=5) draw.text((x1, max(0, y1 - 25)), f"{score:.2f}", fill="blue", font=font) # Return: (annotated_image, predicted_image) return gr.update(value=(annotated, image), slider_position=50, format="png", type="pil") def select_image(self, evt: gr.SelectData): if evt is None or evt.index is None: return gr.update() label = self.image_labels[evt.index] return self.image_cache[label]["input"] def launch(self): with gr.Blocks(theme=gr.themes.Ocean()) as demo: gr.Markdown(""" ## D-FINE Detection Demo This demo compares annotated ground truth data (in **red**) and model predictions (in **blue**). Use the **slider** to visually compare both views: - The **left image** shows the annotated labels. - The **right image** displays predictions from the selected D-FINE model, with each bounding box and its confidence score. 📂 **Training Dataset**: All D-FINE variants were trained on the [L&A Pucks Dataset](https://huggingface.co/datasets/Laudando-Associates-LLC/pucks) available on Hugging Face. """) output = gr.ImageSlider(type="pil", label="Detected Output", height=500, width=880, slider_position=50, format="png") with gr.Row(): model_selector = gr.Radio( choices=list(self.model_variants.keys()), label="Choose D-FINE model", value="D-FINE Nano" ) threshold_slider = gr.Slider( minimum=0.1, maximum=0.955, value=0.4, step=0.05, label="Confidence Threshold" ) run_btn = gr.Button("Run Detection") selected_image = gr.State(value=self.image_cache[self.image_labels[0]]) gr.Markdown("### Select a sample image below:") gallery = gr.Gallery( value=[(pair["input"], label) for label, pair in self.image_cache.items()], label=None, show_label=False, columns=[3], object_fit="cover", height="auto", allow_preview=False ) gallery.select( fn=self.select_image, inputs=[], outputs=selected_image ) run_btn.click( fn=self.run_inference, inputs=[selected_image, model_selector, threshold_slider], outputs=output ) gr.Markdown("### Citation") gr.Markdown(""" If you use **D-FINE** or its methods in your work, please cite the following BibTeX entry: ```latex @misc{peng2024dfine, title={D-FINE: Redefine Regression Task in DETRs as Fine-grained Distribution Refinement}, author={Yansong Peng and Hebei Li and Peixi Wu and Yueyi Zhang and Xiaoyan Sun and Feng Wu}, year={2024}, eprint={2410.13842}, archivePrefix={arXiv}, primaryClass={cs.CV} } ``` """) demo.launch() if __name__ == "__main__": app = DFineDemo() app.launch()