khang119966 commited on
Commit
f10b987
·
verified ·
1 Parent(s): 3fcdc10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -75
app.py CHANGED
@@ -6,12 +6,11 @@ import os
6
  import tempfile
7
  from PIL import Image
8
 
9
- # --- Tải Model Tokenizer (Chỉ một lần khi khởi động) ---
10
- # Di chuyển việc tải model ra ngoài để tránh tải lại mỗi lần gọi hàm
11
  print("Loading model and tokenizer...")
12
  model_name = "deepseek-ai/DeepSeek-OCR"
13
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
14
- # Tải model lên CPU trước, sau đó chuyển sang GPU trong hàm xử lý
15
  model = AutoModel.from_pretrained(
16
  model_name,
17
  _attn_implementation="flash_attention_2",
@@ -19,51 +18,50 @@ model = AutoModel.from_pretrained(
19
  use_safetensors=True,
20
  )
21
  model = model.eval()
22
- print("Model loaded successfully.")
23
 
24
 
25
- # --- Hàm xử chính ---
26
  @spaces.GPU
27
  def process_ocr_task(image, model_size, task_type, ref_text):
28
  """
29
- Xử hình ảnh với DeepSeek-OCR cho tất cả các tác vụ.
30
  Args:
31
- image: Đối tượng PIL Image
32
- model_size: Cấu hình kích thước model
33
- task_type: Loại tác vụ OCR
34
- ref_text: Văn bản tham chiếu cho tác vụ 'Locate'
35
  """
36
  if image is None:
37
  return "Please upload an image first.", None
38
 
39
- # Chuyển model sang GPU định dạng bfloat16 để tối ưu hiệu suất
40
- print("Moving model to GPU...")
41
  model_gpu = model.cuda().to(torch.bfloat16)
42
- print("Model on GPU.")
43
 
44
- # Tạo thư mục tạm thời để lưu trữ đầu ra
45
  with tempfile.TemporaryDirectory() as output_path:
46
- # --- Xây dựng prompt dựa trên loại tác vụ ---
47
- if task_type == "Free OCR":
48
  prompt = "<image>\nFree OCR."
49
- elif task_type == "Convert to Markdown":
50
  prompt = "<image>\n<|grounding|>Convert the document to markdown."
51
- elif task_type == "Parse Figure":
52
  prompt = "<image>\nParse the figure."
53
- elif task_type == "Locate Object by Reference":
54
  if not ref_text or ref_text.strip() == "":
55
- raise gr.Error("For 'Locate' task, please provide the reference text to find.")
56
- # Sử dụng f-string để chèn văn bản tham chiếu vào prompt
57
  prompt = f"<image>\nLocate <|ref|>{ref_text.strip()}<|/ref|> in the image."
58
  else:
59
- # Mặc định Free OCR
60
- prompt = "<image>\nFree OCR."
61
 
62
- # Lưu ảnh được tải lên vào thư mục tạm
63
  temp_image_path = os.path.join(output_path, "temp_image.png")
64
  image.save(temp_image_path)
65
 
66
- # Cấu hình các tham số kích thước model
67
  size_configs = {
68
  "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
69
  "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
@@ -73,8 +71,8 @@ def process_ocr_task(image, model_size, task_type, ref_text):
73
  }
74
  config = size_configs.get(model_size, size_configs["Gundam (Recommended)"])
75
 
76
- print(f"Running inference with prompt: {prompt}")
77
- # --- Chạy inference ---
78
  text_result = model_gpu.infer(
79
  tokenizer,
80
  prompt=prompt,
@@ -83,120 +81,119 @@ def process_ocr_task(image, model_size, task_type, ref_text):
83
  base_size=config["base_size"],
84
  image_size=config["image_size"],
85
  crop_mode=config["crop_mode"],
86
- save_results=True, # Quan trọng: phải lưu kết quả để lấy ảnh output
87
  test_compress=True,
88
  eval_mode=True,
89
  )
90
 
91
- print(f"====\nText Result: {text_result}\n====")
92
 
93
- # --- Xử output (văn bản hình ảnh) ---
94
  image_result_path = None
95
- # Tác vụ 'Locate' 'Markdown' thường tạo ra ảnh kết quả có chữ 'grounding'
96
- if task_type in ["Locate Object by Reference", "Convert to Markdown", "Parse Figure"]:
97
- # Tìm file ảnh kết quả trong thư mục output
98
  for filename in os.listdir(output_path):
99
  if "grounding" in filename or "result" in filename:
100
  image_result_path = os.path.join(output_path, filename)
101
  break
102
 
103
- # Nếu tìm thấy ảnh, tải nó, nếu không trả về None
104
  result_image_pil = Image.open(image_result_path) if image_result_path else None
105
 
106
  return text_result, result_image_pil
107
 
108
 
109
- # --- Xây dựng giao diện Gradio ---
110
- with gr.Blocks(title="DeepSeek-OCR", theme=gr.themes.Soft()) as demo:
111
  gr.Markdown(
112
  """
113
- # Demo toàn diện DeepSeek-OCR
114
- Tải lên một hình ảnh để thử nghiệm các khả năng nhận dạng và hiểu tài liệu của DeepSeek-OCR.
115
 
116
- **Hướng dẫn:**
117
- 1. Tải lên một hình ảnh.
118
- 2. Chọn **Model Size** phù hợp (Gundam được khuyến nghị cho tài liệu).
119
- 3. Chọn **Task Type**:
120
- - **Free OCR**: Trích xuất văn bản thô.
121
- - **Convert to Markdown**: Chuyển đổi tài liệu (giữ cấu trúc) sang định dạng Markdown.
122
- - **Parse Figure**: Phân tích trích xuất dữ liệu từ biểu đồ, hình vẽ.
123
- - **Locate Object by Reference**: Tìm một đối tượng hoặc văn bản cụ thể trong ảnh. **Bạn cần nhập nội dung cần tìm vào ô "Reference Text" bên dưới.**
124
  """
125
  )
126
 
127
  with gr.Row():
128
  with gr.Column(scale=1):
129
- image_input = gr.Image(type="pil", label="Tải ảnh lên", sources=["upload", "clipboard"])
130
 
131
  model_size = gr.Dropdown(
132
  choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
133
  value="Gundam (Recommended)",
134
- label="Model Size",
135
  )
136
 
137
  task_type = gr.Dropdown(
138
- choices=["Free OCR", "Convert to Markdown", "Parse Figure", "Locate Object by Reference"],
139
- value="Convert to Markdown",
140
- label="Task Type",
141
  )
142
 
143
- # Ô nhập văn bản tham chiếu, ban đầu bị ẩn
144
  ref_text_input = gr.Textbox(
145
- label="Reference Text (cho tác vụ Locate)",
146
- placeholder=" dụ: the teacher, 11-2=, a red car...",
147
- visible=False, # Ban đầu ẩn đi
148
  )
149
 
150
- submit_btn = gr.Button("Xử ", variant="primary")
151
 
152
  with gr.Column(scale=2):
153
- output_text = gr.Textbox(label="Kết quả văn bản", lines=15, show_copy_button=True)
154
- output_image = gr.Image(label="Kết quả hình ảnh (nếu )", type="pil")
155
 
156
- # --- Logic tương tác cho giao diện ---
157
  def toggle_ref_text_visibility(task):
158
- # Nếu người dùng chọn 'Locate', hiển thị ô nhập văn bản
159
- if task == "Locate Object by Reference":
160
  return gr.Textbox(visible=True)
161
  else:
162
  return gr.Textbox(visible=False)
163
 
164
- # Khi dropdown 'task_type' thay đổi, gọi hàm để cập nhật trạng thái hiển thị của ô ref_text_input
165
  task_type.change(
166
  fn=toggle_ref_text_visibility,
167
  inputs=task_type,
168
  outputs=ref_text_input,
169
  )
170
 
171
- # Khi nhấn nút submit
172
  submit_btn.click(
173
  fn=process_ocr_task,
174
  inputs=[image_input, model_size, task_type, ref_text_input],
175
  outputs=[output_text, output_image],
176
  )
177
 
178
- # --- Các dụ minh họa ---
179
  gr.Examples(
180
  examples=[
181
- ["./examples/doc_markdown.png", "Gundam (Recommended)", "Convert to Markdown", ""],
182
- ["./examples/chart.png", "Gundam (Recommended)", "Parse Figure", ""],
183
- ["./examples/teacher.png", "Base", "Locate Object by Reference", "the teacher"],
184
- ["./examples/math_locate.png", "Small", "Locate Object by Reference", "11-2="],
185
- ["./examples/receipt.jpg", "Base", "Free OCR", ""],
186
  ],
187
  inputs=[image_input, model_size, task_type, ref_text_input],
188
  outputs=[output_text, output_image],
189
  fn=process_ocr_task,
190
- cache_examples=False, # Tắt cache để đảm bảo chạy lại mỗi lần click
191
  )
192
 
193
- # --- Khởi chạy ứng dụng ---
194
  if __name__ == "__main__":
195
- # Tạo thư mục examples tải ảnh dụ (nếu chưa có)
196
  if not os.path.exists("examples"):
197
  os.makedirs("examples")
198
- # Bạn cần tự tải các file ảnh dụ vào thư mục "examples"
199
- # dụ: doc_markdown.png, chart.png, teacher.png, math_locate.png, receipt.jpg
200
 
201
  demo.queue(max_size=20)
202
- demo.launch(share=True) # share=True để tạo link public
 
6
  import tempfile
7
  from PIL import Image
8
 
9
+ # --- 1. Load Model and Tokenizer (Done only once at startup) ---
 
10
  print("Loading model and tokenizer...")
11
  model_name = "deepseek-ai/DeepSeek-OCR"
12
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
13
+ # Load the model to CPU first; it will be moved to GPU during processing
14
  model = AutoModel.from_pretrained(
15
  model_name,
16
  _attn_implementation="flash_attention_2",
 
18
  use_safetensors=True,
19
  )
20
  model = model.eval()
21
+ print("Model loaded successfully.")
22
 
23
 
24
+ # --- 2. Main Processing Function ---
25
  @spaces.GPU
26
  def process_ocr_task(image, model_size, task_type, ref_text):
27
  """
28
+ Processes an image with DeepSeek-OCR for all supported tasks.
29
  Args:
30
+ image (PIL.Image): The input image.
31
+ model_size (str): The model size configuration.
32
+ task_type (str): The type of OCR task to perform.
33
+ ref_text (str): The reference text for the 'Locate' task.
34
  """
35
  if image is None:
36
  return "Please upload an image first.", None
37
 
38
+ # Move the model to GPU and use bfloat16 for better performance
39
+ print("🚀 Moving model to GPU...")
40
  model_gpu = model.cuda().to(torch.bfloat16)
41
+ print("Model is on GPU.")
42
 
43
+ # Create a temporary directory to store files
44
  with tempfile.TemporaryDirectory() as output_path:
45
+ # --- Build the prompt based on the selected task type ---
46
+ if task_type == "📝 Free OCR":
47
  prompt = "<image>\nFree OCR."
48
+ elif task_type == "📄 Convert to Markdown":
49
  prompt = "<image>\n<|grounding|>Convert the document to markdown."
50
+ elif task_type == "📈 Parse Figure":
51
  prompt = "<image>\nParse the figure."
52
+ elif task_type == "🔍 Locate Object by Reference":
53
  if not ref_text or ref_text.strip() == "":
54
+ raise gr.Error("For the 'Locate' task, you must provide the reference text to find!")
55
+ # Use an f-string to embed the user's reference text into the prompt
56
  prompt = f"<image>\nLocate <|ref|>{ref_text.strip()}<|/ref|> in the image."
57
  else:
58
+ prompt = "<image>\nFree OCR." # Default fallback
 
59
 
60
+ # Save the uploaded image to the temporary path
61
  temp_image_path = os.path.join(output_path, "temp_image.png")
62
  image.save(temp_image_path)
63
 
64
+ # Configure model size parameters
65
  size_configs = {
66
  "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
67
  "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
 
71
  }
72
  config = size_configs.get(model_size, size_configs["Gundam (Recommended)"])
73
 
74
+ print(f"🏃 Running inference with prompt: {prompt}")
75
+ # --- Run the model's inference method ---
76
  text_result = model_gpu.infer(
77
  tokenizer,
78
  prompt=prompt,
 
81
  base_size=config["base_size"],
82
  image_size=config["image_size"],
83
  crop_mode=config["crop_mode"],
84
+ save_results=True, # Important: Must be True to get the output image
85
  test_compress=True,
86
  eval_mode=True,
87
  )
88
 
89
+ print(f"====\n📄 Text Result: {text_result}\n====")
90
 
91
+ # --- Handle the output (both text and image) ---
92
  image_result_path = None
93
+ # Tasks that generate a visual output usually create a 'grounding' or 'result' image
94
+ if task_type in ["🔍 Locate Object by Reference", "📄 Convert to Markdown", "📈 Parse Figure"]:
95
+ # Find the result image in the output directory
96
  for filename in os.listdir(output_path):
97
  if "grounding" in filename or "result" in filename:
98
  image_result_path = os.path.join(output_path, filename)
99
  break
100
 
101
+ # If an image was found, open it with PIL; otherwise, return None
102
  result_image_pil = Image.open(image_result_path) if image_result_path else None
103
 
104
  return text_result, result_image_pil
105
 
106
 
107
+ # --- 3. Build the Gradio Interface ---
108
+ with gr.Blocks(title="🐳DeepSeek-OCR🐳", theme=gr.themes.Soft()) as demo:
109
  gr.Markdown(
110
  """
111
+ # 🐳 Full Demo of DeepSeek-OCR 🐳
112
+ Upload an image to explore the document recognition and understanding capabilities of DeepSeek-OCR.
113
 
114
+ **💡 How to use:**
115
+ 1. **Upload an image** using the upload box.
116
+ 2. Select a **Model Size**. `Gundam` is recommended for most documents for a good balance of speed and accuracy.
117
+ 3. Choose a **Task Type**:
118
+ - **📝 Free OCR**: Extracts raw text from the image. Best for simple text extraction.
119
+ - **📄 Convert to Markdown**: Converts the entire document into Markdown format, preserving structure like headers, lists, and tables.
120
+ - **📈 Parse Figure**: Analyzes and extracts structured data from charts, graphs, and geometric figures.
121
+ - **🔍 Locate Object by Reference**: Finds a specific object or piece of text in the image. You **must** type what you're looking for into the **"Reference Text"** box that appears.
122
  """
123
  )
124
 
125
  with gr.Row():
126
  with gr.Column(scale=1):
127
+ image_input = gr.Image(type="pil", label="🖼️ Upload Image", sources=["upload", "clipboard"])
128
 
129
  model_size = gr.Dropdown(
130
  choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
131
  value="Gundam (Recommended)",
132
+ label="⚙️ Model Size",
133
  )
134
 
135
  task_type = gr.Dropdown(
136
+ choices=["📝 Free OCR", "📄 Convert to Markdown", "📈 Parse Figure", "🔍 Locate Object by Reference"],
137
+ value="📄 Convert to Markdown",
138
+ label="🚀 Task Type",
139
  )
140
 
 
141
  ref_text_input = gr.Textbox(
142
+ label="📝 Reference Text (for Locate task)",
143
+ placeholder="e.g., the teacher, 11-2=, a red car...",
144
+ visible=False, # Initially hidden
145
  )
146
 
147
+ submit_btn = gr.Button("Process Image", variant="primary")
148
 
149
  with gr.Column(scale=2):
150
+ output_text = gr.Textbox(label="📄 Text Result", lines=15, show_copy_button=True)
151
+ output_image = gr.Image(label="🖼️ Image Result (if any)", type="pil")
152
 
153
+ # --- UI Interaction Logic ---
154
  def toggle_ref_text_visibility(task):
155
+ # If the user selects the 'Locate' task, make the reference textbox visible
156
+ if task == "🔍 Locate Object by Reference":
157
  return gr.Textbox(visible=True)
158
  else:
159
  return gr.Textbox(visible=False)
160
 
161
+ # When the 'task_type' dropdown changes, call the function to update the visibility
162
  task_type.change(
163
  fn=toggle_ref_text_visibility,
164
  inputs=task_type,
165
  outputs=ref_text_input,
166
  )
167
 
168
+ # Define what happens when the submit button is clicked
169
  submit_btn.click(
170
  fn=process_ocr_task,
171
  inputs=[image_input, model_size, task_type, ref_text_input],
172
  outputs=[output_text, output_image],
173
  )
174
 
175
+ # --- Example Images and Tasks ---
176
  gr.Examples(
177
  examples=[
178
+ ["./examples/doc_markdown.png", "Gundam (Recommended)", "📄 Convert to Markdown", ""],
179
+ ["./examples/chart.png", "Gundam (Recommended)", "📈 Parse Figure", ""],
180
+ ["./examples/teacher.png", "Base", "🔍 Locate Object by Reference", "the teacher"],
181
+ ["./examples/math_locate.png", "Small", "🔍 Locate Object by Reference", "11-2="],
182
+ ["./examples/receipt.jpg", "Base", "📝 Free OCR", ""],
183
  ],
184
  inputs=[image_input, model_size, task_type, ref_text_input],
185
  outputs=[output_text, output_image],
186
  fn=process_ocr_task,
187
+ cache_examples=False, # Disable caching to ensure examples run every time
188
  )
189
 
190
+ # --- 4. Launch the App ---
191
  if __name__ == "__main__":
192
+ # Create an 'examples' directory if it doesn't exist
193
  if not os.path.exists("examples"):
194
  os.makedirs("examples")
195
+ # Please manually download the example images into the "examples" folder.
196
+ # e.g., doc_markdown.png, chart.png, teacher.png, math_locate.png, receipt.jpg
197
 
198
  demo.queue(max_size=20)
199
+ demo.launch(share=True) # Set share=True to create a public link