vidhanm commited on
Commit
f9326ef
·
1 Parent(s): 200357b

app.py as per generate.py

Browse files
Files changed (1) hide show
  1. app.py +107 -187
app.py CHANGED
@@ -1,242 +1,162 @@
1
  import sys
2
  import os
3
- from typing import Optional # For type hinting
4
- from PIL import Image as PILImage # Use an alias to avoid conflict with gr.Image
5
 
6
  # Add the cloned nanoVLM directory to Python's system path
7
  NANOVLM_REPO_PATH = "/app/nanoVLM"
8
  if NANOVLM_REPO_PATH not in sys.path:
9
  print(f"DEBUG: Adding {NANOVLM_REPO_PATH} to sys.path")
10
  sys.path.insert(0, NANOVLM_REPO_PATH)
11
- else:
12
- print(f"DEBUG: {NANOVLM_REPO_PATH} already in sys.path")
13
 
14
  import gradio as gr
15
  import torch
16
- from transformers import CLIPImageProcessor, GPT2TokenizerFast
17
 
18
- # Import the custom VisionLanguageModel class
19
- VisionLanguageModel = None # Initialize to None
20
  try:
21
- print("DEBUG: Attempting to import VisionLanguageModel from models.vision_language_model")
22
  from models.vision_language_model import VisionLanguageModel
23
- print("DEBUG: Successfully imported VisionLanguageModel from nanoVLM clone.")
24
  except ImportError as e:
25
- print(f"CRITICAL ERROR: Error importing VisionLanguageModel from nanoVLM clone: {e}.")
26
- print("DEBUG: Please ensure /app/nanoVLM/models/vision_language_model.py exists and is correct.")
27
- # No need to exit here, the checks later will handle it.
28
- except Exception as e:
29
- print(f"CRITICAL ERROR: An unexpected error occurred during VisionLanguageModel import: {e}")
30
-
31
 
32
- # Determine the device to use
33
- device_choice = os.environ.get("DEVICE", "auto")
34
- if device_choice == "auto":
35
- device = "cuda" if torch.cuda.is_available() else "cpu"
36
- else:
37
- device = device_choice
38
  print(f"DEBUG: Using device: {device}")
39
 
40
- # --- Configuration for model components ---
41
- model_id_for_weights = "lusxvr/nanoVLM-222M"
42
- image_processor_id = "openai/clip-vit-base-patch32"
43
- tokenizer_id = "gpt2" # Using canonical gpt2 tokenizer
44
-
45
- print(f"DEBUG: Configuration - model_id_for_weights: {model_id_for_weights}")
46
- print(f"DEBUG: Configuration - image_processor_id: {image_processor_id}")
47
- print(f"DEBUG: Configuration - tokenizer_id: {tokenizer_id}")
48
 
49
- image_processor = None
50
- tokenizer = None
51
  model = None
52
 
53
- # --- Load Processor and Model ---
54
- if VisionLanguageModel is not None: # Only proceed if custom model class was imported
55
  try:
56
- print(f"DEBUG: Attempting to load CLIPImageProcessor from: {image_processor_id}")
57
- image_processor = CLIPImageProcessor.from_pretrained(image_processor_id)
58
- print(f"DEBUG: CLIPImageProcessor loaded: {type(image_processor)}")
59
-
60
- print(f"DEBUG: Attempting to load GPT2TokenizerFast from: {tokenizer_id}")
61
- tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_id)
62
- if tokenizer.pad_token is None:
63
- tokenizer.pad_token = tokenizer.eos_token
64
- print(f"DEBUG: Set tokenizer pad_token to eos_token (ID: {tokenizer.eos_token_id})")
65
- print(f"DEBUG: GPT2TokenizerFast loaded: {type(tokenizer)}, vocab_size: {tokenizer.vocab_size}")
66
-
67
- print(f"DEBUG: Attempting to load model weights from {model_id_for_weights} using VisionLanguageModel.from_pretrained")
68
- # Note: The custom VisionLanguageModel.from_pretrained in nanoVLM does not take trust_remote_code
69
- model = VisionLanguageModel.from_pretrained(model_id_for_weights).to(device)
70
- print(f"DEBUG: Model loaded successfully: {type(model)}")
 
 
 
 
 
71
  model.eval()
72
- print("DEBUG: Model set to evaluation mode (model.eval())")
73
-
74
- # Optional: Print model's state_dict keys (can be very long)
75
- # print("DEBUG: Model state_dict keys (first 10):", list(model.state_dict().keys())[:10])
76
- # print(f"DEBUG: Is model on device '{device}'? {next(model.parameters()).device}")
77
 
78
  except Exception as e:
79
- print(f"CRITICAL ERROR: Error loading model or processor components: {e}")
80
  import traceback
81
  traceback.print_exc()
82
- # Reset to ensure generate_text_for_image knows they failed
83
- image_processor = None
84
- tokenizer = None
85
- model = None
86
  else:
87
- print("CRITICAL ERROR: Custom VisionLanguageModel class not imported. Cannot load model.")
88
-
89
-
90
- # --- Input Preparation Function ---
91
- def prepare_inputs(text_list, image_input, image_processor_instance, tokenizer_instance, device_to_use):
92
- print(f"DEBUG (prepare_inputs): Received text_list: {text_list}")
93
- if image_processor_instance is None or tokenizer_instance is None:
94
- print("ERROR (prepare_inputs): Image processor or tokenizer not initialized.")
95
- raise ValueError("Image processor or tokenizer not initialized.")
96
-
97
- # Process image
98
- print(f"DEBUG (prepare_inputs): Processing image with {type(image_processor_instance)}")
99
- processed_image_output = image_processor_instance(images=image_input, return_tensors="pt")
100
- pixel_values = processed_image_output.pixel_values.to(device_to_use)
101
- print(f"DEBUG (prepare_inputs): pixel_values shape: {pixel_values.shape}, dtype: {pixel_values.dtype}")
102
-
103
- # Process text
104
- print(f"DEBUG (prepare_inputs): Processing text with {type(tokenizer_instance)}")
105
- # Using model_max_length from tokenizer, with a fallback.
106
- max_len = getattr(tokenizer_instance, 'model_max_length', 512)
107
- print(f"DEBUG (prepare_inputs): Tokenizer max_length: {max_len}")
108
- processed_text_output = tokenizer_instance(
109
- text=text_list, return_tensors="pt", padding=True, truncation=True, max_length=max_len
110
- )
111
- input_ids = processed_text_output.input_ids.to(device_to_use)
112
- attention_mask = processed_text_output.attention_mask.to(device_to_use)
113
- print(f"DEBUG (prepare_inputs): input_ids shape: {input_ids.shape}, dtype: {input_ids.dtype}, values: {input_ids}")
114
- print(f"DEBUG (prepare_inputs): attention_mask shape: {attention_mask.shape}, dtype: {attention_mask.dtype}, values: {attention_mask}")
115
-
116
- return {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask}
117
 
118
 
119
  # --- Text Generation Function ---
120
  def generate_text_for_image(image_input_pil: Optional[PILImage.Image], prompt_input_str: Optional[str]) -> str:
121
  print(f"DEBUG (generate_text_for_image): Received prompt: '{prompt_input_str}'")
122
- if model is None or image_processor is None or tokenizer is None:
123
- print("ERROR (generate_text_for_image): Model or processor components not loaded.")
124
- return "Error: Model or processor components not loaded correctly. Check application logs."
125
-
126
- if image_input_pil is None:
127
- print("WARN (generate_text_for_image): No image uploaded.")
128
- return "Please upload an image."
129
- if not prompt_input_str:
130
- print("WARN (generate_text_for_image): No prompt provided.")
131
- return "Please provide a prompt (e.g., 'a photo of a')."
132
 
133
  try:
134
- print("DEBUG (generate_text_for_image): Preparing image...")
135
- current_pil_image = image_input_pil # Gradio provides PIL if type="pil"
136
  if not isinstance(current_pil_image, PILImage.Image):
137
- print(f"WARN (generate_text_for_image): Input image not PIL, type: {type(current_pil_image)}. Converting.")
138
- current_pil_image = PILImage.fromarray(current_pil_image) # Fallback if not PIL
139
  if current_pil_image.mode != "RGB":
140
- print(f"DEBUG (generate_text_for_image): Converting image from mode {current_pil_image.mode} to RGB.")
141
  current_pil_image = current_pil_image.convert("RGB")
142
- print(f"DEBUG (generate_text_for_image): Image size: {current_pil_image.size}, mode: {current_pil_image.mode}")
143
-
144
- print("DEBUG (generate_text_for_image): Preparing inputs for the model...")
145
- inputs_dict = prepare_inputs(
146
- text_list=[prompt_input_str], image_input=current_pil_image,
147
- image_processor_instance=image_processor, tokenizer_instance=tokenizer, device_to_use=device
148
- )
 
 
 
149
 
150
- print(f"DEBUG (generate_text_for_image): Calling model.generate with input_ids (shape {inputs_dict['input_ids'].shape}), pixel_values (shape {inputs_dict['pixel_values'].shape}), attention_mask (shape {inputs_dict['attention_mask'].shape})")
151
-
152
- # Match the signature: def generate(self, input_ids, image, attention_mask=None, max_new_tokens=...)
 
 
 
 
 
 
 
 
153
  generated_ids_tensor = model.generate(
154
- inputs_dict['input_ids'], # 1st argument: input_ids (text prompt)
155
- inputs_dict['pixel_values'], # 2nd argument: image (pixel values)
156
- inputs_dict['attention_mask'], # 3rd argument: attention_mask (for text)
157
- max_new_tokens=30, # Using a smaller value for quicker debugging
158
- temperature=0.8, # Slightly higher temperature to encourage diversity
159
- top_k=50, # As per nanoVLM signature default
160
- top_p=0.9, # As per nanoVLM signature default
161
- greedy=False # As per nanoVLM signature default
162
  )
163
-
164
- print(f"DEBUG (generate_text_for_image): Raw generated_ids tensor: {generated_ids_tensor}")
165
-
166
- # Decode the generated tokens
167
- print("DEBUG (generate_text_for_image): Decoding generated tokens...")
168
- generated_text_list_decoded = tokenizer.batch_decode(generated_ids_tensor, skip_special_tokens=True)
169
- print(f"DEBUG (generate_text_for_image): Decoded text list (before join/cleanup): {generated_text_list_decoded}")
170
- generated_text_str = generated_text_list_decoded[0] if generated_text_list_decoded else ""
171
 
172
- # Optional: Clean up prompt if it's echoed by the model
173
  cleaned_text_str = generated_text_str
174
  if prompt_input_str and generated_text_str.startswith(prompt_input_str):
175
- print("DEBUG (generate_text_for_image): Prompt found at the beginning of generation, removing it.")
176
  cleaned_text_str = generated_text_str[len(prompt_input_str):].lstrip(" ,.:")
177
-
178
- print(f"DEBUG (generate_text_for_image): Final cleaned text to be returned: '{cleaned_text_str}'")
179
  return cleaned_text_str.strip()
180
 
181
  except Exception as e:
182
- print(f"CRITICAL ERROR (generate_text_for_image): An error occurred during generation: {e}")
183
  import traceback
184
- traceback.print_exc() # Print full traceback to logs
185
- return f"An error occurred during text generation: {str(e)}. Check application logs."
186
-
187
 
188
- # --- Gradio Interface Definition ---
189
  description_md = """
190
- ## Interactive nanoVLM-222M Demo
191
- Upload an image and provide a text prompt (e.g., "What is in this image?", "Describe the animal in detail.").
192
- The model will attempt to generate a textual response based on the visual content and your query.
193
- This Space uses the `lusxvr/nanoVLM-222M` model with code from the original `huggingface/nanoVLM` repository.
194
  """
195
- # example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" # Not used currently
196
-
197
- print("DEBUG: Defining Gradio interface...")
198
  iface = None
199
- try:
200
- iface = gr.Interface(
201
- fn=generate_text_for_image,
202
- inputs=[
203
- gr.Image(type="pil", label="Upload Image"), # type="pil" ensures PIL.Image object
204
- gr.Textbox(label="Your Prompt / Question", info="e.g., 'a photo of a', 'Describe this scene.'")
205
- ],
206
- outputs=gr.Textbox(label="Generated Text", show_copy_button=True),
207
- title="nanoVLM-222M Interactive Demo",
208
- description=description_md,
209
- # examples=[ # Examples commented out to simplify Gradio setup
210
- # [example_image_url, "a photo of a"],
211
- # [example_image_url, "Describe the image in detail."],
212
- # ],
213
- # cache_examples=False, # Explicitly False, or remove argument
214
- allow_flagging="never" # Keep flagging disabled
215
- )
216
- print("DEBUG: Gradio interface defined successfully.")
217
- except Exception as e:
218
- print(f"CRITICAL ERROR: Error defining Gradio interface: {e}")
219
- import traceback
220
- traceback.print_exc()
221
-
222
-
223
- # --- Launch Gradio App ---
224
  if __name__ == "__main__":
225
- print("DEBUG: Entered __main__ block.")
226
- if VisionLanguageModel is None:
227
- print("CRITICAL ERROR: VisionLanguageModel class was not imported. Cannot proceed.")
228
- elif model is None or image_processor is None or tokenizer is None:
229
- print("CRITICAL ERROR: Model, image_processor, or tokenizer failed to load. Gradio app might not be fully functional.")
230
-
231
- if iface is not None:
232
- print("DEBUG: Attempting to launch Gradio interface...")
233
- try:
234
- iface.launch(server_name="0.0.0.0", server_port=7860) # Standard for Spaces
235
- print("DEBUG: Gradio launch command issued.") # This might not be reached if launch blocks or errors immediately
236
- except Exception as e:
237
- print(f"CRITICAL ERROR: Error launching Gradio interface: {e}")
238
- import traceback
239
- traceback.print_exc()
240
  else:
241
- print("CRITICAL ERROR: Gradio interface (iface) is None. Cannot launch.")
242
-
 
1
  import sys
2
  import os
3
+ from typing import Optional
4
+ from PIL import Image as PILImage
5
 
6
  # Add the cloned nanoVLM directory to Python's system path
7
  NANOVLM_REPO_PATH = "/app/nanoVLM"
8
  if NANOVLM_REPO_PATH not in sys.path:
9
  print(f"DEBUG: Adding {NANOVLM_REPO_PATH} to sys.path")
10
  sys.path.insert(0, NANOVLM_REPO_PATH)
 
 
11
 
12
  import gradio as gr
13
  import torch
14
+ from transformers import AutoProcessor # Using AutoProcessor as in generate.py
15
 
16
+ VisionLanguageModel = None
 
17
  try:
18
+ print("DEBUG: Attempting to import VisionLanguageModel")
19
  from models.vision_language_model import VisionLanguageModel
20
+ print("DEBUG: Successfully imported VisionLanguageModel.")
21
  except ImportError as e:
22
+ print(f"CRITICAL ERROR: Importing VisionLanguageModel: {e}")
 
 
 
 
 
23
 
24
+ # --- Device Setup ---
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
26
  print(f"DEBUG: Using device: {device}")
27
 
28
+ # --- Configuration ---
29
+ # This will be used for both model and processor, as in generate.py
30
+ model_repo_id = "lusxvr/nanoVLM-222M"
31
+ print(f"DEBUG: Model Repository ID for model and processor: {model_repo_id}")
 
 
 
 
32
 
33
+ # --- Initialize ---
34
+ processor = None
35
  model = None
36
 
37
+ if VisionLanguageModel: # Only proceed if custom model class was imported
 
38
  try:
39
+ # Load processor using AutoProcessor, like in generate.py
40
+ print(f"DEBUG: Loading processor using AutoProcessor.from_pretrained('{model_repo_id}')")
41
+ # Using trust_remote_code=True here as a precaution,
42
+ # though ideally not needed if processor_config.json is complete.
43
+ processor = AutoProcessor.from_pretrained(model_repo_id, trust_remote_code=True)
44
+ print(f"DEBUG: AutoProcessor loaded: {type(processor)}")
45
+
46
+ # Ensure tokenizer has pad_token set if it's GPT-2 based
47
+ if hasattr(processor, 'tokenizer') and processor.tokenizer is not None:
48
+ if getattr(processor.tokenizer, 'pad_token', None) is None: # Check if pad_token attribute exists and is None
49
+ processor.tokenizer.pad_token = processor.tokenizer.eos_token
50
+ print(f"DEBUG: Set processor.tokenizer.pad_token to eos_token (ID: {processor.tokenizer.eos_token_id})")
51
+ else:
52
+ print("DEBUG: Processor does not have a 'tokenizer' attribute or it is None.")
53
+
54
+
55
+ # Load model, like in generate.py
56
+ print(f"DEBUG: Loading model VisionLanguageModel.from_pretrained('{model_repo_id}')")
57
+ model = VisionLanguageModel.from_pretrained(model_repo_id).to(device)
58
+ print(f"DEBUG: VisionLanguageModel loaded: {type(model)}")
59
  model.eval()
60
+ print("DEBUG: Model set to eval() mode.")
 
 
 
 
61
 
62
  except Exception as e:
63
+ print(f"CRITICAL ERROR loading model or processor with AutoProcessor: {e}")
64
  import traceback
65
  traceback.print_exc()
66
+ processor = None; model = None
 
 
 
67
  else:
68
+ print("CRITICAL ERROR: VisionLanguageModel class not imported. Cannot load model.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
  # --- Text Generation Function ---
72
  def generate_text_for_image(image_input_pil: Optional[PILImage.Image], prompt_input_str: Optional[str]) -> str:
73
  print(f"DEBUG (generate_text_for_image): Received prompt: '{prompt_input_str}'")
74
+ if model is None or processor is None:
75
+ return "Error: Model or processor not loaded. Check logs."
76
+ if image_input_pil is None: return "Please upload an image."
77
+ if not prompt_input_str: return "Please provide a prompt."
 
 
 
 
 
 
78
 
79
  try:
80
+ current_pil_image = image_input_pil
 
81
  if not isinstance(current_pil_image, PILImage.Image):
82
+ current_pil_image = PILImage.fromarray(current_pil_image)
 
83
  if current_pil_image.mode != "RGB":
 
84
  current_pil_image = current_pil_image.convert("RGB")
85
+ print(f"DEBUG: Image prepped - size: {current_pil_image.size}, mode: {current_pil_image.mode}")
86
+
87
+ # Prepare inputs using the AutoProcessor, as in generate.py
88
+ print("DEBUG: Processing inputs with AutoProcessor...")
89
+ inputs = processor(
90
+ text=[prompt_input_str], images=current_pil_image, return_tensors="pt"
91
+ ).to(device)
92
+ print(f"DEBUG: Inputs from AutoProcessor - keys: {inputs.keys()}")
93
+ print(f"DEBUG: input_ids shape: {inputs['input_ids'].shape}, values: {inputs['input_ids']}")
94
+ print(f"DEBUG: pixel_values shape: {inputs['pixel_values'].shape}")
95
 
96
+ # Ensure attention_mask is present, default to ones if not (though AutoProcessor should provide it)
97
+ attention_mask = inputs.get('attention_mask')
98
+ if attention_mask is None:
99
+ print("WARN: attention_mask not found in processor output, creating a default one of all 1s.")
100
+ attention_mask = torch.ones_like(inputs['input_ids']).to(device)
101
+ print(f"DEBUG: attention_mask shape: {attention_mask.shape}")
102
+
103
+
104
+ print("DEBUG: Calling model.generate (aligning with nanoVLM's generate.py)...")
105
+ # Signature for nanoVLM's generate: (self, input_ids, image, attention_mask, max_new_tokens, ...)
106
+ # `image` parameter in generate() corresponds to `pixel_values` from processor output
107
  generated_ids_tensor = model.generate(
108
+ inputs['input_ids'], # 1st argument to model.generate: input_ids (text prompt)
109
+ inputs['pixel_values'], # 2nd argument to model.generate: image (pixel values)
110
+ attention_mask, # 3rd argument to model.generate: attention_mask
111
+ max_new_tokens=30, # Corresponds to 4th argument in model.generate
112
+ temperature=0.7, # Match generate.py default or your choice
113
+ top_k=50, # Match generate.py default or your choice
114
+ greedy=False # Match generate.py default or your choice
115
+ # top_p is also an option from generate.py's model.generate
116
  )
117
+ print(f"DEBUG: Raw generated_ids: {generated_ids_tensor}")
118
+
119
+ generated_text_list = processor.batch_decode(generated_ids_tensor, skip_special_tokens=True)
120
+ print(f"DEBUG: Decoded text list: {generated_text_list}")
121
+ generated_text_str = generated_text_list[0] if generated_text_list else ""
 
 
 
122
 
 
123
  cleaned_text_str = generated_text_str
124
  if prompt_input_str and generated_text_str.startswith(prompt_input_str):
 
125
  cleaned_text_str = generated_text_str[len(prompt_input_str):].lstrip(" ,.:")
126
+ print(f"DEBUG: Final cleaned text: '{cleaned_text_str}'")
 
127
  return cleaned_text_str.strip()
128
 
129
  except Exception as e:
130
+ print(f"CRITICAL ERROR during generation: {e}")
131
  import traceback
132
+ traceback.print_exc()
133
+ return f"Error during generation: {str(e)}"
 
134
 
135
+ # --- Gradio Interface ---
136
  description_md = """
137
+ ## Interactive nanoVLM-222M Demo (Mirroring generate.py)
138
+ Trying to replicate the working `generate.py` script from `huggingface/nanoVLM`.
139
+ Using AutoProcessor for inputs.
 
140
  """
 
 
 
141
  iface = None
142
+ if processor and model:
143
+ try:
144
+ iface = gr.Interface(
145
+ fn=generate_text_for_image,
146
+ inputs=[gr.Image(type="pil", label="Upload Image"), gr.Textbox(label="Your Prompt")],
147
+ outputs=gr.Textbox(label="Generated Text", show_copy_button=True),
148
+ title="nanoVLM-222M Demo (generate.py Alignment)",
149
+ description=description_md,
150
+ allow_flagging="never"
151
+ )
152
+ print("DEBUG: Gradio interface defined.")
153
+ except Exception as e:
154
+ print(f"CRITICAL ERROR defining Gradio interface: {e}")
155
+ import traceback; traceback.print_exc()
156
+
 
 
 
 
 
 
 
 
 
 
157
  if __name__ == "__main__":
158
+ if iface:
159
+ print("DEBUG: Launching Gradio...")
160
+ iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
161
  else:
162
+ print("CRITICAL ERROR: Gradio interface not defined or model/processor failed to load. Cannot launch.")