Spaces:
Running
on
Zero
Running
on
Zero
Upload 518 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- capspeech/__init__.py +0 -0
- capspeech/ar/README.md +44 -0
- capspeech/ar/__init__.py +0 -0
- capspeech/ar/events.txt +395 -0
- capspeech/ar/finetune_acccaptts.sh +64 -0
- capspeech/ar/finetune_agenttts.sh +61 -0
- capspeech/ar/finetune_captts.sh +64 -0
- capspeech/ar/finetune_capttsse.sh +62 -0
- capspeech/ar/finetune_emocaptts.sh +64 -0
- capspeech/ar/parler_tts/__init__.py +25 -0
- capspeech/ar/parler_tts/configuration_parler_tts.py +291 -0
- capspeech/ar/parler_tts/dac_wrapper/__init__.py +2 -0
- capspeech/ar/parler_tts/dac_wrapper/configuration_dac.py +27 -0
- capspeech/ar/parler_tts/dac_wrapper/modeling_dac.py +164 -0
- capspeech/ar/parler_tts/logits_processors.py +54 -0
- capspeech/ar/parler_tts/modeling_parler_tts.py +0 -0
- capspeech/ar/parler_tts/streamer.py +147 -0
- capspeech/ar/pretrain.sh +68 -0
- capspeech/ar/training/__init__.py +0 -0
- capspeech/ar/training/arguments.py +403 -0
- capspeech/ar/training/arguments_captts.py +391 -0
- capspeech/ar/training/arguments_capttsse.py +387 -0
- capspeech/ar/training/data.py +277 -0
- capspeech/ar/training/data_captts.py +255 -0
- capspeech/ar/training/data_capttsse.py +253 -0
- capspeech/ar/training/finetune_captts.py +1270 -0
- capspeech/ar/training/finetune_capttsse.py +1267 -0
- capspeech/ar/training/run_parler_tts_training.py +1279 -0
- capspeech/ar/training/utils.py +203 -0
- capspeech/eval/README.md +42 -0
- capspeech/eval/__init__.py +0 -0
- capspeech/eval/age_gender.py +35 -0
- capspeech/eval/asr_eval.py +24 -0
- capspeech/eval/base_eval.py +32 -0
- capspeech/eval/bin.json +10 -0
- capspeech/eval/pitch.py +30 -0
- capspeech/eval/requirements.txt +16 -0
- capspeech/eval/speed.py +29 -0
- capspeech/eval/src/__init__.py +0 -0
- capspeech/eval/src/example/__init__.py +0 -0
- capspeech/eval/src/example/categorized_emotion.py +92 -0
- capspeech/eval/src/example/dialect_world_dialect.py +87 -0
- capspeech/eval/src/model/__init__.py +0 -0
- capspeech/eval/src/model/adapter.py +73 -0
- capspeech/eval/src/model/dialect/__init__.py +0 -0
- capspeech/eval/src/model/dialect/wavlm_dialect.py +300 -0
- capspeech/eval/src/model/dialect/whisper_dialect.py +301 -0
- capspeech/eval/src/model/emotion/__init__.py +0 -0
- capspeech/eval/src/model/emotion/wavlm_emotion.py +315 -0
- capspeech/eval/src/model/emotion/wavlm_emotion_dim.py +318 -0
capspeech/__init__.py
ADDED
File without changes
|
capspeech/ar/README.md
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CapSpeech-AR
|
2 |
+
|
3 |
+
## Pretrain
|
4 |
+
|
5 |
+
```bash
|
6 |
+
bash pretrain.sh
|
7 |
+
```
|
8 |
+
Make sure to change paths and keys in `pretrain.sh` to yours.
|
9 |
+
|
10 |
+
## Finetune on CapTTS
|
11 |
+
|
12 |
+
```bash
|
13 |
+
bash finetune_captts.sh
|
14 |
+
```
|
15 |
+
Make sure to change paths and keys in `finetune_captts.sh` to yours.
|
16 |
+
|
17 |
+
## Finetune on EmoCapTTS
|
18 |
+
|
19 |
+
```bash
|
20 |
+
bash finetune_emocaptts.sh
|
21 |
+
```
|
22 |
+
Make sure to change paths and keys in `finetune_emocaptts.sh` to yours.
|
23 |
+
|
24 |
+
## Finetune on AccCapTTS
|
25 |
+
|
26 |
+
```bash
|
27 |
+
bash finetune_acccaptts.sh
|
28 |
+
```
|
29 |
+
Make sure to change paths and keys in `finetune_acccaptts.sh` to yours.
|
30 |
+
|
31 |
+
## Finetune on CapTTS-SE
|
32 |
+
|
33 |
+
```bash
|
34 |
+
bash finetune_capttsse.sh
|
35 |
+
```
|
36 |
+
Make sure to change paths and keys in `finetune_capttsse.sh` to yours.
|
37 |
+
|
38 |
+
|
39 |
+
## Finetune on AgentTTS
|
40 |
+
|
41 |
+
```bash
|
42 |
+
bash finetune_agenttts.sh
|
43 |
+
```
|
44 |
+
Make sure to change paths and keys in `finetune_agenttts.sh` to yours.
|
capspeech/ar/__init__.py
ADDED
File without changes
|
capspeech/ar/events.txt
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
people whispering
|
2 |
+
Microwave oven
|
3 |
+
extending ladders
|
4 |
+
mosquito buzzing
|
5 |
+
dog whimpering
|
6 |
+
coyote howling
|
7 |
+
hair dryer drying
|
8 |
+
Writing
|
9 |
+
rapping
|
10 |
+
machine gun shooting
|
11 |
+
dog bow-wow
|
12 |
+
dog howling
|
13 |
+
barn swallow calling
|
14 |
+
baby babbling
|
15 |
+
Fireworks
|
16 |
+
church bell ringing
|
17 |
+
car horn
|
18 |
+
cat caterwauling
|
19 |
+
subway, metro, underground
|
20 |
+
waterfall burbling
|
21 |
+
lions roaring
|
22 |
+
toilet flushing
|
23 |
+
skateboarding
|
24 |
+
wind
|
25 |
+
ripping paper
|
26 |
+
vacuum cleaner cleaning floors
|
27 |
+
mouse squeaking
|
28 |
+
keyboard typing
|
29 |
+
playing timpani
|
30 |
+
playing harp
|
31 |
+
sheep bleating
|
32 |
+
eletric blender running
|
33 |
+
people slapping
|
34 |
+
playing ukulele
|
35 |
+
frog
|
36 |
+
car engine knocking
|
37 |
+
cat purring
|
38 |
+
chainsaw
|
39 |
+
Violin or fiddle
|
40 |
+
people hiccup
|
41 |
+
playing acoustic guitar
|
42 |
+
donkey, ass braying
|
43 |
+
playing french horn
|
44 |
+
playing squash
|
45 |
+
gibbon howling
|
46 |
+
playing harmonica
|
47 |
+
playing shofar
|
48 |
+
hedge trimmer running
|
49 |
+
playing washboard
|
50 |
+
running electric fan
|
51 |
+
splashing water
|
52 |
+
playing bassoon
|
53 |
+
people slurping
|
54 |
+
playing accordion
|
55 |
+
playing oboe
|
56 |
+
popping popcorn
|
57 |
+
glass breaking
|
58 |
+
alarm clock ringing
|
59 |
+
mouse click
|
60 |
+
Laughter
|
61 |
+
magpie calling
|
62 |
+
playing snare drum
|
63 |
+
people finger snapping
|
64 |
+
ferret dooking
|
65 |
+
tornado roaring
|
66 |
+
Hi-hat
|
67 |
+
lawn mowing
|
68 |
+
church bells
|
69 |
+
cat growling
|
70 |
+
cheetah chirrup
|
71 |
+
heart sounds, heartbeat
|
72 |
+
firing muskets
|
73 |
+
vehicle horn, car horn, honking
|
74 |
+
turkey gobbling
|
75 |
+
ice cream truck, ice cream van
|
76 |
+
underwater bubbling
|
77 |
+
footsteps on snow
|
78 |
+
water drops
|
79 |
+
people sobbing
|
80 |
+
basketball bounce
|
81 |
+
Applause
|
82 |
+
playing sitar
|
83 |
+
playing gong
|
84 |
+
train
|
85 |
+
coughing
|
86 |
+
people screaming
|
87 |
+
Gunshot or gunfire
|
88 |
+
chinchilla barking
|
89 |
+
cat hissing
|
90 |
+
horse clip-clop
|
91 |
+
engine
|
92 |
+
people battle cry
|
93 |
+
typing on computer keyboard
|
94 |
+
playing clarinet
|
95 |
+
driving motorcycle
|
96 |
+
male singing
|
97 |
+
singing bowl
|
98 |
+
skiing
|
99 |
+
driving buses
|
100 |
+
alligators, crocodiles hissing
|
101 |
+
people eating apple
|
102 |
+
door slamming
|
103 |
+
Flute
|
104 |
+
raining
|
105 |
+
Electric piano
|
106 |
+
sliding door
|
107 |
+
washing machine
|
108 |
+
opening or closing car electric windows
|
109 |
+
baby crying
|
110 |
+
people babbling
|
111 |
+
snake hissing
|
112 |
+
brushing teeth
|
113 |
+
playing tambourine
|
114 |
+
Acoustic guitar
|
115 |
+
clock tick
|
116 |
+
playing castanets
|
117 |
+
thunder
|
118 |
+
playing didgeridoo
|
119 |
+
playing synthesizer
|
120 |
+
mouse clicking
|
121 |
+
lathe spinning
|
122 |
+
spraying water
|
123 |
+
hen
|
124 |
+
stream burbling
|
125 |
+
door wood creaks
|
126 |
+
sailing
|
127 |
+
dog
|
128 |
+
car engine idling
|
129 |
+
bowling impact
|
130 |
+
driving snowmobile
|
131 |
+
toilet flush
|
132 |
+
bird squawking
|
133 |
+
playing timbales
|
134 |
+
playing drum kit
|
135 |
+
owl hooting
|
136 |
+
striking pool
|
137 |
+
Oboe
|
138 |
+
duck quacking
|
139 |
+
people belly laughing
|
140 |
+
lighting firecrackers
|
141 |
+
roller coaster running
|
142 |
+
blowtorch igniting
|
143 |
+
wood thrush calling
|
144 |
+
Glockenspiel
|
145 |
+
frog croaking
|
146 |
+
playing harpsichord
|
147 |
+
train horning
|
148 |
+
plastic bottle crushing
|
149 |
+
playing tabla
|
150 |
+
fire crackling
|
151 |
+
dog barking
|
152 |
+
thunderstorm
|
153 |
+
playing banjo
|
154 |
+
swimming
|
155 |
+
volcano explosion
|
156 |
+
playing table tennis
|
157 |
+
sea lion barking
|
158 |
+
rowboat, canoe, kayak rowing
|
159 |
+
Meow
|
160 |
+
pouring water
|
161 |
+
playing tympani
|
162 |
+
rooster
|
163 |
+
siren
|
164 |
+
parrot talking
|
165 |
+
Finger snapping
|
166 |
+
playing steel guitar, slide guitar
|
167 |
+
Trumpet
|
168 |
+
tractor digging
|
169 |
+
people coughing
|
170 |
+
cat meowing
|
171 |
+
Snare drum
|
172 |
+
playing erhu
|
173 |
+
crow cawing
|
174 |
+
playing djembe
|
175 |
+
whale calling
|
176 |
+
mynah bird singing
|
177 |
+
playing tennis
|
178 |
+
chopping food
|
179 |
+
golf driving
|
180 |
+
tapping guitar
|
181 |
+
playing cello
|
182 |
+
dog growling
|
183 |
+
elephant trumpeting
|
184 |
+
sea waves
|
185 |
+
police radio chatter
|
186 |
+
lions growling
|
187 |
+
playing lacrosse
|
188 |
+
children shouting
|
189 |
+
missile launch
|
190 |
+
baby laughter
|
191 |
+
air conditioning noise
|
192 |
+
playing saxophone
|
193 |
+
typing on typewriter
|
194 |
+
printer printing
|
195 |
+
race car, auto racing
|
196 |
+
Bus
|
197 |
+
pigeon, dove cooing
|
198 |
+
playing violin, fiddle
|
199 |
+
Double bass
|
200 |
+
striking bowling
|
201 |
+
fireworks banging
|
202 |
+
Harmonica
|
203 |
+
playing glockenspiel
|
204 |
+
reversing beeps
|
205 |
+
playing piano
|
206 |
+
breathing
|
207 |
+
people marching
|
208 |
+
electric shaver, electric razor shaving
|
209 |
+
chimpanzee pant-hooting
|
210 |
+
cricket chirping
|
211 |
+
bird chirping, tweeting
|
212 |
+
using sewing machines
|
213 |
+
crickets
|
214 |
+
cow lowing
|
215 |
+
playing cymbal
|
216 |
+
vacuum cleaner
|
217 |
+
playing zither
|
218 |
+
train whistling
|
219 |
+
goat bleating
|
220 |
+
eating with cutlery
|
221 |
+
black capped chickadee calling
|
222 |
+
ambulance siren
|
223 |
+
playing hockey
|
224 |
+
dog baying
|
225 |
+
Burping or eructation
|
226 |
+
cupboard opening or closing
|
227 |
+
air horn
|
228 |
+
crying baby
|
229 |
+
people eating crisps
|
230 |
+
sloshing water
|
231 |
+
goose honking
|
232 |
+
orchestra
|
233 |
+
people giggling
|
234 |
+
warbler chirping
|
235 |
+
child singing
|
236 |
+
dinosaurs bellowing
|
237 |
+
motorboat, speedboat acceleration
|
238 |
+
airplane
|
239 |
+
chicken clucking
|
240 |
+
woodpecker pecking tree
|
241 |
+
Drawer open or close
|
242 |
+
people eating
|
243 |
+
drinking sipping
|
244 |
+
singing choir
|
245 |
+
playing bass guitar
|
246 |
+
playing bass drum
|
247 |
+
car passing by
|
248 |
+
playing tuning fork
|
249 |
+
Squeak
|
250 |
+
pig oinking
|
251 |
+
Computer keyboard
|
252 |
+
yodelling
|
253 |
+
playing trombone
|
254 |
+
clapping
|
255 |
+
people sneezing
|
256 |
+
pheasant crowing
|
257 |
+
writing on blackboard with chalk
|
258 |
+
Tambourine
|
259 |
+
opening or closing car doors
|
260 |
+
sharpen knife
|
261 |
+
people whistling
|
262 |
+
fireworks
|
263 |
+
playing bagpipes
|
264 |
+
chainsawing trees
|
265 |
+
squishing water
|
266 |
+
people farting
|
267 |
+
playing electric guitar
|
268 |
+
people booing
|
269 |
+
female singing
|
270 |
+
ocean burbling
|
271 |
+
cattle mooing
|
272 |
+
footsteps
|
273 |
+
Knock
|
274 |
+
wind rustling leaves
|
275 |
+
cattle, bovinae cowbell
|
276 |
+
Clarinet
|
277 |
+
police car (siren)
|
278 |
+
Fart
|
279 |
+
cat
|
280 |
+
sheep
|
281 |
+
chopping wood
|
282 |
+
tap dancing
|
283 |
+
playing mandolin
|
284 |
+
wind chime
|
285 |
+
can opening
|
286 |
+
playing hammond organ
|
287 |
+
zebra braying
|
288 |
+
scuba diving
|
289 |
+
chirping birds
|
290 |
+
playing steelpan
|
291 |
+
playing theremin
|
292 |
+
Keys jangling
|
293 |
+
beat boxing
|
294 |
+
firing cannon
|
295 |
+
bouncing on trampoline
|
296 |
+
door wood knock
|
297 |
+
bathroom ventilation fan running
|
298 |
+
snake rattling
|
299 |
+
bull bellowing
|
300 |
+
electric grinder grinding
|
301 |
+
penguins braying
|
302 |
+
otter growling
|
303 |
+
civil defense siren
|
304 |
+
wind noise
|
305 |
+
people humming
|
306 |
+
clock alarm
|
307 |
+
disc scratching
|
308 |
+
fire truck siren
|
309 |
+
telephone bell ringing
|
310 |
+
people sniggering
|
311 |
+
playing bongo
|
312 |
+
cap gun shooting
|
313 |
+
opening or closing drawers
|
314 |
+
cow
|
315 |
+
hammering nails
|
316 |
+
ice cracking
|
317 |
+
foghorn
|
318 |
+
rain
|
319 |
+
playing badminton
|
320 |
+
eagle screaming
|
321 |
+
playing double bass
|
322 |
+
insects
|
323 |
+
people running
|
324 |
+
planing timber
|
325 |
+
cutting hair with electric trimmers
|
326 |
+
Cello
|
327 |
+
people clapping
|
328 |
+
smoke detector beeping
|
329 |
+
mouse pattering
|
330 |
+
bee, wasp, etc. buzzing
|
331 |
+
canary calling
|
332 |
+
people burping
|
333 |
+
Shatter
|
334 |
+
baltimore oriole calling
|
335 |
+
cuckoo bird calling
|
336 |
+
snoring
|
337 |
+
strike lighter
|
338 |
+
people cheering
|
339 |
+
playing bugle
|
340 |
+
playing congas
|
341 |
+
playing vibraphone
|
342 |
+
hail
|
343 |
+
rope skipping
|
344 |
+
playing trumpet
|
345 |
+
pig
|
346 |
+
hand saw
|
347 |
+
people gargling
|
348 |
+
Scissors
|
349 |
+
metronome
|
350 |
+
chipmunk chirping
|
351 |
+
playing flute
|
352 |
+
fox barking
|
353 |
+
crackling fire
|
354 |
+
playing volleyball
|
355 |
+
skidding
|
356 |
+
Bass drum
|
357 |
+
crow
|
358 |
+
elk bugling
|
359 |
+
Telephone
|
360 |
+
Bark
|
361 |
+
chicken crowing
|
362 |
+
people nose blowing
|
363 |
+
car engine starting
|
364 |
+
pumping water
|
365 |
+
Saxophone
|
366 |
+
fly, housefly buzzing
|
367 |
+
Cough
|
368 |
+
people eating noodle
|
369 |
+
francolin calling
|
370 |
+
arc welding
|
371 |
+
horse neighing
|
372 |
+
Tearing
|
373 |
+
helicopter
|
374 |
+
playing electronic organ
|
375 |
+
Cowbell
|
376 |
+
railroad car, train wagon
|
377 |
+
cell phone buzzing
|
378 |
+
playing cornet
|
379 |
+
sneezing
|
380 |
+
engine accelerating, revving, vroom
|
381 |
+
bird wings flapping
|
382 |
+
playing marimba, xylophone
|
383 |
+
playing guiro
|
384 |
+
people crowd
|
385 |
+
train wheels squealing
|
386 |
+
slot machine
|
387 |
+
laughing
|
388 |
+
lip smacking
|
389 |
+
forging swords
|
390 |
+
Chime
|
391 |
+
playing darts
|
392 |
+
people shuffling
|
393 |
+
Gong
|
394 |
+
airplane flyby
|
395 |
+
None
|
capspeech/ar/finetune_acccaptts.sh
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Please log in to huggingface first
|
2 |
+
|
3 |
+
LIBRITTSR_WAV_DIR='' # downloaded libritts-r wav dir
|
4 |
+
OTHER_WAV_DIR='' # downloaded other wav dirs
|
5 |
+
OUTPUT_DIR="./output_finetuning_acccaptts/" # output dir, to save checkpoints
|
6 |
+
TEMPORY_SAVE_TO_DISK="./audio_code_finetuning_acccaptts/" # dac codec saved dir
|
7 |
+
SAVE_TO_DISK="./dataset_finetuning_acccaptts/" # huggingface metadata saved dir
|
8 |
+
WANDB_KEY='' # your wandb key for logging
|
9 |
+
|
10 |
+
PRETRAINED_MODEL_PATH="" # your pretrained model path
|
11 |
+
|
12 |
+
export CUDA_LAUNCH_BLOCKING=1
|
13 |
+
export TORCH_USE_CUDA_DSA=1
|
14 |
+
|
15 |
+
accelerate launch ./training/finetune_captts.py \
|
16 |
+
--model_name_or_path ${PRETRAINED_MODEL_PATH} \
|
17 |
+
--feature_extractor_name "parler-tts/dac_44khZ_8kbps" \
|
18 |
+
--description_tokenizer_name ${PRETRAINED_MODEL_PATH} \
|
19 |
+
--prompt_tokenizer_name ${PRETRAINED_MODEL_PATH} \
|
20 |
+
--report_to "wandb" \
|
21 |
+
--wandb_key ${WANDB_KEY} \
|
22 |
+
--overwrite_output_dir true \
|
23 |
+
--train_dataset_name "OpenSound/CapSpeech" \
|
24 |
+
--train_split_name "train_SFT_AccCapTTS" \
|
25 |
+
--eval_dataset_name "OpenSound/CapSpeech" \
|
26 |
+
--eval_split_name "validation_SFT_AccCapTTS" \
|
27 |
+
--librittsr_dir ${LIBRITTSR_WAV_DIR} \
|
28 |
+
--other_dir ${OTHER_WAV_DIR} \
|
29 |
+
--max_eval_samples 96 \
|
30 |
+
--per_device_eval_batch_size 32 \
|
31 |
+
--target_audio_column_name "audio_path" \
|
32 |
+
--description_column_name "caption" \
|
33 |
+
--source_column_name "source" \
|
34 |
+
--prompt_column_name "text" \
|
35 |
+
--max_duration_in_seconds 20 \
|
36 |
+
--min_duration_in_seconds 3 \
|
37 |
+
--max_text_length 600 \
|
38 |
+
--preprocessing_num_workers 32 \
|
39 |
+
--do_train true \
|
40 |
+
--num_train_epochs 5 \
|
41 |
+
--gradient_accumulation_steps 6 \
|
42 |
+
--gradient_checkpointing false \
|
43 |
+
--per_device_train_batch_size 4 \
|
44 |
+
--learning_rate 0.0001 \
|
45 |
+
--adam_beta1 0.9 \
|
46 |
+
--adam_beta2 0.99 \
|
47 |
+
--weight_decay 0.01 \
|
48 |
+
--lr_scheduler_type "constant_with_warmup" \
|
49 |
+
--warmup_steps 1000 \
|
50 |
+
--logging_steps 200 \
|
51 |
+
--freeze_text_encoder true \
|
52 |
+
--per_device_eval_batch_size 4 \
|
53 |
+
--audio_encoder_per_device_batch_size 24 \
|
54 |
+
--dtype "float16" \
|
55 |
+
--seed 456 \
|
56 |
+
--output_dir ${OUTPUT_DIR} \
|
57 |
+
--temporary_save_to_disk ${TEMPORY_SAVE_TO_DISK} \
|
58 |
+
--save_to_disk ${SAVE_TO_DISK} \
|
59 |
+
--dataloader_num_workers 32 \
|
60 |
+
--do_eval \
|
61 |
+
--evaluation_strategy steps \
|
62 |
+
--eval_steps 500 \
|
63 |
+
--save_steps 500 \
|
64 |
+
--group_by_length true
|
capspeech/ar/finetune_agenttts.sh
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Please log in to huggingface first
|
2 |
+
|
3 |
+
OTHER_WAV_DIR='' # downloaded capspeech-agentdb wav dir
|
4 |
+
OUTPUT_DIR="./output_finetuning_agenttts/" # output dir, to save checkpoints
|
5 |
+
TEMPORY_SAVE_TO_DISK="./audio_code_finetuning_agenttts/" # dac codec saved dir
|
6 |
+
SAVE_TO_DISK="./dataset_finetuning_agenttts/" # huggingface metadata saved dir
|
7 |
+
WANDB_KEY='' # your wandb key for logging
|
8 |
+
PRETRAINED_MODEL_PATH="" # your pretrained model path
|
9 |
+
|
10 |
+
export CUDA_LAUNCH_BLOCKING=1
|
11 |
+
export TORCH_USE_CUDA_DSA=1
|
12 |
+
|
13 |
+
accelerate launch ./training/finetune_captts.py \
|
14 |
+
--model_name_or_path "/export/fs05/hwang258/parler-tts/parler-tts" \
|
15 |
+
--feature_extractor_name "parler-tts/dac_44khZ_8kbps" \
|
16 |
+
--description_tokenizer_name ${PRETRAINED_MODEL_PATH} \
|
17 |
+
--prompt_tokenizer_name ${PRETRAINED_MODEL_PATH} \
|
18 |
+
--report_to "wandb" \
|
19 |
+
--wandb_key ${WANDB_KEY} \
|
20 |
+
--overwrite_output_dir true \
|
21 |
+
--train_dataset_name "OpenSound/CapSpeech" \
|
22 |
+
--train_split_name "train_AgentDB" \
|
23 |
+
--eval_dataset_name "OpenSound/CapSpeech" \
|
24 |
+
--eval_split_name "test_AgentDB" \
|
25 |
+
--other_dir ${OTHER_WAV_DIR} \
|
26 |
+
--max_eval_samples 96 \
|
27 |
+
--per_device_eval_batch_size 32 \
|
28 |
+
--target_audio_column_name "audio_path" \
|
29 |
+
--description_column_name "caption" \
|
30 |
+
--source_column_name "source" \
|
31 |
+
--prompt_column_name "text" \
|
32 |
+
--max_duration_in_seconds 20 \
|
33 |
+
--min_duration_in_seconds 3 \
|
34 |
+
--max_text_length 600 \
|
35 |
+
--preprocessing_num_workers 32 \
|
36 |
+
--do_train true \
|
37 |
+
--num_train_epochs 50 \
|
38 |
+
--gradient_accumulation_steps 6 \
|
39 |
+
--gradient_checkpointing false \
|
40 |
+
--per_device_train_batch_size 4 \
|
41 |
+
--learning_rate 0.0001 \
|
42 |
+
--adam_beta1 0.9 \
|
43 |
+
--adam_beta2 0.99 \
|
44 |
+
--weight_decay 0.01 \
|
45 |
+
--lr_scheduler_type "constant_with_warmup" \
|
46 |
+
--warmup_steps 500 \
|
47 |
+
--logging_steps 100 \
|
48 |
+
--freeze_text_encoder true \
|
49 |
+
--per_device_eval_batch_size 4 \
|
50 |
+
--audio_encoder_per_device_batch_size 24 \
|
51 |
+
--dtype "float16" \
|
52 |
+
--seed 456 \
|
53 |
+
--output_dir ${OUTPUT_DIR} \
|
54 |
+
--temporary_save_to_disk ${TEMPORY_SAVE_TO_DISK} \
|
55 |
+
--save_to_disk ${SAVE_TO_DISK} \
|
56 |
+
--dataloader_num_workers 32 \
|
57 |
+
--do_eval \
|
58 |
+
--evaluation_strategy steps \
|
59 |
+
--eval_steps 500 \
|
60 |
+
--save_steps 500 \
|
61 |
+
--group_by_length true
|
capspeech/ar/finetune_captts.sh
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Please log in to huggingface first
|
2 |
+
|
3 |
+
LIBRITTSR_WAV_DIR='' # downloaded libritts-r wav dir
|
4 |
+
OTHER_WAV_DIR='' # downloaded other wav dirs
|
5 |
+
OUTPUT_DIR="./output_finetuning_captts/" # output dir, to save checkpoints
|
6 |
+
TEMPORY_SAVE_TO_DISK="./audio_code_finetuning_captts/" # dac codec saved dir
|
7 |
+
SAVE_TO_DISK="./dataset_finetuning_captts/" # huggingface metadata saved dir
|
8 |
+
WANDB_KEY='' # your wandb key for logging
|
9 |
+
|
10 |
+
PRETRAINED_MODEL_PATH="" # your pretrained model path
|
11 |
+
|
12 |
+
export CUDA_LAUNCH_BLOCKING=1
|
13 |
+
export TORCH_USE_CUDA_DSA=1
|
14 |
+
|
15 |
+
accelerate launch ./training/finetune_captts.py \
|
16 |
+
--model_name_or_path ${PRETRAINED_MODEL_PATH} \
|
17 |
+
--feature_extractor_name "parler-tts/dac_44khZ_8kbps" \
|
18 |
+
--description_tokenizer_name ${PRETRAINED_MODEL_PATH} \
|
19 |
+
--prompt_tokenizer_name ${PRETRAINED_MODEL_PATH} \
|
20 |
+
--report_to "wandb" \
|
21 |
+
--wandb_key ${WANDB_KEY} \
|
22 |
+
--overwrite_output_dir true \
|
23 |
+
--train_dataset_name "OpenSound/CapSpeech" \
|
24 |
+
--train_split_name "train_SFT_CapTTS" \
|
25 |
+
--eval_dataset_name "OpenSound/CapSpeech" \
|
26 |
+
--eval_split_name "validation_SFT_CapTTS" \
|
27 |
+
--librittsr_dir ${LIBRITTSR_WAV_DIR} \
|
28 |
+
--other_dir ${OTHER_WAV_DIR} \
|
29 |
+
--max_eval_samples 96 \
|
30 |
+
--per_device_eval_batch_size 32 \
|
31 |
+
--target_audio_column_name "audio_path" \
|
32 |
+
--description_column_name "caption" \
|
33 |
+
--source_column_name "source" \
|
34 |
+
--prompt_column_name "text" \
|
35 |
+
--max_duration_in_seconds 20 \
|
36 |
+
--min_duration_in_seconds 3 \
|
37 |
+
--max_text_length 600 \
|
38 |
+
--preprocessing_num_workers 32 \
|
39 |
+
--do_train true \
|
40 |
+
--num_train_epochs 5 \
|
41 |
+
--gradient_accumulation_steps 6 \
|
42 |
+
--gradient_checkpointing false \
|
43 |
+
--per_device_train_batch_size 4 \
|
44 |
+
--learning_rate 0.0001 \
|
45 |
+
--adam_beta1 0.9 \
|
46 |
+
--adam_beta2 0.99 \
|
47 |
+
--weight_decay 0.01 \
|
48 |
+
--lr_scheduler_type "constant_with_warmup" \
|
49 |
+
--warmup_steps 1000 \
|
50 |
+
--logging_steps 200 \
|
51 |
+
--freeze_text_encoder true \
|
52 |
+
--per_device_eval_batch_size 4 \
|
53 |
+
--audio_encoder_per_device_batch_size 24 \
|
54 |
+
--dtype "float16" \
|
55 |
+
--seed 456 \
|
56 |
+
--output_dir ${OUTPUT_DIR} \
|
57 |
+
--temporary_save_to_disk ${TEMPORY_SAVE_TO_DISK} \
|
58 |
+
--save_to_disk ${SAVE_TO_DISK} \
|
59 |
+
--dataloader_num_workers 32 \
|
60 |
+
--do_eval \
|
61 |
+
--evaluation_strategy steps \
|
62 |
+
--eval_steps 2000 \
|
63 |
+
--save_steps 2000 \
|
64 |
+
--group_by_length true
|
capspeech/ar/finetune_capttsse.sh
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Please log in to huggingface first
|
2 |
+
|
3 |
+
LIBRITTSRMIX_WAV_DIR='' # downloaded capspeech-sedb wav dir
|
4 |
+
OUTPUT_DIR="./output_finetuning_capttsse/" # output dir, to save checkpoints
|
5 |
+
TEMPORY_SAVE_TO_DISK="./audio_code_finetuning_capttsse/" # dac codec saved dir
|
6 |
+
SAVE_TO_DISK="./dataset_finetuning_capttsse/" # huggingface metadata saved dir
|
7 |
+
WANDB_KEY='' # your wandb key for logging
|
8 |
+
|
9 |
+
PRETRAINED_MODEL_PATH="" # your pretrained model path
|
10 |
+
|
11 |
+
export CUDA_LAUNCH_BLOCKING=1
|
12 |
+
export TORCH_USE_CUDA_DSA=1
|
13 |
+
|
14 |
+
accelerate launch ./training/finetune_capttsse.py \
|
15 |
+
--model_name_or_path ${PRETRAINED_MODEL_PATH} \
|
16 |
+
--feature_extractor_name "parler-tts/dac_44khZ_8kbps" \
|
17 |
+
--description_tokenizer_name ${PRETRAINED_MODEL_PATH} \
|
18 |
+
--prompt_tokenizer_name ${PRETRAINED_MODEL_PATH} \
|
19 |
+
--report_to "wandb" \
|
20 |
+
--wandb_key ${WANDB_KEY} \
|
21 |
+
--overwrite_output_dir true \
|
22 |
+
--train_dataset_name "OpenSound/CapSpeech" \
|
23 |
+
--train_split_name "train_SEDB" \
|
24 |
+
--eval_dataset_name "OpenSound/CapSpeech" \
|
25 |
+
--eval_split_name "test_SEDB" \
|
26 |
+
--librittsrmix_dir ${LIBRITTSRMIX_WAV_DIR} \
|
27 |
+
--max_eval_samples 96 \
|
28 |
+
--per_device_eval_batch_size 32 \
|
29 |
+
--target_audio_column_name "audio_path" \
|
30 |
+
--description_column_name "caption" \
|
31 |
+
--source_column_name "source" \
|
32 |
+
--prompt_column_name "text" \
|
33 |
+
--max_duration_in_seconds 20 \
|
34 |
+
--min_duration_in_seconds 3 \
|
35 |
+
--max_text_length 600 \
|
36 |
+
--preprocessing_num_workers 32 \
|
37 |
+
--do_train true \
|
38 |
+
--num_train_epochs 50 \
|
39 |
+
--gradient_accumulation_steps 6 \
|
40 |
+
--gradient_checkpointing false \
|
41 |
+
--per_device_train_batch_size 4 \
|
42 |
+
--learning_rate 0.0001 \
|
43 |
+
--adam_beta1 0.9 \
|
44 |
+
--adam_beta2 0.99 \
|
45 |
+
--weight_decay 0.01 \
|
46 |
+
--lr_scheduler_type "constant_with_warmup" \
|
47 |
+
--warmup_steps 50 \
|
48 |
+
--logging_steps 20 \
|
49 |
+
--freeze_text_encoder true \
|
50 |
+
--per_device_eval_batch_size 4 \
|
51 |
+
--audio_encoder_per_device_batch_size 24 \
|
52 |
+
--dtype "float16" \
|
53 |
+
--seed 456 \
|
54 |
+
--output_dir ${OUTPUT_DIR} \
|
55 |
+
--temporary_save_to_disk ${TEMPORY_SAVE_TO_DISK} \
|
56 |
+
--save_to_disk ${SAVE_TO_DISK} \
|
57 |
+
--dataloader_num_workers 32 \
|
58 |
+
--do_eval \
|
59 |
+
--evaluation_strategy steps \
|
60 |
+
--eval_steps 50 \
|
61 |
+
--save_steps 50 \
|
62 |
+
--group_by_length true
|
capspeech/ar/finetune_emocaptts.sh
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Please log in to huggingface first
|
2 |
+
|
3 |
+
LIBRITTSR_WAV_DIR='' # downloaded libritts-r wav dir
|
4 |
+
OTHER_WAV_DIR='' # downloaded other wav dirs
|
5 |
+
OUTPUT_DIR="./output_finetuning_emocaptts/" # output dir, to save checkpoints
|
6 |
+
TEMPORY_SAVE_TO_DISK="./audio_code_finetuning_emocaptts/" # dac codec saved dir
|
7 |
+
SAVE_TO_DISK="./dataset_finetuning_emocaptts/" # huggingface metadata saved dir
|
8 |
+
WANDB_KEY='' # your wandb key for logging
|
9 |
+
|
10 |
+
PRETRAINED_MODEL_PATH="" # your pretrained model path
|
11 |
+
|
12 |
+
export CUDA_LAUNCH_BLOCKING=1
|
13 |
+
export TORCH_USE_CUDA_DSA=1
|
14 |
+
|
15 |
+
accelerate launch ./training/finetune_captts.py \
|
16 |
+
--model_name_or_path ${PRETRAINED_MODEL_PATH} \
|
17 |
+
--feature_extractor_name "parler-tts/dac_44khZ_8kbps" \
|
18 |
+
--description_tokenizer_name ${PRETRAINED_MODEL_PATH} \
|
19 |
+
--prompt_tokenizer_name ${PRETRAINED_MODEL_PATH} \
|
20 |
+
--report_to "wandb" \
|
21 |
+
--wandb_key ${WANDB_KEY} \
|
22 |
+
--overwrite_output_dir true \
|
23 |
+
--train_dataset_name "OpenSound/CapSpeech" \
|
24 |
+
--train_split_name "train_SFT_EmoCapTTS" \
|
25 |
+
--eval_dataset_name "OpenSound/CapSpeech" \
|
26 |
+
--eval_split_name "validation_SFT_EmoCapTTS" \
|
27 |
+
--librittsr_dir ${LIBRITTSR_WAV_DIR} \
|
28 |
+
--other_dir ${OTHER_WAV_DIR} \
|
29 |
+
--max_eval_samples 96 \
|
30 |
+
--per_device_eval_batch_size 32 \
|
31 |
+
--target_audio_column_name "audio_path" \
|
32 |
+
--description_column_name "caption" \
|
33 |
+
--source_column_name "source" \
|
34 |
+
--prompt_column_name "text" \
|
35 |
+
--max_duration_in_seconds 20 \
|
36 |
+
--min_duration_in_seconds 3 \
|
37 |
+
--max_text_length 600 \
|
38 |
+
--preprocessing_num_workers 32 \
|
39 |
+
--do_train true \
|
40 |
+
--num_train_epochs 5 \
|
41 |
+
--gradient_accumulation_steps 6 \
|
42 |
+
--gradient_checkpointing false \
|
43 |
+
--per_device_train_batch_size 4 \
|
44 |
+
--learning_rate 0.0001 \
|
45 |
+
--adam_beta1 0.9 \
|
46 |
+
--adam_beta2 0.99 \
|
47 |
+
--weight_decay 0.01 \
|
48 |
+
--lr_scheduler_type "constant_with_warmup" \
|
49 |
+
--warmup_steps 1000 \
|
50 |
+
--logging_steps 200 \
|
51 |
+
--freeze_text_encoder true \
|
52 |
+
--per_device_eval_batch_size 4 \
|
53 |
+
--audio_encoder_per_device_batch_size 24 \
|
54 |
+
--dtype "float16" \
|
55 |
+
--seed 456 \
|
56 |
+
--output_dir ${OUTPUT_DIR} \
|
57 |
+
--temporary_save_to_disk ${TEMPORY_SAVE_TO_DISK} \
|
58 |
+
--save_to_disk ${SAVE_TO_DISK} \
|
59 |
+
--dataloader_num_workers 32 \
|
60 |
+
--do_eval \
|
61 |
+
--evaluation_strategy steps \
|
62 |
+
--eval_steps 400 \
|
63 |
+
--save_steps 400 \
|
64 |
+
--group_by_length true
|
capspeech/ar/parler_tts/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "0.2.2"
|
2 |
+
|
3 |
+
|
4 |
+
from transformers import AutoConfig, AutoModel
|
5 |
+
|
6 |
+
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
|
7 |
+
from .dac_wrapper import DACConfig, DACModel
|
8 |
+
from .modeling_parler_tts import (
|
9 |
+
ParlerTTSForCausalLM,
|
10 |
+
ParlerTTSForConditionalGeneration,
|
11 |
+
apply_delay_pattern_mask,
|
12 |
+
build_delay_pattern_mask,
|
13 |
+
)
|
14 |
+
|
15 |
+
from .streamer import ParlerTTSStreamer
|
16 |
+
|
17 |
+
from importlib.metadata import version
|
18 |
+
from packaging.version import Version
|
19 |
+
|
20 |
+
if Version(version("transformers"))<= Version("4.44.2dev"):
|
21 |
+
AutoConfig.register("dac", DACConfig)
|
22 |
+
else:
|
23 |
+
AutoConfig.register("dac_on_the_hub", DACConfig)
|
24 |
+
|
25 |
+
AutoModel.register(DACConfig, DACModel)
|
capspeech/ar/parler_tts/configuration_parler_tts.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Parler-TTS model configuration"""
|
16 |
+
|
17 |
+
from transformers import AutoConfig, logging
|
18 |
+
from transformers.configuration_utils import PretrainedConfig
|
19 |
+
|
20 |
+
from importlib.metadata import version
|
21 |
+
from packaging.version import Version
|
22 |
+
|
23 |
+
use_dac_on_the_hub = Version(version("transformers")) > Version("4.44.2dev")
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__)
|
26 |
+
|
27 |
+
PARLER_TTS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
28 |
+
"parler-tts/parler-tts-mini-v1": "https://huggingface.co/parler-tts/parler-tts-mini-v1/resolve/main/config.json",
|
29 |
+
# See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
class ParlerTTSDecoderConfig(PretrainedConfig):
|
34 |
+
r"""
|
35 |
+
This is the configuration class to store the configuration of an [`ParlerTTSDecoder`]. It is used to instantiate a
|
36 |
+
Parler-TTS decoder according to the specified arguments, defining the model architecture. Instantiating a
|
37 |
+
configuration with the defaults will yield a similar configuration to that of the Parler-TTS
|
38 |
+
[parler-tts/parler-tts-mini-v1](https://huggingface.co/parler-tts/parler-tts-mini-v1) architecture.
|
39 |
+
|
40 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
41 |
+
documentation from [`PretrainedConfig`] for more information.
|
42 |
+
|
43 |
+
|
44 |
+
Args:
|
45 |
+
vocab_size (`int`, *optional*, defaults to 2049):
|
46 |
+
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
|
47 |
+
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
|
48 |
+
hidden_size (`int`, *optional*, defaults to 1024):
|
49 |
+
Dimensionality of the layers and the pooler layer.
|
50 |
+
num_hidden_layers (`int`, *optional*, defaults to 24):
|
51 |
+
Number of decoder layers.
|
52 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
53 |
+
Number of attention heads for each attention layer in the Transformer block.
|
54 |
+
num_key_value_heads (`int`, *optional*):
|
55 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
56 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
57 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
58 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
59 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
60 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
61 |
+
`num_attention_heads`.
|
62 |
+
num_cross_attention_key_value_heads (`int`, *optional*):
|
63 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention in the cross-attention layers.
|
64 |
+
If it is not specified, will default to `num_key_value_heads`.
|
65 |
+
ffn_dim (`int`, *optional*, defaults to 4096):
|
66 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
|
67 |
+
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
68 |
+
The non-linear activation function (function or string) in the decoder and pooler. If string, `"gelu"`,
|
69 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
70 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
71 |
+
The dropout probability for all fully connected layers in the embeddings, text_encoder, and pooler.
|
72 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
73 |
+
The dropout ratio for the attention probabilities.
|
74 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
75 |
+
The dropout ratio for activations inside the fully connected layer.
|
76 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
77 |
+
The maximum sequence length that this model might ever be used with. Typically, set this to something large
|
78 |
+
just in case (e.g., 512 or 1024 or 2048).
|
79 |
+
initializer_factor (`float`, *optional*, defaults to 0.02):
|
80 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
81 |
+
layerdrop (`float`, *optional*, defaults to 0.0):
|
82 |
+
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
83 |
+
for more details.
|
84 |
+
scale_embedding (`bool`, *optional*, defaults to `False`):
|
85 |
+
Scale embeddings by diving by sqrt(hidden_size).
|
86 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
87 |
+
Whether the model should return the last key/values attentions (not used by all models)
|
88 |
+
num_codebooks (`int`, *optional*, defaults to 4):
|
89 |
+
The number of parallel codebooks forwarded to the model.
|
90 |
+
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
91 |
+
Whether input and output word embeddings should be tied.
|
92 |
+
rope_embeddings (`bool`, *optional*, defaults to `False`):
|
93 |
+
Whether to use ROPE or absolute positional embeddings.
|
94 |
+
rope_theta (`float`, *optional*, defaults to 100000.0):
|
95 |
+
The base period of the RoPE embeddings.
|
96 |
+
cross_attention_implementation_strategy (`str`, *optional*):
|
97 |
+
If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation.
|
98 |
+
use_fused_lm_heads(`bool`, *optional*, defaults to `False`):
|
99 |
+
Whether to fuse audio LM heads instead of applying them sequentially.
|
100 |
+
codebook_weights(`List[int]`, *optional*):
|
101 |
+
Weights applied to each codebook when computing the loss.
|
102 |
+
"""
|
103 |
+
|
104 |
+
model_type = "parler_tts_decoder"
|
105 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
106 |
+
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
vocab_size=2049, # vocab size = 2048 (encodec vocab size) + 1 (eos)
|
110 |
+
max_position_embeddings=2048,
|
111 |
+
num_hidden_layers=24,
|
112 |
+
ffn_dim=4096,
|
113 |
+
num_attention_heads=16,
|
114 |
+
num_key_value_heads=None,
|
115 |
+
num_cross_attention_key_value_heads=None,
|
116 |
+
layerdrop=0.0,
|
117 |
+
use_cache=True,
|
118 |
+
activation_function="gelu",
|
119 |
+
hidden_size=1024,
|
120 |
+
dropout=0.1,
|
121 |
+
attention_dropout=0.0,
|
122 |
+
activation_dropout=0.0,
|
123 |
+
initializer_factor=0.02,
|
124 |
+
scale_embedding=False,
|
125 |
+
num_codebooks=4,
|
126 |
+
pad_token_id=2048,
|
127 |
+
bos_token_id=2049,
|
128 |
+
eos_token_id=2048,
|
129 |
+
tie_word_embeddings=False,
|
130 |
+
rope_embeddings=False,
|
131 |
+
rope_theta=10_000.0,
|
132 |
+
cross_attention_implementation_strategy=None,
|
133 |
+
use_fused_lm_heads=False,
|
134 |
+
codebook_weights=None,
|
135 |
+
**kwargs,
|
136 |
+
):
|
137 |
+
self.vocab_size = vocab_size
|
138 |
+
self.max_position_embeddings = max_position_embeddings
|
139 |
+
self.hidden_size = hidden_size
|
140 |
+
self.ffn_dim = ffn_dim
|
141 |
+
self.num_hidden_layers = num_hidden_layers
|
142 |
+
self.num_attention_heads = num_attention_heads
|
143 |
+
if num_key_value_heads is None:
|
144 |
+
num_key_value_heads = num_attention_heads
|
145 |
+
self.num_key_value_heads = num_key_value_heads
|
146 |
+
if num_cross_attention_key_value_heads is None:
|
147 |
+
num_cross_attention_key_value_heads = num_key_value_heads
|
148 |
+
self.num_cross_attention_key_value_heads = num_cross_attention_key_value_heads
|
149 |
+
self.dropout = dropout
|
150 |
+
self.attention_dropout = attention_dropout
|
151 |
+
self.activation_dropout = activation_dropout
|
152 |
+
self.activation_function = activation_function
|
153 |
+
self.initializer_factor = initializer_factor
|
154 |
+
self.layerdrop = layerdrop
|
155 |
+
self.use_cache = use_cache
|
156 |
+
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
157 |
+
self.num_codebooks = num_codebooks
|
158 |
+
self.rope_embeddings = rope_embeddings
|
159 |
+
self.rope_theta = rope_theta
|
160 |
+
self.cross_attention_implementation_strategy = cross_attention_implementation_strategy
|
161 |
+
self.use_fused_lm_heads = use_fused_lm_heads
|
162 |
+
self.codebook_weights = codebook_weights
|
163 |
+
|
164 |
+
if codebook_weights is not None and len(codebook_weights) != num_codebooks:
|
165 |
+
raise ValueError(f"`codebook_weights` has length {len(codebook_weights)} when it should be of length {num_codebooks}.")
|
166 |
+
super().__init__(
|
167 |
+
pad_token_id=pad_token_id,
|
168 |
+
bos_token_id=bos_token_id,
|
169 |
+
eos_token_id=eos_token_id,
|
170 |
+
tie_word_embeddings=tie_word_embeddings,
|
171 |
+
**kwargs,
|
172 |
+
)
|
173 |
+
|
174 |
+
|
175 |
+
class ParlerTTSConfig(PretrainedConfig):
|
176 |
+
r"""
|
177 |
+
This is the configuration class to store the configuration of a [`ParlerTTSModel`]. It is used to instantiate a
|
178 |
+
Parler-TTS model according to the specified arguments, defining the text encoder, audio encoder and Parler-TTS decoder
|
179 |
+
configs.
|
180 |
+
|
181 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
182 |
+
documentation from [`PretrainedConfig`] for more information.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
vocab_size (`int`, *optional*, defaults to 1024):
|
186 |
+
Vocabulary size of the prompt token ids. Defines the number of different tokens that can be
|
187 |
+
represented by the `prompt_inputs_ids`.
|
188 |
+
prompt_cross_attention (`bool`, *optional*, defaults to `False`):
|
189 |
+
Whether to use cross-attention conditioning for the prompt (as well as the description).
|
190 |
+
kwargs (*optional*):
|
191 |
+
Dictionary of keyword arguments. Notably:
|
192 |
+
|
193 |
+
- **text_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
|
194 |
+
defines the text encoder config.
|
195 |
+
- **audio_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
|
196 |
+
defines the audio encoder config.
|
197 |
+
- **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
|
198 |
+
the decoder config.
|
199 |
+
|
200 |
+
Example:
|
201 |
+
|
202 |
+
```python
|
203 |
+
>>> from transformers import (
|
204 |
+
... ParlerTTSConfig,
|
205 |
+
... ParlerTTSDecoderConfig,
|
206 |
+
... T5Config,
|
207 |
+
... EncodecConfig,
|
208 |
+
... ParlerTTSForConditionalGeneration,
|
209 |
+
... )
|
210 |
+
|
211 |
+
>>> # Initializing text encoder, audio encoder, and decoder model configurations
|
212 |
+
>>> text_encoder_config = T5Config()
|
213 |
+
>>> audio_encoder_config = EncodecConfig()
|
214 |
+
>>> decoder_config = ParlerTTSDecoderConfig()
|
215 |
+
|
216 |
+
>>> configuration = ParlerTTSConfig.from_sub_models_config(
|
217 |
+
... text_encoder_config, audio_encoder_config, decoder_config
|
218 |
+
... )
|
219 |
+
|
220 |
+
>>> # Initializing a ParlerTTSForConditionalGeneration (with random weights) from the parler-tts/parler-tts-mini-v1 style configuration
|
221 |
+
>>> model = ParlerTTSForConditionalGeneration(configuration)
|
222 |
+
|
223 |
+
>>> # Accessing the model configuration
|
224 |
+
>>> configuration = model.config
|
225 |
+
>>> config_text_encoder = model.config.text_encoder
|
226 |
+
>>> config_audio_encoder = model.config.audio_encoder
|
227 |
+
>>> config_decoder = model.config.decoder
|
228 |
+
|
229 |
+
>>> # Saving the model, including its configuration
|
230 |
+
>>> model.save_pretrained("parler_tts-model")
|
231 |
+
|
232 |
+
>>> # loading model and config from pretrained folder
|
233 |
+
>>> parler_tts_config = ParlerTTSConfig.from_pretrained("parler_tts-model")
|
234 |
+
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("parler_tts-model", config=parler_tts_config)
|
235 |
+
```"""
|
236 |
+
|
237 |
+
model_type = "parler_tts"
|
238 |
+
is_composition = True
|
239 |
+
|
240 |
+
def __init__(self, vocab_size=1024, prompt_cross_attention=False, **kwargs):
|
241 |
+
super().__init__(**kwargs)
|
242 |
+
if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
|
243 |
+
raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
|
244 |
+
|
245 |
+
text_encoder_config = kwargs.pop("text_encoder")
|
246 |
+
text_encoder_model_type = text_encoder_config.pop("model_type")
|
247 |
+
|
248 |
+
audio_encoder_config = kwargs.pop("audio_encoder")
|
249 |
+
audio_encoder_model_type = audio_encoder_config.pop("model_type")
|
250 |
+
|
251 |
+
model_version = kwargs.get("transformers_version", None)
|
252 |
+
if model_version is not None and Version(model_version) <= Version("4.44.2dev") and use_dac_on_the_hub and audio_encoder_model_type=="dac":
|
253 |
+
# here we have to manually change model type if DAC based on transformers version
|
254 |
+
audio_encoder_model_type = "dac_on_the_hub"
|
255 |
+
|
256 |
+
decoder_config = kwargs.pop("decoder")
|
257 |
+
|
258 |
+
self.vocab_size = vocab_size
|
259 |
+
self.prompt_cross_attention = prompt_cross_attention
|
260 |
+
self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
|
261 |
+
self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
|
262 |
+
self.decoder = ParlerTTSDecoderConfig(**decoder_config)
|
263 |
+
self.is_encoder_decoder = True
|
264 |
+
|
265 |
+
@classmethod
|
266 |
+
def from_sub_models_config(
|
267 |
+
cls,
|
268 |
+
text_encoder_config: PretrainedConfig,
|
269 |
+
audio_encoder_config: PretrainedConfig,
|
270 |
+
decoder_config: ParlerTTSDecoderConfig,
|
271 |
+
**kwargs,
|
272 |
+
):
|
273 |
+
r"""
|
274 |
+
Instantiate a [`ParlerTTSConfig`] (or a derived class) from text encoder, audio encoder and decoder
|
275 |
+
configurations.
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
[`ParlerTTSConfig`]: An instance of a configuration object
|
279 |
+
"""
|
280 |
+
|
281 |
+
return cls(
|
282 |
+
text_encoder=text_encoder_config.to_dict(),
|
283 |
+
audio_encoder=audio_encoder_config.to_dict(),
|
284 |
+
decoder=decoder_config.to_dict(),
|
285 |
+
**kwargs,
|
286 |
+
)
|
287 |
+
|
288 |
+
@property
|
289 |
+
# This is a property because you might want to change the codec model on the fly
|
290 |
+
def sampling_rate(self):
|
291 |
+
return self.audio_encoder.sampling_rate
|
capspeech/ar/parler_tts/dac_wrapper/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .configuration_dac import DACConfig
|
2 |
+
from .modeling_dac import DACModel
|
capspeech/ar/parler_tts/dac_wrapper/configuration_dac.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from transformers import PretrainedConfig
|
3 |
+
from importlib.metadata import version
|
4 |
+
from packaging.version import Version
|
5 |
+
|
6 |
+
|
7 |
+
class DACConfig(PretrainedConfig):
|
8 |
+
model_type = "dac" if Version(version("transformers"))<= Version("4.44.2dev") else "dac_on_the_hub"
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
num_codebooks: int = 9,
|
13 |
+
model_bitrate: int = 8, # kbps
|
14 |
+
codebook_size: int = 1024,
|
15 |
+
latent_dim: int = 1024,
|
16 |
+
frame_rate: int = 86,
|
17 |
+
sampling_rate: int = 44100,
|
18 |
+
**kwargs,
|
19 |
+
):
|
20 |
+
self.codebook_size = codebook_size
|
21 |
+
self.model_bitrate = model_bitrate
|
22 |
+
self.latent_dim = latent_dim
|
23 |
+
self.num_codebooks = num_codebooks
|
24 |
+
self.frame_rate = frame_rate
|
25 |
+
self.sampling_rate = sampling_rate
|
26 |
+
|
27 |
+
super().__init__(**kwargs)
|
capspeech/ar/parler_tts/dac_wrapper/modeling_dac.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from dac.model import DAC
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
from transformers import PreTrainedModel
|
6 |
+
from transformers.models.encodec.modeling_encodec import EncodecDecoderOutput, EncodecEncoderOutput
|
7 |
+
|
8 |
+
from .configuration_dac import DACConfig
|
9 |
+
|
10 |
+
|
11 |
+
# model doesn't support batching yet
|
12 |
+
|
13 |
+
|
14 |
+
class DACModel(PreTrainedModel):
|
15 |
+
config_class = DACConfig
|
16 |
+
main_input_name = "input_values"
|
17 |
+
|
18 |
+
# Set main input to 'input_values' for voice steering
|
19 |
+
main_input_name = "input_values"
|
20 |
+
|
21 |
+
def __init__(self, config):
|
22 |
+
super().__init__(config)
|
23 |
+
|
24 |
+
self.model = DAC(
|
25 |
+
n_codebooks=config.num_codebooks,
|
26 |
+
latent_dim=config.latent_dim,
|
27 |
+
codebook_size=config.codebook_size,
|
28 |
+
)
|
29 |
+
|
30 |
+
self.remove_weight_norm()
|
31 |
+
self.apply_weight_norm()
|
32 |
+
|
33 |
+
def encode(
|
34 |
+
self, input_values, padding_mask=None, bandwidth=None, return_dict=None, n_quantizers=None, sample_rate=None
|
35 |
+
):
|
36 |
+
"""
|
37 |
+
Encodes the input audio waveform into discrete codes.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
|
41 |
+
Float values of the input audio waveform.
|
42 |
+
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
|
43 |
+
Padding mask used to pad the `input_values`.
|
44 |
+
bandwidth (`float`, *optional*):
|
45 |
+
Not used, kept to have the same inferface as HF encodec.
|
46 |
+
n_quantizers (`int`, *optional*) :
|
47 |
+
Number of quantizers to use, by default None
|
48 |
+
If None, all quantizers are used.
|
49 |
+
sample_rate (`int`, *optional*) :
|
50 |
+
Signal sampling_rate
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
|
54 |
+
factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
|
55 |
+
`codebook` of shape `[batch_size, num_codebooks, frames]`.
|
56 |
+
Scale is not used here.
|
57 |
+
|
58 |
+
"""
|
59 |
+
_, channels, input_length = input_values.shape
|
60 |
+
|
61 |
+
if channels < 1 or channels > 2:
|
62 |
+
raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}")
|
63 |
+
|
64 |
+
audio_data = self.model.preprocess(input_values, sample_rate)
|
65 |
+
|
66 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
67 |
+
|
68 |
+
# TODO: for now, no chunk length
|
69 |
+
|
70 |
+
chunk_length = None # self.config.chunk_length
|
71 |
+
if chunk_length is None:
|
72 |
+
chunk_length = input_length
|
73 |
+
stride = input_length
|
74 |
+
else:
|
75 |
+
stride = self.config.chunk_stride
|
76 |
+
|
77 |
+
if padding_mask is None:
|
78 |
+
padding_mask = torch.ones_like(input_values).bool()
|
79 |
+
|
80 |
+
encoded_frames = []
|
81 |
+
scales = []
|
82 |
+
|
83 |
+
step = chunk_length - stride
|
84 |
+
if (input_length % stride) - step != 0:
|
85 |
+
raise ValueError(
|
86 |
+
"The input length is not properly padded for batched chunked decoding. Make sure to pad the input correctly."
|
87 |
+
)
|
88 |
+
|
89 |
+
for offset in range(0, input_length - step, stride):
|
90 |
+
mask = padding_mask[..., offset : offset + chunk_length].bool()
|
91 |
+
frame = audio_data[:, :, offset : offset + chunk_length]
|
92 |
+
|
93 |
+
scale = None
|
94 |
+
|
95 |
+
_, encoded_frame, _, _, _ = self.model.encode(frame, n_quantizers=n_quantizers)
|
96 |
+
encoded_frames.append(encoded_frame)
|
97 |
+
scales.append(scale)
|
98 |
+
|
99 |
+
encoded_frames = torch.stack(encoded_frames)
|
100 |
+
|
101 |
+
if not return_dict:
|
102 |
+
return (encoded_frames, scales)
|
103 |
+
|
104 |
+
return EncodecEncoderOutput(encoded_frames, scales)
|
105 |
+
|
106 |
+
def decode(
|
107 |
+
self,
|
108 |
+
audio_codes,
|
109 |
+
audio_scales,
|
110 |
+
padding_mask=None,
|
111 |
+
return_dict=None,
|
112 |
+
):
|
113 |
+
"""
|
114 |
+
Decodes the given frames into an output audio waveform.
|
115 |
+
|
116 |
+
Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
|
117 |
+
trimmed.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
audio_codes (`torch.FloatTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
|
121 |
+
Discret code embeddings computed using `model.encode`.
|
122 |
+
audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
|
123 |
+
Not used, kept to have the same inferface as HF encodec.
|
124 |
+
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
|
125 |
+
Padding mask used to pad the `input_values`.
|
126 |
+
Not used yet, kept to have the same inferface as HF encodec.
|
127 |
+
return_dict (`bool`, *optional*):
|
128 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
129 |
+
|
130 |
+
"""
|
131 |
+
return_dict = return_dict or self.config.return_dict
|
132 |
+
|
133 |
+
# TODO: for now, no chunk length
|
134 |
+
|
135 |
+
if len(audio_codes) != 1:
|
136 |
+
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
|
137 |
+
|
138 |
+
audio_values = self.model.quantizer.from_codes(audio_codes.squeeze(0))[0]
|
139 |
+
audio_values = self.model.decode(audio_values)
|
140 |
+
if not return_dict:
|
141 |
+
return (audio_values,)
|
142 |
+
return EncodecDecoderOutput(audio_values)
|
143 |
+
|
144 |
+
def forward(self, tensor):
|
145 |
+
raise ValueError("`DACModel.forward` not implemented yet")
|
146 |
+
|
147 |
+
|
148 |
+
def apply_weight_norm(self):
|
149 |
+
weight_norm = nn.utils.weight_norm
|
150 |
+
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
151 |
+
weight_norm = nn.utils.parametrizations.weight_norm
|
152 |
+
|
153 |
+
def _apply_weight_norm(module):
|
154 |
+
if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
|
155 |
+
weight_norm(module)
|
156 |
+
|
157 |
+
self.apply(_apply_weight_norm)
|
158 |
+
|
159 |
+
|
160 |
+
def remove_weight_norm(self):
|
161 |
+
def _remove_weight_norm(module):
|
162 |
+
if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
|
163 |
+
nn.utils.remove_weight_norm(module)
|
164 |
+
self.apply(_remove_weight_norm)
|
capspeech/ar/parler_tts/logits_processors.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import LogitsProcessor, LogitsProcessorList
|
2 |
+
from transformers.pytorch_utils import isin_mps_friendly
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
|
6 |
+
class ParlerTTSLogitsProcessor(LogitsProcessor):
|
7 |
+
r"""This processor ensures that the delayed pattern mask constraints are respected.
|
8 |
+
|
9 |
+
<Tip warning={true}>
|
10 |
+
|
11 |
+
This logits processor is exclusively compatible with Parler-TTS.
|
12 |
+
See the model documentation for examples.
|
13 |
+
|
14 |
+
</Tip>
|
15 |
+
|
16 |
+
Args:
|
17 |
+
eos_token_id (`Union[int, List[int], torch.Tensor]`):
|
18 |
+
The id(s) of the *end-of-sequence* token.
|
19 |
+
min_eos_p (`float`, *optional*):
|
20 |
+
Minimum end of speech threshold.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, eos_token_id, num_codebooks: int, batch_size: int, device: str = "cpu"):
|
24 |
+
if not isinstance(eos_token_id, torch.Tensor):
|
25 |
+
if isinstance(eos_token_id, int):
|
26 |
+
eos_token_id = [eos_token_id]
|
27 |
+
eos_token_id = torch.tensor(eos_token_id, device=device)
|
28 |
+
self.eos_token_id = eos_token_id
|
29 |
+
self.batch_size = batch_size
|
30 |
+
|
31 |
+
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
|
32 |
+
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
|
33 |
+
|
34 |
+
self.num_codebooks = num_codebooks
|
35 |
+
self.device = device
|
36 |
+
|
37 |
+
|
38 |
+
self.codebook_idx = torch.arange(self.batch_size*self.num_codebooks, device=self.device)
|
39 |
+
self.first_codebooks_unfinished = torch.arange(batch_size, device=device)*num_codebooks
|
40 |
+
|
41 |
+
max_codebooks = torch.arange(self.batch_size, device=self.device)*self.num_codebooks + self.num_codebooks -1
|
42 |
+
self.max_codebooks = max_codebooks
|
43 |
+
|
44 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
45 |
+
|
46 |
+
is_eos = isin_mps_friendly(input_ids, self.eos_token_id).sum(1)
|
47 |
+
|
48 |
+
self.first_codebooks_unfinished = torch.where((is_eos[self.first_codebooks_unfinished]>0) & (self.first_codebooks_unfinished<self.max_codebooks), self.first_codebooks_unfinished+1, self.first_codebooks_unfinished)
|
49 |
+
|
50 |
+
# every codebook higher than the first one unfinished will never be eos
|
51 |
+
eos_token_mask = self.codebook_idx > self.first_codebooks_unfinished.repeat_interleave(self.num_codebooks)
|
52 |
+
scores[eos_token_mask, self.eos_token_id] = -math.inf
|
53 |
+
|
54 |
+
return scores
|
capspeech/ar/parler_tts/modeling_parler_tts.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
capspeech/ar/parler_tts/streamer.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .modeling_parler_tts import ParlerTTSForConditionalGeneration
|
3 |
+
from transformers.generation.streamers import BaseStreamer
|
4 |
+
from typing import Optional
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import math
|
8 |
+
from queue import Queue
|
9 |
+
|
10 |
+
|
11 |
+
class ParlerTTSStreamer(BaseStreamer):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
model: ParlerTTSForConditionalGeneration,
|
15 |
+
device: Optional[str] = None,
|
16 |
+
play_steps: Optional[int] = 10,
|
17 |
+
stride: Optional[int] = None,
|
18 |
+
timeout: Optional[float] = None,
|
19 |
+
):
|
20 |
+
"""
|
21 |
+
Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is
|
22 |
+
useful for applications that benefit from accessing the generated audio in a non-blocking way (e.g. in an interactive
|
23 |
+
Gradio demo).
|
24 |
+
Parameters:
|
25 |
+
model (`ParlerTTSForConditionalGeneration`):
|
26 |
+
The Parler-TTS model used to generate the audio waveform.
|
27 |
+
device (`str`, *optional*):
|
28 |
+
The torch device on which to run the computation. If `None`, will default to the device of the model.
|
29 |
+
play_steps (`int`, *optional*, defaults to 10):
|
30 |
+
The number of generation steps with which to return the generated audio array. Using fewer steps will
|
31 |
+
mean the first chunk is ready faster, but will require more codec decoding steps overall. This value
|
32 |
+
should be tuned to your device and latency requirements.
|
33 |
+
stride (`int`, *optional*):
|
34 |
+
The window (stride) between adjacent audio samples. Using a stride between adjacent audio samples reduces
|
35 |
+
the hard boundary between them, giving smoother playback. If `None`, will default to a value equivalent to
|
36 |
+
play_steps // 6 in the audio space.
|
37 |
+
timeout (`int`, *optional*):
|
38 |
+
The timeout for the audio queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
|
39 |
+
in `.generate()`, when it is called in a separate thread.
|
40 |
+
"""
|
41 |
+
self.decoder = model.decoder
|
42 |
+
self.audio_encoder = model.audio_encoder
|
43 |
+
self.generation_config = model.generation_config
|
44 |
+
self.device = device if device is not None else model.device
|
45 |
+
self.use_audio_scales = model.use_audio_scales
|
46 |
+
self.use_4dim_audio_codes = model.use_4dim_audio_codes
|
47 |
+
self.audio_kwargs = {}
|
48 |
+
if self.use_audio_scales:
|
49 |
+
self.audio_kwargs["audio_scales"] = [None]
|
50 |
+
|
51 |
+
# variables used in the streaming process
|
52 |
+
self.play_steps = play_steps
|
53 |
+
if stride is not None:
|
54 |
+
self.stride = stride
|
55 |
+
else:
|
56 |
+
hop_length = math.floor(self.audio_encoder.config.sampling_rate / self.audio_encoder.config.frame_rate)
|
57 |
+
self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
|
58 |
+
self.token_cache = None
|
59 |
+
self.to_yield = 0
|
60 |
+
|
61 |
+
# varibles used in the thread process
|
62 |
+
self.audio_queue = Queue()
|
63 |
+
self.stop_signal = None
|
64 |
+
self.timeout = timeout
|
65 |
+
|
66 |
+
def apply_delay_pattern_mask(self, input_ids):
|
67 |
+
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler)
|
68 |
+
_, delay_pattern_mask = self.decoder.build_delay_pattern_mask(
|
69 |
+
input_ids[:, :1],
|
70 |
+
bos_token_id=self.generation_config.bos_token_id,
|
71 |
+
pad_token_id=self.generation_config.decoder_start_token_id,
|
72 |
+
max_length=input_ids.shape[-1],
|
73 |
+
)
|
74 |
+
# apply the pattern mask to the input ids
|
75 |
+
input_ids = self.decoder.apply_delay_pattern_mask(input_ids, delay_pattern_mask)
|
76 |
+
|
77 |
+
# revert the pattern delay mask by filtering the pad token id
|
78 |
+
mask = (delay_pattern_mask != self.generation_config.bos_token_id) & (delay_pattern_mask != self.generation_config.pad_token_id)
|
79 |
+
input_ids = input_ids[mask].reshape(1, self.decoder.num_codebooks, -1)
|
80 |
+
|
81 |
+
if self.use_4dim_audio_codes:
|
82 |
+
# append the frame dimension back to the audio codes
|
83 |
+
input_ids = input_ids[None, ...]
|
84 |
+
|
85 |
+
# send the input_ids to the correct device
|
86 |
+
input_ids = input_ids.to(self.audio_encoder.device)
|
87 |
+
|
88 |
+
decode_sequentially = (
|
89 |
+
self.generation_config.bos_token_id in input_ids
|
90 |
+
or self.generation_config.pad_token_id in input_ids
|
91 |
+
or self.generation_config.eos_token_id in input_ids
|
92 |
+
)
|
93 |
+
if not decode_sequentially:
|
94 |
+
sample = self.audio_encoder.decode(
|
95 |
+
audio_codes=input_ids,
|
96 |
+
**self.audio_kwargs,
|
97 |
+
).audio_values
|
98 |
+
output_values = sample if sample.ndim == 3 else sample.unsqueeze(0)
|
99 |
+
else:
|
100 |
+
sample = input_ids[:, 0] if self.use_4dim_audio_codes else input_ids[0]
|
101 |
+
sample_mask = ((sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0) if self.use_4dim_audio_codes else ((sample >= self.audio_encoder.config.codebook_size).sum(dim=0) == 0)
|
102 |
+
sample = sample[:, :, sample_mask] if self.use_4dim_audio_codes else sample[:, sample_mask]
|
103 |
+
sample = self.audio_encoder.decode(audio_codes=sample[None, ...], **self.audio_kwargs).audio_values
|
104 |
+
output_values = sample if sample.ndim == 3 else sample.unsqueeze(0)
|
105 |
+
|
106 |
+
audio_values = output_values[0, 0]
|
107 |
+
return audio_values.cpu().float().numpy()
|
108 |
+
|
109 |
+
def put(self, value):
|
110 |
+
batch_size = value.shape[0] // self.decoder.num_codebooks
|
111 |
+
if batch_size > 1:
|
112 |
+
raise ValueError("ParlerTTSStreamer only supports batch size 1")
|
113 |
+
|
114 |
+
if self.token_cache is None:
|
115 |
+
self.token_cache = value
|
116 |
+
else:
|
117 |
+
self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
|
118 |
+
|
119 |
+
if self.token_cache.shape[-1] % self.play_steps == 0:
|
120 |
+
audio_values = self.apply_delay_pattern_mask(self.token_cache)
|
121 |
+
self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
|
122 |
+
self.to_yield += len(audio_values) - self.to_yield - self.stride
|
123 |
+
|
124 |
+
def end(self):
|
125 |
+
"""Flushes any remaining cache and appends the stop symbol."""
|
126 |
+
if self.token_cache is not None:
|
127 |
+
audio_values = self.apply_delay_pattern_mask(self.token_cache)
|
128 |
+
else:
|
129 |
+
audio_values = np.zeros(self.to_yield)
|
130 |
+
|
131 |
+
self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
|
132 |
+
|
133 |
+
def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
|
134 |
+
"""Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
|
135 |
+
self.audio_queue.put(audio, timeout=self.timeout)
|
136 |
+
if stream_end:
|
137 |
+
self.audio_queue.put(self.stop_signal, timeout=self.timeout)
|
138 |
+
|
139 |
+
def __iter__(self):
|
140 |
+
return self
|
141 |
+
|
142 |
+
def __next__(self):
|
143 |
+
value = self.audio_queue.get(timeout=self.timeout)
|
144 |
+
if not isinstance(value, np.ndarray) and value == self.stop_signal:
|
145 |
+
raise StopIteration()
|
146 |
+
else:
|
147 |
+
return value
|
capspeech/ar/pretrain.sh
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Please log in to huggingface first
|
2 |
+
|
3 |
+
MLS_WAV_DIR='' # downloaded mls wav path
|
4 |
+
LIBRITTSRMIX_WAV_DIR='' # downloaded librittsrmix wav path
|
5 |
+
GIGASPEECH_WAV_DIR='' # downloaded gigaspeech wav path
|
6 |
+
COMMONVOICE_WAV_DIR='' # downloaded commonvoice wav path
|
7 |
+
EMILIA_WAV_DIR='' # downloaded emilia wav path
|
8 |
+
OUTPUT_DIR="./output_pretraining/" # output dir, to save checkpoints
|
9 |
+
TEMPORY_SAVE_TO_DISK="./audio_code_pretraining/" # dac codec saved dir
|
10 |
+
SAVE_TO_DISK="./dataset_pretraining/" # huggingface metadata saved dir
|
11 |
+
WANDB_KEY='' # your wandb key for logging
|
12 |
+
|
13 |
+
export CUDA_LAUNCH_BLOCKING=1
|
14 |
+
export TORCH_USE_CUDA_DSA=1
|
15 |
+
|
16 |
+
accelerate launch ./training/run_parler_tts_training.py \
|
17 |
+
--model_name_or_path "parler-tts/parler-tts-mini-v1" \
|
18 |
+
--feature_extractor_name "parler-tts/dac_44khZ_8kbps" \
|
19 |
+
--description_tokenizer_name "google/flan-t5-large" \
|
20 |
+
--prompt_tokenizer_name "google/flan-t5-large" \
|
21 |
+
--report_to "wandb" \
|
22 |
+
--wandb_key ${WANDB_KEY} \
|
23 |
+
--overwrite_output_dir true \
|
24 |
+
--train_dataset_name "OpenSound/CapSpeech" \
|
25 |
+
--train_split_name "train_PT" \
|
26 |
+
--eval_dataset_name "OpenSound/CapSpeech" \
|
27 |
+
--eval_split_name "validation_PT" \
|
28 |
+
--mls_dir ${MLS_WAV_DIR} \
|
29 |
+
--librittsrmix_dir ${LIBRITTSRMIX_WAV_DIR} \
|
30 |
+
--gigaspeech_dir ${GIGASPEECH_WAV_DIR} \
|
31 |
+
--commonvoice_dir ${COMMONVOICE_WAV_DIR} \
|
32 |
+
--emilia_dir ${EMILIA_WAV_DIR} \
|
33 |
+
--max_eval_samples 96 \
|
34 |
+
--per_device_eval_batch_size 32 \
|
35 |
+
--target_audio_column_name "audio_path" \
|
36 |
+
--description_column_name "caption" \
|
37 |
+
--source_column_name "source" \
|
38 |
+
--prompt_column_name "text" \
|
39 |
+
--max_duration_in_seconds 20 \
|
40 |
+
--min_duration_in_seconds 3 \
|
41 |
+
--max_text_length 600 \
|
42 |
+
--preprocessing_num_workers 32 \
|
43 |
+
--do_train true \
|
44 |
+
--num_train_epochs 10 \
|
45 |
+
--gradient_accumulation_steps 6 \
|
46 |
+
--gradient_checkpointing false \
|
47 |
+
--per_device_train_batch_size 4 \
|
48 |
+
--learning_rate 0.001 \
|
49 |
+
--adam_beta1 0.9 \
|
50 |
+
--adam_beta2 0.99 \
|
51 |
+
--weight_decay 0.01 \
|
52 |
+
--lr_scheduler_type "constant_with_warmup" \
|
53 |
+
--warmup_steps 5000 \
|
54 |
+
--logging_steps 200 \
|
55 |
+
--freeze_text_encoder false \
|
56 |
+
--per_device_eval_batch_size 4 \
|
57 |
+
--audio_encoder_per_device_batch_size 24 \
|
58 |
+
--dtype "float16" \
|
59 |
+
--seed 456 \
|
60 |
+
--output_dir ${OUTPUT_DIR} \
|
61 |
+
--temporary_save_to_disk ${TEMPORY_SAVE_TO_DISK} \
|
62 |
+
--save_to_disk ${SAVE_TO_DISK} \
|
63 |
+
--dataloader_num_workers 32 \
|
64 |
+
--do_eval \
|
65 |
+
--evaluation_strategy steps \
|
66 |
+
--eval_steps 5000 \
|
67 |
+
--save_steps 5000 \
|
68 |
+
--group_by_length true
|
capspeech/ar/training/__init__.py
ADDED
File without changes
|
capspeech/ar/training/arguments.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Optional, List
|
3 |
+
|
4 |
+
from transformers import Seq2SeqTrainingArguments
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class ModelArguments:
|
9 |
+
"""
|
10 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
11 |
+
"""
|
12 |
+
|
13 |
+
model_name_or_path: str = field(
|
14 |
+
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
15 |
+
)
|
16 |
+
config_name: Optional[str] = field(
|
17 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
18 |
+
)
|
19 |
+
feature_extractor_name: Optional[str] = field(
|
20 |
+
default=None, metadata={"help": "Pretrained feature extractor name or path if not the same as model_name"}
|
21 |
+
)
|
22 |
+
description_tokenizer_name: Optional[str] = field(
|
23 |
+
default=None, metadata={"help": "Pretrained description tokenizer name or path if not the same as model_name"}
|
24 |
+
)
|
25 |
+
prompt_tokenizer_name: Optional[str] = field(
|
26 |
+
default=None,
|
27 |
+
metadata={"help": "Pretrained prompt tokenizer name or path if not the same as description_tokenizer_name"},
|
28 |
+
)
|
29 |
+
cache_dir: Optional[str] = field(
|
30 |
+
default=None,
|
31 |
+
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
32 |
+
)
|
33 |
+
use_fast_tokenizer: bool = field(
|
34 |
+
default=True,
|
35 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
36 |
+
)
|
37 |
+
model_revision: str = field(
|
38 |
+
default="main",
|
39 |
+
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
40 |
+
)
|
41 |
+
pad_token_id: int = field(
|
42 |
+
default=None,
|
43 |
+
metadata={"help": "If specified, change the model pad token id."},
|
44 |
+
)
|
45 |
+
decoder_start_token_id: int = field(
|
46 |
+
default=None,
|
47 |
+
metadata={"help": "If specified, change the model decoder start token id."},
|
48 |
+
)
|
49 |
+
freeze_text_encoder: bool = field(
|
50 |
+
default=False,
|
51 |
+
metadata={"help": "Whether to freeze the text encoder."},
|
52 |
+
)
|
53 |
+
do_sample: bool = field(
|
54 |
+
default=True,
|
55 |
+
metadata={"help": "Whether to do sampling or greedy decoding."},
|
56 |
+
)
|
57 |
+
temperature: float = field(
|
58 |
+
default=1.0,
|
59 |
+
metadata={"help": "Temperature if sampling."},
|
60 |
+
)
|
61 |
+
max_length: int = field(
|
62 |
+
default=2580,
|
63 |
+
metadata={"help": "Generation max length."},
|
64 |
+
)
|
65 |
+
bandwidth: float = field(
|
66 |
+
default=6,
|
67 |
+
metadata={"help": "Audio encoder bandwidth."},
|
68 |
+
)
|
69 |
+
asr_model_name_or_path: str = field(
|
70 |
+
default="distil-whisper/distil-large-v2",
|
71 |
+
metadata={
|
72 |
+
"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
|
73 |
+
},
|
74 |
+
)
|
75 |
+
clap_model_name_or_path: str = field(
|
76 |
+
default="laion/larger_clap_music_and_speech",
|
77 |
+
metadata={
|
78 |
+
"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
|
79 |
+
},
|
80 |
+
)
|
81 |
+
attn_implementation: str = field(
|
82 |
+
default="eager",
|
83 |
+
metadata={"help": "Attention implementation used. One of `eager`, `sdpa`, `flash_attention_2`"},
|
84 |
+
)
|
85 |
+
cross_attention_implementation_strategy: str = field(
|
86 |
+
default=None,
|
87 |
+
metadata={
|
88 |
+
"help": "If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation."
|
89 |
+
},
|
90 |
+
)
|
91 |
+
prompt_padding_side: Optional[str] = field(
|
92 |
+
default="left",
|
93 |
+
metadata={
|
94 |
+
"help": "Prompt tokenizer padding side. Defaults to `left`. If the prompt is pre-pended to the codebooks hidden states, it should be padded on the left."
|
95 |
+
},
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
@dataclass
|
100 |
+
class DataTrainingArguments:
|
101 |
+
"""
|
102 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
103 |
+
|
104 |
+
Using `HfArgumentParser` we can turn this class
|
105 |
+
into argparse arguments to be able to specify them on
|
106 |
+
the command line.
|
107 |
+
"""
|
108 |
+
|
109 |
+
train_dataset_name: str = field(
|
110 |
+
default=None,
|
111 |
+
metadata={
|
112 |
+
"help": "The name of the training dataset to use (via the datasets library). Load and combine "
|
113 |
+
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
|
114 |
+
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
|
115 |
+
},
|
116 |
+
)
|
117 |
+
train_dataset_config_name: Optional[str] = field(
|
118 |
+
default=None,
|
119 |
+
metadata={
|
120 |
+
"help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
|
121 |
+
"multiple datasets by separating dataset configs by a '+' symbol."
|
122 |
+
},
|
123 |
+
)
|
124 |
+
train_split_name: str = field(
|
125 |
+
default="train",
|
126 |
+
metadata={
|
127 |
+
"help": ("The name of the training data set split to use (via the datasets library). Defaults to 'train'")
|
128 |
+
},
|
129 |
+
)
|
130 |
+
train_dataset_samples: str = field(
|
131 |
+
default=None,
|
132 |
+
metadata={
|
133 |
+
"help": "Number of samples in the training data. Load and combine "
|
134 |
+
"multiple datasets by separating dataset samples by a '+' symbol."
|
135 |
+
},
|
136 |
+
)
|
137 |
+
train_metadata_dataset_name: str = field(
|
138 |
+
default=None,
|
139 |
+
metadata={
|
140 |
+
"help": "The name of the metadata training dataset to use (via the datasets library). Load and combine "
|
141 |
+
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
|
142 |
+
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
|
143 |
+
},
|
144 |
+
)
|
145 |
+
eval_dataset_name: str = field(
|
146 |
+
default=None,
|
147 |
+
metadata={
|
148 |
+
"help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset name if unspecified."
|
149 |
+
},
|
150 |
+
)
|
151 |
+
eval_dataset_config_name: Optional[str] = field(
|
152 |
+
default=None,
|
153 |
+
metadata={
|
154 |
+
"help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset config name if unspecified"
|
155 |
+
},
|
156 |
+
)
|
157 |
+
eval_split_name: str = field(
|
158 |
+
default="test",
|
159 |
+
metadata={
|
160 |
+
"help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'test'"
|
161 |
+
},
|
162 |
+
)
|
163 |
+
eval_metadata_dataset_name: str = field(
|
164 |
+
default=None,
|
165 |
+
metadata={
|
166 |
+
"help": "The name of the metadata training dataset to use (via the datasets library). Load and combine "
|
167 |
+
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
|
168 |
+
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
|
169 |
+
},
|
170 |
+
)
|
171 |
+
target_audio_column_name: str = field(
|
172 |
+
default="audio",
|
173 |
+
metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"},
|
174 |
+
)
|
175 |
+
description_column_name: str = field(
|
176 |
+
default=None,
|
177 |
+
metadata={"help": "The name of the dataset column containing the description text data. Defaults to 'None'."},
|
178 |
+
)
|
179 |
+
prompt_column_name: str = field(
|
180 |
+
default=None,
|
181 |
+
metadata={"help": "The name of the dataset column containing the prompt text data. Defaults to 'None'."},
|
182 |
+
)
|
183 |
+
overwrite_cache: bool = field(
|
184 |
+
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
|
185 |
+
)
|
186 |
+
preprocessing_num_workers: Optional[int] = field(
|
187 |
+
default=None,
|
188 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
189 |
+
)
|
190 |
+
max_train_samples: Optional[int] = field(
|
191 |
+
default=None,
|
192 |
+
metadata={
|
193 |
+
"help": (
|
194 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
195 |
+
"value if set."
|
196 |
+
)
|
197 |
+
},
|
198 |
+
)
|
199 |
+
max_eval_samples: Optional[int] = field(
|
200 |
+
default=None,
|
201 |
+
metadata={
|
202 |
+
"help": (
|
203 |
+
"For debugging purposes or quicker training, truncate the number of validation examples to this "
|
204 |
+
"value if set."
|
205 |
+
)
|
206 |
+
},
|
207 |
+
)
|
208 |
+
max_duration_in_seconds: float = field(
|
209 |
+
default=35.0,
|
210 |
+
metadata={
|
211 |
+
"help": (
|
212 |
+
"Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`."
|
213 |
+
"Also, used to set maximum audio length if `pad_to_max_length=True`."
|
214 |
+
)
|
215 |
+
},
|
216 |
+
)
|
217 |
+
min_duration_in_seconds: float = field(
|
218 |
+
default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
|
219 |
+
)
|
220 |
+
max_text_length: int = field(
|
221 |
+
default=500, metadata={"help": "If set, max description lengths in number of characters."}
|
222 |
+
)
|
223 |
+
max_prompt_token_length: int = field(
|
224 |
+
default=None,
|
225 |
+
metadata={
|
226 |
+
"help": (
|
227 |
+
"If set, filter samples with prompts that are longer than `max_prompt_token_length` tokens."
|
228 |
+
"Also, used to set maximum prompt token length if `pad_to_max_length=True`."
|
229 |
+
)
|
230 |
+
},
|
231 |
+
)
|
232 |
+
max_description_token_length: int = field(
|
233 |
+
default=None,
|
234 |
+
metadata={
|
235 |
+
"help": (
|
236 |
+
"If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
|
237 |
+
"Also, used to set maximum description token length if `pad_to_max_length=True`."
|
238 |
+
)
|
239 |
+
},
|
240 |
+
)
|
241 |
+
pad_to_max_length: bool = field(
|
242 |
+
default=False,
|
243 |
+
metadata={
|
244 |
+
"help": (
|
245 |
+
"If `True`, pad audio, prompt and description to a maximum length set with respectively "
|
246 |
+
"`max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`."
|
247 |
+
)
|
248 |
+
},
|
249 |
+
)
|
250 |
+
preprocessing_only: bool = field(
|
251 |
+
default=False,
|
252 |
+
metadata={
|
253 |
+
"help": (
|
254 |
+
"Whether to only do data preprocessing and skip training. This is especially useful when data"
|
255 |
+
" preprocessing errors out in distributed training due to timeout. In this case, one should run the"
|
256 |
+
" preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
|
257 |
+
" can consequently be loaded in distributed training."
|
258 |
+
" In this training script, `save_to_disk` must be set to the path in which the dataset should be saved. "
|
259 |
+
)
|
260 |
+
},
|
261 |
+
)
|
262 |
+
token: str = field(
|
263 |
+
default=None,
|
264 |
+
metadata={
|
265 |
+
"help": (
|
266 |
+
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
|
267 |
+
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
|
268 |
+
)
|
269 |
+
},
|
270 |
+
)
|
271 |
+
use_auth_token: bool = field(
|
272 |
+
default=None,
|
273 |
+
metadata={
|
274 |
+
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
|
275 |
+
},
|
276 |
+
)
|
277 |
+
trust_remote_code: bool = field(
|
278 |
+
default=False,
|
279 |
+
metadata={
|
280 |
+
"help": (
|
281 |
+
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
|
282 |
+
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
|
283 |
+
"execute code present on the Hub on your local machine."
|
284 |
+
)
|
285 |
+
},
|
286 |
+
)
|
287 |
+
add_audio_samples_to_wandb: bool = field(
|
288 |
+
default=False,
|
289 |
+
metadata={"help": "If set and if `wandb` in args.report_to, will add generated audio samples to wandb logs."},
|
290 |
+
)
|
291 |
+
id_column_name: str = field(default=None, metadata={"help": "id column name."})
|
292 |
+
wandb_project: str = field(
|
293 |
+
default="parler-speech",
|
294 |
+
metadata={"help": "The name of the wandb project."},
|
295 |
+
)
|
296 |
+
wandb_run_name: str = field(
|
297 |
+
default=None,
|
298 |
+
metadata={
|
299 |
+
"help": "If specified, the name of the run. If not specified, wandb will give a random name to this run."
|
300 |
+
},
|
301 |
+
)
|
302 |
+
save_to_disk: str = field(
|
303 |
+
default=None,
|
304 |
+
metadata={
|
305 |
+
"help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
|
306 |
+
},
|
307 |
+
)
|
308 |
+
temporary_save_to_disk: str = field(default=None, metadata={"help": "Temporarily save audio labels here."})
|
309 |
+
save_codec_steps: Optional[int] = field(
|
310 |
+
default=500,
|
311 |
+
metadata={"help": "Temporarily save the audio labels every `save_steps`."},
|
312 |
+
)
|
313 |
+
pad_to_multiple_of: Optional[int] = field(
|
314 |
+
default=2,
|
315 |
+
metadata={"help": ("Pad to multiple of for tokenizers.")},
|
316 |
+
)
|
317 |
+
mls_dir: str = field(
|
318 |
+
default=None,
|
319 |
+
metadata={"help": "mls audio dir"},
|
320 |
+
)
|
321 |
+
librittsrmix_dir: str = field(
|
322 |
+
default=None,
|
323 |
+
metadata={"help": "librittsrmix audio dir"},
|
324 |
+
)
|
325 |
+
gigaspeech_dir: str = field(
|
326 |
+
default=None,
|
327 |
+
metadata={"help": "gigaspeech audio dir"},
|
328 |
+
)
|
329 |
+
commonvoice_dir: str = field(
|
330 |
+
default=None,
|
331 |
+
metadata={"help": "commonvoice audio dir"},
|
332 |
+
)
|
333 |
+
emilia_dir: str = field(
|
334 |
+
default=None,
|
335 |
+
metadata={"help": "emilia audio dir"},
|
336 |
+
)
|
337 |
+
source_column_name: str = field(
|
338 |
+
default="source",
|
339 |
+
metadata={"help": "The name of the source column."},
|
340 |
+
)
|
341 |
+
wandb_key: str = field(
|
342 |
+
default=None,
|
343 |
+
metadata={"help": "wandb key name"},
|
344 |
+
)
|
345 |
+
|
346 |
+
|
347 |
+
@dataclass
|
348 |
+
class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
|
349 |
+
dtype: Optional[str] = field(
|
350 |
+
default="float32",
|
351 |
+
metadata={
|
352 |
+
"help": (
|
353 |
+
"The data type (dtype) in which to run training. One of `float32` (full-precision), "
|
354 |
+
"`float16` or `bfloat16` (both half-precision)."
|
355 |
+
)
|
356 |
+
},
|
357 |
+
)
|
358 |
+
audio_encoder_per_device_batch_size: int = field(
|
359 |
+
default=8,
|
360 |
+
metadata={"help": ("Specify the batch size of the audio encoding pre-processing steps.")},
|
361 |
+
)
|
362 |
+
eval_dataloader_num_workers: Optional[int] = field(
|
363 |
+
default=0,
|
364 |
+
metadata={
|
365 |
+
"help": (
|
366 |
+
"Number of subprocesses to use for evaluation data loading (PyTorch only). 0 means that the data will be loaded in the main process."
|
367 |
+
)
|
368 |
+
},
|
369 |
+
)
|
370 |
+
compute_clap_similarity_metric: bool = field(
|
371 |
+
default=True,
|
372 |
+
metadata={
|
373 |
+
"help": (
|
374 |
+
"Whether or not to compute the clap similarity metric between the description and the generation during evalution."
|
375 |
+
)
|
376 |
+
},
|
377 |
+
)
|
378 |
+
compute_noise_level_metric: bool = field(
|
379 |
+
default=True,
|
380 |
+
metadata={"help": ("Whether or not to compute the squim si-sdr measure of the generations.")},
|
381 |
+
)
|
382 |
+
noise_level_to_compute_clean_wer: float = field(
|
383 |
+
default=25,
|
384 |
+
metadata={
|
385 |
+
"help": (
|
386 |
+
"if `compute_noise_level_metric=True`, will compute a 'clean' WER on samples with generated noise higher than `noise_level_to_compute_clean_wer`."
|
387 |
+
"This is a proxy measure to compute WER on clean audios, provided that the model learn to generate clean audios."
|
388 |
+
)
|
389 |
+
},
|
390 |
+
)
|
391 |
+
eval_generation_steps: Optional[int] = field(
|
392 |
+
default=None,
|
393 |
+
metadata={
|
394 |
+
"help": (
|
395 |
+
"Number of update steps between two generation evaluation. Will default to the same"
|
396 |
+
"value as `eval_steps` if not set. Should be an integer and a multiple of `eval_steps`."
|
397 |
+
)
|
398 |
+
},
|
399 |
+
)
|
400 |
+
codebook_weights: Optional[List[float]] = field(
|
401 |
+
default=None,
|
402 |
+
metadata={"help": "Weights applied to each codebook."},
|
403 |
+
)
|
capspeech/ar/training/arguments_captts.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Optional, List
|
3 |
+
|
4 |
+
from transformers import Seq2SeqTrainingArguments
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class ModelArguments:
|
9 |
+
"""
|
10 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
11 |
+
"""
|
12 |
+
|
13 |
+
model_name_or_path: str = field(
|
14 |
+
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
15 |
+
)
|
16 |
+
config_name: Optional[str] = field(
|
17 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
18 |
+
)
|
19 |
+
feature_extractor_name: Optional[str] = field(
|
20 |
+
default=None, metadata={"help": "Pretrained feature extractor name or path if not the same as model_name"}
|
21 |
+
)
|
22 |
+
description_tokenizer_name: Optional[str] = field(
|
23 |
+
default=None, metadata={"help": "Pretrained description tokenizer name or path if not the same as model_name"}
|
24 |
+
)
|
25 |
+
prompt_tokenizer_name: Optional[str] = field(
|
26 |
+
default=None,
|
27 |
+
metadata={"help": "Pretrained prompt tokenizer name or path if not the same as description_tokenizer_name"},
|
28 |
+
)
|
29 |
+
cache_dir: Optional[str] = field(
|
30 |
+
default=None,
|
31 |
+
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
32 |
+
)
|
33 |
+
use_fast_tokenizer: bool = field(
|
34 |
+
default=True,
|
35 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
36 |
+
)
|
37 |
+
model_revision: str = field(
|
38 |
+
default="main",
|
39 |
+
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
40 |
+
)
|
41 |
+
pad_token_id: int = field(
|
42 |
+
default=None,
|
43 |
+
metadata={"help": "If specified, change the model pad token id."},
|
44 |
+
)
|
45 |
+
decoder_start_token_id: int = field(
|
46 |
+
default=None,
|
47 |
+
metadata={"help": "If specified, change the model decoder start token id."},
|
48 |
+
)
|
49 |
+
freeze_text_encoder: bool = field(
|
50 |
+
default=False,
|
51 |
+
metadata={"help": "Whether to freeze the text encoder."},
|
52 |
+
)
|
53 |
+
do_sample: bool = field(
|
54 |
+
default=True,
|
55 |
+
metadata={"help": "Whether to do sampling or greedy decoding."},
|
56 |
+
)
|
57 |
+
temperature: float = field(
|
58 |
+
default=1.0,
|
59 |
+
metadata={"help": "Temperature if sampling."},
|
60 |
+
)
|
61 |
+
max_length: int = field(
|
62 |
+
default=2580,
|
63 |
+
metadata={"help": "Generation max length."},
|
64 |
+
)
|
65 |
+
bandwidth: float = field(
|
66 |
+
default=6,
|
67 |
+
metadata={"help": "Audio encoder bandwidth."},
|
68 |
+
)
|
69 |
+
asr_model_name_or_path: str = field(
|
70 |
+
default="distil-whisper/distil-large-v2",
|
71 |
+
metadata={
|
72 |
+
"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
|
73 |
+
},
|
74 |
+
)
|
75 |
+
clap_model_name_or_path: str = field(
|
76 |
+
default="laion/larger_clap_music_and_speech",
|
77 |
+
metadata={
|
78 |
+
"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
|
79 |
+
},
|
80 |
+
)
|
81 |
+
attn_implementation: str = field(
|
82 |
+
default="eager",
|
83 |
+
metadata={"help": "Attention implementation used. One of `eager`, `sdpa`, `flash_attention_2`"},
|
84 |
+
)
|
85 |
+
cross_attention_implementation_strategy: str = field(
|
86 |
+
default=None,
|
87 |
+
metadata={
|
88 |
+
"help": "If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation."
|
89 |
+
},
|
90 |
+
)
|
91 |
+
prompt_padding_side: Optional[str] = field(
|
92 |
+
default="left",
|
93 |
+
metadata={
|
94 |
+
"help": "Prompt tokenizer padding side. Defaults to `left`. If the prompt is pre-pended to the codebooks hidden states, it should be padded on the left."
|
95 |
+
},
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
@dataclass
|
100 |
+
class DataTrainingArguments:
|
101 |
+
"""
|
102 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
103 |
+
|
104 |
+
Using `HfArgumentParser` we can turn this class
|
105 |
+
into argparse arguments to be able to specify them on
|
106 |
+
the command line.
|
107 |
+
"""
|
108 |
+
|
109 |
+
train_dataset_name: str = field(
|
110 |
+
default=None,
|
111 |
+
metadata={
|
112 |
+
"help": "The name of the training dataset to use (via the datasets library). Load and combine "
|
113 |
+
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
|
114 |
+
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
|
115 |
+
},
|
116 |
+
)
|
117 |
+
train_dataset_config_name: Optional[str] = field(
|
118 |
+
default=None,
|
119 |
+
metadata={
|
120 |
+
"help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
|
121 |
+
"multiple datasets by separating dataset configs by a '+' symbol."
|
122 |
+
},
|
123 |
+
)
|
124 |
+
train_split_name: str = field(
|
125 |
+
default="train",
|
126 |
+
metadata={
|
127 |
+
"help": ("The name of the training data set split to use (via the datasets library). Defaults to 'train'")
|
128 |
+
},
|
129 |
+
)
|
130 |
+
train_dataset_samples: str = field(
|
131 |
+
default=None,
|
132 |
+
metadata={
|
133 |
+
"help": "Number of samples in the training data. Load and combine "
|
134 |
+
"multiple datasets by separating dataset samples by a '+' symbol."
|
135 |
+
},
|
136 |
+
)
|
137 |
+
train_metadata_dataset_name: str = field(
|
138 |
+
default=None,
|
139 |
+
metadata={
|
140 |
+
"help": "The name of the metadata training dataset to use (via the datasets library). Load and combine "
|
141 |
+
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
|
142 |
+
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
|
143 |
+
},
|
144 |
+
)
|
145 |
+
eval_dataset_name: str = field(
|
146 |
+
default=None,
|
147 |
+
metadata={
|
148 |
+
"help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset name if unspecified."
|
149 |
+
},
|
150 |
+
)
|
151 |
+
eval_dataset_config_name: Optional[str] = field(
|
152 |
+
default=None,
|
153 |
+
metadata={
|
154 |
+
"help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset config name if unspecified"
|
155 |
+
},
|
156 |
+
)
|
157 |
+
eval_split_name: str = field(
|
158 |
+
default="test",
|
159 |
+
metadata={
|
160 |
+
"help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'test'"
|
161 |
+
},
|
162 |
+
)
|
163 |
+
eval_metadata_dataset_name: str = field(
|
164 |
+
default=None,
|
165 |
+
metadata={
|
166 |
+
"help": "The name of the metadata training dataset to use (via the datasets library). Load and combine "
|
167 |
+
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
|
168 |
+
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
|
169 |
+
},
|
170 |
+
)
|
171 |
+
target_audio_column_name: str = field(
|
172 |
+
default="audio",
|
173 |
+
metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"},
|
174 |
+
)
|
175 |
+
description_column_name: str = field(
|
176 |
+
default=None,
|
177 |
+
metadata={"help": "The name of the dataset column containing the description text data. Defaults to 'None'."},
|
178 |
+
)
|
179 |
+
prompt_column_name: str = field(
|
180 |
+
default=None,
|
181 |
+
metadata={"help": "The name of the dataset column containing the prompt text data. Defaults to 'None'."},
|
182 |
+
)
|
183 |
+
overwrite_cache: bool = field(
|
184 |
+
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
|
185 |
+
)
|
186 |
+
preprocessing_num_workers: Optional[int] = field(
|
187 |
+
default=None,
|
188 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
189 |
+
)
|
190 |
+
max_train_samples: Optional[int] = field(
|
191 |
+
default=None,
|
192 |
+
metadata={
|
193 |
+
"help": (
|
194 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
195 |
+
"value if set."
|
196 |
+
)
|
197 |
+
},
|
198 |
+
)
|
199 |
+
max_eval_samples: Optional[int] = field(
|
200 |
+
default=None,
|
201 |
+
metadata={
|
202 |
+
"help": (
|
203 |
+
"For debugging purposes or quicker training, truncate the number of validation examples to this "
|
204 |
+
"value if set."
|
205 |
+
)
|
206 |
+
},
|
207 |
+
)
|
208 |
+
max_duration_in_seconds: float = field(
|
209 |
+
default=35.0,
|
210 |
+
metadata={
|
211 |
+
"help": (
|
212 |
+
"Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`."
|
213 |
+
"Also, used to set maximum audio length if `pad_to_max_length=True`."
|
214 |
+
)
|
215 |
+
},
|
216 |
+
)
|
217 |
+
min_duration_in_seconds: float = field(
|
218 |
+
default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
|
219 |
+
)
|
220 |
+
max_text_length: int = field(
|
221 |
+
default=500, metadata={"help": "If set, max description lengths in number of characters."}
|
222 |
+
)
|
223 |
+
max_prompt_token_length: int = field(
|
224 |
+
default=None,
|
225 |
+
metadata={
|
226 |
+
"help": (
|
227 |
+
"If set, filter samples with prompts that are longer than `max_prompt_token_length` tokens."
|
228 |
+
"Also, used to set maximum prompt token length if `pad_to_max_length=True`."
|
229 |
+
)
|
230 |
+
},
|
231 |
+
)
|
232 |
+
max_description_token_length: int = field(
|
233 |
+
default=None,
|
234 |
+
metadata={
|
235 |
+
"help": (
|
236 |
+
"If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
|
237 |
+
"Also, used to set maximum description token length if `pad_to_max_length=True`."
|
238 |
+
)
|
239 |
+
},
|
240 |
+
)
|
241 |
+
pad_to_max_length: bool = field(
|
242 |
+
default=False,
|
243 |
+
metadata={
|
244 |
+
"help": (
|
245 |
+
"If `True`, pad audio, prompt and description to a maximum length set with respectively "
|
246 |
+
"`max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`."
|
247 |
+
)
|
248 |
+
},
|
249 |
+
)
|
250 |
+
preprocessing_only: bool = field(
|
251 |
+
default=False,
|
252 |
+
metadata={
|
253 |
+
"help": (
|
254 |
+
"Whether to only do data preprocessing and skip training. This is especially useful when data"
|
255 |
+
" preprocessing errors out in distributed training due to timeout. In this case, one should run the"
|
256 |
+
" preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
|
257 |
+
" can consequently be loaded in distributed training."
|
258 |
+
" In this training script, `save_to_disk` must be set to the path in which the dataset should be saved. "
|
259 |
+
)
|
260 |
+
},
|
261 |
+
)
|
262 |
+
token: str = field(
|
263 |
+
default=None,
|
264 |
+
metadata={
|
265 |
+
"help": (
|
266 |
+
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
|
267 |
+
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
|
268 |
+
)
|
269 |
+
},
|
270 |
+
)
|
271 |
+
use_auth_token: bool = field(
|
272 |
+
default=None,
|
273 |
+
metadata={
|
274 |
+
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
|
275 |
+
},
|
276 |
+
)
|
277 |
+
trust_remote_code: bool = field(
|
278 |
+
default=False,
|
279 |
+
metadata={
|
280 |
+
"help": (
|
281 |
+
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
|
282 |
+
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
|
283 |
+
"execute code present on the Hub on your local machine."
|
284 |
+
)
|
285 |
+
},
|
286 |
+
)
|
287 |
+
add_audio_samples_to_wandb: bool = field(
|
288 |
+
default=False,
|
289 |
+
metadata={"help": "If set and if `wandb` in args.report_to, will add generated audio samples to wandb logs."},
|
290 |
+
)
|
291 |
+
id_column_name: str = field(default=None, metadata={"help": "id column name."})
|
292 |
+
wandb_project: str = field(
|
293 |
+
default="parler-speech",
|
294 |
+
metadata={"help": "The name of the wandb project."},
|
295 |
+
)
|
296 |
+
wandb_run_name: str = field(
|
297 |
+
default=None,
|
298 |
+
metadata={
|
299 |
+
"help": "If specified, the name of the run. If not specified, wandb will give a random name to this run."
|
300 |
+
},
|
301 |
+
)
|
302 |
+
save_to_disk: str = field(
|
303 |
+
default=None,
|
304 |
+
metadata={
|
305 |
+
"help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
|
306 |
+
},
|
307 |
+
)
|
308 |
+
temporary_save_to_disk: str = field(default=None, metadata={"help": "Temporarily save audio labels here."})
|
309 |
+
save_codec_steps: Optional[int] = field(
|
310 |
+
default=500,
|
311 |
+
metadata={"help": "Temporarily save the audio labels every `save_steps`."},
|
312 |
+
)
|
313 |
+
pad_to_multiple_of: Optional[int] = field(
|
314 |
+
default=2,
|
315 |
+
metadata={"help": ("Pad to multiple of for tokenizers.")},
|
316 |
+
)
|
317 |
+
librittsr_dir: str = field(
|
318 |
+
default=None,
|
319 |
+
metadata={"help": "librittsr audio dir"},
|
320 |
+
)
|
321 |
+
other_dir: str = field(
|
322 |
+
default=None,
|
323 |
+
metadata={"help": "other audio dir"},
|
324 |
+
)
|
325 |
+
source_column_name: str = field(
|
326 |
+
default="source",
|
327 |
+
metadata={"help": "The name of the source column."},
|
328 |
+
)
|
329 |
+
wandb_key: str = field(
|
330 |
+
default=None,
|
331 |
+
metadata={"help": "wandb key name"},
|
332 |
+
)
|
333 |
+
|
334 |
+
|
335 |
+
@dataclass
|
336 |
+
class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
|
337 |
+
dtype: Optional[str] = field(
|
338 |
+
default="float32",
|
339 |
+
metadata={
|
340 |
+
"help": (
|
341 |
+
"The data type (dtype) in which to run training. One of `float32` (full-precision), "
|
342 |
+
"`float16` or `bfloat16` (both half-precision)."
|
343 |
+
)
|
344 |
+
},
|
345 |
+
)
|
346 |
+
audio_encoder_per_device_batch_size: int = field(
|
347 |
+
default=8,
|
348 |
+
metadata={"help": ("Specify the batch size of the audio encoding pre-processing steps.")},
|
349 |
+
)
|
350 |
+
eval_dataloader_num_workers: Optional[int] = field(
|
351 |
+
default=0,
|
352 |
+
metadata={
|
353 |
+
"help": (
|
354 |
+
"Number of subprocesses to use for evaluation data loading (PyTorch only). 0 means that the data will be loaded in the main process."
|
355 |
+
)
|
356 |
+
},
|
357 |
+
)
|
358 |
+
compute_clap_similarity_metric: bool = field(
|
359 |
+
default=True,
|
360 |
+
metadata={
|
361 |
+
"help": (
|
362 |
+
"Whether or not to compute the clap similarity metric between the description and the generation during evalution."
|
363 |
+
)
|
364 |
+
},
|
365 |
+
)
|
366 |
+
compute_noise_level_metric: bool = field(
|
367 |
+
default=True,
|
368 |
+
metadata={"help": ("Whether or not to compute the squim si-sdr measure of the generations.")},
|
369 |
+
)
|
370 |
+
noise_level_to_compute_clean_wer: float = field(
|
371 |
+
default=25,
|
372 |
+
metadata={
|
373 |
+
"help": (
|
374 |
+
"if `compute_noise_level_metric=True`, will compute a 'clean' WER on samples with generated noise higher than `noise_level_to_compute_clean_wer`."
|
375 |
+
"This is a proxy measure to compute WER on clean audios, provided that the model learn to generate clean audios."
|
376 |
+
)
|
377 |
+
},
|
378 |
+
)
|
379 |
+
eval_generation_steps: Optional[int] = field(
|
380 |
+
default=None,
|
381 |
+
metadata={
|
382 |
+
"help": (
|
383 |
+
"Number of update steps between two generation evaluation. Will default to the same"
|
384 |
+
"value as `eval_steps` if not set. Should be an integer and a multiple of `eval_steps`."
|
385 |
+
)
|
386 |
+
},
|
387 |
+
)
|
388 |
+
codebook_weights: Optional[List[float]] = field(
|
389 |
+
default=None,
|
390 |
+
metadata={"help": "Weights applied to each codebook."},
|
391 |
+
)
|
capspeech/ar/training/arguments_capttsse.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Optional, List
|
3 |
+
|
4 |
+
from transformers import Seq2SeqTrainingArguments
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class ModelArguments:
|
9 |
+
"""
|
10 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
11 |
+
"""
|
12 |
+
|
13 |
+
model_name_or_path: str = field(
|
14 |
+
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
15 |
+
)
|
16 |
+
config_name: Optional[str] = field(
|
17 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
18 |
+
)
|
19 |
+
feature_extractor_name: Optional[str] = field(
|
20 |
+
default=None, metadata={"help": "Pretrained feature extractor name or path if not the same as model_name"}
|
21 |
+
)
|
22 |
+
description_tokenizer_name: Optional[str] = field(
|
23 |
+
default=None, metadata={"help": "Pretrained description tokenizer name or path if not the same as model_name"}
|
24 |
+
)
|
25 |
+
prompt_tokenizer_name: Optional[str] = field(
|
26 |
+
default=None,
|
27 |
+
metadata={"help": "Pretrained prompt tokenizer name or path if not the same as description_tokenizer_name"},
|
28 |
+
)
|
29 |
+
cache_dir: Optional[str] = field(
|
30 |
+
default=None,
|
31 |
+
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
32 |
+
)
|
33 |
+
use_fast_tokenizer: bool = field(
|
34 |
+
default=True,
|
35 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
36 |
+
)
|
37 |
+
model_revision: str = field(
|
38 |
+
default="main",
|
39 |
+
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
40 |
+
)
|
41 |
+
pad_token_id: int = field(
|
42 |
+
default=None,
|
43 |
+
metadata={"help": "If specified, change the model pad token id."},
|
44 |
+
)
|
45 |
+
decoder_start_token_id: int = field(
|
46 |
+
default=None,
|
47 |
+
metadata={"help": "If specified, change the model decoder start token id."},
|
48 |
+
)
|
49 |
+
freeze_text_encoder: bool = field(
|
50 |
+
default=False,
|
51 |
+
metadata={"help": "Whether to freeze the text encoder."},
|
52 |
+
)
|
53 |
+
do_sample: bool = field(
|
54 |
+
default=True,
|
55 |
+
metadata={"help": "Whether to do sampling or greedy decoding."},
|
56 |
+
)
|
57 |
+
temperature: float = field(
|
58 |
+
default=1.0,
|
59 |
+
metadata={"help": "Temperature if sampling."},
|
60 |
+
)
|
61 |
+
max_length: int = field(
|
62 |
+
default=2580,
|
63 |
+
metadata={"help": "Generation max length."},
|
64 |
+
)
|
65 |
+
bandwidth: float = field(
|
66 |
+
default=6,
|
67 |
+
metadata={"help": "Audio encoder bandwidth."},
|
68 |
+
)
|
69 |
+
asr_model_name_or_path: str = field(
|
70 |
+
default="distil-whisper/distil-large-v2",
|
71 |
+
metadata={
|
72 |
+
"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
|
73 |
+
},
|
74 |
+
)
|
75 |
+
clap_model_name_or_path: str = field(
|
76 |
+
default="laion/larger_clap_music_and_speech",
|
77 |
+
metadata={
|
78 |
+
"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
|
79 |
+
},
|
80 |
+
)
|
81 |
+
attn_implementation: str = field(
|
82 |
+
default="eager",
|
83 |
+
metadata={"help": "Attention implementation used. One of `eager`, `sdpa`, `flash_attention_2`"},
|
84 |
+
)
|
85 |
+
cross_attention_implementation_strategy: str = field(
|
86 |
+
default=None,
|
87 |
+
metadata={
|
88 |
+
"help": "If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation."
|
89 |
+
},
|
90 |
+
)
|
91 |
+
prompt_padding_side: Optional[str] = field(
|
92 |
+
default="left",
|
93 |
+
metadata={
|
94 |
+
"help": "Prompt tokenizer padding side. Defaults to `left`. If the prompt is pre-pended to the codebooks hidden states, it should be padded on the left."
|
95 |
+
},
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
@dataclass
|
100 |
+
class DataTrainingArguments:
|
101 |
+
"""
|
102 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
103 |
+
|
104 |
+
Using `HfArgumentParser` we can turn this class
|
105 |
+
into argparse arguments to be able to specify them on
|
106 |
+
the command line.
|
107 |
+
"""
|
108 |
+
|
109 |
+
train_dataset_name: str = field(
|
110 |
+
default=None,
|
111 |
+
metadata={
|
112 |
+
"help": "The name of the training dataset to use (via the datasets library). Load and combine "
|
113 |
+
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
|
114 |
+
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
|
115 |
+
},
|
116 |
+
)
|
117 |
+
train_dataset_config_name: Optional[str] = field(
|
118 |
+
default=None,
|
119 |
+
metadata={
|
120 |
+
"help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
|
121 |
+
"multiple datasets by separating dataset configs by a '+' symbol."
|
122 |
+
},
|
123 |
+
)
|
124 |
+
train_split_name: str = field(
|
125 |
+
default="train",
|
126 |
+
metadata={
|
127 |
+
"help": ("The name of the training data set split to use (via the datasets library). Defaults to 'train'")
|
128 |
+
},
|
129 |
+
)
|
130 |
+
train_dataset_samples: str = field(
|
131 |
+
default=None,
|
132 |
+
metadata={
|
133 |
+
"help": "Number of samples in the training data. Load and combine "
|
134 |
+
"multiple datasets by separating dataset samples by a '+' symbol."
|
135 |
+
},
|
136 |
+
)
|
137 |
+
train_metadata_dataset_name: str = field(
|
138 |
+
default=None,
|
139 |
+
metadata={
|
140 |
+
"help": "The name of the metadata training dataset to use (via the datasets library). Load and combine "
|
141 |
+
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
|
142 |
+
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
|
143 |
+
},
|
144 |
+
)
|
145 |
+
eval_dataset_name: str = field(
|
146 |
+
default=None,
|
147 |
+
metadata={
|
148 |
+
"help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset name if unspecified."
|
149 |
+
},
|
150 |
+
)
|
151 |
+
eval_dataset_config_name: Optional[str] = field(
|
152 |
+
default=None,
|
153 |
+
metadata={
|
154 |
+
"help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset config name if unspecified"
|
155 |
+
},
|
156 |
+
)
|
157 |
+
eval_split_name: str = field(
|
158 |
+
default="test",
|
159 |
+
metadata={
|
160 |
+
"help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'test'"
|
161 |
+
},
|
162 |
+
)
|
163 |
+
eval_metadata_dataset_name: str = field(
|
164 |
+
default=None,
|
165 |
+
metadata={
|
166 |
+
"help": "The name of the metadata training dataset to use (via the datasets library). Load and combine "
|
167 |
+
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
|
168 |
+
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
|
169 |
+
},
|
170 |
+
)
|
171 |
+
target_audio_column_name: str = field(
|
172 |
+
default="audio",
|
173 |
+
metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"},
|
174 |
+
)
|
175 |
+
description_column_name: str = field(
|
176 |
+
default=None,
|
177 |
+
metadata={"help": "The name of the dataset column containing the description text data. Defaults to 'None'."},
|
178 |
+
)
|
179 |
+
prompt_column_name: str = field(
|
180 |
+
default=None,
|
181 |
+
metadata={"help": "The name of the dataset column containing the prompt text data. Defaults to 'None'."},
|
182 |
+
)
|
183 |
+
overwrite_cache: bool = field(
|
184 |
+
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
|
185 |
+
)
|
186 |
+
preprocessing_num_workers: Optional[int] = field(
|
187 |
+
default=None,
|
188 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
189 |
+
)
|
190 |
+
max_train_samples: Optional[int] = field(
|
191 |
+
default=None,
|
192 |
+
metadata={
|
193 |
+
"help": (
|
194 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
195 |
+
"value if set."
|
196 |
+
)
|
197 |
+
},
|
198 |
+
)
|
199 |
+
max_eval_samples: Optional[int] = field(
|
200 |
+
default=None,
|
201 |
+
metadata={
|
202 |
+
"help": (
|
203 |
+
"For debugging purposes or quicker training, truncate the number of validation examples to this "
|
204 |
+
"value if set."
|
205 |
+
)
|
206 |
+
},
|
207 |
+
)
|
208 |
+
max_duration_in_seconds: float = field(
|
209 |
+
default=35.0,
|
210 |
+
metadata={
|
211 |
+
"help": (
|
212 |
+
"Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`."
|
213 |
+
"Also, used to set maximum audio length if `pad_to_max_length=True`."
|
214 |
+
)
|
215 |
+
},
|
216 |
+
)
|
217 |
+
min_duration_in_seconds: float = field(
|
218 |
+
default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
|
219 |
+
)
|
220 |
+
max_text_length: int = field(
|
221 |
+
default=500, metadata={"help": "If set, max description lengths in number of characters."}
|
222 |
+
)
|
223 |
+
max_prompt_token_length: int = field(
|
224 |
+
default=None,
|
225 |
+
metadata={
|
226 |
+
"help": (
|
227 |
+
"If set, filter samples with prompts that are longer than `max_prompt_token_length` tokens."
|
228 |
+
"Also, used to set maximum prompt token length if `pad_to_max_length=True`."
|
229 |
+
)
|
230 |
+
},
|
231 |
+
)
|
232 |
+
max_description_token_length: int = field(
|
233 |
+
default=None,
|
234 |
+
metadata={
|
235 |
+
"help": (
|
236 |
+
"If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
|
237 |
+
"Also, used to set maximum description token length if `pad_to_max_length=True`."
|
238 |
+
)
|
239 |
+
},
|
240 |
+
)
|
241 |
+
pad_to_max_length: bool = field(
|
242 |
+
default=False,
|
243 |
+
metadata={
|
244 |
+
"help": (
|
245 |
+
"If `True`, pad audio, prompt and description to a maximum length set with respectively "
|
246 |
+
"`max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`."
|
247 |
+
)
|
248 |
+
},
|
249 |
+
)
|
250 |
+
preprocessing_only: bool = field(
|
251 |
+
default=False,
|
252 |
+
metadata={
|
253 |
+
"help": (
|
254 |
+
"Whether to only do data preprocessing and skip training. This is especially useful when data"
|
255 |
+
" preprocessing errors out in distributed training due to timeout. In this case, one should run the"
|
256 |
+
" preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
|
257 |
+
" can consequently be loaded in distributed training."
|
258 |
+
" In this training script, `save_to_disk` must be set to the path in which the dataset should be saved. "
|
259 |
+
)
|
260 |
+
},
|
261 |
+
)
|
262 |
+
token: str = field(
|
263 |
+
default=None,
|
264 |
+
metadata={
|
265 |
+
"help": (
|
266 |
+
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
|
267 |
+
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
|
268 |
+
)
|
269 |
+
},
|
270 |
+
)
|
271 |
+
use_auth_token: bool = field(
|
272 |
+
default=None,
|
273 |
+
metadata={
|
274 |
+
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
|
275 |
+
},
|
276 |
+
)
|
277 |
+
trust_remote_code: bool = field(
|
278 |
+
default=False,
|
279 |
+
metadata={
|
280 |
+
"help": (
|
281 |
+
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
|
282 |
+
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
|
283 |
+
"execute code present on the Hub on your local machine."
|
284 |
+
)
|
285 |
+
},
|
286 |
+
)
|
287 |
+
add_audio_samples_to_wandb: bool = field(
|
288 |
+
default=False,
|
289 |
+
metadata={"help": "If set and if `wandb` in args.report_to, will add generated audio samples to wandb logs."},
|
290 |
+
)
|
291 |
+
id_column_name: str = field(default=None, metadata={"help": "id column name."})
|
292 |
+
wandb_project: str = field(
|
293 |
+
default="parler-speech",
|
294 |
+
metadata={"help": "The name of the wandb project."},
|
295 |
+
)
|
296 |
+
wandb_run_name: str = field(
|
297 |
+
default=None,
|
298 |
+
metadata={
|
299 |
+
"help": "If specified, the name of the run. If not specified, wandb will give a random name to this run."
|
300 |
+
},
|
301 |
+
)
|
302 |
+
save_to_disk: str = field(
|
303 |
+
default=None,
|
304 |
+
metadata={
|
305 |
+
"help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
|
306 |
+
},
|
307 |
+
)
|
308 |
+
temporary_save_to_disk: str = field(default=None, metadata={"help": "Temporarily save audio labels here."})
|
309 |
+
save_codec_steps: Optional[int] = field(
|
310 |
+
default=500,
|
311 |
+
metadata={"help": "Temporarily save the audio labels every `save_steps`."},
|
312 |
+
)
|
313 |
+
pad_to_multiple_of: Optional[int] = field(
|
314 |
+
default=2,
|
315 |
+
metadata={"help": ("Pad to multiple of for tokenizers.")},
|
316 |
+
)
|
317 |
+
librittsrmix_dir: str = field(
|
318 |
+
default=None,
|
319 |
+
metadata={"help": "librittsrmix audio dir"},
|
320 |
+
)
|
321 |
+
source_column_name: str = field(
|
322 |
+
default="source",
|
323 |
+
metadata={"help": "The name of the source column."},
|
324 |
+
)
|
325 |
+
wandb_key: str = field(
|
326 |
+
default=None,
|
327 |
+
metadata={"help": "wandb key name"},
|
328 |
+
)
|
329 |
+
|
330 |
+
|
331 |
+
@dataclass
|
332 |
+
class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
|
333 |
+
dtype: Optional[str] = field(
|
334 |
+
default="float32",
|
335 |
+
metadata={
|
336 |
+
"help": (
|
337 |
+
"The data type (dtype) in which to run training. One of `float32` (full-precision), "
|
338 |
+
"`float16` or `bfloat16` (both half-precision)."
|
339 |
+
)
|
340 |
+
},
|
341 |
+
)
|
342 |
+
audio_encoder_per_device_batch_size: int = field(
|
343 |
+
default=8,
|
344 |
+
metadata={"help": ("Specify the batch size of the audio encoding pre-processing steps.")},
|
345 |
+
)
|
346 |
+
eval_dataloader_num_workers: Optional[int] = field(
|
347 |
+
default=0,
|
348 |
+
metadata={
|
349 |
+
"help": (
|
350 |
+
"Number of subprocesses to use for evaluation data loading (PyTorch only). 0 means that the data will be loaded in the main process."
|
351 |
+
)
|
352 |
+
},
|
353 |
+
)
|
354 |
+
compute_clap_similarity_metric: bool = field(
|
355 |
+
default=True,
|
356 |
+
metadata={
|
357 |
+
"help": (
|
358 |
+
"Whether or not to compute the clap similarity metric between the description and the generation during evalution."
|
359 |
+
)
|
360 |
+
},
|
361 |
+
)
|
362 |
+
compute_noise_level_metric: bool = field(
|
363 |
+
default=True,
|
364 |
+
metadata={"help": ("Whether or not to compute the squim si-sdr measure of the generations.")},
|
365 |
+
)
|
366 |
+
noise_level_to_compute_clean_wer: float = field(
|
367 |
+
default=25,
|
368 |
+
metadata={
|
369 |
+
"help": (
|
370 |
+
"if `compute_noise_level_metric=True`, will compute a 'clean' WER on samples with generated noise higher than `noise_level_to_compute_clean_wer`."
|
371 |
+
"This is a proxy measure to compute WER on clean audios, provided that the model learn to generate clean audios."
|
372 |
+
)
|
373 |
+
},
|
374 |
+
)
|
375 |
+
eval_generation_steps: Optional[int] = field(
|
376 |
+
default=None,
|
377 |
+
metadata={
|
378 |
+
"help": (
|
379 |
+
"Number of update steps between two generation evaluation. Will default to the same"
|
380 |
+
"value as `eval_steps` if not set. Should be an integer and a multiple of `eval_steps`."
|
381 |
+
)
|
382 |
+
},
|
383 |
+
)
|
384 |
+
codebook_weights: Optional[List[float]] = field(
|
385 |
+
default=None,
|
386 |
+
metadata={"help": "Weights applied to each codebook."},
|
387 |
+
)
|
capspeech/ar/training/data.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Dict, List, Optional, Set, Union
|
4 |
+
import os
|
5 |
+
import datasets
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from accelerate import Accelerator
|
9 |
+
from datasets import Dataset, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset
|
10 |
+
from tqdm import tqdm
|
11 |
+
from transformers import AutoFeatureExtractor, AutoTokenizer
|
12 |
+
import torchaudio
|
13 |
+
import torchaudio.transforms as T
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class DataCollatorEncodecWithPadding:
|
17 |
+
"""
|
18 |
+
Data collator that will dynamically pad the inputs received to the longest sequence in the batch or
|
19 |
+
to `max_length` if `max_length` is set and `padding=max_length`.
|
20 |
+
"""
|
21 |
+
|
22 |
+
feature_extractor: AutoFeatureExtractor
|
23 |
+
audio_column_name: str
|
24 |
+
mls_dir: Optional[str] = None
|
25 |
+
librittsrmix_dir: Optional[str] = None
|
26 |
+
gigaspeech_dir: Optional[str] = None
|
27 |
+
commonvoice_dir: Optional[str] = None
|
28 |
+
emilia_dir: Optional[str] = None
|
29 |
+
feature_extractor_input_name: Optional[str] = "input_values"
|
30 |
+
max_length: Optional[int] = None
|
31 |
+
padding: Optional[str] = "longest"
|
32 |
+
|
33 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
34 |
+
# split inputs and labels since they have to be of different lengths and need
|
35 |
+
# different padding methods
|
36 |
+
sampling_rate = self.feature_extractor.sampling_rate
|
37 |
+
# load audio
|
38 |
+
audios = []
|
39 |
+
for f in features:
|
40 |
+
path = f[self.audio_column_name]
|
41 |
+
source = f["source"]
|
42 |
+
if source == "libritts-r":
|
43 |
+
path = os.path.join(self.librittsrmix_dir, path)
|
44 |
+
elif source == "mls":
|
45 |
+
path = os.path.join(self.mls_dir, path)
|
46 |
+
elif source == "gigaspeech":
|
47 |
+
path = os.path.join(self.gigaspeech_dir, path)
|
48 |
+
elif source == "commonvoice":
|
49 |
+
path = os.path.join(self.commonvoice_dir, path)
|
50 |
+
elif source == "emilia":
|
51 |
+
path = os.path.join(self.emilia_dir, path)
|
52 |
+
else:
|
53 |
+
raise ValueError(source)
|
54 |
+
|
55 |
+
if os.path.exists(path):
|
56 |
+
waveform, sr = torchaudio.load(path)
|
57 |
+
if sr != sampling_rate:
|
58 |
+
resampler = T.Resample(orig_freq=sr, new_freq=sampling_rate)
|
59 |
+
waveform = resampler(waveform)
|
60 |
+
if waveform.shape[0] > 1:
|
61 |
+
waveform = waveform.mean(dim=0, keepdim=True)
|
62 |
+
audios.append(waveform.squeeze())
|
63 |
+
else:
|
64 |
+
print(f"Read error: {path}")
|
65 |
+
|
66 |
+
|
67 |
+
len_audio = [len(audio) for audio in audios]
|
68 |
+
if self.max_length is not None:
|
69 |
+
audios = [audio[: min(l, self.max_length)] for audio, l in zip(audios, len_audio)]
|
70 |
+
|
71 |
+
# since resampling has already been performed in the 'load_multiple_datasets' function,
|
72 |
+
# a fixed sampling_rate(44100hz) is passed to the feature_extractor.
|
73 |
+
batch = self.feature_extractor(
|
74 |
+
[np.asarray(a, dtype=np.float32) for a in audios], sampling_rate=sampling_rate, return_tensors="pt", padding=self.padding, max_length=self.max_length
|
75 |
+
)
|
76 |
+
batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
|
77 |
+
return batch
|
78 |
+
|
79 |
+
|
80 |
+
@dataclass
|
81 |
+
class DataCollatorParlerTTSWithPadding:
|
82 |
+
"""
|
83 |
+
Data collator that will dynamically pad the inputs received.
|
84 |
+
Args:
|
85 |
+
prompt_tokenizer (:class:`~transformers.AutoTokenizer`)
|
86 |
+
The prompt_tokenizer used for proccessing the data.
|
87 |
+
description_tokenizer (:class:`~transformers.AutoTokenizer`)
|
88 |
+
The description_tokenizer used for proccessing the data.
|
89 |
+
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
90 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
91 |
+
among:
|
92 |
+
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
93 |
+
sequence if provided).
|
94 |
+
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
95 |
+
maximum acceptable input length for the model if that argument is not provided.
|
96 |
+
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
97 |
+
different lengths).
|
98 |
+
pad_to_multiple_of (:obj:`int`, `optional`):
|
99 |
+
If set will pad the sequence to a multiple of the provided value.
|
100 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
101 |
+
7.5 (Volta).
|
102 |
+
"""
|
103 |
+
|
104 |
+
prompt_tokenizer: AutoTokenizer
|
105 |
+
description_tokenizer: AutoTokenizer
|
106 |
+
padding: Union[bool, str] = "longest"
|
107 |
+
pad_to_multiple_of: Optional[int] = None
|
108 |
+
prompt_max_length: Optional[int] = None
|
109 |
+
description_max_length: Optional[int] = None
|
110 |
+
audio_max_length: Optional[int] = None
|
111 |
+
|
112 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
113 |
+
# split inputs and labels since they have to be of different lengths and need
|
114 |
+
# different padding methods
|
115 |
+
|
116 |
+
labels = [torch.tensor(feature["labels"]).transpose(0, 1) for feature in features]
|
117 |
+
# (bsz, seq_len, num_codebooks)
|
118 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
|
119 |
+
if self.audio_max_length is not None and self.padding == "max_length":
|
120 |
+
labels = torch.nn.functional.pad(
|
121 |
+
labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)), value=-100
|
122 |
+
)
|
123 |
+
|
124 |
+
input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
|
125 |
+
|
126 |
+
input_ids = self.description_tokenizer.pad(
|
127 |
+
input_ids,
|
128 |
+
return_tensors="pt",
|
129 |
+
padding=self.padding,
|
130 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
131 |
+
max_length=self.description_max_length,
|
132 |
+
)
|
133 |
+
|
134 |
+
batch = {"labels": labels, **input_ids}
|
135 |
+
|
136 |
+
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
|
137 |
+
prompt_input_ids = self.prompt_tokenizer.pad(
|
138 |
+
prompt_input_ids,
|
139 |
+
return_tensors="pt",
|
140 |
+
padding=self.padding,
|
141 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
142 |
+
max_length=self.prompt_max_length,
|
143 |
+
)
|
144 |
+
|
145 |
+
batch["prompt_input_ids"] = prompt_input_ids["input_ids"]
|
146 |
+
if "attention_mask" in prompt_input_ids:
|
147 |
+
batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"]
|
148 |
+
|
149 |
+
return batch
|
150 |
+
|
151 |
+
|
152 |
+
def convert_dataset_str_to_list(
|
153 |
+
dataset_names,
|
154 |
+
splits=None,
|
155 |
+
dataset_samples=None,
|
156 |
+
default_split="train",
|
157 |
+
):
|
158 |
+
if isinstance(dataset_names, str):
|
159 |
+
dataset_names = dataset_names.split("+")
|
160 |
+
splits = splits.split("+") if splits is not None else None
|
161 |
+
dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None
|
162 |
+
|
163 |
+
if splits is not None and len(splits) != len(dataset_names):
|
164 |
+
raise ValueError(
|
165 |
+
f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
|
166 |
+
)
|
167 |
+
|
168 |
+
if dataset_samples is not None:
|
169 |
+
if len(dataset_samples) != len(dataset_names):
|
170 |
+
raise ValueError(
|
171 |
+
f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
|
172 |
+
f"{len(dataset_samples)} samples."
|
173 |
+
)
|
174 |
+
dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
|
175 |
+
else:
|
176 |
+
dataset_samples = [None] * len(dataset_names)
|
177 |
+
|
178 |
+
splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
|
179 |
+
|
180 |
+
dataset_names_dict = []
|
181 |
+
for i, ds_name in enumerate(dataset_names):
|
182 |
+
dataset_names_dict.append(
|
183 |
+
{
|
184 |
+
"name": ds_name,
|
185 |
+
"split": splits[i],
|
186 |
+
"samples": dataset_samples[i],
|
187 |
+
}
|
188 |
+
)
|
189 |
+
return dataset_names_dict
|
190 |
+
|
191 |
+
|
192 |
+
def load_multiple_datasets(
|
193 |
+
accelerator: Accelerator,
|
194 |
+
dataset_names: Union[List, str],
|
195 |
+
splits: Optional[Union[List, str]] = None,
|
196 |
+
label_column_names: Optional[List] = None,
|
197 |
+
stopping_strategy: Optional[str] = "first_exhausted",
|
198 |
+
dataset_samples: Optional[Union[List, np.array]] = None,
|
199 |
+
streaming: Optional[bool] = False,
|
200 |
+
seed: Optional[int] = None,
|
201 |
+
id_column_name: Optional[str] = None,
|
202 |
+
columns_to_keep: Optional[Set[str]] = None,
|
203 |
+
prompt_column_name: Optional[str] = None,
|
204 |
+
sampling_rate: Optional[int] = None,
|
205 |
+
audio_column_name: Optional[str] = None,
|
206 |
+
logger: Optional[logging.Logger] = None,
|
207 |
+
librittsrmix_dir: Optional[Union[List, str]] = None,
|
208 |
+
mls_dir: Optional[Union[List, str]] = None,
|
209 |
+
gigaspeech_dir: Optional[Union[List, str]] = None,
|
210 |
+
commonvoice_dir: Optional[Union[List, str]] = None,
|
211 |
+
emilia_dir: Optional[Union[List, str]] = None,
|
212 |
+
**kwargs,
|
213 |
+
) -> Union[Dataset, IterableDataset]:
|
214 |
+
dataset_names_dict = convert_dataset_str_to_list(
|
215 |
+
dataset_names, splits, label_column_names, dataset_samples
|
216 |
+
)
|
217 |
+
|
218 |
+
if dataset_samples is not None:
|
219 |
+
dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
|
220 |
+
probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
|
221 |
+
else:
|
222 |
+
probabilities = None
|
223 |
+
|
224 |
+
all_datasets = []
|
225 |
+
# iterate over the datasets we want to interleave
|
226 |
+
for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."):
|
227 |
+
with accelerator.local_main_process_first():
|
228 |
+
dataset = load_dataset(
|
229 |
+
dataset_dict["name"],
|
230 |
+
split=dataset_dict["split"],
|
231 |
+
streaming=streaming,
|
232 |
+
**kwargs,
|
233 |
+
)
|
234 |
+
dataset_features = dataset.features.keys()
|
235 |
+
|
236 |
+
if columns_to_keep is not None:
|
237 |
+
dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
|
238 |
+
|
239 |
+
def resolve_path(example):
|
240 |
+
path = example["audio_path"]
|
241 |
+
source = example["source"]
|
242 |
+
|
243 |
+
if source == "libritts-r":
|
244 |
+
full_path = os.path.join(librittsrmix_dir, path)
|
245 |
+
elif source == "mls":
|
246 |
+
full_path = os.path.join(mls_dir, path)
|
247 |
+
elif source == "gigaspeech":
|
248 |
+
full_path = os.path.join(gigaspeech_dir, path)
|
249 |
+
elif source == "commonvoice":
|
250 |
+
full_path = os.path.join(commonvoice_dir, path)
|
251 |
+
elif source == "emilia":
|
252 |
+
full_path = os.path.join(emilia_dir, path)
|
253 |
+
else:
|
254 |
+
return False # unknown source
|
255 |
+
|
256 |
+
return os.path.exists(full_path)
|
257 |
+
|
258 |
+
dataset = dataset.filter(resolve_path, num_proc=16)
|
259 |
+
|
260 |
+
all_datasets.append(dataset)
|
261 |
+
|
262 |
+
if len(all_datasets) == 1:
|
263 |
+
# we have a single dataset so just return it as is
|
264 |
+
return all_datasets[0]
|
265 |
+
|
266 |
+
if streaming:
|
267 |
+
interleaved_dataset = interleave_datasets(
|
268 |
+
all_datasets,
|
269 |
+
stopping_strategy=stopping_strategy,
|
270 |
+
probabilities=probabilities,
|
271 |
+
seed=seed,
|
272 |
+
)
|
273 |
+
else:
|
274 |
+
with accelerator.local_main_process_first():
|
275 |
+
interleaved_dataset = concatenate_datasets(all_datasets)
|
276 |
+
|
277 |
+
return interleaved_dataset
|
capspeech/ar/training/data_captts.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Dict, List, Optional, Set, Union
|
4 |
+
import os
|
5 |
+
import datasets
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from accelerate import Accelerator
|
9 |
+
from datasets import Dataset, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset
|
10 |
+
from tqdm import tqdm
|
11 |
+
from transformers import AutoFeatureExtractor, AutoTokenizer
|
12 |
+
import torchaudio
|
13 |
+
import torchaudio.transforms as T
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class DataCollatorEncodecWithPadding:
|
17 |
+
"""
|
18 |
+
Data collator that will dynamically pad the inputs received to the longest sequence in the batch or
|
19 |
+
to `max_length` if `max_length` is set and `padding=max_length`.
|
20 |
+
"""
|
21 |
+
|
22 |
+
feature_extractor: AutoFeatureExtractor
|
23 |
+
audio_column_name: str
|
24 |
+
librittsr_dir: Optional[str] = None
|
25 |
+
other_dir: Optional[str] = None
|
26 |
+
feature_extractor_input_name: Optional[str] = "input_values"
|
27 |
+
max_length: Optional[int] = None
|
28 |
+
padding: Optional[str] = "longest"
|
29 |
+
|
30 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
31 |
+
# split inputs and labels since they have to be of different lengths and need
|
32 |
+
# different padding methods
|
33 |
+
sampling_rate = self.feature_extractor.sampling_rate
|
34 |
+
# load audio
|
35 |
+
audios = []
|
36 |
+
for f in features:
|
37 |
+
path = f[self.audio_column_name]
|
38 |
+
source = f["source"]
|
39 |
+
if source == "libritts-r":
|
40 |
+
path = os.path.join(self.librittsr_dir, path)
|
41 |
+
else:
|
42 |
+
path = os.path.join(self.other_dir, path)
|
43 |
+
|
44 |
+
if os.path.exists(path):
|
45 |
+
waveform, sr = torchaudio.load(path)
|
46 |
+
if sr != sampling_rate:
|
47 |
+
resampler = T.Resample(orig_freq=sr, new_freq=sampling_rate)
|
48 |
+
waveform = resampler(waveform)
|
49 |
+
if waveform.shape[0] > 1:
|
50 |
+
waveform = waveform.mean(dim=0, keepdim=True)
|
51 |
+
audios.append(waveform.squeeze())
|
52 |
+
else:
|
53 |
+
print(f"Read error: {path}")
|
54 |
+
|
55 |
+
|
56 |
+
len_audio = [len(audio) for audio in audios]
|
57 |
+
if self.max_length is not None:
|
58 |
+
audios = [audio[: min(l, self.max_length)] for audio, l in zip(audios, len_audio)]
|
59 |
+
|
60 |
+
# since resampling has already been performed in the 'load_multiple_datasets' function,
|
61 |
+
# a fixed sampling_rate(44100hz) is passed to the feature_extractor.
|
62 |
+
batch = self.feature_extractor(
|
63 |
+
[np.asarray(a, dtype=np.float32) for a in audios], sampling_rate=sampling_rate, return_tensors="pt", padding=self.padding, max_length=self.max_length
|
64 |
+
)
|
65 |
+
batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
|
66 |
+
return batch
|
67 |
+
|
68 |
+
|
69 |
+
@dataclass
|
70 |
+
class DataCollatorParlerTTSWithPadding:
|
71 |
+
"""
|
72 |
+
Data collator that will dynamically pad the inputs received.
|
73 |
+
Args:
|
74 |
+
prompt_tokenizer (:class:`~transformers.AutoTokenizer`)
|
75 |
+
The prompt_tokenizer used for proccessing the data.
|
76 |
+
description_tokenizer (:class:`~transformers.AutoTokenizer`)
|
77 |
+
The description_tokenizer used for proccessing the data.
|
78 |
+
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
79 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
80 |
+
among:
|
81 |
+
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
82 |
+
sequence if provided).
|
83 |
+
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
84 |
+
maximum acceptable input length for the model if that argument is not provided.
|
85 |
+
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
86 |
+
different lengths).
|
87 |
+
pad_to_multiple_of (:obj:`int`, `optional`):
|
88 |
+
If set will pad the sequence to a multiple of the provided value.
|
89 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
90 |
+
7.5 (Volta).
|
91 |
+
"""
|
92 |
+
|
93 |
+
prompt_tokenizer: AutoTokenizer
|
94 |
+
description_tokenizer: AutoTokenizer
|
95 |
+
padding: Union[bool, str] = "longest"
|
96 |
+
pad_to_multiple_of: Optional[int] = None
|
97 |
+
prompt_max_length: Optional[int] = None
|
98 |
+
description_max_length: Optional[int] = None
|
99 |
+
audio_max_length: Optional[int] = None
|
100 |
+
|
101 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
102 |
+
# split inputs and labels since they have to be of different lengths and need
|
103 |
+
# different padding methods
|
104 |
+
|
105 |
+
labels = [torch.tensor(feature["labels"]).transpose(0, 1) for feature in features]
|
106 |
+
# (bsz, seq_len, num_codebooks)
|
107 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
|
108 |
+
if self.audio_max_length is not None and self.padding == "max_length":
|
109 |
+
labels = torch.nn.functional.pad(
|
110 |
+
labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)), value=-100
|
111 |
+
)
|
112 |
+
|
113 |
+
input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
|
114 |
+
|
115 |
+
input_ids = self.description_tokenizer.pad(
|
116 |
+
input_ids,
|
117 |
+
return_tensors="pt",
|
118 |
+
padding=self.padding,
|
119 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
120 |
+
max_length=self.description_max_length,
|
121 |
+
)
|
122 |
+
|
123 |
+
batch = {"labels": labels, **input_ids}
|
124 |
+
|
125 |
+
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
|
126 |
+
prompt_input_ids = self.prompt_tokenizer.pad(
|
127 |
+
prompt_input_ids,
|
128 |
+
return_tensors="pt",
|
129 |
+
padding=self.padding,
|
130 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
131 |
+
max_length=self.prompt_max_length,
|
132 |
+
)
|
133 |
+
|
134 |
+
batch["prompt_input_ids"] = prompt_input_ids["input_ids"]
|
135 |
+
if "attention_mask" in prompt_input_ids:
|
136 |
+
batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"]
|
137 |
+
|
138 |
+
return batch
|
139 |
+
|
140 |
+
|
141 |
+
def convert_dataset_str_to_list(
|
142 |
+
dataset_names,
|
143 |
+
splits=None,
|
144 |
+
dataset_samples=None,
|
145 |
+
default_split="train",
|
146 |
+
):
|
147 |
+
if isinstance(dataset_names, str):
|
148 |
+
dataset_names = dataset_names.split("+")
|
149 |
+
splits = splits.split("+") if splits is not None else None
|
150 |
+
dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None
|
151 |
+
|
152 |
+
if splits is not None and len(splits) != len(dataset_names):
|
153 |
+
raise ValueError(
|
154 |
+
f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
|
155 |
+
)
|
156 |
+
|
157 |
+
if dataset_samples is not None:
|
158 |
+
if len(dataset_samples) != len(dataset_names):
|
159 |
+
raise ValueError(
|
160 |
+
f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
|
161 |
+
f"{len(dataset_samples)} samples."
|
162 |
+
)
|
163 |
+
dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
|
164 |
+
else:
|
165 |
+
dataset_samples = [None] * len(dataset_names)
|
166 |
+
|
167 |
+
splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
|
168 |
+
|
169 |
+
dataset_names_dict = []
|
170 |
+
for i, ds_name in enumerate(dataset_names):
|
171 |
+
dataset_names_dict.append(
|
172 |
+
{
|
173 |
+
"name": ds_name,
|
174 |
+
"split": splits[i],
|
175 |
+
"samples": dataset_samples[i],
|
176 |
+
}
|
177 |
+
)
|
178 |
+
return dataset_names_dict
|
179 |
+
|
180 |
+
|
181 |
+
def load_multiple_datasets(
|
182 |
+
accelerator: Accelerator,
|
183 |
+
dataset_names: Union[List, str],
|
184 |
+
splits: Optional[Union[List, str]] = None,
|
185 |
+
label_column_names: Optional[List] = None,
|
186 |
+
stopping_strategy: Optional[str] = "first_exhausted",
|
187 |
+
dataset_samples: Optional[Union[List, np.array]] = None,
|
188 |
+
streaming: Optional[bool] = False,
|
189 |
+
seed: Optional[int] = None,
|
190 |
+
id_column_name: Optional[str] = None,
|
191 |
+
columns_to_keep: Optional[Set[str]] = None,
|
192 |
+
prompt_column_name: Optional[str] = None,
|
193 |
+
sampling_rate: Optional[int] = None,
|
194 |
+
audio_column_name: Optional[str] = None,
|
195 |
+
logger: Optional[logging.Logger] = None,
|
196 |
+
librittsr_dir: Optional[Union[List, str]] = None,
|
197 |
+
other_dir: Optional[Union[List, str]] = None,
|
198 |
+
**kwargs,
|
199 |
+
) -> Union[Dataset, IterableDataset]:
|
200 |
+
dataset_names_dict = convert_dataset_str_to_list(
|
201 |
+
dataset_names, splits, label_column_names, dataset_samples
|
202 |
+
)
|
203 |
+
|
204 |
+
if dataset_samples is not None:
|
205 |
+
dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
|
206 |
+
probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
|
207 |
+
else:
|
208 |
+
probabilities = None
|
209 |
+
|
210 |
+
all_datasets = []
|
211 |
+
# iterate over the datasets we want to interleave
|
212 |
+
for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."):
|
213 |
+
with accelerator.local_main_process_first():
|
214 |
+
dataset = load_dataset(
|
215 |
+
dataset_dict["name"],
|
216 |
+
split=dataset_dict["split"],
|
217 |
+
streaming=streaming,
|
218 |
+
**kwargs,
|
219 |
+
)
|
220 |
+
dataset_features = dataset.features.keys()
|
221 |
+
|
222 |
+
if columns_to_keep is not None:
|
223 |
+
dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
|
224 |
+
|
225 |
+
def resolve_path(example):
|
226 |
+
path = example["audio_path"]
|
227 |
+
source = example["source"]
|
228 |
+
|
229 |
+
if source == "libritts-r":
|
230 |
+
full_path = os.path.join(librittsr_dir, path)
|
231 |
+
else:
|
232 |
+
full_path = os.path.join(other_dir, path)
|
233 |
+
|
234 |
+
return os.path.exists(full_path)
|
235 |
+
|
236 |
+
dataset = dataset.filter(resolve_path, num_proc=16)
|
237 |
+
|
238 |
+
all_datasets.append(dataset)
|
239 |
+
|
240 |
+
if len(all_datasets) == 1:
|
241 |
+
# we have a single dataset so just return it as is
|
242 |
+
return all_datasets[0]
|
243 |
+
|
244 |
+
if streaming:
|
245 |
+
interleaved_dataset = interleave_datasets(
|
246 |
+
all_datasets,
|
247 |
+
stopping_strategy=stopping_strategy,
|
248 |
+
probabilities=probabilities,
|
249 |
+
seed=seed,
|
250 |
+
)
|
251 |
+
else:
|
252 |
+
with accelerator.local_main_process_first():
|
253 |
+
interleaved_dataset = concatenate_datasets(all_datasets)
|
254 |
+
|
255 |
+
return interleaved_dataset
|
capspeech/ar/training/data_capttsse.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Dict, List, Optional, Set, Union
|
4 |
+
import os
|
5 |
+
import datasets
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from accelerate import Accelerator
|
9 |
+
from datasets import Dataset, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset
|
10 |
+
from tqdm import tqdm
|
11 |
+
from transformers import AutoFeatureExtractor, AutoTokenizer
|
12 |
+
import torchaudio
|
13 |
+
import torchaudio.transforms as T
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class DataCollatorEncodecWithPadding:
|
17 |
+
"""
|
18 |
+
Data collator that will dynamically pad the inputs received to the longest sequence in the batch or
|
19 |
+
to `max_length` if `max_length` is set and `padding=max_length`.
|
20 |
+
"""
|
21 |
+
|
22 |
+
feature_extractor: AutoFeatureExtractor
|
23 |
+
audio_column_name: str
|
24 |
+
librittsrmix_dir: Optional[str] = None
|
25 |
+
feature_extractor_input_name: Optional[str] = "input_values"
|
26 |
+
max_length: Optional[int] = None
|
27 |
+
padding: Optional[str] = "longest"
|
28 |
+
|
29 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
30 |
+
# split inputs and labels since they have to be of different lengths and need
|
31 |
+
# different padding methods
|
32 |
+
sampling_rate = self.feature_extractor.sampling_rate
|
33 |
+
# load audio
|
34 |
+
audios = []
|
35 |
+
for f in features:
|
36 |
+
path = f[self.audio_column_name]
|
37 |
+
source = f["source"]
|
38 |
+
if source == "libritts-r":
|
39 |
+
path = os.path.join(self.librittsrmix_dir, path)
|
40 |
+
else:
|
41 |
+
raise ValueError(source)
|
42 |
+
|
43 |
+
if os.path.exists(path):
|
44 |
+
waveform, sr = torchaudio.load(path)
|
45 |
+
if sr != sampling_rate:
|
46 |
+
resampler = T.Resample(orig_freq=sr, new_freq=sampling_rate)
|
47 |
+
waveform = resampler(waveform)
|
48 |
+
if waveform.shape[0] > 1:
|
49 |
+
waveform = waveform.mean(dim=0, keepdim=True)
|
50 |
+
audios.append(waveform.squeeze())
|
51 |
+
else:
|
52 |
+
print(f"Read error: {path}")
|
53 |
+
|
54 |
+
|
55 |
+
len_audio = [len(audio) for audio in audios]
|
56 |
+
if self.max_length is not None:
|
57 |
+
audios = [audio[: min(l, self.max_length)] for audio, l in zip(audios, len_audio)]
|
58 |
+
|
59 |
+
# since resampling has already been performed in the 'load_multiple_datasets' function,
|
60 |
+
# a fixed sampling_rate(44100hz) is passed to the feature_extractor.
|
61 |
+
batch = self.feature_extractor(
|
62 |
+
[np.asarray(a, dtype=np.float32) for a in audios], sampling_rate=sampling_rate, return_tensors="pt", padding=self.padding, max_length=self.max_length
|
63 |
+
)
|
64 |
+
batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
|
65 |
+
return batch
|
66 |
+
|
67 |
+
|
68 |
+
@dataclass
|
69 |
+
class DataCollatorParlerTTSWithPadding:
|
70 |
+
"""
|
71 |
+
Data collator that will dynamically pad the inputs received.
|
72 |
+
Args:
|
73 |
+
prompt_tokenizer (:class:`~transformers.AutoTokenizer`)
|
74 |
+
The prompt_tokenizer used for proccessing the data.
|
75 |
+
description_tokenizer (:class:`~transformers.AutoTokenizer`)
|
76 |
+
The description_tokenizer used for proccessing the data.
|
77 |
+
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
78 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
79 |
+
among:
|
80 |
+
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
81 |
+
sequence if provided).
|
82 |
+
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
83 |
+
maximum acceptable input length for the model if that argument is not provided.
|
84 |
+
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
85 |
+
different lengths).
|
86 |
+
pad_to_multiple_of (:obj:`int`, `optional`):
|
87 |
+
If set will pad the sequence to a multiple of the provided value.
|
88 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
89 |
+
7.5 (Volta).
|
90 |
+
"""
|
91 |
+
|
92 |
+
prompt_tokenizer: AutoTokenizer
|
93 |
+
description_tokenizer: AutoTokenizer
|
94 |
+
padding: Union[bool, str] = "longest"
|
95 |
+
pad_to_multiple_of: Optional[int] = None
|
96 |
+
prompt_max_length: Optional[int] = None
|
97 |
+
description_max_length: Optional[int] = None
|
98 |
+
audio_max_length: Optional[int] = None
|
99 |
+
|
100 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
101 |
+
# split inputs and labels since they have to be of different lengths and need
|
102 |
+
# different padding methods
|
103 |
+
|
104 |
+
labels = [torch.tensor(feature["labels"]).transpose(0, 1) for feature in features]
|
105 |
+
# (bsz, seq_len, num_codebooks)
|
106 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
|
107 |
+
if self.audio_max_length is not None and self.padding == "max_length":
|
108 |
+
labels = torch.nn.functional.pad(
|
109 |
+
labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)), value=-100
|
110 |
+
)
|
111 |
+
|
112 |
+
input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
|
113 |
+
|
114 |
+
input_ids = self.description_tokenizer.pad(
|
115 |
+
input_ids,
|
116 |
+
return_tensors="pt",
|
117 |
+
padding=self.padding,
|
118 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
119 |
+
max_length=self.description_max_length,
|
120 |
+
)
|
121 |
+
|
122 |
+
batch = {"labels": labels, **input_ids}
|
123 |
+
|
124 |
+
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
|
125 |
+
prompt_input_ids = self.prompt_tokenizer.pad(
|
126 |
+
prompt_input_ids,
|
127 |
+
return_tensors="pt",
|
128 |
+
padding=self.padding,
|
129 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
130 |
+
max_length=self.prompt_max_length,
|
131 |
+
)
|
132 |
+
|
133 |
+
batch["prompt_input_ids"] = prompt_input_ids["input_ids"]
|
134 |
+
if "attention_mask" in prompt_input_ids:
|
135 |
+
batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"]
|
136 |
+
|
137 |
+
return batch
|
138 |
+
|
139 |
+
|
140 |
+
def convert_dataset_str_to_list(
|
141 |
+
dataset_names,
|
142 |
+
splits=None,
|
143 |
+
dataset_samples=None,
|
144 |
+
default_split="train",
|
145 |
+
):
|
146 |
+
if isinstance(dataset_names, str):
|
147 |
+
dataset_names = dataset_names.split("+")
|
148 |
+
splits = splits.split("+") if splits is not None else None
|
149 |
+
dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None
|
150 |
+
|
151 |
+
if splits is not None and len(splits) != len(dataset_names):
|
152 |
+
raise ValueError(
|
153 |
+
f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
|
154 |
+
)
|
155 |
+
|
156 |
+
if dataset_samples is not None:
|
157 |
+
if len(dataset_samples) != len(dataset_names):
|
158 |
+
raise ValueError(
|
159 |
+
f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
|
160 |
+
f"{len(dataset_samples)} samples."
|
161 |
+
)
|
162 |
+
dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
|
163 |
+
else:
|
164 |
+
dataset_samples = [None] * len(dataset_names)
|
165 |
+
|
166 |
+
splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
|
167 |
+
|
168 |
+
dataset_names_dict = []
|
169 |
+
for i, ds_name in enumerate(dataset_names):
|
170 |
+
dataset_names_dict.append(
|
171 |
+
{
|
172 |
+
"name": ds_name,
|
173 |
+
"split": splits[i],
|
174 |
+
"samples": dataset_samples[i],
|
175 |
+
}
|
176 |
+
)
|
177 |
+
return dataset_names_dict
|
178 |
+
|
179 |
+
|
180 |
+
def load_multiple_datasets(
|
181 |
+
accelerator: Accelerator,
|
182 |
+
dataset_names: Union[List, str],
|
183 |
+
splits: Optional[Union[List, str]] = None,
|
184 |
+
label_column_names: Optional[List] = None,
|
185 |
+
stopping_strategy: Optional[str] = "first_exhausted",
|
186 |
+
dataset_samples: Optional[Union[List, np.array]] = None,
|
187 |
+
streaming: Optional[bool] = False,
|
188 |
+
seed: Optional[int] = None,
|
189 |
+
id_column_name: Optional[str] = None,
|
190 |
+
columns_to_keep: Optional[Set[str]] = None,
|
191 |
+
prompt_column_name: Optional[str] = None,
|
192 |
+
sampling_rate: Optional[int] = None,
|
193 |
+
audio_column_name: Optional[str] = None,
|
194 |
+
logger: Optional[logging.Logger] = None,
|
195 |
+
librittsrmix_dir: Optional[Union[List, str]] = None,
|
196 |
+
**kwargs,
|
197 |
+
) -> Union[Dataset, IterableDataset]:
|
198 |
+
dataset_names_dict = convert_dataset_str_to_list(
|
199 |
+
dataset_names, splits, label_column_names, dataset_samples
|
200 |
+
)
|
201 |
+
|
202 |
+
if dataset_samples is not None:
|
203 |
+
dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
|
204 |
+
probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
|
205 |
+
else:
|
206 |
+
probabilities = None
|
207 |
+
|
208 |
+
all_datasets = []
|
209 |
+
# iterate over the datasets we want to interleave
|
210 |
+
for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."):
|
211 |
+
with accelerator.local_main_process_first():
|
212 |
+
dataset = load_dataset(
|
213 |
+
dataset_dict["name"],
|
214 |
+
split=dataset_dict["split"],
|
215 |
+
streaming=streaming,
|
216 |
+
**kwargs,
|
217 |
+
)
|
218 |
+
dataset_features = dataset.features.keys()
|
219 |
+
|
220 |
+
if columns_to_keep is not None:
|
221 |
+
dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
|
222 |
+
|
223 |
+
def resolve_path(example):
|
224 |
+
path = example["audio_path"]
|
225 |
+
source = example["source"]
|
226 |
+
|
227 |
+
if source == "libritts-r":
|
228 |
+
full_path = os.path.join(librittsrmix_dir, path)
|
229 |
+
else:
|
230 |
+
return False # unknown source
|
231 |
+
|
232 |
+
return os.path.exists(full_path)
|
233 |
+
|
234 |
+
dataset = dataset.filter(resolve_path, num_proc=16)
|
235 |
+
|
236 |
+
all_datasets.append(dataset)
|
237 |
+
|
238 |
+
if len(all_datasets) == 1:
|
239 |
+
# we have a single dataset so just return it as is
|
240 |
+
return all_datasets[0]
|
241 |
+
|
242 |
+
if streaming:
|
243 |
+
interleaved_dataset = interleave_datasets(
|
244 |
+
all_datasets,
|
245 |
+
stopping_strategy=stopping_strategy,
|
246 |
+
probabilities=probabilities,
|
247 |
+
seed=seed,
|
248 |
+
)
|
249 |
+
else:
|
250 |
+
with accelerator.local_main_process_first():
|
251 |
+
interleaved_dataset = concatenate_datasets(all_datasets)
|
252 |
+
|
253 |
+
return interleaved_dataset
|
capspeech/ar/training/finetune_captts.py
ADDED
@@ -0,0 +1,1270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
""" Train Parler-TTS using 🤗 Accelerate"""
|
18 |
+
|
19 |
+
import logging
|
20 |
+
import os
|
21 |
+
import re
|
22 |
+
import sys
|
23 |
+
import time
|
24 |
+
import math
|
25 |
+
import contextlib
|
26 |
+
from multiprocess import set_start_method
|
27 |
+
from datetime import timedelta
|
28 |
+
import inspect
|
29 |
+
from tqdm import tqdm
|
30 |
+
from pathlib import Path
|
31 |
+
import wandb
|
32 |
+
|
33 |
+
import torch
|
34 |
+
from torch.utils.data import DataLoader
|
35 |
+
|
36 |
+
import datasets
|
37 |
+
from datasets import DatasetDict, Dataset, IterableDataset, concatenate_datasets
|
38 |
+
|
39 |
+
from huggingface_hub import HfApi
|
40 |
+
|
41 |
+
import transformers
|
42 |
+
from transformers import AutoFeatureExtractor, AutoTokenizer, HfArgumentParser
|
43 |
+
from transformers.trainer_pt_utils import LengthGroupedSampler
|
44 |
+
from transformers.optimization import get_scheduler
|
45 |
+
from transformers.utils import send_example_telemetry
|
46 |
+
|
47 |
+
|
48 |
+
from accelerate import Accelerator, skip_first_batches
|
49 |
+
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin, DistributedDataParallelKwargs
|
50 |
+
from accelerate.utils.memory import release_memory
|
51 |
+
|
52 |
+
from parler_tts import (
|
53 |
+
ParlerTTSConfig,
|
54 |
+
ParlerTTSForConditionalGeneration,
|
55 |
+
build_delay_pattern_mask,
|
56 |
+
)
|
57 |
+
|
58 |
+
from training.utils import (
|
59 |
+
get_last_checkpoint,
|
60 |
+
rotate_checkpoints,
|
61 |
+
log_pred,
|
62 |
+
log_metric,
|
63 |
+
load_all_codec_checkpoints,
|
64 |
+
save_codec_checkpoint,
|
65 |
+
get_last_codec_checkpoint_step,
|
66 |
+
)
|
67 |
+
from training.arguments_captts import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments
|
68 |
+
from training.data_captts import load_multiple_datasets, DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding
|
69 |
+
from training.eval import clap_similarity, wer, si_sdr
|
70 |
+
|
71 |
+
logger = logging.getLogger(__name__)
|
72 |
+
|
73 |
+
|
74 |
+
def main():
|
75 |
+
# See all possible arguments in src/transformers/training_args.py
|
76 |
+
# or by passing the --help flag to this script.
|
77 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
78 |
+
|
79 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments))
|
80 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
81 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
82 |
+
# let's parse it to get our arguments.
|
83 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
84 |
+
else:
|
85 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
86 |
+
|
87 |
+
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
88 |
+
# information sent is the one passed as arguments along with your Python/PyTorch versions.
|
89 |
+
send_example_telemetry("run_parler_tts", model_args, data_args)
|
90 |
+
|
91 |
+
if data_args.wandb_key is not None:
|
92 |
+
wandb.login(key=data_args.wandb_key)
|
93 |
+
|
94 |
+
if training_args.dtype == "float16":
|
95 |
+
mixed_precision = "fp16"
|
96 |
+
torch_dtype = torch.float16
|
97 |
+
elif training_args.dtype == "bfloat16":
|
98 |
+
mixed_precision = "bf16"
|
99 |
+
torch_dtype = torch.bfloat16
|
100 |
+
else:
|
101 |
+
mixed_precision = "no"
|
102 |
+
torch_dtype = torch.float32
|
103 |
+
|
104 |
+
if data_args.pad_to_max_length and (
|
105 |
+
data_args.max_duration_in_seconds is None
|
106 |
+
or data_args.max_prompt_token_length is None
|
107 |
+
or data_args.max_description_token_length is None
|
108 |
+
):
|
109 |
+
raise ValueError(
|
110 |
+
"`pad_to_max_length` is `True` but one of the following parameters has not been set: `max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`"
|
111 |
+
)
|
112 |
+
|
113 |
+
padding = "max_length" if data_args.pad_to_max_length else "longest"
|
114 |
+
|
115 |
+
####### A. Preparation
|
116 |
+
kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=120)), DistributedDataParallelKwargs(find_unused_parameters=False)]
|
117 |
+
|
118 |
+
accelerator = Accelerator(
|
119 |
+
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
|
120 |
+
mixed_precision=mixed_precision,
|
121 |
+
log_with=training_args.report_to,
|
122 |
+
project_dir=training_args.output_dir,
|
123 |
+
kwargs_handlers=kwargs_handlers,
|
124 |
+
)
|
125 |
+
|
126 |
+
accelerator.init_trackers(
|
127 |
+
project_name=data_args.wandb_project,
|
128 |
+
config={
|
129 |
+
"learning_rate": training_args.learning_rate,
|
130 |
+
"model_name_or_path": model_args.model_name_or_path,
|
131 |
+
"num_train_epochs": training_args.num_train_epochs,
|
132 |
+
"gradient_accumulation_steps": training_args.gradient_accumulation_steps,
|
133 |
+
"per_device_train_batch_size": training_args.per_device_train_batch_size,
|
134 |
+
"global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
|
135 |
+
"mixed_precision": mixed_precision,
|
136 |
+
"lr_scheduler_type": training_args.lr_scheduler_type,
|
137 |
+
"warmup_steps": training_args.warmup_steps,
|
138 |
+
"freeze_text_encoder": model_args.freeze_text_encoder,
|
139 |
+
"max_duration_in_seconds": data_args.max_duration_in_seconds,
|
140 |
+
"weight_decay": training_args.weight_decay,
|
141 |
+
"adam_beta1": training_args.adam_beta1,
|
142 |
+
"adam_beta2": training_args.adam_beta2,
|
143 |
+
"temperature": model_args.temperature,
|
144 |
+
},
|
145 |
+
init_kwargs={"wandb": {"name": data_args.wandb_run_name}} if data_args.wandb_run_name else {},
|
146 |
+
)
|
147 |
+
|
148 |
+
# Detecting last checkpoint and eventually continue from last checkpoint
|
149 |
+
last_checkpoint = None
|
150 |
+
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
151 |
+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
152 |
+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
153 |
+
raise ValueError(
|
154 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
155 |
+
"Use --overwrite_output_dir to overcome."
|
156 |
+
)
|
157 |
+
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
158 |
+
logger.info(
|
159 |
+
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
160 |
+
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
161 |
+
)
|
162 |
+
|
163 |
+
# Setup logging
|
164 |
+
logging.basicConfig(
|
165 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
166 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
167 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
168 |
+
)
|
169 |
+
logger.setLevel(logging.INFO if accelerator.is_main_process else logging.WARN)
|
170 |
+
|
171 |
+
# Log a small summary on each proces
|
172 |
+
logger.warning(
|
173 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
|
174 |
+
f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
|
175 |
+
)
|
176 |
+
|
177 |
+
# Set the verbosity to info of the Transformers logger (on main process only)
|
178 |
+
if accelerator.is_local_main_process:
|
179 |
+
datasets.utils.logging.set_verbosity_warning()
|
180 |
+
transformers.utils.logging.set_verbosity_info()
|
181 |
+
else:
|
182 |
+
datasets.utils.logging.set_verbosity_error()
|
183 |
+
transformers.utils.logging.set_verbosity_error()
|
184 |
+
|
185 |
+
logger.info("Training/evaluation parameters %s", training_args)
|
186 |
+
|
187 |
+
# Set seed before initializing model.
|
188 |
+
set_seed(training_args.seed)
|
189 |
+
num_workers = data_args.preprocessing_num_workers
|
190 |
+
|
191 |
+
# 1. First, lett's instantiate the feature extractor, tokenizers and model
|
192 |
+
# Note for distributed training, the .from_pretrained methods guarantee that only
|
193 |
+
# one local process can concurrently download model & vocab.
|
194 |
+
|
195 |
+
# load feature extractor
|
196 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
197 |
+
model_args.feature_extractor_name or model_args.model_name_or_path,
|
198 |
+
cache_dir=model_args.cache_dir,
|
199 |
+
token=data_args.token,
|
200 |
+
trust_remote_code=data_args.trust_remote_code,
|
201 |
+
)
|
202 |
+
sampling_rate = feature_extractor.sampling_rate
|
203 |
+
|
204 |
+
# load prompt tokenizer
|
205 |
+
prompt_tokenizer = AutoTokenizer.from_pretrained(
|
206 |
+
model_args.prompt_tokenizer_name or model_args.description_tokenizer_name or model_args.model_name_or_path,
|
207 |
+
cache_dir=model_args.cache_dir,
|
208 |
+
token=data_args.token,
|
209 |
+
trust_remote_code=data_args.trust_remote_code,
|
210 |
+
use_fast=model_args.use_fast_tokenizer,
|
211 |
+
padding_side=model_args.prompt_padding_side,
|
212 |
+
)
|
213 |
+
|
214 |
+
# load description tokenizer
|
215 |
+
description_tokenizer = AutoTokenizer.from_pretrained(
|
216 |
+
model_args.description_tokenizer_name or model_args.model_name_or_path,
|
217 |
+
cache_dir=model_args.cache_dir,
|
218 |
+
token=data_args.token,
|
219 |
+
trust_remote_code=data_args.trust_remote_code,
|
220 |
+
use_fast=model_args.use_fast_tokenizer,
|
221 |
+
)
|
222 |
+
|
223 |
+
if model_args.use_fast_tokenizer:
|
224 |
+
logger.warning(
|
225 |
+
"Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235"
|
226 |
+
)
|
227 |
+
prompt_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
|
228 |
+
description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
|
229 |
+
|
230 |
+
# 2. Now, let's load the dataset
|
231 |
+
|
232 |
+
if data_args.save_to_disk is not None:
|
233 |
+
os.makedirs(data_args.save_to_disk, exist_ok=True)
|
234 |
+
|
235 |
+
# assume that the dataset has been saved to `save_to_disk` if the latter is not empty
|
236 |
+
dataset_was_precomputed = len(os.listdir(data_args.save_to_disk)) > 0
|
237 |
+
if dataset_was_precomputed:
|
238 |
+
with accelerator.local_main_process_first():
|
239 |
+
vectorized_datasets = datasets.load_from_disk(data_args.save_to_disk)
|
240 |
+
else:
|
241 |
+
raw_datasets = DatasetDict()
|
242 |
+
|
243 |
+
columns_to_keep = {
|
244 |
+
"target_audio_column_name": data_args.target_audio_column_name,
|
245 |
+
"prompt_column_name": data_args.prompt_column_name,
|
246 |
+
"source": data_args.source_column_name,
|
247 |
+
}
|
248 |
+
if data_args.description_column_name is not None:
|
249 |
+
columns_to_keep["description_column_name"] = data_args.description_column_name
|
250 |
+
|
251 |
+
if training_args.do_train:
|
252 |
+
raw_datasets["train"] = load_multiple_datasets(
|
253 |
+
accelerator,
|
254 |
+
data_args.train_dataset_name,
|
255 |
+
splits=data_args.train_split_name,
|
256 |
+
dataset_samples=data_args.train_dataset_samples,
|
257 |
+
seed=training_args.seed,
|
258 |
+
cache_dir=model_args.cache_dir,
|
259 |
+
num_proc=data_args.preprocessing_num_workers,
|
260 |
+
id_column_name=data_args.id_column_name,
|
261 |
+
columns_to_keep=columns_to_keep.values(),
|
262 |
+
prompt_column_name=data_args.prompt_column_name,
|
263 |
+
audio_column_name=data_args.target_audio_column_name,
|
264 |
+
sampling_rate=sampling_rate,
|
265 |
+
logger=logger,
|
266 |
+
librittsr_dir=data_args.librittsr_dir,
|
267 |
+
other_dir=data_args.other_dir,
|
268 |
+
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
|
269 |
+
)
|
270 |
+
|
271 |
+
for key in columns_to_keep:
|
272 |
+
if columns_to_keep[key] not in raw_datasets["train"].column_names:
|
273 |
+
raise ValueError(
|
274 |
+
f"--{key} '{columns_to_keep[key]}' not found in dataset '{data_args.train_dataset_name}'."
|
275 |
+
f" Make sure to set `--{key}` to the correct audio column - one of"
|
276 |
+
f" {', '.join(raw_datasets['train'].column_names)}."
|
277 |
+
)
|
278 |
+
|
279 |
+
if data_args.max_train_samples is not None:
|
280 |
+
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
281 |
+
|
282 |
+
if training_args.do_eval:
|
283 |
+
raw_datasets["eval"] = load_multiple_datasets(
|
284 |
+
accelerator,
|
285 |
+
data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
|
286 |
+
splits=data_args.eval_split_name,
|
287 |
+
cache_dir=model_args.cache_dir,
|
288 |
+
num_proc=data_args.preprocessing_num_workers,
|
289 |
+
id_column_name=data_args.id_column_name,
|
290 |
+
columns_to_keep=columns_to_keep.values(),
|
291 |
+
prompt_column_name=data_args.prompt_column_name,
|
292 |
+
audio_column_name=data_args.target_audio_column_name,
|
293 |
+
sampling_rate=sampling_rate,
|
294 |
+
logger=logger,
|
295 |
+
librittsr_dir=data_args.librittsr_dir,
|
296 |
+
other_dir=data_args.other_dir,
|
297 |
+
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
|
298 |
+
)
|
299 |
+
|
300 |
+
if data_args.max_eval_samples is not None:
|
301 |
+
with accelerator.local_main_process_first():
|
302 |
+
raw_datasets["eval"] = (
|
303 |
+
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
|
304 |
+
)
|
305 |
+
|
306 |
+
# 3. Next, let's load the config.
|
307 |
+
config = ParlerTTSConfig.from_pretrained(
|
308 |
+
model_args.model_name_or_path,
|
309 |
+
cache_dir=model_args.cache_dir,
|
310 |
+
token=data_args.token,
|
311 |
+
trust_remote_code=data_args.trust_remote_code,
|
312 |
+
)
|
313 |
+
|
314 |
+
if training_args.codebook_weights is not None and len(training_args.codebook_weights) != config.decoder.num_codebooks:
|
315 |
+
raise ValueError(f"`codebook_weights` has length {len(training_args.codebook_weights)} when it should be of length {config.decoder.num_codebooks}.")
|
316 |
+
|
317 |
+
# update pad token id and decoder_start_token_id
|
318 |
+
config.decoder.update(
|
319 |
+
{
|
320 |
+
"cross_attention_implementation_strategy": model_args.cross_attention_implementation_strategy
|
321 |
+
if model_args.cross_attention_implementation_strategy is not None
|
322 |
+
else None,
|
323 |
+
"codebook_weights": training_args.codebook_weights if training_args.codebook_weights is not None else config.decoder.codebook_weights
|
324 |
+
}
|
325 |
+
)
|
326 |
+
config.update(
|
327 |
+
{
|
328 |
+
"pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id,
|
329 |
+
"decoder_start_token_id": model_args.decoder_start_token_id
|
330 |
+
if model_args.decoder_start_token_id is not None
|
331 |
+
else config.decoder_start_token_id,
|
332 |
+
}
|
333 |
+
)
|
334 |
+
|
335 |
+
with open("events.txt", "r") as f:
|
336 |
+
events = [line.strip() for line in f]
|
337 |
+
events = ["<"+event.lower().replace(" ", "_")+">" for event in events]
|
338 |
+
events.append("<B_start>")
|
339 |
+
events.append("<B_end>")
|
340 |
+
events.append("<I_start>")
|
341 |
+
events.append("<I_end>")
|
342 |
+
|
343 |
+
special_tokens = {"additional_special_tokens": events}
|
344 |
+
prompt_tokenizer.add_special_tokens(special_tokens)
|
345 |
+
description_tokenizer.add_special_tokens(special_tokens)
|
346 |
+
padded_vocab_size = ((len(prompt_tokenizer) + 127) // 128) * 128
|
347 |
+
config.vocab_size = padded_vocab_size
|
348 |
+
|
349 |
+
# create model
|
350 |
+
model = ParlerTTSForConditionalGeneration.from_pretrained(
|
351 |
+
model_args.model_name_or_path,
|
352 |
+
ignore_mismatched_sizes=True,
|
353 |
+
cache_dir=model_args.cache_dir,
|
354 |
+
config=config,
|
355 |
+
token=data_args.token,
|
356 |
+
trust_remote_code=data_args.trust_remote_code,
|
357 |
+
attn_implementation={"decoder": model_args.attn_implementation, "text_encoder": "eager"},
|
358 |
+
)
|
359 |
+
model.text_encoder.resize_token_embeddings(padded_vocab_size)
|
360 |
+
|
361 |
+
# enable gradient checkpointing if necessary
|
362 |
+
if training_args.gradient_checkpointing:
|
363 |
+
model.gradient_checkpointing_enable()
|
364 |
+
|
365 |
+
# 4. Now we preprocess the datasets including loading the audio, resampling and normalization
|
366 |
+
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
|
367 |
+
# so that we just need to set the correct target sampling rate and normalize the input
|
368 |
+
# via the `feature_extractor`
|
369 |
+
|
370 |
+
# derive max & min input length for sample rate & max duration
|
371 |
+
sampling_rate = feature_extractor.sampling_rate
|
372 |
+
max_target_length = int(data_args.max_duration_in_seconds * sampling_rate)
|
373 |
+
min_target_length = int(data_args.min_duration_in_seconds * sampling_rate)
|
374 |
+
target_audio_column_name = data_args.target_audio_column_name
|
375 |
+
description_column_name = data_args.description_column_name
|
376 |
+
prompt_column_name = data_args.prompt_column_name
|
377 |
+
feature_extractor_input_name = feature_extractor.model_input_names[0]
|
378 |
+
audio_encoder_pad_token_id = config.decoder.pad_token_id
|
379 |
+
audio_encoder_eos_token_id = config.decoder.eos_token_id
|
380 |
+
audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id
|
381 |
+
max_length = model.generation_config.max_length
|
382 |
+
num_codebooks = model.decoder.config.num_codebooks
|
383 |
+
bandwidth = model_args.bandwidth
|
384 |
+
attn_implementation = model_args.attn_implementation
|
385 |
+
|
386 |
+
# Freeze Encoders
|
387 |
+
model.freeze_encoders(model_args.freeze_text_encoder)
|
388 |
+
|
389 |
+
# Test all gather - used for warmout and avoiding timeout
|
390 |
+
logger.debug(str(accelerator.process_index), main_process_only=False, in_order=True)
|
391 |
+
test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
|
392 |
+
gathered_tensor = accelerator.gather(test_tensor)
|
393 |
+
print("gathered_tensor", gathered_tensor)
|
394 |
+
accelerator.wait_for_everyone()
|
395 |
+
|
396 |
+
if not dataset_was_precomputed:
|
397 |
+
# Filter on text length
|
398 |
+
if description_column_name is not None and data_args.max_text_length is not None:
|
399 |
+
with accelerator.local_main_process_first():
|
400 |
+
# filter description that is shorter than max_text_length
|
401 |
+
raw_datasets = raw_datasets.filter(
|
402 |
+
lambda x: len(x) < data_args.max_text_length,
|
403 |
+
num_proc=num_workers,
|
404 |
+
input_columns=[description_column_name],
|
405 |
+
)
|
406 |
+
|
407 |
+
# Preprocessing the dataset.
|
408 |
+
# We need to tokenize the texts.
|
409 |
+
def pass_through_processors(description, prompt):
|
410 |
+
batch = {}
|
411 |
+
|
412 |
+
batch["input_ids"] = description_tokenizer(description.strip())["input_ids"]
|
413 |
+
batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"]
|
414 |
+
|
415 |
+
return batch
|
416 |
+
|
417 |
+
with accelerator.local_main_process_first():
|
418 |
+
# this is a trick to avoid to rewrite the entire audio column which takes ages
|
419 |
+
vectorized_datasets = raw_datasets.map(
|
420 |
+
pass_through_processors,
|
421 |
+
remove_columns=next(iter(raw_datasets.values())).column_names,
|
422 |
+
input_columns=[description_column_name, prompt_column_name],
|
423 |
+
num_proc=num_workers,
|
424 |
+
desc="preprocess datasets",
|
425 |
+
)
|
426 |
+
|
427 |
+
# We use Accelerate to perform distributed inference
|
428 |
+
# T5 doesn't support fp16
|
429 |
+
autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
|
430 |
+
|
431 |
+
# Now we encode the audio labels with encodec.
|
432 |
+
####### B. Encode audio
|
433 |
+
|
434 |
+
logger.info("*** Encode target audio with encodec ***")
|
435 |
+
|
436 |
+
# no need to prepare audio_decoder because used for inference without mixed precision
|
437 |
+
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
|
438 |
+
if training_args.torch_compile:
|
439 |
+
audio_decoder = accelerator.prepare_model(model.audio_encoder, evaluation_mode=True)
|
440 |
+
else:
|
441 |
+
audio_decoder = model.audio_encoder
|
442 |
+
|
443 |
+
encoder_data_collator = DataCollatorEncodecWithPadding(
|
444 |
+
feature_extractor,
|
445 |
+
audio_column_name=target_audio_column_name,
|
446 |
+
librittsr_dir=data_args.librittsr_dir,
|
447 |
+
other_dir=data_args.other_dir,
|
448 |
+
feature_extractor_input_name=feature_extractor_input_name,
|
449 |
+
max_length=max_target_length,
|
450 |
+
padding=padding,
|
451 |
+
)
|
452 |
+
encoder_signature = set(inspect.signature(audio_decoder.forward).parameters)
|
453 |
+
|
454 |
+
def apply_audio_decoder(batch):
|
455 |
+
len_audio = batch.pop("len_audio")
|
456 |
+
audio_decoder.to(batch["input_values"].device).eval()
|
457 |
+
if bandwidth is not None:
|
458 |
+
batch["bandwidth"] = bandwidth
|
459 |
+
elif "num_quantizers" in encoder_signature:
|
460 |
+
batch["num_quantizers"] = num_codebooks
|
461 |
+
elif "num_codebooks" in encoder_signature:
|
462 |
+
batch["num_codebooks"] = num_codebooks
|
463 |
+
elif "n_quantizers" in encoder_signature:
|
464 |
+
batch["n_quantizers"] = num_codebooks
|
465 |
+
|
466 |
+
with torch.no_grad():
|
467 |
+
labels = audio_decoder.encode(**batch)["audio_codes"]
|
468 |
+
output = {}
|
469 |
+
output["len_audio"] = len_audio
|
470 |
+
# (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
|
471 |
+
output["labels"] = labels.squeeze(0).transpose(1, 2)
|
472 |
+
|
473 |
+
# if `pad_to_max_length`, the maximum corresponding audio length of the current batch is max_duration*sampling_rate
|
474 |
+
max_length = len_audio.max() if padding != "max_length" else max_target_length
|
475 |
+
output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / max_length
|
476 |
+
return output
|
477 |
+
|
478 |
+
# (1, codebooks, seq_len) where seq_len=1
|
479 |
+
bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
|
480 |
+
|
481 |
+
def postprocess_dataset(labels):
|
482 |
+
# (1, codebooks, seq_len)
|
483 |
+
labels = torch.tensor(labels).unsqueeze(0)
|
484 |
+
# add bos
|
485 |
+
labels = torch.cat([bos_labels, labels], dim=-1)
|
486 |
+
|
487 |
+
labels, delay_pattern_mask = build_delay_pattern_mask(
|
488 |
+
labels,
|
489 |
+
bos_token_id=audio_encoder_bos_token_id,
|
490 |
+
pad_token_id=audio_encoder_eos_token_id,
|
491 |
+
max_length=labels.shape[-1] + num_codebooks,
|
492 |
+
num_codebooks=num_codebooks,
|
493 |
+
)
|
494 |
+
|
495 |
+
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
|
496 |
+
# to take care of EOS
|
497 |
+
# we want labels to look like this:
|
498 |
+
# - [B, a, b, E, E, E, E]
|
499 |
+
# - [B, B, c, d, E, E, E]
|
500 |
+
# - [B, B, B, e, f, E, E]
|
501 |
+
# - [B, B, B, B, g, h, E]
|
502 |
+
labels = torch.where(delay_pattern_mask == -1, audio_encoder_eos_token_id, delay_pattern_mask)
|
503 |
+
|
504 |
+
# the first timestamp is associated to a row full of BOS, let's get rid of it
|
505 |
+
# we also remove the last timestampts (full of PAD)
|
506 |
+
output = {"labels": labels[:, 1:]}
|
507 |
+
return output
|
508 |
+
|
509 |
+
for split in vectorized_datasets:
|
510 |
+
data_loader = DataLoader(
|
511 |
+
raw_datasets[split],
|
512 |
+
batch_size=training_args.audio_encoder_per_device_batch_size,
|
513 |
+
collate_fn=encoder_data_collator,
|
514 |
+
num_workers=training_args.dataloader_num_workers,
|
515 |
+
pin_memory=True,
|
516 |
+
)
|
517 |
+
data_loader = accelerator.prepare(data_loader)
|
518 |
+
total_inference_steps = len(data_loader)
|
519 |
+
|
520 |
+
start_step = get_last_codec_checkpoint_step(os.path.join(data_args.temporary_save_to_disk, split))
|
521 |
+
accelerator.wait_for_everyone()
|
522 |
+
if start_step > 0:
|
523 |
+
logger.info(f"Resuming {split} from step {start_step}")
|
524 |
+
# efficiently skip the first n batches
|
525 |
+
start_step += 1
|
526 |
+
data_loader = skip_first_batches(data_loader, start_step)
|
527 |
+
|
528 |
+
all_generated_labels = []
|
529 |
+
all_lens = []
|
530 |
+
if start_step < total_inference_steps:
|
531 |
+
for i, batch in enumerate(tqdm(data_loader, disable=not accelerator.is_local_main_process)):
|
532 |
+
cur_step = start_step + i
|
533 |
+
generate_labels = apply_audio_decoder(batch)
|
534 |
+
generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0)
|
535 |
+
generate_labels = accelerator.gather_for_metrics(generate_labels)
|
536 |
+
|
537 |
+
if accelerator.is_main_process:
|
538 |
+
lab = generate_labels["labels"].cpu().transpose(1, 2).to(torch.int16)
|
539 |
+
rat = generate_labels["ratio"].cpu().squeeze(1)
|
540 |
+
lens = generate_labels["len_audio"].cpu().squeeze(1)
|
541 |
+
lab = [l[:, : int(ratio * length)] for (l, ratio, length) in zip(lab, rat, lens)]
|
542 |
+
|
543 |
+
all_generated_labels.extend(lab)
|
544 |
+
all_lens.extend(lens)
|
545 |
+
|
546 |
+
if ((cur_step + 1) % data_args.save_codec_steps == 0) or (
|
547 |
+
cur_step == total_inference_steps - 1
|
548 |
+
):
|
549 |
+
tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
|
550 |
+
tmp_labels = tmp_labels.map(
|
551 |
+
postprocess_dataset,
|
552 |
+
num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
|
553 |
+
input_columns=["labels"],
|
554 |
+
desc="Postprocessing labeling",
|
555 |
+
)
|
556 |
+
save_codec_checkpoint(
|
557 |
+
os.path.join(data_args.temporary_save_to_disk, split), tmp_labels, cur_step
|
558 |
+
)
|
559 |
+
all_generated_labels = []
|
560 |
+
all_lens = []
|
561 |
+
|
562 |
+
accelerator.wait_for_everyone()
|
563 |
+
|
564 |
+
if accelerator.is_main_process and len(all_generated_labels) > 0:
|
565 |
+
tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
|
566 |
+
tmp_labels = tmp_labels.map(
|
567 |
+
postprocess_dataset,
|
568 |
+
num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
|
569 |
+
input_columns=["labels"],
|
570 |
+
desc="Postprocessing labeling",
|
571 |
+
)
|
572 |
+
save_codec_checkpoint(os.path.join(data_args.temporary_save_to_disk, split), tmp_labels, cur_step)
|
573 |
+
all_generated_labels = []
|
574 |
+
all_lens = []
|
575 |
+
accelerator.wait_for_everyone()
|
576 |
+
|
577 |
+
del all_generated_labels
|
578 |
+
accelerator.wait_for_everyone()
|
579 |
+
|
580 |
+
with accelerator.local_main_process_first():
|
581 |
+
tmp_labels = load_all_codec_checkpoints(os.path.join(data_args.temporary_save_to_disk, split)).select(
|
582 |
+
range(len(vectorized_datasets[split]))
|
583 |
+
)
|
584 |
+
logger.info(f"Concatenating {split}: {tmp_labels} with {vectorized_datasets[split]}")
|
585 |
+
vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1)
|
586 |
+
|
587 |
+
accelerator.free_memory()
|
588 |
+
del generate_labels, all_lens
|
589 |
+
|
590 |
+
with accelerator.local_main_process_first():
|
591 |
+
# NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations
|
592 |
+
# caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets.
|
593 |
+
# That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets.
|
594 |
+
|
595 |
+
def is_audio_in_length_range(length):
|
596 |
+
return length > min_target_length and length < max_target_length
|
597 |
+
|
598 |
+
# filter data that is shorter than min_target_length
|
599 |
+
vectorized_datasets = vectorized_datasets.filter(
|
600 |
+
is_audio_in_length_range,
|
601 |
+
num_proc=num_workers,
|
602 |
+
input_columns=["target_length"],
|
603 |
+
)
|
604 |
+
|
605 |
+
if description_column_name is not None and data_args.max_description_token_length is not None:
|
606 |
+
with accelerator.local_main_process_first():
|
607 |
+
# filter description that is shorter than max_text_length
|
608 |
+
vectorized_datasets = vectorized_datasets.filter(
|
609 |
+
lambda x: len(x) < data_args.max_description_token_length,
|
610 |
+
num_proc=num_workers,
|
611 |
+
input_columns=["input_ids"],
|
612 |
+
)
|
613 |
+
|
614 |
+
if data_args.max_prompt_token_length is not None:
|
615 |
+
with accelerator.local_main_process_first():
|
616 |
+
# filter description that is shorter than max_text_length
|
617 |
+
vectorized_datasets = vectorized_datasets.filter(
|
618 |
+
lambda x: len(x) < data_args.max_prompt_token_length,
|
619 |
+
num_proc=num_workers,
|
620 |
+
input_columns=["prompt_input_ids"],
|
621 |
+
)
|
622 |
+
|
623 |
+
if data_args.save_to_disk is not None and not dataset_was_precomputed:
|
624 |
+
if accelerator.is_main_process:
|
625 |
+
vectorized_datasets.save_to_disk(
|
626 |
+
data_args.save_to_disk,
|
627 |
+
num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"]) - 1),
|
628 |
+
)
|
629 |
+
accelerator.wait_for_everyone()
|
630 |
+
logger.info(f"Dataset saved at {data_args.save_to_disk}")
|
631 |
+
|
632 |
+
audio_max_length = None
|
633 |
+
if padding == "max_length":
|
634 |
+
audio_max_length = max(vectorized_datasets["train"]["target_length"])
|
635 |
+
with accelerator.local_main_process_first():
|
636 |
+
max_sample = vectorized_datasets["train"].filter(
|
637 |
+
lambda x: x == audio_max_length,
|
638 |
+
num_proc=num_workers,
|
639 |
+
input_columns=["target_length"],
|
640 |
+
)
|
641 |
+
audio_max_length = max([len(l[0]) for l in max_sample["labels"]])
|
642 |
+
|
643 |
+
if description_column_name is not None and data_args.max_description_token_length is not None:
|
644 |
+
with accelerator.local_main_process_first():
|
645 |
+
# filter description that is shorter than max_text_length
|
646 |
+
vectorized_datasets = vectorized_datasets.filter(
|
647 |
+
lambda x: len(x) < data_args.max_description_token_length,
|
648 |
+
num_proc=num_workers,
|
649 |
+
input_columns=["input_ids"],
|
650 |
+
)
|
651 |
+
|
652 |
+
if data_args.max_prompt_token_length is not None:
|
653 |
+
with accelerator.local_main_process_first():
|
654 |
+
# filter description that is shorter than max_text_length
|
655 |
+
vectorized_datasets = vectorized_datasets.filter(
|
656 |
+
lambda x: len(x) < data_args.max_prompt_token_length,
|
657 |
+
num_proc=num_workers,
|
658 |
+
input_columns=["prompt_input_ids"],
|
659 |
+
)
|
660 |
+
|
661 |
+
if training_args.group_by_length:
|
662 |
+
# apply a simple heuristic to take into account audio and text lengths
|
663 |
+
def add_target_lengths(target_length, prompt, description):
|
664 |
+
return {"target_length": target_length + len(prompt) + len(description)}
|
665 |
+
|
666 |
+
with accelerator.local_main_process_first():
|
667 |
+
vectorized_datasets = vectorized_datasets.map(
|
668 |
+
add_target_lengths,
|
669 |
+
num_proc=num_workers,
|
670 |
+
input_columns=["target_length", "prompt_input_ids", "input_ids"],
|
671 |
+
)
|
672 |
+
|
673 |
+
# for large datasets it is advised to run the preprocessing on a
|
674 |
+
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
675 |
+
# be a timeout when running the script in distributed mode.
|
676 |
+
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the
|
677 |
+
# cached dataset
|
678 |
+
if data_args.preprocessing_only and data_args.save_to_disk is None:
|
679 |
+
raise ValueError(
|
680 |
+
"`preprocessing_only=True` but `save_to_disk` is not set. The latter should indicates where to save the dataset locally."
|
681 |
+
)
|
682 |
+
elif data_args.preprocessing_only:
|
683 |
+
logger.info(f"Data preprocessing finished. Files save at {data_args.save_to_disk}")
|
684 |
+
return
|
685 |
+
|
686 |
+
# 6. Next, we can prepare the training.
|
687 |
+
|
688 |
+
# Let's use word CLAP similary and WER metrics as our evaluation metrics,
|
689 |
+
def compute_metrics(
|
690 |
+
audios,
|
691 |
+
descriptions,
|
692 |
+
prompts,
|
693 |
+
device="cpu",
|
694 |
+
compute_clap_similarity_metric=False,
|
695 |
+
compute_noise_level_metric=False,
|
696 |
+
noise_level_to_compute_clean_wer=None,
|
697 |
+
):
|
698 |
+
results = {}
|
699 |
+
input_ids = descriptions
|
700 |
+
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
701 |
+
prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
|
702 |
+
audios = [a.float().cpu().numpy() for a in audios]
|
703 |
+
|
704 |
+
if compute_clap_similarity_metric:
|
705 |
+
clap_score = clap_similarity(
|
706 |
+
model_args.clap_model_name_or_path, texts, audios, device, input_sampling_rate=sampling_rate
|
707 |
+
)
|
708 |
+
results["clap"] = clap_score
|
709 |
+
|
710 |
+
si_sdr_measures = None
|
711 |
+
if compute_noise_level_metric:
|
712 |
+
si_sdr_measures = si_sdr(audios, device, input_sampling_rate=sampling_rate)
|
713 |
+
|
714 |
+
word_error, transcriptions, clean_word_error, noisy_word_error, percent_clean_samples = wer(
|
715 |
+
model_args.asr_model_name_or_path,
|
716 |
+
prompts,
|
717 |
+
audios,
|
718 |
+
device,
|
719 |
+
training_args.per_device_eval_batch_size,
|
720 |
+
sampling_rate,
|
721 |
+
noise_level_to_compute_clean_wer,
|
722 |
+
si_sdr_measures,
|
723 |
+
)
|
724 |
+
results["wer"] = word_error
|
725 |
+
if clean_word_error is not None:
|
726 |
+
results["clean_wer"] = clean_word_error
|
727 |
+
results["noisy_word_error"] = noisy_word_error
|
728 |
+
results["percent_clean_samples"] = percent_clean_samples
|
729 |
+
|
730 |
+
return results, texts, prompts, audios, transcriptions, si_sdr_measures
|
731 |
+
|
732 |
+
# Define Training Schedule
|
733 |
+
# Store some constants
|
734 |
+
per_device_train_batch_size = int(training_args.per_device_train_batch_size)
|
735 |
+
train_batch_size = per_device_train_batch_size * accelerator.num_processes
|
736 |
+
gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
|
737 |
+
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
|
738 |
+
|
739 |
+
if training_args.max_steps < 0:
|
740 |
+
num_epochs = int(training_args.num_train_epochs)
|
741 |
+
steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
|
742 |
+
total_train_steps = steps_per_epoch * num_epochs
|
743 |
+
elif training_args.max_steps > 0:
|
744 |
+
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
745 |
+
total_train_steps = int(training_args.max_steps)
|
746 |
+
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
|
747 |
+
num_epochs = sys.maxsize
|
748 |
+
steps_per_epoch = total_train_steps
|
749 |
+
|
750 |
+
if training_args.eval_steps is None:
|
751 |
+
logger.info(f"eval_steps is not set, evaluating at the end of each epoch")
|
752 |
+
eval_steps = steps_per_epoch
|
753 |
+
else:
|
754 |
+
eval_steps = training_args.eval_steps
|
755 |
+
|
756 |
+
if training_args.eval_generation_steps is None:
|
757 |
+
eval_generation_steps = eval_steps
|
758 |
+
else:
|
759 |
+
eval_generation_steps = training_args.eval_generation_steps
|
760 |
+
|
761 |
+
# T5 doesn't support fp16
|
762 |
+
autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
|
763 |
+
|
764 |
+
# Define optimizer, LR scheduler, collator
|
765 |
+
optimizer = torch.optim.AdamW(
|
766 |
+
params=model.parameters(),
|
767 |
+
lr=training_args.learning_rate,
|
768 |
+
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
769 |
+
eps=training_args.adam_epsilon,
|
770 |
+
weight_decay=training_args.weight_decay,
|
771 |
+
)
|
772 |
+
|
773 |
+
# LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
|
774 |
+
lr_scheduler = get_scheduler(
|
775 |
+
name=training_args.lr_scheduler_type,
|
776 |
+
optimizer=optimizer,
|
777 |
+
num_warmup_steps=training_args.get_warmup_steps(total_train_steps) * accelerator.num_processes,
|
778 |
+
num_training_steps=total_train_steps * accelerator.num_processes,
|
779 |
+
)
|
780 |
+
|
781 |
+
# Instantiate custom data collator
|
782 |
+
data_collator = DataCollatorParlerTTSWithPadding(
|
783 |
+
prompt_tokenizer=prompt_tokenizer,
|
784 |
+
description_tokenizer=description_tokenizer,
|
785 |
+
pad_to_multiple_of=data_args.pad_to_multiple_of,
|
786 |
+
padding=padding,
|
787 |
+
prompt_max_length=data_args.max_prompt_token_length,
|
788 |
+
description_max_length=data_args.max_description_token_length,
|
789 |
+
audio_max_length=audio_max_length,
|
790 |
+
)
|
791 |
+
|
792 |
+
# Prepare everything with accelerate
|
793 |
+
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
|
794 |
+
|
795 |
+
num_examples = total_train_steps * train_batch_size * gradient_accumulation_steps
|
796 |
+
logger.info("***** Running training *****")
|
797 |
+
logger.info(f" Num examples = {num_examples}")
|
798 |
+
logger.info(" Instantaneous batch size per device =" f" {per_device_train_batch_size}")
|
799 |
+
logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
|
800 |
+
logger.info(
|
801 |
+
f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
|
802 |
+
)
|
803 |
+
logger.info(f" Total optimization steps = {total_train_steps}")
|
804 |
+
|
805 |
+
# ======================== Training ================================
|
806 |
+
train_time = 0
|
807 |
+
train_start = time.time()
|
808 |
+
steps_trained_progress_bar = tqdm(
|
809 |
+
range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
|
810 |
+
)
|
811 |
+
continue_training = True
|
812 |
+
epochs_trained = 0
|
813 |
+
cur_step = 0
|
814 |
+
|
815 |
+
checkpoint = None
|
816 |
+
if training_args.resume_from_checkpoint is not None:
|
817 |
+
checkpoint = training_args.resume_from_checkpoint
|
818 |
+
elif last_checkpoint is not None:
|
819 |
+
checkpoint = last_checkpoint
|
820 |
+
|
821 |
+
if accelerator.is_main_process:
|
822 |
+
if training_args.push_to_hub:
|
823 |
+
api = HfApi(token=training_args.hub_token)
|
824 |
+
|
825 |
+
# Create repo (repo_name from args or inferred)
|
826 |
+
repo_name = training_args.hub_model_id
|
827 |
+
if repo_name is None:
|
828 |
+
repo_name = Path(training_args.output_dir).absolute().name
|
829 |
+
repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
|
830 |
+
|
831 |
+
with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
|
832 |
+
if "wandb" not in gitignore:
|
833 |
+
gitignore.write("wandb\n")
|
834 |
+
elif training_args.output_dir is not None:
|
835 |
+
os.makedirs(training_args.output_dir, exist_ok=True)
|
836 |
+
accelerator.wait_for_everyone()
|
837 |
+
|
838 |
+
# Now save everything to be able to create a single processor later
|
839 |
+
# make sure all processes wait until data is saved
|
840 |
+
# only the main process saves them
|
841 |
+
if accelerator.is_main_process:
|
842 |
+
# save feature extractor, tokenizer and config
|
843 |
+
if (
|
844 |
+
model_args.prompt_tokenizer_name is None
|
845 |
+
and model_args.description_tokenizer_name
|
846 |
+
or (model_args.prompt_tokenizer_name == model_args.description_tokenizer_name)
|
847 |
+
):
|
848 |
+
prompt_tokenizer.save_pretrained(training_args.output_dir)
|
849 |
+
else:
|
850 |
+
logger.warning(
|
851 |
+
f"Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer."
|
852 |
+
)
|
853 |
+
prompt_tokenizer.save_pretrained(training_args.output_dir)
|
854 |
+
|
855 |
+
feature_extractor.save_pretrained(training_args.output_dir)
|
856 |
+
config.save_pretrained(training_args.output_dir)
|
857 |
+
accelerator.wait_for_everyone()
|
858 |
+
|
859 |
+
if checkpoint is not None:
|
860 |
+
accelerator.load_state(checkpoint)
|
861 |
+
# Find num steps and epoch from saved state string pattern
|
862 |
+
pattern = r"checkpoint-(\d+)-epoch-(\d+)"
|
863 |
+
match = re.search(pattern, checkpoint)
|
864 |
+
cur_step = int(match.group(1))
|
865 |
+
epochs_trained = int(match.group(2))
|
866 |
+
|
867 |
+
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
868 |
+
logger.info(f" Continuing training from epoch {epochs_trained}")
|
869 |
+
logger.info(f" Continuing training from global step {cur_step}")
|
870 |
+
|
871 |
+
steps_trained_progress_bar.update(cur_step)
|
872 |
+
|
873 |
+
for epoch in range(0, epochs_trained):
|
874 |
+
with accelerator.local_main_process_first():
|
875 |
+
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
|
876 |
+
|
877 |
+
if training_args.max_steps < 0:
|
878 |
+
# we know exactly the number of steps per epoch, so can skip through the required number of batches
|
879 |
+
resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
|
880 |
+
else:
|
881 |
+
# Currently we don't know how many steps we've taken in the current epoch
|
882 |
+
# So we just shuffle the dataset one extra time and start from a fresh epoch
|
883 |
+
# This is "good enough" for our purposes but not fully correct
|
884 |
+
resume_step = None
|
885 |
+
with accelerator.local_main_process_first():
|
886 |
+
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
|
887 |
+
else:
|
888 |
+
resume_step = None
|
889 |
+
|
890 |
+
gen_kwargs = {
|
891 |
+
"do_sample": model_args.do_sample,
|
892 |
+
"temperature": model_args.temperature,
|
893 |
+
"max_length": model_args.max_length,
|
894 |
+
# Because of the delayed pattern mask, generation might stop earlier because of unexpected behaviour
|
895 |
+
# on the first tokens of the codebooks that are delayed.
|
896 |
+
# This fix the issue.
|
897 |
+
"min_new_tokens": num_codebooks + 1,
|
898 |
+
}
|
899 |
+
|
900 |
+
# Define gradient update step fn
|
901 |
+
def train_step(
|
902 |
+
batch,
|
903 |
+
accelerator,
|
904 |
+
autocast_kwargs,
|
905 |
+
num_items_in_batch,
|
906 |
+
gradient_accumulation_steps,
|
907 |
+
):
|
908 |
+
if mixed_precision == "fp16":
|
909 |
+
# fp16 doesn't work with T5-like models
|
910 |
+
with accelerator.autocast(autocast_handler=autocast_kwargs):
|
911 |
+
if training_args.parallel_mode.value != "distributed":
|
912 |
+
encoder_outputs = model.text_encoder(
|
913 |
+
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
|
914 |
+
)
|
915 |
+
else:
|
916 |
+
encoder_outputs = model.module.text_encoder(
|
917 |
+
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
|
918 |
+
)
|
919 |
+
# we optionnally project last_hidden_state to avoid recomputing every time
|
920 |
+
encoder_hidden_states = encoder_outputs.last_hidden_state
|
921 |
+
if (
|
922 |
+
config.text_encoder.hidden_size != config.decoder.hidden_size
|
923 |
+
and config.decoder.cross_attention_hidden_size is None
|
924 |
+
):
|
925 |
+
encoder_hidden_states = (
|
926 |
+
model.enc_to_dec_proj(encoder_hidden_states)
|
927 |
+
if training_args.parallel_mode.value != "distributed"
|
928 |
+
else model.module.enc_to_dec_proj(encoder_hidden_states)
|
929 |
+
)
|
930 |
+
|
931 |
+
if batch.get("attention_mask", None) is not None:
|
932 |
+
encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None]
|
933 |
+
|
934 |
+
encoder_outputs.last_hidden_state = encoder_hidden_states
|
935 |
+
batch["encoder_outputs"] = encoder_outputs
|
936 |
+
|
937 |
+
outputs = model(**batch, loss_reduction="sum")
|
938 |
+
# CE (data) loss
|
939 |
+
ce_loss = (outputs.loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch
|
940 |
+
|
941 |
+
metrics = {"loss": ce_loss}
|
942 |
+
|
943 |
+
# per CE loss
|
944 |
+
per_codebook_losses = outputs.per_codebook_losses
|
945 |
+
metrics.update({f"codebook_{i}_loss": ((l * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch) for (i,l) in enumerate(per_codebook_losses)})
|
946 |
+
return ce_loss, metrics
|
947 |
+
|
948 |
+
# Define eval fn
|
949 |
+
def eval_step(
|
950 |
+
batch,
|
951 |
+
accelerator,
|
952 |
+
autocast_kwargs,
|
953 |
+
):
|
954 |
+
eval_model = model if not training_args.torch_compile else model._orig_mod
|
955 |
+
|
956 |
+
if mixed_precision == "fp16":
|
957 |
+
# fp16 doesn't work with T5-like models
|
958 |
+
with accelerator.autocast(autocast_handler=autocast_kwargs):
|
959 |
+
if training_args.parallel_mode.value != "distributed":
|
960 |
+
encoder_outputs = model.text_encoder(
|
961 |
+
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
|
962 |
+
)
|
963 |
+
else:
|
964 |
+
encoder_outputs = model.module.text_encoder(
|
965 |
+
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
|
966 |
+
)
|
967 |
+
# we optionnally project last_hidden_state to avoid recomputing every time
|
968 |
+
encoder_hidden_states = encoder_outputs.last_hidden_state
|
969 |
+
if (
|
970 |
+
config.text_encoder.hidden_size != config.decoder.hidden_size
|
971 |
+
and config.decoder.cross_attention_hidden_size is None
|
972 |
+
):
|
973 |
+
encoder_hidden_states = (
|
974 |
+
model.enc_to_dec_proj(encoder_hidden_states)
|
975 |
+
if training_args.parallel_mode.value != "distributed"
|
976 |
+
else model.module.enc_to_dec_proj(encoder_hidden_states)
|
977 |
+
)
|
978 |
+
|
979 |
+
if batch.get("attention_mask", None) is not None:
|
980 |
+
encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None]
|
981 |
+
|
982 |
+
encoder_outputs.last_hidden_state = encoder_hidden_states
|
983 |
+
batch["encoder_outputs"] = encoder_outputs
|
984 |
+
|
985 |
+
with torch.no_grad():
|
986 |
+
outputs = eval_model(**batch)
|
987 |
+
# CE (data) loss
|
988 |
+
ce_loss = outputs.loss
|
989 |
+
metrics = {"loss": ce_loss}
|
990 |
+
|
991 |
+
# per CE loss
|
992 |
+
per_codebook_losses = outputs.per_codebook_losses
|
993 |
+
metrics.update({f"codebook_{i}_loss": l for (i,l) in enumerate(per_codebook_losses)})
|
994 |
+
return metrics
|
995 |
+
|
996 |
+
def generate_step(batch, accelerator):
|
997 |
+
batch.pop("decoder_attention_mask", None)
|
998 |
+
eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=True)
|
999 |
+
if training_args.torch_compile:
|
1000 |
+
# if the model is compiled, we use the original model bc compile is not compatible with .generate
|
1001 |
+
eval_model = model._orig_mod
|
1002 |
+
|
1003 |
+
# since we've might have loaded the weights in fp32, we have to autocast to ensure FA2 weights are in half-precision.
|
1004 |
+
# with accelerator.autocast(autocast_handler=AutocastKwargs(enabled=(attn_implementation=="flash_attention_2"))):
|
1005 |
+
output_audios = eval_model.generate(**batch, **gen_kwargs)
|
1006 |
+
output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
|
1007 |
+
return output_audios
|
1008 |
+
|
1009 |
+
model.train()
|
1010 |
+
|
1011 |
+
total_batched_samples = resume_step if resume_step is not None else 0
|
1012 |
+
for epoch in range(epochs_trained, num_epochs):
|
1013 |
+
with accelerator.local_main_process_first():
|
1014 |
+
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
|
1015 |
+
sampler = None
|
1016 |
+
if training_args.group_by_length:
|
1017 |
+
sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"])
|
1018 |
+
train_dataloader = DataLoader(
|
1019 |
+
vectorized_datasets["train"],
|
1020 |
+
collate_fn=data_collator,
|
1021 |
+
batch_size=per_device_train_batch_size,
|
1022 |
+
sampler=sampler,
|
1023 |
+
shuffle=not training_args.group_by_length,
|
1024 |
+
num_workers=training_args.dataloader_num_workers,
|
1025 |
+
pin_memory=training_args.dataloader_pin_memory,
|
1026 |
+
)
|
1027 |
+
train_dataloader = accelerator.prepare(train_dataloader)
|
1028 |
+
if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
|
1029 |
+
train_dataloader.dataset.set_epoch(epoch)
|
1030 |
+
|
1031 |
+
if resume_step is not None:
|
1032 |
+
# Skip the first N batches in the dataloader when resuming from a checkpoint
|
1033 |
+
logger.info(f" Skip first {resume_step} batches")
|
1034 |
+
train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
|
1035 |
+
resume_step = None
|
1036 |
+
accelerator.wait_for_everyone()
|
1037 |
+
|
1038 |
+
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
|
1039 |
+
train_iterator = iter(train_dataloader)
|
1040 |
+
num_steps_in_epoch = len(train_dataloader)
|
1041 |
+
remainder = num_steps_in_epoch % gradient_accumulation_steps
|
1042 |
+
remainder = remainder if remainder != 0 else gradient_accumulation_steps
|
1043 |
+
total_updates = math.ceil(num_steps_in_epoch / gradient_accumulation_steps)
|
1044 |
+
|
1045 |
+
update_step = -1
|
1046 |
+
for _ in range(total_updates):
|
1047 |
+
update_step += 1
|
1048 |
+
|
1049 |
+
# preload the total batch per step
|
1050 |
+
batch_samples = []
|
1051 |
+
num_batches_in_step = gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
|
1052 |
+
for _ in range(num_batches_in_step):
|
1053 |
+
batch_samples += [next(train_iterator)]
|
1054 |
+
|
1055 |
+
# get num items in batch - if different than BOS and than -100
|
1056 |
+
num_items_in_batch = sum([(batch["labels"].ne(audio_encoder_bos_token_id) | batch["labels"].ne(-100) | batch["labels"].ne(audio_encoder_eos_token_id)).sum((0,1))[0] for batch in batch_samples])
|
1057 |
+
num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item()
|
1058 |
+
|
1059 |
+
# losses = []
|
1060 |
+
for i,batch in enumerate(batch_samples):
|
1061 |
+
total_batched_samples += 1
|
1062 |
+
ctx = model.no_sync if (i < len(batch_samples) - 1 and accelerator.num_processes > 1) else contextlib.nullcontext
|
1063 |
+
|
1064 |
+
with ctx():
|
1065 |
+
loss, train_metric = train_step(batch, accelerator, autocast_kwargs, num_items_in_batch, gradient_accumulation_steps)
|
1066 |
+
accelerator.backward(loss)
|
1067 |
+
# losses.append(loss.detach())
|
1068 |
+
|
1069 |
+
grad_norm = accelerator.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
|
1070 |
+
optimizer.step()
|
1071 |
+
lr_scheduler.step()
|
1072 |
+
optimizer.zero_grad()
|
1073 |
+
|
1074 |
+
# The accelerator has performed an optimization step behind the scenes
|
1075 |
+
steps_trained_progress_bar.update(1)
|
1076 |
+
cur_step += 1
|
1077 |
+
|
1078 |
+
# losses = accelerator.gather(sum(losses)).sum().item() / (accelerator.num_processes * gradient_accumulation_steps)
|
1079 |
+
|
1080 |
+
if cur_step % training_args.logging_steps == 0:
|
1081 |
+
steps_trained_progress_bar.write(
|
1082 |
+
f"Step... ({cur_step} / {total_train_steps} | Loss:"
|
1083 |
+
f" {train_metric['loss']}, Learning Rate:"
|
1084 |
+
f" {lr_scheduler.get_last_lr()[0]})"
|
1085 |
+
)
|
1086 |
+
train_metric["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
|
1087 |
+
log_metric(
|
1088 |
+
accelerator,
|
1089 |
+
metrics=train_metric,
|
1090 |
+
learning_rate=lr_scheduler.get_last_lr()[0],
|
1091 |
+
train_time=train_time + time.time() - train_start,
|
1092 |
+
step=cur_step,
|
1093 |
+
epoch=epoch,
|
1094 |
+
prefix="train",
|
1095 |
+
)
|
1096 |
+
|
1097 |
+
# save checkpoint and weights after each save_steps and at the end of training
|
1098 |
+
if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
|
1099 |
+
intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
|
1100 |
+
# safe_serialization=False to avoid shared tensors saving issue (TODO(YL): it's a temporary fix)
|
1101 |
+
# https://github.com/huggingface/transformers/issues/27293#issuecomment-1872560074
|
1102 |
+
accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
|
1103 |
+
accelerator.wait_for_everyone()
|
1104 |
+
if accelerator.is_main_process:
|
1105 |
+
rotate_checkpoints(
|
1106 |
+
training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger
|
1107 |
+
)
|
1108 |
+
|
1109 |
+
if cur_step == total_train_steps:
|
1110 |
+
# un-wrap student model for save
|
1111 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
1112 |
+
unwrapped_model.save_pretrained(training_args.output_dir)
|
1113 |
+
|
1114 |
+
if training_args.push_to_hub:
|
1115 |
+
api.upload_folder(
|
1116 |
+
repo_id=repo_id,
|
1117 |
+
folder_path=training_args.output_dir,
|
1118 |
+
commit_message=f"Saving train state of step {cur_step}",
|
1119 |
+
run_as_future=True,
|
1120 |
+
)
|
1121 |
+
accelerator.wait_for_everyone()
|
1122 |
+
|
1123 |
+
if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
|
1124 |
+
train_time += time.time() - train_start
|
1125 |
+
# ======================== Evaluating ==============================
|
1126 |
+
model.eval()
|
1127 |
+
eval_metrics = []
|
1128 |
+
eval_preds = []
|
1129 |
+
eval_descriptions = []
|
1130 |
+
eval_prompts = []
|
1131 |
+
eval_start = time.time()
|
1132 |
+
|
1133 |
+
# release training input batch
|
1134 |
+
batch = release_memory(batch)
|
1135 |
+
|
1136 |
+
validation_dataloader = DataLoader(
|
1137 |
+
vectorized_datasets["eval"],
|
1138 |
+
collate_fn=data_collator,
|
1139 |
+
batch_size=per_device_eval_batch_size,
|
1140 |
+
drop_last=False,
|
1141 |
+
num_workers=training_args.eval_dataloader_num_workers,
|
1142 |
+
pin_memory=training_args.dataloader_pin_memory,
|
1143 |
+
)
|
1144 |
+
validation_dataloader = accelerator.prepare(validation_dataloader)
|
1145 |
+
|
1146 |
+
for batch in tqdm(
|
1147 |
+
validation_dataloader,
|
1148 |
+
desc=f"Evaluating - Inference ...",
|
1149 |
+
position=2,
|
1150 |
+
disable=not accelerator.is_local_main_process,
|
1151 |
+
):
|
1152 |
+
# Model forward
|
1153 |
+
eval_metric = eval_step(batch, accelerator, autocast_kwargs)
|
1154 |
+
eval_metric = accelerator.gather_for_metrics(eval_metric)
|
1155 |
+
eval_metric = {key: val.unsqueeze(0) if val.ndim == 0 else val for (key,val) in eval_metric.items()}
|
1156 |
+
eval_metrics.append(eval_metric)
|
1157 |
+
|
1158 |
+
if training_args.predict_with_generate and (cur_step % eval_generation_steps == 0 or cur_step == total_train_steps):
|
1159 |
+
validation_dataloader = DataLoader(
|
1160 |
+
vectorized_datasets["eval"],
|
1161 |
+
collate_fn=data_collator,
|
1162 |
+
batch_size=per_device_eval_batch_size,
|
1163 |
+
drop_last=False,
|
1164 |
+
num_workers=training_args.eval_dataloader_num_workers,
|
1165 |
+
pin_memory=training_args.dataloader_pin_memory,
|
1166 |
+
)
|
1167 |
+
validation_dataloader = accelerator.prepare(validation_dataloader)
|
1168 |
+
# generation
|
1169 |
+
for batch in tqdm(
|
1170 |
+
validation_dataloader,
|
1171 |
+
desc=f"Evaluating - Generation ...",
|
1172 |
+
position=2,
|
1173 |
+
disable=not accelerator.is_local_main_process,
|
1174 |
+
):
|
1175 |
+
generated_audios = generate_step(batch, accelerator)
|
1176 |
+
# Gather all predictions and targets
|
1177 |
+
generated_audios, input_ids, prompts = accelerator.pad_across_processes(
|
1178 |
+
(generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0
|
1179 |
+
)
|
1180 |
+
generated_audios, input_ids, prompts = accelerator.gather_for_metrics(
|
1181 |
+
(generated_audios, input_ids, prompts)
|
1182 |
+
)
|
1183 |
+
eval_preds.extend(generated_audios.to("cpu"))
|
1184 |
+
eval_descriptions.extend(input_ids.to("cpu"))
|
1185 |
+
eval_prompts.extend(prompts.to("cpu"))
|
1186 |
+
|
1187 |
+
eval_time = time.time() - eval_start
|
1188 |
+
# normalize eval metrics
|
1189 |
+
eval_metrics = {
|
1190 |
+
key: torch.mean(torch.cat([d[key] for d in eval_metrics])).to("cpu") for key in eval_metrics[0]
|
1191 |
+
}
|
1192 |
+
|
1193 |
+
# compute metrics
|
1194 |
+
metrics_desc = ""
|
1195 |
+
if training_args.predict_with_generate and (cur_step % eval_generation_steps == 0 or cur_step == total_train_steps):
|
1196 |
+
if accelerator.is_local_main_process:
|
1197 |
+
(
|
1198 |
+
metric_values,
|
1199 |
+
pred_descriptions,
|
1200 |
+
pred_prompts,
|
1201 |
+
audios,
|
1202 |
+
transcriptions,
|
1203 |
+
si_sdr_measures,
|
1204 |
+
) = compute_metrics(
|
1205 |
+
eval_preds,
|
1206 |
+
eval_descriptions,
|
1207 |
+
eval_prompts,
|
1208 |
+
accelerator.device,
|
1209 |
+
training_args.compute_clap_similarity_metric,
|
1210 |
+
training_args.compute_noise_level_metric,
|
1211 |
+
training_args.noise_level_to_compute_clean_wer,
|
1212 |
+
)
|
1213 |
+
eval_metrics.update(metric_values)
|
1214 |
+
metrics_desc = " ".join([f"Eval {key}: {value} |" for key, value in metric_values.items()])
|
1215 |
+
if "wandb" in training_args.report_to:
|
1216 |
+
log_pred(
|
1217 |
+
accelerator,
|
1218 |
+
pred_descriptions,
|
1219 |
+
pred_prompts,
|
1220 |
+
transcriptions,
|
1221 |
+
audios,
|
1222 |
+
si_sdr_measures,
|
1223 |
+
sampling_rate=sampling_rate,
|
1224 |
+
step=cur_step,
|
1225 |
+
prefix="eval",
|
1226 |
+
)
|
1227 |
+
accelerator.wait_for_everyone()
|
1228 |
+
|
1229 |
+
# Print metrics and update progress bar
|
1230 |
+
if accelerator.is_local_main_process:
|
1231 |
+
steps_trained_progress_bar.write(
|
1232 |
+
f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
|
1233 |
+
f" {metrics_desc})"
|
1234 |
+
)
|
1235 |
+
|
1236 |
+
log_metric(
|
1237 |
+
accelerator,
|
1238 |
+
metrics=eval_metrics,
|
1239 |
+
train_time=eval_time,
|
1240 |
+
step=cur_step,
|
1241 |
+
epoch=epoch,
|
1242 |
+
prefix="eval",
|
1243 |
+
)
|
1244 |
+
|
1245 |
+
# release eval batch and relax metrics
|
1246 |
+
eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric = release_memory(
|
1247 |
+
eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric
|
1248 |
+
)
|
1249 |
+
if training_args.predict_with_generate and (cur_step % eval_generation_steps == 0 or cur_step == total_train_steps):
|
1250 |
+
generated_audios, input_ids, prompts = release_memory(generated_audios, input_ids, prompts)
|
1251 |
+
|
1252 |
+
# train mode
|
1253 |
+
model.train()
|
1254 |
+
|
1255 |
+
# flush the train metrics
|
1256 |
+
train_start = time.time()
|
1257 |
+
|
1258 |
+
# break condition
|
1259 |
+
if cur_step == total_train_steps:
|
1260 |
+
continue_training = False
|
1261 |
+
break
|
1262 |
+
|
1263 |
+
if not continue_training:
|
1264 |
+
break
|
1265 |
+
|
1266 |
+
accelerator.end_training()
|
1267 |
+
|
1268 |
+
|
1269 |
+
if __name__ == "__main__":
|
1270 |
+
main()
|
capspeech/ar/training/finetune_capttsse.py
ADDED
@@ -0,0 +1,1267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
""" Train Parler-TTS using 🤗 Accelerate"""
|
18 |
+
|
19 |
+
import logging
|
20 |
+
import os
|
21 |
+
import re
|
22 |
+
import sys
|
23 |
+
import time
|
24 |
+
import math
|
25 |
+
import contextlib
|
26 |
+
from multiprocess import set_start_method
|
27 |
+
from datetime import timedelta
|
28 |
+
import inspect
|
29 |
+
from tqdm import tqdm
|
30 |
+
from pathlib import Path
|
31 |
+
import wandb
|
32 |
+
|
33 |
+
import torch
|
34 |
+
from torch.utils.data import DataLoader
|
35 |
+
|
36 |
+
import datasets
|
37 |
+
from datasets import DatasetDict, Dataset, IterableDataset, concatenate_datasets
|
38 |
+
|
39 |
+
from huggingface_hub import HfApi
|
40 |
+
|
41 |
+
import transformers
|
42 |
+
from transformers import AutoFeatureExtractor, AutoTokenizer, HfArgumentParser
|
43 |
+
from transformers.trainer_pt_utils import LengthGroupedSampler
|
44 |
+
from transformers.optimization import get_scheduler
|
45 |
+
from transformers.utils import send_example_telemetry
|
46 |
+
|
47 |
+
|
48 |
+
from accelerate import Accelerator, skip_first_batches
|
49 |
+
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin, DistributedDataParallelKwargs
|
50 |
+
from accelerate.utils.memory import release_memory
|
51 |
+
|
52 |
+
from parler_tts import (
|
53 |
+
ParlerTTSConfig,
|
54 |
+
ParlerTTSForConditionalGeneration,
|
55 |
+
build_delay_pattern_mask,
|
56 |
+
)
|
57 |
+
|
58 |
+
from training.utils import (
|
59 |
+
get_last_checkpoint,
|
60 |
+
rotate_checkpoints,
|
61 |
+
log_pred,
|
62 |
+
log_metric,
|
63 |
+
load_all_codec_checkpoints,
|
64 |
+
save_codec_checkpoint,
|
65 |
+
get_last_codec_checkpoint_step,
|
66 |
+
)
|
67 |
+
from training.arguments_capttsse import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments
|
68 |
+
from training.data_capttsse import load_multiple_datasets, DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding
|
69 |
+
from training.eval import clap_similarity, wer, si_sdr
|
70 |
+
|
71 |
+
logger = logging.getLogger(__name__)
|
72 |
+
|
73 |
+
|
74 |
+
def main():
|
75 |
+
# See all possible arguments in src/transformers/training_args.py
|
76 |
+
# or by passing the --help flag to this script.
|
77 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
78 |
+
|
79 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments))
|
80 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
81 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
82 |
+
# let's parse it to get our arguments.
|
83 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
84 |
+
else:
|
85 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
86 |
+
|
87 |
+
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
88 |
+
# information sent is the one passed as arguments along with your Python/PyTorch versions.
|
89 |
+
send_example_telemetry("run_parler_tts", model_args, data_args)
|
90 |
+
|
91 |
+
if data_args.wandb_key is not None:
|
92 |
+
wandb.login(key=data_args.wandb_key)
|
93 |
+
|
94 |
+
if training_args.dtype == "float16":
|
95 |
+
mixed_precision = "fp16"
|
96 |
+
torch_dtype = torch.float16
|
97 |
+
elif training_args.dtype == "bfloat16":
|
98 |
+
mixed_precision = "bf16"
|
99 |
+
torch_dtype = torch.bfloat16
|
100 |
+
else:
|
101 |
+
mixed_precision = "no"
|
102 |
+
torch_dtype = torch.float32
|
103 |
+
|
104 |
+
if data_args.pad_to_max_length and (
|
105 |
+
data_args.max_duration_in_seconds is None
|
106 |
+
or data_args.max_prompt_token_length is None
|
107 |
+
or data_args.max_description_token_length is None
|
108 |
+
):
|
109 |
+
raise ValueError(
|
110 |
+
"`pad_to_max_length` is `True` but one of the following parameters has not been set: `max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`"
|
111 |
+
)
|
112 |
+
|
113 |
+
padding = "max_length" if data_args.pad_to_max_length else "longest"
|
114 |
+
|
115 |
+
####### A. Preparation
|
116 |
+
kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=120)), DistributedDataParallelKwargs(find_unused_parameters=False)]
|
117 |
+
|
118 |
+
accelerator = Accelerator(
|
119 |
+
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
|
120 |
+
mixed_precision=mixed_precision,
|
121 |
+
log_with=training_args.report_to,
|
122 |
+
project_dir=training_args.output_dir,
|
123 |
+
kwargs_handlers=kwargs_handlers,
|
124 |
+
)
|
125 |
+
|
126 |
+
accelerator.init_trackers(
|
127 |
+
project_name=data_args.wandb_project,
|
128 |
+
config={
|
129 |
+
"learning_rate": training_args.learning_rate,
|
130 |
+
"model_name_or_path": model_args.model_name_or_path,
|
131 |
+
"num_train_epochs": training_args.num_train_epochs,
|
132 |
+
"gradient_accumulation_steps": training_args.gradient_accumulation_steps,
|
133 |
+
"per_device_train_batch_size": training_args.per_device_train_batch_size,
|
134 |
+
"global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
|
135 |
+
"mixed_precision": mixed_precision,
|
136 |
+
"lr_scheduler_type": training_args.lr_scheduler_type,
|
137 |
+
"warmup_steps": training_args.warmup_steps,
|
138 |
+
"freeze_text_encoder": model_args.freeze_text_encoder,
|
139 |
+
"max_duration_in_seconds": data_args.max_duration_in_seconds,
|
140 |
+
"weight_decay": training_args.weight_decay,
|
141 |
+
"adam_beta1": training_args.adam_beta1,
|
142 |
+
"adam_beta2": training_args.adam_beta2,
|
143 |
+
"temperature": model_args.temperature,
|
144 |
+
},
|
145 |
+
init_kwargs={"wandb": {"name": data_args.wandb_run_name}} if data_args.wandb_run_name else {},
|
146 |
+
)
|
147 |
+
|
148 |
+
# Detecting last checkpoint and eventually continue from last checkpoint
|
149 |
+
last_checkpoint = None
|
150 |
+
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
151 |
+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
152 |
+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
153 |
+
raise ValueError(
|
154 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
155 |
+
"Use --overwrite_output_dir to overcome."
|
156 |
+
)
|
157 |
+
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
158 |
+
logger.info(
|
159 |
+
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
160 |
+
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
161 |
+
)
|
162 |
+
|
163 |
+
# Setup logging
|
164 |
+
logging.basicConfig(
|
165 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
166 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
167 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
168 |
+
)
|
169 |
+
logger.setLevel(logging.INFO if accelerator.is_main_process else logging.WARN)
|
170 |
+
|
171 |
+
# Log a small summary on each proces
|
172 |
+
logger.warning(
|
173 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
|
174 |
+
f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
|
175 |
+
)
|
176 |
+
|
177 |
+
# Set the verbosity to info of the Transformers logger (on main process only)
|
178 |
+
if accelerator.is_local_main_process:
|
179 |
+
datasets.utils.logging.set_verbosity_warning()
|
180 |
+
transformers.utils.logging.set_verbosity_info()
|
181 |
+
else:
|
182 |
+
datasets.utils.logging.set_verbosity_error()
|
183 |
+
transformers.utils.logging.set_verbosity_error()
|
184 |
+
|
185 |
+
logger.info("Training/evaluation parameters %s", training_args)
|
186 |
+
|
187 |
+
# Set seed before initializing model.
|
188 |
+
set_seed(training_args.seed)
|
189 |
+
num_workers = data_args.preprocessing_num_workers
|
190 |
+
|
191 |
+
# 1. First, lett's instantiate the feature extractor, tokenizers and model
|
192 |
+
# Note for distributed training, the .from_pretrained methods guarantee that only
|
193 |
+
# one local process can concurrently download model & vocab.
|
194 |
+
|
195 |
+
# load feature extractor
|
196 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
197 |
+
model_args.feature_extractor_name or model_args.model_name_or_path,
|
198 |
+
cache_dir=model_args.cache_dir,
|
199 |
+
token=data_args.token,
|
200 |
+
trust_remote_code=data_args.trust_remote_code,
|
201 |
+
)
|
202 |
+
sampling_rate = feature_extractor.sampling_rate
|
203 |
+
|
204 |
+
# load prompt tokenizer
|
205 |
+
prompt_tokenizer = AutoTokenizer.from_pretrained(
|
206 |
+
model_args.prompt_tokenizer_name or model_args.description_tokenizer_name or model_args.model_name_or_path,
|
207 |
+
cache_dir=model_args.cache_dir,
|
208 |
+
token=data_args.token,
|
209 |
+
trust_remote_code=data_args.trust_remote_code,
|
210 |
+
use_fast=model_args.use_fast_tokenizer,
|
211 |
+
padding_side=model_args.prompt_padding_side,
|
212 |
+
)
|
213 |
+
|
214 |
+
# load description tokenizer
|
215 |
+
description_tokenizer = AutoTokenizer.from_pretrained(
|
216 |
+
model_args.description_tokenizer_name or model_args.model_name_or_path,
|
217 |
+
cache_dir=model_args.cache_dir,
|
218 |
+
token=data_args.token,
|
219 |
+
trust_remote_code=data_args.trust_remote_code,
|
220 |
+
use_fast=model_args.use_fast_tokenizer,
|
221 |
+
)
|
222 |
+
|
223 |
+
if model_args.use_fast_tokenizer:
|
224 |
+
logger.warning(
|
225 |
+
"Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235"
|
226 |
+
)
|
227 |
+
prompt_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
|
228 |
+
description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
|
229 |
+
|
230 |
+
# 2. Now, let's load the dataset
|
231 |
+
|
232 |
+
if data_args.save_to_disk is not None:
|
233 |
+
os.makedirs(data_args.save_to_disk, exist_ok=True)
|
234 |
+
|
235 |
+
# assume that the dataset has been saved to `save_to_disk` if the latter is not empty
|
236 |
+
dataset_was_precomputed = len(os.listdir(data_args.save_to_disk)) > 0
|
237 |
+
if dataset_was_precomputed:
|
238 |
+
with accelerator.local_main_process_first():
|
239 |
+
vectorized_datasets = datasets.load_from_disk(data_args.save_to_disk)
|
240 |
+
else:
|
241 |
+
raw_datasets = DatasetDict()
|
242 |
+
|
243 |
+
columns_to_keep = {
|
244 |
+
"target_audio_column_name": data_args.target_audio_column_name,
|
245 |
+
"prompt_column_name": data_args.prompt_column_name,
|
246 |
+
"source": data_args.source_column_name,
|
247 |
+
}
|
248 |
+
if data_args.description_column_name is not None:
|
249 |
+
columns_to_keep["description_column_name"] = data_args.description_column_name
|
250 |
+
|
251 |
+
if training_args.do_train:
|
252 |
+
raw_datasets["train"] = load_multiple_datasets(
|
253 |
+
accelerator,
|
254 |
+
data_args.train_dataset_name,
|
255 |
+
splits=data_args.train_split_name,
|
256 |
+
dataset_samples=data_args.train_dataset_samples,
|
257 |
+
seed=training_args.seed,
|
258 |
+
cache_dir=model_args.cache_dir,
|
259 |
+
num_proc=data_args.preprocessing_num_workers,
|
260 |
+
id_column_name=data_args.id_column_name,
|
261 |
+
columns_to_keep=columns_to_keep.values(),
|
262 |
+
prompt_column_name=data_args.prompt_column_name,
|
263 |
+
audio_column_name=data_args.target_audio_column_name,
|
264 |
+
sampling_rate=sampling_rate,
|
265 |
+
logger=logger,
|
266 |
+
librittsrmix_dir=data_args.librittsrmix_dir,
|
267 |
+
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
|
268 |
+
)
|
269 |
+
|
270 |
+
for key in columns_to_keep:
|
271 |
+
if columns_to_keep[key] not in raw_datasets["train"].column_names:
|
272 |
+
raise ValueError(
|
273 |
+
f"--{key} '{columns_to_keep[key]}' not found in dataset '{data_args.train_dataset_name}'."
|
274 |
+
f" Make sure to set `--{key}` to the correct audio column - one of"
|
275 |
+
f" {', '.join(raw_datasets['train'].column_names)}."
|
276 |
+
)
|
277 |
+
|
278 |
+
if data_args.max_train_samples is not None:
|
279 |
+
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
280 |
+
|
281 |
+
if training_args.do_eval:
|
282 |
+
raw_datasets["eval"] = load_multiple_datasets(
|
283 |
+
accelerator,
|
284 |
+
data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
|
285 |
+
splits=data_args.eval_split_name,
|
286 |
+
cache_dir=model_args.cache_dir,
|
287 |
+
num_proc=data_args.preprocessing_num_workers,
|
288 |
+
id_column_name=data_args.id_column_name,
|
289 |
+
columns_to_keep=columns_to_keep.values(),
|
290 |
+
prompt_column_name=data_args.prompt_column_name,
|
291 |
+
audio_column_name=data_args.target_audio_column_name,
|
292 |
+
sampling_rate=sampling_rate,
|
293 |
+
logger=logger,
|
294 |
+
librittsrmix_dir=data_args.librittsrmix_dir,
|
295 |
+
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
|
296 |
+
)
|
297 |
+
|
298 |
+
if data_args.max_eval_samples is not None:
|
299 |
+
with accelerator.local_main_process_first():
|
300 |
+
raw_datasets["eval"] = (
|
301 |
+
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
|
302 |
+
)
|
303 |
+
|
304 |
+
# 3. Next, let's load the config.
|
305 |
+
config = ParlerTTSConfig.from_pretrained(
|
306 |
+
model_args.model_name_or_path,
|
307 |
+
cache_dir=model_args.cache_dir,
|
308 |
+
token=data_args.token,
|
309 |
+
trust_remote_code=data_args.trust_remote_code,
|
310 |
+
)
|
311 |
+
|
312 |
+
if training_args.codebook_weights is not None and len(training_args.codebook_weights) != config.decoder.num_codebooks:
|
313 |
+
raise ValueError(f"`codebook_weights` has length {len(training_args.codebook_weights)} when it should be of length {config.decoder.num_codebooks}.")
|
314 |
+
|
315 |
+
# update pad token id and decoder_start_token_id
|
316 |
+
config.decoder.update(
|
317 |
+
{
|
318 |
+
"cross_attention_implementation_strategy": model_args.cross_attention_implementation_strategy
|
319 |
+
if model_args.cross_attention_implementation_strategy is not None
|
320 |
+
else None,
|
321 |
+
"codebook_weights": training_args.codebook_weights if training_args.codebook_weights is not None else config.decoder.codebook_weights
|
322 |
+
}
|
323 |
+
)
|
324 |
+
config.update(
|
325 |
+
{
|
326 |
+
"pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id,
|
327 |
+
"decoder_start_token_id": model_args.decoder_start_token_id
|
328 |
+
if model_args.decoder_start_token_id is not None
|
329 |
+
else config.decoder_start_token_id,
|
330 |
+
}
|
331 |
+
)
|
332 |
+
|
333 |
+
with open("events.txt", "r") as f:
|
334 |
+
events = [line.strip() for line in f]
|
335 |
+
events = ["<"+event.lower().replace(" ", "_")+">" for event in events]
|
336 |
+
events.append("<B_start>")
|
337 |
+
events.append("<B_end>")
|
338 |
+
events.append("<I_start>")
|
339 |
+
events.append("<I_end>")
|
340 |
+
|
341 |
+
special_tokens = {"additional_special_tokens": events}
|
342 |
+
prompt_tokenizer.add_special_tokens(special_tokens)
|
343 |
+
description_tokenizer.add_special_tokens(special_tokens)
|
344 |
+
padded_vocab_size = ((len(prompt_tokenizer) + 127) // 128) * 128
|
345 |
+
config.vocab_size = padded_vocab_size
|
346 |
+
|
347 |
+
# create model
|
348 |
+
model = ParlerTTSForConditionalGeneration.from_pretrained(
|
349 |
+
model_args.model_name_or_path,
|
350 |
+
ignore_mismatched_sizes=True,
|
351 |
+
cache_dir=model_args.cache_dir,
|
352 |
+
config=config,
|
353 |
+
token=data_args.token,
|
354 |
+
trust_remote_code=data_args.trust_remote_code,
|
355 |
+
attn_implementation={"decoder": model_args.attn_implementation, "text_encoder": "eager"},
|
356 |
+
)
|
357 |
+
model.text_encoder.resize_token_embeddings(padded_vocab_size)
|
358 |
+
|
359 |
+
# enable gradient checkpointing if necessary
|
360 |
+
if training_args.gradient_checkpointing:
|
361 |
+
model.gradient_checkpointing_enable()
|
362 |
+
|
363 |
+
# 4. Now we preprocess the datasets including loading the audio, resampling and normalization
|
364 |
+
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
|
365 |
+
# so that we just need to set the correct target sampling rate and normalize the input
|
366 |
+
# via the `feature_extractor`
|
367 |
+
|
368 |
+
# derive max & min input length for sample rate & max duration
|
369 |
+
sampling_rate = feature_extractor.sampling_rate
|
370 |
+
max_target_length = int(data_args.max_duration_in_seconds * sampling_rate)
|
371 |
+
min_target_length = int(data_args.min_duration_in_seconds * sampling_rate)
|
372 |
+
target_audio_column_name = data_args.target_audio_column_name
|
373 |
+
description_column_name = data_args.description_column_name
|
374 |
+
prompt_column_name = data_args.prompt_column_name
|
375 |
+
feature_extractor_input_name = feature_extractor.model_input_names[0]
|
376 |
+
audio_encoder_pad_token_id = config.decoder.pad_token_id
|
377 |
+
audio_encoder_eos_token_id = config.decoder.eos_token_id
|
378 |
+
audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id
|
379 |
+
max_length = model.generation_config.max_length
|
380 |
+
num_codebooks = model.decoder.config.num_codebooks
|
381 |
+
bandwidth = model_args.bandwidth
|
382 |
+
attn_implementation = model_args.attn_implementation
|
383 |
+
|
384 |
+
# Freeze Encoders
|
385 |
+
model.freeze_encoders(model_args.freeze_text_encoder)
|
386 |
+
|
387 |
+
# Test all gather - used for warmout and avoiding timeout
|
388 |
+
logger.debug(str(accelerator.process_index), main_process_only=False, in_order=True)
|
389 |
+
test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
|
390 |
+
gathered_tensor = accelerator.gather(test_tensor)
|
391 |
+
print("gathered_tensor", gathered_tensor)
|
392 |
+
accelerator.wait_for_everyone()
|
393 |
+
|
394 |
+
if not dataset_was_precomputed:
|
395 |
+
# Filter on text length
|
396 |
+
if description_column_name is not None and data_args.max_text_length is not None:
|
397 |
+
with accelerator.local_main_process_first():
|
398 |
+
# filter description that is shorter than max_text_length
|
399 |
+
raw_datasets = raw_datasets.filter(
|
400 |
+
lambda x: len(x) < data_args.max_text_length,
|
401 |
+
num_proc=num_workers,
|
402 |
+
input_columns=[description_column_name],
|
403 |
+
)
|
404 |
+
|
405 |
+
# Preprocessing the dataset.
|
406 |
+
# We need to tokenize the texts.
|
407 |
+
def pass_through_processors(description, prompt):
|
408 |
+
batch = {}
|
409 |
+
|
410 |
+
batch["input_ids"] = description_tokenizer(description.strip())["input_ids"]
|
411 |
+
batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"]
|
412 |
+
|
413 |
+
return batch
|
414 |
+
|
415 |
+
with accelerator.local_main_process_first():
|
416 |
+
# this is a trick to avoid to rewrite the entire audio column which takes ages
|
417 |
+
vectorized_datasets = raw_datasets.map(
|
418 |
+
pass_through_processors,
|
419 |
+
remove_columns=next(iter(raw_datasets.values())).column_names,
|
420 |
+
input_columns=[description_column_name, prompt_column_name],
|
421 |
+
num_proc=num_workers,
|
422 |
+
desc="preprocess datasets",
|
423 |
+
)
|
424 |
+
|
425 |
+
# We use Accelerate to perform distributed inference
|
426 |
+
# T5 doesn't support fp16
|
427 |
+
autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
|
428 |
+
|
429 |
+
# Now we encode the audio labels with encodec.
|
430 |
+
####### B. Encode audio
|
431 |
+
|
432 |
+
logger.info("*** Encode target audio with encodec ***")
|
433 |
+
|
434 |
+
# no need to prepare audio_decoder because used for inference without mixed precision
|
435 |
+
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
|
436 |
+
if training_args.torch_compile:
|
437 |
+
audio_decoder = accelerator.prepare_model(model.audio_encoder, evaluation_mode=True)
|
438 |
+
else:
|
439 |
+
audio_decoder = model.audio_encoder
|
440 |
+
|
441 |
+
encoder_data_collator = DataCollatorEncodecWithPadding(
|
442 |
+
feature_extractor,
|
443 |
+
audio_column_name=target_audio_column_name,
|
444 |
+
librittsrmix_dir=data_args.librittsrmix_dir,
|
445 |
+
feature_extractor_input_name=feature_extractor_input_name,
|
446 |
+
max_length=max_target_length,
|
447 |
+
padding=padding,
|
448 |
+
)
|
449 |
+
encoder_signature = set(inspect.signature(audio_decoder.forward).parameters)
|
450 |
+
|
451 |
+
def apply_audio_decoder(batch):
|
452 |
+
len_audio = batch.pop("len_audio")
|
453 |
+
audio_decoder.to(batch["input_values"].device).eval()
|
454 |
+
if bandwidth is not None:
|
455 |
+
batch["bandwidth"] = bandwidth
|
456 |
+
elif "num_quantizers" in encoder_signature:
|
457 |
+
batch["num_quantizers"] = num_codebooks
|
458 |
+
elif "num_codebooks" in encoder_signature:
|
459 |
+
batch["num_codebooks"] = num_codebooks
|
460 |
+
elif "n_quantizers" in encoder_signature:
|
461 |
+
batch["n_quantizers"] = num_codebooks
|
462 |
+
|
463 |
+
with torch.no_grad():
|
464 |
+
labels = audio_decoder.encode(**batch)["audio_codes"]
|
465 |
+
output = {}
|
466 |
+
output["len_audio"] = len_audio
|
467 |
+
# (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
|
468 |
+
output["labels"] = labels.squeeze(0).transpose(1, 2)
|
469 |
+
|
470 |
+
# if `pad_to_max_length`, the maximum corresponding audio length of the current batch is max_duration*sampling_rate
|
471 |
+
max_length = len_audio.max() if padding != "max_length" else max_target_length
|
472 |
+
output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / max_length
|
473 |
+
return output
|
474 |
+
|
475 |
+
# (1, codebooks, seq_len) where seq_len=1
|
476 |
+
bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
|
477 |
+
|
478 |
+
def postprocess_dataset(labels):
|
479 |
+
# (1, codebooks, seq_len)
|
480 |
+
labels = torch.tensor(labels).unsqueeze(0)
|
481 |
+
# add bos
|
482 |
+
labels = torch.cat([bos_labels, labels], dim=-1)
|
483 |
+
|
484 |
+
labels, delay_pattern_mask = build_delay_pattern_mask(
|
485 |
+
labels,
|
486 |
+
bos_token_id=audio_encoder_bos_token_id,
|
487 |
+
pad_token_id=audio_encoder_eos_token_id,
|
488 |
+
max_length=labels.shape[-1] + num_codebooks,
|
489 |
+
num_codebooks=num_codebooks,
|
490 |
+
)
|
491 |
+
|
492 |
+
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
|
493 |
+
# to take care of EOS
|
494 |
+
# we want labels to look like this:
|
495 |
+
# - [B, a, b, E, E, E, E]
|
496 |
+
# - [B, B, c, d, E, E, E]
|
497 |
+
# - [B, B, B, e, f, E, E]
|
498 |
+
# - [B, B, B, B, g, h, E]
|
499 |
+
labels = torch.where(delay_pattern_mask == -1, audio_encoder_eos_token_id, delay_pattern_mask)
|
500 |
+
|
501 |
+
# the first timestamp is associated to a row full of BOS, let's get rid of it
|
502 |
+
# we also remove the last timestampts (full of PAD)
|
503 |
+
output = {"labels": labels[:, 1:]}
|
504 |
+
return output
|
505 |
+
|
506 |
+
for split in vectorized_datasets:
|
507 |
+
data_loader = DataLoader(
|
508 |
+
raw_datasets[split],
|
509 |
+
batch_size=training_args.audio_encoder_per_device_batch_size,
|
510 |
+
collate_fn=encoder_data_collator,
|
511 |
+
num_workers=training_args.dataloader_num_workers,
|
512 |
+
pin_memory=True,
|
513 |
+
)
|
514 |
+
data_loader = accelerator.prepare(data_loader)
|
515 |
+
total_inference_steps = len(data_loader)
|
516 |
+
|
517 |
+
start_step = get_last_codec_checkpoint_step(os.path.join(data_args.temporary_save_to_disk, split))
|
518 |
+
accelerator.wait_for_everyone()
|
519 |
+
if start_step > 0:
|
520 |
+
logger.info(f"Resuming {split} from step {start_step}")
|
521 |
+
# efficiently skip the first n batches
|
522 |
+
start_step += 1
|
523 |
+
data_loader = skip_first_batches(data_loader, start_step)
|
524 |
+
|
525 |
+
all_generated_labels = []
|
526 |
+
all_lens = []
|
527 |
+
if start_step < total_inference_steps:
|
528 |
+
for i, batch in enumerate(tqdm(data_loader, disable=not accelerator.is_local_main_process)):
|
529 |
+
cur_step = start_step + i
|
530 |
+
generate_labels = apply_audio_decoder(batch)
|
531 |
+
generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0)
|
532 |
+
generate_labels = accelerator.gather_for_metrics(generate_labels)
|
533 |
+
|
534 |
+
if accelerator.is_main_process:
|
535 |
+
lab = generate_labels["labels"].cpu().transpose(1, 2).to(torch.int16)
|
536 |
+
rat = generate_labels["ratio"].cpu().squeeze(1)
|
537 |
+
lens = generate_labels["len_audio"].cpu().squeeze(1)
|
538 |
+
lab = [l[:, : int(ratio * length)] for (l, ratio, length) in zip(lab, rat, lens)]
|
539 |
+
|
540 |
+
all_generated_labels.extend(lab)
|
541 |
+
all_lens.extend(lens)
|
542 |
+
|
543 |
+
if ((cur_step + 1) % data_args.save_codec_steps == 0) or (
|
544 |
+
cur_step == total_inference_steps - 1
|
545 |
+
):
|
546 |
+
tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
|
547 |
+
tmp_labels = tmp_labels.map(
|
548 |
+
postprocess_dataset,
|
549 |
+
num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
|
550 |
+
input_columns=["labels"],
|
551 |
+
desc="Postprocessing labeling",
|
552 |
+
)
|
553 |
+
save_codec_checkpoint(
|
554 |
+
os.path.join(data_args.temporary_save_to_disk, split), tmp_labels, cur_step
|
555 |
+
)
|
556 |
+
all_generated_labels = []
|
557 |
+
all_lens = []
|
558 |
+
|
559 |
+
accelerator.wait_for_everyone()
|
560 |
+
|
561 |
+
if accelerator.is_main_process and len(all_generated_labels) > 0:
|
562 |
+
tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
|
563 |
+
tmp_labels = tmp_labels.map(
|
564 |
+
postprocess_dataset,
|
565 |
+
num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
|
566 |
+
input_columns=["labels"],
|
567 |
+
desc="Postprocessing labeling",
|
568 |
+
)
|
569 |
+
save_codec_checkpoint(os.path.join(data_args.temporary_save_to_disk, split), tmp_labels, cur_step)
|
570 |
+
all_generated_labels = []
|
571 |
+
all_lens = []
|
572 |
+
accelerator.wait_for_everyone()
|
573 |
+
|
574 |
+
del all_generated_labels
|
575 |
+
accelerator.wait_for_everyone()
|
576 |
+
|
577 |
+
with accelerator.local_main_process_first():
|
578 |
+
tmp_labels = load_all_codec_checkpoints(os.path.join(data_args.temporary_save_to_disk, split)).select(
|
579 |
+
range(len(vectorized_datasets[split]))
|
580 |
+
)
|
581 |
+
logger.info(f"Concatenating {split}: {tmp_labels} with {vectorized_datasets[split]}")
|
582 |
+
vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1)
|
583 |
+
|
584 |
+
accelerator.free_memory()
|
585 |
+
del generate_labels, all_lens
|
586 |
+
|
587 |
+
with accelerator.local_main_process_first():
|
588 |
+
# NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations
|
589 |
+
# caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets.
|
590 |
+
# That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets.
|
591 |
+
|
592 |
+
def is_audio_in_length_range(length):
|
593 |
+
return length > min_target_length and length < max_target_length
|
594 |
+
|
595 |
+
# filter data that is shorter than min_target_length
|
596 |
+
vectorized_datasets = vectorized_datasets.filter(
|
597 |
+
is_audio_in_length_range,
|
598 |
+
num_proc=num_workers,
|
599 |
+
input_columns=["target_length"],
|
600 |
+
)
|
601 |
+
|
602 |
+
if description_column_name is not None and data_args.max_description_token_length is not None:
|
603 |
+
with accelerator.local_main_process_first():
|
604 |
+
# filter description that is shorter than max_text_length
|
605 |
+
vectorized_datasets = vectorized_datasets.filter(
|
606 |
+
lambda x: len(x) < data_args.max_description_token_length,
|
607 |
+
num_proc=num_workers,
|
608 |
+
input_columns=["input_ids"],
|
609 |
+
)
|
610 |
+
|
611 |
+
if data_args.max_prompt_token_length is not None:
|
612 |
+
with accelerator.local_main_process_first():
|
613 |
+
# filter description that is shorter than max_text_length
|
614 |
+
vectorized_datasets = vectorized_datasets.filter(
|
615 |
+
lambda x: len(x) < data_args.max_prompt_token_length,
|
616 |
+
num_proc=num_workers,
|
617 |
+
input_columns=["prompt_input_ids"],
|
618 |
+
)
|
619 |
+
|
620 |
+
if data_args.save_to_disk is not None and not dataset_was_precomputed:
|
621 |
+
if accelerator.is_main_process:
|
622 |
+
vectorized_datasets.save_to_disk(
|
623 |
+
data_args.save_to_disk,
|
624 |
+
num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"]) - 1),
|
625 |
+
)
|
626 |
+
accelerator.wait_for_everyone()
|
627 |
+
logger.info(f"Dataset saved at {data_args.save_to_disk}")
|
628 |
+
|
629 |
+
audio_max_length = None
|
630 |
+
if padding == "max_length":
|
631 |
+
audio_max_length = max(vectorized_datasets["train"]["target_length"])
|
632 |
+
with accelerator.local_main_process_first():
|
633 |
+
max_sample = vectorized_datasets["train"].filter(
|
634 |
+
lambda x: x == audio_max_length,
|
635 |
+
num_proc=num_workers,
|
636 |
+
input_columns=["target_length"],
|
637 |
+
)
|
638 |
+
audio_max_length = max([len(l[0]) for l in max_sample["labels"]])
|
639 |
+
|
640 |
+
if description_column_name is not None and data_args.max_description_token_length is not None:
|
641 |
+
with accelerator.local_main_process_first():
|
642 |
+
# filter description that is shorter than max_text_length
|
643 |
+
vectorized_datasets = vectorized_datasets.filter(
|
644 |
+
lambda x: len(x) < data_args.max_description_token_length,
|
645 |
+
num_proc=num_workers,
|
646 |
+
input_columns=["input_ids"],
|
647 |
+
)
|
648 |
+
|
649 |
+
if data_args.max_prompt_token_length is not None:
|
650 |
+
with accelerator.local_main_process_first():
|
651 |
+
# filter description that is shorter than max_text_length
|
652 |
+
vectorized_datasets = vectorized_datasets.filter(
|
653 |
+
lambda x: len(x) < data_args.max_prompt_token_length,
|
654 |
+
num_proc=num_workers,
|
655 |
+
input_columns=["prompt_input_ids"],
|
656 |
+
)
|
657 |
+
|
658 |
+
if training_args.group_by_length:
|
659 |
+
# apply a simple heuristic to take into account audio and text lengths
|
660 |
+
def add_target_lengths(target_length, prompt, description):
|
661 |
+
return {"target_length": target_length + len(prompt) + len(description)}
|
662 |
+
|
663 |
+
with accelerator.local_main_process_first():
|
664 |
+
vectorized_datasets = vectorized_datasets.map(
|
665 |
+
add_target_lengths,
|
666 |
+
num_proc=num_workers,
|
667 |
+
input_columns=["target_length", "prompt_input_ids", "input_ids"],
|
668 |
+
)
|
669 |
+
|
670 |
+
# for large datasets it is advised to run the preprocessing on a
|
671 |
+
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
672 |
+
# be a timeout when running the script in distributed mode.
|
673 |
+
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the
|
674 |
+
# cached dataset
|
675 |
+
if data_args.preprocessing_only and data_args.save_to_disk is None:
|
676 |
+
raise ValueError(
|
677 |
+
"`preprocessing_only=True` but `save_to_disk` is not set. The latter should indicates where to save the dataset locally."
|
678 |
+
)
|
679 |
+
elif data_args.preprocessing_only:
|
680 |
+
logger.info(f"Data preprocessing finished. Files save at {data_args.save_to_disk}")
|
681 |
+
return
|
682 |
+
|
683 |
+
# 6. Next, we can prepare the training.
|
684 |
+
|
685 |
+
# Let's use word CLAP similary and WER metrics as our evaluation metrics,
|
686 |
+
def compute_metrics(
|
687 |
+
audios,
|
688 |
+
descriptions,
|
689 |
+
prompts,
|
690 |
+
device="cpu",
|
691 |
+
compute_clap_similarity_metric=False,
|
692 |
+
compute_noise_level_metric=False,
|
693 |
+
noise_level_to_compute_clean_wer=None,
|
694 |
+
):
|
695 |
+
results = {}
|
696 |
+
input_ids = descriptions
|
697 |
+
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
698 |
+
prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
|
699 |
+
audios = [a.float().cpu().numpy() for a in audios]
|
700 |
+
|
701 |
+
if compute_clap_similarity_metric:
|
702 |
+
clap_score = clap_similarity(
|
703 |
+
model_args.clap_model_name_or_path, texts, audios, device, input_sampling_rate=sampling_rate
|
704 |
+
)
|
705 |
+
results["clap"] = clap_score
|
706 |
+
|
707 |
+
si_sdr_measures = None
|
708 |
+
if compute_noise_level_metric:
|
709 |
+
si_sdr_measures = si_sdr(audios, device, input_sampling_rate=sampling_rate)
|
710 |
+
|
711 |
+
word_error, transcriptions, clean_word_error, noisy_word_error, percent_clean_samples = wer(
|
712 |
+
model_args.asr_model_name_or_path,
|
713 |
+
prompts,
|
714 |
+
audios,
|
715 |
+
device,
|
716 |
+
training_args.per_device_eval_batch_size,
|
717 |
+
sampling_rate,
|
718 |
+
noise_level_to_compute_clean_wer,
|
719 |
+
si_sdr_measures,
|
720 |
+
)
|
721 |
+
results["wer"] = word_error
|
722 |
+
if clean_word_error is not None:
|
723 |
+
results["clean_wer"] = clean_word_error
|
724 |
+
results["noisy_word_error"] = noisy_word_error
|
725 |
+
results["percent_clean_samples"] = percent_clean_samples
|
726 |
+
|
727 |
+
return results, texts, prompts, audios, transcriptions, si_sdr_measures
|
728 |
+
|
729 |
+
# Define Training Schedule
|
730 |
+
# Store some constants
|
731 |
+
per_device_train_batch_size = int(training_args.per_device_train_batch_size)
|
732 |
+
train_batch_size = per_device_train_batch_size * accelerator.num_processes
|
733 |
+
gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
|
734 |
+
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
|
735 |
+
|
736 |
+
if training_args.max_steps < 0:
|
737 |
+
num_epochs = int(training_args.num_train_epochs)
|
738 |
+
steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
|
739 |
+
total_train_steps = steps_per_epoch * num_epochs
|
740 |
+
elif training_args.max_steps > 0:
|
741 |
+
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
742 |
+
total_train_steps = int(training_args.max_steps)
|
743 |
+
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
|
744 |
+
num_epochs = sys.maxsize
|
745 |
+
steps_per_epoch = total_train_steps
|
746 |
+
|
747 |
+
if training_args.eval_steps is None:
|
748 |
+
logger.info(f"eval_steps is not set, evaluating at the end of each epoch")
|
749 |
+
eval_steps = steps_per_epoch
|
750 |
+
else:
|
751 |
+
eval_steps = training_args.eval_steps
|
752 |
+
|
753 |
+
if training_args.eval_generation_steps is None:
|
754 |
+
eval_generation_steps = eval_steps
|
755 |
+
else:
|
756 |
+
eval_generation_steps = training_args.eval_generation_steps
|
757 |
+
|
758 |
+
# T5 doesn't support fp16
|
759 |
+
autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
|
760 |
+
|
761 |
+
# Define optimizer, LR scheduler, collator
|
762 |
+
optimizer = torch.optim.AdamW(
|
763 |
+
params=model.parameters(),
|
764 |
+
lr=training_args.learning_rate,
|
765 |
+
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
766 |
+
eps=training_args.adam_epsilon,
|
767 |
+
weight_decay=training_args.weight_decay,
|
768 |
+
)
|
769 |
+
|
770 |
+
# LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
|
771 |
+
lr_scheduler = get_scheduler(
|
772 |
+
name=training_args.lr_scheduler_type,
|
773 |
+
optimizer=optimizer,
|
774 |
+
num_warmup_steps=training_args.get_warmup_steps(total_train_steps) * accelerator.num_processes,
|
775 |
+
num_training_steps=total_train_steps * accelerator.num_processes,
|
776 |
+
)
|
777 |
+
|
778 |
+
# Instantiate custom data collator
|
779 |
+
data_collator = DataCollatorParlerTTSWithPadding(
|
780 |
+
prompt_tokenizer=prompt_tokenizer,
|
781 |
+
description_tokenizer=description_tokenizer,
|
782 |
+
pad_to_multiple_of=data_args.pad_to_multiple_of,
|
783 |
+
padding=padding,
|
784 |
+
prompt_max_length=data_args.max_prompt_token_length,
|
785 |
+
description_max_length=data_args.max_description_token_length,
|
786 |
+
audio_max_length=audio_max_length,
|
787 |
+
)
|
788 |
+
|
789 |
+
# Prepare everything with accelerate
|
790 |
+
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
|
791 |
+
|
792 |
+
num_examples = total_train_steps * train_batch_size * gradient_accumulation_steps
|
793 |
+
logger.info("***** Running training *****")
|
794 |
+
logger.info(f" Num examples = {num_examples}")
|
795 |
+
logger.info(" Instantaneous batch size per device =" f" {per_device_train_batch_size}")
|
796 |
+
logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
|
797 |
+
logger.info(
|
798 |
+
f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
|
799 |
+
)
|
800 |
+
logger.info(f" Total optimization steps = {total_train_steps}")
|
801 |
+
|
802 |
+
# ======================== Training ================================
|
803 |
+
train_time = 0
|
804 |
+
train_start = time.time()
|
805 |
+
steps_trained_progress_bar = tqdm(
|
806 |
+
range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
|
807 |
+
)
|
808 |
+
continue_training = True
|
809 |
+
epochs_trained = 0
|
810 |
+
cur_step = 0
|
811 |
+
|
812 |
+
checkpoint = None
|
813 |
+
if training_args.resume_from_checkpoint is not None:
|
814 |
+
checkpoint = training_args.resume_from_checkpoint
|
815 |
+
elif last_checkpoint is not None:
|
816 |
+
checkpoint = last_checkpoint
|
817 |
+
|
818 |
+
if accelerator.is_main_process:
|
819 |
+
if training_args.push_to_hub:
|
820 |
+
api = HfApi(token=training_args.hub_token)
|
821 |
+
|
822 |
+
# Create repo (repo_name from args or inferred)
|
823 |
+
repo_name = training_args.hub_model_id
|
824 |
+
if repo_name is None:
|
825 |
+
repo_name = Path(training_args.output_dir).absolute().name
|
826 |
+
repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
|
827 |
+
|
828 |
+
with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
|
829 |
+
if "wandb" not in gitignore:
|
830 |
+
gitignore.write("wandb\n")
|
831 |
+
elif training_args.output_dir is not None:
|
832 |
+
os.makedirs(training_args.output_dir, exist_ok=True)
|
833 |
+
accelerator.wait_for_everyone()
|
834 |
+
|
835 |
+
# Now save everything to be able to create a single processor later
|
836 |
+
# make sure all processes wait until data is saved
|
837 |
+
# only the main process saves them
|
838 |
+
if accelerator.is_main_process:
|
839 |
+
# save feature extractor, tokenizer and config
|
840 |
+
if (
|
841 |
+
model_args.prompt_tokenizer_name is None
|
842 |
+
and model_args.description_tokenizer_name
|
843 |
+
or (model_args.prompt_tokenizer_name == model_args.description_tokenizer_name)
|
844 |
+
):
|
845 |
+
prompt_tokenizer.save_pretrained(training_args.output_dir)
|
846 |
+
else:
|
847 |
+
logger.warning(
|
848 |
+
f"Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer."
|
849 |
+
)
|
850 |
+
prompt_tokenizer.save_pretrained(training_args.output_dir)
|
851 |
+
|
852 |
+
feature_extractor.save_pretrained(training_args.output_dir)
|
853 |
+
config.save_pretrained(training_args.output_dir)
|
854 |
+
accelerator.wait_for_everyone()
|
855 |
+
|
856 |
+
if checkpoint is not None:
|
857 |
+
accelerator.load_state(checkpoint)
|
858 |
+
# Find num steps and epoch from saved state string pattern
|
859 |
+
pattern = r"checkpoint-(\d+)-epoch-(\d+)"
|
860 |
+
match = re.search(pattern, checkpoint)
|
861 |
+
cur_step = int(match.group(1))
|
862 |
+
epochs_trained = int(match.group(2))
|
863 |
+
|
864 |
+
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
865 |
+
logger.info(f" Continuing training from epoch {epochs_trained}")
|
866 |
+
logger.info(f" Continuing training from global step {cur_step}")
|
867 |
+
|
868 |
+
steps_trained_progress_bar.update(cur_step)
|
869 |
+
|
870 |
+
for epoch in range(0, epochs_trained):
|
871 |
+
with accelerator.local_main_process_first():
|
872 |
+
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
|
873 |
+
|
874 |
+
if training_args.max_steps < 0:
|
875 |
+
# we know exactly the number of steps per epoch, so can skip through the required number of batches
|
876 |
+
resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
|
877 |
+
else:
|
878 |
+
# Currently we don't know how many steps we've taken in the current epoch
|
879 |
+
# So we just shuffle the dataset one extra time and start from a fresh epoch
|
880 |
+
# This is "good enough" for our purposes but not fully correct
|
881 |
+
resume_step = None
|
882 |
+
with accelerator.local_main_process_first():
|
883 |
+
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
|
884 |
+
else:
|
885 |
+
resume_step = None
|
886 |
+
|
887 |
+
gen_kwargs = {
|
888 |
+
"do_sample": model_args.do_sample,
|
889 |
+
"temperature": model_args.temperature,
|
890 |
+
"max_length": model_args.max_length,
|
891 |
+
# Because of the delayed pattern mask, generation might stop earlier because of unexpected behaviour
|
892 |
+
# on the first tokens of the codebooks that are delayed.
|
893 |
+
# This fix the issue.
|
894 |
+
"min_new_tokens": num_codebooks + 1,
|
895 |
+
}
|
896 |
+
|
897 |
+
# Define gradient update step fn
|
898 |
+
def train_step(
|
899 |
+
batch,
|
900 |
+
accelerator,
|
901 |
+
autocast_kwargs,
|
902 |
+
num_items_in_batch,
|
903 |
+
gradient_accumulation_steps,
|
904 |
+
):
|
905 |
+
if mixed_precision == "fp16":
|
906 |
+
# fp16 doesn't work with T5-like models
|
907 |
+
with accelerator.autocast(autocast_handler=autocast_kwargs):
|
908 |
+
if training_args.parallel_mode.value != "distributed":
|
909 |
+
encoder_outputs = model.text_encoder(
|
910 |
+
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
|
911 |
+
)
|
912 |
+
else:
|
913 |
+
encoder_outputs = model.module.text_encoder(
|
914 |
+
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
|
915 |
+
)
|
916 |
+
# we optionnally project last_hidden_state to avoid recomputing every time
|
917 |
+
encoder_hidden_states = encoder_outputs.last_hidden_state
|
918 |
+
if (
|
919 |
+
config.text_encoder.hidden_size != config.decoder.hidden_size
|
920 |
+
and config.decoder.cross_attention_hidden_size is None
|
921 |
+
):
|
922 |
+
encoder_hidden_states = (
|
923 |
+
model.enc_to_dec_proj(encoder_hidden_states)
|
924 |
+
if training_args.parallel_mode.value != "distributed"
|
925 |
+
else model.module.enc_to_dec_proj(encoder_hidden_states)
|
926 |
+
)
|
927 |
+
|
928 |
+
if batch.get("attention_mask", None) is not None:
|
929 |
+
encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None]
|
930 |
+
|
931 |
+
encoder_outputs.last_hidden_state = encoder_hidden_states
|
932 |
+
batch["encoder_outputs"] = encoder_outputs
|
933 |
+
|
934 |
+
outputs = model(**batch, loss_reduction="sum")
|
935 |
+
# CE (data) loss
|
936 |
+
ce_loss = (outputs.loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch
|
937 |
+
|
938 |
+
metrics = {"loss": ce_loss}
|
939 |
+
|
940 |
+
# per CE loss
|
941 |
+
per_codebook_losses = outputs.per_codebook_losses
|
942 |
+
metrics.update({f"codebook_{i}_loss": ((l * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch) for (i,l) in enumerate(per_codebook_losses)})
|
943 |
+
return ce_loss, metrics
|
944 |
+
|
945 |
+
# Define eval fn
|
946 |
+
def eval_step(
|
947 |
+
batch,
|
948 |
+
accelerator,
|
949 |
+
autocast_kwargs,
|
950 |
+
):
|
951 |
+
eval_model = model if not training_args.torch_compile else model._orig_mod
|
952 |
+
|
953 |
+
if mixed_precision == "fp16":
|
954 |
+
# fp16 doesn't work with T5-like models
|
955 |
+
with accelerator.autocast(autocast_handler=autocast_kwargs):
|
956 |
+
if training_args.parallel_mode.value != "distributed":
|
957 |
+
encoder_outputs = model.text_encoder(
|
958 |
+
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
|
959 |
+
)
|
960 |
+
else:
|
961 |
+
encoder_outputs = model.module.text_encoder(
|
962 |
+
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
|
963 |
+
)
|
964 |
+
# we optionnally project last_hidden_state to avoid recomputing every time
|
965 |
+
encoder_hidden_states = encoder_outputs.last_hidden_state
|
966 |
+
if (
|
967 |
+
config.text_encoder.hidden_size != config.decoder.hidden_size
|
968 |
+
and config.decoder.cross_attention_hidden_size is None
|
969 |
+
):
|
970 |
+
encoder_hidden_states = (
|
971 |
+
model.enc_to_dec_proj(encoder_hidden_states)
|
972 |
+
if training_args.parallel_mode.value != "distributed"
|
973 |
+
else model.module.enc_to_dec_proj(encoder_hidden_states)
|
974 |
+
)
|
975 |
+
|
976 |
+
if batch.get("attention_mask", None) is not None:
|
977 |
+
encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None]
|
978 |
+
|
979 |
+
encoder_outputs.last_hidden_state = encoder_hidden_states
|
980 |
+
batch["encoder_outputs"] = encoder_outputs
|
981 |
+
|
982 |
+
with torch.no_grad():
|
983 |
+
outputs = eval_model(**batch)
|
984 |
+
# CE (data) loss
|
985 |
+
ce_loss = outputs.loss
|
986 |
+
metrics = {"loss": ce_loss}
|
987 |
+
|
988 |
+
# per CE loss
|
989 |
+
per_codebook_losses = outputs.per_codebook_losses
|
990 |
+
metrics.update({f"codebook_{i}_loss": l for (i,l) in enumerate(per_codebook_losses)})
|
991 |
+
return metrics
|
992 |
+
|
993 |
+
def generate_step(batch, accelerator):
|
994 |
+
batch.pop("decoder_attention_mask", None)
|
995 |
+
eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=True)
|
996 |
+
if training_args.torch_compile:
|
997 |
+
# if the model is compiled, we use the original model bc compile is not compatible with .generate
|
998 |
+
eval_model = model._orig_mod
|
999 |
+
|
1000 |
+
# since we've might have loaded the weights in fp32, we have to autocast to ensure FA2 weights are in half-precision.
|
1001 |
+
# with accelerator.autocast(autocast_handler=AutocastKwargs(enabled=(attn_implementation=="flash_attention_2"))):
|
1002 |
+
output_audios = eval_model.generate(**batch, **gen_kwargs)
|
1003 |
+
output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
|
1004 |
+
return output_audios
|
1005 |
+
|
1006 |
+
model.train()
|
1007 |
+
|
1008 |
+
total_batched_samples = resume_step if resume_step is not None else 0
|
1009 |
+
for epoch in range(epochs_trained, num_epochs):
|
1010 |
+
with accelerator.local_main_process_first():
|
1011 |
+
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
|
1012 |
+
sampler = None
|
1013 |
+
if training_args.group_by_length:
|
1014 |
+
sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"])
|
1015 |
+
train_dataloader = DataLoader(
|
1016 |
+
vectorized_datasets["train"],
|
1017 |
+
collate_fn=data_collator,
|
1018 |
+
batch_size=per_device_train_batch_size,
|
1019 |
+
sampler=sampler,
|
1020 |
+
shuffle=not training_args.group_by_length,
|
1021 |
+
num_workers=training_args.dataloader_num_workers,
|
1022 |
+
pin_memory=training_args.dataloader_pin_memory,
|
1023 |
+
)
|
1024 |
+
train_dataloader = accelerator.prepare(train_dataloader)
|
1025 |
+
if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
|
1026 |
+
train_dataloader.dataset.set_epoch(epoch)
|
1027 |
+
|
1028 |
+
if resume_step is not None:
|
1029 |
+
# Skip the first N batches in the dataloader when resuming from a checkpoint
|
1030 |
+
logger.info(f" Skip first {resume_step} batches")
|
1031 |
+
train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
|
1032 |
+
resume_step = None
|
1033 |
+
accelerator.wait_for_everyone()
|
1034 |
+
|
1035 |
+
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
|
1036 |
+
train_iterator = iter(train_dataloader)
|
1037 |
+
num_steps_in_epoch = len(train_dataloader)
|
1038 |
+
remainder = num_steps_in_epoch % gradient_accumulation_steps
|
1039 |
+
remainder = remainder if remainder != 0 else gradient_accumulation_steps
|
1040 |
+
total_updates = math.ceil(num_steps_in_epoch / gradient_accumulation_steps)
|
1041 |
+
|
1042 |
+
update_step = -1
|
1043 |
+
for _ in range(total_updates):
|
1044 |
+
update_step += 1
|
1045 |
+
|
1046 |
+
# preload the total batch per step
|
1047 |
+
batch_samples = []
|
1048 |
+
num_batches_in_step = gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
|
1049 |
+
for _ in range(num_batches_in_step):
|
1050 |
+
batch_samples += [next(train_iterator)]
|
1051 |
+
|
1052 |
+
# get num items in batch - if different than BOS and than -100
|
1053 |
+
num_items_in_batch = sum([(batch["labels"].ne(audio_encoder_bos_token_id) | batch["labels"].ne(-100) | batch["labels"].ne(audio_encoder_eos_token_id)).sum((0,1))[0] for batch in batch_samples])
|
1054 |
+
num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item()
|
1055 |
+
|
1056 |
+
# losses = []
|
1057 |
+
for i,batch in enumerate(batch_samples):
|
1058 |
+
total_batched_samples += 1
|
1059 |
+
ctx = model.no_sync if (i < len(batch_samples) - 1 and accelerator.num_processes > 1) else contextlib.nullcontext
|
1060 |
+
|
1061 |
+
with ctx():
|
1062 |
+
loss, train_metric = train_step(batch, accelerator, autocast_kwargs, num_items_in_batch, gradient_accumulation_steps)
|
1063 |
+
accelerator.backward(loss)
|
1064 |
+
# losses.append(loss.detach())
|
1065 |
+
|
1066 |
+
grad_norm = accelerator.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
|
1067 |
+
optimizer.step()
|
1068 |
+
lr_scheduler.step()
|
1069 |
+
optimizer.zero_grad()
|
1070 |
+
|
1071 |
+
# The accelerator has performed an optimization step behind the scenes
|
1072 |
+
steps_trained_progress_bar.update(1)
|
1073 |
+
cur_step += 1
|
1074 |
+
|
1075 |
+
# losses = accelerator.gather(sum(losses)).sum().item() / (accelerator.num_processes * gradient_accumulation_steps)
|
1076 |
+
|
1077 |
+
if cur_step % training_args.logging_steps == 0:
|
1078 |
+
steps_trained_progress_bar.write(
|
1079 |
+
f"Step... ({cur_step} / {total_train_steps} | Loss:"
|
1080 |
+
f" {train_metric['loss']}, Learning Rate:"
|
1081 |
+
f" {lr_scheduler.get_last_lr()[0]})"
|
1082 |
+
)
|
1083 |
+
train_metric["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
|
1084 |
+
log_metric(
|
1085 |
+
accelerator,
|
1086 |
+
metrics=train_metric,
|
1087 |
+
learning_rate=lr_scheduler.get_last_lr()[0],
|
1088 |
+
train_time=train_time + time.time() - train_start,
|
1089 |
+
step=cur_step,
|
1090 |
+
epoch=epoch,
|
1091 |
+
prefix="train",
|
1092 |
+
)
|
1093 |
+
|
1094 |
+
# save checkpoint and weights after each save_steps and at the end of training
|
1095 |
+
if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
|
1096 |
+
intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
|
1097 |
+
# safe_serialization=False to avoid shared tensors saving issue (TODO(YL): it's a temporary fix)
|
1098 |
+
# https://github.com/huggingface/transformers/issues/27293#issuecomment-1872560074
|
1099 |
+
accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
|
1100 |
+
accelerator.wait_for_everyone()
|
1101 |
+
if accelerator.is_main_process:
|
1102 |
+
rotate_checkpoints(
|
1103 |
+
training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger
|
1104 |
+
)
|
1105 |
+
|
1106 |
+
if cur_step == total_train_steps:
|
1107 |
+
# un-wrap student model for save
|
1108 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
1109 |
+
unwrapped_model.save_pretrained(training_args.output_dir)
|
1110 |
+
|
1111 |
+
if training_args.push_to_hub:
|
1112 |
+
api.upload_folder(
|
1113 |
+
repo_id=repo_id,
|
1114 |
+
folder_path=training_args.output_dir,
|
1115 |
+
commit_message=f"Saving train state of step {cur_step}",
|
1116 |
+
run_as_future=True,
|
1117 |
+
)
|
1118 |
+
accelerator.wait_for_everyone()
|
1119 |
+
|
1120 |
+
if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
|
1121 |
+
train_time += time.time() - train_start
|
1122 |
+
# ======================== Evaluating ==============================
|
1123 |
+
model.eval()
|
1124 |
+
eval_metrics = []
|
1125 |
+
eval_preds = []
|
1126 |
+
eval_descriptions = []
|
1127 |
+
eval_prompts = []
|
1128 |
+
eval_start = time.time()
|
1129 |
+
|
1130 |
+
# release training input batch
|
1131 |
+
batch = release_memory(batch)
|
1132 |
+
|
1133 |
+
validation_dataloader = DataLoader(
|
1134 |
+
vectorized_datasets["eval"],
|
1135 |
+
collate_fn=data_collator,
|
1136 |
+
batch_size=per_device_eval_batch_size,
|
1137 |
+
drop_last=False,
|
1138 |
+
num_workers=training_args.eval_dataloader_num_workers,
|
1139 |
+
pin_memory=training_args.dataloader_pin_memory,
|
1140 |
+
)
|
1141 |
+
validation_dataloader = accelerator.prepare(validation_dataloader)
|
1142 |
+
|
1143 |
+
for batch in tqdm(
|
1144 |
+
validation_dataloader,
|
1145 |
+
desc=f"Evaluating - Inference ...",
|
1146 |
+
position=2,
|
1147 |
+
disable=not accelerator.is_local_main_process,
|
1148 |
+
):
|
1149 |
+
# Model forward
|
1150 |
+
eval_metric = eval_step(batch, accelerator, autocast_kwargs)
|
1151 |
+
eval_metric = accelerator.gather_for_metrics(eval_metric)
|
1152 |
+
eval_metric = {key: val.unsqueeze(0) if val.ndim == 0 else val for (key,val) in eval_metric.items()}
|
1153 |
+
eval_metrics.append(eval_metric)
|
1154 |
+
|
1155 |
+
if training_args.predict_with_generate and (cur_step % eval_generation_steps == 0 or cur_step == total_train_steps):
|
1156 |
+
validation_dataloader = DataLoader(
|
1157 |
+
vectorized_datasets["eval"],
|
1158 |
+
collate_fn=data_collator,
|
1159 |
+
batch_size=per_device_eval_batch_size,
|
1160 |
+
drop_last=False,
|
1161 |
+
num_workers=training_args.eval_dataloader_num_workers,
|
1162 |
+
pin_memory=training_args.dataloader_pin_memory,
|
1163 |
+
)
|
1164 |
+
validation_dataloader = accelerator.prepare(validation_dataloader)
|
1165 |
+
# generation
|
1166 |
+
for batch in tqdm(
|
1167 |
+
validation_dataloader,
|
1168 |
+
desc=f"Evaluating - Generation ...",
|
1169 |
+
position=2,
|
1170 |
+
disable=not accelerator.is_local_main_process,
|
1171 |
+
):
|
1172 |
+
generated_audios = generate_step(batch, accelerator)
|
1173 |
+
# Gather all predictions and targets
|
1174 |
+
generated_audios, input_ids, prompts = accelerator.pad_across_processes(
|
1175 |
+
(generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0
|
1176 |
+
)
|
1177 |
+
generated_audios, input_ids, prompts = accelerator.gather_for_metrics(
|
1178 |
+
(generated_audios, input_ids, prompts)
|
1179 |
+
)
|
1180 |
+
eval_preds.extend(generated_audios.to("cpu"))
|
1181 |
+
eval_descriptions.extend(input_ids.to("cpu"))
|
1182 |
+
eval_prompts.extend(prompts.to("cpu"))
|
1183 |
+
|
1184 |
+
eval_time = time.time() - eval_start
|
1185 |
+
# normalize eval metrics
|
1186 |
+
eval_metrics = {
|
1187 |
+
key: torch.mean(torch.cat([d[key] for d in eval_metrics])).to("cpu") for key in eval_metrics[0]
|
1188 |
+
}
|
1189 |
+
|
1190 |
+
# compute metrics
|
1191 |
+
metrics_desc = ""
|
1192 |
+
if training_args.predict_with_generate and (cur_step % eval_generation_steps == 0 or cur_step == total_train_steps):
|
1193 |
+
if accelerator.is_local_main_process:
|
1194 |
+
(
|
1195 |
+
metric_values,
|
1196 |
+
pred_descriptions,
|
1197 |
+
pred_prompts,
|
1198 |
+
audios,
|
1199 |
+
transcriptions,
|
1200 |
+
si_sdr_measures,
|
1201 |
+
) = compute_metrics(
|
1202 |
+
eval_preds,
|
1203 |
+
eval_descriptions,
|
1204 |
+
eval_prompts,
|
1205 |
+
accelerator.device,
|
1206 |
+
training_args.compute_clap_similarity_metric,
|
1207 |
+
training_args.compute_noise_level_metric,
|
1208 |
+
training_args.noise_level_to_compute_clean_wer,
|
1209 |
+
)
|
1210 |
+
eval_metrics.update(metric_values)
|
1211 |
+
metrics_desc = " ".join([f"Eval {key}: {value} |" for key, value in metric_values.items()])
|
1212 |
+
if "wandb" in training_args.report_to:
|
1213 |
+
log_pred(
|
1214 |
+
accelerator,
|
1215 |
+
pred_descriptions,
|
1216 |
+
pred_prompts,
|
1217 |
+
transcriptions,
|
1218 |
+
audios,
|
1219 |
+
si_sdr_measures,
|
1220 |
+
sampling_rate=sampling_rate,
|
1221 |
+
step=cur_step,
|
1222 |
+
prefix="eval",
|
1223 |
+
)
|
1224 |
+
accelerator.wait_for_everyone()
|
1225 |
+
|
1226 |
+
# Print metrics and update progress bar
|
1227 |
+
if accelerator.is_local_main_process:
|
1228 |
+
steps_trained_progress_bar.write(
|
1229 |
+
f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
|
1230 |
+
f" {metrics_desc})"
|
1231 |
+
)
|
1232 |
+
|
1233 |
+
log_metric(
|
1234 |
+
accelerator,
|
1235 |
+
metrics=eval_metrics,
|
1236 |
+
train_time=eval_time,
|
1237 |
+
step=cur_step,
|
1238 |
+
epoch=epoch,
|
1239 |
+
prefix="eval",
|
1240 |
+
)
|
1241 |
+
|
1242 |
+
# release eval batch and relax metrics
|
1243 |
+
eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric = release_memory(
|
1244 |
+
eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric
|
1245 |
+
)
|
1246 |
+
if training_args.predict_with_generate and (cur_step % eval_generation_steps == 0 or cur_step == total_train_steps):
|
1247 |
+
generated_audios, input_ids, prompts = release_memory(generated_audios, input_ids, prompts)
|
1248 |
+
|
1249 |
+
# train mode
|
1250 |
+
model.train()
|
1251 |
+
|
1252 |
+
# flush the train metrics
|
1253 |
+
train_start = time.time()
|
1254 |
+
|
1255 |
+
# break condition
|
1256 |
+
if cur_step == total_train_steps:
|
1257 |
+
continue_training = False
|
1258 |
+
break
|
1259 |
+
|
1260 |
+
if not continue_training:
|
1261 |
+
break
|
1262 |
+
|
1263 |
+
accelerator.end_training()
|
1264 |
+
|
1265 |
+
|
1266 |
+
if __name__ == "__main__":
|
1267 |
+
main()
|
capspeech/ar/training/run_parler_tts_training.py
ADDED
@@ -0,0 +1,1279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
""" Train Parler-TTS using 🤗 Accelerate"""
|
18 |
+
|
19 |
+
import logging
|
20 |
+
import os
|
21 |
+
import re
|
22 |
+
import sys
|
23 |
+
import time
|
24 |
+
import math
|
25 |
+
import contextlib
|
26 |
+
from multiprocess import set_start_method
|
27 |
+
from datetime import timedelta
|
28 |
+
import inspect
|
29 |
+
from tqdm import tqdm
|
30 |
+
from pathlib import Path
|
31 |
+
import wandb
|
32 |
+
|
33 |
+
import torch
|
34 |
+
from torch.utils.data import DataLoader
|
35 |
+
|
36 |
+
import datasets
|
37 |
+
from datasets import DatasetDict, Dataset, IterableDataset, concatenate_datasets
|
38 |
+
|
39 |
+
from huggingface_hub import HfApi
|
40 |
+
|
41 |
+
import transformers
|
42 |
+
from transformers import AutoFeatureExtractor, AutoTokenizer, HfArgumentParser
|
43 |
+
from transformers.trainer_pt_utils import LengthGroupedSampler
|
44 |
+
from transformers.optimization import get_scheduler
|
45 |
+
from transformers.utils import send_example_telemetry
|
46 |
+
|
47 |
+
|
48 |
+
from accelerate import Accelerator, skip_first_batches
|
49 |
+
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin, DistributedDataParallelKwargs
|
50 |
+
from accelerate.utils.memory import release_memory
|
51 |
+
|
52 |
+
from parler_tts import (
|
53 |
+
ParlerTTSConfig,
|
54 |
+
ParlerTTSForConditionalGeneration,
|
55 |
+
build_delay_pattern_mask,
|
56 |
+
)
|
57 |
+
|
58 |
+
from training.utils import (
|
59 |
+
get_last_checkpoint,
|
60 |
+
rotate_checkpoints,
|
61 |
+
log_pred,
|
62 |
+
log_metric,
|
63 |
+
load_all_codec_checkpoints,
|
64 |
+
save_codec_checkpoint,
|
65 |
+
get_last_codec_checkpoint_step,
|
66 |
+
)
|
67 |
+
from training.arguments import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments
|
68 |
+
from training.data import load_multiple_datasets, DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding
|
69 |
+
from training.eval import clap_similarity, wer, si_sdr
|
70 |
+
|
71 |
+
logger = logging.getLogger(__name__)
|
72 |
+
|
73 |
+
|
74 |
+
def main():
|
75 |
+
# See all possible arguments in src/transformers/training_args.py
|
76 |
+
# or by passing the --help flag to this script.
|
77 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
78 |
+
|
79 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments))
|
80 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
81 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
82 |
+
# let's parse it to get our arguments.
|
83 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
84 |
+
else:
|
85 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
86 |
+
|
87 |
+
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
88 |
+
# information sent is the one passed as arguments along with your Python/PyTorch versions.
|
89 |
+
send_example_telemetry("run_parler_tts", model_args, data_args)
|
90 |
+
|
91 |
+
if data_args.wandb_key is not None:
|
92 |
+
wandb.login(key=data_args.wandb_key)
|
93 |
+
|
94 |
+
if training_args.dtype == "float16":
|
95 |
+
mixed_precision = "fp16"
|
96 |
+
torch_dtype = torch.float16
|
97 |
+
elif training_args.dtype == "bfloat16":
|
98 |
+
mixed_precision = "bf16"
|
99 |
+
torch_dtype = torch.bfloat16
|
100 |
+
else:
|
101 |
+
mixed_precision = "no"
|
102 |
+
torch_dtype = torch.float32
|
103 |
+
|
104 |
+
if data_args.pad_to_max_length and (
|
105 |
+
data_args.max_duration_in_seconds is None
|
106 |
+
or data_args.max_prompt_token_length is None
|
107 |
+
or data_args.max_description_token_length is None
|
108 |
+
):
|
109 |
+
raise ValueError(
|
110 |
+
"`pad_to_max_length` is `True` but one of the following parameters has not been set: `max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`"
|
111 |
+
)
|
112 |
+
|
113 |
+
padding = "max_length" if data_args.pad_to_max_length else "longest"
|
114 |
+
|
115 |
+
####### A. Preparation
|
116 |
+
kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=120)), DistributedDataParallelKwargs(find_unused_parameters=False)]
|
117 |
+
|
118 |
+
accelerator = Accelerator(
|
119 |
+
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
|
120 |
+
mixed_precision=mixed_precision,
|
121 |
+
log_with=training_args.report_to,
|
122 |
+
project_dir=training_args.output_dir,
|
123 |
+
kwargs_handlers=kwargs_handlers,
|
124 |
+
)
|
125 |
+
|
126 |
+
accelerator.init_trackers(
|
127 |
+
project_name=data_args.wandb_project,
|
128 |
+
config={
|
129 |
+
"learning_rate": training_args.learning_rate,
|
130 |
+
"model_name_or_path": model_args.model_name_or_path,
|
131 |
+
"num_train_epochs": training_args.num_train_epochs,
|
132 |
+
"gradient_accumulation_steps": training_args.gradient_accumulation_steps,
|
133 |
+
"per_device_train_batch_size": training_args.per_device_train_batch_size,
|
134 |
+
"global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
|
135 |
+
"mixed_precision": mixed_precision,
|
136 |
+
"lr_scheduler_type": training_args.lr_scheduler_type,
|
137 |
+
"warmup_steps": training_args.warmup_steps,
|
138 |
+
"freeze_text_encoder": model_args.freeze_text_encoder,
|
139 |
+
"max_duration_in_seconds": data_args.max_duration_in_seconds,
|
140 |
+
"weight_decay": training_args.weight_decay,
|
141 |
+
"adam_beta1": training_args.adam_beta1,
|
142 |
+
"adam_beta2": training_args.adam_beta2,
|
143 |
+
"temperature": model_args.temperature,
|
144 |
+
},
|
145 |
+
init_kwargs={"wandb": {"name": data_args.wandb_run_name}} if data_args.wandb_run_name else {},
|
146 |
+
)
|
147 |
+
|
148 |
+
# Detecting last checkpoint and eventually continue from last checkpoint
|
149 |
+
last_checkpoint = None
|
150 |
+
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
151 |
+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
152 |
+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
153 |
+
raise ValueError(
|
154 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
155 |
+
"Use --overwrite_output_dir to overcome."
|
156 |
+
)
|
157 |
+
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
158 |
+
logger.info(
|
159 |
+
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
160 |
+
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
161 |
+
)
|
162 |
+
|
163 |
+
# Setup logging
|
164 |
+
logging.basicConfig(
|
165 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
166 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
167 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
168 |
+
)
|
169 |
+
logger.setLevel(logging.INFO if accelerator.is_main_process else logging.WARN)
|
170 |
+
|
171 |
+
# Log a small summary on each proces
|
172 |
+
logger.warning(
|
173 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
|
174 |
+
f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
|
175 |
+
)
|
176 |
+
|
177 |
+
# Set the verbosity to info of the Transformers logger (on main process only)
|
178 |
+
if accelerator.is_local_main_process:
|
179 |
+
datasets.utils.logging.set_verbosity_warning()
|
180 |
+
transformers.utils.logging.set_verbosity_info()
|
181 |
+
else:
|
182 |
+
datasets.utils.logging.set_verbosity_error()
|
183 |
+
transformers.utils.logging.set_verbosity_error()
|
184 |
+
|
185 |
+
logger.info("Training/evaluation parameters %s", training_args)
|
186 |
+
|
187 |
+
# Set seed before initializing model.
|
188 |
+
set_seed(training_args.seed)
|
189 |
+
num_workers = data_args.preprocessing_num_workers
|
190 |
+
|
191 |
+
# 1. First, lett's instantiate the feature extractor, tokenizers and model
|
192 |
+
# Note for distributed training, the .from_pretrained methods guarantee that only
|
193 |
+
# one local process can concurrently download model & vocab.
|
194 |
+
|
195 |
+
# load feature extractor
|
196 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
197 |
+
model_args.feature_extractor_name or model_args.model_name_or_path,
|
198 |
+
cache_dir=model_args.cache_dir,
|
199 |
+
token=data_args.token,
|
200 |
+
trust_remote_code=data_args.trust_remote_code,
|
201 |
+
)
|
202 |
+
sampling_rate = feature_extractor.sampling_rate
|
203 |
+
|
204 |
+
# load prompt tokenizer
|
205 |
+
prompt_tokenizer = AutoTokenizer.from_pretrained(
|
206 |
+
model_args.prompt_tokenizer_name or model_args.description_tokenizer_name or model_args.model_name_or_path,
|
207 |
+
cache_dir=model_args.cache_dir,
|
208 |
+
token=data_args.token,
|
209 |
+
trust_remote_code=data_args.trust_remote_code,
|
210 |
+
use_fast=model_args.use_fast_tokenizer,
|
211 |
+
padding_side=model_args.prompt_padding_side,
|
212 |
+
)
|
213 |
+
|
214 |
+
# load description tokenizer
|
215 |
+
description_tokenizer = AutoTokenizer.from_pretrained(
|
216 |
+
model_args.description_tokenizer_name or model_args.model_name_or_path,
|
217 |
+
cache_dir=model_args.cache_dir,
|
218 |
+
token=data_args.token,
|
219 |
+
trust_remote_code=data_args.trust_remote_code,
|
220 |
+
use_fast=model_args.use_fast_tokenizer,
|
221 |
+
)
|
222 |
+
|
223 |
+
if model_args.use_fast_tokenizer:
|
224 |
+
logger.warning(
|
225 |
+
"Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235"
|
226 |
+
)
|
227 |
+
prompt_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
|
228 |
+
description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
|
229 |
+
|
230 |
+
# 2. Now, let's load the dataset
|
231 |
+
|
232 |
+
if data_args.save_to_disk is not None:
|
233 |
+
os.makedirs(data_args.save_to_disk, exist_ok=True)
|
234 |
+
|
235 |
+
# assume that the dataset has been saved to `save_to_disk` if the latter is not empty
|
236 |
+
dataset_was_precomputed = len(os.listdir(data_args.save_to_disk)) > 0
|
237 |
+
if dataset_was_precomputed:
|
238 |
+
with accelerator.local_main_process_first():
|
239 |
+
vectorized_datasets = datasets.load_from_disk(data_args.save_to_disk)
|
240 |
+
else:
|
241 |
+
raw_datasets = DatasetDict()
|
242 |
+
|
243 |
+
columns_to_keep = {
|
244 |
+
"target_audio_column_name": data_args.target_audio_column_name,
|
245 |
+
"prompt_column_name": data_args.prompt_column_name,
|
246 |
+
"source": data_args.source_column_name,
|
247 |
+
}
|
248 |
+
if data_args.description_column_name is not None:
|
249 |
+
columns_to_keep["description_column_name"] = data_args.description_column_name
|
250 |
+
|
251 |
+
if training_args.do_train:
|
252 |
+
raw_datasets["train"] = load_multiple_datasets(
|
253 |
+
accelerator,
|
254 |
+
data_args.train_dataset_name,
|
255 |
+
splits=data_args.train_split_name,
|
256 |
+
dataset_samples=data_args.train_dataset_samples,
|
257 |
+
seed=training_args.seed,
|
258 |
+
cache_dir=model_args.cache_dir,
|
259 |
+
num_proc=data_args.preprocessing_num_workers,
|
260 |
+
id_column_name=data_args.id_column_name,
|
261 |
+
columns_to_keep=columns_to_keep.values(),
|
262 |
+
prompt_column_name=data_args.prompt_column_name,
|
263 |
+
audio_column_name=data_args.target_audio_column_name,
|
264 |
+
sampling_rate=sampling_rate,
|
265 |
+
logger=logger,
|
266 |
+
mls_dir=data_args.mls_dir,
|
267 |
+
librittsrmix_dir=data_args.librittsrmix_dir,
|
268 |
+
gigaspeech_dir=data_args.gigaspeech_dir,
|
269 |
+
commonvoice_dir=data_args.commonvoice_dir,
|
270 |
+
emilia_dir=data_args.emilia_dir,
|
271 |
+
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
|
272 |
+
)
|
273 |
+
|
274 |
+
for key in columns_to_keep:
|
275 |
+
if columns_to_keep[key] not in raw_datasets["train"].column_names:
|
276 |
+
raise ValueError(
|
277 |
+
f"--{key} '{columns_to_keep[key]}' not found in dataset '{data_args.train_dataset_name}'."
|
278 |
+
f" Make sure to set `--{key}` to the correct audio column - one of"
|
279 |
+
f" {', '.join(raw_datasets['train'].column_names)}."
|
280 |
+
)
|
281 |
+
|
282 |
+
if data_args.max_train_samples is not None:
|
283 |
+
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
284 |
+
|
285 |
+
if training_args.do_eval:
|
286 |
+
raw_datasets["eval"] = load_multiple_datasets(
|
287 |
+
accelerator,
|
288 |
+
data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
|
289 |
+
splits=data_args.eval_split_name,
|
290 |
+
cache_dir=model_args.cache_dir,
|
291 |
+
num_proc=data_args.preprocessing_num_workers,
|
292 |
+
id_column_name=data_args.id_column_name,
|
293 |
+
columns_to_keep=columns_to_keep.values(),
|
294 |
+
prompt_column_name=data_args.prompt_column_name,
|
295 |
+
audio_column_name=data_args.target_audio_column_name,
|
296 |
+
sampling_rate=sampling_rate,
|
297 |
+
logger=logger,
|
298 |
+
mls_dir=data_args.mls_dir,
|
299 |
+
librittsrmix_dir=data_args.librittsrmix_dir,
|
300 |
+
gigaspeech_dir=data_args.gigaspeech_dir,
|
301 |
+
commonvoice_dir=data_args.commonvoice_dir,
|
302 |
+
emilia_dir=data_args.emilia_dir
|
303 |
+
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
|
304 |
+
)
|
305 |
+
|
306 |
+
if data_args.max_eval_samples is not None:
|
307 |
+
with accelerator.local_main_process_first():
|
308 |
+
raw_datasets["eval"] = (
|
309 |
+
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
|
310 |
+
)
|
311 |
+
|
312 |
+
# 3. Next, let's load the config.
|
313 |
+
config = ParlerTTSConfig.from_pretrained(
|
314 |
+
model_args.model_name_or_path,
|
315 |
+
cache_dir=model_args.cache_dir,
|
316 |
+
token=data_args.token,
|
317 |
+
trust_remote_code=data_args.trust_remote_code,
|
318 |
+
)
|
319 |
+
|
320 |
+
if training_args.codebook_weights is not None and len(training_args.codebook_weights) != config.decoder.num_codebooks:
|
321 |
+
raise ValueError(f"`codebook_weights` has length {len(training_args.codebook_weights)} when it should be of length {config.decoder.num_codebooks}.")
|
322 |
+
|
323 |
+
# update pad token id and decoder_start_token_id
|
324 |
+
config.decoder.update(
|
325 |
+
{
|
326 |
+
"cross_attention_implementation_strategy": model_args.cross_attention_implementation_strategy
|
327 |
+
if model_args.cross_attention_implementation_strategy is not None
|
328 |
+
else None,
|
329 |
+
"codebook_weights": training_args.codebook_weights if training_args.codebook_weights is not None else config.decoder.codebook_weights
|
330 |
+
}
|
331 |
+
)
|
332 |
+
config.update(
|
333 |
+
{
|
334 |
+
"pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id,
|
335 |
+
"decoder_start_token_id": model_args.decoder_start_token_id
|
336 |
+
if model_args.decoder_start_token_id is not None
|
337 |
+
else config.decoder_start_token_id,
|
338 |
+
}
|
339 |
+
)
|
340 |
+
|
341 |
+
with open("events.txt", "r") as f:
|
342 |
+
events = [line.strip() for line in f]
|
343 |
+
events = ["<"+event.lower().replace(" ", "_")+">" for event in events]
|
344 |
+
events.append("<B_start>")
|
345 |
+
events.append("<B_end>")
|
346 |
+
events.append("<I_start>")
|
347 |
+
events.append("<I_end>")
|
348 |
+
|
349 |
+
special_tokens = {"additional_special_tokens": events}
|
350 |
+
prompt_tokenizer.add_special_tokens(special_tokens)
|
351 |
+
description_tokenizer.add_special_tokens(special_tokens)
|
352 |
+
padded_vocab_size = ((len(prompt_tokenizer) + 127) // 128) * 128
|
353 |
+
config.vocab_size = padded_vocab_size
|
354 |
+
|
355 |
+
# create model
|
356 |
+
model = ParlerTTSForConditionalGeneration.from_pretrained(
|
357 |
+
model_args.model_name_or_path,
|
358 |
+
ignore_mismatched_sizes=True,
|
359 |
+
cache_dir=model_args.cache_dir,
|
360 |
+
config=config,
|
361 |
+
token=data_args.token,
|
362 |
+
trust_remote_code=data_args.trust_remote_code,
|
363 |
+
attn_implementation={"decoder": model_args.attn_implementation, "text_encoder": "eager"},
|
364 |
+
)
|
365 |
+
model.text_encoder.resize_token_embeddings(padded_vocab_size)
|
366 |
+
|
367 |
+
# enable gradient checkpointing if necessary
|
368 |
+
if training_args.gradient_checkpointing:
|
369 |
+
model.gradient_checkpointing_enable()
|
370 |
+
|
371 |
+
# 4. Now we preprocess the datasets including loading the audio, resampling and normalization
|
372 |
+
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
|
373 |
+
# so that we just need to set the correct target sampling rate and normalize the input
|
374 |
+
# via the `feature_extractor`
|
375 |
+
|
376 |
+
# derive max & min input length for sample rate & max duration
|
377 |
+
sampling_rate = feature_extractor.sampling_rate
|
378 |
+
max_target_length = int(data_args.max_duration_in_seconds * sampling_rate)
|
379 |
+
min_target_length = int(data_args.min_duration_in_seconds * sampling_rate)
|
380 |
+
target_audio_column_name = data_args.target_audio_column_name
|
381 |
+
description_column_name = data_args.description_column_name
|
382 |
+
prompt_column_name = data_args.prompt_column_name
|
383 |
+
feature_extractor_input_name = feature_extractor.model_input_names[0]
|
384 |
+
audio_encoder_pad_token_id = config.decoder.pad_token_id
|
385 |
+
audio_encoder_eos_token_id = config.decoder.eos_token_id
|
386 |
+
audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id
|
387 |
+
max_length = model.generation_config.max_length
|
388 |
+
num_codebooks = model.decoder.config.num_codebooks
|
389 |
+
bandwidth = model_args.bandwidth
|
390 |
+
attn_implementation = model_args.attn_implementation
|
391 |
+
|
392 |
+
# Freeze Encoders
|
393 |
+
model.freeze_encoders(model_args.freeze_text_encoder)
|
394 |
+
|
395 |
+
# Test all gather - used for warmout and avoiding timeout
|
396 |
+
logger.debug(str(accelerator.process_index), main_process_only=False, in_order=True)
|
397 |
+
test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
|
398 |
+
gathered_tensor = accelerator.gather(test_tensor)
|
399 |
+
print("gathered_tensor", gathered_tensor)
|
400 |
+
accelerator.wait_for_everyone()
|
401 |
+
|
402 |
+
if not dataset_was_precomputed:
|
403 |
+
# Filter on text length
|
404 |
+
if description_column_name is not None and data_args.max_text_length is not None:
|
405 |
+
with accelerator.local_main_process_first():
|
406 |
+
# filter description that is shorter than max_text_length
|
407 |
+
raw_datasets = raw_datasets.filter(
|
408 |
+
lambda x: len(x) < data_args.max_text_length,
|
409 |
+
num_proc=num_workers,
|
410 |
+
input_columns=[description_column_name],
|
411 |
+
)
|
412 |
+
|
413 |
+
# Preprocessing the dataset.
|
414 |
+
# We need to tokenize the texts.
|
415 |
+
def pass_through_processors(description, prompt):
|
416 |
+
batch = {}
|
417 |
+
|
418 |
+
batch["input_ids"] = description_tokenizer(description.strip())["input_ids"]
|
419 |
+
batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"]
|
420 |
+
|
421 |
+
return batch
|
422 |
+
|
423 |
+
with accelerator.local_main_process_first():
|
424 |
+
# this is a trick to avoid to rewrite the entire audio column which takes ages
|
425 |
+
vectorized_datasets = raw_datasets.map(
|
426 |
+
pass_through_processors,
|
427 |
+
remove_columns=next(iter(raw_datasets.values())).column_names,
|
428 |
+
input_columns=[description_column_name, prompt_column_name],
|
429 |
+
num_proc=num_workers,
|
430 |
+
desc="preprocess datasets",
|
431 |
+
)
|
432 |
+
|
433 |
+
# We use Accelerate to perform distributed inference
|
434 |
+
# T5 doesn't support fp16
|
435 |
+
autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
|
436 |
+
|
437 |
+
# Now we encode the audio labels with encodec.
|
438 |
+
####### B. Encode audio
|
439 |
+
|
440 |
+
logger.info("*** Encode target audio with encodec ***")
|
441 |
+
|
442 |
+
# no need to prepare audio_decoder because used for inference without mixed precision
|
443 |
+
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
|
444 |
+
if training_args.torch_compile:
|
445 |
+
audio_decoder = accelerator.prepare_model(model.audio_encoder, evaluation_mode=True)
|
446 |
+
else:
|
447 |
+
audio_decoder = model.audio_encoder
|
448 |
+
|
449 |
+
encoder_data_collator = DataCollatorEncodecWithPadding(
|
450 |
+
feature_extractor,
|
451 |
+
audio_column_name=target_audio_column_name,
|
452 |
+
mls_dir=data_args.mls_dir,
|
453 |
+
librittsrmix_dir=data_args.librittsrmix_dir,
|
454 |
+
gigaspeech_dir=data_args.gigaspeech_dir,
|
455 |
+
commonvoice_dir=data_args.commonvoice_dir,
|
456 |
+
emilia_dir=data_args.emilia_dir,
|
457 |
+
feature_extractor_input_name=feature_extractor_input_name,
|
458 |
+
max_length=max_target_length,
|
459 |
+
padding=padding,
|
460 |
+
)
|
461 |
+
encoder_signature = set(inspect.signature(audio_decoder.forward).parameters)
|
462 |
+
|
463 |
+
def apply_audio_decoder(batch):
|
464 |
+
len_audio = batch.pop("len_audio")
|
465 |
+
audio_decoder.to(batch["input_values"].device).eval()
|
466 |
+
if bandwidth is not None:
|
467 |
+
batch["bandwidth"] = bandwidth
|
468 |
+
elif "num_quantizers" in encoder_signature:
|
469 |
+
batch["num_quantizers"] = num_codebooks
|
470 |
+
elif "num_codebooks" in encoder_signature:
|
471 |
+
batch["num_codebooks"] = num_codebooks
|
472 |
+
elif "n_quantizers" in encoder_signature:
|
473 |
+
batch["n_quantizers"] = num_codebooks
|
474 |
+
|
475 |
+
with torch.no_grad():
|
476 |
+
labels = audio_decoder.encode(**batch)["audio_codes"]
|
477 |
+
output = {}
|
478 |
+
output["len_audio"] = len_audio
|
479 |
+
# (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
|
480 |
+
output["labels"] = labels.squeeze(0).transpose(1, 2)
|
481 |
+
|
482 |
+
# if `pad_to_max_length`, the maximum corresponding audio length of the current batch is max_duration*sampling_rate
|
483 |
+
max_length = len_audio.max() if padding != "max_length" else max_target_length
|
484 |
+
output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / max_length
|
485 |
+
return output
|
486 |
+
|
487 |
+
# (1, codebooks, seq_len) where seq_len=1
|
488 |
+
bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
|
489 |
+
|
490 |
+
def postprocess_dataset(labels):
|
491 |
+
# (1, codebooks, seq_len)
|
492 |
+
labels = torch.tensor(labels).unsqueeze(0)
|
493 |
+
# add bos
|
494 |
+
labels = torch.cat([bos_labels, labels], dim=-1)
|
495 |
+
|
496 |
+
labels, delay_pattern_mask = build_delay_pattern_mask(
|
497 |
+
labels,
|
498 |
+
bos_token_id=audio_encoder_bos_token_id,
|
499 |
+
pad_token_id=audio_encoder_eos_token_id,
|
500 |
+
max_length=labels.shape[-1] + num_codebooks,
|
501 |
+
num_codebooks=num_codebooks,
|
502 |
+
)
|
503 |
+
|
504 |
+
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
|
505 |
+
# to take care of EOS
|
506 |
+
# we want labels to look like this:
|
507 |
+
# - [B, a, b, E, E, E, E]
|
508 |
+
# - [B, B, c, d, E, E, E]
|
509 |
+
# - [B, B, B, e, f, E, E]
|
510 |
+
# - [B, B, B, B, g, h, E]
|
511 |
+
labels = torch.where(delay_pattern_mask == -1, audio_encoder_eos_token_id, delay_pattern_mask)
|
512 |
+
|
513 |
+
# the first timestamp is associated to a row full of BOS, let's get rid of it
|
514 |
+
# we also remove the last timestampts (full of PAD)
|
515 |
+
output = {"labels": labels[:, 1:]}
|
516 |
+
return output
|
517 |
+
|
518 |
+
for split in vectorized_datasets:
|
519 |
+
data_loader = DataLoader(
|
520 |
+
raw_datasets[split],
|
521 |
+
batch_size=training_args.audio_encoder_per_device_batch_size,
|
522 |
+
collate_fn=encoder_data_collator,
|
523 |
+
num_workers=training_args.dataloader_num_workers,
|
524 |
+
pin_memory=True,
|
525 |
+
)
|
526 |
+
data_loader = accelerator.prepare(data_loader)
|
527 |
+
total_inference_steps = len(data_loader)
|
528 |
+
|
529 |
+
start_step = get_last_codec_checkpoint_step(os.path.join(data_args.temporary_save_to_disk, split))
|
530 |
+
accelerator.wait_for_everyone()
|
531 |
+
if start_step > 0:
|
532 |
+
logger.info(f"Resuming {split} from step {start_step}")
|
533 |
+
# efficiently skip the first n batches
|
534 |
+
start_step += 1
|
535 |
+
data_loader = skip_first_batches(data_loader, start_step)
|
536 |
+
|
537 |
+
all_generated_labels = []
|
538 |
+
all_lens = []
|
539 |
+
if start_step < total_inference_steps:
|
540 |
+
for i, batch in enumerate(tqdm(data_loader, disable=not accelerator.is_local_main_process)):
|
541 |
+
cur_step = start_step + i
|
542 |
+
generate_labels = apply_audio_decoder(batch)
|
543 |
+
generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0)
|
544 |
+
generate_labels = accelerator.gather_for_metrics(generate_labels)
|
545 |
+
|
546 |
+
if accelerator.is_main_process:
|
547 |
+
lab = generate_labels["labels"].cpu().transpose(1, 2).to(torch.int16)
|
548 |
+
rat = generate_labels["ratio"].cpu().squeeze(1)
|
549 |
+
lens = generate_labels["len_audio"].cpu().squeeze(1)
|
550 |
+
lab = [l[:, : int(ratio * length)] for (l, ratio, length) in zip(lab, rat, lens)]
|
551 |
+
|
552 |
+
all_generated_labels.extend(lab)
|
553 |
+
all_lens.extend(lens)
|
554 |
+
|
555 |
+
if ((cur_step + 1) % data_args.save_codec_steps == 0) or (
|
556 |
+
cur_step == total_inference_steps - 1
|
557 |
+
):
|
558 |
+
tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
|
559 |
+
tmp_labels = tmp_labels.map(
|
560 |
+
postprocess_dataset,
|
561 |
+
num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
|
562 |
+
input_columns=["labels"],
|
563 |
+
desc="Postprocessing labeling",
|
564 |
+
)
|
565 |
+
save_codec_checkpoint(
|
566 |
+
os.path.join(data_args.temporary_save_to_disk, split), tmp_labels, cur_step
|
567 |
+
)
|
568 |
+
all_generated_labels = []
|
569 |
+
all_lens = []
|
570 |
+
|
571 |
+
accelerator.wait_for_everyone()
|
572 |
+
|
573 |
+
if accelerator.is_main_process and len(all_generated_labels) > 0:
|
574 |
+
tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
|
575 |
+
tmp_labels = tmp_labels.map(
|
576 |
+
postprocess_dataset,
|
577 |
+
num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
|
578 |
+
input_columns=["labels"],
|
579 |
+
desc="Postprocessing labeling",
|
580 |
+
)
|
581 |
+
save_codec_checkpoint(os.path.join(data_args.temporary_save_to_disk, split), tmp_labels, cur_step)
|
582 |
+
all_generated_labels = []
|
583 |
+
all_lens = []
|
584 |
+
accelerator.wait_for_everyone()
|
585 |
+
|
586 |
+
del all_generated_labels
|
587 |
+
accelerator.wait_for_everyone()
|
588 |
+
|
589 |
+
with accelerator.local_main_process_first():
|
590 |
+
tmp_labels = load_all_codec_checkpoints(os.path.join(data_args.temporary_save_to_disk, split)).select(
|
591 |
+
range(len(vectorized_datasets[split]))
|
592 |
+
)
|
593 |
+
logger.info(f"Concatenating {split}: {tmp_labels} with {vectorized_datasets[split]}")
|
594 |
+
vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1)
|
595 |
+
|
596 |
+
accelerator.free_memory()
|
597 |
+
del generate_labels, all_lens
|
598 |
+
|
599 |
+
with accelerator.local_main_process_first():
|
600 |
+
# NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations
|
601 |
+
# caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets.
|
602 |
+
# That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets.
|
603 |
+
|
604 |
+
def is_audio_in_length_range(length):
|
605 |
+
return length > min_target_length and length < max_target_length
|
606 |
+
|
607 |
+
# filter data that is shorter than min_target_length
|
608 |
+
vectorized_datasets = vectorized_datasets.filter(
|
609 |
+
is_audio_in_length_range,
|
610 |
+
num_proc=num_workers,
|
611 |
+
input_columns=["target_length"],
|
612 |
+
)
|
613 |
+
|
614 |
+
if description_column_name is not None and data_args.max_description_token_length is not None:
|
615 |
+
with accelerator.local_main_process_first():
|
616 |
+
# filter description that is shorter than max_text_length
|
617 |
+
vectorized_datasets = vectorized_datasets.filter(
|
618 |
+
lambda x: len(x) < data_args.max_description_token_length,
|
619 |
+
num_proc=num_workers,
|
620 |
+
input_columns=["input_ids"],
|
621 |
+
)
|
622 |
+
|
623 |
+
if data_args.max_prompt_token_length is not None:
|
624 |
+
with accelerator.local_main_process_first():
|
625 |
+
# filter description that is shorter than max_text_length
|
626 |
+
vectorized_datasets = vectorized_datasets.filter(
|
627 |
+
lambda x: len(x) < data_args.max_prompt_token_length,
|
628 |
+
num_proc=num_workers,
|
629 |
+
input_columns=["prompt_input_ids"],
|
630 |
+
)
|
631 |
+
|
632 |
+
if data_args.save_to_disk is not None and not dataset_was_precomputed:
|
633 |
+
if accelerator.is_main_process:
|
634 |
+
vectorized_datasets.save_to_disk(
|
635 |
+
data_args.save_to_disk,
|
636 |
+
num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"]) - 1),
|
637 |
+
)
|
638 |
+
accelerator.wait_for_everyone()
|
639 |
+
logger.info(f"Dataset saved at {data_args.save_to_disk}")
|
640 |
+
|
641 |
+
audio_max_length = None
|
642 |
+
if padding == "max_length":
|
643 |
+
audio_max_length = max(vectorized_datasets["train"]["target_length"])
|
644 |
+
with accelerator.local_main_process_first():
|
645 |
+
max_sample = vectorized_datasets["train"].filter(
|
646 |
+
lambda x: x == audio_max_length,
|
647 |
+
num_proc=num_workers,
|
648 |
+
input_columns=["target_length"],
|
649 |
+
)
|
650 |
+
audio_max_length = max([len(l[0]) for l in max_sample["labels"]])
|
651 |
+
|
652 |
+
if description_column_name is not None and data_args.max_description_token_length is not None:
|
653 |
+
with accelerator.local_main_process_first():
|
654 |
+
# filter description that is shorter than max_text_length
|
655 |
+
vectorized_datasets = vectorized_datasets.filter(
|
656 |
+
lambda x: len(x) < data_args.max_description_token_length,
|
657 |
+
num_proc=num_workers,
|
658 |
+
input_columns=["input_ids"],
|
659 |
+
)
|
660 |
+
|
661 |
+
if data_args.max_prompt_token_length is not None:
|
662 |
+
with accelerator.local_main_process_first():
|
663 |
+
# filter description that is shorter than max_text_length
|
664 |
+
vectorized_datasets = vectorized_datasets.filter(
|
665 |
+
lambda x: len(x) < data_args.max_prompt_token_length,
|
666 |
+
num_proc=num_workers,
|
667 |
+
input_columns=["prompt_input_ids"],
|
668 |
+
)
|
669 |
+
|
670 |
+
if training_args.group_by_length:
|
671 |
+
# apply a simple heuristic to take into account audio and text lengths
|
672 |
+
def add_target_lengths(target_length, prompt, description):
|
673 |
+
return {"target_length": target_length + len(prompt) + len(description)}
|
674 |
+
|
675 |
+
with accelerator.local_main_process_first():
|
676 |
+
vectorized_datasets = vectorized_datasets.map(
|
677 |
+
add_target_lengths,
|
678 |
+
num_proc=num_workers,
|
679 |
+
input_columns=["target_length", "prompt_input_ids", "input_ids"],
|
680 |
+
)
|
681 |
+
|
682 |
+
# for large datasets it is advised to run the preprocessing on a
|
683 |
+
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
684 |
+
# be a timeout when running the script in distributed mode.
|
685 |
+
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the
|
686 |
+
# cached dataset
|
687 |
+
if data_args.preprocessing_only and data_args.save_to_disk is None:
|
688 |
+
raise ValueError(
|
689 |
+
"`preprocessing_only=True` but `save_to_disk` is not set. The latter should indicates where to save the dataset locally."
|
690 |
+
)
|
691 |
+
elif data_args.preprocessing_only:
|
692 |
+
logger.info(f"Data preprocessing finished. Files save at {data_args.save_to_disk}")
|
693 |
+
return
|
694 |
+
|
695 |
+
# 6. Next, we can prepare the training.
|
696 |
+
|
697 |
+
# Let's use word CLAP similary and WER metrics as our evaluation metrics,
|
698 |
+
def compute_metrics(
|
699 |
+
audios,
|
700 |
+
descriptions,
|
701 |
+
prompts,
|
702 |
+
device="cpu",
|
703 |
+
compute_clap_similarity_metric=False,
|
704 |
+
compute_noise_level_metric=False,
|
705 |
+
noise_level_to_compute_clean_wer=None,
|
706 |
+
):
|
707 |
+
results = {}
|
708 |
+
input_ids = descriptions
|
709 |
+
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
710 |
+
prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
|
711 |
+
audios = [a.float().cpu().numpy() for a in audios]
|
712 |
+
|
713 |
+
if compute_clap_similarity_metric:
|
714 |
+
clap_score = clap_similarity(
|
715 |
+
model_args.clap_model_name_or_path, texts, audios, device, input_sampling_rate=sampling_rate
|
716 |
+
)
|
717 |
+
results["clap"] = clap_score
|
718 |
+
|
719 |
+
si_sdr_measures = None
|
720 |
+
if compute_noise_level_metric:
|
721 |
+
si_sdr_measures = si_sdr(audios, device, input_sampling_rate=sampling_rate)
|
722 |
+
|
723 |
+
word_error, transcriptions, clean_word_error, noisy_word_error, percent_clean_samples = wer(
|
724 |
+
model_args.asr_model_name_or_path,
|
725 |
+
prompts,
|
726 |
+
audios,
|
727 |
+
device,
|
728 |
+
training_args.per_device_eval_batch_size,
|
729 |
+
sampling_rate,
|
730 |
+
noise_level_to_compute_clean_wer,
|
731 |
+
si_sdr_measures,
|
732 |
+
)
|
733 |
+
results["wer"] = word_error
|
734 |
+
if clean_word_error is not None:
|
735 |
+
results["clean_wer"] = clean_word_error
|
736 |
+
results["noisy_word_error"] = noisy_word_error
|
737 |
+
results["percent_clean_samples"] = percent_clean_samples
|
738 |
+
|
739 |
+
return results, texts, prompts, audios, transcriptions, si_sdr_measures
|
740 |
+
|
741 |
+
# Define Training Schedule
|
742 |
+
# Store some constants
|
743 |
+
per_device_train_batch_size = int(training_args.per_device_train_batch_size)
|
744 |
+
train_batch_size = per_device_train_batch_size * accelerator.num_processes
|
745 |
+
gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
|
746 |
+
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
|
747 |
+
|
748 |
+
if training_args.max_steps < 0:
|
749 |
+
num_epochs = int(training_args.num_train_epochs)
|
750 |
+
steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
|
751 |
+
total_train_steps = steps_per_epoch * num_epochs
|
752 |
+
elif training_args.max_steps > 0:
|
753 |
+
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
754 |
+
total_train_steps = int(training_args.max_steps)
|
755 |
+
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
|
756 |
+
num_epochs = sys.maxsize
|
757 |
+
steps_per_epoch = total_train_steps
|
758 |
+
|
759 |
+
if training_args.eval_steps is None:
|
760 |
+
logger.info(f"eval_steps is not set, evaluating at the end of each epoch")
|
761 |
+
eval_steps = steps_per_epoch
|
762 |
+
else:
|
763 |
+
eval_steps = training_args.eval_steps
|
764 |
+
|
765 |
+
if training_args.eval_generation_steps is None:
|
766 |
+
eval_generation_steps = eval_steps
|
767 |
+
else:
|
768 |
+
eval_generation_steps = training_args.eval_generation_steps
|
769 |
+
|
770 |
+
# T5 doesn't support fp16
|
771 |
+
autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
|
772 |
+
|
773 |
+
# Define optimizer, LR scheduler, collator
|
774 |
+
optimizer = torch.optim.AdamW(
|
775 |
+
params=model.parameters(),
|
776 |
+
lr=training_args.learning_rate,
|
777 |
+
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
778 |
+
eps=training_args.adam_epsilon,
|
779 |
+
weight_decay=training_args.weight_decay,
|
780 |
+
)
|
781 |
+
|
782 |
+
# LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
|
783 |
+
lr_scheduler = get_scheduler(
|
784 |
+
name=training_args.lr_scheduler_type,
|
785 |
+
optimizer=optimizer,
|
786 |
+
num_warmup_steps=training_args.get_warmup_steps(total_train_steps) * accelerator.num_processes,
|
787 |
+
num_training_steps=total_train_steps * accelerator.num_processes,
|
788 |
+
)
|
789 |
+
|
790 |
+
# Instantiate custom data collator
|
791 |
+
data_collator = DataCollatorParlerTTSWithPadding(
|
792 |
+
prompt_tokenizer=prompt_tokenizer,
|
793 |
+
description_tokenizer=description_tokenizer,
|
794 |
+
pad_to_multiple_of=data_args.pad_to_multiple_of,
|
795 |
+
padding=padding,
|
796 |
+
prompt_max_length=data_args.max_prompt_token_length,
|
797 |
+
description_max_length=data_args.max_description_token_length,
|
798 |
+
audio_max_length=audio_max_length,
|
799 |
+
)
|
800 |
+
|
801 |
+
# Prepare everything with accelerate
|
802 |
+
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
|
803 |
+
|
804 |
+
num_examples = total_train_steps * train_batch_size * gradient_accumulation_steps
|
805 |
+
logger.info("***** Running training *****")
|
806 |
+
logger.info(f" Num examples = {num_examples}")
|
807 |
+
logger.info(" Instantaneous batch size per device =" f" {per_device_train_batch_size}")
|
808 |
+
logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
|
809 |
+
logger.info(
|
810 |
+
f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
|
811 |
+
)
|
812 |
+
logger.info(f" Total optimization steps = {total_train_steps}")
|
813 |
+
|
814 |
+
# ======================== Training ================================
|
815 |
+
train_time = 0
|
816 |
+
train_start = time.time()
|
817 |
+
steps_trained_progress_bar = tqdm(
|
818 |
+
range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
|
819 |
+
)
|
820 |
+
continue_training = True
|
821 |
+
epochs_trained = 0
|
822 |
+
cur_step = 0
|
823 |
+
|
824 |
+
checkpoint = None
|
825 |
+
if training_args.resume_from_checkpoint is not None:
|
826 |
+
checkpoint = training_args.resume_from_checkpoint
|
827 |
+
elif last_checkpoint is not None:
|
828 |
+
checkpoint = last_checkpoint
|
829 |
+
|
830 |
+
if accelerator.is_main_process:
|
831 |
+
if training_args.push_to_hub:
|
832 |
+
api = HfApi(token=training_args.hub_token)
|
833 |
+
|
834 |
+
# Create repo (repo_name from args or inferred)
|
835 |
+
repo_name = training_args.hub_model_id
|
836 |
+
if repo_name is None:
|
837 |
+
repo_name = Path(training_args.output_dir).absolute().name
|
838 |
+
repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
|
839 |
+
|
840 |
+
with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
|
841 |
+
if "wandb" not in gitignore:
|
842 |
+
gitignore.write("wandb\n")
|
843 |
+
elif training_args.output_dir is not None:
|
844 |
+
os.makedirs(training_args.output_dir, exist_ok=True)
|
845 |
+
accelerator.wait_for_everyone()
|
846 |
+
|
847 |
+
# Now save everything to be able to create a single processor later
|
848 |
+
# make sure all processes wait until data is saved
|
849 |
+
# only the main process saves them
|
850 |
+
if accelerator.is_main_process:
|
851 |
+
# save feature extractor, tokenizer and config
|
852 |
+
if (
|
853 |
+
model_args.prompt_tokenizer_name is None
|
854 |
+
and model_args.description_tokenizer_name
|
855 |
+
or (model_args.prompt_tokenizer_name == model_args.description_tokenizer_name)
|
856 |
+
):
|
857 |
+
prompt_tokenizer.save_pretrained(training_args.output_dir)
|
858 |
+
else:
|
859 |
+
logger.warning(
|
860 |
+
f"Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer."
|
861 |
+
)
|
862 |
+
prompt_tokenizer.save_pretrained(training_args.output_dir)
|
863 |
+
|
864 |
+
feature_extractor.save_pretrained(training_args.output_dir)
|
865 |
+
config.save_pretrained(training_args.output_dir)
|
866 |
+
accelerator.wait_for_everyone()
|
867 |
+
|
868 |
+
if checkpoint is not None:
|
869 |
+
accelerator.load_state(checkpoint)
|
870 |
+
# Find num steps and epoch from saved state string pattern
|
871 |
+
pattern = r"checkpoint-(\d+)-epoch-(\d+)"
|
872 |
+
match = re.search(pattern, checkpoint)
|
873 |
+
cur_step = int(match.group(1))
|
874 |
+
epochs_trained = int(match.group(2))
|
875 |
+
|
876 |
+
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
877 |
+
logger.info(f" Continuing training from epoch {epochs_trained}")
|
878 |
+
logger.info(f" Continuing training from global step {cur_step}")
|
879 |
+
|
880 |
+
steps_trained_progress_bar.update(cur_step)
|
881 |
+
|
882 |
+
for epoch in range(0, epochs_trained):
|
883 |
+
with accelerator.local_main_process_first():
|
884 |
+
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
|
885 |
+
|
886 |
+
if training_args.max_steps < 0:
|
887 |
+
# we know exactly the number of steps per epoch, so can skip through the required number of batches
|
888 |
+
resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
|
889 |
+
else:
|
890 |
+
# Currently we don't know how many steps we've taken in the current epoch
|
891 |
+
# So we just shuffle the dataset one extra time and start from a fresh epoch
|
892 |
+
# This is "good enough" for our purposes but not fully correct
|
893 |
+
resume_step = None
|
894 |
+
with accelerator.local_main_process_first():
|
895 |
+
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
|
896 |
+
else:
|
897 |
+
resume_step = None
|
898 |
+
|
899 |
+
gen_kwargs = {
|
900 |
+
"do_sample": model_args.do_sample,
|
901 |
+
"temperature": model_args.temperature,
|
902 |
+
"max_length": model_args.max_length,
|
903 |
+
# Because of the delayed pattern mask, generation might stop earlier because of unexpected behaviour
|
904 |
+
# on the first tokens of the codebooks that are delayed.
|
905 |
+
# This fix the issue.
|
906 |
+
"min_new_tokens": num_codebooks + 1,
|
907 |
+
}
|
908 |
+
|
909 |
+
# Define gradient update step fn
|
910 |
+
def train_step(
|
911 |
+
batch,
|
912 |
+
accelerator,
|
913 |
+
autocast_kwargs,
|
914 |
+
num_items_in_batch,
|
915 |
+
gradient_accumulation_steps,
|
916 |
+
):
|
917 |
+
if mixed_precision == "fp16":
|
918 |
+
# fp16 doesn't work with T5-like models
|
919 |
+
with accelerator.autocast(autocast_handler=autocast_kwargs):
|
920 |
+
if training_args.parallel_mode.value != "distributed":
|
921 |
+
encoder_outputs = model.text_encoder(
|
922 |
+
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
|
923 |
+
)
|
924 |
+
else:
|
925 |
+
encoder_outputs = model.module.text_encoder(
|
926 |
+
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
|
927 |
+
)
|
928 |
+
# we optionnally project last_hidden_state to avoid recomputing every time
|
929 |
+
encoder_hidden_states = encoder_outputs.last_hidden_state
|
930 |
+
if (
|
931 |
+
config.text_encoder.hidden_size != config.decoder.hidden_size
|
932 |
+
and config.decoder.cross_attention_hidden_size is None
|
933 |
+
):
|
934 |
+
encoder_hidden_states = (
|
935 |
+
model.enc_to_dec_proj(encoder_hidden_states)
|
936 |
+
if training_args.parallel_mode.value != "distributed"
|
937 |
+
else model.module.enc_to_dec_proj(encoder_hidden_states)
|
938 |
+
)
|
939 |
+
|
940 |
+
if batch.get("attention_mask", None) is not None:
|
941 |
+
encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None]
|
942 |
+
|
943 |
+
encoder_outputs.last_hidden_state = encoder_hidden_states
|
944 |
+
batch["encoder_outputs"] = encoder_outputs
|
945 |
+
|
946 |
+
outputs = model(**batch, loss_reduction="sum")
|
947 |
+
# CE (data) loss
|
948 |
+
ce_loss = (outputs.loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch
|
949 |
+
|
950 |
+
metrics = {"loss": ce_loss}
|
951 |
+
|
952 |
+
# per CE loss
|
953 |
+
per_codebook_losses = outputs.per_codebook_losses
|
954 |
+
metrics.update({f"codebook_{i}_loss": ((l * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch) for (i,l) in enumerate(per_codebook_losses)})
|
955 |
+
return ce_loss, metrics
|
956 |
+
|
957 |
+
# Define eval fn
|
958 |
+
def eval_step(
|
959 |
+
batch,
|
960 |
+
accelerator,
|
961 |
+
autocast_kwargs,
|
962 |
+
):
|
963 |
+
eval_model = model if not training_args.torch_compile else model._orig_mod
|
964 |
+
|
965 |
+
if mixed_precision == "fp16":
|
966 |
+
# fp16 doesn't work with T5-like models
|
967 |
+
with accelerator.autocast(autocast_handler=autocast_kwargs):
|
968 |
+
if training_args.parallel_mode.value != "distributed":
|
969 |
+
encoder_outputs = model.text_encoder(
|
970 |
+
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
|
971 |
+
)
|
972 |
+
else:
|
973 |
+
encoder_outputs = model.module.text_encoder(
|
974 |
+
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
|
975 |
+
)
|
976 |
+
# we optionnally project last_hidden_state to avoid recomputing every time
|
977 |
+
encoder_hidden_states = encoder_outputs.last_hidden_state
|
978 |
+
if (
|
979 |
+
config.text_encoder.hidden_size != config.decoder.hidden_size
|
980 |
+
and config.decoder.cross_attention_hidden_size is None
|
981 |
+
):
|
982 |
+
encoder_hidden_states = (
|
983 |
+
model.enc_to_dec_proj(encoder_hidden_states)
|
984 |
+
if training_args.parallel_mode.value != "distributed"
|
985 |
+
else model.module.enc_to_dec_proj(encoder_hidden_states)
|
986 |
+
)
|
987 |
+
|
988 |
+
if batch.get("attention_mask", None) is not None:
|
989 |
+
encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None]
|
990 |
+
|
991 |
+
encoder_outputs.last_hidden_state = encoder_hidden_states
|
992 |
+
batch["encoder_outputs"] = encoder_outputs
|
993 |
+
|
994 |
+
with torch.no_grad():
|
995 |
+
outputs = eval_model(**batch)
|
996 |
+
# CE (data) loss
|
997 |
+
ce_loss = outputs.loss
|
998 |
+
metrics = {"loss": ce_loss}
|
999 |
+
|
1000 |
+
# per CE loss
|
1001 |
+
per_codebook_losses = outputs.per_codebook_losses
|
1002 |
+
metrics.update({f"codebook_{i}_loss": l for (i,l) in enumerate(per_codebook_losses)})
|
1003 |
+
return metrics
|
1004 |
+
|
1005 |
+
def generate_step(batch, accelerator):
|
1006 |
+
batch.pop("decoder_attention_mask", None)
|
1007 |
+
eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=True)
|
1008 |
+
if training_args.torch_compile:
|
1009 |
+
# if the model is compiled, we use the original model bc compile is not compatible with .generate
|
1010 |
+
eval_model = model._orig_mod
|
1011 |
+
|
1012 |
+
# since we've might have loaded the weights in fp32, we have to autocast to ensure FA2 weights are in half-precision.
|
1013 |
+
# with accelerator.autocast(autocast_handler=AutocastKwargs(enabled=(attn_implementation=="flash_attention_2"))):
|
1014 |
+
output_audios = eval_model.generate(**batch, **gen_kwargs)
|
1015 |
+
output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
|
1016 |
+
return output_audios
|
1017 |
+
|
1018 |
+
model.train()
|
1019 |
+
|
1020 |
+
total_batched_samples = resume_step if resume_step is not None else 0
|
1021 |
+
for epoch in range(epochs_trained, num_epochs):
|
1022 |
+
with accelerator.local_main_process_first():
|
1023 |
+
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
|
1024 |
+
sampler = None
|
1025 |
+
if training_args.group_by_length:
|
1026 |
+
sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"])
|
1027 |
+
train_dataloader = DataLoader(
|
1028 |
+
vectorized_datasets["train"],
|
1029 |
+
collate_fn=data_collator,
|
1030 |
+
batch_size=per_device_train_batch_size,
|
1031 |
+
sampler=sampler,
|
1032 |
+
shuffle=not training_args.group_by_length,
|
1033 |
+
num_workers=training_args.dataloader_num_workers,
|
1034 |
+
pin_memory=training_args.dataloader_pin_memory,
|
1035 |
+
)
|
1036 |
+
train_dataloader = accelerator.prepare(train_dataloader)
|
1037 |
+
if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
|
1038 |
+
train_dataloader.dataset.set_epoch(epoch)
|
1039 |
+
|
1040 |
+
if resume_step is not None:
|
1041 |
+
# Skip the first N batches in the dataloader when resuming from a checkpoint
|
1042 |
+
logger.info(f" Skip first {resume_step} batches")
|
1043 |
+
train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
|
1044 |
+
resume_step = None
|
1045 |
+
accelerator.wait_for_everyone()
|
1046 |
+
|
1047 |
+
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
|
1048 |
+
train_iterator = iter(train_dataloader)
|
1049 |
+
num_steps_in_epoch = len(train_dataloader)
|
1050 |
+
remainder = num_steps_in_epoch % gradient_accumulation_steps
|
1051 |
+
remainder = remainder if remainder != 0 else gradient_accumulation_steps
|
1052 |
+
total_updates = math.ceil(num_steps_in_epoch / gradient_accumulation_steps)
|
1053 |
+
|
1054 |
+
update_step = -1
|
1055 |
+
for _ in range(total_updates):
|
1056 |
+
update_step += 1
|
1057 |
+
|
1058 |
+
# preload the total batch per step
|
1059 |
+
batch_samples = []
|
1060 |
+
num_batches_in_step = gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
|
1061 |
+
for _ in range(num_batches_in_step):
|
1062 |
+
batch_samples += [next(train_iterator)]
|
1063 |
+
|
1064 |
+
# get num items in batch - if different than BOS and than -100
|
1065 |
+
num_items_in_batch = sum([(batch["labels"].ne(audio_encoder_bos_token_id) | batch["labels"].ne(-100) | batch["labels"].ne(audio_encoder_eos_token_id)).sum((0,1))[0] for batch in batch_samples])
|
1066 |
+
num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item()
|
1067 |
+
|
1068 |
+
# losses = []
|
1069 |
+
for i,batch in enumerate(batch_samples):
|
1070 |
+
total_batched_samples += 1
|
1071 |
+
ctx = model.no_sync if (i < len(batch_samples) - 1 and accelerator.num_processes > 1) else contextlib.nullcontext
|
1072 |
+
|
1073 |
+
with ctx():
|
1074 |
+
loss, train_metric = train_step(batch, accelerator, autocast_kwargs, num_items_in_batch, gradient_accumulation_steps)
|
1075 |
+
accelerator.backward(loss)
|
1076 |
+
# losses.append(loss.detach())
|
1077 |
+
|
1078 |
+
grad_norm = accelerator.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
|
1079 |
+
optimizer.step()
|
1080 |
+
lr_scheduler.step()
|
1081 |
+
optimizer.zero_grad()
|
1082 |
+
|
1083 |
+
# The accelerator has performed an optimization step behind the scenes
|
1084 |
+
steps_trained_progress_bar.update(1)
|
1085 |
+
cur_step += 1
|
1086 |
+
|
1087 |
+
# losses = accelerator.gather(sum(losses)).sum().item() / (accelerator.num_processes * gradient_accumulation_steps)
|
1088 |
+
|
1089 |
+
if cur_step % training_args.logging_steps == 0:
|
1090 |
+
steps_trained_progress_bar.write(
|
1091 |
+
f"Step... ({cur_step} / {total_train_steps} | Loss:"
|
1092 |
+
f" {train_metric['loss']}, Learning Rate:"
|
1093 |
+
f" {lr_scheduler.get_last_lr()[0]})"
|
1094 |
+
)
|
1095 |
+
train_metric["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
|
1096 |
+
log_metric(
|
1097 |
+
accelerator,
|
1098 |
+
metrics=train_metric,
|
1099 |
+
learning_rate=lr_scheduler.get_last_lr()[0],
|
1100 |
+
train_time=train_time + time.time() - train_start,
|
1101 |
+
step=cur_step,
|
1102 |
+
epoch=epoch,
|
1103 |
+
prefix="train",
|
1104 |
+
)
|
1105 |
+
|
1106 |
+
# save checkpoint and weights after each save_steps and at the end of training
|
1107 |
+
if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
|
1108 |
+
intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
|
1109 |
+
# safe_serialization=False to avoid shared tensors saving issue (TODO(YL): it's a temporary fix)
|
1110 |
+
# https://github.com/huggingface/transformers/issues/27293#issuecomment-1872560074
|
1111 |
+
accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
|
1112 |
+
accelerator.wait_for_everyone()
|
1113 |
+
if accelerator.is_main_process:
|
1114 |
+
rotate_checkpoints(
|
1115 |
+
training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger
|
1116 |
+
)
|
1117 |
+
|
1118 |
+
if cur_step == total_train_steps:
|
1119 |
+
# un-wrap student model for save
|
1120 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
1121 |
+
unwrapped_model.save_pretrained(training_args.output_dir)
|
1122 |
+
|
1123 |
+
if training_args.push_to_hub:
|
1124 |
+
api.upload_folder(
|
1125 |
+
repo_id=repo_id,
|
1126 |
+
folder_path=training_args.output_dir,
|
1127 |
+
commit_message=f"Saving train state of step {cur_step}",
|
1128 |
+
run_as_future=True,
|
1129 |
+
)
|
1130 |
+
accelerator.wait_for_everyone()
|
1131 |
+
|
1132 |
+
if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
|
1133 |
+
train_time += time.time() - train_start
|
1134 |
+
# ======================== Evaluating ==============================
|
1135 |
+
model.eval()
|
1136 |
+
eval_metrics = []
|
1137 |
+
eval_preds = []
|
1138 |
+
eval_descriptions = []
|
1139 |
+
eval_prompts = []
|
1140 |
+
eval_start = time.time()
|
1141 |
+
|
1142 |
+
# release training input batch
|
1143 |
+
batch = release_memory(batch)
|
1144 |
+
|
1145 |
+
validation_dataloader = DataLoader(
|
1146 |
+
vectorized_datasets["eval"],
|
1147 |
+
collate_fn=data_collator,
|
1148 |
+
batch_size=per_device_eval_batch_size,
|
1149 |
+
drop_last=False,
|
1150 |
+
num_workers=training_args.eval_dataloader_num_workers,
|
1151 |
+
pin_memory=training_args.dataloader_pin_memory,
|
1152 |
+
)
|
1153 |
+
validation_dataloader = accelerator.prepare(validation_dataloader)
|
1154 |
+
|
1155 |
+
for batch in tqdm(
|
1156 |
+
validation_dataloader,
|
1157 |
+
desc=f"Evaluating - Inference ...",
|
1158 |
+
position=2,
|
1159 |
+
disable=not accelerator.is_local_main_process,
|
1160 |
+
):
|
1161 |
+
# Model forward
|
1162 |
+
eval_metric = eval_step(batch, accelerator, autocast_kwargs)
|
1163 |
+
eval_metric = accelerator.gather_for_metrics(eval_metric)
|
1164 |
+
eval_metric = {key: val.unsqueeze(0) if val.ndim == 0 else val for (key,val) in eval_metric.items()}
|
1165 |
+
eval_metrics.append(eval_metric)
|
1166 |
+
|
1167 |
+
if training_args.predict_with_generate and (cur_step % eval_generation_steps == 0 or cur_step == total_train_steps):
|
1168 |
+
validation_dataloader = DataLoader(
|
1169 |
+
vectorized_datasets["eval"],
|
1170 |
+
collate_fn=data_collator,
|
1171 |
+
batch_size=per_device_eval_batch_size,
|
1172 |
+
drop_last=False,
|
1173 |
+
num_workers=training_args.eval_dataloader_num_workers,
|
1174 |
+
pin_memory=training_args.dataloader_pin_memory,
|
1175 |
+
)
|
1176 |
+
validation_dataloader = accelerator.prepare(validation_dataloader)
|
1177 |
+
# generation
|
1178 |
+
for batch in tqdm(
|
1179 |
+
validation_dataloader,
|
1180 |
+
desc=f"Evaluating - Generation ...",
|
1181 |
+
position=2,
|
1182 |
+
disable=not accelerator.is_local_main_process,
|
1183 |
+
):
|
1184 |
+
generated_audios = generate_step(batch, accelerator)
|
1185 |
+
# Gather all predictions and targets
|
1186 |
+
generated_audios, input_ids, prompts = accelerator.pad_across_processes(
|
1187 |
+
(generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0
|
1188 |
+
)
|
1189 |
+
generated_audios, input_ids, prompts = accelerator.gather_for_metrics(
|
1190 |
+
(generated_audios, input_ids, prompts)
|
1191 |
+
)
|
1192 |
+
eval_preds.extend(generated_audios.to("cpu"))
|
1193 |
+
eval_descriptions.extend(input_ids.to("cpu"))
|
1194 |
+
eval_prompts.extend(prompts.to("cpu"))
|
1195 |
+
|
1196 |
+
eval_time = time.time() - eval_start
|
1197 |
+
# normalize eval metrics
|
1198 |
+
eval_metrics = {
|
1199 |
+
key: torch.mean(torch.cat([d[key] for d in eval_metrics])).to("cpu") for key in eval_metrics[0]
|
1200 |
+
}
|
1201 |
+
|
1202 |
+
# compute metrics
|
1203 |
+
metrics_desc = ""
|
1204 |
+
if training_args.predict_with_generate and (cur_step % eval_generation_steps == 0 or cur_step == total_train_steps):
|
1205 |
+
if accelerator.is_local_main_process:
|
1206 |
+
(
|
1207 |
+
metric_values,
|
1208 |
+
pred_descriptions,
|
1209 |
+
pred_prompts,
|
1210 |
+
audios,
|
1211 |
+
transcriptions,
|
1212 |
+
si_sdr_measures,
|
1213 |
+
) = compute_metrics(
|
1214 |
+
eval_preds,
|
1215 |
+
eval_descriptions,
|
1216 |
+
eval_prompts,
|
1217 |
+
accelerator.device,
|
1218 |
+
training_args.compute_clap_similarity_metric,
|
1219 |
+
training_args.compute_noise_level_metric,
|
1220 |
+
training_args.noise_level_to_compute_clean_wer,
|
1221 |
+
)
|
1222 |
+
eval_metrics.update(metric_values)
|
1223 |
+
metrics_desc = " ".join([f"Eval {key}: {value} |" for key, value in metric_values.items()])
|
1224 |
+
if "wandb" in training_args.report_to:
|
1225 |
+
log_pred(
|
1226 |
+
accelerator,
|
1227 |
+
pred_descriptions,
|
1228 |
+
pred_prompts,
|
1229 |
+
transcriptions,
|
1230 |
+
audios,
|
1231 |
+
si_sdr_measures,
|
1232 |
+
sampling_rate=sampling_rate,
|
1233 |
+
step=cur_step,
|
1234 |
+
prefix="eval",
|
1235 |
+
)
|
1236 |
+
accelerator.wait_for_everyone()
|
1237 |
+
|
1238 |
+
# Print metrics and update progress bar
|
1239 |
+
if accelerator.is_local_main_process:
|
1240 |
+
steps_trained_progress_bar.write(
|
1241 |
+
f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
|
1242 |
+
f" {metrics_desc})"
|
1243 |
+
)
|
1244 |
+
|
1245 |
+
log_metric(
|
1246 |
+
accelerator,
|
1247 |
+
metrics=eval_metrics,
|
1248 |
+
train_time=eval_time,
|
1249 |
+
step=cur_step,
|
1250 |
+
epoch=epoch,
|
1251 |
+
prefix="eval",
|
1252 |
+
)
|
1253 |
+
|
1254 |
+
# release eval batch and relax metrics
|
1255 |
+
eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric = release_memory(
|
1256 |
+
eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric
|
1257 |
+
)
|
1258 |
+
if training_args.predict_with_generate and (cur_step % eval_generation_steps == 0 or cur_step == total_train_steps):
|
1259 |
+
generated_audios, input_ids, prompts = release_memory(generated_audios, input_ids, prompts)
|
1260 |
+
|
1261 |
+
# train mode
|
1262 |
+
model.train()
|
1263 |
+
|
1264 |
+
# flush the train metrics
|
1265 |
+
train_start = time.time()
|
1266 |
+
|
1267 |
+
# break condition
|
1268 |
+
if cur_step == total_train_steps:
|
1269 |
+
continue_training = False
|
1270 |
+
break
|
1271 |
+
|
1272 |
+
if not continue_training:
|
1273 |
+
break
|
1274 |
+
|
1275 |
+
accelerator.end_training()
|
1276 |
+
|
1277 |
+
|
1278 |
+
if __name__ == "__main__":
|
1279 |
+
main()
|
capspeech/ar/training/utils.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import shutil
|
4 |
+
from dataclasses import field
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Dict, List
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from datasets import concatenate_datasets, load_from_disk
|
10 |
+
from wandb import Audio
|
11 |
+
from datasets import load_from_disk, concatenate_datasets
|
12 |
+
|
13 |
+
|
14 |
+
def list_field(default=None, metadata=None):
|
15 |
+
return field(default_factory=lambda: default, metadata=metadata)
|
16 |
+
|
17 |
+
|
18 |
+
_RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
|
19 |
+
CHECKPOINT_CODEC_PREFIX = "checkpoint"
|
20 |
+
_RE_CODEC_CHECKPOINT = re.compile(r"^checkpoint-(\d+)$")
|
21 |
+
|
22 |
+
|
23 |
+
def get_last_checkpoint(folder):
|
24 |
+
content = os.listdir(folder)
|
25 |
+
checkpoints = [
|
26 |
+
path
|
27 |
+
for path in content
|
28 |
+
if _RE_CHECKPOINT.search(path) is not None and os.path.isdir(os.path.join(folder, path))
|
29 |
+
]
|
30 |
+
if len(checkpoints) == 0:
|
31 |
+
return
|
32 |
+
return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
|
33 |
+
|
34 |
+
|
35 |
+
def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]:
|
36 |
+
"""Helper function to sort saved checkpoints from oldest to newest."""
|
37 |
+
ordering_and_checkpoint_path = []
|
38 |
+
|
39 |
+
glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
|
40 |
+
|
41 |
+
for path in glob_checkpoints:
|
42 |
+
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
|
43 |
+
if regex_match is not None and regex_match.groups() is not None:
|
44 |
+
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
|
45 |
+
|
46 |
+
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
|
47 |
+
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
|
48 |
+
return checkpoints_sorted
|
49 |
+
|
50 |
+
|
51 |
+
def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint", logger=None) -> None:
|
52 |
+
"""Helper function to delete old checkpoints."""
|
53 |
+
if save_total_limit is None or save_total_limit <= 0:
|
54 |
+
return
|
55 |
+
# Check if we should delete older checkpoint(s)
|
56 |
+
checkpoints_sorted = sorted_checkpoints(output_dir=output_dir, checkpoint_prefix=checkpoint_prefix)
|
57 |
+
if len(checkpoints_sorted) <= save_total_limit:
|
58 |
+
return
|
59 |
+
|
60 |
+
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
|
61 |
+
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
|
62 |
+
for checkpoint in checkpoints_to_be_deleted:
|
63 |
+
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
|
64 |
+
shutil.rmtree(checkpoint, ignore_errors=True)
|
65 |
+
|
66 |
+
|
67 |
+
def save_codec_checkpoint(output_dir, dataset, step):
|
68 |
+
checkpoint_path = f"{CHECKPOINT_CODEC_PREFIX}-{step}"
|
69 |
+
output_path = os.path.join(output_dir, checkpoint_path)
|
70 |
+
dataset.save_to_disk(output_path)
|
71 |
+
|
72 |
+
|
73 |
+
def load_codec_checkpoint(checkpoint_path):
|
74 |
+
dataset = load_from_disk(checkpoint_path)
|
75 |
+
return dataset
|
76 |
+
|
77 |
+
|
78 |
+
def sorted_codec_checkpoints(output_dir=None) -> List[str]:
|
79 |
+
"""Helper function to sort saved checkpoints from oldest to newest."""
|
80 |
+
ordering_and_checkpoint_path = []
|
81 |
+
|
82 |
+
glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{CHECKPOINT_CODEC_PREFIX}-*")]
|
83 |
+
|
84 |
+
for path in glob_checkpoints:
|
85 |
+
regex_match = re.match(f".*{CHECKPOINT_CODEC_PREFIX}-([0-9]+)", path)
|
86 |
+
if regex_match is not None and regex_match.groups() is not None:
|
87 |
+
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
|
88 |
+
|
89 |
+
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
|
90 |
+
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
|
91 |
+
return checkpoints_sorted
|
92 |
+
|
93 |
+
|
94 |
+
def load_all_codec_checkpoints(output_dir=None) -> List[str]:
|
95 |
+
"""Helper function to load and concat all checkpoints."""
|
96 |
+
checkpoints_sorted = sorted_codec_checkpoints(output_dir=output_dir)
|
97 |
+
datasets = [load_from_disk(checkpoint) for checkpoint in checkpoints_sorted]
|
98 |
+
datasets = concatenate_datasets(datasets, axis=0)
|
99 |
+
return datasets
|
100 |
+
|
101 |
+
|
102 |
+
def get_last_codec_checkpoint_step(folder) -> int:
|
103 |
+
if not os.path.exists(folder) or not os.path.isdir(folder):
|
104 |
+
os.makedirs(folder, exist_ok=True)
|
105 |
+
return 0
|
106 |
+
content = os.listdir(folder)
|
107 |
+
checkpoints = [path for path in content if _RE_CODEC_CHECKPOINT.search(path) is not None]
|
108 |
+
if len(checkpoints) == 0:
|
109 |
+
return 0
|
110 |
+
last_checkpoint = os.path.join(
|
111 |
+
folder, max(checkpoints, key=lambda x: int(_RE_CODEC_CHECKPOINT.search(x).groups()[0]))
|
112 |
+
)
|
113 |
+
# Find num steps saved state string pattern
|
114 |
+
pattern = r"checkpoint-(\d+)"
|
115 |
+
match = re.search(pattern, last_checkpoint)
|
116 |
+
cur_step = int(match.group(1))
|
117 |
+
return cur_step
|
118 |
+
|
119 |
+
|
120 |
+
def log_metric(
|
121 |
+
accelerator,
|
122 |
+
metrics: Dict,
|
123 |
+
train_time: float,
|
124 |
+
step: int,
|
125 |
+
epoch: int,
|
126 |
+
learning_rate: float = None,
|
127 |
+
prefix: str = "train",
|
128 |
+
):
|
129 |
+
"""Helper function to log all training/evaluation metrics with the correct prefixes and styling."""
|
130 |
+
log_metrics = {}
|
131 |
+
for k, v in metrics.items():
|
132 |
+
if "codebook" in k:
|
133 |
+
log_metrics[f"codebook_{prefix}/{k}"] = v
|
134 |
+
else:
|
135 |
+
log_metrics[f"{prefix}/{k}"] = v
|
136 |
+
log_metrics[f"{prefix}/time"] = train_time
|
137 |
+
log_metrics[f"{prefix}/epoch"] = epoch
|
138 |
+
if learning_rate is not None:
|
139 |
+
log_metrics[f"{prefix}/learning_rate"] = learning_rate
|
140 |
+
accelerator.log(log_metrics, step=step)
|
141 |
+
|
142 |
+
|
143 |
+
def log_pred(
|
144 |
+
accelerator,
|
145 |
+
pred_descriptions: List[str],
|
146 |
+
pred_prompts: List[str],
|
147 |
+
transcriptions: List[str],
|
148 |
+
audios: List[torch.Tensor],
|
149 |
+
si_sdr_measures: List[float],
|
150 |
+
sampling_rate: int,
|
151 |
+
step: int,
|
152 |
+
prefix: str = "eval",
|
153 |
+
num_lines: int = 200000,
|
154 |
+
):
|
155 |
+
"""Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
|
156 |
+
if accelerator.is_main_process:
|
157 |
+
wandb_tracker = accelerator.get_tracker("wandb")
|
158 |
+
# pretty name for current step: step 50000 -> step 50k
|
159 |
+
cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
|
160 |
+
prefix_pretty = prefix.replace("/", "-")
|
161 |
+
|
162 |
+
if si_sdr_measures is None:
|
163 |
+
# convert str data to a wandb compatible format
|
164 |
+
str_data = [
|
165 |
+
[pred_descriptions[i], pred_prompts[i], transcriptions[i]] for i in range(len(pred_descriptions))
|
166 |
+
]
|
167 |
+
# log as a table with the appropriate headers
|
168 |
+
wandb_tracker.log_table(
|
169 |
+
table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
|
170 |
+
columns=["Target descriptions", "Target prompts", "Predicted transcriptions"],
|
171 |
+
data=str_data[:num_lines],
|
172 |
+
step=step,
|
173 |
+
commit=False,
|
174 |
+
)
|
175 |
+
else:
|
176 |
+
# convert str data to a wandb compatible format
|
177 |
+
str_data = [
|
178 |
+
[pred_descriptions[i], pred_prompts[i], transcriptions[i], si_sdr_measures[i]]
|
179 |
+
for i in range(len(pred_descriptions))
|
180 |
+
]
|
181 |
+
# log as a table with the appropriate headers
|
182 |
+
wandb_tracker.log_table(
|
183 |
+
table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
|
184 |
+
columns=["Target descriptions", "Target prompts", "Predicted transcriptions", "Noise estimation"],
|
185 |
+
data=str_data[:num_lines],
|
186 |
+
step=step,
|
187 |
+
commit=False,
|
188 |
+
)
|
189 |
+
|
190 |
+
# wandb can only loads 100 audios per step
|
191 |
+
wandb_tracker.log(
|
192 |
+
{
|
193 |
+
"Speech samples": [
|
194 |
+
Audio(
|
195 |
+
audio,
|
196 |
+
caption=f"{pred_prompts[i]} --- DESCRIPTION: {pred_descriptions[i]}",
|
197 |
+
sample_rate=sampling_rate,
|
198 |
+
)
|
199 |
+
for (i, audio) in enumerate(audios[: min(len(audios), 100)])
|
200 |
+
]
|
201 |
+
},
|
202 |
+
step=step,
|
203 |
+
)
|
capspeech/eval/README.md
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CapSpeech Evaluation Tools
|
2 |
+
|
3 |
+
## Get Start
|
4 |
+
Install dependicies:
|
5 |
+
```bash
|
6 |
+
conda create -n capeval python=3.9
|
7 |
+
conda activate capeval
|
8 |
+
pip install -r requirements.txt
|
9 |
+
pip install git+https://github.com/sarulab-speech/UTMOSv2.git
|
10 |
+
```
|
11 |
+
|
12 |
+
For ASR, we need:
|
13 |
+
```bash
|
14 |
+
conda install ffmpeg
|
15 |
+
```
|
16 |
+
|
17 |
+
## Evaluate pitch, monotony, speed, age, gender
|
18 |
+
RUN:
|
19 |
+
```bash
|
20 |
+
python base_eval.py
|
21 |
+
```
|
22 |
+
|
23 |
+
## Evaluate UTMOSv2
|
24 |
+
RUN:
|
25 |
+
```bash
|
26 |
+
python mos_eval.py
|
27 |
+
```
|
28 |
+
|
29 |
+
## Evaluate ASR Results
|
30 |
+
RUN:
|
31 |
+
```bash
|
32 |
+
python asr_eval.py
|
33 |
+
```
|
34 |
+
|
35 |
+
## Evaluate emotion, accent
|
36 |
+
RUN:
|
37 |
+
```bash
|
38 |
+
cd src/example/
|
39 |
+
python categorized_emotion.py
|
40 |
+
python dialect_world_dialect.py
|
41 |
+
```
|
42 |
+
Please refer to [Vox-profile](https://github.com/tiantiaf0627/vox-profile-release.git) for more evaluation tools.
|
capspeech/eval/__init__.py
ADDED
File without changes
|
capspeech/eval/age_gender.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import audeer
|
2 |
+
import audonnx
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def age_gender_apply(waveform):
|
6 |
+
age_labels = ['child', 'teenager', 'young adult', 'middle-aged adult', 'elderly']
|
7 |
+
gender_labels = ['female', 'male']
|
8 |
+
url = 'https://zenodo.org/record/7761387/files/w2v2-L-robust-6-age-gender.25c844af-1.1.1.zip'
|
9 |
+
cache_root = audeer.mkdir('cache')
|
10 |
+
model_root = audeer.mkdir('model')
|
11 |
+
sampling_rate = 16000
|
12 |
+
archive_path = audeer.download_url(url, cache_root, verbose=True)
|
13 |
+
audeer.extract_archive(archive_path, model_root)
|
14 |
+
model = audonnx.load(model_root)
|
15 |
+
|
16 |
+
result = model(waveform, sampling_rate)
|
17 |
+
# Process age
|
18 |
+
age_label = result['logits_age'].squeeze() * 100.0
|
19 |
+
if age_label <= 12:
|
20 |
+
age_label = 'child'
|
21 |
+
elif age_label <= 19:
|
22 |
+
age_label = 'teenager'
|
23 |
+
elif age_label <= 39:
|
24 |
+
age_label = 'young adult'
|
25 |
+
elif age_label <= 64:
|
26 |
+
age_label = 'middle-aged adult'
|
27 |
+
else:
|
28 |
+
age_label = 'elderly'
|
29 |
+
|
30 |
+
# Process gender
|
31 |
+
gender_label = result['logits_gender'].squeeze()
|
32 |
+
gender_label = gender_label[:2] # Remove child
|
33 |
+
gender_label = np.argmax(gender_label)
|
34 |
+
|
35 |
+
return age_label, gender_labels[gender_label]
|
capspeech/eval/asr_eval.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from jiwer import wer as calculate_wer
|
2 |
+
from jiwer import cer as calculate_cer
|
3 |
+
from whisper.normalizers import EnglishTextNormalizer
|
4 |
+
import whisper
|
5 |
+
import torch
|
6 |
+
|
7 |
+
normalizer = EnglishTextNormalizer()
|
8 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
+
whisper_model = whisper.load_model("large-v3-turbo", device=device)
|
10 |
+
|
11 |
+
def asr(wav_path):
|
12 |
+
result = whisper_model.transcribe(wav_path)
|
13 |
+
pred = result['text'].strip()
|
14 |
+
pred = normalizer(pred)
|
15 |
+
return pred
|
16 |
+
|
17 |
+
if __name__ == '__main__':
|
18 |
+
gt_text="Hey, how are you doing today? I like it."
|
19 |
+
wav_path="your-audio"
|
20 |
+
gt_text = normalizer(gt_text.strip())
|
21 |
+
pred_asr = asr(wav_path)
|
22 |
+
wer = round(calculate_wer(gt_text, pred_asr), 3)
|
23 |
+
cer = round(calculate_cer(gt_text, pred_asr), 3)
|
24 |
+
print(wer, cer)
|
capspeech/eval/base_eval.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pitch import pitch_apply
|
2 |
+
from speed import speed_apply
|
3 |
+
from age_gender import age_gender_apply
|
4 |
+
import librosa
|
5 |
+
import json
|
6 |
+
import bisect
|
7 |
+
|
8 |
+
SPEAKER_RATE_BINS = ["very slowly", "slowly", "slightly slowly", "moderate speed", "slightly fast", "fast", "very fast"]
|
9 |
+
UTTERANCE_LEVEL_STD = ["very monotone", "monotone", "slightly expressive and animated", "expressive and animated", "very expressive and animated"]
|
10 |
+
SPEAKER_LEVEL_PITCH_BINS = ["very low-pitch", "low-pitch", "slightly low-pitch", "moderate pitch", "slightly high-pitch", "high-pitch", "very high-pitch"]
|
11 |
+
with open("bin.json") as json_file:
|
12 |
+
text_bins_dict = json.load(json_file)
|
13 |
+
|
14 |
+
audiopath = "YOUR_AUDIO_PATH"
|
15 |
+
waveform, _ = librosa.load(audiopath, sr=16000)
|
16 |
+
age, gender = age_gender_apply(waveform)
|
17 |
+
pitch_mean, pitch_std = pitch_apply(waveform)
|
18 |
+
if gender == "male":
|
19 |
+
index = bisect.bisect_right(text_bins_dict["pitch_bins_male"], pitch_mean) - 1
|
20 |
+
pitch = SPEAKER_LEVEL_PITCH_BINS[index]
|
21 |
+
else:
|
22 |
+
index = bisect.bisect_right(text_bins_dict["pitch_bins_female"], pitch_mean) - 1
|
23 |
+
pitch = SPEAKER_LEVEL_PITCH_BINS[index]
|
24 |
+
|
25 |
+
index = bisect.bisect_right(text_bins_dict["speech_monotony"], pitch_std) - 1
|
26 |
+
monotony = UTTERANCE_LEVEL_STD[index]
|
27 |
+
speech_duration = speed_apply(waveform)
|
28 |
+
|
29 |
+
index = bisect.bisect_right(text_bins_dict["speaking_rate"], speech_duration) - 1
|
30 |
+
speed = SPEAKER_RATE_BINS[index]
|
31 |
+
|
32 |
+
print(pitch, monotony, speed, age, gender)
|
capspeech/eval/bin.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"speaking_rate": [0.0, 3.8258038258038254, 7.651607651607651, 11.477411477411476, 15.303215303215302, 19.129019129019127, 22.95482295482295, 26.78062678062678],
|
3 |
+
"noise": [17.12751579284668, 25.4012325831822, 33.67494937351772, 41.94866616385323, 50.22238295418875, 58.49609974452427, 66.76981653485979, 75.04353332519531],
|
4 |
+
"reverberation": [10, 35, 45, 55, 59, 60],
|
5 |
+
"speech_monotony": [0.0, 20.37920924595424, 40.75841849190848, 70, 90, 142.6544647216797],
|
6 |
+
"pitch_bins_male": [64.6531982421875, 81.66683959960938, 98.68048095703125, 115.69412231445312, 132.707763671875, 149.72140502929688, 166.73504638671875, 183.74868774414062],
|
7 |
+
"pitch_bins_female": [120.17855072021484, 141.6242690945264, 163.06998746883795, 184.51570584314953, 205.96142421746106, 227.40714259177264, 248.8528609660842, 270.29857934039575],
|
8 |
+
"si-sdr": [-17.804332733154297, -0.40644073486328125, 10, 20, 25, 28, 34.38934326171875],
|
9 |
+
"pesq": [1, 1.7, 2.4, 3.1, 3.6, 4, 4.499948978424072]
|
10 |
+
}
|
capspeech/eval/pitch.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import penn
|
3 |
+
|
4 |
+
def pitch_apply(waveform):
|
5 |
+
hopsize = .01
|
6 |
+
fmin = 30.
|
7 |
+
fmax = 1000.
|
8 |
+
checkpoint = None
|
9 |
+
center = 'half-hop'
|
10 |
+
interp_unvoiced_at = .065
|
11 |
+
sampling_rate = 16000
|
12 |
+
penn_batch_size = 4096
|
13 |
+
waveform = torch.Tensor(waveform).unsqueeze(0)
|
14 |
+
pitch, periodicity = penn.from_audio(
|
15 |
+
waveform.float(),
|
16 |
+
sampling_rate,
|
17 |
+
hopsize=hopsize,
|
18 |
+
fmin=fmin,
|
19 |
+
fmax=fmax,
|
20 |
+
checkpoint=checkpoint,
|
21 |
+
batch_size=penn_batch_size,
|
22 |
+
center=center,
|
23 |
+
interp_unvoiced_at=interp_unvoiced_at,
|
24 |
+
gpu=None
|
25 |
+
)
|
26 |
+
|
27 |
+
pitch_mean = pitch.mean().cpu().numpy()
|
28 |
+
pitch_std = pitch.std().cpu().numpy()
|
29 |
+
|
30 |
+
return pitch_mean, pitch_std
|
capspeech/eval/requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets[audio]
|
2 |
+
https://github.com/marianne-m/brouhaha-vad/archive/main.zip
|
3 |
+
penn
|
4 |
+
g2p
|
5 |
+
demucs
|
6 |
+
transformers
|
7 |
+
bitsandbytes
|
8 |
+
git+https://github.com/sarulab-speech/UTMOSv2.git
|
9 |
+
-U openai-whisper
|
10 |
+
jiwer
|
11 |
+
numpy==1.26.4
|
12 |
+
audeer
|
13 |
+
audonnx
|
14 |
+
laion_clap
|
15 |
+
numpy==1.26.4
|
16 |
+
onnxruntime
|
capspeech/eval/speed.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pyannote.audio import Model
|
2 |
+
from pathlib import Path
|
3 |
+
from brouhaha.pipeline import RegressiveActivityDetectionPipeline
|
4 |
+
import torch
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
def speed_apply(waveform):
|
9 |
+
ratio = 16000/270
|
10 |
+
sampling_rate = 16000
|
11 |
+
device = "cpu"
|
12 |
+
waveform = torch.Tensor(waveform).unsqueeze(0)
|
13 |
+
model = Model.from_pretrained(
|
14 |
+
Path(hf_hub_download(repo_id="ylacombe/brouhaha-best", filename="best.ckpt")),
|
15 |
+
strict=False,
|
16 |
+
)
|
17 |
+
model.to(device)
|
18 |
+
|
19 |
+
pipeline = RegressiveActivityDetectionPipeline(segmentation=model, batch_size=1)
|
20 |
+
pipeline.to(torch.device(device))
|
21 |
+
|
22 |
+
device = pipeline._models["segmentation"].device
|
23 |
+
|
24 |
+
res = pipeline({"sample_rate": sampling_rate,
|
25 |
+
"waveform": waveform.to(device).float()})
|
26 |
+
|
27 |
+
speech_duration = sum(map(lambda x: x[0].duration, res["annotation"].itertracks()))
|
28 |
+
|
29 |
+
return speech_duration
|
capspeech/eval/src/__init__.py
ADDED
File without changes
|
capspeech/eval/src/example/__init__.py
ADDED
File without changes
|
capspeech/eval/src/example/categorized_emotion.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import logging
|
3 |
+
import sys, os, pdb
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1])))
|
9 |
+
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1]), 'model', 'emotion'))
|
10 |
+
|
11 |
+
from wavlm_emotion import WavLMWrapper
|
12 |
+
from whisper_emotion import WhisperWrapper
|
13 |
+
|
14 |
+
|
15 |
+
# define logging console
|
16 |
+
import logging
|
17 |
+
logging.basicConfig(
|
18 |
+
format='%(asctime)s %(levelname)-3s ==> %(message)s',
|
19 |
+
level=logging.INFO,
|
20 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
21 |
+
)
|
22 |
+
|
23 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
24 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
25 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == '__main__':
|
29 |
+
|
30 |
+
label_list = [
|
31 |
+
'Anger',
|
32 |
+
'Contempt',
|
33 |
+
'Disgust',
|
34 |
+
'Fear',
|
35 |
+
'Happiness',
|
36 |
+
'Neutral',
|
37 |
+
'Sadness',
|
38 |
+
'Surprise',
|
39 |
+
'Other'
|
40 |
+
]
|
41 |
+
|
42 |
+
# Find device
|
43 |
+
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
44 |
+
if torch.cuda.is_available(): print('GPU available, use GPU')
|
45 |
+
|
46 |
+
# Define the model
|
47 |
+
# Note that ensemble yields the better performance than the single model
|
48 |
+
# Define the model wrapper
|
49 |
+
model_path = "model"
|
50 |
+
wavlm_model = model = WavLMWrapper(
|
51 |
+
pretrain_model="wavlm_large",
|
52 |
+
finetune_method="finetune",
|
53 |
+
output_class_num=9,
|
54 |
+
freeze_params=True,
|
55 |
+
use_conv_output=True,
|
56 |
+
detailed_class_num=17
|
57 |
+
).to(device)
|
58 |
+
|
59 |
+
whisper_model = WhisperWrapper(
|
60 |
+
pretrain_model="whisper_large",
|
61 |
+
finetune_method="lora",
|
62 |
+
lora_rank=16,
|
63 |
+
output_class_num=9,
|
64 |
+
freeze_params=True,
|
65 |
+
use_conv_output=True,
|
66 |
+
detailed_class_num=17
|
67 |
+
).to(device)
|
68 |
+
|
69 |
+
whisper_model.load_state_dict(torch.load(os.path.join(model_path, f"whisper_emotion.pt"), weights_only=True), strict=False)
|
70 |
+
whisper_model.load_state_dict(torch.load(os.path.join(model_path, f"whisper_emotion_lora.pt")), strict=False)
|
71 |
+
wavlm_model.load_state_dict(torch.load(os.path.join(model_path, f"wavlm_emotion.pt"), weights_only=True), strict=False)
|
72 |
+
|
73 |
+
wavlm_model.eval()
|
74 |
+
whisper_model.eval()
|
75 |
+
|
76 |
+
# Audio must be 16k Hz
|
77 |
+
data = torch.zeros([1, 16000]).to(device)
|
78 |
+
whisper_logits, whisper_embedding, _, _, _, _ = whisper_model(
|
79 |
+
data, return_feature=True
|
80 |
+
)
|
81 |
+
wavlm_logits, wavlm_embedding, _, _, _, _ = wavlm_model(
|
82 |
+
data, return_feature=True
|
83 |
+
)
|
84 |
+
|
85 |
+
ensemble_logits = (whisper_logits + wavlm_logits) / 2
|
86 |
+
ensemble_prob = F.softmax(ensemble_logits, dim=1)
|
87 |
+
|
88 |
+
print(ensemble_prob.shape)
|
89 |
+
print(whisper_embedding.shape)
|
90 |
+
print(wavlm_embedding.shape)
|
91 |
+
print(label_list[torch.argmax(ensemble_prob).detach().cpu().item()])
|
92 |
+
|
capspeech/eval/src/example/dialect_world_dialect.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import sys, os, pdb
|
3 |
+
import argparse, logging
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
|
9 |
+
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1])))
|
10 |
+
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1]), 'model', 'dialect'))
|
11 |
+
|
12 |
+
from wavlm_dialect import WavLMWrapper
|
13 |
+
from whisper_dialect import WhisperWrapper
|
14 |
+
|
15 |
+
|
16 |
+
# define logging console
|
17 |
+
import logging
|
18 |
+
logging.basicConfig(
|
19 |
+
format='%(asctime)s %(levelname)-3s ==> %(message)s',
|
20 |
+
level=logging.INFO,
|
21 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
22 |
+
)
|
23 |
+
|
24 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
25 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
26 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == '__main__':
|
30 |
+
|
31 |
+
|
32 |
+
label_list = [
|
33 |
+
'East Asia', 'English', 'Germanic', 'Irish',
|
34 |
+
'North America', 'Northern Irish', 'Oceania',
|
35 |
+
'Other', 'Romance', 'Scottish', 'Semitic', 'Slavic',
|
36 |
+
'South African', 'Southeast Asia', 'South Asia', 'Welsh'
|
37 |
+
]
|
38 |
+
|
39 |
+
# Find device
|
40 |
+
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
41 |
+
if torch.cuda.is_available(): print('GPU available, use GPU')
|
42 |
+
|
43 |
+
# Define the model
|
44 |
+
# Note that ensemble yields the better performance than the single model
|
45 |
+
model_path = "YOUR_PATH"
|
46 |
+
# Define the model wrapper
|
47 |
+
wavlm_model = model = WavLMWrapper(
|
48 |
+
pretrain_model="wavlm_large",
|
49 |
+
finetune_method="lora",
|
50 |
+
lora_rank=16,
|
51 |
+
output_class_num=16,
|
52 |
+
freeze_params=False,
|
53 |
+
use_conv_output=True,
|
54 |
+
apply_gradient_reversal=False,
|
55 |
+
num_dataset=3
|
56 |
+
).to(device)
|
57 |
+
|
58 |
+
whisper_model = WhisperWrapper(
|
59 |
+
pretrain_model="whisper_large",
|
60 |
+
finetune_method="lora",
|
61 |
+
lora_rank=16,
|
62 |
+
output_class_num=16,
|
63 |
+
freeze_params=False,
|
64 |
+
use_conv_output=True,
|
65 |
+
apply_gradient_reversal=False,
|
66 |
+
num_dataset=11
|
67 |
+
).to(device)
|
68 |
+
|
69 |
+
wavlm_model.load_state_dict(torch.load(os.path.join(model_path, f"wavlm_world_dialect.pt"), weights_only=True), strict=False)
|
70 |
+
wavlm_model.load_state_dict(torch.load(os.path.join(model_path, f"wavlm_world_dialect_lora.pt")), strict=False)
|
71 |
+
|
72 |
+
whisper_model.load_state_dict(torch.load(os.path.join(model_path, f"whisper_world_dialect.pt"), weights_only=True), strict=False)
|
73 |
+
whisper_model.load_state_dict(torch.load(os.path.join(model_path, f"whisper_world_dialect_lora.pt")), strict=False)
|
74 |
+
|
75 |
+
wavlm_model.eval()
|
76 |
+
whisper_model.eval()
|
77 |
+
|
78 |
+
data = torch.zeros([1, 16000]).to(device)
|
79 |
+
wavlm_logits, wavlm_embeddings = wavlm_model(data, return_feature=True)
|
80 |
+
whisper_logits, whisper_embeddings = whisper_model(data, return_feature=True)
|
81 |
+
|
82 |
+
ensemble_logits = (wavlm_logits + whisper_logits) / 2
|
83 |
+
ensemble_prob = F.softmax(ensemble_logits, dim=1)
|
84 |
+
|
85 |
+
pred = label_list[ensemble_prob.argmax(-1)]
|
86 |
+
print(pred)
|
87 |
+
|
capspeech/eval/src/model/__init__.py
ADDED
File without changes
|
capspeech/eval/src/model/adapter.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# References:
|
3 |
+
# https://github.com/jxhe/unify-parameter-efficient-tuning
|
4 |
+
# --------------------------------------------------------
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
|
11 |
+
class Adapter(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
config=None,
|
15 |
+
d_model=768,
|
16 |
+
bottleneck=None,
|
17 |
+
dropout=0.0,
|
18 |
+
init_option="lora",
|
19 |
+
adapter_scalar="1.0",
|
20 |
+
adapter_layernorm_option="none"
|
21 |
+
):
|
22 |
+
super().__init__()
|
23 |
+
self.n_embd = config.d_model if d_model is None else d_model
|
24 |
+
self.down_size = config.attn_bn if bottleneck is None else bottleneck
|
25 |
+
|
26 |
+
#_before
|
27 |
+
self.adapter_layernorm_option = adapter_layernorm_option
|
28 |
+
|
29 |
+
self.adapter_layer_norm_before = None
|
30 |
+
if adapter_layernorm_option == "in" or adapter_layernorm_option == "out":
|
31 |
+
self.adapter_layer_norm_before = nn.LayerNorm(self.n_embd)
|
32 |
+
|
33 |
+
if adapter_scalar == "learnable_scalar":
|
34 |
+
self.scale = nn.Parameter(torch.ones(1))
|
35 |
+
else:
|
36 |
+
self.scale = float(adapter_scalar)
|
37 |
+
|
38 |
+
self.down_proj = nn.Linear(self.n_embd, self.down_size)
|
39 |
+
self.non_linear_func = nn.ReLU()
|
40 |
+
self.up_proj = nn.Linear(self.down_size, self.n_embd)
|
41 |
+
|
42 |
+
self.dropout = dropout
|
43 |
+
if init_option == "bert":
|
44 |
+
raise NotImplementedError
|
45 |
+
elif init_option == "lora":
|
46 |
+
with torch.no_grad():
|
47 |
+
nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
|
48 |
+
nn.init.zeros_(self.up_proj.weight)
|
49 |
+
nn.init.zeros_(self.down_proj.bias)
|
50 |
+
nn.init.zeros_(self.up_proj.bias)
|
51 |
+
|
52 |
+
def forward(self, x, add_residual=True, residual=None):
|
53 |
+
residual = x if residual is None else residual
|
54 |
+
if self.adapter_layernorm_option == 'in':
|
55 |
+
x = self.adapter_layer_norm_before(x)
|
56 |
+
|
57 |
+
down = self.down_proj(x)
|
58 |
+
|
59 |
+
down = self.non_linear_func(down)
|
60 |
+
down = nn.functional.dropout(down, p=self.dropout, training=self.training)
|
61 |
+
up = self.up_proj(down)
|
62 |
+
|
63 |
+
up = up * self.scale
|
64 |
+
|
65 |
+
if self.adapter_layernorm_option == 'out':
|
66 |
+
up = self.adapter_layer_norm_before(up)
|
67 |
+
|
68 |
+
if add_residual:
|
69 |
+
output = up + residual
|
70 |
+
else:
|
71 |
+
output = up
|
72 |
+
|
73 |
+
return output
|
capspeech/eval/src/model/dialect/__init__.py
ADDED
File without changes
|
capspeech/eval/src/model/dialect/wavlm_dialect.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import copy
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
import loralib as lora
|
7 |
+
import transformers.models.wavlm.modeling_wavlm as wavlm
|
8 |
+
from speechbrain.nnet.normalization import LayerNorm
|
9 |
+
from speechbrain.lobes.models.huggingface_transformers.huggingface import make_padding_masks
|
10 |
+
|
11 |
+
from torch import nn
|
12 |
+
from torch.nn import functional as F
|
13 |
+
from transformers import Wav2Vec2FeatureExtractor
|
14 |
+
from transformers import WavLMModel
|
15 |
+
|
16 |
+
import sys
|
17 |
+
from pathlib import Path
|
18 |
+
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1])))
|
19 |
+
from revgrad import RevGrad
|
20 |
+
|
21 |
+
class WavLMEncoderLayer(nn.Module):
|
22 |
+
def __init__(self, layer_idx, config, has_relative_position_bias: bool = True):
|
23 |
+
super().__init__()
|
24 |
+
self.attention = wavlm.WavLMAttention(
|
25 |
+
embed_dim=config.hidden_size,
|
26 |
+
num_heads=config.num_attention_heads,
|
27 |
+
dropout=config.attention_dropout,
|
28 |
+
num_buckets=config.num_buckets,
|
29 |
+
max_distance=config.max_bucket_distance,
|
30 |
+
has_relative_position_bias=has_relative_position_bias,
|
31 |
+
)
|
32 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
33 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
34 |
+
self.feed_forward = wavlm.WavLMFeedForward(config)
|
35 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
36 |
+
self.config = config
|
37 |
+
|
38 |
+
if layer_idx > config.num_hidden_layers // 2:
|
39 |
+
if self.config.finetune_method == "lora" or self.config.finetune_method == "combined":
|
40 |
+
self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank)
|
41 |
+
self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank)
|
42 |
+
|
43 |
+
def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0):
|
44 |
+
|
45 |
+
attn_residual = hidden_states
|
46 |
+
hidden_states, attn_weights, position_bias = self.attention(
|
47 |
+
hidden_states,
|
48 |
+
attention_mask=attention_mask,
|
49 |
+
position_bias=position_bias,
|
50 |
+
output_attentions=output_attentions,
|
51 |
+
index=index,
|
52 |
+
)
|
53 |
+
hidden_states = self.dropout(hidden_states)
|
54 |
+
hidden_states = attn_residual + hidden_states
|
55 |
+
|
56 |
+
hidden_states = self.layer_norm(hidden_states)
|
57 |
+
hidden_states = hidden_states + self.feed_forward(hidden_states)
|
58 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
59 |
+
outputs = (hidden_states, position_bias)
|
60 |
+
|
61 |
+
if output_attentions:
|
62 |
+
outputs += (attn_weights,)
|
63 |
+
|
64 |
+
return outputs
|
65 |
+
|
66 |
+
|
67 |
+
class WavLMEncoderLayerStableLayerNorm(nn.Module):
|
68 |
+
def __init__(self, layer_idx, config, has_relative_position_bias: bool = True):
|
69 |
+
super().__init__()
|
70 |
+
self.attention = wavlm.WavLMAttention(
|
71 |
+
embed_dim=config.hidden_size,
|
72 |
+
num_heads=config.num_attention_heads,
|
73 |
+
dropout=config.attention_dropout,
|
74 |
+
num_buckets=config.num_buckets,
|
75 |
+
max_distance=config.max_bucket_distance,
|
76 |
+
has_relative_position_bias=has_relative_position_bias,
|
77 |
+
)
|
78 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
79 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
80 |
+
self.feed_forward = wavlm.WavLMFeedForward(config)
|
81 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
82 |
+
self.config = config
|
83 |
+
|
84 |
+
if layer_idx > config.num_hidden_layers // 2:
|
85 |
+
if self.config.finetune_method == "lora" or self.config.finetune_method == "combined":
|
86 |
+
self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank)
|
87 |
+
self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank)
|
88 |
+
|
89 |
+
|
90 |
+
def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False):
|
91 |
+
attn_residual = hidden_states
|
92 |
+
hidden_states = self.layer_norm(hidden_states)
|
93 |
+
hidden_states, attn_weights, position_bias = self.attention(
|
94 |
+
hidden_states,
|
95 |
+
attention_mask=attention_mask,
|
96 |
+
position_bias=position_bias,
|
97 |
+
output_attentions=output_attentions,
|
98 |
+
)
|
99 |
+
hidden_states = self.dropout(hidden_states)
|
100 |
+
hidden_states = attn_residual + hidden_states
|
101 |
+
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
|
102 |
+
|
103 |
+
outputs = (hidden_states, position_bias)
|
104 |
+
|
105 |
+
if output_attentions:
|
106 |
+
outputs += (attn_weights,)
|
107 |
+
|
108 |
+
return outputs
|
109 |
+
|
110 |
+
|
111 |
+
class WavLMWrapper(nn.Module):
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
pretrain_model="wavlm_large",
|
115 |
+
hidden_dim=256,
|
116 |
+
finetune_method="lora",
|
117 |
+
lora_rank=16,
|
118 |
+
freeze_params=True,
|
119 |
+
output_class_num=4,
|
120 |
+
use_conv_output=True,
|
121 |
+
apply_gradient_reversal=False,
|
122 |
+
num_dataset=4
|
123 |
+
):
|
124 |
+
super(WavLMWrapper, self).__init__()
|
125 |
+
# 1. We Load the model first with weights
|
126 |
+
if pretrain_model == "wavlm":
|
127 |
+
self.backbone_model = WavLMModel.from_pretrained(
|
128 |
+
"microsoft/wavlm-base-plus",
|
129 |
+
output_hidden_states=True,
|
130 |
+
)
|
131 |
+
elif pretrain_model == "wavlm_large":
|
132 |
+
self.processor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-large')
|
133 |
+
self.backbone_model = WavLMModel.from_pretrained(
|
134 |
+
"microsoft/wavlm-large",
|
135 |
+
output_hidden_states=True,
|
136 |
+
)
|
137 |
+
self.pretrain_model = pretrain_model
|
138 |
+
self.finetune_method = finetune_method
|
139 |
+
self.apply_gradient_reversal = apply_gradient_reversal
|
140 |
+
self.use_conv_output = use_conv_output
|
141 |
+
|
142 |
+
state_dict = self.backbone_model.state_dict()
|
143 |
+
# 2. Read the model config
|
144 |
+
self.model_config = self.backbone_model.config
|
145 |
+
self.model_config.finetune_method = finetune_method
|
146 |
+
self.model_config.lora_rank = lora_rank
|
147 |
+
|
148 |
+
# 3. Config encoder layers with adapter or embedding prompt
|
149 |
+
if self.pretrain_model == "wavlm":
|
150 |
+
self.backbone_model.encoder.layers = nn.ModuleList(
|
151 |
+
[WavLMEncoderLayer(i, self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)]
|
152 |
+
)
|
153 |
+
elif self.pretrain_model == "wavlm_large":
|
154 |
+
self.backbone_model.encoder.layers = nn.ModuleList(
|
155 |
+
[WavLMEncoderLayerStableLayerNorm(i, self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)]
|
156 |
+
)
|
157 |
+
# 4. Load the weights back
|
158 |
+
msg = self.backbone_model.load_state_dict(state_dict, strict=False)
|
159 |
+
|
160 |
+
# 5. Freeze the weights
|
161 |
+
self.freeze_params = freeze_params
|
162 |
+
if self.freeze_params and self.finetune_method != "lora":
|
163 |
+
for _, p in self.backbone_model.named_parameters(): p.requires_grad = False
|
164 |
+
elif self.freeze_params and self.finetune_method == "lora":
|
165 |
+
for name, p in self.backbone_model.named_parameters():
|
166 |
+
if name in msg.missing_keys: p.requires_grad = True
|
167 |
+
else: p.requires_grad = False
|
168 |
+
else:
|
169 |
+
for _, p in self.backbone_model.named_parameters(): p.requires_grad = True
|
170 |
+
|
171 |
+
# 6. Downstream models
|
172 |
+
self.model_seq = nn.Sequential(
|
173 |
+
nn.Conv1d(self.model_config.hidden_size, hidden_dim, 1, padding=0),
|
174 |
+
nn.ReLU(),
|
175 |
+
nn.Dropout(p=0.1),
|
176 |
+
nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0),
|
177 |
+
nn.ReLU(),
|
178 |
+
nn.Dropout(p=0.1),
|
179 |
+
nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0)
|
180 |
+
)
|
181 |
+
|
182 |
+
if self.use_conv_output:
|
183 |
+
num_layers = self.model_config.num_hidden_layers + 1 # transformer layers + input embeddings
|
184 |
+
self.weights = nn.Parameter(torch.ones(num_layers)/num_layers)
|
185 |
+
else:
|
186 |
+
num_layers = self.model_config.num_hidden_layers
|
187 |
+
self.weights = nn.Parameter(torch.zeros(num_layers))
|
188 |
+
|
189 |
+
if apply_gradient_reversal:
|
190 |
+
self.dataset_layer = nn.Sequential(
|
191 |
+
RevGrad(),
|
192 |
+
nn.Linear(hidden_dim, hidden_dim),
|
193 |
+
nn.ReLU(),
|
194 |
+
nn.Linear(hidden_dim, num_dataset),
|
195 |
+
)
|
196 |
+
|
197 |
+
self.out_layer = nn.Sequential(
|
198 |
+
nn.Linear(hidden_dim, hidden_dim),
|
199 |
+
nn.ReLU(),
|
200 |
+
nn.Linear(hidden_dim, output_class_num),
|
201 |
+
)
|
202 |
+
|
203 |
+
def forward(self, x, length=None, return_feature=False):
|
204 |
+
# 1. feature extraction and projections
|
205 |
+
if self.pretrain_model == "wavlm_large":
|
206 |
+
with torch.no_grad():
|
207 |
+
signal, attention_mask = list(), list()
|
208 |
+
if length is not None: attention_mask = make_padding_masks(x, wav_len=length/length.max()).to(x.device)
|
209 |
+
else: attention_mask = make_padding_masks(x, wav_len=torch.tensor([1]).to(x.device)).to(x.device)
|
210 |
+
|
211 |
+
for idx in range(len(x)):
|
212 |
+
input = self.processor(x[idx], sampling_rate=16_000, return_tensors="pt", padding=True)
|
213 |
+
signal.append(input["input_values"][0].to(x.device))
|
214 |
+
signal = torch.stack(signal)
|
215 |
+
|
216 |
+
# 2. get length and mask
|
217 |
+
if length is not None:
|
218 |
+
length = self.get_feat_extract_output_lengths(length.detach().cpu())
|
219 |
+
length = length.cuda()
|
220 |
+
|
221 |
+
if self.pretrain_model == "wavlm":
|
222 |
+
x = self.backbone_model(
|
223 |
+
x, output_hidden_states=True
|
224 |
+
).hidden_states
|
225 |
+
else:
|
226 |
+
x = self.backbone_model(
|
227 |
+
signal,
|
228 |
+
attention_mask=attention_mask,
|
229 |
+
output_hidden_states=True
|
230 |
+
).hidden_states
|
231 |
+
|
232 |
+
# 4. stacked feature
|
233 |
+
if self.use_conv_output: stacked_feature = torch.stack(x, dim=0)
|
234 |
+
else: stacked_feature = torch.stack(x, dim=0)[1:]
|
235 |
+
|
236 |
+
# 5. Weighted sum
|
237 |
+
_, *origin_shape = stacked_feature.shape
|
238 |
+
# Return transformer enc outputs [num_enc_layers, B, T, D]
|
239 |
+
if self.use_conv_output:
|
240 |
+
stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers+1, -1)
|
241 |
+
else:
|
242 |
+
stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers, -1)
|
243 |
+
norm_weights = F.softmax(self.weights, dim=-1)
|
244 |
+
|
245 |
+
# Perform weighted average
|
246 |
+
weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0)
|
247 |
+
features = weighted_feature.view(*origin_shape)
|
248 |
+
|
249 |
+
# 6. Pass the weighted average to point-wise 1D Conv
|
250 |
+
# B x T x D
|
251 |
+
features = features.transpose(1, 2)
|
252 |
+
features = self.model_seq(features)
|
253 |
+
features = features.transpose(1, 2)
|
254 |
+
|
255 |
+
# 7. Pooling
|
256 |
+
if length is not None:
|
257 |
+
mean, std = list(), list()
|
258 |
+
for snt_id in range(features.shape[0]):
|
259 |
+
# Avoiding padded time steps
|
260 |
+
actual_size = length[snt_id]
|
261 |
+
mean.append(torch.mean(features[snt_id, 0:actual_size, ...], dim=0))
|
262 |
+
features = torch.stack(mean)
|
263 |
+
else:
|
264 |
+
features = torch.mean(features, dim=1)
|
265 |
+
|
266 |
+
# 8. Output predictions
|
267 |
+
# B x D
|
268 |
+
predicted = self.out_layer(features)
|
269 |
+
if self.apply_gradient_reversal:
|
270 |
+
dataset_predicted = self.dataset_layer(features)
|
271 |
+
if return_feature: return predicted, dataset_predicted, features
|
272 |
+
return predicted, dataset_predicted
|
273 |
+
if return_feature: return predicted, features
|
274 |
+
return predicted
|
275 |
+
|
276 |
+
# From huggingface
|
277 |
+
def get_feat_extract_output_lengths(self, input_length):
|
278 |
+
"""
|
279 |
+
Computes the output length of the convolutional layers
|
280 |
+
"""
|
281 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
282 |
+
# 1D convolutional layer output length formula taken
|
283 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
284 |
+
return (input_length - kernel_size) // stride + 1
|
285 |
+
for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride):
|
286 |
+
input_length = _conv_out_length(input_length, kernel_size, stride)
|
287 |
+
return input_length
|
288 |
+
|
289 |
+
def prepare_mask(length, shape, dtype):
|
290 |
+
# Modified from huggingface
|
291 |
+
mask = torch.zeros(
|
292 |
+
shape, dtype=dtype
|
293 |
+
)
|
294 |
+
# these two operations makes sure that all values
|
295 |
+
# before the output lengths indices are attended to
|
296 |
+
mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1
|
297 |
+
mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
298 |
+
return mask
|
299 |
+
|
300 |
+
|
capspeech/eval/src/model/dialect/whisper_dialect.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import copy
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
import loralib as lora
|
8 |
+
import transformers.models.whisper.modeling_whisper as whisper
|
9 |
+
|
10 |
+
from torch import nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from transformers.activations import ACT2FN
|
13 |
+
from transformers import WhisperModel, AutoFeatureExtractor
|
14 |
+
|
15 |
+
import sys
|
16 |
+
from pathlib import Path
|
17 |
+
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1])))
|
18 |
+
from revgrad import RevGrad
|
19 |
+
|
20 |
+
class WhisperEncoderLayer(nn.Module):
|
21 |
+
def __init__(self, config, layer_idx):
|
22 |
+
super().__init__()
|
23 |
+
self.embed_dim = config.d_model
|
24 |
+
self.self_attn = whisper.WhisperAttention(
|
25 |
+
embed_dim=self.embed_dim,
|
26 |
+
num_heads=config.encoder_attention_heads,
|
27 |
+
dropout=config.attention_dropout,
|
28 |
+
)
|
29 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
30 |
+
self.dropout = config.dropout
|
31 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
32 |
+
self.activation_dropout = config.activation_dropout
|
33 |
+
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
34 |
+
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
35 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
36 |
+
self.config = config
|
37 |
+
|
38 |
+
if layer_idx > config.encoder_layers // 2:
|
39 |
+
if self.config.finetune_method == "lora" or self.config.finetune_method == "combined":
|
40 |
+
self.fc1 = lora.Linear(self.embed_dim, config.encoder_ffn_dim, r=config.lora_rank)
|
41 |
+
self.fc2 = lora.Linear(config.encoder_ffn_dim, self.embed_dim, r=config.lora_rank)
|
42 |
+
|
43 |
+
def forward(
|
44 |
+
self,
|
45 |
+
hidden_states: torch.Tensor,
|
46 |
+
attention_mask: torch.Tensor,
|
47 |
+
layer_head_mask: torch.Tensor,
|
48 |
+
output_attentions: bool = False,
|
49 |
+
):
|
50 |
+
"""
|
51 |
+
Args:
|
52 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
53 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
54 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
55 |
+
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
56 |
+
`(encoder_attention_heads,)`.
|
57 |
+
output_attentions (`bool`, *optional*):
|
58 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
59 |
+
returned tensors for more detail.
|
60 |
+
"""
|
61 |
+
residual = hidden_states
|
62 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
63 |
+
hidden_states, attn_weights, _ = self.self_attn(
|
64 |
+
hidden_states=hidden_states,
|
65 |
+
attention_mask=attention_mask,
|
66 |
+
layer_head_mask=layer_head_mask,
|
67 |
+
output_attentions=output_attentions,
|
68 |
+
)
|
69 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
70 |
+
hidden_states = residual + hidden_states
|
71 |
+
residual = hidden_states
|
72 |
+
|
73 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
74 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
75 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
76 |
+
hidden_states = self.fc2(hidden_states)
|
77 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
78 |
+
hidden_states = residual + hidden_states
|
79 |
+
|
80 |
+
if hidden_states.dtype == torch.float16 and (
|
81 |
+
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
82 |
+
):
|
83 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
84 |
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
85 |
+
outputs = (hidden_states,)
|
86 |
+
|
87 |
+
if output_attentions:
|
88 |
+
outputs += (attn_weights,)
|
89 |
+
|
90 |
+
return outputs
|
91 |
+
|
92 |
+
class WhisperWrapper(nn.Module):
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
pretrain_model="whisper_large",
|
96 |
+
output_class_num=4,
|
97 |
+
hidden_dim=256,
|
98 |
+
finetune_method="lora",
|
99 |
+
lora_rank=16,
|
100 |
+
freeze_params=True,
|
101 |
+
use_conv_output=True,
|
102 |
+
apply_gradient_reversal=False,
|
103 |
+
num_dataset=4
|
104 |
+
):
|
105 |
+
super(WhisperWrapper, self).__init__()
|
106 |
+
# 1. We Load the model first with weights
|
107 |
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny", chunk_length=15)
|
108 |
+
self.pretrain_model = pretrain_model
|
109 |
+
if self.pretrain_model == "whisper_tiny":
|
110 |
+
self.backbone_model = WhisperModel.from_pretrained(
|
111 |
+
"openai/whisper-tiny",
|
112 |
+
output_hidden_states=True,
|
113 |
+
ignore_mismatched_sizes=True,
|
114 |
+
max_source_positions=750,
|
115 |
+
)
|
116 |
+
elif self.pretrain_model == "whisper_base":
|
117 |
+
self.backbone_model = WhisperModel.from_pretrained(
|
118 |
+
"openai/whisper-base",
|
119 |
+
output_hidden_states=True,
|
120 |
+
ignore_mismatched_sizes=True,
|
121 |
+
max_source_positions=750,
|
122 |
+
)
|
123 |
+
elif self.pretrain_model == "whisper_small":
|
124 |
+
self.backbone_model = WhisperModel.from_pretrained(
|
125 |
+
"openai/whisper-small",
|
126 |
+
output_hidden_states=True,
|
127 |
+
max_source_positions=750,
|
128 |
+
ignore_mismatched_sizes=True
|
129 |
+
)
|
130 |
+
elif self.pretrain_model == "whisper_medium":
|
131 |
+
self.backbone_model = WhisperModel.from_pretrained(
|
132 |
+
"openai/whisper-medium",
|
133 |
+
output_hidden_states=True,
|
134 |
+
ignore_mismatched_sizes=True
|
135 |
+
)
|
136 |
+
elif self.pretrain_model == "whisper_large":
|
137 |
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large-v3", chunk_length=15)
|
138 |
+
self.backbone_model = WhisperModel.from_pretrained(
|
139 |
+
"openai/whisper-large-v3",
|
140 |
+
output_hidden_states=True,
|
141 |
+
ignore_mismatched_sizes=True,
|
142 |
+
max_source_positions=750,
|
143 |
+
)
|
144 |
+
self.embed_positions = copy.deepcopy(self.backbone_model.encoder.embed_positions.weight)
|
145 |
+
self.embed_positions.requires_grad = False
|
146 |
+
|
147 |
+
state_dict = self.backbone_model.state_dict()
|
148 |
+
# 2. Read the model config
|
149 |
+
self.model_config = self.backbone_model.config
|
150 |
+
self.model_config.finetune_method = finetune_method
|
151 |
+
self.model_config.lora_rank = lora_rank
|
152 |
+
self.finetune_method = finetune_method
|
153 |
+
self.apply_gradient_reversal = apply_gradient_reversal
|
154 |
+
self.use_conv_output = use_conv_output
|
155 |
+
|
156 |
+
if self.finetune_method == "lora":
|
157 |
+
# 3. Config encoder layers with adapter or embedding prompt
|
158 |
+
self.backbone_model.encoder.layers = nn.ModuleList(
|
159 |
+
[WhisperEncoderLayer(self.model_config, layer_idx) for layer_idx in range(self.model_config.encoder_layers)]
|
160 |
+
)
|
161 |
+
# 4. Load the weights back
|
162 |
+
msg = self.backbone_model.load_state_dict(state_dict, strict=False)
|
163 |
+
|
164 |
+
# 2. Freeze the weights
|
165 |
+
self.freeze_params = freeze_params
|
166 |
+
if self.freeze_params and self.finetune_method != "lora":
|
167 |
+
for _, p in self.backbone_model.named_parameters(): p.requires_grad = False
|
168 |
+
elif self.freeze_params and self.finetune_method == "lora":
|
169 |
+
for name, p in self.backbone_model.named_parameters():
|
170 |
+
if name in msg.missing_keys: p.requires_grad = True
|
171 |
+
else: p.requires_grad = False
|
172 |
+
else:
|
173 |
+
for name, p in self.backbone_model.named_parameters():
|
174 |
+
if "decoder" not in name and "conv1" not in name and "conv2" not in name and "embed_positions" not in name: p.requires_grad = True
|
175 |
+
else: p.requires_grad = False
|
176 |
+
|
177 |
+
# 6. Downstream models
|
178 |
+
self.model_seq = nn.Sequential(
|
179 |
+
nn.Conv1d(self.model_config.hidden_size, hidden_dim, 1, padding=0),
|
180 |
+
nn.ReLU(),
|
181 |
+
nn.Dropout(p=0.1),
|
182 |
+
nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0),
|
183 |
+
nn.ReLU(),
|
184 |
+
nn.Dropout(p=0.1),
|
185 |
+
nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0)
|
186 |
+
)
|
187 |
+
|
188 |
+
if use_conv_output:
|
189 |
+
num_layers = self.model_config.num_hidden_layers + 1 # transformer layers + input embeddings
|
190 |
+
self.weights = nn.Parameter(torch.ones(num_layers)/num_layers)
|
191 |
+
else:
|
192 |
+
num_layers = self.model_config.num_hidden_layers
|
193 |
+
self.weights = nn.Parameter(torch.zeros(num_layers))
|
194 |
+
|
195 |
+
if apply_gradient_reversal:
|
196 |
+
self.dataset_layer = nn.Sequential(
|
197 |
+
RevGrad(),
|
198 |
+
nn.Linear(hidden_dim, hidden_dim),
|
199 |
+
nn.ReLU(),
|
200 |
+
nn.Linear(hidden_dim, num_dataset),
|
201 |
+
)
|
202 |
+
self.out_layer = nn.Sequential(
|
203 |
+
nn.Linear(hidden_dim, hidden_dim),
|
204 |
+
nn.ReLU(),
|
205 |
+
nn.Linear(hidden_dim, output_class_num),
|
206 |
+
)
|
207 |
+
|
208 |
+
|
209 |
+
def forward(self, x, length=None, return_feature=False):
|
210 |
+
# 1. feature extraction and projections
|
211 |
+
if length is not None:
|
212 |
+
max_audio_len = 15*16000
|
213 |
+
# Append to list for feature_extractor to work
|
214 |
+
new_x = list()
|
215 |
+
for idx in range(len(length)):
|
216 |
+
new_x.append(x[idx].detach().cpu().numpy())
|
217 |
+
|
218 |
+
# Max length is max audio len in a batch
|
219 |
+
features = self.feature_extractor(
|
220 |
+
new_x,
|
221 |
+
return_tensors="pt",
|
222 |
+
sampling_rate=16000,
|
223 |
+
max_length=max_audio_len
|
224 |
+
)
|
225 |
+
features = features.input_features.cuda()
|
226 |
+
else:
|
227 |
+
max_audio_len = 15*16000
|
228 |
+
features = self.feature_extractor(
|
229 |
+
x[0].detach().cpu(),
|
230 |
+
return_tensors="pt",
|
231 |
+
sampling_rate=16000,
|
232 |
+
max_length=max_audio_len
|
233 |
+
)
|
234 |
+
features = features.input_features.cuda()
|
235 |
+
|
236 |
+
# 2. get length and mask
|
237 |
+
if length is not None:
|
238 |
+
length = self._get_feat_extract_output_lengths(length.detach().cpu())
|
239 |
+
# Replace positional embeddings
|
240 |
+
self.backbone_model.encoder.embed_positions = self.backbone_model.encoder.embed_positions.from_pretrained(self.embed_positions[:750])
|
241 |
+
else:
|
242 |
+
# Replace positional embeddings
|
243 |
+
length = torch.tensor([len(x[0])])
|
244 |
+
length = self._get_feat_extract_output_lengths(length)
|
245 |
+
self.backbone_model.encoder.embed_positions = self.backbone_model.encoder.embed_positions.from_pretrained(self.embed_positions[:750])
|
246 |
+
|
247 |
+
# 3. transformer encoding features
|
248 |
+
# compute reduced attention_mask corresponding to feature vectors
|
249 |
+
features = self.backbone_model.encoder(
|
250 |
+
features, output_hidden_states=True
|
251 |
+
).hidden_states
|
252 |
+
|
253 |
+
features = torch.stack(features, dim=0)[-1]
|
254 |
+
|
255 |
+
# 6. Pass the weighted average to point-wise 1D Conv
|
256 |
+
# B x T x D
|
257 |
+
features = features.transpose(1, 2)
|
258 |
+
features = self.model_seq(features)
|
259 |
+
features = features.transpose(1, 2)
|
260 |
+
|
261 |
+
# 7. Pooling
|
262 |
+
if length is not None:
|
263 |
+
mean, std = list(), list()
|
264 |
+
for snt_id in range(features.shape[0]):
|
265 |
+
# Avoiding padded time steps
|
266 |
+
actual_size = length[snt_id]
|
267 |
+
mean.append(torch.mean(features[snt_id, 0:actual_size, ...], dim=0))
|
268 |
+
features = torch.stack(mean)
|
269 |
+
else:
|
270 |
+
features = torch.mean(features, dim=1)
|
271 |
+
|
272 |
+
# 8. Output predictions
|
273 |
+
# B x D
|
274 |
+
predicted = self.out_layer(features)
|
275 |
+
if self.apply_gradient_reversal:
|
276 |
+
dataset_predicted = self.dataset_layer(features)
|
277 |
+
if return_feature: return predicted, dataset_predicted, features
|
278 |
+
return predicted, dataset_predicted
|
279 |
+
if return_feature: return predicted, features
|
280 |
+
return predicted
|
281 |
+
|
282 |
+
# From huggingface
|
283 |
+
def _get_feat_extract_output_lengths(self, input_lengths):
|
284 |
+
"""
|
285 |
+
Computes the output length of the convolutional layers
|
286 |
+
"""
|
287 |
+
input_lengths = input_lengths // 160
|
288 |
+
input_lengths = (input_lengths - 1) // 2 + 1
|
289 |
+
return input_lengths
|
290 |
+
|
291 |
+
def prepare_mask(length, shape, dtype):
|
292 |
+
# Modified from huggingface
|
293 |
+
mask = torch.zeros(
|
294 |
+
shape, dtype=dtype
|
295 |
+
)
|
296 |
+
# these two operations makes sure that all values
|
297 |
+
# before the output lengths indices are attended to
|
298 |
+
mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1
|
299 |
+
mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
300 |
+
return mask
|
301 |
+
|
capspeech/eval/src/model/emotion/__init__.py
ADDED
File without changes
|
capspeech/eval/src/model/emotion/wavlm_emotion.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import loralib as lora
|
4 |
+
import transformers.models.wavlm.modeling_wavlm as wavlm
|
5 |
+
from speechbrain.lobes.models.huggingface_transformers.huggingface import make_padding_masks
|
6 |
+
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from transformers import Wav2Vec2FeatureExtractor
|
10 |
+
from transformers import WavLMModel
|
11 |
+
|
12 |
+
import sys
|
13 |
+
from pathlib import Path
|
14 |
+
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1])))
|
15 |
+
|
16 |
+
class WavLMEncoderLayer(nn.Module):
|
17 |
+
def __init__(self, layer_idx, config, has_relative_position_bias: bool = True):
|
18 |
+
super().__init__()
|
19 |
+
self.attention = wavlm.WavLMAttention(
|
20 |
+
embed_dim=config.hidden_size,
|
21 |
+
num_heads=config.num_attention_heads,
|
22 |
+
dropout=config.attention_dropout,
|
23 |
+
num_buckets=config.num_buckets,
|
24 |
+
max_distance=config.max_bucket_distance,
|
25 |
+
has_relative_position_bias=has_relative_position_bias,
|
26 |
+
)
|
27 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
28 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
29 |
+
self.feed_forward = wavlm.WavLMFeedForward(config)
|
30 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
31 |
+
self.config = config
|
32 |
+
|
33 |
+
if layer_idx > config.num_hidden_layers // 2:
|
34 |
+
if self.config.finetune_method == "lora" or self.config.finetune_method == "combined":
|
35 |
+
self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank)
|
36 |
+
self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank)
|
37 |
+
|
38 |
+
def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0):
|
39 |
+
attn_residual = hidden_states
|
40 |
+
hidden_states, attn_weights, position_bias = self.attention(
|
41 |
+
hidden_states,
|
42 |
+
attention_mask=attention_mask,
|
43 |
+
position_bias=position_bias,
|
44 |
+
output_attentions=output_attentions,
|
45 |
+
index=index,
|
46 |
+
)
|
47 |
+
hidden_states = self.dropout(hidden_states)
|
48 |
+
hidden_states = attn_residual + hidden_states
|
49 |
+
|
50 |
+
# Adapter
|
51 |
+
if self.config.finetune_method == "adapter":
|
52 |
+
adapt_h = self.adapter(hidden_states)
|
53 |
+
|
54 |
+
hidden_states = self.layer_norm(hidden_states)
|
55 |
+
hidden_states = hidden_states + self.feed_forward(hidden_states)
|
56 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
57 |
+
outputs = (hidden_states, position_bias)
|
58 |
+
|
59 |
+
if output_attentions:
|
60 |
+
outputs += (attn_weights,)
|
61 |
+
|
62 |
+
return outputs
|
63 |
+
|
64 |
+
|
65 |
+
class WavLMEncoderLayerStableLayerNorm(nn.Module):
|
66 |
+
def __init__(self, layer_idx, config, has_relative_position_bias: bool = True):
|
67 |
+
super().__init__()
|
68 |
+
self.attention = wavlm.WavLMAttention(
|
69 |
+
embed_dim=config.hidden_size,
|
70 |
+
num_heads=config.num_attention_heads,
|
71 |
+
dropout=config.attention_dropout,
|
72 |
+
num_buckets=config.num_buckets,
|
73 |
+
max_distance=config.max_bucket_distance,
|
74 |
+
has_relative_position_bias=has_relative_position_bias,
|
75 |
+
)
|
76 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
77 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
78 |
+
self.feed_forward = wavlm.WavLMFeedForward(config)
|
79 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
80 |
+
self.config = config
|
81 |
+
|
82 |
+
if layer_idx > config.num_hidden_layers // 2:
|
83 |
+
if self.config.finetune_method == "lora" or self.config.finetune_method == "combined":
|
84 |
+
self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank)
|
85 |
+
self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank)
|
86 |
+
|
87 |
+
|
88 |
+
def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False):
|
89 |
+
attn_residual = hidden_states
|
90 |
+
hidden_states = self.layer_norm(hidden_states)
|
91 |
+
hidden_states, attn_weights, position_bias = self.attention(
|
92 |
+
hidden_states,
|
93 |
+
attention_mask=attention_mask,
|
94 |
+
position_bias=position_bias,
|
95 |
+
output_attentions=output_attentions,
|
96 |
+
)
|
97 |
+
hidden_states = self.dropout(hidden_states)
|
98 |
+
hidden_states = attn_residual + hidden_states
|
99 |
+
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
|
100 |
+
|
101 |
+
outputs = (hidden_states, position_bias)
|
102 |
+
|
103 |
+
if output_attentions:
|
104 |
+
outputs += (attn_weights,)
|
105 |
+
|
106 |
+
return outputs
|
107 |
+
|
108 |
+
|
109 |
+
class WavLMWrapper(nn.Module):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
pretrain_model="wavlm_large",
|
113 |
+
hidden_dim=256,
|
114 |
+
finetune_method="lora",
|
115 |
+
lora_rank=16,
|
116 |
+
freeze_params=True,
|
117 |
+
output_class_num=4,
|
118 |
+
use_conv_output=True,
|
119 |
+
detailed_class_num=17
|
120 |
+
):
|
121 |
+
super(WavLMWrapper, self).__init__()
|
122 |
+
# 1. We Load the model first with weights
|
123 |
+
self.pretrain_model = pretrain_model
|
124 |
+
self.finetune_method = finetune_method
|
125 |
+
self.freeze_params = freeze_params
|
126 |
+
self.use_conv_output = use_conv_output
|
127 |
+
self.lora_rank = lora_rank
|
128 |
+
if self.pretrain_model == "wavlm":
|
129 |
+
self.backbone_model = WavLMModel.from_pretrained(
|
130 |
+
"microsoft/wavlm-base-plus",
|
131 |
+
output_hidden_states=True,
|
132 |
+
)
|
133 |
+
elif self.pretrain_model == "wavlm_large":
|
134 |
+
self.processor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-large')
|
135 |
+
self.backbone_model = WavLMModel.from_pretrained(
|
136 |
+
"microsoft/wavlm-large",
|
137 |
+
output_hidden_states=True,
|
138 |
+
)
|
139 |
+
state_dict = self.backbone_model.state_dict()
|
140 |
+
# 2. Read the model config
|
141 |
+
self.model_config = self.backbone_model.config
|
142 |
+
self.model_config.finetune_method = self.finetune_method
|
143 |
+
self.model_config.lora_rank = self.lora_rank
|
144 |
+
|
145 |
+
# 3. Config encoder layers with adapter or embedding prompt
|
146 |
+
if self.pretrain_model == "wavlm":
|
147 |
+
self.backbone_model.encoder.layers = nn.ModuleList(
|
148 |
+
[WavLMEncoderLayer(i, self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)]
|
149 |
+
)
|
150 |
+
elif self.pretrain_model == "wavlm_large":
|
151 |
+
self.backbone_model.encoder.layers = nn.ModuleList(
|
152 |
+
[WavLMEncoderLayerStableLayerNorm(i, self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)]
|
153 |
+
)
|
154 |
+
# 4. Load the weights back
|
155 |
+
msg = self.backbone_model.load_state_dict(state_dict, strict=False)
|
156 |
+
|
157 |
+
# 5. Freeze the weights
|
158 |
+
self.freeze_params = freeze_params
|
159 |
+
if self.freeze_params and self.finetune_method != "lora":
|
160 |
+
for _, p in self.backbone_model.named_parameters(): p.requires_grad = False
|
161 |
+
elif self.freeze_params and self.finetune_method == "lora":
|
162 |
+
for name, p in self.backbone_model.named_parameters():
|
163 |
+
if name in msg.missing_keys: p.requires_grad = True
|
164 |
+
else: p.requires_grad = False
|
165 |
+
else:
|
166 |
+
for _, p in self.backbone_model.named_parameters(): p.requires_grad = True
|
167 |
+
|
168 |
+
# 6. Downstream models
|
169 |
+
self.model_seq = nn.Sequential(
|
170 |
+
nn.Conv1d(self.model_config.hidden_size, hidden_dim, 1, padding=0),
|
171 |
+
nn.ReLU(),
|
172 |
+
nn.Dropout(p=0.1),
|
173 |
+
nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0),
|
174 |
+
nn.ReLU(),
|
175 |
+
nn.Dropout(p=0.1),
|
176 |
+
nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0)
|
177 |
+
)
|
178 |
+
|
179 |
+
if self.use_conv_output:
|
180 |
+
num_layers = self.model_config.num_hidden_layers + 1 # transformer layers + input embeddings
|
181 |
+
self.weights = nn.Parameter(torch.ones(num_layers)/num_layers)
|
182 |
+
else:
|
183 |
+
num_layers = self.model_config.num_hidden_layers
|
184 |
+
self.weights = nn.Parameter(torch.zeros(num_layers))
|
185 |
+
|
186 |
+
self.emotion_layer = nn.Sequential(
|
187 |
+
nn.Linear(hidden_dim, hidden_dim),
|
188 |
+
nn.ReLU(),
|
189 |
+
nn.Linear(hidden_dim, output_class_num),
|
190 |
+
)
|
191 |
+
|
192 |
+
self.detailed_out_layer = nn.Sequential(
|
193 |
+
nn.Linear(hidden_dim, hidden_dim),
|
194 |
+
nn.ReLU(),
|
195 |
+
nn.Linear(hidden_dim, detailed_class_num),
|
196 |
+
)
|
197 |
+
|
198 |
+
self.arousal_layer = nn.Sequential(
|
199 |
+
nn.Linear(hidden_dim, hidden_dim),
|
200 |
+
nn.ReLU(),
|
201 |
+
nn.Linear(hidden_dim, 1),
|
202 |
+
nn.Sigmoid()
|
203 |
+
)
|
204 |
+
|
205 |
+
self.valence_layer = nn.Sequential(
|
206 |
+
nn.Linear(hidden_dim, hidden_dim),
|
207 |
+
nn.ReLU(),
|
208 |
+
nn.Linear(hidden_dim, 1),
|
209 |
+
nn.Sigmoid()
|
210 |
+
)
|
211 |
+
|
212 |
+
self.dominance_layer = nn.Sequential(
|
213 |
+
nn.Linear(hidden_dim, hidden_dim),
|
214 |
+
nn.ReLU(),
|
215 |
+
nn.Linear(hidden_dim, 1),
|
216 |
+
nn.Sigmoid()
|
217 |
+
)
|
218 |
+
|
219 |
+
def forward(self, x, length=None, return_feature=False):
|
220 |
+
# 1. feature extraction and projections
|
221 |
+
if self.pretrain_model == "wavlm_large":
|
222 |
+
with torch.no_grad():
|
223 |
+
signal, attention_mask = list(), list()
|
224 |
+
if length is not None: attention_mask = make_padding_masks(x, wav_len=length/length.max()).to(x.device)
|
225 |
+
else: attention_mask = make_padding_masks(x, wav_len=torch.tensor([1]).to(x.device)).to(x.device)
|
226 |
+
|
227 |
+
for idx in range(len(x)):
|
228 |
+
input = self.processor(x[idx], sampling_rate=16_000, return_tensors="pt", padding=True)
|
229 |
+
signal.append(input["input_values"][0].to(x.device))
|
230 |
+
signal = torch.stack(signal)
|
231 |
+
|
232 |
+
# 2. get length and mask
|
233 |
+
if length is not None:
|
234 |
+
length = self.get_feat_extract_output_lengths(length.detach().cpu())
|
235 |
+
length = length.cuda()
|
236 |
+
|
237 |
+
if self.pretrain_model == "wavlm":
|
238 |
+
x = self.backbone_model(
|
239 |
+
x, output_hidden_states=True
|
240 |
+
).hidden_states
|
241 |
+
else:
|
242 |
+
x = self.backbone_model(
|
243 |
+
signal,
|
244 |
+
attention_mask=attention_mask,
|
245 |
+
output_hidden_states=True
|
246 |
+
).hidden_states
|
247 |
+
|
248 |
+
# 4. stacked feature
|
249 |
+
if self.use_conv_output: stacked_feature = torch.stack(x, dim=0)
|
250 |
+
else: stacked_feature = torch.stack(x, dim=0)[1:]
|
251 |
+
|
252 |
+
# 5. Weighted sum
|
253 |
+
_, *origin_shape = stacked_feature.shape
|
254 |
+
# Return transformer enc outputs [num_enc_layers, B, T, D]
|
255 |
+
if self.use_conv_output:
|
256 |
+
stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers+1, -1)
|
257 |
+
else:
|
258 |
+
stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers, -1)
|
259 |
+
norm_weights = F.softmax(self.weights, dim=-1)
|
260 |
+
|
261 |
+
# Perform weighted average
|
262 |
+
weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0)
|
263 |
+
features = weighted_feature.view(*origin_shape)
|
264 |
+
|
265 |
+
# 6. Pass the weighted average to point-wise 1D Conv
|
266 |
+
# B x T x D
|
267 |
+
features = features.transpose(1, 2)
|
268 |
+
features = self.model_seq(features)
|
269 |
+
features = features.transpose(1, 2)
|
270 |
+
|
271 |
+
# 7. Pooling
|
272 |
+
if length is not None:
|
273 |
+
mean, std = list(), list()
|
274 |
+
for snt_id in range(features.shape[0]):
|
275 |
+
# Avoiding padded time steps
|
276 |
+
actual_size = length[snt_id]
|
277 |
+
mean.append(torch.mean(features[snt_id, 0:actual_size, ...], dim=0))
|
278 |
+
features = torch.stack(mean)
|
279 |
+
else:
|
280 |
+
features = torch.mean(features, dim=1)
|
281 |
+
|
282 |
+
# Output predictions
|
283 |
+
# B x D
|
284 |
+
predicted = self.emotion_layer(features)
|
285 |
+
detailed_predicted = self.detailed_out_layer(features)
|
286 |
+
arousal = self.arousal_layer(features)
|
287 |
+
valence = self.valence_layer(features)
|
288 |
+
dominance = self.dominance_layer(features)
|
289 |
+
if return_feature: return predicted, features, detailed_predicted, arousal, valence, dominance
|
290 |
+
return predicted, detailed_predicted, arousal, valence, dominance
|
291 |
+
|
292 |
+
# From huggingface
|
293 |
+
def get_feat_extract_output_lengths(self, input_length):
|
294 |
+
"""
|
295 |
+
Computes the output length of the convolutional layers
|
296 |
+
"""
|
297 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
298 |
+
# 1D convolutional layer output length formula taken
|
299 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
300 |
+
return (input_length - kernel_size) // stride + 1
|
301 |
+
for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride):
|
302 |
+
input_length = _conv_out_length(input_length, kernel_size, stride)
|
303 |
+
return input_length
|
304 |
+
|
305 |
+
def prepare_mask(length, shape, dtype):
|
306 |
+
# Modified from huggingface
|
307 |
+
mask = torch.zeros(
|
308 |
+
shape, dtype=dtype
|
309 |
+
)
|
310 |
+
# these two operations makes sure that all values
|
311 |
+
# before the output lengths indices are attended to
|
312 |
+
mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1
|
313 |
+
mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
314 |
+
return mask
|
315 |
+
|
capspeech/eval/src/model/emotion/wavlm_emotion_dim.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
import loralib as lora
|
7 |
+
import transformers.models.wav2vec2.modeling_wav2vec2 as w2v2
|
8 |
+
import transformers.models.wavlm.modeling_wavlm as wavlm
|
9 |
+
from speechbrain.lobes.models.huggingface_transformers.huggingface import make_padding_masks
|
10 |
+
|
11 |
+
from torch import nn
|
12 |
+
from torch.nn import functional as F
|
13 |
+
from transformers import Wav2Vec2FeatureExtractor
|
14 |
+
from transformers import WavLMModel
|
15 |
+
|
16 |
+
import sys
|
17 |
+
from pathlib import Path
|
18 |
+
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1])))
|
19 |
+
|
20 |
+
class WavLMEncoderLayer(nn.Module):
|
21 |
+
def __init__(self, layer_idx, config, has_relative_position_bias: bool = True):
|
22 |
+
super().__init__()
|
23 |
+
self.attention = wavlm.WavLMAttention(
|
24 |
+
embed_dim=config.hidden_size,
|
25 |
+
num_heads=config.num_attention_heads,
|
26 |
+
dropout=config.attention_dropout,
|
27 |
+
num_buckets=config.num_buckets,
|
28 |
+
max_distance=config.max_bucket_distance,
|
29 |
+
has_relative_position_bias=has_relative_position_bias,
|
30 |
+
)
|
31 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
32 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
33 |
+
self.feed_forward = wavlm.WavLMFeedForward(config)
|
34 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
35 |
+
self.config = config
|
36 |
+
|
37 |
+
if layer_idx > config.num_hidden_layers // 2:
|
38 |
+
if self.config.finetune_method == "lora" or self.config.finetune_method == "combined":
|
39 |
+
self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank)
|
40 |
+
self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank)
|
41 |
+
|
42 |
+
def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0):
|
43 |
+
attn_residual = hidden_states
|
44 |
+
hidden_states, attn_weights, position_bias = self.attention(
|
45 |
+
hidden_states,
|
46 |
+
attention_mask=attention_mask,
|
47 |
+
position_bias=position_bias,
|
48 |
+
output_attentions=output_attentions,
|
49 |
+
index=index,
|
50 |
+
)
|
51 |
+
hidden_states = self.dropout(hidden_states)
|
52 |
+
hidden_states = attn_residual + hidden_states
|
53 |
+
|
54 |
+
# Adapter
|
55 |
+
if self.config.finetune_method == "adapter":
|
56 |
+
adapt_h = self.adapter(hidden_states)
|
57 |
+
|
58 |
+
hidden_states = self.layer_norm(hidden_states)
|
59 |
+
hidden_states = hidden_states + self.feed_forward(hidden_states)
|
60 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
61 |
+
outputs = (hidden_states, position_bias)
|
62 |
+
|
63 |
+
if output_attentions:
|
64 |
+
outputs += (attn_weights,)
|
65 |
+
|
66 |
+
return outputs
|
67 |
+
|
68 |
+
|
69 |
+
class WavLMEncoderLayerStableLayerNorm(nn.Module):
|
70 |
+
def __init__(self, layer_idx, config, has_relative_position_bias: bool = True):
|
71 |
+
super().__init__()
|
72 |
+
self.attention = wavlm.WavLMAttention(
|
73 |
+
embed_dim=config.hidden_size,
|
74 |
+
num_heads=config.num_attention_heads,
|
75 |
+
dropout=config.attention_dropout,
|
76 |
+
num_buckets=config.num_buckets,
|
77 |
+
max_distance=config.max_bucket_distance,
|
78 |
+
has_relative_position_bias=has_relative_position_bias,
|
79 |
+
)
|
80 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
81 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
82 |
+
self.feed_forward = wavlm.WavLMFeedForward(config)
|
83 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
84 |
+
self.config = config
|
85 |
+
|
86 |
+
if layer_idx > config.num_hidden_layers // 2:
|
87 |
+
if self.config.finetune_method == "lora" or self.config.finetune_method == "combined":
|
88 |
+
self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank)
|
89 |
+
self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank)
|
90 |
+
|
91 |
+
|
92 |
+
def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False):
|
93 |
+
attn_residual = hidden_states
|
94 |
+
hidden_states = self.layer_norm(hidden_states)
|
95 |
+
hidden_states, attn_weights, position_bias = self.attention(
|
96 |
+
hidden_states,
|
97 |
+
attention_mask=attention_mask,
|
98 |
+
position_bias=position_bias,
|
99 |
+
output_attentions=output_attentions,
|
100 |
+
)
|
101 |
+
hidden_states = self.dropout(hidden_states)
|
102 |
+
hidden_states = attn_residual + hidden_states
|
103 |
+
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
|
104 |
+
|
105 |
+
outputs = (hidden_states, position_bias)
|
106 |
+
|
107 |
+
if output_attentions:
|
108 |
+
outputs += (attn_weights,)
|
109 |
+
|
110 |
+
return outputs
|
111 |
+
|
112 |
+
|
113 |
+
class WavLMWrapper(nn.Module):
|
114 |
+
def __init__(
|
115 |
+
self,
|
116 |
+
pretrain_model="wavlm_large",
|
117 |
+
hidden_dim=256,
|
118 |
+
finetune_method="lora",
|
119 |
+
lora_rank=16,
|
120 |
+
freeze_params=True,
|
121 |
+
output_class_num=4,
|
122 |
+
use_conv_output=True,
|
123 |
+
detailed_class_num=17,
|
124 |
+
predict_gender=False
|
125 |
+
):
|
126 |
+
super(WavLMWrapper, self).__init__()
|
127 |
+
# 1. We Load the model first with weights
|
128 |
+
self.pretrain_model = pretrain_model
|
129 |
+
self.finetune_method = finetune_method
|
130 |
+
self.freeze_params = freeze_params
|
131 |
+
self.use_conv_output = use_conv_output
|
132 |
+
self.lora_rank = lora_rank
|
133 |
+
self.predict_gender = predict_gender
|
134 |
+
if self.pretrain_model == "wavlm":
|
135 |
+
self.backbone_model = WavLMModel.from_pretrained(
|
136 |
+
"microsoft/wavlm-base-plus",
|
137 |
+
output_hidden_states=True,
|
138 |
+
)
|
139 |
+
elif self.pretrain_model == "wavlm_large":
|
140 |
+
self.processor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-large')
|
141 |
+
self.backbone_model = WavLMModel.from_pretrained(
|
142 |
+
"microsoft/wavlm-large",
|
143 |
+
output_hidden_states=True,
|
144 |
+
)
|
145 |
+
state_dict = self.backbone_model.state_dict()
|
146 |
+
# 2. Read the model config
|
147 |
+
self.model_config = self.backbone_model.config
|
148 |
+
self.model_config.finetune_method = self.finetune_method
|
149 |
+
self.model_config.lora_rank = self.lora_rank
|
150 |
+
|
151 |
+
# 3. Config encoder layers with adapter or embedding prompt
|
152 |
+
if self.pretrain_model == "wavlm":
|
153 |
+
self.backbone_model.encoder.layers = nn.ModuleList(
|
154 |
+
[WavLMEncoderLayer(i, self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)]
|
155 |
+
)
|
156 |
+
elif self.pretrain_model == "wavlm_large":
|
157 |
+
self.backbone_model.encoder.layers = nn.ModuleList(
|
158 |
+
[WavLMEncoderLayerStableLayerNorm(i, self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)]
|
159 |
+
)
|
160 |
+
# 4. Load the weights back
|
161 |
+
msg = self.backbone_model.load_state_dict(state_dict, strict=False)
|
162 |
+
|
163 |
+
# 5. Freeze the weights
|
164 |
+
self.freeze_params = freeze_params
|
165 |
+
if self.freeze_params and self.finetune_method != "lora":
|
166 |
+
for _, p in self.backbone_model.named_parameters(): p.requires_grad = False
|
167 |
+
elif self.freeze_params and self.finetune_method == "lora":
|
168 |
+
for name, p in self.backbone_model.named_parameters():
|
169 |
+
if name in msg.missing_keys: p.requires_grad = True
|
170 |
+
else: p.requires_grad = False
|
171 |
+
else:
|
172 |
+
for _, p in self.backbone_model.named_parameters(): p.requires_grad = True
|
173 |
+
|
174 |
+
# 6. Downstream models
|
175 |
+
self.model_seq = nn.Sequential(
|
176 |
+
nn.Conv1d(self.model_config.hidden_size, hidden_dim, 1, padding=0),
|
177 |
+
nn.ReLU(),
|
178 |
+
nn.Dropout(p=0.1),
|
179 |
+
nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0),
|
180 |
+
nn.ReLU(),
|
181 |
+
nn.Dropout(p=0.1),
|
182 |
+
nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0)
|
183 |
+
)
|
184 |
+
|
185 |
+
if self.use_conv_output:
|
186 |
+
num_layers = self.model_config.num_hidden_layers + 1 # transformer layers + input embeddings
|
187 |
+
self.weights = nn.Parameter(torch.ones(num_layers)/num_layers)
|
188 |
+
else:
|
189 |
+
num_layers = self.model_config.num_hidden_layers
|
190 |
+
self.weights = nn.Parameter(torch.zeros(num_layers))
|
191 |
+
|
192 |
+
self.arousal_layer = nn.Sequential(
|
193 |
+
nn.Linear(hidden_dim, hidden_dim),
|
194 |
+
nn.ReLU(),
|
195 |
+
nn.Linear(hidden_dim, 1),
|
196 |
+
nn.Sigmoid()
|
197 |
+
)
|
198 |
+
|
199 |
+
self.valence_layer = nn.Sequential(
|
200 |
+
nn.Linear(hidden_dim, hidden_dim),
|
201 |
+
nn.ReLU(),
|
202 |
+
nn.Linear(hidden_dim, 1),
|
203 |
+
nn.Sigmoid()
|
204 |
+
)
|
205 |
+
|
206 |
+
self.dominance_layer = nn.Sequential(
|
207 |
+
nn.Linear(hidden_dim, hidden_dim),
|
208 |
+
nn.ReLU(),
|
209 |
+
nn.Linear(hidden_dim, 1),
|
210 |
+
nn.Sigmoid()
|
211 |
+
)
|
212 |
+
|
213 |
+
if self.predict_gender:
|
214 |
+
self.gender_layer = nn.Sequential(
|
215 |
+
nn.Linear(hidden_dim, hidden_dim),
|
216 |
+
nn.ReLU(),
|
217 |
+
nn.Linear(hidden_dim, 2)
|
218 |
+
)
|
219 |
+
|
220 |
+
def forward(self, x, length=None, return_feature=False):
|
221 |
+
# 1. feature extraction and projections
|
222 |
+
if self.pretrain_model == "wavlm_large":
|
223 |
+
with torch.no_grad():
|
224 |
+
signal, attention_mask = list(), list()
|
225 |
+
if length is not None: attention_mask = make_padding_masks(x, wav_len=length/length.max()).to(x.device)
|
226 |
+
else: attention_mask = make_padding_masks(x, wav_len=torch.tensor([1]).to(x.device)).to(x.device)
|
227 |
+
|
228 |
+
for idx in range(len(x)):
|
229 |
+
input = self.processor(x[idx], sampling_rate=16_000, return_tensors="pt", padding=True)
|
230 |
+
signal.append(input["input_values"][0].to(x.device))
|
231 |
+
signal = torch.stack(signal)
|
232 |
+
|
233 |
+
# 2. get length and mask
|
234 |
+
if length is not None:
|
235 |
+
length = self.get_feat_extract_output_lengths(length.detach().cpu())
|
236 |
+
length = length.cuda()
|
237 |
+
|
238 |
+
if self.pretrain_model == "wavlm":
|
239 |
+
x = self.backbone_model(
|
240 |
+
x, output_hidden_states=True
|
241 |
+
).hidden_states
|
242 |
+
else:
|
243 |
+
x = self.backbone_model(
|
244 |
+
signal,
|
245 |
+
attention_mask=attention_mask,
|
246 |
+
output_hidden_states=True
|
247 |
+
).hidden_states
|
248 |
+
|
249 |
+
# 4. stacked feature
|
250 |
+
if self.use_conv_output: stacked_feature = torch.stack(x, dim=0)
|
251 |
+
else: stacked_feature = torch.stack(x, dim=0)[1:]
|
252 |
+
|
253 |
+
# 5. Weighted sum
|
254 |
+
_, *origin_shape = stacked_feature.shape
|
255 |
+
# Return transformer enc outputs [num_enc_layers, B, T, D]
|
256 |
+
if self.use_conv_output:
|
257 |
+
stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers+1, -1)
|
258 |
+
else:
|
259 |
+
stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers, -1)
|
260 |
+
norm_weights = F.softmax(self.weights, dim=-1)
|
261 |
+
|
262 |
+
# Perform weighted average
|
263 |
+
weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0)
|
264 |
+
features = weighted_feature.view(*origin_shape)
|
265 |
+
|
266 |
+
# 6. Pass the weighted average to point-wise 1D Conv
|
267 |
+
# B x T x D
|
268 |
+
features = features.transpose(1, 2)
|
269 |
+
features = self.model_seq(features)
|
270 |
+
features = features.transpose(1, 2)
|
271 |
+
|
272 |
+
# 7. Pooling
|
273 |
+
if length is not None:
|
274 |
+
mean, std = list(), list()
|
275 |
+
for snt_id in range(features.shape[0]):
|
276 |
+
# Avoiding padded time steps
|
277 |
+
actual_size = length[snt_id]
|
278 |
+
mean.append(torch.mean(features[snt_id, 0:actual_size, ...], dim=0))
|
279 |
+
features = torch.stack(mean)
|
280 |
+
else:
|
281 |
+
features = torch.mean(features, dim=1)
|
282 |
+
|
283 |
+
# 8. Output predictions
|
284 |
+
# B x D
|
285 |
+
arousal = self.arousal_layer(features)
|
286 |
+
valence = self.valence_layer(features)
|
287 |
+
dominance = self.dominance_layer(features)
|
288 |
+
|
289 |
+
if(self.predict_gender):
|
290 |
+
gender_outputs = self.gender_layer(features)
|
291 |
+
return arousal, valence, dominance, gender_outputs
|
292 |
+
|
293 |
+
return arousal, valence, dominance
|
294 |
+
|
295 |
+
# From huggingface
|
296 |
+
def get_feat_extract_output_lengths(self, input_length):
|
297 |
+
"""
|
298 |
+
Computes the output length of the convolutional layers
|
299 |
+
"""
|
300 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
301 |
+
# 1D convolutional layer output length formula taken
|
302 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
303 |
+
return (input_length - kernel_size) // stride + 1
|
304 |
+
for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride):
|
305 |
+
input_length = _conv_out_length(input_length, kernel_size, stride)
|
306 |
+
return input_length
|
307 |
+
|
308 |
+
def prepare_mask(length, shape, dtype):
|
309 |
+
# Modified from huggingface
|
310 |
+
mask = torch.zeros(
|
311 |
+
shape, dtype=dtype
|
312 |
+
)
|
313 |
+
# these two operations makes sure that all values
|
314 |
+
# before the output lengths indices are attended to
|
315 |
+
mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1
|
316 |
+
mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
317 |
+
return mask
|
318 |
+
|