File size: 4,632 Bytes
5e02fce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""This recipe to train CLAP.
It supports distillation using tinyCLAP (https://arxiv.org/abs/2311.14517).

Authors
    * Francesco Paissan 2024
"""

import sys

import gradio as gr
import speechbrain as sb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T
from hyperpyyaml import load_hyperpyyaml
from speechbrain.utils.distributed import run_on_main
from speechbrain.utils.metric_stats import MetricStats

torch.backends.cudnn.enabled = False

eps = 1e-10


class CLAPBrain(sb.Brain):
    def preprocess(self, wavs):
        """Pre-process wavs."""
        x = self.hparams.spectrogram_extractor(wavs)
        x = self.hparams.logmel_extractor(x)

        return x

    def prepare_txt_features(self, text):
        """Prepares text features to input in CLAP text encoder."""
        txt_inp = self.hparams.txt_tokenizer(
            text,
            max_length=self.hparams.text_max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        ).to(self.device)

        return txt_inp

    def compute_sim(self, audio_embed, caption_embed):
        """Computes CLAP similarity metric."""
        similarity = audio_embed @ caption_embed.t()

        return similarity

    def compute_forward(self, batch, stage):
        if len(batch) == 2:
            wavs, caption = batch
        else:
            wavs, caption, _, _ = batch

        wavs = wavs.to(self.device).squeeze(1)

        x_sb = self.preprocess(wavs)

        text_inp = self.prepare_txt_features(caption)

        txt_shared, aud_shared = self.hparams.clap(
            x_sb,
            text_inp.input_ids.data,
            text_inp.token_type_ids.data,
            text_inp.attention_mask.data,
        )

        if not hasattr(self.modules, "clap"):
            aud_shared_student, _, _ = self.modules.clap_student(x_sb)
            aud_shared_student = aud_shared_student / aud_shared_student.norm(
                dim=1, keepdim=True
            )

        return txt_shared, aud_shared, aud_shared_student


def audio_preprocess(x, sample_rate):
    tmp, sr = torchaudio.load(x)
    resample = T.Resample(sr, sample_rate)

    tmp = resample(tmp)
    tmp = tmp.sum(0, keepdims=True)

    return tmp


@torch.no_grad()
def inference_wrapper(clap_brain):
    def f(wav_path, prompt):
        clap_brain.modules.eval()
        tmp = audio_preprocess(wav_path, clap_brain.hparams.sample_rate)

        ret = clap_brain.compute_forward([tmp, prompt], stage=sb.Stage.TEST)
        sim = clap_brain.compute_sim(ret[2], ret[0])

        return f"tinyCLAP similarity is: {round(sim.item(), 2)}"

    return f


if __name__ == "__main__":

    # CLI:
    # hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
    hparams_file = "hparams/inference.yaml"

    # Load hyperparameters file with command-line overrides
    with open(hparams_file) as fin:
        hparams = load_hyperpyyaml(fin, {})

    # Tensorboard logging
    if hparams["use_tensorboard"]:
        from speechbrain.utils.train_logger import TensorboardLogger

        hparams["tensorboard_train_logger"] = TensorboardLogger(
            hparams["tensorboard_logs_folder"]
        )

    hparams["clap"].to(hparams["device"])
    hparams["clap"].requires_grad_(False)
    hparams["clap"].eval()

    if hparams["zs_eval"]:
        hparams["class_list"] = datasets["train"].dataset.classes

    if hparams["audioenc_name_student"] is not None:
        if hparams["projection_only"]:
            print("Freezing Base AudioEncoder. Updating only the projection layers.")
            hparams["student_model"].base.requires_grad_(False)

    hparams["spectrogram_extractor"].to(hparams["device"])
    hparams["logmel_extractor"].to(hparams["device"])

    clap_brain = CLAPBrain(
        modules=hparams["modules"],
        hparams=hparams,
    )

    if hparams["pretrained_CLAP"] is not None:
        print("Loading CLAP model...")
        run_on_main(hparams["load_CLAP"].collect_files)
        hparams["load_CLAP"].load_collected()

    inference_api = inference_wrapper(clap_brain)

    examples_list = [
        ["./tunztunz_music.wav", "this is the sound of house music"],
        ["./siren.wav", "this is the sound of sirens wailing"],
        [
            "./whistling_and_chirping.wav",
            "someone is whistling while birds are chirping",
        ],
    ]

    demo = gr.Interface(
        fn=inference_api,
        inputs=[gr.Audio(type="filepath"), gr.Textbox()],
        outputs=["text"],
        examples=examples_list,
    )
    demo.launch()