Spaces:
Running
Running
File size: 5,606 Bytes
391e575 cb10a62 391e575 c3aa1f7 1435c91 391e575 1435c91 391e575 c89f60b 1435c91 391e575 760a845 cb10a62 c89f60b 6b2b9b7 391e575 cb10a62 c89f60b 391e575 6b2b9b7 c89f60b 0954181 1435c91 a3e61e9 0954181 1435c91 0954181 cb10a62 391e575 cb10a62 391e575 cb10a62 391e575 cb10a62 1435c91 cb10a62 391e575 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
# !/usr/bin/python
# -*- coding: utf-8 -*-
# @time : 2021/2/29 21:41
# @author : Mo
# @function: transformers直接加载bert类模型测试
import traceback
import copy
import time
import sys
import os
import re
os.environ["MACRO_CORRECT_FLAG_CSC_TOKEN"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["USE_TORCH"] = "1"
from macro_correct.pytorch_textcorrection.tcTools import preprocess_same_with_training
from macro_correct.pytorch_textcorrection.tcTools import get_errors_for_difflib
from macro_correct.pytorch_textcorrection.tcTools import cut_sent_by_maxlen
from macro_correct.pytorch_textcorrection.tcTools import count_flag_zh
from macro_correct import correct_basic
from macro_correct import correct_long
from macro_correct import correct
import gradio as gr
# pyinstaller -F xxxx.py
# pretrained_model_name_or_path = "shibing624/macbert4csc-base-chinese"
pretrained_model_name_or_path = "Macadam/macbert4mdcspell_v2"
# pretrained_model_name_or_path = "Macropodus/macbert4mdcspell_v1"
# pretrained_model_name_or_path = "Macropodus/macbert4csc_v1"
# pretrained_model_name_or_path = "Macropodus/macbert4csc_v2"
# pretrained_model_name_or_path = "Macropodus/bert4csc_v1"
# device = torch.device("cpu")
# device = torch.device("cuda")
def cut_sent_by_stay_and_maxlen(text, max_len=126, return_length=True):
"""
分句但是保存原标点符号, 如果长度还是太长的话就切为固定长度的句子
Args:
text: str, sentence of input text;
max_len: int, max_len of traing texts;
return_length: bool, wether return length or not
Returns:
res: List<tuple>
"""
### text_sp = re.split(r"!”|?”|。”|……”|”!|”?|”。|”……|》。|)。|!|?|。|…|\!|\?", text)
text_sp = re.split(r"[》)!?。…”;;!?\n]+", text)
conn_symbol = "!?。…”;;!?》)\n"
text_length_s = []
text_cut = []
len_text = len(text) - 1
# signal_symbol = "—”>;?…)‘《’(·》“~,、!。:<"
len_global = 0
for idx, text_sp_i in enumerate(text_sp):
text_cut_idx = text_sp[idx]
len_global_before = copy.deepcopy(len_global)
len_global += len(text_sp_i)
while True:
if len_global <= len_text and text[len_global] in conn_symbol:
text_cut_idx += text[len_global]
else:
# len_global += 1
if text_cut_idx:
### 如果标点符号依旧切分不了, 就强行切
if len(text_cut_idx) > max_len:
text_cut_i, text_length_s_i = cut_sent_by_maxlen(
text=text, max_len=max_len, return_length=True)
text_length_s.extend(text_length_s_i)
text_cut.extend(text_cut_i)
else:
text_length_s.append([len_global_before, len_global])
text_cut.append(text_cut_idx)
break
len_global += 1
if return_length:
return text_cut, text_length_s
return text_cut
def macro_correct(text):
print(text)
texts, texts_length = cut_sent_by_stay_and_maxlen(text, return_length=True)
text_str = ""
text_list = []
for t in texts:
print(t)
t_process = preprocess_same_with_training(t)
text_csc = correct_long(t_process, num_rethink=1, flag_cut=True, limit_length_char=1)
print(text_csc)
### 繁简
if t != t_process:
t_correct, errors = get_errors_for_difflib(t_process, t)
errors_new = []
for err in errors:
if count_flag_zh(err[0]) and count_flag_zh(err[1]):
errors_new.append(err + [1])
if errors_new:
if text_csc:
text_csc[0]["errors"] += errors_new
else:
text_csc = [{"source": t, "target": t_process, "errors": errors_new}]
### 本身的错误
if text_csc:
text_list.extend(text_csc)
text_str += text_csc[0].get("target")
else:
text_list.extend([{}])
text_str += t
text_str += "\n" + "#" * 32 + "\n"
for tdx, t in enumerate(text_list):
if t:
for tk, tv in t.items():
if tk == "index":
text_str += f"idx: {str(tdx+1)}\n"
else:
text_str += f"{str(tk).strip()}: {str(tv).strip()}\n"
text_str += "\n"
return text_str
if __name__ == '__main__':
print(macro_correct('少先队员因该为老人让坐'))
examples = [
"机七学习是人工智能领遇最能体现智能的一个分知",
"我是练习时长两念半的鸽仁练习生蔡徐坤",
"真麻烦你了。希望你们好好的跳无",
"他法语说的很好,的语也不错",
"遇到一位很棒的奴生跟我疗天",
"我们为这个目标努力不解",
]
gr.Interface(
macro_correct,
inputs='text',
outputs='text',
title="Chinese Spelling Correction Model Macropodus/macbert4csc_v2",
description="Copy or input error Chinese text. Submit and the machine will correct text.",
article="Link to <a href='https://github.com/yongzhuo/macro-correct' style='color:blue;' target='_blank\'>Github REPO: macro-correct</a>",
examples=examples
).launch()
# ).launch(server_name="0.0.0.0", server_port=8066, share=False, debug=True)
|