bumchik2 commited on
Commit
fae2fa4
·
1 Parent(s): e225b80

using new model now

Browse files
Files changed (1) hide show
  1. app.py +27 -19
app.py CHANGED
@@ -14,18 +14,24 @@ USED_MODEL = "distilbert-base-cased"
14
  def load_model():
15
  # csv локально прочитать очень быстро, так что его не кешируем, хотя это не сложно было бы добавить наверное
16
  arxiv_topics_df = pd.read_csv('arxiv_topics.csv')
17
- tag_to_index = {}
 
18
  for i, row in arxiv_topics_df.iterrows():
19
- tag_to_index[row['tag']] = i
20
- index_to_tag = {value: key for key, value in tag_to_index.items()}
21
-
22
- return AutoModelForSequenceClassification.from_pretrained(
23
- "bumchik2/train_distilbert-base-cased-tags-classification-simple",
 
 
 
24
  problem_type="multi_label_classification",
25
- num_labels=len(tag_to_index),
26
- id2label=index_to_tag,
27
- label2id=tag_to_index
28
  )
 
 
29
 
30
  model = load_model()
31
 
@@ -44,21 +50,23 @@ def tokenize_function(text):
44
  def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]:
45
  # csv локально прочитать очень быстро, так что его не кешируем, хотя это не сложно было бы добавить наверное
46
  arxiv_topics_df = pd.read_csv('arxiv_topics.csv')
47
- tag_to_index = {}
48
- tag_to_category = {}
49
  for i, row in arxiv_topics_df.iterrows():
50
- tag_to_category[row['tag']] = row['category']
51
- tag_to_index[row['tag']] = i
52
- index_to_tag = {value: key for key, value in tag_to_index.items()}
 
 
53
 
54
  text = f'{title} $ {summary}'
55
- tags_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits
56
  sigmoid = torch.nn.Sigmoid()
57
- tags_probs = sigmoid(tags_logits.squeeze().cpu()).numpy()
58
- tags_probs /= tags_probs.sum()
59
  category_probs_dict = {category: 0.0 for category in set(arxiv_topics_df['category'])}
60
- for index in range(len(index_to_tag)):
61
- category_probs_dict[tag_to_category[index_to_tag[index]]] += float(tags_probs[index])
62
  return category_probs_dict
63
 
64
 
 
14
  def load_model():
15
  # csv локально прочитать очень быстро, так что его не кешируем, хотя это не сложно было бы добавить наверное
16
  arxiv_topics_df = pd.read_csv('arxiv_topics.csv')
17
+ category_to_index = {}
18
+ current_index = 0
19
  for i, row in arxiv_topics_df.iterrows():
20
+ category = row['category']
21
+ if category not in category_to_index:
22
+ category_to_index[category] = current_index
23
+ current_index += 1
24
+ index_to_category = {value: key for key, value in category_to_index.items()}
25
+
26
+ model = AutoModelForSequenceClassification.from_pretrained(
27
+ "bumchik2/train-distilbert-base-cased-tags-classification",
28
  problem_type="multi_label_classification",
29
+ num_labels=len(category_to_index),
30
+ id2label=index_to_category,
31
+ label2id=category_to_index
32
  )
33
+ model.eval()
34
+ return model
35
 
36
  model = load_model()
37
 
 
50
  def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]:
51
  # csv локально прочитать очень быстро, так что его не кешируем, хотя это не сложно было бы добавить наверное
52
  arxiv_topics_df = pd.read_csv('arxiv_topics.csv')
53
+ category_to_index = {}
54
+ current_index = 0
55
  for i, row in arxiv_topics_df.iterrows():
56
+ category = row['category']
57
+ if category not in category_to_index:
58
+ category_to_index[category] = current_index
59
+ current_index += 1
60
+ index_to_category = {value: key for key, value in category_to_index.items()}
61
 
62
  text = f'{title} $ {summary}'
63
+ category_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits
64
  sigmoid = torch.nn.Sigmoid()
65
+ category_probs = sigmoid(category_logits.squeeze().cpu()).numpy()
66
+ category_probs /= category_probs.sum()
67
  category_probs_dict = {category: 0.0 for category in set(arxiv_topics_df['category'])}
68
+ for index in range(len(index_to_category)):
69
+ category_probs_dict[index_to_category[index]] += float(category_probs[index])
70
  return category_probs_dict
71
 
72