Spaces:
Runtime error
Runtime error
| # Get Transcription, WER and PPM | |
| """ | |
| TODO: | |
| [DONE]: Automatic generating Config | |
| """ | |
| import yaml | |
| import argparse | |
| import sys | |
| from pathlib import Path | |
| sys.path.append("./src") | |
| import lightning_module | |
| from UV import plot_UV, get_speech_interval | |
| from transformers import pipeline | |
| from rich.progress import track | |
| from rich import print as rprint | |
| import numpy as np | |
| import jiwer | |
| import pdb | |
| import torch.nn as nn | |
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| from sys import flags | |
| from random import sample | |
| from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
| # root_path = Path(__file__).parents[1] | |
| class ChangeSampleRate(nn.Module): | |
| def __init__(self, input_rate: int, output_rate: int): | |
| super().__init__() | |
| self.output_rate = output_rate | |
| self.input_rate = input_rate | |
| def forward(self, wav: torch.tensor) -> torch.tensor: | |
| # Only accepts 1-channel waveform input | |
| wav = wav.view(wav.size(0), -1) | |
| new_length = wav.size(-1) * self.output_rate // self.input_rate | |
| indices = torch.arange(new_length) * ( | |
| self.input_rate / self.output_rate | |
| ) | |
| round_down = wav[:, indices.long()] | |
| round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)] | |
| output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze( | |
| 0 | |
| ) + round_up * indices.fmod(1.0).unsqueeze(0) | |
| return output | |
| model = lightning_module.BaselineLightningModule.load_from_checkpoint( | |
| "./src/epoch=3-step=7459.ckpt" | |
| ).eval() | |
| def calc_wer(audio_path, ref, ASR_pipeline): | |
| wav, sr = torchaudio.load(audio_path) | |
| osr = 16_000 | |
| batch = wav.unsqueeze(0).repeat(10, 1, 1) | |
| csr = ChangeSampleRate(sr, osr) | |
| out_wavs = csr(wav) | |
| # ASR | |
| trans = ASR_pipeline(audio_path)["text"] | |
| # WER | |
| wer = jiwer.wer( | |
| ref, | |
| trans, | |
| truth_transform=transformation, | |
| hypothesis_transform=transformation, | |
| ) | |
| return trans, wer | |
| if __name__ == "__main__": | |
| # Argparse | |
| parser = argparse.ArgumentParser( | |
| prog="get_ref_PPM", | |
| description="Generate Phoneme per Minute (and Voice/Unvoice plot)", | |
| epilog="", | |
| ) | |
| parser.add_argument( | |
| "--tag", | |
| type=str, | |
| default=None, | |
| required=False, | |
| help="ID tag for output *.csv", | |
| ) | |
| parser.add_argument("--ref_txt", type=str, required=True, help="Reference TXT") | |
| parser.add_argument( | |
| "--ref_wavs", type=str, required=True, help="Reference WAVs" | |
| ) | |
| parser.add_argument( | |
| "--metadata", | |
| type=str, | |
| required=False, | |
| help="metadata.csv including wav_id and reference", | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| type=str, | |
| default='whisper-medium-FT', | |
| choices=['wav2vec+ctc', 'whipser-medium-FT', 'whipser-large-v2'], | |
| help="ASR engine for evaluation:\n ver1: wav2vec+ctc \n ver2: whipser-medium(Fined-tuned)\n ver3: whipser-large-v2", | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| required=True, | |
| help="Output Directory for *.csv", | |
| ) | |
| parser.add_argument( | |
| "--to_config", | |
| choices=["True", "False"], | |
| default="False", | |
| help="Generating Config from .txt and wavs/*wav", | |
| ) | |
| args = parser.parse_args() | |
| refs = np.loadtxt(args.ref_txt, delimiter="\n", dtype="str") | |
| refs_ids = [x.split()[0] for x in refs] | |
| refs_txt = [" ".join(x.split()[1:]) for x in refs] | |
| ref_wavs = [str(x) for x in sorted(Path(args.ref_wavs).glob("**/*.wav"))] | |
| # pdb.set_trace() | |
| try: | |
| len(refs) == len(ref_wavs) | |
| except ValueError: | |
| print("Error: Text and Wavs don't match") | |
| exit() | |
| # ASR part | |
| if args.model== "whisper-medium-FT": | |
| ASR_pipeline = pipeline("automatic-speech-recognition", model="KevinGeng/whipser_medium_en_PAL300_step25") | |
| elif args.model == "wav2vec+ctc": | |
| ASR_pipeline = pipeline("automatic-speech-recognition") | |
| elif args.model == "whisper-large-v2": | |
| ASR_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-large-v2") | |
| # pdb.set_trace() | |
| # WER part | |
| transformation = jiwer.Compose( | |
| [ | |
| jiwer.ToLowerCase(), | |
| jiwer.RemoveWhiteSpace(replace_by_space=True), | |
| jiwer.RemoveMultipleSpaces(), | |
| jiwer.ReduceToListOfListOfWords(word_delimiter=" "), | |
| ] | |
| ) | |
| # WPM part | |
| processor = Wav2Vec2Processor.from_pretrained( | |
| "facebook/wav2vec2-xlsr-53-espeak-cv-ft" | |
| ) | |
| phoneme_model = Wav2Vec2ForCTC.from_pretrained( | |
| "facebook/wav2vec2-xlsr-53-espeak-cv-ft" | |
| ) | |
| # phoneme_model = pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft") | |
| description = """ | |
| MOS prediction demo using UTMOS-strong w/o phoneme encoder model, \ | |
| which is trained on the main track dataset. | |
| This demo only accepts .wav format. Best at 16 kHz sampling rate. | |
| Paper is available [here](https://arxiv.org/abs/2204.02152) | |
| Add ASR based on wav2vec-960, currently only English available. | |
| Add WER interface. | |
| """ | |
| referance_id = gr.Textbox( | |
| value="ID", placeholder="Utter ID", label="Reference_ID" | |
| ) | |
| referance_textbox = gr.Textbox( | |
| value="", placeholder="Input reference here", label="Reference" | |
| ) | |
| # Set up interface | |
| result = [] | |
| result.append("id,ref,hyp,wer") | |
| for id, x, y in track( | |
| zip(refs_ids, ref_wavs, refs_txt), | |
| total=len(refs_ids), | |
| description="Loading references information", | |
| ): | |
| trans, wer = calc_wer(x, y, ASR_pipeline=ASR_pipeline) | |
| record = ",".join( | |
| [ | |
| id, | |
| str(y), | |
| str(trans), | |
| str(wer) | |
| ] | |
| ) | |
| result.append(record) | |
| # Output | |
| if args.tag == None: | |
| args.tag = Path(args.ref_wavs).stem | |
| # Make output_dir | |
| # pdb.set_trace() | |
| Path.mkdir(Path(args.output_dir), exist_ok=True) | |
| # pdb.set_trace() | |
| with open("%s/%s.csv" % (args.output_dir, args.tag), "w") as f: | |
| print("\n".join(result), file=f) | |
| # Generating config | |
| if args.to_config == "True": | |
| config_dict = { | |
| "exp_id": args.tag, | |
| "ref_txt": args.ref_txt, | |
| "ref_feature": "%s/%s.csv" % (args.output_dir, args.tag), | |
| "ref_wavs": args.ref_wavs, | |
| "thre": { | |
| "minppm": 100, | |
| "maxppm": 100, | |
| "WER": 0.1, | |
| "AUTOMOS": 4.0, | |
| }, | |
| "auth": {"username": None, "password": None}, | |
| } | |
| with open("./config/%s.yaml" % args.tag, "w") as config_f: | |
| rprint("Dumping as config ./config/%s.yaml" % args.tag) | |
| rprint(config_dict) | |
| yaml.dump(config_dict, stream=config_f) | |
| rprint("Change parameter ./config/%s.yaml if necessary" % args.tag) | |
| print("Reference Dumping Finished") |