import torch from PIL import Image import torchvision.transforms.functional as F from src.pix2pix_turbo import Pix2Pix_Turbo import numpy as np def process_sketch(sketch_path, output_path, prompt, val_r=0.4, seed=42): # Load the model model = Pix2Pix_Turbo("sketch_to_image_stochastic") # Set the seed for reproducibility torch.manual_seed(seed) # Load the sketch image image = Image.open(sketch_path).convert("RGB") # Convert the image to tensor and threshold it image_t = F.to_tensor(image) > 0.5 # Prepare the input tensor with torch.no_grad(): c_t = image_t.unsqueeze(0).cuda().float() B, C, H, W = c_t.shape # Create a random noise map noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device) # Call the Pix2Pix model output_image = model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise) # Convert the output tensor to PIL image output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5) # Save the output image output_pil.save(output_path) print(f"Output image saved to {output_path}") if __name__ == "__main__": sketch_path = "sketch.png" output_path = "output.png" prompt = ("a fantasy concept art of a magical castle in the sky, ") process_sketch(sketch_path, output_path, prompt)