File size: 4,326 Bytes
fed1edc
17cb7d3
d087544
17cb7d3
 
 
 
 
 
 
 
f4ab270
7089c67
19a569a
f4ab270
d087544
 
 
 
 
 
 
f4ab270
17cb7d3
 
 
4e86063
ef90c80
4e86063
ae07559
4e86063
adf249b
4e86063
 
4474937
 
 
 
 
 
 
 
4e86063
 
 
 
35c31a2
 
e573002
4e86063
 
 
 
 
8afa9a4
4db4bee
cf9fc36
fed1edc
 
8afa9a4
 
4e86063
4474937
 
 
 
17cb7d3
 
d087544
17cb7d3
 
 
4e86063
4db4bee
19a569a
17cb7d3
d087544
 
17cb7d3
 
7089c67
 
d087544
 
325f853
17cb7d3
 
 
 
325f853
 
17cb7d3
 
 
 
325f853
 
 
 
 
 
17cb7d3
 
 
 
7089c67
d087544
 
cb05e89
 
 
17cb7d3
 
d087544
 
17cb7d3
 
 
d087544
 
17cb7d3
 
 
 
fed1edc
7089c67
fed1edc
17cb7d3
 
 
2cdfcf0
 
b60fd95
2cdfcf0
8cd000d
2cdfcf0
609fa27
17cb7d3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""Røst speech-to-text demo."""

import logging
import warnings

import gradio as gr
import numpy as np
import samplerate
import torch
from punctfix import PunctFixer
from transformers import pipeline
from dotenv import load_dotenv
import torch_audiomentations as ta
import os

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s ⋅ %(name)s ⋅ %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("roest-asr-demo")

load_dotenv()

warnings.filterwarnings("ignore", category=FutureWarning)


MODEL_ID = "CoRal-project/roest-wav2vec2-315m-v2"

TITLE = "Røst Speech-to-text Demo"

EMAIL_SUBJECT = "Røst tale-til-tekst demo".replace(" ", "+")
EMAIL_BODY = """
Hej,

Jeg har lige prøvet jeres Røst tale-til-tekst demo, og jeg er imponeret!

Jeg kunne godt tænke mig at høre mere om jeres talegenkendelsesløsninger.

Min use case er [indsæt use case her].

Venlig hilsen,
[dit navn]
""".strip().replace(" ", "+").replace("\n", "%0D")

ICON = """
<svg xmlns="http://www.w3.org/2000/svg" width="25px" height="25px" viewBox="0 0 24 24"
    fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round"
    stroke-linejoin="round" style="display: inline;">
  <path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4"/>
  <polyline points="17 8 12 3 7 8"/>
  <line x1="12" y1="3" x2="12" y2="15"/>
</svg>
"""
DESCRIPTION = f"""
This is a demo of the Danish speech recognition model
[{MODEL_ID}](https://huggingface.co/{MODEL_ID}).

Press "Record" to record your
own voice. When you're done you can press "Stop" to stop recording and "Submit" to
send the audio to the model for transcription. You can also upload an audio file by
pressing the {ICON} button.

_If you like what you see and are interested in integrating speech-to-text solutions
into your products, feel free to
[contact us](mailto:alexandra@alexandra.dk?subject={EMAIL_SUBJECT}&body={EMAIL_BODY})._
"""

logger.info("Loading the ASR model...")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
transcriber = pipeline(
    task="automatic-speech-recognition",
    model=MODEL_ID,
    device=device,
    token=os.getenv("HUGGINGFACE_HUB_TOKEN")
)

logger.info("Loading the punctuation fixer model...")
transcription_fixer = PunctFixer(language="da", device=device)

normaliser = ta.PeakNormalization(p=1.0)

logger.info("Models loaded, ready to transcribe audio.")

def transcribe_audio(sampling_rate_and_audio: tuple[int, np.ndarray] | None) -> str:
    """Transcribe the audio.

    Args:
        sampling_rate_and_audio:
            A tuple with the sampling rate and the audio, or None if no audio was
            provided.

    Returns:
        The transcription.
    """
    if sampling_rate_and_audio is None:
        return (
            "No audio was provided. Please record or upload an audio clip, and try "
            "again."
        )

    sampling_rate, audio = sampling_rate_and_audio
    if audio.ndim > 1:
        audio = np.mean(audio, axis=1)
    audio = samplerate.resample(audio, 16_000 / sampling_rate, "sinc_best")
    audio = normaliser(torch.tensor(audio).unsqueeze(0).unsqueeze(0)).squeeze().numpy()

    logger.info(f"Transcribing audio clip of {len(audio) / 16_000:.2f} seconds...")
    transcription = transcriber(
        inputs=audio, generate_kwargs=dict(language="danish", task="transcribe")
    )
    if not isinstance(transcription, dict):
        return ""

    logger.info(f"Raw transcription is {transcription['text']!r}. Cleaning it up...")
    cleaned_transcription = transcription_fixer.punctuate(
        text=transcription["text"]
    )

    logger.info(f"Final transcription: {cleaned_transcription!r}")
    return cleaned_transcription

demo = gr.Interface(
    fn=transcribe_audio,
    inputs=gr.Audio(
        sources=["microphone", "upload"], show_label=False
    ),
    outputs="textbox",
    title=TITLE,
    description=DESCRIPTION,
    examples=[
        "https://filedn.com/lRBwPhPxgV74tO0rDoe8SpH/audio-examples/bornholmsk.wav",
        "https://filedn.com/lRBwPhPxgV74tO0rDoe8SpH/audio-examples/soenderjysk.wav",
        "https://filedn.com/lRBwPhPxgV74tO0rDoe8SpH/audio-examples/nordjysk.wav",
        "https://filedn.com/lRBwPhPxgV74tO0rDoe8SpH/audio-examples/accent.wav",
    ],
    cache_examples=False,
)

demo.launch()