Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,484 Bytes
7c34c28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from typing import List, Dict, Sequence
from dataclasses import dataclass, field
import logging
import os
from torch.utils.data import Dataset
from datasets import load_from_disk, load_dataset, Dataset as HFDataset
import transformers
import torch
from multi_token.modalities.base_modality import Modality
from multi_token.constants import IGNORE_INDEX
from multi_token.data_tools import encode_chat, encode_chat_multitask
from multi_token.model_utils import MultiTaskType
@dataclass
class DataArguments:
dataset_path: str = field(
default=None, metadata={"help": "Path to the training data."}
)
@dataclass
class TrainDataArguments:
train_dataset_path: str = field(
default=None, metadata={"help": "Path to the training data."}
)
@dataclass
class EvaluationDataArguments:
evaluation_dataset_path: str = field(
default=None, metadata={"help": "Path to the evaluation data."}
)
def _resolve_dataset(path: str) -> HFDataset:
if os.path.exists(path):
return load_from_disk(path)
else:
return load_dataset(path, split="train", data_files="*.arrow")
class LMMDataset(Dataset):
def __init__(
self,
data_args: DataArguments,
tokenizer: transformers.PreTrainedTokenizer,
modalities: List[Modality],
):
super(LMMDataset, self).__init__()
self.dataset = _resolve_dataset(data_args.dataset_path)
self.tokenizer = tokenizer
self.modalities = modalities
def __len__(self):
return len(self.dataset)
def get_example(self) -> Dict:
return self.dataset[0]
def __getitem__(self, i) -> Dict:
try:
item = self.dataset[i]
use_multi_task = MultiTaskType.NO_MULTI_TASK
for m in self.modalities:
if m.use_multi_task != MultiTaskType.NO_MULTI_TASK:
use_multi_task = m.use_multi_task
break
if use_multi_task != MultiTaskType.NO_MULTI_TASK:
return encode_chat_multitask(item, self.tokenizer, self.modalities)
else:
return encode_chat(item, self.tokenizer, self.modalities)
except Exception as e:
new_i = i + 1
if new_i >= len(self):
new_i = 0
logging.error(f"Error encoding chat: {e} index={i} trying index={new_i}")
return self.__getitem__(new_i)
@dataclass
class DataCollatorForSupervisedLMMDataset:
def __init__(
self,
tokenizer: transformers.PreTrainedTokenizer,
modalities: List[Modality],
):
self.tokenizer = tokenizer
self.modalities = modalities
self.use_multi_task = MultiTaskType.NO_MULTI_TASK
for modality in self.modalities:
if modality.use_multi_task != MultiTaskType.NO_MULTI_TASK:
self.use_multi_task = modality.use_multi_task
break
def __call__(self, instances: Sequence[Dict]) -> Dict[str, List]:
input_ids = []
lmm_labels = []
task_labels = []
for instance in instances:
input_ids.append(instance["input_ids"])
if self.use_multi_task == MultiTaskType.NO_MULTI_TASK:
lmm_labels.append(instance["labels"])
else:
lmm_labels.append(instance["labels"][0])
inst_task_labels = []
for label_id in range(1, len(instance["labels"])):
inst_task_labels.append(instance["labels"][label_id])
task_labels.append(inst_task_labels)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
# print("Lmm labels 1 type :", type(lmm_labels))
lmm_labels = torch.nn.utils.rnn.pad_sequence(
lmm_labels, batch_first=True, padding_value=IGNORE_INDEX
)
# print("Lmm labels 2 type :", type(lmm_labels))
input_ids = input_ids[:, : self.tokenizer.model_max_length]
lmm_labels = lmm_labels[:, : self.tokenizer.model_max_length]
output_labels = [lmm_labels, task_labels]
batch = dict(
input_ids=input_ids,
labels=output_labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
for m in self.modalities:
batch[m.name] = [instance[m.name] for instance in instances]
return batch
|