import numpy as np import torch from PIL import Image from io import BytesIO from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPImageProcessor from transformers import CLIPConfig from dataclasses import dataclass from transformers import CLIPModel as HFCLIPModel from safetensors.torch import load_file from torch import nn, einsum from .trainer.models.base_model import BaseModelConfig from transformers import CLIPConfig from transformers import AutoProcessor, AutoModel, AutoTokenizer from typing import Any, Optional, Tuple, Union, List import torch from .trainer.models.cross_modeling import Cross_model from .trainer.models import clip_model import torch.nn.functional as F import gc import json from .config import MODEL_PATHS class MPScore(torch.nn.Module): def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'): super().__init__() """Initialize the MPSModel with a processor, tokenizer, and model. Args: device (Union[str, torch.device]): The device to load the model on. """ self.device = device processor_name_or_path = path.get("clip") self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path) self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True) self.model = clip_model.CLIPModel(processor_name_or_path, config_file=True) state_dict = load_file(path.get("mps")) self.model.load_state_dict(state_dict, strict=False) self.model.to(device) self.condition = condition def _calculate_score(self, image: torch.Tensor, prompt: str) -> float: """Calculate the reward score for a single image and prompt. Args: image (torch.Tensor): The processed image tensor. prompt (str): The prompt text. Returns: float: The reward score. """ def _tokenize(caption): input_ids = self.tokenizer( caption, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ).input_ids return input_ids text_input = _tokenize(prompt).to(self.device) if self.condition == 'overall': condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things' elif self.condition == 'aesthetics': condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry' elif self.condition == 'quality': condition_prompt = 'shape, face, hair, hands, limbs, structure, instance, texture' elif self.condition == 'semantic': condition_prompt = 'quantity, attributes, position, number, location' else: raise ValueError( f"Unsupported condition: {self.condition}. Choose 'overall', 'aesthetics', 'quality', or 'semantic'.") condition_batch = _tokenize(condition_prompt).repeat(text_input.shape[0], 1).to(self.device) with torch.no_grad(): text_f, text_features = self.model.model.get_text_features(text_input) image_f = self.model.model.get_image_features(image.half()) condition_f, _ = self.model.model.get_text_features(condition_batch) sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f) sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0] sim_text_condition = sim_text_condition / sim_text_condition.max() mask = torch.where(sim_text_condition > 0.3, 0, float('-inf')) mask = mask.repeat(1, image_f.shape[1], 1) image_features = self.model.cross_model(image_f, text_f, mask.half())[:, 0, :] image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) image_score = self.model.logit_scale.exp() * text_features @ image_features.T return image_score[0].cpu().numpy().item() @torch.no_grad() def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]: """Score the images based on the prompt. Args: images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s). prompt (str): The prompt text. Returns: List[float]: List of reward scores for the images. """ if isinstance(images, (str, Image.Image)): # Single image if isinstance(images, str): image = self.image_processor(Image.open(images), return_tensors="pt")["pixel_values"].to(self.device) else: image = self.image_processor(images, return_tensors="pt")["pixel_values"].to(self.device) return [self._calculate_score(image, prompt)] elif isinstance(images, list): # Multiple images scores = [] for one_images in images: if isinstance(one_images, str): image = self.image_processor(Image.open(one_images), return_tensors="pt")["pixel_values"].to(self.device) elif isinstance(one_images, Image.Image): image = self.image_processor(one_images, return_tensors="pt")["pixel_values"].to(self.device) else: raise TypeError("The type of parameter images is illegal.") scores.append(self._calculate_score(image, prompt)) return scores else: raise TypeError("The type of parameter images is illegal.")