How to Quantize an MNIST network to 8 bits in Pytorch from scratch (No retraining required).

Karanbir Chahal
4 min readAug 4, 2019

Update: The blog post for quantization aware training is online and linked here, through this we can train and quantize our model to run in 4 bits !

Hello, I wanted to share my journey into how I was able to run inference of a neural network using fixed point arithmetic (8 bit arithmetic). The state of Pytorch as of today allows for only 32 bit or 16 bit floating point training and inference. If one wants to compress a Pytorch neural network using quantisation today, he/she would need to import it to onnx, convert to caffe and run a glow quantisation compiler over the computational graph which finally yields a quantised network.

Before delving into how to quantise a net, let’s look at why we need to quantise a net. The simple answer is increased speed of inference, floating point arithmetic usually takes longer to compute compared to fixed point (integer) arithmetic. An added advantage is space savings, a floating point network is 4 times bigger in size than a 8 bit quantized network. This is particularly relevant for edge devices (mobile phones, IOTs) as low storage space and computational demands are critical for it to be a productionable solution.

Before moving forward, here’s a working Colab notebook to run and verify this quantized network for those who just want to look at the code. THis example implements Quantisation from scratch in vanilla Pytorch (no external libs or frameworks)

Now that we have justified the need to quantize let’s look at how we quantise a simple MNIST model. Let’s use a simple model architecture for solving MNIST, that uses 2 conv layers and 2 fully connected layers.

Let’s train this network using a simple training script as given below:

Now, we can train this network using a simple ```model = main()``` command. Once the model is trained for 10 epochs, let us test this model via the following test function.

Upon testing the model we get an accuracy of 99%. (~9900/10000) correctly classified. Now let’s look into performing quantisation via a technique called post training quantisation.

The gist to this is we convert the activations and the weights of the neural net to 8 bit integers (with range 0 to 255). Hence, we perform all arithmetic in fixed point and hope that there is no significant degradation to accuracy.

To quantize and dequantize a tensor we use the following formula:

x_Float = scale*(x_Quant -zero_point). Hence,

x_Quant = (x_Float/scale) + zero_point.

Here scale is equal (max_val — min_val) / (qmax — qmin)

Where max_val and min_val are maximum and minimum values of X tensor respectively. qmin and q_max represents the range of an 8 bit number (0 and 255 respectively). The scale scales the quantised net and the zero point shifts the number. The dequantisation and quantisation functions given below give more clarity as how a floating point tensor is converted to an 8 bit tensor and vice versa.

It is important to note that the scale is a floating point number whereas the zero point is an integer (8 bit). However modern implementations bypass this floating point multiplication of the scale by doing some fancy bit tricks ( i.e an approximation) that is shown to have negligible effect on the accuracy of the net.

Now that we have these functions ready, we can go on to quantising our weights and activations by modifying the forward pass of our MNIST network. The modified forward pass looks something like this.

Here, we quantise the activations before inputting it into the convolutional layer conv1 and use a function called quantizeLayer that takes in a conv or a linear layer along with the activations, scale and zero point of the quantised activation, the quantizeLayer() function performs the forward pass of the layer fully quantized. Please look at the code above if you have any doubts. You might wonder what quantize_tensor_act() function does, it simply quantises the activation x using the min and max values the tensor x usually has by running it through a 1000 examples and averaging the results. It uses these stats to calculate the scale and hence the zero point, which as imperative to quantize a tensor. Now, let’s put all of it together and run the network using this new quantForward method and check the final accuracy.

It’s still 99% percent ! Of course this is just a toy example and I have heavily skipped the quantisation theory but this is the basic gist of how quantisation is performed in neural nets. It is in’t voodoo magic but simply linear algebra and a few clever tricks to get around the pytorch layers.

Hope this was a fun ride for you guys, please check out this working Colab notebook to run and verify this quantized network !

If any of you want to learn more about quantization, I have embedded the resources I learnt from below. They were indeed invaluable.

  1. Jacob Benoit et al’s Quantisation Paper (Google)
  2. Raghuraman’s Paper on Quantisation (Google, he’s now at Facebook)
  3. Distiller Docs on Quantisation
  4. Gemmlowp’s Quantisation Tutorial

--

--