annabeth97c commited on
Commit
12f2e48
·
verified ·
1 Parent(s): ecc971f

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. src/sonicverse/configs/tasks.json +208 -0
  2. src/sonicverse/configs/tasks_baseline.json +20 -0
  3. src/sonicverse/configs/tasks_ft.json +208 -0
  4. src/sonicverse/configs/tasks_pt_weight.json +10 -0
  5. src/sonicverse/configs/zero2.json +23 -0
  6. src/sonicverse/configs/zero3.json +28 -0
  7. src/sonicverse/configs/zero3_offload.json +56 -0
  8. src/sonicverse/multi_token.egg-info/PKG-INFO +6 -0
  9. src/sonicverse/multi_token.egg-info/SOURCES.txt +6 -0
  10. src/sonicverse/multi_token.egg-info/dependency_links.txt +1 -0
  11. src/sonicverse/multi_token.egg-info/requires.txt +8 -0
  12. src/sonicverse/multi_token.egg-info/top_level.txt +1 -0
  13. src/sonicverse/multi_token/constants.py +4 -0
  14. src/sonicverse/multi_token/data_tools.py +336 -0
  15. src/sonicverse/multi_token/inference.py +83 -0
  16. src/sonicverse/multi_token/language_models/__init__.py +7 -0
  17. src/sonicverse/multi_token/language_models/base_model.py +181 -0
  18. src/sonicverse/multi_token/language_models/mistral.py +235 -0
  19. src/sonicverse/multi_token/modalities/__init__.py +31 -0
  20. src/sonicverse/multi_token/modalities/audio_clap.py +142 -0
  21. src/sonicverse/multi_token/modalities/audio_descript.py +169 -0
  22. src/sonicverse/multi_token/modalities/audio_descript_bu.py +133 -0
  23. src/sonicverse/multi_token/modalities/audio_mert.py +162 -0
  24. src/sonicverse/multi_token/modalities/audio_mert_bu.py +159 -0
  25. src/sonicverse/multi_token/modalities/audio_whisper.py +120 -0
  26. src/sonicverse/multi_token/modalities/base_modality.py +48 -0
  27. src/sonicverse/multi_token/modalities/bu__init__.py +31 -0
  28. src/sonicverse/multi_token/modalities/document_gte.py +144 -0
  29. src/sonicverse/multi_token/modalities/imagebind.py +153 -0
  30. src/sonicverse/multi_token/modalities/multi_task_projector_shared.py +321 -0
  31. src/sonicverse/multi_token/modalities/projectors.py +416 -0
  32. src/sonicverse/multi_token/modalities/video_xclip.py +113 -0
  33. src/sonicverse/multi_token/modalities/vision_clip.py +178 -0
  34. src/sonicverse/multi_token/model_utils.py +112 -0
  35. src/sonicverse/multi_token/training.py +344 -0
  36. src/sonicverse/multi_token/training_data.py +133 -0
  37. src/sonicverse/requirements.txt +8 -0
  38. src/sonicverse/scripts/audio_setup.sh +3 -0
  39. src/sonicverse/scripts/clap_gpt_build_finetune_dataset.py +155 -0
  40. src/sonicverse/scripts/clap_gpt_build_pretrain_dataset.py +142 -0
  41. src/sonicverse/scripts/document_build_finetune_dataset.py +162 -0
  42. src/sonicverse/scripts/document_build_pretrain_dataset.py +89 -0
  43. src/sonicverse/scripts/document_setup.sh +5 -0
  44. src/sonicverse/scripts/evaluate_model.py +112 -0
  45. src/sonicverse/scripts/evaluate_model_latest.py +127 -0
  46. src/sonicverse/scripts/evaluate_model_mullama.py +168 -0
  47. src/sonicverse/scripts/evaluate_model_mullama_musiccaps.py +143 -0
  48. src/sonicverse/scripts/evaluate_model_mullama_musiccaps_fixed_prompt.py +138 -0
  49. src/sonicverse/scripts/evaluate_mullama.py +115 -0
  50. src/sonicverse/scripts/evaluate_temp.py +122 -0
src/sonicverse/configs/tasks.json ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backbone": {
3
+ "num_layers": 5,
4
+ "input_channels": 25,
5
+ "output_channels": 25,
6
+ "output_size": 4096,
7
+ "hidden_size": 4096,
8
+ "requires_grad": true
9
+ },
10
+ "task_heads": {
11
+ "lmm_projector": {
12
+ "num_layers": 3,
13
+ "output_size": 4096,
14
+ "hidden_size": 4096,
15
+ "input_size": 768,
16
+ "input_channels": 13,
17
+ "width": 40,
18
+ "weight": 1.0,
19
+ "model_type": "mlp",
20
+ "requires_grad": true,
21
+ "use_aggregator": true,
22
+ "use_time_average": true,
23
+ "use_sigmoid": false,
24
+ "use_transpose": false,
25
+ "use_backbone_output": false
26
+ },
27
+ "instrument_detection": {
28
+ "model_type": "mlp",
29
+ "use_aggregator": true,
30
+ "use_time_average": true,
31
+ "use_sigmoid": true,
32
+ "use_transpose": false,
33
+ "num_layers": 2,
34
+ "input_size": 508,
35
+ "output_size": 40,
36
+ "hidden_size": 4096,
37
+ "width": 1,
38
+ "weight": 0.1,
39
+ "requires_grad": true,
40
+ "num_conv_layers": 4,
41
+ "output_channel": 1
42
+ },
43
+ "mood_detection": {
44
+ "model_type": "mlp",
45
+ "use_aggregator": true,
46
+ "use_time_average": true,
47
+ "use_sigmoid": true,
48
+ "use_transpose": false,
49
+ "num_layers": 2,
50
+ "input_size": 508,
51
+ "output_size": 56,
52
+ "hidden_size": 4096,
53
+ "width": 1,
54
+ "weight": 0.1,
55
+ "requires_grad": true,
56
+ "num_conv_layers": 4,
57
+ "output_channel": 1
58
+ },
59
+ "genre_detection": {
60
+ "model_type": "mlp",
61
+ "use_aggregator": true,
62
+ "use_time_average": true,
63
+ "use_sigmoid": true,
64
+ "use_transpose": false,
65
+ "num_layers": 2,
66
+ "input_size": 508,
67
+ "output_size": 87,
68
+ "hidden_size": 4096,
69
+ "width": 1,
70
+ "weight": 0.1,
71
+ "requires_grad": true,
72
+ "num_conv_layers": 4,
73
+ "output_channel": 1
74
+ },
75
+ "key_detection": {
76
+ "model_type": "mlp",
77
+ "use_aggregator": true,
78
+ "use_time_average": true,
79
+ "use_sigmoid": true,
80
+ "use_transpose": false,
81
+ "num_layers": 2,
82
+ "input_size": 508,
83
+ "output_size": 24,
84
+ "hidden_size": 4096,
85
+ "width": 1,
86
+ "weight": 0.1,
87
+ "requires_grad": true,
88
+ "num_conv_layers": 4,
89
+ "output_channel": 1
90
+ },
91
+ "vocals_detection": {
92
+ "model_type": "mlp",
93
+ "use_aggregator": true,
94
+ "use_time_average": true,
95
+ "use_sigmoid": true,
96
+ "use_transpose": false,
97
+ "num_layers": 2,
98
+ "input_size": 508,
99
+ "output_size": 3,
100
+ "hidden_size": 4096,
101
+ "width": 1,
102
+ "weight": 0.1,
103
+ "requires_grad": true,
104
+ "num_conv_layers": 4,
105
+ "output_channel": 1
106
+ }
107
+ },
108
+ "task_projectors": {
109
+ "instrument_detection": {
110
+ "model_type": "mlp",
111
+ "num_layers": 3,
112
+ "input_channels": 0,
113
+ "input_size": 40,
114
+ "output_size": 4096,
115
+ "hidden_size": 4096,
116
+ "width": 4,
117
+ "use_aggregator": false,
118
+ "use_time_average": false,
119
+ "use_sigmoid": false,
120
+ "use_transpose": false,
121
+ "requires_grad": true
122
+ },
123
+ "mood_detection": {
124
+ "model_type": "mlp",
125
+ "num_layers": 3,
126
+ "input_channels": 0,
127
+ "input_size": 56,
128
+ "output_size": 4096,
129
+ "hidden_size": 4096,
130
+ "width": 4,
131
+ "use_aggregator": false,
132
+ "use_time_average": false,
133
+ "use_sigmoid": false,
134
+ "use_transpose": false,
135
+ "requires_grad": true
136
+ },
137
+ "genre_detection": {
138
+ "model_type": "mlp",
139
+ "num_layers": 3,
140
+ "input_channels": 0,
141
+ "input_size": 87,
142
+ "output_size": 4096,
143
+ "hidden_size": 4096,
144
+ "width": 4,
145
+ "use_aggregator": false,
146
+ "use_time_average": false,
147
+ "use_sigmoid": false,
148
+ "use_transpose": false,
149
+ "requires_grad": true
150
+ },
151
+ "key_detection": {
152
+ "model_type": "mlp",
153
+ "num_layers": 3,
154
+ "input_channels": 0,
155
+ "input_size": 24,
156
+ "output_size": 4096,
157
+ "hidden_size": 4096,
158
+ "width": 4,
159
+ "use_aggregator": false,
160
+ "use_time_average": false,
161
+ "use_sigmoid": false,
162
+ "use_transpose": false,
163
+ "requires_grad": true
164
+ },
165
+ "vocals_detection": {
166
+ "model_type": "mlp",
167
+ "num_layers": 3,
168
+ "input_channels": 0,
169
+ "input_size": 3,
170
+ "output_size": 4096,
171
+ "hidden_size": 4096,
172
+ "width": 4,
173
+ "use_aggregator": false,
174
+ "use_time_average": false,
175
+ "use_sigmoid": false,
176
+ "use_transpose": false,
177
+ "requires_grad": true
178
+ },
179
+ "chords_detection": {
180
+ "model_type": "mlp",
181
+ "num_layers": 3,
182
+ "input_channels": 0,
183
+ "input_size": 216,
184
+ "output_size": 4096,
185
+ "hidden_size": 4096,
186
+ "width": 4,
187
+ "use_aggregator": false,
188
+ "use_time_average": false,
189
+ "use_sigmoid": false,
190
+ "use_transpose": false,
191
+ "requires_grad": true
192
+ },
193
+ "beats_detection": {
194
+ "model_type": "mlp_conv_agg",
195
+ "num_layers": 3,
196
+ "input_channels": 2,
197
+ "input_size": 500,
198
+ "output_size": 4096,
199
+ "hidden_size": 4096,
200
+ "width": 4,
201
+ "use_aggregator": true,
202
+ "use_time_average": false,
203
+ "use_sigmoid": false,
204
+ "use_transpose": true,
205
+ "requires_grad": true
206
+ }
207
+ }
208
+ }
src/sonicverse/configs/tasks_baseline.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "task_heads": {
3
+ "lmm_projector": {
4
+ "num_layers": 3,
5
+ "output_size": 4096,
6
+ "hidden_size": 4096,
7
+ "input_size": 768,
8
+ "input_channels": 13,
9
+ "width": 60,
10
+ "weight": 1.0,
11
+ "model_type": "mlp",
12
+ "requires_grad": true,
13
+ "use_aggregator": true,
14
+ "use_time_average": true,
15
+ "use_sigmoid": false,
16
+ "use_transpose": false
17
+ }
18
+ },
19
+ "task_projectors": {}
20
+ }
src/sonicverse/configs/tasks_ft.json ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backbone": {
3
+ "num_layers": 5,
4
+ "input_channels": 25,
5
+ "output_channels": 25,
6
+ "output_size": 4096,
7
+ "hidden_size": 4096,
8
+ "requires_grad": false
9
+ },
10
+ "task_heads": {
11
+ "lmm_projector": {
12
+ "num_layers": 3,
13
+ "output_size": 4096,
14
+ "hidden_size": 4096,
15
+ "input_size": 768,
16
+ "input_channels": 13,
17
+ "width": 40,
18
+ "weight": 1.0,
19
+ "model_type": "mlp",
20
+ "requires_grad": true,
21
+ "use_aggregator": true,
22
+ "use_time_average": true,
23
+ "use_sigmoid": false,
24
+ "use_transpose": false,
25
+ "use_backbone_output": false
26
+ },
27
+ "instrument_detection": {
28
+ "model_type": "mlp",
29
+ "use_aggregator": true,
30
+ "use_time_average": true,
31
+ "use_sigmoid": true,
32
+ "use_transpose": false,
33
+ "num_layers": 2,
34
+ "input_size": 508,
35
+ "output_size": 40,
36
+ "hidden_size": 4096,
37
+ "width": 1,
38
+ "weight": 0.0,
39
+ "requires_grad": false,
40
+ "num_conv_layers": 4,
41
+ "output_channel": 1
42
+ },
43
+ "mood_detection": {
44
+ "model_type": "mlp",
45
+ "use_aggregator": true,
46
+ "use_time_average": true,
47
+ "use_sigmoid": true,
48
+ "use_transpose": false,
49
+ "num_layers": 2,
50
+ "input_size": 508,
51
+ "output_size": 56,
52
+ "hidden_size": 4096,
53
+ "width": 1,
54
+ "weight": 0.0,
55
+ "requires_grad": false,
56
+ "num_conv_layers": 4,
57
+ "output_channel": 1
58
+ },
59
+ "genre_detection": {
60
+ "model_type": "mlp",
61
+ "use_aggregator": true,
62
+ "use_time_average": true,
63
+ "use_sigmoid": true,
64
+ "use_transpose": false,
65
+ "num_layers": 2,
66
+ "input_size": 508,
67
+ "output_size": 87,
68
+ "hidden_size": 4096,
69
+ "width": 1,
70
+ "weight": 0.0,
71
+ "requires_grad": false,
72
+ "num_conv_layers": 4,
73
+ "output_channel": 1
74
+ },
75
+ "key_detection": {
76
+ "model_type": "mlp",
77
+ "use_aggregator": true,
78
+ "use_time_average": true,
79
+ "use_sigmoid": true,
80
+ "use_transpose": false,
81
+ "num_layers": 2,
82
+ "input_size": 508,
83
+ "output_size": 24,
84
+ "hidden_size": 4096,
85
+ "width": 1,
86
+ "weight": 0.0,
87
+ "requires_grad": false,
88
+ "num_conv_layers": 4,
89
+ "output_channel": 1
90
+ },
91
+ "vocals_detection": {
92
+ "model_type": "mlp",
93
+ "use_aggregator": true,
94
+ "use_time_average": true,
95
+ "use_sigmoid": true,
96
+ "use_transpose": false,
97
+ "num_layers": 2,
98
+ "input_size": 508,
99
+ "output_size": 3,
100
+ "hidden_size": 4096,
101
+ "width": 1,
102
+ "weight": 0.0,
103
+ "requires_grad": false,
104
+ "num_conv_layers": 4,
105
+ "output_channel": 1
106
+ }
107
+ },
108
+ "task_projectors": {
109
+ "instrument_detection": {
110
+ "model_type": "mlp",
111
+ "num_layers": 3,
112
+ "input_channels": 0,
113
+ "input_size": 40,
114
+ "output_size": 4096,
115
+ "hidden_size": 4096,
116
+ "width": 4,
117
+ "use_aggregator": false,
118
+ "use_time_average": false,
119
+ "use_sigmoid": false,
120
+ "use_transpose": false,
121
+ "requires_grad": true
122
+ },
123
+ "mood_detection": {
124
+ "model_type": "mlp",
125
+ "num_layers": 3,
126
+ "input_channels": 0,
127
+ "input_size": 56,
128
+ "output_size": 4096,
129
+ "hidden_size": 4096,
130
+ "width": 4,
131
+ "use_aggregator": false,
132
+ "use_time_average": false,
133
+ "use_sigmoid": false,
134
+ "use_transpose": false,
135
+ "requires_grad": true
136
+ },
137
+ "genre_detection": {
138
+ "model_type": "mlp",
139
+ "num_layers": 3,
140
+ "input_channels": 0,
141
+ "input_size": 87,
142
+ "output_size": 4096,
143
+ "hidden_size": 4096,
144
+ "width": 4,
145
+ "use_aggregator": false,
146
+ "use_time_average": false,
147
+ "use_sigmoid": false,
148
+ "use_transpose": false,
149
+ "requires_grad": true
150
+ },
151
+ "key_detection": {
152
+ "model_type": "mlp",
153
+ "num_layers": 3,
154
+ "input_channels": 0,
155
+ "input_size": 24,
156
+ "output_size": 4096,
157
+ "hidden_size": 4096,
158
+ "width": 4,
159
+ "use_aggregator": false,
160
+ "use_time_average": false,
161
+ "use_sigmoid": false,
162
+ "use_transpose": false,
163
+ "requires_grad": true
164
+ },
165
+ "vocals_detection": {
166
+ "model_type": "mlp",
167
+ "num_layers": 3,
168
+ "input_channels": 0,
169
+ "input_size": 3,
170
+ "output_size": 4096,
171
+ "hidden_size": 4096,
172
+ "width": 4,
173
+ "use_aggregator": false,
174
+ "use_time_average": false,
175
+ "use_sigmoid": false,
176
+ "use_transpose": false,
177
+ "requires_grad": true
178
+ },
179
+ "chords_detection": {
180
+ "model_type": "mlp",
181
+ "num_layers": 3,
182
+ "input_channels": 0,
183
+ "input_size": 216,
184
+ "output_size": 4096,
185
+ "hidden_size": 4096,
186
+ "width": 4,
187
+ "use_aggregator": false,
188
+ "use_time_average": false,
189
+ "use_sigmoid": false,
190
+ "use_transpose": false,
191
+ "requires_grad": true
192
+ },
193
+ "beats_detection": {
194
+ "model_type": "mlp_conv_agg",
195
+ "num_layers": 3,
196
+ "input_channels": 2,
197
+ "input_size": 500,
198
+ "output_size": 4096,
199
+ "hidden_size": 4096,
200
+ "width": 4,
201
+ "use_aggregator": true,
202
+ "use_time_average": false,
203
+ "use_sigmoid": false,
204
+ "use_transpose": true,
205
+ "requires_grad": true
206
+ }
207
+ }
208
+ }
src/sonicverse/configs/tasks_pt_weight.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "pretrained_paths": [
3
+ {
4
+ "path": "/experiments/music_extraction/mlp_shared_multi_task_trial_002/train_002_epoch_45_step_40.pth",
5
+ "components": ["backbone", "instrument_detection", "genre_detection", "mood_detection", "key_detection", "vocals_detection"],
6
+ "use_prefix": true,
7
+ "prefix": "audio_mert_lmm_projector"
8
+ }
9
+ ]
10
+ }
src/sonicverse/configs/zero2.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "train_micro_batch_size_per_gpu": "auto",
14
+ "train_batch_size": "auto",
15
+ "gradient_accumulation_steps": "auto",
16
+ "zero_optimization": {
17
+ "stage": 2,
18
+ "overlap_comm": true,
19
+ "contiguous_gradients": true,
20
+ "sub_group_size": 1e9,
21
+ "reduce_bucket_size": "auto"
22
+ }
23
+ }
src/sonicverse/configs/zero3.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "train_micro_batch_size_per_gpu": "auto",
14
+ "train_batch_size": "auto",
15
+ "gradient_accumulation_steps": "auto",
16
+ "zero_optimization": {
17
+ "stage": 3,
18
+ "overlap_comm": true,
19
+ "contiguous_gradients": true,
20
+ "sub_group_size": 1e9,
21
+ "reduce_bucket_size": "auto",
22
+ "stage3_prefetch_bucket_size": "auto",
23
+ "stage3_param_persistence_threshold": "auto",
24
+ "stage3_max_live_parameters": 1e9,
25
+ "stage3_max_reuse_distance": 1e9,
26
+ "stage3_gather_16bit_weights_on_model_save": true
27
+ }
28
+ }
src/sonicverse/configs/zero3_offload.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+ "scheduler": {
23
+ "type": "WarmupLR",
24
+ "params": {
25
+ "warmup_min_lr": "auto",
26
+ "warmup_max_lr": "auto",
27
+ "warmup_num_steps": "auto"
28
+ }
29
+ },
30
+ "zero_optimization": {
31
+ "stage": 3,
32
+ "offload_optimizer": {
33
+ "device": "cpu",
34
+ "pin_memory": true
35
+ },
36
+ "offload_param": {
37
+ "device": "cpu",
38
+ "pin_memory": true
39
+ },
40
+ "overlap_comm": true,
41
+ "contiguous_gradients": true,
42
+ "sub_group_size": 1e9,
43
+ "reduce_bucket_size": "auto",
44
+ "stage3_prefetch_bucket_size": "auto",
45
+ "stage3_param_persistence_threshold": "auto",
46
+ "stage3_max_live_parameters": 1e9,
47
+ "stage3_max_reuse_distance": 1e9,
48
+ "gather_16bit_weights_on_model_save": true
49
+ },
50
+ "gradient_accumulation_steps": "auto",
51
+ "gradient_clipping": "auto",
52
+ "train_batch_size": "auto",
53
+ "train_micro_batch_size_per_gpu": "auto",
54
+ "steps_per_print": 1e5,
55
+ "wall_clock_breakdown": false
56
+ }
src/sonicverse/multi_token.egg-info/PKG-INFO ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: multi-token
3
+ Version: 0.0.4
4
+ Home-page: https://github.com/sshh12/multi_token
5
+ Author: Shrivu Shankar
6
+ License: Apache License 2.0
src/sonicverse/multi_token.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ setup.py
2
+ multi_token.egg-info/PKG-INFO
3
+ multi_token.egg-info/SOURCES.txt
4
+ multi_token.egg-info/dependency_links.txt
5
+ multi_token.egg-info/requires.txt
6
+ multi_token.egg-info/top_level.txt
src/sonicverse/multi_token.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
src/sonicverse/multi_token.egg-info/requires.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.34.0
2
+ accelerate>=0.21.0
3
+ scipy>=1.11.3
4
+ bitsandbytes>=0.41.0
5
+ datasets>=2.14.5
6
+ sentencepiece>=0.1.99
7
+ peft>=0.4.0
8
+ deepspeed==0.9.5
src/sonicverse/multi_token.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
src/sonicverse/multi_token/constants.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ IGNORE_INDEX = -100
2
+
3
+ ROLE_ASSISTANT = "assistant"
4
+ ROLE_USER = "user"
src/sonicverse/multi_token/data_tools.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Union, Optional
2
+ from collections import Counter
3
+ from functools import cache
4
+ import contextlib
5
+ import tempfile
6
+ import shutil
7
+ import random
8
+ import subprocess
9
+ import json
10
+ import re
11
+ import io
12
+ import os
13
+
14
+ import torch
15
+ import requests
16
+ import transformers
17
+ import numpy as np
18
+ from datasets import load_dataset, Dataset
19
+ from PIL import Image
20
+
21
+ from multi_token.constants import IGNORE_INDEX
22
+
23
+
24
+ def encode_chat(
25
+ item: Dict,
26
+ tokenizer: transformers.PreTrainedTokenizer,
27
+ modalities: List["Modality"],
28
+ ) -> Dict:
29
+ messages = list(item["messages"])
30
+ chat_as_string = tokenizer.apply_chat_template(messages, tokenize=False)
31
+
32
+ token_to_modality = {m.token: m for m in modalities}
33
+ modality_token_counts = Counter()
34
+ instruct_pattern = r"(\[INST\][\s\S]*?\[\/INST\])"
35
+ pattern = "(" + "|".join(re.escape(m.token) for m in modalities) + ")"
36
+
37
+ chat_part = re.split(instruct_pattern, chat_as_string)
38
+ input_ids = []
39
+ labels = []
40
+ for part in chat_part:
41
+ if "[INST]" in part:
42
+ is_instruction = True
43
+ else:
44
+ is_instruction = False
45
+ for subpart in re.split(pattern, part):
46
+ if not subpart:
47
+ continue
48
+ if subpart in token_to_modality:
49
+ assert (
50
+ is_instruction
51
+ ), "There should be no modality tokens outside of instructions"
52
+ m = token_to_modality[subpart]
53
+ modality_token_counts[m.name] += 1
54
+ input_ids.extend([m.token_idx] * m.token_width)
55
+ labels.extend([IGNORE_INDEX] * m.token_width)
56
+ elif is_instruction:
57
+ part_ids = tokenizer(subpart, add_special_tokens=False).input_ids
58
+ input_ids.extend(part_ids)
59
+ labels.extend([IGNORE_INDEX] * len(part_ids))
60
+ else:
61
+ part_ids = tokenizer(subpart, add_special_tokens=False).input_ids
62
+ input_ids.extend(part_ids)
63
+ labels.extend(part_ids)
64
+
65
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
66
+ labels = torch.tensor(labels, dtype=torch.long)
67
+
68
+ data_dict = dict(
69
+ input_ids=input_ids,
70
+ labels=labels,
71
+ )
72
+ for m in modalities:
73
+ data_dict[m.name] = m.preprocess_rows([item])[0]
74
+ return data_dict
75
+
76
+ def encode_chat_multitask(
77
+ item: Dict,
78
+ tokenizer: transformers.PreTrainedTokenizer,
79
+ modalities: List["Modality"],
80
+ ) -> Dict:
81
+ messages = list(item["messages"])
82
+ chat_as_string = tokenizer.apply_chat_template(messages, tokenize=False)
83
+
84
+ token_to_modality = {m.token: m for m in modalities}
85
+ modality_token_counts = Counter()
86
+ instruct_pattern = r"(\[INST\][\s\S]*?\[\/INST\])"
87
+ pattern = "(" + "|".join(re.escape(m.token) for m in modalities) + ")"
88
+
89
+ chat_part = re.split(instruct_pattern, chat_as_string)
90
+ input_ids = []
91
+ labels = []
92
+ labels.append([])
93
+ for part in chat_part:
94
+ if "[INST]" in part:
95
+ is_instruction = True
96
+ else:
97
+ is_instruction = False
98
+ for subpart in re.split(pattern, part):
99
+ if not subpart:
100
+ continue
101
+ if subpart in token_to_modality:
102
+ assert (
103
+ is_instruction
104
+ ), "There should be no modality tokens outside of instructions"
105
+ m = token_to_modality[subpart]
106
+ modality_token_counts[m.name] += 1
107
+ input_ids.extend([m.token_idx] * m.token_width)
108
+ labels[0].extend([IGNORE_INDEX] * m.token_width)
109
+ elif is_instruction:
110
+ part_ids = tokenizer(subpart, add_special_tokens=False).input_ids
111
+ input_ids.extend(part_ids)
112
+ labels[0].extend([IGNORE_INDEX] * len(part_ids))
113
+ else:
114
+ part_ids = tokenizer(subpart, add_special_tokens=False).input_ids
115
+ input_ids.extend(part_ids)
116
+ labels[0].extend(part_ids)
117
+
118
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
119
+ labels[0] = torch.tensor(labels[0], dtype=torch.long)
120
+
121
+ task_list = []
122
+ for m in modalities:
123
+ task_list += m.tasks["task_heads"].keys()
124
+ # labels[task_specs["task_id"]] = load_tensor(item[task_name][0])
125
+
126
+ for task_name in task_list:
127
+ if task_name != "lmm_projector":
128
+ labels.append(load_tensor(item[task_name][0]))
129
+
130
+ # labels = torch.tensor(labels, dtype=torch.long)
131
+
132
+ data_dict = dict(
133
+ input_ids=input_ids,
134
+ labels=labels,
135
+ )
136
+ for m in modalities:
137
+ data_dict[m.name] = m.preprocess_rows([item])[0]
138
+ return data_dict
139
+
140
+ def load_tensor(path: str) -> np.ndarray:
141
+ return torch.tensor(np.load(path))
142
+
143
+
144
+ def load_image(value: Any) -> Image.Image:
145
+ img = None
146
+ if isinstance(value, str):
147
+ if value.startswith("http://") or value.startswith("https://"):
148
+ response = requests.get(value)
149
+ img = Image.open(io.BytesIO(response.content))
150
+ elif os.path.exists(value):
151
+ img = Image.open(value)
152
+ elif isinstance(value, Image.Image):
153
+ img = value
154
+ if img is None:
155
+ raise ValueError(f"Could not load image from {value}")
156
+ img = img.convert("RGB")
157
+ return img
158
+
159
+
160
+ @contextlib.contextmanager
161
+ def with_local_files(fn_or_urls: List[Any]):
162
+ local_fns = []
163
+ fps = []
164
+ for fn_or_url in fn_or_urls:
165
+ if isinstance(fn_or_url, Image.Image):
166
+ fp = tempfile.NamedTemporaryFile(suffix=".png", mode="wb")
167
+ fn_or_url.convert("RGB").save(fp)
168
+ fps.append(fp)
169
+ local_fns.append(fp.name)
170
+ elif fn_or_url.startswith("http://") or fn_or_url.startswith("https://"):
171
+ suffix = os.path.splitext(fn_or_url)[-1]
172
+ with requests.get(fn_or_url, stream=True) as r:
173
+ fp = tempfile.NamedTemporaryFile(suffix=suffix, mode="wb")
174
+ shutil.copyfileobj(r.raw, fp)
175
+ fps.append(fp)
176
+ local_fns.append(fp.name)
177
+ else:
178
+ local_fns.append(fn_or_url)
179
+ try:
180
+ yield local_fns
181
+ finally:
182
+ for fp in fps:
183
+ fp.close()
184
+
185
+
186
+ @cache
187
+ def _get_dataset(dataset_args: str) -> Dataset:
188
+ return load_dataset(**json.loads(dataset_args))
189
+
190
+
191
+ def get_dataset_cached(dataset_args: Dict) -> Dataset:
192
+ return _get_dataset(json.dumps(dataset_args))
193
+
194
+
195
+ def load_audio_signal(input_: Union[Dict, str]) -> Dict:
196
+ from audiotools import AudioSignal
197
+
198
+ if isinstance(input_, dict) and "array" in input_:
199
+ array = input_["array"]
200
+ elif isinstance(input_, dict) and "dataset_args" in input_:
201
+ item = get_dataset_cached(input_["dataset_args"])[input_["idx"]]
202
+ array = item["audio"]["array"]
203
+ elif isinstance(input_, dict) and "path" in input_:
204
+ with with_local_files([input_["path"]]) as local_fns:
205
+ array = AudioSignal(local_fns[0])
206
+ elif isinstance(input_, str):
207
+ with with_local_files([input_]) as local_fns:
208
+ array = AudioSignal(local_fns[0])
209
+ else:
210
+ raise ValueError(f"Could not load audio from {input_}")
211
+
212
+ return {"array": list(array)}
213
+
214
+
215
+ def load_audio(input_: Union[Dict, str], target_sampling_rate: int = None) -> Dict:
216
+ import soundfile as sf
217
+ import librosa
218
+
219
+ if isinstance(input_, dict) and "array" in input_ and "sampling_rate" in input_:
220
+ array = input_["array"]
221
+ sampling_rate = input_["sampling_rate"]
222
+ elif isinstance(input_, dict) and "dataset_args" in input_:
223
+ item = get_dataset_cached(input_["dataset_args"])[input_["idx"]]
224
+ array = item["audio"]["array"]
225
+ sampling_rate = item["audio"]["sampling_rate"]
226
+ elif isinstance(input_, dict) and "path" in input_:
227
+ with with_local_files([input_["path"]]) as local_fns:
228
+ array, sampling_rate = sf.read(local_fns[0])
229
+ elif isinstance(input_, str):
230
+ with with_local_files([input_]) as local_fns:
231
+ array, sampling_rate = sf.read(local_fns[0])
232
+ else:
233
+ raise ValueError(f"Could not load audio from {input_}")
234
+
235
+ if array.ndim == 2:
236
+ array = array.mean(axis=1)
237
+
238
+ if target_sampling_rate is not None and sampling_rate != target_sampling_rate:
239
+ array = librosa.resample(
240
+ array, orig_sr=sampling_rate, target_sr=target_sampling_rate
241
+ )
242
+ sampling_rate = target_sampling_rate
243
+
244
+ return {"array": list(array), "sampling_rate": sampling_rate}
245
+
246
+
247
+ def _download_yt_video(url: str) -> str:
248
+ from pytube import YouTube
249
+
250
+ youtube = YouTube(url)
251
+ video = youtube.streams.first()
252
+
253
+ fn = "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=10))
254
+ file_path = video.download(output_path=tempfile.gettempdir(), filename=fn)
255
+
256
+ return file_path
257
+
258
+
259
+ def _read_video_pyav(container, indices):
260
+ frames = []
261
+ container.seek(0)
262
+ start_index = indices[0]
263
+ end_index = indices[-1]
264
+ for i, frame in enumerate(container.decode(video=0)):
265
+ if i > end_index:
266
+ break
267
+ if i >= start_index and i in indices:
268
+ frames.append(frame)
269
+ return np.stack([x.to_ndarray(format="rgb24") for x in frames])
270
+
271
+
272
+ def _sample_frame_indices(clip_len, frame_sample_rate, seg_len):
273
+ converted_len = int(clip_len * frame_sample_rate)
274
+ end_idx = np.random.randint(converted_len, seg_len)
275
+ start_idx = end_idx - converted_len
276
+ indices = np.linspace(start_idx, end_idx, num=clip_len)
277
+ indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
278
+ return indices
279
+
280
+
281
+ def load_video(
282
+ input_: str,
283
+ frames: int = 8,
284
+ frame_sample_rate: int = 1,
285
+ start_time: Optional[int] = None,
286
+ end_time: Optional[int] = None,
287
+ ) -> np.ndarray:
288
+ import av
289
+
290
+ delete_file = False
291
+
292
+ if isinstance(input_, dict) and "youtube.com" and input_.get("url", ""):
293
+ file_path = _download_yt_video(input_["url"])
294
+ delete_file = True
295
+ # start_time = input_.get("start_time", None)
296
+ # end_time = input_.get("end_time", None)
297
+ elif isinstance(input_, str) and "youtube.com" in input_:
298
+ file_path = _download_yt_video(input_)
299
+ delete_file = True
300
+ elif isinstance(input_, str):
301
+ file_path = input_
302
+ else:
303
+ raise ValueError(f"Could not load video from {input_}")
304
+
305
+ if start_time is not None or end_time is not None:
306
+ start_time = start_time if start_time is not None else 0
307
+ end_time = end_time if end_time is not None else "end"
308
+ trim_file_path = f"{file_path.rsplit('.', 1)[0]}_trim.mp4"
309
+ subprocess.run(
310
+ [
311
+ "ffmpeg",
312
+ "-i",
313
+ file_path,
314
+ "-ss",
315
+ str(start_time),
316
+ "-to",
317
+ str(end_time),
318
+ "-c",
319
+ "copy",
320
+ trim_file_path,
321
+ ]
322
+ )
323
+ file_path = trim_file_path
324
+
325
+ container = av.open(file_path)
326
+ indices = _sample_frame_indices(
327
+ clip_len=frames,
328
+ frame_sample_rate=frame_sample_rate,
329
+ seg_len=container.streams.video[0].frames,
330
+ )
331
+ video = _read_video_pyav(container, indices)
332
+
333
+ if delete_file:
334
+ os.remove(file_path)
335
+
336
+ return video
src/sonicverse/multi_token/inference.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type, List, Optional
2
+ import logging
3
+
4
+ from transformers import AutoTokenizer, AutoConfig, BitsAndBytesConfig
5
+ from huggingface_hub import hf_hub_download
6
+ from peft import PeftModel
7
+ import torch
8
+ import os
9
+
10
+ from multi_token.model_utils import fix_tokenizer, MultiTaskType
11
+ from multi_token.modalities.base_modality import Modality
12
+ from multi_token.language_models.mistral import MistralForCausalLM
13
+ from multi_token.language_models import LANGUAGE_MODEL_NAME_TO_CLASS
14
+ from multi_token.modalities import MODALITY_BUILDERS
15
+
16
+
17
+ def load_trained_lora_model(
18
+ model_name_or_path: str,
19
+ model_lora_path: str,
20
+ model_cls: Optional[Type] = None,
21
+ modalities: Optional[List[Modality]] = None,
22
+ load_bits: int = 16,
23
+ device_map: str = "auto",
24
+ use_multi_task: int = MultiTaskType.NO_MULTI_TASK,
25
+ tasks_config: str = None
26
+ ):
27
+ load_kwargs = {"device_map": device_map}
28
+
29
+ if load_bits == 8:
30
+ load_kwargs["load_in_8bit"] = True
31
+ elif load_bits == 4:
32
+ load_kwargs["load_in_4bit"] = True
33
+ load_kwargs["quantization_config"] = BitsAndBytesConfig(
34
+ load_in_4bit=True,
35
+ bnb_4bit_compute_dtype=torch.float16,
36
+ bnb_4bit_use_double_quant=True,
37
+ bnb_4bit_quant_type="nf4",
38
+ )
39
+ elif load_bits == 16:
40
+ load_kwargs["torch_dtype"] = torch.float16
41
+ else:
42
+ raise ValueError(f"Invalid load_bits: {load_bits}")
43
+
44
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
45
+ fix_tokenizer(tokenizer)
46
+
47
+ cfg = AutoConfig.from_pretrained(model_lora_path)
48
+ if model_cls is None:
49
+ model_cls = LANGUAGE_MODEL_NAME_TO_CLASS[cfg.model_cls]
50
+ if modalities is None:
51
+ if use_multi_task:
52
+ modalities = MODALITY_BUILDERS[cfg.modality_builder](use_multi_task = use_multi_task, tasks_config = tasks_config)
53
+ else:
54
+ modalities = MODALITY_BUILDERS[cfg.modality_builder]()
55
+
56
+ logging.info(f"Loading base model from {model_name_or_path} as {load_bits} bits")
57
+ model = model_cls.from_pretrained(
58
+ model_name_or_path, low_cpu_mem_usage=True, config=cfg, **load_kwargs
59
+ )
60
+ model.modalities = modalities
61
+
62
+ logging.info(f"Loading projector weights for {[m.name for m in modalities]}")
63
+ if os.path.exists(os.path.join(model_lora_path, "non_lora_trainables.bin")):
64
+ non_lora_trainables = torch.load(
65
+ os.path.join(model_lora_path, "non_lora_trainables.bin"), map_location="cuda"
66
+ )
67
+ else:
68
+ local_fn = hf_hub_download(
69
+ repo_id=model_lora_path,
70
+ filename="non_lora_trainables.bin",
71
+ repo_type="model",
72
+ )
73
+ non_lora_trainables = torch.load(local_fn, map_location="cuda")
74
+ model.get_model().initialize_pretrained_modules(modalities, non_lora_trainables)
75
+
76
+ logging.info(f"Loading and merging LoRA weights from {model_lora_path}")
77
+ model = PeftModel.from_pretrained(model, model_lora_path)
78
+ if load_bits == 16:
79
+ # TODO: Figure out why this fails for other bit sizes
80
+ model = model.merge_and_unload()
81
+ model.eval()
82
+
83
+ return model, tokenizer
src/sonicverse/multi_token/language_models/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from multi_token.language_models.mistral import (
2
+ MistralLMMForCausalLM,
3
+ )
4
+
5
+ LANGUAGE_MODEL_CLASSES = [MistralLMMForCausalLM]
6
+
7
+ LANGUAGE_MODEL_NAME_TO_CLASS = {cls.__name__: cls for cls in LANGUAGE_MODEL_CLASSES}
src/sonicverse/multi_token/language_models/base_model.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+ from abc import ABC, abstractmethod
3
+
4
+ from torch.nn.functional import conv1d
5
+ import torch
6
+ import logging
7
+
8
+ from multi_token.modalities.base_modality import Modality
9
+ from multi_token.model_utils import MultiTaskType
10
+
11
+ from torchviz import make_dot
12
+
13
+ class LMMMetaModel:
14
+ def __init__(self, config):
15
+ super(LMMMetaModel, self).__init__(config)
16
+
17
+ def _load_projector_weights(self, weights: Dict):
18
+ weights = {
19
+ (k[23:] if k.startswith("base_model.model.model.") else k): v
20
+ for k, v in weights.items()
21
+ }
22
+ logging.info(f"Loading pretrained weights: {list(weights.keys())}")
23
+ load_result = self.load_state_dict(weights, strict=False)
24
+ assert (
25
+ len(load_result.unexpected_keys) == 0
26
+ ), "Unexpected weights, is this the right model?"
27
+
28
+ def initialize_pretrained_modules(self, modalities: List[Modality], weights: Dict):
29
+ for m in modalities:
30
+ # projector = m.build_projector(self.config.hidden_size)
31
+ # setattr(self, m.name + "_lmm_projector", projector)
32
+ projector = m.build_projector(self.config.hidden_size)
33
+ if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
34
+ for task_name in m.tasks["task_heads"].keys():
35
+ task_model = projector[task_name]
36
+ setattr(self, m.name + "_" + task_name, task_model)
37
+ else:
38
+ setattr(self, m.name + "_lmm_projector", projector)
39
+
40
+ self._load_projector_weights(weights)
41
+
42
+ def initialize_modules(self, modalities: List[Modality], weights: Dict):
43
+ names = [m.name for m in modalities]
44
+
45
+ self.config.modalities = names
46
+
47
+ for m in modalities:
48
+ # projector = m.build_projector(self.config.hidden_size)
49
+ # setattr(self, m.name + "_lmm_projector", projector)
50
+ projector = m.build_projector(self.config.hidden_size)
51
+ if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
52
+ for task_name in m.tasks["task_heads"].keys():
53
+ task_model = projector[task_name]
54
+ setattr(self, m.name + "_" + task_name, task_model)
55
+ else:
56
+ setattr(self, m.name + "_lmm_projector", projector)
57
+
58
+ self._load_projector_weights(weights)
59
+
60
+
61
+ class LMMMetaForCausalLM(ABC):
62
+ @abstractmethod
63
+ def get_model(self) -> "LMMMetaForCausalLM":
64
+ pass
65
+
66
+ def prepare_inputs_labels_for_multimodal(
67
+ self, input_ids, attention_mask, past_key_values, labels, **kwargs
68
+ ):
69
+ model = self.get_model()
70
+
71
+ batch_size, seq_len = input_ids.shape
72
+
73
+ # batch_size x seq_len x embedding_hidden_size
74
+ inputs_embeds = torch.zeros(
75
+ (batch_size, seq_len, self.config.hidden_size),
76
+ dtype=self.dtype,
77
+ device=self.device,
78
+ )
79
+
80
+ # modality x batch_size x instance_idx x modality_token_width x embedding_hidden_size
81
+ projected_tensors = []
82
+ # assuming that if caching is enabled, we'll never have past_key_values AND need to encode the instruction modality values
83
+ task_vals = {}
84
+
85
+ #print("here past_key_values", past_key_values)
86
+ #past_key_values == None
87
+ if past_key_values is None:
88
+ for m in self.modalities:
89
+ m_vals = m.forward(kwargs.get(m.name))
90
+ mp_vals = []
91
+ if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
92
+ proj = {}
93
+ for task_name in m.tasks["task_heads"].keys():
94
+ proj[task_name] = getattr(model, m.name + "_" + task_name)
95
+ else:
96
+ proj = getattr(model, m.name + "_lmm_projector")
97
+
98
+ # project each batch into language model token space
99
+ for m_val in m_vals:
100
+ if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
101
+ for task_name in m.tasks["task_heads"].keys():
102
+ if task_name == "lmm_projector":
103
+ mp_vals.append(proj[task_name](m_val))
104
+ # make_dot(mp_vals[-1], params=dict(list(model.named_parameters()))).render(task_name, format="png")
105
+ else:
106
+ if task_name not in task_vals:
107
+ task_vals[task_name] = [proj[task_name](m_val)]
108
+ else:
109
+ task_vals[task_name].append(proj[task_name](m_val))
110
+ # make_dot(task_vals[task_name], params=dict(list(model.named_parameters()))).render(task_name, format="png")
111
+
112
+ elif m.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK:
113
+ task_outputs = proj(m_val)
114
+ mp_vals.append(task_outputs.pop("projectors"))
115
+ for task_name in task_outputs.keys():
116
+ if not task_name in task_vals:
117
+ task_vals[task_name] = [task_outputs[task_name]]
118
+ else:
119
+ task_vals[task_name].append(task_outputs[task_name])
120
+ else:
121
+ mp_vals.append(proj(m_val))
122
+
123
+ assert all(
124
+ mp_val.shape[1:] == (m.token_width, self.config.hidden_size)
125
+ for mp_val in mp_vals
126
+ ), (
127
+ "Modality tensors have incorrect shape, check your projector implementation "
128
+ + str([mp_val.shape[1:] for mp_val in mp_vals])
129
+ + " vs expected "
130
+ + str((m.token_width, self.config.hidden_size))
131
+ )
132
+ projected_tensors.append(mp_vals)
133
+
134
+ indices = None
135
+ for i, input_ids_sample in enumerate(input_ids):
136
+ is_text_mask = input_ids_sample >= 0
137
+
138
+ # fill in all the LLM-based text embeddings
139
+ inputs_embeds[i, is_text_mask] = model.embed_tokens(
140
+ input_ids_sample[is_text_mask]
141
+ )
142
+
143
+ # skip if all tokens are text tokens
144
+ if is_text_mask.sum() == seq_len:
145
+ continue
146
+ assert (
147
+ past_key_values is None
148
+ ), "We shouldn't have cached keys if this is the first instruction pass"
149
+
150
+ #past_key_values = None
151
+
152
+ for mi, m in enumerate(self.modalities):
153
+ # locate the group of tokens for this modality
154
+ m_mask = (input_ids_sample == m.token_idx).float()
155
+ m_kernel = torch.tensor(
156
+ [-1] * m.token_width, dtype=m_mask.dtype, device=m_mask.device
157
+ )
158
+ m_conv = conv1d(
159
+ m_mask.unsqueeze(0).unsqueeze(0),
160
+ m_kernel.unsqueeze(0).unsqueeze(0),
161
+ )
162
+
163
+ # where do we see `token_width`-tokens in a row?
164
+ indices = (m_conv[0, 0] == -m.token_width).nonzero(as_tuple=True)[0]
165
+
166
+ # fill these embeddings with the projected modality tensor
167
+ last_covered_idx = -1
168
+ k = 0
169
+ for possible_token_idx in indices:
170
+ if possible_token_idx <= last_covered_idx:
171
+ # make sure we don't overwrite an instance we've already covered
172
+ # handles bug caused by back-to-back tokens
173
+ continue
174
+ batch_modality_tensor = projected_tensors[mi][i][k]
175
+ inputs_embeds[
176
+ i, possible_token_idx : possible_token_idx + m.token_width
177
+ ] = batch_modality_tensor
178
+ last_covered_idx = possible_token_idx + m.token_width - 1
179
+ k += 1
180
+
181
+ return None, attention_mask, past_key_values, inputs_embeds, labels, task_vals
src/sonicverse/multi_token/language_models/mistral.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import CrossEntropyLoss
7
+
8
+ from transformers import (
9
+ AutoConfig,
10
+ AutoModelForCausalLM,
11
+ MistralConfig,
12
+ MistralModel,
13
+ MistralForCausalLM,
14
+ )
15
+
16
+ from transformers.modeling_outputs import CausalLMOutputWithPast
17
+
18
+ from multi_token.language_models.base_model import (
19
+ LMMMetaModel,
20
+ LMMMetaForCausalLM,
21
+ )
22
+
23
+
24
+ class MistralLMMConfig(MistralConfig):
25
+ model_type = "mistral-lmm"
26
+
27
+
28
+ class MistralLMMModel(LMMMetaModel, MistralModel):
29
+ config_class = MistralLMMConfig
30
+
31
+ def __init__(self, config: MistralLMMConfig):
32
+ super(MistralLMMModel, self).__init__(config)
33
+
34
+
35
+ class MistralLMMForCausalLM(MistralForCausalLM, LMMMetaForCausalLM):
36
+ config_class = MistralLMMConfig
37
+
38
+ def __init__(self, config):
39
+ super(MistralForCausalLM, self).__init__(config)
40
+ self.model = MistralLMMModel(config)
41
+
42
+ self.vocab_size = config.vocab_size
43
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
44
+ self.modalities = None
45
+
46
+ # Initialize weights and apply final processing
47
+ self.post_init()
48
+
49
+ def get_model(self) -> "MistralLMMForCausalLM":
50
+ return self.model
51
+
52
+ def forward(
53
+ self,
54
+ input_ids: torch.LongTensor = None,
55
+ attention_mask: Optional[torch.Tensor] = None,
56
+ position_ids: Optional[torch.LongTensor] = None,
57
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
58
+ inputs_embeds: Optional[torch.FloatTensor] = None,
59
+ labels: Optional[List] = None,
60
+ use_cache: Optional[bool] = None,
61
+ output_attentions: Optional[bool] = None,
62
+ output_hidden_states: Optional[bool] = None,
63
+ return_dict: Optional[bool] = None,
64
+ **kwargs
65
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
66
+ #print("Past keys ",past_key_values)
67
+ output_attentions = (
68
+ output_attentions
69
+ if output_attentions is not None
70
+ else self.config.output_attentions
71
+ )
72
+ output_hidden_states = (
73
+ output_hidden_states
74
+ if output_hidden_states is not None
75
+ else self.config.output_hidden_states
76
+ )
77
+ return_dict = (
78
+ return_dict if return_dict is not None else self.config.use_return_dict
79
+ )
80
+
81
+ if labels != None:
82
+ labels_inp = labels[0]
83
+ else:
84
+ labels_inp = labels
85
+ (
86
+ input_ids,
87
+ attention_mask,
88
+ past_key_values,
89
+ inputs_embeds,
90
+ lmm_labels,
91
+ task_values
92
+ ) = self.prepare_inputs_labels_for_multimodal(
93
+ input_ids, attention_mask, past_key_values, labels_inp, **kwargs
94
+ )
95
+
96
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
97
+ outputs = self.model(
98
+ input_ids=input_ids,
99
+ attention_mask=attention_mask,
100
+ position_ids=position_ids,
101
+ past_key_values=past_key_values,
102
+ inputs_embeds=inputs_embeds,
103
+ use_cache=use_cache,
104
+ output_attentions=output_attentions,
105
+ output_hidden_states=output_hidden_states,
106
+ return_dict=return_dict,
107
+ )
108
+
109
+ hidden_states = outputs[0]
110
+ logits = self.lm_head(hidden_states)
111
+ logits = logits.float()
112
+
113
+ # print("Labels 1 size ", len(labels[1]))
114
+ # print("labels 1 element size ", len(labels[1][0]))
115
+ # print("labels 1 element 1 task size ", labels[1][0][0].shape)
116
+ # print("labels 1 element 2 task size ", labels[1][0][1].shape)
117
+ # print("labels 1 element 3 task size ", labels[1][0][2].shape)
118
+ # print("task vals size ", len(task_values))
119
+ # for task in task_values.keys():
120
+ # print(" task ", task, len(task_values[task]))
121
+ # print(" task element", task, task_values[task][0].shape)
122
+
123
+
124
+ if labels != None:
125
+ task_pairs = {}
126
+ task_list = list(task_values.keys())
127
+ for task_id in range(len(task_list)):
128
+ _task_labels = []
129
+ _task_outputs = []
130
+
131
+ _task = task_list[task_id]
132
+ for inst in range(len(task_values[_task])):
133
+ # print("task output shape ", _task, task_values[_task][inst].shape)
134
+ _task_outputs.append(task_values[_task][inst].unsqueeze(0))
135
+ _task_labels.append(torch.stack([labels[1][inst][task_id]]))
136
+
137
+ task_pairs[_task] = [_task_labels, _task_outputs]
138
+ # print("TASK ", _task)
139
+ # print(" LABELS LEN ", len(task_pairs[_task][0]))
140
+ # print(" LABELS ELEM shape ", task_pairs[_task][0][0].shape)
141
+ # print(" VALUES LEN ", len(task_pairs[_task][1]))
142
+ # print(" VALUES ELEM shape ", task_pairs[_task][1][0].shape)
143
+
144
+ loss = None
145
+ if lmm_labels is not None:
146
+ # Shift so that tokens < n predict n
147
+ shift_logits = logits[..., :-1, :].contiguous()
148
+ shift_labels = lmm_labels[..., 1:].contiguous()
149
+ # Flatten the tokens
150
+ loss_fct = CrossEntropyLoss()
151
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
152
+ shift_labels = shift_labels.view(-1)
153
+ # Enable model parallelism
154
+ shift_labels = shift_labels.to(shift_logits.device)
155
+ loss = loss_fct(shift_logits, shift_labels)
156
+
157
+ # print("loss ", loss)
158
+
159
+
160
+ if labels != None:
161
+ task_loss = {}
162
+ for task in task_list:
163
+ preds = torch.cat(task_pairs[task][1], dim=0)
164
+ labs = torch.cat(task_pairs[task][0], dim=0)
165
+ preds_flat = preds.view(-1, preds.size(-1)) # Reshape to (batch_size * sequence_length, num_classes)
166
+ labs_flat = labs.view(-1) # Reshape to (batch_size * sequence_length)
167
+
168
+ #print("task ", task)
169
+ #print("preds shape ", preds.shape)
170
+ #print("labs shape ", labs.shape)
171
+ if task == "lmm_projector":
172
+ task_loss[task] = CrossEntropyLoss()(preds,labs)
173
+ else:
174
+ task_loss[task] = nn.BCEWithLogitsLoss()(preds, labs)
175
+ # print("task losses ", task_loss)
176
+
177
+ total_loss = None
178
+ if labels != None:
179
+ total_task_loss = None
180
+ for task in task_list:
181
+ if self.modalities[0].tasks["task_heads"][task]["weight"] != 0.0:
182
+ if total_task_loss != None:
183
+ total_task_loss += self.modalities[0].tasks["task_heads"][task]["weight"]*task_loss[task]
184
+ else:
185
+ total_task_loss = self.modalities[0].tasks["task_heads"][task]["weight"]*task_loss[task]
186
+
187
+ if total_task_loss != None:
188
+ total_loss = self.modalities[0].tasks["task_heads"]["lmm_projector"]["weight"]*loss + total_task_loss
189
+ else:
190
+ total_loss = loss
191
+
192
+ if not return_dict:
193
+ output = (logits,) + outputs[1:]
194
+ return (total_loss,) + output if total_loss is not None else output
195
+
196
+ return CausalLMOutputWithPast(
197
+ loss=total_loss,
198
+ logits=logits,
199
+ past_key_values=outputs.past_key_values,
200
+ hidden_states=outputs.hidden_states,
201
+ attentions=outputs.attentions,
202
+ )
203
+
204
+ def prepare_inputs_for_generation(
205
+ self,
206
+ input_ids,
207
+ past_key_values=None,
208
+ attention_mask=None,
209
+ inputs_embeds=None,
210
+ modality_inputs=None,
211
+ **kwargs
212
+ ):
213
+ #print("hoooo", past_key_values)
214
+
215
+ #past_key_values = None
216
+ if past_key_values:
217
+ input_ids = input_ids[:, -1:]
218
+
219
+ if inputs_embeds is not None:
220
+ raise ValueError("inputs_embeds not supported")
221
+
222
+ model_inputs = {
223
+ "input_ids": input_ids,
224
+ "position_ids": None,
225
+ "past_key_values": past_key_values,
226
+ "use_cache": kwargs.get("use_cache"),
227
+ "attention_mask": attention_mask,
228
+ **(modality_inputs or {}),
229
+ }
230
+
231
+ return model_inputs
232
+
233
+
234
+ AutoConfig.register("mistral-lmm", MistralLMMConfig)
235
+ AutoModelForCausalLM.register(MistralLMMConfig, MistralLMMForCausalLM)
src/sonicverse/multi_token/modalities/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multi_token.model_utils import MultiTaskType
2
+ from multi_token.modalities.vision_clip import (
3
+ CLIPVisionModality,
4
+ OUTPUT_LAYER as CLIP_POOL_LAYER,
5
+ )
6
+ from multi_token.modalities.imagebind import ImageBindModality
7
+ from multi_token.modalities.document_gte import DocumentGTEModality
8
+ from multi_token.modalities.audio_whisper import WhisperAudioModality
9
+ from multi_token.modalities.audio_clap import CLAPAudioModality
10
+ from multi_token.modalities.video_xclip import XCLIPVideoModality
11
+ from multi_token.modalities.audio_descript import DescriptAudioModality
12
+ from multi_token.modalities.audio_mert import MERTAudioModality
13
+
14
+ MODALITY_BUILDERS = {
15
+ "vision_clip": lambda: [CLIPVisionModality()],
16
+ "vision_clip_pool": lambda: [
17
+ CLIPVisionModality(feature_layer=CLIP_POOL_LAYER, num_tokens_output=10)
18
+ ],
19
+ "audio_whisper": lambda: [
20
+ WhisperAudioModality(
21
+ num_tokens_output=10, model_name_or_path="openai/whisper-small"
22
+ )
23
+ ],
24
+ "audio_mert": lambda use_multi_task=MultiTaskType.NO_MULTI_TASK, tasks_config=None :[MERTAudioModality(use_multi_task=use_multi_task, tasks_config=tasks_config, num_tokens_output=60, hidden_dim=32, num_conv_layers = 3, num_mlp_layers = 2)],
25
+ "audio_clap": lambda use_multi_task=MultiTaskType.NO_MULTI_TASK, tasks_config=None :[CLAPAudioModality(use_multi_task=use_multi_task, tasks_config=tasks_config, num_tokens_output=20)],
26
+ "audio_descript": lambda use_multi_task=MultiTaskType.NO_MULTI_TASK, tasks_config=None : [DescriptAudioModality(use_multi_task=use_multi_task, tasks_config=tasks_config, num_projector_conv_layers=1, num_projector_mlp_layers=1, num_tokens_output=60, codebooks=96)],
27
+ "video_xclip": lambda: [XCLIPVideoModality(num_tokens_output=10)],
28
+ "imagebind": lambda: [ImageBindModality()],
29
+ "document_gte": lambda: [DocumentGTEModality()],
30
+ "document_gte_x16": lambda: [DocumentGTEModality(num_tokens_output=32)],
31
+ }
src/sonicverse/multi_token/modalities/audio_clap.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import ClapModel, ClapProcessor
6
+
7
+ from multi_token.model_utils import MultiTaskType
8
+ from multi_token.data_tools import load_audio
9
+ from multi_token.modalities.base_modality import Modality
10
+ from multi_token.modalities.projectors import (
11
+ build_mlp_vector_projector, build_mt_vector_projector, MultiTaskModel
12
+ )
13
+
14
+ import json
15
+
16
+ OUTPUT_EMB_SIZE = 512
17
+
18
+
19
+ class CLAPAudioModule(nn.Module):
20
+ def __init__(self, model_name_or_path: str):
21
+ super().__init__()
22
+ self.model_name_or_path = model_name_or_path
23
+ self.model = None
24
+ self.processor = None
25
+
26
+ self.load_model()
27
+
28
+ def load_model(self):
29
+ self.model = ClapModel.from_pretrained(self.model_name_or_path)
30
+ self.processor = ClapProcessor.from_pretrained(self.model_name_or_path)
31
+ self.model.requires_grad_(False)
32
+
33
+ @torch.no_grad()
34
+ def forward(self, audios) -> torch.Tensor:
35
+ embs = []
36
+ for audio_features in audios:
37
+ features = self.model.get_audio_features(
38
+ input_features=audio_features["input_features"].to(torch.float32),
39
+ is_longer=audio_features["is_longer"],
40
+ )
41
+ embs.append(features)
42
+ embs = torch.stack(embs)
43
+ return embs.view(-1, 1, OUTPUT_EMB_SIZE)
44
+
45
+ @property
46
+ def dtype(self):
47
+ return self.model.dtype
48
+
49
+ @property
50
+ def device(self):
51
+ return self.model.device
52
+
53
+
54
+ class CLAPAudioModality(Modality):
55
+ def __init__(
56
+ self,
57
+ model_name_or_path: str = "laion/clap-htsat-fused",
58
+ num_projector_layers: int = 2,
59
+ num_tokens_output: int = 10,
60
+ use_multi_task: int = MultiTaskType.NO_MULTI_TASK,
61
+ tasks_config: str = None
62
+ ):
63
+ self.model_name_or_path = model_name_or_path
64
+ self.module = CLAPAudioModule(model_name_or_path=self.model_name_or_path)
65
+ self.num_projector_layers = num_projector_layers
66
+ self.num_tokens_output = num_tokens_output
67
+ self.dtype = torch.float32
68
+ self.use_multi_task = use_multi_task
69
+ self.tasks = None
70
+ if self.use_multi_task != MultiTaskType.NO_MULTI_TASK:
71
+ with open(tasks_config, 'r') as f:
72
+ self.tasks = json.load(f)
73
+
74
+ print("Tasks :", self.tasks)
75
+
76
+ def build_projector(self, lm_hidden_size: int) -> nn.Module:
77
+ if self.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK:
78
+ return MultiTaskModel(OUTPUT_EMB_SIZE, self.tasks)
79
+ elif self.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
80
+ return build_mt_vector_projector(
81
+ # return build_mlp_vector_projector(
82
+ input_hidden_size=OUTPUT_EMB_SIZE,
83
+ lm_hidden_size=lm_hidden_size,
84
+ # num_layers=self.num_projector_layers,
85
+ # num_tokens=self.num_tokens_output,
86
+ # )
87
+ tasks = self.tasks
88
+ )
89
+ # )["llm_projector"]
90
+ else:
91
+ return build_mlp_vector_projector(
92
+ input_hidden_size=OUTPUT_EMB_SIZE,
93
+ lm_hidden_size=lm_hidden_size,
94
+ num_layers=self.num_projector_layers,
95
+ num_tokens=self.num_tokens_output,
96
+ )
97
+
98
+ @property
99
+ def name(self) -> str:
100
+ return "audio_clap"
101
+
102
+ @property
103
+ def token(self) -> str:
104
+ return "<sound>"
105
+
106
+ @property
107
+ def data_key(self) -> str:
108
+ return "sounds"
109
+
110
+ @property
111
+ def token_width(self) -> int:
112
+ return self.num_tokens_output
113
+
114
+ def to(self, dtype: torch.dtype, device: torch.device) -> "CLAPAudioModality":
115
+ self.dtype = dtype
116
+ self.module.to(device=device)
117
+ return self
118
+
119
+ def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Dict]]:
120
+ row_values = []
121
+ for row in rows:
122
+ audios = []
123
+ for audio_dict in row[self.data_key]:
124
+ audio_dict = load_audio(
125
+ audio_dict,
126
+ target_sampling_rate=self.module.processor.feature_extractor.sampling_rate,
127
+ )
128
+ audio_processed = self.module.processor(
129
+ audios=audio_dict["array"],
130
+ return_tensors="pt",
131
+ sampling_rate=audio_dict["sampling_rate"],
132
+ )
133
+ audios.append(audio_processed)
134
+ row_values.append(audios)
135
+ return row_values
136
+
137
+ @torch.no_grad()
138
+ def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]:
139
+ audio_features = []
140
+ for audio_batch in encoded_values:
141
+ audio_features.append(self.module.forward(audio_batch).to(dtype=self.dtype))
142
+ return audio_features
src/sonicverse/multi_token/modalities/audio_descript.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import dac
6
+ from audiotools import AudioSignal
7
+
8
+ from multi_token.model_utils import MultiTaskType
9
+ from multi_token.data_tools import load_audio_signal
10
+ from multi_token.modalities.base_modality import Modality
11
+ from multi_token.modalities.projectors import (
12
+ build_mlp_vector_projector, build_attentive_cnn_projector, build_cnn_mlp_projector, MultiTaskModel
13
+ )
14
+
15
+ import json
16
+
17
+ OUTPUT_FRAMES_SIZE = 512
18
+ # OUTPUT_EMB_SIZE = 2048
19
+ OUTPUT_EMB_CHANNELS = 96
20
+
21
+ class DescriptAudioModule(nn.Module):
22
+ def __init__(self, model_name_or_path: str, codebooks = 4):
23
+ super().__init__()
24
+ self.model_name_or_path = model_name_or_path
25
+ self.model = None
26
+ self.processor = None
27
+ self.codebooks = codebooks
28
+
29
+ self.load_model()
30
+
31
+ def load_model(self):
32
+ # self.model = ClapModel.from_pretrained(self.model_name_or_path)
33
+ self.model = dac.DAC.load(self.model_name_or_path)
34
+
35
+ def forward(self, audios) -> torch.Tensor:
36
+ embs = []
37
+ for audio_features in audios:
38
+ # print("Audio features sample rate ", audio_features[0].sample_rate)
39
+ x = self.model.preprocess(audio_features[0].audio_data, audio_features[0].sample_rate)
40
+ z, codes, latents, _, _ = self.model.encode(x)
41
+
42
+ # print("latents og shape ", latents.shape)
43
+ # If the tensor is larger than desired_shape, crop it
44
+ if latents.shape[2] > OUTPUT_FRAMES_SIZE:
45
+ latents = latents[:, :, :OUTPUT_FRAMES_SIZE]
46
+ # If the tensor is smaller than desired_shape, pad it
47
+ elif latents.shape[2] < OUTPUT_FRAMES_SIZE:
48
+ pad_width = (0, OUTPUT_FRAMES_SIZE - latents.shape[2])
49
+ latents = torch.nn.functional.pad(latents, pad_width)
50
+ # print("Codes new shape ", codes_new.shape)
51
+
52
+ # print("latents int shape ", latents.shape)
53
+
54
+ latents = latents[0][:self.codebooks]
55
+
56
+ # print("latents final shape ", latents.shape)
57
+
58
+ embs.append(latents)
59
+
60
+ embs = torch.stack(embs)
61
+
62
+ # output_embs = embs.view(-1, 1, OUTPUT_FRAMES_SIZE*self.codebooks)
63
+ # print("embs post view shape ", output_embs.shape)
64
+
65
+ return embs
66
+
67
+ @property
68
+ def dtype(self):
69
+ return self.model.dtype
70
+
71
+ @property
72
+ def device(self):
73
+ return self.model.device
74
+
75
+
76
+ class DescriptAudioModality(Modality):
77
+ def __init__(
78
+ self,
79
+ model_name_or_path: str = dac.utils.download(model_type="16khz"),
80
+ num_projector_conv_layers: int = 2,
81
+ num_projector_mlp_layers: int = 2,
82
+ num_tokens_output: int = 10,
83
+ codebooks: int = 96,
84
+ use_multi_task: MultiTaskType = MultiTaskType.NO_MULTI_TASK,
85
+ tasks_config: str = None
86
+ ):
87
+ self.model_name_or_path = model_name_or_path
88
+ self.module = DescriptAudioModule(model_name_or_path=self.model_name_or_path, codebooks=codebooks)
89
+ self.num_projector_conv_layers = num_projector_conv_layers
90
+ self.num_projector_mlp_layers = num_projector_mlp_layers
91
+ self.num_tokens_output = num_tokens_output
92
+ self.dtype = torch.float32
93
+ self.codebooks = codebooks
94
+ self.use_multi_task = use_multi_task
95
+ self.tasks = None
96
+ if self.use_multi_task != MultiTaskType.NO_MULTI_TASK:
97
+ with open(tasks_config, 'r') as f:
98
+ self.tasks = json.load(f)
99
+
100
+ print("Tasks :", self.tasks)
101
+
102
+
103
+ def build_projector(self, lm_hidden_size: int) -> nn.Module:
104
+ if self.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK:
105
+ projector = MultiTaskModel(OUTPUT_EMB_CHANNELS, 1, True, -1, False, self.tasks)
106
+ print("projector ", projector)
107
+ return projector
108
+ elif self.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
109
+ return build_mt_vector_projector(
110
+ # return build_mlp_vector_projector(
111
+ input_hidden_size=OUTPUT_EMB_CHANNELS,
112
+ lm_hidden_size=lm_hidden_size,
113
+ # num_layers=self.num_projector_layers,
114
+ # num_tokens=self.num_tokens_output,
115
+ # )
116
+ tasks = self.tasks
117
+ )
118
+ # )["llm_projector"]
119
+ else:
120
+ return build_multi_layer_cnn_mlp_projector(
121
+ input_channels = OUTPUT_EMB_CHANNELS,
122
+ input_size = OUTPUT_EMB_SIZE,
123
+ num_feature_layers= OUTPUT_FEATURE_LAYERS,
124
+ lm_hidden_size = lm_hidden_size,
125
+ num_tokens = self.num_tokens_output,
126
+ hidden_dim = self.hidden_dim,
127
+ num_conv_layers = self.num_conv_layers,
128
+ num_mlp_layers = self.num_mlp_layers
129
+ )
130
+
131
+ @property
132
+ def name(self) -> str:
133
+ return "audio_descript"
134
+
135
+ @property
136
+ def token(self) -> str:
137
+ return "<sound>"
138
+
139
+ @property
140
+ def data_key(self) -> str:
141
+ return "sounds"
142
+
143
+ @property
144
+ def token_width(self) -> int:
145
+ return self.num_tokens_output
146
+
147
+ def to(self, dtype: torch.dtype, device: torch.device) -> "DescriptAudioModality":
148
+ self.dtype = dtype
149
+ self.module.to(device=device)
150
+ return self
151
+
152
+ def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Dict]]:
153
+ row_values = []
154
+ for row in rows:
155
+ audios = []
156
+ for audio_dict in row[self.data_key]:
157
+ audio_dict = load_audio_signal(
158
+ audio_dict
159
+ )
160
+ audios.append(audio_dict["array"])
161
+ row_values.append(audios)
162
+ return row_values
163
+
164
+ @torch.no_grad()
165
+ def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]:
166
+ audio_features = []
167
+ for audio_batch in encoded_values:
168
+ audio_features.append(self.module.forward(audio_batch).to(dtype=self.dtype))
169
+ return audio_features
src/sonicverse/multi_token/modalities/audio_descript_bu.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import dac
6
+ from audiotools import AudioSignal
7
+
8
+
9
+ from multi_token.data_tools import load_audio_signal
10
+ from multi_token.modalities.base_modality import Modality
11
+ from multi_token.modalities.projectors import (
12
+ build_mlp_vector_projector, build_attentive_cnn_projector, build_cnn_mlp_projector
13
+ )
14
+
15
+ OUTPUT_FRAMES_SIZE = 512
16
+ # OUTPUT_EMB_SIZE = 2048
17
+
18
+ class DescriptAudioModule(nn.Module):
19
+ def __init__(self, model_name_or_path: str, codebooks = 4):
20
+ super().__init__()
21
+ self.model_name_or_path = model_name_or_path
22
+ self.model = None
23
+ self.processor = None
24
+ self.codebooks = codebooks
25
+
26
+ self.load_model()
27
+
28
+ def load_model(self):
29
+ # self.model = ClapModel.from_pretrained(self.model_name_or_path)
30
+ self.model = dac.DAC.load(self.model_name_or_path)
31
+
32
+ def forward(self, audios) -> torch.Tensor:
33
+ embs = []
34
+ for audio_features in audios:
35
+ x = self.model.preprocess(audio_features[0].audio_data, audio_features[0].sample_rate)
36
+ z, codes, latents, _, _ = self.model.encode(x)
37
+
38
+ # If the tensor is larger than desired_shape, crop it
39
+ if codes.shape[2] > OUTPUT_FRAMES_SIZE:
40
+ codes = codes[:, :, :OUTPUT_FRAMES_SIZE]
41
+ # If the tensor is smaller than desired_shape, pad it
42
+ elif codes.shape[2] < OUTPUT_FRAMES_SIZE:
43
+ pad_width = (0, OUTPUT_FRAMES_SIZE - codes.shape[2])
44
+ codes = torch.nn.functional.pad(codes, pad_width)
45
+ # print("Codes new shape ", codes_new.shape)
46
+
47
+ codes_of_interest = codes[0][:self.codebooks]
48
+
49
+ embs.append(codes_of_interest)
50
+
51
+ embs = torch.stack(embs)
52
+
53
+ # output_embs = embs.view(-1, 1, OUTPUT_FRAMES_SIZE*self.codebooks)
54
+ # print("embs post view shape ", output_embs.shape)
55
+
56
+ return embs
57
+
58
+ @property
59
+ def dtype(self):
60
+ return self.model.dtype
61
+
62
+ @property
63
+ def device(self):
64
+ return self.model.device
65
+
66
+
67
+ class DescriptAudioModality(Modality):
68
+ def __init__(
69
+ self,
70
+ model_name_or_path: str = dac.utils.download(model_type="16khz"),
71
+ num_projector_conv_layers: int = 2,
72
+ num_projector_mlp_layers: int = 2,
73
+ num_tokens_output: int = 10,
74
+ codebooks: int = 4
75
+ ):
76
+ self.model_name_or_path = model_name_or_path
77
+ self.module = DescriptAudioModule(model_name_or_path=self.model_name_or_path, codebooks=codebooks)
78
+ self.num_projector_conv_layers = num_projector_conv_layers
79
+ self.num_projector_mlp_layers = num_projector_mlp_layers
80
+ self.num_tokens_output = num_tokens_output
81
+ self.dtype = torch.float32
82
+ self.codebooks = codebooks
83
+
84
+ def build_projector(self, lm_hidden_size: int) -> nn.Module:
85
+ return build_cnn_mlp_projector(
86
+ input_channels=self.codebooks,
87
+ input_size=OUTPUT_FRAMES_SIZE,
88
+ lm_hidden_size=lm_hidden_size,
89
+ num_tokens=self.num_tokens_output,
90
+ hidden_dim=64,
91
+ num_conv_layers=self.num_projector_conv_layers,
92
+ num_mlp_layers=self.num_projector_mlp_layers
93
+ )
94
+
95
+ @property
96
+ def name(self) -> str:
97
+ return "audio_descript"
98
+
99
+ @property
100
+ def token(self) -> str:
101
+ return "<sound>"
102
+
103
+ @property
104
+ def data_key(self) -> str:
105
+ return "sounds"
106
+
107
+ @property
108
+ def token_width(self) -> int:
109
+ return self.num_tokens_output
110
+
111
+ def to(self, dtype: torch.dtype, device: torch.device) -> "DescriptAudioModality":
112
+ self.dtype = dtype
113
+ self.module.to(device=device)
114
+ return self
115
+
116
+ def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Dict]]:
117
+ row_values = []
118
+ for row in rows:
119
+ audios = []
120
+ for audio_dict in row[self.data_key]:
121
+ audio_dict = load_audio_signal(
122
+ audio_dict
123
+ )
124
+ audios.append(audio_dict["array"])
125
+ row_values.append(audios)
126
+ return row_values
127
+
128
+ @torch.no_grad()
129
+ def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]:
130
+ audio_features = []
131
+ for audio_batch in encoded_values:
132
+ audio_features.append(self.module.forward(audio_batch).to(dtype=self.dtype))
133
+ return audio_features
src/sonicverse/multi_token/modalities/audio_mert.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import Wav2Vec2FeatureExtractor, AutoModel
6
+
7
+ from multi_token.model_utils import MultiTaskType
8
+ from multi_token.data_tools import load_audio
9
+ from multi_token.modalities.base_modality import Modality
10
+ from multi_token.modalities.projectors import (
11
+ build_mlp_vector_projector, build_mt_vector_projector, build_multi_layer_cnn_mlp_projector, MultiTaskModel
12
+ )
13
+ from multi_token.modalities.multi_task_projector_shared import MultiTaskSharedModel
14
+
15
+ import json
16
+
17
+ OUTPUT_EMB_CHANNELS = 768 #1024
18
+ OUTPUT_EMB_SIZE = 760
19
+ OUTPUT_FEATURE_LAYERS = 13 #25
20
+
21
+ cache_dir="/home/ubuntu/.cache/"
22
+
23
+ class MERTAudioModule(nn.Module):
24
+ def __init__(self, model_name_or_path: str):
25
+ super().__init__()
26
+ self.model_name_or_path = model_name_or_path
27
+ self.model = None
28
+ self.processor = None
29
+
30
+ self.load_model()
31
+
32
+ def load_model(self):
33
+ self.model = AutoModel.from_pretrained(self.model_name_or_path, trust_remote_code=True, cache_dir=cache_dir)
34
+ self.processor = Wav2Vec2FeatureExtractor.from_pretrained(self.model_name_or_path,trust_remote_code=True, cache_dir=cache_dir)
35
+ self.model.requires_grad_(False)
36
+
37
+ @torch.no_grad()
38
+ def forward(self, audios) -> torch.Tensor:
39
+ embs = []
40
+ for audio_features in audios:
41
+ outputs = self.model(**audio_features.to(torch.float32), output_hidden_states=True)
42
+ features = torch.stack(outputs.hidden_states).squeeze()
43
+ embs.append(features)
44
+ embs = torch.stack(embs)
45
+ embs = embs.squeeze()
46
+ padding_needed = OUTPUT_EMB_SIZE - embs.shape[1]
47
+ embs = torch.nn.functional.pad(embs, (0, 0, 0, padding_needed, 0, 0))
48
+ return embs
49
+
50
+ @property
51
+ def dtype(self):
52
+ return self.model.dtype
53
+
54
+ @property
55
+ def device(self):
56
+ return self.model.device
57
+
58
+
59
+ class MERTAudioModality(Modality):
60
+ def __init__(
61
+ self,
62
+ model_name_or_path: str = "m-a-p/MERT-v1-95M",
63
+ num_tokens_output: int = 10,
64
+ hidden_dim: int = 32,
65
+ num_conv_layers: int = 5,
66
+ num_mlp_layers: int = 5,
67
+ use_multi_task: MultiTaskType = MultiTaskType.NO_MULTI_TASK,
68
+ tasks_config: str = None
69
+ ):
70
+ self.model_name_or_path = model_name_or_path
71
+ self.module = MERTAudioModule(model_name_or_path=self.model_name_or_path)
72
+ self.num_tokens_output = num_tokens_output
73
+ self.hidden_dim = hidden_dim
74
+ self.num_conv_layers = num_conv_layers
75
+ self.num_mlp_layers = num_mlp_layers
76
+ self.dtype = torch.float32
77
+ self.use_multi_task = use_multi_task
78
+ self.tasks = None
79
+ if self.use_multi_task != MultiTaskType.NO_MULTI_TASK:
80
+ with open(tasks_config, 'r') as f:
81
+ self.tasks = json.load(f)
82
+
83
+ print("Tasks :", self.tasks)
84
+
85
+ # all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
86
+ # print(all_layer_hidden_states.shape) # [25 layer, Time steps, 1024 feature_dim]
87
+ # time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
88
+ # print(time_reduced_hidden_states.shape) # [25, 1024]
89
+
90
+ def build_projector(self, lm_hidden_size: int) -> nn.Module:
91
+ if self.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK:
92
+ projector = MultiTaskSharedModel(self.tasks)
93
+ print("projector ", projector)
94
+ return projector
95
+ elif self.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
96
+ return build_mt_vector_projector(
97
+ # return build_mlp_vector_projector(
98
+ input_hidden_size=OUTPUT_EMB_SIZE,
99
+ lm_hidden_size=lm_hidden_size,
100
+ # num_layers=self.num_projector_layers,
101
+ # num_tokens=self.num_tokens_output,
102
+ # )
103
+ tasks = self.tasks
104
+ )
105
+ # )["llm_projector"]
106
+ else:
107
+ return build_multi_layer_cnn_mlp_projector(
108
+ input_channels = OUTPUT_EMB_CHANNELS,
109
+ input_size = OUTPUT_EMB_SIZE,
110
+ num_feature_layers= OUTPUT_FEATURE_LAYERS,
111
+ lm_hidden_size = lm_hidden_size,
112
+ num_tokens = self.num_tokens_output,
113
+ hidden_dim = self.hidden_dim,
114
+ num_conv_layers = self.num_conv_layers,
115
+ num_mlp_layers = self.num_mlp_layers
116
+ )
117
+
118
+ @property
119
+ def name(self) -> str:
120
+ return "audio_mert"
121
+
122
+ @property
123
+ def token(self) -> str:
124
+ return "<sound>"
125
+
126
+ @property
127
+ def data_key(self) -> str:
128
+ return "sounds"
129
+
130
+ @property
131
+ def token_width(self) -> int:
132
+ return self.num_tokens_output
133
+
134
+ def to(self, dtype: torch.dtype, device: torch.device) -> "MERTAudioModality":
135
+ self.dtype = dtype
136
+ self.module.to(device=device)
137
+ return self
138
+
139
+ def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Dict]]:
140
+ row_values = []
141
+ for row in rows:
142
+ audios = []
143
+ for audio_dict in row[self.data_key]:
144
+ audio_dict = load_audio(
145
+ audio_dict,
146
+ target_sampling_rate=self.module.processor.sampling_rate,
147
+ )
148
+ audio_processed = self.module.processor(
149
+ audio_dict["array"],
150
+ return_tensors="pt",
151
+ sampling_rate=audio_dict["sampling_rate"],
152
+ )
153
+ audios.append(audio_processed)
154
+ row_values.append(audios)
155
+ return row_values
156
+
157
+ @torch.no_grad()
158
+ def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]:
159
+ audio_features = []
160
+ for audio_batch in encoded_values:
161
+ audio_features.append(self.module.forward(audio_batch).to(dtype=self.dtype))
162
+ return audio_features
src/sonicverse/multi_token/modalities/audio_mert_bu.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import Wav2Vec2FeatureExtractor, AutoModel
6
+
7
+ from multi_token.model_utils import MultiTaskType
8
+ from multi_token.data_tools import load_audio
9
+ from multi_token.modalities.base_modality import Modality
10
+ from multi_token.modalities.projectors import (
11
+ build_mlp_vector_projector, build_mt_vector_projector, build_multi_layer_cnn_mlp_projector, MultiTaskModel
12
+ )
13
+
14
+ import json
15
+
16
+ OUTPUT_EMB_CHANNELS = 1024
17
+ OUTPUT_EMB_SIZE = 760
18
+ OUTPUT_FEATURE_LAYERS = 25
19
+
20
+ class MERTAudioModule(nn.Module):
21
+ def __init__(self, model_name_or_path: str):
22
+ super().__init__()
23
+ self.model_name_or_path = model_name_or_path
24
+ self.model = None
25
+ self.processor = None
26
+
27
+ self.load_model()
28
+
29
+ def load_model(self):
30
+ self.model = AutoModel.from_pretrained(self.model_name_or_path, trust_remote_code=True)
31
+ self.processor = Wav2Vec2FeatureExtractor.from_pretrained(self.model_name_or_path,trust_remote_code=True)
32
+ self.model.requires_grad_(False)
33
+
34
+ @torch.no_grad()
35
+ def forward(self, audios) -> torch.Tensor:
36
+ embs = []
37
+ for audio_features in audios:
38
+ outputs = self.model(**audio_features.to(torch.float32), output_hidden_states=True)
39
+ features = torch.stack(outputs.hidden_states).squeeze()
40
+ embs.append(features)
41
+ embs = torch.stack(embs)
42
+ embs = embs.squeeze()
43
+ padding_needed = OUTPUT_EMB_SIZE - embs.shape[1]
44
+ embs = torch.nn.functional.pad(embs, (0, 0, 0, padding_needed, 0, 0))
45
+ return embs
46
+
47
+ @property
48
+ def dtype(self):
49
+ return self.model.dtype
50
+
51
+ @property
52
+ def device(self):
53
+ return self.model.device
54
+
55
+
56
+ class MERTAudioModality(Modality):
57
+ def __init__(
58
+ self,
59
+ model_name_or_path: str = "m-a-p/MERT-v1-330M",
60
+ num_tokens_output: int = 10,
61
+ hidden_dim: int = 32,
62
+ num_conv_layers: int = 5,
63
+ num_mlp_layers: int = 5,
64
+ use_multi_task: MultiTaskType = MultiTaskType.NO_MULTI_TASK,
65
+ tasks_config: str = None
66
+ ):
67
+ self.model_name_or_path = model_name_or_path
68
+ self.module = MERTAudioModule(model_name_or_path=self.model_name_or_path)
69
+ self.num_tokens_output = num_tokens_output
70
+ self.hidden_dim = hidden_dim
71
+ self.num_conv_layers = num_conv_layers
72
+ self.num_mlp_layers = num_mlp_layers
73
+ self.dtype = torch.float32
74
+ self.use_multi_task = use_multi_task
75
+ self.tasks = None
76
+ if self.use_multi_task != MultiTaskType.NO_MULTI_TASK:
77
+ with open(tasks_config, 'r') as f:
78
+ self.tasks = json.load(f)
79
+
80
+ print("Tasks :", self.tasks)
81
+
82
+ # all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
83
+ # print(all_layer_hidden_states.shape) # [25 layer, Time steps, 1024 feature_dim]
84
+ # time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
85
+ # print(time_reduced_hidden_states.shape) # [25, 1024]
86
+
87
+ def build_projector(self, lm_hidden_size: int) -> nn.Module:
88
+ if self.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK:
89
+ projector = MultiTaskModel(OUTPUT_EMB_CHANNELS, OUTPUT_FEATURE_LAYERS, True, self.tasks)
90
+ print("projector ", projector)
91
+ return projector
92
+ elif self.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
93
+ return build_mt_vector_projector(
94
+ # return build_mlp_vector_projector(
95
+ input_hidden_size=OUTPUT_EMB_SIZE,
96
+ lm_hidden_size=lm_hidden_size,
97
+ # num_layers=self.num_projector_layers,
98
+ # num_tokens=self.num_tokens_output,
99
+ # )
100
+ tasks = self.tasks
101
+ )
102
+ # )["llm_projector"]
103
+ else:
104
+ return build_multi_layer_cnn_mlp_projector(
105
+ input_channels = OUTPUT_EMB_CHANNELS,
106
+ input_size = OUTPUT_EMB_SIZE,
107
+ num_feature_layers= OUTPUT_FEATURE_LAYERS,
108
+ lm_hidden_size = lm_hidden_size,
109
+ num_tokens = self.num_tokens_output,
110
+ hidden_dim = self.hidden_dim,
111
+ num_conv_layers = self.num_conv_layers,
112
+ num_mlp_layers = self.num_mlp_layers
113
+ )
114
+
115
+ @property
116
+ def name(self) -> str:
117
+ return "audio_mert"
118
+
119
+ @property
120
+ def token(self) -> str:
121
+ return "<sound>"
122
+
123
+ @property
124
+ def data_key(self) -> str:
125
+ return "sounds"
126
+
127
+ @property
128
+ def token_width(self) -> int:
129
+ return self.num_tokens_output
130
+
131
+ def to(self, dtype: torch.dtype, device: torch.device) -> "MERTAudioModality":
132
+ self.dtype = dtype
133
+ self.module.to(device=device)
134
+ return self
135
+
136
+ def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Dict]]:
137
+ row_values = []
138
+ for row in rows:
139
+ audios = []
140
+ for audio_dict in row[self.data_key]:
141
+ audio_dict = load_audio(
142
+ audio_dict,
143
+ target_sampling_rate=self.module.processor.sampling_rate,
144
+ )
145
+ audio_processed = self.module.processor(
146
+ audio_dict["array"],
147
+ return_tensors="pt",
148
+ sampling_rate=audio_dict["sampling_rate"],
149
+ )
150
+ audios.append(audio_processed)
151
+ row_values.append(audios)
152
+ return row_values
153
+
154
+ @torch.no_grad()
155
+ def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]:
156
+ audio_features = []
157
+ for audio_batch in encoded_values:
158
+ audio_features.append(self.module.forward(audio_batch).to(dtype=self.dtype))
159
+ return audio_features
src/sonicverse/multi_token/modalities/audio_whisper.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoFeatureExtractor, WhisperModel
6
+
7
+ from multi_token.data_tools import load_audio
8
+ from multi_token.modalities.base_modality import Modality
9
+ from multi_token.modalities.projectors import (
10
+ build_mlp_vector_projector,
11
+ )
12
+
13
+
14
+ OUTPUT_EMB_SIZE = 768
15
+
16
+
17
+ class WhisperAudioModule(nn.Module):
18
+ def __init__(self, model_name_or_path: str):
19
+ super().__init__()
20
+ self.model_name_or_path = model_name_or_path
21
+ self.model = None
22
+ self.feature_extractor = None
23
+
24
+ self.load_model()
25
+
26
+ def load_model(self):
27
+ self.model = WhisperModel.from_pretrained(self.model_name_or_path)
28
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(
29
+ self.model_name_or_path
30
+ )
31
+ self.model.requires_grad_(False)
32
+
33
+ @torch.no_grad()
34
+ def forward(self, audios) -> torch.Tensor:
35
+ hidden_states = []
36
+ for i in range(audios.shape[0]):
37
+ decoder_input_ids = (
38
+ torch.tensor([[1]]) * self.model.config.decoder_start_token_id
39
+ )
40
+ last_hidden_state = self.model(
41
+ audios[i].to(device=self.device, dtype=self.dtype),
42
+ decoder_input_ids=decoder_input_ids.to(device=self.device),
43
+ ).last_hidden_state
44
+ hidden_states.append(last_hidden_state)
45
+ last_hidden_state = torch.stack(hidden_states)
46
+ return last_hidden_state.view(-1, 1, OUTPUT_EMB_SIZE)
47
+
48
+ @property
49
+ def dtype(self):
50
+ return self.model.dtype
51
+
52
+ @property
53
+ def device(self):
54
+ return self.model.device
55
+
56
+
57
+ class WhisperAudioModality(Modality):
58
+ def __init__(
59
+ self,
60
+ model_name_or_path: str = "openai/whisper-small",
61
+ num_projector_layers: int = 2,
62
+ num_tokens_output: int = 10,
63
+ ):
64
+ self.model_name_or_path = model_name_or_path
65
+ self.module = WhisperAudioModule(model_name_or_path=self.model_name_or_path)
66
+ self.num_projector_layers = num_projector_layers
67
+ self.num_tokens_output = num_tokens_output
68
+
69
+ def build_projector(self, lm_hidden_size: int) -> nn.Module:
70
+ return build_mlp_vector_projector(
71
+ input_hidden_size=OUTPUT_EMB_SIZE,
72
+ lm_hidden_size=lm_hidden_size,
73
+ num_layers=self.num_projector_layers,
74
+ num_tokens=self.num_tokens_output,
75
+ )
76
+
77
+ @property
78
+ def name(self) -> str:
79
+ return "audio_whisper"
80
+
81
+ @property
82
+ def token(self) -> str:
83
+ return "<speech>"
84
+
85
+ @property
86
+ def data_key(self) -> str:
87
+ return "speech_audios"
88
+
89
+ @property
90
+ def token_width(self) -> int:
91
+ return self.num_tokens_output
92
+
93
+ def to(self, dtype: torch.dtype, device: torch.device) -> "WhisperAudioModality":
94
+ self.module.to(dtype=dtype, device=device)
95
+ return self
96
+
97
+ def preprocess_rows(self, rows: List[Dict]) -> List[Optional[torch.Tensor]]:
98
+ row_values = []
99
+ for row in rows:
100
+ audios = []
101
+ for audio_dict in row[self.data_key]:
102
+ audio_dict = load_audio(
103
+ audio_dict,
104
+ target_sampling_rate=self.module.feature_extractor.sampling_rate,
105
+ )
106
+ audio_processed = self.module.feature_extractor(
107
+ audio_dict["array"],
108
+ return_tensors="pt",
109
+ sampling_rate=audio_dict["sampling_rate"],
110
+ ).input_features
111
+ audios.append(audio_processed)
112
+ row_values.append(torch.stack(audios) if len(audios) > 0 else None)
113
+ return row_values
114
+
115
+ @torch.no_grad()
116
+ def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]:
117
+ audio_features = []
118
+ for audio_batch in encoded_values:
119
+ audio_features.append(self.module.forward(audio_batch))
120
+ return audio_features
src/sonicverse/multi_token/modalities/base_modality.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Any
2
+ from abc import ABC, abstractmethod
3
+ from functools import cached_property
4
+
5
+ import torch.nn as nn
6
+ import torch
7
+
8
+
9
+ class Modality(ABC):
10
+ @abstractmethod
11
+ def build_projector(self, lm_hidden_size: int) -> nn.Module:
12
+ pass
13
+
14
+ @property
15
+ @abstractmethod
16
+ def name(self) -> str:
17
+ pass
18
+
19
+ @property
20
+ @abstractmethod
21
+ def token(self) -> str:
22
+ pass
23
+
24
+ @property
25
+ @abstractmethod
26
+ def data_key(self) -> str:
27
+ pass
28
+
29
+ @property
30
+ @abstractmethod
31
+ def token_width(self) -> int:
32
+ pass
33
+
34
+ @cached_property
35
+ def token_idx(self) -> int:
36
+ hash_ = sum(ord(c) ** i for i, c in enumerate(self.token))
37
+ return -abs(hash_ % 10_000)
38
+
39
+ @abstractmethod
40
+ def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Any]]:
41
+ pass
42
+
43
+ @abstractmethod
44
+ def forward(self, encoded_values: List[Any]) -> List[torch.Tensor]:
45
+ pass
46
+
47
+ def to(self, dtype: torch.dtype, device: torch.device) -> "Modality":
48
+ return self
src/sonicverse/multi_token/modalities/bu__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multi_token.model_utils import MultiTaskType
2
+ from multi_token.modalities.vision_clip import (
3
+ CLIPVisionModality,
4
+ OUTPUT_LAYER as CLIP_POOL_LAYER,
5
+ )
6
+ from multi_token.modalities.imagebind import ImageBindModality
7
+ from multi_token.modalities.document_gte import DocumentGTEModality
8
+ from multi_token.modalities.audio_whisper import WhisperAudioModality
9
+ from multi_token.modalities.audio_clap import CLAPAudioModality
10
+ from multi_token.modalities.video_xclip import XCLIPVideoModality
11
+ from multi_token.modalities.audio_descript import DescriptAudioModality
12
+ from multi_token.modalities.audio_mert import MERTAudioModality
13
+
14
+ MODALITY_BUILDERS = {
15
+ "vision_clip": lambda: [CLIPVisionModality()],
16
+ "vision_clip_pool": lambda: [
17
+ CLIPVisionModality(feature_layer=CLIP_POOL_LAYER, num_tokens_output=10)
18
+ ],
19
+ "audio_whisper": lambda: [
20
+ WhisperAudioModality(
21
+ num_tokens_output=10, model_name_or_path="openai/whisper-small"
22
+ )
23
+ ],
24
+ "audio_mert": lambda use_multi_task=MultiTaskType.NO_MULTI_TASK, tasks_config=None :[MERTAudioModality(use_multi_task=use_multi_task, tasks_config=tasks_config, num_tokens_output=60, hidden_dim=32, num_conv_layers = 3, num_mlp_layers = 2)],
25
+ "audio_clap": lambda use_multi_task=MultiTaskType.NO_MULTI_TASK, tasks_config=None :[CLAPAudioModality(use_multi_task=use_multi_task, tasks_config=tasks_config, num_tokens_output=20)],
26
+ "audio_descript": lambda: [DescriptAudioModality(num_projector_conv_layers=1, num_projector_mlp_layers=1, num_tokens_output=5, codebooks=12)],
27
+ "video_xclip": lambda: [XCLIPVideoModality(num_tokens_output=10)],
28
+ "imagebind": lambda: [ImageBindModality()],
29
+ "document_gte": lambda: [DocumentGTEModality()],
30
+ "document_gte_x16": lambda: [DocumentGTEModality(num_tokens_output=32)],
31
+ }
src/sonicverse/multi_token/modalities/document_gte.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import os
6
+ from functools import cache
7
+ from transformers import AutoTokenizer, AutoModel
8
+
9
+ from multi_token.modalities.base_modality import Modality
10
+ from multi_token.modalities.projectors import build_mlp_vector_projector
11
+
12
+ GTE_EMBEDDING_SIZE = 1024
13
+ GTE_CONTEXT_WINDOW = 512
14
+ GTE_DEFAULT_MODEL = "thenlper/gte-large"
15
+ DOCUMENT_GTE_FORCE_CPU = "DOCUMENT_GTE_FORCE_CPU"
16
+
17
+
18
+ def average_pool(
19
+ last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
20
+ ) -> torch.Tensor:
21
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
22
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
23
+
24
+
25
+ @cache
26
+ def _get_tokenizer(model_name_or_path: str = GTE_DEFAULT_MODEL):
27
+ return AutoTokenizer.from_pretrained(model_name_or_path)
28
+
29
+
30
+ def split_text_into_documents(text: str) -> List[str]:
31
+ from nltk.tokenize import sent_tokenize
32
+
33
+ tokenizer = _get_tokenizer(GTE_DEFAULT_MODEL)
34
+
35
+ sentences = sent_tokenize(text)
36
+ documents = [[]]
37
+
38
+ for sentence in sentences:
39
+ sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
40
+ if len(documents[-1]) + len(sentence_tokens) > GTE_CONTEXT_WINDOW:
41
+ documents.append([])
42
+ documents[-1].extend(sentence_tokens)
43
+
44
+ return [tokenizer.decode(doc) for doc in documents]
45
+
46
+
47
+ class DocumentGTEModule(nn.Module):
48
+ def __init__(self, model_name_or_path: str):
49
+ super().__init__()
50
+ self.feature_layer = -2
51
+ self.model_name_or_path = model_name_or_path
52
+
53
+ self.model = AutoModel.from_pretrained("thenlper/gte-large")
54
+ self.model.requires_grad_(False)
55
+
56
+ @torch.no_grad()
57
+ def forward(self, batch_dict) -> torch.Tensor:
58
+ outputs = self.model(**batch_dict)
59
+ embeddings = average_pool(
60
+ outputs.last_hidden_state, batch_dict["attention_mask"]
61
+ )
62
+ return embeddings
63
+
64
+ @property
65
+ def embedding_size(self):
66
+ return GTE_EMBEDDING_SIZE
67
+
68
+
69
+ class DocumentGTEModality(Modality):
70
+ def __init__(
71
+ self,
72
+ model_name_or_path: str = GTE_DEFAULT_MODEL,
73
+ num_projector_layers: int = 2,
74
+ num_tokens_output: int = 4,
75
+ ):
76
+ self.model_name_or_path = model_name_or_path
77
+ self.module = DocumentGTEModule(model_name_or_path=self.model_name_or_path)
78
+ self.tokenizer = _get_tokenizer(model_name_or_path)
79
+ self.num_projector_layers = num_projector_layers
80
+ self.num_tokens_output = num_tokens_output
81
+ self.dtype = torch.float32
82
+ self.device = "cpu"
83
+ self.document_gte_device = "cpu"
84
+
85
+ def build_projector(self, lm_hidden_size: int) -> nn.Module:
86
+ return build_mlp_vector_projector(
87
+ input_hidden_size=self.module.embedding_size,
88
+ lm_hidden_size=lm_hidden_size,
89
+ num_layers=self.num_projector_layers,
90
+ num_tokens=self.num_tokens_output,
91
+ )
92
+
93
+ @property
94
+ def name(self) -> str:
95
+ return "document_gte"
96
+
97
+ @property
98
+ def token(self) -> str:
99
+ return "<document>"
100
+
101
+ @property
102
+ def data_key(self) -> str:
103
+ return "documents"
104
+
105
+ @property
106
+ def token_width(self) -> int:
107
+ return self.num_tokens_output
108
+
109
+ def to(self, dtype: torch.dtype, device: torch.device) -> "DocumentGTEModality":
110
+ self.dtype = dtype
111
+ self.device = device
112
+ if DOCUMENT_GTE_FORCE_CPU not in os.environ:
113
+ # running out of VRAM on 24GB GPU
114
+ self.document_gte_device = device
115
+ self.module.to(device=self.document_gte_device)
116
+ return self
117
+
118
+ def preprocess_rows(self, rows: List[Dict]) -> List[Dict]:
119
+ row_values = []
120
+ for row in rows:
121
+ documents = []
122
+ for doc in row[self.data_key]:
123
+ documents.append(doc)
124
+ documents_tokenized = self.tokenizer(
125
+ documents,
126
+ max_length=GTE_CONTEXT_WINDOW,
127
+ padding=True,
128
+ truncation=True,
129
+ return_tensors="pt",
130
+ )
131
+ row_values.append(documents_tokenized)
132
+ return row_values
133
+
134
+ @torch.no_grad()
135
+ def forward(self, encoded_values: List[Dict]) -> List[torch.Tensor]:
136
+ outputs = []
137
+ for val in encoded_values:
138
+ outputs.append(
139
+ self.module.forward(val.to(device=self.document_gte_device))
140
+ .to(device=self.device, dtype=self.dtype)
141
+ .view(-1, 1, self.module.embedding_size)
142
+ )
143
+ # batch_size x num_items x 1 x embedding_size
144
+ return outputs
src/sonicverse/multi_token/modalities/imagebind.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import os
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from multi_token.modalities.base_modality import Modality
8
+ from multi_token.modalities.projectors import build_mlp_vector_projector
9
+ from multi_token.data_tools import with_local_files
10
+
11
+ IMAGE_BIND_FORCE_CPU = "IMAGE_BIND_FORCE_CPU"
12
+ IMAGE_BIND_EMBEDDING_SIZE = 1024
13
+
14
+
15
+ class ImageBindModule(nn.Module):
16
+ def __init__(self):
17
+ super().__init__()
18
+ from imagebind.models import imagebind_model
19
+ from imagebind import data
20
+
21
+ data.BPE_PATH = os.path.join(
22
+ os.path.dirname(data.__file__), "..", "bpe", "bpe_simple_vocab_16e6.txt.gz"
23
+ )
24
+ self.model = imagebind_model.imagebind_huge(pretrained=True)
25
+ self.model.eval()
26
+ self.model.requires_grad_(False)
27
+
28
+ @torch.no_grad()
29
+ def forward(self, items: Dict) -> torch.Tensor:
30
+ forward_outs = self.model(items)
31
+ return forward_outs
32
+
33
+ @property
34
+ def embedding_size(self):
35
+ return IMAGE_BIND_EMBEDDING_SIZE
36
+
37
+
38
+ class ImageBindModality(Modality):
39
+ def __init__(
40
+ self,
41
+ num_projector_layers: int = 2,
42
+ num_tokens: int = 4,
43
+ preprocess_device: str = "cpu",
44
+ ):
45
+ self.module = ImageBindModule()
46
+ self.dtype = torch.float32
47
+ self.device = "cpu" # used for outputs
48
+ self.imagebind_device = "cpu" # used for imagebind model itself
49
+ self.preprocess_device = preprocess_device # used for preprocessing
50
+ self.num_projector_layers = num_projector_layers
51
+ self.num_tokens = num_tokens
52
+
53
+ def build_projector(self, lm_hidden_size: int) -> nn.Module:
54
+ return build_mlp_vector_projector(
55
+ self.module.embedding_size,
56
+ lm_hidden_size,
57
+ num_layers=self.num_projector_layers,
58
+ num_tokens=self.num_tokens,
59
+ )
60
+
61
+ @property
62
+ def name(self) -> str:
63
+ return "imagebind"
64
+
65
+ @property
66
+ def token(self) -> str:
67
+ return "<imagebind>"
68
+
69
+ @property
70
+ def data_key(self) -> str:
71
+ return "imagebinds"
72
+
73
+ @property
74
+ def token_width(self) -> int:
75
+ return self.num_tokens
76
+
77
+ def to(self, dtype: torch.dtype, device: torch.device) -> "ImageBindModality":
78
+ # we ignore dtype and sometimes device as well
79
+ self.device = device
80
+ self.dtype = dtype
81
+ if IMAGE_BIND_FORCE_CPU not in os.environ:
82
+ # running out of VRAM on 24GB GPU
83
+ self.module.to(device=device)
84
+ self.imagebind_device = device
85
+ return self
86
+
87
+ def preprocess_rows(self, rows: List[Dict]) -> List[List[Dict]]:
88
+ from imagebind.models.imagebind_model import ModalityType
89
+ from imagebind import data
90
+
91
+ row_values = []
92
+ for row in rows:
93
+ items = []
94
+ with with_local_files(row[self.data_key]) as item_paths:
95
+ for item_path in item_paths:
96
+ ib_modality = filename_to_imagebind_modality(item_path)
97
+ if ib_modality == ModalityType.TEXT:
98
+ items.append(
99
+ {
100
+ ModalityType.TEXT: data.load_and_transform_text(
101
+ [item_path], self.preprocess_device
102
+ )
103
+ }
104
+ )
105
+ elif ib_modality == ModalityType.VISION:
106
+ items.append(
107
+ {
108
+ ModalityType.VISION: data.load_and_transform_vision_data(
109
+ [item_path], self.preprocess_device
110
+ )
111
+ }
112
+ )
113
+ elif ib_modality == ModalityType.AUDIO:
114
+ items.append(
115
+ {
116
+ ModalityType.AUDIO: data.load_and_transform_audio_data(
117
+ [item_path], self.preprocess_device
118
+ )
119
+ }
120
+ )
121
+ else:
122
+ raise ValueError(f"Unknown modality type: {ib_modality}")
123
+ row_values.append(items)
124
+ return row_values
125
+
126
+ @torch.no_grad()
127
+ def forward(self, encoded_values: List[List[Dict]]) -> List[torch.Tensor]:
128
+ item_features = []
129
+ for item_batch in encoded_values:
130
+ item_batch_emb = []
131
+ for item in item_batch:
132
+ item = {
133
+ k: v.to(device=self.imagebind_device, dtype=torch.float32)
134
+ for k, v in item.items()
135
+ }
136
+ item_batch_emb.extend(list(self.module.forward(item).values()))
137
+ item_features.append(
138
+ torch.stack(item_batch_emb).to(device=self.device, dtype=self.dtype)
139
+ )
140
+ # batch_size x num_items x 1 x embedding_size
141
+ return item_features
142
+
143
+
144
+ def filename_to_imagebind_modality(fn: str) -> str:
145
+ from imagebind.models.imagebind_model import ModalityType
146
+
147
+ _, ext = os.path.splitext(fn)
148
+ if ext in {".wav"}:
149
+ return ModalityType.AUDIO
150
+ elif ext in {".jpg", ".png", ".jpeg"}:
151
+ return ModalityType.VISION
152
+ else:
153
+ return ModalityType.TEXT
src/sonicverse/multi_token/modalities/multi_task_projector_shared.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.autograd import Variable
4
+ import torch.nn.functional as F
5
+ from typing import Dict
6
+ import numpy as np
7
+
8
+ class CNN(nn.Module):
9
+ def __init__(self, input_channels = 25, num_class=15):
10
+ super(CNN, self).__init__()
11
+ self.aggregator = nn.Parameter(torch.randn((input_channels, 1,1), dtype=torch.float))
12
+ self.input_channels = input_channels
13
+
14
+ # init bn
15
+ self.bn_init = nn.BatchNorm2d(1)
16
+
17
+ # layer 1
18
+ self.conv_1 = nn.Conv2d(1, 64, 3, padding=1)
19
+ self.bn_1 = nn.BatchNorm2d(64)
20
+ self.mp_1 = nn.MaxPool2d((2, 4))
21
+
22
+ # layer 2
23
+ self.conv_2 = nn.Conv2d(64, 128, 3, padding=1)
24
+ self.bn_2 = nn.BatchNorm2d(128)
25
+ self.mp_2 = nn.MaxPool2d((2, 4))
26
+
27
+ # layer 3
28
+ self.conv_3 = nn.Conv2d(128, 128, 3, padding=1)
29
+ self.bn_3 = nn.BatchNorm2d(128)
30
+ self.mp_3 = nn.MaxPool2d((2, 4))
31
+
32
+ # layer 4
33
+ self.conv_4 = nn.Conv2d(128, 128, 3, padding=1)
34
+ self.bn_4 = nn.BatchNorm2d(128)
35
+ self.mp_4 = nn.MaxPool2d((3, 5))
36
+
37
+ # layer 5
38
+ self.conv_5 = nn.Conv2d(128, 64, 3, padding=1)
39
+ self.bn_5 = nn.BatchNorm2d(64)
40
+ self.mp_5 = nn.MaxPool2d((3, 3))
41
+
42
+ # classifier
43
+ self.dense = nn.Linear(640, num_class)
44
+ self.dropout = nn.Dropout(0.5)
45
+
46
+ def forward(self, x):
47
+ aggregator_weights = F.softmax(self.aggregator)
48
+ # aggregator_weights = aggregator_weights.view(self.input_channels, 1)
49
+ # print("0 x shape : ")
50
+ x = (x * aggregator_weights).sum(dim=0)
51
+
52
+ # print("aggregator_output shape ", x.shape)
53
+
54
+ x = x.unsqueeze(0).unsqueeze(0)
55
+
56
+ # print("1 x shape ", x.shape)
57
+ # init bn
58
+ x = self.bn_init(x)
59
+ # print("2 x shape ", x.shape)
60
+
61
+ # layer 1
62
+ x = self.mp_1(nn.ELU()(self.bn_1(self.conv_1(x))))
63
+ # print("3 x shape ", x.shape)
64
+
65
+ # layer 2
66
+ x = self.mp_2(nn.ELU()(self.bn_2(self.conv_2(x))))
67
+ # print("4 x shape ", x.shape)
68
+
69
+ # layer 3
70
+ x = self.mp_3(nn.ELU()(self.bn_3(self.conv_3(x))))
71
+ # print("5 x shape ", x.shape)
72
+
73
+ # layer 4
74
+ x = self.mp_4(nn.ELU()(self.bn_4(self.conv_4(x))))
75
+ # print("6 x shape ", x.shape)
76
+
77
+ # layer 5
78
+ x = self.mp_5(nn.ELU()(self.bn_5(self.conv_5(x))))
79
+ # print("7 x shape ", x.shape)
80
+
81
+ # classifier
82
+ x = x.view(x.size(0), -1)
83
+ # print("8 x shape ", x.shape)
84
+ x = self.dropout(x)
85
+ # print("9 x shape ", x.shape)
86
+ logit = nn.Sigmoid()(self.dense(x))
87
+ # print("logit shape ", logit.shape)
88
+
89
+ return logit
90
+
91
+
92
+ class MLP(nn.Module):
93
+ def __init__(self, input_channels=25, num_class=15):
94
+ super(MLP, self).__init__()
95
+ self.aggregator = nn.Parameter(torch.randn((input_channels, 1,1), dtype=torch.float))
96
+ self.input_channels = input_channels
97
+
98
+ self.hidden_layer_1 = nn.Linear(768, 512)
99
+ self.output = nn.Linear(512, num_class)
100
+ self.dropout = nn.Dropout(p=0.2)
101
+ self.loss = self.get_loss() # can return a dict of losses
102
+
103
+ def forward(self, x):
104
+ """
105
+ x: (B, L, T, H)
106
+ T=#chunks, can be 1 or several chunks
107
+ """
108
+
109
+ weights = F.softmax(self.aggregator, dim=1)
110
+ x = (x * weights).sum(dim=1)
111
+
112
+ x = x.mean(-2)
113
+
114
+ x = self.hidden_layer_1(x)
115
+ x = F.relu(x)
116
+ x = self.dropout(x)
117
+
118
+ return self.output(x)
119
+
120
+ def get_loss(self):
121
+ return nn.BCEWithLogitsLoss()
122
+
123
+ class MLPBackbone(nn.Module):
124
+ def __init__(self, input_features=768, hidden_dim=512):
125
+ super(MLPBackbone, self).__init__()
126
+
127
+ self.hidden_layer_1 = nn.Linear(input_features, hidden_dim)
128
+ self.dropout = nn.Dropout(p=0.2)
129
+ self.loss = self.get_loss() # can return a dict of losses
130
+
131
+ def forward(self, x):
132
+ """
133
+ x: (B, L, T, H)
134
+ T=#chunks, can be 1 or several chunks
135
+ """
136
+
137
+ x = self.hidden_layer_1(x)
138
+ x = F.relu(x)
139
+ x = self.dropout(x)
140
+
141
+ return x
142
+
143
+ def get_loss(self):
144
+ return nn.BCEWithLogitsLoss()
145
+
146
+ class MLPShared(nn.Module):
147
+ def __init__(self, input_channels=25, num_class=15):
148
+ super(MLPShared, self).__init__()
149
+ self.aggregator = nn.Parameter(torch.randn((input_channels, 1,1), dtype=torch.float))
150
+ self.input_channels = input_channels
151
+
152
+ self.hidden_layer_1 = nn.Linear(512, 256)
153
+ self.output = nn.Linear(256, num_class)
154
+ self.dropout = nn.Dropout(p=0.2)
155
+ self.loss = self.get_loss() # can return a dict of losses
156
+
157
+ def forward(self, x):
158
+ """
159
+ x: (B, L, T, H)
160
+ T=#chunks, can be 1 or several chunks
161
+ """
162
+
163
+ weights = F.softmax(self.aggregator, dim=1)
164
+ x = (x * weights).sum(dim=1)
165
+
166
+ x = x.mean(-2)
167
+
168
+ x = self.hidden_layer_1(x)
169
+ x = F.relu(x)
170
+ x = self.dropout(x)
171
+
172
+ return self.output(x)
173
+
174
+ def get_loss(self):
175
+ return nn.BCEWithLogitsLoss()
176
+
177
+ class MLPAggTaskHead(nn.Module):
178
+ def __init__(self, input_channels: int, input_size: int, output_size: int, use_aggregator: bool, use_time_average: bool, use_sigmoid: bool, use_transpose: bool, num_layers: int, hidden_dim: int, width: int):
179
+ super(MLPAggTaskHead, self).__init__()
180
+ if use_aggregator:
181
+ self.aggregator = nn.Parameter(torch.randn((input_channels), dtype=torch.float))
182
+ self.use_aggregator = use_aggregator
183
+ self.use_time_average = use_time_average
184
+ self.use_transpose = use_transpose
185
+ self.use_sigmoid = use_sigmoid
186
+ self.input_channels = input_channels
187
+ self.output_size = output_size
188
+ self.width = width
189
+
190
+ if self.width > 1:
191
+ self.layers = nn.ModuleList()
192
+ for i in range(self.width):
193
+ mlp_layers = [nn.GELU()]
194
+ mlp_layers += self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim)
195
+ if self.use_sigmoid: mlp_layers += [nn.Sigmoid()]
196
+ self.layers.append(nn.Sequential(*mlp_layers))
197
+ else:
198
+ mlp_layers = [nn.GELU()]
199
+ mlp_layers += self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim)
200
+ if self.use_sigmoid: mlp_layers += [nn.Sigmoid()]
201
+ self.layers = nn.Sequential(*mlp_layers)
202
+
203
+ def _create_mlp_layers(self, input_size, output_size, num_layers, hidden_dim):
204
+ if num_layers >=2:
205
+ layers = [nn.Linear(input_size, hidden_dim)]
206
+ layers.append(nn.GELU())
207
+ if num_layers > 2:
208
+ for _ in range(1, num_layers - 2):
209
+ layers += [
210
+ nn.Linear(hidden_dim, hidden_dim),
211
+ nn.GELU()
212
+ ]
213
+ layers.append(nn.Linear(hidden_dim, output_size))
214
+ else:
215
+ layers = [nn.Linear(input_size, output_size)]
216
+ return layers
217
+
218
+
219
+ def forward(self, x):
220
+ if self.use_transpose:
221
+ x = x.transpose(1, 0)
222
+ if self.use_time_average:
223
+ x = x.mean(-2)
224
+ if self.use_aggregator:
225
+ aggregator_weights = F.softmax(self.aggregator)
226
+ aggregator_weights = aggregator_weights.view(self.input_channels, 1)
227
+ aggregator_output = (x * aggregator_weights).sum(dim=0)
228
+ aggregator_output = aggregator_output.unsqueeze(dim=0)
229
+ # print("Agg output ", aggregator_output.shape)
230
+ else:
231
+ aggregator_output = x
232
+
233
+ if self.width > 1:
234
+ if (self.input_channels < 1):
235
+ return torch.cat([layer(aggregator_output.unsqueeze(dim=0)) for layer in self.layers], dim=-2)
236
+ else:
237
+ return torch.cat([layer(aggregator_output.unsqueeze(dim=0)).squeeze(dim=0) for layer in self.layers], dim=-2)
238
+ else:
239
+ if (self.input_channels < 1):
240
+ return self.layers(aggregator_output.unsqueeze(dim=0))
241
+ else:
242
+ return self.layers(aggregator_output.unsqueeze(dim=0)).squeeze()
243
+
244
+
245
+ class MultiTaskModel(nn.Module):
246
+ def __init__(self, tasks: Dict):
247
+ super(MultiTaskModel, self).__init__()
248
+ self.tasks = tasks
249
+ for task_name, task_head in self.tasks["task_heads"].items():
250
+ setattr(self, task_name, MLP(13, task_head["output_size"]))
251
+ if task_name in self.tasks["task_projectors"].keys():
252
+ task_projector = tasks["task_projectors"][task_name]
253
+ setattr(self, task_name + "_projector", MLPAggTaskHead(task_projector["input_channels"], task_projector["input_size"], task_projector["output_size"], task_projector["use_aggregator"], task_projector["use_time_average"], task_projector["use_sigmoid"], task_projector["use_transpose"], task_projector["num_layers"], task_projector["hidden_size"], task_projector["width"]))
254
+
255
+ def forward(self, x):
256
+ task_head_outputs = {}
257
+ task_projector_outputs = []
258
+
259
+ backbone_output = x
260
+
261
+ for task_name in self.tasks["task_heads"]:
262
+ if task_name != "lmm_projector":
263
+ task_head_outputs[task_name] = getattr(self, task_name)(backbone_output)
264
+ if task_name in self.tasks["task_projectors"].keys():
265
+ task_projector_outputs.append(getattr(self, task_name + "_projector")(task_head_outputs[task_name]))
266
+ else:
267
+ task_projector_outputs.append(getattr(self, task_name)(backbone_output))
268
+
269
+ if len(task_projector_outputs) > 0:
270
+ task_projector_outputs_unsqueezed = [task_projector_output.unsqueeze(0) for task_projector_output in task_projector_outputs]
271
+ task_head_outputs["projectors"] = torch.cat(task_projector_outputs_unsqueezed, dim=-2)
272
+
273
+ return task_head_outputs
274
+
275
+ class MultiTaskSharedModel(nn.Module):
276
+ def __init__(self, tasks: Dict):
277
+ super(MultiTaskSharedModel, self).__init__()
278
+ self.tasks = tasks
279
+ self.use_backbone = False
280
+ if "backbone" in self.tasks.keys():
281
+ self.use_backbone = True
282
+ if self.use_backbone: self.backbone = MLPBackbone(768, 512)
283
+ for task_name, task_head in self.tasks["task_heads"].items():
284
+ if task_name != "lmm_projector":
285
+ setattr(self, task_name, MLPShared(13, task_head["output_size"]))
286
+ else:
287
+ setattr(self, task_name, MLPAggTaskHead(task_head["input_channels"], task_head["input_size"], task_head["output_size"], task_head["use_aggregator"], task_head["use_time_average"], task_head["use_sigmoid"], task_head["use_transpose"], task_head["num_layers"], task_head["hidden_size"], task_head["width"]))
288
+ if task_name in self.tasks["task_projectors"].keys():
289
+ task_projector = tasks["task_projectors"][task_name]
290
+ setattr(self, task_name + "_projector", MLPAggTaskHead(task_projector["input_channels"], task_projector["input_size"], task_projector["output_size"], task_projector["use_aggregator"], task_projector["use_time_average"], task_projector["use_sigmoid"], task_projector["use_transpose"], task_projector["num_layers"], task_projector["hidden_size"], task_projector["width"]))
291
+
292
+ def forward(self, x):
293
+ task_head_outputs = {}
294
+ task_projector_outputs = []
295
+
296
+ if self.use_backbone:
297
+ backbone_output = self.backbone(x)
298
+ else:
299
+ backbone_output = x
300
+
301
+ #print("Output shape ", backbone_output.shape)
302
+ for task_name in self.tasks["task_heads"]:
303
+ #print("task namee ", task_name)
304
+ if task_name != "lmm_projector":
305
+ task_head_outputs[task_name] = getattr(self, task_name)(backbone_output)
306
+ if task_name in self.tasks["task_projectors"].keys():
307
+ task_projector_outputs.append(getattr(self, task_name + "_projector")(task_head_outputs[task_name]))
308
+ else:
309
+ llm_input = x
310
+ if self.tasks["task_heads"][task_name]["use_backbone_output"]:
311
+ llm_input = backbone_output
312
+ task_projector_outputs.append(getattr(self, task_name)(llm_input))
313
+
314
+ if len(task_projector_outputs) > 0:
315
+ task_projector_outputs_unsqueezed = [task_projector_output.unsqueeze(0) for task_projector_output in task_projector_outputs]
316
+ task_head_outputs["projectors"] = torch.cat(task_projector_outputs_unsqueezed, dim=-2)
317
+
318
+ return task_head_outputs
319
+
320
+
321
+
src/sonicverse/multi_token/modalities/projectors.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from typing import Dict
4
+ import numpy as np
5
+
6
+ import torch.nn.functional as F
7
+
8
+ def build_patch_mlp_projector(
9
+ input_hidden_size: int, lm_hidden_size: int, num_layers: int
10
+ ) -> nn.Module:
11
+ modules = [nn.Linear(input_hidden_size, lm_hidden_size)]
12
+ for _ in range(1, num_layers):
13
+ modules.append(nn.GELU())
14
+ modules.append(nn.Linear(lm_hidden_size, lm_hidden_size))
15
+ return nn.Sequential(*modules)
16
+
17
+
18
+ class _MLPVectorProjector(nn.Module):
19
+ def __init__(
20
+ self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int
21
+ ):
22
+ super(_MLPVectorProjector, self).__init__()
23
+ self.mlps = nn.ModuleList()
24
+ for _ in range(width):
25
+ mlp = [nn.Linear(input_hidden_size, lm_hidden_size)]
26
+ for _ in range(1, num_layers):
27
+ mlp.append(nn.GELU())
28
+ mlp.append(nn.Linear(lm_hidden_size, lm_hidden_size))
29
+ self.mlps.append(nn.Sequential(*mlp))
30
+
31
+ def forward(self, x):
32
+ output = torch.cat([mlp(x) for mlp in self.mlps], dim=-2)
33
+ return output
34
+
35
+ def build_mlp_vector_projector(
36
+ input_hidden_size: int, lm_hidden_size: int, num_layers: int, num_tokens: int
37
+ ):
38
+ return _MLPVectorProjector(
39
+ input_hidden_size, lm_hidden_size, num_layers, num_tokens
40
+ )
41
+
42
+ class MLPBackbone(nn.Module):
43
+ def __init__(self, input_size: int, output_size: int, num_layers: int, hidden_dim: int):
44
+ super(MLPBackbone, self).__init__()
45
+ self.output_size = output_size
46
+ mlp_layers = self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim)
47
+ self.layers = nn.Sequential(*mlp_layers)
48
+
49
+ def _create_conv_layers(self, input_channels, num_conv_layers, hidden_dim):
50
+ layers = []
51
+ for _ in range(num_conv_layers):
52
+ layers += [
53
+ nn.Conv1d(input_channels, hidden_dim, kernel_size=3, padding=1),
54
+ nn.GELU(),
55
+ nn.MaxPool1d(kernel_size=2, stride=2)
56
+ ]
57
+ input_channels = hidden_dim
58
+ return layers
59
+
60
+ def _create_mlp_layers(self, input_size, output_size, num_layers, hidden_dim):
61
+ if num_layers >=2:
62
+ layers = [nn.Linear(input_size, hidden_dim)]
63
+ layers.append(nn.GELU())
64
+ if num_layers > 2:
65
+ for _ in range(1, num_layers - 2):
66
+ layers += [
67
+ nn.Linear(hidden_dim, hidden_dim),
68
+ nn.GELU()
69
+ ]
70
+ layers.append(nn.Linear(hidden_dim, output_size))
71
+ else:
72
+ layers = [nn.Linear(input_size, output_size)]
73
+ return layers
74
+
75
+ def forward(self, x):
76
+ return self.layers(x)
77
+
78
+ class MLPTaskHead(nn.Module):
79
+ def __init__(self, backbone: nn.Module, input_size: int, output_size: int, num_layers: int, hidden_dim: int, width: int = 1):
80
+ super(MLPTaskHead, self).__init__()
81
+ self.backbone = backbone
82
+ self.width = width
83
+ if width > 1:
84
+ self.layers = nn.ModuleList()
85
+ for i in range(width):
86
+ mlp_layers = [nn.GELU()]
87
+ mlp_layers += self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim)
88
+ self.layers.append(nn.Sequential(*mlp_layers))
89
+ else:
90
+ mlp_layers = [nn.GELU()]
91
+ mlp_layers += self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim)
92
+ self.layers = nn.Sequential(*mlp_layers)
93
+
94
+ def _create_mlp_layers(self, input_size, output_size, num_layers, hidden_dim):
95
+ if num_layers >=2:
96
+ layers = [nn.Linear(input_size, hidden_dim)]
97
+ layers.append(nn.GELU())
98
+ if num_layers > 2:
99
+ for _ in range(1, num_layers - 2):
100
+ layers += [
101
+ nn.Linear(hidden_dim, hidden_dim),
102
+ nn.GELU()
103
+ ]
104
+ layers.append(nn.Linear(hidden_dim, output_size))
105
+ else:
106
+ layers = [nn.Linear(input_size, output_size)]
107
+ return layers
108
+
109
+ def _create_conv_layers(self, input_channels, num_conv_layers, hidden_dim):
110
+ layers = []
111
+ for _ in range(num_conv_layers):
112
+ layers += [
113
+ nn.Conv2d(in_channels = input_channels, out_channels = hidden_dim, kernel_size=(3,3), stride=1, padding=1),
114
+ nn.GELU(),
115
+ nn.MaxPool1d(kernel_size=2, stride=2)
116
+ ]
117
+ input_channels = hidden_dim
118
+ return layers
119
+
120
+ def forward(self, x):
121
+ output = self.backbone.forward(x)
122
+ if self.width > 1:
123
+ return torch.cat([layer(output) for layer in self.layers], dim=-2)
124
+ else:
125
+ return self.layers(output)
126
+
127
+ class MLPTaskModule(nn.Module):
128
+ def __init__(self, input_size: int, output_size: int, num_layers: int, hidden_dim: int, width: int = 1):
129
+ super(MLPTaskModule, self).__init__()
130
+ self.width = width
131
+ if width > 1:
132
+ self.layers = nn.ModuleList()
133
+ for i in range(width):
134
+ mlp_layers = [nn.GELU()]
135
+ mlp_layers += self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim)
136
+ self.layers.append(nn.Sequential(*mlp_layers))
137
+ else:
138
+ mlp_layers = [nn.GELU()]
139
+ mlp_layers += self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim)
140
+ self.layers = nn.Sequential(*mlp_layers)
141
+
142
+ def _create_mlp_layers(self, input_size, output_size, num_layers, hidden_dim):
143
+ if num_layers >=2:
144
+ layers = [nn.Linear(input_size, hidden_dim)]
145
+ layers.append(nn.GELU())
146
+ if num_layers > 2:
147
+ for _ in range(1, num_layers - 2):
148
+ layers += [
149
+ nn.Linear(hidden_dim, hidden_dim),
150
+ nn.GELU()
151
+ ]
152
+ layers.append(nn.Linear(hidden_dim, output_size))
153
+ else:
154
+ layers = [nn.Linear(input_size, output_size)]
155
+ return layers
156
+
157
+ def _create_conv_layers(self, input_channels, num_conv_layers, hidden_dim):
158
+ layers = []
159
+ for _ in range(num_conv_layers):
160
+ layers += [
161
+ nn.Conv2d(in_channels = input_channels, out_channels = hidden_dim, kernel_size=(3,3), stride=1, padding=1),
162
+ nn.GELU(),
163
+ nn.MaxPool1d(kernel_size=2, stride=2)
164
+ ]
165
+ input_channels = hidden_dim
166
+ return layers
167
+
168
+ def forward(self, x):
169
+ if self.width > 1:
170
+ return torch.cat([layer(x) for layer in self.layers], dim=-2)
171
+ else:
172
+ return self.layers(x)
173
+
174
+
175
+ class MultiTaskModel(nn.Module):
176
+ def __init__(self, input_hidden_size: int, input_channels: int, time_average: bool, time_dimension: int, use_aggregator: bool, tasks: Dict):
177
+ super(MultiTaskModel, self).__init__()
178
+ self.tasks = tasks
179
+ self.time_average = time_average
180
+ self.time_dimension = time_dimension
181
+ self.use_aggregator = use_aggregator
182
+ if self.use_aggregator:
183
+ if (time_average):
184
+ self.aggregator = nn.Parameter(torch.randn((input_channels, 1), dtype = torch.float))
185
+ else:
186
+ self.aggregator = nn.Parameter(torch.randn((input_channels, 1, 1), dtype = torch.float))
187
+
188
+ self.backbone = MLPBackbone(input_hidden_size, self.tasks["backbone"]["output_size"], self.tasks["backbone"]["num_layers"], self.tasks["backbone"]["hidden_size"])
189
+ for task_name, task_head in self.tasks["task_heads"].items():
190
+ setattr(self, task_name, MLPTaskModule(self.tasks["backbone"]["output_size"], task_head["output_size"], task_head["num_layers"], task_head["hidden_size"], task_head["width"]))
191
+ if task_name in self.tasks["task_projectors"].keys():
192
+ task_projector = tasks["task_projectors"][task_name]
193
+ setattr(self, task_name + "_projector", MLPTaskModule(task_head["output_size"], task_projector["output_size"], task_projector["num_layers"], task_projector["hidden_size"], task_projector["width"]))
194
+
195
+ def forward(self, x):
196
+ task_head_outputs = {}
197
+ task_projector_outputs = []
198
+
199
+ if self.time_average:
200
+ x = x.mean(self.time_dimension)
201
+ if self.use_aggregator:
202
+ aggregator_weights = F.softmax(self.aggregator, dim=0)
203
+ aggregator_output = (x * aggregator_weights).sum(dim=0)
204
+ aggregator_output = aggregator_output.unsqueeze(0)
205
+ else:
206
+ aggregator_output = x
207
+
208
+ backbone_output = self.backbone(aggregator_output)
209
+
210
+ for task_name in self.tasks["task_heads"]:
211
+ if task_name != "lmm_projector":
212
+ task_head_output = getattr(self, task_name)(backbone_output)
213
+ min_val = torch.min(task_head_output)
214
+ max_val = torch.max(task_head_output)
215
+
216
+ normalized_task_head_output = (task_head_output - min_val) / (max_val - min_val)
217
+ task_head_outputs[task_name] = normalized_task_head_output
218
+ if task_name in self.tasks["task_projectors"].keys():
219
+ task_projector_outputs.append(getattr(self, task_name + "_projector")(task_head_outputs[task_name]))
220
+ else:
221
+ task_projector_outputs.append(getattr(self, task_name)(backbone_output))
222
+
223
+ task_projector_outputs_unsqueezed = [task_projector_output.unsqueeze(0) for task_projector_output in task_projector_outputs]
224
+ if len(task_projector_outputs_unsqueezed) > 0:
225
+ task_head_outputs["projectors"] = torch.cat(task_projector_outputs_unsqueezed, dim=-2)
226
+
227
+ return task_head_outputs
228
+
229
+
230
+ def build_mt_vector_projector(
231
+ input_hidden_size: int, lm_hidden_size: int, tasks: Dict
232
+ ):
233
+ projector = nn.ModuleDict()
234
+ projector["backbone"] = MLPBackbone(input_hidden_size, tasks["backbone"]["output_size"], tasks["backbone"]["num_layers"], tasks["backbone"]["hidden_size"])
235
+ for task_name, task_head in tasks["task_heads"].items():
236
+ projector[task_name] = MLPTaskHead(projector["backbone"], task_head["hidden_size"], task_head["output_size"], task_head["num_layers"], task_head["hidden_size"], task_head["width"])
237
+
238
+ return projector
239
+
240
+ class Attention(nn.Module):
241
+ def __init__(self, input_dim, hidden_dim):
242
+ super(Attention, self).__init__()
243
+ self.linear_in = nn.Linear(input_dim, hidden_dim)
244
+ self.linear_out = nn.Linear(hidden_dim, 1)
245
+
246
+ def forward(self, x):
247
+ # Input shape: (batch_size, seq_len, input_dim)
248
+ energy = torch.tanh(self.linear_in(x))
249
+ attention_scores = torch.softmax(self.linear_out(energy), dim=1)
250
+ context_vector = torch.sum(attention_scores * x, dim=1)
251
+ return context_vector
252
+
253
+ class _CNNAttentionTokenizer(nn.Module):
254
+ def __init__(self, input_channels, output_size, width, hidden_dim, num_conv_layers):
255
+ super(_CNNAttentionTokenizer, self).__init__()
256
+ self.width = width
257
+ self.cnns = nn.ModuleList()
258
+ self.attentions = nn.ModuleList()
259
+ for _ in range(width):
260
+ cnn = self._create_conv_layers(input_channels, num_conv_layers)
261
+ self.cnns.append(cnn)
262
+ attention = [Attention(hidden_dim, 125)]
263
+ linear_input_size = hidden_dim
264
+ attention.append(nn.Linear(linear_input_size, output_size))
265
+ self.attentions.append(nn.Sequential(*attention))
266
+
267
+
268
+ def _create_conv_layers(self, input_channels, num_conv_layers):
269
+ layers = []
270
+ in_channels = input_channels
271
+ for _ in range(num_conv_layers):
272
+ layers += [
273
+ nn.Conv1d(in_channels, 64, kernel_size=3, padding=1),
274
+ nn.ReLU(),
275
+ nn.MaxPool1d(kernel_size=2, stride=2)
276
+ ]
277
+ in_channels = 64
278
+ return nn.Sequential(*layers)
279
+
280
+ def forward(self, x):
281
+ outputs = []
282
+ for token in range(self.width):
283
+ # Input shape: (batch_size, input_channels, sequence_length)
284
+ token_output = self.cnns[token](x) # Apply convolutional layers
285
+ token_output = token_output.permute(0, 2, 1) # Reshape for attention mechanism (batch_size, sequence_length, input_dim
286
+ token_output = self.attentions[token](token_output) # Apply attention mechanism
287
+ outputs.append(token_output)
288
+ output = torch.cat(outputs, dim=-2)
289
+ output = torch.stack([output])
290
+ return output
291
+
292
+ def build_attentive_cnn_projector(
293
+ input_channels: int, lm_hidden_size: int, num_tokens: int, hidden_dim: int, num_layers: int
294
+ ):
295
+ return _CNNAttentionTokenizer(input_channels, lm_hidden_size, num_tokens, hidden_dim, num_layers)
296
+
297
+ class _CNNMLPProjector(nn.Module):
298
+ def __init__(self, input_channels, input_size, output_size = 4096, width = 5, hidden_dim = 64, num_conv_layers = 1, num_mlp_layers = 2):
299
+ super(_CNNMLPProjector, self).__init__()
300
+ self.width = width
301
+ self.cnnmlps = nn.ModuleList()
302
+ for _ in range(self.width):
303
+ cnnmlp = self._create_conv_layers(input_channels, num_conv_layers, hidden_dim)
304
+ cnnmlp.append(nn.Flatten())
305
+ cnn_output_size = hidden_dim*((input_size + 2*1 - 3*num_conv_layers) // (2**num_conv_layers) + 1)
306
+ cnnmlp.append(nn.Linear(cnn_output_size, output_size))
307
+ cnnmlp.append(nn.GELU())
308
+ cnnmlp += self._create_mlp_layers(output_size, output_size, num_mlp_layers, output_size)
309
+ self.cnnmlps.append(nn.Sequential(*cnnmlp))
310
+
311
+ def _create_conv_layers(self, input_channels, num_conv_layers, hidden_dim):
312
+ layers = []
313
+ for _ in range(num_conv_layers):
314
+ layers += [
315
+ nn.Conv1d(input_channels, hidden_dim, kernel_size=3, padding=1),
316
+ nn.GELU(),
317
+ nn.MaxPool1d(kernel_size=2, stride=2)
318
+ ]
319
+ input_channels = hidden_dim
320
+ return layers
321
+
322
+ def _create_mlp_layers(self, input_size, output_size, num_layers, hidden_dim):
323
+ if num_layers >=2:
324
+ layers = [nn.Linear(input_size, hidden_dim)]
325
+ layers.append(nn.GELU())
326
+ if num_layers > 2:
327
+ for _ in range(1, num_layers - 2):
328
+ layers += [
329
+ nn.Linear(hidden_dim, hidden_dim),
330
+ nn.GELU()
331
+ ]
332
+ layers.append(nn.Linear(hidden_dim, output_size))
333
+ else:
334
+ layers = [nn.Linear(input_size, output_size)]
335
+ return layers
336
+
337
+ def forward(self, x):
338
+ return torch.stack([torch.cat([cnnmlp(x) for cnnmlp in self.cnnmlps], dim=-2)])
339
+
340
+ def build_cnn_mlp_projector(
341
+ input_channels: int, input_size: int, lm_hidden_size: int, num_tokens: int, hidden_dim: int, num_conv_layers: int, num_mlp_layers: int
342
+ ):
343
+ return _CNNMLPProjector(input_channels, input_size, lm_hidden_size, num_tokens, hidden_dim, num_conv_layers, num_mlp_layers)
344
+
345
+ class _MultiLayeredCNNMLPProjector(nn.Module):
346
+ def __init__(self, input_channels, input_size, num_feature_layers, output_size = 4096, width = 5, hidden_dim = 64, num_conv_layers = 1, num_mlp_layers = 2):
347
+ super(_MultiLayeredCNNMLPProjector, self).__init__()
348
+ self.width = width
349
+ self.num_feature_layers = num_feature_layers
350
+ self.cnnmlps = nn.ModuleList()
351
+ for _ in range(self.width*self.num_feature_layers):
352
+ cnnmlp = self._create_conv_layers(input_channels, num_conv_layers, hidden_dim)
353
+ cnnmlp += [nn.GELU()]
354
+ cnnmlp += self._create_mlp_layers(input_size, output_size, num_mlp_layers, output_size)
355
+ self.cnnmlps.append(nn.Sequential(*cnnmlp))
356
+
357
+ def _create_conv_layers(self, input_channels, num_conv_layers, hidden_size):
358
+ layers = []
359
+
360
+ if input_channels >= hidden_size:
361
+ hidden_dim = int(input_channels/2)
362
+ else:
363
+ hidden_dim = hidden_size
364
+
365
+ layers += [nn.Conv1d(in_channels=input_channels, out_channels=hidden_dim, kernel_size=3, stride=1, padding=1), nn.GELU()]
366
+ if num_conv_layers > 2:
367
+ for _ in range(num_conv_layers - 2):
368
+ if hidden_dim/2 >= hidden_size:
369
+ output_dim = int(hidden_dim/2)
370
+ else:
371
+ output_dim = hidden_size
372
+ layers += [
373
+ nn.Conv1d(in_channels=hidden_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1),
374
+ nn.GELU(),
375
+ ]
376
+ hidden_dim = output_dim
377
+ layers += [nn.Conv1d(in_channels=hidden_dim, out_channels=1, kernel_size=3, stride=1, padding=1)]
378
+ return layers
379
+
380
+ def _create_mlp_layers(self, input_size, output_size, num_layers, hidden_dim):
381
+ if num_layers >=2:
382
+ layers = [nn.Linear(input_size, hidden_dim)]
383
+ layers.append(nn.GELU())
384
+ if num_layers > 2:
385
+ for _ in range(1, num_layers - 2):
386
+ layers += [
387
+ nn.Linear(hidden_dim, hidden_dim),
388
+ nn.GELU()
389
+ ]
390
+ layers.append(nn.Linear(hidden_dim, output_size))
391
+ else:
392
+ layers = [nn.Linear(input_size, output_size)]
393
+ return layers
394
+
395
+ def forward(self, x):
396
+ print("X SHAPE ", x.shape)
397
+ inp_feature_layers = []
398
+ for feature_id in range(self.num_feature_layers):
399
+ in_feat_layer = x[feature_id].unsqueeze(0).permute(0,2,1)
400
+ inp_feature_layers.append(in_feat_layer)
401
+
402
+ outputs = []
403
+ for layer_count in range(self.width*self.num_feature_layers):
404
+ feature_id = int(layer_count/self.width)
405
+ outputs+=[self.cnnmlps[layer_count](inp_feature_layers[feature_id])]
406
+
407
+ return torch.cat(outputs, dim=-2)
408
+
409
+
410
+ def build_multi_layer_cnn_mlp_projector(
411
+ input_channels: int, input_size: int, num_feature_layers: int, lm_hidden_size: int, num_tokens: int, hidden_dim: int, num_conv_layers: int, num_mlp_layers: int
412
+ ):
413
+ assert(num_tokens % num_feature_layers == 0)
414
+ width = int(num_tokens/num_feature_layers)
415
+ return _MultiLayeredCNNMLPProjector(input_channels, input_size, num_feature_layers, lm_hidden_size, width, hidden_dim, num_conv_layers, num_mlp_layers)
416
+
src/sonicverse/multi_token/modalities/video_xclip.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoProcessor, AutoModel
6
+
7
+ from multi_token.data_tools import load_video
8
+ from multi_token.modalities.base_modality import Modality
9
+ from multi_token.modalities.projectors import (
10
+ build_mlp_vector_projector,
11
+ )
12
+
13
+
14
+ OUTPUT_EMB_SIZE = 512
15
+
16
+
17
+ class XCLIPVideoModule(nn.Module):
18
+ def __init__(self, model_name_or_path: str):
19
+ super().__init__()
20
+ self.model_name_or_path = model_name_or_path
21
+ self.model = None
22
+ self.processor = None
23
+
24
+ self.load_model()
25
+
26
+ def load_model(self):
27
+ self.model = AutoModel.from_pretrained(self.model_name_or_path)
28
+ self.processor = AutoProcessor.from_pretrained(self.model_name_or_path)
29
+ self.model.requires_grad_(False)
30
+
31
+ @torch.no_grad()
32
+ def forward(self, video_inputs) -> torch.Tensor:
33
+ with torch.no_grad():
34
+ outputs = self.model(**(video_inputs.to(device=self.device)))
35
+
36
+ emb = outputs.video_embeds.to(device=self.device, dtype=self.dtype).view(
37
+ -1, 1, OUTPUT_EMB_SIZE
38
+ )
39
+ return emb
40
+
41
+ @property
42
+ def dtype(self):
43
+ return self.model.dtype
44
+
45
+ @property
46
+ def device(self):
47
+ return self.model.device
48
+
49
+
50
+ class XCLIPVideoModality(Modality):
51
+ def __init__(
52
+ self,
53
+ model_name_or_path: str = "microsoft/xclip-base-patch32",
54
+ num_projector_layers: int = 2,
55
+ num_tokens_output: int = 10,
56
+ ):
57
+ self.model_name_or_path = model_name_or_path
58
+ self.module = XCLIPVideoModule(model_name_or_path=self.model_name_or_path)
59
+ self.num_projector_layers = num_projector_layers
60
+ self.num_tokens_output = num_tokens_output
61
+
62
+ def build_projector(self, lm_hidden_size: int) -> nn.Module:
63
+ return build_mlp_vector_projector(
64
+ input_hidden_size=OUTPUT_EMB_SIZE,
65
+ lm_hidden_size=lm_hidden_size,
66
+ num_layers=self.num_projector_layers,
67
+ num_tokens=self.num_tokens_output,
68
+ )
69
+
70
+ @property
71
+ def name(self) -> str:
72
+ return "video_xclip"
73
+
74
+ @property
75
+ def token(self) -> str:
76
+ return "<video>"
77
+
78
+ @property
79
+ def data_key(self) -> str:
80
+ return "videos"
81
+
82
+ @property
83
+ def token_width(self) -> int:
84
+ return self.num_tokens_output
85
+
86
+ def to(self, dtype: torch.dtype, device: torch.device) -> "XCLIPVideoModality":
87
+ self.module.to(dtype=dtype, device=device)
88
+ return self
89
+
90
+ def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Dict]]:
91
+ row_values = []
92
+ for row in rows:
93
+ video_arrays = [
94
+ load_video(
95
+ video_info,
96
+ )
97
+ for video_info in row[self.data_key]
98
+ ]
99
+ videos_enc = self.module.processor(
100
+ videos=[list(video) for video in video_arrays],
101
+ text=["IGNORE"],
102
+ return_tensors="pt",
103
+ padding=True,
104
+ )
105
+ row_values.append(videos_enc)
106
+ return row_values
107
+
108
+ @torch.no_grad()
109
+ def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]:
110
+ video_features = []
111
+ for video_batch in encoded_values:
112
+ video_features.append(self.module.forward(video_batch))
113
+ return video_features
src/sonicverse/multi_token/modalities/vision_clip.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import CLIPVisionModel, CLIPImageProcessor
6
+ from PIL import Image
7
+
8
+ from multi_token.modalities.base_modality import Modality
9
+ from multi_token.modalities.projectors import (
10
+ build_patch_mlp_projector,
11
+ build_mlp_vector_projector,
12
+ )
13
+ from multi_token.data_tools import load_image
14
+
15
+ PATCH_LAYER = -2
16
+ OUTPUT_LAYER = -1
17
+ OUTPUT_EMB_SIZE = 1024
18
+
19
+
20
+ class CLIPVisionModule(nn.Module):
21
+ def __init__(self, model_name_or_path: str, feature_layer: int = PATCH_LAYER):
22
+ super().__init__()
23
+ self.feature_layer = feature_layer
24
+ self.model_name_or_path = model_name_or_path
25
+ self.image_processor = None
26
+ self.image_model = None
27
+
28
+ self.load_model()
29
+
30
+ def load_model(self):
31
+ self.image_processor = CLIPImageProcessor.from_pretrained(
32
+ self.model_name_or_path
33
+ )
34
+ self.image_model = CLIPVisionModel.from_pretrained(self.model_name_or_path)
35
+ self.image_model.requires_grad_(False)
36
+
37
+ @torch.no_grad()
38
+ def forward(self, images) -> torch.Tensor:
39
+ if self.feature_layer == PATCH_LAYER:
40
+ image_forward_outs = self.image_model(
41
+ images.to(device=self.device, dtype=self.dtype),
42
+ output_hidden_states=True,
43
+ )
44
+ image_features = image_forward_outs.hidden_states[self.feature_layer]
45
+ image_features = image_features[:, 1:].to(images.dtype)
46
+ else:
47
+ image_forward_outs = self.image_model(
48
+ images.to(device=self.device, dtype=self.dtype),
49
+ )
50
+ image_features = image_forward_outs.pooler_output.to(images.dtype).view(
51
+ -1, 1, OUTPUT_EMB_SIZE
52
+ )
53
+ return image_features
54
+
55
+ @property
56
+ def dtype(self):
57
+ return self.image_model.dtype
58
+
59
+ @property
60
+ def device(self):
61
+ return self.image_model.device
62
+
63
+ @property
64
+ def config(self):
65
+ return self.image_model.config
66
+
67
+ @property
68
+ def hidden_size(self):
69
+ return self.config.hidden_size
70
+
71
+ @property
72
+ def num_patches(self):
73
+ return (self.config.image_size // self.config.patch_size) ** 2
74
+
75
+
76
+ def _expand2square(pil_img: Image, background_color: Tuple) -> Image:
77
+ width, height = pil_img.size
78
+ if width == height:
79
+ return pil_img
80
+ elif width > height:
81
+ result = Image.new(pil_img.mode, (width, width), background_color)
82
+ result.paste(pil_img, (0, (width - height) // 2))
83
+ return result
84
+ else:
85
+ result = Image.new(pil_img.mode, (height, height), background_color)
86
+ result.paste(pil_img, ((height - width) // 2, 0))
87
+ return result
88
+
89
+
90
+ class CLIPVisionModality(Modality):
91
+ def __init__(
92
+ self,
93
+ model_name_or_path: str = "openai/clip-vit-large-patch14-336",
94
+ pad_non_square_images: bool = False,
95
+ num_projector_layers: int = 2,
96
+ feature_layer: int = PATCH_LAYER,
97
+ num_tokens_output: Optional[int] = None,
98
+ ):
99
+ if feature_layer not in [PATCH_LAYER, OUTPUT_LAYER]:
100
+ raise ValueError(
101
+ f"feature_layer must be one of {PATCH_LAYER} or {OUTPUT_LAYER}"
102
+ )
103
+ if (feature_layer == PATCH_LAYER) != (num_tokens_output is None):
104
+ raise ValueError(
105
+ "num_tokens_output must be None if feature_layer is PATCH_LAYER"
106
+ )
107
+ self.model_name_or_path = model_name_or_path
108
+ self.module = CLIPVisionModule(
109
+ model_name_or_path=self.model_name_or_path, feature_layer=feature_layer
110
+ )
111
+ self.pad_non_square_images = pad_non_square_images
112
+ self.num_projector_layers = num_projector_layers
113
+ self.num_tokens_output = num_tokens_output
114
+
115
+ def build_projector(self, lm_hidden_size: int) -> nn.Module:
116
+ if self.module.feature_layer == PATCH_LAYER:
117
+ return build_patch_mlp_projector(
118
+ self.module.hidden_size,
119
+ lm_hidden_size,
120
+ num_layers=self.num_projector_layers,
121
+ )
122
+ else:
123
+ return build_mlp_vector_projector(
124
+ input_hidden_size=OUTPUT_EMB_SIZE,
125
+ lm_hidden_size=lm_hidden_size,
126
+ num_layers=self.num_projector_layers,
127
+ num_tokens=self.num_tokens_output,
128
+ )
129
+
130
+ @property
131
+ def name(self) -> str:
132
+ return "vision_clip"
133
+
134
+ @property
135
+ def token(self) -> str:
136
+ return "<image>"
137
+
138
+ @property
139
+ def data_key(self) -> str:
140
+ return "images"
141
+
142
+ @property
143
+ def token_width(self) -> int:
144
+ if self.module.feature_layer == PATCH_LAYER:
145
+ return self.module.num_patches
146
+ else:
147
+ return self.num_tokens_output
148
+
149
+ def to(self, dtype: torch.dtype, device: torch.device) -> "CLIPVisionModality":
150
+ self.module.to(dtype=dtype, device=device)
151
+ return self
152
+
153
+ def preprocess_rows(self, rows: List[Dict]) -> List[Optional[torch.Tensor]]:
154
+ row_values = []
155
+ for row in rows:
156
+ images = []
157
+ for image_fn in row[self.data_key]:
158
+ image_obj = load_image(image_fn)
159
+ if self.pad_non_square_images:
160
+ image_obj = _expand2square(
161
+ image_obj,
162
+ tuple(
163
+ int(x * 255) for x in self.module.image_processor.image_mean
164
+ ),
165
+ )
166
+ image = self.module.image_processor.preprocess(
167
+ image_obj, return_tensors="pt"
168
+ )["pixel_values"][0]
169
+ images.append(image)
170
+ row_values.append(torch.stack(images) if len(images) > 0 else None)
171
+ return row_values
172
+
173
+ @torch.no_grad()
174
+ def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]:
175
+ image_features = []
176
+ for image_batch in encoded_values:
177
+ image_features.append(self.module.forward(image_batch))
178
+ return image_features
src/sonicverse/multi_token/model_utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+ import logging
3
+ import torch
4
+
5
+ from enum import Enum
6
+
7
+ class MultiTaskType(Enum):
8
+ NO_MULTI_TASK = 0
9
+ SIMPLE_MULTI_TASK = 1
10
+ PROJECTED_MULTI_TASK = 2
11
+
12
+ def _find_all_linear_names(model) -> List[str]:
13
+ cls = torch.nn.Linear
14
+ lora_module_names = set()
15
+ for name, module in model.named_modules():
16
+ if isinstance(module, cls):
17
+ names = name.split(".")
18
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
19
+
20
+ if "lm_head" in lora_module_names:
21
+ lora_module_names.remove("lm_head")
22
+ return list(lora_module_names)
23
+
24
+
25
+ def maybe_zero_3(param, ignore_status=False, name=None):
26
+ from deepspeed import zero
27
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
28
+
29
+ if hasattr(param, "ds_id"):
30
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
31
+ if not ignore_status:
32
+ logging.warning(
33
+ f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}"
34
+ )
35
+ with zero.GatheredParameters([param]):
36
+ param = param.data.detach().cpu().clone()
37
+ else:
38
+ param = param.detach().cpu().clone()
39
+ return param
40
+
41
+
42
+ def get_peft_state(named_params, bias) -> Dict:
43
+ if bias == "none":
44
+ to_return = {k: t for k, t in named_params if "lora_" in k}
45
+ elif bias == "all":
46
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
47
+ elif bias == "lora_only":
48
+ to_return = {}
49
+ maybe_lora_bias = {}
50
+ lora_bias_names = set()
51
+ for k, t in named_params:
52
+ if "lora_" in k:
53
+ to_return[k] = t
54
+ bias_name = k.split("lora_")[0] + "bias"
55
+ lora_bias_names.add(bias_name)
56
+ elif "bias" in k:
57
+ maybe_lora_bias[k] = t
58
+ for k, t in maybe_lora_bias:
59
+ if bias_name in lora_bias_names:
60
+ to_return[bias_name] = t
61
+ else:
62
+ raise NotImplementedError()
63
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
64
+ return to_return
65
+
66
+
67
+ def get_peft_state_non_lora(named_params, task_names) -> Dict:
68
+ to_return = {}
69
+ for k, t in named_params:
70
+ if "lora_" not in k:
71
+ task_name_in_k = False
72
+ for task_name in task_names:
73
+ if task_name in k:
74
+ task_name_in_k = True
75
+ if t.requires_grad or task_name_in_k:
76
+ to_return[k] = t
77
+ to_return = {
78
+ k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()
79
+ }
80
+ return to_return
81
+
82
+
83
+ def make_model_lora(model, training_args: "TrainingArguments"):
84
+ from peft import LoraConfig, get_peft_model
85
+
86
+ lora_config = LoraConfig(
87
+ r=training_args.lora_r,
88
+ lora_alpha=training_args.lora_alpha,
89
+ target_modules=_find_all_linear_names(model),
90
+ lora_dropout=training_args.lora_dropout,
91
+ bias=training_args.lora_bias,
92
+ task_type="CAUSAL_LM",
93
+ )
94
+ if training_args.bits == 16:
95
+ if training_args.bf16:
96
+ model.to(torch.bfloat16)
97
+ if training_args.fp16:
98
+ model.to(torch.float16)
99
+
100
+ model = get_peft_model(model, lora_config)
101
+ return model
102
+
103
+
104
+ def fix_tokenizer(tokenizer):
105
+ if tokenizer.pad_token is None:
106
+ tokenizer.pad_token = tokenizer.unk_token
107
+ if tokenizer.mask_token is None:
108
+ tokenizer.mask_token = tokenizer.unk_token
109
+ if tokenizer.cls_token is None:
110
+ tokenizer.cls_token = tokenizer.unk_token
111
+ if tokenizer.sep_token is None:
112
+ tokenizer.sep_token = tokenizer.unk_token
src/sonicverse/multi_token/training.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ from dataclasses import field, dataclass
3
+ import logging
4
+ import subprocess
5
+ import pathlib
6
+ import torch
7
+ import shutil
8
+ import glob
9
+ import os
10
+ import json
11
+
12
+ import transformers
13
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
14
+ from transformers import Trainer
15
+
16
+ from multi_token.training_data import (
17
+ DataArguments,
18
+ LMMDataset,
19
+ DataCollatorForSupervisedLMMDataset,
20
+ )
21
+ from multi_token.model_utils import (
22
+ make_model_lora,
23
+ get_peft_state,
24
+ get_peft_state_non_lora,
25
+ fix_tokenizer,
26
+ MultiTaskType
27
+ )
28
+ from multi_token.modalities.base_modality import Modality
29
+
30
+
31
+ README_TEMPLATE = """
32
+ ---
33
+ license: apache-2.0
34
+ base_model: {base_model}
35
+ dataset: {dataset}
36
+ tags:
37
+ - finetuned
38
+ - multimodal
39
+ inference: false
40
+ ---
41
+
42
+ These are weights for a version of `{base_model}` finetuned for multimodal applications.
43
+
44
+ ### Modalities
45
+
46
+ {modalities}
47
+
48
+ ### Usage
49
+
50
+ GitHub: https://github.com/sshh12/multi_token (includes training scripts and basic inference server)
51
+
52
+ ### Dataset
53
+
54
+ {dataset} ({num_examples} examples)
55
+
56
+ ```
57
+ {dataset_example}
58
+ ```
59
+
60
+ ### Training Device(s)
61
+
62
+ ```
63
+ {training_devices_dump}
64
+ ```
65
+
66
+
67
+ ### Model
68
+
69
+ ```
70
+ {repr_model}
71
+ ```
72
+
73
+ """
74
+
75
+
76
+ @dataclass
77
+ class TrainingArguments(transformers.TrainingArguments):
78
+ cache_dir: Optional[str] = field(default=None)
79
+ remove_unused_columns: bool = field(default=False)
80
+ optim: str = field(default="adamw_torch")
81
+ model_max_length: int = field(
82
+ default=512,
83
+ metadata={
84
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
85
+ },
86
+ )
87
+ double_quant: bool = field(
88
+ default=True,
89
+ metadata={
90
+ "help": "Compress the quantization statistics through double quantization."
91
+ },
92
+ )
93
+ quant_type: str = field(
94
+ default="nf4",
95
+ metadata={
96
+ "help": "Quantization data type to use. Should be one of `fp4` or `nf4`."
97
+ },
98
+ )
99
+ pretrain_projectors: bool = field(default=False)
100
+ pretrained_projectors_path: Optional[str] = field(default=None)
101
+ pretrained_projectors_config: Optional[str] = field(default=None)
102
+ bits: int = field(default=16, metadata={"help": "How many bits to use."})
103
+ lora_enable: bool = False
104
+ lora_r: int = 64
105
+ lora_alpha: int = 16
106
+ lora_dropout: float = 0.05
107
+ lora_weight_path: str = ""
108
+ lora_bias: str = "none"
109
+
110
+
111
+ @dataclass
112
+ class ModelArguments:
113
+ model_name_or_path: str = field(default="mistralai/Mistral-7B-Instruct-v0.1")
114
+ model_cls: str = field(default="MistralLMMForCausalLM")
115
+ modality_builder: str = field(default="vision_clip")
116
+ use_multi_task: int = field(default=MultiTaskType.PROJECTED_MULTI_TASK)
117
+ tasks_config: str = field(default="src/sonicverse/configs/tasks.json")
118
+ model_lora_path: Optional[str] = field(default="annabeth97c/sonicverse")
119
+
120
+
121
+ class LMMTrainer(Trainer):
122
+ def _save_checkpoint(self, model, trial, metrics=None):
123
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
124
+
125
+ run_dir = self._get_output_dir(trial=trial)
126
+ output_dir = os.path.join(run_dir, checkpoint_folder)
127
+ self._save_extras(output_dir)
128
+
129
+ super(LMMTrainer, self)._save_checkpoint(model, trial, metrics)
130
+
131
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
132
+ self._save_extras(output_dir)
133
+ super(LMMTrainer, self)._save(output_dir, state_dict)
134
+ for unused_dir in glob.iglob(os.path.join(output_dir, "global_step*")):
135
+ shutil.rmtree(unused_dir)
136
+
137
+ def _save_extras(self, output_dir: Optional[str] = None):
138
+ self.model.config.save_pretrained(output_dir)
139
+
140
+ task_names = []
141
+ for m in self.model.modalities:
142
+ task_names += m.tasks["task_heads"].keys()
143
+
144
+ non_lora_state_dict = get_peft_state_non_lora(self.model.named_parameters(), task_names)
145
+ torch.save(
146
+ non_lora_state_dict,
147
+ os.path.join(output_dir, "non_lora_trainables.bin"),
148
+ )
149
+
150
+
151
+ def _get_training_devices_dump() -> str:
152
+ out = subprocess.check_output(
153
+ ["nvidia-smi", "--query-gpu=gpu_name,gpu_bus_id,vbios_version", "--format=csv"]
154
+ )
155
+ return out.decode("utf-8").strip()
156
+
157
+
158
+ def train_for_modalities(
159
+ model_cls,
160
+ training_args: TrainingArguments,
161
+ model_args: ModelArguments,
162
+ train_data_args: DataArguments,
163
+ evaluation_data_args: DataArguments,
164
+ modalities: List[Modality],
165
+ ):
166
+ for m in modalities:
167
+ m.to(
168
+ dtype=torch.bfloat16 if training_args.bf16 else torch.float16,
169
+ device=training_args.device,
170
+ )
171
+
172
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
173
+ model_args.model_name_or_path,
174
+ cache_dir=training_args.cache_dir,
175
+ model_max_length=training_args.model_max_length,
176
+ padding_side="right",
177
+ use_fast=False,
178
+ )
179
+ fix_tokenizer(tokenizer)
180
+
181
+ train_dataset = LMMDataset(train_data_args, tokenizer, modalities)
182
+ evaluation_dataset = LMMDataset(evaluation_data_args, tokenizer, modalities)
183
+ collator = DataCollatorForSupervisedLMMDataset(tokenizer, modalities)
184
+
185
+ model = model_cls.from_pretrained(
186
+ model_args.model_name_or_path,
187
+ cache_dir=training_args.cache_dir,
188
+ )
189
+ model.to(
190
+ dtype=torch.bfloat16 if training_args.bf16 else torch.float16,
191
+ device=training_args.device,
192
+ )
193
+ model.modalities = modalities
194
+ model.config.use_cache = False
195
+ model.config.model_cls = model_cls.__name__
196
+ model.config.modality_builder = model_args.modality_builder
197
+
198
+ if training_args.gradient_checkpointing:
199
+ if hasattr(model, "enable_input_require_grads"):
200
+ model.enable_input_require_grads()
201
+ else:
202
+
203
+ def make_inputs_require_grad(module, input, output):
204
+ output.requires_grad_(True)
205
+
206
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
207
+
208
+ if model_args.model_lora_path:
209
+ raise ValueError(
210
+ "LoRA path not supported for training -- set the output path to an existing model to resume training"
211
+ )
212
+
213
+ if training_args.lora_enable:
214
+ logging.info("Adding LoRA adapters...")
215
+ model = make_model_lora(model, training_args)
216
+
217
+ if training_args.pretrained_projectors_path:
218
+ projector_weights_og = torch.load(
219
+ training_args.pretrained_projectors_path, map_location="cpu"
220
+ )
221
+ if model_args.use_multi_task==MultiTaskType.SIMPLE_MULTI_TASK:
222
+ projector_weights = {}
223
+ for k, v in projector_weights_og.items():
224
+ for m in modalities:
225
+ for task_name in m.tasks["task_heads"].keys():
226
+ if task_name in k:
227
+ projector_weights[k] = v
228
+ else:
229
+ projector_weights = {
230
+ k: v for k, v in projector_weights_og.items() if "_lmm_projector" in k
231
+ }
232
+
233
+ elif training_args.pretrained_projectors_config:
234
+ with open(training_args.pretrained_projectors_config, "r") as f:
235
+ pretrained_weights_config = json.load(f)
236
+
237
+ projector_weights = {}
238
+
239
+ for pretrained_path_info in pretrained_weights_config["pretrained_paths"]:
240
+ pretrained_path = pretrained_path_info["path"]
241
+ components = pretrained_path_info["components"]
242
+ use_prefix = pretrained_path_info["use_prefix"]
243
+ prefix = pretrained_path_info["prefix"]
244
+
245
+ pretrained_weights = torch.load(pretrained_path, map_location="cpu")
246
+
247
+ for k, v in pretrained_weights.items():
248
+ if any(component in k for component in components):
249
+ weight_key = k
250
+ if use_prefix:
251
+ weight_key = prefix + "." + k
252
+ projector_weights[weight_key] = v
253
+
254
+ else:
255
+ projector_weights = {}
256
+
257
+ model.get_model().initialize_modules(modalities, projector_weights)
258
+
259
+ task_names = []
260
+ tasks = {}
261
+ for m in model.modalities:
262
+ if m.use_multi_task != MultiTaskType.NO_MULTI_TASK:
263
+ tasks = m.tasks
264
+ task_names += m.tasks["task_heads"].keys()
265
+
266
+ if training_args.pretrain_projectors:
267
+ model.requires_grad_(False)
268
+ for m in modalities:
269
+ if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
270
+ for task_name in m.tasks["task_heads"].keys():
271
+ task_model = getattr(model.get_model(), m.name + "_" + task_name)
272
+ for p in task_model.parameters():
273
+ p.requires_grad = True
274
+ elif m.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK:
275
+ proj = getattr(model.get_model(), m.name + "_lmm_projector")
276
+
277
+ if "backbone" in m.tasks.keys():
278
+ backbone = getattr(proj, "backbone")
279
+ for backbone_param in backbone.parameters():
280
+ backbone_param.requires_grad = tasks["backbone"]["requires_grad"]
281
+
282
+ for task in task_names:
283
+ task_head = getattr(proj, task)
284
+ for task_head_param in task_head.parameters():
285
+ task_head_param.requires_grad = tasks["task_heads"][task]["requires_grad"]
286
+ if task in tasks["task_projectors"]:
287
+ task_projector = getattr(proj, task + "_projector")
288
+ for task_projector_param in task_projector.parameters():
289
+ task_projector_param.requires_grad = tasks["task_projectors"][task]["requires_grad"]
290
+
291
+ else:
292
+ proj = getattr(model.get_model(), m.name + "_lmm_projector")
293
+ for p in proj.parameters():
294
+ p.requires_grad = True
295
+
296
+ os.makedirs(training_args.output_dir, exist_ok=True)
297
+ with open(
298
+ os.path.join(training_args.output_dir, "model_named_parameters.txt"), "w"
299
+ ) as f:
300
+ for name, param in model.named_parameters():
301
+ f.write(f"{name} {param.shape} {param.requires_grad}\n")
302
+
303
+ with open(os.path.join(training_args.output_dir, "README.md"), "w") as f:
304
+ modalities_text = [
305
+ f"* {m.__class__.__name__} (use `{m.token}` in text and provide `{m.data_key}`, encoded as {m.token_width} tokens)"
306
+ for m in modalities
307
+ ]
308
+ readme_text = README_TEMPLATE.format(
309
+ base_model=model_args.model_name_or_path,
310
+ dataset=train_data_args.dataset_path,
311
+ dataset_example=repr(train_dataset.get_example()),
312
+ num_examples=len(train_dataset),
313
+ modalities="\n".join(modalities_text),
314
+ training_devices_dump=_get_training_devices_dump(),
315
+ repr_model=f"{model_cls.__name__}.model =\n\n{repr(model)}",
316
+ )
317
+ f.write(readme_text)
318
+
319
+ trainer = LMMTrainer(
320
+ model=model,
321
+ tokenizer=tokenizer,
322
+ args=training_args,
323
+ data_collator=collator,
324
+ train_dataset=train_dataset,
325
+ eval_dataset=evaluation_dataset,
326
+ )
327
+
328
+ if list(pathlib.Path(training_args.output_dir).glob(f"{PREFIX_CHECKPOINT_DIR}-*")):
329
+ trainer.train(resume_from_checkpoint=True)
330
+ else:
331
+ trainer.train()
332
+
333
+ trainer.save_state()
334
+
335
+ model.config.use_cache = True
336
+ model.config.save_pretrained(training_args.output_dir)
337
+ state_dict = get_peft_state(model.named_parameters(), training_args.lora_bias)
338
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
339
+
340
+ non_lora_state_dict = get_peft_state_non_lora(model.named_parameters(), task_names)
341
+ torch.save(
342
+ non_lora_state_dict,
343
+ os.path.join(training_args.output_dir, "non_lora_trainables.bin"),
344
+ )
src/sonicverse/multi_token/training_data.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Sequence
2
+ from dataclasses import dataclass, field
3
+ import logging
4
+ import os
5
+
6
+ from torch.utils.data import Dataset
7
+ from datasets import load_from_disk, load_dataset, Dataset as HFDataset
8
+ import transformers
9
+ import torch
10
+
11
+ from multi_token.modalities.base_modality import Modality
12
+ from multi_token.constants import IGNORE_INDEX
13
+ from multi_token.data_tools import encode_chat, encode_chat_multitask
14
+ from multi_token.model_utils import MultiTaskType
15
+
16
+
17
+ @dataclass
18
+ class DataArguments:
19
+ dataset_path: str = field(
20
+ default=None, metadata={"help": "Path to the training data."}
21
+ )
22
+
23
+ @dataclass
24
+ class TrainDataArguments:
25
+ train_dataset_path: str = field(
26
+ default=None, metadata={"help": "Path to the training data."}
27
+ )
28
+
29
+ @dataclass
30
+ class EvaluationDataArguments:
31
+ evaluation_dataset_path: str = field(
32
+ default=None, metadata={"help": "Path to the evaluation data."}
33
+ )
34
+
35
+
36
+ def _resolve_dataset(path: str) -> HFDataset:
37
+ if os.path.exists(path):
38
+ return load_from_disk(path)
39
+ else:
40
+ return load_dataset(path, split="train", data_files="*.arrow")
41
+
42
+
43
+ class LMMDataset(Dataset):
44
+ def __init__(
45
+ self,
46
+ data_args: DataArguments,
47
+ tokenizer: transformers.PreTrainedTokenizer,
48
+ modalities: List[Modality],
49
+ ):
50
+ super(LMMDataset, self).__init__()
51
+ self.dataset = _resolve_dataset(data_args.dataset_path)
52
+ self.tokenizer = tokenizer
53
+ self.modalities = modalities
54
+
55
+ def __len__(self):
56
+ return len(self.dataset)
57
+
58
+ def get_example(self) -> Dict:
59
+ return self.dataset[0]
60
+
61
+ def __getitem__(self, i) -> Dict:
62
+ try:
63
+ item = self.dataset[i]
64
+ use_multi_task = MultiTaskType.NO_MULTI_TASK
65
+ for m in self.modalities:
66
+ if m.use_multi_task != MultiTaskType.NO_MULTI_TASK:
67
+ use_multi_task = m.use_multi_task
68
+ break
69
+ if use_multi_task != MultiTaskType.NO_MULTI_TASK:
70
+ return encode_chat_multitask(item, self.tokenizer, self.modalities)
71
+ else:
72
+ return encode_chat(item, self.tokenizer, self.modalities)
73
+ except Exception as e:
74
+ new_i = i + 1
75
+ if new_i >= len(self):
76
+ new_i = 0
77
+ logging.error(f"Error encoding chat: {e} index={i} trying index={new_i}")
78
+ return self.__getitem__(new_i)
79
+
80
+
81
+ @dataclass
82
+ class DataCollatorForSupervisedLMMDataset:
83
+ def __init__(
84
+ self,
85
+ tokenizer: transformers.PreTrainedTokenizer,
86
+ modalities: List[Modality],
87
+ ):
88
+ self.tokenizer = tokenizer
89
+ self.modalities = modalities
90
+
91
+ self.use_multi_task = MultiTaskType.NO_MULTI_TASK
92
+ for modality in self.modalities:
93
+ if modality.use_multi_task != MultiTaskType.NO_MULTI_TASK:
94
+ self.use_multi_task = modality.use_multi_task
95
+ break
96
+
97
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, List]:
98
+ input_ids = []
99
+ lmm_labels = []
100
+ task_labels = []
101
+ for instance in instances:
102
+ input_ids.append(instance["input_ids"])
103
+ if self.use_multi_task == MultiTaskType.NO_MULTI_TASK:
104
+ lmm_labels.append(instance["labels"])
105
+ else:
106
+ lmm_labels.append(instance["labels"][0])
107
+ inst_task_labels = []
108
+ for label_id in range(1, len(instance["labels"])):
109
+ inst_task_labels.append(instance["labels"][label_id])
110
+ task_labels.append(inst_task_labels)
111
+
112
+ input_ids = torch.nn.utils.rnn.pad_sequence(
113
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
114
+ )
115
+ # print("Lmm labels 1 type :", type(lmm_labels))
116
+ lmm_labels = torch.nn.utils.rnn.pad_sequence(
117
+ lmm_labels, batch_first=True, padding_value=IGNORE_INDEX
118
+ )
119
+ # print("Lmm labels 2 type :", type(lmm_labels))
120
+
121
+ input_ids = input_ids[:, : self.tokenizer.model_max_length]
122
+ lmm_labels = lmm_labels[:, : self.tokenizer.model_max_length]
123
+ output_labels = [lmm_labels, task_labels]
124
+ batch = dict(
125
+ input_ids=input_ids,
126
+ labels=output_labels,
127
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
128
+ )
129
+
130
+ for m in self.modalities:
131
+ batch[m.name] = [instance[m.name] for instance in instances]
132
+
133
+ return batch
src/sonicverse/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.34.0
2
+ accelerate>=0.21.0
3
+ scipy>=1.11.3
4
+ bitsandbytes>=0.41.0
5
+ datasets>=2.14.5
6
+ sentencepiece>=0.1.99
7
+ peft>=0.4.0
8
+ deepspeed==0.9.5
src/sonicverse/scripts/audio_setup.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ pip install librosa soundfile
src/sonicverse/scripts/clap_gpt_build_finetune_dataset.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import argparse
3
+ import json
4
+ import os
5
+ import random
6
+ import openai
7
+
8
+ from datasets import Dataset, load_dataset
9
+
10
+ from multi_token.constants import ROLE_ASSISTANT, ROLE_USER
11
+
12
+ PROMPT = """
13
+ You are helping train a sound assistant that can take audio inputs and output text.
14
+
15
+ You can hear an audio file with the following metadata tags:
16
+ {captions}
17
+
18
+ {question}
19
+
20
+ Include the question and answer.
21
+ """
22
+
23
+ QUESTIONS = [
24
+ "Ask a question about the content of the audio.",
25
+ "Ask a complex question about the content of the audio.",
26
+ "Ask a complex question that is relevant to the content of the audio, for example, asking about background knowledge of the things mentioned. Do not ask about uncertain details.",
27
+ "Ask a complex question that is relevant to the content of the audio, for example, asking about the events referred to in the audio. Do not ask about uncertain details.",
28
+ "Ask about your thoughts on the audio.",
29
+ "Ask about what occurs in the audio.",
30
+ "Ask a question on a topic that related to the audio.",
31
+ "Ask a question that classifies the audio in some way.",
32
+ "Ask a question that can only be answered by listening to the audio.",
33
+ ]
34
+
35
+
36
+ OPENAI_TOOLS = [
37
+ {
38
+ "type": "function",
39
+ "function": {
40
+ "name": "create_chat",
41
+ "description": "Create a training example",
42
+ "parameters": {
43
+ "type": "object",
44
+ "properties": {
45
+ "question": {
46
+ "type": "string",
47
+ "description": "The question, must be provided",
48
+ },
49
+ "answer": {
50
+ "type": "string",
51
+ "description": "The answer to the question, must be provided",
52
+ },
53
+ },
54
+ "required": ["question", "answer"],
55
+ },
56
+ },
57
+ }
58
+ ]
59
+
60
+
61
+ def _build_convo(row) -> List:
62
+ client = openai.Client()
63
+
64
+ captions = [row["metadataTags"]]
65
+ paths = [row["url"]]
66
+
67
+ captions_text = "\n".join([f"{cap}" for i, cap in enumerate(captions)])
68
+ prompt = PROMPT.format(
69
+ captions=captions_text, question=random.choice(QUESTIONS)
70
+ ).strip()
71
+
72
+ completion = client.chat.completions.create(
73
+ model="gpt-3.5-turbo-1106",
74
+ messages=[{"role": "system", "content": prompt}],
75
+ tools=OPENAI_TOOLS,
76
+ tool_choice={"type": "function", "function": {"name": "create_chat"}},
77
+ )
78
+ resp = json.loads(completion.choices[0].message.tool_calls[0].function.arguments)
79
+ if "answer" not in resp:
80
+ print(resp)
81
+ q = resp["question"]
82
+ a = resp["answer"]
83
+
84
+ if random.choice([True, False]):
85
+ q = "<sound>" * len(captions) + " " + q
86
+ else:
87
+ q = q + " " + "<sound>" * len(captions)
88
+
89
+ example = {
90
+ "sounds": paths,
91
+ "messages": [
92
+ {
93
+ "role": ROLE_USER,
94
+ "content": q,
95
+ },
96
+ {
97
+ "role": ROLE_ASSISTANT,
98
+ "content": a,
99
+ },
100
+ ],
101
+ }
102
+ return example
103
+
104
+
105
+ def main(args):
106
+ data = load_dataset("Chr0my/Epidemic_sounds", split="train")
107
+ data_idxs = list(range(len(data)))
108
+
109
+ os.makedirs(args.cache_folder, exist_ok=True)
110
+
111
+ def gen(seeds):
112
+ r = random.Random(seeds[0] + 3)
113
+ cache = open(
114
+ os.path.join(args.cache_folder, f"gpt-cache.{seeds[0]}.jsonl"), "a"
115
+ )
116
+ i = 0
117
+ while i < len(seeds):
118
+ selected_idxs = r.sample(data_idxs, k=1)[0]
119
+ selected_example = data[selected_idxs]
120
+ try:
121
+ example = _build_convo(selected_example)
122
+ cache.write(json.dumps(example) + "\n")
123
+ yield example
124
+ i += 1
125
+ except Exception as e:
126
+ print(e)
127
+ continue
128
+ cache.close()
129
+
130
+ ds = Dataset.from_generator(
131
+ gen,
132
+ num_proc=args.num_proc,
133
+ gen_kwargs={"seeds": list(range(args.num_examples))},
134
+ )
135
+ ds.save_to_disk(args.output_folder)
136
+
137
+
138
+ if __name__ == "__main__":
139
+ parser = argparse.ArgumentParser()
140
+ parser.add_argument(
141
+ "-o",
142
+ "--output_folder",
143
+ type=str,
144
+ default="/data/clap-gpt-finetune",
145
+ )
146
+ parser.add_argument(
147
+ "-c",
148
+ "--cache_folder",
149
+ type=str,
150
+ default="/data/clap-gpt-finetune-cache",
151
+ )
152
+ parser.add_argument("-n", "--num_examples", type=int, default=100_000)
153
+ parser.add_argument("-p", "--num_proc", type=int, default=10)
154
+ args = parser.parse_args()
155
+ main(args)
src/sonicverse/scripts/clap_gpt_build_pretrain_dataset.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import argparse
3
+ import json
4
+ import os
5
+ import random
6
+ import openai
7
+
8
+ from datasets import Dataset, load_dataset
9
+
10
+ from multi_token.constants import ROLE_ASSISTANT, ROLE_USER
11
+
12
+ PROMPT = """
13
+ You are helping write captions for audio clips.
14
+
15
+ Here are the tags for the audio clip you are captioning:
16
+ {captions}
17
+
18
+ Write a brief caption for the audio clip.
19
+ """
20
+
21
+ PRETRAIN_PHRASES = [
22
+ "What is happening in <sound>?",
23
+ "Describe the sound. <sound>",
24
+ "<sound> Provide a description of the audio.",
25
+ "Can you interpret <sound>?",
26
+ "Please explain what's happening in <sound>",
27
+ "What does <sound> represent?",
28
+ "Could you describe <sound> for me?",
29
+ "What's the content of <sound>?",
30
+ "Can you depict <sound>?",
31
+ "What is <sound>?",
32
+ "In the audo clip, <sound>, what is happening?",
33
+ "Provide a description of the sound. <sound>",
34
+ "Provide a caption for the sound. <sound>",
35
+ ]
36
+
37
+ OPENAI_TOOLS = [
38
+ {
39
+ "type": "function",
40
+ "function": {
41
+ "name": "write_caption",
42
+ "description": "Write a caption for an audio clip",
43
+ "parameters": {
44
+ "type": "object",
45
+ "properties": {
46
+ "caption": {
47
+ "type": "string",
48
+ },
49
+ },
50
+ "required": ["caption"],
51
+ },
52
+ },
53
+ }
54
+ ]
55
+
56
+
57
+ def _build_convo(row) -> List:
58
+ client = openai.Client()
59
+
60
+ captions = [row["metadataTags"]]
61
+ sounds = [row["url"]]
62
+
63
+ captions_text = "\n".join([f'Tags: "{cap}"' for i, cap in enumerate(captions)])
64
+ prompt = PROMPT.format(captions=captions_text).strip()
65
+
66
+ completion = client.chat.completions.create(
67
+ model="gpt-3.5-turbo-1106",
68
+ messages=[{"role": "system", "content": prompt}],
69
+ tools=OPENAI_TOOLS,
70
+ tool_choice={"type": "function", "function": {"name": "write_caption"}},
71
+ )
72
+ resp = json.loads(completion.choices[0].message.tool_calls[0].function.arguments)
73
+ caption = resp["caption"]
74
+
75
+ q = random.choice(PRETRAIN_PHRASES)
76
+
77
+ example = {
78
+ "sounds": sounds,
79
+ "messages": [
80
+ {
81
+ "role": ROLE_USER,
82
+ "content": q,
83
+ },
84
+ {
85
+ "role": ROLE_ASSISTANT,
86
+ "content": caption,
87
+ },
88
+ ],
89
+ }
90
+ return example
91
+
92
+
93
+ def main(args):
94
+ data = load_dataset("Chr0my/Epidemic_sounds", split="train")
95
+
96
+ os.makedirs(args.cache_folder, exist_ok=True)
97
+
98
+ def gen(seeds):
99
+ cache = open(
100
+ os.path.join(args.cache_folder, f"gpt-cache.{seeds[0]}.jsonl"), "a"
101
+ )
102
+ for s in seeds:
103
+ selected_row = data[s]
104
+ try:
105
+ example = _build_convo(selected_row)
106
+ cache.write(json.dumps(example) + "\n")
107
+ yield example
108
+ except Exception as e:
109
+ print(e)
110
+ continue
111
+
112
+ cache.close()
113
+
114
+ idxs = list(range(len(data)))
115
+ random.shuffle(idxs)
116
+
117
+ ds = Dataset.from_generator(
118
+ gen,
119
+ num_proc=args.num_proc,
120
+ gen_kwargs={"seeds": idxs},
121
+ )
122
+ ds.save_to_disk(args.output_folder)
123
+
124
+
125
+ if __name__ == "__main__":
126
+ parser = argparse.ArgumentParser()
127
+ parser.add_argument(
128
+ "-o",
129
+ "--output_folder",
130
+ type=str,
131
+ default="/data/clap-gpt-pretrain",
132
+ )
133
+ parser.add_argument(
134
+ "-c",
135
+ "--cache_folder",
136
+ type=str,
137
+ default="/data/clap-gpt-pretrain-cache",
138
+ )
139
+ parser.add_argument("-n", "--num_examples", type=int, default=500_000)
140
+ parser.add_argument("-p", "--num_proc", type=int, default=10)
141
+ args = parser.parse_args()
142
+ main(args)
src/sonicverse/scripts/document_build_finetune_dataset.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import argparse
3
+ import re
4
+ import glob
5
+ import json
6
+
7
+ from datasets import load_dataset
8
+ from datasets import Dataset
9
+
10
+ from multi_token.constants import ROLE_ASSISTANT, ROLE_USER
11
+ from multi_token.modalities.document_gte import (
12
+ split_text_into_documents,
13
+ )
14
+
15
+ TEMP_TOKEN = "<<<TEMP-TOKEN>>>"
16
+
17
+ # regex, doc, prompt
18
+ LONG_ALPACA_REGEXES = [
19
+ (
20
+ r"Below is a paper. Memorize the paper and answer my question after the paper.\n The paper begins. \n ([\s\S]+) \n Now the paper ends. \n([\s\S]+)",
21
+ lambda m: m.group(1),
22
+ lambda m: f"Read the paper {TEMP_TOKEN}. {m.group(2)}",
23
+ ),
24
+ (
25
+ r"Below is a paper. Memorize the material and answer my question after the paper.\n([\s\S]+)\n Now the material ends. ([\s\S]+)",
26
+ lambda m: m.group(1),
27
+ lambda m: f"Read the paper {TEMP_TOKEN}. {m.group(2)}",
28
+ ),
29
+ (
30
+ r"There are two papers. Memorize them and answer my question after the paper.\n The first paper begins. \n ([\s\S]+) Now the second paper ends.([\s\S]+)",
31
+ lambda m: m.group(1),
32
+ lambda m: f"Read the papers {TEMP_TOKEN}. {m.group(2)}",
33
+ ),
34
+ (
35
+ r"Below is some paragraphs in the book, ([\s\S]+?). Memorize the content and answer my question after the book.\n([\s\S]+) \n Now the material ends.([\s\S]+)",
36
+ lambda m: m.group(2),
37
+ lambda m: f"Read the book {m.group(1)} {TEMP_TOKEN}. {m.group(3)}",
38
+ ),
39
+ ]
40
+
41
+ # regex, doc, prompt, answer
42
+ LONG_DATA_REGEXES = [
43
+ (
44
+ r"Write a high-quality answer for the given question using only the provided search results \(some of which might be irrelevant\).([\s\S]+)Question: ([\s\S]+)Answer: ([\s\S]+)\nLong Answer: ([\s\S]+)",
45
+ lambda m: m.group(1).strip(),
46
+ lambda m: f"Write a high-quality answer for the given question using only the provided search results {TEMP_TOKEN}. {m.group(2).strip()}",
47
+ lambda m: m.group(4).strip(),
48
+ ),
49
+ (
50
+ r"([\s\S]+)\nQ: ([\s\S]+)\nA: ([\s\S]+)",
51
+ lambda m: m.group(1).strip(),
52
+ lambda m: f"Read the following book {TEMP_TOKEN}. {m.group(2).strip()}",
53
+ lambda m: m.group(3).strip(),
54
+ ),
55
+ ]
56
+
57
+
58
+ def _write_long_alpaca_convo(row, max_document_chunks) -> List:
59
+ doc_text = None
60
+ prompt = None
61
+ for regex, get_doc, get_prompt in LONG_ALPACA_REGEXES:
62
+ match = re.match(regex, row["instruction"])
63
+ if match:
64
+ doc_text = get_doc(match)
65
+ prompt = get_prompt(match).replace("Question: ", "")
66
+ break
67
+
68
+ if doc_text is None and row["input"]:
69
+ doc_text = row["input"]
70
+ prompt = row["instruction"] + f" {TEMP_TOKEN}"
71
+
72
+ if doc_text is None:
73
+ raise ValueError("No document found")
74
+
75
+ docs = split_text_into_documents(doc_text)
76
+ if len(docs) > max_document_chunks:
77
+ raise ValueError("Document too long")
78
+ example = {
79
+ "id": "longalpaca-" + str(hash(row["instruction"])),
80
+ "documents": docs,
81
+ }
82
+ example["messages"] = [
83
+ {
84
+ "role": ROLE_USER,
85
+ "content": prompt.replace(TEMP_TOKEN, "<document>" * len(docs)),
86
+ },
87
+ {
88
+ "role": ROLE_ASSISTANT,
89
+ "content": row["output"].replace("Answer: ", ""),
90
+ },
91
+ ]
92
+ return example
93
+
94
+
95
+ def _write_long_data_collections_convo(row, max_document_chunks) -> List:
96
+ doc_text = None
97
+ prompt = None
98
+ answer = None
99
+ for regex, get_doc, get_prompt, get_answer in LONG_DATA_REGEXES:
100
+ match = re.match(regex, row["text"])
101
+ if match:
102
+ doc_text = get_doc(match)
103
+ prompt = get_prompt(match)
104
+ answer = get_answer(match).replace(" .", ".")
105
+ break
106
+
107
+ if not doc_text or not prompt or not answer:
108
+ raise ValueError("No document found")
109
+
110
+ docs = split_text_into_documents(doc_text)
111
+ if len(docs) > max_document_chunks:
112
+ raise ValueError("Document too long")
113
+ example = {
114
+ "id": "longdatacollection-" + str(hash(row["text"])),
115
+ "documents": docs,
116
+ }
117
+ example["messages"] = [
118
+ {
119
+ "role": ROLE_USER,
120
+ "content": prompt.replace(TEMP_TOKEN, "<document>" * len(docs)),
121
+ },
122
+ {
123
+ "role": ROLE_ASSISTANT,
124
+ "content": answer,
125
+ },
126
+ ]
127
+ return example
128
+
129
+
130
+ def main(args):
131
+ long_alpaca = load_dataset(args.long_alpaca_path, "train")["train"]
132
+
133
+ def gen():
134
+ for row in long_alpaca:
135
+ try:
136
+ yield _write_long_alpaca_convo(row, args.max_document_chunks)
137
+ except ValueError:
138
+ continue
139
+ for long_collection_fn in glob.iglob(args.long_collections_glob):
140
+ with open(long_collection_fn) as f:
141
+ for line in f:
142
+ row = json.loads(line)
143
+ try:
144
+ yield _write_long_data_collections_convo(
145
+ row, args.max_document_chunks
146
+ )
147
+ except ValueError:
148
+ continue
149
+
150
+ ds = Dataset.from_generator(gen)
151
+ ds = ds.shuffle(seed=42)
152
+ ds.save_to_disk(args.output_folder)
153
+
154
+
155
+ if __name__ == "__main__":
156
+ parser = argparse.ArgumentParser()
157
+ parser.add_argument("--long_alpaca_path", type=str, default="Yukang/LongAlpaca-12k")
158
+ parser.add_argument("--long_collections_glob", type=str)
159
+ parser.add_argument("-o", "--output_folder", type=str)
160
+ parser.add_argument("-c", "--max_document_chunks", type=int, default=256)
161
+ args = parser.parse_args()
162
+ main(args)
src/sonicverse/scripts/document_build_pretrain_dataset.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import random
3
+ import argparse
4
+
5
+ from datasets import load_dataset
6
+ from datasets import Dataset
7
+
8
+ from multi_token.constants import ROLE_ASSISTANT, ROLE_USER
9
+ from multi_token.modalities.document_gte import (
10
+ split_text_into_documents,
11
+ )
12
+
13
+ TEMP_TOKEN = "<<<TEMP-TOKEN>>>"
14
+
15
+ PRETRAIN_PHRASES = [
16
+ f"Repeat the content of the document {TEMP_TOKEN}",
17
+ f"Transcribe {TEMP_TOKEN}",
18
+ f"Provide a verbatim transcription of {TEMP_TOKEN}",
19
+ f"Write down exactly what is in {TEMP_TOKEN}",
20
+ f"Copy the text from {TEMP_TOKEN}",
21
+ f"Duplicate the content of {TEMP_TOKEN}",
22
+ f"Reproduce the text in {TEMP_TOKEN}",
23
+ f"Render the exact text from {TEMP_TOKEN}",
24
+ f"Echo the content of {TEMP_TOKEN}",
25
+ f"Mirror the text in {TEMP_TOKEN}",
26
+ f"Reflect the content of {TEMP_TOKEN}",
27
+ f"Transcribe the exact words from {TEMP_TOKEN}",
28
+ f"Write out the exact content of {TEMP_TOKEN}",
29
+ f"Provide a direct transcription of {TEMP_TOKEN}",
30
+ f"Give a word-for-word account of {TEMP_TOKEN}",
31
+ f"Reiterate the exact text of {TEMP_TOKEN}",
32
+ f"Replicate the content of {TEMP_TOKEN}",
33
+ f"Reprint the text from {TEMP_TOKEN}",
34
+ f"Rewrite the exact words from {TEMP_TOKEN}",
35
+ ]
36
+
37
+
38
+ def _write_convo(row, max_document_chunks) -> List:
39
+ docs = split_text_into_documents(row["text"])
40
+ if len(docs) > max_document_chunks:
41
+ raise ValueError("Document too long")
42
+ example = {
43
+ "id": str(row["title"]),
44
+ "documents": docs,
45
+ }
46
+ phrase = random.choice(PRETRAIN_PHRASES)
47
+ example["messages"] = [
48
+ {
49
+ "role": ROLE_USER,
50
+ "content": phrase.replace(TEMP_TOKEN, "<document>" * len(docs)),
51
+ },
52
+ {
53
+ "role": ROLE_ASSISTANT,
54
+ "content": row["text"],
55
+ },
56
+ ]
57
+ return example
58
+
59
+
60
+ def main(args):
61
+ wiki_data = load_dataset("graelo/wikipedia", "20230601.en")["train"]
62
+
63
+ idxs = list(range(len(wiki_data)))
64
+ random.shuffle(idxs)
65
+
66
+ def gen():
67
+ i = 0
68
+ for idx in idxs:
69
+ row = wiki_data[idx]
70
+ try:
71
+ yield _write_convo(row, args.max_document_chunks)
72
+ except ValueError:
73
+ pass
74
+ else:
75
+ i += 1
76
+ if i >= args.max_examples:
77
+ break
78
+
79
+ ds = Dataset.from_generator(gen)
80
+ ds.save_to_disk(args.output_folder)
81
+
82
+
83
+ if __name__ == "__main__":
84
+ parser = argparse.ArgumentParser()
85
+ parser.add_argument("-o", "--output_folder", type=str)
86
+ parser.add_argument("-n", "--max_examples", type=int, default=1_000_000)
87
+ parser.add_argument("-c", "--max_document_chunks", type=int, default=4)
88
+ args = parser.parse_args()
89
+ main(args)
src/sonicverse/scripts/document_setup.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ pip install nltk
4
+
5
+ python -c "import nltk; nltk.download('punkt')"
src/sonicverse/scripts/evaluate_model.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ import logging
3
+
4
+ from flask import Flask, request, jsonify
5
+ import transformers
6
+ import torch
7
+
8
+ from datasets import load_from_disk
9
+
10
+ from multi_token.model_utils import MultiTaskType
11
+ from multi_token.training import (
12
+ ModelArguments,
13
+ )
14
+ from multi_token.inference import load_trained_lora_model
15
+ from multi_token.data_tools import encode_chat
16
+
17
+ import evaluate
18
+
19
+ import random
20
+
21
+ PRETRAIN_PHRASES = [
22
+ "What is happening in the given music <sound>?",
23
+ "Describe the sound. <sound>",
24
+ "Describe the music. <sound>",
25
+ "<sound> Provide a description of the music.",
26
+ "<sound> Provide a description of the sound.",
27
+ "Can you interpret <sound>?",
28
+ "Please explain what's happening in <sound>",
29
+ "What does <sound> represent?",
30
+ "Could you describe <sound> for me?",
31
+ "What's the content of <sound>?",
32
+ "Can you depict <sound>?",
33
+ "What is <sound>?",
34
+ "In the music clip, <sound>, what is happening?",
35
+ "Provide a description of the music. <sound>",
36
+ "Provide a description of the sound. <sound>",
37
+ "Provide a caption for the sound. <sound>",
38
+ "Provide a caption for the music. <sound>",
39
+ ]
40
+
41
+
42
+ @dataclass
43
+ class ServeArguments(ModelArguments):
44
+ port: int = field(default=8080)
45
+ host: str = field(default="0.0.0.0")
46
+ load_bits: int = field(default=16)
47
+ max_new_tokens: int = field(default=128)
48
+ temperature: float = field(default=0.01)
49
+
50
+
51
+ def generate(input_json):
52
+ encoded_dict = encode_chat(input_json, tokenizer, model.modalities)
53
+
54
+ with torch.inference_mode():
55
+ output_ids = model.generate(
56
+ input_ids=encoded_dict["input_ids"].unsqueeze(0).to(model.device),
57
+ max_new_tokens=serve_args.max_new_tokens,
58
+ use_cache=True,
59
+ do_sample=True,
60
+ temperature=serve_args.temperature,
61
+ modality_inputs={
62
+ m.name: [encoded_dict[m.name]] for m in model.modalities
63
+ },
64
+ )
65
+
66
+ outputs = tokenizer.decode(
67
+ output_ids[0, encoded_dict["input_ids"].shape[0] :],
68
+ skip_special_tokens=True,
69
+ ).strip()
70
+
71
+ return {"output": outputs}
72
+
73
+
74
+ if __name__ == "__main__":
75
+ logging.getLogger().setLevel(logging.INFO)
76
+
77
+ parser = transformers.HfArgumentParser((ServeArguments,))
78
+
79
+ serve_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
80
+
81
+ dataset_path = "/data/musicbench_multitoken_official_split/val"
82
+
83
+ ds = load_from_disk(dataset_path)
84
+
85
+ model, tokenizer = load_trained_lora_model(
86
+ model_name_or_path=serve_args.model_name_or_path,
87
+ model_lora_path=serve_args.model_lora_path,
88
+ load_bits=serve_args.load_bits,
89
+ use_multi_task=MultiTaskType(serve_args.use_multi_task),
90
+ tasks_config=serve_args.tasks_config
91
+ )
92
+
93
+ predictions = []
94
+ references = []
95
+ content_phrase = random.choice(PRETRAIN_PHRASES)
96
+ for data_point_id in range(100):
97
+ data_point = ds[data_point_id]
98
+ # print("datapoint", data_point)
99
+ input_json={"messages": [{"role": "user", "content": content_phrase}], "sounds": data_point["sounds"]}
100
+ output_json = generate(input_json)
101
+
102
+ print("Prediction ",output_json["output"])
103
+ print("Reference ", data_point["messages"][1]["content"])
104
+ print()
105
+ print()
106
+ predictions.append(output_json["output"])
107
+ references.append(data_point["messages"][1]["content"])
108
+
109
+ sacrebleu = evaluate.load("sacrebleu")
110
+ sacrebleu_results=sacrebleu.compute(predictions=predictions, references=references)
111
+
112
+ print(sacrebleu_results["score"])
src/sonicverse/scripts/evaluate_model_latest.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from dataclasses import dataclass, field
3
+ import logging
4
+ from flask import Flask, request, jsonify
5
+ import transformers
6
+ import torch
7
+ from datasets import load_from_disk
8
+ from multi_token.model_utils import MultiTaskType
9
+ from multi_token.training import ModelArguments
10
+ from multi_token.inference import load_trained_lora_model
11
+ from multi_token.data_tools import encode_chat
12
+ import evaluate
13
+ import random
14
+ import bert_score
15
+ from tqdm import tqdm
16
+
17
+ PRETRAIN_PHRASES_OLD = [
18
+ "Describe the audio in detail"
19
+ ]
20
+
21
+ PRETRAIN_PHRASES = [
22
+ "What is happening in the given music <sound>?",
23
+ "Describe the sound. <sound>",
24
+ "Describe the music. <sound>",
25
+ "<sound> Provide a description of the music.",
26
+ "<sound> Provide a description of the sound.",
27
+ "Can you interpret <sound>?",
28
+ "Please explain what's happening in <sound>",
29
+ "What does <sound> represent?",
30
+ "Could you describe <sound> for me?",
31
+ "What's the content of <sound>?",
32
+ "Can you depict <sound>?",
33
+ "What is <sound>?",
34
+ "In the music clip, <sound>, what is happening?",
35
+ "Provide a description of the music. <sound>",
36
+ "Provide a description of the sound. <sound>",
37
+ "Provide a caption for the sound. <sound>",
38
+ "Provide a caption for the music. <sound>",
39
+ ]
40
+
41
+ random.seed(1234)
42
+
43
+ @dataclass
44
+ class ServeArguments(ModelArguments):
45
+ port: int = field(default=8080)
46
+ host: str = field(default="0.0.0.0")
47
+ load_bits: int = field(default=16)
48
+ max_new_tokens: int = field(default=128)
49
+ temperature: float = field(default=0.01)
50
+
51
+ def generate(input_json):
52
+ encoded_dict = encode_chat(input_json, tokenizer, model.modalities)
53
+ with torch.inference_mode():
54
+ output_ids = model.generate(
55
+ input_ids=encoded_dict["input_ids"].unsqueeze(0).to(model.device),
56
+ max_new_tokens=serve_args.max_new_tokens,
57
+ use_cache=True,
58
+ do_sample=True,
59
+ temperature=serve_args.temperature,
60
+ modality_inputs={
61
+ m.name: [encoded_dict[m.name]] for m in model.modalities
62
+ },
63
+ )
64
+ outputs = tokenizer.decode(
65
+ output_ids[0, encoded_dict["input_ids"].shape[0]:],
66
+ skip_special_tokens=True,
67
+ ).strip()
68
+ return {"output": outputs}
69
+
70
+ if __name__ == "__main__":
71
+ logging.getLogger().setLevel(logging.INFO)
72
+ parser = transformers.HfArgumentParser((ServeArguments,))
73
+ serve_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
74
+ dataset_path = "/data/musicbench_multitoken_official_split/val"
75
+ ds = load_from_disk(dataset_path)
76
+ shuffled_ds = ds.shuffle(seed=1234)
77
+ model, tokenizer = load_trained_lora_model(
78
+ model_name_or_path=serve_args.model_name_or_path,
79
+ model_lora_path=serve_args.model_lora_path,
80
+ load_bits=serve_args.load_bits,
81
+ use_multi_task=MultiTaskType(serve_args.use_multi_task),
82
+ tasks_config=serve_args.tasks_config
83
+ )
84
+
85
+ predictions = []
86
+ references = []
87
+ content_phrase = random.choice(PRETRAIN_PHRASES)
88
+ # for data_point_id in range(len(ds)):
89
+ for data_point_id in tqdm(range(10)):
90
+ data_point = shuffled_ds[data_point_id]
91
+ input_json = {"messages": [{"role": "user", "content": content_phrase}], "sounds": data_point["sounds"]}
92
+ output_json = generate(input_json)
93
+ print("Prediction ", output_json["output"])
94
+ print("Reference ", data_point["messages"][1]["content"])
95
+ print()
96
+ print()
97
+ predictions.append(output_json["output"])
98
+ references.append(data_point["messages"][1]["content"])
99
+
100
+ # Load evaluation metrics
101
+ bleu = evaluate.load("bleu")
102
+ meteor = evaluate.load("meteor")
103
+ rouge = evaluate.load("rouge")
104
+
105
+ # Compute BLEU scores
106
+ bleu_results = bleu.compute(predictions=predictions, references=references, max_order=4)
107
+ print(bleu_results)
108
+ #bleu_score = sum(bleu_results[f"bleu{i}"] for i in range(1, 5)) / 4
109
+
110
+ # Compute METEOR score
111
+ meteor_results = meteor.compute(predictions=predictions, references=references)
112
+ meteor_score = meteor_results["meteor"]
113
+
114
+ # Compute ROUGE-L score
115
+ rouge_results = rouge.compute(predictions=predictions, references=references, rouge_types=["rougeL"])
116
+ rouge_l_score = rouge_results["rougeL"].mid.fmeasure
117
+ #print(rouge_results)
118
+
119
+ # Compute BERT-Score
120
+ P, R, F1 = bert_score.score(predictions, references, lang="en", rescale_with_baseline=True)
121
+ bert_score_f1 = F1.mean().item()
122
+
123
+ # Print results
124
+ #print(f"BLEU Score: {bleu_score}")
125
+ print(f"METEOR Score: {meteor_score}")
126
+ #print(f"ROUGE-L Score: {rouge_l_score}")
127
+ print(f"BERT-Score F1: {bert_score_f1}")
src/sonicverse/scripts/evaluate_model_mullama.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from dataclasses import dataclass, field
3
+ import logging
4
+ from flask import Flask, request, jsonify
5
+ import transformers
6
+ import torch
7
+ from datasets import load_from_disk
8
+ from multi_token.model_utils import MultiTaskType
9
+ from multi_token.training import ModelArguments
10
+ from multi_token.inference import load_trained_lora_model
11
+ from multi_token.data_tools import encode_chat
12
+ import evaluate
13
+ import random
14
+ import bert_score
15
+ from tqdm import tqdm
16
+
17
+ from rouge_score import rouge_scorer
18
+ from nltk.translate.bleu_score import sentence_bleu
19
+ from nltk.translate.meteor_score import meteor_score as meteor_scorer
20
+ from nltk.tokenize import wordpunct_tokenize
21
+ import json
22
+ from bert_score import score
23
+ from tqdm.auto import tqdm
24
+
25
+ import yaml
26
+
27
+ scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
28
+
29
+
30
+ PRETRAIN_PHRASES_OLD = [
31
+ "Describe the audio in detail"
32
+ ]
33
+
34
+ PRETRAIN_PHRASES = [
35
+ "What is happening in the given music <sound>?",
36
+ "Describe the sound. <sound>",
37
+ "Describe the music. <sound>",
38
+ "<sound> Provide a description of the music.",
39
+ "<sound> Provide a description of the sound.",
40
+ "Can you interpret <sound>?",
41
+ "Please explain what's happening in <sound>",
42
+ "What does <sound> represent?",
43
+ "Could you describe <sound> for me?",
44
+ "What's the content of <sound>?",
45
+ "Can you depict <sound>?",
46
+ "What is <sound>?",
47
+ "In the music clip, <sound>, what is happening?",
48
+ "Provide a description of the music. <sound>",
49
+ "Provide a description of the sound. <sound>",
50
+ "Provide a caption for the sound. <sound>",
51
+ "Provide a caption for the music. <sound>",
52
+ ]
53
+
54
+ random.seed(1234)
55
+
56
+ @dataclass
57
+ class ServeArguments(ModelArguments):
58
+ port: int = field(default=8080)
59
+ host: str = field(default="0.0.0.0")
60
+ load_bits: int = field(default=16)
61
+ max_new_tokens: int = field(default=128)
62
+ temperature: float = field(default=0.01)
63
+
64
+ def generate(input_json):
65
+ encoded_dict = encode_chat(input_json, tokenizer, model.modalities)
66
+ with torch.inference_mode():
67
+ output_ids = model.generate(
68
+ input_ids=encoded_dict["input_ids"].unsqueeze(0).to(model.device),
69
+ max_new_tokens=serve_args.max_new_tokens,
70
+ use_cache=True,
71
+ do_sample=True,
72
+ temperature=serve_args.temperature,
73
+ modality_inputs={
74
+ m.name: [encoded_dict[m.name]] for m in model.modalities
75
+ },
76
+ )
77
+ outputs = tokenizer.decode(
78
+ output_ids[0, encoded_dict["input_ids"].shape[0]:],
79
+ skip_special_tokens=True,
80
+ ).strip()
81
+ return {"output": outputs}
82
+
83
+ def evaluate(candidates, mult_reference):
84
+ rouge_score, bleu_score, bleu4_score, meteor_score = 0, 0, 0, 0
85
+ for ref, cand in tqdm(zip(mult_reference, candidates), total=len(mult_reference)):
86
+ rouge_score += scorer.score(ref, cand)['rougeL'].recall
87
+ cand_split = wordpunct_tokenize(cand)
88
+ ref_split = wordpunct_tokenize(ref)
89
+ bleu4_score += sentence_bleu([ref], cand, weights=(0.0, 0.0, 0.0, 1.0))
90
+ bleu_score += sentence_bleu([ref], cand)
91
+ meteor_score += meteor_scorer([ref_split], cand_split)
92
+ rouge_score, bleu_score, bleu4_score, meteor_score = rouge_score / (len(candidates)), bleu_score / (len(candidates)), bleu4_score / (len(candidates)), meteor_score / (len(candidates))
93
+ P, R, F1 = score(candidates, mult_reference, lang="en", verbose=True)
94
+ bert_score = R.mean().item()
95
+ #print(f"Model: {model_name}")
96
+ print(f"BLEU Score: {bleu_score}")
97
+ print(f"BLEU-4 Score: {bleu4_score}")
98
+ print(f"METEOR Score: {meteor_score}")
99
+ print(f"ROUGE Score: {rouge_score}")
100
+ print(f"BERT Score: {bert_score}")
101
+
102
+ if __name__ == "__main__":
103
+ logging.getLogger().setLevel(logging.INFO)
104
+ parser = transformers.HfArgumentParser((ServeArguments,))
105
+ serve_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
106
+ dataset_path = "/data/musicbench_multitoken_official_split/val"
107
+ ds = load_from_disk(dataset_path)
108
+ shuffled_ds = ds.shuffle(seed=1234)
109
+ model, tokenizer = load_trained_lora_model(
110
+ model_name_or_path=serve_args.model_name_or_path,
111
+ model_lora_path=serve_args.model_lora_path,
112
+ load_bits=serve_args.load_bits,
113
+ use_multi_task=MultiTaskType(serve_args.use_multi_task),
114
+ tasks_config=serve_args.tasks_config
115
+ )
116
+
117
+ predictions = []
118
+ references = []
119
+ content_phrase = random.choice(PRETRAIN_PHRASES)
120
+ # for data_point_id in range(len(ds)):
121
+ print("len(ds)", len(ds))
122
+ for data_point_id in tqdm(range(100)):
123
+ # for data_point_id in tqdm(range(6831)):
124
+ data_point = shuffled_ds[data_point_id]
125
+ input_json = {"messages": [{"role": "user", "content": content_phrase}], "sounds": data_point["sounds"]}
126
+ output_json = generate(input_json)
127
+ # print("Prediction ", output_json["output"])
128
+ # print("Reference ", data_point["messages"][1]["content"])
129
+ # print()
130
+ # print()
131
+ predictions.append(output_json["output"])
132
+ references.append(data_point["messages"][1]["content"])
133
+
134
+ pairs = {"predictions": predictions, "references": references}
135
+
136
+ evaluate(predictions, references)
137
+
138
+ # with open('/experiments/captioning/mert_tasks_separate_backbone_train_001_ft/checkpoint_1985_test/val_2.yaml', 'w') as file:
139
+ # yaml.dump(pairs, file, default_flow_style=False)
140
+
141
+ # Load evaluation metrics
142
+ # bleu = evaluate.load("bleu")
143
+ # meteor = evaluate.load("meteor")
144
+ # rouge = evaluate.load("rouge")
145
+
146
+ # Compute BLEU scores
147
+ # bleu_results = bleu.compute(predictions=predictions, references=references, max_order=4)
148
+ # print(bleu_results)
149
+ #bleu_score = sum(bleu_results[f"bleu{i}"] for i in range(1, 5)) / 4
150
+
151
+ # Compute METEOR score
152
+ # meteor_results = meteor.compute(predictions=predictions, references=references)
153
+ # meteor_score = meteor_results["meteor"]
154
+
155
+ # Compute ROUGE-L score
156
+ # rouge_results = rouge.compute(predictions=predictions, references=references, rouge_types=["rougeL"])
157
+ # rouge_l_score = rouge_results["rougeL"].mid.fmeasure
158
+ # print(rouge_results)
159
+
160
+ # Compute BERT-Score
161
+ # P, R, F1 = bert_score.score(predictions, references, lang="en", rescale_with_baseline=True)
162
+ # bert_score_f1 = F1.mean().item()
163
+
164
+ # Print results
165
+ #print(f"BLEU Score: {bleu_score}")
166
+ # print(f"METEOR Score: {meteor_score}")
167
+ # print(f"ROUGE-L Score: {rouge_l_score}")
168
+ # print(f"BERT-Score F1: {bert_score_f1}")
src/sonicverse/scripts/evaluate_model_mullama_musiccaps.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from dataclasses import dataclass, field
3
+ import logging
4
+ from flask import Flask, request, jsonify
5
+ import transformers
6
+ import torch
7
+ from datasets import load_from_disk
8
+ from multi_token.model_utils import MultiTaskType
9
+ from multi_token.training import ModelArguments
10
+ from multi_token.inference import load_trained_lora_model
11
+ from multi_token.data_tools import encode_chat
12
+ import evaluate
13
+ import random
14
+ import bert_score
15
+ from tqdm import tqdm
16
+
17
+ from rouge_score import rouge_scorer
18
+ from nltk.translate.bleu_score import sentence_bleu
19
+ from nltk.translate.meteor_score import meteor_score as meteor_scorer
20
+ from nltk.tokenize import wordpunct_tokenize
21
+ import json
22
+ from bert_score import score
23
+ from tqdm.auto import tqdm
24
+
25
+ import yaml
26
+
27
+ scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
28
+
29
+
30
+ PRETRAIN_PHRASES_OLD = [
31
+ "Describe the audio in detail"
32
+ ]
33
+
34
+ PRETRAIN_PHRASES = [
35
+ "What is happening in the given music <sound>?",
36
+ "Describe the sound. <sound>",
37
+ "Describe the music. <sound>",
38
+ "<sound> Provide a description of the music.",
39
+ "<sound> Provide a description of the sound.",
40
+ "Can you interpret <sound>?",
41
+ "Please explain what's happening in <sound>",
42
+ "What does <sound> represent?",
43
+ "Could you describe <sound> for me?",
44
+ "What's the content of <sound>?",
45
+ "Can you depict <sound>?",
46
+ "What is <sound>?",
47
+ "In the music clip, <sound>, what is happening?",
48
+ "Provide a description of the music. <sound>",
49
+ "Provide a description of the sound. <sound>",
50
+ "Provide a caption for the sound. <sound>",
51
+ "Provide a caption for the music. <sound>",
52
+ ]
53
+
54
+ random.seed(1234)
55
+
56
+ @dataclass
57
+ class ServeArguments(ModelArguments):
58
+ port: int = field(default=8080)
59
+ host: str = field(default="0.0.0.0")
60
+ load_bits: int = field(default=16)
61
+ max_new_tokens: int = field(default=128)
62
+ temperature: float = field(default=0.01)
63
+
64
+ def generate(input_json):
65
+ encoded_dict = encode_chat(input_json, tokenizer, model.modalities)
66
+ with torch.inference_mode():
67
+ output_ids = model.generate(
68
+ input_ids=encoded_dict["input_ids"].unsqueeze(0).to(model.device),
69
+ max_new_tokens=serve_args.max_new_tokens,
70
+ use_cache=True,
71
+ do_sample=True,
72
+ temperature=serve_args.temperature,
73
+ modality_inputs={
74
+ m.name: [encoded_dict[m.name]] for m in model.modalities
75
+ },
76
+ )
77
+ outputs = tokenizer.decode(
78
+ output_ids[0, encoded_dict["input_ids"].shape[0]:],
79
+ skip_special_tokens=True,
80
+ ).strip()
81
+ return {"output": outputs}
82
+
83
+ def evaluate(candidates, mult_reference):
84
+ rouge_score, bleu_score, bleu4_score, meteor_score = 0, 0, 0, 0
85
+ for ref, cand in tqdm(zip(mult_reference, candidates), total=len(mult_reference)):
86
+ rouge_score += scorer.score(ref, cand)['rougeL'].recall
87
+ cand_split = wordpunct_tokenize(cand)
88
+ ref_split = wordpunct_tokenize(ref)
89
+ bleu4_score += sentence_bleu([ref], cand, weights=(0.0, 0.0, 0.0, 1.0))
90
+ bleu_score += sentence_bleu([ref], cand)
91
+ meteor_score += meteor_scorer([ref_split], cand_split)
92
+ rouge_score, bleu_score, bleu4_score, meteor_score = rouge_score / (len(candidates)), bleu_score / (len(candidates)), bleu4_score / (len(candidates)), meteor_score / (len(candidates))
93
+ P, R, F1 = score(candidates, mult_reference, lang="en", verbose=True)
94
+ bert_score = R.mean().item()
95
+ #print(f"Model: {model_name}")
96
+ print(f"BLEU Score: {bleu_score}")
97
+ print(f"BLEU-4 Score: {bleu4_score}")
98
+ print(f"METEOR Score: {meteor_score}")
99
+ print(f"ROUGE Score: {rouge_score}")
100
+ print(f"BERT Score: {bert_score}")
101
+
102
+ if __name__ == "__main__":
103
+ logging.getLogger().setLevel(logging.INFO)
104
+ parser = transformers.HfArgumentParser((ServeArguments,))
105
+ serve_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
106
+ # dataset_path = "/data/musiccaps/musiccaps_val"
107
+ dataset_path = "/data/musicbench_multitoken_official_split/val/"
108
+ ds = load_from_disk(dataset_path)
109
+ shuffled_ds = ds.shuffle(seed=1234)
110
+ model, tokenizer = load_trained_lora_model(
111
+ model_name_or_path=serve_args.model_name_or_path,
112
+ model_lora_path=serve_args.model_lora_path,
113
+ load_bits=serve_args.load_bits,
114
+ use_multi_task=MultiTaskType(serve_args.use_multi_task),
115
+ tasks_config=serve_args.tasks_config
116
+ )
117
+
118
+ predictions = []
119
+ references = []
120
+ content_phrase = random.choice(PRETRAIN_PHRASES)
121
+ # for data_point_id in range(len(ds)):
122
+ print("len(ds)", len(ds))
123
+ #for data_point in tqdm(ds):
124
+ for data_point_id in tqdm(range(100)):
125
+ #print("DATA POINT ", data_point)
126
+ data_point = ds[data_point_id]
127
+ print("DATA POINT ", data_point)
128
+ input_json = {"messages": [{"role": "user", "content": content_phrase}], "sounds": data_point["sounds"]}
129
+ output_json = generate(input_json)
130
+ #print("Prediction ", output_json["output"])
131
+ #print("Reference ", data_point["caption"])
132
+ #print()
133
+ #print()
134
+ predictions.append(output_json["output"])
135
+ references.append(data_point["messages"][1]["content"])
136
+
137
+ pairs = {"predictions": predictions, "references": references}
138
+
139
+ evaluate(predictions, references)
140
+
141
+ with open('test/musicbench_eval.yaml', 'w') as file:
142
+ yaml.dump(pairs, file, default_flow_style=False)
143
+
src/sonicverse/scripts/evaluate_model_mullama_musiccaps_fixed_prompt.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ import logging
3
+ from flask import Flask, request, jsonify
4
+ import transformers
5
+ import torch
6
+ from datasets import load_from_disk
7
+ from multi_token.model_utils import MultiTaskType
8
+ from multi_token.training import ModelArguments
9
+ from multi_token.inference import load_trained_lora_model
10
+ from multi_token.data_tools import encode_chat
11
+ import evaluate
12
+ import random
13
+ import bert_score
14
+ from tqdm import tqdm
15
+
16
+ from rouge_score import rouge_scorer
17
+ from nltk.translate.bleu_score import sentence_bleu
18
+ from nltk.translate.meteor_score import meteor_score as meteor_scorer
19
+ from nltk.tokenize import wordpunct_tokenize
20
+ import json
21
+ from bert_score import score
22
+ from tqdm.auto import tqdm
23
+
24
+ import yaml
25
+
26
+ scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
27
+
28
+
29
+ PRETRAIN_PHRASES_OLD = [
30
+ "Describe the audio in detail"
31
+ ]
32
+
33
+ PRETRAIN_PHRASES = [
34
+ # "What is happening in the given music <sound>?",
35
+ # "Describe the sound. <sound>",
36
+ # "Describe the music. <sound>",
37
+ # "<sound> Provide a description of the music.",
38
+ # "<sound> Provide a description of the sound.",
39
+ # "Can you interpret <sound>?",
40
+ # "Please explain what's happening in <sound>",
41
+ # "What does <sound> represent?",
42
+ # "Could you describe <sound> for me?",
43
+ # "What's the content of <sound>?",
44
+ # "Can you depict <sound>?",
45
+ # "What is <sound>?",
46
+ # "In the music clip, <sound>, what is happening?",
47
+ # "Provide a description of the music. <sound>",
48
+ # "Provide a description of the sound. <sound>",
49
+ # "Provide a caption for the sound. <sound>",
50
+ "Provide a caption for the music. <sound>",
51
+ ]
52
+
53
+ random.seed(1234)
54
+
55
+ @dataclass
56
+ class ServeArguments(ModelArguments):
57
+ port: int = field(default=8080)
58
+ host: str = field(default="0.0.0.0")
59
+ load_bits: int = field(default=16)
60
+ max_new_tokens: int = field(default=128)
61
+ temperature: float = field(default=0.01)
62
+
63
+ def generate(input_json):
64
+ encoded_dict = encode_chat(input_json, tokenizer, model.modalities)
65
+ with torch.inference_mode():
66
+ output_ids = model.generate(
67
+ input_ids=encoded_dict["input_ids"].unsqueeze(0).to(model.device),
68
+ max_new_tokens=serve_args.max_new_tokens,
69
+ use_cache=True,
70
+ do_sample=True,
71
+ temperature=serve_args.temperature,
72
+ modality_inputs={
73
+ m.name: [encoded_dict[m.name]] for m in model.modalities
74
+ },
75
+ )
76
+ outputs = tokenizer.decode(
77
+ output_ids[0, encoded_dict["input_ids"].shape[0]:],
78
+ skip_special_tokens=True,
79
+ ).strip()
80
+ return {"output": outputs}
81
+
82
+ def evaluate(candidates, mult_reference):
83
+ rouge_score, bleu_score, bleu4_score, meteor_score = 0, 0, 0, 0
84
+ for ref, cand in tqdm(zip(mult_reference, candidates), total=len(mult_reference)):
85
+ rouge_score += scorer.score(ref, cand)['rougeL'].recall
86
+ cand_split = wordpunct_tokenize(cand)
87
+ ref_split = wordpunct_tokenize(ref)
88
+ bleu4_score += sentence_bleu([ref], cand, weights=(0.0, 0.0, 0.0, 1.0))
89
+ bleu_score += sentence_bleu([ref], cand)
90
+ meteor_score += meteor_scorer([ref_split], cand_split)
91
+ rouge_score, bleu_score, bleu4_score, meteor_score = rouge_score / (len(candidates)), bleu_score / (len(candidates)), bleu4_score / (len(candidates)), meteor_score / (len(candidates))
92
+ P, R, F1 = score(candidates, mult_reference, lang="en", verbose=True)
93
+ bert_score = R.mean().item()
94
+ #print(f"Model: {model_name}")
95
+ print(f"BLEU Score: {bleu_score}")
96
+ print(f"BLEU-4 Score: {bleu4_score}")
97
+ print(f"METEOR Score: {meteor_score}")
98
+ print(f"ROUGE Score: {rouge_score}")
99
+ print(f"BERT Score: {bert_score}")
100
+
101
+ if __name__ == "__main__":
102
+ logging.getLogger().setLevel(logging.INFO)
103
+ parser = transformers.HfArgumentParser((ServeArguments,))
104
+ serve_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
105
+ dataset_path = "/data/musiccaps/musiccaps_val"
106
+ ds = load_from_disk(dataset_path)
107
+ shuffled_ds = ds.shuffle(seed=1234)
108
+ model, tokenizer = load_trained_lora_model(
109
+ model_name_or_path=serve_args.model_name_or_path,
110
+ model_lora_path=serve_args.model_lora_path,
111
+ load_bits=serve_args.load_bits,
112
+ use_multi_task=MultiTaskType(serve_args.use_multi_task),
113
+ tasks_config=serve_args.tasks_config
114
+ )
115
+
116
+ predictions = []
117
+ references = []
118
+ content_phrase = random.choice(PRETRAIN_PHRASES)
119
+ # for data_point_id in range(len(ds)):
120
+ print("len(ds)", len(ds))
121
+ for data_point in tqdm(ds):
122
+ print(data_point["audio"])
123
+ # data_point = ds[data_point_id]
124
+ input_json = {"messages": [{"role": "user", "content": content_phrase}], "sounds": [data_point["audio"]]}
125
+ output_json = generate(input_json)
126
+ print("Prediction ", output_json["output"])
127
+ print("Reference ", data_point["caption"])
128
+ print()
129
+ print()
130
+ predictions.append(output_json["output"])
131
+ references.append(data_point["caption"])
132
+
133
+ pairs = {"predictions": predictions, "references": references}
134
+
135
+ evaluate(predictions, references)
136
+
137
+ with open('/experiments/captioning/mert_tasks_separate_backbone_train_001_ft/checkpoint_1985_test/musiccaps_val_fixed_prompt.yaml', 'w') as file:
138
+ yaml.dump(pairs, file, default_flow_style=False)
src/sonicverse/scripts/evaluate_mullama.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ import logging
3
+ from flask import Flask, request, jsonify
4
+ import transformers
5
+ import torch
6
+ from datasets import load_from_disk
7
+ from multi_token.model_utils import MultiTaskType
8
+ from multi_token.training import ModelArguments
9
+ from multi_token.inference import load_trained_lora_model
10
+ from multi_token.data_tools import encode_chat
11
+ import evaluate
12
+ import random
13
+ import bert_score
14
+
15
+ PRETRAIN_PHRASES = [
16
+ "What is happening in the given music <sound>?",
17
+ "Describe the sound. <sound>",
18
+ "Describe the music. <sound>",
19
+ "<sound> Provide a description of the music.",
20
+ "<sound> Provide a description of the sound.",
21
+ "Can you interpret <sound>?",
22
+ "Please explain what's happening in <sound>",
23
+ "What does <sound> represent?",
24
+ "Could you describe <sound> for me?",
25
+ "What's the content of <sound>?",
26
+ "Can you depict <sound>?",
27
+ "What is <sound>?",
28
+ "In the music clip, <sound>, what is happening?",
29
+ "Provide a description of the music. <sound>",
30
+ "Provide a description of the sound. <sound>",
31
+ "Provide a caption for the sound. <sound>",
32
+ "Provide a caption for the music. <sound>",
33
+ ]
34
+
35
+ @dataclass
36
+ class ServeArguments(ModelArguments):
37
+ port: int = field(default=8080)
38
+ host: str = field(default="0.0.0.0")
39
+ load_bits: int = field(default=16)
40
+ max_new_tokens: int = field(default=128)
41
+ temperature: float = field(default=0.01)
42
+
43
+ def generate(input_json):
44
+ encoded_dict = encode_chat(input_json, tokenizer, model.modalities)
45
+ with torch.inference_mode():
46
+ output_ids = model.generate(
47
+ input_ids=encoded_dict["input_ids"].unsqueeze(0).to(model.device),
48
+ max_new_tokens=serve_args.max_new_tokens,
49
+ use_cache=True,
50
+ do_sample=True,
51
+ temperature=serve_args.temperature,
52
+ modality_inputs={
53
+ m.name: [encoded_dict[m.name]] for m in model.modalities
54
+ },
55
+ )
56
+ outputs = tokenizer.decode(
57
+ output_ids[0, encoded_dict["input_ids"].shape[0]:],
58
+ skip_special_tokens=True,
59
+ ).strip()
60
+ return {"output": outputs}
61
+
62
+ if __name__ == "__main__":
63
+ logging.getLogger().setLevel(logging.INFO)
64
+ parser = transformers.HfArgumentParser((ServeArguments,))
65
+ serve_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
66
+ dataset_path = "/data/musicbench_multitoken_official_split/val"
67
+ ds = load_from_disk(dataset_path)
68
+
69
+ # Load MU-LLaMA model and tokenizer
70
+ model_name_or_path = "mu-llama/MU-LLaMA"
71
+ model = transformers.LlamaForCausalLM.from_pretrained(model_name_or_path)
72
+ tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name_or_path)
73
+
74
+ predictions = []
75
+ references = []
76
+ content_phrase = random.choice(PRETRAIN_PHRASES)
77
+ for data_point_id in range(100):
78
+ data_point = ds[data_point_id]
79
+ input_json = {"messages": [{"role": "user", "content": content_phrase}], "sounds": data_point["sounds"]}
80
+ output_json = generate(input_json)
81
+ print("Prediction ", output_json["output"])
82
+ print("Reference ", data_point["messages"][1]["content"])
83
+ print()
84
+ print()
85
+ predictions.append(output_json["output"])
86
+ references.append(data_point["messages"][1]["content"])
87
+
88
+ # Load evaluation metrics
89
+ bleu = evaluate.load("bleu")
90
+ meteor = evaluate.load("meteor")
91
+ rouge = evaluate.load("rouge")
92
+
93
+ # Compute BLEU scores
94
+ bleu_results = bleu.compute(predictions=predictions, references=references, max_order=4)
95
+ # bleu_score = sum(bleu_results[f"bleu{i}"] for i in range(1, 5)) / 4
96
+ print(bleu_results)
97
+
98
+ # Compute METEOR score
99
+ meteor_results = meteor.compute(predictions=predictions, references=references)
100
+ meteor_score = meteor_results["meteor"]
101
+
102
+ # Compute ROUGE-L score
103
+ rouge_results = rouge.compute(predictions=predictions, references=references, rouge_types=["rougeL"])
104
+ #rouge_l_score = rouge_results["rougeL"].mid.fmeasure
105
+ print(rouge_results)
106
+
107
+ # Compute BERT-Score
108
+ P, R, F1 = bert_score.score(predictions, references, lang="en", rescale_with_baseline=True)
109
+ bert_score_f1 = F1.mean().item()
110
+
111
+ # Print results
112
+ # print(f"BLEU Score: {bleu_score}")
113
+ print(f"METEOR Score: {meteor_score}")
114
+ # print(f"ROUGE-L Score: {rouge_l_score}")
115
+ print(f"BERT-Score F1: {bert_score_f1}")
src/sonicverse/scripts/evaluate_temp.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ import logging
3
+ from flask import Flask, request, jsonify
4
+ import transformers
5
+ import torch
6
+ from datasets import load_from_disk
7
+ from multi_token.model_utils import MultiTaskType
8
+ from multi_token.training import ModelArguments
9
+ from multi_token.inference import load_trained_lora_model
10
+ from multi_token.data_tools import encode_chat
11
+ import evaluate
12
+ import random
13
+ import bert_score
14
+ import os
15
+
16
+ os.environ['HF_EVALUATE_OFFLINE'] = '1'
17
+
18
+ PRETRAIN_PHRASES = ["Describe the audio in detail <sound>"]
19
+
20
+ PRETRAIN_PHRASES_old = [
21
+ "What is happening in the given music <sound>?",
22
+ "Describe the sound. <sound>",
23
+ "Describe the music. <sound>",
24
+ "<sound> Provide a description of the music.",
25
+ "<sound> Provide a description of the sound.",
26
+ "Can you interpret <sound>?",
27
+ "Please explain what's happening in <sound>",
28
+ "What does <sound> represent?",
29
+ "Could you describe <sound> for me?",
30
+ "What's the content of <sound>?",
31
+ "Can you depict <sound>?",
32
+ "What is <sound>?",
33
+ "In the music clip, <sound>, what is happening?",
34
+ "Provide a description of the music. <sound>",
35
+ "Provide a description of the sound. <sound>",
36
+ "Provide a caption for the sound. <sound>",
37
+ "Provide a caption for the music. <sound>",
38
+ ]
39
+
40
+ @dataclass
41
+ class ServeArguments(ModelArguments):
42
+ port: int = field(default=8080)
43
+ host: str = field(default="0.0.0.0")
44
+ load_bits: int = field(default=16)
45
+ max_new_tokens: int = field(default=128)
46
+ temperature: float = field(default=0.01)
47
+
48
+ def generate(input_json):
49
+ encoded_dict = encode_chat(input_json, tokenizer, model.modalities)
50
+ with torch.inference_mode():
51
+ output_ids = model.generate(
52
+ input_ids=encoded_dict["input_ids"].unsqueeze(0).to(model.device),
53
+ max_new_tokens=serve_args.max_new_tokens,
54
+ use_cache=True,
55
+ do_sample=True,
56
+ temperature=serve_args.temperature,
57
+ modality_inputs={
58
+ m.name: [encoded_dict[m.name]] for m in model.modalities
59
+ },
60
+ )
61
+ outputs = tokenizer.decode(
62
+ output_ids[0, encoded_dict["input_ids"].shape[0]:],
63
+ skip_special_tokens=True,
64
+ ).strip()
65
+ return {"output": outputs}
66
+
67
+ if __name__ == "__main__":
68
+ logging.getLogger().setLevel(logging.INFO)
69
+ parser = transformers.HfArgumentParser((ServeArguments,))
70
+ serve_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
71
+ dataset_path = "/data/musicbench_multitoken_official_split/val"
72
+ ds = load_from_disk(dataset_path)
73
+ model, tokenizer = load_trained_lora_model(
74
+ model_name_or_path=serve_args.model_name_or_path,
75
+ model_lora_path=serve_args.model_lora_path,
76
+ load_bits=serve_args.load_bits,
77
+ use_multi_task=MultiTaskType(serve_args.use_multi_task),
78
+ tasks_config=serve_args.tasks_config
79
+ )
80
+
81
+ predictions = []
82
+ references = []
83
+ content_phrase = random.choice(PRETRAIN_PHRASES)
84
+ for data_point_id in range(10):
85
+ data_point = ds[data_point_id]
86
+ input_json = {"messages": [{"role": "user", "content": content_phrase}], "sounds": data_point["sounds"]}
87
+ output_json = generate(input_json)
88
+ print("Prediction ", output_json["output"])
89
+ print("Reference ", data_point["messages"][1]["content"])
90
+ print()
91
+ print()
92
+ predictions.append(output_json["output"])
93
+ references.append(data_point["messages"][1]["content"])
94
+
95
+ # Load evaluation metrics
96
+ bleu = evaluate.load("bleu")
97
+ meteor = evaluate.load("meteor")
98
+ rouge = evaluate.load("rouge")
99
+
100
+ # Compute BLEU scores
101
+ bleu_results = bleu.compute(predictions=predictions, references=references, max_order=4)
102
+ print(bleu_results)
103
+ #bleu_score = sum(bleu_results[f"bleu{i}"] for i in range(1, 5)) / 4
104
+
105
+ # Compute METEOR score
106
+ meteor_results = meteor.compute(predictions=predictions, references=references)
107
+ meteor_score = meteor_results["meteor"]
108
+
109
+ # Compute ROUGE-L score
110
+ rouge_results = rouge.compute(predictions=predictions, references=references, rouge_types=["rougeL"])
111
+ # rouge_l_score = rouge_results["rougeL"].mid.fmeasure
112
+ print(rouge_results)
113
+
114
+ # Compute BERT-Score
115
+ P, R, F1 = bert_score.score(predictions, references, lang="en", rescale_with_baseline=True)
116
+ bert_score_f1 = F1.mean().item()
117
+
118
+ # Print results
119
+ #print(f"BLEU Score: {bleu_score}")
120
+ print(f"METEOR Score: {meteor_score}")
121
+ # print(f"ROUGE-L Score: {rouge_l_score}")
122
+ print(f"BERT-Score F1: {bert_score_f1}")