|
from VideoReconstructionModel.code.v2a_model import V2AModel |
|
from lib.model.v2a import V2A |
|
from VideoReconstructionModel.code.lib.model.networks import ImplicitNet, RenderingNet, GeometryEncodingNet |
|
from VideoReconstructionModel.code.lib.datasets import create_dataset |
|
from Callbacks import TimeLimitCallback |
|
|
|
import torch |
|
import torch.nn as nn |
|
import pytorch_lightning as pl |
|
import torch.nn.functional as F |
|
from torchvision import models |
|
from pytorch_lightning.loggers import WandbLogger |
|
import wandb |
|
import os |
|
import glob |
|
import yaml |
|
|
|
class AGen_model(pl.LightningModule): |
|
def __init__(self, opt): |
|
super(AGen_model, self).__init__() |
|
|
|
|
|
self.opt = opt |
|
|
|
|
|
self.implicit_network = ImplicitNet(opt.model.implicit_network) |
|
|
|
|
|
|
|
|
|
|
|
self.rendering_network = RenderingNet(opt.model.rendering_network) |
|
|
|
|
|
|
|
|
|
def training_step(self, batch): |
|
|
|
|
|
|
|
torch.cuda.set_device(self.device) |
|
video_path = batch[0] |
|
|
|
metainfo_path = os.path.join(video_path, 'confs/', 'video_metainfo.yaml') |
|
with open(metainfo_path, 'r') as file: |
|
loaded_config = yaml.safe_load(file) |
|
self.opt.dataset.metainfo = loaded_config.get('metainfo', {}) |
|
|
|
|
|
video_outputs_folder = os.path.abspath(os.path.join(os.getcwd(), 'Video/', self.opt.dataset.metainfo.data_dir)) |
|
video_checkpoints_folder = os.path.abspath(os.path.join(os.getcwd(), 'Video/', self.opt.dataset.metainfo.data_dir, 'checkpoints')) |
|
if os.path.exists(video_checkpoints_folder): |
|
self.opt.model.smpl_init = False |
|
|
|
|
|
'''During this step, video reconstruction is performed on each single video, resulting in |
|
the training of the implicit network, the rendering network and the relative encodings |
|
''' |
|
|
|
checkpoint_callback = pl.callbacks.ModelCheckpoint( |
|
dirpath=video_checkpoints_folder, |
|
filename="{epoch:04d}-{loss}", |
|
save_on_train_epoch_end=True, |
|
save_last=True) |
|
time_limit_callback = TimeLimitCallback(max_duration_seconds=self.opt.max_duration_seconds) |
|
|
|
|
|
trainset = create_dataset(self.opt.dataset.metainfo, self.opt.dataset.train) |
|
validset = create_dataset(self.opt.dataset.metainfo, self.opt.dataset.valid) |
|
|
|
|
|
if os.path.exists(video_checkpoints_folder): |
|
checkpoint = os.path.join(video_checkpoints_folder, 'last.ckpt') |
|
v2a_trainer = pl.Trainer( |
|
gpus=[self.device.index], |
|
accelerator="gpu", |
|
callbacks=[checkpoint_callback, time_limit_callback], |
|
max_epochs=8000, |
|
logger=self.logger, |
|
log_every_n_steps=1, |
|
num_sanity_val_steps=0, |
|
resume_from_checkpoint=checkpoint, |
|
enable_progress_bar=False, |
|
enable_model_summary=False |
|
) |
|
model = V2AModel.load_from_checkpoint(checkpoint, opt=self.opt, implicit_network=self.implicit_network, |
|
rendering_network=self.rendering_network) |
|
model.model.implicit_network = self.implicit_network |
|
model.model.rendering_network = self.rendering_network |
|
else: |
|
v2a_trainer = pl.Trainer( |
|
gpus=[self.device.index], |
|
accelerator="gpu", |
|
callbacks=[checkpoint_callback, time_limit_callback], |
|
max_epochs=8000, |
|
logger=self.logger, |
|
log_every_n_steps=1, |
|
num_sanity_val_steps=0, |
|
enable_progress_bar=False, |
|
enable_model_summary=False |
|
) |
|
model = V2AModel(self.opt, implicit_network=self.implicit_network, |
|
rendering_network=self.rendering_network) |
|
|
|
v2a_trainer.fit(model, trainset) |
|
|
|
|
|
|
|
|
|
v2a_trainer.validate(model, validset) |
|
validation_metrics = v2a_trainer.callback_metrics |
|
|
|
return |
|
|
|
def configure_optimizers(self): |
|
|
|
|
|
|
|
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) |
|
return optimizer |
|
|
|
|
|
'''This validation is performed at inference time on the unseen videos in the validset''' |
|
def validation_step(self, batch, *args, **kwargs): |
|
|
|
|
|
|
|
torch.cuda.set_device(self.device) |
|
video_path = batch[0] |
|
|
|
metainfo_path = os.path.join(video_path, 'confs/', 'video_metainfo.yaml') |
|
with open(metainfo_path, 'r') as file: |
|
loaded_config = yaml.safe_load(file) |
|
self.opt.dataset.metainfo = loaded_config.get('metainfo', {}) |
|
|
|
video_outputs_folder = os.path.abspath(os.path.join(os.getcwd(), 'Video/', self.opt.dataset.metainfo.data_dir)) |
|
video_checkpoints_folder = os.path.abspath(os.path.join(os.getcwd(), 'Video/', self.opt.dataset.metainfo.data_dir, 'checkpoints')) |
|
if os.path.exists(video_checkpoints_folder): |
|
self.opt.model.smpl_init = False |
|
|
|
|
|
|
|
trainset = create_dataset(self.opt.dataset.metainfo, self.opt.dataset.train) |
|
validset = create_dataset(self.opt.dataset.metainfo, self.opt.dataset.valid) |
|
|
|
|
|
|
|
checkpoint_callback = pl.callbacks.ModelCheckpoint( |
|
dirpath=video_checkpoints_folder, |
|
filename="{epoch:04d}-{loss}", |
|
save_on_train_epoch_end=True, |
|
save_last=True) |
|
time_limit_callback = TimeLimitCallback(max_duration_seconds=self.opt.max_duration_seconds) |
|
|
|
v2a_trainer = pl.Trainer( |
|
gpus=[self.device.index], |
|
accelerator="gpu", |
|
callbacks=[checkpoint_callback, time_limit_callback], |
|
max_epochs=8000, |
|
logger=self.logger, |
|
log_every_n_steps=1, |
|
num_sanity_val_steps=0, |
|
enable_progress_bar=False, |
|
enable_model_summary=False |
|
) |
|
model = V2AModel(self.opt, implicit_network=self.implicit_network, |
|
rendering_network=self.rendering_network) |
|
|
|
|
|
self.implicit_network.eval() |
|
|
|
|
|
self.rendering_network.eval() |
|
|
|
|
|
model.eval() |
|
|
|
|
|
v2a_trainer.validate(model, validset) |
|
validation_metrics = v2a_trainer.callback_metrics |
|
|
|
|
|
|
|
|
|
return |
|
|
|
def validation_step_end(self, outputs): |
|
|
|
pass |
|
|
|
def validation_epoch_end(self, outputs): |
|
|
|
pass |
|
|
|
|
|
def test_step(self, batch, *args, **kwargs): |
|
torch.cuda.set_device(self.device) |
|
video_path = batch[0] |
|
|
|
metainfo_path = os.path.join(video_path, 'confs/', 'video_metainfo.yaml') |
|
with open(metainfo_path, 'r') as file: |
|
loaded_config = yaml.safe_load(file) |
|
self.opt.dataset.metainfo = loaded_config.get('metainfo', {}) |
|
|
|
video_outputs_folder = os.path.abspath(os.path.join(os.getcwd(), 'Video/', self.opt.dataset.metainfo.data_dir)) |
|
video_checkpoints_folder = os.path.abspath(os.path.join(os.getcwd(), 'Video/', self.opt.dataset.metainfo.data_dir, 'checkpoints')) |
|
|
|
|
|
|
|
trainset = create_dataset(self.opt.dataset.metainfo, self.opt.dataset.train) |
|
validset = create_dataset(self.opt.dataset.metainfo, self.opt.dataset.valid) |
|
testset = create_dataset(self.opt.dataset.metainfo, self.opt.dataset.test) |
|
|
|
|
|
|
|
checkpoint_callback = pl.callbacks.ModelCheckpoint( |
|
dirpath=video_checkpoints_folder, |
|
filename="{epoch:04d}-{loss}", |
|
save_on_train_epoch_end=True, |
|
save_last=True) |
|
|
|
|
|
v2a_trainer = pl.Trainer( |
|
gpus=[self.device.index], |
|
accelerator="gpu", |
|
callbacks=[checkpoint_callback], |
|
max_epochs=self.opt.refinement_epochs, |
|
logger=self.logger, |
|
log_every_n_steps=1, |
|
num_sanity_val_steps=0, |
|
enable_progress_bar=False, |
|
enable_model_summary=False |
|
) |
|
|
|
model = V2AModel(self.opt, implicit_network=self.implicit_network, |
|
rendering_network=self.rendering_network) |
|
|
|
if self.opt.videos_dataset.test.pretrained == False: |
|
|
|
self.implicit_network = ImplicitNet(self.opt.model.implicit_network) |
|
self.rendering_network = RenderingNet(self.opt.model.rendering_network) |
|
|
|
|
|
model = V2AModel(self.opt, implicit_network=self.implicit_network, |
|
rendering_network=self.rendering_network) |
|
|
|
|
|
self.opt.dataset.metainfo.type = "test-non-pretrained" |
|
|
|
if self.opt.videos_dataset.test.mode == "short_time": |
|
|
|
v2a_trainer.fit(model, trainset) |
|
|
|
|
|
self.implicit_network.eval() |
|
|
|
|
|
self.rendering_network.eval() |
|
|
|
|
|
model.eval() |
|
|
|
if self.opt.videos_dataset.test.size == "reduced": |
|
|
|
v2a_trainer.validate(model, validset) |
|
validation_metrics = v2a_trainer.callback_metrics |
|
elif self.opt.videos_dataset.test.size == "full": |
|
v2a_trainer.test(model, testset) |
|
|
|
return |