SimToken / train.py
yfan07's picture
Add files using upload-large-folder tool
f1106d1 verified
import transformers
from datasets import REFAVS
from configs import args
from torch.utils.data import DataLoader
from functools import partial
from models.llava import conversation as conversation_lib
# from models.avs_model import VISAForCausalLM
from models.avs_model import Simtoken_ForCausalLM
import torch
from transformers import AutoConfig
from peft import LoraConfig, get_peft_model
from torch import optim
from torch.optim import AdamW
from transformers import get_cosine_schedule_with_warmup
from tqdm import tqdm
from utils import utility
import random
import numpy as np
import re
import time
import os
import warnings
warnings.filterwarnings("ignore")
from transformers import logging
logging.set_verbosity_error()
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
DEFAULT_VIDEO_TOKEN = "<video>"
AUDIO_TOKEN_INDEX = -300
DEFAULT_AUDIO_TOKEN = "<audio>"
def set_seed(seed: int = 42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def dict_to_cuda(input_dict):
for k, v in input_dict.items():
if isinstance(input_dict[k], torch.Tensor):
input_dict[k] = v.cuda(non_blocking=True)
elif (
isinstance(input_dict[k], list)
and len(input_dict[k]) > 0
and isinstance(input_dict[k][0], torch.Tensor)
):
input_dict[k] = [ele.cuda(non_blocking=True) for ele in v]
return input_dict
def tokenizer_image_audio_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, audio_token_index=AUDIO_TOKEN_INDEX, num_frames=10, return_tensors=None):
prompt_chunks = re.split(r'(<image>|<audio>|<video>)', prompt)
prompt_chunks = [chunk for chunk in prompt_chunks if chunk]
text_chunks = []
token_types = []
for chunk in prompt_chunks:
if chunk == "<image>":
token_types.append("image")
elif chunk == "<audio>":
token_types.append("audio")
elif chunk == "<video>":
token_types.append("video")
else:
text_chunks.append(chunk)
tokenized_chunks = [tokenizer(chunk).input_ids for chunk in text_chunks]
def insert_separators(text_chunks, tokenized_chunks, token_types, image_token_index, audio_token_index, num_frames):
input_ids = []
offset = 0
if (
len(tokenized_chunks) > 0
and len(tokenized_chunks[0]) > 0
and tokenized_chunks[0][0] == tokenizer.bos_token_id
):
offset = 1
input_ids.append(tokenized_chunks[0][0])
min_length = min(len(text_chunks), len(token_types))
for i in range(min_length):
input_ids.extend(tokenized_chunks[i][offset:])
if token_types[i] == "image":
input_ids.append(image_token_index)
elif token_types[i] == "audio":
input_ids.append(audio_token_index)
elif token_types[i] == "video":
input_ids.extend([image_token_index] * num_frames)
if len(text_chunks) > min_length:
input_ids.extend(tokenized_chunks[min_length][offset:])
return input_ids
input_ids = insert_separators(text_chunks, tokenized_chunks, token_types, image_token_index, audio_token_index, num_frames)
if return_tensors is not None:
if return_tensors == "pt":
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f"Unsupported tensor type: {return_tensors}")
return input_ids
def collate_fn(batch, tokenizer=None):
vids = []
images = []
image_clips = []
masks = []
conversations = []
audio_feats = []
image_feats = []
resizes = []
orgsizes = []
refs = []
refs_num = []
fids = []
for data in batch:
vids.append(data['vid'])
images.append(data['image'])
image_clips.append(data['img_clip'])
masks.append(data['mask'])
conversations.append(data['conversation'])
audio_feats.append(data['feat_aud'])
resizes.append(data['resize'])
orgsizes.append(data['orgsize'])
image_feats.append(data['feat_sam'])
refs_num.append(len(data['ref']))
fids.append(data['fids'])
refs.append(data['ref'][0])
# input_ids = [tokenizer_image_token(conv, tokenizer, return_tensors="pt") for conv in conversations]
input_ids = [tokenizer_image_audio_token(conv, tokenizer, return_tensors="pt") for conv in conversations] # list
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
attention_masks = input_ids.ne(tokenizer.pad_token_id)
ref_ids = [tokenizer_image_audio_token(ref, tokenizer, return_tensors="pt") for ref in refs]
conv = conversation_lib.default_conversation.copy()
labels = input_ids.clone()
# sep = conv.sep + conv.roles[1] + ": " # “###Assistant:”
sep = 'Sure, it is [SEG]'
for conversation, target in zip(conversations, labels):
parts = conversation.split(sep)
# print(parts)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
sep_len = len(tokenizer_image_audio_token(sep, tokenizer)) - 1
for i in range(len(parts)-1):
part_len = len(tokenizer_image_audio_token(parts[i], tokenizer)) - 2
target[cur_len: cur_len + part_len] = IGNORE_INDEX
cur_len += part_len + sep_len
target[cur_len:] = IGNORE_INDEX
return {"vids": vids,
"images": images, # list[B]:[T, 3, 1024, 1024]
"images_clip": image_clips, # list[B]:[T, 3, 224, 224]
"masks": masks, # list[B]:[num_ref, T, H, W]
"convs": conversations, # list[B]: str
"input_ids": input_ids, # list[B]:[max_len]
"attention_masks": attention_masks, # list[B]:[max_len]
"labels": labels, # list[B]:[max_len]
"audio_feats": audio_feats, # list[B]:[10, 128]
"resizes": resizes, # list[B]
"orgsizes": orgsizes, # list[B]
"image_feats": image_feats,
"ref_ids": ref_ids, # list[B]: [ref_id_len]
"refs_num": refs_num,
"fids": fids
}
import torch.multiprocessing as mp
if __name__ == "__main__":
mp.set_start_method("spawn")
set_seed(42)
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.mllm,
cache_dir=None,
model_max_length=2048, # 2048
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
num_added_tokens = tokenizer.add_tokens("[SEG]")
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] # 32000
print("seg_token_idx: ", seg_token_idx)
train_dataset = REFAVS('train', args, tokenizer, input_type='refer')
val_dataset_s_refer = REFAVS('test_s', args, tokenizer, input_type='refer')
val_dataset_u_refer = REFAVS('test_u', args, tokenizer, input_type='refer')
val_dataset_n_refer = REFAVS('test_n', args, tokenizer, input_type='refer')
g = torch.Generator()
g.manual_seed(42)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, worker_init_fn=seed_worker,collate_fn=partial(collate_fn, tokenizer=tokenizer), generator=g)
val_dataloader_s_refer = DataLoader(val_dataset_s_refer, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
val_dataloader_u_refer = DataLoader(val_dataset_u_refer, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
val_dataloader_n_refer = DataLoader(val_dataset_n_refer, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
model_args = {
"train_mask_decoder": True,
"out_dim": 256, # 256
"ce_loss_weight": 1.0,
"dice_loss_weight": 0.5,
"bce_loss_weight": 2.0,
"seg_token_idx": seg_token_idx,
"vision_pretrained": args.vision_pretrained, # sam_vit_h_xxx.pth
"vision_tower": args.vision_tower,
"use_im_start_end": False,
"compress": args.compress,
"start": args.start,
}
model = Simtoken_ForCausalLM.from_pretrained(args.mllm, torch_dtype=torch.float32, low_cpu_mem_usage=True, **model_args)
print("\nmodel loaded")
model.config.eos_token_id = tokenizer.eos_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model.get_model().initialize_vision_modules(model.get_model().config)
vision_tower = model.get_model().get_vision_tower()
vision_tower.to(dtype=torch.float32, device="cuda")
model_args_from_pt = AutoConfig.from_pretrained(args.mllm)
model_args_from_pt.use_cluster = True
model_args_from_pt.freeze = False
model_args_from_pt.mm_tune = True
model_args_from_pt.spatial_cluster_rate0 = 64
model_args_from_pt.spatial_cluster_rate1 = 32
model_args_from_pt.spatial_cluster_rate2 = 16
model_args_from_pt.temporal_cluster_rate = 0.0625
model_args_from_pt.use_cluster = True
model_args_from_pt.vision_tune = False
model.get_model().initialize_cluster_modules(model_args_from_pt)
model.get_model().initialize_lisa_modules(model.get_model().config)
for p in vision_tower.parameters():
p.requires_grad = False
for p in model.get_model().mm_projector.parameters():
p.requires_grad = False
lora_r = 8
target_modules = "q_proj,v_proj"
if lora_r > 0:
def find_linear_layers(model, lora_target_modules):
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if (
isinstance(module, cls)
and all(
[
x not in name
for x in [
"visual_model",
"vision_tower",
"mm_projector",
"text_hidden_fcs",
"audio_feature_layer",
]
]
)
and any([x in name for x in lora_target_modules])
):
lora_module_names.add(name)
return sorted(list(lora_module_names))
lora_alpha = 16
lora_dropout = 0.05
lora_target_modules = find_linear_layers(
model, target_modules.split(",")
)
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
print("\nLora deployed")
model.print_trainable_parameters()
model = model.to("cuda")
model.resize_token_embeddings(len(tokenizer))
for name, param in model.audio_feature_layer.named_parameters():
param.requires_grad = True
# print(name, param.requires_grad)
# for name, param in model.token_compressor.named_parameters():
# param.requires_grad = True
for n, p in model.named_parameters():
if any(
[
x in n
for x in ["lm_head", "embed_tokens", "mask_decoder", "text_hidden_fcs"]
]
):
p.requires_grad = True
print("will save train model")
def valuate(model, dataloader, args, name):
model.eval()
total_iou = 0
total_fscore = 0
count = 0
for batch in tqdm(dataloader, desc=f"Evaluating on {name}"):
input_dict = dict_to_cuda(batch)
with torch.no_grad():
output_dict = model.forward(images=input_dict["images"],
images_clip=input_dict["images_clip"],
audio_features=input_dict["audio_feats"],
image_features=input_dict["image_feats"],
input_ids=input_dict["input_ids"],
labels=input_dict["labels"],
attention_masks=input_dict["attention_masks"],
masks_list=input_dict["masks"],
resize_list=input_dict["resizes"],
orgsize_list=input_dict["orgsizes"],
conversation_list=input_dict["convs"],
refs_num=input_dict["refs_num"],
fids=input_dict["fids"],
vids=input_dict["vids"],
contrast=args.ct_weight,
ref_ids=input_dict["ref_ids"],
inference=True)
pred_masks = output_dict["pred_masks"] # list[B]:[num_seg, T, H, W]
gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
for i in range(len(pred_masks)):
num_seg = pred_masks[i].shape[0]
T = pred_masks[i].shape[1]
iou = utility.mask_iou(pred_masks[i], gt_masks[i])
fscore = utility.Eval_Fmeasure(pred_masks[i], gt_masks[i], None)
total_iou += iou * num_seg * T
total_fscore += fscore * num_seg * T
count += num_seg * T
print(f"\n valuate on {name}: miou: {total_iou/count} fscore: {total_fscore/count}")
with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
f.write(f"valuate on {name}: miou {total_iou/count} true fscore {total_fscore/count} \n")
# ---------------train------------------------------------------
model.train()
epochs = args.epochs
print("init lr:", args.lr)
optimizer = AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01)
gradient_accumulation_steps = int(16 // args.batch_size)
step_per_epoch = len(train_dataloader) // gradient_accumulation_steps
total_steps = epochs * step_per_epoch
warmup_steps = int(total_steps * 0.1)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
)
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
running_loss = 0.0
loop = tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}/{epochs}")
for step, batch in enumerate(loop):
input_dict = dict_to_cuda(batch)
output_dict = model.forward(images=input_dict["images"],
images_clip=input_dict["images_clip"],
audio_features=input_dict["audio_feats"],
image_features=input_dict["image_feats"],
input_ids=input_dict["input_ids"],
labels=input_dict["labels"],
attention_masks=input_dict["attention_masks"],
masks_list=input_dict["masks"],
resize_list=input_dict["resizes"],
orgsize_list=input_dict["orgsizes"],
conversation_list=input_dict["convs"],
refs_num=input_dict["refs_num"],
fids=input_dict["fids"],
vids=input_dict["vids"],
contrast=args.ct_weight,
ref_ids=input_dict["ref_ids"],
epoch=epoch,
inference=False)
loss = output_dict["loss"]
loss = loss / gradient_accumulation_steps
loss.backward()
running_loss += loss.item()
if (step + 1) % gradient_accumulation_steps == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
current_lr = scheduler.get_lr()[0]
loop.set_postfix(lr=current_lr, loss=running_loss / ((step + 1) / gradient_accumulation_steps))
print(f" Epoch {epoch + 1}, Loss:{running_loss / ((step + 1) / gradient_accumulation_steps) :.4f}, Learning Rate:{scheduler.get_last_lr()[0]:.6f}")
with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
f.write(f"Epoch {epoch}: running_loss {running_loss / len(train_dataloader) * gradient_accumulation_steps} Learning Rate:{scheduler.get_last_lr()[0]:.6f}\n")
torch.save(model.state_dict(), os.path.join(args.checkpoint_root, f"{args.name}.pth"))
print(f"trained model saved as {args.name}.pth")
# ---------------test on seen & unseen ------------------------------------------
model.eval()
valuate(model, val_dataloader_s_refer, args, 'test_s_refer')
valuate(model, val_dataloader_u_refer, args, 'test_u_refer')
# ---------------test on Null ------------------------------------------
model.eval()
total_metric = 0
count = 0
for batch in tqdm(val_dataloader_n_refer, desc=f"Evaluating on test_n_refer"):
input_dict = dict_to_cuda(batch)
with torch.no_grad():
output_dict = model.forward(images=input_dict["images"],
images_clip=input_dict["images_clip"],
audio_features=input_dict["audio_feats"],
image_features=input_dict["image_feats"],
input_ids=input_dict["input_ids"],
labels=input_dict["labels"],
attention_masks=input_dict["attention_masks"],
masks_list=input_dict["masks"],
resize_list=input_dict["resizes"],
orgsize_list=input_dict["orgsizes"],
conversation_list=input_dict["convs"],
refs_num=input_dict["refs_num"],
fids=input_dict["fids"],
vids=input_dict["vids"],
contrast=args.ct_weight,
ref_ids=input_dict["ref_ids"],
inference=True)
pred_masks = output_dict["pred_masks"] # list[B]:[num_seg, T, H, W]
gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
for i in range(len(pred_masks)):
num_seg = pred_masks[i].shape[0]
T = pred_masks[i].shape[1]
null_metric = utility.metric_s_for_null(pred_masks[i])
total_metric += null_metric * num_seg * T
count += num_seg * T
print(f"\n valuate on test_n_refer, metric: {total_metric/count}")
with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
f.write(f"\n valuate on test_n_refer: metric {total_metric/count} \n")