update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
import logging | |
from typing import Optional, Union | |
from pytorch_ie import PieDataModule | |
from pytorch_ie.core.taskmodule import IterableTaskEncodingDataset, TaskEncodingDataset | |
from torch.utils.data import DataLoader, Sampler | |
from .components.sampler import ImbalancedDatasetSampler | |
logger = logging.getLogger(__name__) | |
class PieDataModuleWithSampler(PieDataModule): | |
def __init__( | |
self, | |
train_sampler: Optional[str] = None, | |
dont_shuffle_train: bool = False, | |
**kwargs, | |
) -> None: | |
super().__init__(**kwargs) | |
self.train_sampler_name = train_sampler | |
self.dont_shuffle_train = dont_shuffle_train | |
def get_train_sampler( | |
self, | |
dataset: Union[TaskEncodingDataset, IterableTaskEncodingDataset], | |
) -> Optional[Sampler]: | |
if self.train_sampler_name is None: | |
return None | |
elif self.train_sampler_name == "imbalanced_dataset": | |
# for now, this work only with targets that have a single entry | |
return ImbalancedDatasetSampler( | |
dataset, callback_get_label=lambda ds: [x.targets[0] for x in ds] | |
) | |
else: | |
raise ValueError(f"unknown sampler name: {self.train_sampler_name}") | |
def train_dataloader(self) -> DataLoader: | |
ds = self.data_split(self.train_split) | |
sampler = self.get_train_sampler(dataset=ds) | |
# don't shuffle if we explicitly set dont_shuffle_train, | |
# streamed datasets or if we use a sampler or | |
shuffle = not ( | |
self.dont_shuffle_train | |
or isinstance(ds, IterableTaskEncodingDataset) | |
or sampler is not None | |
) | |
if not shuffle: | |
logger.warning("not shuffling train dataloader") | |
return DataLoader( | |
dataset=ds, | |
sampler=sampler, | |
collate_fn=self.taskmodule.collate, | |
shuffle=shuffle, | |
**self.dataloader_kwargs, | |
) | |