SPDTransNet
Repository for the SPDTransNet model, a Transformer-based architecture to analyze sequences of SPD matrices without loss of their Riemannian structure, and said model's SP-MHA (or Structure-Preserving Multihead Attention) component.
Install / Use
/learn @MathieuSeraphim/SPDTransNetREADME

This repository is the official implementation of the SPDTransNet model, as presented in the paper Structure-Preserving Transformers for Sequences of SPD Matrices by Mathieu Seraphim, Alexis Lechervy, Florian Yger, Luc Brun and Olivier Etard - Conference EUSIPCO 2024, France (available on arXiv).
A previous, non-structure-preserving version of this method was presented in the paper
Temporal Sequences of EEG Covariance Matrices for Automated Sleep Stage Scoring with Attention Mechanisms
by Mathieu Seraphim, Paul Dequidt,
Alexis Lechervy, Florian Yger,
Luc Brun and Olivier Etard -
Conference CAIP 2023, Cyprus (official Version on Record here -
Accepted Manuscript here).
Other references can be found in this repository's main paper.
Important: due to space constraints, some specifics were left out of the paper. To selectively read the relevant
sections, please prioritize sections marked with an asterisk, as well as the similarly marked
subsections within them.
A spreadsheet with our model's results (as summarized in the paper) can be found here.
For additional information, feel free to contact us.
<h2 style="text-align: center;">Brief overview</h2>The SPDTransNet model
(see figure above) is designed to adapt Transformers
to analyze timeseries of Symmetric Positive Definite (SPD) matrices.
More specifically, it takes as input a sequence of elements, and outputs a classification only for the central element of
the sequence (the surrounding elements providing contextual information).
Each element described as a (potentially multichannel) sequence of SPD matrices (as seen in the figure),
bijectively tokenized
to be processed by the model.
A first intra-element step extracting element-wise features, and a second inter-element step to compare said features. Both steps are based on Transformer encoders, utilizing our SP-MHA auto-attention mechanism described below.
<h3 id="SPMHA" style="text-align: center;">Structure-Preserving Multihead Attention (SP-MHA)</h3>The SP-MHA bloc generates $h$ attention maps from projected tokens in the same way as the original Linear MHA, bui it then combines them and applies them to the unaltered tokens (i.e. 1D vectors) in the V(alue) tensor.
The tokens in V only undergo linear combinations weighted by the attention maps, without linear mappings or concatenations. Hence, this auto-attention mechanism does not alter the geometric properties of the inputs, as long as:
- Said inputs may be described in vector form without loss of structural information,
- Linear combinations between these tokens do not cause a loss of structural information.
Both of these are true for our tokenized SPD matrices (cf. paper).
<h3 style="text-align: center;">Application to EEG Sleep Staging</h3>EEG sleep staging refers to the subdivision of a set of concurrent 1D electrophysiological signals (including EEG) into
fixed-length windows called "epochs", which are manually labeled ("scored") with a given sleep stage.
Theoretical justifications for this approach can be found in the paper.
In this repository, we use the data from the Montreal Archive of Sleep Studies' SS3 dataset (MASS-SS3), comprised of 62 healthy subjects and scored using the AASM scoring rules, with 30s epochs and 5 sleep stages: Awake, REM sleep, and the N1, N2 and N3 non-REM sleep stages.
Each epoch is subdivided into 30 1s segments, to capture relevant events of around 1 to 2 seconds in duration. As these events also exhibit distinctive frequential characteristics, we filter the signals to isolate frequency bands.
Our default preprocessing strategy uses 8 EEG signals and 6 distinct frequency bands - giving us, for each epoch, a timeseries of $S$ = 30 SPD matrices of size 9 $\times$ 9 over $C$ = 7 channels after preprocessing (cf. figure).
<h3 id="caveats" style="text-align: center;">Our results</h3>| # | Model | MF1 | Macro Acc. | N1 F1 | Valid. metric | Token dim. $d(m)$ | # Feat. Tokens $t$ | |:---:|:--------------------------------------------------------------------------------:|:---------------------:|:--------------------:|:--------------------:|:-------------:|:-----------------:|:-------------------:| | 1 | SPDTransNet, $L=13$ | 81.06 $\pm$ 3.49 | 84.87 $\pm$ 2.47 | 60.39 $\pm$ 6.77 | MF1 | 351 ($m = 26$) | 7 | | 2 | SPDTransNet, $L=21$ | 81.24 $\pm$ 3.29 | 84.40 $\pm$ 2.61 | 60.50 $\pm$ 6.18 | MF1 | 351 ($m = 26$) | 10 | | 3 | SPDTransNet, $L=29$ | 80.83 $\pm$ 3.40 | 84.29 $\pm$ 2.65 | 60.35 $\pm$ 6.01 | N1 F1 | 351 ($m = 26$) | 5 | | 4 | Classic MHA | 80.82 $\pm$ 3.40 | 84.60 $\pm$ 2.95 | 60.16 $\pm$ 7.20 | MF1 | 351 ($m = 26$) | 10 | | 5 | DeepSleepNet | 78.14 $\pm$ 4.12 | 80.05 $\pm$ 3.47 | 53.52 $\pm$ 8.24 | N/A | N/A | N/A | | 6 | IITNet | 78.48 $\pm$ 3.15 | 81.88 $\pm$ 2.89 | 56.01 $\pm$ 6.54 | N/A | N/A | N/A | | 7 | GraphSleepNet | 75.58 $\pm$ 3.75 | 79.75 $\pm$ 3.41 | 50.80 $\pm$ 8.06 | N/A | N/A | N/A | | 8 | Dequidt et al. | 81.04 $\pm$ 3.26 | 82.59 $\pm$ 3.45 | 58.42 $\pm$ 6.09 | N/A | N/A | N/A | | 9 | Seraphim et al. | 79.78 $\pm$ 4.56 | 81.76 $\pm$ 4.61 | 58.43 $\pm$ 6.41 | MF1 | Concatenation | 1 |
The reported results for models from the literature differ from those published by the original authors, as we
re-trained their models using our methodology.
A more detailed account of our results can be found here.
More information in the paper.
<h3 id="caveats" style="text-align: center;">A few caveats</h3>In this repository, the term "epoch" may be used in two contexts:
- In reference to the model, it corresponds to a training cycle (standard Deep Learning nomenclature).
- In reference to the data, it corresponds to a period of 30s in the input signal, and the features describing it (standard sleep medicine nomenclature).
This implementation has been build with modularity in mind. However, the approach taken is fairly nonstandard:
- Folders are organized starting with an underscore, to keep them in a sorted order - this goes against Python coding conventions (marking them as "protected"), but doesn't impact computations,
- The way classes are instantiated from YAML files (the "Wrapper" classes) is a complete workaround to how the utilized functions were supposed to work.
Consequently, the code remains modular and adaptable, but might be challenging to reuse in another context.
EDIT / IMPORTANT: I recently got contacted by someone who spent way too much time trying to reuse my code in another context. I may have been too subtle in the above sentence, so let me be clear: the structure of this project is an overengineered piece of garbage. It works well enough as a standalone project, but making it compatible with any decent framework would be a nightmare. I beg of you, if you want to reuse my code, please contact me by e-mail using the link at the top of this page. You'll save yourself a world of trouble.
<h2 style="text-align: center;">Instructions</h2>The following instructions should be sufficient to reproduce the results presented in the paper. To expand upon it,
see the further documentation linked below.
Please note that the MASS dataset
