File size: 5,606 Bytes
391e575
cb10a62
391e575
 
 
 
 
 
c3aa1f7
1435c91
391e575
 
1435c91
391e575
 
 
 
c89f60b
1435c91
 
 
391e575
 
 
760a845
cb10a62
c89f60b
6b2b9b7
391e575
 
 
 
 
 
 
 
cb10a62
 
c89f60b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391e575
6b2b9b7
c89f60b
0954181
 
 
 
1435c91
a3e61e9
0954181
1435c91
 
 
 
 
 
 
 
 
 
 
 
 
0954181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb10a62
 
 
391e575
cb10a62
 
 
 
391e575
cb10a62
 
 
 
 
391e575
cb10a62
 
1435c91
cb10a62
 
 
391e575
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# !/usr/bin/python
# -*- coding: utf-8 -*-
# @time    : 2021/2/29 21:41
# @author  : Mo
# @function: transformers直接加载bert类模型测试


import traceback
import copy
import time
import sys
import os
import re
os.environ["MACRO_CORRECT_FLAG_CSC_TOKEN"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["USE_TORCH"] = "1"

from macro_correct.pytorch_textcorrection.tcTools import preprocess_same_with_training
from macro_correct.pytorch_textcorrection.tcTools import get_errors_for_difflib
from macro_correct.pytorch_textcorrection.tcTools import cut_sent_by_maxlen
from macro_correct.pytorch_textcorrection.tcTools import count_flag_zh
from macro_correct import correct_basic
from macro_correct import correct_long
from macro_correct import correct
import gradio as gr

# pyinstaller -F xxxx.py

# pretrained_model_name_or_path = "shibing624/macbert4csc-base-chinese"
pretrained_model_name_or_path = "Macadam/macbert4mdcspell_v2"
# pretrained_model_name_or_path = "Macropodus/macbert4mdcspell_v1"
# pretrained_model_name_or_path = "Macropodus/macbert4csc_v1"
# pretrained_model_name_or_path = "Macropodus/macbert4csc_v2"
# pretrained_model_name_or_path = "Macropodus/bert4csc_v1"
# device = torch.device("cpu")
# device = torch.device("cuda")


def cut_sent_by_stay_and_maxlen(text, max_len=126, return_length=True):
    """
    分句但是保存原标点符号, 如果长度还是太长的话就切为固定长度的句子
    Args:
        text: str, sentence of input text;
        max_len: int, max_len of traing texts;
        return_length: bool, wether return length or not
    Returns:
        res: List<tuple>
    """
    ### text_sp = re.split(r"!”|?”|。”|……”|”!|”?|”。|”……|》。|)。|!|?|。|…|\!|\?", text)
    text_sp = re.split(r"[》)!?。…”;;!?\n]+", text)
    conn_symbol = "!?。…”;;!?》)\n"
    text_length_s = []
    text_cut = []
    len_text = len(text) - 1
    # signal_symbol = "—”>;?…)‘《’(·》“~,、!。:<"
    len_global = 0
    for idx, text_sp_i in enumerate(text_sp):
        text_cut_idx = text_sp[idx]
        len_global_before = copy.deepcopy(len_global)
        len_global += len(text_sp_i)
        while True:
            if len_global <= len_text and text[len_global] in conn_symbol:
                text_cut_idx += text[len_global]
            else:
                # len_global += 1
                if text_cut_idx:
                    ### 如果标点符号依旧切分不了, 就强行切
                    if len(text_cut_idx) > max_len:
                        text_cut_i, text_length_s_i = cut_sent_by_maxlen(
                            text=text, max_len=max_len, return_length=True)
                        text_length_s.extend(text_length_s_i)
                        text_cut.extend(text_cut_i)
                    else:
                        text_length_s.append([len_global_before, len_global])
                        text_cut.append(text_cut_idx)
                break
            len_global += 1
    if return_length:
        return text_cut, text_length_s
    return text_cut


def macro_correct(text):
    print(text)
    texts, texts_length = cut_sent_by_stay_and_maxlen(text, return_length=True)
    text_str = ""
    text_list = []
    for t in texts:
        print(t)
        t_process = preprocess_same_with_training(t)
        text_csc = correct_long(t_process, num_rethink=1, flag_cut=True, limit_length_char=1)
        print(text_csc)
        ### 繁简
        if t != t_process:
            t_correct, errors = get_errors_for_difflib(t_process, t)
            errors_new = []
            for err in errors:
                if count_flag_zh(err[0]) and count_flag_zh(err[1]):
                    errors_new.append(err + [1])
            if errors_new:
                if text_csc:
                    text_csc[0]["errors"] += errors_new
                else:
                    text_csc = [{"source": t, "target": t_process, "errors": errors_new}]
        ### 本身的错误
        if text_csc:
            text_list.extend(text_csc)
            text_str += text_csc[0].get("target")
        else:
            text_list.extend([{}])
            text_str += t
    text_str += "\n" + "#" * 32 + "\n"
    for tdx, t in enumerate(text_list):
        if t:
            for tk, tv in t.items():
                if tk == "index":
                    text_str += f"idx: {str(tdx+1)}\n"
                else:
                    text_str += f"{str(tk).strip()}: {str(tv).strip()}\n"
            text_str += "\n"
    return text_str


if __name__ == '__main__':
    print(macro_correct('少先队员因该为老人让坐'))

    examples = [
        "机七学习是人工智能领遇最能体现智能的一个分知",
        "我是练习时长两念半的鸽仁练习生蔡徐坤",
        "真麻烦你了。希望你们好好的跳无",
        "他法语说的很好,的语也不错",
        "遇到一位很棒的奴生跟我疗天",
        "我们为这个目标努力不解",
    ]
    gr.Interface(
        macro_correct,
        inputs='text',
        outputs='text',
        title="Chinese Spelling Correction Model Macropodus/macbert4csc_v2",
        description="Copy or input error Chinese text. Submit and the machine will correct text.",
        article="Link to <a href='https://github.com/yongzhuo/macro-correct' style='color:blue;' target='_blank\'>Github REPO: macro-correct</a>",
        examples=examples
    ).launch()
    # ).launch(server_name="0.0.0.0", server_port=8066, share=False, debug=True)