omnisealbench / backend /config.py
antoinetran's picture
Fix the relative vs. absolute path issues
cc391f1 verified
raw
history blame
8.32 kB
# Change these values to match your dataset structure if loading locally or from a different source.
# IMPORTANT: When running from docker more setup is required (e.g. on Huggingface)
import os
from collections import defaultdict
from pathlib import Path
ABS_DATASET_DOMAIN = "https://dl.fbaipublicfiles.com"
# Sample dataset domain and path for local loading
# Some more configuration may be required to load examples from local files.
# ABS_DATASET_DOMAIN = "./data"
ABS_DATASET_PATH = f"{ABS_DATASET_DOMAIN}/omnisealbench/"
MODALITY_CONFIG_CONSTANTS = {
"audio": {
"first_cols": [
"snr",
"sisnr",
"stoi",
"pesq",
],
"attack_scores": [
"bit_acc",
"log10_p_value",
"TPR",
"FPR",
],
"categories": {
"speed": "Time",
"updownresample": "Time",
"echo": "Time",
"random_noise": "Amplitude",
"lowpass_filter": "Amplitude",
"highpass_filter": "Amplitude",
"bandpass_filter": "Amplitude",
"smooth": "Amplitude",
"boost_audio": "Amplitude",
"duck_audio": "Amplitude",
"shush": "Amplitude",
"pink_noise": "Amplitude",
"aac_compression": "Compression",
"mp3_compression": "Compression",
},
"attacks_with_variations": [
"random_noise",
"lowpass_filter",
"highpass_filter",
"boost_audio",
"duck_audio",
"shush",
],
},
"image": {
"first_cols": ["psnr", "ssim", "lpips", "decoder_time"],
"attack_scores": ["bit_acc", "log10_p_value", "TPR", "FPR"],
"categories": {
"proportion": "Geometric",
"collage": "Inpainting",
"center_crop": "Geometric",
"rotate": "Geometric",
"jpeg": "Compression",
"brightness": "Visual",
"contrast": "Visual",
"saturation": "Visual",
"sharpness": "Visual",
"resize": "Geometric",
"overlay_text": "Inpainting",
"hflip": "Geometric",
"perspective": "Geometric",
"median_filter": "Visual",
"hue": "Visual",
"gaussian_blur": "Visual",
"comb": "Mixed",
"avg": "Averages",
"none": "Baseline",
},
"attacks_with_variations": [
"center_crop",
"jpeg",
"brightness",
"contrast",
"saturation",
"sharpness",
"resize",
"perspective",
"median_filter",
"hue",
"gaussian_blur",
],
},
"video": {
"first_cols": ["psnr", "ssim", "msssim", "lpips", "vmaf", "decoder_time"],
"attack_scores": ["bit_acc", "log10_p_value", "TPR", "FPR"],
"categories": {
"HorizontalFlip": "Geometric",
"Rotate": "Geometric",
"Resize": "Geometric",
"Crop": "Geometric",
"Perspective": "Geometric",
"Brightness": "Visual",
"Contrast": "Visual",
"Saturation": "Visual",
"Grayscale": "Visual",
"Hue": "Visual",
"JPEG": "Compression",
"GaussianBlur": "Visual",
"MedianFilter": "Visual",
"H264": "Compression",
"H264rgb": "Compression",
"H265": "Compression",
"VP9": "Compression",
"H264_Crop_Brightness0": "Mixed",
"H264_Crop_Brightness1": "Mixed",
"H264_Crop_Brightness2": "Mixed",
"H264_Crop_Brightness3": "Mixed",
},
"attacks_with_variations": [
"Rotate",
"Resize",
"Crop",
"Brightness",
"Contrast",
"Saturation",
"H264",
"H264rgb",
"H265",
],
},
}
DATASET_CONFIGS = {
"voxpopuli_1k/audio": {"type": "audio", "path": ABS_DATASET_PATH},
"ravdess_1k/audio": {"type": "audio", "path": ABS_DATASET_PATH},
"val2014_1k_v2/image": {"type": "image", "path": ABS_DATASET_PATH},
"sa_1b_val_1k/image": {"type": "image", "path": ABS_DATASET_PATH},
"sav_val_full_v2/video": {"type": "video", "path": ABS_DATASET_PATH},
}
def get_user_dataset():
datasets = defaultdict(list)
default_local_data_dir = str(Path(__file__).parent.joinpath("data"))
user_data_dir = os.getenv("OMNISEAL_LEADERBOARD_DATA", default_local_data_dir)
if user_data_dir:
for user_data in os.listdir(user_data_dir):
if not os.path.isdir(os.path.join(user_data_dir, user_data)):
continue
user_dtype = os.listdir(os.path.join(user_data_dir, user_data, "examples"))[
0
]
datasets[user_dtype].append(user_data + "/" + user_dtype)
return datasets
def get_datasets():
grouped = {"audio": [], "image": [], "video": []}
for name, cfg in DATASET_CONFIGS.items():
dtype = cfg.get("type")
if dtype in grouped:
grouped[dtype].append(name)
# Add user datasets
user_datasets = get_user_dataset()
for dtype, user_names in user_datasets.items():
if dtype in grouped:
_names = [name for name in user_names if name not in grouped[dtype]]
grouped[dtype].extend(_names)
return grouped
def get_example_config(type, dataset_name):
"""Get example configuration for a specific dataset."""
if not dataset_name:
raise ValueError("Dataset name is required")
# Check if it's a valid dataset for this type
all_datasets = get_datasets()
if dataset_name not in all_datasets.get(type, []):
raise ValueError(f"Unknown dataset {dataset_name} for type {type}")
# Extract the dataset name without the type suffix
dataset_base_name = dataset_name.split("/")[0]
# Check if it's a user dataset
user_datasets = get_user_dataset()
user_data_dir = os.getenv("OMNISEAL_LEADERBOARD_DATA", "./data")
if dataset_name in user_datasets.get(type, []):
# It's a user dataset
examples_config = {
"dataset_name": dataset_base_name,
"path": user_data_dir + "/",
"db_key": dataset_base_name,
}
else:
# It's a predefined dataset from DATASET_CONFIGS
if dataset_name in DATASET_CONFIGS:
dataset_config = DATASET_CONFIGS[dataset_name]
examples_config = {
"dataset_name": dataset_base_name,
"path": dataset_config["path"],
"db_key": _get_db_key_for_dataset(dataset_base_name),
}
else:
raise ValueError(f"Dataset {dataset_name} not found in configurations")
return examples_config
def _get_db_key_for_dataset(dataset_base_name):
"""Helper function to determine the database key for a dataset"""
# Map of dataset names to their db keys
db_key_mapping = {
"voxpopuli_1k": "voxpopuli",
"val2014_1k_v2": "local_val2014",
"sa_1b_val_1k": "local_valid",
"sav_val_full_v2": "sa-v_sav_val_videos",
"ravdess_1k": "ravdess", # Add mapping for ravdess dataset
}
return db_key_mapping.get(dataset_base_name, dataset_base_name)
def get_dataset_config(dataset_name):
if dataset_name in DATASET_CONFIGS:
cfg = DATASET_CONFIGS[dataset_name]
extra_cfg = MODALITY_CONFIG_CONSTANTS.get(cfg["type"], {})
cfg.update(extra_cfg)
return cfg
else:
modality = dataset_name.split("/")[-1]
user_dataset = get_user_dataset()
default_local_data_dir = str(Path(__file__).parent.joinpath("data"))
user_data_dir = os.getenv("OMNISEAL_LEADERBOARD_DATA", default_local_data_dir)
if dataset_name in user_dataset.get(modality, []):
cfg = {
"type": modality,
"path": user_data_dir,
}
extra_cfg = MODALITY_CONFIG_CONSTANTS.get(cfg["type"], {})
cfg.update(extra_cfg)
return cfg
raise ValueError(f"Unknown dataset: {dataset_name}")