mihalykiss commited on
Commit
e99c594
·
verified ·
1 Parent(s): 7ea8ec5

Multiple pace and \n charachters fix

Browse files
Files changed (1) hide show
  1. app.py +17 -14
app.py CHANGED
@@ -2,6 +2,10 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
  import re
 
 
 
 
5
  model1_path = "modernbert.bin"
6
  model2_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12"
7
  model3_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22"
@@ -35,22 +39,21 @@ label_mapping = {
35
  39: 'text-davinci-002', 40: 'text-davinci-003'
36
  }
37
 
38
- def clean_text(text):
39
-
40
- text = text.replace("\r\n", "\n").replace("\r", "\n")
41
-
42
-
43
- text = re.sub(r"\n\s*\n+", "\n\n", text)
44
-
45
- text = re.sub(r"[ \t]+", " ", text)
46
 
47
- text = re.sub(r"(\w+)-\n(\w+)", r"\1\2", text)
48
 
49
- text = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
 
50
 
51
- text = text.strip()
52
-
53
- return text
 
 
 
54
 
55
  def classify_text(text):
56
  cleaned_text = clean_text(text)
@@ -60,7 +63,7 @@ def classify_text(text):
60
  )
61
  return result_message
62
 
63
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
64
 
65
  with torch.no_grad():
66
  logits_1 = model_1(**inputs).logits
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
  import re
5
+ from tokenizers import normalizers
6
+ from tokenizers.normalizers import Sequence, Replace, Strip
7
+ from tokenizers import Regex
8
+
9
  model1_path = "modernbert.bin"
10
  model2_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12"
11
  model3_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22"
 
39
  39: 'text-davinci-002', 40: 'text-davinci-003'
40
  }
41
 
42
+ def clean_text(text: str) -> str:
43
+ text = re.sub(r'\s{2,}', ' ', text)
44
+ text = re.sub(r'\s+([,.;:?!])', r'\1', text)
45
+ return text
 
 
 
 
46
 
 
47
 
48
+ newline_to_space = Replace(Regex(r'\s*\n\s*'), " ")
49
+ join_hyphen_break = Replace(Regex(r'(\w+)[--]\s*\n\s*(\w+)'), r"\1\2")
50
 
51
+ tokenizer.backend_tokenizer.normalizer = Sequence([
52
+ tokenizer.backend_tokenizer.normalizer,
53
+ join_hyphen_break,
54
+ newline_to_space,
55
+ Strip()
56
+ ])
57
 
58
  def classify_text(text):
59
  cleaned_text = clean_text(text)
 
63
  )
64
  return result_message
65
 
66
+ inputs = tokenizer(cleaned_text, return_tensors="pt", truncation=True, padding=True).to(device)
67
 
68
  with torch.no_grad():
69
  logits_1 = model_1(**inputs).logits