ScientificArgumentRecommender / src /datamodules /datamodule_with_sampler.py
ArneBinder's picture
update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e verified
import logging
from typing import Optional, Union
from pytorch_ie import PieDataModule
from pytorch_ie.core.taskmodule import IterableTaskEncodingDataset, TaskEncodingDataset
from torch.utils.data import DataLoader, Sampler
from .components.sampler import ImbalancedDatasetSampler
logger = logging.getLogger(__name__)
class PieDataModuleWithSampler(PieDataModule):
def __init__(
self,
train_sampler: Optional[str] = None,
dont_shuffle_train: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.train_sampler_name = train_sampler
self.dont_shuffle_train = dont_shuffle_train
def get_train_sampler(
self,
dataset: Union[TaskEncodingDataset, IterableTaskEncodingDataset],
) -> Optional[Sampler]:
if self.train_sampler_name is None:
return None
elif self.train_sampler_name == "imbalanced_dataset":
# for now, this work only with targets that have a single entry
return ImbalancedDatasetSampler(
dataset, callback_get_label=lambda ds: [x.targets[0] for x in ds]
)
else:
raise ValueError(f"unknown sampler name: {self.train_sampler_name}")
def train_dataloader(self) -> DataLoader:
ds = self.data_split(self.train_split)
sampler = self.get_train_sampler(dataset=ds)
# don't shuffle if we explicitly set dont_shuffle_train,
# streamed datasets or if we use a sampler or
shuffle = not (
self.dont_shuffle_train
or isinstance(ds, IterableTaskEncodingDataset)
or sampler is not None
)
if not shuffle:
logger.warning("not shuffling train dataloader")
return DataLoader(
dataset=ds,
sampler=sampler,
collate_fn=self.taskmodule.collate,
shuffle=shuffle,
**self.dataloader_kwargs,
)