|
|
|
|
|
import torch |
|
import pandas as pd |
|
import numpy as np |
|
import os |
|
import PIL.Image |
|
import random |
|
import custom_transforms as ctrans |
|
import math |
|
import utils as ut |
|
|
|
from torchvision import transforms |
|
|
|
from torch.utils.data.distributed import DistributedSampler |
|
from custom_sampler import DistributedEvalSampler |
|
from functools import partial |
|
import datasets as ds |
|
import io |
|
import logging |
|
|
|
class dataset_huggingface(torch.utils.data.Dataset): |
|
""" |
|
Dataset for Community Forensics |
|
""" |
|
def __init__( |
|
self, |
|
args, |
|
repo_id='OwensLab/CommunityForensics', |
|
split='Systematic+Manual', |
|
mode='train', |
|
cache_dir='', |
|
dtype=torch.float32, |
|
): |
|
""" |
|
args: Namespace of argument parser |
|
split: split of the dataset to use |
|
mode: 'train' or 'eval' |
|
cache_dir: directory to cache the dataset |
|
dtype: data type |
|
""" |
|
super(dataset_huggingface).__init__() |
|
self.args = args |
|
self.repo_id = repo_id |
|
self.split = split |
|
self.mode = mode |
|
self.cache_dir = cache_dir |
|
self.dtype = dtype |
|
self.dataset = self.get_hf_dataset() |
|
|
|
def __getitem__(self, index): |
|
""" |
|
Returns the image and label for the given index. |
|
""" |
|
data = self.dataset[index] |
|
image_bytes = data['image_data'] |
|
label = int(data['label']) |
|
generator_name = data['model_name'] |
|
|
|
img = PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
return img, label, generator_name |
|
|
|
def get_hf_dataset(self): |
|
""" |
|
Returns the huggingface dataset object |
|
""" |
|
hf_repo_id = self.repo_id |
|
if self.mode == 'train': |
|
shuffle=True |
|
shuffle_batch_size=3000 |
|
elif self.mode == 'eval': |
|
shuffle=False |
|
|
|
|
|
|
|
|
|
token_df = pd.read_csv("/nfs/turbo/coe-ahowens/jespark/tokens.csv") |
|
HF_TOKEN = token_df.loc[token_df['label'] == 'huggingface_write_token', 'token'].values[0] |
|
|
|
|
|
|
|
|
|
hf_dataset = ds.load_dataset(hf_repo_id, split=self.split, cache_dir=self.cache_dir, token=HF_TOKEN) |
|
if shuffle: |
|
hf_dataset = hf_dataset.shuffle(seed=self.args.seed, writer_batch_size=shuffle_batch_size) |
|
|
|
return hf_dataset |
|
|
|
def __len__(self): |
|
""" |
|
Returns the length of the dataset. |
|
""" |
|
return len(self.dataset) |
|
|
|
class dataset_folder_based(torch.utils.data.Dataset): |
|
""" |
|
Dataset for sourcing images from a directory; designed to be used with the huggingface datasets library. |
|
""" |
|
def __init__( |
|
self, |
|
args, |
|
dir, |
|
labels="real:0,fake:1", |
|
logger: logging.Logger = None, |
|
dtype=torch.float32, |
|
): |
|
""" |
|
args: Namespace of argument parser |
|
dir: directory to index |
|
labels: labels for the dataset. Default: "real:0,fake:1" -- assigns integer label 0 to images under "real" and 1 to images under "fake". |
|
dtype: data type |
|
|
|
The directory must be formatted as follows: |
|
- <generator_or_dataset_name> |
|
∟ <label -- "real" or "fake"> |
|
∟ <image_name>.{jpg,png,...} |
|
`dir` should point to the parent directory of the `generator_or_dataset_name` folders. |
|
""" |
|
super(dataset_folder_based).__init__() |
|
self.args = args |
|
self.dir = dir |
|
self.labels = self.parse_labels(labels) |
|
assert len(self.labels) == 2, f"Labels must be in the format 'label1:int,label2:int'. It only supports two labels. Instead, it is: {labels}." |
|
|
|
self.logger = logger |
|
if self.logger is None: |
|
self.logger = ut.logger |
|
self.dtype = dtype |
|
self.df = self.get_index(dir) |
|
|
|
def __getitem__(self, index): |
|
""" |
|
Returns the image and label for the given index. |
|
""" |
|
img_path = self.df.iloc[index]['ImagePath'] |
|
label = int(self.df.iloc[index]['Label']) |
|
generator_name = self.df.iloc[index]['GeneratorName'] |
|
|
|
img = PIL.Image.open(img_path).convert("RGB") |
|
|
|
return img, label, generator_name |
|
|
|
def __len__(self): |
|
""" |
|
Returns the length of the dataset. |
|
""" |
|
return len(self.df) |
|
|
|
def parse_labels(self, labels): |
|
""" |
|
Parses the labels string and returns a dictionary of labels. |
|
""" |
|
labels_dict = {} |
|
for label in labels.split(','): |
|
label_name, label_value = label.split(':') |
|
labels_dict[label_name] = int(label_value) |
|
|
|
return labels_dict |
|
|
|
def get_label_int(self, label): |
|
""" |
|
Returns the integer label for the given label name. |
|
""" |
|
if label in self.labels: |
|
return self.labels[label] |
|
else: |
|
raise ValueError(f"Label {label} not found in labels: {self.labels}. Please check the labels.") |
|
|
|
def get_index(self, dir): |
|
""" |
|
Check the `dir` for the index file. If it exists, load it. If not, index the directory and save the index file. |
|
""" |
|
index_path = os.path.join(dir, 'index.csv') |
|
if os.path.exists(index_path): |
|
df = pd.read_csv(index_path) |
|
if self.args.rank == 0: |
|
self.logger.info(f"Loaded index file from {index_path}") |
|
else: |
|
if self.args.rank == 0: |
|
self.logger.info(f"Index file not found. Indexing the directory {dir}. This may take a while...") |
|
df = self.index_directory(dir) |
|
return df |
|
|
|
def index_directory(self, dir, report_every=1000): |
|
""" |
|
Indexes the given directory and returns a dataframe with the image paths, labels, and generator names. |
|
The directory must be formatted as follows: |
|
- <generator_or_dataset_name> |
|
∟ <label -- "real" or "fake"> |
|
∟ <image_name>.{jpg,png,...} |
|
`dir` should point to the parent directory of the `generator_or_dataset_name` folders. |
|
""" |
|
df = pd.DataFrame(columns=['ImagePath', 'Label', 'GeneratorName']) |
|
temp_dfs=[] |
|
for root, dirs, files in os.walk(dir): |
|
for file in files: |
|
if file.endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp', '.gif')): |
|
|
|
generator_name=os.path.basename(os.path.dirname(root)) |
|
label=os.path.basename(root) |
|
label_int=self.get_label_int(label) |
|
|
|
image_path=os.path.join(root, file) |
|
|
|
temp_dfs.append(pd.DataFrame([[image_path, label_int, generator_name]], columns=['ImagePath', 'Label', 'GeneratorName'])) |
|
if len(temp_dfs) % report_every == 0 and self.args.rank == 0: |
|
print(f"\rIndexed {len(temp_dfs)} images... ", end='', flush=True) |
|
df = pd.concat(temp_dfs, ignore_index=True) |
|
print("") |
|
|
|
df = df.sort_values(by=['GeneratorName', 'Label', 'ImagePath']) |
|
df = df.reset_index(drop=True) |
|
|
|
df.to_csv(os.path.join(dir, 'index.csv'), index=False) |
|
|
|
self.logger.info(f"Indexed the directory {dir} and saved the index file to {os.path.join(dir, 'index.csv')}") |
|
return df |
|
|
|
def limit_real_data(self, df, num_max_images): |
|
""" |
|
Limits the real data to contain `num_max_images` total images by preserving the smallest datasets first. |
|
""" |
|
new_df=pd.DataFrame() |
|
|
|
real_df = df[df['Label'] == 0] |
|
fake_df = df[df['Label'] == 1] |
|
|
|
if len(real_df) <= num_max_images: |
|
self.logger.info(f"The size of real data: {len(real_df)} is less than or equal to the target size: {num_max_images}. No need to limit the real data. Note that the original model is trained with near 50/50 real/fake to avoid bias -- too much deviation from this may lead to unwanted detection bias.") |
|
return df |
|
|
|
dataset_counts = real_df['GeneratorName'].value_counts() |
|
|
|
dataset_counts = dataset_counts.sort_values(ascending=True) |
|
smallest_sum=0 |
|
smallest_idx=0 |
|
num_not_appended_datasets=len(dataset_counts) |
|
|
|
while True: |
|
perModelLen = dataset_counts.iloc[smallest_idx] |
|
if (perModelLen * num_not_appended_datasets + smallest_sum) >= num_max_images: |
|
perModelLen = math.ceil((num_max_images - smallest_sum) / num_not_appended_datasets) |
|
break |
|
elif smallest_idx == len(dataset_counts)-1: |
|
break |
|
else: |
|
smallest_sum += dataset_counts.iloc[smallest_idx] |
|
smallest_idx+=1 |
|
num_not_appended_datasets-=1 |
|
|
|
|
|
for dataset_name in dataset_counts.index[smallest_idx:]: |
|
dataset_df = real_df[real_df['GeneratorName'] == dataset_name] |
|
if len(dataset_df) > perModelLen: |
|
dataset_df = dataset_df.sample(n=perModelLen, random_state=self.args.seed) |
|
new_df = pd.concat([new_df, dataset_df], ignore_index=True) |
|
|
|
|
|
for dataset_name in dataset_counts.index[:smallest_idx]: |
|
dataset_df = real_df[real_df['GeneratorName'] == dataset_name] |
|
new_df = pd.concat([new_df, dataset_df], ignore_index=True) |
|
|
|
|
|
if self.args.rank == 0: |
|
pd.options.display.float_format = '{:.2f} %'.format |
|
self.logger.info(f"Max images per dataset limited to {perModelLen}. Affected datasets: {dataset_counts.index[smallest_idx:]}") |
|
|
|
dataset_counts = new_df['GeneratorName'].value_counts() |
|
dataset_counts = dataset_counts / dataset_counts.sum() * 100 |
|
self.logger.info(f"Dataset composition: \n{dataset_counts}") |
|
|
|
|
|
new_df = pd.concat([new_df, fake_df], ignore_index=True) |
|
|
|
return new_df |
|
|
|
def determine_resize_crop_sizes(args): |
|
""" |
|
Determine resize and crop sizes based on input size. |
|
""" |
|
if args.input_size==224: |
|
resize_size=256 |
|
crop_size=224 |
|
elif args.input_size==384: |
|
resize_size=440 |
|
crop_size=384 |
|
return resize_size, crop_size |
|
|
|
def get_transform(args, mode="train", dtype=torch.float32): |
|
norm_mean = [0.485, 0.456, 0.406] |
|
norm_std = [0.229, 0.224, 0.225] |
|
resize_size, crop_size = determine_resize_crop_sizes(args) |
|
augment_list = [] |
|
|
|
if mode=="train": |
|
augment_list.append(transforms.Resize(resize_size)) |
|
|
|
|
|
if args.rsa_ops != '': |
|
|
|
|
|
augment_list.append(ctrans.RandomStateAugmentation(resize_size=resize_size, crop_size=crop_size, auglist=args.rsa_ops, min_augs=args.rsa_min_num_ops, max_augs=args.rsa_max_num_ops)) |
|
|
|
augment_list.append(transforms.RandomCrop(crop_size)) |
|
|
|
|
|
augment_list.extend([ |
|
ctrans.ToTensor_range(val_min=0, val_max=1), |
|
transforms.Normalize(mean=norm_mean, std=norm_std), |
|
transforms.ConvertImageDtype(dtype) |
|
]) |
|
elif mode=="val" or mode=="test": |
|
augment_list.append(transforms.Resize(resize_size)) |
|
augment_list.extend([ |
|
transforms.CenterCrop(crop_size), |
|
ctrans.ToTensor_range(val_min=0, val_max=1), |
|
transforms.Normalize(mean=norm_mean, std=norm_std), |
|
transforms.ConvertImageDtype(dtype), |
|
]) |
|
transform = transforms.Compose(augment_list) |
|
return transform |
|
|
|
class SubsetWithTransform(torch.utils.data.Dataset): |
|
""" |
|
Custom subset class which allows to customize transform for each subsets got from random_split() |
|
""" |
|
def __init__(self, subset, transform=None): |
|
self.subset = subset |
|
self.subset_len = len(subset) |
|
self.transform = transform |
|
|
|
def __getitem__(self, index): |
|
img, lab, generator_name = self.subset[index] |
|
if self.transform: |
|
img = self.transform(img) |
|
return img, lab, generator_name |
|
|
|
def __len__(self): |
|
return self.subset_len |
|
|
|
def set_seeds_for_data(seed=11997733): |
|
""" |
|
Set seeds for Python, numpy, and pytorch. Used to split the dataset consistantly across DDP instances. |
|
""" |
|
torch.manual_seed(seed) |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
|
|
def set_seeds_for_worker(seed=11997733, id=0): |
|
""" |
|
Set seeds for python, and numpy. Default seed=11997733. |
|
PyTorch seeding is handled by torch.Generator passed into the DataLoader |
|
""" |
|
seed = seed % (2**31) |
|
random.seed(seed+id) |
|
np.random.seed(seed+id) |
|
|
|
def worker_seed_reporter(id=None): |
|
""" |
|
Debug: reports worker seeds |
|
""" |
|
workerseed = torch.utils.data.get_worker_info().seed |
|
numwkr = torch.utils.data.get_worker_info().num_workers |
|
baseseed = torch.initial_seed() |
|
print(f"Worker id: {id+1}/{numwkr}, worker seed: {workerseed}, baseseed: {baseseed}, workerseed % (2**31): {workerseed % (2**31)}") |
|
|
|
def set_seeds_and_report(report=True, id=0): |
|
""" |
|
Debug: set seeds and report |
|
""" |
|
workerseed = torch.utils.data.get_worker_info().seed |
|
set_seeds_for_worker(workerseed, id) |
|
if report: |
|
worker_seed_reporter(id) |
|
|
|
def get_seedftn_and_generator(args, seed=None): |
|
""" |
|
Get the seed function and generator for the dataloader. |
|
Args: |
|
args: Namespace of argument parser |
|
seed: seed for random number generation |
|
""" |
|
rank = args.rank |
|
if seed is not None: |
|
seedftn = partial(set_seeds_and_report, False) |
|
seed_generator = torch.Generator(device='cpu') |
|
seed_generator.manual_seed(seed+rank) |
|
else: |
|
seedftn = None |
|
seed_generator = None |
|
seed = random.randint(0, 1000000000) |
|
|
|
return seedftn, seed_generator, seed |
|
|
|
def get_train_dataloaders( |
|
args, |
|
huggingface_repo_id='', |
|
huggingface_split='Systematic+Manual', |
|
additional_data_path='', |
|
additional_data_label_format='real:0,fake:1', |
|
batch_size=128, |
|
num_workers=4, |
|
val_frac=0.01, |
|
logger: logging.Logger = None, |
|
seed=None, |
|
): |
|
""" |
|
Get train and validation dataloaders for the dataset. |
|
Args: |
|
args: Namespace of argument parser |
|
huggingface_repo_id: huggingface repo id for the dataset |
|
huggingface_split: split of the dataset to use |
|
additional_data_path: path to the folder containing the dataset |
|
batch_size: size of batch |
|
num_workers: number of subprocesses to spawn |
|
val_frac: fraction of data to use for validation (default: 0.01) |
|
seed: seed for random number generation |
|
""" |
|
rank = args.rank |
|
world_size = args.world_size |
|
|
|
seedftn, seed_generator, seed = get_seedftn_and_generator(args, seed) |
|
if logger is None: |
|
logger = ut.logger |
|
|
|
hf_dataset=None |
|
if huggingface_repo_id != '': |
|
hf_dataset=dataset_huggingface(args, huggingface_repo_id, split=huggingface_split, mode='train', cache_dir=args.cache_dir, dtype=torch.float32) |
|
|
|
folder_dataset=None |
|
if additional_data_path != '': |
|
folder_dataset=dataset_folder_based(args, additional_data_path, additional_data_label_format, logger=logger, dtype=torch.float32) |
|
num_fake_images = len(folder_dataset.df[folder_dataset.df['Label'] == 1]) |
|
if hf_dataset is not None and not args.dont_limit_real_data_to_fake: |
|
num_hf_fake_images = len(hf_dataset.dataset.filter(lambda x: x['label'] == 1, num_proc=num_workers)) |
|
num_hf_real_images = len(hf_dataset.dataset) - num_hf_fake_images |
|
num_fake_images = num_fake_images + num_hf_fake_images |
|
|
|
|
|
folder_based_real_limit = num_fake_images - num_hf_real_images |
|
if folder_based_real_limit < 0: |
|
folder_based_real_limit = 0 |
|
else: |
|
if rank == 0: |
|
logger.info(f"Limiting folder-based real data to {folder_based_real_limit} images to match the number of fake images.") |
|
folder_dataset.df = folder_dataset.limit_real_data(folder_dataset.df, folder_based_real_limit) |
|
|
|
|
|
if hf_dataset is not None and folder_dataset is not None: |
|
dataset_object = torch.utils.data.ConcatDataset([hf_dataset, folder_dataset]) |
|
elif hf_dataset is not None: |
|
dataset_object = hf_dataset |
|
elif folder_dataset is not None: |
|
dataset_object = folder_dataset |
|
else: |
|
raise ValueError("No dataset provided. Please provide a huggingface repo id or a folder path.") |
|
|
|
set_seeds_for_data(seed) |
|
|
|
|
|
train_frac = 1 - val_frac |
|
if val_frac > 0: |
|
traindata_split, valdata_split = torch.utils.data.random_split(dataset_object, (train_frac, val_frac)) |
|
else: |
|
traindata_split = dataset_object |
|
valdata_split = [] |
|
|
|
set_seeds_for_data(seed+rank) |
|
|
|
|
|
traindata_split = SubsetWithTransform(traindata_split, transform=get_transform(args, mode='train', dtype=torch.float32)) |
|
train_sampler = DistributedSampler( |
|
traindata_split, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False, |
|
) |
|
trainloader = torch.utils.data.DataLoader( |
|
traindata_split, batch_size=batch_size, pin_memory=True, |
|
shuffle=False, num_workers=num_workers, sampler=train_sampler, |
|
worker_init_fn=seedftn, generator=seed_generator |
|
) |
|
|
|
if len(valdata_split) > 0: |
|
valdata_split = SubsetWithTransform(valdata_split, transform=get_transform(args, mode='val', dtype=torch.float32)) |
|
val_sampler = DistributedEvalSampler( |
|
valdata_split, num_replicas=world_size, rank=rank, shuffle=False, |
|
) |
|
valloader = torch.utils.data.DataLoader( |
|
valdata_split, batch_size=batch_size, pin_memory=True, |
|
shuffle=False, num_workers=num_workers, sampler=val_sampler, |
|
worker_init_fn=seedftn, generator=seed_generator |
|
) |
|
else: |
|
valloader = None |
|
|
|
if rank == 0: |
|
if huggingface_repo_id != '': |
|
logger.info(f"Loaded huggingface dataset from {huggingface_repo_id}. Split: {huggingface_split}.") |
|
if additional_data_path != '': |
|
logger.info(f"Loaded folder dataset from {additional_data_path}.") |
|
logger.info(f"Train/Val split: num_total: {len(dataset_object)}, num_train: {len(traindata_split)}, num_val: {len(valdata_split)} ") |
|
|
|
return trainloader, valloader |
|
|
|
def get_test_dataloader( |
|
args, |
|
huggingface_repo_id='', |
|
huggingface_split='PublicEval', |
|
additional_data_path='', |
|
additional_data_label_format='real:0,fake:1', |
|
batch_size=128, |
|
num_workers=4, |
|
logger: logging.Logger = None, |
|
seed=None, |
|
): |
|
""" |
|
Get test dataloader for the dataset. |
|
Args: |
|
args: Namespace of argument parser |
|
huggingface_repo_id: huggingface repo id for the dataset |
|
huggingface_split: split of the dataset to use |
|
additional_data_path: path to the folder containing the dataset |
|
batch_size: size of batch |
|
num_workers: number of subprocesses to spawn |
|
seed: seed for random number generation |
|
""" |
|
rank = args.rank |
|
world_size = args.world_size |
|
|
|
if logger is None: |
|
logger = ut.logger |
|
|
|
seedftn, seed_generator, seed = get_seedftn_and_generator(args, seed) |
|
|
|
hf_dataset=None |
|
if huggingface_repo_id != '': |
|
hf_dataset=dataset_huggingface(args, huggingface_repo_id, split=huggingface_split, mode='eval', cache_dir=args.cache_dir, dtype=torch.float32) |
|
|
|
folder_dataset=None |
|
if additional_data_path != '': |
|
folder_dataset=dataset_folder_based(args, additional_data_path, additional_data_label_format, logger=logger, dtype=torch.float32) |
|
|
|
|
|
if hf_dataset is not None and folder_dataset is not None: |
|
dataset_object = torch.utils.data.ConcatDataset([hf_dataset, folder_dataset]) |
|
elif hf_dataset is not None: |
|
dataset_object = hf_dataset |
|
elif folder_dataset is not None: |
|
dataset_object = folder_dataset |
|
else: |
|
raise ValueError("No dataset provided. Please provide a huggingface repo id or a folder path.") |
|
|
|
set_seeds_for_data(seed+rank) |
|
|
|
|
|
dataset_object = SubsetWithTransform(dataset_object, transform=get_transform(args, mode='val', dtype=torch.float32)) |
|
|
|
|
|
test_sampler = DistributedEvalSampler( |
|
dataset_object, num_replicas=world_size, rank=rank, shuffle=True, |
|
) |
|
testloader = torch.utils.data.DataLoader( |
|
dataset_object, batch_size=batch_size, pin_memory=True, |
|
shuffle=False, num_workers=num_workers, sampler=test_sampler, |
|
worker_init_fn=seedftn, generator=seed_generator |
|
) |
|
if rank == 0: |
|
if huggingface_repo_id != '': |
|
logger.info(f"Loaded huggingface dataset from {huggingface_repo_id}. Split: {huggingface_split}.") |
|
if additional_data_path != '': |
|
logger.info(f"Loaded folder dataset from {additional_data_path}.") |
|
logger.info(f"Test set size: {len(dataset_object)} ") |
|
return testloader |