Ssast
Code for the AAAI 2022 paper "SSAST: Self-Supervised Audio Spectrogram Transformer".
Install / Use
/learn @YuanGongND/SsastREADME
SSAST: Self-Supervised Audio Spectrogram Transformer
- News
- Introduction
- Citing
- Getting Started
- SSAST Model
- Data Preparation
- Self-Supervised Pretraining
- Fine-tuning On Downstream Tasks
- Pretrained Models
- Contact
News
March, 2022: We released a new preprint CMKD: CNN/Transformer-Based Cross-Model Knowledge Distillation for Audio Classification, where we proposed a knowledge distillation based method to further improve the AST model performance without changing its architecture. This method can be applied in the fine-tuning stage of SSAST.
Feb 2022: I will present SSAST at AAAI 2022 at 12:00 PM - 1:45 PM (EST) on February 25th and then 7:45 PM - 9:30 PM (EST) on February 27th.
Introduction
<p align="center"><img src="https://github.com/YuanGongND/ssast/blob/main/figure/10854_ssast.png?raw=true" alt="Illustration of AST." width="800"/></p>This repository contains the official implementation (in PyTorch) of the Self-Supervised Audio Spectrogram Transformer (SSAST) proposed in the AAAI 2022 paper SSAST: Self-Supervised Audio Spectrogram Transformer (Yuan Gong, Cheng-I Jeff Lai, Yu-An Chung, James Glass; MIT CSAIL). [Slides]
SSAST is the first patch-based joint discriminative and generative self-supervised learning framework, and also the first self-supervised learning framework for AST. SSAST significantly boosts AST performance on all downstream tasks we evaluated with an average improvement of 60.9%, leading to similar or even better results than a supervised pretrained AST. SSAST can be used as a drop-in replacement of previous ImageNet (supervised) pretrained AST, and has the advantage of 1) no labeled data is used; 2) flexible patch size and shape, ImagenNet pretraining only supports square patches; and 3) better performance on many tasks, in particular speech tasks.
Citing
Please cite our paper if you find this repository useful.
@inproceedings{gong2022ssast,
title={SSAST: Self-Supervised Audio Spectrogram Transformer},
author={Gong, Yuan and Lai, Cheng-I and Chung, Yu-An and Glass, James},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={36},
number={10},
pages={10699--10709},
year={2022}
}
@inproceedings{gong21b_interspeech,
author={Yuan Gong and Yu-An Chung and James Glass},
title={{AST: Audio Spectrogram Transformer}},
year=2021,
booktitle={Proc. Interspeech 2021},
pages={571--575},
doi={10.21437/Interspeech.2021-698}
}
Getting Started
Prepare the Environment
Clone or download this repository and set it as the working directory, create a virtual environment and install the dependencies.
cd ast/
python3 -m venv venvssast
source venvssast/bin/activate
pip install -r requirements.txt
Where is the code?
The SSAST model and pretraining function code is in src/models/ast_model.py.
The self-supervised pretraining script is src/pretrain/{run_mask_{frame,patch}, run_mask_{frame,patch}_tiny}, which calls src/run.py, which then calls src/traintest_mask.py, which then calls src/models/ast_model.py.
The fine-tuning scripts are in src/finetune/, for PSLA experiments, these scripts call src/run.py, which then calls src/traintest.py, which then calls src/traintest_mask.py, which then calls src/models/ast_model.py.
The data preparation samples are in src/prep_data.
SSAST Model
The SSAST model script is in src/models/ast_models.py.
ASTModel(label_dim=527,
fshape=16, tshape=16 fstride=10, tstride=10,
input_fdim=128, input_tdim=1024, model_size='base',
pretrain_stage=True, load_pretrained_mdl_path=None)
Parameters:
label_dim : The number of classes, only need to specify in the fine-tuning stage.
fshape: The side length of the patch on the frequency dimension.
tshape: The side length of the patch on the time dimension.
fstride: The stride of patch spliting on the frequency dimension, for 16*16 patchs, fstride=16 means no overlap, fstride=10 means overlap of 6.
tstride: The stride of patch spliting on the time dimension, for 16*16 patchs, tstride=16 means no overlap, tstride=10 means overlap of 6.
input_fdim: The number of frequency bins of the input spectrogram.
input_tdim: The number of time frames of the input spectrogram.
model_size: The model size of AST, should be in [tiny, small, base] (default: base).
pretrain_stage: Set as True in the self-supervised pretraining stage and False in the fine-tuning stage.
load_pretrained_mdl_path: The pretrained model used for fine-tuning. Only needed when pretrain_stage=False as it is for fine-tuning.
Methods:
forward(x, task, cluster=True, mask_patch=400)
The entry method of the class that calls fine-tuning and pretraining methods. Parameters:
x: the input spectrogram in shape[batch_size, time_frame_num, frequency_bin_num].Note: the input spectrogram should be normalized with dataset mean and std, see here.task: the pretraining or fine-tuning task, should in[ft_avgtok, ft_cls, pretrain_mpc, pretrain_mpg], see below for details.cluster: setTrueif using cluster patch masking strategy.mask_patch: the number of patch masked, only needed in the pretraining stage.
finetuningavgtok(x): fine-tune the model by using the average of the outputs of all tokens as the clip represention. Return in shape [batch_size, label_dim].
finetuningcls(x): fine-tune the model by using the output of the cls token as clip represention. Return in shape [batch_size, label_dim].
mpc(x, mask_patch=mask_patch, cluster=cluster): pretrain the model with mask_patch number of masked patches with the discriminative objective. Return the accuracy and NCE loss of the pretext task.
mpg(x, mask_patch=mask_patch, cluster=cluster): pretrain the model with mask_patch number of masked patches with the generative objective. Return the mean square error of the pretext task.
Example:
# pretraining stage
# suppose you have an unlabled dataset with avg length of 1024 frames (i.e., 10.24s)
input_tdim = 1024
# create a 16*16 patch based AST model for pretraining.
# note, we don't use patch split overlap in pretraining, so fstride=fshape and tstride=tshape
ast_mdl = ASTModel(
fshape=16, tshape=16, fstride=16, tstride=16,
input_fdim=128, input_tdim=input_tdim, model_size='base',
pretrain_stage=True)
# # alternatively, create a frame based AST model
# ast_mdl = ASTModel(
# fshape=128, tshape=2, fstride=128, tstride=2,
# input_fdim=128, input_tdim=input_tdim, model_size='base',
# pretrain=True)
# do pretraining, see src/traintest_mask.py for our full pretraining code
# input in shape [batch_size, input_tdim, input_fdim]
test_input = torch.zeros([10, input_tdim, 128])
# mask 100 patches for both discriminative and generative loss
acc, nce_loss = ast_mdl(test_input, task='pretrain_mpc', mask_patch=100)
mse_loss = ast_mdl(test_input, task='pretrain_mpg', mask_patch=100)
loss = nce_loss + 10 * mse_loss
# do back propagate and update the model, etc
# after pretraining, save the pretrained model.
# the code is designed for Dataparallel model
ast_mdl = torch.nn.DataParallel(ast_mdl)
torch.save(ast_mdl.state_dict(), './test_mdl.pth')
# fine-tuning stage
# now you have a labeled dataset you want to finetune AST on
# suppose the avg length is 100 frames (1s) and there are 35 classes
# the fshape and tshape must be same in pretraining and finetuning
# but fstride and tstride can be different in pretraining and finetuning
# using smaller strides improves the performance but also increase the computational overhead
# set pretrain_stage as False since now is in the finetuning stage
# provide the path of the pretrained model you want to load
input_tdim = 100 # fine-tuning data length can be different with pretraining data length
ast_mdl = ASTModel(label_dim=35,
fshape=16, tshape=16, fstride=10, tstride=10,
input_fdim=128, input_tdim=input_tdim, model_size='base',
pretrain_stage=False, load_pretrained_mdl_path='./test_mdl.pth')
# # alternatively, use a frame based AST model
# ast_mdl = ASTModel(label_dim=35,
# fshape=128, tshape=2, fstride=128, tstride=1,
# input_fdim=128, input_tdim=input_tdim, model_size='base',
# pretrain_stage=False, load_pretrained_mdl_path='./test_mdl.pth')
# do finetuning, see src/traintest.py for our finetuning code
test_input = torch.zeros([10, input_tdim, 128])
prediction = ast_mdl(test_input, task='ft_avgtok')
# output should in shape [batch_size, label_dim]
print(prediction.shape)
# calculate the loss, do back propagate, etc
Data Preparation
For both pretraining and fine-tuning, our dataloader requires two files:
- A json file containing path of the audio and corresponding label.
- Self-supervised pretraining does not need any label, but our current version of
dataloader.pyneeds label information to run, you need to use a dummy label for pretraining data. Below is an example json file.
- Self-supervised pretraining does not need any label, but our current version of
{
"data": [
{
"wav": "/data/sls/audioset/data/audio/eval/_/_/--4gqARaEJE_0.000.
