Batch Normalization, Layer Normalization and Root Mean Square Layer Normalization: A Comprehensive Guide with Python Implementations
Introduction
Stabilizing and accelerating the training of neural networks often hinge on the normalization techniques employed. While the theory behind normalization appears straightforward, its practical applications come in various flavours, each with unique merits and shortcomings.
This post will explore three popular types of normalizations:
Batch Normalization (BatchNorm)
Layer Normalization (LayerNorm)
Root Mean Square Layer Normalization (RMSNorm)
We'll cover:
The mathematics behind each technique
A discussion on computational complexity
The pros and cons of each method
Batch Normalization
Overview
Batch Normalization, introduced by Sergey Ioffe and Christian Szegedy [1], aims to normalize the outputs of a layer across each feature dimension for a given mini-batch during training. To put it simply, it uses the statistics (mean and variance) computed across all instances in the mini-batch.
Equation
The output \(\hat{x}\) is computed as:
\[\hat{x} = \frac{x - \mathbb{E}_{\text{mini-batch}}(x)}{\sqrt{Var_{\text{mini-batch}}(x) + \epsilon}} \cdot \gamma + \beta\]
Here, \(\mathbb{E}_{\text{mini-batch}}(x)\) and \(Var_{\text{mini-batch}}(x)\) are the mean and variance, computed per feature over the mini-batch, and \(\epsilon\) is a small constant for numerical stability. \(\gamma\) and \(\beta\) are scaling and shifting learnable parameters, respectively.
Running Statistics
Batch Norm also demands the calculation and storage of running statistics for both the mean and variance. During training, these are calculated as the exponential moving average (EMA), updated using a scalar momentum term \(\alpha\), such that \(y_{EMA_i} = \alpha y_{EMA_{i-1}} + (1 - \alpha)y_i\) where \(i\) is the current training step. During inference, the stored running statistics are used to normalise the single sample.
Properties
Compared to no-normalization, Batch Norm:
Reduces internal covariate shift (i.e. reduces the change in the distributions of layers' input)
Speeds up convergence
Enables higher learning rates
Less sensitive to initialization
Python Implementation
class BatchNorm(nn.Module):
def __init__(
self,
size: int,
eps: float = 1e-5,
):
"""
Batch Normalization.
Assumes the shape of the input x is (batch, seq_len, d_model)
Args:
size: shape of the feature dimention (i.e. d_model)
eps: For numerical stability. Defaults to 1e-5.
"""
super(BatchNorm, self).__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(size), requires_grad=True)
self.beta = nn.Parameter(torch.ones(size), requires_grad=True)
def forward(self, x):
x_var, x_mean = torch.var_mean(x, dim=[0,1], keepdim=True, correction=0)
x_std = torch.sqrt(x_var + self.eps)
x_norm = (x - x_mean)/ x_std
return self.gamma.unsqueeze(0).unsqueeze(1) * x_norm + self.beta.unsqueeze(0).unsqueeze(1)
Assuming our input \(x\) has the shape (batch
, seq_len
, d_model
), for batch normalization, we normalize across both the batch and sequence length dimensions (0
and 1
respectively), but keep the feature dimension (d_model
) intact. This is because BatchNorm aims to stabilize the distribution of each feature over the mini-batch.
Layer Normalization
Overview
Layer Normalization [2], unlike Batch Norm, normalizes the features for each individual data point in a batch, making it less susceptible to variations in batch size.
Equation
The output \(\hat{x}\) is computed similarly to Batch Norm but differs in the axis over which \(\mathbb{E}(x)\)and \(Var(x)\) are computed.
\[\hat{x} = \frac{x - \mathbb{E}_{\text{features}}(x)}{\sqrt{Var_{\text{feature}}(x) + \epsilon}} \cdot \gamma + \beta\]
Here, \(\mathbb{E}_{\text{features}}(x)\) and \(Var_{\text{features}}(x)\) are the mean and variance calculated over the feature dimension.
Properties
Less sensitive to batch size than Batch Norm
Works well for sequence models
Stabilizes training
Accelerates convergence
Python Implementation
class LayerNorm(nn.Module):
def __init__(
self,
size: int,
eps: float = 1e-5,
):
"""
Layer Normalization.
Assumes the shape of the input x is (batch, seq_len, d_model)
Args:
size: shape of the feature dimention (i.e. d_model)
eps: For numerical stability. Defaults to 1e-5.
"""
super(Layernorm, self).__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(size), requires_grad=True)
self.beta = nn.Parameter(torch.ones(size), requires_grad=True)
def forward(self, x):
x_var, x_mean = torch.var_mean(x, dim=-1, keepdim=True, correction=0)
x_std = torch.sqrt(x_var + self.eps)
x_norm = (x - x_mean)/ x_std
return self.gamma.unsqueeze(0).unsqueeze(1) * x_norm + self.beta.unsqueeze(0).unsqueeze(1)
Assuming our input \(x\) has the shape (batch
, seq_len
, d_model
), Layer Norm normalizes across the feature dimension (d_model
) for each sequence in the batch. The rationale is to normalize all features for a single data point to have zero mean and unit variance, making the model less sensitive to the scale of input features.
RMS Normalization
Overview
RMSNorm [3] is a variant of LayerNorm that 1) uses the root mean square, \(\mathbb{E}(x^2)\), instead of the standard deviation for re-scaling and 2) does not use the re-centering operation. The authors hypothesize that the re-centering invariant property in LayerNorm is dispensable, and only keep the re-scaling invariance property in RMS Norm.
Equation
The output \(\hat{x}\) is calculated as:
\[\hat{x} = \frac{x}{ \sqrt{\mathbb{E}_{\text{feature}}(x^2) + \epsilon}} \cdot \gamma\]
Properties
- Computationally simpler and thus more efficient than Layer Norm
Python Implementation
class RMSNorm(nn.Module):
def __init__(
self,
size: int,
eps: float = 1e-5,
):
"""
Root-Mean-Square Layer Normalization.
Assumes the shape of the input x is (batch, seq_len, d_model)
Args:
size: shape of the feature dimention (i.e. d_model)
eps: For numerical stability. Defaults to 1e-5.
"""
super(RMSnorm, self).__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(size), requires_grad=True)
def forward(self, x):
rms = torch.sqrt((x ** 2).mean(dim=-1, keepdim=True) + self.eps) # as an alternative can also use the frobenius norm to compute rms
x_norm = x / rms
return self.gamma.unsqueeze(0).unsqueeze(1) * x_norm
Assuming our input \(x\) has the shape (batch
, seq_len
, d_model
), for RMS Layer Normalization, like LN, we normalize across the feature dimension (d_model
). We use the root mean square of the feature values for each data point in the sequence. This method is computationally efficient and can be more robust to outliers.
Computational Complexity and Memory Requirements
Batch Norm: Requires storage of running statistics, making it harder to parallelize.
Layer Norm: Less computationally intensive as no running statistics are needed.
RMSNorm: Even less computationally intensive than LayerNorm due to no re-centering.
Use-cases and Recommendations
Batch Normalization: This technique is particularly strong in convolutional architectures where batch sizes are often large enough for the mean and variance estimates to be reliable. However, it's not ideal for models like RNNs and Transformers where sequence lengths can vary. Its reliance on running statistics also poses challenges for online learning scenarios and can add complexity when attempting to parallelize the model across multiple devices.
Layer Normalization: Highly effective for sequence models such as RNNs and Transformers. It's also a better choice for scenarios with small batch sizes as it computes statistics for each data point independently, negating the need for a large batch to estimate population statistics.
RMSNorm: If computational efficiency is your priority, RMSNorm offers a simpler equation that's less computationally intensive than LayerNorm. Experiments on several NLP tasks show that RMSNorm is comparable to LayerNorm in quality, but accelerates the running speed [3].
Conclusion
Incorporating the right normalization technique can make or break your deep learning model. This post has aimed to provide a theoretical and practical overview of Batch Normalization, Layer Normalization, and RMS Layer Normalization. The Python implementations should help you get a hands-on understanding of how these techniques work at a granular level.
References
[1] Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift, Sergey Ioffe and Christian Szegedy, 2015
[2] Layer normalization, Jimmy Lei Ba et al., 2016
[3] Root Mean Square Layer Normalization, Biao Zhang and Rico Sennrich, 2019