diff --git a/src/sonicverse/configs/tasks.json b/src/sonicverse/configs/tasks.json new file mode 100644 index 0000000000000000000000000000000000000000..84b6d6f4fd725fcf51cfd4c6b52beb347fb33a4b --- /dev/null +++ b/src/sonicverse/configs/tasks.json @@ -0,0 +1,208 @@ +{ + "backbone": { + "num_layers": 5, + "input_channels": 25, + "output_channels": 25, + "output_size": 4096, + "hidden_size": 4096, + "requires_grad": true + }, + "task_heads": { + "lmm_projector": { + "num_layers": 3, + "output_size": 4096, + "hidden_size": 4096, + "input_size": 768, + "input_channels": 13, + "width": 40, + "weight": 1.0, + "model_type": "mlp", + "requires_grad": true, + "use_aggregator": true, + "use_time_average": true, + "use_sigmoid": false, + "use_transpose": false, + "use_backbone_output": false + }, + "instrument_detection": { + "model_type": "mlp", + "use_aggregator": true, + "use_time_average": true, + "use_sigmoid": true, + "use_transpose": false, + "num_layers": 2, + "input_size": 508, + "output_size": 40, + "hidden_size": 4096, + "width": 1, + "weight": 0.1, + "requires_grad": true, + "num_conv_layers": 4, + "output_channel": 1 + }, + "mood_detection": { + "model_type": "mlp", + "use_aggregator": true, + "use_time_average": true, + "use_sigmoid": true, + "use_transpose": false, + "num_layers": 2, + "input_size": 508, + "output_size": 56, + "hidden_size": 4096, + "width": 1, + "weight": 0.1, + "requires_grad": true, + "num_conv_layers": 4, + "output_channel": 1 + }, + "genre_detection": { + "model_type": "mlp", + "use_aggregator": true, + "use_time_average": true, + "use_sigmoid": true, + "use_transpose": false, + "num_layers": 2, + "input_size": 508, + "output_size": 87, + "hidden_size": 4096, + "width": 1, + "weight": 0.1, + "requires_grad": true, + "num_conv_layers": 4, + "output_channel": 1 + }, + "key_detection": { + "model_type": "mlp", + "use_aggregator": true, + "use_time_average": true, + "use_sigmoid": true, + "use_transpose": false, + "num_layers": 2, + "input_size": 508, + "output_size": 24, + "hidden_size": 4096, + "width": 1, + "weight": 0.1, + "requires_grad": true, + "num_conv_layers": 4, + "output_channel": 1 + }, + "vocals_detection": { + "model_type": "mlp", + "use_aggregator": true, + "use_time_average": true, + "use_sigmoid": true, + "use_transpose": false, + "num_layers": 2, + "input_size": 508, + "output_size": 3, + "hidden_size": 4096, + "width": 1, + "weight": 0.1, + "requires_grad": true, + "num_conv_layers": 4, + "output_channel": 1 + } + }, + "task_projectors": { + "instrument_detection": { + "model_type": "mlp", + "num_layers": 3, + "input_channels": 0, + "input_size": 40, + "output_size": 4096, + "hidden_size": 4096, + "width": 4, + "use_aggregator": false, + "use_time_average": false, + "use_sigmoid": false, + "use_transpose": false, + "requires_grad": true + }, + "mood_detection": { + "model_type": "mlp", + "num_layers": 3, + "input_channels": 0, + "input_size": 56, + "output_size": 4096, + "hidden_size": 4096, + "width": 4, + "use_aggregator": false, + "use_time_average": false, + "use_sigmoid": false, + "use_transpose": false, + "requires_grad": true + }, + "genre_detection": { + "model_type": "mlp", + "num_layers": 3, + "input_channels": 0, + "input_size": 87, + "output_size": 4096, + "hidden_size": 4096, + "width": 4, + "use_aggregator": false, + "use_time_average": false, + "use_sigmoid": false, + "use_transpose": false, + "requires_grad": true + }, + "key_detection": { + "model_type": "mlp", + "num_layers": 3, + "input_channels": 0, + "input_size": 24, + "output_size": 4096, + "hidden_size": 4096, + "width": 4, + "use_aggregator": false, + "use_time_average": false, + "use_sigmoid": false, + "use_transpose": false, + "requires_grad": true + }, + "vocals_detection": { + "model_type": "mlp", + "num_layers": 3, + "input_channels": 0, + "input_size": 3, + "output_size": 4096, + "hidden_size": 4096, + "width": 4, + "use_aggregator": false, + "use_time_average": false, + "use_sigmoid": false, + "use_transpose": false, + "requires_grad": true + }, + "chords_detection": { + "model_type": "mlp", + "num_layers": 3, + "input_channels": 0, + "input_size": 216, + "output_size": 4096, + "hidden_size": 4096, + "width": 4, + "use_aggregator": false, + "use_time_average": false, + "use_sigmoid": false, + "use_transpose": false, + "requires_grad": true + }, + "beats_detection": { + "model_type": "mlp_conv_agg", + "num_layers": 3, + "input_channels": 2, + "input_size": 500, + "output_size": 4096, + "hidden_size": 4096, + "width": 4, + "use_aggregator": true, + "use_time_average": false, + "use_sigmoid": false, + "use_transpose": true, + "requires_grad": true + } + } +} diff --git a/src/sonicverse/configs/tasks_baseline.json b/src/sonicverse/configs/tasks_baseline.json new file mode 100644 index 0000000000000000000000000000000000000000..1c21d86c7610a5f4d24b892c416ffbc8b2a0a8d6 --- /dev/null +++ b/src/sonicverse/configs/tasks_baseline.json @@ -0,0 +1,20 @@ +{ + "task_heads": { + "lmm_projector": { + "num_layers": 3, + "output_size": 4096, + "hidden_size": 4096, + "input_size": 768, + "input_channels": 13, + "width": 60, + "weight": 1.0, + "model_type": "mlp", + "requires_grad": true, + "use_aggregator": true, + "use_time_average": true, + "use_sigmoid": false, + "use_transpose": false + } + }, + "task_projectors": {} +} diff --git a/src/sonicverse/configs/tasks_ft.json b/src/sonicverse/configs/tasks_ft.json new file mode 100644 index 0000000000000000000000000000000000000000..6359451331e62756ff884e412944a559500b8505 --- /dev/null +++ b/src/sonicverse/configs/tasks_ft.json @@ -0,0 +1,208 @@ +{ + "backbone": { + "num_layers": 5, + "input_channels": 25, + "output_channels": 25, + "output_size": 4096, + "hidden_size": 4096, + "requires_grad": false + }, + "task_heads": { + "lmm_projector": { + "num_layers": 3, + "output_size": 4096, + "hidden_size": 4096, + "input_size": 768, + "input_channels": 13, + "width": 40, + "weight": 1.0, + "model_type": "mlp", + "requires_grad": true, + "use_aggregator": true, + "use_time_average": true, + "use_sigmoid": false, + "use_transpose": false, + "use_backbone_output": false + }, + "instrument_detection": { + "model_type": "mlp", + "use_aggregator": true, + "use_time_average": true, + "use_sigmoid": true, + "use_transpose": false, + "num_layers": 2, + "input_size": 508, + "output_size": 40, + "hidden_size": 4096, + "width": 1, + "weight": 0.0, + "requires_grad": false, + "num_conv_layers": 4, + "output_channel": 1 + }, + "mood_detection": { + "model_type": "mlp", + "use_aggregator": true, + "use_time_average": true, + "use_sigmoid": true, + "use_transpose": false, + "num_layers": 2, + "input_size": 508, + "output_size": 56, + "hidden_size": 4096, + "width": 1, + "weight": 0.0, + "requires_grad": false, + "num_conv_layers": 4, + "output_channel": 1 + }, + "genre_detection": { + "model_type": "mlp", + "use_aggregator": true, + "use_time_average": true, + "use_sigmoid": true, + "use_transpose": false, + "num_layers": 2, + "input_size": 508, + "output_size": 87, + "hidden_size": 4096, + "width": 1, + "weight": 0.0, + "requires_grad": false, + "num_conv_layers": 4, + "output_channel": 1 + }, + "key_detection": { + "model_type": "mlp", + "use_aggregator": true, + "use_time_average": true, + "use_sigmoid": true, + "use_transpose": false, + "num_layers": 2, + "input_size": 508, + "output_size": 24, + "hidden_size": 4096, + "width": 1, + "weight": 0.0, + "requires_grad": false, + "num_conv_layers": 4, + "output_channel": 1 + }, + "vocals_detection": { + "model_type": "mlp", + "use_aggregator": true, + "use_time_average": true, + "use_sigmoid": true, + "use_transpose": false, + "num_layers": 2, + "input_size": 508, + "output_size": 3, + "hidden_size": 4096, + "width": 1, + "weight": 0.0, + "requires_grad": false, + "num_conv_layers": 4, + "output_channel": 1 + } + }, + "task_projectors": { + "instrument_detection": { + "model_type": "mlp", + "num_layers": 3, + "input_channels": 0, + "input_size": 40, + "output_size": 4096, + "hidden_size": 4096, + "width": 4, + "use_aggregator": false, + "use_time_average": false, + "use_sigmoid": false, + "use_transpose": false, + "requires_grad": true + }, + "mood_detection": { + "model_type": "mlp", + "num_layers": 3, + "input_channels": 0, + "input_size": 56, + "output_size": 4096, + "hidden_size": 4096, + "width": 4, + "use_aggregator": false, + "use_time_average": false, + "use_sigmoid": false, + "use_transpose": false, + "requires_grad": true + }, + "genre_detection": { + "model_type": "mlp", + "num_layers": 3, + "input_channels": 0, + "input_size": 87, + "output_size": 4096, + "hidden_size": 4096, + "width": 4, + "use_aggregator": false, + "use_time_average": false, + "use_sigmoid": false, + "use_transpose": false, + "requires_grad": true + }, + "key_detection": { + "model_type": "mlp", + "num_layers": 3, + "input_channels": 0, + "input_size": 24, + "output_size": 4096, + "hidden_size": 4096, + "width": 4, + "use_aggregator": false, + "use_time_average": false, + "use_sigmoid": false, + "use_transpose": false, + "requires_grad": true + }, + "vocals_detection": { + "model_type": "mlp", + "num_layers": 3, + "input_channels": 0, + "input_size": 3, + "output_size": 4096, + "hidden_size": 4096, + "width": 4, + "use_aggregator": false, + "use_time_average": false, + "use_sigmoid": false, + "use_transpose": false, + "requires_grad": true + }, + "chords_detection": { + "model_type": "mlp", + "num_layers": 3, + "input_channels": 0, + "input_size": 216, + "output_size": 4096, + "hidden_size": 4096, + "width": 4, + "use_aggregator": false, + "use_time_average": false, + "use_sigmoid": false, + "use_transpose": false, + "requires_grad": true + }, + "beats_detection": { + "model_type": "mlp_conv_agg", + "num_layers": 3, + "input_channels": 2, + "input_size": 500, + "output_size": 4096, + "hidden_size": 4096, + "width": 4, + "use_aggregator": true, + "use_time_average": false, + "use_sigmoid": false, + "use_transpose": true, + "requires_grad": true + } + } +} diff --git a/src/sonicverse/configs/tasks_pt_weight.json b/src/sonicverse/configs/tasks_pt_weight.json new file mode 100644 index 0000000000000000000000000000000000000000..5df31b4a22eddfb40310227d93a31285486b66f1 --- /dev/null +++ b/src/sonicverse/configs/tasks_pt_weight.json @@ -0,0 +1,10 @@ +{ + "pretrained_paths": [ + { + "path": "/experiments/music_extraction/mlp_shared_multi_task_trial_002/train_002_epoch_45_step_40.pth", + "components": ["backbone", "instrument_detection", "genre_detection", "mood_detection", "key_detection", "vocals_detection"], + "use_prefix": true, + "prefix": "audio_mert_lmm_projector" + } + ] +} diff --git a/src/sonicverse/configs/zero2.json b/src/sonicverse/configs/zero2.json new file mode 100644 index 0000000000000000000000000000000000000000..c95ebefe07b7d8d9fd0936a014679d07102cc270 --- /dev/null +++ b/src/sonicverse/configs/zero2.json @@ -0,0 +1,23 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "train_micro_batch_size_per_gpu": "auto", + "train_batch_size": "auto", + "gradient_accumulation_steps": "auto", + "zero_optimization": { + "stage": 2, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto" + } +} \ No newline at end of file diff --git a/src/sonicverse/configs/zero3.json b/src/sonicverse/configs/zero3.json new file mode 100644 index 0000000000000000000000000000000000000000..6917317af62da757ca759a92b326ddfa65b203cc --- /dev/null +++ b/src/sonicverse/configs/zero3.json @@ -0,0 +1,28 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "train_micro_batch_size_per_gpu": "auto", + "train_batch_size": "auto", + "gradient_accumulation_steps": "auto", + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + } +} \ No newline at end of file diff --git a/src/sonicverse/configs/zero3_offload.json b/src/sonicverse/configs/zero3_offload.json new file mode 100644 index 0000000000000000000000000000000000000000..74ab0134e0eacb48fa64f9d34d73708571331687 --- /dev/null +++ b/src/sonicverse/configs/zero3_offload.json @@ -0,0 +1,56 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "offload_param": { + "device": "cpu", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "gather_16bit_weights_on_model_save": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "steps_per_print": 1e5, + "wall_clock_breakdown": false + } \ No newline at end of file diff --git a/src/sonicverse/multi_token.egg-info/PKG-INFO b/src/sonicverse/multi_token.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..38ae45b8de0c8699e8a4b7b05956ca44dae62373 --- /dev/null +++ b/src/sonicverse/multi_token.egg-info/PKG-INFO @@ -0,0 +1,6 @@ +Metadata-Version: 2.1 +Name: multi-token +Version: 0.0.4 +Home-page: https://github.com/sshh12/multi_token +Author: Shrivu Shankar +License: Apache License 2.0 diff --git a/src/sonicverse/multi_token.egg-info/SOURCES.txt b/src/sonicverse/multi_token.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..2294aae83797dcafadbf5b6d906074799c192973 --- /dev/null +++ b/src/sonicverse/multi_token.egg-info/SOURCES.txt @@ -0,0 +1,6 @@ +setup.py +multi_token.egg-info/PKG-INFO +multi_token.egg-info/SOURCES.txt +multi_token.egg-info/dependency_links.txt +multi_token.egg-info/requires.txt +multi_token.egg-info/top_level.txt \ No newline at end of file diff --git a/src/sonicverse/multi_token.egg-info/dependency_links.txt b/src/sonicverse/multi_token.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/sonicverse/multi_token.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/src/sonicverse/multi_token.egg-info/requires.txt b/src/sonicverse/multi_token.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..f129b5fe1018aed6ca7016b870189e6323615197 --- /dev/null +++ b/src/sonicverse/multi_token.egg-info/requires.txt @@ -0,0 +1,8 @@ +transformers>=4.34.0 +accelerate>=0.21.0 +scipy>=1.11.3 +bitsandbytes>=0.41.0 +datasets>=2.14.5 +sentencepiece>=0.1.99 +peft>=0.4.0 +deepspeed==0.9.5 diff --git a/src/sonicverse/multi_token.egg-info/top_level.txt b/src/sonicverse/multi_token.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/sonicverse/multi_token.egg-info/top_level.txt @@ -0,0 +1 @@ + diff --git a/src/sonicverse/multi_token/constants.py b/src/sonicverse/multi_token/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..8c5675e1550da65643875ee694f43f03fe4b26bb --- /dev/null +++ b/src/sonicverse/multi_token/constants.py @@ -0,0 +1,4 @@ +IGNORE_INDEX = -100 + +ROLE_ASSISTANT = "assistant" +ROLE_USER = "user" diff --git a/src/sonicverse/multi_token/data_tools.py b/src/sonicverse/multi_token/data_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..d3abdaeb2306ee2dd979267290624279ba488f77 --- /dev/null +++ b/src/sonicverse/multi_token/data_tools.py @@ -0,0 +1,336 @@ +from typing import Dict, List, Any, Union, Optional +from collections import Counter +from functools import cache +import contextlib +import tempfile +import shutil +import random +import subprocess +import json +import re +import io +import os + +import torch +import requests +import transformers +import numpy as np +from datasets import load_dataset, Dataset +from PIL import Image + +from multi_token.constants import IGNORE_INDEX + + +def encode_chat( + item: Dict, + tokenizer: transformers.PreTrainedTokenizer, + modalities: List["Modality"], +) -> Dict: + messages = list(item["messages"]) + chat_as_string = tokenizer.apply_chat_template(messages, tokenize=False) + + token_to_modality = {m.token: m for m in modalities} + modality_token_counts = Counter() + instruct_pattern = r"(\[INST\][\s\S]*?\[\/INST\])" + pattern = "(" + "|".join(re.escape(m.token) for m in modalities) + ")" + + chat_part = re.split(instruct_pattern, chat_as_string) + input_ids = [] + labels = [] + for part in chat_part: + if "[INST]" in part: + is_instruction = True + else: + is_instruction = False + for subpart in re.split(pattern, part): + if not subpart: + continue + if subpart in token_to_modality: + assert ( + is_instruction + ), "There should be no modality tokens outside of instructions" + m = token_to_modality[subpart] + modality_token_counts[m.name] += 1 + input_ids.extend([m.token_idx] * m.token_width) + labels.extend([IGNORE_INDEX] * m.token_width) + elif is_instruction: + part_ids = tokenizer(subpart, add_special_tokens=False).input_ids + input_ids.extend(part_ids) + labels.extend([IGNORE_INDEX] * len(part_ids)) + else: + part_ids = tokenizer(subpart, add_special_tokens=False).input_ids + input_ids.extend(part_ids) + labels.extend(part_ids) + + input_ids = torch.tensor(input_ids, dtype=torch.long) + labels = torch.tensor(labels, dtype=torch.long) + + data_dict = dict( + input_ids=input_ids, + labels=labels, + ) + for m in modalities: + data_dict[m.name] = m.preprocess_rows([item])[0] + return data_dict + +def encode_chat_multitask( + item: Dict, + tokenizer: transformers.PreTrainedTokenizer, + modalities: List["Modality"], +) -> Dict: + messages = list(item["messages"]) + chat_as_string = tokenizer.apply_chat_template(messages, tokenize=False) + + token_to_modality = {m.token: m for m in modalities} + modality_token_counts = Counter() + instruct_pattern = r"(\[INST\][\s\S]*?\[\/INST\])" + pattern = "(" + "|".join(re.escape(m.token) for m in modalities) + ")" + + chat_part = re.split(instruct_pattern, chat_as_string) + input_ids = [] + labels = [] + labels.append([]) + for part in chat_part: + if "[INST]" in part: + is_instruction = True + else: + is_instruction = False + for subpart in re.split(pattern, part): + if not subpart: + continue + if subpart in token_to_modality: + assert ( + is_instruction + ), "There should be no modality tokens outside of instructions" + m = token_to_modality[subpart] + modality_token_counts[m.name] += 1 + input_ids.extend([m.token_idx] * m.token_width) + labels[0].extend([IGNORE_INDEX] * m.token_width) + elif is_instruction: + part_ids = tokenizer(subpart, add_special_tokens=False).input_ids + input_ids.extend(part_ids) + labels[0].extend([IGNORE_INDEX] * len(part_ids)) + else: + part_ids = tokenizer(subpart, add_special_tokens=False).input_ids + input_ids.extend(part_ids) + labels[0].extend(part_ids) + + input_ids = torch.tensor(input_ids, dtype=torch.long) + labels[0] = torch.tensor(labels[0], dtype=torch.long) + + task_list = [] + for m in modalities: + task_list += m.tasks["task_heads"].keys() + # labels[task_specs["task_id"]] = load_tensor(item[task_name][0]) + + for task_name in task_list: + if task_name != "lmm_projector": + labels.append(load_tensor(item[task_name][0])) + + # labels = torch.tensor(labels, dtype=torch.long) + + data_dict = dict( + input_ids=input_ids, + labels=labels, + ) + for m in modalities: + data_dict[m.name] = m.preprocess_rows([item])[0] + return data_dict + +def load_tensor(path: str) -> np.ndarray: + return torch.tensor(np.load(path)) + + +def load_image(value: Any) -> Image.Image: + img = None + if isinstance(value, str): + if value.startswith("http://") or value.startswith("https://"): + response = requests.get(value) + img = Image.open(io.BytesIO(response.content)) + elif os.path.exists(value): + img = Image.open(value) + elif isinstance(value, Image.Image): + img = value + if img is None: + raise ValueError(f"Could not load image from {value}") + img = img.convert("RGB") + return img + + +@contextlib.contextmanager +def with_local_files(fn_or_urls: List[Any]): + local_fns = [] + fps = [] + for fn_or_url in fn_or_urls: + if isinstance(fn_or_url, Image.Image): + fp = tempfile.NamedTemporaryFile(suffix=".png", mode="wb") + fn_or_url.convert("RGB").save(fp) + fps.append(fp) + local_fns.append(fp.name) + elif fn_or_url.startswith("http://") or fn_or_url.startswith("https://"): + suffix = os.path.splitext(fn_or_url)[-1] + with requests.get(fn_or_url, stream=True) as r: + fp = tempfile.NamedTemporaryFile(suffix=suffix, mode="wb") + shutil.copyfileobj(r.raw, fp) + fps.append(fp) + local_fns.append(fp.name) + else: + local_fns.append(fn_or_url) + try: + yield local_fns + finally: + for fp in fps: + fp.close() + + +@cache +def _get_dataset(dataset_args: str) -> Dataset: + return load_dataset(**json.loads(dataset_args)) + + +def get_dataset_cached(dataset_args: Dict) -> Dataset: + return _get_dataset(json.dumps(dataset_args)) + + +def load_audio_signal(input_: Union[Dict, str]) -> Dict: + from audiotools import AudioSignal + + if isinstance(input_, dict) and "array" in input_: + array = input_["array"] + elif isinstance(input_, dict) and "dataset_args" in input_: + item = get_dataset_cached(input_["dataset_args"])[input_["idx"]] + array = item["audio"]["array"] + elif isinstance(input_, dict) and "path" in input_: + with with_local_files([input_["path"]]) as local_fns: + array = AudioSignal(local_fns[0]) + elif isinstance(input_, str): + with with_local_files([input_]) as local_fns: + array = AudioSignal(local_fns[0]) + else: + raise ValueError(f"Could not load audio from {input_}") + + return {"array": list(array)} + + +def load_audio(input_: Union[Dict, str], target_sampling_rate: int = None) -> Dict: + import soundfile as sf + import librosa + + if isinstance(input_, dict) and "array" in input_ and "sampling_rate" in input_: + array = input_["array"] + sampling_rate = input_["sampling_rate"] + elif isinstance(input_, dict) and "dataset_args" in input_: + item = get_dataset_cached(input_["dataset_args"])[input_["idx"]] + array = item["audio"]["array"] + sampling_rate = item["audio"]["sampling_rate"] + elif isinstance(input_, dict) and "path" in input_: + with with_local_files([input_["path"]]) as local_fns: + array, sampling_rate = sf.read(local_fns[0]) + elif isinstance(input_, str): + with with_local_files([input_]) as local_fns: + array, sampling_rate = sf.read(local_fns[0]) + else: + raise ValueError(f"Could not load audio from {input_}") + + if array.ndim == 2: + array = array.mean(axis=1) + + if target_sampling_rate is not None and sampling_rate != target_sampling_rate: + array = librosa.resample( + array, orig_sr=sampling_rate, target_sr=target_sampling_rate + ) + sampling_rate = target_sampling_rate + + return {"array": list(array), "sampling_rate": sampling_rate} + + +def _download_yt_video(url: str) -> str: + from pytube import YouTube + + youtube = YouTube(url) + video = youtube.streams.first() + + fn = "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=10)) + file_path = video.download(output_path=tempfile.gettempdir(), filename=fn) + + return file_path + + +def _read_video_pyav(container, indices): + frames = [] + container.seek(0) + start_index = indices[0] + end_index = indices[-1] + for i, frame in enumerate(container.decode(video=0)): + if i > end_index: + break + if i >= start_index and i in indices: + frames.append(frame) + return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + +def _sample_frame_indices(clip_len, frame_sample_rate, seg_len): + converted_len = int(clip_len * frame_sample_rate) + end_idx = np.random.randint(converted_len, seg_len) + start_idx = end_idx - converted_len + indices = np.linspace(start_idx, end_idx, num=clip_len) + indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64) + return indices + + +def load_video( + input_: str, + frames: int = 8, + frame_sample_rate: int = 1, + start_time: Optional[int] = None, + end_time: Optional[int] = None, +) -> np.ndarray: + import av + + delete_file = False + + if isinstance(input_, dict) and "youtube.com" and input_.get("url", ""): + file_path = _download_yt_video(input_["url"]) + delete_file = True + # start_time = input_.get("start_time", None) + # end_time = input_.get("end_time", None) + elif isinstance(input_, str) and "youtube.com" in input_: + file_path = _download_yt_video(input_) + delete_file = True + elif isinstance(input_, str): + file_path = input_ + else: + raise ValueError(f"Could not load video from {input_}") + + if start_time is not None or end_time is not None: + start_time = start_time if start_time is not None else 0 + end_time = end_time if end_time is not None else "end" + trim_file_path = f"{file_path.rsplit('.', 1)[0]}_trim.mp4" + subprocess.run( + [ + "ffmpeg", + "-i", + file_path, + "-ss", + str(start_time), + "-to", + str(end_time), + "-c", + "copy", + trim_file_path, + ] + ) + file_path = trim_file_path + + container = av.open(file_path) + indices = _sample_frame_indices( + clip_len=frames, + frame_sample_rate=frame_sample_rate, + seg_len=container.streams.video[0].frames, + ) + video = _read_video_pyav(container, indices) + + if delete_file: + os.remove(file_path) + + return video diff --git a/src/sonicverse/multi_token/inference.py b/src/sonicverse/multi_token/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..fe00ec579140da2ecf3c453ac8de8e0d5c3c10a7 --- /dev/null +++ b/src/sonicverse/multi_token/inference.py @@ -0,0 +1,83 @@ +from typing import Type, List, Optional +import logging + +from transformers import AutoTokenizer, AutoConfig, BitsAndBytesConfig +from huggingface_hub import hf_hub_download +from peft import PeftModel +import torch +import os + +from multi_token.model_utils import fix_tokenizer, MultiTaskType +from multi_token.modalities.base_modality import Modality +from multi_token.language_models.mistral import MistralForCausalLM +from multi_token.language_models import LANGUAGE_MODEL_NAME_TO_CLASS +from multi_token.modalities import MODALITY_BUILDERS + + +def load_trained_lora_model( + model_name_or_path: str, + model_lora_path: str, + model_cls: Optional[Type] = None, + modalities: Optional[List[Modality]] = None, + load_bits: int = 16, + device_map: str = "auto", + use_multi_task: int = MultiTaskType.NO_MULTI_TASK, + tasks_config: str = None +): + load_kwargs = {"device_map": device_map} + + if load_bits == 8: + load_kwargs["load_in_8bit"] = True + elif load_bits == 4: + load_kwargs["load_in_4bit"] = True + load_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) + elif load_bits == 16: + load_kwargs["torch_dtype"] = torch.float16 + else: + raise ValueError(f"Invalid load_bits: {load_bits}") + + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False) + fix_tokenizer(tokenizer) + + cfg = AutoConfig.from_pretrained(model_lora_path) + if model_cls is None: + model_cls = LANGUAGE_MODEL_NAME_TO_CLASS[cfg.model_cls] + if modalities is None: + if use_multi_task: + modalities = MODALITY_BUILDERS[cfg.modality_builder](use_multi_task = use_multi_task, tasks_config = tasks_config) + else: + modalities = MODALITY_BUILDERS[cfg.modality_builder]() + + logging.info(f"Loading base model from {model_name_or_path} as {load_bits} bits") + model = model_cls.from_pretrained( + model_name_or_path, low_cpu_mem_usage=True, config=cfg, **load_kwargs + ) + model.modalities = modalities + + logging.info(f"Loading projector weights for {[m.name for m in modalities]}") + if os.path.exists(os.path.join(model_lora_path, "non_lora_trainables.bin")): + non_lora_trainables = torch.load( + os.path.join(model_lora_path, "non_lora_trainables.bin"), map_location="cuda" + ) + else: + local_fn = hf_hub_download( + repo_id=model_lora_path, + filename="non_lora_trainables.bin", + repo_type="model", + ) + non_lora_trainables = torch.load(local_fn, map_location="cuda") + model.get_model().initialize_pretrained_modules(modalities, non_lora_trainables) + + logging.info(f"Loading and merging LoRA weights from {model_lora_path}") + model = PeftModel.from_pretrained(model, model_lora_path) + if load_bits == 16: + # TODO: Figure out why this fails for other bit sizes + model = model.merge_and_unload() + model.eval() + + return model, tokenizer diff --git a/src/sonicverse/multi_token/language_models/__init__.py b/src/sonicverse/multi_token/language_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abaf0e65b65c17d18229fc7b0605b6e6b65d5644 --- /dev/null +++ b/src/sonicverse/multi_token/language_models/__init__.py @@ -0,0 +1,7 @@ +from multi_token.language_models.mistral import ( + MistralLMMForCausalLM, +) + +LANGUAGE_MODEL_CLASSES = [MistralLMMForCausalLM] + +LANGUAGE_MODEL_NAME_TO_CLASS = {cls.__name__: cls for cls in LANGUAGE_MODEL_CLASSES} diff --git a/src/sonicverse/multi_token/language_models/base_model.py b/src/sonicverse/multi_token/language_models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1bf6f6597d7224e9aafdb9e65bdd37f80155e5cb --- /dev/null +++ b/src/sonicverse/multi_token/language_models/base_model.py @@ -0,0 +1,181 @@ +from typing import List, Dict +from abc import ABC, abstractmethod + +from torch.nn.functional import conv1d +import torch +import logging + +from multi_token.modalities.base_modality import Modality +from multi_token.model_utils import MultiTaskType + +from torchviz import make_dot + +class LMMMetaModel: + def __init__(self, config): + super(LMMMetaModel, self).__init__(config) + + def _load_projector_weights(self, weights: Dict): + weights = { + (k[23:] if k.startswith("base_model.model.model.") else k): v + for k, v in weights.items() + } + logging.info(f"Loading pretrained weights: {list(weights.keys())}") + load_result = self.load_state_dict(weights, strict=False) + assert ( + len(load_result.unexpected_keys) == 0 + ), "Unexpected weights, is this the right model?" + + def initialize_pretrained_modules(self, modalities: List[Modality], weights: Dict): + for m in modalities: + # projector = m.build_projector(self.config.hidden_size) + # setattr(self, m.name + "_lmm_projector", projector) + projector = m.build_projector(self.config.hidden_size) + if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK: + for task_name in m.tasks["task_heads"].keys(): + task_model = projector[task_name] + setattr(self, m.name + "_" + task_name, task_model) + else: + setattr(self, m.name + "_lmm_projector", projector) + + self._load_projector_weights(weights) + + def initialize_modules(self, modalities: List[Modality], weights: Dict): + names = [m.name for m in modalities] + + self.config.modalities = names + + for m in modalities: + # projector = m.build_projector(self.config.hidden_size) + # setattr(self, m.name + "_lmm_projector", projector) + projector = m.build_projector(self.config.hidden_size) + if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK: + for task_name in m.tasks["task_heads"].keys(): + task_model = projector[task_name] + setattr(self, m.name + "_" + task_name, task_model) + else: + setattr(self, m.name + "_lmm_projector", projector) + + self._load_projector_weights(weights) + + +class LMMMetaForCausalLM(ABC): + @abstractmethod + def get_model(self) -> "LMMMetaForCausalLM": + pass + + def prepare_inputs_labels_for_multimodal( + self, input_ids, attention_mask, past_key_values, labels, **kwargs + ): + model = self.get_model() + + batch_size, seq_len = input_ids.shape + + # batch_size x seq_len x embedding_hidden_size + inputs_embeds = torch.zeros( + (batch_size, seq_len, self.config.hidden_size), + dtype=self.dtype, + device=self.device, + ) + + # modality x batch_size x instance_idx x modality_token_width x embedding_hidden_size + projected_tensors = [] + # assuming that if caching is enabled, we'll never have past_key_values AND need to encode the instruction modality values + task_vals = {} + + #print("here past_key_values", past_key_values) + #past_key_values == None + if past_key_values is None: + for m in self.modalities: + m_vals = m.forward(kwargs.get(m.name)) + mp_vals = [] + if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK: + proj = {} + for task_name in m.tasks["task_heads"].keys(): + proj[task_name] = getattr(model, m.name + "_" + task_name) + else: + proj = getattr(model, m.name + "_lmm_projector") + + # project each batch into language model token space + for m_val in m_vals: + if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK: + for task_name in m.tasks["task_heads"].keys(): + if task_name == "lmm_projector": + mp_vals.append(proj[task_name](m_val)) + # make_dot(mp_vals[-1], params=dict(list(model.named_parameters()))).render(task_name, format="png") + else: + if task_name not in task_vals: + task_vals[task_name] = [proj[task_name](m_val)] + else: + task_vals[task_name].append(proj[task_name](m_val)) + # make_dot(task_vals[task_name], params=dict(list(model.named_parameters()))).render(task_name, format="png") + + elif m.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK: + task_outputs = proj(m_val) + mp_vals.append(task_outputs.pop("projectors")) + for task_name in task_outputs.keys(): + if not task_name in task_vals: + task_vals[task_name] = [task_outputs[task_name]] + else: + task_vals[task_name].append(task_outputs[task_name]) + else: + mp_vals.append(proj(m_val)) + + assert all( + mp_val.shape[1:] == (m.token_width, self.config.hidden_size) + for mp_val in mp_vals + ), ( + "Modality tensors have incorrect shape, check your projector implementation " + + str([mp_val.shape[1:] for mp_val in mp_vals]) + + " vs expected " + + str((m.token_width, self.config.hidden_size)) + ) + projected_tensors.append(mp_vals) + + indices = None + for i, input_ids_sample in enumerate(input_ids): + is_text_mask = input_ids_sample >= 0 + + # fill in all the LLM-based text embeddings + inputs_embeds[i, is_text_mask] = model.embed_tokens( + input_ids_sample[is_text_mask] + ) + + # skip if all tokens are text tokens + if is_text_mask.sum() == seq_len: + continue + assert ( + past_key_values is None + ), "We shouldn't have cached keys if this is the first instruction pass" + + #past_key_values = None + + for mi, m in enumerate(self.modalities): + # locate the group of tokens for this modality + m_mask = (input_ids_sample == m.token_idx).float() + m_kernel = torch.tensor( + [-1] * m.token_width, dtype=m_mask.dtype, device=m_mask.device + ) + m_conv = conv1d( + m_mask.unsqueeze(0).unsqueeze(0), + m_kernel.unsqueeze(0).unsqueeze(0), + ) + + # where do we see `token_width`-tokens in a row? + indices = (m_conv[0, 0] == -m.token_width).nonzero(as_tuple=True)[0] + + # fill these embeddings with the projected modality tensor + last_covered_idx = -1 + k = 0 + for possible_token_idx in indices: + if possible_token_idx <= last_covered_idx: + # make sure we don't overwrite an instance we've already covered + # handles bug caused by back-to-back tokens + continue + batch_modality_tensor = projected_tensors[mi][i][k] + inputs_embeds[ + i, possible_token_idx : possible_token_idx + m.token_width + ] = batch_modality_tensor + last_covered_idx = possible_token_idx + m.token_width - 1 + k += 1 + + return None, attention_mask, past_key_values, inputs_embeds, labels, task_vals diff --git a/src/sonicverse/multi_token/language_models/mistral.py b/src/sonicverse/multi_token/language_models/mistral.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb1f00696f4554859311bda4ffd7fb04726ab6a --- /dev/null +++ b/src/sonicverse/multi_token/language_models/mistral.py @@ -0,0 +1,235 @@ + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss + +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + MistralConfig, + MistralModel, + MistralForCausalLM, +) + +from transformers.modeling_outputs import CausalLMOutputWithPast + +from multi_token.language_models.base_model import ( + LMMMetaModel, + LMMMetaForCausalLM, +) + + +class MistralLMMConfig(MistralConfig): + model_type = "mistral-lmm" + + +class MistralLMMModel(LMMMetaModel, MistralModel): + config_class = MistralLMMConfig + + def __init__(self, config: MistralLMMConfig): + super(MistralLMMModel, self).__init__(config) + + +class MistralLMMForCausalLM(MistralForCausalLM, LMMMetaForCausalLM): + config_class = MistralLMMConfig + + def __init__(self, config): + super(MistralForCausalLM, self).__init__(config) + self.model = MistralLMMModel(config) + + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.modalities = None + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self) -> "MistralLMMForCausalLM": + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[List] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, CausalLMOutputWithPast]: + #print("Past keys ",past_key_values) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if labels != None: + labels_inp = labels[0] + else: + labels_inp = labels + ( + input_ids, + attention_mask, + past_key_values, + inputs_embeds, + lmm_labels, + task_values + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, attention_mask, past_key_values, labels_inp, **kwargs + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + # print("Labels 1 size ", len(labels[1])) + # print("labels 1 element size ", len(labels[1][0])) + # print("labels 1 element 1 task size ", labels[1][0][0].shape) + # print("labels 1 element 2 task size ", labels[1][0][1].shape) + # print("labels 1 element 3 task size ", labels[1][0][2].shape) + # print("task vals size ", len(task_values)) + # for task in task_values.keys(): + # print(" task ", task, len(task_values[task])) + # print(" task element", task, task_values[task][0].shape) + + + if labels != None: + task_pairs = {} + task_list = list(task_values.keys()) + for task_id in range(len(task_list)): + _task_labels = [] + _task_outputs = [] + + _task = task_list[task_id] + for inst in range(len(task_values[_task])): + # print("task output shape ", _task, task_values[_task][inst].shape) + _task_outputs.append(task_values[_task][inst].unsqueeze(0)) + _task_labels.append(torch.stack([labels[1][inst][task_id]])) + + task_pairs[_task] = [_task_labels, _task_outputs] + # print("TASK ", _task) + # print(" LABELS LEN ", len(task_pairs[_task][0])) + # print(" LABELS ELEM shape ", task_pairs[_task][0][0].shape) + # print(" VALUES LEN ", len(task_pairs[_task][1])) + # print(" VALUES ELEM shape ", task_pairs[_task][1][0].shape) + + loss = None + if lmm_labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = lmm_labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + # print("loss ", loss) + + + if labels != None: + task_loss = {} + for task in task_list: + preds = torch.cat(task_pairs[task][1], dim=0) + labs = torch.cat(task_pairs[task][0], dim=0) + preds_flat = preds.view(-1, preds.size(-1)) # Reshape to (batch_size * sequence_length, num_classes) + labs_flat = labs.view(-1) # Reshape to (batch_size * sequence_length) + + #print("task ", task) + #print("preds shape ", preds.shape) + #print("labs shape ", labs.shape) + if task == "lmm_projector": + task_loss[task] = CrossEntropyLoss()(preds,labs) + else: + task_loss[task] = nn.BCEWithLogitsLoss()(preds, labs) + # print("task losses ", task_loss) + + total_loss = None + if labels != None: + total_task_loss = None + for task in task_list: + if self.modalities[0].tasks["task_heads"][task]["weight"] != 0.0: + if total_task_loss != None: + total_task_loss += self.modalities[0].tasks["task_heads"][task]["weight"]*task_loss[task] + else: + total_task_loss = self.modalities[0].tasks["task_heads"][task]["weight"]*task_loss[task] + + if total_task_loss != None: + total_loss = self.modalities[0].tasks["task_heads"]["lmm_projector"]["weight"]*loss + total_task_loss + else: + total_loss = loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (total_loss,) + output if total_loss is not None else output + + return CausalLMOutputWithPast( + loss=total_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + modality_inputs=None, + **kwargs + ): + #print("hoooo", past_key_values) + + #past_key_values = None + if past_key_values: + input_ids = input_ids[:, -1:] + + if inputs_embeds is not None: + raise ValueError("inputs_embeds not supported") + + model_inputs = { + "input_ids": input_ids, + "position_ids": None, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + **(modality_inputs or {}), + } + + return model_inputs + + +AutoConfig.register("mistral-lmm", MistralLMMConfig) +AutoModelForCausalLM.register(MistralLMMConfig, MistralLMMForCausalLM) diff --git a/src/sonicverse/multi_token/modalities/__init__.py b/src/sonicverse/multi_token/modalities/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e54dffcbd71b97a867051213c879cc4cb13f71b5 --- /dev/null +++ b/src/sonicverse/multi_token/modalities/__init__.py @@ -0,0 +1,31 @@ +from multi_token.model_utils import MultiTaskType +from multi_token.modalities.vision_clip import ( + CLIPVisionModality, + OUTPUT_LAYER as CLIP_POOL_LAYER, +) +from multi_token.modalities.imagebind import ImageBindModality +from multi_token.modalities.document_gte import DocumentGTEModality +from multi_token.modalities.audio_whisper import WhisperAudioModality +from multi_token.modalities.audio_clap import CLAPAudioModality +from multi_token.modalities.video_xclip import XCLIPVideoModality +from multi_token.modalities.audio_descript import DescriptAudioModality +from multi_token.modalities.audio_mert import MERTAudioModality + +MODALITY_BUILDERS = { + "vision_clip": lambda: [CLIPVisionModality()], + "vision_clip_pool": lambda: [ + CLIPVisionModality(feature_layer=CLIP_POOL_LAYER, num_tokens_output=10) + ], + "audio_whisper": lambda: [ + WhisperAudioModality( + num_tokens_output=10, model_name_or_path="openai/whisper-small" + ) + ], + "audio_mert": lambda use_multi_task=MultiTaskType.NO_MULTI_TASK, tasks_config=None :[MERTAudioModality(use_multi_task=use_multi_task, tasks_config=tasks_config, num_tokens_output=60, hidden_dim=32, num_conv_layers = 3, num_mlp_layers = 2)], + "audio_clap": lambda use_multi_task=MultiTaskType.NO_MULTI_TASK, tasks_config=None :[CLAPAudioModality(use_multi_task=use_multi_task, tasks_config=tasks_config, num_tokens_output=20)], + "audio_descript": lambda use_multi_task=MultiTaskType.NO_MULTI_TASK, tasks_config=None : [DescriptAudioModality(use_multi_task=use_multi_task, tasks_config=tasks_config, num_projector_conv_layers=1, num_projector_mlp_layers=1, num_tokens_output=60, codebooks=96)], + "video_xclip": lambda: [XCLIPVideoModality(num_tokens_output=10)], + "imagebind": lambda: [ImageBindModality()], + "document_gte": lambda: [DocumentGTEModality()], + "document_gte_x16": lambda: [DocumentGTEModality(num_tokens_output=32)], +} diff --git a/src/sonicverse/multi_token/modalities/audio_clap.py b/src/sonicverse/multi_token/modalities/audio_clap.py new file mode 100644 index 0000000000000000000000000000000000000000..10515c6e2039ac49a8f89f25ea6050e83dcc4824 --- /dev/null +++ b/src/sonicverse/multi_token/modalities/audio_clap.py @@ -0,0 +1,142 @@ +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +from transformers import ClapModel, ClapProcessor + +from multi_token.model_utils import MultiTaskType +from multi_token.data_tools import load_audio +from multi_token.modalities.base_modality import Modality +from multi_token.modalities.projectors import ( + build_mlp_vector_projector, build_mt_vector_projector, MultiTaskModel +) + +import json + +OUTPUT_EMB_SIZE = 512 + + +class CLAPAudioModule(nn.Module): + def __init__(self, model_name_or_path: str): + super().__init__() + self.model_name_or_path = model_name_or_path + self.model = None + self.processor = None + + self.load_model() + + def load_model(self): + self.model = ClapModel.from_pretrained(self.model_name_or_path) + self.processor = ClapProcessor.from_pretrained(self.model_name_or_path) + self.model.requires_grad_(False) + + @torch.no_grad() + def forward(self, audios) -> torch.Tensor: + embs = [] + for audio_features in audios: + features = self.model.get_audio_features( + input_features=audio_features["input_features"].to(torch.float32), + is_longer=audio_features["is_longer"], + ) + embs.append(features) + embs = torch.stack(embs) + return embs.view(-1, 1, OUTPUT_EMB_SIZE) + + @property + def dtype(self): + return self.model.dtype + + @property + def device(self): + return self.model.device + + +class CLAPAudioModality(Modality): + def __init__( + self, + model_name_or_path: str = "laion/clap-htsat-fused", + num_projector_layers: int = 2, + num_tokens_output: int = 10, + use_multi_task: int = MultiTaskType.NO_MULTI_TASK, + tasks_config: str = None + ): + self.model_name_or_path = model_name_or_path + self.module = CLAPAudioModule(model_name_or_path=self.model_name_or_path) + self.num_projector_layers = num_projector_layers + self.num_tokens_output = num_tokens_output + self.dtype = torch.float32 + self.use_multi_task = use_multi_task + self.tasks = None + if self.use_multi_task != MultiTaskType.NO_MULTI_TASK: + with open(tasks_config, 'r') as f: + self.tasks = json.load(f) + + print("Tasks :", self.tasks) + + def build_projector(self, lm_hidden_size: int) -> nn.Module: + if self.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK: + return MultiTaskModel(OUTPUT_EMB_SIZE, self.tasks) + elif self.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK: + return build_mt_vector_projector( + # return build_mlp_vector_projector( + input_hidden_size=OUTPUT_EMB_SIZE, + lm_hidden_size=lm_hidden_size, + # num_layers=self.num_projector_layers, + # num_tokens=self.num_tokens_output, + # ) + tasks = self.tasks + ) + # )["llm_projector"] + else: + return build_mlp_vector_projector( + input_hidden_size=OUTPUT_EMB_SIZE, + lm_hidden_size=lm_hidden_size, + num_layers=self.num_projector_layers, + num_tokens=self.num_tokens_output, + ) + + @property + def name(self) -> str: + return "audio_clap" + + @property + def token(self) -> str: + return "" + + @property + def data_key(self) -> str: + return "sounds" + + @property + def token_width(self) -> int: + return self.num_tokens_output + + def to(self, dtype: torch.dtype, device: torch.device) -> "CLAPAudioModality": + self.dtype = dtype + self.module.to(device=device) + return self + + def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Dict]]: + row_values = [] + for row in rows: + audios = [] + for audio_dict in row[self.data_key]: + audio_dict = load_audio( + audio_dict, + target_sampling_rate=self.module.processor.feature_extractor.sampling_rate, + ) + audio_processed = self.module.processor( + audios=audio_dict["array"], + return_tensors="pt", + sampling_rate=audio_dict["sampling_rate"], + ) + audios.append(audio_processed) + row_values.append(audios) + return row_values + + @torch.no_grad() + def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]: + audio_features = [] + for audio_batch in encoded_values: + audio_features.append(self.module.forward(audio_batch).to(dtype=self.dtype)) + return audio_features diff --git a/src/sonicverse/multi_token/modalities/audio_descript.py b/src/sonicverse/multi_token/modalities/audio_descript.py new file mode 100644 index 0000000000000000000000000000000000000000..6ffe2734e250c11f1423e01c2c82cc8ee2d599be --- /dev/null +++ b/src/sonicverse/multi_token/modalities/audio_descript.py @@ -0,0 +1,169 @@ +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +import dac +from audiotools import AudioSignal + +from multi_token.model_utils import MultiTaskType +from multi_token.data_tools import load_audio_signal +from multi_token.modalities.base_modality import Modality +from multi_token.modalities.projectors import ( + build_mlp_vector_projector, build_attentive_cnn_projector, build_cnn_mlp_projector, MultiTaskModel +) + +import json + +OUTPUT_FRAMES_SIZE = 512 +# OUTPUT_EMB_SIZE = 2048 +OUTPUT_EMB_CHANNELS = 96 + +class DescriptAudioModule(nn.Module): + def __init__(self, model_name_or_path: str, codebooks = 4): + super().__init__() + self.model_name_or_path = model_name_or_path + self.model = None + self.processor = None + self.codebooks = codebooks + + self.load_model() + + def load_model(self): + # self.model = ClapModel.from_pretrained(self.model_name_or_path) + self.model = dac.DAC.load(self.model_name_or_path) + + def forward(self, audios) -> torch.Tensor: + embs = [] + for audio_features in audios: + # print("Audio features sample rate ", audio_features[0].sample_rate) + x = self.model.preprocess(audio_features[0].audio_data, audio_features[0].sample_rate) + z, codes, latents, _, _ = self.model.encode(x) + + # print("latents og shape ", latents.shape) + # If the tensor is larger than desired_shape, crop it + if latents.shape[2] > OUTPUT_FRAMES_SIZE: + latents = latents[:, :, :OUTPUT_FRAMES_SIZE] + # If the tensor is smaller than desired_shape, pad it + elif latents.shape[2] < OUTPUT_FRAMES_SIZE: + pad_width = (0, OUTPUT_FRAMES_SIZE - latents.shape[2]) + latents = torch.nn.functional.pad(latents, pad_width) + # print("Codes new shape ", codes_new.shape) + + # print("latents int shape ", latents.shape) + + latents = latents[0][:self.codebooks] + + # print("latents final shape ", latents.shape) + + embs.append(latents) + + embs = torch.stack(embs) + + # output_embs = embs.view(-1, 1, OUTPUT_FRAMES_SIZE*self.codebooks) + # print("embs post view shape ", output_embs.shape) + + return embs + + @property + def dtype(self): + return self.model.dtype + + @property + def device(self): + return self.model.device + + +class DescriptAudioModality(Modality): + def __init__( + self, + model_name_or_path: str = dac.utils.download(model_type="16khz"), + num_projector_conv_layers: int = 2, + num_projector_mlp_layers: int = 2, + num_tokens_output: int = 10, + codebooks: int = 96, + use_multi_task: MultiTaskType = MultiTaskType.NO_MULTI_TASK, + tasks_config: str = None + ): + self.model_name_or_path = model_name_or_path + self.module = DescriptAudioModule(model_name_or_path=self.model_name_or_path, codebooks=codebooks) + self.num_projector_conv_layers = num_projector_conv_layers + self.num_projector_mlp_layers = num_projector_mlp_layers + self.num_tokens_output = num_tokens_output + self.dtype = torch.float32 + self.codebooks = codebooks + self.use_multi_task = use_multi_task + self.tasks = None + if self.use_multi_task != MultiTaskType.NO_MULTI_TASK: + with open(tasks_config, 'r') as f: + self.tasks = json.load(f) + + print("Tasks :", self.tasks) + + + def build_projector(self, lm_hidden_size: int) -> nn.Module: + if self.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK: + projector = MultiTaskModel(OUTPUT_EMB_CHANNELS, 1, True, -1, False, self.tasks) + print("projector ", projector) + return projector + elif self.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK: + return build_mt_vector_projector( + # return build_mlp_vector_projector( + input_hidden_size=OUTPUT_EMB_CHANNELS, + lm_hidden_size=lm_hidden_size, + # num_layers=self.num_projector_layers, + # num_tokens=self.num_tokens_output, + # ) + tasks = self.tasks + ) + # )["llm_projector"] + else: + return build_multi_layer_cnn_mlp_projector( + input_channels = OUTPUT_EMB_CHANNELS, + input_size = OUTPUT_EMB_SIZE, + num_feature_layers= OUTPUT_FEATURE_LAYERS, + lm_hidden_size = lm_hidden_size, + num_tokens = self.num_tokens_output, + hidden_dim = self.hidden_dim, + num_conv_layers = self.num_conv_layers, + num_mlp_layers = self.num_mlp_layers + ) + + @property + def name(self) -> str: + return "audio_descript" + + @property + def token(self) -> str: + return "" + + @property + def data_key(self) -> str: + return "sounds" + + @property + def token_width(self) -> int: + return self.num_tokens_output + + def to(self, dtype: torch.dtype, device: torch.device) -> "DescriptAudioModality": + self.dtype = dtype + self.module.to(device=device) + return self + + def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Dict]]: + row_values = [] + for row in rows: + audios = [] + for audio_dict in row[self.data_key]: + audio_dict = load_audio_signal( + audio_dict + ) + audios.append(audio_dict["array"]) + row_values.append(audios) + return row_values + + @torch.no_grad() + def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]: + audio_features = [] + for audio_batch in encoded_values: + audio_features.append(self.module.forward(audio_batch).to(dtype=self.dtype)) + return audio_features diff --git a/src/sonicverse/multi_token/modalities/audio_descript_bu.py b/src/sonicverse/multi_token/modalities/audio_descript_bu.py new file mode 100644 index 0000000000000000000000000000000000000000..782e189f2b0721cee7d2b88c3e7a956e658135cf --- /dev/null +++ b/src/sonicverse/multi_token/modalities/audio_descript_bu.py @@ -0,0 +1,133 @@ +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +import dac +from audiotools import AudioSignal + + +from multi_token.data_tools import load_audio_signal +from multi_token.modalities.base_modality import Modality +from multi_token.modalities.projectors import ( + build_mlp_vector_projector, build_attentive_cnn_projector, build_cnn_mlp_projector +) + +OUTPUT_FRAMES_SIZE = 512 +# OUTPUT_EMB_SIZE = 2048 + +class DescriptAudioModule(nn.Module): + def __init__(self, model_name_or_path: str, codebooks = 4): + super().__init__() + self.model_name_or_path = model_name_or_path + self.model = None + self.processor = None + self.codebooks = codebooks + + self.load_model() + + def load_model(self): + # self.model = ClapModel.from_pretrained(self.model_name_or_path) + self.model = dac.DAC.load(self.model_name_or_path) + + def forward(self, audios) -> torch.Tensor: + embs = [] + for audio_features in audios: + x = self.model.preprocess(audio_features[0].audio_data, audio_features[0].sample_rate) + z, codes, latents, _, _ = self.model.encode(x) + + # If the tensor is larger than desired_shape, crop it + if codes.shape[2] > OUTPUT_FRAMES_SIZE: + codes = codes[:, :, :OUTPUT_FRAMES_SIZE] + # If the tensor is smaller than desired_shape, pad it + elif codes.shape[2] < OUTPUT_FRAMES_SIZE: + pad_width = (0, OUTPUT_FRAMES_SIZE - codes.shape[2]) + codes = torch.nn.functional.pad(codes, pad_width) + # print("Codes new shape ", codes_new.shape) + + codes_of_interest = codes[0][:self.codebooks] + + embs.append(codes_of_interest) + + embs = torch.stack(embs) + + # output_embs = embs.view(-1, 1, OUTPUT_FRAMES_SIZE*self.codebooks) + # print("embs post view shape ", output_embs.shape) + + return embs + + @property + def dtype(self): + return self.model.dtype + + @property + def device(self): + return self.model.device + + +class DescriptAudioModality(Modality): + def __init__( + self, + model_name_or_path: str = dac.utils.download(model_type="16khz"), + num_projector_conv_layers: int = 2, + num_projector_mlp_layers: int = 2, + num_tokens_output: int = 10, + codebooks: int = 4 + ): + self.model_name_or_path = model_name_or_path + self.module = DescriptAudioModule(model_name_or_path=self.model_name_or_path, codebooks=codebooks) + self.num_projector_conv_layers = num_projector_conv_layers + self.num_projector_mlp_layers = num_projector_mlp_layers + self.num_tokens_output = num_tokens_output + self.dtype = torch.float32 + self.codebooks = codebooks + + def build_projector(self, lm_hidden_size: int) -> nn.Module: + return build_cnn_mlp_projector( + input_channels=self.codebooks, + input_size=OUTPUT_FRAMES_SIZE, + lm_hidden_size=lm_hidden_size, + num_tokens=self.num_tokens_output, + hidden_dim=64, + num_conv_layers=self.num_projector_conv_layers, + num_mlp_layers=self.num_projector_mlp_layers + ) + + @property + def name(self) -> str: + return "audio_descript" + + @property + def token(self) -> str: + return "" + + @property + def data_key(self) -> str: + return "sounds" + + @property + def token_width(self) -> int: + return self.num_tokens_output + + def to(self, dtype: torch.dtype, device: torch.device) -> "DescriptAudioModality": + self.dtype = dtype + self.module.to(device=device) + return self + + def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Dict]]: + row_values = [] + for row in rows: + audios = [] + for audio_dict in row[self.data_key]: + audio_dict = load_audio_signal( + audio_dict + ) + audios.append(audio_dict["array"]) + row_values.append(audios) + return row_values + + @torch.no_grad() + def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]: + audio_features = [] + for audio_batch in encoded_values: + audio_features.append(self.module.forward(audio_batch).to(dtype=self.dtype)) + return audio_features diff --git a/src/sonicverse/multi_token/modalities/audio_mert.py b/src/sonicverse/multi_token/modalities/audio_mert.py new file mode 100644 index 0000000000000000000000000000000000000000..b9368f0beefeb580060962ca9ed66cc29147b8ad --- /dev/null +++ b/src/sonicverse/multi_token/modalities/audio_mert.py @@ -0,0 +1,162 @@ +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +from transformers import Wav2Vec2FeatureExtractor, AutoModel + +from multi_token.model_utils import MultiTaskType +from multi_token.data_tools import load_audio +from multi_token.modalities.base_modality import Modality +from multi_token.modalities.projectors import ( + build_mlp_vector_projector, build_mt_vector_projector, build_multi_layer_cnn_mlp_projector, MultiTaskModel +) +from multi_token.modalities.multi_task_projector_shared import MultiTaskSharedModel + +import json + +OUTPUT_EMB_CHANNELS = 768 #1024 +OUTPUT_EMB_SIZE = 760 +OUTPUT_FEATURE_LAYERS = 13 #25 + +cache_dir="/home/ubuntu/.cache/" + +class MERTAudioModule(nn.Module): + def __init__(self, model_name_or_path: str): + super().__init__() + self.model_name_or_path = model_name_or_path + self.model = None + self.processor = None + + self.load_model() + + def load_model(self): + self.model = AutoModel.from_pretrained(self.model_name_or_path, trust_remote_code=True, cache_dir=cache_dir) + self.processor = Wav2Vec2FeatureExtractor.from_pretrained(self.model_name_or_path,trust_remote_code=True, cache_dir=cache_dir) + self.model.requires_grad_(False) + + @torch.no_grad() + def forward(self, audios) -> torch.Tensor: + embs = [] + for audio_features in audios: + outputs = self.model(**audio_features.to(torch.float32), output_hidden_states=True) + features = torch.stack(outputs.hidden_states).squeeze() + embs.append(features) + embs = torch.stack(embs) + embs = embs.squeeze() + padding_needed = OUTPUT_EMB_SIZE - embs.shape[1] + embs = torch.nn.functional.pad(embs, (0, 0, 0, padding_needed, 0, 0)) + return embs + + @property + def dtype(self): + return self.model.dtype + + @property + def device(self): + return self.model.device + + +class MERTAudioModality(Modality): + def __init__( + self, + model_name_or_path: str = "m-a-p/MERT-v1-95M", + num_tokens_output: int = 10, + hidden_dim: int = 32, + num_conv_layers: int = 5, + num_mlp_layers: int = 5, + use_multi_task: MultiTaskType = MultiTaskType.NO_MULTI_TASK, + tasks_config: str = None + ): + self.model_name_or_path = model_name_or_path + self.module = MERTAudioModule(model_name_or_path=self.model_name_or_path) + self.num_tokens_output = num_tokens_output + self.hidden_dim = hidden_dim + self.num_conv_layers = num_conv_layers + self.num_mlp_layers = num_mlp_layers + self.dtype = torch.float32 + self.use_multi_task = use_multi_task + self.tasks = None + if self.use_multi_task != MultiTaskType.NO_MULTI_TASK: + with open(tasks_config, 'r') as f: + self.tasks = json.load(f) + + print("Tasks :", self.tasks) + + # all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze() + # print(all_layer_hidden_states.shape) # [25 layer, Time steps, 1024 feature_dim] + # time_reduced_hidden_states = all_layer_hidden_states.mean(-2) + # print(time_reduced_hidden_states.shape) # [25, 1024] + + def build_projector(self, lm_hidden_size: int) -> nn.Module: + if self.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK: + projector = MultiTaskSharedModel(self.tasks) + print("projector ", projector) + return projector + elif self.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK: + return build_mt_vector_projector( + # return build_mlp_vector_projector( + input_hidden_size=OUTPUT_EMB_SIZE, + lm_hidden_size=lm_hidden_size, + # num_layers=self.num_projector_layers, + # num_tokens=self.num_tokens_output, + # ) + tasks = self.tasks + ) + # )["llm_projector"] + else: + return build_multi_layer_cnn_mlp_projector( + input_channels = OUTPUT_EMB_CHANNELS, + input_size = OUTPUT_EMB_SIZE, + num_feature_layers= OUTPUT_FEATURE_LAYERS, + lm_hidden_size = lm_hidden_size, + num_tokens = self.num_tokens_output, + hidden_dim = self.hidden_dim, + num_conv_layers = self.num_conv_layers, + num_mlp_layers = self.num_mlp_layers + ) + + @property + def name(self) -> str: + return "audio_mert" + + @property + def token(self) -> str: + return "" + + @property + def data_key(self) -> str: + return "sounds" + + @property + def token_width(self) -> int: + return self.num_tokens_output + + def to(self, dtype: torch.dtype, device: torch.device) -> "MERTAudioModality": + self.dtype = dtype + self.module.to(device=device) + return self + + def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Dict]]: + row_values = [] + for row in rows: + audios = [] + for audio_dict in row[self.data_key]: + audio_dict = load_audio( + audio_dict, + target_sampling_rate=self.module.processor.sampling_rate, + ) + audio_processed = self.module.processor( + audio_dict["array"], + return_tensors="pt", + sampling_rate=audio_dict["sampling_rate"], + ) + audios.append(audio_processed) + row_values.append(audios) + return row_values + + @torch.no_grad() + def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]: + audio_features = [] + for audio_batch in encoded_values: + audio_features.append(self.module.forward(audio_batch).to(dtype=self.dtype)) + return audio_features diff --git a/src/sonicverse/multi_token/modalities/audio_mert_bu.py b/src/sonicverse/multi_token/modalities/audio_mert_bu.py new file mode 100644 index 0000000000000000000000000000000000000000..f28cae6f78a2db82418818ec153a026527e8ceb8 --- /dev/null +++ b/src/sonicverse/multi_token/modalities/audio_mert_bu.py @@ -0,0 +1,159 @@ +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +from transformers import Wav2Vec2FeatureExtractor, AutoModel + +from multi_token.model_utils import MultiTaskType +from multi_token.data_tools import load_audio +from multi_token.modalities.base_modality import Modality +from multi_token.modalities.projectors import ( + build_mlp_vector_projector, build_mt_vector_projector, build_multi_layer_cnn_mlp_projector, MultiTaskModel +) + +import json + +OUTPUT_EMB_CHANNELS = 1024 +OUTPUT_EMB_SIZE = 760 +OUTPUT_FEATURE_LAYERS = 25 + +class MERTAudioModule(nn.Module): + def __init__(self, model_name_or_path: str): + super().__init__() + self.model_name_or_path = model_name_or_path + self.model = None + self.processor = None + + self.load_model() + + def load_model(self): + self.model = AutoModel.from_pretrained(self.model_name_or_path, trust_remote_code=True) + self.processor = Wav2Vec2FeatureExtractor.from_pretrained(self.model_name_or_path,trust_remote_code=True) + self.model.requires_grad_(False) + + @torch.no_grad() + def forward(self, audios) -> torch.Tensor: + embs = [] + for audio_features in audios: + outputs = self.model(**audio_features.to(torch.float32), output_hidden_states=True) + features = torch.stack(outputs.hidden_states).squeeze() + embs.append(features) + embs = torch.stack(embs) + embs = embs.squeeze() + padding_needed = OUTPUT_EMB_SIZE - embs.shape[1] + embs = torch.nn.functional.pad(embs, (0, 0, 0, padding_needed, 0, 0)) + return embs + + @property + def dtype(self): + return self.model.dtype + + @property + def device(self): + return self.model.device + + +class MERTAudioModality(Modality): + def __init__( + self, + model_name_or_path: str = "m-a-p/MERT-v1-330M", + num_tokens_output: int = 10, + hidden_dim: int = 32, + num_conv_layers: int = 5, + num_mlp_layers: int = 5, + use_multi_task: MultiTaskType = MultiTaskType.NO_MULTI_TASK, + tasks_config: str = None + ): + self.model_name_or_path = model_name_or_path + self.module = MERTAudioModule(model_name_or_path=self.model_name_or_path) + self.num_tokens_output = num_tokens_output + self.hidden_dim = hidden_dim + self.num_conv_layers = num_conv_layers + self.num_mlp_layers = num_mlp_layers + self.dtype = torch.float32 + self.use_multi_task = use_multi_task + self.tasks = None + if self.use_multi_task != MultiTaskType.NO_MULTI_TASK: + with open(tasks_config, 'r') as f: + self.tasks = json.load(f) + + print("Tasks :", self.tasks) + + # all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze() + # print(all_layer_hidden_states.shape) # [25 layer, Time steps, 1024 feature_dim] + # time_reduced_hidden_states = all_layer_hidden_states.mean(-2) + # print(time_reduced_hidden_states.shape) # [25, 1024] + + def build_projector(self, lm_hidden_size: int) -> nn.Module: + if self.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK: + projector = MultiTaskModel(OUTPUT_EMB_CHANNELS, OUTPUT_FEATURE_LAYERS, True, self.tasks) + print("projector ", projector) + return projector + elif self.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK: + return build_mt_vector_projector( + # return build_mlp_vector_projector( + input_hidden_size=OUTPUT_EMB_SIZE, + lm_hidden_size=lm_hidden_size, + # num_layers=self.num_projector_layers, + # num_tokens=self.num_tokens_output, + # ) + tasks = self.tasks + ) + # )["llm_projector"] + else: + return build_multi_layer_cnn_mlp_projector( + input_channels = OUTPUT_EMB_CHANNELS, + input_size = OUTPUT_EMB_SIZE, + num_feature_layers= OUTPUT_FEATURE_LAYERS, + lm_hidden_size = lm_hidden_size, + num_tokens = self.num_tokens_output, + hidden_dim = self.hidden_dim, + num_conv_layers = self.num_conv_layers, + num_mlp_layers = self.num_mlp_layers + ) + + @property + def name(self) -> str: + return "audio_mert" + + @property + def token(self) -> str: + return "" + + @property + def data_key(self) -> str: + return "sounds" + + @property + def token_width(self) -> int: + return self.num_tokens_output + + def to(self, dtype: torch.dtype, device: torch.device) -> "MERTAudioModality": + self.dtype = dtype + self.module.to(device=device) + return self + + def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Dict]]: + row_values = [] + for row in rows: + audios = [] + for audio_dict in row[self.data_key]: + audio_dict = load_audio( + audio_dict, + target_sampling_rate=self.module.processor.sampling_rate, + ) + audio_processed = self.module.processor( + audio_dict["array"], + return_tensors="pt", + sampling_rate=audio_dict["sampling_rate"], + ) + audios.append(audio_processed) + row_values.append(audios) + return row_values + + @torch.no_grad() + def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]: + audio_features = [] + for audio_batch in encoded_values: + audio_features.append(self.module.forward(audio_batch).to(dtype=self.dtype)) + return audio_features diff --git a/src/sonicverse/multi_token/modalities/audio_whisper.py b/src/sonicverse/multi_token/modalities/audio_whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..6bdeaa1ea60b69f79065ba328e62524a2f4777d4 --- /dev/null +++ b/src/sonicverse/multi_token/modalities/audio_whisper.py @@ -0,0 +1,120 @@ +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +from transformers import AutoFeatureExtractor, WhisperModel + +from multi_token.data_tools import load_audio +from multi_token.modalities.base_modality import Modality +from multi_token.modalities.projectors import ( + build_mlp_vector_projector, +) + + +OUTPUT_EMB_SIZE = 768 + + +class WhisperAudioModule(nn.Module): + def __init__(self, model_name_or_path: str): + super().__init__() + self.model_name_or_path = model_name_or_path + self.model = None + self.feature_extractor = None + + self.load_model() + + def load_model(self): + self.model = WhisperModel.from_pretrained(self.model_name_or_path) + self.feature_extractor = AutoFeatureExtractor.from_pretrained( + self.model_name_or_path + ) + self.model.requires_grad_(False) + + @torch.no_grad() + def forward(self, audios) -> torch.Tensor: + hidden_states = [] + for i in range(audios.shape[0]): + decoder_input_ids = ( + torch.tensor([[1]]) * self.model.config.decoder_start_token_id + ) + last_hidden_state = self.model( + audios[i].to(device=self.device, dtype=self.dtype), + decoder_input_ids=decoder_input_ids.to(device=self.device), + ).last_hidden_state + hidden_states.append(last_hidden_state) + last_hidden_state = torch.stack(hidden_states) + return last_hidden_state.view(-1, 1, OUTPUT_EMB_SIZE) + + @property + def dtype(self): + return self.model.dtype + + @property + def device(self): + return self.model.device + + +class WhisperAudioModality(Modality): + def __init__( + self, + model_name_or_path: str = "openai/whisper-small", + num_projector_layers: int = 2, + num_tokens_output: int = 10, + ): + self.model_name_or_path = model_name_or_path + self.module = WhisperAudioModule(model_name_or_path=self.model_name_or_path) + self.num_projector_layers = num_projector_layers + self.num_tokens_output = num_tokens_output + + def build_projector(self, lm_hidden_size: int) -> nn.Module: + return build_mlp_vector_projector( + input_hidden_size=OUTPUT_EMB_SIZE, + lm_hidden_size=lm_hidden_size, + num_layers=self.num_projector_layers, + num_tokens=self.num_tokens_output, + ) + + @property + def name(self) -> str: + return "audio_whisper" + + @property + def token(self) -> str: + return "" + + @property + def data_key(self) -> str: + return "speech_audios" + + @property + def token_width(self) -> int: + return self.num_tokens_output + + def to(self, dtype: torch.dtype, device: torch.device) -> "WhisperAudioModality": + self.module.to(dtype=dtype, device=device) + return self + + def preprocess_rows(self, rows: List[Dict]) -> List[Optional[torch.Tensor]]: + row_values = [] + for row in rows: + audios = [] + for audio_dict in row[self.data_key]: + audio_dict = load_audio( + audio_dict, + target_sampling_rate=self.module.feature_extractor.sampling_rate, + ) + audio_processed = self.module.feature_extractor( + audio_dict["array"], + return_tensors="pt", + sampling_rate=audio_dict["sampling_rate"], + ).input_features + audios.append(audio_processed) + row_values.append(torch.stack(audios) if len(audios) > 0 else None) + return row_values + + @torch.no_grad() + def forward(self, encoded_values: List[torch.Tensor]) -> List[torch.Tensor]: + audio_features = [] + for audio_batch in encoded_values: + audio_features.append(self.module.forward(audio_batch)) + return audio_features diff --git a/src/sonicverse/multi_token/modalities/base_modality.py b/src/sonicverse/multi_token/modalities/base_modality.py new file mode 100644 index 0000000000000000000000000000000000000000..3a37d7a6ee1adc5de2e728e952eea991faf9eb13 --- /dev/null +++ b/src/sonicverse/multi_token/modalities/base_modality.py @@ -0,0 +1,48 @@ +from typing import Dict, List, Optional, Any +from abc import ABC, abstractmethod +from functools import cached_property + +import torch.nn as nn +import torch + + +class Modality(ABC): + @abstractmethod + def build_projector(self, lm_hidden_size: int) -> nn.Module: + pass + + @property + @abstractmethod + def name(self) -> str: + pass + + @property + @abstractmethod + def token(self) -> str: + pass + + @property + @abstractmethod + def data_key(self) -> str: + pass + + @property + @abstractmethod + def token_width(self) -> int: + pass + + @cached_property + def token_idx(self) -> int: + hash_ = sum(ord(c) ** i for i, c in enumerate(self.token)) + return -abs(hash_ % 10_000) + + @abstractmethod + def preprocess_rows(self, rows: List[Dict]) -> List[Optional[Any]]: + pass + + @abstractmethod + def forward(self, encoded_values: List[Any]) -> List[torch.Tensor]: + pass + + def to(self, dtype: torch.dtype, device: torch.device) -> "Modality": + return self diff --git a/src/sonicverse/multi_token/modalities/bu__init__.py b/src/sonicverse/multi_token/modalities/bu__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6b90ace97caedbd25ef199311481a500d4e1e36 --- /dev/null +++ b/src/sonicverse/multi_token/modalities/bu__init__.py @@ -0,0 +1,31 @@ +from multi_token.model_utils import MultiTaskType +from multi_token.modalities.vision_clip import ( + CLIPVisionModality, + OUTPUT_LAYER as CLIP_POOL_LAYER, +) +from multi_token.modalities.imagebind import ImageBindModality +from multi_token.modalities.document_gte import DocumentGTEModality +from multi_token.modalities.audio_whisper import WhisperAudioModality +from multi_token.modalities.audio_clap import CLAPAudioModality +from multi_token.modalities.video_xclip import XCLIPVideoModality +from multi_token.modalities.audio_descript import DescriptAudioModality +from multi_token.modalities.audio_mert import MERTAudioModality + +MODALITY_BUILDERS = { + "vision_clip": lambda: [CLIPVisionModality()], + "vision_clip_pool": lambda: [ + CLIPVisionModality(feature_layer=CLIP_POOL_LAYER, num_tokens_output=10) + ], + "audio_whisper": lambda: [ + WhisperAudioModality( + num_tokens_output=10, model_name_or_path="openai/whisper-small" + ) + ], + "audio_mert": lambda use_multi_task=MultiTaskType.NO_MULTI_TASK, tasks_config=None :[MERTAudioModality(use_multi_task=use_multi_task, tasks_config=tasks_config, num_tokens_output=60, hidden_dim=32, num_conv_layers = 3, num_mlp_layers = 2)], + "audio_clap": lambda use_multi_task=MultiTaskType.NO_MULTI_TASK, tasks_config=None :[CLAPAudioModality(use_multi_task=use_multi_task, tasks_config=tasks_config, num_tokens_output=20)], + "audio_descript": lambda: [DescriptAudioModality(num_projector_conv_layers=1, num_projector_mlp_layers=1, num_tokens_output=5, codebooks=12)], + "video_xclip": lambda: [XCLIPVideoModality(num_tokens_output=10)], + "imagebind": lambda: [ImageBindModality()], + "document_gte": lambda: [DocumentGTEModality()], + "document_gte_x16": lambda: [DocumentGTEModality(num_tokens_output=32)], +} diff --git a/src/sonicverse/multi_token/modalities/document_gte.py b/src/sonicverse/multi_token/modalities/document_gte.py new file mode 100644 index 0000000000000000000000000000000000000000..d4e3ec195353bfc1e65c55d936cee56b22207955 --- /dev/null +++ b/src/sonicverse/multi_token/modalities/document_gte.py @@ -0,0 +1,144 @@ +from typing import Dict, List + +import torch +import torch.nn as nn +import os +from functools import cache +from transformers import AutoTokenizer, AutoModel + +from multi_token.modalities.base_modality import Modality +from multi_token.modalities.projectors import build_mlp_vector_projector + +GTE_EMBEDDING_SIZE = 1024 +GTE_CONTEXT_WINDOW = 512 +GTE_DEFAULT_MODEL = "thenlper/gte-large" +DOCUMENT_GTE_FORCE_CPU = "DOCUMENT_GTE_FORCE_CPU" + + +def average_pool( + last_hidden_states: torch.Tensor, attention_mask: torch.Tensor +) -> torch.Tensor: + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + + +@cache +def _get_tokenizer(model_name_or_path: str = GTE_DEFAULT_MODEL): + return AutoTokenizer.from_pretrained(model_name_or_path) + + +def split_text_into_documents(text: str) -> List[str]: + from nltk.tokenize import sent_tokenize + + tokenizer = _get_tokenizer(GTE_DEFAULT_MODEL) + + sentences = sent_tokenize(text) + documents = [[]] + + for sentence in sentences: + sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False) + if len(documents[-1]) + len(sentence_tokens) > GTE_CONTEXT_WINDOW: + documents.append([]) + documents[-1].extend(sentence_tokens) + + return [tokenizer.decode(doc) for doc in documents] + + +class DocumentGTEModule(nn.Module): + def __init__(self, model_name_or_path: str): + super().__init__() + self.feature_layer = -2 + self.model_name_or_path = model_name_or_path + + self.model = AutoModel.from_pretrained("thenlper/gte-large") + self.model.requires_grad_(False) + + @torch.no_grad() + def forward(self, batch_dict) -> torch.Tensor: + outputs = self.model(**batch_dict) + embeddings = average_pool( + outputs.last_hidden_state, batch_dict["attention_mask"] + ) + return embeddings + + @property + def embedding_size(self): + return GTE_EMBEDDING_SIZE + + +class DocumentGTEModality(Modality): + def __init__( + self, + model_name_or_path: str = GTE_DEFAULT_MODEL, + num_projector_layers: int = 2, + num_tokens_output: int = 4, + ): + self.model_name_or_path = model_name_or_path + self.module = DocumentGTEModule(model_name_or_path=self.model_name_or_path) + self.tokenizer = _get_tokenizer(model_name_or_path) + self.num_projector_layers = num_projector_layers + self.num_tokens_output = num_tokens_output + self.dtype = torch.float32 + self.device = "cpu" + self.document_gte_device = "cpu" + + def build_projector(self, lm_hidden_size: int) -> nn.Module: + return build_mlp_vector_projector( + input_hidden_size=self.module.embedding_size, + lm_hidden_size=lm_hidden_size, + num_layers=self.num_projector_layers, + num_tokens=self.num_tokens_output, + ) + + @property + def name(self) -> str: + return "document_gte" + + @property + def token(self) -> str: + return "" + + @property + def data_key(self) -> str: + return "documents" + + @property + def token_width(self) -> int: + return self.num_tokens_output + + def to(self, dtype: torch.dtype, device: torch.device) -> "DocumentGTEModality": + self.dtype = dtype + self.device = device + if DOCUMENT_GTE_FORCE_CPU not in os.environ: + # running out of VRAM on 24GB GPU + self.document_gte_device = device + self.module.to(device=self.document_gte_device) + return self + + def preprocess_rows(self, rows: List[Dict]) -> List[Dict]: + row_values = [] + for row in rows: + documents = [] + for doc in row[self.data_key]: + documents.append(doc) + documents_tokenized = self.tokenizer( + documents, + max_length=GTE_CONTEXT_WINDOW, + padding=True, + truncation=True, + return_tensors="pt", + ) + row_values.append(documents_tokenized) + return row_values + + @torch.no_grad() + def forward(self, encoded_values: List[Dict]) -> List[torch.Tensor]: + outputs = [] + for val in encoded_values: + outputs.append( + self.module.forward(val.to(device=self.document_gte_device)) + .to(device=self.device, dtype=self.dtype) + .view(-1, 1, self.module.embedding_size) + ) + # batch_size x num_items x 1 x embedding_size + return outputs diff --git a/src/sonicverse/multi_token/modalities/imagebind.py b/src/sonicverse/multi_token/modalities/imagebind.py new file mode 100644 index 0000000000000000000000000000000000000000..87c41e75f85ea1c4366d08b0461f08c53cfe893c --- /dev/null +++ b/src/sonicverse/multi_token/modalities/imagebind.py @@ -0,0 +1,153 @@ +from typing import Dict, List +import os + +import torch +import torch.nn as nn + +from multi_token.modalities.base_modality import Modality +from multi_token.modalities.projectors import build_mlp_vector_projector +from multi_token.data_tools import with_local_files + +IMAGE_BIND_FORCE_CPU = "IMAGE_BIND_FORCE_CPU" +IMAGE_BIND_EMBEDDING_SIZE = 1024 + + +class ImageBindModule(nn.Module): + def __init__(self): + super().__init__() + from imagebind.models import imagebind_model + from imagebind import data + + data.BPE_PATH = os.path.join( + os.path.dirname(data.__file__), "..", "bpe", "bpe_simple_vocab_16e6.txt.gz" + ) + self.model = imagebind_model.imagebind_huge(pretrained=True) + self.model.eval() + self.model.requires_grad_(False) + + @torch.no_grad() + def forward(self, items: Dict) -> torch.Tensor: + forward_outs = self.model(items) + return forward_outs + + @property + def embedding_size(self): + return IMAGE_BIND_EMBEDDING_SIZE + + +class ImageBindModality(Modality): + def __init__( + self, + num_projector_layers: int = 2, + num_tokens: int = 4, + preprocess_device: str = "cpu", + ): + self.module = ImageBindModule() + self.dtype = torch.float32 + self.device = "cpu" # used for outputs + self.imagebind_device = "cpu" # used for imagebind model itself + self.preprocess_device = preprocess_device # used for preprocessing + self.num_projector_layers = num_projector_layers + self.num_tokens = num_tokens + + def build_projector(self, lm_hidden_size: int) -> nn.Module: + return build_mlp_vector_projector( + self.module.embedding_size, + lm_hidden_size, + num_layers=self.num_projector_layers, + num_tokens=self.num_tokens, + ) + + @property + def name(self) -> str: + return "imagebind" + + @property + def token(self) -> str: + return "" + + @property + def data_key(self) -> str: + return "imagebinds" + + @property + def token_width(self) -> int: + return self.num_tokens + + def to(self, dtype: torch.dtype, device: torch.device) -> "ImageBindModality": + # we ignore dtype and sometimes device as well + self.device = device + self.dtype = dtype + if IMAGE_BIND_FORCE_CPU not in os.environ: + # running out of VRAM on 24GB GPU + self.module.to(device=device) + self.imagebind_device = device + return self + + def preprocess_rows(self, rows: List[Dict]) -> List[List[Dict]]: + from imagebind.models.imagebind_model import ModalityType + from imagebind import data + + row_values = [] + for row in rows: + items = [] + with with_local_files(row[self.data_key]) as item_paths: + for item_path in item_paths: + ib_modality = filename_to_imagebind_modality(item_path) + if ib_modality == ModalityType.TEXT: + items.append( + { + ModalityType.TEXT: data.load_and_transform_text( + [item_path], self.preprocess_device + ) + } + ) + elif ib_modality == ModalityType.VISION: + items.append( + { + ModalityType.VISION: data.load_and_transform_vision_data( + [item_path], self.preprocess_device + ) + } + ) + elif ib_modality == ModalityType.AUDIO: + items.append( + { + ModalityType.AUDIO: data.load_and_transform_audio_data( + [item_path], self.preprocess_device + ) + } + ) + else: + raise ValueError(f"Unknown modality type: {ib_modality}") + row_values.append(items) + return row_values + + @torch.no_grad() + def forward(self, encoded_values: List[List[Dict]]) -> List[torch.Tensor]: + item_features = [] + for item_batch in encoded_values: + item_batch_emb = [] + for item in item_batch: + item = { + k: v.to(device=self.imagebind_device, dtype=torch.float32) + for k, v in item.items() + } + item_batch_emb.extend(list(self.module.forward(item).values())) + item_features.append( + torch.stack(item_batch_emb).to(device=self.device, dtype=self.dtype) + ) + # batch_size x num_items x 1 x embedding_size + return item_features + + +def filename_to_imagebind_modality(fn: str) -> str: + from imagebind.models.imagebind_model import ModalityType + + _, ext = os.path.splitext(fn) + if ext in {".wav"}: + return ModalityType.AUDIO + elif ext in {".jpg", ".png", ".jpeg"}: + return ModalityType.VISION + else: + return ModalityType.TEXT diff --git a/src/sonicverse/multi_token/modalities/multi_task_projector_shared.py b/src/sonicverse/multi_token/modalities/multi_task_projector_shared.py new file mode 100644 index 0000000000000000000000000000000000000000..ec64de5dd403be038763ad01f3fc61b929639dbb --- /dev/null +++ b/src/sonicverse/multi_token/modalities/multi_task_projector_shared.py @@ -0,0 +1,321 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +import torch.nn.functional as F +from typing import Dict +import numpy as np + +class CNN(nn.Module): + def __init__(self, input_channels = 25, num_class=15): + super(CNN, self).__init__() + self.aggregator = nn.Parameter(torch.randn((input_channels, 1,1), dtype=torch.float)) + self.input_channels = input_channels + + # init bn + self.bn_init = nn.BatchNorm2d(1) + + # layer 1 + self.conv_1 = nn.Conv2d(1, 64, 3, padding=1) + self.bn_1 = nn.BatchNorm2d(64) + self.mp_1 = nn.MaxPool2d((2, 4)) + + # layer 2 + self.conv_2 = nn.Conv2d(64, 128, 3, padding=1) + self.bn_2 = nn.BatchNorm2d(128) + self.mp_2 = nn.MaxPool2d((2, 4)) + + # layer 3 + self.conv_3 = nn.Conv2d(128, 128, 3, padding=1) + self.bn_3 = nn.BatchNorm2d(128) + self.mp_3 = nn.MaxPool2d((2, 4)) + + # layer 4 + self.conv_4 = nn.Conv2d(128, 128, 3, padding=1) + self.bn_4 = nn.BatchNorm2d(128) + self.mp_4 = nn.MaxPool2d((3, 5)) + + # layer 5 + self.conv_5 = nn.Conv2d(128, 64, 3, padding=1) + self.bn_5 = nn.BatchNorm2d(64) + self.mp_5 = nn.MaxPool2d((3, 3)) + + # classifier + self.dense = nn.Linear(640, num_class) + self.dropout = nn.Dropout(0.5) + + def forward(self, x): + aggregator_weights = F.softmax(self.aggregator) + # aggregator_weights = aggregator_weights.view(self.input_channels, 1) + # print("0 x shape : ") + x = (x * aggregator_weights).sum(dim=0) + + # print("aggregator_output shape ", x.shape) + + x = x.unsqueeze(0).unsqueeze(0) + + # print("1 x shape ", x.shape) + # init bn + x = self.bn_init(x) + # print("2 x shape ", x.shape) + + # layer 1 + x = self.mp_1(nn.ELU()(self.bn_1(self.conv_1(x)))) + # print("3 x shape ", x.shape) + + # layer 2 + x = self.mp_2(nn.ELU()(self.bn_2(self.conv_2(x)))) + # print("4 x shape ", x.shape) + + # layer 3 + x = self.mp_3(nn.ELU()(self.bn_3(self.conv_3(x)))) + # print("5 x shape ", x.shape) + + # layer 4 + x = self.mp_4(nn.ELU()(self.bn_4(self.conv_4(x)))) + # print("6 x shape ", x.shape) + + # layer 5 + x = self.mp_5(nn.ELU()(self.bn_5(self.conv_5(x)))) + # print("7 x shape ", x.shape) + + # classifier + x = x.view(x.size(0), -1) + # print("8 x shape ", x.shape) + x = self.dropout(x) + # print("9 x shape ", x.shape) + logit = nn.Sigmoid()(self.dense(x)) + # print("logit shape ", logit.shape) + + return logit + + +class MLP(nn.Module): + def __init__(self, input_channels=25, num_class=15): + super(MLP, self).__init__() + self.aggregator = nn.Parameter(torch.randn((input_channels, 1,1), dtype=torch.float)) + self.input_channels = input_channels + + self.hidden_layer_1 = nn.Linear(768, 512) + self.output = nn.Linear(512, num_class) + self.dropout = nn.Dropout(p=0.2) + self.loss = self.get_loss() # can return a dict of losses + + def forward(self, x): + """ + x: (B, L, T, H) + T=#chunks, can be 1 or several chunks + """ + + weights = F.softmax(self.aggregator, dim=1) + x = (x * weights).sum(dim=1) + + x = x.mean(-2) + + x = self.hidden_layer_1(x) + x = F.relu(x) + x = self.dropout(x) + + return self.output(x) + + def get_loss(self): + return nn.BCEWithLogitsLoss() + +class MLPBackbone(nn.Module): + def __init__(self, input_features=768, hidden_dim=512): + super(MLPBackbone, self).__init__() + + self.hidden_layer_1 = nn.Linear(input_features, hidden_dim) + self.dropout = nn.Dropout(p=0.2) + self.loss = self.get_loss() # can return a dict of losses + + def forward(self, x): + """ + x: (B, L, T, H) + T=#chunks, can be 1 or several chunks + """ + + x = self.hidden_layer_1(x) + x = F.relu(x) + x = self.dropout(x) + + return x + + def get_loss(self): + return nn.BCEWithLogitsLoss() + +class MLPShared(nn.Module): + def __init__(self, input_channels=25, num_class=15): + super(MLPShared, self).__init__() + self.aggregator = nn.Parameter(torch.randn((input_channels, 1,1), dtype=torch.float)) + self.input_channels = input_channels + + self.hidden_layer_1 = nn.Linear(512, 256) + self.output = nn.Linear(256, num_class) + self.dropout = nn.Dropout(p=0.2) + self.loss = self.get_loss() # can return a dict of losses + + def forward(self, x): + """ + x: (B, L, T, H) + T=#chunks, can be 1 or several chunks + """ + + weights = F.softmax(self.aggregator, dim=1) + x = (x * weights).sum(dim=1) + + x = x.mean(-2) + + x = self.hidden_layer_1(x) + x = F.relu(x) + x = self.dropout(x) + + return self.output(x) + + def get_loss(self): + return nn.BCEWithLogitsLoss() + +class MLPAggTaskHead(nn.Module): + def __init__(self, input_channels: int, input_size: int, output_size: int, use_aggregator: bool, use_time_average: bool, use_sigmoid: bool, use_transpose: bool, num_layers: int, hidden_dim: int, width: int): + super(MLPAggTaskHead, self).__init__() + if use_aggregator: + self.aggregator = nn.Parameter(torch.randn((input_channels), dtype=torch.float)) + self.use_aggregator = use_aggregator + self.use_time_average = use_time_average + self.use_transpose = use_transpose + self.use_sigmoid = use_sigmoid + self.input_channels = input_channels + self.output_size = output_size + self.width = width + + if self.width > 1: + self.layers = nn.ModuleList() + for i in range(self.width): + mlp_layers = [nn.GELU()] + mlp_layers += self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim) + if self.use_sigmoid: mlp_layers += [nn.Sigmoid()] + self.layers.append(nn.Sequential(*mlp_layers)) + else: + mlp_layers = [nn.GELU()] + mlp_layers += self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim) + if self.use_sigmoid: mlp_layers += [nn.Sigmoid()] + self.layers = nn.Sequential(*mlp_layers) + + def _create_mlp_layers(self, input_size, output_size, num_layers, hidden_dim): + if num_layers >=2: + layers = [nn.Linear(input_size, hidden_dim)] + layers.append(nn.GELU()) + if num_layers > 2: + for _ in range(1, num_layers - 2): + layers += [ + nn.Linear(hidden_dim, hidden_dim), + nn.GELU() + ] + layers.append(nn.Linear(hidden_dim, output_size)) + else: + layers = [nn.Linear(input_size, output_size)] + return layers + + + def forward(self, x): + if self.use_transpose: + x = x.transpose(1, 0) + if self.use_time_average: + x = x.mean(-2) + if self.use_aggregator: + aggregator_weights = F.softmax(self.aggregator) + aggregator_weights = aggregator_weights.view(self.input_channels, 1) + aggregator_output = (x * aggregator_weights).sum(dim=0) + aggregator_output = aggregator_output.unsqueeze(dim=0) + # print("Agg output ", aggregator_output.shape) + else: + aggregator_output = x + + if self.width > 1: + if (self.input_channels < 1): + return torch.cat([layer(aggregator_output.unsqueeze(dim=0)) for layer in self.layers], dim=-2) + else: + return torch.cat([layer(aggregator_output.unsqueeze(dim=0)).squeeze(dim=0) for layer in self.layers], dim=-2) + else: + if (self.input_channels < 1): + return self.layers(aggregator_output.unsqueeze(dim=0)) + else: + return self.layers(aggregator_output.unsqueeze(dim=0)).squeeze() + + +class MultiTaskModel(nn.Module): + def __init__(self, tasks: Dict): + super(MultiTaskModel, self).__init__() + self.tasks = tasks + for task_name, task_head in self.tasks["task_heads"].items(): + setattr(self, task_name, MLP(13, task_head["output_size"])) + if task_name in self.tasks["task_projectors"].keys(): + task_projector = tasks["task_projectors"][task_name] + setattr(self, task_name + "_projector", MLPAggTaskHead(task_projector["input_channels"], task_projector["input_size"], task_projector["output_size"], task_projector["use_aggregator"], task_projector["use_time_average"], task_projector["use_sigmoid"], task_projector["use_transpose"], task_projector["num_layers"], task_projector["hidden_size"], task_projector["width"])) + + def forward(self, x): + task_head_outputs = {} + task_projector_outputs = [] + + backbone_output = x + + for task_name in self.tasks["task_heads"]: + if task_name != "lmm_projector": + task_head_outputs[task_name] = getattr(self, task_name)(backbone_output) + if task_name in self.tasks["task_projectors"].keys(): + task_projector_outputs.append(getattr(self, task_name + "_projector")(task_head_outputs[task_name])) + else: + task_projector_outputs.append(getattr(self, task_name)(backbone_output)) + + if len(task_projector_outputs) > 0: + task_projector_outputs_unsqueezed = [task_projector_output.unsqueeze(0) for task_projector_output in task_projector_outputs] + task_head_outputs["projectors"] = torch.cat(task_projector_outputs_unsqueezed, dim=-2) + + return task_head_outputs + +class MultiTaskSharedModel(nn.Module): + def __init__(self, tasks: Dict): + super(MultiTaskSharedModel, self).__init__() + self.tasks = tasks + self.use_backbone = False + if "backbone" in self.tasks.keys(): + self.use_backbone = True + if self.use_backbone: self.backbone = MLPBackbone(768, 512) + for task_name, task_head in self.tasks["task_heads"].items(): + if task_name != "lmm_projector": + setattr(self, task_name, MLPShared(13, task_head["output_size"])) + else: + setattr(self, task_name, MLPAggTaskHead(task_head["input_channels"], task_head["input_size"], task_head["output_size"], task_head["use_aggregator"], task_head["use_time_average"], task_head["use_sigmoid"], task_head["use_transpose"], task_head["num_layers"], task_head["hidden_size"], task_head["width"])) + if task_name in self.tasks["task_projectors"].keys(): + task_projector = tasks["task_projectors"][task_name] + setattr(self, task_name + "_projector", MLPAggTaskHead(task_projector["input_channels"], task_projector["input_size"], task_projector["output_size"], task_projector["use_aggregator"], task_projector["use_time_average"], task_projector["use_sigmoid"], task_projector["use_transpose"], task_projector["num_layers"], task_projector["hidden_size"], task_projector["width"])) + + def forward(self, x): + task_head_outputs = {} + task_projector_outputs = [] + + if self.use_backbone: + backbone_output = self.backbone(x) + else: + backbone_output = x + + #print("Output shape ", backbone_output.shape) + for task_name in self.tasks["task_heads"]: + #print("task namee ", task_name) + if task_name != "lmm_projector": + task_head_outputs[task_name] = getattr(self, task_name)(backbone_output) + if task_name in self.tasks["task_projectors"].keys(): + task_projector_outputs.append(getattr(self, task_name + "_projector")(task_head_outputs[task_name])) + else: + llm_input = x + if self.tasks["task_heads"][task_name]["use_backbone_output"]: + llm_input = backbone_output + task_projector_outputs.append(getattr(self, task_name)(llm_input)) + + if len(task_projector_outputs) > 0: + task_projector_outputs_unsqueezed = [task_projector_output.unsqueeze(0) for task_projector_output in task_projector_outputs] + task_head_outputs["projectors"] = torch.cat(task_projector_outputs_unsqueezed, dim=-2) + + return task_head_outputs + + + diff --git a/src/sonicverse/multi_token/modalities/projectors.py b/src/sonicverse/multi_token/modalities/projectors.py new file mode 100644 index 0000000000000000000000000000000000000000..8e38327d2c18c8542c2580187b3467bbc7d5c598 --- /dev/null +++ b/src/sonicverse/multi_token/modalities/projectors.py @@ -0,0 +1,416 @@ +import torch.nn as nn +import torch +from typing import Dict +import numpy as np + +import torch.nn.functional as F + +def build_patch_mlp_projector( + input_hidden_size: int, lm_hidden_size: int, num_layers: int +) -> nn.Module: + modules = [nn.Linear(input_hidden_size, lm_hidden_size)] + for _ in range(1, num_layers): + modules.append(nn.GELU()) + modules.append(nn.Linear(lm_hidden_size, lm_hidden_size)) + return nn.Sequential(*modules) + + +class _MLPVectorProjector(nn.Module): + def __init__( + self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int + ): + super(_MLPVectorProjector, self).__init__() + self.mlps = nn.ModuleList() + for _ in range(width): + mlp = [nn.Linear(input_hidden_size, lm_hidden_size)] + for _ in range(1, num_layers): + mlp.append(nn.GELU()) + mlp.append(nn.Linear(lm_hidden_size, lm_hidden_size)) + self.mlps.append(nn.Sequential(*mlp)) + + def forward(self, x): + output = torch.cat([mlp(x) for mlp in self.mlps], dim=-2) + return output + +def build_mlp_vector_projector( + input_hidden_size: int, lm_hidden_size: int, num_layers: int, num_tokens: int +): + return _MLPVectorProjector( + input_hidden_size, lm_hidden_size, num_layers, num_tokens + ) + +class MLPBackbone(nn.Module): + def __init__(self, input_size: int, output_size: int, num_layers: int, hidden_dim: int): + super(MLPBackbone, self).__init__() + self.output_size = output_size + mlp_layers = self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim) + self.layers = nn.Sequential(*mlp_layers) + + def _create_conv_layers(self, input_channels, num_conv_layers, hidden_dim): + layers = [] + for _ in range(num_conv_layers): + layers += [ + nn.Conv1d(input_channels, hidden_dim, kernel_size=3, padding=1), + nn.GELU(), + nn.MaxPool1d(kernel_size=2, stride=2) + ] + input_channels = hidden_dim + return layers + + def _create_mlp_layers(self, input_size, output_size, num_layers, hidden_dim): + if num_layers >=2: + layers = [nn.Linear(input_size, hidden_dim)] + layers.append(nn.GELU()) + if num_layers > 2: + for _ in range(1, num_layers - 2): + layers += [ + nn.Linear(hidden_dim, hidden_dim), + nn.GELU() + ] + layers.append(nn.Linear(hidden_dim, output_size)) + else: + layers = [nn.Linear(input_size, output_size)] + return layers + + def forward(self, x): + return self.layers(x) + +class MLPTaskHead(nn.Module): + def __init__(self, backbone: nn.Module, input_size: int, output_size: int, num_layers: int, hidden_dim: int, width: int = 1): + super(MLPTaskHead, self).__init__() + self.backbone = backbone + self.width = width + if width > 1: + self.layers = nn.ModuleList() + for i in range(width): + mlp_layers = [nn.GELU()] + mlp_layers += self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim) + self.layers.append(nn.Sequential(*mlp_layers)) + else: + mlp_layers = [nn.GELU()] + mlp_layers += self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim) + self.layers = nn.Sequential(*mlp_layers) + + def _create_mlp_layers(self, input_size, output_size, num_layers, hidden_dim): + if num_layers >=2: + layers = [nn.Linear(input_size, hidden_dim)] + layers.append(nn.GELU()) + if num_layers > 2: + for _ in range(1, num_layers - 2): + layers += [ + nn.Linear(hidden_dim, hidden_dim), + nn.GELU() + ] + layers.append(nn.Linear(hidden_dim, output_size)) + else: + layers = [nn.Linear(input_size, output_size)] + return layers + + def _create_conv_layers(self, input_channels, num_conv_layers, hidden_dim): + layers = [] + for _ in range(num_conv_layers): + layers += [ + nn.Conv2d(in_channels = input_channels, out_channels = hidden_dim, kernel_size=(3,3), stride=1, padding=1), + nn.GELU(), + nn.MaxPool1d(kernel_size=2, stride=2) + ] + input_channels = hidden_dim + return layers + + def forward(self, x): + output = self.backbone.forward(x) + if self.width > 1: + return torch.cat([layer(output) for layer in self.layers], dim=-2) + else: + return self.layers(output) + +class MLPTaskModule(nn.Module): + def __init__(self, input_size: int, output_size: int, num_layers: int, hidden_dim: int, width: int = 1): + super(MLPTaskModule, self).__init__() + self.width = width + if width > 1: + self.layers = nn.ModuleList() + for i in range(width): + mlp_layers = [nn.GELU()] + mlp_layers += self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim) + self.layers.append(nn.Sequential(*mlp_layers)) + else: + mlp_layers = [nn.GELU()] + mlp_layers += self._create_mlp_layers(input_size, output_size, num_layers, hidden_dim) + self.layers = nn.Sequential(*mlp_layers) + + def _create_mlp_layers(self, input_size, output_size, num_layers, hidden_dim): + if num_layers >=2: + layers = [nn.Linear(input_size, hidden_dim)] + layers.append(nn.GELU()) + if num_layers > 2: + for _ in range(1, num_layers - 2): + layers += [ + nn.Linear(hidden_dim, hidden_dim), + nn.GELU() + ] + layers.append(nn.Linear(hidden_dim, output_size)) + else: + layers = [nn.Linear(input_size, output_size)] + return layers + + def _create_conv_layers(self, input_channels, num_conv_layers, hidden_dim): + layers = [] + for _ in range(num_conv_layers): + layers += [ + nn.Conv2d(in_channels = input_channels, out_channels = hidden_dim, kernel_size=(3,3), stride=1, padding=1), + nn.GELU(), + nn.MaxPool1d(kernel_size=2, stride=2) + ] + input_channels = hidden_dim + return layers + + def forward(self, x): + if self.width > 1: + return torch.cat([layer(x) for layer in self.layers], dim=-2) + else: + return self.layers(x) + + +class MultiTaskModel(nn.Module): + def __init__(self, input_hidden_size: int, input_channels: int, time_average: bool, time_dimension: int, use_aggregator: bool, tasks: Dict): + super(MultiTaskModel, self).__init__() + self.tasks = tasks + self.time_average = time_average + self.time_dimension = time_dimension + self.use_aggregator = use_aggregator + if self.use_aggregator: + if (time_average): + self.aggregator = nn.Parameter(torch.randn((input_channels, 1), dtype = torch.float)) + else: + self.aggregator = nn.Parameter(torch.randn((input_channels, 1, 1), dtype = torch.float)) + + self.backbone = MLPBackbone(input_hidden_size, self.tasks["backbone"]["output_size"], self.tasks["backbone"]["num_layers"], self.tasks["backbone"]["hidden_size"]) + for task_name, task_head in self.tasks["task_heads"].items(): + setattr(self, task_name, MLPTaskModule(self.tasks["backbone"]["output_size"], task_head["output_size"], task_head["num_layers"], task_head["hidden_size"], task_head["width"])) + if task_name in self.tasks["task_projectors"].keys(): + task_projector = tasks["task_projectors"][task_name] + setattr(self, task_name + "_projector", MLPTaskModule(task_head["output_size"], task_projector["output_size"], task_projector["num_layers"], task_projector["hidden_size"], task_projector["width"])) + + def forward(self, x): + task_head_outputs = {} + task_projector_outputs = [] + + if self.time_average: + x = x.mean(self.time_dimension) + if self.use_aggregator: + aggregator_weights = F.softmax(self.aggregator, dim=0) + aggregator_output = (x * aggregator_weights).sum(dim=0) + aggregator_output = aggregator_output.unsqueeze(0) + else: + aggregator_output = x + + backbone_output = self.backbone(aggregator_output) + + for task_name in self.tasks["task_heads"]: + if task_name != "lmm_projector": + task_head_output = getattr(self, task_name)(backbone_output) + min_val = torch.min(task_head_output) + max_val = torch.max(task_head_output) + + normalized_task_head_output = (task_head_output - min_val) / (max_val - min_val) + task_head_outputs[task_name] = normalized_task_head_output + if task_name in self.tasks["task_projectors"].keys(): + task_projector_outputs.append(getattr(self, task_name + "_projector")(task_head_outputs[task_name])) + else: + task_projector_outputs.append(getattr(self, task_name)(backbone_output)) + + task_projector_outputs_unsqueezed = [task_projector_output.unsqueeze(0) for task_projector_output in task_projector_outputs] + if len(task_projector_outputs_unsqueezed) > 0: + task_head_outputs["projectors"] = torch.cat(task_projector_outputs_unsqueezed, dim=-2) + + return task_head_outputs + + +def build_mt_vector_projector( + input_hidden_size: int, lm_hidden_size: int, tasks: Dict +): + projector = nn.ModuleDict() + projector["backbone"] = MLPBackbone(input_hidden_size, tasks["backbone"]["output_size"], tasks["backbone"]["num_layers"], tasks["backbone"]["hidden_size"]) + for task_name, task_head in tasks["task_heads"].items(): + projector[task_name] = MLPTaskHead(projector["backbone"], task_head["hidden_size"], task_head["output_size"], task_head["num_layers"], task_head["hidden_size"], task_head["width"]) + + return projector + +class Attention(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(Attention, self).__init__() + self.linear_in = nn.Linear(input_dim, hidden_dim) + self.linear_out = nn.Linear(hidden_dim, 1) + + def forward(self, x): + # Input shape: (batch_size, seq_len, input_dim) + energy = torch.tanh(self.linear_in(x)) + attention_scores = torch.softmax(self.linear_out(energy), dim=1) + context_vector = torch.sum(attention_scores * x, dim=1) + return context_vector + +class _CNNAttentionTokenizer(nn.Module): + def __init__(self, input_channels, output_size, width, hidden_dim, num_conv_layers): + super(_CNNAttentionTokenizer, self).__init__() + self.width = width + self.cnns = nn.ModuleList() + self.attentions = nn.ModuleList() + for _ in range(width): + cnn = self._create_conv_layers(input_channels, num_conv_layers) + self.cnns.append(cnn) + attention = [Attention(hidden_dim, 125)] + linear_input_size = hidden_dim + attention.append(nn.Linear(linear_input_size, output_size)) + self.attentions.append(nn.Sequential(*attention)) + + + def _create_conv_layers(self, input_channels, num_conv_layers): + layers = [] + in_channels = input_channels + for _ in range(num_conv_layers): + layers += [ + nn.Conv1d(in_channels, 64, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool1d(kernel_size=2, stride=2) + ] + in_channels = 64 + return nn.Sequential(*layers) + + def forward(self, x): + outputs = [] + for token in range(self.width): + # Input shape: (batch_size, input_channels, sequence_length) + token_output = self.cnns[token](x) # Apply convolutional layers + token_output = token_output.permute(0, 2, 1) # Reshape for attention mechanism (batch_size, sequence_length, input_dim + token_output = self.attentions[token](token_output) # Apply attention mechanism + outputs.append(token_output) + output = torch.cat(outputs, dim=-2) + output = torch.stack([output]) + return output + +def build_attentive_cnn_projector( + input_channels: int, lm_hidden_size: int, num_tokens: int, hidden_dim: int, num_layers: int + ): + return _CNNAttentionTokenizer(input_channels, lm_hidden_size, num_tokens, hidden_dim, num_layers) + +class _CNNMLPProjector(nn.Module): + def __init__(self, input_channels, input_size, output_size = 4096, width = 5, hidden_dim = 64, num_conv_layers = 1, num_mlp_layers = 2): + super(_CNNMLPProjector, self).__init__() + self.width = width + self.cnnmlps = nn.ModuleList() + for _ in range(self.width): + cnnmlp = self._create_conv_layers(input_channels, num_conv_layers, hidden_dim) + cnnmlp.append(nn.Flatten()) + cnn_output_size = hidden_dim*((input_size + 2*1 - 3*num_conv_layers) // (2**num_conv_layers) + 1) + cnnmlp.append(nn.Linear(cnn_output_size, output_size)) + cnnmlp.append(nn.GELU()) + cnnmlp += self._create_mlp_layers(output_size, output_size, num_mlp_layers, output_size) + self.cnnmlps.append(nn.Sequential(*cnnmlp)) + + def _create_conv_layers(self, input_channels, num_conv_layers, hidden_dim): + layers = [] + for _ in range(num_conv_layers): + layers += [ + nn.Conv1d(input_channels, hidden_dim, kernel_size=3, padding=1), + nn.GELU(), + nn.MaxPool1d(kernel_size=2, stride=2) + ] + input_channels = hidden_dim + return layers + + def _create_mlp_layers(self, input_size, output_size, num_layers, hidden_dim): + if num_layers >=2: + layers = [nn.Linear(input_size, hidden_dim)] + layers.append(nn.GELU()) + if num_layers > 2: + for _ in range(1, num_layers - 2): + layers += [ + nn.Linear(hidden_dim, hidden_dim), + nn.GELU() + ] + layers.append(nn.Linear(hidden_dim, output_size)) + else: + layers = [nn.Linear(input_size, output_size)] + return layers + + def forward(self, x): + return torch.stack([torch.cat([cnnmlp(x) for cnnmlp in self.cnnmlps], dim=-2)]) + +def build_cnn_mlp_projector( + input_channels: int, input_size: int, lm_hidden_size: int, num_tokens: int, hidden_dim: int, num_conv_layers: int, num_mlp_layers: int + ): + return _CNNMLPProjector(input_channels, input_size, lm_hidden_size, num_tokens, hidden_dim, num_conv_layers, num_mlp_layers) + +class _MultiLayeredCNNMLPProjector(nn.Module): + def __init__(self, input_channels, input_size, num_feature_layers, output_size = 4096, width = 5, hidden_dim = 64, num_conv_layers = 1, num_mlp_layers = 2): + super(_MultiLayeredCNNMLPProjector, self).__init__() + self.width = width + self.num_feature_layers = num_feature_layers + self.cnnmlps = nn.ModuleList() + for _ in range(self.width*self.num_feature_layers): + cnnmlp = self._create_conv_layers(input_channels, num_conv_layers, hidden_dim) + cnnmlp += [nn.GELU()] + cnnmlp += self._create_mlp_layers(input_size, output_size, num_mlp_layers, output_size) + self.cnnmlps.append(nn.Sequential(*cnnmlp)) + + def _create_conv_layers(self, input_channels, num_conv_layers, hidden_size): + layers = [] + + if input_channels >= hidden_size: + hidden_dim = int(input_channels/2) + else: + hidden_dim = hidden_size + + layers += [nn.Conv1d(in_channels=input_channels, out_channels=hidden_dim, kernel_size=3, stride=1, padding=1), nn.GELU()] + if num_conv_layers > 2: + for _ in range(num_conv_layers - 2): + if hidden_dim/2 >= hidden_size: + output_dim = int(hidden_dim/2) + else: + output_dim = hidden_size + layers += [ + nn.Conv1d(in_channels=hidden_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1), + nn.GELU(), + ] + hidden_dim = output_dim + layers += [nn.Conv1d(in_channels=hidden_dim, out_channels=1, kernel_size=3, stride=1, padding=1)] + return layers + + def _create_mlp_layers(self, input_size, output_size, num_layers, hidden_dim): + if num_layers >=2: + layers = [nn.Linear(input_size, hidden_dim)] + layers.append(nn.GELU()) + if num_layers > 2: + for _ in range(1, num_layers - 2): + layers += [ + nn.Linear(hidden_dim, hidden_dim), + nn.GELU() + ] + layers.append(nn.Linear(hidden_dim, output_size)) + else: + layers = [nn.Linear(input_size, output_size)] + return layers + + def forward(self, x): + print("X SHAPE ", x.shape) + inp_feature_layers = [] + for feature_id in range(self.num_feature_layers): + in_feat_layer = x[feature_id].unsqueeze(0).permute(0,2,1) + inp_feature_layers.append(in_feat_layer) + + outputs = [] + for layer_count in range(self.width*self.num_feature_layers): + feature_id = int(layer_count/self.width) + outputs+=[self.cnnmlps[layer_count](inp_feature_layers[feature_id])] + + return torch.cat(outputs, dim=-2) + + +def build_multi_layer_cnn_mlp_projector( + input_channels: int, input_size: int, num_feature_layers: int, lm_hidden_size: int, num_tokens: int, hidden_dim: int, num_conv_layers: int, num_mlp_layers: int + ): + assert(num_tokens % num_feature_layers == 0) + width = int(num_tokens/num_feature_layers) + return _MultiLayeredCNNMLPProjector(input_channels, input_size, num_feature_layers, lm_hidden_size, width, hidden_dim, num_conv_layers, num_mlp_layers) + diff --git a/src/sonicverse/multi_token/modalities/video_xclip.py b/src/sonicverse/multi_token/modalities/video_xclip.py new file mode 100644 index 0000000000000000000000000000000000000000..d875a9972f884937326300fae327ada62fcf5300 --- /dev/null +++ b/src/sonicverse/multi_token/modalities/video_xclip.py @@ -0,0 +1,113 @@ +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +from transformers import AutoProcessor, AutoModel + +from multi_token.data_tools import load_video +from multi_token.modalities.base_modality import Modality +from multi_token.modalities.projectors import ( + build_mlp_vector_projector, +) + + +OUTPUT_EMB_SIZE = 512 + + +class XCLIPVideoModule(nn.Module): + def __init__(self, model_name_or_path: str): + super().__init__() + self.model_name_or_path = model_name_or_path + self.model = None + self.processor = None + + self.load_model() + + def load_model(self): + self.model = AutoModel.from_pretrained(self.model_name_or_path) + self.processor = AutoProcessor.from_pretrained(self.model_name_or_path) + self.model.requires_grad_(False) + + @torch.no_grad() + def forward(self, video_inputs) -> torch.Tensor: + with torch.no_grad(): + outputs = self.model(**(video_inputs.to(device=self.device))) + + emb = outputs.video_embeds.to(device=self.device, dtype=self.dtype).view( + -1, 1, OUTPUT_EMB_SIZE + ) + return emb + + @property + def dtype(self): + return self.model.dtype + + @property + def device(self): + return self.model.device + + +class XCLIPVideoModality(Modality): + def __init__( + self, + model_name_or_path: str = "microsoft/xclip-base-patch32", + num_projector_layers: int = 2, + num_tokens_output: int = 10, + ): + self.model_name_or_path = model_name_or_path + self.module = XCLIPVideoModule(model_name_or_path=self.model_name_or_path) + self.num_projector_layers = num_projector_layers + self.num_tokens_output = num_tokens_output + + def build_projector(self, lm_hidden_size: int) -> nn.Module: + return build_mlp_vector_projector( + input_hidden_size=OUTPUT_EMB_SIZE, + lm_hidden_size=lm_hidden_size, + num_layers=self.num_projector_layers, + num_tokens=self.num_tokens_output, + ) + + @property + def name(self) -> str: + return "video_xclip" + + @property + def token(self) -> str: + return "