| import transformers |
| from datasets import REFAVS |
| from configs import args |
| from torch.utils.data import DataLoader |
| from functools import partial |
| from models.llava import conversation as conversation_lib |
| |
| from models.avs_model import Simtoken_ForCausalLM |
|
|
| import torch |
| from transformers import AutoConfig |
| from peft import LoraConfig, get_peft_model |
| from torch import optim |
| from torch.optim import AdamW |
| from transformers import get_cosine_schedule_with_warmup |
| from tqdm import tqdm |
|
|
| from utils import utility |
| import random |
| import numpy as np |
| import re |
| import time |
| import os |
|
|
|
|
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| from transformers import logging |
| logging.set_verbosity_error() |
|
|
|
|
| IGNORE_INDEX = -100 |
| IMAGE_TOKEN_INDEX = -200 |
| DEFAULT_IMAGE_TOKEN = "<image>" |
| DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>" |
| DEFAULT_IM_START_TOKEN = "<im_start>" |
| DEFAULT_IM_END_TOKEN = "<im_end>" |
| DEFAULT_VIDEO_TOKEN = "<video>" |
|
|
| AUDIO_TOKEN_INDEX = -300 |
| DEFAULT_AUDIO_TOKEN = "<audio>" |
|
|
| def set_seed(seed: int = 42): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
| os.environ["PYTHONHASHSEED"] = str(seed) |
| os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" |
|
|
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
|
|
| def seed_worker(worker_id): |
|
|
| worker_seed = torch.initial_seed() % 2**32 |
| np.random.seed(worker_seed) |
| random.seed(worker_seed) |
|
|
|
|
| def dict_to_cuda(input_dict): |
| for k, v in input_dict.items(): |
| if isinstance(input_dict[k], torch.Tensor): |
| input_dict[k] = v.cuda(non_blocking=True) |
| elif ( |
| isinstance(input_dict[k], list) |
| and len(input_dict[k]) > 0 |
| and isinstance(input_dict[k][0], torch.Tensor) |
| ): |
| input_dict[k] = [ele.cuda(non_blocking=True) for ele in v] |
| return input_dict |
|
|
| def tokenizer_image_audio_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, audio_token_index=AUDIO_TOKEN_INDEX, num_frames=10, return_tensors=None): |
|
|
| prompt_chunks = re.split(r'(<image>|<audio>|<video>)', prompt) |
|
|
|
|
| prompt_chunks = [chunk for chunk in prompt_chunks if chunk] |
|
|
| text_chunks = [] |
| token_types = [] |
| for chunk in prompt_chunks: |
| if chunk == "<image>": |
| token_types.append("image") |
| elif chunk == "<audio>": |
| token_types.append("audio") |
| elif chunk == "<video>": |
| token_types.append("video") |
| else: |
| text_chunks.append(chunk) |
|
|
| tokenized_chunks = [tokenizer(chunk).input_ids for chunk in text_chunks] |
|
|
| def insert_separators(text_chunks, tokenized_chunks, token_types, image_token_index, audio_token_index, num_frames): |
| input_ids = [] |
| offset = 0 |
| if ( |
| len(tokenized_chunks) > 0 |
| and len(tokenized_chunks[0]) > 0 |
| and tokenized_chunks[0][0] == tokenizer.bos_token_id |
| ): |
| offset = 1 |
| input_ids.append(tokenized_chunks[0][0]) |
|
|
| min_length = min(len(text_chunks), len(token_types)) |
| for i in range(min_length): |
|
|
| input_ids.extend(tokenized_chunks[i][offset:]) |
|
|
| if token_types[i] == "image": |
| input_ids.append(image_token_index) |
| elif token_types[i] == "audio": |
| input_ids.append(audio_token_index) |
| elif token_types[i] == "video": |
| input_ids.extend([image_token_index] * num_frames) |
|
|
|
|
| if len(text_chunks) > min_length: |
| input_ids.extend(tokenized_chunks[min_length][offset:]) |
|
|
| return input_ids |
|
|
| input_ids = insert_separators(text_chunks, tokenized_chunks, token_types, image_token_index, audio_token_index, num_frames) |
|
|
| if return_tensors is not None: |
| if return_tensors == "pt": |
| return torch.tensor(input_ids, dtype=torch.long) |
| raise ValueError(f"Unsupported tensor type: {return_tensors}") |
| return input_ids |
|
|
| def collate_fn(batch, tokenizer=None): |
| vids = [] |
| images = [] |
| image_clips = [] |
| masks = [] |
| conversations = [] |
| audio_feats = [] |
| image_feats = [] |
| resizes = [] |
| orgsizes = [] |
| refs = [] |
| refs_num = [] |
| fids = [] |
|
|
|
|
| for data in batch: |
| vids.append(data['vid']) |
| images.append(data['image']) |
| image_clips.append(data['img_clip']) |
| masks.append(data['mask']) |
| conversations.append(data['conversation']) |
| audio_feats.append(data['feat_aud']) |
| resizes.append(data['resize']) |
| orgsizes.append(data['orgsize']) |
| image_feats.append(data['feat_sam']) |
| refs_num.append(len(data['ref'])) |
| fids.append(data['fids']) |
|
|
| refs.append(data['ref'][0]) |
|
|
|
|
| |
| input_ids = [tokenizer_image_audio_token(conv, tokenizer, return_tensors="pt") for conv in conversations] |
| input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) |
| attention_masks = input_ids.ne(tokenizer.pad_token_id) |
|
|
| ref_ids = [tokenizer_image_audio_token(ref, tokenizer, return_tensors="pt") for ref in refs] |
|
|
| conv = conversation_lib.default_conversation.copy() |
| labels = input_ids.clone() |
|
|
| |
| sep = 'Sure, it is [SEG]' |
|
|
| for conversation, target in zip(conversations, labels): |
|
|
| parts = conversation.split(sep) |
| |
|
|
| cur_len = 1 |
| target[:cur_len] = IGNORE_INDEX |
|
|
| sep_len = len(tokenizer_image_audio_token(sep, tokenizer)) - 1 |
|
|
|
|
| for i in range(len(parts)-1): |
| part_len = len(tokenizer_image_audio_token(parts[i], tokenizer)) - 2 |
| target[cur_len: cur_len + part_len] = IGNORE_INDEX |
| cur_len += part_len + sep_len |
|
|
| target[cur_len:] = IGNORE_INDEX |
|
|
|
|
| return {"vids": vids, |
| "images": images, |
| "images_clip": image_clips, |
| "masks": masks, |
| "convs": conversations, |
| "input_ids": input_ids, |
| "attention_masks": attention_masks, |
| "labels": labels, |
| "audio_feats": audio_feats, |
| "resizes": resizes, |
| "orgsizes": orgsizes, |
| "image_feats": image_feats, |
| "ref_ids": ref_ids, |
| "refs_num": refs_num, |
| "fids": fids |
| } |
|
|
|
|
| import torch.multiprocessing as mp |
| if __name__ == "__main__": |
| mp.set_start_method("spawn") |
| set_seed(42) |
| tokenizer = transformers.AutoTokenizer.from_pretrained( |
| args.mllm, |
| cache_dir=None, |
| model_max_length=2048, |
| padding_side="right", |
| use_fast=False, |
| ) |
|
|
| tokenizer.pad_token = tokenizer.unk_token |
| num_added_tokens = tokenizer.add_tokens("[SEG]") |
| seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] |
| print("seg_token_idx: ", seg_token_idx) |
|
|
| train_dataset = REFAVS('train', args, tokenizer, input_type='refer') |
| val_dataset_s_refer = REFAVS('test_s', args, tokenizer, input_type='refer') |
| val_dataset_u_refer = REFAVS('test_u', args, tokenizer, input_type='refer') |
| val_dataset_n_refer = REFAVS('test_n', args, tokenizer, input_type='refer') |
|
|
|
|
| g = torch.Generator() |
| g.manual_seed(42) |
|
|
| train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, worker_init_fn=seed_worker,collate_fn=partial(collate_fn, tokenizer=tokenizer), generator=g) |
|
|
| val_dataloader_s_refer = DataLoader(val_dataset_s_refer, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer)) |
| val_dataloader_u_refer = DataLoader(val_dataset_u_refer, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer)) |
| val_dataloader_n_refer = DataLoader(val_dataset_n_refer, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer)) |
|
|
|
|
| model_args = { |
| "train_mask_decoder": True, |
| "out_dim": 256, |
| "ce_loss_weight": 1.0, |
| "dice_loss_weight": 0.5, |
| "bce_loss_weight": 2.0, |
| "seg_token_idx": seg_token_idx, |
| "vision_pretrained": args.vision_pretrained, |
| "vision_tower": args.vision_tower, |
| "use_im_start_end": False, |
| "compress": args.compress, |
| "start": args.start, |
| } |
|
|
| model = Simtoken_ForCausalLM.from_pretrained(args.mllm, torch_dtype=torch.float32, low_cpu_mem_usage=True, **model_args) |
| print("\nmodel loaded") |
|
|
| model.config.eos_token_id = tokenizer.eos_token_id |
| model.config.bos_token_id = tokenizer.bos_token_id |
| model.config.pad_token_id = tokenizer.pad_token_id |
|
|
| model.enable_input_require_grads() |
| model.gradient_checkpointing_enable() |
|
|
| model.get_model().initialize_vision_modules(model.get_model().config) |
| vision_tower = model.get_model().get_vision_tower() |
| vision_tower.to(dtype=torch.float32, device="cuda") |
|
|
| model_args_from_pt = AutoConfig.from_pretrained(args.mllm) |
| model_args_from_pt.use_cluster = True |
| model_args_from_pt.freeze = False |
| model_args_from_pt.mm_tune = True |
| model_args_from_pt.spatial_cluster_rate0 = 64 |
| model_args_from_pt.spatial_cluster_rate1 = 32 |
| model_args_from_pt.spatial_cluster_rate2 = 16 |
| model_args_from_pt.temporal_cluster_rate = 0.0625 |
| model_args_from_pt.use_cluster = True |
| model_args_from_pt.vision_tune = False |
| model.get_model().initialize_cluster_modules(model_args_from_pt) |
|
|
| model.get_model().initialize_lisa_modules(model.get_model().config) |
|
|
| for p in vision_tower.parameters(): |
| p.requires_grad = False |
| for p in model.get_model().mm_projector.parameters(): |
| p.requires_grad = False |
|
|
| lora_r = 8 |
| target_modules = "q_proj,v_proj" |
| if lora_r > 0: |
|
|
| def find_linear_layers(model, lora_target_modules): |
| cls = torch.nn.Linear |
| lora_module_names = set() |
|
|
| for name, module in model.named_modules(): |
| if ( |
|
|
| isinstance(module, cls) |
|
|
| and all( |
| [ |
| x not in name |
| for x in [ |
| "visual_model", |
| "vision_tower", |
| "mm_projector", |
| "text_hidden_fcs", |
| "audio_feature_layer", |
| ] |
| ] |
| ) |
|
|
| and any([x in name for x in lora_target_modules]) |
| ): |
|
|
| lora_module_names.add(name) |
| return sorted(list(lora_module_names)) |
|
|
|
|
| lora_alpha = 16 |
| lora_dropout = 0.05 |
|
|
| lora_target_modules = find_linear_layers( |
| model, target_modules.split(",") |
| ) |
| lora_config = LoraConfig( |
| r=lora_r, |
| lora_alpha=lora_alpha, |
| target_modules=lora_target_modules, |
| lora_dropout=lora_dropout, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
|
|
| model = get_peft_model(model, lora_config) |
| print("\nLora deployed") |
|
|
| model.print_trainable_parameters() |
|
|
| model = model.to("cuda") |
| model.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
| for name, param in model.audio_feature_layer.named_parameters(): |
| param.requires_grad = True |
| |
| |
| |
|
|
|
|
| for n, p in model.named_parameters(): |
| if any( |
| [ |
| x in n |
| for x in ["lm_head", "embed_tokens", "mask_decoder", "text_hidden_fcs"] |
| ] |
| ): |
| p.requires_grad = True |
|
|
|
|
| print("will save train model") |
|
|
| def valuate(model, dataloader, args, name): |
| model.eval() |
|
|
| total_iou = 0 |
| total_fscore = 0 |
| count = 0 |
|
|
| for batch in tqdm(dataloader, desc=f"Evaluating on {name}"): |
| input_dict = dict_to_cuda(batch) |
| with torch.no_grad(): |
| output_dict = model.forward(images=input_dict["images"], |
| images_clip=input_dict["images_clip"], |
| audio_features=input_dict["audio_feats"], |
| image_features=input_dict["image_feats"], |
| input_ids=input_dict["input_ids"], |
| labels=input_dict["labels"], |
| attention_masks=input_dict["attention_masks"], |
| masks_list=input_dict["masks"], |
| resize_list=input_dict["resizes"], |
| orgsize_list=input_dict["orgsizes"], |
| conversation_list=input_dict["convs"], |
| refs_num=input_dict["refs_num"], |
| fids=input_dict["fids"], |
| vids=input_dict["vids"], |
| contrast=args.ct_weight, |
| ref_ids=input_dict["ref_ids"], |
| inference=True) |
| pred_masks = output_dict["pred_masks"] |
| gt_masks = output_dict["gt_masks"] |
| for i in range(len(pred_masks)): |
| num_seg = pred_masks[i].shape[0] |
| T = pred_masks[i].shape[1] |
| iou = utility.mask_iou(pred_masks[i], gt_masks[i]) |
| fscore = utility.Eval_Fmeasure(pred_masks[i], gt_masks[i], None) |
|
|
| total_iou += iou * num_seg * T |
| total_fscore += fscore * num_seg * T |
| count += num_seg * T |
|
|
| print(f"\n valuate on {name}: miou: {total_iou/count} fscore: {total_fscore/count}") |
|
|
| with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f: |
| f.write(f"valuate on {name}: miou {total_iou/count} true fscore {total_fscore/count} \n") |
|
|
|
|
| |
|
|
| model.train() |
| epochs = args.epochs |
| print("init lr:", args.lr) |
| optimizer = AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01) |
|
|
| gradient_accumulation_steps = int(16 // args.batch_size) |
| step_per_epoch = len(train_dataloader) // gradient_accumulation_steps |
| total_steps = epochs * step_per_epoch |
| warmup_steps = int(total_steps * 0.1) |
|
|
| scheduler = get_cosine_schedule_with_warmup( |
| optimizer, |
| num_warmup_steps=warmup_steps, |
| num_training_steps=total_steps, |
| ) |
|
|
|
|
| for epoch in range(epochs): |
|
|
| model.train() |
| optimizer.zero_grad() |
| running_loss = 0.0 |
|
|
| loop = tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}/{epochs}") |
| for step, batch in enumerate(loop): |
| input_dict = dict_to_cuda(batch) |
| output_dict = model.forward(images=input_dict["images"], |
| images_clip=input_dict["images_clip"], |
| audio_features=input_dict["audio_feats"], |
| image_features=input_dict["image_feats"], |
| input_ids=input_dict["input_ids"], |
| labels=input_dict["labels"], |
| attention_masks=input_dict["attention_masks"], |
| masks_list=input_dict["masks"], |
| resize_list=input_dict["resizes"], |
| orgsize_list=input_dict["orgsizes"], |
| conversation_list=input_dict["convs"], |
| refs_num=input_dict["refs_num"], |
| fids=input_dict["fids"], |
| vids=input_dict["vids"], |
| contrast=args.ct_weight, |
| ref_ids=input_dict["ref_ids"], |
| epoch=epoch, |
| inference=False) |
|
|
| loss = output_dict["loss"] |
| loss = loss / gradient_accumulation_steps |
| loss.backward() |
| running_loss += loss.item() |
|
|
|
|
| if (step + 1) % gradient_accumulation_steps == 0: |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
|
|
| current_lr = scheduler.get_lr()[0] |
| loop.set_postfix(lr=current_lr, loss=running_loss / ((step + 1) / gradient_accumulation_steps)) |
|
|
| print(f" Epoch {epoch + 1}, Loss:{running_loss / ((step + 1) / gradient_accumulation_steps) :.4f}, Learning Rate:{scheduler.get_last_lr()[0]:.6f}") |
|
|
|
|
| with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f: |
| f.write(f"Epoch {epoch}: running_loss {running_loss / len(train_dataloader) * gradient_accumulation_steps} Learning Rate:{scheduler.get_last_lr()[0]:.6f}\n") |
|
|
|
|
| torch.save(model.state_dict(), os.path.join(args.checkpoint_root, f"{args.name}.pth")) |
| print(f"trained model saved as {args.name}.pth") |
|
|
| |
| model.eval() |
|
|
| valuate(model, val_dataloader_s_refer, args, 'test_s_refer') |
| valuate(model, val_dataloader_u_refer, args, 'test_u_refer') |
|
|
| |
| model.eval() |
|
|
| total_metric = 0 |
| count = 0 |
|
|
| for batch in tqdm(val_dataloader_n_refer, desc=f"Evaluating on test_n_refer"): |
| input_dict = dict_to_cuda(batch) |
| with torch.no_grad(): |
| output_dict = model.forward(images=input_dict["images"], |
| images_clip=input_dict["images_clip"], |
| audio_features=input_dict["audio_feats"], |
| image_features=input_dict["image_feats"], |
| input_ids=input_dict["input_ids"], |
| labels=input_dict["labels"], |
| attention_masks=input_dict["attention_masks"], |
| masks_list=input_dict["masks"], |
| resize_list=input_dict["resizes"], |
| orgsize_list=input_dict["orgsizes"], |
| conversation_list=input_dict["convs"], |
| refs_num=input_dict["refs_num"], |
| fids=input_dict["fids"], |
| vids=input_dict["vids"], |
| contrast=args.ct_weight, |
| ref_ids=input_dict["ref_ids"], |
| inference=True) |
| pred_masks = output_dict["pred_masks"] |
| gt_masks = output_dict["gt_masks"] |
| for i in range(len(pred_masks)): |
| num_seg = pred_masks[i].shape[0] |
| T = pred_masks[i].shape[1] |
| null_metric = utility.metric_s_for_null(pred_masks[i]) |
|
|
| total_metric += null_metric * num_seg * T |
| count += num_seg * T |
|
|
|
|
| print(f"\n valuate on test_n_refer, metric: {total_metric/count}") |
|
|
| with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f: |
| f.write(f"\n valuate on test_n_refer: metric {total_metric/count} \n") |