Spaces:
Running
Running
# Exit on any error | |
set -e | |
# Function to log messages | |
log_message() { | |
local timestamp=$(date '+%Y-%m-%d %H:%M:%S') | |
echo "[${timestamp}] $1" | |
} | |
# Function to check if a command exists | |
command_exists() { | |
command -v "$1" >/dev/null 2>&1 | |
} | |
# Function to check if a directory exists | |
check_directory() { | |
if [ ! -d "$1" ]; then | |
log_message "ERROR: Directory $1 does not exist" | |
exit 1 | |
fi | |
} | |
# Function to check Python and required tools | |
check_requirements() { | |
log_message "Checking requirements..." | |
# Check for Python | |
if ! command_exists python3; then | |
log_message "ERROR: Python3 is not installed" | |
exit 1 | |
fi | |
# Check for accelerate | |
if ! command_exists accelerate; then | |
log_message "ERROR: Accelerate is not installed. Please install it using 'pip install accelerate'" | |
exit 1 | |
fi | |
} | |
# Main script execution | |
main() { | |
log_message "Starting training pipeline..." | |
# Set variables | |
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | |
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" | |
DATASET_DIR="$PROJECT_ROOT/data/processed_ds" | |
SCRIPT_DIR="$PROJECT_ROOT/scripts" | |
SRC_DIR="$PROJECT_ROOT/src/slimface/training" | |
# Check if required directories exist | |
check_directory "$SCRIPT_DIR" | |
check_directory "$SRC_DIR" | |
# Check requirements | |
check_requirements | |
# Process dataset | |
log_message "Processing dataset..." | |
python3 "${SCRIPT_DIR}/process_dataset.py" \ | |
--random_state 42 \ | |
--test_split_rate 0.2 \ | |
--augment || { | |
log_message "ERROR: Dataset processing failed" | |
exit 1 | |
} | |
check_directory "$DATASET_DIR" | |
# Configure accelerate | |
log_message "Configuring accelerate..." | |
accelerate config default || { | |
log_message "ERROR: Accelerate configuration failed" | |
exit 1 | |
} | |
# Launch training | |
log_message "Starting model training..." | |
accelerate launch "${SRC_DIR}/accelerate_train.py" \ | |
--batch_size 32 \ | |
--algorithm yolo \ | |
--learning_rate 1e-4 \ | |
--max_lr_factor 4 \ | |
--warmup_steps 0.05 \ | |
--num_epochs 100 \ | |
--dataset_dir "$DATASET_DIR" \ | |
--classification_model_name efficientnet_v2_s || { | |
log_message "ERROR: Training failed" | |
exit 1 | |
} | |
log_message "Training pipeline completed successfully" | |
} | |
# Trap Ctrl+C and exit gracefully | |
trap 'log_message "Script interrupted by user"; exit 1' INT | |
# Execute main function | |
main "$@" |