Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Model Setup Script for Enhanced RAG Demo | |
This script handles automatic downloading and setup of required models | |
for deployment environments like HuggingFace Spaces where models may not | |
be pre-installed. | |
Usage: | |
python scripts/setup_models.py | |
Environment Variables: | |
SKIP_MODEL_DOWNLOAD: Set to '1' to skip model downloads | |
SPACY_MODEL: Override default spaCy model (default: en_core_web_sm) | |
""" | |
import os | |
import sys | |
import logging | |
import subprocess | |
import time | |
from pathlib import Path | |
from typing import List, Dict, Any, Optional | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
def check_spacy_model(model_name: str = "en_core_web_sm") -> bool: | |
""" | |
Check if spaCy model is available. | |
Args: | |
model_name: Name of the spaCy model to check | |
Returns: | |
True if model is available, False otherwise | |
""" | |
try: | |
import spacy | |
spacy.load(model_name) | |
logger.info(f"β spaCy model '{model_name}' is available") | |
return True | |
except OSError: | |
logger.warning(f"β spaCy model '{model_name}' not found") | |
return False | |
except ImportError: | |
logger.warning("β spaCy not installed") | |
return False | |
except Exception as e: | |
logger.warning(f"β Error checking spaCy model: {e}") | |
return False | |
def download_spacy_model(model_name: str = "en_core_web_sm", timeout: int = 300) -> bool: | |
""" | |
Download spaCy model. | |
Args: | |
model_name: Name of the spaCy model to download | |
timeout: Download timeout in seconds | |
Returns: | |
True if download successful, False otherwise | |
""" | |
try: | |
logger.info(f"π₯ Downloading spaCy model '{model_name}'...") | |
result = subprocess.run([ | |
sys.executable, "-m", "spacy", "download", model_name | |
], capture_output=True, text=True, timeout=timeout) | |
if result.returncode == 0: | |
logger.info(f"β Successfully downloaded spaCy model '{model_name}'") | |
return True | |
else: | |
logger.error(f"β Failed to download spaCy model: {result.stderr}") | |
return False | |
except subprocess.TimeoutExpired: | |
logger.error(f"β spaCy model download timed out after {timeout} seconds") | |
return False | |
except Exception as e: | |
logger.error(f"β Error downloading spaCy model: {e}") | |
return False | |
def setup_cache_directories() -> None: | |
""" | |
Set up cache directories for models with proper permissions. | |
""" | |
cache_dirs = [ | |
os.environ.get('TRANSFORMERS_CACHE', '/tmp/.cache/huggingface/transformers'), | |
os.environ.get('HF_HOME', '/tmp/.cache/huggingface'), | |
os.environ.get('SENTENCE_TRANSFORMERS_HOME', '/tmp/.cache/sentence-transformers'), | |
] | |
for cache_dir in cache_dirs: | |
try: | |
os.makedirs(cache_dir, exist_ok=True) | |
logger.info(f"π Created cache directory: {cache_dir}") | |
except Exception as e: | |
logger.warning(f"β οΈ Could not create cache directory {cache_dir}: {e}") | |
def validate_python_packages() -> Dict[str, bool]: | |
""" | |
Validate that required Python packages are installed. | |
Returns: | |
Dictionary mapping package names to availability status | |
""" | |
required_packages = { | |
'rank_bm25': 'rank_bm25', | |
'pdfplumber': 'pdfplumber', | |
'sentence_transformers': 'sentence_transformers', | |
'transformers': 'transformers', | |
'spacy': 'spacy', | |
'huggingface_hub': 'huggingface_hub', | |
'faiss': 'faiss', | |
'accelerate': 'accelerate' # Optional but recommended | |
} | |
status = {} | |
for display_name, import_name in required_packages.items(): | |
try: | |
__import__(import_name) | |
status[display_name] = True | |
logger.info(f"β {display_name} is available") | |
except ImportError: | |
status[display_name] = False | |
logger.error(f"β {display_name} is not installed") | |
return status | |
def main() -> int: | |
""" | |
Main setup function. | |
Returns: | |
Exit code (0 for success, 1 for failure) | |
""" | |
logger.info("π Starting Enhanced RAG Demo model setup...") | |
# Check if model download should be skipped | |
skip_download = os.environ.get('SKIP_MODEL_DOWNLOAD', '').lower() in ('1', 'true', 'yes') | |
if skip_download: | |
logger.info("βοΈ Skipping model downloads (SKIP_MODEL_DOWNLOAD set)") | |
return 0 | |
success = True | |
# 1. Validate Python packages | |
logger.info("π¦ Validating Python packages...") | |
package_status = validate_python_packages() | |
critical_packages = ['rank_bm25', 'pdfplumber', 'sentence_transformers', 'transformers', 'spacy'] | |
missing_critical = [pkg for pkg in critical_packages if not package_status.get(pkg, False)] | |
if missing_critical: | |
logger.error(f"β Critical packages missing: {', '.join(missing_critical)}") | |
logger.error("Please install missing packages with: pip install -r requirements.txt") | |
success = False | |
# 2. Setup cache directories | |
logger.info("π Setting up cache directories...") | |
setup_cache_directories() | |
# 3. Handle spaCy model | |
spacy_model = os.environ.get('SPACY_MODEL', 'en_core_web_sm') | |
logger.info(f"π€ Checking spaCy model: {spacy_model}") | |
if package_status.get('spacy', False): | |
if not check_spacy_model(spacy_model): | |
logger.info(f"π₯ Attempting to download spaCy model '{spacy_model}'...") | |
if not download_spacy_model(spacy_model): | |
logger.error(f"β Failed to download spaCy model '{spacy_model}'") | |
logger.warning("β οΈ Entity extraction features may be limited") | |
# Don't fail completely - this is non-critical for basic functionality | |
else: | |
logger.warning("β οΈ spaCy not available - entity extraction will be disabled") | |
# 4. Test model loading (basic validation) | |
if package_status.get('sentence_transformers', False): | |
try: | |
logger.info("π§ͺ Testing sentence-transformers model loading...") | |
from sentence_transformers import SentenceTransformer | |
# Try to load a small model for validation | |
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', cache_folder='/tmp/.cache/sentence-transformers') | |
logger.info("β sentence-transformers model loading successful") | |
del model # Free memory | |
except Exception as e: | |
logger.warning(f"β οΈ sentence-transformers model loading failed: {e}") | |
if success: | |
logger.info("π Model setup completed successfully!") | |
return 0 | |
else: | |
logger.error("π₯ Model setup encountered errors") | |
return 1 | |
if __name__ == "__main__": | |
exit_code = main() | |
sys.exit(exit_code) |