import torch from math import gcd from typing import Optional, Union import joblib import numpy as np from scipy import signal def load_scaler_joblib(path: str) -> tuple[torch.Tensor, torch.Tensor]: """ Load ecg_scaler.pkl and return center and scale as torch tensors. Args: path: Path to the joblib file. Returns: center: torch.Tensor scale: torch.Tensor """ sc = joblib.load(path) center = torch.from_numpy(sc.mean_.astype(np.float32)) scale = torch.from_numpy(sc.scale_.astype(np.float32)).clamp_min(1e-8) return center, scale class ECGTransform: """ Unified ECG preprocessing: downsampling and scaling. Usage: transform = ECGTransform(center, scale, src_fs=512, target_fs=100) ecg_out = transform(ecg_in) """ def __init__( self, center: Union[np.ndarray, torch.Tensor], scale: Union[np.ndarray, torch.Tensor], src_fs: int = 100, #we assume the input ECG is already at 100Hz target_fs: int = 100, band: Optional[tuple[float, float]] = (0.5, 40.0), bp_order: int = 4, axis: int = -1, ) -> None: self.center = torch.as_tensor(center, dtype=torch.float32) self.scale = torch.as_tensor(scale, dtype=torch.float32).clamp_min(1e-8) self.src_fs = src_fs self.target_fs = target_fs self.band = band self.bp_order = bp_order self.axis = axis def downsample(self, x: np.ndarray) -> np.ndarray: x = np.asarray(x) if self.band is not None: lowcut, highcut = self.band max_high = 0.45 * self.target_fs highcut = min(highcut, max_high) nyq = self.src_fs / 2.0 if lowcut <= 0: wn = highcut / nyq sos = signal.butter(self.bp_order, wn, btype="low", output="sos") else: wn = (lowcut / nyq, highcut / nyq) sos = signal.butter(self.bp_order, wn, btype="band", output="sos") x = signal.sosfiltfilt(sos, x, axis=self.axis) g = gcd(self.src_fs, self.target_fs) up = self.target_fs // g down = self.src_fs // g y = signal.resample_poly(x, up, down, axis=self.axis, window=("kaiser", 5.0), padtype="median") return y def scale(self, ecg: torch.Tensor) -> torch.Tensor: ecg = ecg.to(torch.float32) ecg = (ecg - self.center[:, None]) / self.scale[:, None] return ecg def __call__(self, x: np.ndarray) -> torch.Tensor: """ Downsample and scale ECG data. Args: x: np.ndarray, shape (leads, time) Returns: torch.Tensor, shape (leads, time) """ if self.src_fs != self.target_fs: x = self.downsample(x) if not isinstance(x, torch.Tensor): x = torch.from_numpy(x) x = self.scale(x) return x