device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet3D().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) for epoch in range(50): for batch in tqdm(dataloader): video = batch["video"].to(device) text = batch["text"].to(device) t = torch.randint(0, 1000, (video.shape[0], 1)).to(device) noise = torch.randn_like(video) alpha_t = (1 - t/1000).view(-1, 1, 1, 1, 1) noisy_video = torch.sqrt(alpha_t) * video + torch.sqrt(1 - alpha_t) * noise pred_noise = model(noisy_video, t/1000, text) loss = F.mse_loss(pred_noise, noise) optimizer.zero_grad() loss.backward() optimizer.step() print(f"Epoch {epoch}, Loss: {loss.item():.4f}")