import torch def train(model, trainloader, optimizer, criterion, DEVICE): model.train() running_loss = 0 for itr, data in enumerate(trainloader): # print(itr) # print(data[0].shape, data[1].shape) # print(len(trainloader)) # if itr % 100 == 0: # print("itr: {}".format(itr)) optimizer.zero_grad() imgs, target = data[0].to(DEVICE), data[1].to(DEVICE) output_logits = model(imgs) loss = criterion( output_logits, target) running_loss = loss.item() loss.backward() optimizer.step() epoch_loss = running_loss/len(trainloader) print("epoch loss = {}".format(epoch_loss)) return epoch_loss