Training in progress, epoch 1
Browse files- README.md +7 -7
- adapter_config.json +7 -7
- adapter_model.safetensors +2 -2
- chat_template.jinja +47 -0
- runs/Jul19_23-56-17_meedgxh100a/events.out.tfevents.1752983779.meedgxh100a.1020669.0 +3 -0
- tokenizer_config.json +0 -1
- train_medgemma_ft_copy.py +402 -0
- training_args.bin +2 -2
README.md
CHANGED
|
@@ -4,8 +4,8 @@ library_name: transformers
|
|
| 4 |
model_name: medgemma-27b-it-dr4
|
| 5 |
tags:
|
| 6 |
- generated_from_trainer
|
| 7 |
-
- trl
|
| 8 |
- sft
|
|
|
|
| 9 |
licence: license
|
| 10 |
---
|
| 11 |
|
|
@@ -27,18 +27,18 @@ print(output["generated_text"])
|
|
| 27 |
|
| 28 |
## Training procedure
|
| 29 |
|
| 30 |
-
[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/yoon307-kaist/medgemma-27b-it-dr4-Project/runs/
|
| 31 |
|
| 32 |
|
| 33 |
This model was trained with SFT.
|
| 34 |
|
| 35 |
### Framework versions
|
| 36 |
|
| 37 |
-
- TRL: 0.19.
|
| 38 |
-
- Transformers: 4.
|
| 39 |
-
- Pytorch: 2.
|
| 40 |
-
- Datasets:
|
| 41 |
-
- Tokenizers: 0.21.
|
| 42 |
|
| 43 |
## Citations
|
| 44 |
|
|
|
|
| 4 |
model_name: medgemma-27b-it-dr4
|
| 5 |
tags:
|
| 6 |
- generated_from_trainer
|
|
|
|
| 7 |
- sft
|
| 8 |
+
- trl
|
| 9 |
licence: license
|
| 10 |
---
|
| 11 |
|
|
|
|
| 27 |
|
| 28 |
## Training procedure
|
| 29 |
|
| 30 |
+
[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/yoon307-kaist/medgemma-27b-it-dr4-Project/runs/y71p22um)
|
| 31 |
|
| 32 |
|
| 33 |
This model was trained with SFT.
|
| 34 |
|
| 35 |
### Framework versions
|
| 36 |
|
| 37 |
+
- TRL: 0.19.1
|
| 38 |
+
- Transformers: 4.53.2
|
| 39 |
+
- Pytorch: 2.6.0+cu124
|
| 40 |
+
- Datasets: 4.0.0
|
| 41 |
+
- Tokenizers: 0.21.2
|
| 42 |
|
| 43 |
## Citations
|
| 44 |
|
adapter_config.json
CHANGED
|
@@ -24,20 +24,20 @@
|
|
| 24 |
],
|
| 25 |
"peft_type": "LORA",
|
| 26 |
"qalora_group_size": 16,
|
| 27 |
-
"r":
|
| 28 |
"rank_pattern": {},
|
| 29 |
"revision": null,
|
| 30 |
"target_modules": [
|
| 31 |
-
"
|
| 32 |
"up_proj",
|
| 33 |
-
"
|
|
|
|
|
|
|
|
|
|
| 34 |
"o_proj",
|
| 35 |
"fc2",
|
| 36 |
-
"fc1",
|
| 37 |
-
"out_proj",
|
| 38 |
"k_proj",
|
| 39 |
-
"
|
| 40 |
-
"q_proj"
|
| 41 |
],
|
| 42 |
"task_type": "CAUSAL_LM",
|
| 43 |
"trainable_token_indices": null,
|
|
|
|
| 24 |
],
|
| 25 |
"peft_type": "LORA",
|
| 26 |
"qalora_group_size": 16,
|
| 27 |
+
"r": 16,
|
| 28 |
"rank_pattern": {},
|
| 29 |
"revision": null,
|
| 30 |
"target_modules": [
|
| 31 |
+
"out_proj",
|
| 32 |
"up_proj",
|
| 33 |
+
"fc1",
|
| 34 |
+
"q_proj",
|
| 35 |
+
"gate_proj",
|
| 36 |
+
"v_proj",
|
| 37 |
"o_proj",
|
| 38 |
"fc2",
|
|
|
|
|
|
|
| 39 |
"k_proj",
|
| 40 |
+
"down_proj"
|
|
|
|
| 41 |
],
|
| 42 |
"task_type": "CAUSAL_LM",
|
| 43 |
"trainable_token_indices": null,
|
adapter_model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:46088cfd14dbb5cfc8432351881e54ce4f916e9bb12fb44f8f1c04005524d622
|
| 3 |
+
size 6127553104
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{{ bos_token }}
|
| 2 |
+
{%- if messages[0]['role'] == 'system' -%}
|
| 3 |
+
{%- if messages[0]['content'] is string -%}
|
| 4 |
+
{%- set first_user_prefix = messages[0]['content'] + '
|
| 5 |
+
|
| 6 |
+
' -%}
|
| 7 |
+
{%- else -%}
|
| 8 |
+
{%- set first_user_prefix = messages[0]['content'][0]['text'] + '
|
| 9 |
+
|
| 10 |
+
' -%}
|
| 11 |
+
{%- endif -%}
|
| 12 |
+
{%- set loop_messages = messages[1:] -%}
|
| 13 |
+
{%- else -%}
|
| 14 |
+
{%- set first_user_prefix = "" -%}
|
| 15 |
+
{%- set loop_messages = messages -%}
|
| 16 |
+
{%- endif -%}
|
| 17 |
+
{%- for message in loop_messages -%}
|
| 18 |
+
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
|
| 19 |
+
{{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
|
| 20 |
+
{%- endif -%}
|
| 21 |
+
{%- if (message['role'] == 'assistant') -%}
|
| 22 |
+
{%- set role = "model" -%}
|
| 23 |
+
{%- else -%}
|
| 24 |
+
{%- set role = message['role'] -%}
|
| 25 |
+
{%- endif -%}
|
| 26 |
+
{{ '<start_of_turn>' + role + '
|
| 27 |
+
' + (first_user_prefix if loop.first else "") }}
|
| 28 |
+
{%- if message['content'] is string -%}
|
| 29 |
+
{{ message['content'] | trim }}
|
| 30 |
+
{%- elif message['content'] is iterable -%}
|
| 31 |
+
{%- for item in message['content'] -%}
|
| 32 |
+
{%- if item['type'] == 'image' -%}
|
| 33 |
+
{{ '<start_of_image>' }}
|
| 34 |
+
{%- elif item['type'] == 'text' -%}
|
| 35 |
+
{{ item['text'] | trim }}
|
| 36 |
+
{%- endif -%}
|
| 37 |
+
{%- endfor -%}
|
| 38 |
+
{%- else -%}
|
| 39 |
+
{{ raise_exception("Invalid content type") }}
|
| 40 |
+
{%- endif -%}
|
| 41 |
+
{{ '<end_of_turn>
|
| 42 |
+
' }}
|
| 43 |
+
{%- endfor -%}
|
| 44 |
+
{%- if add_generation_prompt -%}
|
| 45 |
+
{{'<start_of_turn>model
|
| 46 |
+
'}}
|
| 47 |
+
{%- endif -%}
|
runs/Jul19_23-56-17_meedgxh100a/events.out.tfevents.1752983779.meedgxh100a.1020669.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0702869b05003d340911d40565a0adbea440ad368add0a194753e18b069fe298
|
| 3 |
+
size 9916
|
tokenizer_config.json
CHANGED
|
@@ -51325,7 +51325,6 @@
|
|
| 51325 |
},
|
| 51326 |
"boi_token": "<start_of_image>",
|
| 51327 |
"bos_token": "<bos>",
|
| 51328 |
-
"chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n",
|
| 51329 |
"clean_up_tokenization_spaces": false,
|
| 51330 |
"eoi_token": "<end_of_image>",
|
| 51331 |
"eos_token": "<eos>",
|
|
|
|
| 51325 |
},
|
| 51326 |
"boi_token": "<start_of_image>",
|
| 51327 |
"bos_token": "<bos>",
|
|
|
|
| 51328 |
"clean_up_tokenization_spaces": false,
|
| 51329 |
"eoi_token": "<end_of_image>",
|
| 51330 |
"eos_token": "<eos>",
|
train_medgemma_ft_copy.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import division, print_function
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# medqa gsm8k openbookqa bioasq pubmedqa squad_v2
|
| 5 |
+
|
| 6 |
+
# === Base ===
|
| 7 |
+
import os
|
| 8 |
+
import os.path as osp
|
| 9 |
+
import random
|
| 10 |
+
import argparse
|
| 11 |
+
import logging
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from matplotlib import pyplot as plt
|
| 14 |
+
import pdb
|
| 15 |
+
from PIL import Image
|
| 16 |
+
import shutil
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
# === DL ===
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
import torch.backends.cudnn as cudnn
|
| 23 |
+
from torch.utils.data import DataLoader
|
| 24 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 25 |
+
|
| 26 |
+
# === Custom ===
|
| 27 |
+
# import tools.imutils as imutils
|
| 28 |
+
# import tools.utils as utils
|
| 29 |
+
# import tools.pyutils as pyutils
|
| 30 |
+
# from tools.utils import compute_es_auc, compute_group_auc, ImprovedBalancedBatchSampler, compute_es_auc_multi
|
| 31 |
+
|
| 32 |
+
# === Evaluation ===
|
| 33 |
+
from sklearn.metrics import roc_curve, accuracy_score, roc_auc_score
|
| 34 |
+
|
| 35 |
+
# === Transformers ===
|
| 36 |
+
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig, pipeline, AutoModelForCausalLM
|
| 37 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 38 |
+
from trl import SFTTrainer, SFTConfig
|
| 39 |
+
import wandb
|
| 40 |
+
|
| 41 |
+
# === Label Masking Function ===
|
| 42 |
+
def mask_until_after_assistant(labels: torch.Tensor, tokenizer, assistant_token_ids: list):
|
| 43 |
+
for i in range(labels.size(0)):
|
| 44 |
+
for j in range(labels.size(1) - len(assistant_token_ids) + 1):
|
| 45 |
+
if torch.equal(labels[i, j:j+len(assistant_token_ids)], torch.tensor(assistant_token_ids, device=labels.device)):
|
| 46 |
+
labels[i, :j + len(assistant_token_ids)] = -100 # ASSISTANT: 까지 마스킹
|
| 47 |
+
break
|
| 48 |
+
return labels
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# === Collate Function ===
|
| 52 |
+
def collate_fn(examples):
|
| 53 |
+
texts = []
|
| 54 |
+
images = []
|
| 55 |
+
for example in examples:
|
| 56 |
+
image = example["image"].convert("RGB")
|
| 57 |
+
image = image.resize((IM_SIZE,IM_SIZE))
|
| 58 |
+
images.append([image])
|
| 59 |
+
texts.append(processor.apply_chat_template(
|
| 60 |
+
example["messages"], add_generation_prompt=False, tokenize=False
|
| 61 |
+
).strip())
|
| 62 |
+
|
| 63 |
+
# Tokenize the texts and process the images
|
| 64 |
+
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
| 65 |
+
|
| 66 |
+
# The labels are the input_ids, with the padding and image tokens masked in
|
| 67 |
+
# the loss computation
|
| 68 |
+
labels = batch["input_ids"].clone()
|
| 69 |
+
|
| 70 |
+
# Mask image tokens
|
| 71 |
+
image_token_id = [
|
| 72 |
+
processor.tokenizer.convert_tokens_to_ids(
|
| 73 |
+
processor.tokenizer.special_tokens_map["boi_token"]
|
| 74 |
+
)
|
| 75 |
+
]
|
| 76 |
+
# Mask tokens that are not used in the loss computation
|
| 77 |
+
labels[labels == processor.tokenizer.pad_token_id] = -100
|
| 78 |
+
labels[labels == image_token_id] = -100
|
| 79 |
+
labels[labels == 262144] = -100
|
| 80 |
+
|
| 81 |
+
labels = mask_until_after_assistant(labels, processor.tokenizer, ASST_ID)
|
| 82 |
+
labels[:,-1] = -100
|
| 83 |
+
|
| 84 |
+
batch["labels"] = labels
|
| 85 |
+
# pdb.set_trace()
|
| 86 |
+
return batch
|
| 87 |
+
|
| 88 |
+
def format_data(sample):
|
| 89 |
+
label = 'negative' if sample[task_idx] == '0.0' else 'positive'
|
| 90 |
+
prompt = f"Please diagnose whether the {disease_name} exist or not based on the given image.\n"
|
| 91 |
+
|
| 92 |
+
# pdb.set_trace()
|
| 93 |
+
example = {}
|
| 94 |
+
example["image"] = Image.open(os.path.join(img_root_path, sample[1]))
|
| 95 |
+
example["label"] = 0 if sample[task_idx]== '0,0' else 1
|
| 96 |
+
example["messages"] = [
|
| 97 |
+
{"role": "system", "content": [{"type": "text", "text": system_message}]},
|
| 98 |
+
{"role": "user", "content": [
|
| 99 |
+
# {"type": "image", "image": os.path.join(img_root_path, sample[1])},
|
| 100 |
+
{"type": "image"},
|
| 101 |
+
{"type": "text", "text": prompt},
|
| 102 |
+
]},
|
| 103 |
+
{"role": "assistant", "content": [{"type": "text", "text": str(label)}]}
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
return example
|
| 107 |
+
|
| 108 |
+
def format_data_for_inference(sample):
|
| 109 |
+
prompt = f"Please diagnose whether the {disease_name} exist or not based on the given image.\n"
|
| 110 |
+
|
| 111 |
+
# pdb.set_trace()
|
| 112 |
+
example = {}
|
| 113 |
+
example["image"] = Image.open(os.path.join(img_root_path, sample[1]))
|
| 114 |
+
# example["label"] = 0 if sample[task_idx]== '0,0' else 1
|
| 115 |
+
example["messages"] = [
|
| 116 |
+
{"role": "system", "content": [{"type": "text", "text": system_message}]},
|
| 117 |
+
{"role": "user", "content": [
|
| 118 |
+
# {"type": "image", "image": os.path.join(img_root_path, sample[1])},
|
| 119 |
+
{"type": "image"},
|
| 120 |
+
{"type": "text", "text": prompt+"\n"},
|
| 121 |
+
]},
|
| 122 |
+
# {"role": "assistant", "content": [{"type": "text", "text": str(label)}]}
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
return example
|
| 126 |
+
|
| 127 |
+
# === Logit Preprocessing ===
|
| 128 |
+
def slice_logits(logits, labels):
|
| 129 |
+
if isinstance(logits, (tuple, list)):
|
| 130 |
+
logits = logits[0]
|
| 131 |
+
return logits.detach().cpu()
|
| 132 |
+
|
| 133 |
+
def compute_metrics(eval_pred):
|
| 134 |
+
logits = torch.tensor(eval_pred.predictions)
|
| 135 |
+
|
| 136 |
+
token_ids = logits.argmax(dim=-1) # (B, L): predicted token at each position
|
| 137 |
+
|
| 138 |
+
batch_logits = []
|
| 139 |
+
for b in range(logits.size(0)):
|
| 140 |
+
seq = token_ids[b] # (L,)
|
| 141 |
+
idxs = torch.where((seq == POS_ID[0]) | (seq == NEG_ID[0]))[0]
|
| 142 |
+
if len(idxs) == 0:
|
| 143 |
+
raise ValueError(f"Neither pos_id nor neg_id found in sequence {b}")
|
| 144 |
+
t = idxs[0].item() # first position where pos or neg appears
|
| 145 |
+
tok_id = seq[t].item() # should be either pos_id or neg_id
|
| 146 |
+
batch_logits.append(logits[b, t, tok_id]) # scalar
|
| 147 |
+
|
| 148 |
+
batch_logits = torch.stack(batch_logits) # shape: [B]
|
| 149 |
+
pred_texts = processor.tokenizer.batch_decode(token_ids[:,-1], skip_special_tokens=True)
|
| 150 |
+
|
| 151 |
+
# print(pred_texts)
|
| 152 |
+
# pdb.set_trace()
|
| 153 |
+
probs = torch.sigmoid(logits[:,-1, POS_ID[0]] - logits[:,-1, NEG_ID[0]]).numpy()
|
| 154 |
+
|
| 155 |
+
# probs = torch.sigmoid(batch_logits).numpy()
|
| 156 |
+
labels = torch.tensor(eval_pred.label_ids)
|
| 157 |
+
gt_ids = labels[labels != -100].view(logits.size(0), -1)[:, 0]
|
| 158 |
+
y_true = (gt_ids == POS_ID[0]).int().cpu().numpy()
|
| 159 |
+
auc_val = roc_auc_score(y_true, probs)
|
| 160 |
+
fpr, tpr, thr = roc_curve(y_true, probs)
|
| 161 |
+
best = thr[np.argmax(tpr - fpr)]
|
| 162 |
+
acc = accuracy_score(y_true, probs >= best)
|
| 163 |
+
return {"roc_auc": auc_val, "accuracy": acc}
|
| 164 |
+
|
| 165 |
+
def run_custom_evaluation(trainer, val_dataset, val_labels):
|
| 166 |
+
outputs = trainer.predict(val_dataset)
|
| 167 |
+
logits = torch.from_numpy(outputs.predictions) # (B, S, L)
|
| 168 |
+
# pdb.set_trace()
|
| 169 |
+
probs = torch.sigmoid(logits[:,-1, POS_ID[0]] - logits[:,-1, NEG_ID[0]]).numpy()
|
| 170 |
+
|
| 171 |
+
# decoded = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
| 172 |
+
# y_pred = [1 if "positive" in t.lower() else 0 for t in decoded]
|
| 173 |
+
|
| 174 |
+
auc_val = roc_auc_score(val_labels, probs)
|
| 175 |
+
# acc = accuracy_score(val_labels, y_pred)
|
| 176 |
+
print(f"[Custom Eval] AUC: {auc_val:.4f}")
|
| 177 |
+
# print(f"[Custom Eval] AUC: {auc_val:.4f}, ACC: {acc:.4f}")
|
| 178 |
+
return {"auc": auc_val}
|
| 179 |
+
|
| 180 |
+
# === Main ===
|
| 181 |
+
if __name__ == '__main__':
|
| 182 |
+
parser = argparse.ArgumentParser()
|
| 183 |
+
parser.add_argument("--task", required=True, help='amd, dr, glaucoma')
|
| 184 |
+
parser.add_argument("--name", required=True)
|
| 185 |
+
parser.add_argument("--use_subset", action='store_true')
|
| 186 |
+
args = parser.parse_args()
|
| 187 |
+
|
| 188 |
+
random.seed(42)
|
| 189 |
+
|
| 190 |
+
# pyutils.same_seeds(0)
|
| 191 |
+
|
| 192 |
+
task_map = {'dr': (-3, 'Diabetic Retinopathy'), 'amd': (-2, 'Aged Macular Degeneration'), 'glaucoma': (-1, 'Glaucoma')}
|
| 193 |
+
task_idx, disease_name = task_map[args.task]
|
| 194 |
+
system_message = f"""You are an expert AI in ophthalmology.\n
|
| 195 |
+
Your primary role is to provide accurate, reliable, and up-to-date medical knowledge based on credible sources.\n
|
| 196 |
+
You must follow these guidelines:\n
|
| 197 |
+
1. Be accurate, concise, and clinically relevant.\n
|
| 198 |
+
2. Use proper medical terms.\n
|
| 199 |
+
3. Avoid overexplaining unless requested.\n
|
| 200 |
+
4. Tone: confident, professional, precise.\n
|
| 201 |
+
Do not include any explanation or thought.\n
|
| 202 |
+
Diabetic Retinopathy (DR) is a diabetes-related eye disease that affects the retina — the light-sensitive tissue at the back of the eye. It occurs when chronically high blood sugar levels damage the small blood vessels in the retina, leading to leakage, blockage, or abnormal blood vessel growth.\n
|
| 203 |
+
If {disease_name} is present, answer exactly 'positive'. Otherwise answer 'negative'."""
|
| 204 |
+
|
| 205 |
+
cudnn.benchmark = True
|
| 206 |
+
img_root_path = '/PHShome/sy1081/exeye/data'
|
| 207 |
+
train_dataset = np.load('/PHShome/sy1081/exeye/data/train_final.npy')
|
| 208 |
+
val_dataset_raw = np.load('/PHShome/sy1081/exeye/data/val_final.npy')
|
| 209 |
+
|
| 210 |
+
if args.use_subset:
|
| 211 |
+
def subset(data,train=True):
|
| 212 |
+
neg = [s for s in data if s[task_idx] == '0.0']
|
| 213 |
+
pos = [s for s in data if s[task_idx] != '0.0']
|
| 214 |
+
num_sample = len(pos)
|
| 215 |
+
if train:
|
| 216 |
+
return random.sample(neg, 5*num_sample), random.sample(pos, num_sample)
|
| 217 |
+
else:
|
| 218 |
+
# return random.sample(neg, 5*num_sample), pos
|
| 219 |
+
# return random.sample(neg, 15), random.sample(pos, 15)
|
| 220 |
+
return neg, pos
|
| 221 |
+
train_dataset = sum(subset(train_dataset,train=True), [])
|
| 222 |
+
val_dataset_raw = sum(subset(val_dataset_raw,train=False), [])
|
| 223 |
+
|
| 224 |
+
train_dataset = [format_data(s) for s in tqdm(train_dataset)]
|
| 225 |
+
random.shuffle(train_dataset)
|
| 226 |
+
val_dataset = [format_data_for_inference(s) for s in tqdm(val_dataset_raw)]
|
| 227 |
+
val_labels = [1 if s[task_idx] != '0.0' else 0 for s in val_dataset_raw]
|
| 228 |
+
# val_dataset = [format_data(s) for s in tqdm(val_dataset)]
|
| 229 |
+
print("="*50)
|
| 230 |
+
print(f"Total number of Data| Train: {len(train_dataset)} | Val : {len(val_dataset)}")
|
| 231 |
+
print("="*50)
|
| 232 |
+
|
| 233 |
+
# model_id = "google/medgemma-4b-it"
|
| 234 |
+
model_id = "google/medgemma-27b-it"
|
| 235 |
+
model_kwargs = dict(
|
| 236 |
+
attn_implementation="eager",
|
| 237 |
+
torch_dtype=torch.bfloat16,
|
| 238 |
+
device_map="auto",
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 242 |
+
load_in_4bit=True,
|
| 243 |
+
bnb_4bit_use_double_quant=True,
|
| 244 |
+
bnb_4bit_quant_type="nf4",
|
| 245 |
+
bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
|
| 246 |
+
bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
|
| 250 |
+
|
| 251 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 252 |
+
model_id,
|
| 253 |
+
**model_kwargs
|
| 254 |
+
# torch_dtype=torch.bfloat16,
|
| 255 |
+
# device_map="auto",
|
| 256 |
+
)
|
| 257 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
| 258 |
+
|
| 259 |
+
# Use right padding to avoid issues during training
|
| 260 |
+
processor.tokenizer.padding_side = "right"
|
| 261 |
+
|
| 262 |
+
POS_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("positive")) #30558
|
| 263 |
+
NEG_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("negative")) #27851
|
| 264 |
+
ASST_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("model\n"))
|
| 265 |
+
|
| 266 |
+
IM_SIZE = 512
|
| 267 |
+
|
| 268 |
+
peft_config = LoraConfig(
|
| 269 |
+
lora_alpha=16,
|
| 270 |
+
lora_dropout=0.05,
|
| 271 |
+
r=16,
|
| 272 |
+
bias="none",
|
| 273 |
+
target_modules="all-linear",
|
| 274 |
+
# target_modules=["q_proj", "v_proj"],
|
| 275 |
+
task_type="CAUSAL_LM",
|
| 276 |
+
modules_to_save=[
|
| 277 |
+
"lm_head",
|
| 278 |
+
"embed_tokens",
|
| 279 |
+
],
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
exp_name = f"{model_id.split('/')[-1]}-{args.name}"
|
| 284 |
+
|
| 285 |
+
if os.path.exists(exp_name):
|
| 286 |
+
from peft import PeftModel
|
| 287 |
+
print("🔁 Loading trained PEFT weights...")
|
| 288 |
+
# model = PeftModel.from_pretrained(model, exp_name)
|
| 289 |
+
model = PeftModel.from_pretrained(model, exp_name+"/checkpoint-598")
|
| 290 |
+
# model = PeftModel.from_pretrained(model, "llava-1.5-7b-hf-dr-all/checkpoint-80")
|
| 291 |
+
phase= "val"
|
| 292 |
+
else:
|
| 293 |
+
print("🚀 Initializing new LoRA model...")
|
| 294 |
+
# model = prepare_model_for_kbit_training(model)
|
| 295 |
+
model = get_peft_model(model, peft_config)
|
| 296 |
+
model.print_trainable_parameters()
|
| 297 |
+
phase= "train"
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
training_args = SFTConfig(
|
| 301 |
+
output_dir=exp_name,
|
| 302 |
+
num_train_epochs= 15, # Number of training epochs
|
| 303 |
+
per_device_train_batch_size=2, # Batch size per device during training
|
| 304 |
+
per_device_eval_batch_size=4, # Batch size per device during evaluation
|
| 305 |
+
gradient_accumulation_steps=8, # Number of steps before performing a backward/update pass
|
| 306 |
+
gradient_checkpointing=True, # Enable gradient checkpointing to reduce memory usage
|
| 307 |
+
optim="adamw_torch_fused", # Use fused AdamW optimizer for better performance
|
| 308 |
+
logging_steps=10, # Number of steps between logs
|
| 309 |
+
save_strategy="epoch", # Save checkpoint every epoch
|
| 310 |
+
eval_strategy="steps", # Evaluate every `eval_steps`
|
| 311 |
+
eval_steps=10000, # Number of steps between evaluations
|
| 312 |
+
learning_rate=1e-3, # Learning rate based on QLoRA paper
|
| 313 |
+
bf16=True, # Use bfloat16 precision
|
| 314 |
+
max_grad_norm=0.3, # Max gradient norm based on QLoRA paper
|
| 315 |
+
warmup_ratio=0.03, # Warmup ratio based on QLoRA paper
|
| 316 |
+
lr_scheduler_type="linear", # Use linear learning rate scheduler
|
| 317 |
+
# lr_scheduler_type="constant", # Use linear learning rate scheduler
|
| 318 |
+
push_to_hub=True, # Push model to Hub
|
| 319 |
+
report_to="tensorboard", # Report metrics to tensorboard
|
| 320 |
+
gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues
|
| 321 |
+
dataset_kwargs={"skip_prepare_dataset": True}, # Skip default dataset preparation to preprocess manually
|
| 322 |
+
remove_unused_columns = False, # Columns are unused for training but needed for data collator
|
| 323 |
+
label_names=["labels"],
|
| 324 |
+
)
|
| 325 |
+
# training_args.remove_unused_columns = False
|
| 326 |
+
|
| 327 |
+
wandb.init(project=f"{exp_name}-Project", name=exp_name, config=training_args)
|
| 328 |
+
|
| 329 |
+
trainer = SFTTrainer(
|
| 330 |
+
model=model,
|
| 331 |
+
args=training_args,
|
| 332 |
+
train_dataset=train_dataset,
|
| 333 |
+
eval_dataset=val_dataset,
|
| 334 |
+
data_collator=collate_fn,
|
| 335 |
+
peft_config=peft_config,
|
| 336 |
+
processing_class=processor.tokenizer,
|
| 337 |
+
# compute_metrics=compute_metrics,
|
| 338 |
+
# preprocess_logits_for_metrics=slice_logits,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# if not os.path.exists(exp_name):
|
| 342 |
+
shutil.copy("/PHShome/sy1081/exeye/train_medgemma_ft.py",os.path.join(".",exp_name,"train_medgemma_ft_copy.py"))
|
| 343 |
+
|
| 344 |
+
if phase == 'train':
|
| 345 |
+
trainer.train()
|
| 346 |
+
trainer.save_model(training_args.output_dir)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
batch_size = 1
|
| 350 |
+
model.eval()
|
| 351 |
+
all_logits = []
|
| 352 |
+
|
| 353 |
+
for i in tqdm(range(0, len(val_dataset), batch_size), desc="Running inference with logits"):
|
| 354 |
+
batch = val_dataset[i:i + batch_size]
|
| 355 |
+
|
| 356 |
+
# prepare inputs
|
| 357 |
+
texts = []
|
| 358 |
+
images = []
|
| 359 |
+
for example in batch:
|
| 360 |
+
text = processor.apply_chat_template(
|
| 361 |
+
example["messages"], add_generation_prompt=True, tokenize=False
|
| 362 |
+
).strip()
|
| 363 |
+
texts.append(text)
|
| 364 |
+
image = example["image"].convert("RGB").resize((IM_SIZE, IM_SIZE))
|
| 365 |
+
images.append([image])
|
| 366 |
+
|
| 367 |
+
# tokenizer & image processor
|
| 368 |
+
with torch.no_grad():
|
| 369 |
+
texts[0] += "\n"
|
| 370 |
+
inputs = processor(
|
| 371 |
+
text=texts,
|
| 372 |
+
images=images,
|
| 373 |
+
return_tensors="pt",
|
| 374 |
+
padding=True
|
| 375 |
+
).to(model.device)
|
| 376 |
+
|
| 377 |
+
outputs = model(**inputs, output_hidden_states=False, return_dict=True)
|
| 378 |
+
|
| 379 |
+
print("==> ",processor.tokenizer.decode(outputs.logits[0].argmax(-1)[-1]))
|
| 380 |
+
|
| 381 |
+
logits = outputs.logits
|
| 382 |
+
# pdb.set_trace()
|
| 383 |
+
probs = torch.sigmoid(logits[0,-1, POS_ID] - logits[0,-1, NEG_ID])
|
| 384 |
+
# logits: (B, L, V)
|
| 385 |
+
# all_logits.append(outputs.logits.to(torch.float32).detach().cpu().numpy())
|
| 386 |
+
all_logits.append(probs)
|
| 387 |
+
|
| 388 |
+
# pdb.set_trace()
|
| 389 |
+
|
| 390 |
+
probs_all = torch.stack(all_logits,dim=0)
|
| 391 |
+
probs_all = [prob.to(torch.float32).detach().cpu() for prob in probs_all]
|
| 392 |
+
# logits= torch.from_numpy(np.stack(all_logits,axis=0)).squeeze(1)
|
| 393 |
+
|
| 394 |
+
# probs = torch.sigmoid(logits[:,-1, POS_ID] - logits[:,-1, NEG_ID])
|
| 395 |
+
|
| 396 |
+
# decoded = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
| 397 |
+
# y_pred = [1 if "positive" in t.lower() else 0 for t in decoded]
|
| 398 |
+
# pdb.set_trace()
|
| 399 |
+
auc_val = roc_auc_score(val_labels, probs_all)
|
| 400 |
+
print(auc_val)
|
| 401 |
+
|
| 402 |
+
# print(trainer.evaluate())
|
training_args.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:26753afe5611dc69c3ec3c59e8748980463f74e5c7c31a5103a131c32c91af02
|
| 3 |
+
size 5816
|