File size: 7,202 Bytes
43c5292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import json
import torch
import torch.nn as nn
from transformers import AutoTokenizer, T5ForConditionalGeneration


def load_glyph_byT5_v2(args, device):
    """
    Loads ByT5 tokenizer and encoder model for glyph encoding.

    Args:
        args (dict): Configuration dictionary containing paths and settings.
        device (str or torch.device): Device to load the model onto.

    Returns:
        dict: Dictionary with keys 'byt5_tokenizer', 'byt5_model', 'byt5_max_length'.
    """
    byt5_tokenizer, byt5_model, byt5_max_length = create_byt5(args, device)
    byt5_model = byt5_model.to(device=device)
    return {
        "byt5_tokenizer": byt5_tokenizer,
        "byt5_model": byt5_model,
        "byt5_max_length": byt5_max_length,
    }


def create_byt5(args, device):
    """
    Create ByT5 tokenizer and encoder, load weights if provided.

    Args:
        args (dict): Configuration dictionary.
        device (str or torch.device): Device to load the model onto.

    Returns:
        tuple: (byt5_tokenizer, byt5_model, byt5_max_length)
    """
    byt5_max_length = args['byt5_max_length']
    byt5_config = dict(
        byt5_name=args['byT5_google_path'],
        special_token=True,
        color_special_token=True,
        font_special_token=True,
        color_ann_path=args['multilingual_prompt_format_color_path'],
        font_ann_path=args['multilingual_prompt_format_font_path'],
        multilingual=True,
    )
    huggingface_cache_dir = None
    byt5_model, byt5_tokenizer = load_byt5_and_byt5_tokenizer(
        **byt5_config,
        huggingface_cache_dir=huggingface_cache_dir,
        device=device,
    )

    # Load custom checkpoint if provided
    if args['byT5_ckpt_path'] is not None:
        if "cuda" not in str(device):
            byt5_state_dict = torch.load(args['byT5_ckpt_path'], map_location=f"cuda:{device}")
        else:
            byt5_state_dict = torch.load(args['byT5_ckpt_path'], map_location=device)
        if 'state_dict' in byt5_state_dict:
            sd = byt5_state_dict["state_dict"]
            newsd = {}
            for k, v in sd.items():
                if k.startswith('module.text_tower.encoder.'):
                    newsd[k[len('module.text_tower.encoder.'):]] = v
            byt5_state_dict = newsd
        byt5_model.load_state_dict(byt5_state_dict)
    byt5_model.requires_grad_(False)
    return byt5_tokenizer, byt5_model, byt5_max_length


def add_special_token(
    tokenizer,
    text_encoder,
    add_color,
    add_font,
    color_ann_path,
    font_ann_path,
    multilingual=False,
):
    """
    Add special tokens for color and font to tokenizer and text encoder.

    Args:
        tokenizer: Huggingface tokenizer.
        text_encoder: Huggingface T5 encoder.
        add_color (bool): Whether to add color tokens.
        add_font (bool): Whether to add font tokens.
        color_ann_path (str): Path to color annotation JSON.
        font_ann_path (str): Path to font annotation JSON.
        multilingual (bool): Whether to use multilingual font tokens.
    """
    with open(font_ann_path, 'r') as f:
        idx_font_dict = json.load(f)
    with open(color_ann_path, 'r') as f:
        idx_color_dict = json.load(f)

    if multilingual:
        font_token = [f'<{font_code[:2]}-font-{idx_font_dict[font_code]}>' for font_code in idx_font_dict]
    else:
        font_token = [f'<font-{i}>' for i in range(len(idx_font_dict))]
    color_token = [f'<color-{i}>' for i in range(len(idx_color_dict))]
    additional_special_tokens = []
    if add_color:
        additional_special_tokens += color_token
    if add_font:
        additional_special_tokens += font_token

    tokenizer.add_tokens(additional_special_tokens, special_tokens=True)
    # Set mean_resizing=False to avoid PyTorch LAPACK dependency
    text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False)


def load_byt5_and_byt5_tokenizer(
    byt5_name='google/byt5-small',
    special_token=False,
    color_special_token=False,
    font_special_token=False,
    color_ann_path='assets/color_idx.json',
    font_ann_path='assets/font_idx_512.json',
    huggingface_cache_dir=None,
    multilingual=False,
    device=None,
):
    """
    Load ByT5 encoder and tokenizer from Huggingface, and add special tokens if needed.

    Args:
        byt5_name (str): Model name or path.
        special_token (bool): Whether to add special tokens.
        color_special_token (bool): Whether to add color tokens.
        font_special_token (bool): Whether to add font tokens.
        color_ann_path (str): Path to color annotation JSON.
        font_ann_path (str): Path to font annotation JSON.
        huggingface_cache_dir (str): Huggingface cache directory.
        multilingual (bool): Whether to use multilingual font tokens.
        device (str or torch.device): Device to load the model onto.

    Returns:
        tuple: (byt5_text_encoder, byt5_tokenizer)
    """
    byt5_tokenizer = AutoTokenizer.from_pretrained(
        byt5_name,
        cache_dir=huggingface_cache_dir,
    )
    byt5_text_encoder = T5ForConditionalGeneration.from_pretrained(
        byt5_name,
        cache_dir=huggingface_cache_dir,
    ).get_encoder()

    if "cuda" not in str(device):
        device = torch.device(f"cuda:{device}")
    else:
        device = torch.device(device)
    byt5_text_encoder = byt5_text_encoder.to(device)

    if special_token:
        add_special_token(
            byt5_tokenizer,
            byt5_text_encoder,
            add_color=color_special_token,
            add_font=font_special_token,
            color_ann_path=color_ann_path,
            font_ann_path=font_ann_path,
            multilingual=multilingual,
        )
    return byt5_text_encoder, byt5_tokenizer


class ByT5Mapper(nn.Module):
    """
    ByT5Mapper: Maps ByT5 encoder outputs to a new space, with optional residual connection.

    Args:
        in_dim (int): Input dimension (must equal out_dim if use_residual).
        out_dim (int): Output dimension after second linear layer.
        hidden_dim (int): Hidden dimension for intermediate layer.
        out_dim1 (int): Final output dimension.
        use_residual (bool): Whether to use residual connection (default: True).
    """

    def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_residual=True):
        super().__init__()
        if use_residual:
            assert in_dim == out_dim
        self.layernorm = nn.LayerNorm(in_dim)
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_dim)
        self.fc3 = nn.Linear(out_dim, out_dim1)
        self.use_residual = use_residual
        self.act_fn = nn.GELU()

    def forward(self, x):
        """
        Forward pass for ByT5Mapper.

        Args:
            x (Tensor): Input tensor of shape (..., in_dim).

        Returns:
            Tensor: Output tensor of shape (..., out_dim1).
        """
        residual = x
        x = self.layernorm(x)
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.fc2(x)
        x2 = self.act_fn(x)
        x2 = self.fc3(x2)
        if self.use_residual:
            x2 = x2 + residual
        return x2