brain-unet-model / inference.py
AndaiMD's picture
Upload folder using huggingface_hub
2b13561 verified
from transformers import AutoModel, AutoConfig
from PIL import Image
import torch
from torchvision import transforms
from unet import UNet
# Define model
# model = UNet(in_channels=1, out_channels=3)
# state_dict = torch.load("unet_epoch20.pth", map_location="cpu")
# model.load_state_dict(state_dict)
# model.eval()
config = AutoConfig.from_pretrained(".")
model = AutoModel.from_pretrained(".", config=config)
model.eval()
# Define preprocessing
preprocess = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((config.image_size, config.image_size)),
transforms.ToTensor(),
])
# Define inference function
def predict(image: Image.Image) -> Image.Image:
input_tensor = preprocess(image).unsqueeze(0) # Shape: (1, C, H, W)
with torch.no_grad():
output = model(input_tensor).squeeze(0) # Shape: (3, H, W)
prediction = torch.argmax(output, dim=0).byte() # Shape: (H, W)
return transforms.ToPILImage()(prediction)