vidhanm commited on
Commit
671ce94
·
1 Parent(s): de8d25e

updated parameter of generate.py in app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -25
app.py CHANGED
@@ -24,23 +24,22 @@ print(f"DEBUG: Using model repo ID: {MODEL_REPO_ID}")
24
 
25
  # In app.py
26
 
 
 
27
  def call_generate_script(image_path: str, prompt_text: str) -> str:
28
  print(f"\n--- DEBUG (call_generate_script) ---")
29
  print(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}")
30
  print(f"Calling with image_path='{image_path}', prompt='{prompt_text}'")
31
 
32
- # Arguments for nanoVLM's generate.py, VERIFIED against its source code
33
  cmd_args = [
34
  "python", "-u", GENERATE_SCRIPT_PATH,
35
  "--hf_model", MODEL_REPO_ID,
36
- "--image_path", image_path, # VERIFIED: script expects --image_path
37
  "--prompt", prompt_text,
38
- "--num_samples", "1", # VERIFIED: script expects --num_samples
39
- "--max_new_tokens", "30", # This was correct
40
- "--device", "cpu" # VERIFIED: script expects --device
41
- # Optional args for generate.py that you can add if needed:
42
- # "--temperature", "0.7",
43
- # "--top_k", "200" # Default is 200 in script
44
  ]
45
 
46
  print(f"Executing command: {' '.join(cmd_args)}")
@@ -71,33 +70,37 @@ def call_generate_script(image_path: str, prompt_text: str) -> str:
71
 
72
  if process.returncode != 0:
73
  error_message = f"Error: Generation script failed (code {process.returncode})."
74
- if "unrecognized arguments" in stderr:
75
  error_message += " Argument mismatch with script."
 
 
76
  print(error_message)
77
- return error_message + f" STDERR Snippet: {stderr[:300]}" # Show more stderr
78
 
79
- # --- Parse the output from nanoVLM's generate.py ---
80
- # The original nanoVLM generate.py prints:
81
- # > Sample 1: <generated text>
82
  output_lines = stdout.splitlines()
83
  generated_text = "[No parsable output from generate.py]"
84
 
85
  found_output_line = False
86
  for line_idx, line in enumerate(output_lines):
87
  stripped_line = line.strip()
88
- # The actual generate.py from nanoVLM prints "> Sample 1:"
89
- prefix_to_remove = None
90
- if stripped_line.startswith("> Sample 1:"):
91
- prefix_to_remove = "> Sample 1:"
92
-
93
- if prefix_to_remove:
94
- generated_text = stripped_line.replace(prefix_to_remove, "", 1).strip()
95
- found_output_line = True
96
- print(f"Parsed generated text: '{generated_text}'")
97
- break
98
-
 
 
99
  if not found_output_line:
100
- print(f"Could not find '> Sample 1:' line in generate.py output. Raw STDOUT was:\n{stdout}")
101
  if stdout:
102
  generated_text = f"[Parsing failed] STDOUT: {stdout[:500]}"
103
  else:
 
24
 
25
  # In app.py
26
 
27
+ # In app.py
28
+
29
  def call_generate_script(image_path: str, prompt_text: str) -> str:
30
  print(f"\n--- DEBUG (call_generate_script) ---")
31
  print(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}")
32
  print(f"Calling with image_path='{image_path}', prompt='{prompt_text}'")
33
 
34
+ # Arguments for the provided nanoVLM's generate.py
35
  cmd_args = [
36
  "python", "-u", GENERATE_SCRIPT_PATH,
37
  "--hf_model", MODEL_REPO_ID,
38
+ "--image", image_path, # VERIFIED: script uses --image
39
  "--prompt", prompt_text,
40
+ "--generations", "1", # VERIFIED: script uses --generations
41
+ "--max_new_tokens", "30" # This was correct
42
+ # No --device argument, as it's not in the provided generate.py
 
 
 
43
  ]
44
 
45
  print(f"Executing command: {' '.join(cmd_args)}")
 
70
 
71
  if process.returncode != 0:
72
  error_message = f"Error: Generation script failed (code {process.returncode})."
73
+ if "unrecognized arguments" in stderr: # This shouldn't happen now
74
  error_message += " Argument mismatch with script."
75
+ elif "syntax error" in stderr.lower():
76
+ error_message += " Syntax error in script."
77
  print(error_message)
78
+ return error_message + f" STDERR Snippet: {stderr[:300]}"
79
 
80
+ # --- Parse the output from the provided nanoVLM's generate.py ---
81
+ # The script prints:
82
+ # >> Generation {i+1}: {out}
83
  output_lines = stdout.splitlines()
84
  generated_text = "[No parsable output from generate.py]"
85
 
86
  found_output_line = False
87
  for line_idx, line in enumerate(output_lines):
88
  stripped_line = line.strip()
89
+ # Looking for the specific output format " >> Generation X: text"
90
+ if stripped_line.startswith(">> Generation 1:"): # Assuming we only care about the first generation
91
+ # Extract text after ">> Generation 1: " (note the space after colon)
92
+ try:
93
+ generated_text = stripped_line.split(">> Generation 1:", 1)[1].strip()
94
+ found_output_line = True
95
+ print(f"Parsed generated text: '{generated_text}'")
96
+ break
97
+ except IndexError:
98
+ print(f"Could not split line for '>> Generation 1:': '{stripped_line}'")
99
+ generated_text = f"[Parsing failed] Malformed 'Generation 1' line: {stripped_line}"
100
+ break
101
+
102
  if not found_output_line:
103
+ print(f"Could not find '>> Generation 1:' line in generate.py output. Raw STDOUT was:\n{stdout}")
104
  if stdout:
105
  generated_text = f"[Parsing failed] STDOUT: {stdout[:500]}"
106
  else: