import numpy as np from numpy.typing import NDArray from typing import Optional, List from abc import ABC, abstractmethod from .utils import split_text class BaseContextPartitioner(ABC): """ A base class for partitioning a context into sources. Attributes: context (str): The context to partition. Methods: num_sources(self) -> int: Property. The number of sources within the context. split_context(self) -> None: Split the context into sources. get_source(self, index: int) -> str: Get a represention of the source corresponding to a given index. get_context(self, mask: Optional[NDArray] = None) -> str: Get a version of the context ablated according to the given mask. sources(self) -> List[str]: Property. A list of all sources within the context. """ def __init__(self, context: str) -> None: self.context = context @property @abstractmethod def num_sources(self) -> int: """The number of sources.""" @abstractmethod def split_context(self) -> None: """Split the context into sources.""" @abstractmethod def get_source(self, index: int) -> str: """Get a represention of the source corresponding to a given index.""" @abstractmethod def get_context(self, mask: Optional[NDArray] = None): """Get a version of the context ablated according to the given mask.""" @property def sources(self) -> List[str]: """A list of all sources.""" return [self.get_source(i) for i in range(self.num_sources)] class SimpleContextPartitioner(BaseContextPartitioner): """ A simple context partitioner that splits the context into sources based on a separator. """ def __init__(self, context: str, source_type: str = "sentence") -> None: super().__init__(context) self.source_type = source_type self._cache = {} def split_context(self): """Split text into parts and cache the parts and separators.""" parts, separators, _ = split_text(self.context, self.source_type) self._cache["parts"] = parts self._cache["separators"] = separators @property def parts(self): if self._cache.get("parts") is None: self.split_context() return self._cache["parts"] @property def separators(self): if self._cache.get("separators") is None: self.split_context() return self._cache["separators"] @property def num_sources(self) -> int: return len(self.parts) def get_source(self, index: int) -> str: return self.parts[index] def get_context(self, mask: Optional[NDArray] = None): if mask is None: mask = np.ones(self.num_sources, dtype=bool) separators = np.array(self.separators)[mask] parts = np.array(self.parts)[mask] context = "" for i, (separator, part) in enumerate(zip(separators, parts)): if i > 0: context += separator context += part return context