kartoun commited on
Commit
4cfc688
·
verified ·
1 Parent(s): ccb9e1e

Upload 2 files

Browse files
DBbun_EEG_Encoder_Eval_Demo_v1.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dbbun_eeg_eval.py
2
+ # DBbun EEG — pretrained encoder evaluation & demo
3
+ # Run this in your "eeg" conda env (or other env with numpy, torch, matplotlib)
4
+
5
+ from pathlib import Path
6
+ import json
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.data import Dataset, DataLoader
11
+ import matplotlib.pyplot as plt
12
+
13
+ # ======================
14
+ # CONFIG — EDIT THESE
15
+ # ======================
16
+ # Use local model folder (recommended for Spyder)
17
+ MODEL_DIR = r"C:\DBbun\Code\EEG\pretraining\pretrained_out" # has encoder_state.pt + model_def.json
18
+ # If you prefer TorchScript instead of state_dict, set this True and make sure encoder_traced.pt exists.
19
+ PREFER_TORCHSCRIPT = False
20
+
21
+ # Your validation data directory — can contain .npy OR .npz files (recurses)
22
+ DATA_DIR = r"d:\dbbun-eeg\data\valid" # e.g., r"d:\dbbun-eeg\data\val" if NPZ
23
+ USE_NPZ = True # set True if your files are .npz
24
+
25
+ # Windowing
26
+ HOP_SECONDS = 1.5 # larger hop => fewer windows (faster)
27
+ MAX_FILES = 10 # limit for quick demo (set None to use all)
28
+
29
+ # Embedding export
30
+ SAVE_EMBEDDINGS = True
31
+ EMB_OUT_PATH = Path(MODEL_DIR) / "demo_embeddings.npy" # saved as (n_windows, latent_dim)
32
+
33
+ # Linear probe (optional; toy labels)
34
+ RUN_LINEAR_PROBE = True
35
+
36
+ # ======================
37
+ # Utilities
38
+ # ======================
39
+
40
+ def load_model_def(model_dir: str):
41
+ md_path = Path(model_dir) / "model_def.json"
42
+ if not md_path.exists():
43
+ raise FileNotFoundError(f"model_def.json not found at {md_path}")
44
+ return json.loads(md_path.read_text())
45
+
46
+ class Conv1dEncoder(nn.Module):
47
+ def __init__(self, in_channels, widths=(32,64,128), latent_dim=128, dropout=0.1):
48
+ super().__init__()
49
+ layers, prev = [], in_channels
50
+ for w in widths:
51
+ layers += [
52
+ nn.Conv1d(prev, w, kernel_size=7, padding=3, stride=2),
53
+ nn.BatchNorm1d(w),
54
+ nn.GELU(),
55
+ nn.Dropout(dropout),
56
+ ]
57
+ prev = w
58
+ self.conv = nn.Sequential(*layers)
59
+ self.pool = nn.AdaptiveAvgPool1d(1)
60
+ self.proj = nn.Linear(prev, latent_dim)
61
+
62
+ def forward(self, x):
63
+ h = self.conv(x) # (B, W, L')
64
+ g = self.pool(h).squeeze(-1) # (B, W)
65
+ z = self.proj(g) # (B, latent)
66
+ return z, h
67
+
68
+ def load_encoder(model_dir: str, prefer_ts: bool = False):
69
+ md = load_model_def(model_dir)
70
+ if prefer_ts and (Path(model_dir) / "encoder_traced.pt").exists():
71
+ print("[Model] Loading TorchScript encoder_traced.pt")
72
+ enc = torch.jit.load(str(Path(model_dir) / "encoder_traced.pt"), map_location="cpu")
73
+ # TorchScript returns the scripted forward; assume it returns (z, h) as in training
74
+ scripted = True
75
+ else:
76
+ print("[Model] Loading state_dict encoder_state.pt")
77
+ enc = Conv1dEncoder(
78
+ in_channels=md["channels"],
79
+ widths=tuple(md["encoder_channels"]),
80
+ latent_dim=md["latent_dim"],
81
+ dropout=md["dropout"]
82
+ )
83
+ enc.load_state_dict(torch.load(Path(model_dir) / "encoder_state.pt", map_location="cpu"))
84
+ scripted = False
85
+ enc.eval()
86
+ window_samples = int(md["window_seconds"] * md["sample_rate"])
87
+ return enc, md, window_samples, scripted
88
+
89
+ # Dataset that supports .npy (memmap) or .npz (loads "eeg" key if present)
90
+ class EEGWindows(Dataset):
91
+ def __init__(self, folder, window_len, hop, use_npz=False, max_files=None, print_summary=True):
92
+ self.folder = Path(folder)
93
+ self.use_npz = use_npz
94
+ self.window = int(window_len)
95
+ self.hop = int(hop)
96
+ # gather files
97
+ pattern = "*.npz" if use_npz else "*.npy"
98
+ self.files = sorted(self.folder.rglob(pattern))
99
+ if max_files:
100
+ self.files = self.files[:int(max_files)]
101
+ if print_summary:
102
+ print(f"[Data] Found {len(self.files)} files under {self.folder}")
103
+ # build index
104
+ self.index = []
105
+ self.shapes = []
106
+ self.labels_per_sec_exist = False
107
+ self.seizure_fraction_estimates = []
108
+
109
+ for i, f in enumerate(self.files):
110
+ if use_npz:
111
+ with np.load(f, allow_pickle=True) as z:
112
+ if "eeg" in z.files:
113
+ a = np.array(z["eeg"], dtype=np.float32)
114
+ else:
115
+ # fallback to first array in the container
116
+ a = np.array(z[list(z.files)[0]], dtype=np.float32)
117
+ # Try to detect labels
118
+ if "labels_sec" in z.files:
119
+ self.labels_per_sec_exist = True
120
+ lbl = np.array(z["labels_sec"]).astype(np.uint8)
121
+ self.seizure_fraction_estimates.append(float(lbl.mean()))
122
+ else:
123
+ a = np.load(f, mmap_mode='r') # (C, T)
124
+
125
+ if a.ndim != 2:
126
+ continue
127
+ C, T = int(a.shape[0]), int(a.shape[1])
128
+ self.shapes.append((C, T))
129
+
130
+ if T >= self.window:
131
+ starts = np.arange(0, T - self.window + 1, self.hop, dtype=int)
132
+ self.index += [(i, int(s)) for s in starts]
133
+
134
+ self.channels = max((c for c, _ in self.shapes), default=1)
135
+
136
+ if print_summary:
137
+ total_windows = len(self.index)
138
+ print(f"[Data] Channels(max): {self.channels} | Windows: {total_windows}")
139
+ if self.labels_per_sec_exist and self.seizure_fraction_estimates:
140
+ print(f"[Data] labels_sec present. Mean seizure_fraction across loaded files: "
141
+ f"{np.mean(self.seizure_fraction_estimates):.3f}")
142
+
143
+ def __len__(self):
144
+ return len(self.index)
145
+
146
+ def __getitem__(self, idx):
147
+ fi, start = self.index[idx]
148
+ f = self.files[fi]
149
+ if self.use_npz:
150
+ with np.load(f, allow_pickle=True) as z:
151
+ if "eeg" in z.files:
152
+ a = z["eeg"]
153
+ else:
154
+ a = z[list(z.files)[0]]
155
+ seg = np.asarray(a[:, start:start + self.window], dtype=np.float32)
156
+ else:
157
+ a = np.load(f, mmap_mode='r')
158
+ seg = np.asarray(a[:, start:start + self.window], dtype=np.float32)
159
+
160
+ # pad/crop channels to common number
161
+ C = seg.shape[0]
162
+ if C < self.channels:
163
+ pad = np.zeros((self.channels - C, seg.shape[1]), dtype=np.float32)
164
+ seg = np.concatenate([seg, pad], axis=0)
165
+ elif C > self.channels:
166
+ seg = seg[:self.channels]
167
+
168
+ # per-window z-score normalization
169
+ mu = seg.mean(axis=1, keepdims=True)
170
+ sd = seg.std(axis=1, keepdims=True) + 1e-6
171
+ seg = (seg - mu) / sd
172
+
173
+ return torch.from_numpy(seg) # (C, L)
174
+
175
+ def pca_2d_numpy(E: np.ndarray):
176
+ """Return 2D PCA projection using NumPy SVD"""
177
+ E0 = E - E.mean(0, keepdims=True)
178
+ U, S, Vt = np.linalg.svd(E0, full_matrices=False)
179
+ Y = E0 @ Vt[:2].T
180
+ return Y
181
+
182
+ def run_linear_probe(E: np.ndarray, epochs=5, lr=1e-3):
183
+ """Tiny demo head on toy labels derived from PC1 threshold; replace with real labels if you have them."""
184
+ Y = pca_2d_numpy(E)
185
+ labels = (Y[:, 0] > Y[:, 0].mean()).astype(np.int64)
186
+ Z = torch.from_numpy(E).float()
187
+ y = torch.from_numpy(labels)
188
+ head = nn.Linear(E.shape[1], 2)
189
+ opt = torch.optim.AdamW(head.parameters(), lr=lr)
190
+ lossf = nn.CrossEntropyLoss()
191
+ for ep in range(1, epochs + 1):
192
+ opt.zero_grad(set_to_none=True)
193
+ logits = head(Z)
194
+ loss = lossf(logits, y)
195
+ loss.backward()
196
+ opt.step()
197
+ with torch.no_grad():
198
+ acc = (logits.argmax(1) == y).float().mean().item()
199
+ print(f"[Probe] Epoch {ep}/{epochs} - loss: {loss.item():.4f} | acc: {acc:.3f}")
200
+ return Y
201
+
202
+ # ======================
203
+ # Main
204
+ # ======================
205
+ if __name__ == "__main__":
206
+ torch.backends.cudnn.benchmark = True
207
+ try:
208
+ torch.set_float32_matmul_precision("medium")
209
+ except Exception:
210
+ pass
211
+
212
+ enc, md, WIN_SAMPLES, scripted = load_encoder(MODEL_DIR, PREFER_TORCHSCRIPT)
213
+ HOP = int(HOP_SECONDS * md["sample_rate"])
214
+ print(f"[Config] Window = {WIN_SAMPLES} samples | Hop = {HOP} | Sample rate = {md['sample_rate']} Hz")
215
+
216
+ ds = EEGWindows(DATA_DIR, WIN_SAMPLES, HOP, use_npz=USE_NPZ, max_files=MAX_FILES, print_summary=True)
217
+ if len(ds) == 0:
218
+ raise SystemExit("No windows produced — check DATA_DIR / USE_NPZ / window settings.")
219
+
220
+ # DataLoader: 0 workers on Windows avoids fork issues in Spyder
221
+ dl = DataLoader(ds, batch_size=64, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
222
+
223
+ # ---- Extract embeddings ----
224
+ all_Z = []
225
+ enc.eval()
226
+ with torch.no_grad():
227
+ for i, x in enumerate(dl):
228
+ # x: (B, C, L) on CPU; encoder is on CPU by default in this script
229
+ z, _ = enc(x) if not scripted else enc(x) # both return (z, h)
230
+ all_Z.append(z.cpu().numpy())
231
+ if i >= 50: # limit passes for speed; raise/remove for full run
232
+ break
233
+
234
+ E = np.concatenate(all_Z, axis=0) # (n_windows, latent_dim)
235
+ print(f"[Emb] Collected embeddings: {E.shape}")
236
+
237
+ if SAVE_EMBEDDINGS:
238
+ EMB_OUT_PATH.parent.mkdir(parents=True, exist_ok=True)
239
+ np.save(EMB_OUT_PATH, E)
240
+ print(f"[Emb] Saved to: {EMB_OUT_PATH}")
241
+
242
+ # ---- PCA scatter ----
243
+ Y = pca_2d_numpy(E)
244
+ plt.figure(figsize=(5, 5))
245
+ plt.scatter(Y[:, 0], Y[:, 1], s=6)
246
+ plt.title("Encoder embeddings — PCA (first 2 components)")
247
+ plt.xlabel("PC1")
248
+ plt.ylabel("PC2")
249
+ plt.tight_layout()
250
+ plt.show()
251
+
252
+ # ---- Optional: toy linear probe ----
253
+ if RUN_LINEAR_PROBE:
254
+ _ = run_linear_probe(E, epochs=5, lr=1e-3)
255
+
256
+ # ---- If .npz labels exist, print seizure_fraction summary ----
257
+ if ds.labels_per_sec_exist and len(ds.seizure_fraction_estimates) > 0:
258
+ print(f"[Meta] Mean seizure_fraction (from labels_sec): "
259
+ f"{np.mean(ds.seizure_fraction_estimates):.3f} "
260
+ f"(over {len(ds.seizure_fraction_estimates)} files)")
261
+ else:
262
+ print("[Meta] No labels_sec found in files (expected for .npy datasets).")
demo_embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c987bfaa11cf2366ae5b3aafbca0675acb361fe480d9a37094554df25b22f584
3
+ size 1671296