File size: 4,295 Bytes
7e18220
 
 
 
 
 
c7f1481
7e18220
 
 
c7f1481
7e18220
 
 
 
 
fae2fa4
 
7e18220
fae2fa4
 
 
 
 
7e18220
fae2fa4
c7f1481
7e18220
fae2fa4
 
 
c9974be
fae2fa4
 
7e18220
 
 
 
 
 
c7f1481
7e18220
 
 
 
 
 
 
 
 
 
 
fae2fa4
 
7e18220
fae2fa4
 
 
 
 
7e18220
b1eb861
fae2fa4
7e18220
fae2fa4
 
7e18220
fae2fa4
 
7e18220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97c6f33
 
7e18220
e225b80
2fab8a5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
from transformers import pipeline
import torch
from transformers import AutoModelForSequenceClassification
import pandas as pd
from typing import Dict
from transformers import RobertaTokenizer
from typing import List


USED_MODEL = "distilroberta-base"

@st.cache_resource  # кэширование
def load_model():
    # csv локально прочитать очень быстро, так что его не кешируем, хотя это не сложно было бы добавить наверное
    arxiv_topics_df = pd.read_csv('arxiv_topics.csv')
    category_to_index = {}
    current_index = 0
    for i, row in arxiv_topics_df.iterrows():
        category = row['category']
        if category not in category_to_index:
            category_to_index[category] = current_index
            current_index += 1
    index_to_category = {value: key for key, value in category_to_index.items()}

    model = AutoModelForSequenceClassification.from_pretrained(
        f"bumchik2/train-{USED_MODEL}-tags-classification", 
        problem_type="multi_label_classification", 
        num_labels=len(category_to_index),
        id2label=index_to_category,
        label2id=category_to_index
    )
    model.eval()
    return model

model = load_model()


@st.cache_resource()
def get_tokenizer():
    return RobertaTokenizer.from_pretrained(USED_MODEL)


def tokenize_function(text):
    tokenizer = get_tokenizer()
    return tokenizer(text, padding="max_length", truncation=True)


@torch.no_grad
def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]:
    # csv локально прочитать очень быстро, так что его не кешируем, хотя это не сложно было бы добавить наверное
    arxiv_topics_df = pd.read_csv('arxiv_topics.csv')
    category_to_index = {}
    current_index = 0
    for i, row in arxiv_topics_df.iterrows():
        category = row['category']
        if category not in category_to_index:
            category_to_index[category] = current_index
            current_index += 1
    index_to_category = {value: key for key, value in category_to_index.items()}

    text = f'{title} $ {summary or ""}'
    category_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits
    sigmoid = torch.nn.Sigmoid()
    category_probs = sigmoid(category_logits.squeeze().cpu()).numpy()
    category_probs /= category_probs.sum()
    category_probs_dict = {category: 0.0 for category in set(arxiv_topics_df['category'])}
    for index in range(len(index_to_category)):
        category_probs_dict[index_to_category[index]] += float(category_probs[index])
    return category_probs_dict


def get_most_probable_keys(probs_dict: Dict[str, float], target_probability: float, print_probabilities: bool) -> List[str]:
    current_p = 0
    probs_list = sorted([(value, key) for key, value in probs_dict.items()])[::-1]
    current_index = 0
    answer = []
    while current_p <= target_probability:
        current_p += probs_list[current_index][0]
        if not print_probabilities:
            answer.append(probs_list[current_index][1])
        else:
            answer.append(f'{probs_list[current_index][1]} ({probs_list[current_index][0]})')
        current_index += 1
        if current_index >= len(probs_list):
            break
    return answer


title = st.text_input("Article title", value="Enter title here...")
summary = st.text_input("Article summary", value="Enter summary here...")

need_to_print_probabilities = st.radio("Need to print probabilities: ", ('Yes', 'No'), index=0)
st.session_state['need_to_print_probabilities'] = need_to_print_probabilities

target_probability = st.slider("Select minimum probability sum", 0.0, 1.0, step=0.01, value=0.95)
st.session_state['target_probability'] = 'target_probability'


if title or summary:
    category_probs_dict = get_category_probs_dict(model=model, title=title, summary=summary or '')
    result = get_most_probable_keys(probs_dict=category_probs_dict, target_probability=target_probability, print_probabilities=need_to_print_probabilities=='Yes')
    result_str = "  \n ".join(result)
    st.write(result_str)