| | import multiprocessing as mp |
| | import pathlib |
| | from typing import Any |
| |
|
| | import datasets |
| | from PIL import Image |
| | import torch |
| | from torch.utils.data import DataLoader |
| | from torch.utils.data import Dataset |
| | from torchvision import transforms |
| |
|
| | from src import config |
| | from src import tokenizer as tk |
| |
|
| |
|
| | class CaptionDatset(Dataset): |
| | def __init__(self, dataset: datasets.Dataset, img_path: pathlib.Path) -> None: |
| | self.dataset = dataset |
| | self.img_path = img_path |
| |
|
| | def __len__(self) -> int: |
| | return len(self.dataset) |
| |
|
| | def __getitem__(self, idx: int) -> dict[str, Any]: |
| | item = self.dataset[idx] |
| | image = Image.open(self.img_path / item["url"].rsplit("/", 1)[-1]).convert("RGB") |
| | return {"image": image, "caption": item["short_caption"]} |
| |
|
| |
|
| | class CollateFn: |
| | def __init__(self, tokenizer: tk.Tokenizer, transform: transforms.Compose): |
| | self.tokenizer = tokenizer |
| | self.transform = transform |
| |
|
| | def __call__(self, batch: list[dict[str, Any]]) -> dict[str, torch.Tensor]: |
| | stacked_images = torch.stack([self.transform(item["image"]) for item in batch]) |
| | tokenized_text = self.tokenizer([item["caption"] for item in batch]) |
| |
|
| | return { |
| | "images": stacked_images, |
| | **tokenized_text, |
| | } |
| |
|
| |
|
| | def _get_dataloaders( |
| | train_ds: Dataset, |
| | valid_ds: Dataset, |
| | training_config: config.TrainerConfig, |
| | collate_fn: CollateFn, |
| | ) -> tuple[DataLoader, DataLoader]: |
| | common_params = { |
| | "batch_size": training_config.batch_size, |
| | "pin_memory": True, |
| | "num_workers": mp.cpu_count() // 3, |
| | "collate_fn": collate_fn, |
| | } |
| | train_loader = DataLoader( |
| | train_ds, |
| | shuffle=True, |
| | drop_last=True, |
| | **common_params, |
| | ) |
| | valid_loader = DataLoader( |
| | valid_ds, |
| | shuffle=False, |
| | drop_last=False, |
| | **common_params, |
| | ) |
| | return train_loader, valid_loader |
| |
|
| |
|
| | def get_dataset( |
| | transform: transforms.Compose, |
| | tokenizer: tk.Tokenizer, |
| | hyper_parameters: config.TrainerConfig, |
| | ) -> tuple[DataLoader, DataLoader]: |
| | dataset: datasets.Dataset = datasets.load_dataset( |
| | hyper_parameters._data_config.dataset, split="train" |
| | ) |
| | train_test_dataset = dataset.train_test_split(seed=42, test_size=0.1) |
| | train_ds = CaptionDatset(train_test_dataset["train"], config.IMAGE_DOWNLOAD_PATH) |
| | valid_ds = CaptionDatset(train_test_dataset["test"], config.IMAGE_DOWNLOAD_PATH) |
| | collate_fn = CollateFn(tokenizer, transform) |
| |
|
| | return _get_dataloaders( |
| | train_ds=train_ds, |
| | valid_ds=valid_ds, |
| | training_config=hyper_parameters, |
| | collate_fn=collate_fn, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | import os |
| |
|
| | from tqdm.auto import tqdm |
| |
|
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| | hyper_parameters = config.TrainerConfig() |
| | transform = transforms.Compose([transforms.Resize((128, 128)), transforms.ToTensor()]) |
| | tokenizer = tk.Tokenizer( |
| | hyper_parameters._model_config.text_model, hyper_parameters._model_config.max_len |
| | ) |
| | train_dl, valid_dl = get_dataset(transform, tokenizer, hyper_parameters) |
| |
|
| | batch = next(iter(train_dl)) |
| | print({k: v.shape for k, v in batch.items()}) |
| |
|
| | for batch in tqdm(train_dl): |
| | continue |
| |
|