Aggressive Quantization: How to run MNIST on a 4 bit Neural Net using Pytorch

Karanbir Chahal
3 min readOct 10, 2019

--

In the previous blog post we went through the process of quantizing our neural net and were successfully able to run inference in 8 bits with next to no loss in accuracy. I bet the question of “can we do better” must have come to your mind. Well, the same came to me and I tried to figure out the limits of how much we could quantize our network.

As always this blog post comes with the code linked in a Google Colab notebook and I encourage everyone to look through the code as it really is a very simple example of how to quantize a vanilla neural net. And best of all it does work. For the skeptic who doesn’t believe me (quite understandably) I have visualized the weights of the neural network in a Matplotlib histogram to be utterly confident that the quantization is being done properly.

Now on to the method.

We try and use the same technique that we used in the earlier blog post of post training quantization and simply reduce the number of bits to 4 from 8. We see that the accuracy plummets from 99% to 11%. There is something very clearly wrong. To remedy this we use a technique called quantization aware training which is exactly what it sounds like. We quantise the weights and activations of the neural network during training and let the network learn the new range of the weights. Once the network is aware of the quantization it can adjust it’s weights in an attempt to minimize the loss.

One might wonder what the backpropogation for a quantized function look like. Well it’s as easy as pie. We propagate the gradients straight through the function and below without changing it in any way. This is also called a “straight through” estimator. The code for it is given below:

The quantization and dequantization functions can be looked at via the colab notebook linked above.

Now we need to look at how the forward pass for this training looks like. Before I post the code for the train function. I would like to inform the reader about the problem with quantizing activations. Unlike our neural net weights, activation change with time as the weights change. It is difficult to get the min and max stats of each activation to be able to calculate the scale factor and zero point for quantization effectively. We can keep a running mean of the minimum and maximum activation in a layer but empirically that is shown to not perform well. This is due to the observation that during the start of training the weights change quickly leading to activations changing distributions quickly. Hence the stats calculated for scale factor and zero point quickly become out dated. To combat this we use a metric called exponentially moving averages (EMA) to calculate the min and max stats of an activation layer. The intuition of using EMA is that the metrics change slowly and sustain over longer period of times. So recent substantial jitter that occurs when the weights are moving around is not taken account too much and the whole history of an activation layer is taken into account of the stats. It has been observed that it stabilizes training when we are performing quantization during training.

During training we switch off activation quantization and just perform weight quantization through our FakeQuantOp which uses a straight through estimator for backpropagation. The stats for the quantization of activations are tracked throughout 5 epochs of training using exponentially moving averages. By the 5th epoch, the activation’s stats are good enough to use for activation quantization and hence we start to use activation quantization and train for a few more epochs.

It is observed that the training loss goes up when we introduce activation quantization which is perfectly natural. The network is able to adjust it’s weights and return to a high accuracy. We observe we are able to get back our ~98% accuracy on MNIST. The forward pass of the quantization aware training is given below

Please let me know if you have any doubts and again look at the code to see a fully working simple example !

--

--