File size: 5,300 Bytes
7c34c28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from typing import Dict, List
import os

import torch
import torch.nn as nn

from multi_token.modalities.base_modality import Modality
from multi_token.modalities.projectors import build_mlp_vector_projector
from multi_token.data_tools import with_local_files

IMAGE_BIND_FORCE_CPU = "IMAGE_BIND_FORCE_CPU"
IMAGE_BIND_EMBEDDING_SIZE = 1024


class ImageBindModule(nn.Module):
    def __init__(self):
        super().__init__()
        from imagebind.models import imagebind_model
        from imagebind import data

        data.BPE_PATH = os.path.join(
            os.path.dirname(data.__file__), "..", "bpe", "bpe_simple_vocab_16e6.txt.gz"
        )
        self.model = imagebind_model.imagebind_huge(pretrained=True)
        self.model.eval()
        self.model.requires_grad_(False)

    @torch.no_grad()
    def forward(self, items: Dict) -> torch.Tensor:
        forward_outs = self.model(items)
        return forward_outs

    @property
    def embedding_size(self):
        return IMAGE_BIND_EMBEDDING_SIZE


class ImageBindModality(Modality):
    def __init__(
        self,
        num_projector_layers: int = 2,
        num_tokens: int = 4,
        preprocess_device: str = "cpu",
    ):
        self.module = ImageBindModule()
        self.dtype = torch.float32
        self.device = "cpu"  # used for outputs
        self.imagebind_device = "cpu"  # used for imagebind model itself
        self.preprocess_device = preprocess_device  # used for preprocessing
        self.num_projector_layers = num_projector_layers
        self.num_tokens = num_tokens

    def build_projector(self, lm_hidden_size: int) -> nn.Module:
        return build_mlp_vector_projector(
            self.module.embedding_size,
            lm_hidden_size,
            num_layers=self.num_projector_layers,
            num_tokens=self.num_tokens,
        )

    @property
    def name(self) -> str:
        return "imagebind"

    @property
    def token(self) -> str:
        return "<imagebind>"

    @property
    def data_key(self) -> str:
        return "imagebinds"

    @property
    def token_width(self) -> int:
        return self.num_tokens

    def to(self, dtype: torch.dtype, device: torch.device) -> "ImageBindModality":
        # we ignore dtype and sometimes device as well
        self.device = device
        self.dtype = dtype
        if IMAGE_BIND_FORCE_CPU not in os.environ:
            # running out of VRAM on 24GB GPU
            self.module.to(device=device)
            self.imagebind_device = device
        return self

    def preprocess_rows(self, rows: List[Dict]) -> List[List[Dict]]:
        from imagebind.models.imagebind_model import ModalityType
        from imagebind import data

        row_values = []
        for row in rows:
            items = []
            with with_local_files(row[self.data_key]) as item_paths:
                for item_path in item_paths:
                    ib_modality = filename_to_imagebind_modality(item_path)
                    if ib_modality == ModalityType.TEXT:
                        items.append(
                            {
                                ModalityType.TEXT: data.load_and_transform_text(
                                    [item_path], self.preprocess_device
                                )
                            }
                        )
                    elif ib_modality == ModalityType.VISION:
                        items.append(
                            {
                                ModalityType.VISION: data.load_and_transform_vision_data(
                                    [item_path], self.preprocess_device
                                )
                            }
                        )
                    elif ib_modality == ModalityType.AUDIO:
                        items.append(
                            {
                                ModalityType.AUDIO: data.load_and_transform_audio_data(
                                    [item_path], self.preprocess_device
                                )
                            }
                        )
                    else:
                        raise ValueError(f"Unknown modality type: {ib_modality}")
            row_values.append(items)
        return row_values

    @torch.no_grad()
    def forward(self, encoded_values: List[List[Dict]]) -> List[torch.Tensor]:
        item_features = []
        for item_batch in encoded_values:
            item_batch_emb = []
            for item in item_batch:
                item = {
                    k: v.to(device=self.imagebind_device, dtype=torch.float32)
                    for k, v in item.items()
                }
                item_batch_emb.extend(list(self.module.forward(item).values()))
            item_features.append(
                torch.stack(item_batch_emb).to(device=self.device, dtype=self.dtype)
            )
        # batch_size x num_items x 1 x embedding_size
        return item_features


def filename_to_imagebind_modality(fn: str) -> str:
    from imagebind.models.imagebind_model import ModalityType

    _, ext = os.path.splitext(fn)
    if ext in {".wav"}:
        return ModalityType.AUDIO
    elif ext in {".jpg", ".png", ".jpeg"}:
        return ModalityType.VISION
    else:
        return ModalityType.TEXT