deepfake / utils /utils.py
Dharshaneshwaran
Full updated code with finding ai generated images too
ddcedb5
raw
history blame
1.46 kB
import contextlib
import numpy as np
import random
import shutil
import os
import torch
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def save_checkpoint(state, is_best, checkpoint_path, filename="checkpoint.pt"):
filename = os.path.join(checkpoint_path, filename)
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, os.path.join(checkpoint_path, "model_best.pt"))
def load_checkpoint(model, path):
best_checkpoint = torch.load(path)
model.load_state_dict(best_checkpoint["state_dict"])
def log_metrics(set_name, metrics, logger):
logger.info(
"{}: Loss: {:.5f} | spec_acc: {:.5f}, rgb_acc: {:.5f}".format(
set_name, metrics["loss"], metrics["spec_acc"], metrics["rgb_acc"]
)
)
@contextlib.contextmanager
def numpy_seed(seed, *addl_seeds):
"""Context manager which seeds the NumPy PRNG with the specified seed and
restores the state afterward"""
if seed is None:
yield
return
if len(addl_seeds) > 0:
seed = int(hash((seed, *addl_seeds)) % 1e6)
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)