# dbbun_eeg_eval_labeled_pca.py # Color-coded PCA: seizure (red) vs non-seizure (blue) from pathlib import Path import json, numpy as np, torch, torch.nn as nn from torch.utils.data import Dataset, DataLoader import matplotlib.pyplot as plt # ====================== # CONFIG (your settings) # ====================== MODEL_DIR = r"C:\DBBun\Code\EEG\pretraining\pretrained_out" DATA_DIR = r"d:\dbbun-eeg\data\valid" # NPZ folder USE_NPZ = True # NPZ with labels_sec PREFER_TORCHSCRIPT = False HOP_SECONDS = 1.5 MAX_FILES = 10 BATCH_SIZE = 64 WINDOW_LABEL_THRESHOLD = 0.5 # >0.5 sec of seizure in a 2s window → label 1 SAVE_EMBEDDINGS = True EMB_OUT_PATH = Path(MODEL_DIR) / "demo_embeddings.npy" # ====================== # Model loading # ====================== def load_model_def(model_dir: str): md_path = Path(model_dir) / "model_def.json" return json.loads(md_path.read_text()) class Conv1dEncoder(nn.Module): def __init__(self, in_channels, widths=(32,64,128), latent_dim=128, dropout=0.1): super().__init__() layers, prev = [], in_channels for w in widths: layers += [nn.Conv1d(prev,w,7,2,3), nn.BatchNorm1d(w), nn.GELU(), nn.Dropout(dropout)] prev = w self.conv = nn.Sequential(*layers) self.pool = nn.AdaptiveAvgPool1d(1) self.proj = nn.Linear(prev, latent_dim) def forward(self, x): h = self.conv(x); g = self.pool(h).squeeze(-1); z = self.proj(g) return z, h def load_encoder(model_dir: str, prefer_ts: bool = False): md = load_model_def(model_dir) if prefer_ts and (Path(model_dir)/"encoder_traced.pt").exists(): print("[Model] TorchScript") enc = torch.jit.load(str(Path(model_dir)/"encoder_traced.pt"), map_location="cpu") scripted = True else: print("[Model] state_dict") enc = Conv1dEncoder(md["channels"], tuple(md["encoder_channels"]), md["latent_dim"], md["dropout"]) enc.load_state_dict(torch.load(Path(model_dir)/"encoder_state.pt", map_location="cpu")) scripted = False enc.eval() win = int(md["window_seconds"] * md["sample_rate"]) return enc, md, win, scripted # ====================== # Dataset with labels # ====================== class EEGWindowsLabeled(Dataset): """Returns (window_tensor, label), label∈{0,1,-1} derived from labels_sec.""" def __init__(self, folder, window_len, hop, sr, use_npz=True, max_files=None, print_summary=True): self.folder = Path(folder); self.window = int(window_len); self.hop = int(hop) self.sr = int(sr); self.use_npz = use_npz patt = "*.npz" if use_npz else "*.npy" self.files = sorted(self.folder.rglob(patt))[: (int(max_files) if max_files else None)] self.index, self.shapes = [], [] self.labels_present, self.sz_frac = False, [] for i, f in enumerate(self.files): if use_npz: with np.load(f, allow_pickle=True) as z: a = z["eeg"] if "eeg" in z.files else z[list(z.files)[0]] a = np.array(a, dtype=np.float32) if "labels_sec" in z.files: self.labels_present = True self.sz_frac.append(float(np.mean(z["labels_sec"]))) else: a = np.load(f, mmap_mode="r") if a.ndim != 2: continue C,T = a.shape; self.shapes.append((int(C),int(T))) if T >= self.window: starts = np.arange(0, T-self.window+1, self.hop, dtype=int) self.index += [(i, int(s)) for s in starts] self.channels = max((c for c,_ in self.shapes), default=1) if print_summary: print(f"[Data] Files: {len(self.files)} | Windows: {len(self.index)} | Channels(max): {self.channels}") if self.labels_present and self.sz_frac: print(f"[Data] labels_sec present — mean seizure_fraction: {np.mean(self.sz_frac):.3f}") def __len__(self): return len(self.index) def _label_from_labels_sec(self, labels_sec, start, win_len): s0 = start // self.sr s1 = min((start+win_len-1)//self.sr, len(labels_sec)-1) if s0> s1: return -1 frac = float(np.mean(labels_sec[s0:s1+1])) return 1 if frac > WINDOW_LABEL_THRESHOLD else 0 def __getitem__(self, idx): fi, start = self.index[idx] f = self.files[fi]; label = -1 if self.use_npz: with np.load(f, allow_pickle=True) as z: a = z["eeg"] if "eeg" in z.files else z[list(z.files)[0]] seg = np.asarray(a[:, start:start+self.window], dtype=np.float32) if "labels_sec" in z.files: label = self._label_from_labels_sec(np.asarray(z["labels_sec"]), start, self.window) else: a = np.load(f, mmap_mode="r") seg = np.asarray(a[:, start:start+self.window], dtype=np.float32) C = seg.shape[0] if C < self.channels: pad = np.zeros((self.channels-C, seg.shape[1]), dtype=np.float32) seg = np.concatenate([seg, pad], axis=0) elif C > self.channels: seg = seg[:self.channels] mu = seg.mean(axis=1, keepdims=True); sd = seg.std(axis=1, keepdims=True)+1e-6 seg = (seg - mu)/sd return torch.from_numpy(seg), torch.tensor(label, dtype=torch.int64) # ====================== # PCA + plots # ====================== def pca_2d(E: np.ndarray): E0 = E - E.mean(0, keepdims=True) U,S,Vt = np.linalg.svd(E0, full_matrices=False) return E0 @ Vt[:2].T def plot_pca_colored(Y, labels): lbl = labels.astype(int) has = lbl >= 0 plt.figure(figsize=(6,5)) if np.any(has): nz, sz = lbl==0, lbl==1 if np.any(nz): plt.scatter(Y[nz,0], Y[nz,1], s=6, alpha=0.7, label="non-seizure", c="blue") if np.any(sz): plt.scatter(Y[sz,0], Y[sz,1], s=10, alpha=0.9, label="seizure", c="red") if np.any(~has): plt.scatter(Y[~has,0], Y[~has,1], s=4, alpha=0.3, label="unlabeled", c="gray") plt.legend() else: plt.scatter(Y[:,0], Y[:,1], s=6) plt.title("Encoder embeddings — PCA (colored by label)") plt.xlabel("PC1"); plt.ylabel("PC2"); plt.tight_layout(); plt.show() # ====================== # Main # ====================== if __name__ == "__main__": enc, md, WIN_SAMPLES, scripted = load_encoder(MODEL_DIR, PREFER_TORCHSCRIPT) HOP = int(HOP_SECONDS * md["sample_rate"]) print(f"[Config] Window={WIN_SAMPLES} | Hop={HOP} | SR={md['sample_rate']} Hz") ds = EEGWindowsLabeled(DATA_DIR, WIN_SAMPLES, HOP, sr=md["sample_rate"], use_npz=USE_NPZ, max_files=MAX_FILES, print_summary=True) if len(ds)==0: raise SystemExit("No windows produced — check DATA_DIR / USE_NPZ.") dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True, drop_last=True) all_Z, all_L = [], [] enc.eval() with torch.no_grad(): for i, (x, lbl) in enumerate(dl): z, _ = enc(x) if not scripted else enc(x) all_Z.append(z.cpu().numpy()) all_L.append(lbl.cpu().numpy()) if i >= 50: # cap for speed; remove to process all windows break E = np.concatenate(all_Z, axis=0) # (n, 128) L = np.concatenate(all_L, axis=0).astype(int) # (n,) print(f"[Emb] {E.shape[0]} embeddings collected | latent={E.shape[1]}") if SAVE_EMBEDDINGS: EMB_OUT_PATH.parent.mkdir(parents=True, exist_ok=True) np.save(EMB_OUT_PATH, E) print(f"[Emb] Saved: {EMB_OUT_PATH}") Y = pca_2d(E) plot_pca_colored(Y, L)