WaveTrainerFit
Official implementation of "Wave-Trainer-Fit: Neural Vocoder with Trainable Prior and Fixed-Point Iteration towards High-Quality Speech Generation from SSL features" accepted by ICASSP2026.
Install / Use
/learn @line/WaveTrainerFitREADME
Wave-Trainer-Fit | Neural vocoder from SSL features (Accepted by ICASSP 2026)
[audio-samples] [hugging-face] [arXiv]
Official implementation of "Wave-Trainer-Fit: Neural Vocoder with Trainable Prior and Fixed-Point Iteration towards High-Quality Speech Generation from SSL features" accepted by ICASSP2026.
Abstract:<br> We propose WaveTrainerFit, a neural vocoder that performs high-quality waveform generation from data-driven features such as SSL features. WaveTrainerFit builds upon the WaveFit vocoder, which integrates diffusion model and generative adversarial network. Furthermore, the proposed method incorporates the following key improvements: 1. By introducing trainable priors, the inference process starts from noise close to the target speech instead of Gaussian noise. 2. Reference-aware gain adjustment is performed by imposing constraints on the trainable prior to matching the speech energy. These improvements are expected to reduce the complexity of waveform modeling from data-driven features, enabling high-quality waveform generation with fewer inference steps. Through experiments, we showed that WaveTrainerFit can generate highly natural waveforms with improved speaker similarity from data-driven features, while requiring fewer iterations than WaveFit. Moreover, we showed that the proposed method works robustly with respect to the depth at which SSL features are extracted. <img src=./assets/concept.png width=100%>
Table of contents
Installation
By executing the following code, you can easily set up the environment.
For the installation of uv, please refer to Installing uv of the official website.
git clone https://github.com/line/WaveTrainerFit
cd WaveTrainerFit
uv sync
Quick start
If you want to use our model right away, it's a good idea to try the pre-trained model.
You can call the model and execute waveform generation with the following code.
import torchaudio
import torch
from wavetrainerfit import load_pretrained_vocoder
from transformers import WavLMModel, AutoFeatureExtractor
ssl_preprocessor = AutoFeatureExtractor.from_pretrained('microsoft/wavlm-large')
ssl_model: WavLMModel = WavLMModel.from_pretrained('microsoft/wavlm-large')
layer = 2
ssl_vocoder, cfg = load_pretrained_vocoder(f'wavlm{layer}_wavetrainerfit5')
waveform, sr = torchaudio.load('./assets/ljspeech-samples/LJ037-0171.wav')
if sr != 16000:
waveform = torchaudio.transforms.Resample(
orig_freq=sr,
new_freq=16000
)(waveform)
inputs = ssl_preprocessor(
waveform[0].numpy(),
sampling_rate=16000,
return_tensors="pt"
)
with torch.no_grad():
inputs = ssl_model(**inputs, output_hidden_states=True)
inputs = inputs.hidden_states[layer] # (Batch, Timeframe, Featuredim)
generated_waveform = ssl_vocoder.pred(
conditional_feature=inputs, # (Batch, Timeframe, Featuredim)
T_=5 # num of iteration
)
torchaudio.save(
'./assets/ljspeech-samples/LJ037-0171-reconstructed.wav',
generated_waveform[-1][:, 0].cpu(), 24000
)
Pretrained models
We have released our pre-trained models on HuggingFace. The list of provided models is as follows: | Model-tag | Conditional features | Layer num | sampling rate of inputs | sampling rate of outputs [Hz]|Datasets for training| # iters | #iters of model | |:------| :---------: | :---: | :---: | :---: | :-----: | :------------: | :------------: | | wavlm2_wavetrainerfit5 | WavLM-large | 2 | 16000 | 24000 |LibriTTS-R (train-clean-360)| 400k | 5 | | wavlm2_wavefit5 | WavLM-large | 2 | 16000 | 24000 |LibriTTS-R (train-clean-360)| 400k | 5 | | wavlm8_wavetrainerfit5 | WavLM-large | 8 | 16000 | 24000 |LibriTTS-R (train-clean-360)| 400k | 5 | | wavlm8_wavefit5 | WavLM-large | 8 | 16000 | 24000 |LibriTTS-R (train-clean-360)| 400k | 5 | | wavlm24_wavetrainerfit5 | WavLM-large | 24 | 16000 | 24000 |LibriTTS-R (train-clean-360)| 400k | 5 | | wavlm24_wavefit5 | WavLM-large | 24 | 16000 | 24000 |LibriTTS-R (train-clean-360)| 400k | 5 | | xlsr8_wavetrainerfit5 | XLS-R-300m | 8 | 16000 | 24000 |LibriTTS-R (train-clean-360)| 400k | 5 | | xlsr8_wavefit5 | XLS-R-300m | 8 | 16000 | 24000 |LibriTTS-R (train-clean-360)| 400k | 5 | | whisper8_wavetrainerfit5 | ※ Whisper-medium | 8 | 16000 | 24000 |LibriTTS-R (train-clean-360)| 400k | 5 | | whisper8_wavefit5 | ※ Whisper-medium | 8 | 16000 | 24000 |LibriTTS-R (train-clean-360)| 400k | 5 |
※ As a result of our verification, we found that amplitude decay occurs in Whisper features after about 2.0 seconds. During evaluation, our model processed inputs by dividing them into 2.0-second segments → extracting features with the Whisper encoder → recombining → resynthesizing. If you use this model in your application, the upstream feature extraction must also follow this flow.
Training
If you want to train the model, please refer to the _template folder in the egs directory. This contains most of the information needed for training.
Licenses
This project is licensed under the Apache License 2.0 for all modifications and additions made by LY Corporation. The original codebase by Yukara Ikemiya is licensed under the MIT License. Please see the LICENSE, LICENSE-MIT, and NOTICE files for full details.
The pre-trained models published on Hugging Face are licensed under either the CC-BY-4.0 or CC-BY-SA-3.0. For details, please refer to our Hugging Face repository.
Citation
If you publish a paper using this repository, please cite the following paper🙏
@inproceedings{wavetrainerfit,
author={Hien Ohnaka and Yuma Shirahata and Masaya Kawamura},
title={Wave-Trainer-Fit: Neural Vocoder with Trainable Prior and Fixed-Point Iteration towards High-Quality Speech Generation from SSL features},
booktitle={Accepted to IEEE ICASSP},
year={2026}
}
Acknowledgements
Our work would not have been possible without Mr. Ikemiya's code. We would like to take this opportunity to express our gratitude to Mr. Ikemiya.
- ref: wavefit-pytorch
