File size: 646 Bytes
e41b635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import time
import pytorch_lightning as pl

class TimeLimitCallback(pl.Callback):
    def __init__(self, max_duration_seconds):
        super().__init__()
        self.max_duration_seconds = max_duration_seconds
        self.start_time = None

    def on_train_start(self, trainer, pl_module):
        self.start_time = time.time()

    def on_batch_end(self, trainer, pl_module):
        elapsed_time = time.time() - self.start_time
        if elapsed_time >= self.max_duration_seconds:
            print(f"Training stopped due to time constraint ({self.max_duration_seconds} seconds).")
            trainer.should_stop = True  # Stops training