File size: 7,751 Bytes
3133b5e
 
d868d2e
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d868d2e
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d868d2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import abc
import logging
from typing import Callable, List, Optional

import torch
import torch.nn.functional as F
from adapters import AutoAdapterModel
from pie_modules.models import SequencePairSimilarityModelWithPooler
from pie_modules.models.sequence_classification_with_pooler import (
    InputType,
    OutputType,
    SequenceClassificationModelWithPooler,
    SequenceClassificationModelWithPoolerBase,
    TargetType,
    separate_arguments_by_prefix,
)
from pytorch_ie import PyTorchIEModel
from torch import FloatTensor, Tensor
from transformers import AutoConfig, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput

logger = logging.getLogger(__name__)


class SequenceClassificationModelWithPoolerAndAdapterBase(
    SequenceClassificationModelWithPoolerBase, abc.ABC
):
    def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs):
        self.adapter_name_or_path = adapter_name_or_path
        super().__init__(**kwargs)

    def setup_base_model(self) -> PreTrainedModel:
        if self.adapter_name_or_path is None:
            return super().setup_base_model()
        else:
            config = AutoConfig.from_pretrained(self.model_name_or_path)
            if self.is_from_pretrained:
                model = AutoAdapterModel.from_config(config=config)
            else:
                model = AutoAdapterModel.from_pretrained(self.model_name_or_path, config=config)
            # load the adapter in any case (it looks like it is not saved in the state or loaded
            # from a serialized state)
            logger.info(f"load adapter: {self.adapter_name_or_path}")
            model.load_adapter(self.adapter_name_or_path, source="hf", set_active=True)
            return model


@PyTorchIEModel.register()
class SequencePairSimilarityModelWithPoolerAndAdapter(
    SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
):
    pass


@PyTorchIEModel.register()
class SequenceClassificationModelWithPoolerAndAdapter(
    SequenceClassificationModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
):
    pass


def get_max_cosine_sim(embeddings: Tensor, embeddings_pair: Tensor) -> Tensor:
    # Normalize the embeddings
    embeddings_normalized = F.normalize(embeddings, p=2, dim=1)  # Shape: (n, k)
    embeddings_normalized_pair = F.normalize(embeddings_pair, p=2, dim=1)  # Shape: (m, k)

    # Compute the cosine similarity matrix
    cosine_sim = torch.mm(embeddings_normalized, embeddings_normalized_pair.T)  # Shape: (n, m)

    # Get the overall maximum cosine similarity value
    max_cosine_sim = torch.max(cosine_sim)  # This will return a scalar
    return max_cosine_sim


def get_span_embeddings(
    embeddings: FloatTensor, start_indices: Tensor, end_indices: Tensor
) -> List[FloatTensor]:
    result = []
    for embeds, starts, ends in zip(embeddings, start_indices, end_indices):
        span_embeds = embeds[starts[0] : ends[0]]
        result.append(span_embeds)
    return result


@PyTorchIEModel.register()
class SequencePairSimilarityModelWithMaxCosineSim(SequencePairSimilarityModelWithPooler):
    def get_pooled_output(self, model_inputs, pooler_inputs) -> List[FloatTensor]:
        output = self.model(**model_inputs)
        hidden_state = output.last_hidden_state
        # pooled_output = self.pooler(hidden_state, **pooler_inputs)
        # pooled_output = self.dropout(pooled_output)
        span_embeds = get_span_embeddings(hidden_state, **pooler_inputs)
        return span_embeds

    def forward(
        self,
        inputs: InputType,
        targets: Optional[TargetType] = None,
        return_hidden_states: bool = False,
    ) -> OutputType:
        sanitized_inputs = separate_arguments_by_prefix(
            # Note that the order of the prefixes is important because one is a prefix of the other,
            # so we need to start with the longer!
            arguments=inputs,
            prefixes=["pooler_pair_", "pooler_"],
        )

        span_embeddings = self.get_pooled_output(
            model_inputs=sanitized_inputs["remaining"]["encoding"],
            pooler_inputs=sanitized_inputs["pooler_"],
        )
        span_embeddings_pair = self.get_pooled_output(
            model_inputs=sanitized_inputs["remaining"]["encoding_pair"],
            pooler_inputs=sanitized_inputs["pooler_pair_"],
        )

        logits_list = [
            get_max_cosine_sim(span_embeds, span_embeds_pair)
            for span_embeds, span_embeds_pair in zip(span_embeddings, span_embeddings_pair)
        ]
        logits = torch.stack(logits_list)

        result = {"logits": logits}
        if targets is not None:
            labels = targets["scores"]
            loss = self.loss_fct(logits, labels)
            result["loss"] = loss
        if return_hidden_states:
            raise NotImplementedError("return_hidden_states is not yet implemented")

        return SequenceClassifierOutput(**result)


@PyTorchIEModel.register()
class SequencePairSimilarityModelWithMaxCosineSimAndAdapter(
    SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter
):
    pass


@PyTorchIEModel.register()
class SequencePairSimilarityModelDummy(SequencePairSimilarityModelWithPooler):

    def __init__(
        self,
        method: str = "random",
        random_seed: Optional[int] = None,
        **kwargs,
    ):
        self.method = method
        self.random_seed = random_seed
        super().__init__(**kwargs)

    def setup_classifier(
        self, pooler_output_dim: int
    ) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]:
        if self.method == "random":
            generator = torch.Generator(device=self.device)
            if self.random_seed is not None:
                generator = generator.manual_seed(self.random_seed)

            def binary_classify_random(
                inputs: torch.FloatTensor,
                inputs_pair: torch.FloatTensor,
            ) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]:
                """Randomly classifies pairs of inputs as similar or not similar."""
                # Generate random logits in the range of [0, 1]
                logits = torch.rand(inputs.size(0), device=self.device, generator=generator)
                return logits

            return binary_classify_random
        elif self.method == "zero":

            def binary_classify_zero(
                inputs: torch.FloatTensor,
                inputs_pair: torch.FloatTensor,
            ) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]:
                """Classifies pairs of inputs as not similar (logit = 0)."""
                # Return a tensor of zeros with the same batch size
                logits = torch.zeros(inputs.size(0), device=self.device)
                return logits

            return binary_classify_zero
        else:
            raise ValueError(
                f"Unknown method: {self.method}. Supported methods are 'random' and 'zero'."
            )

    def setup_loss_fct(self) -> Callable:
        def loss_fct(logits: FloatTensor, labels: FloatTensor) -> FloatTensor:
            raise NotImplementedError(
                "Dummy model does not support loss function, as it is not used for training."
            )

        return loss_fct

    def get_pooled_output(self, model_inputs, pooler_inputs) -> torch.FloatTensor:
        # Just return a tensor of zeros in the shape of the batch size
        # so that the classifier can construct dummy logits in the correct shape.
        bs = pooler_inputs["start_indices"].size(0)
        return torch.zeros(bs, device=self.device)