File size: 1,932 Bytes
7c34c28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from multi_token.model_utils import MultiTaskType
from multi_token.modalities.vision_clip import (
    CLIPVisionModality,
    OUTPUT_LAYER as CLIP_POOL_LAYER,
)
from multi_token.modalities.imagebind import ImageBindModality
from multi_token.modalities.document_gte import DocumentGTEModality
from multi_token.modalities.audio_whisper import WhisperAudioModality
from multi_token.modalities.audio_clap import CLAPAudioModality
from multi_token.modalities.video_xclip import XCLIPVideoModality
from multi_token.modalities.audio_descript import DescriptAudioModality
from multi_token.modalities.audio_mert import MERTAudioModality

MODALITY_BUILDERS = {
    "vision_clip": lambda: [CLIPVisionModality()],
    "vision_clip_pool": lambda: [
        CLIPVisionModality(feature_layer=CLIP_POOL_LAYER, num_tokens_output=10)
    ],
    "audio_whisper": lambda: [
        WhisperAudioModality(
            num_tokens_output=10, model_name_or_path="openai/whisper-small"
        )
    ],
    "audio_mert": lambda use_multi_task=MultiTaskType.NO_MULTI_TASK, tasks_config=None :[MERTAudioModality(use_multi_task=use_multi_task, tasks_config=tasks_config, num_tokens_output=60, hidden_dim=32, num_conv_layers = 3, num_mlp_layers = 2)],
    "audio_clap": lambda use_multi_task=MultiTaskType.NO_MULTI_TASK, tasks_config=None :[CLAPAudioModality(use_multi_task=use_multi_task, tasks_config=tasks_config, num_tokens_output=20)],
    "audio_descript": lambda use_multi_task=MultiTaskType.NO_MULTI_TASK, tasks_config=None : [DescriptAudioModality(use_multi_task=use_multi_task, tasks_config=tasks_config, num_projector_conv_layers=1, num_projector_mlp_layers=1, num_tokens_output=60, codebooks=96)],
    "video_xclip": lambda: [XCLIPVideoModality(num_tokens_output=10)],
    "imagebind": lambda: [ImageBindModality()],
    "document_gte": lambda: [DocumentGTEModality()],
    "document_gte_x16": lambda: [DocumentGTEModality(num_tokens_output=32)],
}