Mmengine
OpenMMLab Foundational Library for Training Deep Learning Models
Install / Use
/learn @open-mmlab/MmengineREADME
Introduction | Installation | Get Started | 📘Documentation | 🤔Reporting Issues
</div> <div align="center">English | 简体中文
</div> <div align="center"> <a href="https://openmmlab.medium.com/" style="text-decoration:none;"> <img src="https://user-images.githubusercontent.com/25839884/219255827-67c1a27f-f8c5-46a9-811d-5e57448c61d1.png" width="3%" alt="" /></a> <img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" /> <a href="https://discord.com/channels/1037617289144569886/1073056342287323168" style="text-decoration:none;"> <img src="https://user-images.githubusercontent.com/25839884/218347213-c080267f-cbb6-443e-8532-8e1ed9a58ea9.png" width="3%" alt="" /></a> <img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" /> <a href="https://twitter.com/OpenMMLab" style="text-decoration:none;"> <img src="https://user-images.githubusercontent.com/25839884/218346637-d30c8a0f-3eba-4699-8131-512fb06d46db.png" width="3%" alt="" /></a> <img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" /> <a href="https://www.youtube.com/openmmlab" style="text-decoration:none;"> <img src="https://user-images.githubusercontent.com/25839884/218346691-ceb2116a-465a-40af-8424-9f30d2348ca9.png" width="3%" alt="" /></a> <img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" /> <a href="https://space.bilibili.com/1293512903" style="text-decoration:none;"> <img src="https://user-images.githubusercontent.com/25839884/219026751-d7d14cce-a7c9-4e82-9942-8375fca65b99.png" width="3%" alt="" /></a> <img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" /> <a href="https://www.zhihu.com/people/openmmlab" style="text-decoration:none;"> <img src="https://user-images.githubusercontent.com/25839884/219026120-ba71e48b-6e94-4bd4-b4e9-b7d175b5e362.png" width="3%" alt="" /></a> </div>What's New
v0.10.6 was released on 2025-01-13.
Highlights:
- Support custom
artifact_locationin MLflowVisBackend #1505 - Enable
exclude_frozen_parametersforDeepSpeedEngine._zero3_consolidated_16bit_state_dict#1517
Read Changelog for more details.
Introduction
MMEngine is a foundational library for training deep learning models based on PyTorch. It serves as the training engine of all OpenMMLab codebases, which support hundreds of algorithms in various research areas. Moreover, MMEngine is also generic to be applied to non-OpenMMLab projects. Its highlights are as follows:
Integrate mainstream large-scale model training frameworks
Supports a variety of training strategies
Provides a user-friendly configuration system
- Pure Python-style configuration files, easy to navigate
- Plain-text-style configuration files, supporting JSON and YAML
Covers mainstream training monitoring platforms
Installation
<details> <summary>Supported PyTorch Versions</summary>| MMEngine | PyTorch | Python | | ------------------ | ------------ | -------------- | | main | >=1.6 <=2.1 | >=3.8, <=3.11 | | >=0.9.0, <=0.10.4 | >=1.6 <=2.1 | >=3.8, <=3.11 |
</details>Before installing MMEngine, please ensure that PyTorch has been successfully installed following the official guide.
Install MMEngine
pip install -U openmim
mim install mmengine
Verify the installation
python -c 'from mmengine.utils.dl_utils import collect_env;print(collect_env())'
Get Started
Taking the training of a ResNet-50 model on the CIFAR-10 dataset as an example, we will use MMEngine to build a complete, configurable training and validation process in less than 80 lines of code.
<details> <summary>Build Models</summary>First, we need to define a model which 1) inherits from BaseModel and 2) accepts an additional argument mode in the forward method, in addition to those arguments related to the dataset.
- During training, the value of
modeis "loss", and theforwardmethod should return adictcontaining the key "loss". - During validation, the value of
modeis "predict", and the forward method should return results containing both predictions and labels.
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
</details>
<details>
<summary>Build Datasets</summary>
Next, we need to create Datasets and DataLoaders for training and validation. In this case, we simply use built-in datasets supported in TorchVision.
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
shuffle=True,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
val_dataloader = DataLoader(batch_size=32,
shuffle=False,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
</details>
<details>
<summary>Build Metrics</summary>
To validate and test the model, we need to define a Metric called accuracy to evaluate the model. This metric needs to inherit from BaseMetric and implements the process and compute_metrics methods.
from mmengine.evaluator import BaseMetric
class Accuracy(BaseMetric):
def process(self, data_batch, data_samples):
score, gt = data_samples
# Save the results of a batch to `self.results`
self.results.append({
'batch_size'
