DIVA
DIVA: A Dirichlet Process Mixtures Based Incremental Deep Clustering Algorithm via Variational Auto-Encoder
Install / Use
/learn @Ghiara/DIVAREADME
DIVA: A Dirichlet Process Mixtures Based Incremental Deep Clustering Algorithm via Variational Auto-Encoder
Official implementation for paper: DIVA: A Dirichlet Process Based Incremental Deep Clustering Algorithm via Variational Auto-Encoder
<p align="center"> A demo video for showing DIVA's dynamic adaptation ability in deep clustering. <a href="https://www.youtube.com/watch?v=uHPGAUSSbh8"> <img src="https://github.com/Ghiara/diva/blob/master/pretrained/poster.png" alt="Demo Video" width="100%"> </a> </p> <!-- </img> -->Requirements
we use python 3.7 and pytorch-lightning for training. Before start training, make sure you have installed bnpy package in your local environment, refer to here for more details.
- python 3.7
- bnpy 1.7.0
- pytorch-lightning 1.9.4
- numpy, pandas, matplotlib, seaborn, torchvision
Installation Instructions
# Install dependencies and package
pip3 install -r requirements.txt
Detailed Code Structure Overview
DIVA
|- dataset # folder for saving datasets
| |- reuters10k.py # dataset instance of reuters10k that follows torchvision formatting
| |- reuters10k.mat # origin data of reuters10k
|- pretrained # folder for saving pretrained example model on MNIST
| |- dpmm # folder for saving DPMM cluster module
| |- diva_vae.ckpt # checkpoint file of trained DIVA VAE part on MNIST with 100 epochs and ACC 0.91
| |- pretrained.ipynb # example file how to load pretrained model
|- diva.py # diva implementations for image and text; train manager
|- main_mnist.ipynb # main entry point of diva training on MNIST, including evaluation plots.
|- main_stl10.ipynb # main entry point of diva training on STL-10.
|- main_imagenet50.ipynb # main entry point of diva training on ImageNet-50.
|- feature_extraction.ipynb # script that using pretrained ResNet-50 to extract features of STL-10.
Dataset Notation
Since the training on raw image of STL-10 and ImageNet-50 is quite difficult, we use extractor to get low dimensional encoding of these datasets. For STL-10 we use pretrained ResNet-50 provided by torchvision, just follow the script feature_extraction.ipynb you will get the features that we used in our study. For ImageNet-50 we use the MOCO to extract features, more details refer to here and here.
Load pretrained DPMM clustering module
# load DPMM module
dpmm_model = bnpy.ioutil.ModelReader.load_model_at_prefix('path/to/your/bn_model/folder/dpmm', prefix="Best")
# function for getting the cluster parameters
def calc_cluster_component_params(bnp_model):
comp_mu = [torch.Tensor(bnp_model.obsModel.get_mean_for_comp(i)) for i in np.arange(0, bnp_model.obsModel.K)]
comp_var = [torch.Tensor(np.sum(bnp_model.obsModel.get_covar_mat_for_comp(i), axis=0)) for i in np.arange(0, bnp_model.obsModel.K)]
return comp_mu, comp_var
Citation
if you would like to refer to our work, please use following BibTeX formatted citation
@misc{bing2023diva,
title={DIVA: A Dirichlet Process Based Incremental Deep Clustering Algorithm via Variational Auto-Encoder},
author={Zhenshan Bing and Yuan Meng and Yuqi Yun and Hang Su and Xiaojie Su and Kai Huang and Alois Knoll},
year={2023},
eprint={2305.14067},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Related Skills
proje
Interactive vocabulary learning platform with smart flashcards and spaced repetition for effective language acquisition.
YC-Killer
2.7kA library of enterprise-grade AI agents designed to democratize artificial intelligence and provide free, open-source alternatives to overvalued Y Combinator startups. If you are excited about democratizing AI access & AI agents, please star ⭐️ this repository and use the link in the readme to join our open source AI research team.
best-practices-researcher
The most comprehensive Claude Code skills registry | Web Search: https://skills-registry-web.vercel.app
research_rules
Research & Verification Rules Quote Verification Protocol Primary Task "Make sure that the quote is relevant to the chapter and so you we want to make sure that we want to have it identifie
