deepfake / data /dfdt_dataset.py
Dharshaneshwaran
Full updated code with finding ai generated images too
ddcedb5
raw
history blame
4.73 kB
'''Module for loading the fakeavceleb dataset from tfrecord format'''
import numpy as np
import tensorflow as tf
from data.augmentation_utils import create_frame_transforms, create_spec_transforms
FEATURE_DESCRIPTION = {
'video_path': tf.io.FixedLenFeature([], tf.string),
'image/encoded': tf.io.FixedLenFeature([], tf.string),
'clip/label/index': tf.io.FixedLenFeature([], tf.int64),
'clip/label/text': tf.io.FixedLenFeature([], tf.string),
'WAVEFORM/feature/floats': tf.io.FixedLenFeature([], tf.string)
}
@tf.function
def _parse_function(example_proto):
#Parse the input `tf.train.Example` proto using the dictionary above.
example = tf.io.parse_single_example(example_proto, FEATURE_DESCRIPTION)
video_path = example['video_path']
video = tf.io.decode_raw(example['image/encoded'], tf.int8)
spectrogram = tf.io.decode_raw(example['WAVEFORM/feature/floats'], tf.float32)
label = example["clip/label/text"]
label_map = example["clip/label/index"]
return video, spectrogram, label_map
@tf.function
def decode_inputs(video, spectrogram, label_map):
'''Decode tensors to arrays with desired shape'''
frame = tf.reshape(video, [10, 3, 256, 256])
frame = frame[0] / 255 #Pick the first frame and normalize it.
# frame = tf.cast(frame, tf.float32)
label_map = tf.expand_dims(label_map, axis = 0)
sample = {'video_reshaped': frame, 'spectrogram': spectrogram, 'label_map': label_map}
return sample
def decode_train_inputs(video, spectrogram, label_map):
#Data augmentation for spectograms
spectrogram_shape = spectrogram.shape
spec_augmented = tf.py_function(aug_spec_fn, [spectrogram], tf.float32)
spec_augmented.set_shape(spectrogram_shape)
frame = tf.reshape(video, [10, 256, 256, 3])
frame = frame[0] #Pick the first frame.
frame = frame / 255 #Normalize tensor.
frame_augmented = tf.py_function(aug_img_fn, [frame], tf.uint8)
# frame_augmented.set_shape(frame_shape)
frame_augmented.set_shape([3, 256, 256])
label_map = tf.expand_dims(label_map, axis = 0)
augmented_sample = {'video_reshaped': frame_augmented, 'spectrogram': spec_augmented, 'label_map': label_map}
return augmented_sample
def aug_img_fn(frame):
frame = frame.numpy().astype(np.uint8)
frame_data = {'image': frame}
aug_frame_data = create_frame_transforms(**frame_data)
aug_img = aug_frame_data['image']
aug_img = aug_img.transpose(2, 0, 1)
return aug_img
def aug_spec_fn(spec):
spec = spec.numpy()
spec_data = {'spec': spec}
aug_spec_data = create_spec_transforms(**spec_data)
aug_spec = aug_spec_data['spec']
return aug_spec
class FakeAVCelebDatasetTrain:
def __init__(self, args):
self.args = args
self.samples = self.load_features_from_tfrec()
def load_features_from_tfrec(self):
'''Loads raw features from a tfrecord file and returns them as raw inputs'''
ds = tf.io.matching_files(self.args.data_dir)
files = tf.random.shuffle(ds)
shards = tf.data.Dataset.from_tensor_slices(files)
dataset = shards.interleave(tf.data.TFRecordDataset)
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.map(_parse_function, num_parallel_calls = tf.data.AUTOTUNE)
dataset = dataset.map(decode_train_inputs, num_parallel_calls = tf.data.AUTOTUNE)
dataset = dataset.padded_batch(batch_size = self.args.batch_size)
return dataset
def __len__(self):
self.samples = self.load_features_from_tfrec(self.args.data_dir)
cnt = self.samples.reduce(np.int64(0), lambda x, _: x + 1)
cnt = cnt.numpy()
return cnt
class FakeAVCelebDatasetVal:
def __init__(self, args):
self.args = args
self.samples = self.load_features_from_tfrec()
def load_features_from_tfrec(self):
'''Loads raw features from a tfrecord file and returns them as raw inputs'''
ds = tf.io.matching_files(self.args.data_dir)
files = tf.random.shuffle(ds)
shards = tf.data.Dataset.from_tensor_slices(files)
dataset = shards.interleave(tf.data.TFRecordDataset)
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.map(_parse_function, num_parallel_calls = tf.data.AUTOTUNE)
dataset = dataset.map(decode_inputs, num_parallel_calls = tf.data.AUTOTUNE)
dataset = dataset.padded_batch(batch_size = self.args.batch_size)
return dataset
def __len__(self):
self.samples = self.load_features_from_tfrec(self.args.data_dir)
cnt = self.samples.reduce(np.int64(0), lambda x, _: x + 1)
cnt = cnt.numpy()
return cnt