Hacker1337's picture
fixed 95 percent choosing.
d21dd33
import os
import socket
import pandas as pd
import gradio as gr
# import spaces #[uncomment to use ZeroGPU]
# from diffusers import DiffusionPipeline
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
from dataset import category2human, create_prompt
LOCAL_COMPUTER_NAMES = ["amir-xps"]
def is_local_machine():
return socket.gethostname().lower() in [
name.lower() for name in LOCAL_COMPUTER_NAMES
]
if is_local_machine():
model_path = os.path.expanduser("~/.cache/huggingface/checkpoints/distilbert-arxiv2")
else:
model_path = "Hacker1337/distilbert-arxiv-checkpoint"
from transformers import pipeline
classifier = pipeline(
"text-classification",
model=model_path,
tokenizer=model_path,
)
# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
title_prompt,
summary_prompt,
progress=gr.Progress(track_tqdm=True),
):
sample_prompt_full = create_prompt(
title_prompt,
summary_prompt,
)
predictions = classifier(sample_prompt_full, top_k=None)
target_probs_sum = 0.95
print(predictions)
df = pd.DataFrame(predictions)
df["label"] = df["label"].apply(lambda x: category2human[x])
label_dict = {}
bar_plot_dict = {}
total_prop = sum([prediction["score"] for prediction in predictions])
gained_prob = 0
for prediction in sorted(predictions, key=lambda x: x["score"], reverse=True):
bar_plot_dict[prediction["label"]] = prediction["score"]
if (gained_prob) / total_prop < target_probs_sum:
label_dict[category2human[prediction["label"]]] = (
prediction["score"] / total_prop
)
gained_prob += prediction["score"]
if gained_prob < total_prop + 1e-5:
label_dict["Other"] = (total_prop - gained_prob) / total_prop
return df, label_dict
examples_titles = [
"Survey on Semantic Stereo Matching",
]
examples_summaries = [
"""Stereo matching is one of the widely used techniques for inferring depth from
stereo images owing to its robustness and speed. It has become one of the major
topics of research since it finds its applications in autonomous driving,
robotic navigation, 3D reconstruction, and many other fields. Finding pixel
correspondences in non-textured, occluded and reflective areas is the major
challenge in stereo matching. Recent developments have shown that semantic cues
from image segmentation can be used to improve the results of stereo matching.
Many deep neural network architectures have been proposed to leverage the
advantages of semantic segmentation in stereo matching. This paper aims to give
a comparison among the state of art networks both in terms of accuracy and in
terms of speed which are of higher importance in real-time applications.""",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # Text-to-Image Gradio Template")
gr.Markdown(
"This space classifies scientific machine learning papers into categories based on their title and abstract."
"Bar plot shows probabilities of belonging to each category."
)
gr.Markdown(
"Second thing predicts most probable single class classification. It shows only first 95\% of categories."
)
title_prompt = gr.Text(
label="Title Prompt",
show_label=False,
max_lines=1,
placeholder="Enter paper's title",
container=False,
)
summary_prompt = gr.Text(
label="Summary Prompt",
show_label=False,
max_lines=10,
placeholder="Enter paper's abstract",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
result_bar = gr.BarPlot(
label="Multi class classification",
show_label=True,
x="label",
y="score",
x_label_angle=30,
)
result_label = gr.Label(label="Single class selection")
# with gr.Accordion("Advanced Settings", open=False):
gr.Examples(examples=examples_titles, inputs=[title_prompt])
gr.Examples(examples=examples_summaries, inputs=[summary_prompt])
gr.on(
triggers=[run_button.click, title_prompt.submit, summary_prompt.submit],
fn=infer,
inputs=[
title_prompt,
summary_prompt,
],
outputs=[result_bar, result_label],
)
if __name__ == "__main__":
demo.launch()