Training in progress, epoch 0
Browse files- .gitattributes +1 -0
- README.md +58 -0
- adapter_config.json +47 -0
- adapter_model.safetensors +3 -0
- added_tokens.json +3 -0
- runs/Jul18_11-30-32_seribizon/events.out.tfevents.1752852638.seribizon.200344.0 +3 -0
- special_tokens_map.json +33 -0
- tokenizer.json +3 -0
- tokenizer.model +3 -0
- tokenizer_config.json +0 -0
- train_medgemma_ft_copy.py +432 -0
- training_args.bin +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: google/medgemma-27b-it
|
| 3 |
+
library_name: transformers
|
| 4 |
+
model_name: medgemma-27b-it-dr5
|
| 5 |
+
tags:
|
| 6 |
+
- generated_from_trainer
|
| 7 |
+
- trl
|
| 8 |
+
- sft
|
| 9 |
+
licence: license
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Model Card for medgemma-27b-it-dr5
|
| 13 |
+
|
| 14 |
+
This model is a fine-tuned version of [google/medgemma-27b-it](https://huggingface.co/google/medgemma-27b-it).
|
| 15 |
+
It has been trained using [TRL](https://github.com/huggingface/trl).
|
| 16 |
+
|
| 17 |
+
## Quick start
|
| 18 |
+
|
| 19 |
+
```python
|
| 20 |
+
from transformers import pipeline
|
| 21 |
+
|
| 22 |
+
question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?"
|
| 23 |
+
generator = pipeline("text-generation", model="berkamphoon/medgemma-27b-it-dr5", device="cuda")
|
| 24 |
+
output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0]
|
| 25 |
+
print(output["generated_text"])
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## Training procedure
|
| 29 |
+
|
| 30 |
+
[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/yoon307-kaist/medgemma-27b-it-dr5-Project/runs/mbxoj7k5)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
This model was trained with SFT.
|
| 34 |
+
|
| 35 |
+
### Framework versions
|
| 36 |
+
|
| 37 |
+
- TRL: 0.19.0
|
| 38 |
+
- Transformers: 4.51.3
|
| 39 |
+
- Pytorch: 2.5.0
|
| 40 |
+
- Datasets: 3.6.0
|
| 41 |
+
- Tokenizers: 0.21.1
|
| 42 |
+
|
| 43 |
+
## Citations
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
Cite TRL as:
|
| 48 |
+
|
| 49 |
+
```bibtex
|
| 50 |
+
@misc{vonwerra2022trl,
|
| 51 |
+
title = {{TRL: Transformer Reinforcement Learning}},
|
| 52 |
+
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec},
|
| 53 |
+
year = 2020,
|
| 54 |
+
journal = {GitHub repository},
|
| 55 |
+
publisher = {GitHub},
|
| 56 |
+
howpublished = {\url{https://github.com/huggingface/trl}}
|
| 57 |
+
}
|
| 58 |
+
```
|
adapter_config.json
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "google/medgemma-27b-it",
|
| 5 |
+
"bias": "none",
|
| 6 |
+
"corda_config": null,
|
| 7 |
+
"eva_config": null,
|
| 8 |
+
"exclude_modules": null,
|
| 9 |
+
"fan_in_fan_out": false,
|
| 10 |
+
"inference_mode": true,
|
| 11 |
+
"init_lora_weights": true,
|
| 12 |
+
"layer_replication": null,
|
| 13 |
+
"layers_pattern": null,
|
| 14 |
+
"layers_to_transform": null,
|
| 15 |
+
"loftq_config": {},
|
| 16 |
+
"lora_alpha": 16,
|
| 17 |
+
"lora_bias": false,
|
| 18 |
+
"lora_dropout": 0.05,
|
| 19 |
+
"megatron_config": null,
|
| 20 |
+
"megatron_core": "megatron.core",
|
| 21 |
+
"modules_to_save": [
|
| 22 |
+
"lm_head",
|
| 23 |
+
"embed_tokens"
|
| 24 |
+
],
|
| 25 |
+
"peft_type": "LORA",
|
| 26 |
+
"qalora_group_size": 16,
|
| 27 |
+
"r": 16,
|
| 28 |
+
"rank_pattern": {},
|
| 29 |
+
"revision": null,
|
| 30 |
+
"target_modules": [
|
| 31 |
+
"v_proj",
|
| 32 |
+
"gate_proj",
|
| 33 |
+
"fc2",
|
| 34 |
+
"k_proj",
|
| 35 |
+
"out_proj",
|
| 36 |
+
"q_proj",
|
| 37 |
+
"o_proj",
|
| 38 |
+
"down_proj",
|
| 39 |
+
"up_proj",
|
| 40 |
+
"fc1"
|
| 41 |
+
],
|
| 42 |
+
"task_type": "CAUSAL_LM",
|
| 43 |
+
"trainable_token_indices": null,
|
| 44 |
+
"use_dora": false,
|
| 45 |
+
"use_qalora": false,
|
| 46 |
+
"use_rslora": false
|
| 47 |
+
}
|
adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:193d5bc56b229dcd16a327b58b3d06056ba3d4a25c915706b577c5185a762759
|
| 3 |
+
size 11766077184
|
added_tokens.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"<image_soft_token>": 262144
|
| 3 |
+
}
|
runs/Jul18_11-30-32_seribizon/events.out.tfevents.1752852638.seribizon.200344.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2974ec26d6853b6c599a617228a9030894a8d4f99e3aec1935aa4d2dda5b0979
|
| 3 |
+
size 9306
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"boi_token": "<start_of_image>",
|
| 3 |
+
"bos_token": {
|
| 4 |
+
"content": "<bos>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false
|
| 9 |
+
},
|
| 10 |
+
"eoi_token": "<end_of_image>",
|
| 11 |
+
"eos_token": {
|
| 12 |
+
"content": "<eos>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false
|
| 17 |
+
},
|
| 18 |
+
"image_token": "<image_soft_token>",
|
| 19 |
+
"pad_token": {
|
| 20 |
+
"content": "<pad>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false
|
| 25 |
+
},
|
| 26 |
+
"unk_token": {
|
| 27 |
+
"content": "<unk>",
|
| 28 |
+
"lstrip": false,
|
| 29 |
+
"normalized": false,
|
| 30 |
+
"rstrip": false,
|
| 31 |
+
"single_word": false
|
| 32 |
+
}
|
| 33 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1ebf1915455f8237564395182c49e3c685cfe3533b3d50ec6d49ce65ec43c32e
|
| 3 |
+
size 33384723
|
tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
|
| 3 |
+
size 4689074
|
tokenizer_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
train_medgemma_ft_copy.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import division, print_function
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# medqa gsm8k openbookqa bioasq pubmedqa squad_v2
|
| 5 |
+
|
| 6 |
+
# === Base ===
|
| 7 |
+
import os
|
| 8 |
+
import os.path as osp
|
| 9 |
+
import random
|
| 10 |
+
import argparse
|
| 11 |
+
import logging
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from matplotlib import pyplot as plt
|
| 14 |
+
import pdb
|
| 15 |
+
from PIL import Image
|
| 16 |
+
import shutil
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
# === DL ===
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
import torch.backends.cudnn as cudnn
|
| 23 |
+
from torch.utils.data import DataLoader
|
| 24 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 25 |
+
|
| 26 |
+
# === Custom ===
|
| 27 |
+
import tools.imutils as imutils
|
| 28 |
+
import tools.utils as utils
|
| 29 |
+
import tools.pyutils as pyutils
|
| 30 |
+
from tools.utils import compute_es_auc, compute_group_auc, ImprovedBalancedBatchSampler, compute_es_auc_multi
|
| 31 |
+
|
| 32 |
+
# === Evaluation ===
|
| 33 |
+
from sklearn.metrics import roc_curve, accuracy_score, roc_auc_score
|
| 34 |
+
|
| 35 |
+
# === Transformers ===
|
| 36 |
+
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig, pipeline
|
| 37 |
+
from peft import LoraConfig, get_peft_model
|
| 38 |
+
from trl import SFTTrainer, SFTConfig
|
| 39 |
+
import wandb
|
| 40 |
+
|
| 41 |
+
# === Label Masking Function ===
|
| 42 |
+
def mask_until_after_assistant(labels: torch.Tensor, tokenizer, assistant_token_ids: list):
|
| 43 |
+
for i in range(labels.size(0)):
|
| 44 |
+
for j in range(labels.size(1) - len(assistant_token_ids) + 1):
|
| 45 |
+
if torch.equal(labels[i, j:j+len(assistant_token_ids)], torch.tensor(assistant_token_ids, device=labels.device)):
|
| 46 |
+
labels[i, :j + len(assistant_token_ids)] = -100 # ASSISTANT: 까지 마스킹
|
| 47 |
+
break
|
| 48 |
+
return labels
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# === Collate Function ===
|
| 52 |
+
def collate_fn(examples):
|
| 53 |
+
texts = []
|
| 54 |
+
images = []
|
| 55 |
+
for example in examples:
|
| 56 |
+
image = example["image"].convert("RGB")
|
| 57 |
+
image = image.resize((512,512))
|
| 58 |
+
images.append([image])
|
| 59 |
+
texts.append(processor.apply_chat_template(
|
| 60 |
+
example["messages"], add_generation_prompt=False, tokenize=False
|
| 61 |
+
).strip())
|
| 62 |
+
|
| 63 |
+
# Tokenize the texts and process the images
|
| 64 |
+
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
| 65 |
+
|
| 66 |
+
# The labels are the input_ids, with the padding and image tokens masked in
|
| 67 |
+
# the loss computation
|
| 68 |
+
labels = batch["input_ids"].clone()
|
| 69 |
+
|
| 70 |
+
# Mask image tokens
|
| 71 |
+
image_token_id = [
|
| 72 |
+
processor.tokenizer.convert_tokens_to_ids(
|
| 73 |
+
processor.tokenizer.special_tokens_map["boi_token"]
|
| 74 |
+
)
|
| 75 |
+
]
|
| 76 |
+
# Mask tokens that are not used in the loss computation
|
| 77 |
+
labels[labels == processor.tokenizer.pad_token_id] = -100
|
| 78 |
+
labels[labels == image_token_id] = -100
|
| 79 |
+
labels[labels == 262144] = -100
|
| 80 |
+
|
| 81 |
+
labels = mask_until_after_assistant(labels, processor.tokenizer, ASST_ID)
|
| 82 |
+
labels[:,-1] = -100
|
| 83 |
+
|
| 84 |
+
batch["labels"] = labels
|
| 85 |
+
# pdb.set_trace()
|
| 86 |
+
return batch
|
| 87 |
+
|
| 88 |
+
def format_data(sample):
|
| 89 |
+
label = 'negative' if sample[task_idx] == '0.0' else 'positive'
|
| 90 |
+
prompt = f"Please diagnose whether the {disease_name} exist or not based on the given image.\n"
|
| 91 |
+
|
| 92 |
+
# pdb.set_trace()
|
| 93 |
+
example = {}
|
| 94 |
+
example["image"] = Image.open(os.path.join(img_root_path, sample[1]))
|
| 95 |
+
example["label"] = 0 if sample[task_idx]== '0,0' else 1
|
| 96 |
+
example["messages"] = [
|
| 97 |
+
{"role": "system", "content": [{"type": "text", "text": system_message}]},
|
| 98 |
+
{"role": "user", "content": [
|
| 99 |
+
# {"type": "image", "image": os.path.join(img_root_path, sample[1])},
|
| 100 |
+
{"type": "image"},
|
| 101 |
+
{"type": "text", "text": prompt},
|
| 102 |
+
]},
|
| 103 |
+
{"role": "assistant", "content": [{"type": "text", "text": str(label)}]}
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
return example
|
| 107 |
+
|
| 108 |
+
def format_data_for_inference(sample):
|
| 109 |
+
prompt = f"Please diagnose whether the {disease_name} exist or not based on the given image.\n"
|
| 110 |
+
|
| 111 |
+
# pdb.set_trace()
|
| 112 |
+
example = {}
|
| 113 |
+
example["image"] = Image.open(os.path.join(img_root_path, sample[1]))
|
| 114 |
+
# example["label"] = 0 if sample[task_idx]== '0,0' else 1
|
| 115 |
+
example["messages"] = [
|
| 116 |
+
{"role": "system", "content": [{"type": "text", "text": system_message}]},
|
| 117 |
+
{"role": "user", "content": [
|
| 118 |
+
# {"type": "image", "image": os.path.join(img_root_path, sample[1])},
|
| 119 |
+
{"type": "image"},
|
| 120 |
+
{"type": "text", "text": prompt+"\n"},
|
| 121 |
+
]},
|
| 122 |
+
# {"role": "assistant", "content": [{"type": "text", "text": str(label)}]}
|
| 123 |
+
]
|
| 124 |
+
# prompt = f"Please diagnose whether the {disease_name} exist or not based on the given image."
|
| 125 |
+
# return [
|
| 126 |
+
# {"role": "system", "content": [{"type": "text", "text": system_message}]},
|
| 127 |
+
# {"role": "user", "content": [
|
| 128 |
+
# {"type": "image", "image": os.path.join(img_root_path, sample[1])},
|
| 129 |
+
# {"type": "text", "text": prompt}
|
| 130 |
+
# ]}
|
| 131 |
+
# ]
|
| 132 |
+
return example
|
| 133 |
+
|
| 134 |
+
# === Logit Preprocessing ===
|
| 135 |
+
def slice_logits(logits, labels):
|
| 136 |
+
if isinstance(logits, (tuple, list)):
|
| 137 |
+
logits = logits[0]
|
| 138 |
+
return logits.detach().cpu()
|
| 139 |
+
|
| 140 |
+
def compute_metrics(eval_pred):
|
| 141 |
+
logits = torch.tensor(eval_pred.predictions)
|
| 142 |
+
|
| 143 |
+
token_ids = logits.argmax(dim=-1) # (B, L): predicted token at each position
|
| 144 |
+
|
| 145 |
+
batch_logits = []
|
| 146 |
+
for b in range(logits.size(0)):
|
| 147 |
+
seq = token_ids[b] # (L,)
|
| 148 |
+
idxs = torch.where((seq == POS_ID[0]) | (seq == NEG_ID[0]))[0]
|
| 149 |
+
if len(idxs) == 0:
|
| 150 |
+
raise ValueError(f"Neither pos_id nor neg_id found in sequence {b}")
|
| 151 |
+
t = idxs[0].item() # first position where pos or neg appears
|
| 152 |
+
tok_id = seq[t].item() # should be either pos_id or neg_id
|
| 153 |
+
batch_logits.append(logits[b, t, tok_id]) # scalar
|
| 154 |
+
|
| 155 |
+
batch_logits = torch.stack(batch_logits) # shape: [B]
|
| 156 |
+
pred_texts = processor.tokenizer.batch_decode(token_ids[:,-1], skip_special_tokens=True)
|
| 157 |
+
|
| 158 |
+
# print(pred_texts)
|
| 159 |
+
# pdb.set_trace()
|
| 160 |
+
probs = torch.sigmoid(logits[:,-1, POS_ID[0]] - logits[:,-1, NEG_ID[0]]).numpy()
|
| 161 |
+
|
| 162 |
+
# probs = torch.sigmoid(batch_logits).numpy()
|
| 163 |
+
labels = torch.tensor(eval_pred.label_ids)
|
| 164 |
+
gt_ids = labels[labels != -100].view(logits.size(0), -1)[:, 0]
|
| 165 |
+
y_true = (gt_ids == POS_ID[0]).int().cpu().numpy()
|
| 166 |
+
auc_val = roc_auc_score(y_true, probs)
|
| 167 |
+
fpr, tpr, thr = roc_curve(y_true, probs)
|
| 168 |
+
best = thr[np.argmax(tpr - fpr)]
|
| 169 |
+
acc = accuracy_score(y_true, probs >= best)
|
| 170 |
+
return {"roc_auc": auc_val, "accuracy": acc}
|
| 171 |
+
|
| 172 |
+
def run_custom_evaluation(trainer, val_dataset, val_labels):
|
| 173 |
+
outputs = trainer.predict(val_dataset)
|
| 174 |
+
logits = torch.from_numpy(outputs.predictions) # (B, S, L)
|
| 175 |
+
# pdb.set_trace()
|
| 176 |
+
probs = torch.sigmoid(logits[:,-1, POS_ID[0]] - logits[:,-1, NEG_ID[0]]).numpy()
|
| 177 |
+
|
| 178 |
+
# decoded = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
| 179 |
+
# y_pred = [1 if "positive" in t.lower() else 0 for t in decoded]
|
| 180 |
+
|
| 181 |
+
auc_val = roc_auc_score(val_labels, probs)
|
| 182 |
+
# acc = accuracy_score(val_labels, y_pred)
|
| 183 |
+
print(f"[Custom Eval] AUC: {auc_val:.4f}")
|
| 184 |
+
# print(f"[Custom Eval] AUC: {auc_val:.4f}, ACC: {acc:.4f}")
|
| 185 |
+
return {"auc": auc_val}
|
| 186 |
+
|
| 187 |
+
# === Main ===
|
| 188 |
+
if __name__ == '__main__':
|
| 189 |
+
parser = argparse.ArgumentParser()
|
| 190 |
+
parser.add_argument("--task", required=True, help='amd, dr, glaucoma')
|
| 191 |
+
parser.add_argument("--name", required=True)
|
| 192 |
+
parser.add_argument("--use_subset", action='store_true')
|
| 193 |
+
args = parser.parse_args()
|
| 194 |
+
|
| 195 |
+
pyutils.same_seeds(0)
|
| 196 |
+
|
| 197 |
+
task_map = {'dr': (-3, 'Diabetic Retinopathy'), 'amd': (-2, 'Aged Macular Degeneration'), 'glaucoma': (-1, 'Glaucoma')}
|
| 198 |
+
task_idx, disease_name = task_map[args.task]
|
| 199 |
+
system_message = f"""You are an expert AI in ophthalmology.\n
|
| 200 |
+
Your primary role is to provide accurate, reliable, and up-to-date medical knowledge based on credible sources.\n
|
| 201 |
+
You must follow these guidelines:\n
|
| 202 |
+
1. Be accurate, concise, and clinically relevant.\n
|
| 203 |
+
2. Use proper medical terms.\n
|
| 204 |
+
3. Avoid overexplaining unless requested.\n
|
| 205 |
+
4. Tone: confident, professional, precise.\n
|
| 206 |
+
Do not include any explanation or thought.\n
|
| 207 |
+
If {disease_name} is present, answer exactly 'positive'. Otherwise answer 'negative'."""
|
| 208 |
+
# Diabetic Retinopathy (DR) is a diabetes-related eye disease that affects the retina — the light-sensitive tissue at the back of the eye. It occurs when chronically high blood sugar levels damage the small blood vessels in the retina, leading to leakage, blockage, or abnormal blood vessel growth.\n
|
| 209 |
+
|
| 210 |
+
cudnn.benchmark = True
|
| 211 |
+
img_root_path = '/shared/ssd_30T/yoon/exEYE/Eyeproject/data'
|
| 212 |
+
train_dataset = np.load('/shared/ssd_30T/yoon/exEYE/datasplit/train_final.npy')
|
| 213 |
+
val_dataset_raw = np.load('/shared/ssd_30T/yoon/exEYE/datasplit/val_final.npy')
|
| 214 |
+
|
| 215 |
+
if args.use_subset:
|
| 216 |
+
def subset(data,train=True):
|
| 217 |
+
neg = [s for s in data if s[task_idx] == '0.0']
|
| 218 |
+
pos = [s for s in data if s[task_idx] != '0.0']
|
| 219 |
+
num_sample = len(pos)
|
| 220 |
+
if train:
|
| 221 |
+
return random.sample(neg, 7*num_sample), random.sample(pos, num_sample)
|
| 222 |
+
else:
|
| 223 |
+
return random.sample(neg, 3*num_sample), random.sample(pos, num_sample)
|
| 224 |
+
# return random.sample(neg, 15), random.sample(pos, 15)
|
| 225 |
+
# return neg, random.sample(pos, num_sample)
|
| 226 |
+
train_dataset = sum(subset(train_dataset,train=True), [])
|
| 227 |
+
val_dataset_raw = sum(subset(val_dataset_raw,train=False), [])
|
| 228 |
+
|
| 229 |
+
train_dataset = [format_data(s) for s in tqdm(train_dataset)]
|
| 230 |
+
random.shuffle(train_dataset)
|
| 231 |
+
val_dataset = [format_data_for_inference(s) for s in tqdm(val_dataset_raw)]
|
| 232 |
+
val_labels = [1 if s[task_idx] != '0.0' else 0 for s in val_dataset_raw]
|
| 233 |
+
# val_dataset = [format_data(s) for s in tqdm(val_dataset)]
|
| 234 |
+
print("="*50)
|
| 235 |
+
print(f"Total number of Data| Train: {len(train_dataset)} | Val : {len(val_dataset)}")
|
| 236 |
+
print("="*50)
|
| 237 |
+
|
| 238 |
+
model_id = "google/medgemma-4b-it"
|
| 239 |
+
model_kwargs = dict(
|
| 240 |
+
attn_implementation="eager",
|
| 241 |
+
torch_dtype=torch.bfloat16,
|
| 242 |
+
device_map="auto",
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 246 |
+
load_in_4bit=True,
|
| 247 |
+
bnb_4bit_use_double_quant=True,
|
| 248 |
+
bnb_4bit_quant_type="nf4",
|
| 249 |
+
bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
|
| 250 |
+
bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
|
| 254 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
| 255 |
+
|
| 256 |
+
# Use right padding to avoid issues during training
|
| 257 |
+
processor.tokenizer.padding_side = "right"
|
| 258 |
+
# processor.image_processor.size = {"height": 512, "width": 512}
|
| 259 |
+
# processor.image_processor.crop_size = {"height": 512, "width": 512}
|
| 260 |
+
|
| 261 |
+
POS_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("positive")) #30558
|
| 262 |
+
NEG_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("negative")) #27851
|
| 263 |
+
ASST_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("model\n"))
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
peft_config = LoraConfig(
|
| 267 |
+
lora_alpha=16,
|
| 268 |
+
lora_dropout=0.05,
|
| 269 |
+
r=32,
|
| 270 |
+
bias="none",
|
| 271 |
+
target_modules="all-linear",
|
| 272 |
+
# target_modules=["q_proj", "v_proj"],
|
| 273 |
+
task_type="CAUSAL_LM",
|
| 274 |
+
modules_to_save=[
|
| 275 |
+
"lm_head",
|
| 276 |
+
"embed_tokens",
|
| 277 |
+
],
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
exp_name = f"{model_id.split('/')[-1]}-{args.name}"
|
| 282 |
+
|
| 283 |
+
if os.path.exists(exp_name):
|
| 284 |
+
from peft import PeftModel
|
| 285 |
+
print("🔁 Loading trained PEFT weights...")
|
| 286 |
+
model = PeftModel.from_pretrained(model, exp_name)
|
| 287 |
+
# model = PeftModel.from_pretrained(model, exp_name+"/checkpoint-242")
|
| 288 |
+
# model = PeftModel.from_pretrained(model, "llava-1.5-7b-hf-dr-all/checkpoint-80")
|
| 289 |
+
phase= "val"
|
| 290 |
+
else:
|
| 291 |
+
print("🚀 Initializing new LoRA model...")
|
| 292 |
+
model = get_peft_model(model, peft_config)
|
| 293 |
+
model.print_trainable_parameters()
|
| 294 |
+
phase= "train"
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
training_args = SFTConfig(
|
| 298 |
+
output_dir=exp_name,
|
| 299 |
+
num_train_epochs= 16, # Number of training epochs
|
| 300 |
+
per_device_train_batch_size=4, # Batch size per device during training
|
| 301 |
+
per_device_eval_batch_size=4, # Batch size per device during evaluation
|
| 302 |
+
gradient_accumulation_steps=8, # Number of steps before performing a backward/update pass
|
| 303 |
+
gradient_checkpointing=True, # Enable gradient checkpointing to reduce memory usage
|
| 304 |
+
optim="adamw_torch_fused", # Use fused AdamW optimizer for better performance
|
| 305 |
+
logging_steps=10, # Number of steps between logs
|
| 306 |
+
save_strategy="epoch", # Save checkpoint every epoch
|
| 307 |
+
eval_strategy="steps", # Evaluate every `eval_steps`
|
| 308 |
+
eval_steps=10000, # Number of steps between evaluations
|
| 309 |
+
learning_rate=8e-4, # Learning rate based on QLoRA paper
|
| 310 |
+
bf16=True, # Use bfloat16 precision
|
| 311 |
+
max_grad_norm=0.3, # Max gradient norm based on QLoRA paper
|
| 312 |
+
warmup_ratio=0.03, # Warmup ratio based on QLoRA paper
|
| 313 |
+
lr_scheduler_type="linear", # Use linear learning rate scheduler
|
| 314 |
+
push_to_hub=True, # Push model to Hub
|
| 315 |
+
report_to="tensorboard", # Report metrics to tensorboard
|
| 316 |
+
gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues
|
| 317 |
+
dataset_kwargs={"skip_prepare_dataset": True}, # Skip default dataset preparation to preprocess manually
|
| 318 |
+
remove_unused_columns = False, # Columns are unused for training but needed for data collator
|
| 319 |
+
label_names=["labels"],
|
| 320 |
+
)
|
| 321 |
+
# training_args.remove_unused_columns = False
|
| 322 |
+
|
| 323 |
+
wandb.init(project=f"{exp_name}-Project", name=exp_name, config=training_args)
|
| 324 |
+
|
| 325 |
+
trainer = SFTTrainer(
|
| 326 |
+
model=model,
|
| 327 |
+
args=training_args,
|
| 328 |
+
train_dataset=train_dataset,
|
| 329 |
+
eval_dataset=val_dataset,
|
| 330 |
+
data_collator=collate_fn,
|
| 331 |
+
peft_config=peft_config,
|
| 332 |
+
processing_class=processor.tokenizer,
|
| 333 |
+
# compute_metrics=compute_metrics,
|
| 334 |
+
# preprocess_logits_for_metrics=slice_logits,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
if not os.path.exists(exp_name):
|
| 338 |
+
shutil.copy("/shared/ssd_30T/yoon/exEYE/Eyeproject/train_medgemma_ft.py",os.path.join(".",exp_name,"train_medgemma_ft_copy.py"))
|
| 339 |
+
|
| 340 |
+
if phase == 'train':
|
| 341 |
+
trainer.train()
|
| 342 |
+
trainer.save_model(training_args.output_dir)
|
| 343 |
+
|
| 344 |
+
# custom_eval_metrics = run_custom_evaluation(trainer, val_dataset, val_labels)
|
| 345 |
+
# else:
|
| 346 |
+
# ft_pipe = pipeline(
|
| 347 |
+
# "image-text-to-text",
|
| 348 |
+
# model=exp_name,
|
| 349 |
+
# processor=processor,
|
| 350 |
+
# torch_dtype=torch.bfloat16,
|
| 351 |
+
# )
|
| 352 |
+
|
| 353 |
+
# # Set `do_sample = False` for deterministic responses
|
| 354 |
+
# ft_pipe.model.generation_config.do_sample = False
|
| 355 |
+
# ft_pipe.model.generation_config.pad_token_id = processor.tokenizer.eos_token_id
|
| 356 |
+
# # Use left padding during inference
|
| 357 |
+
# processor.tokenizer.padding_side = "left"
|
| 358 |
+
|
| 359 |
+
# texts = []
|
| 360 |
+
# images = []
|
| 361 |
+
|
| 362 |
+
# for example in val_dataset:
|
| 363 |
+
# text = processor.apply_chat_template(
|
| 364 |
+
# example["messages"], add_generation_prompt=True, tokenize=False
|
| 365 |
+
# ).strip()
|
| 366 |
+
# texts.append(text)
|
| 367 |
+
# image = example["image"].convert("RGB").resize((512, 512))
|
| 368 |
+
# images.append([image]) # 리스트로 감싸야 MedGEMMA가 기대하는 batched format
|
| 369 |
+
|
| 370 |
+
# # pdb.set_trace()
|
| 371 |
+
# ft_outputs = ft_pipe(
|
| 372 |
+
# text=texts,
|
| 373 |
+
# images=images,
|
| 374 |
+
# max_new_tokens=5,
|
| 375 |
+
# batch_size=1,
|
| 376 |
+
# return_full_text=False,
|
| 377 |
+
# )
|
| 378 |
+
|
| 379 |
+
batch_size = 1
|
| 380 |
+
model.eval()
|
| 381 |
+
all_logits = []
|
| 382 |
+
|
| 383 |
+
for i in tqdm(range(0, len(val_dataset), batch_size), desc="Running inference with logits"):
|
| 384 |
+
batch = val_dataset[i:i + batch_size]
|
| 385 |
+
|
| 386 |
+
# prepare inputs
|
| 387 |
+
texts = []
|
| 388 |
+
images = []
|
| 389 |
+
for example in batch:
|
| 390 |
+
text = processor.apply_chat_template(
|
| 391 |
+
example["messages"], add_generation_prompt=True, tokenize=False
|
| 392 |
+
).strip()
|
| 393 |
+
texts.append(text)
|
| 394 |
+
image = example["image"].convert("RGB").resize((512, 512))
|
| 395 |
+
images.append([image])
|
| 396 |
+
|
| 397 |
+
# tokenizer & image processor
|
| 398 |
+
with torch.no_grad():
|
| 399 |
+
texts[0] += "\n"
|
| 400 |
+
inputs = processor(
|
| 401 |
+
text=texts,
|
| 402 |
+
images=images,
|
| 403 |
+
return_tensors="pt",
|
| 404 |
+
padding=True
|
| 405 |
+
).to(model.device)
|
| 406 |
+
|
| 407 |
+
outputs = model(**inputs, output_hidden_states=False, return_dict=True)
|
| 408 |
+
|
| 409 |
+
print("==> ",processor.tokenizer.decode(outputs.logits[0].argmax(-1)[-1]))
|
| 410 |
+
|
| 411 |
+
logits = outputs.logits
|
| 412 |
+
# pdb.set_trace()
|
| 413 |
+
probs = torch.sigmoid(logits[0,-1, POS_ID] - logits[0,-1, NEG_ID])
|
| 414 |
+
# logits: (B, L, V)
|
| 415 |
+
# all_logits.append(outputs.logits.to(torch.float32).detach().cpu().numpy())
|
| 416 |
+
all_logits.append(probs)
|
| 417 |
+
|
| 418 |
+
# pdb.set_trace()
|
| 419 |
+
|
| 420 |
+
probs_all = torch.stack(all_logits,dim=0)
|
| 421 |
+
probs_all = [prob.to(torch.float32).detach().cpu() for prob in probs_all]
|
| 422 |
+
# logits= torch.from_numpy(np.stack(all_logits,axis=0)).squeeze(1)
|
| 423 |
+
|
| 424 |
+
# probs = torch.sigmoid(logits[:,-1, POS_ID] - logits[:,-1, NEG_ID])
|
| 425 |
+
|
| 426 |
+
# decoded = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
| 427 |
+
# y_pred = [1 if "positive" in t.lower() else 0 for t in decoded]
|
| 428 |
+
# pdb.set_trace()
|
| 429 |
+
auc_val = roc_auc_score(val_labels, probs_all)
|
| 430 |
+
print(auc_val)
|
| 431 |
+
|
| 432 |
+
# print(trainer.evaluate())
|
training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:61713d7b70980b1dac1979fbf4fa512bed3f7bbc0fa63cf78beb8efa0e918976
|
| 3 |
+
size 5752
|