File size: 7,751 Bytes
3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import abc
import logging
from typing import Callable, List, Optional
import torch
import torch.nn.functional as F
from adapters import AutoAdapterModel
from pie_modules.models import SequencePairSimilarityModelWithPooler
from pie_modules.models.sequence_classification_with_pooler import (
InputType,
OutputType,
SequenceClassificationModelWithPooler,
SequenceClassificationModelWithPoolerBase,
TargetType,
separate_arguments_by_prefix,
)
from pytorch_ie import PyTorchIEModel
from torch import FloatTensor, Tensor
from transformers import AutoConfig, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
logger = logging.getLogger(__name__)
class SequenceClassificationModelWithPoolerAndAdapterBase(
SequenceClassificationModelWithPoolerBase, abc.ABC
):
def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs):
self.adapter_name_or_path = adapter_name_or_path
super().__init__(**kwargs)
def setup_base_model(self) -> PreTrainedModel:
if self.adapter_name_or_path is None:
return super().setup_base_model()
else:
config = AutoConfig.from_pretrained(self.model_name_or_path)
if self.is_from_pretrained:
model = AutoAdapterModel.from_config(config=config)
else:
model = AutoAdapterModel.from_pretrained(self.model_name_or_path, config=config)
# load the adapter in any case (it looks like it is not saved in the state or loaded
# from a serialized state)
logger.info(f"load adapter: {self.adapter_name_or_path}")
model.load_adapter(self.adapter_name_or_path, source="hf", set_active=True)
return model
@PyTorchIEModel.register()
class SequencePairSimilarityModelWithPoolerAndAdapter(
SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
):
pass
@PyTorchIEModel.register()
class SequenceClassificationModelWithPoolerAndAdapter(
SequenceClassificationModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
):
pass
def get_max_cosine_sim(embeddings: Tensor, embeddings_pair: Tensor) -> Tensor:
# Normalize the embeddings
embeddings_normalized = F.normalize(embeddings, p=2, dim=1) # Shape: (n, k)
embeddings_normalized_pair = F.normalize(embeddings_pair, p=2, dim=1) # Shape: (m, k)
# Compute the cosine similarity matrix
cosine_sim = torch.mm(embeddings_normalized, embeddings_normalized_pair.T) # Shape: (n, m)
# Get the overall maximum cosine similarity value
max_cosine_sim = torch.max(cosine_sim) # This will return a scalar
return max_cosine_sim
def get_span_embeddings(
embeddings: FloatTensor, start_indices: Tensor, end_indices: Tensor
) -> List[FloatTensor]:
result = []
for embeds, starts, ends in zip(embeddings, start_indices, end_indices):
span_embeds = embeds[starts[0] : ends[0]]
result.append(span_embeds)
return result
@PyTorchIEModel.register()
class SequencePairSimilarityModelWithMaxCosineSim(SequencePairSimilarityModelWithPooler):
def get_pooled_output(self, model_inputs, pooler_inputs) -> List[FloatTensor]:
output = self.model(**model_inputs)
hidden_state = output.last_hidden_state
# pooled_output = self.pooler(hidden_state, **pooler_inputs)
# pooled_output = self.dropout(pooled_output)
span_embeds = get_span_embeddings(hidden_state, **pooler_inputs)
return span_embeds
def forward(
self,
inputs: InputType,
targets: Optional[TargetType] = None,
return_hidden_states: bool = False,
) -> OutputType:
sanitized_inputs = separate_arguments_by_prefix(
# Note that the order of the prefixes is important because one is a prefix of the other,
# so we need to start with the longer!
arguments=inputs,
prefixes=["pooler_pair_", "pooler_"],
)
span_embeddings = self.get_pooled_output(
model_inputs=sanitized_inputs["remaining"]["encoding"],
pooler_inputs=sanitized_inputs["pooler_"],
)
span_embeddings_pair = self.get_pooled_output(
model_inputs=sanitized_inputs["remaining"]["encoding_pair"],
pooler_inputs=sanitized_inputs["pooler_pair_"],
)
logits_list = [
get_max_cosine_sim(span_embeds, span_embeds_pair)
for span_embeds, span_embeds_pair in zip(span_embeddings, span_embeddings_pair)
]
logits = torch.stack(logits_list)
result = {"logits": logits}
if targets is not None:
labels = targets["scores"]
loss = self.loss_fct(logits, labels)
result["loss"] = loss
if return_hidden_states:
raise NotImplementedError("return_hidden_states is not yet implemented")
return SequenceClassifierOutput(**result)
@PyTorchIEModel.register()
class SequencePairSimilarityModelWithMaxCosineSimAndAdapter(
SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter
):
pass
@PyTorchIEModel.register()
class SequencePairSimilarityModelDummy(SequencePairSimilarityModelWithPooler):
def __init__(
self,
method: str = "random",
random_seed: Optional[int] = None,
**kwargs,
):
self.method = method
self.random_seed = random_seed
super().__init__(**kwargs)
def setup_classifier(
self, pooler_output_dim: int
) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]:
if self.method == "random":
generator = torch.Generator(device=self.device)
if self.random_seed is not None:
generator = generator.manual_seed(self.random_seed)
def binary_classify_random(
inputs: torch.FloatTensor,
inputs_pair: torch.FloatTensor,
) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]:
"""Randomly classifies pairs of inputs as similar or not similar."""
# Generate random logits in the range of [0, 1]
logits = torch.rand(inputs.size(0), device=self.device, generator=generator)
return logits
return binary_classify_random
elif self.method == "zero":
def binary_classify_zero(
inputs: torch.FloatTensor,
inputs_pair: torch.FloatTensor,
) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]:
"""Classifies pairs of inputs as not similar (logit = 0)."""
# Return a tensor of zeros with the same batch size
logits = torch.zeros(inputs.size(0), device=self.device)
return logits
return binary_classify_zero
else:
raise ValueError(
f"Unknown method: {self.method}. Supported methods are 'random' and 'zero'."
)
def setup_loss_fct(self) -> Callable:
def loss_fct(logits: FloatTensor, labels: FloatTensor) -> FloatTensor:
raise NotImplementedError(
"Dummy model does not support loss function, as it is not used for training."
)
return loss_fct
def get_pooled_output(self, model_inputs, pooler_inputs) -> torch.FloatTensor:
# Just return a tensor of zeros in the shape of the batch size
# so that the classifier can construct dummy logits in the correct shape.
bs = pooler_inputs["start_indices"].size(0)
return torch.zeros(bs, device=self.device)
|