Normalizer Free Nets (NFNets): Deepmind Releases A New State-Of-The-Art Image Classification Model

Written by Mostafa Ibrahim

Our smaller models match the test accuracy of an EfficientNet-B7 on ImageNet while being up to 8.7× faster to train, and our largest models attain a new state-of-the-art top-1 accuracy of 86.5%— Based on High-Performance Large-Scale Image Recognition Without Normalization Research Paper.

One of the most annoying things about training a model is the time it takes to train it and the amount of memory needed to fit in the data and the models. Since image classification is one of the most common machine learning tasks, Deepmind released a new model that matches the state-of-art (SOTA) performance with significantly less size, higher training speed, and fewer optimization techniques for simplicity.

In their work, they examine the current SOTA models such as EfficientNets and ResNets. In their analysis they pindown some of the optimization techniques that utilize a lot of memory without producing a significant value for performance. They prove that these networks can achieve the same performance without those optimization techniques.

Although the proposed model might be the most interesting bit, I still find the analysis of previous work to be very interesting. Simply because this is where most of the learning happens, we start understanding what could have been done better and why the newly proposed method/technique is an improvement over the old one.

Pre-requisite: Batch Normalisation

The paper starts off with an analysis of batch normalization. Why? because although it has shown great results and has been used heavily in tons of SOTA models, it has several disadvantages outlined by the paper [1], such as:

  1. Very expensive computational costs

  2. Introduces a lot of extra hyper-parameters that need further fine-tuning

  3. Causes a lot of implementation errors in distributed training

  4. Performs poorly on small batch sizes, which are used often in training larger models

But first, before removing batch normalization, we have to understand what benefits it brought to the models. Because we want to find a smarter way to still have those benefits, but with fewer cons. Those benefits are [1]:

  1. It downscales residual branches in deep ResNets. ResNets are one of the most widely used image classification networks. They usually extend to thousands of layers, and batch normalization reduces the scale of “hidden activations” that often cause gradients to behave in a funny way (gradient exploding problem)

  2. Eliminates mean-shift for popular activation functions such as ReLU and GeLU. In large networks, the output of those activation functions typically shifts towards very large values on average. This causes the network to predict the same label for all samples in certain situations (such as initialization) decreasing its performance. Batch normalization solves this mean-shift problem.

There are some other benefits, but I think you got the gist that it's all mainly about regularisation and smoothing the training process.

NFNets — Normaliser Free Networks:

Source: arxiv

Although there have been previous attempts to remove batch normalization (BN) in various papers, the results didn’t match the SOTA performance or training latency and seemed to fail on large batch sizes, and this is the main selling point of this paper. They succeed in removing (BN) without affecting performance, and with improving the training latency by a large margin.

To do that, they propose a gradient clipping technique called Adaptive Gradient Clipping (AGC) [1]. Essentially, gradient clipping is used to stabilize model training [1] by not allowing the gradient to go beyond a certain threshold. This allows using larger training rates and thus faster convergence without the exploding gradient problem.

However, the main issue is setting the threshold hyper-parameter, which is quite a difficult and manual task. The main benefit of AGC is to remove this hyperparameter. To do this we have to examine the gradient norms and the parameter norms.

Although I am quite interested in the mathematics behind every machine learning model, I understand that a lot of ML enthusiasts don’t enjoy reading a bunch of long differential equations, that’s why I will explain AGC from a theoretical/intuitive perspective rather than a mathematically rigorous one.

A norm is simply a measure of the magnitude of a vector. AGC is built on the premise that:

the unit-wise ratio of the norm of the gradients to the norm of the weights of a layer provides a simple measure of how much a single gradient descent step will change the original weights.

Source: arxiv

But why is that premise valid? Let’s back up a little. A very high gradient will make our learning unstable, and if that's the case then the ratio of the gradient of the weight matrix to the weight matrix will be very high.

That weight ratio is equivalent to:

learning rate x the ratio between the gradient and the weight matrix (which is our premise).

So essentially, the ratio proposed by that premise is a valid indicator as to whether we should clip the gradient or not. There is also another minor tweak. They have found that through multiple experiments, it's much better to use a unit-wise ratio of gradient norms instead of a layer-wise ratio (because each layer can have more than one gradient).

In addition to AGC, they also used dropout to substitute the regularisation effect that Batch normalization was offering.

They also used an optimization technique called Sharpness-Aware Minimization (SAM) [1].

Motivated by the connection between the geometry of the loss landscape and generalization — including a generalization bound that we prove here — we introduce a novel, effective procedure for instead simultaneously minimizing loss value and loss sharpness. In particular, our procedure, Sharpness-Aware Minimization (SAM), seeks parameters that lie in neighborhoods having uniformly low loss; this formulation results in a min-max optimization problem on which gradient descent can be performed efficiently. We present empirical results showing that SAM improves model generalization across a variety of benchmark datasets (e.g., CIFAR-{10, 100}, ImageNet, finetuning tasks) and models, yielding novel state-of-the-art performance for several.

Source: SAM arxiv paper

The idea of loss sharpness seems quite interesting and I might be exploring it in another article for the sake of brevity here. One final point to note here though is that they make a small modification to SAM [1] to reduce its computational cost by 20–40%! and they only employ it on their 2 largest model variants. It’s always great to see additions being made to such techniques instead of just using them out of the box. I think this shows that they have analyzed it greatly before using it (and thus were able to optimize it a bit).

Final thoughts and take away

Who would have thought that replacing a minor optimization technique such as batch normalization would result in a 9x improvement in training latency. I think this sends a message of being a bit more skeptical about popular optimization techniques that are used everywhere. In all fairness, I have been a victim of this crime before, I used to just put every popular optimization technique into my machine learning projects without fully examining its pros and cons. I guess this is one of the main benefits of reading ML papers, the analysis of previous SOTAs!

Official Code Availability-

Community Code Implementation-


[1] High-Performance Large-Scale Image Recognition Without Normalization. Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan. 2021