esm.doi.bio/esm33/out29

\section*{A.1.7.3.2. Stage 2.}

In the second stage of VQ-VAE training, the encoder and codebook are frozen and a new, deeper, decoder is trained. This second stage of training has multiple purposes. First, a larger decoder improves reconstruction quality. Second, augmented structure tokens from ESM3 are added to enable learning pAE and pLDDT heads. Third, we add sequence conditioning and train with all-atom geometric losses to be able to decode all-atom protein structures. Fourth, we extend the context length of the decoder to be able to decode multimers and larger single chain proteins.

Training data for stage 2 consists of predicted structures in AFDB and ESMAtlas, as well as single chain, multimer, and antibody-antigen complexes from the PDB. Sequence conditioning was added to the decoder via learned embeddings which are summed with structure token embeddings at the input to the decoder stack.

The structure token decoder was trained in three stages: $2 \mathrm{~A}$, 2B, 2C detailed in Table S2. The purpose of stage 2A is to efficiently learn decoding of all-atom structures. Enhanced training efficiency is achieved by keeping a short context length and omitting the pAE and pLDDT losses, which are both memory-consuming and can be in competition with strong reconstruction quality. In stage $2 \mathrm{~B}$, we add the pAE and pLDDT losses. These structure confidence heads cannot be well-calibrated unless structure tokens are augmented such that ESM3-predicted structure tokens are within the training distribution. To this end, for stages $2 \mathrm{~B}$ and $2 \mathrm{C}$ we replace ground truth structure tokens with ESM3-predicted structure tokens $50 \%$ of the time. In stage $2 \mathrm{C}$, we extend context length to 2048 and upsample experimental structures relative to predicted structures.

  1. All-atom Distance Loss: We generalize the Backbone Distance Loss to all atoms by computing a pairwise $L{2}$ distance matrix for all 14 atoms in the atom14 representation of each residue. This results in $D{ ext {pred }}, D \in \mathbb{R}^{14 L imes 14 L}$. The rest of the computation follows as before: $\left(D_{ ext {pred }}-D ight)^{2}$, clamping to $(5 \AA)^{2}$, and taking the mean, while masking invalid pairs (where any atom14 representations are "empty").

  2. All-atom Direction Loss: We extend the Backbone Direction Loss to all heavy atoms:

(a) Compute a pairwise distance matrix per residue from the 3D coordinates of each atom in atom14 representation, resulting in $\mathbb{R}^{L imes 14 imes 14}$.

(b) Mark atoms less than $2 \AA$ apart (excluding self) as covalent bonds.

(c) Filter to keep atoms with at least 2 covalent bonds, keeping only the first 2 bonds per atom, with ordering determined by the atom 14 representation.

(d) For each selected atom, compute a normal (zaxis) vector to the plane spanned by its two covalent bonds, resulting in three vectors per selected atom.

(e) Randomly subsample to 10,000 vectors per protein if the number exceeds 10,000 , ensuring the same vectors are sampled in both predicted and ground truth structures.

(f) Compute all-to-all pairwise dot products, forming $D{ ext {pred }}, D \in \mathbb{R}^{n imes n}$. Compute $\left(D{ ext {pred }}-D ight)^{2}$, clamp the max to 20 , and take the mean.

  1. pLDDT Head: Uses a Regression Head with 50 output classes (each capturing 0.02 units from 0 to 100 ). Predicted structures are compared to ground truth to calculate per-residue pLDDT values, which are supervised with cross-entropy loss.
  2. pAE Head: Use a Pairwise Projection Head to produce 64 logits per residue pair $\in \mathbb{R}^{L imes L imes d}$, converting to probabilities $p$ via softmax. Each probability corresponds to a bin representing $0.5 \AA$ of positional error, with centers $[0.25,0.75, \ldots, 31.25,31.75]$.

Please explain this in detail to a beginning grad student in machine learning.

Response: ### Introduction to Stage 2 of VQ-VAE Training for Protein Structures

Overview

In this explanation, we'll delve into the second stage of training a Vector Quantized Variational AutoEncoder (VQ-VAE) designed for protein structures. By the end, you'll understand:

  1. The objectives of Stage 2 training.
  2. The enhancements made to the decoder.
  3. Introduction of new losses and why they're important.
  4. Details about the pLDDT and pAE heads.

Let's start by setting the context.


Background Concepts

VQ-VAE in Protein Modeling

Why VQ-VAE for Proteins?


Goals of Stage 2 Training

In Stage 2, we aim to enhance the VQ-VAE to better handle complex protein structures and improve prediction confidence. The main goals are:

  1. Improve Reconstruction Quality: By using a larger decoder, we can capture more complex patterns in protein structures.
  2. Enable Confidence Predictions: Incorporate pLDDT (Predicted Local Distance Difference Test) and pAE (Predicted Alignment Error) heads to estimate the confidence in our structure predictions.
  3. All-Atom Modeling: Move from modeling just the backbone atoms to all atoms in the protein, providing a more detailed representation.
  4. Handle Larger Proteins and Complexes: Extend the context length to model larger proteins, including multimers (complexes of multiple protein chains).

Training Data for Stage 2


Enhancements in Stage 2

1. Freezing the Encoder and Codebook

2. Larger Decoder

3. Structure Confidence Heads

4. All-Atom Modeling

5. Longer Context Length


Training Stages Within Stage 2

The training is split into three sub-stages: 2A, 2B, and 2C.

Stage 2A

Stage 2B

Stage 2C


Loss Functions and Model Heads

Now, let's dive into the specific losses and model components introduced in Stage 2.

1. All-Atom Distance Loss

2. All-Atom Direction Loss

  1. Pairwise Distances per Residue:
  2. Identify Covalent Bonds:
  3. Filter Atoms:
  4. Compute Normal Vectors:
  5. Subsampling:
  6. Compute Pairwise Dot Products:
  7. Compute Loss:

3. pLDDT Head (Predicted Local Distance Difference Test)

4. pAE Head (Predicted Alignment Error)


Why These Losses and Heads?


Handling ESM3-Predicted Structure Tokens


Extending Context Length


Summary

In Stage 2 of training the VQ-VAE for protein structures, we make significant enhancements to the decoder to improve its ability to reconstruct accurate and detailed protein structures. By introducing new losses that focus on all-atom distances and orientations, as well as confidence prediction heads (pLDDT and pAE), we enable the model to produce not just accurate structures but also estimates of how reliable those structures are.

Data augmentation with ESM3-predicted tokens helps the model generalize to the types of errors present in real-world predictions, leading to better calibration of the confidence measures. Extending the context length and focusing on experimental structures ensure that the model is capable of handling the complexity of large proteins and protein complexes.


Key Takeaways for a Machine Learning Graduate Student


Further Reading


I hope this detailed explanation clarifies the concepts and methodologies employed in Stage 2 of training the VQ-VAE for protein structures. If you have any questions or need further elaboration on any of the points, feel free to ask!