File size: 3,515 Bytes
d347764
 
 
 
bc1e031
 
d347764
 
 
 
 
 
 
b8b62a3
 
 
 
d347764
b8b62a3
 
 
d347764
b8b62a3
d347764
b8b62a3
d347764
 
b8b62a3
 
 
 
 
 
d347764
 
 
b8b62a3
 
 
 
 
 
 
 
 
 
 
 
d347764
 
 
 
b8b62a3
 
 
 
d347764
b8b62a3
 
 
 
d347764
 
 
 
f805e49
 
b8b62a3
c6f1d54
f805e49
 
 
 
c737803
 
 
d347764
b8b62a3
d347764
f805e49
 
d347764
c737803
 
 
b8b62a3
c737803
 
 
 
 
 
 
3946ba6
c737803
3c8d90b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
import numpy as np
import torch
from datasets import load_dataset
from transformers import pipeline, VitsModel, VitsTokenizer
# from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor, pipeline

device = "cuda:0" if torch.cuda.is_available() else "cpu"

# load speech translation checkpoint
asr_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)

# load text-to-speech checkpoint and speaker embeddings
# processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")

# model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
# vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)

# embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
# speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
# using VITS MMS TTS instead of T5 TTS

model = VitsModel.from_pretrained("facebook/mms-tts-deu")

tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-deu")

def translate(audio):
    try:
        outputs = asr_pipe(audio, generate_kwargs={"task": "translate", "return_timestamps": True})
        return outputs["text"]
    except Exception as e:
        print(f"Error in translation: {e}")
        return "Error during translation"


def synthesise(text):
    try:
        inputs = tokenizer(text, return_tensors="pt")
        input_ids = inputs["input_ids"]
        with torch.no_grad():
            outputs = model(input_ids)
        speech = outputs["waveform"]
        speech = speech.cpu()
        return speech.squeeze()
    except Exception as e:
        print(f"Error in synthesis: {e}")
        return None



def speech_to_speech_translation(audio):
    translated_text = translate(audio)
    print('translated text:\t', translated_text)
    if translated_text == "Error during translation":
        return None, None # Return None for both outputs in case of translation error.

    synthesised_speech = synthesise(translated_text)

    if synthesised_speech is None:
        return None, None # Return None for both outputs in case of synthesis error.

    synthesised_speech = (synthesised_speech.numpy() * 32767).astype(np.int16)
    return 16000, synthesised_speech


title = "Cascaded STST"
description = """
Demo for cascaded speech-to-speech translation (STST), mapping from source speech in any language to target speech in German. Demo uses OpenAI's [Whisper Base](https://huggingface.co/openai/whisper-base) model for speech translation, and Microsoft's
[SpeechT5 TTS](https://huggingface.co/microsoft/speecht5_tts) model for text-to-speech:

![Cascaded STST](https://huggingface.co/datasets/huggingface-course/audio-course-images/resolve/main/s2st_cascaded.png "Diagram of cascaded speech to speech translation")
"""

demo = gr.Blocks()

mic_translate = gr.Interface(
    fn=speech_to_speech_translation,
    inputs=gr.Microphone(type="filepath"),
    outputs=gr.Audio(label="Generated Speech", type="numpy"),
    title=title,
    description=description,
)

file_translate = gr.Interface(
    fn=speech_to_speech_translation,
    inputs=gr.Audio(type="filepath"),
    outputs=gr.Audio(label="Generated Speech", type="numpy"),
    examples=[["./example.wav"]],
    title=title,
    description=description,
)

with demo:
    gr.TabbedInterface([mic_translate, file_translate], ["Microphone", "Audio File"])

demo.launch(debug=True, height=600)
# demo.launch(height=600)