""" LLaVA model implementation. """ import torch from transformers import AutoProcessor, AutoModelForCausalLM from PIL import Image from ..configs.settings import MODEL_NAME, MODEL_REVISION, DEVICE from ..utils.logging import get_logger logger = get_logger(__name__) class LLaVAModel: """LLaVA model wrapper class.""" def __init__(self): """Initialize the LLaVA model and processor.""" try: logger.info(f"Initializing LLaVA model from {MODEL_NAME}") logger.info(f"Using device: {DEVICE}") # Initialize processor self.processor = AutoProcessor.from_pretrained( MODEL_NAME, revision=MODEL_REVISION, trust_remote_code=True ) # Set model dtype based on device model_dtype = torch.float32 if DEVICE == "cpu" else torch.float16 # Initialize model with appropriate settings self.model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, revision=MODEL_REVISION, torch_dtype=model_dtype, device_map="auto" if DEVICE == "cuda" else None, trust_remote_code=True, low_cpu_mem_usage=True ) # Move model to device if not using device_map if DEVICE == "cpu": self.model = self.model.to(DEVICE) logger.info("Model initialization complete") except Exception as e: logger.error(f"Error initializing model: {str(e)}") raise def generate_response( self, image: Image.Image, prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9 ) -> str: """ Generate a response for the given image and prompt. Args: image: Input image as PIL Image prompt: Text prompt for the model max_new_tokens: Maximum number of tokens to generate temperature: Sampling temperature top_p: Top-p sampling parameter Returns: str: Generated response """ try: # Prepare inputs inputs = self.processor( prompt, image, return_tensors="pt" ).to(DEVICE) # Generate response with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True ) # Decode and return response response = self.processor.decode( outputs[0], skip_special_tokens=True ) logger.debug(f"Generated response: {response[:100]}...") return response except Exception as e: logger.error(f"Error generating response: {str(e)}") raise def __call__(self, *args, **kwargs): """Convenience method to call generate_response.""" return self.generate_response(*args, **kwargs)