File size: 4,250 Bytes
1b6bcbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# -*- coding: utf-8 -*-

import os

inp_text=                           os.environ.get("inp_text")
inp_wav_dir=                        os.environ.get("inp_wav_dir")
exp_name=                           os.environ.get("exp_name")
i_part=                             os.environ.get("i_part")
all_parts=                          os.environ.get("all_parts")
os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
opt_dir=                            os.environ.get("opt_dir")
bert_pretrained_dir=                os.environ.get("bert_pretrained_dir")
is_half=eval(os.environ.get("is_half","True"))
import sys,numpy as np,traceback,pdb
import os.path
from glob import glob
from tqdm import tqdm
from text.cleaner import clean_text
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np

# inp_text=sys.argv[1]
# inp_wav_dir=sys.argv[2]
# exp_name=sys.argv[3]
# i_part=sys.argv[4]
# all_parts=sys.argv[5]
# os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]#i_gpu
# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
# bert_pretrained_dir="/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large"

from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
    dir=os.path.dirname(path)
    name=os.path.basename(path)
    tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
    torch.save(fea,tmp_path)
    shutil.move(tmp_path,"%s/%s"%(dir,name))

txt_path="%s/2-name2text-%s.txt"%(opt_dir,i_part)
if(os.path.exists(txt_path)==False):
    bert_dir="%s/3-bert"%(opt_dir)
    os.makedirs(opt_dir,exist_ok=True)
    os.makedirs(bert_dir,exist_ok=True)
    device="cuda:0"
    tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
    bert_model=AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
    if (is_half == True):
        bert_model = bert_model.half().to(device)
    else:
        bert_model = bert_model.to(device)
    def get_bert_feature(text, word2ph):
        with torch.no_grad():
            inputs = tokenizer(text, return_tensors="pt")
            for i in inputs:
                inputs[i] = inputs[i].to(device)
            res = bert_model(**inputs, output_hidden_states=True)
            res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]

        assert len(word2ph) == len(text)
        phone_level_feature = []
        for i in range(len(word2ph)):
            repeat_feature = res[i].repeat(word2ph[i], 1)
            phone_level_feature.append(repeat_feature)

        phone_level_feature = torch.cat(phone_level_feature, dim=0)

        return phone_level_feature.T
    def process(data,res):
        for name,text,lan in data:
            try:
                name=os.path.basename(name)
                phones, word2ph, norm_text=clean_text(text.replace("%", '-').replace('¥', ','),lan)
                path_bert="%s/%s.pt"%(bert_dir,name)
                if (os.path.exists(path_bert) == False and lan == "zh"):
                    bert_feature = get_bert_feature(norm_text, word2ph)
                    assert bert_feature.shape[-1] == len(phones)
                    # torch.save(bert_feature, path_bert)
                    my_save(bert_feature, path_bert)
                phones = " ".join(phones)
                # res.append([name,phones])
                res.append([name,phones, word2ph, norm_text])
            except:
                print(name, text, traceback.format_exc())

    todo=[]
    res=[]
    with open(inp_text,"r",encoding="utf8")as f:
        lines=f.read().strip("\n").split("\n")

    language_v1_to_language_v2={
        "ZH":"zh"
    }
    for line in lines[int(i_part)::int(all_parts)]:
        try:
            wav_name,spk_name,language,text=line.split("|")
            # todo.append([name,text,"zh"])
            todo.append([wav_name,text,language_v1_to_language_v2.get(language,language)])
        except:
            print(line,traceback.format_exc())

    process(todo,res)
    opt=[]
    for name,phones, word2ph, norm_text in res:
        opt.append("%s\t%s\t%s\t%s"%(name,phones, word2ph, norm_text))
    with open(txt_path,"w",encoding="utf8")as f:
        f.write("\n".join(opt)+"\n")