import abc import logging from typing import Callable, List, Optional import torch import torch.nn.functional as F from adapters import AutoAdapterModel from pie_modules.models import SequencePairSimilarityModelWithPooler from pie_modules.models.sequence_classification_with_pooler import ( InputType, OutputType, SequenceClassificationModelWithPooler, SequenceClassificationModelWithPoolerBase, TargetType, separate_arguments_by_prefix, ) from pytorch_ie import PyTorchIEModel from torch import FloatTensor, Tensor from transformers import AutoConfig, PreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput logger = logging.getLogger(__name__) class SequenceClassificationModelWithPoolerAndAdapterBase( SequenceClassificationModelWithPoolerBase, abc.ABC ): def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs): self.adapter_name_or_path = adapter_name_or_path super().__init__(**kwargs) def setup_base_model(self) -> PreTrainedModel: if self.adapter_name_or_path is None: return super().setup_base_model() else: config = AutoConfig.from_pretrained(self.model_name_or_path) if self.is_from_pretrained: model = AutoAdapterModel.from_config(config=config) else: model = AutoAdapterModel.from_pretrained(self.model_name_or_path, config=config) # load the adapter in any case (it looks like it is not saved in the state or loaded # from a serialized state) logger.info(f"load adapter: {self.adapter_name_or_path}") model.load_adapter(self.adapter_name_or_path, source="hf", set_active=True) return model @PyTorchIEModel.register() class SequencePairSimilarityModelWithPoolerAndAdapter( SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase ): pass @PyTorchIEModel.register() class SequenceClassificationModelWithPoolerAndAdapter( SequenceClassificationModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase ): pass def get_max_cosine_sim(embeddings: Tensor, embeddings_pair: Tensor) -> Tensor: # Normalize the embeddings embeddings_normalized = F.normalize(embeddings, p=2, dim=1) # Shape: (n, k) embeddings_normalized_pair = F.normalize(embeddings_pair, p=2, dim=1) # Shape: (m, k) # Compute the cosine similarity matrix cosine_sim = torch.mm(embeddings_normalized, embeddings_normalized_pair.T) # Shape: (n, m) # Get the overall maximum cosine similarity value max_cosine_sim = torch.max(cosine_sim) # This will return a scalar return max_cosine_sim def get_span_embeddings( embeddings: FloatTensor, start_indices: Tensor, end_indices: Tensor ) -> List[FloatTensor]: result = [] for embeds, starts, ends in zip(embeddings, start_indices, end_indices): span_embeds = embeds[starts[0] : ends[0]] result.append(span_embeds) return result @PyTorchIEModel.register() class SequencePairSimilarityModelWithMaxCosineSim(SequencePairSimilarityModelWithPooler): def get_pooled_output(self, model_inputs, pooler_inputs) -> List[FloatTensor]: output = self.model(**model_inputs) hidden_state = output.last_hidden_state # pooled_output = self.pooler(hidden_state, **pooler_inputs) # pooled_output = self.dropout(pooled_output) span_embeds = get_span_embeddings(hidden_state, **pooler_inputs) return span_embeds def forward( self, inputs: InputType, targets: Optional[TargetType] = None, return_hidden_states: bool = False, ) -> OutputType: sanitized_inputs = separate_arguments_by_prefix( # Note that the order of the prefixes is important because one is a prefix of the other, # so we need to start with the longer! arguments=inputs, prefixes=["pooler_pair_", "pooler_"], ) span_embeddings = self.get_pooled_output( model_inputs=sanitized_inputs["remaining"]["encoding"], pooler_inputs=sanitized_inputs["pooler_"], ) span_embeddings_pair = self.get_pooled_output( model_inputs=sanitized_inputs["remaining"]["encoding_pair"], pooler_inputs=sanitized_inputs["pooler_pair_"], ) logits_list = [ get_max_cosine_sim(span_embeds, span_embeds_pair) for span_embeds, span_embeds_pair in zip(span_embeddings, span_embeddings_pair) ] logits = torch.stack(logits_list) result = {"logits": logits} if targets is not None: labels = targets["scores"] loss = self.loss_fct(logits, labels) result["loss"] = loss if return_hidden_states: raise NotImplementedError("return_hidden_states is not yet implemented") return SequenceClassifierOutput(**result) @PyTorchIEModel.register() class SequencePairSimilarityModelWithMaxCosineSimAndAdapter( SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter ): pass @PyTorchIEModel.register() class SequencePairSimilarityModelDummy(SequencePairSimilarityModelWithPooler): def __init__( self, method: str = "random", random_seed: Optional[int] = None, **kwargs, ): self.method = method self.random_seed = random_seed super().__init__(**kwargs) def setup_classifier( self, pooler_output_dim: int ) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]: if self.method == "random": generator = torch.Generator(device=self.device) if self.random_seed is not None: generator = generator.manual_seed(self.random_seed) def binary_classify_random( inputs: torch.FloatTensor, inputs_pair: torch.FloatTensor, ) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]: """Randomly classifies pairs of inputs as similar or not similar.""" # Generate random logits in the range of [0, 1] logits = torch.rand(inputs.size(0), device=self.device, generator=generator) return logits return binary_classify_random elif self.method == "zero": def binary_classify_zero( inputs: torch.FloatTensor, inputs_pair: torch.FloatTensor, ) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]: """Classifies pairs of inputs as not similar (logit = 0).""" # Return a tensor of zeros with the same batch size logits = torch.zeros(inputs.size(0), device=self.device) return logits return binary_classify_zero else: raise ValueError( f"Unknown method: {self.method}. Supported methods are 'random' and 'zero'." ) def setup_loss_fct(self) -> Callable: def loss_fct(logits: FloatTensor, labels: FloatTensor) -> FloatTensor: raise NotImplementedError( "Dummy model does not support loss function, as it is not used for training." ) return loss_fct def get_pooled_output(self, model_inputs, pooler_inputs) -> torch.FloatTensor: # Just return a tensor of zeros in the shape of the batch size # so that the classifier can construct dummy logits in the correct shape. bs = pooler_inputs["start_indices"].size(0) return torch.zeros(bs, device=self.device)