Llama2 From Scratch with Pytorch Lightning

Llama2 From Scratch with Pytorch Lightning

In our previous blog post, we built the Llama LLM with PyTorch Lightning, with Weights & Biases for experiment tracking and Hydra for configuration management.

Now, we turn our attention to Llama 2, the successor to Llama. Let's look at the differences:

  • Dataset: Llama2 benefits from a 40% increase in training data.

  • Context Length: Trained with a 4096 token context length, up from 2048.

  • Attention Mechanism: An architectural evolution from Multi-head Attention (MHA) to Grouped-Query Attention (GQA).

Apart from the switch to GQA, the architecture remains untouched. Thus, much of our Llama codebase remains applicable, sparing only the attention block.

Understanding the Grouped-Query Attention Block

Let's now see what the new Grouped-Query Attention block (GQALlamaBlock) looks like and break down the code:

class GQALlamaBlock(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        context_len: int,
        causal_attention: bool,
        n_heads: int,
        n_groups:int,
        swiglu_d_multiplier: float
        ):
        super().__init__()
        self.embedding_size = embedding_size
        self.causal_attention = causal_attention
        self.n_heads = n_heads
        self.n_groups = n_groups 
        assert self.n_heads % self.n_groups == 0, f"Number of heads ({self.n_heads}) must be divisable by the number of groups ({self.n_groups})"
        self.group_size = self.n_heads // self.n_groups
        assert self.embedding_size % self.n_heads == 0, f"Embedding size ({self.embedding_size}) must be divisable by the number of heads ({self.n_heads})"
        self.head_dim = self.embedding_size // self.n_heads

        self.R = get_rotary_matrix(context_len=context_len, embedding_dim=self.head_dim)   
        self.rms = RMSnorm(size=embedding_size)
        self.ff_q = nn.Linear(embedding_size, embedding_size, bias=False)
        kv_embedding_size = self.head_dim * self.n_groups
        self.ff_k = nn.Linear(embedding_size, kv_embedding_size, bias=False)
        self.ff_v = nn.Linear(embedding_size, kv_embedding_size, bias=False)

        # In Llama paper swiglu_d_multiplier = 2/3 * 4 
        swiglu_size = int(swiglu_d_multiplier * embedding_size) 
        self.fc1 = nn.Linear(embedding_size, swiglu_size) 
        self.activation = SwiGLU(size=swiglu_size)
        self.fc2 = nn.Linear(swiglu_size, embedding_size)

    def forward(self, x):
        input_shape = x.shape
        q_resize = (x.shape[0], x.shape[1], self.n_heads, self.head_dim)
        kv_resize = (x.shape[0], x.shape[1], self.n_groups, self.head_dim)
        x_res = x
        x = self.rms(x) # pre-normalization
        query = self.ff_q(x).reshape(q_resize)
        key = self.ff_k(x).reshape(kv_resize)
        value = self.ff_v(x).reshape(kv_resize)

        # Apply rotation to query and key, separatly for each head  
        R_matrix = self.R[:input_shape[1], :, :].to(query.device) 
        query_rot = torch.einsum('bhld,ldd->bhld', query.permute(0,2,1,3), R_matrix)
        key_rot = torch.einsum('bgdl,ldd->bgdl', key.permute(0,2,3,1), R_matrix)

        query_rot = query_rot.reshape(input_shape[0], self.group_size, self.n_groups, input_shape[1], self.head_dim)
        score = torch.einsum('bsgld, bgdp->bsglp', query_rot, key_rot)
        if self.causal_attention:
            score += causal_mask(size=score.shape, device=score.device)
        score = score / torch.sqrt(torch.tensor(self.head_dim)) 
        attention = torch.softmax(score, dim=-1) 
        x = torch.einsum('bsgpl,bgld->bsgpd', attention, value.permute(0,2,1,3))
        x = x.reshape(input_shape[0], self.group_size*self.n_groups, input_shape[1], self.head_dim)
        x = x.permute(0, 2, 1, 3).reshape(input_shape)        
        x += x_res
        x_res = x
        x = self.rms(x)
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        return x + x_res
  • Initialization: The constructor sets up the block, ensuring the number of heads is divisible by the number of groups and the embedding size by the number of heads.

  • Matrices and Normalization: get_rotary_matrix generates a rotary position embedding matrix, while RMSnorm is used for layer normalization, both carried over from Llama.

  • Group-aware Feedforward (New!): ff_q, ff_k, and ff_v are linear layers for transforming inputs into queries, keys, and values, respectively. These are reshaped according to the number of heads and groups. The output size of ff_q is \(d \cdot h\), and ff_k and ff_v are \(d \cdot g\), where \(d\) is the dimensionality of each head/group, \(h\) is the number of heads, and \(g\) is the number of groups.

  • SwiGLU Activation: A SwiGLU-based feedforward network with a custom size determined by swiglu_d_multiplier is employed.

The Forward Pass:

During the forward pass, the method reshapes queries, keys, and values, and applies the rotary embeddings separately to each head. This rotation aligns the attention mechanism with the relative position information.

Again, most components are the same as we sow in MHALlamaBlock in Llama. Let's focus on the main differences, all done with the magic of

reshape and torch.einsum!

  1. Calculating Attention Scores:

     query_rot = query_rot.reshape(
                 input_shape[0], 
                 self.group_size, 
                 self.n_groups, 
                 input_shape[1], 
                 self.head_dim
                 )
     score = torch.einsum('bsgld, bgdp->bsglp', query_rot, key_rot)
    

    The first line reshapes the rotated query to prepare it for attention score calculation.

    • input_shape[0] is the batch size

    • self.group_size is the size of each group within the heads. This is derived by dividing the total number of heads by the number of groups (n_heads / n_groups).

    • self.n_groups is the number of groups that the heads are divided into.

    • input_shape[1] is the sequence length.

    • self.head_dim is the dimensionality of each head.

This reshaping step is crucial because it aligns the data into a structure that reflects the grouping of attention heads. Each head within a group will contribute to a portion of the attention calculation, and this structure facilitates that process.

The second line calculates the attention scores using torch.einsum . The notation 'bsgld, bgdp->bsglp' describes how the tensors are combined.

  • bsgld corresponds to the reshaped rotated queries tensor:

    • b for batch size

    • s for the size of each group

    • g for the number of groups

    • l for the sequence length

    • d for the dimension of each head

  • bgdp corresponds to the rotated keys tensor:

    • b for batch size

    • g for the number of groups

    • d for the dimension of each head

    • p for the sequence length, which is the same as l but is labelled differently to indicate a different role in this operation (here, p is the 'target' position in the sequence that each 'source' position l is attending to).

When torch.einsum processes this operation, it does the following:

  • It aligns the rotated queries and keys based on the batch and group dimensions (b and g).

  • It then computes dot products between each query and key vector across the dimension d for each position in the sequence (l attending to p).

  • The result is a raw attention score tensor: bsglp.

  1. Applying Attention to Values:

     x = torch.einsum('bsgpl,bgld->bsgpd', attention, value.permute(0,2,1,3))
    
  2. After applying the causal mask, the normalization and the softmax function to the scores, this torch.einsum operation applies the calculated attention to the values.

Using the same notation as the attention score calculation, bsgpl represents the attention weights.

  • bgld represents the value vectors, permuted to align with the attention weights:

    • b for batch size

    • g for the number of groups

    • l for the sequence length

    • d for the dimension of each head

  • -> bsgpd indicates the resulting tensor's dimensions:

    • b for batch size

    • s for the size of each group

    • g for the number of groups

    • p for the sequence length (the position that received the attention)

    • d for the dimension of each head

After applying the attention, the tensor must be reshaped back to the original input dimensions to maintain the consistency of the model's layers.

x = x.reshape(
    input_shape[0], 
    self.group_size*self.n_groups, 
    input_shape[1], 
    self.head_dim
    )

The tensor is reshaped to collapse the sub-group and group dimensions (s * g) back into a single dimension representing all heads. input_shape[0] is the batch size, and input_shape[1] is the sequence length, both derived from the original input tensor. self.head_dim is the dimension of each head.

The next line:

x = x.permute(0, 2, 1, 3).reshape(input_shape)

First, reorder the dimensions of x to match the expected order for the next layers in the model (x.permute(0, 2, 1, 3)). .reshape(input_shape) then reshape the tensor to the original input shape, ensuring that the output of the attention block can be seamlessly integrated into the subsequent layers of the model.

Generation Examples

Right, now that we have our Llama2 model, let's use it for token generation! For that, let's compare some examples of token generation using the different sampling methods described in the Llama post, i.e. greedy, random sampling, top-k sampling, top-p sampling, and their variants including temperature scaling.

We can train our Llama2 model using python baby_llama/run.py model=llama2, this produces the following results:

Full Text, greedy: 
[BOS]KING RICHARD III:
O Ratcliff, I have dream'd a fearful dream!
What thinkest thou, will our friends prove all true?[EOS]

Full Text, rnd_sampling: 
[BOS]AUTOLYCUS:
TheWhen ravel spacious! therich of mine hath
 chairs not appointed me in the blended of my post,
There and garland to the wronging testy of the din
of myHourly.[EOS]

Full Text, rnd_sampling_t: 
[BOS]SICINIUS:
Come, what talk you
Of Marcius?[EOS]

Full Text, topk_sampling: 
[BOS]CORIOLANUS:
It is a purposed thing, and grows by plot,
To curb the will of the nobility:
Suffer't, and live with such as cannot rule
Nor ever will be ruled.[EOS]

Full Text, topk_sampling_t: 
[BOS]KING EDWARD IV:
So, master mayor: these gates must not be shut
But in the night or in the time of war.
What! fear not, man, but yield me up the keys;
ForSpit must be fear'd and thee,
And all those friends that deign to follow me.[EOS]

Full Text, topp_sampling: 
[BOS]ROMEO:
Or I shall.[EOS]

Full Text, topp_sampling_t: 
[BOS]BUCKINGHAM:
And, in good time, here comes the noble duke.[EOS]

Similar to the previous Llama model, we are only training for 10 epochs, using a small network (8 layers), hidden dimension (1024), context length (256) and training batch size (8). Additionally, here we are using 4 groups within GQA. You can check the wandb run to see all the configurations and generation examples during training for this experiment.

For this small example, we can't see a notable improvement in memory usage. But still, the number of parameters of the model is slightly smaller (233M of Llama vs 224M of Llama2) and we can see in the figure below that Llama2 (dandy-lake-72, purple) is using (slightly) less memory than Llama (hopeful-surf-62, brown).

It should be noted that these improvements will scale with the model and that GQA's main advantage is to speed up inference.