#!/usr/bin/env python3 """ Example inference script for the Multi-Head SigLIP2 Classifier from Hugging Face Hub. Usage examples: # Multiple images, single text python example.py --image img1.png --image img2.jpg --repo fal/multihead_cls --text "an example caption" # N images, N texts (returns an N x N similarity matrix) python example.py \ --image img1.png --image img2.jpg \ --text "a cat" --text "a dog" --repo fal/multihead_cls Requires: torch, transformers, huggingface_hub, Pillow, click """ import json import click import torch from PIL import Image from transformers import AutoProcessor from huggingface_hub import hf_hub_download # Local model definition replicated from training for easy inference import torch.nn as nn from transformers import SiglipModel import torch.nn.functional as F CKPT = "google/siglip-base-patch16-256" class MultiHeadSiglipClassifier(nn.Module): """Dynamic multi-head classifier based on task configuration""" def __init__(self, task_config: dict, model_name: str = CKPT): super().__init__() self.task_config = task_config self.siglip = SiglipModel.from_pretrained(model_name) # Freeze SigLIP parameters for param in self.siglip.parameters(): param.requires_grad = False # Create classification heads dynamically based on task config hidden_size = self.siglip.config.vision_config.hidden_size self.classification_heads = nn.ModuleDict() for task in task_config['tasks']: task_key = task['key'] num_classes = len(task['labels']) # Create linear layer for this task head = nn.Linear(hidden_size, num_classes) self.classification_heads[task_key] = head def forward(self, pixel_values): # Get SigLIP image embeddings only combined_embeds = self.siglip.get_image_features(pixel_values=pixel_values) # Apply all classification heads outputs = {} for task_key, head in self.classification_heads.items(): outputs[task_key] = head(combined_embeds) return outputs def load_model_from_hf(repo_id: str): """Load model, processor, and task config from Hugging Face Hub""" # Download task configuration try: task_config_path = hf_hub_download(repo_id=repo_id, filename="task_config.json", repo_type="model") with open(task_config_path, 'r') as f: task_config = json.load(f) except Exception as e: raise RuntimeError(f"Could not load task_config.json from {repo_id}: {e}") # Load processor processor = AutoProcessor.from_pretrained(CKPT) # Create model with task config model = MultiHeadSiglipClassifier(task_config) # Load trained weights try: ckpt_path = hf_hub_download(repo_id=repo_id, filename="model.pth", repo_type="model") state_dict = torch.load(ckpt_path, map_location="cpu") model.load_state_dict(state_dict) except Exception as e: raise RuntimeError(f"Could not load model.pth from {repo_id}: {e}") model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) return model, processor, device, task_config def predict_batch(model, processor, device, task_config, image_paths, texts: list[str] | None = None): """Run predictions on a batch of images using dynamic task configuration""" images = [Image.open(p).convert("RGB") for p in image_paths] if texts is not None and len(texts) == 0: texts = None # Process images image_inputs = processor(images=images, return_tensors="pt") pixel_values = image_inputs["pixel_values"].to(device) with torch.no_grad(): outputs = model(pixel_values) # Compute image embeddings for similarity image_embeds = model.siglip.get_image_features(pixel_values=pixel_values) image_embeds = F.normalize(image_embeds, p=2, dim=-1) # Prepare text inputs if provided text_embeds = None input_ids = None attention_mask = None if texts is not None: text_inputs = processor(text=texts, padding="max_length", return_tensors="pt") input_ids = text_inputs["input_ids"].to(device) attention_mask = text_inputs.get("attention_mask") attention_mask = attention_mask.to(device) if attention_mask is not None else None text_embeds = model.siglip.get_text_features(input_ids=input_ids, attention_mask=attention_mask) text_embeds = F.normalize(text_embeds, p=2, dim=-1) # Create task mappings tasks = {task['key']: task for task in task_config['tasks']} batch_results = [] batch_size = pixel_values.shape[0] for i in range(batch_size): item = {"image": str(image_paths[i])} # Process each task dynamically for task_key, task_info in tasks.items(): logits = outputs[task_key][i] probs = torch.softmax(logits, dim=0) pred_idx = torch.argmax(probs).item() if task_info['type'] == 'binary': # Binary classification item[f"{task_key}_prediction"] = task_info['labels'][pred_idx] item[f"{task_key}_confidence"] = float(probs[pred_idx].item()) item[f"{task_key}_prob_yes"] = float(probs[1].item()) if len(task_info['labels']) > 1 else 0.0 item[f"{task_key}_prob_no"] = float(probs[0].item()) elif task_info['type'] == 'multi_class': # Multi-class classification item[f"{task_key}_prediction"] = task_info['labels'][pred_idx] item[f"{task_key}_confidence"] = float(probs[pred_idx].item()) # Add probabilities for all classes for idx, label in enumerate(task_info['labels']): item[f"{task_key}_prob_{label}"] = float(probs[idx].item()) batch_results.append(item) cosine_matrix = None if input_ids is not None: # These embeds are already L2-normalized inside SigLIP forward cosine = torch.matmul(image_embeds, text_embeds.T) cosine_matrix = cosine.cpu().tolist() return { "images": [str(p) for p in image_paths], "texts": texts or [], "task_config": task_config, "predictions": batch_results, "cosine_similarity": cosine_matrix, } @click.command() @click.option("--image", "images", multiple=True, type=click.Path(exists=True, dir_okay=False, readable=True), help="Path(s) to image file(s). Can be passed multiple times.") @click.option("--repo", default="fal/multihead_cls", show_default=True, help="Hugging Face repo id with model checkpoint.") @click.option("--text", "texts", multiple=True, help="Text prompt(s). Can be passed multiple times to build an N x N image-text similarity matrix.") @click.option("--show-tasks", is_flag=True, help="Show available classification tasks and exit.") def cli(images, repo, texts, show_tasks): """Multi-head SigLIP2 classifier inference from Hugging Face Hub""" # Load model and task config model, processor, device, task_config = load_model_from_hf(repo) if show_tasks: click.echo("Available classification tasks:") for i, task in enumerate(task_config['tasks'], 1): click.echo(f" {i}. {task['name']} ({task['key']})") click.echo(f" Type: {task['type']}") click.echo(f" Labels: {', '.join(task['labels'])}") click.echo(f" Description: {task['description']}") click.echo() return if not images: images = ("img.png",) results = predict_batch(model, processor, device, task_config, list(images), texts=list(texts) if texts else None) click.echo(json.dumps(results, indent=2)) if __name__ == "__main__": cli()