Giter VIP home page Giter VIP logo

apex's Introduction

Introduction

This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intent of Apex is to make up-to-date utilities available to users as quickly as possible.

Full API Documentation: https://nvidia.github.io/apex

Contents

1. Amp: Automatic Mixed Precision

apex.amp is a tool to enable mixed precision training by changing only 3 lines of your script. Users can easily experiment with different pure and mixed precision training modes by supplying different flags to amp.initialize.

Webinar introducing Amp (The flag cast_batchnorm has been renamed to keep_batchnorm_fp32).

API Documentation

Comprehensive Imagenet example

DCGAN example coming soon...

Moving to the new Amp API (for users of the deprecated "Amp" and "FP16_Optimizer" APIs)

2. Distributed Training

apex.parallel.DistributedDataParallel is a module wrapper, similar to torch.nn.parallel.DistributedDataParallel. It enables convenient multiprocess distributed training, optimized for NVIDIA's NCCL communication library.

API Documentation

Python Source

Example/Walkthrough

The Imagenet example shows use of apex.parallel.DistributedDataParallel along with apex.amp.

Synchronized Batch Normalization

apex.parallel.SyncBatchNorm extends torch.nn.modules.batchnorm._BatchNorm to support synchronized BN. It allreduces stats across processes during multiprocess (DistributedDataParallel) training. Synchronous BN has been used in cases where only a small local minibatch can fit on each GPU. Allreduced stats increase the effective batch size for the BN layer to the global batch size across all processes (which, technically, is the correct formulation). Synchronous BN has been observed to improve converged accuracy in some of our research models.

Checkpointing

To properly save and load your amp training, we introduce the amp.state_dict(), which contains all loss_scalers and their corresponding unskipped steps, as well as amp.load_state_dict() to restore these attributes.

In order to get bitwise accuracy, we recommend the following workflow:

# Initialization
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)

# Train your model
...
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
...

# Save checkpoint
checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...

# Restore
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')

model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])

# Continue training
...

Note that we recommend restoring the model using the same opt_level. Also note that we recommend calling the load_state_dict methods after amp.initialize.

Installation

Containers

NVIDIA PyTorch Containers are available on NGC: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch. The containers come with all the custom extensions available at the moment.

See the NGC documentation for details such as:

  • how to pull a container
  • how to run a pulled container
  • release notes

From Source

To install Apex from source, we recommend using the nightly Pytorch obtainable from https://github.com/pytorch/pytorch.

The latest stable release obtainable from https://pytorch.org should also work.

Linux

For performance and full functionality, we recommend installing Apex with CUDA and C++ extensions via

git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

Apex also supports a Python-only build via

pip install -v --disable-pip-version-check --no-cache-dir ./

A Python-only build omits:

  • Fused kernels required to use apex.optimizers.FusedAdam.
  • Fused kernels required to use apex.normalization.FusedLayerNorm and apex.normalization.FusedRMSNorm.
  • Fused kernels that improve the performance and numerical stability of apex.parallel.SyncBatchNorm.
  • Fused kernels that improve the performance of apex.parallel.DistributedDataParallel and apex.amp. DistributedDataParallel, amp, and SyncBatchNorm will still be usable, but they may be slower.

Pyprof support has been moved to its own dedicated repository. Pyprof is deprecated in Apex and the pyprof directory will be removed by the end of June 2022.

[Experimental] Windows

pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . may work if you were able to build Pytorch from source on your system. A Python-only build via pip install -v --no-cache-dir . is more likely to work.
If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.

apex's People

Contributors

a-maci avatar alpha0422 avatar carlc-nv avatar cbcase avatar chochowski avatar crcrpar avatar csarofeen avatar definitelynotmcarilli avatar ekrimer avatar eqy avatar fdecayed avatar henrymai avatar jjsjann123 avatar kevinstephano avatar kexinyu avatar mcarilli avatar mkolod avatar nanz-nv avatar ngimel avatar ptrblck avatar raulpuric avatar romerojosh avatar rxy1212 avatar schetlur avatar seryilmaz avatar slayton58 avatar syed-ahmed avatar thorjohnsen avatar xwang233 avatar yjk21 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.