Multi GPU training with Pytorch

Training deep learning models consist of a high amount of numerical calculations which can be performed to a great extent in parallel. Since GPUs offer far more cores than CPUs, GPUs (>10k cores) outperform CPUs (<= 64 cores) in most deep learning applications by factors.

Table of contents

The next level of performance is reached by scaling the calculations accross multiple GPUs, therefore AIME servers can be equipped with up to eight high performance GPUs.

To utilize the full power of the AIME machines, it is important to ensure all installed GPUs are participating effectively in the deep learning training.

The following article explains how to train a model with the PyTorch framework using multiple GPUs. The first part deals with an easy but not optimal approach using Pytorchs DataParallel. The second part explaines a more advance solution for improved performance with multiple processes using DistributedDataParallel.

Multi GPU training in a single process (DataParallel)

The most easiest way to utilize all installed GPUs with PyTorch is the usage of the PyTorch built-in function DataParallel from the PyTorch module torch.nn.parallel. This can be done in almost the same way like a single GPU training. After your model is initialized, just modify your model with the following line:

model = torch.nn.parallel.DataParallel(model, device_ids=list(range(<num_gpus>)), dim=0)

with <num_gpus> as the number of GPUs to use.

Be aware, that the batch size used in your data loader is equal to the global batch size over all GPUs. So if you want to use the local batch size of each GPU, you need to multiply it with the number of GPUs.

Here is a fully working example of multi GPU training with a resnet50 model from the torchvision library using DataParallel.

#!/usr/bin/env python3

from pathlib import Path
import torch
import torchvision


def load_data(num_gpus):
    transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                                                 ])
    dataset = torchvision.datasets.ImageFolder(root=, transform=transforms)

    dataloader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=64,
        shuffle=False,
        num_workers=4*num_gpus
                                                )
    return dataloader

def save_model(epoch, model, optimizer):
    """Saves model checkpoint on given epoch with given data name.
    """
    checkpoint_folder = Path.cwd() / 'model_checkpoints'
    if not checkpoint_folder.is_dir():
        checkpoint_folder.mkdir()
    file = checkpoint_folder / f'epoch_{epoch}.pt'
    if not file.is_file():
        file.touch()
    torch.save(
        {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        },
        file
                )
    return True

def load_model(epoch, model, optimizer):
    """Loads model state from file.
    """
    file = Path.cwd() / 'model_checkpoints' / f'epoch_{epoch}.pt'
    checkpoint = torch.load(file)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return model, optimizer

def run_training(num_gpus):

    model = torchvision.models.resnet50(pretrained=False)
    model = model.cuda()
    model = torch.nn.parallel.DataParallel(model, device_ids=list(range(num_gpus)), dim=0)

    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
    criterion = torch.nn.CrossEntropyLoss()
    criterion.cuda()
    model.train()
    num_epochs = 30
    dataloader = load_data(num_gpus)
    total_steps = len(dataloader)
    for epoch in range(1, num_epochs):
        print(f'\nEpoch {epoch}\n')
        if epoch > 1:
            model, optimizer = load_model(epoch-1, model, optimizer)
        for step, (images, labels) in enumerate(dataloader, 1):
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            if step % 10 == 0:
                print(f'Epoch [{epoch} / {num_epochs}], Step [{step} / {total_steps}], Loss: {loss.item():.4f}')
        save_model(epoch, model, optimizer)

if __name__ == "__main__":
    num_gpus = torch.cuda.device_count()
    print('num_gpus: ', num_gpus)
    run_training(num_gpus)

While the usage of DataParallel is quiet easy and the performance enhancement is already visible compared to a single GPU, there are still ways to even further improve the performance, since with DataParallel all calculations are happening in the same process and the GPUs are therefore not utilized to their full capacity. For even better performance we need multiple processes, one process per GPU. This can be done with DistributedDataParallel, explained in the following section. The table below shows a comparison in training performance between both methods measured with our benchmark tool https://github.com/aime-team/pytorch-benchmarks. You can see the training performance of DistributedDataParallel is up to 17% better then the performance of DataParallel.

Number of GPUs Images per second with Data Parallel Images per second with Distributed Data Parallel
1x NVIDIA RTX 3090 473 -
2x NVIDIA RTX 3090 883 944
4x NVIDIA RTX 3090 1526 1788

Multi GPU training with multiple processes (DistributedDataParallel)

The PyTorch built-in function DistributedDataParallel from the PyTorch module torch.nn.parallel is able to distribute the training over all GPUs with one subprocess per GPU utilizing its full capacity. But compared to DataParallel there are some additional steps necessary. First of all we need to set the environment variables master address and master port with the following lines:

os.environ['MASTER_ADDR'] = 'localhost'  
os.environ['MASTER_PORT'] = '12355'

The master port can be changed to an arbitrary number.

Then we need to spawn multiple processes, one for each GPU with the spawn-method from the module torch.multiprocessing:

torch.multiprocessing.spawn(  
    run_training_process_on_given_gpu, 
    args=(args, ), 
    nprocs=<num_gpus>, 
    join=True)

The function run_training_process_on_given_gpu has to contain all the training code for each GPU with its arguments args (without rank) as a tuple. nprocs is the number of processes to be spawned (i.g. the number of GPUs). In the implentation of the function run_training_process_on_given_gpu the first positional argument has to be the rank of the process. The spawn-method now initializes nprocs processes. The positional argument rank is automatically filled by the nprocs processes starting at rank 0 incremental. In each process a process group is now to be initialized with the method init_process_group from the module torch.distributed:

torch.distributed.init_process_group(backend=<backend>, rank=rank, world_size=<num_gpus>, init_method='env://')  

Common values for the term <backend> are 'gloo' and 'nccl', whereas 'nccl' is recommended for multi GPU training. More details about backends for distributed training can be found under https://pytorch.org/docs/stable/distributed.html. The world_size here is equal to the number of GPUs. The initialization method 'env://'  pulls all informations it needs from the environment. The rank is here the rank of the GPU set by the spawn-method above for each GPU.

Now your model needs to be prepared for distributed training with the following line. The following steps need to be done in each process with the given rank:

model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])  

Now it is necessary to get your model and your initialized loss function (criterion) into the memory of each participating GPU.

torch.cuda.set_device(rank)  
model.cuda(rank)  
criterion = torch.nn.CrossEntropyLoss()  
criterion.cuda(rank)

Next step is to prepare the dataloader for distributed training. First set the parameter for the transformation of the data to the model.

transforms = torchvision.transforms.Compose([  
    torchvision.transforms.Resize(256),  
    torchvision.transforms.CenterCrop(224),  
    torchvision.transforms.ToTensor(),  
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], 	
    std=[0.229, 0.224, 0.225]),])

Then initialize the dataset, f.i. with the method ImageFolder() from the module torchvision.datasets:

dataset = torchvision.datasets.ImageFolder(root=<image_destination>, transform=transforms)

Then we need to initialize the DistributedSampler from the module torch.utils.data.distributed.

sampler = torch.utils.data.distributed.DistributedSampler(dataset)

Now we can initialize the dataloader from the module torch.utils.data with the given dataset and sampler:

dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=<batch_size>, shuffle=False, num_workers=4*<num_gpus>, pin_memory=True, sampler=sampler)

To get the best performance, it is recommended to use four times the number of the participating GPUs as the numbers of workers. Another performance boost can be done by setting pin_memory=True in the dataloader in combination with non_blocking=True while shifting the data to GPU memory with the cuda() call (see down below). In opposite to a single GPU or DataParallel training, the shuffle argument needs to be set as False, since the sampler is already taking care of shuffling the data. Here the batch_size is equal to the local batch size of each GPU and not to the global batch size as in DataParallel. Now the model and the dataloader are ready for the distributed training.

The last modification for multi GPU training happens in the training loop. The data needs to get shifted to the memory of the GPUs after each step with the cuda() command:

for step, (data, label) in enumerate(dataloader):  
    data, label = data.cuda(rank, non_blocking=True), label.cuda(rank, non_blocking=True)

If the argument non_blocking is True, the data loader doesn't wait with the next command until the data is shifted to the memory of GPU. pin_memory=True means the data transport is happening in a specific defined part of the GPU memory, being blocked for other tasks. The combination gives another performance boost on the cost of a higher memory consumption.

Saving the model

To save the trained model to the disk just add the following command:

if rank == 0:
    torch.save(
        {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        },
        file
                )

with file as a string or PosixPath containing the file location of the checkpoint file. It is sufficient to save the checkpoint for one GPU only (here with rank 0), since it already contains the information of all GPUs. So saving it for each GPU while waiting for each other is redundant.

Loading the model

To load a saved checkpoint of your model to all participating processes, use the following commands:

torch.distributed.barrier()  
map_location = {'cuda:0': f'cuda:{rank}'}  
checkpoint = torch.load(file, map_location=map_location)  
model.load_state_dict(checkpoint['model_state_dict'])  
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

with file as a string or PosixPath containing the file location of the checkpoint file. The torch.distributed.barrier() command ensures the processes to be syncronized and the map_location cares about the distribution over all processes.

Code Summary

As a summary here is an example of a fully working multi GPU training with a resnet50 model from the torchvision library using DistributedDataParallel:

#!/usr/bin/env python3

import sys
import os
from pathlib import Path
import torch
import torchvision


def save_model(epoch, model, optimizer):
    """Saves model checkpoint on given epoch with given data name.
    """
    checkpoint_folder = Path.cwd() / 'model_checkpoints'
    if not checkpoint_folder.is_dir():
        checkpoint_folder.mkdir()
    file = checkpoint_folder / f'epoch_{epoch}.pt'
    if not file.is_file():
        file.touch()
    torch.save(
        {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        },
        file
                )
    return True

def load_model(rank, epoch, model, optimizer):
    """Loads model state from file to the GPU with given rank.
    """
    torch.distributed.barrier()
    map_location = {'cuda:0': f'cuda:{rank}'}
    file = Path.cwd() / 'model_checkpoints' / f'epoch_{epoch}.pt'
    checkpoint = torch.load(file, map_location=map_location)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return model, optimizer

def load_data(num_gpus):
    transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                                                ])
    dataset = torchvision.datasets.ImageFolder(root=, transform=transforms)
    sampler = torch.utils.data.distributed.DistributedSampler(dataset)

    dataloader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=64,
        shuffle=False,
        num_workers=4*num_gpus,
        pin_memory=True,
        sampler=sampler
                                                )
    return dataloader


def run_training_process_on_given_gpu(rank, num_gpus):
    torch.cuda.set_device(rank)
    torch.distributed.init_process_group(backend='nccl', rank=rank,
                    world_size=num_gpus, init_method='env://')
    model = torchvision.models.resnet50(pretrained=False)
    model = model.cuda(rank)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])

    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
    criterion = torch.nn.CrossEntropyLoss()
    criterion.cuda(rank)
    model.train()
    num_epochs = 30
    dataloader = load_data(num_gpus)
    total_steps = len(dataloader)
    for epoch in range(1, num_epochs):
        if rank == 0:
            print(f'\nEpoch {epoch}\n')
        if epoch > 1:
            model, optimizer = load_model(rank, epoch-1, model, optimizer)

        for step, (images, labels) in enumerate(dataloader, 1):
            images, labels = images.cuda(rank, non_blocking=True), labels.cuda(rank, non_blocking=True)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            if step % 10 == 0:
                if rank == 0:
                    print(f'Epoch [{epoch} / {num_epochs}], Step [{step} / {total_steps}], Loss: {loss.item():.4f}')
        if rank == 0:
            save_model(epoch, model, optimizer)
    torch.distributed.destroy_process_group()


if __name__ == "__main__":
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    num_gpus = torch.cuda.device_count()
    print('num_gpus: ', num_gpus)
    torch.multiprocessing.spawn(run_training_process_on_given_gpu, args=(num_gpus, ), nprocs=num_gpus, join=True)

As shown, to utilize the full capacity of all participating GPUs in PyTorch, it is compulsary to use the module DistributedDataParallel. With our benchmark tool https://github.com/aime-team/pytorch-benchmarks, containing the code for both methods, you're able to compare the performance enhancement.

Spread the word

Keep reading...