OpenSound commited on
Commit
dd9600d
·
verified ·
1 Parent(s): 2b391d0

Upload 518 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. capspeech/__init__.py +0 -0
  2. capspeech/ar/README.md +44 -0
  3. capspeech/ar/__init__.py +0 -0
  4. capspeech/ar/events.txt +395 -0
  5. capspeech/ar/finetune_acccaptts.sh +64 -0
  6. capspeech/ar/finetune_agenttts.sh +61 -0
  7. capspeech/ar/finetune_captts.sh +64 -0
  8. capspeech/ar/finetune_capttsse.sh +62 -0
  9. capspeech/ar/finetune_emocaptts.sh +64 -0
  10. capspeech/ar/parler_tts/__init__.py +25 -0
  11. capspeech/ar/parler_tts/configuration_parler_tts.py +291 -0
  12. capspeech/ar/parler_tts/dac_wrapper/__init__.py +2 -0
  13. capspeech/ar/parler_tts/dac_wrapper/configuration_dac.py +27 -0
  14. capspeech/ar/parler_tts/dac_wrapper/modeling_dac.py +164 -0
  15. capspeech/ar/parler_tts/logits_processors.py +54 -0
  16. capspeech/ar/parler_tts/modeling_parler_tts.py +0 -0
  17. capspeech/ar/parler_tts/streamer.py +147 -0
  18. capspeech/ar/pretrain.sh +68 -0
  19. capspeech/ar/training/__init__.py +0 -0
  20. capspeech/ar/training/arguments.py +403 -0
  21. capspeech/ar/training/arguments_captts.py +391 -0
  22. capspeech/ar/training/arguments_capttsse.py +387 -0
  23. capspeech/ar/training/data.py +277 -0
  24. capspeech/ar/training/data_captts.py +255 -0
  25. capspeech/ar/training/data_capttsse.py +253 -0
  26. capspeech/ar/training/finetune_captts.py +1270 -0
  27. capspeech/ar/training/finetune_capttsse.py +1267 -0
  28. capspeech/ar/training/run_parler_tts_training.py +1279 -0
  29. capspeech/ar/training/utils.py +203 -0
  30. capspeech/eval/README.md +42 -0
  31. capspeech/eval/__init__.py +0 -0
  32. capspeech/eval/age_gender.py +35 -0
  33. capspeech/eval/asr_eval.py +24 -0
  34. capspeech/eval/base_eval.py +32 -0
  35. capspeech/eval/bin.json +10 -0
  36. capspeech/eval/pitch.py +30 -0
  37. capspeech/eval/requirements.txt +16 -0
  38. capspeech/eval/speed.py +29 -0
  39. capspeech/eval/src/__init__.py +0 -0
  40. capspeech/eval/src/example/__init__.py +0 -0
  41. capspeech/eval/src/example/categorized_emotion.py +92 -0
  42. capspeech/eval/src/example/dialect_world_dialect.py +87 -0
  43. capspeech/eval/src/model/__init__.py +0 -0
  44. capspeech/eval/src/model/adapter.py +73 -0
  45. capspeech/eval/src/model/dialect/__init__.py +0 -0
  46. capspeech/eval/src/model/dialect/wavlm_dialect.py +300 -0
  47. capspeech/eval/src/model/dialect/whisper_dialect.py +301 -0
  48. capspeech/eval/src/model/emotion/__init__.py +0 -0
  49. capspeech/eval/src/model/emotion/wavlm_emotion.py +315 -0
  50. capspeech/eval/src/model/emotion/wavlm_emotion_dim.py +318 -0
capspeech/__init__.py ADDED
File without changes
capspeech/ar/README.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CapSpeech-AR
2
+
3
+ ## Pretrain
4
+
5
+ ```bash
6
+ bash pretrain.sh
7
+ ```
8
+ Make sure to change paths and keys in `pretrain.sh` to yours.
9
+
10
+ ## Finetune on CapTTS
11
+
12
+ ```bash
13
+ bash finetune_captts.sh
14
+ ```
15
+ Make sure to change paths and keys in `finetune_captts.sh` to yours.
16
+
17
+ ## Finetune on EmoCapTTS
18
+
19
+ ```bash
20
+ bash finetune_emocaptts.sh
21
+ ```
22
+ Make sure to change paths and keys in `finetune_emocaptts.sh` to yours.
23
+
24
+ ## Finetune on AccCapTTS
25
+
26
+ ```bash
27
+ bash finetune_acccaptts.sh
28
+ ```
29
+ Make sure to change paths and keys in `finetune_acccaptts.sh` to yours.
30
+
31
+ ## Finetune on CapTTS-SE
32
+
33
+ ```bash
34
+ bash finetune_capttsse.sh
35
+ ```
36
+ Make sure to change paths and keys in `finetune_capttsse.sh` to yours.
37
+
38
+
39
+ ## Finetune on AgentTTS
40
+
41
+ ```bash
42
+ bash finetune_agenttts.sh
43
+ ```
44
+ Make sure to change paths and keys in `finetune_agenttts.sh` to yours.
capspeech/ar/__init__.py ADDED
File without changes
capspeech/ar/events.txt ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ people whispering
2
+ Microwave oven
3
+ extending ladders
4
+ mosquito buzzing
5
+ dog whimpering
6
+ coyote howling
7
+ hair dryer drying
8
+ Writing
9
+ rapping
10
+ machine gun shooting
11
+ dog bow-wow
12
+ dog howling
13
+ barn swallow calling
14
+ baby babbling
15
+ Fireworks
16
+ church bell ringing
17
+ car horn
18
+ cat caterwauling
19
+ subway, metro, underground
20
+ waterfall burbling
21
+ lions roaring
22
+ toilet flushing
23
+ skateboarding
24
+ wind
25
+ ripping paper
26
+ vacuum cleaner cleaning floors
27
+ mouse squeaking
28
+ keyboard typing
29
+ playing timpani
30
+ playing harp
31
+ sheep bleating
32
+ eletric blender running
33
+ people slapping
34
+ playing ukulele
35
+ frog
36
+ car engine knocking
37
+ cat purring
38
+ chainsaw
39
+ Violin or fiddle
40
+ people hiccup
41
+ playing acoustic guitar
42
+ donkey, ass braying
43
+ playing french horn
44
+ playing squash
45
+ gibbon howling
46
+ playing harmonica
47
+ playing shofar
48
+ hedge trimmer running
49
+ playing washboard
50
+ running electric fan
51
+ splashing water
52
+ playing bassoon
53
+ people slurping
54
+ playing accordion
55
+ playing oboe
56
+ popping popcorn
57
+ glass breaking
58
+ alarm clock ringing
59
+ mouse click
60
+ Laughter
61
+ magpie calling
62
+ playing snare drum
63
+ people finger snapping
64
+ ferret dooking
65
+ tornado roaring
66
+ Hi-hat
67
+ lawn mowing
68
+ church bells
69
+ cat growling
70
+ cheetah chirrup
71
+ heart sounds, heartbeat
72
+ firing muskets
73
+ vehicle horn, car horn, honking
74
+ turkey gobbling
75
+ ice cream truck, ice cream van
76
+ underwater bubbling
77
+ footsteps on snow
78
+ water drops
79
+ people sobbing
80
+ basketball bounce
81
+ Applause
82
+ playing sitar
83
+ playing gong
84
+ train
85
+ coughing
86
+ people screaming
87
+ Gunshot or gunfire
88
+ chinchilla barking
89
+ cat hissing
90
+ horse clip-clop
91
+ engine
92
+ people battle cry
93
+ typing on computer keyboard
94
+ playing clarinet
95
+ driving motorcycle
96
+ male singing
97
+ singing bowl
98
+ skiing
99
+ driving buses
100
+ alligators, crocodiles hissing
101
+ people eating apple
102
+ door slamming
103
+ Flute
104
+ raining
105
+ Electric piano
106
+ sliding door
107
+ washing machine
108
+ opening or closing car electric windows
109
+ baby crying
110
+ people babbling
111
+ snake hissing
112
+ brushing teeth
113
+ playing tambourine
114
+ Acoustic guitar
115
+ clock tick
116
+ playing castanets
117
+ thunder
118
+ playing didgeridoo
119
+ playing synthesizer
120
+ mouse clicking
121
+ lathe spinning
122
+ spraying water
123
+ hen
124
+ stream burbling
125
+ door wood creaks
126
+ sailing
127
+ dog
128
+ car engine idling
129
+ bowling impact
130
+ driving snowmobile
131
+ toilet flush
132
+ bird squawking
133
+ playing timbales
134
+ playing drum kit
135
+ owl hooting
136
+ striking pool
137
+ Oboe
138
+ duck quacking
139
+ people belly laughing
140
+ lighting firecrackers
141
+ roller coaster running
142
+ blowtorch igniting
143
+ wood thrush calling
144
+ Glockenspiel
145
+ frog croaking
146
+ playing harpsichord
147
+ train horning
148
+ plastic bottle crushing
149
+ playing tabla
150
+ fire crackling
151
+ dog barking
152
+ thunderstorm
153
+ playing banjo
154
+ swimming
155
+ volcano explosion
156
+ playing table tennis
157
+ sea lion barking
158
+ rowboat, canoe, kayak rowing
159
+ Meow
160
+ pouring water
161
+ playing tympani
162
+ rooster
163
+ siren
164
+ parrot talking
165
+ Finger snapping
166
+ playing steel guitar, slide guitar
167
+ Trumpet
168
+ tractor digging
169
+ people coughing
170
+ cat meowing
171
+ Snare drum
172
+ playing erhu
173
+ crow cawing
174
+ playing djembe
175
+ whale calling
176
+ mynah bird singing
177
+ playing tennis
178
+ chopping food
179
+ golf driving
180
+ tapping guitar
181
+ playing cello
182
+ dog growling
183
+ elephant trumpeting
184
+ sea waves
185
+ police radio chatter
186
+ lions growling
187
+ playing lacrosse
188
+ children shouting
189
+ missile launch
190
+ baby laughter
191
+ air conditioning noise
192
+ playing saxophone
193
+ typing on typewriter
194
+ printer printing
195
+ race car, auto racing
196
+ Bus
197
+ pigeon, dove cooing
198
+ playing violin, fiddle
199
+ Double bass
200
+ striking bowling
201
+ fireworks banging
202
+ Harmonica
203
+ playing glockenspiel
204
+ reversing beeps
205
+ playing piano
206
+ breathing
207
+ people marching
208
+ electric shaver, electric razor shaving
209
+ chimpanzee pant-hooting
210
+ cricket chirping
211
+ bird chirping, tweeting
212
+ using sewing machines
213
+ crickets
214
+ cow lowing
215
+ playing cymbal
216
+ vacuum cleaner
217
+ playing zither
218
+ train whistling
219
+ goat bleating
220
+ eating with cutlery
221
+ black capped chickadee calling
222
+ ambulance siren
223
+ playing hockey
224
+ dog baying
225
+ Burping or eructation
226
+ cupboard opening or closing
227
+ air horn
228
+ crying baby
229
+ people eating crisps
230
+ sloshing water
231
+ goose honking
232
+ orchestra
233
+ people giggling
234
+ warbler chirping
235
+ child singing
236
+ dinosaurs bellowing
237
+ motorboat, speedboat acceleration
238
+ airplane
239
+ chicken clucking
240
+ woodpecker pecking tree
241
+ Drawer open or close
242
+ people eating
243
+ drinking sipping
244
+ singing choir
245
+ playing bass guitar
246
+ playing bass drum
247
+ car passing by
248
+ playing tuning fork
249
+ Squeak
250
+ pig oinking
251
+ Computer keyboard
252
+ yodelling
253
+ playing trombone
254
+ clapping
255
+ people sneezing
256
+ pheasant crowing
257
+ writing on blackboard with chalk
258
+ Tambourine
259
+ opening or closing car doors
260
+ sharpen knife
261
+ people whistling
262
+ fireworks
263
+ playing bagpipes
264
+ chainsawing trees
265
+ squishing water
266
+ people farting
267
+ playing electric guitar
268
+ people booing
269
+ female singing
270
+ ocean burbling
271
+ cattle mooing
272
+ footsteps
273
+ Knock
274
+ wind rustling leaves
275
+ cattle, bovinae cowbell
276
+ Clarinet
277
+ police car (siren)
278
+ Fart
279
+ cat
280
+ sheep
281
+ chopping wood
282
+ tap dancing
283
+ playing mandolin
284
+ wind chime
285
+ can opening
286
+ playing hammond organ
287
+ zebra braying
288
+ scuba diving
289
+ chirping birds
290
+ playing steelpan
291
+ playing theremin
292
+ Keys jangling
293
+ beat boxing
294
+ firing cannon
295
+ bouncing on trampoline
296
+ door wood knock
297
+ bathroom ventilation fan running
298
+ snake rattling
299
+ bull bellowing
300
+ electric grinder grinding
301
+ penguins braying
302
+ otter growling
303
+ civil defense siren
304
+ wind noise
305
+ people humming
306
+ clock alarm
307
+ disc scratching
308
+ fire truck siren
309
+ telephone bell ringing
310
+ people sniggering
311
+ playing bongo
312
+ cap gun shooting
313
+ opening or closing drawers
314
+ cow
315
+ hammering nails
316
+ ice cracking
317
+ foghorn
318
+ rain
319
+ playing badminton
320
+ eagle screaming
321
+ playing double bass
322
+ insects
323
+ people running
324
+ planing timber
325
+ cutting hair with electric trimmers
326
+ Cello
327
+ people clapping
328
+ smoke detector beeping
329
+ mouse pattering
330
+ bee, wasp, etc. buzzing
331
+ canary calling
332
+ people burping
333
+ Shatter
334
+ baltimore oriole calling
335
+ cuckoo bird calling
336
+ snoring
337
+ strike lighter
338
+ people cheering
339
+ playing bugle
340
+ playing congas
341
+ playing vibraphone
342
+ hail
343
+ rope skipping
344
+ playing trumpet
345
+ pig
346
+ hand saw
347
+ people gargling
348
+ Scissors
349
+ metronome
350
+ chipmunk chirping
351
+ playing flute
352
+ fox barking
353
+ crackling fire
354
+ playing volleyball
355
+ skidding
356
+ Bass drum
357
+ crow
358
+ elk bugling
359
+ Telephone
360
+ Bark
361
+ chicken crowing
362
+ people nose blowing
363
+ car engine starting
364
+ pumping water
365
+ Saxophone
366
+ fly, housefly buzzing
367
+ Cough
368
+ people eating noodle
369
+ francolin calling
370
+ arc welding
371
+ horse neighing
372
+ Tearing
373
+ helicopter
374
+ playing electronic organ
375
+ Cowbell
376
+ railroad car, train wagon
377
+ cell phone buzzing
378
+ playing cornet
379
+ sneezing
380
+ engine accelerating, revving, vroom
381
+ bird wings flapping
382
+ playing marimba, xylophone
383
+ playing guiro
384
+ people crowd
385
+ train wheels squealing
386
+ slot machine
387
+ laughing
388
+ lip smacking
389
+ forging swords
390
+ Chime
391
+ playing darts
392
+ people shuffling
393
+ Gong
394
+ airplane flyby
395
+ None
capspeech/ar/finetune_acccaptts.sh ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Please log in to huggingface first
2
+
3
+ LIBRITTSR_WAV_DIR='' # downloaded libritts-r wav dir
4
+ OTHER_WAV_DIR='' # downloaded other wav dirs
5
+ OUTPUT_DIR="./output_finetuning_acccaptts/" # output dir, to save checkpoints
6
+ TEMPORY_SAVE_TO_DISK="./audio_code_finetuning_acccaptts/" # dac codec saved dir
7
+ SAVE_TO_DISK="./dataset_finetuning_acccaptts/" # huggingface metadata saved dir
8
+ WANDB_KEY='' # your wandb key for logging
9
+
10
+ PRETRAINED_MODEL_PATH="" # your pretrained model path
11
+
12
+ export CUDA_LAUNCH_BLOCKING=1
13
+ export TORCH_USE_CUDA_DSA=1
14
+
15
+ accelerate launch ./training/finetune_captts.py \
16
+ --model_name_or_path ${PRETRAINED_MODEL_PATH} \
17
+ --feature_extractor_name "parler-tts/dac_44khZ_8kbps" \
18
+ --description_tokenizer_name ${PRETRAINED_MODEL_PATH} \
19
+ --prompt_tokenizer_name ${PRETRAINED_MODEL_PATH} \
20
+ --report_to "wandb" \
21
+ --wandb_key ${WANDB_KEY} \
22
+ --overwrite_output_dir true \
23
+ --train_dataset_name "OpenSound/CapSpeech" \
24
+ --train_split_name "train_SFT_AccCapTTS" \
25
+ --eval_dataset_name "OpenSound/CapSpeech" \
26
+ --eval_split_name "validation_SFT_AccCapTTS" \
27
+ --librittsr_dir ${LIBRITTSR_WAV_DIR} \
28
+ --other_dir ${OTHER_WAV_DIR} \
29
+ --max_eval_samples 96 \
30
+ --per_device_eval_batch_size 32 \
31
+ --target_audio_column_name "audio_path" \
32
+ --description_column_name "caption" \
33
+ --source_column_name "source" \
34
+ --prompt_column_name "text" \
35
+ --max_duration_in_seconds 20 \
36
+ --min_duration_in_seconds 3 \
37
+ --max_text_length 600 \
38
+ --preprocessing_num_workers 32 \
39
+ --do_train true \
40
+ --num_train_epochs 5 \
41
+ --gradient_accumulation_steps 6 \
42
+ --gradient_checkpointing false \
43
+ --per_device_train_batch_size 4 \
44
+ --learning_rate 0.0001 \
45
+ --adam_beta1 0.9 \
46
+ --adam_beta2 0.99 \
47
+ --weight_decay 0.01 \
48
+ --lr_scheduler_type "constant_with_warmup" \
49
+ --warmup_steps 1000 \
50
+ --logging_steps 200 \
51
+ --freeze_text_encoder true \
52
+ --per_device_eval_batch_size 4 \
53
+ --audio_encoder_per_device_batch_size 24 \
54
+ --dtype "float16" \
55
+ --seed 456 \
56
+ --output_dir ${OUTPUT_DIR} \
57
+ --temporary_save_to_disk ${TEMPORY_SAVE_TO_DISK} \
58
+ --save_to_disk ${SAVE_TO_DISK} \
59
+ --dataloader_num_workers 32 \
60
+ --do_eval \
61
+ --evaluation_strategy steps \
62
+ --eval_steps 500 \
63
+ --save_steps 500 \
64
+ --group_by_length true
capspeech/ar/finetune_agenttts.sh ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Please log in to huggingface first
2
+
3
+ OTHER_WAV_DIR='' # downloaded capspeech-agentdb wav dir
4
+ OUTPUT_DIR="./output_finetuning_agenttts/" # output dir, to save checkpoints
5
+ TEMPORY_SAVE_TO_DISK="./audio_code_finetuning_agenttts/" # dac codec saved dir
6
+ SAVE_TO_DISK="./dataset_finetuning_agenttts/" # huggingface metadata saved dir
7
+ WANDB_KEY='' # your wandb key for logging
8
+ PRETRAINED_MODEL_PATH="" # your pretrained model path
9
+
10
+ export CUDA_LAUNCH_BLOCKING=1
11
+ export TORCH_USE_CUDA_DSA=1
12
+
13
+ accelerate launch ./training/finetune_captts.py \
14
+ --model_name_or_path "/export/fs05/hwang258/parler-tts/parler-tts" \
15
+ --feature_extractor_name "parler-tts/dac_44khZ_8kbps" \
16
+ --description_tokenizer_name ${PRETRAINED_MODEL_PATH} \
17
+ --prompt_tokenizer_name ${PRETRAINED_MODEL_PATH} \
18
+ --report_to "wandb" \
19
+ --wandb_key ${WANDB_KEY} \
20
+ --overwrite_output_dir true \
21
+ --train_dataset_name "OpenSound/CapSpeech" \
22
+ --train_split_name "train_AgentDB" \
23
+ --eval_dataset_name "OpenSound/CapSpeech" \
24
+ --eval_split_name "test_AgentDB" \
25
+ --other_dir ${OTHER_WAV_DIR} \
26
+ --max_eval_samples 96 \
27
+ --per_device_eval_batch_size 32 \
28
+ --target_audio_column_name "audio_path" \
29
+ --description_column_name "caption" \
30
+ --source_column_name "source" \
31
+ --prompt_column_name "text" \
32
+ --max_duration_in_seconds 20 \
33
+ --min_duration_in_seconds 3 \
34
+ --max_text_length 600 \
35
+ --preprocessing_num_workers 32 \
36
+ --do_train true \
37
+ --num_train_epochs 50 \
38
+ --gradient_accumulation_steps 6 \
39
+ --gradient_checkpointing false \
40
+ --per_device_train_batch_size 4 \
41
+ --learning_rate 0.0001 \
42
+ --adam_beta1 0.9 \
43
+ --adam_beta2 0.99 \
44
+ --weight_decay 0.01 \
45
+ --lr_scheduler_type "constant_with_warmup" \
46
+ --warmup_steps 500 \
47
+ --logging_steps 100 \
48
+ --freeze_text_encoder true \
49
+ --per_device_eval_batch_size 4 \
50
+ --audio_encoder_per_device_batch_size 24 \
51
+ --dtype "float16" \
52
+ --seed 456 \
53
+ --output_dir ${OUTPUT_DIR} \
54
+ --temporary_save_to_disk ${TEMPORY_SAVE_TO_DISK} \
55
+ --save_to_disk ${SAVE_TO_DISK} \
56
+ --dataloader_num_workers 32 \
57
+ --do_eval \
58
+ --evaluation_strategy steps \
59
+ --eval_steps 500 \
60
+ --save_steps 500 \
61
+ --group_by_length true
capspeech/ar/finetune_captts.sh ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Please log in to huggingface first
2
+
3
+ LIBRITTSR_WAV_DIR='' # downloaded libritts-r wav dir
4
+ OTHER_WAV_DIR='' # downloaded other wav dirs
5
+ OUTPUT_DIR="./output_finetuning_captts/" # output dir, to save checkpoints
6
+ TEMPORY_SAVE_TO_DISK="./audio_code_finetuning_captts/" # dac codec saved dir
7
+ SAVE_TO_DISK="./dataset_finetuning_captts/" # huggingface metadata saved dir
8
+ WANDB_KEY='' # your wandb key for logging
9
+
10
+ PRETRAINED_MODEL_PATH="" # your pretrained model path
11
+
12
+ export CUDA_LAUNCH_BLOCKING=1
13
+ export TORCH_USE_CUDA_DSA=1
14
+
15
+ accelerate launch ./training/finetune_captts.py \
16
+ --model_name_or_path ${PRETRAINED_MODEL_PATH} \
17
+ --feature_extractor_name "parler-tts/dac_44khZ_8kbps" \
18
+ --description_tokenizer_name ${PRETRAINED_MODEL_PATH} \
19
+ --prompt_tokenizer_name ${PRETRAINED_MODEL_PATH} \
20
+ --report_to "wandb" \
21
+ --wandb_key ${WANDB_KEY} \
22
+ --overwrite_output_dir true \
23
+ --train_dataset_name "OpenSound/CapSpeech" \
24
+ --train_split_name "train_SFT_CapTTS" \
25
+ --eval_dataset_name "OpenSound/CapSpeech" \
26
+ --eval_split_name "validation_SFT_CapTTS" \
27
+ --librittsr_dir ${LIBRITTSR_WAV_DIR} \
28
+ --other_dir ${OTHER_WAV_DIR} \
29
+ --max_eval_samples 96 \
30
+ --per_device_eval_batch_size 32 \
31
+ --target_audio_column_name "audio_path" \
32
+ --description_column_name "caption" \
33
+ --source_column_name "source" \
34
+ --prompt_column_name "text" \
35
+ --max_duration_in_seconds 20 \
36
+ --min_duration_in_seconds 3 \
37
+ --max_text_length 600 \
38
+ --preprocessing_num_workers 32 \
39
+ --do_train true \
40
+ --num_train_epochs 5 \
41
+ --gradient_accumulation_steps 6 \
42
+ --gradient_checkpointing false \
43
+ --per_device_train_batch_size 4 \
44
+ --learning_rate 0.0001 \
45
+ --adam_beta1 0.9 \
46
+ --adam_beta2 0.99 \
47
+ --weight_decay 0.01 \
48
+ --lr_scheduler_type "constant_with_warmup" \
49
+ --warmup_steps 1000 \
50
+ --logging_steps 200 \
51
+ --freeze_text_encoder true \
52
+ --per_device_eval_batch_size 4 \
53
+ --audio_encoder_per_device_batch_size 24 \
54
+ --dtype "float16" \
55
+ --seed 456 \
56
+ --output_dir ${OUTPUT_DIR} \
57
+ --temporary_save_to_disk ${TEMPORY_SAVE_TO_DISK} \
58
+ --save_to_disk ${SAVE_TO_DISK} \
59
+ --dataloader_num_workers 32 \
60
+ --do_eval \
61
+ --evaluation_strategy steps \
62
+ --eval_steps 2000 \
63
+ --save_steps 2000 \
64
+ --group_by_length true
capspeech/ar/finetune_capttsse.sh ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Please log in to huggingface first
2
+
3
+ LIBRITTSRMIX_WAV_DIR='' # downloaded capspeech-sedb wav dir
4
+ OUTPUT_DIR="./output_finetuning_capttsse/" # output dir, to save checkpoints
5
+ TEMPORY_SAVE_TO_DISK="./audio_code_finetuning_capttsse/" # dac codec saved dir
6
+ SAVE_TO_DISK="./dataset_finetuning_capttsse/" # huggingface metadata saved dir
7
+ WANDB_KEY='' # your wandb key for logging
8
+
9
+ PRETRAINED_MODEL_PATH="" # your pretrained model path
10
+
11
+ export CUDA_LAUNCH_BLOCKING=1
12
+ export TORCH_USE_CUDA_DSA=1
13
+
14
+ accelerate launch ./training/finetune_capttsse.py \
15
+ --model_name_or_path ${PRETRAINED_MODEL_PATH} \
16
+ --feature_extractor_name "parler-tts/dac_44khZ_8kbps" \
17
+ --description_tokenizer_name ${PRETRAINED_MODEL_PATH} \
18
+ --prompt_tokenizer_name ${PRETRAINED_MODEL_PATH} \
19
+ --report_to "wandb" \
20
+ --wandb_key ${WANDB_KEY} \
21
+ --overwrite_output_dir true \
22
+ --train_dataset_name "OpenSound/CapSpeech" \
23
+ --train_split_name "train_SEDB" \
24
+ --eval_dataset_name "OpenSound/CapSpeech" \
25
+ --eval_split_name "test_SEDB" \
26
+ --librittsrmix_dir ${LIBRITTSRMIX_WAV_DIR} \
27
+ --max_eval_samples 96 \
28
+ --per_device_eval_batch_size 32 \
29
+ --target_audio_column_name "audio_path" \
30
+ --description_column_name "caption" \
31
+ --source_column_name "source" \
32
+ --prompt_column_name "text" \
33
+ --max_duration_in_seconds 20 \
34
+ --min_duration_in_seconds 3 \
35
+ --max_text_length 600 \
36
+ --preprocessing_num_workers 32 \
37
+ --do_train true \
38
+ --num_train_epochs 50 \
39
+ --gradient_accumulation_steps 6 \
40
+ --gradient_checkpointing false \
41
+ --per_device_train_batch_size 4 \
42
+ --learning_rate 0.0001 \
43
+ --adam_beta1 0.9 \
44
+ --adam_beta2 0.99 \
45
+ --weight_decay 0.01 \
46
+ --lr_scheduler_type "constant_with_warmup" \
47
+ --warmup_steps 50 \
48
+ --logging_steps 20 \
49
+ --freeze_text_encoder true \
50
+ --per_device_eval_batch_size 4 \
51
+ --audio_encoder_per_device_batch_size 24 \
52
+ --dtype "float16" \
53
+ --seed 456 \
54
+ --output_dir ${OUTPUT_DIR} \
55
+ --temporary_save_to_disk ${TEMPORY_SAVE_TO_DISK} \
56
+ --save_to_disk ${SAVE_TO_DISK} \
57
+ --dataloader_num_workers 32 \
58
+ --do_eval \
59
+ --evaluation_strategy steps \
60
+ --eval_steps 50 \
61
+ --save_steps 50 \
62
+ --group_by_length true
capspeech/ar/finetune_emocaptts.sh ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Please log in to huggingface first
2
+
3
+ LIBRITTSR_WAV_DIR='' # downloaded libritts-r wav dir
4
+ OTHER_WAV_DIR='' # downloaded other wav dirs
5
+ OUTPUT_DIR="./output_finetuning_emocaptts/" # output dir, to save checkpoints
6
+ TEMPORY_SAVE_TO_DISK="./audio_code_finetuning_emocaptts/" # dac codec saved dir
7
+ SAVE_TO_DISK="./dataset_finetuning_emocaptts/" # huggingface metadata saved dir
8
+ WANDB_KEY='' # your wandb key for logging
9
+
10
+ PRETRAINED_MODEL_PATH="" # your pretrained model path
11
+
12
+ export CUDA_LAUNCH_BLOCKING=1
13
+ export TORCH_USE_CUDA_DSA=1
14
+
15
+ accelerate launch ./training/finetune_captts.py \
16
+ --model_name_or_path ${PRETRAINED_MODEL_PATH} \
17
+ --feature_extractor_name "parler-tts/dac_44khZ_8kbps" \
18
+ --description_tokenizer_name ${PRETRAINED_MODEL_PATH} \
19
+ --prompt_tokenizer_name ${PRETRAINED_MODEL_PATH} \
20
+ --report_to "wandb" \
21
+ --wandb_key ${WANDB_KEY} \
22
+ --overwrite_output_dir true \
23
+ --train_dataset_name "OpenSound/CapSpeech" \
24
+ --train_split_name "train_SFT_EmoCapTTS" \
25
+ --eval_dataset_name "OpenSound/CapSpeech" \
26
+ --eval_split_name "validation_SFT_EmoCapTTS" \
27
+ --librittsr_dir ${LIBRITTSR_WAV_DIR} \
28
+ --other_dir ${OTHER_WAV_DIR} \
29
+ --max_eval_samples 96 \
30
+ --per_device_eval_batch_size 32 \
31
+ --target_audio_column_name "audio_path" \
32
+ --description_column_name "caption" \
33
+ --source_column_name "source" \
34
+ --prompt_column_name "text" \
35
+ --max_duration_in_seconds 20 \
36
+ --min_duration_in_seconds 3 \
37
+ --max_text_length 600 \
38
+ --preprocessing_num_workers 32 \
39
+ --do_train true \
40
+ --num_train_epochs 5 \
41
+ --gradient_accumulation_steps 6 \
42
+ --gradient_checkpointing false \
43
+ --per_device_train_batch_size 4 \
44
+ --learning_rate 0.0001 \
45
+ --adam_beta1 0.9 \
46
+ --adam_beta2 0.99 \
47
+ --weight_decay 0.01 \
48
+ --lr_scheduler_type "constant_with_warmup" \
49
+ --warmup_steps 1000 \
50
+ --logging_steps 200 \
51
+ --freeze_text_encoder true \
52
+ --per_device_eval_batch_size 4 \
53
+ --audio_encoder_per_device_batch_size 24 \
54
+ --dtype "float16" \
55
+ --seed 456 \
56
+ --output_dir ${OUTPUT_DIR} \
57
+ --temporary_save_to_disk ${TEMPORY_SAVE_TO_DISK} \
58
+ --save_to_disk ${SAVE_TO_DISK} \
59
+ --dataloader_num_workers 32 \
60
+ --do_eval \
61
+ --evaluation_strategy steps \
62
+ --eval_steps 400 \
63
+ --save_steps 400 \
64
+ --group_by_length true
capspeech/ar/parler_tts/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.2.2"
2
+
3
+
4
+ from transformers import AutoConfig, AutoModel
5
+
6
+ from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
7
+ from .dac_wrapper import DACConfig, DACModel
8
+ from .modeling_parler_tts import (
9
+ ParlerTTSForCausalLM,
10
+ ParlerTTSForConditionalGeneration,
11
+ apply_delay_pattern_mask,
12
+ build_delay_pattern_mask,
13
+ )
14
+
15
+ from .streamer import ParlerTTSStreamer
16
+
17
+ from importlib.metadata import version
18
+ from packaging.version import Version
19
+
20
+ if Version(version("transformers"))<= Version("4.44.2dev"):
21
+ AutoConfig.register("dac", DACConfig)
22
+ else:
23
+ AutoConfig.register("dac_on_the_hub", DACConfig)
24
+
25
+ AutoModel.register(DACConfig, DACModel)
capspeech/ar/parler_tts/configuration_parler_tts.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Parler-TTS model configuration"""
16
+
17
+ from transformers import AutoConfig, logging
18
+ from transformers.configuration_utils import PretrainedConfig
19
+
20
+ from importlib.metadata import version
21
+ from packaging.version import Version
22
+
23
+ use_dac_on_the_hub = Version(version("transformers")) > Version("4.44.2dev")
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ PARLER_TTS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
28
+ "parler-tts/parler-tts-mini-v1": "https://huggingface.co/parler-tts/parler-tts-mini-v1/resolve/main/config.json",
29
+ # See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts
30
+ }
31
+
32
+
33
+ class ParlerTTSDecoderConfig(PretrainedConfig):
34
+ r"""
35
+ This is the configuration class to store the configuration of an [`ParlerTTSDecoder`]. It is used to instantiate a
36
+ Parler-TTS decoder according to the specified arguments, defining the model architecture. Instantiating a
37
+ configuration with the defaults will yield a similar configuration to that of the Parler-TTS
38
+ [parler-tts/parler-tts-mini-v1](https://huggingface.co/parler-tts/parler-tts-mini-v1) architecture.
39
+
40
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
41
+ documentation from [`PretrainedConfig`] for more information.
42
+
43
+
44
+ Args:
45
+ vocab_size (`int`, *optional*, defaults to 2049):
46
+ Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
47
+ represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
48
+ hidden_size (`int`, *optional*, defaults to 1024):
49
+ Dimensionality of the layers and the pooler layer.
50
+ num_hidden_layers (`int`, *optional*, defaults to 24):
51
+ Number of decoder layers.
52
+ num_attention_heads (`int`, *optional*, defaults to 16):
53
+ Number of attention heads for each attention layer in the Transformer block.
54
+ num_key_value_heads (`int`, *optional*):
55
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
56
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
57
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
58
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
59
+ by meanpooling all the original heads within that group. For more details checkout [this
60
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
61
+ `num_attention_heads`.
62
+ num_cross_attention_key_value_heads (`int`, *optional*):
63
+ This is the number of key_value heads that should be used to implement Grouped Query Attention in the cross-attention layers.
64
+ If it is not specified, will default to `num_key_value_heads`.
65
+ ffn_dim (`int`, *optional*, defaults to 4096):
66
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
67
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
68
+ The non-linear activation function (function or string) in the decoder and pooler. If string, `"gelu"`,
69
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
70
+ dropout (`float`, *optional*, defaults to 0.1):
71
+ The dropout probability for all fully connected layers in the embeddings, text_encoder, and pooler.
72
+ attention_dropout (`float`, *optional*, defaults to 0.0):
73
+ The dropout ratio for the attention probabilities.
74
+ activation_dropout (`float`, *optional*, defaults to 0.0):
75
+ The dropout ratio for activations inside the fully connected layer.
76
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
77
+ The maximum sequence length that this model might ever be used with. Typically, set this to something large
78
+ just in case (e.g., 512 or 1024 or 2048).
79
+ initializer_factor (`float`, *optional*, defaults to 0.02):
80
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
81
+ layerdrop (`float`, *optional*, defaults to 0.0):
82
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
83
+ for more details.
84
+ scale_embedding (`bool`, *optional*, defaults to `False`):
85
+ Scale embeddings by diving by sqrt(hidden_size).
86
+ use_cache (`bool`, *optional*, defaults to `True`):
87
+ Whether the model should return the last key/values attentions (not used by all models)
88
+ num_codebooks (`int`, *optional*, defaults to 4):
89
+ The number of parallel codebooks forwarded to the model.
90
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
91
+ Whether input and output word embeddings should be tied.
92
+ rope_embeddings (`bool`, *optional*, defaults to `False`):
93
+ Whether to use ROPE or absolute positional embeddings.
94
+ rope_theta (`float`, *optional*, defaults to 100000.0):
95
+ The base period of the RoPE embeddings.
96
+ cross_attention_implementation_strategy (`str`, *optional*):
97
+ If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation.
98
+ use_fused_lm_heads(`bool`, *optional*, defaults to `False`):
99
+ Whether to fuse audio LM heads instead of applying them sequentially.
100
+ codebook_weights(`List[int]`, *optional*):
101
+ Weights applied to each codebook when computing the loss.
102
+ """
103
+
104
+ model_type = "parler_tts_decoder"
105
+ keys_to_ignore_at_inference = ["past_key_values"]
106
+
107
+ def __init__(
108
+ self,
109
+ vocab_size=2049, # vocab size = 2048 (encodec vocab size) + 1 (eos)
110
+ max_position_embeddings=2048,
111
+ num_hidden_layers=24,
112
+ ffn_dim=4096,
113
+ num_attention_heads=16,
114
+ num_key_value_heads=None,
115
+ num_cross_attention_key_value_heads=None,
116
+ layerdrop=0.0,
117
+ use_cache=True,
118
+ activation_function="gelu",
119
+ hidden_size=1024,
120
+ dropout=0.1,
121
+ attention_dropout=0.0,
122
+ activation_dropout=0.0,
123
+ initializer_factor=0.02,
124
+ scale_embedding=False,
125
+ num_codebooks=4,
126
+ pad_token_id=2048,
127
+ bos_token_id=2049,
128
+ eos_token_id=2048,
129
+ tie_word_embeddings=False,
130
+ rope_embeddings=False,
131
+ rope_theta=10_000.0,
132
+ cross_attention_implementation_strategy=None,
133
+ use_fused_lm_heads=False,
134
+ codebook_weights=None,
135
+ **kwargs,
136
+ ):
137
+ self.vocab_size = vocab_size
138
+ self.max_position_embeddings = max_position_embeddings
139
+ self.hidden_size = hidden_size
140
+ self.ffn_dim = ffn_dim
141
+ self.num_hidden_layers = num_hidden_layers
142
+ self.num_attention_heads = num_attention_heads
143
+ if num_key_value_heads is None:
144
+ num_key_value_heads = num_attention_heads
145
+ self.num_key_value_heads = num_key_value_heads
146
+ if num_cross_attention_key_value_heads is None:
147
+ num_cross_attention_key_value_heads = num_key_value_heads
148
+ self.num_cross_attention_key_value_heads = num_cross_attention_key_value_heads
149
+ self.dropout = dropout
150
+ self.attention_dropout = attention_dropout
151
+ self.activation_dropout = activation_dropout
152
+ self.activation_function = activation_function
153
+ self.initializer_factor = initializer_factor
154
+ self.layerdrop = layerdrop
155
+ self.use_cache = use_cache
156
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
157
+ self.num_codebooks = num_codebooks
158
+ self.rope_embeddings = rope_embeddings
159
+ self.rope_theta = rope_theta
160
+ self.cross_attention_implementation_strategy = cross_attention_implementation_strategy
161
+ self.use_fused_lm_heads = use_fused_lm_heads
162
+ self.codebook_weights = codebook_weights
163
+
164
+ if codebook_weights is not None and len(codebook_weights) != num_codebooks:
165
+ raise ValueError(f"`codebook_weights` has length {len(codebook_weights)} when it should be of length {num_codebooks}.")
166
+ super().__init__(
167
+ pad_token_id=pad_token_id,
168
+ bos_token_id=bos_token_id,
169
+ eos_token_id=eos_token_id,
170
+ tie_word_embeddings=tie_word_embeddings,
171
+ **kwargs,
172
+ )
173
+
174
+
175
+ class ParlerTTSConfig(PretrainedConfig):
176
+ r"""
177
+ This is the configuration class to store the configuration of a [`ParlerTTSModel`]. It is used to instantiate a
178
+ Parler-TTS model according to the specified arguments, defining the text encoder, audio encoder and Parler-TTS decoder
179
+ configs.
180
+
181
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
182
+ documentation from [`PretrainedConfig`] for more information.
183
+
184
+ Args:
185
+ vocab_size (`int`, *optional*, defaults to 1024):
186
+ Vocabulary size of the prompt token ids. Defines the number of different tokens that can be
187
+ represented by the `prompt_inputs_ids`.
188
+ prompt_cross_attention (`bool`, *optional*, defaults to `False`):
189
+ Whether to use cross-attention conditioning for the prompt (as well as the description).
190
+ kwargs (*optional*):
191
+ Dictionary of keyword arguments. Notably:
192
+
193
+ - **text_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
194
+ defines the text encoder config.
195
+ - **audio_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
196
+ defines the audio encoder config.
197
+ - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
198
+ the decoder config.
199
+
200
+ Example:
201
+
202
+ ```python
203
+ >>> from transformers import (
204
+ ... ParlerTTSConfig,
205
+ ... ParlerTTSDecoderConfig,
206
+ ... T5Config,
207
+ ... EncodecConfig,
208
+ ... ParlerTTSForConditionalGeneration,
209
+ ... )
210
+
211
+ >>> # Initializing text encoder, audio encoder, and decoder model configurations
212
+ >>> text_encoder_config = T5Config()
213
+ >>> audio_encoder_config = EncodecConfig()
214
+ >>> decoder_config = ParlerTTSDecoderConfig()
215
+
216
+ >>> configuration = ParlerTTSConfig.from_sub_models_config(
217
+ ... text_encoder_config, audio_encoder_config, decoder_config
218
+ ... )
219
+
220
+ >>> # Initializing a ParlerTTSForConditionalGeneration (with random weights) from the parler-tts/parler-tts-mini-v1 style configuration
221
+ >>> model = ParlerTTSForConditionalGeneration(configuration)
222
+
223
+ >>> # Accessing the model configuration
224
+ >>> configuration = model.config
225
+ >>> config_text_encoder = model.config.text_encoder
226
+ >>> config_audio_encoder = model.config.audio_encoder
227
+ >>> config_decoder = model.config.decoder
228
+
229
+ >>> # Saving the model, including its configuration
230
+ >>> model.save_pretrained("parler_tts-model")
231
+
232
+ >>> # loading model and config from pretrained folder
233
+ >>> parler_tts_config = ParlerTTSConfig.from_pretrained("parler_tts-model")
234
+ >>> model = ParlerTTSForConditionalGeneration.from_pretrained("parler_tts-model", config=parler_tts_config)
235
+ ```"""
236
+
237
+ model_type = "parler_tts"
238
+ is_composition = True
239
+
240
+ def __init__(self, vocab_size=1024, prompt_cross_attention=False, **kwargs):
241
+ super().__init__(**kwargs)
242
+ if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
243
+ raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
244
+
245
+ text_encoder_config = kwargs.pop("text_encoder")
246
+ text_encoder_model_type = text_encoder_config.pop("model_type")
247
+
248
+ audio_encoder_config = kwargs.pop("audio_encoder")
249
+ audio_encoder_model_type = audio_encoder_config.pop("model_type")
250
+
251
+ model_version = kwargs.get("transformers_version", None)
252
+ if model_version is not None and Version(model_version) <= Version("4.44.2dev") and use_dac_on_the_hub and audio_encoder_model_type=="dac":
253
+ # here we have to manually change model type if DAC based on transformers version
254
+ audio_encoder_model_type = "dac_on_the_hub"
255
+
256
+ decoder_config = kwargs.pop("decoder")
257
+
258
+ self.vocab_size = vocab_size
259
+ self.prompt_cross_attention = prompt_cross_attention
260
+ self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
261
+ self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
262
+ self.decoder = ParlerTTSDecoderConfig(**decoder_config)
263
+ self.is_encoder_decoder = True
264
+
265
+ @classmethod
266
+ def from_sub_models_config(
267
+ cls,
268
+ text_encoder_config: PretrainedConfig,
269
+ audio_encoder_config: PretrainedConfig,
270
+ decoder_config: ParlerTTSDecoderConfig,
271
+ **kwargs,
272
+ ):
273
+ r"""
274
+ Instantiate a [`ParlerTTSConfig`] (or a derived class) from text encoder, audio encoder and decoder
275
+ configurations.
276
+
277
+ Returns:
278
+ [`ParlerTTSConfig`]: An instance of a configuration object
279
+ """
280
+
281
+ return cls(
282
+ text_encoder=text_encoder_config.to_dict(),
283
+ audio_encoder=audio_encoder_config.to_dict(),
284
+ decoder=decoder_config.to_dict(),
285
+ **kwargs,
286
+ )
287
+
288
+ @property
289
+ # This is a property because you might want to change the codec model on the fly
290
+ def sampling_rate(self):
291
+ return self.audio_encoder.sampling_rate
capspeech/ar/parler_tts/dac_wrapper/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .configuration_dac import DACConfig
2
+ from .modeling_dac import DACModel
capspeech/ar/parler_tts/dac_wrapper/configuration_dac.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import PretrainedConfig
3
+ from importlib.metadata import version
4
+ from packaging.version import Version
5
+
6
+
7
+ class DACConfig(PretrainedConfig):
8
+ model_type = "dac" if Version(version("transformers"))<= Version("4.44.2dev") else "dac_on_the_hub"
9
+
10
+ def __init__(
11
+ self,
12
+ num_codebooks: int = 9,
13
+ model_bitrate: int = 8, # kbps
14
+ codebook_size: int = 1024,
15
+ latent_dim: int = 1024,
16
+ frame_rate: int = 86,
17
+ sampling_rate: int = 44100,
18
+ **kwargs,
19
+ ):
20
+ self.codebook_size = codebook_size
21
+ self.model_bitrate = model_bitrate
22
+ self.latent_dim = latent_dim
23
+ self.num_codebooks = num_codebooks
24
+ self.frame_rate = frame_rate
25
+ self.sampling_rate = sampling_rate
26
+
27
+ super().__init__(**kwargs)
capspeech/ar/parler_tts/dac_wrapper/modeling_dac.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dac.model import DAC
3
+ from torch import nn
4
+
5
+ from transformers import PreTrainedModel
6
+ from transformers.models.encodec.modeling_encodec import EncodecDecoderOutput, EncodecEncoderOutput
7
+
8
+ from .configuration_dac import DACConfig
9
+
10
+
11
+ # model doesn't support batching yet
12
+
13
+
14
+ class DACModel(PreTrainedModel):
15
+ config_class = DACConfig
16
+ main_input_name = "input_values"
17
+
18
+ # Set main input to 'input_values' for voice steering
19
+ main_input_name = "input_values"
20
+
21
+ def __init__(self, config):
22
+ super().__init__(config)
23
+
24
+ self.model = DAC(
25
+ n_codebooks=config.num_codebooks,
26
+ latent_dim=config.latent_dim,
27
+ codebook_size=config.codebook_size,
28
+ )
29
+
30
+ self.remove_weight_norm()
31
+ self.apply_weight_norm()
32
+
33
+ def encode(
34
+ self, input_values, padding_mask=None, bandwidth=None, return_dict=None, n_quantizers=None, sample_rate=None
35
+ ):
36
+ """
37
+ Encodes the input audio waveform into discrete codes.
38
+
39
+ Args:
40
+ input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
41
+ Float values of the input audio waveform.
42
+ padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
43
+ Padding mask used to pad the `input_values`.
44
+ bandwidth (`float`, *optional*):
45
+ Not used, kept to have the same inferface as HF encodec.
46
+ n_quantizers (`int`, *optional*) :
47
+ Number of quantizers to use, by default None
48
+ If None, all quantizers are used.
49
+ sample_rate (`int`, *optional*) :
50
+ Signal sampling_rate
51
+
52
+ Returns:
53
+ A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
54
+ factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
55
+ `codebook` of shape `[batch_size, num_codebooks, frames]`.
56
+ Scale is not used here.
57
+
58
+ """
59
+ _, channels, input_length = input_values.shape
60
+
61
+ if channels < 1 or channels > 2:
62
+ raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}")
63
+
64
+ audio_data = self.model.preprocess(input_values, sample_rate)
65
+
66
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
67
+
68
+ # TODO: for now, no chunk length
69
+
70
+ chunk_length = None # self.config.chunk_length
71
+ if chunk_length is None:
72
+ chunk_length = input_length
73
+ stride = input_length
74
+ else:
75
+ stride = self.config.chunk_stride
76
+
77
+ if padding_mask is None:
78
+ padding_mask = torch.ones_like(input_values).bool()
79
+
80
+ encoded_frames = []
81
+ scales = []
82
+
83
+ step = chunk_length - stride
84
+ if (input_length % stride) - step != 0:
85
+ raise ValueError(
86
+ "The input length is not properly padded for batched chunked decoding. Make sure to pad the input correctly."
87
+ )
88
+
89
+ for offset in range(0, input_length - step, stride):
90
+ mask = padding_mask[..., offset : offset + chunk_length].bool()
91
+ frame = audio_data[:, :, offset : offset + chunk_length]
92
+
93
+ scale = None
94
+
95
+ _, encoded_frame, _, _, _ = self.model.encode(frame, n_quantizers=n_quantizers)
96
+ encoded_frames.append(encoded_frame)
97
+ scales.append(scale)
98
+
99
+ encoded_frames = torch.stack(encoded_frames)
100
+
101
+ if not return_dict:
102
+ return (encoded_frames, scales)
103
+
104
+ return EncodecEncoderOutput(encoded_frames, scales)
105
+
106
+ def decode(
107
+ self,
108
+ audio_codes,
109
+ audio_scales,
110
+ padding_mask=None,
111
+ return_dict=None,
112
+ ):
113
+ """
114
+ Decodes the given frames into an output audio waveform.
115
+
116
+ Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
117
+ trimmed.
118
+
119
+ Args:
120
+ audio_codes (`torch.FloatTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
121
+ Discret code embeddings computed using `model.encode`.
122
+ audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
123
+ Not used, kept to have the same inferface as HF encodec.
124
+ padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
125
+ Padding mask used to pad the `input_values`.
126
+ Not used yet, kept to have the same inferface as HF encodec.
127
+ return_dict (`bool`, *optional*):
128
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
129
+
130
+ """
131
+ return_dict = return_dict or self.config.return_dict
132
+
133
+ # TODO: for now, no chunk length
134
+
135
+ if len(audio_codes) != 1:
136
+ raise ValueError(f"Expected one frame, got {len(audio_codes)}")
137
+
138
+ audio_values = self.model.quantizer.from_codes(audio_codes.squeeze(0))[0]
139
+ audio_values = self.model.decode(audio_values)
140
+ if not return_dict:
141
+ return (audio_values,)
142
+ return EncodecDecoderOutput(audio_values)
143
+
144
+ def forward(self, tensor):
145
+ raise ValueError("`DACModel.forward` not implemented yet")
146
+
147
+
148
+ def apply_weight_norm(self):
149
+ weight_norm = nn.utils.weight_norm
150
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
151
+ weight_norm = nn.utils.parametrizations.weight_norm
152
+
153
+ def _apply_weight_norm(module):
154
+ if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
155
+ weight_norm(module)
156
+
157
+ self.apply(_apply_weight_norm)
158
+
159
+
160
+ def remove_weight_norm(self):
161
+ def _remove_weight_norm(module):
162
+ if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
163
+ nn.utils.remove_weight_norm(module)
164
+ self.apply(_remove_weight_norm)
capspeech/ar/parler_tts/logits_processors.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LogitsProcessor, LogitsProcessorList
2
+ from transformers.pytorch_utils import isin_mps_friendly
3
+ import math
4
+ import torch
5
+
6
+ class ParlerTTSLogitsProcessor(LogitsProcessor):
7
+ r"""This processor ensures that the delayed pattern mask constraints are respected.
8
+
9
+ <Tip warning={true}>
10
+
11
+ This logits processor is exclusively compatible with Parler-TTS.
12
+ See the model documentation for examples.
13
+
14
+ </Tip>
15
+
16
+ Args:
17
+ eos_token_id (`Union[int, List[int], torch.Tensor]`):
18
+ The id(s) of the *end-of-sequence* token.
19
+ min_eos_p (`float`, *optional*):
20
+ Minimum end of speech threshold.
21
+ """
22
+
23
+ def __init__(self, eos_token_id, num_codebooks: int, batch_size: int, device: str = "cpu"):
24
+ if not isinstance(eos_token_id, torch.Tensor):
25
+ if isinstance(eos_token_id, int):
26
+ eos_token_id = [eos_token_id]
27
+ eos_token_id = torch.tensor(eos_token_id, device=device)
28
+ self.eos_token_id = eos_token_id
29
+ self.batch_size = batch_size
30
+
31
+ if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
32
+ raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
33
+
34
+ self.num_codebooks = num_codebooks
35
+ self.device = device
36
+
37
+
38
+ self.codebook_idx = torch.arange(self.batch_size*self.num_codebooks, device=self.device)
39
+ self.first_codebooks_unfinished = torch.arange(batch_size, device=device)*num_codebooks
40
+
41
+ max_codebooks = torch.arange(self.batch_size, device=self.device)*self.num_codebooks + self.num_codebooks -1
42
+ self.max_codebooks = max_codebooks
43
+
44
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
45
+
46
+ is_eos = isin_mps_friendly(input_ids, self.eos_token_id).sum(1)
47
+
48
+ self.first_codebooks_unfinished = torch.where((is_eos[self.first_codebooks_unfinished]>0) & (self.first_codebooks_unfinished<self.max_codebooks), self.first_codebooks_unfinished+1, self.first_codebooks_unfinished)
49
+
50
+ # every codebook higher than the first one unfinished will never be eos
51
+ eos_token_mask = self.codebook_idx > self.first_codebooks_unfinished.repeat_interleave(self.num_codebooks)
52
+ scores[eos_token_mask, self.eos_token_id] = -math.inf
53
+
54
+ return scores
capspeech/ar/parler_tts/modeling_parler_tts.py ADDED
The diff for this file is too large to render. See raw diff
 
capspeech/ar/parler_tts/streamer.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from .modeling_parler_tts import ParlerTTSForConditionalGeneration
3
+ from transformers.generation.streamers import BaseStreamer
4
+ from typing import Optional
5
+ import torch
6
+ import numpy as np
7
+ import math
8
+ from queue import Queue
9
+
10
+
11
+ class ParlerTTSStreamer(BaseStreamer):
12
+ def __init__(
13
+ self,
14
+ model: ParlerTTSForConditionalGeneration,
15
+ device: Optional[str] = None,
16
+ play_steps: Optional[int] = 10,
17
+ stride: Optional[int] = None,
18
+ timeout: Optional[float] = None,
19
+ ):
20
+ """
21
+ Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is
22
+ useful for applications that benefit from accessing the generated audio in a non-blocking way (e.g. in an interactive
23
+ Gradio demo).
24
+ Parameters:
25
+ model (`ParlerTTSForConditionalGeneration`):
26
+ The Parler-TTS model used to generate the audio waveform.
27
+ device (`str`, *optional*):
28
+ The torch device on which to run the computation. If `None`, will default to the device of the model.
29
+ play_steps (`int`, *optional*, defaults to 10):
30
+ The number of generation steps with which to return the generated audio array. Using fewer steps will
31
+ mean the first chunk is ready faster, but will require more codec decoding steps overall. This value
32
+ should be tuned to your device and latency requirements.
33
+ stride (`int`, *optional*):
34
+ The window (stride) between adjacent audio samples. Using a stride between adjacent audio samples reduces
35
+ the hard boundary between them, giving smoother playback. If `None`, will default to a value equivalent to
36
+ play_steps // 6 in the audio space.
37
+ timeout (`int`, *optional*):
38
+ The timeout for the audio queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
39
+ in `.generate()`, when it is called in a separate thread.
40
+ """
41
+ self.decoder = model.decoder
42
+ self.audio_encoder = model.audio_encoder
43
+ self.generation_config = model.generation_config
44
+ self.device = device if device is not None else model.device
45
+ self.use_audio_scales = model.use_audio_scales
46
+ self.use_4dim_audio_codes = model.use_4dim_audio_codes
47
+ self.audio_kwargs = {}
48
+ if self.use_audio_scales:
49
+ self.audio_kwargs["audio_scales"] = [None]
50
+
51
+ # variables used in the streaming process
52
+ self.play_steps = play_steps
53
+ if stride is not None:
54
+ self.stride = stride
55
+ else:
56
+ hop_length = math.floor(self.audio_encoder.config.sampling_rate / self.audio_encoder.config.frame_rate)
57
+ self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
58
+ self.token_cache = None
59
+ self.to_yield = 0
60
+
61
+ # varibles used in the thread process
62
+ self.audio_queue = Queue()
63
+ self.stop_signal = None
64
+ self.timeout = timeout
65
+
66
+ def apply_delay_pattern_mask(self, input_ids):
67
+ # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler)
68
+ _, delay_pattern_mask = self.decoder.build_delay_pattern_mask(
69
+ input_ids[:, :1],
70
+ bos_token_id=self.generation_config.bos_token_id,
71
+ pad_token_id=self.generation_config.decoder_start_token_id,
72
+ max_length=input_ids.shape[-1],
73
+ )
74
+ # apply the pattern mask to the input ids
75
+ input_ids = self.decoder.apply_delay_pattern_mask(input_ids, delay_pattern_mask)
76
+
77
+ # revert the pattern delay mask by filtering the pad token id
78
+ mask = (delay_pattern_mask != self.generation_config.bos_token_id) & (delay_pattern_mask != self.generation_config.pad_token_id)
79
+ input_ids = input_ids[mask].reshape(1, self.decoder.num_codebooks, -1)
80
+
81
+ if self.use_4dim_audio_codes:
82
+ # append the frame dimension back to the audio codes
83
+ input_ids = input_ids[None, ...]
84
+
85
+ # send the input_ids to the correct device
86
+ input_ids = input_ids.to(self.audio_encoder.device)
87
+
88
+ decode_sequentially = (
89
+ self.generation_config.bos_token_id in input_ids
90
+ or self.generation_config.pad_token_id in input_ids
91
+ or self.generation_config.eos_token_id in input_ids
92
+ )
93
+ if not decode_sequentially:
94
+ sample = self.audio_encoder.decode(
95
+ audio_codes=input_ids,
96
+ **self.audio_kwargs,
97
+ ).audio_values
98
+ output_values = sample if sample.ndim == 3 else sample.unsqueeze(0)
99
+ else:
100
+ sample = input_ids[:, 0] if self.use_4dim_audio_codes else input_ids[0]
101
+ sample_mask = ((sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0) if self.use_4dim_audio_codes else ((sample >= self.audio_encoder.config.codebook_size).sum(dim=0) == 0)
102
+ sample = sample[:, :, sample_mask] if self.use_4dim_audio_codes else sample[:, sample_mask]
103
+ sample = self.audio_encoder.decode(audio_codes=sample[None, ...], **self.audio_kwargs).audio_values
104
+ output_values = sample if sample.ndim == 3 else sample.unsqueeze(0)
105
+
106
+ audio_values = output_values[0, 0]
107
+ return audio_values.cpu().float().numpy()
108
+
109
+ def put(self, value):
110
+ batch_size = value.shape[0] // self.decoder.num_codebooks
111
+ if batch_size > 1:
112
+ raise ValueError("ParlerTTSStreamer only supports batch size 1")
113
+
114
+ if self.token_cache is None:
115
+ self.token_cache = value
116
+ else:
117
+ self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
118
+
119
+ if self.token_cache.shape[-1] % self.play_steps == 0:
120
+ audio_values = self.apply_delay_pattern_mask(self.token_cache)
121
+ self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
122
+ self.to_yield += len(audio_values) - self.to_yield - self.stride
123
+
124
+ def end(self):
125
+ """Flushes any remaining cache and appends the stop symbol."""
126
+ if self.token_cache is not None:
127
+ audio_values = self.apply_delay_pattern_mask(self.token_cache)
128
+ else:
129
+ audio_values = np.zeros(self.to_yield)
130
+
131
+ self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
132
+
133
+ def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
134
+ """Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
135
+ self.audio_queue.put(audio, timeout=self.timeout)
136
+ if stream_end:
137
+ self.audio_queue.put(self.stop_signal, timeout=self.timeout)
138
+
139
+ def __iter__(self):
140
+ return self
141
+
142
+ def __next__(self):
143
+ value = self.audio_queue.get(timeout=self.timeout)
144
+ if not isinstance(value, np.ndarray) and value == self.stop_signal:
145
+ raise StopIteration()
146
+ else:
147
+ return value
capspeech/ar/pretrain.sh ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Please log in to huggingface first
2
+
3
+ MLS_WAV_DIR='' # downloaded mls wav path
4
+ LIBRITTSRMIX_WAV_DIR='' # downloaded librittsrmix wav path
5
+ GIGASPEECH_WAV_DIR='' # downloaded gigaspeech wav path
6
+ COMMONVOICE_WAV_DIR='' # downloaded commonvoice wav path
7
+ EMILIA_WAV_DIR='' # downloaded emilia wav path
8
+ OUTPUT_DIR="./output_pretraining/" # output dir, to save checkpoints
9
+ TEMPORY_SAVE_TO_DISK="./audio_code_pretraining/" # dac codec saved dir
10
+ SAVE_TO_DISK="./dataset_pretraining/" # huggingface metadata saved dir
11
+ WANDB_KEY='' # your wandb key for logging
12
+
13
+ export CUDA_LAUNCH_BLOCKING=1
14
+ export TORCH_USE_CUDA_DSA=1
15
+
16
+ accelerate launch ./training/run_parler_tts_training.py \
17
+ --model_name_or_path "parler-tts/parler-tts-mini-v1" \
18
+ --feature_extractor_name "parler-tts/dac_44khZ_8kbps" \
19
+ --description_tokenizer_name "google/flan-t5-large" \
20
+ --prompt_tokenizer_name "google/flan-t5-large" \
21
+ --report_to "wandb" \
22
+ --wandb_key ${WANDB_KEY} \
23
+ --overwrite_output_dir true \
24
+ --train_dataset_name "OpenSound/CapSpeech" \
25
+ --train_split_name "train_PT" \
26
+ --eval_dataset_name "OpenSound/CapSpeech" \
27
+ --eval_split_name "validation_PT" \
28
+ --mls_dir ${MLS_WAV_DIR} \
29
+ --librittsrmix_dir ${LIBRITTSRMIX_WAV_DIR} \
30
+ --gigaspeech_dir ${GIGASPEECH_WAV_DIR} \
31
+ --commonvoice_dir ${COMMONVOICE_WAV_DIR} \
32
+ --emilia_dir ${EMILIA_WAV_DIR} \
33
+ --max_eval_samples 96 \
34
+ --per_device_eval_batch_size 32 \
35
+ --target_audio_column_name "audio_path" \
36
+ --description_column_name "caption" \
37
+ --source_column_name "source" \
38
+ --prompt_column_name "text" \
39
+ --max_duration_in_seconds 20 \
40
+ --min_duration_in_seconds 3 \
41
+ --max_text_length 600 \
42
+ --preprocessing_num_workers 32 \
43
+ --do_train true \
44
+ --num_train_epochs 10 \
45
+ --gradient_accumulation_steps 6 \
46
+ --gradient_checkpointing false \
47
+ --per_device_train_batch_size 4 \
48
+ --learning_rate 0.001 \
49
+ --adam_beta1 0.9 \
50
+ --adam_beta2 0.99 \
51
+ --weight_decay 0.01 \
52
+ --lr_scheduler_type "constant_with_warmup" \
53
+ --warmup_steps 5000 \
54
+ --logging_steps 200 \
55
+ --freeze_text_encoder false \
56
+ --per_device_eval_batch_size 4 \
57
+ --audio_encoder_per_device_batch_size 24 \
58
+ --dtype "float16" \
59
+ --seed 456 \
60
+ --output_dir ${OUTPUT_DIR} \
61
+ --temporary_save_to_disk ${TEMPORY_SAVE_TO_DISK} \
62
+ --save_to_disk ${SAVE_TO_DISK} \
63
+ --dataloader_num_workers 32 \
64
+ --do_eval \
65
+ --evaluation_strategy steps \
66
+ --eval_steps 5000 \
67
+ --save_steps 5000 \
68
+ --group_by_length true
capspeech/ar/training/__init__.py ADDED
File without changes
capspeech/ar/training/arguments.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional, List
3
+
4
+ from transformers import Seq2SeqTrainingArguments
5
+
6
+
7
+ @dataclass
8
+ class ModelArguments:
9
+ """
10
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
11
+ """
12
+
13
+ model_name_or_path: str = field(
14
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
15
+ )
16
+ config_name: Optional[str] = field(
17
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
18
+ )
19
+ feature_extractor_name: Optional[str] = field(
20
+ default=None, metadata={"help": "Pretrained feature extractor name or path if not the same as model_name"}
21
+ )
22
+ description_tokenizer_name: Optional[str] = field(
23
+ default=None, metadata={"help": "Pretrained description tokenizer name or path if not the same as model_name"}
24
+ )
25
+ prompt_tokenizer_name: Optional[str] = field(
26
+ default=None,
27
+ metadata={"help": "Pretrained prompt tokenizer name or path if not the same as description_tokenizer_name"},
28
+ )
29
+ cache_dir: Optional[str] = field(
30
+ default=None,
31
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
32
+ )
33
+ use_fast_tokenizer: bool = field(
34
+ default=True,
35
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
36
+ )
37
+ model_revision: str = field(
38
+ default="main",
39
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
40
+ )
41
+ pad_token_id: int = field(
42
+ default=None,
43
+ metadata={"help": "If specified, change the model pad token id."},
44
+ )
45
+ decoder_start_token_id: int = field(
46
+ default=None,
47
+ metadata={"help": "If specified, change the model decoder start token id."},
48
+ )
49
+ freeze_text_encoder: bool = field(
50
+ default=False,
51
+ metadata={"help": "Whether to freeze the text encoder."},
52
+ )
53
+ do_sample: bool = field(
54
+ default=True,
55
+ metadata={"help": "Whether to do sampling or greedy decoding."},
56
+ )
57
+ temperature: float = field(
58
+ default=1.0,
59
+ metadata={"help": "Temperature if sampling."},
60
+ )
61
+ max_length: int = field(
62
+ default=2580,
63
+ metadata={"help": "Generation max length."},
64
+ )
65
+ bandwidth: float = field(
66
+ default=6,
67
+ metadata={"help": "Audio encoder bandwidth."},
68
+ )
69
+ asr_model_name_or_path: str = field(
70
+ default="distil-whisper/distil-large-v2",
71
+ metadata={
72
+ "help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
73
+ },
74
+ )
75
+ clap_model_name_or_path: str = field(
76
+ default="laion/larger_clap_music_and_speech",
77
+ metadata={
78
+ "help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
79
+ },
80
+ )
81
+ attn_implementation: str = field(
82
+ default="eager",
83
+ metadata={"help": "Attention implementation used. One of `eager`, `sdpa`, `flash_attention_2`"},
84
+ )
85
+ cross_attention_implementation_strategy: str = field(
86
+ default=None,
87
+ metadata={
88
+ "help": "If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation."
89
+ },
90
+ )
91
+ prompt_padding_side: Optional[str] = field(
92
+ default="left",
93
+ metadata={
94
+ "help": "Prompt tokenizer padding side. Defaults to `left`. If the prompt is pre-pended to the codebooks hidden states, it should be padded on the left."
95
+ },
96
+ )
97
+
98
+
99
+ @dataclass
100
+ class DataTrainingArguments:
101
+ """
102
+ Arguments pertaining to what data we are going to input our model for training and eval.
103
+
104
+ Using `HfArgumentParser` we can turn this class
105
+ into argparse arguments to be able to specify them on
106
+ the command line.
107
+ """
108
+
109
+ train_dataset_name: str = field(
110
+ default=None,
111
+ metadata={
112
+ "help": "The name of the training dataset to use (via the datasets library). Load and combine "
113
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
114
+ " librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
115
+ },
116
+ )
117
+ train_dataset_config_name: Optional[str] = field(
118
+ default=None,
119
+ metadata={
120
+ "help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
121
+ "multiple datasets by separating dataset configs by a '+' symbol."
122
+ },
123
+ )
124
+ train_split_name: str = field(
125
+ default="train",
126
+ metadata={
127
+ "help": ("The name of the training data set split to use (via the datasets library). Defaults to 'train'")
128
+ },
129
+ )
130
+ train_dataset_samples: str = field(
131
+ default=None,
132
+ metadata={
133
+ "help": "Number of samples in the training data. Load and combine "
134
+ "multiple datasets by separating dataset samples by a '+' symbol."
135
+ },
136
+ )
137
+ train_metadata_dataset_name: str = field(
138
+ default=None,
139
+ metadata={
140
+ "help": "The name of the metadata training dataset to use (via the datasets library). Load and combine "
141
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
142
+ " librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
143
+ },
144
+ )
145
+ eval_dataset_name: str = field(
146
+ default=None,
147
+ metadata={
148
+ "help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset name if unspecified."
149
+ },
150
+ )
151
+ eval_dataset_config_name: Optional[str] = field(
152
+ default=None,
153
+ metadata={
154
+ "help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset config name if unspecified"
155
+ },
156
+ )
157
+ eval_split_name: str = field(
158
+ default="test",
159
+ metadata={
160
+ "help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'test'"
161
+ },
162
+ )
163
+ eval_metadata_dataset_name: str = field(
164
+ default=None,
165
+ metadata={
166
+ "help": "The name of the metadata training dataset to use (via the datasets library). Load and combine "
167
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
168
+ " librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
169
+ },
170
+ )
171
+ target_audio_column_name: str = field(
172
+ default="audio",
173
+ metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"},
174
+ )
175
+ description_column_name: str = field(
176
+ default=None,
177
+ metadata={"help": "The name of the dataset column containing the description text data. Defaults to 'None'."},
178
+ )
179
+ prompt_column_name: str = field(
180
+ default=None,
181
+ metadata={"help": "The name of the dataset column containing the prompt text data. Defaults to 'None'."},
182
+ )
183
+ overwrite_cache: bool = field(
184
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
185
+ )
186
+ preprocessing_num_workers: Optional[int] = field(
187
+ default=None,
188
+ metadata={"help": "The number of processes to use for the preprocessing."},
189
+ )
190
+ max_train_samples: Optional[int] = field(
191
+ default=None,
192
+ metadata={
193
+ "help": (
194
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
195
+ "value if set."
196
+ )
197
+ },
198
+ )
199
+ max_eval_samples: Optional[int] = field(
200
+ default=None,
201
+ metadata={
202
+ "help": (
203
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
204
+ "value if set."
205
+ )
206
+ },
207
+ )
208
+ max_duration_in_seconds: float = field(
209
+ default=35.0,
210
+ metadata={
211
+ "help": (
212
+ "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`."
213
+ "Also, used to set maximum audio length if `pad_to_max_length=True`."
214
+ )
215
+ },
216
+ )
217
+ min_duration_in_seconds: float = field(
218
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
219
+ )
220
+ max_text_length: int = field(
221
+ default=500, metadata={"help": "If set, max description lengths in number of characters."}
222
+ )
223
+ max_prompt_token_length: int = field(
224
+ default=None,
225
+ metadata={
226
+ "help": (
227
+ "If set, filter samples with prompts that are longer than `max_prompt_token_length` tokens."
228
+ "Also, used to set maximum prompt token length if `pad_to_max_length=True`."
229
+ )
230
+ },
231
+ )
232
+ max_description_token_length: int = field(
233
+ default=None,
234
+ metadata={
235
+ "help": (
236
+ "If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
237
+ "Also, used to set maximum description token length if `pad_to_max_length=True`."
238
+ )
239
+ },
240
+ )
241
+ pad_to_max_length: bool = field(
242
+ default=False,
243
+ metadata={
244
+ "help": (
245
+ "If `True`, pad audio, prompt and description to a maximum length set with respectively "
246
+ "`max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`."
247
+ )
248
+ },
249
+ )
250
+ preprocessing_only: bool = field(
251
+ default=False,
252
+ metadata={
253
+ "help": (
254
+ "Whether to only do data preprocessing and skip training. This is especially useful when data"
255
+ " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
256
+ " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
257
+ " can consequently be loaded in distributed training."
258
+ " In this training script, `save_to_disk` must be set to the path in which the dataset should be saved. "
259
+ )
260
+ },
261
+ )
262
+ token: str = field(
263
+ default=None,
264
+ metadata={
265
+ "help": (
266
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
267
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
268
+ )
269
+ },
270
+ )
271
+ use_auth_token: bool = field(
272
+ default=None,
273
+ metadata={
274
+ "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
275
+ },
276
+ )
277
+ trust_remote_code: bool = field(
278
+ default=False,
279
+ metadata={
280
+ "help": (
281
+ "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
282
+ "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
283
+ "execute code present on the Hub on your local machine."
284
+ )
285
+ },
286
+ )
287
+ add_audio_samples_to_wandb: bool = field(
288
+ default=False,
289
+ metadata={"help": "If set and if `wandb` in args.report_to, will add generated audio samples to wandb logs."},
290
+ )
291
+ id_column_name: str = field(default=None, metadata={"help": "id column name."})
292
+ wandb_project: str = field(
293
+ default="parler-speech",
294
+ metadata={"help": "The name of the wandb project."},
295
+ )
296
+ wandb_run_name: str = field(
297
+ default=None,
298
+ metadata={
299
+ "help": "If specified, the name of the run. If not specified, wandb will give a random name to this run."
300
+ },
301
+ )
302
+ save_to_disk: str = field(
303
+ default=None,
304
+ metadata={
305
+ "help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
306
+ },
307
+ )
308
+ temporary_save_to_disk: str = field(default=None, metadata={"help": "Temporarily save audio labels here."})
309
+ save_codec_steps: Optional[int] = field(
310
+ default=500,
311
+ metadata={"help": "Temporarily save the audio labels every `save_steps`."},
312
+ )
313
+ pad_to_multiple_of: Optional[int] = field(
314
+ default=2,
315
+ metadata={"help": ("Pad to multiple of for tokenizers.")},
316
+ )
317
+ mls_dir: str = field(
318
+ default=None,
319
+ metadata={"help": "mls audio dir"},
320
+ )
321
+ librittsrmix_dir: str = field(
322
+ default=None,
323
+ metadata={"help": "librittsrmix audio dir"},
324
+ )
325
+ gigaspeech_dir: str = field(
326
+ default=None,
327
+ metadata={"help": "gigaspeech audio dir"},
328
+ )
329
+ commonvoice_dir: str = field(
330
+ default=None,
331
+ metadata={"help": "commonvoice audio dir"},
332
+ )
333
+ emilia_dir: str = field(
334
+ default=None,
335
+ metadata={"help": "emilia audio dir"},
336
+ )
337
+ source_column_name: str = field(
338
+ default="source",
339
+ metadata={"help": "The name of the source column."},
340
+ )
341
+ wandb_key: str = field(
342
+ default=None,
343
+ metadata={"help": "wandb key name"},
344
+ )
345
+
346
+
347
+ @dataclass
348
+ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
349
+ dtype: Optional[str] = field(
350
+ default="float32",
351
+ metadata={
352
+ "help": (
353
+ "The data type (dtype) in which to run training. One of `float32` (full-precision), "
354
+ "`float16` or `bfloat16` (both half-precision)."
355
+ )
356
+ },
357
+ )
358
+ audio_encoder_per_device_batch_size: int = field(
359
+ default=8,
360
+ metadata={"help": ("Specify the batch size of the audio encoding pre-processing steps.")},
361
+ )
362
+ eval_dataloader_num_workers: Optional[int] = field(
363
+ default=0,
364
+ metadata={
365
+ "help": (
366
+ "Number of subprocesses to use for evaluation data loading (PyTorch only). 0 means that the data will be loaded in the main process."
367
+ )
368
+ },
369
+ )
370
+ compute_clap_similarity_metric: bool = field(
371
+ default=True,
372
+ metadata={
373
+ "help": (
374
+ "Whether or not to compute the clap similarity metric between the description and the generation during evalution."
375
+ )
376
+ },
377
+ )
378
+ compute_noise_level_metric: bool = field(
379
+ default=True,
380
+ metadata={"help": ("Whether or not to compute the squim si-sdr measure of the generations.")},
381
+ )
382
+ noise_level_to_compute_clean_wer: float = field(
383
+ default=25,
384
+ metadata={
385
+ "help": (
386
+ "if `compute_noise_level_metric=True`, will compute a 'clean' WER on samples with generated noise higher than `noise_level_to_compute_clean_wer`."
387
+ "This is a proxy measure to compute WER on clean audios, provided that the model learn to generate clean audios."
388
+ )
389
+ },
390
+ )
391
+ eval_generation_steps: Optional[int] = field(
392
+ default=None,
393
+ metadata={
394
+ "help": (
395
+ "Number of update steps between two generation evaluation. Will default to the same"
396
+ "value as `eval_steps` if not set. Should be an integer and a multiple of `eval_steps`."
397
+ )
398
+ },
399
+ )
400
+ codebook_weights: Optional[List[float]] = field(
401
+ default=None,
402
+ metadata={"help": "Weights applied to each codebook."},
403
+ )
capspeech/ar/training/arguments_captts.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional, List
3
+
4
+ from transformers import Seq2SeqTrainingArguments
5
+
6
+
7
+ @dataclass
8
+ class ModelArguments:
9
+ """
10
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
11
+ """
12
+
13
+ model_name_or_path: str = field(
14
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
15
+ )
16
+ config_name: Optional[str] = field(
17
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
18
+ )
19
+ feature_extractor_name: Optional[str] = field(
20
+ default=None, metadata={"help": "Pretrained feature extractor name or path if not the same as model_name"}
21
+ )
22
+ description_tokenizer_name: Optional[str] = field(
23
+ default=None, metadata={"help": "Pretrained description tokenizer name or path if not the same as model_name"}
24
+ )
25
+ prompt_tokenizer_name: Optional[str] = field(
26
+ default=None,
27
+ metadata={"help": "Pretrained prompt tokenizer name or path if not the same as description_tokenizer_name"},
28
+ )
29
+ cache_dir: Optional[str] = field(
30
+ default=None,
31
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
32
+ )
33
+ use_fast_tokenizer: bool = field(
34
+ default=True,
35
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
36
+ )
37
+ model_revision: str = field(
38
+ default="main",
39
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
40
+ )
41
+ pad_token_id: int = field(
42
+ default=None,
43
+ metadata={"help": "If specified, change the model pad token id."},
44
+ )
45
+ decoder_start_token_id: int = field(
46
+ default=None,
47
+ metadata={"help": "If specified, change the model decoder start token id."},
48
+ )
49
+ freeze_text_encoder: bool = field(
50
+ default=False,
51
+ metadata={"help": "Whether to freeze the text encoder."},
52
+ )
53
+ do_sample: bool = field(
54
+ default=True,
55
+ metadata={"help": "Whether to do sampling or greedy decoding."},
56
+ )
57
+ temperature: float = field(
58
+ default=1.0,
59
+ metadata={"help": "Temperature if sampling."},
60
+ )
61
+ max_length: int = field(
62
+ default=2580,
63
+ metadata={"help": "Generation max length."},
64
+ )
65
+ bandwidth: float = field(
66
+ default=6,
67
+ metadata={"help": "Audio encoder bandwidth."},
68
+ )
69
+ asr_model_name_or_path: str = field(
70
+ default="distil-whisper/distil-large-v2",
71
+ metadata={
72
+ "help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
73
+ },
74
+ )
75
+ clap_model_name_or_path: str = field(
76
+ default="laion/larger_clap_music_and_speech",
77
+ metadata={
78
+ "help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
79
+ },
80
+ )
81
+ attn_implementation: str = field(
82
+ default="eager",
83
+ metadata={"help": "Attention implementation used. One of `eager`, `sdpa`, `flash_attention_2`"},
84
+ )
85
+ cross_attention_implementation_strategy: str = field(
86
+ default=None,
87
+ metadata={
88
+ "help": "If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation."
89
+ },
90
+ )
91
+ prompt_padding_side: Optional[str] = field(
92
+ default="left",
93
+ metadata={
94
+ "help": "Prompt tokenizer padding side. Defaults to `left`. If the prompt is pre-pended to the codebooks hidden states, it should be padded on the left."
95
+ },
96
+ )
97
+
98
+
99
+ @dataclass
100
+ class DataTrainingArguments:
101
+ """
102
+ Arguments pertaining to what data we are going to input our model for training and eval.
103
+
104
+ Using `HfArgumentParser` we can turn this class
105
+ into argparse arguments to be able to specify them on
106
+ the command line.
107
+ """
108
+
109
+ train_dataset_name: str = field(
110
+ default=None,
111
+ metadata={
112
+ "help": "The name of the training dataset to use (via the datasets library). Load and combine "
113
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
114
+ " librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
115
+ },
116
+ )
117
+ train_dataset_config_name: Optional[str] = field(
118
+ default=None,
119
+ metadata={
120
+ "help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
121
+ "multiple datasets by separating dataset configs by a '+' symbol."
122
+ },
123
+ )
124
+ train_split_name: str = field(
125
+ default="train",
126
+ metadata={
127
+ "help": ("The name of the training data set split to use (via the datasets library). Defaults to 'train'")
128
+ },
129
+ )
130
+ train_dataset_samples: str = field(
131
+ default=None,
132
+ metadata={
133
+ "help": "Number of samples in the training data. Load and combine "
134
+ "multiple datasets by separating dataset samples by a '+' symbol."
135
+ },
136
+ )
137
+ train_metadata_dataset_name: str = field(
138
+ default=None,
139
+ metadata={
140
+ "help": "The name of the metadata training dataset to use (via the datasets library). Load and combine "
141
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
142
+ " librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
143
+ },
144
+ )
145
+ eval_dataset_name: str = field(
146
+ default=None,
147
+ metadata={
148
+ "help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset name if unspecified."
149
+ },
150
+ )
151
+ eval_dataset_config_name: Optional[str] = field(
152
+ default=None,
153
+ metadata={
154
+ "help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset config name if unspecified"
155
+ },
156
+ )
157
+ eval_split_name: str = field(
158
+ default="test",
159
+ metadata={
160
+ "help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'test'"
161
+ },
162
+ )
163
+ eval_metadata_dataset_name: str = field(
164
+ default=None,
165
+ metadata={
166
+ "help": "The name of the metadata training dataset to use (via the datasets library). Load and combine "
167
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
168
+ " librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
169
+ },
170
+ )
171
+ target_audio_column_name: str = field(
172
+ default="audio",
173
+ metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"},
174
+ )
175
+ description_column_name: str = field(
176
+ default=None,
177
+ metadata={"help": "The name of the dataset column containing the description text data. Defaults to 'None'."},
178
+ )
179
+ prompt_column_name: str = field(
180
+ default=None,
181
+ metadata={"help": "The name of the dataset column containing the prompt text data. Defaults to 'None'."},
182
+ )
183
+ overwrite_cache: bool = field(
184
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
185
+ )
186
+ preprocessing_num_workers: Optional[int] = field(
187
+ default=None,
188
+ metadata={"help": "The number of processes to use for the preprocessing."},
189
+ )
190
+ max_train_samples: Optional[int] = field(
191
+ default=None,
192
+ metadata={
193
+ "help": (
194
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
195
+ "value if set."
196
+ )
197
+ },
198
+ )
199
+ max_eval_samples: Optional[int] = field(
200
+ default=None,
201
+ metadata={
202
+ "help": (
203
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
204
+ "value if set."
205
+ )
206
+ },
207
+ )
208
+ max_duration_in_seconds: float = field(
209
+ default=35.0,
210
+ metadata={
211
+ "help": (
212
+ "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`."
213
+ "Also, used to set maximum audio length if `pad_to_max_length=True`."
214
+ )
215
+ },
216
+ )
217
+ min_duration_in_seconds: float = field(
218
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
219
+ )
220
+ max_text_length: int = field(
221
+ default=500, metadata={"help": "If set, max description lengths in number of characters."}
222
+ )
223
+ max_prompt_token_length: int = field(
224
+ default=None,
225
+ metadata={
226
+ "help": (
227
+ "If set, filter samples with prompts that are longer than `max_prompt_token_length` tokens."
228
+ "Also, used to set maximum prompt token length if `pad_to_max_length=True`."
229
+ )
230
+ },
231
+ )
232
+ max_description_token_length: int = field(
233
+ default=None,
234
+ metadata={
235
+ "help": (
236
+ "If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
237
+ "Also, used to set maximum description token length if `pad_to_max_length=True`."
238
+ )
239
+ },
240
+ )
241
+ pad_to_max_length: bool = field(
242
+ default=False,
243
+ metadata={
244
+ "help": (
245
+ "If `True`, pad audio, prompt and description to a maximum length set with respectively "
246
+ "`max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`."
247
+ )
248
+ },
249
+ )
250
+ preprocessing_only: bool = field(
251
+ default=False,
252
+ metadata={
253
+ "help": (
254
+ "Whether to only do data preprocessing and skip training. This is especially useful when data"
255
+ " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
256
+ " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
257
+ " can consequently be loaded in distributed training."
258
+ " In this training script, `save_to_disk` must be set to the path in which the dataset should be saved. "
259
+ )
260
+ },
261
+ )
262
+ token: str = field(
263
+ default=None,
264
+ metadata={
265
+ "help": (
266
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
267
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
268
+ )
269
+ },
270
+ )
271
+ use_auth_token: bool = field(
272
+ default=None,
273
+ metadata={
274
+ "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
275
+ },
276
+ )
277
+ trust_remote_code: bool = field(
278
+ default=False,
279
+ metadata={
280
+ "help": (
281
+ "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
282
+ "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
283
+ "execute code present on the Hub on your local machine."
284
+ )
285
+ },
286
+ )
287
+ add_audio_samples_to_wandb: bool = field(
288
+ default=False,
289
+ metadata={"help": "If set and if `wandb` in args.report_to, will add generated audio samples to wandb logs."},
290
+ )
291
+ id_column_name: str = field(default=None, metadata={"help": "id column name."})
292
+ wandb_project: str = field(
293
+ default="parler-speech",
294
+ metadata={"help": "The name of the wandb project."},
295
+ )
296
+ wandb_run_name: str = field(
297
+ default=None,
298
+ metadata={
299
+ "help": "If specified, the name of the run. If not specified, wandb will give a random name to this run."
300
+ },
301
+ )
302
+ save_to_disk: str = field(
303
+ default=None,
304
+ metadata={
305
+ "help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
306
+ },
307
+ )
308
+ temporary_save_to_disk: str = field(default=None, metadata={"help": "Temporarily save audio labels here."})
309
+ save_codec_steps: Optional[int] = field(
310
+ default=500,
311
+ metadata={"help": "Temporarily save the audio labels every `save_steps`."},
312
+ )
313
+ pad_to_multiple_of: Optional[int] = field(
314
+ default=2,
315
+ metadata={"help": ("Pad to multiple of for tokenizers.")},
316
+ )
317
+ librittsr_dir: str = field(
318
+ default=None,
319
+ metadata={"help": "librittsr audio dir"},
320
+ )
321
+ other_dir: str = field(
322
+ default=None,
323
+ metadata={"help": "other audio dir"},
324
+ )
325
+ source_column_name: str = field(
326
+ default="source",
327
+ metadata={"help": "The name of the source column."},
328
+ )
329
+ wandb_key: str = field(
330
+ default=None,
331
+ metadata={"help": "wandb key name"},
332
+ )
333
+
334
+
335
+ @dataclass
336
+ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
337
+ dtype: Optional[str] = field(
338
+ default="float32",
339
+ metadata={
340
+ "help": (
341
+ "The data type (dtype) in which to run training. One of `float32` (full-precision), "
342
+ "`float16` or `bfloat16` (both half-precision)."
343
+ )
344
+ },
345
+ )
346
+ audio_encoder_per_device_batch_size: int = field(
347
+ default=8,
348
+ metadata={"help": ("Specify the batch size of the audio encoding pre-processing steps.")},
349
+ )
350
+ eval_dataloader_num_workers: Optional[int] = field(
351
+ default=0,
352
+ metadata={
353
+ "help": (
354
+ "Number of subprocesses to use for evaluation data loading (PyTorch only). 0 means that the data will be loaded in the main process."
355
+ )
356
+ },
357
+ )
358
+ compute_clap_similarity_metric: bool = field(
359
+ default=True,
360
+ metadata={
361
+ "help": (
362
+ "Whether or not to compute the clap similarity metric between the description and the generation during evalution."
363
+ )
364
+ },
365
+ )
366
+ compute_noise_level_metric: bool = field(
367
+ default=True,
368
+ metadata={"help": ("Whether or not to compute the squim si-sdr measure of the generations.")},
369
+ )
370
+ noise_level_to_compute_clean_wer: float = field(
371
+ default=25,
372
+ metadata={
373
+ "help": (
374
+ "if `compute_noise_level_metric=True`, will compute a 'clean' WER on samples with generated noise higher than `noise_level_to_compute_clean_wer`."
375
+ "This is a proxy measure to compute WER on clean audios, provided that the model learn to generate clean audios."
376
+ )
377
+ },
378
+ )
379
+ eval_generation_steps: Optional[int] = field(
380
+ default=None,
381
+ metadata={
382
+ "help": (
383
+ "Number of update steps between two generation evaluation. Will default to the same"
384
+ "value as `eval_steps` if not set. Should be an integer and a multiple of `eval_steps`."
385
+ )
386
+ },
387
+ )
388
+ codebook_weights: Optional[List[float]] = field(
389
+ default=None,
390
+ metadata={"help": "Weights applied to each codebook."},
391
+ )
capspeech/ar/training/arguments_capttsse.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional, List
3
+
4
+ from transformers import Seq2SeqTrainingArguments
5
+
6
+
7
+ @dataclass
8
+ class ModelArguments:
9
+ """
10
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
11
+ """
12
+
13
+ model_name_or_path: str = field(
14
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
15
+ )
16
+ config_name: Optional[str] = field(
17
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
18
+ )
19
+ feature_extractor_name: Optional[str] = field(
20
+ default=None, metadata={"help": "Pretrained feature extractor name or path if not the same as model_name"}
21
+ )
22
+ description_tokenizer_name: Optional[str] = field(
23
+ default=None, metadata={"help": "Pretrained description tokenizer name or path if not the same as model_name"}
24
+ )
25
+ prompt_tokenizer_name: Optional[str] = field(
26
+ default=None,
27
+ metadata={"help": "Pretrained prompt tokenizer name or path if not the same as description_tokenizer_name"},
28
+ )
29
+ cache_dir: Optional[str] = field(
30
+ default=None,
31
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
32
+ )
33
+ use_fast_tokenizer: bool = field(
34
+ default=True,
35
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
36
+ )
37
+ model_revision: str = field(
38
+ default="main",
39
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
40
+ )
41
+ pad_token_id: int = field(
42
+ default=None,
43
+ metadata={"help": "If specified, change the model pad token id."},
44
+ )
45
+ decoder_start_token_id: int = field(
46
+ default=None,
47
+ metadata={"help": "If specified, change the model decoder start token id."},
48
+ )
49
+ freeze_text_encoder: bool = field(
50
+ default=False,
51
+ metadata={"help": "Whether to freeze the text encoder."},
52
+ )
53
+ do_sample: bool = field(
54
+ default=True,
55
+ metadata={"help": "Whether to do sampling or greedy decoding."},
56
+ )
57
+ temperature: float = field(
58
+ default=1.0,
59
+ metadata={"help": "Temperature if sampling."},
60
+ )
61
+ max_length: int = field(
62
+ default=2580,
63
+ metadata={"help": "Generation max length."},
64
+ )
65
+ bandwidth: float = field(
66
+ default=6,
67
+ metadata={"help": "Audio encoder bandwidth."},
68
+ )
69
+ asr_model_name_or_path: str = field(
70
+ default="distil-whisper/distil-large-v2",
71
+ metadata={
72
+ "help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
73
+ },
74
+ )
75
+ clap_model_name_or_path: str = field(
76
+ default="laion/larger_clap_music_and_speech",
77
+ metadata={
78
+ "help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
79
+ },
80
+ )
81
+ attn_implementation: str = field(
82
+ default="eager",
83
+ metadata={"help": "Attention implementation used. One of `eager`, `sdpa`, `flash_attention_2`"},
84
+ )
85
+ cross_attention_implementation_strategy: str = field(
86
+ default=None,
87
+ metadata={
88
+ "help": "If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation."
89
+ },
90
+ )
91
+ prompt_padding_side: Optional[str] = field(
92
+ default="left",
93
+ metadata={
94
+ "help": "Prompt tokenizer padding side. Defaults to `left`. If the prompt is pre-pended to the codebooks hidden states, it should be padded on the left."
95
+ },
96
+ )
97
+
98
+
99
+ @dataclass
100
+ class DataTrainingArguments:
101
+ """
102
+ Arguments pertaining to what data we are going to input our model for training and eval.
103
+
104
+ Using `HfArgumentParser` we can turn this class
105
+ into argparse arguments to be able to specify them on
106
+ the command line.
107
+ """
108
+
109
+ train_dataset_name: str = field(
110
+ default=None,
111
+ metadata={
112
+ "help": "The name of the training dataset to use (via the datasets library). Load and combine "
113
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
114
+ " librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
115
+ },
116
+ )
117
+ train_dataset_config_name: Optional[str] = field(
118
+ default=None,
119
+ metadata={
120
+ "help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
121
+ "multiple datasets by separating dataset configs by a '+' symbol."
122
+ },
123
+ )
124
+ train_split_name: str = field(
125
+ default="train",
126
+ metadata={
127
+ "help": ("The name of the training data set split to use (via the datasets library). Defaults to 'train'")
128
+ },
129
+ )
130
+ train_dataset_samples: str = field(
131
+ default=None,
132
+ metadata={
133
+ "help": "Number of samples in the training data. Load and combine "
134
+ "multiple datasets by separating dataset samples by a '+' symbol."
135
+ },
136
+ )
137
+ train_metadata_dataset_name: str = field(
138
+ default=None,
139
+ metadata={
140
+ "help": "The name of the metadata training dataset to use (via the datasets library). Load and combine "
141
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
142
+ " librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
143
+ },
144
+ )
145
+ eval_dataset_name: str = field(
146
+ default=None,
147
+ metadata={
148
+ "help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset name if unspecified."
149
+ },
150
+ )
151
+ eval_dataset_config_name: Optional[str] = field(
152
+ default=None,
153
+ metadata={
154
+ "help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset config name if unspecified"
155
+ },
156
+ )
157
+ eval_split_name: str = field(
158
+ default="test",
159
+ metadata={
160
+ "help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'test'"
161
+ },
162
+ )
163
+ eval_metadata_dataset_name: str = field(
164
+ default=None,
165
+ metadata={
166
+ "help": "The name of the metadata training dataset to use (via the datasets library). Load and combine "
167
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
168
+ " librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
169
+ },
170
+ )
171
+ target_audio_column_name: str = field(
172
+ default="audio",
173
+ metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"},
174
+ )
175
+ description_column_name: str = field(
176
+ default=None,
177
+ metadata={"help": "The name of the dataset column containing the description text data. Defaults to 'None'."},
178
+ )
179
+ prompt_column_name: str = field(
180
+ default=None,
181
+ metadata={"help": "The name of the dataset column containing the prompt text data. Defaults to 'None'."},
182
+ )
183
+ overwrite_cache: bool = field(
184
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
185
+ )
186
+ preprocessing_num_workers: Optional[int] = field(
187
+ default=None,
188
+ metadata={"help": "The number of processes to use for the preprocessing."},
189
+ )
190
+ max_train_samples: Optional[int] = field(
191
+ default=None,
192
+ metadata={
193
+ "help": (
194
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
195
+ "value if set."
196
+ )
197
+ },
198
+ )
199
+ max_eval_samples: Optional[int] = field(
200
+ default=None,
201
+ metadata={
202
+ "help": (
203
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
204
+ "value if set."
205
+ )
206
+ },
207
+ )
208
+ max_duration_in_seconds: float = field(
209
+ default=35.0,
210
+ metadata={
211
+ "help": (
212
+ "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`."
213
+ "Also, used to set maximum audio length if `pad_to_max_length=True`."
214
+ )
215
+ },
216
+ )
217
+ min_duration_in_seconds: float = field(
218
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
219
+ )
220
+ max_text_length: int = field(
221
+ default=500, metadata={"help": "If set, max description lengths in number of characters."}
222
+ )
223
+ max_prompt_token_length: int = field(
224
+ default=None,
225
+ metadata={
226
+ "help": (
227
+ "If set, filter samples with prompts that are longer than `max_prompt_token_length` tokens."
228
+ "Also, used to set maximum prompt token length if `pad_to_max_length=True`."
229
+ )
230
+ },
231
+ )
232
+ max_description_token_length: int = field(
233
+ default=None,
234
+ metadata={
235
+ "help": (
236
+ "If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
237
+ "Also, used to set maximum description token length if `pad_to_max_length=True`."
238
+ )
239
+ },
240
+ )
241
+ pad_to_max_length: bool = field(
242
+ default=False,
243
+ metadata={
244
+ "help": (
245
+ "If `True`, pad audio, prompt and description to a maximum length set with respectively "
246
+ "`max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`."
247
+ )
248
+ },
249
+ )
250
+ preprocessing_only: bool = field(
251
+ default=False,
252
+ metadata={
253
+ "help": (
254
+ "Whether to only do data preprocessing and skip training. This is especially useful when data"
255
+ " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
256
+ " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
257
+ " can consequently be loaded in distributed training."
258
+ " In this training script, `save_to_disk` must be set to the path in which the dataset should be saved. "
259
+ )
260
+ },
261
+ )
262
+ token: str = field(
263
+ default=None,
264
+ metadata={
265
+ "help": (
266
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
267
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
268
+ )
269
+ },
270
+ )
271
+ use_auth_token: bool = field(
272
+ default=None,
273
+ metadata={
274
+ "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
275
+ },
276
+ )
277
+ trust_remote_code: bool = field(
278
+ default=False,
279
+ metadata={
280
+ "help": (
281
+ "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
282
+ "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
283
+ "execute code present on the Hub on your local machine."
284
+ )
285
+ },
286
+ )
287
+ add_audio_samples_to_wandb: bool = field(
288
+ default=False,
289
+ metadata={"help": "If set and if `wandb` in args.report_to, will add generated audio samples to wandb logs."},
290
+ )
291
+ id_column_name: str = field(default=None, metadata={"help": "id column name."})
292
+ wandb_project: str = field(
293
+ default="parler-speech",
294
+ metadata={"help": "The name of the wandb project."},
295
+ )
296
+ wandb_run_name: str = field(
297
+ default=None,
298
+ metadata={
299
+ "help": "If specified, the name of the run. If not specified, wandb will give a random name to this run."
300
+ },
301
+ )
302
+ save_to_disk: str = field(
303
+ default=None,
304
+ metadata={
305
+ "help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
306
+ },
307
+ )
308
+ temporary_save_to_disk: str = field(default=None, metadata={"help": "Temporarily save audio labels here."})
309
+ save_codec_steps: Optional[int] = field(
310
+ default=500,
311
+ metadata={"help": "Temporarily save the audio labels every `save_steps`."},
312
+ )
313
+ pad_to_multiple_of: Optional[int] = field(
314
+ default=2,
315
+ metadata={"help": ("Pad to multiple of for tokenizers.")},
316
+ )
317
+ librittsrmix_dir: str = field(
318
+ default=None,
319
+ metadata={"help": "librittsrmix audio dir"},
320
+ )
321
+ source_column_name: str = field(
322
+ default="source",
323
+ metadata={"help": "The name of the source column."},
324
+ )
325
+ wandb_key: str = field(
326
+ default=None,
327
+ metadata={"help": "wandb key name"},
328
+ )
329
+
330
+
331
+ @dataclass
332
+ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
333
+ dtype: Optional[str] = field(
334
+ default="float32",
335
+ metadata={
336
+ "help": (
337
+ "The data type (dtype) in which to run training. One of `float32` (full-precision), "
338
+ "`float16` or `bfloat16` (both half-precision)."
339
+ )
340
+ },
341
+ )
342
+ audio_encoder_per_device_batch_size: int = field(
343
+ default=8,
344
+ metadata={"help": ("Specify the batch size of the audio encoding pre-processing steps.")},
345
+ )
346
+ eval_dataloader_num_workers: Optional[int] = field(
347
+ default=0,
348
+ metadata={
349
+ "help": (
350
+ "Number of subprocesses to use for evaluation data loading (PyTorch only). 0 means that the data will be loaded in the main process."
351
+ )
352
+ },
353
+ )
354
+ compute_clap_similarity_metric: bool = field(
355
+ default=True,
356
+ metadata={
357
+ "help": (
358
+ "Whether or not to compute the clap similarity metric between the description and the generation during evalution."
359
+ )
360
+ },
361
+ )
362
+ compute_noise_level_metric: bool = field(
363
+ default=True,
364
+ metadata={"help": ("Whether or not to compute the squim si-sdr measure of the generations.")},
365
+ )
366
+ noise_level_to_compute_clean_wer: float = field(
367
+ default=25,
368
+ metadata={
369
+ "help": (
370
+ "if `compute_noise_level_metric=True`, will compute a 'clean' WER on samples with generated noise higher than `noise_level_to_compute_clean_wer`."
371
+ "This is a proxy measure to compute WER on clean audios, provided that the model learn to generate clean audios."
372
+ )
373
+ },
374
+ )
375
+ eval_generation_steps: Optional[int] = field(
376
+ default=None,
377
+ metadata={
378
+ "help": (
379
+ "Number of update steps between two generation evaluation. Will default to the same"
380
+ "value as `eval_steps` if not set. Should be an integer and a multiple of `eval_steps`."
381
+ )
382
+ },
383
+ )
384
+ codebook_weights: Optional[List[float]] = field(
385
+ default=None,
386
+ metadata={"help": "Weights applied to each codebook."},
387
+ )
capspeech/ar/training/data.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional, Set, Union
4
+ import os
5
+ import datasets
6
+ import numpy as np
7
+ import torch
8
+ from accelerate import Accelerator
9
+ from datasets import Dataset, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset
10
+ from tqdm import tqdm
11
+ from transformers import AutoFeatureExtractor, AutoTokenizer
12
+ import torchaudio
13
+ import torchaudio.transforms as T
14
+
15
+ @dataclass
16
+ class DataCollatorEncodecWithPadding:
17
+ """
18
+ Data collator that will dynamically pad the inputs received to the longest sequence in the batch or
19
+ to `max_length` if `max_length` is set and `padding=max_length`.
20
+ """
21
+
22
+ feature_extractor: AutoFeatureExtractor
23
+ audio_column_name: str
24
+ mls_dir: Optional[str] = None
25
+ librittsrmix_dir: Optional[str] = None
26
+ gigaspeech_dir: Optional[str] = None
27
+ commonvoice_dir: Optional[str] = None
28
+ emilia_dir: Optional[str] = None
29
+ feature_extractor_input_name: Optional[str] = "input_values"
30
+ max_length: Optional[int] = None
31
+ padding: Optional[str] = "longest"
32
+
33
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
34
+ # split inputs and labels since they have to be of different lengths and need
35
+ # different padding methods
36
+ sampling_rate = self.feature_extractor.sampling_rate
37
+ # load audio
38
+ audios = []
39
+ for f in features:
40
+ path = f[self.audio_column_name]
41
+ source = f["source"]
42
+ if source == "libritts-r":
43
+ path = os.path.join(self.librittsrmix_dir, path)
44
+ elif source == "mls":
45
+ path = os.path.join(self.mls_dir, path)
46
+ elif source == "gigaspeech":
47
+ path = os.path.join(self.gigaspeech_dir, path)
48
+ elif source == "commonvoice":
49
+ path = os.path.join(self.commonvoice_dir, path)
50
+ elif source == "emilia":
51
+ path = os.path.join(self.emilia_dir, path)
52
+ else:
53
+ raise ValueError(source)
54
+
55
+ if os.path.exists(path):
56
+ waveform, sr = torchaudio.load(path)
57
+ if sr != sampling_rate:
58
+ resampler = T.Resample(orig_freq=sr, new_freq=sampling_rate)
59
+ waveform = resampler(waveform)
60
+ if waveform.shape[0] > 1:
61
+ waveform = waveform.mean(dim=0, keepdim=True)
62
+ audios.append(waveform.squeeze())
63
+ else:
64
+ print(f"Read error: {path}")
65
+
66
+
67
+ len_audio = [len(audio) for audio in audios]
68
+ if self.max_length is not None:
69
+ audios = [audio[: min(l, self.max_length)] for audio, l in zip(audios, len_audio)]
70
+
71
+ # since resampling has already been performed in the 'load_multiple_datasets' function,
72
+ # a fixed sampling_rate(44100hz) is passed to the feature_extractor.
73
+ batch = self.feature_extractor(
74
+ [np.asarray(a, dtype=np.float32) for a in audios], sampling_rate=sampling_rate, return_tensors="pt", padding=self.padding, max_length=self.max_length
75
+ )
76
+ batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
77
+ return batch
78
+
79
+
80
+ @dataclass
81
+ class DataCollatorParlerTTSWithPadding:
82
+ """
83
+ Data collator that will dynamically pad the inputs received.
84
+ Args:
85
+ prompt_tokenizer (:class:`~transformers.AutoTokenizer`)
86
+ The prompt_tokenizer used for proccessing the data.
87
+ description_tokenizer (:class:`~transformers.AutoTokenizer`)
88
+ The description_tokenizer used for proccessing the data.
89
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
90
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
91
+ among:
92
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
93
+ sequence if provided).
94
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
95
+ maximum acceptable input length for the model if that argument is not provided.
96
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
97
+ different lengths).
98
+ pad_to_multiple_of (:obj:`int`, `optional`):
99
+ If set will pad the sequence to a multiple of the provided value.
100
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
101
+ 7.5 (Volta).
102
+ """
103
+
104
+ prompt_tokenizer: AutoTokenizer
105
+ description_tokenizer: AutoTokenizer
106
+ padding: Union[bool, str] = "longest"
107
+ pad_to_multiple_of: Optional[int] = None
108
+ prompt_max_length: Optional[int] = None
109
+ description_max_length: Optional[int] = None
110
+ audio_max_length: Optional[int] = None
111
+
112
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
113
+ # split inputs and labels since they have to be of different lengths and need
114
+ # different padding methods
115
+
116
+ labels = [torch.tensor(feature["labels"]).transpose(0, 1) for feature in features]
117
+ # (bsz, seq_len, num_codebooks)
118
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
119
+ if self.audio_max_length is not None and self.padding == "max_length":
120
+ labels = torch.nn.functional.pad(
121
+ labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)), value=-100
122
+ )
123
+
124
+ input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
125
+
126
+ input_ids = self.description_tokenizer.pad(
127
+ input_ids,
128
+ return_tensors="pt",
129
+ padding=self.padding,
130
+ pad_to_multiple_of=self.pad_to_multiple_of,
131
+ max_length=self.description_max_length,
132
+ )
133
+
134
+ batch = {"labels": labels, **input_ids}
135
+
136
+ prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
137
+ prompt_input_ids = self.prompt_tokenizer.pad(
138
+ prompt_input_ids,
139
+ return_tensors="pt",
140
+ padding=self.padding,
141
+ pad_to_multiple_of=self.pad_to_multiple_of,
142
+ max_length=self.prompt_max_length,
143
+ )
144
+
145
+ batch["prompt_input_ids"] = prompt_input_ids["input_ids"]
146
+ if "attention_mask" in prompt_input_ids:
147
+ batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"]
148
+
149
+ return batch
150
+
151
+
152
+ def convert_dataset_str_to_list(
153
+ dataset_names,
154
+ splits=None,
155
+ dataset_samples=None,
156
+ default_split="train",
157
+ ):
158
+ if isinstance(dataset_names, str):
159
+ dataset_names = dataset_names.split("+")
160
+ splits = splits.split("+") if splits is not None else None
161
+ dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None
162
+
163
+ if splits is not None and len(splits) != len(dataset_names):
164
+ raise ValueError(
165
+ f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
166
+ )
167
+
168
+ if dataset_samples is not None:
169
+ if len(dataset_samples) != len(dataset_names):
170
+ raise ValueError(
171
+ f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
172
+ f"{len(dataset_samples)} samples."
173
+ )
174
+ dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
175
+ else:
176
+ dataset_samples = [None] * len(dataset_names)
177
+
178
+ splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
179
+
180
+ dataset_names_dict = []
181
+ for i, ds_name in enumerate(dataset_names):
182
+ dataset_names_dict.append(
183
+ {
184
+ "name": ds_name,
185
+ "split": splits[i],
186
+ "samples": dataset_samples[i],
187
+ }
188
+ )
189
+ return dataset_names_dict
190
+
191
+
192
+ def load_multiple_datasets(
193
+ accelerator: Accelerator,
194
+ dataset_names: Union[List, str],
195
+ splits: Optional[Union[List, str]] = None,
196
+ label_column_names: Optional[List] = None,
197
+ stopping_strategy: Optional[str] = "first_exhausted",
198
+ dataset_samples: Optional[Union[List, np.array]] = None,
199
+ streaming: Optional[bool] = False,
200
+ seed: Optional[int] = None,
201
+ id_column_name: Optional[str] = None,
202
+ columns_to_keep: Optional[Set[str]] = None,
203
+ prompt_column_name: Optional[str] = None,
204
+ sampling_rate: Optional[int] = None,
205
+ audio_column_name: Optional[str] = None,
206
+ logger: Optional[logging.Logger] = None,
207
+ librittsrmix_dir: Optional[Union[List, str]] = None,
208
+ mls_dir: Optional[Union[List, str]] = None,
209
+ gigaspeech_dir: Optional[Union[List, str]] = None,
210
+ commonvoice_dir: Optional[Union[List, str]] = None,
211
+ emilia_dir: Optional[Union[List, str]] = None,
212
+ **kwargs,
213
+ ) -> Union[Dataset, IterableDataset]:
214
+ dataset_names_dict = convert_dataset_str_to_list(
215
+ dataset_names, splits, label_column_names, dataset_samples
216
+ )
217
+
218
+ if dataset_samples is not None:
219
+ dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
220
+ probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
221
+ else:
222
+ probabilities = None
223
+
224
+ all_datasets = []
225
+ # iterate over the datasets we want to interleave
226
+ for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."):
227
+ with accelerator.local_main_process_first():
228
+ dataset = load_dataset(
229
+ dataset_dict["name"],
230
+ split=dataset_dict["split"],
231
+ streaming=streaming,
232
+ **kwargs,
233
+ )
234
+ dataset_features = dataset.features.keys()
235
+
236
+ if columns_to_keep is not None:
237
+ dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
238
+
239
+ def resolve_path(example):
240
+ path = example["audio_path"]
241
+ source = example["source"]
242
+
243
+ if source == "libritts-r":
244
+ full_path = os.path.join(librittsrmix_dir, path)
245
+ elif source == "mls":
246
+ full_path = os.path.join(mls_dir, path)
247
+ elif source == "gigaspeech":
248
+ full_path = os.path.join(gigaspeech_dir, path)
249
+ elif source == "commonvoice":
250
+ full_path = os.path.join(commonvoice_dir, path)
251
+ elif source == "emilia":
252
+ full_path = os.path.join(emilia_dir, path)
253
+ else:
254
+ return False # unknown source
255
+
256
+ return os.path.exists(full_path)
257
+
258
+ dataset = dataset.filter(resolve_path, num_proc=16)
259
+
260
+ all_datasets.append(dataset)
261
+
262
+ if len(all_datasets) == 1:
263
+ # we have a single dataset so just return it as is
264
+ return all_datasets[0]
265
+
266
+ if streaming:
267
+ interleaved_dataset = interleave_datasets(
268
+ all_datasets,
269
+ stopping_strategy=stopping_strategy,
270
+ probabilities=probabilities,
271
+ seed=seed,
272
+ )
273
+ else:
274
+ with accelerator.local_main_process_first():
275
+ interleaved_dataset = concatenate_datasets(all_datasets)
276
+
277
+ return interleaved_dataset
capspeech/ar/training/data_captts.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional, Set, Union
4
+ import os
5
+ import datasets
6
+ import numpy as np
7
+ import torch
8
+ from accelerate import Accelerator
9
+ from datasets import Dataset, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset
10
+ from tqdm import tqdm
11
+ from transformers import AutoFeatureExtractor, AutoTokenizer
12
+ import torchaudio
13
+ import torchaudio.transforms as T
14
+
15
+ @dataclass
16
+ class DataCollatorEncodecWithPadding:
17
+ """
18
+ Data collator that will dynamically pad the inputs received to the longest sequence in the batch or
19
+ to `max_length` if `max_length` is set and `padding=max_length`.
20
+ """
21
+
22
+ feature_extractor: AutoFeatureExtractor
23
+ audio_column_name: str
24
+ librittsr_dir: Optional[str] = None
25
+ other_dir: Optional[str] = None
26
+ feature_extractor_input_name: Optional[str] = "input_values"
27
+ max_length: Optional[int] = None
28
+ padding: Optional[str] = "longest"
29
+
30
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
31
+ # split inputs and labels since they have to be of different lengths and need
32
+ # different padding methods
33
+ sampling_rate = self.feature_extractor.sampling_rate
34
+ # load audio
35
+ audios = []
36
+ for f in features:
37
+ path = f[self.audio_column_name]
38
+ source = f["source"]
39
+ if source == "libritts-r":
40
+ path = os.path.join(self.librittsr_dir, path)
41
+ else:
42
+ path = os.path.join(self.other_dir, path)
43
+
44
+ if os.path.exists(path):
45
+ waveform, sr = torchaudio.load(path)
46
+ if sr != sampling_rate:
47
+ resampler = T.Resample(orig_freq=sr, new_freq=sampling_rate)
48
+ waveform = resampler(waveform)
49
+ if waveform.shape[0] > 1:
50
+ waveform = waveform.mean(dim=0, keepdim=True)
51
+ audios.append(waveform.squeeze())
52
+ else:
53
+ print(f"Read error: {path}")
54
+
55
+
56
+ len_audio = [len(audio) for audio in audios]
57
+ if self.max_length is not None:
58
+ audios = [audio[: min(l, self.max_length)] for audio, l in zip(audios, len_audio)]
59
+
60
+ # since resampling has already been performed in the 'load_multiple_datasets' function,
61
+ # a fixed sampling_rate(44100hz) is passed to the feature_extractor.
62
+ batch = self.feature_extractor(
63
+ [np.asarray(a, dtype=np.float32) for a in audios], sampling_rate=sampling_rate, return_tensors="pt", padding=self.padding, max_length=self.max_length
64
+ )
65
+ batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
66
+ return batch
67
+
68
+
69
+ @dataclass
70
+ class DataCollatorParlerTTSWithPadding:
71
+ """
72
+ Data collator that will dynamically pad the inputs received.
73
+ Args:
74
+ prompt_tokenizer (:class:`~transformers.AutoTokenizer`)
75
+ The prompt_tokenizer used for proccessing the data.
76
+ description_tokenizer (:class:`~transformers.AutoTokenizer`)
77
+ The description_tokenizer used for proccessing the data.
78
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
79
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
80
+ among:
81
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
82
+ sequence if provided).
83
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
84
+ maximum acceptable input length for the model if that argument is not provided.
85
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
86
+ different lengths).
87
+ pad_to_multiple_of (:obj:`int`, `optional`):
88
+ If set will pad the sequence to a multiple of the provided value.
89
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
90
+ 7.5 (Volta).
91
+ """
92
+
93
+ prompt_tokenizer: AutoTokenizer
94
+ description_tokenizer: AutoTokenizer
95
+ padding: Union[bool, str] = "longest"
96
+ pad_to_multiple_of: Optional[int] = None
97
+ prompt_max_length: Optional[int] = None
98
+ description_max_length: Optional[int] = None
99
+ audio_max_length: Optional[int] = None
100
+
101
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
102
+ # split inputs and labels since they have to be of different lengths and need
103
+ # different padding methods
104
+
105
+ labels = [torch.tensor(feature["labels"]).transpose(0, 1) for feature in features]
106
+ # (bsz, seq_len, num_codebooks)
107
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
108
+ if self.audio_max_length is not None and self.padding == "max_length":
109
+ labels = torch.nn.functional.pad(
110
+ labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)), value=-100
111
+ )
112
+
113
+ input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
114
+
115
+ input_ids = self.description_tokenizer.pad(
116
+ input_ids,
117
+ return_tensors="pt",
118
+ padding=self.padding,
119
+ pad_to_multiple_of=self.pad_to_multiple_of,
120
+ max_length=self.description_max_length,
121
+ )
122
+
123
+ batch = {"labels": labels, **input_ids}
124
+
125
+ prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
126
+ prompt_input_ids = self.prompt_tokenizer.pad(
127
+ prompt_input_ids,
128
+ return_tensors="pt",
129
+ padding=self.padding,
130
+ pad_to_multiple_of=self.pad_to_multiple_of,
131
+ max_length=self.prompt_max_length,
132
+ )
133
+
134
+ batch["prompt_input_ids"] = prompt_input_ids["input_ids"]
135
+ if "attention_mask" in prompt_input_ids:
136
+ batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"]
137
+
138
+ return batch
139
+
140
+
141
+ def convert_dataset_str_to_list(
142
+ dataset_names,
143
+ splits=None,
144
+ dataset_samples=None,
145
+ default_split="train",
146
+ ):
147
+ if isinstance(dataset_names, str):
148
+ dataset_names = dataset_names.split("+")
149
+ splits = splits.split("+") if splits is not None else None
150
+ dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None
151
+
152
+ if splits is not None and len(splits) != len(dataset_names):
153
+ raise ValueError(
154
+ f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
155
+ )
156
+
157
+ if dataset_samples is not None:
158
+ if len(dataset_samples) != len(dataset_names):
159
+ raise ValueError(
160
+ f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
161
+ f"{len(dataset_samples)} samples."
162
+ )
163
+ dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
164
+ else:
165
+ dataset_samples = [None] * len(dataset_names)
166
+
167
+ splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
168
+
169
+ dataset_names_dict = []
170
+ for i, ds_name in enumerate(dataset_names):
171
+ dataset_names_dict.append(
172
+ {
173
+ "name": ds_name,
174
+ "split": splits[i],
175
+ "samples": dataset_samples[i],
176
+ }
177
+ )
178
+ return dataset_names_dict
179
+
180
+
181
+ def load_multiple_datasets(
182
+ accelerator: Accelerator,
183
+ dataset_names: Union[List, str],
184
+ splits: Optional[Union[List, str]] = None,
185
+ label_column_names: Optional[List] = None,
186
+ stopping_strategy: Optional[str] = "first_exhausted",
187
+ dataset_samples: Optional[Union[List, np.array]] = None,
188
+ streaming: Optional[bool] = False,
189
+ seed: Optional[int] = None,
190
+ id_column_name: Optional[str] = None,
191
+ columns_to_keep: Optional[Set[str]] = None,
192
+ prompt_column_name: Optional[str] = None,
193
+ sampling_rate: Optional[int] = None,
194
+ audio_column_name: Optional[str] = None,
195
+ logger: Optional[logging.Logger] = None,
196
+ librittsr_dir: Optional[Union[List, str]] = None,
197
+ other_dir: Optional[Union[List, str]] = None,
198
+ **kwargs,
199
+ ) -> Union[Dataset, IterableDataset]:
200
+ dataset_names_dict = convert_dataset_str_to_list(
201
+ dataset_names, splits, label_column_names, dataset_samples
202
+ )
203
+
204
+ if dataset_samples is not None:
205
+ dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
206
+ probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
207
+ else:
208
+ probabilities = None
209
+
210
+ all_datasets = []
211
+ # iterate over the datasets we want to interleave
212
+ for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."):
213
+ with accelerator.local_main_process_first():
214
+ dataset = load_dataset(
215
+ dataset_dict["name"],
216
+ split=dataset_dict["split"],
217
+ streaming=streaming,
218
+ **kwargs,
219
+ )
220
+ dataset_features = dataset.features.keys()
221
+
222
+ if columns_to_keep is not None:
223
+ dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
224
+
225
+ def resolve_path(example):
226
+ path = example["audio_path"]
227
+ source = example["source"]
228
+
229
+ if source == "libritts-r":
230
+ full_path = os.path.join(librittsr_dir, path)
231
+ else:
232
+ full_path = os.path.join(other_dir, path)
233
+
234
+ return os.path.exists(full_path)
235
+
236
+ dataset = dataset.filter(resolve_path, num_proc=16)
237
+
238
+ all_datasets.append(dataset)
239
+
240
+ if len(all_datasets) == 1:
241
+ # we have a single dataset so just return it as is
242
+ return all_datasets[0]
243
+
244
+ if streaming:
245
+ interleaved_dataset = interleave_datasets(
246
+ all_datasets,
247
+ stopping_strategy=stopping_strategy,
248
+ probabilities=probabilities,
249
+ seed=seed,
250
+ )
251
+ else:
252
+ with accelerator.local_main_process_first():
253
+ interleaved_dataset = concatenate_datasets(all_datasets)
254
+
255
+ return interleaved_dataset
capspeech/ar/training/data_capttsse.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional, Set, Union
4
+ import os
5
+ import datasets
6
+ import numpy as np
7
+ import torch
8
+ from accelerate import Accelerator
9
+ from datasets import Dataset, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset
10
+ from tqdm import tqdm
11
+ from transformers import AutoFeatureExtractor, AutoTokenizer
12
+ import torchaudio
13
+ import torchaudio.transforms as T
14
+
15
+ @dataclass
16
+ class DataCollatorEncodecWithPadding:
17
+ """
18
+ Data collator that will dynamically pad the inputs received to the longest sequence in the batch or
19
+ to `max_length` if `max_length` is set and `padding=max_length`.
20
+ """
21
+
22
+ feature_extractor: AutoFeatureExtractor
23
+ audio_column_name: str
24
+ librittsrmix_dir: Optional[str] = None
25
+ feature_extractor_input_name: Optional[str] = "input_values"
26
+ max_length: Optional[int] = None
27
+ padding: Optional[str] = "longest"
28
+
29
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
30
+ # split inputs and labels since they have to be of different lengths and need
31
+ # different padding methods
32
+ sampling_rate = self.feature_extractor.sampling_rate
33
+ # load audio
34
+ audios = []
35
+ for f in features:
36
+ path = f[self.audio_column_name]
37
+ source = f["source"]
38
+ if source == "libritts-r":
39
+ path = os.path.join(self.librittsrmix_dir, path)
40
+ else:
41
+ raise ValueError(source)
42
+
43
+ if os.path.exists(path):
44
+ waveform, sr = torchaudio.load(path)
45
+ if sr != sampling_rate:
46
+ resampler = T.Resample(orig_freq=sr, new_freq=sampling_rate)
47
+ waveform = resampler(waveform)
48
+ if waveform.shape[0] > 1:
49
+ waveform = waveform.mean(dim=0, keepdim=True)
50
+ audios.append(waveform.squeeze())
51
+ else:
52
+ print(f"Read error: {path}")
53
+
54
+
55
+ len_audio = [len(audio) for audio in audios]
56
+ if self.max_length is not None:
57
+ audios = [audio[: min(l, self.max_length)] for audio, l in zip(audios, len_audio)]
58
+
59
+ # since resampling has already been performed in the 'load_multiple_datasets' function,
60
+ # a fixed sampling_rate(44100hz) is passed to the feature_extractor.
61
+ batch = self.feature_extractor(
62
+ [np.asarray(a, dtype=np.float32) for a in audios], sampling_rate=sampling_rate, return_tensors="pt", padding=self.padding, max_length=self.max_length
63
+ )
64
+ batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
65
+ return batch
66
+
67
+
68
+ @dataclass
69
+ class DataCollatorParlerTTSWithPadding:
70
+ """
71
+ Data collator that will dynamically pad the inputs received.
72
+ Args:
73
+ prompt_tokenizer (:class:`~transformers.AutoTokenizer`)
74
+ The prompt_tokenizer used for proccessing the data.
75
+ description_tokenizer (:class:`~transformers.AutoTokenizer`)
76
+ The description_tokenizer used for proccessing the data.
77
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
78
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
79
+ among:
80
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
81
+ sequence if provided).
82
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
83
+ maximum acceptable input length for the model if that argument is not provided.
84
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
85
+ different lengths).
86
+ pad_to_multiple_of (:obj:`int`, `optional`):
87
+ If set will pad the sequence to a multiple of the provided value.
88
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
89
+ 7.5 (Volta).
90
+ """
91
+
92
+ prompt_tokenizer: AutoTokenizer
93
+ description_tokenizer: AutoTokenizer
94
+ padding: Union[bool, str] = "longest"
95
+ pad_to_multiple_of: Optional[int] = None
96
+ prompt_max_length: Optional[int] = None
97
+ description_max_length: Optional[int] = None
98
+ audio_max_length: Optional[int] = None
99
+
100
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
101
+ # split inputs and labels since they have to be of different lengths and need
102
+ # different padding methods
103
+
104
+ labels = [torch.tensor(feature["labels"]).transpose(0, 1) for feature in features]
105
+ # (bsz, seq_len, num_codebooks)
106
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
107
+ if self.audio_max_length is not None and self.padding == "max_length":
108
+ labels = torch.nn.functional.pad(
109
+ labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)), value=-100
110
+ )
111
+
112
+ input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
113
+
114
+ input_ids = self.description_tokenizer.pad(
115
+ input_ids,
116
+ return_tensors="pt",
117
+ padding=self.padding,
118
+ pad_to_multiple_of=self.pad_to_multiple_of,
119
+ max_length=self.description_max_length,
120
+ )
121
+
122
+ batch = {"labels": labels, **input_ids}
123
+
124
+ prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
125
+ prompt_input_ids = self.prompt_tokenizer.pad(
126
+ prompt_input_ids,
127
+ return_tensors="pt",
128
+ padding=self.padding,
129
+ pad_to_multiple_of=self.pad_to_multiple_of,
130
+ max_length=self.prompt_max_length,
131
+ )
132
+
133
+ batch["prompt_input_ids"] = prompt_input_ids["input_ids"]
134
+ if "attention_mask" in prompt_input_ids:
135
+ batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"]
136
+
137
+ return batch
138
+
139
+
140
+ def convert_dataset_str_to_list(
141
+ dataset_names,
142
+ splits=None,
143
+ dataset_samples=None,
144
+ default_split="train",
145
+ ):
146
+ if isinstance(dataset_names, str):
147
+ dataset_names = dataset_names.split("+")
148
+ splits = splits.split("+") if splits is not None else None
149
+ dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None
150
+
151
+ if splits is not None and len(splits) != len(dataset_names):
152
+ raise ValueError(
153
+ f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
154
+ )
155
+
156
+ if dataset_samples is not None:
157
+ if len(dataset_samples) != len(dataset_names):
158
+ raise ValueError(
159
+ f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
160
+ f"{len(dataset_samples)} samples."
161
+ )
162
+ dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
163
+ else:
164
+ dataset_samples = [None] * len(dataset_names)
165
+
166
+ splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
167
+
168
+ dataset_names_dict = []
169
+ for i, ds_name in enumerate(dataset_names):
170
+ dataset_names_dict.append(
171
+ {
172
+ "name": ds_name,
173
+ "split": splits[i],
174
+ "samples": dataset_samples[i],
175
+ }
176
+ )
177
+ return dataset_names_dict
178
+
179
+
180
+ def load_multiple_datasets(
181
+ accelerator: Accelerator,
182
+ dataset_names: Union[List, str],
183
+ splits: Optional[Union[List, str]] = None,
184
+ label_column_names: Optional[List] = None,
185
+ stopping_strategy: Optional[str] = "first_exhausted",
186
+ dataset_samples: Optional[Union[List, np.array]] = None,
187
+ streaming: Optional[bool] = False,
188
+ seed: Optional[int] = None,
189
+ id_column_name: Optional[str] = None,
190
+ columns_to_keep: Optional[Set[str]] = None,
191
+ prompt_column_name: Optional[str] = None,
192
+ sampling_rate: Optional[int] = None,
193
+ audio_column_name: Optional[str] = None,
194
+ logger: Optional[logging.Logger] = None,
195
+ librittsrmix_dir: Optional[Union[List, str]] = None,
196
+ **kwargs,
197
+ ) -> Union[Dataset, IterableDataset]:
198
+ dataset_names_dict = convert_dataset_str_to_list(
199
+ dataset_names, splits, label_column_names, dataset_samples
200
+ )
201
+
202
+ if dataset_samples is not None:
203
+ dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
204
+ probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
205
+ else:
206
+ probabilities = None
207
+
208
+ all_datasets = []
209
+ # iterate over the datasets we want to interleave
210
+ for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."):
211
+ with accelerator.local_main_process_first():
212
+ dataset = load_dataset(
213
+ dataset_dict["name"],
214
+ split=dataset_dict["split"],
215
+ streaming=streaming,
216
+ **kwargs,
217
+ )
218
+ dataset_features = dataset.features.keys()
219
+
220
+ if columns_to_keep is not None:
221
+ dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
222
+
223
+ def resolve_path(example):
224
+ path = example["audio_path"]
225
+ source = example["source"]
226
+
227
+ if source == "libritts-r":
228
+ full_path = os.path.join(librittsrmix_dir, path)
229
+ else:
230
+ return False # unknown source
231
+
232
+ return os.path.exists(full_path)
233
+
234
+ dataset = dataset.filter(resolve_path, num_proc=16)
235
+
236
+ all_datasets.append(dataset)
237
+
238
+ if len(all_datasets) == 1:
239
+ # we have a single dataset so just return it as is
240
+ return all_datasets[0]
241
+
242
+ if streaming:
243
+ interleaved_dataset = interleave_datasets(
244
+ all_datasets,
245
+ stopping_strategy=stopping_strategy,
246
+ probabilities=probabilities,
247
+ seed=seed,
248
+ )
249
+ else:
250
+ with accelerator.local_main_process_first():
251
+ interleaved_dataset = concatenate_datasets(all_datasets)
252
+
253
+ return interleaved_dataset
capspeech/ar/training/finetune_captts.py ADDED
@@ -0,0 +1,1270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """ Train Parler-TTS using 🤗 Accelerate"""
18
+
19
+ import logging
20
+ import os
21
+ import re
22
+ import sys
23
+ import time
24
+ import math
25
+ import contextlib
26
+ from multiprocess import set_start_method
27
+ from datetime import timedelta
28
+ import inspect
29
+ from tqdm import tqdm
30
+ from pathlib import Path
31
+ import wandb
32
+
33
+ import torch
34
+ from torch.utils.data import DataLoader
35
+
36
+ import datasets
37
+ from datasets import DatasetDict, Dataset, IterableDataset, concatenate_datasets
38
+
39
+ from huggingface_hub import HfApi
40
+
41
+ import transformers
42
+ from transformers import AutoFeatureExtractor, AutoTokenizer, HfArgumentParser
43
+ from transformers.trainer_pt_utils import LengthGroupedSampler
44
+ from transformers.optimization import get_scheduler
45
+ from transformers.utils import send_example_telemetry
46
+
47
+
48
+ from accelerate import Accelerator, skip_first_batches
49
+ from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin, DistributedDataParallelKwargs
50
+ from accelerate.utils.memory import release_memory
51
+
52
+ from parler_tts import (
53
+ ParlerTTSConfig,
54
+ ParlerTTSForConditionalGeneration,
55
+ build_delay_pattern_mask,
56
+ )
57
+
58
+ from training.utils import (
59
+ get_last_checkpoint,
60
+ rotate_checkpoints,
61
+ log_pred,
62
+ log_metric,
63
+ load_all_codec_checkpoints,
64
+ save_codec_checkpoint,
65
+ get_last_codec_checkpoint_step,
66
+ )
67
+ from training.arguments_captts import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments
68
+ from training.data_captts import load_multiple_datasets, DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding
69
+ from training.eval import clap_similarity, wer, si_sdr
70
+
71
+ logger = logging.getLogger(__name__)
72
+
73
+
74
+ def main():
75
+ # See all possible arguments in src/transformers/training_args.py
76
+ # or by passing the --help flag to this script.
77
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
78
+
79
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments))
80
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
81
+ # If we pass only one argument to the script and it's the path to a json file,
82
+ # let's parse it to get our arguments.
83
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
84
+ else:
85
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
86
+
87
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
88
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
89
+ send_example_telemetry("run_parler_tts", model_args, data_args)
90
+
91
+ if data_args.wandb_key is not None:
92
+ wandb.login(key=data_args.wandb_key)
93
+
94
+ if training_args.dtype == "float16":
95
+ mixed_precision = "fp16"
96
+ torch_dtype = torch.float16
97
+ elif training_args.dtype == "bfloat16":
98
+ mixed_precision = "bf16"
99
+ torch_dtype = torch.bfloat16
100
+ else:
101
+ mixed_precision = "no"
102
+ torch_dtype = torch.float32
103
+
104
+ if data_args.pad_to_max_length and (
105
+ data_args.max_duration_in_seconds is None
106
+ or data_args.max_prompt_token_length is None
107
+ or data_args.max_description_token_length is None
108
+ ):
109
+ raise ValueError(
110
+ "`pad_to_max_length` is `True` but one of the following parameters has not been set: `max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`"
111
+ )
112
+
113
+ padding = "max_length" if data_args.pad_to_max_length else "longest"
114
+
115
+ ####### A. Preparation
116
+ kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=120)), DistributedDataParallelKwargs(find_unused_parameters=False)]
117
+
118
+ accelerator = Accelerator(
119
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
120
+ mixed_precision=mixed_precision,
121
+ log_with=training_args.report_to,
122
+ project_dir=training_args.output_dir,
123
+ kwargs_handlers=kwargs_handlers,
124
+ )
125
+
126
+ accelerator.init_trackers(
127
+ project_name=data_args.wandb_project,
128
+ config={
129
+ "learning_rate": training_args.learning_rate,
130
+ "model_name_or_path": model_args.model_name_or_path,
131
+ "num_train_epochs": training_args.num_train_epochs,
132
+ "gradient_accumulation_steps": training_args.gradient_accumulation_steps,
133
+ "per_device_train_batch_size": training_args.per_device_train_batch_size,
134
+ "global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
135
+ "mixed_precision": mixed_precision,
136
+ "lr_scheduler_type": training_args.lr_scheduler_type,
137
+ "warmup_steps": training_args.warmup_steps,
138
+ "freeze_text_encoder": model_args.freeze_text_encoder,
139
+ "max_duration_in_seconds": data_args.max_duration_in_seconds,
140
+ "weight_decay": training_args.weight_decay,
141
+ "adam_beta1": training_args.adam_beta1,
142
+ "adam_beta2": training_args.adam_beta2,
143
+ "temperature": model_args.temperature,
144
+ },
145
+ init_kwargs={"wandb": {"name": data_args.wandb_run_name}} if data_args.wandb_run_name else {},
146
+ )
147
+
148
+ # Detecting last checkpoint and eventually continue from last checkpoint
149
+ last_checkpoint = None
150
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
151
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
152
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
153
+ raise ValueError(
154
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
155
+ "Use --overwrite_output_dir to overcome."
156
+ )
157
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
158
+ logger.info(
159
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
160
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
161
+ )
162
+
163
+ # Setup logging
164
+ logging.basicConfig(
165
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
166
+ datefmt="%m/%d/%Y %H:%M:%S",
167
+ handlers=[logging.StreamHandler(sys.stdout)],
168
+ )
169
+ logger.setLevel(logging.INFO if accelerator.is_main_process else logging.WARN)
170
+
171
+ # Log a small summary on each proces
172
+ logger.warning(
173
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
174
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
175
+ )
176
+
177
+ # Set the verbosity to info of the Transformers logger (on main process only)
178
+ if accelerator.is_local_main_process:
179
+ datasets.utils.logging.set_verbosity_warning()
180
+ transformers.utils.logging.set_verbosity_info()
181
+ else:
182
+ datasets.utils.logging.set_verbosity_error()
183
+ transformers.utils.logging.set_verbosity_error()
184
+
185
+ logger.info("Training/evaluation parameters %s", training_args)
186
+
187
+ # Set seed before initializing model.
188
+ set_seed(training_args.seed)
189
+ num_workers = data_args.preprocessing_num_workers
190
+
191
+ # 1. First, lett's instantiate the feature extractor, tokenizers and model
192
+ # Note for distributed training, the .from_pretrained methods guarantee that only
193
+ # one local process can concurrently download model & vocab.
194
+
195
+ # load feature extractor
196
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
197
+ model_args.feature_extractor_name or model_args.model_name_or_path,
198
+ cache_dir=model_args.cache_dir,
199
+ token=data_args.token,
200
+ trust_remote_code=data_args.trust_remote_code,
201
+ )
202
+ sampling_rate = feature_extractor.sampling_rate
203
+
204
+ # load prompt tokenizer
205
+ prompt_tokenizer = AutoTokenizer.from_pretrained(
206
+ model_args.prompt_tokenizer_name or model_args.description_tokenizer_name or model_args.model_name_or_path,
207
+ cache_dir=model_args.cache_dir,
208
+ token=data_args.token,
209
+ trust_remote_code=data_args.trust_remote_code,
210
+ use_fast=model_args.use_fast_tokenizer,
211
+ padding_side=model_args.prompt_padding_side,
212
+ )
213
+
214
+ # load description tokenizer
215
+ description_tokenizer = AutoTokenizer.from_pretrained(
216
+ model_args.description_tokenizer_name or model_args.model_name_or_path,
217
+ cache_dir=model_args.cache_dir,
218
+ token=data_args.token,
219
+ trust_remote_code=data_args.trust_remote_code,
220
+ use_fast=model_args.use_fast_tokenizer,
221
+ )
222
+
223
+ if model_args.use_fast_tokenizer:
224
+ logger.warning(
225
+ "Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235"
226
+ )
227
+ prompt_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
228
+ description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
229
+
230
+ # 2. Now, let's load the dataset
231
+
232
+ if data_args.save_to_disk is not None:
233
+ os.makedirs(data_args.save_to_disk, exist_ok=True)
234
+
235
+ # assume that the dataset has been saved to `save_to_disk` if the latter is not empty
236
+ dataset_was_precomputed = len(os.listdir(data_args.save_to_disk)) > 0
237
+ if dataset_was_precomputed:
238
+ with accelerator.local_main_process_first():
239
+ vectorized_datasets = datasets.load_from_disk(data_args.save_to_disk)
240
+ else:
241
+ raw_datasets = DatasetDict()
242
+
243
+ columns_to_keep = {
244
+ "target_audio_column_name": data_args.target_audio_column_name,
245
+ "prompt_column_name": data_args.prompt_column_name,
246
+ "source": data_args.source_column_name,
247
+ }
248
+ if data_args.description_column_name is not None:
249
+ columns_to_keep["description_column_name"] = data_args.description_column_name
250
+
251
+ if training_args.do_train:
252
+ raw_datasets["train"] = load_multiple_datasets(
253
+ accelerator,
254
+ data_args.train_dataset_name,
255
+ splits=data_args.train_split_name,
256
+ dataset_samples=data_args.train_dataset_samples,
257
+ seed=training_args.seed,
258
+ cache_dir=model_args.cache_dir,
259
+ num_proc=data_args.preprocessing_num_workers,
260
+ id_column_name=data_args.id_column_name,
261
+ columns_to_keep=columns_to_keep.values(),
262
+ prompt_column_name=data_args.prompt_column_name,
263
+ audio_column_name=data_args.target_audio_column_name,
264
+ sampling_rate=sampling_rate,
265
+ logger=logger,
266
+ librittsr_dir=data_args.librittsr_dir,
267
+ other_dir=data_args.other_dir,
268
+ # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
269
+ )
270
+
271
+ for key in columns_to_keep:
272
+ if columns_to_keep[key] not in raw_datasets["train"].column_names:
273
+ raise ValueError(
274
+ f"--{key} '{columns_to_keep[key]}' not found in dataset '{data_args.train_dataset_name}'."
275
+ f" Make sure to set `--{key}` to the correct audio column - one of"
276
+ f" {', '.join(raw_datasets['train'].column_names)}."
277
+ )
278
+
279
+ if data_args.max_train_samples is not None:
280
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
281
+
282
+ if training_args.do_eval:
283
+ raw_datasets["eval"] = load_multiple_datasets(
284
+ accelerator,
285
+ data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
286
+ splits=data_args.eval_split_name,
287
+ cache_dir=model_args.cache_dir,
288
+ num_proc=data_args.preprocessing_num_workers,
289
+ id_column_name=data_args.id_column_name,
290
+ columns_to_keep=columns_to_keep.values(),
291
+ prompt_column_name=data_args.prompt_column_name,
292
+ audio_column_name=data_args.target_audio_column_name,
293
+ sampling_rate=sampling_rate,
294
+ logger=logger,
295
+ librittsr_dir=data_args.librittsr_dir,
296
+ other_dir=data_args.other_dir,
297
+ # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
298
+ )
299
+
300
+ if data_args.max_eval_samples is not None:
301
+ with accelerator.local_main_process_first():
302
+ raw_datasets["eval"] = (
303
+ raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
304
+ )
305
+
306
+ # 3. Next, let's load the config.
307
+ config = ParlerTTSConfig.from_pretrained(
308
+ model_args.model_name_or_path,
309
+ cache_dir=model_args.cache_dir,
310
+ token=data_args.token,
311
+ trust_remote_code=data_args.trust_remote_code,
312
+ )
313
+
314
+ if training_args.codebook_weights is not None and len(training_args.codebook_weights) != config.decoder.num_codebooks:
315
+ raise ValueError(f"`codebook_weights` has length {len(training_args.codebook_weights)} when it should be of length {config.decoder.num_codebooks}.")
316
+
317
+ # update pad token id and decoder_start_token_id
318
+ config.decoder.update(
319
+ {
320
+ "cross_attention_implementation_strategy": model_args.cross_attention_implementation_strategy
321
+ if model_args.cross_attention_implementation_strategy is not None
322
+ else None,
323
+ "codebook_weights": training_args.codebook_weights if training_args.codebook_weights is not None else config.decoder.codebook_weights
324
+ }
325
+ )
326
+ config.update(
327
+ {
328
+ "pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id,
329
+ "decoder_start_token_id": model_args.decoder_start_token_id
330
+ if model_args.decoder_start_token_id is not None
331
+ else config.decoder_start_token_id,
332
+ }
333
+ )
334
+
335
+ with open("events.txt", "r") as f:
336
+ events = [line.strip() for line in f]
337
+ events = ["<"+event.lower().replace(" ", "_")+">" for event in events]
338
+ events.append("<B_start>")
339
+ events.append("<B_end>")
340
+ events.append("<I_start>")
341
+ events.append("<I_end>")
342
+
343
+ special_tokens = {"additional_special_tokens": events}
344
+ prompt_tokenizer.add_special_tokens(special_tokens)
345
+ description_tokenizer.add_special_tokens(special_tokens)
346
+ padded_vocab_size = ((len(prompt_tokenizer) + 127) // 128) * 128
347
+ config.vocab_size = padded_vocab_size
348
+
349
+ # create model
350
+ model = ParlerTTSForConditionalGeneration.from_pretrained(
351
+ model_args.model_name_or_path,
352
+ ignore_mismatched_sizes=True,
353
+ cache_dir=model_args.cache_dir,
354
+ config=config,
355
+ token=data_args.token,
356
+ trust_remote_code=data_args.trust_remote_code,
357
+ attn_implementation={"decoder": model_args.attn_implementation, "text_encoder": "eager"},
358
+ )
359
+ model.text_encoder.resize_token_embeddings(padded_vocab_size)
360
+
361
+ # enable gradient checkpointing if necessary
362
+ if training_args.gradient_checkpointing:
363
+ model.gradient_checkpointing_enable()
364
+
365
+ # 4. Now we preprocess the datasets including loading the audio, resampling and normalization
366
+ # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
367
+ # so that we just need to set the correct target sampling rate and normalize the input
368
+ # via the `feature_extractor`
369
+
370
+ # derive max & min input length for sample rate & max duration
371
+ sampling_rate = feature_extractor.sampling_rate
372
+ max_target_length = int(data_args.max_duration_in_seconds * sampling_rate)
373
+ min_target_length = int(data_args.min_duration_in_seconds * sampling_rate)
374
+ target_audio_column_name = data_args.target_audio_column_name
375
+ description_column_name = data_args.description_column_name
376
+ prompt_column_name = data_args.prompt_column_name
377
+ feature_extractor_input_name = feature_extractor.model_input_names[0]
378
+ audio_encoder_pad_token_id = config.decoder.pad_token_id
379
+ audio_encoder_eos_token_id = config.decoder.eos_token_id
380
+ audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id
381
+ max_length = model.generation_config.max_length
382
+ num_codebooks = model.decoder.config.num_codebooks
383
+ bandwidth = model_args.bandwidth
384
+ attn_implementation = model_args.attn_implementation
385
+
386
+ # Freeze Encoders
387
+ model.freeze_encoders(model_args.freeze_text_encoder)
388
+
389
+ # Test all gather - used for warmout and avoiding timeout
390
+ logger.debug(str(accelerator.process_index), main_process_only=False, in_order=True)
391
+ test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
392
+ gathered_tensor = accelerator.gather(test_tensor)
393
+ print("gathered_tensor", gathered_tensor)
394
+ accelerator.wait_for_everyone()
395
+
396
+ if not dataset_was_precomputed:
397
+ # Filter on text length
398
+ if description_column_name is not None and data_args.max_text_length is not None:
399
+ with accelerator.local_main_process_first():
400
+ # filter description that is shorter than max_text_length
401
+ raw_datasets = raw_datasets.filter(
402
+ lambda x: len(x) < data_args.max_text_length,
403
+ num_proc=num_workers,
404
+ input_columns=[description_column_name],
405
+ )
406
+
407
+ # Preprocessing the dataset.
408
+ # We need to tokenize the texts.
409
+ def pass_through_processors(description, prompt):
410
+ batch = {}
411
+
412
+ batch["input_ids"] = description_tokenizer(description.strip())["input_ids"]
413
+ batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"]
414
+
415
+ return batch
416
+
417
+ with accelerator.local_main_process_first():
418
+ # this is a trick to avoid to rewrite the entire audio column which takes ages
419
+ vectorized_datasets = raw_datasets.map(
420
+ pass_through_processors,
421
+ remove_columns=next(iter(raw_datasets.values())).column_names,
422
+ input_columns=[description_column_name, prompt_column_name],
423
+ num_proc=num_workers,
424
+ desc="preprocess datasets",
425
+ )
426
+
427
+ # We use Accelerate to perform distributed inference
428
+ # T5 doesn't support fp16
429
+ autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
430
+
431
+ # Now we encode the audio labels with encodec.
432
+ ####### B. Encode audio
433
+
434
+ logger.info("*** Encode target audio with encodec ***")
435
+
436
+ # no need to prepare audio_decoder because used for inference without mixed precision
437
+ # see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
438
+ if training_args.torch_compile:
439
+ audio_decoder = accelerator.prepare_model(model.audio_encoder, evaluation_mode=True)
440
+ else:
441
+ audio_decoder = model.audio_encoder
442
+
443
+ encoder_data_collator = DataCollatorEncodecWithPadding(
444
+ feature_extractor,
445
+ audio_column_name=target_audio_column_name,
446
+ librittsr_dir=data_args.librittsr_dir,
447
+ other_dir=data_args.other_dir,
448
+ feature_extractor_input_name=feature_extractor_input_name,
449
+ max_length=max_target_length,
450
+ padding=padding,
451
+ )
452
+ encoder_signature = set(inspect.signature(audio_decoder.forward).parameters)
453
+
454
+ def apply_audio_decoder(batch):
455
+ len_audio = batch.pop("len_audio")
456
+ audio_decoder.to(batch["input_values"].device).eval()
457
+ if bandwidth is not None:
458
+ batch["bandwidth"] = bandwidth
459
+ elif "num_quantizers" in encoder_signature:
460
+ batch["num_quantizers"] = num_codebooks
461
+ elif "num_codebooks" in encoder_signature:
462
+ batch["num_codebooks"] = num_codebooks
463
+ elif "n_quantizers" in encoder_signature:
464
+ batch["n_quantizers"] = num_codebooks
465
+
466
+ with torch.no_grad():
467
+ labels = audio_decoder.encode(**batch)["audio_codes"]
468
+ output = {}
469
+ output["len_audio"] = len_audio
470
+ # (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
471
+ output["labels"] = labels.squeeze(0).transpose(1, 2)
472
+
473
+ # if `pad_to_max_length`, the maximum corresponding audio length of the current batch is max_duration*sampling_rate
474
+ max_length = len_audio.max() if padding != "max_length" else max_target_length
475
+ output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / max_length
476
+ return output
477
+
478
+ # (1, codebooks, seq_len) where seq_len=1
479
+ bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
480
+
481
+ def postprocess_dataset(labels):
482
+ # (1, codebooks, seq_len)
483
+ labels = torch.tensor(labels).unsqueeze(0)
484
+ # add bos
485
+ labels = torch.cat([bos_labels, labels], dim=-1)
486
+
487
+ labels, delay_pattern_mask = build_delay_pattern_mask(
488
+ labels,
489
+ bos_token_id=audio_encoder_bos_token_id,
490
+ pad_token_id=audio_encoder_eos_token_id,
491
+ max_length=labels.shape[-1] + num_codebooks,
492
+ num_codebooks=num_codebooks,
493
+ )
494
+
495
+ # the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
496
+ # to take care of EOS
497
+ # we want labels to look like this:
498
+ # - [B, a, b, E, E, E, E]
499
+ # - [B, B, c, d, E, E, E]
500
+ # - [B, B, B, e, f, E, E]
501
+ # - [B, B, B, B, g, h, E]
502
+ labels = torch.where(delay_pattern_mask == -1, audio_encoder_eos_token_id, delay_pattern_mask)
503
+
504
+ # the first timestamp is associated to a row full of BOS, let's get rid of it
505
+ # we also remove the last timestampts (full of PAD)
506
+ output = {"labels": labels[:, 1:]}
507
+ return output
508
+
509
+ for split in vectorized_datasets:
510
+ data_loader = DataLoader(
511
+ raw_datasets[split],
512
+ batch_size=training_args.audio_encoder_per_device_batch_size,
513
+ collate_fn=encoder_data_collator,
514
+ num_workers=training_args.dataloader_num_workers,
515
+ pin_memory=True,
516
+ )
517
+ data_loader = accelerator.prepare(data_loader)
518
+ total_inference_steps = len(data_loader)
519
+
520
+ start_step = get_last_codec_checkpoint_step(os.path.join(data_args.temporary_save_to_disk, split))
521
+ accelerator.wait_for_everyone()
522
+ if start_step > 0:
523
+ logger.info(f"Resuming {split} from step {start_step}")
524
+ # efficiently skip the first n batches
525
+ start_step += 1
526
+ data_loader = skip_first_batches(data_loader, start_step)
527
+
528
+ all_generated_labels = []
529
+ all_lens = []
530
+ if start_step < total_inference_steps:
531
+ for i, batch in enumerate(tqdm(data_loader, disable=not accelerator.is_local_main_process)):
532
+ cur_step = start_step + i
533
+ generate_labels = apply_audio_decoder(batch)
534
+ generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0)
535
+ generate_labels = accelerator.gather_for_metrics(generate_labels)
536
+
537
+ if accelerator.is_main_process:
538
+ lab = generate_labels["labels"].cpu().transpose(1, 2).to(torch.int16)
539
+ rat = generate_labels["ratio"].cpu().squeeze(1)
540
+ lens = generate_labels["len_audio"].cpu().squeeze(1)
541
+ lab = [l[:, : int(ratio * length)] for (l, ratio, length) in zip(lab, rat, lens)]
542
+
543
+ all_generated_labels.extend(lab)
544
+ all_lens.extend(lens)
545
+
546
+ if ((cur_step + 1) % data_args.save_codec_steps == 0) or (
547
+ cur_step == total_inference_steps - 1
548
+ ):
549
+ tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
550
+ tmp_labels = tmp_labels.map(
551
+ postprocess_dataset,
552
+ num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
553
+ input_columns=["labels"],
554
+ desc="Postprocessing labeling",
555
+ )
556
+ save_codec_checkpoint(
557
+ os.path.join(data_args.temporary_save_to_disk, split), tmp_labels, cur_step
558
+ )
559
+ all_generated_labels = []
560
+ all_lens = []
561
+
562
+ accelerator.wait_for_everyone()
563
+
564
+ if accelerator.is_main_process and len(all_generated_labels) > 0:
565
+ tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
566
+ tmp_labels = tmp_labels.map(
567
+ postprocess_dataset,
568
+ num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
569
+ input_columns=["labels"],
570
+ desc="Postprocessing labeling",
571
+ )
572
+ save_codec_checkpoint(os.path.join(data_args.temporary_save_to_disk, split), tmp_labels, cur_step)
573
+ all_generated_labels = []
574
+ all_lens = []
575
+ accelerator.wait_for_everyone()
576
+
577
+ del all_generated_labels
578
+ accelerator.wait_for_everyone()
579
+
580
+ with accelerator.local_main_process_first():
581
+ tmp_labels = load_all_codec_checkpoints(os.path.join(data_args.temporary_save_to_disk, split)).select(
582
+ range(len(vectorized_datasets[split]))
583
+ )
584
+ logger.info(f"Concatenating {split}: {tmp_labels} with {vectorized_datasets[split]}")
585
+ vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1)
586
+
587
+ accelerator.free_memory()
588
+ del generate_labels, all_lens
589
+
590
+ with accelerator.local_main_process_first():
591
+ # NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations
592
+ # caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets.
593
+ # That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets.
594
+
595
+ def is_audio_in_length_range(length):
596
+ return length > min_target_length and length < max_target_length
597
+
598
+ # filter data that is shorter than min_target_length
599
+ vectorized_datasets = vectorized_datasets.filter(
600
+ is_audio_in_length_range,
601
+ num_proc=num_workers,
602
+ input_columns=["target_length"],
603
+ )
604
+
605
+ if description_column_name is not None and data_args.max_description_token_length is not None:
606
+ with accelerator.local_main_process_first():
607
+ # filter description that is shorter than max_text_length
608
+ vectorized_datasets = vectorized_datasets.filter(
609
+ lambda x: len(x) < data_args.max_description_token_length,
610
+ num_proc=num_workers,
611
+ input_columns=["input_ids"],
612
+ )
613
+
614
+ if data_args.max_prompt_token_length is not None:
615
+ with accelerator.local_main_process_first():
616
+ # filter description that is shorter than max_text_length
617
+ vectorized_datasets = vectorized_datasets.filter(
618
+ lambda x: len(x) < data_args.max_prompt_token_length,
619
+ num_proc=num_workers,
620
+ input_columns=["prompt_input_ids"],
621
+ )
622
+
623
+ if data_args.save_to_disk is not None and not dataset_was_precomputed:
624
+ if accelerator.is_main_process:
625
+ vectorized_datasets.save_to_disk(
626
+ data_args.save_to_disk,
627
+ num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"]) - 1),
628
+ )
629
+ accelerator.wait_for_everyone()
630
+ logger.info(f"Dataset saved at {data_args.save_to_disk}")
631
+
632
+ audio_max_length = None
633
+ if padding == "max_length":
634
+ audio_max_length = max(vectorized_datasets["train"]["target_length"])
635
+ with accelerator.local_main_process_first():
636
+ max_sample = vectorized_datasets["train"].filter(
637
+ lambda x: x == audio_max_length,
638
+ num_proc=num_workers,
639
+ input_columns=["target_length"],
640
+ )
641
+ audio_max_length = max([len(l[0]) for l in max_sample["labels"]])
642
+
643
+ if description_column_name is not None and data_args.max_description_token_length is not None:
644
+ with accelerator.local_main_process_first():
645
+ # filter description that is shorter than max_text_length
646
+ vectorized_datasets = vectorized_datasets.filter(
647
+ lambda x: len(x) < data_args.max_description_token_length,
648
+ num_proc=num_workers,
649
+ input_columns=["input_ids"],
650
+ )
651
+
652
+ if data_args.max_prompt_token_length is not None:
653
+ with accelerator.local_main_process_first():
654
+ # filter description that is shorter than max_text_length
655
+ vectorized_datasets = vectorized_datasets.filter(
656
+ lambda x: len(x) < data_args.max_prompt_token_length,
657
+ num_proc=num_workers,
658
+ input_columns=["prompt_input_ids"],
659
+ )
660
+
661
+ if training_args.group_by_length:
662
+ # apply a simple heuristic to take into account audio and text lengths
663
+ def add_target_lengths(target_length, prompt, description):
664
+ return {"target_length": target_length + len(prompt) + len(description)}
665
+
666
+ with accelerator.local_main_process_first():
667
+ vectorized_datasets = vectorized_datasets.map(
668
+ add_target_lengths,
669
+ num_proc=num_workers,
670
+ input_columns=["target_length", "prompt_input_ids", "input_ids"],
671
+ )
672
+
673
+ # for large datasets it is advised to run the preprocessing on a
674
+ # single machine first with ``args.preprocessing_only`` since there will mostly likely
675
+ # be a timeout when running the script in distributed mode.
676
+ # In a second step ``args.preprocessing_only`` can then be set to `False` to load the
677
+ # cached dataset
678
+ if data_args.preprocessing_only and data_args.save_to_disk is None:
679
+ raise ValueError(
680
+ "`preprocessing_only=True` but `save_to_disk` is not set. The latter should indicates where to save the dataset locally."
681
+ )
682
+ elif data_args.preprocessing_only:
683
+ logger.info(f"Data preprocessing finished. Files save at {data_args.save_to_disk}")
684
+ return
685
+
686
+ # 6. Next, we can prepare the training.
687
+
688
+ # Let's use word CLAP similary and WER metrics as our evaluation metrics,
689
+ def compute_metrics(
690
+ audios,
691
+ descriptions,
692
+ prompts,
693
+ device="cpu",
694
+ compute_clap_similarity_metric=False,
695
+ compute_noise_level_metric=False,
696
+ noise_level_to_compute_clean_wer=None,
697
+ ):
698
+ results = {}
699
+ input_ids = descriptions
700
+ texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
701
+ prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
702
+ audios = [a.float().cpu().numpy() for a in audios]
703
+
704
+ if compute_clap_similarity_metric:
705
+ clap_score = clap_similarity(
706
+ model_args.clap_model_name_or_path, texts, audios, device, input_sampling_rate=sampling_rate
707
+ )
708
+ results["clap"] = clap_score
709
+
710
+ si_sdr_measures = None
711
+ if compute_noise_level_metric:
712
+ si_sdr_measures = si_sdr(audios, device, input_sampling_rate=sampling_rate)
713
+
714
+ word_error, transcriptions, clean_word_error, noisy_word_error, percent_clean_samples = wer(
715
+ model_args.asr_model_name_or_path,
716
+ prompts,
717
+ audios,
718
+ device,
719
+ training_args.per_device_eval_batch_size,
720
+ sampling_rate,
721
+ noise_level_to_compute_clean_wer,
722
+ si_sdr_measures,
723
+ )
724
+ results["wer"] = word_error
725
+ if clean_word_error is not None:
726
+ results["clean_wer"] = clean_word_error
727
+ results["noisy_word_error"] = noisy_word_error
728
+ results["percent_clean_samples"] = percent_clean_samples
729
+
730
+ return results, texts, prompts, audios, transcriptions, si_sdr_measures
731
+
732
+ # Define Training Schedule
733
+ # Store some constants
734
+ per_device_train_batch_size = int(training_args.per_device_train_batch_size)
735
+ train_batch_size = per_device_train_batch_size * accelerator.num_processes
736
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
737
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
738
+
739
+ if training_args.max_steps < 0:
740
+ num_epochs = int(training_args.num_train_epochs)
741
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
742
+ total_train_steps = steps_per_epoch * num_epochs
743
+ elif training_args.max_steps > 0:
744
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
745
+ total_train_steps = int(training_args.max_steps)
746
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
747
+ num_epochs = sys.maxsize
748
+ steps_per_epoch = total_train_steps
749
+
750
+ if training_args.eval_steps is None:
751
+ logger.info(f"eval_steps is not set, evaluating at the end of each epoch")
752
+ eval_steps = steps_per_epoch
753
+ else:
754
+ eval_steps = training_args.eval_steps
755
+
756
+ if training_args.eval_generation_steps is None:
757
+ eval_generation_steps = eval_steps
758
+ else:
759
+ eval_generation_steps = training_args.eval_generation_steps
760
+
761
+ # T5 doesn't support fp16
762
+ autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
763
+
764
+ # Define optimizer, LR scheduler, collator
765
+ optimizer = torch.optim.AdamW(
766
+ params=model.parameters(),
767
+ lr=training_args.learning_rate,
768
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
769
+ eps=training_args.adam_epsilon,
770
+ weight_decay=training_args.weight_decay,
771
+ )
772
+
773
+ # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
774
+ lr_scheduler = get_scheduler(
775
+ name=training_args.lr_scheduler_type,
776
+ optimizer=optimizer,
777
+ num_warmup_steps=training_args.get_warmup_steps(total_train_steps) * accelerator.num_processes,
778
+ num_training_steps=total_train_steps * accelerator.num_processes,
779
+ )
780
+
781
+ # Instantiate custom data collator
782
+ data_collator = DataCollatorParlerTTSWithPadding(
783
+ prompt_tokenizer=prompt_tokenizer,
784
+ description_tokenizer=description_tokenizer,
785
+ pad_to_multiple_of=data_args.pad_to_multiple_of,
786
+ padding=padding,
787
+ prompt_max_length=data_args.max_prompt_token_length,
788
+ description_max_length=data_args.max_description_token_length,
789
+ audio_max_length=audio_max_length,
790
+ )
791
+
792
+ # Prepare everything with accelerate
793
+ model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
794
+
795
+ num_examples = total_train_steps * train_batch_size * gradient_accumulation_steps
796
+ logger.info("***** Running training *****")
797
+ logger.info(f" Num examples = {num_examples}")
798
+ logger.info(" Instantaneous batch size per device =" f" {per_device_train_batch_size}")
799
+ logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
800
+ logger.info(
801
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
802
+ )
803
+ logger.info(f" Total optimization steps = {total_train_steps}")
804
+
805
+ # ======================== Training ================================
806
+ train_time = 0
807
+ train_start = time.time()
808
+ steps_trained_progress_bar = tqdm(
809
+ range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
810
+ )
811
+ continue_training = True
812
+ epochs_trained = 0
813
+ cur_step = 0
814
+
815
+ checkpoint = None
816
+ if training_args.resume_from_checkpoint is not None:
817
+ checkpoint = training_args.resume_from_checkpoint
818
+ elif last_checkpoint is not None:
819
+ checkpoint = last_checkpoint
820
+
821
+ if accelerator.is_main_process:
822
+ if training_args.push_to_hub:
823
+ api = HfApi(token=training_args.hub_token)
824
+
825
+ # Create repo (repo_name from args or inferred)
826
+ repo_name = training_args.hub_model_id
827
+ if repo_name is None:
828
+ repo_name = Path(training_args.output_dir).absolute().name
829
+ repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
830
+
831
+ with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
832
+ if "wandb" not in gitignore:
833
+ gitignore.write("wandb\n")
834
+ elif training_args.output_dir is not None:
835
+ os.makedirs(training_args.output_dir, exist_ok=True)
836
+ accelerator.wait_for_everyone()
837
+
838
+ # Now save everything to be able to create a single processor later
839
+ # make sure all processes wait until data is saved
840
+ # only the main process saves them
841
+ if accelerator.is_main_process:
842
+ # save feature extractor, tokenizer and config
843
+ if (
844
+ model_args.prompt_tokenizer_name is None
845
+ and model_args.description_tokenizer_name
846
+ or (model_args.prompt_tokenizer_name == model_args.description_tokenizer_name)
847
+ ):
848
+ prompt_tokenizer.save_pretrained(training_args.output_dir)
849
+ else:
850
+ logger.warning(
851
+ f"Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer."
852
+ )
853
+ prompt_tokenizer.save_pretrained(training_args.output_dir)
854
+
855
+ feature_extractor.save_pretrained(training_args.output_dir)
856
+ config.save_pretrained(training_args.output_dir)
857
+ accelerator.wait_for_everyone()
858
+
859
+ if checkpoint is not None:
860
+ accelerator.load_state(checkpoint)
861
+ # Find num steps and epoch from saved state string pattern
862
+ pattern = r"checkpoint-(\d+)-epoch-(\d+)"
863
+ match = re.search(pattern, checkpoint)
864
+ cur_step = int(match.group(1))
865
+ epochs_trained = int(match.group(2))
866
+
867
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
868
+ logger.info(f" Continuing training from epoch {epochs_trained}")
869
+ logger.info(f" Continuing training from global step {cur_step}")
870
+
871
+ steps_trained_progress_bar.update(cur_step)
872
+
873
+ for epoch in range(0, epochs_trained):
874
+ with accelerator.local_main_process_first():
875
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
876
+
877
+ if training_args.max_steps < 0:
878
+ # we know exactly the number of steps per epoch, so can skip through the required number of batches
879
+ resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
880
+ else:
881
+ # Currently we don't know how many steps we've taken in the current epoch
882
+ # So we just shuffle the dataset one extra time and start from a fresh epoch
883
+ # This is "good enough" for our purposes but not fully correct
884
+ resume_step = None
885
+ with accelerator.local_main_process_first():
886
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
887
+ else:
888
+ resume_step = None
889
+
890
+ gen_kwargs = {
891
+ "do_sample": model_args.do_sample,
892
+ "temperature": model_args.temperature,
893
+ "max_length": model_args.max_length,
894
+ # Because of the delayed pattern mask, generation might stop earlier because of unexpected behaviour
895
+ # on the first tokens of the codebooks that are delayed.
896
+ # This fix the issue.
897
+ "min_new_tokens": num_codebooks + 1,
898
+ }
899
+
900
+ # Define gradient update step fn
901
+ def train_step(
902
+ batch,
903
+ accelerator,
904
+ autocast_kwargs,
905
+ num_items_in_batch,
906
+ gradient_accumulation_steps,
907
+ ):
908
+ if mixed_precision == "fp16":
909
+ # fp16 doesn't work with T5-like models
910
+ with accelerator.autocast(autocast_handler=autocast_kwargs):
911
+ if training_args.parallel_mode.value != "distributed":
912
+ encoder_outputs = model.text_encoder(
913
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
914
+ )
915
+ else:
916
+ encoder_outputs = model.module.text_encoder(
917
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
918
+ )
919
+ # we optionnally project last_hidden_state to avoid recomputing every time
920
+ encoder_hidden_states = encoder_outputs.last_hidden_state
921
+ if (
922
+ config.text_encoder.hidden_size != config.decoder.hidden_size
923
+ and config.decoder.cross_attention_hidden_size is None
924
+ ):
925
+ encoder_hidden_states = (
926
+ model.enc_to_dec_proj(encoder_hidden_states)
927
+ if training_args.parallel_mode.value != "distributed"
928
+ else model.module.enc_to_dec_proj(encoder_hidden_states)
929
+ )
930
+
931
+ if batch.get("attention_mask", None) is not None:
932
+ encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None]
933
+
934
+ encoder_outputs.last_hidden_state = encoder_hidden_states
935
+ batch["encoder_outputs"] = encoder_outputs
936
+
937
+ outputs = model(**batch, loss_reduction="sum")
938
+ # CE (data) loss
939
+ ce_loss = (outputs.loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch
940
+
941
+ metrics = {"loss": ce_loss}
942
+
943
+ # per CE loss
944
+ per_codebook_losses = outputs.per_codebook_losses
945
+ metrics.update({f"codebook_{i}_loss": ((l * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch) for (i,l) in enumerate(per_codebook_losses)})
946
+ return ce_loss, metrics
947
+
948
+ # Define eval fn
949
+ def eval_step(
950
+ batch,
951
+ accelerator,
952
+ autocast_kwargs,
953
+ ):
954
+ eval_model = model if not training_args.torch_compile else model._orig_mod
955
+
956
+ if mixed_precision == "fp16":
957
+ # fp16 doesn't work with T5-like models
958
+ with accelerator.autocast(autocast_handler=autocast_kwargs):
959
+ if training_args.parallel_mode.value != "distributed":
960
+ encoder_outputs = model.text_encoder(
961
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
962
+ )
963
+ else:
964
+ encoder_outputs = model.module.text_encoder(
965
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
966
+ )
967
+ # we optionnally project last_hidden_state to avoid recomputing every time
968
+ encoder_hidden_states = encoder_outputs.last_hidden_state
969
+ if (
970
+ config.text_encoder.hidden_size != config.decoder.hidden_size
971
+ and config.decoder.cross_attention_hidden_size is None
972
+ ):
973
+ encoder_hidden_states = (
974
+ model.enc_to_dec_proj(encoder_hidden_states)
975
+ if training_args.parallel_mode.value != "distributed"
976
+ else model.module.enc_to_dec_proj(encoder_hidden_states)
977
+ )
978
+
979
+ if batch.get("attention_mask", None) is not None:
980
+ encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None]
981
+
982
+ encoder_outputs.last_hidden_state = encoder_hidden_states
983
+ batch["encoder_outputs"] = encoder_outputs
984
+
985
+ with torch.no_grad():
986
+ outputs = eval_model(**batch)
987
+ # CE (data) loss
988
+ ce_loss = outputs.loss
989
+ metrics = {"loss": ce_loss}
990
+
991
+ # per CE loss
992
+ per_codebook_losses = outputs.per_codebook_losses
993
+ metrics.update({f"codebook_{i}_loss": l for (i,l) in enumerate(per_codebook_losses)})
994
+ return metrics
995
+
996
+ def generate_step(batch, accelerator):
997
+ batch.pop("decoder_attention_mask", None)
998
+ eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=True)
999
+ if training_args.torch_compile:
1000
+ # if the model is compiled, we use the original model bc compile is not compatible with .generate
1001
+ eval_model = model._orig_mod
1002
+
1003
+ # since we've might have loaded the weights in fp32, we have to autocast to ensure FA2 weights are in half-precision.
1004
+ # with accelerator.autocast(autocast_handler=AutocastKwargs(enabled=(attn_implementation=="flash_attention_2"))):
1005
+ output_audios = eval_model.generate(**batch, **gen_kwargs)
1006
+ output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
1007
+ return output_audios
1008
+
1009
+ model.train()
1010
+
1011
+ total_batched_samples = resume_step if resume_step is not None else 0
1012
+ for epoch in range(epochs_trained, num_epochs):
1013
+ with accelerator.local_main_process_first():
1014
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1015
+ sampler = None
1016
+ if training_args.group_by_length:
1017
+ sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"])
1018
+ train_dataloader = DataLoader(
1019
+ vectorized_datasets["train"],
1020
+ collate_fn=data_collator,
1021
+ batch_size=per_device_train_batch_size,
1022
+ sampler=sampler,
1023
+ shuffle=not training_args.group_by_length,
1024
+ num_workers=training_args.dataloader_num_workers,
1025
+ pin_memory=training_args.dataloader_pin_memory,
1026
+ )
1027
+ train_dataloader = accelerator.prepare(train_dataloader)
1028
+ if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
1029
+ train_dataloader.dataset.set_epoch(epoch)
1030
+
1031
+ if resume_step is not None:
1032
+ # Skip the first N batches in the dataloader when resuming from a checkpoint
1033
+ logger.info(f" Skip first {resume_step} batches")
1034
+ train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
1035
+ resume_step = None
1036
+ accelerator.wait_for_everyone()
1037
+
1038
+ # We chunkify the epoch iterator into gradient accumulation steps `n` batches
1039
+ train_iterator = iter(train_dataloader)
1040
+ num_steps_in_epoch = len(train_dataloader)
1041
+ remainder = num_steps_in_epoch % gradient_accumulation_steps
1042
+ remainder = remainder if remainder != 0 else gradient_accumulation_steps
1043
+ total_updates = math.ceil(num_steps_in_epoch / gradient_accumulation_steps)
1044
+
1045
+ update_step = -1
1046
+ for _ in range(total_updates):
1047
+ update_step += 1
1048
+
1049
+ # preload the total batch per step
1050
+ batch_samples = []
1051
+ num_batches_in_step = gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
1052
+ for _ in range(num_batches_in_step):
1053
+ batch_samples += [next(train_iterator)]
1054
+
1055
+ # get num items in batch - if different than BOS and than -100
1056
+ num_items_in_batch = sum([(batch["labels"].ne(audio_encoder_bos_token_id) | batch["labels"].ne(-100) | batch["labels"].ne(audio_encoder_eos_token_id)).sum((0,1))[0] for batch in batch_samples])
1057
+ num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item()
1058
+
1059
+ # losses = []
1060
+ for i,batch in enumerate(batch_samples):
1061
+ total_batched_samples += 1
1062
+ ctx = model.no_sync if (i < len(batch_samples) - 1 and accelerator.num_processes > 1) else contextlib.nullcontext
1063
+
1064
+ with ctx():
1065
+ loss, train_metric = train_step(batch, accelerator, autocast_kwargs, num_items_in_batch, gradient_accumulation_steps)
1066
+ accelerator.backward(loss)
1067
+ # losses.append(loss.detach())
1068
+
1069
+ grad_norm = accelerator.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
1070
+ optimizer.step()
1071
+ lr_scheduler.step()
1072
+ optimizer.zero_grad()
1073
+
1074
+ # The accelerator has performed an optimization step behind the scenes
1075
+ steps_trained_progress_bar.update(1)
1076
+ cur_step += 1
1077
+
1078
+ # losses = accelerator.gather(sum(losses)).sum().item() / (accelerator.num_processes * gradient_accumulation_steps)
1079
+
1080
+ if cur_step % training_args.logging_steps == 0:
1081
+ steps_trained_progress_bar.write(
1082
+ f"Step... ({cur_step} / {total_train_steps} | Loss:"
1083
+ f" {train_metric['loss']}, Learning Rate:"
1084
+ f" {lr_scheduler.get_last_lr()[0]})"
1085
+ )
1086
+ train_metric["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
1087
+ log_metric(
1088
+ accelerator,
1089
+ metrics=train_metric,
1090
+ learning_rate=lr_scheduler.get_last_lr()[0],
1091
+ train_time=train_time + time.time() - train_start,
1092
+ step=cur_step,
1093
+ epoch=epoch,
1094
+ prefix="train",
1095
+ )
1096
+
1097
+ # save checkpoint and weights after each save_steps and at the end of training
1098
+ if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
1099
+ intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
1100
+ # safe_serialization=False to avoid shared tensors saving issue (TODO(YL): it's a temporary fix)
1101
+ # https://github.com/huggingface/transformers/issues/27293#issuecomment-1872560074
1102
+ accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
1103
+ accelerator.wait_for_everyone()
1104
+ if accelerator.is_main_process:
1105
+ rotate_checkpoints(
1106
+ training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger
1107
+ )
1108
+
1109
+ if cur_step == total_train_steps:
1110
+ # un-wrap student model for save
1111
+ unwrapped_model = accelerator.unwrap_model(model)
1112
+ unwrapped_model.save_pretrained(training_args.output_dir)
1113
+
1114
+ if training_args.push_to_hub:
1115
+ api.upload_folder(
1116
+ repo_id=repo_id,
1117
+ folder_path=training_args.output_dir,
1118
+ commit_message=f"Saving train state of step {cur_step}",
1119
+ run_as_future=True,
1120
+ )
1121
+ accelerator.wait_for_everyone()
1122
+
1123
+ if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
1124
+ train_time += time.time() - train_start
1125
+ # ======================== Evaluating ==============================
1126
+ model.eval()
1127
+ eval_metrics = []
1128
+ eval_preds = []
1129
+ eval_descriptions = []
1130
+ eval_prompts = []
1131
+ eval_start = time.time()
1132
+
1133
+ # release training input batch
1134
+ batch = release_memory(batch)
1135
+
1136
+ validation_dataloader = DataLoader(
1137
+ vectorized_datasets["eval"],
1138
+ collate_fn=data_collator,
1139
+ batch_size=per_device_eval_batch_size,
1140
+ drop_last=False,
1141
+ num_workers=training_args.eval_dataloader_num_workers,
1142
+ pin_memory=training_args.dataloader_pin_memory,
1143
+ )
1144
+ validation_dataloader = accelerator.prepare(validation_dataloader)
1145
+
1146
+ for batch in tqdm(
1147
+ validation_dataloader,
1148
+ desc=f"Evaluating - Inference ...",
1149
+ position=2,
1150
+ disable=not accelerator.is_local_main_process,
1151
+ ):
1152
+ # Model forward
1153
+ eval_metric = eval_step(batch, accelerator, autocast_kwargs)
1154
+ eval_metric = accelerator.gather_for_metrics(eval_metric)
1155
+ eval_metric = {key: val.unsqueeze(0) if val.ndim == 0 else val for (key,val) in eval_metric.items()}
1156
+ eval_metrics.append(eval_metric)
1157
+
1158
+ if training_args.predict_with_generate and (cur_step % eval_generation_steps == 0 or cur_step == total_train_steps):
1159
+ validation_dataloader = DataLoader(
1160
+ vectorized_datasets["eval"],
1161
+ collate_fn=data_collator,
1162
+ batch_size=per_device_eval_batch_size,
1163
+ drop_last=False,
1164
+ num_workers=training_args.eval_dataloader_num_workers,
1165
+ pin_memory=training_args.dataloader_pin_memory,
1166
+ )
1167
+ validation_dataloader = accelerator.prepare(validation_dataloader)
1168
+ # generation
1169
+ for batch in tqdm(
1170
+ validation_dataloader,
1171
+ desc=f"Evaluating - Generation ...",
1172
+ position=2,
1173
+ disable=not accelerator.is_local_main_process,
1174
+ ):
1175
+ generated_audios = generate_step(batch, accelerator)
1176
+ # Gather all predictions and targets
1177
+ generated_audios, input_ids, prompts = accelerator.pad_across_processes(
1178
+ (generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0
1179
+ )
1180
+ generated_audios, input_ids, prompts = accelerator.gather_for_metrics(
1181
+ (generated_audios, input_ids, prompts)
1182
+ )
1183
+ eval_preds.extend(generated_audios.to("cpu"))
1184
+ eval_descriptions.extend(input_ids.to("cpu"))
1185
+ eval_prompts.extend(prompts.to("cpu"))
1186
+
1187
+ eval_time = time.time() - eval_start
1188
+ # normalize eval metrics
1189
+ eval_metrics = {
1190
+ key: torch.mean(torch.cat([d[key] for d in eval_metrics])).to("cpu") for key in eval_metrics[0]
1191
+ }
1192
+
1193
+ # compute metrics
1194
+ metrics_desc = ""
1195
+ if training_args.predict_with_generate and (cur_step % eval_generation_steps == 0 or cur_step == total_train_steps):
1196
+ if accelerator.is_local_main_process:
1197
+ (
1198
+ metric_values,
1199
+ pred_descriptions,
1200
+ pred_prompts,
1201
+ audios,
1202
+ transcriptions,
1203
+ si_sdr_measures,
1204
+ ) = compute_metrics(
1205
+ eval_preds,
1206
+ eval_descriptions,
1207
+ eval_prompts,
1208
+ accelerator.device,
1209
+ training_args.compute_clap_similarity_metric,
1210
+ training_args.compute_noise_level_metric,
1211
+ training_args.noise_level_to_compute_clean_wer,
1212
+ )
1213
+ eval_metrics.update(metric_values)
1214
+ metrics_desc = " ".join([f"Eval {key}: {value} |" for key, value in metric_values.items()])
1215
+ if "wandb" in training_args.report_to:
1216
+ log_pred(
1217
+ accelerator,
1218
+ pred_descriptions,
1219
+ pred_prompts,
1220
+ transcriptions,
1221
+ audios,
1222
+ si_sdr_measures,
1223
+ sampling_rate=sampling_rate,
1224
+ step=cur_step,
1225
+ prefix="eval",
1226
+ )
1227
+ accelerator.wait_for_everyone()
1228
+
1229
+ # Print metrics and update progress bar
1230
+ if accelerator.is_local_main_process:
1231
+ steps_trained_progress_bar.write(
1232
+ f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
1233
+ f" {metrics_desc})"
1234
+ )
1235
+
1236
+ log_metric(
1237
+ accelerator,
1238
+ metrics=eval_metrics,
1239
+ train_time=eval_time,
1240
+ step=cur_step,
1241
+ epoch=epoch,
1242
+ prefix="eval",
1243
+ )
1244
+
1245
+ # release eval batch and relax metrics
1246
+ eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric = release_memory(
1247
+ eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric
1248
+ )
1249
+ if training_args.predict_with_generate and (cur_step % eval_generation_steps == 0 or cur_step == total_train_steps):
1250
+ generated_audios, input_ids, prompts = release_memory(generated_audios, input_ids, prompts)
1251
+
1252
+ # train mode
1253
+ model.train()
1254
+
1255
+ # flush the train metrics
1256
+ train_start = time.time()
1257
+
1258
+ # break condition
1259
+ if cur_step == total_train_steps:
1260
+ continue_training = False
1261
+ break
1262
+
1263
+ if not continue_training:
1264
+ break
1265
+
1266
+ accelerator.end_training()
1267
+
1268
+
1269
+ if __name__ == "__main__":
1270
+ main()
capspeech/ar/training/finetune_capttsse.py ADDED
@@ -0,0 +1,1267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """ Train Parler-TTS using 🤗 Accelerate"""
18
+
19
+ import logging
20
+ import os
21
+ import re
22
+ import sys
23
+ import time
24
+ import math
25
+ import contextlib
26
+ from multiprocess import set_start_method
27
+ from datetime import timedelta
28
+ import inspect
29
+ from tqdm import tqdm
30
+ from pathlib import Path
31
+ import wandb
32
+
33
+ import torch
34
+ from torch.utils.data import DataLoader
35
+
36
+ import datasets
37
+ from datasets import DatasetDict, Dataset, IterableDataset, concatenate_datasets
38
+
39
+ from huggingface_hub import HfApi
40
+
41
+ import transformers
42
+ from transformers import AutoFeatureExtractor, AutoTokenizer, HfArgumentParser
43
+ from transformers.trainer_pt_utils import LengthGroupedSampler
44
+ from transformers.optimization import get_scheduler
45
+ from transformers.utils import send_example_telemetry
46
+
47
+
48
+ from accelerate import Accelerator, skip_first_batches
49
+ from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin, DistributedDataParallelKwargs
50
+ from accelerate.utils.memory import release_memory
51
+
52
+ from parler_tts import (
53
+ ParlerTTSConfig,
54
+ ParlerTTSForConditionalGeneration,
55
+ build_delay_pattern_mask,
56
+ )
57
+
58
+ from training.utils import (
59
+ get_last_checkpoint,
60
+ rotate_checkpoints,
61
+ log_pred,
62
+ log_metric,
63
+ load_all_codec_checkpoints,
64
+ save_codec_checkpoint,
65
+ get_last_codec_checkpoint_step,
66
+ )
67
+ from training.arguments_capttsse import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments
68
+ from training.data_capttsse import load_multiple_datasets, DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding
69
+ from training.eval import clap_similarity, wer, si_sdr
70
+
71
+ logger = logging.getLogger(__name__)
72
+
73
+
74
+ def main():
75
+ # See all possible arguments in src/transformers/training_args.py
76
+ # or by passing the --help flag to this script.
77
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
78
+
79
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments))
80
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
81
+ # If we pass only one argument to the script and it's the path to a json file,
82
+ # let's parse it to get our arguments.
83
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
84
+ else:
85
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
86
+
87
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
88
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
89
+ send_example_telemetry("run_parler_tts", model_args, data_args)
90
+
91
+ if data_args.wandb_key is not None:
92
+ wandb.login(key=data_args.wandb_key)
93
+
94
+ if training_args.dtype == "float16":
95
+ mixed_precision = "fp16"
96
+ torch_dtype = torch.float16
97
+ elif training_args.dtype == "bfloat16":
98
+ mixed_precision = "bf16"
99
+ torch_dtype = torch.bfloat16
100
+ else:
101
+ mixed_precision = "no"
102
+ torch_dtype = torch.float32
103
+
104
+ if data_args.pad_to_max_length and (
105
+ data_args.max_duration_in_seconds is None
106
+ or data_args.max_prompt_token_length is None
107
+ or data_args.max_description_token_length is None
108
+ ):
109
+ raise ValueError(
110
+ "`pad_to_max_length` is `True` but one of the following parameters has not been set: `max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`"
111
+ )
112
+
113
+ padding = "max_length" if data_args.pad_to_max_length else "longest"
114
+
115
+ ####### A. Preparation
116
+ kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=120)), DistributedDataParallelKwargs(find_unused_parameters=False)]
117
+
118
+ accelerator = Accelerator(
119
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
120
+ mixed_precision=mixed_precision,
121
+ log_with=training_args.report_to,
122
+ project_dir=training_args.output_dir,
123
+ kwargs_handlers=kwargs_handlers,
124
+ )
125
+
126
+ accelerator.init_trackers(
127
+ project_name=data_args.wandb_project,
128
+ config={
129
+ "learning_rate": training_args.learning_rate,
130
+ "model_name_or_path": model_args.model_name_or_path,
131
+ "num_train_epochs": training_args.num_train_epochs,
132
+ "gradient_accumulation_steps": training_args.gradient_accumulation_steps,
133
+ "per_device_train_batch_size": training_args.per_device_train_batch_size,
134
+ "global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
135
+ "mixed_precision": mixed_precision,
136
+ "lr_scheduler_type": training_args.lr_scheduler_type,
137
+ "warmup_steps": training_args.warmup_steps,
138
+ "freeze_text_encoder": model_args.freeze_text_encoder,
139
+ "max_duration_in_seconds": data_args.max_duration_in_seconds,
140
+ "weight_decay": training_args.weight_decay,
141
+ "adam_beta1": training_args.adam_beta1,
142
+ "adam_beta2": training_args.adam_beta2,
143
+ "temperature": model_args.temperature,
144
+ },
145
+ init_kwargs={"wandb": {"name": data_args.wandb_run_name}} if data_args.wandb_run_name else {},
146
+ )
147
+
148
+ # Detecting last checkpoint and eventually continue from last checkpoint
149
+ last_checkpoint = None
150
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
151
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
152
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
153
+ raise ValueError(
154
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
155
+ "Use --overwrite_output_dir to overcome."
156
+ )
157
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
158
+ logger.info(
159
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
160
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
161
+ )
162
+
163
+ # Setup logging
164
+ logging.basicConfig(
165
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
166
+ datefmt="%m/%d/%Y %H:%M:%S",
167
+ handlers=[logging.StreamHandler(sys.stdout)],
168
+ )
169
+ logger.setLevel(logging.INFO if accelerator.is_main_process else logging.WARN)
170
+
171
+ # Log a small summary on each proces
172
+ logger.warning(
173
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
174
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
175
+ )
176
+
177
+ # Set the verbosity to info of the Transformers logger (on main process only)
178
+ if accelerator.is_local_main_process:
179
+ datasets.utils.logging.set_verbosity_warning()
180
+ transformers.utils.logging.set_verbosity_info()
181
+ else:
182
+ datasets.utils.logging.set_verbosity_error()
183
+ transformers.utils.logging.set_verbosity_error()
184
+
185
+ logger.info("Training/evaluation parameters %s", training_args)
186
+
187
+ # Set seed before initializing model.
188
+ set_seed(training_args.seed)
189
+ num_workers = data_args.preprocessing_num_workers
190
+
191
+ # 1. First, lett's instantiate the feature extractor, tokenizers and model
192
+ # Note for distributed training, the .from_pretrained methods guarantee that only
193
+ # one local process can concurrently download model & vocab.
194
+
195
+ # load feature extractor
196
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
197
+ model_args.feature_extractor_name or model_args.model_name_or_path,
198
+ cache_dir=model_args.cache_dir,
199
+ token=data_args.token,
200
+ trust_remote_code=data_args.trust_remote_code,
201
+ )
202
+ sampling_rate = feature_extractor.sampling_rate
203
+
204
+ # load prompt tokenizer
205
+ prompt_tokenizer = AutoTokenizer.from_pretrained(
206
+ model_args.prompt_tokenizer_name or model_args.description_tokenizer_name or model_args.model_name_or_path,
207
+ cache_dir=model_args.cache_dir,
208
+ token=data_args.token,
209
+ trust_remote_code=data_args.trust_remote_code,
210
+ use_fast=model_args.use_fast_tokenizer,
211
+ padding_side=model_args.prompt_padding_side,
212
+ )
213
+
214
+ # load description tokenizer
215
+ description_tokenizer = AutoTokenizer.from_pretrained(
216
+ model_args.description_tokenizer_name or model_args.model_name_or_path,
217
+ cache_dir=model_args.cache_dir,
218
+ token=data_args.token,
219
+ trust_remote_code=data_args.trust_remote_code,
220
+ use_fast=model_args.use_fast_tokenizer,
221
+ )
222
+
223
+ if model_args.use_fast_tokenizer:
224
+ logger.warning(
225
+ "Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235"
226
+ )
227
+ prompt_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
228
+ description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
229
+
230
+ # 2. Now, let's load the dataset
231
+
232
+ if data_args.save_to_disk is not None:
233
+ os.makedirs(data_args.save_to_disk, exist_ok=True)
234
+
235
+ # assume that the dataset has been saved to `save_to_disk` if the latter is not empty
236
+ dataset_was_precomputed = len(os.listdir(data_args.save_to_disk)) > 0
237
+ if dataset_was_precomputed:
238
+ with accelerator.local_main_process_first():
239
+ vectorized_datasets = datasets.load_from_disk(data_args.save_to_disk)
240
+ else:
241
+ raw_datasets = DatasetDict()
242
+
243
+ columns_to_keep = {
244
+ "target_audio_column_name": data_args.target_audio_column_name,
245
+ "prompt_column_name": data_args.prompt_column_name,
246
+ "source": data_args.source_column_name,
247
+ }
248
+ if data_args.description_column_name is not None:
249
+ columns_to_keep["description_column_name"] = data_args.description_column_name
250
+
251
+ if training_args.do_train:
252
+ raw_datasets["train"] = load_multiple_datasets(
253
+ accelerator,
254
+ data_args.train_dataset_name,
255
+ splits=data_args.train_split_name,
256
+ dataset_samples=data_args.train_dataset_samples,
257
+ seed=training_args.seed,
258
+ cache_dir=model_args.cache_dir,
259
+ num_proc=data_args.preprocessing_num_workers,
260
+ id_column_name=data_args.id_column_name,
261
+ columns_to_keep=columns_to_keep.values(),
262
+ prompt_column_name=data_args.prompt_column_name,
263
+ audio_column_name=data_args.target_audio_column_name,
264
+ sampling_rate=sampling_rate,
265
+ logger=logger,
266
+ librittsrmix_dir=data_args.librittsrmix_dir,
267
+ # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
268
+ )
269
+
270
+ for key in columns_to_keep:
271
+ if columns_to_keep[key] not in raw_datasets["train"].column_names:
272
+ raise ValueError(
273
+ f"--{key} '{columns_to_keep[key]}' not found in dataset '{data_args.train_dataset_name}'."
274
+ f" Make sure to set `--{key}` to the correct audio column - one of"
275
+ f" {', '.join(raw_datasets['train'].column_names)}."
276
+ )
277
+
278
+ if data_args.max_train_samples is not None:
279
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
280
+
281
+ if training_args.do_eval:
282
+ raw_datasets["eval"] = load_multiple_datasets(
283
+ accelerator,
284
+ data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
285
+ splits=data_args.eval_split_name,
286
+ cache_dir=model_args.cache_dir,
287
+ num_proc=data_args.preprocessing_num_workers,
288
+ id_column_name=data_args.id_column_name,
289
+ columns_to_keep=columns_to_keep.values(),
290
+ prompt_column_name=data_args.prompt_column_name,
291
+ audio_column_name=data_args.target_audio_column_name,
292
+ sampling_rate=sampling_rate,
293
+ logger=logger,
294
+ librittsrmix_dir=data_args.librittsrmix_dir,
295
+ # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
296
+ )
297
+
298
+ if data_args.max_eval_samples is not None:
299
+ with accelerator.local_main_process_first():
300
+ raw_datasets["eval"] = (
301
+ raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
302
+ )
303
+
304
+ # 3. Next, let's load the config.
305
+ config = ParlerTTSConfig.from_pretrained(
306
+ model_args.model_name_or_path,
307
+ cache_dir=model_args.cache_dir,
308
+ token=data_args.token,
309
+ trust_remote_code=data_args.trust_remote_code,
310
+ )
311
+
312
+ if training_args.codebook_weights is not None and len(training_args.codebook_weights) != config.decoder.num_codebooks:
313
+ raise ValueError(f"`codebook_weights` has length {len(training_args.codebook_weights)} when it should be of length {config.decoder.num_codebooks}.")
314
+
315
+ # update pad token id and decoder_start_token_id
316
+ config.decoder.update(
317
+ {
318
+ "cross_attention_implementation_strategy": model_args.cross_attention_implementation_strategy
319
+ if model_args.cross_attention_implementation_strategy is not None
320
+ else None,
321
+ "codebook_weights": training_args.codebook_weights if training_args.codebook_weights is not None else config.decoder.codebook_weights
322
+ }
323
+ )
324
+ config.update(
325
+ {
326
+ "pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id,
327
+ "decoder_start_token_id": model_args.decoder_start_token_id
328
+ if model_args.decoder_start_token_id is not None
329
+ else config.decoder_start_token_id,
330
+ }
331
+ )
332
+
333
+ with open("events.txt", "r") as f:
334
+ events = [line.strip() for line in f]
335
+ events = ["<"+event.lower().replace(" ", "_")+">" for event in events]
336
+ events.append("<B_start>")
337
+ events.append("<B_end>")
338
+ events.append("<I_start>")
339
+ events.append("<I_end>")
340
+
341
+ special_tokens = {"additional_special_tokens": events}
342
+ prompt_tokenizer.add_special_tokens(special_tokens)
343
+ description_tokenizer.add_special_tokens(special_tokens)
344
+ padded_vocab_size = ((len(prompt_tokenizer) + 127) // 128) * 128
345
+ config.vocab_size = padded_vocab_size
346
+
347
+ # create model
348
+ model = ParlerTTSForConditionalGeneration.from_pretrained(
349
+ model_args.model_name_or_path,
350
+ ignore_mismatched_sizes=True,
351
+ cache_dir=model_args.cache_dir,
352
+ config=config,
353
+ token=data_args.token,
354
+ trust_remote_code=data_args.trust_remote_code,
355
+ attn_implementation={"decoder": model_args.attn_implementation, "text_encoder": "eager"},
356
+ )
357
+ model.text_encoder.resize_token_embeddings(padded_vocab_size)
358
+
359
+ # enable gradient checkpointing if necessary
360
+ if training_args.gradient_checkpointing:
361
+ model.gradient_checkpointing_enable()
362
+
363
+ # 4. Now we preprocess the datasets including loading the audio, resampling and normalization
364
+ # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
365
+ # so that we just need to set the correct target sampling rate and normalize the input
366
+ # via the `feature_extractor`
367
+
368
+ # derive max & min input length for sample rate & max duration
369
+ sampling_rate = feature_extractor.sampling_rate
370
+ max_target_length = int(data_args.max_duration_in_seconds * sampling_rate)
371
+ min_target_length = int(data_args.min_duration_in_seconds * sampling_rate)
372
+ target_audio_column_name = data_args.target_audio_column_name
373
+ description_column_name = data_args.description_column_name
374
+ prompt_column_name = data_args.prompt_column_name
375
+ feature_extractor_input_name = feature_extractor.model_input_names[0]
376
+ audio_encoder_pad_token_id = config.decoder.pad_token_id
377
+ audio_encoder_eos_token_id = config.decoder.eos_token_id
378
+ audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id
379
+ max_length = model.generation_config.max_length
380
+ num_codebooks = model.decoder.config.num_codebooks
381
+ bandwidth = model_args.bandwidth
382
+ attn_implementation = model_args.attn_implementation
383
+
384
+ # Freeze Encoders
385
+ model.freeze_encoders(model_args.freeze_text_encoder)
386
+
387
+ # Test all gather - used for warmout and avoiding timeout
388
+ logger.debug(str(accelerator.process_index), main_process_only=False, in_order=True)
389
+ test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
390
+ gathered_tensor = accelerator.gather(test_tensor)
391
+ print("gathered_tensor", gathered_tensor)
392
+ accelerator.wait_for_everyone()
393
+
394
+ if not dataset_was_precomputed:
395
+ # Filter on text length
396
+ if description_column_name is not None and data_args.max_text_length is not None:
397
+ with accelerator.local_main_process_first():
398
+ # filter description that is shorter than max_text_length
399
+ raw_datasets = raw_datasets.filter(
400
+ lambda x: len(x) < data_args.max_text_length,
401
+ num_proc=num_workers,
402
+ input_columns=[description_column_name],
403
+ )
404
+
405
+ # Preprocessing the dataset.
406
+ # We need to tokenize the texts.
407
+ def pass_through_processors(description, prompt):
408
+ batch = {}
409
+
410
+ batch["input_ids"] = description_tokenizer(description.strip())["input_ids"]
411
+ batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"]
412
+
413
+ return batch
414
+
415
+ with accelerator.local_main_process_first():
416
+ # this is a trick to avoid to rewrite the entire audio column which takes ages
417
+ vectorized_datasets = raw_datasets.map(
418
+ pass_through_processors,
419
+ remove_columns=next(iter(raw_datasets.values())).column_names,
420
+ input_columns=[description_column_name, prompt_column_name],
421
+ num_proc=num_workers,
422
+ desc="preprocess datasets",
423
+ )
424
+
425
+ # We use Accelerate to perform distributed inference
426
+ # T5 doesn't support fp16
427
+ autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
428
+
429
+ # Now we encode the audio labels with encodec.
430
+ ####### B. Encode audio
431
+
432
+ logger.info("*** Encode target audio with encodec ***")
433
+
434
+ # no need to prepare audio_decoder because used for inference without mixed precision
435
+ # see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
436
+ if training_args.torch_compile:
437
+ audio_decoder = accelerator.prepare_model(model.audio_encoder, evaluation_mode=True)
438
+ else:
439
+ audio_decoder = model.audio_encoder
440
+
441
+ encoder_data_collator = DataCollatorEncodecWithPadding(
442
+ feature_extractor,
443
+ audio_column_name=target_audio_column_name,
444
+ librittsrmix_dir=data_args.librittsrmix_dir,
445
+ feature_extractor_input_name=feature_extractor_input_name,
446
+ max_length=max_target_length,
447
+ padding=padding,
448
+ )
449
+ encoder_signature = set(inspect.signature(audio_decoder.forward).parameters)
450
+
451
+ def apply_audio_decoder(batch):
452
+ len_audio = batch.pop("len_audio")
453
+ audio_decoder.to(batch["input_values"].device).eval()
454
+ if bandwidth is not None:
455
+ batch["bandwidth"] = bandwidth
456
+ elif "num_quantizers" in encoder_signature:
457
+ batch["num_quantizers"] = num_codebooks
458
+ elif "num_codebooks" in encoder_signature:
459
+ batch["num_codebooks"] = num_codebooks
460
+ elif "n_quantizers" in encoder_signature:
461
+ batch["n_quantizers"] = num_codebooks
462
+
463
+ with torch.no_grad():
464
+ labels = audio_decoder.encode(**batch)["audio_codes"]
465
+ output = {}
466
+ output["len_audio"] = len_audio
467
+ # (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
468
+ output["labels"] = labels.squeeze(0).transpose(1, 2)
469
+
470
+ # if `pad_to_max_length`, the maximum corresponding audio length of the current batch is max_duration*sampling_rate
471
+ max_length = len_audio.max() if padding != "max_length" else max_target_length
472
+ output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / max_length
473
+ return output
474
+
475
+ # (1, codebooks, seq_len) where seq_len=1
476
+ bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
477
+
478
+ def postprocess_dataset(labels):
479
+ # (1, codebooks, seq_len)
480
+ labels = torch.tensor(labels).unsqueeze(0)
481
+ # add bos
482
+ labels = torch.cat([bos_labels, labels], dim=-1)
483
+
484
+ labels, delay_pattern_mask = build_delay_pattern_mask(
485
+ labels,
486
+ bos_token_id=audio_encoder_bos_token_id,
487
+ pad_token_id=audio_encoder_eos_token_id,
488
+ max_length=labels.shape[-1] + num_codebooks,
489
+ num_codebooks=num_codebooks,
490
+ )
491
+
492
+ # the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
493
+ # to take care of EOS
494
+ # we want labels to look like this:
495
+ # - [B, a, b, E, E, E, E]
496
+ # - [B, B, c, d, E, E, E]
497
+ # - [B, B, B, e, f, E, E]
498
+ # - [B, B, B, B, g, h, E]
499
+ labels = torch.where(delay_pattern_mask == -1, audio_encoder_eos_token_id, delay_pattern_mask)
500
+
501
+ # the first timestamp is associated to a row full of BOS, let's get rid of it
502
+ # we also remove the last timestampts (full of PAD)
503
+ output = {"labels": labels[:, 1:]}
504
+ return output
505
+
506
+ for split in vectorized_datasets:
507
+ data_loader = DataLoader(
508
+ raw_datasets[split],
509
+ batch_size=training_args.audio_encoder_per_device_batch_size,
510
+ collate_fn=encoder_data_collator,
511
+ num_workers=training_args.dataloader_num_workers,
512
+ pin_memory=True,
513
+ )
514
+ data_loader = accelerator.prepare(data_loader)
515
+ total_inference_steps = len(data_loader)
516
+
517
+ start_step = get_last_codec_checkpoint_step(os.path.join(data_args.temporary_save_to_disk, split))
518
+ accelerator.wait_for_everyone()
519
+ if start_step > 0:
520
+ logger.info(f"Resuming {split} from step {start_step}")
521
+ # efficiently skip the first n batches
522
+ start_step += 1
523
+ data_loader = skip_first_batches(data_loader, start_step)
524
+
525
+ all_generated_labels = []
526
+ all_lens = []
527
+ if start_step < total_inference_steps:
528
+ for i, batch in enumerate(tqdm(data_loader, disable=not accelerator.is_local_main_process)):
529
+ cur_step = start_step + i
530
+ generate_labels = apply_audio_decoder(batch)
531
+ generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0)
532
+ generate_labels = accelerator.gather_for_metrics(generate_labels)
533
+
534
+ if accelerator.is_main_process:
535
+ lab = generate_labels["labels"].cpu().transpose(1, 2).to(torch.int16)
536
+ rat = generate_labels["ratio"].cpu().squeeze(1)
537
+ lens = generate_labels["len_audio"].cpu().squeeze(1)
538
+ lab = [l[:, : int(ratio * length)] for (l, ratio, length) in zip(lab, rat, lens)]
539
+
540
+ all_generated_labels.extend(lab)
541
+ all_lens.extend(lens)
542
+
543
+ if ((cur_step + 1) % data_args.save_codec_steps == 0) or (
544
+ cur_step == total_inference_steps - 1
545
+ ):
546
+ tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
547
+ tmp_labels = tmp_labels.map(
548
+ postprocess_dataset,
549
+ num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
550
+ input_columns=["labels"],
551
+ desc="Postprocessing labeling",
552
+ )
553
+ save_codec_checkpoint(
554
+ os.path.join(data_args.temporary_save_to_disk, split), tmp_labels, cur_step
555
+ )
556
+ all_generated_labels = []
557
+ all_lens = []
558
+
559
+ accelerator.wait_for_everyone()
560
+
561
+ if accelerator.is_main_process and len(all_generated_labels) > 0:
562
+ tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
563
+ tmp_labels = tmp_labels.map(
564
+ postprocess_dataset,
565
+ num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
566
+ input_columns=["labels"],
567
+ desc="Postprocessing labeling",
568
+ )
569
+ save_codec_checkpoint(os.path.join(data_args.temporary_save_to_disk, split), tmp_labels, cur_step)
570
+ all_generated_labels = []
571
+ all_lens = []
572
+ accelerator.wait_for_everyone()
573
+
574
+ del all_generated_labels
575
+ accelerator.wait_for_everyone()
576
+
577
+ with accelerator.local_main_process_first():
578
+ tmp_labels = load_all_codec_checkpoints(os.path.join(data_args.temporary_save_to_disk, split)).select(
579
+ range(len(vectorized_datasets[split]))
580
+ )
581
+ logger.info(f"Concatenating {split}: {tmp_labels} with {vectorized_datasets[split]}")
582
+ vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1)
583
+
584
+ accelerator.free_memory()
585
+ del generate_labels, all_lens
586
+
587
+ with accelerator.local_main_process_first():
588
+ # NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations
589
+ # caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets.
590
+ # That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets.
591
+
592
+ def is_audio_in_length_range(length):
593
+ return length > min_target_length and length < max_target_length
594
+
595
+ # filter data that is shorter than min_target_length
596
+ vectorized_datasets = vectorized_datasets.filter(
597
+ is_audio_in_length_range,
598
+ num_proc=num_workers,
599
+ input_columns=["target_length"],
600
+ )
601
+
602
+ if description_column_name is not None and data_args.max_description_token_length is not None:
603
+ with accelerator.local_main_process_first():
604
+ # filter description that is shorter than max_text_length
605
+ vectorized_datasets = vectorized_datasets.filter(
606
+ lambda x: len(x) < data_args.max_description_token_length,
607
+ num_proc=num_workers,
608
+ input_columns=["input_ids"],
609
+ )
610
+
611
+ if data_args.max_prompt_token_length is not None:
612
+ with accelerator.local_main_process_first():
613
+ # filter description that is shorter than max_text_length
614
+ vectorized_datasets = vectorized_datasets.filter(
615
+ lambda x: len(x) < data_args.max_prompt_token_length,
616
+ num_proc=num_workers,
617
+ input_columns=["prompt_input_ids"],
618
+ )
619
+
620
+ if data_args.save_to_disk is not None and not dataset_was_precomputed:
621
+ if accelerator.is_main_process:
622
+ vectorized_datasets.save_to_disk(
623
+ data_args.save_to_disk,
624
+ num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"]) - 1),
625
+ )
626
+ accelerator.wait_for_everyone()
627
+ logger.info(f"Dataset saved at {data_args.save_to_disk}")
628
+
629
+ audio_max_length = None
630
+ if padding == "max_length":
631
+ audio_max_length = max(vectorized_datasets["train"]["target_length"])
632
+ with accelerator.local_main_process_first():
633
+ max_sample = vectorized_datasets["train"].filter(
634
+ lambda x: x == audio_max_length,
635
+ num_proc=num_workers,
636
+ input_columns=["target_length"],
637
+ )
638
+ audio_max_length = max([len(l[0]) for l in max_sample["labels"]])
639
+
640
+ if description_column_name is not None and data_args.max_description_token_length is not None:
641
+ with accelerator.local_main_process_first():
642
+ # filter description that is shorter than max_text_length
643
+ vectorized_datasets = vectorized_datasets.filter(
644
+ lambda x: len(x) < data_args.max_description_token_length,
645
+ num_proc=num_workers,
646
+ input_columns=["input_ids"],
647
+ )
648
+
649
+ if data_args.max_prompt_token_length is not None:
650
+ with accelerator.local_main_process_first():
651
+ # filter description that is shorter than max_text_length
652
+ vectorized_datasets = vectorized_datasets.filter(
653
+ lambda x: len(x) < data_args.max_prompt_token_length,
654
+ num_proc=num_workers,
655
+ input_columns=["prompt_input_ids"],
656
+ )
657
+
658
+ if training_args.group_by_length:
659
+ # apply a simple heuristic to take into account audio and text lengths
660
+ def add_target_lengths(target_length, prompt, description):
661
+ return {"target_length": target_length + len(prompt) + len(description)}
662
+
663
+ with accelerator.local_main_process_first():
664
+ vectorized_datasets = vectorized_datasets.map(
665
+ add_target_lengths,
666
+ num_proc=num_workers,
667
+ input_columns=["target_length", "prompt_input_ids", "input_ids"],
668
+ )
669
+
670
+ # for large datasets it is advised to run the preprocessing on a
671
+ # single machine first with ``args.preprocessing_only`` since there will mostly likely
672
+ # be a timeout when running the script in distributed mode.
673
+ # In a second step ``args.preprocessing_only`` can then be set to `False` to load the
674
+ # cached dataset
675
+ if data_args.preprocessing_only and data_args.save_to_disk is None:
676
+ raise ValueError(
677
+ "`preprocessing_only=True` but `save_to_disk` is not set. The latter should indicates where to save the dataset locally."
678
+ )
679
+ elif data_args.preprocessing_only:
680
+ logger.info(f"Data preprocessing finished. Files save at {data_args.save_to_disk}")
681
+ return
682
+
683
+ # 6. Next, we can prepare the training.
684
+
685
+ # Let's use word CLAP similary and WER metrics as our evaluation metrics,
686
+ def compute_metrics(
687
+ audios,
688
+ descriptions,
689
+ prompts,
690
+ device="cpu",
691
+ compute_clap_similarity_metric=False,
692
+ compute_noise_level_metric=False,
693
+ noise_level_to_compute_clean_wer=None,
694
+ ):
695
+ results = {}
696
+ input_ids = descriptions
697
+ texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
698
+ prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
699
+ audios = [a.float().cpu().numpy() for a in audios]
700
+
701
+ if compute_clap_similarity_metric:
702
+ clap_score = clap_similarity(
703
+ model_args.clap_model_name_or_path, texts, audios, device, input_sampling_rate=sampling_rate
704
+ )
705
+ results["clap"] = clap_score
706
+
707
+ si_sdr_measures = None
708
+ if compute_noise_level_metric:
709
+ si_sdr_measures = si_sdr(audios, device, input_sampling_rate=sampling_rate)
710
+
711
+ word_error, transcriptions, clean_word_error, noisy_word_error, percent_clean_samples = wer(
712
+ model_args.asr_model_name_or_path,
713
+ prompts,
714
+ audios,
715
+ device,
716
+ training_args.per_device_eval_batch_size,
717
+ sampling_rate,
718
+ noise_level_to_compute_clean_wer,
719
+ si_sdr_measures,
720
+ )
721
+ results["wer"] = word_error
722
+ if clean_word_error is not None:
723
+ results["clean_wer"] = clean_word_error
724
+ results["noisy_word_error"] = noisy_word_error
725
+ results["percent_clean_samples"] = percent_clean_samples
726
+
727
+ return results, texts, prompts, audios, transcriptions, si_sdr_measures
728
+
729
+ # Define Training Schedule
730
+ # Store some constants
731
+ per_device_train_batch_size = int(training_args.per_device_train_batch_size)
732
+ train_batch_size = per_device_train_batch_size * accelerator.num_processes
733
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
734
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
735
+
736
+ if training_args.max_steps < 0:
737
+ num_epochs = int(training_args.num_train_epochs)
738
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
739
+ total_train_steps = steps_per_epoch * num_epochs
740
+ elif training_args.max_steps > 0:
741
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
742
+ total_train_steps = int(training_args.max_steps)
743
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
744
+ num_epochs = sys.maxsize
745
+ steps_per_epoch = total_train_steps
746
+
747
+ if training_args.eval_steps is None:
748
+ logger.info(f"eval_steps is not set, evaluating at the end of each epoch")
749
+ eval_steps = steps_per_epoch
750
+ else:
751
+ eval_steps = training_args.eval_steps
752
+
753
+ if training_args.eval_generation_steps is None:
754
+ eval_generation_steps = eval_steps
755
+ else:
756
+ eval_generation_steps = training_args.eval_generation_steps
757
+
758
+ # T5 doesn't support fp16
759
+ autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
760
+
761
+ # Define optimizer, LR scheduler, collator
762
+ optimizer = torch.optim.AdamW(
763
+ params=model.parameters(),
764
+ lr=training_args.learning_rate,
765
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
766
+ eps=training_args.adam_epsilon,
767
+ weight_decay=training_args.weight_decay,
768
+ )
769
+
770
+ # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
771
+ lr_scheduler = get_scheduler(
772
+ name=training_args.lr_scheduler_type,
773
+ optimizer=optimizer,
774
+ num_warmup_steps=training_args.get_warmup_steps(total_train_steps) * accelerator.num_processes,
775
+ num_training_steps=total_train_steps * accelerator.num_processes,
776
+ )
777
+
778
+ # Instantiate custom data collator
779
+ data_collator = DataCollatorParlerTTSWithPadding(
780
+ prompt_tokenizer=prompt_tokenizer,
781
+ description_tokenizer=description_tokenizer,
782
+ pad_to_multiple_of=data_args.pad_to_multiple_of,
783
+ padding=padding,
784
+ prompt_max_length=data_args.max_prompt_token_length,
785
+ description_max_length=data_args.max_description_token_length,
786
+ audio_max_length=audio_max_length,
787
+ )
788
+
789
+ # Prepare everything with accelerate
790
+ model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
791
+
792
+ num_examples = total_train_steps * train_batch_size * gradient_accumulation_steps
793
+ logger.info("***** Running training *****")
794
+ logger.info(f" Num examples = {num_examples}")
795
+ logger.info(" Instantaneous batch size per device =" f" {per_device_train_batch_size}")
796
+ logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
797
+ logger.info(
798
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
799
+ )
800
+ logger.info(f" Total optimization steps = {total_train_steps}")
801
+
802
+ # ======================== Training ================================
803
+ train_time = 0
804
+ train_start = time.time()
805
+ steps_trained_progress_bar = tqdm(
806
+ range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
807
+ )
808
+ continue_training = True
809
+ epochs_trained = 0
810
+ cur_step = 0
811
+
812
+ checkpoint = None
813
+ if training_args.resume_from_checkpoint is not None:
814
+ checkpoint = training_args.resume_from_checkpoint
815
+ elif last_checkpoint is not None:
816
+ checkpoint = last_checkpoint
817
+
818
+ if accelerator.is_main_process:
819
+ if training_args.push_to_hub:
820
+ api = HfApi(token=training_args.hub_token)
821
+
822
+ # Create repo (repo_name from args or inferred)
823
+ repo_name = training_args.hub_model_id
824
+ if repo_name is None:
825
+ repo_name = Path(training_args.output_dir).absolute().name
826
+ repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
827
+
828
+ with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
829
+ if "wandb" not in gitignore:
830
+ gitignore.write("wandb\n")
831
+ elif training_args.output_dir is not None:
832
+ os.makedirs(training_args.output_dir, exist_ok=True)
833
+ accelerator.wait_for_everyone()
834
+
835
+ # Now save everything to be able to create a single processor later
836
+ # make sure all processes wait until data is saved
837
+ # only the main process saves them
838
+ if accelerator.is_main_process:
839
+ # save feature extractor, tokenizer and config
840
+ if (
841
+ model_args.prompt_tokenizer_name is None
842
+ and model_args.description_tokenizer_name
843
+ or (model_args.prompt_tokenizer_name == model_args.description_tokenizer_name)
844
+ ):
845
+ prompt_tokenizer.save_pretrained(training_args.output_dir)
846
+ else:
847
+ logger.warning(
848
+ f"Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer."
849
+ )
850
+ prompt_tokenizer.save_pretrained(training_args.output_dir)
851
+
852
+ feature_extractor.save_pretrained(training_args.output_dir)
853
+ config.save_pretrained(training_args.output_dir)
854
+ accelerator.wait_for_everyone()
855
+
856
+ if checkpoint is not None:
857
+ accelerator.load_state(checkpoint)
858
+ # Find num steps and epoch from saved state string pattern
859
+ pattern = r"checkpoint-(\d+)-epoch-(\d+)"
860
+ match = re.search(pattern, checkpoint)
861
+ cur_step = int(match.group(1))
862
+ epochs_trained = int(match.group(2))
863
+
864
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
865
+ logger.info(f" Continuing training from epoch {epochs_trained}")
866
+ logger.info(f" Continuing training from global step {cur_step}")
867
+
868
+ steps_trained_progress_bar.update(cur_step)
869
+
870
+ for epoch in range(0, epochs_trained):
871
+ with accelerator.local_main_process_first():
872
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
873
+
874
+ if training_args.max_steps < 0:
875
+ # we know exactly the number of steps per epoch, so can skip through the required number of batches
876
+ resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
877
+ else:
878
+ # Currently we don't know how many steps we've taken in the current epoch
879
+ # So we just shuffle the dataset one extra time and start from a fresh epoch
880
+ # This is "good enough" for our purposes but not fully correct
881
+ resume_step = None
882
+ with accelerator.local_main_process_first():
883
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
884
+ else:
885
+ resume_step = None
886
+
887
+ gen_kwargs = {
888
+ "do_sample": model_args.do_sample,
889
+ "temperature": model_args.temperature,
890
+ "max_length": model_args.max_length,
891
+ # Because of the delayed pattern mask, generation might stop earlier because of unexpected behaviour
892
+ # on the first tokens of the codebooks that are delayed.
893
+ # This fix the issue.
894
+ "min_new_tokens": num_codebooks + 1,
895
+ }
896
+
897
+ # Define gradient update step fn
898
+ def train_step(
899
+ batch,
900
+ accelerator,
901
+ autocast_kwargs,
902
+ num_items_in_batch,
903
+ gradient_accumulation_steps,
904
+ ):
905
+ if mixed_precision == "fp16":
906
+ # fp16 doesn't work with T5-like models
907
+ with accelerator.autocast(autocast_handler=autocast_kwargs):
908
+ if training_args.parallel_mode.value != "distributed":
909
+ encoder_outputs = model.text_encoder(
910
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
911
+ )
912
+ else:
913
+ encoder_outputs = model.module.text_encoder(
914
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
915
+ )
916
+ # we optionnally project last_hidden_state to avoid recomputing every time
917
+ encoder_hidden_states = encoder_outputs.last_hidden_state
918
+ if (
919
+ config.text_encoder.hidden_size != config.decoder.hidden_size
920
+ and config.decoder.cross_attention_hidden_size is None
921
+ ):
922
+ encoder_hidden_states = (
923
+ model.enc_to_dec_proj(encoder_hidden_states)
924
+ if training_args.parallel_mode.value != "distributed"
925
+ else model.module.enc_to_dec_proj(encoder_hidden_states)
926
+ )
927
+
928
+ if batch.get("attention_mask", None) is not None:
929
+ encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None]
930
+
931
+ encoder_outputs.last_hidden_state = encoder_hidden_states
932
+ batch["encoder_outputs"] = encoder_outputs
933
+
934
+ outputs = model(**batch, loss_reduction="sum")
935
+ # CE (data) loss
936
+ ce_loss = (outputs.loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch
937
+
938
+ metrics = {"loss": ce_loss}
939
+
940
+ # per CE loss
941
+ per_codebook_losses = outputs.per_codebook_losses
942
+ metrics.update({f"codebook_{i}_loss": ((l * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch) for (i,l) in enumerate(per_codebook_losses)})
943
+ return ce_loss, metrics
944
+
945
+ # Define eval fn
946
+ def eval_step(
947
+ batch,
948
+ accelerator,
949
+ autocast_kwargs,
950
+ ):
951
+ eval_model = model if not training_args.torch_compile else model._orig_mod
952
+
953
+ if mixed_precision == "fp16":
954
+ # fp16 doesn't work with T5-like models
955
+ with accelerator.autocast(autocast_handler=autocast_kwargs):
956
+ if training_args.parallel_mode.value != "distributed":
957
+ encoder_outputs = model.text_encoder(
958
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
959
+ )
960
+ else:
961
+ encoder_outputs = model.module.text_encoder(
962
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
963
+ )
964
+ # we optionnally project last_hidden_state to avoid recomputing every time
965
+ encoder_hidden_states = encoder_outputs.last_hidden_state
966
+ if (
967
+ config.text_encoder.hidden_size != config.decoder.hidden_size
968
+ and config.decoder.cross_attention_hidden_size is None
969
+ ):
970
+ encoder_hidden_states = (
971
+ model.enc_to_dec_proj(encoder_hidden_states)
972
+ if training_args.parallel_mode.value != "distributed"
973
+ else model.module.enc_to_dec_proj(encoder_hidden_states)
974
+ )
975
+
976
+ if batch.get("attention_mask", None) is not None:
977
+ encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None]
978
+
979
+ encoder_outputs.last_hidden_state = encoder_hidden_states
980
+ batch["encoder_outputs"] = encoder_outputs
981
+
982
+ with torch.no_grad():
983
+ outputs = eval_model(**batch)
984
+ # CE (data) loss
985
+ ce_loss = outputs.loss
986
+ metrics = {"loss": ce_loss}
987
+
988
+ # per CE loss
989
+ per_codebook_losses = outputs.per_codebook_losses
990
+ metrics.update({f"codebook_{i}_loss": l for (i,l) in enumerate(per_codebook_losses)})
991
+ return metrics
992
+
993
+ def generate_step(batch, accelerator):
994
+ batch.pop("decoder_attention_mask", None)
995
+ eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=True)
996
+ if training_args.torch_compile:
997
+ # if the model is compiled, we use the original model bc compile is not compatible with .generate
998
+ eval_model = model._orig_mod
999
+
1000
+ # since we've might have loaded the weights in fp32, we have to autocast to ensure FA2 weights are in half-precision.
1001
+ # with accelerator.autocast(autocast_handler=AutocastKwargs(enabled=(attn_implementation=="flash_attention_2"))):
1002
+ output_audios = eval_model.generate(**batch, **gen_kwargs)
1003
+ output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
1004
+ return output_audios
1005
+
1006
+ model.train()
1007
+
1008
+ total_batched_samples = resume_step if resume_step is not None else 0
1009
+ for epoch in range(epochs_trained, num_epochs):
1010
+ with accelerator.local_main_process_first():
1011
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1012
+ sampler = None
1013
+ if training_args.group_by_length:
1014
+ sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"])
1015
+ train_dataloader = DataLoader(
1016
+ vectorized_datasets["train"],
1017
+ collate_fn=data_collator,
1018
+ batch_size=per_device_train_batch_size,
1019
+ sampler=sampler,
1020
+ shuffle=not training_args.group_by_length,
1021
+ num_workers=training_args.dataloader_num_workers,
1022
+ pin_memory=training_args.dataloader_pin_memory,
1023
+ )
1024
+ train_dataloader = accelerator.prepare(train_dataloader)
1025
+ if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
1026
+ train_dataloader.dataset.set_epoch(epoch)
1027
+
1028
+ if resume_step is not None:
1029
+ # Skip the first N batches in the dataloader when resuming from a checkpoint
1030
+ logger.info(f" Skip first {resume_step} batches")
1031
+ train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
1032
+ resume_step = None
1033
+ accelerator.wait_for_everyone()
1034
+
1035
+ # We chunkify the epoch iterator into gradient accumulation steps `n` batches
1036
+ train_iterator = iter(train_dataloader)
1037
+ num_steps_in_epoch = len(train_dataloader)
1038
+ remainder = num_steps_in_epoch % gradient_accumulation_steps
1039
+ remainder = remainder if remainder != 0 else gradient_accumulation_steps
1040
+ total_updates = math.ceil(num_steps_in_epoch / gradient_accumulation_steps)
1041
+
1042
+ update_step = -1
1043
+ for _ in range(total_updates):
1044
+ update_step += 1
1045
+
1046
+ # preload the total batch per step
1047
+ batch_samples = []
1048
+ num_batches_in_step = gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
1049
+ for _ in range(num_batches_in_step):
1050
+ batch_samples += [next(train_iterator)]
1051
+
1052
+ # get num items in batch - if different than BOS and than -100
1053
+ num_items_in_batch = sum([(batch["labels"].ne(audio_encoder_bos_token_id) | batch["labels"].ne(-100) | batch["labels"].ne(audio_encoder_eos_token_id)).sum((0,1))[0] for batch in batch_samples])
1054
+ num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item()
1055
+
1056
+ # losses = []
1057
+ for i,batch in enumerate(batch_samples):
1058
+ total_batched_samples += 1
1059
+ ctx = model.no_sync if (i < len(batch_samples) - 1 and accelerator.num_processes > 1) else contextlib.nullcontext
1060
+
1061
+ with ctx():
1062
+ loss, train_metric = train_step(batch, accelerator, autocast_kwargs, num_items_in_batch, gradient_accumulation_steps)
1063
+ accelerator.backward(loss)
1064
+ # losses.append(loss.detach())
1065
+
1066
+ grad_norm = accelerator.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
1067
+ optimizer.step()
1068
+ lr_scheduler.step()
1069
+ optimizer.zero_grad()
1070
+
1071
+ # The accelerator has performed an optimization step behind the scenes
1072
+ steps_trained_progress_bar.update(1)
1073
+ cur_step += 1
1074
+
1075
+ # losses = accelerator.gather(sum(losses)).sum().item() / (accelerator.num_processes * gradient_accumulation_steps)
1076
+
1077
+ if cur_step % training_args.logging_steps == 0:
1078
+ steps_trained_progress_bar.write(
1079
+ f"Step... ({cur_step} / {total_train_steps} | Loss:"
1080
+ f" {train_metric['loss']}, Learning Rate:"
1081
+ f" {lr_scheduler.get_last_lr()[0]})"
1082
+ )
1083
+ train_metric["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
1084
+ log_metric(
1085
+ accelerator,
1086
+ metrics=train_metric,
1087
+ learning_rate=lr_scheduler.get_last_lr()[0],
1088
+ train_time=train_time + time.time() - train_start,
1089
+ step=cur_step,
1090
+ epoch=epoch,
1091
+ prefix="train",
1092
+ )
1093
+
1094
+ # save checkpoint and weights after each save_steps and at the end of training
1095
+ if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
1096
+ intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
1097
+ # safe_serialization=False to avoid shared tensors saving issue (TODO(YL): it's a temporary fix)
1098
+ # https://github.com/huggingface/transformers/issues/27293#issuecomment-1872560074
1099
+ accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
1100
+ accelerator.wait_for_everyone()
1101
+ if accelerator.is_main_process:
1102
+ rotate_checkpoints(
1103
+ training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger
1104
+ )
1105
+
1106
+ if cur_step == total_train_steps:
1107
+ # un-wrap student model for save
1108
+ unwrapped_model = accelerator.unwrap_model(model)
1109
+ unwrapped_model.save_pretrained(training_args.output_dir)
1110
+
1111
+ if training_args.push_to_hub:
1112
+ api.upload_folder(
1113
+ repo_id=repo_id,
1114
+ folder_path=training_args.output_dir,
1115
+ commit_message=f"Saving train state of step {cur_step}",
1116
+ run_as_future=True,
1117
+ )
1118
+ accelerator.wait_for_everyone()
1119
+
1120
+ if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
1121
+ train_time += time.time() - train_start
1122
+ # ======================== Evaluating ==============================
1123
+ model.eval()
1124
+ eval_metrics = []
1125
+ eval_preds = []
1126
+ eval_descriptions = []
1127
+ eval_prompts = []
1128
+ eval_start = time.time()
1129
+
1130
+ # release training input batch
1131
+ batch = release_memory(batch)
1132
+
1133
+ validation_dataloader = DataLoader(
1134
+ vectorized_datasets["eval"],
1135
+ collate_fn=data_collator,
1136
+ batch_size=per_device_eval_batch_size,
1137
+ drop_last=False,
1138
+ num_workers=training_args.eval_dataloader_num_workers,
1139
+ pin_memory=training_args.dataloader_pin_memory,
1140
+ )
1141
+ validation_dataloader = accelerator.prepare(validation_dataloader)
1142
+
1143
+ for batch in tqdm(
1144
+ validation_dataloader,
1145
+ desc=f"Evaluating - Inference ...",
1146
+ position=2,
1147
+ disable=not accelerator.is_local_main_process,
1148
+ ):
1149
+ # Model forward
1150
+ eval_metric = eval_step(batch, accelerator, autocast_kwargs)
1151
+ eval_metric = accelerator.gather_for_metrics(eval_metric)
1152
+ eval_metric = {key: val.unsqueeze(0) if val.ndim == 0 else val for (key,val) in eval_metric.items()}
1153
+ eval_metrics.append(eval_metric)
1154
+
1155
+ if training_args.predict_with_generate and (cur_step % eval_generation_steps == 0 or cur_step == total_train_steps):
1156
+ validation_dataloader = DataLoader(
1157
+ vectorized_datasets["eval"],
1158
+ collate_fn=data_collator,
1159
+ batch_size=per_device_eval_batch_size,
1160
+ drop_last=False,
1161
+ num_workers=training_args.eval_dataloader_num_workers,
1162
+ pin_memory=training_args.dataloader_pin_memory,
1163
+ )
1164
+ validation_dataloader = accelerator.prepare(validation_dataloader)
1165
+ # generation
1166
+ for batch in tqdm(
1167
+ validation_dataloader,
1168
+ desc=f"Evaluating - Generation ...",
1169
+ position=2,
1170
+ disable=not accelerator.is_local_main_process,
1171
+ ):
1172
+ generated_audios = generate_step(batch, accelerator)
1173
+ # Gather all predictions and targets
1174
+ generated_audios, input_ids, prompts = accelerator.pad_across_processes(
1175
+ (generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0
1176
+ )
1177
+ generated_audios, input_ids, prompts = accelerator.gather_for_metrics(
1178
+ (generated_audios, input_ids, prompts)
1179
+ )
1180
+ eval_preds.extend(generated_audios.to("cpu"))
1181
+ eval_descriptions.extend(input_ids.to("cpu"))
1182
+ eval_prompts.extend(prompts.to("cpu"))
1183
+
1184
+ eval_time = time.time() - eval_start
1185
+ # normalize eval metrics
1186
+ eval_metrics = {
1187
+ key: torch.mean(torch.cat([d[key] for d in eval_metrics])).to("cpu") for key in eval_metrics[0]
1188
+ }
1189
+
1190
+ # compute metrics
1191
+ metrics_desc = ""
1192
+ if training_args.predict_with_generate and (cur_step % eval_generation_steps == 0 or cur_step == total_train_steps):
1193
+ if accelerator.is_local_main_process:
1194
+ (
1195
+ metric_values,
1196
+ pred_descriptions,
1197
+ pred_prompts,
1198
+ audios,
1199
+ transcriptions,
1200
+ si_sdr_measures,
1201
+ ) = compute_metrics(
1202
+ eval_preds,
1203
+ eval_descriptions,
1204
+ eval_prompts,
1205
+ accelerator.device,
1206
+ training_args.compute_clap_similarity_metric,
1207
+ training_args.compute_noise_level_metric,
1208
+ training_args.noise_level_to_compute_clean_wer,
1209
+ )
1210
+ eval_metrics.update(metric_values)
1211
+ metrics_desc = " ".join([f"Eval {key}: {value} |" for key, value in metric_values.items()])
1212
+ if "wandb" in training_args.report_to:
1213
+ log_pred(
1214
+ accelerator,
1215
+ pred_descriptions,
1216
+ pred_prompts,
1217
+ transcriptions,
1218
+ audios,
1219
+ si_sdr_measures,
1220
+ sampling_rate=sampling_rate,
1221
+ step=cur_step,
1222
+ prefix="eval",
1223
+ )
1224
+ accelerator.wait_for_everyone()
1225
+
1226
+ # Print metrics and update progress bar
1227
+ if accelerator.is_local_main_process:
1228
+ steps_trained_progress_bar.write(
1229
+ f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
1230
+ f" {metrics_desc})"
1231
+ )
1232
+
1233
+ log_metric(
1234
+ accelerator,
1235
+ metrics=eval_metrics,
1236
+ train_time=eval_time,
1237
+ step=cur_step,
1238
+ epoch=epoch,
1239
+ prefix="eval",
1240
+ )
1241
+
1242
+ # release eval batch and relax metrics
1243
+ eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric = release_memory(
1244
+ eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric
1245
+ )
1246
+ if training_args.predict_with_generate and (cur_step % eval_generation_steps == 0 or cur_step == total_train_steps):
1247
+ generated_audios, input_ids, prompts = release_memory(generated_audios, input_ids, prompts)
1248
+
1249
+ # train mode
1250
+ model.train()
1251
+
1252
+ # flush the train metrics
1253
+ train_start = time.time()
1254
+
1255
+ # break condition
1256
+ if cur_step == total_train_steps:
1257
+ continue_training = False
1258
+ break
1259
+
1260
+ if not continue_training:
1261
+ break
1262
+
1263
+ accelerator.end_training()
1264
+
1265
+
1266
+ if __name__ == "__main__":
1267
+ main()
capspeech/ar/training/run_parler_tts_training.py ADDED
@@ -0,0 +1,1279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """ Train Parler-TTS using 🤗 Accelerate"""
18
+
19
+ import logging
20
+ import os
21
+ import re
22
+ import sys
23
+ import time
24
+ import math
25
+ import contextlib
26
+ from multiprocess import set_start_method
27
+ from datetime import timedelta
28
+ import inspect
29
+ from tqdm import tqdm
30
+ from pathlib import Path
31
+ import wandb
32
+
33
+ import torch
34
+ from torch.utils.data import DataLoader
35
+
36
+ import datasets
37
+ from datasets import DatasetDict, Dataset, IterableDataset, concatenate_datasets
38
+
39
+ from huggingface_hub import HfApi
40
+
41
+ import transformers
42
+ from transformers import AutoFeatureExtractor, AutoTokenizer, HfArgumentParser
43
+ from transformers.trainer_pt_utils import LengthGroupedSampler
44
+ from transformers.optimization import get_scheduler
45
+ from transformers.utils import send_example_telemetry
46
+
47
+
48
+ from accelerate import Accelerator, skip_first_batches
49
+ from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin, DistributedDataParallelKwargs
50
+ from accelerate.utils.memory import release_memory
51
+
52
+ from parler_tts import (
53
+ ParlerTTSConfig,
54
+ ParlerTTSForConditionalGeneration,
55
+ build_delay_pattern_mask,
56
+ )
57
+
58
+ from training.utils import (
59
+ get_last_checkpoint,
60
+ rotate_checkpoints,
61
+ log_pred,
62
+ log_metric,
63
+ load_all_codec_checkpoints,
64
+ save_codec_checkpoint,
65
+ get_last_codec_checkpoint_step,
66
+ )
67
+ from training.arguments import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments
68
+ from training.data import load_multiple_datasets, DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding
69
+ from training.eval import clap_similarity, wer, si_sdr
70
+
71
+ logger = logging.getLogger(__name__)
72
+
73
+
74
+ def main():
75
+ # See all possible arguments in src/transformers/training_args.py
76
+ # or by passing the --help flag to this script.
77
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
78
+
79
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments))
80
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
81
+ # If we pass only one argument to the script and it's the path to a json file,
82
+ # let's parse it to get our arguments.
83
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
84
+ else:
85
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
86
+
87
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
88
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
89
+ send_example_telemetry("run_parler_tts", model_args, data_args)
90
+
91
+ if data_args.wandb_key is not None:
92
+ wandb.login(key=data_args.wandb_key)
93
+
94
+ if training_args.dtype == "float16":
95
+ mixed_precision = "fp16"
96
+ torch_dtype = torch.float16
97
+ elif training_args.dtype == "bfloat16":
98
+ mixed_precision = "bf16"
99
+ torch_dtype = torch.bfloat16
100
+ else:
101
+ mixed_precision = "no"
102
+ torch_dtype = torch.float32
103
+
104
+ if data_args.pad_to_max_length and (
105
+ data_args.max_duration_in_seconds is None
106
+ or data_args.max_prompt_token_length is None
107
+ or data_args.max_description_token_length is None
108
+ ):
109
+ raise ValueError(
110
+ "`pad_to_max_length` is `True` but one of the following parameters has not been set: `max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`"
111
+ )
112
+
113
+ padding = "max_length" if data_args.pad_to_max_length else "longest"
114
+
115
+ ####### A. Preparation
116
+ kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=120)), DistributedDataParallelKwargs(find_unused_parameters=False)]
117
+
118
+ accelerator = Accelerator(
119
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
120
+ mixed_precision=mixed_precision,
121
+ log_with=training_args.report_to,
122
+ project_dir=training_args.output_dir,
123
+ kwargs_handlers=kwargs_handlers,
124
+ )
125
+
126
+ accelerator.init_trackers(
127
+ project_name=data_args.wandb_project,
128
+ config={
129
+ "learning_rate": training_args.learning_rate,
130
+ "model_name_or_path": model_args.model_name_or_path,
131
+ "num_train_epochs": training_args.num_train_epochs,
132
+ "gradient_accumulation_steps": training_args.gradient_accumulation_steps,
133
+ "per_device_train_batch_size": training_args.per_device_train_batch_size,
134
+ "global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
135
+ "mixed_precision": mixed_precision,
136
+ "lr_scheduler_type": training_args.lr_scheduler_type,
137
+ "warmup_steps": training_args.warmup_steps,
138
+ "freeze_text_encoder": model_args.freeze_text_encoder,
139
+ "max_duration_in_seconds": data_args.max_duration_in_seconds,
140
+ "weight_decay": training_args.weight_decay,
141
+ "adam_beta1": training_args.adam_beta1,
142
+ "adam_beta2": training_args.adam_beta2,
143
+ "temperature": model_args.temperature,
144
+ },
145
+ init_kwargs={"wandb": {"name": data_args.wandb_run_name}} if data_args.wandb_run_name else {},
146
+ )
147
+
148
+ # Detecting last checkpoint and eventually continue from last checkpoint
149
+ last_checkpoint = None
150
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
151
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
152
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
153
+ raise ValueError(
154
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
155
+ "Use --overwrite_output_dir to overcome."
156
+ )
157
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
158
+ logger.info(
159
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
160
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
161
+ )
162
+
163
+ # Setup logging
164
+ logging.basicConfig(
165
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
166
+ datefmt="%m/%d/%Y %H:%M:%S",
167
+ handlers=[logging.StreamHandler(sys.stdout)],
168
+ )
169
+ logger.setLevel(logging.INFO if accelerator.is_main_process else logging.WARN)
170
+
171
+ # Log a small summary on each proces
172
+ logger.warning(
173
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
174
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
175
+ )
176
+
177
+ # Set the verbosity to info of the Transformers logger (on main process only)
178
+ if accelerator.is_local_main_process:
179
+ datasets.utils.logging.set_verbosity_warning()
180
+ transformers.utils.logging.set_verbosity_info()
181
+ else:
182
+ datasets.utils.logging.set_verbosity_error()
183
+ transformers.utils.logging.set_verbosity_error()
184
+
185
+ logger.info("Training/evaluation parameters %s", training_args)
186
+
187
+ # Set seed before initializing model.
188
+ set_seed(training_args.seed)
189
+ num_workers = data_args.preprocessing_num_workers
190
+
191
+ # 1. First, lett's instantiate the feature extractor, tokenizers and model
192
+ # Note for distributed training, the .from_pretrained methods guarantee that only
193
+ # one local process can concurrently download model & vocab.
194
+
195
+ # load feature extractor
196
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
197
+ model_args.feature_extractor_name or model_args.model_name_or_path,
198
+ cache_dir=model_args.cache_dir,
199
+ token=data_args.token,
200
+ trust_remote_code=data_args.trust_remote_code,
201
+ )
202
+ sampling_rate = feature_extractor.sampling_rate
203
+
204
+ # load prompt tokenizer
205
+ prompt_tokenizer = AutoTokenizer.from_pretrained(
206
+ model_args.prompt_tokenizer_name or model_args.description_tokenizer_name or model_args.model_name_or_path,
207
+ cache_dir=model_args.cache_dir,
208
+ token=data_args.token,
209
+ trust_remote_code=data_args.trust_remote_code,
210
+ use_fast=model_args.use_fast_tokenizer,
211
+ padding_side=model_args.prompt_padding_side,
212
+ )
213
+
214
+ # load description tokenizer
215
+ description_tokenizer = AutoTokenizer.from_pretrained(
216
+ model_args.description_tokenizer_name or model_args.model_name_or_path,
217
+ cache_dir=model_args.cache_dir,
218
+ token=data_args.token,
219
+ trust_remote_code=data_args.trust_remote_code,
220
+ use_fast=model_args.use_fast_tokenizer,
221
+ )
222
+
223
+ if model_args.use_fast_tokenizer:
224
+ logger.warning(
225
+ "Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235"
226
+ )
227
+ prompt_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
228
+ description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
229
+
230
+ # 2. Now, let's load the dataset
231
+
232
+ if data_args.save_to_disk is not None:
233
+ os.makedirs(data_args.save_to_disk, exist_ok=True)
234
+
235
+ # assume that the dataset has been saved to `save_to_disk` if the latter is not empty
236
+ dataset_was_precomputed = len(os.listdir(data_args.save_to_disk)) > 0
237
+ if dataset_was_precomputed:
238
+ with accelerator.local_main_process_first():
239
+ vectorized_datasets = datasets.load_from_disk(data_args.save_to_disk)
240
+ else:
241
+ raw_datasets = DatasetDict()
242
+
243
+ columns_to_keep = {
244
+ "target_audio_column_name": data_args.target_audio_column_name,
245
+ "prompt_column_name": data_args.prompt_column_name,
246
+ "source": data_args.source_column_name,
247
+ }
248
+ if data_args.description_column_name is not None:
249
+ columns_to_keep["description_column_name"] = data_args.description_column_name
250
+
251
+ if training_args.do_train:
252
+ raw_datasets["train"] = load_multiple_datasets(
253
+ accelerator,
254
+ data_args.train_dataset_name,
255
+ splits=data_args.train_split_name,
256
+ dataset_samples=data_args.train_dataset_samples,
257
+ seed=training_args.seed,
258
+ cache_dir=model_args.cache_dir,
259
+ num_proc=data_args.preprocessing_num_workers,
260
+ id_column_name=data_args.id_column_name,
261
+ columns_to_keep=columns_to_keep.values(),
262
+ prompt_column_name=data_args.prompt_column_name,
263
+ audio_column_name=data_args.target_audio_column_name,
264
+ sampling_rate=sampling_rate,
265
+ logger=logger,
266
+ mls_dir=data_args.mls_dir,
267
+ librittsrmix_dir=data_args.librittsrmix_dir,
268
+ gigaspeech_dir=data_args.gigaspeech_dir,
269
+ commonvoice_dir=data_args.commonvoice_dir,
270
+ emilia_dir=data_args.emilia_dir,
271
+ # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
272
+ )
273
+
274
+ for key in columns_to_keep:
275
+ if columns_to_keep[key] not in raw_datasets["train"].column_names:
276
+ raise ValueError(
277
+ f"--{key} '{columns_to_keep[key]}' not found in dataset '{data_args.train_dataset_name}'."
278
+ f" Make sure to set `--{key}` to the correct audio column - one of"
279
+ f" {', '.join(raw_datasets['train'].column_names)}."
280
+ )
281
+
282
+ if data_args.max_train_samples is not None:
283
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
284
+
285
+ if training_args.do_eval:
286
+ raw_datasets["eval"] = load_multiple_datasets(
287
+ accelerator,
288
+ data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
289
+ splits=data_args.eval_split_name,
290
+ cache_dir=model_args.cache_dir,
291
+ num_proc=data_args.preprocessing_num_workers,
292
+ id_column_name=data_args.id_column_name,
293
+ columns_to_keep=columns_to_keep.values(),
294
+ prompt_column_name=data_args.prompt_column_name,
295
+ audio_column_name=data_args.target_audio_column_name,
296
+ sampling_rate=sampling_rate,
297
+ logger=logger,
298
+ mls_dir=data_args.mls_dir,
299
+ librittsrmix_dir=data_args.librittsrmix_dir,
300
+ gigaspeech_dir=data_args.gigaspeech_dir,
301
+ commonvoice_dir=data_args.commonvoice_dir,
302
+ emilia_dir=data_args.emilia_dir
303
+ # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
304
+ )
305
+
306
+ if data_args.max_eval_samples is not None:
307
+ with accelerator.local_main_process_first():
308
+ raw_datasets["eval"] = (
309
+ raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
310
+ )
311
+
312
+ # 3. Next, let's load the config.
313
+ config = ParlerTTSConfig.from_pretrained(
314
+ model_args.model_name_or_path,
315
+ cache_dir=model_args.cache_dir,
316
+ token=data_args.token,
317
+ trust_remote_code=data_args.trust_remote_code,
318
+ )
319
+
320
+ if training_args.codebook_weights is not None and len(training_args.codebook_weights) != config.decoder.num_codebooks:
321
+ raise ValueError(f"`codebook_weights` has length {len(training_args.codebook_weights)} when it should be of length {config.decoder.num_codebooks}.")
322
+
323
+ # update pad token id and decoder_start_token_id
324
+ config.decoder.update(
325
+ {
326
+ "cross_attention_implementation_strategy": model_args.cross_attention_implementation_strategy
327
+ if model_args.cross_attention_implementation_strategy is not None
328
+ else None,
329
+ "codebook_weights": training_args.codebook_weights if training_args.codebook_weights is not None else config.decoder.codebook_weights
330
+ }
331
+ )
332
+ config.update(
333
+ {
334
+ "pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id,
335
+ "decoder_start_token_id": model_args.decoder_start_token_id
336
+ if model_args.decoder_start_token_id is not None
337
+ else config.decoder_start_token_id,
338
+ }
339
+ )
340
+
341
+ with open("events.txt", "r") as f:
342
+ events = [line.strip() for line in f]
343
+ events = ["<"+event.lower().replace(" ", "_")+">" for event in events]
344
+ events.append("<B_start>")
345
+ events.append("<B_end>")
346
+ events.append("<I_start>")
347
+ events.append("<I_end>")
348
+
349
+ special_tokens = {"additional_special_tokens": events}
350
+ prompt_tokenizer.add_special_tokens(special_tokens)
351
+ description_tokenizer.add_special_tokens(special_tokens)
352
+ padded_vocab_size = ((len(prompt_tokenizer) + 127) // 128) * 128
353
+ config.vocab_size = padded_vocab_size
354
+
355
+ # create model
356
+ model = ParlerTTSForConditionalGeneration.from_pretrained(
357
+ model_args.model_name_or_path,
358
+ ignore_mismatched_sizes=True,
359
+ cache_dir=model_args.cache_dir,
360
+ config=config,
361
+ token=data_args.token,
362
+ trust_remote_code=data_args.trust_remote_code,
363
+ attn_implementation={"decoder": model_args.attn_implementation, "text_encoder": "eager"},
364
+ )
365
+ model.text_encoder.resize_token_embeddings(padded_vocab_size)
366
+
367
+ # enable gradient checkpointing if necessary
368
+ if training_args.gradient_checkpointing:
369
+ model.gradient_checkpointing_enable()
370
+
371
+ # 4. Now we preprocess the datasets including loading the audio, resampling and normalization
372
+ # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
373
+ # so that we just need to set the correct target sampling rate and normalize the input
374
+ # via the `feature_extractor`
375
+
376
+ # derive max & min input length for sample rate & max duration
377
+ sampling_rate = feature_extractor.sampling_rate
378
+ max_target_length = int(data_args.max_duration_in_seconds * sampling_rate)
379
+ min_target_length = int(data_args.min_duration_in_seconds * sampling_rate)
380
+ target_audio_column_name = data_args.target_audio_column_name
381
+ description_column_name = data_args.description_column_name
382
+ prompt_column_name = data_args.prompt_column_name
383
+ feature_extractor_input_name = feature_extractor.model_input_names[0]
384
+ audio_encoder_pad_token_id = config.decoder.pad_token_id
385
+ audio_encoder_eos_token_id = config.decoder.eos_token_id
386
+ audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id
387
+ max_length = model.generation_config.max_length
388
+ num_codebooks = model.decoder.config.num_codebooks
389
+ bandwidth = model_args.bandwidth
390
+ attn_implementation = model_args.attn_implementation
391
+
392
+ # Freeze Encoders
393
+ model.freeze_encoders(model_args.freeze_text_encoder)
394
+
395
+ # Test all gather - used for warmout and avoiding timeout
396
+ logger.debug(str(accelerator.process_index), main_process_only=False, in_order=True)
397
+ test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
398
+ gathered_tensor = accelerator.gather(test_tensor)
399
+ print("gathered_tensor", gathered_tensor)
400
+ accelerator.wait_for_everyone()
401
+
402
+ if not dataset_was_precomputed:
403
+ # Filter on text length
404
+ if description_column_name is not None and data_args.max_text_length is not None:
405
+ with accelerator.local_main_process_first():
406
+ # filter description that is shorter than max_text_length
407
+ raw_datasets = raw_datasets.filter(
408
+ lambda x: len(x) < data_args.max_text_length,
409
+ num_proc=num_workers,
410
+ input_columns=[description_column_name],
411
+ )
412
+
413
+ # Preprocessing the dataset.
414
+ # We need to tokenize the texts.
415
+ def pass_through_processors(description, prompt):
416
+ batch = {}
417
+
418
+ batch["input_ids"] = description_tokenizer(description.strip())["input_ids"]
419
+ batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"]
420
+
421
+ return batch
422
+
423
+ with accelerator.local_main_process_first():
424
+ # this is a trick to avoid to rewrite the entire audio column which takes ages
425
+ vectorized_datasets = raw_datasets.map(
426
+ pass_through_processors,
427
+ remove_columns=next(iter(raw_datasets.values())).column_names,
428
+ input_columns=[description_column_name, prompt_column_name],
429
+ num_proc=num_workers,
430
+ desc="preprocess datasets",
431
+ )
432
+
433
+ # We use Accelerate to perform distributed inference
434
+ # T5 doesn't support fp16
435
+ autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
436
+
437
+ # Now we encode the audio labels with encodec.
438
+ ####### B. Encode audio
439
+
440
+ logger.info("*** Encode target audio with encodec ***")
441
+
442
+ # no need to prepare audio_decoder because used for inference without mixed precision
443
+ # see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
444
+ if training_args.torch_compile:
445
+ audio_decoder = accelerator.prepare_model(model.audio_encoder, evaluation_mode=True)
446
+ else:
447
+ audio_decoder = model.audio_encoder
448
+
449
+ encoder_data_collator = DataCollatorEncodecWithPadding(
450
+ feature_extractor,
451
+ audio_column_name=target_audio_column_name,
452
+ mls_dir=data_args.mls_dir,
453
+ librittsrmix_dir=data_args.librittsrmix_dir,
454
+ gigaspeech_dir=data_args.gigaspeech_dir,
455
+ commonvoice_dir=data_args.commonvoice_dir,
456
+ emilia_dir=data_args.emilia_dir,
457
+ feature_extractor_input_name=feature_extractor_input_name,
458
+ max_length=max_target_length,
459
+ padding=padding,
460
+ )
461
+ encoder_signature = set(inspect.signature(audio_decoder.forward).parameters)
462
+
463
+ def apply_audio_decoder(batch):
464
+ len_audio = batch.pop("len_audio")
465
+ audio_decoder.to(batch["input_values"].device).eval()
466
+ if bandwidth is not None:
467
+ batch["bandwidth"] = bandwidth
468
+ elif "num_quantizers" in encoder_signature:
469
+ batch["num_quantizers"] = num_codebooks
470
+ elif "num_codebooks" in encoder_signature:
471
+ batch["num_codebooks"] = num_codebooks
472
+ elif "n_quantizers" in encoder_signature:
473
+ batch["n_quantizers"] = num_codebooks
474
+
475
+ with torch.no_grad():
476
+ labels = audio_decoder.encode(**batch)["audio_codes"]
477
+ output = {}
478
+ output["len_audio"] = len_audio
479
+ # (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
480
+ output["labels"] = labels.squeeze(0).transpose(1, 2)
481
+
482
+ # if `pad_to_max_length`, the maximum corresponding audio length of the current batch is max_duration*sampling_rate
483
+ max_length = len_audio.max() if padding != "max_length" else max_target_length
484
+ output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / max_length
485
+ return output
486
+
487
+ # (1, codebooks, seq_len) where seq_len=1
488
+ bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
489
+
490
+ def postprocess_dataset(labels):
491
+ # (1, codebooks, seq_len)
492
+ labels = torch.tensor(labels).unsqueeze(0)
493
+ # add bos
494
+ labels = torch.cat([bos_labels, labels], dim=-1)
495
+
496
+ labels, delay_pattern_mask = build_delay_pattern_mask(
497
+ labels,
498
+ bos_token_id=audio_encoder_bos_token_id,
499
+ pad_token_id=audio_encoder_eos_token_id,
500
+ max_length=labels.shape[-1] + num_codebooks,
501
+ num_codebooks=num_codebooks,
502
+ )
503
+
504
+ # the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
505
+ # to take care of EOS
506
+ # we want labels to look like this:
507
+ # - [B, a, b, E, E, E, E]
508
+ # - [B, B, c, d, E, E, E]
509
+ # - [B, B, B, e, f, E, E]
510
+ # - [B, B, B, B, g, h, E]
511
+ labels = torch.where(delay_pattern_mask == -1, audio_encoder_eos_token_id, delay_pattern_mask)
512
+
513
+ # the first timestamp is associated to a row full of BOS, let's get rid of it
514
+ # we also remove the last timestampts (full of PAD)
515
+ output = {"labels": labels[:, 1:]}
516
+ return output
517
+
518
+ for split in vectorized_datasets:
519
+ data_loader = DataLoader(
520
+ raw_datasets[split],
521
+ batch_size=training_args.audio_encoder_per_device_batch_size,
522
+ collate_fn=encoder_data_collator,
523
+ num_workers=training_args.dataloader_num_workers,
524
+ pin_memory=True,
525
+ )
526
+ data_loader = accelerator.prepare(data_loader)
527
+ total_inference_steps = len(data_loader)
528
+
529
+ start_step = get_last_codec_checkpoint_step(os.path.join(data_args.temporary_save_to_disk, split))
530
+ accelerator.wait_for_everyone()
531
+ if start_step > 0:
532
+ logger.info(f"Resuming {split} from step {start_step}")
533
+ # efficiently skip the first n batches
534
+ start_step += 1
535
+ data_loader = skip_first_batches(data_loader, start_step)
536
+
537
+ all_generated_labels = []
538
+ all_lens = []
539
+ if start_step < total_inference_steps:
540
+ for i, batch in enumerate(tqdm(data_loader, disable=not accelerator.is_local_main_process)):
541
+ cur_step = start_step + i
542
+ generate_labels = apply_audio_decoder(batch)
543
+ generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0)
544
+ generate_labels = accelerator.gather_for_metrics(generate_labels)
545
+
546
+ if accelerator.is_main_process:
547
+ lab = generate_labels["labels"].cpu().transpose(1, 2).to(torch.int16)
548
+ rat = generate_labels["ratio"].cpu().squeeze(1)
549
+ lens = generate_labels["len_audio"].cpu().squeeze(1)
550
+ lab = [l[:, : int(ratio * length)] for (l, ratio, length) in zip(lab, rat, lens)]
551
+
552
+ all_generated_labels.extend(lab)
553
+ all_lens.extend(lens)
554
+
555
+ if ((cur_step + 1) % data_args.save_codec_steps == 0) or (
556
+ cur_step == total_inference_steps - 1
557
+ ):
558
+ tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
559
+ tmp_labels = tmp_labels.map(
560
+ postprocess_dataset,
561
+ num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
562
+ input_columns=["labels"],
563
+ desc="Postprocessing labeling",
564
+ )
565
+ save_codec_checkpoint(
566
+ os.path.join(data_args.temporary_save_to_disk, split), tmp_labels, cur_step
567
+ )
568
+ all_generated_labels = []
569
+ all_lens = []
570
+
571
+ accelerator.wait_for_everyone()
572
+
573
+ if accelerator.is_main_process and len(all_generated_labels) > 0:
574
+ tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
575
+ tmp_labels = tmp_labels.map(
576
+ postprocess_dataset,
577
+ num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
578
+ input_columns=["labels"],
579
+ desc="Postprocessing labeling",
580
+ )
581
+ save_codec_checkpoint(os.path.join(data_args.temporary_save_to_disk, split), tmp_labels, cur_step)
582
+ all_generated_labels = []
583
+ all_lens = []
584
+ accelerator.wait_for_everyone()
585
+
586
+ del all_generated_labels
587
+ accelerator.wait_for_everyone()
588
+
589
+ with accelerator.local_main_process_first():
590
+ tmp_labels = load_all_codec_checkpoints(os.path.join(data_args.temporary_save_to_disk, split)).select(
591
+ range(len(vectorized_datasets[split]))
592
+ )
593
+ logger.info(f"Concatenating {split}: {tmp_labels} with {vectorized_datasets[split]}")
594
+ vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1)
595
+
596
+ accelerator.free_memory()
597
+ del generate_labels, all_lens
598
+
599
+ with accelerator.local_main_process_first():
600
+ # NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations
601
+ # caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets.
602
+ # That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets.
603
+
604
+ def is_audio_in_length_range(length):
605
+ return length > min_target_length and length < max_target_length
606
+
607
+ # filter data that is shorter than min_target_length
608
+ vectorized_datasets = vectorized_datasets.filter(
609
+ is_audio_in_length_range,
610
+ num_proc=num_workers,
611
+ input_columns=["target_length"],
612
+ )
613
+
614
+ if description_column_name is not None and data_args.max_description_token_length is not None:
615
+ with accelerator.local_main_process_first():
616
+ # filter description that is shorter than max_text_length
617
+ vectorized_datasets = vectorized_datasets.filter(
618
+ lambda x: len(x) < data_args.max_description_token_length,
619
+ num_proc=num_workers,
620
+ input_columns=["input_ids"],
621
+ )
622
+
623
+ if data_args.max_prompt_token_length is not None:
624
+ with accelerator.local_main_process_first():
625
+ # filter description that is shorter than max_text_length
626
+ vectorized_datasets = vectorized_datasets.filter(
627
+ lambda x: len(x) < data_args.max_prompt_token_length,
628
+ num_proc=num_workers,
629
+ input_columns=["prompt_input_ids"],
630
+ )
631
+
632
+ if data_args.save_to_disk is not None and not dataset_was_precomputed:
633
+ if accelerator.is_main_process:
634
+ vectorized_datasets.save_to_disk(
635
+ data_args.save_to_disk,
636
+ num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"]) - 1),
637
+ )
638
+ accelerator.wait_for_everyone()
639
+ logger.info(f"Dataset saved at {data_args.save_to_disk}")
640
+
641
+ audio_max_length = None
642
+ if padding == "max_length":
643
+ audio_max_length = max(vectorized_datasets["train"]["target_length"])
644
+ with accelerator.local_main_process_first():
645
+ max_sample = vectorized_datasets["train"].filter(
646
+ lambda x: x == audio_max_length,
647
+ num_proc=num_workers,
648
+ input_columns=["target_length"],
649
+ )
650
+ audio_max_length = max([len(l[0]) for l in max_sample["labels"]])
651
+
652
+ if description_column_name is not None and data_args.max_description_token_length is not None:
653
+ with accelerator.local_main_process_first():
654
+ # filter description that is shorter than max_text_length
655
+ vectorized_datasets = vectorized_datasets.filter(
656
+ lambda x: len(x) < data_args.max_description_token_length,
657
+ num_proc=num_workers,
658
+ input_columns=["input_ids"],
659
+ )
660
+
661
+ if data_args.max_prompt_token_length is not None:
662
+ with accelerator.local_main_process_first():
663
+ # filter description that is shorter than max_text_length
664
+ vectorized_datasets = vectorized_datasets.filter(
665
+ lambda x: len(x) < data_args.max_prompt_token_length,
666
+ num_proc=num_workers,
667
+ input_columns=["prompt_input_ids"],
668
+ )
669
+
670
+ if training_args.group_by_length:
671
+ # apply a simple heuristic to take into account audio and text lengths
672
+ def add_target_lengths(target_length, prompt, description):
673
+ return {"target_length": target_length + len(prompt) + len(description)}
674
+
675
+ with accelerator.local_main_process_first():
676
+ vectorized_datasets = vectorized_datasets.map(
677
+ add_target_lengths,
678
+ num_proc=num_workers,
679
+ input_columns=["target_length", "prompt_input_ids", "input_ids"],
680
+ )
681
+
682
+ # for large datasets it is advised to run the preprocessing on a
683
+ # single machine first with ``args.preprocessing_only`` since there will mostly likely
684
+ # be a timeout when running the script in distributed mode.
685
+ # In a second step ``args.preprocessing_only`` can then be set to `False` to load the
686
+ # cached dataset
687
+ if data_args.preprocessing_only and data_args.save_to_disk is None:
688
+ raise ValueError(
689
+ "`preprocessing_only=True` but `save_to_disk` is not set. The latter should indicates where to save the dataset locally."
690
+ )
691
+ elif data_args.preprocessing_only:
692
+ logger.info(f"Data preprocessing finished. Files save at {data_args.save_to_disk}")
693
+ return
694
+
695
+ # 6. Next, we can prepare the training.
696
+
697
+ # Let's use word CLAP similary and WER metrics as our evaluation metrics,
698
+ def compute_metrics(
699
+ audios,
700
+ descriptions,
701
+ prompts,
702
+ device="cpu",
703
+ compute_clap_similarity_metric=False,
704
+ compute_noise_level_metric=False,
705
+ noise_level_to_compute_clean_wer=None,
706
+ ):
707
+ results = {}
708
+ input_ids = descriptions
709
+ texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
710
+ prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
711
+ audios = [a.float().cpu().numpy() for a in audios]
712
+
713
+ if compute_clap_similarity_metric:
714
+ clap_score = clap_similarity(
715
+ model_args.clap_model_name_or_path, texts, audios, device, input_sampling_rate=sampling_rate
716
+ )
717
+ results["clap"] = clap_score
718
+
719
+ si_sdr_measures = None
720
+ if compute_noise_level_metric:
721
+ si_sdr_measures = si_sdr(audios, device, input_sampling_rate=sampling_rate)
722
+
723
+ word_error, transcriptions, clean_word_error, noisy_word_error, percent_clean_samples = wer(
724
+ model_args.asr_model_name_or_path,
725
+ prompts,
726
+ audios,
727
+ device,
728
+ training_args.per_device_eval_batch_size,
729
+ sampling_rate,
730
+ noise_level_to_compute_clean_wer,
731
+ si_sdr_measures,
732
+ )
733
+ results["wer"] = word_error
734
+ if clean_word_error is not None:
735
+ results["clean_wer"] = clean_word_error
736
+ results["noisy_word_error"] = noisy_word_error
737
+ results["percent_clean_samples"] = percent_clean_samples
738
+
739
+ return results, texts, prompts, audios, transcriptions, si_sdr_measures
740
+
741
+ # Define Training Schedule
742
+ # Store some constants
743
+ per_device_train_batch_size = int(training_args.per_device_train_batch_size)
744
+ train_batch_size = per_device_train_batch_size * accelerator.num_processes
745
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
746
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
747
+
748
+ if training_args.max_steps < 0:
749
+ num_epochs = int(training_args.num_train_epochs)
750
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
751
+ total_train_steps = steps_per_epoch * num_epochs
752
+ elif training_args.max_steps > 0:
753
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
754
+ total_train_steps = int(training_args.max_steps)
755
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
756
+ num_epochs = sys.maxsize
757
+ steps_per_epoch = total_train_steps
758
+
759
+ if training_args.eval_steps is None:
760
+ logger.info(f"eval_steps is not set, evaluating at the end of each epoch")
761
+ eval_steps = steps_per_epoch
762
+ else:
763
+ eval_steps = training_args.eval_steps
764
+
765
+ if training_args.eval_generation_steps is None:
766
+ eval_generation_steps = eval_steps
767
+ else:
768
+ eval_generation_steps = training_args.eval_generation_steps
769
+
770
+ # T5 doesn't support fp16
771
+ autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
772
+
773
+ # Define optimizer, LR scheduler, collator
774
+ optimizer = torch.optim.AdamW(
775
+ params=model.parameters(),
776
+ lr=training_args.learning_rate,
777
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
778
+ eps=training_args.adam_epsilon,
779
+ weight_decay=training_args.weight_decay,
780
+ )
781
+
782
+ # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
783
+ lr_scheduler = get_scheduler(
784
+ name=training_args.lr_scheduler_type,
785
+ optimizer=optimizer,
786
+ num_warmup_steps=training_args.get_warmup_steps(total_train_steps) * accelerator.num_processes,
787
+ num_training_steps=total_train_steps * accelerator.num_processes,
788
+ )
789
+
790
+ # Instantiate custom data collator
791
+ data_collator = DataCollatorParlerTTSWithPadding(
792
+ prompt_tokenizer=prompt_tokenizer,
793
+ description_tokenizer=description_tokenizer,
794
+ pad_to_multiple_of=data_args.pad_to_multiple_of,
795
+ padding=padding,
796
+ prompt_max_length=data_args.max_prompt_token_length,
797
+ description_max_length=data_args.max_description_token_length,
798
+ audio_max_length=audio_max_length,
799
+ )
800
+
801
+ # Prepare everything with accelerate
802
+ model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
803
+
804
+ num_examples = total_train_steps * train_batch_size * gradient_accumulation_steps
805
+ logger.info("***** Running training *****")
806
+ logger.info(f" Num examples = {num_examples}")
807
+ logger.info(" Instantaneous batch size per device =" f" {per_device_train_batch_size}")
808
+ logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
809
+ logger.info(
810
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
811
+ )
812
+ logger.info(f" Total optimization steps = {total_train_steps}")
813
+
814
+ # ======================== Training ================================
815
+ train_time = 0
816
+ train_start = time.time()
817
+ steps_trained_progress_bar = tqdm(
818
+ range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
819
+ )
820
+ continue_training = True
821
+ epochs_trained = 0
822
+ cur_step = 0
823
+
824
+ checkpoint = None
825
+ if training_args.resume_from_checkpoint is not None:
826
+ checkpoint = training_args.resume_from_checkpoint
827
+ elif last_checkpoint is not None:
828
+ checkpoint = last_checkpoint
829
+
830
+ if accelerator.is_main_process:
831
+ if training_args.push_to_hub:
832
+ api = HfApi(token=training_args.hub_token)
833
+
834
+ # Create repo (repo_name from args or inferred)
835
+ repo_name = training_args.hub_model_id
836
+ if repo_name is None:
837
+ repo_name = Path(training_args.output_dir).absolute().name
838
+ repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
839
+
840
+ with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
841
+ if "wandb" not in gitignore:
842
+ gitignore.write("wandb\n")
843
+ elif training_args.output_dir is not None:
844
+ os.makedirs(training_args.output_dir, exist_ok=True)
845
+ accelerator.wait_for_everyone()
846
+
847
+ # Now save everything to be able to create a single processor later
848
+ # make sure all processes wait until data is saved
849
+ # only the main process saves them
850
+ if accelerator.is_main_process:
851
+ # save feature extractor, tokenizer and config
852
+ if (
853
+ model_args.prompt_tokenizer_name is None
854
+ and model_args.description_tokenizer_name
855
+ or (model_args.prompt_tokenizer_name == model_args.description_tokenizer_name)
856
+ ):
857
+ prompt_tokenizer.save_pretrained(training_args.output_dir)
858
+ else:
859
+ logger.warning(
860
+ f"Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer."
861
+ )
862
+ prompt_tokenizer.save_pretrained(training_args.output_dir)
863
+
864
+ feature_extractor.save_pretrained(training_args.output_dir)
865
+ config.save_pretrained(training_args.output_dir)
866
+ accelerator.wait_for_everyone()
867
+
868
+ if checkpoint is not None:
869
+ accelerator.load_state(checkpoint)
870
+ # Find num steps and epoch from saved state string pattern
871
+ pattern = r"checkpoint-(\d+)-epoch-(\d+)"
872
+ match = re.search(pattern, checkpoint)
873
+ cur_step = int(match.group(1))
874
+ epochs_trained = int(match.group(2))
875
+
876
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
877
+ logger.info(f" Continuing training from epoch {epochs_trained}")
878
+ logger.info(f" Continuing training from global step {cur_step}")
879
+
880
+ steps_trained_progress_bar.update(cur_step)
881
+
882
+ for epoch in range(0, epochs_trained):
883
+ with accelerator.local_main_process_first():
884
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
885
+
886
+ if training_args.max_steps < 0:
887
+ # we know exactly the number of steps per epoch, so can skip through the required number of batches
888
+ resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
889
+ else:
890
+ # Currently we don't know how many steps we've taken in the current epoch
891
+ # So we just shuffle the dataset one extra time and start from a fresh epoch
892
+ # This is "good enough" for our purposes but not fully correct
893
+ resume_step = None
894
+ with accelerator.local_main_process_first():
895
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
896
+ else:
897
+ resume_step = None
898
+
899
+ gen_kwargs = {
900
+ "do_sample": model_args.do_sample,
901
+ "temperature": model_args.temperature,
902
+ "max_length": model_args.max_length,
903
+ # Because of the delayed pattern mask, generation might stop earlier because of unexpected behaviour
904
+ # on the first tokens of the codebooks that are delayed.
905
+ # This fix the issue.
906
+ "min_new_tokens": num_codebooks + 1,
907
+ }
908
+
909
+ # Define gradient update step fn
910
+ def train_step(
911
+ batch,
912
+ accelerator,
913
+ autocast_kwargs,
914
+ num_items_in_batch,
915
+ gradient_accumulation_steps,
916
+ ):
917
+ if mixed_precision == "fp16":
918
+ # fp16 doesn't work with T5-like models
919
+ with accelerator.autocast(autocast_handler=autocast_kwargs):
920
+ if training_args.parallel_mode.value != "distributed":
921
+ encoder_outputs = model.text_encoder(
922
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
923
+ )
924
+ else:
925
+ encoder_outputs = model.module.text_encoder(
926
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
927
+ )
928
+ # we optionnally project last_hidden_state to avoid recomputing every time
929
+ encoder_hidden_states = encoder_outputs.last_hidden_state
930
+ if (
931
+ config.text_encoder.hidden_size != config.decoder.hidden_size
932
+ and config.decoder.cross_attention_hidden_size is None
933
+ ):
934
+ encoder_hidden_states = (
935
+ model.enc_to_dec_proj(encoder_hidden_states)
936
+ if training_args.parallel_mode.value != "distributed"
937
+ else model.module.enc_to_dec_proj(encoder_hidden_states)
938
+ )
939
+
940
+ if batch.get("attention_mask", None) is not None:
941
+ encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None]
942
+
943
+ encoder_outputs.last_hidden_state = encoder_hidden_states
944
+ batch["encoder_outputs"] = encoder_outputs
945
+
946
+ outputs = model(**batch, loss_reduction="sum")
947
+ # CE (data) loss
948
+ ce_loss = (outputs.loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch
949
+
950
+ metrics = {"loss": ce_loss}
951
+
952
+ # per CE loss
953
+ per_codebook_losses = outputs.per_codebook_losses
954
+ metrics.update({f"codebook_{i}_loss": ((l * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch) for (i,l) in enumerate(per_codebook_losses)})
955
+ return ce_loss, metrics
956
+
957
+ # Define eval fn
958
+ def eval_step(
959
+ batch,
960
+ accelerator,
961
+ autocast_kwargs,
962
+ ):
963
+ eval_model = model if not training_args.torch_compile else model._orig_mod
964
+
965
+ if mixed_precision == "fp16":
966
+ # fp16 doesn't work with T5-like models
967
+ with accelerator.autocast(autocast_handler=autocast_kwargs):
968
+ if training_args.parallel_mode.value != "distributed":
969
+ encoder_outputs = model.text_encoder(
970
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
971
+ )
972
+ else:
973
+ encoder_outputs = model.module.text_encoder(
974
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
975
+ )
976
+ # we optionnally project last_hidden_state to avoid recomputing every time
977
+ encoder_hidden_states = encoder_outputs.last_hidden_state
978
+ if (
979
+ config.text_encoder.hidden_size != config.decoder.hidden_size
980
+ and config.decoder.cross_attention_hidden_size is None
981
+ ):
982
+ encoder_hidden_states = (
983
+ model.enc_to_dec_proj(encoder_hidden_states)
984
+ if training_args.parallel_mode.value != "distributed"
985
+ else model.module.enc_to_dec_proj(encoder_hidden_states)
986
+ )
987
+
988
+ if batch.get("attention_mask", None) is not None:
989
+ encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None]
990
+
991
+ encoder_outputs.last_hidden_state = encoder_hidden_states
992
+ batch["encoder_outputs"] = encoder_outputs
993
+
994
+ with torch.no_grad():
995
+ outputs = eval_model(**batch)
996
+ # CE (data) loss
997
+ ce_loss = outputs.loss
998
+ metrics = {"loss": ce_loss}
999
+
1000
+ # per CE loss
1001
+ per_codebook_losses = outputs.per_codebook_losses
1002
+ metrics.update({f"codebook_{i}_loss": l for (i,l) in enumerate(per_codebook_losses)})
1003
+ return metrics
1004
+
1005
+ def generate_step(batch, accelerator):
1006
+ batch.pop("decoder_attention_mask", None)
1007
+ eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=True)
1008
+ if training_args.torch_compile:
1009
+ # if the model is compiled, we use the original model bc compile is not compatible with .generate
1010
+ eval_model = model._orig_mod
1011
+
1012
+ # since we've might have loaded the weights in fp32, we have to autocast to ensure FA2 weights are in half-precision.
1013
+ # with accelerator.autocast(autocast_handler=AutocastKwargs(enabled=(attn_implementation=="flash_attention_2"))):
1014
+ output_audios = eval_model.generate(**batch, **gen_kwargs)
1015
+ output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
1016
+ return output_audios
1017
+
1018
+ model.train()
1019
+
1020
+ total_batched_samples = resume_step if resume_step is not None else 0
1021
+ for epoch in range(epochs_trained, num_epochs):
1022
+ with accelerator.local_main_process_first():
1023
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1024
+ sampler = None
1025
+ if training_args.group_by_length:
1026
+ sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"])
1027
+ train_dataloader = DataLoader(
1028
+ vectorized_datasets["train"],
1029
+ collate_fn=data_collator,
1030
+ batch_size=per_device_train_batch_size,
1031
+ sampler=sampler,
1032
+ shuffle=not training_args.group_by_length,
1033
+ num_workers=training_args.dataloader_num_workers,
1034
+ pin_memory=training_args.dataloader_pin_memory,
1035
+ )
1036
+ train_dataloader = accelerator.prepare(train_dataloader)
1037
+ if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
1038
+ train_dataloader.dataset.set_epoch(epoch)
1039
+
1040
+ if resume_step is not None:
1041
+ # Skip the first N batches in the dataloader when resuming from a checkpoint
1042
+ logger.info(f" Skip first {resume_step} batches")
1043
+ train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
1044
+ resume_step = None
1045
+ accelerator.wait_for_everyone()
1046
+
1047
+ # We chunkify the epoch iterator into gradient accumulation steps `n` batches
1048
+ train_iterator = iter(train_dataloader)
1049
+ num_steps_in_epoch = len(train_dataloader)
1050
+ remainder = num_steps_in_epoch % gradient_accumulation_steps
1051
+ remainder = remainder if remainder != 0 else gradient_accumulation_steps
1052
+ total_updates = math.ceil(num_steps_in_epoch / gradient_accumulation_steps)
1053
+
1054
+ update_step = -1
1055
+ for _ in range(total_updates):
1056
+ update_step += 1
1057
+
1058
+ # preload the total batch per step
1059
+ batch_samples = []
1060
+ num_batches_in_step = gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
1061
+ for _ in range(num_batches_in_step):
1062
+ batch_samples += [next(train_iterator)]
1063
+
1064
+ # get num items in batch - if different than BOS and than -100
1065
+ num_items_in_batch = sum([(batch["labels"].ne(audio_encoder_bos_token_id) | batch["labels"].ne(-100) | batch["labels"].ne(audio_encoder_eos_token_id)).sum((0,1))[0] for batch in batch_samples])
1066
+ num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item()
1067
+
1068
+ # losses = []
1069
+ for i,batch in enumerate(batch_samples):
1070
+ total_batched_samples += 1
1071
+ ctx = model.no_sync if (i < len(batch_samples) - 1 and accelerator.num_processes > 1) else contextlib.nullcontext
1072
+
1073
+ with ctx():
1074
+ loss, train_metric = train_step(batch, accelerator, autocast_kwargs, num_items_in_batch, gradient_accumulation_steps)
1075
+ accelerator.backward(loss)
1076
+ # losses.append(loss.detach())
1077
+
1078
+ grad_norm = accelerator.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
1079
+ optimizer.step()
1080
+ lr_scheduler.step()
1081
+ optimizer.zero_grad()
1082
+
1083
+ # The accelerator has performed an optimization step behind the scenes
1084
+ steps_trained_progress_bar.update(1)
1085
+ cur_step += 1
1086
+
1087
+ # losses = accelerator.gather(sum(losses)).sum().item() / (accelerator.num_processes * gradient_accumulation_steps)
1088
+
1089
+ if cur_step % training_args.logging_steps == 0:
1090
+ steps_trained_progress_bar.write(
1091
+ f"Step... ({cur_step} / {total_train_steps} | Loss:"
1092
+ f" {train_metric['loss']}, Learning Rate:"
1093
+ f" {lr_scheduler.get_last_lr()[0]})"
1094
+ )
1095
+ train_metric["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
1096
+ log_metric(
1097
+ accelerator,
1098
+ metrics=train_metric,
1099
+ learning_rate=lr_scheduler.get_last_lr()[0],
1100
+ train_time=train_time + time.time() - train_start,
1101
+ step=cur_step,
1102
+ epoch=epoch,
1103
+ prefix="train",
1104
+ )
1105
+
1106
+ # save checkpoint and weights after each save_steps and at the end of training
1107
+ if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
1108
+ intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
1109
+ # safe_serialization=False to avoid shared tensors saving issue (TODO(YL): it's a temporary fix)
1110
+ # https://github.com/huggingface/transformers/issues/27293#issuecomment-1872560074
1111
+ accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
1112
+ accelerator.wait_for_everyone()
1113
+ if accelerator.is_main_process:
1114
+ rotate_checkpoints(
1115
+ training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger
1116
+ )
1117
+
1118
+ if cur_step == total_train_steps:
1119
+ # un-wrap student model for save
1120
+ unwrapped_model = accelerator.unwrap_model(model)
1121
+ unwrapped_model.save_pretrained(training_args.output_dir)
1122
+
1123
+ if training_args.push_to_hub:
1124
+ api.upload_folder(
1125
+ repo_id=repo_id,
1126
+ folder_path=training_args.output_dir,
1127
+ commit_message=f"Saving train state of step {cur_step}",
1128
+ run_as_future=True,
1129
+ )
1130
+ accelerator.wait_for_everyone()
1131
+
1132
+ if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
1133
+ train_time += time.time() - train_start
1134
+ # ======================== Evaluating ==============================
1135
+ model.eval()
1136
+ eval_metrics = []
1137
+ eval_preds = []
1138
+ eval_descriptions = []
1139
+ eval_prompts = []
1140
+ eval_start = time.time()
1141
+
1142
+ # release training input batch
1143
+ batch = release_memory(batch)
1144
+
1145
+ validation_dataloader = DataLoader(
1146
+ vectorized_datasets["eval"],
1147
+ collate_fn=data_collator,
1148
+ batch_size=per_device_eval_batch_size,
1149
+ drop_last=False,
1150
+ num_workers=training_args.eval_dataloader_num_workers,
1151
+ pin_memory=training_args.dataloader_pin_memory,
1152
+ )
1153
+ validation_dataloader = accelerator.prepare(validation_dataloader)
1154
+
1155
+ for batch in tqdm(
1156
+ validation_dataloader,
1157
+ desc=f"Evaluating - Inference ...",
1158
+ position=2,
1159
+ disable=not accelerator.is_local_main_process,
1160
+ ):
1161
+ # Model forward
1162
+ eval_metric = eval_step(batch, accelerator, autocast_kwargs)
1163
+ eval_metric = accelerator.gather_for_metrics(eval_metric)
1164
+ eval_metric = {key: val.unsqueeze(0) if val.ndim == 0 else val for (key,val) in eval_metric.items()}
1165
+ eval_metrics.append(eval_metric)
1166
+
1167
+ if training_args.predict_with_generate and (cur_step % eval_generation_steps == 0 or cur_step == total_train_steps):
1168
+ validation_dataloader = DataLoader(
1169
+ vectorized_datasets["eval"],
1170
+ collate_fn=data_collator,
1171
+ batch_size=per_device_eval_batch_size,
1172
+ drop_last=False,
1173
+ num_workers=training_args.eval_dataloader_num_workers,
1174
+ pin_memory=training_args.dataloader_pin_memory,
1175
+ )
1176
+ validation_dataloader = accelerator.prepare(validation_dataloader)
1177
+ # generation
1178
+ for batch in tqdm(
1179
+ validation_dataloader,
1180
+ desc=f"Evaluating - Generation ...",
1181
+ position=2,
1182
+ disable=not accelerator.is_local_main_process,
1183
+ ):
1184
+ generated_audios = generate_step(batch, accelerator)
1185
+ # Gather all predictions and targets
1186
+ generated_audios, input_ids, prompts = accelerator.pad_across_processes(
1187
+ (generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0
1188
+ )
1189
+ generated_audios, input_ids, prompts = accelerator.gather_for_metrics(
1190
+ (generated_audios, input_ids, prompts)
1191
+ )
1192
+ eval_preds.extend(generated_audios.to("cpu"))
1193
+ eval_descriptions.extend(input_ids.to("cpu"))
1194
+ eval_prompts.extend(prompts.to("cpu"))
1195
+
1196
+ eval_time = time.time() - eval_start
1197
+ # normalize eval metrics
1198
+ eval_metrics = {
1199
+ key: torch.mean(torch.cat([d[key] for d in eval_metrics])).to("cpu") for key in eval_metrics[0]
1200
+ }
1201
+
1202
+ # compute metrics
1203
+ metrics_desc = ""
1204
+ if training_args.predict_with_generate and (cur_step % eval_generation_steps == 0 or cur_step == total_train_steps):
1205
+ if accelerator.is_local_main_process:
1206
+ (
1207
+ metric_values,
1208
+ pred_descriptions,
1209
+ pred_prompts,
1210
+ audios,
1211
+ transcriptions,
1212
+ si_sdr_measures,
1213
+ ) = compute_metrics(
1214
+ eval_preds,
1215
+ eval_descriptions,
1216
+ eval_prompts,
1217
+ accelerator.device,
1218
+ training_args.compute_clap_similarity_metric,
1219
+ training_args.compute_noise_level_metric,
1220
+ training_args.noise_level_to_compute_clean_wer,
1221
+ )
1222
+ eval_metrics.update(metric_values)
1223
+ metrics_desc = " ".join([f"Eval {key}: {value} |" for key, value in metric_values.items()])
1224
+ if "wandb" in training_args.report_to:
1225
+ log_pred(
1226
+ accelerator,
1227
+ pred_descriptions,
1228
+ pred_prompts,
1229
+ transcriptions,
1230
+ audios,
1231
+ si_sdr_measures,
1232
+ sampling_rate=sampling_rate,
1233
+ step=cur_step,
1234
+ prefix="eval",
1235
+ )
1236
+ accelerator.wait_for_everyone()
1237
+
1238
+ # Print metrics and update progress bar
1239
+ if accelerator.is_local_main_process:
1240
+ steps_trained_progress_bar.write(
1241
+ f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
1242
+ f" {metrics_desc})"
1243
+ )
1244
+
1245
+ log_metric(
1246
+ accelerator,
1247
+ metrics=eval_metrics,
1248
+ train_time=eval_time,
1249
+ step=cur_step,
1250
+ epoch=epoch,
1251
+ prefix="eval",
1252
+ )
1253
+
1254
+ # release eval batch and relax metrics
1255
+ eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric = release_memory(
1256
+ eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric
1257
+ )
1258
+ if training_args.predict_with_generate and (cur_step % eval_generation_steps == 0 or cur_step == total_train_steps):
1259
+ generated_audios, input_ids, prompts = release_memory(generated_audios, input_ids, prompts)
1260
+
1261
+ # train mode
1262
+ model.train()
1263
+
1264
+ # flush the train metrics
1265
+ train_start = time.time()
1266
+
1267
+ # break condition
1268
+ if cur_step == total_train_steps:
1269
+ continue_training = False
1270
+ break
1271
+
1272
+ if not continue_training:
1273
+ break
1274
+
1275
+ accelerator.end_training()
1276
+
1277
+
1278
+ if __name__ == "__main__":
1279
+ main()
capspeech/ar/training/utils.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import shutil
4
+ from dataclasses import field
5
+ from pathlib import Path
6
+ from typing import Dict, List
7
+
8
+ import torch
9
+ from datasets import concatenate_datasets, load_from_disk
10
+ from wandb import Audio
11
+ from datasets import load_from_disk, concatenate_datasets
12
+
13
+
14
+ def list_field(default=None, metadata=None):
15
+ return field(default_factory=lambda: default, metadata=metadata)
16
+
17
+
18
+ _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
19
+ CHECKPOINT_CODEC_PREFIX = "checkpoint"
20
+ _RE_CODEC_CHECKPOINT = re.compile(r"^checkpoint-(\d+)$")
21
+
22
+
23
+ def get_last_checkpoint(folder):
24
+ content = os.listdir(folder)
25
+ checkpoints = [
26
+ path
27
+ for path in content
28
+ if _RE_CHECKPOINT.search(path) is not None and os.path.isdir(os.path.join(folder, path))
29
+ ]
30
+ if len(checkpoints) == 0:
31
+ return
32
+ return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
33
+
34
+
35
+ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]:
36
+ """Helper function to sort saved checkpoints from oldest to newest."""
37
+ ordering_and_checkpoint_path = []
38
+
39
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
40
+
41
+ for path in glob_checkpoints:
42
+ regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
43
+ if regex_match is not None and regex_match.groups() is not None:
44
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
45
+
46
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
47
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
48
+ return checkpoints_sorted
49
+
50
+
51
+ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint", logger=None) -> None:
52
+ """Helper function to delete old checkpoints."""
53
+ if save_total_limit is None or save_total_limit <= 0:
54
+ return
55
+ # Check if we should delete older checkpoint(s)
56
+ checkpoints_sorted = sorted_checkpoints(output_dir=output_dir, checkpoint_prefix=checkpoint_prefix)
57
+ if len(checkpoints_sorted) <= save_total_limit:
58
+ return
59
+
60
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
61
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
62
+ for checkpoint in checkpoints_to_be_deleted:
63
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
64
+ shutil.rmtree(checkpoint, ignore_errors=True)
65
+
66
+
67
+ def save_codec_checkpoint(output_dir, dataset, step):
68
+ checkpoint_path = f"{CHECKPOINT_CODEC_PREFIX}-{step}"
69
+ output_path = os.path.join(output_dir, checkpoint_path)
70
+ dataset.save_to_disk(output_path)
71
+
72
+
73
+ def load_codec_checkpoint(checkpoint_path):
74
+ dataset = load_from_disk(checkpoint_path)
75
+ return dataset
76
+
77
+
78
+ def sorted_codec_checkpoints(output_dir=None) -> List[str]:
79
+ """Helper function to sort saved checkpoints from oldest to newest."""
80
+ ordering_and_checkpoint_path = []
81
+
82
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{CHECKPOINT_CODEC_PREFIX}-*")]
83
+
84
+ for path in glob_checkpoints:
85
+ regex_match = re.match(f".*{CHECKPOINT_CODEC_PREFIX}-([0-9]+)", path)
86
+ if regex_match is not None and regex_match.groups() is not None:
87
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
88
+
89
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
90
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
91
+ return checkpoints_sorted
92
+
93
+
94
+ def load_all_codec_checkpoints(output_dir=None) -> List[str]:
95
+ """Helper function to load and concat all checkpoints."""
96
+ checkpoints_sorted = sorted_codec_checkpoints(output_dir=output_dir)
97
+ datasets = [load_from_disk(checkpoint) for checkpoint in checkpoints_sorted]
98
+ datasets = concatenate_datasets(datasets, axis=0)
99
+ return datasets
100
+
101
+
102
+ def get_last_codec_checkpoint_step(folder) -> int:
103
+ if not os.path.exists(folder) or not os.path.isdir(folder):
104
+ os.makedirs(folder, exist_ok=True)
105
+ return 0
106
+ content = os.listdir(folder)
107
+ checkpoints = [path for path in content if _RE_CODEC_CHECKPOINT.search(path) is not None]
108
+ if len(checkpoints) == 0:
109
+ return 0
110
+ last_checkpoint = os.path.join(
111
+ folder, max(checkpoints, key=lambda x: int(_RE_CODEC_CHECKPOINT.search(x).groups()[0]))
112
+ )
113
+ # Find num steps saved state string pattern
114
+ pattern = r"checkpoint-(\d+)"
115
+ match = re.search(pattern, last_checkpoint)
116
+ cur_step = int(match.group(1))
117
+ return cur_step
118
+
119
+
120
+ def log_metric(
121
+ accelerator,
122
+ metrics: Dict,
123
+ train_time: float,
124
+ step: int,
125
+ epoch: int,
126
+ learning_rate: float = None,
127
+ prefix: str = "train",
128
+ ):
129
+ """Helper function to log all training/evaluation metrics with the correct prefixes and styling."""
130
+ log_metrics = {}
131
+ for k, v in metrics.items():
132
+ if "codebook" in k:
133
+ log_metrics[f"codebook_{prefix}/{k}"] = v
134
+ else:
135
+ log_metrics[f"{prefix}/{k}"] = v
136
+ log_metrics[f"{prefix}/time"] = train_time
137
+ log_metrics[f"{prefix}/epoch"] = epoch
138
+ if learning_rate is not None:
139
+ log_metrics[f"{prefix}/learning_rate"] = learning_rate
140
+ accelerator.log(log_metrics, step=step)
141
+
142
+
143
+ def log_pred(
144
+ accelerator,
145
+ pred_descriptions: List[str],
146
+ pred_prompts: List[str],
147
+ transcriptions: List[str],
148
+ audios: List[torch.Tensor],
149
+ si_sdr_measures: List[float],
150
+ sampling_rate: int,
151
+ step: int,
152
+ prefix: str = "eval",
153
+ num_lines: int = 200000,
154
+ ):
155
+ """Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
156
+ if accelerator.is_main_process:
157
+ wandb_tracker = accelerator.get_tracker("wandb")
158
+ # pretty name for current step: step 50000 -> step 50k
159
+ cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
160
+ prefix_pretty = prefix.replace("/", "-")
161
+
162
+ if si_sdr_measures is None:
163
+ # convert str data to a wandb compatible format
164
+ str_data = [
165
+ [pred_descriptions[i], pred_prompts[i], transcriptions[i]] for i in range(len(pred_descriptions))
166
+ ]
167
+ # log as a table with the appropriate headers
168
+ wandb_tracker.log_table(
169
+ table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
170
+ columns=["Target descriptions", "Target prompts", "Predicted transcriptions"],
171
+ data=str_data[:num_lines],
172
+ step=step,
173
+ commit=False,
174
+ )
175
+ else:
176
+ # convert str data to a wandb compatible format
177
+ str_data = [
178
+ [pred_descriptions[i], pred_prompts[i], transcriptions[i], si_sdr_measures[i]]
179
+ for i in range(len(pred_descriptions))
180
+ ]
181
+ # log as a table with the appropriate headers
182
+ wandb_tracker.log_table(
183
+ table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
184
+ columns=["Target descriptions", "Target prompts", "Predicted transcriptions", "Noise estimation"],
185
+ data=str_data[:num_lines],
186
+ step=step,
187
+ commit=False,
188
+ )
189
+
190
+ # wandb can only loads 100 audios per step
191
+ wandb_tracker.log(
192
+ {
193
+ "Speech samples": [
194
+ Audio(
195
+ audio,
196
+ caption=f"{pred_prompts[i]} --- DESCRIPTION: {pred_descriptions[i]}",
197
+ sample_rate=sampling_rate,
198
+ )
199
+ for (i, audio) in enumerate(audios[: min(len(audios), 100)])
200
+ ]
201
+ },
202
+ step=step,
203
+ )
capspeech/eval/README.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CapSpeech Evaluation Tools
2
+
3
+ ## Get Start
4
+ Install dependicies:
5
+ ```bash
6
+ conda create -n capeval python=3.9
7
+ conda activate capeval
8
+ pip install -r requirements.txt
9
+ pip install git+https://github.com/sarulab-speech/UTMOSv2.git
10
+ ```
11
+
12
+ For ASR, we need:
13
+ ```bash
14
+ conda install ffmpeg
15
+ ```
16
+
17
+ ## Evaluate pitch, monotony, speed, age, gender
18
+ RUN:
19
+ ```bash
20
+ python base_eval.py
21
+ ```
22
+
23
+ ## Evaluate UTMOSv2
24
+ RUN:
25
+ ```bash
26
+ python mos_eval.py
27
+ ```
28
+
29
+ ## Evaluate ASR Results
30
+ RUN:
31
+ ```bash
32
+ python asr_eval.py
33
+ ```
34
+
35
+ ## Evaluate emotion, accent
36
+ RUN:
37
+ ```bash
38
+ cd src/example/
39
+ python categorized_emotion.py
40
+ python dialect_world_dialect.py
41
+ ```
42
+ Please refer to [Vox-profile](https://github.com/tiantiaf0627/vox-profile-release.git) for more evaluation tools.
capspeech/eval/__init__.py ADDED
File without changes
capspeech/eval/age_gender.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import audeer
2
+ import audonnx
3
+ import numpy as np
4
+
5
+ def age_gender_apply(waveform):
6
+ age_labels = ['child', 'teenager', 'young adult', 'middle-aged adult', 'elderly']
7
+ gender_labels = ['female', 'male']
8
+ url = 'https://zenodo.org/record/7761387/files/w2v2-L-robust-6-age-gender.25c844af-1.1.1.zip'
9
+ cache_root = audeer.mkdir('cache')
10
+ model_root = audeer.mkdir('model')
11
+ sampling_rate = 16000
12
+ archive_path = audeer.download_url(url, cache_root, verbose=True)
13
+ audeer.extract_archive(archive_path, model_root)
14
+ model = audonnx.load(model_root)
15
+
16
+ result = model(waveform, sampling_rate)
17
+ # Process age
18
+ age_label = result['logits_age'].squeeze() * 100.0
19
+ if age_label <= 12:
20
+ age_label = 'child'
21
+ elif age_label <= 19:
22
+ age_label = 'teenager'
23
+ elif age_label <= 39:
24
+ age_label = 'young adult'
25
+ elif age_label <= 64:
26
+ age_label = 'middle-aged adult'
27
+ else:
28
+ age_label = 'elderly'
29
+
30
+ # Process gender
31
+ gender_label = result['logits_gender'].squeeze()
32
+ gender_label = gender_label[:2] # Remove child
33
+ gender_label = np.argmax(gender_label)
34
+
35
+ return age_label, gender_labels[gender_label]
capspeech/eval/asr_eval.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from jiwer import wer as calculate_wer
2
+ from jiwer import cer as calculate_cer
3
+ from whisper.normalizers import EnglishTextNormalizer
4
+ import whisper
5
+ import torch
6
+
7
+ normalizer = EnglishTextNormalizer()
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ whisper_model = whisper.load_model("large-v3-turbo", device=device)
10
+
11
+ def asr(wav_path):
12
+ result = whisper_model.transcribe(wav_path)
13
+ pred = result['text'].strip()
14
+ pred = normalizer(pred)
15
+ return pred
16
+
17
+ if __name__ == '__main__':
18
+ gt_text="Hey, how are you doing today? I like it."
19
+ wav_path="your-audio"
20
+ gt_text = normalizer(gt_text.strip())
21
+ pred_asr = asr(wav_path)
22
+ wer = round(calculate_wer(gt_text, pred_asr), 3)
23
+ cer = round(calculate_cer(gt_text, pred_asr), 3)
24
+ print(wer, cer)
capspeech/eval/base_eval.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pitch import pitch_apply
2
+ from speed import speed_apply
3
+ from age_gender import age_gender_apply
4
+ import librosa
5
+ import json
6
+ import bisect
7
+
8
+ SPEAKER_RATE_BINS = ["very slowly", "slowly", "slightly slowly", "moderate speed", "slightly fast", "fast", "very fast"]
9
+ UTTERANCE_LEVEL_STD = ["very monotone", "monotone", "slightly expressive and animated", "expressive and animated", "very expressive and animated"]
10
+ SPEAKER_LEVEL_PITCH_BINS = ["very low-pitch", "low-pitch", "slightly low-pitch", "moderate pitch", "slightly high-pitch", "high-pitch", "very high-pitch"]
11
+ with open("bin.json") as json_file:
12
+ text_bins_dict = json.load(json_file)
13
+
14
+ audiopath = "YOUR_AUDIO_PATH"
15
+ waveform, _ = librosa.load(audiopath, sr=16000)
16
+ age, gender = age_gender_apply(waveform)
17
+ pitch_mean, pitch_std = pitch_apply(waveform)
18
+ if gender == "male":
19
+ index = bisect.bisect_right(text_bins_dict["pitch_bins_male"], pitch_mean) - 1
20
+ pitch = SPEAKER_LEVEL_PITCH_BINS[index]
21
+ else:
22
+ index = bisect.bisect_right(text_bins_dict["pitch_bins_female"], pitch_mean) - 1
23
+ pitch = SPEAKER_LEVEL_PITCH_BINS[index]
24
+
25
+ index = bisect.bisect_right(text_bins_dict["speech_monotony"], pitch_std) - 1
26
+ monotony = UTTERANCE_LEVEL_STD[index]
27
+ speech_duration = speed_apply(waveform)
28
+
29
+ index = bisect.bisect_right(text_bins_dict["speaking_rate"], speech_duration) - 1
30
+ speed = SPEAKER_RATE_BINS[index]
31
+
32
+ print(pitch, monotony, speed, age, gender)
capspeech/eval/bin.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "speaking_rate": [0.0, 3.8258038258038254, 7.651607651607651, 11.477411477411476, 15.303215303215302, 19.129019129019127, 22.95482295482295, 26.78062678062678],
3
+ "noise": [17.12751579284668, 25.4012325831822, 33.67494937351772, 41.94866616385323, 50.22238295418875, 58.49609974452427, 66.76981653485979, 75.04353332519531],
4
+ "reverberation": [10, 35, 45, 55, 59, 60],
5
+ "speech_monotony": [0.0, 20.37920924595424, 40.75841849190848, 70, 90, 142.6544647216797],
6
+ "pitch_bins_male": [64.6531982421875, 81.66683959960938, 98.68048095703125, 115.69412231445312, 132.707763671875, 149.72140502929688, 166.73504638671875, 183.74868774414062],
7
+ "pitch_bins_female": [120.17855072021484, 141.6242690945264, 163.06998746883795, 184.51570584314953, 205.96142421746106, 227.40714259177264, 248.8528609660842, 270.29857934039575],
8
+ "si-sdr": [-17.804332733154297, -0.40644073486328125, 10, 20, 25, 28, 34.38934326171875],
9
+ "pesq": [1, 1.7, 2.4, 3.1, 3.6, 4, 4.499948978424072]
10
+ }
capspeech/eval/pitch.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import penn
3
+
4
+ def pitch_apply(waveform):
5
+ hopsize = .01
6
+ fmin = 30.
7
+ fmax = 1000.
8
+ checkpoint = None
9
+ center = 'half-hop'
10
+ interp_unvoiced_at = .065
11
+ sampling_rate = 16000
12
+ penn_batch_size = 4096
13
+ waveform = torch.Tensor(waveform).unsqueeze(0)
14
+ pitch, periodicity = penn.from_audio(
15
+ waveform.float(),
16
+ sampling_rate,
17
+ hopsize=hopsize,
18
+ fmin=fmin,
19
+ fmax=fmax,
20
+ checkpoint=checkpoint,
21
+ batch_size=penn_batch_size,
22
+ center=center,
23
+ interp_unvoiced_at=interp_unvoiced_at,
24
+ gpu=None
25
+ )
26
+
27
+ pitch_mean = pitch.mean().cpu().numpy()
28
+ pitch_std = pitch.std().cpu().numpy()
29
+
30
+ return pitch_mean, pitch_std
capspeech/eval/requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets[audio]
2
+ https://github.com/marianne-m/brouhaha-vad/archive/main.zip
3
+ penn
4
+ g2p
5
+ demucs
6
+ transformers
7
+ bitsandbytes
8
+ git+https://github.com/sarulab-speech/UTMOSv2.git
9
+ -U openai-whisper
10
+ jiwer
11
+ numpy==1.26.4
12
+ audeer
13
+ audonnx
14
+ laion_clap
15
+ numpy==1.26.4
16
+ onnxruntime
capspeech/eval/speed.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pyannote.audio import Model
2
+ from pathlib import Path
3
+ from brouhaha.pipeline import RegressiveActivityDetectionPipeline
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+ import numpy as np
7
+
8
+ def speed_apply(waveform):
9
+ ratio = 16000/270
10
+ sampling_rate = 16000
11
+ device = "cpu"
12
+ waveform = torch.Tensor(waveform).unsqueeze(0)
13
+ model = Model.from_pretrained(
14
+ Path(hf_hub_download(repo_id="ylacombe/brouhaha-best", filename="best.ckpt")),
15
+ strict=False,
16
+ )
17
+ model.to(device)
18
+
19
+ pipeline = RegressiveActivityDetectionPipeline(segmentation=model, batch_size=1)
20
+ pipeline.to(torch.device(device))
21
+
22
+ device = pipeline._models["segmentation"].device
23
+
24
+ res = pipeline({"sample_rate": sampling_rate,
25
+ "waveform": waveform.to(device).float()})
26
+
27
+ speech_duration = sum(map(lambda x: x[0].duration, res["annotation"].itertracks()))
28
+
29
+ return speech_duration
capspeech/eval/src/__init__.py ADDED
File without changes
capspeech/eval/src/example/__init__.py ADDED
File without changes
capspeech/eval/src/example/categorized_emotion.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ import sys, os, pdb
4
+ import torch.nn.functional as F
5
+
6
+ from pathlib import Path
7
+
8
+ sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1])))
9
+ sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1]), 'model', 'emotion'))
10
+
11
+ from wavlm_emotion import WavLMWrapper
12
+ from whisper_emotion import WhisperWrapper
13
+
14
+
15
+ # define logging console
16
+ import logging
17
+ logging.basicConfig(
18
+ format='%(asctime)s %(levelname)-3s ==> %(message)s',
19
+ level=logging.INFO,
20
+ datefmt='%Y-%m-%d %H:%M:%S'
21
+ )
22
+
23
+ os.environ["MKL_NUM_THREADS"] = "1"
24
+ os.environ["NUMEXPR_NUM_THREADS"] = "1"
25
+ os.environ["OMP_NUM_THREADS"] = "1"
26
+
27
+
28
+ if __name__ == '__main__':
29
+
30
+ label_list = [
31
+ 'Anger',
32
+ 'Contempt',
33
+ 'Disgust',
34
+ 'Fear',
35
+ 'Happiness',
36
+ 'Neutral',
37
+ 'Sadness',
38
+ 'Surprise',
39
+ 'Other'
40
+ ]
41
+
42
+ # Find device
43
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
44
+ if torch.cuda.is_available(): print('GPU available, use GPU')
45
+
46
+ # Define the model
47
+ # Note that ensemble yields the better performance than the single model
48
+ # Define the model wrapper
49
+ model_path = "model"
50
+ wavlm_model = model = WavLMWrapper(
51
+ pretrain_model="wavlm_large",
52
+ finetune_method="finetune",
53
+ output_class_num=9,
54
+ freeze_params=True,
55
+ use_conv_output=True,
56
+ detailed_class_num=17
57
+ ).to(device)
58
+
59
+ whisper_model = WhisperWrapper(
60
+ pretrain_model="whisper_large",
61
+ finetune_method="lora",
62
+ lora_rank=16,
63
+ output_class_num=9,
64
+ freeze_params=True,
65
+ use_conv_output=True,
66
+ detailed_class_num=17
67
+ ).to(device)
68
+
69
+ whisper_model.load_state_dict(torch.load(os.path.join(model_path, f"whisper_emotion.pt"), weights_only=True), strict=False)
70
+ whisper_model.load_state_dict(torch.load(os.path.join(model_path, f"whisper_emotion_lora.pt")), strict=False)
71
+ wavlm_model.load_state_dict(torch.load(os.path.join(model_path, f"wavlm_emotion.pt"), weights_only=True), strict=False)
72
+
73
+ wavlm_model.eval()
74
+ whisper_model.eval()
75
+
76
+ # Audio must be 16k Hz
77
+ data = torch.zeros([1, 16000]).to(device)
78
+ whisper_logits, whisper_embedding, _, _, _, _ = whisper_model(
79
+ data, return_feature=True
80
+ )
81
+ wavlm_logits, wavlm_embedding, _, _, _, _ = wavlm_model(
82
+ data, return_feature=True
83
+ )
84
+
85
+ ensemble_logits = (whisper_logits + wavlm_logits) / 2
86
+ ensemble_prob = F.softmax(ensemble_logits, dim=1)
87
+
88
+ print(ensemble_prob.shape)
89
+ print(whisper_embedding.shape)
90
+ print(wavlm_embedding.shape)
91
+ print(label_list[torch.argmax(ensemble_prob).detach().cpu().item()])
92
+
capspeech/eval/src/example/dialect_world_dialect.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys, os, pdb
3
+ import argparse, logging
4
+ import torch.nn.functional as F
5
+
6
+ from pathlib import Path
7
+
8
+
9
+ sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1])))
10
+ sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1]), 'model', 'dialect'))
11
+
12
+ from wavlm_dialect import WavLMWrapper
13
+ from whisper_dialect import WhisperWrapper
14
+
15
+
16
+ # define logging console
17
+ import logging
18
+ logging.basicConfig(
19
+ format='%(asctime)s %(levelname)-3s ==> %(message)s',
20
+ level=logging.INFO,
21
+ datefmt='%Y-%m-%d %H:%M:%S'
22
+ )
23
+
24
+ os.environ["MKL_NUM_THREADS"] = "1"
25
+ os.environ["NUMEXPR_NUM_THREADS"] = "1"
26
+ os.environ["OMP_NUM_THREADS"] = "1"
27
+
28
+
29
+ if __name__ == '__main__':
30
+
31
+
32
+ label_list = [
33
+ 'East Asia', 'English', 'Germanic', 'Irish',
34
+ 'North America', 'Northern Irish', 'Oceania',
35
+ 'Other', 'Romance', 'Scottish', 'Semitic', 'Slavic',
36
+ 'South African', 'Southeast Asia', 'South Asia', 'Welsh'
37
+ ]
38
+
39
+ # Find device
40
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
41
+ if torch.cuda.is_available(): print('GPU available, use GPU')
42
+
43
+ # Define the model
44
+ # Note that ensemble yields the better performance than the single model
45
+ model_path = "YOUR_PATH"
46
+ # Define the model wrapper
47
+ wavlm_model = model = WavLMWrapper(
48
+ pretrain_model="wavlm_large",
49
+ finetune_method="lora",
50
+ lora_rank=16,
51
+ output_class_num=16,
52
+ freeze_params=False,
53
+ use_conv_output=True,
54
+ apply_gradient_reversal=False,
55
+ num_dataset=3
56
+ ).to(device)
57
+
58
+ whisper_model = WhisperWrapper(
59
+ pretrain_model="whisper_large",
60
+ finetune_method="lora",
61
+ lora_rank=16,
62
+ output_class_num=16,
63
+ freeze_params=False,
64
+ use_conv_output=True,
65
+ apply_gradient_reversal=False,
66
+ num_dataset=11
67
+ ).to(device)
68
+
69
+ wavlm_model.load_state_dict(torch.load(os.path.join(model_path, f"wavlm_world_dialect.pt"), weights_only=True), strict=False)
70
+ wavlm_model.load_state_dict(torch.load(os.path.join(model_path, f"wavlm_world_dialect_lora.pt")), strict=False)
71
+
72
+ whisper_model.load_state_dict(torch.load(os.path.join(model_path, f"whisper_world_dialect.pt"), weights_only=True), strict=False)
73
+ whisper_model.load_state_dict(torch.load(os.path.join(model_path, f"whisper_world_dialect_lora.pt")), strict=False)
74
+
75
+ wavlm_model.eval()
76
+ whisper_model.eval()
77
+
78
+ data = torch.zeros([1, 16000]).to(device)
79
+ wavlm_logits, wavlm_embeddings = wavlm_model(data, return_feature=True)
80
+ whisper_logits, whisper_embeddings = whisper_model(data, return_feature=True)
81
+
82
+ ensemble_logits = (wavlm_logits + whisper_logits) / 2
83
+ ensemble_prob = F.softmax(ensemble_logits, dim=1)
84
+
85
+ pred = label_list[ensemble_prob.argmax(-1)]
86
+ print(pred)
87
+
capspeech/eval/src/model/__init__.py ADDED
File without changes
capspeech/eval/src/model/adapter.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # References:
3
+ # https://github.com/jxhe/unify-parameter-efficient-tuning
4
+ # --------------------------------------------------------
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class Adapter(nn.Module):
12
+ def __init__(
13
+ self,
14
+ config=None,
15
+ d_model=768,
16
+ bottleneck=None,
17
+ dropout=0.0,
18
+ init_option="lora",
19
+ adapter_scalar="1.0",
20
+ adapter_layernorm_option="none"
21
+ ):
22
+ super().__init__()
23
+ self.n_embd = config.d_model if d_model is None else d_model
24
+ self.down_size = config.attn_bn if bottleneck is None else bottleneck
25
+
26
+ #_before
27
+ self.adapter_layernorm_option = adapter_layernorm_option
28
+
29
+ self.adapter_layer_norm_before = None
30
+ if adapter_layernorm_option == "in" or adapter_layernorm_option == "out":
31
+ self.adapter_layer_norm_before = nn.LayerNorm(self.n_embd)
32
+
33
+ if adapter_scalar == "learnable_scalar":
34
+ self.scale = nn.Parameter(torch.ones(1))
35
+ else:
36
+ self.scale = float(adapter_scalar)
37
+
38
+ self.down_proj = nn.Linear(self.n_embd, self.down_size)
39
+ self.non_linear_func = nn.ReLU()
40
+ self.up_proj = nn.Linear(self.down_size, self.n_embd)
41
+
42
+ self.dropout = dropout
43
+ if init_option == "bert":
44
+ raise NotImplementedError
45
+ elif init_option == "lora":
46
+ with torch.no_grad():
47
+ nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
48
+ nn.init.zeros_(self.up_proj.weight)
49
+ nn.init.zeros_(self.down_proj.bias)
50
+ nn.init.zeros_(self.up_proj.bias)
51
+
52
+ def forward(self, x, add_residual=True, residual=None):
53
+ residual = x if residual is None else residual
54
+ if self.adapter_layernorm_option == 'in':
55
+ x = self.adapter_layer_norm_before(x)
56
+
57
+ down = self.down_proj(x)
58
+
59
+ down = self.non_linear_func(down)
60
+ down = nn.functional.dropout(down, p=self.dropout, training=self.training)
61
+ up = self.up_proj(down)
62
+
63
+ up = up * self.scale
64
+
65
+ if self.adapter_layernorm_option == 'out':
66
+ up = self.adapter_layer_norm_before(up)
67
+
68
+ if add_residual:
69
+ output = up + residual
70
+ else:
71
+ output = up
72
+
73
+ return output
capspeech/eval/src/model/dialect/__init__.py ADDED
File without changes
capspeech/eval/src/model/dialect/wavlm_dialect.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import copy
4
+ import torch
5
+ import argparse
6
+ import loralib as lora
7
+ import transformers.models.wavlm.modeling_wavlm as wavlm
8
+ from speechbrain.nnet.normalization import LayerNorm
9
+ from speechbrain.lobes.models.huggingface_transformers.huggingface import make_padding_masks
10
+
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+ from transformers import Wav2Vec2FeatureExtractor
14
+ from transformers import WavLMModel
15
+
16
+ import sys
17
+ from pathlib import Path
18
+ sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1])))
19
+ from revgrad import RevGrad
20
+
21
+ class WavLMEncoderLayer(nn.Module):
22
+ def __init__(self, layer_idx, config, has_relative_position_bias: bool = True):
23
+ super().__init__()
24
+ self.attention = wavlm.WavLMAttention(
25
+ embed_dim=config.hidden_size,
26
+ num_heads=config.num_attention_heads,
27
+ dropout=config.attention_dropout,
28
+ num_buckets=config.num_buckets,
29
+ max_distance=config.max_bucket_distance,
30
+ has_relative_position_bias=has_relative_position_bias,
31
+ )
32
+ self.dropout = nn.Dropout(config.hidden_dropout)
33
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
34
+ self.feed_forward = wavlm.WavLMFeedForward(config)
35
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
36
+ self.config = config
37
+
38
+ if layer_idx > config.num_hidden_layers // 2:
39
+ if self.config.finetune_method == "lora" or self.config.finetune_method == "combined":
40
+ self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank)
41
+ self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank)
42
+
43
+ def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0):
44
+
45
+ attn_residual = hidden_states
46
+ hidden_states, attn_weights, position_bias = self.attention(
47
+ hidden_states,
48
+ attention_mask=attention_mask,
49
+ position_bias=position_bias,
50
+ output_attentions=output_attentions,
51
+ index=index,
52
+ )
53
+ hidden_states = self.dropout(hidden_states)
54
+ hidden_states = attn_residual + hidden_states
55
+
56
+ hidden_states = self.layer_norm(hidden_states)
57
+ hidden_states = hidden_states + self.feed_forward(hidden_states)
58
+ hidden_states = self.final_layer_norm(hidden_states)
59
+ outputs = (hidden_states, position_bias)
60
+
61
+ if output_attentions:
62
+ outputs += (attn_weights,)
63
+
64
+ return outputs
65
+
66
+
67
+ class WavLMEncoderLayerStableLayerNorm(nn.Module):
68
+ def __init__(self, layer_idx, config, has_relative_position_bias: bool = True):
69
+ super().__init__()
70
+ self.attention = wavlm.WavLMAttention(
71
+ embed_dim=config.hidden_size,
72
+ num_heads=config.num_attention_heads,
73
+ dropout=config.attention_dropout,
74
+ num_buckets=config.num_buckets,
75
+ max_distance=config.max_bucket_distance,
76
+ has_relative_position_bias=has_relative_position_bias,
77
+ )
78
+ self.dropout = nn.Dropout(config.hidden_dropout)
79
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
80
+ self.feed_forward = wavlm.WavLMFeedForward(config)
81
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
82
+ self.config = config
83
+
84
+ if layer_idx > config.num_hidden_layers // 2:
85
+ if self.config.finetune_method == "lora" or self.config.finetune_method == "combined":
86
+ self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank)
87
+ self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank)
88
+
89
+
90
+ def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False):
91
+ attn_residual = hidden_states
92
+ hidden_states = self.layer_norm(hidden_states)
93
+ hidden_states, attn_weights, position_bias = self.attention(
94
+ hidden_states,
95
+ attention_mask=attention_mask,
96
+ position_bias=position_bias,
97
+ output_attentions=output_attentions,
98
+ )
99
+ hidden_states = self.dropout(hidden_states)
100
+ hidden_states = attn_residual + hidden_states
101
+ hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
102
+
103
+ outputs = (hidden_states, position_bias)
104
+
105
+ if output_attentions:
106
+ outputs += (attn_weights,)
107
+
108
+ return outputs
109
+
110
+
111
+ class WavLMWrapper(nn.Module):
112
+ def __init__(
113
+ self,
114
+ pretrain_model="wavlm_large",
115
+ hidden_dim=256,
116
+ finetune_method="lora",
117
+ lora_rank=16,
118
+ freeze_params=True,
119
+ output_class_num=4,
120
+ use_conv_output=True,
121
+ apply_gradient_reversal=False,
122
+ num_dataset=4
123
+ ):
124
+ super(WavLMWrapper, self).__init__()
125
+ # 1. We Load the model first with weights
126
+ if pretrain_model == "wavlm":
127
+ self.backbone_model = WavLMModel.from_pretrained(
128
+ "microsoft/wavlm-base-plus",
129
+ output_hidden_states=True,
130
+ )
131
+ elif pretrain_model == "wavlm_large":
132
+ self.processor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-large')
133
+ self.backbone_model = WavLMModel.from_pretrained(
134
+ "microsoft/wavlm-large",
135
+ output_hidden_states=True,
136
+ )
137
+ self.pretrain_model = pretrain_model
138
+ self.finetune_method = finetune_method
139
+ self.apply_gradient_reversal = apply_gradient_reversal
140
+ self.use_conv_output = use_conv_output
141
+
142
+ state_dict = self.backbone_model.state_dict()
143
+ # 2. Read the model config
144
+ self.model_config = self.backbone_model.config
145
+ self.model_config.finetune_method = finetune_method
146
+ self.model_config.lora_rank = lora_rank
147
+
148
+ # 3. Config encoder layers with adapter or embedding prompt
149
+ if self.pretrain_model == "wavlm":
150
+ self.backbone_model.encoder.layers = nn.ModuleList(
151
+ [WavLMEncoderLayer(i, self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)]
152
+ )
153
+ elif self.pretrain_model == "wavlm_large":
154
+ self.backbone_model.encoder.layers = nn.ModuleList(
155
+ [WavLMEncoderLayerStableLayerNorm(i, self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)]
156
+ )
157
+ # 4. Load the weights back
158
+ msg = self.backbone_model.load_state_dict(state_dict, strict=False)
159
+
160
+ # 5. Freeze the weights
161
+ self.freeze_params = freeze_params
162
+ if self.freeze_params and self.finetune_method != "lora":
163
+ for _, p in self.backbone_model.named_parameters(): p.requires_grad = False
164
+ elif self.freeze_params and self.finetune_method == "lora":
165
+ for name, p in self.backbone_model.named_parameters():
166
+ if name in msg.missing_keys: p.requires_grad = True
167
+ else: p.requires_grad = False
168
+ else:
169
+ for _, p in self.backbone_model.named_parameters(): p.requires_grad = True
170
+
171
+ # 6. Downstream models
172
+ self.model_seq = nn.Sequential(
173
+ nn.Conv1d(self.model_config.hidden_size, hidden_dim, 1, padding=0),
174
+ nn.ReLU(),
175
+ nn.Dropout(p=0.1),
176
+ nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0),
177
+ nn.ReLU(),
178
+ nn.Dropout(p=0.1),
179
+ nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0)
180
+ )
181
+
182
+ if self.use_conv_output:
183
+ num_layers = self.model_config.num_hidden_layers + 1 # transformer layers + input embeddings
184
+ self.weights = nn.Parameter(torch.ones(num_layers)/num_layers)
185
+ else:
186
+ num_layers = self.model_config.num_hidden_layers
187
+ self.weights = nn.Parameter(torch.zeros(num_layers))
188
+
189
+ if apply_gradient_reversal:
190
+ self.dataset_layer = nn.Sequential(
191
+ RevGrad(),
192
+ nn.Linear(hidden_dim, hidden_dim),
193
+ nn.ReLU(),
194
+ nn.Linear(hidden_dim, num_dataset),
195
+ )
196
+
197
+ self.out_layer = nn.Sequential(
198
+ nn.Linear(hidden_dim, hidden_dim),
199
+ nn.ReLU(),
200
+ nn.Linear(hidden_dim, output_class_num),
201
+ )
202
+
203
+ def forward(self, x, length=None, return_feature=False):
204
+ # 1. feature extraction and projections
205
+ if self.pretrain_model == "wavlm_large":
206
+ with torch.no_grad():
207
+ signal, attention_mask = list(), list()
208
+ if length is not None: attention_mask = make_padding_masks(x, wav_len=length/length.max()).to(x.device)
209
+ else: attention_mask = make_padding_masks(x, wav_len=torch.tensor([1]).to(x.device)).to(x.device)
210
+
211
+ for idx in range(len(x)):
212
+ input = self.processor(x[idx], sampling_rate=16_000, return_tensors="pt", padding=True)
213
+ signal.append(input["input_values"][0].to(x.device))
214
+ signal = torch.stack(signal)
215
+
216
+ # 2. get length and mask
217
+ if length is not None:
218
+ length = self.get_feat_extract_output_lengths(length.detach().cpu())
219
+ length = length.cuda()
220
+
221
+ if self.pretrain_model == "wavlm":
222
+ x = self.backbone_model(
223
+ x, output_hidden_states=True
224
+ ).hidden_states
225
+ else:
226
+ x = self.backbone_model(
227
+ signal,
228
+ attention_mask=attention_mask,
229
+ output_hidden_states=True
230
+ ).hidden_states
231
+
232
+ # 4. stacked feature
233
+ if self.use_conv_output: stacked_feature = torch.stack(x, dim=0)
234
+ else: stacked_feature = torch.stack(x, dim=0)[1:]
235
+
236
+ # 5. Weighted sum
237
+ _, *origin_shape = stacked_feature.shape
238
+ # Return transformer enc outputs [num_enc_layers, B, T, D]
239
+ if self.use_conv_output:
240
+ stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers+1, -1)
241
+ else:
242
+ stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers, -1)
243
+ norm_weights = F.softmax(self.weights, dim=-1)
244
+
245
+ # Perform weighted average
246
+ weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0)
247
+ features = weighted_feature.view(*origin_shape)
248
+
249
+ # 6. Pass the weighted average to point-wise 1D Conv
250
+ # B x T x D
251
+ features = features.transpose(1, 2)
252
+ features = self.model_seq(features)
253
+ features = features.transpose(1, 2)
254
+
255
+ # 7. Pooling
256
+ if length is not None:
257
+ mean, std = list(), list()
258
+ for snt_id in range(features.shape[0]):
259
+ # Avoiding padded time steps
260
+ actual_size = length[snt_id]
261
+ mean.append(torch.mean(features[snt_id, 0:actual_size, ...], dim=0))
262
+ features = torch.stack(mean)
263
+ else:
264
+ features = torch.mean(features, dim=1)
265
+
266
+ # 8. Output predictions
267
+ # B x D
268
+ predicted = self.out_layer(features)
269
+ if self.apply_gradient_reversal:
270
+ dataset_predicted = self.dataset_layer(features)
271
+ if return_feature: return predicted, dataset_predicted, features
272
+ return predicted, dataset_predicted
273
+ if return_feature: return predicted, features
274
+ return predicted
275
+
276
+ # From huggingface
277
+ def get_feat_extract_output_lengths(self, input_length):
278
+ """
279
+ Computes the output length of the convolutional layers
280
+ """
281
+ def _conv_out_length(input_length, kernel_size, stride):
282
+ # 1D convolutional layer output length formula taken
283
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
284
+ return (input_length - kernel_size) // stride + 1
285
+ for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride):
286
+ input_length = _conv_out_length(input_length, kernel_size, stride)
287
+ return input_length
288
+
289
+ def prepare_mask(length, shape, dtype):
290
+ # Modified from huggingface
291
+ mask = torch.zeros(
292
+ shape, dtype=dtype
293
+ )
294
+ # these two operations makes sure that all values
295
+ # before the output lengths indices are attended to
296
+ mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1
297
+ mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool()
298
+ return mask
299
+
300
+
capspeech/eval/src/model/dialect/whisper_dialect.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import copy
4
+ import torch
5
+ import argparse
6
+ import numpy as np
7
+ import loralib as lora
8
+ import transformers.models.whisper.modeling_whisper as whisper
9
+
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+ from transformers.activations import ACT2FN
13
+ from transformers import WhisperModel, AutoFeatureExtractor
14
+
15
+ import sys
16
+ from pathlib import Path
17
+ sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1])))
18
+ from revgrad import RevGrad
19
+
20
+ class WhisperEncoderLayer(nn.Module):
21
+ def __init__(self, config, layer_idx):
22
+ super().__init__()
23
+ self.embed_dim = config.d_model
24
+ self.self_attn = whisper.WhisperAttention(
25
+ embed_dim=self.embed_dim,
26
+ num_heads=config.encoder_attention_heads,
27
+ dropout=config.attention_dropout,
28
+ )
29
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
30
+ self.dropout = config.dropout
31
+ self.activation_fn = ACT2FN[config.activation_function]
32
+ self.activation_dropout = config.activation_dropout
33
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
34
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
35
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
36
+ self.config = config
37
+
38
+ if layer_idx > config.encoder_layers // 2:
39
+ if self.config.finetune_method == "lora" or self.config.finetune_method == "combined":
40
+ self.fc1 = lora.Linear(self.embed_dim, config.encoder_ffn_dim, r=config.lora_rank)
41
+ self.fc2 = lora.Linear(config.encoder_ffn_dim, self.embed_dim, r=config.lora_rank)
42
+
43
+ def forward(
44
+ self,
45
+ hidden_states: torch.Tensor,
46
+ attention_mask: torch.Tensor,
47
+ layer_head_mask: torch.Tensor,
48
+ output_attentions: bool = False,
49
+ ):
50
+ """
51
+ Args:
52
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
53
+ attention_mask (`torch.FloatTensor`): attention mask of size
54
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
55
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
56
+ `(encoder_attention_heads,)`.
57
+ output_attentions (`bool`, *optional*):
58
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
59
+ returned tensors for more detail.
60
+ """
61
+ residual = hidden_states
62
+ hidden_states = self.self_attn_layer_norm(hidden_states)
63
+ hidden_states, attn_weights, _ = self.self_attn(
64
+ hidden_states=hidden_states,
65
+ attention_mask=attention_mask,
66
+ layer_head_mask=layer_head_mask,
67
+ output_attentions=output_attentions,
68
+ )
69
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
70
+ hidden_states = residual + hidden_states
71
+ residual = hidden_states
72
+
73
+ hidden_states = self.final_layer_norm(hidden_states)
74
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
75
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
76
+ hidden_states = self.fc2(hidden_states)
77
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
78
+ hidden_states = residual + hidden_states
79
+
80
+ if hidden_states.dtype == torch.float16 and (
81
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
82
+ ):
83
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
84
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
85
+ outputs = (hidden_states,)
86
+
87
+ if output_attentions:
88
+ outputs += (attn_weights,)
89
+
90
+ return outputs
91
+
92
+ class WhisperWrapper(nn.Module):
93
+ def __init__(
94
+ self,
95
+ pretrain_model="whisper_large",
96
+ output_class_num=4,
97
+ hidden_dim=256,
98
+ finetune_method="lora",
99
+ lora_rank=16,
100
+ freeze_params=True,
101
+ use_conv_output=True,
102
+ apply_gradient_reversal=False,
103
+ num_dataset=4
104
+ ):
105
+ super(WhisperWrapper, self).__init__()
106
+ # 1. We Load the model first with weights
107
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny", chunk_length=15)
108
+ self.pretrain_model = pretrain_model
109
+ if self.pretrain_model == "whisper_tiny":
110
+ self.backbone_model = WhisperModel.from_pretrained(
111
+ "openai/whisper-tiny",
112
+ output_hidden_states=True,
113
+ ignore_mismatched_sizes=True,
114
+ max_source_positions=750,
115
+ )
116
+ elif self.pretrain_model == "whisper_base":
117
+ self.backbone_model = WhisperModel.from_pretrained(
118
+ "openai/whisper-base",
119
+ output_hidden_states=True,
120
+ ignore_mismatched_sizes=True,
121
+ max_source_positions=750,
122
+ )
123
+ elif self.pretrain_model == "whisper_small":
124
+ self.backbone_model = WhisperModel.from_pretrained(
125
+ "openai/whisper-small",
126
+ output_hidden_states=True,
127
+ max_source_positions=750,
128
+ ignore_mismatched_sizes=True
129
+ )
130
+ elif self.pretrain_model == "whisper_medium":
131
+ self.backbone_model = WhisperModel.from_pretrained(
132
+ "openai/whisper-medium",
133
+ output_hidden_states=True,
134
+ ignore_mismatched_sizes=True
135
+ )
136
+ elif self.pretrain_model == "whisper_large":
137
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large-v3", chunk_length=15)
138
+ self.backbone_model = WhisperModel.from_pretrained(
139
+ "openai/whisper-large-v3",
140
+ output_hidden_states=True,
141
+ ignore_mismatched_sizes=True,
142
+ max_source_positions=750,
143
+ )
144
+ self.embed_positions = copy.deepcopy(self.backbone_model.encoder.embed_positions.weight)
145
+ self.embed_positions.requires_grad = False
146
+
147
+ state_dict = self.backbone_model.state_dict()
148
+ # 2. Read the model config
149
+ self.model_config = self.backbone_model.config
150
+ self.model_config.finetune_method = finetune_method
151
+ self.model_config.lora_rank = lora_rank
152
+ self.finetune_method = finetune_method
153
+ self.apply_gradient_reversal = apply_gradient_reversal
154
+ self.use_conv_output = use_conv_output
155
+
156
+ if self.finetune_method == "lora":
157
+ # 3. Config encoder layers with adapter or embedding prompt
158
+ self.backbone_model.encoder.layers = nn.ModuleList(
159
+ [WhisperEncoderLayer(self.model_config, layer_idx) for layer_idx in range(self.model_config.encoder_layers)]
160
+ )
161
+ # 4. Load the weights back
162
+ msg = self.backbone_model.load_state_dict(state_dict, strict=False)
163
+
164
+ # 2. Freeze the weights
165
+ self.freeze_params = freeze_params
166
+ if self.freeze_params and self.finetune_method != "lora":
167
+ for _, p in self.backbone_model.named_parameters(): p.requires_grad = False
168
+ elif self.freeze_params and self.finetune_method == "lora":
169
+ for name, p in self.backbone_model.named_parameters():
170
+ if name in msg.missing_keys: p.requires_grad = True
171
+ else: p.requires_grad = False
172
+ else:
173
+ for name, p in self.backbone_model.named_parameters():
174
+ if "decoder" not in name and "conv1" not in name and "conv2" not in name and "embed_positions" not in name: p.requires_grad = True
175
+ else: p.requires_grad = False
176
+
177
+ # 6. Downstream models
178
+ self.model_seq = nn.Sequential(
179
+ nn.Conv1d(self.model_config.hidden_size, hidden_dim, 1, padding=0),
180
+ nn.ReLU(),
181
+ nn.Dropout(p=0.1),
182
+ nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0),
183
+ nn.ReLU(),
184
+ nn.Dropout(p=0.1),
185
+ nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0)
186
+ )
187
+
188
+ if use_conv_output:
189
+ num_layers = self.model_config.num_hidden_layers + 1 # transformer layers + input embeddings
190
+ self.weights = nn.Parameter(torch.ones(num_layers)/num_layers)
191
+ else:
192
+ num_layers = self.model_config.num_hidden_layers
193
+ self.weights = nn.Parameter(torch.zeros(num_layers))
194
+
195
+ if apply_gradient_reversal:
196
+ self.dataset_layer = nn.Sequential(
197
+ RevGrad(),
198
+ nn.Linear(hidden_dim, hidden_dim),
199
+ nn.ReLU(),
200
+ nn.Linear(hidden_dim, num_dataset),
201
+ )
202
+ self.out_layer = nn.Sequential(
203
+ nn.Linear(hidden_dim, hidden_dim),
204
+ nn.ReLU(),
205
+ nn.Linear(hidden_dim, output_class_num),
206
+ )
207
+
208
+
209
+ def forward(self, x, length=None, return_feature=False):
210
+ # 1. feature extraction and projections
211
+ if length is not None:
212
+ max_audio_len = 15*16000
213
+ # Append to list for feature_extractor to work
214
+ new_x = list()
215
+ for idx in range(len(length)):
216
+ new_x.append(x[idx].detach().cpu().numpy())
217
+
218
+ # Max length is max audio len in a batch
219
+ features = self.feature_extractor(
220
+ new_x,
221
+ return_tensors="pt",
222
+ sampling_rate=16000,
223
+ max_length=max_audio_len
224
+ )
225
+ features = features.input_features.cuda()
226
+ else:
227
+ max_audio_len = 15*16000
228
+ features = self.feature_extractor(
229
+ x[0].detach().cpu(),
230
+ return_tensors="pt",
231
+ sampling_rate=16000,
232
+ max_length=max_audio_len
233
+ )
234
+ features = features.input_features.cuda()
235
+
236
+ # 2. get length and mask
237
+ if length is not None:
238
+ length = self._get_feat_extract_output_lengths(length.detach().cpu())
239
+ # Replace positional embeddings
240
+ self.backbone_model.encoder.embed_positions = self.backbone_model.encoder.embed_positions.from_pretrained(self.embed_positions[:750])
241
+ else:
242
+ # Replace positional embeddings
243
+ length = torch.tensor([len(x[0])])
244
+ length = self._get_feat_extract_output_lengths(length)
245
+ self.backbone_model.encoder.embed_positions = self.backbone_model.encoder.embed_positions.from_pretrained(self.embed_positions[:750])
246
+
247
+ # 3. transformer encoding features
248
+ # compute reduced attention_mask corresponding to feature vectors
249
+ features = self.backbone_model.encoder(
250
+ features, output_hidden_states=True
251
+ ).hidden_states
252
+
253
+ features = torch.stack(features, dim=0)[-1]
254
+
255
+ # 6. Pass the weighted average to point-wise 1D Conv
256
+ # B x T x D
257
+ features = features.transpose(1, 2)
258
+ features = self.model_seq(features)
259
+ features = features.transpose(1, 2)
260
+
261
+ # 7. Pooling
262
+ if length is not None:
263
+ mean, std = list(), list()
264
+ for snt_id in range(features.shape[0]):
265
+ # Avoiding padded time steps
266
+ actual_size = length[snt_id]
267
+ mean.append(torch.mean(features[snt_id, 0:actual_size, ...], dim=0))
268
+ features = torch.stack(mean)
269
+ else:
270
+ features = torch.mean(features, dim=1)
271
+
272
+ # 8. Output predictions
273
+ # B x D
274
+ predicted = self.out_layer(features)
275
+ if self.apply_gradient_reversal:
276
+ dataset_predicted = self.dataset_layer(features)
277
+ if return_feature: return predicted, dataset_predicted, features
278
+ return predicted, dataset_predicted
279
+ if return_feature: return predicted, features
280
+ return predicted
281
+
282
+ # From huggingface
283
+ def _get_feat_extract_output_lengths(self, input_lengths):
284
+ """
285
+ Computes the output length of the convolutional layers
286
+ """
287
+ input_lengths = input_lengths // 160
288
+ input_lengths = (input_lengths - 1) // 2 + 1
289
+ return input_lengths
290
+
291
+ def prepare_mask(length, shape, dtype):
292
+ # Modified from huggingface
293
+ mask = torch.zeros(
294
+ shape, dtype=dtype
295
+ )
296
+ # these two operations makes sure that all values
297
+ # before the output lengths indices are attended to
298
+ mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1
299
+ mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool()
300
+ return mask
301
+
capspeech/eval/src/model/emotion/__init__.py ADDED
File without changes
capspeech/eval/src/model/emotion/wavlm_emotion.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import loralib as lora
4
+ import transformers.models.wavlm.modeling_wavlm as wavlm
5
+ from speechbrain.lobes.models.huggingface_transformers.huggingface import make_padding_masks
6
+
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from transformers import Wav2Vec2FeatureExtractor
10
+ from transformers import WavLMModel
11
+
12
+ import sys
13
+ from pathlib import Path
14
+ sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1])))
15
+
16
+ class WavLMEncoderLayer(nn.Module):
17
+ def __init__(self, layer_idx, config, has_relative_position_bias: bool = True):
18
+ super().__init__()
19
+ self.attention = wavlm.WavLMAttention(
20
+ embed_dim=config.hidden_size,
21
+ num_heads=config.num_attention_heads,
22
+ dropout=config.attention_dropout,
23
+ num_buckets=config.num_buckets,
24
+ max_distance=config.max_bucket_distance,
25
+ has_relative_position_bias=has_relative_position_bias,
26
+ )
27
+ self.dropout = nn.Dropout(config.hidden_dropout)
28
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
29
+ self.feed_forward = wavlm.WavLMFeedForward(config)
30
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
31
+ self.config = config
32
+
33
+ if layer_idx > config.num_hidden_layers // 2:
34
+ if self.config.finetune_method == "lora" or self.config.finetune_method == "combined":
35
+ self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank)
36
+ self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank)
37
+
38
+ def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0):
39
+ attn_residual = hidden_states
40
+ hidden_states, attn_weights, position_bias = self.attention(
41
+ hidden_states,
42
+ attention_mask=attention_mask,
43
+ position_bias=position_bias,
44
+ output_attentions=output_attentions,
45
+ index=index,
46
+ )
47
+ hidden_states = self.dropout(hidden_states)
48
+ hidden_states = attn_residual + hidden_states
49
+
50
+ # Adapter
51
+ if self.config.finetune_method == "adapter":
52
+ adapt_h = self.adapter(hidden_states)
53
+
54
+ hidden_states = self.layer_norm(hidden_states)
55
+ hidden_states = hidden_states + self.feed_forward(hidden_states)
56
+ hidden_states = self.final_layer_norm(hidden_states)
57
+ outputs = (hidden_states, position_bias)
58
+
59
+ if output_attentions:
60
+ outputs += (attn_weights,)
61
+
62
+ return outputs
63
+
64
+
65
+ class WavLMEncoderLayerStableLayerNorm(nn.Module):
66
+ def __init__(self, layer_idx, config, has_relative_position_bias: bool = True):
67
+ super().__init__()
68
+ self.attention = wavlm.WavLMAttention(
69
+ embed_dim=config.hidden_size,
70
+ num_heads=config.num_attention_heads,
71
+ dropout=config.attention_dropout,
72
+ num_buckets=config.num_buckets,
73
+ max_distance=config.max_bucket_distance,
74
+ has_relative_position_bias=has_relative_position_bias,
75
+ )
76
+ self.dropout = nn.Dropout(config.hidden_dropout)
77
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
78
+ self.feed_forward = wavlm.WavLMFeedForward(config)
79
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
80
+ self.config = config
81
+
82
+ if layer_idx > config.num_hidden_layers // 2:
83
+ if self.config.finetune_method == "lora" or self.config.finetune_method == "combined":
84
+ self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank)
85
+ self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank)
86
+
87
+
88
+ def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False):
89
+ attn_residual = hidden_states
90
+ hidden_states = self.layer_norm(hidden_states)
91
+ hidden_states, attn_weights, position_bias = self.attention(
92
+ hidden_states,
93
+ attention_mask=attention_mask,
94
+ position_bias=position_bias,
95
+ output_attentions=output_attentions,
96
+ )
97
+ hidden_states = self.dropout(hidden_states)
98
+ hidden_states = attn_residual + hidden_states
99
+ hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
100
+
101
+ outputs = (hidden_states, position_bias)
102
+
103
+ if output_attentions:
104
+ outputs += (attn_weights,)
105
+
106
+ return outputs
107
+
108
+
109
+ class WavLMWrapper(nn.Module):
110
+ def __init__(
111
+ self,
112
+ pretrain_model="wavlm_large",
113
+ hidden_dim=256,
114
+ finetune_method="lora",
115
+ lora_rank=16,
116
+ freeze_params=True,
117
+ output_class_num=4,
118
+ use_conv_output=True,
119
+ detailed_class_num=17
120
+ ):
121
+ super(WavLMWrapper, self).__init__()
122
+ # 1. We Load the model first with weights
123
+ self.pretrain_model = pretrain_model
124
+ self.finetune_method = finetune_method
125
+ self.freeze_params = freeze_params
126
+ self.use_conv_output = use_conv_output
127
+ self.lora_rank = lora_rank
128
+ if self.pretrain_model == "wavlm":
129
+ self.backbone_model = WavLMModel.from_pretrained(
130
+ "microsoft/wavlm-base-plus",
131
+ output_hidden_states=True,
132
+ )
133
+ elif self.pretrain_model == "wavlm_large":
134
+ self.processor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-large')
135
+ self.backbone_model = WavLMModel.from_pretrained(
136
+ "microsoft/wavlm-large",
137
+ output_hidden_states=True,
138
+ )
139
+ state_dict = self.backbone_model.state_dict()
140
+ # 2. Read the model config
141
+ self.model_config = self.backbone_model.config
142
+ self.model_config.finetune_method = self.finetune_method
143
+ self.model_config.lora_rank = self.lora_rank
144
+
145
+ # 3. Config encoder layers with adapter or embedding prompt
146
+ if self.pretrain_model == "wavlm":
147
+ self.backbone_model.encoder.layers = nn.ModuleList(
148
+ [WavLMEncoderLayer(i, self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)]
149
+ )
150
+ elif self.pretrain_model == "wavlm_large":
151
+ self.backbone_model.encoder.layers = nn.ModuleList(
152
+ [WavLMEncoderLayerStableLayerNorm(i, self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)]
153
+ )
154
+ # 4. Load the weights back
155
+ msg = self.backbone_model.load_state_dict(state_dict, strict=False)
156
+
157
+ # 5. Freeze the weights
158
+ self.freeze_params = freeze_params
159
+ if self.freeze_params and self.finetune_method != "lora":
160
+ for _, p in self.backbone_model.named_parameters(): p.requires_grad = False
161
+ elif self.freeze_params and self.finetune_method == "lora":
162
+ for name, p in self.backbone_model.named_parameters():
163
+ if name in msg.missing_keys: p.requires_grad = True
164
+ else: p.requires_grad = False
165
+ else:
166
+ for _, p in self.backbone_model.named_parameters(): p.requires_grad = True
167
+
168
+ # 6. Downstream models
169
+ self.model_seq = nn.Sequential(
170
+ nn.Conv1d(self.model_config.hidden_size, hidden_dim, 1, padding=0),
171
+ nn.ReLU(),
172
+ nn.Dropout(p=0.1),
173
+ nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0),
174
+ nn.ReLU(),
175
+ nn.Dropout(p=0.1),
176
+ nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0)
177
+ )
178
+
179
+ if self.use_conv_output:
180
+ num_layers = self.model_config.num_hidden_layers + 1 # transformer layers + input embeddings
181
+ self.weights = nn.Parameter(torch.ones(num_layers)/num_layers)
182
+ else:
183
+ num_layers = self.model_config.num_hidden_layers
184
+ self.weights = nn.Parameter(torch.zeros(num_layers))
185
+
186
+ self.emotion_layer = nn.Sequential(
187
+ nn.Linear(hidden_dim, hidden_dim),
188
+ nn.ReLU(),
189
+ nn.Linear(hidden_dim, output_class_num),
190
+ )
191
+
192
+ self.detailed_out_layer = nn.Sequential(
193
+ nn.Linear(hidden_dim, hidden_dim),
194
+ nn.ReLU(),
195
+ nn.Linear(hidden_dim, detailed_class_num),
196
+ )
197
+
198
+ self.arousal_layer = nn.Sequential(
199
+ nn.Linear(hidden_dim, hidden_dim),
200
+ nn.ReLU(),
201
+ nn.Linear(hidden_dim, 1),
202
+ nn.Sigmoid()
203
+ )
204
+
205
+ self.valence_layer = nn.Sequential(
206
+ nn.Linear(hidden_dim, hidden_dim),
207
+ nn.ReLU(),
208
+ nn.Linear(hidden_dim, 1),
209
+ nn.Sigmoid()
210
+ )
211
+
212
+ self.dominance_layer = nn.Sequential(
213
+ nn.Linear(hidden_dim, hidden_dim),
214
+ nn.ReLU(),
215
+ nn.Linear(hidden_dim, 1),
216
+ nn.Sigmoid()
217
+ )
218
+
219
+ def forward(self, x, length=None, return_feature=False):
220
+ # 1. feature extraction and projections
221
+ if self.pretrain_model == "wavlm_large":
222
+ with torch.no_grad():
223
+ signal, attention_mask = list(), list()
224
+ if length is not None: attention_mask = make_padding_masks(x, wav_len=length/length.max()).to(x.device)
225
+ else: attention_mask = make_padding_masks(x, wav_len=torch.tensor([1]).to(x.device)).to(x.device)
226
+
227
+ for idx in range(len(x)):
228
+ input = self.processor(x[idx], sampling_rate=16_000, return_tensors="pt", padding=True)
229
+ signal.append(input["input_values"][0].to(x.device))
230
+ signal = torch.stack(signal)
231
+
232
+ # 2. get length and mask
233
+ if length is not None:
234
+ length = self.get_feat_extract_output_lengths(length.detach().cpu())
235
+ length = length.cuda()
236
+
237
+ if self.pretrain_model == "wavlm":
238
+ x = self.backbone_model(
239
+ x, output_hidden_states=True
240
+ ).hidden_states
241
+ else:
242
+ x = self.backbone_model(
243
+ signal,
244
+ attention_mask=attention_mask,
245
+ output_hidden_states=True
246
+ ).hidden_states
247
+
248
+ # 4. stacked feature
249
+ if self.use_conv_output: stacked_feature = torch.stack(x, dim=0)
250
+ else: stacked_feature = torch.stack(x, dim=0)[1:]
251
+
252
+ # 5. Weighted sum
253
+ _, *origin_shape = stacked_feature.shape
254
+ # Return transformer enc outputs [num_enc_layers, B, T, D]
255
+ if self.use_conv_output:
256
+ stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers+1, -1)
257
+ else:
258
+ stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers, -1)
259
+ norm_weights = F.softmax(self.weights, dim=-1)
260
+
261
+ # Perform weighted average
262
+ weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0)
263
+ features = weighted_feature.view(*origin_shape)
264
+
265
+ # 6. Pass the weighted average to point-wise 1D Conv
266
+ # B x T x D
267
+ features = features.transpose(1, 2)
268
+ features = self.model_seq(features)
269
+ features = features.transpose(1, 2)
270
+
271
+ # 7. Pooling
272
+ if length is not None:
273
+ mean, std = list(), list()
274
+ for snt_id in range(features.shape[0]):
275
+ # Avoiding padded time steps
276
+ actual_size = length[snt_id]
277
+ mean.append(torch.mean(features[snt_id, 0:actual_size, ...], dim=0))
278
+ features = torch.stack(mean)
279
+ else:
280
+ features = torch.mean(features, dim=1)
281
+
282
+ # Output predictions
283
+ # B x D
284
+ predicted = self.emotion_layer(features)
285
+ detailed_predicted = self.detailed_out_layer(features)
286
+ arousal = self.arousal_layer(features)
287
+ valence = self.valence_layer(features)
288
+ dominance = self.dominance_layer(features)
289
+ if return_feature: return predicted, features, detailed_predicted, arousal, valence, dominance
290
+ return predicted, detailed_predicted, arousal, valence, dominance
291
+
292
+ # From huggingface
293
+ def get_feat_extract_output_lengths(self, input_length):
294
+ """
295
+ Computes the output length of the convolutional layers
296
+ """
297
+ def _conv_out_length(input_length, kernel_size, stride):
298
+ # 1D convolutional layer output length formula taken
299
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
300
+ return (input_length - kernel_size) // stride + 1
301
+ for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride):
302
+ input_length = _conv_out_length(input_length, kernel_size, stride)
303
+ return input_length
304
+
305
+ def prepare_mask(length, shape, dtype):
306
+ # Modified from huggingface
307
+ mask = torch.zeros(
308
+ shape, dtype=dtype
309
+ )
310
+ # these two operations makes sure that all values
311
+ # before the output lengths indices are attended to
312
+ mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1
313
+ mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool()
314
+ return mask
315
+
capspeech/eval/src/model/emotion/wavlm_emotion_dim.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import torch
4
+ import argparse
5
+ import numpy as np
6
+ import loralib as lora
7
+ import transformers.models.wav2vec2.modeling_wav2vec2 as w2v2
8
+ import transformers.models.wavlm.modeling_wavlm as wavlm
9
+ from speechbrain.lobes.models.huggingface_transformers.huggingface import make_padding_masks
10
+
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+ from transformers import Wav2Vec2FeatureExtractor
14
+ from transformers import WavLMModel
15
+
16
+ import sys
17
+ from pathlib import Path
18
+ sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1])))
19
+
20
+ class WavLMEncoderLayer(nn.Module):
21
+ def __init__(self, layer_idx, config, has_relative_position_bias: bool = True):
22
+ super().__init__()
23
+ self.attention = wavlm.WavLMAttention(
24
+ embed_dim=config.hidden_size,
25
+ num_heads=config.num_attention_heads,
26
+ dropout=config.attention_dropout,
27
+ num_buckets=config.num_buckets,
28
+ max_distance=config.max_bucket_distance,
29
+ has_relative_position_bias=has_relative_position_bias,
30
+ )
31
+ self.dropout = nn.Dropout(config.hidden_dropout)
32
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
33
+ self.feed_forward = wavlm.WavLMFeedForward(config)
34
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
35
+ self.config = config
36
+
37
+ if layer_idx > config.num_hidden_layers // 2:
38
+ if self.config.finetune_method == "lora" or self.config.finetune_method == "combined":
39
+ self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank)
40
+ self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank)
41
+
42
+ def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0):
43
+ attn_residual = hidden_states
44
+ hidden_states, attn_weights, position_bias = self.attention(
45
+ hidden_states,
46
+ attention_mask=attention_mask,
47
+ position_bias=position_bias,
48
+ output_attentions=output_attentions,
49
+ index=index,
50
+ )
51
+ hidden_states = self.dropout(hidden_states)
52
+ hidden_states = attn_residual + hidden_states
53
+
54
+ # Adapter
55
+ if self.config.finetune_method == "adapter":
56
+ adapt_h = self.adapter(hidden_states)
57
+
58
+ hidden_states = self.layer_norm(hidden_states)
59
+ hidden_states = hidden_states + self.feed_forward(hidden_states)
60
+ hidden_states = self.final_layer_norm(hidden_states)
61
+ outputs = (hidden_states, position_bias)
62
+
63
+ if output_attentions:
64
+ outputs += (attn_weights,)
65
+
66
+ return outputs
67
+
68
+
69
+ class WavLMEncoderLayerStableLayerNorm(nn.Module):
70
+ def __init__(self, layer_idx, config, has_relative_position_bias: bool = True):
71
+ super().__init__()
72
+ self.attention = wavlm.WavLMAttention(
73
+ embed_dim=config.hidden_size,
74
+ num_heads=config.num_attention_heads,
75
+ dropout=config.attention_dropout,
76
+ num_buckets=config.num_buckets,
77
+ max_distance=config.max_bucket_distance,
78
+ has_relative_position_bias=has_relative_position_bias,
79
+ )
80
+ self.dropout = nn.Dropout(config.hidden_dropout)
81
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
82
+ self.feed_forward = wavlm.WavLMFeedForward(config)
83
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
84
+ self.config = config
85
+
86
+ if layer_idx > config.num_hidden_layers // 2:
87
+ if self.config.finetune_method == "lora" or self.config.finetune_method == "combined":
88
+ self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank)
89
+ self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank)
90
+
91
+
92
+ def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False):
93
+ attn_residual = hidden_states
94
+ hidden_states = self.layer_norm(hidden_states)
95
+ hidden_states, attn_weights, position_bias = self.attention(
96
+ hidden_states,
97
+ attention_mask=attention_mask,
98
+ position_bias=position_bias,
99
+ output_attentions=output_attentions,
100
+ )
101
+ hidden_states = self.dropout(hidden_states)
102
+ hidden_states = attn_residual + hidden_states
103
+ hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
104
+
105
+ outputs = (hidden_states, position_bias)
106
+
107
+ if output_attentions:
108
+ outputs += (attn_weights,)
109
+
110
+ return outputs
111
+
112
+
113
+ class WavLMWrapper(nn.Module):
114
+ def __init__(
115
+ self,
116
+ pretrain_model="wavlm_large",
117
+ hidden_dim=256,
118
+ finetune_method="lora",
119
+ lora_rank=16,
120
+ freeze_params=True,
121
+ output_class_num=4,
122
+ use_conv_output=True,
123
+ detailed_class_num=17,
124
+ predict_gender=False
125
+ ):
126
+ super(WavLMWrapper, self).__init__()
127
+ # 1. We Load the model first with weights
128
+ self.pretrain_model = pretrain_model
129
+ self.finetune_method = finetune_method
130
+ self.freeze_params = freeze_params
131
+ self.use_conv_output = use_conv_output
132
+ self.lora_rank = lora_rank
133
+ self.predict_gender = predict_gender
134
+ if self.pretrain_model == "wavlm":
135
+ self.backbone_model = WavLMModel.from_pretrained(
136
+ "microsoft/wavlm-base-plus",
137
+ output_hidden_states=True,
138
+ )
139
+ elif self.pretrain_model == "wavlm_large":
140
+ self.processor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-large')
141
+ self.backbone_model = WavLMModel.from_pretrained(
142
+ "microsoft/wavlm-large",
143
+ output_hidden_states=True,
144
+ )
145
+ state_dict = self.backbone_model.state_dict()
146
+ # 2. Read the model config
147
+ self.model_config = self.backbone_model.config
148
+ self.model_config.finetune_method = self.finetune_method
149
+ self.model_config.lora_rank = self.lora_rank
150
+
151
+ # 3. Config encoder layers with adapter or embedding prompt
152
+ if self.pretrain_model == "wavlm":
153
+ self.backbone_model.encoder.layers = nn.ModuleList(
154
+ [WavLMEncoderLayer(i, self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)]
155
+ )
156
+ elif self.pretrain_model == "wavlm_large":
157
+ self.backbone_model.encoder.layers = nn.ModuleList(
158
+ [WavLMEncoderLayerStableLayerNorm(i, self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)]
159
+ )
160
+ # 4. Load the weights back
161
+ msg = self.backbone_model.load_state_dict(state_dict, strict=False)
162
+
163
+ # 5. Freeze the weights
164
+ self.freeze_params = freeze_params
165
+ if self.freeze_params and self.finetune_method != "lora":
166
+ for _, p in self.backbone_model.named_parameters(): p.requires_grad = False
167
+ elif self.freeze_params and self.finetune_method == "lora":
168
+ for name, p in self.backbone_model.named_parameters():
169
+ if name in msg.missing_keys: p.requires_grad = True
170
+ else: p.requires_grad = False
171
+ else:
172
+ for _, p in self.backbone_model.named_parameters(): p.requires_grad = True
173
+
174
+ # 6. Downstream models
175
+ self.model_seq = nn.Sequential(
176
+ nn.Conv1d(self.model_config.hidden_size, hidden_dim, 1, padding=0),
177
+ nn.ReLU(),
178
+ nn.Dropout(p=0.1),
179
+ nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0),
180
+ nn.ReLU(),
181
+ nn.Dropout(p=0.1),
182
+ nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0)
183
+ )
184
+
185
+ if self.use_conv_output:
186
+ num_layers = self.model_config.num_hidden_layers + 1 # transformer layers + input embeddings
187
+ self.weights = nn.Parameter(torch.ones(num_layers)/num_layers)
188
+ else:
189
+ num_layers = self.model_config.num_hidden_layers
190
+ self.weights = nn.Parameter(torch.zeros(num_layers))
191
+
192
+ self.arousal_layer = nn.Sequential(
193
+ nn.Linear(hidden_dim, hidden_dim),
194
+ nn.ReLU(),
195
+ nn.Linear(hidden_dim, 1),
196
+ nn.Sigmoid()
197
+ )
198
+
199
+ self.valence_layer = nn.Sequential(
200
+ nn.Linear(hidden_dim, hidden_dim),
201
+ nn.ReLU(),
202
+ nn.Linear(hidden_dim, 1),
203
+ nn.Sigmoid()
204
+ )
205
+
206
+ self.dominance_layer = nn.Sequential(
207
+ nn.Linear(hidden_dim, hidden_dim),
208
+ nn.ReLU(),
209
+ nn.Linear(hidden_dim, 1),
210
+ nn.Sigmoid()
211
+ )
212
+
213
+ if self.predict_gender:
214
+ self.gender_layer = nn.Sequential(
215
+ nn.Linear(hidden_dim, hidden_dim),
216
+ nn.ReLU(),
217
+ nn.Linear(hidden_dim, 2)
218
+ )
219
+
220
+ def forward(self, x, length=None, return_feature=False):
221
+ # 1. feature extraction and projections
222
+ if self.pretrain_model == "wavlm_large":
223
+ with torch.no_grad():
224
+ signal, attention_mask = list(), list()
225
+ if length is not None: attention_mask = make_padding_masks(x, wav_len=length/length.max()).to(x.device)
226
+ else: attention_mask = make_padding_masks(x, wav_len=torch.tensor([1]).to(x.device)).to(x.device)
227
+
228
+ for idx in range(len(x)):
229
+ input = self.processor(x[idx], sampling_rate=16_000, return_tensors="pt", padding=True)
230
+ signal.append(input["input_values"][0].to(x.device))
231
+ signal = torch.stack(signal)
232
+
233
+ # 2. get length and mask
234
+ if length is not None:
235
+ length = self.get_feat_extract_output_lengths(length.detach().cpu())
236
+ length = length.cuda()
237
+
238
+ if self.pretrain_model == "wavlm":
239
+ x = self.backbone_model(
240
+ x, output_hidden_states=True
241
+ ).hidden_states
242
+ else:
243
+ x = self.backbone_model(
244
+ signal,
245
+ attention_mask=attention_mask,
246
+ output_hidden_states=True
247
+ ).hidden_states
248
+
249
+ # 4. stacked feature
250
+ if self.use_conv_output: stacked_feature = torch.stack(x, dim=0)
251
+ else: stacked_feature = torch.stack(x, dim=0)[1:]
252
+
253
+ # 5. Weighted sum
254
+ _, *origin_shape = stacked_feature.shape
255
+ # Return transformer enc outputs [num_enc_layers, B, T, D]
256
+ if self.use_conv_output:
257
+ stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers+1, -1)
258
+ else:
259
+ stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers, -1)
260
+ norm_weights = F.softmax(self.weights, dim=-1)
261
+
262
+ # Perform weighted average
263
+ weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0)
264
+ features = weighted_feature.view(*origin_shape)
265
+
266
+ # 6. Pass the weighted average to point-wise 1D Conv
267
+ # B x T x D
268
+ features = features.transpose(1, 2)
269
+ features = self.model_seq(features)
270
+ features = features.transpose(1, 2)
271
+
272
+ # 7. Pooling
273
+ if length is not None:
274
+ mean, std = list(), list()
275
+ for snt_id in range(features.shape[0]):
276
+ # Avoiding padded time steps
277
+ actual_size = length[snt_id]
278
+ mean.append(torch.mean(features[snt_id, 0:actual_size, ...], dim=0))
279
+ features = torch.stack(mean)
280
+ else:
281
+ features = torch.mean(features, dim=1)
282
+
283
+ # 8. Output predictions
284
+ # B x D
285
+ arousal = self.arousal_layer(features)
286
+ valence = self.valence_layer(features)
287
+ dominance = self.dominance_layer(features)
288
+
289
+ if(self.predict_gender):
290
+ gender_outputs = self.gender_layer(features)
291
+ return arousal, valence, dominance, gender_outputs
292
+
293
+ return arousal, valence, dominance
294
+
295
+ # From huggingface
296
+ def get_feat_extract_output_lengths(self, input_length):
297
+ """
298
+ Computes the output length of the convolutional layers
299
+ """
300
+ def _conv_out_length(input_length, kernel_size, stride):
301
+ # 1D convolutional layer output length formula taken
302
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
303
+ return (input_length - kernel_size) // stride + 1
304
+ for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride):
305
+ input_length = _conv_out_length(input_length, kernel_size, stride)
306
+ return input_length
307
+
308
+ def prepare_mask(length, shape, dtype):
309
+ # Modified from huggingface
310
+ mask = torch.zeros(
311
+ shape, dtype=dtype
312
+ )
313
+ # these two operations makes sure that all values
314
+ # before the output lengths indices are attended to
315
+ mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1
316
+ mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool()
317
+ return mask
318
+