Torchfl
A Python library for rapid prototyping, experimenting, and logging of federated learning using state-of-the-art models and datasets. Built using PyTorch and PyTorch Lightning.
Install / Use
/learn @torchfl-org/TorchflREADME
Table of Contents
- Key Features
- Installation
- Examples and Usage
- Available Models
- Available Datasets
- Contributing
- Citation
Features
- Python 3.6+ support. Built using
torch-1.10.1,torchvision-0.11.2, andpytorch-lightning-1.5.7. - Customizable implementations for state-of-the-art deep learning models which can be trained in federated or non-federated settings.
- Supports finetuning of the pre-trained deep learning models, allowing for faster training using transfer learning.
- PyTorch LightningDataModule wrappers for the most commonly used datasets to reduce the boilerplate code before experiments.
- Built using the bottom-up approach for the datamodules and models which ensures abstractions while allowing for customization.
- Provides implementation of the federated learning (FL) samplers, aggregators, and wrappers, to prototype FL experiments on-the-go.
- Backwards compatible with the PyTorch LightningDataModule, LightningModule, loggers, and DevOps tools.
- More details about the examples and usage can be found below.
- For more documentation related to the usage, visit - https://torchfl.readthedocs.io/.
Installation
Stable Release
As of now, torchfl is available on PyPI and can be installed using the following command in your terminal:
$ pip install torchfl
This is the preferred method to install torchfl with the most stable release.
If you don't have pip installed, this Python installation guide can guide you through the process.
Examples and Usage
Although torchfl is primarily built for quick prototyping of federated learning experiments, the models, datasets, and abstractions can also speed up the non-federated learning experiments. In this section, we will explore examples and usages under both the settings.
Non-Federated Learning
The following steps should be followed on a high-level to train a non-federated learning experiment. We are using the EMNIST (MNIST) dataset and densenet121 for this example.
-
Import the relevant modules.
from torchfl.datamodules.emnist import EMNISTDataModule from torchfl.models.wrapper.emnist import MNISTEMNISTimport pytorch_lightning as pl from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks import ( ModelCheckpoint, LearningRateMonitor, DeviceStatsMonitor, ModelSummary, ProgressBar, ... )For more details, view the full list of PyTorch Lightning callbacks and loggers on the official website.
-
Setup the PyTorch Lightning trainer.
trainer = pl.Trainer( ... logger=[ TensorBoardLogger( name=experiment_name, save_dir=os.path.join(checkpoint_save_path, experiment_name), ) ], callbacks=[ ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"), LearningRateMonitor("epoch"), DeviceStatsMonitor(), ModelSummary(), ProgressBar(), ], ... )More details about the PyTorch Lightning Trainer API can be found on their official website.
-
Prepare the dataset using the wrappers provided by
torchfl.datamodules.datamodule = EMNISTDataModule(dataset_name="mnist") datamodule.prepare_data() datamodule.setup() -
Initialize the model using the wrappers provided by
torchfl.models.wrappers.# check if the model can be loaded from a given checkpoint if (checkpoint_load_path) and os.path.isfile(checkpoint_load_path): model = MNISTEMNIST( "densenet121", "adam", {"lr": 0.001} ).load_from_checkpoint(checkpoint_load_path) else: pl.seed_everything(42) model = MNISTEMNIST("densenet121", "adam", {"lr": 0.001}) trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader()) -
Collect the results.
val_result = trainer.test( model, test_dataloaders=datamodule.val_dataloader(), verbose=True ) test_result = trainer.test( model, test_dataloaders=datamodule.test_dataloader(), verbose=True ) -
The corresponding files for the experiment (model checkpoints and logger metadata) will be stored at
default_root_dirargument given to the PyTorch LightningTrainerobject in Step 2. For this experiment, we use the Tensorboard logger. To view the logs (and related plots and metrics), go to thedefault_root_dirpath and find the Tensorboard log files. Upload the files to the Tensorboard Development portal following the instructions here. Once the log files are uploaded, a unique url to your experiment will be generated which can be shared with ease! An example can be found here. -
Note that,
torchflis compatible with all the loggers supported by PyTorch Lightning. More information about the PyTorch Lightning loggers can be found here.
For full non-federated learning example scripts, check examples/trainers.
Federated Learning
The following steps should be followed on a high-level to train a federated learning experiment.
- Pick a dataset and use the
datamodulesto create federated data shards with iid or non-iid distribution.def get_datamodule() -> EMNISTDataModule: datamodule: EMNISTDataModule = EMNISTDataModule( dataset_name=SUPPORTED_DATASETS_TYPE.MNIST, train_batch_size=10 ) datamodule.prepare_data() datamodule.setup() return datamodule agent_data_shard_map = get_agent_data_shard_map().federated_iid_dataloader( num_workers=fl_params.num_agents, workers_batch_size=fl_params.local_train_batch_size, ) - Use the TorchFL
agentsmodule and themodelsmodule to initialize the global model, agents, and distribute their models.def initialize_agents( fl_params: FLParams, agent_data_shard_map: Dict[int, DataLoader] ) -> List[V1Agent]: """Initialize agents.""" agents = [] for agent_id in range(fl_params.num_agents): agent = V1Agent( id=agent_id, model=MNISTEMNIST( model_name=EMNIST_MODELS_ENUM.MOBILENETV3SMALL, optimizer_name=OPTIMIZERS_TYPE.ADAM, optimizer_hparams={"lr": 0.001}, model_hparams={"pre_trained": True, "feature_extract": True}, fl_hparams=fl_params, ), data_shard=agent_data_shard_map[agent_id], ) agents.append(agent) return agents global_model = MNISTEMNIST( model_name=EMNIST_MODELS_ENUM.MOBILENETV3SMALL, optimizer_name=OPTIMIZERS_TYPE.ADAM, optimizer_hparams={"lr": 0.001}, model_hparams={"pre_trained": True, "feature_extract": True}, fl_hparams=fl_params, ) all_agents = initialize_agents(fl_params, agent_data_shard_map) - Initiliaze an
FLParamobject with the desired FL hyperparameters and pass it on to theEntrypointobject which will abstract the training.fl_params = FLParams( experiment_name="iid_mnist_fedavg_10_agents_5_sampled_50_epochs_mobilenetv3small_latest", num_agents=10, global_epochs=10, local_epochs=2, sampling_ratio=0.5, ) entrypoint = Entrypoint( global_model=global_model, global_da
