import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel from PIL import Image import logging import spaces import numpy as np logging.basicConfig(level=logging.INFO) class LLaVAPhiModel: def __init__(self, model_id="sagar007/Lava_phi"): self.device = "cuda" self.model_id = model_id logging.info("Initializing LLaVA-Phi model...") self.tokenizer = AutoTokenizer.from_pretrained(model_id) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") self.history = [] self.model = None self.clip = None self.projection = None @spaces.GPU def ensure_models_loaded(self): if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available. This model requires a GPU.") if self.model is None: from transformers import BitsAndBytesConfig quantization_config = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16, bnb_8bit_use_double_quant=False ) self.model = AutoModelForCausalLM.from_pretrained( self.model_id, quantization_config=quantization_config, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True ) self.model.config.pad_token_id = self.tokenizer.eos_token_id logging.info("Successfully loaded main model on GPU") if self.clip is None: self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device) logging.info("Successfully loaded CLIP model") embed_dim = self.model.config.hidden_size clip_dim = self.clip.config.projection_dim self.projection = torch.nn.Linear(clip_dim, embed_dim).to(self.device) # Rest of your class (process_image, generate_response, etc.) remains unchanged # ... (omitted for brevity) def create_demo(): try: model = LLaVAPhiModel() demo = gr.Blocks(css="footer {visibility: hidden}") with demo: gr.Markdown( """ # LLaVA-Phi Demo (Optimized for Accuracy) Chat with a vision-language model that can understand both text and images. """ ) chatbot = gr.Chatbot(height=400) with gr.Row(): with gr.Column(scale=0.7): msg = gr.Textbox( show_label=False, placeholder="Enter text and/or upload an image", container=False ) with gr.Column(scale=0.15, min_width=0): clear = gr.Button("Clear") with gr.Column(scale=0.15, min_width=0): submit = gr.Button("Submit", variant="primary") image = gr.Image(type="pil", label="Upload Image (Optional)") with gr.Accordion("Advanced Settings", open=False): gr.Markdown("Adjust these parameters to control hallucination tendency") temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (lower = more factual)") top_p_slider = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-p (nucleus sampling)") top_k_slider = gr.Slider(10, 100, value=50, step=5, label="Top-k") rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty") update_params = gr.Button("Update Parameters") def respond(message, chat_history, image): if not message and image is None: return chat_history response = model.generate_response(message, image) chat_history.append((message, response)) return "", chat_history def clear_chat(): model.clear_history() return None, None def update_params_fn(temp, top_p, top_k, rep_penalty): return model.update_generation_params(temp, top_p, top_k, rep_penalty) submit.click( respond, [msg, chatbot, image], [msg, chatbot], ) clear.click( clear_chat, None, [chatbot, image], ) msg.submit( respond, [msg, chatbot, image], [msg, chatbot], ) update_params.click( update_params_fn, [temp_slider, top_p_slider, top_k_slider, rep_penalty_slider], None ) return demo except Exception as e: logging.error(f"Error creating demo: {str(e)}") raise if __name__ == "__main__": demo = create_demo() demo.launch(server_name="0.0.0.0", server_port=7860, share=True)