File size: 2,304 Bytes
473c3a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from __future__ import annotations

from enum import Enum

import numpy as np


class DType(str, Enum):
    Float16 = "float16"
    Float32 = "float32"
    Float64 = "float64"
    Int8 = "int8"


def quantize_embeddings(embeddings: np.ndarray, quantize_to: DType) -> np.ndarray:
    """
    Quantize embeddings to a specified data type to reduce memory usage.

    :param embeddings: The embeddings to quantize, as a numpy array.
    :param quantize_to: The data type to quantize to.
    :return: The quantized embeddings.
    :raises ValueError: If the quantization type is not valid.
    """
    if quantize_to == DType.Float16:
        return embeddings.astype(np.float16)
    if quantize_to == DType.Float32:
        return embeddings.astype(np.float32)
    if quantize_to == DType.Float64:
        return embeddings.astype(np.float64)
    if quantize_to == DType.Int8:
        # Normalize to [-128, 127] range for int8
        # We normalize to -127 to 127 to keep symmetry.
        scale = np.max(np.abs(embeddings)) / 127.0
        return np.round(embeddings / scale).astype(np.int8)
    msg = "Not a valid enum member of DType."
    raise ValueError(msg)


def quantize_and_reduce_dim(
    embeddings: np.ndarray, quantize_to: str | DType | None, dimensionality: int | None
) -> np.ndarray:
    """
    Quantize embeddings to a datatype and reduce dimensionality.

    :param embeddings: The embeddings to quantize and reduce, as a numpy array.
    :param quantize_to: The data type to quantize to. If None, no quantization is performed.
    :param dimensionality: The number of dimensions to keep. If None, no dimensionality reduction is performed.
    :return: The quantized and reduced embeddings.
    :raises ValueError: If the passed dimensionality is not None and greater than the model dimensionality.
    """
    if quantize_to is not None:
        quantize_to = DType(quantize_to)
        embeddings = quantize_embeddings(embeddings, quantize_to)

    if dimensionality is not None:
        if dimensionality > embeddings.shape[1]:
            msg = f"Dimensionality {dimensionality} is greater than the model dimensionality {embeddings.shape[1]}"
            raise ValueError(
                msg
            )
        embeddings = embeddings[:, :dimensionality]

    return embeddings