import sys import os # Add the cloned nanoVLM directory to Python's system path NANOVLM_REPO_PATH = "/app/nanoVLM" if NANOVLM_REPO_PATH not in sys.path: sys.path.insert(0, NANOVLM_REPO_PATH) import gradio as gr from PIL import Image import torch # Import specific processor components from transformers import CLIPImageProcessor, GPT2TokenizerFast # Import the custom VisionLanguageModel class try: from models.vision_language_model import VisionLanguageModel print("Successfully imported VisionLanguageModel from nanoVLM clone.") except ImportError as e: print(f"Error importing VisionLanguageModel from nanoVLM clone: {e}.") VisionLanguageModel = None # Determine the device to use device_choice = os.environ.get("DEVICE", "auto") if device_choice == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" else: device = device_choice print(f"Using device: {device}") # --- Configuration for model components --- # The main model ID for weights and overall config model_id_for_weights = "lusxvr/nanoVLM-222M" # The ID for the vision backbone's image processor configuration image_processor_id = "openai/clip-vit-base-patch32" # The ID for the tokenizer (can be the main model ID if it provides specific tokenizer files) tokenizer_id = "lusxvr/nanoVLM-222M" # Or directly "gpt2" if preferred, but model_id is usually safer image_processor = None tokenizer = None model = None if VisionLanguageModel: try: print(f"Attempting to load CLIPImageProcessor from: {image_processor_id}") image_processor = CLIPImageProcessor.from_pretrained(image_processor_id, trust_remote_code=True) print("CLIPImageProcessor loaded.") print(f"Attempting to load GPT2TokenizerFast from: {tokenizer_id}") tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_id, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("Set tokenizer pad_token to eos_token.") print("GPT2TokenizerFast loaded.") print(f"Attempting to load model weights from {model_id_for_weights} using VisionLanguageModel.from_pretrained") model = VisionLanguageModel.from_pretrained( model_id_for_weights, trust_remote_code=True ).to(device) print("Model loaded successfully.") model.eval() except Exception as e: print(f"Error loading model or processor components: {e}") import traceback traceback.print_exc() # Print full traceback image_processor = None tokenizer = None model = None else: print("Custom VisionLanguageModel class not imported, cannot load model.") def prepare_inputs(text_list, image_input, image_processor_instance, tokenizer_instance, device_to_use): if image_processor_instance is None or tokenizer_instance is None: raise ValueError("Image processor or tokenizer not initialized.") processed_image = image_processor_instance(images=image_input, return_tensors="pt").pixel_values.to(device_to_use) processed_text = tokenizer_instance( text=text_list, return_tensors="pt", padding=True, truncation=True, max_length=tokenizer_instance.model_max_length ) input_ids = processed_text.input_ids.to(device_to_use) attention_mask = processed_text.attention_mask.to(device_to_use) return {"pixel_values": processed_image, "input_ids": input_ids, "attention_mask": attention_mask} def generate_text_for_image(image_input, prompt_input): if model is None or image_processor is None or tokenizer is None: return "Error: Model or processor components not loaded correctly. Check logs." if image_input is None: return "Please upload an image." if not prompt_input: return "Please provide a prompt." try: if not isinstance(image_input, Image.Image): pil_image = Image.fromarray(image_input) else: pil_image = image_input if pil_image.mode != "RGB": pil_image = pil_image.convert("RGB") inputs = prepare_inputs( text_list=[prompt_input], image_input=pil_image, image_processor_instance=image_processor, tokenizer_instance=tokenizer, device_to_use=device ) generated_ids = model.generate( pixel_values=inputs['pixel_values'], input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], max_new_tokens=150, num_beams=3, no_repeat_ngram_size=2, early_stopping=True, pad_token_id=tokenizer.pad_token_id ) generated_text_list = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) generated_text = generated_text_list[0] if generated_text_list else "" if prompt_input and generated_text.startswith(prompt_input): cleaned_text = generated_text[len(prompt_input):].lstrip(" ,.:") else: cleaned_text = generated_text return cleaned_text.strip() except Exception as e: print(f"Error during generation: {e}") import traceback traceback.print_exc() return f"An error occurred during text generation: {str(e)}" description = "Interactive demo for lusxvr/nanoVLM-222M." example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" # gradio_cache_dir = os.environ.get("GRADIO_TEMP_DIR", "/tmp/gradio_tmp") # Not used for now iface = gr.Interface( fn=generate_text_for_image, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Textbox(label="Your Prompt/Question") ], outputs=gr.Textbox(label="Generated Text", show_copy_button=True), title="Interactive nanoVLM-222M Demo", description=description, examples=[ [example_image_url, "a photo of a"], [example_image_url, "Describe the image in detail."], ], cache_examples=True, # This might cause issues if Gradio version is old. Remove if needed. # examples_cache_folder=gradio_cache_dir, # Removed due to potential Gradio version issue allow_flagging="never" ) if __name__ == "__main__": if model is None or image_processor is None or tokenizer is None: print("CRITICAL: Model or processor components failed to load.") else: print("Launching Gradio interface...") iface.launch(server_name="0.0.0.0", server_port=7860)