Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import torchaudio | |
import os | |
import sys | |
import glob | |
import debugpy | |
import torch | |
import numpy as np | |
import re | |
def count_params_by_module(model_name, model): | |
logging.info(f"Counting num_parameters of {model_name}:") | |
param_stats = {} | |
total_params = 0 # Count total parameters | |
total_requires_grad_params = 0 # Count parameters with requires_grad=True | |
total_no_grad_params = 0 # Count parameters with requires_grad=False | |
for name, param in model.named_parameters(): | |
module_name = name.split('.')[0] | |
if module_name not in param_stats: | |
param_stats[module_name] = {'total': 0, 'requires_grad': 0, 'no_grad': 0} | |
param_num = param.numel() | |
param_stats[module_name]['total'] += param_num | |
total_params += param_num | |
if param.requires_grad: | |
param_stats[module_name]['requires_grad'] += param_num | |
total_requires_grad_params += param_num | |
else: | |
param_stats[module_name]['no_grad'] += param_num | |
total_no_grad_params += param_num | |
# Calculate maximum width for each column | |
max_module_name_length = max(len(module) for module in param_stats) | |
max_param_length = max(len(f"{stats['total'] / 1e6:.2f}M") for stats in param_stats.values()) | |
# Output parameter statistics for each module | |
for module, stats in param_stats.items(): | |
logging.info(f"\t{module:<{max_module_name_length}}: " | |
f"Total: {stats['total'] / 1e6:<{max_param_length}.2f}M, " | |
f"Requires Grad: {stats['requires_grad'] / 1e6:<{max_param_length}.2f}M, " | |
f"No Grad: {stats['no_grad'] / 1e6:<{max_param_length}.2f}M") | |
# Output total parameter statistics | |
logging.info(f"\tTotal parameters: {total_params / 1e6:.2f}M parameters") | |
logging.info(f"\tRequires Grad parameters: {total_requires_grad_params / 1e6:.2f}M parameters") | |
logging.info(f"\tNo Grad parameters: {total_no_grad_params / 1e6:.2f}M parameters") | |
logging.info(f"################################################################") | |
def load_and_resample_audio(audio_path, target_sample_rate): | |
wav, raw_sample_rate = torchaudio.load(audio_path) # (1, T) tensor | |
if raw_sample_rate != target_sample_rate: | |
wav = torchaudio.functional.resample(wav, raw_sample_rate, target_sample_rate) # tensor | |
return wav.squeeze() | |
def set_logging(): | |
rank = os.environ.get("RANK", 0) | |
logging.basicConfig( | |
level=logging.INFO, | |
stream=sys.stdout, | |
format=f"%(asctime)s [RANK {rank}] (%(module)s:%(lineno)d) %(levelname)s : %(message)s", | |
) | |
def waiting_for_debug(ip, port): | |
rank = os.environ.get("RANK", "0") | |
debugpy.listen((ip, port)) # Replace localhost with cluster node IP | |
logging.info(f"[rank = {rank}] Waiting for debugger attach...") | |
debugpy.wait_for_client() | |
logging.info(f"[rank = {rank}] Debugger attached") | |
def load_audio(audio_path, target_sample_rate): | |
# Load audio file, wav shape: (channels, time) | |
wav, raw_sample_rate = torchaudio.load(audio_path) | |
# If multi-channel, convert to mono by averaging across channels | |
if wav.shape[0] > 1: | |
wav = torch.mean(wav, dim=0, keepdim=True) # Average across channels, keep channel dim | |
# Resample if necessary | |
if raw_sample_rate != target_sample_rate: | |
wav = torchaudio.functional.resample(wav, raw_sample_rate, target_sample_rate) | |
# Convert to numpy, add channel dimension, then back to tensor with desired shape | |
wav = np.expand_dims(wav.squeeze(0).numpy(), axis=1) # Shape: (time, 1) | |
wav = torch.tensor(wav).reshape(1, 1, -1) # Shape: (1, 1, time) | |
return wav | |
def save_audio(audio_outpath, audio_out, sample_rate): | |
torchaudio.save( | |
audio_outpath, | |
audio_out, | |
sample_rate=sample_rate, | |
encoding='PCM_S', | |
bits_per_sample=16 | |
) | |
logging.info(f"Successfully saved audio at {audio_outpath}") | |
def find_audio_files(input_dir): | |
audio_extensions = ['*.flac', '*.mp3', '*.wav'] | |
audios_input = [] | |
for ext in audio_extensions: | |
audios_input.extend(glob.glob(os.path.join(input_dir, '**', ext), recursive=True)) | |
logging.info(f"Found {len(audios_input)} audio files in {input_dir}") | |
return sorted(audios_input) | |
def normalize_text(text): | |
# Remove all punctuation (including English and Chinese punctuation) | |
text = re.sub(r'[^\w\s\u4e00-\u9fff]', '', text) | |
# Convert to lowercase (effective for English, no effect on Chinese) | |
text = text.lower() | |
# Remove extra spaces | |
text = ' '.join(text.split()) | |
return text |