berkamphoon commited on
Commit
16eb3ef
·
verified ·
1 Parent(s): 9aec233

Training in progress, epoch 0

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: google/medgemma-27b-it
3
+ library_name: transformers
4
+ model_name: medgemma-27b-it-dr5
5
+ tags:
6
+ - generated_from_trainer
7
+ - trl
8
+ - sft
9
+ licence: license
10
+ ---
11
+
12
+ # Model Card for medgemma-27b-it-dr5
13
+
14
+ This model is a fine-tuned version of [google/medgemma-27b-it](https://huggingface.co/google/medgemma-27b-it).
15
+ It has been trained using [TRL](https://github.com/huggingface/trl).
16
+
17
+ ## Quick start
18
+
19
+ ```python
20
+ from transformers import pipeline
21
+
22
+ question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?"
23
+ generator = pipeline("text-generation", model="berkamphoon/medgemma-27b-it-dr5", device="cuda")
24
+ output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0]
25
+ print(output["generated_text"])
26
+ ```
27
+
28
+ ## Training procedure
29
+
30
+ [<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/yoon307-kaist/medgemma-27b-it-dr5-Project/runs/mbxoj7k5)
31
+
32
+
33
+ This model was trained with SFT.
34
+
35
+ ### Framework versions
36
+
37
+ - TRL: 0.19.0
38
+ - Transformers: 4.51.3
39
+ - Pytorch: 2.5.0
40
+ - Datasets: 3.6.0
41
+ - Tokenizers: 0.21.1
42
+
43
+ ## Citations
44
+
45
+
46
+
47
+ Cite TRL as:
48
+
49
+ ```bibtex
50
+ @misc{vonwerra2022trl,
51
+ title = {{TRL: Transformer Reinforcement Learning}},
52
+ author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec},
53
+ year = 2020,
54
+ journal = {GitHub repository},
55
+ publisher = {GitHub},
56
+ howpublished = {\url{https://github.com/huggingface/trl}}
57
+ }
58
+ ```
adapter_config.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "google/medgemma-27b-it",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 16,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.05,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": [
22
+ "lm_head",
23
+ "embed_tokens"
24
+ ],
25
+ "peft_type": "LORA",
26
+ "qalora_group_size": 16,
27
+ "r": 16,
28
+ "rank_pattern": {},
29
+ "revision": null,
30
+ "target_modules": [
31
+ "v_proj",
32
+ "gate_proj",
33
+ "fc2",
34
+ "k_proj",
35
+ "out_proj",
36
+ "q_proj",
37
+ "o_proj",
38
+ "down_proj",
39
+ "up_proj",
40
+ "fc1"
41
+ ],
42
+ "task_type": "CAUSAL_LM",
43
+ "trainable_token_indices": null,
44
+ "use_dora": false,
45
+ "use_qalora": false,
46
+ "use_rslora": false
47
+ }
adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:193d5bc56b229dcd16a327b58b3d06056ba3d4a25c915706b577c5185a762759
3
+ size 11766077184
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<image_soft_token>": 262144
3
+ }
runs/Jul18_11-30-32_seribizon/events.out.tfevents.1752852638.seribizon.200344.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2974ec26d6853b6c599a617228a9030894a8d4f99e3aec1935aa4d2dda5b0979
3
+ size 9306
special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "boi_token": "<start_of_image>",
3
+ "bos_token": {
4
+ "content": "<bos>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "eoi_token": "<end_of_image>",
11
+ "eos_token": {
12
+ "content": "<eos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "image_token": "<image_soft_token>",
19
+ "pad_token": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ebf1915455f8237564395182c49e3c685cfe3533b3d50ec6d49ce65ec43c32e
3
+ size 33384723
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
3
+ size 4689074
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
train_medgemma_ft_copy.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division, print_function
2
+
3
+
4
+ # medqa gsm8k openbookqa bioasq pubmedqa squad_v2
5
+
6
+ # === Base ===
7
+ import os
8
+ import os.path as osp
9
+ import random
10
+ import argparse
11
+ import logging
12
+ from tqdm import tqdm
13
+ from matplotlib import pyplot as plt
14
+ import pdb
15
+ from PIL import Image
16
+ import shutil
17
+ import os
18
+
19
+ # === DL ===
20
+ import numpy as np
21
+ import torch
22
+ import torch.backends.cudnn as cudnn
23
+ from torch.utils.data import DataLoader
24
+ from torch.utils.tensorboard import SummaryWriter
25
+
26
+ # === Custom ===
27
+ import tools.imutils as imutils
28
+ import tools.utils as utils
29
+ import tools.pyutils as pyutils
30
+ from tools.utils import compute_es_auc, compute_group_auc, ImprovedBalancedBatchSampler, compute_es_auc_multi
31
+
32
+ # === Evaluation ===
33
+ from sklearn.metrics import roc_curve, accuracy_score, roc_auc_score
34
+
35
+ # === Transformers ===
36
+ from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig, pipeline
37
+ from peft import LoraConfig, get_peft_model
38
+ from trl import SFTTrainer, SFTConfig
39
+ import wandb
40
+
41
+ # === Label Masking Function ===
42
+ def mask_until_after_assistant(labels: torch.Tensor, tokenizer, assistant_token_ids: list):
43
+ for i in range(labels.size(0)):
44
+ for j in range(labels.size(1) - len(assistant_token_ids) + 1):
45
+ if torch.equal(labels[i, j:j+len(assistant_token_ids)], torch.tensor(assistant_token_ids, device=labels.device)):
46
+ labels[i, :j + len(assistant_token_ids)] = -100 # ASSISTANT: 까지 마스킹
47
+ break
48
+ return labels
49
+
50
+
51
+ # === Collate Function ===
52
+ def collate_fn(examples):
53
+ texts = []
54
+ images = []
55
+ for example in examples:
56
+ image = example["image"].convert("RGB")
57
+ image = image.resize((512,512))
58
+ images.append([image])
59
+ texts.append(processor.apply_chat_template(
60
+ example["messages"], add_generation_prompt=False, tokenize=False
61
+ ).strip())
62
+
63
+ # Tokenize the texts and process the images
64
+ batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
65
+
66
+ # The labels are the input_ids, with the padding and image tokens masked in
67
+ # the loss computation
68
+ labels = batch["input_ids"].clone()
69
+
70
+ # Mask image tokens
71
+ image_token_id = [
72
+ processor.tokenizer.convert_tokens_to_ids(
73
+ processor.tokenizer.special_tokens_map["boi_token"]
74
+ )
75
+ ]
76
+ # Mask tokens that are not used in the loss computation
77
+ labels[labels == processor.tokenizer.pad_token_id] = -100
78
+ labels[labels == image_token_id] = -100
79
+ labels[labels == 262144] = -100
80
+
81
+ labels = mask_until_after_assistant(labels, processor.tokenizer, ASST_ID)
82
+ labels[:,-1] = -100
83
+
84
+ batch["labels"] = labels
85
+ # pdb.set_trace()
86
+ return batch
87
+
88
+ def format_data(sample):
89
+ label = 'negative' if sample[task_idx] == '0.0' else 'positive'
90
+ prompt = f"Please diagnose whether the {disease_name} exist or not based on the given image.\n"
91
+
92
+ # pdb.set_trace()
93
+ example = {}
94
+ example["image"] = Image.open(os.path.join(img_root_path, sample[1]))
95
+ example["label"] = 0 if sample[task_idx]== '0,0' else 1
96
+ example["messages"] = [
97
+ {"role": "system", "content": [{"type": "text", "text": system_message}]},
98
+ {"role": "user", "content": [
99
+ # {"type": "image", "image": os.path.join(img_root_path, sample[1])},
100
+ {"type": "image"},
101
+ {"type": "text", "text": prompt},
102
+ ]},
103
+ {"role": "assistant", "content": [{"type": "text", "text": str(label)}]}
104
+ ]
105
+
106
+ return example
107
+
108
+ def format_data_for_inference(sample):
109
+ prompt = f"Please diagnose whether the {disease_name} exist or not based on the given image.\n"
110
+
111
+ # pdb.set_trace()
112
+ example = {}
113
+ example["image"] = Image.open(os.path.join(img_root_path, sample[1]))
114
+ # example["label"] = 0 if sample[task_idx]== '0,0' else 1
115
+ example["messages"] = [
116
+ {"role": "system", "content": [{"type": "text", "text": system_message}]},
117
+ {"role": "user", "content": [
118
+ # {"type": "image", "image": os.path.join(img_root_path, sample[1])},
119
+ {"type": "image"},
120
+ {"type": "text", "text": prompt+"\n"},
121
+ ]},
122
+ # {"role": "assistant", "content": [{"type": "text", "text": str(label)}]}
123
+ ]
124
+ # prompt = f"Please diagnose whether the {disease_name} exist or not based on the given image."
125
+ # return [
126
+ # {"role": "system", "content": [{"type": "text", "text": system_message}]},
127
+ # {"role": "user", "content": [
128
+ # {"type": "image", "image": os.path.join(img_root_path, sample[1])},
129
+ # {"type": "text", "text": prompt}
130
+ # ]}
131
+ # ]
132
+ return example
133
+
134
+ # === Logit Preprocessing ===
135
+ def slice_logits(logits, labels):
136
+ if isinstance(logits, (tuple, list)):
137
+ logits = logits[0]
138
+ return logits.detach().cpu()
139
+
140
+ def compute_metrics(eval_pred):
141
+ logits = torch.tensor(eval_pred.predictions)
142
+
143
+ token_ids = logits.argmax(dim=-1) # (B, L): predicted token at each position
144
+
145
+ batch_logits = []
146
+ for b in range(logits.size(0)):
147
+ seq = token_ids[b] # (L,)
148
+ idxs = torch.where((seq == POS_ID[0]) | (seq == NEG_ID[0]))[0]
149
+ if len(idxs) == 0:
150
+ raise ValueError(f"Neither pos_id nor neg_id found in sequence {b}")
151
+ t = idxs[0].item() # first position where pos or neg appears
152
+ tok_id = seq[t].item() # should be either pos_id or neg_id
153
+ batch_logits.append(logits[b, t, tok_id]) # scalar
154
+
155
+ batch_logits = torch.stack(batch_logits) # shape: [B]
156
+ pred_texts = processor.tokenizer.batch_decode(token_ids[:,-1], skip_special_tokens=True)
157
+
158
+ # print(pred_texts)
159
+ # pdb.set_trace()
160
+ probs = torch.sigmoid(logits[:,-1, POS_ID[0]] - logits[:,-1, NEG_ID[0]]).numpy()
161
+
162
+ # probs = torch.sigmoid(batch_logits).numpy()
163
+ labels = torch.tensor(eval_pred.label_ids)
164
+ gt_ids = labels[labels != -100].view(logits.size(0), -1)[:, 0]
165
+ y_true = (gt_ids == POS_ID[0]).int().cpu().numpy()
166
+ auc_val = roc_auc_score(y_true, probs)
167
+ fpr, tpr, thr = roc_curve(y_true, probs)
168
+ best = thr[np.argmax(tpr - fpr)]
169
+ acc = accuracy_score(y_true, probs >= best)
170
+ return {"roc_auc": auc_val, "accuracy": acc}
171
+
172
+ def run_custom_evaluation(trainer, val_dataset, val_labels):
173
+ outputs = trainer.predict(val_dataset)
174
+ logits = torch.from_numpy(outputs.predictions) # (B, S, L)
175
+ # pdb.set_trace()
176
+ probs = torch.sigmoid(logits[:,-1, POS_ID[0]] - logits[:,-1, NEG_ID[0]]).numpy()
177
+
178
+ # decoded = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
179
+ # y_pred = [1 if "positive" in t.lower() else 0 for t in decoded]
180
+
181
+ auc_val = roc_auc_score(val_labels, probs)
182
+ # acc = accuracy_score(val_labels, y_pred)
183
+ print(f"[Custom Eval] AUC: {auc_val:.4f}")
184
+ # print(f"[Custom Eval] AUC: {auc_val:.4f}, ACC: {acc:.4f}")
185
+ return {"auc": auc_val}
186
+
187
+ # === Main ===
188
+ if __name__ == '__main__':
189
+ parser = argparse.ArgumentParser()
190
+ parser.add_argument("--task", required=True, help='amd, dr, glaucoma')
191
+ parser.add_argument("--name", required=True)
192
+ parser.add_argument("--use_subset", action='store_true')
193
+ args = parser.parse_args()
194
+
195
+ pyutils.same_seeds(0)
196
+
197
+ task_map = {'dr': (-3, 'Diabetic Retinopathy'), 'amd': (-2, 'Aged Macular Degeneration'), 'glaucoma': (-1, 'Glaucoma')}
198
+ task_idx, disease_name = task_map[args.task]
199
+ system_message = f"""You are an expert AI in ophthalmology.\n
200
+ Your primary role is to provide accurate, reliable, and up-to-date medical knowledge based on credible sources.\n
201
+ You must follow these guidelines:\n
202
+ 1. Be accurate, concise, and clinically relevant.\n
203
+ 2. Use proper medical terms.\n
204
+ 3. Avoid overexplaining unless requested.\n
205
+ 4. Tone: confident, professional, precise.\n
206
+ Do not include any explanation or thought.\n
207
+ If {disease_name} is present, answer exactly 'positive'. Otherwise answer 'negative'."""
208
+ # Diabetic Retinopathy (DR) is a diabetes-related eye disease that affects the retina — the light-sensitive tissue at the back of the eye. It occurs when chronically high blood sugar levels damage the small blood vessels in the retina, leading to leakage, blockage, or abnormal blood vessel growth.\n
209
+
210
+ cudnn.benchmark = True
211
+ img_root_path = '/shared/ssd_30T/yoon/exEYE/Eyeproject/data'
212
+ train_dataset = np.load('/shared/ssd_30T/yoon/exEYE/datasplit/train_final.npy')
213
+ val_dataset_raw = np.load('/shared/ssd_30T/yoon/exEYE/datasplit/val_final.npy')
214
+
215
+ if args.use_subset:
216
+ def subset(data,train=True):
217
+ neg = [s for s in data if s[task_idx] == '0.0']
218
+ pos = [s for s in data if s[task_idx] != '0.0']
219
+ num_sample = len(pos)
220
+ if train:
221
+ return random.sample(neg, 7*num_sample), random.sample(pos, num_sample)
222
+ else:
223
+ return random.sample(neg, 3*num_sample), random.sample(pos, num_sample)
224
+ # return random.sample(neg, 15), random.sample(pos, 15)
225
+ # return neg, random.sample(pos, num_sample)
226
+ train_dataset = sum(subset(train_dataset,train=True), [])
227
+ val_dataset_raw = sum(subset(val_dataset_raw,train=False), [])
228
+
229
+ train_dataset = [format_data(s) for s in tqdm(train_dataset)]
230
+ random.shuffle(train_dataset)
231
+ val_dataset = [format_data_for_inference(s) for s in tqdm(val_dataset_raw)]
232
+ val_labels = [1 if s[task_idx] != '0.0' else 0 for s in val_dataset_raw]
233
+ # val_dataset = [format_data(s) for s in tqdm(val_dataset)]
234
+ print("="*50)
235
+ print(f"Total number of Data| Train: {len(train_dataset)} | Val : {len(val_dataset)}")
236
+ print("="*50)
237
+
238
+ model_id = "google/medgemma-4b-it"
239
+ model_kwargs = dict(
240
+ attn_implementation="eager",
241
+ torch_dtype=torch.bfloat16,
242
+ device_map="auto",
243
+ )
244
+
245
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
246
+ load_in_4bit=True,
247
+ bnb_4bit_use_double_quant=True,
248
+ bnb_4bit_quant_type="nf4",
249
+ bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
250
+ bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
251
+ )
252
+
253
+ model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
254
+ processor = AutoProcessor.from_pretrained(model_id)
255
+
256
+ # Use right padding to avoid issues during training
257
+ processor.tokenizer.padding_side = "right"
258
+ # processor.image_processor.size = {"height": 512, "width": 512}
259
+ # processor.image_processor.crop_size = {"height": 512, "width": 512}
260
+
261
+ POS_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("positive")) #30558
262
+ NEG_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("negative")) #27851
263
+ ASST_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("model\n"))
264
+
265
+
266
+ peft_config = LoraConfig(
267
+ lora_alpha=16,
268
+ lora_dropout=0.05,
269
+ r=32,
270
+ bias="none",
271
+ target_modules="all-linear",
272
+ # target_modules=["q_proj", "v_proj"],
273
+ task_type="CAUSAL_LM",
274
+ modules_to_save=[
275
+ "lm_head",
276
+ "embed_tokens",
277
+ ],
278
+ )
279
+
280
+
281
+ exp_name = f"{model_id.split('/')[-1]}-{args.name}"
282
+
283
+ if os.path.exists(exp_name):
284
+ from peft import PeftModel
285
+ print("🔁 Loading trained PEFT weights...")
286
+ model = PeftModel.from_pretrained(model, exp_name)
287
+ # model = PeftModel.from_pretrained(model, exp_name+"/checkpoint-242")
288
+ # model = PeftModel.from_pretrained(model, "llava-1.5-7b-hf-dr-all/checkpoint-80")
289
+ phase= "val"
290
+ else:
291
+ print("🚀 Initializing new LoRA model...")
292
+ model = get_peft_model(model, peft_config)
293
+ model.print_trainable_parameters()
294
+ phase= "train"
295
+
296
+
297
+ training_args = SFTConfig(
298
+ output_dir=exp_name,
299
+ num_train_epochs= 16, # Number of training epochs
300
+ per_device_train_batch_size=4, # Batch size per device during training
301
+ per_device_eval_batch_size=4, # Batch size per device during evaluation
302
+ gradient_accumulation_steps=8, # Number of steps before performing a backward/update pass
303
+ gradient_checkpointing=True, # Enable gradient checkpointing to reduce memory usage
304
+ optim="adamw_torch_fused", # Use fused AdamW optimizer for better performance
305
+ logging_steps=10, # Number of steps between logs
306
+ save_strategy="epoch", # Save checkpoint every epoch
307
+ eval_strategy="steps", # Evaluate every `eval_steps`
308
+ eval_steps=10000, # Number of steps between evaluations
309
+ learning_rate=8e-4, # Learning rate based on QLoRA paper
310
+ bf16=True, # Use bfloat16 precision
311
+ max_grad_norm=0.3, # Max gradient norm based on QLoRA paper
312
+ warmup_ratio=0.03, # Warmup ratio based on QLoRA paper
313
+ lr_scheduler_type="linear", # Use linear learning rate scheduler
314
+ push_to_hub=True, # Push model to Hub
315
+ report_to="tensorboard", # Report metrics to tensorboard
316
+ gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues
317
+ dataset_kwargs={"skip_prepare_dataset": True}, # Skip default dataset preparation to preprocess manually
318
+ remove_unused_columns = False, # Columns are unused for training but needed for data collator
319
+ label_names=["labels"],
320
+ )
321
+ # training_args.remove_unused_columns = False
322
+
323
+ wandb.init(project=f"{exp_name}-Project", name=exp_name, config=training_args)
324
+
325
+ trainer = SFTTrainer(
326
+ model=model,
327
+ args=training_args,
328
+ train_dataset=train_dataset,
329
+ eval_dataset=val_dataset,
330
+ data_collator=collate_fn,
331
+ peft_config=peft_config,
332
+ processing_class=processor.tokenizer,
333
+ # compute_metrics=compute_metrics,
334
+ # preprocess_logits_for_metrics=slice_logits,
335
+ )
336
+
337
+ if not os.path.exists(exp_name):
338
+ shutil.copy("/shared/ssd_30T/yoon/exEYE/Eyeproject/train_medgemma_ft.py",os.path.join(".",exp_name,"train_medgemma_ft_copy.py"))
339
+
340
+ if phase == 'train':
341
+ trainer.train()
342
+ trainer.save_model(training_args.output_dir)
343
+
344
+ # custom_eval_metrics = run_custom_evaluation(trainer, val_dataset, val_labels)
345
+ # else:
346
+ # ft_pipe = pipeline(
347
+ # "image-text-to-text",
348
+ # model=exp_name,
349
+ # processor=processor,
350
+ # torch_dtype=torch.bfloat16,
351
+ # )
352
+
353
+ # # Set `do_sample = False` for deterministic responses
354
+ # ft_pipe.model.generation_config.do_sample = False
355
+ # ft_pipe.model.generation_config.pad_token_id = processor.tokenizer.eos_token_id
356
+ # # Use left padding during inference
357
+ # processor.tokenizer.padding_side = "left"
358
+
359
+ # texts = []
360
+ # images = []
361
+
362
+ # for example in val_dataset:
363
+ # text = processor.apply_chat_template(
364
+ # example["messages"], add_generation_prompt=True, tokenize=False
365
+ # ).strip()
366
+ # texts.append(text)
367
+ # image = example["image"].convert("RGB").resize((512, 512))
368
+ # images.append([image]) # 리스트로 감싸야 MedGEMMA가 기대하는 batched format
369
+
370
+ # # pdb.set_trace()
371
+ # ft_outputs = ft_pipe(
372
+ # text=texts,
373
+ # images=images,
374
+ # max_new_tokens=5,
375
+ # batch_size=1,
376
+ # return_full_text=False,
377
+ # )
378
+
379
+ batch_size = 1
380
+ model.eval()
381
+ all_logits = []
382
+
383
+ for i in tqdm(range(0, len(val_dataset), batch_size), desc="Running inference with logits"):
384
+ batch = val_dataset[i:i + batch_size]
385
+
386
+ # prepare inputs
387
+ texts = []
388
+ images = []
389
+ for example in batch:
390
+ text = processor.apply_chat_template(
391
+ example["messages"], add_generation_prompt=True, tokenize=False
392
+ ).strip()
393
+ texts.append(text)
394
+ image = example["image"].convert("RGB").resize((512, 512))
395
+ images.append([image])
396
+
397
+ # tokenizer & image processor
398
+ with torch.no_grad():
399
+ texts[0] += "\n"
400
+ inputs = processor(
401
+ text=texts,
402
+ images=images,
403
+ return_tensors="pt",
404
+ padding=True
405
+ ).to(model.device)
406
+
407
+ outputs = model(**inputs, output_hidden_states=False, return_dict=True)
408
+
409
+ print("==> ",processor.tokenizer.decode(outputs.logits[0].argmax(-1)[-1]))
410
+
411
+ logits = outputs.logits
412
+ # pdb.set_trace()
413
+ probs = torch.sigmoid(logits[0,-1, POS_ID] - logits[0,-1, NEG_ID])
414
+ # logits: (B, L, V)
415
+ # all_logits.append(outputs.logits.to(torch.float32).detach().cpu().numpy())
416
+ all_logits.append(probs)
417
+
418
+ # pdb.set_trace()
419
+
420
+ probs_all = torch.stack(all_logits,dim=0)
421
+ probs_all = [prob.to(torch.float32).detach().cpu() for prob in probs_all]
422
+ # logits= torch.from_numpy(np.stack(all_logits,axis=0)).squeeze(1)
423
+
424
+ # probs = torch.sigmoid(logits[:,-1, POS_ID] - logits[:,-1, NEG_ID])
425
+
426
+ # decoded = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
427
+ # y_pred = [1 if "positive" in t.lower() else 0 for t in decoded]
428
+ # pdb.set_trace()
429
+ auc_val = roc_auc_score(val_labels, probs_all)
430
+ print(auc_val)
431
+
432
+ # print(trainer.evaluate())
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61713d7b70980b1dac1979fbf4fa512bed3f7bbc0fa63cf78beb8efa0e918976
3
+ size 5752