jamino30 commited on
Commit
0f06e3f
·
verified ·
1 Parent(s): 534b489

Upload folder using huggingface_hub

Browse files
u2net/saved_models/best-u2net-duts.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:254a475b314f78277a6eb1d8eb549f588f4ead5e2a885fb177f66173c6d753b7
3
+ size 134
u2net/train.py CHANGED
@@ -48,13 +48,13 @@ if __name__ == '__main__':
48
  valid_batch_size = 80
49
  epochs = 100
50
 
51
- lr = 1e-4
52
  loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
53
 
54
  model_name = 'u2net-duts'
55
  model = U2Net()
56
  model = torch.nn.DataParallel(model.to(device))
57
- optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
58
 
59
  train_loader = DataLoader(
60
  ConcatDataset([DUTSDataset(split='train'), MSRADataset(split='train')]),
@@ -67,6 +67,7 @@ if __name__ == '__main__':
67
  num_workers=16, persistent_workers=True
68
  )
69
 
 
70
  losses = {'train': [], 'val': []}
71
  for epoch in tqdm(range(epochs), desc='Epochs'):
72
  torch.cuda.empty_cache()
@@ -75,10 +76,12 @@ if __name__ == '__main__':
75
  losses['train'].append(train_loss)
76
  losses['val'].append(val_loss)
77
 
78
- if (epoch + 1) % 10 == 0:
79
- torch.save(model.state_dict(), f'results/inter-{model_name}.pt')
 
 
80
 
81
- print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
82
 
83
  torch.save(model.state_dict(), f'results/{model_name}.pt')
84
  with open('results/loss.txt', 'wb') as f:
 
48
  valid_batch_size = 80
49
  epochs = 100
50
 
51
+ lr = 1e-3
52
  loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
53
 
54
  model_name = 'u2net-duts'
55
  model = U2Net()
56
  model = torch.nn.DataParallel(model.to(device))
57
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
58
 
59
  train_loader = DataLoader(
60
  ConcatDataset([DUTSDataset(split='train'), MSRADataset(split='train')]),
 
67
  num_workers=16, persistent_workers=True
68
  )
69
 
70
+ best_val_loss = float('inf')
71
  losses = {'train': [], 'val': []}
72
  for epoch in tqdm(range(epochs), desc='Epochs'):
73
  torch.cuda.empty_cache()
 
76
  losses['train'].append(train_loss)
77
  losses['val'].append(val_loss)
78
 
79
+ if val_loss < best_val_loss:
80
+ best_val_loss = val_loss
81
+ torch.save(model.state_dict(), f'results/best-{model_name}.pt')
82
+ print('Best model saved.')
83
 
84
+ print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f} (Best: {best_val_loss:.4f})')
85
 
86
  torch.save(model.state_dict(), f'results/{model_name}.pt')
87
  with open('results/loss.txt', 'wb') as f: