import os import toml import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from datasets import load_dataset from tqdm.auto import tqdm import torch.nn.functional as F from diffusers import UNet2DConditionModel, DDPMScheduler from transformers import CLIPTextModel, CLIPTokenizer from peft import get_peft_model, LoraConfig from accelerate import Accelerator from PIL import Image if __name__ == "__main__": # Leer configuración config = toml.load("config_lora.toml") # Acelerador (GPU si disponible) accelerator = Accelerator() device = accelerator.device # Dataset personalizado class CustomImageDataset(Dataset): def __init__(self, dataset, size=512): self.dataset = dataset self.transform = transforms.Compose([ transforms.Resize((size, size)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) def __len__(self): return len(self.dataset) def __getitem__(self, idx): image = self.dataset[idx]["image"] if not isinstance(image, Image.Image): image = Image.fromarray(image) return self.transform(image) # Cargar dataset desde Hugging Face Hub dataset = load_dataset(config["dataset_dir"], split="train") dataset = CustomImageDataset(dataset, size=config["resolution"]) dataloader = DataLoader(dataset, batch_size=config["train_batch_size"], shuffle=True) # Modelo base tokenizer = CLIPTokenizer.from_pretrained(config["pretrained_model_name_or_path"], subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(config["pretrained_model_name_or_path"], subfolder="text_encoder") unet = UNet2DConditionModel.from_pretrained(config["pretrained_model_name_or_path"], subfolder="unet") # LoRA lora_config = LoraConfig( r=config["rank"], lora_alpha=config["lora_alpha"], lora_dropout=config["lora_dropout"], bias="none" ) unet = get_peft_model(unet, lora_config) # Scheduler noise_scheduler = DDPMScheduler.from_pretrained(config["pretrained_model_name_or_path"], subfolder="scheduler") # Optimizador optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9) # Preparar todo para entrenamiento unet, optimizer, dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, dataloader, lr_scheduler) unet.train() # Entrenamiento global_step = 0 progress_bar = tqdm(range(config["max_train_steps"]), desc="Entrenando", leave=True) for epoch in range(100): for batch in dataloader: clean_images = batch.to(device) noise = torch.randn_like(clean_images) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (clean_images.shape[0],), device=device).long() noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) prompt = ["una mujer latina con vestido verde"] * clean_images.shape[0] text_input = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").to(device) encoder_hidden_states = text_encoder(text_input.input_ids)[0] model_pred = unet(noisy_images, timesteps, encoder_hidden_states=encoder_hidden_states).sample loss = F.mse_loss(model_pred, noise) accelerator.backward(loss) optimizer.step() lr_scheduler.step() optimizer.zero_grad() progress_bar.update(1) progress_bar.set_postfix({"loss": loss.item()}) global_step += 1 if global_step % config["checkpointing_steps"] == 0: os.makedirs(config["output_dir"], exist_ok=True) torch.save(unet.state_dict(), os.path.join(config["output_dir"], f"checkpoint_{global_step}.pt")) if global_step >= config["max_train_steps"]: break if global_step >= config["max_train_steps"]: break # Guardar pesos finales os.makedirs(config["output_dir"], exist_ok=True) torch.save(unet.state_dict(), os.path.join(config["output_dir"], "final_lora.pt"))