|
import torch
|
|
from tqdm import tqdm
|
|
from copy import deepcopy
|
|
import pickle
|
|
import clip
|
|
|
|
|
|
|
|
def convert_models_to_fp32(model):
|
|
for p in model.parameters():
|
|
p.data = p.data.float()
|
|
if p.grad is not None:
|
|
p.grad.data = p.grad.data.float()
|
|
|
|
|
|
def train_or_test(model, optimizer, iterator, device, mode="train"):
|
|
if mode == "train":
|
|
model.train()
|
|
grad_env = torch.enable_grad
|
|
elif mode == "test":
|
|
model.eval()
|
|
grad_env = torch.no_grad
|
|
else:
|
|
raise ValueError("This mode is not recognized.")
|
|
|
|
|
|
dict_loss = {}
|
|
with grad_env():
|
|
for i, batch in tqdm(enumerate(iterator), desc="Computing batch"):
|
|
|
|
|
|
|
|
batch = {key: val.to(device) if torch.is_tensor(val) else val for key, val in batch.items()}
|
|
|
|
if mode == "train":
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
batch = model(batch)
|
|
|
|
mixed_loss, losses = model.compute_loss(batch)
|
|
|
|
if i == 0:
|
|
dict_loss = deepcopy(losses)
|
|
else:
|
|
for key in dict_loss.keys():
|
|
dict_loss[key] += losses[key]
|
|
|
|
if mode == "train":
|
|
|
|
mixed_loss.backward()
|
|
|
|
if model.clip_training:
|
|
convert_models_to_fp32(model.clip_model)
|
|
optimizer.step()
|
|
clip.model.convert_weights(model.clip_model)
|
|
else:
|
|
optimizer.step()
|
|
|
|
return dict_loss
|
|
|
|
|
|
def train(model, optimizer, iterator, device):
|
|
return train_or_test(model, optimizer, iterator, device, mode="train")
|
|
|
|
|
|
def test(model, optimizer, iterator, device):
|
|
return train_or_test(model, optimizer, iterator, device, mode="test")
|
|
|