Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,100 Bytes
e00b5e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import os
import numpy as np
import torch
import soundfile as sf
import librosa
import gradio as gr
import spaces # For ZeroGPU
from xcodec2.modeling_xcodec2 import XCodec2Model
# ====== Settings ======
BASE_REPO = os.getenv("BASE_REPO", "HKUSTAudio/xcodec2") # Baseline (pretrained)
FT_REPO = os.getenv("FT_REPO", "NandemoGHS/Anime-XCodec2") # Fine-tuned (yours)
TARGET_SR = 16000 # XCodec2 expects 16 kHz
MAX_SECONDS_DEFAULT = 30 # Default max duration (seconds)
def _ensure_models():
"""Load both models to CPU once, and reuse across requests."""
global _model_base, _model_ft
if _model_base is None:
_model_base = XCodec2Model.from_pretrained(BASE_REPO).eval().to("cpu")
if _model_ft is None:
_model_ft = XCodec2Model.from_pretrained(FT_REPO).eval().to("cpu")
# ====== Globals (lazy CPU load; move to GPU only during inference) ======
_model_base = None
_model_ft = None
_ensure_models()
def _load_audio(filepath: str, max_seconds: int):
"""
Load audio (wav/flac/ogg/mp3), convert to mono, resample to 16 kHz,
trim to the given max length (from the beginning), and return torch.Tensor (1, T).
"""
# Try soundfile first, then fall back to librosa
try:
wav, sr = sf.read(filepath, dtype="float32", always_2d=False)
except Exception:
wav, sr = librosa.load(filepath, sr=None, mono=False)
wav = np.asarray(wav, dtype=np.float32)
# Mono
if wav.ndim == 2:
# soundfile often returns (frames, channels)
if wav.shape[1] in (1, 2): # (frames, ch)
wav = wav.mean(axis=1)
else: # Possibly (ch, frames)
wav = wav.mean(axis=0)
elif wav.ndim > 2:
wav = np.mean(wav, axis=tuple(range(1, wav.ndim)))
# Resample to 16 kHz
if sr != TARGET_SR:
wav = librosa.resample(wav, orig_sr=sr, target_sr=TARGET_SR)
sr = TARGET_SR
# Length cap
if max_seconds is None or max_seconds <= 0:
max_seconds = MAX_SECONDS_DEFAULT
max_len = int(sr * max_seconds)
if wav.shape[0] > max_len:
wav = wav[:max_len]
# Light safety normalization
peak = np.max(np.abs(wav))
if peak > 1.0:
wav = wav / (peak + 1e-8)
wav_tensor = torch.from_numpy(wav).float().unsqueeze(0) # (1, T)
return wav_tensor, sr
def _codes_to_tensor(codes, device):
"""
Normalize the output of xcodec2.encode_code to a tensor with shape (1, 1, N).
Handles version differences where the return type/shape may vary.
"""
if isinstance(codes, torch.Tensor):
return codes.to(device)
try:
t = torch.as_tensor(codes[0][0], device=device)
return t.unsqueeze(0).unsqueeze(0) if t.ndim == 1 else t
except Exception:
return torch.as_tensor(codes, device=device)
def _reconstruct(model: XCodec2Model, waveform: torch.Tensor, device: str) -> np.ndarray:
"""Encode→decode with XCodec2 to get a reconstructed waveform (np.float32, clipped to [-1, 1])."""
with torch.inference_mode():
wave = waveform.to(device)
codes = model.encode_code(input_waveform=wave)
codes_t = _codes_to_tensor(codes, device=device)
recon = model.decode_code(codes_t) # (1, 1, T')
recon_np = recon.squeeze().detach().cpu().numpy().astype(np.float32)
recon_np = np.clip(recon_np, -1.0, 1.0)
return recon_np
@spaces.GPU(duration=60) # ZeroGPU: reserve GPU only during this function call
def run(audio_path, max_seconds):
if audio_path is None:
raise gr.Error("Please upload an audio file.")
_ensure_models()
waveform, sr = _load_audio(audio_path, max_seconds)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Baseline (pretrained)
base = _model_base.to(device)
recon_base = _reconstruct(base, waveform, device)
# Fine-tuned
ft = _model_ft.to(device)
recon_ft = _reconstruct(ft, waveform, device)
# Gradio Audio expects (sample_rate, np.ndarray)
return (sr, recon_base), (sr, recon_ft)
# ====== UI ======
DESCRIPTION = """
# Anime‑XCodec2 / XCodec2 Reconstruction Demo
Compare **Baseline (HKUSTAudio/xcodec2)** and **Fine‑tuned (NandemoGHS/Anime‑XCodec2)** reconstructions side by side.
- Supported inputs: wav / flac / ogg / mp3
- Input is automatically converted to **16 kHz** (as required by XCodec2).
- ZeroGPU ready. If no GPU is available, it falls back to CPU (slower).
"""
with gr.Blocks(theme=gr.themes.Soft(), css="footer {visibility: hidden}") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
inp = gr.Audio(
sources=["upload"],
type="filepath",
label="Upload an audio file",
waveform_options={"show_controls": True}
)
max_sec = gr.Slider(
3, 60, value=MAX_SECONDS_DEFAULT, step=1,
label="Max length (seconds)",
info="If the input is longer, only the first N seconds will be processed."
)
run_btn = gr.Button("Run", variant="primary")
gr.Markdown(
f"**Baseline model**: `{BASE_REPO}` \n"
f"**Fine‑tuned model**: `{FT_REPO}` \n"
f"**Inference device**: auto (GPU on ZeroGPU)"
)
with gr.Column(scale=1):
with gr.Row():
out_base = gr.Audio(
label="Baseline reconstruction (HKUSTAudio/xcodec2)",
show_download_button=True, format="wav"
)
out_ft = gr.Audio(
label="Fine‑tuned reconstruction (NandemoGHS/Anime‑XCodec2)",
show_download_button=True, format="wav"
)
run_btn.click(run, inputs=[inp, max_sec], outputs=[out_base, out_ft])
# In Spaces, explicit launch is optional
if __name__ == "__main__":
demo.queue(max_size=8).launch()
|