Spaces:
Running
on
Zero
Running
on
Zero
import lightning.pytorch as pl | |
from lightning.pytorch import Callback | |
import os.path | |
import numpy | |
from typing import Sequence, Any, Dict | |
from concurrent.futures import ThreadPoolExecutor | |
from lightning.pytorch.utilities.types import STEP_OUTPUT | |
from lightning_utilities.core.rank_zero import rank_zero_info | |
class SaveImagesHook(Callback): | |
def __init__(self, save_dir="val", save_compressed=False): | |
self.save_dir = save_dir | |
self.save_compressed = save_compressed | |
def save_start(self, target_dir): | |
self.samples = [] | |
self.target_dir = target_dir | |
self.executor_pool = ThreadPoolExecutor(max_workers=8) | |
if not os.path.exists(self.target_dir): | |
os.makedirs(self.target_dir, exist_ok=True) | |
else: | |
if os.listdir(target_dir) and "debug" not in str(target_dir): | |
raise FileExistsError(f'{self.target_dir} already exists and not empty!') | |
rank_zero_info(f"Save images to {self.target_dir}") | |
self._saved_num = 0 | |
def save_image(self, trainer, pl_module, images, metadatas,): | |
images = images.permute(0, 2, 3, 1).cpu().numpy() | |
for sample, metadata in zip(images, metadatas): | |
save_fn = metadata.pop("save_fn", None) | |
self.executor_pool.submit(save_fn, sample, metadata, self.target_dir) | |
def process_batch( | |
self, | |
trainer: "pl.Trainer", | |
pl_module: "pl.LightningModule", | |
samples: STEP_OUTPUT, | |
batch: Any, | |
) -> None: | |
xT, y, metadata = batch | |
b, c, h, w = samples.shape | |
if not self.save_compressed or self._saved_num < 10: | |
self._saved_num += b | |
self.save_image(trainer, pl_module, samples, metadata) | |
all_samples = pl_module.all_gather(samples).view(-1, c, h, w) | |
if trainer.is_global_zero: | |
all_samples = all_samples.permute(0, 2, 3, 1).cpu().numpy() | |
self.samples.append(all_samples) | |
def save_end(self): | |
if self.save_compressed and len(self.samples) > 0: | |
samples = numpy.concatenate(self.samples) | |
numpy.savez(f'{self.target_dir}/output.npz', arr_0=samples) | |
self.executor_pool.shutdown(wait=True) | |
self.target_dir = None | |
self.executor_pool = None | |
self.samples = [] | |
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}") | |
self.save_start(target_dir) | |
def on_validation_batch_end( | |
self, | |
trainer: "pl.Trainer", | |
pl_module: "pl.LightningModule", | |
outputs: STEP_OUTPUT, | |
batch: Any, | |
batch_idx: int, | |
dataloader_idx: int = 0, | |
) -> None: | |
return self.process_batch(trainer, pl_module, outputs, batch) | |
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
self.save_end() | |
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, "predict") | |
self.save_start(target_dir) | |
def on_predict_batch_end( | |
self, | |
trainer: "pl.Trainer", | |
pl_module: "pl.LightningModule", | |
samples: Any, | |
batch: Any, | |
batch_idx: int, | |
dataloader_idx: int = 0, | |
) -> None: | |
return self.process_batch(trainer, pl_module, samples, batch) | |
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
self.save_end() | |
def state_dict(self) -> Dict[str, Any]: | |
return dict() |