\section*{A.1.7.1.1. Codebook Learning}
quantize transforms the $L$ latents into $L$ discrete tokens. Since the VQ-VAE was initially proposed (67), numerous approaches and tricks have been developed to address issues with poor codebook utilization and unstable training. We chose to learn the codebook as an exponential moving average of encoder outputs (67-69). To improve codebook utilization, unused codes are re-initialized to encoder outputs.
\section*{A.1.7.1.2. Parallel Encoding}
To improve training and inference efficiency, we encode all local structure graphs within a protein in parallel. In practice, this means that given a batch of $B$ proteins with average sequence length $L$, then the inputs to the structure encoder will have shape $B L imes 16 imes d$.
\section*{A.1.7.2. DECODER}
While the encoder independently processes all local structures in parallel, the decoder $f_{ ext {dec }}$ attends over the entire set of $L$ tokens to reconstruct the full structure. It is composed using a stack of bidirectional Transformer blocks with regular self-attention.
As discussed in Appendix A.1.7.3, the VQ-VAE is trained in two stages. In the first stage, a smaller decoder trunk consisting of 8 Transformer blocks with width 1024, rotary positional embeddings, and MLPs is trained to only predict backbone coordinates. In the second stage, the decoder weights are re-initialized and the network size is expanded to 30 layers, each with an embedding dimension of 1280 ( $\sim 600 \mathrm{M}$ parameters) to predict all atom coordinates.
The exact steps to convert structure tokens back to 3D allatom coordinates using the decoder is provided in Algorithm 8 and detailed as follows,
Transformer: We embed the structure tokens and pass them through a stack of Transformer blocks $f_{d e c}$ (regular self-attention + MLP sublayers, no geometric attention).
Projection Head: We use a projection head to regress 3 3-D vectors per residue: a translation vector $ec{t}$, and 2 vectors $-ec{x}$ and $ec{y}$ that define the $N-C_{lpha}-C$ plane per residue after it has been rotated into position. This head also predicts the unnormalized sine and cosine components of up to 7 sidechain torsion angles.
Calculate $T$ : We use gram_schmidt to convert $ec{t}$, $ec{x}$, and $ec{y}$ into frames $T \in S E(3)^{L}$.
Calculate $T{ ext {local }}$ : We normalize the sine and cosine components and convert them to frames $T{ ext {local }} \in$ $S E(3)^{L imes 7}$ corresponding to rotations around the previous element on the sidechain.
Compose Frames: We compose each element of $T{ ext {local }}$ with its predecessors on a tree rooted at $T$ to form $T{ ext {global }} \in S E(3)^{L imes 14}$, corresponding to the transformations needed for each heavy atom per residue in atom14 representation.
Apply Frames: We then apply the frame to the $\overrightarrow{X_{ ext {ref }}} \in$ $\mathbb{R}^{L imes 14 imes 3}$ coordinates in a reference frame, to rotate and transform each residue into their final positions.
Algorithm 8 structure_decode
Input: $z \in\{0 . .4099\}^{L imes 16}$
1: $z=\operatorname{embed}(z)$
$ riangle \mathbb{R}^{L imes d}$
2: $z=f_{d e c}(z)$
$ riangleright \mathbb{R}^{L imes d}$
3: $ec{t}, ec{x}, ec{y}, \sin heta, \overline{\cos heta}=\operatorname{proj}(z) \quad riangleright\left(\mathbb{R}^{L imes 3}
ight)_{ imes 3},\left(\mathbb{R}^{L imes 7}
ight)_{ imes 2}$
4: $T=$ gram_schmidt $(ec{t},-ec{x}, ec{y}) \quad riangle S E(3)^{L}$
5: $\sin heta=rac{\overline{\sin heta}}{\sqrt{\sin ^{2}+\overline{\cos heta}}} \quad riangleright[-1,1]^{L imes 7}$
7: $T_{ ext {local }}=$ rot_frames $(\sin heta, \cos heta) \quad riangleright S E(3)^{L imes 7}$
8: $T_{ ext {global }}=$ compose $\left(T_{ ext {local }}, T
ight) \quad riangleright S E(3)^{L imes 14}$
9: $ec{X}=T_{ ext {global }}\left(\overrightarrow{X_{r e f}}
ight) \quad riangleright \mathbb{R}^{L imes 14 imes 3}$
Please explain this in detail to a beginning grad student in machine learning.
Response: Introduction
In this explanation, we'll break down the steps involved in the decoder of a Vector Quantized Variational Autoencoder (VQ-VAE) as applied to protein structures. We'll focus on understanding Algorithm 8 (structure_decode
), which details how to convert discrete latent tokens back into 3D all-atom coordinates of a protein. We'll assume you're familiar with basic machine learning concepts but new to this specific application.
Background
Proteins are composed of amino acids that fold into complex three-dimensional structures critical for their function. Modeling protein structures computationally is essential in bioinformatics and drug discovery.
A VQ-VAE is a type of autoencoder that learns discrete latent representations (codes) of input data. Instead of mapping inputs to a continuous latent space as in traditional autoencoders, VQ-VAE maps inputs to the nearest embedding vectors in a codebook of discrete codes. This is particularly useful in modeling data where discrete representations are meaningful.
Overview of the Process
Encoding: The protein structure is broken down into local structures and encoded into discrete tokens using the VQ-VAE encoder.
Quantization: The continuous latent representations are quantized to the nearest code in the codebook, resulting in discrete tokens.
Decoding: The decoder takes these tokens and reconstructs the full 3D protein structure.
Focus on the Decoder (structure_decode
)
The decoder's role is to reconstruct the protein's 3D structure from the discrete tokens. Let's go through Algorithm 8 step by step.
Algorithm 8: structure_decode
Inputs:
Step-by-Step Explanation
Step 1: Embedding the Tokens
1: z = embed(z) # Results in z ∈ ℝ^{L × d}
Step 2: Passing through the Transformer Decoder
2: z = f_dec(z) # z still ∈ ℝ^{L × d}
Step 3: Projection to Geometric Parameters
3: t_vec, x_vec, y_vec, sin_theta_bar, cos_theta_bar = proj(z)
Step 4: Constructing the Reference Frames
4: T = gram_schmidt(t_vec, -x_vec, y_vec) # T ∈ SE(3)^{L}
Step 5: Normalizing the Angles
5: sin_theta = sin_theta_bar / sqrt(sin_theta_bar^2 + cos_theta_bar^2) # sin_theta ∈ [-1, 1]^{L × 7}
Step 6: (Assumed) Calculating Cosine of the Angles
plaintext
cos_theta = cos_theta_bar / sqrt(sin_theta_bar^2 + cos_theta_bar^2) # cos_theta ∈ [-1, 1]^{L × 7}
Step 7: Calculating Local Transformation Frames
7: T_local = rot_frames(sin_theta, cos_theta) # T_local ∈ SE(3)^{L × 7}
Step 8: Composing Global Transformation Frames
8: T_global = compose(T_local, T) # T_global ∈ SE(3)^{L × 14}
Step 9: Applying Transformations to Reference Coordinates
9: X = T_global(X_ref) # X ∈ ℝ^{L × 14 × 3}
Summary of the Steps
Key Concepts Explained
Additional Notes
Conclusion
The decoder in this VQ-VAE model translates discrete tokens back into a full 3D protein structure by predicting geometric parameters and applying a series of transformations. This process leverages both advanced neural network architectures (Transformers) and geometric principles (rotations and coordinate frames) to accurately reconstruct complex molecular structures.
By understanding each step in structure_decode
, a beginning graduate student can appreciate how machine learning models can integrate deep learning with domain-specific knowledge (such as protein chemistry) to solve complex problems.
Further Study Recommendations
References
Feel free to ask if you have any questions or need further clarification on any of the steps or concepts!