File size: 5,045 Bytes
34ea7b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from transformers import AutoConfig, AutoModel
from PIL import Image
import torch
import llava
from datasets import load_dataset
from nvila_lite_2b_dev.tokenizer_utils import tokenize_conversation
import json
from torch.utils.data import Dataset, default_collate
from llava.mm_utils import process_image, process_images
import llava
import tqdm
import torch.nn.functional as F
import random
DEFAULT_IMAGE_TOKEN = '<image>'
class EasyDataset(Dataset):
def __init__(
self,
dataset,
config,
tokenizer,
device='cuda',
dtype=torch.float16
):
super().__init__()
self.dataset = dataset
self.config = config
self.device = device
self.dtype = dtype
self.tokenizer = tokenizer
def __len__(self):
# return len(self.data_list)
return self.n_samples
def __getitem__(self, index):
image = self.dataset[index]['image']
conversation = json.loads(self.dataset[index]['conversation'])[:2]
images = process_image(image, self.config, None, enable_dynamic_res=True)
conversation[0]["value"] = conversation[0]["value"].replace(
DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0]
)
input_ids = tokenize_conversation(conversation, self.tokenizer).unsqueeze(0)
return [image for image in images], input_ids
def main():
model_path = "Efficient-Large-Model/nvila_lite_2b_dev"
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
# print(config)
device1 = torch.device("cuda:0")
device2 = torch.device("cuda:1")
model_hf = AutoModel.from_config(config, trust_remote_code=True,device='cuda:0').to(device1)
model_vila = llava.load('Efficient-Large-Model/NVILA-Lite-2B', device='cuda:0').to(device2)
parameter_names = list(dict(model_hf.named_parameters()))
parameter_names_select = random.sample(parameter_names, 10)
grad_diff = {}
for name in parameter_names_select:
grad_diff[name] = {"Grad_L1": [], "Grad_L2": []}
config = model_hf.config
config.image_processor = model_hf.vision_tower.image_processor
dataset = load_dataset('Yirany/UniMM-Chat', split='train')
image_text_dataset = EasyDataset(dataset, config, model_hf.tokenizer)
results = {"L1_diff": [], "L2_diff": [], "Cosine_similarity": []}
for item in tqdm.tqdm(image_text_dataset):
media = {}
media1 = {}
media2 = {}
media['image'], input_ids = item
media1['image'] = [image.to(device1).half() for image in media['image']]
media2['image'] = [image.to(device2).half() for image in media['image']]
input_ids1 = input_ids.to(device1)
labels1 = torch.randint(0, len(model_hf.tokenizer), input_ids1.shape, dtype=input_ids.dtype).to(device1)
output1 = model_hf(input_ids=input_ids1, media=media1, labels=labels1)
input_ids2 = input_ids.to(device2)
labels2 = torch.randint(0, len(model_hf.tokenizer), input_ids2.shape, dtype=input_ids.dtype).to(device2)
output2 = model_vila(input_ids=input_ids2, media=media2, labels=labels2)
logits1 = output1.logits
logits2 = output2.logits
logits2 = logits2.to(device1)
l1_diff = torch.nn.functional.l1_loss(logits1, logits2).item()
l2_diff = torch.nn.functional.mse_loss(logits1, logits2).item()
cosine_sim = F.cosine_similarity(logits1, logits2, dim=-1).mean().item()
results["L1_diff"].append(l1_diff)
results["L2_diff"].append(l2_diff)
results["Cosine_similarity"].append(cosine_sim)
loss1 = output1.loss
loss2 = output2.loss
loss1.backward(retain_graph=True)
loss2.backward(retain_graph=True)
for name in parameter_names_select:
param1 = dict(model_hf.named_parameters())[name].grad
param2 = dict(model_vila.named_parameters())[name].grad
grad_l1 = F.l1_loss(param1, param2.to(device1)).item()
grad_l2 = F.mse_loss(param1, param2.to(device1)).item()
grad_diff[name]["Grad_L1"].append(grad_l1)
grad_diff[name]["Grad_L2"].append(grad_l2)
del param1, param2
del output1, output2, logits1, logits2, input_ids, input_ids1, input_ids2, media, media1, media2, labels1, labels2
del loss1, loss2
torch.cuda.empty_cache()
model_hf.zero_grad()
model_vila.zero_grad()
if len(results["L1_diff"])>100:
break
for name in parameter_names_select:
grad_diff[name] = {key: sum(values) / len(values) for key, values in grad_diff[name].items()}
final_results = {key: sum(values) / len(values) for key, values in results.items()}
for key, value in final_results.items():
print(f"{key}: {value:.6f}")
for name in parameter_names_select:
for key, value in grad_diff[name].items():
print(f"{name} {key}: {value:.6f}")
if __name__ == "__main__":
main()
|