File size: 12,139 Bytes
473c3a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
from __future__ import annotations

import logging
import os
import re
from typing import cast

import numpy as np
from huggingface_hub import model_info
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerFast

from distiller.model2vec.distill.inference import PCADimType, create_embeddings, post_process_embeddings
from distiller.model2vec.distill.utils import select_optimal_device
from distiller.model2vec.model import StaticModel
from distiller.model2vec.quantization import DType, quantize_embeddings
from distiller.model2vec.tokenizer import clean_and_create_vocabulary, replace_vocabulary, turn_tokens_into_ids

logger = logging.getLogger(__name__)


def distill_from_model(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerFast,
    vocabulary: list[str] | None = None,
    device: str | None = None,
    pca_dims: PCADimType = 256,
    apply_zipf: bool | None = None,
    sif_coefficient: float | None = 1e-4,
    token_remove_pattern: str | None = r"\[unused\d+\]",
    quantize_to: DType | str = DType.Float16,
    use_subword: bool | None = None,
) -> StaticModel:
    """
    Distill a staticmodel from a sentence transformer.

    This function creates a set of embeddings from a sentence transformer. It does this by doing either
    a forward pass for all subword tokens in the tokenizer, or by doing a forward pass for all tokens in a passed vocabulary.

    If you pass through a vocabulary, we create a custom word tokenizer for that vocabulary.
    If you don't pass a vocabulary, we use the model's tokenizer directly.

    :param model: The model to use.
    :param tokenizer: The tokenizer to use.
    :param vocabulary: The vocabulary to use. If this is None, we use the model's vocabulary.
    :param device: The device to use.
    :param pca_dims: The number of components to use for PCA.
        If this is None, we don't apply PCA.
        If this is 'auto', we don't reduce dimensionality, but still apply PCA.
    :param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
        Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
    :param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
        Should be a value > 0 and < 1.0. A value of 1e-4 is a good default.
    :param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
        If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
    :param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
    :param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
    :return: A StaticModel
    :raises: ValueError if the vocabulary is empty after preprocessing.

    """
    if use_subword is not None:
        logger.warning(
            "The `use_subword` parameter is deprecated and will be removed in the next release. It doesn't do anything."
        )
    quantize_to = DType(quantize_to)
    backend_tokenizer = tokenizer.backend_tokenizer
    sif_coefficient, token_remove_regex = _validate_parameters(apply_zipf, sif_coefficient, token_remove_pattern)

    if vocabulary is None:
        vocabulary = []

    device = select_optimal_device(device)

    n_tokens_before = len(vocabulary)
    # Clean the vocabulary by removing duplicate tokens and tokens that are in the internal vocabulary.
    all_tokens, backend_tokenizer = clean_and_create_vocabulary(
        tokenizer, vocabulary, token_remove_regex=token_remove_regex
    )
    n_tokens_after = len([token for token in all_tokens if not token.is_internal])
    if n_tokens_before:
        logger.info(
            f"Adding {n_tokens_after} tokens to the vocabulary. Removed {n_tokens_before - n_tokens_after} tokens during preprocessing."
        )

    if not all_tokens:
        msg = "The vocabulary is empty after preprocessing. Please check your token_remove_pattern."
        raise ValueError(msg)

    unk_token = cast("str | None", tokenizer.special_tokens_map.get("unk_token"))
    pad_token = cast("str | None", tokenizer.special_tokens_map.get("pad_token"))

    # Weird if to satsify mypy
    if pad_token is None:
        if unk_token is not None:
            pad_token = unk_token
            logger.warning(
                "The pad token is not set. Setting it to the unk token. This is a workaround for models that don't have a pad token."
            )
        else:
            pad_token = unk_token or all_tokens[0].form
            logger.warning(
                "The pad token is not set. Setting it to the first token in the vocabulary. This is a workaround for models that don't have a pad token."
            )

    # Replace the vocabulary in the tokenizer with the new vocabulary.
    backend_tokenizer = replace_vocabulary(backend_tokenizer, all_tokens, unk_token=unk_token, pad_token=pad_token)

    logger.info(f"Creating embeddings for {len(all_tokens)} tokens")
    # Convert tokens to IDs
    token_ids = turn_tokens_into_ids(all_tokens, tokenizer, unk_token)

    # Create the embeddings
    embeddings = create_embeddings(
        tokenized=token_ids, model=model, device=device, pad_token_id=tokenizer.get_vocab()[pad_token]
    )

    # Post process the embeddings by applying PCA and Zipf weighting.
    embeddings = post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)
    # Quantize the embeddings.
    embeddings = quantize_embeddings(embeddings, quantize_to)

    model_name = getattr(model, "name_or_path", "")

    config = {
        "model_type": "model2vec",
        "architectures": ["StaticModel"],
        "tokenizer_name": model_name,
        "apply_pca": pca_dims,
        "apply_zipf": apply_zipf,
        "sif_coefficient": sif_coefficient,
        "hidden_dim": embeddings.shape[1],
        "seq_length": 1000000,  # Set this to a high value since we don't have a sequence length limit.
        "normalize": True,
    }

    if os.path.exists(model_name):
        # Using a local model. Get the model name from the path.
        model_name = os.path.basename(model_name)
        language = None
    else:
        # Get the language from the model card.
        try:
            info = model_info(model_name)
            language = info.cardData.get("language", None) if info.cardData is not None else None
        except Exception as e:
            # NOTE: bare except because there's many reasons this can fail.
            logger.warning(f"Couldn't get the model info from the Hugging Face Hub: {e}. Setting language to None.")
            language = None

    return StaticModel(
        vectors=embeddings,
        tokenizer=backend_tokenizer,
        config=config,
        base_model_name=model_name,
        language=language,
        normalize=True,
    )


def _validate_parameters(
    apply_zipf: bool | None,
    sif_coefficient: float | None,
    token_remove_pattern: str | None,
) -> tuple[float | None, re.Pattern | None]:
    """
    Validate the parameters passed to the distillation function.

    :param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
        Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
    :param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
        Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
    :param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
    :return: The SIF coefficient to use.
    :raises: ValueError if the regex can't be compiled.

    """
    if apply_zipf is not None:
        logger.warning(
            "The `apply_zipf` parameter is deprecated and will be removed in the next release. "
            "Zipf weighting is applied based on the sif_coefficient parameter. If this is set to None, "
            "no weighting is applied."
        )
        if apply_zipf and sif_coefficient is None:
            logger.warning("You set apply_zipf to True, but sif_coefficient is None. Setting sif_coefficient to 1e-4.")
            sif_coefficient = 1e-4
        elif not apply_zipf:
            logger.warning("Because you set apply_zipf to False, we ignore the sif_coefficient parameter.")
            sif_coefficient = None

    if sif_coefficient is not None and not 0 < sif_coefficient < 1.0:
        msg = "SIF coefficient must be a value > 0 and < 1.0."
        raise ValueError(msg)

    token_remove_regex: re.Pattern | None = None
    if token_remove_pattern is not None:
        try:
            token_remove_regex = re.compile(token_remove_pattern)
        except re.error as e:
            msg = f"Couldn't compile the regex pattern: {e}"
            raise ValueError(msg)

    return sif_coefficient, token_remove_regex


def distill(
    model_name: str,
    vocabulary: list[str] | None = None,
    device: str | None = None,
    pca_dims: PCADimType = 256,
    apply_zipf: bool | None = None,
    sif_coefficient: float | None = 1e-4,
    token_remove_pattern: str | None = r"\[unused\d+\]",
    trust_remote_code: bool = False,
    quantize_to: DType | str = DType.Float16,
    use_subword: bool | None = None,
) -> StaticModel:
    """
    Distill a staticmodel from a sentence transformer.

    This function creates a set of embeddings from a sentence transformer. It does this by doing either
    a forward pass for all subword tokens in the tokenizer, or by doing a forward pass for all tokens in a passed vocabulary.

    If you pass through a vocabulary, we create a custom word tokenizer for that vocabulary.
    If you don't pass a vocabulary, we use the model's tokenizer directly.

    :param model_name: The model name to use. Any sentencetransformer compatible model works.
    :param vocabulary: The vocabulary to use. If this is None, we use the model's vocabulary.
    :param device: The device to use.
    :param pca_dims: The number of components to use for PCA.
        If this is None, we don't apply PCA.
        If this is 'auto', we don't reduce dimenionality, but still apply PCA.
    :param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
        Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
    :param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
        Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
    :param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
    :param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components.
    :param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
    :param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
    :return: A StaticModel

    """
    model: PreTrainedModel = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code)
    tokenizer = cast(
        "PreTrainedTokenizerFast",
        AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code, use_fast=True),
    )

    return distill_from_model(
        model=model,
        tokenizer=tokenizer,
        vocabulary=vocabulary,
        device=device,
        pca_dims=pca_dims,
        apply_zipf=apply_zipf,
        token_remove_pattern=token_remove_pattern,
        sif_coefficient=sif_coefficient,
        quantize_to=quantize_to,
        use_subword=use_subword,
    )