M2-Encoder-1B / modeling_m2_encoder.py
malusama's picture
Upload safetensors export
ea0524d verified
import os
import sys
import importlib
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from .configuration_m2_encoder import M2EncoderConfig
@dataclass
class M2EncoderOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
text_embeds: Optional[torch.FloatTensor] = None
image_embeds: Optional[torch.FloatTensor] = None
logits_per_image: Optional[torch.FloatTensor] = None
logits_per_text: Optional[torch.FloatTensor] = None
class M2EncoderModel(PreTrainedModel):
config_class = M2EncoderConfig
base_model_prefix = "m2_encoder"
main_input_name = "pixel_values"
def __init__(self, config: M2EncoderConfig):
super().__init__(config)
model_dir = getattr(config, "_model_dir", None)
if model_dir is None:
raise ValueError(
"M2EncoderConfig is missing `_model_dir`. Use "
"`M2EncoderModel.from_pretrained(...)` so the checkpoint path can be resolved."
)
if model_dir not in sys.path:
sys.path.insert(0, model_dir)
vlmo_default_config = importlib.import_module("vlmo.config").config
VLMo = importlib.import_module("vlmo.modules").VLMo
vlmo_config = vlmo_default_config()
vlmo_config.update(config.to_vlmo_overrides(model_dir))
load_path = vlmo_config["load_path"]
use_safetensors = load_path.endswith(".safetensors")
if use_safetensors:
vlmo_config["load_path"] = ""
if vlmo_config["flash_attn"]:
patch_torch_scale_with_flash_attn = importlib.import_module(
"vlmo.utils.patch_utils"
).patch_torch_scale_with_flash_attn
patch_torch_scale_with_flash_attn()
self.model = VLMo(vlmo_config)
if use_safetensors:
state_dict = load_file(load_path)
self.model.load_state_dict(state_dict, strict=False)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path,
*model_args,
config: Optional[M2EncoderConfig] = None,
**kwargs,
):
model_dir = pretrained_model_name_or_path
if not os.path.isdir(model_dir):
model_dir = snapshot_download(repo_id=pretrained_model_name_or_path)
if config is None:
config = M2EncoderConfig.from_pretrained(model_dir, **kwargs)
checkpoint_path = os.path.join(
model_dir,
kwargs.pop("m2_checkpoint_name", config.model_file),
)
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
f"Missing M2-Encoder checkpoint: {checkpoint_path}"
)
config._model_dir = model_dir
return cls(config, *model_args)
def get_text_features(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
) -> torch.FloatTensor:
outputs = self.model.infer_text(
{
"text_ids": input_ids,
"text_masks": attention_mask,
"text_labels": None,
}
)
return outputs["cls_vlffn_feats"]
def get_image_features(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
outputs = self.model.infer_image({"image": [pixel_values]})
return outputs["cls_vlffn_feats"]
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = True,
**kwargs,
) -> Union[M2EncoderOutput, Tuple[torch.FloatTensor, ...]]:
text_embeds = None
image_embeds = None
if input_ids is not None:
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
text_embeds = self.get_text_features(
input_ids=input_ids, attention_mask=attention_mask
)
if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values=pixel_values)
logits_per_image = None
logits_per_text = None
if image_embeds is not None and text_embeds is not None:
logit_scale = self.model.logit_scale.exp()
logits_per_image = logit_scale * image_embeds @ text_embeds.t()
logits_per_text = logits_per_image.t()
if not return_dict:
return tuple(
value
for value in (
text_embeds,
image_embeds,
logits_per_image,
logits_per_text,
)
if value is not None
)
return M2EncoderOutput(
text_embeds=text_embeds,
image_embeds=image_embeds,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
)