update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
import json | |
import logging | |
import os | |
import shutil | |
from itertools import islice | |
from typing import Iterator, List, Optional, Sequence, Tuple | |
from langchain.storage import create_kv_docstore | |
from langchain_core.documents import Document as LCDocument | |
from langchain_core.stores import BaseStore, ByteStore | |
from pie_datasets import Dataset, DatasetDict | |
from .pie_document_store import PieDocumentStore | |
logger = logging.getLogger(__name__) | |
class BasicPieDocumentStore(PieDocumentStore): | |
"""PIE Document store that uses a client to store and retrieve documents.""" | |
def __init__( | |
self, | |
client: Optional[BaseStore[str, LCDocument]] = None, | |
byte_store: Optional[ByteStore] = None, | |
): | |
if byte_store is not None: | |
client = create_kv_docstore(byte_store) | |
elif client is None: | |
raise Exception("You must pass a `byte_store` parameter.") | |
self.client = client | |
def mget(self, keys: Sequence[str]) -> List[LCDocument]: | |
return self.client.mget(keys) | |
def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None: | |
self.client.mset(items) | |
def mdelete(self, keys: Sequence[str]) -> None: | |
self.client.mdelete(keys) | |
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: | |
return self.client.yield_keys(prefix=prefix) | |
def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None: | |
all_doc_ids = [] | |
all_metadata = [] | |
pie_documents_path = os.path.join(path, "pie_documents") | |
if os.path.exists(pie_documents_path): | |
# remove existing directory | |
logger.warning(f"Removing existing directory: {pie_documents_path}") | |
shutil.rmtree(pie_documents_path) | |
os.makedirs(pie_documents_path, exist_ok=True) | |
doc_ids_iter = iter(self.client.yield_keys()) | |
mode = "w" | |
while batch_doc_ids := list(islice(doc_ids_iter, batch_size or 1000)): | |
all_doc_ids.extend(batch_doc_ids) | |
docs = self.client.mget(batch_doc_ids) | |
pie_docs = [] | |
for doc in docs: | |
pie_doc = doc.metadata[self.METADATA_KEY_PIE_DOCUMENT] | |
pie_docs.append(pie_doc) | |
all_metadata.append( | |
{k: v for k, v in doc.metadata.items() if k != self.METADATA_KEY_PIE_DOCUMENT} | |
) | |
pie_dataset = Dataset.from_documents(pie_docs) | |
DatasetDict({"train": pie_dataset}).to_json(path=pie_documents_path, mode=mode) | |
mode = "a" # append after the first batch | |
if len(all_doc_ids) > 0: | |
doc_ids_path = os.path.join(path, "doc_ids.json") | |
with open(doc_ids_path, "w") as f: | |
json.dump(all_doc_ids, f) | |
if len(all_metadata) > 0: | |
metadata_path = os.path.join(path, "metadata.json") | |
with open(metadata_path, "w") as f: | |
json.dump(all_metadata, f) | |
def _load_from_directory(self, path: str, **kwargs) -> None: | |
pie_documents_path = os.path.join(path, "pie_documents") | |
if not os.path.exists(pie_documents_path): | |
logger.warning( | |
f"Directory {pie_documents_path} does not exist, don't load any documents." | |
) | |
return None | |
pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path) | |
pie_docs = pie_dataset["train"] | |
metadata_path = os.path.join(path, "metadata.json") | |
if os.path.exists(metadata_path): | |
with open(metadata_path, "r") as f: | |
all_metadata = json.load(f) | |
else: | |
logger.warning(f"File {metadata_path} does not exist, don't load any metadata.") | |
all_metadata = [{} for _ in pie_docs] | |
docs = [ | |
self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(pie_docs, all_metadata) | |
] | |
doc_ids_path = os.path.join(path, "doc_ids.json") | |
if os.path.exists(doc_ids_path): | |
with open(doc_ids_path, "r") as f: | |
all_doc_ids = json.load(f) | |
else: | |
logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.") | |
all_doc_ids = [doc.id for doc in pie_docs] | |
self.client.mset(zip(all_doc_ids, docs)) | |
logger.info(f"Loaded {len(docs)} documents from {path} into docstore") | |