#!/usr/bin/env python3 """ Generate audio using JAM model Reads from filtered test set and generates audio using CFM+DiT model. """ import os import glob import time import json import random import sys from huggingface_hub import snapshot_download import torch import torchaudio from omegaconf import OmegaConf from torch.utils.data import DataLoader, Dataset from tqdm.auto import tqdm import accelerate import pyloudnorm as pyln from safetensors.torch import load_file from muq import MuQMuLan import numpy as np from accelerate import Accelerator from jam.dataset import enhance_webdataset_config, DiffusionWebDataset from jam.model.vae import StableAudioOpenVAE, DiffRhythmVAE # DiffRhythm imports for CFM+DiT model from jam.model import CFM, DiT def get_negative_style_prompt(device, file_path): vocal_stlye = np.load(file_path) vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512] vocal_stlye = vocal_stlye.half() return vocal_stlye def normalize_audio(audio, normalize_lufs=True): audio = audio - audio.mean(-1, keepdim=True) audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8) if normalize_lufs: meter = pyln.Meter(rate=44100) target_lufs = -14.0 loudness = meter.integrated_loudness(audio.transpose(0, 1).numpy()) normalised = pyln.normalize.loudness(audio.transpose(0, 1).numpy(), loudness, target_lufs) normalised = torch.from_numpy(normalised).transpose(0, 1) else: normalised = audio return normalised class FilteredTestSetDataset(Dataset): """Custom dataset for loading from filtered test set JSON""" def __init__(self, test_set_path, diffusion_dataset, muq_model, num_samples=None, random_crop_style=False, num_style_secs=30, use_prompt_style=False): with open(test_set_path, 'r') as f: self.test_samples = json.load(f) if num_samples is not None: self.test_samples = self.test_samples[:num_samples] self.diffusion_dataset = diffusion_dataset self.muq_model = muq_model self.random_crop_style = random_crop_style self.num_style_secs = num_style_secs self.use_prompt_style = use_prompt_style if self.use_prompt_style: print("Using prompt style instead of audio style.") def __len__(self): return len(self.test_samples) def __getitem__(self, idx): test_sample = self.test_samples[idx] sample_id = test_sample["id"] # Load LRC data lrc_path = test_sample["lrc_path"] with open(lrc_path, 'r') as f: lrc_data = json.load(f) if 'word' not in lrc_data: data = {'word': lrc_data} lrc_data = data # Generate style embedding from original audio on-the-fly audio_path = test_sample["audio_path"] if self.use_prompt_style: prompt_path = test_sample["prompt_path"] prompt = open(prompt_path, 'r').read() if len(prompt) > 300: print(f"Sample {sample_id} has prompt length {len(prompt)}") prompt = prompt[:300] print(prompt) style_embedding = self.muq_model(texts=[prompt]).squeeze(0) else: style_embedding = self.generate_style_embedding(audio_path) duration = test_sample["duration"] # Create fake latent with correct length # Assuming frame_rate from config (typically 21.5 fps for 44.1kHz) frame_rate = 21.5 num_frames = int(duration * frame_rate) fake_latent = torch.randn(128, num_frames) # 128 is latent dim # Create sample tuple matching DiffusionWebDataset format fake_sample = ( sample_id, fake_latent, # latent with correct duration style_embedding, # style from actual audio lrc_data # actual LRC data ) # Process through DiffusionWebDataset's process_sample_safely processed_sample = self.diffusion_dataset.process_sample_safely(fake_sample) # Add metadata if processed_sample is not None: processed_sample['test_metadata'] = { 'sample_id': sample_id, 'audio_path': audio_path, 'lrc_path': lrc_path, 'duration': duration, 'num_frames': num_frames } return processed_sample def generate_style_embedding(self, audio_path): """Generate style embedding using MuQ model on the whole music""" # Load audio waveform, sample_rate = torchaudio.load(audio_path) # Resample to 24kHz if needed (MuQ expects 24kHz) if sample_rate != 24000: resampler = torchaudio.transforms.Resample(sample_rate, 24000) waveform = resampler(waveform) # Convert to mono if stereo if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) # Ensure waveform is 2D (channels, time) - squeeze out channel dim for mono waveform = waveform.squeeze(0) # Now shape is (time,) # Move to same device as model waveform = waveform.to(self.muq_model.device) # Generate embedding using MuQ model with torch.inference_mode(): # MuQ expects batch dimension and 1D audio, returns (batch, embedding_dim) if self.random_crop_style: # Randomly crop 30 seconds from the waveform total_samples = waveform.shape[0] target_samples = 24000 * self.num_style_secs # 30 seconds at 24kHz start_idx = random.randint(0, total_samples - target_samples) style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., start_idx:start_idx + target_samples]) else: style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * self.num_style_secs]) # Keep shape as (embedding_dim,) not scalar return style_embedding[0] def custom_collate_fn_with_metadata(batch, base_collate_fn): """Custom collate function that preserves test_metadata""" # Filter out None samples batch = [item for item in batch if item is not None] if not batch: return None # Extract test_metadata before collating test_metadata = [item.pop('test_metadata') for item in batch] # Use base collate function for the rest collated = base_collate_fn(batch) # Add test_metadata back if collated is not None: collated['test_metadata'] = test_metadata return collated def load_model(model_config, checkpoint_path, device): """ Load JAM CFM model from checkpoint (follows infer.py pattern) """ # Build CFM model from config dit_config = model_config["dit"].copy() # Add text_num_embeds if not specified - should be at least 64 for phoneme tokens if "text_num_embeds" not in dit_config: dit_config["text_num_embeds"] = 256 # Default value from DiT cfm = CFM( transformer=DiT(**dit_config), **model_config["cfm"] ) cfm = cfm.to(device) # Load checkpoint - use the path from config checkpoint = load_file(checkpoint_path) cfm.load_state_dict(checkpoint, strict=False) return cfm.eval() def generate_latent(model, batch, sample_kwargs, negative_style_prompt_path=None, ignore_style=False, device='cuda'): """ Generate latent from batch data (follows infer.py pattern) """ with torch.inference_mode(): batch_size = len(batch["lrc"]) text = batch["lrc"].to(device) style_prompt = batch["prompt"].to(device) start_time = batch["start_time"].to(device) duration_abs = batch["duration_abs"].to(device) duration_rel = batch["duration_rel"].to(device) # Create zero conditioning latent # Handle case where model might be wrapped by accelerator max_frames = model.max_frames cond = torch.zeros(batch_size, max_frames, 64).to(text.device) pred_frames = [(0, max_frames)] default_sample_kwargs = { "cfg_strength": 4, "steps": 50, "batch_infer_num": 1 } sample_kwargs = {**default_sample_kwargs, **sample_kwargs} if negative_style_prompt_path is None: negative_style_prompt_path = 'public_checkpoints/vocal.npy' negative_style_prompt = get_negative_style_prompt(text.device, negative_style_prompt_path) elif negative_style_prompt_path == 'zeros': negative_style_prompt = torch.zeros(1, 512).to(text.device) else: negative_style_prompt = get_negative_style_prompt(text.device, negative_style_prompt_path) negative_style_prompt = negative_style_prompt.repeat(batch_size, 1) latents, _ = model.sample( cond=cond, text=text, style_prompt=negative_style_prompt if ignore_style else style_prompt, duration_abs=duration_abs, duration_rel=duration_rel, negative_style_prompt=negative_style_prompt, start_time=start_time, latent_pred_segments=pred_frames, **sample_kwargs ) return latents class Jamify: def __init__(self): os.makedirs('outputs', exist_ok=True) device = 'cuda' config_path = 'jam_infer.yaml' self.config = OmegaConf.load(config_path) OmegaConf.resolve(self.config) # Override output directory for evaluation print("Downloading main model checkpoint...") model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5") self.config.evaluation.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors") # Load VAE based on configuration vae_type = self.config.evaluation.get('vae_type', 'stable_audio') if vae_type == 'diffrhythm': vae = DiffRhythmVAE(device=device).to(device) else: vae = StableAudioOpenVAE().to(device) self.vae = vae self.vae_type = vae_type self.cfm_model = load_model(self.config.model, self.config.evaluation.checkpoint_path, device) self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(device).eval() dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset) enhance_webdataset_config(dataset_cfg) # Override multiple_styles to False since we're generating single style embeddings dataset_cfg.multiple_styles = False self.base_dataset = DiffusionWebDataset(**dataset_cfg) def cleanup_old_files(self, sample_id): # Clean up old generated files (keep only last 5 files) old_mp3_files = sorted(glob.glob("outputs/*.mp3")) if len(old_mp3_files) >= 10: for old_file in old_mp3_files[:-9]: # Keep last 4, delete older ones try: os.remove(old_file) print(f"Cleaned up old file: {old_file}") except OSError: pass os.unlink(f"outputs/{sample_id}.json") def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration): sample_id = str(int(time.time() * 1000000)) # microsecond timestamp for uniqueness test_set = [{ "id": sample_id, "audio_path": reference_audio_path, "lrc_path": lyrics_json_path, "duration": duration, "prompt_path": style_prompt }] json.dump(test_set, open(f"outputs/{sample_id}.json", "w")) # Create filtered test set dataset test_dataset = FilteredTestSetDataset( test_set_path=f"outputs/{sample_id}.json", diffusion_dataset=self.base_dataset, muq_model=self.muq_model, num_samples=1, random_crop_style=self.config.evaluation.random_crop_style, num_style_secs=self.config.evaluation.num_style_secs, use_prompt_style=self.config.evaluation.use_prompt_style ) # Create dataloader with custom collate function dataloader = DataLoader( test_dataset, batch_size=1, shuffle=False, collate_fn=lambda batch: custom_collate_fn_with_metadata(batch, self.base_dataset.custom_collate_fn) ) batch = next(iter(dataloader)) sample_kwargs = self.config.evaluation.sample_kwargs latent = generate_latent(self.cfm_model, batch, sample_kwargs, self.config.evaluation.negative_style_prompt, self.config.evaluation.ignore_style)[0][0] test_metadata = batch['test_metadata'][0] sample_id = test_metadata['sample_id'] original_duration = test_metadata['duration'] # Decode audio latent_for_vae = latent.transpose(0, 1).unsqueeze(0) # Use chunked decoding if configured (only for DiffRhythm VAE) use_chunked = self.config.evaluation.get('use_chunked_decoding', True) if self.vae_type == 'diffrhythm' and use_chunked: pred_audio = self.vae.decode( latent_for_vae, chunked=True, overlap=self.config.evaluation.get('chunked_overlap', 32), chunk_size=self.config.evaluation.get('chunked_size', 128) ).sample.squeeze(0).detach().cpu() else: pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu() pred_audio = normalize_audio(pred_audio) sample_rate = 44100 trim_samples = int(original_duration * sample_rate) if pred_audio.shape[1] > trim_samples: pred_audio_trimmed = pred_audio[:, :trim_samples] else: pred_audio_trimmed = pred_audio output_path = f'outputs/{sample_id}.mp3' torchaudio.save(output_path, pred_audio_trimmed, sample_rate, format="mp3") self.cleanup_old_files(sample_id) return output_path