|
from AGen_model import AGen_model |
|
from videos_dataset import create_video_dataloader |
|
|
|
import hydra |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.loggers import WandbLogger |
|
import wandb |
|
import os |
|
import glob |
|
import torch |
|
|
|
@hydra.main(config_path="confs", config_name="base") |
|
def main(opt): |
|
pl.seed_everything(42) |
|
print("Working dir:", os.getcwd()) |
|
|
|
checkpoint_callback = pl.callbacks.ModelCheckpoint( |
|
dirpath="checkpoints/", |
|
filename="{epoch:04d}-{loss}", |
|
save_on_train_epoch_end=True, |
|
save_last=True) |
|
|
|
wandb.require("service") |
|
logger = WandbLogger(project=opt.project_name, name=f"{opt.project_name}/test") |
|
|
|
|
|
gpu_ids_str = ','.join(map(str, opt.gpus)) |
|
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids_str |
|
|
|
|
|
AGen_trainer = pl.Trainer( |
|
gpus=torch.cuda.device_count(), |
|
accelerator="gpu", |
|
strategy="ddp", |
|
callbacks=[checkpoint_callback], |
|
max_epochs=8000, |
|
check_val_every_n_epoch=5, |
|
logger=logger, |
|
log_every_n_steps=1, |
|
num_sanity_val_steps=0, |
|
) |
|
|
|
model = AGen_model(opt) |
|
checkpoint = "/outputs/checkpoints/epoch=0013-loss=0.ckpt" |
|
testset = create_video_dataloader(opt.videos_dataset.test) |
|
|
|
AGen_trainer.test(model, testset, ckpt_path=checkpoint) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |