Spaces:
Sleeping
Sleeping
import os | |
import socket | |
import pandas as pd | |
import gradio as gr | |
# import spaces #[uncomment to use ZeroGPU] | |
# from diffusers import DiffusionPipeline | |
import torch | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
if torch.cuda.is_available(): | |
torch_dtype = torch.float16 | |
else: | |
torch_dtype = torch.float32 | |
from dataset import category2human, create_prompt | |
LOCAL_COMPUTER_NAMES = ["amir-xps"] | |
def is_local_machine(): | |
return socket.gethostname().lower() in [ | |
name.lower() for name in LOCAL_COMPUTER_NAMES | |
] | |
if is_local_machine(): | |
model_path = os.path.expanduser("~/.cache/huggingface/checkpoints/distilbert-arxiv2") | |
else: | |
model_path = "Hacker1337/distilbert-arxiv-checkpoint" | |
from transformers import pipeline | |
classifier = pipeline( | |
"text-classification", | |
model=model_path, | |
tokenizer=model_path, | |
) | |
# @spaces.GPU #[uncomment to use ZeroGPU] | |
def infer( | |
title_prompt, | |
summary_prompt, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
sample_prompt_full = create_prompt( | |
title_prompt, | |
summary_prompt, | |
) | |
predictions = classifier(sample_prompt_full, top_k=None) | |
target_probs_sum = 0.95 | |
print(predictions) | |
df = pd.DataFrame(predictions) | |
df["label"] = df["label"].apply(lambda x: category2human[x]) | |
label_dict = {} | |
bar_plot_dict = {} | |
total_prop = sum([prediction["score"] for prediction in predictions]) | |
gained_prob = 0 | |
for prediction in sorted(predictions, key=lambda x: x["score"], reverse=True): | |
bar_plot_dict[prediction["label"]] = prediction["score"] | |
if (gained_prob) / total_prop < target_probs_sum: | |
label_dict[category2human[prediction["label"]]] = ( | |
prediction["score"] / total_prop | |
) | |
gained_prob += prediction["score"] | |
if gained_prob < total_prop + 1e-5: | |
label_dict["Other"] = (total_prop - gained_prob) / total_prop | |
return df, label_dict | |
examples_titles = [ | |
"Survey on Semantic Stereo Matching", | |
] | |
examples_summaries = [ | |
"""Stereo matching is one of the widely used techniques for inferring depth from | |
stereo images owing to its robustness and speed. It has become one of the major | |
topics of research since it finds its applications in autonomous driving, | |
robotic navigation, 3D reconstruction, and many other fields. Finding pixel | |
correspondences in non-textured, occluded and reflective areas is the major | |
challenge in stereo matching. Recent developments have shown that semantic cues | |
from image segmentation can be used to improve the results of stereo matching. | |
Many deep neural network architectures have been proposed to leverage the | |
advantages of semantic segmentation in stereo matching. This paper aims to give | |
a comparison among the state of art networks both in terms of accuracy and in | |
terms of speed which are of higher importance in real-time applications.""", | |
] | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 640px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(" # Text-to-Image Gradio Template") | |
gr.Markdown( | |
"This space classifies scientific machine learning papers into categories based on their title and abstract." | |
"Bar plot shows probabilities of belonging to each category." | |
) | |
gr.Markdown( | |
"Second thing predicts most probable single class classification. It shows only first 95\% of categories." | |
) | |
title_prompt = gr.Text( | |
label="Title Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter paper's title", | |
container=False, | |
) | |
summary_prompt = gr.Text( | |
label="Summary Prompt", | |
show_label=False, | |
max_lines=10, | |
placeholder="Enter paper's abstract", | |
container=False, | |
) | |
run_button = gr.Button("Run", scale=0, variant="primary") | |
result_bar = gr.BarPlot( | |
label="Multi class classification", | |
show_label=True, | |
x="label", | |
y="score", | |
x_label_angle=30, | |
) | |
result_label = gr.Label(label="Single class selection") | |
# with gr.Accordion("Advanced Settings", open=False): | |
gr.Examples(examples=examples_titles, inputs=[title_prompt]) | |
gr.Examples(examples=examples_summaries, inputs=[summary_prompt]) | |
gr.on( | |
triggers=[run_button.click, title_prompt.submit, summary_prompt.submit], | |
fn=infer, | |
inputs=[ | |
title_prompt, | |
summary_prompt, | |
], | |
outputs=[result_bar, result_label], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |