File size: 12,490 Bytes
5e02fce ea94406 5e02fce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 |
"""
Code to define CLAP-related networks.
Some code inspired from here https://github.com/zhepeiw/clap_curation
Credits:
* Francesco Paissan 2024
"""
import numpy as np
import torch
import torch.nn.functional as F
from micromind.networks import PhiNet
from speechbrain.utils.fetching import fetch
from torch import nn
from torchinfo import summary
from transformers import AutoModel, BatchEncoding
def get_model_from_str(s, vs=("alpha", "beta", "t0", "N")):
def get_var(s, key):
tmp = s.split("_")
return tmp[tmp.index(key) + 1]
verb = "PhiNet initialized with "
ret = {}
for k in vs:
verb += f"{k}={get_var(s, k)} "
ret[k] = float(get_var(s, k))
ret["t_zero"] = ret["t0"]
ret["num_layers"] = ret["N"]
del ret["t0"]
del ret["N"]
return ret
def get_audio_encoder(name: str):
if name == "Cnn14":
return Cnn14
elif "phinet" in name:
phinet_conf = get_model_from_str(name)
return PhiNet(input_shape=(1, 640, 64), compatibility=True, **phinet_conf)
else:
raise Exception(
"The audio encoder name {} is incorrect or not supported".format(name)
)
class Projection(nn.Module):
def __init__(self, d_in: int, d_out: int, p: float = 0.5) -> None:
super().__init__()
self.linear1 = nn.Linear(d_in, d_out, bias=False)
self.linear2 = nn.Linear(d_out, d_out, bias=False)
self.layer_norm = nn.LayerNorm(d_out)
self.drop = nn.Dropout(p)
def forward(self, x: torch.Tensor) -> torch.Tensor:
embed1 = self.linear1(x)
embed2 = self.drop(self.linear2(F.gelu(embed1)))
embeds = self.layer_norm(embed1 + embed2)
return embeds
class PhiNet(PhiNet):
def __init__(self, embedding_dim=2048, *args, **kwargs):
super().__init__(*args, **kwargs)
self.bn0 = nn.BatchNorm2d(64)
if embedding_dim is not None:
in_channels_next = self._layers[-1]._layers[-2].weight.shape[0]
self.pn_block = nn.Conv2d(
in_channels_next,
embedding_dim,
kernel_size=1,
stride=2,
)
def forward(self, x):
if x.dim() == 3:
x = x[:, None]
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
x = super().forward(x)
embedding = x
x = self.pn_block(x)
x = x.mean((-1, -2))
return {"embedding": (x, embedding), "clipwise_output": x}
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias=False,
)
self.conv2 = nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias=False,
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, input, pool_size=(2, 2), pool_type="avg"):
x = input
x = F.relu_(self.bn1(self.conv1(x)))
x = F.relu_(self.bn2(self.conv2(x)))
if pool_type == "max":
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == "avg":
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == "avg+max":
x1 = F.avg_pool2d(x, kernel_size=pool_size)
x2 = F.max_pool2d(x, kernel_size=pool_size)
x = x1 + x2
else:
raise Exception("Incorrect argument!")
return x
class ConvBlock5x5(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock5x5, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(5, 5),
stride=(1, 1),
padding=(2, 2),
bias=False,
)
self.bn1 = nn.BatchNorm2d(out_channels)
def forward(self, input, pool_size=(2, 2), pool_type="avg"):
x = input
x = F.relu_(self.bn1(self.conv1(x)))
if pool_type == "max":
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == "avg":
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == "avg+max":
x1 = F.avg_pool2d(x, kernel_size=pool_size)
x2 = F.max_pool2d(x, kernel_size=pool_size)
x = x1 + x2
else:
raise Exception("Incorrect argument!")
return x
class AttBlock(nn.Module):
def __init__(self, n_in, n_out, activation="linear", temperature=1.0):
super(AttBlock, self).__init__()
self.activation = activation
self.temperature = temperature
self.att = nn.Conv1d(
in_channels=n_in,
out_channels=n_out,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
self.cla = nn.Conv1d(
in_channels=n_in,
out_channels=n_out,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
self.bn_att = nn.BatchNorm1d(n_out)
def forward(self, x):
# x: (n_samples, n_in, n_time)
norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
cla = self.nonlinear_transform(self.cla(x))
x = torch.sum(norm_att * cla, dim=2)
return x, norm_att, cla
def nonlinear_transform(self, x):
if self.activation == "linear":
return x
elif self.activation == "sigmoid":
return torch.sigmoid(x)
class Cnn14(nn.Module):
def __init__(
self,
classes_num,
out_emb,
):
super(Cnn14, self).__init__()
self.bn0 = nn.BatchNorm2d(64)
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
# out_emb is 2048 for best Cnn14
self.fc1 = nn.Linear(2048, out_emb, bias=True)
self.fc_audioset = nn.Linear(out_emb, classes_num, bias=True)
def forward(self, x, mixup_lambda=None):
"""
Input: (batch_size, data_length)
"""
# (batch_size, 1, time_steps, mel_bins)
if x.dim() == 3:
x = x.unsqueeze(1)
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x, p=0.2, training=self.training)
x4_out = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x4_out, p=0.2, training=self.training)
x3_out = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x3_out, p=0.2, training=self.training)
x2_out = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x2_out, p=0.2, training=self.training)
x1_out = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
x = F.dropout(x1_out, p=0.2, training=self.training)
x = torch.mean(x, dim=3)
(x1, _) = torch.max(x, dim=2)
x2 = torch.mean(x, dim=2)
x = x1 + x2
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu_(self.fc1(x))
embedding = F.dropout(x, p=0.5, training=self.training)
clipwise_output = torch.sigmoid(self.fc_audioset(x))
output_dict = {
"clipwise_output": clipwise_output,
"embedding": (embedding, x1_out, x2_out, x3_out, x4_out),
}
return output_dict
class AudioEncoder(nn.Module):
def __init__(
self,
audioenc_name: str,
d_in: int,
d_out: int,
classes_num: int,
) -> None:
super().__init__()
audio_encoder = get_audio_encoder(audioenc_name)
if not "phinet" in audioenc_name:
self.base = audio_encoder(
classes_num,
d_in,
)
else:
self.base = audio_encoder
self.projection = Projection(d_in, d_out)
def forward(self, x):
out_dict = self.base(x)
audio_features, audio_classification_output = (
out_dict["embedding"][0],
out_dict["clipwise_output"],
)
projected_vec = self.projection(audio_features)
return (
projected_vec,
out_dict["embedding"][1:],
audio_classification_output,
)
class TextEncoder(nn.Module):
def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None:
super().__init__()
self.base = AutoModel.from_pretrained(text_model)
self.projection = Projection(transformer_embed_dim, d_out)
def forward(self, x):
out = self.base(**x)[0]
hidden_state = out
out = out[:, 0, :] # get CLS token output
projected_vec = self.projection(out)
self.hidden_state = hidden_state.detach()
return projected_vec
class CLAP(nn.Module):
def __init__(
self,
# audio
audioenc_name: str,
classes_num: int,
out_emb: int,
# text
text_model: str,
transformer_embed_dim: int,
# common
d_proj: int,
pretrained_weights: bool = True,
CLAP_weights: str = None,
# audio student
audioenc_name_student=None,
out_emb_student=None,
):
super().__init__()
ckpt_path = None
if pretrained_weights and CLAP_weights is not None:
weights_path = "CLAP_weights.pth"
tmp = CLAP_weights.split("/")
print(
" ".join(
"""Fetching CLAP weights.
The checkpoint is a ~2GB, so be patient.
The process will start right after.
""".split()
)
)
fetch(
tmp[-1],
"/".join(tmp[:-1]),
savedir=".",
save_filename=weights_path,
)
ckpt_path = weights_path
self.audio_encoder = AudioEncoder(
audioenc_name,
out_emb,
d_proj,
classes_num,
)
self.caption_encoder = TextEncoder(d_proj, text_model, transformer_embed_dim)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
state_dict = torch.load(ckpt_path, map_location="cpu")["model"]
self.load_state_dict(self.clean_state_dict(state_dict))
print("Loaded pretrained CLAP checkpoint.")
@staticmethod
def clean_state_dict(state_dict):
"""Removes pre-processing keys from the state-dict."""
keys_to_remove = []
for k in state_dict:
if "spectrogram" in k or "mel" in k:
keys_to_remove.append(k)
for k in keys_to_remove:
state_dict.pop(
k,
None,
)
return state_dict
def forward(self, audio, input_ids, token_type_ids, attention_mask, single=None):
audio_embed = None
caption_embed = None
if not single == "txt":
audio_embed, _, _ = self.audio_encoder(audio)
audio_embed = audio_embed / audio_embed.norm(dim=1, keepdim=True)
if not single == "aud":
text = BatchEncoding(
{
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
}
)
caption_embed = self.caption_encoder(text)
caption_embed = caption_embed / caption_embed.norm(dim=1, keepdim=True)
return caption_embed, audio_embed
|