|
from __future__ import annotations |
|
|
|
import json |
|
import logging |
|
import re |
|
from importlib import import_module |
|
from importlib.metadata import metadata |
|
from typing import TYPE_CHECKING, Any, Protocol, cast |
|
|
|
import safetensors |
|
from joblib import Parallel |
|
from tokenizers import Tokenizer |
|
from tqdm import tqdm |
|
|
|
if TYPE_CHECKING: |
|
from collections.abc import Iterator |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ProgressParallel(Parallel): |
|
"""A drop-in replacement for joblib.Parallel that shows a tqdm progress bar.""" |
|
|
|
def __init__(self, use_tqdm: bool = True, total: int | None = None, *args: Any, **kwargs: Any) -> None: |
|
""" |
|
Initialize the ProgressParallel object. |
|
|
|
:param use_tqdm: Whether to show the progress bar. |
|
:param total: Total number of tasks (batches) you expect to process. If None, |
|
it updates the total dynamically to the number of dispatched tasks. |
|
:param *args: Additional arguments to pass to `Parallel.__init__`. |
|
:param **kwargs: Additional keyword arguments to pass to `Parallel.__init__`. |
|
""" |
|
self._use_tqdm = use_tqdm |
|
self._total = total |
|
super().__init__(*args, **kwargs) |
|
|
|
def __call__(self, *args: Any, **kwargs: Any) -> Any: |
|
"""Create a tqdm context.""" |
|
with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar: |
|
self._pbar = self._pbar |
|
return super().__call__(*args, **kwargs) |
|
|
|
def print_progress(self) -> None: |
|
"""Hook called by joblib as tasks complete. We update the tqdm bar here.""" |
|
if self._total is None: |
|
|
|
self._pbar.total = self.n_dispatched_tasks |
|
|
|
self._pbar.n = self.n_completed_tasks |
|
self._pbar.refresh() |
|
|
|
|
|
class SafeOpenProtocol(Protocol): |
|
"""Protocol to fix safetensors safe open.""" |
|
|
|
def get_tensor(self, key: str) -> np.ndarray: |
|
"""Get a tensor.""" |
|
... |
|
|
|
|
|
_MODULE_MAP = (("scikit-learn", "sklearn"),) |
|
_DIVIDERS = re.compile(r"[=<>!]+") |
|
|
|
|
|
def get_package_extras(package: str, extra: str) -> Iterator[str]: |
|
"""Get the extras of the package.""" |
|
try: |
|
message = metadata(package) |
|
except Exception as e: |
|
|
|
|
|
logger.debug(f"Could not retrieve metadata for package '{package}': {e}") |
|
return iter([]) |
|
|
|
all_packages = message.get_all("Requires-Dist") or [] |
|
for package in all_packages: |
|
name, *rest = package.split(";", maxsplit=1) |
|
if rest: |
|
|
|
found_extra = rest[0].split("==")[-1].strip(" \"'") |
|
if found_extra == extra: |
|
prefix, *_ = _DIVIDERS.split(name) |
|
yield prefix.strip() |
|
|
|
|
|
def importable(module: str, extra: str) -> None: |
|
"""Check if a module is importable.""" |
|
module = dict(_MODULE_MAP).get(module, module) |
|
try: |
|
import_module(module) |
|
except ImportError: |
|
msg = f"`{module}`, is required. Please reinstall model2vec with the `{extra}` extra. `pip install model2vec[{extra}]`" |
|
raise ImportError(msg) |
|
|
|
|
|
def setup_logging() -> None: |
|
"""Simple logging setup.""" |
|
from rich.logging import RichHandler |
|
|
|
logging.basicConfig( |
|
level="INFO", |
|
format="%(name)s - %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
handlers=[RichHandler(rich_tracebacks=True)], |
|
) |
|
|
|
|
|
def load_local_model(folder: Path) -> tuple[np.ndarray, Tokenizer, dict[str, str]]: |
|
"""Load a local model.""" |
|
embeddings_path = folder / "model.safetensors" |
|
tokenizer_path = folder / "tokenizer.json" |
|
config_path = folder / "config.json" |
|
|
|
opened_tensor_file = cast("SafeOpenProtocol", safetensors.safe_open(embeddings_path, framework="numpy")) |
|
embeddings = opened_tensor_file.get_tensor("embeddings") |
|
|
|
config = json.load(open(config_path)) if config_path.exists() else {} |
|
|
|
tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path)) |
|
|
|
if len(tokenizer.get_vocab()) != len(embeddings): |
|
logger.warning( |
|
f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`" |
|
) |
|
|
|
return embeddings, tokenizer, config |
|
|