Spaces:
No application file
No application file
File size: 8,025 Bytes
b72fefd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
#!/usr/bin/env python3
"""
Example inference script for the Multi-Head SigLIP2 Classifier from Hugging Face Hub.
Usage examples:
# Multiple images, single text
python example.py --image img1.png --image img2.jpg --repo fal/multihead_cls --text "an example caption"
# N images, N texts (returns an N x N similarity matrix)
python example.py \
--image img1.png --image img2.jpg \
--text "a cat" --text "a dog" --repo fal/multihead_cls
Requires: torch, transformers, huggingface_hub, Pillow, click
"""
import json
import click
import torch
from PIL import Image
from transformers import AutoProcessor
from huggingface_hub import hf_hub_download
# Local model definition replicated from training for easy inference
import torch.nn as nn
from transformers import SiglipModel
import torch.nn.functional as F
CKPT = "google/siglip-base-patch16-256"
class MultiHeadSiglipClassifier(nn.Module):
"""Dynamic multi-head classifier based on task configuration"""
def __init__(self, task_config: dict, model_name: str = CKPT):
super().__init__()
self.task_config = task_config
self.siglip = SiglipModel.from_pretrained(model_name)
# Freeze SigLIP parameters
for param in self.siglip.parameters():
param.requires_grad = False
# Create classification heads dynamically based on task config
hidden_size = self.siglip.config.vision_config.hidden_size
self.classification_heads = nn.ModuleDict()
for task in task_config['tasks']:
task_key = task['key']
num_classes = len(task['labels'])
# Create linear layer for this task
head = nn.Linear(hidden_size, num_classes)
self.classification_heads[task_key] = head
def forward(self, pixel_values):
# Get SigLIP image embeddings only
combined_embeds = self.siglip.get_image_features(pixel_values=pixel_values)
# Apply all classification heads
outputs = {}
for task_key, head in self.classification_heads.items():
outputs[task_key] = head(combined_embeds)
return outputs
def load_model_from_hf(repo_id: str):
"""Load model, processor, and task config from Hugging Face Hub"""
# Download task configuration
try:
task_config_path = hf_hub_download(repo_id=repo_id, filename="task_config.json", repo_type="model")
with open(task_config_path, 'r') as f:
task_config = json.load(f)
except Exception as e:
raise RuntimeError(f"Could not load task_config.json from {repo_id}: {e}")
# Load processor
processor = AutoProcessor.from_pretrained(CKPT)
# Create model with task config
model = MultiHeadSiglipClassifier(task_config)
# Load trained weights
try:
ckpt_path = hf_hub_download(repo_id=repo_id, filename="model.pth", repo_type="model")
state_dict = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(state_dict)
except Exception as e:
raise RuntimeError(f"Could not load model.pth from {repo_id}: {e}")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, processor, device, task_config
def predict_batch(model, processor, device, task_config, image_paths, texts: list[str] | None = None):
"""Run predictions on a batch of images using dynamic task configuration"""
images = [Image.open(p).convert("RGB") for p in image_paths]
if texts is not None and len(texts) == 0:
texts = None
# Process images
image_inputs = processor(images=images, return_tensors="pt")
pixel_values = image_inputs["pixel_values"].to(device)
with torch.no_grad():
outputs = model(pixel_values)
# Compute image embeddings for similarity
image_embeds = model.siglip.get_image_features(pixel_values=pixel_values)
image_embeds = F.normalize(image_embeds, p=2, dim=-1)
# Prepare text inputs if provided
text_embeds = None
input_ids = None
attention_mask = None
if texts is not None:
text_inputs = processor(text=texts, padding="max_length", return_tensors="pt")
input_ids = text_inputs["input_ids"].to(device)
attention_mask = text_inputs.get("attention_mask")
attention_mask = attention_mask.to(device) if attention_mask is not None else None
text_embeds = model.siglip.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
text_embeds = F.normalize(text_embeds, p=2, dim=-1)
# Create task mappings
tasks = {task['key']: task for task in task_config['tasks']}
batch_results = []
batch_size = pixel_values.shape[0]
for i in range(batch_size):
item = {"image": str(image_paths[i])}
# Process each task dynamically
for task_key, task_info in tasks.items():
logits = outputs[task_key][i]
probs = torch.softmax(logits, dim=0)
pred_idx = torch.argmax(probs).item()
if task_info['type'] == 'binary':
# Binary classification
item[f"{task_key}_prediction"] = task_info['labels'][pred_idx]
item[f"{task_key}_confidence"] = float(probs[pred_idx].item())
item[f"{task_key}_prob_yes"] = float(probs[1].item()) if len(task_info['labels']) > 1 else 0.0
item[f"{task_key}_prob_no"] = float(probs[0].item())
elif task_info['type'] == 'multi_class':
# Multi-class classification
item[f"{task_key}_prediction"] = task_info['labels'][pred_idx]
item[f"{task_key}_confidence"] = float(probs[pred_idx].item())
# Add probabilities for all classes
for idx, label in enumerate(task_info['labels']):
item[f"{task_key}_prob_{label}"] = float(probs[idx].item())
batch_results.append(item)
cosine_matrix = None
if input_ids is not None:
# These embeds are already L2-normalized inside SigLIP forward
cosine = torch.matmul(image_embeds, text_embeds.T)
cosine_matrix = cosine.cpu().tolist()
return {
"images": [str(p) for p in image_paths],
"texts": texts or [],
"task_config": task_config,
"predictions": batch_results,
"cosine_similarity": cosine_matrix,
}
@click.command()
@click.option("--image", "images", multiple=True, type=click.Path(exists=True, dir_okay=False, readable=True), help="Path(s) to image file(s). Can be passed multiple times.")
@click.option("--repo", default="fal/multihead_cls", show_default=True, help="Hugging Face repo id with model checkpoint.")
@click.option("--text", "texts", multiple=True, help="Text prompt(s). Can be passed multiple times to build an N x N image-text similarity matrix.")
@click.option("--show-tasks", is_flag=True, help="Show available classification tasks and exit.")
def cli(images, repo, texts, show_tasks):
"""Multi-head SigLIP2 classifier inference from Hugging Face Hub"""
# Load model and task config
model, processor, device, task_config = load_model_from_hf(repo)
if show_tasks:
click.echo("Available classification tasks:")
for i, task in enumerate(task_config['tasks'], 1):
click.echo(f" {i}. {task['name']} ({task['key']})")
click.echo(f" Type: {task['type']}")
click.echo(f" Labels: {', '.join(task['labels'])}")
click.echo(f" Description: {task['description']}")
click.echo()
return
if not images:
images = ("img.png",)
results = predict_batch(model, processor, device, task_config, list(images), texts=list(texts) if texts else None)
click.echo(json.dumps(results, indent=2))
if __name__ == "__main__":
cli() |