| | import os |
| | import torch |
| | import torch.nn as nn |
| | import torchvision.transforms as T |
| | import torchvision.datasets as datasets |
| | from torch.utils.data import DataLoader |
| | from tqdm import tqdm |
| | from torchvision.models import resnet18 |
| |
|
| |
|
| | |
| | class SSLModel(nn.Module): |
| | def __init__(self, backbone, projection_dim=128): |
| | super(SSLModel, self).__init__() |
| | self.backbone = backbone |
| | self.projection_head = nn.Sequential( |
| | nn.Linear(backbone.fc.in_features, 512), |
| | nn.ReLU(), |
| | nn.Linear(512, projection_dim) |
| | ) |
| | self.backbone.fc = nn.Identity() |
| |
|
| | def forward(self, x): |
| | features = self.backbone(x) |
| | projections = self.projection_head(features) |
| | return projections |
| |
|
| |
|
| | |
| | def contrastive_loss(z_i, z_j, temperature=0.5): |
| | batch_size = z_i.shape[0] |
| |
|
| | |
| | z = torch.cat([z_i, z_j], dim=0) |
| |
|
| | |
| | sim_matrix = torch.mm(z, z.T) / temperature |
| |
|
| | |
| | sim_matrix = sim_matrix - torch.max(sim_matrix, dim=1, keepdim=True)[0] |
| |
|
| | |
| | mask = torch.eye(sim_matrix.size(0), device=sim_matrix.device).bool() |
| | sim_matrix = sim_matrix.masked_fill(mask, -float("inf")) |
| |
|
| | |
| | pos_sim = torch.cat([ |
| | torch.diag(sim_matrix, sim_matrix.size(0) // 2), |
| | torch.diag(sim_matrix, -sim_matrix.size(0) // 2) |
| | ]) |
| |
|
| | |
| | loss = -torch.log(torch.exp(pos_sim) / torch.sum(torch.exp(sim_matrix), dim=1)) |
| | return loss.mean() |
| |
|
| |
|
| | def train_ssl(): |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | transform = T.Compose([ |
| | T.RandomResizedCrop(32), |
| | T.RandomHorizontalFlip(), |
| | T.ColorJitter(0.4, 0.4, 0.4, 0.1), |
| | T.RandomGrayscale(p=0.2), |
| | T.GaussianBlur(kernel_size=3), |
| | T.ToTensor(), |
| | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| | ]) |
| |
|
| | |
| | train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True) |
| | train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, pin_memory=True, num_workers=4) |
| |
|
| | |
| | model = SSLModel(resnet18(pretrained=False)).to(device) |
| |
|
| | |
| | optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4) |
| | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) |
| |
|
| | |
| | start_epoch = 1 |
| | checkpoint_path = "models/saves/run2/ssl_checkpoint_epoch_14.pth" |
| | if os.path.exists(checkpoint_path): |
| | print(f"Resuming training from checkpoint: {checkpoint_path}") |
| | checkpoint = torch.load(checkpoint_path, map_location=device) |
| | model.load_state_dict(checkpoint["model_state_dict"]) |
| | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) |
| | start_epoch = checkpoint["epoch"] + 1 |
| |
|
| | |
| | os.makedirs("checkpoints", exist_ok=True) |
| |
|
| | |
| | model.train() |
| | total_epochs = 15 |
| | for epoch in range(start_epoch, total_epochs + 1): |
| | epoch_loss = 0 |
| | progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{total_epochs}", unit="batch") |
| |
|
| | for batch in progress_bar: |
| | imgs, _ = batch |
| | imgs = imgs.to(device, non_blocking=True) |
| |
|
| | |
| | z_i = model(imgs) |
| | z_j = model(imgs) |
| |
|
| | |
| | assert not torch.isnan(z_i).any(), "z_i contains NaN values!" |
| | assert not torch.isnan(z_j).any(), "z_j contains NaN values!" |
| |
|
| | try: |
| | loss = contrastive_loss(z_i, z_j) |
| | except Exception as e: |
| | print(f"Loss computation failed: {e}") |
| | continue |
| |
|
| | optimizer.zero_grad() |
| | loss.backward() |
| |
|
| | |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
| | optimizer.step() |
| |
|
| | |
| | epoch_loss += loss.item() |
| | progress_bar.set_postfix(loss=f"{loss.item():.4f}") |
| |
|
| | scheduler.step() |
| | print(f"Epoch {epoch}, Average Loss: {epoch_loss / len(train_loader):.4f}") |
| |
|
| | |
| | save_path = f"checkpoints/ssl_checkpoint_epoch_{epoch}.pth" |
| | torch.save({ |
| | "epoch": epoch, |
| | "model_state_dict": model.state_dict(), |
| | "optimizer_state_dict": optimizer.state_dict(), |
| | }, save_path) |
| | print(f"Model saved to {save_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | train_ssl() |
| |
|
| |
|