import os import sys sys.path.append("..") sys.path.append(".") import time import json import pprint import random import numpy as np from easydict import EasyDict as EDict from tqdm import tqdm, trange from collections import OrderedDict import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from utils.basic_utils import save_json from baselines.crossmodal_moment_localization.config import BaseOptions from baselines.crossmodal_moment_localization.model_xml import XML from baselines.crossmodal_moment_localization.start_end_dataset import \ StartEndDataset, start_end_collate, StartEndEvalDataset, prepare_batch_inputs from baselines.crossmodal_moment_localization.inference import eval_epoch, start_inference from baselines.crossmodal_moment_localization.optimization import BertAdam from utils.basic_utils import AverageMeter, get_logger from utils.model_utils import count_parameters def get_eval_data(opt, data_path, data_mode): dataset = StartEndEvalDataset( data_path=data_path, desc_bert_path_or_handler=opt.desc_bert_path, sub_bert_path_or_handler=opt.sub_bert_path if "sub" in opt.ctx_mode else None, max_desc_len=opt.max_desc_l, max_ctx_len=opt.max_ctx_l, corpus_path=opt.corpus_path, vid_feat_path_or_handler=opt.vid_feat_path if "video" in opt.ctx_mode else None, clip_length=opt.clip_length, ctx_mode=opt.ctx_mode, data_mode=data_mode, h5driver=opt.h5driver, data_ratio=opt.data_ratio, normalize_vfeat=not opt.no_norm_vfeat, normalize_tfeat=not opt.no_norm_tfeat) return dataset def set_seed(seed, use_cuda=True): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if use_cuda: torch.cuda.manual_seed_all(seed) def rm_key_from_odict(odict_obj, rm_suffix): """remove key entry from the OrderedDict""" return OrderedDict([(k, v) for k, v in odict_obj.items() if rm_suffix not in k]) def train(model, train_dataset, val_data, test_data, context_data, opt, logger): if opt.device.type == "cuda": logger.info("CUDA enabled.") model.to(opt.device) if len(opt.device_ids) > 1: logger.info("Use multi GPU", opt.device_ids) model = torch.nn.DataParallel(model, device_ids=opt.device_ids) # use multi GPU train_loader = DataLoader(train_dataset, collate_fn=start_end_collate, batch_size=opt.bsz, num_workers=opt.num_workers, shuffle=True, pin_memory=opt.pin_memory) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01}, {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0} ] num_train_optimization_steps = len(train_loader) * opt.n_epoch optimizer = BertAdam(optimizer_grouped_parameters, lr=opt.lr, weight_decay=opt.wd, warmup=opt.lr_warmup_proportion, t_total=num_train_optimization_steps, schedule="warmup_linear") thresholds = [0.3, 0.5, 0.7] topks = [10, 20, 40] best_val_ndcg = 0 for epoch_i in range(0, opt.n_epoch): print(f"TRAIN EPOCH: {epoch_i}|{opt.n_epoch}") eval_step = len(train_loader) // opt.eval_num_per_epoch if opt.hard_negtiave_start_epoch != -1 and epoch_i >= opt.hard_negtiave_start_epoch: model.set_hard_negative(True, opt.hard_pool_size) if opt.train_span_start_epoch != -1 and epoch_i >= opt.train_span_start_epoch: model.set_train_st_ed(opt.lw_st_ed) num_training_examples = len(train_loader) for batch_idx, batch in tqdm(enumerate(train_loader), desc="Training Iteration", total=num_training_examples): global_step = epoch_i * num_training_examples + batch_idx + 1 model.train(mode=True) # continue model_inputs = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory) loss, loss_dict = model(**model_inputs) optimizer.zero_grad() loss.backward() if opt.grad_clip != -1: nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) optimizer.step() if global_step % eval_step == 0 or batch_idx == len(train_loader): model.eval() with torch.no_grad(): val_performance, val_predictions = eval_epoch(model, val_data, context_data, logger, opt, max_after_nms=40, iou_thds=thresholds, topks=topks) test_performance, test_predictions = eval_epoch(model, test_data, context_data, logger, opt, max_after_nms=40, iou_thds=thresholds, topks=topks) logger.info(f"EPOCH: {epoch_i}") anchor_ndcg = 0 line1 = "" line2 = "VAL: " line3 = "TEST: " anchor_ndcg = val_performance[20][0.5] for K, vs in val_performance.items(): for T, v in vs.items(): line1 += f"NDCG@{K}, IoU={T}\t" line2 += f" {v:.6f}" for K, vs in test_performance.items(): for T, v in vs.items(): line3 += f" {v:.6f}" logger.info(line1) logger.info(line2) logger.info(line3) if anchor_ndcg > best_val_ndcg: print("~"*40) save_json(val_predictions, os.path.join(opt.results_dir, "best_val_predictions.json")) save_json(test_predictions, os.path.join(opt.results_dir, "best_test_predictions.json")) best_val_ndcg = anchor_ndcg logger.info("BEST " + line2) logger.info("BEST " + line3) checkpoint = {"model": model.state_dict(), "model_cfg": model.config, "epoch": epoch_i} torch.save(checkpoint, opt.ckpt_filepath) logger.info("save checkpoint: {}".format(opt.ckpt_filepath)) print("~"*40) logger.info("") def main(): opt = BaseOptions().parse() set_seed(opt.seed) logger = get_logger(opt.results_dir, opt.model_name +"_"+ opt.exp_id) train_dataset = StartEndDataset( dset_name=opt.dset_name, data_path=opt.train_path, desc_bert_path_or_handler=opt.desc_bert_path, sub_bert_path_or_handler=opt.sub_bert_path, max_desc_len=opt.max_desc_l, max_ctx_len=opt.max_ctx_l, vid_feat_path_or_handler=opt.vid_feat_path, clip_length=opt.clip_length, ctx_mode=opt.ctx_mode, h5driver=opt.h5driver, data_ratio=opt.data_ratio, normalize_vfeat=not opt.no_norm_vfeat, normalize_tfeat=not opt.no_norm_tfeat, ) context_data = get_eval_data(opt, opt.val_path, data_mode="context") val_data = get_eval_data(opt, opt.val_path, data_mode="query") test_data = get_eval_data(opt, opt.test_path, data_mode="query") model_config = EDict( merge_two_stream=not opt.no_merge_two_stream, # merge video and subtitles cross_att=not opt.no_cross_att, # use cross-attention when encoding video and subtitles span_predictor_type=opt.span_predictor_type, # span_predictor_type encoder_type=opt.encoder_type, # gru, lstm, transformer add_pe_rnn=opt.add_pe_rnn, # add pe for RNNs pe_type=opt.pe_type, # visual_input_size=opt.vid_feat_size, sub_input_size=opt.sub_feat_size, # for both desc and subtitles query_input_size=opt.q_feat_size, # for both desc and subtitles hidden_size=opt.hidden_size, # stack_conv_predictor_conv_kernel_sizes=opt.stack_conv_predictor_conv_kernel_sizes, # conv_kernel_size=opt.conv_kernel_size, conv_stride=opt.conv_stride, max_ctx_l=opt.max_ctx_l, max_desc_l=opt.max_desc_l, input_drop=opt.input_drop, cross_att_drop=opt.cross_att_drop, drop=opt.drop, n_heads=opt.n_heads, # self-att heads initializer_range=opt.initializer_range, # for linear layer ctx_mode=opt.ctx_mode, # video, sub or video_sub margin=opt.margin, # margin for ranking loss ranking_loss_type=opt.ranking_loss_type, # loss type, 'hinge' or 'lse' lw_neg_q=opt.lw_neg_q, # loss weight for neg. query and pos. context lw_neg_ctx=opt.lw_neg_ctx, # loss weight for pos. query and neg. context lw_st_ed=0, # will be assigned dynamically at training time use_hard_negative=False, # reset at each epoch hard_pool_size=opt.hard_pool_size, use_self_attention=not opt.no_self_att, # whether to use self attention no_modular=opt.no_modular ) logger.info("model_config {}".format(model_config)) model = XML(model_config) count_parameters(model) logger.info("Start Training...") train(model, train_dataset, val_data, test_data, context_data, opt, logger) if __name__ == '__main__': main()