import time | |
import pytorch_lightning as pl | |
class TimeLimitCallback(pl.Callback): | |
def __init__(self, max_duration_seconds): | |
super().__init__() | |
self.max_duration_seconds = max_duration_seconds | |
self.start_time = None | |
def on_train_start(self, trainer, pl_module): | |
self.start_time = time.time() | |
def on_batch_end(self, trainer, pl_module): | |
elapsed_time = time.time() - self.start_time | |
if elapsed_time >= self.max_duration_seconds: | |
print(f"Training stopped due to time constraint ({self.max_duration_seconds} seconds).") | |
trainer.should_stop = True # Stops training | |