File size: 3,329 Bytes
fd2651a
9fecce5
 
29bc91e
d342d8b
fd2651a
fc93bc3
 
 
 
 
 
 
9fecce5
fc93bc3
9fecce5
fd2651a
d342d8b
600f2a3
 
 
 
 
 
 
 
29bc91e
 
 
600f2a3
 
 
 
8ead0a1
fc93bc3
 
8ead0a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e81c28c
 
9fecce5
 
 
e81c28c
9fecce5
e81c28c
 
fc93bc3
9fecce5
cde52cf
80cd182
 
9fecce5
 
80cd182
 
9fecce5
29bc91e
80cd182
fc93bc3
fd2651a
cde52cf
 
fd2651a
80cd182
cde52cf
80cd182
31763a0
80cd182
fd2651a
3c6c945
 
8afd5a6
5d74f96
12c4ba7
3c6c945
 
b2c4d0a
 
3c6c945
 
 
 
 
07aab95
3c6c945
 
de035a6
29bc91e
600f2a3
 
ed8c92d
600f2a3
 
 
 
 
 
 
 
 
29bc91e
fd2651a
fc93bc3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
from qwen_vl_utils import process_vision_info
from threading import Thread
import spaces

file_path = "csfufu/Revisual-R1-final"

processor = AutoProcessor.from_pretrained(
    file_path,
    min_pixels=256*28*28,
    max_pixels=1280*28*28
)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    file_path, torch_dtype="auto", device_map="auto"
)

@spaces.GPU
def respond(
    input_dict,
    chat_history,
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    text = input_dict["text"]
    files = input_dict["files"]

    messages = [{
        "role": "system",
        "content": system_message
    }]

    print(chat_history)

    for message in chat_history:
        if isinstance(message["content"], str):
            messages.append({
                "role": message["role"],
                "content": [
                    { "type": "text", "text": message["content"] },
                ]
            })
        else:
            messages.append({
                "role": message["role"],
                "content": [
                    { "type": "image", "image": image }
                    for image in message["content"]
                ]
            })

    messages.append(
        {
            "role": "user",
            "content": [
                { "type": "text", "text": text },
                *[{"type": "image", "image": image} for image in files]
            ]
        }
    )
    
    image_inputs, video_inputs = process_vision_info(messages)
    prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(
        text=[prompt],
        images=image_inputs,
        videos=video_inputs,
        return_tensors="pt",
        padding=True,
    ).to(model.device)
    
    streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p)

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    buffer = ""
    for new_text in streamer:
        buffer += new_text
        print(new_text, end='')
        yield buffer

    print()

demo = gr.ChatInterface(
        title='Revisual-R1',
        type='messages',
        chatbot=gr.Chatbot(
            type='messages',
        #    allow_tags=['think'],
            sanitize_html=False,
            scale=1,
        ),
        fn=respond,
        examples=[[{
                "text": "Solve this question.",
                "files": [ "example.png" ]
            }
        ]],
        cache_examples=False,
        multimodal=True,
        additional_inputs=[
            gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
            gr.Slider(minimum=1, maximum=8192, value=512, step=1, label="Max new tokens"),
            gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
            gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.95,
                step=0.05,
                label="Top-p (nucleus sampling)",
            ),
        ],
    )

demo.launch(debug=True)