JAM / model.py
renhang
fix model.py
65e9daa
#!/usr/bin/env python3
"""
Generate audio using JAM model
Reads from filtered test set and generates audio using CFM+DiT model.
"""
import os
import glob
import time
import json
import random
import sys
from huggingface_hub import snapshot_download
import torch
import torchaudio
from omegaconf import OmegaConf
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
import accelerate
import pyloudnorm as pyln
from safetensors.torch import load_file
from muq import MuQMuLan
import numpy as np
from accelerate import Accelerator
from jam.dataset import enhance_webdataset_config, DiffusionWebDataset
from jam.model.vae import StableAudioOpenVAE, DiffRhythmVAE
# DiffRhythm imports for CFM+DiT model
from jam.model import CFM, DiT
def get_negative_style_prompt(device, file_path):
vocal_stlye = np.load(file_path)
vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512]
vocal_stlye = vocal_stlye.half()
return vocal_stlye
def normalize_audio(audio, normalize_lufs=True):
audio = audio - audio.mean(-1, keepdim=True)
audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8)
if normalize_lufs:
meter = pyln.Meter(rate=44100)
target_lufs = -14.0
loudness = meter.integrated_loudness(audio.transpose(0, 1).numpy())
normalised = pyln.normalize.loudness(audio.transpose(0, 1).numpy(), loudness, target_lufs)
normalised = torch.from_numpy(normalised).transpose(0, 1)
else:
normalised = audio
return normalised
class FilteredTestSetDataset(Dataset):
"""Custom dataset for loading from filtered test set JSON"""
def __init__(self, test_set_path, diffusion_dataset, muq_model, num_samples=None, random_crop_style=False, num_style_secs=30, use_prompt_style=False):
with open(test_set_path, 'r') as f:
self.test_samples = json.load(f)
if num_samples is not None:
self.test_samples = self.test_samples[:num_samples]
self.diffusion_dataset = diffusion_dataset
self.muq_model = muq_model
self.random_crop_style = random_crop_style
self.num_style_secs = num_style_secs
self.use_prompt_style = use_prompt_style
if self.use_prompt_style:
print("Using prompt style instead of audio style.")
def __len__(self):
return len(self.test_samples)
def __getitem__(self, idx):
test_sample = self.test_samples[idx]
sample_id = test_sample["id"]
# Load LRC data
lrc_path = test_sample["lrc_path"]
with open(lrc_path, 'r') as f:
lrc_data = json.load(f)
if 'word' not in lrc_data:
data = {'word': lrc_data}
lrc_data = data
# Generate style embedding from original audio on-the-fly
audio_path = test_sample["audio_path"]
if self.use_prompt_style:
prompt_path = test_sample["prompt_path"]
prompt = open(prompt_path, 'r').read()
if len(prompt) > 300:
print(f"Sample {sample_id} has prompt length {len(prompt)}")
prompt = prompt[:300]
print(prompt)
style_embedding = self.muq_model(texts=[prompt]).squeeze(0)
else:
style_embedding = self.generate_style_embedding(audio_path)
duration = test_sample["duration"]
# Create fake latent with correct length
# Assuming frame_rate from config (typically 21.5 fps for 44.1kHz)
frame_rate = 21.5
num_frames = int(duration * frame_rate)
fake_latent = torch.randn(128, num_frames) # 128 is latent dim
# Create sample tuple matching DiffusionWebDataset format
fake_sample = (
sample_id,
fake_latent, # latent with correct duration
style_embedding, # style from actual audio
lrc_data # actual LRC data
)
# Process through DiffusionWebDataset's process_sample_safely
processed_sample = self.diffusion_dataset.process_sample_safely(fake_sample)
# Add metadata
if processed_sample is not None:
processed_sample['test_metadata'] = {
'sample_id': sample_id,
'audio_path': audio_path,
'lrc_path': lrc_path,
'duration': duration,
'num_frames': num_frames
}
return processed_sample
def generate_style_embedding(self, audio_path):
"""Generate style embedding using MuQ model on the whole music"""
# Load audio
waveform, sample_rate = torchaudio.load(audio_path)
# Resample to 24kHz if needed (MuQ expects 24kHz)
if sample_rate != 24000:
resampler = torchaudio.transforms.Resample(sample_rate, 24000)
waveform = resampler(waveform)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Ensure waveform is 2D (channels, time) - squeeze out channel dim for mono
waveform = waveform.squeeze(0) # Now shape is (time,)
# Move to same device as model
waveform = waveform.to(self.muq_model.device)
# Generate embedding using MuQ model
with torch.inference_mode():
# MuQ expects batch dimension and 1D audio, returns (batch, embedding_dim)
if self.random_crop_style:
# Randomly crop 30 seconds from the waveform
total_samples = waveform.shape[0]
target_samples = 24000 * self.num_style_secs # 30 seconds at 24kHz
start_idx = random.randint(0, total_samples - target_samples)
style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., start_idx:start_idx + target_samples])
else:
style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * self.num_style_secs])
# Keep shape as (embedding_dim,) not scalar
return style_embedding[0]
def custom_collate_fn_with_metadata(batch, base_collate_fn):
"""Custom collate function that preserves test_metadata"""
# Filter out None samples
batch = [item for item in batch if item is not None]
if not batch:
return None
# Extract test_metadata before collating
test_metadata = [item.pop('test_metadata') for item in batch]
# Use base collate function for the rest
collated = base_collate_fn(batch)
# Add test_metadata back
if collated is not None:
collated['test_metadata'] = test_metadata
return collated
def load_model(model_config, checkpoint_path, device):
"""
Load JAM CFM model from checkpoint (follows infer.py pattern)
"""
# Build CFM model from config
dit_config = model_config["dit"].copy()
# Add text_num_embeds if not specified - should be at least 64 for phoneme tokens
if "text_num_embeds" not in dit_config:
dit_config["text_num_embeds"] = 256 # Default value from DiT
cfm = CFM(
transformer=DiT(**dit_config),
**model_config["cfm"]
)
cfm = cfm.to(device)
# Load checkpoint - use the path from config
checkpoint = load_file(checkpoint_path)
cfm.load_state_dict(checkpoint, strict=False)
return cfm.eval()
def generate_latent(model, batch, sample_kwargs, negative_style_prompt_path=None, ignore_style=False, device='cuda'):
"""
Generate latent from batch data (follows infer.py pattern)
"""
with torch.inference_mode():
batch_size = len(batch["lrc"])
text = batch["lrc"].to(device)
style_prompt = batch["prompt"].to(device)
start_time = batch["start_time"].to(device)
duration_abs = batch["duration_abs"].to(device)
duration_rel = batch["duration_rel"].to(device)
# Create zero conditioning latent
# Handle case where model might be wrapped by accelerator
max_frames = model.max_frames
cond = torch.zeros(batch_size, max_frames, 64).to(text.device)
pred_frames = [(0, max_frames)]
default_sample_kwargs = {
"cfg_strength": 4,
"steps": 50,
"batch_infer_num": 1
}
sample_kwargs = {**default_sample_kwargs, **sample_kwargs}
if negative_style_prompt_path is None:
negative_style_prompt_path = 'public_checkpoints/vocal.npy'
negative_style_prompt = get_negative_style_prompt(text.device, negative_style_prompt_path)
elif negative_style_prompt_path == 'zeros':
negative_style_prompt = torch.zeros(1, 512).to(text.device)
else:
negative_style_prompt = get_negative_style_prompt(text.device, negative_style_prompt_path)
negative_style_prompt = negative_style_prompt.repeat(batch_size, 1)
latents, _ = model.sample(
cond=cond,
text=text,
style_prompt=negative_style_prompt if ignore_style else style_prompt,
duration_abs=duration_abs,
duration_rel=duration_rel,
negative_style_prompt=negative_style_prompt,
start_time=start_time,
latent_pred_segments=pred_frames,
**sample_kwargs
)
return latents
class Jamify:
def __init__(self):
os.makedirs('outputs', exist_ok=True)
device = 'cuda'
config_path = 'jam_infer.yaml'
self.config = OmegaConf.load(config_path)
OmegaConf.resolve(self.config)
# Override output directory for evaluation
print("Downloading main model checkpoint...")
model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5")
self.config.evaluation.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors")
# Load VAE based on configuration
vae_type = self.config.evaluation.get('vae_type', 'stable_audio')
if vae_type == 'diffrhythm':
vae = DiffRhythmVAE(device=device).to(device)
else:
vae = StableAudioOpenVAE().to(device)
self.vae = vae
self.vae_type = vae_type
self.cfm_model = load_model(self.config.model, self.config.evaluation.checkpoint_path, device)
self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(device).eval()
dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset)
enhance_webdataset_config(dataset_cfg)
# Override multiple_styles to False since we're generating single style embeddings
dataset_cfg.multiple_styles = False
self.base_dataset = DiffusionWebDataset(**dataset_cfg)
def cleanup_old_files(self, sample_id):
# Clean up old generated files (keep only last 5 files)
old_mp3_files = sorted(glob.glob("outputs/*.mp3"))
if len(old_mp3_files) >= 10:
for old_file in old_mp3_files[:-9]: # Keep last 4, delete older ones
try:
os.remove(old_file)
print(f"Cleaned up old file: {old_file}")
except OSError:
pass
os.unlink(f"outputs/{sample_id}.json")
def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration):
sample_id = str(int(time.time() * 1000000)) # microsecond timestamp for uniqueness
test_set = [{
"id": sample_id,
"audio_path": reference_audio_path,
"lrc_path": lyrics_json_path,
"duration": duration,
"prompt_path": style_prompt
}]
json.dump(test_set, open(f"outputs/{sample_id}.json", "w"))
# Create filtered test set dataset
test_dataset = FilteredTestSetDataset(
test_set_path=f"outputs/{sample_id}.json",
diffusion_dataset=self.base_dataset,
muq_model=self.muq_model,
num_samples=1,
random_crop_style=self.config.evaluation.random_crop_style,
num_style_secs=self.config.evaluation.num_style_secs,
use_prompt_style=self.config.evaluation.use_prompt_style
)
# Create dataloader with custom collate function
dataloader = DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
collate_fn=lambda batch: custom_collate_fn_with_metadata(batch, self.base_dataset.custom_collate_fn)
)
batch = next(iter(dataloader))
sample_kwargs = self.config.evaluation.sample_kwargs
latent = generate_latent(self.cfm_model, batch, sample_kwargs, self.config.evaluation.negative_style_prompt, self.config.evaluation.ignore_style)[0][0]
test_metadata = batch['test_metadata'][0]
sample_id = test_metadata['sample_id']
original_duration = test_metadata['duration']
# Decode audio
latent_for_vae = latent.transpose(0, 1).unsqueeze(0)
# Use chunked decoding if configured (only for DiffRhythm VAE)
use_chunked = self.config.evaluation.get('use_chunked_decoding', True)
if self.vae_type == 'diffrhythm' and use_chunked:
pred_audio = self.vae.decode(
latent_for_vae,
chunked=True,
overlap=self.config.evaluation.get('chunked_overlap', 32),
chunk_size=self.config.evaluation.get('chunked_size', 128)
).sample.squeeze(0).detach().cpu()
else:
pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu()
pred_audio = normalize_audio(pred_audio)
sample_rate = 44100
trim_samples = int(original_duration * sample_rate)
if pred_audio.shape[1] > trim_samples:
pred_audio_trimmed = pred_audio[:, :trim_samples]
else:
pred_audio_trimmed = pred_audio
output_path = f'outputs/{sample_id}.mp3'
torchaudio.save(output_path, pred_audio_trimmed, sample_rate, format="mp3")
self.cleanup_old_files(sample_id)
return output_path