Spaces:
Sleeping
Sleeping
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| from lavis.common.dist_utils import download_cached_file | |
| from lavis.common.registry import registry | |
| from lavis.common.utils import get_abs_path, is_url | |
| from lavis.models.base_model import MomentumDistilationMixin | |
| from lavis.models.blip_models.blip import BlipBase | |
| from lavis.models.blip_models.blip_outputs import BlipIntermediateOutput, BlipOutput | |
| from lavis.models.blip_models.nlvr_encoder import BertModel | |
| from lavis.models.vit import VisionTransformerEncoder, interpolate_pos_embed | |
| from torch import nn | |
| from transformers import BertConfig | |
| class BlipNLVR(BlipBase, MomentumDistilationMixin): | |
| """ | |
| Class for BLIP NLVR model. | |
| Supported model types: | |
| - base: model with pre-trained BLIP weights, used as initialization for fine-tuning. | |
| - nlvr: finetuned model on NLVR2 dataset. | |
| Usage: | |
| >>> from lavis.models import load_model | |
| >>> model = load_model("blip_nlvr", "nlvr") | |
| """ | |
| PRETRAINED_MODEL_CONFIG_DICT = { | |
| "nlvr": "configs/models/blip_nlvr.yaml", | |
| } | |
| def __init__(self, image_encoder, text_encoder, num_classes): | |
| super().__init__() | |
| self.tokenizer = self.init_tokenizer() | |
| self.visual_encoder = image_encoder | |
| self.text_encoder = text_encoder | |
| hidden_size = text_encoder.config.hidden_size | |
| self.cls_head = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size), | |
| nn.ReLU(), | |
| nn.Linear(hidden_size, num_classes), | |
| ) | |
| def forward(self, samples, is_train=True): | |
| """ | |
| Forward function for training and evaluation. | |
| Args: | |
| samples (dict): a dict of input samples, which contains the following keys: | |
| - image0 (torch.Tensor): input image 0, shape (batch_size, 3, H, W), default H=384, W=384. | |
| - image1 (torch.Tensor): input image 1, shape (batch_size, 3, H, W), default H=384, W=384. | |
| - text_input (list): list of strings, each string is a natural language sentence. | |
| - label (torch.LongTensor): ground truth label with shape (batch_size,). | |
| is_train (bool): whether the model is in training mode. | |
| If True, the model will return the loss; | |
| If False, the model will return the prediction. | |
| Examples: | |
| >>> import torch | |
| >>> from lavis.models import load_model | |
| >>> model = load_model("blip_nlvr", "nlvr") | |
| >>> samples = { | |
| ... "image0": torch.randn(2, 3, 384, 384), | |
| ... "image1": torch.randn(2, 3, 384, 384), | |
| ... "text_input": ["there is a ferret in tall grass", "there are lips in one of the images"], | |
| ... "label": torch.tensor([0, 1]), | |
| ... } | |
| >>> output = model(samples) | |
| >>> output.keys() | |
| odict_keys(['intermediate_output', 'loss']) | |
| """ | |
| text = samples["text_input"] | |
| text = self.tokenizer(text, padding="longest", return_tensors="pt").to( | |
| self.device | |
| ) | |
| text.input_ids[:, 0] = self.tokenizer.enc_token_id | |
| targets = samples["label"] | |
| image0 = samples["image0"] | |
| image1 = samples["image1"] | |
| images = torch.cat([image0, image1], dim=0) | |
| image_embeds = self.visual_encoder.forward_features(images) | |
| image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( | |
| self.device | |
| ) | |
| image0_embeds, image1_embeds = torch.split(image_embeds, targets.size(0)) | |
| encoder_output = self.text_encoder( | |
| text.input_ids, | |
| attention_mask=text.attention_mask, | |
| encoder_hidden_states=[image0_embeds, image1_embeds], | |
| encoder_attention_mask=[ | |
| image_atts[: image0_embeds.size(0)], | |
| image_atts[image0_embeds.size(0) :], | |
| ], | |
| return_dict=True, | |
| ) | |
| prediction = self.cls_head(encoder_output.last_hidden_state[:, 0, :]) | |
| if is_train: | |
| loss = F.cross_entropy(prediction, targets) | |
| # return {"loss": loss} | |
| return BlipOutput( | |
| loss=loss, | |
| intermediate_output=BlipIntermediateOutput( | |
| image_embeds=torch.stack([image0_embeds, image1_embeds], dim=0), | |
| encoder_output=encoder_output, | |
| ), | |
| ) | |
| else: | |
| return {"predictions": prediction, "targets": targets} | |
| def predict(self, samples): | |
| output = self.forward(samples, is_train=False) | |
| return output | |
| def from_config(cls, cfg=None): | |
| image_encoder = VisionTransformerEncoder.from_config(cfg) | |
| # text encoder + multimodal encoder | |
| bert_config = BertConfig.from_json_file(get_abs_path(cfg["med_config_path"])) | |
| text_encoder = BertModel(config=bert_config, add_pooling_layer=False) | |
| num_classes = cfg.get("num_classes", 3) | |
| assert num_classes > 1, "Invalid number of classes provided, found {}".format( | |
| num_classes | |
| ) | |
| model = cls( | |
| image_encoder=image_encoder, | |
| text_encoder=text_encoder, | |
| num_classes=num_classes, | |
| ) | |
| model.load_checkpoint_from_config(cfg) | |
| return model | |
| def load_from_pretrained(self, url_or_filename): | |
| if is_url(url_or_filename): | |
| cached_file = download_cached_file( | |
| url_or_filename, check_hash=False, progress=True | |
| ) | |
| checkpoint = torch.load(cached_file, map_location="cpu") | |
| elif os.path.isfile(url_or_filename): | |
| checkpoint = torch.load(url_or_filename, map_location="cpu") | |
| else: | |
| raise RuntimeError("checkpoint url or path is invalid") | |
| state_dict = checkpoint["model"] | |
| state_dict["visual_encoder.pos_embed"] = interpolate_pos_embed( | |
| state_dict["visual_encoder.pos_embed"], self.visual_encoder | |
| ) | |
| for key in list(state_dict.keys()): | |
| if "crossattention.self." in key: | |
| new_key0 = key.replace("self", "self0") | |
| new_key1 = key.replace("self", "self1") | |
| state_dict[new_key0] = state_dict[key] | |
| state_dict[new_key1] = state_dict[key] | |
| elif "crossattention.output.dense." in key: | |
| new_key0 = key.replace("dense", "dense0") | |
| new_key1 = key.replace("dense", "dense1") | |
| state_dict[new_key0] = state_dict[key] | |
| state_dict[new_key1] = state_dict[key] | |
| msg = self.load_state_dict(state_dict, strict=False) | |
| print("load checkpoint from %s" % url_or_filename) | |
| print(f"missing keys {msg.missing_keys}") | |
| return msg | |