Spaces:
Configuration error
Configuration error
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- src/sonicverse/configs/tasks.json +208 -0
- src/sonicverse/configs/tasks_baseline.json +20 -0
- src/sonicverse/configs/tasks_ft.json +208 -0
- src/sonicverse/configs/tasks_pt_weight.json +10 -0
- src/sonicverse/configs/zero2.json +23 -0
- src/sonicverse/configs/zero3.json +28 -0
- src/sonicverse/configs/zero3_offload.json +56 -0
- src/sonicverse/multi_token.egg-info/PKG-INFO +6 -0
- src/sonicverse/multi_token.egg-info/SOURCES.txt +6 -0
- src/sonicverse/multi_token.egg-info/dependency_links.txt +1 -0
- src/sonicverse/multi_token.egg-info/requires.txt +8 -0
- src/sonicverse/multi_token.egg-info/top_level.txt +1 -0
- src/sonicverse/multi_token/constants.py +4 -0
- src/sonicverse/multi_token/data_tools.py +336 -0
- src/sonicverse/multi_token/inference.py +83 -0
- src/sonicverse/multi_token/language_models/__init__.py +7 -0
- src/sonicverse/multi_token/language_models/base_model.py +181 -0
- src/sonicverse/multi_token/language_models/mistral.py +235 -0
- src/sonicverse/multi_token/modalities/__init__.py +31 -0
- src/sonicverse/multi_token/modalities/audio_clap.py +142 -0
- src/sonicverse/multi_token/modalities/audio_descript.py +169 -0
- src/sonicverse/multi_token/modalities/audio_descript_bu.py +133 -0
- src/sonicverse/multi_token/modalities/audio_mert.py +162 -0
- src/sonicverse/multi_token/modalities/audio_mert_bu.py +159 -0
- src/sonicverse/multi_token/modalities/audio_whisper.py +120 -0
- src/sonicverse/multi_token/modalities/base_modality.py +48 -0
- src/sonicverse/multi_token/modalities/bu__init__.py +31 -0
- src/sonicverse/multi_token/modalities/document_gte.py +144 -0
- src/sonicverse/multi_token/modalities/imagebind.py +153 -0
- src/sonicverse/multi_token/modalities/multi_task_projector_shared.py +321 -0
- src/sonicverse/multi_token/modalities/projectors.py +416 -0
- src/sonicverse/multi_token/modalities/video_xclip.py +113 -0
- src/sonicverse/multi_token/modalities/vision_clip.py +178 -0
- src/sonicverse/multi_token/model_utils.py +112 -0
- src/sonicverse/multi_token/training.py +344 -0
- src/sonicverse/multi_token/training_data.py +133 -0
- src/sonicverse/requirements.txt +8 -0
- src/sonicverse/scripts/audio_setup.sh +3 -0
- src/sonicverse/scripts/clap_gpt_build_finetune_dataset.py +155 -0
- src/sonicverse/scripts/clap_gpt_build_pretrain_dataset.py +142 -0
- src/sonicverse/scripts/document_build_finetune_dataset.py +162 -0
- src/sonicverse/scripts/document_build_pretrain_dataset.py +89 -0
- src/sonicverse/scripts/document_setup.sh +5 -0
- src/sonicverse/scripts/evaluate_model.py +112 -0
- src/sonicverse/scripts/evaluate_model_latest.py +127 -0
- src/sonicverse/scripts/evaluate_model_mullama.py +168 -0
- src/sonicverse/scripts/evaluate_model_mullama_musiccaps.py +143 -0
- src/sonicverse/scripts/evaluate_model_mullama_musiccaps_fixed_prompt.py +138 -0
- src/sonicverse/scripts/evaluate_mullama.py +115 -0
- src/sonicverse/scripts/evaluate_temp.py +122 -0
src/sonicverse/configs/tasks.json
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"backbone": {
|
3 |
+
"num_layers": 5,
|
4 |
+
"input_channels": 25,
|
5 |
+
"output_channels": 25,
|
6 |
+
"output_size": 4096,
|
7 |
+
"hidden_size": 4096,
|
8 |
+
"requires_grad": true
|
9 |
+
},
|
10 |
+
"task_heads": {
|
11 |
+
"lmm_projector": {
|
12 |
+
"num_layers": 3,
|
13 |
+
"output_size": 4096,
|
14 |
+
"hidden_size": 4096,
|
15 |
+
"input_size": 768,
|
16 |
+
"input_channels": 13,
|
17 |
+
"width": 40,
|
18 |
+
"weight": 1.0,
|
19 |
+
"model_type": "mlp",
|
20 |
+
"requires_grad": true,
|
21 |
+
"use_aggregator": true,
|
22 |
+
"use_time_average": true,
|
23 |
+
"use_sigmoid": false,
|
24 |
+
"use_transpose": false,
|
25 |
+
"use_backbone_output": false
|
26 |
+
},
|
27 |
+
"instrument_detection": {
|
28 |
+
"model_type": "mlp",
|
29 |
+
"use_aggregator": true,
|
30 |
+
"use_time_average": true,
|
31 |
+
"use_sigmoid": true,
|
32 |
+
"use_transpose": false,
|
33 |
+
"num_layers": 2,
|
34 |
+
"input_size": 508,
|
35 |
+
"output_size": 40,
|
36 |
+
"hidden_size": 4096,
|
37 |
+
"width": 1,
|
38 |
+
"weight": 0.1,
|
39 |
+
"requires_grad": true,
|
40 |
+
"num_conv_layers": 4,
|
41 |
+
"output_channel": 1
|
42 |
+
},
|
43 |
+
"mood_detection": {
|
44 |
+
"model_type": "mlp",
|
45 |
+
"use_aggregator": true,
|
46 |
+
"use_time_average": true,
|
47 |
+
"use_sigmoid": true,
|
48 |
+
"use_transpose": false,
|
49 |
+
"num_layers": 2,
|
50 |
+
"input_size": 508,
|
51 |
+
"output_size": 56,
|
52 |
+
"hidden_size": 4096,
|
53 |
+
"width": 1,
|
54 |
+
"weight": 0.1,
|
55 |
+
"requires_grad": true,
|
56 |
+
"num_conv_layers": 4,
|
57 |
+
"output_channel": 1
|
58 |
+
},
|
59 |
+
"genre_detection": {
|
60 |
+
"model_type": "mlp",
|
61 |
+
"use_aggregator": true,
|
62 |
+
"use_time_average": true,
|
63 |
+
"use_sigmoid": true,
|
64 |
+
"use_transpose": false,
|
65 |
+
"num_layers": 2,
|
66 |
+
"input_size": 508,
|
67 |
+
"output_size": 87,
|
68 |
+
"hidden_size": 4096,
|
69 |
+
"width": 1,
|
70 |
+
"weight": 0.1,
|
71 |
+
"requires_grad": true,
|
72 |
+
"num_conv_layers": 4,
|
73 |
+
"output_channel": 1
|
74 |
+
},
|
75 |
+
"key_detection": {
|
76 |
+
"model_type": "mlp",
|
77 |
+
"use_aggregator": true,
|
78 |
+
"use_time_average": true,
|
79 |
+
"use_sigmoid": true,
|
80 |
+
"use_transpose": false,
|
81 |
+
"num_layers": 2,
|
82 |
+
"input_size": 508,
|
83 |
+
"output_size": 24,
|
84 |
+
"hidden_size": 4096,
|
85 |
+
"width": 1,
|
86 |
+
"weight": 0.1,
|
87 |
+
"requires_grad": true,
|
88 |
+
"num_conv_layers": 4,
|
89 |
+
"output_channel": 1
|
90 |
+
},
|
91 |
+
"vocals_detection": {
|
92 |
+
"model_type": "mlp",
|
93 |
+
"use_aggregator": true,
|
94 |
+
"use_time_average": true,
|
95 |
+
"use_sigmoid": true,
|
96 |
+
"use_transpose": false,
|
97 |
+
"num_layers": 2,
|
98 |
+
"input_size": 508,
|
99 |
+
"output_size": 3,
|
100 |
+
"hidden_size": 4096,
|
101 |
+
"width": 1,
|
102 |
+
"weight": 0.1,
|
103 |
+
"requires_grad": true,
|
104 |
+
"num_conv_layers": 4,
|
105 |
+
"output_channel": 1
|
106 |
+
}
|
107 |
+
},
|
108 |
+
"task_projectors": {
|
109 |
+
"instrument_detection": {
|
110 |
+
"model_type": "mlp",
|
111 |
+
"num_layers": 3,
|
112 |
+
"input_channels": 0,
|
113 |
+
"input_size": 40,
|
114 |
+
"output_size": 4096,
|
115 |
+
"hidden_size": 4096,
|
116 |
+
"width": 4,
|
117 |
+
"use_aggregator": false,
|
118 |
+
"use_time_average": false,
|
119 |
+
"use_sigmoid": false,
|
120 |
+
"use_transpose": false,
|
121 |
+
"requires_grad": true
|
122 |
+
},
|
123 |
+
"mood_detection": {
|
124 |
+
"model_type": "mlp",
|
125 |
+
"num_layers": 3,
|
126 |
+
"input_channels": 0,
|
127 |
+
"input_size": 56,
|
128 |
+
"output_size": 4096,
|
129 |
+
"hidden_size": 4096,
|
130 |
+
"width": 4,
|
131 |
+
"use_aggregator": false,
|
132 |
+
"use_time_average": false,
|
133 |
+
"use_sigmoid": false,
|
134 |
+
"use_transpose": false,
|
135 |
+
"requires_grad": true
|
136 |
+
},
|
137 |
+
"genre_detection": {
|
138 |
+
"model_type": "mlp",
|
139 |
+
"num_layers": 3,
|
140 |
+
"input_channels": 0,
|
141 |
+
"input_size": 87,
|
142 |
+
"output_size": 4096,
|
143 |
+
"hidden_size": 4096,
|
144 |
+
"width": 4,
|
145 |
+
"use_aggregator": false,
|
146 |
+
"use_time_average": false,
|
147 |
+
"use_sigmoid": false,
|
148 |
+
"use_transpose": false,
|
149 |
+
"requires_grad": true
|
150 |
+
},
|
151 |
+
"key_detection": {
|
152 |
+
"model_type": "mlp",
|
153 |
+
"num_layers": 3,
|
154 |
+
"input_channels": 0,
|
155 |
+
"input_size": 24,
|
156 |
+
"output_size": 4096,
|
157 |
+
"hidden_size": 4096,
|
158 |
+
"width": 4,
|
159 |
+
"use_aggregator": false,
|
160 |
+
"use_time_average": false,
|
161 |
+
"use_sigmoid": false,
|
162 |
+
"use_transpose": false,
|
163 |
+
"requires_grad": true
|
164 |
+
},
|
165 |
+
"vocals_detection": {
|
166 |
+
"model_type": "mlp",
|
167 |
+
"num_layers": 3,
|
168 |
+
"input_channels": 0,
|
169 |
+
"input_size": 3,
|
170 |
+
"output_size": 4096,
|
171 |
+
"hidden_size": 4096,
|
172 |
+
"width": 4,
|
173 |
+
"use_aggregator": false,
|
174 |
+
"use_time_average": false,
|
175 |
+
"use_sigmoid": false,
|
176 |
+
"use_transpose": false,
|
177 |
+
"requires_grad": true
|
178 |
+
},
|
179 |
+
"chords_detection": {
|
180 |
+
"model_type": "mlp",
|
181 |
+
"num_layers": 3,
|
182 |
+
"input_channels": 0,
|
183 |
+
"input_size": 216,
|
184 |
+
"output_size": 4096,
|
185 |
+
"hidden_size": 4096,
|
186 |
+
"width": 4,
|
187 |
+
"use_aggregator": false,
|
188 |
+
"use_time_average": false,
|
189 |
+
"use_sigmoid": false,
|
190 |
+
"use_transpose": false,
|
191 |
+
"requires_grad": true
|
192 |
+
},
|
193 |
+
"beats_detection": {
|
194 |
+
"model_type": "mlp_conv_agg",
|
195 |
+
"num_layers": 3,
|
196 |
+
"input_channels": 2,
|
197 |
+
"input_size": 500,
|
198 |
+
"output_size": 4096,
|
199 |
+
"hidden_size": 4096,
|
200 |
+
"width": 4,
|
201 |
+
"use_aggregator": true,
|
202 |
+
"use_time_average": false,
|
203 |
+
"use_sigmoid": false,
|
204 |
+
"use_transpose": true,
|
205 |
+
"requires_grad": true
|
206 |
+
}
|
207 |
+
}
|
208 |
+
}
|
src/sonicverse/configs/tasks_baseline.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"task_heads": {
|
3 |
+
"lmm_projector": {
|
4 |
+
"num_layers": 3,
|
5 |
+
"output_size": 4096,
|
6 |
+
"hidden_size": 4096,
|
7 |
+
"input_size": 768,
|
8 |
+
"input_channels": 13,
|
9 |
+
"width": 60,
|
10 |
+
"weight": 1.0,
|
11 |
+
"model_type": "mlp",
|
12 |
+
"requires_grad": true,
|
13 |
+
"use_aggregator": true,
|
14 |
+
"use_time_average": true,
|
15 |
+
"use_sigmoid": false,
|
16 |
+
"use_transpose": false
|
17 |
+
}
|
18 |
+
},
|
19 |
+
"task_projectors": {}
|
20 |
+
}
|
src/sonicverse/configs/tasks_ft.json
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"backbone": {
|
3 |
+
"num_layers": 5,
|
4 |
+
"input_channels": 25,
|
5 |
+
"output_channels": 25,
|
6 |
+
"output_size": 4096,
|
7 |
+
"hidden_size": 4096,
|
8 |
+
"requires_grad": false
|
9 |
+
},
|
10 |
+
"task_heads": {
|
11 |
+
"lmm_projector": {
|
12 |
+
"num_layers": 3,
|
13 |
+
"output_size": 4096,
|
14 |
+
"hidden_size": 4096,
|
15 |
+
"input_size": 768,
|
16 |
+
"input_channels": 13,
|
17 |
+
"width": 40,
|
18 |
+
"weight": 1.0,
|
19 |
+
"model_type": "mlp",
|
20 |
+
"requires_grad": true,
|
21 |
+
"use_aggregator": true,
|
22 |
+
"use_time_average": true,
|
23 |
+
"use_sigmoid": false,
|
24 |
+
"use_transpose": false,
|
25 |
+
"use_backbone_output": false
|
26 |
+
},
|
27 |
+
"instrument_detection": {
|
28 |
+
"model_type": "mlp",
|
29 |
+
"use_aggregator": true,
|
30 |
+
"use_time_average": true,
|
31 |
+
"use_sigmoid": true,
|
32 |
+
"use_transpose": false,
|
33 |
+
"num_layers": 2,
|
34 |
+
"input_size": 508,
|
35 |
+
"output_size": 40,
|
36 |
+
"hidden_size": 4096,
|
37 |
+
"width": 1,
|
38 |
+
"weight": 0.0,
|
39 |
+
"requires_grad": false,
|
40 |
+
"num_conv_layers": 4,
|
41 |
+
"output_channel": 1
|
42 |
+
},
|
43 |
+
"mood_detection": {
|
44 |
+
"model_type": "mlp",
|
45 |
+
"use_aggregator": true,
|
46 |
+
"use_time_average": true,
|
47 |
+
"use_sigmoid": true,
|
48 |
+
"use_transpose": false,
|
49 |
+
"num_layers": 2,
|
50 |
+
"input_size": 508,
|
51 |
+
"output_size": 56,
|
52 |
+
"hidden_size": 4096,
|
53 |
+
"width": 1,
|
54 |
+
"weight": 0.0,
|
55 |
+
"requires_grad": false,
|
56 |
+
"num_conv_layers": 4,
|
57 |
+
"output_channel": 1
|
58 |
+
},
|
59 |
+
"genre_detection": {
|
60 |
+
"model_type": "mlp",
|
61 |
+
"use_aggregator": true,
|
62 |
+
"use_time_average": true,
|
63 |
+
"use_sigmoid": true,
|
64 |
+
"use_transpose": false,
|
65 |
+
"num_layers": 2,
|
66 |
+
"input_size": 508,
|
67 |
+
"output_size": 87,
|
68 |
+
"hidden_size": 4096,
|
69 |
+
"width": 1,
|
70 |
+
"weight": 0.0,
|
71 |
+
"requires_grad": false,
|
72 |
+
"num_conv_layers": 4,
|
73 |
+
"output_channel": 1
|
74 |
+
},
|
75 |
+
"key_detection": {
|
76 |
+
"model_type": "mlp",
|
77 |
+
"use_aggregator": true,
|
78 |
+
"use_time_average": true,
|
79 |
+
"use_sigmoid": true,
|
80 |
+
"use_transpose": false,
|
81 |
+
"num_layers": 2,
|
82 |
+
"input_size": 508,
|
83 |
+
"output_size": 24,
|
84 |
+
"hidden_size": 4096,
|
85 |
+
"width": 1,
|
86 |
+
"weight": 0.0,
|
87 |
+
"requires_grad": false,
|
88 |
+
"num_conv_layers": 4,
|
89 |
+
"output_channel": 1
|
90 |
+
},
|
91 |
+
"vocals_detection": {
|
92 |
+
"model_type": "mlp",
|
93 |
+
"use_aggregator": true,
|
94 |
+
"use_time_average": true,
|
95 |
+
"use_sigmoid": true,
|
96 |
+
"use_transpose": false,
|
97 |
+
"num_layers": 2,
|
98 |
+
"input_size": 508,
|
99 |
+
"output_size": 3,
|
100 |
+
"hidden_size": 4096,
|
101 |
+
"width": 1,
|
102 |
+
"weight": 0.0,
|
103 |
+
"requires_grad": false,
|
104 |
+
"num_conv_layers": 4,
|
105 |
+
"output_channel": 1
|
106 |
+
}
|
107 |
+
},
|
108 |
+
"task_projectors": {
|
109 |
+
"instrument_detection": {
|
110 |
+
"model_type": "mlp",
|
111 |
+
"num_layers": 3,
|
112 |
+
"input_channels": 0,
|
113 |
+
"input_size": 40,
|
114 |
+
"output_size": 4096,
|
115 |
+
"hidden_size": 4096,
|
116 |
+
"width": 4,
|
117 |
+
"use_aggregator": false,
|
118 |
+
"use_time_average": false,
|
119 |
+
"use_sigmoid": false,
|
120 |
+
"use_transpose": false,
|
121 |
+
"requires_grad": true
|
122 |
+
},
|
123 |
+
"mood_detection": {
|
124 |
+
"model_type": "mlp",
|
125 |
+
"num_layers": 3,
|
126 |
+
"input_channels": 0,
|
127 |
+
"input_size": 56,
|
128 |
+
"output_size": 4096,
|
129 |
+
"hidden_size": 4096,
|
130 |
+
"width": 4,
|
131 |
+
"use_aggregator": false,
|
132 |
+
"use_time_average": false,
|
133 |
+
"use_sigmoid": false,
|
134 |
+
"use_transpose": false,
|
135 |
+
"requires_grad": true
|
136 |
+
},
|
137 |
+
"genre_detection": {
|
138 |
+
"model_type": "mlp",
|
139 |
+
"num_layers": 3,
|
140 |
+
"input_channels": 0,
|
141 |
+
"input_size": 87,
|
142 |
+
"output_size": 4096,
|
143 |
+
"hidden_size": 4096,
|
144 |
+
"width": 4,
|
145 |
+
"use_aggregator": false,
|
146 |
+
"use_time_average": false,
|
147 |
+
"use_sigmoid": false,
|
148 |
+
"use_transpose": false,
|
149 |
+
"requires_grad": true
|
150 |
+
},
|
151 |
+
"key_detection": {
|
152 |
+
"model_type": "mlp",
|
153 |
+
"num_layers": 3,
|
154 |
+
"input_channels": 0,
|
155 |
+
"input_size": 24,
|
156 |
+
"output_size": 4096,
|
157 |
+
"hidden_size": 4096,
|
158 |
+
"width": 4,
|
159 |
+
"use_aggregator": false,
|
160 |
+
"use_time_average": false,
|
161 |
+
"use_sigmoid": false,
|
162 |
+
"use_transpose": false,
|
163 |
+
"requires_grad": true
|
164 |
+
},
|
165 |
+
"vocals_detection": {
|
166 |
+
"model_type": "mlp",
|
167 |
+
"num_layers": 3,
|
168 |
+
"input_channels": 0,
|
169 |
+
"input_size": 3,
|
170 |
+
"output_size": 4096,
|
171 |
+
"hidden_size": 4096,
|
172 |
+
"width": 4,
|
173 |
+
"use_aggregator": false,
|
174 |
+
"use_time_average": false,
|
175 |
+
"use_sigmoid": false,
|
176 |
+
"use_transpose": false,
|
177 |
+
"requires_grad": true
|
178 |
+
},
|
179 |
+
"chords_detection": {
|
180 |
+
"model_type": "mlp",
|
181 |
+
"num_layers": 3,
|
182 |
+
"input_channels": 0,
|
183 |
+
"input_size": 216,
|
184 |
+
"output_size": 4096,
|
185 |
+
"hidden_size": 4096,
|
186 |
+
"width": 4,
|
187 |
+
"use_aggregator": false,
|
188 |
+
"use_time_average": false,
|
189 |
+
"use_sigmoid": false,
|
190 |
+
"use_transpose": false,
|
191 |
+
"requires_grad": true
|
192 |
+
},
|
193 |
+
"beats_detection": {
|
194 |
+
"model_type": "mlp_conv_agg",
|
195 |
+
"num_layers": 3,
|
196 |
+
"input_channels": 2,
|
197 |
+
"input_size": 500,
|
198 |
+
"output_size": 4096,
|
199 |
+
"hidden_size": 4096,
|
200 |
+
"width": 4,
|
201 |
+
"use_aggregator": true,
|
202 |
+
"use_time_average": false,
|
203 |
+
"use_sigmoid": false,
|
204 |
+
"use_transpose": true,
|
205 |
+
"requires_grad": true
|
206 |
+
}
|
207 |
+
}
|
208 |
+
}
|
src/sonicverse/configs/tasks_pt_weight.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"pretrained_paths": [
|
3 |
+
{
|
4 |
+
"path": "/experiments/music_extraction/mlp_shared_multi_task_trial_002/train_002_epoch_45_step_40.pth",
|
5 |
+
"components": ["backbone", "instrument_detection", "genre_detection", "mood_detection", "key_detection", "vocals_detection"],
|
6 |
+
"use_prefix": true,
|
7 |
+
"prefix": "audio_mert_lmm_projector"
|
8 |
+
}
|
9 |
+
]
|
10 |
+
}
|
src/sonicverse/configs/zero2.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"fp16": {
|
3 |
+
"enabled": "auto",
|
4 |
+
"loss_scale": 0,
|
5 |
+
"loss_scale_window": 1000,
|
6 |
+
"initial_scale_power": 16,
|
7 |
+
"hysteresis": 2,
|
8 |
+
"min_loss_scale": 1
|
9 |
+
},
|
10 |
+
"bf16": {
|
11 |
+
"enabled": "auto"
|
12 |
+
},
|
13 |
+
"train_micro_batch_size_per_gpu": "auto",
|
14 |
+
"train_batch_size": "auto",
|
15 |
+
"gradient_accumulation_steps": "auto",
|
16 |
+
"zero_optimization": {
|
17 |
+
"stage": 2,
|
18 |
+
"overlap_comm": true,
|
19 |
+
"contiguous_gradients": true,
|
20 |
+
"sub_group_size": 1e9,
|
21 |
+
"reduce_bucket_size": "auto"
|
22 |
+
}
|
23 |
+
}
|
src/sonicverse/configs/zero3.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"fp16": {
|
3 |
+
"enabled": "auto",
|
4 |
+
"loss_scale": 0,
|
5 |
+
"loss_scale_window": 1000,
|
6 |
+
"initial_scale_power": 16,
|
7 |
+
"hysteresis": 2,
|
8 |
+
"min_loss_scale": 1
|
9 |
+
},
|
10 |
+
"bf16": {
|
11 |
+
"enabled": "auto"
|
12 |
+
},
|
13 |
+
"train_micro_batch_size_per_gpu": "auto",
|
14 |
+
"train_batch_size": "auto",
|
15 |
+
"gradient_accumulation_steps": "auto",
|
16 |
+
"zero_optimization": {
|
17 |
+
"stage": 3,
|
18 |
+
"overlap_comm": true,
|
19 |
+
"contiguous_gradients": true,
|
20 |
+
"sub_group_size": 1e9,
|
21 |
+
"reduce_bucket_size": "auto",
|
22 |
+
"stage3_prefetch_bucket_size": "auto",
|
23 |
+
"stage3_param_persistence_threshold": "auto",
|
24 |
+
"stage3_max_live_parameters": 1e9,
|
25 |
+
"stage3_max_reuse_distance": 1e9,
|
26 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
27 |
+
}
|
28 |
+
}
|
src/sonicverse/configs/zero3_offload.json
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"fp16": {
|
3 |
+
"enabled": "auto",
|
4 |
+
"loss_scale": 0,
|
5 |
+
"loss_scale_window": 1000,
|
6 |
+
"initial_scale_power": 16,
|
7 |
+
"hysteresis": 2,
|
8 |
+
"min_loss_scale": 1
|
9 |
+
},
|
10 |
+
"bf16": {
|
11 |
+
"enabled": "auto"
|
12 |
+
},
|
13 |
+
"optimizer": {
|
14 |
+
"type": "AdamW",
|
15 |
+
"params": {
|
16 |
+
"lr": "auto",
|
17 |
+
"betas": "auto",
|
18 |
+
"eps": "auto",
|
19 |
+
"weight_decay": "auto"
|
20 |
+
}
|
21 |
+
},
|
22 |
+
"scheduler": {
|
23 |
+
"type": "WarmupLR",
|
24 |
+
"params": {
|
25 |
+
"warmup_min_lr": "auto",
|
26 |
+
"warmup_max_lr": "auto",
|
27 |
+
"warmup_num_steps": "auto"
|
28 |
+
}
|
29 |
+
},
|
30 |
+
"zero_optimization": {
|
31 |
+
"stage": 3,
|
32 |
+
"offload_optimizer": {
|
33 |
+
"device": "cpu",
|
34 |
+
"pin_memory": true
|
35 |
+
},
|
36 |
+
"offload_param": {
|
37 |
+
"device": "cpu",
|
38 |
+
"pin_memory": true
|
39 |
+
},
|
40 |
+
"overlap_comm": true,
|
41 |
+
"contiguous_gradients": true,
|
42 |
+
"sub_group_size": 1e9,
|
43 |
+
"reduce_bucket_size": "auto",
|
44 |
+
"stage3_prefetch_bucket_size": "auto",
|
45 |
+
"stage3_param_persistence_threshold": "auto",
|
46 |
+
"stage3_max_live_parameters": 1e9,
|
47 |
+
"stage3_max_reuse_distance": 1e9,
|
48 |
+
"gather_16bit_weights_on_model_save": true
|
49 |
+
},
|
50 |
+
"gradient_accumulation_steps": "auto",
|
51 |
+
"gradient_clipping": "auto",
|
52 |
+
"train_batch_size": "auto",
|
53 |
+
"train_micro_batch_size_per_gpu": "auto",
|
54 |
+
"steps_per_print": 1e5,
|
55 |
+
"wall_clock_breakdown": false
|
56 |
+
}
|
src/sonicverse/multi_token.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: multi-token
|
3 |
+
Version: 0.0.4
|
4 |
+
Home-page: https://github.com/sshh12/multi_token
|
5 |
+
Author: Shrivu Shankar
|
6 |
+
License: Apache License 2.0
|
src/sonicverse/multi_token.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
setup.py
|
2 |
+
multi_token.egg-info/PKG-INFO
|
3 |
+
multi_token.egg-info/SOURCES.txt
|
4 |
+
multi_token.egg-info/dependency_links.txt
|
5 |
+
multi_token.egg-info/requires.txt
|
6 |
+
multi_token.egg-info/top_level.txt
|
src/sonicverse/multi_token.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
src/sonicverse/multi_token.egg-info/requires.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers>=4.34.0
|
2 |
+
accelerate>=0.21.0
|
3 |
+
scipy>=1.11.3
|
4 |
+
bitsandbytes>=0.41.0
|
5 |
+
datasets>=2.14.5
|
6 |
+
sentencepiece>=0.1.99
|
7 |
+
peft>=0.4.0
|
8 |
+
deepspeed==0.9.5
|
src/sonicverse/multi_token.egg-info/top_level.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
src/sonicverse/multi_token/constants.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
IGNORE_INDEX = -100
|
2 |
+
|
3 |
+
ROLE_ASSISTANT = "assistant"
|
4 |
+
ROLE_USER = "user"
|
src/sonicverse/multi_token/data_tools.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any, Union, Optional
|
2 |
+
from collections import Counter
|
3 |
+
from functools import cache
|
4 |
+
import contextlib
|
5 |
+
import tempfile
|
6 |
+
import shutil
|
7 |
+
import random
|
8 |
+
import subprocess
|
9 |
+
import json
|
10 |
+
import re
|
11 |
+
import io
|
12 |
+
import os
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import requests
|
16 |
+
import transformers
|
17 |
+
import numpy as np
|
18 |
+
from datasets import load_dataset, Dataset
|
19 |
+
from PIL import Image
|
20 |
+
|
21 |
+
from multi_token.constants import IGNORE_INDEX
|
22 |
+
|
23 |
+
|
24 |
+
def encode_chat(
|
25 |
+
item: Dict,
|
26 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
27 |
+
modalities: List["Modality"],
|
28 |
+
) -> Dict:
|
29 |
+
messages = list(item["messages"])
|
30 |
+
chat_as_string = tokenizer.apply_chat_template(messages, tokenize=False)
|
31 |
+
|
32 |
+
token_to_modality = {m.token: m for m in modalities}
|
33 |
+
modality_token_counts = Counter()
|
34 |
+
instruct_pattern = r"(\[INST\][\s\S]*?\[\/INST\])"
|
35 |
+
pattern = "(" + "|".join(re.escape(m.token) for m in modalities) + ")"
|
36 |
+
|
37 |
+
chat_part = re.split(instruct_pattern, chat_as_string)
|
38 |
+
input_ids = []
|
39 |
+
labels = []
|
40 |
+
for part in chat_part:
|
41 |
+
if "[INST]" in part:
|
42 |
+
is_instruction = True
|
43 |
+
else:
|
44 |
+
is_instruction = False
|
45 |
+
for subpart in re.split(pattern, part):
|
46 |
+
if not subpart:
|
47 |
+
continue
|
48 |
+
if subpart in token_to_modality:
|
49 |
+
assert (
|
50 |
+
is_instruction
|
51 |
+
), "There should be no modality tokens outside of instructions"
|
52 |
+
m = token_to_modality[subpart]
|
53 |
+
modality_token_counts[m.name] += 1
|
54 |
+
input_ids.extend([m.token_idx] * m.token_width)
|
55 |
+
labels.extend([IGNORE_INDEX] * m.token_width)
|
56 |
+
elif is_instruction:
|
57 |
+
part_ids = tokenizer(subpart, add_special_tokens=False).input_ids
|
58 |
+
input_ids.extend(part_ids)
|
59 |
+
labels.extend([IGNORE_INDEX] * len(part_ids))
|
60 |
+
else:
|
61 |
+
part_ids = tokenizer(subpart, add_special_tokens=False).input_ids
|
62 |
+
input_ids.extend(part_ids)
|
63 |
+
labels.extend(part_ids)
|
64 |
+
|
65 |
+
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
66 |
+
labels = torch.tensor(labels, dtype=torch.long)
|
67 |
+
|
68 |
+
data_dict = dict(
|
69 |
+
input_ids=input_ids,
|
70 |
+
labels=labels,
|
71 |
+
)
|
72 |
+
for m in modalities:
|
73 |
+
data_dict[m.name] = m.preprocess_rows([item])[0]
|
74 |
+
return data_dict
|
75 |
+
|
76 |
+
def encode_chat_multitask(
|
77 |
+
item: Dict,
|
78 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
79 |
+
modalities: List["Modality"],
|
80 |
+
) -> Dict:
|
81 |
+
messages = list(item["messages"])
|
82 |
+
chat_as_string = tokenizer.apply_chat_template(messages, tokenize=False)
|
83 |
+
|
84 |
+
token_to_modality = {m.token: m for m in modalities}
|
85 |
+
modality_token_counts = Counter()
|
86 |
+
instruct_pattern = r"(\[INST\][\s\S]*?\[\/INST\])"
|
87 |
+
pattern = "(" + "|".join(re.escape(m.token) for m in modalities) + ")"
|
88 |
+
|
89 |
+
chat_part = re.split(instruct_pattern, chat_as_string)
|
90 |
+
input_ids = []
|
91 |
+
labels = []
|
92 |
+
labels.append([])
|
93 |
+
for part in chat_part:
|
94 |
+
if "[INST]" in part:
|
95 |
+
is_instruction = True
|
96 |
+
else:
|
97 |
+
is_instruction = False
|
98 |
+
for subpart in re.split(pattern, part):
|
99 |
+
if not subpart:
|
100 |
+
continue
|
101 |
+
if subpart in token_to_modality:
|
102 |
+
assert (
|
103 |
+
is_instruction
|
104 |
+
), "There should be no modality tokens outside of instructions"
|
105 |
+
m = token_to_modality[subpart]
|
106 |
+
modality_token_counts[m.name] += 1
|
107 |
+
input_ids.extend([m.token_idx] * m.token_width)
|
108 |
+
labels[0].extend([IGNORE_INDEX] * m.token_width)
|
109 |
+
elif is_instruction:
|
110 |
+
part_ids = tokenizer(subpart, add_special_tokens=False).input_ids
|
111 |
+
input_ids.extend(part_ids)
|
112 |
+
labels[0].extend([IGNORE_INDEX] * len(part_ids))
|
113 |
+
else:
|
114 |
+
part_ids = tokenizer(subpart, add_special_tokens=False).input_ids
|
115 |
+
input_ids.extend(part_ids)
|
116 |
+
labels[0].extend(part_ids)
|
117 |
+
|
118 |
+
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
119 |
+
labels[0] = torch.tensor(labels[0], dtype=torch.long)
|
120 |
+
|
121 |
+
task_list = []
|
122 |
+
for m in modalities:
|
123 |
+
task_list += m.tasks["task_heads"].keys()
|
124 |
+
# labels[task_specs["task_id"]] = load_tensor(item[task_name][0])
|
125 |
+
|
126 |
+
for task_name in task_list:
|
127 |
+
if task_name != "lmm_projector":
|
128 |
+
labels.append(load_tensor(item[task_name][0]))
|
129 |
+
|
130 |
+
# labels = torch.tensor(labels, dtype=torch.long)
|
131 |
+
|
132 |
+
data_dict = dict(
|
133 |
+
input_ids=input_ids,
|
134 |
+
labels=labels,
|
135 |
+
)
|
136 |
+
for m in modalities:
|
137 |
+
data_dict[m.name] = m.preprocess_rows([item])[0]
|
138 |
+
return data_dict
|
139 |
+
|
140 |
+
def load_tensor(path: str) -> np.ndarray:
|
141 |
+
return torch.tensor(np.load(path))
|
142 |
+
|
143 |
+
|
144 |
+
def load_image(value: Any) -> Image.Image:
|
145 |
+
img = None
|
146 |
+
if isinstance(value, str):
|
147 |
+
if value.startswith("http://") or value.startswith("https://"):
|
148 |
+
response = requests.get(value)
|
149 |
+
img = Image.open(io.BytesIO(response.content))
|
150 |
+
elif os.path.exists(value):
|
151 |
+
img = Image.open(value)
|
152 |
+
elif isinstance(value, Image.Image):
|
153 |
+
img = value
|
154 |
+
if img is None:
|
155 |
+
raise ValueError(f"Could not load image from {value}")
|
156 |
+
img = img.convert("RGB")
|
157 |
+
return img
|
158 |
+
|
159 |
+
|
160 |
+
@contextlib.contextmanager
|
161 |
+
def with_local_files(fn_or_urls: List[Any]):
|
162 |
+
local_fns = []
|
163 |
+
fps = []
|
164 |
+
for fn_or_url in fn_or_urls:
|
165 |
+
if isinstance(fn_or_url, Image.Image):
|
166 |
+
fp = tempfile.NamedTemporaryFile(suffix=".png", mode="wb")
|
167 |
+
fn_or_url.convert("RGB").save(fp)
|
168 |
+
fps.append(fp)
|
169 |
+
local_fns.append(fp.name)
|
170 |
+
elif fn_or_url.startswith("http://") or fn_or_url.startswith("https://"):
|
171 |
+
suffix = os.path.splitext(fn_or_url)[-1]
|
172 |
+
with requests.get(fn_or_url, stream=True) as r:
|
173 |
+
fp = tempfile.NamedTemporaryFile(suffix=suffix, mode="wb")
|
174 |
+
shutil.copyfileobj(r.raw, fp)
|
175 |
+
fps.append(fp)
|
176 |
+
local_fns.append(fp.name)
|
177 |
+
else:
|
178 |
+
local_fns.append(fn_or_url)
|
179 |
+
try:
|
180 |
+
yield local_fns
|
181 |
+
finally:
|
182 |
+
for fp in fps:
|
183 |
+
fp.close()
|
184 |
+
|
185 |
+
|
186 |
+
@cache
|
187 |
+
def _get_dataset(dataset_args: str) -> Dataset:
|
188 |
+
return load_dataset(**json.loads(dataset_args))
|
189 |
+
|
190 |
+
|
191 |
+
def get_dataset_cached(dataset_args: Dict) -> Dataset:
|
192 |
+
return _get_dataset(json.dumps(dataset_args))
|
193 |
+
|
194 |
+
|
195 |
+
def load_audio_signal(input_: Union[Dict, str]) -> Dict:
|
196 |
+
from audiotools import AudioSignal
|
197 |
+
|
198 |
+
if isinstance(input_, dict) and "array" in input_:
|
199 |
+
array = input_["array"]
|
200 |
+
elif isinstance(input_, dict) and "dataset_args" in input_:
|
201 |
+
item = get_dataset_cached(input_["dataset_args"])[input_["idx"]]
|
202 |
+
array = item["audio"]["array"]
|
203 |
+
elif isinstance(input_, dict) and "path" in input_:
|
204 |
+
with with_local_files([input_["path"]]) as local_fns:
|
205 |
+
array = AudioSignal(local_fns[0])
|
206 |
+
elif isinstance(input_, str):
|
207 |
+
with with_local_files([input_]) as local_fns:
|
208 |
+
array = AudioSignal(local_fns[0])
|
209 |
+
else:
|
210 |
+
raise ValueError(f"Could not load audio from {input_}")
|
211 |
+
|
212 |
+
return {"array": list(array)}
|
213 |
+
|
214 |
+
|
215 |
+
def load_audio(input_: Union[Dict, str], target_sampling_rate: int = None) -> Dict:
|
216 |
+
import soundfile as sf
|
217 |
+
import librosa
|
218 |
+
|
219 |
+
if isinstance(input_, dict) and "array" in input_ and "sampling_rate" in input_:
|
220 |
+
array = input_["array"]
|
221 |
+
sampling_rate = input_["sampling_rate"]
|
222 |
+
elif isinstance(input_, dict) and "dataset_args" in input_:
|
223 |
+
item = get_dataset_cached(input_["dataset_args"])[input_["idx"]]
|
224 |
+
array = item["audio"]["array"]
|
225 |
+
sampling_rate = item["audio"]["sampling_rate"]
|
226 |
+
elif isinstance(input_, dict) and "path" in input_:
|
227 |
+
with with_local_files([input_["path"]]) as local_fns:
|
228 |
+
array, sampling_rate = sf.read(local_fns[0])
|
229 |
+
elif isinstance(input_, str):
|
230 |
+
with with_local_files([input_]) as local_fns:
|
231 |
+
array, sampling_rate = sf.read(local_fns[0])
|
232 |
+
else:
|
233 |
+
raise ValueError(f"Could not load audio from {input_}")
|
234 |
+
|
235 |
+
if array.ndim == 2:
|
236 |
+
array = array.mean(axis=1)
|
237 |
+
|
238 |
+
if target_sampling_rate is not None and sampling_rate != target_sampling_rate:
|
239 |
+
array = librosa.resample(
|
240 |
+
array, orig_sr=sampling_rate, target_sr=target_sampling_rate
|
241 |
+
)
|
242 |
+
sampling_rate = target_sampling_rate
|
243 |
+
|
244 |
+
return {"array": list(array), "sampling_rate": sampling_rate}
|
245 |
+
|
246 |
+
|
247 |
+
def _download_yt_video(url: str) -> str:
|
248 |
+
from pytube import YouTube
|
249 |
+
|
250 |
+
youtube = YouTube(url)
|
251 |
+
video = youtube.streams.first()
|
252 |
+
|
253 |
+
fn = "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=10))
|
254 |
+
file_path = video.download(output_path=tempfile.gettempdir(), filename=fn)
|
255 |
+
|
256 |
+
return file_path
|
257 |
+
|
258 |
+
|
259 |
+
def _read_video_pyav(container, indices):
|
260 |
+
frames = []
|
261 |
+
container.seek(0)
|
262 |
+
start_index = indices[0]
|
263 |
+
end_index = indices[-1]
|
264 |
+
for i, frame in enumerate(container.decode(video=0)):
|
265 |
+
if i > end_index:
|
266 |
+
break
|
267 |
+
if i >= start_index and i in indices:
|
268 |
+
frames.append(frame)
|
269 |
+
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
270 |
+
|
271 |
+
|
272 |
+
def _sample_frame_indices(clip_len, frame_sample_rate, seg_len):
|
273 |
+
converted_len = int(clip_len * frame_sample_rate)
|
274 |
+
end_idx = np.random.randint(converted_len, seg_len)
|
275 |
+
start_idx = end_idx - converted_len
|
276 |
+
indices = np.linspace(start_idx, end_idx, num=clip_len)
|
277 |
+
indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
|
278 |
+
return indices
|
279 |
+
|
280 |
+
|
281 |
+
def load_video(
|
282 |
+
input_: str,
|
283 |
+
frames: int = 8,
|
284 |
+
frame_sample_rate: int = 1,
|
285 |
+
start_time: Optional[int] = None,
|
286 |
+
end_time: Optional[int] = None,
|
287 |
+
) -> np.ndarray:
|
288 |
+
import av
|
289 |
+
|
290 |
+
delete_file = False
|
291 |
+
|
292 |
+
if isinstance(input_, dict) and "youtube.com" and input_.get("url", ""):
|
293 |
+
file_path = _download_yt_video(input_["url"])
|
294 |
+
delete_file = True
|
295 |
+
# start_time = input_.get("start_time", None)
|
296 |
+
# end_time = input_.get("end_time", None)
|
297 |
+
elif isinstance(input_, str) and "youtube.com" in input_:
|
298 |
+
file_path = _download_yt_video(input_)
|
299 |
+
delete_file = True
|
300 |
+
elif isinstance(input_, str):
|
301 |
+
file_path = input_
|
302 |
+
else:
|
303 |
+
raise ValueError(f"Could not load video from {input_}")
|
304 |
+
|
305 |
+
if start_time is not None or end_time is not None:
|
306 |
+
start_time = start_time if start_time is not None else 0
|
307 |
+
end_time = end_time if end_time is not None else "end"
|
308 |
+
trim_file_path = f"{file_path.rsplit('.', 1)[0]}_trim.mp4"
|
309 |
+
subprocess.run(
|
310 |
+
[
|
311 |
+
"ffmpeg",
|
312 |
+
"-i",
|
313 |
+
file_path,
|
314 |
+
"-ss",
|
315 |
+
str(start_time),
|
316 |
+
"-to",
|
317 |
+
str(end_time),
|
318 |
+
"-c",
|
319 |
+
"copy",
|
320 |
+
trim_file_path,
|
321 |
+
]
|
322 |
+
)
|
323 |
+
file_path = trim_file_path
|
324 |
+
|
325 |
+
container = av.open(file_path)
|
326 |
+
indices = _sample_frame_indices(
|
327 |
+
clip_len=frames,
|
328 |
+
frame_sample_rate=frame_sample_rate,
|
329 |
+
seg_len=container.streams.video[0].frames,
|
330 |
+
)
|
331 |
+
video = _read_video_pyav(container, indices)
|
332 |
+
|
333 |
+
if delete_file:
|
334 |
+
os.remove(file_path)
|
335 |
+
|
336 |
+
return video
|
src/sonicverse/multi_token/inference.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Type, List, Optional
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from transformers import AutoTokenizer, AutoConfig, BitsAndBytesConfig
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
from peft import PeftModel
|
7 |
+
import torch
|
8 |
+
import os
|
9 |
+
|
10 |
+
from multi_token.model_utils import fix_tokenizer, MultiTaskType
|
11 |
+
from multi_token.modalities.base_modality import Modality
|
12 |
+
from multi_token.language_models.mistral import MistralForCausalLM
|
13 |
+
from multi_token.language_models import LANGUAGE_MODEL_NAME_TO_CLASS
|
14 |
+
from multi_token.modalities import MODALITY_BUILDERS
|
15 |
+
|
16 |
+
|
17 |
+
def load_trained_lora_model(
|
18 |
+
model_name_or_path: str,
|
19 |
+
model_lora_path: str,
|
20 |
+
model_cls: Optional[Type] = None,
|
21 |
+
modalities: Optional[List[Modality]] = None,
|
22 |
+
load_bits: int = 16,
|
23 |
+
device_map: str = "auto",
|
24 |
+
use_multi_task: int = MultiTaskType.NO_MULTI_TASK,
|
25 |
+
tasks_config: str = None
|
26 |
+
):
|
27 |
+
load_kwargs = {"device_map": device_map}
|
28 |
+
|
29 |
+
if load_bits == 8:
|
30 |
+
load_kwargs["load_in_8bit"] = True
|
31 |
+
elif load_bits == 4:
|
32 |
+
load_kwargs["load_in_4bit"] = True
|
33 |
+
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
34 |
+
load_in_4bit=True,
|
35 |
+
bnb_4bit_compute_dtype=torch.float16,
|
36 |
+
bnb_4bit_use_double_quant=True,
|
37 |
+
bnb_4bit_quant_type="nf4",
|
38 |
+
)
|
39 |
+
elif load_bits == 16:
|
40 |
+
load_kwargs["torch_dtype"] = torch.float16
|
41 |
+
else:
|
42 |
+
raise ValueError(f"Invalid load_bits: {load_bits}")
|
43 |
+
|
44 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
|
45 |
+
fix_tokenizer(tokenizer)
|
46 |
+
|
47 |
+
cfg = AutoConfig.from_pretrained(model_lora_path)
|
48 |
+
if model_cls is None:
|
49 |
+
model_cls = LANGUAGE_MODEL_NAME_TO_CLASS[cfg.model_cls]
|
50 |
+
if modalities is None:
|
51 |
+
if use_multi_task:
|
52 |
+
modalities = MODALITY_BUILDERS[cfg.modality_builder](use_multi_task = use_multi_task, tasks_config = tasks_config)
|
53 |
+
else:
|
54 |
+
modalities = MODALITY_BUILDERS[cfg.modality_builder]()
|
55 |
+
|
56 |
+
logging.info(f"Loading base model from {model_name_or_path} as {load_bits} bits")
|
57 |
+
model = model_cls.from_pretrained(
|
58 |
+
model_name_or_path, low_cpu_mem_usage=True, config=cfg, **load_kwargs
|
59 |
+
)
|
60 |
+
model.modalities = modalities
|
61 |
+
|
62 |
+
logging.info(f"Loading projector weights for {[m.name for m in modalities]}")
|
63 |
+
if os.path.exists(os.path.join(model_lora_path, "non_lora_trainables.bin")):
|
64 |
+
non_lora_trainables = torch.load(
|
65 |
+
os.path.join(model_lora_path, "non_lora_trainables.bin"), map_location="cuda"
|
66 |
+
)
|
67 |
+
else:
|
68 |
+
local_fn = hf_hub_download(
|
69 |
+
repo_id=model_lora_path,
|
70 |
+
filename="non_lora_trainables.bin",
|
71 |
+
repo_type="model",
|
72 |
+
)
|
73 |
+
non_lora_trainables = torch.load(local_fn, map_location="cuda")
|
74 |
+
model.get_model().initialize_pretrained_modules(modalities, non_lora_trainables)
|
75 |
+
|
76 |
+
logging.info(f"Loading and merging LoRA weights from {model_lora_path}")
|
77 |
+
model = PeftModel.from_pretrained(model, model_lora_path)
|
78 |
+
if load_bits == 16:
|
79 |
+
# TODO: Figure out why this fails for other bit sizes
|
80 |
+
model = model.merge_and_unload()
|
81 |
+
model.eval()
|
82 |
+
|
83 |
+
return model, tokenizer
|
src/sonicverse/multi_token/language_models/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from multi_token.language_models.mistral import (
|
2 |
+
MistralLMMForCausalLM,
|
3 |
+
)
|
4 |
+
|
5 |
+
LANGUAGE_MODEL_CLASSES = [MistralLMMForCausalLM]
|
6 |
+
|
7 |
+
LANGUAGE_MODEL_NAME_TO_CLASS = {cls.__name__: cls for cls in LANGUAGE_MODEL_CLASSES}
|
src/sonicverse/multi_token/language_models/base_model.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
|
4 |
+
from torch.nn.functional import conv1d
|
5 |
+
import torch
|
6 |
+
import logging
|
7 |
+
|
8 |
+
from multi_token.modalities.base_modality import Modality
|
9 |
+
from multi_token.model_utils import MultiTaskType
|
10 |
+
|
11 |
+
from torchviz import make_dot
|
12 |
+
|
13 |
+
class LMMMetaModel:
|
14 |
+
def __init__(self, config):
|
15 |
+
super(LMMMetaModel, self).__init__(config)
|
16 |
+
|
17 |
+
def _load_projector_weights(self, weights: Dict):
|
18 |
+
weights = {
|
19 |
+
(k[23:] if k.startswith("base_model.model.model.") else k): v
|
20 |
+
for k, v in weights.items()
|
21 |
+
}
|
22 |
+
logging.info(f"Loading pretrained weights: {list(weights.keys())}")
|
23 |
+
load_result = self.load_state_dict(weights, strict=False)
|
24 |
+
assert (
|
25 |
+
len(load_result.unexpected_keys) == 0
|
26 |
+
), "Unexpected weights, is this the right model?"
|
27 |
+
|
28 |
+
def initialize_pretrained_modules(self, modalities: List[Modality], weights: Dict):
|
29 |
+
for m in modalities:
|
30 |
+
# projector = m.build_projector(self.config.hidden_size)
|
31 |
+
# setattr(self, m.name + "_lmm_projector", projector)
|
32 |
+
projector = m.build_projector(self.config.hidden_size)
|
33 |
+
if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
|
34 |
+
for task_name in m.tasks["task_heads"].keys():
|
35 |
+
task_model = projector[task_name]
|
36 |
+
setattr(self, m.name + "_" + task_name, task_model)
|
37 |
+
else:
|
38 |
+
setattr(self, m.name + "_lmm_projector", projector)
|
39 |
+
|
40 |
+
self._load_projector_weights(weights)
|
41 |
+
|
42 |
+
def initialize_modules(self, modalities: List[Modality], weights: Dict):
|
43 |
+
names = [m.name for m in modalities]
|
44 |
+
|
45 |
+
self.config.modalities = names
|
46 |
+
|
47 |
+
for m in modalities:
|
48 |
+
# projector = m.build_projector(self.config.hidden_size)
|
49 |
+
# setattr(self, m.name + "_lmm_projector", projector)
|
50 |
+
projector = m.build_projector(self.config.hidden_size)
|
51 |
+
if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
|
52 |
+
for task_name in m.tasks["task_heads"].keys():
|
53 |
+
task_model = projector[task_name]
|
54 |
+
setattr(self, m.name + "_" + task_name, task_model)
|
55 |
+
else:
|
56 |
+
setattr(self, m.name + "_lmm_projector", projector)
|
57 |
+
|
58 |
+
self._load_projector_weights(weights)
|
59 |
+
|
60 |
+
|
61 |
+
class LMMMetaForCausalLM(ABC):
|
62 |
+
@abstractmethod
|
63 |
+
def get_model(self) -> "LMMMetaForCausalLM":
|
64 |
+
pass
|
65 |
+
|
66 |
+
def prepare_inputs_labels_for_multimodal(
|
67 |
+
self, input_ids, attention_mask, past_key_values, labels, **kwargs
|
68 |
+
):
|
69 |
+
model = self.get_model()
|
70 |
+
|
71 |
+
batch_size, seq_len = input_ids.shape
|
72 |
+
|
73 |
+
# batch_size x seq_len x embedding_hidden_size
|
74 |
+
inputs_embeds = torch.zeros(
|
75 |
+
(batch_size, seq_len, self.config.hidden_size),
|
76 |
+
dtype=self.dtype,
|
77 |
+
device=self.device,
|
78 |
+
)
|
79 |
+
|
80 |
+
# modality x batch_size x instance_idx x modality_token_width x embedding_hidden_size
|
81 |
+
projected_tensors = []
|
82 |
+
# assuming that if caching is enabled, we'll never have past_key_values AND need to encode the instruction modality values
|
83 |
+
task_vals = {}
|
84 |
+
|
85 |
+
#print("here past_key_values", past_key_values)
|
86 |
+
#past_key_values == None
|
87 |
+
if past_key_values is None:
|
88 |
+
for m in self.modalities:
|
89 |
+
m_vals = m.forward(kwargs.get(m.name))
|
90 |
+
mp_vals = []
|
91 |
+
if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
|
92 |
+
proj = {}
|
93 |
+
for task_name in m.tasks["task_heads"].keys():
|
94 |
+
proj[task_name] = getattr(model, m.name + "_" + task_name)
|
95 |
+
else:
|
96 |
+
proj = getattr(model, m.name + "_lmm_projector")
|
97 |
+
|
98 |
+
# project each batch into language model token space
|
99 |
+
for m_val in m_vals:
|
100 |
+
if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
|
101 |
+
for task_name in m.tasks["task_heads"].keys():
|
102 |
+
if task_name == "lmm_projector":
|
103 |
+
mp_vals.append(proj[task_name](m_val))
|
104 |
+
# make_dot(mp_vals[-1], params=dict(list(model.named_parameters()))).render(task_name, format="png")
|
105 |
+
else:
|
106 |
+
if task_name not in task_vals:
|
107 |
+
task_vals[task_name] = [proj[task_name](m_val)]
|
108 |
+
else:
|
109 |
+
task_vals[task_name].append(proj[task_name](m_val))
|
110 |
+
# make_dot(task_vals[task_name], params=dict(list(model.named_parameters()))).render(task_name, format="png")
|
111 |
+
|
112 |
+
elif m.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK:
|
113 |
+
task_outputs = proj(m_val)
|
114 |
+
mp_vals.append(task_outputs.pop("projectors"))
|
115 |
+
for task_name in task_outputs.keys():
|
116 |
+
if not task_name in task_vals:
|
117 |
+
task_vals[task_name] = [task_outputs[task_name]]
|
118 |
+
else:
|
119 |
+
task_vals[task_name].append(task_outputs[task_name])
|
120 |
+
else:
|
121 |
+
mp_vals.append(proj(m_val))
|
122 |
+
|
123 |
+
assert all(
|
124 |
+
mp_val.shape[1:] == (m.token_width, self.config.hidden_size)
|
125 |
+
for mp_val in mp_vals
|
126 |
+
), (
|
127 |
+
"Modality tensors have incorrect shape, check your projector implementation "
|
128 |
+
+ str([mp_val.shape[1:] for mp_val in mp_vals])
|
129 |
+
+ " vs expected "
|
130 |
+
+ str((m.token_width, self.config.hidden_size))
|
131 |
+
)
|
132 |
+
projected_tensors.append(mp_vals)
|
133 |
+
|
134 |
+
indices = None
|
135 |
+
for i, input_ids_sample in enumerate(input_ids):
|
136 |
+
is_text_mask = input_ids_sample >= 0
|
137 |
+
|
138 |
+
# fill in all the LLM-based text embeddings
|
139 |
+
inputs_embeds[i, is_text_mask] = model.embed_tokens(
|
140 |
+
input_ids_sample[is_text_mask]
|
141 |
+
)
|
142 |
+
|
143 |
+
# skip if all tokens are text tokens
|
144 |
+
if is_text_mask.sum() == seq_len:
|
145 |
+
continue
|
146 |
+
assert (
|
147 |
+
past_key_values is None
|
148 |
+
), "We shouldn't have cached keys if this is the first instruction pass"
|
149 |
+
|
150 |
+
#past_key_values = None
|
151 |
+
|
152 |
+
for mi, m in enumerate(self.modalities):
|
153 |
+
# locate the group of tokens for this modality
|
154 |
+
m_mask = (input_ids_sample == m.token_idx).float()
|
155 |
+
m_kernel = torch.tensor(
|
156 |
+
[-1] * m.token_width, dtype=m_mask.dtype, device=m_mask.device
|
157 |
+
)
|
158 |
+
m_conv = conv1d(
|
159 |
+
m_mask.unsqueeze(0).unsqueeze(0),
|
160 |
+
m_kernel.unsqueeze(0).unsqueeze(0),
|
161 |
+
)
|
162 |
+
|
163 |
+
# where do we see `token_width`-tokens in a row?
|
164 |
+
indices = (m_conv[0, 0] == -m.token_width).nonzero(as_tuple=True)[0]
|
165 |
+
|
166 |
+
# fill these embeddings with the projected modality tensor
|
167 |
+
last_covered_idx = -1
|
168 |
+
k = 0
|
169 |
+
for possible_token_idx in indices:
|
170 |
+
if possible_token_idx <= last_covered_idx:
|
171 |
+
# make sure we don't overwrite an instance we've already covered
|
172 |
+
# handles bug caused by back-to-back tokens
|
173 |
+
continue
|
174 |
+
batch_modality_tensor = projected_tensors[mi][i][k]
|
175 |
+
inputs_embeds[
|
176 |
+
i, possible_token_idx : possible_token_idx + m.token_width
|
177 |
+
] = batch_modality_tensor
|
178 |
+
last_covered_idx = possible_token_idx + m.token_width - 1
|
179 |
+
k += 1
|
180 |
+
|
181 |
+
return None, attention_mask, past_key_values, inputs_embeds, labels, task_vals
|
src/sonicverse/multi_token/language_models/mistral.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.nn import CrossEntropyLoss
|
7 |
+
|
8 |
+
from transformers import (
|
9 |
+
AutoConfig,
|
10 |
+
AutoModelForCausalLM,
|
11 |
+
MistralConfig,
|
12 |
+
MistralModel,
|
13 |
+
MistralForCausalLM,
|
14 |
+
)
|
15 |
+
|
16 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
17 |
+
|
18 |
+
from multi_token.language_models.base_model import (
|
19 |
+
LMMMetaModel,
|
20 |
+
LMMMetaForCausalLM,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
class MistralLMMConfig(MistralConfig):
|
25 |
+
model_type = "mistral-lmm"
|
26 |
+
|
27 |
+
|
28 |
+
class MistralLMMModel(LMMMetaModel, MistralModel):
|
29 |
+
config_class = MistralLMMConfig
|
30 |
+
|
31 |
+
def __init__(self, config: MistralLMMConfig):
|
32 |
+
super(MistralLMMModel, self).__init__(config)
|
33 |
+
|
34 |
+
|
35 |
+
class MistralLMMForCausalLM(MistralForCausalLM, LMMMetaForCausalLM):
|
36 |
+
config_class = MistralLMMConfig
|
37 |
+
|
38 |
+
def __init__(self, config):
|
39 |
+
super(MistralForCausalLM, self).__init__(config)
|
40 |
+
self.model = MistralLMMModel(config)
|
41 |
+
|
42 |
+
self.vocab_size = config.vocab_size
|
43 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
44 |
+
self.modalities = None
|
45 |
+
|
46 |
+
# Initialize weights and apply final processing
|
47 |
+
self.post_init()
|
48 |
+
|
49 |
+
def get_model(self) -> "MistralLMMForCausalLM":
|
50 |
+
return self.model
|
51 |
+
|
52 |
+
def forward(
|
53 |
+
self,
|
54 |
+
input_ids: torch.LongTensor = None,
|
55 |
+
attention_mask: Optional[torch.Tensor] = None,
|
56 |
+
position_ids: Optional[torch.LongTensor] = None,
|
57 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
58 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
59 |
+
labels: Optional[List] = None,
|
60 |
+
use_cache: Optional[bool] = None,
|
61 |
+
output_attentions: Optional[bool] = None,
|
62 |
+
output_hidden_states: Optional[bool] = None,
|
63 |
+
return_dict: Optional[bool] = None,
|
64 |
+
**kwargs
|
65 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
66 |
+
#print("Past keys ",past_key_values)
|
67 |
+
output_attentions = (
|
68 |
+
output_attentions
|
69 |
+
if output_attentions is not None
|
70 |
+
else self.config.output_attentions
|
71 |
+
)
|
72 |
+
output_hidden_states = (
|
73 |
+
output_hidden_states
|
74 |
+
if output_hidden_states is not None
|
75 |
+
else self.config.output_hidden_states
|
76 |
+
)
|
77 |
+
return_dict = (
|
78 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
79 |
+
)
|
80 |
+
|
81 |
+
if labels != None:
|
82 |
+
labels_inp = labels[0]
|
83 |
+
else:
|
84 |
+
labels_inp = labels
|
85 |
+
(
|
86 |
+
input_ids,
|
87 |
+
attention_mask,
|
88 |
+
past_key_values,
|
89 |
+
inputs_embeds,
|
90 |
+
lmm_labels,
|
91 |
+
task_values
|
92 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
93 |
+
input_ids, attention_mask, past_key_values, labels_inp, **kwargs
|
94 |
+
)
|
95 |
+
|
96 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
97 |
+
outputs = self.model(
|
98 |
+
input_ids=input_ids,
|
99 |
+
attention_mask=attention_mask,
|
100 |
+
position_ids=position_ids,
|
101 |
+
past_key_values=past_key_values,
|
102 |
+
inputs_embeds=inputs_embeds,
|
103 |
+
use_cache=use_cache,
|
104 |
+
output_attentions=output_attentions,
|
105 |
+
output_hidden_states=output_hidden_states,
|
106 |
+
return_dict=return_dict,
|
107 |
+
)
|
108 |
+
|
109 |
+
hidden_states = outputs[0]
|
110 |
+
logits = self.lm_head(hidden_states)
|
111 |
+
logits = logits.float()
|
112 |
+
|
113 |
+
# print("Labels 1 size ", len(labels[1]))
|
114 |
+
# print("labels 1 element size ", len(labels[1][0]))
|
115 |
+
# print("labels 1 element 1 task size ", labels[1][0][0].shape)
|
116 |
+
# print("labels 1 element 2 task size ", labels[1][0][1].shape)
|
117 |
+
# print("labels 1 element 3 task size ", labels[1][0][2].shape)
|
118 |
+
# print("task vals size ", len(task_values))
|
119 |
+
# for task in task_values.keys():
|
120 |
+
# print(" task ", task, len(task_values[task]))
|
121 |
+
# print(" task element", task, task_values[task][0].shape)
|
122 |
+
|
123 |
+
|
124 |
+
if labels != None:
|
125 |
+
task_pairs = {}
|
126 |
+
task_list = list(task_values.keys())
|
127 |
+
for task_id in range(len(task_list)):
|
128 |
+
_task_labels = []
|
129 |
+
_task_outputs = []
|
130 |
+
|
131 |
+
_task = task_list[task_id]
|
132 |
+
for inst in range(len(task_values[_task])):
|
133 |
+
# print("task output shape ", _task, task_values[_task][inst].shape)
|
134 |
+
_task_outputs.append(task_values[_task][inst].unsqueeze(0))
|
135 |
+
_task_labels.append(torch.stack([labels[1][inst][task_id]]))
|
136 |
+
|
137 |
+
task_pairs[_task] = [_task_labels, _task_outputs]
|
138 |
+
# print("TASK ", _task)
|
139 |
+
# print(" LABELS LEN ", len(task_pairs[_task][0]))
|
140 |
+
# print(" LABELS ELEM shape ", task_pairs[_task][0][0].shape)
|
141 |
+
# print(" VALUES LEN ", len(task_pairs[_task][1]))
|
142 |
+
# print(" VALUES ELEM shape ", task_pairs[_task][1][0].shape)
|
143 |
+
|
144 |
+
loss = None
|
145 |
+
if lmm_labels is not None:
|
146 |
+
# Shift so that tokens < n predict n
|
147 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
148 |
+
shift_labels = lmm_labels[..., 1:].contiguous()
|
149 |
+
# Flatten the tokens
|
150 |
+
loss_fct = CrossEntropyLoss()
|
151 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
152 |
+
shift_labels = shift_labels.view(-1)
|
153 |
+
# Enable model parallelism
|
154 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
155 |
+
loss = loss_fct(shift_logits, shift_labels)
|
156 |
+
|
157 |
+
# print("loss ", loss)
|
158 |
+
|
159 |
+
|
160 |
+
if labels != None:
|
161 |
+
task_loss = {}
|
162 |
+
for task in task_list:
|
163 |
+
preds = torch.cat(task_pairs[task][1], dim=0)
|
164 |
+
labs = torch.cat(task_pairs[task][0], dim=0)
|
165 |
+
preds_flat = preds.view(-1, preds.size(-1)) # Reshape to (batch_size * sequence_length, num_classes)
|
166 |
+
labs_flat = labs.view(-1) # Reshape to (batch_size * sequence_length)
|
167 |
+
|
168 |
+
#print("task ", task)
|
169 |
+
#print("preds shape ", preds.shape)
|
170 |
+
#print("labs shape ", labs.shape)
|
171 |
+
if task == "lmm_projector":
|
172 |
+
task_loss[task] = CrossEntropyLoss()(preds,labs)
|
173 |
+
else:
|
174 |
+
task_loss[task] = nn.BCEWithLogitsLoss()(preds, labs)
|
175 |
+
# print("task losses ", task_loss)
|
176 |
+
|
177 |
+
total_loss = None
|
178 |
+
if labels != None:
|
179 |
+
total_task_loss = None
|
180 |
+
for task in task_list:
|
181 |
+
if self.modalities[0].tasks["task_heads"][task]["weight"] != 0.0:
|
182 |
+
if total_task_loss != None:
|
183 |
+
total_task_loss += self.modalities[0].tasks["task_heads"][task]["weight"]*task_loss[task]
|
184 |
+
else:
|
185 |
+
total_task_loss = self.modalities[0].tasks["task_heads"][task]["weight"]*task_loss[task]
|
186 |
+
|
187 |
+
if total_task_loss != None:
|
188 |
+
total_loss = self.modalities[0].tasks["task_heads"]["lmm_projector"]["weight"]*loss + total_task_loss
|
189 |
+
else:
|
190 |
+
total_loss = loss
|
191 |
+
|
192 |
+
if not return_dict:
|
193 |
+
output = (logits,) + outputs[1:]
|
194 |
+
return (total_loss,) + output if total_loss is not None else output
|
195 |
+
|
196 |
+
return CausalLMOutputWithPast(
|
197 |
+
loss=total_loss,
|
198 |
+
logits=logits,
|
199 |
+
past_key_values=outputs.past_key_values,
|
200 |
+
hidden_states=outputs.hidden_states,
|
201 |
+
attentions=outputs.attentions,
|
202 |
+
)
|
203 |
+
|
204 |
+
def prepare_inputs_for_generation(
|
205 |
+
self,
|
206 |
+
input_ids,
|
207 |
+
past_key_values=None,
|
208 |
+
attention_mask=None,
|
209 |
+
inputs_embeds=None,
|
210 |
+
modality_inputs=None,
|
211 |
+
**kwargs
|
212 |
+
):
|
213 |
+
#print("hoooo", past_key_values)
|
214 |
+
|
215 |
+
#past_key_values = None
|
216 |
+
if past_key_values:
|
217 |
+
input_ids = input_ids[:, -1:]
|
218 |
+
|
219 |
+
if inputs_embeds is not None:
|
220 |
+
raise ValueError("inputs_embeds not supported")
|
221 |
+
|
222 |
+
model_inputs = {
|
223 |
+
"input_ids": input_ids,
|
224 |
+
"position_ids": None,
|
225 |
+
"past_key_values": past_key_values,
|
226 |
+
"use_cache": kwargs.get("use_cache"),
|
227 |
+
"attention_mask": attention_mask,
|
228 |
+
**(modality_inputs or {}),
|
229 |
+
}
|
230 |
+
|
231 |
+
return model_inputs
|
232 |
+
|
233 |
+
|
234 |
+
AutoConfig.register("mistral-lmm", MistralLMMConfig)
|
235 |
+
AutoModelForCausalLM.register(MistralLMMConfig, MistralLMMForCausalLM)
|
src/sonicverse/multi_token/modalities/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from multi_token.model_utils import MultiTaskType
|
2 |
+
from multi_token.modalities.vision_clip import (
|
3 |
+
CLIPVisionModality,
|
4 |
+
OUTPUT_LAYER as CLIP_POOL_LAYER,
|
5 |
+
)
|
6 |
+
from multi_token.modalities.imagebind import ImageBindModality
|
7 |
+
from multi_token.modalities.document_gte import DocumentGTEModality
|
8 |
+
from multi_token.modalities.audio_whisper import WhisperAudioModality
|
9 |
+
from multi_token.modalities.audio_clap import CLAPAudioModality
|
10 |
+
from multi_token.modalities.video_xclip import XCLIPVideoModality
|
11 |
+
from multi_token.modalities.audio_descript import DescriptAudioModality
|
12 |
+
from multi_token.modalities.audio_mert import MERTAudioModality
|
13 |
+
|
14 |
+
MODALITY_BUILDERS = {
|
15 |
+
"vision_clip": lambda: [CLIPVisionModality()],
|
16 |
+
"vision_clip_pool": lambda: [
|
17 |
+
CLIPVisionModality(feature_layer=CLIP_POOL_LAYER, num_tokens_output=10)
|
18 |
+
],
|
19 |
+
"audio_whisper": lambda: [
|
20 |
+
WhisperAudioModality(
|
21 |
+
num_tokens_output=10, model_name_or_path="openai/whisper-small"
|
22 |
+
)
|
23 |
+
],
|
24 |
+
"audio_mert": lambda use_multi_task=MultiTaskType.NO_MULTI_TASK, tasks_config=None :[MERTAudioModality(use_multi_task=use_multi_task, tasks_config=tasks_config, num_tokens_output=60, hidden_dim=32, num_conv_layers = 3, num_mlp_layers = 2)],
|
25 |
+
"audio_clap": lambda use_multi_task=MultiTaskType.NO_MULTI_TASK, tasks_config=None :[CLAPAudioModality(use_multi_task=use_multi_task, tasks_config=tasks_config, num_tokens_output=20)],
|
26 |
+
"audio_descript": lambda use_multi_task=MultiTaskType.NO_MULTI_TASK, tasks_config=None : [DescriptAudioModality(use_multi_task=use_multi_task, tasks_config=tasks_config, num_projector_conv_layers=1, num_projector_mlp_layers=1, num_tokens_output=60, codebooks=96)],
|
27 |
+
"video_xclip": lambda: [XCLIPVideoModality(num_tokens_output=10)],
|
28 |
+
"imagebind": lambda: [ImageBindModality()],
|
29 |
+
"document_gte": lambda: [DocumentGTEModality()],
|
30 |
+
"document_gte_x16": lambda: [DocumentGTEModality(num_tokens_output=32)],
|
31 |
+
}
|
src/sonicverse/multi_token/modalities/audio_clap.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from transformers import ClapModel, ClapProcessor
|
6 |
+
|
7 |
+
from multi_token.model_utils import MultiTaskType
|
8 |
+
from multi_token.data_tools import load_audio
|
9 |
+
from multi_token.modalities.base_modality import Modality
|
10 |
+
from multi_token.modalities.projectors import (
|
11 |
+
build_mlp_vector_projector, build_mt_vector_projector, MultiTaskModel
|
12 |
+
)
|
13 |
+
|
14 |
+
import json
|
15 |
+
|
16 |
+
OUTPUT_EMB_SIZE = 512
|
17 |
+
|
18 |
+
|
19 |
+
class CLAPAudioModule(nn.Module):
|
20 |
+
def __init__(self, model_name_or_path: str):
|
21 |
+
super().__init__()
|
22 |
+
self.model_name_or_path = model_name_or_path
|
23 |
+
self.model = None
|
24 |
+
self.processor = None
|
25 |
+
|
26 |
+
self.load_model()
|
27 |
+
|
28 |
+
def load_model(self):
|
29 |
+
self.model = ClapModel.from_pretrained(self.model_name_or_path)
|
30 |
+
self.processor = ClapProcessor.from_pretrained(self.model_name_or_path)
|
31 |
+
self.model.requires_grad_(False)
|
32 |
+
|
33 |
+
@torch.no_grad()
|
34 |
+
def forward(self, audios) -> torch.Tensor:
|
35 |
+
embs = []
|
36 |
+
for audio_features in audios:
|
37 |
+
features = self.model.get_audio_features(
|
38 |
+
input_features=audio_features["input_features"].to(torch.float32),
|
39 |
+
is_longer=audio_features["is_longer"],
|
40 |
+
)
|
41 |
+
embs.append(features)
|
42 |
+
embs = torch.stack(embs)
|
43 |
+
return embs.view(-1, 1, OUTPUT_EMB_SIZE)
|
44 |
+
|
45 |
+
@property
|
46 |
+
def dtype(self):
|
47 |
+
return self.model.dtype
|
48 |
+
|
49 |
+
@property
|
50 |
+
def device(self):
|
51 |
+
return self.model.device
|
52 |
+
|
53 |
+
|
54 |
+
class CLAPAudioModality(Modality):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
model_name_or_path: str = "laion/clap-htsat-fused",
|
58 |
+
num_projector_layers: int = 2,
|
59 |
+
num_tokens_output: int = 10,
|
60 |
+
use_multi_task: int = MultiTaskType.NO_MULTI_TASK,
|
61 |
+
tasks_config: str = None
|
62 |
+
):
|
63 |
+
self.model_name_or_path = model_name_or_path
|
64 |
+
self.module = CLAPAudioModule(model_name_or_path=self.model_name_or_path)
|
65 |
+
self.num_projector_layers = num_projector_layers
|
66 |
+
self.num_tokens_output = num_tokens_output
|
67 |
+
self.dtype = torch.float32
|
68 |
+
self.use_multi_task = use_multi_task
|
69 |
+
self.tasks = None
|
70 |
+
if self.use_multi_task != MultiTaskType.NO_MULTI_TASK:
|
71 |
+
with open(tasks_config, 'r') as f:
|
72 |
+
self.tasks = json.load(f)
|
73 |
+
|
74 |
+
print("Tasks :", self.tasks)
|
75 |
+
|
76 |
+
def build_projector(self, lm_hidden_size: int) -> nn.Module:
|
77 |
+
if self.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK:
|
78 |
+
return MultiTaskModel(OUTPUT_EMB_SIZE, self.tasks)
|
79 |
+
elif self.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
|
80 |
+
return build_mt_vector_projector(
|
81 |
+
# return build_mlp_vector_projector(
|
82 |
+
input_hidden_size=OUTPUT_EMB_SIZE,
|
83 |
+
lm_hidden_size=lm_hidden_size,
|
84 |
+
# num_layers=self.num_projector_layers,
|
85 |
+
# num_tokens=self.num_tokens_output,
|
86 |
+
# )
|
87 |
+
tasks = self.tasks
|
88 |
+
)
|
89 |
+
# )["llm_projector"]
|
90 |
+
else:
|
91 |
+
return build_mlp_vector_projector(
|
92 |
+
input_hidden_size=OUTPUT_EMB_SIZE,
|
93 |
+
lm_hidden_size=lm_hidden_size,
|
94 |
+
num_layers=self.num_projector_layers,
|
95 |
+
num_tokens=self.num_tokens_output,
|
96 |
+
)
|
97 |
+
|
98 |
+
@property
|
99 |
+
def name(self) -> str:
|
100 |
+
return "audio_clap"
|
101 |
+
|
102 |
+
@property
|
103 |
+
def token(self) -> str:
|
104 |
+
return "<sound>"
|
105 |
+
|
106 |
+
@property
|
107 |
+
def data_key(self) -> str:
|
108 |
+
return "sounds"
|
109 |
+
|
110 |
+
@property
|
111 |
+
def token_width(self) -> int:
|
112 |
+
return self.num_tokens_output
|
113 |
+
|
114 |
+
def to(self, dtype: torch.dtype, device: torch.device) -> "CLAPAudioModality":
|
115 |
+
self.dtype = dtype
|
116 |
+
self.module.to(device=device)
|
117 |
+
return self
|
118 |
+
|
119 |
+
def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Dict]]:
|
120 |
+
row_values = []
|
121 |
+
for row in rows:
|
122 |
+
audios = []
|
123 |
+
for audio_dict in row[self.data_key]:
|
124 |
+
audio_dict = load_audio(
|
125 |
+
audio_dict,
|
126 |
+
target_sampling_rate=self.module.processor.feature_extractor.sampling_rate,
|
127 |
+
)
|
128 |
+
audio_processed = self.module.processor(
|
129 |
+
audios=audio_dict["array"],
|
130 |
+
return_tensors="pt",
|
131 |
+
sampling_rate=audio_dict["sampling_rate"],
|
132 |
+
)
|
133 |
+
audios.append(audio_processed)
|
134 |
+
row_values.append(audios)
|
135 |
+
return row_values
|
136 |
+
|
137 |
+
@torch.no_grad()
|
138 |
+
def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]:
|
139 |
+
audio_features = []
|
140 |
+
for audio_batch in encoded_values:
|
141 |
+
audio_features.append(self.module.forward(audio_batch).to(dtype=self.dtype))
|
142 |
+
return audio_features
|
src/sonicverse/multi_token/modalities/audio_descript.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import dac
|
6 |
+
from audiotools import AudioSignal
|
7 |
+
|
8 |
+
from multi_token.model_utils import MultiTaskType
|
9 |
+
from multi_token.data_tools import load_audio_signal
|
10 |
+
from multi_token.modalities.base_modality import Modality
|
11 |
+
from multi_token.modalities.projectors import (
|
12 |
+
build_mlp_vector_projector, build_attentive_cnn_projector, build_cnn_mlp_projector, MultiTaskModel
|
13 |
+
)
|
14 |
+
|
15 |
+
import json
|
16 |
+
|
17 |
+
OUTPUT_FRAMES_SIZE = 512
|
18 |
+
# OUTPUT_EMB_SIZE = 2048
|
19 |
+
OUTPUT_EMB_CHANNELS = 96
|
20 |
+
|
21 |
+
class DescriptAudioModule(nn.Module):
|
22 |
+
def __init__(self, model_name_or_path: str, codebooks = 4):
|
23 |
+
super().__init__()
|
24 |
+
self.model_name_or_path = model_name_or_path
|
25 |
+
self.model = None
|
26 |
+
self.processor = None
|
27 |
+
self.codebooks = codebooks
|
28 |
+
|
29 |
+
self.load_model()
|
30 |
+
|
31 |
+
def load_model(self):
|
32 |
+
# self.model = ClapModel.from_pretrained(self.model_name_or_path)
|
33 |
+
self.model = dac.DAC.load(self.model_name_or_path)
|
34 |
+
|
35 |
+
def forward(self, audios) -> torch.Tensor:
|
36 |
+
embs = []
|
37 |
+
for audio_features in audios:
|
38 |
+
# print("Audio features sample rate ", audio_features[0].sample_rate)
|
39 |
+
x = self.model.preprocess(audio_features[0].audio_data, audio_features[0].sample_rate)
|
40 |
+
z, codes, latents, _, _ = self.model.encode(x)
|
41 |
+
|
42 |
+
# print("latents og shape ", latents.shape)
|
43 |
+
# If the tensor is larger than desired_shape, crop it
|
44 |
+
if latents.shape[2] > OUTPUT_FRAMES_SIZE:
|
45 |
+
latents = latents[:, :, :OUTPUT_FRAMES_SIZE]
|
46 |
+
# If the tensor is smaller than desired_shape, pad it
|
47 |
+
elif latents.shape[2] < OUTPUT_FRAMES_SIZE:
|
48 |
+
pad_width = (0, OUTPUT_FRAMES_SIZE - latents.shape[2])
|
49 |
+
latents = torch.nn.functional.pad(latents, pad_width)
|
50 |
+
# print("Codes new shape ", codes_new.shape)
|
51 |
+
|
52 |
+
# print("latents int shape ", latents.shape)
|
53 |
+
|
54 |
+
latents = latents[0][:self.codebooks]
|
55 |
+
|
56 |
+
# print("latents final shape ", latents.shape)
|
57 |
+
|
58 |
+
embs.append(latents)
|
59 |
+
|
60 |
+
embs = torch.stack(embs)
|
61 |
+
|
62 |
+
# output_embs = embs.view(-1, 1, OUTPUT_FRAMES_SIZE*self.codebooks)
|
63 |
+
# print("embs post view shape ", output_embs.shape)
|
64 |
+
|
65 |
+
return embs
|
66 |
+
|
67 |
+
@property
|
68 |
+
def dtype(self):
|
69 |
+
return self.model.dtype
|
70 |
+
|
71 |
+
@property
|
72 |
+
def device(self):
|
73 |
+
return self.model.device
|
74 |
+
|
75 |
+
|
76 |
+
class DescriptAudioModality(Modality):
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
model_name_or_path: str = dac.utils.download(model_type="16khz"),
|
80 |
+
num_projector_conv_layers: int = 2,
|
81 |
+
num_projector_mlp_layers: int = 2,
|
82 |
+
num_tokens_output: int = 10,
|
83 |
+
codebooks: int = 96,
|
84 |
+
use_multi_task: MultiTaskType = MultiTaskType.NO_MULTI_TASK,
|
85 |
+
tasks_config: str = None
|
86 |
+
):
|
87 |
+
self.model_name_or_path = model_name_or_path
|
88 |
+
self.module = DescriptAudioModule(model_name_or_path=self.model_name_or_path, codebooks=codebooks)
|
89 |
+
self.num_projector_conv_layers = num_projector_conv_layers
|
90 |
+
self.num_projector_mlp_layers = num_projector_mlp_layers
|
91 |
+
self.num_tokens_output = num_tokens_output
|
92 |
+
self.dtype = torch.float32
|
93 |
+
self.codebooks = codebooks
|
94 |
+
self.use_multi_task = use_multi_task
|
95 |
+
self.tasks = None
|
96 |
+
if self.use_multi_task != MultiTaskType.NO_MULTI_TASK:
|
97 |
+
with open(tasks_config, 'r') as f:
|
98 |
+
self.tasks = json.load(f)
|
99 |
+
|
100 |
+
print("Tasks :", self.tasks)
|
101 |
+
|
102 |
+
|
103 |
+
def build_projector(self, lm_hidden_size: int) -> nn.Module:
|
104 |
+
if self.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK:
|
105 |
+
projector = MultiTaskModel(OUTPUT_EMB_CHANNELS, 1, True, -1, False, self.tasks)
|
106 |
+
print("projector ", projector)
|
107 |
+
return projector
|
108 |
+
elif self.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
|
109 |
+
return build_mt_vector_projector(
|
110 |
+
# return build_mlp_vector_projector(
|
111 |
+
input_hidden_size=OUTPUT_EMB_CHANNELS,
|
112 |
+
lm_hidden_size=lm_hidden_size,
|
113 |
+
# num_layers=self.num_projector_layers,
|
114 |
+
# num_tokens=self.num_tokens_output,
|
115 |
+
# )
|
116 |
+
tasks = self.tasks
|
117 |
+
)
|
118 |
+
# )["llm_projector"]
|
119 |
+
else:
|
120 |
+
return build_multi_layer_cnn_mlp_projector(
|
121 |
+
input_channels = OUTPUT_EMB_CHANNELS,
|
122 |
+
input_size = OUTPUT_EMB_SIZE,
|
123 |
+
num_feature_layers= OUTPUT_FEATURE_LAYERS,
|
124 |
+
lm_hidden_size = lm_hidden_size,
|
125 |
+
num_tokens = self.num_tokens_output,
|
126 |
+
hidden_dim = self.hidden_dim,
|
127 |
+
num_conv_layers = self.num_conv_layers,
|
128 |
+
num_mlp_layers = self.num_mlp_layers
|
129 |
+
)
|
130 |
+
|
131 |
+
@property
|
132 |
+
def name(self) -> str:
|
133 |
+
return "audio_descript"
|
134 |
+
|
135 |
+
@property
|
136 |
+
def token(self) -> str:
|
137 |
+
return "<sound>"
|
138 |
+
|
139 |
+
@property
|
140 |
+
def data_key(self) -> str:
|
141 |
+
return "sounds"
|
142 |
+
|
143 |
+
@property
|
144 |
+
def token_width(self) -> int:
|
145 |
+
return self.num_tokens_output
|
146 |
+
|
147 |
+
def to(self, dtype: torch.dtype, device: torch.device) -> "DescriptAudioModality":
|
148 |
+
self.dtype = dtype
|
149 |
+
self.module.to(device=device)
|
150 |
+
return self
|
151 |
+
|
152 |
+
def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Dict]]:
|
153 |
+
row_values = []
|
154 |
+
for row in rows:
|
155 |
+
audios = []
|
156 |
+
for audio_dict in row[self.data_key]:
|
157 |
+
audio_dict = load_audio_signal(
|
158 |
+
audio_dict
|
159 |
+
)
|
160 |
+
audios.append(audio_dict["array"])
|
161 |
+
row_values.append(audios)
|
162 |
+
return row_values
|
163 |
+
|
164 |
+
@torch.no_grad()
|
165 |
+
def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]:
|
166 |
+
audio_features = []
|
167 |
+
for audio_batch in encoded_values:
|
168 |
+
audio_features.append(self.module.forward(audio_batch).to(dtype=self.dtype))
|
169 |
+
return audio_features
|
src/sonicverse/multi_token/modalities/audio_descript_bu.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import dac
|
6 |
+
from audiotools import AudioSignal
|
7 |
+
|
8 |
+
|
9 |
+
from multi_token.data_tools import load_audio_signal
|
10 |
+
from multi_token.modalities.base_modality import Modality
|
11 |
+
from multi_token.modalities.projectors import (
|
12 |
+
build_mlp_vector_projector, build_attentive_cnn_projector, build_cnn_mlp_projector
|
13 |
+
)
|
14 |
+
|
15 |
+
OUTPUT_FRAMES_SIZE = 512
|
16 |
+
# OUTPUT_EMB_SIZE = 2048
|
17 |
+
|
18 |
+
class DescriptAudioModule(nn.Module):
|
19 |
+
def __init__(self, model_name_or_path: str, codebooks = 4):
|
20 |
+
super().__init__()
|
21 |
+
self.model_name_or_path = model_name_or_path
|
22 |
+
self.model = None
|
23 |
+
self.processor = None
|
24 |
+
self.codebooks = codebooks
|
25 |
+
|
26 |
+
self.load_model()
|
27 |
+
|
28 |
+
def load_model(self):
|
29 |
+
# self.model = ClapModel.from_pretrained(self.model_name_or_path)
|
30 |
+
self.model = dac.DAC.load(self.model_name_or_path)
|
31 |
+
|
32 |
+
def forward(self, audios) -> torch.Tensor:
|
33 |
+
embs = []
|
34 |
+
for audio_features in audios:
|
35 |
+
x = self.model.preprocess(audio_features[0].audio_data, audio_features[0].sample_rate)
|
36 |
+
z, codes, latents, _, _ = self.model.encode(x)
|
37 |
+
|
38 |
+
# If the tensor is larger than desired_shape, crop it
|
39 |
+
if codes.shape[2] > OUTPUT_FRAMES_SIZE:
|
40 |
+
codes = codes[:, :, :OUTPUT_FRAMES_SIZE]
|
41 |
+
# If the tensor is smaller than desired_shape, pad it
|
42 |
+
elif codes.shape[2] < OUTPUT_FRAMES_SIZE:
|
43 |
+
pad_width = (0, OUTPUT_FRAMES_SIZE - codes.shape[2])
|
44 |
+
codes = torch.nn.functional.pad(codes, pad_width)
|
45 |
+
# print("Codes new shape ", codes_new.shape)
|
46 |
+
|
47 |
+
codes_of_interest = codes[0][:self.codebooks]
|
48 |
+
|
49 |
+
embs.append(codes_of_interest)
|
50 |
+
|
51 |
+
embs = torch.stack(embs)
|
52 |
+
|
53 |
+
# output_embs = embs.view(-1, 1, OUTPUT_FRAMES_SIZE*self.codebooks)
|
54 |
+
# print("embs post view shape ", output_embs.shape)
|
55 |
+
|
56 |
+
return embs
|
57 |
+
|
58 |
+
@property
|
59 |
+
def dtype(self):
|
60 |
+
return self.model.dtype
|
61 |
+
|
62 |
+
@property
|
63 |
+
def device(self):
|
64 |
+
return self.model.device
|
65 |
+
|
66 |
+
|
67 |
+
class DescriptAudioModality(Modality):
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
model_name_or_path: str = dac.utils.download(model_type="16khz"),
|
71 |
+
num_projector_conv_layers: int = 2,
|
72 |
+
num_projector_mlp_layers: int = 2,
|
73 |
+
num_tokens_output: int = 10,
|
74 |
+
codebooks: int = 4
|
75 |
+
):
|
76 |
+
self.model_name_or_path = model_name_or_path
|
77 |
+
self.module = DescriptAudioModule(model_name_or_path=self.model_name_or_path, codebooks=codebooks)
|
78 |
+
self.num_projector_conv_layers = num_projector_conv_layers
|
79 |
+
self.num_projector_mlp_layers = num_projector_mlp_layers
|
80 |
+
self.num_tokens_output = num_tokens_output
|
81 |
+
self.dtype = torch.float32
|
82 |
+
self.codebooks = codebooks
|
83 |
+
|
84 |
+
def build_projector(self, lm_hidden_size: int) -> nn.Module:
|
85 |
+
return build_cnn_mlp_projector(
|
86 |
+
input_channels=self.codebooks,
|
87 |
+
input_size=OUTPUT_FRAMES_SIZE,
|
88 |
+
lm_hidden_size=lm_hidden_size,
|
89 |
+
num_tokens=self.num_tokens_output,
|
90 |
+
hidden_dim=64,
|
91 |
+
num_conv_layers=self.num_projector_conv_layers,
|
92 |
+
num_mlp_layers=self.num_projector_mlp_layers
|
93 |
+
)
|
94 |
+
|
95 |
+
@property
|
96 |
+
def name(self) -> str:
|
97 |
+
return "audio_descript"
|
98 |
+
|
99 |
+
@property
|
100 |
+
def token(self) -> str:
|
101 |
+
return "<sound>"
|
102 |
+
|
103 |
+
@property
|
104 |
+
def data_key(self) -> str:
|
105 |
+
return "sounds"
|
106 |
+
|
107 |
+
@property
|
108 |
+
def token_width(self) -> int:
|
109 |
+
return self.num_tokens_output
|
110 |
+
|
111 |
+
def to(self, dtype: torch.dtype, device: torch.device) -> "DescriptAudioModality":
|
112 |
+
self.dtype = dtype
|
113 |
+
self.module.to(device=device)
|
114 |
+
return self
|
115 |
+
|
116 |
+
def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Dict]]:
|
117 |
+
row_values = []
|
118 |
+
for row in rows:
|
119 |
+
audios = []
|
120 |
+
for audio_dict in row[self.data_key]:
|
121 |
+
audio_dict = load_audio_signal(
|
122 |
+
audio_dict
|
123 |
+
)
|
124 |
+
audios.append(audio_dict["array"])
|
125 |
+
row_values.append(audios)
|
126 |
+
return row_values
|
127 |
+
|
128 |
+
@torch.no_grad()
|
129 |
+
def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]:
|
130 |
+
audio_features = []
|
131 |
+
for audio_batch in encoded_values:
|
132 |
+
audio_features.append(self.module.forward(audio_batch).to(dtype=self.dtype))
|
133 |
+
return audio_features
|
src/sonicverse/multi_token/modalities/audio_mert.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from transformers import Wav2Vec2FeatureExtractor, AutoModel
|
6 |
+
|
7 |
+
from multi_token.model_utils import MultiTaskType
|
8 |
+
from multi_token.data_tools import load_audio
|
9 |
+
from multi_token.modalities.base_modality import Modality
|
10 |
+
from multi_token.modalities.projectors import (
|
11 |
+
build_mlp_vector_projector, build_mt_vector_projector, build_multi_layer_cnn_mlp_projector, MultiTaskModel
|
12 |
+
)
|
13 |
+
from multi_token.modalities.multi_task_projector_shared import MultiTaskSharedModel
|
14 |
+
|
15 |
+
import json
|
16 |
+
|
17 |
+
OUTPUT_EMB_CHANNELS = 768 #1024
|
18 |
+
OUTPUT_EMB_SIZE = 760
|
19 |
+
OUTPUT_FEATURE_LAYERS = 13 #25
|
20 |
+
|
21 |
+
cache_dir="/home/ubuntu/.cache/"
|
22 |
+
|
23 |
+
class MERTAudioModule(nn.Module):
|
24 |
+
def __init__(self, model_name_or_path: str):
|
25 |
+
super().__init__()
|
26 |
+
self.model_name_or_path = model_name_or_path
|
27 |
+
self.model = None
|
28 |
+
self.processor = None
|
29 |
+
|
30 |
+
self.load_model()
|
31 |
+
|
32 |
+
def load_model(self):
|
33 |
+
self.model = AutoModel.from_pretrained(self.model_name_or_path, trust_remote_code=True, cache_dir=cache_dir)
|
34 |
+
self.processor = Wav2Vec2FeatureExtractor.from_pretrained(self.model_name_or_path,trust_remote_code=True, cache_dir=cache_dir)
|
35 |
+
self.model.requires_grad_(False)
|
36 |
+
|
37 |
+
@torch.no_grad()
|
38 |
+
def forward(self, audios) -> torch.Tensor:
|
39 |
+
embs = []
|
40 |
+
for audio_features in audios:
|
41 |
+
outputs = self.model(**audio_features.to(torch.float32), output_hidden_states=True)
|
42 |
+
features = torch.stack(outputs.hidden_states).squeeze()
|
43 |
+
embs.append(features)
|
44 |
+
embs = torch.stack(embs)
|
45 |
+
embs = embs.squeeze()
|
46 |
+
padding_needed = OUTPUT_EMB_SIZE - embs.shape[1]
|
47 |
+
embs = torch.nn.functional.pad(embs, (0, 0, 0, padding_needed, 0, 0))
|
48 |
+
return embs
|
49 |
+
|
50 |
+
@property
|
51 |
+
def dtype(self):
|
52 |
+
return self.model.dtype
|
53 |
+
|
54 |
+
@property
|
55 |
+
def device(self):
|
56 |
+
return self.model.device
|
57 |
+
|
58 |
+
|
59 |
+
class MERTAudioModality(Modality):
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
model_name_or_path: str = "m-a-p/MERT-v1-95M",
|
63 |
+
num_tokens_output: int = 10,
|
64 |
+
hidden_dim: int = 32,
|
65 |
+
num_conv_layers: int = 5,
|
66 |
+
num_mlp_layers: int = 5,
|
67 |
+
use_multi_task: MultiTaskType = MultiTaskType.NO_MULTI_TASK,
|
68 |
+
tasks_config: str = None
|
69 |
+
):
|
70 |
+
self.model_name_or_path = model_name_or_path
|
71 |
+
self.module = MERTAudioModule(model_name_or_path=self.model_name_or_path)
|
72 |
+
self.num_tokens_output = num_tokens_output
|
73 |
+
self.hidden_dim = hidden_dim
|
74 |
+
self.num_conv_layers = num_conv_layers
|
75 |
+
self.num_mlp_layers = num_mlp_layers
|
76 |
+
self.dtype = torch.float32
|
77 |
+
self.use_multi_task = use_multi_task
|
78 |
+
self.tasks = None
|
79 |
+
if self.use_multi_task != MultiTaskType.NO_MULTI_TASK:
|
80 |
+
with open(tasks_config, 'r') as f:
|
81 |
+
self.tasks = json.load(f)
|
82 |
+
|
83 |
+
print("Tasks :", self.tasks)
|
84 |
+
|
85 |
+
# all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
|
86 |
+
# print(all_layer_hidden_states.shape) # [25 layer, Time steps, 1024 feature_dim]
|
87 |
+
# time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
|
88 |
+
# print(time_reduced_hidden_states.shape) # [25, 1024]
|
89 |
+
|
90 |
+
def build_projector(self, lm_hidden_size: int) -> nn.Module:
|
91 |
+
if self.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK:
|
92 |
+
projector = MultiTaskSharedModel(self.tasks)
|
93 |
+
print("projector ", projector)
|
94 |
+
return projector
|
95 |
+
elif self.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
|
96 |
+
return build_mt_vector_projector(
|
97 |
+
# return build_mlp_vector_projector(
|
98 |
+
input_hidden_size=OUTPUT_EMB_SIZE,
|
99 |
+
lm_hidden_size=lm_hidden_size,
|
100 |
+
# num_layers=self.num_projector_layers,
|
101 |
+
# num_tokens=self.num_tokens_output,
|
102 |
+
# )
|
103 |
+
tasks = self.tasks
|
104 |
+
)
|
105 |
+
# )["llm_projector"]
|
106 |
+
else:
|
107 |
+
return build_multi_layer_cnn_mlp_projector(
|
108 |
+
input_channels = OUTPUT_EMB_CHANNELS,
|
109 |
+
input_size = OUTPUT_EMB_SIZE,
|
110 |
+
num_feature_layers= OUTPUT_FEATURE_LAYERS,
|
111 |
+
lm_hidden_size = lm_hidden_size,
|
112 |
+
num_tokens = self.num_tokens_output,
|
113 |
+
hidden_dim = self.hidden_dim,
|
114 |
+
num_conv_layers = self.num_conv_layers,
|
115 |
+
num_mlp_layers = self.num_mlp_layers
|
116 |
+
)
|
117 |
+
|
118 |
+
@property
|
119 |
+
def name(self) -> str:
|
120 |
+
return "audio_mert"
|
121 |
+
|
122 |
+
@property
|
123 |
+
def token(self) -> str:
|
124 |
+
return "<sound>"
|
125 |
+
|
126 |
+
@property
|
127 |
+
def data_key(self) -> str:
|
128 |
+
return "sounds"
|
129 |
+
|
130 |
+
@property
|
131 |
+
def token_width(self) -> int:
|
132 |
+
return self.num_tokens_output
|
133 |
+
|
134 |
+
def to(self, dtype: torch.dtype, device: torch.device) -> "MERTAudioModality":
|
135 |
+
self.dtype = dtype
|
136 |
+
self.module.to(device=device)
|
137 |
+
return self
|
138 |
+
|
139 |
+
def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Dict]]:
|
140 |
+
row_values = []
|
141 |
+
for row in rows:
|
142 |
+
audios = []
|
143 |
+
for audio_dict in row[self.data_key]:
|
144 |
+
audio_dict = load_audio(
|
145 |
+
audio_dict,
|
146 |
+
target_sampling_rate=self.module.processor.sampling_rate,
|
147 |
+
)
|
148 |
+
audio_processed = self.module.processor(
|
149 |
+
audio_dict["array"],
|
150 |
+
return_tensors="pt",
|
151 |
+
sampling_rate=audio_dict["sampling_rate"],
|
152 |
+
)
|
153 |
+
audios.append(audio_processed)
|
154 |
+
row_values.append(audios)
|
155 |
+
return row_values
|
156 |
+
|
157 |
+
@torch.no_grad()
|
158 |
+
def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]:
|
159 |
+
audio_features = []
|
160 |
+
for audio_batch in encoded_values:
|
161 |
+
audio_features.append(self.module.forward(audio_batch).to(dtype=self.dtype))
|
162 |
+
return audio_features
|
src/sonicverse/multi_token/modalities/audio_mert_bu.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from transformers import Wav2Vec2FeatureExtractor, AutoModel
|
6 |
+
|
7 |
+
from multi_token.model_utils import MultiTaskType
|
8 |
+
from multi_token.data_tools import load_audio
|
9 |
+
from multi_token.modalities.base_modality import Modality
|
10 |
+
from multi_token.modalities.projectors import (
|
11 |
+
build_mlp_vector_projector, build_mt_vector_projector, build_multi_layer_cnn_mlp_projector, MultiTaskModel
|
12 |
+
)
|
13 |
+
|
14 |
+
import json
|
15 |
+
|
16 |
+
OUTPUT_EMB_CHANNELS = 1024
|
17 |
+
OUTPUT_EMB_SIZE = 760
|
18 |
+
OUTPUT_FEATURE_LAYERS = 25
|
19 |
+
|
20 |
+
class MERTAudioModule(nn.Module):
|
21 |
+
def __init__(self, model_name_or_path: str):
|
22 |
+
super().__init__()
|
23 |
+
self.model_name_or_path = model_name_or_path
|
24 |
+
self.model = None
|
25 |
+
self.processor = None
|
26 |
+
|
27 |
+
self.load_model()
|
28 |
+
|
29 |
+
def load_model(self):
|
30 |
+
self.model = AutoModel.from_pretrained(self.model_name_or_path, trust_remote_code=True)
|
31 |
+
self.processor = Wav2Vec2FeatureExtractor.from_pretrained(self.model_name_or_path,trust_remote_code=True)
|
32 |
+
self.model.requires_grad_(False)
|
33 |
+
|
34 |
+
@torch.no_grad()
|
35 |
+
def forward(self, audios) -> torch.Tensor:
|
36 |
+
embs = []
|
37 |
+
for audio_features in audios:
|
38 |
+
outputs = self.model(**audio_features.to(torch.float32), output_hidden_states=True)
|
39 |
+
features = torch.stack(outputs.hidden_states).squeeze()
|
40 |
+
embs.append(features)
|
41 |
+
embs = torch.stack(embs)
|
42 |
+
embs = embs.squeeze()
|
43 |
+
padding_needed = OUTPUT_EMB_SIZE - embs.shape[1]
|
44 |
+
embs = torch.nn.functional.pad(embs, (0, 0, 0, padding_needed, 0, 0))
|
45 |
+
return embs
|
46 |
+
|
47 |
+
@property
|
48 |
+
def dtype(self):
|
49 |
+
return self.model.dtype
|
50 |
+
|
51 |
+
@property
|
52 |
+
def device(self):
|
53 |
+
return self.model.device
|
54 |
+
|
55 |
+
|
56 |
+
class MERTAudioModality(Modality):
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
model_name_or_path: str = "m-a-p/MERT-v1-330M",
|
60 |
+
num_tokens_output: int = 10,
|
61 |
+
hidden_dim: int = 32,
|
62 |
+
num_conv_layers: int = 5,
|
63 |
+
num_mlp_layers: int = 5,
|
64 |
+
use_multi_task: MultiTaskType = MultiTaskType.NO_MULTI_TASK,
|
65 |
+
tasks_config: str = None
|
66 |
+
):
|
67 |
+
self.model_name_or_path = model_name_or_path
|
68 |
+
self.module = MERTAudioModule(model_name_or_path=self.model_name_or_path)
|
69 |
+
self.num_tokens_output = num_tokens_output
|
70 |
+
self.hidden_dim = hidden_dim
|
71 |
+
self.num_conv_layers = num_conv_layers
|
72 |
+
self.num_mlp_layers = num_mlp_layers
|
73 |
+
self.dtype = torch.float32
|
74 |
+
self.use_multi_task = use_multi_task
|
75 |
+
self.tasks = None
|
76 |
+
if self.use_multi_task != MultiTaskType.NO_MULTI_TASK:
|
77 |
+
with open(tasks_config, 'r') as f:
|
78 |
+
self.tasks = json.load(f)
|
79 |
+
|
80 |
+
print("Tasks :", self.tasks)
|
81 |
+
|
82 |
+
# all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
|
83 |
+
# print(all_layer_hidden_states.shape) # [25 layer, Time steps, 1024 feature_dim]
|
84 |
+
# time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
|
85 |
+
# print(time_reduced_hidden_states.shape) # [25, 1024]
|
86 |
+
|
87 |
+
def build_projector(self, lm_hidden_size: int) -> nn.Module:
|
88 |
+
if self.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK:
|
89 |
+
projector = MultiTaskModel(OUTPUT_EMB_CHANNELS, OUTPUT_FEATURE_LAYERS, True, self.tasks)
|
90 |
+
print("projector ", projector)
|
91 |
+
return projector
|
92 |
+
elif self.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
|
93 |
+
return build_mt_vector_projector(
|
94 |
+
# return build_mlp_vector_projector(
|
95 |
+
input_hidden_size=OUTPUT_EMB_SIZE,
|
96 |
+
lm_hidden_size=lm_hidden_size,
|
97 |
+
# num_layers=self.num_projector_layers,
|
98 |
+
# num_tokens=self.num_tokens_output,
|
99 |
+
# )
|
100 |
+
tasks = self.tasks
|
101 |
+
)
|
102 |
+
# )["llm_projector"]
|
103 |
+
else:
|
104 |
+
return build_multi_layer_cnn_mlp_projector(
|
105 |
+
input_channels = OUTPUT_EMB_CHANNELS,
|
106 |
+
input_size = OUTPUT_EMB_SIZE,
|
107 |
+
num_feature_layers= OUTPUT_FEATURE_LAYERS,
|
108 |
+
lm_hidden_size = lm_hidden_size,
|
109 |
+
num_tokens = self.num_tokens_output,
|
110 |
+
hidden_dim = self.hidden_dim,
|
111 |
+
num_conv_layers = self.num_conv_layers,
|
112 |
+
num_mlp_layers = self.num_mlp_layers
|
113 |
+
)
|
114 |
+
|
115 |
+
@property
|
116 |
+
def name(self) -> str:
|
117 |
+
return "audio_mert"
|
118 |
+
|
119 |
+
@property
|
120 |
+
def token(self) -> str:
|
121 |
+
return "<sound>"
|
122 |
+
|
123 |
+
@property
|
124 |
+
def data_key(self) -> str:
|
125 |
+
return "sounds"
|
126 |
+
|
127 |
+
@property
|
128 |
+
def token_width(self) -> int:
|
129 |
+
return self.num_tokens_output
|
130 |
+
|
131 |
+
def to(self, dtype: torch.dtype, device: torch.device) -> "MERTAudioModality":
|
132 |
+
self.dtype = dtype
|
133 |
+
self.module.to(device=device)
|
134 |
+
return self
|
135 |
+
|
136 |
+
def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Dict]]:
|
137 |
+
row_values = []
|
138 |
+
for row in rows:
|
139 |
+
audios = []
|
140 |
+
for audio_dict in row[self.data_key]:
|
141 |
+
audio_dict = load_audio(
|
142 |
+
audio_dict,
|
143 |
+
target_sampling_rate=self.module.processor.sampling_rate,
|
144 |
+
)
|
145 |
+
audio_processed = self.module.processor(
|
146 |
+
audio_dict["array"],
|
147 |
+
return_tensors="pt",
|
148 |
+
sampling_rate=audio_dict["sampling_rate"],
|
149 |
+
)
|
150 |
+
audios.append(audio_processed)
|
151 |
+
row_values.append(audios)
|
152 |
+
return row_values
|
153 |
+
|
154 |
+
@torch.no_grad()
|
155 |
+
def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]:
|
156 |
+
audio_features = []
|
157 |
+
for audio_batch in encoded_values:
|
158 |
+
audio_features.append(self.module.forward(audio_batch).to(dtype=self.dtype))
|
159 |
+
return audio_features
|
src/sonicverse/multi_token/modalities/audio_whisper.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from transformers import AutoFeatureExtractor, WhisperModel
|
6 |
+
|
7 |
+
from multi_token.data_tools import load_audio
|
8 |
+
from multi_token.modalities.base_modality import Modality
|
9 |
+
from multi_token.modalities.projectors import (
|
10 |
+
build_mlp_vector_projector,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
OUTPUT_EMB_SIZE = 768
|
15 |
+
|
16 |
+
|
17 |
+
class WhisperAudioModule(nn.Module):
|
18 |
+
def __init__(self, model_name_or_path: str):
|
19 |
+
super().__init__()
|
20 |
+
self.model_name_or_path = model_name_or_path
|
21 |
+
self.model = None
|
22 |
+
self.feature_extractor = None
|
23 |
+
|
24 |
+
self.load_model()
|
25 |
+
|
26 |
+
def load_model(self):
|
27 |
+
self.model = WhisperModel.from_pretrained(self.model_name_or_path)
|
28 |
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
|
29 |
+
self.model_name_or_path
|
30 |
+
)
|
31 |
+
self.model.requires_grad_(False)
|
32 |
+
|
33 |
+
@torch.no_grad()
|
34 |
+
def forward(self, audios) -> torch.Tensor:
|
35 |
+
hidden_states = []
|
36 |
+
for i in range(audios.shape[0]):
|
37 |
+
decoder_input_ids = (
|
38 |
+
torch.tensor([[1]]) * self.model.config.decoder_start_token_id
|
39 |
+
)
|
40 |
+
last_hidden_state = self.model(
|
41 |
+
audios[i].to(device=self.device, dtype=self.dtype),
|
42 |
+
decoder_input_ids=decoder_input_ids.to(device=self.device),
|
43 |
+
).last_hidden_state
|
44 |
+
hidden_states.append(last_hidden_state)
|
45 |
+
last_hidden_state = torch.stack(hidden_states)
|
46 |
+
return last_hidden_state.view(-1, 1, OUTPUT_EMB_SIZE)
|
47 |
+
|
48 |
+
@property
|
49 |
+
def dtype(self):
|
50 |
+
return self.model.dtype
|
51 |
+
|
52 |
+
@property
|
53 |
+
def device(self):
|
54 |
+
return self.model.device
|
55 |
+
|
56 |
+
|
57 |
+
class WhisperAudioModality(Modality):
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
model_name_or_path: str = "openai/whisper-small",
|
61 |
+
num_projector_layers: int = 2,
|
62 |
+
num_tokens_output: int = 10,
|
63 |
+
):
|
64 |
+
self.model_name_or_path = model_name_or_path
|
65 |
+
self.module = WhisperAudioModule(model_name_or_path=self.model_name_or_path)
|
66 |
+
self.num_projector_layers = num_projector_layers
|
67 |
+
self.num_tokens_output = num_tokens_output
|
68 |
+
|
69 |
+
def build_projector(self, lm_hidden_size: int) -> nn.Module:
|
70 |
+
return build_mlp_vector_projector(
|
71 |
+
input_hidden_size=OUTPUT_EMB_SIZE,
|
72 |
+
lm_hidden_size=lm_hidden_size,
|
73 |
+
num_layers=self.num_projector_layers,
|
74 |
+
num_tokens=self.num_tokens_output,
|
75 |
+
)
|
76 |
+
|
77 |
+
@property
|
78 |
+
def name(self) -> str:
|
79 |
+
return "audio_whisper"
|
80 |
+
|
81 |
+
@property
|
82 |
+
def token(self) -> str:
|
83 |
+
return "<speech>"
|
84 |
+
|
85 |
+
@property
|
86 |
+
def data_key(self) -> str:
|
87 |
+
return "speech_audios"
|
88 |
+
|
89 |
+
@property
|
90 |
+
def token_width(self) -> int:
|
91 |
+
return self.num_tokens_output
|
92 |
+
|
93 |
+
def to(self, dtype: torch.dtype, device: torch.device) -> "WhisperAudioModality":
|
94 |
+
self.module.to(dtype=dtype, device=device)
|
95 |
+
return self
|
96 |
+
|
97 |
+
def preprocess_rows(self, rows: List[Dict]) -> List[Optional[torch.Tensor]]:
|
98 |
+
row_values = []
|
99 |
+
for row in rows:
|
100 |
+
audios = []
|
101 |
+
for audio_dict in row[self.data_key]:
|
102 |
+
audio_dict = load_audio(
|
103 |
+
audio_dict,
|
104 |
+
target_sampling_rate=self.module.feature_extractor.sampling_rate,
|
105 |
+
)
|
106 |
+
audio_processed = self.module.feature_extractor(
|
107 |
+
audio_dict["array"],
|
108 |
+
return_tensors="pt",
|
109 |
+
sampling_rate=audio_dict["sampling_rate"],
|
110 |
+
).input_features
|
111 |
+
audios.append(audio_processed)
|
112 |
+
row_values.append(torch.stack(audios) if len(audios) > 0 else None)
|
113 |
+
return row_values
|
114 |
+
|
115 |
+
@torch.no_grad()
|
116 |
+
def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]:
|
117 |
+
audio_features = []
|
118 |
+
for audio_batch in encoded_values:
|
119 |
+
audio_features.append(self.module.forward(audio_batch))
|
120 |
+
return audio_features
|
src/sonicverse/multi_token/modalities/base_modality.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Any
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from functools import cached_property
|
4 |
+
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class Modality(ABC):
|
10 |
+
@abstractmethod
|
11 |
+
def build_projector(self, lm_hidden_size: int) -> nn.Module:
|
12 |
+
pass
|
13 |
+
|
14 |
+
@property
|
15 |
+
@abstractmethod
|
16 |
+
def name(self) -> str:
|
17 |
+
pass
|
18 |
+
|
19 |
+
@property
|
20 |
+
@abstractmethod
|
21 |
+
def token(self) -> str:
|
22 |
+
pass
|
23 |
+
|
24 |
+
@property
|
25 |
+
@abstractmethod
|
26 |
+
def data_key(self) -> str:
|
27 |
+
pass
|
28 |
+
|
29 |
+
@property
|
30 |
+
@abstractmethod
|
31 |
+
def token_width(self) -> int:
|
32 |
+
pass
|
33 |
+
|
34 |
+
@cached_property
|
35 |
+
def token_idx(self) -> int:
|
36 |
+
hash_ = sum(ord(c) ** i for i, c in enumerate(self.token))
|
37 |
+
return -abs(hash_ % 10_000)
|
38 |
+
|
39 |
+
@abstractmethod
|
40 |
+
def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Any]]:
|
41 |
+
pass
|
42 |
+
|
43 |
+
@abstractmethod
|
44 |
+
def forward(self, encoded_values: List[Any]) -> List[torch.Tensor]:
|
45 |
+
pass
|
46 |
+
|
47 |
+
def to(self, dtype: torch.dtype, device: torch.device) -> "Modality":
|
48 |
+
return self
|
src/sonicverse/multi_token/modalities/bu__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from multi_token.model_utils import MultiTaskType
|
2 |
+
from multi_token.modalities.vision_clip import (
|
3 |
+
CLIPVisionModality,
|
4 |
+
OUTPUT_LAYER as CLIP_POOL_LAYER,
|
5 |
+
)
|
6 |
+
from multi_token.modalities.imagebind import ImageBindModality
|
7 |
+
from multi_token.modalities.document_gte import DocumentGTEModality
|
8 |
+
from multi_token.modalities.audio_whisper import WhisperAudioModality
|
9 |
+
from multi_token.modalities.audio_clap import CLAPAudioModality
|
10 |
+
from multi_token.modalities.video_xclip import XCLIPVideoModality
|
11 |
+
from multi_token.modalities.audio_descript import DescriptAudioModality
|
12 |
+
from multi_token.modalities.audio_mert import MERTAudioModality
|
13 |
+
|
14 |
+
MODALITY_BUILDERS = {
|
15 |
+
"vision_clip": lambda: [CLIPVisionModality()],
|
16 |
+
"vision_clip_pool": lambda: [
|
17 |
+
CLIPVisionModality(feature_layer=CLIP_POOL_LAYER, num_tokens_output=10)
|
18 |
+
],
|
19 |
+
"audio_whisper": lambda: [
|
20 |
+
WhisperAudioModality(
|
21 |
+
num_tokens_output=10, model_name_or_path="openai/whisper-small"
|
22 |
+
)
|
23 |
+
],
|
24 |
+
"audio_mert": lambda use_multi_task=MultiTaskType.NO_MULTI_TASK, tasks_config=None :[MERTAudioModality(use_multi_task=use_multi_task, tasks_config=tasks_config, num_tokens_output=60, hidden_dim=32, num_conv_layers = 3, num_mlp_layers = 2)],
|
25 |
+
"audio_clap": lambda use_multi_task=MultiTaskType.NO_MULTI_TASK, tasks_config=None :[CLAPAudioModality(use_multi_task=use_multi_task, tasks_config=tasks_config, num_tokens_output=20)],
|
26 |
+
"audio_descript": lambda: [DescriptAudioModality(num_projector_conv_layers=1, num_projector_mlp_layers=1, num_tokens_output=5, codebooks=12)],
|
27 |
+
"video_xclip": lambda: [XCLIPVideoModality(num_tokens_output=10)],
|
28 |
+
"imagebind": lambda: [ImageBindModality()],
|
29 |
+
"document_gte": lambda: [DocumentGTEModality()],
|
30 |
+
"document_gte_x16": lambda: [DocumentGTEModality(num_tokens_output=32)],
|
31 |
+
}
|
src/sonicverse/multi_token/modalities/document_gte.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import os
|
6 |
+
from functools import cache
|
7 |
+
from transformers import AutoTokenizer, AutoModel
|
8 |
+
|
9 |
+
from multi_token.modalities.base_modality import Modality
|
10 |
+
from multi_token.modalities.projectors import build_mlp_vector_projector
|
11 |
+
|
12 |
+
GTE_EMBEDDING_SIZE = 1024
|
13 |
+
GTE_CONTEXT_WINDOW = 512
|
14 |
+
GTE_DEFAULT_MODEL = "thenlper/gte-large"
|
15 |
+
DOCUMENT_GTE_FORCE_CPU = "DOCUMENT_GTE_FORCE_CPU"
|
16 |
+
|
17 |
+
|
18 |
+
def average_pool(
|
19 |
+
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
|
20 |
+
) -> torch.Tensor:
|
21 |
+
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
22 |
+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
23 |
+
|
24 |
+
|
25 |
+
@cache
|
26 |
+
def _get_tokenizer(model_name_or_path: str = GTE_DEFAULT_MODEL):
|
27 |
+
return AutoTokenizer.from_pretrained(model_name_or_path)
|
28 |
+
|
29 |
+
|
30 |
+
def split_text_into_documents(text: str) -> List[str]:
|
31 |
+
from nltk.tokenize import sent_tokenize
|
32 |
+
|
33 |
+
tokenizer = _get_tokenizer(GTE_DEFAULT_MODEL)
|
34 |
+
|
35 |
+
sentences = sent_tokenize(text)
|
36 |
+
documents = [[]]
|
37 |
+
|
38 |
+
for sentence in sentences:
|
39 |
+
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
40 |
+
if len(documents[-1]) + len(sentence_tokens) > GTE_CONTEXT_WINDOW:
|
41 |
+
documents.append([])
|
42 |
+
documents[-1].extend(sentence_tokens)
|
43 |
+
|
44 |
+
return [tokenizer.decode(doc) for doc in documents]
|
45 |
+
|
46 |
+
|
47 |
+
class DocumentGTEModule(nn.Module):
|
48 |
+
def __init__(self, model_name_or_path: str):
|
49 |
+
super().__init__()
|
50 |
+
self.feature_layer = -2
|
51 |
+
self.model_name_or_path = model_name_or_path
|
52 |
+
|
53 |
+
self.model = AutoModel.from_pretrained("thenlper/gte-large")
|
54 |
+
self.model.requires_grad_(False)
|
55 |
+
|
56 |
+
@torch.no_grad()
|
57 |
+
def forward(self, batch_dict) -> torch.Tensor:
|
58 |
+
outputs = self.model(**batch_dict)
|
59 |
+
embeddings = average_pool(
|
60 |
+
outputs.last_hidden_state, batch_dict["attention_mask"]
|
61 |
+
)
|
62 |
+
return embeddings
|
63 |
+
|
64 |
+
@property
|
65 |
+
def embedding_size(self):
|
66 |
+
return GTE_EMBEDDING_SIZE
|
67 |
+
|
68 |
+
|
69 |
+
class DocumentGTEModality(Modality):
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
model_name_or_path: str = GTE_DEFAULT_MODEL,
|
73 |
+
num_projector_layers: int = 2,
|
74 |
+
num_tokens_output: int = 4,
|
75 |
+
):
|
76 |
+
self.model_name_or_path = model_name_or_path
|
77 |
+
self.module = DocumentGTEModule(model_name_or_path=self.model_name_or_path)
|
78 |
+
self.tokenizer = _get_tokenizer(model_name_or_path)
|
79 |
+
self.num_projector_layers = num_projector_layers
|
80 |
+
self.num_tokens_output = num_tokens_output
|
81 |
+
self.dtype = torch.float32
|
82 |
+
self.device = "cpu"
|
83 |
+
self.document_gte_device = "cpu"
|
84 |
+
|
85 |
+
def build_projector(self, lm_hidden_size: int) -> nn.Module:
|
86 |
+
return build_mlp_vector_projector(
|
87 |
+
input_hidden_size=self.module.embedding_size,
|
88 |
+
lm_hidden_size=lm_hidden_size,
|
89 |
+
num_layers=self.num_projector_layers,
|
90 |
+
num_tokens=self.num_tokens_output,
|
91 |
+
)
|
92 |
+
|
93 |
+
@property
|
94 |
+
def name(self) -> str:
|
95 |
+
return "document_gte"
|
96 |
+
|
97 |
+
@property
|
98 |
+
def token(self) -> str:
|
99 |
+
return "<document>"
|
100 |
+
|
101 |
+
@property
|
102 |
+
def data_key(self) -> str:
|
103 |
+
return "documents"
|
104 |
+
|
105 |
+
@property
|
106 |
+
def token_width(self) -> int:
|
107 |
+
return self.num_tokens_output
|
108 |
+
|
109 |
+
def to(self, dtype: torch.dtype, device: torch.device) -> "DocumentGTEModality":
|
110 |
+
self.dtype = dtype
|
111 |
+
self.device = device
|
112 |
+
if DOCUMENT_GTE_FORCE_CPU not in os.environ:
|
113 |
+
# running out of VRAM on 24GB GPU
|
114 |
+
self.document_gte_device = device
|
115 |
+
self.module.to(device=self.document_gte_device)
|
116 |
+
return self
|
117 |
+
|
118 |
+
def preprocess_rows(self, rows: List[Dict]) -> List[Dict]:
|
119 |
+
row_values = []
|
120 |
+
for row in rows:
|
121 |
+
documents = []
|
122 |
+
for doc in row[self.data_key]:
|
123 |
+
documents.append(doc)
|
124 |
+
documents_tokenized = self.tokenizer(
|
125 |
+
documents,
|
126 |
+
max_length=GTE_CONTEXT_WINDOW,
|
127 |
+
padding=True,
|
128 |
+
truncation=True,
|
129 |
+
return_tensors="pt",
|
130 |
+
)
|
131 |
+
row_values.append(documents_tokenized)
|
132 |
+
return row_values
|
133 |
+
|
134 |
+
@torch.no_grad()
|
135 |
+
def forward(self, encoded_values: List[Dict]) -> List[torch.Tensor]:
|
136 |
+
outputs = []
|
137 |
+
for val in encoded_values:
|
138 |
+
outputs.append(
|
139 |
+
self.module.forward(val.to(device=self.document_gte_device))
|
140 |
+
.to(device=self.device, dtype=self.dtype)
|
141 |
+
.view(-1, 1, self.module.embedding_size)
|
142 |
+
)
|
143 |
+
# batch_size x num_items x 1 x embedding_size
|
144 |
+
return outputs
|
src/sonicverse/multi_token/modalities/imagebind.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from multi_token.modalities.base_modality import Modality
|
8 |
+
from multi_token.modalities.projectors import build_mlp_vector_projector
|
9 |
+
from multi_token.data_tools import with_local_files
|
10 |
+
|
11 |
+
IMAGE_BIND_FORCE_CPU = "IMAGE_BIND_FORCE_CPU"
|
12 |
+
IMAGE_BIND_EMBEDDING_SIZE = 1024
|
13 |
+
|
14 |
+
|
15 |
+
class ImageBindModule(nn.Module):
|
16 |
+
def __init__(self):
|
17 |
+
super().__init__()
|
18 |
+
from imagebind.models import imagebind_model
|
19 |
+
from imagebind import data
|
20 |
+
|
21 |
+
data.BPE_PATH = os.path.join(
|
22 |
+
os.path.dirname(data.__file__), "..", "bpe", "bpe_simple_vocab_16e6.txt.gz"
|
23 |
+
)
|
24 |
+
self.model = imagebind_model.imagebind_huge(pretrained=True)
|
25 |
+
self.model.eval()
|
26 |
+
self.model.requires_grad_(False)
|
27 |
+
|
28 |
+
@torch.no_grad()
|
29 |
+
def forward(self, items: Dict) -> torch.Tensor:
|
30 |
+
forward_outs = self.model(items)
|
31 |
+
return forward_outs
|
32 |
+
|
33 |
+
@property
|
34 |
+
def embedding_size(self):
|
35 |
+
return IMAGE_BIND_EMBEDDING_SIZE
|
36 |
+
|
37 |
+
|
38 |
+
class ImageBindModality(Modality):
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
num_projector_layers: int = 2,
|
42 |
+
num_tokens: int = 4,
|
43 |
+
preprocess_device: str = "cpu",
|
44 |
+
):
|
45 |
+
self.module = ImageBindModule()
|
46 |
+
self.dtype = torch.float32
|
47 |
+
self.device = "cpu" # used for outputs
|
48 |
+
self.imagebind_device = "cpu" # used for imagebind model itself
|
49 |
+
self.preprocess_device = preprocess_device # used for preprocessing
|
50 |
+
self.num_projector_layers = num_projector_layers
|
51 |
+
self.num_tokens = num_tokens
|
52 |
+
|
53 |
+
def build_projector(self, lm_hidden_size: int) -> nn.Module:
|
54 |
+
return build_mlp_vector_projector(
|
55 |
+
self.module.embedding_size,
|
56 |
+
lm_hidden_size,
|
57 |
+
num_layers=self.num_projector_layers,
|
58 |
+
num_tokens=self.num_tokens,
|
59 |
+
)
|
60 |
+
|
61 |
+
@property
|
62 |
+
def name(self) -> str:
|
63 |
+
return "imagebind"
|
64 |
+
|
65 |
+
@property
|
66 |
+
def token(self) -> str:
|
67 |
+
return "<imagebind>"
|
68 |
+
|
69 |
+
@property
|
70 |
+
def data_key(self) -> str:
|
71 |
+
return "imagebinds"
|
72 |
+
|
73 |
+
@property
|
74 |
+
def token_width(self) -> int:
|
75 |
+
return self.num_tokens
|
76 |
+
|
77 |
+
def to(self, dtype: torch.dtype, device: torch.device) -> "ImageBindModality":
|
78 |
+
# we ignore dtype and sometimes device as well
|
79 |
+
self.device = device
|
80 |
+
self.dtype = dtype
|
81 |
+
if IMAGE_BIND_FORCE_CPU not in os.environ:
|
82 |
+
# running out of VRAM on 24GB GPU
|
83 |
+
self.module.to(device=device)
|
84 |
+
self.imagebind_device = device
|
85 |
+
return self
|
86 |
+
|
87 |
+
def preprocess_rows(self, rows: List[Dict]) -> List[List[Dict]]:
|
88 |
+
from imagebind.models.imagebind_model import ModalityType
|
89 |
+
from imagebind import data
|
90 |
+
|
91 |
+
row_values = []
|
92 |
+
for row in rows:
|
93 |
+
items = []
|
94 |
+
with with_local_files(row[self.data_key]) as item_paths:
|
95 |
+
for item_path in item_paths:
|
96 |
+
ib_modality = filename_to_imagebind_modality(item_path)
|
97 |
+
if ib_modality == ModalityType.TEXT:
|
98 |
+
items.append(
|
99 |
+
{
|
100 |
+
ModalityType.TEXT: data.load_and_transform_text(
|
101 |
+
[item_path], self.preprocess_device
|
102 |
+
)
|
103 |
+
}
|
104 |
+
)
|
105 |
+
elif ib_modality == ModalityType.VISION:
|
106 |
+
items.append(
|
107 |
+
{
|
108 |
+
ModalityType.VISION: data.load_and_transform_vision_data(
|
109 |
+
[item_path], self.preprocess_device
|
110 |
+
)
|
111 |
+
}
|
112 |
+
)
|
113 |
+
elif ib_modality == ModalityType.AUDIO:
|
114 |
+
items.append(
|
115 |
+
{
|
116 |
+
ModalityType.AUDIO: data.load_and_transform_audio_data(
|
117 |
+
[item_path], self.preprocess_device
|
118 |
+
)
|
119 |
+
}
|
120 |
+
)
|
121 |
+
else:
|
122 |
+
raise ValueError(f"Unknown modality type: {ib_modality}")
|
123 |
+
row_values.append(items)
|
124 |
+
return row_values
|
125 |
+
|
126 |
+
@torch.no_grad()
|
127 |
+
def forward(self, encoded_values: List[List[Dict]]) -> List[torch.Tensor]:
|
128 |
+
item_features = []
|
129 |
+
for item_batch in encoded_values:
|
130 |
+
item_batch_emb = []
|
131 |
+
for item in item_batch:
|
132 |
+
item = {
|
133 |
+
k: v.to(device=self.imagebind_device, dtype=torch.float32)
|
134 |
+
for k, v in item.items()
|
135 |
+
}
|
136 |
+
item_batch_emb.extend(list(self.module.forward(item).values()))
|
137 |
+
item_features.append(
|
138 |
+
torch.stack(item_batch_emb).to(device=self.device, dtype=self.dtype)
|
139 |
+
)
|
140 |
+
# batch_size x num_items x 1 x embedding_size
|
141 |
+
return item_features
|
142 |
+
|
143 |
+
|
144 |
+
def filename_to_imagebind_modality(fn: str) -> str:
|
145 |
+
from imagebind.models.imagebind_model import ModalityType
|
146 |
+
|
147 |
+
_, ext = os.path.splitext(fn)
|
148 |
+
if ext in {".wav"}:
|
149 |
+
return ModalityType.AUDIO
|
150 |
+
elif ext in {".jpg", ".png", ".jpeg"}:
|
151 |
+
return ModalityType.VISION
|
152 |
+
else:
|
153 |
+
return ModalityType.TEXT
|
src/sonicverse/multi_token/modalities/multi_task_projector_shared.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.autograd import Variable
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from typing import Dict
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
class CNN(nn.Module):
|
9 |
+
def __init__(self, input_channels = 25, num_class=15):
|
10 |
+
super(CNN, self).__init__()
|
11 |
+
self.aggregator = nn.Parameter(torch.randn((input_channels, 1,1), dtype=torch.float))
|
12 |
+
self.input_channels = input_channels
|
13 |
+
|
14 |
+
# init bn
|
15 |
+
self.bn_init = nn.BatchNorm2d(1)
|
16 |
+
|
17 |
+
# layer 1
|
18 |
+
self.conv_1 = nn.Conv2d(1, 64, 3, padding=1)
|
19 |
+
self.bn_1 = nn.BatchNorm2d(64)
|
20 |
+
self.mp_1 = nn.MaxPool2d((2, 4))
|
21 |
+
|
22 |
+
# layer 2
|
23 |
+
self.conv_2 = nn.Conv2d(64, 128, 3, padding=1)
|
24 |
+
self.bn_2 = nn.BatchNorm2d(128)
|
25 |
+
self.mp_2 = nn.MaxPool2d((2, 4))
|
26 |
+
|
27 |
+
# layer 3
|
28 |
+
self.conv_3 = nn.Conv2d(128, 128, 3, padding=1)
|
29 |
+
self.bn_3 = nn.BatchNorm2d(128)
|
30 |
+
self.mp_3 = nn.MaxPool2d((2, 4))
|
31 |
+
|
32 |
+
# layer 4
|
33 |
+
self.conv_4 = nn.Conv2d(128, 128, 3, padding=1)
|
34 |
+
self.bn_4 = nn.BatchNorm2d(128)
|
35 |
+
self.mp_4 = nn.MaxPool2d((3, 5))
|
36 |
+
|
37 |
+
# layer 5
|
38 |
+
self.conv_5 = nn.Conv2d(128, 64, 3, padding=1)
|
39 |
+
self.bn_5 = nn.BatchNorm2d(64)
|
40 |
+
self.mp_5 = nn.MaxPool2d((3, 3))
|
41 |
+
|
42 |
+
# classifier
|
43 |
+
self.dense = nn.Linear(640, num_class)
|
44 |
+
self.dropout = nn.Dropout(0.5)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
aggregator_weights = F.softmax(self.aggregator)
|
48 |
+
# aggregator_weights = aggregator_weights.view(self.input_channels, 1)
|
49 |
+
# print("0 x shape : ")
|
50 |
+
x = (x * aggregator_weights).sum(dim=0)
|
51 |
+
|
52 |
+
# print("aggregator_output shape ", x.shape)
|
53 |
+
|
54 |
+
x = x.unsqueeze(0).unsqueeze(0)
|
55 |
+
|
56 |
+
# print("1 x shape ", x.shape)
|
57 |
+
# init bn
|
58 |
+
x = self.bn_init(x)
|
59 |
+
# print("2 x shape ", x.shape)
|
60 |
+
|
61 |
+
# layer 1
|
62 |
+
x = self.mp_1(nn.ELU()(self.bn_1(self.conv_1(x))))
|
63 |
+
# print("3 x shape ", x.shape)
|
64 |
+
|
65 |
+
# layer 2
|
66 |
+
x = self.mp_2(nn.ELU()(self.bn_2(self.conv_2(x))))
|
67 |
+
# print("4 x shape ", x.shape)
|
68 |
+
|
69 |
+
# layer 3
|
70 |
+
x = self.mp_3(nn.ELU()(self.bn_3(self.conv_3(x))))
|
71 |
+
# print("5 x shape ", x.shape)
|
72 |
+
|
73 |
+
# layer 4
|
74 |
+
x = self.mp_4(nn.ELU()(self.bn_4(self.conv_4(x))))
|
75 |
+
# print("6 x shape ", x.shape)
|
76 |
+
|
77 |
+
# layer 5
|
78 |
+
x = self.mp_5(nn.ELU()(self.bn_5(self.conv_5(x))))
|
79 |
+
# print("7 x shape ", x.shape)
|
80 |
+
|
81 |
+
# classifier
|
82 |
+
x = x.view(x.size(0), -1)
|
83 |
+
# print("8 x shape ", x.shape)
|
84 |
+
x = self.dropout(x)
|
85 |
+
# print("9 x shape ", x.shape)
|
86 |
+
logit = nn.Sigmoid()(self.dense(x))
|
87 |
+
# print("logit shape ", logit.shape)
|
88 |
+
|
89 |
+
return logit
|
90 |
+
|
91 |
+
|
92 |
+
class MLP(nn.Module):
|
93 |
+
def __init__(self, input_channels=25, num_class=15):
|
94 |
+
super(MLP, self).__init__()
|
95 |
+
self.aggregator = nn.Parameter(torch.randn((input_channels, 1,1), dtype=torch.float))
|
96 |
+
self.input_channels = input_channels
|
97 |
+
|
98 |
+
self.hidden_layer_1 = nn.Linear(768, 512)
|
99 |
+
self.output = nn.Linear(512, num_class)
|
100 |
+
self.dropout = nn.Dropout(p=0.2)
|
101 |
+
self.loss = self.get_loss() # can return a dict of losses
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
"""
|
105 |
+
x: (B, L, T, H)
|
106 |
+
T=#chunks, can be 1 or several chunks
|
107 |
+
"""
|
108 |
+
|
109 |
+
weights = F.softmax(self.aggregator, dim=1)
|
110 |
+
x = (x * weights).sum(dim=1)
|
111 |
+
|
112 |
+
x = x.mean(-2)
|
113 |
+
|
114 |
+
x = self.hidden_layer_1(x)
|
115 |
+
x = F.relu(x)
|
116 |
+
x = self.dropout(x)
|
117 |
+
|
118 |
+
return self.output(x)
|
119 |
+
|
120 |
+
def get_loss(self):
|
121 |
+
return nn.BCEWithLogitsLoss()
|
122 |
+
|
123 |
+
class MLPBackbone(nn.Module):
|
124 |
+
def __init__(self, input_features=768, hidden_dim=512):
|
125 |
+
super(MLPBackbone, self).__init__()
|
126 |
+
|
127 |
+
self.hidden_layer_1 = nn.Linear(input_features, hidden_dim)
|
128 |
+
self.dropout = nn.Dropout(p=0.2)
|
129 |
+
self.loss = self.get_loss() # can return a dict of losses
|
130 |
+
|
131 |
+
def forward(self, x):
|
132 |
+
"""
|
133 |
+
x: (B, L, T, H)
|
134 |
+
T=#chunks, can be 1 or several chunks
|
135 |
+
"""
|
136 |
+
|
137 |
+
x = self.hidden_layer_1(x)
|
138 |
+
x = F.relu(x)
|
139 |
+
x = self.dropout(x)
|
140 |
+
|
141 |
+
return x
|
142 |
+
|
143 |
+
def get_loss(self):
|
144 |
+
return nn.BCEWithLogitsLoss()
|
145 |
+
|
146 |
+
class MLPShared(nn.Module):
|
147 |
+
def __init__(self, input_channels=25, num_class=15):
|
148 |
+
super(MLPShared, self).__init__()
|
149 |
+
self.aggregator = nn.Parameter(torch.randn((input_channels, 1,1), dtype=torch.float))
|
150 |
+
self.input_channels = input_channels
|
151 |
+
|
152 |
+
self.hidden_layer_1 = nn.Linear(512, 256)
|
153 |
+
self.output = nn.Linear(256, num_class)
|
154 |
+
self.dropout = nn.Dropout(p=0.2)
|
155 |
+
self.loss = self.get_loss() # can return a dict of losses
|
156 |
+
|
157 |
+
def forward(self, x):
|
158 |
+
"""
|
159 |
+
x: (B, L, T, H)
|
160 |
+
T=#chunks, can be 1 or several chunks
|
161 |
+
"""
|
162 |
+
|
163 |
+
weights = F.softmax(self.aggregator, dim=1)
|
164 |
+
x = (x * weights).sum(dim=1)
|
165 |
+
|
166 |
+
x = x.mean(-2)
|
167 |
+
|
168 |
+
x = self.hidden_layer_1(x)
|
169 |
+
x = F.relu(x)
|
170 |
+
x = self.dropout(x)
|
171 |
+
|
172 |
+
return self.output(x)
|
173 |
+
|
174 |
+
def get_loss(self):
|
175 |
+
return nn.BCEWithLogitsLoss()
|
176 |
+
|
177 |
+
class MLPAggTaskHead(nn.Module):
|
178 |
+
def __init__(self, input_channels: int, input_size: int, output_size: int, use_aggregator: bool, use_time_average: bool, use_sigmoid: bool, use_transpose: bool, num_layers: int, hidden_dim: int, width: int):
|
179 |
+
super(MLPAggTaskHead, self).__init__()
|
180 |
+
if use_aggregator:
|
181 |
+
self.aggregator = nn.Parameter(torch.randn((input_channels), dtype=torch.float))
|
182 |
+
self.use_aggregator = use_aggregator
|
183 |
+
self.use_time_average = use_time_average
|
184 |
+
self.use_transpose = use_transpose
|
185 |
+
self.use_sigmoid = use_sigmoid
|
186 |
+
self.input_channels = input_channels
|
187 |
+
self.output_size = output_size
|
188 |
+
self.width = width
|
189 |
+
|
190 |
+
if self.width > 1:
|
191 |
+
self.layers = nn.ModuleList()
|
192 |
+
for i in range(self.width):
|
193 |
+
mlp_layers = [nn.GELU()]
|
194 |
+
mlp_layers += self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim)
|
195 |
+
if self.use_sigmoid: mlp_layers += [nn.Sigmoid()]
|
196 |
+
self.layers.append(nn.Sequential(*mlp_layers))
|
197 |
+
else:
|
198 |
+
mlp_layers = [nn.GELU()]
|
199 |
+
mlp_layers += self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim)
|
200 |
+
if self.use_sigmoid: mlp_layers += [nn.Sigmoid()]
|
201 |
+
self.layers = nn.Sequential(*mlp_layers)
|
202 |
+
|
203 |
+
def _create_mlp_layers(self, input_size, output_size, num_layers, hidden_dim):
|
204 |
+
if num_layers >=2:
|
205 |
+
layers = [nn.Linear(input_size, hidden_dim)]
|
206 |
+
layers.append(nn.GELU())
|
207 |
+
if num_layers > 2:
|
208 |
+
for _ in range(1, num_layers - 2):
|
209 |
+
layers += [
|
210 |
+
nn.Linear(hidden_dim, hidden_dim),
|
211 |
+
nn.GELU()
|
212 |
+
]
|
213 |
+
layers.append(nn.Linear(hidden_dim, output_size))
|
214 |
+
else:
|
215 |
+
layers = [nn.Linear(input_size, output_size)]
|
216 |
+
return layers
|
217 |
+
|
218 |
+
|
219 |
+
def forward(self, x):
|
220 |
+
if self.use_transpose:
|
221 |
+
x = x.transpose(1, 0)
|
222 |
+
if self.use_time_average:
|
223 |
+
x = x.mean(-2)
|
224 |
+
if self.use_aggregator:
|
225 |
+
aggregator_weights = F.softmax(self.aggregator)
|
226 |
+
aggregator_weights = aggregator_weights.view(self.input_channels, 1)
|
227 |
+
aggregator_output = (x * aggregator_weights).sum(dim=0)
|
228 |
+
aggregator_output = aggregator_output.unsqueeze(dim=0)
|
229 |
+
# print("Agg output ", aggregator_output.shape)
|
230 |
+
else:
|
231 |
+
aggregator_output = x
|
232 |
+
|
233 |
+
if self.width > 1:
|
234 |
+
if (self.input_channels < 1):
|
235 |
+
return torch.cat([layer(aggregator_output.unsqueeze(dim=0)) for layer in self.layers], dim=-2)
|
236 |
+
else:
|
237 |
+
return torch.cat([layer(aggregator_output.unsqueeze(dim=0)).squeeze(dim=0) for layer in self.layers], dim=-2)
|
238 |
+
else:
|
239 |
+
if (self.input_channels < 1):
|
240 |
+
return self.layers(aggregator_output.unsqueeze(dim=0))
|
241 |
+
else:
|
242 |
+
return self.layers(aggregator_output.unsqueeze(dim=0)).squeeze()
|
243 |
+
|
244 |
+
|
245 |
+
class MultiTaskModel(nn.Module):
|
246 |
+
def __init__(self, tasks: Dict):
|
247 |
+
super(MultiTaskModel, self).__init__()
|
248 |
+
self.tasks = tasks
|
249 |
+
for task_name, task_head in self.tasks["task_heads"].items():
|
250 |
+
setattr(self, task_name, MLP(13, task_head["output_size"]))
|
251 |
+
if task_name in self.tasks["task_projectors"].keys():
|
252 |
+
task_projector = tasks["task_projectors"][task_name]
|
253 |
+
setattr(self, task_name + "_projector", MLPAggTaskHead(task_projector["input_channels"], task_projector["input_size"], task_projector["output_size"], task_projector["use_aggregator"], task_projector["use_time_average"], task_projector["use_sigmoid"], task_projector["use_transpose"], task_projector["num_layers"], task_projector["hidden_size"], task_projector["width"]))
|
254 |
+
|
255 |
+
def forward(self, x):
|
256 |
+
task_head_outputs = {}
|
257 |
+
task_projector_outputs = []
|
258 |
+
|
259 |
+
backbone_output = x
|
260 |
+
|
261 |
+
for task_name in self.tasks["task_heads"]:
|
262 |
+
if task_name != "lmm_projector":
|
263 |
+
task_head_outputs[task_name] = getattr(self, task_name)(backbone_output)
|
264 |
+
if task_name in self.tasks["task_projectors"].keys():
|
265 |
+
task_projector_outputs.append(getattr(self, task_name + "_projector")(task_head_outputs[task_name]))
|
266 |
+
else:
|
267 |
+
task_projector_outputs.append(getattr(self, task_name)(backbone_output))
|
268 |
+
|
269 |
+
if len(task_projector_outputs) > 0:
|
270 |
+
task_projector_outputs_unsqueezed = [task_projector_output.unsqueeze(0) for task_projector_output in task_projector_outputs]
|
271 |
+
task_head_outputs["projectors"] = torch.cat(task_projector_outputs_unsqueezed, dim=-2)
|
272 |
+
|
273 |
+
return task_head_outputs
|
274 |
+
|
275 |
+
class MultiTaskSharedModel(nn.Module):
|
276 |
+
def __init__(self, tasks: Dict):
|
277 |
+
super(MultiTaskSharedModel, self).__init__()
|
278 |
+
self.tasks = tasks
|
279 |
+
self.use_backbone = False
|
280 |
+
if "backbone" in self.tasks.keys():
|
281 |
+
self.use_backbone = True
|
282 |
+
if self.use_backbone: self.backbone = MLPBackbone(768, 512)
|
283 |
+
for task_name, task_head in self.tasks["task_heads"].items():
|
284 |
+
if task_name != "lmm_projector":
|
285 |
+
setattr(self, task_name, MLPShared(13, task_head["output_size"]))
|
286 |
+
else:
|
287 |
+
setattr(self, task_name, MLPAggTaskHead(task_head["input_channels"], task_head["input_size"], task_head["output_size"], task_head["use_aggregator"], task_head["use_time_average"], task_head["use_sigmoid"], task_head["use_transpose"], task_head["num_layers"], task_head["hidden_size"], task_head["width"]))
|
288 |
+
if task_name in self.tasks["task_projectors"].keys():
|
289 |
+
task_projector = tasks["task_projectors"][task_name]
|
290 |
+
setattr(self, task_name + "_projector", MLPAggTaskHead(task_projector["input_channels"], task_projector["input_size"], task_projector["output_size"], task_projector["use_aggregator"], task_projector["use_time_average"], task_projector["use_sigmoid"], task_projector["use_transpose"], task_projector["num_layers"], task_projector["hidden_size"], task_projector["width"]))
|
291 |
+
|
292 |
+
def forward(self, x):
|
293 |
+
task_head_outputs = {}
|
294 |
+
task_projector_outputs = []
|
295 |
+
|
296 |
+
if self.use_backbone:
|
297 |
+
backbone_output = self.backbone(x)
|
298 |
+
else:
|
299 |
+
backbone_output = x
|
300 |
+
|
301 |
+
#print("Output shape ", backbone_output.shape)
|
302 |
+
for task_name in self.tasks["task_heads"]:
|
303 |
+
#print("task namee ", task_name)
|
304 |
+
if task_name != "lmm_projector":
|
305 |
+
task_head_outputs[task_name] = getattr(self, task_name)(backbone_output)
|
306 |
+
if task_name in self.tasks["task_projectors"].keys():
|
307 |
+
task_projector_outputs.append(getattr(self, task_name + "_projector")(task_head_outputs[task_name]))
|
308 |
+
else:
|
309 |
+
llm_input = x
|
310 |
+
if self.tasks["task_heads"][task_name]["use_backbone_output"]:
|
311 |
+
llm_input = backbone_output
|
312 |
+
task_projector_outputs.append(getattr(self, task_name)(llm_input))
|
313 |
+
|
314 |
+
if len(task_projector_outputs) > 0:
|
315 |
+
task_projector_outputs_unsqueezed = [task_projector_output.unsqueeze(0) for task_projector_output in task_projector_outputs]
|
316 |
+
task_head_outputs["projectors"] = torch.cat(task_projector_outputs_unsqueezed, dim=-2)
|
317 |
+
|
318 |
+
return task_head_outputs
|
319 |
+
|
320 |
+
|
321 |
+
|
src/sonicverse/multi_token/modalities/projectors.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
from typing import Dict
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
def build_patch_mlp_projector(
|
9 |
+
input_hidden_size: int, lm_hidden_size: int, num_layers: int
|
10 |
+
) -> nn.Module:
|
11 |
+
modules = [nn.Linear(input_hidden_size, lm_hidden_size)]
|
12 |
+
for _ in range(1, num_layers):
|
13 |
+
modules.append(nn.GELU())
|
14 |
+
modules.append(nn.Linear(lm_hidden_size, lm_hidden_size))
|
15 |
+
return nn.Sequential(*modules)
|
16 |
+
|
17 |
+
|
18 |
+
class _MLPVectorProjector(nn.Module):
|
19 |
+
def __init__(
|
20 |
+
self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int
|
21 |
+
):
|
22 |
+
super(_MLPVectorProjector, self).__init__()
|
23 |
+
self.mlps = nn.ModuleList()
|
24 |
+
for _ in range(width):
|
25 |
+
mlp = [nn.Linear(input_hidden_size, lm_hidden_size)]
|
26 |
+
for _ in range(1, num_layers):
|
27 |
+
mlp.append(nn.GELU())
|
28 |
+
mlp.append(nn.Linear(lm_hidden_size, lm_hidden_size))
|
29 |
+
self.mlps.append(nn.Sequential(*mlp))
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
output = torch.cat([mlp(x) for mlp in self.mlps], dim=-2)
|
33 |
+
return output
|
34 |
+
|
35 |
+
def build_mlp_vector_projector(
|
36 |
+
input_hidden_size: int, lm_hidden_size: int, num_layers: int, num_tokens: int
|
37 |
+
):
|
38 |
+
return _MLPVectorProjector(
|
39 |
+
input_hidden_size, lm_hidden_size, num_layers, num_tokens
|
40 |
+
)
|
41 |
+
|
42 |
+
class MLPBackbone(nn.Module):
|
43 |
+
def __init__(self, input_size: int, output_size: int, num_layers: int, hidden_dim: int):
|
44 |
+
super(MLPBackbone, self).__init__()
|
45 |
+
self.output_size = output_size
|
46 |
+
mlp_layers = self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim)
|
47 |
+
self.layers = nn.Sequential(*mlp_layers)
|
48 |
+
|
49 |
+
def _create_conv_layers(self, input_channels, num_conv_layers, hidden_dim):
|
50 |
+
layers = []
|
51 |
+
for _ in range(num_conv_layers):
|
52 |
+
layers += [
|
53 |
+
nn.Conv1d(input_channels, hidden_dim, kernel_size=3, padding=1),
|
54 |
+
nn.GELU(),
|
55 |
+
nn.MaxPool1d(kernel_size=2, stride=2)
|
56 |
+
]
|
57 |
+
input_channels = hidden_dim
|
58 |
+
return layers
|
59 |
+
|
60 |
+
def _create_mlp_layers(self, input_size, output_size, num_layers, hidden_dim):
|
61 |
+
if num_layers >=2:
|
62 |
+
layers = [nn.Linear(input_size, hidden_dim)]
|
63 |
+
layers.append(nn.GELU())
|
64 |
+
if num_layers > 2:
|
65 |
+
for _ in range(1, num_layers - 2):
|
66 |
+
layers += [
|
67 |
+
nn.Linear(hidden_dim, hidden_dim),
|
68 |
+
nn.GELU()
|
69 |
+
]
|
70 |
+
layers.append(nn.Linear(hidden_dim, output_size))
|
71 |
+
else:
|
72 |
+
layers = [nn.Linear(input_size, output_size)]
|
73 |
+
return layers
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
return self.layers(x)
|
77 |
+
|
78 |
+
class MLPTaskHead(nn.Module):
|
79 |
+
def __init__(self, backbone: nn.Module, input_size: int, output_size: int, num_layers: int, hidden_dim: int, width: int = 1):
|
80 |
+
super(MLPTaskHead, self).__init__()
|
81 |
+
self.backbone = backbone
|
82 |
+
self.width = width
|
83 |
+
if width > 1:
|
84 |
+
self.layers = nn.ModuleList()
|
85 |
+
for i in range(width):
|
86 |
+
mlp_layers = [nn.GELU()]
|
87 |
+
mlp_layers += self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim)
|
88 |
+
self.layers.append(nn.Sequential(*mlp_layers))
|
89 |
+
else:
|
90 |
+
mlp_layers = [nn.GELU()]
|
91 |
+
mlp_layers += self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim)
|
92 |
+
self.layers = nn.Sequential(*mlp_layers)
|
93 |
+
|
94 |
+
def _create_mlp_layers(self, input_size, output_size, num_layers, hidden_dim):
|
95 |
+
if num_layers >=2:
|
96 |
+
layers = [nn.Linear(input_size, hidden_dim)]
|
97 |
+
layers.append(nn.GELU())
|
98 |
+
if num_layers > 2:
|
99 |
+
for _ in range(1, num_layers - 2):
|
100 |
+
layers += [
|
101 |
+
nn.Linear(hidden_dim, hidden_dim),
|
102 |
+
nn.GELU()
|
103 |
+
]
|
104 |
+
layers.append(nn.Linear(hidden_dim, output_size))
|
105 |
+
else:
|
106 |
+
layers = [nn.Linear(input_size, output_size)]
|
107 |
+
return layers
|
108 |
+
|
109 |
+
def _create_conv_layers(self, input_channels, num_conv_layers, hidden_dim):
|
110 |
+
layers = []
|
111 |
+
for _ in range(num_conv_layers):
|
112 |
+
layers += [
|
113 |
+
nn.Conv2d(in_channels = input_channels, out_channels = hidden_dim, kernel_size=(3,3), stride=1, padding=1),
|
114 |
+
nn.GELU(),
|
115 |
+
nn.MaxPool1d(kernel_size=2, stride=2)
|
116 |
+
]
|
117 |
+
input_channels = hidden_dim
|
118 |
+
return layers
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
output = self.backbone.forward(x)
|
122 |
+
if self.width > 1:
|
123 |
+
return torch.cat([layer(output) for layer in self.layers], dim=-2)
|
124 |
+
else:
|
125 |
+
return self.layers(output)
|
126 |
+
|
127 |
+
class MLPTaskModule(nn.Module):
|
128 |
+
def __init__(self, input_size: int, output_size: int, num_layers: int, hidden_dim: int, width: int = 1):
|
129 |
+
super(MLPTaskModule, self).__init__()
|
130 |
+
self.width = width
|
131 |
+
if width > 1:
|
132 |
+
self.layers = nn.ModuleList()
|
133 |
+
for i in range(width):
|
134 |
+
mlp_layers = [nn.GELU()]
|
135 |
+
mlp_layers += self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim)
|
136 |
+
self.layers.append(nn.Sequential(*mlp_layers))
|
137 |
+
else:
|
138 |
+
mlp_layers = [nn.GELU()]
|
139 |
+
mlp_layers += self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim)
|
140 |
+
self.layers = nn.Sequential(*mlp_layers)
|
141 |
+
|
142 |
+
def _create_mlp_layers(self, input_size, output_size, num_layers, hidden_dim):
|
143 |
+
if num_layers >=2:
|
144 |
+
layers = [nn.Linear(input_size, hidden_dim)]
|
145 |
+
layers.append(nn.GELU())
|
146 |
+
if num_layers > 2:
|
147 |
+
for _ in range(1, num_layers - 2):
|
148 |
+
layers += [
|
149 |
+
nn.Linear(hidden_dim, hidden_dim),
|
150 |
+
nn.GELU()
|
151 |
+
]
|
152 |
+
layers.append(nn.Linear(hidden_dim, output_size))
|
153 |
+
else:
|
154 |
+
layers = [nn.Linear(input_size, output_size)]
|
155 |
+
return layers
|
156 |
+
|
157 |
+
def _create_conv_layers(self, input_channels, num_conv_layers, hidden_dim):
|
158 |
+
layers = []
|
159 |
+
for _ in range(num_conv_layers):
|
160 |
+
layers += [
|
161 |
+
nn.Conv2d(in_channels = input_channels, out_channels = hidden_dim, kernel_size=(3,3), stride=1, padding=1),
|
162 |
+
nn.GELU(),
|
163 |
+
nn.MaxPool1d(kernel_size=2, stride=2)
|
164 |
+
]
|
165 |
+
input_channels = hidden_dim
|
166 |
+
return layers
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
if self.width > 1:
|
170 |
+
return torch.cat([layer(x) for layer in self.layers], dim=-2)
|
171 |
+
else:
|
172 |
+
return self.layers(x)
|
173 |
+
|
174 |
+
|
175 |
+
class MultiTaskModel(nn.Module):
|
176 |
+
def __init__(self, input_hidden_size: int, input_channels: int, time_average: bool, time_dimension: int, use_aggregator: bool, tasks: Dict):
|
177 |
+
super(MultiTaskModel, self).__init__()
|
178 |
+
self.tasks = tasks
|
179 |
+
self.time_average = time_average
|
180 |
+
self.time_dimension = time_dimension
|
181 |
+
self.use_aggregator = use_aggregator
|
182 |
+
if self.use_aggregator:
|
183 |
+
if (time_average):
|
184 |
+
self.aggregator = nn.Parameter(torch.randn((input_channels, 1), dtype = torch.float))
|
185 |
+
else:
|
186 |
+
self.aggregator = nn.Parameter(torch.randn((input_channels, 1, 1), dtype = torch.float))
|
187 |
+
|
188 |
+
self.backbone = MLPBackbone(input_hidden_size, self.tasks["backbone"]["output_size"], self.tasks["backbone"]["num_layers"], self.tasks["backbone"]["hidden_size"])
|
189 |
+
for task_name, task_head in self.tasks["task_heads"].items():
|
190 |
+
setattr(self, task_name, MLPTaskModule(self.tasks["backbone"]["output_size"], task_head["output_size"], task_head["num_layers"], task_head["hidden_size"], task_head["width"]))
|
191 |
+
if task_name in self.tasks["task_projectors"].keys():
|
192 |
+
task_projector = tasks["task_projectors"][task_name]
|
193 |
+
setattr(self, task_name + "_projector", MLPTaskModule(task_head["output_size"], task_projector["output_size"], task_projector["num_layers"], task_projector["hidden_size"], task_projector["width"]))
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
task_head_outputs = {}
|
197 |
+
task_projector_outputs = []
|
198 |
+
|
199 |
+
if self.time_average:
|
200 |
+
x = x.mean(self.time_dimension)
|
201 |
+
if self.use_aggregator:
|
202 |
+
aggregator_weights = F.softmax(self.aggregator, dim=0)
|
203 |
+
aggregator_output = (x * aggregator_weights).sum(dim=0)
|
204 |
+
aggregator_output = aggregator_output.unsqueeze(0)
|
205 |
+
else:
|
206 |
+
aggregator_output = x
|
207 |
+
|
208 |
+
backbone_output = self.backbone(aggregator_output)
|
209 |
+
|
210 |
+
for task_name in self.tasks["task_heads"]:
|
211 |
+
if task_name != "lmm_projector":
|
212 |
+
task_head_output = getattr(self, task_name)(backbone_output)
|
213 |
+
min_val = torch.min(task_head_output)
|
214 |
+
max_val = torch.max(task_head_output)
|
215 |
+
|
216 |
+
normalized_task_head_output = (task_head_output - min_val) / (max_val - min_val)
|
217 |
+
task_head_outputs[task_name] = normalized_task_head_output
|
218 |
+
if task_name in self.tasks["task_projectors"].keys():
|
219 |
+
task_projector_outputs.append(getattr(self, task_name + "_projector")(task_head_outputs[task_name]))
|
220 |
+
else:
|
221 |
+
task_projector_outputs.append(getattr(self, task_name)(backbone_output))
|
222 |
+
|
223 |
+
task_projector_outputs_unsqueezed = [task_projector_output.unsqueeze(0) for task_projector_output in task_projector_outputs]
|
224 |
+
if len(task_projector_outputs_unsqueezed) > 0:
|
225 |
+
task_head_outputs["projectors"] = torch.cat(task_projector_outputs_unsqueezed, dim=-2)
|
226 |
+
|
227 |
+
return task_head_outputs
|
228 |
+
|
229 |
+
|
230 |
+
def build_mt_vector_projector(
|
231 |
+
input_hidden_size: int, lm_hidden_size: int, tasks: Dict
|
232 |
+
):
|
233 |
+
projector = nn.ModuleDict()
|
234 |
+
projector["backbone"] = MLPBackbone(input_hidden_size, tasks["backbone"]["output_size"], tasks["backbone"]["num_layers"], tasks["backbone"]["hidden_size"])
|
235 |
+
for task_name, task_head in tasks["task_heads"].items():
|
236 |
+
projector[task_name] = MLPTaskHead(projector["backbone"], task_head["hidden_size"], task_head["output_size"], task_head["num_layers"], task_head["hidden_size"], task_head["width"])
|
237 |
+
|
238 |
+
return projector
|
239 |
+
|
240 |
+
class Attention(nn.Module):
|
241 |
+
def __init__(self, input_dim, hidden_dim):
|
242 |
+
super(Attention, self).__init__()
|
243 |
+
self.linear_in = nn.Linear(input_dim, hidden_dim)
|
244 |
+
self.linear_out = nn.Linear(hidden_dim, 1)
|
245 |
+
|
246 |
+
def forward(self, x):
|
247 |
+
# Input shape: (batch_size, seq_len, input_dim)
|
248 |
+
energy = torch.tanh(self.linear_in(x))
|
249 |
+
attention_scores = torch.softmax(self.linear_out(energy), dim=1)
|
250 |
+
context_vector = torch.sum(attention_scores * x, dim=1)
|
251 |
+
return context_vector
|
252 |
+
|
253 |
+
class _CNNAttentionTokenizer(nn.Module):
|
254 |
+
def __init__(self, input_channels, output_size, width, hidden_dim, num_conv_layers):
|
255 |
+
super(_CNNAttentionTokenizer, self).__init__()
|
256 |
+
self.width = width
|
257 |
+
self.cnns = nn.ModuleList()
|
258 |
+
self.attentions = nn.ModuleList()
|
259 |
+
for _ in range(width):
|
260 |
+
cnn = self._create_conv_layers(input_channels, num_conv_layers)
|
261 |
+
self.cnns.append(cnn)
|
262 |
+
attention = [Attention(hidden_dim, 125)]
|
263 |
+
linear_input_size = hidden_dim
|
264 |
+
attention.append(nn.Linear(linear_input_size, output_size))
|
265 |
+
self.attentions.append(nn.Sequential(*attention))
|
266 |
+
|
267 |
+
|
268 |
+
def _create_conv_layers(self, input_channels, num_conv_layers):
|
269 |
+
layers = []
|
270 |
+
in_channels = input_channels
|
271 |
+
for _ in range(num_conv_layers):
|
272 |
+
layers += [
|
273 |
+
nn.Conv1d(in_channels, 64, kernel_size=3, padding=1),
|
274 |
+
nn.ReLU(),
|
275 |
+
nn.MaxPool1d(kernel_size=2, stride=2)
|
276 |
+
]
|
277 |
+
in_channels = 64
|
278 |
+
return nn.Sequential(*layers)
|
279 |
+
|
280 |
+
def forward(self, x):
|
281 |
+
outputs = []
|
282 |
+
for token in range(self.width):
|
283 |
+
# Input shape: (batch_size, input_channels, sequence_length)
|
284 |
+
token_output = self.cnns[token](x) # Apply convolutional layers
|
285 |
+
token_output = token_output.permute(0, 2, 1) # Reshape for attention mechanism (batch_size, sequence_length, input_dim
|
286 |
+
token_output = self.attentions[token](token_output) # Apply attention mechanism
|
287 |
+
outputs.append(token_output)
|
288 |
+
output = torch.cat(outputs, dim=-2)
|
289 |
+
output = torch.stack([output])
|
290 |
+
return output
|
291 |
+
|
292 |
+
def build_attentive_cnn_projector(
|
293 |
+
input_channels: int, lm_hidden_size: int, num_tokens: int, hidden_dim: int, num_layers: int
|
294 |
+
):
|
295 |
+
return _CNNAttentionTokenizer(input_channels, lm_hidden_size, num_tokens, hidden_dim, num_layers)
|
296 |
+
|
297 |
+
class _CNNMLPProjector(nn.Module):
|
298 |
+
def __init__(self, input_channels, input_size, output_size = 4096, width = 5, hidden_dim = 64, num_conv_layers = 1, num_mlp_layers = 2):
|
299 |
+
super(_CNNMLPProjector, self).__init__()
|
300 |
+
self.width = width
|
301 |
+
self.cnnmlps = nn.ModuleList()
|
302 |
+
for _ in range(self.width):
|
303 |
+
cnnmlp = self._create_conv_layers(input_channels, num_conv_layers, hidden_dim)
|
304 |
+
cnnmlp.append(nn.Flatten())
|
305 |
+
cnn_output_size = hidden_dim*((input_size + 2*1 - 3*num_conv_layers) // (2**num_conv_layers) + 1)
|
306 |
+
cnnmlp.append(nn.Linear(cnn_output_size, output_size))
|
307 |
+
cnnmlp.append(nn.GELU())
|
308 |
+
cnnmlp += self._create_mlp_layers(output_size, output_size, num_mlp_layers, output_size)
|
309 |
+
self.cnnmlps.append(nn.Sequential(*cnnmlp))
|
310 |
+
|
311 |
+
def _create_conv_layers(self, input_channels, num_conv_layers, hidden_dim):
|
312 |
+
layers = []
|
313 |
+
for _ in range(num_conv_layers):
|
314 |
+
layers += [
|
315 |
+
nn.Conv1d(input_channels, hidden_dim, kernel_size=3, padding=1),
|
316 |
+
nn.GELU(),
|
317 |
+
nn.MaxPool1d(kernel_size=2, stride=2)
|
318 |
+
]
|
319 |
+
input_channels = hidden_dim
|
320 |
+
return layers
|
321 |
+
|
322 |
+
def _create_mlp_layers(self, input_size, output_size, num_layers, hidden_dim):
|
323 |
+
if num_layers >=2:
|
324 |
+
layers = [nn.Linear(input_size, hidden_dim)]
|
325 |
+
layers.append(nn.GELU())
|
326 |
+
if num_layers > 2:
|
327 |
+
for _ in range(1, num_layers - 2):
|
328 |
+
layers += [
|
329 |
+
nn.Linear(hidden_dim, hidden_dim),
|
330 |
+
nn.GELU()
|
331 |
+
]
|
332 |
+
layers.append(nn.Linear(hidden_dim, output_size))
|
333 |
+
else:
|
334 |
+
layers = [nn.Linear(input_size, output_size)]
|
335 |
+
return layers
|
336 |
+
|
337 |
+
def forward(self, x):
|
338 |
+
return torch.stack([torch.cat([cnnmlp(x) for cnnmlp in self.cnnmlps], dim=-2)])
|
339 |
+
|
340 |
+
def build_cnn_mlp_projector(
|
341 |
+
input_channels: int, input_size: int, lm_hidden_size: int, num_tokens: int, hidden_dim: int, num_conv_layers: int, num_mlp_layers: int
|
342 |
+
):
|
343 |
+
return _CNNMLPProjector(input_channels, input_size, lm_hidden_size, num_tokens, hidden_dim, num_conv_layers, num_mlp_layers)
|
344 |
+
|
345 |
+
class _MultiLayeredCNNMLPProjector(nn.Module):
|
346 |
+
def __init__(self, input_channels, input_size, num_feature_layers, output_size = 4096, width = 5, hidden_dim = 64, num_conv_layers = 1, num_mlp_layers = 2):
|
347 |
+
super(_MultiLayeredCNNMLPProjector, self).__init__()
|
348 |
+
self.width = width
|
349 |
+
self.num_feature_layers = num_feature_layers
|
350 |
+
self.cnnmlps = nn.ModuleList()
|
351 |
+
for _ in range(self.width*self.num_feature_layers):
|
352 |
+
cnnmlp = self._create_conv_layers(input_channels, num_conv_layers, hidden_dim)
|
353 |
+
cnnmlp += [nn.GELU()]
|
354 |
+
cnnmlp += self._create_mlp_layers(input_size, output_size, num_mlp_layers, output_size)
|
355 |
+
self.cnnmlps.append(nn.Sequential(*cnnmlp))
|
356 |
+
|
357 |
+
def _create_conv_layers(self, input_channels, num_conv_layers, hidden_size):
|
358 |
+
layers = []
|
359 |
+
|
360 |
+
if input_channels >= hidden_size:
|
361 |
+
hidden_dim = int(input_channels/2)
|
362 |
+
else:
|
363 |
+
hidden_dim = hidden_size
|
364 |
+
|
365 |
+
layers += [nn.Conv1d(in_channels=input_channels, out_channels=hidden_dim, kernel_size=3, stride=1, padding=1), nn.GELU()]
|
366 |
+
if num_conv_layers > 2:
|
367 |
+
for _ in range(num_conv_layers - 2):
|
368 |
+
if hidden_dim/2 >= hidden_size:
|
369 |
+
output_dim = int(hidden_dim/2)
|
370 |
+
else:
|
371 |
+
output_dim = hidden_size
|
372 |
+
layers += [
|
373 |
+
nn.Conv1d(in_channels=hidden_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1),
|
374 |
+
nn.GELU(),
|
375 |
+
]
|
376 |
+
hidden_dim = output_dim
|
377 |
+
layers += [nn.Conv1d(in_channels=hidden_dim, out_channels=1, kernel_size=3, stride=1, padding=1)]
|
378 |
+
return layers
|
379 |
+
|
380 |
+
def _create_mlp_layers(self, input_size, output_size, num_layers, hidden_dim):
|
381 |
+
if num_layers >=2:
|
382 |
+
layers = [nn.Linear(input_size, hidden_dim)]
|
383 |
+
layers.append(nn.GELU())
|
384 |
+
if num_layers > 2:
|
385 |
+
for _ in range(1, num_layers - 2):
|
386 |
+
layers += [
|
387 |
+
nn.Linear(hidden_dim, hidden_dim),
|
388 |
+
nn.GELU()
|
389 |
+
]
|
390 |
+
layers.append(nn.Linear(hidden_dim, output_size))
|
391 |
+
else:
|
392 |
+
layers = [nn.Linear(input_size, output_size)]
|
393 |
+
return layers
|
394 |
+
|
395 |
+
def forward(self, x):
|
396 |
+
print("X SHAPE ", x.shape)
|
397 |
+
inp_feature_layers = []
|
398 |
+
for feature_id in range(self.num_feature_layers):
|
399 |
+
in_feat_layer = x[feature_id].unsqueeze(0).permute(0,2,1)
|
400 |
+
inp_feature_layers.append(in_feat_layer)
|
401 |
+
|
402 |
+
outputs = []
|
403 |
+
for layer_count in range(self.width*self.num_feature_layers):
|
404 |
+
feature_id = int(layer_count/self.width)
|
405 |
+
outputs+=[self.cnnmlps[layer_count](inp_feature_layers[feature_id])]
|
406 |
+
|
407 |
+
return torch.cat(outputs, dim=-2)
|
408 |
+
|
409 |
+
|
410 |
+
def build_multi_layer_cnn_mlp_projector(
|
411 |
+
input_channels: int, input_size: int, num_feature_layers: int, lm_hidden_size: int, num_tokens: int, hidden_dim: int, num_conv_layers: int, num_mlp_layers: int
|
412 |
+
):
|
413 |
+
assert(num_tokens % num_feature_layers == 0)
|
414 |
+
width = int(num_tokens/num_feature_layers)
|
415 |
+
return _MultiLayeredCNNMLPProjector(input_channels, input_size, num_feature_layers, lm_hidden_size, width, hidden_dim, num_conv_layers, num_mlp_layers)
|
416 |
+
|
src/sonicverse/multi_token/modalities/video_xclip.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from transformers import AutoProcessor, AutoModel
|
6 |
+
|
7 |
+
from multi_token.data_tools import load_video
|
8 |
+
from multi_token.modalities.base_modality import Modality
|
9 |
+
from multi_token.modalities.projectors import (
|
10 |
+
build_mlp_vector_projector,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
OUTPUT_EMB_SIZE = 512
|
15 |
+
|
16 |
+
|
17 |
+
class XCLIPVideoModule(nn.Module):
|
18 |
+
def __init__(self, model_name_or_path: str):
|
19 |
+
super().__init__()
|
20 |
+
self.model_name_or_path = model_name_or_path
|
21 |
+
self.model = None
|
22 |
+
self.processor = None
|
23 |
+
|
24 |
+
self.load_model()
|
25 |
+
|
26 |
+
def load_model(self):
|
27 |
+
self.model = AutoModel.from_pretrained(self.model_name_or_path)
|
28 |
+
self.processor = AutoProcessor.from_pretrained(self.model_name_or_path)
|
29 |
+
self.model.requires_grad_(False)
|
30 |
+
|
31 |
+
@torch.no_grad()
|
32 |
+
def forward(self, video_inputs) -> torch.Tensor:
|
33 |
+
with torch.no_grad():
|
34 |
+
outputs = self.model(**(video_inputs.to(device=self.device)))
|
35 |
+
|
36 |
+
emb = outputs.video_embeds.to(device=self.device, dtype=self.dtype).view(
|
37 |
+
-1, 1, OUTPUT_EMB_SIZE
|
38 |
+
)
|
39 |
+
return emb
|
40 |
+
|
41 |
+
@property
|
42 |
+
def dtype(self):
|
43 |
+
return self.model.dtype
|
44 |
+
|
45 |
+
@property
|
46 |
+
def device(self):
|
47 |
+
return self.model.device
|
48 |
+
|
49 |
+
|
50 |
+
class XCLIPVideoModality(Modality):
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
model_name_or_path: str = "microsoft/xclip-base-patch32",
|
54 |
+
num_projector_layers: int = 2,
|
55 |
+
num_tokens_output: int = 10,
|
56 |
+
):
|
57 |
+
self.model_name_or_path = model_name_or_path
|
58 |
+
self.module = XCLIPVideoModule(model_name_or_path=self.model_name_or_path)
|
59 |
+
self.num_projector_layers = num_projector_layers
|
60 |
+
self.num_tokens_output = num_tokens_output
|
61 |
+
|
62 |
+
def build_projector(self, lm_hidden_size: int) -> nn.Module:
|
63 |
+
return build_mlp_vector_projector(
|
64 |
+
input_hidden_size=OUTPUT_EMB_SIZE,
|
65 |
+
lm_hidden_size=lm_hidden_size,
|
66 |
+
num_layers=self.num_projector_layers,
|
67 |
+
num_tokens=self.num_tokens_output,
|
68 |
+
)
|
69 |
+
|
70 |
+
@property
|
71 |
+
def name(self) -> str:
|
72 |
+
return "video_xclip"
|
73 |
+
|
74 |
+
@property
|
75 |
+
def token(self) -> str:
|
76 |
+
return "<video>"
|
77 |
+
|
78 |
+
@property
|
79 |
+
def data_key(self) -> str:
|
80 |
+
return "videos"
|
81 |
+
|
82 |
+
@property
|
83 |
+
def token_width(self) -> int:
|
84 |
+
return self.num_tokens_output
|
85 |
+
|
86 |
+
def to(self, dtype: torch.dtype, device: torch.device) -> "XCLIPVideoModality":
|
87 |
+
self.module.to(dtype=dtype, device=device)
|
88 |
+
return self
|
89 |
+
|
90 |
+
def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Dict]]:
|
91 |
+
row_values = []
|
92 |
+
for row in rows:
|
93 |
+
video_arrays = [
|
94 |
+
load_video(
|
95 |
+
video_info,
|
96 |
+
)
|
97 |
+
for video_info in row[self.data_key]
|
98 |
+
]
|
99 |
+
videos_enc = self.module.processor(
|
100 |
+
videos=[list(video) for video in video_arrays],
|
101 |
+
text=["IGNORE"],
|
102 |
+
return_tensors="pt",
|
103 |
+
padding=True,
|
104 |
+
)
|
105 |
+
row_values.append(videos_enc)
|
106 |
+
return row_values
|
107 |
+
|
108 |
+
@torch.no_grad()
|
109 |
+
def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]:
|
110 |
+
video_features = []
|
111 |
+
for video_batch in encoded_values:
|
112 |
+
video_features.append(self.module.forward(video_batch))
|
113 |
+
return video_features
|
src/sonicverse/multi_token/modalities/vision_clip.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Tuple, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from transformers import CLIPVisionModel, CLIPImageProcessor
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from multi_token.modalities.base_modality import Modality
|
9 |
+
from multi_token.modalities.projectors import (
|
10 |
+
build_patch_mlp_projector,
|
11 |
+
build_mlp_vector_projector,
|
12 |
+
)
|
13 |
+
from multi_token.data_tools import load_image
|
14 |
+
|
15 |
+
PATCH_LAYER = -2
|
16 |
+
OUTPUT_LAYER = -1
|
17 |
+
OUTPUT_EMB_SIZE = 1024
|
18 |
+
|
19 |
+
|
20 |
+
class CLIPVisionModule(nn.Module):
|
21 |
+
def __init__(self, model_name_or_path: str, feature_layer: int = PATCH_LAYER):
|
22 |
+
super().__init__()
|
23 |
+
self.feature_layer = feature_layer
|
24 |
+
self.model_name_or_path = model_name_or_path
|
25 |
+
self.image_processor = None
|
26 |
+
self.image_model = None
|
27 |
+
|
28 |
+
self.load_model()
|
29 |
+
|
30 |
+
def load_model(self):
|
31 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(
|
32 |
+
self.model_name_or_path
|
33 |
+
)
|
34 |
+
self.image_model = CLIPVisionModel.from_pretrained(self.model_name_or_path)
|
35 |
+
self.image_model.requires_grad_(False)
|
36 |
+
|
37 |
+
@torch.no_grad()
|
38 |
+
def forward(self, images) -> torch.Tensor:
|
39 |
+
if self.feature_layer == PATCH_LAYER:
|
40 |
+
image_forward_outs = self.image_model(
|
41 |
+
images.to(device=self.device, dtype=self.dtype),
|
42 |
+
output_hidden_states=True,
|
43 |
+
)
|
44 |
+
image_features = image_forward_outs.hidden_states[self.feature_layer]
|
45 |
+
image_features = image_features[:, 1:].to(images.dtype)
|
46 |
+
else:
|
47 |
+
image_forward_outs = self.image_model(
|
48 |
+
images.to(device=self.device, dtype=self.dtype),
|
49 |
+
)
|
50 |
+
image_features = image_forward_outs.pooler_output.to(images.dtype).view(
|
51 |
+
-1, 1, OUTPUT_EMB_SIZE
|
52 |
+
)
|
53 |
+
return image_features
|
54 |
+
|
55 |
+
@property
|
56 |
+
def dtype(self):
|
57 |
+
return self.image_model.dtype
|
58 |
+
|
59 |
+
@property
|
60 |
+
def device(self):
|
61 |
+
return self.image_model.device
|
62 |
+
|
63 |
+
@property
|
64 |
+
def config(self):
|
65 |
+
return self.image_model.config
|
66 |
+
|
67 |
+
@property
|
68 |
+
def hidden_size(self):
|
69 |
+
return self.config.hidden_size
|
70 |
+
|
71 |
+
@property
|
72 |
+
def num_patches(self):
|
73 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
74 |
+
|
75 |
+
|
76 |
+
def _expand2square(pil_img: Image, background_color: Tuple) -> Image:
|
77 |
+
width, height = pil_img.size
|
78 |
+
if width == height:
|
79 |
+
return pil_img
|
80 |
+
elif width > height:
|
81 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
82 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
83 |
+
return result
|
84 |
+
else:
|
85 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
86 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
87 |
+
return result
|
88 |
+
|
89 |
+
|
90 |
+
class CLIPVisionModality(Modality):
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
model_name_or_path: str = "openai/clip-vit-large-patch14-336",
|
94 |
+
pad_non_square_images: bool = False,
|
95 |
+
num_projector_layers: int = 2,
|
96 |
+
feature_layer: int = PATCH_LAYER,
|
97 |
+
num_tokens_output: Optional[int] = None,
|
98 |
+
):
|
99 |
+
if feature_layer not in [PATCH_LAYER, OUTPUT_LAYER]:
|
100 |
+
raise ValueError(
|
101 |
+
f"feature_layer must be one of {PATCH_LAYER} or {OUTPUT_LAYER}"
|
102 |
+
)
|
103 |
+
if (feature_layer == PATCH_LAYER) != (num_tokens_output is None):
|
104 |
+
raise ValueError(
|
105 |
+
"num_tokens_output must be None if feature_layer is PATCH_LAYER"
|
106 |
+
)
|
107 |
+
self.model_name_or_path = model_name_or_path
|
108 |
+
self.module = CLIPVisionModule(
|
109 |
+
model_name_or_path=self.model_name_or_path, feature_layer=feature_layer
|
110 |
+
)
|
111 |
+
self.pad_non_square_images = pad_non_square_images
|
112 |
+
self.num_projector_layers = num_projector_layers
|
113 |
+
self.num_tokens_output = num_tokens_output
|
114 |
+
|
115 |
+
def build_projector(self, lm_hidden_size: int) -> nn.Module:
|
116 |
+
if self.module.feature_layer == PATCH_LAYER:
|
117 |
+
return build_patch_mlp_projector(
|
118 |
+
self.module.hidden_size,
|
119 |
+
lm_hidden_size,
|
120 |
+
num_layers=self.num_projector_layers,
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
return build_mlp_vector_projector(
|
124 |
+
input_hidden_size=OUTPUT_EMB_SIZE,
|
125 |
+
lm_hidden_size=lm_hidden_size,
|
126 |
+
num_layers=self.num_projector_layers,
|
127 |
+
num_tokens=self.num_tokens_output,
|
128 |
+
)
|
129 |
+
|
130 |
+
@property
|
131 |
+
def name(self) -> str:
|
132 |
+
return "vision_clip"
|
133 |
+
|
134 |
+
@property
|
135 |
+
def token(self) -> str:
|
136 |
+
return "<image>"
|
137 |
+
|
138 |
+
@property
|
139 |
+
def data_key(self) -> str:
|
140 |
+
return "images"
|
141 |
+
|
142 |
+
@property
|
143 |
+
def token_width(self) -> int:
|
144 |
+
if self.module.feature_layer == PATCH_LAYER:
|
145 |
+
return self.module.num_patches
|
146 |
+
else:
|
147 |
+
return self.num_tokens_output
|
148 |
+
|
149 |
+
def to(self, dtype: torch.dtype, device: torch.device) -> "CLIPVisionModality":
|
150 |
+
self.module.to(dtype=dtype, device=device)
|
151 |
+
return self
|
152 |
+
|
153 |
+
def preprocess_rows(self, rows: List[Dict]) -> List[Optional[torch.Tensor]]:
|
154 |
+
row_values = []
|
155 |
+
for row in rows:
|
156 |
+
images = []
|
157 |
+
for image_fn in row[self.data_key]:
|
158 |
+
image_obj = load_image(image_fn)
|
159 |
+
if self.pad_non_square_images:
|
160 |
+
image_obj = _expand2square(
|
161 |
+
image_obj,
|
162 |
+
tuple(
|
163 |
+
int(x * 255) for x in self.module.image_processor.image_mean
|
164 |
+
),
|
165 |
+
)
|
166 |
+
image = self.module.image_processor.preprocess(
|
167 |
+
image_obj, return_tensors="pt"
|
168 |
+
)["pixel_values"][0]
|
169 |
+
images.append(image)
|
170 |
+
row_values.append(torch.stack(images) if len(images) > 0 else None)
|
171 |
+
return row_values
|
172 |
+
|
173 |
+
@torch.no_grad()
|
174 |
+
def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]:
|
175 |
+
image_features = []
|
176 |
+
for image_batch in encoded_values:
|
177 |
+
image_features.append(self.module.forward(image_batch))
|
178 |
+
return image_features
|
src/sonicverse/multi_token/model_utils.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict
|
2 |
+
import logging
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from enum import Enum
|
6 |
+
|
7 |
+
class MultiTaskType(Enum):
|
8 |
+
NO_MULTI_TASK = 0
|
9 |
+
SIMPLE_MULTI_TASK = 1
|
10 |
+
PROJECTED_MULTI_TASK = 2
|
11 |
+
|
12 |
+
def _find_all_linear_names(model) -> List[str]:
|
13 |
+
cls = torch.nn.Linear
|
14 |
+
lora_module_names = set()
|
15 |
+
for name, module in model.named_modules():
|
16 |
+
if isinstance(module, cls):
|
17 |
+
names = name.split(".")
|
18 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
19 |
+
|
20 |
+
if "lm_head" in lora_module_names:
|
21 |
+
lora_module_names.remove("lm_head")
|
22 |
+
return list(lora_module_names)
|
23 |
+
|
24 |
+
|
25 |
+
def maybe_zero_3(param, ignore_status=False, name=None):
|
26 |
+
from deepspeed import zero
|
27 |
+
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
28 |
+
|
29 |
+
if hasattr(param, "ds_id"):
|
30 |
+
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
|
31 |
+
if not ignore_status:
|
32 |
+
logging.warning(
|
33 |
+
f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}"
|
34 |
+
)
|
35 |
+
with zero.GatheredParameters([param]):
|
36 |
+
param = param.data.detach().cpu().clone()
|
37 |
+
else:
|
38 |
+
param = param.detach().cpu().clone()
|
39 |
+
return param
|
40 |
+
|
41 |
+
|
42 |
+
def get_peft_state(named_params, bias) -> Dict:
|
43 |
+
if bias == "none":
|
44 |
+
to_return = {k: t for k, t in named_params if "lora_" in k}
|
45 |
+
elif bias == "all":
|
46 |
+
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
|
47 |
+
elif bias == "lora_only":
|
48 |
+
to_return = {}
|
49 |
+
maybe_lora_bias = {}
|
50 |
+
lora_bias_names = set()
|
51 |
+
for k, t in named_params:
|
52 |
+
if "lora_" in k:
|
53 |
+
to_return[k] = t
|
54 |
+
bias_name = k.split("lora_")[0] + "bias"
|
55 |
+
lora_bias_names.add(bias_name)
|
56 |
+
elif "bias" in k:
|
57 |
+
maybe_lora_bias[k] = t
|
58 |
+
for k, t in maybe_lora_bias:
|
59 |
+
if bias_name in lora_bias_names:
|
60 |
+
to_return[bias_name] = t
|
61 |
+
else:
|
62 |
+
raise NotImplementedError()
|
63 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
|
64 |
+
return to_return
|
65 |
+
|
66 |
+
|
67 |
+
def get_peft_state_non_lora(named_params, task_names) -> Dict:
|
68 |
+
to_return = {}
|
69 |
+
for k, t in named_params:
|
70 |
+
if "lora_" not in k:
|
71 |
+
task_name_in_k = False
|
72 |
+
for task_name in task_names:
|
73 |
+
if task_name in k:
|
74 |
+
task_name_in_k = True
|
75 |
+
if t.requires_grad or task_name_in_k:
|
76 |
+
to_return[k] = t
|
77 |
+
to_return = {
|
78 |
+
k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()
|
79 |
+
}
|
80 |
+
return to_return
|
81 |
+
|
82 |
+
|
83 |
+
def make_model_lora(model, training_args: "TrainingArguments"):
|
84 |
+
from peft import LoraConfig, get_peft_model
|
85 |
+
|
86 |
+
lora_config = LoraConfig(
|
87 |
+
r=training_args.lora_r,
|
88 |
+
lora_alpha=training_args.lora_alpha,
|
89 |
+
target_modules=_find_all_linear_names(model),
|
90 |
+
lora_dropout=training_args.lora_dropout,
|
91 |
+
bias=training_args.lora_bias,
|
92 |
+
task_type="CAUSAL_LM",
|
93 |
+
)
|
94 |
+
if training_args.bits == 16:
|
95 |
+
if training_args.bf16:
|
96 |
+
model.to(torch.bfloat16)
|
97 |
+
if training_args.fp16:
|
98 |
+
model.to(torch.float16)
|
99 |
+
|
100 |
+
model = get_peft_model(model, lora_config)
|
101 |
+
return model
|
102 |
+
|
103 |
+
|
104 |
+
def fix_tokenizer(tokenizer):
|
105 |
+
if tokenizer.pad_token is None:
|
106 |
+
tokenizer.pad_token = tokenizer.unk_token
|
107 |
+
if tokenizer.mask_token is None:
|
108 |
+
tokenizer.mask_token = tokenizer.unk_token
|
109 |
+
if tokenizer.cls_token is None:
|
110 |
+
tokenizer.cls_token = tokenizer.unk_token
|
111 |
+
if tokenizer.sep_token is None:
|
112 |
+
tokenizer.sep_token = tokenizer.unk_token
|
src/sonicverse/multi_token/training.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, List
|
2 |
+
from dataclasses import field, dataclass
|
3 |
+
import logging
|
4 |
+
import subprocess
|
5 |
+
import pathlib
|
6 |
+
import torch
|
7 |
+
import shutil
|
8 |
+
import glob
|
9 |
+
import os
|
10 |
+
import json
|
11 |
+
|
12 |
+
import transformers
|
13 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
14 |
+
from transformers import Trainer
|
15 |
+
|
16 |
+
from multi_token.training_data import (
|
17 |
+
DataArguments,
|
18 |
+
LMMDataset,
|
19 |
+
DataCollatorForSupervisedLMMDataset,
|
20 |
+
)
|
21 |
+
from multi_token.model_utils import (
|
22 |
+
make_model_lora,
|
23 |
+
get_peft_state,
|
24 |
+
get_peft_state_non_lora,
|
25 |
+
fix_tokenizer,
|
26 |
+
MultiTaskType
|
27 |
+
)
|
28 |
+
from multi_token.modalities.base_modality import Modality
|
29 |
+
|
30 |
+
|
31 |
+
README_TEMPLATE = """
|
32 |
+
---
|
33 |
+
license: apache-2.0
|
34 |
+
base_model: {base_model}
|
35 |
+
dataset: {dataset}
|
36 |
+
tags:
|
37 |
+
- finetuned
|
38 |
+
- multimodal
|
39 |
+
inference: false
|
40 |
+
---
|
41 |
+
|
42 |
+
These are weights for a version of `{base_model}` finetuned for multimodal applications.
|
43 |
+
|
44 |
+
### Modalities
|
45 |
+
|
46 |
+
{modalities}
|
47 |
+
|
48 |
+
### Usage
|
49 |
+
|
50 |
+
GitHub: https://github.com/sshh12/multi_token (includes training scripts and basic inference server)
|
51 |
+
|
52 |
+
### Dataset
|
53 |
+
|
54 |
+
{dataset} ({num_examples} examples)
|
55 |
+
|
56 |
+
```
|
57 |
+
{dataset_example}
|
58 |
+
```
|
59 |
+
|
60 |
+
### Training Device(s)
|
61 |
+
|
62 |
+
```
|
63 |
+
{training_devices_dump}
|
64 |
+
```
|
65 |
+
|
66 |
+
|
67 |
+
### Model
|
68 |
+
|
69 |
+
```
|
70 |
+
{repr_model}
|
71 |
+
```
|
72 |
+
|
73 |
+
"""
|
74 |
+
|
75 |
+
|
76 |
+
@dataclass
|
77 |
+
class TrainingArguments(transformers.TrainingArguments):
|
78 |
+
cache_dir: Optional[str] = field(default=None)
|
79 |
+
remove_unused_columns: bool = field(default=False)
|
80 |
+
optim: str = field(default="adamw_torch")
|
81 |
+
model_max_length: int = field(
|
82 |
+
default=512,
|
83 |
+
metadata={
|
84 |
+
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
85 |
+
},
|
86 |
+
)
|
87 |
+
double_quant: bool = field(
|
88 |
+
default=True,
|
89 |
+
metadata={
|
90 |
+
"help": "Compress the quantization statistics through double quantization."
|
91 |
+
},
|
92 |
+
)
|
93 |
+
quant_type: str = field(
|
94 |
+
default="nf4",
|
95 |
+
metadata={
|
96 |
+
"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."
|
97 |
+
},
|
98 |
+
)
|
99 |
+
pretrain_projectors: bool = field(default=False)
|
100 |
+
pretrained_projectors_path: Optional[str] = field(default=None)
|
101 |
+
pretrained_projectors_config: Optional[str] = field(default=None)
|
102 |
+
bits: int = field(default=16, metadata={"help": "How many bits to use."})
|
103 |
+
lora_enable: bool = False
|
104 |
+
lora_r: int = 64
|
105 |
+
lora_alpha: int = 16
|
106 |
+
lora_dropout: float = 0.05
|
107 |
+
lora_weight_path: str = ""
|
108 |
+
lora_bias: str = "none"
|
109 |
+
|
110 |
+
|
111 |
+
@dataclass
|
112 |
+
class ModelArguments:
|
113 |
+
model_name_or_path: str = field(default="mistralai/Mistral-7B-Instruct-v0.1")
|
114 |
+
model_cls: str = field(default="MistralLMMForCausalLM")
|
115 |
+
modality_builder: str = field(default="vision_clip")
|
116 |
+
use_multi_task: int = field(default=MultiTaskType.PROJECTED_MULTI_TASK)
|
117 |
+
tasks_config: str = field(default="src/sonicverse/configs/tasks.json")
|
118 |
+
model_lora_path: Optional[str] = field(default="annabeth97c/sonicverse")
|
119 |
+
|
120 |
+
|
121 |
+
class LMMTrainer(Trainer):
|
122 |
+
def _save_checkpoint(self, model, trial, metrics=None):
|
123 |
+
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
124 |
+
|
125 |
+
run_dir = self._get_output_dir(trial=trial)
|
126 |
+
output_dir = os.path.join(run_dir, checkpoint_folder)
|
127 |
+
self._save_extras(output_dir)
|
128 |
+
|
129 |
+
super(LMMTrainer, self)._save_checkpoint(model, trial, metrics)
|
130 |
+
|
131 |
+
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
132 |
+
self._save_extras(output_dir)
|
133 |
+
super(LMMTrainer, self)._save(output_dir, state_dict)
|
134 |
+
for unused_dir in glob.iglob(os.path.join(output_dir, "global_step*")):
|
135 |
+
shutil.rmtree(unused_dir)
|
136 |
+
|
137 |
+
def _save_extras(self, output_dir: Optional[str] = None):
|
138 |
+
self.model.config.save_pretrained(output_dir)
|
139 |
+
|
140 |
+
task_names = []
|
141 |
+
for m in self.model.modalities:
|
142 |
+
task_names += m.tasks["task_heads"].keys()
|
143 |
+
|
144 |
+
non_lora_state_dict = get_peft_state_non_lora(self.model.named_parameters(), task_names)
|
145 |
+
torch.save(
|
146 |
+
non_lora_state_dict,
|
147 |
+
os.path.join(output_dir, "non_lora_trainables.bin"),
|
148 |
+
)
|
149 |
+
|
150 |
+
|
151 |
+
def _get_training_devices_dump() -> str:
|
152 |
+
out = subprocess.check_output(
|
153 |
+
["nvidia-smi", "--query-gpu=gpu_name,gpu_bus_id,vbios_version", "--format=csv"]
|
154 |
+
)
|
155 |
+
return out.decode("utf-8").strip()
|
156 |
+
|
157 |
+
|
158 |
+
def train_for_modalities(
|
159 |
+
model_cls,
|
160 |
+
training_args: TrainingArguments,
|
161 |
+
model_args: ModelArguments,
|
162 |
+
train_data_args: DataArguments,
|
163 |
+
evaluation_data_args: DataArguments,
|
164 |
+
modalities: List[Modality],
|
165 |
+
):
|
166 |
+
for m in modalities:
|
167 |
+
m.to(
|
168 |
+
dtype=torch.bfloat16 if training_args.bf16 else torch.float16,
|
169 |
+
device=training_args.device,
|
170 |
+
)
|
171 |
+
|
172 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
173 |
+
model_args.model_name_or_path,
|
174 |
+
cache_dir=training_args.cache_dir,
|
175 |
+
model_max_length=training_args.model_max_length,
|
176 |
+
padding_side="right",
|
177 |
+
use_fast=False,
|
178 |
+
)
|
179 |
+
fix_tokenizer(tokenizer)
|
180 |
+
|
181 |
+
train_dataset = LMMDataset(train_data_args, tokenizer, modalities)
|
182 |
+
evaluation_dataset = LMMDataset(evaluation_data_args, tokenizer, modalities)
|
183 |
+
collator = DataCollatorForSupervisedLMMDataset(tokenizer, modalities)
|
184 |
+
|
185 |
+
model = model_cls.from_pretrained(
|
186 |
+
model_args.model_name_or_path,
|
187 |
+
cache_dir=training_args.cache_dir,
|
188 |
+
)
|
189 |
+
model.to(
|
190 |
+
dtype=torch.bfloat16 if training_args.bf16 else torch.float16,
|
191 |
+
device=training_args.device,
|
192 |
+
)
|
193 |
+
model.modalities = modalities
|
194 |
+
model.config.use_cache = False
|
195 |
+
model.config.model_cls = model_cls.__name__
|
196 |
+
model.config.modality_builder = model_args.modality_builder
|
197 |
+
|
198 |
+
if training_args.gradient_checkpointing:
|
199 |
+
if hasattr(model, "enable_input_require_grads"):
|
200 |
+
model.enable_input_require_grads()
|
201 |
+
else:
|
202 |
+
|
203 |
+
def make_inputs_require_grad(module, input, output):
|
204 |
+
output.requires_grad_(True)
|
205 |
+
|
206 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
207 |
+
|
208 |
+
if model_args.model_lora_path:
|
209 |
+
raise ValueError(
|
210 |
+
"LoRA path not supported for training -- set the output path to an existing model to resume training"
|
211 |
+
)
|
212 |
+
|
213 |
+
if training_args.lora_enable:
|
214 |
+
logging.info("Adding LoRA adapters...")
|
215 |
+
model = make_model_lora(model, training_args)
|
216 |
+
|
217 |
+
if training_args.pretrained_projectors_path:
|
218 |
+
projector_weights_og = torch.load(
|
219 |
+
training_args.pretrained_projectors_path, map_location="cpu"
|
220 |
+
)
|
221 |
+
if model_args.use_multi_task==MultiTaskType.SIMPLE_MULTI_TASK:
|
222 |
+
projector_weights = {}
|
223 |
+
for k, v in projector_weights_og.items():
|
224 |
+
for m in modalities:
|
225 |
+
for task_name in m.tasks["task_heads"].keys():
|
226 |
+
if task_name in k:
|
227 |
+
projector_weights[k] = v
|
228 |
+
else:
|
229 |
+
projector_weights = {
|
230 |
+
k: v for k, v in projector_weights_og.items() if "_lmm_projector" in k
|
231 |
+
}
|
232 |
+
|
233 |
+
elif training_args.pretrained_projectors_config:
|
234 |
+
with open(training_args.pretrained_projectors_config, "r") as f:
|
235 |
+
pretrained_weights_config = json.load(f)
|
236 |
+
|
237 |
+
projector_weights = {}
|
238 |
+
|
239 |
+
for pretrained_path_info in pretrained_weights_config["pretrained_paths"]:
|
240 |
+
pretrained_path = pretrained_path_info["path"]
|
241 |
+
components = pretrained_path_info["components"]
|
242 |
+
use_prefix = pretrained_path_info["use_prefix"]
|
243 |
+
prefix = pretrained_path_info["prefix"]
|
244 |
+
|
245 |
+
pretrained_weights = torch.load(pretrained_path, map_location="cpu")
|
246 |
+
|
247 |
+
for k, v in pretrained_weights.items():
|
248 |
+
if any(component in k for component in components):
|
249 |
+
weight_key = k
|
250 |
+
if use_prefix:
|
251 |
+
weight_key = prefix + "." + k
|
252 |
+
projector_weights[weight_key] = v
|
253 |
+
|
254 |
+
else:
|
255 |
+
projector_weights = {}
|
256 |
+
|
257 |
+
model.get_model().initialize_modules(modalities, projector_weights)
|
258 |
+
|
259 |
+
task_names = []
|
260 |
+
tasks = {}
|
261 |
+
for m in model.modalities:
|
262 |
+
if m.use_multi_task != MultiTaskType.NO_MULTI_TASK:
|
263 |
+
tasks = m.tasks
|
264 |
+
task_names += m.tasks["task_heads"].keys()
|
265 |
+
|
266 |
+
if training_args.pretrain_projectors:
|
267 |
+
model.requires_grad_(False)
|
268 |
+
for m in modalities:
|
269 |
+
if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
|
270 |
+
for task_name in m.tasks["task_heads"].keys():
|
271 |
+
task_model = getattr(model.get_model(), m.name + "_" + task_name)
|
272 |
+
for p in task_model.parameters():
|
273 |
+
p.requires_grad = True
|
274 |
+
elif m.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK:
|
275 |
+
proj = getattr(model.get_model(), m.name + "_lmm_projector")
|
276 |
+
|
277 |
+
if "backbone" in m.tasks.keys():
|
278 |
+
backbone = getattr(proj, "backbone")
|
279 |
+
for backbone_param in backbone.parameters():
|
280 |
+
backbone_param.requires_grad = tasks["backbone"]["requires_grad"]
|
281 |
+
|
282 |
+
for task in task_names:
|
283 |
+
task_head = getattr(proj, task)
|
284 |
+
for task_head_param in task_head.parameters():
|
285 |
+
task_head_param.requires_grad = tasks["task_heads"][task]["requires_grad"]
|
286 |
+
if task in tasks["task_projectors"]:
|
287 |
+
task_projector = getattr(proj, task + "_projector")
|
288 |
+
for task_projector_param in task_projector.parameters():
|
289 |
+
task_projector_param.requires_grad = tasks["task_projectors"][task]["requires_grad"]
|
290 |
+
|
291 |
+
else:
|
292 |
+
proj = getattr(model.get_model(), m.name + "_lmm_projector")
|
293 |
+
for p in proj.parameters():
|
294 |
+
p.requires_grad = True
|
295 |
+
|
296 |
+
os.makedirs(training_args.output_dir, exist_ok=True)
|
297 |
+
with open(
|
298 |
+
os.path.join(training_args.output_dir, "model_named_parameters.txt"), "w"
|
299 |
+
) as f:
|
300 |
+
for name, param in model.named_parameters():
|
301 |
+
f.write(f"{name} {param.shape} {param.requires_grad}\n")
|
302 |
+
|
303 |
+
with open(os.path.join(training_args.output_dir, "README.md"), "w") as f:
|
304 |
+
modalities_text = [
|
305 |
+
f"* {m.__class__.__name__} (use `{m.token}` in text and provide `{m.data_key}`, encoded as {m.token_width} tokens)"
|
306 |
+
for m in modalities
|
307 |
+
]
|
308 |
+
readme_text = README_TEMPLATE.format(
|
309 |
+
base_model=model_args.model_name_or_path,
|
310 |
+
dataset=train_data_args.dataset_path,
|
311 |
+
dataset_example=repr(train_dataset.get_example()),
|
312 |
+
num_examples=len(train_dataset),
|
313 |
+
modalities="\n".join(modalities_text),
|
314 |
+
training_devices_dump=_get_training_devices_dump(),
|
315 |
+
repr_model=f"{model_cls.__name__}.model =\n\n{repr(model)}",
|
316 |
+
)
|
317 |
+
f.write(readme_text)
|
318 |
+
|
319 |
+
trainer = LMMTrainer(
|
320 |
+
model=model,
|
321 |
+
tokenizer=tokenizer,
|
322 |
+
args=training_args,
|
323 |
+
data_collator=collator,
|
324 |
+
train_dataset=train_dataset,
|
325 |
+
eval_dataset=evaluation_dataset,
|
326 |
+
)
|
327 |
+
|
328 |
+
if list(pathlib.Path(training_args.output_dir).glob(f"{PREFIX_CHECKPOINT_DIR}-*")):
|
329 |
+
trainer.train(resume_from_checkpoint=True)
|
330 |
+
else:
|
331 |
+
trainer.train()
|
332 |
+
|
333 |
+
trainer.save_state()
|
334 |
+
|
335 |
+
model.config.use_cache = True
|
336 |
+
model.config.save_pretrained(training_args.output_dir)
|
337 |
+
state_dict = get_peft_state(model.named_parameters(), training_args.lora_bias)
|
338 |
+
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
|
339 |
+
|
340 |
+
non_lora_state_dict = get_peft_state_non_lora(model.named_parameters(), task_names)
|
341 |
+
torch.save(
|
342 |
+
non_lora_state_dict,
|
343 |
+
os.path.join(training_args.output_dir, "non_lora_trainables.bin"),
|
344 |
+
)
|
src/sonicverse/multi_token/training_data.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Sequence
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
from datasets import load_from_disk, load_dataset, Dataset as HFDataset
|
8 |
+
import transformers
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from multi_token.modalities.base_modality import Modality
|
12 |
+
from multi_token.constants import IGNORE_INDEX
|
13 |
+
from multi_token.data_tools import encode_chat, encode_chat_multitask
|
14 |
+
from multi_token.model_utils import MultiTaskType
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class DataArguments:
|
19 |
+
dataset_path: str = field(
|
20 |
+
default=None, metadata={"help": "Path to the training data."}
|
21 |
+
)
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class TrainDataArguments:
|
25 |
+
train_dataset_path: str = field(
|
26 |
+
default=None, metadata={"help": "Path to the training data."}
|
27 |
+
)
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class EvaluationDataArguments:
|
31 |
+
evaluation_dataset_path: str = field(
|
32 |
+
default=None, metadata={"help": "Path to the evaluation data."}
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
def _resolve_dataset(path: str) -> HFDataset:
|
37 |
+
if os.path.exists(path):
|
38 |
+
return load_from_disk(path)
|
39 |
+
else:
|
40 |
+
return load_dataset(path, split="train", data_files="*.arrow")
|
41 |
+
|
42 |
+
|
43 |
+
class LMMDataset(Dataset):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
data_args: DataArguments,
|
47 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
48 |
+
modalities: List[Modality],
|
49 |
+
):
|
50 |
+
super(LMMDataset, self).__init__()
|
51 |
+
self.dataset = _resolve_dataset(data_args.dataset_path)
|
52 |
+
self.tokenizer = tokenizer
|
53 |
+
self.modalities = modalities
|
54 |
+
|
55 |
+
def __len__(self):
|
56 |
+
return len(self.dataset)
|
57 |
+
|
58 |
+
def get_example(self) -> Dict:
|
59 |
+
return self.dataset[0]
|
60 |
+
|
61 |
+
def __getitem__(self, i) -> Dict:
|
62 |
+
try:
|
63 |
+
item = self.dataset[i]
|
64 |
+
use_multi_task = MultiTaskType.NO_MULTI_TASK
|
65 |
+
for m in self.modalities:
|
66 |
+
if m.use_multi_task != MultiTaskType.NO_MULTI_TASK:
|
67 |
+
use_multi_task = m.use_multi_task
|
68 |
+
break
|
69 |
+
if use_multi_task != MultiTaskType.NO_MULTI_TASK:
|
70 |
+
return encode_chat_multitask(item, self.tokenizer, self.modalities)
|
71 |
+
else:
|
72 |
+
return encode_chat(item, self.tokenizer, self.modalities)
|
73 |
+
except Exception as e:
|
74 |
+
new_i = i + 1
|
75 |
+
if new_i >= len(self):
|
76 |
+
new_i = 0
|
77 |
+
logging.error(f"Error encoding chat: {e} index={i} trying index={new_i}")
|
78 |
+
return self.__getitem__(new_i)
|
79 |
+
|
80 |
+
|
81 |
+
@dataclass
|
82 |
+
class DataCollatorForSupervisedLMMDataset:
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
86 |
+
modalities: List[Modality],
|
87 |
+
):
|
88 |
+
self.tokenizer = tokenizer
|
89 |
+
self.modalities = modalities
|
90 |
+
|
91 |
+
self.use_multi_task = MultiTaskType.NO_MULTI_TASK
|
92 |
+
for modality in self.modalities:
|
93 |
+
if modality.use_multi_task != MultiTaskType.NO_MULTI_TASK:
|
94 |
+
self.use_multi_task = modality.use_multi_task
|
95 |
+
break
|
96 |
+
|
97 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, List]:
|
98 |
+
input_ids = []
|
99 |
+
lmm_labels = []
|
100 |
+
task_labels = []
|
101 |
+
for instance in instances:
|
102 |
+
input_ids.append(instance["input_ids"])
|
103 |
+
if self.use_multi_task == MultiTaskType.NO_MULTI_TASK:
|
104 |
+
lmm_labels.append(instance["labels"])
|
105 |
+
else:
|
106 |
+
lmm_labels.append(instance["labels"][0])
|
107 |
+
inst_task_labels = []
|
108 |
+
for label_id in range(1, len(instance["labels"])):
|
109 |
+
inst_task_labels.append(instance["labels"][label_id])
|
110 |
+
task_labels.append(inst_task_labels)
|
111 |
+
|
112 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
113 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
114 |
+
)
|
115 |
+
# print("Lmm labels 1 type :", type(lmm_labels))
|
116 |
+
lmm_labels = torch.nn.utils.rnn.pad_sequence(
|
117 |
+
lmm_labels, batch_first=True, padding_value=IGNORE_INDEX
|
118 |
+
)
|
119 |
+
# print("Lmm labels 2 type :", type(lmm_labels))
|
120 |
+
|
121 |
+
input_ids = input_ids[:, : self.tokenizer.model_max_length]
|
122 |
+
lmm_labels = lmm_labels[:, : self.tokenizer.model_max_length]
|
123 |
+
output_labels = [lmm_labels, task_labels]
|
124 |
+
batch = dict(
|
125 |
+
input_ids=input_ids,
|
126 |
+
labels=output_labels,
|
127 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
128 |
+
)
|
129 |
+
|
130 |
+
for m in self.modalities:
|
131 |
+
batch[m.name] = [instance[m.name] for instance in instances]
|
132 |
+
|
133 |
+
return batch
|
src/sonicverse/requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers>=4.34.0
|
2 |
+
accelerate>=0.21.0
|
3 |
+
scipy>=1.11.3
|
4 |
+
bitsandbytes>=0.41.0
|
5 |
+
datasets>=2.14.5
|
6 |
+
sentencepiece>=0.1.99
|
7 |
+
peft>=0.4.0
|
8 |
+
deepspeed==0.9.5
|
src/sonicverse/scripts/audio_setup.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
pip install librosa soundfile
|
src/sonicverse/scripts/clap_gpt_build_finetune_dataset.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import openai
|
7 |
+
|
8 |
+
from datasets import Dataset, load_dataset
|
9 |
+
|
10 |
+
from multi_token.constants import ROLE_ASSISTANT, ROLE_USER
|
11 |
+
|
12 |
+
PROMPT = """
|
13 |
+
You are helping train a sound assistant that can take audio inputs and output text.
|
14 |
+
|
15 |
+
You can hear an audio file with the following metadata tags:
|
16 |
+
{captions}
|
17 |
+
|
18 |
+
{question}
|
19 |
+
|
20 |
+
Include the question and answer.
|
21 |
+
"""
|
22 |
+
|
23 |
+
QUESTIONS = [
|
24 |
+
"Ask a question about the content of the audio.",
|
25 |
+
"Ask a complex question about the content of the audio.",
|
26 |
+
"Ask a complex question that is relevant to the content of the audio, for example, asking about background knowledge of the things mentioned. Do not ask about uncertain details.",
|
27 |
+
"Ask a complex question that is relevant to the content of the audio, for example, asking about the events referred to in the audio. Do not ask about uncertain details.",
|
28 |
+
"Ask about your thoughts on the audio.",
|
29 |
+
"Ask about what occurs in the audio.",
|
30 |
+
"Ask a question on a topic that related to the audio.",
|
31 |
+
"Ask a question that classifies the audio in some way.",
|
32 |
+
"Ask a question that can only be answered by listening to the audio.",
|
33 |
+
]
|
34 |
+
|
35 |
+
|
36 |
+
OPENAI_TOOLS = [
|
37 |
+
{
|
38 |
+
"type": "function",
|
39 |
+
"function": {
|
40 |
+
"name": "create_chat",
|
41 |
+
"description": "Create a training example",
|
42 |
+
"parameters": {
|
43 |
+
"type": "object",
|
44 |
+
"properties": {
|
45 |
+
"question": {
|
46 |
+
"type": "string",
|
47 |
+
"description": "The question, must be provided",
|
48 |
+
},
|
49 |
+
"answer": {
|
50 |
+
"type": "string",
|
51 |
+
"description": "The answer to the question, must be provided",
|
52 |
+
},
|
53 |
+
},
|
54 |
+
"required": ["question", "answer"],
|
55 |
+
},
|
56 |
+
},
|
57 |
+
}
|
58 |
+
]
|
59 |
+
|
60 |
+
|
61 |
+
def _build_convo(row) -> List:
|
62 |
+
client = openai.Client()
|
63 |
+
|
64 |
+
captions = [row["metadataTags"]]
|
65 |
+
paths = [row["url"]]
|
66 |
+
|
67 |
+
captions_text = "\n".join([f"{cap}" for i, cap in enumerate(captions)])
|
68 |
+
prompt = PROMPT.format(
|
69 |
+
captions=captions_text, question=random.choice(QUESTIONS)
|
70 |
+
).strip()
|
71 |
+
|
72 |
+
completion = client.chat.completions.create(
|
73 |
+
model="gpt-3.5-turbo-1106",
|
74 |
+
messages=[{"role": "system", "content": prompt}],
|
75 |
+
tools=OPENAI_TOOLS,
|
76 |
+
tool_choice={"type": "function", "function": {"name": "create_chat"}},
|
77 |
+
)
|
78 |
+
resp = json.loads(completion.choices[0].message.tool_calls[0].function.arguments)
|
79 |
+
if "answer" not in resp:
|
80 |
+
print(resp)
|
81 |
+
q = resp["question"]
|
82 |
+
a = resp["answer"]
|
83 |
+
|
84 |
+
if random.choice([True, False]):
|
85 |
+
q = "<sound>" * len(captions) + " " + q
|
86 |
+
else:
|
87 |
+
q = q + " " + "<sound>" * len(captions)
|
88 |
+
|
89 |
+
example = {
|
90 |
+
"sounds": paths,
|
91 |
+
"messages": [
|
92 |
+
{
|
93 |
+
"role": ROLE_USER,
|
94 |
+
"content": q,
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"role": ROLE_ASSISTANT,
|
98 |
+
"content": a,
|
99 |
+
},
|
100 |
+
],
|
101 |
+
}
|
102 |
+
return example
|
103 |
+
|
104 |
+
|
105 |
+
def main(args):
|
106 |
+
data = load_dataset("Chr0my/Epidemic_sounds", split="train")
|
107 |
+
data_idxs = list(range(len(data)))
|
108 |
+
|
109 |
+
os.makedirs(args.cache_folder, exist_ok=True)
|
110 |
+
|
111 |
+
def gen(seeds):
|
112 |
+
r = random.Random(seeds[0] + 3)
|
113 |
+
cache = open(
|
114 |
+
os.path.join(args.cache_folder, f"gpt-cache.{seeds[0]}.jsonl"), "a"
|
115 |
+
)
|
116 |
+
i = 0
|
117 |
+
while i < len(seeds):
|
118 |
+
selected_idxs = r.sample(data_idxs, k=1)[0]
|
119 |
+
selected_example = data[selected_idxs]
|
120 |
+
try:
|
121 |
+
example = _build_convo(selected_example)
|
122 |
+
cache.write(json.dumps(example) + "\n")
|
123 |
+
yield example
|
124 |
+
i += 1
|
125 |
+
except Exception as e:
|
126 |
+
print(e)
|
127 |
+
continue
|
128 |
+
cache.close()
|
129 |
+
|
130 |
+
ds = Dataset.from_generator(
|
131 |
+
gen,
|
132 |
+
num_proc=args.num_proc,
|
133 |
+
gen_kwargs={"seeds": list(range(args.num_examples))},
|
134 |
+
)
|
135 |
+
ds.save_to_disk(args.output_folder)
|
136 |
+
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
parser = argparse.ArgumentParser()
|
140 |
+
parser.add_argument(
|
141 |
+
"-o",
|
142 |
+
"--output_folder",
|
143 |
+
type=str,
|
144 |
+
default="/data/clap-gpt-finetune",
|
145 |
+
)
|
146 |
+
parser.add_argument(
|
147 |
+
"-c",
|
148 |
+
"--cache_folder",
|
149 |
+
type=str,
|
150 |
+
default="/data/clap-gpt-finetune-cache",
|
151 |
+
)
|
152 |
+
parser.add_argument("-n", "--num_examples", type=int, default=100_000)
|
153 |
+
parser.add_argument("-p", "--num_proc", type=int, default=10)
|
154 |
+
args = parser.parse_args()
|
155 |
+
main(args)
|
src/sonicverse/scripts/clap_gpt_build_pretrain_dataset.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import openai
|
7 |
+
|
8 |
+
from datasets import Dataset, load_dataset
|
9 |
+
|
10 |
+
from multi_token.constants import ROLE_ASSISTANT, ROLE_USER
|
11 |
+
|
12 |
+
PROMPT = """
|
13 |
+
You are helping write captions for audio clips.
|
14 |
+
|
15 |
+
Here are the tags for the audio clip you are captioning:
|
16 |
+
{captions}
|
17 |
+
|
18 |
+
Write a brief caption for the audio clip.
|
19 |
+
"""
|
20 |
+
|
21 |
+
PRETRAIN_PHRASES = [
|
22 |
+
"What is happening in <sound>?",
|
23 |
+
"Describe the sound. <sound>",
|
24 |
+
"<sound> Provide a description of the audio.",
|
25 |
+
"Can you interpret <sound>?",
|
26 |
+
"Please explain what's happening in <sound>",
|
27 |
+
"What does <sound> represent?",
|
28 |
+
"Could you describe <sound> for me?",
|
29 |
+
"What's the content of <sound>?",
|
30 |
+
"Can you depict <sound>?",
|
31 |
+
"What is <sound>?",
|
32 |
+
"In the audo clip, <sound>, what is happening?",
|
33 |
+
"Provide a description of the sound. <sound>",
|
34 |
+
"Provide a caption for the sound. <sound>",
|
35 |
+
]
|
36 |
+
|
37 |
+
OPENAI_TOOLS = [
|
38 |
+
{
|
39 |
+
"type": "function",
|
40 |
+
"function": {
|
41 |
+
"name": "write_caption",
|
42 |
+
"description": "Write a caption for an audio clip",
|
43 |
+
"parameters": {
|
44 |
+
"type": "object",
|
45 |
+
"properties": {
|
46 |
+
"caption": {
|
47 |
+
"type": "string",
|
48 |
+
},
|
49 |
+
},
|
50 |
+
"required": ["caption"],
|
51 |
+
},
|
52 |
+
},
|
53 |
+
}
|
54 |
+
]
|
55 |
+
|
56 |
+
|
57 |
+
def _build_convo(row) -> List:
|
58 |
+
client = openai.Client()
|
59 |
+
|
60 |
+
captions = [row["metadataTags"]]
|
61 |
+
sounds = [row["url"]]
|
62 |
+
|
63 |
+
captions_text = "\n".join([f'Tags: "{cap}"' for i, cap in enumerate(captions)])
|
64 |
+
prompt = PROMPT.format(captions=captions_text).strip()
|
65 |
+
|
66 |
+
completion = client.chat.completions.create(
|
67 |
+
model="gpt-3.5-turbo-1106",
|
68 |
+
messages=[{"role": "system", "content": prompt}],
|
69 |
+
tools=OPENAI_TOOLS,
|
70 |
+
tool_choice={"type": "function", "function": {"name": "write_caption"}},
|
71 |
+
)
|
72 |
+
resp = json.loads(completion.choices[0].message.tool_calls[0].function.arguments)
|
73 |
+
caption = resp["caption"]
|
74 |
+
|
75 |
+
q = random.choice(PRETRAIN_PHRASES)
|
76 |
+
|
77 |
+
example = {
|
78 |
+
"sounds": sounds,
|
79 |
+
"messages": [
|
80 |
+
{
|
81 |
+
"role": ROLE_USER,
|
82 |
+
"content": q,
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"role": ROLE_ASSISTANT,
|
86 |
+
"content": caption,
|
87 |
+
},
|
88 |
+
],
|
89 |
+
}
|
90 |
+
return example
|
91 |
+
|
92 |
+
|
93 |
+
def main(args):
|
94 |
+
data = load_dataset("Chr0my/Epidemic_sounds", split="train")
|
95 |
+
|
96 |
+
os.makedirs(args.cache_folder, exist_ok=True)
|
97 |
+
|
98 |
+
def gen(seeds):
|
99 |
+
cache = open(
|
100 |
+
os.path.join(args.cache_folder, f"gpt-cache.{seeds[0]}.jsonl"), "a"
|
101 |
+
)
|
102 |
+
for s in seeds:
|
103 |
+
selected_row = data[s]
|
104 |
+
try:
|
105 |
+
example = _build_convo(selected_row)
|
106 |
+
cache.write(json.dumps(example) + "\n")
|
107 |
+
yield example
|
108 |
+
except Exception as e:
|
109 |
+
print(e)
|
110 |
+
continue
|
111 |
+
|
112 |
+
cache.close()
|
113 |
+
|
114 |
+
idxs = list(range(len(data)))
|
115 |
+
random.shuffle(idxs)
|
116 |
+
|
117 |
+
ds = Dataset.from_generator(
|
118 |
+
gen,
|
119 |
+
num_proc=args.num_proc,
|
120 |
+
gen_kwargs={"seeds": idxs},
|
121 |
+
)
|
122 |
+
ds.save_to_disk(args.output_folder)
|
123 |
+
|
124 |
+
|
125 |
+
if __name__ == "__main__":
|
126 |
+
parser = argparse.ArgumentParser()
|
127 |
+
parser.add_argument(
|
128 |
+
"-o",
|
129 |
+
"--output_folder",
|
130 |
+
type=str,
|
131 |
+
default="/data/clap-gpt-pretrain",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"-c",
|
135 |
+
"--cache_folder",
|
136 |
+
type=str,
|
137 |
+
default="/data/clap-gpt-pretrain-cache",
|
138 |
+
)
|
139 |
+
parser.add_argument("-n", "--num_examples", type=int, default=500_000)
|
140 |
+
parser.add_argument("-p", "--num_proc", type=int, default=10)
|
141 |
+
args = parser.parse_args()
|
142 |
+
main(args)
|
src/sonicverse/scripts/document_build_finetune_dataset.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import argparse
|
3 |
+
import re
|
4 |
+
import glob
|
5 |
+
import json
|
6 |
+
|
7 |
+
from datasets import load_dataset
|
8 |
+
from datasets import Dataset
|
9 |
+
|
10 |
+
from multi_token.constants import ROLE_ASSISTANT, ROLE_USER
|
11 |
+
from multi_token.modalities.document_gte import (
|
12 |
+
split_text_into_documents,
|
13 |
+
)
|
14 |
+
|
15 |
+
TEMP_TOKEN = "<<<TEMP-TOKEN>>>"
|
16 |
+
|
17 |
+
# regex, doc, prompt
|
18 |
+
LONG_ALPACA_REGEXES = [
|
19 |
+
(
|
20 |
+
r"Below is a paper. Memorize the paper and answer my question after the paper.\n The paper begins. \n ([\s\S]+) \n Now the paper ends. \n([\s\S]+)",
|
21 |
+
lambda m: m.group(1),
|
22 |
+
lambda m: f"Read the paper {TEMP_TOKEN}. {m.group(2)}",
|
23 |
+
),
|
24 |
+
(
|
25 |
+
r"Below is a paper. Memorize the material and answer my question after the paper.\n([\s\S]+)\n Now the material ends. ([\s\S]+)",
|
26 |
+
lambda m: m.group(1),
|
27 |
+
lambda m: f"Read the paper {TEMP_TOKEN}. {m.group(2)}",
|
28 |
+
),
|
29 |
+
(
|
30 |
+
r"There are two papers. Memorize them and answer my question after the paper.\n The first paper begins. \n ([\s\S]+) Now the second paper ends.([\s\S]+)",
|
31 |
+
lambda m: m.group(1),
|
32 |
+
lambda m: f"Read the papers {TEMP_TOKEN}. {m.group(2)}",
|
33 |
+
),
|
34 |
+
(
|
35 |
+
r"Below is some paragraphs in the book, ([\s\S]+?). Memorize the content and answer my question after the book.\n([\s\S]+) \n Now the material ends.([\s\S]+)",
|
36 |
+
lambda m: m.group(2),
|
37 |
+
lambda m: f"Read the book {m.group(1)} {TEMP_TOKEN}. {m.group(3)}",
|
38 |
+
),
|
39 |
+
]
|
40 |
+
|
41 |
+
# regex, doc, prompt, answer
|
42 |
+
LONG_DATA_REGEXES = [
|
43 |
+
(
|
44 |
+
r"Write a high-quality answer for the given question using only the provided search results \(some of which might be irrelevant\).([\s\S]+)Question: ([\s\S]+)Answer: ([\s\S]+)\nLong Answer: ([\s\S]+)",
|
45 |
+
lambda m: m.group(1).strip(),
|
46 |
+
lambda m: f"Write a high-quality answer for the given question using only the provided search results {TEMP_TOKEN}. {m.group(2).strip()}",
|
47 |
+
lambda m: m.group(4).strip(),
|
48 |
+
),
|
49 |
+
(
|
50 |
+
r"([\s\S]+)\nQ: ([\s\S]+)\nA: ([\s\S]+)",
|
51 |
+
lambda m: m.group(1).strip(),
|
52 |
+
lambda m: f"Read the following book {TEMP_TOKEN}. {m.group(2).strip()}",
|
53 |
+
lambda m: m.group(3).strip(),
|
54 |
+
),
|
55 |
+
]
|
56 |
+
|
57 |
+
|
58 |
+
def _write_long_alpaca_convo(row, max_document_chunks) -> List:
|
59 |
+
doc_text = None
|
60 |
+
prompt = None
|
61 |
+
for regex, get_doc, get_prompt in LONG_ALPACA_REGEXES:
|
62 |
+
match = re.match(regex, row["instruction"])
|
63 |
+
if match:
|
64 |
+
doc_text = get_doc(match)
|
65 |
+
prompt = get_prompt(match).replace("Question: ", "")
|
66 |
+
break
|
67 |
+
|
68 |
+
if doc_text is None and row["input"]:
|
69 |
+
doc_text = row["input"]
|
70 |
+
prompt = row["instruction"] + f" {TEMP_TOKEN}"
|
71 |
+
|
72 |
+
if doc_text is None:
|
73 |
+
raise ValueError("No document found")
|
74 |
+
|
75 |
+
docs = split_text_into_documents(doc_text)
|
76 |
+
if len(docs) > max_document_chunks:
|
77 |
+
raise ValueError("Document too long")
|
78 |
+
example = {
|
79 |
+
"id": "longalpaca-" + str(hash(row["instruction"])),
|
80 |
+
"documents": docs,
|
81 |
+
}
|
82 |
+
example["messages"] = [
|
83 |
+
{
|
84 |
+
"role": ROLE_USER,
|
85 |
+
"content": prompt.replace(TEMP_TOKEN, "<document>" * len(docs)),
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"role": ROLE_ASSISTANT,
|
89 |
+
"content": row["output"].replace("Answer: ", ""),
|
90 |
+
},
|
91 |
+
]
|
92 |
+
return example
|
93 |
+
|
94 |
+
|
95 |
+
def _write_long_data_collections_convo(row, max_document_chunks) -> List:
|
96 |
+
doc_text = None
|
97 |
+
prompt = None
|
98 |
+
answer = None
|
99 |
+
for regex, get_doc, get_prompt, get_answer in LONG_DATA_REGEXES:
|
100 |
+
match = re.match(regex, row["text"])
|
101 |
+
if match:
|
102 |
+
doc_text = get_doc(match)
|
103 |
+
prompt = get_prompt(match)
|
104 |
+
answer = get_answer(match).replace(" .", ".")
|
105 |
+
break
|
106 |
+
|
107 |
+
if not doc_text or not prompt or not answer:
|
108 |
+
raise ValueError("No document found")
|
109 |
+
|
110 |
+
docs = split_text_into_documents(doc_text)
|
111 |
+
if len(docs) > max_document_chunks:
|
112 |
+
raise ValueError("Document too long")
|
113 |
+
example = {
|
114 |
+
"id": "longdatacollection-" + str(hash(row["text"])),
|
115 |
+
"documents": docs,
|
116 |
+
}
|
117 |
+
example["messages"] = [
|
118 |
+
{
|
119 |
+
"role": ROLE_USER,
|
120 |
+
"content": prompt.replace(TEMP_TOKEN, "<document>" * len(docs)),
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"role": ROLE_ASSISTANT,
|
124 |
+
"content": answer,
|
125 |
+
},
|
126 |
+
]
|
127 |
+
return example
|
128 |
+
|
129 |
+
|
130 |
+
def main(args):
|
131 |
+
long_alpaca = load_dataset(args.long_alpaca_path, "train")["train"]
|
132 |
+
|
133 |
+
def gen():
|
134 |
+
for row in long_alpaca:
|
135 |
+
try:
|
136 |
+
yield _write_long_alpaca_convo(row, args.max_document_chunks)
|
137 |
+
except ValueError:
|
138 |
+
continue
|
139 |
+
for long_collection_fn in glob.iglob(args.long_collections_glob):
|
140 |
+
with open(long_collection_fn) as f:
|
141 |
+
for line in f:
|
142 |
+
row = json.loads(line)
|
143 |
+
try:
|
144 |
+
yield _write_long_data_collections_convo(
|
145 |
+
row, args.max_document_chunks
|
146 |
+
)
|
147 |
+
except ValueError:
|
148 |
+
continue
|
149 |
+
|
150 |
+
ds = Dataset.from_generator(gen)
|
151 |
+
ds = ds.shuffle(seed=42)
|
152 |
+
ds.save_to_disk(args.output_folder)
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
parser = argparse.ArgumentParser()
|
157 |
+
parser.add_argument("--long_alpaca_path", type=str, default="Yukang/LongAlpaca-12k")
|
158 |
+
parser.add_argument("--long_collections_glob", type=str)
|
159 |
+
parser.add_argument("-o", "--output_folder", type=str)
|
160 |
+
parser.add_argument("-c", "--max_document_chunks", type=int, default=256)
|
161 |
+
args = parser.parse_args()
|
162 |
+
main(args)
|
src/sonicverse/scripts/document_build_pretrain_dataset.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import random
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
from datasets import load_dataset
|
6 |
+
from datasets import Dataset
|
7 |
+
|
8 |
+
from multi_token.constants import ROLE_ASSISTANT, ROLE_USER
|
9 |
+
from multi_token.modalities.document_gte import (
|
10 |
+
split_text_into_documents,
|
11 |
+
)
|
12 |
+
|
13 |
+
TEMP_TOKEN = "<<<TEMP-TOKEN>>>"
|
14 |
+
|
15 |
+
PRETRAIN_PHRASES = [
|
16 |
+
f"Repeat the content of the document {TEMP_TOKEN}",
|
17 |
+
f"Transcribe {TEMP_TOKEN}",
|
18 |
+
f"Provide a verbatim transcription of {TEMP_TOKEN}",
|
19 |
+
f"Write down exactly what is in {TEMP_TOKEN}",
|
20 |
+
f"Copy the text from {TEMP_TOKEN}",
|
21 |
+
f"Duplicate the content of {TEMP_TOKEN}",
|
22 |
+
f"Reproduce the text in {TEMP_TOKEN}",
|
23 |
+
f"Render the exact text from {TEMP_TOKEN}",
|
24 |
+
f"Echo the content of {TEMP_TOKEN}",
|
25 |
+
f"Mirror the text in {TEMP_TOKEN}",
|
26 |
+
f"Reflect the content of {TEMP_TOKEN}",
|
27 |
+
f"Transcribe the exact words from {TEMP_TOKEN}",
|
28 |
+
f"Write out the exact content of {TEMP_TOKEN}",
|
29 |
+
f"Provide a direct transcription of {TEMP_TOKEN}",
|
30 |
+
f"Give a word-for-word account of {TEMP_TOKEN}",
|
31 |
+
f"Reiterate the exact text of {TEMP_TOKEN}",
|
32 |
+
f"Replicate the content of {TEMP_TOKEN}",
|
33 |
+
f"Reprint the text from {TEMP_TOKEN}",
|
34 |
+
f"Rewrite the exact words from {TEMP_TOKEN}",
|
35 |
+
]
|
36 |
+
|
37 |
+
|
38 |
+
def _write_convo(row, max_document_chunks) -> List:
|
39 |
+
docs = split_text_into_documents(row["text"])
|
40 |
+
if len(docs) > max_document_chunks:
|
41 |
+
raise ValueError("Document too long")
|
42 |
+
example = {
|
43 |
+
"id": str(row["title"]),
|
44 |
+
"documents": docs,
|
45 |
+
}
|
46 |
+
phrase = random.choice(PRETRAIN_PHRASES)
|
47 |
+
example["messages"] = [
|
48 |
+
{
|
49 |
+
"role": ROLE_USER,
|
50 |
+
"content": phrase.replace(TEMP_TOKEN, "<document>" * len(docs)),
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"role": ROLE_ASSISTANT,
|
54 |
+
"content": row["text"],
|
55 |
+
},
|
56 |
+
]
|
57 |
+
return example
|
58 |
+
|
59 |
+
|
60 |
+
def main(args):
|
61 |
+
wiki_data = load_dataset("graelo/wikipedia", "20230601.en")["train"]
|
62 |
+
|
63 |
+
idxs = list(range(len(wiki_data)))
|
64 |
+
random.shuffle(idxs)
|
65 |
+
|
66 |
+
def gen():
|
67 |
+
i = 0
|
68 |
+
for idx in idxs:
|
69 |
+
row = wiki_data[idx]
|
70 |
+
try:
|
71 |
+
yield _write_convo(row, args.max_document_chunks)
|
72 |
+
except ValueError:
|
73 |
+
pass
|
74 |
+
else:
|
75 |
+
i += 1
|
76 |
+
if i >= args.max_examples:
|
77 |
+
break
|
78 |
+
|
79 |
+
ds = Dataset.from_generator(gen)
|
80 |
+
ds.save_to_disk(args.output_folder)
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == "__main__":
|
84 |
+
parser = argparse.ArgumentParser()
|
85 |
+
parser.add_argument("-o", "--output_folder", type=str)
|
86 |
+
parser.add_argument("-n", "--max_examples", type=int, default=1_000_000)
|
87 |
+
parser.add_argument("-c", "--max_document_chunks", type=int, default=4)
|
88 |
+
args = parser.parse_args()
|
89 |
+
main(args)
|
src/sonicverse/scripts/document_setup.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
pip install nltk
|
4 |
+
|
5 |
+
python -c "import nltk; nltk.download('punkt')"
|
src/sonicverse/scripts/evaluate_model.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from flask import Flask, request, jsonify
|
5 |
+
import transformers
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from datasets import load_from_disk
|
9 |
+
|
10 |
+
from multi_token.model_utils import MultiTaskType
|
11 |
+
from multi_token.training import (
|
12 |
+
ModelArguments,
|
13 |
+
)
|
14 |
+
from multi_token.inference import load_trained_lora_model
|
15 |
+
from multi_token.data_tools import encode_chat
|
16 |
+
|
17 |
+
import evaluate
|
18 |
+
|
19 |
+
import random
|
20 |
+
|
21 |
+
PRETRAIN_PHRASES = [
|
22 |
+
"What is happening in the given music <sound>?",
|
23 |
+
"Describe the sound. <sound>",
|
24 |
+
"Describe the music. <sound>",
|
25 |
+
"<sound> Provide a description of the music.",
|
26 |
+
"<sound> Provide a description of the sound.",
|
27 |
+
"Can you interpret <sound>?",
|
28 |
+
"Please explain what's happening in <sound>",
|
29 |
+
"What does <sound> represent?",
|
30 |
+
"Could you describe <sound> for me?",
|
31 |
+
"What's the content of <sound>?",
|
32 |
+
"Can you depict <sound>?",
|
33 |
+
"What is <sound>?",
|
34 |
+
"In the music clip, <sound>, what is happening?",
|
35 |
+
"Provide a description of the music. <sound>",
|
36 |
+
"Provide a description of the sound. <sound>",
|
37 |
+
"Provide a caption for the sound. <sound>",
|
38 |
+
"Provide a caption for the music. <sound>",
|
39 |
+
]
|
40 |
+
|
41 |
+
|
42 |
+
@dataclass
|
43 |
+
class ServeArguments(ModelArguments):
|
44 |
+
port: int = field(default=8080)
|
45 |
+
host: str = field(default="0.0.0.0")
|
46 |
+
load_bits: int = field(default=16)
|
47 |
+
max_new_tokens: int = field(default=128)
|
48 |
+
temperature: float = field(default=0.01)
|
49 |
+
|
50 |
+
|
51 |
+
def generate(input_json):
|
52 |
+
encoded_dict = encode_chat(input_json, tokenizer, model.modalities)
|
53 |
+
|
54 |
+
with torch.inference_mode():
|
55 |
+
output_ids = model.generate(
|
56 |
+
input_ids=encoded_dict["input_ids"].unsqueeze(0).to(model.device),
|
57 |
+
max_new_tokens=serve_args.max_new_tokens,
|
58 |
+
use_cache=True,
|
59 |
+
do_sample=True,
|
60 |
+
temperature=serve_args.temperature,
|
61 |
+
modality_inputs={
|
62 |
+
m.name: [encoded_dict[m.name]] for m in model.modalities
|
63 |
+
},
|
64 |
+
)
|
65 |
+
|
66 |
+
outputs = tokenizer.decode(
|
67 |
+
output_ids[0, encoded_dict["input_ids"].shape[0] :],
|
68 |
+
skip_special_tokens=True,
|
69 |
+
).strip()
|
70 |
+
|
71 |
+
return {"output": outputs}
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
logging.getLogger().setLevel(logging.INFO)
|
76 |
+
|
77 |
+
parser = transformers.HfArgumentParser((ServeArguments,))
|
78 |
+
|
79 |
+
serve_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
80 |
+
|
81 |
+
dataset_path = "/data/musicbench_multitoken_official_split/val"
|
82 |
+
|
83 |
+
ds = load_from_disk(dataset_path)
|
84 |
+
|
85 |
+
model, tokenizer = load_trained_lora_model(
|
86 |
+
model_name_or_path=serve_args.model_name_or_path,
|
87 |
+
model_lora_path=serve_args.model_lora_path,
|
88 |
+
load_bits=serve_args.load_bits,
|
89 |
+
use_multi_task=MultiTaskType(serve_args.use_multi_task),
|
90 |
+
tasks_config=serve_args.tasks_config
|
91 |
+
)
|
92 |
+
|
93 |
+
predictions = []
|
94 |
+
references = []
|
95 |
+
content_phrase = random.choice(PRETRAIN_PHRASES)
|
96 |
+
for data_point_id in range(100):
|
97 |
+
data_point = ds[data_point_id]
|
98 |
+
# print("datapoint", data_point)
|
99 |
+
input_json={"messages": [{"role": "user", "content": content_phrase}], "sounds": data_point["sounds"]}
|
100 |
+
output_json = generate(input_json)
|
101 |
+
|
102 |
+
print("Prediction ",output_json["output"])
|
103 |
+
print("Reference ", data_point["messages"][1]["content"])
|
104 |
+
print()
|
105 |
+
print()
|
106 |
+
predictions.append(output_json["output"])
|
107 |
+
references.append(data_point["messages"][1]["content"])
|
108 |
+
|
109 |
+
sacrebleu = evaluate.load("sacrebleu")
|
110 |
+
sacrebleu_results=sacrebleu.compute(predictions=predictions, references=references)
|
111 |
+
|
112 |
+
print(sacrebleu_results["score"])
|
src/sonicverse/scripts/evaluate_model_latest.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
import logging
|
4 |
+
from flask import Flask, request, jsonify
|
5 |
+
import transformers
|
6 |
+
import torch
|
7 |
+
from datasets import load_from_disk
|
8 |
+
from multi_token.model_utils import MultiTaskType
|
9 |
+
from multi_token.training import ModelArguments
|
10 |
+
from multi_token.inference import load_trained_lora_model
|
11 |
+
from multi_token.data_tools import encode_chat
|
12 |
+
import evaluate
|
13 |
+
import random
|
14 |
+
import bert_score
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
PRETRAIN_PHRASES_OLD = [
|
18 |
+
"Describe the audio in detail"
|
19 |
+
]
|
20 |
+
|
21 |
+
PRETRAIN_PHRASES = [
|
22 |
+
"What is happening in the given music <sound>?",
|
23 |
+
"Describe the sound. <sound>",
|
24 |
+
"Describe the music. <sound>",
|
25 |
+
"<sound> Provide a description of the music.",
|
26 |
+
"<sound> Provide a description of the sound.",
|
27 |
+
"Can you interpret <sound>?",
|
28 |
+
"Please explain what's happening in <sound>",
|
29 |
+
"What does <sound> represent?",
|
30 |
+
"Could you describe <sound> for me?",
|
31 |
+
"What's the content of <sound>?",
|
32 |
+
"Can you depict <sound>?",
|
33 |
+
"What is <sound>?",
|
34 |
+
"In the music clip, <sound>, what is happening?",
|
35 |
+
"Provide a description of the music. <sound>",
|
36 |
+
"Provide a description of the sound. <sound>",
|
37 |
+
"Provide a caption for the sound. <sound>",
|
38 |
+
"Provide a caption for the music. <sound>",
|
39 |
+
]
|
40 |
+
|
41 |
+
random.seed(1234)
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class ServeArguments(ModelArguments):
|
45 |
+
port: int = field(default=8080)
|
46 |
+
host: str = field(default="0.0.0.0")
|
47 |
+
load_bits: int = field(default=16)
|
48 |
+
max_new_tokens: int = field(default=128)
|
49 |
+
temperature: float = field(default=0.01)
|
50 |
+
|
51 |
+
def generate(input_json):
|
52 |
+
encoded_dict = encode_chat(input_json, tokenizer, model.modalities)
|
53 |
+
with torch.inference_mode():
|
54 |
+
output_ids = model.generate(
|
55 |
+
input_ids=encoded_dict["input_ids"].unsqueeze(0).to(model.device),
|
56 |
+
max_new_tokens=serve_args.max_new_tokens,
|
57 |
+
use_cache=True,
|
58 |
+
do_sample=True,
|
59 |
+
temperature=serve_args.temperature,
|
60 |
+
modality_inputs={
|
61 |
+
m.name: [encoded_dict[m.name]] for m in model.modalities
|
62 |
+
},
|
63 |
+
)
|
64 |
+
outputs = tokenizer.decode(
|
65 |
+
output_ids[0, encoded_dict["input_ids"].shape[0]:],
|
66 |
+
skip_special_tokens=True,
|
67 |
+
).strip()
|
68 |
+
return {"output": outputs}
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
logging.getLogger().setLevel(logging.INFO)
|
72 |
+
parser = transformers.HfArgumentParser((ServeArguments,))
|
73 |
+
serve_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
74 |
+
dataset_path = "/data/musicbench_multitoken_official_split/val"
|
75 |
+
ds = load_from_disk(dataset_path)
|
76 |
+
shuffled_ds = ds.shuffle(seed=1234)
|
77 |
+
model, tokenizer = load_trained_lora_model(
|
78 |
+
model_name_or_path=serve_args.model_name_or_path,
|
79 |
+
model_lora_path=serve_args.model_lora_path,
|
80 |
+
load_bits=serve_args.load_bits,
|
81 |
+
use_multi_task=MultiTaskType(serve_args.use_multi_task),
|
82 |
+
tasks_config=serve_args.tasks_config
|
83 |
+
)
|
84 |
+
|
85 |
+
predictions = []
|
86 |
+
references = []
|
87 |
+
content_phrase = random.choice(PRETRAIN_PHRASES)
|
88 |
+
# for data_point_id in range(len(ds)):
|
89 |
+
for data_point_id in tqdm(range(10)):
|
90 |
+
data_point = shuffled_ds[data_point_id]
|
91 |
+
input_json = {"messages": [{"role": "user", "content": content_phrase}], "sounds": data_point["sounds"]}
|
92 |
+
output_json = generate(input_json)
|
93 |
+
print("Prediction ", output_json["output"])
|
94 |
+
print("Reference ", data_point["messages"][1]["content"])
|
95 |
+
print()
|
96 |
+
print()
|
97 |
+
predictions.append(output_json["output"])
|
98 |
+
references.append(data_point["messages"][1]["content"])
|
99 |
+
|
100 |
+
# Load evaluation metrics
|
101 |
+
bleu = evaluate.load("bleu")
|
102 |
+
meteor = evaluate.load("meteor")
|
103 |
+
rouge = evaluate.load("rouge")
|
104 |
+
|
105 |
+
# Compute BLEU scores
|
106 |
+
bleu_results = bleu.compute(predictions=predictions, references=references, max_order=4)
|
107 |
+
print(bleu_results)
|
108 |
+
#bleu_score = sum(bleu_results[f"bleu{i}"] for i in range(1, 5)) / 4
|
109 |
+
|
110 |
+
# Compute METEOR score
|
111 |
+
meteor_results = meteor.compute(predictions=predictions, references=references)
|
112 |
+
meteor_score = meteor_results["meteor"]
|
113 |
+
|
114 |
+
# Compute ROUGE-L score
|
115 |
+
rouge_results = rouge.compute(predictions=predictions, references=references, rouge_types=["rougeL"])
|
116 |
+
rouge_l_score = rouge_results["rougeL"].mid.fmeasure
|
117 |
+
#print(rouge_results)
|
118 |
+
|
119 |
+
# Compute BERT-Score
|
120 |
+
P, R, F1 = bert_score.score(predictions, references, lang="en", rescale_with_baseline=True)
|
121 |
+
bert_score_f1 = F1.mean().item()
|
122 |
+
|
123 |
+
# Print results
|
124 |
+
#print(f"BLEU Score: {bleu_score}")
|
125 |
+
print(f"METEOR Score: {meteor_score}")
|
126 |
+
#print(f"ROUGE-L Score: {rouge_l_score}")
|
127 |
+
print(f"BERT-Score F1: {bert_score_f1}")
|
src/sonicverse/scripts/evaluate_model_mullama.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
import logging
|
4 |
+
from flask import Flask, request, jsonify
|
5 |
+
import transformers
|
6 |
+
import torch
|
7 |
+
from datasets import load_from_disk
|
8 |
+
from multi_token.model_utils import MultiTaskType
|
9 |
+
from multi_token.training import ModelArguments
|
10 |
+
from multi_token.inference import load_trained_lora_model
|
11 |
+
from multi_token.data_tools import encode_chat
|
12 |
+
import evaluate
|
13 |
+
import random
|
14 |
+
import bert_score
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
from rouge_score import rouge_scorer
|
18 |
+
from nltk.translate.bleu_score import sentence_bleu
|
19 |
+
from nltk.translate.meteor_score import meteor_score as meteor_scorer
|
20 |
+
from nltk.tokenize import wordpunct_tokenize
|
21 |
+
import json
|
22 |
+
from bert_score import score
|
23 |
+
from tqdm.auto import tqdm
|
24 |
+
|
25 |
+
import yaml
|
26 |
+
|
27 |
+
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
|
28 |
+
|
29 |
+
|
30 |
+
PRETRAIN_PHRASES_OLD = [
|
31 |
+
"Describe the audio in detail"
|
32 |
+
]
|
33 |
+
|
34 |
+
PRETRAIN_PHRASES = [
|
35 |
+
"What is happening in the given music <sound>?",
|
36 |
+
"Describe the sound. <sound>",
|
37 |
+
"Describe the music. <sound>",
|
38 |
+
"<sound> Provide a description of the music.",
|
39 |
+
"<sound> Provide a description of the sound.",
|
40 |
+
"Can you interpret <sound>?",
|
41 |
+
"Please explain what's happening in <sound>",
|
42 |
+
"What does <sound> represent?",
|
43 |
+
"Could you describe <sound> for me?",
|
44 |
+
"What's the content of <sound>?",
|
45 |
+
"Can you depict <sound>?",
|
46 |
+
"What is <sound>?",
|
47 |
+
"In the music clip, <sound>, what is happening?",
|
48 |
+
"Provide a description of the music. <sound>",
|
49 |
+
"Provide a description of the sound. <sound>",
|
50 |
+
"Provide a caption for the sound. <sound>",
|
51 |
+
"Provide a caption for the music. <sound>",
|
52 |
+
]
|
53 |
+
|
54 |
+
random.seed(1234)
|
55 |
+
|
56 |
+
@dataclass
|
57 |
+
class ServeArguments(ModelArguments):
|
58 |
+
port: int = field(default=8080)
|
59 |
+
host: str = field(default="0.0.0.0")
|
60 |
+
load_bits: int = field(default=16)
|
61 |
+
max_new_tokens: int = field(default=128)
|
62 |
+
temperature: float = field(default=0.01)
|
63 |
+
|
64 |
+
def generate(input_json):
|
65 |
+
encoded_dict = encode_chat(input_json, tokenizer, model.modalities)
|
66 |
+
with torch.inference_mode():
|
67 |
+
output_ids = model.generate(
|
68 |
+
input_ids=encoded_dict["input_ids"].unsqueeze(0).to(model.device),
|
69 |
+
max_new_tokens=serve_args.max_new_tokens,
|
70 |
+
use_cache=True,
|
71 |
+
do_sample=True,
|
72 |
+
temperature=serve_args.temperature,
|
73 |
+
modality_inputs={
|
74 |
+
m.name: [encoded_dict[m.name]] for m in model.modalities
|
75 |
+
},
|
76 |
+
)
|
77 |
+
outputs = tokenizer.decode(
|
78 |
+
output_ids[0, encoded_dict["input_ids"].shape[0]:],
|
79 |
+
skip_special_tokens=True,
|
80 |
+
).strip()
|
81 |
+
return {"output": outputs}
|
82 |
+
|
83 |
+
def evaluate(candidates, mult_reference):
|
84 |
+
rouge_score, bleu_score, bleu4_score, meteor_score = 0, 0, 0, 0
|
85 |
+
for ref, cand in tqdm(zip(mult_reference, candidates), total=len(mult_reference)):
|
86 |
+
rouge_score += scorer.score(ref, cand)['rougeL'].recall
|
87 |
+
cand_split = wordpunct_tokenize(cand)
|
88 |
+
ref_split = wordpunct_tokenize(ref)
|
89 |
+
bleu4_score += sentence_bleu([ref], cand, weights=(0.0, 0.0, 0.0, 1.0))
|
90 |
+
bleu_score += sentence_bleu([ref], cand)
|
91 |
+
meteor_score += meteor_scorer([ref_split], cand_split)
|
92 |
+
rouge_score, bleu_score, bleu4_score, meteor_score = rouge_score / (len(candidates)), bleu_score / (len(candidates)), bleu4_score / (len(candidates)), meteor_score / (len(candidates))
|
93 |
+
P, R, F1 = score(candidates, mult_reference, lang="en", verbose=True)
|
94 |
+
bert_score = R.mean().item()
|
95 |
+
#print(f"Model: {model_name}")
|
96 |
+
print(f"BLEU Score: {bleu_score}")
|
97 |
+
print(f"BLEU-4 Score: {bleu4_score}")
|
98 |
+
print(f"METEOR Score: {meteor_score}")
|
99 |
+
print(f"ROUGE Score: {rouge_score}")
|
100 |
+
print(f"BERT Score: {bert_score}")
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
logging.getLogger().setLevel(logging.INFO)
|
104 |
+
parser = transformers.HfArgumentParser((ServeArguments,))
|
105 |
+
serve_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
106 |
+
dataset_path = "/data/musicbench_multitoken_official_split/val"
|
107 |
+
ds = load_from_disk(dataset_path)
|
108 |
+
shuffled_ds = ds.shuffle(seed=1234)
|
109 |
+
model, tokenizer = load_trained_lora_model(
|
110 |
+
model_name_or_path=serve_args.model_name_or_path,
|
111 |
+
model_lora_path=serve_args.model_lora_path,
|
112 |
+
load_bits=serve_args.load_bits,
|
113 |
+
use_multi_task=MultiTaskType(serve_args.use_multi_task),
|
114 |
+
tasks_config=serve_args.tasks_config
|
115 |
+
)
|
116 |
+
|
117 |
+
predictions = []
|
118 |
+
references = []
|
119 |
+
content_phrase = random.choice(PRETRAIN_PHRASES)
|
120 |
+
# for data_point_id in range(len(ds)):
|
121 |
+
print("len(ds)", len(ds))
|
122 |
+
for data_point_id in tqdm(range(100)):
|
123 |
+
# for data_point_id in tqdm(range(6831)):
|
124 |
+
data_point = shuffled_ds[data_point_id]
|
125 |
+
input_json = {"messages": [{"role": "user", "content": content_phrase}], "sounds": data_point["sounds"]}
|
126 |
+
output_json = generate(input_json)
|
127 |
+
# print("Prediction ", output_json["output"])
|
128 |
+
# print("Reference ", data_point["messages"][1]["content"])
|
129 |
+
# print()
|
130 |
+
# print()
|
131 |
+
predictions.append(output_json["output"])
|
132 |
+
references.append(data_point["messages"][1]["content"])
|
133 |
+
|
134 |
+
pairs = {"predictions": predictions, "references": references}
|
135 |
+
|
136 |
+
evaluate(predictions, references)
|
137 |
+
|
138 |
+
# with open('/experiments/captioning/mert_tasks_separate_backbone_train_001_ft/checkpoint_1985_test/val_2.yaml', 'w') as file:
|
139 |
+
# yaml.dump(pairs, file, default_flow_style=False)
|
140 |
+
|
141 |
+
# Load evaluation metrics
|
142 |
+
# bleu = evaluate.load("bleu")
|
143 |
+
# meteor = evaluate.load("meteor")
|
144 |
+
# rouge = evaluate.load("rouge")
|
145 |
+
|
146 |
+
# Compute BLEU scores
|
147 |
+
# bleu_results = bleu.compute(predictions=predictions, references=references, max_order=4)
|
148 |
+
# print(bleu_results)
|
149 |
+
#bleu_score = sum(bleu_results[f"bleu{i}"] for i in range(1, 5)) / 4
|
150 |
+
|
151 |
+
# Compute METEOR score
|
152 |
+
# meteor_results = meteor.compute(predictions=predictions, references=references)
|
153 |
+
# meteor_score = meteor_results["meteor"]
|
154 |
+
|
155 |
+
# Compute ROUGE-L score
|
156 |
+
# rouge_results = rouge.compute(predictions=predictions, references=references, rouge_types=["rougeL"])
|
157 |
+
# rouge_l_score = rouge_results["rougeL"].mid.fmeasure
|
158 |
+
# print(rouge_results)
|
159 |
+
|
160 |
+
# Compute BERT-Score
|
161 |
+
# P, R, F1 = bert_score.score(predictions, references, lang="en", rescale_with_baseline=True)
|
162 |
+
# bert_score_f1 = F1.mean().item()
|
163 |
+
|
164 |
+
# Print results
|
165 |
+
#print(f"BLEU Score: {bleu_score}")
|
166 |
+
# print(f"METEOR Score: {meteor_score}")
|
167 |
+
# print(f"ROUGE-L Score: {rouge_l_score}")
|
168 |
+
# print(f"BERT-Score F1: {bert_score_f1}")
|
src/sonicverse/scripts/evaluate_model_mullama_musiccaps.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
import logging
|
4 |
+
from flask import Flask, request, jsonify
|
5 |
+
import transformers
|
6 |
+
import torch
|
7 |
+
from datasets import load_from_disk
|
8 |
+
from multi_token.model_utils import MultiTaskType
|
9 |
+
from multi_token.training import ModelArguments
|
10 |
+
from multi_token.inference import load_trained_lora_model
|
11 |
+
from multi_token.data_tools import encode_chat
|
12 |
+
import evaluate
|
13 |
+
import random
|
14 |
+
import bert_score
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
from rouge_score import rouge_scorer
|
18 |
+
from nltk.translate.bleu_score import sentence_bleu
|
19 |
+
from nltk.translate.meteor_score import meteor_score as meteor_scorer
|
20 |
+
from nltk.tokenize import wordpunct_tokenize
|
21 |
+
import json
|
22 |
+
from bert_score import score
|
23 |
+
from tqdm.auto import tqdm
|
24 |
+
|
25 |
+
import yaml
|
26 |
+
|
27 |
+
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
|
28 |
+
|
29 |
+
|
30 |
+
PRETRAIN_PHRASES_OLD = [
|
31 |
+
"Describe the audio in detail"
|
32 |
+
]
|
33 |
+
|
34 |
+
PRETRAIN_PHRASES = [
|
35 |
+
"What is happening in the given music <sound>?",
|
36 |
+
"Describe the sound. <sound>",
|
37 |
+
"Describe the music. <sound>",
|
38 |
+
"<sound> Provide a description of the music.",
|
39 |
+
"<sound> Provide a description of the sound.",
|
40 |
+
"Can you interpret <sound>?",
|
41 |
+
"Please explain what's happening in <sound>",
|
42 |
+
"What does <sound> represent?",
|
43 |
+
"Could you describe <sound> for me?",
|
44 |
+
"What's the content of <sound>?",
|
45 |
+
"Can you depict <sound>?",
|
46 |
+
"What is <sound>?",
|
47 |
+
"In the music clip, <sound>, what is happening?",
|
48 |
+
"Provide a description of the music. <sound>",
|
49 |
+
"Provide a description of the sound. <sound>",
|
50 |
+
"Provide a caption for the sound. <sound>",
|
51 |
+
"Provide a caption for the music. <sound>",
|
52 |
+
]
|
53 |
+
|
54 |
+
random.seed(1234)
|
55 |
+
|
56 |
+
@dataclass
|
57 |
+
class ServeArguments(ModelArguments):
|
58 |
+
port: int = field(default=8080)
|
59 |
+
host: str = field(default="0.0.0.0")
|
60 |
+
load_bits: int = field(default=16)
|
61 |
+
max_new_tokens: int = field(default=128)
|
62 |
+
temperature: float = field(default=0.01)
|
63 |
+
|
64 |
+
def generate(input_json):
|
65 |
+
encoded_dict = encode_chat(input_json, tokenizer, model.modalities)
|
66 |
+
with torch.inference_mode():
|
67 |
+
output_ids = model.generate(
|
68 |
+
input_ids=encoded_dict["input_ids"].unsqueeze(0).to(model.device),
|
69 |
+
max_new_tokens=serve_args.max_new_tokens,
|
70 |
+
use_cache=True,
|
71 |
+
do_sample=True,
|
72 |
+
temperature=serve_args.temperature,
|
73 |
+
modality_inputs={
|
74 |
+
m.name: [encoded_dict[m.name]] for m in model.modalities
|
75 |
+
},
|
76 |
+
)
|
77 |
+
outputs = tokenizer.decode(
|
78 |
+
output_ids[0, encoded_dict["input_ids"].shape[0]:],
|
79 |
+
skip_special_tokens=True,
|
80 |
+
).strip()
|
81 |
+
return {"output": outputs}
|
82 |
+
|
83 |
+
def evaluate(candidates, mult_reference):
|
84 |
+
rouge_score, bleu_score, bleu4_score, meteor_score = 0, 0, 0, 0
|
85 |
+
for ref, cand in tqdm(zip(mult_reference, candidates), total=len(mult_reference)):
|
86 |
+
rouge_score += scorer.score(ref, cand)['rougeL'].recall
|
87 |
+
cand_split = wordpunct_tokenize(cand)
|
88 |
+
ref_split = wordpunct_tokenize(ref)
|
89 |
+
bleu4_score += sentence_bleu([ref], cand, weights=(0.0, 0.0, 0.0, 1.0))
|
90 |
+
bleu_score += sentence_bleu([ref], cand)
|
91 |
+
meteor_score += meteor_scorer([ref_split], cand_split)
|
92 |
+
rouge_score, bleu_score, bleu4_score, meteor_score = rouge_score / (len(candidates)), bleu_score / (len(candidates)), bleu4_score / (len(candidates)), meteor_score / (len(candidates))
|
93 |
+
P, R, F1 = score(candidates, mult_reference, lang="en", verbose=True)
|
94 |
+
bert_score = R.mean().item()
|
95 |
+
#print(f"Model: {model_name}")
|
96 |
+
print(f"BLEU Score: {bleu_score}")
|
97 |
+
print(f"BLEU-4 Score: {bleu4_score}")
|
98 |
+
print(f"METEOR Score: {meteor_score}")
|
99 |
+
print(f"ROUGE Score: {rouge_score}")
|
100 |
+
print(f"BERT Score: {bert_score}")
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
logging.getLogger().setLevel(logging.INFO)
|
104 |
+
parser = transformers.HfArgumentParser((ServeArguments,))
|
105 |
+
serve_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
106 |
+
# dataset_path = "/data/musiccaps/musiccaps_val"
|
107 |
+
dataset_path = "/data/musicbench_multitoken_official_split/val/"
|
108 |
+
ds = load_from_disk(dataset_path)
|
109 |
+
shuffled_ds = ds.shuffle(seed=1234)
|
110 |
+
model, tokenizer = load_trained_lora_model(
|
111 |
+
model_name_or_path=serve_args.model_name_or_path,
|
112 |
+
model_lora_path=serve_args.model_lora_path,
|
113 |
+
load_bits=serve_args.load_bits,
|
114 |
+
use_multi_task=MultiTaskType(serve_args.use_multi_task),
|
115 |
+
tasks_config=serve_args.tasks_config
|
116 |
+
)
|
117 |
+
|
118 |
+
predictions = []
|
119 |
+
references = []
|
120 |
+
content_phrase = random.choice(PRETRAIN_PHRASES)
|
121 |
+
# for data_point_id in range(len(ds)):
|
122 |
+
print("len(ds)", len(ds))
|
123 |
+
#for data_point in tqdm(ds):
|
124 |
+
for data_point_id in tqdm(range(100)):
|
125 |
+
#print("DATA POINT ", data_point)
|
126 |
+
data_point = ds[data_point_id]
|
127 |
+
print("DATA POINT ", data_point)
|
128 |
+
input_json = {"messages": [{"role": "user", "content": content_phrase}], "sounds": data_point["sounds"]}
|
129 |
+
output_json = generate(input_json)
|
130 |
+
#print("Prediction ", output_json["output"])
|
131 |
+
#print("Reference ", data_point["caption"])
|
132 |
+
#print()
|
133 |
+
#print()
|
134 |
+
predictions.append(output_json["output"])
|
135 |
+
references.append(data_point["messages"][1]["content"])
|
136 |
+
|
137 |
+
pairs = {"predictions": predictions, "references": references}
|
138 |
+
|
139 |
+
evaluate(predictions, references)
|
140 |
+
|
141 |
+
with open('test/musicbench_eval.yaml', 'w') as file:
|
142 |
+
yaml.dump(pairs, file, default_flow_style=False)
|
143 |
+
|
src/sonicverse/scripts/evaluate_model_mullama_musiccaps_fixed_prompt.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
import logging
|
3 |
+
from flask import Flask, request, jsonify
|
4 |
+
import transformers
|
5 |
+
import torch
|
6 |
+
from datasets import load_from_disk
|
7 |
+
from multi_token.model_utils import MultiTaskType
|
8 |
+
from multi_token.training import ModelArguments
|
9 |
+
from multi_token.inference import load_trained_lora_model
|
10 |
+
from multi_token.data_tools import encode_chat
|
11 |
+
import evaluate
|
12 |
+
import random
|
13 |
+
import bert_score
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from rouge_score import rouge_scorer
|
17 |
+
from nltk.translate.bleu_score import sentence_bleu
|
18 |
+
from nltk.translate.meteor_score import meteor_score as meteor_scorer
|
19 |
+
from nltk.tokenize import wordpunct_tokenize
|
20 |
+
import json
|
21 |
+
from bert_score import score
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
|
24 |
+
import yaml
|
25 |
+
|
26 |
+
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
|
27 |
+
|
28 |
+
|
29 |
+
PRETRAIN_PHRASES_OLD = [
|
30 |
+
"Describe the audio in detail"
|
31 |
+
]
|
32 |
+
|
33 |
+
PRETRAIN_PHRASES = [
|
34 |
+
# "What is happening in the given music <sound>?",
|
35 |
+
# "Describe the sound. <sound>",
|
36 |
+
# "Describe the music. <sound>",
|
37 |
+
# "<sound> Provide a description of the music.",
|
38 |
+
# "<sound> Provide a description of the sound.",
|
39 |
+
# "Can you interpret <sound>?",
|
40 |
+
# "Please explain what's happening in <sound>",
|
41 |
+
# "What does <sound> represent?",
|
42 |
+
# "Could you describe <sound> for me?",
|
43 |
+
# "What's the content of <sound>?",
|
44 |
+
# "Can you depict <sound>?",
|
45 |
+
# "What is <sound>?",
|
46 |
+
# "In the music clip, <sound>, what is happening?",
|
47 |
+
# "Provide a description of the music. <sound>",
|
48 |
+
# "Provide a description of the sound. <sound>",
|
49 |
+
# "Provide a caption for the sound. <sound>",
|
50 |
+
"Provide a caption for the music. <sound>",
|
51 |
+
]
|
52 |
+
|
53 |
+
random.seed(1234)
|
54 |
+
|
55 |
+
@dataclass
|
56 |
+
class ServeArguments(ModelArguments):
|
57 |
+
port: int = field(default=8080)
|
58 |
+
host: str = field(default="0.0.0.0")
|
59 |
+
load_bits: int = field(default=16)
|
60 |
+
max_new_tokens: int = field(default=128)
|
61 |
+
temperature: float = field(default=0.01)
|
62 |
+
|
63 |
+
def generate(input_json):
|
64 |
+
encoded_dict = encode_chat(input_json, tokenizer, model.modalities)
|
65 |
+
with torch.inference_mode():
|
66 |
+
output_ids = model.generate(
|
67 |
+
input_ids=encoded_dict["input_ids"].unsqueeze(0).to(model.device),
|
68 |
+
max_new_tokens=serve_args.max_new_tokens,
|
69 |
+
use_cache=True,
|
70 |
+
do_sample=True,
|
71 |
+
temperature=serve_args.temperature,
|
72 |
+
modality_inputs={
|
73 |
+
m.name: [encoded_dict[m.name]] for m in model.modalities
|
74 |
+
},
|
75 |
+
)
|
76 |
+
outputs = tokenizer.decode(
|
77 |
+
output_ids[0, encoded_dict["input_ids"].shape[0]:],
|
78 |
+
skip_special_tokens=True,
|
79 |
+
).strip()
|
80 |
+
return {"output": outputs}
|
81 |
+
|
82 |
+
def evaluate(candidates, mult_reference):
|
83 |
+
rouge_score, bleu_score, bleu4_score, meteor_score = 0, 0, 0, 0
|
84 |
+
for ref, cand in tqdm(zip(mult_reference, candidates), total=len(mult_reference)):
|
85 |
+
rouge_score += scorer.score(ref, cand)['rougeL'].recall
|
86 |
+
cand_split = wordpunct_tokenize(cand)
|
87 |
+
ref_split = wordpunct_tokenize(ref)
|
88 |
+
bleu4_score += sentence_bleu([ref], cand, weights=(0.0, 0.0, 0.0, 1.0))
|
89 |
+
bleu_score += sentence_bleu([ref], cand)
|
90 |
+
meteor_score += meteor_scorer([ref_split], cand_split)
|
91 |
+
rouge_score, bleu_score, bleu4_score, meteor_score = rouge_score / (len(candidates)), bleu_score / (len(candidates)), bleu4_score / (len(candidates)), meteor_score / (len(candidates))
|
92 |
+
P, R, F1 = score(candidates, mult_reference, lang="en", verbose=True)
|
93 |
+
bert_score = R.mean().item()
|
94 |
+
#print(f"Model: {model_name}")
|
95 |
+
print(f"BLEU Score: {bleu_score}")
|
96 |
+
print(f"BLEU-4 Score: {bleu4_score}")
|
97 |
+
print(f"METEOR Score: {meteor_score}")
|
98 |
+
print(f"ROUGE Score: {rouge_score}")
|
99 |
+
print(f"BERT Score: {bert_score}")
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
logging.getLogger().setLevel(logging.INFO)
|
103 |
+
parser = transformers.HfArgumentParser((ServeArguments,))
|
104 |
+
serve_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
105 |
+
dataset_path = "/data/musiccaps/musiccaps_val"
|
106 |
+
ds = load_from_disk(dataset_path)
|
107 |
+
shuffled_ds = ds.shuffle(seed=1234)
|
108 |
+
model, tokenizer = load_trained_lora_model(
|
109 |
+
model_name_or_path=serve_args.model_name_or_path,
|
110 |
+
model_lora_path=serve_args.model_lora_path,
|
111 |
+
load_bits=serve_args.load_bits,
|
112 |
+
use_multi_task=MultiTaskType(serve_args.use_multi_task),
|
113 |
+
tasks_config=serve_args.tasks_config
|
114 |
+
)
|
115 |
+
|
116 |
+
predictions = []
|
117 |
+
references = []
|
118 |
+
content_phrase = random.choice(PRETRAIN_PHRASES)
|
119 |
+
# for data_point_id in range(len(ds)):
|
120 |
+
print("len(ds)", len(ds))
|
121 |
+
for data_point in tqdm(ds):
|
122 |
+
print(data_point["audio"])
|
123 |
+
# data_point = ds[data_point_id]
|
124 |
+
input_json = {"messages": [{"role": "user", "content": content_phrase}], "sounds": [data_point["audio"]]}
|
125 |
+
output_json = generate(input_json)
|
126 |
+
print("Prediction ", output_json["output"])
|
127 |
+
print("Reference ", data_point["caption"])
|
128 |
+
print()
|
129 |
+
print()
|
130 |
+
predictions.append(output_json["output"])
|
131 |
+
references.append(data_point["caption"])
|
132 |
+
|
133 |
+
pairs = {"predictions": predictions, "references": references}
|
134 |
+
|
135 |
+
evaluate(predictions, references)
|
136 |
+
|
137 |
+
with open('/experiments/captioning/mert_tasks_separate_backbone_train_001_ft/checkpoint_1985_test/musiccaps_val_fixed_prompt.yaml', 'w') as file:
|
138 |
+
yaml.dump(pairs, file, default_flow_style=False)
|
src/sonicverse/scripts/evaluate_mullama.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
import logging
|
3 |
+
from flask import Flask, request, jsonify
|
4 |
+
import transformers
|
5 |
+
import torch
|
6 |
+
from datasets import load_from_disk
|
7 |
+
from multi_token.model_utils import MultiTaskType
|
8 |
+
from multi_token.training import ModelArguments
|
9 |
+
from multi_token.inference import load_trained_lora_model
|
10 |
+
from multi_token.data_tools import encode_chat
|
11 |
+
import evaluate
|
12 |
+
import random
|
13 |
+
import bert_score
|
14 |
+
|
15 |
+
PRETRAIN_PHRASES = [
|
16 |
+
"What is happening in the given music <sound>?",
|
17 |
+
"Describe the sound. <sound>",
|
18 |
+
"Describe the music. <sound>",
|
19 |
+
"<sound> Provide a description of the music.",
|
20 |
+
"<sound> Provide a description of the sound.",
|
21 |
+
"Can you interpret <sound>?",
|
22 |
+
"Please explain what's happening in <sound>",
|
23 |
+
"What does <sound> represent?",
|
24 |
+
"Could you describe <sound> for me?",
|
25 |
+
"What's the content of <sound>?",
|
26 |
+
"Can you depict <sound>?",
|
27 |
+
"What is <sound>?",
|
28 |
+
"In the music clip, <sound>, what is happening?",
|
29 |
+
"Provide a description of the music. <sound>",
|
30 |
+
"Provide a description of the sound. <sound>",
|
31 |
+
"Provide a caption for the sound. <sound>",
|
32 |
+
"Provide a caption for the music. <sound>",
|
33 |
+
]
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class ServeArguments(ModelArguments):
|
37 |
+
port: int = field(default=8080)
|
38 |
+
host: str = field(default="0.0.0.0")
|
39 |
+
load_bits: int = field(default=16)
|
40 |
+
max_new_tokens: int = field(default=128)
|
41 |
+
temperature: float = field(default=0.01)
|
42 |
+
|
43 |
+
def generate(input_json):
|
44 |
+
encoded_dict = encode_chat(input_json, tokenizer, model.modalities)
|
45 |
+
with torch.inference_mode():
|
46 |
+
output_ids = model.generate(
|
47 |
+
input_ids=encoded_dict["input_ids"].unsqueeze(0).to(model.device),
|
48 |
+
max_new_tokens=serve_args.max_new_tokens,
|
49 |
+
use_cache=True,
|
50 |
+
do_sample=True,
|
51 |
+
temperature=serve_args.temperature,
|
52 |
+
modality_inputs={
|
53 |
+
m.name: [encoded_dict[m.name]] for m in model.modalities
|
54 |
+
},
|
55 |
+
)
|
56 |
+
outputs = tokenizer.decode(
|
57 |
+
output_ids[0, encoded_dict["input_ids"].shape[0]:],
|
58 |
+
skip_special_tokens=True,
|
59 |
+
).strip()
|
60 |
+
return {"output": outputs}
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
logging.getLogger().setLevel(logging.INFO)
|
64 |
+
parser = transformers.HfArgumentParser((ServeArguments,))
|
65 |
+
serve_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
66 |
+
dataset_path = "/data/musicbench_multitoken_official_split/val"
|
67 |
+
ds = load_from_disk(dataset_path)
|
68 |
+
|
69 |
+
# Load MU-LLaMA model and tokenizer
|
70 |
+
model_name_or_path = "mu-llama/MU-LLaMA"
|
71 |
+
model = transformers.LlamaForCausalLM.from_pretrained(model_name_or_path)
|
72 |
+
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name_or_path)
|
73 |
+
|
74 |
+
predictions = []
|
75 |
+
references = []
|
76 |
+
content_phrase = random.choice(PRETRAIN_PHRASES)
|
77 |
+
for data_point_id in range(100):
|
78 |
+
data_point = ds[data_point_id]
|
79 |
+
input_json = {"messages": [{"role": "user", "content": content_phrase}], "sounds": data_point["sounds"]}
|
80 |
+
output_json = generate(input_json)
|
81 |
+
print("Prediction ", output_json["output"])
|
82 |
+
print("Reference ", data_point["messages"][1]["content"])
|
83 |
+
print()
|
84 |
+
print()
|
85 |
+
predictions.append(output_json["output"])
|
86 |
+
references.append(data_point["messages"][1]["content"])
|
87 |
+
|
88 |
+
# Load evaluation metrics
|
89 |
+
bleu = evaluate.load("bleu")
|
90 |
+
meteor = evaluate.load("meteor")
|
91 |
+
rouge = evaluate.load("rouge")
|
92 |
+
|
93 |
+
# Compute BLEU scores
|
94 |
+
bleu_results = bleu.compute(predictions=predictions, references=references, max_order=4)
|
95 |
+
# bleu_score = sum(bleu_results[f"bleu{i}"] for i in range(1, 5)) / 4
|
96 |
+
print(bleu_results)
|
97 |
+
|
98 |
+
# Compute METEOR score
|
99 |
+
meteor_results = meteor.compute(predictions=predictions, references=references)
|
100 |
+
meteor_score = meteor_results["meteor"]
|
101 |
+
|
102 |
+
# Compute ROUGE-L score
|
103 |
+
rouge_results = rouge.compute(predictions=predictions, references=references, rouge_types=["rougeL"])
|
104 |
+
#rouge_l_score = rouge_results["rougeL"].mid.fmeasure
|
105 |
+
print(rouge_results)
|
106 |
+
|
107 |
+
# Compute BERT-Score
|
108 |
+
P, R, F1 = bert_score.score(predictions, references, lang="en", rescale_with_baseline=True)
|
109 |
+
bert_score_f1 = F1.mean().item()
|
110 |
+
|
111 |
+
# Print results
|
112 |
+
# print(f"BLEU Score: {bleu_score}")
|
113 |
+
print(f"METEOR Score: {meteor_score}")
|
114 |
+
# print(f"ROUGE-L Score: {rouge_l_score}")
|
115 |
+
print(f"BERT-Score F1: {bert_score_f1}")
|
src/sonicverse/scripts/evaluate_temp.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
import logging
|
3 |
+
from flask import Flask, request, jsonify
|
4 |
+
import transformers
|
5 |
+
import torch
|
6 |
+
from datasets import load_from_disk
|
7 |
+
from multi_token.model_utils import MultiTaskType
|
8 |
+
from multi_token.training import ModelArguments
|
9 |
+
from multi_token.inference import load_trained_lora_model
|
10 |
+
from multi_token.data_tools import encode_chat
|
11 |
+
import evaluate
|
12 |
+
import random
|
13 |
+
import bert_score
|
14 |
+
import os
|
15 |
+
|
16 |
+
os.environ['HF_EVALUATE_OFFLINE'] = '1'
|
17 |
+
|
18 |
+
PRETRAIN_PHRASES = ["Describe the audio in detail <sound>"]
|
19 |
+
|
20 |
+
PRETRAIN_PHRASES_old = [
|
21 |
+
"What is happening in the given music <sound>?",
|
22 |
+
"Describe the sound. <sound>",
|
23 |
+
"Describe the music. <sound>",
|
24 |
+
"<sound> Provide a description of the music.",
|
25 |
+
"<sound> Provide a description of the sound.",
|
26 |
+
"Can you interpret <sound>?",
|
27 |
+
"Please explain what's happening in <sound>",
|
28 |
+
"What does <sound> represent?",
|
29 |
+
"Could you describe <sound> for me?",
|
30 |
+
"What's the content of <sound>?",
|
31 |
+
"Can you depict <sound>?",
|
32 |
+
"What is <sound>?",
|
33 |
+
"In the music clip, <sound>, what is happening?",
|
34 |
+
"Provide a description of the music. <sound>",
|
35 |
+
"Provide a description of the sound. <sound>",
|
36 |
+
"Provide a caption for the sound. <sound>",
|
37 |
+
"Provide a caption for the music. <sound>",
|
38 |
+
]
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class ServeArguments(ModelArguments):
|
42 |
+
port: int = field(default=8080)
|
43 |
+
host: str = field(default="0.0.0.0")
|
44 |
+
load_bits: int = field(default=16)
|
45 |
+
max_new_tokens: int = field(default=128)
|
46 |
+
temperature: float = field(default=0.01)
|
47 |
+
|
48 |
+
def generate(input_json):
|
49 |
+
encoded_dict = encode_chat(input_json, tokenizer, model.modalities)
|
50 |
+
with torch.inference_mode():
|
51 |
+
output_ids = model.generate(
|
52 |
+
input_ids=encoded_dict["input_ids"].unsqueeze(0).to(model.device),
|
53 |
+
max_new_tokens=serve_args.max_new_tokens,
|
54 |
+
use_cache=True,
|
55 |
+
do_sample=True,
|
56 |
+
temperature=serve_args.temperature,
|
57 |
+
modality_inputs={
|
58 |
+
m.name: [encoded_dict[m.name]] for m in model.modalities
|
59 |
+
},
|
60 |
+
)
|
61 |
+
outputs = tokenizer.decode(
|
62 |
+
output_ids[0, encoded_dict["input_ids"].shape[0]:],
|
63 |
+
skip_special_tokens=True,
|
64 |
+
).strip()
|
65 |
+
return {"output": outputs}
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
logging.getLogger().setLevel(logging.INFO)
|
69 |
+
parser = transformers.HfArgumentParser((ServeArguments,))
|
70 |
+
serve_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
71 |
+
dataset_path = "/data/musicbench_multitoken_official_split/val"
|
72 |
+
ds = load_from_disk(dataset_path)
|
73 |
+
model, tokenizer = load_trained_lora_model(
|
74 |
+
model_name_or_path=serve_args.model_name_or_path,
|
75 |
+
model_lora_path=serve_args.model_lora_path,
|
76 |
+
load_bits=serve_args.load_bits,
|
77 |
+
use_multi_task=MultiTaskType(serve_args.use_multi_task),
|
78 |
+
tasks_config=serve_args.tasks_config
|
79 |
+
)
|
80 |
+
|
81 |
+
predictions = []
|
82 |
+
references = []
|
83 |
+
content_phrase = random.choice(PRETRAIN_PHRASES)
|
84 |
+
for data_point_id in range(10):
|
85 |
+
data_point = ds[data_point_id]
|
86 |
+
input_json = {"messages": [{"role": "user", "content": content_phrase}], "sounds": data_point["sounds"]}
|
87 |
+
output_json = generate(input_json)
|
88 |
+
print("Prediction ", output_json["output"])
|
89 |
+
print("Reference ", data_point["messages"][1]["content"])
|
90 |
+
print()
|
91 |
+
print()
|
92 |
+
predictions.append(output_json["output"])
|
93 |
+
references.append(data_point["messages"][1]["content"])
|
94 |
+
|
95 |
+
# Load evaluation metrics
|
96 |
+
bleu = evaluate.load("bleu")
|
97 |
+
meteor = evaluate.load("meteor")
|
98 |
+
rouge = evaluate.load("rouge")
|
99 |
+
|
100 |
+
# Compute BLEU scores
|
101 |
+
bleu_results = bleu.compute(predictions=predictions, references=references, max_order=4)
|
102 |
+
print(bleu_results)
|
103 |
+
#bleu_score = sum(bleu_results[f"bleu{i}"] for i in range(1, 5)) / 4
|
104 |
+
|
105 |
+
# Compute METEOR score
|
106 |
+
meteor_results = meteor.compute(predictions=predictions, references=references)
|
107 |
+
meteor_score = meteor_results["meteor"]
|
108 |
+
|
109 |
+
# Compute ROUGE-L score
|
110 |
+
rouge_results = rouge.compute(predictions=predictions, references=references, rouge_types=["rougeL"])
|
111 |
+
# rouge_l_score = rouge_results["rougeL"].mid.fmeasure
|
112 |
+
print(rouge_results)
|
113 |
+
|
114 |
+
# Compute BERT-Score
|
115 |
+
P, R, F1 = bert_score.score(predictions, references, lang="en", rescale_with_baseline=True)
|
116 |
+
bert_score_f1 = F1.mean().item()
|
117 |
+
|
118 |
+
# Print results
|
119 |
+
#print(f"BLEU Score: {bleu_score}")
|
120 |
+
print(f"METEOR Score: {meteor_score}")
|
121 |
+
# print(f"ROUGE-L Score: {rouge_l_score}")
|
122 |
+
print(f"BERT-Score F1: {bert_score_f1}")
|