esm.doi.bio/esm33/out25

Algorithm 6 details the Geometric Self-Attention layer. It can be efficiently implemented using similar ideas as FlashAttention (33). It is used twice in our system: in the VQ-VAE encoder for structure tokens (Appendix A.1.7.1), and in the first layer of ESM3.

Unlike regular self-attention, which only operates on perresidue embeddings, Geometric Attention incorporates the per-residue frames $T$ to integrate geometric information in a rotation and translation invariant way. The process of forming the attention matrix $A$ is as follows:

  1. QKV Projections: Two sets of keys and queries $\left(Q{r}, K{r} ight)$ and $\left(Q{d}, K{d} ight)$, along with $V$, all with shapes $\in \mathbb{R}^{L imes h imes 3}$ are linearly projected from layer input $X$. $L$ is the sequence length, $h$ is the number of heads.

  2. Convert QKV to global frame: Each of the queries, keys and values are initially assumed to be in the local frame of their corresponding residue.

(a) Convert to Global Rotational Frame: We convert each of the vectors in $Q{r}, K{r}, V$ from their local frame (where the $x y$ plane is the $N-C{lpha}-C$ plane for each residue) to a global rotational frame (where the $x y$ plane is aligned for all residues) by applying $\mathbf{R}{i}$ (Algorithm 6, lines 3, 4).

(b) Convert to Global Distance Frame: We convert each of the vectors in $Q{d}, K{d}$ from their local frame to a global frame by applying $T_{i}$ (Algorithm 6 , lines 5, 6).

  1. Directional Attention: The pairwise, per-head $h$ rotational similarity $R$ between keys $i$ and queries $j$ is calculated using the dot product $[R]{i, j, h}= rac{1}{\sqrt{3}}\left[q{r} ight]{i, h,:}$. $\left[k{r} ight]_{j, h,:}$ This is equivalent to the cosine distance between projected points.

  2. Distance Attention: The pairwise, per-head $h$ distance similarity $D$ between keys $i$ and queries $j$ is computed using the $L{2}$ norm of the difference $[D]{i, j, h}=$ $ rac{1}{\sqrt{3}}\left|\left[q{r} ight]{i, h,:}-\left[k{r} ight]{j, h,:} ight|_{2}$.

  3. Scale Factor: $R$ and $D$ are scaled per-head with learned scalars $\left[ar{w}{r} ight]{h}$ and $\left[ar{w}{d} ight]{h}$, respectively, where $ar{w}{r}, ar{w}{d} \in \mathbb{R}^{h}$. We use the softplus function to transform weights into $[0, \infty)^{h}$. This scaling allows certain heads to specialize in attending via distance or directional attention.

Algorithm 6 geometric_mha
Input: $X \in \mathbb{R}^{L     imes d}, T \in S E(3)^{L}$
    $Q_{r}, K_{r}, Q_{d}, K_{d}, V=\operatorname{Linear}(X) \quad     riangleright\left(\mathbb{R}^{L     imes h  imes 3}
ight)_{    imes 5}$
    $\left(\mathbf{R}_{i}, \mathbf{t}_{i}
ight)=T_{i} \quad     riangleright\left(S O(3)^{L}, \mathbb{R}^{L     imes 3}
ight)$
    $\left[Q_{r}
ight]_{i, h,:}=\mathbf{R}_{i}\left(\left[Q_{r}
ight]_{i, h,:}
ight) \quad     riangleright \mathbb{R}^{L  imes h  imes 3}$
    $\left[K_{r}
ight]_{i, h,:}=\mathbf{R}_{i}\left(\left[K_{r}
ight]_{i, h,:}
ight)$
    $    riangleright \mathbb{R}^{L  imes h  imes 3}$
    $\left[Q_{d}
ight]_{i, h,:}=T_{i}\left(\left[Q_{d}
ight]_{i, h,:}
ight) \quad     riangleright \mathbb{R}^{L  imes h  imes 3}$
    $\left[K_{d}
ight]_{i, h,:}=T_{i}\left(\left[K_{d}
ight]_{i, h,:}
ight) \quad     riangleright \mathbb{R}^{L  imes h  imes 3}$
    $7:[R]_{i, j, h}=rac{1}{\sqrt{3}}\left[q_{r}
ight]_{i, h,:} \cdot\left[k_{r}
ight]_{j, h,:} \quad     riangleright \mathbb{R}^{L  imes L  imes h}$
    8: $[D]_{i, j, h}=rac{1}{\sqrt{3}}\left\|\left[q_{r}
ight]_{i, h,:}-\left[k_{r}
ight]_{j, h,:}
ight\|_{2} \quad     riangleright \mathbb{R}^{L  imes L  imes h}$
    9: $A=\operatorname{softplus}\left(ar{w}_{r}
ight) R-\operatorname{softplus}\left(ar{w}_{d}
ight) D \quad     riangleright \mathbb{R}^{L  imes L  imes h}$
    $A=\operatorname{softmax}_{j}(A)$
    $[V]_{i, h,:}=\mathbf{R}_{i}\left([V]_{i, h,:}
ight)$
    $O=A \cdot V \quad     riangleright \mathbb{R}^{L  imes h  imes 3}$
    $[O]_{i, h,:}=\mathbf{R}_{i}^{-1}\left([O]_{i, h,:}
ight)$
    $X=X+\operatorname{Linear}(O)$
    $    riangle \mathbb{R}^{L   imes d}$

Please explain this in detail to a beginning grad student in machine learning.

Response: Certainly! Let's dive into the Geometric Self-Attention layer described in Algorithm 6. We'll break it down step by step, ensuring that all the concepts are clear.

Background:

Self-Attention in Transformers:

In Transformer models, self-attention allows each position in the input sequence to attend to all other positions, capturing dependencies regardless of their distance in the sequence. It operates primarily on per-residue embeddings, which are vector representations of tokens (e.g., words, amino acids).

Limitations in Geometric Contexts:

When dealing with data that has geometric properties (like protein structures), regular self-attention doesn't consider the spatial relationships between elements. Proteins are not just sequences; they have 3D structures that are crucial for their function.

Introducing Geometric Self-Attention:

Geometric Self-Attention extends regular self-attention by incorporating geometric information, specifically rotations and translations, in a way that is invariant to these transformations. This is essential for modeling structures where orientation and position matter but should not affect the outcome due to their arbitrary nature in 3D space.


Understanding the Steps:

1. QKV Projections:

2. Converting QKV to Global Frame:

3. Directional Attention:

4. Distance Attention:

5. Scale Factor:


Understanding Algorithm 6 Line by Line:

Let's go through Algorithm 6 step by step.

Input:

Line 2: Compute Projections

Line 3-4: Extract Transformations

Line 5-6: Convert ( Qr, Kr, V ) to Global Rotational Frame

Line 7-8: Convert ( Qd, Kd ) to Global Distance Frame

Line 9: Compute Rotational Similarity ( R )

Line 10: Compute Distance Similarity ( D )

Line 11: Compute Scaled Attention Scores ( A )

Line 12: Apply Softmax

Line 13: Rotate Values ( V )

Line 14: Compute Output ( O )

Line 15: Convert ( O ) Back to Local Frame

Line 16-17: Update ( X ) with Output


Key Takeaways:


Why Is This Important?

In domains like protein modeling, capturing the intricate 3D relationships between amino acids is crucial. Proteins function based on their shapes and how different parts of the protein interact in space.

By incorporating geometric information directly into the attention mechanism, the model can:


Final Notes:


I hope this detailed explanation clarifies how the Geometric Self-Attention layer works and how it integrates geometric information into the attention mechanism. Let me know if you have any questions or need further clarification on any part!