import torch from datasets import load_dataset from torch.utils.data import DataLoader from transformers import BertTokenizer import decord import numpy as np from tqdm import tqdm FRAMES = 50 H, W = 128, 128 BATCH_SIZE = 8 TEXT_MAX_LEN = 3000 dataset = load_dataset("gaussalgo/webvid-10m", split="train") # 10M samples tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") class VideoDataset(torch.utils.data.Dataset): def __init__(self, dataset): self.dataset = dataset self.decord_ctx = decord.cpu(0) # CPU decoding def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] vr = decord.VideoReader(item["video_path"], ctx=self.decord_ctx) frame_indices = np.linspace(0, len(vr)-1, FRAMES, dtype=int) video = vr.get_batch(frame_indices).numpy() # (FRAMES, H, W, 3) video = torch.from_numpy(video).permute(3, 0, 1, 2).float() # (3, FRAMES, H, W) video = F.interpolate(video, size=(H, W), mode="bilinear") video = (video / 255.0) * 2 - 1 # [-1, 1] text = tokenizer( item["caption"], padding="max_length", truncation=True, max_length=TEXT_MAX_LEN, return_tensors="pt" ).input_ids.squeeze(0) return {"video": video, "text": text} # DataLoader dataset = VideoDataset(dataset) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)