Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python3 | |
""" | |
Generate audio using JAM model | |
Reads from filtered test set and generates audio using CFM+DiT model. | |
""" | |
import os | |
import glob | |
import time | |
import json | |
import random | |
import sys | |
from huggingface_hub import snapshot_download | |
import torch | |
import torchaudio | |
from omegaconf import OmegaConf | |
from torch.utils.data import DataLoader, Dataset | |
from tqdm.auto import tqdm | |
import accelerate | |
import pyloudnorm as pyln | |
from safetensors.torch import load_file | |
from muq import MuQMuLan | |
import numpy as np | |
from accelerate import Accelerator | |
from jam.dataset import enhance_webdataset_config, DiffusionWebDataset | |
from jam.model.vae import StableAudioOpenVAE, DiffRhythmVAE | |
# DiffRhythm imports for CFM+DiT model | |
from jam.model import CFM, DiT | |
def get_negative_style_prompt(device, file_path): | |
vocal_stlye = np.load(file_path) | |
vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512] | |
vocal_stlye = vocal_stlye.half() | |
return vocal_stlye | |
def normalize_audio(audio, normalize_lufs=True): | |
audio = audio - audio.mean(-1, keepdim=True) | |
audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8) | |
if normalize_lufs: | |
meter = pyln.Meter(rate=44100) | |
target_lufs = -14.0 | |
loudness = meter.integrated_loudness(audio.transpose(0, 1).numpy()) | |
normalised = pyln.normalize.loudness(audio.transpose(0, 1).numpy(), loudness, target_lufs) | |
normalised = torch.from_numpy(normalised).transpose(0, 1) | |
else: | |
normalised = audio | |
return normalised | |
class FilteredTestSetDataset(Dataset): | |
"""Custom dataset for loading from filtered test set JSON""" | |
def __init__(self, test_set_path, diffusion_dataset, muq_model, num_samples=None, random_crop_style=False, num_style_secs=30, use_prompt_style=False): | |
with open(test_set_path, 'r') as f: | |
self.test_samples = json.load(f) | |
if num_samples is not None: | |
self.test_samples = self.test_samples[:num_samples] | |
self.diffusion_dataset = diffusion_dataset | |
self.muq_model = muq_model | |
self.random_crop_style = random_crop_style | |
self.num_style_secs = num_style_secs | |
self.use_prompt_style = use_prompt_style | |
if self.use_prompt_style: | |
print("Using prompt style instead of audio style.") | |
def __len__(self): | |
return len(self.test_samples) | |
def __getitem__(self, idx): | |
test_sample = self.test_samples[idx] | |
sample_id = test_sample["id"] | |
# Load LRC data | |
lrc_path = test_sample["lrc_path"] | |
with open(lrc_path, 'r') as f: | |
lrc_data = json.load(f) | |
if 'word' not in lrc_data: | |
data = {'word': lrc_data} | |
lrc_data = data | |
# Generate style embedding from original audio on-the-fly | |
audio_path = test_sample["audio_path"] | |
if self.use_prompt_style: | |
prompt_path = test_sample["prompt_path"] | |
prompt = open(prompt_path, 'r').read() | |
if len(prompt) > 300: | |
print(f"Sample {sample_id} has prompt length {len(prompt)}") | |
prompt = prompt[:300] | |
print(prompt) | |
style_embedding = self.muq_model(texts=[prompt]).squeeze(0) | |
else: | |
style_embedding = self.generate_style_embedding(audio_path) | |
duration = test_sample["duration"] | |
# Create fake latent with correct length | |
# Assuming frame_rate from config (typically 21.5 fps for 44.1kHz) | |
frame_rate = 21.5 | |
num_frames = int(duration * frame_rate) | |
fake_latent = torch.randn(128, num_frames) # 128 is latent dim | |
# Create sample tuple matching DiffusionWebDataset format | |
fake_sample = ( | |
sample_id, | |
fake_latent, # latent with correct duration | |
style_embedding, # style from actual audio | |
lrc_data # actual LRC data | |
) | |
# Process through DiffusionWebDataset's process_sample_safely | |
processed_sample = self.diffusion_dataset.process_sample_safely(fake_sample) | |
# Add metadata | |
if processed_sample is not None: | |
processed_sample['test_metadata'] = { | |
'sample_id': sample_id, | |
'audio_path': audio_path, | |
'lrc_path': lrc_path, | |
'duration': duration, | |
'num_frames': num_frames | |
} | |
return processed_sample | |
def generate_style_embedding(self, audio_path): | |
"""Generate style embedding using MuQ model on the whole music""" | |
# Load audio | |
waveform, sample_rate = torchaudio.load(audio_path) | |
# Resample to 24kHz if needed (MuQ expects 24kHz) | |
if sample_rate != 24000: | |
resampler = torchaudio.transforms.Resample(sample_rate, 24000) | |
waveform = resampler(waveform) | |
# Convert to mono if stereo | |
if waveform.shape[0] > 1: | |
waveform = waveform.mean(dim=0, keepdim=True) | |
# Ensure waveform is 2D (channels, time) - squeeze out channel dim for mono | |
waveform = waveform.squeeze(0) # Now shape is (time,) | |
# Move to same device as model | |
waveform = waveform.to(self.muq_model.device) | |
# Generate embedding using MuQ model | |
with torch.inference_mode(): | |
# MuQ expects batch dimension and 1D audio, returns (batch, embedding_dim) | |
if self.random_crop_style: | |
# Randomly crop 30 seconds from the waveform | |
total_samples = waveform.shape[0] | |
target_samples = 24000 * self.num_style_secs # 30 seconds at 24kHz | |
start_idx = random.randint(0, total_samples - target_samples) | |
style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., start_idx:start_idx + target_samples]) | |
else: | |
style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * self.num_style_secs]) | |
# Keep shape as (embedding_dim,) not scalar | |
return style_embedding[0] | |
def custom_collate_fn_with_metadata(batch, base_collate_fn): | |
"""Custom collate function that preserves test_metadata""" | |
# Filter out None samples | |
batch = [item for item in batch if item is not None] | |
if not batch: | |
return None | |
# Extract test_metadata before collating | |
test_metadata = [item.pop('test_metadata') for item in batch] | |
# Use base collate function for the rest | |
collated = base_collate_fn(batch) | |
# Add test_metadata back | |
if collated is not None: | |
collated['test_metadata'] = test_metadata | |
return collated | |
def load_model(model_config, checkpoint_path, device): | |
""" | |
Load JAM CFM model from checkpoint (follows infer.py pattern) | |
""" | |
# Build CFM model from config | |
dit_config = model_config["dit"].copy() | |
# Add text_num_embeds if not specified - should be at least 64 for phoneme tokens | |
if "text_num_embeds" not in dit_config: | |
dit_config["text_num_embeds"] = 256 # Default value from DiT | |
cfm = CFM( | |
transformer=DiT(**dit_config), | |
**model_config["cfm"] | |
) | |
cfm = cfm.to(device) | |
# Load checkpoint - use the path from config | |
checkpoint = load_file(checkpoint_path) | |
cfm.load_state_dict(checkpoint, strict=False) | |
return cfm.eval() | |
def generate_latent(model, batch, sample_kwargs, negative_style_prompt_path=None, ignore_style=False, device='cuda'): | |
""" | |
Generate latent from batch data (follows infer.py pattern) | |
""" | |
with torch.inference_mode(): | |
batch_size = len(batch["lrc"]) | |
text = batch["lrc"].to(device) | |
style_prompt = batch["prompt"].to(device) | |
start_time = batch["start_time"].to(device) | |
duration_abs = batch["duration_abs"].to(device) | |
duration_rel = batch["duration_rel"].to(device) | |
# Create zero conditioning latent | |
# Handle case where model might be wrapped by accelerator | |
max_frames = model.max_frames | |
cond = torch.zeros(batch_size, max_frames, 64).to(text.device) | |
pred_frames = [(0, max_frames)] | |
default_sample_kwargs = { | |
"cfg_strength": 4, | |
"steps": 50, | |
"batch_infer_num": 1 | |
} | |
sample_kwargs = {**default_sample_kwargs, **sample_kwargs} | |
if negative_style_prompt_path is None: | |
negative_style_prompt_path = 'public_checkpoints/vocal.npy' | |
negative_style_prompt = get_negative_style_prompt(text.device, negative_style_prompt_path) | |
elif negative_style_prompt_path == 'zeros': | |
negative_style_prompt = torch.zeros(1, 512).to(text.device) | |
else: | |
negative_style_prompt = get_negative_style_prompt(text.device, negative_style_prompt_path) | |
negative_style_prompt = negative_style_prompt.repeat(batch_size, 1) | |
latents, _ = model.sample( | |
cond=cond, | |
text=text, | |
style_prompt=negative_style_prompt if ignore_style else style_prompt, | |
duration_abs=duration_abs, | |
duration_rel=duration_rel, | |
negative_style_prompt=negative_style_prompt, | |
start_time=start_time, | |
latent_pred_segments=pred_frames, | |
**sample_kwargs | |
) | |
return latents | |
class Jamify: | |
def __init__(self): | |
os.makedirs('outputs', exist_ok=True) | |
device = 'cuda' | |
config_path = 'jam_infer.yaml' | |
self.config = OmegaConf.load(config_path) | |
OmegaConf.resolve(self.config) | |
# Override output directory for evaluation | |
print("Downloading main model checkpoint...") | |
model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5") | |
self.config.evaluation.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors") | |
# Load VAE based on configuration | |
vae_type = self.config.evaluation.get('vae_type', 'stable_audio') | |
if vae_type == 'diffrhythm': | |
vae = DiffRhythmVAE(device=device).to(device) | |
else: | |
vae = StableAudioOpenVAE().to(device) | |
self.vae = vae | |
self.vae_type = vae_type | |
self.cfm_model = load_model(self.config.model, self.config.evaluation.checkpoint_path, device) | |
self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(device).eval() | |
dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset) | |
enhance_webdataset_config(dataset_cfg) | |
# Override multiple_styles to False since we're generating single style embeddings | |
dataset_cfg.multiple_styles = False | |
self.base_dataset = DiffusionWebDataset(**dataset_cfg) | |
def cleanup_old_files(self, sample_id): | |
# Clean up old generated files (keep only last 5 files) | |
old_mp3_files = sorted(glob.glob("outputs/*.mp3")) | |
if len(old_mp3_files) >= 10: | |
for old_file in old_mp3_files[:-9]: # Keep last 4, delete older ones | |
try: | |
os.remove(old_file) | |
print(f"Cleaned up old file: {old_file}") | |
except OSError: | |
pass | |
os.unlink(f"outputs/{sample_id}.json") | |
def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration): | |
sample_id = str(int(time.time() * 1000000)) # microsecond timestamp for uniqueness | |
test_set = [{ | |
"id": sample_id, | |
"audio_path": reference_audio_path, | |
"lrc_path": lyrics_json_path, | |
"duration": duration, | |
"prompt_path": style_prompt | |
}] | |
json.dump(test_set, open(f"outputs/{sample_id}.json", "w")) | |
# Create filtered test set dataset | |
test_dataset = FilteredTestSetDataset( | |
test_set_path=f"outputs/{sample_id}.json", | |
diffusion_dataset=self.base_dataset, | |
muq_model=self.muq_model, | |
num_samples=1, | |
random_crop_style=self.config.evaluation.random_crop_style, | |
num_style_secs=self.config.evaluation.num_style_secs, | |
use_prompt_style=self.config.evaluation.use_prompt_style | |
) | |
# Create dataloader with custom collate function | |
dataloader = DataLoader( | |
test_dataset, | |
batch_size=1, | |
shuffle=False, | |
collate_fn=lambda batch: custom_collate_fn_with_metadata(batch, self.base_dataset.custom_collate_fn) | |
) | |
batch = next(iter(dataloader)) | |
sample_kwargs = self.config.evaluation.sample_kwargs | |
latent = generate_latent(self.cfm_model, batch, sample_kwargs, self.config.evaluation.negative_style_prompt, self.config.evaluation.ignore_style)[0][0] | |
test_metadata = batch['test_metadata'][0] | |
sample_id = test_metadata['sample_id'] | |
original_duration = test_metadata['duration'] | |
# Decode audio | |
latent_for_vae = latent.transpose(0, 1).unsqueeze(0) | |
# Use chunked decoding if configured (only for DiffRhythm VAE) | |
use_chunked = self.config.evaluation.get('use_chunked_decoding', True) | |
if self.vae_type == 'diffrhythm' and use_chunked: | |
pred_audio = self.vae.decode( | |
latent_for_vae, | |
chunked=True, | |
overlap=self.config.evaluation.get('chunked_overlap', 32), | |
chunk_size=self.config.evaluation.get('chunked_size', 128) | |
).sample.squeeze(0).detach().cpu() | |
else: | |
pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu() | |
pred_audio = normalize_audio(pred_audio) | |
sample_rate = 44100 | |
trim_samples = int(original_duration * sample_rate) | |
if pred_audio.shape[1] > trim_samples: | |
pred_audio_trimmed = pred_audio[:, :trim_samples] | |
else: | |
pred_audio_trimmed = pred_audio | |
output_path = f'outputs/{sample_id}.mp3' | |
torchaudio.save(output_path, pred_audio_trimmed, sample_rate, format="mp3") | |
self.cleanup_old_files(sample_id) | |
return output_path | |