This note discusses how to build parallelism support for model training.

There are many solutions to distributed training. They largely fall into two categories, synchronized training and asynchronized training.

Synchronized Training

Synchronized training means all trainer nodes are synchronized when training the model.

Data Parallelism

Data parallelism is the simplest way to do distributed model training. The idea is pretty simple, i.e., shard the data based on the batch dimension, and use a trainer to run the forward path of the model for each shard, and then do an all reduce to collect the gradients to update weights. The following diagram shows the idea.

The implementation requires us to deploy our code onto multiple nodes. This programming model is called SPMD (Single Program Multiple Data). The model itself is unchanged. The only thing changed here is the training loop. There are a few more steps in the training loop, e.g., all-reduce to get the average gradients.

Model Parallelism

Data parallelism works in most of the cases. However, if you have a really large model, and it doesn’t fit on one machine’s memory, you’ll have to run it on multiple nodes. In this case, we’ll need model parallelism.

There are two solutions to the model parallelism.

Tensor Parallelism

Tensor parallelism is a very elegant solution. It doesn’t require you to change the code of your model. All you need to do is to define the sharding strategy of your tensors.

The idea was first proposed by a google research paper GSPMD. What it does is basically to run your model assuming your tensors are sharded in a certain way. The layout of the tensors are calculated along your model computation graph.

Data parallelism is basically a special case of tensor parallelism, i.e., sharding on the batch dimension.

Model Pipeline

Another way to do model parallelism is to place your model graph on multiple nodes. This requires model developers to change the code of the model, which is more complicated from the model engineers’ perspectives.

ASynchronized Training

Asynchronized training is another way to train model. In asynchronized training, each worker train its model without communicating with other worker nodes.

Parameter Server

Parameter Server is an asynchronized training strategy. It was first introduced by this paper. The idea is basically to place variables on variable servers, and each worker pulls the variables, run forward and backward pass, and update the weights and eventually send them back to variable servers.