ACE-Step-Custom / src /audio_processor.py
ACE-Step Custom
Deploy ACE-Step Custom Edition with bug fixes
a602628
"""
Audio Processor - Handles audio blending and processing
"""
import numpy as np
import torch
import torchaudio
from scipy import signal
import logging
from typing import Optional, Tuple
logger = logging.getLogger(__name__)
class AudioProcessor:
"""Handles audio processing, blending, and effects."""
def __init__(self, config: dict):
"""
Initialize audio processor.
Args:
config: Configuration dictionary
"""
self.config = config
self.sample_rate = config.get("sample_rate", 44100)
def blend_clip(
self,
new_clip_path: str,
previous_clip: Optional[np.ndarray],
lead_in: float = 2.0,
lead_out: float = 2.0
) -> str:
"""
Blend new clip with previous clip using crossfades.
Args:
new_clip_path: Path to new audio clip
previous_clip: Previous clip as numpy array
lead_in: Lead-in duration in seconds for blending
lead_out: Lead-out duration in seconds for blending
Returns:
Path to blended clip
"""
try:
# Load new clip
new_audio, sr = torchaudio.load(new_clip_path)
if sr != self.sample_rate:
resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
new_audio = resampler(new_audio)
new_np = new_audio.numpy()
# If no previous clip, return new clip as-is
if previous_clip is None:
return new_clip_path
# Calculate blend samples
lead_in_samples = int(lead_in * self.sample_rate)
lead_out_samples = int(lead_out * self.sample_rate)
# Ensure clips are compatible shape
if previous_clip.shape[0] != new_np.shape[0]:
# Match channels
if previous_clip.shape[0] == 1 and new_np.shape[0] == 2:
previous_clip = np.repeat(previous_clip, 2, axis=0)
elif previous_clip.shape[0] == 2 and new_np.shape[0] == 1:
new_np = np.repeat(new_np, 2, axis=0)
# Blend lead-in with previous clip's lead-out
if previous_clip.shape[1] >= lead_out_samples and new_np.shape[1] >= lead_in_samples:
# Extract regions to blend
prev_tail = previous_clip[:, -lead_out_samples:]
new_head = new_np[:, :lead_in_samples]
# Create crossfade
# Use equal-power crossfade for smooth transition
fade_out = np.cos(np.linspace(0, np.pi/2, lead_out_samples)) ** 2
fade_in = np.sin(np.linspace(0, np.pi/2, lead_in_samples)) ** 2
# Adjust lengths if different
if lead_in_samples != lead_out_samples:
# Use shorter length
blend_length = min(lead_in_samples, lead_out_samples)
prev_tail = prev_tail[:, -blend_length:]
new_head = new_head[:, :blend_length]
fade_out = fade_out[-blend_length:]
fade_in = fade_in[:blend_length]
# Apply crossfade
blended_region = (prev_tail * fade_out + new_head * fade_in)
# Reconstruct clip with blended region
result = new_np.copy()
result[:, :blended_region.shape[1]] = blended_region
else:
# Not enough audio to blend, return as-is
result = new_np
# Apply gentle compression to avoid clipping
result = self._apply_compression(result)
# Save blended clip
from pathlib import Path
from datetime import datetime
output_dir = Path(self.config.get("output_dir", "outputs"))
output_dir.mkdir(exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = output_dir / f"blended_{timestamp}.wav"
result_tensor = torch.from_numpy(result).float()
torchaudio.save(
str(output_path),
result_tensor,
self.sample_rate,
encoding="PCM_S",
bits_per_sample=16
)
logger.info(f"✅ Blended clip saved: {output_path}")
return str(output_path)
except Exception as e:
logger.error(f"Blending failed: {e}")
# Return original if blending fails
return new_clip_path
def crossfade(
self,
audio1: np.ndarray,
audio2: np.ndarray,
fade_duration: float = 2.0
) -> np.ndarray:
"""
Create crossfade between two audio segments.
Args:
audio1: First audio segment
audio2: Second audio segment
fade_duration: Duration of crossfade in seconds
Returns:
Crossfaded audio
"""
fade_samples = int(fade_duration * self.sample_rate)
# Ensure same number of channels
if audio1.shape[0] != audio2.shape[0]:
target_channels = max(audio1.shape[0], audio2.shape[0])
if audio1.shape[0] < target_channels:
audio1 = np.repeat(audio1, target_channels // audio1.shape[0], axis=0)
if audio2.shape[0] < target_channels:
audio2 = np.repeat(audio2, target_channels // audio2.shape[0], axis=0)
# Extract fade regions
fade_out_region = audio1[:, -fade_samples:]
fade_in_region = audio2[:, :fade_samples]
# Create equal-power crossfade curves
fade_out_curve = np.cos(np.linspace(0, np.pi/2, fade_samples)) ** 2
fade_in_curve = np.sin(np.linspace(0, np.pi/2, fade_samples)) ** 2
# Apply fades
faded = fade_out_region * fade_out_curve + fade_in_region * fade_in_curve
# Concatenate: audio1 (except fade region) + faded + audio2 (except fade region)
result = np.concatenate([
audio1[:, :-fade_samples],
faded,
audio2[:, fade_samples:]
], axis=1)
return result
def _apply_compression(self, audio: np.ndarray, threshold: float = 0.8) -> np.ndarray:
"""
Apply gentle compression to prevent clipping.
Args:
audio: Input audio
threshold: Compression threshold
Returns:
Compressed audio
"""
# Soft clipping using tanh
peak = np.abs(audio).max()
if peak > threshold:
# Apply soft compression
compressed = np.tanh(audio * (threshold / peak)) * threshold
return compressed
return audio
def normalize_audio(self, audio: np.ndarray, target_db: float = -3.0) -> np.ndarray:
"""
Normalize audio to target dB level.
Args:
audio: Input audio
target_db: Target level in dB
Returns:
Normalized audio
"""
# Calculate current peak in dB
peak = np.abs(audio).max()
if peak == 0:
return audio
current_db = 20 * np.log10(peak)
# Calculate gain needed
gain_db = target_db - current_db
gain_linear = 10 ** (gain_db / 20)
# Apply gain
normalized = audio * gain_linear
# Ensure no clipping
normalized = np.clip(normalized, -1.0, 1.0)
return normalized
def remove_clicks_pops(self, audio: np.ndarray) -> np.ndarray:
"""
Remove clicks and pops from audio.
Args:
audio: Input audio
Returns:
Cleaned audio
"""
# Apply median filter to remove impulse noise
from scipy.ndimage import median_filter
cleaned = np.zeros_like(audio)
for ch in range(audio.shape[0]):
cleaned[ch] = median_filter(audio[ch], size=3)
return cleaned
def apply_fade(
self,
audio: np.ndarray,
fade_in: float = 0.0,
fade_out: float = 0.0
) -> np.ndarray:
"""
Apply fade in/out to audio.
Args:
audio: Input audio
fade_in: Fade in duration in seconds
fade_out: Fade out duration in seconds
Returns:
Faded audio
"""
result = audio.copy()
# Fade in
if fade_in > 0:
fade_in_samples = int(fade_in * self.sample_rate)
fade_in_samples = min(fade_in_samples, audio.shape[1])
fade_curve = np.linspace(0, 1, fade_in_samples) ** 2
result[:, :fade_in_samples] *= fade_curve
# Fade out
if fade_out > 0:
fade_out_samples = int(fade_out * self.sample_rate)
fade_out_samples = min(fade_out_samples, audio.shape[1])
fade_curve = np.linspace(1, 0, fade_out_samples) ** 2
result[:, -fade_out_samples:] *= fade_curve
return result
def resample_audio(
self,
audio: np.ndarray,
orig_sr: int,
target_sr: int
) -> np.ndarray:
"""
Resample audio to target sample rate.
Args:
audio: Input audio
orig_sr: Original sample rate
target_sr: Target sample rate
Returns:
Resampled audio
"""
if orig_sr == target_sr:
return audio
# Use scipy's resample for high-quality resampling
num_samples = int(audio.shape[1] * target_sr / orig_sr)
resampled = signal.resample(audio, num_samples, axis=1)
return resampled
def match_loudness(
self,
audio1: np.ndarray,
audio2: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""
Match loudness between two audio segments.
Args:
audio1: First audio segment
audio2: Second audio segment
Returns:
Tuple of loudness-matched audio segments
"""
# Calculate RMS for each
rms1 = np.sqrt(np.mean(audio1 ** 2))
rms2 = np.sqrt(np.mean(audio2 ** 2))
if rms2 == 0:
return audio1, audio2
# Calculate gain to match audio1 to audio2
gain = rms2 / rms1
# Apply gain to audio1
matched_audio1 = audio1 * gain
# Prevent clipping
peak = np.abs(matched_audio1).max()
if peak > 1.0:
matched_audio1 = matched_audio1 / peak
return matched_audio1, audio2