Training in progress, epoch 1
Browse files- README.md +7 -7
- adapter_config.json +6 -6
- adapter_model.safetensors +2 -2
- chat_template.jinja +47 -0
- runs/Jul20_17-33-00_meedgxh100a/events.out.tfevents.1753047182.meedgxh100a.2023753.0 +3 -0
- tokenizer_config.json +0 -1
- train_medgemma_ft_copy.py +38 -68
- training_args.bin +2 -2
README.md
CHANGED
|
@@ -4,8 +4,8 @@ library_name: transformers
|
|
| 4 |
model_name: medgemma-27b-it-dr5
|
| 5 |
tags:
|
| 6 |
- generated_from_trainer
|
| 7 |
-
- trl
|
| 8 |
- sft
|
|
|
|
| 9 |
licence: license
|
| 10 |
---
|
| 11 |
|
|
@@ -27,18 +27,18 @@ print(output["generated_text"])
|
|
| 27 |
|
| 28 |
## Training procedure
|
| 29 |
|
| 30 |
-
[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/yoon307-kaist/medgemma-27b-it-dr5-Project/runs/
|
| 31 |
|
| 32 |
|
| 33 |
This model was trained with SFT.
|
| 34 |
|
| 35 |
### Framework versions
|
| 36 |
|
| 37 |
-
- TRL: 0.19.
|
| 38 |
-
- Transformers: 4.
|
| 39 |
-
- Pytorch: 2.
|
| 40 |
-
- Datasets:
|
| 41 |
-
- Tokenizers: 0.21.
|
| 42 |
|
| 43 |
## Citations
|
| 44 |
|
|
|
|
| 4 |
model_name: medgemma-27b-it-dr5
|
| 5 |
tags:
|
| 6 |
- generated_from_trainer
|
|
|
|
| 7 |
- sft
|
| 8 |
+
- trl
|
| 9 |
licence: license
|
| 10 |
---
|
| 11 |
|
|
|
|
| 27 |
|
| 28 |
## Training procedure
|
| 29 |
|
| 30 |
+
[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/yoon307-kaist/medgemma-27b-it-dr5-Project/runs/6argv9kb)
|
| 31 |
|
| 32 |
|
| 33 |
This model was trained with SFT.
|
| 34 |
|
| 35 |
### Framework versions
|
| 36 |
|
| 37 |
+
- TRL: 0.19.1
|
| 38 |
+
- Transformers: 4.53.2
|
| 39 |
+
- Pytorch: 2.6.0+cu124
|
| 40 |
+
- Datasets: 4.0.0
|
| 41 |
+
- Tokenizers: 0.21.2
|
| 42 |
|
| 43 |
## Citations
|
| 44 |
|
adapter_config.json
CHANGED
|
@@ -29,15 +29,15 @@
|
|
| 29 |
"revision": null,
|
| 30 |
"target_modules": [
|
| 31 |
"v_proj",
|
| 32 |
-
"
|
|
|
|
|
|
|
| 33 |
"fc2",
|
| 34 |
-
"k_proj",
|
| 35 |
-
"out_proj",
|
| 36 |
"q_proj",
|
| 37 |
-
"
|
| 38 |
-
"
|
| 39 |
"up_proj",
|
| 40 |
-
"
|
| 41 |
],
|
| 42 |
"task_type": "CAUSAL_LM",
|
| 43 |
"trainable_token_indices": null,
|
|
|
|
| 29 |
"revision": null,
|
| 30 |
"target_modules": [
|
| 31 |
"v_proj",
|
| 32 |
+
"down_proj",
|
| 33 |
+
"o_proj",
|
| 34 |
+
"fc1",
|
| 35 |
"fc2",
|
|
|
|
|
|
|
| 36 |
"q_proj",
|
| 37 |
+
"k_proj",
|
| 38 |
+
"gate_proj",
|
| 39 |
"up_proj",
|
| 40 |
+
"out_proj"
|
| 41 |
],
|
| 42 |
"task_type": "CAUSAL_LM",
|
| 43 |
"trainable_token_indices": null,
|
adapter_model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f6b8752afa62eaf145b3ab7bcd63788ad169ed8f26b3a901c59c47ac67134b7b
|
| 3 |
+
size 6127553104
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{{ bos_token }}
|
| 2 |
+
{%- if messages[0]['role'] == 'system' -%}
|
| 3 |
+
{%- if messages[0]['content'] is string -%}
|
| 4 |
+
{%- set first_user_prefix = messages[0]['content'] + '
|
| 5 |
+
|
| 6 |
+
' -%}
|
| 7 |
+
{%- else -%}
|
| 8 |
+
{%- set first_user_prefix = messages[0]['content'][0]['text'] + '
|
| 9 |
+
|
| 10 |
+
' -%}
|
| 11 |
+
{%- endif -%}
|
| 12 |
+
{%- set loop_messages = messages[1:] -%}
|
| 13 |
+
{%- else -%}
|
| 14 |
+
{%- set first_user_prefix = "" -%}
|
| 15 |
+
{%- set loop_messages = messages -%}
|
| 16 |
+
{%- endif -%}
|
| 17 |
+
{%- for message in loop_messages -%}
|
| 18 |
+
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
|
| 19 |
+
{{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
|
| 20 |
+
{%- endif -%}
|
| 21 |
+
{%- if (message['role'] == 'assistant') -%}
|
| 22 |
+
{%- set role = "model" -%}
|
| 23 |
+
{%- else -%}
|
| 24 |
+
{%- set role = message['role'] -%}
|
| 25 |
+
{%- endif -%}
|
| 26 |
+
{{ '<start_of_turn>' + role + '
|
| 27 |
+
' + (first_user_prefix if loop.first else "") }}
|
| 28 |
+
{%- if message['content'] is string -%}
|
| 29 |
+
{{ message['content'] | trim }}
|
| 30 |
+
{%- elif message['content'] is iterable -%}
|
| 31 |
+
{%- for item in message['content'] -%}
|
| 32 |
+
{%- if item['type'] == 'image' -%}
|
| 33 |
+
{{ '<start_of_image>' }}
|
| 34 |
+
{%- elif item['type'] == 'text' -%}
|
| 35 |
+
{{ item['text'] | trim }}
|
| 36 |
+
{%- endif -%}
|
| 37 |
+
{%- endfor -%}
|
| 38 |
+
{%- else -%}
|
| 39 |
+
{{ raise_exception("Invalid content type") }}
|
| 40 |
+
{%- endif -%}
|
| 41 |
+
{{ '<end_of_turn>
|
| 42 |
+
' }}
|
| 43 |
+
{%- endfor -%}
|
| 44 |
+
{%- if add_generation_prompt -%}
|
| 45 |
+
{{'<start_of_turn>model
|
| 46 |
+
'}}
|
| 47 |
+
{%- endif -%}
|
runs/Jul20_17-33-00_meedgxh100a/events.out.tfevents.1753047182.meedgxh100a.2023753.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0c56c796f26a94c0ecce8c68b316405d52cda87acf40d7bc948259609714e558
|
| 3 |
+
size 9269
|
tokenizer_config.json
CHANGED
|
@@ -51325,7 +51325,6 @@
|
|
| 51325 |
},
|
| 51326 |
"boi_token": "<start_of_image>",
|
| 51327 |
"bos_token": "<bos>",
|
| 51328 |
-
"chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n",
|
| 51329 |
"clean_up_tokenization_spaces": false,
|
| 51330 |
"eoi_token": "<end_of_image>",
|
| 51331 |
"eos_token": "<eos>",
|
|
|
|
| 51325 |
},
|
| 51326 |
"boi_token": "<start_of_image>",
|
| 51327 |
"bos_token": "<bos>",
|
|
|
|
| 51328 |
"clean_up_tokenization_spaces": false,
|
| 51329 |
"eoi_token": "<end_of_image>",
|
| 51330 |
"eos_token": "<eos>",
|
train_medgemma_ft_copy.py
CHANGED
|
@@ -24,17 +24,17 @@ from torch.utils.data import DataLoader
|
|
| 24 |
from torch.utils.tensorboard import SummaryWriter
|
| 25 |
|
| 26 |
# === Custom ===
|
| 27 |
-
import tools.imutils as imutils
|
| 28 |
-
import tools.utils as utils
|
| 29 |
-
import tools.pyutils as pyutils
|
| 30 |
-
from tools.utils import compute_es_auc, compute_group_auc, ImprovedBalancedBatchSampler, compute_es_auc_multi
|
| 31 |
|
| 32 |
# === Evaluation ===
|
| 33 |
from sklearn.metrics import roc_curve, accuracy_score, roc_auc_score
|
| 34 |
|
| 35 |
# === Transformers ===
|
| 36 |
-
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig, pipeline
|
| 37 |
-
from peft import LoraConfig, get_peft_model
|
| 38 |
from trl import SFTTrainer, SFTConfig
|
| 39 |
import wandb
|
| 40 |
|
|
@@ -54,7 +54,7 @@ def collate_fn(examples):
|
|
| 54 |
images = []
|
| 55 |
for example in examples:
|
| 56 |
image = example["image"].convert("RGB")
|
| 57 |
-
image = image.resize((
|
| 58 |
images.append([image])
|
| 59 |
texts.append(processor.apply_chat_template(
|
| 60 |
example["messages"], add_generation_prompt=False, tokenize=False
|
|
@@ -121,14 +121,7 @@ def format_data_for_inference(sample):
|
|
| 121 |
]},
|
| 122 |
# {"role": "assistant", "content": [{"type": "text", "text": str(label)}]}
|
| 123 |
]
|
| 124 |
-
|
| 125 |
-
# return [
|
| 126 |
-
# {"role": "system", "content": [{"type": "text", "text": system_message}]},
|
| 127 |
-
# {"role": "user", "content": [
|
| 128 |
-
# {"type": "image", "image": os.path.join(img_root_path, sample[1])},
|
| 129 |
-
# {"type": "text", "text": prompt}
|
| 130 |
-
# ]}
|
| 131 |
-
# ]
|
| 132 |
return example
|
| 133 |
|
| 134 |
# === Logit Preprocessing ===
|
|
@@ -191,8 +184,10 @@ if __name__ == '__main__':
|
|
| 191 |
parser.add_argument("--name", required=True)
|
| 192 |
parser.add_argument("--use_subset", action='store_true')
|
| 193 |
args = parser.parse_args()
|
|
|
|
|
|
|
| 194 |
|
| 195 |
-
pyutils.same_seeds(0)
|
| 196 |
|
| 197 |
task_map = {'dr': (-3, 'Diabetic Retinopathy'), 'amd': (-2, 'Aged Macular Degeneration'), 'glaucoma': (-1, 'Glaucoma')}
|
| 198 |
task_idx, disease_name = task_map[args.task]
|
|
@@ -204,13 +199,13 @@ if __name__ == '__main__':
|
|
| 204 |
3. Avoid overexplaining unless requested.\n
|
| 205 |
4. Tone: confident, professional, precise.\n
|
| 206 |
Do not include any explanation or thought.\n
|
|
|
|
| 207 |
If {disease_name} is present, answer exactly 'positive'. Otherwise answer 'negative'."""
|
| 208 |
-
# Diabetic Retinopathy (DR) is a diabetes-related eye disease that affects the retina — the light-sensitive tissue at the back of the eye. It occurs when chronically high blood sugar levels damage the small blood vessels in the retina, leading to leakage, blockage, or abnormal blood vessel growth.\n
|
| 209 |
|
| 210 |
cudnn.benchmark = True
|
| 211 |
-
img_root_path = '/
|
| 212 |
-
train_dataset = np.load('/
|
| 213 |
-
val_dataset_raw = np.load('/
|
| 214 |
|
| 215 |
if args.use_subset:
|
| 216 |
def subset(data,train=True):
|
|
@@ -218,11 +213,11 @@ if __name__ == '__main__':
|
|
| 218 |
pos = [s for s in data if s[task_idx] != '0.0']
|
| 219 |
num_sample = len(pos)
|
| 220 |
if train:
|
| 221 |
-
return random.sample(neg,
|
| 222 |
else:
|
| 223 |
-
return random.sample(neg,
|
| 224 |
# return random.sample(neg, 15), random.sample(pos, 15)
|
| 225 |
-
# return neg,
|
| 226 |
train_dataset = sum(subset(train_dataset,train=True), [])
|
| 227 |
val_dataset_raw = sum(subset(val_dataset_raw,train=False), [])
|
| 228 |
|
|
@@ -235,7 +230,8 @@ if __name__ == '__main__':
|
|
| 235 |
print(f"Total number of Data| Train: {len(train_dataset)} | Val : {len(val_dataset)}")
|
| 236 |
print("="*50)
|
| 237 |
|
| 238 |
-
model_id = "google/medgemma-4b-it"
|
|
|
|
| 239 |
model_kwargs = dict(
|
| 240 |
attn_implementation="eager",
|
| 241 |
torch_dtype=torch.bfloat16,
|
|
@@ -250,23 +246,29 @@ if __name__ == '__main__':
|
|
| 250 |
bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
|
| 251 |
)
|
| 252 |
|
| 253 |
-
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
processor = AutoProcessor.from_pretrained(model_id)
|
| 255 |
|
| 256 |
# Use right padding to avoid issues during training
|
| 257 |
processor.tokenizer.padding_side = "right"
|
| 258 |
-
# processor.image_processor.size = {"height": 512, "width": 512}
|
| 259 |
-
# processor.image_processor.crop_size = {"height": 512, "width": 512}
|
| 260 |
|
| 261 |
POS_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("positive")) #30558
|
| 262 |
NEG_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("negative")) #27851
|
| 263 |
ASST_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("model\n"))
|
| 264 |
|
|
|
|
| 265 |
|
| 266 |
peft_config = LoraConfig(
|
| 267 |
lora_alpha=16,
|
| 268 |
lora_dropout=0.05,
|
| 269 |
-
r=
|
| 270 |
bias="none",
|
| 271 |
target_modules="all-linear",
|
| 272 |
# target_modules=["q_proj", "v_proj"],
|
|
@@ -284,11 +286,12 @@ if __name__ == '__main__':
|
|
| 284 |
from peft import PeftModel
|
| 285 |
print("🔁 Loading trained PEFT weights...")
|
| 286 |
model = PeftModel.from_pretrained(model, exp_name)
|
| 287 |
-
# model = PeftModel.from_pretrained(model, exp_name+"/checkpoint-
|
| 288 |
# model = PeftModel.from_pretrained(model, "llava-1.5-7b-hf-dr-all/checkpoint-80")
|
| 289 |
phase= "val"
|
| 290 |
else:
|
| 291 |
print("🚀 Initializing new LoRA model...")
|
|
|
|
| 292 |
model = get_peft_model(model, peft_config)
|
| 293 |
model.print_trainable_parameters()
|
| 294 |
phase= "train"
|
|
@@ -296,7 +299,7 @@ if __name__ == '__main__':
|
|
| 296 |
|
| 297 |
training_args = SFTConfig(
|
| 298 |
output_dir=exp_name,
|
| 299 |
-
num_train_epochs=
|
| 300 |
per_device_train_batch_size=4, # Batch size per device during training
|
| 301 |
per_device_eval_batch_size=4, # Batch size per device during evaluation
|
| 302 |
gradient_accumulation_steps=8, # Number of steps before performing a backward/update pass
|
|
@@ -306,11 +309,12 @@ if __name__ == '__main__':
|
|
| 306 |
save_strategy="epoch", # Save checkpoint every epoch
|
| 307 |
eval_strategy="steps", # Evaluate every `eval_steps`
|
| 308 |
eval_steps=10000, # Number of steps between evaluations
|
| 309 |
-
learning_rate=
|
| 310 |
bf16=True, # Use bfloat16 precision
|
| 311 |
max_grad_norm=0.3, # Max gradient norm based on QLoRA paper
|
| 312 |
warmup_ratio=0.03, # Warmup ratio based on QLoRA paper
|
| 313 |
lr_scheduler_type="linear", # Use linear learning rate scheduler
|
|
|
|
| 314 |
push_to_hub=True, # Push model to Hub
|
| 315 |
report_to="tensorboard", # Report metrics to tensorboard
|
| 316 |
gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues
|
|
@@ -334,47 +338,13 @@ if __name__ == '__main__':
|
|
| 334 |
# preprocess_logits_for_metrics=slice_logits,
|
| 335 |
)
|
| 336 |
|
| 337 |
-
if not os.path.exists(exp_name):
|
| 338 |
-
|
| 339 |
|
| 340 |
if phase == 'train':
|
| 341 |
trainer.train()
|
| 342 |
trainer.save_model(training_args.output_dir)
|
| 343 |
|
| 344 |
-
# custom_eval_metrics = run_custom_evaluation(trainer, val_dataset, val_labels)
|
| 345 |
-
# else:
|
| 346 |
-
# ft_pipe = pipeline(
|
| 347 |
-
# "image-text-to-text",
|
| 348 |
-
# model=exp_name,
|
| 349 |
-
# processor=processor,
|
| 350 |
-
# torch_dtype=torch.bfloat16,
|
| 351 |
-
# )
|
| 352 |
-
|
| 353 |
-
# # Set `do_sample = False` for deterministic responses
|
| 354 |
-
# ft_pipe.model.generation_config.do_sample = False
|
| 355 |
-
# ft_pipe.model.generation_config.pad_token_id = processor.tokenizer.eos_token_id
|
| 356 |
-
# # Use left padding during inference
|
| 357 |
-
# processor.tokenizer.padding_side = "left"
|
| 358 |
-
|
| 359 |
-
# texts = []
|
| 360 |
-
# images = []
|
| 361 |
-
|
| 362 |
-
# for example in val_dataset:
|
| 363 |
-
# text = processor.apply_chat_template(
|
| 364 |
-
# example["messages"], add_generation_prompt=True, tokenize=False
|
| 365 |
-
# ).strip()
|
| 366 |
-
# texts.append(text)
|
| 367 |
-
# image = example["image"].convert("RGB").resize((512, 512))
|
| 368 |
-
# images.append([image]) # 리스트로 감싸야 MedGEMMA가 기대하는 batched format
|
| 369 |
-
|
| 370 |
-
# # pdb.set_trace()
|
| 371 |
-
# ft_outputs = ft_pipe(
|
| 372 |
-
# text=texts,
|
| 373 |
-
# images=images,
|
| 374 |
-
# max_new_tokens=5,
|
| 375 |
-
# batch_size=1,
|
| 376 |
-
# return_full_text=False,
|
| 377 |
-
# )
|
| 378 |
|
| 379 |
batch_size = 1
|
| 380 |
model.eval()
|
|
@@ -391,7 +361,7 @@ if __name__ == '__main__':
|
|
| 391 |
example["messages"], add_generation_prompt=True, tokenize=False
|
| 392 |
).strip()
|
| 393 |
texts.append(text)
|
| 394 |
-
image = example["image"].convert("RGB").resize((
|
| 395 |
images.append([image])
|
| 396 |
|
| 397 |
# tokenizer & image processor
|
|
|
|
| 24 |
from torch.utils.tensorboard import SummaryWriter
|
| 25 |
|
| 26 |
# === Custom ===
|
| 27 |
+
# import tools.imutils as imutils
|
| 28 |
+
# import tools.utils as utils
|
| 29 |
+
# import tools.pyutils as pyutils
|
| 30 |
+
# from tools.utils import compute_es_auc, compute_group_auc, ImprovedBalancedBatchSampler, compute_es_auc_multi
|
| 31 |
|
| 32 |
# === Evaluation ===
|
| 33 |
from sklearn.metrics import roc_curve, accuracy_score, roc_auc_score
|
| 34 |
|
| 35 |
# === Transformers ===
|
| 36 |
+
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig, pipeline, AutoModelForCausalLM
|
| 37 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 38 |
from trl import SFTTrainer, SFTConfig
|
| 39 |
import wandb
|
| 40 |
|
|
|
|
| 54 |
images = []
|
| 55 |
for example in examples:
|
| 56 |
image = example["image"].convert("RGB")
|
| 57 |
+
image = image.resize((IM_SIZE,IM_SIZE))
|
| 58 |
images.append([image])
|
| 59 |
texts.append(processor.apply_chat_template(
|
| 60 |
example["messages"], add_generation_prompt=False, tokenize=False
|
|
|
|
| 121 |
]},
|
| 122 |
# {"role": "assistant", "content": [{"type": "text", "text": str(label)}]}
|
| 123 |
]
|
| 124 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
return example
|
| 126 |
|
| 127 |
# === Logit Preprocessing ===
|
|
|
|
| 184 |
parser.add_argument("--name", required=True)
|
| 185 |
parser.add_argument("--use_subset", action='store_true')
|
| 186 |
args = parser.parse_args()
|
| 187 |
+
|
| 188 |
+
random.seed(42)
|
| 189 |
|
| 190 |
+
# pyutils.same_seeds(0)
|
| 191 |
|
| 192 |
task_map = {'dr': (-3, 'Diabetic Retinopathy'), 'amd': (-2, 'Aged Macular Degeneration'), 'glaucoma': (-1, 'Glaucoma')}
|
| 193 |
task_idx, disease_name = task_map[args.task]
|
|
|
|
| 199 |
3. Avoid overexplaining unless requested.\n
|
| 200 |
4. Tone: confident, professional, precise.\n
|
| 201 |
Do not include any explanation or thought.\n
|
| 202 |
+
Diabetic Retinopathy (DR) is a diabetes-related eye disease that affects the retina — the light-sensitive tissue at the back of the eye. It occurs when chronically high blood sugar levels damage the small blood vessels in the retina, leading to leakage, blockage, or abnormal blood vessel growth.\n
|
| 203 |
If {disease_name} is present, answer exactly 'positive'. Otherwise answer 'negative'."""
|
|
|
|
| 204 |
|
| 205 |
cudnn.benchmark = True
|
| 206 |
+
img_root_path = '/PHShome/sy1081/exeye/data'
|
| 207 |
+
train_dataset = np.load('/PHShome/sy1081/exeye/data/train_final.npy')
|
| 208 |
+
val_dataset_raw = np.load('/PHShome/sy1081/exeye/data/val_final.npy')
|
| 209 |
|
| 210 |
if args.use_subset:
|
| 211 |
def subset(data,train=True):
|
|
|
|
| 213 |
pos = [s for s in data if s[task_idx] != '0.0']
|
| 214 |
num_sample = len(pos)
|
| 215 |
if train:
|
| 216 |
+
return random.sample(neg, 5*num_sample), random.sample(pos, num_sample)
|
| 217 |
else:
|
| 218 |
+
return random.sample(neg, num_sample), pos
|
| 219 |
# return random.sample(neg, 15), random.sample(pos, 15)
|
| 220 |
+
# return neg, pos
|
| 221 |
train_dataset = sum(subset(train_dataset,train=True), [])
|
| 222 |
val_dataset_raw = sum(subset(val_dataset_raw,train=False), [])
|
| 223 |
|
|
|
|
| 230 |
print(f"Total number of Data| Train: {len(train_dataset)} | Val : {len(val_dataset)}")
|
| 231 |
print("="*50)
|
| 232 |
|
| 233 |
+
# model_id = "google/medgemma-4b-it"
|
| 234 |
+
model_id = "google/medgemma-27b-it"
|
| 235 |
model_kwargs = dict(
|
| 236 |
attn_implementation="eager",
|
| 237 |
torch_dtype=torch.bfloat16,
|
|
|
|
| 246 |
bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
|
| 247 |
)
|
| 248 |
|
| 249 |
+
# model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
|
| 250 |
+
|
| 251 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 252 |
+
model_id,
|
| 253 |
+
**model_kwargs
|
| 254 |
+
# torch_dtype=torch.bfloat16,
|
| 255 |
+
# device_map="auto",
|
| 256 |
+
)
|
| 257 |
processor = AutoProcessor.from_pretrained(model_id)
|
| 258 |
|
| 259 |
# Use right padding to avoid issues during training
|
| 260 |
processor.tokenizer.padding_side = "right"
|
|
|
|
|
|
|
| 261 |
|
| 262 |
POS_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("positive")) #30558
|
| 263 |
NEG_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("negative")) #27851
|
| 264 |
ASST_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("model\n"))
|
| 265 |
|
| 266 |
+
IM_SIZE = 1024
|
| 267 |
|
| 268 |
peft_config = LoraConfig(
|
| 269 |
lora_alpha=16,
|
| 270 |
lora_dropout=0.05,
|
| 271 |
+
r=16,
|
| 272 |
bias="none",
|
| 273 |
target_modules="all-linear",
|
| 274 |
# target_modules=["q_proj", "v_proj"],
|
|
|
|
| 286 |
from peft import PeftModel
|
| 287 |
print("🔁 Loading trained PEFT weights...")
|
| 288 |
model = PeftModel.from_pretrained(model, exp_name)
|
| 289 |
+
# model = PeftModel.from_pretrained(model, exp_name+"/checkpoint-690")
|
| 290 |
# model = PeftModel.from_pretrained(model, "llava-1.5-7b-hf-dr-all/checkpoint-80")
|
| 291 |
phase= "val"
|
| 292 |
else:
|
| 293 |
print("🚀 Initializing new LoRA model...")
|
| 294 |
+
# model = prepare_model_for_kbit_training(model)
|
| 295 |
model = get_peft_model(model, peft_config)
|
| 296 |
model.print_trainable_parameters()
|
| 297 |
phase= "train"
|
|
|
|
| 299 |
|
| 300 |
training_args = SFTConfig(
|
| 301 |
output_dir=exp_name,
|
| 302 |
+
num_train_epochs= 20, # Number of training epochs
|
| 303 |
per_device_train_batch_size=4, # Batch size per device during training
|
| 304 |
per_device_eval_batch_size=4, # Batch size per device during evaluation
|
| 305 |
gradient_accumulation_steps=8, # Number of steps before performing a backward/update pass
|
|
|
|
| 309 |
save_strategy="epoch", # Save checkpoint every epoch
|
| 310 |
eval_strategy="steps", # Evaluate every `eval_steps`
|
| 311 |
eval_steps=10000, # Number of steps between evaluations
|
| 312 |
+
learning_rate=5e-4, # Learning rate based on QLoRA paper
|
| 313 |
bf16=True, # Use bfloat16 precision
|
| 314 |
max_grad_norm=0.3, # Max gradient norm based on QLoRA paper
|
| 315 |
warmup_ratio=0.03, # Warmup ratio based on QLoRA paper
|
| 316 |
lr_scheduler_type="linear", # Use linear learning rate scheduler
|
| 317 |
+
# lr_scheduler_type="constant", # Use linear learning rate scheduler
|
| 318 |
push_to_hub=True, # Push model to Hub
|
| 319 |
report_to="tensorboard", # Report metrics to tensorboard
|
| 320 |
gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues
|
|
|
|
| 338 |
# preprocess_logits_for_metrics=slice_logits,
|
| 339 |
)
|
| 340 |
|
| 341 |
+
# if not os.path.exists(exp_name):
|
| 342 |
+
shutil.copy("/PHShome/sy1081/exeye/train_medgemma_ft.py",os.path.join(".",exp_name,"train_medgemma_ft_copy.py"))
|
| 343 |
|
| 344 |
if phase == 'train':
|
| 345 |
trainer.train()
|
| 346 |
trainer.save_model(training_args.output_dir)
|
| 347 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
batch_size = 1
|
| 350 |
model.eval()
|
|
|
|
| 361 |
example["messages"], add_generation_prompt=True, tokenize=False
|
| 362 |
).strip()
|
| 363 |
texts.append(text)
|
| 364 |
+
image = example["image"].convert("RGB").resize((IM_SIZE, IM_SIZE))
|
| 365 |
images.append([image])
|
| 366 |
|
| 367 |
# tokenizer & image processor
|
training_args.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2c0ab1f9caf759796d310240a8f917319ddf8b52bbe1f0b2c42b4b965b668b1c
|
| 3 |
+
size 5816
|