Batch Normalization, Layer Normalization and Root Mean Square Layer Normalization: A Comprehensive Guide with Python Implementations

Introduction

Stabilizing and accelerating the training of neural networks often hinge on the normalization techniques employed. While the theory behind normalization appears straightforward, its practical applications come in various flavours, each with unique merits and shortcomings.

This post will explore three popular types of normalizations:

  1. Batch Normalization (BatchNorm)

  2. Layer Normalization (LayerNorm)

  3. Root Mean Square Layer Normalization (RMSNorm)

We'll cover:

  • The mathematics behind each technique

  • A discussion on computational complexity

  • The pros and cons of each method

Batch Normalization

Overview

Batch Normalization, introduced by Sergey Ioffe and Christian Szegedy [1], aims to normalize the outputs of a layer across each feature dimension for a given mini-batch during training. To put it simply, it uses the statistics (mean and variance) computed across all instances in the mini-batch.

Equation

The output \(\hat{x}\) is computed as:

\[\hat{x} = \frac{x - \mathbb{E}_{\text{mini-batch}}(x)}{\sqrt{Var_{\text{mini-batch}}(x) + \epsilon}}​ \cdot \gamma + \beta\]

Here, \(\mathbb{E}_{\text{mini-batch}}(x)\) and \(Var_{\text{mini-batch}}(x)\) are the mean and variance, computed per feature over the mini-batch, and \(\epsilon\) is a small constant for numerical stability. \(\gamma\) and \(\beta\) are scaling and shifting learnable parameters, respectively.

Running Statistics

Batch Norm also demands the calculation and storage of running statistics for both the mean and variance. During training, these are calculated as the exponential moving average (EMA), updated using a scalar momentum term \(\alpha\), such that \(y_{EMA_i} = \alpha y_{EMA_{i-1}} + (1 - \alpha)y_i\) where \(i\) is the current training step. During inference, the stored running statistics are used to normalise the single sample.

Properties

Compared to no-normalization, Batch Norm:

  • Reduces internal covariate shift (i.e. reduces the change in the distributions of layers' input)

  • Speeds up convergence

  • Enables higher learning rates

  • Less sensitive to initialization

Python Implementation

class BatchNorm(nn.Module):
    def __init__(
        self, 
        size: int,
        eps: float = 1e-5, 
    ):
        """
        Batch Normalization.
        Assumes the shape of the input x is (batch, seq_len, d_model)

        Args:
            size: shape of the feature dimention (i.e. d_model)
            eps: For numerical stability. Defaults to 1e-5.
        """
        super(BatchNorm, self).__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(size), requires_grad=True)
        self.beta = nn.Parameter(torch.ones(size), requires_grad=True)

    def forward(self, x):
        x_var, x_mean = torch.var_mean(x, dim=[0,1], keepdim=True, correction=0)
        x_std = torch.sqrt(x_var + self.eps)

        x_norm = (x - x_mean)/ x_std

        return self.gamma.unsqueeze(0).unsqueeze(1) * x_norm + self.beta.unsqueeze(0).unsqueeze(1)

Assuming our input \(x\) has the shape (batch, seq_len, d_model), for batch normalization, we normalize across both the batch and sequence length dimensions (0 and 1 respectively), but keep the feature dimension (d_model) intact. This is because BatchNorm aims to stabilize the distribution of each feature over the mini-batch.

Layer Normalization

Overview

Layer Normalization [2], unlike Batch Norm, normalizes the features for each individual data point in a batch, making it less susceptible to variations in batch size.

Equation

The output \(\hat{x}\) is computed similarly to Batch Norm but differs in the axis over which \(\mathbb{E}(x)\)and \(Var(x)\) are computed.

\[\hat{x} = \frac{x - \mathbb{E}_{\text{features}}(x)}{\sqrt{Var_{\text{feature}}(x) + \epsilon}}​ \cdot \gamma + \beta\]

Here, \(\mathbb{E}_{\text{features}}(x)\) and \(Var_{\text{features}}(x)\) are the mean and variance calculated over the feature dimension.

Properties

  • Less sensitive to batch size than Batch Norm

  • Works well for sequence models

  • Stabilizes training

  • Accelerates convergence

Python Implementation

class LayerNorm(nn.Module):
    def __init__(
        self, 
        size: int,
        eps: float = 1e-5, 
    ):
        """
        Layer Normalization.
        Assumes the shape of the input x is (batch, seq_len, d_model)

        Args:
            size: shape of the feature dimention (i.e. d_model)
            eps: For numerical stability. Defaults to 1e-5.
        """
        super(Layernorm, self).__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(size), requires_grad=True)
        self.beta = nn.Parameter(torch.ones(size), requires_grad=True)

    def forward(self, x):
        x_var, x_mean = torch.var_mean(x, dim=-1, keepdim=True, correction=0)
        x_std = torch.sqrt(x_var + self.eps)

        x_norm = (x - x_mean)/ x_std

        return self.gamma.unsqueeze(0).unsqueeze(1) * x_norm + self.beta.unsqueeze(0).unsqueeze(1)

Assuming our input \(x\) has the shape (batch, seq_len, d_model), Layer Norm normalizes across the feature dimension (d_model) for each sequence in the batch. The rationale is to normalize all features for a single data point to have zero mean and unit variance, making the model less sensitive to the scale of input features.

RMS Normalization

Overview

RMSNorm [3] is a variant of LayerNorm that 1) uses the root mean square, \(\mathbb{E}(x^2)\), instead of the standard deviation for re-scaling and 2) does not use the re-centering operation. The authors hypothesize that the re-centering invariant property in LayerNorm is dispensable, and only keep the re-scaling invariance property in RMS Norm.

Equation

The output \(\hat{x}\) is calculated as:

\[\hat{x} = \frac{x}{ \sqrt{\mathbb{E}_{\text{feature}}(x^2) + \epsilon}} \cdot \gamma​\]

Properties

  • Computationally simpler and thus more efficient than Layer Norm

Python Implementation

class RMSNorm(nn.Module):
    def __init__(
        self, 
        size: int,
        eps: float = 1e-5, 
    ):
        """
        Root-Mean-Square Layer Normalization.
        Assumes the shape of the input x is (batch, seq_len, d_model)

        Args:
            size: shape of the feature dimention (i.e. d_model)
            eps: For numerical stability. Defaults to 1e-5.
        """
        super(RMSnorm, self).__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(size), requires_grad=True)

    def forward(self, x):
        rms = torch.sqrt((x ** 2).mean(dim=-1, keepdim=True) + self.eps) # as an alternative can also use the frobenius norm to compute rms
        x_norm = x / rms

        return self.gamma.unsqueeze(0).unsqueeze(1) * x_norm

Assuming our input \(x\) has the shape (batch, seq_len, d_model), for RMS Layer Normalization, like LN, we normalize across the feature dimension (d_model). We use the root mean square of the feature values for each data point in the sequence. This method is computationally efficient and can be more robust to outliers.

Computational Complexity and Memory Requirements

  • Batch Norm: Requires storage of running statistics, making it harder to parallelize.

  • Layer Norm: Less computationally intensive as no running statistics are needed.

  • RMSNorm: Even less computationally intensive than LayerNorm due to no re-centering.

Use-cases and Recommendations

  • Batch Normalization: This technique is particularly strong in convolutional architectures where batch sizes are often large enough for the mean and variance estimates to be reliable. However, it's not ideal for models like RNNs and Transformers where sequence lengths can vary. Its reliance on running statistics also poses challenges for online learning scenarios and can add complexity when attempting to parallelize the model across multiple devices.

  • Layer Normalization: Highly effective for sequence models such as RNNs and Transformers. It's also a better choice for scenarios with small batch sizes as it computes statistics for each data point independently, negating the need for a large batch to estimate population statistics.

  • RMSNorm: If computational efficiency is your priority, RMSNorm offers a simpler equation that's less computationally intensive than LayerNorm. Experiments on several NLP tasks show that RMSNorm is comparable to LayerNorm in quality, but accelerates the running speed [3].

Conclusion

Incorporating the right normalization technique can make or break your deep learning model. This post has aimed to provide a theoretical and practical overview of Batch Normalization, Layer Normalization, and RMS Layer Normalization. The Python implementations should help you get a hands-on understanding of how these techniques work at a granular level.

References

[1] Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift, Sergey Ioffe and Christian Szegedy, 2015

[2] Layer normalization, Jimmy Lei Ba et al., 2016

[3] Root Mean Square Layer Normalization, Biao Zhang and Rico Sennrich, 2019