Upload 2 files
Browse files- DBbun_EEG_Encoder_Eval_Demo_v1.py +262 -0
- demo_embeddings.npy +3 -0
DBbun_EEG_Encoder_Eval_Demo_v1.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dbbun_eeg_eval.py
|
2 |
+
# DBbun EEG — pretrained encoder evaluation & demo
|
3 |
+
# Run this in your "eeg" conda env (or other env with numpy, torch, matplotlib)
|
4 |
+
|
5 |
+
from pathlib import Path
|
6 |
+
import json
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.utils.data import Dataset, DataLoader
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
|
13 |
+
# ======================
|
14 |
+
# CONFIG — EDIT THESE
|
15 |
+
# ======================
|
16 |
+
# Use local model folder (recommended for Spyder)
|
17 |
+
MODEL_DIR = r"C:\DBbun\Code\EEG\pretraining\pretrained_out" # has encoder_state.pt + model_def.json
|
18 |
+
# If you prefer TorchScript instead of state_dict, set this True and make sure encoder_traced.pt exists.
|
19 |
+
PREFER_TORCHSCRIPT = False
|
20 |
+
|
21 |
+
# Your validation data directory — can contain .npy OR .npz files (recurses)
|
22 |
+
DATA_DIR = r"d:\dbbun-eeg\data\valid" # e.g., r"d:\dbbun-eeg\data\val" if NPZ
|
23 |
+
USE_NPZ = True # set True if your files are .npz
|
24 |
+
|
25 |
+
# Windowing
|
26 |
+
HOP_SECONDS = 1.5 # larger hop => fewer windows (faster)
|
27 |
+
MAX_FILES = 10 # limit for quick demo (set None to use all)
|
28 |
+
|
29 |
+
# Embedding export
|
30 |
+
SAVE_EMBEDDINGS = True
|
31 |
+
EMB_OUT_PATH = Path(MODEL_DIR) / "demo_embeddings.npy" # saved as (n_windows, latent_dim)
|
32 |
+
|
33 |
+
# Linear probe (optional; toy labels)
|
34 |
+
RUN_LINEAR_PROBE = True
|
35 |
+
|
36 |
+
# ======================
|
37 |
+
# Utilities
|
38 |
+
# ======================
|
39 |
+
|
40 |
+
def load_model_def(model_dir: str):
|
41 |
+
md_path = Path(model_dir) / "model_def.json"
|
42 |
+
if not md_path.exists():
|
43 |
+
raise FileNotFoundError(f"model_def.json not found at {md_path}")
|
44 |
+
return json.loads(md_path.read_text())
|
45 |
+
|
46 |
+
class Conv1dEncoder(nn.Module):
|
47 |
+
def __init__(self, in_channels, widths=(32,64,128), latent_dim=128, dropout=0.1):
|
48 |
+
super().__init__()
|
49 |
+
layers, prev = [], in_channels
|
50 |
+
for w in widths:
|
51 |
+
layers += [
|
52 |
+
nn.Conv1d(prev, w, kernel_size=7, padding=3, stride=2),
|
53 |
+
nn.BatchNorm1d(w),
|
54 |
+
nn.GELU(),
|
55 |
+
nn.Dropout(dropout),
|
56 |
+
]
|
57 |
+
prev = w
|
58 |
+
self.conv = nn.Sequential(*layers)
|
59 |
+
self.pool = nn.AdaptiveAvgPool1d(1)
|
60 |
+
self.proj = nn.Linear(prev, latent_dim)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
h = self.conv(x) # (B, W, L')
|
64 |
+
g = self.pool(h).squeeze(-1) # (B, W)
|
65 |
+
z = self.proj(g) # (B, latent)
|
66 |
+
return z, h
|
67 |
+
|
68 |
+
def load_encoder(model_dir: str, prefer_ts: bool = False):
|
69 |
+
md = load_model_def(model_dir)
|
70 |
+
if prefer_ts and (Path(model_dir) / "encoder_traced.pt").exists():
|
71 |
+
print("[Model] Loading TorchScript encoder_traced.pt")
|
72 |
+
enc = torch.jit.load(str(Path(model_dir) / "encoder_traced.pt"), map_location="cpu")
|
73 |
+
# TorchScript returns the scripted forward; assume it returns (z, h) as in training
|
74 |
+
scripted = True
|
75 |
+
else:
|
76 |
+
print("[Model] Loading state_dict encoder_state.pt")
|
77 |
+
enc = Conv1dEncoder(
|
78 |
+
in_channels=md["channels"],
|
79 |
+
widths=tuple(md["encoder_channels"]),
|
80 |
+
latent_dim=md["latent_dim"],
|
81 |
+
dropout=md["dropout"]
|
82 |
+
)
|
83 |
+
enc.load_state_dict(torch.load(Path(model_dir) / "encoder_state.pt", map_location="cpu"))
|
84 |
+
scripted = False
|
85 |
+
enc.eval()
|
86 |
+
window_samples = int(md["window_seconds"] * md["sample_rate"])
|
87 |
+
return enc, md, window_samples, scripted
|
88 |
+
|
89 |
+
# Dataset that supports .npy (memmap) or .npz (loads "eeg" key if present)
|
90 |
+
class EEGWindows(Dataset):
|
91 |
+
def __init__(self, folder, window_len, hop, use_npz=False, max_files=None, print_summary=True):
|
92 |
+
self.folder = Path(folder)
|
93 |
+
self.use_npz = use_npz
|
94 |
+
self.window = int(window_len)
|
95 |
+
self.hop = int(hop)
|
96 |
+
# gather files
|
97 |
+
pattern = "*.npz" if use_npz else "*.npy"
|
98 |
+
self.files = sorted(self.folder.rglob(pattern))
|
99 |
+
if max_files:
|
100 |
+
self.files = self.files[:int(max_files)]
|
101 |
+
if print_summary:
|
102 |
+
print(f"[Data] Found {len(self.files)} files under {self.folder}")
|
103 |
+
# build index
|
104 |
+
self.index = []
|
105 |
+
self.shapes = []
|
106 |
+
self.labels_per_sec_exist = False
|
107 |
+
self.seizure_fraction_estimates = []
|
108 |
+
|
109 |
+
for i, f in enumerate(self.files):
|
110 |
+
if use_npz:
|
111 |
+
with np.load(f, allow_pickle=True) as z:
|
112 |
+
if "eeg" in z.files:
|
113 |
+
a = np.array(z["eeg"], dtype=np.float32)
|
114 |
+
else:
|
115 |
+
# fallback to first array in the container
|
116 |
+
a = np.array(z[list(z.files)[0]], dtype=np.float32)
|
117 |
+
# Try to detect labels
|
118 |
+
if "labels_sec" in z.files:
|
119 |
+
self.labels_per_sec_exist = True
|
120 |
+
lbl = np.array(z["labels_sec"]).astype(np.uint8)
|
121 |
+
self.seizure_fraction_estimates.append(float(lbl.mean()))
|
122 |
+
else:
|
123 |
+
a = np.load(f, mmap_mode='r') # (C, T)
|
124 |
+
|
125 |
+
if a.ndim != 2:
|
126 |
+
continue
|
127 |
+
C, T = int(a.shape[0]), int(a.shape[1])
|
128 |
+
self.shapes.append((C, T))
|
129 |
+
|
130 |
+
if T >= self.window:
|
131 |
+
starts = np.arange(0, T - self.window + 1, self.hop, dtype=int)
|
132 |
+
self.index += [(i, int(s)) for s in starts]
|
133 |
+
|
134 |
+
self.channels = max((c for c, _ in self.shapes), default=1)
|
135 |
+
|
136 |
+
if print_summary:
|
137 |
+
total_windows = len(self.index)
|
138 |
+
print(f"[Data] Channels(max): {self.channels} | Windows: {total_windows}")
|
139 |
+
if self.labels_per_sec_exist and self.seizure_fraction_estimates:
|
140 |
+
print(f"[Data] labels_sec present. Mean seizure_fraction across loaded files: "
|
141 |
+
f"{np.mean(self.seizure_fraction_estimates):.3f}")
|
142 |
+
|
143 |
+
def __len__(self):
|
144 |
+
return len(self.index)
|
145 |
+
|
146 |
+
def __getitem__(self, idx):
|
147 |
+
fi, start = self.index[idx]
|
148 |
+
f = self.files[fi]
|
149 |
+
if self.use_npz:
|
150 |
+
with np.load(f, allow_pickle=True) as z:
|
151 |
+
if "eeg" in z.files:
|
152 |
+
a = z["eeg"]
|
153 |
+
else:
|
154 |
+
a = z[list(z.files)[0]]
|
155 |
+
seg = np.asarray(a[:, start:start + self.window], dtype=np.float32)
|
156 |
+
else:
|
157 |
+
a = np.load(f, mmap_mode='r')
|
158 |
+
seg = np.asarray(a[:, start:start + self.window], dtype=np.float32)
|
159 |
+
|
160 |
+
# pad/crop channels to common number
|
161 |
+
C = seg.shape[0]
|
162 |
+
if C < self.channels:
|
163 |
+
pad = np.zeros((self.channels - C, seg.shape[1]), dtype=np.float32)
|
164 |
+
seg = np.concatenate([seg, pad], axis=0)
|
165 |
+
elif C > self.channels:
|
166 |
+
seg = seg[:self.channels]
|
167 |
+
|
168 |
+
# per-window z-score normalization
|
169 |
+
mu = seg.mean(axis=1, keepdims=True)
|
170 |
+
sd = seg.std(axis=1, keepdims=True) + 1e-6
|
171 |
+
seg = (seg - mu) / sd
|
172 |
+
|
173 |
+
return torch.from_numpy(seg) # (C, L)
|
174 |
+
|
175 |
+
def pca_2d_numpy(E: np.ndarray):
|
176 |
+
"""Return 2D PCA projection using NumPy SVD"""
|
177 |
+
E0 = E - E.mean(0, keepdims=True)
|
178 |
+
U, S, Vt = np.linalg.svd(E0, full_matrices=False)
|
179 |
+
Y = E0 @ Vt[:2].T
|
180 |
+
return Y
|
181 |
+
|
182 |
+
def run_linear_probe(E: np.ndarray, epochs=5, lr=1e-3):
|
183 |
+
"""Tiny demo head on toy labels derived from PC1 threshold; replace with real labels if you have them."""
|
184 |
+
Y = pca_2d_numpy(E)
|
185 |
+
labels = (Y[:, 0] > Y[:, 0].mean()).astype(np.int64)
|
186 |
+
Z = torch.from_numpy(E).float()
|
187 |
+
y = torch.from_numpy(labels)
|
188 |
+
head = nn.Linear(E.shape[1], 2)
|
189 |
+
opt = torch.optim.AdamW(head.parameters(), lr=lr)
|
190 |
+
lossf = nn.CrossEntropyLoss()
|
191 |
+
for ep in range(1, epochs + 1):
|
192 |
+
opt.zero_grad(set_to_none=True)
|
193 |
+
logits = head(Z)
|
194 |
+
loss = lossf(logits, y)
|
195 |
+
loss.backward()
|
196 |
+
opt.step()
|
197 |
+
with torch.no_grad():
|
198 |
+
acc = (logits.argmax(1) == y).float().mean().item()
|
199 |
+
print(f"[Probe] Epoch {ep}/{epochs} - loss: {loss.item():.4f} | acc: {acc:.3f}")
|
200 |
+
return Y
|
201 |
+
|
202 |
+
# ======================
|
203 |
+
# Main
|
204 |
+
# ======================
|
205 |
+
if __name__ == "__main__":
|
206 |
+
torch.backends.cudnn.benchmark = True
|
207 |
+
try:
|
208 |
+
torch.set_float32_matmul_precision("medium")
|
209 |
+
except Exception:
|
210 |
+
pass
|
211 |
+
|
212 |
+
enc, md, WIN_SAMPLES, scripted = load_encoder(MODEL_DIR, PREFER_TORCHSCRIPT)
|
213 |
+
HOP = int(HOP_SECONDS * md["sample_rate"])
|
214 |
+
print(f"[Config] Window = {WIN_SAMPLES} samples | Hop = {HOP} | Sample rate = {md['sample_rate']} Hz")
|
215 |
+
|
216 |
+
ds = EEGWindows(DATA_DIR, WIN_SAMPLES, HOP, use_npz=USE_NPZ, max_files=MAX_FILES, print_summary=True)
|
217 |
+
if len(ds) == 0:
|
218 |
+
raise SystemExit("No windows produced — check DATA_DIR / USE_NPZ / window settings.")
|
219 |
+
|
220 |
+
# DataLoader: 0 workers on Windows avoids fork issues in Spyder
|
221 |
+
dl = DataLoader(ds, batch_size=64, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
|
222 |
+
|
223 |
+
# ---- Extract embeddings ----
|
224 |
+
all_Z = []
|
225 |
+
enc.eval()
|
226 |
+
with torch.no_grad():
|
227 |
+
for i, x in enumerate(dl):
|
228 |
+
# x: (B, C, L) on CPU; encoder is on CPU by default in this script
|
229 |
+
z, _ = enc(x) if not scripted else enc(x) # both return (z, h)
|
230 |
+
all_Z.append(z.cpu().numpy())
|
231 |
+
if i >= 50: # limit passes for speed; raise/remove for full run
|
232 |
+
break
|
233 |
+
|
234 |
+
E = np.concatenate(all_Z, axis=0) # (n_windows, latent_dim)
|
235 |
+
print(f"[Emb] Collected embeddings: {E.shape}")
|
236 |
+
|
237 |
+
if SAVE_EMBEDDINGS:
|
238 |
+
EMB_OUT_PATH.parent.mkdir(parents=True, exist_ok=True)
|
239 |
+
np.save(EMB_OUT_PATH, E)
|
240 |
+
print(f"[Emb] Saved to: {EMB_OUT_PATH}")
|
241 |
+
|
242 |
+
# ---- PCA scatter ----
|
243 |
+
Y = pca_2d_numpy(E)
|
244 |
+
plt.figure(figsize=(5, 5))
|
245 |
+
plt.scatter(Y[:, 0], Y[:, 1], s=6)
|
246 |
+
plt.title("Encoder embeddings — PCA (first 2 components)")
|
247 |
+
plt.xlabel("PC1")
|
248 |
+
plt.ylabel("PC2")
|
249 |
+
plt.tight_layout()
|
250 |
+
plt.show()
|
251 |
+
|
252 |
+
# ---- Optional: toy linear probe ----
|
253 |
+
if RUN_LINEAR_PROBE:
|
254 |
+
_ = run_linear_probe(E, epochs=5, lr=1e-3)
|
255 |
+
|
256 |
+
# ---- If .npz labels exist, print seizure_fraction summary ----
|
257 |
+
if ds.labels_per_sec_exist and len(ds.seizure_fraction_estimates) > 0:
|
258 |
+
print(f"[Meta] Mean seizure_fraction (from labels_sec): "
|
259 |
+
f"{np.mean(ds.seizure_fraction_estimates):.3f} "
|
260 |
+
f"(over {len(ds.seizure_fraction_estimates)} files)")
|
261 |
+
else:
|
262 |
+
print("[Meta] No labels_sec found in files (expected for .npy datasets).")
|
demo_embeddings.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c987bfaa11cf2366ae5b3aafbca0675acb361fe480d9a37094554df25b22f584
|
3 |
+
size 1671296
|