Algorithm 6 details the Geometric Self-Attention layer. It can be efficiently implemented using similar ideas as FlashAttention (33). It is used twice in our system: in the VQ-VAE encoder for structure tokens (Appendix A.1.7.1), and in the first layer of ESM3.
Unlike regular self-attention, which only operates on perresidue embeddings, Geometric Attention incorporates the per-residue frames $T$ to integrate geometric information in a rotation and translation invariant way. The process of forming the attention matrix $A$ is as follows:
QKV Projections: Two sets of keys and queries $\left(Q{r}, K{r} ight)$ and $\left(Q{d}, K{d} ight)$, along with $V$, all with shapes $\in \mathbb{R}^{L imes h imes 3}$ are linearly projected from layer input $X$. $L$ is the sequence length, $h$ is the number of heads.
Convert QKV to global frame: Each of the queries, keys and values are initially assumed to be in the local frame of their corresponding residue.
(a) Convert to Global Rotational Frame: We convert each of the vectors in $Q{r}, K{r}, V$ from their local frame (where the $x y$ plane is the $N-C{lpha}-C$ plane for each residue) to a global rotational frame (where the $x y$ plane is aligned for all residues) by applying $\mathbf{R}{i}$ (Algorithm 6, lines 3, 4).
(b) Convert to Global Distance Frame: We convert each of the vectors in $Q{d}, K{d}$ from their local frame to a global frame by applying $T_{i}$ (Algorithm 6 , lines 5, 6).
Directional Attention: The pairwise, per-head $h$ rotational similarity $R$ between keys $i$ and queries $j$ is calculated using the dot product $[R]{i, j, h}=rac{1}{\sqrt{3}}\left[q{r} ight]{i, h,:}$. $\left[k{r} ight]_{j, h,:}$ This is equivalent to the cosine distance between projected points.
Distance Attention: The pairwise, per-head $h$ distance similarity $D$ between keys $i$ and queries $j$ is computed using the $L{2}$ norm of the difference $[D]{i, j, h}=$ $rac{1}{\sqrt{3}}\left|\left[q{r} ight]{i, h,:}-\left[k{r} ight]{j, h,:} ight|_{2}$.
Scale Factor: $R$ and $D$ are scaled per-head with learned scalars $\left[ar{w}{r} ight]{h}$ and $\left[ar{w}{d} ight]{h}$, respectively, where $ar{w}{r}, ar{w}{d} \in \mathbb{R}^{h}$. We use the softplus function to transform weights into $[0, \infty)^{h}$. This scaling allows certain heads to specialize in attending via distance or directional attention.
Algorithm 6 geometric_mha
Input: $X \in \mathbb{R}^{L imes d}, T \in S E(3)^{L}$
$Q_{r}, K_{r}, Q_{d}, K_{d}, V=\operatorname{Linear}(X) \quad riangleright\left(\mathbb{R}^{L imes h imes 3}
ight)_{ imes 5}$
$\left(\mathbf{R}_{i}, \mathbf{t}_{i}
ight)=T_{i} \quad riangleright\left(S O(3)^{L}, \mathbb{R}^{L imes 3}
ight)$
$\left[Q_{r}
ight]_{i, h,:}=\mathbf{R}_{i}\left(\left[Q_{r}
ight]_{i, h,:}
ight) \quad riangleright \mathbb{R}^{L imes h imes 3}$
$\left[K_{r}
ight]_{i, h,:}=\mathbf{R}_{i}\left(\left[K_{r}
ight]_{i, h,:}
ight)$
$ riangleright \mathbb{R}^{L imes h imes 3}$
$\left[Q_{d}
ight]_{i, h,:}=T_{i}\left(\left[Q_{d}
ight]_{i, h,:}
ight) \quad riangleright \mathbb{R}^{L imes h imes 3}$
$\left[K_{d}
ight]_{i, h,:}=T_{i}\left(\left[K_{d}
ight]_{i, h,:}
ight) \quad riangleright \mathbb{R}^{L imes h imes 3}$
$7:[R]_{i, j, h}=rac{1}{\sqrt{3}}\left[q_{r}
ight]_{i, h,:} \cdot\left[k_{r}
ight]_{j, h,:} \quad riangleright \mathbb{R}^{L imes L imes h}$
8: $[D]_{i, j, h}=rac{1}{\sqrt{3}}\left\|\left[q_{r}
ight]_{i, h,:}-\left[k_{r}
ight]_{j, h,:}
ight\|_{2} \quad riangleright \mathbb{R}^{L imes L imes h}$
9: $A=\operatorname{softplus}\left(ar{w}_{r}
ight) R-\operatorname{softplus}\left(ar{w}_{d}
ight) D \quad riangleright \mathbb{R}^{L imes L imes h}$
$A=\operatorname{softmax}_{j}(A)$
$[V]_{i, h,:}=\mathbf{R}_{i}\left([V]_{i, h,:}
ight)$
$O=A \cdot V \quad riangleright \mathbb{R}^{L imes h imes 3}$
$[O]_{i, h,:}=\mathbf{R}_{i}^{-1}\left([O]_{i, h,:}
ight)$
$X=X+\operatorname{Linear}(O)$
$ riangle \mathbb{R}^{L imes d}$
Please explain this in detail to a beginning grad student in machine learning.
Response: Certainly! Let's dive into the Geometric Self-Attention layer described in Algorithm 6. We'll break it down step by step, ensuring that all the concepts are clear.
Self-Attention in Transformers:
In Transformer models, self-attention allows each position in the input sequence to attend to all other positions, capturing dependencies regardless of their distance in the sequence. It operates primarily on per-residue embeddings, which are vector representations of tokens (e.g., words, amino acids).
Limitations in Geometric Contexts:
When dealing with data that has geometric properties (like protein structures), regular self-attention doesn't consider the spatial relationships between elements. Proteins are not just sequences; they have 3D structures that are crucial for their function.
Geometric Self-Attention extends regular self-attention by incorporating geometric information, specifically rotations and translations, in a way that is invariant to these transformations. This is essential for modeling structures where orientation and position matter but should not affect the outcome due to their arbitrary nature in 3D space.
From the input embeddings ( X \in \mathbb{R}^{L \times d} ), we compute:
Two sets of queries and keys:
Values:
Explanation:
( L ): Sequence length (number of residues).
( h ): Number of attention heads.
The projections map the input embeddings into lower-dimensional spaces suitable for capturing geometric relationships.
Local Frames: Initially, ( Qr, Kr, V ) are in the local frame of each residue. This means their coordinate axes are defined relative to each residue's orientation.
Global Frames: We need to bring these vectors into a common frame of reference to compute meaningful relationships between different residues.
(a) Rotational Frame Conversion:
(b) Distance Frame Conversion:
Why Two Conversions?
Directional Information (( Qr, Kr, V )): Focused on the orientation of residues.
Positional Information (( Qd, Kd )): Focused on the actual positions for distance computations.
Compute the rotational similarity ( R ) between queries and keys:
[ [R]{i, j, h} = \frac{1}{\sqrt{3}} \left( [Qr]{i, h, :} \cdot [Kr]_{j, h, :} \right) ]
Explanation:
The dot product measures the cosine similarity between the directional vectors.
Dividing by ( \sqrt{3} ) normalizes the similarity scores.
Compute the distance similarity ( D ) between queries and keys:
[ [D]{i, j, h} = \frac{1}{\sqrt{3}} \left| [Qd]{i, h, :} - [Kd]{j, h, :} \right|2 ]
Explanation:
The ( L_2 ) norm computes the Euclidean distance between the positional vectors.
Again, dividing by ( \sqrt{3} ) normalizes the distances.
Apply learned scalar weights to ( R ) and ( D ):
[ A = \text{softplus}( \bar{w}r ) R - \text{softplus}( \bar{w}d ) D ]
Explanation:
( \bar{w}r, \bar{w}d \in \mathbb{R}^h ) are parameters learned during training, one for each head.
The softplus function ensures the weights are positive: ( \text{softplus}(x) = \ln(1 + e^x) ).
This scaling allows each attention head to specialize:
Combining ( R ) and ( D ):
By weighting and subtracting ( D ) from ( R ), we create a combined attention score that considers both how residues are oriented relative to each other and how far apart they are.
Let's go through Algorithm 6 step by step.
Compute projections for queries, keys, and values:
[ Qr, Kr, Qd, Kd, V = \text{Linear}(X) ]
Explanation:
A linear layer projects ( X ) into multiple components.
Each component has shape ( \mathbb{R}^{L \times h \times 3} ).
This operation is similar to how standard attention computes Q, K, V but adapted for geometric data.
For each residue ( i ):
[ (\mathbf{R}i, \mathbf{t}i) = T_i ]
Explanation:
( \mathbf{R}_i \in SO(3) ): Rotation matrix.
( \mathbf{t}_i \in \mathbb{R}^{3} ): Translation vector.
( SO(3) ) is the group of 3D rotations.
Apply rotation ( \mathbf{R}_i ) to each vector:
[ [Qr]{i, h, :} = \mathbf{R}i ( [Qr]_{i, h, :} ) ]
[ [Kr]{i, h, :} = \mathbf{R}i ( [Kr]_{i, h, :} ) ]
Explanation:
Rotates the local vectors into the global rotational frame.
Apply full transformation ( T_i ) (rotation and translation):
[ [Qd]{i, h, :} = Ti ( [Qd]_{i, h, :} ) ]
[ [Kd]{i, h, :} = Ti ( [Kd]_{i, h, :} ) ]
Explanation:
Now, ( Qd ) and ( Kd ) are in a global frame where both orientation and position are aligned.
Compute ( R ):
[ [R]{i, j, h} = \frac{1}{\sqrt{3}} [Qr]{i, h, :} \cdot [Kr]_{j, h, :} ]
Explanation:
Dot product between the rotated queries and keys.
Measures how similarly oriented the residues are.
Compute ( D ):
[ [D]{i, j, h} = \frac{1}{\sqrt{3}} \left| [Qd]{i, h, :} - [Kd]{j, h, :} \right|2 ]
Explanation:
Euclidean distance between the transformed queries and keys.
Measures how close the residues are in space.
Apply learned weights and combine ( R ) and ( D ):
[ A = \text{softplus}( \bar{w}r ) R - \text{softplus}( \bar{w}d ) D ]
Explanation:
Learned weights ( \bar{w}r, \bar{w}d ) modulate the contribution of ( R ) and ( D ).
( A \in \mathbb{R}^{L \times L \times h} ): Attention scores before softmax.
Normalize attention scores:
[ A = \text{softmax}_j ( A ) ]
Explanation:
The softmax is computed over the index ( j ), ensuring that for each query ( i ) and head ( h ), the attention scores over all keys ( j ) sum to 1.
Apply rotation to ( V ):
[ [V]{i, h, :} = \mathbf{R}i ( [V]_{i, h, :} ) ]
Explanation:
Brings ( V ) into the global rotational frame, just like ( Qr ) and ( Kr ).
Perform attention-weighted sum:
[ O = A \cdot V ]
Explanation:
( O \in \mathbb{R}^{L \times h \times 3} ).
Combines the values ( V ) weighted by the attention scores ( A ).
Apply inverse rotation to ( O ):
[ [O]{i, h, :} = \mathbf{R}i^{-1} ( [O]_{i, h, :} ) ]
Explanation:
Transforms ( O ) back into each residue's local frame.
This is necessary because the subsequent layers or steps may expect data in the local frame.
Integrate ( O ) back into the model:
[ X = X + \text{Linear}( O ) ]
Explanation:
Applies a linear transformation to ( O ) to match the dimensionality of ( X ).
The result is added back to ( X ) as a residual connection, which helps with training deep networks by mitigating issues like vanishing gradients.
Geometric Self-Attention allows the model to consider both directional and distance-related geometric relationships between residues.
By transforming local vectors into a global frame, we can compute meaningful similarities and distances.
Learned weights enable the model to balance between focusing on orientation or distance, allowing different heads to specialize.
Residual connections and transformations back to the local frame ensure that the geometric information is seamlessly integrated into the model without disrupting the data's inherent structure.
In domains like protein modeling, capturing the intricate 3D relationships between amino acids is crucial. Proteins function based on their shapes and how different parts of the protein interact in space.
By incorporating geometric information directly into the attention mechanism, the model can:
Softplus Function: Smooth version of the ReLU function that maps real numbers to positive real numbers, ensuring that weights like ( \bar{w}r ) and ( \bar{w}d ) are positive.
Normalization Factors (( \sqrt{3} )): Used to scale the dot products and distances to have consistent magnitudes, which helps with training stability.
Attention Heads: Having multiple heads allows the model to capture different types of relationships simultaneously. Some heads might focus on local interactions, while others capture global structure.
I hope this detailed explanation clarifies how the Geometric Self-Attention layer works and how it integrates geometric information into the attention mechanism. Let me know if you have any questions or need further clarification on any part!