import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, AutoModel from PIL import Image import logging import spaces # Setup logging logging.basicConfig(level=logging.INFO) class LLaVAPhiModel: def __init__(self, model_id="sagar007/Lava_phi"): self.device = "cuda" # Always use cuda with ZeroGPU self.model_id = model_id logging.info("Initializing LLaVA-Phi model...") # Initialize tokenizer (can be done outside GPU context) self.tokenizer = AutoTokenizer.from_pretrained(model_id) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Initialize processor (can be done outside GPU context) self.processor = AutoProcessor.from_pretrained("microsoft/clip-vit-base-patch32") # Store conversation history self.history = [] # Lazy loading of models - will be initialized in GPU context self.model = None self.clip = None @spaces.GPU def ensure_models_loaded(self): """Ensure models are loaded in GPU context""" if self.model is None: # Load main model from transformers import BitsAndBytesConfig quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) self.model = AutoModelForCausalLM.from_pretrained( self.model_id, quantization_config=quantization_config, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True ) self.model.config.pad_token_id = self.tokenizer.eos_token_id if self.clip is None: # Load CLIP model self.clip = AutoModel.from_pretrained("microsoft/clip-vit-base-patch32").to(self.device) @spaces.GPU def process_image(self, image): """Process image through CLIP""" try: # Ensure models are loaded self.ensure_models_loaded() # Convert image to correct format if isinstance(image, str): image = Image.open(image) elif isinstance(image, numpy.ndarray): image = Image.fromarray(image) with torch.no_grad(): image_inputs = self.processor(images=image, return_tensors="pt") image_features = self.clip.get_image_features( pixel_values=image_inputs.pixel_values.to(self.device) ) return image_features except Exception as e: logging.error(f"Error processing image: {str(e)}") raise @spaces.GPU(duration=120) # Set longer duration for generation def generate_response(self, message, image=None): try: # Ensure models are loaded self.ensure_models_loaded() if image is not None: try: image_features = self.process_image(image) has_image = True except Exception as e: logging.error(f"Failed to process image: {str(e)}") image_features = None has_image = False message = f"Note: Failed to process image. Continuing with text only. Error: {str(e)}\n{message}" prompt = f"human: {'' if has_image else ''}\n{message}\ngpt:" context = "" for turn in self.history[-3:]: context += f"human: {turn[0]}\ngpt: {turn[1]}\n" full_prompt = context + prompt inputs = self.tokenizer( full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=512 ) inputs = {k: v.to(self.device) for k, v in inputs.items()} if has_image: inputs["image_features"] = image_features with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=256, min_length=20, temperature=0.7, do_sample=True, top_p=0.9, top_k=40, repetition_penalty=1.5, no_repeat_ngram_size=3, use_cache=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) else: prompt = f"human: {message}\ngpt:" context = "" for turn in self.history[-3:]: context += f"human: {turn[0]}\ngpt: {turn[1]}\n" full_prompt = context + prompt inputs = self.tokenizer( full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=512 ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=150, min_length=20, temperature=0.6, do_sample=True, top_p=0.85, top_k=30, repetition_penalty=1.8, no_repeat_ngram_size=4, use_cache=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) if "gpt:" in response: response = response.split("gpt:")[-1].strip() if "human:" in response: response = response.split("human:")[0].strip() if "" in response: response = response.replace("", "").strip() self.history.append((message, response)) return response except Exception as e: logging.error(f"Error generating response: {str(e)}") logging.error(f"Full traceback:", exc_info=True) return f"Error: {str(e)}" def clear_history(self): self.history = [] return None def create_demo(): try: model = LLaVAPhiModel() with gr.Blocks(css="footer {visibility: hidden}") as demo: gr.Markdown( """ # LLaVA-Phi Demo (ZeroGPU) Chat with a vision-language model that can understand both text and images. """ ) chatbot = gr.Chatbot(height=400) with gr.Row(): with gr.Column(scale=0.7): msg = gr.Textbox( show_label=False, placeholder="Enter text and/or upload an image", container=False ) with gr.Column(scale=0.15, min_width=0): clear = gr.Button("Clear") with gr.Column(scale=0.15, min_width=0): submit = gr.Button("Submit", variant="primary") image = gr.Image(type="pil", label="Upload Image (Optional)") def respond(message, chat_history, image): if not message and image is None: return chat_history response = model.generate_response(message, image) chat_history.append((message, response)) return "", chat_history def clear_chat(): model.clear_history() return None, None submit.click( respond, [msg, chatbot, image], [msg, chatbot], ) clear.click( clear_chat, None, [chatbot, image], ) msg.submit( respond, [msg, chatbot, image], [msg, chatbot], ) return demo except Exception as e: logging.error(f"Error creating demo: {str(e)}") raise if __name__ == "__main__": demo = create_demo() demo.launch( server_name="0.0.0.0", server_port=7860, share=True )