Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,515 Bytes
bef5729 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from src.utils.typing_utils import *
import torch
from .objaverse_part import ObjaversePartDataset, BatchedObjaversePartDataset
# Copied from https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/loader.py
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._DataLoader__initialized = False
if self.batch_sampler is None:
self.sampler = _RepeatSampler(self.sampler)
else:
self.batch_sampler = _RepeatSampler(self.batch_sampler)
self._DataLoader__initialized = True
self.iterator = super().__iter__()
def __len__(self):
return len(self.sampler) if self.batch_sampler is None else len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)
class _RepeatSampler(object):
""" Sampler that repeats forever.
Args:
sampler (Sampler)
"""
def __init__(self, sampler):
self.sampler = sampler
if isinstance(self.sampler, torch.utils.data.sampler.BatchSampler):
self.batch_size = self.sampler.batch_size
self.drop_last = self.sampler.drop_last
def __len__(self):
return len(self.sampler)
def __iter__(self):
while True:
yield from iter(self.sampler)
def yield_forever(iterator: Iterator[Any]):
while True:
for x in iterator:
yield x |