Spaces:
Running
Running
File size: 8,659 Bytes
a3a3978 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 |
# Training Pipeline
```mermaid
graph TB
%% Input Data Sources
subgraph "Data Sources"
JSONL[JSONL Dataset<br/>{"audio_path": "...", "text": "..."}]
GRANARY[NVIDIA Granary Dataset<br/>Multilingual ASR Data]
HFDATA[HF Hub Datasets<br/>Community Datasets]
end
%% Data Processing
subgraph "Data Processing"
LOADER[Dataset Loader<br/>_load_jsonl_dataset()]
CASTER[Audio Casting<br/>16kHz resampling]
COLLATOR[VoxtralDataCollator<br/>Audio + Text Processing]
end
%% Training Scripts
subgraph "Training Scripts"
TRAIN_FULL[Full Fine-tuning<br/>scripts/train.py]
TRAIN_LORA[LoRA Fine-tuning<br/>scripts/train_lora.py]
subgraph "Training Components"
MODEL_INIT[Model Initialization<br/>VoxtralForConditionalGeneration]
LORA_CONFIG[LoRA Configuration<br/>LoraConfig + get_peft_model]
PROCESSOR_INIT[Processor Initialization<br/>VoxtralProcessor]
end
end
%% Training Infrastructure
subgraph "Training Infrastructure"
TRACKIO_INIT[Trackio Integration<br/>Experiment Tracking]
HF_TRAINER[Hugging Face Trainer<br/>TrainingArguments + Trainer]
TORCH_DEVICE[Torch Device Setup<br/>GPU/CPU Detection]
end
%% Training Process
subgraph "Training Process"
FORWARD_PASS[Forward Pass<br/>Audio Processing + Generation]
LOSS_CALC[Loss Calculation<br/>Masked Language Modeling]
BACKWARD_PASS[Backward Pass<br/>Gradient Computation]
OPTIMIZER_STEP[Optimizer Step<br/>Parameter Updates]
LOGGING[Metrics Logging<br/>Loss, Perplexity, etc.]
end
%% Model Management
subgraph "Model Management"
CHECKPOINT_SAVING[Checkpoint Saving<br/>Model snapshots]
MODEL_SAVING[Final Model Saving<br/>Processor + Model]
LOCAL_STORAGE[Local Storage<br/>outputs/ directory]
end
%% Flow Connections
JSONL --> LOADER
GRANARY --> LOADER
HFDATA --> LOADER
LOADER --> CASTER
CASTER --> COLLATOR
COLLATOR --> TRAIN_FULL
COLLATOR --> TRAIN_LORA
TRAIN_FULL --> MODEL_INIT
TRAIN_LORA --> MODEL_INIT
TRAIN_LORA --> LORA_CONFIG
MODEL_INIT --> PROCESSOR_INIT
LORA_CONFIG --> PROCESSOR_INIT
PROCESSOR_INIT --> TRACKIO_INIT
PROCESSOR_INIT --> HF_TRAINER
PROCESSOR_INIT --> TORCH_DEVICE
TRACKIO_INIT --> HF_TRAINER
TORCH_DEVICE --> HF_TRAINER
HF_TRAINER --> FORWARD_PASS
FORWARD_PASS --> LOSS_CALC
LOSS_CALC --> BACKWARD_PASS
BACKWARD_PASS --> OPTIMIZER_STEP
OPTIMIZER_STEP --> LOGGING
LOGGING --> CHECKPOINT_SAVING
LOGGING --> TRACKIO_INIT
HF_TRAINER --> MODEL_SAVING
MODEL_SAVING --> LOCAL_STORAGE
%% Styling
classDef input fill:#e3f2fd,stroke:#1976d2,stroke-width:2px
classDef processing fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px
classDef training fill:#e8f5e8,stroke:#388e3c,stroke-width:2px
classDef infrastructure fill:#fff3e0,stroke:#f57c00,stroke-width:2px
classDef execution fill:#fce4ec,stroke:#c2185b,stroke-width:2px
classDef output fill:#f5f5f5,stroke:#424242,stroke-width:2px
class JSONL,GRANARY,HFDATA input
class LOADER,CASTER,COLLATOR processing
class TRAIN_FULL,TRAIN_LORA,MODEL_INIT,LORA_CONFIG,PROCESSOR_INIT training
class TRACKIO_INIT,HF_TRAINER,TORCH_DEVICE infrastructure
class FORWARD_PASS,LOSS_CALC,BACKWARD_PASS,OPTIMIZER_STEP,LOGGING execution
class CHECKPOINT_SAVING,MODEL_SAVING,LOCAL_STORAGE output
```
## Training Pipeline Overview
This diagram illustrates the complete training pipeline for Voxtral ASR fine-tuning, showing how data flows through the training scripts and supporting infrastructure.
### Data Input Sources
#### JSONL Datasets
- **Local Datasets**: User-created datasets from recordings or uploads
- **Format**: `{"audio_path": "path/to/audio.wav", "text": "transcription"}`
- **Processing**: Loaded via `_load_jsonl_dataset()` function
#### NVIDIA Granary Dataset
- **Multilingual Support**: 25+ European languages
- **High Quality**: Curated ASR training data
- **Streaming**: Efficient loading without full download
#### Hugging Face Hub Datasets
- **Community Datasets**: Public datasets from HF Hub
- **Standard Formats**: Compatible with Voxtral training requirements
### Data Processing Pipeline
#### Dataset Loading
```python
# Load local JSONL or HF dataset
ds = _load_jsonl_dataset(jsonl_path)
# or
ds = load_dataset(ds_name, ds_cfg, split="test")
```
#### Audio Processing
```python
# Cast to Audio format with 16kHz resampling
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
```
#### Data Collation
- **VoxtralDataCollator**: Custom collator for Voxtral training
- **Audio Processing**: Converts audio to model inputs
- **Text Tokenization**: Processes transcription text
- **Masking**: Masks prompt tokens during training
### Training Script Architecture
#### Full Fine-tuning (`train.py`)
- **Complete Model Updates**: All parameters trainable
- **Higher Memory Requirements**: Full model in memory
- **Better Convergence**: Can achieve higher accuracy
#### LoRA Fine-tuning (`train_lora.py`)
- **Parameter Efficient**: Only LoRA adapters trained
- **Lower Memory Usage**: Base model frozen
- **Faster Training**: Fewer parameters to update
- **Configurable**: r, alpha, dropout parameters
### Training Infrastructure
#### Trackio Integration
```python
trackio.init(
project="voxtral-finetuning",
config={...}, # Training parameters
space_id=trackio_space
)
```
#### Hugging Face Trainer
```python
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=batch_size,
learning_rate=learning_rate,
num_train_epochs=epochs,
bf16=True, # BFloat16 for efficiency
report_to=["trackio"],
# ... other args
)
```
#### Device Management
- **GPU Detection**: Automatic CUDA/GPU detection
- **Fallback**: CPU training if no GPU available
- **Memory Optimization**: Model sharding and gradient checkpointing
### Training Process Flow
#### Forward Pass
1. **Audio Input**: Raw audio waveforms
2. **Audio Tower**: Audio feature extraction
3. **Text Generation**: Autoregressive text generation from audio features
#### Loss Calculation
- **Masked Language Modeling**: Only transcription tokens contribute to loss
- **Audio Prompt Masking**: Audio processing tokens are masked out
- **Cross-Entropy Loss**: Standard language modeling loss
#### Backward Pass & Optimization
- **Gradient Computation**: Backpropagation through the model
- **LoRA Updates**: Only adapter parameters updated (LoRA mode)
- **Full Updates**: All parameters updated (full fine-tuning)
### Model Management
#### Checkpoint Saving
- **Regular Checkpoints**: Saved every N steps
- **Best Model Tracking**: Save best model based on validation loss
- **Resume Capability**: Continue training from checkpoints
#### Final Model Saving
```python
trainer.save_model() # Saves model and tokenizer
processor.save_pretrained(output_dir) # Saves processor
```
#### Local Storage Structure
```
outputs/
βββ voxtral-finetuned-{timestamp}/
β βββ config.json
β βββ model.safetensors
β βββ tokenizer.json
β βββ training_config.json
β βββ train_results.json
β βββ eval_results.json
```
### Integration Points
#### With Interface (`interface.py`)
- **Parameter Passing**: Training parameters from UI
- **Log Streaming**: Real-time training logs to UI
- **Progress Monitoring**: Training progress updates
#### With Model Publishing (`push_to_huggingface.py`)
- **Model Upload**: Trained model to HF Hub
- **Metadata**: Training config and results
- **Model Cards**: Automatic model card generation
#### With Demo Deployment (`deploy_demo_space.py`)
- **Space Creation**: HF Spaces for demos
- **Model Integration**: Deploy trained model in demo
- **Configuration**: Demo-specific settings
### Performance Considerations
#### Memory Optimization
- **LoRA**: Significantly reduces memory requirements
- **Gradient Checkpointing**: Trade compute for memory
- **Mixed Precision**: BF16/FP16 training
#### Training Efficiency
- **Batch Size**: Balanced with gradient accumulation
- **Learning Rate**: Warmup and decay schedules
- **Early Stopping**: Prevent overfitting
#### Monitoring & Debugging
- **Metrics Tracking**: Loss, perplexity, learning rate
- **GPU Utilization**: Memory and compute monitoring
- **Error Handling**: Graceful failure recovery
See also:
- [Architecture Overview](architecture.md)
- [Interface Workflow](interface-workflow.md)
- [Data Flow](data-flow.md)
|