from transformers import AutoModel, AutoConfig from PIL import Image import torch from torchvision import transforms from unet import UNet # Define model # model = UNet(in_channels=1, out_channels=3) # state_dict = torch.load("unet_epoch20.pth", map_location="cpu") # model.load_state_dict(state_dict) # model.eval() config = AutoConfig.from_pretrained(".") model = AutoModel.from_pretrained(".", config=config) model.eval() # Define preprocessing preprocess = transforms.Compose([ transforms.Grayscale(), transforms.Resize((config.image_size, config.image_size)), transforms.ToTensor(), ]) # Define inference function def predict(image: Image.Image) -> Image.Image: input_tensor = preprocess(image).unsqueeze(0) # Shape: (1, C, H, W) with torch.no_grad(): output = model(input_tensor).squeeze(0) # Shape: (3, H, W) prediction = torch.argmax(output, dim=0).byte() # Shape: (H, W) return transforms.ToPILImage()(prediction)