| import os |
| import sys |
| import json |
| import torch |
| import torch.nn.functional as F |
| import librosa |
| import numpy as np |
| from pathlib import Path |
| from tqdm import tqdm |
| import warnings |
| from torch.nn.utils import remove_weight_norm, weight_norm |
| import librosa |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import json |
| import torch |
| from higgs_audio_tokenizer import HiggsAudioTokenizer |
| import torch |
| import torch.nn as nn |
| import warnings |
|
|
| warnings.filterwarnings('ignore') |
|
|
|
|
| def remove_weight_norms_from_model(model): |
| for module in model.modules(): |
| try: |
| remove_weight_norm(module) |
| except: |
| continue |
| return model |
|
|
|
|
| class EncodedResult: |
| def __init__(self, audio_codes): |
| self.audio_codes = audio_codes |
|
|
|
|
| def encode_batch(model, x_batch): |
| e_semantic_input = model.get_regress_target(x_batch).detach() |
| e_semantic = model.encoder_semantic(e_semantic_input.transpose(1, 2)) |
| e_acoustic = model.encoder(x_batch) |
| |
| if e_acoustic.shape[2] != e_semantic.shape[2]: |
| pad_size = 160 * model.semantic_downsample_factor |
| |
| x_slice = x_batch[:, 0, :] |
| |
| x_padded = F.pad(x_slice, (pad_size, pad_size)) |
| |
| e_acoustic = model.encoder(x_padded.unsqueeze(1)) |
| |
| min_len = min(e_acoustic.shape[2], e_semantic.shape[2]) |
| e_acoustic = e_acoustic[:, :, :min_len] |
| e_semantic = e_semantic[:, :, :min_len] |
| |
| e = torch.cat([e_acoustic, e_semantic], dim=1) |
| e = model.fc_prior(e.transpose(1, 2)) |
| |
| if model.quantizer_type == "RVQ": |
| e = e.transpose(1, 2) |
| _, codes, _, _ = model.quantizer(e, model.frame_rate, None) |
| codes = codes.permute(1, 0, 2) |
| else: |
| quantized, codes = model.quantizer(e) |
| codes = codes.permute(0, 2, 1) |
| |
| return EncodedResult(audio_codes=codes) |
|
|
|
|
| def fix_all_inference_issues(model): |
| device = next(model.parameters()).device |
| |
| model.eval() |
| with torch.no_grad(): |
| for module in model.modules(): |
| if isinstance(module, nn.Module): |
| module.eval() |
| if hasattr(module, 'training'): |
| module.training = False |
| |
| if hasattr(model, 'semantic_model'): |
| print("Fixing semantic model...") |
| |
| model.semantic_model = model.semantic_model.to(device) |
| model.semantic_model.eval() |
| |
| def disable_gradient_checkpointing(module): |
| if hasattr(module, 'gradient_checkpointing'): |
| module.gradient_checkpointing = False |
| if hasattr(module, 'gradient_checkpointing_disable'): |
| try: |
| module.gradient_checkpointing_disable() |
| except: |
| pass |
| for child in module.children(): |
| disable_gradient_checkpointing(child) |
| |
| disable_gradient_checkpointing(model.semantic_model) |
| |
| if hasattr(model.semantic_model, 'encoder'): |
| model.semantic_model.encoder.gradient_checkpointing = False |
| if hasattr(model.semantic_model.encoder, 'layers'): |
| for layer in model.semantic_model.encoder.layers: |
| if hasattr(layer, 'gradient_checkpointing'): |
| layer.gradient_checkpointing = False |
| |
| def set_dropout_eval(module): |
| if isinstance(module, nn.Dropout): |
| module.eval() |
| module.training = False |
| for child in module.children(): |
| set_dropout_eval(child) |
| |
| set_dropout_eval(model) |
| |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| |
| return model |
|
|
|
|
| def inference_pipeline(checkpoint_path, config_path, device='cuda'): |
| print("Loading config...") |
| with open(config_path, 'r') as f: |
| config = json.load(f) |
| |
| print("Creating model...") |
| model = HiggsAudioTokenizer( |
| n_filters=config['n_filters'], |
| D=config['D'], |
| target_bandwidths=config['target_bandwidths'], |
| ratios=config['ratios'], |
| sample_rate=config['sample_rate'], |
| bins=config['bins'], |
| n_q=config['n_q'], |
| codebook_dim=config.get('codebook_dim', None), |
| semantic_techer=config['semantic_techer'], |
| device=device |
| ).to(device) |
| |
| print("Loading checkpoint...") |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| |
| if 'model_state_dict' in checkpoint: |
| state_dict = checkpoint['model_state_dict'] |
| else: |
| state_dict = checkpoint |
| |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| if k.startswith('module.'): |
| new_state_dict[k[7:]] = v |
| else: |
| new_state_dict[k] = v |
| |
| model.load_state_dict(new_state_dict, strict=False) |
| |
| print("Fixing inference issues...") |
| model = fix_all_inference_issues(model) |
| |
| return model |
|
|
|
|
| warnings.filterwarnings("ignore") |
|
|
| OUTPUT_DIR = "/home/ubuntu/data_boson_44.1khz" |
| BATCH_SIZE = 32 |
| SAMPLE_RATE = 44100 |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
| DATASET_PATH = "/home/ubuntu/ttsar/Layla/src_bpe_2/Qanary_data" |
|
|
| print(f"Using device: {DEVICE}") |
|
|
| os.chdir("/home/ubuntu/ttsar/boson_audio_codec/audio_processing") |
|
|
| from datasets import load_from_disk |
|
|
| print(f"Loading dataset from: {DATASET_PATH}") |
| ds = load_from_disk(DATASET_PATH) |
| print(f"Dataset info: {ds}") |
|
|
| columns_to_remove = ['spk', 'duration', 'codes', 'input_ids', 'attention_mask'] |
| existing_columns = [col for col in columns_to_remove if col in ds.column_names] |
| if existing_columns: |
| ds = ds.remove_columns(existing_columns) |
|
|
| df = ds.to_pandas() |
| print(f"Loaded {len(df)} files from dataset") |
|
|
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
| print(f"Output directory '{OUTPUT_DIR}' is ready.") |
|
|
| print("Checking for already processed files...") |
|
|
|
|
| def get_output_path(audio_path): |
| base_name = Path(audio_path).stem |
| return os.path.join(OUTPUT_DIR, f"{base_name}.pt") |
|
|
|
|
| original_count = len(df) |
| df['output_exists'] = df['filename'].apply(lambda x: os.path.exists(get_output_path(x))) |
| df_filtered = df[~df['output_exists']].copy() |
| skipped_count = original_count - len(df_filtered) |
|
|
| print(f"Found {skipped_count} already processed files. Skipping them.") |
| print(f"Processing {len(df_filtered)} remaining files.") |
|
|
| if len(df_filtered) == 0: |
| print("All files have already been processed!") |
| exit() |
|
|
| print("Loading Higgs Audio Tokenizer model...") |
| from transformers import HubertModel |
| from higgs_audio_tokenizer import HiggsAudioTokenizer |
|
|
| checkpoint_path = '/home/ubuntu/ttsar/boson_audio_codec/audio_processing/outputs_CQT/checkpoints/step_99000.pth' |
| config_path = '/home/ubuntu/ttsar/boson_audio_codec/audio_processing/config copy.json' |
| device = 'cuda' |
|
|
| model = inference_pipeline(checkpoint_path, config_path, device) |
| _ = model.eval() |
| model = remove_weight_norms_from_model(model) |
| print(f"Model loaded on {DEVICE}") |
|
|
| hop_length = model.hop_length |
| print(f"Encoder hop length: {hop_length}") |
|
|
| print(f"\nStarting batch processing with batch size {BATCH_SIZE}...") |
|
|
| filenames = df_filtered['filename'].tolist() |
| total_processed = 0 |
| total_errors = 0 |
|
|
| with torch.no_grad(): |
| for batch_start in tqdm(range(0, len(filenames), BATCH_SIZE), desc="Processing batches"): |
| batch_end = min(batch_start + BATCH_SIZE, len(filenames)) |
| batch_filenames = filenames[batch_start:batch_end] |
| |
| batch_audio = [] |
| batch_lengths = [] |
| batch_outputs = [] |
| |
| for filename in batch_filenames: |
| output_path = get_output_path(filename) |
| |
| if os.path.exists(output_path): |
| continue |
| |
| try: |
| wav, _ = librosa.load(filename, sr=SAMPLE_RATE) |
| wav_tensor = torch.from_numpy(wav).float() |
| |
| batch_audio.append(wav_tensor) |
| batch_lengths.append(len(wav)) |
| batch_outputs.append(output_path) |
| |
| except Exception as e: |
| print(f"\nError loading {filename}: {e}") |
| total_errors += 1 |
| continue |
| |
| if not batch_audio: |
| continue |
| |
| max_len = max(len(x) for x in batch_audio) |
| padded_batch = [] |
| |
| for audio in batch_audio: |
| pad_len = max_len - len(audio) |
| if pad_len > 0: |
| audio = F.pad(audio, (0, pad_len), mode='constant', value=0) |
| padded_batch.append(audio) |
| |
| batch_tensor = torch.stack(padded_batch, dim=0) |
| batch_tensor = batch_tensor.unsqueeze(1) |
| batch_tensor = batch_tensor.to(DEVICE) |
| |
| try: |
| encoded = encode_batch(model, batch_tensor) |
| codes = encoded.audio_codes |
| |
| for idx, (output_path, orig_len) in enumerate(zip(batch_outputs, batch_lengths)): |
| true_code_len = int(np.ceil(orig_len / hop_length)) |
| |
| item_codes = codes[idx, :, :true_code_len].cpu() |
| |
| torch.save(item_codes, output_path) |
| total_processed += 1 |
| |
| except Exception as e: |
| print(f"\nError encoding batch: {e}") |
| total_errors += len(batch_outputs) |
|
|
| print("\n" + "="*50) |
| print("PROCESSING COMPLETE!") |
| print("="*50) |
| print(f"Successfully processed: {total_processed} files") |
| print(f"Previously processed: {skipped_count} files") |
| print(f"Errors encountered: {total_errors} files") |
| print(f"Output directory: {OUTPUT_DIR}") |
|
|
| final_count = len(list(Path(OUTPUT_DIR).glob("*.pt"))) |
| print(f"Total .pt files in output: {final_count}") |