File size: 4,484 Bytes
7c34c28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from typing import List, Dict, Sequence
from dataclasses import dataclass, field
import logging
import os

from torch.utils.data import Dataset
from datasets import load_from_disk, load_dataset, Dataset as HFDataset
import transformers
import torch

from multi_token.modalities.base_modality import Modality
from multi_token.constants import IGNORE_INDEX
from multi_token.data_tools import encode_chat, encode_chat_multitask
from multi_token.model_utils import MultiTaskType


@dataclass
class DataArguments:
    dataset_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )

@dataclass
class TrainDataArguments:
    train_dataset_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )

@dataclass
class EvaluationDataArguments:
    evaluation_dataset_path: str = field(
        default=None, metadata={"help": "Path to the evaluation data."}
    )


def _resolve_dataset(path: str) -> HFDataset:
    if os.path.exists(path):
        return load_from_disk(path)
    else:
        return load_dataset(path, split="train", data_files="*.arrow")


class LMMDataset(Dataset):
    def __init__(
        self,
        data_args: DataArguments,
        tokenizer: transformers.PreTrainedTokenizer,
        modalities: List[Modality],
    ):
        super(LMMDataset, self).__init__()
        self.dataset = _resolve_dataset(data_args.dataset_path)
        self.tokenizer = tokenizer
        self.modalities = modalities

    def __len__(self):
        return len(self.dataset)

    def get_example(self) -> Dict:
        return self.dataset[0]

    def __getitem__(self, i) -> Dict:
        try:
            item = self.dataset[i]
            use_multi_task = MultiTaskType.NO_MULTI_TASK
            for m in self.modalities:
                if m.use_multi_task != MultiTaskType.NO_MULTI_TASK:
                    use_multi_task = m.use_multi_task
                    break
            if use_multi_task != MultiTaskType.NO_MULTI_TASK:
                return encode_chat_multitask(item, self.tokenizer, self.modalities)
            else:
                return encode_chat(item, self.tokenizer, self.modalities)
        except Exception as e:
            new_i = i + 1
            if new_i >= len(self):
                new_i = 0
            logging.error(f"Error encoding chat: {e} index={i} trying index={new_i}")
            return self.__getitem__(new_i)


@dataclass
class DataCollatorForSupervisedLMMDataset:
    def __init__(
        self,
        tokenizer: transformers.PreTrainedTokenizer,
        modalities: List[Modality],
    ):
        self.tokenizer = tokenizer
        self.modalities = modalities

        self.use_multi_task = MultiTaskType.NO_MULTI_TASK
        for modality in self.modalities:
            if modality.use_multi_task != MultiTaskType.NO_MULTI_TASK:
                self.use_multi_task = modality.use_multi_task
                break

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, List]:
        input_ids = []
        lmm_labels = []
        task_labels = []
        for instance in instances:
            input_ids.append(instance["input_ids"])
            if self.use_multi_task == MultiTaskType.NO_MULTI_TASK:
                lmm_labels.append(instance["labels"])
            else:
                lmm_labels.append(instance["labels"][0])
                inst_task_labels = []
                for label_id in range(1, len(instance["labels"])):
                    inst_task_labels.append(instance["labels"][label_id])
                task_labels.append(inst_task_labels)

        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        # print("Lmm labels 1 type :", type(lmm_labels))
        lmm_labels = torch.nn.utils.rnn.pad_sequence(
            lmm_labels, batch_first=True, padding_value=IGNORE_INDEX
        )
        # print("Lmm labels 2 type :", type(lmm_labels))

        input_ids = input_ids[:, : self.tokenizer.model_max_length]
        lmm_labels = lmm_labels[:, : self.tokenizer.model_max_length]
        output_labels = [lmm_labels, task_labels]
        batch = dict(
            input_ids=input_ids,
            labels=output_labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )

        for m in self.modalities:
            batch[m.name] = [instance[m.name] for instance in instances]

        return batch