Upload folder using huggingface_hub
Browse files- u2net/saved_models/best-u2net-duts.pt +3 -0
- u2net/train.py +8 -5
u2net/saved_models/best-u2net-duts.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:254a475b314f78277a6eb1d8eb549f588f4ead5e2a885fb177f66173c6d753b7
|
3 |
+
size 134
|
u2net/train.py
CHANGED
@@ -48,13 +48,13 @@ if __name__ == '__main__':
|
|
48 |
valid_batch_size = 80
|
49 |
epochs = 100
|
50 |
|
51 |
-
lr = 1e-
|
52 |
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
|
53 |
|
54 |
model_name = 'u2net-duts'
|
55 |
model = U2Net()
|
56 |
model = torch.nn.DataParallel(model.to(device))
|
57 |
-
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-
|
58 |
|
59 |
train_loader = DataLoader(
|
60 |
ConcatDataset([DUTSDataset(split='train'), MSRADataset(split='train')]),
|
@@ -67,6 +67,7 @@ if __name__ == '__main__':
|
|
67 |
num_workers=16, persistent_workers=True
|
68 |
)
|
69 |
|
|
|
70 |
losses = {'train': [], 'val': []}
|
71 |
for epoch in tqdm(range(epochs), desc='Epochs'):
|
72 |
torch.cuda.empty_cache()
|
@@ -75,10 +76,12 @@ if __name__ == '__main__':
|
|
75 |
losses['train'].append(train_loss)
|
76 |
losses['val'].append(val_loss)
|
77 |
|
78 |
-
if
|
79 |
-
|
|
|
|
|
80 |
|
81 |
-
print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
|
82 |
|
83 |
torch.save(model.state_dict(), f'results/{model_name}.pt')
|
84 |
with open('results/loss.txt', 'wb') as f:
|
|
|
48 |
valid_batch_size = 80
|
49 |
epochs = 100
|
50 |
|
51 |
+
lr = 1e-3
|
52 |
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
|
53 |
|
54 |
model_name = 'u2net-duts'
|
55 |
model = U2Net()
|
56 |
model = torch.nn.DataParallel(model.to(device))
|
57 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
|
58 |
|
59 |
train_loader = DataLoader(
|
60 |
ConcatDataset([DUTSDataset(split='train'), MSRADataset(split='train')]),
|
|
|
67 |
num_workers=16, persistent_workers=True
|
68 |
)
|
69 |
|
70 |
+
best_val_loss = float('inf')
|
71 |
losses = {'train': [], 'val': []}
|
72 |
for epoch in tqdm(range(epochs), desc='Epochs'):
|
73 |
torch.cuda.empty_cache()
|
|
|
76 |
losses['train'].append(train_loss)
|
77 |
losses['val'].append(val_loss)
|
78 |
|
79 |
+
if val_loss < best_val_loss:
|
80 |
+
best_val_loss = val_loss
|
81 |
+
torch.save(model.state_dict(), f'results/best-{model_name}.pt')
|
82 |
+
print('Best model saved.')
|
83 |
|
84 |
+
print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f} (Best: {best_val_loss:.4f})')
|
85 |
|
86 |
torch.save(model.state_dict(), f'results/{model_name}.pt')
|
87 |
with open('results/loss.txt', 'wb') as f:
|