A Gentle Introduction to Multi-Head Latent Attention (MLA)


Not all Transformer models are called “large language models” because you can build a very small model using the Transformer architecture. The truly large Transformer models are often impractical to use at home because they’re too large to fit on a single computer and too slow to run without a cluster of GPUs.

The recent introduction of Multi-Head Latent Attention (MLA) proposed a new approach to running attention operations with a lower memory footprint. First proposed in DeepSeek-V2, it changes how you perform matrix multiplication in the attention operation. In this post, you will learn how MLA works and how to implement it in PyTorch.

Let’s get started.

A Gentle Introduction to Multi-Head Latent Attention (MLA)
Photo by Victoriano Izquierdo. Some rights reserved.

Overview

This post is divided into three parts; they are:

  • Low-Rank Approximation of Matrices
  • Multi-head Latent Attention (MLA)
  • PyTorch Implementation

Low-Rank Approximation of Matrices

Multi-Head Attention (MHA) and Grouped-Query Attention (GQA) are the attention mechanisms used in almost all transformer models. Recently, a new attention mechanism called Multi-head Latent Attention (MLA) was proposed in DeepSeek-V2 to further reduce computational cost and speed up inference.

The core idea is to use low-rank approximation to convert a large matrix into two smaller matrices, $M\approx UV$. If matrix $M$ is an $n\times m$ matrix, $U$ would be an $n\times r$ matrix and $V$ would be an $r\times m$ matrix, where $r$ is smaller than both $n$ and $m$. The product $UV$ won’t be identical to $M$, but it can be close enough for practical purposes. One method to decompose $M$ into $U$ and $V$ is to use singular value decomposition (SVD) and select the top $r$ orthonormal bases. Specifically, the SVD of matrix $M$ produces:

$$
M = U \Sigma V^T
$$

where $U$ and $V$ are square matrices (the orthonormal bases) and $\Sigma$ is a diagonal matrix containing the singular values of $M$. If you zero out the lower singular values from the diagonal of $\Sigma$, you effectively remove the lower rows of $U$ and $V$. The result of this multiplication is an approximation of $M$. If the elements zeroed out from $\Sigma$ are numerically close to zero, this approximation will be quite accurate.

This concept isn’t new. Low-rank adaptation is a common technique for fine-tuning large transformer models, and it also uses such approximations of projection matrices to augment the model for new functionality.

Multi-head Latent Attention (MLA)

Similar to GQA, which only manipulates the key and value projections, Multi-head Latent Attention (MLA) also factorizes only the key and value projections. However, unlike GQA, MLA doesn’t share the key and value projections across multiple queries, but operates in the same way as multi-head attention. The original paper describes MLA as operating on the compressed latent representation of the key/value space during inference.

For input sequence $X$, self-attention using MLA computes:

$$
\begin{aligned}
Q &= XW_Q^DW_Q^U = (XW_Q^D)W_Q^U = C_QW_Q^U \\
K &= XW_{KV}^DW_K^U = (XW_{KV}^D)W_K^U = C_{KV}W_K^U \\
V &= XW_{KV}^DW_V^U = (XW_{KV}^D)W_V^U = C_{KV}W_V^U
\end{aligned}
$$

Where:

  • $W_Q^D,W_{KV}^D \in \mathbb{R}^{d\times r}$ are low-rank compression matrices, with a small $r$, to reduce the dimension
  • $W_Q^U,W_K^U,W_V^U \in \mathbb{R}^{r\times(n_h d_h)}$ are decompression matrices, to recover the dimension
  • $r$ is the latent dimension, typically $r \ll n_h\cdot d_h$

You might notice that $K$, for example, is computed as a projection from $X$, but through two matrix multiplications instead of one. This might seem like a waste of computation, but you’ll see why this is actually efficient in the following explanation.

Now consider the standard attention operation:

$$
\begin{aligned}
O_i &= \text{softmax}\big(\frac{QK^\top}{\sqrt{d_k}}\big)V \\
&= \text{softmax}\big(\frac{(XW_Q^D W_{Q,i}^U)(XW_{KV}^D W_{K,i}^U)^\top}{\sqrt{d_k}}\big)XW_{KV}^D W_V^U \\
&= \text{softmax}\big(\frac{XW_Q^D W_{Q,i}^U {W_{K,i}^U}^\top {W_{KV}^D}^\top X^\top}{\sqrt{d_k}}\big)XW_{KV}^D W_{V,i}^U \\
&= \text{softmax}\big(\frac{C_Q W_{Q,i}^U {W_{K,i}^U}^\top C_{KV}^\top}{\sqrt{d_k}}\big)C_{KV} W_{V,i}^U
\end{aligned}
$$

This is where MLA’s computational savings come from: Instead of factoring the key and value projection matrices $W^K$ and $W^V$ independently, the compression matrices are shared. Recall that even in cross-attention, the key and value input sequences are the same, so you have a shared factor $C_{KV}$ for both the $K$ and $V$ projections.

Another key technique is that the multiple heads of attention are implemented only in the decompression matrices $W_Q^U, W_K^U, W_V^U$. Hence, for a single head, the equations above use the notations $W_{Q,i}^U, W_{K,i}^U, W_{V,i}^U$. In this way, both $C_Q$ and $C_{KV}$ are computed once and shared across all heads.

Furthermore, note the matrix multiplication $W_{Q,i}^U{W_{K,i}^U}^\top$ in the last line of the softmax above. This is a multiplication of two decompression matrices, independent of the input $X$. Therefore, this matrix multiplication can be pre-computed and cached as $W_{QK,i} = W_{Q,i}^U{W_{K,i}^U}^\top$, saving time during inference.

By breaking down the projection matrices and using a lower dimension for the latent representation, MLA saves computation and memory usage even though there are more matrices involved in the equation.

PyTorch Implementation

Once you understand MLA’s design, implementing it in PyTorch is straightforward. Here’s an example:

Comparing this code with the equations from the previous section, you can see that $W_{QK,i}$ is defined directly as a component in this module.

The input sequence x to the forward() method has a shape of (batch_size, seq_len, d_model), as does the final output. First, the input x is projected into C_q and C_kv, which are shared by all attention heads. Next, the attention score is computed for each head using two matrix multiplications. First, you use self.W_qk to multiply C_q, then reshape the result into (batch_size, seq_len, num_heads, kv_latent_dim). Then you multiply it with C_kv, after appropriate axis transpositions, to get the attention score. Since C_qW_qk is a 4-dimensional tensor and C_kv is a 3-dimensional one, you add a dummy dimension to C_kv in place of the num_heads dimension.

Next, you obtain the attention weight by applying softmax to the attention score. To get the attention output, you multiply the attention weight with V, which is computed from C_kv projected using self.Wv_u. Finally, you concatenate the outputs of all heads and apply the output projection to get the final output.

The original MLA paper suggests that it outperforms GQA in both model quality and inference speed. Since the matrices are smaller in this case, it’s also more memory efficient. However, you don’t need to train a model specifically for MLA. You can also convert a model trained with traditional multi-head attention to MLA by factoring the projection matrices after training.

Further Readings

Below are some resources you may find useful:

Summary

In this post, you learned how MLA works and how to implement it in PyTorch. MLA is a new attention mechanism proposed in DeepSeek-V2 that uses low-rank approximation of projection matrices in multi-head attention. This approach can significantly reduce computational cost and memory usage while maintaining model performance.

Learn Transformers and Attention!

Building Transformer Models with Attention

Teach your deep learning model to read a sentence

…using transformer models with attention

Discover how in my new Ebook:

Building Transformer Models with Attention

It provides self-study tutorials with working code to guide you into building a fully-working transformer models that can

translate sentences from one language to another

Give magical power of understanding human language for
Your Projects

See What’s Inside

Leave a Reply

Your email address will not be published. Required fields are marked *