berkamphoon commited on
Commit
cb2d655
·
verified ·
1 Parent(s): 221b9d1

Training in progress, epoch 1

Browse files
README.md CHANGED
@@ -4,8 +4,8 @@ library_name: transformers
4
  model_name: medgemma-27b-it-dr4
5
  tags:
6
  - generated_from_trainer
7
- - trl
8
  - sft
 
9
  licence: license
10
  ---
11
 
@@ -27,18 +27,18 @@ print(output["generated_text"])
27
 
28
  ## Training procedure
29
 
30
- [<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/yoon307-kaist/medgemma-27b-it-dr4-Project/runs/r39bmq47)
31
 
32
 
33
  This model was trained with SFT.
34
 
35
  ### Framework versions
36
 
37
- - TRL: 0.19.0
38
- - Transformers: 4.51.3
39
- - Pytorch: 2.5.0
40
- - Datasets: 3.6.0
41
- - Tokenizers: 0.21.1
42
 
43
  ## Citations
44
 
 
4
  model_name: medgemma-27b-it-dr4
5
  tags:
6
  - generated_from_trainer
 
7
  - sft
8
+ - trl
9
  licence: license
10
  ---
11
 
 
27
 
28
  ## Training procedure
29
 
30
+ [<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/yoon307-kaist/medgemma-27b-it-dr4-Project/runs/y71p22um)
31
 
32
 
33
  This model was trained with SFT.
34
 
35
  ### Framework versions
36
 
37
+ - TRL: 0.19.1
38
+ - Transformers: 4.53.2
39
+ - Pytorch: 2.6.0+cu124
40
+ - Datasets: 4.0.0
41
+ - Tokenizers: 0.21.2
42
 
43
  ## Citations
44
 
adapter_config.json CHANGED
@@ -24,20 +24,20 @@
24
  ],
25
  "peft_type": "LORA",
26
  "qalora_group_size": 16,
27
- "r": 8,
28
  "rank_pattern": {},
29
  "revision": null,
30
  "target_modules": [
31
- "v_proj",
32
  "up_proj",
33
- "down_proj",
 
 
 
34
  "o_proj",
35
  "fc2",
36
- "fc1",
37
- "out_proj",
38
  "k_proj",
39
- "gate_proj",
40
- "q_proj"
41
  ],
42
  "task_type": "CAUSAL_LM",
43
  "trainable_token_indices": null,
 
24
  ],
25
  "peft_type": "LORA",
26
  "qalora_group_size": 16,
27
+ "r": 16,
28
  "rank_pattern": {},
29
  "revision": null,
30
  "target_modules": [
31
+ "out_proj",
32
  "up_proj",
33
+ "fc1",
34
+ "q_proj",
35
+ "gate_proj",
36
+ "v_proj",
37
  "o_proj",
38
  "fc2",
 
 
39
  "k_proj",
40
+ "down_proj"
 
41
  ],
42
  "task_type": "CAUSAL_LM",
43
  "trainable_token_indices": null,
adapter_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fc1ece2f5b5d375b03e51674440e84a18891275ca6bfc1f53fc7fdc9550f96dc
3
- size 5883125880
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46088cfd14dbb5cfc8432351881e54ce4f916e9bb12fb44f8f1c04005524d622
3
+ size 6127553104
chat_template.jinja ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {{ bos_token }}
2
+ {%- if messages[0]['role'] == 'system' -%}
3
+ {%- if messages[0]['content'] is string -%}
4
+ {%- set first_user_prefix = messages[0]['content'] + '
5
+
6
+ ' -%}
7
+ {%- else -%}
8
+ {%- set first_user_prefix = messages[0]['content'][0]['text'] + '
9
+
10
+ ' -%}
11
+ {%- endif -%}
12
+ {%- set loop_messages = messages[1:] -%}
13
+ {%- else -%}
14
+ {%- set first_user_prefix = "" -%}
15
+ {%- set loop_messages = messages -%}
16
+ {%- endif -%}
17
+ {%- for message in loop_messages -%}
18
+ {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
19
+ {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
20
+ {%- endif -%}
21
+ {%- if (message['role'] == 'assistant') -%}
22
+ {%- set role = "model" -%}
23
+ {%- else -%}
24
+ {%- set role = message['role'] -%}
25
+ {%- endif -%}
26
+ {{ '<start_of_turn>' + role + '
27
+ ' + (first_user_prefix if loop.first else "") }}
28
+ {%- if message['content'] is string -%}
29
+ {{ message['content'] | trim }}
30
+ {%- elif message['content'] is iterable -%}
31
+ {%- for item in message['content'] -%}
32
+ {%- if item['type'] == 'image' -%}
33
+ {{ '<start_of_image>' }}
34
+ {%- elif item['type'] == 'text' -%}
35
+ {{ item['text'] | trim }}
36
+ {%- endif -%}
37
+ {%- endfor -%}
38
+ {%- else -%}
39
+ {{ raise_exception("Invalid content type") }}
40
+ {%- endif -%}
41
+ {{ '<end_of_turn>
42
+ ' }}
43
+ {%- endfor -%}
44
+ {%- if add_generation_prompt -%}
45
+ {{'<start_of_turn>model
46
+ '}}
47
+ {%- endif -%}
runs/Jul19_23-56-17_meedgxh100a/events.out.tfevents.1752983779.meedgxh100a.1020669.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0702869b05003d340911d40565a0adbea440ad368add0a194753e18b069fe298
3
+ size 9916
tokenizer_config.json CHANGED
@@ -51325,7 +51325,6 @@
51325
  },
51326
  "boi_token": "<start_of_image>",
51327
  "bos_token": "<bos>",
51328
- "chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n",
51329
  "clean_up_tokenization_spaces": false,
51330
  "eoi_token": "<end_of_image>",
51331
  "eos_token": "<eos>",
 
51325
  },
51326
  "boi_token": "<start_of_image>",
51327
  "bos_token": "<bos>",
 
51328
  "clean_up_tokenization_spaces": false,
51329
  "eoi_token": "<end_of_image>",
51330
  "eos_token": "<eos>",
train_medgemma_ft_copy.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division, print_function
2
+
3
+
4
+ # medqa gsm8k openbookqa bioasq pubmedqa squad_v2
5
+
6
+ # === Base ===
7
+ import os
8
+ import os.path as osp
9
+ import random
10
+ import argparse
11
+ import logging
12
+ from tqdm import tqdm
13
+ from matplotlib import pyplot as plt
14
+ import pdb
15
+ from PIL import Image
16
+ import shutil
17
+ import os
18
+
19
+ # === DL ===
20
+ import numpy as np
21
+ import torch
22
+ import torch.backends.cudnn as cudnn
23
+ from torch.utils.data import DataLoader
24
+ from torch.utils.tensorboard import SummaryWriter
25
+
26
+ # === Custom ===
27
+ # import tools.imutils as imutils
28
+ # import tools.utils as utils
29
+ # import tools.pyutils as pyutils
30
+ # from tools.utils import compute_es_auc, compute_group_auc, ImprovedBalancedBatchSampler, compute_es_auc_multi
31
+
32
+ # === Evaluation ===
33
+ from sklearn.metrics import roc_curve, accuracy_score, roc_auc_score
34
+
35
+ # === Transformers ===
36
+ from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig, pipeline, AutoModelForCausalLM
37
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
38
+ from trl import SFTTrainer, SFTConfig
39
+ import wandb
40
+
41
+ # === Label Masking Function ===
42
+ def mask_until_after_assistant(labels: torch.Tensor, tokenizer, assistant_token_ids: list):
43
+ for i in range(labels.size(0)):
44
+ for j in range(labels.size(1) - len(assistant_token_ids) + 1):
45
+ if torch.equal(labels[i, j:j+len(assistant_token_ids)], torch.tensor(assistant_token_ids, device=labels.device)):
46
+ labels[i, :j + len(assistant_token_ids)] = -100 # ASSISTANT: 까지 마스킹
47
+ break
48
+ return labels
49
+
50
+
51
+ # === Collate Function ===
52
+ def collate_fn(examples):
53
+ texts = []
54
+ images = []
55
+ for example in examples:
56
+ image = example["image"].convert("RGB")
57
+ image = image.resize((IM_SIZE,IM_SIZE))
58
+ images.append([image])
59
+ texts.append(processor.apply_chat_template(
60
+ example["messages"], add_generation_prompt=False, tokenize=False
61
+ ).strip())
62
+
63
+ # Tokenize the texts and process the images
64
+ batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
65
+
66
+ # The labels are the input_ids, with the padding and image tokens masked in
67
+ # the loss computation
68
+ labels = batch["input_ids"].clone()
69
+
70
+ # Mask image tokens
71
+ image_token_id = [
72
+ processor.tokenizer.convert_tokens_to_ids(
73
+ processor.tokenizer.special_tokens_map["boi_token"]
74
+ )
75
+ ]
76
+ # Mask tokens that are not used in the loss computation
77
+ labels[labels == processor.tokenizer.pad_token_id] = -100
78
+ labels[labels == image_token_id] = -100
79
+ labels[labels == 262144] = -100
80
+
81
+ labels = mask_until_after_assistant(labels, processor.tokenizer, ASST_ID)
82
+ labels[:,-1] = -100
83
+
84
+ batch["labels"] = labels
85
+ # pdb.set_trace()
86
+ return batch
87
+
88
+ def format_data(sample):
89
+ label = 'negative' if sample[task_idx] == '0.0' else 'positive'
90
+ prompt = f"Please diagnose whether the {disease_name} exist or not based on the given image.\n"
91
+
92
+ # pdb.set_trace()
93
+ example = {}
94
+ example["image"] = Image.open(os.path.join(img_root_path, sample[1]))
95
+ example["label"] = 0 if sample[task_idx]== '0,0' else 1
96
+ example["messages"] = [
97
+ {"role": "system", "content": [{"type": "text", "text": system_message}]},
98
+ {"role": "user", "content": [
99
+ # {"type": "image", "image": os.path.join(img_root_path, sample[1])},
100
+ {"type": "image"},
101
+ {"type": "text", "text": prompt},
102
+ ]},
103
+ {"role": "assistant", "content": [{"type": "text", "text": str(label)}]}
104
+ ]
105
+
106
+ return example
107
+
108
+ def format_data_for_inference(sample):
109
+ prompt = f"Please diagnose whether the {disease_name} exist or not based on the given image.\n"
110
+
111
+ # pdb.set_trace()
112
+ example = {}
113
+ example["image"] = Image.open(os.path.join(img_root_path, sample[1]))
114
+ # example["label"] = 0 if sample[task_idx]== '0,0' else 1
115
+ example["messages"] = [
116
+ {"role": "system", "content": [{"type": "text", "text": system_message}]},
117
+ {"role": "user", "content": [
118
+ # {"type": "image", "image": os.path.join(img_root_path, sample[1])},
119
+ {"type": "image"},
120
+ {"type": "text", "text": prompt+"\n"},
121
+ ]},
122
+ # {"role": "assistant", "content": [{"type": "text", "text": str(label)}]}
123
+ ]
124
+
125
+ return example
126
+
127
+ # === Logit Preprocessing ===
128
+ def slice_logits(logits, labels):
129
+ if isinstance(logits, (tuple, list)):
130
+ logits = logits[0]
131
+ return logits.detach().cpu()
132
+
133
+ def compute_metrics(eval_pred):
134
+ logits = torch.tensor(eval_pred.predictions)
135
+
136
+ token_ids = logits.argmax(dim=-1) # (B, L): predicted token at each position
137
+
138
+ batch_logits = []
139
+ for b in range(logits.size(0)):
140
+ seq = token_ids[b] # (L,)
141
+ idxs = torch.where((seq == POS_ID[0]) | (seq == NEG_ID[0]))[0]
142
+ if len(idxs) == 0:
143
+ raise ValueError(f"Neither pos_id nor neg_id found in sequence {b}")
144
+ t = idxs[0].item() # first position where pos or neg appears
145
+ tok_id = seq[t].item() # should be either pos_id or neg_id
146
+ batch_logits.append(logits[b, t, tok_id]) # scalar
147
+
148
+ batch_logits = torch.stack(batch_logits) # shape: [B]
149
+ pred_texts = processor.tokenizer.batch_decode(token_ids[:,-1], skip_special_tokens=True)
150
+
151
+ # print(pred_texts)
152
+ # pdb.set_trace()
153
+ probs = torch.sigmoid(logits[:,-1, POS_ID[0]] - logits[:,-1, NEG_ID[0]]).numpy()
154
+
155
+ # probs = torch.sigmoid(batch_logits).numpy()
156
+ labels = torch.tensor(eval_pred.label_ids)
157
+ gt_ids = labels[labels != -100].view(logits.size(0), -1)[:, 0]
158
+ y_true = (gt_ids == POS_ID[0]).int().cpu().numpy()
159
+ auc_val = roc_auc_score(y_true, probs)
160
+ fpr, tpr, thr = roc_curve(y_true, probs)
161
+ best = thr[np.argmax(tpr - fpr)]
162
+ acc = accuracy_score(y_true, probs >= best)
163
+ return {"roc_auc": auc_val, "accuracy": acc}
164
+
165
+ def run_custom_evaluation(trainer, val_dataset, val_labels):
166
+ outputs = trainer.predict(val_dataset)
167
+ logits = torch.from_numpy(outputs.predictions) # (B, S, L)
168
+ # pdb.set_trace()
169
+ probs = torch.sigmoid(logits[:,-1, POS_ID[0]] - logits[:,-1, NEG_ID[0]]).numpy()
170
+
171
+ # decoded = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
172
+ # y_pred = [1 if "positive" in t.lower() else 0 for t in decoded]
173
+
174
+ auc_val = roc_auc_score(val_labels, probs)
175
+ # acc = accuracy_score(val_labels, y_pred)
176
+ print(f"[Custom Eval] AUC: {auc_val:.4f}")
177
+ # print(f"[Custom Eval] AUC: {auc_val:.4f}, ACC: {acc:.4f}")
178
+ return {"auc": auc_val}
179
+
180
+ # === Main ===
181
+ if __name__ == '__main__':
182
+ parser = argparse.ArgumentParser()
183
+ parser.add_argument("--task", required=True, help='amd, dr, glaucoma')
184
+ parser.add_argument("--name", required=True)
185
+ parser.add_argument("--use_subset", action='store_true')
186
+ args = parser.parse_args()
187
+
188
+ random.seed(42)
189
+
190
+ # pyutils.same_seeds(0)
191
+
192
+ task_map = {'dr': (-3, 'Diabetic Retinopathy'), 'amd': (-2, 'Aged Macular Degeneration'), 'glaucoma': (-1, 'Glaucoma')}
193
+ task_idx, disease_name = task_map[args.task]
194
+ system_message = f"""You are an expert AI in ophthalmology.\n
195
+ Your primary role is to provide accurate, reliable, and up-to-date medical knowledge based on credible sources.\n
196
+ You must follow these guidelines:\n
197
+ 1. Be accurate, concise, and clinically relevant.\n
198
+ 2. Use proper medical terms.\n
199
+ 3. Avoid overexplaining unless requested.\n
200
+ 4. Tone: confident, professional, precise.\n
201
+ Do not include any explanation or thought.\n
202
+ Diabetic Retinopathy (DR) is a diabetes-related eye disease that affects the retina — the light-sensitive tissue at the back of the eye. It occurs when chronically high blood sugar levels damage the small blood vessels in the retina, leading to leakage, blockage, or abnormal blood vessel growth.\n
203
+ If {disease_name} is present, answer exactly 'positive'. Otherwise answer 'negative'."""
204
+
205
+ cudnn.benchmark = True
206
+ img_root_path = '/PHShome/sy1081/exeye/data'
207
+ train_dataset = np.load('/PHShome/sy1081/exeye/data/train_final.npy')
208
+ val_dataset_raw = np.load('/PHShome/sy1081/exeye/data/val_final.npy')
209
+
210
+ if args.use_subset:
211
+ def subset(data,train=True):
212
+ neg = [s for s in data if s[task_idx] == '0.0']
213
+ pos = [s for s in data if s[task_idx] != '0.0']
214
+ num_sample = len(pos)
215
+ if train:
216
+ return random.sample(neg, 5*num_sample), random.sample(pos, num_sample)
217
+ else:
218
+ # return random.sample(neg, 5*num_sample), pos
219
+ # return random.sample(neg, 15), random.sample(pos, 15)
220
+ return neg, pos
221
+ train_dataset = sum(subset(train_dataset,train=True), [])
222
+ val_dataset_raw = sum(subset(val_dataset_raw,train=False), [])
223
+
224
+ train_dataset = [format_data(s) for s in tqdm(train_dataset)]
225
+ random.shuffle(train_dataset)
226
+ val_dataset = [format_data_for_inference(s) for s in tqdm(val_dataset_raw)]
227
+ val_labels = [1 if s[task_idx] != '0.0' else 0 for s in val_dataset_raw]
228
+ # val_dataset = [format_data(s) for s in tqdm(val_dataset)]
229
+ print("="*50)
230
+ print(f"Total number of Data| Train: {len(train_dataset)} | Val : {len(val_dataset)}")
231
+ print("="*50)
232
+
233
+ # model_id = "google/medgemma-4b-it"
234
+ model_id = "google/medgemma-27b-it"
235
+ model_kwargs = dict(
236
+ attn_implementation="eager",
237
+ torch_dtype=torch.bfloat16,
238
+ device_map="auto",
239
+ )
240
+
241
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
242
+ load_in_4bit=True,
243
+ bnb_4bit_use_double_quant=True,
244
+ bnb_4bit_quant_type="nf4",
245
+ bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
246
+ bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
247
+ )
248
+
249
+ # model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
250
+
251
+ model = AutoModelForCausalLM.from_pretrained(
252
+ model_id,
253
+ **model_kwargs
254
+ # torch_dtype=torch.bfloat16,
255
+ # device_map="auto",
256
+ )
257
+ processor = AutoProcessor.from_pretrained(model_id)
258
+
259
+ # Use right padding to avoid issues during training
260
+ processor.tokenizer.padding_side = "right"
261
+
262
+ POS_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("positive")) #30558
263
+ NEG_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("negative")) #27851
264
+ ASST_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("model\n"))
265
+
266
+ IM_SIZE = 512
267
+
268
+ peft_config = LoraConfig(
269
+ lora_alpha=16,
270
+ lora_dropout=0.05,
271
+ r=16,
272
+ bias="none",
273
+ target_modules="all-linear",
274
+ # target_modules=["q_proj", "v_proj"],
275
+ task_type="CAUSAL_LM",
276
+ modules_to_save=[
277
+ "lm_head",
278
+ "embed_tokens",
279
+ ],
280
+ )
281
+
282
+
283
+ exp_name = f"{model_id.split('/')[-1]}-{args.name}"
284
+
285
+ if os.path.exists(exp_name):
286
+ from peft import PeftModel
287
+ print("🔁 Loading trained PEFT weights...")
288
+ # model = PeftModel.from_pretrained(model, exp_name)
289
+ model = PeftModel.from_pretrained(model, exp_name+"/checkpoint-598")
290
+ # model = PeftModel.from_pretrained(model, "llava-1.5-7b-hf-dr-all/checkpoint-80")
291
+ phase= "val"
292
+ else:
293
+ print("🚀 Initializing new LoRA model...")
294
+ # model = prepare_model_for_kbit_training(model)
295
+ model = get_peft_model(model, peft_config)
296
+ model.print_trainable_parameters()
297
+ phase= "train"
298
+
299
+
300
+ training_args = SFTConfig(
301
+ output_dir=exp_name,
302
+ num_train_epochs= 15, # Number of training epochs
303
+ per_device_train_batch_size=2, # Batch size per device during training
304
+ per_device_eval_batch_size=4, # Batch size per device during evaluation
305
+ gradient_accumulation_steps=8, # Number of steps before performing a backward/update pass
306
+ gradient_checkpointing=True, # Enable gradient checkpointing to reduce memory usage
307
+ optim="adamw_torch_fused", # Use fused AdamW optimizer for better performance
308
+ logging_steps=10, # Number of steps between logs
309
+ save_strategy="epoch", # Save checkpoint every epoch
310
+ eval_strategy="steps", # Evaluate every `eval_steps`
311
+ eval_steps=10000, # Number of steps between evaluations
312
+ learning_rate=1e-3, # Learning rate based on QLoRA paper
313
+ bf16=True, # Use bfloat16 precision
314
+ max_grad_norm=0.3, # Max gradient norm based on QLoRA paper
315
+ warmup_ratio=0.03, # Warmup ratio based on QLoRA paper
316
+ lr_scheduler_type="linear", # Use linear learning rate scheduler
317
+ # lr_scheduler_type="constant", # Use linear learning rate scheduler
318
+ push_to_hub=True, # Push model to Hub
319
+ report_to="tensorboard", # Report metrics to tensorboard
320
+ gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues
321
+ dataset_kwargs={"skip_prepare_dataset": True}, # Skip default dataset preparation to preprocess manually
322
+ remove_unused_columns = False, # Columns are unused for training but needed for data collator
323
+ label_names=["labels"],
324
+ )
325
+ # training_args.remove_unused_columns = False
326
+
327
+ wandb.init(project=f"{exp_name}-Project", name=exp_name, config=training_args)
328
+
329
+ trainer = SFTTrainer(
330
+ model=model,
331
+ args=training_args,
332
+ train_dataset=train_dataset,
333
+ eval_dataset=val_dataset,
334
+ data_collator=collate_fn,
335
+ peft_config=peft_config,
336
+ processing_class=processor.tokenizer,
337
+ # compute_metrics=compute_metrics,
338
+ # preprocess_logits_for_metrics=slice_logits,
339
+ )
340
+
341
+ # if not os.path.exists(exp_name):
342
+ shutil.copy("/PHShome/sy1081/exeye/train_medgemma_ft.py",os.path.join(".",exp_name,"train_medgemma_ft_copy.py"))
343
+
344
+ if phase == 'train':
345
+ trainer.train()
346
+ trainer.save_model(training_args.output_dir)
347
+
348
+
349
+ batch_size = 1
350
+ model.eval()
351
+ all_logits = []
352
+
353
+ for i in tqdm(range(0, len(val_dataset), batch_size), desc="Running inference with logits"):
354
+ batch = val_dataset[i:i + batch_size]
355
+
356
+ # prepare inputs
357
+ texts = []
358
+ images = []
359
+ for example in batch:
360
+ text = processor.apply_chat_template(
361
+ example["messages"], add_generation_prompt=True, tokenize=False
362
+ ).strip()
363
+ texts.append(text)
364
+ image = example["image"].convert("RGB").resize((IM_SIZE, IM_SIZE))
365
+ images.append([image])
366
+
367
+ # tokenizer & image processor
368
+ with torch.no_grad():
369
+ texts[0] += "\n"
370
+ inputs = processor(
371
+ text=texts,
372
+ images=images,
373
+ return_tensors="pt",
374
+ padding=True
375
+ ).to(model.device)
376
+
377
+ outputs = model(**inputs, output_hidden_states=False, return_dict=True)
378
+
379
+ print("==> ",processor.tokenizer.decode(outputs.logits[0].argmax(-1)[-1]))
380
+
381
+ logits = outputs.logits
382
+ # pdb.set_trace()
383
+ probs = torch.sigmoid(logits[0,-1, POS_ID] - logits[0,-1, NEG_ID])
384
+ # logits: (B, L, V)
385
+ # all_logits.append(outputs.logits.to(torch.float32).detach().cpu().numpy())
386
+ all_logits.append(probs)
387
+
388
+ # pdb.set_trace()
389
+
390
+ probs_all = torch.stack(all_logits,dim=0)
391
+ probs_all = [prob.to(torch.float32).detach().cpu() for prob in probs_all]
392
+ # logits= torch.from_numpy(np.stack(all_logits,axis=0)).squeeze(1)
393
+
394
+ # probs = torch.sigmoid(logits[:,-1, POS_ID] - logits[:,-1, NEG_ID])
395
+
396
+ # decoded = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
397
+ # y_pred = [1 if "positive" in t.lower() else 0 for t in decoded]
398
+ # pdb.set_trace()
399
+ auc_val = roc_auc_score(val_labels, probs_all)
400
+ print(auc_val)
401
+
402
+ # print(trainer.evaluate())
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a147c6b97d2559d2483a7172fd40a027f9a50db2443fcf14e4d379ed9a216ba2
3
- size 5752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26753afe5611dc69c3ec3c59e8748980463f74e5c7c31a5103a131c32c91af02
3
+ size 5816