Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
import os | |
import torch | |
import numpy as np | |
import soundfile as sf | |
import gradio as gr | |
from model import UFormer, UFormerConfig | |
# ββββββββββββββββββββββ | |
# 1) Setup & model loading from local checkpoints | |
# ββββββββββββββββββββββ | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
CHECKPOINT_DIR = "checkpoints" | |
config = UFormerConfig() | |
_model_cache = {} | |
VALID_CKPTS = [ | |
"acoustic_guitar","bass","electric_guitar","guitars","keyboards", | |
"orchestra","rhythm_section","synth","vocals" | |
] | |
def _get_model(ckpt_name: str): | |
if ckpt_name not in VALID_CKPTS: | |
raise ValueError(f"Invalid checkpoint {ckpt_name!r}, choose from {VALID_CKPTS}") | |
if ckpt_name in _model_cache: | |
return _model_cache[ckpt_name] | |
ckpt_path = os.path.join(CHECKPOINT_DIR, f"{ckpt_name}.pth") | |
model = UFormer(config).to(DEVICE).eval() | |
state = torch.load(ckpt_path, map_location=DEVICE) | |
model.load_state_dict(state) | |
_model_cache[ckpt_name] = model | |
return model | |
# ββββββββββββββββββββββ | |
# 2) Overlap-add for long audio | |
# ββββββββββββββββββββββ | |
def _overlap_add(model, x: np.ndarray, sr: int, chunk_s: float=5., hop_s: float=2.5): | |
C, T = x.shape | |
chunk, hop = int(sr*chunk_s), int(sr*hop_s) | |
pad = (-(T - chunk) % hop) if T > chunk else 0 | |
x_pad = np.pad(x, ((0,0),(0,pad)), mode="reflect") | |
win = np.hanning(chunk)[None, :] | |
out = np.zeros_like(x_pad) | |
norm = np.zeros((1, x_pad.shape[1])) | |
n_chunks = 1 + (x_pad.shape[1] - chunk) // hop | |
print(f"Processing {n_chunks} chunks of size {chunk} with hop {hop}...") | |
for i in range(n_chunks): | |
s = i * hop | |
seg = x_pad[:, s:s+chunk].astype(np.float32) | |
with torch.no_grad(): | |
y = model(torch.from_numpy(seg[None]).to(DEVICE)).squeeze(0).cpu().numpy() | |
out[:, s:s+chunk] += y * win | |
norm[:, s:s+chunk] += win | |
eps = 1e-8 | |
return (out / (norm + eps))[:, :T] | |
# ββββββββββββββββββββββ | |
# 3) Restore function for Gradio | |
# ββββββββββββββββββββββ | |
def restore_fn(audio_path, checkpoint): | |
audio, sr = sf.read(audio_path) | |
if audio.ndim == 1: | |
audio = np.stack([audio, audio], axis=1) | |
x = audio.T # (C, T) | |
model = _get_model(checkpoint) | |
if x.shape[1] <= sr * 5: | |
seg = x.astype(np.float32)[None] | |
with torch.no_grad(): | |
y = model(torch.from_numpy(seg).to(DEVICE)).squeeze(0).cpu().numpy() | |
else: | |
y = _overlap_add(model, x, sr) | |
tmp = "restored.wav" | |
sf.write(tmp, y.T, sr, format="WAV") | |
return tmp | |
# ββββββββββββββββββββββ | |
# 4) Gradio App | |
# ββββββββββββββββββββββ | |
demo = gr.Interface( | |
fn=restore_fn, | |
inputs=[ | |
gr.Audio(sources="upload", type="filepath", label="Your Input"), | |
gr.Dropdown(VALID_CKPTS, label="Checkpoint") | |
], | |
outputs=gr.Audio(type="filepath", label="Restored Output"), | |
title="π΅ Music Source Restoration", | |
description="Upload an (stereo) audio file and choose an instrument/group checkpoint to restore. Please note that these are baseline models for demonstration purposes only, and most of them don't perform really well...", | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |
else: | |
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) | |