update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
from typing import Dict, Iterable, Optional, Sequence, Type, TypeVar | |
from pie_datasets import Dataset, DatasetDict, IterableDataset | |
from pytorch_ie.core import Document | |
from src.serializer.interface import DocumentSerializer | |
from src.utils.logging_utils import get_pylogger | |
log = get_pylogger(__name__) | |
D = TypeVar("D", bound=Document) | |
def as_json_lines(file_name: str) -> bool: | |
if file_name.lower().endswith(".jsonl"): | |
return True | |
elif file_name.lower().endswith(".json"): | |
return False | |
else: | |
raise Exception(f"unknown file extension: {file_name}") | |
class JsonSerializer(DocumentSerializer): | |
def __init__(self, **kwargs): | |
self.default_kwargs = kwargs | |
def write( | |
cls, | |
documents: Iterable[Document], | |
path: str, | |
split: str = "train", | |
append: bool = False, | |
) -> Dict[str, str]: | |
if not isinstance(documents, (Dataset, IterableDataset)): | |
if not isinstance(documents, Sequence): | |
documents = IterableDataset.from_documents(documents) | |
else: | |
documents = Dataset.from_documents(documents) | |
dataset_dict = DatasetDict({split: documents}) | |
dataset_dict.to_json(path=path, mode="a" if append else "w") | |
return {"path": path, "split": split} | |
def read( | |
cls, | |
path: str, | |
document_type: Optional[Type[D]] = None, | |
split: Optional[str] = None, | |
) -> Dataset[Document]: | |
dataset_dict = DatasetDict.from_json( | |
data_dir=path, document_type=document_type, split=split | |
) | |
if split is not None: | |
return dataset_dict[split] | |
if len(dataset_dict) == 1: | |
return dataset_dict[list(dataset_dict.keys())[0]] | |
raise ValueError(f"multiple splits found in dataset_dict: {list(dataset_dict.keys())}") | |
def read_with_defaults(self, **kwargs) -> Sequence[D]: | |
all_kwargs = {**self.default_kwargs, **kwargs} | |
return self.read(**all_kwargs) | |
def write_with_defaults(self, **kwargs) -> Dict[str, str]: | |
all_kwargs = {**self.default_kwargs, **kwargs} | |
return self.write(**all_kwargs) | |
def __call__( | |
self, documents: Iterable[Document], append: bool = False, **kwargs | |
) -> Dict[str, str]: | |
return self.write_with_defaults(documents=documents, append=append, **kwargs) | |