File size: 3,343 Bytes
6755a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import sys
from multiprocessing.pool import Pool
import os
import logging
from typing import Union, List, Tuple

import torch
import numpy as np
import pandas as pd
import h5py
import diffusers
from diffusers import AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
from einops import rearrange
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer

from ...data.extract_feature.base_extract_feature import BaseFeatureExtractor

from .save_text_emb import save_text_emb_with_h5py


class ClipTextFeatureExtractor(BaseFeatureExtractor):
    def __init__(
        self,
        pretrained_model_name_or_path: str,
        device: str = "cpu",
        dtype: torch.dtype = None,
        name: str = "CLIPEncoderLayer",
    ):
        super().__init__(device, dtype, name)
        self.pretrained_model_name_or_path = pretrained_model_name_or_path
        self.tokenizer = CLIPTokenizer.from_pretrained(
            pretrained_model_name_or_path, subfolder="tokenizer"
        )
        text_encoder = CLIPTextModel.from_pretrained(
            pretrained_model_name_or_path, subfolder="text_encoder"
        )
        text_encoder.requires_grad_(False)
        self.text_encoder = text_encoder.to(device=device, dtype=dtype)

    def extract(
        self,
        text: Union[str, List[str]],
        return_type: str = "numpy",
        save_emb_path: str = None,
        save_type: str = "h5py",
        text_emb_key: str = None,
        text_key: str = "text",
        text_tuple_length: int = 20,
        text_index: int = 0,
        insert_name_to_key: bool = False,
    ) -> Union[np.ndarray, torch.Tensor]:
        if text_emb_key is not None:
            text_emb_key = f"{text_emb_key}_{text_index}"
        if self.name is not None and insert_name_to_key:
            if text_emb_key is not None:
                text_emb_key = f"{self.name}_{text_emb_key}"
        text_inputs = self.tokenizer(
            text,
            max_length=self.tokenizer.model_max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids
        if (
            hasattr(self.text_encoder.config, "use_attention_mask")
            and self.text_encoder.config.use_attention_mask
        ):
            attention_mask = text_inputs.attention_mask.to(self.device)
        else:
            attention_mask = None
        # transformers.modeling_outputs.BaseModelOutputWithPooling
        # 'last_hidden_state', 'pooler_output'
        # we choose the first
        print()
        text_embeds = self.text_encoder(
            text_input_ids.to(device=self.device),
            attention_mask=attention_mask,
        )[0]

        if return_type == "numpy":
            text_embeds = text_embeds.cpu().numpy()
        if save_emb_path is None:
            return text_embeds
        else:
            if save_type == "h5py":
                save_text_emb_with_h5py(
                    path=save_emb_path,
                    emb=text_embeds,
                    text_emb_key=text_emb_key,
                    text=text,
                    text_key=text_key,
                    text_tuple_length=text_tuple_length,
                    text_index=text_index,
                )
                return text_embeds