sathish2352 commited on
Commit
6c35334
·
verified ·
1 Parent(s): a1cddb1

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +2 -8
models.py CHANGED
@@ -4,21 +4,15 @@ import os
4
 
5
  def load_model():
6
  model_path = "sathish2352/email-classifier-model"
7
-
8
- # Set HF_HOME to use a writable cache dir
9
- os.environ["HF_HOME"] = "/tmp/huggingface"
10
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
11
- os.makedirs("/tmp/huggingface/transformers", exist_ok=True)
12
-
13
  tokenizer = AutoTokenizer.from_pretrained(model_path)
14
  model = AutoModelForSequenceClassification.from_pretrained(model_path)
15
-
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  model.to(device)
18
  model.eval()
19
-
20
  return tokenizer, model, device
21
 
 
22
  def classify_email(text, tokenizer, model, device):
23
  inputs = tokenizer(text, return_tensors="pt", max_length=256, padding="max_length", truncation=True)
24
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
4
 
5
  def load_model():
6
  model_path = "sathish2352/email-classifier-model"
 
 
 
 
 
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_path)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_path)
9
+
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  model.to(device)
12
  model.eval()
 
13
  return tokenizer, model, device
14
 
15
+
16
  def classify_email(text, tokenizer, model, device):
17
  inputs = tokenizer(text, return_tensors="pt", max_length=256, padding="max_length", truncation=True)
18
  inputs = {k: v.to(device) for k, v in inputs.items()}