PhysioWave
[NeurIPS 2025] PhysioWave: A Multi-Scale Wavelet-Transformer for Physiological Signal Representation
Install / Use
/learn @ForeverBlue816/PhysioWaveREADME
PhysioWave: A Multi-Scale Wavelet-Transformer for Physiological Signal Representation
<div align="center"> </div> <div align="center">Official PyTorch implementation of PhysioWave, accepted at NeurIPS 2025
A novel wavelet-based architecture for physiological signal processing that leverages adaptive multi-scale decomposition and frequency-guided masking to advance self-supervised learning
</div>🌟 Key Features
<table> <tr> <td width="50%">✨ Learnable Wavelet Decomposition
- Adaptive multi-resolution analysis
- Soft gating mechanism for optimal wavelet selection
📊 Frequency-Guided Masking
- Novel masking strategy prioritizing high-energy components
- Superior to random masking for signal representation
🔗 Cross-Scale Feature Fusion
- Attention-based fusion across decomposition levels
- Hierarchical feature integration
🧠 Multi-Modal Support
- Unified framework for ECG and EMG signals
- Extensible to other physiological signals
📈 Large-Scale Pretraining: Models trained on 182GB of ECG and 823GB of EMG data
</div>🏗️ Model Architecture
<div align="center"> <img src="fig/model.png" alt="PhysioWave Architecture" width="90%"> </div>Pipeline Overview
The PhysioWave pretraining pipeline consists of five key stages:
- Wavelet Initialization: Standard wavelet functions (e.g., 'db6', 'sym4') generate learnable low-pass and high-pass filters
- Multi-Scale Decomposition: Adaptive wavelet decomposition produces multi-scale frequency-band representations
- Patch Embedding: Decomposed features are processed into spatio-temporal patches with FFT-based importance scoring
- Masked Encoding: High-scoring patches are masked and processed through Transformer layers with rotary position embeddings
- Reconstruction: Lightweight decoder reconstructs masked patches for self-supervised learning
Core Components
| Component | Description | |-----------|-------------| | 🌊 Learnable Wavelet Decomposition | Adaptively selects optimal wavelet bases for input signals | | 📐 Multi-Scale Feature Reconstruction | Hierarchical decomposition with soft gating between scales | | 🎯 Frequency-Guided Masking | Identifies and masks high-energy patches for self-supervised learning | | 🔄 Transformer Encoder/Decoder | Processes masked patches with rotary position embeddings |
📊 Performance Highlights
Benchmark Results
<div align="center">| Task | Dataset | Metric | Performance | |------|---------|--------|-------------| | ECG Arrhythmia | PTB-XL | Accuracy | 73.1% | | ECG Multi-Label | CPSC 2018 | F1-Micro | 77.1% | | ECG Multi-Label | Shaoxing | F1-Micro | 94.6% | | EMG Gesture | EPN-612 | Accuracy | 94.5% |
</div>Multi-Label Classification Detailed Metrics
<details> <summary><b>CPSC 2018 Dataset (9-Class Multi-Label)</b></summary> <div align="center">| Metric | Micro-Average | Macro-Average | |--------|---------------|---------------| | Precision | 0.7389 | 0.6173 | | Recall | 0.8059 | 0.6883 | | F1-Score | 0.7709 | 0.6500 | | AUROC | 0.9584 | 0.9280 |
</div>Dataset Details:
- 9 official diagnostic classes (SNR, AF, IAVB, LBBB, RBBB, PAC, PVC, STD, STE)
- 12-lead ECG signals at 500 Hz
- Record-level split to prevent data leakage
| Metric | Micro-Average | Macro-Average | |--------|---------------|---------------| | Precision | 0.9389 | 0.9361 | | Recall | 0.9536 | 0.9470 | | F1-Score | 0.9462 | 0.9413 | | AUROC | 0.9949 | 0.9930 |
</div>Dataset Details:
- 4 merged diagnostic classes (SB, AFIB, GSVT, SR)
- 12-lead ECG signals at 500 Hz
- Balanced multi-label distribution
💾 Pretrained Models
<div align="center">📥 Download Pretrained Models
</div>| Model | Parameters | Training Data | Description |
|-------|------------|---------------|-------------|
| ecg.pth | 14M | 182GB ECG | ECG pretrained model |
| emg.pth | 5M | 823GB EMG | EMG pretrained model |
Usage:
# Load pretrained model
checkpoint = torch.load('ecg.pth')
model.load_state_dict(checkpoint['model_state_dict'])
🚀 Quick Start
Prerequisites
# Clone repository
git clone https://github.com/ForeverBlue816/PhysioWave.git
cd PhysioWave
# Create conda environment
conda create -n physiowave python=3.11
conda activate physiowave
# Install PyTorch (CUDA 12.1)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# Install requirements
pip install -r requirements.txt
📦 Data Preparation
<details> <summary><b>Dataset Download Links</b></summary>ECG Datasets
- PTB-XL Database - 21,837 clinical ECG records
- MIMIC-IV-ECG - 800K+ ECG recordings
- PhysioNet Challenge 2021 - Multi-database ECG
- CPSC 2018 - Arrhythmia classification challenge
- Chapman-Shaoxing - Large-scale 12-lead ECG
EMG Datasets
- EPN-612 Dataset - 612 hand gestures
- NinaPro Database DB6 - HD-sEMG recordings
HDF5 Structure
# Single-label classification
{
'data': (N, C, T), # Signal data: float32
'label': (N,) # Labels: int64
}
# Multi-label classification
{
'data': (N, C, T), # Signal data: float32
'label': (N, K) # Multi-hot labels: float32
}
Dimensions:
N= Number of samplesC= Number of channelsT= Time pointsK= Number of classes (multi-label only)
Signal Specifications
| Signal | Channels | Length | Sampling Rate | Normalization | |--------|----------|--------|---------------|---------------| | ECG | 12 | 2048 | 500 Hz | MinMax [-1,1] or Z-score | | EMG | 8 | 1024 | 200-2000 Hz | Max-abs or Z-score |
</details>🔄 Preprocessing Examples
<details> <summary><b>ECG Preprocessing (PTB-XL - Single-Label)</b></summary># Download PTB-XL dataset
wget -r -N -c -np https://physionet.org/files/ptb-xl/1.0.3/
# Preprocess for single-label classification
python ECG/ptbxl_finetune.py
Output files:
train.h5- Training data with shape(N, 12, 2048)val.h5- Validation datatest.h5- Test data
Label format: (N,) with 5 superclasses (NORM, MI, STTC, CD, HYP)
# Preprocess CPSC 2018 dataset
python ECG/cpsc_multilabel.py
Output files:
cpsc_9class_train.h5- Training datacpsc_9class_val.h5- Validation datacpsc_9class_test.h5- Test datacpsc_9class_info.json- Dataset metadatalabel_map.json- Class mappingsrecord_splits.json- Record-level split info
Label format: (N, 9) with 9 official CPSC classes
# Preprocess Chapman-Shaoxing dataset
python ECG/shaoxing_multilabel.py
Output files:
train.h5- Training dataval.h5- Validation datatest.h5- Test datadataset_info.json- Metadatarecord_splits.json- Split information
Label format: (N, 4) with 4 merged classes (SB, AFIB, GSVT, SR)
# Download from Zenodo and preprocess
python EMG/epn_finetune.py
Output files:
epn612_train_set.h5- Training set(N, 8, 1024)epn612_val_set.h5- Validation setepn612_test_set.h5- Test set
Label format: (N,) with 6 gesture classes
🎯 Training
Pretraining
<details> <summary><b>ECG Pretraining</b></summary># Edit ECG/pretrain_ecg.sh to set data paths
bash ECG/pretrain_ecg.sh
Key parameters:
--mask_ratio 0.7 # Mask 70% of patches
--masking_strategy frequency_guided # Use frequency-guided masking
--importance_ratio 0.7 # Balance importance vs randomness
--epochs 100 # Pretraining epochs
</details>
<details>
<summary><b>EMG Pretraining</b></summary>
# Edit EMG/pretrain_emg.sh to set data paths
bash EMG/pretrain_emg.sh
Key parameters:
--mask_ratio 0.6 # Mask 60% of patches
--in_channels 8 # 8-channel EMG
--wave_kernel_size 16 # Smaller kernel for EMG
</details>
Fine-tuning
Single-Label Classification
<details> <summary><b>Standard Fine-tuning (ECG/EMG)</b></summary># ECG fine-tuning (PTB-XL)
bash ECG/finetune_ecg.sh
# EMG fine-tuning (EPN-612)
bash EMG/finetune_emg.sh
Example command:
torchrun --nproc_per_node=4 finetune.py \
--train_file path/to/train.h5 \
--val_file path/to/val.h5 \
--test_file path/to/test.h5 \
--pretrained_path pretrained/ecg.pth \
-
