bumchik2 commited on
Commit
c7f1481
·
1 Parent(s): fae2fa4

update app

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -4,11 +4,11 @@ import torch
4
  from transformers import AutoModelForSequenceClassification
5
  import pandas as pd
6
  from typing import Dict
7
- from transformers import DistilBertTokenizer
8
  from typing import List
9
 
10
 
11
- USED_MODEL = "distilbert-base-cased"
12
 
13
  @st.cache_resource # кэширование
14
  def load_model():
@@ -24,7 +24,7 @@ def load_model():
24
  index_to_category = {value: key for key, value in category_to_index.items()}
25
 
26
  model = AutoModelForSequenceClassification.from_pretrained(
27
- "bumchik2/train-distilbert-base-cased-tags-classification",
28
  problem_type="multi_label_classification",
29
  num_labels=len(category_to_index),
30
  id2label=index_to_category,
@@ -38,7 +38,7 @@ model = load_model()
38
 
39
  @st.cache_resource()
40
  def get_tokenizer():
41
- return DistilBertTokenizer.from_pretrained(USED_MODEL)
42
 
43
 
44
  def tokenize_function(text):
 
4
  from transformers import AutoModelForSequenceClassification
5
  import pandas as pd
6
  from typing import Dict
7
+ from transformers import RobertaTokenizer
8
  from typing import List
9
 
10
 
11
+ USED_MODEL = "distilroberta-base"
12
 
13
  @st.cache_resource # кэширование
14
  def load_model():
 
24
  index_to_category = {value: key for key, value in category_to_index.items()}
25
 
26
  model = AutoModelForSequenceClassification.from_pretrained(
27
+ f"bumchik2/train-{USED_MODEL}-tags-classification",
28
  problem_type="multi_label_classification",
29
  num_labels=len(category_to_index),
30
  id2label=index_to_category,
 
38
 
39
  @st.cache_resource()
40
  def get_tokenizer():
41
+ return RobertaTokenizer.from_pretrained(USED_MODEL)
42
 
43
 
44
  def tokenize_function(text):