File size: 2,633 Bytes
b7f710c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/bin/bash

# Exit on any error
set -e

# Function to log messages
log_message() {
    local timestamp=$(date '+%Y-%m-%d %H:%M:%S')
    echo "[${timestamp}] $1"
}

# Function to check if a command exists
command_exists() {
    command -v "$1" >/dev/null 2>&1
}

# Function to check if a directory exists
check_directory() {
    if [ ! -d "$1" ]; then
        log_message "ERROR: Directory $1 does not exist"
        exit 1
    fi
}

# Function to check Python and required tools
check_requirements() {
    log_message "Checking requirements..."

    # Check for Python
    if ! command_exists python3; then
        log_message "ERROR: Python3 is not installed"
        exit 1
    fi

    # Check for accelerate
    if ! command_exists accelerate; then
        log_message "ERROR: Accelerate is not installed. Please install it using 'pip install accelerate'"
        exit 1
    fi
}

# Main script execution
main() {
    log_message "Starting training pipeline..."

    # Set variables
    SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
    PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"

    DATASET_DIR="$PROJECT_ROOT/data/processed_ds"
    SCRIPT_DIR="$PROJECT_ROOT/scripts"
    SRC_DIR="$PROJECT_ROOT/src/slimface/training"

    # Check if required directories exist
    check_directory "$SCRIPT_DIR"
    check_directory "$SRC_DIR"

    # Check requirements
    check_requirements

    # Process dataset
    log_message "Processing dataset..."
    python3 "${SCRIPT_DIR}/process_dataset.py" \
        --random_state 42 \
        --test_split_rate 0.2 \
        --augment || {
        log_message "ERROR: Dataset processing failed"
        exit 1
    }
    check_directory "$DATASET_DIR"
    # Configure accelerate
    log_message "Configuring accelerate..."
    accelerate config default || {
        log_message "ERROR: Accelerate configuration failed"
        exit 1
    }

    # Launch training
    log_message "Starting model training..."
    accelerate launch "${SRC_DIR}/accelerate_train.py" \
        --batch_size 32 \
        --algorithm yolo \
        --learning_rate 1e-4 \
        --max_lr_factor 4 \
        --warmup_steps 0.05 \
        --num_epochs 100 \
        --dataset_dir "$DATASET_DIR" \
        --classification_model_name efficientnet_v2_s || {
        log_message "ERROR: Training failed"
        exit 1
    }

    log_message "Training pipeline completed successfully"
}

# Trap Ctrl+C and exit gracefully
trap 'log_message "Script interrupted by user"; exit 1' INT

# Execute main function
main "$@"