Skada
Domain adaptation toolbox compatible with scikit-learn and pytorch
Install / Use
/learn @scikit-adaptation/SkadaREADME
SKADA - Domain Adaptation with scikit-learn and PyTorch
SKADA is a library for domain adaptation (DA) with a scikit-learn and PyTorch/skorch compatible API with the following features:
- DA estimators and transformers with a scikit-learn compatible API (fit, transform, predict).
- PyTorch/skorch API for deep learning DA algorithms.
- Classifier/Regressor and data Adapter DA algorithms compatible with scikit-learn pipelines.
- Compatible with scikit-learn validation loops (cross_val_score, GridSearchCV, etc).
Citation: If you use this library in your research, please cite the following reference:
Gnassounou T., Kachaiev O., Flamary R., Collas A., Lalou Y., de Mathelin A., Gramfort A., Bueno R., Michel F., Mellot A., Loison V., Odonnat A., Moreau T. (2024). SKADA : Scikit Adaptation (version 0.3.0). URL: https://scikit-adaptation.github.io/
or in Bibtex format :
@misc{gnassounou2024skada,
author = {Gnassounou, Théo and Kachaiev, Oleksii and Flamary, Rémi and Collas, Antoine and Lalou, Yanis and de Mathelin, Antoine and Gramfort, Alexandre and Bueno, Ruben and Michel, Florent and Mellot, Apolline and Loison, Virginie and Odonnat, Ambroise and Moreau, Thomas},
month = {7},
title = {SKADA : Scikit Adaptation},
url = {https://scikit-adaptation.github.io/},
year = {2024}
}
Implemented algorithms
The following algorithms are currently implemented.
Domain adaptation algorithms
- Sample reweighting methods (Gaussian [1], Discriminant [2], KLIEPReweight [3], DensityRatio [4], TarS [21], KMMReweight [23])
- Sample mapping methods (CORAL [5], Optimal Transport DA OTDA [6], LinearMonge [7], LS-ConS [21])
- Subspace methods (SubspaceAlignment [8], TCA [9], Transfer Subspace Learning [27])
- Other methods (JDOT [10], DASVM [11], OT Label Propagation [28])
Any methods that can be cast as an adaptation of the input data can be used in one of two ways:
- a scikit-learn transformer (Adapter) which provides both a full Classifier/Regressor estimator
- or an
Adapterthat can be used in a DA pipeline withmake_da_pipeline. Refer to the examples below and visit the galleryfor more details.
Deep learning domain adaptation algorithms
- Deep Correlation alignment (DeepCORAL [12])
- Deep joint distribution optimal (DeepJDOT [13])
- Divergence minimization (MMD/DAN [14])
- Adversarial/discriminator based DA (DANN [15], CDAN [16])
DA metrics
- Importance Weighted [17]
- Prediction entropy [18]
- Soft neighborhood density [19]
- Deep Embedded Validation (DEV) [20]
- Circular Validation [11]
Installation
The library is not yet available on PyPI. You can install it from the source code.
pip install git+https://github.com/scikit-adaptation/skada
Short examples
We provide here a few examples to illustrate the use of the library. For more details, please refer to this example, the quick start guide and the gallery.
First, the DA data in the SKADA API is stored in the following format:
X, y, sample_domain
Where X is the input data, y is the target labels and sample_domain is the
domain labels (positive for source and negative for target domains). We provide
below an example ho how to fit a DA estimator:
from skada import CORAL
da = CORAL()
da.fit(X, y, sample_domain=sample_domain) # sample_domain passed by name
ypred = da.predict(Xt) # predict on test data
One can also use Adapter classes to create a full pipeline with DA:
from skada import CORALAdapter, make_da_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
pipe = make_da_pipeline(StandardScaler(), CORALAdapter(), LogisticRegression())
pipe.fit(X, y, sample_domain=sample_domain) # sample_domain passed by name
Please note that for Adapter classes that implement sample reweighting, the
subsequent classifier/regressor must require sample_weights as input. This is
done with the set_fit_requires method. For instance, with LogisticRegression, you
would use LogisticRegression().set_fit_requires('sample_weight'):
from skada import GaussianReweightAdapter, make_da_pipeline
pipe = make_da_pipeline(GaussianReweightAdapter(),
LogisticRegression().set_fit_request(sample_weight=True))
Finally SKADA can be used for cross validation scores estimation and hyperparameter selection :
from sklearn.model_selection import cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from skada import CORALAdapter, make_da_pipeline
from skada.model_selection import SourceTargetShuffleSplit
from skada.metrics import PredictionEntropyScorer
# make pipeline
pipe = make_da_pipeline(StandardScaler(), CORALAdapter(), LogisticRegression())
# split and score
cv = SourceTargetShuffleSplit()
scorer = PredictionEntropyScorer()
# cross val score
scores = cross_val_score(pipe, X, y, params={'sample_domain': sample_domain},
cv=cv, scoring=scorer)
# grid search
param_grid = {'coraladapter__reg': [0.1, 0.5, 0.9]}
grid_search = GridSearchCV(estimator=pipe,
param_grid=param_grid,
cv=cv, scoring=scorer)
grid_search.fit(X, y, sample_domain=sample_domain)
Acknowledgements
This toolbox has been created and is maintained by the SKADA team that includes the following members:
- Théo Gnassounou
- Oleksii Kachaiev
- Rémi Flamary
- Antoine Collas
- Yanis Lalou
- Antoine de Mathelin
- Ruben Bueno
SKADA has benefited from the financing or manpower from the following partners:
<img src="https://scikit-adaptation.github.io/dev/_static/images/logo_anr.jpg" alt="ANR" style="height:60px;"/> <img src="https://scikit-adaptation.github.io/dev/_static/images/logo_hiparis.png" alt="Hi!PARIS" style="height:60px;"/> <img src="https://scikit-adaptation.github.io/dev/_static/images/logo_elias.png" alt="ELIAS European project" style="height:60px;"/>License
The library is distributed under the 3-Clause BSD license.
References
[1] Shimodaira Hidetoshi. "Improving predictive inference under covariate shift by weighting the log-likelihood function." Journal of statistical planning and inference 90, no. 2 (2000): 227-244.
[2] Sugiyama Masashi, Taiji Suzuki, and Takafumi Kanamori. "Density-ratio matching under the Bregman divergence: a unified framework of density-ratio estimation." Annals of the Institute of Statistical Mathematics 64 (2012): 1009-1044.
[3] Sugiyama Masashi, Taiji Suzuki, Shinichi Nakajima, Hisashi Kashima, Paul Von Bünau, and Motoaki Kawanabe. "Direct importance estimation for covariate shift adaptation." Annals of the Institute of Statistical Mathematics 60 (2008): 699-746.
[4] Sugiyama Masashi, and Klaus-Robert Müller. "Input-dependent estimation of generalization error under covariate shift." (2005): 249-279.
[5] Sun Baochen, Jiashi Feng, and Kate Saenko. "Correlation alignment for unsupervised domain adaptation." Domain adaptation in computer vision applications (2017): 153-171.
[6] Courty Nicolas, Flamary Rémi, Tuia Devis, and Alain Rakotomamonjy. "Optimal transport for domain adaptation." IEEE Trans. Pattern Anal. Mach. Intell 1, no. 1-40 (2016): 2.
[7] Flamary, R., Lounici, K., & Ferrari, A. (2019). Concentration bounds for linear monge mapping estimation and optimal transport domain adaptation. arXiv preprint arXiv:1905.10155.
[8] Fernando, B., Habrard, A., Sebban, M., & Tuytelaars, T. (2013). Unsupervised visual domain adaptation using subspace alignment. In Proceedings of the IEEE international conference on computer vision (pp. 2960-2967).
[9] Pan, S. J., Tsang, I. W., Kwok, J. T., & Yang, Q. (2010). Domain adaptation via transfer component analysis. IEEE transactions on neural networks, 22(2), 199-210.
[10] Courty, N., Flamary, R., Habrard, A., & Rakotomamonjy, A. (2017). Joint distribution optimal transportation for domain adaptation. Adv
Related Skills
node-connect
343.3kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
92.1kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
343.3kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
343.3kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
