TatTwamAI / models /_init_.py
Jayashree Sridhar
refactore the code files to use TinyGPT2Model
292f6f6
"""
Models module for Personal Coach CrewAI Application
Handles all AI model loading and management
"""
from typing import TYPE_CHECKING, Optional, Dict, Any
import torch
# Version info
__version__ = "1.0.0"
# Lazy imports
if TYPE_CHECKING:
from .mistral_model import MistralModel, MistralConfig, MistralPromptFormatter
from .tiny_gpt2_model import TinyGPT2Model
# Public API
__all__ = [
# Main model classes
"MistralModel",
"MistralConfig",
"MistralPromptFormatter",
"TinyGPT2Model",
# Model management
"load_model",
"get_model_info",
"clear_model_cache",
# Constants
"AVAILABLE_MODELS",
"MODEL_REQUIREMENTS",
"DEFAULT_MODEL_CONFIG"
]
# Available models
AVAILABLE_MODELS = {
"mistral-7b-instruct": {
"model_id": "mistralai/Mistral-7B-Instruct-v0.1",
"type": "instruction-following",
"size": "7B",
"context_length": 32768,
"languages": ["multilingual"]
},
"mistral-7b": {
"model_id": "mistralai/Mistral-7B-v0.1",
"type": "base",
"size": "7B",
"context_length": 32768,
"languages": ["multilingual"]
},
"tiny-gpt2": {
"model_id": "sshleifer/tiny-gpt2",
"type": "tiny",
"size": "small",
"context_length": 256,
"languages": ["en"]
}
}
# Model requirements
MODEL_REQUIREMENTS = {
"mistral-7b-instruct": {
"ram": "16GB",
"vram": "8GB (GPU) or 16GB (CPU)",
"disk": "15GB",
"compute": "GPU recommended"
},
"tiny-gpt2": {
"ram": "≤1GB",
"vram": "CPU only",
"disk": "<1GB",
"compute": "CPU"
}
}
# Default configuration: Set to CPU/float32
DEFAULT_MODEL_CONFIG = {
"max_length": 256,
"temperature": 0.7,
"top_p": 0.95,
"top_k": 50,
"do_sample": True,
"num_return_sequences": 1,
"device": "cpu",
"torch_dtype": torch.float32,
"load_in_8bit": False,
"cache_dir": ".cache/models"
}
# Model instance cache
_model_cache: Dict[str, Any] = {}
def load_model(model_name: str = "tiny-gpt2", config: Optional[Dict[str, Any]] = None):
"""
Load a model with caching support
Args:
model_name: Name of the model to load
config: Optional configuration override
Returns:
Model instance
"""
# Check cache first
cache_key = f"{model_name}_{str(config)}"
if cache_key in _model_cache:
return _model_cache[cache_key]
# Import here to avoid circular imports
if model_name == "tiny-gpt2":
from .tiny_gpt2_model import TinyGPT2Model
# No config needed for TinyGPT2, ignore config for now
model = TinyGPT2Model()
elif model_name in ["mistral-7b-instruct", "mistral-7b"]:
from .mistral_model import MistralModel, MistralConfig
model_info = AVAILABLE_MODELS.get(model_name)
if not model_info:
raise ValueError(f"Unknown model: {model_name}")
model_config = DEFAULT_MODEL_CONFIG.copy()
if config:
model_config.update(config)
mistral_config = MistralConfig(
model_id=model_info["model_id"],
**model_config
)
model = MistralModel(mistral_config)
else:
raise ValueError(f"Unknown model: {model_name}")
# Cache it
_model_cache[cache_key] = model
return model
def get_model_info(model_name: str) -> Optional[Dict[str, Any]]:
"""
Get information about a model
Args:
model_name: Name of the model
Returns:
Model information dictionary or None
"""
info = AVAILABLE_MODELS.get(model_name)
if info:
# Add requirements
requirements = MODEL_REQUIREMENTS.get(model_name, {})
info = info.copy() # avoid mutating global dict!
info["requirements"] = requirements
# Add loading status
cache_keys = [k for k in _model_cache.keys() if k.startswith(model_name)]
info["is_loaded"] = len(cache_keys) > 0
return info
def clear_model_cache(model_name: Optional[str] = None):
"""
Clear model cache to free memory
Args:
model_name: Specific model to clear, or None for all
"""
global _model_cache
if model_name:
# Clear specific model
keys_to_remove = [k for k in _model_cache.keys() if k.startswith(model_name)]
for key in keys_to_remove:
del _model_cache[key]
else:
# Clear all
_model_cache.clear()
# Force garbage collection
import gc
gc.collect()
# Clear GPU cache if using CUDA
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Utility functions
def estimate_memory_usage(model_name: str) -> Dict[str, Any]:
"""
Estimate memory usage for a model
Args:
model_name: Name of the model
Returns:
Memory estimation dictionary
"""
model_info = AVAILABLE_MODELS.get(model_name)
if not model_info:
return {}
size = model_info.get("size", "7B")
if size.endswith("B"):
size_gb = float(size.replace("B", "")) # e.g. "7B"
elif size == "small":
size_gb = 0.02 # Arbitrary tiny model size in GB
else:
size_gb = 0.1 # catchall
estimates = {
"model_size_gb": size_gb,
"fp32_memory_gb": size_gb * 4, # 4 bytes per parameter
"fp16_memory_gb": size_gb * 2, # 2 bytes per parameter
"int8_memory_gb": size_gb, # 1 byte per parameter
"recommended_ram_gb": size_gb * 2.5,
"recommended_vram_gb": size_gb * 1.5
}
return estimates
def get_device_info() -> Dict[str, Any]:
"""Get information about available compute devices"""
info = {
"cuda_available": torch.cuda.is_available(),
"device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
"current_device": torch.cuda.current_device() if torch.cuda.is_available() else None,
"device_name": torch.cuda.get_device_name() if torch.cuda.is_available() else "CPU"
}
if torch.cuda.is_available():
info["gpu_memory"] = {
"allocated": torch.cuda.memory_allocated() / 1024**3, # GB
"reserved": torch.cuda.memory_reserved() / 1024**3, # GB
"total": torch.cuda.get_device_properties(0).total_memory / 1024**3 # GB
}
return info
# Module initialization
import os
if os.getenv("DEBUG_MODE", "false").lower() == "true":
print(f"Models module v{__version__} initialized")
device_info = get_device_info()
print(f"Device: {device_info['device_name']}")
if device_info['cuda_available']:
print(f"GPU Memory: {device_info['gpu_memory']['total']:.1f}GB")