Spaces:
Running
on
A100
Running
on
A100
import json | |
import torch | |
import torch.nn as nn | |
from transformers import AutoTokenizer, T5ForConditionalGeneration | |
def load_glyph_byT5_v2(args, device): | |
""" | |
Loads ByT5 tokenizer and encoder model for glyph encoding. | |
Args: | |
args (dict): Configuration dictionary containing paths and settings. | |
device (str or torch.device): Device to load the model onto. | |
Returns: | |
dict: Dictionary with keys 'byt5_tokenizer', 'byt5_model', 'byt5_max_length'. | |
""" | |
byt5_tokenizer, byt5_model, byt5_max_length = create_byt5(args, device) | |
byt5_model = byt5_model.to(device=device) | |
return { | |
"byt5_tokenizer": byt5_tokenizer, | |
"byt5_model": byt5_model, | |
"byt5_max_length": byt5_max_length, | |
} | |
def create_byt5(args, device): | |
""" | |
Create ByT5 tokenizer and encoder, load weights if provided. | |
Args: | |
args (dict): Configuration dictionary. | |
device (str or torch.device): Device to load the model onto. | |
Returns: | |
tuple: (byt5_tokenizer, byt5_model, byt5_max_length) | |
""" | |
byt5_max_length = args['byt5_max_length'] | |
byt5_config = dict( | |
byt5_name=args['byT5_google_path'], | |
special_token=True, | |
color_special_token=True, | |
font_special_token=True, | |
color_ann_path=args['multilingual_prompt_format_color_path'], | |
font_ann_path=args['multilingual_prompt_format_font_path'], | |
multilingual=True, | |
) | |
huggingface_cache_dir = None | |
byt5_model, byt5_tokenizer = load_byt5_and_byt5_tokenizer( | |
**byt5_config, | |
huggingface_cache_dir=huggingface_cache_dir, | |
device=device, | |
) | |
# Load custom checkpoint if provided | |
if args['byT5_ckpt_path'] is not None: | |
if "cuda" not in str(device): | |
byt5_state_dict = torch.load(args['byT5_ckpt_path'], map_location=f"cuda:{device}") | |
else: | |
byt5_state_dict = torch.load(args['byT5_ckpt_path'], map_location=device) | |
if 'state_dict' in byt5_state_dict: | |
sd = byt5_state_dict["state_dict"] | |
newsd = {} | |
for k, v in sd.items(): | |
if k.startswith('module.text_tower.encoder.'): | |
newsd[k[len('module.text_tower.encoder.'):]] = v | |
byt5_state_dict = newsd | |
byt5_model.load_state_dict(byt5_state_dict) | |
byt5_model.requires_grad_(False) | |
return byt5_tokenizer, byt5_model, byt5_max_length | |
def add_special_token( | |
tokenizer, | |
text_encoder, | |
add_color, | |
add_font, | |
color_ann_path, | |
font_ann_path, | |
multilingual=False, | |
): | |
""" | |
Add special tokens for color and font to tokenizer and text encoder. | |
Args: | |
tokenizer: Huggingface tokenizer. | |
text_encoder: Huggingface T5 encoder. | |
add_color (bool): Whether to add color tokens. | |
add_font (bool): Whether to add font tokens. | |
color_ann_path (str): Path to color annotation JSON. | |
font_ann_path (str): Path to font annotation JSON. | |
multilingual (bool): Whether to use multilingual font tokens. | |
""" | |
with open(font_ann_path, 'r') as f: | |
idx_font_dict = json.load(f) | |
with open(color_ann_path, 'r') as f: | |
idx_color_dict = json.load(f) | |
if multilingual: | |
font_token = [f'<{font_code[:2]}-font-{idx_font_dict[font_code]}>' for font_code in idx_font_dict] | |
else: | |
font_token = [f'<font-{i}>' for i in range(len(idx_font_dict))] | |
color_token = [f'<color-{i}>' for i in range(len(idx_color_dict))] | |
additional_special_tokens = [] | |
if add_color: | |
additional_special_tokens += color_token | |
if add_font: | |
additional_special_tokens += font_token | |
tokenizer.add_tokens(additional_special_tokens, special_tokens=True) | |
# Set mean_resizing=False to avoid PyTorch LAPACK dependency | |
text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False) | |
def load_byt5_and_byt5_tokenizer( | |
byt5_name='google/byt5-small', | |
special_token=False, | |
color_special_token=False, | |
font_special_token=False, | |
color_ann_path='assets/color_idx.json', | |
font_ann_path='assets/font_idx_512.json', | |
huggingface_cache_dir=None, | |
multilingual=False, | |
device=None, | |
): | |
""" | |
Load ByT5 encoder and tokenizer from Huggingface, and add special tokens if needed. | |
Args: | |
byt5_name (str): Model name or path. | |
special_token (bool): Whether to add special tokens. | |
color_special_token (bool): Whether to add color tokens. | |
font_special_token (bool): Whether to add font tokens. | |
color_ann_path (str): Path to color annotation JSON. | |
font_ann_path (str): Path to font annotation JSON. | |
huggingface_cache_dir (str): Huggingface cache directory. | |
multilingual (bool): Whether to use multilingual font tokens. | |
device (str or torch.device): Device to load the model onto. | |
Returns: | |
tuple: (byt5_text_encoder, byt5_tokenizer) | |
""" | |
byt5_tokenizer = AutoTokenizer.from_pretrained( | |
byt5_name, | |
cache_dir=huggingface_cache_dir, | |
) | |
byt5_text_encoder = T5ForConditionalGeneration.from_pretrained( | |
byt5_name, | |
cache_dir=huggingface_cache_dir, | |
).get_encoder() | |
if "cuda" not in str(device): | |
device = torch.device(f"cuda:{device}") | |
else: | |
device = torch.device(device) | |
byt5_text_encoder = byt5_text_encoder.to(device) | |
if special_token: | |
add_special_token( | |
byt5_tokenizer, | |
byt5_text_encoder, | |
add_color=color_special_token, | |
add_font=font_special_token, | |
color_ann_path=color_ann_path, | |
font_ann_path=font_ann_path, | |
multilingual=multilingual, | |
) | |
return byt5_text_encoder, byt5_tokenizer | |
class ByT5Mapper(nn.Module): | |
""" | |
ByT5Mapper: Maps ByT5 encoder outputs to a new space, with optional residual connection. | |
Args: | |
in_dim (int): Input dimension (must equal out_dim if use_residual). | |
out_dim (int): Output dimension after second linear layer. | |
hidden_dim (int): Hidden dimension for intermediate layer. | |
out_dim1 (int): Final output dimension. | |
use_residual (bool): Whether to use residual connection (default: True). | |
""" | |
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_residual=True): | |
super().__init__() | |
if use_residual: | |
assert in_dim == out_dim | |
self.layernorm = nn.LayerNorm(in_dim) | |
self.fc1 = nn.Linear(in_dim, hidden_dim) | |
self.fc2 = nn.Linear(hidden_dim, out_dim) | |
self.fc3 = nn.Linear(out_dim, out_dim1) | |
self.use_residual = use_residual | |
self.act_fn = nn.GELU() | |
def forward(self, x): | |
""" | |
Forward pass for ByT5Mapper. | |
Args: | |
x (Tensor): Input tensor of shape (..., in_dim). | |
Returns: | |
Tensor: Output tensor of shape (..., out_dim1). | |
""" | |
residual = x | |
x = self.layernorm(x) | |
x = self.fc1(x) | |
x = self.act_fn(x) | |
x = self.fc2(x) | |
x2 = self.act_fn(x) | |
x2 = self.fc3(x2) | |
if self.use_residual: | |
x2 = x2 + residual | |
return x2 | |