VoxFactory / docs /training-pipeline.md
Joseph Pollack
adds docs
a3a3978 unverified

A newer version of the Gradio SDK is available: 5.45.0

Upgrade

Training Pipeline

graph TB
    %% Input Data Sources
    subgraph "Data Sources"
        JSONL[JSONL Dataset<br/>{"audio_path": "...", "text": "..."}]
        GRANARY[NVIDIA Granary Dataset<br/>Multilingual ASR Data]
        HFDATA[HF Hub Datasets<br/>Community Datasets]
    end

    %% Data Processing
    subgraph "Data Processing"
        LOADER[Dataset Loader<br/>_load_jsonl_dataset()]
        CASTER[Audio Casting<br/>16kHz resampling]
        COLLATOR[VoxtralDataCollator<br/>Audio + Text Processing]
    end

    %% Training Scripts
    subgraph "Training Scripts"
        TRAIN_FULL[Full Fine-tuning<br/>scripts/train.py]
        TRAIN_LORA[LoRA Fine-tuning<br/>scripts/train_lora.py]

        subgraph "Training Components"
            MODEL_INIT[Model Initialization<br/>VoxtralForConditionalGeneration]
            LORA_CONFIG[LoRA Configuration<br/>LoraConfig + get_peft_model]
            PROCESSOR_INIT[Processor Initialization<br/>VoxtralProcessor]
        end
    end

    %% Training Infrastructure
    subgraph "Training Infrastructure"
        TRACKIO_INIT[Trackio Integration<br/>Experiment Tracking]
        HF_TRAINER[Hugging Face Trainer<br/>TrainingArguments + Trainer]
        TORCH_DEVICE[Torch Device Setup<br/>GPU/CPU Detection]
    end

    %% Training Process
    subgraph "Training Process"
        FORWARD_PASS[Forward Pass<br/>Audio Processing + Generation]
        LOSS_CALC[Loss Calculation<br/>Masked Language Modeling]
        BACKWARD_PASS[Backward Pass<br/>Gradient Computation]
        OPTIMIZER_STEP[Optimizer Step<br/>Parameter Updates]
        LOGGING[Metrics Logging<br/>Loss, Perplexity, etc.]
    end

    %% Model Management
    subgraph "Model Management"
        CHECKPOINT_SAVING[Checkpoint Saving<br/>Model snapshots]
        MODEL_SAVING[Final Model Saving<br/>Processor + Model]
        LOCAL_STORAGE[Local Storage<br/>outputs/ directory]
    end

    %% Flow Connections
    JSONL --> LOADER
    GRANARY --> LOADER
    HFDATA --> LOADER

    LOADER --> CASTER
    CASTER --> COLLATOR

    COLLATOR --> TRAIN_FULL
    COLLATOR --> TRAIN_LORA

    TRAIN_FULL --> MODEL_INIT
    TRAIN_LORA --> MODEL_INIT
    TRAIN_LORA --> LORA_CONFIG

    MODEL_INIT --> PROCESSOR_INIT
    LORA_CONFIG --> PROCESSOR_INIT

    PROCESSOR_INIT --> TRACKIO_INIT
    PROCESSOR_INIT --> HF_TRAINER
    PROCESSOR_INIT --> TORCH_DEVICE

    TRACKIO_INIT --> HF_TRAINER
    TORCH_DEVICE --> HF_TRAINER

    HF_TRAINER --> FORWARD_PASS
    FORWARD_PASS --> LOSS_CALC
    LOSS_CALC --> BACKWARD_PASS
    BACKWARD_PASS --> OPTIMIZER_STEP
    OPTIMIZER_STEP --> LOGGING

    LOGGING --> CHECKPOINT_SAVING
    LOGGING --> TRACKIO_INIT

    HF_TRAINER --> MODEL_SAVING
    MODEL_SAVING --> LOCAL_STORAGE

    %% Styling
    classDef input fill:#e3f2fd,stroke:#1976d2,stroke-width:2px
    classDef processing fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px
    classDef training fill:#e8f5e8,stroke:#388e3c,stroke-width:2px
    classDef infrastructure fill:#fff3e0,stroke:#f57c00,stroke-width:2px
    classDef execution fill:#fce4ec,stroke:#c2185b,stroke-width:2px
    classDef output fill:#f5f5f5,stroke:#424242,stroke-width:2px

    class JSONL,GRANARY,HFDATA input
    class LOADER,CASTER,COLLATOR processing
    class TRAIN_FULL,TRAIN_LORA,MODEL_INIT,LORA_CONFIG,PROCESSOR_INIT training
    class TRACKIO_INIT,HF_TRAINER,TORCH_DEVICE infrastructure
    class FORWARD_PASS,LOSS_CALC,BACKWARD_PASS,OPTIMIZER_STEP,LOGGING execution
    class CHECKPOINT_SAVING,MODEL_SAVING,LOCAL_STORAGE output

Training Pipeline Overview

This diagram illustrates the complete training pipeline for Voxtral ASR fine-tuning, showing how data flows through the training scripts and supporting infrastructure.

Data Input Sources

JSONL Datasets

  • Local Datasets: User-created datasets from recordings or uploads
  • Format: {"audio_path": "path/to/audio.wav", "text": "transcription"}
  • Processing: Loaded via _load_jsonl_dataset() function

NVIDIA Granary Dataset

  • Multilingual Support: 25+ European languages
  • High Quality: Curated ASR training data
  • Streaming: Efficient loading without full download

Hugging Face Hub Datasets

  • Community Datasets: Public datasets from HF Hub
  • Standard Formats: Compatible with Voxtral training requirements

Data Processing Pipeline

Dataset Loading

# Load local JSONL or HF dataset
ds = _load_jsonl_dataset(jsonl_path)
# or
ds = load_dataset(ds_name, ds_cfg, split="test")

Audio Processing

# Cast to Audio format with 16kHz resampling
ds = ds.cast_column("audio", Audio(sampling_rate=16000))

Data Collation

  • VoxtralDataCollator: Custom collator for Voxtral training
  • Audio Processing: Converts audio to model inputs
  • Text Tokenization: Processes transcription text
  • Masking: Masks prompt tokens during training

Training Script Architecture

Full Fine-tuning (train.py)

  • Complete Model Updates: All parameters trainable
  • Higher Memory Requirements: Full model in memory
  • Better Convergence: Can achieve higher accuracy

LoRA Fine-tuning (train_lora.py)

  • Parameter Efficient: Only LoRA adapters trained
  • Lower Memory Usage: Base model frozen
  • Faster Training: Fewer parameters to update
  • Configurable: r, alpha, dropout parameters

Training Infrastructure

Trackio Integration

trackio.init(
    project="voxtral-finetuning",
    config={...},  # Training parameters
    space_id=trackio_space
)

Hugging Face Trainer

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=batch_size,
    learning_rate=learning_rate,
    num_train_epochs=epochs,
    bf16=True,  # BFloat16 for efficiency
    report_to=["trackio"],
    # ... other args
)

Device Management

  • GPU Detection: Automatic CUDA/GPU detection
  • Fallback: CPU training if no GPU available
  • Memory Optimization: Model sharding and gradient checkpointing

Training Process Flow

Forward Pass

  1. Audio Input: Raw audio waveforms
  2. Audio Tower: Audio feature extraction
  3. Text Generation: Autoregressive text generation from audio features

Loss Calculation

  • Masked Language Modeling: Only transcription tokens contribute to loss
  • Audio Prompt Masking: Audio processing tokens are masked out
  • Cross-Entropy Loss: Standard language modeling loss

Backward Pass & Optimization

  • Gradient Computation: Backpropagation through the model
  • LoRA Updates: Only adapter parameters updated (LoRA mode)
  • Full Updates: All parameters updated (full fine-tuning)

Model Management

Checkpoint Saving

  • Regular Checkpoints: Saved every N steps
  • Best Model Tracking: Save best model based on validation loss
  • Resume Capability: Continue training from checkpoints

Final Model Saving

trainer.save_model()  # Saves model and tokenizer
processor.save_pretrained(output_dir)  # Saves processor

Local Storage Structure

outputs/
β”œβ”€β”€ voxtral-finetuned-{timestamp}/
β”‚   β”œβ”€β”€ config.json
β”‚   β”œβ”€β”€ model.safetensors
β”‚   β”œβ”€β”€ tokenizer.json
β”‚   β”œβ”€β”€ training_config.json
β”‚   β”œβ”€β”€ train_results.json
β”‚   └── eval_results.json

Integration Points

With Interface (interface.py)

  • Parameter Passing: Training parameters from UI
  • Log Streaming: Real-time training logs to UI
  • Progress Monitoring: Training progress updates

With Model Publishing (push_to_huggingface.py)

  • Model Upload: Trained model to HF Hub
  • Metadata: Training config and results
  • Model Cards: Automatic model card generation

With Demo Deployment (deploy_demo_space.py)

  • Space Creation: HF Spaces for demos
  • Model Integration: Deploy trained model in demo
  • Configuration: Demo-specific settings

Performance Considerations

Memory Optimization

  • LoRA: Significantly reduces memory requirements
  • Gradient Checkpointing: Trade compute for memory
  • Mixed Precision: BF16/FP16 training

Training Efficiency

  • Batch Size: Balanced with gradient accumulation
  • Learning Rate: Warmup and decay schedules
  • Early Stopping: Prevent overfitting

Monitoring & Debugging

  • Metrics Tracking: Loss, perplexity, learning rate
  • GPU Utilization: Memory and compute monitoring
  • Error Handling: Graceful failure recovery

See also: