from transformers import AutoModel, AutoConfig | |
from PIL import Image | |
import torch | |
from torchvision import transforms | |
from unet import UNet | |
# Define model | |
# model = UNet(in_channels=1, out_channels=3) | |
# state_dict = torch.load("unet_epoch20.pth", map_location="cpu") | |
# model.load_state_dict(state_dict) | |
# model.eval() | |
config = AutoConfig.from_pretrained(".") | |
model = AutoModel.from_pretrained(".", config=config) | |
model.eval() | |
# Define preprocessing | |
preprocess = transforms.Compose([ | |
transforms.Grayscale(), | |
transforms.Resize((config.image_size, config.image_size)), | |
transforms.ToTensor(), | |
]) | |
# Define inference function | |
def predict(image: Image.Image) -> Image.Image: | |
input_tensor = preprocess(image).unsqueeze(0) # Shape: (1, C, H, W) | |
with torch.no_grad(): | |
output = model(input_tensor).squeeze(0) # Shape: (3, H, W) | |
prediction = torch.argmax(output, dim=0).byte() # Shape: (H, W) | |
return transforms.ToPILImage()(prediction) |