Sarthak
chore: moved model2vec as in internal package
473c3a0
from __future__ import annotations
import json
import logging
import re
from importlib import import_module
from importlib.metadata import metadata
from typing import TYPE_CHECKING, Any, Protocol, cast
import safetensors
from joblib import Parallel
from tokenizers import Tokenizer
from tqdm import tqdm
if TYPE_CHECKING:
from collections.abc import Iterator
from pathlib import Path
import numpy as np
logger = logging.getLogger(__name__)
class ProgressParallel(Parallel):
"""A drop-in replacement for joblib.Parallel that shows a tqdm progress bar."""
def __init__(self, use_tqdm: bool = True, total: int | None = None, *args: Any, **kwargs: Any) -> None:
"""
Initialize the ProgressParallel object.
:param use_tqdm: Whether to show the progress bar.
:param total: Total number of tasks (batches) you expect to process. If None,
it updates the total dynamically to the number of dispatched tasks.
:param *args: Additional arguments to pass to `Parallel.__init__`.
:param **kwargs: Additional keyword arguments to pass to `Parallel.__init__`.
"""
self._use_tqdm = use_tqdm
self._total = total
super().__init__(*args, **kwargs)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Create a tqdm context."""
with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar:
self._pbar = self._pbar
return super().__call__(*args, **kwargs)
def print_progress(self) -> None:
"""Hook called by joblib as tasks complete. We update the tqdm bar here."""
if self._total is None:
# If no fixed total was given, we dynamically set the total
self._pbar.total = self.n_dispatched_tasks
# Move the bar to the number of completed tasks
self._pbar.n = self.n_completed_tasks
self._pbar.refresh()
class SafeOpenProtocol(Protocol):
"""Protocol to fix safetensors safe open."""
def get_tensor(self, key: str) -> np.ndarray:
"""Get a tensor."""
... # pragma: no cover
_MODULE_MAP = (("scikit-learn", "sklearn"),)
_DIVIDERS = re.compile(r"[=<>!]+")
def get_package_extras(package: str, extra: str) -> Iterator[str]:
"""Get the extras of the package."""
try:
message = metadata(package)
except Exception as e:
# For local packages without metadata, return empty iterator
# This allows the package to work without installed metadata
logger.debug(f"Could not retrieve metadata for package '{package}': {e}")
return iter([])
all_packages = message.get_all("Requires-Dist") or []
for package in all_packages:
name, *rest = package.split(";", maxsplit=1)
if rest:
# Extract and clean the extra requirement
found_extra = rest[0].split("==")[-1].strip(" \"'")
if found_extra == extra:
prefix, *_ = _DIVIDERS.split(name)
yield prefix.strip()
def importable(module: str, extra: str) -> None:
"""Check if a module is importable."""
module = dict(_MODULE_MAP).get(module, module)
try:
import_module(module)
except ImportError:
msg = f"`{module}`, is required. Please reinstall model2vec with the `{extra}` extra. `pip install model2vec[{extra}]`"
raise ImportError(msg)
def setup_logging() -> None:
"""Simple logging setup."""
from rich.logging import RichHandler
logging.basicConfig(
level="INFO",
format="%(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[RichHandler(rich_tracebacks=True)],
)
def load_local_model(folder: Path) -> tuple[np.ndarray, Tokenizer, dict[str, str]]:
"""Load a local model."""
embeddings_path = folder / "model.safetensors"
tokenizer_path = folder / "tokenizer.json"
config_path = folder / "config.json"
opened_tensor_file = cast("SafeOpenProtocol", safetensors.safe_open(embeddings_path, framework="numpy"))
embeddings = opened_tensor_file.get_tensor("embeddings")
config = json.load(open(config_path)) if config_path.exists() else {}
tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
if len(tokenizer.get_vocab()) != len(embeddings):
logger.warning(
f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
)
return embeddings, tokenizer, config