Spaces:
Sleeping
Sleeping
File size: 4,739 Bytes
0c21c10 7c398ad 39c0bc7 0c21c10 39c0bc7 2904d0e 7c398ad 0c21c10 7c398ad 2904d0e 7c398ad d21dd33 7c398ad 2904d0e 39c0bc7 7c398ad 39c0bc7 7c398ad d21dd33 7c398ad d21dd33 7c398ad d21dd33 7c398ad 39c0bc7 0c21c10 39c0bc7 7c398ad 39c0bc7 0c21c10 39c0bc7 0c21c10 39c0bc7 7c398ad 39c0bc7 0c21c10 39c0bc7 0c21c10 39c0bc7 0c21c10 39c0bc7 7c398ad 39c0bc7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os
import socket
import pandas as pd
import gradio as gr
# import spaces #[uncomment to use ZeroGPU]
# from diffusers import DiffusionPipeline
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
from dataset import category2human, create_prompt
LOCAL_COMPUTER_NAMES = ["amir-xps"]
def is_local_machine():
return socket.gethostname().lower() in [
name.lower() for name in LOCAL_COMPUTER_NAMES
]
if is_local_machine():
model_path = os.path.expanduser("~/.cache/huggingface/checkpoints/distilbert-arxiv2")
else:
model_path = "Hacker1337/distilbert-arxiv-checkpoint"
from transformers import pipeline
classifier = pipeline(
"text-classification",
model=model_path,
tokenizer=model_path,
)
# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
title_prompt,
summary_prompt,
progress=gr.Progress(track_tqdm=True),
):
sample_prompt_full = create_prompt(
title_prompt,
summary_prompt,
)
predictions = classifier(sample_prompt_full, top_k=None)
target_probs_sum = 0.95
print(predictions)
df = pd.DataFrame(predictions)
df["label"] = df["label"].apply(lambda x: category2human[x])
label_dict = {}
bar_plot_dict = {}
total_prop = sum([prediction["score"] for prediction in predictions])
gained_prob = 0
for prediction in sorted(predictions, key=lambda x: x["score"], reverse=True):
bar_plot_dict[prediction["label"]] = prediction["score"]
if (gained_prob) / total_prop < target_probs_sum:
label_dict[category2human[prediction["label"]]] = (
prediction["score"] / total_prop
)
gained_prob += prediction["score"]
if gained_prob < total_prop + 1e-5:
label_dict["Other"] = (total_prop - gained_prob) / total_prop
return df, label_dict
examples_titles = [
"Survey on Semantic Stereo Matching",
]
examples_summaries = [
"""Stereo matching is one of the widely used techniques for inferring depth from
stereo images owing to its robustness and speed. It has become one of the major
topics of research since it finds its applications in autonomous driving,
robotic navigation, 3D reconstruction, and many other fields. Finding pixel
correspondences in non-textured, occluded and reflective areas is the major
challenge in stereo matching. Recent developments have shown that semantic cues
from image segmentation can be used to improve the results of stereo matching.
Many deep neural network architectures have been proposed to leverage the
advantages of semantic segmentation in stereo matching. This paper aims to give
a comparison among the state of art networks both in terms of accuracy and in
terms of speed which are of higher importance in real-time applications.""",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # Text-to-Image Gradio Template")
gr.Markdown(
"This space classifies scientific machine learning papers into categories based on their title and abstract."
"Bar plot shows probabilities of belonging to each category."
)
gr.Markdown(
"Second thing predicts most probable single class classification. It shows only first 95\% of categories."
)
title_prompt = gr.Text(
label="Title Prompt",
show_label=False,
max_lines=1,
placeholder="Enter paper's title",
container=False,
)
summary_prompt = gr.Text(
label="Summary Prompt",
show_label=False,
max_lines=10,
placeholder="Enter paper's abstract",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
result_bar = gr.BarPlot(
label="Multi class classification",
show_label=True,
x="label",
y="score",
x_label_angle=30,
)
result_label = gr.Label(label="Single class selection")
# with gr.Accordion("Advanced Settings", open=False):
gr.Examples(examples=examples_titles, inputs=[title_prompt])
gr.Examples(examples=examples_summaries, inputs=[summary_prompt])
gr.on(
triggers=[run_button.click, title_prompt.submit, summary_prompt.submit],
fn=infer,
inputs=[
title_prompt,
summary_prompt,
],
outputs=[result_bar, result_label],
)
if __name__ == "__main__":
demo.launch()
|