Wmmd
Official PyTorch implementation of JASA paper "Word-Level Maximum Mean Discrepancy Regularization for Word Embedding"
Install / Use
/learn @youqiangao/WmmdREADME
Word-level Maximum Mean Discrepancy Regularization for Word Embedding
Authors: Youqian Gao, Ben Dai
Introduction
Word-level maximum mean discrepancy (WMMD) is a novel regularization framework that recognizes and accounts for the "word-level distribution discrepancy"–a common phenomenon in a range of NLP datasets where word distributions are noticeably disparate under different labels. The regularization serves a specific purpose: to enhance/preserve the distribution discrepancies within word embedding numerical vectors and thus prevent overfitting.
We visualize the embedding of a CNN model trained on the CE-T1 dataset, with different regularization. The WMMD regularization separates the word vectors under different labels and enforces the word-level discrepancy between the two groups.
<p align="center"> <img src="figures/embed-visual.png" width="70%" /> </p>Furthermore, WMMD regularization offers a flexible framework to incorporate prior information and high-order distribution discrepancies of words. For more information, kindly refer to the paper.
Motivation
The following word clouds show significant differences in the most common words between the "sports" and "business" categories in the BBC News dataset. Given the word-level distribution discrepancy in the dataset, a good word embedding should result in numerical word representations that maintain the word-level distribution discrepancy.
<p align="center"> <img src="figures/bbc-word-dist-diff.jpg" width="70%" /> </p>Setup
pip install -r requirements.txt
Running the WMMD code
python main.py dataset=<dataset> model=<model> regularizer=<regularizer>
For more details, refer to python main.py --help. Currently, the learning framework supports the following configurations:
- dataset: bbc-news, ce-t1, sim1, sim2, sim3 (check the paper for more details on three simulation examples)
- model: cnn, bilstm, gru, mlp, logistic
- regularizer: wmmd, swmmd (structured wmmd), biwmmd (bigram wmmd), dropout, l1, none
Citation
If you use this code, please cite the following paper:
@article{gao2025word,
title={Word-Level Maximum Mean Discrepancy Regularization for Word Embedding},
author={Gao, Youqian and Dai, Ben},
journal={Journal of the American Statistical Association},
number={just-accepted},
pages={1--22},
year={2025},
publisher={Taylor \& Francis}
}
