Spaces:
Runtime error
Runtime error
| """ | |
| TODO: | |
| + [x] Load Configuration | |
| + [ ] Multi ASR Engine | |
| + [ ] Batch / Real Time support | |
| """ | |
| from pathlib import Path | |
| from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC, AutoProcessor | |
| from datasets import load_dataset | |
| from datasets import Dataset, Audio | |
| import pdb | |
| import string | |
| # local import | |
| import sys | |
| sys.path.append("src") | |
| # token_model = AutoModelForCTC.from_pretrained( | |
| # "facebook/wav2vec2-base-960h" | |
| # ) | |
| # ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") | |
| audio_path = "/Users/kevingeng/Laronix/Laronix_PAL_ASR_Offline_Plot/data/samples/3_Healthy1.wav" | |
| audio_dir= "/Users/kevingeng/Laronix/laronix_automos/data/Patient_sil_trim_16k_normed_5_snr_40/" | |
| # tgt_audio_dir= "/Users/kevingeng/Laronix/Dataset/Pneumatic/automos" | |
| # src_audio_list = sorted(Path(src_audio_dir).glob("**/*.wav")) | |
| # src_audio_list = [str(x) for x in src_audio_list] | |
| # src_audio_dict = {"audio": src_audio_list} | |
| # src_dataset = Dataset.from_dict(src_audio_dict).cast_column("audio", Audio()) | |
| # tgt_audio_list = sorted(Path(tgt_audio_dir).glob("**/*.wav")) | |
| # tgt_audio_list = [str(x) for x in tgt_audio_list] | |
| # tgt_audio_dict = {"audio": tgt_audio_list} | |
| # tgt_dataset = Dataset.from_dict(tgt_audio_dict).cast_column("audio", Audio()) | |
| # Get Transcription, WER and PPM | |
| """ | |
| TODO: | |
| [DONE]: Automatic generating Config | |
| """ | |
| import yaml | |
| import argparse | |
| import sys | |
| from pathlib import Path | |
| sys.path.append("./src") | |
| import lightning_module | |
| from UV import plot_UV, get_speech_interval | |
| from transformers import pipeline | |
| from rich.progress import track | |
| from rich import print as rprint | |
| import numpy as np | |
| import jiwer | |
| import pdb | |
| import torch.nn as nn | |
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| from sys import flags | |
| from random import sample | |
| from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
| # root_path = Path(__file__).parents[1] | |
| class ChangeSampleRate(nn.Module): | |
| def __init__(self, input_rate: int, output_rate: int): | |
| super().__init__() | |
| self.output_rate = output_rate | |
| self.input_rate = input_rate | |
| def forward(self, wav: torch.tensor) -> torch.tensor: | |
| # Only accepts 1-channel waveform input | |
| wav = wav.view(wav.size(0), -1) | |
| new_length = wav.size(-1) * self.output_rate // self.input_rate | |
| indices = torch.arange(new_length) * ( | |
| self.input_rate / self.output_rate | |
| ) | |
| round_down = wav[:, indices.long()] | |
| round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)] | |
| output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze( | |
| 0 | |
| ) + round_up * indices.fmod(1.0).unsqueeze(0) | |
| return output | |
| model = lightning_module.BaselineLightningModule.load_from_checkpoint( | |
| "./src/epoch=3-step=7459.ckpt" | |
| ).eval() | |
| def calc_wer(audio_path, ref): | |
| wav, sr = torchaudio.load(audio_path) | |
| osr = 16_000 | |
| batch = wav.unsqueeze(0).repeat(10, 1, 1) | |
| csr = ChangeSampleRate(sr, osr) | |
| out_wavs = csr(wav) | |
| # ASR | |
| trans = p(audio_path)["text"] | |
| # WER | |
| wer = jiwer.wer( | |
| ref, | |
| trans, | |
| truth_transform=transformation, | |
| hypothesis_transform=transformation, | |
| ) | |
| return trans, wer | |
| # if __name__ == "__main__": | |
| # # Argparse | |
| # parser = argparse.ArgumentParser( | |
| # prog="get_ref_PPM", | |
| # description="Generate Phoneme per Minute (and Voice/Unvoice plot)", | |
| # epilog="", | |
| # ) | |
| # parser.add_argument( | |
| # "--tag", | |
| # type=str, | |
| # default=None, | |
| # required=False, | |
| # help="ID tag for output *.csv", | |
| # ) | |
| # parser.add_argument("--ref_txt", type=str, required=True, help="Reference TXT") | |
| # parser.add_argument( | |
| # "--ref_wavs", type=str, required=True, help="Reference WAVs" | |
| # ) | |
| # parser.add_argument( | |
| # "--output_dir", | |
| # type=str, | |
| # required=True, | |
| # help="Output Directory for *.csv", | |
| # ) | |
| # parser.add_argument( | |
| # "--to_config", | |
| # choices=["True", "False"], | |
| # default="False", | |
| # help="Generating Config from .txt and wavs/*wav", | |
| # ) | |
| # args = parser.parse_args() | |
| # refs = np.loadtxt(args.ref_txt, delimiter="\n", dtype="str") | |
| # refs_ids = [x.split()[0] for x in refs] | |
| # refs_txt = [" ".join(x.split()[1:]) for x in refs] | |
| # ref_wavs = [str(x) for x in sorted(Path(args.ref_wavs).glob("**/*.wav"))] | |
| # # pdb.set_trace() | |
| # try: | |
| # len(refs) == len(ref_wavs) | |
| # except ValueError: | |
| # print("Error: Text and Wavs don't match") | |
| # exit() | |
| # # ASR part | |
| # p = pipeline("automatic-speech-recognition") | |
| # # WER part | |
| # transformation = jiwer.Compose( | |
| # [ | |
| # jiwer.ToLowerCase(), | |
| # jiwer.RemoveWhiteSpace(replace_by_space=True), | |
| # jiwer.RemoveMultipleSpaces(), | |
| # jiwer.ReduceToListOfListOfWords(word_delimiter=" "), | |
| # ] | |
| # ) | |
| # # WPM part | |
| # processor = Wav2Vec2Processor.from_pretrained( | |
| # "facebook/wav2vec2-xlsr-53-espeak-cv-ft" | |
| # ) | |
| # phoneme_model = Wav2Vec2ForCTC.from_pretrained( | |
| # "facebook/wav2vec2-xlsr-53-espeak-cv-ft" | |
| # ) | |
| # # phoneme_model = pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft") | |
| # description = """ | |
| # MOS prediction demo using UTMOS-strong w/o phoneme encoder model, \ | |
| # which is trained on the main track dataset. | |
| # This demo only accepts .wav format. Best at 16 kHz sampling rate. | |
| # Paper is available [here](https://arxiv.org/abs/2204.02152) | |
| # Add ASR based on wav2vec-960, currently only English available. | |
| # Add WER interface. | |
| # """ | |
| # referance_id = gr.Textbox( | |
| # value="ID", placeholder="Utter ID", label="Reference_ID" | |
| # ) | |
| # referance_textbox = gr.Textbox( | |
| # value="", placeholder="Input reference here", label="Reference" | |
| # ) | |
| # # Set up interface | |
| # result = [] | |
| # result.append("id, trans, wer") | |
| # for id, x, y in track( | |
| # zip(refs_ids, ref_wavs, refs_txt), | |
| # total=len(refs_ids), | |
| # description="Loading references information", | |
| # ): | |
| # trans, wer = calc_wer(x, y) | |
| # record = ",".join( | |
| # [ | |
| # id, | |
| # str(trans), | |
| # str(wer) | |
| # ] | |
| # ) | |
| # result.append(record) | |
| # # Output | |
| # if args.tag == None: | |
| # args.tag = Path(args.ref_wavs).stem | |
| # # Make output_dir | |
| # # pdb.set_trace() | |
| # Path.mkdir(Path(args.output_dir), exist_ok=True) | |
| # # pdb.set_trace() | |
| # with open("%s/%s.csv" % (args.output_dir, args.tag), "w") as f: | |
| # print("\n".join(result), file=f) | |
| # # Generating config | |
| # if args.to_config == "True": | |
| # config_dict = { | |
| # "exp_id": args.tag, | |
| # "ref_txt": args.ref_txt, | |
| # "ref_feature": "%s/%s.csv" % (args.output_dir, args.tag), | |
| # "ref_wavs": args.ref_wavs, | |
| # "thre": { | |
| # "minppm": 100, | |
| # "maxppm": 100, | |
| # "WER": 0.1, | |
| # "AUTOMOS": 4.0, | |
| # }, | |
| # "auth": {"username": None, "password": None}, | |
| # } | |
| # with open("./config/%s.yaml" % args.tag, "w") as config_f: | |
| # rprint("Dumping as config ./config/%s.yaml" % args.tag) | |
| # rprint(config_dict) | |
| # yaml.dump(config_dict, stream=config_f) | |
| # rprint("Change parameter ./config/%s.yaml if necessary" % args.tag) | |
| # print("Reference Dumping Finished") | |
| def dataclean(example): | |
| return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))} | |
| # processor = AutoFeatureExtractor.from_pretrained( | |
| # "facebook/wav2vec2-base-960h" | |
| # ) | |
| processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base") | |
| def prepare_dataset(batch): | |
| audio = batch["audio"] | |
| batch = processor(audio["array"], sampling_rate = audio["sampling_rate"], text=batch['transcription']) | |
| batch["input_length"] = len(batch["input_values"][0]) | |
| return batch | |
| src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train") | |
| src_dataset = src_dataset.map(dataclean) | |
| # train_dev / test | |
| ds = src_dataset.train_test_split(test_size=0.1) | |
| train_dev = ds['train'] | |
| # train / dev | |
| train_dev = train_dev.train_test_split(test_size=int(len(src_dataset)*0.1)) | |
| # train/dev/test | |
| train = train_dev['train'] | |
| test = ds['test'] | |
| dev = train_dev['test'] | |
| # pdb.set_trace() | |
| import numpy as np | |
| def compute_metrics(pred): | |
| pred_logits = pred.predictions | |
| pred_ids = np.argmax(pred_logits, axis=-1) | |
| pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id | |
| pred_str = processor.batch_decode(pred_ids) | |
| label_str = processor.batch_decode(pred.label_ids, group_tokens=False) | |
| wer = wer.compute(predictions=pred_str, references=label_str) | |
| return {"wer": wer} | |
| pdb.set_trace() | |
| # TOKENLIZER("data/samples/5_Laronix1.wav") | |
| # pdb.set_trace() | |
| # tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h") | |
| encoded_train = train.map(prepare_dataset, num_proc=4) | |
| from transformers import AutoModelForCTC, TrainingArguments, Trainer | |
| model = AutoModelForCTC.from_pretrained( | |
| "facebook/wav2vec2-base", | |
| ctc_loss_reduction="mean", | |
| pad_token_id=processor.tokenizer.pad_token_id, | |
| ) | |
| pdb.set_trace() | |
| training_args = TrainingArguments( | |
| output_dir="my_awesome_asr_mind_model", | |
| per_device_train_batch_size=8, | |
| gradient_accumulation_steps=2, | |
| learning_rate=1e-5, | |
| warmup_steps=500, | |
| max_steps=2000, | |
| gradient_checkpointing=True, | |
| fp16=True, | |
| group_by_length=True, | |
| evaluation_strategy="steps", | |
| per_device_eval_batch_size=8, | |
| save_steps=1000, | |
| eval_steps=1000, | |
| logging_steps=25, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="wer", | |
| greater_is_better=False, | |
| push_to_hub=True, | |
| ) | |
| pdb.set_trace() | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=encoded_train["train"], | |
| eval_dataset=encoded_train["test"], | |
| tokenizer=processor.feature_extractor, | |
| compute_metrics=compute_metrics, | |
| ) | |
| pdb.set_trace() | |
| # data_collator=data_collator, | |
| trainer.train() | |
| # x = tokenizer(test['transcription'][0]) | |