SkillAgentSearch skills...

PaSST

Efficient Training of Audio Transformers with Patchout

Install / Use

/learn @kkoutini/PaSST
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

PaSST: Efficient Training of Audio Transformers with Patchout

This is the implementation for Efficient Training of Audio Transformers with Patchout

Patchout significantly reduces the training time and GPU memory requirements to train transformers on audio spectrograms, while improving their performance.

<p align="center"><img src="https://github.com/kkoutini/PaSST/blob/main/.github/speed_mem_map.png?raw=true" width="600"/></p>

Patchout works by dropping out some of the input patches during training. In either an unstructured way (randomly, similar to dropout), or entire time-frames or frequency bins of the extracted patches (similar to SpecAugment), which corresponds to rows/columns in step 3 of the figure below.

<p align="center"><img src="https://github.com/kkoutini/PaSST/raw/main/.github/passt_diag.png?raw=true" width="600"/></p>

Table of contents

Pre-trained models for Inference and embeddings extractions

If you only want to use the embeddings generated by the pretrained models, use your own fine-tuning framework, or you need it only for inference, you can find a stripped down version of this repo here. The package follows HEAR 2021 NeurIPS Challenge API, and can be installed:

pip install hear21passt

This repo is a complete framework for training the models and fine-tuning pre-trained models on Audioset on downstream tasks.

Getting the logits from the pretrained models

from hear21passt.base import get_basic_model,get_model_passt
import torch
# get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer
model = get_basic_model(mode="logits")
print(model.mel) # Extracts mel spectrogram from raw waveforms.
print(model.net) # the transformer network.

# example inference
model.eval()
model = model.cuda()
with torch.no_grad():
    # audio_wave has the shape of [batch, seconds*32000] sampling rate is 32k
    # example audio_wave of batch=3 and 10 seconds
    audio = torch.ones((3, 32000 * 10))*0.5
    audio_wave = audio.cuda()
    logits=model(audio_wave) 

Getting a pre-trained model for fine tuning

from hear21passt.base import get_basic_model,get_model_passt
import torch
# get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer
model = get_basic_model(mode="logits")
print(model.mel) # Extracts mel spectrogram from raw waveforms.

# optional replace the transformer with one that has the required number of classes i.e. 50
model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476",  n_classes=50)
print(model.net) # the transformer network.


# now model contains mel + the transformer pre-trained model ready to be fine tuned.
# It's still expecting input of the shape [batch, seconds*32000] sampling rate is 32k

model.train()
model = model.cuda()

Development environment

If you want to use the same environment as in the paper, you can follow the instructions below.

Setting up the development experiments environment

For training models from scratch or fine-tuning using the same setup as in the paper:

  1. If needed, create a new environment with python 3.8 and activate it:
conda create -n passt python=3.8
conda activate passt
  1. Install pytorch build that suits your system. For example:
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch

  1. Install the requirements:
pip install -r requirements.txt

Setting up using the exported conda environment

Alternatively, you can use the exported conda environment environment.yml to create the environment.

For setting up Mamba is recommended since it works faster than conda:

conda install mamba -n base -c conda-forge

Now you can import the environment from environment.yml

mamba env create -f environment.yml

Now you have an environment named ba3l.

Checking the environment

In order to check if your environment matched the environment we used in our runs, please check the environment.yml and pip_list.txt files, which were exported using:

conda env export --no-builds | grep -v "prefix" > environment.yml
pip list > pip_list.txt

Getting started

If you want to use your setup and only use the models from this repo, you can get the models train them from scratch or fine-tune them on your own dataset as explained above Pre-trained models for Inference and embeddings extractions. The rest of this section explains using this repo for training and fine-tuning the models. For that, first you need to set up the development environment as explained above.

General information

The repo is built using sacred for experiment management and configuration, pytorch-lightning for training, and wandb for logging.

Each dataset has a main experiment file such as ex_audioset.py and ex_openmic.py and a dataset folder. The experiment file contains the main training and validation logic. The dataset folder contains the dataset specific code needed to download, preprocess and load the dataset for training.

In general, you can prob the experiment file for help, this will print the available commands and basic options:

python ex_audioset.py help

Configuring the experiment

Each experiment has a set of default configuration options, defined in the experiment file, e.g. ex_audioset.py. You can override any of the configuration using the sacred syntax. You can use the print_config command to print the configuration values without training a model:

 python ex_audioset.py print_config

You can use then use the command line interface to override any of the configuration options (sacred syntax), using with e.g.:

python ex_audioset.py with trainer.precision=16 

This will train on Audioset using 16-bit precision.

The overall configurations look like this:

  ...
  seed = 542198583                  # the random seed for this experiment
  slurm_job_id = ''
  speed_test_batch_size = 100
  swa = True
  swa_epoch_start = 50
  swa_freq = 5
  use_mixup = True
  warm_up_len = 5
  weight_decay = 0.0001
  basedataset:
    base_dir = 'audioset_hdf5s/'     # base directory of the dataset, change it or make a link
    eval_hdf5 = 'audioset_hdf5s/mp3/eval_segments_mp3.hdf'
    wavmix = 1
    ....
    roll_conf:
      axis = 1
      shift = None
      shift_range = 50
  datasets:
    test:
      batch_size = 20
      dataset = {CMD!}'/basedataset.get_test_set'
      num_workers = 16
      validate = True
    training:
      batch_size = 12
      dataset = {CMD!}'/basedataset.get_full_training_set'
      num_workers = 16
      sampler = {CMD!}'/basedataset.get_ft_weighted_sampler'
      shuffle = None
      train = True
  models:
    mel:
      freqm = 48
      timem = 192
      hopsize = 320
      htk = False
      n_fft = 1024
      n_mels = 128
      norm = 1
      sr = 32000
      ...
    net:
      arch = 'passt_s_swa_p16_128_ap476'
      fstride = 10
      in_channels = 1
      input_fdim = 128
      input_tdim = 998
      n_classes = 527
      s_patchout_f = 4
      s_patchout_t = 40
      tstride = 10
      u_patchout = 0
      ...
  trainer:
    accelerator = None
    accumulate_grad_batches = 1
    amp_backend = 'native'
    amp_level = 'O2'
    auto_lr_find = False
    auto_scale_batch_size = False
    ...

There are many things that can be updated from the command line. In short:

  • All the configuration options under trainer are pytorch lightning trainer api. For example, to turn off cuda benchmarking add trainer.benchmark=False to the command line.
  • wandb is the wandb configuration. For example, to change the wandb project wandb.project="test_project" to the command line.
  • models.net are the PaSST (or the chosen NN) options. Examples: models.net.u_patchout, models.net.s_patchout_f models.net.s_patchout_t control the unstructured patchout and structured patchout over frequency and time. input_fdim and input_tdim are the input spectrogram dimensions over frequency and time. models.net.fstride and models.net.tstride are the strides of the input patches over frequency and time, setting these to 16 means no patch overlap.
  • models.mel are the preprocessing options (mel spectrograms). mel.sr is the sampling rate, mel.hopsize is the hop size of the STFT window, mel.n_mels is the number of mel
View on GitHub
GitHub Stars371
CategoryEducation
Updated10d ago
Forks58

Languages

Python

Security Score

100/100

Audited on Mar 17, 2026

No findings