File size: 1,798 Bytes
6cd7429
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""

Configuration file for Federated Autoencoder Training

"""

import os

# Dataset Configuration
DATA_ROOT = "data"
DATASETS = ["Michel Daudon (w256 1k v1)", "Jonathan El-Beze (w256 1k v1)"]
SUBVERSIONS = ["MIX", "SEC", "SUR"]
IMAGE_SIZE = (256, 256)
CHANNELS = 3

# Federated Learning Configuration
NUM_CLIENTS = 10  # Default number of clients
NUM_ROUNDS = 20  # Number of federated rounds
CLIENTS_PER_ROUND = 8  # Number of clients participating per round

# Model Configuration
LATENT_DIM = 128  # Autoencoder latent dimension
LEARNING_RATE = 0.001
BATCH_SIZE = 32
LOCAL_EPOCHS = 3  # Number of local epochs per client

# Data Corruption Configuration
CORRUPTION_PROBABILITY = 0.2  # Default corruption probability
CORRUPTION_TYPES = [
    "gaussian_noise",
    "salt_pepper",
    "blur",
    "brightness",
    "contrast"
]

# Training Configuration
import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42

# CUDA Configuration
CUDA_AVAILABLE = torch.cuda.is_available()
if CUDA_AVAILABLE:
    CUDA_DEVICE_COUNT = torch.cuda.device_count()
    CUDA_DEVICE_NAME = torch.cuda.get_device_name(0)
    CUDA_MEMORY_GB = torch.cuda.get_device_properties(0).total_memory / 1e9
    
    # Enable CUDA optimizations
    torch.backends.cudnn.benchmark = True  # Optimize for consistent input sizes
    torch.backends.cudnn.deterministic = False  # Allow non-deterministic for speed
    
    # Set memory allocation strategy
    torch.cuda.empty_cache()  # Clear cache

# Logging Configuration
LOG_DIR = "logs"
MODEL_SAVE_DIR = "models"
RESULTS_DIR = "results"

# Non-IID Configuration
ALPHA = 0.3  # Dirichlet distribution parameter for non-IID data distribution
MIN_SAMPLES_PER_CLIENT = 50  # Minimum samples per client