smog / src /train /trainer.py
vonexel's picture
add: src
fe64bad verified
import torch
from tqdm import tqdm
from copy import deepcopy
import pickle
import clip
# https://github.com/openai/CLIP/issues/57
def convert_models_to_fp32(model):
for p in model.parameters():
p.data = p.data.float()
if p.grad is not None:
p.grad.data = p.grad.data.float()
def train_or_test(model, optimizer, iterator, device, mode="train"):
if mode == "train":
model.train()
grad_env = torch.enable_grad
elif mode == "test":
model.eval()
grad_env = torch.no_grad
else:
raise ValueError("This mode is not recognized.")
# loss of the epoch
dict_loss = {}
with grad_env():
for i, batch in tqdm(enumerate(iterator), desc="Computing batch"):
# Put everything in device
# Added if is_tensor as 'clip_text' in batch is a list of strings, not a tensor!
batch = {key: val.to(device) if torch.is_tensor(val) else val for key, val in batch.items()}
if mode == "train":
# update the gradients to zero
optimizer.zero_grad()
# forward pass
batch = model(batch)
mixed_loss, losses = model.compute_loss(batch)
if i == 0:
dict_loss = deepcopy(losses)
else:
for key in dict_loss.keys():
dict_loss[key] += losses[key]
if mode == "train":
# backward pass
mixed_loss.backward()
# update the weights
if model.clip_training:
convert_models_to_fp32(model.clip_model)
optimizer.step()
clip.model.convert_weights(model.clip_model)
else:
optimizer.step()
return dict_loss
def train(model, optimizer, iterator, device):
return train_or_test(model, optimizer, iterator, device, mode="train")
def test(model, optimizer, iterator, device):
return train_or_test(model, optimizer, iterator, device, mode="test")