File size: 4,739 Bytes
0c21c10
7c398ad
 
39c0bc7
 
 
0c21c10
39c0bc7
 
 
 
 
 
 
 
 
2904d0e
7c398ad
0c21c10
7c398ad
2904d0e
7c398ad
 
 
 
 
 
 
 
d21dd33
7c398ad
 
 
 
 
 
 
 
 
2904d0e
 
39c0bc7
 
 
7c398ad
 
39c0bc7
 
7c398ad
 
 
 
 
 
 
 
 
 
 
 
 
d21dd33
7c398ad
d21dd33
7c398ad
 
 
 
d21dd33
7c398ad
 
39c0bc7
 
0c21c10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39c0bc7
 
 
 
 
 
 
 
 
 
 
 
7c398ad
 
 
 
 
 
 
39c0bc7
0c21c10
 
 
 
 
 
 
 
 
 
 
 
 
 
39c0bc7
0c21c10
39c0bc7
7c398ad
 
 
 
 
 
 
 
 
39c0bc7
0c21c10
 
 
 
39c0bc7
0c21c10
39c0bc7
 
0c21c10
 
39c0bc7
7c398ad
39c0bc7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import socket
import pandas as pd
import gradio as gr

# import spaces #[uncomment to use ZeroGPU]
# from diffusers import DiffusionPipeline
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

if torch.cuda.is_available():
    torch_dtype = torch.float16
else:
    torch_dtype = torch.float32


from dataset import category2human, create_prompt

LOCAL_COMPUTER_NAMES = ["amir-xps"]


def is_local_machine():
    return socket.gethostname().lower() in [
        name.lower() for name in LOCAL_COMPUTER_NAMES
    ]


if is_local_machine():
    model_path = os.path.expanduser("~/.cache/huggingface/checkpoints/distilbert-arxiv2")
else:
    model_path = "Hacker1337/distilbert-arxiv-checkpoint"

from transformers import pipeline

classifier = pipeline(
    "text-classification",
    model=model_path,
    tokenizer=model_path,
)


# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
    title_prompt,
    summary_prompt,
    progress=gr.Progress(track_tqdm=True),
):
    sample_prompt_full = create_prompt(
        title_prompt,
        summary_prompt,
    )
    predictions = classifier(sample_prompt_full, top_k=None)
    target_probs_sum = 0.95
    print(predictions)
    df = pd.DataFrame(predictions)
    df["label"] = df["label"].apply(lambda x: category2human[x])
    label_dict = {}
    bar_plot_dict = {}
    total_prop = sum([prediction["score"] for prediction in predictions])
    gained_prob = 0
    for prediction in sorted(predictions, key=lambda x: x["score"], reverse=True):
        bar_plot_dict[prediction["label"]] = prediction["score"]
        if (gained_prob) / total_prop < target_probs_sum:
            label_dict[category2human[prediction["label"]]] = (
                prediction["score"] / total_prop
            )
            gained_prob += prediction["score"]
    if gained_prob < total_prop + 1e-5:
        label_dict["Other"] = (total_prop - gained_prob) / total_prop
    return df, label_dict


examples_titles = [
    "Survey on Semantic Stereo Matching",
]
examples_summaries = [
    """Stereo matching is one of the widely used techniques for inferring depth from
stereo images owing to its robustness and speed. It has become one of the major
topics of research since it finds its applications in autonomous driving,
robotic navigation, 3D reconstruction, and many other fields. Finding pixel
correspondences in non-textured, occluded and reflective areas is the major
challenge in stereo matching. Recent developments have shown that semantic cues
from image segmentation can be used to improve the results of stereo matching.
Many deep neural network architectures have been proposed to leverage the
advantages of semantic segmentation in stereo matching. This paper aims to give
a comparison among the state of art networks both in terms of accuracy and in
terms of speed which are of higher importance in real-time applications.""",
]


css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(" # Text-to-Image Gradio Template")
        gr.Markdown(
            "This space classifies scientific machine learning papers into categories based on their title and abstract."
            "Bar plot shows probabilities of belonging to each category."
        )
        gr.Markdown(
            "Second thing predicts most probable single class classification. It shows only first 95\% of categories."
        )

        title_prompt = gr.Text(
            label="Title Prompt",
            show_label=False,
            max_lines=1,
            placeholder="Enter paper's title",
            container=False,
        )
        summary_prompt = gr.Text(
            label="Summary Prompt",
            show_label=False,
            max_lines=10,
            placeholder="Enter paper's abstract",
            container=False,
        )

        run_button = gr.Button("Run", scale=0, variant="primary")

        result_bar = gr.BarPlot(
            label="Multi class classification",
            show_label=True,
            x="label",
            y="score",
            x_label_angle=30,
        )

        result_label = gr.Label(label="Single class selection")

        # with gr.Accordion("Advanced Settings", open=False):

        gr.Examples(examples=examples_titles, inputs=[title_prompt])
        gr.Examples(examples=examples_summaries, inputs=[summary_prompt])
    gr.on(
        triggers=[run_button.click, title_prompt.submit, summary_prompt.submit],
        fn=infer,
        inputs=[
            title_prompt,
            summary_prompt,
        ],
        outputs=[result_bar, result_label],
    )

if __name__ == "__main__":
    demo.launch()