vidhanm commited on
Commit
fb82462
·
1 Parent(s): 224188b

updated generate text for image

Browse files
Files changed (1) hide show
  1. app.py +27 -11
app.py CHANGED
@@ -70,43 +70,59 @@ def prepare_inputs(text_list, image_input, image_processor_instance, tokenizer_i
70
  attention_mask = processed_text.attention_mask.to(device_to_use)
71
  return {"pixel_values": processed_image, "input_ids": input_ids, "attention_mask": attention_mask}
72
 
73
- def generate_text_for_image(image_input, prompt_input):
 
 
 
 
 
74
  if model is None or image_processor is None or tokenizer is None:
75
  return "Error: Model or processor components not loaded correctly. Check logs."
76
  if image_input is None: return "Please upload an image."
77
  if not prompt_input: return "Please provide a prompt."
78
 
79
  try:
80
- if not isinstance(image_input, Image.Image):
81
- pil_image = Image.fromarray(image_input)
82
- else:
83
- pil_image = image_input
84
- if pil_image.mode != "RGB": pil_image = pil_image.convert("RGB")
85
 
86
  inputs = prepare_inputs(
87
- text_list=[prompt_input], image_input=pil_image,
88
  image_processor_instance=image_processor, tokenizer_instance=tokenizer, device_to_use=device
89
  )
90
 
 
 
 
91
  generated_ids = model.generate(
92
- pixel_values=inputs['pixel_values'], input_ids=inputs['input_ids'],
93
- attention_mask=inputs['attention_mask'], max_new_tokens=150, num_beams=3,
94
- no_repeat_ngram_size=2, early_stopping=True, pad_token_id=tokenizer.pad_token_id
 
 
95
  )
96
 
97
  generated_text_list = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
98
  generated_text = generated_text_list[0] if generated_text_list else ""
99
 
 
100
  if prompt_input and generated_text.startswith(prompt_input):
101
  cleaned_text = generated_text[len(prompt_input):].lstrip(" ,.:")
102
  else:
103
  cleaned_text = generated_text
 
104
  return cleaned_text.strip()
 
105
  except Exception as e:
106
  print(f"Error during generation: {e}")
107
- import traceback; traceback.print_exc()
 
108
  return f"An error occurred during text generation: {str(e)}"
109
 
 
 
110
  description = "Interactive demo for lusxvr/nanoVLM-222M."
111
  # example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" # Not used for now
112
 
 
70
  attention_mask = processed_text.attention_mask.to(device_to_use)
71
  return {"pixel_values": processed_image, "input_ids": input_ids, "attention_mask": attention_mask}
72
 
73
+ from typing import Optional
74
+ from PIL import Image as PILImage # Add at the top of your app.py
75
+
76
+ # ... (other imports and model loading) ...
77
+
78
+ def generate_text_for_image(image_input: Optional[PILImage.Image], prompt_input: Optional[str]) -> str:
79
  if model is None or image_processor is None or tokenizer is None:
80
  return "Error: Model or processor components not loaded correctly. Check logs."
81
  if image_input is None: return "Please upload an image."
82
  if not prompt_input: return "Please provide a prompt."
83
 
84
  try:
85
+ current_pil_image = image_input
86
+ if not isinstance(current_pil_image, PILImage.Image):
87
+ current_pil_image = PILImage.fromarray(current_pil_image)
88
+ if current_pil_image.mode != "RGB":
89
+ current_pil_image = current_pil_image.convert("RGB")
90
 
91
  inputs = prepare_inputs(
92
+ text_list=[prompt_input], image_input=current_pil_image,
93
  image_processor_instance=image_processor, tokenizer_instance=tokenizer, device_to_use=device
94
  )
95
 
96
+ print(f"Debug: Passing to model.generate: pixel_values_shape={inputs['pixel_values'].shape}, input_ids_shape={inputs['input_ids'].shape}, attention_mask_shape={inputs['attention_mask'].shape}")
97
+
98
+ # Call model.generate with positional arguments matching nanoVLM's VisionLanguageModel.generate
99
  generated_ids = model.generate(
100
+ inputs['pixel_values'], # pixel_values
101
+ inputs['input_ids'], # prompt_token_ids
102
+ inputs['attention_mask'], # attention_mask
103
+ 150 # max_new_tokens (as a positional argument)
104
+ # You can add temperature=..., top_k=... here if desired, as they are keyword args in nanoVLM's generate
105
  )
106
 
107
  generated_text_list = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
108
  generated_text = generated_text_list[0] if generated_text_list else ""
109
 
110
+ # Clean up prompt if it's echoed (optional, depends on model behavior)
111
  if prompt_input and generated_text.startswith(prompt_input):
112
  cleaned_text = generated_text[len(prompt_input):].lstrip(" ,.:")
113
  else:
114
  cleaned_text = generated_text
115
+
116
  return cleaned_text.strip()
117
+
118
  except Exception as e:
119
  print(f"Error during generation: {e}")
120
+ import traceback
121
+ traceback.print_exc()
122
  return f"An error occurred during text generation: {str(e)}"
123
 
124
+ # ... (rest of app.py)
125
+
126
  description = "Interactive demo for lusxvr/nanoVLM-222M."
127
  # example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" # Not used for now
128