import json import logging import math import os import pdb import random import re import sys import time import traceback from collections import defaultdict from typing import Dict, List, Optional, Sequence import numpy as np import torch import transformers from transformers.trainer_pt_utils import LabelSmoother from .dataset_base import BaseDataset IGNORE_TOKEN_ID = LabelSmoother.ignore_index class DeepSeekDataset(BaseDataset): def __init__( self, *args, **kwargs, ): super().__init__( *args, **kwargs, ) self.default_system_message = "You are a helpful AI assistant." self.default_system_message = None self.ret = defaultdict(dict) self.is_cat = True if self.cross_dataset_joint: for i in range(2): self.maybe_init_ret(f"default_{i}") def maybe_init_ret(self, source, force=False): if source not in self.ret or force: self.ret[source] = {} self.ret[source]["tokens"] = [] self.ret[source]["labels"] = [] self.ret[source]["actual_seq_len"] = [] if self.create_position_ids: self.ret[source]["position_ids"] = [] if self.create_attention_mask: self.ret[source]["attention_mask"] = [] if self.create_attention_mask_2d: self.ret[source]["attention_mask_2d"] = torch.tril( torch.ones( (1, self.max_padding_length, self.max_padding_length), dtype=torch.bool ) ) return len(self.ret[source]["tokens"]) == 0 def get_max_min_ret_length(self): max_ret_lengh = 0 min_ret_lengh = self.max_padding_length + 1 max_ret_key = None min_ret_key = None for k, v in self.ret.items(): cur_length = len(v["tokens"]) if cur_length > max_ret_lengh: max_ret_lengh = cur_length max_ret_key = k if cur_length < min_ret_lengh: min_ret_lengh = cur_length min_ret_key = k return max_ret_lengh, max_ret_key, min_ret_lengh, min_ret_key def add_ret(self, ret, source): cur_length = len(ret["input_ids"]) cur_image_length = len(ret["images"]) cur_audio_length = len(ret["audios"]) all_length = len(self.ret[source]["tokens"]) if "images" in self.ret[source]: all_image_length = len(self.ret[source]["images"]) else: all_image_length = 0 if cur_image_length > 0: if all_image_length > 0: self.ret[source]["images"] = torch.cat( [self.ret[source]["images"], ret["images"]], dim=0 ) ret["image_indices"][1, :, :] += all_length self.ret[source]["image_indices"] = torch.cat( [self.ret[source]["image_indices"], ret["image_indices"]], dim=1 ) else: self.ret[source]["images"] = ret["images"] self.ret[source]["image_indices"] = ret["image_indices"] if "audios" in self.ret[source]: all_audio_length = len(self.ret[source]["audios"]) else: all_audio_length = 0 if cur_audio_length > 0: if all_audio_length > 0: # self.ret[source]["audios"] = torch.cat( # [self.ret[source]["audios"], ret["audios"]], dim=0 # ) # ret["audio_indices"][1, :, :] += all_length # self.ret[source]["audio_indices"] = torch.cat( # [self.ret[source]["audio_indices"], ret["audio_indices"]], dim=1 # ) self.ret[source]["audios"].extend(ret["audios"]) for audio_indice in ret["audio_indices"]: audio_indice[1, :, :] += all_length self.ret[source]["audio_indices"].extend(ret["audio_indices"]) else: self.ret[source]["audios"] = ret["audios"] self.ret[source]["audio_indices"] = ret["audio_indices"] # print(self.ret[source]["audios"]) if self.create_attention_mask: self.ret[source]["attention_mask"] += ret["attention_mask"] if self.create_attention_mask_2d: self.ret[source]["attention_mask_2d"][:, all_length:, :all_length] = 0 if self.create_position_ids: self.ret[source]["position_ids"] += list(range(cur_length)) self.ret[source]["tokens"] += ret["input_ids"] self.ret[source]["labels"] += ret["labels"] self.ret[source]["actual_seq_len"] += [all_length + cur_length] def process_ret(self, to_ret): if "tokens" in to_ret and len(to_ret["tokens"]) > 0: pass else: return to_ret if self.create_position_ids: if self.reset_position_ids: pass else: to_ret["position_ids"] = list(range(len(to_ret["tokens"]))) if self.create_attention_mask_2d: if self.reset_attention_mask: pass else: to_ret["attention_mask_2d"] = torch.tril( torch.ones( (1, self.max_padding_length, self.max_padding_length), dtype=torch.bool ) ) if self.shift_token: to_ret["tokens"] = to_ret["tokens"][:-1] to_ret["labels"] = to_ret["labels"][1:] to_ret["actual_seq_len"][-1] -= 1 if self.create_position_ids: to_ret["position_ids"] = to_ret["position_ids"][:-1] if self.create_attention_mask: to_ret["attention_mask"] = to_ret["attention_mask"][:-1] if self.create_attention_mask_2d: to_ret["attention_mask_2d"][:, :, -1] = 0 to_ret["attention_mask_2d"][:, -1, :] = 0 assert len(to_ret["tokens"]) == len( to_ret["labels"] ), f"{len(to_ret['tokens'])} {len(to_ret['labels'])}" if not self.variable_length and self.max_padding_length > len(to_ret["tokens"]): to_ret["tokens"] += [self.tokenizer.pad_token_id] * ( self.max_padding_length - len(to_ret["tokens"]) ) to_ret["labels"] += [IGNORE_TOKEN_ID] * ( self.max_padding_length - len(to_ret["labels"]) ) to_ret["actual_seq_len"][-1] = self.max_padding_length if self.create_position_ids: # to_ret["position_ids"] += to_ret["position_ids"][-1:] * ( # self.max_padding_length - len(to_ret["position_ids"]) # ) to_ret["position_ids"] += list( range(to_ret["position_ids"][-1] + 1, self.max_padding_length) ) if self.create_attention_mask: to_ret["attention_mask"] += [0] * ( self.max_padding_length - len(to_ret["attention_mask"]) ) to_ret["tokens"] = to_ret["tokens"][: self.max_padding_length] to_ret["labels"] = to_ret["labels"][: self.max_padding_length] to_ret["actual_seq_len"][-1] = self.max_padding_length if self.create_position_ids: to_ret["position_ids"] = to_ret["position_ids"][: self.max_padding_length] if self.create_attention_mask: to_ret["attention_mask"] = to_ret["attention_mask"][: self.max_padding_length] to_ret["tokens"] = torch.tensor(to_ret["tokens"], dtype=torch.int64) to_ret["labels"] = torch.tensor(to_ret["labels"], dtype=torch.int64) to_ret["actual_seq_len"] = torch.tensor(to_ret["actual_seq_len"], dtype=torch.int64) if self.create_position_ids: to_ret["position_ids"] = torch.tensor(to_ret["position_ids"], dtype=torch.int64) if self.create_attention_mask: to_ret["attention_mask"] = torch.tensor(to_ret["attention_mask"], dtype=torch.int64) if self.create_attention_mask_2d: attention_mask_2d = to_ret.pop("attention_mask_2d") attention_mask_2d = attention_mask_2d.masked_fill( (to_ret["attention_mask"] < 0.5).view(1, 1, self.max_padding_length), value=0 ) attention_mask_2d = attention_mask_2d < 0.5 to_ret["attention_mask"] = attention_mask_2d if self.create_loss_mask: loss_mask = torch.where(to_ret["labels"] == IGNORE_TOKEN_ID, 0, 1) to_ret["loss_mask"] = loss_mask.to(torch.float32) if not self.reset_position_ids and not self.reset_attention_mask: to_ret.pop("actual_seq_len") to_ret["input_ids"] = to_ret["tokens"] # print("to_ret[tokens]", to_ret["tokens"]) # print("to_ret[labels]", to_ret["labels"]) return to_ret def is_skip(self): if self.processed_samples < self.skip_samples: if self.processed_samples % 1e3 == 0: print( f"processed_samples {self.processed_samples} skip_samples {self.skip_samples}" ) return True def show_statistic(self): log_interval = 10000 if self.max_padding_length >= 2**17: log_interval = 500 if self.max_padding_length >= 2**20: log_interval = 100 if self.unjoint_samples % log_interval == 0: print( f"processed_samples {self.processed_samples} unjoint_samples {self.unjoint_samples} joint_samples {self.joint_samples} {[len(v['tokens']) for _, v in self.ret.items()]}", flush=True, ) return False def __getitem__(self, index): self.processor["audio"].load_model() while True: # if True: try: self.processed_samples += 1 if self.is_skip(): return {} sample = self.raw_data[index] if self.cross_dataset_joint: is_empty = False ( max_ret_lengh, max_ret_key, min_ret_lengh, min_ret_key, ) = self.get_max_min_ret_length() else: source = sample["source"] is_empty = self.maybe_init_ret(source) max_ret_lengh = min_ret_lengh = len(self.ret[source]["tokens"]) max_ret_key = min_ret_key = source is_begin = is_empty or self.reset_position_ids or self.reset_attention_mask ret = preprocess( sample, self.tokenizer, self.image_token_length, default_system_message=self.default_system_message, processor=self.processor, is_begin=is_begin, max_num_frame=self.max_num_frame, max_fps=self.max_fps, ) if ret is None: return {} cur_length = len(ret["input_ids"]) if cur_length > self.max_padding_length: return {} self.unjoint_samples += 1 if not self.dataset_joint: to_ret = self.ret.pop(max_ret_key) self.maybe_init_ret(max_ret_key, force=True) self.add_ret(ret, max_ret_key) elif min_ret_lengh + cur_length > self.max_padding_length: to_ret = self.ret.pop(max_ret_key) self.joint_samples += 1 self.maybe_init_ret(max_ret_key, force=True) self.add_ret(ret, max_ret_key) else: to_ret = {} self.add_ret(ret, min_ret_key) to_ret = self.process_ret(to_ret) self.show_statistic() return to_ret except Exception as error: try: with open(os.path.join(self.output_dir, "data_error.log"), "a") as f: print("-" * 100, file=f) print(traceback.format_exc(), file=f) print(self.raw_data[index], file=f) except Exception as error: print(error) return {} def preprocess( sample, tokenizer: transformers.PreTrainedTokenizer, image_token_length: int, default_system_message: str = "You are a helpful assistant.", processor=None, is_begin: bool = True, max_num_frame: int = 8, max_fps: int = 1, ) -> Dict: # <|im_start|>system # You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|> # <|im_start|>user # Hello, how are you?<|im_end|> # <|im_start|>assistantI'm doing great. How can I help you today?<|im_end|> # <|im_start|>user # I'd like to show off how chat templating works!<|im_end|> from ..constants import ( IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN, VID_START_TOKEN, VID_END_TOKEN, VID_CONTEXT_TOKEN, PATCH_START_TOKEN, PATCH_END_TOKEN, PATCH_CONTEXT_TOKEN, AUD_START_TOKEN, AUD_END_TOKEN, IMG_TAG_TOKEN, VID_TAG_TOKEN, AUD_TAG_TOKEN, AUD_CONTEXT_TOKEN, ) human_roles = ["user", "human"] gpt_roles = ["assistant", "gpt"] system_roles = ["system"] IMG_CONTEXT_ID = tokenizer(IMG_CONTEXT_TOKEN, add_special_tokens=False).input_ids IMG_START_ID = tokenizer(IMG_START_TOKEN, add_special_tokens=False).input_ids IMG_END_ID = tokenizer(IMG_END_TOKEN, add_special_tokens=False).input_ids VID_CONTEXT_ID = tokenizer(VID_CONTEXT_TOKEN, add_special_tokens=False).input_ids VID_START_ID = tokenizer(VID_START_TOKEN, add_special_tokens=False).input_ids VID_END_ID = tokenizer(VID_END_TOKEN, add_special_tokens=False).input_ids PATCH_CONTEXT_ID = tokenizer(PATCH_CONTEXT_TOKEN, add_special_tokens=False).input_ids PATCH_START_ID = tokenizer(PATCH_START_TOKEN, add_special_tokens=False).input_ids PATCH_END_ID = tokenizer(PATCH_END_TOKEN, add_special_tokens=False).input_ids AUD_CONTEXT_ID = tokenizer(AUD_CONTEXT_TOKEN, add_special_tokens=False).input_ids AUD_START_ID = tokenizer(AUD_START_TOKEN, add_special_tokens=False).input_ids AUD_END_ID = tokenizer(AUD_END_TOKEN, add_special_tokens=False).input_ids IMG_TAG_ID = tokenizer(IMG_TAG_TOKEN, add_special_tokens=False).input_ids VID_TAG_ID = tokenizer(VID_TAG_TOKEN, add_special_tokens=False).input_ids AUD_TAG_ID = tokenizer(AUD_TAG_TOKEN, add_special_tokens=False).input_ids assert len(IMG_CONTEXT_ID) == 1 assert len(IMG_START_ID) == 1 assert len(IMG_END_ID) == 1 assert len(VID_CONTEXT_ID) == 1 assert len(VID_START_ID) == 1 assert len(VID_END_ID) == 1 assert len(PATCH_CONTEXT_ID) == 1 assert len(PATCH_START_ID) == 1 assert len(PATCH_END_ID) == 1 IMG_CONTEXT_ID = IMG_CONTEXT_ID[0] IMG_START_ID = IMG_START_ID[0] IMG_END_ID = IMG_END_ID[0] VID_CONTEXT_ID = VID_CONTEXT_ID[0] VID_START_ID = VID_START_ID[0] VID_END_ID = VID_END_ID[0] PATCH_CONTEXT_ID = PATCH_CONTEXT_ID[0] PATCH_START_ID = PATCH_START_ID[0] PATCH_END_ID = PATCH_END_ID[0] AUD_CONTEXT_ID = AUD_CONTEXT_ID[0] AUD_START_ID = AUD_START_ID[0] AUD_END_ID = AUD_END_ID[0] IMG_TAG_ID = IMG_TAG_ID[0] VID_TAG_ID = VID_TAG_ID[0] AUD_TAG_ID = AUD_TAG_ID[0] BOS_ID = tokenizer.bos_token_id EOS_ID = tokenizer.eos_token_id IM_START = "<|begin▁of▁sentence|>" IM_END = "<|end▁of▁sentence|>" USER = "<|User|>" ASSISTANT = "<|Assistant|>" nl_tokens = tokenizer("\n", add_special_tokens=False).input_ids IM_START_IDS = tokenizer(IM_START, add_special_tokens=False).input_ids IM_END_IDS = tokenizer(IM_END, add_special_tokens=False).input_ids USER_IDS = tokenizer(USER, add_special_tokens=False).input_ids ASSISTANT_IDS = tokenizer(ASSISTANT, add_special_tokens=False).input_ids assert len(USER_IDS) == 1, USER_IDS assert len(ASSISTANT_IDS) == 1, ASSISTANT_IDS assert len(IM_END_IDS) == 1, IM_END_IDS assert len(IM_START_IDS) == 1, IM_START_IDS input_ids, targets = [], [] images = [] image_indices = [] audios = [] audio_indices = [] messages = [] if "conversations" in sample: messages = sample["conversations"] if len(messages) == 0 and "messages" in sample: messages = sample["messages"] # ---------------------------------------------------------------- # add text to TTS if True: add_text = None # add_audio = None for j, sentence in enumerate(messages): content = sentence["content"] role = sentence["role"] if role == "user": if "Convert the text to speech." in content: add_text = content.replace("Convert the text to speech.\n", "") add_text = add_text.strip() # if "Convert the speech to text." in content: # add_audio = sample["audios"][-1] if role == "assistant" and add_text is not None: sentence["content"] = add_text + content # if role == "assistant" and add_audio is not None: # sentence["content"] = content + "\n