berkamphoon commited on
Commit
5913fc2
·
verified ·
1 Parent(s): 5ca4b90

Training in progress, epoch 1

Browse files
README.md CHANGED
@@ -4,8 +4,8 @@ library_name: transformers
4
  model_name: medgemma-27b-it-dr5
5
  tags:
6
  - generated_from_trainer
7
- - trl
8
  - sft
 
9
  licence: license
10
  ---
11
 
@@ -27,18 +27,18 @@ print(output["generated_text"])
27
 
28
  ## Training procedure
29
 
30
- [<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/yoon307-kaist/medgemma-27b-it-dr5-Project/runs/mbxoj7k5)
31
 
32
 
33
  This model was trained with SFT.
34
 
35
  ### Framework versions
36
 
37
- - TRL: 0.19.0
38
- - Transformers: 4.51.3
39
- - Pytorch: 2.5.0
40
- - Datasets: 3.6.0
41
- - Tokenizers: 0.21.1
42
 
43
  ## Citations
44
 
 
4
  model_name: medgemma-27b-it-dr5
5
  tags:
6
  - generated_from_trainer
 
7
  - sft
8
+ - trl
9
  licence: license
10
  ---
11
 
 
27
 
28
  ## Training procedure
29
 
30
+ [<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/yoon307-kaist/medgemma-27b-it-dr5-Project/runs/6argv9kb)
31
 
32
 
33
  This model was trained with SFT.
34
 
35
  ### Framework versions
36
 
37
+ - TRL: 0.19.1
38
+ - Transformers: 4.53.2
39
+ - Pytorch: 2.6.0+cu124
40
+ - Datasets: 4.0.0
41
+ - Tokenizers: 0.21.2
42
 
43
  ## Citations
44
 
adapter_config.json CHANGED
@@ -29,15 +29,15 @@
29
  "revision": null,
30
  "target_modules": [
31
  "v_proj",
32
- "gate_proj",
 
 
33
  "fc2",
34
- "k_proj",
35
- "out_proj",
36
  "q_proj",
37
- "o_proj",
38
- "down_proj",
39
  "up_proj",
40
- "fc1"
41
  ],
42
  "task_type": "CAUSAL_LM",
43
  "trainable_token_indices": null,
 
29
  "revision": null,
30
  "target_modules": [
31
  "v_proj",
32
+ "down_proj",
33
+ "o_proj",
34
+ "fc1",
35
  "fc2",
 
 
36
  "q_proj",
37
+ "k_proj",
38
+ "gate_proj",
39
  "up_proj",
40
+ "out_proj"
41
  ],
42
  "task_type": "CAUSAL_LM",
43
  "trainable_token_indices": null,
adapter_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:193d5bc56b229dcd16a327b58b3d06056ba3d4a25c915706b577c5185a762759
3
- size 11766077184
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6b8752afa62eaf145b3ab7bcd63788ad169ed8f26b3a901c59c47ac67134b7b
3
+ size 6127553104
chat_template.jinja ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {{ bos_token }}
2
+ {%- if messages[0]['role'] == 'system' -%}
3
+ {%- if messages[0]['content'] is string -%}
4
+ {%- set first_user_prefix = messages[0]['content'] + '
5
+
6
+ ' -%}
7
+ {%- else -%}
8
+ {%- set first_user_prefix = messages[0]['content'][0]['text'] + '
9
+
10
+ ' -%}
11
+ {%- endif -%}
12
+ {%- set loop_messages = messages[1:] -%}
13
+ {%- else -%}
14
+ {%- set first_user_prefix = "" -%}
15
+ {%- set loop_messages = messages -%}
16
+ {%- endif -%}
17
+ {%- for message in loop_messages -%}
18
+ {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
19
+ {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
20
+ {%- endif -%}
21
+ {%- if (message['role'] == 'assistant') -%}
22
+ {%- set role = "model" -%}
23
+ {%- else -%}
24
+ {%- set role = message['role'] -%}
25
+ {%- endif -%}
26
+ {{ '<start_of_turn>' + role + '
27
+ ' + (first_user_prefix if loop.first else "") }}
28
+ {%- if message['content'] is string -%}
29
+ {{ message['content'] | trim }}
30
+ {%- elif message['content'] is iterable -%}
31
+ {%- for item in message['content'] -%}
32
+ {%- if item['type'] == 'image' -%}
33
+ {{ '<start_of_image>' }}
34
+ {%- elif item['type'] == 'text' -%}
35
+ {{ item['text'] | trim }}
36
+ {%- endif -%}
37
+ {%- endfor -%}
38
+ {%- else -%}
39
+ {{ raise_exception("Invalid content type") }}
40
+ {%- endif -%}
41
+ {{ '<end_of_turn>
42
+ ' }}
43
+ {%- endfor -%}
44
+ {%- if add_generation_prompt -%}
45
+ {{'<start_of_turn>model
46
+ '}}
47
+ {%- endif -%}
runs/Jul20_17-33-00_meedgxh100a/events.out.tfevents.1753047182.meedgxh100a.2023753.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c56c796f26a94c0ecce8c68b316405d52cda87acf40d7bc948259609714e558
3
+ size 9269
tokenizer_config.json CHANGED
@@ -51325,7 +51325,6 @@
51325
  },
51326
  "boi_token": "<start_of_image>",
51327
  "bos_token": "<bos>",
51328
- "chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n",
51329
  "clean_up_tokenization_spaces": false,
51330
  "eoi_token": "<end_of_image>",
51331
  "eos_token": "<eos>",
 
51325
  },
51326
  "boi_token": "<start_of_image>",
51327
  "bos_token": "<bos>",
 
51328
  "clean_up_tokenization_spaces": false,
51329
  "eoi_token": "<end_of_image>",
51330
  "eos_token": "<eos>",
train_medgemma_ft_copy.py CHANGED
@@ -24,17 +24,17 @@ from torch.utils.data import DataLoader
24
  from torch.utils.tensorboard import SummaryWriter
25
 
26
  # === Custom ===
27
- import tools.imutils as imutils
28
- import tools.utils as utils
29
- import tools.pyutils as pyutils
30
- from tools.utils import compute_es_auc, compute_group_auc, ImprovedBalancedBatchSampler, compute_es_auc_multi
31
 
32
  # === Evaluation ===
33
  from sklearn.metrics import roc_curve, accuracy_score, roc_auc_score
34
 
35
  # === Transformers ===
36
- from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig, pipeline
37
- from peft import LoraConfig, get_peft_model
38
  from trl import SFTTrainer, SFTConfig
39
  import wandb
40
 
@@ -54,7 +54,7 @@ def collate_fn(examples):
54
  images = []
55
  for example in examples:
56
  image = example["image"].convert("RGB")
57
- image = image.resize((512,512))
58
  images.append([image])
59
  texts.append(processor.apply_chat_template(
60
  example["messages"], add_generation_prompt=False, tokenize=False
@@ -121,14 +121,7 @@ def format_data_for_inference(sample):
121
  ]},
122
  # {"role": "assistant", "content": [{"type": "text", "text": str(label)}]}
123
  ]
124
- # prompt = f"Please diagnose whether the {disease_name} exist or not based on the given image."
125
- # return [
126
- # {"role": "system", "content": [{"type": "text", "text": system_message}]},
127
- # {"role": "user", "content": [
128
- # {"type": "image", "image": os.path.join(img_root_path, sample[1])},
129
- # {"type": "text", "text": prompt}
130
- # ]}
131
- # ]
132
  return example
133
 
134
  # === Logit Preprocessing ===
@@ -191,8 +184,10 @@ if __name__ == '__main__':
191
  parser.add_argument("--name", required=True)
192
  parser.add_argument("--use_subset", action='store_true')
193
  args = parser.parse_args()
 
 
194
 
195
- pyutils.same_seeds(0)
196
 
197
  task_map = {'dr': (-3, 'Diabetic Retinopathy'), 'amd': (-2, 'Aged Macular Degeneration'), 'glaucoma': (-1, 'Glaucoma')}
198
  task_idx, disease_name = task_map[args.task]
@@ -204,13 +199,13 @@ if __name__ == '__main__':
204
  3. Avoid overexplaining unless requested.\n
205
  4. Tone: confident, professional, precise.\n
206
  Do not include any explanation or thought.\n
 
207
  If {disease_name} is present, answer exactly 'positive'. Otherwise answer 'negative'."""
208
- # Diabetic Retinopathy (DR) is a diabetes-related eye disease that affects the retina — the light-sensitive tissue at the back of the eye. It occurs when chronically high blood sugar levels damage the small blood vessels in the retina, leading to leakage, blockage, or abnormal blood vessel growth.\n
209
 
210
  cudnn.benchmark = True
211
- img_root_path = '/shared/ssd_30T/yoon/exEYE/Eyeproject/data'
212
- train_dataset = np.load('/shared/ssd_30T/yoon/exEYE/datasplit/train_final.npy')
213
- val_dataset_raw = np.load('/shared/ssd_30T/yoon/exEYE/datasplit/val_final.npy')
214
 
215
  if args.use_subset:
216
  def subset(data,train=True):
@@ -218,11 +213,11 @@ if __name__ == '__main__':
218
  pos = [s for s in data if s[task_idx] != '0.0']
219
  num_sample = len(pos)
220
  if train:
221
- return random.sample(neg, 7*num_sample), random.sample(pos, num_sample)
222
  else:
223
- return random.sample(neg, 3*num_sample), random.sample(pos, num_sample)
224
  # return random.sample(neg, 15), random.sample(pos, 15)
225
- # return neg, random.sample(pos, num_sample)
226
  train_dataset = sum(subset(train_dataset,train=True), [])
227
  val_dataset_raw = sum(subset(val_dataset_raw,train=False), [])
228
 
@@ -235,7 +230,8 @@ if __name__ == '__main__':
235
  print(f"Total number of Data| Train: {len(train_dataset)} | Val : {len(val_dataset)}")
236
  print("="*50)
237
 
238
- model_id = "google/medgemma-4b-it"
 
239
  model_kwargs = dict(
240
  attn_implementation="eager",
241
  torch_dtype=torch.bfloat16,
@@ -250,23 +246,29 @@ if __name__ == '__main__':
250
  bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
251
  )
252
 
253
- model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
 
 
 
 
 
 
 
254
  processor = AutoProcessor.from_pretrained(model_id)
255
 
256
  # Use right padding to avoid issues during training
257
  processor.tokenizer.padding_side = "right"
258
- # processor.image_processor.size = {"height": 512, "width": 512}
259
- # processor.image_processor.crop_size = {"height": 512, "width": 512}
260
 
261
  POS_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("positive")) #30558
262
  NEG_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("negative")) #27851
263
  ASST_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("model\n"))
264
 
 
265
 
266
  peft_config = LoraConfig(
267
  lora_alpha=16,
268
  lora_dropout=0.05,
269
- r=32,
270
  bias="none",
271
  target_modules="all-linear",
272
  # target_modules=["q_proj", "v_proj"],
@@ -284,11 +286,12 @@ if __name__ == '__main__':
284
  from peft import PeftModel
285
  print("🔁 Loading trained PEFT weights...")
286
  model = PeftModel.from_pretrained(model, exp_name)
287
- # model = PeftModel.from_pretrained(model, exp_name+"/checkpoint-242")
288
  # model = PeftModel.from_pretrained(model, "llava-1.5-7b-hf-dr-all/checkpoint-80")
289
  phase= "val"
290
  else:
291
  print("🚀 Initializing new LoRA model...")
 
292
  model = get_peft_model(model, peft_config)
293
  model.print_trainable_parameters()
294
  phase= "train"
@@ -296,7 +299,7 @@ if __name__ == '__main__':
296
 
297
  training_args = SFTConfig(
298
  output_dir=exp_name,
299
- num_train_epochs= 16, # Number of training epochs
300
  per_device_train_batch_size=4, # Batch size per device during training
301
  per_device_eval_batch_size=4, # Batch size per device during evaluation
302
  gradient_accumulation_steps=8, # Number of steps before performing a backward/update pass
@@ -306,11 +309,12 @@ if __name__ == '__main__':
306
  save_strategy="epoch", # Save checkpoint every epoch
307
  eval_strategy="steps", # Evaluate every `eval_steps`
308
  eval_steps=10000, # Number of steps between evaluations
309
- learning_rate=8e-4, # Learning rate based on QLoRA paper
310
  bf16=True, # Use bfloat16 precision
311
  max_grad_norm=0.3, # Max gradient norm based on QLoRA paper
312
  warmup_ratio=0.03, # Warmup ratio based on QLoRA paper
313
  lr_scheduler_type="linear", # Use linear learning rate scheduler
 
314
  push_to_hub=True, # Push model to Hub
315
  report_to="tensorboard", # Report metrics to tensorboard
316
  gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues
@@ -334,47 +338,13 @@ if __name__ == '__main__':
334
  # preprocess_logits_for_metrics=slice_logits,
335
  )
336
 
337
- if not os.path.exists(exp_name):
338
- shutil.copy("/shared/ssd_30T/yoon/exEYE/Eyeproject/train_medgemma_ft.py",os.path.join(".",exp_name,"train_medgemma_ft_copy.py"))
339
 
340
  if phase == 'train':
341
  trainer.train()
342
  trainer.save_model(training_args.output_dir)
343
 
344
- # custom_eval_metrics = run_custom_evaluation(trainer, val_dataset, val_labels)
345
- # else:
346
- # ft_pipe = pipeline(
347
- # "image-text-to-text",
348
- # model=exp_name,
349
- # processor=processor,
350
- # torch_dtype=torch.bfloat16,
351
- # )
352
-
353
- # # Set `do_sample = False` for deterministic responses
354
- # ft_pipe.model.generation_config.do_sample = False
355
- # ft_pipe.model.generation_config.pad_token_id = processor.tokenizer.eos_token_id
356
- # # Use left padding during inference
357
- # processor.tokenizer.padding_side = "left"
358
-
359
- # texts = []
360
- # images = []
361
-
362
- # for example in val_dataset:
363
- # text = processor.apply_chat_template(
364
- # example["messages"], add_generation_prompt=True, tokenize=False
365
- # ).strip()
366
- # texts.append(text)
367
- # image = example["image"].convert("RGB").resize((512, 512))
368
- # images.append([image]) # 리스트로 감싸야 MedGEMMA가 기대하는 batched format
369
-
370
- # # pdb.set_trace()
371
- # ft_outputs = ft_pipe(
372
- # text=texts,
373
- # images=images,
374
- # max_new_tokens=5,
375
- # batch_size=1,
376
- # return_full_text=False,
377
- # )
378
 
379
  batch_size = 1
380
  model.eval()
@@ -391,7 +361,7 @@ if __name__ == '__main__':
391
  example["messages"], add_generation_prompt=True, tokenize=False
392
  ).strip()
393
  texts.append(text)
394
- image = example["image"].convert("RGB").resize((512, 512))
395
  images.append([image])
396
 
397
  # tokenizer & image processor
 
24
  from torch.utils.tensorboard import SummaryWriter
25
 
26
  # === Custom ===
27
+ # import tools.imutils as imutils
28
+ # import tools.utils as utils
29
+ # import tools.pyutils as pyutils
30
+ # from tools.utils import compute_es_auc, compute_group_auc, ImprovedBalancedBatchSampler, compute_es_auc_multi
31
 
32
  # === Evaluation ===
33
  from sklearn.metrics import roc_curve, accuracy_score, roc_auc_score
34
 
35
  # === Transformers ===
36
+ from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig, pipeline, AutoModelForCausalLM
37
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
38
  from trl import SFTTrainer, SFTConfig
39
  import wandb
40
 
 
54
  images = []
55
  for example in examples:
56
  image = example["image"].convert("RGB")
57
+ image = image.resize((IM_SIZE,IM_SIZE))
58
  images.append([image])
59
  texts.append(processor.apply_chat_template(
60
  example["messages"], add_generation_prompt=False, tokenize=False
 
121
  ]},
122
  # {"role": "assistant", "content": [{"type": "text", "text": str(label)}]}
123
  ]
124
+
 
 
 
 
 
 
 
125
  return example
126
 
127
  # === Logit Preprocessing ===
 
184
  parser.add_argument("--name", required=True)
185
  parser.add_argument("--use_subset", action='store_true')
186
  args = parser.parse_args()
187
+
188
+ random.seed(42)
189
 
190
+ # pyutils.same_seeds(0)
191
 
192
  task_map = {'dr': (-3, 'Diabetic Retinopathy'), 'amd': (-2, 'Aged Macular Degeneration'), 'glaucoma': (-1, 'Glaucoma')}
193
  task_idx, disease_name = task_map[args.task]
 
199
  3. Avoid overexplaining unless requested.\n
200
  4. Tone: confident, professional, precise.\n
201
  Do not include any explanation or thought.\n
202
+ Diabetic Retinopathy (DR) is a diabetes-related eye disease that affects the retina — the light-sensitive tissue at the back of the eye. It occurs when chronically high blood sugar levels damage the small blood vessels in the retina, leading to leakage, blockage, or abnormal blood vessel growth.\n
203
  If {disease_name} is present, answer exactly 'positive'. Otherwise answer 'negative'."""
 
204
 
205
  cudnn.benchmark = True
206
+ img_root_path = '/PHShome/sy1081/exeye/data'
207
+ train_dataset = np.load('/PHShome/sy1081/exeye/data/train_final.npy')
208
+ val_dataset_raw = np.load('/PHShome/sy1081/exeye/data/val_final.npy')
209
 
210
  if args.use_subset:
211
  def subset(data,train=True):
 
213
  pos = [s for s in data if s[task_idx] != '0.0']
214
  num_sample = len(pos)
215
  if train:
216
+ return random.sample(neg, 5*num_sample), random.sample(pos, num_sample)
217
  else:
218
+ return random.sample(neg, num_sample), pos
219
  # return random.sample(neg, 15), random.sample(pos, 15)
220
+ # return neg, pos
221
  train_dataset = sum(subset(train_dataset,train=True), [])
222
  val_dataset_raw = sum(subset(val_dataset_raw,train=False), [])
223
 
 
230
  print(f"Total number of Data| Train: {len(train_dataset)} | Val : {len(val_dataset)}")
231
  print("="*50)
232
 
233
+ # model_id = "google/medgemma-4b-it"
234
+ model_id = "google/medgemma-27b-it"
235
  model_kwargs = dict(
236
  attn_implementation="eager",
237
  torch_dtype=torch.bfloat16,
 
246
  bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
247
  )
248
 
249
+ # model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
250
+
251
+ model = AutoModelForCausalLM.from_pretrained(
252
+ model_id,
253
+ **model_kwargs
254
+ # torch_dtype=torch.bfloat16,
255
+ # device_map="auto",
256
+ )
257
  processor = AutoProcessor.from_pretrained(model_id)
258
 
259
  # Use right padding to avoid issues during training
260
  processor.tokenizer.padding_side = "right"
 
 
261
 
262
  POS_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("positive")) #30558
263
  NEG_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("negative")) #27851
264
  ASST_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("model\n"))
265
 
266
+ IM_SIZE = 1024
267
 
268
  peft_config = LoraConfig(
269
  lora_alpha=16,
270
  lora_dropout=0.05,
271
+ r=16,
272
  bias="none",
273
  target_modules="all-linear",
274
  # target_modules=["q_proj", "v_proj"],
 
286
  from peft import PeftModel
287
  print("🔁 Loading trained PEFT weights...")
288
  model = PeftModel.from_pretrained(model, exp_name)
289
+ # model = PeftModel.from_pretrained(model, exp_name+"/checkpoint-690")
290
  # model = PeftModel.from_pretrained(model, "llava-1.5-7b-hf-dr-all/checkpoint-80")
291
  phase= "val"
292
  else:
293
  print("🚀 Initializing new LoRA model...")
294
+ # model = prepare_model_for_kbit_training(model)
295
  model = get_peft_model(model, peft_config)
296
  model.print_trainable_parameters()
297
  phase= "train"
 
299
 
300
  training_args = SFTConfig(
301
  output_dir=exp_name,
302
+ num_train_epochs= 20, # Number of training epochs
303
  per_device_train_batch_size=4, # Batch size per device during training
304
  per_device_eval_batch_size=4, # Batch size per device during evaluation
305
  gradient_accumulation_steps=8, # Number of steps before performing a backward/update pass
 
309
  save_strategy="epoch", # Save checkpoint every epoch
310
  eval_strategy="steps", # Evaluate every `eval_steps`
311
  eval_steps=10000, # Number of steps between evaluations
312
+ learning_rate=5e-4, # Learning rate based on QLoRA paper
313
  bf16=True, # Use bfloat16 precision
314
  max_grad_norm=0.3, # Max gradient norm based on QLoRA paper
315
  warmup_ratio=0.03, # Warmup ratio based on QLoRA paper
316
  lr_scheduler_type="linear", # Use linear learning rate scheduler
317
+ # lr_scheduler_type="constant", # Use linear learning rate scheduler
318
  push_to_hub=True, # Push model to Hub
319
  report_to="tensorboard", # Report metrics to tensorboard
320
  gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues
 
338
  # preprocess_logits_for_metrics=slice_logits,
339
  )
340
 
341
+ # if not os.path.exists(exp_name):
342
+ shutil.copy("/PHShome/sy1081/exeye/train_medgemma_ft.py",os.path.join(".",exp_name,"train_medgemma_ft_copy.py"))
343
 
344
  if phase == 'train':
345
  trainer.train()
346
  trainer.save_model(training_args.output_dir)
347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
  batch_size = 1
350
  model.eval()
 
361
  example["messages"], add_generation_prompt=True, tokenize=False
362
  ).strip()
363
  texts.append(text)
364
+ image = example["image"].convert("RGB").resize((IM_SIZE, IM_SIZE))
365
  images.append([image])
366
 
367
  # tokenizer & image processor
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:61713d7b70980b1dac1979fbf4fa512bed3f7bbc0fa63cf78beb8efa0e918976
3
- size 5752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c0ab1f9caf759796d310240a8f917319ddf8b52bbe1f0b2c42b4b965b668b1c
3
+ size 5816