import os import sys import importlib from dataclasses import dataclass from typing import Optional, Tuple, Union import torch from huggingface_hub import snapshot_download from safetensors.torch import load_file from transformers import PreTrainedModel from transformers.modeling_outputs import ModelOutput from .configuration_m2_encoder import M2EncoderConfig @dataclass class M2EncoderOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None text_embeds: Optional[torch.FloatTensor] = None image_embeds: Optional[torch.FloatTensor] = None logits_per_image: Optional[torch.FloatTensor] = None logits_per_text: Optional[torch.FloatTensor] = None class M2EncoderModel(PreTrainedModel): config_class = M2EncoderConfig base_model_prefix = "m2_encoder" main_input_name = "pixel_values" def __init__(self, config: M2EncoderConfig): super().__init__(config) model_dir = getattr(config, "_model_dir", None) if model_dir is None: raise ValueError( "M2EncoderConfig is missing `_model_dir`. Use " "`M2EncoderModel.from_pretrained(...)` so the checkpoint path can be resolved." ) if model_dir not in sys.path: sys.path.insert(0, model_dir) vlmo_default_config = importlib.import_module("vlmo.config").config VLMo = importlib.import_module("vlmo.modules").VLMo vlmo_config = vlmo_default_config() vlmo_config.update(config.to_vlmo_overrides(model_dir)) load_path = vlmo_config["load_path"] use_safetensors = load_path.endswith(".safetensors") if use_safetensors: vlmo_config["load_path"] = "" if vlmo_config["flash_attn"]: patch_torch_scale_with_flash_attn = importlib.import_module( "vlmo.utils.patch_utils" ).patch_torch_scale_with_flash_attn patch_torch_scale_with_flash_attn() self.model = VLMo(vlmo_config) if use_safetensors: state_dict = load_file(load_path) self.model.load_state_dict(state_dict, strict=False) @classmethod def from_pretrained( cls, pretrained_model_name_or_path, *model_args, config: Optional[M2EncoderConfig] = None, **kwargs, ): model_dir = pretrained_model_name_or_path if not os.path.isdir(model_dir): model_dir = snapshot_download(repo_id=pretrained_model_name_or_path) if config is None: config = M2EncoderConfig.from_pretrained(model_dir, **kwargs) checkpoint_path = os.path.join( model_dir, kwargs.pop("m2_checkpoint_name", config.model_file), ) if not os.path.exists(checkpoint_path): raise FileNotFoundError( f"Missing M2-Encoder checkpoint: {checkpoint_path}" ) config._model_dir = model_dir return cls(config, *model_args) def get_text_features( self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, ) -> torch.FloatTensor: outputs = self.model.infer_text( { "text_ids": input_ids, "text_masks": attention_mask, "text_labels": None, } ) return outputs["cls_vlffn_feats"] def get_image_features(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor: outputs = self.model.infer_image({"image": [pixel_values]}) return outputs["cls_vlffn_feats"] def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = True, **kwargs, ) -> Union[M2EncoderOutput, Tuple[torch.FloatTensor, ...]]: text_embeds = None image_embeds = None if input_ids is not None: if attention_mask is None: attention_mask = torch.ones_like(input_ids) text_embeds = self.get_text_features( input_ids=input_ids, attention_mask=attention_mask ) if pixel_values is not None: image_embeds = self.get_image_features(pixel_values=pixel_values) logits_per_image = None logits_per_text = None if image_embeds is not None and text_embeds is not None: logit_scale = self.model.logit_scale.exp() logits_per_image = logit_scale * image_embeds @ text_embeds.t() logits_per_text = logits_per_image.t() if not return_dict: return tuple( value for value in ( text_embeds, image_embeds, logits_per_image, logits_per_text, ) if value is not None ) return M2EncoderOutput( text_embeds=text_embeds, image_embeds=image_embeds, logits_per_image=logits_per_image, logits_per_text=logits_per_text, )