update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
import json | |
import logging | |
import os | |
import shutil | |
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple | |
from datasets import Dataset as HFDataset | |
from langchain_core.documents import Document as LCDocument | |
from pie_datasets import Dataset, DatasetDict, concatenate_datasets | |
from pytorch_ie.documents import TextBasedDocument | |
from .pie_document_store import PieDocumentStore | |
logger = logging.getLogger(__name__) | |
class DatasetsPieDocumentStore(PieDocumentStore): | |
"""PIE Document store that uses Huggingface Datasets as the backend.""" | |
def __init__(self) -> None: | |
self._data: Optional[Dataset] = None | |
# keys map to indices in the dataset | |
self._keys: Dict[str, int] = {} | |
self._metadata: Dict[str, Any] = {} | |
def __len__(self): | |
return len(self._keys) | |
def _get_pie_docs_by_indices(self, indices: Iterable[int]) -> Sequence[TextBasedDocument]: | |
if self._data is None: | |
return [] | |
return self._data.apply_hf_func(func=HFDataset.select, indices=indices) | |
def mget(self, keys: Sequence[str]) -> List[LCDocument]: | |
if self._data is None or len(keys) == 0: | |
return [] | |
keys_in_data = [key for key in keys if key in self._keys] | |
indices = [self._keys[key] for key in keys_in_data] | |
dataset = self._get_pie_docs_by_indices(indices) | |
metadatas = [self._metadata.get(key, {}) for key in keys_in_data] | |
return [self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(dataset, metadatas)] | |
def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None: | |
if len(items) == 0: | |
return | |
keys, new_docs = zip(*items) | |
pie_docs, metadatas = zip(*[self.unwrap_with_metadata(doc) for doc in new_docs]) | |
if self._data is None: | |
idx_start = 0 | |
self._data = Dataset.from_documents(pie_docs) | |
else: | |
# we pass the features to the new dataset to mitigate issues caused by | |
# slightly different inferred features | |
dataset = Dataset.from_documents(pie_docs, features=self._data.features) | |
idx_start = len(self._data) | |
self._data = concatenate_datasets([self._data, dataset], clear_metadata=False) | |
keys_dict = {key: idx for idx, key in zip(range(idx_start, len(self._data)), keys)} | |
self._keys.update(keys_dict) | |
self._metadata.update( | |
{key: metadata for key, metadata in zip(keys, metadatas) if metadata} | |
) | |
def add_pie_dataset( | |
self, | |
dataset: Dataset, | |
keys: Optional[List[str]] = None, | |
metadatas: Optional[List[Dict[str, Any]]] = None, | |
) -> None: | |
if len(dataset) == 0: | |
return | |
if keys is None: | |
keys = [doc.id for doc in dataset] | |
if len(keys) != len(set(keys)): | |
raise ValueError("Keys must be unique.") | |
if None in keys: | |
raise ValueError("Keys must not be None.") | |
if metadatas is None: | |
metadatas = [{} for _ in range(len(dataset))] | |
if len(keys) != len(dataset) or len(keys) != len(metadatas): | |
raise ValueError("Keys, dataset and metadatas must have the same length.") | |
if self._data is None: | |
idx_start = 0 | |
self._data = dataset | |
else: | |
idx_start = len(self._data) | |
self._data = concatenate_datasets([self._data, dataset], clear_metadata=False) | |
keys_dict = {key: idx for idx, key in zip(range(idx_start, len(self._data)), keys)} | |
self._keys.update(keys_dict) | |
metadatas_dict = {key: metadata for key, metadata in zip(keys, metadatas) if metadata} | |
self._metadata.update(metadatas_dict) | |
def mdelete(self, keys: Sequence[str]) -> None: | |
for key in keys: | |
idx = self._keys.pop(key, None) | |
if idx is not None: | |
self._metadata.pop(key, None) | |
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: | |
return (key for key in self._keys if prefix is None or key.startswith(prefix)) | |
def _purge_invalid_entries(self): | |
if self._data is None or len(self._keys) == len(self._data): | |
return | |
self._data = self._get_pie_docs_by_indices(self._keys.values()) | |
def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None: | |
self._purge_invalid_entries() | |
if len(self) == 0: | |
logger.warning("No documents to save.") | |
return | |
all_doc_ids = list(self._keys) | |
all_metadatas: List[Dict[str, Any]] = [self._metadata.get(key, {}) for key in all_doc_ids] | |
pie_documents_path = os.path.join(path, "pie_documents") | |
if os.path.exists(pie_documents_path): | |
# remove existing directory | |
logger.warning(f"Removing existing directory: {pie_documents_path}") | |
shutil.rmtree(pie_documents_path) | |
os.makedirs(pie_documents_path, exist_ok=True) | |
DatasetDict({"train": self._data}).to_json(pie_documents_path, mode="w") | |
doc_ids_path = os.path.join(path, "doc_ids.json") | |
with open(doc_ids_path, "w") as f: | |
json.dump(all_doc_ids, f) | |
metadata_path = os.path.join(path, "metadata.json") | |
with open(metadata_path, "w") as f: | |
json.dump(all_metadatas, f) | |
def _load_from_directory(self, path: str, **kwargs) -> None: | |
doc_ids_path = os.path.join(path, "doc_ids.json") | |
if os.path.exists(doc_ids_path): | |
with open(doc_ids_path, "r") as f: | |
all_doc_ids = json.load(f) | |
else: | |
logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.") | |
all_doc_ids = None | |
metadata_path = os.path.join(path, "metadata.json") | |
if os.path.exists(metadata_path): | |
with open(metadata_path, "r") as f: | |
all_metadata = json.load(f) | |
else: | |
logger.warning(f"File {metadata_path} does not exist, don't load any metadata.") | |
all_metadata = None | |
pie_documents_path = os.path.join(path, "pie_documents") | |
if not os.path.exists(pie_documents_path): | |
logger.warning( | |
f"Directory {pie_documents_path} does not exist, don't load any documents." | |
) | |
return None | |
# If we have a dataset already loaded, we use its features to load the new dataset | |
# This is to mitigate issues caused by slightly different inferred features. | |
features = self._data.features if self._data is not None else None | |
pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path, features=features) | |
pie_docs = pie_dataset["train"] | |
self.add_pie_dataset(pie_docs, keys=all_doc_ids, metadatas=all_metadata) | |
logger.info(f"Loaded {len(pie_docs)} documents from {path} into docstore") | |