RingL commited on
Commit
6048ff0
·
1 Parent(s): 0adea00
Files changed (2) hide show
  1. app.py +155 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Hugging Face's logo
2
+ Hugging Face
3
+ Models
4
+ Datasets
5
+ Spaces
6
+ Posts
7
+ Docs
8
+ Enterprise
9
+ Pricing
10
+
11
+
12
+
13
+ Spaces:
14
+
15
+
16
+ sergiopaniego
17
+ /
18
+ Qwen2-VL-7B-trl-sft-ChartQA
19
+
20
+
21
+ like
22
+ 6
23
+ App
24
+ Files
25
+ Community
26
+ Qwen2-VL-7B-trl-sft-ChartQA
27
+ /
28
+ app.py
29
+
30
+ sergiopaniego's picture
31
+ sergiopaniego
32
+ Formated code
33
+ 5ca3297
34
+ 4 months ago
35
+ raw
36
+
37
+ Copy download link
38
+ history
39
+ blame
40
+ contribute
41
+ delete
42
+
43
+ 3.47 kB
44
+ import gradio as gr
45
+ import spaces
46
+ from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor
47
+ from qwen_vl_utils import process_vision_info
48
+ import torch
49
+ from PIL import Image
50
+ from datetime import datetime
51
+ import numpy as np
52
+ import os
53
+
54
+
55
+ DESCRIPTION = """
56
+ # VisQA Demo
57
+ """
58
+
59
+ model_id = "Qwen/Qwen2-VL-7B-Instruct"
60
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
61
+ model_id,
62
+ device_map="auto",
63
+ torch_dtype=torch.bfloat16,
64
+ )
65
+ adapter_path = "sergiopaniego/qwen2-7b-instruct-trl-sft-ChartQA"
66
+ model.load_adapter(adapter_path)
67
+ processor = Qwen2VLProcessor.from_pretrained(model_id)
68
+
69
+ def array_to_image_path(image_array):
70
+ if image_array is None:
71
+ raise ValueError("No image provided. Please upload an image before submitting.")
72
+ # Convert numpy array to PIL Image
73
+ img = Image.fromarray(np.uint8(image_array))
74
+
75
+ # Generate a unique filename using timestamp
76
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
77
+ filename = f"image_{timestamp}.png"
78
+
79
+ # Save the image
80
+ img.save(filename)
81
+
82
+ # Get the full path of the saved image
83
+ full_path = os.path.abspath(filename)
84
+
85
+ return full_path
86
+
87
+
88
+ @spaces.GPU
89
+ def run_example(image, text_input=None):
90
+ image_path = array_to_image_path(image)
91
+ image = Image.fromarray(image).convert("RGB")
92
+ messages = [
93
+ {
94
+ "role": "user",
95
+ "content": [
96
+ {
97
+ "type": "image",
98
+ "image": image_path,
99
+ },
100
+ {
101
+ "type": "text",
102
+ "text": text_input
103
+ },
104
+ ],
105
+ }
106
+ ]
107
+
108
+ # Preparation for inference
109
+ text = processor.apply_chat_template(
110
+ messages, tokenize=False, add_generation_prompt=True
111
+ )
112
+ image_inputs, video_inputs = process_vision_info(messages)
113
+ inputs = processor(
114
+ text=[text],
115
+ images=image_inputs,
116
+ videos=video_inputs,
117
+ padding=True,
118
+ return_tensors="pt",
119
+ )
120
+ inputs = inputs.to("cuda")
121
+
122
+ # Inference: Generation of the output
123
+ generated_ids = model.generate(**inputs, max_new_tokens=1024)
124
+ generated_ids_trimmed = [
125
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
126
+ ]
127
+ output_text = processor.batch_decode(
128
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
129
+ )
130
+
131
+ return output_text[0]
132
+
133
+ css = """
134
+ #output {
135
+ height: 500px;
136
+ overflow: auto;
137
+ border: 1px solid #ccc;
138
+ }
139
+ """
140
+
141
+ with gr.Blocks(css=css) as demo:
142
+ gr.Markdown(DESCRIPTION)
143
+ with gr.Tab(label="Qwen2-VL-7B-trl-sft-ChartQA Input"):
144
+ with gr.Row():
145
+ with gr.Column():
146
+ input_img = gr.Image(label="Input Picture")
147
+ text_input = gr.Textbox(label="Question")
148
+ submit_btn = gr.Button(value="Submit")
149
+ with gr.Column():
150
+ output_text = gr.Textbox(label="Output Text")
151
+
152
+ submit_btn.click(run_example, [input_img, text_input], [output_text])
153
+
154
+ demo.queue(api_open=False)
155
+ demo.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.24.4
2
+ Pillow==10.3.0
3
+ Requests==2.31.0
4
+ torch
5
+ torchvision
6
+ git+https://github.com/huggingface/transformers.git
7
+ accelerate
8
+ qwen-vl-utils
9
+ peft