Spaces:
Running
on
L40S
Running
on
L40S
import torch | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModel | |
from typing import List, Union | |
import os | |
from .config import MODEL_PATHS | |
class PickScore(torch.nn.Module): | |
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS): | |
super().__init__() | |
"""Initialize the Selector with a processor and model. | |
Args: | |
device (Union[str, torch.device]): The device to load the model on. | |
""" | |
self.device = device if isinstance(device, torch.device) else torch.device(device) | |
processor_name_or_path = path.get("clip") | |
model_pretrained_name_or_path = path.get("pickscore") | |
self.processor = AutoProcessor.from_pretrained(processor_name_or_path) | |
self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(self.device) | |
def _calculate_score(self, image: torch.Tensor, prompt: str, softmax: bool = False) -> float: | |
"""Calculate the score for a single image and prompt. | |
Args: | |
image (torch.Tensor): The processed image tensor. | |
prompt (str): The prompt text. | |
softmax (bool): Whether to apply softmax to the scores. | |
Returns: | |
float: The score for the image. | |
""" | |
with torch.no_grad(): | |
# Prepare text inputs | |
text_inputs = self.processor( | |
text=prompt, | |
padding=True, | |
truncation=True, | |
max_length=77, | |
return_tensors="pt", | |
).to(self.device) | |
# Embed images and text | |
image_embs = self.model.get_image_features(pixel_values=image) | |
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True) | |
text_embs = self.model.get_text_features(**text_inputs) | |
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True) | |
# Compute score | |
score = (text_embs @ image_embs.T)[0] | |
if softmax: | |
# Apply logit scale and softmax | |
score = torch.softmax(self.model.logit_scale.exp() * score, dim=-1) | |
return score.cpu().item() | |
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str, softmax: bool = False) -> List[float]: | |
"""Score the images based on the prompt. | |
Args: | |
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s). | |
prompt (str): The prompt text. | |
softmax (bool): Whether to apply softmax to the scores. | |
Returns: | |
List[float]: List of scores for the images. | |
""" | |
try: | |
if isinstance(images, (str, Image.Image)): | |
# Single image | |
if isinstance(images, str): | |
pil_image = Image.open(images) | |
else: | |
pil_image = images | |
# Prepare image inputs | |
image_inputs = self.processor( | |
images=pil_image, | |
padding=True, | |
truncation=True, | |
max_length=77, | |
return_tensors="pt", | |
).to(self.device) | |
return [self._calculate_score(image_inputs["pixel_values"], prompt, softmax)] | |
elif isinstance(images, list): | |
# Multiple images | |
scores = [] | |
for one_image in images: | |
if isinstance(one_image, str): | |
pil_image = Image.open(one_image) | |
elif isinstance(one_image, Image.Image): | |
pil_image = one_image | |
else: | |
raise TypeError("The type of parameter images is illegal.") | |
# Prepare image inputs | |
image_inputs = self.processor( | |
images=pil_image, | |
padding=True, | |
truncation=True, | |
max_length=77, | |
return_tensors="pt", | |
).to(self.device) | |
scores.append(self._calculate_score(image_inputs["pixel_values"], prompt, softmax)) | |
return scores | |
else: | |
raise TypeError("The type of parameter images is illegal.") | |
except Exception as e: | |
raise RuntimeError(f"Error in scoring images: {e}") | |