Liangrj5
		
	commited on
		
		
					Commit 
							
							·
						
						ebf5d87
	
1
								Parent(s):
							
							6dd9459
								
init
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- baselines/__init__.py +0 -0
- baselines/__pycache__/__init__.cpython-311.pyc +0 -0
- baselines/clip_alignment_with_language/README.md +25 -0
- baselines/clip_alignment_with_language/__init__.py +0 -0
- baselines/clip_alignment_with_language/__pycache__/__init__.cpython-311.pyc +0 -0
- baselines/clip_alignment_with_language/__pycache__/config.cpython-311.pyc +0 -0
- baselines/clip_alignment_with_language/__pycache__/inference.cpython-311.pyc +0 -0
- baselines/clip_alignment_with_language/__pycache__/model.cpython-311.pyc +0 -0
- baselines/clip_alignment_with_language/__pycache__/proposal_retrieval_dataset.cpython-311.pyc +0 -0
- baselines/clip_alignment_with_language/config.py +207 -0
- baselines/clip_alignment_with_language/inference.py +672 -0
- baselines/clip_alignment_with_language/local_utils/__init__.py +0 -0
- baselines/clip_alignment_with_language/local_utils/__pycache__/__init__.cpython-311.pyc +0 -0
- baselines/clip_alignment_with_language/local_utils/__pycache__/compute_proposal_upper_bound.cpython-311.pyc +0 -0
- baselines/clip_alignment_with_language/local_utils/__pycache__/proposal.cpython-311.pyc +0 -0
- baselines/clip_alignment_with_language/local_utils/compute_proposal_upper_bound.py +117 -0
- baselines/clip_alignment_with_language/local_utils/proposal.py +181 -0
- baselines/clip_alignment_with_language/local_utils/tvr_proposal_test_log.txt +61 -0
- baselines/clip_alignment_with_language/mix_model_prediction.py +86 -0
- baselines/clip_alignment_with_language/model.py +299 -0
- baselines/clip_alignment_with_language/proposal_retrieval_dataset.py +587 -0
- baselines/clip_alignment_with_language/scripts/compute_upper_bound.sh +23 -0
- baselines/clip_alignment_with_language/scripts/inference.sh +17 -0
- baselines/clip_alignment_with_language/scripts/inference_mix.sh +27 -0
- baselines/clip_alignment_with_language/scripts/inference_with_external.sh +54 -0
- baselines/clip_alignment_with_language/scripts/re_train_cal.sh +21 -0
- baselines/clip_alignment_with_language/scripts/re_train_mcn.sh +21 -0
- baselines/clip_alignment_with_language/scripts/train.sh +80 -0
- baselines/clip_alignment_with_language/train.py +310 -0
- baselines/crossmodal_moment_localization/README.md +2 -0
- baselines/crossmodal_moment_localization/__init__.py +0 -0
- baselines/crossmodal_moment_localization/__pycache__/__init__.cpython-311.pyc +0 -0
- baselines/crossmodal_moment_localization/__pycache__/config.cpython-311.pyc +0 -0
- baselines/crossmodal_moment_localization/__pycache__/inference.cpython-311.pyc +0 -0
- baselines/crossmodal_moment_localization/__pycache__/model_components.cpython-311.pyc +0 -0
- baselines/crossmodal_moment_localization/__pycache__/model_xml.cpython-311.pyc +0 -0
- baselines/crossmodal_moment_localization/__pycache__/ndcg_iou_topk.cpython-311.pyc +0 -0
- baselines/crossmodal_moment_localization/__pycache__/optimization.cpython-311.pyc +0 -0
- baselines/crossmodal_moment_localization/__pycache__/start_end_dataset.cpython-311.pyc +0 -0
- baselines/crossmodal_moment_localization/config.py +276 -0
- baselines/crossmodal_moment_localization/inference.py +414 -0
- baselines/crossmodal_moment_localization/model_components.py +317 -0
- baselines/crossmodal_moment_localization/model_xml.py +642 -0
- baselines/crossmodal_moment_localization/ndcg_iou_topk.py +68 -0
- baselines/crossmodal_moment_localization/optimization.py +338 -0
- baselines/crossmodal_moment_localization/scripts/eval.sh +14 -0
- baselines/crossmodal_moment_localization/scripts/inference.sh +18 -0
- baselines/crossmodal_moment_localization/scripts/inference_with_external.sh +40 -0
- baselines/crossmodal_moment_localization/scripts/train.sh +70 -0
- baselines/crossmodal_moment_localization/start_end_dataset.py +393 -0
    	
        baselines/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        baselines/__pycache__/__init__.cpython-311.pyc
    ADDED
    
    | Binary file (176 Bytes). View file | 
|  | 
    	
        baselines/clip_alignment_with_language/README.md
    ADDED
    
    | @@ -0,0 +1,25 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Clip Alignment With Language
         | 
| 2 | 
            +
            This folder contains the CAL model described in the paper
         | 
| 3 | 
            +
            ```
         | 
| 4 | 
            +
            @article{Escorcia2019TemporalLO,
         | 
| 5 | 
            +
              title={Temporal Localization of Moments in Video Collections with Natural Language},
         | 
| 6 | 
            +
              author={Victor Escorcia and Mattia Soldan and Josef Sivic and Bernard Ghanem and Bryan Russell},
         | 
| 7 | 
            +
              journal={ArXiv},
         | 
| 8 | 
            +
              year={2019},
         | 
| 9 | 
            +
              volume={abs/1907.12763}
         | 
| 10 | 
            +
            }
         | 
| 11 | 
            +
            ```
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            It also resembles the MCN model in
         | 
| 14 | 
            +
            ```
         | 
| 15 | 
            +
            @article{Hendricks2017LocalizingMI,
         | 
| 16 | 
            +
              title={Localizing Moments in Video with Natural Language},
         | 
| 17 | 
            +
              author={Lisa Anne Hendricks and Oliver Wang and Eli Shechtman and Josef Sivic and Trevor Darrell and Bryan C. Russell},
         | 
| 18 | 
            +
              journal={2017 IEEE International Conference on Computer Vision (ICCV)},
         | 
| 19 | 
            +
              year={2017},
         | 
| 20 | 
            +
              pages={5804-5813}
         | 
| 21 | 
            +
            }
         | 
| 22 | 
            +
            ```
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            Disclaimer: This code is implemented by [Jie Lei](http://www.cs.unc.edu/~jielei/) for the TVR dataset, 
         | 
| 25 | 
            +
            it does not guarantee the reproducibility of the original authors' results.
         | 
    	
        baselines/clip_alignment_with_language/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        baselines/clip_alignment_with_language/__pycache__/__init__.cpython-311.pyc
    ADDED
    
    | Binary file (205 Bytes). View file | 
|  | 
    	
        baselines/clip_alignment_with_language/__pycache__/config.cpython-311.pyc
    ADDED
    
    | Binary file (17.8 kB). View file | 
|  | 
    	
        baselines/clip_alignment_with_language/__pycache__/inference.cpython-311.pyc
    ADDED
    
    | Binary file (43 kB). View file | 
|  | 
    	
        baselines/clip_alignment_with_language/__pycache__/model.cpython-311.pyc
    ADDED
    
    | Binary file (15.8 kB). View file | 
|  | 
    	
        baselines/clip_alignment_with_language/__pycache__/proposal_retrieval_dataset.cpython-311.pyc
    ADDED
    
    | Binary file (37 kB). View file | 
|  | 
    	
        baselines/clip_alignment_with_language/config.py
    ADDED
    
    | @@ -0,0 +1,207 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import time
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import argparse
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from utils.basic_utils import mkdirp, load_json, save_json, make_zipfile
         | 
| 7 | 
            +
            from baselines.clip_alignment_with_language.local_utils.proposal import ProposalConfigs
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class BaseOptions(object):
         | 
| 11 | 
            +
                saved_option_filename = "opt.json"
         | 
| 12 | 
            +
                ckpt_filename = "model.ckpt"
         | 
| 13 | 
            +
                tensorboard_log_dir = "tensorboard_log"
         | 
| 14 | 
            +
                train_log_filename = "train.log.txt"
         | 
| 15 | 
            +
                eval_log_filename = "eval.log.txt"
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def __init__(self):
         | 
| 18 | 
            +
                    self.parser = argparse.ArgumentParser()
         | 
| 19 | 
            +
                    self.initialized = False
         | 
| 20 | 
            +
                    self.opt = None
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                def initialize(self):
         | 
| 23 | 
            +
                    self.initialized = True
         | 
| 24 | 
            +
                    self.parser.add_argument("--dset_name", type=str, choices=["tvr"])
         | 
| 25 | 
            +
                    self.parser.add_argument("--eval_split_name", type=str, default="val",
         | 
| 26 | 
            +
                                             help="should match keys in corpus_path, must set for VCMR")
         | 
| 27 | 
            +
                    self.parser.add_argument("--debug", action="store_true",
         | 
| 28 | 
            +
                                             help="debug (fast) mode, break all loops, do not load all data into memory.")
         | 
| 29 | 
            +
                    self.parser.add_argument("--data_ratio", type=float, default=1.0,
         | 
| 30 | 
            +
                                             help="how many training and eval data to use. 1.0: use all, 0.1: use 10%."
         | 
| 31 | 
            +
                                                  "Use small portion for debug purposes. Note this is different from --debug, "
         | 
| 32 | 
            +
                                                  "which works by breaking the loops, typically they are not used together.")
         | 
| 33 | 
            +
                    self.parser.add_argument("--results_root", type=str, default="results")
         | 
| 34 | 
            +
                    self.parser.add_argument("--exp_id", type=str, default="res", help="id of the current run")
         | 
| 35 | 
            +
                    self.parser.add_argument("--seed", type=int, default=2018, help="random seed")
         | 
| 36 | 
            +
                    self.parser.add_argument("--device", type=int, default=0, help="0 cuda, -1 cpu")
         | 
| 37 | 
            +
                    self.parser.add_argument("--device_ids", type=int, nargs="+", default=[0], help="GPU ids to run the job")
         | 
| 38 | 
            +
                    self.parser.add_argument("--num_workers", type=int, default=8,
         | 
| 39 | 
            +
                                             help="num subprocesses used to load the data, 0: use main process")
         | 
| 40 | 
            +
                    self.parser.add_argument("--no_core_driver", action="store_true",
         | 
| 41 | 
            +
                                             help="hdf5 driver, default use `core` (load into RAM), if specified, use `None`")
         | 
| 42 | 
            +
                    self.parser.add_argument("--no_pin_memory", action="store_true",
         | 
| 43 | 
            +
                                             help="Don't use pin_memory=True for dataloader. "
         | 
| 44 | 
            +
                                                  "ref: https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/4")
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    # training config
         | 
| 47 | 
            +
                    self.parser.add_argument("--lr", type=float, default=0.05, help="learning rate")
         | 
| 48 | 
            +
                    self.parser.add_argument("--wd", type=float, default=0, help="weight decay")
         | 
| 49 | 
            +
                    self.parser.add_argument("--momentum", type=float, default=0.95, help="momentum for SGD")
         | 
| 50 | 
            +
                    self.parser.add_argument("--n_epoch", type=int, default=108, help="number of epochs to run")
         | 
| 51 | 
            +
                    self.parser.add_argument("--max_es_cnt", type=int, default=108, help="number of epochs to early stop")
         | 
| 52 | 
            +
                    self.parser.add_argument("--bsz", type=int, default=128, help="mini-batch size")
         | 
| 53 | 
            +
                    self.parser.add_argument("--eval_query_bsz", type=int, default=1000,
         | 
| 54 | 
            +
                                             help="mini-batch size at inference, for query")
         | 
| 55 | 
            +
                    self.parser.add_argument("--eval_proposal_bsz", type=int, default=200,
         | 
| 56 | 
            +
                                             help="mini-batch size at inference, for proposals")
         | 
| 57 | 
            +
                    self.parser.add_argument("--eval_untrained", action="store_true", help="Evaluate on un-trained model")
         | 
| 58 | 
            +
                    self.parser.add_argument("--grad_clip", type=float, default=-1, help="perform gradient clip, -1: disable")
         | 
| 59 | 
            +
                    self.parser.add_argument("--margin", type=float, default=0.1, help="margin for hinge loss")
         | 
| 60 | 
            +
                    self.parser.add_argument("--inter_loss_weight", type=float, default=0.4, help="margin for ranking loss")
         | 
| 61 | 
            +
                    self.parser.add_argument("--loss_type", type=str, default="hinge", choices=["hinge", "lse"],
         | 
| 62 | 
            +
                                             help="att loss type, can be hinge loss or its smooth approximation LogSumExp")
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    # Model and Data config
         | 
| 65 | 
            +
                    self.parser.add_argument("--max_sub_l", type=int, default=50,
         | 
| 66 | 
            +
                                             help="max length of all sub sentence 97.71 under 50 for 3 sentences")
         | 
| 67 | 
            +
                    self.parser.add_argument("--max_desc_l", type=int, default=30, help="max length of descriptions")
         | 
| 68 | 
            +
                    self.parser.add_argument("--pos_iou_thd", type=float, default=0.7, help="moments with IoU >= as positive")
         | 
| 69 | 
            +
                    self.parser.add_argument("--neg_iou_thd", type=float, default=0.35, help="moments with IoU < as negative")
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    self.parser.add_argument("--train_path", type=str, default=None)
         | 
| 72 | 
            +
                    self.parser.add_argument("--eval_path", type=str, default=None,
         | 
| 73 | 
            +
                                             help="Evaluating during training, for Dev set. If None, will only do training, "
         | 
| 74 | 
            +
                                                  "anet_cap and charades_sta has no dev set, so None")
         | 
| 75 | 
            +
                    self.parser.add_argument("--external_train_vr_res_path", type=str, default=None,
         | 
| 76 | 
            +
                                             help="if set, use external video retrieval results to guide "
         | 
| 77 | 
            +
                                                  "inter-nvideo negative sampling. ")
         | 
| 78 | 
            +
                    self.parser.add_argument("--init_ckpt_path", type=str, default=None,
         | 
| 79 | 
            +
                                             help="init model parameters from checkpoint. Use absolute path")
         | 
| 80 | 
            +
                    self.parser.add_argument("--external_inference_vr_res_path", type=str, default=None,
         | 
| 81 | 
            +
                                             help="if set, use external video retrieval results to guide evaluation. ")
         | 
| 82 | 
            +
                    self.parser.add_argument("--use_glove", action="store_true", help="Use GloVe instead of BERT features")
         | 
| 83 | 
            +
                    self.parser.add_argument("--word2idx_path", type=str,
         | 
| 84 | 
            +
                                             help="a dict, {word: word_idx, ...}, "
         | 
| 85 | 
            +
                                                  "special tokens are {<pad>: 0, <unk>: 1, <eos>: 2}")
         | 
| 86 | 
            +
                    self.parser.add_argument("--vocab_size", type=int, default=-1,
         | 
| 87 | 
            +
                                             help="Set automatically to len(word2idx)")
         | 
| 88 | 
            +
                    self.parser.add_argument("--glove_path", type=str,
         | 
| 89 | 
            +
                                             help="path to file containing the GloVe embeddings for words in word2idx")
         | 
| 90 | 
            +
                    self.parser.add_argument("--desc_bert_path", type=str, default=None)
         | 
| 91 | 
            +
                    self.parser.add_argument("--sub_bert_path", type=str, default=None)
         | 
| 92 | 
            +
                    self.parser.add_argument("--sub_feat_size", type=int, default=768, help="feature dim for sub feature")
         | 
| 93 | 
            +
                    self.parser.add_argument("--desc_feat_size", type=int, default=768)
         | 
| 94 | 
            +
                    self.parser.add_argument("--ctx_mode", type=str,
         | 
| 95 | 
            +
                                             choices=["video", "sub", "tef", "video_sub", "video_tef", "sub_tef", "video_sub_tef"],
         | 
| 96 | 
            +
                                             help="which context to use. a combination of [video, sub, tef]")
         | 
| 97 | 
            +
                    self.parser.add_argument("--corpus_path", type=str, default=None)
         | 
| 98 | 
            +
                    self.parser.add_argument("--vid_feat_path", type=str, default="")
         | 
| 99 | 
            +
                    self.parser.add_argument("--no_norm_vfeat", action="store_true",
         | 
| 100 | 
            +
                                             help="Do not do normalization on video feat, use it when using i3d_resnet concat feat")
         | 
| 101 | 
            +
                    self.parser.add_argument("--no_norm_tfeat", action="store_true", help="Do not do normalization on text feat")
         | 
| 102 | 
            +
                    self.parser.add_argument("--clip_length", type=float, default=None,
         | 
| 103 | 
            +
                                             help="each video will be uniformly segmented into small clips, "
         | 
| 104 | 
            +
                                                  "will automatically loaded from ProposalConfigs if None")
         | 
| 105 | 
            +
                    self.parser.add_argument("--vid_feat_size", type=int, help="feature dim for video feature")
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    self.parser.add_argument("--model_type", default="cal", choices=["cal", "mcn"])
         | 
| 108 | 
            +
                    self.parser.add_argument("--embedding_size", type=int, default=768)
         | 
| 109 | 
            +
                    self.parser.add_argument("--lstm_hidden_size", type=int, default=256)
         | 
| 110 | 
            +
                    self.parser.add_argument("--visual_hidden_size", type=int, default=256)
         | 
| 111 | 
            +
                    self.parser.add_argument("--output_size", type=int, default=256)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    # post processing
         | 
| 114 | 
            +
                    self.parser.add_argument("--nms_thd", type=float, default=-1,
         | 
| 115 | 
            +
                                             help="additionally use non-maximum suppression "
         | 
| 116 | 
            +
                                                  "(or non-minimum suppression for distance)"
         | 
| 117 | 
            +
                                                  "to post-processing the predictions. "
         | 
| 118 | 
            +
                                                  "-1: do not use nms. 0.6 for charades_sta, 0.5 for anet_cap,")
         | 
| 119 | 
            +
                    self.parser.add_argument("--max_after_nms", type=int, default=100, help="Stores at max_after_nms for eval")
         | 
| 120 | 
            +
                    self.parser.add_argument("--max_before_nms", type=int, default=300, help="Max before nms")
         | 
| 121 | 
            +
                    self.parser.add_argument("--use_intermediate", action="store_true",
         | 
| 122 | 
            +
                                             help="Whether to use/save intermediate results to results directory."
         | 
| 123 | 
            +
                                                  "Might want use this if we are going to ")
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                def save_args(self, opt):
         | 
| 126 | 
            +
                    args = vars(opt)
         | 
| 127 | 
            +
                    # Save settings
         | 
| 128 | 
            +
                    if not isinstance(self, TestOptions):
         | 
| 129 | 
            +
                        option_file_path = os.path.join(opt.results_dir, self.saved_option_filename)  # not yaml file indeed
         | 
| 130 | 
            +
                        save_json(args, option_file_path, save_pretty=True)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                def parse(self):
         | 
| 133 | 
            +
                    if not self.initialized:
         | 
| 134 | 
            +
                        self.initialize()
         | 
| 135 | 
            +
                    opt = self.parser.parse_args()
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    if opt.debug:
         | 
| 138 | 
            +
                        opt.results_root = os.path.sep.join(opt.results_root.split(os.path.sep)[:-1] + ["debug_results", ])
         | 
| 139 | 
            +
                        opt.no_core_driver = True
         | 
| 140 | 
            +
                        opt.num_workers = 0
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    if isinstance(self, TestOptions):
         | 
| 143 | 
            +
                        # modify model_dir to absolute path
         | 
| 144 | 
            +
                        opt.model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", opt.model_dir)
         | 
| 145 | 
            +
                        saved_options = load_json(os.path.join(opt.model_dir, self.saved_option_filename))
         | 
| 146 | 
            +
                        for arg in saved_options:  # use saved options to overwrite all BaseOptions args.
         | 
| 147 | 
            +
                            if arg not in ["results_root", "num_workers", "nms_thd", "debug", "eval_split_name", "eval_path",
         | 
| 148 | 
            +
                                           "use_intermediate", "external_inference_vr_res_path"]:
         | 
| 149 | 
            +
                                setattr(opt, arg, saved_options[arg])
         | 
| 150 | 
            +
                        # opt.no_core_driver = True
         | 
| 151 | 
            +
                    else:
         | 
| 152 | 
            +
                        if opt.exp_id is None:
         | 
| 153 | 
            +
                            raise ValueError("--exp_id is required for at a training option!")
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                        if opt.clip_length is None:
         | 
| 156 | 
            +
                            opt.clip_length = ProposalConfigs[opt.dset_name]["clip_length"]
         | 
| 157 | 
            +
                        opt.results_dir = os.path.join(opt.results_root,
         | 
| 158 | 
            +
                                                       "-".join([opt.dset_name, opt.model_type, opt.ctx_mode, opt.exp_id,
         | 
| 159 | 
            +
                                                                 time.strftime("%Y_%m_%d_%H_%M_%S")]))
         | 
| 160 | 
            +
                        mkdirp(opt.results_dir)
         | 
| 161 | 
            +
                        # save a copy of current code
         | 
| 162 | 
            +
                        code_dir = os.path.dirname(os.path.realpath(__file__))
         | 
| 163 | 
            +
                        code_zip_filename = os.path.join(opt.results_dir, "code.zip")
         | 
| 164 | 
            +
                        make_zipfile(code_dir, code_zip_filename,
         | 
| 165 | 
            +
                                     enclosing_dir="code",
         | 
| 166 | 
            +
                                     exclude_dirs_substring="results",
         | 
| 167 | 
            +
                                     exclude_dirs=["results", "debug_results", "__pycache__"],
         | 
| 168 | 
            +
                                     exclude_extensions=[".pyc", ".ipynb", ".swap"])
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    self.save_args(opt)
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    if "sub" in opt.ctx_mode:
         | 
| 173 | 
            +
                        assert opt.dset_name == "tvr", "sub is only supported for tvr dataset"
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    if "video" in opt.ctx_mode and opt.vid_feat_size > 3000:  # 3072, the normalized concatenation of resnet+i3d
         | 
| 176 | 
            +
                        assert opt.no_norm_vfeat
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    opt.ckpt_filepath = os.path.join(opt.results_dir, self.ckpt_filename)
         | 
| 179 | 
            +
                    opt.train_log_filepath = os.path.join(opt.results_dir, self.train_log_filename)
         | 
| 180 | 
            +
                    opt.eval_log_filepath = os.path.join(opt.results_dir, self.eval_log_filename)
         | 
| 181 | 
            +
                    opt.tensorboard_log_dir = os.path.join(opt.results_dir, self.tensorboard_log_dir)
         | 
| 182 | 
            +
                    opt.device = torch.device("cuda:%d" % opt.device_ids[0] if opt.device >= 0 else "cpu")
         | 
| 183 | 
            +
                    opt.h5driver = None if opt.no_core_driver else "core"
         | 
| 184 | 
            +
                    # num_workers > 1 will only work with "core" mode, i.e., memory-mapped hdf5
         | 
| 185 | 
            +
                    opt.pin_memory = not opt.no_pin_memory
         | 
| 186 | 
            +
                    opt.num_workers = 1 if opt.no_core_driver else opt.num_workers
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    # Display settings
         | 
| 189 | 
            +
                    print("------------ Options -------------\n{}\n-------------------"
         | 
| 190 | 
            +
                          .format({str(k): str(v) for k, v in sorted(vars(opt).items())}))
         | 
| 191 | 
            +
                    self.opt = opt
         | 
| 192 | 
            +
                    return opt
         | 
| 193 | 
            +
             | 
| 194 | 
            +
             | 
| 195 | 
            +
            class TestOptions(BaseOptions):
         | 
| 196 | 
            +
                """add additional options for evaluating"""
         | 
| 197 | 
            +
                def initialize(self):
         | 
| 198 | 
            +
                    BaseOptions.initialize(self)
         | 
| 199 | 
            +
                    # also need to specify --eval_split_name
         | 
| 200 | 
            +
                    self.parser.add_argument("--eval_id", type=str, help="evaluation id")
         | 
| 201 | 
            +
                    self.parser.add_argument("--model_dir", type=str,
         | 
| 202 | 
            +
                                             help="dir contains the model file, will be converted to absolute path afterwards")
         | 
| 203 | 
            +
                    self.parser.add_argument("--tasks", type=str, nargs="+", choices=["VCMR", "SVMR", "VR"], default="SVMR",
         | 
| 204 | 
            +
                                             help="Which tasks to run."
         | 
| 205 | 
            +
                                                  "VCMR: Video Corpus Moment Retrieval;"
         | 
| 206 | 
            +
                                                  "SVMR: Single Video Moment Retrieval;"
         | 
| 207 | 
            +
                                                  "VR: regular Video Retrieval.")
         | 
    	
        baselines/clip_alignment_with_language/inference.py
    ADDED
    
    | @@ -0,0 +1,672 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import time
         | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            import pprint
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            from tqdm import tqdm, trange
         | 
| 7 | 
            +
            from collections import defaultdict, OrderedDict
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torch.backends.cudnn as cudnn
         | 
| 11 | 
            +
            from torch.utils.data import DataLoader
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from baselines.clip_alignment_with_language.config import TestOptions
         | 
| 14 | 
            +
            from baselines.clip_alignment_with_language.model import CALWithSub
         | 
| 15 | 
            +
            from baselines.clip_alignment_with_language.proposal_retrieval_dataset import \
         | 
| 16 | 
            +
                proposal_retrieval_collate, ProposalRetrievalEvalDataset, prepare_batch_inputs
         | 
| 17 | 
            +
            from utils.basic_utils import save_jsonl, save_json, load_json
         | 
| 18 | 
            +
            from utils.temporal_nms import temporal_non_maximum_suppression
         | 
| 19 | 
            +
            from utils.tensor_utils import pad_sequences_1d
         | 
| 20 | 
            +
            from standalone_eval.eval import eval_retrieval
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            import logging
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 25 | 
            +
            logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
         | 
| 26 | 
            +
                                datefmt="%Y-%m-%d %H:%M:%S",
         | 
| 27 | 
            +
                                level=logging.INFO)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def combine_single_video_proposal_embeddings(proposals_embedding_list, proposals_mask_list):
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
                Args:
         | 
| 33 | 
            +
                    proposals_embedding_list: list(torch.Tensor), bsz * (N_prop, N_clips, D_o)
         | 
| 34 | 
            +
                    proposals_mask_list: list(torch.Tensor), bsz * (N_prop, N_clips)
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
                if len(proposals_embedding_list) == 1:
         | 
| 37 | 
            +
                    return proposals_embedding_list[0], proposals_mask_list[0]
         | 
| 38 | 
            +
                else:  # > 1
         | 
| 39 | 
            +
                    max_n_clips = max([e.shape[1] for e in proposals_embedding_list])
         | 
| 40 | 
            +
                    n_proposals = sum([len(e) for e in proposals_embedding_list])
         | 
| 41 | 
            +
                    d = proposals_embedding_list[0].shape[2]
         | 
| 42 | 
            +
                    proposals_embedding = proposals_embedding_list[0].new_zeros((n_proposals, max_n_clips, d))
         | 
| 43 | 
            +
                    proposals_mask = proposals_mask_list[0].new_zeros((n_proposals, max_n_clips))
         | 
| 44 | 
            +
                    mask_lengths = [0, ] + [len(m) for m in proposals_mask_list]
         | 
| 45 | 
            +
                    mask_cumsum_lengths = np.cumsum(mask_lengths)
         | 
| 46 | 
            +
                    for idx, (e, m) in enumerate(zip(proposals_embedding_list, proposals_mask_list)):
         | 
| 47 | 
            +
                        proposals_embedding[mask_cumsum_lengths[idx]:mask_cumsum_lengths[idx + 1], :e.shape[1]] = e
         | 
| 48 | 
            +
                        proposals_mask[mask_cumsum_lengths[idx]:mask_cumsum_lengths[idx + 1], :m.shape[1]] = m
         | 
| 49 | 
            +
                    return proposals_embedding, proposals_mask
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            def compute_query_embeddings(model, eval_dataset, opt, load_gt_vid_name):
         | 
| 53 | 
            +
                """Use val set to do evaluation, remember to run with torch.no_grad().
         | 
| 54 | 
            +
                estimated size 20,000 (query) * 100 (hsz) * 4 / (1024**2) = 7.63 MB
         | 
| 55 | 
            +
                """
         | 
| 56 | 
            +
                model.eval()
         | 
| 57 | 
            +
                eval_dataset.set_data_mode("query")
         | 
| 58 | 
            +
                eval_dataset.load_gt_vid_name_for_query(load_gt_vid_name)
         | 
| 59 | 
            +
                query_eval_loader = DataLoader(eval_dataset,
         | 
| 60 | 
            +
                                               collate_fn=proposal_retrieval_collate,
         | 
| 61 | 
            +
                                               batch_size=opt.eval_query_bsz,
         | 
| 62 | 
            +
                                               num_workers=opt.num_workers,
         | 
| 63 | 
            +
                                               shuffle=False,
         | 
| 64 | 
            +
                                               pin_memory=opt.pin_memory)
         | 
| 65 | 
            +
                global_meta_list = []  # list(dicts)
         | 
| 66 | 
            +
                # n_query = min(len(eval_dataset), opt.eval_query_bsz) if opt.debug else len(eval_dataset)
         | 
| 67 | 
            +
                n_query = len(eval_dataset)
         | 
| 68 | 
            +
                global_query_embedding = torch.empty((n_query,
         | 
| 69 | 
            +
                                                      model.config.output_size),
         | 
| 70 | 
            +
                                                     dtype=torch.float32, device=opt.device)  # (N_q, D_o)
         | 
| 71 | 
            +
                for idx, batch in tqdm(enumerate(query_eval_loader),
         | 
| 72 | 
            +
                                       desc="Computing q embedding",
         | 
| 73 | 
            +
                                       total=len(query_eval_loader)):
         | 
| 74 | 
            +
                    global_meta_list.extend(batch[0])
         | 
| 75 | 
            +
                    model_inputs = prepare_batch_inputs(batch[1], device=opt.device, non_blocking=opt.pin_memory)
         | 
| 76 | 
            +
                    global_query_embedding[idx * opt.eval_query_bsz: (idx + 1) * opt.eval_query_bsz] = \
         | 
| 77 | 
            +
                        model.query_encoder(**model_inputs)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    if opt.debug:
         | 
| 80 | 
            +
                        break
         | 
| 81 | 
            +
                return global_meta_list, global_query_embedding
         | 
| 82 | 
            +
             | 
| 83 | 
            +
             | 
| 84 | 
            +
            def compute_proposal_embeddings(model, eval_dataset, opt):
         | 
| 85 | 
            +
                """Use val set to do evaluation, remember to run with torch.no_grad().
         | 
| 86 | 
            +
                estimated 1000 (videos) * 300 (proposals) * 20 (clips) * 100 (hsz) * 4 / (1024 ** 3) = 2.24 GB
         | 
| 87 | 
            +
                """
         | 
| 88 | 
            +
                model.eval()
         | 
| 89 | 
            +
                eval_dataset.set_data_mode("context")
         | 
| 90 | 
            +
                global_meta_list = []  # list(dicts)
         | 
| 91 | 
            +
                global_proposal_video_embedding_list = []  # list(torch.tensor), N_videos * [N_prop, N_clips, D_o]
         | 
| 92 | 
            +
                global_proposal_sub_embedding_list = []  # list(torch.tensor), N_videos * [N_prop, N_clips, D_o]
         | 
| 93 | 
            +
                global_proposal_video_mask_list = []  # list(torch.tensor), N_videos * [N_prop, N_clips]
         | 
| 94 | 
            +
                global_proposal_sub_mask_list = []  # list(torch.tensor), N_videos * [N_prop, N_clips]
         | 
| 95 | 
            +
                for idx, single_video_info in tqdm(enumerate(eval_dataset),
         | 
| 96 | 
            +
                                                   desc="Computing prop embedding for videos",
         | 
| 97 | 
            +
                                                   total=len(eval_dataset)):
         | 
| 98 | 
            +
                    global_meta_list.append(single_video_info["meta"])
         | 
| 99 | 
            +
                    if model.use_video or model.tef_only:
         | 
| 100 | 
            +
                        proposals_features_list = single_video_info["model_inputs"]["video_moment_features_list"]
         | 
| 101 | 
            +
                        proposals_mask_list = single_video_info["model_inputs"]["video_moment_mask_list"]
         | 
| 102 | 
            +
                        proposals_mask_list = [e.to(opt.device, non_blocking=opt.pin_memory) for e in proposals_mask_list]
         | 
| 103 | 
            +
                        proposals_embedding_list = []  # (N_prop, D_o)
         | 
| 104 | 
            +
                        for feat in proposals_features_list:
         | 
| 105 | 
            +
                            proposals_embedding_list.append(
         | 
| 106 | 
            +
                                model.moment_encoder(feat.to(opt.device, non_blocking=opt.pin_memory), module_name="video"))
         | 
| 107 | 
            +
                        p, m = combine_single_video_proposal_embeddings(proposals_embedding_list, proposals_mask_list)
         | 
| 108 | 
            +
                        global_proposal_video_embedding_list.append(p)
         | 
| 109 | 
            +
                        global_proposal_video_mask_list.append(m)
         | 
| 110 | 
            +
                    else:
         | 
| 111 | 
            +
                        global_proposal_video_embedding_list.append(None)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    if model.use_sub:
         | 
| 114 | 
            +
                        proposals_features_list = single_video_info["model_inputs"]["sub_moment_features_list"]
         | 
| 115 | 
            +
                        proposals_mask_list = single_video_info["model_inputs"]["sub_moment_mask_list"]
         | 
| 116 | 
            +
                        proposals_mask_list = [e.to(opt.device, non_blocking=opt.pin_memory) for e in proposals_mask_list]
         | 
| 117 | 
            +
                        proposals_embedding_list = []  # (N_prop, D_o)
         | 
| 118 | 
            +
                        for feat in proposals_features_list:
         | 
| 119 | 
            +
                            proposals_embedding_list.append(
         | 
| 120 | 
            +
                                model.moment_encoder(feat.to(opt.device, non_blocking=opt.pin_memory), module_name="sub"))
         | 
| 121 | 
            +
                        p, m = combine_single_video_proposal_embeddings(proposals_embedding_list, proposals_mask_list)
         | 
| 122 | 
            +
                        global_proposal_sub_embedding_list.append(p)
         | 
| 123 | 
            +
                        global_proposal_sub_mask_list.append(m)
         | 
| 124 | 
            +
                    else:
         | 
| 125 | 
            +
                        global_proposal_sub_embedding_list.append(None)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    if opt.debug and idx == 100:
         | 
| 128 | 
            +
                        break
         | 
| 129 | 
            +
                global_proposal_mask_list = global_proposal_sub_mask_list if model.use_sub else global_proposal_video_mask_list
         | 
| 130 | 
            +
                return global_meta_list, global_proposal_video_embedding_list, \
         | 
| 131 | 
            +
                       global_proposal_sub_embedding_list, global_proposal_mask_list
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
            def compute_query_proposal_distance(model, eval_dataset, opt, tasks=("SVMR",)):
         | 
| 135 | 
            +
                """compute and save query and video proposal embeddings,
         | 
| 136 | 
            +
                tasks: SVMR (single video moment retrieval), VCMR (video corpus moment retrieval)
         | 
| 137 | 
            +
                """
         | 
| 138 | 
            +
                is_svmr = "SVMR" in tasks
         | 
| 139 | 
            +
                is_vcmr = "VCMR" in tasks
         | 
| 140 | 
            +
                query_meta_list, query_embed = compute_query_embeddings(model, eval_dataset, opt,
         | 
| 141 | 
            +
                                                                        load_gt_vid_name=is_svmr)
         | 
| 142 | 
            +
                video_meta_list, video_prop_embed_list, sub_prop_embed_list, prop_mask_list = \
         | 
| 143 | 
            +
                    compute_proposal_embeddings(model, eval_dataset, opt)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                eval_res = dict(
         | 
| 146 | 
            +
                    query_meta=query_meta_list,  # N_q * dict()
         | 
| 147 | 
            +
                    video_meta=video_meta_list,  # N_videos * dict()
         | 
| 148 | 
            +
                    video2idx=eval_dataset.video2idx,  # dict {vid_name: index}
         | 
| 149 | 
            +
                    query_prop_dist_vcmr=[],  # N_videos * (N_q, N_prop), note N_prop is changing for each video.
         | 
| 150 | 
            +
                    query_prop_dist_svmr=[],  # N_q * (N_prop, ), each query has a GT video, no need to calc. for all.
         | 
| 151 | 
            +
                )
         | 
| 152 | 
            +
                if is_vcmr:
         | 
| 153 | 
            +
                    for v_prop_embed, s_prop_embed, prop_mask in tqdm(
         | 
| 154 | 
            +
                            zip(video_prop_embed_list, sub_prop_embed_list, prop_mask_list),
         | 
| 155 | 
            +
                            desc="Computing VCMR q to prop dist for videos",
         | 
| 156 | 
            +
                            total=len(video_prop_embed_list)):
         | 
| 157 | 
            +
                        query_prop_dist = model.compute_cdist_inference(
         | 
| 158 | 
            +
                            query_embed, v_prop_embed, s_prop_embed, prop_mask)  # (N_q, N_prop)
         | 
| 159 | 
            +
                        eval_res["query_prop_dist_vcmr"].append(query_prop_dist.cpu())
         | 
| 160 | 
            +
                        if opt.debug:
         | 
| 161 | 
            +
                            break
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                if is_svmr:
         | 
| 164 | 
            +
                    if opt.debug:
         | 
| 165 | 
            +
                        debug_query_meta = []
         | 
| 166 | 
            +
                    # this is different from video2idx
         | 
| 167 | 
            +
                    svmr_video2meta_idx = {e["vid_name"]: idx for idx, e in enumerate(video_meta_list)}
         | 
| 168 | 
            +
                    # logger.info("svmr_video2idx {}".format(list(svmr_video2idx.keys())[:3]))
         | 
| 169 | 
            +
                    for single_q_embed, single_q_meta in tqdm(zip(query_embed, query_meta_list),
         | 
| 170 | 
            +
                                                              desc="Computing SVMR q to prop dist for videos",
         | 
| 171 | 
            +
                                                              total=len(query_embed)):
         | 
| 172 | 
            +
                        # logger.info("single_q_meta[vid_name] {}".format(single_q_meta["vid_name"]))
         | 
| 173 | 
            +
                        if opt.debug:
         | 
| 174 | 
            +
                            if single_q_meta["vid_name"] not in svmr_video2meta_idx:
         | 
| 175 | 
            +
                                continue
         | 
| 176 | 
            +
                            debug_query_meta.append(single_q_meta)
         | 
| 177 | 
            +
                        q_gt_vid_meta_idx = svmr_video2meta_idx[single_q_meta["vid_name"]]
         | 
| 178 | 
            +
                        v_prop_embed = video_prop_embed_list[q_gt_vid_meta_idx]  # [N_prop, N_clips, D_o]
         | 
| 179 | 
            +
                        s_prop_embed = sub_prop_embed_list[q_gt_vid_meta_idx]  # [N_prop, N_clips, D_o]
         | 
| 180 | 
            +
                        prop_mask = prop_mask_list[q_gt_vid_meta_idx]  # [N_prop, N_clips]
         | 
| 181 | 
            +
                        query_prop_dist = model.compute_cdist_inference(
         | 
| 182 | 
            +
                            single_q_embed.unsqueeze(0), v_prop_embed, s_prop_embed, prop_mask)  # (1, N_prop)
         | 
| 183 | 
            +
                        eval_res["query_prop_dist_svmr"].append(query_prop_dist.squeeze(0).cpu().numpy())
         | 
| 184 | 
            +
                    if opt.debug:
         | 
| 185 | 
            +
                        eval_res["query_meta"] = debug_query_meta
         | 
| 186 | 
            +
                return eval_res
         | 
| 187 | 
            +
             | 
| 188 | 
            +
             | 
| 189 | 
            +
            def filter_vcmr_by_nms(all_video_predictions, nms_threshold=0.6,
         | 
| 190 | 
            +
                                   max_before_nms=1000, max_after_nms=100, score_col_idx=3):
         | 
| 191 | 
            +
                """ Apply non-maximum suppression for all the predictions for each video.
         | 
| 192 | 
            +
                1) group predictions by video index
         | 
| 193 | 
            +
                2) apply nms individually for each video index group
         | 
| 194 | 
            +
                3) combine and sort the predictions
         | 
| 195 | 
            +
                Args:
         | 
| 196 | 
            +
                    all_video_predictions: list(sublist),
         | 
| 197 | 
            +
                        Each sublist is [video_idx (int), st (float), ed(float), score (float)]
         | 
| 198 | 
            +
                        Note the scores are negative distances.
         | 
| 199 | 
            +
                    nms_threshold: float
         | 
| 200 | 
            +
                    max_before_nms: int
         | 
| 201 | 
            +
                    max_after_nms: int
         | 
| 202 | 
            +
                    score_col_idx: int
         | 
| 203 | 
            +
                Returns:
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                """
         | 
| 206 | 
            +
                predictions_neg_by_video_group = defaultdict(list)
         | 
| 207 | 
            +
                for pred in all_video_predictions[:max_before_nms]:
         | 
| 208 | 
            +
                    predictions_neg_by_video_group[pred[0]].append(pred[1:])  # [st (float), ed(float), score (float)]
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                predictions_by_video_group_neg_after_nms = dict()
         | 
| 211 | 
            +
                for video_idx, grouped_preds in predictions_neg_by_video_group.items():
         | 
| 212 | 
            +
                    predictions_by_video_group_neg_after_nms[video_idx] = \
         | 
| 213 | 
            +
                        temporal_non_maximum_suppression(grouped_preds, nms_threshold=nms_threshold)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                predictions_after_nms = []
         | 
| 216 | 
            +
                for video_idx, grouped_preds in predictions_by_video_group_neg_after_nms.items():
         | 
| 217 | 
            +
                    for pred in grouped_preds:
         | 
| 218 | 
            +
                        pred = [video_idx] + pred  # [video_idx (int), st (float), ed(float), score (float)]
         | 
| 219 | 
            +
                        predictions_after_nms.append(pred)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                # ranking happens across videos
         | 
| 222 | 
            +
                predictions_after_nms = sorted(predictions_after_nms,
         | 
| 223 | 
            +
                                               key=lambda x: x[score_col_idx],
         | 
| 224 | 
            +
                                               reverse=True)[:max_after_nms]  # descending order
         | 
| 225 | 
            +
                return predictions_after_nms
         | 
| 226 | 
            +
             | 
| 227 | 
            +
             | 
| 228 | 
            +
            def post_processing_vcmr_nms(vcmr_res, nms_thd=0.6, max_before_nms=1000, max_after_nms=100):
         | 
| 229 | 
            +
                """
         | 
| 230 | 
            +
                vcmr_res: list(dict), each dict is{
         | 
| 231 | 
            +
                    "desc": str,
         | 
| 232 | 
            +
                    "desc_id": int,
         | 
| 233 | 
            +
                    "predictions": list(sublist)  # each sublist is
         | 
| 234 | 
            +
                        [video_idx (int), st (float), ed(float), score (float)], video_idx could be different
         | 
| 235 | 
            +
                }
         | 
| 236 | 
            +
                """
         | 
| 237 | 
            +
                processed_vcmr_res = []
         | 
| 238 | 
            +
                for e in vcmr_res:
         | 
| 239 | 
            +
                    e["predictions"] = filter_vcmr_by_nms(e["predictions"],
         | 
| 240 | 
            +
                                                          nms_threshold=nms_thd,
         | 
| 241 | 
            +
                                                          max_before_nms=max_before_nms,
         | 
| 242 | 
            +
                                                          max_after_nms=max_after_nms)
         | 
| 243 | 
            +
                    processed_vcmr_res.append(e)
         | 
| 244 | 
            +
                return processed_vcmr_res
         | 
| 245 | 
            +
             | 
| 246 | 
            +
             | 
| 247 | 
            +
            def post_processing_svmr_nms(svmr_res, nms_thd=0.6, max_before_nms=1000, max_after_nms=100):
         | 
| 248 | 
            +
                """
         | 
| 249 | 
            +
                svmr_res: list(dict), each dict is
         | 
| 250 | 
            +
                    {"desc": str,
         | 
| 251 | 
            +
                     "desc_id": int,
         | 
| 252 | 
            +
                     "predictions": list(sublist)  # each sublist is
         | 
| 253 | 
            +
                        [video_idx (int), st (float), ed(float), score (float)], video_idx is the same.
         | 
| 254 | 
            +
                     }
         | 
| 255 | 
            +
                """
         | 
| 256 | 
            +
                processed_svmr_res = []
         | 
| 257 | 
            +
                for e in svmr_res:
         | 
| 258 | 
            +
                    # the predictions are sorted inside the nms func.
         | 
| 259 | 
            +
                    _predictions = [d[1:] for d in e["predictions"][:max_before_nms]]
         | 
| 260 | 
            +
                    _predictions = temporal_non_maximum_suppression(
         | 
| 261 | 
            +
                        _predictions, nms_threshold=nms_thd)[:max_after_nms]
         | 
| 262 | 
            +
                    _video_id = e["predictions"][0][0] # video_id is the same for all predictions
         | 
| 263 | 
            +
                    e["predictions"] = [[_video_id, ] + d for d in _predictions]
         | 
| 264 | 
            +
                    processed_svmr_res.append(e)
         | 
| 265 | 
            +
                return processed_svmr_res
         | 
| 266 | 
            +
             | 
| 267 | 
            +
             | 
| 268 | 
            +
            def generate_vcmr_predictions_from_res_with_external(eval_res, max_prop_per_query=300, query_bsz_in_sort=1000):
         | 
| 269 | 
            +
                """ This function is for Video Corpus Moment Retrieval (VCMR).
         | 
| 270 | 
            +
                Generate prediction file which could be evaluated using standalone_eval.eval.
         | 
| 271 | 
            +
                Args:
         | 
| 272 | 
            +
                    eval_res: dict(
         | 
| 273 | 
            +
                        query_meta=query_meta_list,  # N_q * dict(), each dict is {"desc_id": int, "desc": str}
         | 
| 274 | 
            +
                        video_meta=video_meta_list,  # N_videos * dict(), {"vid_name": str, "duration": float, "proposals": ndarray}
         | 
| 275 | 
            +
                        video2idx=eval_dataset.video2idx,  # dict {vid_name: index}
         | 
| 276 | 
            +
                        video_bsz_in_sort=[],  # N_videos * (N_q, N_prop)
         | 
| 277 | 
            +
                    )
         | 
| 278 | 
            +
                    max_prop_per_query: int or None. If None, generate ranking for all possible moments, else generate top {}.
         | 
| 279 | 
            +
                    query_bsz_in_sort: int, only sort a subset of queries at a time, it will be too large to sort all queries.
         | 
| 280 | 
            +
                return:
         | 
| 281 | 
            +
                    list(dicts): each dict is dict(desc=str, desc_id=int, predictions=list(sublist)),
         | 
| 282 | 
            +
                        each sublist is [vid_name (str), st (float), ed (float), score (float)], score is negative distance.
         | 
| 283 | 
            +
                """
         | 
| 284 | 
            +
                # video2idx
         | 
| 285 | 
            +
                video2idx = eval_res["video2idx"]
         | 
| 286 | 
            +
                video_meta = eval_res["video_meta"]
         | 
| 287 | 
            +
                query_meta = eval_res["query_meta"]
         | 
| 288 | 
            +
                video_idx2meta_idx = {video2idx[m["vid_name"]]: i for i, m in enumerate(video_meta)}
         | 
| 289 | 
            +
                external_query2video = eval_res["external_query2video"] if "external_query2video" in eval_res else None
         | 
| 290 | 
            +
                # 「query idx: [video meta idx]」
         | 
| 291 | 
            +
                external_query2video_meta_idx = {k: [video_idx2meta_idx[e] for e in v] for k, v in external_query2video.items()}
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                external_ordered_video_meta_indices = torch.LongTensor(
         | 
| 294 | 
            +
                    [external_query2video_meta_idx[e["desc_id"]] for e in query_meta])  # (Nq, 5)
         | 
| 295 | 
            +
                top_n_retrieved = external_ordered_video_meta_indices.shape[1]
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                # (N_videos, N_prop, N_q), (N_videos, N_prop)
         | 
| 298 | 
            +
                padded_dist, padded_mask = pad_sequences_1d([e.transpose(0, 1) for e in eval_res["query_prop_dist_vcmr"]],
         | 
| 299 | 
            +
                                                            dtype=eval_res["query_prop_dist_vcmr"][0].dtype,
         | 
| 300 | 
            +
                                                            device=eval_res["query_prop_dist_vcmr"][0].device)
         | 
| 301 | 
            +
                # putting 'NaN' into the invalid bits, torch.sort considers 'NaN' as larger than any number!!!
         | 
| 302 | 
            +
                padded_dist += (padded_mask.unsqueeze(2) == 0).float() * 1e10
         | 
| 303 | 
            +
                n_videos, n_prop, n_q = padded_dist.shape
         | 
| 304 | 
            +
                padded_dist = padded_dist.permute(2, 0, 1)  # (N_q, N_videos, N_prop)
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                # get only top retrieved, N_videos now decreased to top_n_retrieved
         | 
| 307 | 
            +
                row_indices = torch.arange(n_q, device=padded_dist.device)
         | 
| 308 | 
            +
                padded_dist = torch.stack([
         | 
| 309 | 
            +
                    padded_dist[row_indices, external_ordered_video_meta_indices[:, col_idx]]
         | 
| 310 | 
            +
                    for col_idx in range(top_n_retrieved)], dim=1)  # (N_q, 5, N_prop)
         | 
| 311 | 
            +
                n_videos = top_n_retrieved
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                padded_dist = padded_dist.view(n_q, -1).contiguous()  # (N_q, N_video*N_prop)
         | 
| 314 | 
            +
                print("n_videos, n_prop, n_q {}".format((n_videos, n_prop, n_q)))
         | 
| 315 | 
            +
                print("padded_dist, {}".format(padded_dist.shape))
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                sorted_distances, sorted_indices = torch.topk(padded_dist.to(torch.device("cuda:0"), non_blocking=True),
         | 
| 318 | 
            +
                                                              k=min(max_prop_per_query, n_videos * n_prop),
         | 
| 319 | 
            +
                                                              dim=1, largest=False, sorted=True)  # (N_q, max_prop_per_query) * 2
         | 
| 320 | 
            +
                print("orted_distances {}, sorted_indices {}".format(sorted_distances.shape, sorted_indices.shape))
         | 
| 321 | 
            +
                sorted_distances = - sorted_distances.cpu().numpy()
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                # (N_q, max_prop_per_query) * 2, prop_indices: inside video indices.
         | 
| 324 | 
            +
                video_meta_indices_retrieved = torch.floor(sorted_indices.float() / n_prop).long().cpu().numpy()
         | 
| 325 | 
            +
                # map back to original video idx (not video meta idx, but real video idx)
         | 
| 326 | 
            +
                video_indices = np.array([[external_query2video[query_meta[i]["desc_id"]][j] for j in r]
         | 
| 327 | 
            +
                                          for i, r in enumerate(video_meta_indices_retrieved)])  # (N_q, max_prop_per_query)
         | 
| 328 | 
            +
                prop_indices = torch.remainder(sorted_indices, n_prop).cpu().numpy()  # (N_q, max_prop_per_query)
         | 
| 329 | 
            +
                print("video_indices {}, prop_indices {}".format(video_indices.shape, prop_indices.shape))
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                vr_res = []
         | 
| 332 | 
            +
                for i in trange(n_q, desc="[VR] Loop over queries to generate predictions"):
         | 
| 333 | 
            +
                    row = video_indices[i]
         | 
| 334 | 
            +
                    score_row = - sorted_distances[i]
         | 
| 335 | 
            +
                    cur_vr_redictions = []
         | 
| 336 | 
            +
                    for j, video_idx in enumerate(row):
         | 
| 337 | 
            +
                        cur_vr_redictions.append([int(video_idx), 0, 0, float(score_row[j])])
         | 
| 338 | 
            +
                    cur_query_pred = dict(
         | 
| 339 | 
            +
                        desc_id=query_meta[i]["desc_id"],
         | 
| 340 | 
            +
                        desc=query_meta[i]["desc"],
         | 
| 341 | 
            +
                        predictions=cur_vr_redictions
         | 
| 342 | 
            +
                    )
         | 
| 343 | 
            +
                    vr_res.append(cur_query_pred)
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                vcmr_res = []
         | 
| 346 | 
            +
                logger.debug("sorted_indices {}".format(sorted_indices.shape))
         | 
| 347 | 
            +
                logger.debug("sorted_distances {}".format(sorted_distances.shape))
         | 
| 348 | 
            +
                out_bounds_cnt = 0
         | 
| 349 | 
            +
                for idx, (v_row_indices, p_row_indices) in tqdm(enumerate(zip(video_indices, prop_indices)),
         | 
| 350 | 
            +
                                                                desc="[VCMR] Loop over queries to generate predictions",
         | 
| 351 | 
            +
                                                                total=n_q):  # query
         | 
| 352 | 
            +
                    sorted_distances_row = - sorted_distances[idx]  # converted to negative distance
         | 
| 353 | 
            +
                    # [video_idx(int), st(float), ed(float), score(float)]
         | 
| 354 | 
            +
                    cur_ranked_predictions = []
         | 
| 355 | 
            +
                    for col_idx, (v_col_idx, p_col_idx) in enumerate(zip(v_row_indices, p_row_indices)):
         | 
| 356 | 
            +
                        cur_proposals = eval_res["video_meta"][video_idx2meta_idx[v_col_idx]]["proposals"]
         | 
| 357 | 
            +
                        cur_pred = []
         | 
| 358 | 
            +
                        cur_pred += [int(v_col_idx), ]
         | 
| 359 | 
            +
                        # what is wrong with the indexing below??? (out of bounds), but results seems fine???
         | 
| 360 | 
            +
                        # Not a bug. Since there might be less than max_before_nms proposals from the top retrieved videos
         | 
| 361 | 
            +
                        if p_col_idx >= len(cur_proposals):
         | 
| 362 | 
            +
                            out_bounds_cnt += 1
         | 
| 363 | 
            +
                            p_col_idx = len(cur_proposals)-1
         | 
| 364 | 
            +
                        cur_pred += cur_proposals[p_col_idx].tolist()
         | 
| 365 | 
            +
                        cur_pred += [float(sorted_distances_row[col_idx])]
         | 
| 366 | 
            +
                        cur_ranked_predictions.append(cur_pred)
         | 
| 367 | 
            +
                    cur_query_pred = dict(
         | 
| 368 | 
            +
                        desc_id=eval_res["query_meta"][idx]["desc_id"],
         | 
| 369 | 
            +
                        desc=eval_res["query_meta"][idx]["desc"],
         | 
| 370 | 
            +
                        predictions=cur_ranked_predictions
         | 
| 371 | 
            +
                    )
         | 
| 372 | 
            +
                    vcmr_res.append(cur_query_pred)
         | 
| 373 | 
            +
                logger.info("[DEBUG] out_bounds_cnt {}".format(out_bounds_cnt))
         | 
| 374 | 
            +
                return vcmr_res, vr_res
         | 
| 375 | 
            +
             | 
| 376 | 
            +
             | 
| 377 | 
            +
            def generate_vcmr_predictions_from_res(eval_res, max_prop_per_query=300, query_bsz_in_sort=1000):
         | 
| 378 | 
            +
                """ This function is for Video Corpus Moment Retrieval (VCMR).
         | 
| 379 | 
            +
                Generate prediction file which could be evaluated using standalone_eval.eval.
         | 
| 380 | 
            +
                Args:
         | 
| 381 | 
            +
                    eval_res: dict(
         | 
| 382 | 
            +
                        query_meta=query_meta_list,  # N_q * dict(), each dict is {"desc_id": int, "desc": str}
         | 
| 383 | 
            +
                        video_meta=video_meta_list,  # N_videos * dict(), {"vid_name": str, "duration": float, "proposals": ndarray}
         | 
| 384 | 
            +
                        video2idx=eval_dataset.video2idx,  # dict {vid_name: index}
         | 
| 385 | 
            +
                        video_bsz_in_sort=[],  # N_videos * (N_q, N_prop)
         | 
| 386 | 
            +
                    )
         | 
| 387 | 
            +
                    max_prop_per_query: int or None. If None, generate ranking for all possible moments, else generate top {}.
         | 
| 388 | 
            +
                    query_bsz_in_sort: int, only sort a subset of queries at a time, it will be too large to sort all queries.
         | 
| 389 | 
            +
                return:
         | 
| 390 | 
            +
                    list(dicts): each dict is dict(desc=str, desc_id=int, predictions=list(sublist)),
         | 
| 391 | 
            +
                        each sublist is [vid_name (str), st (float), ed (float), score (float)], score is negative distance.
         | 
| 392 | 
            +
                """
         | 
| 393 | 
            +
                # video2idx
         | 
| 394 | 
            +
                video2idx = eval_res["video2idx"]
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                # (N_videos, N_prop, N_q), (N_videos, N_prop)
         | 
| 397 | 
            +
                padded_dist, padded_mask = pad_sequences_1d([e.transpose(0, 1) for e in eval_res["query_prop_dist_vcmr"]],
         | 
| 398 | 
            +
                                                            dtype=eval_res["query_prop_dist_vcmr"][0].dtype,
         | 
| 399 | 
            +
                                                            device=eval_res["query_prop_dist_vcmr"][0].device)
         | 
| 400 | 
            +
                # putting 'NaN' into the invalid bits, torch.sort considers 'NaN' as larger than any number!!!
         | 
| 401 | 
            +
                padded_dist += (padded_mask.unsqueeze(2) == 0).float() * 1e10
         | 
| 402 | 
            +
                n_videos, n_prop, n_q = padded_dist.shape
         | 
| 403 | 
            +
                print("n_videos, n_prop, n_q {}".format((n_videos, n_prop, n_q)))
         | 
| 404 | 
            +
                padded_dist = padded_dist.view(n_videos * n_prop, n_q).transpose(0, 1).contiguous()  # (N_q, N_video*N_prop)
         | 
| 405 | 
            +
                print("padded_dist, {}".format(padded_dist.shape))
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                sorted_distances, sorted_indices = torch.topk(padded_dist.to(torch.device("cuda:0"), non_blocking=True),
         | 
| 408 | 
            +
                                                              k=min(max_prop_per_query, n_videos * n_prop),
         | 
| 409 | 
            +
                                                              dim=1, largest=False, sorted=True)  # (N_q, max_prop_per_query) * 2
         | 
| 410 | 
            +
                sorted_distances = - sorted_distances.cpu().numpy()
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                # (N_q, max_prop_per_query) * 2, prop_indices: inside video indices.
         | 
| 413 | 
            +
                video_meta_indices = torch.floor(sorted_indices.float() / n_prop).long().cpu().numpy()
         | 
| 414 | 
            +
                prop_indices = torch.remainder(sorted_indices, n_prop).cpu().numpy()
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                vr_res = []
         | 
| 417 | 
            +
                query_meta = eval_res["query_meta"]
         | 
| 418 | 
            +
                for i in trange(n_q, desc="[VR] Loop over queries to generate predictions"):
         | 
| 419 | 
            +
                    row = video_meta_indices[i]
         | 
| 420 | 
            +
                    score_row = - sorted_distances[i]
         | 
| 421 | 
            +
                    cur_vr_redictions = []
         | 
| 422 | 
            +
                    for j, meta_idx in enumerate(row):
         | 
| 423 | 
            +
                        video_idx = video2idx[eval_res["video_meta"][meta_idx]["vid_name"]]
         | 
| 424 | 
            +
                        cur_vr_redictions.append([video_idx, 0, 0, float(score_row[j])])
         | 
| 425 | 
            +
                    cur_query_pred = dict(
         | 
| 426 | 
            +
                        desc_id=query_meta[i]["desc_id"],
         | 
| 427 | 
            +
                        desc=query_meta[i]["desc"],
         | 
| 428 | 
            +
                        predictions=cur_vr_redictions
         | 
| 429 | 
            +
                    )
         | 
| 430 | 
            +
                    vr_res.append(cur_query_pred)
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                vcmr_res = []
         | 
| 433 | 
            +
                logger.debug("sorted_indices {}".format(sorted_indices.shape))
         | 
| 434 | 
            +
                logger.debug("sorted_distances {}".format(sorted_distances.shape))
         | 
| 435 | 
            +
                for idx, (vm_row_indices, p_row_indices) in tqdm(enumerate(zip(video_meta_indices, prop_indices)),
         | 
| 436 | 
            +
                                                                 desc="[VCMR] Loop over queries to generate predictions",
         | 
| 437 | 
            +
                                                                 total=n_q):  # query
         | 
| 438 | 
            +
                    sorted_distances_row = - sorted_distances[idx]  # converted to negative distance
         | 
| 439 | 
            +
                    # [video_idx(int), st(float), ed(float), score(float)]
         | 
| 440 | 
            +
                    cur_ranked_predictions = []
         | 
| 441 | 
            +
                    for col_idx, (v_col_idx, p_col_idx) in enumerate(zip(vm_row_indices, p_row_indices)):
         | 
| 442 | 
            +
                        cur_pred = []
         | 
| 443 | 
            +
                        cur_pred += [video2idx[eval_res["video_meta"][v_col_idx]["vid_name"]], ]
         | 
| 444 | 
            +
                        cur_pred += eval_res["video_meta"][v_col_idx]["proposals"][p_col_idx].tolist()
         | 
| 445 | 
            +
                        cur_pred += [float(sorted_distances_row[col_idx])]
         | 
| 446 | 
            +
                        cur_ranked_predictions.append(cur_pred)
         | 
| 447 | 
            +
                    cur_query_pred = dict(
         | 
| 448 | 
            +
                        desc_id=eval_res["query_meta"][idx]["desc_id"],
         | 
| 449 | 
            +
                        desc=eval_res["query_meta"][idx]["desc"],
         | 
| 450 | 
            +
                        predictions=cur_ranked_predictions
         | 
| 451 | 
            +
                    )
         | 
| 452 | 
            +
                    vcmr_res.append(cur_query_pred)
         | 
| 453 | 
            +
                return vcmr_res, vr_res
         | 
| 454 | 
            +
             | 
| 455 | 
            +
             | 
| 456 | 
            +
            def generate_svmr_predictions_from_res(eval_res, max_prop_per_query=None):
         | 
| 457 | 
            +
                """ This function is for Video Corpus Moment Retrieval (VCMR).
         | 
| 458 | 
            +
                Generate prediction file which could be evaluated using standalone_eval.eval.
         | 
| 459 | 
            +
                Args:
         | 
| 460 | 
            +
                    eval_res: dict(
         | 
| 461 | 
            +
                        query_meta=query_meta_list,  # N_q * dict(), each dict is {"desc_id": int, "desc": str}
         | 
| 462 | 
            +
                        video_meta=video_meta_list,  # N_videos * dict(), {"vid_name": str, "duration": float, "proposals": ndarray}
         | 
| 463 | 
            +
                        video2idx=eval_dataset.video2idx,  # dict {vid_name: index}
         | 
| 464 | 
            +
                        query_prop_dist_svmr=[],  # N_q * (N_prop, )
         | 
| 465 | 
            +
                    )
         | 
| 466 | 
            +
                    max_prop_per_query: not used
         | 
| 467 | 
            +
                return:
         | 
| 468 | 
            +
                    list(dicts): each dict is dict(desc=str, desc_id=int, predictions=list(sublist)),
         | 
| 469 | 
            +
                        each sublist is [vid_name (str), st (float), ed (float), score (float)], score is negative distance.
         | 
| 470 | 
            +
                """
         | 
| 471 | 
            +
                video2idx = eval_res["video2idx"]
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                svmr_res = []
         | 
| 474 | 
            +
                svmr_video2meta_idx = {e["vid_name"]: idx for idx, e in enumerate(eval_res["video_meta"])}
         | 
| 475 | 
            +
                for idx, (q_p_dist, q_m) in tqdm(enumerate(zip(eval_res["query_prop_dist_svmr"], eval_res["query_meta"])),
         | 
| 476 | 
            +
                                                 desc="Loop over queries to generate predictions",
         | 
| 477 | 
            +
                                                 total=len(eval_res["query_prop_dist_svmr"])):  # query
         | 
| 478 | 
            +
                    sorted_indices = np.argsort(q_p_dist)  # (N_prop, )  # ascending order, distance
         | 
| 479 | 
            +
                    if max_prop_per_query is not None:
         | 
| 480 | 
            +
                        sorted_indices = sorted_indices[:max_prop_per_query]
         | 
| 481 | 
            +
                    v_eval_idx = video2idx[q_m["vid_name"]]
         | 
| 482 | 
            +
                    v_meta_idx = svmr_video2meta_idx[q_m["vid_name"]]
         | 
| 483 | 
            +
                    proposals = eval_res["video_meta"][v_meta_idx]["proposals"]  # (N_p, 2)
         | 
| 484 | 
            +
                    # [video_idx(int), st(float), ed(float), score(float)]
         | 
| 485 | 
            +
                    cur_ranked_predictions = [
         | 
| 486 | 
            +
                        [v_eval_idx, ] + proposals[sort_idx].tolist() + [- round(float(q_p_dist[sort_idx]), 4), ]
         | 
| 487 | 
            +
                        for sort_idx in sorted_indices]
         | 
| 488 | 
            +
                    cur_query_pred = dict(
         | 
| 489 | 
            +
                        desc_id=q_m["desc_id"],
         | 
| 490 | 
            +
                        desc=q_m["desc"],
         | 
| 491 | 
            +
                        predictions=cur_ranked_predictions
         | 
| 492 | 
            +
                    )
         | 
| 493 | 
            +
                    svmr_res.append(cur_query_pred)
         | 
| 494 | 
            +
                return svmr_res
         | 
| 495 | 
            +
             | 
| 496 | 
            +
             | 
| 497 | 
            +
            POST_PROCESSING_MMS_FUNC = {
         | 
| 498 | 
            +
                "SVMR": post_processing_svmr_nms,
         | 
| 499 | 
            +
                "VCMR": post_processing_vcmr_nms
         | 
| 500 | 
            +
            }
         | 
| 501 | 
            +
             | 
| 502 | 
            +
             | 
| 503 | 
            +
            def get_submission_top_n(submission, top_n=100):
         | 
| 504 | 
            +
                def get_prediction_top_n(list_dict_predictions, top_n):
         | 
| 505 | 
            +
                    top_n_res = []
         | 
| 506 | 
            +
                    for e in list_dict_predictions:
         | 
| 507 | 
            +
                        e["predictions"] = e["predictions"][:top_n]
         | 
| 508 | 
            +
                        top_n_res.append(e)
         | 
| 509 | 
            +
                    return top_n_res
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                top_n_submission = dict(video2idx=submission["video2idx"], )
         | 
| 512 | 
            +
                for k in submission:
         | 
| 513 | 
            +
                    if k != "video2idx":
         | 
| 514 | 
            +
                        top_n_submission[k] = get_prediction_top_n(submission[k], top_n)
         | 
| 515 | 
            +
                return top_n_submission
         | 
| 516 | 
            +
             | 
| 517 | 
            +
             | 
| 518 | 
            +
            def load_external_vr_res(external_vr_res_path, top_n_vr_videos=5):
         | 
| 519 | 
            +
                """return a mapping from desc_id to top retrieved video id"""
         | 
| 520 | 
            +
                external_vr_res = load_json(external_vr_res_path)
         | 
| 521 | 
            +
                external_vr_res = get_submission_top_n(external_vr_res, top_n=top_n_vr_videos)["VR"]
         | 
| 522 | 
            +
                query2video = {e["desc_id"]: [sub_e[0] for sub_e in e["predictions"]] for e in external_vr_res}
         | 
| 523 | 
            +
                return query2video
         | 
| 524 | 
            +
             | 
| 525 | 
            +
             | 
| 526 | 
            +
            def eval_epoch(model, eval_dataset, opt, save_submission_filename,
         | 
| 527 | 
            +
                           tasks=("SVMR",), max_before_nms=1000, max_after_nms=100):
         | 
| 528 | 
            +
                model.eval()
         | 
| 529 | 
            +
                logger.info("Computing scores")
         | 
| 530 | 
            +
                logger.info("Start timing")
         | 
| 531 | 
            +
                # times = []  # do not use
         | 
| 532 | 
            +
                # for _ in range(3):
         | 
| 533 | 
            +
                #     st_time = time.time()
         | 
| 534 | 
            +
                if opt.use_intermediate:
         | 
| 535 | 
            +
                    intermediate_cache_path = os.path.join(opt.results_dir, "{}_eval_res.pt".format(opt.eval_split_name))
         | 
| 536 | 
            +
                    if not os.path.exists(intermediate_cache_path):
         | 
| 537 | 
            +
                        logger.info("Saving intermediate results {}.".format(intermediate_cache_path))
         | 
| 538 | 
            +
                        eval_res = compute_query_proposal_distance(model, eval_dataset, opt, tasks=tasks)
         | 
| 539 | 
            +
                        torch.save(eval_res, intermediate_cache_path)
         | 
| 540 | 
            +
                    else:
         | 
| 541 | 
            +
                        logger.info("Loading intermediate results {}.".format(intermediate_cache_path))
         | 
| 542 | 
            +
                        eval_res = torch.load(intermediate_cache_path)
         | 
| 543 | 
            +
                else:
         | 
| 544 | 
            +
                    logger.info("Running without saving intermediate results, you might want to turn on --use_intermediate.")
         | 
| 545 | 
            +
                    eval_res = compute_query_proposal_distance(model, eval_dataset, opt, tasks=tasks)
         | 
| 546 | 
            +
                # del model  # We dont need model anymore
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                # eval_res = compute_query_proposal_distance(model, eval_dataset, opt, tasks=tasks)
         | 
| 549 | 
            +
             | 
| 550 | 
            +
                logger.info("Generating predictions from scores")
         | 
| 551 | 
            +
                eval_submission_raw = dict(video2idx=eval_res["video2idx"])
         | 
| 552 | 
            +
                if "SVMR" in tasks:
         | 
| 553 | 
            +
                    eval_submission_raw["SVMR"] = generate_svmr_predictions_from_res(
         | 
| 554 | 
            +
                        eval_res, max_prop_per_query=max_before_nms)
         | 
| 555 | 
            +
                # vcmr_loading_time = 0
         | 
| 556 | 
            +
                if "VCMR" in tasks:
         | 
| 557 | 
            +
                    if opt.external_inference_vr_res_path is not None:
         | 
| 558 | 
            +
                        logger.info("Using external VR results from {}".format(opt.external_inference_vr_res_path))
         | 
| 559 | 
            +
                        # vcmr_loading_time = time.time()
         | 
| 560 | 
            +
                        eval_res["external_query2video"] = load_external_vr_res(
         | 
| 561 | 
            +
                            opt.external_inference_vr_res_path, top_n_vr_videos=5)
         | 
| 562 | 
            +
                        # vcmr_loading_time = time.time() - vcmr_loading_time
         | 
| 563 | 
            +
                        vcmr_res, vr_res = generate_vcmr_predictions_from_res_with_external(
         | 
| 564 | 
            +
                            eval_res, max_prop_per_query=max_before_nms)
         | 
| 565 | 
            +
                    else:
         | 
| 566 | 
            +
                        vcmr_res, vr_res = generate_vcmr_predictions_from_res(
         | 
| 567 | 
            +
                            eval_res, max_prop_per_query=max_before_nms)
         | 
| 568 | 
            +
                    eval_submission_raw["VCMR"] = vcmr_res
         | 
| 569 | 
            +
                    eval_submission_raw["VR"] = vr_res
         | 
| 570 | 
            +
                    # times += [time.time() - st_time - vcmr_loading_time]
         | 
| 571 | 
            +
                # times = torch.FloatTensor(times)
         | 
| 572 | 
            +
                IOU_THDS = (0.5, 0.7)
         | 
| 573 | 
            +
             | 
| 574 | 
            +
                logger.info("Saving/Evaluating before nms results")
         | 
| 575 | 
            +
                submission_path = os.path.join(opt.results_dir, save_submission_filename)
         | 
| 576 | 
            +
                eval_submission = get_submission_top_n(eval_submission_raw, top_n=max_after_nms)
         | 
| 577 | 
            +
                if max_after_nms < 1000:
         | 
| 578 | 
            +
                    save_json(eval_submission, submission_path)
         | 
| 579 | 
            +
                else:
         | 
| 580 | 
            +
                    torch.save(eval_submission, submission_path.replace(".json", ".pt"))
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                metrics = eval_retrieval(eval_submission, eval_dataset.query_data,
         | 
| 583 | 
            +
                                         iou_thds=IOU_THDS, match_number=not opt.debug, verbose=opt.debug,
         | 
| 584 | 
            +
                                         use_desc_type=opt.dset_name == "tvr")
         | 
| 585 | 
            +
                # metrics["time_avg"] = float(times.mean())
         | 
| 586 | 
            +
                # metrics["time_std"] = float(times.std())
         | 
| 587 | 
            +
                save_metrics_path = submission_path.replace(".json", "_metrics.json")
         | 
| 588 | 
            +
                save_json(metrics, save_metrics_path, save_pretty=True, sort_keys=False)
         | 
| 589 | 
            +
                latest_file_paths = [submission_path, save_metrics_path]
         | 
| 590 | 
            +
             | 
| 591 | 
            +
                if opt.nms_thd != -1:
         | 
| 592 | 
            +
                    logger.info("Performing nms with nms_thd {}".format(opt.nms_thd))
         | 
| 593 | 
            +
                    eval_submission_after_nms = dict(video2idx=eval_submission_raw["video2idx"])
         | 
| 594 | 
            +
                    for k, nms_func in POST_PROCESSING_MMS_FUNC.items():
         | 
| 595 | 
            +
                        if k in eval_submission_raw:
         | 
| 596 | 
            +
                            eval_submission_after_nms[k] = nms_func(eval_submission_raw[k],
         | 
| 597 | 
            +
                                                                    nms_thd=opt.nms_thd,
         | 
| 598 | 
            +
                                                                    max_before_nms=max_before_nms,
         | 
| 599 | 
            +
                                                                    max_after_nms=max_after_nms)
         | 
| 600 | 
            +
             | 
| 601 | 
            +
                    logger.info("Saving/Evaluating nms results")
         | 
| 602 | 
            +
                    submission_nms_path = submission_path.replace(".json", "_nms_thd_{}.json".format(opt.nms_thd))
         | 
| 603 | 
            +
                    save_json(eval_submission_after_nms, submission_nms_path)
         | 
| 604 | 
            +
                    metrics_nms = eval_retrieval(eval_submission_after_nms, eval_dataset.query_data,
         | 
| 605 | 
            +
                                                 iou_thds=IOU_THDS, match_number=not opt.debug, verbose=opt.debug)
         | 
| 606 | 
            +
                    save_metrics_nms_path = submission_nms_path.replace(".json", "_metrics.json")
         | 
| 607 | 
            +
                    save_json(metrics_nms, save_metrics_nms_path, save_pretty=True, sort_keys=False)
         | 
| 608 | 
            +
                    latest_file_paths += [submission_nms_path, save_metrics_nms_path]
         | 
| 609 | 
            +
                else:
         | 
| 610 | 
            +
                    metrics_nms = None
         | 
| 611 | 
            +
                return metrics, metrics_nms, latest_file_paths
         | 
| 612 | 
            +
             | 
| 613 | 
            +
             | 
| 614 | 
            +
            def setup_model(opt):
         | 
| 615 | 
            +
                """Load model from checkpoint and move to specified device"""
         | 
| 616 | 
            +
                checkpoint = torch.load(opt.ckpt_filepath)
         | 
| 617 | 
            +
                model = CALWithSub(checkpoint["model_cfg"])
         | 
| 618 | 
            +
                model.load_state_dict(checkpoint["model"])
         | 
| 619 | 
            +
                logger.info("Loaded model saved at epoch {} from checkpoint: {}"
         | 
| 620 | 
            +
                            .format(checkpoint["epoch"], opt.ckpt_filepath))
         | 
| 621 | 
            +
             | 
| 622 | 
            +
                if opt.device.type == "cuda":
         | 
| 623 | 
            +
                    logger.info("CUDA enabled.")
         | 
| 624 | 
            +
                    model.to(opt.device)
         | 
| 625 | 
            +
                    if len(opt.device_ids) > 1:
         | 
| 626 | 
            +
                        logger.info("Use multi GPU", opt.device_ids)
         | 
| 627 | 
            +
                        model = torch.nn.DataParallel(model, device_ids=opt.device_ids)  # use multi GPU
         | 
| 628 | 
            +
                return model
         | 
| 629 | 
            +
             | 
| 630 | 
            +
             | 
| 631 | 
            +
            def start_inference():
         | 
| 632 | 
            +
                logger.info("Setup config, data and model...")
         | 
| 633 | 
            +
                opt = TestOptions().parse()
         | 
| 634 | 
            +
                cudnn.benchmark = False
         | 
| 635 | 
            +
                cudnn.deterministic = True
         | 
| 636 | 
            +
             | 
| 637 | 
            +
                assert opt.eval_path is not None
         | 
| 638 | 
            +
                eval_dataset = ProposalRetrievalEvalDataset(
         | 
| 639 | 
            +
                    dset_name=opt.dset_name,
         | 
| 640 | 
            +
                    model_type=opt.model_type,
         | 
| 641 | 
            +
                    eval_split_name=opt.eval_split_name,  # should only be val set
         | 
| 642 | 
            +
                    data_path=opt.eval_path,
         | 
| 643 | 
            +
                    desc_bert_path_or_handler=opt.desc_bert_path,
         | 
| 644 | 
            +
                    sub_bert_path_or_handler=opt.sub_bert_path,
         | 
| 645 | 
            +
                    max_desc_len=opt.max_desc_l,
         | 
| 646 | 
            +
                    corpus_path=opt.corpus_path,
         | 
| 647 | 
            +
                    vid_feat_path_or_handler=opt.vid_feat_path,
         | 
| 648 | 
            +
                    clip_length=opt.clip_length,
         | 
| 649 | 
            +
                    eval_proposal_bsz=opt.eval_proposal_bsz,
         | 
| 650 | 
            +
                    ctx_mode=opt.ctx_mode,
         | 
| 651 | 
            +
                    data_mode="query",
         | 
| 652 | 
            +
                    h5driver=opt.h5driver,
         | 
| 653 | 
            +
                    data_ratio=opt.data_ratio,
         | 
| 654 | 
            +
                    normalize_vfeat=not opt.no_norm_vfeat,
         | 
| 655 | 
            +
                    normalize_tfeat=not opt.no_norm_tfeat,
         | 
| 656 | 
            +
                )
         | 
| 657 | 
            +
             | 
| 658 | 
            +
                model = setup_model(opt)
         | 
| 659 | 
            +
                save_submission_filename = \
         | 
| 660 | 
            +
                    "inference_{}_{}_{}_predictions_{}.json".format(
         | 
| 661 | 
            +
                        opt.dset_name, opt.eval_split_name, opt.eval_id, "_".join(opt.tasks))
         | 
| 662 | 
            +
                logger.info("Starting inference...")
         | 
| 663 | 
            +
                with torch.no_grad():
         | 
| 664 | 
            +
                    metrics_no_nms, metrics_nms, latest_file_paths = \
         | 
| 665 | 
            +
                        eval_epoch(model, eval_dataset, opt, save_submission_filename, tasks=opt.tasks,
         | 
| 666 | 
            +
                                   max_before_nms=opt.max_before_nms, max_after_nms=opt.max_after_nms)
         | 
| 667 | 
            +
                logger.info("metrics_no_nms \n{}".format(pprint.pformat(metrics_no_nms, indent=4)))
         | 
| 668 | 
            +
                logger.info("metrics_nms \n{}".format(pprint.pformat(metrics_nms, indent=4)))
         | 
| 669 | 
            +
             | 
| 670 | 
            +
             | 
| 671 | 
            +
            if __name__ == '__main__':
         | 
| 672 | 
            +
                start_inference()
         | 
    	
        baselines/clip_alignment_with_language/local_utils/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        baselines/clip_alignment_with_language/local_utils/__pycache__/__init__.cpython-311.pyc
    ADDED
    
    | Binary file (217 Bytes). View file | 
|  | 
    	
        baselines/clip_alignment_with_language/local_utils/__pycache__/compute_proposal_upper_bound.cpython-311.pyc
    ADDED
    
    | Binary file (8.16 kB). View file | 
|  | 
    	
        baselines/clip_alignment_with_language/local_utils/__pycache__/proposal.cpython-311.pyc
    ADDED
    
    | Binary file (7.9 kB). View file | 
|  | 
    	
        baselines/clip_alignment_with_language/local_utils/compute_proposal_upper_bound.py
    ADDED
    
    | @@ -0,0 +1,117 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Compute oracle upper bound for a given proposal method, which acts like
         | 
| 3 | 
            +
            a reversed recall, where we recall the GT timestamp pairs in the set of
         | 
| 4 | 
            +
            generated proposals.
         | 
| 5 | 
            +
            """
         | 
| 6 | 
            +
            import pprint
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            from tqdm import tqdm
         | 
| 9 | 
            +
            from collections import Counter
         | 
| 10 | 
            +
            from utils.basic_utils import load_jsonl, save_json
         | 
| 11 | 
            +
            from standalone_eval.eval import compute_temporal_iou_batch
         | 
| 12 | 
            +
            from baselines.clip_alignment_with_language.local_utils.proposal import get_proposal_interface, ProposalConfigs
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def get_didemo_agreed_ts(times_list):
         | 
| 16 | 
            +
                """
         | 
| 17 | 
            +
                input example: [[1, 1], [1, 1], [1, 1], [0, 0]],
         | 
| 18 | 
            +
                return: [1, 1]"""
         | 
| 19 | 
            +
                times_str_list = [tuple(e) for e in times_list]
         | 
| 20 | 
            +
                times_str_list_counter = Counter(times_str_list)
         | 
| 21 | 
            +
                most_frequent_times = times_str_list_counter.most_common(1)[0][0]
         | 
| 22 | 
            +
                return most_frequent_times
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def get_proposals_for_single_desc_video_pair(single_data, proposal_fn, dset_name):
         | 
| 26 | 
            +
                proposal_info = dict(
         | 
| 27 | 
            +
                    vid_name=single_data["vid_name"],
         | 
| 28 | 
            +
                    desc_id=single_data["desc_id"],
         | 
| 29 | 
            +
                    gt_ts=single_data["ts"] if dset_name != "didemo" else get_didemo_agreed_ts(single_data["ts"]),
         | 
| 30 | 
            +
                    proposals=proposal_fn(video_id="", metadata={"duration": single_data["duration"]}),
         | 
| 31 | 
            +
                )
         | 
| 32 | 
            +
                proposal_info["proposal_ious"] = compute_temporal_iou_batch(
         | 
| 33 | 
            +
                    proposal_info["proposals"], proposal_info["gt_ts"])
         | 
| 34 | 
            +
                return proposal_info
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def get_proposals_for_videos(datalist, dset_name):
         | 
| 38 | 
            +
                """datalist list(dict): each dict is
         | 
| 39 | 
            +
                {"desc_id": str/int, "duration": float, "ts": [st (float), ed (float)], ...}
         | 
| 40 | 
            +
                Note for Didemo dataset, "ts" entry is a list of [st (float), ed (float)] from different annotators,
         | 
| 41 | 
            +
                here we use the most frequent ts, we break ties by randomly sample one
         | 
| 42 | 
            +
                """
         | 
| 43 | 
            +
                proposal_interface = get_proposal_interface(dset_name)
         | 
| 44 | 
            +
                video_proposals_list = []
         | 
| 45 | 
            +
                for e in tqdm(datalist, desc="Computing video proposals"):
         | 
| 46 | 
            +
                    video_proposals_list.append(
         | 
| 47 | 
            +
                        get_proposals_for_single_desc_video_pair(e, proposal_interface, dset_name))
         | 
| 48 | 
            +
                return video_proposals_list
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def is_recalled_single_moment(proposal_ious, iou_thds=(0.5, 0.7)):
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
                Args:
         | 
| 54 | 
            +
                    proposal_ious: np.ndarray, shape (N_proposal, )
         | 
| 55 | 
            +
                    iou_thds: set, temporal IoU thresholds
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                Returns:
         | 
| 58 | 
            +
                    list(bool), len == len(iou_thds), indicates whether recall under a iou_thd is found.
         | 
| 59 | 
            +
                """
         | 
| 60 | 
            +
                recalled = [False, ] * len(iou_thds)
         | 
| 61 | 
            +
                for idx, iou_thd in enumerate(iou_thds):
         | 
| 62 | 
            +
                    recalled[idx] = np.sum(proposal_ious >= iou_thd) >= 1  # at least one
         | 
| 63 | 
            +
                return recalled
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            def compute_proposal_recall_upper_bound(video_proposals_list, iou_thds=(0.5, 0.7)):
         | 
| 67 | 
            +
                """video_proposals_list from get_proposals_for_videos()"""
         | 
| 68 | 
            +
                iou_corrects = np.empty((len(video_proposals_list), 2), dtype=np.float32)
         | 
| 69 | 
            +
                for idx, d in tqdm(enumerate(video_proposals_list),
         | 
| 70 | 
            +
                                   desc="Computing recall for videos",
         | 
| 71 | 
            +
                                   total=len(video_proposals_list)):
         | 
| 72 | 
            +
                    iou_corrects[idx] = is_recalled_single_moment(d["proposal_ious"],
         | 
| 73 | 
            +
                                                                  iou_thds=iou_thds)
         | 
| 74 | 
            +
                recall_by_iou = {iou_thd: float(np.mean(iou_corrects[:, idx]))
         | 
| 75 | 
            +
                                 for idx, iou_thd in enumerate(iou_thds)}
         | 
| 76 | 
            +
                return recall_by_iou
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            def main_compute_upper_bound():
         | 
| 80 | 
            +
                import argparse
         | 
| 81 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 82 | 
            +
                parser.add_argument("-dset_name", type=str, choices=["tvr"])
         | 
| 83 | 
            +
                parser.add_argument("-eval_file_path", type=str, help="path to the file containing data to be evaluated")
         | 
| 84 | 
            +
                parser.add_argument("-save_path", type=str, help="path to save the results")
         | 
| 85 | 
            +
                parser.add_argument("-verbose", action="store_true")
         | 
| 86 | 
            +
                args = parser.parse_args()
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                eval_datalist = load_jsonl(args.eval_file_path)
         | 
| 89 | 
            +
                video_proposals_list = get_proposals_for_videos(eval_datalist, args.dset_name)
         | 
| 90 | 
            +
                recall_metrics = compute_proposal_recall_upper_bound(video_proposals_list, iou_thds=(0.5, 0.7))
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                video_proposals_list_by_video = {}
         | 
| 93 | 
            +
                for p in video_proposals_list:
         | 
| 94 | 
            +
                    if p["vid_name"] in video_proposals_list_by_video:
         | 
| 95 | 
            +
                        continue
         | 
| 96 | 
            +
                    else:
         | 
| 97 | 
            +
                        video_proposals_list_by_video[p["vid_name"]] = p
         | 
| 98 | 
            +
                video_proposals_list_by_video = list(video_proposals_list_by_video.values())
         | 
| 99 | 
            +
                total_n_clips_in_proposals = \
         | 
| 100 | 
            +
                    np.sum([np.sum(e["proposals"][:, 1] - e["proposals"][:, 0]) for e in video_proposals_list_by_video])
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                results = dict(
         | 
| 103 | 
            +
                    avg_num_proposals=float(np.mean([len(e["proposals"]) for e in video_proposals_list_by_video])),
         | 
| 104 | 
            +
                    total_num_proposals=int(np.sum([len(e["proposals"]) for e in video_proposals_list_by_video])),
         | 
| 105 | 
            +
                    recall_metrics=recall_metrics,
         | 
| 106 | 
            +
                    dset_name=args.dset_name,
         | 
| 107 | 
            +
                    filename=args.eval_file_path,
         | 
| 108 | 
            +
                    proposal_config=ProposalConfigs[args.dset_name]
         | 
| 109 | 
            +
                )
         | 
| 110 | 
            +
                results["avg_clip_per_proposal"] = total_n_clips_in_proposals / results["total_num_proposals"]
         | 
| 111 | 
            +
                save_json(results, args.save_path, save_pretty=True)
         | 
| 112 | 
            +
                if args.verbose:
         | 
| 113 | 
            +
                    pprint.pprint(results)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            if __name__ == '__main__':
         | 
| 117 | 
            +
                main_compute_upper_bound()
         | 
    	
        baselines/clip_alignment_with_language/local_utils/proposal.py
    ADDED
    
    | @@ -0,0 +1,181 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # MIT License
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Copyright (c) 2018 Victor Escorcia Castillo
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 6 | 
            +
            # of this software and associated documentation files (the "Software"), to deal
         | 
| 7 | 
            +
            # in the Software without restriction, including without limitation the rights
         | 
| 8 | 
            +
            # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 9 | 
            +
            # copies of the Software, and to permit persons to whom the Software is
         | 
| 10 | 
            +
            # furnished to do so, subject to the following conditions:
         | 
| 11 | 
            +
            #
         | 
| 12 | 
            +
            # The above copyright notice and this permission notice shall be included in all
         | 
| 13 | 
            +
            # copies or substantial portions of the Software.
         | 
| 14 | 
            +
            #
         | 
| 15 | 
            +
            # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 16 | 
            +
            # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 17 | 
            +
            # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 18 | 
            +
            # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 19 | 
            +
            # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 20 | 
            +
            # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 21 | 
            +
            # SOFTWARE.
         | 
| 22 | 
            +
            # ==============================================================================
         | 
| 23 | 
            +
            """
         | 
| 24 | 
            +
            Group multiple methods to generate salient temporal windows in a video"""
         | 
| 25 | 
            +
            import itertools
         | 
| 26 | 
            +
            import numpy as np
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            PROPOSAL_SCHEMES = ['DidemoICCV17SS', 'SlidingWindowMSRSS']
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class TemporalProposalsBase:
         | 
| 32 | 
            +
                """Base class (signature) to generate temporal candidate in video"""
         | 
| 33 | 
            +
                def __call__(self, video_id, metadata=None, feature_collection=None):
         | 
| 34 | 
            +
                    raise NotImplementedError('Implement with the signature above')
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            class DidemoICCV17SS(TemporalProposalsBase):
         | 
| 38 | 
            +
                """Original search space of moments proposed in ICCV-2017
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                Attributes:
         | 
| 41 | 
            +
                    clip_length_min (float) : minimum length, in seconds, of a video clip.
         | 
| 42 | 
            +
                    proposals (numpy array) : of shape [21, 2] representing all the
         | 
| 43 | 
            +
                        possible temporal segments of valid annotations of DiDeMo dataset.
         | 
| 44 | 
            +
                        It represents the search space of a temporal localization
         | 
| 45 | 
            +
                        algorithm.
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                Reference: Hendricks et al. Localizing Moments in Video with Natural
         | 
| 48 | 
            +
                    Language. ICCV 2017.
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
                clip_length_min = 5.0
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def __init__(self, *args, dtype=np.float32, **kwargs):
         | 
| 53 | 
            +
                    clips_indices = [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]
         | 
| 54 | 
            +
                    for i in itertools.combinations(range(len(clips_indices)), 2):
         | 
| 55 | 
            +
                        clips_indices.append(i)
         | 
| 56 | 
            +
                    self.proposals = np.array(clips_indices, dtype=dtype)
         | 
| 57 | 
            +
                    self.proposals *= self.clip_length_min
         | 
| 58 | 
            +
                    self.proposals[:, 1] += self.clip_length_min
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def __call__(self, *args, **kwargs):
         | 
| 61 | 
            +
                    return self.proposals
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            class SlidingWindowMSRSS(TemporalProposalsBase):
         | 
| 65 | 
            +
                """Multi-scale sliding window with relative stride within the same scale
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                Attributes:
         | 
| 68 | 
            +
                    length (float) : length of smallest window.
         | 
| 69 | 
            +
                    scales (sequence of int) : duration of moments relative to
         | 
| 70 | 
            +
                        `length`.
         | 
| 71 | 
            +
                    stride (float) : relative stride between two windows with the same
         | 
| 72 | 
            +
                        duration. We used different strides for each scale rounding it
         | 
| 73 | 
            +
                        towards a multiple of `length`. Note that the minimum stride is
         | 
| 74 | 
            +
                        `length` for any window will be the `length` itself.
         | 
| 75 | 
            +
                    dtype (numpy.dtype) :
         | 
| 76 | 
            +
                """
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def __init__(self, length, scales, stride=0.5, round_base=0.5, dtype=np.float32):
         | 
| 79 | 
            +
                    self.length = length
         | 
| 80 | 
            +
                    self.scales = scales
         | 
| 81 | 
            +
                    self.round_base = round_base
         | 
| 82 | 
            +
                    self.relative_stride = stride
         | 
| 83 | 
            +
                    # pick strides per scale that are multiples of length
         | 
| 84 | 
            +
                    self.strides = [max(round(s * stride / round_base) * round_base, round_base)
         | 
| 85 | 
            +
                                    * length for s in scales]
         | 
| 86 | 
            +
                    self.dtype = dtype
         | 
| 87 | 
            +
                    assert len(scales) > 0
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def sliding_windows(self, t_end, t_start=0):
         | 
| 90 | 
            +
                    """sliding canonical windows over a given time interval"""
         | 
| 91 | 
            +
                    windows_ = []
         | 
| 92 | 
            +
                    for i, stride in enumerate(self.strides):
         | 
| 93 | 
            +
                        num_i = np.ceil((t_end - t_start) / stride)
         | 
| 94 | 
            +
                        windows_i = np.empty((int(num_i), 2), dtype=np.float32)
         | 
| 95 | 
            +
                        windows_i[:, 0] = np.arange(t_start, t_end, stride)
         | 
| 96 | 
            +
                        windows_i[:, 1] = windows_i[:, 0] + self.length * self.scales[i]
         | 
| 97 | 
            +
                        windows_i[windows_i[:, 1] > t_end, 1] = t_end
         | 
| 98 | 
            +
                        windows_.append(windows_i)
         | 
| 99 | 
            +
                        # print("--------------------------------{}".format(i))
         | 
| 100 | 
            +
                        # print(windows_i)
         | 
| 101 | 
            +
                    # import sys
         | 
| 102 | 
            +
                    # sys.exit(1)
         | 
| 103 | 
            +
                    windows = np.concatenate(windows_, axis=0)
         | 
| 104 | 
            +
                    # Hacky way to make windows fit inside video
         | 
| 105 | 
            +
                    # It implies windows at the end may not belong to the set spanned by
         | 
| 106 | 
            +
                    # length and scales.
         | 
| 107 | 
            +
                    return np.unique(windows, axis=0)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                def __call__(self, video_id, metadata=None, feature_collection=None):
         | 
| 110 | 
            +
                    """return: (N_window, 2), each row contains (start, end)"""
         | 
| 111 | 
            +
                    duration = metadata.get('duration')
         | 
| 112 | 
            +
                    assert duration is not None
         | 
| 113 | 
            +
                    return self.sliding_windows(duration)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            ProposalConfigs = {
         | 
| 117 | 
            +
                "didemo": {
         | 
| 118 | 
            +
                    "proposal_interface": "DidemoICCV17SS",
         | 
| 119 | 
            +
                    "clip_length": 2.5,
         | 
| 120 | 
            +
                },
         | 
| 121 | 
            +
                "tvr": {
         | 
| 122 | 
            +
                    "length": 3,  # min proposal length
         | 
| 123 | 
            +
                    "scales": [1, 2, 4, 8],
         | 
| 124 | 
            +
                    "stride": 0.3,
         | 
| 125 | 
            +
                    "round_base": 1,
         | 
| 126 | 
            +
                    "min_proposal_length": 3,  # length * min(scales)
         | 
| 127 | 
            +
                    "clip_length": 1.5,  # length should be divisible by clip_length
         | 
| 128 | 
            +
                    "proposal_interface": "SlidingWindowMSRSS",
         | 
| 129 | 
            +
                },
         | 
| 130 | 
            +
                "anet_cap": {
         | 
| 131 | 
            +
                    "length": 5,
         | 
| 132 | 
            +
                    "scales": [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26],
         | 
| 133 | 
            +
                    "stride": 0.3,
         | 
| 134 | 
            +
                    "round_base": 1,
         | 
| 135 | 
            +
                    "min_proposal_length": 10,  # length * min(scales)
         | 
| 136 | 
            +
                    "clip_length": 5,  # length * min(scales) / 2
         | 
| 137 | 
            +
                    "proposal_interface": "SlidingWindowMSRSS",
         | 
| 138 | 
            +
                },
         | 
| 139 | 
            +
                "charades_sta": {
         | 
| 140 | 
            +
                    "length": 3,
         | 
| 141 | 
            +
                    "scales": [2, 3, 4, 5, 6, 7, 8],
         | 
| 142 | 
            +
                    "stride": 0.3,
         | 
| 143 | 
            +
                    "round_base": 1,
         | 
| 144 | 
            +
                    "min_proposal_length": 6,  # length * min(scales)
         | 
| 145 | 
            +
                    "clip_length": 3,  # length * min(scales) / 2
         | 
| 146 | 
            +
                    "proposal_interface": "SlidingWindowMSRSS",
         | 
| 147 | 
            +
                },
         | 
| 148 | 
            +
                "profiling": {
         | 
| 149 | 
            +
                    "length": 5,
         | 
| 150 | 
            +
                    "scales": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
         | 
| 151 | 
            +
                    "stride": 0.3,
         | 
| 152 | 
            +
                    "round_base": 1,
         | 
| 153 | 
            +
                    "clip_length": 5,  # length * min(scales) / 2
         | 
| 154 | 
            +
                    "proposal_interface": "SlidingWindowMSRSS",
         | 
| 155 | 
            +
                },
         | 
| 156 | 
            +
            }
         | 
| 157 | 
            +
            """
         | 
| 158 | 
            +
            'min_clip_length' is used to uniformly segment the video into smaller clips, it is a half of
         | 
| 159 | 
            +
            the 'min_proposal_length'. Thus we can enforce each moment has at least 2 clips.
         | 
| 160 | 
            +
            """
         | 
| 161 | 
            +
             | 
| 162 | 
            +
             | 
| 163 | 
            +
            def get_proposal_interface(dset_name):
         | 
| 164 | 
            +
                """ dset_name (str): one of ["tvr"] """
         | 
| 165 | 
            +
                assert dset_name in ProposalConfigs
         | 
| 166 | 
            +
                if dset_name == "didemo":
         | 
| 167 | 
            +
                    return DidemoICCV17SS()
         | 
| 168 | 
            +
                else:
         | 
| 169 | 
            +
                    arg_names = ["length", "scales", "stride", "round_base"]
         | 
| 170 | 
            +
                    func_args = {k: ProposalConfigs[dset_name][k] for k in arg_names}
         | 
| 171 | 
            +
                    return SlidingWindowMSRSS(**func_args)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
             | 
| 174 | 
            +
            if __name__ == '__main__':
         | 
| 175 | 
            +
                test_fns_args = [(DidemoICCV17SS, (),),
         | 
| 176 | 
            +
                                 (SlidingWindowMSRSS, (1.5, [2, 4, 6, 12]))]
         | 
| 177 | 
            +
                for fn_i, args_i in test_fns_args:
         | 
| 178 | 
            +
                    proposal_fn = fn_i(*args_i)
         | 
| 179 | 
            +
                    x = proposal_fn('hola', {'duration': 15})
         | 
| 180 | 
            +
                    if fn_i == DidemoICCV17SS:
         | 
| 181 | 
            +
                        assert len(x) == 21
         | 
    	
        baselines/clip_alignment_with_language/local_utils/tvr_proposal_test_log.txt
    ADDED
    
    | @@ -0,0 +1,61 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
             | 
| 2 | 
            +
            """
         | 
| 3 | 
            +
            {'avg_num_proposals': 158.30197338228544,
         | 
| 4 | 
            +
             'dset_name': 'tvr',
         | 
| 5 | 
            +
             'filename': 'data/retrieval_release_data_with_ids/tvr_val_release.jsonl',
         | 
| 6 | 
            +
             'proposal_config': {'length': 3,
         | 
| 7 | 
            +
                                 'proposal_interface': 'SlidingWindowMSRSS',
         | 
| 8 | 
            +
                                 'round_base': 1,
         | 
| 9 | 
            +
                                 'scales': [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16],
         | 
| 10 | 
            +
                                 'stride': 0.3},
         | 
| 11 | 
            +
             'recall_metrics': {0.5: 0.8927030563354492, 0.7: 0.6690225005149841},
         | 
| 12 | 
            +
             'total_num_proposals': 344940}
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            {'avg_num_proposals': 213.3295089490592,
         | 
| 16 | 
            +
             'dset_name': 'tvr',
         | 
| 17 | 
            +
             'filename': 'data/retrieval_release_data_with_ids/tvr_val_release.jsonl',
         | 
| 18 | 
            +
             'proposal_config': {'length': 3,
         | 
| 19 | 
            +
                                 'min_clip_length': 1.5,
         | 
| 20 | 
            +
                                 'min_proposal_length': 3,
         | 
| 21 | 
            +
                                 'proposal_interface': 'SlidingWindowMSRSS',
         | 
| 22 | 
            +
                                 'round_base': 0.5,
         | 
| 23 | 
            +
                                 'scales': [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16],
         | 
| 24 | 
            +
                                 'stride': 0.3},
         | 
| 25 | 
            +
             'recall_metrics': {0.5: 0.9612666368484497, 0.7: 0.8215695023536682},
         | 
| 26 | 
            +
             'total_num_proposals': 464845}
         | 
| 27 | 
            +
             --
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            {'avg_num_proposals': 213.3295089490592,
         | 
| 31 | 
            +
             'dset_name': 'tvr',
         | 
| 32 | 
            +
             'filename': '../../data/retrieval_release_data_with_ids/tvr_val_release.jsonl',
         | 
| 33 | 
            +
             'proposal_config': {'length': 3,
         | 
| 34 | 
            +
                                 'proposal_interface': 'SlidingWindowMSRSS',
         | 
| 35 | 
            +
                                 'round_base': 0.5,
         | 
| 36 | 
            +
                                 'scales': [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16],
         | 
| 37 | 
            +
                                 'stride': 0.3},
         | 
| 38 | 
            +
             'recall_metrics': {0.5: 0.9612666368484497, 0.7: 0.8215695023536682}}
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            {'avg_num_proposals': 263.3845800826067,
         | 
| 42 | 
            +
             'dset_name': 'tvr',
         | 
| 43 | 
            +
             'filename': '../../data/retrieval_release_data_with_ids/tvr_val_release.jsonl',
         | 
| 44 | 
            +
             'proposal_config': {'length': 3,
         | 
| 45 | 
            +
                                 'proposal_interface': 'SlidingWindowMSRSS',
         | 
| 46 | 
            +
                                 'round_base': 0.5,
         | 
| 47 | 
            +
                                 'scales': [0.5, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16],
         | 
| 48 | 
            +
                                 'stride': 0.3},
         | 
| 49 | 
            +
             'recall_metrics': {0.5: 0.9841211438179016, 0.7: 0.8567232489585876}}
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            {'avg_num_proposals': 242.97246443322626,
         | 
| 53 | 
            +
             'dset_name': 'tvr',
         | 
| 54 | 
            +
             'filename': '../../data/retrieval_release_data_with_ids/tvr_val_release.jsonl',
         | 
| 55 | 
            +
             'proposal_config': {'length': 3,
         | 
| 56 | 
            +
                                 'proposal_interface': 'SlidingWindowMSRSS',
         | 
| 57 | 
            +
                                 'round_base': 0.5,
         | 
| 58 | 
            +
                                 'scales': [0.5, 1, 2, 3, 4, 5, 6, 7, 8],
         | 
| 59 | 
            +
                                 'stride': 0.3},
         | 
| 60 | 
            +
             'recall_metrics': {0.5: 0.9608076810836792, 0.7: 0.8212941884994507}}
         | 
| 61 | 
            +
            """
         | 
    	
        baselines/clip_alignment_with_language/mix_model_prediction.py
    ADDED
    
    | @@ -0,0 +1,86 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Implement the CAL + CAL (TEF) model mentioned in
         | 
| 3 | 
            +
            ```
         | 
| 4 | 
            +
            @article{Escorcia2019TemporalLO,
         | 
| 5 | 
            +
              title={Temporal Localization of Moments in Video Collections with Natural Language},
         | 
| 6 | 
            +
              author={Victor Escorcia and Mattia Soldan and Josef Sivic and Bernard Ghanem and Bryan Russell},
         | 
| 7 | 
            +
              journal={ArXiv},
         | 
| 8 | 
            +
              year={2019},
         | 
| 9 | 
            +
              volume={abs/1907.12763}
         | 
| 10 | 
            +
            }
         | 
| 11 | 
            +
            ```
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            Methods:
         | 
| 14 | 
            +
                1, Give top200 predictions for each query in CAL then using CAL (TEF) to re-rank.
         | 
| 15 | 
            +
                2, This is approximated by re-ranking the top200 CAL using top1000 CAL(TEF) -- we assume they will be all covered.
         | 
| 16 | 
            +
            """
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
            import subprocess
         | 
| 20 | 
            +
            import numpy as np
         | 
| 21 | 
            +
            from tqdm import tqdm
         | 
| 22 | 
            +
            from utils.basic_utils import load_json, save_json
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def load_saved_res(pred_path):
         | 
| 26 | 
            +
                if pred_path.endswith(".json"):
         | 
| 27 | 
            +
                    pred = load_json(pred_path)
         | 
| 28 | 
            +
                else:
         | 
| 29 | 
            +
                    pred = torch.load(pred_path)
         | 
| 30 | 
            +
                vcmr_res = {e["desc_id"]: e for e in pred["VCMR"]}
         | 
| 31 | 
            +
                video2idx = pred["video2idx"]
         | 
| 32 | 
            +
                return vcmr_res, video2idx
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            def main_mix_results(pred_path, tef_pred_path, save_path, max_after_nms=100):
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
                Args:
         | 
| 38 | 
            +
                    pred_path: contains top-200 VCMR predictions
         | 
| 39 | 
            +
                    tef_pred_path: contains top-1000 VCMR predictions
         | 
| 40 | 
            +
                    save_path:
         | 
| 41 | 
            +
                    max_after_nms: int,
         | 
| 42 | 
            +
                Returns:
         | 
| 43 | 
            +
                    save
         | 
| 44 | 
            +
                """
         | 
| 45 | 
            +
                vcmr_res, video2idx = load_saved_res(pred_path)
         | 
| 46 | 
            +
                tef_vcmr_res, video2idx = load_saved_res(tef_pred_path)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                reranked_vcmr_res = {}
         | 
| 49 | 
            +
                num_valid = []
         | 
| 50 | 
            +
                for desc_id, preds in tqdm(vcmr_res.items(), desc="Loop over the predictions"):
         | 
| 51 | 
            +
                    tef_preds = tef_vcmr_res[desc_id]["predictions"]
         | 
| 52 | 
            +
                    pred_moments = set([tuple(e[:3]) for e in preds["predictions"]])
         | 
| 53 | 
            +
                    reranked_moments = [e for e in tef_preds if tuple(e[:3]) in pred_moments][:max_after_nms]
         | 
| 54 | 
            +
                    num_valid += [len(reranked_moments)]
         | 
| 55 | 
            +
                    if len(reranked_moments) != 100:
         | 
| 56 | 
            +
                        reranked_moments += reranked_moments[:100 - len(reranked_moments)]
         | 
| 57 | 
            +
                    reranked_vcmr_res[desc_id] = dict(
         | 
| 58 | 
            +
                        predictions=reranked_moments,
         | 
| 59 | 
            +
                        desc_id=desc_id,
         | 
| 60 | 
            +
                        desc=preds["desc"]
         | 
| 61 | 
            +
                    )
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                print("There are {} moments founded on average".format(np.mean(num_valid)))
         | 
| 64 | 
            +
                reranked_predictions = dict(
         | 
| 65 | 
            +
                    VCMR=list(reranked_vcmr_res.values()),
         | 
| 66 | 
            +
                    video2idx=video2idx
         | 
| 67 | 
            +
                )
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                save_json(reranked_predictions, save_path)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            if __name__ == '__main__':
         | 
| 73 | 
            +
                import argparse
         | 
| 74 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 75 | 
            +
                parser.add_argument("--pred_path", type=str, help="path to prediction res")
         | 
| 76 | 
            +
                parser.add_argument("--tef_pred_path", type=str, help="path to TEF prediction res")
         | 
| 77 | 
            +
                parser.add_argument("--save_path", type=str, help="path to save the re-ranked predictions, same dir as --pred_path")
         | 
| 78 | 
            +
                parser.add_argument("--gt_path", type=str, help="path to ground truth file")
         | 
| 79 | 
            +
                args = parser.parse_args()
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                main_mix_results(args.pred_path, args.tef_pred_path, args.save_path)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                metrics_path = args.save_path.replace(".json", "_metrics.json")
         | 
| 84 | 
            +
                eval_cmd = "python standalone_eval/eval.py --submission_path " + args.save_path + " --gt_path " + args.gt_path + \
         | 
| 85 | 
            +
                    " --save_path " + metrics_path
         | 
| 86 | 
            +
                results = subprocess.run(eval_cmd, shell=True)
         | 
    	
        baselines/clip_alignment_with_language/model.py
    ADDED
    
    | @@ -0,0 +1,299 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            from utils.model_utils import RNNEncoder
         | 
| 5 | 
            +
            from easydict import EasyDict as edict
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            cal_base_cfg = edict(
         | 
| 9 | 
            +
                visual_input_size=2048,  # changes based on visual input type
         | 
| 10 | 
            +
                textual_input_size=768,
         | 
| 11 | 
            +
                query_feat_size=768,
         | 
| 12 | 
            +
                visual_hidden_size=500,  #
         | 
| 13 | 
            +
                output_size=100,
         | 
| 14 | 
            +
                embedding_size=768,
         | 
| 15 | 
            +
                lstm_hidden_size=1000,
         | 
| 16 | 
            +
                margin=0.1,  # margin for ranking loss
         | 
| 17 | 
            +
                loss_type="hinge",  # loss type, 'hinge' or 'lse'
         | 
| 18 | 
            +
                inter_loss_weight=0.4,  # weight for inter negatives
         | 
| 19 | 
            +
                ctx_mode="video"
         | 
| 20 | 
            +
            )
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            class CAL(nn.Module):
         | 
| 24 | 
            +
                def __init__(self, config):
         | 
| 25 | 
            +
                    super(CAL, self).__init__()
         | 
| 26 | 
            +
                    self.config = config
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    self.moment_mlp = nn.Sequential(
         | 
| 29 | 
            +
                        nn.Linear(config.visual_input_size, config.visual_hidden_size),
         | 
| 30 | 
            +
                        nn.ReLU(True),
         | 
| 31 | 
            +
                        nn.Linear(config.visual_hidden_size, config.output_size),
         | 
| 32 | 
            +
                    )
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    self.query_lstm = RNNEncoder(word_embedding_size=config.embedding_size,
         | 
| 35 | 
            +
                                                 hidden_size=config.lstm_hidden_size,
         | 
| 36 | 
            +
                                                 bidirectional=False,
         | 
| 37 | 
            +
                                                 rnn_type="lstm",
         | 
| 38 | 
            +
                                                 dropout_p=0,
         | 
| 39 | 
            +
                                                 n_layers=1,
         | 
| 40 | 
            +
                                                 return_outputs=False)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    self.query_linear = nn.Linear(config.lstm_hidden_size, config.output_size)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def moment_encoder(self, moment_feat):
         | 
| 45 | 
            +
                    """moment_feat: (N, L_clip, D_v)"""
         | 
| 46 | 
            +
                    return F.normalize(self.moment_mlp(moment_feat), p=2, dim=-1)  # (N, L_clip, D_o)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def query_encoder(self, query_feat, query_mask):
         | 
| 49 | 
            +
                    """
         | 
| 50 | 
            +
                    Args:
         | 
| 51 | 
            +
                        query_feat: (N, L_q, D_q), torch.float32
         | 
| 52 | 
            +
                        query_mask: (N, L_q), torch.float32, with 1 indicates valid query, 0 indicates mask
         | 
| 53 | 
            +
                    """
         | 
| 54 | 
            +
                    _, hidden = self.query_lstm(query_feat, torch.sum(query_mask, dim=1).long())
         | 
| 55 | 
            +
                    return F.normalize(self.query_linear(hidden), p=2, dim=-1)  # (N, D_o)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def compute_pdist(self, query_embedding, moment_feat, moment_mask):
         | 
| 58 | 
            +
                    """ pairwise L2 distance
         | 
| 59 | 
            +
                    Args:
         | 
| 60 | 
            +
                        query_embedding: (N, D_o)
         | 
| 61 | 
            +
                        moment_feat: (N, L_clip, D_v)
         | 
| 62 | 
            +
                        moment_mask: (N, L_clip), torch.float32, where 1 indicates valid, 0 indicates padding
         | 
| 63 | 
            +
                    """
         | 
| 64 | 
            +
                    moment_embedding = self.moment_encoder(moment_feat)  # (N, L_clip, D_o)
         | 
| 65 | 
            +
                    moment_clip_dist = torch.sum((moment_embedding - query_embedding.unsqueeze(1)) ** 2, dim=2)  # (N, L_clip)
         | 
| 66 | 
            +
                    moment_dist = torch.sum(moment_clip_dist * moment_mask, dim=1) / moment_mask.sum(1)  # (N, )
         | 
| 67 | 
            +
                    return moment_dist  # (N, )
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                @classmethod
         | 
| 70 | 
            +
                def compute_cdist_inference(cls, query_embeddings, moment_embeddings, moment_mask):
         | 
| 71 | 
            +
                    """ Compute L2 distance for every possible pair of queries and proposals. This is different from
         | 
| 72 | 
            +
                    compute_pdist as the latter computes only pairs at each row.
         | 
| 73 | 
            +
                    Args:
         | 
| 74 | 
            +
                        query_embeddings: (N_q, D_o)
         | 
| 75 | 
            +
                        moment_embeddings: (N_prop, N_clips, D_o)
         | 
| 76 | 
            +
                        moment_mask: (N_prop, N_clips)
         | 
| 77 | 
            +
                    return:
         | 
| 78 | 
            +
                        query_moment_scores: (N_q, N_prop)
         | 
| 79 | 
            +
                    """
         | 
| 80 | 
            +
                    # sync device
         | 
| 81 | 
            +
                    query_device = query_embeddings.device  # convert to cuda if we want to use GPU
         | 
| 82 | 
            +
                    if moment_embeddings.device != query_device:
         | 
| 83 | 
            +
                        moment_embeddings = moment_embeddings.to(query_device)
         | 
| 84 | 
            +
                        moment_mask = moment_mask.to(query_device)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    # compute
         | 
| 87 | 
            +
                    n_query = query_embeddings.shape[0]
         | 
| 88 | 
            +
                    n_prop, n_clips, d = moment_embeddings.shape
         | 
| 89 | 
            +
                    query_clip_dist = torch.cdist(
         | 
| 90 | 
            +
                        query_embeddings, moment_embeddings.reshape(-1, d), p=2) ** 2  # (N_q, N_prop * N_clips)
         | 
| 91 | 
            +
                    query_clip_dist = query_clip_dist.reshape(n_query, n_prop, n_clips)
         | 
| 92 | 
            +
                    query_moment_dist = torch.sum(
         | 
| 93 | 
            +
                        query_clip_dist * moment_mask.unsqueeze(0), dim=2) / moment_mask.sum(1).unsqueeze(0)
         | 
| 94 | 
            +
                    return query_moment_dist  # (N_q, N_prop)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                def forward(self, query_feat, query_mask, pos_moment_feat, pos_moment_mask,
         | 
| 97 | 
            +
                            intra_neg_moment_feat, intra_neg_moment_mask,
         | 
| 98 | 
            +
                            inter_neg_moment_feat, inter_neg_moment_mask):
         | 
| 99 | 
            +
                    """
         | 
| 100 | 
            +
                    Args:
         | 
| 101 | 
            +
                        query_feat: (N, L, D_q)
         | 
| 102 | 
            +
                        query_mask: (N, L)
         | 
| 103 | 
            +
                        pos_moment_feat: (N, L_clip_1, D_v)
         | 
| 104 | 
            +
                        pos_moment_mask: (N, L_clip_1)
         | 
| 105 | 
            +
                        intra_neg_moment_feat: (N, L_clip_2, D_v)
         | 
| 106 | 
            +
                        intra_neg_moment_mask: (N, L_clip_2)
         | 
| 107 | 
            +
                        inter_neg_moment_feat: (N, L_clip_3, D_v)
         | 
| 108 | 
            +
                        inter_neg_moment_mask: (N, L_clip_2)
         | 
| 109 | 
            +
                    """
         | 
| 110 | 
            +
                    query_embed = self.query_encoder(query_feat, query_mask)  # (N, D_o)
         | 
| 111 | 
            +
                    pos_dist = self.compute_pdist(query_embed, pos_moment_feat, pos_moment_mask)  # (N, )
         | 
| 112 | 
            +
                    intra_neg_dist = self.compute_pdist(query_embed, intra_neg_moment_feat, intra_neg_moment_mask)  # (N, )
         | 
| 113 | 
            +
                    if self.config.inter_loss_weight == 0:  # should be zero for tef_only method.
         | 
| 114 | 
            +
                        loss_inter = 0.
         | 
| 115 | 
            +
                    else:
         | 
| 116 | 
            +
                        inter_neg_dist = self.compute_pdist(query_embed, inter_neg_moment_feat, inter_neg_moment_mask)  # (N, )
         | 
| 117 | 
            +
                        loss_inter = self.calc_loss(pos_dist, inter_neg_dist)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    loss = self.calc_loss(pos_dist, intra_neg_dist) + self.config.inter_loss_weight * loss_inter
         | 
| 120 | 
            +
                    return loss
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                def calc_loss(self, pos_dist, neg_dist):
         | 
| 123 | 
            +
                    """ Note here we encourage positive distance to be smaller than negative distance.
         | 
| 124 | 
            +
                    Args:
         | 
| 125 | 
            +
                        pos_dist: (N, ), torch.float32
         | 
| 126 | 
            +
                        neg_dist: (N, ), torch.float32
         | 
| 127 | 
            +
                    """
         | 
| 128 | 
            +
                    if self.config.loss_type == "hinge":  # max(0, m + S_pos - S_neg)
         | 
| 129 | 
            +
                        return torch.clamp(self.config.margin + pos_dist - neg_dist, min=0).sum() / len(pos_dist)
         | 
| 130 | 
            +
                    elif self.config.loss_type == "lse":  # log[1 + exp(S_pos - S_neg)]
         | 
| 131 | 
            +
                        return torch.log1p(torch.exp(pos_dist - neg_dist)).sum() / len(pos_dist)
         | 
| 132 | 
            +
                    else:
         | 
| 133 | 
            +
                        raise NotImplementedError("Only support 'hinge' and 'lse'")
         | 
| 134 | 
            +
             | 
| 135 | 
            +
             | 
| 136 | 
            +
            class CALWithSub(nn.Module):
         | 
| 137 | 
            +
                def __init__(self, config):
         | 
| 138 | 
            +
                    super(CALWithSub, self).__init__()
         | 
| 139 | 
            +
                    self.config = config
         | 
| 140 | 
            +
                    self.use_video = "video" in config.ctx_mode
         | 
| 141 | 
            +
                    self.use_sub = "sub" in config.ctx_mode
         | 
| 142 | 
            +
                    self.use_tef = "tef" in config.ctx_mode
         | 
| 143 | 
            +
                    self.tef_only = self.use_tef and not self.use_video and not self.use_sub
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    if self.use_video or self.tef_only:
         | 
| 146 | 
            +
                        self.video_moment_mlp = nn.Sequential(
         | 
| 147 | 
            +
                            nn.Linear(config.visual_input_size, config.visual_hidden_size),
         | 
| 148 | 
            +
                            nn.ReLU(True),
         | 
| 149 | 
            +
                            nn.Linear(config.visual_hidden_size, config.output_size),
         | 
| 150 | 
            +
                        )
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    if self.use_sub:
         | 
| 153 | 
            +
                        self.sub_moment_mlp = nn.Sequential(
         | 
| 154 | 
            +
                            nn.Linear(config.textual_input_size, config.visual_hidden_size),
         | 
| 155 | 
            +
                            nn.ReLU(True),
         | 
| 156 | 
            +
                            nn.Linear(config.visual_hidden_size, config.output_size),
         | 
| 157 | 
            +
                        )
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    self.query_lstm = RNNEncoder(word_embedding_size=config.query_feat_size,
         | 
| 160 | 
            +
                                                 hidden_size=config.lstm_hidden_size,
         | 
| 161 | 
            +
                                                 bidirectional=False,
         | 
| 162 | 
            +
                                                 rnn_type="lstm",
         | 
| 163 | 
            +
                                                 dropout_p=0,
         | 
| 164 | 
            +
                                                 n_layers=1,
         | 
| 165 | 
            +
                                                 return_outputs=False)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    self.query_linear = nn.Linear(config.lstm_hidden_size, config.output_size)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                def moment_encoder(self, moment_feat, module_name="video"):
         | 
| 170 | 
            +
                    """moment_feat: (N, L_clip, D_v)"""
         | 
| 171 | 
            +
                    if moment_feat is not None:
         | 
| 172 | 
            +
                        encoder = getattr(self, module_name + "_moment_mlp")
         | 
| 173 | 
            +
                        return F.normalize(encoder(moment_feat), p=2, dim=-1)  # (N, L_clip, D_o)
         | 
| 174 | 
            +
                    else:
         | 
| 175 | 
            +
                        return None
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                def query_encoder(self, query_feat, query_mask):
         | 
| 178 | 
            +
                    """
         | 
| 179 | 
            +
                    Args:
         | 
| 180 | 
            +
                        query_feat: (N, L_q, D_q), torch.float32
         | 
| 181 | 
            +
                        query_mask: (N, L_q), torch.float32, with 1 indicates valid query, 0 indicates mask
         | 
| 182 | 
            +
                    """
         | 
| 183 | 
            +
                    _, hidden = self.query_lstm(query_feat, torch.sum(query_mask, dim=1).long())
         | 
| 184 | 
            +
                    return F.normalize(self.query_linear(hidden), p=2, dim=-1)  # (N, D_o)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                def _compute_pdist(self, query_embedding, moment_feat, moment_mask, module_name="video"):
         | 
| 187 | 
            +
                    """ pairwise L2 distance
         | 
| 188 | 
            +
                    Args:
         | 
| 189 | 
            +
                        query_embedding: (N, D_o)
         | 
| 190 | 
            +
                        moment_feat: (N, L_clip, D_v)
         | 
| 191 | 
            +
                        moment_mask: (N, L_clip), torch.float32, where 1 indicates valid, 0 indicates padding
         | 
| 192 | 
            +
                    """
         | 
| 193 | 
            +
                    moment_embedding = self.moment_encoder(moment_feat, module_name=module_name)  # (N, L_clip, D_o)
         | 
| 194 | 
            +
                    moment_clip_dist = torch.sum((moment_embedding - query_embedding.unsqueeze(1)) ** 2, dim=2)  # (N, L_clip)
         | 
| 195 | 
            +
                    moment_dist = torch.sum(moment_clip_dist * moment_mask, dim=1) / moment_mask.sum(1)  # (N, )
         | 
| 196 | 
            +
                    return moment_dist  # (N, )
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                def compute_pdist(self, query_embedding, moment_video_feat, moment_sub_feat, moment_mask):
         | 
| 199 | 
            +
                    """ pairwise L2 distance
         | 
| 200 | 
            +
                    Args:
         | 
| 201 | 
            +
                        query_embedding: (N, D_o)
         | 
| 202 | 
            +
                        moment_video_feat: (N, L_clip, D_v)
         | 
| 203 | 
            +
                        moment_sub_feat: (N, L_clip, D_t)
         | 
| 204 | 
            +
                        moment_mask: (N, L_clip), torch.float32, where 1 indicates valid, 0 indicates padding
         | 
| 205 | 
            +
                    """
         | 
| 206 | 
            +
                    divisor = (self.use_video or self.tef_only) + self.use_sub
         | 
| 207 | 
            +
                    video_moment_dist = self._compute_pdist(query_embedding, moment_video_feat, moment_mask, module_name="video") \
         | 
| 208 | 
            +
                        if self.use_video or self.tef_only else 0
         | 
| 209 | 
            +
                    sub_moment_dist = self._compute_pdist(query_embedding, moment_sub_feat, moment_mask, module_name="sub") \
         | 
| 210 | 
            +
                        if self.use_sub else 0
         | 
| 211 | 
            +
                    return (video_moment_dist + sub_moment_dist) / divisor  # (N, )
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                def _compute_cdist_inference(self, query_embeddings, moment_embeddings, moment_mask):
         | 
| 214 | 
            +
                    """ Compute L2 distance for every possible pair of queries and proposals. This is different from
         | 
| 215 | 
            +
                    compute_pdist as the latter computes only pairs at each row.
         | 
| 216 | 
            +
                    Args:
         | 
| 217 | 
            +
                        query_embeddings: (N_q, D_o)
         | 
| 218 | 
            +
                        moment_embeddings: (N_prop, N_clips, D_o)
         | 
| 219 | 
            +
                        moment_mask: (N_prop, N_clips)
         | 
| 220 | 
            +
                    return:
         | 
| 221 | 
            +
                        query_moment_scores: (N_q, N_prop)
         | 
| 222 | 
            +
                    """
         | 
| 223 | 
            +
                    # sync device
         | 
| 224 | 
            +
                    query_device = query_embeddings.device  # convert to cuda if we want to use GPU
         | 
| 225 | 
            +
                    if moment_embeddings.device != query_device:
         | 
| 226 | 
            +
                        moment_embeddings = moment_embeddings.to(query_device)
         | 
| 227 | 
            +
                        moment_mask = moment_mask.to(query_device)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    # compute
         | 
| 230 | 
            +
                    n_query = query_embeddings.shape[0]
         | 
| 231 | 
            +
                    n_prop, n_clips, d = moment_embeddings.shape
         | 
| 232 | 
            +
                    query_clip_dist = torch.cdist(
         | 
| 233 | 
            +
                        query_embeddings, moment_embeddings.reshape(-1, d), p=2) ** 2  # (N_q, N_prop * N_clips)
         | 
| 234 | 
            +
                    query_clip_dist = query_clip_dist.reshape(n_query, n_prop, n_clips)
         | 
| 235 | 
            +
                    query_moment_dist = torch.sum(
         | 
| 236 | 
            +
                        query_clip_dist * moment_mask.unsqueeze(0), dim=2) / moment_mask.sum(1).unsqueeze(0)
         | 
| 237 | 
            +
                    return query_moment_dist  # (N_q, N_prop)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                def compute_cdist_inference(self, query_embeddings, video_moment_embeddings, sub_moment_embeddings, moment_mask):
         | 
| 240 | 
            +
                    divisor = (self.use_video or self.tef_only) + self.use_sub
         | 
| 241 | 
            +
                    video_moment_dist = self._compute_cdist_inference(query_embeddings, video_moment_embeddings, moment_mask) \
         | 
| 242 | 
            +
                        if self.use_video or self.tef_only else 0
         | 
| 243 | 
            +
                    sub_moment_dist = self._compute_cdist_inference(query_embeddings, sub_moment_embeddings, moment_mask) \
         | 
| 244 | 
            +
                        if self.use_sub else 0
         | 
| 245 | 
            +
                    return (video_moment_dist + sub_moment_dist) / divisor  # (N_q, N_prop)
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                def forward(self, query_feat, query_mask, pos_moment_video_feat, pos_moment_video_mask,
         | 
| 248 | 
            +
                            intra_neg_moment_video_feat, intra_neg_moment_video_mask,
         | 
| 249 | 
            +
                            inter_neg_moment_video_feat, inter_neg_moment_video_mask,
         | 
| 250 | 
            +
                            pos_moment_sub_feat, pos_moment_sub_mask,
         | 
| 251 | 
            +
                            intra_neg_moment_sub_feat, intra_neg_moment_sub_mask,
         | 
| 252 | 
            +
                            inter_neg_moment_sub_feat, inter_neg_moment_sub_mask):
         | 
| 253 | 
            +
                    """
         | 
| 254 | 
            +
                    Args:
         | 
| 255 | 
            +
                        query_feat: (N, L, D_q)
         | 
| 256 | 
            +
                        query_mask: (N, L)
         | 
| 257 | 
            +
                        pos_moment_video_feat: (N, L_clip_1, D_v)
         | 
| 258 | 
            +
                        pos_moment_video_mask: (N, L_clip_1)
         | 
| 259 | 
            +
                        intra_neg_moment_video_feat: (N, L_clip_2, D_v)
         | 
| 260 | 
            +
                        intra_neg_moment_video_mask: (N, L_clip_2)
         | 
| 261 | 
            +
                        inter_neg_moment_video_feat: (N, L_clip_3, D_v)
         | 
| 262 | 
            +
                        inter_neg_moment_video_mask: (N, L_clip_2)
         | 
| 263 | 
            +
                        pos_moment_sub_feat:
         | 
| 264 | 
            +
                        pos_moment_sub_mask:
         | 
| 265 | 
            +
                        intra_neg_moment_sub_feat:
         | 
| 266 | 
            +
                        intra_neg_moment_sub_mask:
         | 
| 267 | 
            +
                        inter_neg_moment_sub_feat:
         | 
| 268 | 
            +
                        inter_neg_moment_sub_mask:
         | 
| 269 | 
            +
                    """
         | 
| 270 | 
            +
                    query_embed = self.query_encoder(query_feat, query_mask)  # (N, D_o)
         | 
| 271 | 
            +
                    pos_dist = self.compute_pdist(
         | 
| 272 | 
            +
                        query_embed, pos_moment_video_feat, pos_moment_sub_feat,
         | 
| 273 | 
            +
                        moment_mask=pos_moment_sub_mask if self.use_sub else pos_moment_video_mask)  # (N, )
         | 
| 274 | 
            +
                    intra_neg_dist = self.compute_pdist(
         | 
| 275 | 
            +
                        query_embed, intra_neg_moment_video_feat, intra_neg_moment_sub_feat,
         | 
| 276 | 
            +
                        moment_mask=intra_neg_moment_sub_mask if self.use_sub else intra_neg_moment_video_mask)  # (N, )
         | 
| 277 | 
            +
                    if self.config.inter_loss_weight == 0:  # should be zero for tef_only method.
         | 
| 278 | 
            +
                        loss_inter = 0.
         | 
| 279 | 
            +
                    else:
         | 
| 280 | 
            +
                        inter_neg_dist = self.compute_pdist(
         | 
| 281 | 
            +
                            query_embed, inter_neg_moment_video_feat, inter_neg_moment_sub_feat,
         | 
| 282 | 
            +
                            moment_mask=inter_neg_moment_sub_mask if self.use_sub else inter_neg_moment_video_mask)  # (N, )
         | 
| 283 | 
            +
                        loss_inter = self.calc_loss(pos_dist, inter_neg_dist)
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                    loss = self.calc_loss(pos_dist, intra_neg_dist) + self.config.inter_loss_weight * loss_inter
         | 
| 286 | 
            +
                    return loss
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                def calc_loss(self, pos_dist, neg_dist):
         | 
| 289 | 
            +
                    """ Note here we encourage positive distance to be smaller than negative distance.
         | 
| 290 | 
            +
                    Args:
         | 
| 291 | 
            +
                        pos_dist: (N, ), torch.float32
         | 
| 292 | 
            +
                        neg_dist: (N, ), torch.float32
         | 
| 293 | 
            +
                    """
         | 
| 294 | 
            +
                    if self.config.loss_type == "hinge":  # max(0, m + S_pos - S_neg)
         | 
| 295 | 
            +
                        return torch.clamp(self.config.margin + pos_dist - neg_dist, min=0).sum() / len(pos_dist)
         | 
| 296 | 
            +
                    elif self.config.loss_type == "lse":  # log[1 + exp(S_pos - S_neg)]
         | 
| 297 | 
            +
                        return torch.log1p(torch.exp(pos_dist - neg_dist)).sum() / len(pos_dist)
         | 
| 298 | 
            +
                    else:
         | 
| 299 | 
            +
                        raise NotImplementedError("Only support 'hinge' and 'lse'")
         | 
    	
        baselines/clip_alignment_with_language/proposal_retrieval_dataset.py
    ADDED
    
    | @@ -0,0 +1,587 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Dataset for clip model
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
            import logging
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from torch.utils.data import Dataset
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            import h5py
         | 
| 9 | 
            +
            import math
         | 
| 10 | 
            +
            import random
         | 
| 11 | 
            +
            from utils.basic_utils import load_jsonl, load_json, l2_normalize_np_array
         | 
| 12 | 
            +
            from utils.tensor_utils import pad_sequences_1d
         | 
| 13 | 
            +
            from baselines.clip_alignment_with_language.local_utils.proposal import get_proposal_interface
         | 
| 14 | 
            +
            from baselines.clip_alignment_with_language.local_utils.compute_proposal_upper_bound import \
         | 
| 15 | 
            +
                get_didemo_agreed_ts
         | 
| 16 | 
            +
            from standalone_eval.eval import compute_temporal_iou_batch
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            class ProposalRetrievalDataset(Dataset):
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
                Args:
         | 
| 24 | 
            +
                    dset_name, str, ["tvr"]
         | 
| 25 | 
            +
                    ctx_mode: str,
         | 
| 26 | 
            +
                    pos_iou_thd: float, in [0, 1], >= pos_iou_thd are defined as positive
         | 
| 27 | 
            +
                    neg_iou_thd: float, in [0, 1], < neg_iou_thd are defined as negative
         | 
| 28 | 
            +
                Return:
         | 
| 29 | 
            +
                    a dict: {
         | 
| 30 | 
            +
                        "meta": {
         | 
| 31 | 
            +
                            "desc_id": int,
         | 
| 32 | 
            +
                            "desc": str,
         | 
| 33 | 
            +
                            "vid_name": str,
         | 
| 34 | 
            +
                            "duration": float,
         | 
| 35 | 
            +
                            "ts": [st (float), ed (float)], seconds, ground_truth timestamps
         | 
| 36 | 
            +
                            "pos_moment": [st (float), ed (float)], seconds, IoU with "ts" >= pos_iou_thd
         | 
| 37 | 
            +
                            "intra_neg_moment": [st (float), ed (float)], seconds, IoU with "ts" < neg_iou_thd
         | 
| 38 | 
            +
                            "inter_neg_vid_name": str,
         | 
| 39 | 
            +
                            "inter_neg_duration": float,
         | 
| 40 | 
            +
                            "inter_neg_moment": [st (float), ed (float)], seconds, IoU with "ts" < neg_iou_thd
         | 
| 41 | 
            +
                        }
         | 
| 42 | 
            +
                        "model_inputs": {
         | 
| 43 | 
            +
                            "desc_feat": torch.tensor, (L, D_t)
         | 
| 44 | 
            +
                            "pos_moment_feat": torch.tensor, (n_clip_in_moment, D)
         | 
| 45 | 
            +
                            "intra_neg_moment_feat": torch.tensor, (n_clip_in_moment, D)
         | 
| 46 | 
            +
                            "inter_neg_moment_feat": torch.tensor, (n_clip_in_moment, D)
         | 
| 47 | 
            +
                        }
         | 
| 48 | 
            +
                    }
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
                def __init__(self, dset_name, data_path, desc_bert_path, sub_bert_path, max_desc_len,
         | 
| 51 | 
            +
                             vid_feat_path, clip_length, vid_feat_size, sub_feat_size=0, ctx_mode="video_tef",
         | 
| 52 | 
            +
                             pos_iou_thd=0.7, neg_iou_thd=0.3, h5driver=None, data_ratio=1.0,
         | 
| 53 | 
            +
                             normalize_vfeat=True, normalize_tfeat=True, model_type="cal",
         | 
| 54 | 
            +
                             external_train_vr_res_path=None, corpus_path=None):
         | 
| 55 | 
            +
                    self.dset_name = dset_name
         | 
| 56 | 
            +
                    self.model_type = model_type
         | 
| 57 | 
            +
                    self.pool_local = model_type == "mcn"  # pool local feature
         | 
| 58 | 
            +
                    self.data_path = data_path
         | 
| 59 | 
            +
                    self.data_ratio = data_ratio
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    self.desc_bert_path = desc_bert_path
         | 
| 62 | 
            +
                    self.max_desc_len = max_desc_len
         | 
| 63 | 
            +
                    self.sub_bert_path = sub_bert_path
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    self.vid_feat_path = vid_feat_path
         | 
| 66 | 
            +
                    self.clip_length = clip_length
         | 
| 67 | 
            +
                    self.ctx_mode = ctx_mode
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    self.pos_iou_thd = pos_iou_thd
         | 
| 70 | 
            +
                    self.neg_iou_thd = neg_iou_thd
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    self.vid_feat_output_size = 2 * vid_feat_size * ("video" in ctx_mode) + 2 * ("tef" in ctx_mode)
         | 
| 73 | 
            +
                    self.sub_feat_output_size = 2 * sub_feat_size * ("sub" in ctx_mode) + 2 * ("tef" in ctx_mode)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    # prepare desc data
         | 
| 76 | 
            +
                    self.data = load_jsonl(data_path)
         | 
| 77 | 
            +
                    if self.data_ratio != 1:
         | 
| 78 | 
            +
                        n_examples = int(len(self.data) * data_ratio)
         | 
| 79 | 
            +
                        self.data = self.data[:n_examples]
         | 
| 80 | 
            +
                        logger.info("Using {}% of the data: {} examples".format(data_ratio * 100, n_examples))
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    self.proposal_fn = get_proposal_interface(dset_name)
         | 
| 83 | 
            +
                    if self.ctx_mode != "tef":
         | 
| 84 | 
            +
                        self.vid_feat_h5 = h5py.File(self.vid_feat_path, "r", driver=h5driver)
         | 
| 85 | 
            +
                    self.desc_bert_h5 = h5py.File(self.desc_bert_path, "r", driver=h5driver)
         | 
| 86 | 
            +
                    if "sub" in self.ctx_mode:
         | 
| 87 | 
            +
                        self.sub_bert_h5 = h5py.File(self.sub_bert_path, "r", driver=h5driver)
         | 
| 88 | 
            +
                    self.normalize_vfeat = normalize_vfeat
         | 
| 89 | 
            +
                    self.normalize_tfeat = normalize_tfeat
         | 
| 90 | 
            +
                    self.use_video = "video" in self.ctx_mode
         | 
| 91 | 
            +
                    self.use_sub = "sub" in self.ctx_mode
         | 
| 92 | 
            +
                    self.use_tef = "tef" in self.ctx_mode
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    if external_train_vr_res_path is not None:
         | 
| 95 | 
            +
                        video_data = load_json(corpus_path)["train"]
         | 
| 96 | 
            +
                        # {video_idx: [vid_name, vid_duration]}
         | 
| 97 | 
            +
                        video_idx2name_dur_pair = {v[1]: [k, v[0]] for k, v in video_data.items()}
         | 
| 98 | 
            +
                        external_vr_res = load_json(external_train_vr_res_path)
         | 
| 99 | 
            +
                        # {desc_id: [(vid_name, vid_duration), ...]}
         | 
| 100 | 
            +
                        self.desc_id2video_names_dur_pairs = \
         | 
| 101 | 
            +
                            {e["desc_id"]: [video_idx2name_dur_pair[int(sub_e[0])] for sub_e in e["predictions"]]
         | 
| 102 | 
            +
                             for e in external_vr_res["VR"]}  # ordered
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                def __len__(self):
         | 
| 105 | 
            +
                    return len(self.data)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def __getitem__(self, index):
         | 
| 108 | 
            +
                    raw_data = self.data[index]
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    # initialize with basic data
         | 
| 111 | 
            +
                    meta = dict(
         | 
| 112 | 
            +
                        desc_id=raw_data["desc_id"],
         | 
| 113 | 
            +
                        desc=raw_data["desc"],
         | 
| 114 | 
            +
                        vid_name=raw_data["vid_name"],
         | 
| 115 | 
            +
                        duration=raw_data["duration"],
         | 
| 116 | 
            +
                        ts=raw_data["ts"] if self.dset_name != "didemo" else get_didemo_agreed_ts(raw_data["ts"]),
         | 
| 117 | 
            +
                    )
         | 
| 118 | 
            +
                    model_inputs = dict()
         | 
| 119 | 
            +
                    query_feat = self.desc_bert_h5[str(raw_data["desc_id"])][:self.max_desc_len]
         | 
| 120 | 
            +
                    if self.normalize_tfeat:
         | 
| 121 | 
            +
                        query_feat = l2_normalize_np_array(query_feat)
         | 
| 122 | 
            +
                    model_inputs["query_feat"] = torch.from_numpy(query_feat)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    # sample positive and negative moments
         | 
| 125 | 
            +
                    meta["pos_moment"] = self.align_ts_to_clip_boundaries(meta["duration"], meta["ts"])
         | 
| 126 | 
            +
                    meta["intra_neg_moment"] = self.sample_intra_neg_moment(meta["duration"], meta["ts"])
         | 
| 127 | 
            +
                    meta["inter_neg_moment"], meta["inter_neg_vid_name"], meta["inter_neg_duration"] = \
         | 
| 128 | 
            +
                        self.sample_inter_video_negative(meta["vid_name"], meta["pos_moment"] / meta["duration"],
         | 
| 129 | 
            +
                                                         desc_id=meta["desc_id"])
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    pos_tef, intra_neg_tef, inter_neg_tef = (None,) * 3
         | 
| 132 | 
            +
                    if self.use_tef:
         | 
| 133 | 
            +
                        pos_tef = meta["pos_moment"] / meta["duration"]  # temporal endpoint feature, (2, )
         | 
| 134 | 
            +
                        intra_neg_tef = meta["intra_neg_moment"] / meta["duration"]
         | 
| 135 | 
            +
                        inter_neg_tef = meta["inter_neg_moment"] / meta["inter_neg_duration"]
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    if self.use_video:
         | 
| 138 | 
            +
                        pos_v_feat = self.vid_feat_h5[meta["vid_name"]]  # (N_frm, D)
         | 
| 139 | 
            +
                        neg_v_feat = self.vid_feat_h5[meta["inter_neg_vid_name"]]
         | 
| 140 | 
            +
                        pos_v_ctx_feat = np.mean(pos_v_feat, axis=0)
         | 
| 141 | 
            +
                        neg_v_ctx_feat = np.mean(neg_v_feat, axis=0)
         | 
| 142 | 
            +
                        if self.normalize_vfeat:
         | 
| 143 | 
            +
                            pos_v_ctx_feat = l2_normalize_np_array(pos_v_ctx_feat)
         | 
| 144 | 
            +
                            neg_v_ctx_feat = l2_normalize_np_array(neg_v_ctx_feat)
         | 
| 145 | 
            +
                        pos_moment_v_feat = self.get_moment_feat(pos_v_feat, meta["pos_moment"],
         | 
| 146 | 
            +
                                                                 normalize=self.normalize_vfeat,
         | 
| 147 | 
            +
                                                                 fix_outbound=True, pool_local=self.pool_local)
         | 
| 148 | 
            +
                        intra_neg_moment_v_feat = self.get_moment_feat(pos_v_feat, meta["intra_neg_moment"],
         | 
| 149 | 
            +
                                                                       normalize=self.normalize_vfeat,
         | 
| 150 | 
            +
                                                                       fix_outbound=True, pool_local=self.pool_local)
         | 
| 151 | 
            +
                        inter_neg_moment_v_feat = self.get_moment_feat(neg_v_feat, meta["inter_neg_moment"],
         | 
| 152 | 
            +
                                                                       normalize=self.normalize_vfeat,
         | 
| 153 | 
            +
                                                                       fix_outbound=True, pool_local=self.pool_local)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                        # concat features, [video_clip_feat; video_context_feat; temporal_endpoint_feat]
         | 
| 156 | 
            +
                        model_inputs["pos_moment_video_feat"] = self.concat_feat_adv(
         | 
| 157 | 
            +
                            moment_feats=[pos_moment_v_feat, pos_v_ctx_feat], tef=pos_tef, ctx_mode=self.ctx_mode)
         | 
| 158 | 
            +
                        model_inputs["intra_neg_moment_video_feat"] = self.concat_feat_adv(
         | 
| 159 | 
            +
                            moment_feats=[intra_neg_moment_v_feat, pos_v_ctx_feat], tef=intra_neg_tef, ctx_mode=self.ctx_mode)
         | 
| 160 | 
            +
                        model_inputs["inter_neg_moment_video_feat"] = self.concat_feat_adv(
         | 
| 161 | 
            +
                            moment_feats=[inter_neg_moment_v_feat, neg_v_ctx_feat], tef=inter_neg_tef, ctx_mode=self.ctx_mode)
         | 
| 162 | 
            +
                    else:
         | 
| 163 | 
            +
                        for k in ["pos_moment_video_feat", "intra_neg_moment_video_feat", "inter_neg_moment_video_feat"]:
         | 
| 164 | 
            +
                            model_inputs[k] = torch.zeros((2, 2))
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    if self.use_sub:  # no need for ctx feature, as the features are already contextulized
         | 
| 167 | 
            +
                        pos_s_feat = self.sub_bert_h5[meta["vid_name"]]  # (N_words, D_t)
         | 
| 168 | 
            +
                        neg_s_feat = self.sub_bert_h5[meta["inter_neg_vid_name"]]
         | 
| 169 | 
            +
                        pos_s_ctx_feat = np.mean(pos_s_feat, axis=0)
         | 
| 170 | 
            +
                        neg_s_ctx_feat = np.mean(neg_s_feat, axis=0)
         | 
| 171 | 
            +
                        if self.normalize_tfeat:
         | 
| 172 | 
            +
                            pos_s_ctx_feat = l2_normalize_np_array(pos_s_ctx_feat)
         | 
| 173 | 
            +
                            neg_s_ctx_feat = l2_normalize_np_array(neg_s_ctx_feat)
         | 
| 174 | 
            +
                        pos_moment_s_feat = self.get_moment_feat(pos_s_feat, meta["pos_moment"],
         | 
| 175 | 
            +
                                                                 normalize=self.normalize_tfeat,
         | 
| 176 | 
            +
                                                                 fix_outbound=True, pool_local=self.pool_local)
         | 
| 177 | 
            +
                        intra_neg_moment_s_feat = self.get_moment_feat(pos_s_feat, meta["intra_neg_moment"],
         | 
| 178 | 
            +
                                                                       normalize=self.normalize_tfeat,
         | 
| 179 | 
            +
                                                                       fix_outbound=True, pool_local=self.pool_local)
         | 
| 180 | 
            +
                        inter_neg_moment_s_feat = self.get_moment_feat(neg_s_feat, meta["inter_neg_moment"],
         | 
| 181 | 
            +
                                                                       normalize=self.normalize_tfeat,
         | 
| 182 | 
            +
                                                                       fix_outbound=True, pool_local=self.pool_local)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                        # concat features, [sub_clip_feat; sub_context_feat; temporal_endpoint_feat]
         | 
| 185 | 
            +
                        model_inputs["pos_moment_sub_feat"] = self.concat_feat_adv(
         | 
| 186 | 
            +
                            moment_feats=[pos_moment_s_feat, pos_s_ctx_feat], tef=pos_tef, ctx_mode=self.ctx_mode)
         | 
| 187 | 
            +
                        model_inputs["intra_neg_moment_sub_feat"] = self.concat_feat_adv(
         | 
| 188 | 
            +
                            moment_feats=[intra_neg_moment_s_feat, pos_s_ctx_feat], tef=intra_neg_tef, ctx_mode=self.ctx_mode)
         | 
| 189 | 
            +
                        model_inputs["inter_neg_moment_sub_feat"] = self.concat_feat_adv(
         | 
| 190 | 
            +
                            moment_feats=[inter_neg_moment_s_feat, neg_s_ctx_feat], tef=inter_neg_tef, ctx_mode=self.ctx_mode)
         | 
| 191 | 
            +
                    else:
         | 
| 192 | 
            +
                        for k in ["pos_moment_sub_feat", "intra_neg_moment_sub_feat", "inter_neg_moment_sub_feat"]:
         | 
| 193 | 
            +
                            model_inputs[k] = torch.zeros((2, 2))
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    if not self.use_sub and not self.use_video and self.use_tef:  # use video stream
         | 
| 196 | 
            +
                        model_inputs["pos_moment_video_feat"] = \
         | 
| 197 | 
            +
                            self.concat_feat_adv(tef=pos_tef, ctx_mode=self.ctx_mode)
         | 
| 198 | 
            +
                        model_inputs["intra_neg_moment_video_feat"] = \
         | 
| 199 | 
            +
                            self.concat_feat_adv(tef=intra_neg_tef, ctx_mode=self.ctx_mode)
         | 
| 200 | 
            +
                        model_inputs["inter_neg_moment_video_feat"] = \
         | 
| 201 | 
            +
                            self.concat_feat_adv(tef=inter_neg_tef, ctx_mode=self.ctx_mode)
         | 
| 202 | 
            +
                    return dict(meta=meta, model_inputs=model_inputs)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                def align_ts_to_clip_boundaries(self, duration, ts):
         | 
| 205 | 
            +
                    """  # TODO Do we really need this???
         | 
| 206 | 
            +
                    Generate a moment [st, ed] that is most close to a clip boundary,
         | 
| 207 | 
            +
                    st and ed must be a multiple of self.clip_length, and ed <= duration
         | 
| 208 | 
            +
                    duration: float,
         | 
| 209 | 
            +
                    ts: [st (float), ed (float)], ground_truth ts
         | 
| 210 | 
            +
                    """
         | 
| 211 | 
            +
                    clip_aligned_ts = np.array([math.floor(ts[0] / self.clip_length),
         | 
| 212 | 
            +
                                                math.ceil(ts[1] / self.clip_length)]) * self.clip_length
         | 
| 213 | 
            +
                    clip_aligned_ts[1] = min(clip_aligned_ts[1], duration)
         | 
| 214 | 
            +
                    return clip_aligned_ts
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                def sample_intra_neg_moment(self, duration, ts):
         | 
| 217 | 
            +
                    """ Generate a intra negative moment given the video duration and the GT ts.
         | 
| 218 | 
            +
                    The returned moment will be aligned to clip boundaries.
         | 
| 219 | 
            +
                    1) neg_moment has at least 2 clips
         | 
| 220 | 
            +
                    2) its iou with ts should be < self.neg_iou_thd
         | 
| 221 | 
            +
                    Args:
         | 
| 222 | 
            +
                        duration: float
         | 
| 223 | 
            +
                        ts: [st (float), ed (float)], ground_truth ts
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    Returns:
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    """
         | 
| 228 | 
            +
                    max_n_search = 5  # search at most max_n_search times, so the program will not be stuck in infinite loops.
         | 
| 229 | 
            +
                    sampled_moments = self.sample_ts_at_clip_boundaries(duration, n_pairs=max_n_search)  # (n_pairs, 2)
         | 
| 230 | 
            +
                    sampled_moments_ious = compute_temporal_iou_batch(sampled_moments, ts)  # (n_pairs, )
         | 
| 231 | 
            +
                    smallest_iou_idx = np.argmin(sampled_moments_ious)
         | 
| 232 | 
            +
                    sampled_moment = sampled_moments[smallest_iou_idx]
         | 
| 233 | 
            +
                    # only a small number (<20 with max_n_search==10) of samples are wrong,
         | 
| 234 | 
            +
                    # usually when the video_duration is too short.
         | 
| 235 | 
            +
                    # if sampled_moments_ious[smallest_iou_idx] >= self.neg_iou_thd:
         | 
| 236 | 
            +
                    #     logger.warning("the sampled intra-neg might be wrong. "
         | 
| 237 | 
            +
                    #                    "v_dur {}, ts {}, sampled neg moment {}, iou {}"
         | 
| 238 | 
            +
                    #                    .format(duration, ts, sampled_moment, sampled_moments_ious[smallest_iou_idx]))
         | 
| 239 | 
            +
                    return sampled_moment
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                def sample_ts_at_clip_boundaries(self, duration, n_pairs=1):
         | 
| 242 | 
            +
                    """sample n_pairs moment at clip boundaries, each has at least two clips."""
         | 
| 243 | 
            +
                    # '+ self.clip_length' since we assume indexing using [clip_st_idx, clip_ed_idx),
         | 
| 244 | 
            +
                    moments = np.random.randint(0, np.ceil(duration / self.clip_length), size=(n_pairs, 2))
         | 
| 245 | 
            +
                    moments = np.sort(moments, axis=1) * self.clip_length
         | 
| 246 | 
            +
                    less_equal = moments[:, 1] - moments[:, 0] <= self.clip_length
         | 
| 247 | 
            +
                    start_zero = moments[:, 0] == 0
         | 
| 248 | 
            +
                    moments[:, 1][less_equal * start_zero] += self.clip_length
         | 
| 249 | 
            +
                    moments[:, 0][less_equal * (start_zero == False)] -= self.clip_length  # keep as bool!!!
         | 
| 250 | 
            +
                    return moments
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                def sample_inter_video_negative(self, pos_vid_name, normalized_pos_moment, desc_id=None):
         | 
| 253 | 
            +
                    """Sample a negative moment --> negative video + similar normalized moment.
         | 
| 254 | 
            +
                    1) they are not from the same video
         | 
| 255 | 
            +
                    Args:
         | 
| 256 | 
            +
                        pos_vid_name: str,
         | 
| 257 | 
            +
                        normalized_pos_moment: np.ndarray, (2, ), value in [0, 1], normalized by duration.
         | 
| 258 | 
            +
                        desc_id: str
         | 
| 259 | 
            +
                    Returns:
         | 
| 260 | 
            +
                        moment: np.ndarray, (2, ), ts aligned to clip boundaries.
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    """
         | 
| 263 | 
            +
                    use_guided_negative = hasattr(self, "desc_id2video_names_dur_pairs")
         | 
| 264 | 
            +
                    if use_guided_negative:
         | 
| 265 | 
            +
                        top_videos = self.desc_id2video_names_dur_pairs[desc_id]
         | 
| 266 | 
            +
                        max_idx = len(top_videos) - 1
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    while True:  # usually only run once.
         | 
| 269 | 
            +
                        if use_guided_negative:
         | 
| 270 | 
            +
                            sampled_idx = min(max_idx, int(random.expovariate(0.1)))
         | 
| 271 | 
            +
                            sampled_video_name, sampled_video_dur = top_videos[sampled_idx]
         | 
| 272 | 
            +
                        else:
         | 
| 273 | 
            +
                            neg_vid_data = self.data[int(random.random() * len(self))]
         | 
| 274 | 
            +
                            sampled_video_name, sampled_video_dur = neg_vid_data["vid_name"], neg_vid_data["duration"]
         | 
| 275 | 
            +
                        if sampled_video_name != pos_vid_name:
         | 
| 276 | 
            +
                            inter_neg_moment = self.align_ts_to_clip_boundaries(
         | 
| 277 | 
            +
                                sampled_video_dur, sampled_video_dur * normalized_pos_moment)
         | 
| 278 | 
            +
                            break
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                    return inter_neg_moment, sampled_video_name, sampled_video_dur
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                @classmethod
         | 
| 283 | 
            +
                def get_clip_indices_from_moments(cls, moment, clip_length):
         | 
| 284 | 
            +
                    clip_st_ed_indices = moment / clip_length
         | 
| 285 | 
            +
                    return math.floor(clip_st_ed_indices[0]), math.ceil(clip_st_ed_indices[1])
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                def get_moment_feat(self, vid_feat, moment, normalize=True, fix_outbound=False, pool_local=False):
         | 
| 288 | 
            +
                    """Each moment contains multiple clips.
         | 
| 289 | 
            +
                    Inside means [moment[0], moment[1]] (seconds)
         | 
| 290 | 
            +
                    Args:
         | 
| 291 | 
            +
                        vid_feat: np.ndarray, (N_clips, D)
         | 
| 292 | 
            +
                        moment: [st (float), ed (float)], np.ndarray
         | 
| 293 | 
            +
                        normalize: L2 normalize features
         | 
| 294 | 
            +
                        fix_outbound: bool,
         | 
| 295 | 
            +
                        pool_local: whether to mean pool the features
         | 
| 296 | 
            +
                    Returns:
         | 
| 297 | 
            +
                        moment_feature: np.ndarray, ((moment[1] - moment[0]) / clip_length, D) or (D, )
         | 
| 298 | 
            +
                    """
         | 
| 299 | 
            +
                    clip_st_idx, clip_ed_idx = self.get_clip_indices_from_moments(moment, self.clip_length)
         | 
| 300 | 
            +
                    if fix_outbound:
         | 
| 301 | 
            +
                        vid_feat_len = len(vid_feat)
         | 
| 302 | 
            +
                        if clip_st_idx >= vid_feat_len:
         | 
| 303 | 
            +
                            clip_st_idx = vid_feat_len - 2
         | 
| 304 | 
            +
                    moment_feat = vid_feat[clip_st_idx:clip_ed_idx]  # indexed as [st, ed)
         | 
| 305 | 
            +
                    if pool_local:
         | 
| 306 | 
            +
                        moment_feat = np.mean(moment_feat, axis=0, keepdims=True)
         | 
| 307 | 
            +
                    if normalize:
         | 
| 308 | 
            +
                        moment_feat = l2_normalize_np_array(moment_feat)
         | 
| 309 | 
            +
                    return moment_feat  # (n_clip_in_moment, D) or (D, )
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                @classmethod
         | 
| 312 | 
            +
                def concat_feat_adv(cls, moment_feats=None, tef=None, to_torch=True, ctx_mode="tef"):
         | 
| 313 | 
            +
                    """ Concat moment_feat with other_feats and tef. All the features should be L2 normalized before concatenating
         | 
| 314 | 
            +
                    Args:
         | 
| 315 | 
            +
                        moment_feats: list of feats, one of them might be None. Other possible values are
         | 
| 316 | 
            +
                            ctx_feat (D, ) or sub(vid)_moment_feat (N_p, N_clips, D_t) or (N_clips, D_t).
         | 
| 317 | 
            +
                            The first non-None feature array is used as base for the rest to concatenate with.
         | 
| 318 | 
            +
                        tef: (N_p, 2) or (2, ), np.ndarray
         | 
| 319 | 
            +
                        to_torch: convert resulting np.ndarray to torch.tensor
         | 
| 320 | 
            +
                        ctx_mode:
         | 
| 321 | 
            +
                    """
         | 
| 322 | 
            +
                    if ctx_mode == "tef":
         | 
| 323 | 
            +
                        assembled_feat = np.expand_dims(tef, axis=-2)
         | 
| 324 | 
            +
                    else:  # concat moment_feat with all other_feats
         | 
| 325 | 
            +
                        moment_feats = [e for e in moment_feats if e is not None]  # remove possible None (placeholder)
         | 
| 326 | 
            +
                        extra_dims = moment_feats[0].shape[:-1]  # all others will need to broadcast to match it.
         | 
| 327 | 
            +
                        if isinstance(extra_dims, int):  # happens when len(moment_feat.shape) == 2
         | 
| 328 | 
            +
                            extra_dims = (extra_dims, )
         | 
| 329 | 
            +
                        last_dim_lengths = [0, ] + [e.shape[-1] for e in moment_feats]
         | 
| 330 | 
            +
                        if "tef" in ctx_mode:  # add tef
         | 
| 331 | 
            +
                            last_dim_lengths += [2, ]
         | 
| 332 | 
            +
                            moment_feats += [np.expand_dims(tef, axis=-2), ]
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                        if len(moment_feats) > 1:
         | 
| 335 | 
            +
                            assembled_feat = np.empty(extra_dims + (sum(last_dim_lengths), ), dtype=np.float32)
         | 
| 336 | 
            +
                            last_dim_lengths_cumsum = [sum(last_dim_lengths[0:idx+1]) for idx in range(len(last_dim_lengths))]
         | 
| 337 | 
            +
                            for idx, feat in enumerate(moment_feats):
         | 
| 338 | 
            +
                                assembled_feat[..., last_dim_lengths_cumsum[idx]:last_dim_lengths_cumsum[idx+1]] = feat
         | 
| 339 | 
            +
                        else:
         | 
| 340 | 
            +
                            assembled_feat = moment_feats[0]
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    if to_torch:
         | 
| 343 | 
            +
                        return torch.from_numpy(assembled_feat)
         | 
| 344 | 
            +
                    else:
         | 
| 345 | 
            +
                        return assembled_feat  # (N_prop, N_clips, D_concat) or (N_clips, D_concat)
         | 
| 346 | 
            +
             | 
| 347 | 
            +
             | 
| 348 | 
            +
            class ProposalRetrievalEvalDataset(Dataset):
         | 
| 349 | 
            +
                """
         | 
| 350 | 
            +
                init_data_mode: `video_query` or `video_only` or `query_only`,
         | 
| 351 | 
            +
                    it indicates which data to load when initialize the Dataset object.
         | 
| 352 | 
            +
                data_mode: `context` or `query`, it indicates which data to return for self.__get_item__()
         | 
| 353 | 
            +
                desc_bert_path_or_handler: h5py.File object or str path
         | 
| 354 | 
            +
                vid_feat_path_or_handler: h5py.File object or str path
         | 
| 355 | 
            +
                eval_proposal_bsz: the proposals for a single video will be sorted in length and batched here with
         | 
| 356 | 
            +
                    max batch size to be eval_proposal_bsz. A single video might have multiple batches of proposals.
         | 
| 357 | 
            +
                load_gt_video: load GroundTruth Video, useful when evaluating single video moment retrieval.
         | 
| 358 | 
            +
                data_ratio: percentage of query data to use.
         | 
| 359 | 
            +
                """
         | 
| 360 | 
            +
                def __init__(self, dset_name, eval_split_name, data_path=None,
         | 
| 361 | 
            +
                             desc_bert_path_or_handler=None, max_desc_len=None,
         | 
| 362 | 
            +
                             sub_bert_path_or_handler=None, vid_feat_path_or_handler=None,
         | 
| 363 | 
            +
                             corpus_path=None, clip_length=None,
         | 
| 364 | 
            +
                             eval_proposal_bsz=None, ctx_mode="tef", data_mode="context",
         | 
| 365 | 
            +
                             h5driver=None, data_ratio=1.0, normalize_vfeat=True,
         | 
| 366 | 
            +
                             normalize_tfeat=True, max_n_proposals=90, model_type="cal"):
         | 
| 367 | 
            +
                    self.dset_name = dset_name
         | 
| 368 | 
            +
                    self.model_type = model_type
         | 
| 369 | 
            +
                    self.pool_local = model_type == "mcn"  # pool local feature
         | 
| 370 | 
            +
                    self.eval_split_name = eval_split_name
         | 
| 371 | 
            +
                    self.ctx_mode = ctx_mode
         | 
| 372 | 
            +
                    self.load_gt_video = False
         | 
| 373 | 
            +
                    self.data_ratio = data_ratio  # only affect query data
         | 
| 374 | 
            +
                    self.normalize_vfeat = normalize_vfeat
         | 
| 375 | 
            +
                    self.normalize_tfeat = normalize_tfeat
         | 
| 376 | 
            +
                    self.max_n_proposals = max_n_proposals
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                    self.data_mode = None
         | 
| 379 | 
            +
                    self.set_data_mode(data_mode)
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                    self.max_desc_len = max_desc_len
         | 
| 382 | 
            +
                    self.data_path = data_path
         | 
| 383 | 
            +
                    self.query_data = load_jsonl(data_path)
         | 
| 384 | 
            +
                    if data_ratio != 1:
         | 
| 385 | 
            +
                        n_examples = int(len(self.query_data) * data_ratio)
         | 
| 386 | 
            +
                        self.query_data = self.query_data[:n_examples]
         | 
| 387 | 
            +
                        logger.info("Using {}% of the data: {} examples".format(data_ratio * 100, n_examples))
         | 
| 388 | 
            +
                    if isinstance(desc_bert_path_or_handler, h5py.File):
         | 
| 389 | 
            +
                        self.desc_bert_h5 = desc_bert_path_or_handler
         | 
| 390 | 
            +
                    else:
         | 
| 391 | 
            +
                        self.desc_bert_h5 = h5py.File(desc_bert_path_or_handler, "r", driver=h5driver)
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                    video_data = load_json(corpus_path)[self.eval_split_name]
         | 
| 394 | 
            +
                    self.video_data = [{"vid_name": k, "duration": v[0]} for k, v in video_data.items()]
         | 
| 395 | 
            +
                    self.video2idx = {k: v[1] for k, v in video_data.items()}
         | 
| 396 | 
            +
                    self.eval_proposal_bsz = eval_proposal_bsz
         | 
| 397 | 
            +
                    self.clip_length = clip_length
         | 
| 398 | 
            +
                    self.proposal_fn = get_proposal_interface(dset_name)
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                    self.use_video = "video" in self.ctx_mode
         | 
| 401 | 
            +
                    self.use_sub = "sub" in self.ctx_mode
         | 
| 402 | 
            +
                    self.use_tef = "tef" in self.ctx_mode
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                    if self.use_video:
         | 
| 405 | 
            +
                        if isinstance(vid_feat_path_or_handler, h5py.File):
         | 
| 406 | 
            +
                            self.vid_feat_h5 = vid_feat_path_or_handler
         | 
| 407 | 
            +
                        else:  # str path
         | 
| 408 | 
            +
                            self.vid_feat_h5 = h5py.File(vid_feat_path_or_handler, "r", driver=h5driver)
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    if self.use_sub:
         | 
| 411 | 
            +
                        if isinstance(sub_bert_path_or_handler, h5py.File):
         | 
| 412 | 
            +
                            self.sub_bert_h5 = sub_bert_path_or_handler
         | 
| 413 | 
            +
                        else:  # str path
         | 
| 414 | 
            +
                            self.sub_bert_h5 = h5py.File(sub_bert_path_or_handler, "r", driver=h5driver)
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                def set_data_mode(self, data_mode):
         | 
| 417 | 
            +
                    """context or query"""
         | 
| 418 | 
            +
                    assert data_mode in ["context", "query"]
         | 
| 419 | 
            +
                    self.data_mode = data_mode
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                def load_gt_vid_name_for_query(self, load_gt_video):
         | 
| 422 | 
            +
                    """load_gt_video: bool, affect the returned value of self._get_item_query"""
         | 
| 423 | 
            +
                    assert "vid_name" in self.query_data[0]
         | 
| 424 | 
            +
                    self.load_gt_video = load_gt_video
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                def __len__(self):
         | 
| 427 | 
            +
                    if self.data_mode == "context":
         | 
| 428 | 
            +
                        return len(self.video_data)
         | 
| 429 | 
            +
                    else:
         | 
| 430 | 
            +
                        return len(self.query_data)
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                def __getitem__(self, index):
         | 
| 433 | 
            +
                    if self.data_mode == "context":
         | 
| 434 | 
            +
                        return self._get_item_context(index)
         | 
| 435 | 
            +
                    else:
         | 
| 436 | 
            +
                        return self._get_item_query(index)
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                def _get_item_query(self, index):
         | 
| 439 | 
            +
                    """Need to batch"""
         | 
| 440 | 
            +
                    raw_data = self.query_data[index]
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                    meta = dict(
         | 
| 443 | 
            +
                        desc_id=raw_data["desc_id"],
         | 
| 444 | 
            +
                        desc=raw_data["desc"],
         | 
| 445 | 
            +
                        vid_name=raw_data["vid_name"] if self.load_gt_video else None
         | 
| 446 | 
            +
                    )
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                    model_inputs = dict()
         | 
| 449 | 
            +
                    query_feat = self.desc_bert_h5[str(raw_data["desc_id"])][:self.max_desc_len]
         | 
| 450 | 
            +
                    if self.normalize_tfeat:
         | 
| 451 | 
            +
                        query_feat = l2_normalize_np_array(query_feat)
         | 
| 452 | 
            +
                    model_inputs["query_feat"] = torch.from_numpy(query_feat)
         | 
| 453 | 
            +
                    return dict(meta=meta, model_inputs=model_inputs)
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                def _get_item_context(self, index):
         | 
| 456 | 
            +
                    """No need to batch, since it has already been batched here"""
         | 
| 457 | 
            +
                    raw_data = self.video_data[index]
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                    # get proposals and sort in ascending order, to get more efficient batching
         | 
| 460 | 
            +
                    proposals = self.proposal_fn(
         | 
| 461 | 
            +
                        video_id="", metadata={"duration": raw_data["duration"]})  # np.ndarray (N_p, 2)
         | 
| 462 | 
            +
                    proposals_lengths = proposals[:, 1] - proposals[:, 0]  # seconds
         | 
| 463 | 
            +
                    sorted_proposal_indices = np.argsort(proposals_lengths)[:self.max_n_proposals]
         | 
| 464 | 
            +
                    sorted_proposals = proposals[sorted_proposal_indices]
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                    # initialize with basic data
         | 
| 467 | 
            +
                    meta = dict(
         | 
| 468 | 
            +
                        vid_name=raw_data["vid_name"],
         | 
| 469 | 
            +
                        duration=raw_data["duration"],
         | 
| 470 | 
            +
                        proposals=sorted_proposals
         | 
| 471 | 
            +
                    )
         | 
| 472 | 
            +
                    model_inputs = dict()
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                    n_proposal_batches = math.ceil(1.0 * len(sorted_proposals) / self.eval_proposal_bsz)
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                    tef_batched_list = [None, ] * n_proposal_batches
         | 
| 477 | 
            +
                    t_moments_mask_list = [None, ] * n_proposal_batches
         | 
| 478 | 
            +
                    if self.use_tef:
         | 
| 479 | 
            +
                        tef_array = sorted_proposals / meta["duration"]  # (N_p, 2)
         | 
| 480 | 
            +
                        for batch_idx in range(n_proposal_batches):
         | 
| 481 | 
            +
                            st_m_idx = batch_idx * self.eval_proposal_bsz
         | 
| 482 | 
            +
                            ed_m_idx = (batch_idx + 1) * self.eval_proposal_bsz
         | 
| 483 | 
            +
                            tef_batched_list[batch_idx] = tef_array[st_m_idx:ed_m_idx]
         | 
| 484 | 
            +
                            t_moments_mask_list[batch_idx] = \
         | 
| 485 | 
            +
                                np.ones((len(tef_batched_list[batch_idx]), 1), dtype=np.float32)
         | 
| 486 | 
            +
                        if not self.use_video and not self.use_sub:  # use video stream
         | 
| 487 | 
            +
                            model_inputs["video_moment_features_list"] = [
         | 
| 488 | 
            +
                                ProposalRetrievalDataset.concat_feat_adv(tef=t, ctx_mode=self.ctx_mode) for t in tef_batched_list]
         | 
| 489 | 
            +
                            model_inputs["video_moment_mask_list"] = [torch.from_numpy(e) for e in t_moments_mask_list]
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                    # extract/group/pad
         | 
| 492 | 
            +
                    if self.use_video:
         | 
| 493 | 
            +
                        v_feat = self.vid_feat_h5[meta["vid_name"]]  # (N_frm, D)
         | 
| 494 | 
            +
                        v_ctx_feat = np.mean(v_feat, axis=0)  # (D, )
         | 
| 495 | 
            +
                        if self.normalize_vfeat:
         | 
| 496 | 
            +
                            v_ctx_feat = l2_normalize_np_array(v_ctx_feat)
         | 
| 497 | 
            +
                        v_padded_moments_features_list, v_moments_mask_list = \
         | 
| 498 | 
            +
                            self.get_batched_moment_feat_for_all_proposals(v_feat, sorted_proposals,
         | 
| 499 | 
            +
                                                                           pool_local=self.pool_local,
         | 
| 500 | 
            +
                                                                           normalize=self.normalize_vfeat)
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                        model_inputs["video_moment_features_list"] = [ProposalRetrievalDataset.concat_feat_adv(
         | 
| 503 | 
            +
                            moment_feats=[v, v_ctx_feat], tef=t, ctx_mode=self.ctx_mode)
         | 
| 504 | 
            +
                            for v, t in zip(v_padded_moments_features_list, tef_batched_list)]
         | 
| 505 | 
            +
                        model_inputs["video_moment_mask_list"] = [torch.from_numpy(e) for e in v_moments_mask_list]
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                    if self.use_sub:
         | 
| 508 | 
            +
                        s_feat = self.sub_bert_h5[meta["vid_name"]]  # (N_frm, D)
         | 
| 509 | 
            +
                        s_ctx_feat = np.mean(s_feat, axis=0)  # (D, )
         | 
| 510 | 
            +
                        if self.normalize_tfeat:
         | 
| 511 | 
            +
                            s_ctx_feat = l2_normalize_np_array(s_ctx_feat)
         | 
| 512 | 
            +
                        s_padded_moments_features_list, s_moments_mask_list = \
         | 
| 513 | 
            +
                            self.get_batched_moment_feat_for_all_proposals(s_feat, sorted_proposals,
         | 
| 514 | 
            +
                                                                           pool_local=self.pool_local,
         | 
| 515 | 
            +
                                                                           normalize=self.normalize_tfeat)
         | 
| 516 | 
            +
                        model_inputs["sub_moment_features_list"] = [ProposalRetrievalDataset.concat_feat_adv(
         | 
| 517 | 
            +
                            moment_feats=[s, s_ctx_feat], tef=t, ctx_mode=self.ctx_mode)
         | 
| 518 | 
            +
                            for s, t in zip(s_padded_moments_features_list, tef_batched_list)]
         | 
| 519 | 
            +
                        model_inputs["sub_moment_mask_list"] = [torch.from_numpy(e) for e in s_moments_mask_list]
         | 
| 520 | 
            +
                    return dict(meta=meta, model_inputs=model_inputs)
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                def get_batched_moment_feat_for_all_proposals(self, feature, moments, pool_local=False, normalize=True):
         | 
| 523 | 
            +
                    """proposals of the same video wil be segmented into multiple batches to accomodate GPU memory
         | 
| 524 | 
            +
                    pool_local: pool local feature into a single vector
         | 
| 525 | 
            +
                    """
         | 
| 526 | 
            +
                    n_proposal_batches = math.ceil(1.0 * len(moments) / self.eval_proposal_bsz)
         | 
| 527 | 
            +
                    padded_moments_features_list = [None, ] * n_proposal_batches
         | 
| 528 | 
            +
                    moments_mask_list = [None, ] * n_proposal_batches
         | 
| 529 | 
            +
                    moments_features = self.get_moment_feat_for_all_proposals(
         | 
| 530 | 
            +
                        feature, moments, normalize=normalize, pool_local=pool_local)  # N_p * [(N_clips, D), ]
         | 
| 531 | 
            +
                    for batch_idx in range(n_proposal_batches):
         | 
| 532 | 
            +
                        st_m_idx = batch_idx * self.eval_proposal_bsz
         | 
| 533 | 
            +
                        ed_m_idx = (batch_idx + 1) * self.eval_proposal_bsz
         | 
| 534 | 
            +
                        padded_moments_features, moments_mask = \
         | 
| 535 | 
            +
                            pad_sequences_1d(moments_features[st_m_idx:ed_m_idx], dtype=np.float32)
         | 
| 536 | 
            +
                        padded_moments_features_list[batch_idx] = padded_moments_features
         | 
| 537 | 
            +
                        moments_mask_list[batch_idx] = moments_mask
         | 
| 538 | 
            +
                        assert np.sum(np.sum(moments_mask, axis=1) == 0) == 0, " err {}".format(moments_mask)
         | 
| 539 | 
            +
                    assert np.sum(np.sum(moments_mask_list[0], axis=1) == 0) == 0, " err {}".format(moments_mask_list)
         | 
| 540 | 
            +
                    return padded_moments_features_list, moments_mask_list
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                def get_moment_feat_for_all_proposals(self, vid_feat, moments, normalize=True, pool_local=False):
         | 
| 543 | 
            +
                    """Each moment is comprised of multiple clips
         | 
| 544 | 
            +
                    Args:
         | 
| 545 | 
            +
                        vid_feat: np.ndarray, (N_clips, D)
         | 
| 546 | 
            +
                        moments: np.ndarray, (N_p, 2), each row is [st (float), ed (float)],
         | 
| 547 | 
            +
                        normalize: L2 normalize
         | 
| 548 | 
            +
                        pool_local:
         | 
| 549 | 
            +
                    Returns:
         | 
| 550 | 
            +
                        moments_features: list(np.ndarray), [(N_clips, D), ] * N_p, N_clips is changing.
         | 
| 551 | 
            +
                    """
         | 
| 552 | 
            +
                    if normalize and not pool_local:
         | 
| 553 | 
            +
                        vid_feat = l2_normalize_np_array(vid_feat)
         | 
| 554 | 
            +
                    vid_feat_len = len(vid_feat)
         | 
| 555 | 
            +
                    moments_st_clip_indices = np.floor(moments[:, 0] / self.clip_length).astype(np.int64).clip(0, vid_feat_len-1)
         | 
| 556 | 
            +
                    moments_ed_clip_indices = np.ceil(moments[:, 1] / self.clip_length).astype(np.int64).clip(1, vid_feat_len)
         | 
| 557 | 
            +
                    moments_features = []
         | 
| 558 | 
            +
                    for st_idx, ed_idx, m in zip(moments_st_clip_indices, moments_ed_clip_indices, moments):
         | 
| 559 | 
            +
                        feat = vid_feat[st_idx:ed_idx]
         | 
| 560 | 
            +
                        if pool_local:
         | 
| 561 | 
            +
                            feat = np.mean(feat, axis=0, keepdims=True)
         | 
| 562 | 
            +
                            if normalize:
         | 
| 563 | 
            +
                                feat = l2_normalize_np_array(feat)
         | 
| 564 | 
            +
                        moments_features.append(feat)
         | 
| 565 | 
            +
                    return moments_features
         | 
| 566 | 
            +
             | 
| 567 | 
            +
             | 
| 568 | 
            +
            def proposal_retrieval_collate(batch):
         | 
| 569 | 
            +
                batch_meta = [e["meta"] for e in batch]  # seems no need to collate ?
         | 
| 570 | 
            +
             | 
| 571 | 
            +
                model_inputs_keys = batch[0]["model_inputs"].keys()
         | 
| 572 | 
            +
                batched_data = {k: pad_sequences_1d([e["model_inputs"][k] for e in batch], dtype=torch.float32)
         | 
| 573 | 
            +
                                for k in model_inputs_keys}
         | 
| 574 | 
            +
                return batch_meta, batched_data
         | 
| 575 | 
            +
             | 
| 576 | 
            +
             | 
| 577 | 
            +
            def prepare_batch_inputs(batched_model_inputs, device, non_blocking=False):
         | 
| 578 | 
            +
                model_inputs = {}
         | 
| 579 | 
            +
                for k, v in batched_model_inputs.items():
         | 
| 580 | 
            +
                    model_inputs[k] = v[0].to(device, non_blocking=non_blocking)
         | 
| 581 | 
            +
                    model_inputs[k.replace("feat", "mask")] = v[1].to(device, non_blocking=non_blocking)
         | 
| 582 | 
            +
                return model_inputs
         | 
| 583 | 
            +
             | 
| 584 | 
            +
             | 
| 585 | 
            +
            if __name__ == '__main__':
         | 
| 586 | 
            +
                from baselines.clip_alignment_with_language.config import BaseOptions
         | 
| 587 | 
            +
                options = BaseOptions().parse()
         | 
    	
        baselines/clip_alignment_with_language/scripts/compute_upper_bound.sh
    ADDED
    
    | @@ -0,0 +1,23 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env bash
         | 
| 2 | 
            +
            # run at project root dir
         | 
| 3 | 
            +
            dset_name=$1  # see case below
         | 
| 4 | 
            +
            split_name=$2  # train/val/test, some datasets may not support all the 3 splits
         | 
| 5 | 
            +
            result_dir="baselines/clip_alignment_with_language/results"
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            echo "Running with dataset ${dset_name} with split ${split_name}"
         | 
| 8 | 
            +
            case ${dset_name} in
         | 
| 9 | 
            +
                tvr)  # only supports train/val
         | 
| 10 | 
            +
                    eval_file_path=data/tvr_${split_name}_release.jsonl
         | 
| 11 | 
            +
                    save_path=${result_dir}/tvr_${split_name}_proposal_upper_bound.json
         | 
| 12 | 
            +
                    ;;
         | 
| 13 | 
            +
                *)
         | 
| 14 | 
            +
                    echo -n "Unknown argument"
         | 
| 15 | 
            +
                    ;;
         | 
| 16 | 
            +
            esac
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            echo "Running evaluation"
         | 
| 19 | 
            +
            python baselines/clip_alignment_with_language/local_utils/compute_proposal_upper_bound.py \
         | 
| 20 | 
            +
            -dset_name=${dset_name} \
         | 
| 21 | 
            +
            -eval_file_path=${eval_file_path} \
         | 
| 22 | 
            +
            -save_path=${save_path} \
         | 
| 23 | 
            +
            -verbose
         | 
    	
        baselines/clip_alignment_with_language/scripts/inference.sh
    ADDED
    
    | @@ -0,0 +1,17 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env bash
         | 
| 2 | 
            +
            # run at project root dir
         | 
| 3 | 
            +
            # Usage:
         | 
| 4 | 
            +
            # bash baselines/clip_alignment_with_language/scripts/inference.sh ANY_OTHER_PYTHON_ARGS
         | 
| 5 | 
            +
            model_dir=$1
         | 
| 6 | 
            +
            eval_split_name=$2
         | 
| 7 | 
            +
            eval_path=data/tvr_${eval_split_name}_release.jsonl
         | 
| 8 | 
            +
            tasks=(VR)
         | 
| 9 | 
            +
            tasks+=(SVMR)
         | 
| 10 | 
            +
            tasks+=(VCMR)
         | 
| 11 | 
            +
            echo "tasks ${tasks[@]}"
         | 
| 12 | 
            +
            python baselines/clip_alignment_with_language/inference.py \
         | 
| 13 | 
            +
            --model_dir ${model_dir} \
         | 
| 14 | 
            +
            --tasks ${tasks[@]} \
         | 
| 15 | 
            +
            --eval_split_name ${eval_split_name} \
         | 
| 16 | 
            +
            --eval_path ${eval_path} \
         | 
| 17 | 
            +
            ${@:3}
         | 
    	
        baselines/clip_alignment_with_language/scripts/inference_mix.sh
    ADDED
    
    | @@ -0,0 +1,27 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env bash
         | 
| 2 | 
            +
            # run at project root dir
         | 
| 3 | 
            +
            # Usage:
         | 
| 4 | 
            +
            # bash baselines/clip_alignment_with_language/scripts/inference_mix.sh
         | 
| 5 | 
            +
            eval_model=$1  # [mcn, cal], retrain models should only be paired with mee
         | 
| 6 | 
            +
            project_root=/net/bvisionserver14/playpen-ssd/jielei/projects/video_retrieval/baselines/clip_alignment_with_language/results
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            # setup eval model
         | 
| 9 | 
            +
            if [[ ${eval_model} == mcn ]]; then
         | 
| 10 | 
            +
                pred_dir=tvr-mcn-video_sub-res-2019_11_05_14_16_40
         | 
| 11 | 
            +
                tef_pred_dir=tvr-mcn-video_sub_tef-res-2019_11_05_14_14_57
         | 
| 12 | 
            +
            elif [[ ${eval_model} == cal ]]; then
         | 
| 13 | 
            +
                pred_dir=tvr-cal-video_sub-res-2019_11_05_14_32_59
         | 
| 14 | 
            +
                tef_pred_dir=tvr-cal-video_sub_tef-res-2019_11_05_14_25_49
         | 
| 15 | 
            +
            fi
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            pred_path=${project_root}/${pred_dir}/inference_tvr_test_public_max200_predictions_VR_SVMR_VCMR.json
         | 
| 18 | 
            +
            save_path=${project_root}/${pred_dir}/inference_tvr_test_public_max200_predictions_VR_SVMR_VCMR_rerank_${tef_pred_dir}.json
         | 
| 19 | 
            +
            tef_pred_path=${project_root}/${tef_pred_dir}/inference_tvr_test_public_max10000_predictions_VCMR.pt
         | 
| 20 | 
            +
            gt_path=data/tvr_test_public_archive.jsonl
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            python baselines/clip_alignment_with_language/mix_model_prediction.py \
         | 
| 24 | 
            +
            --pred_path=${pred_path} \
         | 
| 25 | 
            +
            --tef_pred_path=${tef_pred_path} \
         | 
| 26 | 
            +
            --gt_path=${gt_path} \
         | 
| 27 | 
            +
            --save_path=${save_path}
         | 
    	
        baselines/clip_alignment_with_language/scripts/inference_with_external.sh
    ADDED
    
    | @@ -0,0 +1,54 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env bash
         | 
| 2 | 
            +
            # run at project root dir
         | 
| 3 | 
            +
            # Usage:
         | 
| 4 | 
            +
            # bash baselines/clip_alignment_with_language/scripts/inference_with_external.sh
         | 
| 5 | 
            +
            #model_dir=$1
         | 
| 6 | 
            +
            # DO not use NMS, since it gives worse results
         | 
| 7 | 
            +
            eval_model=$1  # [mcn, mcn_tef, cal, cal_tef, mcn_retrain, cal_retrain], retrain models should only be paired with mee
         | 
| 8 | 
            +
            external_model=$2  # [mee, mcn, cal]
         | 
| 9 | 
            +
            eval_split_name=$3
         | 
| 10 | 
            +
            eval_path=data/tvr_${eval_split_name}_release.jsonl
         | 
| 11 | 
            +
            project_root=/net/bvisionserver14/playpen-ssd/jielei/projects/video_retrieval/baselines
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # setup eval model
         | 
| 14 | 
            +
            if [[ ${eval_model} == mcn ]]; then
         | 
| 15 | 
            +
                eval_model_dir=tvr-mcn-video_sub-res-2019_11_05_14_16_40
         | 
| 16 | 
            +
            elif [[ ${eval_model} == mcn_tef ]]; then
         | 
| 17 | 
            +
                eval_model_dir=tvr-mcn-video_sub_tef-res-2019_11_05_14_14_57
         | 
| 18 | 
            +
            elif [[ ${eval_model} == cal ]]; then
         | 
| 19 | 
            +
                eval_model_dir=tvr-cal-video_sub-res-2019_11_05_14_32_59
         | 
| 20 | 
            +
            elif [[ ${eval_model} == cal_tef ]]; then
         | 
| 21 | 
            +
                eval_model_dir=tvr-cal-video_sub_tef-res-2019_11_05_14_25_49
         | 
| 22 | 
            +
            elif [[ ${eval_model} == mcn_tef_retrain ]]; then
         | 
| 23 | 
            +
                eval_model_dir=tvr-mcn-video_sub_tef-+ex_vr_mee_tvr-video_sub-res-2019_11_06_00_33_39_tvr-mcn-video_sub_tef-res-2019_11_05_14_14_57+-2019_11_06_02_26_49
         | 
| 24 | 
            +
            elif [[ ${eval_model} == cal_tef_retrain ]]; then
         | 
| 25 | 
            +
                eval_model_dir=tvr-cal-video_sub_tef-+ex_vr_mee_tvr-video_sub-res-2019_11_06_00_33_39_tvr-cal-video_sub_tef-res-2019_11_05_14_25_49+-2019_11_06_03_12_15
         | 
| 26 | 
            +
            fi
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            # setup external
         | 
| 29 | 
            +
            if [[ ${external_model} == mee ]]; then
         | 
| 30 | 
            +
                external_model_dir=tvr-video_sub-res-2019_11_06_00_33_39
         | 
| 31 | 
            +
                external_inference_vr_res_path=${project_root}/mixture_embedding_experts/results/${external_model_dir}/inference_tvr_${eval_split_name}_None_predictions_VR.json
         | 
| 32 | 
            +
            elif [[ ${external_model} == mcn ]]; then
         | 
| 33 | 
            +
                external_model_dir=tvr-mcn-video_sub-res-2019_11_05_14_16_40
         | 
| 34 | 
            +
                external_inference_vr_res_path=${project_root}/clip_alignment_with_language/results/${external_model_dir}/inference_tvr_${eval_split_name}_None_predictions_VR_SVMR_VCMR.json
         | 
| 35 | 
            +
            elif [[ ${external_model} == cal ]]; then
         | 
| 36 | 
            +
                external_model_dir=tvr-cal-video_sub-res-2019_11_05_14_32_59
         | 
| 37 | 
            +
                external_inference_vr_res_path=${project_root}/clip_alignment_with_language/results/${external_model_dir}/inference_tvr_${eval_split_name}_None_predictions_VR_SVMR_VCMR.json
         | 
| 38 | 
            +
            fi
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            tasks=(VR)
         | 
| 41 | 
            +
            tasks+=(SVMR)
         | 
| 42 | 
            +
            tasks+=(VCMR)
         | 
| 43 | 
            +
            echo "tasks ${tasks[@]}"
         | 
| 44 | 
            +
            python baselines/clip_alignment_with_language/inference.py \
         | 
| 45 | 
            +
            --model_dir ${eval_model_dir} \
         | 
| 46 | 
            +
            --tasks ${tasks[@]} \
         | 
| 47 | 
            +
            --eval_split_name ${eval_split_name} \
         | 
| 48 | 
            +
            --eval_path ${eval_path} \
         | 
| 49 | 
            +
            --external_inference_vr_res_path ${external_inference_vr_res_path} \
         | 
| 50 | 
            +
            --eval_id ${external_model_dir} \
         | 
| 51 | 
            +
            ${@:4}
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            #--use_intermediate \  # temporary removed
         | 
| 54 | 
            +
             | 
    	
        baselines/clip_alignment_with_language/scripts/re_train_cal.sh
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env bash
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            lr=0.00005
         | 
| 4 | 
            +
            n_epoch=20
         | 
| 5 | 
            +
            project_root=/net/bvisionserver14/playpen-ssd/jielei/projects/video_retrieval
         | 
| 6 | 
            +
            ckpt_filename="model.ckpt"
         | 
| 7 | 
            +
            init_ckpt_path=${project_root}/baselines/clip_alignment_with_language/results/tvr-cal-video_sub_tef-res-2019_11_05_14_25_49/${ckpt_filename}
         | 
| 8 | 
            +
            exp_id=+ex_vr_mee_tvr-video_sub-res-2019_11_06_00_33_39_tvr-cal-video_sub_tef-res-2019_11_05_14_25_49+
         | 
| 9 | 
            +
            external_train_vr_res_path=${project_root}/baselines/mixture_embedding_experts/results/tvr-video_sub-res-2019_11_06_00_33_39/inference_tvr_train_None_predictions_VR.json
         | 
| 10 | 
            +
            model_type=cal
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            bash baselines/clip_alignment_with_language/scripts/train.sh tvr video_sub_tef resnet_i3d \
         | 
| 13 | 
            +
            --no_norm_vfeat \
         | 
| 14 | 
            +
            --model_type ${model_type} \
         | 
| 15 | 
            +
            --exp_id ${exp_id} \
         | 
| 16 | 
            +
            --init_ckpt_path ${init_ckpt_path} \
         | 
| 17 | 
            +
            --external_train_vr_res_path ${external_train_vr_res_path} \
         | 
| 18 | 
            +
            --lr ${lr} \
         | 
| 19 | 
            +
            --n_epoch ${n_epoch} \
         | 
| 20 | 
            +
            --max_es_cnt 5 \
         | 
| 21 | 
            +
            ${@:1}
         | 
    	
        baselines/clip_alignment_with_language/scripts/re_train_mcn.sh
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env bash
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            lr=0.00005
         | 
| 4 | 
            +
            n_epoch=20
         | 
| 5 | 
            +
            project_root=/net/bvisionserver14/playpen-ssd/jielei/projects/video_retrieval
         | 
| 6 | 
            +
            ckpt_filename="model.ckpt"
         | 
| 7 | 
            +
            init_ckpt_path=${project_root}/baselines/clip_alignment_with_language/results/tvr-mcn-video_sub_tef-res-2019_11_05_14_14_57/${ckpt_filename}
         | 
| 8 | 
            +
            exp_id=+ex_vr_mee_tvr-video_sub-res-2019_11_06_00_33_39_tvr-mcn-video_sub_tef-res-2019_11_05_14_14_57+
         | 
| 9 | 
            +
            external_train_vr_res_path=${project_root}/baselines/mixture_embedding_experts/results/tvr-video_sub-res-2019_11_06_00_33_39/inference_tvr_train_None_predictions_VR.json
         | 
| 10 | 
            +
            model_type=mcn
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            bash baselines/clip_alignment_with_language/scripts/train.sh tvr video_sub_tef resnet_i3d \
         | 
| 13 | 
            +
            --no_norm_vfeat \
         | 
| 14 | 
            +
            --model_type ${model_type} \
         | 
| 15 | 
            +
            --exp_id ${exp_id} \
         | 
| 16 | 
            +
            --init_ckpt_path ${init_ckpt_path} \
         | 
| 17 | 
            +
            --external_train_vr_res_path ${external_train_vr_res_path} \
         | 
| 18 | 
            +
            --lr ${lr} \
         | 
| 19 | 
            +
            --n_epoch ${n_epoch} \
         | 
| 20 | 
            +
            --max_es_cnt 5 \
         | 
| 21 | 
            +
            ${@:1}
         | 
    	
        baselines/clip_alignment_with_language/scripts/train.sh
    ADDED
    
    | @@ -0,0 +1,80 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env bash
         | 
| 2 | 
            +
            # run at project root dir
         | 
| 3 | 
            +
            # Usage:
         | 
| 4 | 
            +
            # bash baselines/clip_alignment_with_language/scripts/train.sh tvr all ANY_OTHER_PYTHON_ARGS
         | 
| 5 | 
            +
            # if re-training, please also give --init_ckpt_path and --external_train_vr_res_path, may also use lower lr ?
         | 
| 6 | 
            +
            dset_name=$1  # see case below
         | 
| 7 | 
            +
            ctx_mode=$2  # ["video", "sub", "tef", "video_sub", "video_tef", "sub_tef", "video_sub_tef"]
         | 
| 8 | 
            +
            vid_feat_type=$3  # [resnet, i3d, resnet_i3d, none] , none for subtitles only models
         | 
| 9 | 
            +
            feature_root=data/tvr_feature_release
         | 
| 10 | 
            +
            results_root=baselines/clip_alignment_with_language/results
         | 
| 11 | 
            +
            vid_feat_size=2048
         | 
| 12 | 
            +
            extra_args=()
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            if [[ ${ctx_mode} == *"sub"* ]] || [[ ${ctx_mode} == "sub" ]]; then
         | 
| 15 | 
            +
                if [[ ${dset_name} != "tvr" ]]; then
         | 
| 16 | 
            +
                    echo "The use of subtitles is only supported in tvr."
         | 
| 17 | 
            +
                    exit 1
         | 
| 18 | 
            +
                fi
         | 
| 19 | 
            +
            fi
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            case ${dset_name} in
         | 
| 23 | 
            +
                tvr)
         | 
| 24 | 
            +
                    train_path=data/tvr_train_release.jsonl
         | 
| 25 | 
            +
                    corpus_path=data/tvr_video2dur_idx.json
         | 
| 26 | 
            +
                    desc_bert_path=${feature_root}/bert_feature/query_only/tvr_query_pretrained_w_query.h5
         | 
| 27 | 
            +
                    vid_feat_path=${feature_root}/video_feature/tvr_resnet152_rgb_max_cl-1.5.h5
         | 
| 28 | 
            +
                    clip_length=1.5
         | 
| 29 | 
            +
                    eval_split_name=val
         | 
| 30 | 
            +
                    nms_thd=-1
         | 
| 31 | 
            +
                    extra_args+=(--eval_path)
         | 
| 32 | 
            +
                    extra_args+=(data/tvr_val_release.jsonl)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    if [[ ${vid_feat_type} == "i3d" ]]; then
         | 
| 35 | 
            +
                        echo "Using I3D feature with shape 1024"
         | 
| 36 | 
            +
                        vid_feat_path=${feature_root}/video_feature/tvr_i3d_rgb600_avg_cl-1.5.h5
         | 
| 37 | 
            +
                        vid_feat_size=1024
         | 
| 38 | 
            +
                    elif [[ ${vid_feat_type} == "resnet" ]]; then
         | 
| 39 | 
            +
                        echo "Using ResNet feature with shape 2048"
         | 
| 40 | 
            +
                        vid_feat_path=${feature_root}/video_feature/tvr_resnet152_rgb_max_cl-1.5.h5
         | 
| 41 | 
            +
                        vid_feat_size=2048
         | 
| 42 | 
            +
                    elif [[ ${vid_feat_type} == "resnet_i3d" ]]; then
         | 
| 43 | 
            +
                        echo "Using concatenated ResNet and I3D feature with shape 2048+1024"
         | 
| 44 | 
            +
                        vid_feat_path=${feature_root}/video_feature/tvr_resnet152_rgb_max_i3d_rgb600_avg_cat_cl-1.5.h5
         | 
| 45 | 
            +
                        vid_feat_size=3072
         | 
| 46 | 
            +
                        extra_args+=(--no_norm_vfeat)  # since they are already normalized.
         | 
| 47 | 
            +
                    fi
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    if [[ ${ctx_mode} == *"sub"* ]] || [[ ${ctx_mode} == "sub" ]]; then
         | 
| 50 | 
            +
                        echo "Running with sub."
         | 
| 51 | 
            +
                        desc_bert_path=${feature_root}/bert_feature/sub_query/tvr_query_pretrained_w_sub_query.h5  # overwrite
         | 
| 52 | 
            +
                        sub_bert_path=${feature_root}/bert_feature/sub_query/tvr_sub_pretrained_w_sub_query_max_cl-1.5.h5
         | 
| 53 | 
            +
                        sub_feat_size=768
         | 
| 54 | 
            +
                        extra_args+=(--sub_feat_size)
         | 
| 55 | 
            +
                        extra_args+=(${sub_feat_size})
         | 
| 56 | 
            +
                        extra_args+=(--sub_bert_path)
         | 
| 57 | 
            +
                        extra_args+=(${sub_bert_path})
         | 
| 58 | 
            +
                    fi
         | 
| 59 | 
            +
                    ;;
         | 
| 60 | 
            +
                *)
         | 
| 61 | 
            +
                    echo -n "Unknown argument"
         | 
| 62 | 
            +
                    ;;
         | 
| 63 | 
            +
            esac
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            echo "Start training with dataset [${dset_name}] in Context Mode [${ctx_mode}]"
         | 
| 66 | 
            +
            echo "Extra args ${extra_args[@]}"
         | 
| 67 | 
            +
            python baselines/clip_alignment_with_language/train.py \
         | 
| 68 | 
            +
            --dset_name=${dset_name} \
         | 
| 69 | 
            +
            --eval_split_name=${eval_split_name} \
         | 
| 70 | 
            +
            --nms_thd=${nms_thd} \
         | 
| 71 | 
            +
            --results_root=${results_root} \
         | 
| 72 | 
            +
            --train_path=${train_path} \
         | 
| 73 | 
            +
            --desc_bert_path=${desc_bert_path} \
         | 
| 74 | 
            +
            --corpus_path=${corpus_path} \
         | 
| 75 | 
            +
            --vid_feat_path=${vid_feat_path} \
         | 
| 76 | 
            +
            --clip_length=${clip_length} \
         | 
| 77 | 
            +
            --vid_feat_size=${vid_feat_size} \
         | 
| 78 | 
            +
            --ctx_mode=${ctx_mode} \
         | 
| 79 | 
            +
            ${extra_args[@]} \
         | 
| 80 | 
            +
            ${@:4}
         | 
    	
        baselines/clip_alignment_with_language/train.py
    ADDED
    
    | @@ -0,0 +1,310 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import time
         | 
| 3 | 
            +
            import json
         | 
| 4 | 
            +
            import pprint
         | 
| 5 | 
            +
            import random
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            from collections import OrderedDict
         | 
| 8 | 
            +
            from easydict import EasyDict as EDict
         | 
| 9 | 
            +
            from tqdm import tqdm, trange
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            import torch.nn as nn
         | 
| 13 | 
            +
            import torch.backends.cudnn as cudnn
         | 
| 14 | 
            +
            from torch.utils.data import DataLoader
         | 
| 15 | 
            +
            from torch.utils.tensorboard import SummaryWriter
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from baselines.clip_alignment_with_language.config import BaseOptions
         | 
| 18 | 
            +
            from baselines.clip_alignment_with_language.model import CALWithSub
         | 
| 19 | 
            +
            from baselines.clip_alignment_with_language.proposal_retrieval_dataset import \
         | 
| 20 | 
            +
                ProposalRetrievalDataset, proposal_retrieval_collate, ProposalRetrievalEvalDataset, prepare_batch_inputs
         | 
| 21 | 
            +
            from baselines.clip_alignment_with_language.inference import eval_epoch, start_inference
         | 
| 22 | 
            +
            from utils.basic_utils import save_jsonl, save_json, AverageMeter
         | 
| 23 | 
            +
            from utils.model_utils import count_parameters
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            import logging
         | 
| 27 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 28 | 
            +
            logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
         | 
| 29 | 
            +
                                datefmt="%Y-%m-%d %H:%M:%S",
         | 
| 30 | 
            +
                                level=logging.INFO)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def set_seed(seed, use_cuda=True):
         | 
| 34 | 
            +
                random.seed(seed)
         | 
| 35 | 
            +
                np.random.seed(seed)
         | 
| 36 | 
            +
                torch.manual_seed(seed)
         | 
| 37 | 
            +
                if use_cuda:
         | 
| 38 | 
            +
                    torch.cuda.manual_seed_all(seed)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def train_epoch(model, train_loader, optimizer, opt, epoch_i):
         | 
| 42 | 
            +
                model.train()
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                # init meters
         | 
| 45 | 
            +
                dataloading_time = AverageMeter()
         | 
| 46 | 
            +
                prepare_inputs_time = AverageMeter()
         | 
| 47 | 
            +
                model_forward_time = AverageMeter()
         | 
| 48 | 
            +
                model_backward_time = AverageMeter()
         | 
| 49 | 
            +
                loss_meter = AverageMeter()
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                num_training_examples = len(train_loader)
         | 
| 52 | 
            +
                timer_dataloading = time.time()
         | 
| 53 | 
            +
                for batch_idx, batch in tqdm(enumerate(train_loader),
         | 
| 54 | 
            +
                                             desc="Training Iteration",
         | 
| 55 | 
            +
                                             total=num_training_examples):
         | 
| 56 | 
            +
                    dataloading_time.update(time.time() - timer_dataloading)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    # continue
         | 
| 59 | 
            +
                    timer_start = time.time()
         | 
| 60 | 
            +
                    model_inputs = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory)
         | 
| 61 | 
            +
                    prepare_inputs_time.update(time.time() - timer_start)
         | 
| 62 | 
            +
                    # logger.info("model_inputs {}"
         | 
| 63 | 
            +
                    #             .format({k: (type(k), v.shape if isinstance(v, torch.Tensor) else v)
         | 
| 64 | 
            +
                    #                      for k, v in model_inputs.items()}))
         | 
| 65 | 
            +
                    # logger.info("model_inputs \n{}".format({k: (type(v), v.shape, v.dtype) for k, v in model_inputs.items()}))
         | 
| 66 | 
            +
                    timer_start = time.time()
         | 
| 67 | 
            +
                    loss = model(**model_inputs)
         | 
| 68 | 
            +
                    model_forward_time.update(time.time() - timer_start)
         | 
| 69 | 
            +
                    timer_start = time.time()
         | 
| 70 | 
            +
                    optimizer.zero_grad()
         | 
| 71 | 
            +
                    loss.backward()
         | 
| 72 | 
            +
                    if opt.grad_clip != -1:
         | 
| 73 | 
            +
                        nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
         | 
| 74 | 
            +
                    optimizer.step()
         | 
| 75 | 
            +
                    model_backward_time.update(time.time() - timer_start)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    global_step = epoch_i * num_training_examples + batch_idx
         | 
| 78 | 
            +
                    opt.writer.add_scalar("Train/LR", float(optimizer.param_groups[0]["lr"]), global_step)
         | 
| 79 | 
            +
                    opt.writer.add_scalar("Train/Loss", float(loss), global_step)
         | 
| 80 | 
            +
                    loss_meter.update(float(loss))
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    timer_dataloading = time.time()
         | 
| 83 | 
            +
                    if opt.debug and batch_idx == 3:
         | 
| 84 | 
            +
                        break
         | 
| 85 | 
            +
                to_write = opt.train_log_txt_formatter.format(
         | 
| 86 | 
            +
                    time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
         | 
| 87 | 
            +
                    epoch=epoch_i,
         | 
| 88 | 
            +
                    loss_str=str(loss_meter.avg))
         | 
| 89 | 
            +
                with open(opt.train_log_filepath, "a") as f:
         | 
| 90 | 
            +
                    f.write(to_write)
         | 
| 91 | 
            +
                print("Epoch time stats:")
         | 
| 92 | 
            +
                print("dataloading_time: max {dataloading_time.max} "
         | 
| 93 | 
            +
                      "min {dataloading_time.min} avg {dataloading_time.avg}\n"
         | 
| 94 | 
            +
                      "prepare_inputs_time: max {prepare_inputs_time.max} "
         | 
| 95 | 
            +
                      "min {prepare_inputs_time.min} avg {prepare_inputs_time.avg}\n"
         | 
| 96 | 
            +
                      "model_forward_time: max {model_forward_time.max} "
         | 
| 97 | 
            +
                      "min {model_forward_time.min} avg {model_forward_time.avg}\n"
         | 
| 98 | 
            +
                      "model_backward_time: max {model_backward_time.max} "
         | 
| 99 | 
            +
                      "min {model_backward_time.min} avg {model_backward_time.avg}\n"
         | 
| 100 | 
            +
                      "".format(dataloading_time=dataloading_time, prepare_inputs_time=prepare_inputs_time,
         | 
| 101 | 
            +
                                model_forward_time=model_forward_time, model_backward_time=model_backward_time))
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            def train(model, train_dataset, val_dataset, opt):
         | 
| 105 | 
            +
                # Prepare optimizer
         | 
| 106 | 
            +
                optimizer = torch.optim.SGD(
         | 
| 107 | 
            +
                    filter(lambda p: p.requires_grad, model.parameters()),
         | 
| 108 | 
            +
                    lr=opt.lr,
         | 
| 109 | 
            +
                    weight_decay=opt.wd,
         | 
| 110 | 
            +
                    momentum=opt.momentum)
         | 
| 111 | 
            +
                # reduce the lr by 0.1 every 30 epochs
         | 
| 112 | 
            +
                scheduler = torch.optim.lr_scheduler.StepLR(
         | 
| 113 | 
            +
                    optimizer,
         | 
| 114 | 
            +
                    step_size=30,
         | 
| 115 | 
            +
                    gamma=0.1
         | 
| 116 | 
            +
                )
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                train_loader = DataLoader(train_dataset,
         | 
| 119 | 
            +
                                          collate_fn=proposal_retrieval_collate,
         | 
| 120 | 
            +
                                          batch_size=opt.bsz,
         | 
| 121 | 
            +
                                          num_workers=opt.num_workers,
         | 
| 122 | 
            +
                                          shuffle=True,
         | 
| 123 | 
            +
                                          pin_memory=opt.pin_memory)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                prev_best_score = 0.
         | 
| 126 | 
            +
                es_cnt = 0
         | 
| 127 | 
            +
                start_epoch = -1 if opt.eval_untrained else 0
         | 
| 128 | 
            +
                eval_tasks_at_training = ["SVMR", ]
         | 
| 129 | 
            +
                save_submission_filename = \
         | 
| 130 | 
            +
                    "latest_{}_{}_predictions_{}.json".format(opt.dset_name, opt.eval_split_name, "_".join(eval_tasks_at_training))
         | 
| 131 | 
            +
                for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
         | 
| 132 | 
            +
                    if epoch_i > -1:
         | 
| 133 | 
            +
                        with torch.autograd.detect_anomaly():
         | 
| 134 | 
            +
                            train_epoch(model, train_loader, optimizer, opt, epoch_i)
         | 
| 135 | 
            +
                    global_step = (epoch_i + 1) * len(train_loader)
         | 
| 136 | 
            +
                    scheduler.step()
         | 
| 137 | 
            +
                    if opt.eval_path is not None:
         | 
| 138 | 
            +
                        with torch.no_grad():
         | 
| 139 | 
            +
                            metrics_no_nms, metrics_nms, latest_file_paths = \
         | 
| 140 | 
            +
                                eval_epoch(model, val_dataset, opt, save_submission_filename, tasks=eval_tasks_at_training,
         | 
| 141 | 
            +
                                           max_before_nms=300, max_after_nms=100)
         | 
| 142 | 
            +
                        logger.info("metrics_no_nms {}".format(
         | 
| 143 | 
            +
                            pprint.pformat(rm_key_from_odict(metrics_no_nms, rm_suffix="by_type"), indent=4)))
         | 
| 144 | 
            +
                        logger.info("metrics_nms \n{}".format(pprint.pformat(metrics_nms, indent=4)))
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                        to_write = opt.eval_log_txt_formatter.format(
         | 
| 147 | 
            +
                            time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
         | 
| 148 | 
            +
                            epoch=epoch_i,
         | 
| 149 | 
            +
                            eval_metrics_str=json.dumps(metrics_no_nms))
         | 
| 150 | 
            +
                        with open(opt.eval_log_filepath, "a") as f:
         | 
| 151 | 
            +
                            f.write(to_write)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                        # metrics = metrics_nms if metrics_nms is not None else metrics_no_nms
         | 
| 154 | 
            +
                        metrics = metrics_no_nms
         | 
| 155 | 
            +
                        # early stop/ log / save model
         | 
| 156 | 
            +
                        for task_type, task_metrics in metrics.items():
         | 
| 157 | 
            +
                            for iou_thd in [0.5, 0.7]:
         | 
| 158 | 
            +
                                opt.writer.add_scalars("Eval/{}-{}".format(task_type, iou_thd),
         | 
| 159 | 
            +
                                                       {k: v for k, v in task_metrics.items() if str(iou_thd) in k},
         | 
| 160 | 
            +
                                                       global_step)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                        # use the most strict metric available
         | 
| 163 | 
            +
                        if metrics["SVMR"]["0.5-r1"] > prev_best_score:
         | 
| 164 | 
            +
                            es_cnt = 0
         | 
| 165 | 
            +
                            prev_best_score = metrics["SVMR"]["0.5-r1"]
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                            checkpoint = {
         | 
| 168 | 
            +
                                "model": model.state_dict(),
         | 
| 169 | 
            +
                                "model_cfg": model.config,
         | 
| 170 | 
            +
                                "epoch": epoch_i}
         | 
| 171 | 
            +
                            torch.save(checkpoint, opt.ckpt_filepath)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                            best_file_paths = [e.replace("latest", "best") for e in latest_file_paths]
         | 
| 174 | 
            +
                            for src, tgt in zip(latest_file_paths, best_file_paths):
         | 
| 175 | 
            +
                                os.renames(src, tgt)
         | 
| 176 | 
            +
                            logger.info("The checkpoint file has been updated.")
         | 
| 177 | 
            +
                        else:
         | 
| 178 | 
            +
                            es_cnt += 1
         | 
| 179 | 
            +
                            if es_cnt > opt.max_es_cnt:  # early stop
         | 
| 180 | 
            +
                                with open(opt.train_log_filepath, "a") as f:
         | 
| 181 | 
            +
                                    f.write("Early Stop at epoch {}".format(epoch_i))
         | 
| 182 | 
            +
                                logger.info("Early stop at {} with SVMR 0.5-r1 {}".format(epoch_i, prev_best_score))
         | 
| 183 | 
            +
                                break
         | 
| 184 | 
            +
                    else:
         | 
| 185 | 
            +
                        checkpoint = {
         | 
| 186 | 
            +
                            "model": model.state_dict(),
         | 
| 187 | 
            +
                            "model_cfg": model.config,
         | 
| 188 | 
            +
                            "epoch": epoch_i}
         | 
| 189 | 
            +
                        torch.save(checkpoint, opt.ckpt_filepath)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    if opt.debug:
         | 
| 192 | 
            +
                        break
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                opt.writer.close()
         | 
| 195 | 
            +
             | 
| 196 | 
            +
             | 
| 197 | 
            +
            def rm_key_from_odict(odict_obj, rm_suffix):
         | 
| 198 | 
            +
                """remove key entry from the OrderedDict"""
         | 
| 199 | 
            +
                return OrderedDict([(k, v) for k, v in odict_obj.items() if rm_suffix not in k])
         | 
| 200 | 
            +
             | 
| 201 | 
            +
             | 
| 202 | 
            +
            def start_training():
         | 
| 203 | 
            +
                logger.info("Setup config, data and model...")
         | 
| 204 | 
            +
                opt = BaseOptions().parse()
         | 
| 205 | 
            +
                set_seed(opt.seed)
         | 
| 206 | 
            +
                if opt.debug:  # keep the model run deterministically
         | 
| 207 | 
            +
                    # 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config.
         | 
| 208 | 
            +
                    # Enable this only when input size is fixed.
         | 
| 209 | 
            +
                    cudnn.benchmark = False
         | 
| 210 | 
            +
                    cudnn.deterministic = True
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                opt.writer = SummaryWriter(opt.tensorboard_log_dir)
         | 
| 213 | 
            +
                opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
         | 
| 214 | 
            +
                opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Metrics] {eval_metrics_str}\n"
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                train_dataset = ProposalRetrievalDataset(
         | 
| 217 | 
            +
                    dset_name=opt.dset_name,
         | 
| 218 | 
            +
                    model_type=opt.model_type,
         | 
| 219 | 
            +
                    data_path=opt.train_path,
         | 
| 220 | 
            +
                    desc_bert_path=opt.desc_bert_path,
         | 
| 221 | 
            +
                    sub_bert_path=opt.sub_bert_path,
         | 
| 222 | 
            +
                    max_desc_len=opt.max_desc_l,
         | 
| 223 | 
            +
                    vid_feat_path=opt.vid_feat_path,
         | 
| 224 | 
            +
                    clip_length=opt.clip_length,
         | 
| 225 | 
            +
                    vid_feat_size=opt.vid_feat_size,
         | 
| 226 | 
            +
                    sub_feat_size=opt.sub_feat_size,
         | 
| 227 | 
            +
                    ctx_mode=opt.ctx_mode,
         | 
| 228 | 
            +
                    pos_iou_thd=opt.pos_iou_thd,
         | 
| 229 | 
            +
                    neg_iou_thd=opt.neg_iou_thd,
         | 
| 230 | 
            +
                    h5driver=opt.h5driver,
         | 
| 231 | 
            +
                    data_ratio=opt.data_ratio,
         | 
| 232 | 
            +
                    normalize_vfeat=not opt.no_norm_vfeat,
         | 
| 233 | 
            +
                    normalize_tfeat=not opt.no_norm_tfeat,
         | 
| 234 | 
            +
                    external_train_vr_res_path=opt.external_train_vr_res_path,  # If not None, used to guide negative sampling
         | 
| 235 | 
            +
                    corpus_path=opt.corpus_path,
         | 
| 236 | 
            +
                )
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                if opt.eval_path is not None:
         | 
| 239 | 
            +
                    eval_dataset = ProposalRetrievalEvalDataset(
         | 
| 240 | 
            +
                        dset_name=opt.dset_name,
         | 
| 241 | 
            +
                        model_type=opt.model_type,
         | 
| 242 | 
            +
                        eval_split_name=opt.eval_split_name,  # should only be val set
         | 
| 243 | 
            +
                        data_path=opt.eval_path,
         | 
| 244 | 
            +
                        desc_bert_path_or_handler=train_dataset.desc_bert_h5,
         | 
| 245 | 
            +
                        sub_bert_path_or_handler=train_dataset.sub_bert_h5 if "sub" in opt.ctx_mode else None,
         | 
| 246 | 
            +
                        max_desc_len=opt.max_desc_l,
         | 
| 247 | 
            +
                        corpus_path=opt.corpus_path,
         | 
| 248 | 
            +
                        vid_feat_path_or_handler=train_dataset.vid_feat_h5 if "video" in opt.ctx_mode else None,
         | 
| 249 | 
            +
                        clip_length=opt.clip_length,
         | 
| 250 | 
            +
                        eval_proposal_bsz=opt.eval_proposal_bsz,
         | 
| 251 | 
            +
                        ctx_mode=opt.ctx_mode,
         | 
| 252 | 
            +
                        data_mode="query",
         | 
| 253 | 
            +
                        h5driver=opt.h5driver,
         | 
| 254 | 
            +
                        data_ratio=opt.data_ratio,
         | 
| 255 | 
            +
                        normalize_vfeat=not opt.no_norm_vfeat,
         | 
| 256 | 
            +
                        normalize_tfeat=not opt.no_norm_tfeat,
         | 
| 257 | 
            +
                    )
         | 
| 258 | 
            +
                else:
         | 
| 259 | 
            +
                    eval_dataset = None
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                model_config = EDict(
         | 
| 262 | 
            +
                    visual_input_size=train_dataset.vid_feat_output_size,  # changes based on visual input type
         | 
| 263 | 
            +
                    textual_input_size=train_dataset.sub_feat_output_size,
         | 
| 264 | 
            +
                    query_feat_size=opt.desc_feat_size,
         | 
| 265 | 
            +
                    visual_hidden_size=opt.visual_hidden_size,  #
         | 
| 266 | 
            +
                    output_size=opt.output_size,
         | 
| 267 | 
            +
                    embedding_size=opt.embedding_size,
         | 
| 268 | 
            +
                    lstm_hidden_size=opt.lstm_hidden_size,
         | 
| 269 | 
            +
                    margin=opt.margin,  # margin for ranking loss
         | 
| 270 | 
            +
                    loss_type=opt.loss_type,  # loss type, 'hinge' or 'lse'
         | 
| 271 | 
            +
                    inter_loss_weight=opt.inter_loss_weight * (opt.ctx_mode == "tef"),  # weight for inter negatives
         | 
| 272 | 
            +
                    ctx_mode=opt.ctx_mode
         | 
| 273 | 
            +
                )
         | 
| 274 | 
            +
                logger.info("model_config {}".format(model_config))
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                model = CALWithSub(model_config)
         | 
| 277 | 
            +
                if opt.device.type == "cuda":
         | 
| 278 | 
            +
                    logger.info("CUDA enabled.")
         | 
| 279 | 
            +
                    model.to(opt.device)
         | 
| 280 | 
            +
                    if len(opt.device_ids) > 1:
         | 
| 281 | 
            +
                        logger.info("Use multi GPU", opt.device_ids)
         | 
| 282 | 
            +
                        model = torch.nn.DataParallel(model, device_ids=opt.device_ids)  # use multi GPU
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                if opt.init_ckpt_path is not None:
         | 
| 285 | 
            +
                    checkpoint = torch.load(opt.init_ckpt_path)
         | 
| 286 | 
            +
                    model.load_state_dict(checkpoint["model"])
         | 
| 287 | 
            +
                    logger.info("Loaded model saved at epoch {} from checkpoint: {}"
         | 
| 288 | 
            +
                                .format(checkpoint["epoch"], opt.init_ckpt_path))
         | 
| 289 | 
            +
                count_parameters(model)
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                logger.info("Start Training...")
         | 
| 292 | 
            +
                train(model, train_dataset, eval_dataset, opt)
         | 
| 293 | 
            +
                return opt.results_dir, opt.eval_split_name, opt.eval_path, opt.debug
         | 
| 294 | 
            +
             | 
| 295 | 
            +
             | 
| 296 | 
            +
            if __name__ == '__main__':
         | 
| 297 | 
            +
                model_dir, eval_split_name, eval_path, debug = start_training()
         | 
| 298 | 
            +
                if not debug:
         | 
| 299 | 
            +
                    model_dir = model_dir.split(os.sep)[-1]
         | 
| 300 | 
            +
                    tasks = ["SVMR", "VCMR"]
         | 
| 301 | 
            +
                    input_args = ["--model_dir", model_dir,
         | 
| 302 | 
            +
                                  "--eval_split_name", eval_split_name,
         | 
| 303 | 
            +
                                  "--eval_path", eval_path,
         | 
| 304 | 
            +
                                  "--tasks"] + tasks
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    import sys
         | 
| 307 | 
            +
                    sys.argv[1:] = input_args
         | 
| 308 | 
            +
                    logger.info("\n\n\nFINISHED TRAINING!!!")
         | 
| 309 | 
            +
                    logger.info("Evaluating model in {}".format(model_dir))
         | 
| 310 | 
            +
                    start_inference()
         | 
    	
        baselines/crossmodal_moment_localization/README.md
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            Cross-modal Moment Localization (XML)
         | 
| 2 | 
            +
            ===
         | 
    	
        baselines/crossmodal_moment_localization/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        baselines/crossmodal_moment_localization/__pycache__/__init__.cpython-311.pyc
    ADDED
    
    | Binary file (207 Bytes). View file | 
|  | 
    	
        baselines/crossmodal_moment_localization/__pycache__/config.cpython-311.pyc
    ADDED
    
    | Binary file (23.3 kB). View file | 
|  | 
    	
        baselines/crossmodal_moment_localization/__pycache__/inference.cpython-311.pyc
    ADDED
    
    | Binary file (24.1 kB). View file | 
|  | 
    	
        baselines/crossmodal_moment_localization/__pycache__/model_components.cpython-311.pyc
    ADDED
    
    | Binary file (19.8 kB). View file | 
|  | 
    	
        baselines/crossmodal_moment_localization/__pycache__/model_xml.cpython-311.pyc
    ADDED
    
    | Binary file (39.8 kB). View file | 
|  | 
    	
        baselines/crossmodal_moment_localization/__pycache__/ndcg_iou_topk.cpython-311.pyc
    ADDED
    
    | Binary file (5.64 kB). View file | 
|  | 
    	
        baselines/crossmodal_moment_localization/__pycache__/optimization.cpython-311.pyc
    ADDED
    
    | Binary file (18.8 kB). View file | 
|  | 
    	
        baselines/crossmodal_moment_localization/__pycache__/start_end_dataset.cpython-311.pyc
    ADDED
    
    | Binary file (19.5 kB). View file | 
|  | 
    	
        baselines/crossmodal_moment_localization/config.py
    ADDED
    
    | @@ -0,0 +1,276 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import time
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import argparse
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from utils.basic_utils import mkdirp, load_json, save_json, make_zipfile
         | 
| 7 | 
            +
            from baselines.clip_alignment_with_language.local_utils.proposal import ProposalConfigs
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class BaseOptions(object):
         | 
| 11 | 
            +
                saved_option_filename = "opt.json"
         | 
| 12 | 
            +
                ckpt_filename = "model.ckpt"
         | 
| 13 | 
            +
                tensorboard_log_dir = "tensorboard_log"
         | 
| 14 | 
            +
                train_log_filename = "train.log.txt"
         | 
| 15 | 
            +
                eval_log_filename = "eval.log.txt"
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def __init__(self):
         | 
| 18 | 
            +
                    self.parser = argparse.ArgumentParser()
         | 
| 19 | 
            +
                    self.initialized = False
         | 
| 20 | 
            +
                    self.opt = None
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                def initialize(self):
         | 
| 23 | 
            +
                    self.initialized = True
         | 
| 24 | 
            +
                    self.parser.add_argument("--dset_name", type=str, choices=["tvr"])
         | 
| 25 | 
            +
                    self.parser.add_argument("--model_name", type=str)
         | 
| 26 | 
            +
                    self.parser.add_argument("--eval_split_name", type=str, default="val",
         | 
| 27 | 
            +
                                             help="should match keys in corpus_path, must set for VCMR")
         | 
| 28 | 
            +
                    self.parser.add_argument("--debug", action="store_true",
         | 
| 29 | 
            +
                                             help="debug (fast) mode, break all loops, do not load all data into memory.")
         | 
| 30 | 
            +
                    self.parser.add_argument("--data_ratio", type=float, default=1.0,
         | 
| 31 | 
            +
                                             help="how many training and eval data to use. 1.0: use all, 0.1: use 10%."
         | 
| 32 | 
            +
                                                  "Use small portion for debug purposes. Note this is different from --debug, "
         | 
| 33 | 
            +
                                                  "which works by breaking the loops, typically they are not used together.")
         | 
| 34 | 
            +
                    self.parser.add_argument("--results_root", type=str, default="results")
         | 
| 35 | 
            +
                    self.parser.add_argument("--exp_id", type=str, default=None, help="id of this run, required at training")
         | 
| 36 | 
            +
                    self.parser.add_argument("--seed", type=int, default=2018, help="random seed")
         | 
| 37 | 
            +
                    self.parser.add_argument("--device", type=int, default=0, help="0 cuda, -1 cpu")
         | 
| 38 | 
            +
                    self.parser.add_argument("--device_ids", type=int, nargs="+", default=[0], help="GPU ids to run the job")
         | 
| 39 | 
            +
                    self.parser.add_argument("--num_workers", type=int, default=4,
         | 
| 40 | 
            +
                                             help="num subprocesses used to load the data, 0: use main process")
         | 
| 41 | 
            +
                    self.parser.add_argument("--no_core_driver", action="store_true",
         | 
| 42 | 
            +
                                             help="hdf5 driver, default use `core` (load into RAM), if specified, use `None`")
         | 
| 43 | 
            +
                    self.parser.add_argument("--no_pin_memory", action="store_true",
         | 
| 44 | 
            +
                                             help="Don't use pin_memory=True for dataloader. "
         | 
| 45 | 
            +
                                                  "ref: https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/4")
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    # training config
         | 
| 48 | 
            +
                    self.parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
         | 
| 49 | 
            +
                    self.parser.add_argument("--lr_warmup_proportion", type=float, default=0.01,
         | 
| 50 | 
            +
                                             help="Proportion of training to perform linear learning rate warmup for. "
         | 
| 51 | 
            +
                                                  "E.g., 0.1 = 10% of training.")
         | 
| 52 | 
            +
                    self.parser.add_argument("--wd", type=float, default=0.01, help="weight decay")
         | 
| 53 | 
            +
                    self.parser.add_argument("--n_epoch", type=int, default=100, help="number of epochs to run")
         | 
| 54 | 
            +
                    self.parser.add_argument("--max_es_cnt", type=int, default=10,
         | 
| 55 | 
            +
                                             help="number of epochs to early stop, use -1 to disable early stop")
         | 
| 56 | 
            +
                    self.parser.add_argument("--stop_task", type=str, default="VCMR", choices=["VCMR", "SVMR", "VR"],
         | 
| 57 | 
            +
                                             help="Use metric associated with stop_task for early stop")
         | 
| 58 | 
            +
                    self.parser.add_argument("--eval_tasks_at_training", type=str, nargs="+",
         | 
| 59 | 
            +
                                             default=["VCMR"], choices=["VCMR", "SVMR", "VR"],
         | 
| 60 | 
            +
                                             help="evaluate and report  numbers for tasks specified here.")
         | 
| 61 | 
            +
                    self.parser.add_argument("--bsz", type=int, default=128, help="mini-batch size")
         | 
| 62 | 
            +
                    self.parser.add_argument("--eval_query_bsz", type=int, default=50,
         | 
| 63 | 
            +
                                             help="mini-batch size at inference, for query")
         | 
| 64 | 
            +
                    self.parser.add_argument("--eval_context_bsz", type=int, default=200,
         | 
| 65 | 
            +
                                             help="mini-batch size at inference, for video/sub")
         | 
| 66 | 
            +
                    self.parser.add_argument("--eval_untrained", action="store_true", help="Evaluate on un-trained model")
         | 
| 67 | 
            +
                    self.parser.add_argument("--grad_clip", type=float, default=-1, help="perform gradient clip, -1: disable")
         | 
| 68 | 
            +
                    self.parser.add_argument("--margin", type=float, default=0.1, help="margin for   hinge loss")
         | 
| 69 | 
            +
                    self.parser.add_argument("--lw_neg_q", type=float, default=1,
         | 
| 70 | 
            +
                                             help="weight for ranking loss with negative query and positive context")
         | 
| 71 | 
            +
                    self.parser.add_argument("--lw_neg_ctx", type=float, default=1,
         | 
| 72 | 
            +
                                             help="weight for ranking loss with positive query and negative context")
         | 
| 73 | 
            +
                    self.parser.add_argument("--lw_st_ed", type=float, default=0.01, help="weight for st ed prediction loss")
         | 
| 74 | 
            +
                    self.parser.add_argument("--train_span_start_epoch", type=int, default=0,
         | 
| 75 | 
            +
                                             help="which epoch to start training span prediction, -1 to disable")
         | 
| 76 | 
            +
                    self.parser.add_argument("--ranking_loss_type", type=str, default="hinge", choices=["hinge", "lse"],
         | 
| 77 | 
            +
                                             help="att loss type, can be hinge loss or its smooth approximation LogSumExp")
         | 
| 78 | 
            +
                    self.parser.add_argument("--hard_negtiave_start_epoch", type=int, default=20,
         | 
| 79 | 
            +
                                             help="which epoch to start hard negative sampling for video-level ranking loss,"
         | 
| 80 | 
            +
                                                  "use -1 to disable")
         | 
| 81 | 
            +
                    self.parser.add_argument("--hard_pool_size", type=int, default=20,
         | 
| 82 | 
            +
                                             help="hard negatives are still sampled, but from a harder pool.")
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    # Model and Data config
         | 
| 85 | 
            +
                    self.parser.add_argument("--max_sub_l", type=int, default=50,
         | 
| 86 | 
            +
                                             help="max length of all sub sentence 97.71 under 50 for 3 sentences")
         | 
| 87 | 
            +
                    self.parser.add_argument("--max_desc_l", type=int, default=30, help="max length of descriptions")
         | 
| 88 | 
            +
                    self.parser.add_argument("--max_ctx_l", type=int, default=100,
         | 
| 89 | 
            +
                                             help="max number of snippets, 100 for tvr clip_length=1.5, oly 109/21825 > 100")
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    self.parser.add_argument("--train_path", type=str, default=None)
         | 
| 92 | 
            +
                    self.parser.add_argument("--val_path", type=str, default=None)
         | 
| 93 | 
            +
                    self.parser.add_argument("--test_path", type=str, default=None)
         | 
| 94 | 
            +
                    self.parser.add_argument("--external_inference_vr_res_path", type=str, default=None,
         | 
| 95 | 
            +
                                             help="if set, use external video retrieval results to guide evaluation. ")
         | 
| 96 | 
            +
                    self.parser.add_argument("--use_glove", action="store_true", help="Use GloVe instead of BERT features")
         | 
| 97 | 
            +
                    self.parser.add_argument("--word2idx_path", type=str,
         | 
| 98 | 
            +
                                             help="a dict, {word: word_idx, ...}, "
         | 
| 99 | 
            +
                                                  "special tokens are {<pad>: 0, <unk>: 1, <eos>: 2}")
         | 
| 100 | 
            +
                    self.parser.add_argument("--vocab_size", type=int, default=-1,
         | 
| 101 | 
            +
                                             help="Set automatically to len(word2idx)")
         | 
| 102 | 
            +
                    self.parser.add_argument("--glove_path", type=str,
         | 
| 103 | 
            +
                                             help="path to file containing the GloVe embeddings for words in word2idx")
         | 
| 104 | 
            +
                    self.parser.add_argument("--desc_bert_path", type=str, default=None)
         | 
| 105 | 
            +
                    self.parser.add_argument("--sub_bert_path", type=str, default=None)
         | 
| 106 | 
            +
                    self.parser.add_argument("--sub_feat_size", type=int, default=768, help="feature dim for sub feature")
         | 
| 107 | 
            +
                    self.parser.add_argument("--q_feat_size", type=int, default=768, help="feature dim for sub feature")
         | 
| 108 | 
            +
                    self.parser.add_argument("--ctx_mode", type=str, choices=["video", "sub", "video_sub", "tef",
         | 
| 109 | 
            +
                                                                              "video_tef", "sub_tef", "video_sub_tef"],
         | 
| 110 | 
            +
                                             help="which context to use. a combination of [video, sub, tef]")
         | 
| 111 | 
            +
                    self.parser.add_argument("--corpus_path", type=str, default=None)
         | 
| 112 | 
            +
                    self.parser.add_argument("--vid_feat_path", type=str, default="")
         | 
| 113 | 
            +
                    self.parser.add_argument("--no_norm_vfeat", action="store_true",
         | 
| 114 | 
            +
                                             help="Do not do normalization on video feat, use it only when using resnet_i3d feat")
         | 
| 115 | 
            +
                    self.parser.add_argument("--no_norm_tfeat", action="store_true", help="Do not do normalization on text feat")
         | 
| 116 | 
            +
                    self.parser.add_argument("--clip_length", type=float, default=None,
         | 
| 117 | 
            +
                                             help="each video will be uniformly segmented into small clips, "
         | 
| 118 | 
            +
                                                  "will automatically loaded from ProposalConfigs if None")
         | 
| 119 | 
            +
                    self.parser.add_argument("--vid_feat_size", type=int, help="feature dim for video feature")
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    self.parser.add_argument("--span_predictor_type", type=str, default="conv", choices=["conv", "cat_linear"],
         | 
| 122 | 
            +
                                             help="how to generate span predictions, "
         | 
| 123 | 
            +
                                                  "conv: apply 1D-Conv layer on top of NxL dot product of query and clips"
         | 
| 124 | 
            +
                                                  "cat_linear: cat the query and clips then use a linear layer to give output. "
         | 
| 125 | 
            +
                                                  "Note cat_linear is implemented as first project query and clips into scores, "
         | 
| 126 | 
            +
                                                  "separately, then sum them up, this should be similar to first cat then project.")
         | 
| 127 | 
            +
                    self.parser.add_argument("--stack_conv_predictor_conv_kernel_sizes", type=int, default=-1, nargs="+",
         | 
| 128 | 
            +
                                             help="combine the results from conv edge detectors of all sizes specified."
         | 
| 129 | 
            +
                                                  "-1: disable. If specified, will ignore --conv_kernel_size option."
         | 
| 130 | 
            +
                                                  "This flag is only used when --merge_two_stream and --span_predictor_type conv!")
         | 
| 131 | 
            +
                    self.parser.add_argument("--encoder_type", type=str, default="transformer",
         | 
| 132 | 
            +
                                             choices=["gru", "lstm", "transformer", "cnn"])
         | 
| 133 | 
            +
                    self.parser.add_argument("--add_pe_rnn", action="store_true",
         | 
| 134 | 
            +
                                             help="Add positional encoding for GRU and LSTM encoder as well")
         | 
| 135 | 
            +
                    self.parser.add_argument("--no_merge_two_stream", action="store_true", help="do not merge video and subtitles")
         | 
| 136 | 
            +
                    self.parser.add_argument("--no_cross_att", action="store_true",
         | 
| 137 | 
            +
                                             help="Use cross-attention for modeling video and subtitles")
         | 
| 138 | 
            +
                    self.parser.add_argument("--no_self_att", action="store_true", help="do not use self attention")
         | 
| 139 | 
            +
                    self.parser.add_argument("--no_modular", action="store_true", help="do not use modular attention")
         | 
| 140 | 
            +
                    self.parser.add_argument("--pe_type", type=str, default="cosine", choices=["none", "linear", "cosine"],
         | 
| 141 | 
            +
                                             help="Only for query encoding")
         | 
| 142 | 
            +
                    self.parser.add_argument("--max_position_embeddings", type=int, default=300)
         | 
| 143 | 
            +
                    self.parser.add_argument("--hidden_size", type=int, default=256)
         | 
| 144 | 
            +
                    self.parser.add_argument("--n_heads", type=int, default=4)
         | 
| 145 | 
            +
                    self.parser.add_argument("--input_drop", type=float, default=0.1, help="Applied to all inputs")
         | 
| 146 | 
            +
                    self.parser.add_argument("--drop", type=float, default=0.1, help="Applied to all other layers")
         | 
| 147 | 
            +
                    self.parser.add_argument("--cross_att_drop", type=float, default=0.1, help="Applied to cross-att")
         | 
| 148 | 
            +
                    self.parser.add_argument("--conv_kernel_size", type=int, default=5)
         | 
| 149 | 
            +
                    self.parser.add_argument("--conv_stride", type=int, default=1)
         | 
| 150 | 
            +
                    self.parser.add_argument("--initializer_range", type=float, default=0.02,
         | 
| 151 | 
            +
                                             help="initializer range for linear layer")
         | 
| 152 | 
            +
                    self.parser.add_argument("--eval_num_per_epoch", type=float)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    # post processing
         | 
| 155 | 
            +
                    self.parser.add_argument("--min_pred_l", type=int, default=2,
         | 
| 156 | 
            +
                                             help="constrain the [st, ed] with ed - st >= 2"
         | 
| 157 | 
            +
                                                  "(2 clips with length 1.5 each, 3 secs in total"
         | 
| 158 | 
            +
                                                  "this is the min length for proposal-based method)")
         | 
| 159 | 
            +
                    self.parser.add_argument("--max_pred_l", type=int, default=16,
         | 
| 160 | 
            +
                                             help="constrain the [st, ed] pairs with ed - st <= 16, 24 secs in total"
         | 
| 161 | 
            +
                                                  "(16 clips with length 1.5 each, "
         | 
| 162 | 
            +
                                                  "this is the max length for proposal-based method)")
         | 
| 163 | 
            +
                    self.parser.add_argument("--q2c_alpha", type=float, default=20,
         | 
| 164 | 
            +
                                             help="give more importance to top scored videos' spans,  "
         | 
| 165 | 
            +
                                                  "the new score will be: s_new = exp(alpha * s), "
         | 
| 166 | 
            +
                                                  "higher alpha indicates more importance. Note s in [-1, 1]")
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    self.parser.add_argument("--max_before_nms", type=int, default=200)
         | 
| 169 | 
            +
                    self.parser.add_argument("--max_vcmr_video", type=int, default=100,
         | 
| 170 | 
            +
                                             help="re-ranking in top-max_vcmr_video")
         | 
| 171 | 
            +
                    self.parser.add_argument("--nms_thd", type=float, default=-1,
         | 
| 172 | 
            +
                                             help="additionally use non-maximum suppression "
         | 
| 173 | 
            +
                                                  "(or non-minimum suppression for distance)"
         | 
| 174 | 
            +
                                                  "to post-processing the predictions. "
         | 
| 175 | 
            +
                                                  "-1: do not use nms. 0.6 for charades_sta, 0.5 for anet_cap,")
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                def display_save(self, opt):
         | 
| 178 | 
            +
                    args = vars(opt)
         | 
| 179 | 
            +
                    # Display settings
         | 
| 180 | 
            +
                    print("------------ Options -------------\n{}\n-------------------"
         | 
| 181 | 
            +
                          .format({str(k): str(v) for k, v in sorted(args.items())}))
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    # Save settings
         | 
| 184 | 
            +
                    if not isinstance(self, TestOptions):
         | 
| 185 | 
            +
                        option_file_path = os.path.join(opt.results_dir, self.saved_option_filename)  # not yaml file indeed
         | 
| 186 | 
            +
                        save_json(args, option_file_path, save_pretty=True)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                def parse(self):
         | 
| 189 | 
            +
                    if not self.initialized:
         | 
| 190 | 
            +
                        self.initialize()
         | 
| 191 | 
            +
                    opt = self.parser.parse_args()
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    if opt.debug:
         | 
| 194 | 
            +
                        opt.results_root = os.path.sep.join(opt.results_root.split(os.path.sep)[:-1] + ["debug_results", ])
         | 
| 195 | 
            +
                        opt.no_core_driver = True
         | 
| 196 | 
            +
                        opt.num_workers = 0
         | 
| 197 | 
            +
                        opt.eval_query_bsz = 100
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    if isinstance(self, TestOptions):
         | 
| 200 | 
            +
                        # modify model_dir to absolute path
         | 
| 201 | 
            +
                        opt.model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", opt.model_dir)
         | 
| 202 | 
            +
                        saved_options = load_json(os.path.join(opt.model_dir, self.saved_option_filename))
         | 
| 203 | 
            +
                        for arg in saved_options:  # use saved options to overwrite all BaseOptions args.
         | 
| 204 | 
            +
                            if arg not in ["results_root", "num_workers", "nms_thd", "debug",
         | 
| 205 | 
            +
                                           "eval_split_name", "eval_path", "eval_query_bsz", "eval_context_bsz",
         | 
| 206 | 
            +
                                           "max_pred_l", "min_pred_l", "external_inference_vr_res_path"]:
         | 
| 207 | 
            +
                                setattr(opt, arg, saved_options[arg])
         | 
| 208 | 
            +
                        # opt.no_core_driver = True
         | 
| 209 | 
            +
                    else:
         | 
| 210 | 
            +
                        if opt.exp_id is None:
         | 
| 211 | 
            +
                            raise ValueError("--exp_id is required for at a training option!")
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                        if opt.clip_length is None:
         | 
| 214 | 
            +
                            opt.clip_length = ProposalConfigs[opt.dset_name]["clip_length"]
         | 
| 215 | 
            +
                            print("Loaded clip_length {} from proposal config file".format(opt.clip_length))
         | 
| 216 | 
            +
                        opt.results_dir = os.path.join(opt.results_root, "_".join([opt.model_name, opt.exp_id, time.strftime("%Y%m%d_%H%M%S")]))
         | 
| 217 | 
            +
                        mkdirp(opt.results_dir)
         | 
| 218 | 
            +
                        # save a copy of current code
         | 
| 219 | 
            +
                        code_dir = os.path.dirname(os.path.realpath(__file__))
         | 
| 220 | 
            +
                        code_zip_filename = os.path.join(opt.results_dir, "code.zip")
         | 
| 221 | 
            +
                        make_zipfile(code_dir, code_zip_filename,
         | 
| 222 | 
            +
                                     enclosing_dir="code",
         | 
| 223 | 
            +
                                     exclude_dirs_substring="results",
         | 
| 224 | 
            +
                                     exclude_dirs=["results", "debug_results", "__pycache__"],
         | 
| 225 | 
            +
                                     exclude_extensions=[".pyc", ".ipynb", ".swap"],)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    self.display_save(opt)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    if "sub" in opt.ctx_mode:
         | 
| 230 | 
            +
                        assert opt.dset_name == "tvr", "sub is only supported for tvr dataset"
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    if opt.hard_negtiave_start_epoch != -1:
         | 
| 233 | 
            +
                        if opt.hard_pool_size > opt.bsz:
         | 
| 234 | 
            +
                            print("[WARNING] hard_pool_size is larger than bsz")
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    assert opt.stop_task in opt.eval_tasks_at_training
         | 
| 237 | 
            +
                    opt.ckpt_filepath = os.path.join(opt.results_dir, self.ckpt_filename)
         | 
| 238 | 
            +
                    opt.train_log_filepath = os.path.join(opt.results_dir, self.train_log_filename)
         | 
| 239 | 
            +
                    opt.eval_log_filepath = os.path.join(opt.results_dir, self.eval_log_filename)
         | 
| 240 | 
            +
                    opt.tensorboard_log_dir = os.path.join(opt.results_dir, self.tensorboard_log_dir)
         | 
| 241 | 
            +
                    opt.device = torch.device("cuda:%d" % opt.device_ids[0] if opt.device >= 0 else "cpu")
         | 
| 242 | 
            +
                    opt.h5driver = None if opt.no_core_driver else "core"
         | 
| 243 | 
            +
                    # num_workers > 1 will only work with "core" mode, i.e., memory-mapped hdf5
         | 
| 244 | 
            +
                    opt.num_workers = 1 if opt.no_core_driver else opt.num_workers
         | 
| 245 | 
            +
                    opt.pin_memory = not opt.no_pin_memory
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    if "video" in opt.ctx_mode and opt.vid_feat_size > 3000:  # 3072, the normalized concatenation of resnet+i3d
         | 
| 248 | 
            +
                        assert opt.no_norm_vfeat
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    if "tef" in opt.ctx_mode and "video" in opt.ctx_mode:
         | 
| 251 | 
            +
                        opt.vid_feat_size += 2
         | 
| 252 | 
            +
                    if "tef" in opt.ctx_mode and "sub" in opt.ctx_mode:
         | 
| 253 | 
            +
                        opt.sub_feat_size += 2
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    if "video" not in opt.ctx_mode or "sub" not in opt.ctx_mode:
         | 
| 256 | 
            +
                        opt.no_merge_two_stream = True
         | 
| 257 | 
            +
                        opt.no_cross_att = True
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    self.opt = opt
         | 
| 260 | 
            +
                    return opt
         | 
| 261 | 
            +
             | 
| 262 | 
            +
             | 
| 263 | 
            +
            class TestOptions(BaseOptions):
         | 
| 264 | 
            +
                """add additional options for evaluating"""
         | 
| 265 | 
            +
                def initialize(self):
         | 
| 266 | 
            +
                    BaseOptions.initialize(self)
         | 
| 267 | 
            +
                    # also need to specify --eval_split_name
         | 
| 268 | 
            +
                    self.parser.add_argument("--eval_id", type=str, help="evaluation id")
         | 
| 269 | 
            +
                    self.parser.add_argument("--model_dir", type=str,
         | 
| 270 | 
            +
                                             help="dir contains the model file, will be converted to absolute path afterwards")
         | 
| 271 | 
            +
                    self.parser.add_argument("--tasks", type=str, nargs="+",
         | 
| 272 | 
            +
                                             choices=["VCMR", "SVMR", "VR"], default=["VCMR", "SVMR", "VR"],
         | 
| 273 | 
            +
                                             help="Which tasks to run."
         | 
| 274 | 
            +
                                                  "VCMR: Video Corpus Moment Retrieval;"
         | 
| 275 | 
            +
                                                  "SVMR: Single Video Moment Retrieval;"
         | 
| 276 | 
            +
                                                  "VR: regular Video Retrieval. (will be performed automatically with VCMR)")
         | 
    	
        baselines/crossmodal_moment_localization/inference.py
    ADDED
    
    | @@ -0,0 +1,414 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import copy
         | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            import time
         | 
| 5 | 
            +
            import pprint
         | 
| 6 | 
            +
            from tqdm import tqdm, trange
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torch.nn.functional as F
         | 
| 11 | 
            +
            import torch.backends.cudnn as cudnn
         | 
| 12 | 
            +
            from torch.utils.data import DataLoader
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from baselines.crossmodal_moment_localization.config import TestOptions
         | 
| 15 | 
            +
            from baselines.crossmodal_moment_localization.model_xml import XML
         | 
| 16 | 
            +
            from baselines.crossmodal_moment_localization.start_end_dataset import \
         | 
| 17 | 
            +
                start_end_collate, StartEndEvalDataset, prepare_batch_inputs
         | 
| 18 | 
            +
            from baselines.clip_alignment_with_language.inference import \
         | 
| 19 | 
            +
                get_submission_top_n, post_processing_vcmr_nms, post_processing_svmr_nms
         | 
| 20 | 
            +
            from utils.basic_utils import save_json, load_json
         | 
| 21 | 
            +
            from utils.tensor_utils import find_max_triples_from_upper_triangle_product
         | 
| 22 | 
            +
            from standalone_eval.eval import eval_retrieval
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            import logging
         | 
| 25 | 
            +
            from ndcg_iou_topk import calculate_ndcg_iou
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def compute_context_info(model, eval_dataset, opt):
         | 
| 31 | 
            +
                """Use val set to do evaluation, remember to run with torch.no_grad().
         | 
| 32 | 
            +
                estimated 2200 (videos) * 100 (frm) * 500 (hsz) * 4 (B) * 2 (video/sub) * 2 (layers) / (1024 ** 2) = 1.76 GB
         | 
| 33 | 
            +
                max_n_videos: only consider max_n_videos videos for each query to return st_ed scores.
         | 
| 34 | 
            +
                """
         | 
| 35 | 
            +
                model.eval()
         | 
| 36 | 
            +
                # eval_dataset.set_data_mode("context")
         | 
| 37 | 
            +
                context_dataloader = DataLoader(eval_dataset,
         | 
| 38 | 
            +
                                                collate_fn=start_end_collate,
         | 
| 39 | 
            +
                                                batch_size=opt.eval_context_bsz,
         | 
| 40 | 
            +
                                                num_workers=opt.num_workers,
         | 
| 41 | 
            +
                                                shuffle=False,
         | 
| 42 | 
            +
                                                pin_memory=opt.pin_memory)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                metas = []  # list(dicts)
         | 
| 45 | 
            +
                video_feat1 = []
         | 
| 46 | 
            +
                video_feat2 = []
         | 
| 47 | 
            +
                video_mask = []
         | 
| 48 | 
            +
                sub_feat1 = []
         | 
| 49 | 
            +
                sub_feat2 = []
         | 
| 50 | 
            +
                sub_mask = []
         | 
| 51 | 
            +
                for idx, batch in tqdm(enumerate(context_dataloader),
         | 
| 52 | 
            +
                                       desc="Computing query2video scores",
         | 
| 53 | 
            +
                                       total=len(context_dataloader)):
         | 
| 54 | 
            +
                    metas.extend(batch[0])
         | 
| 55 | 
            +
                    model_inputs = prepare_batch_inputs(batch[1], device=opt.device, non_blocking=opt.pin_memory)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    _video_feat1, _video_feat2, _sub_feat1, _sub_feat2 = model.encode_context(
         | 
| 58 | 
            +
                        model_inputs["video_feat"], model_inputs["video_mask"],
         | 
| 59 | 
            +
                        model_inputs["sub_feat"], model_inputs["sub_mask"])
         | 
| 60 | 
            +
                    if "video" in opt.ctx_mode:
         | 
| 61 | 
            +
                        video_feat1.append(_video_feat1)
         | 
| 62 | 
            +
                        video_feat2.append(_video_feat2)
         | 
| 63 | 
            +
                        video_mask.append(model_inputs["video_mask"])
         | 
| 64 | 
            +
                    if "sub" in opt.ctx_mode:
         | 
| 65 | 
            +
                        sub_feat1.append(_sub_feat1)
         | 
| 66 | 
            +
                        sub_feat2.append(_sub_feat2)
         | 
| 67 | 
            +
                        sub_mask.append(model_inputs["sub_mask"])
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def cat_tensor(tensor_list):
         | 
| 70 | 
            +
                    if len(tensor_list) == 0:
         | 
| 71 | 
            +
                        return None
         | 
| 72 | 
            +
                    else:
         | 
| 73 | 
            +
                        seq_l = [e.shape[1] for e in tensor_list]
         | 
| 74 | 
            +
                        b_sizes = [e.shape[0] for e in tensor_list]
         | 
| 75 | 
            +
                        b_sizes_cumsum = np.cumsum([0] + b_sizes)
         | 
| 76 | 
            +
                        if len(tensor_list[0].shape) == 3:
         | 
| 77 | 
            +
                            hsz = tensor_list[0].shape[2]
         | 
| 78 | 
            +
                            res_tensor = tensor_list[0].new_zeros(sum(b_sizes), max(seq_l), hsz)
         | 
| 79 | 
            +
                        elif len(tensor_list[0].shape) == 2:
         | 
| 80 | 
            +
                            res_tensor = tensor_list[0].new_zeros(sum(b_sizes), max(seq_l))
         | 
| 81 | 
            +
                        else:
         | 
| 82 | 
            +
                            raise ValueError("Only support 2/3 dimensional tensors")
         | 
| 83 | 
            +
                        for i, e in enumerate(tensor_list):
         | 
| 84 | 
            +
                            res_tensor[b_sizes_cumsum[i]:b_sizes_cumsum[i+1], :seq_l[i]] = e
         | 
| 85 | 
            +
                        return res_tensor
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                return metas, dict(
         | 
| 88 | 
            +
                    video_feat1=cat_tensor(video_feat1),  # (N_videos, L, hsz),
         | 
| 89 | 
            +
                    video_feat2=cat_tensor(video_feat2),
         | 
| 90 | 
            +
                    video_mask=cat_tensor(video_mask),  # (N_videos, L)
         | 
| 91 | 
            +
                    sub_feat1=cat_tensor(sub_feat1),
         | 
| 92 | 
            +
                    sub_feat2=cat_tensor(sub_feat2),
         | 
| 93 | 
            +
                    sub_mask=cat_tensor(sub_mask),
         | 
| 94 | 
            +
                )
         | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
            def index_if_not_none(input_tensor, indices):
         | 
| 98 | 
            +
                if input_tensor is None:
         | 
| 99 | 
            +
                    return input_tensor
         | 
| 100 | 
            +
                else:
         | 
| 101 | 
            +
                    return input_tensor[indices]
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
             | 
| 106 | 
            +
            def generate_min_max_length_mask(array_shape, min_l, max_l):
         | 
| 107 | 
            +
                """ The last two dimension denotes matrix of upper-triangle with upper-right corner masked,
         | 
| 108 | 
            +
                below is the case for 4x4.
         | 
| 109 | 
            +
                [[0, 1, 1, 0],
         | 
| 110 | 
            +
                 [0, 0, 1, 1],
         | 
| 111 | 
            +
                 [0, 0, 0, 1],
         | 
| 112 | 
            +
                 [0, 0, 0, 0]]
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                Args:
         | 
| 115 | 
            +
                    array_shape: np.shape??? The last two dimensions should be the same
         | 
| 116 | 
            +
                    min_l: int, minimum length of predicted span
         | 
| 117 | 
            +
                    max_l: int, maximum length of predicted span
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                Returns:
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                """
         | 
| 122 | 
            +
                single_dims = (1, ) * (len(array_shape) - 2)
         | 
| 123 | 
            +
                mask_shape = single_dims + array_shape[-2:]
         | 
| 124 | 
            +
                extra_length_mask_array = np.ones(mask_shape, dtype=np.float32)  # (1, ..., 1, L, L)
         | 
| 125 | 
            +
                mask_triu = np.triu(extra_length_mask_array, k=min_l)
         | 
| 126 | 
            +
                mask_triu_reversed = 1 - np.triu(extra_length_mask_array, k=max_l)
         | 
| 127 | 
            +
                final_prob_mask = mask_triu * mask_triu_reversed
         | 
| 128 | 
            +
                return final_prob_mask  # with valid bit to be 1
         | 
| 129 | 
            +
             | 
| 130 | 
            +
             | 
| 131 | 
            +
            def get_svmr_res_from_st_ed_probs(svmr_gt_st_probs, svmr_gt_ed_probs, query_metas, video2idx,
         | 
| 132 | 
            +
                                              clip_length, min_pred_l, max_pred_l, max_before_nms):
         | 
| 133 | 
            +
                """
         | 
| 134 | 
            +
                Args:
         | 
| 135 | 
            +
                    svmr_gt_st_probs: np.ndarray (N_queries, L, L), value range [0, 1]
         | 
| 136 | 
            +
                    svmr_gt_ed_probs:
         | 
| 137 | 
            +
                    query_metas:
         | 
| 138 | 
            +
                    video2idx:
         | 
| 139 | 
            +
                    clip_length: float, how long each clip is in seconds
         | 
| 140 | 
            +
                    min_pred_l: int, minimum number of clips
         | 
| 141 | 
            +
                    max_pred_l: int, maximum number of clips
         | 
| 142 | 
            +
                    max_before_nms: get top-max_before_nms predictions for each query
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                Returns:
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                """
         | 
| 147 | 
            +
                svmr_res = []
         | 
| 148 | 
            +
                query_vid_names = [e["vid_name"] for e in query_metas]
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                # masking very long ones! Since most are relatively short.
         | 
| 151 | 
            +
                st_ed_prob_product = np.einsum("bm,bn->bmn", svmr_gt_st_probs, svmr_gt_ed_probs)  # (N, L, L)
         | 
| 152 | 
            +
                # extra_length_mask_array = np.ones(st_ed_prob_product.shape, dtype=bool)  # (N, L, L)
         | 
| 153 | 
            +
                # mask_triu = np.triu(extra_length_mask_array, k=min_pred_l)
         | 
| 154 | 
            +
                # mask_triu_reversed = np.logical_not(np.triu(extra_length_mask_array, k=max_pred_l))
         | 
| 155 | 
            +
                # final_prob_mask = np.logical_and(mask_triu, mask_triu_reversed)  # with valid bit to be 1
         | 
| 156 | 
            +
                valid_prob_mask = generate_min_max_length_mask(st_ed_prob_product.shape, min_l=min_pred_l, max_l=max_pred_l)
         | 
| 157 | 
            +
                st_ed_prob_product *= valid_prob_mask  # invalid location will become zero!
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                batched_sorted_triples = find_max_triples_from_upper_triangle_product(
         | 
| 160 | 
            +
                    st_ed_prob_product, top_n=max_before_nms, prob_thd=None)
         | 
| 161 | 
            +
                for i, q_vid_name in tqdm(enumerate(query_vid_names),
         | 
| 162 | 
            +
                                          desc="[SVMR] Loop over queries to generate predictions",
         | 
| 163 | 
            +
                                          total=len(query_vid_names)):  # i is query_id
         | 
| 164 | 
            +
                    q_m = query_metas[i]
         | 
| 165 | 
            +
                    video_idx = video2idx[q_vid_name]
         | 
| 166 | 
            +
                    _sorted_triples = batched_sorted_triples[i]
         | 
| 167 | 
            +
                    _sorted_triples[:, 1] += 1  # as we redefined ed_idx, which is inside the moment.
         | 
| 168 | 
            +
                    _sorted_triples[:, :2] = _sorted_triples[:, :2] * clip_length
         | 
| 169 | 
            +
                    # [video_idx(int), st(float), ed(float), score(float)]
         | 
| 170 | 
            +
                    cur_ranked_predictions = [[video_idx, ] + row for row in _sorted_triples.tolist()]
         | 
| 171 | 
            +
                    cur_query_pred = dict(
         | 
| 172 | 
            +
                        query_id=q_m["query_id"],
         | 
| 173 | 
            +
                        desc=q_m["desc"],
         | 
| 174 | 
            +
                        predictions=cur_ranked_predictions
         | 
| 175 | 
            +
                    )
         | 
| 176 | 
            +
                    svmr_res.append(cur_query_pred)
         | 
| 177 | 
            +
                return svmr_res
         | 
| 178 | 
            +
             | 
| 179 | 
            +
             | 
| 180 | 
            +
            def load_external_vr_res2(external_vr_res_path, top_n_vr_videos=5):
         | 
| 181 | 
            +
                """return a mapping from query_id to top retrieved video info"""
         | 
| 182 | 
            +
                external_vr_res = load_json(external_vr_res_path)
         | 
| 183 | 
            +
                external_vr_res = get_submission_top_n(external_vr_res, top_n=top_n_vr_videos)["VR"]
         | 
| 184 | 
            +
                query2video = {e["query_id"]: e["predictions"] for e in external_vr_res}
         | 
| 185 | 
            +
                return query2video
         | 
| 186 | 
            +
             | 
| 187 | 
            +
             | 
| 188 | 
            +
            def compute_query2ctx_info(model, eval_dataset, opt, video_metas, ctx_info,
         | 
| 189 | 
            +
                                       max_before_nms=1000, max_n_videos=100, maxtopk=40):
         | 
| 190 | 
            +
                """Use val set to do evaluation, remember to run with torch.no_grad().
         | 
| 191 | 
            +
                estimated size 20,000 (query) * 500 (hsz) * 4 / (1024**2) = 38.15 MB
         | 
| 192 | 
            +
                max_n_videos: int, use max_n_videos videos for computing VCMR/VR results
         | 
| 193 | 
            +
                """
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                video2idx = eval_dataset.video2idx
         | 
| 196 | 
            +
                # video_metas = ctx_info["video_metas"]
         | 
| 197 | 
            +
                if opt.external_inference_vr_res_path is not None:
         | 
| 198 | 
            +
                    video_idx2meta_idx = {video2idx[m["vid_name"]]: i for i, m in enumerate(video_metas)}
         | 
| 199 | 
            +
                    external_query2video = \
         | 
| 200 | 
            +
                        load_external_vr_res2(opt.external_inference_vr_res_path, top_n_vr_videos=max_n_videos)
         | 
| 201 | 
            +
                    # 「query idx: [video meta idx]」
         | 
| 202 | 
            +
                    external_query2video_meta_idx = \
         | 
| 203 | 
            +
                        {k: [video_idx2meta_idx[e[0]] for e in v] for k, v in external_query2video.items()}
         | 
| 204 | 
            +
                else:
         | 
| 205 | 
            +
                    external_query2video = None
         | 
| 206 | 
            +
                    external_query2video_meta_idx = None
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                model.eval()
         | 
| 209 | 
            +
                eval_dataset.set_data_mode("query")
         | 
| 210 | 
            +
                # eval_dataset.load_gt_vid_name_for_query(is_svmr)
         | 
| 211 | 
            +
                query_eval_loader = DataLoader(eval_dataset,
         | 
| 212 | 
            +
                                               collate_fn=start_end_collate,
         | 
| 213 | 
            +
                                               batch_size=opt.eval_query_bsz,
         | 
| 214 | 
            +
                                               num_workers=opt.num_workers,
         | 
| 215 | 
            +
                                               shuffle=False,
         | 
| 216 | 
            +
                                               pin_memory=opt.pin_memory)
         | 
| 217 | 
            +
                n_total_videos = len(video_metas)
         | 
| 218 | 
            +
                n_total_query = len(eval_dataset)
         | 
| 219 | 
            +
                bsz = opt.eval_query_bsz
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                flat_st_ed_scores_sorted_indices = np.empty((n_total_query, max_before_nms), dtype=int)
         | 
| 222 | 
            +
                flat_st_ed_sorted_scores = np.zeros((n_total_query, max_before_nms), dtype=np.float32)
         | 
| 223 | 
            +
                sorted_q2c_indices = np.empty((n_total_query, max_n_videos), dtype=int)
         | 
| 224 | 
            +
                sorted_q2c_scores = np.empty((n_total_query, max_n_videos), dtype=np.float32)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
             | 
| 227 | 
            +
                query_metas = []
         | 
| 228 | 
            +
                for idx, batch in tqdm(
         | 
| 229 | 
            +
                        enumerate(query_eval_loader), desc="Computing q embedding", total=len(query_eval_loader)):
         | 
| 230 | 
            +
                    _query_metas = batch[0]
         | 
| 231 | 
            +
                    query_metas.extend(batch[0])
         | 
| 232 | 
            +
                    model_inputs = prepare_batch_inputs(batch[1], device=opt.device, non_blocking=opt.pin_memory)
         | 
| 233 | 
            +
                    # query_context_scores (_N_q, N_videos), st_prob, ed_prob (_N_q, N_videos, L)
         | 
| 234 | 
            +
                    _query_context_scores, _st_probs, _ed_probs = \
         | 
| 235 | 
            +
                        model.get_pred_from_raw_query(model_inputs["query_feat"], model_inputs["query_mask"],
         | 
| 236 | 
            +
                                                      ctx_info["video_feat1"], ctx_info["video_feat2"],
         | 
| 237 | 
            +
                                                      ctx_info["video_mask"],
         | 
| 238 | 
            +
                                                      ctx_info["sub_feat1"], ctx_info["sub_feat2"],
         | 
| 239 | 
            +
                                                      ctx_info["sub_mask"],
         | 
| 240 | 
            +
                                                      cross=True)
         | 
| 241 | 
            +
                    # _query_context_scores = _query_context_scores + 1  # move cosine similarity to [0, 2]
         | 
| 242 | 
            +
                    # To give more importance to top scores, the higher opt.alpha is the more importance will be given
         | 
| 243 | 
            +
                    _query_context_scores = torch.exp(opt.q2c_alpha * _query_context_scores)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                    # normalize to get true probabilities!!!
         | 
| 246 | 
            +
                    # the probabilities here are already (pad) masked, so only need to do softmax
         | 
| 247 | 
            +
                    _st_probs = F.softmax(_st_probs, dim=-1)  # (_N_q, N_videos, L)
         | 
| 248 | 
            +
                    _ed_probs = F.softmax(_ed_probs, dim=-1)
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    if external_query2video is None:
         | 
| 251 | 
            +
                        _sorted_q2c_scores, _sorted_q2c_indices = \
         | 
| 252 | 
            +
                            torch.topk(_query_context_scores, max_n_videos, dim=1, largest=True)
         | 
| 253 | 
            +
                    else:
         | 
| 254 | 
            +
                        relevant_video_info = [external_query2video[qm["query_id"]] for qm in _query_metas]
         | 
| 255 | 
            +
                        _sorted_q2c_indices = _query_context_scores.new(
         | 
| 256 | 
            +
                            [[video_idx2meta_idx[sub_e[0]] for sub_e in e] for e in relevant_video_info]).long()
         | 
| 257 | 
            +
                        _sorted_q2c_scores = _query_context_scores.new(
         | 
| 258 | 
            +
                            [[sub_e[3] for sub_e in e] for e in relevant_video_info])
         | 
| 259 | 
            +
                        _sorted_q2c_scores = torch.exp(opt.q2c_alpha * _sorted_q2c_scores)
         | 
| 260 | 
            +
                    # collect data for vr and vcmr
         | 
| 261 | 
            +
                    sorted_q2c_indices[idx * bsz:(idx + 1) * bsz] = _sorted_q2c_indices.cpu().numpy()
         | 
| 262 | 
            +
                    sorted_q2c_scores[idx * bsz:(idx + 1) * bsz] = _sorted_q2c_scores.cpu().numpy()
         | 
| 263 | 
            +
             | 
| 264 | 
            +
             | 
| 265 | 
            +
                    # Get VCMR results
         | 
| 266 | 
            +
                    # compute combined scores
         | 
| 267 | 
            +
                    row_indices = torch.arange(0, len(_st_probs), device=opt.device).unsqueeze(1)
         | 
| 268 | 
            +
                    _st_probs = _st_probs[row_indices, _sorted_q2c_indices]  # (_N_q, max_n_videos, L)
         | 
| 269 | 
            +
                    _ed_probs = _ed_probs[row_indices, _sorted_q2c_indices]
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    # (_N_q, max_n_videos, L, L)
         | 
| 272 | 
            +
                    _st_ed_scores = torch.einsum("qvm,qv,qvn->qvmn", _st_probs, _sorted_q2c_scores, _ed_probs)
         | 
| 273 | 
            +
                    valid_prob_mask = generate_min_max_length_mask(
         | 
| 274 | 
            +
                        _st_ed_scores.shape, min_l=opt.min_pred_l, max_l=opt.max_pred_l)
         | 
| 275 | 
            +
                    _st_ed_scores *= torch.from_numpy(
         | 
| 276 | 
            +
                        valid_prob_mask).to(_st_ed_scores.device)  # invalid location will become zero!
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    # sort across the top-max_n_videos videos (by flatten from the 2nd dim)
         | 
| 279 | 
            +
                    # the indices here are local indices, not global indices
         | 
| 280 | 
            +
                    _n_q = _st_ed_scores.shape[0]
         | 
| 281 | 
            +
                    _flat_st_ed_scores = _st_ed_scores.reshape(_n_q, -1)  # (N_q, max_n_videos*L*L)
         | 
| 282 | 
            +
                    _flat_st_ed_sorted_scores, _flat_st_ed_scores_sorted_indices = \
         | 
| 283 | 
            +
                        torch.sort(_flat_st_ed_scores, dim=1, descending=True)
         | 
| 284 | 
            +
                    # collect data
         | 
| 285 | 
            +
                    flat_st_ed_sorted_scores[idx * bsz:(idx + 1) * bsz] = \
         | 
| 286 | 
            +
                        _flat_st_ed_sorted_scores[:, :max_before_nms].cpu().numpy()
         | 
| 287 | 
            +
                    flat_st_ed_scores_sorted_indices[idx * bsz:(idx + 1) * bsz] = \
         | 
| 288 | 
            +
                        _flat_st_ed_scores_sorted_indices[:, :max_before_nms].cpu().numpy()
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    if opt.debug:
         | 
| 291 | 
            +
                        break
         | 
| 292 | 
            +
             | 
| 293 | 
            +
             | 
| 294 | 
            +
                vcmr_res = {}
         | 
| 295 | 
            +
                for i, (_flat_st_ed_scores_sorted_indices, _flat_st_ed_sorted_scores) in tqdm(
         | 
| 296 | 
            +
                        enumerate(zip(flat_st_ed_scores_sorted_indices, flat_st_ed_sorted_scores)),
         | 
| 297 | 
            +
                        desc="[VCMR] Loop over queries to generate predictions", total=n_total_query):  # i is query_idx
         | 
| 298 | 
            +
                    # list([video_idx(int), st(float), ed(float), score(float)])
         | 
| 299 | 
            +
                    video_meta_indices_local, pred_st_indices, pred_ed_indices = \
         | 
| 300 | 
            +
                        np.unravel_index(_flat_st_ed_scores_sorted_indices,
         | 
| 301 | 
            +
                                            shape=(max_n_videos, opt.max_ctx_l, opt.max_ctx_l))
         | 
| 302 | 
            +
                    # video_meta_indices_local refers to the indices among the top-max_n_videos
         | 
| 303 | 
            +
                    # video_meta_indices refers to the indices in all the videos, which is the True indices
         | 
| 304 | 
            +
                    video_meta_indices = sorted_q2c_indices[i, video_meta_indices_local]
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    pred_st_in_seconds = pred_st_indices.astype(np.float32) * opt.clip_length
         | 
| 307 | 
            +
                    pred_ed_in_seconds = pred_ed_indices.astype(np.float32) * opt.clip_length + opt.clip_length
         | 
| 308 | 
            +
                    cur_vcmr_redictions = []
         | 
| 309 | 
            +
                    for j, (v_meta_idx, v_score) in enumerate(zip(video_meta_indices, _flat_st_ed_sorted_scores)):  # videos
         | 
| 310 | 
            +
                        video_idx = video2idx[video_metas[v_meta_idx]["vid_name"]]
         | 
| 311 | 
            +
                        cur_vcmr_redictions.append(
         | 
| 312 | 
            +
                                {
         | 
| 313 | 
            +
                                "video_name": video_metas[v_meta_idx]["vid_name"],
         | 
| 314 | 
            +
                                "timestamp": [float(pred_st_in_seconds[j]), float(pred_ed_in_seconds[j])],
         | 
| 315 | 
            +
                                "model_scores": float(v_score)
         | 
| 316 | 
            +
                            }
         | 
| 317 | 
            +
                        )
         | 
| 318 | 
            +
                    query_id=query_metas[i]["query_id"]
         | 
| 319 | 
            +
                    vcmr_res[query_id] = cur_vcmr_redictions[:maxtopk]
         | 
| 320 | 
            +
                return vcmr_res
         | 
| 321 | 
            +
             | 
| 322 | 
            +
             | 
| 323 | 
            +
            def get_eval_res(model,  eval_dataset, context_data, opt, maxtopk):
         | 
| 324 | 
            +
                """compute and save query and video proposal embeddings"""
         | 
| 325 | 
            +
                
         | 
| 326 | 
            +
                video_metas, context_info = compute_context_info(model, context_data, opt)
         | 
| 327 | 
            +
                eval_res = compute_query2ctx_info(model, eval_dataset, opt, video_metas, context_info, 
         | 
| 328 | 
            +
                                                max_before_nms=opt.max_before_nms, max_n_videos=opt.max_vcmr_video, maxtopk=maxtopk)
         | 
| 329 | 
            +
                return eval_res
         | 
| 330 | 
            +
             | 
| 331 | 
            +
             | 
| 332 | 
            +
            POST_PROCESSING_MMS_FUNC = {
         | 
| 333 | 
            +
                "SVMR": post_processing_svmr_nms,
         | 
| 334 | 
            +
                "VCMR": post_processing_vcmr_nms
         | 
| 335 | 
            +
            }
         | 
| 336 | 
            +
             | 
| 337 | 
            +
            # def get_prediction_top_n(list_dict_predictions, top_n):
         | 
| 338 | 
            +
            #     top_n_res = []
         | 
| 339 | 
            +
            #     for e in list_dict_predictions:
         | 
| 340 | 
            +
            #         e["predictions"] = e["predictions"][:top_n]
         | 
| 341 | 
            +
            #         top_n_res.append(e)
         | 
| 342 | 
            +
            #     return top_n_res
         | 
| 343 | 
            +
             | 
| 344 | 
            +
             | 
| 345 | 
            +
            def eval_epoch(model, eval_dataset, context_data, logger, opt, max_after_nms, iou_thds, topks):
         | 
| 346 | 
            +
                """max_after_nms: always set to 100, since the eval script only evaluate top-100"""
         | 
| 347 | 
            +
                # IOU_THDS = (0.3, 0.5, 0.7)
         | 
| 348 | 
            +
                
         | 
| 349 | 
            +
                model.eval()
         | 
| 350 | 
            +
                pred_data = get_eval_res(model, eval_dataset, context_data, opt, max(topks))
         | 
| 351 | 
            +
                # pred_data = get_prediction_top_n(eval_res, top_n=max_after_nms)
         | 
| 352 | 
            +
                gt_data = eval_dataset.ground_truth
         | 
| 353 | 
            +
                average_ndcg = calculate_ndcg_iou(gt_data, pred_data, iou_thds, topks)
         | 
| 354 | 
            +
                return average_ndcg, pred_data
         | 
| 355 | 
            +
             | 
| 356 | 
            +
            def setup_model(opt):
         | 
| 357 | 
            +
                """Load model from checkpoint and move to specified device"""
         | 
| 358 | 
            +
                checkpoint = torch.load(opt.ckpt_filepath)
         | 
| 359 | 
            +
                loaded_model_cfg = checkpoint["model_cfg"]
         | 
| 360 | 
            +
                loaded_model_cfg["stack_conv_predictor_conv_kernel_sizes"] = -1
         | 
| 361 | 
            +
                model = XML(loaded_model_cfg)
         | 
| 362 | 
            +
                model.load_state_dict(checkpoint["model"])
         | 
| 363 | 
            +
                logger.info("Loaded model saved at epoch {} from checkpoint: {}"
         | 
| 364 | 
            +
                            .format(checkpoint["epoch"], opt.ckpt_filepath))
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                if opt.device.type == "cuda":
         | 
| 367 | 
            +
                    logger.info("CUDA enabled.")
         | 
| 368 | 
            +
                    model.to(opt.device)
         | 
| 369 | 
            +
                    if len(opt.device_ids) > 1:
         | 
| 370 | 
            +
                        logger.info("Use multi GPU", opt.device_ids)
         | 
| 371 | 
            +
                        model = torch.nn.DataParallel(model, device_ids=opt.device_ids)  # use multi GPU
         | 
| 372 | 
            +
                return model
         | 
| 373 | 
            +
             | 
| 374 | 
            +
             | 
| 375 | 
            +
            def start_inference():
         | 
| 376 | 
            +
                logger.info("Setup config, data and model...")
         | 
| 377 | 
            +
                opt = TestOptions().parse()
         | 
| 378 | 
            +
                cudnn.benchmark = False
         | 
| 379 | 
            +
                cudnn.deterministic = True
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                assert opt.eval_path is not None
         | 
| 382 | 
            +
                eval_dataset = StartEndEvalDataset(
         | 
| 383 | 
            +
                    dset_name=opt.dset_name,
         | 
| 384 | 
            +
                    eval_split_name=opt.eval_split_name,  # should only be val set
         | 
| 385 | 
            +
                    data_path=opt.eval_path,
         | 
| 386 | 
            +
                    desc_bert_path_or_handler=opt.desc_bert_path,
         | 
| 387 | 
            +
                    sub_bert_path_or_handler=opt.sub_bert_path,
         | 
| 388 | 
            +
                    max_desc_len=opt.max_desc_l,
         | 
| 389 | 
            +
                    max_ctx_len=opt.max_ctx_l,
         | 
| 390 | 
            +
                    corpus_path=opt.corpus_path,
         | 
| 391 | 
            +
                    vid_feat_path_or_handler=opt.vid_feat_path,
         | 
| 392 | 
            +
                    clip_length=opt.clip_length,
         | 
| 393 | 
            +
                    ctx_mode=opt.ctx_mode,
         | 
| 394 | 
            +
                    data_mode="query",
         | 
| 395 | 
            +
                    h5driver=opt.h5driver,
         | 
| 396 | 
            +
                    data_ratio=opt.data_ratio,
         | 
| 397 | 
            +
                    normalize_vfeat=not opt.no_norm_vfeat,
         | 
| 398 | 
            +
                    normalize_tfeat=not opt.no_norm_tfeat
         | 
| 399 | 
            +
                )
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                model = setup_model(opt)
         | 
| 402 | 
            +
                save_submission_filename = "inference_{}_{}_{}_predictions_{}.json".format(
         | 
| 403 | 
            +
                    opt.dset_name, opt.eval_split_name, opt.eval_id, "_".join(opt.tasks))
         | 
| 404 | 
            +
                logger.info("Starting inference...")
         | 
| 405 | 
            +
                with torch.no_grad():
         | 
| 406 | 
            +
                    metrics_no_nms, metrics_nms, latest_file_paths = \
         | 
| 407 | 
            +
                        eval_epoch(model, eval_dataset, opt, save_submission_filename,
         | 
| 408 | 
            +
                                   tasks=opt.tasks, max_after_nms=100)
         | 
| 409 | 
            +
                logger.info("metrics_no_nms \n{}".format(pprint.pformat(metrics_no_nms, indent=4)))
         | 
| 410 | 
            +
                logger.info("metrics_nms \n{}".format(pprint.pformat(metrics_nms, indent=4)))
         | 
| 411 | 
            +
             | 
| 412 | 
            +
             | 
| 413 | 
            +
            if __name__ == '__main__':
         | 
| 414 | 
            +
                start_inference()
         | 
    	
        baselines/crossmodal_moment_localization/model_components.py
    ADDED
    
    | @@ -0,0 +1,317 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class DepthwiseSeparableConv(nn.Module):
         | 
| 8 | 
            +
                """
         | 
| 9 | 
            +
                Depth-wise separable convolution uses less parameters to generate output by convolution.
         | 
| 10 | 
            +
                :Examples:
         | 
| 11 | 
            +
                    >>> m = DepthwiseSeparableConv(300, 200, 5, dim=1)
         | 
| 12 | 
            +
                    >>> input_tensor = torch.randn(32, 300, 20)
         | 
| 13 | 
            +
                    >>> output = m(input_tensor)
         | 
| 14 | 
            +
                """
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def __init__(self, in_ch, out_ch, k, dim=1, relu=True):
         | 
| 17 | 
            +
                    """
         | 
| 18 | 
            +
                    :param in_ch: input hidden dimension size
         | 
| 19 | 
            +
                    :param out_ch: output hidden dimension size
         | 
| 20 | 
            +
                    :param k: kernel size
         | 
| 21 | 
            +
                    :param dim: default 1. 1D conv or 2D conv
         | 
| 22 | 
            +
                    """
         | 
| 23 | 
            +
                    super(DepthwiseSeparableConv, self).__init__()
         | 
| 24 | 
            +
                    self.relu = relu
         | 
| 25 | 
            +
                    if dim == 1:
         | 
| 26 | 
            +
                        self.depthwise_conv = nn.Conv1d(in_channels=in_ch, out_channels=in_ch,
         | 
| 27 | 
            +
                                                        kernel_size=k, groups=in_ch, padding=k//2)
         | 
| 28 | 
            +
                        self.pointwise_conv = nn.Conv1d(in_channels=in_ch, out_channels=out_ch,
         | 
| 29 | 
            +
                                                        kernel_size=1, padding=0)
         | 
| 30 | 
            +
                    elif dim == 2:
         | 
| 31 | 
            +
                        self.depthwise_conv = nn.Conv2d(in_channels=in_ch, out_channels=in_ch,
         | 
| 32 | 
            +
                                                        kernel_size=k, groups=in_ch, padding=k//2)
         | 
| 33 | 
            +
                        self.pointwise_conv = nn.Conv2d(in_channels=in_ch, out_channels=out_ch,
         | 
| 34 | 
            +
                                                        kernel_size=1, padding=0)
         | 
| 35 | 
            +
                    else:
         | 
| 36 | 
            +
                        raise Exception("Incorrect dimension!")
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                def forward(self, x):
         | 
| 39 | 
            +
                    """
         | 
| 40 | 
            +
                    :Input: (N, L_in, D)
         | 
| 41 | 
            +
                    :Output: (N, L_out, D)
         | 
| 42 | 
            +
                    """
         | 
| 43 | 
            +
                    x = x.transpose(1, 2)
         | 
| 44 | 
            +
                    if self.relu:
         | 
| 45 | 
            +
                        out = F.relu(self.pointwise_conv(self.depthwise_conv(x)), inplace=True)
         | 
| 46 | 
            +
                    else:
         | 
| 47 | 
            +
                        out = self.pointwise_conv(self.depthwise_conv(x))
         | 
| 48 | 
            +
                    return out.transpose(1, 2)  # (N, L, D)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            class ConvEncoder(nn.Module):
         | 
| 52 | 
            +
                def __init__(self, kernel_size=7, n_filters=128, dropout=0.1):
         | 
| 53 | 
            +
                    super(ConvEncoder, self).__init__()
         | 
| 54 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 55 | 
            +
                    self.layer_norm = nn.LayerNorm(n_filters)
         | 
| 56 | 
            +
                    self.conv = DepthwiseSeparableConv(in_ch=n_filters, out_ch=n_filters, k=kernel_size, relu=True)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def forward(self, x, mask):
         | 
| 59 | 
            +
                    """
         | 
| 60 | 
            +
                    :param x: (N, L, D)
         | 
| 61 | 
            +
                    :param mask: (N, L), is not used.
         | 
| 62 | 
            +
                    :return: (N, L, D)
         | 
| 63 | 
            +
                    """
         | 
| 64 | 
            +
                    return self.layer_norm(self.dropout(self.conv(x)) + x)  # (N, L, D)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            class TrainablePositionalEncoding(nn.Module):
         | 
| 68 | 
            +
                """Construct the embeddings from word, position and token_type embeddings.
         | 
| 69 | 
            +
                """
         | 
| 70 | 
            +
                def __init__(self, max_position_embeddings, hidden_size, dropout=0.1):
         | 
| 71 | 
            +
                    super(TrainablePositionalEncoding, self).__init__()
         | 
| 72 | 
            +
                    self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
         | 
| 73 | 
            +
                    self.LayerNorm = nn.LayerNorm(hidden_size)
         | 
| 74 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def forward(self, input_feat):
         | 
| 77 | 
            +
                    """
         | 
| 78 | 
            +
                    Args:
         | 
| 79 | 
            +
                        input_feat: (N, L, D)
         | 
| 80 | 
            +
                    """
         | 
| 81 | 
            +
                    bsz, seq_length = input_feat.shape[:2]
         | 
| 82 | 
            +
                    position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device)
         | 
| 83 | 
            +
                    position_ids = position_ids.unsqueeze(0).repeat(bsz, 1)  # (N, L)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    position_embeddings = self.position_embeddings(position_ids)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    embeddings = self.LayerNorm(input_feat + position_embeddings)
         | 
| 88 | 
            +
                    embeddings = self.dropout(embeddings)
         | 
| 89 | 
            +
                    return embeddings
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            class PositionEncoding(nn.Module):
         | 
| 93 | 
            +
                """
         | 
| 94 | 
            +
                Add positional information to input tensor.
         | 
| 95 | 
            +
                :Examples:
         | 
| 96 | 
            +
                    >>> model = PositionEncoding(n_filters=6, max_len=10)
         | 
| 97 | 
            +
                    >>> test_input1 = torch.zeros(3, 10, 6)
         | 
| 98 | 
            +
                    >>> output1 = model(test_input1)
         | 
| 99 | 
            +
                    >>> output1.size()
         | 
| 100 | 
            +
                    >>> test_input2 = torch.zeros(5, 3, 9, 6)
         | 
| 101 | 
            +
                    >>> output2 = model(test_input2)
         | 
| 102 | 
            +
                    >>> output2.size()
         | 
| 103 | 
            +
                """
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                def __init__(self, n_filters=128, max_len=500, pe_type="cosine"):
         | 
| 106 | 
            +
                    """
         | 
| 107 | 
            +
                    :param n_filters: same with input hidden size
         | 
| 108 | 
            +
                    :param max_len: maximum sequence length
         | 
| 109 | 
            +
                    :param pe_type: cosine or linear or None
         | 
| 110 | 
            +
                    """
         | 
| 111 | 
            +
                    super(PositionEncoding, self).__init__()
         | 
| 112 | 
            +
                    self.pe_type = pe_type
         | 
| 113 | 
            +
                    if pe_type != "none":
         | 
| 114 | 
            +
                        position = torch.arange(0, max_len).float().unsqueeze(1)
         | 
| 115 | 
            +
                        if pe_type == "cosine":
         | 
| 116 | 
            +
                            # Compute the positional encodings once in log space.
         | 
| 117 | 
            +
                            pe = torch.zeros(max_len, n_filters)  # (L, D)
         | 
| 118 | 
            +
                            div_term = torch.exp(torch.arange(0, n_filters, 2).float() * - (math.log(10000.0) / n_filters))
         | 
| 119 | 
            +
                            pe[:, 0::2] = torch.sin(position * div_term)
         | 
| 120 | 
            +
                            pe[:, 1::2] = torch.cos(position * div_term)
         | 
| 121 | 
            +
                        elif pe_type == "linear":
         | 
| 122 | 
            +
                            pe = position / max_len
         | 
| 123 | 
            +
                        else:
         | 
| 124 | 
            +
                            raise ValueError
         | 
| 125 | 
            +
                        self.register_buffer("pe", pe)  # buffer is a tensor, not a variable, (L, D)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                def forward(self, x):
         | 
| 128 | 
            +
                    """
         | 
| 129 | 
            +
                    :Input: (*, L, D)
         | 
| 130 | 
            +
                    :Output: (*, L, D) the same size as input
         | 
| 131 | 
            +
                    """
         | 
| 132 | 
            +
                    if self.pe_type != "none":
         | 
| 133 | 
            +
                        pe = self.pe.data[:x.size(-2), :]  # (#x.size(-2), n_filters)
         | 
| 134 | 
            +
                        extra_dim = len(x.size()) - 2
         | 
| 135 | 
            +
                        for _ in range(extra_dim):
         | 
| 136 | 
            +
                            pe = pe.unsqueeze(0)
         | 
| 137 | 
            +
                        x = x + pe
         | 
| 138 | 
            +
                    return x
         | 
| 139 | 
            +
             | 
| 140 | 
            +
             | 
| 141 | 
            +
            class LinearLayer(nn.Module):
         | 
| 142 | 
            +
                """linear layer configurable with layer normalization, dropout, ReLU."""
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
         | 
| 145 | 
            +
                    super(LinearLayer, self).__init__()
         | 
| 146 | 
            +
                    self.relu = relu
         | 
| 147 | 
            +
                    self.layer_norm = layer_norm
         | 
| 148 | 
            +
                    if layer_norm:
         | 
| 149 | 
            +
                        self.LayerNorm = nn.LayerNorm(in_hsz)
         | 
| 150 | 
            +
                    layers = [
         | 
| 151 | 
            +
                        nn.Dropout(dropout),
         | 
| 152 | 
            +
                        nn.Linear(in_hsz, out_hsz)
         | 
| 153 | 
            +
                    ]
         | 
| 154 | 
            +
                    self.net = nn.Sequential(*layers)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                def forward(self, x):
         | 
| 157 | 
            +
                    """(N, L, D)"""
         | 
| 158 | 
            +
                    if self.layer_norm:
         | 
| 159 | 
            +
                        x = self.LayerNorm(x)
         | 
| 160 | 
            +
                    x = self.net(x)
         | 
| 161 | 
            +
                    if self.relu:
         | 
| 162 | 
            +
                        x = F.relu(x, inplace=True)
         | 
| 163 | 
            +
                    return x  # (N, L, D)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            bert_config = dict(
         | 
| 167 | 
            +
                hidden_size=768,
         | 
| 168 | 
            +
                intermediate_size=768,
         | 
| 169 | 
            +
                hidden_dropout_prob=0.1,
         | 
| 170 | 
            +
                attention_probs_dropout_prob=0.1,
         | 
| 171 | 
            +
                num_attention_heads=4,
         | 
| 172 | 
            +
            )
         | 
| 173 | 
            +
             | 
| 174 | 
            +
             | 
| 175 | 
            +
            class BertLayer(nn.Module):
         | 
| 176 | 
            +
                def __init__(self, config, use_self_attention=True):
         | 
| 177 | 
            +
                    super(BertLayer, self).__init__()
         | 
| 178 | 
            +
                    self.use_self_attention = use_self_attention
         | 
| 179 | 
            +
                    if use_self_attention:
         | 
| 180 | 
            +
                        self.attention = BertAttention(config)
         | 
| 181 | 
            +
                    self.intermediate = BertIntermediate(config)
         | 
| 182 | 
            +
                    self.output = BertOutput(config)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                def forward(self, hidden_states, attention_mask):
         | 
| 185 | 
            +
                    """
         | 
| 186 | 
            +
                    Args:
         | 
| 187 | 
            +
                        hidden_states:  (N, L, D)
         | 
| 188 | 
            +
                        attention_mask:  (N, L) with 1 indicate valid, 0 indicates invalid
         | 
| 189 | 
            +
                    Returns:
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    """
         | 
| 192 | 
            +
                    if self.use_self_attention:
         | 
| 193 | 
            +
                        attention_output = self.attention(hidden_states, attention_mask)
         | 
| 194 | 
            +
                    else:
         | 
| 195 | 
            +
                        attention_output = hidden_states
         | 
| 196 | 
            +
                    intermediate_output = self.intermediate(attention_output)
         | 
| 197 | 
            +
                    layer_output = self.output(intermediate_output, attention_output)
         | 
| 198 | 
            +
                    return layer_output
         | 
| 199 | 
            +
             | 
| 200 | 
            +
             | 
| 201 | 
            +
            class BertAttention(nn.Module):
         | 
| 202 | 
            +
                def __init__(self, config):
         | 
| 203 | 
            +
                    super(BertAttention, self).__init__()
         | 
| 204 | 
            +
                    self.self = BertSelfAttention(config)
         | 
| 205 | 
            +
                    self.output = BertSelfOutput(config)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                def forward(self, input_tensor, attention_mask):
         | 
| 208 | 
            +
                    """
         | 
| 209 | 
            +
                    Args:
         | 
| 210 | 
            +
                        input_tensor: (N, L, D)
         | 
| 211 | 
            +
                        attention_mask: (N, L)
         | 
| 212 | 
            +
                    Returns:
         | 
| 213 | 
            +
                    """
         | 
| 214 | 
            +
                    self_output = self.self(input_tensor, input_tensor, input_tensor, attention_mask)
         | 
| 215 | 
            +
                    attention_output = self.output(self_output, input_tensor)
         | 
| 216 | 
            +
                    return attention_output
         | 
| 217 | 
            +
             | 
| 218 | 
            +
             | 
| 219 | 
            +
            class BertIntermediate(nn.Module):
         | 
| 220 | 
            +
                def __init__(self, config):
         | 
| 221 | 
            +
                    super(BertIntermediate, self).__init__()
         | 
| 222 | 
            +
                    self.dense = nn.Sequential(
         | 
| 223 | 
            +
                        nn.Linear(config.hidden_size, config.intermediate_size),
         | 
| 224 | 
            +
                        nn.ReLU(True))
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                def forward(self, hidden_states):
         | 
| 227 | 
            +
                    return self.dense(hidden_states)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
             | 
| 230 | 
            +
            class BertOutput(nn.Module):
         | 
| 231 | 
            +
                def __init__(self, config):
         | 
| 232 | 
            +
                    super(BertOutput, self).__init__()
         | 
| 233 | 
            +
                    self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
         | 
| 234 | 
            +
                    self.LayerNorm = nn.LayerNorm(config.hidden_size)
         | 
| 235 | 
            +
                    self.dropout = nn.Dropout(config.hidden_dropout_prob)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                def forward(self, hidden_states, input_tensor):
         | 
| 238 | 
            +
                    hidden_states = self.dense(hidden_states)
         | 
| 239 | 
            +
                    hidden_states = self.dropout(hidden_states)
         | 
| 240 | 
            +
                    hidden_states = self.LayerNorm(hidden_states + input_tensor)
         | 
| 241 | 
            +
                    return hidden_states
         | 
| 242 | 
            +
             | 
| 243 | 
            +
             | 
| 244 | 
            +
            class BertSelfAttention(nn.Module):
         | 
| 245 | 
            +
                def __init__(self, config):
         | 
| 246 | 
            +
                    super(BertSelfAttention, self).__init__()
         | 
| 247 | 
            +
                    if config.hidden_size % config.num_attention_heads != 0:
         | 
| 248 | 
            +
                        raise ValueError(
         | 
| 249 | 
            +
                            "The hidden size (%d) is not a multiple of the number of attention "
         | 
| 250 | 
            +
                            "heads (%d)" % (config.hidden_size, config.num_attention_heads))
         | 
| 251 | 
            +
                    self.num_attention_heads = config.num_attention_heads
         | 
| 252 | 
            +
                    self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
         | 
| 253 | 
            +
                    self.all_head_size = self.num_attention_heads * self.attention_head_size
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    self.query = nn.Linear(config.hidden_size, self.all_head_size)
         | 
| 256 | 
            +
                    self.key = nn.Linear(config.hidden_size, self.all_head_size)
         | 
| 257 | 
            +
                    self.value = nn.Linear(config.hidden_size, self.all_head_size)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                def transpose_for_scores(self, x):
         | 
| 262 | 
            +
                    new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)  # (N, L, nh, dh)
         | 
| 263 | 
            +
                    x = x.view(*new_x_shape)
         | 
| 264 | 
            +
                    return x.permute(0, 2, 1, 3)  # (N, nh, L, dh)
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                def forward(self, query_states, key_states, value_states, attention_mask):
         | 
| 267 | 
            +
                    """
         | 
| 268 | 
            +
                    Args:
         | 
| 269 | 
            +
                        query_states: (N, Lq, D)
         | 
| 270 | 
            +
                        key_states: (N, L, D)
         | 
| 271 | 
            +
                        value_states: (N, L, D)
         | 
| 272 | 
            +
                        attention_mask: (N, Lq, L)
         | 
| 273 | 
            +
                    Returns:
         | 
| 274 | 
            +
                    """
         | 
| 275 | 
            +
                    # only need to mask the dimension where the softmax (last dim) is applied, as another dim (second last)
         | 
| 276 | 
            +
                    # will be ignored in future computation anyway
         | 
| 277 | 
            +
                    attention_mask = (1 - attention_mask.unsqueeze(1)) * -10000.  # (N, 1, Lq, L)
         | 
| 278 | 
            +
                    mixed_query_layer = self.query(query_states)
         | 
| 279 | 
            +
                    mixed_key_layer = self.key(key_states)
         | 
| 280 | 
            +
                    mixed_value_layer = self.value(value_states)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    query_layer = self.transpose_for_scores(mixed_query_layer)  # (N, nh, Lq, dh)
         | 
| 283 | 
            +
                    key_layer = self.transpose_for_scores(mixed_key_layer)  # (N, nh, L, dh)
         | 
| 284 | 
            +
                    value_layer = self.transpose_for_scores(mixed_value_layer)  # (N, nh, L, dh)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    # Take the dot product between "query" and "key" to get the raw attention scores.
         | 
| 287 | 
            +
                    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))  # (N, nh, Lq, L)
         | 
| 288 | 
            +
                    attention_scores = attention_scores / math.sqrt(self.attention_head_size)
         | 
| 289 | 
            +
                    # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
         | 
| 290 | 
            +
                    attention_scores = attention_scores + attention_mask
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                    # Normalize the attention scores to probabilities.
         | 
| 293 | 
            +
                    attention_probs = nn.Softmax(dim=-1)(attention_scores)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                    # This is actually dropping out entire tokens to attend to, which might
         | 
| 296 | 
            +
                    # seem a bit unusual, but is taken from the original Transformer paper.
         | 
| 297 | 
            +
                    attention_probs = self.dropout(attention_probs)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    context_layer = torch.matmul(attention_probs, value_layer)
         | 
| 300 | 
            +
                    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
         | 
| 301 | 
            +
                    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
         | 
| 302 | 
            +
                    context_layer = context_layer.view(*new_context_layer_shape)
         | 
| 303 | 
            +
                    return context_layer
         | 
| 304 | 
            +
             | 
| 305 | 
            +
             | 
| 306 | 
            +
            class BertSelfOutput(nn.Module):
         | 
| 307 | 
            +
                def __init__(self, config):
         | 
| 308 | 
            +
                    super(BertSelfOutput, self).__init__()
         | 
| 309 | 
            +
                    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
         | 
| 310 | 
            +
                    self.LayerNorm = nn.LayerNorm(config.hidden_size)
         | 
| 311 | 
            +
                    self.dropout = nn.Dropout(config.hidden_dropout_prob)
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                def forward(self, hidden_states, input_tensor):
         | 
| 314 | 
            +
                    hidden_states = self.dense(hidden_states)
         | 
| 315 | 
            +
                    hidden_states = self.dropout(hidden_states)
         | 
| 316 | 
            +
                    hidden_states = self.LayerNorm(hidden_states + input_tensor)
         | 
| 317 | 
            +
                    return hidden_states
         | 
    	
        baselines/crossmodal_moment_localization/model_xml.py
    ADDED
    
    | @@ -0,0 +1,642 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import copy
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
            from easydict import EasyDict as edict
         | 
| 7 | 
            +
            from baselines.crossmodal_moment_localization.model_components import \
         | 
| 8 | 
            +
                BertAttention, PositionEncoding, LinearLayer, BertSelfAttention, TrainablePositionalEncoding, ConvEncoder
         | 
| 9 | 
            +
            from utils.model_utils import RNNEncoder
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            base_bert_layer_config = dict(
         | 
| 12 | 
            +
                hidden_size=768,
         | 
| 13 | 
            +
                intermediate_size=768,
         | 
| 14 | 
            +
                hidden_dropout_prob=0.1,
         | 
| 15 | 
            +
                attention_probs_dropout_prob=0.1,
         | 
| 16 | 
            +
                num_attention_heads=4,
         | 
| 17 | 
            +
            )
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            xml_base_config = edict(
         | 
| 20 | 
            +
                merge_two_stream=True,  # merge only the scores
         | 
| 21 | 
            +
                cross_att=True,  # cross-attention for video and subtitles
         | 
| 22 | 
            +
                span_predictor_type="conv",
         | 
| 23 | 
            +
                encoder_type="transformer",  # cnn, transformer, lstm, gru
         | 
| 24 | 
            +
                add_pe_rnn=False,  # add positional encoding for RNNs, (LSTM and GRU)
         | 
| 25 | 
            +
                visual_input_size=2048,  # changes based on visual input type
         | 
| 26 | 
            +
                query_input_size=768,
         | 
| 27 | 
            +
                sub_input_size=768,
         | 
| 28 | 
            +
                hidden_size=500,  #
         | 
| 29 | 
            +
                conv_kernel_size=5,  # conv kernel_size for st_ed predictor
         | 
| 30 | 
            +
                stack_conv_predictor_conv_kernel_sizes=-1,  # Do not use
         | 
| 31 | 
            +
                conv_stride=1,  #
         | 
| 32 | 
            +
                max_ctx_l=100,
         | 
| 33 | 
            +
                max_desc_l=30,
         | 
| 34 | 
            +
                input_drop=0.1,  # dropout for input
         | 
| 35 | 
            +
                drop=0.1,  # dropout for other layers
         | 
| 36 | 
            +
                n_heads=4,  # self attention heads
         | 
| 37 | 
            +
                ctx_mode="video_sub",  # which context are used. 'video', 'sub' or 'video_sub'
         | 
| 38 | 
            +
                margin=0.1,  # margin for ranking loss
         | 
| 39 | 
            +
                ranking_loss_type="hinge",  # loss type, 'hinge' or 'lse'
         | 
| 40 | 
            +
                lw_neg_q=1,  # loss weight for neg. query and pos. context
         | 
| 41 | 
            +
                lw_neg_ctx=1,  # loss weight for pos. query and neg. context
         | 
| 42 | 
            +
                lw_st_ed=1,  # loss weight for st ed prediction
         | 
| 43 | 
            +
                use_hard_negative=False,  # use hard negative at video level, we may change it during training.
         | 
| 44 | 
            +
                hard_pool_size=20,
         | 
| 45 | 
            +
                use_self_attention=True,
         | 
| 46 | 
            +
                no_modular=False,
         | 
| 47 | 
            +
                pe_type="none",  # no positional encoding
         | 
| 48 | 
            +
                initializer_range=0.02,
         | 
| 49 | 
            +
            )
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            class XML(nn.Module):
         | 
| 53 | 
            +
                def __init__(self, config):
         | 
| 54 | 
            +
                    super(XML, self).__init__()
         | 
| 55 | 
            +
                    self.config = config
         | 
| 56 | 
            +
                    # self.position_embeddings = PositionEncoding(n_filters=config.hidden_size,
         | 
| 57 | 
            +
                    #                                             max_len=config.max_position_embeddings,
         | 
| 58 | 
            +
                    #                                             pe_type=config.pe_type)
         | 
| 59 | 
            +
                    self.query_pos_embed = TrainablePositionalEncoding(
         | 
| 60 | 
            +
                        max_position_embeddings=config.max_desc_l,
         | 
| 61 | 
            +
                        hidden_size=config.hidden_size, dropout=config.input_drop)
         | 
| 62 | 
            +
                    self.ctx_pos_embed = TrainablePositionalEncoding(
         | 
| 63 | 
            +
                        max_position_embeddings=config.max_ctx_l,
         | 
| 64 | 
            +
                        hidden_size=config.hidden_size, dropout=config.input_drop)
         | 
| 65 | 
            +
                    self.query_input_proj = LinearLayer(config.query_input_size,
         | 
| 66 | 
            +
                                                        config.hidden_size,
         | 
| 67 | 
            +
                                                        layer_norm=True,
         | 
| 68 | 
            +
                                                        dropout=config.input_drop,
         | 
| 69 | 
            +
                                                        relu=True)
         | 
| 70 | 
            +
                    if config.encoder_type == "transformer":  # self-att encoder
         | 
| 71 | 
            +
                        self.query_encoder = BertAttention(edict(
         | 
| 72 | 
            +
                            hidden_size=config.hidden_size,
         | 
| 73 | 
            +
                            intermediate_size=config.hidden_size,
         | 
| 74 | 
            +
                            hidden_dropout_prob=config.drop,
         | 
| 75 | 
            +
                            attention_probs_dropout_prob=config.drop,
         | 
| 76 | 
            +
                            num_attention_heads=config.n_heads,
         | 
| 77 | 
            +
                        ))
         | 
| 78 | 
            +
                    elif config.encoder_type == "cnn":
         | 
| 79 | 
            +
                        self.query_encoder = ConvEncoder(
         | 
| 80 | 
            +
                            kernel_size=5,
         | 
| 81 | 
            +
                            n_filters=config.hidden_size,
         | 
| 82 | 
            +
                            dropout=config.drop
         | 
| 83 | 
            +
                        )
         | 
| 84 | 
            +
                    elif config.encoder_type in ["gru", "lstm"]:
         | 
| 85 | 
            +
                        self.query_encoder = RNNEncoder(
         | 
| 86 | 
            +
                            word_embedding_size=config.hidden_size,
         | 
| 87 | 
            +
                            hidden_size=config.hidden_size // 2,
         | 
| 88 | 
            +
                            bidirectional=True,
         | 
| 89 | 
            +
                            n_layers=1,
         | 
| 90 | 
            +
                            rnn_type=config.encoder_type,
         | 
| 91 | 
            +
                            return_outputs=True,
         | 
| 92 | 
            +
                            return_hidden=False
         | 
| 93 | 
            +
                        )
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    conv_cfg = dict(in_channels=1,
         | 
| 96 | 
            +
                                    out_channels=1,
         | 
| 97 | 
            +
                                    kernel_size=config.conv_kernel_size,
         | 
| 98 | 
            +
                                    stride=config.conv_stride,
         | 
| 99 | 
            +
                                    padding=config.conv_kernel_size // 2,
         | 
| 100 | 
            +
                                    bias=False)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    cross_att_cfg = edict(
         | 
| 103 | 
            +
                        hidden_size=config.hidden_size,
         | 
| 104 | 
            +
                        num_attention_heads=config.n_heads,
         | 
| 105 | 
            +
                        attention_probs_dropout_prob=config.drop
         | 
| 106 | 
            +
                    )
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    self.use_video = "video" in config.ctx_mode
         | 
| 109 | 
            +
                    if self.use_video:
         | 
| 110 | 
            +
                        self.video_input_proj = LinearLayer(config.visual_input_size,
         | 
| 111 | 
            +
                                                            config.hidden_size,
         | 
| 112 | 
            +
                                                            layer_norm=True,
         | 
| 113 | 
            +
                                                            dropout=config.input_drop,
         | 
| 114 | 
            +
                                                            relu=True)
         | 
| 115 | 
            +
                        self.video_encoder1 = copy.deepcopy(self.query_encoder)
         | 
| 116 | 
            +
                        self.video_encoder2 = copy.deepcopy(self.query_encoder)
         | 
| 117 | 
            +
                        if self.config.cross_att:
         | 
| 118 | 
            +
                            self.video_cross_att = BertSelfAttention(cross_att_cfg)
         | 
| 119 | 
            +
                            self.video_cross_layernorm = nn.LayerNorm(config.hidden_size)
         | 
| 120 | 
            +
                        else:
         | 
| 121 | 
            +
                            if self.config.encoder_type == "transformer":
         | 
| 122 | 
            +
                                self.video_encoder3 = copy.deepcopy(self.query_encoder)
         | 
| 123 | 
            +
                        self.video_query_linear = nn.Linear(config.hidden_size, config.hidden_size)
         | 
| 124 | 
            +
                        if config.span_predictor_type == "conv":
         | 
| 125 | 
            +
                            if not config.merge_two_stream:
         | 
| 126 | 
            +
                                self.video_st_predictor = nn.Conv1d(**conv_cfg)
         | 
| 127 | 
            +
                                self.video_ed_predictor = nn.Conv1d(**conv_cfg)
         | 
| 128 | 
            +
                        elif config.span_predictor_type == "cat_linear":
         | 
| 129 | 
            +
                            self.video_st_predictor = nn.ModuleList([nn.Linear(config.hidden_size, 1) for _ in range(2)])
         | 
| 130 | 
            +
                            self.video_ed_predictor = nn.ModuleList([nn.Linear(config.hidden_size, 1) for _ in range(2)])
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    self.use_sub = "sub" in config.ctx_mode
         | 
| 133 | 
            +
                    if self.use_sub:
         | 
| 134 | 
            +
                        self.sub_input_proj = LinearLayer(config.sub_input_size,
         | 
| 135 | 
            +
                                                          config.hidden_size,
         | 
| 136 | 
            +
                                                          layer_norm=True,
         | 
| 137 | 
            +
                                                          dropout=config.input_drop,
         | 
| 138 | 
            +
                                                          relu=True)
         | 
| 139 | 
            +
                        self.sub_encoder1 = copy.deepcopy(self.query_encoder)
         | 
| 140 | 
            +
                        self.sub_encoder2 = copy.deepcopy(self.query_encoder)
         | 
| 141 | 
            +
                        if self.config.cross_att:
         | 
| 142 | 
            +
                            self.sub_cross_att = BertSelfAttention(cross_att_cfg)
         | 
| 143 | 
            +
                            self.sub_cross_layernorm = nn.LayerNorm(config.hidden_size)
         | 
| 144 | 
            +
                        else:
         | 
| 145 | 
            +
                            if self.config.encoder_type == "transformer":
         | 
| 146 | 
            +
                                self.sub_encoder3 = copy.deepcopy(self.query_encoder)
         | 
| 147 | 
            +
                        self.sub_query_linear = nn.Linear(config.hidden_size, config.hidden_size)
         | 
| 148 | 
            +
                        if config.span_predictor_type == "conv":
         | 
| 149 | 
            +
                            if not config.merge_two_stream:
         | 
| 150 | 
            +
                                self.sub_st_predictor = nn.Conv1d(**conv_cfg)
         | 
| 151 | 
            +
                                self.sub_ed_predictor = nn.Conv1d(**conv_cfg)
         | 
| 152 | 
            +
                        elif config.span_predictor_type == "cat_linear":
         | 
| 153 | 
            +
                            self.sub_st_predictor = nn.ModuleList([nn.Linear(config.hidden_size, 1) for _ in range(2)])
         | 
| 154 | 
            +
                            self.sub_ed_predictor = nn.ModuleList([nn.Linear(config.hidden_size, 1) for _ in range(2)])
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    self.modular_vector_mapping = nn.Linear(in_features=config.hidden_size,
         | 
| 157 | 
            +
                                                            out_features=self.use_sub + self.use_video,
         | 
| 158 | 
            +
                                                            bias=False)
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    self.temporal_criterion = nn.CrossEntropyLoss(reduction="mean")
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    if config.merge_two_stream and config.span_predictor_type == "conv":
         | 
| 163 | 
            +
                        if self.config.stack_conv_predictor_conv_kernel_sizes == -1:
         | 
| 164 | 
            +
                            self.merged_st_predictor = nn.Conv1d(**conv_cfg)
         | 
| 165 | 
            +
                            self.merged_ed_predictor = nn.Conv1d(**conv_cfg)
         | 
| 166 | 
            +
                        else:
         | 
| 167 | 
            +
                            print("Will be using  multiple Conv layers for prediction.")
         | 
| 168 | 
            +
                            self.merged_st_predictors = nn.ModuleList()
         | 
| 169 | 
            +
                            self.merged_ed_predictors = nn.ModuleList()
         | 
| 170 | 
            +
                            num_convs = len(self.config.stack_conv_predictor_conv_kernel_sizes)
         | 
| 171 | 
            +
                            for k in self.config.stack_conv_predictor_conv_kernel_sizes:
         | 
| 172 | 
            +
                                conv_cfg = dict(in_channels=1,
         | 
| 173 | 
            +
                                                out_channels=1,
         | 
| 174 | 
            +
                                                kernel_size=k,
         | 
| 175 | 
            +
                                                stride=config.conv_stride,
         | 
| 176 | 
            +
                                                padding=k // 2,
         | 
| 177 | 
            +
                                                bias=False)
         | 
| 178 | 
            +
                                self.merged_st_predictors.append(nn.Conv1d(**conv_cfg))
         | 
| 179 | 
            +
                                self.merged_ed_predictors.append(nn.Conv1d(**conv_cfg))
         | 
| 180 | 
            +
                            self.combine_st_conv = nn.Linear(num_convs, 1, bias=False)
         | 
| 181 | 
            +
                            self.combine_ed_conv = nn.Linear(num_convs, 1, bias=False)
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    self.reset_parameters()
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                def reset_parameters(self):
         | 
| 186 | 
            +
                    """ Initialize the weights."""
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    def re_init(module):
         | 
| 189 | 
            +
                        if isinstance(module, (nn.Linear, nn.Embedding)):
         | 
| 190 | 
            +
                            # Slightly different from the TF version which uses truncated_normal for initialization
         | 
| 191 | 
            +
                            # cf https://github.com/pytorch/pytorch/pull/5617
         | 
| 192 | 
            +
                            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
         | 
| 193 | 
            +
                        elif isinstance(module, nn.LayerNorm):
         | 
| 194 | 
            +
                            module.bias.data.zero_()
         | 
| 195 | 
            +
                            module.weight.data.fill_(1.0)
         | 
| 196 | 
            +
                        elif isinstance(module, nn.Conv1d):
         | 
| 197 | 
            +
                            module.reset_parameters()
         | 
| 198 | 
            +
                        if isinstance(module, nn.Linear) and module.bias is not None:
         | 
| 199 | 
            +
                            module.bias.data.zero_()
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    self.apply(re_init)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                def set_hard_negative(self, use_hard_negative, hard_pool_size):
         | 
| 204 | 
            +
                    """use_hard_negative: bool; hard_pool_size: int, """
         | 
| 205 | 
            +
                    self.config.use_hard_negative = use_hard_negative
         | 
| 206 | 
            +
                    self.config.hard_pool_size = hard_pool_size
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                def set_train_st_ed(self, lw_st_ed):
         | 
| 209 | 
            +
                    """pre-train video retrieval then span prediction"""
         | 
| 210 | 
            +
                    self.config.lw_st_ed = lw_st_ed
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                def forward(self, query_feat, query_mask, video_feat, video_mask, sub_feat, sub_mask,
         | 
| 213 | 
            +
                            tef_feat, tef_mask, st_ed_indices):
         | 
| 214 | 
            +
                    """
         | 
| 215 | 
            +
                    Args:
         | 
| 216 | 
            +
                        query_feat: (N, Lq, Dq)
         | 
| 217 | 
            +
                        query_mask: (N, Lq)
         | 
| 218 | 
            +
                        video_feat: (N, Lv, Dv) or None
         | 
| 219 | 
            +
                        video_mask: (N, Lv) or None
         | 
| 220 | 
            +
                        sub_feat: (N, Lv, Ds) or None
         | 
| 221 | 
            +
                        sub_mask: (N, Lv) or None
         | 
| 222 | 
            +
                        tef_feat: (N, Lv, 2) or None,
         | 
| 223 | 
            +
                        tef_mask: (N, Lv) or None,
         | 
| 224 | 
            +
                        st_ed_indices: (N, 2), torch.LongTensor, 1st, 2nd columns are st, ed labels respectively.
         | 
| 225 | 
            +
                    """
         | 
| 226 | 
            +
                    video_feat1, video_feat2, sub_feat1, sub_feat2 = \
         | 
| 227 | 
            +
                        self.encode_context(video_feat, video_mask, sub_feat, sub_mask)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    query_context_scores, st_prob, ed_prob = \
         | 
| 230 | 
            +
                        self.get_pred_from_raw_query(query_feat, query_mask,
         | 
| 231 | 
            +
                                                     video_feat1, video_feat2, video_mask,
         | 
| 232 | 
            +
                                                     sub_feat1, sub_feat2, sub_mask, cross=False)
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    loss_st_ed = 0
         | 
| 235 | 
            +
                    if self.config.lw_st_ed != 0:
         | 
| 236 | 
            +
                        loss_st = self.temporal_criterion(st_prob, st_ed_indices[:, 0])
         | 
| 237 | 
            +
                        loss_ed = self.temporal_criterion(ed_prob, st_ed_indices[:, 1])
         | 
| 238 | 
            +
                        loss_st_ed = loss_st + loss_ed
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    loss_neg_ctx, loss_neg_q = 0, 0
         | 
| 241 | 
            +
                    if self.config.lw_neg_ctx != 0 or self.config.lw_neg_q != 0:
         | 
| 242 | 
            +
                        loss_neg_ctx, loss_neg_q = self.get_video_level_loss(query_context_scores)
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    loss_st_ed = self.config.lw_st_ed * loss_st_ed
         | 
| 245 | 
            +
                    loss_neg_ctx = self.config.lw_neg_ctx * loss_neg_ctx
         | 
| 246 | 
            +
                    loss_neg_q = self.config.lw_neg_q * loss_neg_q
         | 
| 247 | 
            +
                    loss = loss_st_ed + loss_neg_ctx + loss_neg_q
         | 
| 248 | 
            +
                    return loss, {"loss_st_ed": float(loss_st_ed),
         | 
| 249 | 
            +
                                  "loss_neg_ctx": float(loss_neg_ctx),
         | 
| 250 | 
            +
                                  "loss_neg_q": float(loss_neg_q),
         | 
| 251 | 
            +
                                  "loss_overall": float(loss)}
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                def get_visualization_data(self, query_feat, query_mask, video_feat, video_mask, sub_feat, sub_mask,
         | 
| 254 | 
            +
                                           tef_feat, tef_mask, st_ed_indices):
         | 
| 255 | 
            +
                    assert self.config.merge_two_stream and self.use_video and self.use_sub and not self.config.no_modular
         | 
| 256 | 
            +
                    video_feat1, video_feat2, sub_feat1, sub_feat2 = \
         | 
| 257 | 
            +
                        self.encode_context(video_feat, video_mask, sub_feat, sub_mask)
         | 
| 258 | 
            +
                    encoded_query = self.encode_input(query_feat, query_mask,
         | 
| 259 | 
            +
                                                      self.query_input_proj, self.query_encoder, self.query_pos_embed)  # (N, Lq, D)
         | 
| 260 | 
            +
                    # (N, D), (N, D), (N, L, 2)
         | 
| 261 | 
            +
                    video_query, sub_query, modular_att_scores = \
         | 
| 262 | 
            +
                        self.get_modularized_queries(encoded_query, query_mask, return_modular_att=True)
         | 
| 263 | 
            +
                    # (N, L), (N, L), (N, L)
         | 
| 264 | 
            +
                    st_prob, ed_prob, similarity_scores, video_similarity, sub_similarity = self.get_merged_st_ed_prob(
         | 
| 265 | 
            +
                        video_query, video_feat2, sub_query, sub_feat2, video_mask, cross=False, return_similaity=True)
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    # clean up invalid bits
         | 
| 268 | 
            +
                    data = dict(modular_att_scores=modular_att_scores.cpu().numpy(),  # (N, Lq, 2), row 0, 1 are video, sub.
         | 
| 269 | 
            +
                                st_prob=st_prob.cpu().numpy(),  # (N, L)
         | 
| 270 | 
            +
                                ed_prob=ed_prob.cpu().numpy(),  # (N, L)
         | 
| 271 | 
            +
                                similarity_scores=similarity_scores.cpu().numpy(),  # (N, L)
         | 
| 272 | 
            +
                                video_similarity=video_similarity.cpu().numpy(),  # (N, L)
         | 
| 273 | 
            +
                                sub_similarity=sub_similarity.cpu().numpy(),  # (N, L)
         | 
| 274 | 
            +
                                st_ed_indices=st_ed_indices.cpu().numpy())  # (N, L)
         | 
| 275 | 
            +
                    query_lengths = query_mask.sum(1).to(torch.long).cpu().tolist()  # (N, )
         | 
| 276 | 
            +
                    ctx_lengths = video_mask.sum(1).to(torch.long).cpu().tolist()  # (N, )
         | 
| 277 | 
            +
                    # print("query_lengths {}".format((type(query_lengths), len(query_lengths), query_lengths[:10])))
         | 
| 278 | 
            +
                    for k, v in data.items():
         | 
| 279 | 
            +
                        if k == "modular_att_scores":
         | 
| 280 | 
            +
                            # print(k, v, v.shape, type(v))
         | 
| 281 | 
            +
                            data[k] = [e[:l] for l, e in zip(query_lengths, v)]  # list(e) where e is  (Lq_i, 2)
         | 
| 282 | 
            +
                        else:
         | 
| 283 | 
            +
                            data[k] = [e[:l] for l, e in zip(ctx_lengths, v)]   # list(e) where e is (Lc_i)
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                    # aggregate info for each example
         | 
| 286 | 
            +
                    datalist = []
         | 
| 287 | 
            +
                    for idx in range(len(data["modular_att_scores"])):
         | 
| 288 | 
            +
                        datalist.append({k: v[idx] for k, v in data.items()})
         | 
| 289 | 
            +
                    return datalist  # list(dicts) of length N
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                def encode_query(self, query_feat, query_mask):
         | 
| 292 | 
            +
                    encoded_query = self.encode_input(query_feat, query_mask,
         | 
| 293 | 
            +
                                                      self.query_input_proj, self.query_encoder, self.query_pos_embed)  # (N, Lq, D)
         | 
| 294 | 
            +
                    video_query, sub_query = self.get_modularized_queries(encoded_query, query_mask)  # (N, D) * 2
         | 
| 295 | 
            +
                    return video_query, sub_query
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                def non_cross_encode_context(self, context_feat, context_mask, module_name="video"):
         | 
| 298 | 
            +
                    encoder_layer3 = getattr(self, module_name + "_encoder3") \
         | 
| 299 | 
            +
                        if self.config.encoder_type == "transformer" else None
         | 
| 300 | 
            +
                    return self._non_cross_encode_context(context_feat, context_mask,
         | 
| 301 | 
            +
                                                          input_proj_layer=getattr(self, module_name + "_input_proj"),
         | 
| 302 | 
            +
                                                          encoder_layer1=getattr(self, module_name + "_encoder1"),
         | 
| 303 | 
            +
                                                          encoder_layer2=getattr(self, module_name + "_encoder2"),
         | 
| 304 | 
            +
                                                          encoder_layer3=encoder_layer3)
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                def _non_cross_encode_context(self, context_feat, context_mask, input_proj_layer,
         | 
| 307 | 
            +
                                              encoder_layer1, encoder_layer2, encoder_layer3=None):
         | 
| 308 | 
            +
                    """
         | 
| 309 | 
            +
                    Args:
         | 
| 310 | 
            +
                        context_feat: (N, L, D)
         | 
| 311 | 
            +
                        context_mask: (N, L)
         | 
| 312 | 
            +
                        input_proj_layer:
         | 
| 313 | 
            +
                        encoder_layer1:
         | 
| 314 | 
            +
                        encoder_layer2:
         | 
| 315 | 
            +
                        encoder_layer3
         | 
| 316 | 
            +
                    """
         | 
| 317 | 
            +
                    context_feat1 = self.encode_input(
         | 
| 318 | 
            +
                        context_feat, context_mask, input_proj_layer, encoder_layer1, self.ctx_pos_embed)  # (N, L, D)
         | 
| 319 | 
            +
                    if self.config.encoder_type in ["transformer", "cnn"]:
         | 
| 320 | 
            +
                        context_mask = context_mask.unsqueeze(1)  # (N, 1, L), torch.FloatTensor
         | 
| 321 | 
            +
                        context_feat2 = encoder_layer2(context_feat1, context_mask)  # (N, L, D)
         | 
| 322 | 
            +
                        if self.config.encoder_type == "transformer":
         | 
| 323 | 
            +
                            context_feat2 = encoder_layer3(context_feat2, context_mask)
         | 
| 324 | 
            +
                    elif self.config.encoder_type in ["gru", "lstm"]:
         | 
| 325 | 
            +
                        context_mask = context_mask.sum(1).long()  # (N, ), torch.LongTensor
         | 
| 326 | 
            +
                        context_feat2 = encoder_layer2(context_feat1, context_mask)[0]  # (N, L, D)
         | 
| 327 | 
            +
                    else:
         | 
| 328 | 
            +
                        raise NotImplementedError
         | 
| 329 | 
            +
                    return context_feat1, context_feat2
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                def encode_context(self, video_feat, video_mask, sub_feat, sub_mask):
         | 
| 332 | 
            +
                    if self.config.cross_att:
         | 
| 333 | 
            +
                        assert self.use_video and self.use_sub
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                        return self.cross_encode_context(video_feat, video_mask, sub_feat, sub_mask)
         | 
| 336 | 
            +
                    else:
         | 
| 337 | 
            +
                        video_feat1, video_feat2 = (None,) * 2
         | 
| 338 | 
            +
                        if self.use_video:
         | 
| 339 | 
            +
                            video_feat1, video_feat2 = self.non_cross_encode_context(video_feat, video_mask, module_name="video")
         | 
| 340 | 
            +
                        sub_feat1, sub_feat2 = (None,) * 2
         | 
| 341 | 
            +
                        if self.use_sub:
         | 
| 342 | 
            +
                            sub_feat1, sub_feat2 = self.non_cross_encode_context(sub_feat, sub_mask, module_name="sub")
         | 
| 343 | 
            +
                        return video_feat1, video_feat2, sub_feat1, sub_feat2
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                def cross_encode_context(self, video_feat, video_mask, sub_feat, sub_mask):
         | 
| 346 | 
            +
                    encoded_video_feat = self.encode_input(video_feat, video_mask,
         | 
| 347 | 
            +
                                                           self.video_input_proj, self.video_encoder1, self.ctx_pos_embed)
         | 
| 348 | 
            +
                    encoded_sub_feat = self.encode_input(sub_feat, sub_mask,
         | 
| 349 | 
            +
                                                         self.sub_input_proj, self.sub_encoder1, self.ctx_pos_embed)
         | 
| 350 | 
            +
                    x_encoded_video_feat = self.cross_context_encoder(
         | 
| 351 | 
            +
                        encoded_video_feat, video_mask, encoded_sub_feat, sub_mask,
         | 
| 352 | 
            +
                        self.video_cross_att, self.video_cross_layernorm, self.video_encoder2)  # (N, L, D)
         | 
| 353 | 
            +
                    x_encoded_sub_feat = self.cross_context_encoder(
         | 
| 354 | 
            +
                        encoded_sub_feat, sub_mask, encoded_video_feat, video_mask,
         | 
| 355 | 
            +
                        self.sub_cross_att, self.sub_cross_layernorm, self.sub_encoder2)  # (N, L, D)
         | 
| 356 | 
            +
                    return encoded_video_feat, x_encoded_video_feat, encoded_sub_feat, x_encoded_sub_feat
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                def cross_context_encoder(self, main_context_feat, main_context_mask, side_context_feat, side_context_mask,
         | 
| 359 | 
            +
                                          cross_att_layer, norm_layer, self_att_layer):
         | 
| 360 | 
            +
                    """
         | 
| 361 | 
            +
                    Args:
         | 
| 362 | 
            +
                        main_context_feat: (N, Lq, D)
         | 
| 363 | 
            +
                        main_context_mask: (N, Lq)
         | 
| 364 | 
            +
                        side_context_feat: (N, Lk, D)
         | 
| 365 | 
            +
                        side_context_mask: (N, Lk)
         | 
| 366 | 
            +
                        cross_att_layer:
         | 
| 367 | 
            +
                        norm_layer:
         | 
| 368 | 
            +
                        self_att_layer:
         | 
| 369 | 
            +
                    """
         | 
| 370 | 
            +
                    cross_mask = torch.einsum("bm,bn->bmn", main_context_mask, side_context_mask)  # (N, Lq, Lk)
         | 
| 371 | 
            +
                    cross_out = cross_att_layer(main_context_feat, side_context_feat, side_context_feat, cross_mask)  # (N, Lq, D)
         | 
| 372 | 
            +
                    residual_out = norm_layer(cross_out + main_context_feat)
         | 
| 373 | 
            +
                    if self.config.encoder_type in ["cnn", "transformer"]:
         | 
| 374 | 
            +
                        return self_att_layer(residual_out, main_context_mask.unsqueeze(1))
         | 
| 375 | 
            +
                    elif self.config.encoder_type in ["gru", "lstm"]:
         | 
| 376 | 
            +
                        return self_att_layer(residual_out, main_context_mask.sum(1).long())[0]
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                def encode_input(self, feat, mask, input_proj_layer, encoder_layer, pos_embed_layer):
         | 
| 379 | 
            +
                    """
         | 
| 380 | 
            +
                    Args:
         | 
| 381 | 
            +
                        feat: (N, L, D_input), torch.float32
         | 
| 382 | 
            +
                        mask: (N, L), torch.float32, with 1 indicates valid query, 0 indicates mask
         | 
| 383 | 
            +
                        input_proj_layer: down project input
         | 
| 384 | 
            +
                        encoder_layer: encoder layer
         | 
| 385 | 
            +
                        # add_pe: bool, whether to add positional encoding
         | 
| 386 | 
            +
                        pos_embed_layer
         | 
| 387 | 
            +
                    """
         | 
| 388 | 
            +
                    feat = input_proj_layer(feat)
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                    if self.config.encoder_type in ["cnn", "transformer"]:
         | 
| 391 | 
            +
                        feat = pos_embed_layer(feat)
         | 
| 392 | 
            +
                        mask = mask.unsqueeze(1)  # (N, 1, L), torch.FloatTensor
         | 
| 393 | 
            +
                        return encoder_layer(feat, mask)  # (N, L, D_hidden)
         | 
| 394 | 
            +
                    elif self.config.encoder_type in ["gru", "lstm"]:
         | 
| 395 | 
            +
                        if self.config.add_pe_rnn:
         | 
| 396 | 
            +
                            feat = pos_embed_layer(feat)
         | 
| 397 | 
            +
                        mask = mask.sum(1).long()  # (N, ), torch.LongTensor
         | 
| 398 | 
            +
                        return encoder_layer(feat, mask)[0]  # (N, L, D_hidden)
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                def get_modularized_queries(self, encoded_query, query_mask, return_modular_att=False):
         | 
| 401 | 
            +
                    """
         | 
| 402 | 
            +
                    Args:
         | 
| 403 | 
            +
                        encoded_query: (N, L, D)
         | 
| 404 | 
            +
                        query_mask: (N, L)
         | 
| 405 | 
            +
                        return_modular_att: bool
         | 
| 406 | 
            +
                    """
         | 
| 407 | 
            +
                    if self.config.no_modular:
         | 
| 408 | 
            +
                        modular_query = torch.max(mask_logits(encoded_query, query_mask.unsqueeze(2)), dim=1)[0]  # (N, D)
         | 
| 409 | 
            +
                        return modular_query, modular_query  #
         | 
| 410 | 
            +
                    else:
         | 
| 411 | 
            +
                        modular_attention_scores = self.modular_vector_mapping(encoded_query)  # (N, L, 2 or 1)
         | 
| 412 | 
            +
                        modular_attention_scores = F.softmax(
         | 
| 413 | 
            +
                            mask_logits(modular_attention_scores, query_mask.unsqueeze(2)), dim=1)
         | 
| 414 | 
            +
                        # TODO check whether it is the same
         | 
| 415 | 
            +
                        modular_queries = torch.einsum("blm,bld->bmd",
         | 
| 416 | 
            +
                                                       modular_attention_scores, encoded_query)  # (N, 2 or 1, D)
         | 
| 417 | 
            +
                        if return_modular_att:
         | 
| 418 | 
            +
                            assert modular_queries.shape[1] == 2
         | 
| 419 | 
            +
                            return modular_queries[:, 0], modular_queries[:, 1], modular_attention_scores
         | 
| 420 | 
            +
                        else:
         | 
| 421 | 
            +
                            if modular_queries.shape[1] == 2:
         | 
| 422 | 
            +
                                return modular_queries[:, 0], modular_queries[:, 1]  # (N, D) * 2
         | 
| 423 | 
            +
                            else:  # 1
         | 
| 424 | 
            +
                                return modular_queries[:, 0], modular_queries[:, 0]  # the same
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                def get_modular_weights(self, encoded_query, query_mask):
         | 
| 427 | 
            +
                    """
         | 
| 428 | 
            +
                    Args:
         | 
| 429 | 
            +
                        encoded_query: (N, L, D)
         | 
| 430 | 
            +
                        query_mask: (N, L)
         | 
| 431 | 
            +
                    """
         | 
| 432 | 
            +
                    max_encoded_query, _ = torch.max(mask_logits(encoded_query, query_mask.unsqueeze(2)), dim=1)  # (N, D)
         | 
| 433 | 
            +
                    modular_weights = self.modular_weights_calculator(max_encoded_query)  # (N, 2)
         | 
| 434 | 
            +
                    modular_weights = F.softmax(modular_weights, dim=-1)
         | 
| 435 | 
            +
                    return modular_weights[:, 0:1], modular_weights[:, 1:2]  # (N, 1) * 2
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                def get_video_level_scores(self, modularied_query, context_feat1, context_mask):
         | 
| 438 | 
            +
                    """ Calculate video2query scores for each pair of video and query inside the batch.
         | 
| 439 | 
            +
                    Args:
         | 
| 440 | 
            +
                        modularied_query: (N, D)
         | 
| 441 | 
            +
                        context_feat1: (N, L, D), output of the first transformer encoder layer
         | 
| 442 | 
            +
                        context_mask: (N, L)
         | 
| 443 | 
            +
                    Returns:
         | 
| 444 | 
            +
                        context_query_scores: (N, N)  score of each query w.r.t. each video inside the batch,
         | 
| 445 | 
            +
                            diagonal positions are positive. used to get negative samples.
         | 
| 446 | 
            +
                    """
         | 
| 447 | 
            +
                    modularied_query = F.normalize(modularied_query, dim=-1)
         | 
| 448 | 
            +
                    context_feat1 = F.normalize(context_feat1, dim=-1)
         | 
| 449 | 
            +
                    query_context_scores = torch.einsum("md,nld->mln", modularied_query, context_feat1)  # (N, L, N)
         | 
| 450 | 
            +
                    context_mask = context_mask.transpose(0, 1).unsqueeze(0)  # (1, L, N)
         | 
| 451 | 
            +
                    query_context_scores = mask_logits(query_context_scores, context_mask)  # (N, L, N)
         | 
| 452 | 
            +
                    query_context_scores, _ = torch.max(query_context_scores,
         | 
| 453 | 
            +
                                                        dim=1)  # (N, N) diagonal positions are positive pairs.
         | 
| 454 | 
            +
                    return query_context_scores
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                def get_merged_st_ed_prob(self, video_query, video_feat, sub_query, sub_feat, context_mask,
         | 
| 457 | 
            +
                                          cross=False, return_similaity=False):
         | 
| 458 | 
            +
                    """context_mask could be either video_mask or sub_mask, since they are the same"""
         | 
| 459 | 
            +
                    assert self.use_video and self.use_sub and self.config.span_predictor_type == "conv"
         | 
| 460 | 
            +
                    video_query = self.video_query_linear(video_query)
         | 
| 461 | 
            +
                    sub_query = self.sub_query_linear(sub_query)
         | 
| 462 | 
            +
                    stack_conv = self.config.stack_conv_predictor_conv_kernel_sizes != -1
         | 
| 463 | 
            +
                    num_convs = len(self.config.stack_conv_predictor_conv_kernel_sizes) if stack_conv else None
         | 
| 464 | 
            +
                    if cross:
         | 
| 465 | 
            +
                        video_similarity = torch.einsum("md,nld->mnl", video_query, video_feat)
         | 
| 466 | 
            +
                        sub_similarity = torch.einsum("md,nld->mnl", sub_query, sub_feat)
         | 
| 467 | 
            +
                        similarity = (video_similarity + sub_similarity) / 2  # (Nq, Nv, L)  from query to all videos.
         | 
| 468 | 
            +
                        n_q, n_c, l = similarity.shape
         | 
| 469 | 
            +
                        similarity = similarity.view(n_q * n_c, 1, l)
         | 
| 470 | 
            +
                        if not stack_conv:
         | 
| 471 | 
            +
                            st_prob = self.merged_st_predictor(similarity).view(n_q, n_c, l)  # (Nq, Nv, L)
         | 
| 472 | 
            +
                            ed_prob = self.merged_ed_predictor(similarity).view(n_q, n_c, l)  # (Nq, Nv, L)
         | 
| 473 | 
            +
                        else:
         | 
| 474 | 
            +
                            st_prob_list = []
         | 
| 475 | 
            +
                            ed_prob_list = []
         | 
| 476 | 
            +
                            for idx in range(num_convs):
         | 
| 477 | 
            +
                                st_prob_list.append(self.merged_st_predictors[idx](similarity).squeeze().unsqueeze(2))
         | 
| 478 | 
            +
                                ed_prob_list.append(self.merged_ed_predictors[idx](similarity).squeeze().unsqueeze(2))
         | 
| 479 | 
            +
                            # (Nq*Nv, L, 3) --> (Nq*Nv, L) -> (Nq, Nv, L)
         | 
| 480 | 
            +
                            st_prob = self.combine_st_conv(torch.cat(st_prob_list, dim=2)).view(n_q, n_c, l)
         | 
| 481 | 
            +
                            ed_prob = self.combine_ed_conv(torch.cat(ed_prob_list, dim=2)).view(n_q, n_c, l)
         | 
| 482 | 
            +
                    else:
         | 
| 483 | 
            +
                        video_similarity = torch.einsum("bd,bld->bl", video_query, video_feat)  # (N, L)
         | 
| 484 | 
            +
                        sub_similarity = torch.einsum("bd,bld->bl", sub_query, sub_feat)  # (N, L)
         | 
| 485 | 
            +
                        similarity = (video_similarity + sub_similarity) / 2
         | 
| 486 | 
            +
                        if not stack_conv:
         | 
| 487 | 
            +
                            st_prob = self.merged_st_predictor(similarity.unsqueeze(1)).squeeze()  # (N, L)
         | 
| 488 | 
            +
                            ed_prob = self.merged_ed_predictor(similarity.unsqueeze(1)).squeeze()  # (N, L)
         | 
| 489 | 
            +
                        else:
         | 
| 490 | 
            +
                            st_prob_list = []
         | 
| 491 | 
            +
                            ed_prob_list = []
         | 
| 492 | 
            +
                            for idx in range(num_convs):
         | 
| 493 | 
            +
                                st_prob_list.append(self.merged_st_predictors[idx](similarity.unsqueeze(1)).squeeze().unsqueeze(2))
         | 
| 494 | 
            +
                                ed_prob_list.append(self.merged_ed_predictors[idx](similarity.unsqueeze(1)).squeeze().unsqueeze(2))
         | 
| 495 | 
            +
                            st_prob = self.combine_st_conv(torch.cat(st_prob_list, dim=2)).squeeze()  # (N, L, 3) --> (N, L)
         | 
| 496 | 
            +
                            ed_prob = self.combine_ed_conv(torch.cat(ed_prob_list, dim=2)).squeeze()  # (N, L, 3) --> (N, L)
         | 
| 497 | 
            +
                    st_prob = mask_logits(st_prob, context_mask)  # (N, L)
         | 
| 498 | 
            +
                    ed_prob = mask_logits(ed_prob, context_mask)
         | 
| 499 | 
            +
                    if return_similaity:
         | 
| 500 | 
            +
                        assert not cross
         | 
| 501 | 
            +
                        return st_prob, ed_prob, similarity, video_similarity, sub_similarity
         | 
| 502 | 
            +
                    else:
         | 
| 503 | 
            +
                        return st_prob, ed_prob
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                def get_st_ed_prob(self, modularied_query, context_feat2, context_mask,
         | 
| 506 | 
            +
                                   module_name="video", cross=False):
         | 
| 507 | 
            +
                    return self._get_st_ed_prob(modularied_query, context_feat2, context_mask,
         | 
| 508 | 
            +
                                                module_query_linear=getattr(self, module_name + "_query_linear"),
         | 
| 509 | 
            +
                                                st_predictor=getattr(self, module_name + "_st_predictor"),
         | 
| 510 | 
            +
                                                ed_predictor=getattr(self, module_name + "_ed_predictor"),
         | 
| 511 | 
            +
                                                cross=cross)
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                def _get_st_ed_prob(self, modularied_query, context_feat2, context_mask,
         | 
| 514 | 
            +
                                    module_query_linear, st_predictor, ed_predictor, cross=False):
         | 
| 515 | 
            +
                    """
         | 
| 516 | 
            +
                    Args:
         | 
| 517 | 
            +
                        modularied_query: (N, D)
         | 
| 518 | 
            +
                        context_feat2: (N, L, D), output of the first transformer encoder layer
         | 
| 519 | 
            +
                        context_mask: (N, L)
         | 
| 520 | 
            +
                        module_query_linear:
         | 
| 521 | 
            +
                        st_predictor:
         | 
| 522 | 
            +
                        ed_predictor:
         | 
| 523 | 
            +
                        cross: at inference, calculate prob for each possible pairs of query and context.
         | 
| 524 | 
            +
                    """
         | 
| 525 | 
            +
                    query = module_query_linear(modularied_query)  # (N, D) no need to normalize here.
         | 
| 526 | 
            +
                    if cross:
         | 
| 527 | 
            +
                        if self.config.span_predictor_type == "conv":
         | 
| 528 | 
            +
                            similarity = torch.einsum("md,nld->mnl", query, context_feat2)  # (Nq, Nv, L)  from query to all videos.
         | 
| 529 | 
            +
                            n_q, n_c, l = similarity.shape
         | 
| 530 | 
            +
                            similarity = similarity.view(n_q * n_c, 1, l)
         | 
| 531 | 
            +
                            st_prob = st_predictor(similarity).view(n_q, n_c, l)  # (Nq, Nv, L)
         | 
| 532 | 
            +
                            ed_prob = ed_predictor(similarity).view(n_q, n_c, l)  # (Nq, Nv, L)
         | 
| 533 | 
            +
                        elif self.config.span_predictor_type == "cat_linear":
         | 
| 534 | 
            +
                            st_prob_q = st_predictor[0](query).unsqueeze(1)  # (Nq, 1, 1)
         | 
| 535 | 
            +
                            st_prob_ctx = st_predictor[1](context_feat2).squeeze().unsqueeze(0)  # (1, Nv, L)
         | 
| 536 | 
            +
                            st_prob = st_prob_q + st_prob_ctx  # (Nq, Nv, L)
         | 
| 537 | 
            +
                            ed_prob_q = ed_predictor[0](query).unsqueeze(1)  # (Nq, 1, 1)
         | 
| 538 | 
            +
                            ed_prob_ctx = ed_predictor[1](context_feat2).squeeze().unsqueeze(0)  # (1, Nv, L)
         | 
| 539 | 
            +
                            ed_prob = ed_prob_q + ed_prob_ctx  # (Nq, Nv, L)
         | 
| 540 | 
            +
                        context_mask = context_mask.unsqueeze(0)  # (1, Nv, L)
         | 
| 541 | 
            +
                    else:
         | 
| 542 | 
            +
                        if self.config.span_predictor_type == "conv":
         | 
| 543 | 
            +
                            similarity = torch.einsum("bd,bld->bl", query, context_feat2)  # (N, L)
         | 
| 544 | 
            +
                            st_prob = st_predictor(similarity.unsqueeze(1)).squeeze()  # (N, L)
         | 
| 545 | 
            +
                            ed_prob = ed_predictor(similarity.unsqueeze(1)).squeeze()  # (N, L)
         | 
| 546 | 
            +
                        elif self.config.span_predictor_type == "cat_linear":
         | 
| 547 | 
            +
                            # avoid concatenation by break into smaller matrix multiplications.
         | 
| 548 | 
            +
                            st_prob = st_predictor[0](query) + st_predictor[1](context_feat2).squeeze()  # (N, L)
         | 
| 549 | 
            +
                            ed_prob = ed_predictor[0](query) + ed_predictor[1](context_feat2).squeeze()  # (N, L)
         | 
| 550 | 
            +
                    st_prob = mask_logits(st_prob, context_mask)  # (N, L)
         | 
| 551 | 
            +
                    ed_prob = mask_logits(ed_prob, context_mask)
         | 
| 552 | 
            +
                    return st_prob, ed_prob
         | 
| 553 | 
            +
             | 
| 554 | 
            +
                def get_pred_from_raw_query(self, query_feat, query_mask,
         | 
| 555 | 
            +
                                            video_feat1, video_feat2, video_mask,
         | 
| 556 | 
            +
                                            sub_feat1, sub_feat2, sub_mask, cross=False):
         | 
| 557 | 
            +
                    """
         | 
| 558 | 
            +
                    Args:
         | 
| 559 | 
            +
                        query_feat: (N, Lq, Dq)
         | 
| 560 | 
            +
                        query_mask: (N, Lq)
         | 
| 561 | 
            +
                        video_feat1: (N, Lv, D) or None
         | 
| 562 | 
            +
                        video_feat2:
         | 
| 563 | 
            +
                        video_mask: (N, Lv)
         | 
| 564 | 
            +
                        sub_feat1: (N, Lv, D) or None
         | 
| 565 | 
            +
                        sub_feat2:
         | 
| 566 | 
            +
                        sub_mask: (N, Lv)
         | 
| 567 | 
            +
                        cross:
         | 
| 568 | 
            +
                    """
         | 
| 569 | 
            +
                    video_query, sub_query = self.encode_query(query_feat, query_mask)
         | 
| 570 | 
            +
                    divisor = self.use_sub + self.use_video
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                    # get video-level retrieval scores
         | 
| 573 | 
            +
                    video_q2ctx_scores = self.get_video_level_scores(video_query, video_feat1, video_mask) if self.use_video else 0
         | 
| 574 | 
            +
                    sub_q2ctx_scores = self.get_video_level_scores(sub_query, sub_feat1, sub_mask) if self.use_sub else 0
         | 
| 575 | 
            +
                    q2ctx_scores = (video_q2ctx_scores + sub_q2ctx_scores) / divisor  # (N, N)
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                    if self.config.merge_two_stream and self.use_video and self.use_sub:
         | 
| 578 | 
            +
                        st_prob, ed_prob = self.get_merged_st_ed_prob(
         | 
| 579 | 
            +
                            video_query, video_feat2, sub_query, sub_feat2, video_mask, cross=cross)
         | 
| 580 | 
            +
                    else:
         | 
| 581 | 
            +
                        video_st_prob, video_ed_prob = self.get_st_ed_prob(
         | 
| 582 | 
            +
                            video_query, video_feat2, video_mask, module_name="video", cross=cross) if self.use_video else (0, 0)
         | 
| 583 | 
            +
                        sub_st_prob, sub_ed_prob = self.get_st_ed_prob(
         | 
| 584 | 
            +
                            sub_query, sub_feat2, sub_mask, module_name="sub", cross=cross) if self.use_sub else (0, 0)
         | 
| 585 | 
            +
                        st_prob = (video_st_prob + sub_st_prob) / divisor  # (N, Lv)
         | 
| 586 | 
            +
                        ed_prob = (video_ed_prob + sub_ed_prob) / divisor  # (N, Lv)
         | 
| 587 | 
            +
                    return q2ctx_scores, st_prob, ed_prob  # un-normalized masked probabilities!!!!!
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                def get_video_level_loss(self, query_context_scores):
         | 
| 590 | 
            +
                    """ ranking loss between (pos. query + pos. video) and (pos. query + neg. video) or (neg. query + pos. video)
         | 
| 591 | 
            +
                    Args:
         | 
| 592 | 
            +
                        query_context_scores: (N, N), cosine similarity [-1, 1],
         | 
| 593 | 
            +
                            Each row contains the scores between the query to each of the videos inside the batch.
         | 
| 594 | 
            +
                    """
         | 
| 595 | 
            +
                    bsz = len(query_context_scores)
         | 
| 596 | 
            +
                    diagonal_indices = torch.arange(bsz).to(query_context_scores.device)
         | 
| 597 | 
            +
                    pos_scores = query_context_scores[diagonal_indices, diagonal_indices]  # (N, )
         | 
| 598 | 
            +
                    query_context_scores_masked = copy.deepcopy(query_context_scores.data)
         | 
| 599 | 
            +
                    # impossibly large for cosine similarity, the copy is created as modifying the original will cause error
         | 
| 600 | 
            +
                    query_context_scores_masked[diagonal_indices, diagonal_indices] = 999
         | 
| 601 | 
            +
                    pos_query_neg_context_scores = self.get_neg_scores(query_context_scores,
         | 
| 602 | 
            +
                                                                       query_context_scores_masked)
         | 
| 603 | 
            +
                    neg_query_pos_context_scores = self.get_neg_scores(query_context_scores.transpose(0, 1),
         | 
| 604 | 
            +
                                                                       query_context_scores_masked.transpose(0, 1))
         | 
| 605 | 
            +
                    loss_neg_ctx = self.get_ranking_loss(pos_scores, pos_query_neg_context_scores)
         | 
| 606 | 
            +
                    loss_neg_q = self.get_ranking_loss(pos_scores, neg_query_pos_context_scores)
         | 
| 607 | 
            +
                    return loss_neg_ctx, loss_neg_q
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                def get_neg_scores(self, scores, scores_masked):
         | 
| 610 | 
            +
                    """
         | 
| 611 | 
            +
                    scores: (N, N), cosine similarity [-1, 1],
         | 
| 612 | 
            +
                        Each row are scores: query --> all videos. Transposed version: video --> all queries.
         | 
| 613 | 
            +
                    scores_masked: (N, N) the same as scores, except that the diagonal (positive) positions
         | 
| 614 | 
            +
                        are masked with a large value.
         | 
| 615 | 
            +
                    """
         | 
| 616 | 
            +
                    bsz = len(scores)
         | 
| 617 | 
            +
                    batch_indices = torch.arange(bsz).to(scores.device)
         | 
| 618 | 
            +
                    _, sorted_scores_indices = torch.sort(scores_masked, descending=True, dim=1)
         | 
| 619 | 
            +
                    sample_min_idx = 1  # skip the masked positive
         | 
| 620 | 
            +
                    sample_max_idx = min(sample_min_idx + self.config.hard_pool_size, bsz) \
         | 
| 621 | 
            +
                        if self.config.use_hard_negative else bsz
         | 
| 622 | 
            +
                    sampled_neg_score_indices = sorted_scores_indices[
         | 
| 623 | 
            +
                        batch_indices, torch.randint(sample_min_idx, sample_max_idx, size=(bsz,)).to(scores.device)]  # (N, )
         | 
| 624 | 
            +
                    sampled_neg_scores = scores[batch_indices, sampled_neg_score_indices]  # (N, )
         | 
| 625 | 
            +
                    return sampled_neg_scores
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                def get_ranking_loss(self, pos_score, neg_score):
         | 
| 628 | 
            +
                    """ Note here we encourage positive scores to be larger than negative scores.
         | 
| 629 | 
            +
                    Args:
         | 
| 630 | 
            +
                        pos_score: (N, ), torch.float32
         | 
| 631 | 
            +
                        neg_score: (N, ), torch.float32
         | 
| 632 | 
            +
                    """
         | 
| 633 | 
            +
                    if self.config.ranking_loss_type == "hinge":  # max(0, m + S_neg - S_pos)
         | 
| 634 | 
            +
                        return torch.clamp(self.config.margin + neg_score - pos_score, min=0).sum() / len(pos_score)
         | 
| 635 | 
            +
                    elif self.config.ranking_loss_type == "lse":  # log[1 + exp(S_neg - S_pos)]
         | 
| 636 | 
            +
                        return torch.log1p(torch.exp(neg_score - pos_score)).sum() / len(pos_score)
         | 
| 637 | 
            +
                    else:
         | 
| 638 | 
            +
                        raise NotImplementedError("Only support 'hinge' and 'lse'")
         | 
| 639 | 
            +
             | 
| 640 | 
            +
             | 
| 641 | 
            +
            def mask_logits(target, mask):
         | 
| 642 | 
            +
                return target * mask + (1 - mask) * (-1e10)
         | 
    	
        baselines/crossmodal_moment_localization/ndcg_iou_topk.py
    ADDED
    
    | @@ -0,0 +1,68 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from utils.basic_utils import load_jsonl, save_jsonl, load_json
         | 
| 2 | 
            +
            import pandas as pd
         | 
| 3 | 
            +
            from tqdm import tqdm
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            from collections import defaultdict
         | 
| 6 | 
            +
            import copy  
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            def calculate_iou(pred_start: float, pred_end: float, gt_start: float, gt_end: float) -> float:
         | 
| 9 | 
            +
                intersection_start = max(pred_start, gt_start)
         | 
| 10 | 
            +
                intersection_end = min(pred_end, gt_end)
         | 
| 11 | 
            +
                intersection = max(0, intersection_end - intersection_start)
         | 
| 12 | 
            +
                union = (pred_end - pred_start) + (gt_end - gt_start) - intersection
         | 
| 13 | 
            +
                return intersection / union if union > 0 else 0
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            # Function to calculate DCG
         | 
| 17 | 
            +
            def calculate_dcg(scores):
         | 
| 18 | 
            +
                return sum((2**score - 1) / np.log2(idx + 2) for idx, score in enumerate(scores))
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            # Function to calculate NDCG
         | 
| 21 | 
            +
            def calculate_ndcg(pred_scores, true_scores):
         | 
| 22 | 
            +
                dcg = calculate_dcg(pred_scores)
         | 
| 23 | 
            +
                idcg = calculate_dcg(sorted(true_scores, reverse=True))
         | 
| 24 | 
            +
                return dcg / idcg if idcg > 0 else 0
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def calculate_ndcg_iou(all_gt, all_pred, TS, KS):
         | 
| 29 | 
            +
                performance = defaultdict(lambda: defaultdict(list))
         | 
| 30 | 
            +
                performance_avg = defaultdict(lambda: defaultdict(float))
         | 
| 31 | 
            +
                for k in tqdm(all_pred.keys(), desc="Calculate NDCG"):
         | 
| 32 | 
            +
                    one_pred = all_pred[k]
         | 
| 33 | 
            +
                    one_gt = all_gt[k]  
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    one_gt.sort(key=lambda x: x["relevance"], reverse=True)
         | 
| 36 | 
            +
                    for T in TS:
         | 
| 37 | 
            +
                        one_gt_drop = copy.deepcopy(one_gt)  
         | 
| 38 | 
            +
                        predictions_with_scores = []
         | 
| 39 | 
            +
                        
         | 
| 40 | 
            +
                        for pred in one_pred:
         | 
| 41 | 
            +
                            pred_video_name, pred_time = pred["video_name"], pred["timestamp"]
         | 
| 42 | 
            +
                            matched_rows = [gt for gt in one_gt_drop if gt["video_name"] == pred_video_name]
         | 
| 43 | 
            +
                            if not matched_rows:
         | 
| 44 | 
            +
                                pred["pred_relevance"] = 0
         | 
| 45 | 
            +
                            else:
         | 
| 46 | 
            +
                                ious = [calculate_iou(pred_time[0], pred_time[1], gt["timestamp"][0], gt["timestamp"][1]) for gt in matched_rows]
         | 
| 47 | 
            +
                                max_iou_idx = np.argmax(ious)
         | 
| 48 | 
            +
                                max_iou_row = matched_rows[max_iou_idx]
         | 
| 49 | 
            +
                                
         | 
| 50 | 
            +
                                if ious[max_iou_idx] > T:
         | 
| 51 | 
            +
                                    pred["pred_relevance"] = max_iou_row["relevance"]
         | 
| 52 | 
            +
                                    # Remove the matched ground truth row
         | 
| 53 | 
            +
                                    original_idx = one_gt_drop.index(max_iou_row)
         | 
| 54 | 
            +
                                    one_gt_drop.pop(original_idx)
         | 
| 55 | 
            +
                                else:
         | 
| 56 | 
            +
                                    pred["pred_relevance"] = 0
         | 
| 57 | 
            +
                            predictions_with_scores.append(pred)
         | 
| 58 | 
            +
                        for K in KS:
         | 
| 59 | 
            +
                            true_scores = [gt["relevance"] for gt in one_gt][:K]
         | 
| 60 | 
            +
                            pred_scores = [pred["pred_relevance"] for pred in predictions_with_scores][:K]
         | 
| 61 | 
            +
                            ndcg_score = calculate_ndcg(pred_scores, true_scores)
         | 
| 62 | 
            +
                            performance[K][T].append(ndcg_score)
         | 
| 63 | 
            +
                for K, vs in performance.items():
         | 
| 64 | 
            +
                    for T, v in vs.items():
         | 
| 65 | 
            +
                        performance_avg[K][T] = np.mean(v)
         | 
| 66 | 
            +
                return performance_avg
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
    	
        baselines/crossmodal_moment_localization/optimization.py
    ADDED
    
    | @@ -0,0 +1,338 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
            """PyTorch optimization for BERT model."""
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import math
         | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
            from torch.optim import Optimizer
         | 
| 20 | 
            +
            from torch.optim.optimizer import required
         | 
| 21 | 
            +
            from torch.nn.utils import clip_grad_norm_
         | 
| 22 | 
            +
            import logging
         | 
| 23 | 
            +
            import abc
         | 
| 24 | 
            +
            import sys
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            if sys.version_info >= (3, 4):
         | 
| 30 | 
            +
                ABC = abc.ABC
         | 
| 31 | 
            +
            else:
         | 
| 32 | 
            +
                ABC = abc.ABCMeta('ABC', (), {})
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            class _LRSchedule(ABC):
         | 
| 36 | 
            +
                """ Parent of all LRSchedules here. """
         | 
| 37 | 
            +
                warn_t_total = False        # is set to True for schedules where progressing beyond t_total steps doesn't make sense
         | 
| 38 | 
            +
                def __init__(self, warmup=0.002, t_total=-1, **kw):
         | 
| 39 | 
            +
                    """
         | 
| 40 | 
            +
                    :param warmup:  what fraction of t_total steps will be used for linear warmup
         | 
| 41 | 
            +
                    :param t_total: how many training steps (updates) are planned
         | 
| 42 | 
            +
                    :param kw:
         | 
| 43 | 
            +
                    """
         | 
| 44 | 
            +
                    super(_LRSchedule, self).__init__(**kw)
         | 
| 45 | 
            +
                    if t_total < 0:
         | 
| 46 | 
            +
                        logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
         | 
| 47 | 
            +
                    if not 0.0 <= warmup < 1.0 and not warmup == -1:
         | 
| 48 | 
            +
                        raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
         | 
| 49 | 
            +
                    warmup = max(warmup, 0.)
         | 
| 50 | 
            +
                    self.warmup, self.t_total = float(warmup), float(t_total)
         | 
| 51 | 
            +
                    self.warned_for_t_total_at_progress = -1
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def get_lr(self, step, nowarn=False):
         | 
| 54 | 
            +
                    """
         | 
| 55 | 
            +
                    :param step:    which of t_total steps we're on
         | 
| 56 | 
            +
                    :param nowarn:  set to True to suppress warning regarding training beyond specified 't_total' steps
         | 
| 57 | 
            +
                    :return:        learning rate multiplier for current update
         | 
| 58 | 
            +
                    """
         | 
| 59 | 
            +
                    if self.t_total < 0:
         | 
| 60 | 
            +
                        return 1.
         | 
| 61 | 
            +
                    progress = float(step) / self.t_total
         | 
| 62 | 
            +
                    ret = self.get_lr_(progress)
         | 
| 63 | 
            +
                    # warning for exceeding t_total (only active with warmup_linear
         | 
| 64 | 
            +
                    if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress:
         | 
| 65 | 
            +
                        logger.warning(
         | 
| 66 | 
            +
                            "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly."
         | 
| 67 | 
            +
                                .format(ret, self.__class__.__name__))
         | 
| 68 | 
            +
                        self.warned_for_t_total_at_progress = progress
         | 
| 69 | 
            +
                    # end warning
         | 
| 70 | 
            +
                    return ret
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                @abc.abstractmethod
         | 
| 73 | 
            +
                def get_lr_(self, progress):
         | 
| 74 | 
            +
                    """
         | 
| 75 | 
            +
                    :param progress:    value between 0 and 1 (unless going beyond t_total steps) specifying training progress
         | 
| 76 | 
            +
                    :return:            learning rate multiplier for current update
         | 
| 77 | 
            +
                    """
         | 
| 78 | 
            +
                    return 1.
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            class ConstantLR(_LRSchedule):
         | 
| 82 | 
            +
                def get_lr_(self, progress):
         | 
| 83 | 
            +
                    return 1.
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            class WarmupCosineSchedule(_LRSchedule):
         | 
| 87 | 
            +
                """
         | 
| 88 | 
            +
                Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
         | 
| 89 | 
            +
                Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve.
         | 
| 90 | 
            +
                If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
         | 
| 91 | 
            +
                """
         | 
| 92 | 
            +
                warn_t_total = True
         | 
| 93 | 
            +
                def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw):
         | 
| 94 | 
            +
                    """
         | 
| 95 | 
            +
                    :param warmup:      see LRSchedule
         | 
| 96 | 
            +
                    :param t_total:     see LRSchedule
         | 
| 97 | 
            +
                    :param cycles:      number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
         | 
| 98 | 
            +
                    :param kw:
         | 
| 99 | 
            +
                    """
         | 
| 100 | 
            +
                    super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw)
         | 
| 101 | 
            +
                    self.cycles = cycles
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                def get_lr_(self, progress):
         | 
| 104 | 
            +
                    if progress < self.warmup:
         | 
| 105 | 
            +
                        return progress / self.warmup
         | 
| 106 | 
            +
                    else:
         | 
| 107 | 
            +
                        progress = (progress - self.warmup) / (1 - self.warmup)   # progress after warmup
         | 
| 108 | 
            +
                        return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
         | 
| 109 | 
            +
             | 
| 110 | 
            +
             | 
| 111 | 
            +
            class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
         | 
| 112 | 
            +
                """
         | 
| 113 | 
            +
                Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
         | 
| 114 | 
            +
                If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
         | 
| 115 | 
            +
                learning rate (with hard restarts).
         | 
| 116 | 
            +
                """
         | 
| 117 | 
            +
                def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
         | 
| 118 | 
            +
                    super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
         | 
| 119 | 
            +
                    assert(cycles >= 1.)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                def get_lr_(self, progress):
         | 
| 122 | 
            +
                    if progress < self.warmup:
         | 
| 123 | 
            +
                        return progress / self.warmup
         | 
| 124 | 
            +
                    else:
         | 
| 125 | 
            +
                        progress = (progress - self.warmup) / (1 - self.warmup)     # progress after warmup
         | 
| 126 | 
            +
                        ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1)))
         | 
| 127 | 
            +
                        return ret
         | 
| 128 | 
            +
             | 
| 129 | 
            +
             | 
| 130 | 
            +
            class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule):
         | 
| 131 | 
            +
                """
         | 
| 132 | 
            +
                All training progress is divided in `cycles` (default=1.) parts of equal length.
         | 
| 133 | 
            +
                Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1.,
         | 
| 134 | 
            +
                followed by a learning rate decreasing from 1. to 0. following a cosine curve.
         | 
| 135 | 
            +
                """
         | 
| 136 | 
            +
                def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
         | 
| 137 | 
            +
                    assert(warmup * cycles < 1.)
         | 
| 138 | 
            +
                    warmup = warmup * cycles if warmup >= 0 else warmup
         | 
| 139 | 
            +
                    super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                def get_lr_(self, progress):
         | 
| 142 | 
            +
                    progress = progress * self.cycles % 1.
         | 
| 143 | 
            +
                    if progress < self.warmup:
         | 
| 144 | 
            +
                        return progress / self.warmup
         | 
| 145 | 
            +
                    else:
         | 
| 146 | 
            +
                        progress = (progress - self.warmup) / (1 - self.warmup)     # progress after warmup
         | 
| 147 | 
            +
                        ret = 0.5 * (1. + math.cos(math.pi * progress))
         | 
| 148 | 
            +
                        return ret
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
            +
            class WarmupConstantSchedule(_LRSchedule):
         | 
| 152 | 
            +
                """
         | 
| 153 | 
            +
                Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
         | 
| 154 | 
            +
                Keeps learning rate equal to 1. after warmup.
         | 
| 155 | 
            +
                """
         | 
| 156 | 
            +
                def get_lr_(self, progress):
         | 
| 157 | 
            +
                    if progress < self.warmup:
         | 
| 158 | 
            +
                        return progress / self.warmup
         | 
| 159 | 
            +
                    return 1.
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
            class WarmupLinearSchedule(_LRSchedule):
         | 
| 163 | 
            +
                """
         | 
| 164 | 
            +
                Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
         | 
| 165 | 
            +
                Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps.
         | 
| 166 | 
            +
                """
         | 
| 167 | 
            +
                warn_t_total = True
         | 
| 168 | 
            +
                def get_lr_(self, progress):
         | 
| 169 | 
            +
                    if progress < self.warmup:
         | 
| 170 | 
            +
                        return progress / self.warmup
         | 
| 171 | 
            +
                    return max((progress - 1.) / (self.warmup - 1.), 0.)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
             | 
| 174 | 
            +
            SCHEDULES = {
         | 
| 175 | 
            +
                None:       ConstantLR,
         | 
| 176 | 
            +
                "none":     ConstantLR,
         | 
| 177 | 
            +
                "warmup_cosine": WarmupCosineSchedule,
         | 
| 178 | 
            +
                "warmup_constant": WarmupConstantSchedule,
         | 
| 179 | 
            +
                "warmup_linear": WarmupLinearSchedule
         | 
| 180 | 
            +
            }
         | 
| 181 | 
            +
             | 
| 182 | 
            +
             | 
| 183 | 
            +
            class EMA(object):
         | 
| 184 | 
            +
                """ Exponential Moving Average for model parameters.
         | 
| 185 | 
            +
                references:
         | 
| 186 | 
            +
                [1] https://github.com/BangLiu/QANet-PyTorch/blob/master/model/modules/ema.py
         | 
| 187 | 
            +
                [2] https://github.com/hengruo/QANet-pytorch/blob/e2de07cd2c711d525f5ffee35c3764335d4b501d/main.py"""
         | 
| 188 | 
            +
                def __init__(self, decay):
         | 
| 189 | 
            +
                    self.decay = decay
         | 
| 190 | 
            +
                    self.shadow = {}
         | 
| 191 | 
            +
                    self.original = {}
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                def register(self, name, val):
         | 
| 194 | 
            +
                    self.shadow[name] = val.clone()
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                def __call__(self, model, step):
         | 
| 197 | 
            +
                    decay = min(self.decay,  (1 + step) / (10.0 + step))
         | 
| 198 | 
            +
                    for name, param in model.named_parameters():
         | 
| 199 | 
            +
                        if param.requires_grad:
         | 
| 200 | 
            +
                            assert name in self.shadow
         | 
| 201 | 
            +
                            new_average = \
         | 
| 202 | 
            +
                                (1.0 - decay) * param.data + decay * self.shadow[name]
         | 
| 203 | 
            +
                            self.shadow[name] = new_average.clone()
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                def assign(self, model):
         | 
| 206 | 
            +
                    for name, param in model.named_parameters():
         | 
| 207 | 
            +
                        if param.requires_grad:
         | 
| 208 | 
            +
                            assert name in self.shadow
         | 
| 209 | 
            +
                            self.original[name] = param.data.clone()
         | 
| 210 | 
            +
                            param.data = self.shadow[name]
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                def resume(self, model):
         | 
| 213 | 
            +
                    for name, param in model.named_parameters():
         | 
| 214 | 
            +
                        if param.requires_grad:
         | 
| 215 | 
            +
                            assert name in self.shadow
         | 
| 216 | 
            +
                            param.data = self.original[name]
         | 
| 217 | 
            +
             | 
| 218 | 
            +
             | 
| 219 | 
            +
            class BertAdam(Optimizer):
         | 
| 220 | 
            +
                """Implements BERT version of Adam algorithm with weight decay fix.
         | 
| 221 | 
            +
                Params:
         | 
| 222 | 
            +
                    lr: learning rate
         | 
| 223 | 
            +
                    warmup: portion of t_total for the warmup, -1  means no warmup. Default: -1
         | 
| 224 | 
            +
                    t_total: total number of training steps for the learning
         | 
| 225 | 
            +
                        rate schedule, -1  means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1
         | 
| 226 | 
            +
                    schedule: schedule to use for the warmup (see above).
         | 
| 227 | 
            +
                        Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below).
         | 
| 228 | 
            +
                        If `None` or `'none'`, learning rate is always kept constant.
         | 
| 229 | 
            +
                        Default : `'warmup_linear'`
         | 
| 230 | 
            +
                    b1: Adams b1. Default: 0.9
         | 
| 231 | 
            +
                    b2: Adams b2. Default: 0.999
         | 
| 232 | 
            +
                    e: Adams epsilon. Default: 1e-6
         | 
| 233 | 
            +
                    weight_decay: Weight decay. Default: 0.01
         | 
| 234 | 
            +
                    max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
         | 
| 235 | 
            +
                """
         | 
| 236 | 
            +
                def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
         | 
| 237 | 
            +
                             b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
         | 
| 238 | 
            +
                    if lr is not required and lr < 0.0:
         | 
| 239 | 
            +
                        raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
         | 
| 240 | 
            +
                    if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
         | 
| 241 | 
            +
                        raise ValueError("Invalid schedule parameter: {}".format(schedule))
         | 
| 242 | 
            +
                    if not 0.0 <= b1 < 1.0:
         | 
| 243 | 
            +
                        raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
         | 
| 244 | 
            +
                    if not 0.0 <= b2 < 1.0:
         | 
| 245 | 
            +
                        raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
         | 
| 246 | 
            +
                    if not e >= 0.0:
         | 
| 247 | 
            +
                        raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
         | 
| 248 | 
            +
                    # initialize schedule object
         | 
| 249 | 
            +
                    if not isinstance(schedule, _LRSchedule):
         | 
| 250 | 
            +
                        schedule_type = SCHEDULES[schedule]
         | 
| 251 | 
            +
                        schedule = schedule_type(warmup=warmup, t_total=t_total)
         | 
| 252 | 
            +
                    else:
         | 
| 253 | 
            +
                        if warmup != -1 or t_total != -1:
         | 
| 254 | 
            +
                            logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. "
         | 
| 255 | 
            +
                                           "Please specify custom warmup and t_total in _LRSchedule object.")
         | 
| 256 | 
            +
                    defaults = dict(lr=lr, schedule=schedule,
         | 
| 257 | 
            +
                                    b1=b1, b2=b2, e=e, weight_decay=weight_decay,
         | 
| 258 | 
            +
                                    max_grad_norm=max_grad_norm)
         | 
| 259 | 
            +
                    super(BertAdam, self).__init__(params, defaults)
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                def get_lr(self):
         | 
| 262 | 
            +
                    lr = []
         | 
| 263 | 
            +
                    for group in self.param_groups:
         | 
| 264 | 
            +
                        for p in group['params']:
         | 
| 265 | 
            +
                            state = self.state[p]
         | 
| 266 | 
            +
                            if len(state) == 0:
         | 
| 267 | 
            +
                                return [0]
         | 
| 268 | 
            +
                            lr_scheduled = group['lr']
         | 
| 269 | 
            +
                            lr_scheduled *= group['schedule'].get_lr(state['step'])
         | 
| 270 | 
            +
                            lr.append(lr_scheduled)
         | 
| 271 | 
            +
                    return lr
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                def step(self, closure=None):
         | 
| 274 | 
            +
                    """Performs a single optimization step.
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    Arguments:
         | 
| 277 | 
            +
                        closure (callable, optional): A closure that reevaluates the model
         | 
| 278 | 
            +
                            and returns the loss.
         | 
| 279 | 
            +
                    """
         | 
| 280 | 
            +
                    loss = None
         | 
| 281 | 
            +
                    if closure is not None:
         | 
| 282 | 
            +
                        loss = closure()
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    for group in self.param_groups:
         | 
| 285 | 
            +
                        for p in group['params']:
         | 
| 286 | 
            +
                            if p.grad is None:
         | 
| 287 | 
            +
                                continue
         | 
| 288 | 
            +
                            grad = p.grad.data
         | 
| 289 | 
            +
                            if grad.is_sparse:
         | 
| 290 | 
            +
                                raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                            state = self.state[p]
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                            # State initialization
         | 
| 295 | 
            +
                            if len(state) == 0:
         | 
| 296 | 
            +
                                state['step'] = 0
         | 
| 297 | 
            +
                                # Exponential moving average of gradient values
         | 
| 298 | 
            +
                                state['next_m'] = torch.zeros_like(p.data)
         | 
| 299 | 
            +
                                # Exponential moving average of squared gradient values
         | 
| 300 | 
            +
                                state['next_v'] = torch.zeros_like(p.data)
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                            next_m, next_v = state['next_m'], state['next_v']
         | 
| 303 | 
            +
                            beta1, beta2 = group['b1'], group['b2']
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                            # Add grad clipping
         | 
| 306 | 
            +
                            if group['max_grad_norm'] > 0:
         | 
| 307 | 
            +
                                clip_grad_norm_(p, group['max_grad_norm'])
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                            # Decay the first and second moment running average coefficient
         | 
| 310 | 
            +
                            # In-place operations to update the averages at the same time
         | 
| 311 | 
            +
                            next_m.mul_(beta1).add_(grad, alpha=1 - beta1)
         | 
| 312 | 
            +
                            next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
         | 
| 313 | 
            +
                            update = next_m / (next_v.sqrt() + group['e'])
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                            # Just adding the square of the weights to the loss function is *not*
         | 
| 316 | 
            +
                            # the correct way of using L2 regularization/weight decay with Adam,
         | 
| 317 | 
            +
                            # since that will interact with the m and v parameters in strange ways.
         | 
| 318 | 
            +
                            #
         | 
| 319 | 
            +
                            # Instead we want to decay the weights in a manner that doesn't interact
         | 
| 320 | 
            +
                            # with the m/v parameters. This is equivalent to adding the square
         | 
| 321 | 
            +
                            # of the weights to the loss with plain (non-momentum) SGD.
         | 
| 322 | 
            +
                            if group['weight_decay'] > 0.0:
         | 
| 323 | 
            +
                                update += group['weight_decay'] * p.data
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                            lr_scheduled = group['lr']
         | 
| 326 | 
            +
                            lr_scheduled *= group['schedule'].get_lr(state['step'])
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                            update_with_lr = lr_scheduled * update
         | 
| 329 | 
            +
                            p.data.add_(-update_with_lr)
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                            state['step'] += 1
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                            # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
         | 
| 334 | 
            +
                            # No bias correction
         | 
| 335 | 
            +
                            # bias_correction1 = 1 - beta1 ** state['step']
         | 
| 336 | 
            +
                            # bias_correction2 = 1 - beta2 ** state['step']
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                    return loss
         | 
    	
        baselines/crossmodal_moment_localization/scripts/eval.sh
    ADDED
    
    | @@ -0,0 +1,14 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env bash
         | 
| 2 | 
            +
            # run at project root dir
         | 
| 3 | 
            +
            # Usage:
         | 
| 4 | 
            +
            # bash baselines/crossmodal_moment_localization/scripts/eval.sh ANY_OTHER_PYTHON_ARGS
         | 
| 5 | 
            +
            eval_split_name=$1
         | 
| 6 | 
            +
            submission_path=$2
         | 
| 7 | 
            +
            save_path=$3
         | 
| 8 | 
            +
            gt_path=data/tvr_${eval_split_name}_release.jsonl
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            python standalone_eval/eval.py \
         | 
| 11 | 
            +
            --gt_path ${gt_path} \
         | 
| 12 | 
            +
            --submission_path ${submission_path} \
         | 
| 13 | 
            +
            --save_path ${save_path} \
         | 
| 14 | 
            +
            ${@:4}
         | 
    	
        baselines/crossmodal_moment_localization/scripts/inference.sh
    ADDED
    
    | @@ -0,0 +1,18 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env bash
         | 
| 2 | 
            +
            # run at project root dir
         | 
| 3 | 
            +
            # Usage:
         | 
| 4 | 
            +
            # bash baselines/crossmodal_moment_localization/scripts/inference.sh ANY_OTHER_PYTHON_ARGS
         | 
| 5 | 
            +
            model_dir=$1
         | 
| 6 | 
            +
            eval_split_name=$2
         | 
| 7 | 
            +
            eval_path=data/tvr_${eval_split_name}_release.jsonl
         | 
| 8 | 
            +
            tasks=()
         | 
| 9 | 
            +
            tasks+=(VCMR)
         | 
| 10 | 
            +
            tasks+=(SVMR)
         | 
| 11 | 
            +
            tasks+=(VR)
         | 
| 12 | 
            +
            echo "tasks ${tasks[@]}"
         | 
| 13 | 
            +
            python baselines/crossmodal_moment_localization/inference.py \
         | 
| 14 | 
            +
            --model_dir ${model_dir} \
         | 
| 15 | 
            +
            --tasks ${tasks[@]} \
         | 
| 16 | 
            +
            --eval_split_name ${eval_split_name} \
         | 
| 17 | 
            +
            --eval_path ${eval_path} \
         | 
| 18 | 
            +
            ${@:3}
         | 
    	
        baselines/crossmodal_moment_localization/scripts/inference_with_external.sh
    ADDED
    
    | @@ -0,0 +1,40 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env bash
         | 
| 2 | 
            +
            # run at project root dir
         | 
| 3 | 
            +
            # Usage:
         | 
| 4 | 
            +
            # bash baselines/crossmodal_moment_localization/scripts/inference_with_external.sh
         | 
| 5 | 
            +
            #model_dir=$1
         | 
| 6 | 
            +
            # DO not use NMS, since it gives worse results
         | 
| 7 | 
            +
            eval_model=$1  # [xml, xml_tef]
         | 
| 8 | 
            +
            eval_split_name=$2
         | 
| 9 | 
            +
            external_model=mee  # [mee, mcn, cal]
         | 
| 10 | 
            +
            eval_path=data/tvr_${eval_split_name}_release.jsonl
         | 
| 11 | 
            +
            project_root=./baselines
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # setup eval model
         | 
| 14 | 
            +
            if [[ ${eval_model} == xml ]]; then
         | 
| 15 | 
            +
                eval_model_dir=tvr-video_sub-resnet_i3d_no_norm_v-2019_11_03_12_22_19
         | 
| 16 | 
            +
            elif [[ ${eval_model} == xml_tef ]]; then
         | 
| 17 | 
            +
                eval_model_dir=tvr-video_sub_tef-resnet_i3d_no_norm_v-2019_11_03_12_53_01
         | 
| 18 | 
            +
            fi
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            # setup external
         | 
| 21 | 
            +
            if [[ ${external_model} == mee ]]; then
         | 
| 22 | 
            +
                external_model_dir=tvr-video_sub-res-2019_11_06_00_33_39
         | 
| 23 | 
            +
                external_inference_vr_res_path=${project_root}/mixture_embedding_experts/results/${external_model_dir}/inference_tvr_${eval_split_name}_None_predictions_VR.json
         | 
| 24 | 
            +
            fi
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            tasks=(VR)
         | 
| 27 | 
            +
            tasks+=(SVMR)
         | 
| 28 | 
            +
            tasks+=(VCMR)
         | 
| 29 | 
            +
            echo "tasks ${tasks[@]}"
         | 
| 30 | 
            +
            python baselines/crossmodal_moment_localization/inference.py \
         | 
| 31 | 
            +
            --model_dir ${eval_model_dir} \
         | 
| 32 | 
            +
            --tasks ${tasks[@]} \
         | 
| 33 | 
            +
            --eval_split_name ${eval_split_name} \
         | 
| 34 | 
            +
            --eval_path ${eval_path} \
         | 
| 35 | 
            +
            --external_inference_vr_res_path ${external_inference_vr_res_path} \
         | 
| 36 | 
            +
            --eval_id ${external_model_dir} \
         | 
| 37 | 
            +
            ${@:3}
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            #--use_intermediate \  # temporary removed
         | 
| 40 | 
            +
             | 
    	
        baselines/crossmodal_moment_localization/scripts/train.sh
    ADDED
    
    | @@ -0,0 +1,70 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env bash
         | 
| 2 | 
            +
            # run at project root dir
         | 
| 3 | 
            +
            # Usage:
         | 
| 4 | 
            +
            # bash baselines/crossmodal_moment_localization/scripts/train.sh tvr all ANY_OTHER_PYTHON_ARGS
         | 
| 5 | 
            +
            # use --eval_tasks_at_training ["VR", "SVMR", "VCMR"] --stop_task ["VR", "SVMR", "VCMR"] for
         | 
| 6 | 
            +
            # use --lw_neg_q 0 --lw_neg_ctx 0 for training SVMR/SVMR only
         | 
| 7 | 
            +
            # use --lw_st_ed 0 for training with VR only
         | 
| 8 | 
            +
            dset_name=$1  # see case below
         | 
| 9 | 
            +
            ctx_mode=$2  # [video, sub, tef, video_sub, video_tef, sub_tef, video_sub_tef]
         | 
| 10 | 
            +
            vid_feat_type=$3  # [resnet, i3d, resnet_i3d]
         | 
| 11 | 
            +
            feature_root=data/tvr_feature_release
         | 
| 12 | 
            +
            results_root=baselines/crossmodal_moment_localization/results
         | 
| 13 | 
            +
            vid_feat_size=2048
         | 
| 14 | 
            +
            extra_args=()
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            if [[ ${ctx_mode} == *"sub"* ]] || [[ ${ctx_mode} == "sub" ]]; then
         | 
| 17 | 
            +
                if [[ ${dset_name} != "tvr" ]]; then
         | 
| 18 | 
            +
                    echo "The use of subtitles is only supported in tvr."
         | 
| 19 | 
            +
                    exit 1
         | 
| 20 | 
            +
                fi
         | 
| 21 | 
            +
            fi
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            case ${dset_name} in
         | 
| 25 | 
            +
                tvr)
         | 
| 26 | 
            +
                    train_path=data/tvr_train_release.jsonl
         | 
| 27 | 
            +
                    corpus_path=data/tvr_video2dur_idx.json
         | 
| 28 | 
            +
                    desc_bert_path=${feature_root}/bert_feature/query_only/tvr_query_pretrained_w_query.h5
         | 
| 29 | 
            +
                    if [[ ${vid_feat_type} == "i3d" ]]; then
         | 
| 30 | 
            +
                        echo "Using I3D feature with shape 1024"
         | 
| 31 | 
            +
                        vid_feat_path=${feature_root}/video_feature/tvr_i3d_rgb600_avg_cl-1.5.h5
         | 
| 32 | 
            +
                        vid_feat_size=1024
         | 
| 33 | 
            +
                    elif [[ ${vid_feat_type} == "resnet" ]]; then
         | 
| 34 | 
            +
                        echo "Using ResNet feature with shape 2048"
         | 
| 35 | 
            +
                        vid_feat_path=${feature_root}/video_feature/tvr_resnet152_rgb_max_cl-1.5.h5
         | 
| 36 | 
            +
                        vid_feat_size=2048
         | 
| 37 | 
            +
                    elif [[ ${vid_feat_type} == "resnet_i3d" ]]; then
         | 
| 38 | 
            +
                        echo "Using concatenated ResNet and I3D feature with shape 2048+1024"
         | 
| 39 | 
            +
                        vid_feat_path=${feature_root}/video_feature/tvr_resnet152_rgb_max_i3d_rgb600_avg_cat_cl-1.5.h5
         | 
| 40 | 
            +
                        vid_feat_size=3072
         | 
| 41 | 
            +
                        extra_args+=(--no_norm_vfeat)  # since they are already normalized.
         | 
| 42 | 
            +
                    fi
         | 
| 43 | 
            +
                    eval_split_name=val
         | 
| 44 | 
            +
                    nms_thd=-1
         | 
| 45 | 
            +
                    extra_args+=(--eval_path)
         | 
| 46 | 
            +
                    extra_args+=(data/tvr_val_release.jsonl)
         | 
| 47 | 
            +
                    clip_length=1.5
         | 
| 48 | 
            +
                    extra_args+=(--max_ctx_l)
         | 
| 49 | 
            +
                    extra_args+=(100)  # max_ctx_l = 100 for clip_length = 1.5, only ~109/21825 has more than 100.
         | 
| 50 | 
            +
                    extra_args+=(--max_pred_l)
         | 
| 51 | 
            +
                    extra_args+=(16)
         | 
| 52 | 
            +
                    if [[ ${ctx_mode} == *"sub"* ]] || [[ ${ctx_mode} == "sub" ]]; then
         | 
| 53 | 
            +
                        echo "Running with sub."
         | 
| 54 | 
            +
                        desc_bert_path=${feature_root}/bert_feature/sub_query/tvr_query_pretrained_w_sub_query.h5  # overwrite
         | 
| 55 | 
            +
                        sub_bert_path=${feature_root}/bert_feature/sub_query/tvr_sub_pretrained_w_sub_query_max_cl-1.5.h5
         | 
| 56 | 
            +
                        sub_feat_size=768
         | 
| 57 | 
            +
                        extra_args+=(--sub_feat_size)
         | 
| 58 | 
            +
                        extra_args+=(${sub_feat_size})
         | 
| 59 | 
            +
                        extra_args+=(--sub_bert_path)
         | 
| 60 | 
            +
                        extra_args+=(${sub_bert_path})
         | 
| 61 | 
            +
                    fi
         | 
| 62 | 
            +
                    ;;
         | 
| 63 | 
            +
                *)
         | 
| 64 | 
            +
                    echo -n "Unknown argument"
         | 
| 65 | 
            +
                    ;;
         | 
| 66 | 
            +
            esac
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            echo "Start training with dataset [${dset_name}] in Context Mode [${ctx_mode}]"
         | 
| 69 | 
            +
            echo "Extra args ${extra_args[@]}"
         | 
| 70 | 
            +
            echo " python baselines/crossmodal_moment_localization/train.py     --dset_name=${dset_name}     --eval_split_name=${eval_split_name}     --nms_thd=${nms_thd}     --results_root=${results_root}     --train_path=${train_path}     --desc_bert_path=${desc_bert_path}     --corpus_path=${corpus_path}     --vid_feat_path=${vid_feat_path}     --clip_length=${clip_length}     --vid_feat_size=${vid_feat_size}     --ctx_mode=${ctx_mode}     ${extra_args[@]}     ${@:4}"
         | 
    	
        baselines/crossmodal_moment_localization/start_end_dataset.py
    ADDED
    
    | @@ -0,0 +1,393 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Dataset for clip model
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
            import logging
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from torch.utils.data import Dataset
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            import h5py
         | 
| 9 | 
            +
            import time
         | 
| 10 | 
            +
            import math
         | 
| 11 | 
            +
            import random
         | 
| 12 | 
            +
            from tqdm import tqdm
         | 
| 13 | 
            +
            from utils.basic_utils import load_json, load_json, l2_normalize_np_array, flat_list_of_lists, merge_dicts
         | 
| 14 | 
            +
            from utils.tensor_utils import pad_sequences_1d
         | 
| 15 | 
            +
            from baselines.clip_alignment_with_language.local_utils.compute_proposal_upper_bound import \
         | 
| 16 | 
            +
                get_didemo_agreed_ts
         | 
| 17 | 
            +
            import pandas as pd
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class StartEndDataset(Dataset):
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                Args:
         | 
| 25 | 
            +
                    dset_name, str, ["tvr"]
         | 
| 26 | 
            +
                    ctx_mode: str,
         | 
| 27 | 
            +
                Return:
         | 
| 28 | 
            +
                    a dict: {
         | 
| 29 | 
            +
                        "meta": {
         | 
| 30 | 
            +
                            "query_id": int,
         | 
| 31 | 
            +
                            "desc": str,
         | 
| 32 | 
            +
                            "vid_name": str,
         | 
| 33 | 
            +
                            "duration": float,
         | 
| 34 | 
            +
                            "ts": [st (float), ed (float)], seconds, ground_truth timestamps
         | 
| 35 | 
            +
                        }
         | 
| 36 | 
            +
                        "model_inputs": {
         | 
| 37 | 
            +
                            "query_feat": torch.tensor, (L, D_q)
         | 
| 38 | 
            +
                            "video_feat": torch.tensor, (n_clip_in_moment, D_video)
         | 
| 39 | 
            +
                            "sub_feat": torch.tensor, (n_clip_in_moment, D_sub)
         | 
| 40 | 
            +
                            "st_ed_indices": torch.LongTensor, (2, )
         | 
| 41 | 
            +
                        }
         | 
| 42 | 
            +
                    }
         | 
| 43 | 
            +
                """
         | 
| 44 | 
            +
                def __init__(self, dset_name, data_path, desc_bert_path_or_handler, sub_bert_path_or_handler,
         | 
| 45 | 
            +
                             max_desc_len, max_ctx_len,
         | 
| 46 | 
            +
                             vid_feat_path_or_handler, clip_length, ctx_mode="video",
         | 
| 47 | 
            +
                             normalize_vfeat=True, normalize_tfeat=True, h5driver=None, data_ratio=1.0):
         | 
| 48 | 
            +
                    self.dset_name = dset_name
         | 
| 49 | 
            +
                    self.data_path = data_path
         | 
| 50 | 
            +
                    self.data_ratio = data_ratio
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    self.desc_bert_path_or_handler = desc_bert_path_or_handler
         | 
| 53 | 
            +
                    self.max_desc_len = max_desc_len
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    self.sub_bert_path_or_handler = sub_bert_path_or_handler
         | 
| 56 | 
            +
                    self.max_ctx_len = max_ctx_len
         | 
| 57 | 
            +
                    self.vid_feat_path_or_handler = vid_feat_path_or_handler
         | 
| 58 | 
            +
                    self.clip_length = clip_length
         | 
| 59 | 
            +
                    self.ctx_mode = ctx_mode
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    # prepare desc data
         | 
| 62 | 
            +
                    self.data = self.expand_annotations(load_json(data_path))
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    if self.data_ratio != 1:
         | 
| 65 | 
            +
                        n_examples = int(len(self.data) * data_ratio)
         | 
| 66 | 
            +
                        self.data = self.data[:n_examples]
         | 
| 67 | 
            +
                        logger.info("Using {}% of the data: {} examples".format(data_ratio * 100, n_examples))
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    self.use_video = "video" in self.ctx_mode
         | 
| 70 | 
            +
                    self.use_sub = "sub" in self.ctx_mode
         | 
| 71 | 
            +
                    self.use_tef = "tef" in self.ctx_mode
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    if self.use_video:
         | 
| 74 | 
            +
                        if isinstance(vid_feat_path_or_handler, h5py.File):
         | 
| 75 | 
            +
                            self.vid_feat_h5 = vid_feat_path_or_handler
         | 
| 76 | 
            +
                        else:  # str path
         | 
| 77 | 
            +
                            self.vid_feat_h5 = h5py.File(vid_feat_path_or_handler, "r", driver=h5driver)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    if isinstance(desc_bert_path_or_handler, h5py.File):
         | 
| 80 | 
            +
                        self.desc_bert_h5 = desc_bert_path_or_handler
         | 
| 81 | 
            +
                    else:
         | 
| 82 | 
            +
                        self.desc_bert_h5 = h5py.File(desc_bert_path_or_handler, "r", driver=h5driver)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    if self.use_sub:
         | 
| 85 | 
            +
                        if isinstance(sub_bert_path_or_handler, h5py.File):
         | 
| 86 | 
            +
                            self.sub_bert_h5 = sub_bert_path_or_handler
         | 
| 87 | 
            +
                        else:  # str path
         | 
| 88 | 
            +
                            self.sub_bert_h5 = h5py.File(sub_bert_path_or_handler, "r", driver=h5driver)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    self.normalize_vfeat = normalize_vfeat
         | 
| 91 | 
            +
                    self.normalize_tfeat = normalize_tfeat
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def __len__(self):
         | 
| 94 | 
            +
                    return len(self.data)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                def expand_annotations(self, annotations):
         | 
| 97 | 
            +
                    new_annotations = []
         | 
| 98 | 
            +
                    for i in annotations:
         | 
| 99 | 
            +
                        query = i["query"]
         | 
| 100 | 
            +
                        query_id = i["query_id"]
         | 
| 101 | 
            +
                        for moment in  i["relevant_moment"]:
         | 
| 102 | 
            +
                            moment.update({'query': query, 'query_id': query_id})
         | 
| 103 | 
            +
                            new_annotations.append(moment)
         | 
| 104 | 
            +
                    return new_annotations
         | 
| 105 | 
            +
                
         | 
| 106 | 
            +
                def __getitem__(self, index):
         | 
| 107 | 
            +
                    raw_data = self.data[index]
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    # initialize with basic data
         | 
| 110 | 
            +
                    meta = dict(
         | 
| 111 | 
            +
                        query_id=raw_data["query_id"],
         | 
| 112 | 
            +
                        desc=raw_data["query"],
         | 
| 113 | 
            +
                        vid_name=raw_data["video_name"],
         | 
| 114 | 
            +
                        duration=raw_data["duration"],
         | 
| 115 | 
            +
                        ts=raw_data["timestamp"] ,
         | 
| 116 | 
            +
                    )
         | 
| 117 | 
            +
                    model_inputs = dict()
         | 
| 118 | 
            +
                    model_inputs["query_feat"] = self.get_query_feat_by_query_id(meta["query_id"])
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    ctx_l = 0
         | 
| 121 | 
            +
                    if self.use_video:
         | 
| 122 | 
            +
                        video_feat = self.vid_feat_h5[meta["vid_name"]][:self.max_ctx_len]  # (N_clip, D)
         | 
| 123 | 
            +
                        if self.normalize_vfeat:
         | 
| 124 | 
            +
                            video_feat = l2_normalize_np_array(video_feat)
         | 
| 125 | 
            +
                        model_inputs["video_feat"] = torch.from_numpy(video_feat)
         | 
| 126 | 
            +
                        ctx_l = len(video_feat)
         | 
| 127 | 
            +
                    else:
         | 
| 128 | 
            +
                        model_inputs["video_feat"] = torch.zeros((2, 2))
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    if self.use_sub:  # no need for ctx feature, as the features are already contextulized
         | 
| 131 | 
            +
                        sub_feat = self.sub_bert_h5[meta["vid_name"]][:self.max_ctx_len]  # (N_clips, D_t)
         | 
| 132 | 
            +
                        if self.normalize_tfeat:
         | 
| 133 | 
            +
                            sub_feat = l2_normalize_np_array(sub_feat)
         | 
| 134 | 
            +
                        model_inputs["sub_feat"] = torch.from_numpy(sub_feat)
         | 
| 135 | 
            +
                        ctx_l = len(sub_feat)
         | 
| 136 | 
            +
                    else:
         | 
| 137 | 
            +
                        model_inputs["sub_feat"] = torch.zeros((2, 2))
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    if self.use_tef:
         | 
| 140 | 
            +
                        # note the tef features here are normalized clip indices (1.5 secs), instead of the original time (1 sec)
         | 
| 141 | 
            +
                        ctx_l = meta["duration"] // self.clip_length + 1 if ctx_l == 0 else ctx_l
         | 
| 142 | 
            +
                        tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
         | 
| 143 | 
            +
                        tef_ed = tef_st + 1.0 / ctx_l
         | 
| 144 | 
            +
                        tef = torch.stack([tef_st, tef_ed], dim=1)  # (N_clips, 2)
         | 
| 145 | 
            +
                        model_inputs["tef_feat"] = tef
         | 
| 146 | 
            +
                    else:
         | 
| 147 | 
            +
                        model_inputs["tef_feat"] = torch.zeros((2, 2))
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    if self.use_video and self.use_tef:
         | 
| 150 | 
            +
                        model_inputs["video_feat"] = torch.cat(
         | 
| 151 | 
            +
                            [model_inputs["video_feat"], model_inputs["tef_feat"]], dim=1)  # (N_clips, D+2)
         | 
| 152 | 
            +
                    if self.use_sub and self.use_tef:
         | 
| 153 | 
            +
                        model_inputs["sub_feat"] = torch.cat(
         | 
| 154 | 
            +
                            [model_inputs["sub_feat"], model_inputs["tef_feat"]], dim=1)  # (N_clips, D_t+2)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    model_inputs["st_ed_indices"] = self.get_st_ed_label(meta["ts"], max_idx=ctx_l-1)
         | 
| 157 | 
            +
                    return dict(meta=meta, model_inputs=model_inputs)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                def get_st_ed_label(self, ts, max_idx):
         | 
| 160 | 
            +
                    """
         | 
| 161 | 
            +
                    Args:
         | 
| 162 | 
            +
                        ts: [st (float), ed (float)] in seconds, ed > st
         | 
| 163 | 
            +
                        max_idx: length of the video
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    Returns:
         | 
| 166 | 
            +
                        [st_idx, ed_idx]: int,
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    Given ts = [3.2, 7.6], st_idx = 2, ed_idx = 6,
         | 
| 169 | 
            +
                    clips should be indexed as [2: 6), the translated back ts should be [3:9].
         | 
| 170 | 
            +
                    # TODO which one is better, [2: 5] or [2: 6)
         | 
| 171 | 
            +
                    """
         | 
| 172 | 
            +
                    st_idx = min(math.floor(ts[0] / self.clip_length), max_idx)
         | 
| 173 | 
            +
                    ed_idx = min(math.ceil(ts[1] / self.clip_length), max_idx)
         | 
| 174 | 
            +
                    return torch.LongTensor([st_idx, ed_idx])
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                def get_query_feat_by_query_id(self, query_id):
         | 
| 177 | 
            +
                    query_feat = self.desc_bert_h5[str(query_id)][:self.max_desc_len]
         | 
| 178 | 
            +
                    if self.normalize_tfeat:
         | 
| 179 | 
            +
                        query_feat = l2_normalize_np_array(query_feat)
         | 
| 180 | 
            +
                    return torch.from_numpy(query_feat)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
             | 
| 183 | 
            +
            class StartEndEvalDataset(Dataset):
         | 
| 184 | 
            +
                """
         | 
| 185 | 
            +
                init_data_mode: `video_query` or `video_only` or `query_only`,
         | 
| 186 | 
            +
                    it indicates which data to load when initialize the Dataset object.
         | 
| 187 | 
            +
                data_mode: `context` or `query`, it indicates which data to return for self.__get_item__()
         | 
| 188 | 
            +
                desc_bert_path_or_handler: h5py.File object or str path
         | 
| 189 | 
            +
                vid_feat_path_or_handler: h5py.File object or str path
         | 
| 190 | 
            +
                eval_proposal_bsz: the proposals for a single video will be sorted in length and batched here with
         | 
| 191 | 
            +
                    max batch size to be eval_proposal_bsz. A single video might have multiple batches of proposals.
         | 
| 192 | 
            +
                load_gt_video: load GroundTruth Video, useful when evaluating single video moment retrieval.
         | 
| 193 | 
            +
                data_ratio: percentage of query data to use.
         | 
| 194 | 
            +
                """
         | 
| 195 | 
            +
                def __init__(self, data_path=None,
         | 
| 196 | 
            +
                             desc_bert_path_or_handler=None, max_desc_len=None,  max_ctx_len=None,
         | 
| 197 | 
            +
                             sub_bert_path_or_handler=None, vid_feat_path_or_handler=None,
         | 
| 198 | 
            +
                             corpus_path=None, clip_length=None,
         | 
| 199 | 
            +
                             ctx_mode="video", data_mode="context",
         | 
| 200 | 
            +
                             h5driver=None, data_ratio=1.0, normalize_vfeat=True, normalize_tfeat=True):
         | 
| 201 | 
            +
                    self.ctx_mode = ctx_mode
         | 
| 202 | 
            +
                    self.load_gt_video = False
         | 
| 203 | 
            +
                    self.data_ratio = data_ratio  # only affect query data
         | 
| 204 | 
            +
                    self.normalize_vfeat = normalize_vfeat
         | 
| 205 | 
            +
                    self.normalize_tfeat = normalize_tfeat
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    self.data_mode = None
         | 
| 208 | 
            +
                    self.set_data_mode(data_mode)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    self.max_desc_len = max_desc_len
         | 
| 211 | 
            +
                    self.max_ctx_len = max_ctx_len
         | 
| 212 | 
            +
                    self.data_path = data_path
         | 
| 213 | 
            +
                    
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    self.annotations = load_json(data_path)
         | 
| 216 | 
            +
                    self.ground_truth = self.get_relevant_moment_gt()
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    
         | 
| 219 | 
            +
                    if isinstance(desc_bert_path_or_handler, h5py.File):
         | 
| 220 | 
            +
                        self.desc_bert_h5 = desc_bert_path_or_handler
         | 
| 221 | 
            +
                    else:
         | 
| 222 | 
            +
                        self.desc_bert_h5 = h5py.File(desc_bert_path_or_handler, "r", driver=h5driver)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    video_data = load_json(corpus_path)
         | 
| 225 | 
            +
                    self.video_data = [{"vid_name": k, "duration": v} for k, v in video_data.items()]
         | 
| 226 | 
            +
                    self.video2idx = {k: v for k, v in video_data.items()}
         | 
| 227 | 
            +
                    self.clip_length = clip_length
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    self.use_video = "video" in self.ctx_mode
         | 
| 230 | 
            +
                    self.use_sub = "sub" in self.ctx_mode
         | 
| 231 | 
            +
                    self.use_tef = "tef" in self.ctx_mode
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    if self.use_video:
         | 
| 234 | 
            +
                        if isinstance(vid_feat_path_or_handler, h5py.File):
         | 
| 235 | 
            +
                            self.vid_feat_h5 = vid_feat_path_or_handler
         | 
| 236 | 
            +
                        else:  # str path
         | 
| 237 | 
            +
                            self.vid_feat_h5 = h5py.File(vid_feat_path_or_handler, "r", driver=h5driver)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    if self.use_sub:
         | 
| 240 | 
            +
                        if isinstance(sub_bert_path_or_handler, h5py.File):
         | 
| 241 | 
            +
                            self.sub_bert_h5 = sub_bert_path_or_handler
         | 
| 242 | 
            +
                        else:  # str path
         | 
| 243 | 
            +
                            self.sub_bert_h5 = h5py.File(sub_bert_path_or_handler, "r", driver=h5driver)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
             | 
| 246 | 
            +
                def get_relevant_moment_gt(self):
         | 
| 247 | 
            +
                    gt_all = {}
         | 
| 248 | 
            +
                    for data in self.annotations:
         | 
| 249 | 
            +
                        gt_all[data["query_id"]] = data["relevant_moment"]
         | 
| 250 | 
            +
                    return gt_all
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                def set_data_mode(self, data_mode):
         | 
| 253 | 
            +
                    """context or query"""
         | 
| 254 | 
            +
                    assert data_mode in ["context", "query"]
         | 
| 255 | 
            +
                    self.data_mode = data_mode
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                # def load_gt_vid_name_for_query(self, load_gt_video):
         | 
| 258 | 
            +
                #     """load_gt_video: bool, affect the returned value of self._get_item_query"""
         | 
| 259 | 
            +
                #     if load_gt_video:
         | 
| 260 | 
            +
                #         assert "vid_name" in self.query_data[0]
         | 
| 261 | 
            +
                #     self.load_gt_video = load_gt_video
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                def __len__(self):
         | 
| 264 | 
            +
                    if self.data_mode == "context":
         | 
| 265 | 
            +
                        return len(self.video_data)
         | 
| 266 | 
            +
                    else:
         | 
| 267 | 
            +
                        return len(self.annotations)
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                def __getitem__(self, index):
         | 
| 270 | 
            +
                    if self.data_mode == "context":
         | 
| 271 | 
            +
                        return self._get_item_context(index)
         | 
| 272 | 
            +
                    else:
         | 
| 273 | 
            +
                        return self._get_item_query(index)
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                def get_query_feat_by_query_id(self, query_id):
         | 
| 276 | 
            +
                    query_feat = self.desc_bert_h5[str(query_id)][:self.max_desc_len]
         | 
| 277 | 
            +
                    if self.normalize_tfeat:
         | 
| 278 | 
            +
                        query_feat = l2_normalize_np_array(query_feat)
         | 
| 279 | 
            +
                    return torch.from_numpy(query_feat)
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                def _get_item_query(self, index):
         | 
| 282 | 
            +
                    """Need to batch"""
         | 
| 283 | 
            +
                    raw_data = self.annotations[index]
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                    meta = dict(
         | 
| 286 | 
            +
                        query_id=raw_data["query_id"],
         | 
| 287 | 
            +
                        desc=raw_data["query"],
         | 
| 288 | 
            +
                        vid_name=raw_data["video_name"] if self.load_gt_video else None
         | 
| 289 | 
            +
                    )
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    model_inputs = dict()
         | 
| 292 | 
            +
                    model_inputs["query_feat"] = self.get_query_feat_by_query_id(meta["query_id"])
         | 
| 293 | 
            +
                    return dict(meta=meta, model_inputs=model_inputs)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                def get_st_ed_label(self, ts, max_idx):
         | 
| 296 | 
            +
                    """
         | 
| 297 | 
            +
                    Args:
         | 
| 298 | 
            +
                        ts: [st (float), ed (float)] in seconds, ed > st
         | 
| 299 | 
            +
                        max_idx: length of the video
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                    Returns:
         | 
| 302 | 
            +
                        [st_idx, ed_idx]: int,
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    Given ts = [3.2, 7.6], st_idx = 2, ed_idx = 6,
         | 
| 305 | 
            +
                    clips should be indexed as [2: 6), the translated back ts should be [3:9].
         | 
| 306 | 
            +
                    Given ts = [5, 9], st_idx = 3, ed_idx = 6,
         | 
| 307 | 
            +
                    clips should be indexed as [3: 6), the translated back ts should be [4.5:9].
         | 
| 308 | 
            +
                    # TODO which one is better, [2: 5] or [2: 6)
         | 
| 309 | 
            +
                    """
         | 
| 310 | 
            +
                    # TODO ed_idx -= 1, should also modify relevant code in inference.py
         | 
| 311 | 
            +
                    st_idx = min(math.floor(ts[0] / self.clip_length), max_idx)
         | 
| 312 | 
            +
                    ed_idx = min(math.ceil(ts[1] / self.clip_length) - 1, max_idx)  # st_idx could be the same as ed_idx
         | 
| 313 | 
            +
                    return torch.LongTensor([st_idx, ed_idx])
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                def _get_item_context(self, index):
         | 
| 316 | 
            +
                    """No need to batch, since it has already been batched here"""
         | 
| 317 | 
            +
                    raw_data = self.video_data[index]
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    # initialize with basic data
         | 
| 320 | 
            +
                    meta = dict(
         | 
| 321 | 
            +
                        vid_name=raw_data["vid_name"],
         | 
| 322 | 
            +
                        duration=raw_data["duration"],
         | 
| 323 | 
            +
                    )
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                    model_inputs = dict()
         | 
| 326 | 
            +
                    ctx_l = 0
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    if self.use_video:
         | 
| 329 | 
            +
                        video_feat = self.vid_feat_h5[meta["vid_name"]][:self.max_ctx_len]  # (N_clip, D)
         | 
| 330 | 
            +
                        if self.normalize_vfeat:
         | 
| 331 | 
            +
                            video_feat = l2_normalize_np_array(video_feat)
         | 
| 332 | 
            +
                        model_inputs["video_feat"] = torch.from_numpy(video_feat)
         | 
| 333 | 
            +
                        ctx_l = len(video_feat)
         | 
| 334 | 
            +
                    else:
         | 
| 335 | 
            +
                        model_inputs["video_feat"] = torch.zeros((2, 2))
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                    if self.use_sub:  # no need for ctx feature, as the features are already contextulized
         | 
| 338 | 
            +
                        sub_feat = self.sub_bert_h5[meta["vid_name"]][:self.max_ctx_len]  # (N_clips, D_t)
         | 
| 339 | 
            +
                        if self.normalize_tfeat:
         | 
| 340 | 
            +
                            sub_feat = l2_normalize_np_array(sub_feat)
         | 
| 341 | 
            +
                        model_inputs["sub_feat"] = torch.from_numpy(sub_feat)
         | 
| 342 | 
            +
                        ctx_l = len(sub_feat)
         | 
| 343 | 
            +
                    else:
         | 
| 344 | 
            +
                        model_inputs["sub_feat"] = torch.zeros((2, 2))
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                    if self.use_tef:
         | 
| 347 | 
            +
                        ctx_l = meta["duration"] // self.clip_length + 1 if ctx_l == 0 else ctx_l
         | 
| 348 | 
            +
                        tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
         | 
| 349 | 
            +
                        tef_ed = tef_st + 1.0 / ctx_l
         | 
| 350 | 
            +
                        tef = torch.stack([tef_st, tef_ed], dim=1)  # (N_clips, 2)
         | 
| 351 | 
            +
                        model_inputs["tef_feat"] = tef
         | 
| 352 | 
            +
                    else:
         | 
| 353 | 
            +
                        model_inputs["tef_feat"] = torch.zeros((2, 2))
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                    if self.use_video and self.use_tef:
         | 
| 356 | 
            +
                        model_inputs["video_feat"] = torch.cat(
         | 
| 357 | 
            +
                            [model_inputs["video_feat"], model_inputs["tef_feat"]], dim=1)  # (N_clips, D+2)
         | 
| 358 | 
            +
                    if self.use_sub and self.use_tef:
         | 
| 359 | 
            +
                        model_inputs["sub_feat"] = torch.cat(
         | 
| 360 | 
            +
                            [model_inputs["sub_feat"], model_inputs["tef_feat"]], dim=1)  # (N_clips, D_t+2)
         | 
| 361 | 
            +
                    return dict(meta=meta, model_inputs=model_inputs)
         | 
| 362 | 
            +
             | 
| 363 | 
            +
             | 
| 364 | 
            +
            def start_end_collate(batch):
         | 
| 365 | 
            +
                batch_meta = [e["meta"] for e in batch]  # seems no need to collate ?
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                model_inputs_keys = batch[0]["model_inputs"].keys()
         | 
| 368 | 
            +
                batched_data = dict()
         | 
| 369 | 
            +
                for k in model_inputs_keys:
         | 
| 370 | 
            +
                    if "feat" in k:
         | 
| 371 | 
            +
                        batched_data[k] = pad_sequences_1d(
         | 
| 372 | 
            +
                            [e["model_inputs"][k] for e in batch], dtype=torch.float32, fixed_length=None)
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                if "st_ed_indices" in model_inputs_keys:
         | 
| 375 | 
            +
                    batched_data["st_ed_indices"] = torch.stack(
         | 
| 376 | 
            +
                        [e["model_inputs"]["st_ed_indices"] for e in batch], dim=0)
         | 
| 377 | 
            +
                return batch_meta, batched_data
         | 
| 378 | 
            +
             | 
| 379 | 
            +
             | 
| 380 | 
            +
            def prepare_batch_inputs(batched_model_inputs, device, non_blocking=False):
         | 
| 381 | 
            +
                model_inputs = {}
         | 
| 382 | 
            +
                for k, v in batched_model_inputs.items():
         | 
| 383 | 
            +
                    if "feat" in k:
         | 
| 384 | 
            +
                        model_inputs[k] = v[0].to(device, non_blocking=non_blocking)
         | 
| 385 | 
            +
                        model_inputs[k.replace("feat", "mask")] = v[1].to(device, non_blocking=non_blocking)
         | 
| 386 | 
            +
                    else:
         | 
| 387 | 
            +
                        model_inputs[k] = v.to(device, non_blocking=non_blocking)
         | 
| 388 | 
            +
                return model_inputs
         | 
| 389 | 
            +
             | 
| 390 | 
            +
             | 
| 391 | 
            +
            if __name__ == '__main__':
         | 
| 392 | 
            +
                from baselines.crossmodal_moment_localization.config import BaseOptions
         | 
| 393 | 
            +
                options = BaseOptions().parse()
         |