MATCHA
Communication-efficient decentralized SGD (Pytorch)
Install / Use
/learn @JYWa/MATCHAREADME
MATCHA: Communication-Efficient Decentralized SGD
Code to reproduce the experiments reported in this paper:
Jianyu Wang, Anit Kumar Sahu, Zhouyi Yang, Gauri Joshi, Soummya Kar, "MATCHA: Speeding Up Decentralized SGD via Matching Decomposition Sampling," arxiv preprint 2019.
A short version has been abridged in FL-NeurIPS'19 and received the Distinguished Student Paper Award.
This repo contains the implementations of MATCHA and D-PSGD for any arbitrary node topologies. You can also use it to develop other decentralized training methods. Please cite this paper if you use this code for your research/projects.
Dependencies and Setup
The code runs on Python 3.5 with PyTorch 1.0.0 and torchvision 0.2.1. The peer-to-peer communication among workers is achieved by MPI4Py sendrecv function.
Training examples
Here is an example on how to use MATCHA to train a neural network.
import util
from graph_manager import FixedProcessor, MatchaProcessor
from communicator import decenCommunicator, ChocoCommunicator, centralizedCommunicator
# Define the base node topology by giving the graph ID
# There are six pre-defined graphs in utils.py
base_graph = util.select_graph(args.graphid)
# Preprocess the base topology: 1) decompose it into matchings;
# 2) get activation probabities for matchings;
# 3) compute the mixing weight;
# 4) generate activation flags for each iteration
# All these information is stored in GP
GP = MatchaProcessor(base_graph,
commBudget = args.budget,
rank = rank,
size = size,
iterations = args.epoch * num_batches,
issubgraph = True)
# Define the communicator
communicator = decenCommunicator(rank, size, GP)
# Start training
for batch_id, (data, label) in enumerate(data_loader):
# same as serial training
output = model(data) # forward
loss = criterion(output, label)
loss.backward() # backward
optimizer.step() # gradient step
optimizer.zero_grad()
# additional line to average local models at workers
communicator.communicate(model)
In order to use D-PSGD, we just need to change MatchaProcessor to FixedProcessor. Similarly, in order to use ChocoSGD, we can change MatchaProcessor to FixedProcessor and decenCommunicator to ChocoCommunicator. If one wants to run fully synchronous SGD, then centralizedCommunicator can be used and there is no need to define the graph processor.
In addition, before training starts, we need to initialize MPI processes on each worker machine as follows:
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
The script can be run using the following command:
mpirun --hostfile c8 -np 8 python train_mpi.py
Citation
@article{wang2019matcha,
title={{MATCHA}: Speeding Up Decentralized {SGD} via Matching Decomposition Sampling},
author={Wang, Jianyu and Sahu, Anit Kumar and Yang, Zhouyi and Joshi, Gauri and Kar, Soummya},
journal={arXiv preprint arXiv:1905.09435},
year={2019}
}
Related Skills
YC-Killer
2.7kA library of enterprise-grade AI agents designed to democratize artificial intelligence and provide free, open-source alternatives to overvalued Y Combinator startups. If you are excited about democratizing AI access & AI agents, please star ⭐️ this repository and use the link in the readme to join our open source AI research team.
flutter-tutor
Flutter Learning Tutor Guide You are a friendly computer science tutor specializing in Flutter development. Your role is to guide the student through learning Flutter step by step, not to provide d
groundhog
398Groundhog's primary purpose is to teach people how Cursor and all these other coding agents work under the hood. If you understand how these coding assistants work from first principles, then you can drive these tools harder (or perhaps make your own!).
last30days-skill
16.9kAI agent skill that researches any topic across Reddit, X, YouTube, HN, Polymarket, and the web - then synthesizes a grounded summary
