| | import os |
| | import torch |
| | import h5py |
| | import random |
| | import numpy as np |
| | import soundfile as sf |
| | from models import DiT |
| | from diffusion import create_diffusion |
| | from tqdm import tqdm |
| | import sys |
| | sys.path.append('./tools/bigvgan_v2_22khz_80band_256x') |
| | from bigvgan import BigVGAN |
| | from torch import nn |
| | import torch.nn.functional as F |
| | import argparse |
| |
|
| | device = 'cuda:1' if torch.cuda.is_available() else 'cpu' |
| |
|
| | class MelToAudio_bigvgan(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.vocoder = BigVGAN.from_pretrained('/home/zheqid/workspace/music_dit/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False) |
| | self.vocoder.remove_weight_norm() |
| |
|
| | def __call__(self, z): |
| | x = self.mel_to_audio(z) |
| | return x |
| |
|
| | def mel_to_audio(self, x): |
| | with torch.no_grad(): |
| | self.vocoder.eval() |
| | y = self.vocoder(x[:, :, :]) |
| | y = y.squeeze(0) |
| | return y |
| |
|
| | vocoder = MelToAudio_bigvgan().to(device) |
| |
|
| | def load_trained_model(checkpoint_path): |
| | model = DiT( |
| | input_size=(80, 800), |
| | patch_size=8, |
| | in_channels=1, |
| | hidden_size=384, |
| | depth=12, |
| | num_heads=6, |
| | ) |
| | model.to(device) |
| | checkpoint = torch.load(checkpoint_path) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | model.eval() |
| | return model |
| |
|
| | def load_all_meta_and_mel_from_h5(h5_file): |
| | with h5py.File(h5_file, 'r') as f: |
| | keys = list(f.keys()) |
| | for key in keys: |
| | meta_latent = torch.FloatTensor(f[key]['meta'][:]).to(device) |
| | mel = torch.FloatTensor(f[key]['mel'][:]).to(device) |
| | yield key, meta_latent, mel |
| |
|
| | def extract_random_mel_segment(mel, segment_length=800): |
| | total_length = mel.shape[2] |
| | if total_length > segment_length: |
| | start = np.random.randint(0, total_length - segment_length) |
| | mel_segment = mel[:, :, start:start + segment_length] |
| | else: |
| | padding = segment_length - total_length |
| | mel_segment = F.pad(mel, (0, padding), mode='constant', value=0) |
| | |
| | mel_segment = (mel_segment + 10) / 20 |
| | return mel_segment |
| |
|
| | def infer_and_generate_audio(model, diffusion, meta_latent): |
| | latent_size = (80, 800) |
| | z = torch.randn(1, 1, latent_size[0], latent_size[1], device=device) |
| | model_kwargs = dict(y=meta_latent) |
| |
|
| | with torch.no_grad(): |
| | samples = diffusion.p_sample_loop( |
| | model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device |
| | ) |
| | |
| | return samples |
| |
|
| | def save_audio(mel, vocoder, output_path, sample_rate=24000): |
| | with torch.no_grad(): |
| | if mel.dim() == 4 and mel.shape[1] == 1: |
| | mel = mel[0, 0, :, :] |
| | elif mel.dim() == 3 and mel.shape[0] == 1: |
| | mel = mel[0] |
| | else: |
| | raise ValueError(f"Unexpected mel shape: {mel.shape}") |
| | |
| | mel = mel.unsqueeze(0) |
| | wav = vocoder(mel * 20 - 10).cpu().numpy() |
| | |
| | sf.write(output_path, wav[0], samplerate=sample_rate) |
| | print(f"Saved audio to: {output_path}") |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description='Generate audio using DiT and BigVGAN') |
| | parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint') |
| | parser.add_argument('--h5_file', type=str, required=True, help='Path to input H5 file') |
| | parser.add_argument('--output_gt_dir', type=str, required=True, help='Directory to save ground truth audio') |
| | parser.add_argument('--output_gen_dir', type=str, required=True, help='Directory to save generated audio') |
| | parser.add_argument('--segment_length', type=int, default=800, help='Segment length for mel slices (default: 800)') |
| | parser.add_argument('--sample_rate', type=int, default=22050, help='Sample rate for output audio (default: 24000)') |
| | args = parser.parse_args() |
| |
|
| | model = load_trained_model(args.checkpoint) |
| | diffusion = create_diffusion(timestep_respacing="") |
| |
|
| | for i, (key, meta_latent, mel) in enumerate(tqdm(load_all_meta_and_mel_from_h5(args.h5_file))): |
| | mel_segment = extract_random_mel_segment(mel, segment_length=args.segment_length) |
| |
|
| | ground_truth_wav_path = os.path.join(args.output_gt_dir, f"{key}.wav") |
| | save_audio(mel_segment, vocoder, ground_truth_wav_path, sample_rate=args.sample_rate) |
| |
|
| | generated_mel = infer_and_generate_audio(model, diffusion, meta_latent) |
| |
|
| | output_wav_path = os.path.join(args.output_gen_dir, f"{key}.wav") |
| | save_audio(generated_mel, vocoder, output_wav_path, sample_rate=args.sample_rate) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| | |
| | ''' |
| | python sample.py --checkpoint ./gtzan-ck/model_epoch_20000.pt \ |
| | --h5_file ./dataset/gtzan_test.h5 \ |
| | --output_gt_dir ./sample/gn \ |
| | --output_gen_dir ./sample/gt \ |
| | --segment_length 800 \ |
| | --sample_rate 22050 |
| | ''' |
| |
|