esm.doi.bio/esm33/out27

\section*{A.1.7.1.1. Codebook Learning}

quantize transforms the $L$ latents into $L$ discrete tokens. Since the VQ-VAE was initially proposed (67), numerous approaches and tricks have been developed to address issues with poor codebook utilization and unstable training. We chose to learn the codebook as an exponential moving average of encoder outputs (67-69). To improve codebook utilization, unused codes are re-initialized to encoder outputs.

\section*{A.1.7.1.2. Parallel Encoding}

To improve training and inference efficiency, we encode all local structure graphs within a protein in parallel. In practice, this means that given a batch of $B$ proteins with average sequence length $L$, then the inputs to the structure encoder will have shape $B L imes 16 imes d$.

\section*{A.1.7.2. DECODER}

While the encoder independently processes all local structures in parallel, the decoder $f_{ ext {dec }}$ attends over the entire set of $L$ tokens to reconstruct the full structure. It is composed using a stack of bidirectional Transformer blocks with regular self-attention.

As discussed in Appendix A.1.7.3, the VQ-VAE is trained in two stages. In the first stage, a smaller decoder trunk consisting of 8 Transformer blocks with width 1024, rotary positional embeddings, and MLPs is trained to only predict backbone coordinates. In the second stage, the decoder weights are re-initialized and the network size is expanded to 30 layers, each with an embedding dimension of 1280 ( $\sim 600 \mathrm{M}$ parameters) to predict all atom coordinates.

The exact steps to convert structure tokens back to 3D allatom coordinates using the decoder is provided in Algorithm 8 and detailed as follows,

  1. Transformer: We embed the structure tokens and pass them through a stack of Transformer blocks $f_{d e c}$ (regular self-attention + MLP sublayers, no geometric attention).

  2. Projection Head: We use a projection head to regress 3 3-D vectors per residue: a translation vector $ ec{t}$, and 2 vectors $- ec{x}$ and $ ec{y}$ that define the $N-C_{lpha}-C$ plane per residue after it has been rotated into position. This head also predicts the unnormalized sine and cosine components of up to 7 sidechain torsion angles.

  3. Calculate $T$ : We use gram_schmidt to convert $ ec{t}$, $ ec{x}$, and $ ec{y}$ into frames $T \in S E(3)^{L}$.

  4. Calculate $T{ ext {local }}$ : We normalize the sine and cosine components and convert them to frames $T{ ext {local }} \in$ $S E(3)^{L imes 7}$ corresponding to rotations around the previous element on the sidechain.

  5. Compose Frames: We compose each element of $T{ ext {local }}$ with its predecessors on a tree rooted at $T$ to form $T{ ext {global }} \in S E(3)^{L imes 14}$, corresponding to the transformations needed for each heavy atom per residue in atom14 representation.

  6. Apply Frames: We then apply the frame to the $\overrightarrow{X_{ ext {ref }}} \in$ $\mathbb{R}^{L imes 14 imes 3}$ coordinates in a reference frame, to rotate and transform each residue into their final positions.

Algorithm 8 structure_decode
Input: $z \in\{0 . .4099\}^{L     imes 16}$
    1: $z=\operatorname{embed}(z)$
    $    riangle \mathbb{R}^{L   imes d}$
    2: $z=f_{d e c}(z)$
    $    riangleright \mathbb{R}^{L  imes d}$
    3: $ec{t}, ec{x}, ec{y}, \sin     heta, \overline{\cos    heta}=\operatorname{proj}(z) \quad  riangleright\left(\mathbb{R}^{L     imes 3}
ight)_{    imes 3},\left(\mathbb{R}^{L     imes 7}
ight)_{    imes 2}$
    4: $T=$ gram_schmidt $(ec{t},-ec{x}, ec{y}) \quad     riangle S E(3)^{L}$
    5: $\sin     heta=rac{\overline{\sin    heta}}{\sqrt{\sin ^{2}+\overline{\cos   heta}}} \quad   riangleright[-1,1]^{L   imes 7}$

    7: $T_{    ext {local }}=$ rot_frames $(\sin     heta, \cos  heta) \quad     riangleright S E(3)^{L  imes 7}$
    8: $T_{    ext {global }}=$ compose $\left(T_{   ext {local }}, T
ight) \quad     riangleright S E(3)^{L  imes 14}$
    9: $ec{X}=T_{    ext {global }}\left(\overrightarrow{X_{r e f}}
ight) \quad     riangleright \mathbb{R}^{L  imes 14     imes 3}$

Please explain this in detail to a beginning grad student in machine learning.

Response: Introduction

In this explanation, we'll break down the steps involved in the decoder of a Vector Quantized Variational Autoencoder (VQ-VAE) as applied to protein structures. We'll focus on understanding Algorithm 8 (structure_decode), which details how to convert discrete latent tokens back into 3D all-atom coordinates of a protein. We'll assume you're familiar with basic machine learning concepts but new to this specific application.

Background

Proteins are composed of amino acids that fold into complex three-dimensional structures critical for their function. Modeling protein structures computationally is essential in bioinformatics and drug discovery.

A VQ-VAE is a type of autoencoder that learns discrete latent representations (codes) of input data. Instead of mapping inputs to a continuous latent space as in traditional autoencoders, VQ-VAE maps inputs to the nearest embedding vectors in a codebook of discrete codes. This is particularly useful in modeling data where discrete representations are meaningful.

Overview of the Process

  1. Encoding: The protein structure is broken down into local structures and encoded into discrete tokens using the VQ-VAE encoder.

  2. Quantization: The continuous latent representations are quantized to the nearest code in the codebook, resulting in discrete tokens.

  3. Decoding: The decoder takes these tokens and reconstructs the full 3D protein structure.

Focus on the Decoder (structure_decode)

The decoder's role is to reconstruct the protein's 3D structure from the discrete tokens. Let's go through Algorithm 8 step by step.

Algorithm 8: structure_decode

Inputs:

Step-by-Step Explanation

Step 1: Embedding the Tokens

1: z = embed(z)  # Results in z ∈ ℝ^{L × d}

Step 2: Passing through the Transformer Decoder

2: z = f_dec(z)  # z still ∈ ℝ^{L × d}

Step 3: Projection to Geometric Parameters

3: t_vec, x_vec, y_vec, sin_theta_bar, cos_theta_bar = proj(z)

Step 4: Constructing the Reference Frames

4: T = gram_schmidt(t_vec, -x_vec, y_vec)  # T ∈ SE(3)^{L}

Step 5: Normalizing the Angles

5: sin_theta = sin_theta_bar / sqrt(sin_theta_bar^2 + cos_theta_bar^2)  # sin_theta ∈ [-1, 1]^{L × 7}

Step 6: (Assumed) Calculating Cosine of the Angles

Step 7: Calculating Local Transformation Frames

7: T_local = rot_frames(sin_theta, cos_theta)  # T_local ∈ SE(3)^{L × 7}

Step 8: Composing Global Transformation Frames

8: T_global = compose(T_local, T)  # T_global ∈ SE(3)^{L × 14}

Step 9: Applying Transformations to Reference Coordinates

9: X = T_global(X_ref)  # X ∈ ℝ^{L × 14 × 3}

Summary of the Steps

  1. Embedding Tokens: Discrete tokens are converted into continuous vectors for processing by the neural network.
  2. Transformer Processing: The embeddings are passed through Transformer layers to capture relationships between residues.
  3. Projection to Geometric Parameters: Extract geometric information (positions, orientations, angles) necessary for reconstruction.
  4. Frame Construction: Build initial coordinate frames for each residue using orthonormal vectors.
  5. Angle Normalization: Ensure predicted angles are valid by normalizing sine and cosine components.
  6. Local Rotations: Calculate rotations for side-chain atoms based on normalized angles.
  7. Frame Composition: Combine local rotations with initial frames to get transformations for all atoms.
  8. Apply Transformations: Use the transformation matrices to get the final 3D coordinates of the protein.

Key Concepts Explained

Additional Notes

Conclusion

The decoder in this VQ-VAE model translates discrete tokens back into a full 3D protein structure by predicting geometric parameters and applying a series of transformations. This process leverages both advanced neural network architectures (Transformers) and geometric principles (rotations and coordinate frames) to accurately reconstruct complex molecular structures.

By understanding each step in structure_decode, a beginning graduate student can appreciate how machine learning models can integrate deep learning with domain-specific knowledge (such as protein chemistry) to solve complex problems.

Further Study Recommendations

References


Feel free to ask if you have any questions or need further clarification on any of the steps or concepts!