sjepa

Signal-JEPA

Self-supervised pre-trained weights for the Signal-JEPA foundation model from Guetschel et al. (2024), packaged for use with braindecode. See the full API reference in the docs: braindecode.models.SignalJEPA.

The model was pre-trained on the Lee2019 dataset (62 EEG channels in the 10-10 layout, sampled at 128 Hz). The repo ships the weights together with a config.json so they can be loaded in one line with YourModelClass.from_pretrained(repo_id, ...).

Available checkpoints

Two variants are published:

repo ID channel embedding included when to use
braindecode/signal-jepa βœ“ 62-row _ChannelEmbedding aligned with the pre-training layout your recording channels are a subset (by name, case-insensitive) of the 62 pre-training channels β€” you want to reuse the learned spatial embeddings
braindecode/signal-jepa_without-chans βœ— only the SSL backbone (feature encoder + transformer) your channels are not a subset of the pre-training set, or you prefer to train channel embeddings from scratch

If you are unsure, start with braindecode/signal-jepa_without-chans: it always works, regardless of your electrode layout.

Quick start

Base model (pre-training architecture)

The base model outputs contextual features, not class predictions. Use it for downstream feature extraction or further SSL.

from braindecode.models import SignalJEPA

# With the pre-trained channel embeddings (recording channels βŠ‚ pre-train set):
model = SignalJEPA.from_pretrained("braindecode/signal-jepa")

# Or: with your own channels, kept aligned to the pre-training embedding table
model = SignalJEPA.from_pretrained(
    "braindecode/signal-jepa",
    chs_info=raw.info["chs"],           # subset of the 62 pre-training channels
    channel_embedding="pretrain_aligned",
)

# Or: without pre-trained channel embeddings (any electrode layout):
model = SignalJEPA.from_pretrained(
    "braindecode/signal-jepa_without-chans",
    chs_info=raw.info["chs"],
    strict=False,  # the channel-embedding weight is intentionally missing
)

Downstream architectures

Three classification architectures are introduced in the paper:

  • a) Contextual β€” uses the full transformer encoder
  • b) Post-local β€” discards the transformer; spatial convolution after local features
  • c) Pre-local β€” discards the transformer; spatial convolution before local features

All three add a freshly-initialized classification head on top of the SSL backbone. The head is not part of the checkpoint and will be trained from scratch during fine-tuning; pass strict=False so from_pretrained does not complain about those missing keys.

from braindecode.models import (
    SignalJEPA_Contextual,
    SignalJEPA_PreLocal,
    SignalJEPA_PostLocal,
)

# a) Contextual β€” keeps the transformer
model = SignalJEPA_Contextual.from_pretrained(
    "braindecode/signal-jepa",          # or "signal-jepa_without-chans"
    n_times=256,                         # e.g. 2 s at 128 Hz
    n_outputs=4,
    strict=False,                        # ignore un-trained classification head
)

# b) Post-local β€” transformer discarded
model = SignalJEPA_PostLocal.from_pretrained(
    "braindecode/signal-jepa_without-chans",
    n_chans=19,
    n_times=256,
    n_outputs=4,
    strict=False,
)

# c) Pre-local β€” transformer discarded
model = SignalJEPA_PreLocal.from_pretrained(
    "braindecode/signal-jepa_without-chans",
    n_chans=19,
    n_times=256,
    n_outputs=4,
    strict=False,
)

See the braindecode tutorial Fine-tuning a Foundation Model (Signal-JEPA) for a complete example including layer freezing and training with skorch.EEGClassifier.

Channel embedding modes

SignalJEPA and SignalJEPA_Contextual accept a channel_embedding kwarg:

  • "scratch" (default): the _ChannelEmbedding table has one row per user channel, initialized from chs_info. Compatible with the without-chans checkpoint.
  • "pretrain_aligned": the table has 62 rows in the pre-training order, forward indexes into the subset matching your chs_info (matched by channel name, case-insensitive). Compatible with the full checkpoint.

from_pretrained picks the right mode automatically based on the checkpoint's config.json; override with the channel_embedding= kwarg if needed.

Citation

@article{guetschel2024sjepa,
  title   = {S-JEPA: towards seamless cross-dataset transfer
             through dynamic spatial attention},
  author  = {Guetschel, Pierre and Moreau, Thomas and Tangermann, Michael},
  journal = {arXiv preprint arXiv:2403.11772},
  year    = {2024},
}
Downloads last month
1,355
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Paper for braindecode/signal-jepa