import os import numpy as np import torch import soundfile as sf import librosa import gradio as gr import spaces # For ZeroGPU from xcodec2.modeling_xcodec2 import XCodec2Model # ====== Settings ====== BASE_REPO = os.getenv("BASE_REPO", "HKUSTAudio/xcodec2") # Baseline (pretrained) FT_REPO = os.getenv("FT_REPO", "NandemoGHS/Anime-XCodec2") # Fine-tuned (yours) TARGET_SR = 16000 # XCodec2 expects 16 kHz MAX_SECONDS_DEFAULT = 30 # Default max duration (seconds) def _ensure_models(): """Load both models to CPU once, and reuse across requests.""" global _model_base, _model_ft if _model_base is None: _model_base = XCodec2Model.from_pretrained(BASE_REPO).eval().to("cpu") if _model_ft is None: _model_ft = XCodec2Model.from_pretrained(FT_REPO).eval().to("cpu") # ====== Globals (lazy CPU load; move to GPU only during inference) ====== _model_base = None _model_ft = None _ensure_models() def _load_audio(filepath: str, max_seconds: int): """ Load audio (wav/flac/ogg/mp3), convert to mono, resample to 16 kHz, trim to the given max length (from the beginning), and return torch.Tensor (1, T). """ # Try soundfile first, then fall back to librosa try: wav, sr = sf.read(filepath, dtype="float32", always_2d=False) except Exception: wav, sr = librosa.load(filepath, sr=None, mono=False) wav = np.asarray(wav, dtype=np.float32) # Mono if wav.ndim == 2: # soundfile often returns (frames, channels) if wav.shape[1] in (1, 2): # (frames, ch) wav = wav.mean(axis=1) else: # Possibly (ch, frames) wav = wav.mean(axis=0) elif wav.ndim > 2: wav = np.mean(wav, axis=tuple(range(1, wav.ndim))) # Resample to 16 kHz if sr != TARGET_SR: wav = librosa.resample(wav, orig_sr=sr, target_sr=TARGET_SR) sr = TARGET_SR # Length cap if max_seconds is None or max_seconds <= 0: max_seconds = MAX_SECONDS_DEFAULT max_len = int(sr * max_seconds) if wav.shape[0] > max_len: wav = wav[:max_len] # Light safety normalization peak = np.max(np.abs(wav)) if peak > 1.0: wav = wav / (peak + 1e-8) wav_tensor = torch.from_numpy(wav).float().unsqueeze(0) # (1, T) return wav_tensor, sr def _codes_to_tensor(codes, device): """ Normalize the output of xcodec2.encode_code to a tensor with shape (1, 1, N). Handles version differences where the return type/shape may vary. """ if isinstance(codes, torch.Tensor): return codes.to(device) try: t = torch.as_tensor(codes[0][0], device=device) return t.unsqueeze(0).unsqueeze(0) if t.ndim == 1 else t except Exception: return torch.as_tensor(codes, device=device) def _reconstruct(model: XCodec2Model, waveform: torch.Tensor, device: str) -> np.ndarray: """Encode→decode with XCodec2 to get a reconstructed waveform (np.float32, clipped to [-1, 1]).""" with torch.inference_mode(): wave = waveform.to(device) codes = model.encode_code(input_waveform=wave) codes_t = _codes_to_tensor(codes, device=device) recon = model.decode_code(codes_t) # (1, 1, T') recon_np = recon.squeeze().detach().cpu().numpy().astype(np.float32) recon_np = np.clip(recon_np, -1.0, 1.0) return recon_np @spaces.GPU(duration=60) # ZeroGPU: reserve GPU only during this function call def run(audio_path, max_seconds): if audio_path is None: raise gr.Error("Please upload an audio file.") _ensure_models() waveform, sr = _load_audio(audio_path, max_seconds) device = "cuda" if torch.cuda.is_available() else "cpu" # Baseline (pretrained) base = _model_base.to(device) recon_base = _reconstruct(base, waveform, device) # Fine-tuned ft = _model_ft.to(device) recon_ft = _reconstruct(ft, waveform, device) # Gradio Audio expects (sample_rate, np.ndarray) return (sr, recon_base), (sr, recon_ft) # ====== UI ====== DESCRIPTION = """ # Anime‑XCodec2 / XCodec2 Reconstruction Demo Compare **Baseline (HKUSTAudio/xcodec2)** and **Fine‑tuned (NandemoGHS/Anime‑XCodec2)** reconstructions side by side. - Supported inputs: wav / flac / ogg / mp3 - Input is automatically converted to **16 kHz** (as required by XCodec2). - ZeroGPU ready. If no GPU is available, it falls back to CPU (slower). """ with gr.Blocks(theme=gr.themes.Soft(), css="footer {visibility: hidden}") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=1): inp = gr.Audio( sources=["upload"], type="filepath", label="Upload an audio file", waveform_options={"show_controls": True} ) max_sec = gr.Slider( 3, 60, value=MAX_SECONDS_DEFAULT, step=1, label="Max length (seconds)", info="If the input is longer, only the first N seconds will be processed." ) run_btn = gr.Button("Run", variant="primary") gr.Markdown( f"**Baseline model**: `{BASE_REPO}` \n" f"**Fine‑tuned model**: `{FT_REPO}` \n" f"**Inference device**: auto (GPU on ZeroGPU)" ) with gr.Column(scale=1): with gr.Row(): out_base = gr.Audio( label="Baseline reconstruction (HKUSTAudio/xcodec2)", show_download_button=True, format="wav" ) out_ft = gr.Audio( label="Fine‑tuned reconstruction (NandemoGHS/Anime‑XCodec2)", show_download_button=True, format="wav" ) run_btn.click(run, inputs=[inp, max_sec], outputs=[out_base, out_ft]) # In Spaces, explicit launch is optional if __name__ == "__main__": demo.queue(max_size=8).launch()