|
|
|
from transformers import AutoProcessor, AutoModelForImageTextToText |
|
from PIL import Image |
|
import requests |
|
import torch |
|
|
|
model_id = "google/medgemma-4b-it" |
|
|
|
model = AutoModelForImageTextToText.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
) |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
|
|
image_url = "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png" |
|
image = Image.open(requests.get(image_url, headers={"User-Agent": "example"}, stream=True).raw) |
|
|
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": [{"type": "text", "text": "You are an expert radiologist."}] |
|
}, |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": "Describe this X-ray"}, |
|
{"type": "image", "image": image} |
|
] |
|
} |
|
] |
|
|
|
inputs = processor.apply_chat_template( |
|
messages, add_generation_prompt=True, tokenize=True, |
|
return_dict=True, return_tensors="pt" |
|
).to(model.device, dtype=torch.bfloat16) |
|
|
|
input_len = inputs["input_ids"].shape[-1] |
|
|
|
with torch.inference_mode(): |
|
generation = model.generate(**inputs, max_new_tokens=200, do_sample=False) |
|
generation = generation[0][input_len:] |
|
|
|
decoded = processor.decode(generation, skip_special_tokens=True) |
|
print(decoded) |
|
|