# dbbun_eeg_eval.py # DBbun EEG — pretrained encoder evaluation & demo # Run this in your "eeg" conda env (or other env with numpy, torch, matplotlib) from pathlib import Path import json import numpy as np import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader import matplotlib.pyplot as plt # ====================== # CONFIG — EDIT THESE # ====================== # Use local model folder (recommended for Spyder) MODEL_DIR = r"C:\DBbun\Code\EEG\pretraining\pretrained_out" # has encoder_state.pt + model_def.json # If you prefer TorchScript instead of state_dict, set this True and make sure encoder_traced.pt exists. PREFER_TORCHSCRIPT = False # Your validation data directory — can contain .npy OR .npz files (recurses) DATA_DIR = r"d:\dbbun-eeg\data\valid" # e.g., r"d:\dbbun-eeg\data\val" if NPZ USE_NPZ = True # set True if your files are .npz # Windowing HOP_SECONDS = 1.5 # larger hop => fewer windows (faster) MAX_FILES = 10 # limit for quick demo (set None to use all) # Embedding export SAVE_EMBEDDINGS = True EMB_OUT_PATH = Path(MODEL_DIR) / "demo_embeddings.npy" # saved as (n_windows, latent_dim) # Linear probe (optional; toy labels) RUN_LINEAR_PROBE = True # ====================== # Utilities # ====================== def load_model_def(model_dir: str): md_path = Path(model_dir) / "model_def.json" if not md_path.exists(): raise FileNotFoundError(f"model_def.json not found at {md_path}") return json.loads(md_path.read_text()) class Conv1dEncoder(nn.Module): def __init__(self, in_channels, widths=(32,64,128), latent_dim=128, dropout=0.1): super().__init__() layers, prev = [], in_channels for w in widths: layers += [ nn.Conv1d(prev, w, kernel_size=7, padding=3, stride=2), nn.BatchNorm1d(w), nn.GELU(), nn.Dropout(dropout), ] prev = w self.conv = nn.Sequential(*layers) self.pool = nn.AdaptiveAvgPool1d(1) self.proj = nn.Linear(prev, latent_dim) def forward(self, x): h = self.conv(x) # (B, W, L') g = self.pool(h).squeeze(-1) # (B, W) z = self.proj(g) # (B, latent) return z, h def load_encoder(model_dir: str, prefer_ts: bool = False): md = load_model_def(model_dir) if prefer_ts and (Path(model_dir) / "encoder_traced.pt").exists(): print("[Model] Loading TorchScript encoder_traced.pt") enc = torch.jit.load(str(Path(model_dir) / "encoder_traced.pt"), map_location="cpu") # TorchScript returns the scripted forward; assume it returns (z, h) as in training scripted = True else: print("[Model] Loading state_dict encoder_state.pt") enc = Conv1dEncoder( in_channels=md["channels"], widths=tuple(md["encoder_channels"]), latent_dim=md["latent_dim"], dropout=md["dropout"] ) enc.load_state_dict(torch.load(Path(model_dir) / "encoder_state.pt", map_location="cpu")) scripted = False enc.eval() window_samples = int(md["window_seconds"] * md["sample_rate"]) return enc, md, window_samples, scripted # Dataset that supports .npy (memmap) or .npz (loads "eeg" key if present) class EEGWindows(Dataset): def __init__(self, folder, window_len, hop, use_npz=False, max_files=None, print_summary=True): self.folder = Path(folder) self.use_npz = use_npz self.window = int(window_len) self.hop = int(hop) # gather files pattern = "*.npz" if use_npz else "*.npy" self.files = sorted(self.folder.rglob(pattern)) if max_files: self.files = self.files[:int(max_files)] if print_summary: print(f"[Data] Found {len(self.files)} files under {self.folder}") # build index self.index = [] self.shapes = [] self.labels_per_sec_exist = False self.seizure_fraction_estimates = [] for i, f in enumerate(self.files): if use_npz: with np.load(f, allow_pickle=True) as z: if "eeg" in z.files: a = np.array(z["eeg"], dtype=np.float32) else: # fallback to first array in the container a = np.array(z[list(z.files)[0]], dtype=np.float32) # Try to detect labels if "labels_sec" in z.files: self.labels_per_sec_exist = True lbl = np.array(z["labels_sec"]).astype(np.uint8) self.seizure_fraction_estimates.append(float(lbl.mean())) else: a = np.load(f, mmap_mode='r') # (C, T) if a.ndim != 2: continue C, T = int(a.shape[0]), int(a.shape[1]) self.shapes.append((C, T)) if T >= self.window: starts = np.arange(0, T - self.window + 1, self.hop, dtype=int) self.index += [(i, int(s)) for s in starts] self.channels = max((c for c, _ in self.shapes), default=1) if print_summary: total_windows = len(self.index) print(f"[Data] Channels(max): {self.channels} | Windows: {total_windows}") if self.labels_per_sec_exist and self.seizure_fraction_estimates: print(f"[Data] labels_sec present. Mean seizure_fraction across loaded files: " f"{np.mean(self.seizure_fraction_estimates):.3f}") def __len__(self): return len(self.index) def __getitem__(self, idx): fi, start = self.index[idx] f = self.files[fi] if self.use_npz: with np.load(f, allow_pickle=True) as z: if "eeg" in z.files: a = z["eeg"] else: a = z[list(z.files)[0]] seg = np.asarray(a[:, start:start + self.window], dtype=np.float32) else: a = np.load(f, mmap_mode='r') seg = np.asarray(a[:, start:start + self.window], dtype=np.float32) # pad/crop channels to common number C = seg.shape[0] if C < self.channels: pad = np.zeros((self.channels - C, seg.shape[1]), dtype=np.float32) seg = np.concatenate([seg, pad], axis=0) elif C > self.channels: seg = seg[:self.channels] # per-window z-score normalization mu = seg.mean(axis=1, keepdims=True) sd = seg.std(axis=1, keepdims=True) + 1e-6 seg = (seg - mu) / sd return torch.from_numpy(seg) # (C, L) def pca_2d_numpy(E: np.ndarray): """Return 2D PCA projection using NumPy SVD""" E0 = E - E.mean(0, keepdims=True) U, S, Vt = np.linalg.svd(E0, full_matrices=False) Y = E0 @ Vt[:2].T return Y def run_linear_probe(E: np.ndarray, epochs=5, lr=1e-3): """Tiny demo head on toy labels derived from PC1 threshold; replace with real labels if you have them.""" Y = pca_2d_numpy(E) labels = (Y[:, 0] > Y[:, 0].mean()).astype(np.int64) Z = torch.from_numpy(E).float() y = torch.from_numpy(labels) head = nn.Linear(E.shape[1], 2) opt = torch.optim.AdamW(head.parameters(), lr=lr) lossf = nn.CrossEntropyLoss() for ep in range(1, epochs + 1): opt.zero_grad(set_to_none=True) logits = head(Z) loss = lossf(logits, y) loss.backward() opt.step() with torch.no_grad(): acc = (logits.argmax(1) == y).float().mean().item() print(f"[Probe] Epoch {ep}/{epochs} - loss: {loss.item():.4f} | acc: {acc:.3f}") return Y # ====================== # Main # ====================== if __name__ == "__main__": torch.backends.cudnn.benchmark = True try: torch.set_float32_matmul_precision("medium") except Exception: pass enc, md, WIN_SAMPLES, scripted = load_encoder(MODEL_DIR, PREFER_TORCHSCRIPT) HOP = int(HOP_SECONDS * md["sample_rate"]) print(f"[Config] Window = {WIN_SAMPLES} samples | Hop = {HOP} | Sample rate = {md['sample_rate']} Hz") ds = EEGWindows(DATA_DIR, WIN_SAMPLES, HOP, use_npz=USE_NPZ, max_files=MAX_FILES, print_summary=True) if len(ds) == 0: raise SystemExit("No windows produced — check DATA_DIR / USE_NPZ / window settings.") # DataLoader: 0 workers on Windows avoids fork issues in Spyder dl = DataLoader(ds, batch_size=64, shuffle=True, num_workers=0, pin_memory=True, drop_last=True) # ---- Extract embeddings ---- all_Z = [] enc.eval() with torch.no_grad(): for i, x in enumerate(dl): # x: (B, C, L) on CPU; encoder is on CPU by default in this script z, _ = enc(x) if not scripted else enc(x) # both return (z, h) all_Z.append(z.cpu().numpy()) if i >= 50: # limit passes for speed; raise/remove for full run break E = np.concatenate(all_Z, axis=0) # (n_windows, latent_dim) print(f"[Emb] Collected embeddings: {E.shape}") if SAVE_EMBEDDINGS: EMB_OUT_PATH.parent.mkdir(parents=True, exist_ok=True) np.save(EMB_OUT_PATH, E) print(f"[Emb] Saved to: {EMB_OUT_PATH}") # ---- PCA scatter ---- Y = pca_2d_numpy(E) plt.figure(figsize=(5, 5)) plt.scatter(Y[:, 0], Y[:, 1], s=6) plt.title("Encoder embeddings — PCA (first 2 components)") plt.xlabel("PC1") plt.ylabel("PC2") plt.tight_layout() plt.show() # ---- Optional: toy linear probe ---- if RUN_LINEAR_PROBE: _ = run_linear_probe(E, epochs=5, lr=1e-3) # ---- If .npz labels exist, print seizure_fraction summary ---- if ds.labels_per_sec_exist and len(ds.seizure_fraction_estimates) > 0: print(f"[Meta] Mean seizure_fraction (from labels_sec): " f"{np.mean(ds.seizure_fraction_estimates):.3f} " f"(over {len(ds.seizure_fraction_estimates)} files)") else: print("[Meta] No labels_sec found in files (expected for .npy datasets).")