Now, we turn our attention to Llama 2, the successor to Llama. Let's look at the differences:

**Dataset**: Llama2 benefits from a 40% increase in training data.**Context Length**: Trained with a 4096 token context length, up from 2048.**Attention Mechanism**: An architectural evolution from Multi-head Attention (MHA) to Grouped-Query Attention (GQA).

Apart from the switch to GQA, the architecture remains untouched. Thus, much of our Llama codebase remains applicable, sparing only the attention block.

Let's now see what the new Grouped-Query Attention block (`GQALlamaBlock`

) looks like and break down the code:

`class GQALlamaBlock(nn.Module): def __init__( self, embedding_size: int, context_len: int, causal_attention: bool, n_heads: int, n_groups:int, swiglu_d_multiplier: float ): super().__init__() self.embedding_size = embedding_size self.causal_attention = causal_attention self.n_heads = n_heads self.n_groups = n_groups assert self.n_heads % self.n_groups == 0, f"Number of heads ({self.n_heads}) must be divisable by the number of groups ({self.n_groups})" self.group_size = self.n_heads // self.n_groups assert self.embedding_size % self.n_heads == 0, f"Embedding size ({self.embedding_size}) must be divisable by the number of heads ({self.n_heads})" self.head_dim = self.embedding_size // self.n_heads self.R = get_rotary_matrix(context_len=context_len, embedding_dim=self.head_dim) self.rms = RMSnorm(size=embedding_size) self.ff_q = nn.Linear(embedding_size, embedding_size, bias=False) kv_embedding_size = self.head_dim * self.n_groups self.ff_k = nn.Linear(embedding_size, kv_embedding_size, bias=False) self.ff_v = nn.Linear(embedding_size, kv_embedding_size, bias=False) # In Llama paper swiglu_d_multiplier = 2/3 * 4 swiglu_size = int(swiglu_d_multiplier * embedding_size) self.fc1 = nn.Linear(embedding_size, swiglu_size) self.activation = SwiGLU(size=swiglu_size) self.fc2 = nn.Linear(swiglu_size, embedding_size) def forward(self, x): input_shape = x.shape q_resize = (x.shape[0], x.shape[1], self.n_heads, self.head_dim) kv_resize = (x.shape[0], x.shape[1], self.n_groups, self.head_dim) x_res = x x = self.rms(x) # pre-normalization query = self.ff_q(x).reshape(q_resize) key = self.ff_k(x).reshape(kv_resize) value = self.ff_v(x).reshape(kv_resize) # Apply rotation to query and key, separatly for each head R_matrix = self.R[:input_shape[1], :, :].to(query.device) query_rot = torch.einsum('bhld,ldd->bhld', query.permute(0,2,1,3), R_matrix) key_rot = torch.einsum('bgdl,ldd->bgdl', key.permute(0,2,3,1), R_matrix) query_rot = query_rot.reshape(input_shape[0], self.group_size, self.n_groups, input_shape[1], self.head_dim) score = torch.einsum('bsgld, bgdp->bsglp', query_rot, key_rot) if self.causal_attention: score += causal_mask(size=score.shape, device=score.device) score = score / torch.sqrt(torch.tensor(self.head_dim)) attention = torch.softmax(score, dim=-1) x = torch.einsum('bsgpl,bgld->bsgpd', attention, value.permute(0,2,1,3)) x = x.reshape(input_shape[0], self.group_size*self.n_groups, input_shape[1], self.head_dim) x = x.permute(0, 2, 1, 3).reshape(input_shape) x += x_res x_res = x x = self.rms(x) x = self.fc1(x) x = self.activation(x) x = self.fc2(x) return x + x_res`

**Initialization**: The constructor sets up the block, ensuring the number of heads is divisible by the number of groups and the embedding size by the number of heads.**Matrices and Normalization**:`get_rotary_matrix`

generates a rotary position embedding matrix, while`RMSnorm`

is used for layer normalization, both carried over from Llama.**Group-aware Feedforward (New!)**:`ff_q`

,`ff_k`

, and`ff_v`

are linear layers for transforming inputs into queries, keys, and values, respectively. These are reshaped according to the number of heads and groups. The output size of`ff_q`

is \(d \cdot h\), and`ff_k`

and`ff_v`

are \(d \cdot g\), where \(d\) is the dimensionality of each head/group, \(h\) is the number of heads, and \(g\) is the number of groups.**SwiGLU Activation**: A SwiGLU-based feedforward network with a custom size determined by`swiglu_d_multiplier`

is employed.

**The Forward Pass:**

During the forward pass, the method reshapes queries, keys, and values, and applies the rotary embeddings separately to each head. This rotation aligns the attention mechanism with the relative position information.

Again, most components are the same as we sow in `MHALlamaBlock`

in Llama. Let's focus on the main differences, all done with the magic of

`reshape`

and `torch.einsum`

!

**Calculating Attention Scores**:`query_rot = query_rot.reshape( input_shape[0], self.group_size, self.n_groups, input_shape[1], self.head_dim ) score = torch.einsum('bsgld, bgdp->bsglp', query_rot, key_rot)`

The first line reshapes the rotated query to prepare it for attention score calculation.

`input_shape[0]`

is the batch size`self.group_size`

is the size of each group within the heads. This is derived by dividing the total number of heads by the number of groups (`n_heads / n_groups`

).`self.n_groups`

is the number of groups that the heads are divided into.`input_shape[1]`

is the sequence length.`self.head_dim`

is the dimensionality of each head.

This reshaping step is crucial because it aligns the data into a structure that reflects the grouping of attention heads. Each head within a group will contribute to a portion of the attention calculation, and this structure facilitates that process.

The second line calculates the attention scores using `torch.einsum`

. The notation `'bsgld, bgdp->bsglp'`

describes how the tensors are combined.

`bsgld`

corresponds to the reshaped rotated queries tensor:`b`

for batch size`s`

for the size of each group`g`

for the number of groups`l`

for the sequence length`d`

for the dimension of each head

`bgdp`

corresponds to the rotated keys tensor:`b`

for batch size`g`

for the number of groups`d`

for the dimension of each head`p`

for the sequence length, which is the same as`l`

but is labelled differently to indicate a different role in this operation (here,`p`

is the 'target' position in the sequence that each 'source' position`l`

is attending to).

When `torch.einsum`

processes this operation, it does the following:

It aligns the rotated queries and keys based on the batch and group dimensions (

`b`

and`g`

).It then computes dot products between each query and key vector across the dimension

`d`

for each position in the sequence (`l`

attending to`p`

).The result is a raw attention score tensor:

`bsglp`

.

**Applying Attention to Values**:`x = torch.einsum('bsgpl,bgld->bsgpd', attention, value.permute(0,2,1,3))`

After applying the causal mask, the normalization and the softmax function to the scores, this

`torch.einsum`

operation applies the calculated attention to the values.

Using the same notation as the attention score calculation, `bsgpl`

represents the attention weights.

`bgld`

represents the value vectors, permuted to align with the attention weights:`b`

for batch size`g`

for the number of groups`l`

for the sequence length`d`

for the dimension of each head

`-> bsgpd`

indicates the resulting tensor's dimensions:`b`

for batch size`s`

for the size of each group`g`

for the number of groups`p`

for the sequence length (the position that received the attention)`d`

for the dimension of each head

After applying the attention, the tensor must be reshaped back to the original input dimensions to maintain the consistency of the model's layers.

`x = x.reshape( input_shape[0], self.group_size*self.n_groups, input_shape[1], self.head_dim )`

The tensor is reshaped to collapse the sub-group and group dimensions (`s * g`

) back into a single dimension representing all heads. `input_shape[0]`

is the batch size, and `input_shape[1]`

is the sequence length, both derived from the original input tensor. `self.head_dim`

is the dimension of each head.

The next line:

`x = x.permute(0, 2, 1, 3).reshape(input_shape)`

First, reorder the dimensions of `x`

to match the expected order for the next layers in the model (`x.permute(0, 2, 1, 3)`

). `.reshape(input_shape)`

then reshape the tensor to the original input shape, ensuring that the output of the attention block can be seamlessly integrated into the subsequent layers of the model.

Right, now that we have our Llama2 model, let's use it for token generation! For that, let's compare some examples of token generation using the different sampling methods described in the Llama post, i.e. greedy, random sampling, top-k sampling, top-p sampling, and their variants including temperature scaling.

We can train our Llama2 model using `python baby_llama/run.py model=llama2`

, this produces the following results:

`Full Text, greedy: [BOS]KING RICHARD III:O Ratcliff, I have dream'd a fearful dream!What thinkest thou, will our friends prove all true?[EOS]Full Text, rnd_sampling: [BOS]AUTOLYCUS:TheWhen ravel spacious! therich of mine hath chairs not appointed me in the blended of my post,There and garland to the wronging testy of the dinof myHourly.[EOS]Full Text, rnd_sampling_t: [BOS]SICINIUS:Come, what talk youOf Marcius?[EOS]Full Text, topk_sampling: [BOS]CORIOLANUS:It is a purposed thing, and grows by plot,To curb the will of the nobility:Suffer't, and live with such as cannot ruleNor ever will be ruled.[EOS]Full Text, topk_sampling_t: [BOS]KING EDWARD IV:So, master mayor: these gates must not be shutBut in the night or in the time of war.What! fear not, man, but yield me up the keys;ForSpit must be fear'd and thee,And all those friends that deign to follow me.[EOS]Full Text, topp_sampling: [BOS]ROMEO:Or I shall.[EOS]Full Text, topp_sampling_t: [BOS]BUCKINGHAM:And, in good time, here comes the noble duke.[EOS]`

Similar to the previous Llama model, we are only training for 10 epochs, using a small network (8 layers), hidden dimension (1024), context length (256) and training batch size (8). Additionally, here we are using 4 groups within GQA. You can check the wandb run to see all the configurations and generation examples during training for this experiment.

For this small example, we can't see a notable improvement in memory usage. But still, the number of parameters of the model is slightly smaller (233M of Llama vs 224M of Llama2) and we can see in the figure below that Llama2 (dandy-lake-72, purple) is using (slightly) less memory than Llama (hopeful-surf-62, brown).

It should be noted that these improvements will scale with the model and that GQA's main advantage is to speed up inference.

]]>Our project is comprehensive and, among other things, includes constructing our attention mechanism that incorporates the three key components specified in the original Llama paper:

RMSNorm for pre-normalization

RoPE (Rotary Positional Embedding)

SwiGLU activation function

To help visualize the architecture, here's a diagram illustrating a single block of our model:

First things first: let's set up our development environment to ensure that everything runs smoothly. For this project, we'll be using Python 3.10 and manage our dependencies using Poetry. Here's how you can set it up:

`# Create a new Conda environment named 'llama'conda create -n llama python=3.10# Activate the Conda environmentconda activate llama# Install Poetry for dependency managementpip install poetry# Install project dependenciespoetry install`

With the environment set up, you're now ready to dive into the intricacies of building Baby Llama from scratch.

Given the domain-specific language characteristics of our dataset, we opted for training a custom Byte-Pair Encoding (BPE) tokenizer. This allows for more accurate and efficient tokenization specific to our corpus.

Our code snippet for training the tokenizer involves several components:

Initialization of a BPE tokenizer.

Setting pre-tokenizers and decoders to ByteLevel.

Configuration of special tokens and post-processors.

Training the tokenizer on a specific dataset specified in the

`cfg.path`

.

`# Initialize the BPE tokenizertokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))`

Here, we initialize a BPE tokenizer. We specify the unknown token as `[UNK]`

, which is what the tokenizer will use for any character sequences it hasn't seen before.

`tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)tokenizer.decoder = decoders.ByteLevel()`

These lines set the pre-tokenizer and decoder to use Byte-Level tokenization, a foundational part of BPE. This allows the BPE tokenizer to use bytes as the base vocabulary, providing an initial vocabulary size of 256.

Here, `add_prefix_space=False`

indicates that no space will be prefixed to each word at the beginning of a sentence.

`# Define the trainer and special tokenstrainer = trainers.BpeTrainer(special_tokens=["[UNK]", "[PAD]", "[BOS]", "[EOS]"])`

Here, we specify the training settings and declare special tokens that have specific roles during both training and inference. During training, BPE identifies the most frequently occurring pairs of consecutive bytes and merges them to create new tokens. These new tokens are then represented by new bytes that don't occur in the original dataset, thus effectively expanding the vocabulary.

`# Add post-processor for special tokenstokenizer.post_processor = processors.TemplateProcessing( single="[BOS] $A [EOS]", special_tokens=[("[BOS]", 2), ("[EOS]", 3)],)`

Post-processing is configured to automatically add `[BOS]`

and `[EOS]`

tokens at the beginning and end of each sequence (represented as `$A`

), respectively. The numbers `2`

and `3`

specify the indices of `[BOS]`

and `[EOS]`

based on their order in the special tokens list, so they must match.

`# Train the tokenizer on the datasettokenizer.train([cfg.path], trainer)`

Training is triggered using the `.train()`

method, and it's here that all the previously set configurations come into play. The tokenizer is trained on the data specified in `cfg.path`

.

`# Save the pretrained tokenizerpretrained_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)pretrained_tokenizer.save_pretrained(cfg.tokenizer_path)`

Finally, we save the trained tokenizer using the Transformers library's `PreTrainedTokenizerFast`

class. Upon running the `pretrained_tokenizer.save_pretrained(cfg.tokenizer_path)`

line, three files will be created within the folder specified by `cfg.tokenizer_path`

. These files contain the necessary configurations to reload the tokenizer for future use.

To illustrate the tokenizer's functionality, let's encode and decode a sample sentence:

`encodings = tokenizer.encode("CORIOLANUS: \n It is apart \n That I shall blush in acting, and might well \n Be taken from the people.")decodings = tokenizer.decode(encodings.ids)print(f"Token Ids: {encodings.ids}")print(f"Encoded Tokens : {encodings.tokens}")print(f"Decoded Tokens: {decodings}")`

This produces the following output:

`Token Ids: [2, 725, 12, 68, 67, 5327, 137, 6799, 68, 67, 9936, 104, 227, 4150, 120, 9025, 8, 109, 771, 371, 68, 67, 4391, 3236, 289, 80, 1005, 10, 3]Encoded Tokens : ['[BOS]', 'CORIOLANUS', ':', '', '', 'It', 'is', 'apart', '', '', 'That', 'I', 'shall', 'blush', 'in', 'acting', ',', 'and', 'might', 'well', '', '', 'Be', 'taken', 'from', 'the', 'people', '.', '[EOS]']Decoded Tokens: CORIOLANUS: It is apart That I shall blush in acting, and might well Be taken from the people.`

Here, the example output includes the following encoded tokens: `['[BOS]', 'CORIOLANUS', ':', '', '', 'It', 'is', 'apart', ...]`

. You'll notice the special character ` in the encoded tokens. This character signifies a space before a word within a sentence and is a product of the ByteLevel pre-tokenization. In ByteLevel tokenization, spaces are also encoded into specific byte tokens, and is how the model represents these spaces when followed by a word within the context of a sentence.`

This example demonstrates the tokenizer's ability to encode and decode text accurately, preserving the original sentence structure and adding special tokens at the beginning and end of the sequence.

To execute this tokenizer training script, simply run:

`python run_tokenizer.py`

Because we're using Hydra for configuration management, modifying aspects like the dataset path or where to save the tokenizer is straightforward. All these settings are located in the `cfg`

object and are sourced from a YAML configuration file.

Let's now focus on the data preparation and loading.

`tokenizer_name = "gpt2" if cfg.dataset.tokenizer_path is None else cfg.dataset.tokenizer_pathtokenizer = AutoTokenizer.from_pretrained(tokenizer_name)`

Here, `tokenizer_name`

is set to either "gpt2" or a path to your custom tokenizer, saved in `cfg.dataset.tokenizer_path`

. This allows you to switch between a custom and a pre-trained tokenizer effortlessly. For our experiments, `cfg.dataset.tokenizer_path`

is the path to the folder we created in the previous "Tokenizer Training" step.

`AutoTokenizer.from_pretrained`

is then used to load the tokenizer.

`dataset = getfromtext( data_path=Path(cfg.dataset.path), tokenizer=tokenizer, tokenizer_args=dict(return_tensors="pt", add_special_tokens=True, truncation=True, padding="max_length", max_length=cfg.model.context_len+1))`

The `getfromtext`

is a custom function that transfors the raw text (from `cfg.dataset.path`

) into a `CLMDataset`

object, which is compatible with PyTorch's `DataLoader`

.

`def getfromtext( data_path: Path, tokenizer: AutoTokenizer, tokenizer_args: dict ) -> CLMDataset: data = data_path.read_text().split("\n\n") data = [i for i in data if i] return CLMDataset(data=data, tokenizer=tokenizer, tokenizer_args=tokenizer_args)`

The `CLMDataset`

class inherits from the PyTorch's `Dataset`

class that takes care of tokenization and formatting of your text data, making it compatible with PyTorch's `DataLoader`

and ready for training.

Let's check the code for the two main parts of `CLMDataset`

: 1) the `__getitem__`

method and 2) how the arguments of the tokenizer are used. The `__getitem__`

is designed to work with PyTorch's `DataLoader`

. It returns a tuple consisting of input IDs, target IDs (next token IDs for each input ID), and the attention mask.

`def __getitem__(self, idx: int) -> Tuple[int, int, int]: return self.tokens["input_ids"][idx, :-1], self.tokens["input_ids"][idx, 1:], self.tokens["attention_mask"][idx, :-1]`

This slicing technique creates input and target sequences by shifting one tokena common practice in next-token prediction.

The tokenizer, with its arguments, is simply called within the class as:

`class CLMDataset(Dataset): def __init__( self, data: Path, tokenizer: AutoTokenizer, tokenizer_args: dict, ): self.data = data self.tokens = tokenizer(self.data, **tokenizer_args) ...`

The tokenizer arguments are passed down from the `getfromtext`

to the `CLMDataset`

. In our experiments, we use `return_tensors="pt"`

to return PyTorch tensors, `add_special_tokens=True`

to include special tokens in the tokenized output, `truncation=True`

for handling sequences longer than the model's maximum input length, `padding="max_length"`

to pad shorter sequences to the max length (in the batch), and `max_length=cfg.model.context_len+1`

to set the maximum sequence length (the "+1" accounts for label-shifting during training).

Having prepared our data and made it compatible with PyTorch's `DataLoader`

, the next step is to manage this data efficiently for different stages of the model training, validation, and testing. This is where `CLMDataModule`

comes into play. `CLMDataModule`

is a class that inherits from PyTorch Lightning's `LightningDataModule`

and takes care of data loading and preparation. Here's how we use it:

`datamodule = CLMDataModule( data=dataset, train_ratio=cfg.dataset.train_ratio, val_ratio=cfg.dataset.val_ratio, test_ratio=cfg.dataset.test_ratio, train_batchsize=cfg.trainer.train_batchsize, val_test_batchsize=cfg.trainer.val_test_batchsize, num_workers=cfg.trainer.num_workers)`

The `CLMDataModule`

class provides standard methods like `train_dataloader`

, `val_dataloader`

, and `test_dataloader`

to return PyTorch `DataLoader`

objects for each phase. These methods are quite standard, utilizing the batch sizes and number of workers specified during initialization. These loaders will use the `CLMDataset`

object you provided and its `__getitem__`

method to fetch batches of data. `CLMDataModule`

also has a `setup`

method that splits the dataset into training, validation, and test sets based on the provided ratios. It takes a `stage`

argument to determine which splits to prepare, allowing to use different data stages without reloading the entire dataset:

`def setup(self, stage): train, val, test = random_split( dataset=self.data, lengths=[self.train_ratio, self.val_ratio, self.test_ratio] ) if stage == "fit": self.train, self.val = train, val if stage == "test": self.test = test`

Let's have an intuition of the three main Llama components and implement them!

First, we initialize the Llama architecture using the following code snippet:

`transformer = Llama( vocab_size=dataset.get_vocab_size(), hidden_size=cfg.model.hidden_size, context_len=cfg.model.context_len, causal_attention=True, n_heads=cfg.model.n_heads, n_blocks=cfg.model.n_blocks)`

where:

`vocab_size`

: size of the vocabulary, taken from the dataset you're working with.`hidden_size`

: size of the hidden layer, specified in your hydra configuration.`context_len`

: length of the context window for attention, also from your hydra configuration.`causal_attention`

: boolean flag to indicate if the model should use causal (unidirectional) attention.`n_heads`

: number of attention heads, specified in your hydra configuration.`n_blocks`

: number of transformer blocks (layers), also specified in your hydra configuration.

`class Llama(nn.Module): def __init__( self, vocab_size: int, hidden_size: int, context_len: int, causal_attention: bool, n_heads: int, n_blocks: int ): super().__init__() self.context_len = context_len self.embedding = nn.Embedding(vocab_size, hidden_size) self.attention_block = nn.ModuleList([MHALlamaBlock(hidden_size, context_len, causal_attention, n_heads) for _ in range(n_blocks)]) self.unembedding = nn.Linear(hidden_size, vocab_size) def forward(self, x): x = self.embedding(x) for single_block in self.attention_block: x = single_block(x) x = self.unembedding(x) return x`

The `Llama`

class is defined as a subclass of PyTorch's `nn.Module`

. Inside its `__init__`

method:

`self.embedding`

: embedding layer that converts token IDs to vectors.`self.attention_block`

: list of attention blocks, each handling multi-head self-attention and feed-forward operations.`self.unembedding`

: linear layer that maps the output back to vocabulary space.

In the `forward`

method, the input sequence `x`

goes through the embedding layer, the list of attention blocks, and finally the unembedding layer, before it is returned as output.

This completes the architecture of our Llama model.

Let's now delve into the three main components of Llama and implement them!

RMSNorm is used to normalize the input of each attention block. The inspiration for including pre-normalization comes from GPT-3, which showed that it improves training stability compared to output normalization.

RMSNorm is computationally simpler and more efficient than LayerNorm due to its utilization of root mean square for re-scaling and its lack of re-centring invariance.

Here's a simplified RMSNorm code snippet to give you an idea:

`class RMSnorm(nn.Module): def __init__( self, size: int, eps: float = 1e-5, ): super(RMSnorm, self).__init__() self.eps = eps self.gamma = nn.Parameter(torch.ones(size), requires_grad=True) def forward(self, x): rms = torch.sqrt((x ** 2).mean(dim=-1, keepdim=True) + self.eps) x_norm = x / rms return self.gamma.unsqueeze(0).unsqueeze(1) * x_norm`

For more mathematical and implementation details about RMSNorm and its differences with Batch Normalization and Layer Normalization, refer to our dedicated blog post.

RoPE is based on rotating queries and keys in the attention mechanism, with a unique rotation at each position. This segment of code focuses on applying the rotation in a single attention block (the full code to the attention block is down below):

`R_matrix = self.R[:resize[1], :, :].to(query.device)query_rot = torch.einsum('bhld,ldd->bhld', query.permute(0,2,1,3), R_matrix)key_rot = torch.einsum('bhdl,ldd->bhdl', key.permute(0,2,3,1), R_matrix)`

The `self.R`

is a pre-computed rotary matrix for positional encoding, `resize[1]`

is the sequence length and is used to slice the rotary matrix to match the sequence length of the queries and keys. the dimensions of query and key are ordered as [Batch size, Sequence length, Number of Heads, Hidden Dimension]. We permute these to rearrange the dimensions in a way that facilitates the subsequent operations. Specifically, we bring the sequence length (`l`

) and dimension (`d`

) next to each other for the rotation operation. Let's now try to understand the `torch.einsum`

operation! Here, the expression `bhld,ldd->bhld`

indicates the following:

`bhld`

: Represents batch size (`b`

), number of heads (`h`

), sequence length (`l`

), and hidden dimension (`d`

) - of each head - for the query.`ldd`

: Stands for sequence length (`l`

) and hidden dimension (`d`

), twice to align with the square`R_matrix`

.`->bhld`

: Tells us that the output should maintain the original dimensions of batch size, number of heads, sequence length, and dimension. In this case, the`torch.einsum`

function takes each slice along the`l`

and`d`

dimensions from`query`

, multiplies it with the`R_matrix`

, and sums along those dimensions. Because the output subscripts (`bhld`

) are the same as the input, there is no reduction in dimensionsmeaning, we get an output of the same shape as the`query`

, but now each query vector has been rotated based on its position in the sequence.

For a deeper dive into RoPE, its mathematical formulation, and its practical implementation in PyTorch, check out our blog post.

SwiGLU is a combination of the Swish activation function and the GLU (Gated Linear Unit):

$$SwiGLU(A,B)=ASwish(B)=A(B(B))$$

where \(A\) and \(B\) are two linear transformations, \(Swish(x) = x \cdot \sigma(\beta x)\) and \(\sigma\) is the sigmoid function Here's the essential code snippet for SwiGLU:

Here's the essential code snippet for SwiGLU:

`class SwiGLU(nn.Module): def __init__(self, size): super().__init__() self.linearA = nn.Linear(size, size) self.linearB = nn.Linear(size, size) self.beta = nn.Parameter(torch.randn(1), requires_grad=True) def forward(self, x): swish = self.linearB(x) * torch.sigmoid(self.beta * self.linearB(x)) return swish * self.linearA(x)`

Following the original Llama paper, for our experiments, we set `size`

to \(\frac{2}{3}4d\), where \(d\) is the hidden size (or dimension) of our Llama model. This can be easily changed using the `model.swiglu_d_moltiplier`

argument of hydra config.

Now, let's put everything together to see all the code for a single Llama multi-head attention block:

`def causal_mask(size, device): x = torch.full(size, float("-inf")) return torch.triu(x, diagonal=1).to(device=device)class MHALlamaBlock(nn.Module): def __init__( self, embedding_size: int, context_len: int, causal_attention: bool, n_heads: int, swiglu_d_moltiplier: float ): super().__init__() self.embedding_size = embedding_size self.causal_attention = causal_attention self.n_heads = n_heads assert self.embedding_size % self.n_heads == 0, f"Embedding size ({self.embedding_size}) must be divisable by the number of heads ({self.n_heads})" self.head_dim = self.embedding_size // self.n_heads self.R = get_rotary_matrix(context_len=context_len, embedding_dim=self.head_dim) self.rms = RMSnorm(size=embedding_size) self.ff_k = nn.Linear(embedding_size, embedding_size, bias=False) self.ff_q = nn.Linear(embedding_size, embedding_size, bias=False) self.ff_v = nn.Linear(embedding_size, embedding_size, bias=False) # In Llama paper swiglu_d_moltiplier = 2/3 * 4 swiglu_size = int(swiglu_d_moltiplier * embedding_size) self.fc1 = nn.Linear(embedding_size, swiglu_size) self.activation = SwiGLU(size=swiglu_size) self.fc2 = nn.Linear(swiglu_size, embedding_size) def forward(self, x): input_shape = x.shape resize = (x.shape[0], x.shape[1], self.n_heads, self.head_dim) x_res = x x = self.rms(x) # pre-normalization query = self.ff_q(x).reshape(resize) key = self.ff_k(x).reshape(resize) value = self.ff_v(x).reshape(resize) # Apply rotation to query and key, separatly for each head R_matrix = self.R[:resize[1], :, :].to(query.device) query_rot = torch.einsum('bhld,ldd->bhld', query.permute(0,2,1,3), R_matrix) key_rot = torch.einsum('bhdl,ldd->bhdl', key.permute(0,2,3,1), R_matrix) score = query_rot @ key_rot if self.causal_attention: score += causal_mask(size=score.shape, device=score.device) score = score / torch.sqrt(torch.tensor(self.head_dim)) attention = torch.softmax(score, dim=-1) x = attention @ value.permute(0,2,1,3) x = x.permute(0, 2, 1, 3).reshape(input_shape) x += x_res x_res = x x = self.rms(x) x = self.fc1(x) x = self.activation(x) x = self.fc2(x) return x + x_res`

This reflects the architecture in the diagram included at the beginning of this post.

Let's now take advantage of `LightningModule`

to easily define the training, validation and test loop, the optimizer and the learning rate scheduler as well as the prediction (we will call it `generation`

).

The `SimpleModule`

is our customized class that inherits from `LightningModule`

. The `SimpleModule`

class starts by taking in two main components: the model architecture (here our Llama architecture defined above) and the tokenizer (again, defined above). Here's how you would instantiate `SimpleModule`

:

`model = SimpleModule( transformer, tokenizer=tokenizer)`

And how `SimpleModule`

is initialized:

`class SimpleModule(pl.LightningModule): def __init__( self, model: nn.Module, tokenizer: AutoTokenizer, ): super().__init__() self.model = model self.loss = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) self.tokenizer = tokenizer self.logger_table_data = []`

The `tokenizer`

is used to specify the pad token to ignore when calculating the loss, if not specified the default value the loss will ignore is -100.

The `self.logger_table_data`

is a list we will use to log some examples at the end of each validation on wandb - we will see how do to it later in our post.

Our `forward`

method is straightforward, calling the `forward`

method of our `self.model`

:

`def forward(self, x): return self.model(x)`

Also the `training_step`

, `validation_step`

and `test_step`

are standard, these methods handle what happens during each training, validation and test step. We will include here only the code for `training_step`

, as they will all call the `_get_preds_loss`

to get the loss of the current batch and log it:

`def _get_preds_loss(self, batch): x, y, _ = batch y_hat = self.model(x) loss = self.loss(y_hat.view(-1, y_hat.shape[-1]), y.view(-1)) return y_hat, lossdef training_step(self, batch, batch_idx): _, loss = self._get_preds_loss(batch) self.log('train_loss', loss) return loss`

Remember, the `__getitem__`

method in `CLMDataset`

returns input tokens, target tokens (input tokens shifted by one position), and attention masks, which are unpacked here using `x, y, _ = batch`

. Also, as always, tensor reshaping is crucial for calculating the loss properly!

Now let's see how we can generate some examples (using `generate`

- we will include the code for it in a bit!) and log them at the end of each validation step using the `on_validation_end`

method:

`def on_validation_end(self) -> None: _, output_decoded = self.generate(context_len=self.model.context_len, max_output_token=50) print(f"Full Text: \n{output_decoded}") current_epoch = len(self.logger_table_data) -1 self.logger_table_data.append([current_epoch, output_decoded]) self.logger.log_table(key="Example Text Generation", columns=["Epoch", "Text"], data=self.logger_table_data, ) return super().on_validation_end()`

`LightningModule`

also allows us to easily configure the optimizer by overwriting the `configure_optimizers`

method in our custom `SimpleModule`

:

`def configure_optimizers(self): max_step = self.trainer.max_epochs * (len(self.trainer.datamodule.train_dataloader())) optimizer = torch.optim.AdamW(self.parameters(), lr=3e-4, weight_decay = 0.1, betas=(0.9, 0.95)) scheduler = { 'scheduler': OneCycleLR( optimizer, max_lr=3e-4, total_steps=max_step, pct_start=0.03, anneal_strategy='cos', ), 'interval': 'step', 'frequency': 1 } return {'optimizer': optimizer, 'lr_scheduler': scheduler}`

This method returns a dictionary containing the optimizer and the learning rate scheduler to be used by the PyTorch Lightning `Trainer`

- which we'll define in a second! The optimizer is `AdamW`

(very straightforward to use!) and the learning rate scheduler is used to set the learning rate of each parameter group according to the 1cycle learning rate policy (`OneCycleLR`

). Let's see all the components:

`max_lr=3e-4`

: sets the maximum learning rate.`total_steps=max_step`

: aligns the total number of steps with the calculated`max_step`

. Where the maximum number of steps is the maximum number of epochs multiplied by the number of batches in our training set.`pct_start=0.03`

: specifies that 3% of the total steps will be used for the warm-up phase.`anneal_strategy='cos'`

: uses cosine annealing for the learning rate schedule.`interval`

: specifies the scheduler should update at every step, as an alternative we could update it at every epoch.`frequency`

: sets the update frequency to 1, meaning the scheduler updates every time it's called.

Since our `SimpleModule`

inherits from `LightningModule`

, it has several built-in attributes and methods, among which `self.logger`

(used in our `on_validation_end`

) and `self.trainer`

(used in `configure_optimizers`

). When we will create our `Trainer`

object (later in our post) and define our custom attributes `logger`

and `trainer`

, PyTorch Lightning internally will set both `self.logger`

and `self.trainer`

within our `LightningModule`

(`SimpleModule`

) - one more reason to use Lightning!

One of the exciting parts of `SimpleModule`

is its token generation capabilities. Whether you want to use greedy decoding, random sampling, top-k, or top-p sampling, it has you covered.

The `_single_generate`

method in `SimpleModule`

generates a single token based on various strategies. You can control the behaviour using the arguments like `temperature`

, `top_k`

, `top_p`

, and `greedy`

.

`def _single_generate(self, idx, context_len, temperature, top_k, top_p, greedy): logits = self(idx[:, -context_len:])[:, -1, :] logits = logits / temperature if greedy: return torch.argmax(logits, dim=1).reshape(-1, 1) # Initialize mask with ones mask = torch.ones_like(logits).bool() if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=1), dim=1) sorted_mask = cumulative_probs > top_p # Ensure at least the most probable is included if sorted_mask contains all True if sorted_mask.all(): sorted_mask[..., :1] = 0 to_scatter = sorted_mask.type_as(logits) * float('-inf') to_scatter[sorted_mask == 0] = logits.gather(1, sorted_indices)[sorted_mask == 0] logits.scatter_(1, sorted_indices, to_scatter) elif top_k > 0: top_k = min(top_k, logits.shape[1]) values, _ = torch.topk(logits, top_k) # smallest allowed value kth_values = values[..., -1] logits = torch.where(logits < kth_values.unsqueeze(-1), torch.tensor(float('-inf')).type_as(logits), logits) probs = torch.softmax(logits, dim=1) m = Categorical(probs) idx_next_token = m.sample() return idx_next_token.reshape(-1, 1)`

Let's check how to use the different strategies using `_single_generate`

:

**Greedy Decoding**: chooses the most likely next token at each time step.- Set
`greedy=True`

.

- Set
**Random Sampling**: samples from the distribution of the next tokens.- Set
`greedy=False`

and both`top_k=0`

and`top_p=0`

.

- Set
**Top-k Sampling**: samples from the top k most likely the next tokens.- Set
`top_k`

to a value greater than 0 and`top_p=0`

.

- Set
**Top-p Sampling**: samples from the smallest set of tokens whose cumulative probability exceeds`p`

.- Set
`top_p`

to a value between 0 (non-included) and 1.

- Set
**Temperature**: controls the randomness. Higher values make the output more random, and lower values make it more focused on high-probability tokens. This is used to increase the probability of probable tokens while reducing the one that is not.- Adjust
`temperature`

to control the randomness.

- Adjust

What if we want to generate more than a single token? For that, we can use the `generate`

method in `SimpleModule`

. This function generates multiple tokens by utilizing `_single_generate`

for each token and then uses the tokenizer to decode the generated token IDs.

`def generate(self, context_len, max_output_token, temperature=1, top_k=0, top_p=0.9, greedy=False): idx = torch.tensor([self.tokenizer.bos_token_id]).unsqueeze(0).to(self.device) for _ in range(max_output_token): next_token = self._single_generate(idx, context_len, temperature, top_k, top_p, greedy) idx = torch.cat([idx, next_token], dim=1) if next_token.item() == self.tokenizer.eos_token_id: break decoded = self.tokenizer.decode(idx[0], skip_special_tokens=False) return idx, decoded`

Let's now explore the `ModelTrainer`

class, a wrapper that configures and runs training using PyTorch Lightning. This class not only handles the model training but also integrates seamlessly with Weights and Biases (Wandb) for experiment tracking.

First, here is the code to initialize the `ModelTrainer`

and use it to train our `model`

:

`lr_monitor_callback = LearningRateMonitor(logging_interval='step')checkpoint_callback = ModelCheckpoint( monitor='val_loss', mode='min', save_last=False, filename='{epoch}-{val_loss:.2f}', auto_insert_metric_name=False )modeltrainer = ModelTrainer( wandb_project_name=cfg.wandb_project_name, wandb_entity_name=cfg.wandb_entity_name, wandb_disable_log=cfg.wandb_disable_log, model=model, datamodule=datamodule, max_epochs=cfg.trainer.max_epochs, check_val_every_n_epoch=cfg.trainer.check_val_every_n_epoch, callbacks=[lr_monitor_callback, checkpoint_callback] )trainer = modeltrainer.train()modeltrainer.wandb_logger.experiment.config.update(OmegaConf.to_container(cfg))`

Once again, we take advantage of `Lightning`

, here to: 1) automatically monitor and logs learning rate during training (`LearningRateMonitor`

) and 2) save the model periodically by monitoring the validation loss (`ModelCheckpoint`

). We do that using two of the built-in callbacks of `Lightning`

. We can think of callbacks as planned function calls at specific locations that allow you to inject custom behaviour into the training loop without having to modify the core training logic.

This is as easy as passing the callbacks to the `Trainer`

(inside our custom `ModelTrainer`

) in an array of Callback class instances.

Our `ModelTrainer`

is a straightforward class that looks like it:

`class ModelTrainer: def __init__( self, wandb_project_name, wandb_entity_name, wandb_disable_log, model, datamodule, max_epochs, check_val_every_n_epoch, callbacks ): self.wandb_project_name = wandb_project_name self.wandb_entity_name = wandb_entity_name self.wandb_disable_log = wandb_disable_log self.model = model self.datamodule = datamodule self.max_epochs = max_epochs self.check_val_every_n_epoch = check_val_every_n_epoch self.callbacks = callbacks self.wandb_logger = self._wandb_init() self.wandb_logger.watch(self.model) def _wandb_init(self): return WandbLogger( project=self.wandb_project_name, entity=self.wandb_entity_name, offline=self.wandb_disable_log ) def wandb_close(self): self.wandb_logger.experiment.unwatch(self.model) self.wandb_logger.experiment.finish() def train(self): trainer = Trainer( max_epochs=self.max_epochs, callbacks=self.callbacks, logger=self.wandb_logger, check_val_every_n_epoch=self.check_val_every_n_epoch, gradient_clip_val=1.0, gradient_clip_algorithm="norm", num_sanity_val_steps=None ) trainer.fit(model=self.model, datamodule=self.datamodule) return trainer`

The `Trainer`

is a standard `Lightning`

trainer, it is worth noticing that here we use gradient clipping to avoid exploding gradients. The `gradient_clip_val=1.0`

sets the maximum allowable value for the gradients during backpropagation and `gradient_clip_algorithm="norm"`

part specifies that the L2 norm is used for the clipping.

The `logger=self.wandb_logger`

part integrates Wandb for logging and experiment tracking. Where `self.wandb_logger`

is defined as `WandbLogger`

, a specialized logger provided by `PyTorch Lightning`

to interface seamlessly with Wandb. This logger makes it easy to log all sorts of training metadata directly to the Wandbinterface, where you can visualize it in real time.

In the code snippet above (how to initialize and call `ModelTrainer`

) we used this logger to update the Wandbexperiment configuration:

`modeltrainer.wandb_logger.experiment.config.update(OmegaConf.to_container(cfg))`

Here, the experiment's configuration is handled using Hydra.

Right, now that we have trained our Llama model, let's use it for token generation! For that, let's compare some examples of token generation using the different sampling methods described above:

`generation_config = {"greedy": {"temperature":1, "top_k":0, "top_p":0.0, "greedy":True}, "rnd_sampling": {"temperature":1, "top_k":0, "top_p":0.0, "greedy":False}, "rnd_sampling_t": {"temperature":0.7, "top_k":0, "top_p":0.0, "greedy":False}, "topk_sampling": {"temperature":1, "top_k":40, "top_p":0.0, "greedy":False}, "topk_sampling_t": {"temperature":0.7, "top_k":40, "top_p":0.0, "greedy":False}, "topp_sampling": {"temperature":1, "top_k":0, "top_p":0.9, "greedy":False}, "topp_sampling_t": {"temperature":0.7, "top_k":0, "top_p":0.9, "greedy":False}, }for conf_k, conf_v in generation_config.items(): _, outputs_decoded = model.generate(context_len=cfg.model.context_len, max_output_token=300, **conf_v) print(f"\nFull Text, {conf_k}: \n{outputs_decoded}")`

This produces the following results:

`Full Text, greedy: [BOS]KING RICHARD III:And be a happy mother by the deed.[EOS]Full Text, rnd_sampling: [BOS]CATESBY:Madam, his majesty doth call for you,And for your grace; and you, my noble lords.[EOS]Full Text, rnd_sampling_t: [BOS]DUKE VINCENTIO:Good morning to you, fair and gracious daughter.[EOS]Full Text, topk_sampling: [BOS]LUCIO:I believe thee; for I think thou never wast wheregrace was said.[EOS]Full Text, topk_sampling_t: [BOS]First Servingman:But when goes this forward?[EOS]Full Text, topp_sampling: [BOS]KATHARINA: Buckingham, I say, sir, that I do love.[EOS]Full Text, topp_sampling_t: [BOS]PETRUCHIO:I see you do not mean to part with her,Or else you like not of my company.[EOS]`

Remember, `[BOS]`

and `[EOS]`

are the special tokens we defined to describe the beginning and end of the sentence.

The results are not perfect, but we think they look very promising since we are only training for 10 epochs, using a small network (8 layers), hidden dimension (1024), context length (256) and training batch size (8). You can check the wandb run to see all the configurations and generation examples during training for this experiment.

The whole code for training our Llama model and generating some examples can easily be run with:

`python baby_llama/run.py`

If you want to run experiments using different configurations (e.g. number of epochs, hidden dimension, etc.), you can easily do it using Hydra! By running `python baby_llama/run.py -h`

you can see what arguments you can change to run your experiment:

`== Configuration groups ==Compose your configuration from those groups (group=option)dataset: tinyshakespearemodel: llamatrainer: standard== Config ==Override anything in the config (foo.bar=value)dataset: name: tinyshakespeare path: /home/sara/github_code/BabyLlama/data/tinyshakespeare.txt tokenizer_path: /home/sara/github_code/BabyLlama/data/tokenizer/ train_ratio: 0.8 val_ratio: 0.2 test_ratio: 0.0model: context_len: 256 hidden_size: 1024 n_heads: 8 n_blocks: 8 swiglu_d_moltiplier: 2.67trainer: max_epochs: 10 check_val_every_n_epoch: 1 num_workers: 4 train_batchsize: 8 val_test_batchsize: 8wandb_project_name: baby_llamawandb_entity_name: sarawandb_disable_log: false`

]]>Stabilizing and accelerating the training of neural networks often hinge on the normalization techniques employed. While the theory behind normalization appears straightforward, its practical applications come in various flavours, each with unique merits and shortcomings.

This post will explore three popular types of normalizations:

Batch Normalization (BatchNorm)

Layer Normalization (LayerNorm)

Root Mean Square Layer Normalization (RMSNorm)

We'll cover:

The mathematics behind each technique

A discussion on computational complexity

The pros and cons of each method

Batch Normalization, introduced by Sergey Ioffe and Christian Szegedy [1], aims to normalize the outputs of a layer across each feature dimension for a given mini-batch during training. To put it simply, it uses the statistics (mean and variance) computed across all instances in the mini-batch.

The output \(\hat{x}\) is computed as:

\[\hat{x} = \frac{x - \mathbb{E}_{\text{mini-batch}}(x)}{\sqrt{Var_{\text{mini-batch}}(x) + \epsilon}} \cdot \gamma + \beta\]

Here, \(\mathbb{E}_{\text{mini-batch}}(x)\) and \(Var_{\text{mini-batch}}(x)\) are the mean and variance, computed per feature over the mini-batch, and \(\epsilon\) is a small constant for numerical stability. \(\gamma\) and \(\beta\) are scaling and shifting learnable parameters, respectively.

**Running Statistics**

Batch Norm also demands the calculation and storage of running statistics for both the mean and variance. During training, these are calculated as the exponential moving average (EMA), updated using a scalar momentum term \(\alpha\), such that \(y_{EMA_i} = \alpha y_{EMA_{i-1}} + (1 - \alpha)y_i\) where \(i\) is the current training step. During inference, the stored running statistics are used to normalise the single sample.

Compared to no-normalization, Batch Norm:

Reduces internal covariate shift (i.e. reduces the change in the distributions of layers' input)

Speeds up convergence

Enables higher learning rates

Less sensitive to initialization

**Python Implementation**

`class BatchNorm(nn.Module): def __init__( self, size: int, eps: float = 1e-5, ): """ Batch Normalization. Assumes the shape of the input x is (batch, seq_len, d_model) Args: size: shape of the feature dimention (i.e. d_model) eps: For numerical stability. Defaults to 1e-5. """ super(BatchNorm, self).__init__() self.eps = eps self.gamma = nn.Parameter(torch.ones(size), requires_grad=True) self.beta = nn.Parameter(torch.ones(size), requires_grad=True) def forward(self, x): x_var, x_mean = torch.var_mean(x, dim=[0,1], keepdim=True, correction=0) x_std = torch.sqrt(x_var + self.eps) x_norm = (x - x_mean)/ x_std return self.gamma.unsqueeze(0).unsqueeze(1) * x_norm + self.beta.unsqueeze(0).unsqueeze(1)`

Assuming our input \(x\) has the shape (`batch`

, `seq_len`

, `d_model`

), for batch normalization, we normalize across both the batch and sequence length dimensions (`0`

and `1`

respectively), but keep the feature dimension (`d_model`

) intact. This is because BatchNorm aims to stabilize the distribution of each feature over the mini-batch.

Layer Normalization [2], unlike Batch Norm, normalizes the features for each individual data point in a batch, making it less susceptible to variations in batch size.

The output \(\hat{x}\) is computed similarly to Batch Norm but differs in the axis over which \(\mathbb{E}(x)\)and \(Var(x)\) are computed.

\[\hat{x} = \frac{x - \mathbb{E}_{\text{features}}(x)}{\sqrt{Var_{\text{feature}}(x) + \epsilon}} \cdot \gamma + \beta\]

Here, \(\mathbb{E}_{\text{features}}(x)\) and \(Var_{\text{features}}(x)\) are the mean and variance calculated over the feature dimension.

Less sensitive to batch size than Batch Norm

Works well for sequence models

Stabilizes training

Accelerates convergence

**Python Implementation**

`class LayerNorm(nn.Module): def __init__( self, size: int, eps: float = 1e-5, ): """ Layer Normalization. Assumes the shape of the input x is (batch, seq_len, d_model) Args: size: shape of the feature dimention (i.e. d_model) eps: For numerical stability. Defaults to 1e-5. """ super(Layernorm, self).__init__() self.eps = eps self.gamma = nn.Parameter(torch.ones(size), requires_grad=True) self.beta = nn.Parameter(torch.ones(size), requires_grad=True) def forward(self, x): x_var, x_mean = torch.var_mean(x, dim=-1, keepdim=True, correction=0) x_std = torch.sqrt(x_var + self.eps) x_norm = (x - x_mean)/ x_std return self.gamma.unsqueeze(0).unsqueeze(1) * x_norm + self.beta.unsqueeze(0).unsqueeze(1)`

Assuming our input \(x\) has the shape (`batch`

, `seq_len`

, `d_model`

), Layer Norm normalizes across the feature dimension (`d_model`

) for each sequence in the batch. The rationale is to normalize all features for a single data point to have zero mean and unit variance, making the model less sensitive to the scale of input features.

RMSNorm [3] is a variant of LayerNorm that 1) uses the root mean square, \(\mathbb{E}(x^2)\), instead of the standard deviation for re-scaling and 2) does not use the re-centering operation. The authors hypothesize that the re-centering invariant property in LayerNorm is dispensable, and only keep the re-scaling invariance property in RMS Norm.

The output \(\hat{x}\) is calculated as:

\[\hat{x} = \frac{x}{ \sqrt{\mathbb{E}_{\text{feature}}(x^2) + \epsilon}} \cdot \gamma\]

- Computationally simpler and thus more efficient than Layer Norm

**Python Implementation**

`class RMSNorm(nn.Module): def __init__( self, size: int, eps: float = 1e-5, ): """ Root-Mean-Square Layer Normalization. Assumes the shape of the input x is (batch, seq_len, d_model) Args: size: shape of the feature dimention (i.e. d_model) eps: For numerical stability. Defaults to 1e-5. """ super(RMSnorm, self).__init__() self.eps = eps self.gamma = nn.Parameter(torch.ones(size), requires_grad=True) def forward(self, x): rms = torch.sqrt((x ** 2).mean(dim=-1, keepdim=True) + self.eps) # as an alternative can also use the frobenius norm to compute rms x_norm = x / rms return self.gamma.unsqueeze(0).unsqueeze(1) * x_norm`

Assuming our input \(x\) has the shape (`batch`

, `seq_len`

, `d_model`

), for RMS Layer Normalization, like LN, we normalize across the feature dimension (`d_model`

). We use the root mean square of the feature values for each data point in the sequence. This method is computationally efficient and can be more robust to outliers.

**Batch Norm**: Requires storage of running statistics, making it harder to parallelize.**Layer Norm**: Less computationally intensive as no running statistics are needed.**RMSNorm**: Even less computationally intensive than LayerNorm due to no re-centering.

**Batch Normalization**: This technique is particularly strong in convolutional architectures where batch sizes are often large enough for the mean and variance estimates to be reliable. However, it's not ideal for models like RNNs and Transformers where sequence lengths can vary. Its reliance on running statistics also poses challenges for online learning scenarios and can add complexity when attempting to parallelize the model across multiple devices.**Layer Normalization**: Highly effective for sequence models such as RNNs and Transformers. It's also a better choice for scenarios with small batch sizes as it computes statistics for each data point independently, negating the need for a large batch to estimate population statistics.**RMSNorm**: If computational efficiency is your priority, RMSNorm offers a simpler equation that's less computationally intensive than LayerNorm. Experiments on several NLP tasks show that RMSNorm is comparable to LayerNorm in quality, but accelerates the running speed [3].

Incorporating the right normalization technique can make or break your deep learning model. This post has aimed to provide a theoretical and practical overview of Batch Normalization, Layer Normalization, and RMS Layer Normalization. The Python implementations should help you get a hands-on understanding of how these techniques work at a granular level.

[1] Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift, Sergey Ioffe and Christian Szegedy, 2015

[2] Layer normalization, Jimmy Lei Ba et al., 2016

[3] Root Mean Square Layer Normalization, Biao Zhang and Rico Sennrich, 2019

]]>The transformer architecture has seen a meteoric rise in its applications across various domains of machine learning. However, the architecture lacks an inherent understanding of the order or sequence of tokens. This necessitates some form of positional encoding, such as Rotary Positional Embedding (RoPE) [1]. This blog post delves into the mathematical formulation of RoPE and its practical implementation in PyTorch.

Transformers employ self-attention or cross-attention mechanisms that are agnostic to the order of tokens. This means the model perceives the input tokens as a set rather than a sequence. It thereby loses crucial information about the relationships between tokens based on their positions in the sequence. To mitigate this, positional encodings are utilized to embed information about the token positions directly into the model.

RoPE is unique in that it encodes the absolute position \(m\\\) of tokens through a rotation matrix \(R_{\Theta, m}\), and also incorporates the explicit relative position dependency in the self-attention formulation. The idea is to embed the position of a token in a sequence by rotating queries and keys, with a different rotation at each position.

By rotating each query/key dependent on where they are in the sequence, we increasingly reduce the dot product (due to encouraging increasing levels of misalignment) depending on how far tokens are away from each other.

For a 2D query vector, \(q_m = \begin{pmatrix} q_m^{(1)} \\ q_m^{(2)} \end{pmatrix}\), at a single position \(m\), the rotation matrix \(R_{\Theta, m}\) is a 2x2 matrix formulated as:

\[R_{\Theta,m} = \begin{pmatrix} cos\ m & sin\ m \\ sin\ m & cos\ m \end{pmatrix}\]

where \(\theta \in \mathbb{R}\) is a preset non-zero constant.

We can use this matrix to rotate a vector \(q\), to obtain the new rotated vector \(\hat{q}\):

$$\hat{q} = \left(\begin{array}{cc} \cos m \theta & -\sin m \theta \\ \sin m \theta & \cos m \theta \end{array}\right) \left(\begin{array}{l} q_m^{(1)} \\ q_m^{(2)} \end{array}\right)$$

Here is a visualization of the rotation applied to the 2D vector \(q\). The initial vector is displayed in black and the red vectors are the rotated versions of the initial vector.

To generalize to any even dimension \(d \ge 2\), we divide the \(d\)-dimensional space into \( \frac{d}{2}\) sub-spaces and apply rotations individually.

Let's see a simple example with \(d=4\), the rotary matrix is defined as:

where \(\Theta = \{\theta_i = 10000^{-2i/4}, i \in [1, 2]\}\) are the pre-defined parameters and \(i\) the index for the \(\frac{d}{2}\) sub-spaces, corresponding to:

The general form of the parameters is:

\(\Theta = \{\theta_i = 10000^{-2(i -1) /d}, i \in [1, \dots, d/2]\}\)

In transformers, the rotation is applied to both queries and keys before the dot-product score is computed in the attention mechanism. For query \(q_m\) in position \(m\) and key \(k_n\) in position \(n\), the formula becomes:

\[q_m^T k_n = (R_{\theta,m}^d q_m )^T (R_{\theta,n}^d k_n ) = q_m^T R_{\theta,n-m}^d k_n,\]

where \(R_{\theta,n-m} = (R_{\theta,m}^d)^T R_{\theta,n}^d \) .

`def get_rotary_matrix(context_len: int, embedding_dim: int) -> torch.Tensor: """ Generate the Rotary Matrix for ROPE Args: context_len (int): context len embedding_dim (int): embedding dim Returns: torch.Tensor: the rotary matrix of dimension context_len x embedding_dim x embedding_dim """ R = torch.zeros((context_len, embedding_dim, embedding_dim), requires_grad=False) positions = torch.arange(1, context_len+1).unsqueeze(1) # Create matrix theta (shape: context_len x embedding_dim // 2) slice_i = torch.arange(0, embedding_dim // 2) theta = 10000. ** (-2.0 * (slice_i.float()) / embedding_dim) m_theta = positions * theta # Create sin and cos values cos_values = torch.cos(m_theta) sin_values = torch.sin(m_theta) # Populate the rotary matrix R using 2D slicing R[:, 2*slice_i, 2*slice_i] = cos_values R[:, 2*slice_i, 2*slice_i+1] = -sin_values R[:, 2*slice_i+1, 2*slice_i] = sin_values R[:, 2*slice_i+1, 2*slice_i+1] = cos_values return R# ... Init parameters and random inputbatch_size = 16context_len = 128embedding_dim = 1024ff_q = nn.Linear(embedding_dim, embedding_dim, bias=False)ff_k = nn.Linear(embedding_dim, embedding_dim, bias=False)x = torch.randn((batch_size, context_len, embedding_dim))# ... In the attention computationqueries = ff_q(x)keys = ff_k(x)R_matrix = get_rotary_matrix(context_len, embedding_dim)queries_rot = (queries.transpose(0,1) @ R_matrix).transpose(0,1)keys_rot = (keys.transpose(0,1) @ R_matrix).transpose(0,1)# ... Compute the score in the attention mechanism using the rotated queries and keys`

**Initialization**: We initialize a 3D tensor (R) to store the rotary matrices for each position \(m\).`R = torch.zeros((context_len, embedding_dim, embedding_dim), requires_grad=False)`

**Positional and Dimensional Slicing**: We create a tensor of positions and another tensor`slice_i`

for indexing into our \(d\)-dimensional space.`positions = torch.arange(1, context_len+1).unsqueeze(1) slice_i = torch.arange(0, embedding_dim // 2)`

**Theta Calculation**: A decay factor \(\theta\) is calculated, which scales down the effect of positions for higher dimensions.`theta = 10000. ** (-2.0 * (slice_i.float()) / embedding_dim) m_theta = positions * theta`

**Sin and Cos Values**: We calculate the sine and cosine values for each \(m\theta\).`cos_values = torch.cos(m_theta) sin_values = torch.sin(m_theta)`

**Populating (R)**: We use slicing to populate the rotary matrix efficiently.`R[:, 2*slice_i, 2*slice_i] = cos_values R[:, 2*slice_i, 2*slice_i+1] = -sin_values R[:, 2*slice_i+1, 2*slice_i] = sin_values R[:, 2*slice_i+1, 2*slice_i+1] = cos_values`

**Rotate the queries and keys**: Apply the rotary matrix to the queries and keys inside the attention computation.`R_matrix = get_rotary_matrix(context_len, embedding_dim) queries_rot = (queries.transpose(0,1) @ R_matrix).transpose(0,1) keys_rot = (keys.transpose(0,1) @ R_matrix).transpose(0,1)`

[1] Roformer: Enhanced transformer with rotary position embedding, Jianlin Su et al., 2021.

]]>