\section*{A.1.7.3. TRAINING}
When using a VQ-VAE to learn discrete representations which maximize reconstruction quality, it is common to train in the autoencoder in two stages (70). In the first stage, the encoder and codebook is learned with a relatively small and efficient decoder. In the second stage, the encoder and codebook are frozen and a larger or otherwise more computationally expensive decoder is trained to maximize reconstruction quality. We follow this two-stage training approach for the structure tokenizer.
\section*{A.1.7.3.1. Stage 1.}
The VQ-VAE is trained for $90 \mathrm{k}$ steps on a dataset of single chain proteins from the PDB, AFDB, and ESMAtlas. We use the AdamW optimizer (Loshchilov et al. 2017) with learning rate annealed from $4 \mathrm{e}-4$ according to a cosine decay schedule. Proteins are cropped to a maximum sequence length of 512. Five losses are used to supervise this stage of training. The geometric distance and geometric direction losses are responsible for supervising reconstruction of high quality backbone structures.
Additionally, a distogram and binned direction classification loss are used to bootstrap structure prediction but are ultimately immaterial to reconstruction. We have found that these structure prediction losses formulated as classification tasks improve convergence early in training. To produce these pairwise logits, we use a pairwiseprojhead, that takes $x \in \mathbb{R}^{L imes d}$ and returns logits $z \in \mathbb{R}^{L imes L imes d^{\prime}}$. It works as follows:
Algorithm 9 pairwise_proj_head
Input: $x \in \mathbb{R}^{L imes d}$
$q, k=\operatorname{proj}(x), \operatorname{proj}(x)$
$: \operatorname{prod}_{i, j,:} \operatorname{diff}_{i, j,:}=q_{j,:} \odot k_{i,:}, q_{j,:}-k_{i,:}$
$z=$ regression_head $([$ prod $\mid$ diff $]) riangleright \mathbb{R}^{L imes L imes d^{\prime}}$
return $z$
Finally, an inverse folding token prediction loss (i.e., a crossentropy loss between predicted sequence and ground truth sequence) is an auxiliary loss used to encourage the learned representations to contain information pertinent to sequencerelated tasks.
The five losses are covered in detailed as follows:
Algorithm 10 backbone_distance_loss
Input: $\hat{X} \in \mathbb{R}^{L imes 3 imes 3}, X \in \mathbb{R}^{L imes 3 imes 3}$
: $\hat{Z}, Z=\operatorname{flatten}(\hat{X})$, flatten $(X) \quad riangleright \mathbb{R}^{3 L imes 3}, \mathbb{R}^{3 L imes 3}$
$\left[D_{ ext {pred }}
ight]_{i, j}=\left\|[\hat{Z}]_{i,:}-[\hat{Z}]_{j,:}
ight\|_{2}^{2} \quad riangleright \mathbb{R}^{3 L imes 3 L}$
$[D]_{i, j}=\left\|[Z]_{i,:}-[Z]_{j,:}
ight\|_{2}^{2} \quad riangleright \mathbb{R}^{3 L imes 3 L}$
$E=\left(D_{ ext {pred }}-D
ight)^{2}$
$E=\min (E, 25)$
$l=\operatorname{mean}_{i, j}(E)$
$ riangle \mathbb{R}$
return $l$
(b) $C_{lpha} ightarrow C$
(c) $C ightarrow N_{ ext {next }}$
(d) $\mathbf{n}{C{lpha}}=-\left(N ightarrow C{lpha} ight) imes\left(C{lpha} ightarrow C ight)$
(e) $\mathbf{n}{N}=\left(C{ ext {prev }} ightarrow N ight) imes\left(N ightarrow C_{lpha} ight)$
(f) $\mathbf{n}{C}=\left(C{lpha} ightarrow C ight) imes\left(C ightarrow N_{ ext {next }} ight)$
Compute the pairwise dot product, forming $D{ ext {pred }}, D \in$ $\mathbb{R}^{6 L imes 6 L}$. Compute $\left(D{ ext {pred }}-D ight)^{2}$, clamp the maximum error to 20 , and take the mean.
In algorithm form (with compute_vectors computing the six vectors described above):
Algorithm 11 backbone_direction_loss
Input: $\hat{X} \in \mathbb{R}^{L imes 3 imes 3}, X \in \mathbb{R}^{L imes 3 imes 3}$
$\hat{V}=$ compute_vectors $(\hat{X}) \quad riangleright \mathbb{R}^{6 L imes 3}$
$V=$ compute_vectors $(X) \quad riangle \mathbb{R}^{6 L imes 3}$
$\left[D_{ ext {pred }}
ight]_{i, j}=[\hat{V}]_{i,:} \cdot[\hat{V}]_{j,:} \quad riangleright \mathbb{R}^{6 L imes 6 L}$
$[D]_{i, j}=[V]_{i,:} \cdot[V]_{j,:} \quad riangleright \mathbb{R}^{6 L imes 6 L}$
$E=\left(D_{ ext {pred }}-D
ight)^{2}$
$E=\min (E, 20)$
$l=\operatorname{mean}_{i, j}(E) \quad riangleright \mathbb{R}$
return $l$
(a) Unit vectors: Compute three vectors per residue from ground truth coordinates: $C{lpha} ightarrow C, C{lpha} ightarrow$ $N$, and $\mathbf{n}{C{lpha}}=\left(C{lpha} ightarrow C ight) imes\left(C{lpha} ightarrow N ight)$, and normalize them to unit length.
(b) Dot Products: Compute pairwise dot products between each pair of vectors for all residues, forming $D \in[-1,1]^{L imes L imes 6}$. Bin the dot products into 16 evenly spaced bins in $[-1,1]$, forming classification labels $y \in{0 . .15}^{L imes L}$.
(c) Pairwise Logits: Pass the final layer representations of the decoder $h \in \mathbb{R}^{L imes d}$ through a pairwiseprojhead to obtain logits $z \in$ $\mathbb{R}^{L imes L imes 6 imes 16}$.
(d) Cross Entropy: Calculate cross-entropy loss using the labels $y$ from the ground truth structure and the logits $z$, and average over all $L imes L imes 6$ values.
i. Obtain the three vectors $N ightarrow C{lpha}, C{lpha} ightarrow C$, and $\mathbf{n}=\left(N ightarrow C{lpha} ight) imes\left(C{lpha} ightarrow C ight)$.
ii. Define the following scalars:
$$ egin{aligned} a & =-0.58273431 \ b & =0.56802827 \ c & =-0.54067466 \end{aligned} $$
iii. Compute the location of $C_{eta}$ using the formula:
$C{eta}=a \mathbf{n}+b\left(N ightarrow C{lpha} ight)+c\left(C{lpha} ightarrow C ight)+C{lpha}$
(b) Pairwise $C{eta}$ distances: Compute an $L imes L$ pairwise distance matrix of the $C{eta}$, and bin them into one of 64 bins, with lower bounds $\left[0,2.3125^{2},(2.3125+0.3075)^{2}, \ldots, 21.6875^{2} ight]$, forming the labels $y \in{0 . .63}^{L imes L}$.
(c) Pairwise logits: Pass the final layer representations of the decoder $h \in \mathbb{R}^{L imes d}$ through a pairwiseprojhead to obtain the logits $z \in \mathbb{R}^{L imes L imes 64}$.
(d) Cross Entropy: Calculate the cross-entropy using the labels $y$ computed from the ground truth structure and the logits $z$, then average over all $L imes L$ values.
Please explain this in detail to a beginning grad student in machine learning.
Response: Introduction
In this explanation, we'll delve into the training process described in section A.1.7.3.1, focusing on Stage 1. We'll break down the objectives, methods, and algorithms used to train a Vector Quantized Variational Autoencoder (VQ-VAE) for modeling protein structures. This explanation is tailored for a beginning graduate student in machine learning.
1. Background: Understanding VQ-VAE and Its Application to Proteins
1.1 What is a VQ-VAE?
Variational Autoencoder (VAE): A type of neural network that learns to encode input data into a latent space and then decode it back to reconstruct the original input. It's used for unsupervised learning of data distributions.
Vector Quantized VAE (VQ-VAE): An extension of VAE where the latent space is discrete rather than continuous. Instead of continuous latent variables, the encoder outputs indices into a finite set of latent embeddings (the codebook). This makes VQ-VAE suitable for tasks that require discrete outputs, like generating sequences or tokens.
1.2 Why Use VQ-VAE for Protein Structures?
Proteins are complex molecules with sequences of amino acids that fold into three-dimensional structures. Modeling these structures requires capturing both the sequential and spatial information.
Learning Discrete Representations: By using VQ-VAE, we can encode protein structures into discrete tokens. This is useful for tasks like protein design, where generating new sequences with desired structural properties is important.
Efficient Reconstruction: VQ-VAE allows for efficient encoding and decoding, making it feasible to handle large datasets like protein databases.
2. Training the VQ-VAE in Two Stages
2.1 Why Two Stages?
Training the encoder and decoder together can be challenging due to the complexity of protein structures. A two-stage training process helps to:
This approach ensures that the encoder learns robust representations before training a powerful decoder to exploit them.
3. Stage 1 Training Details
3.1 Overview
In Stage 1, we train the VQ-VAE to:
3.2 Dataset and Training Configuration
3.3 Loss Functions
Five loss functions are used to supervise the training:
These losses are designed to capture various aspects of protein structure and to stabilize and accelerate training.
4. Detailed Explanation of Loss Functions
4.1 Backbone Distance Loss
Purpose: Ensure that the predicted protein backbone closely matches the true backbone in terms of distances between atoms.
Backbone Atoms: Each residue in a protein backbone has three key atoms:
For ( L ) residues, there are ( 3L ) backbone atoms.
Steps:
Intuition: By minimizing the difference in distances between all pairs of backbone atoms, we encourage the predicted structure to closely match the true structure in terms of overall shape and size.
4.2 Backbone Direction Loss
Purpose: Ensure that the predicted structures have correct local orientations, capturing the geometric relationships between atoms.
Vectors Computed Per Residue:
Steps:
Intuition: By focusing on the dot products of unit vectors, we are ensuring that the relative orientations of backbone atoms in the predicted structure match those of the true structure, which is crucial for accurate folding.
4.3 Binned Direction Classification Loss
Purpose: Provide a coarse-grained measure of orientation similarity to stabilize early training, especially when predicted coordinates are not yet accurate.
Steps:
pairwise_proj_head
to obtain logits ( z \in \mathbb{R}^{L \times L \times 6 \times 16} ).Intuition: This loss helps the model learn general orientation patterns early in training, even when exact coordinates are not accurate. It's a form of coarse alignment that guides the model toward the correct geometries.
4.4 Distogram Loss
Purpose: Encourage the model to learn accurate inter-residue distances by predicting distances between residues in a discretized (binned) fashion.
Steps:
pairwise_proj_head
to obtain logits ( z \in \mathbb{R}^{L \times L \times 64} ).Intuition: By predicting discretized distances between residues, the model learns to approximate the overall spatial arrangement of the protein, which is essential for accurate structure prediction.
4.5 Inverse Folding Loss
Purpose: Encourage the model to capture sequence information from structural representations, promoting a richer embedding that relates structure and sequence.
Steps:
Intuition: Predicting the amino acid sequence from structural representations ensures that the model's embeddings contain information relevant to both structure and sequence, which is beneficial for tasks like protein design.
5. Additional Components
5.1 pairwiseprojhead Function
This function is used to generate pairwise logits for classification tasks (like in the Binned Direction Classification Loss and Distogram Loss).
Steps:
Purpose: By combining the product and difference of projected representations, the function captures interactions between positions ( i ) and ( j ), which is useful for predicting pairwise relationships.
5.2 compute_vectors Function
This function computes vectors necessary for the Backbone Direction Loss and other geometric computations.
Steps:
6. Summary and Intuition
6.1 Importance of Multi-Loss Training
Using multiple loss functions allows the model to learn different aspects of protein structures:
6.2 How These Losses Work Together
6.3 Overall Training Strategy
Conclusion
By understanding the components and objectives of each loss function used in Stage 1 training, we gain insight into how the VQ-VAE is trained to model protein structures effectively. The combination of detailed geometric losses and auxiliary classification tasks provides a robust framework for learning high-quality discrete representations, which are crucial for downstream applications like protein design and structural analysis.
Note to the Reader
As a beginning graduate student in machine learning, it's essential to grasp how different loss functions can be strategically employed to guide a model's learning process. In complex tasks like protein structure prediction, combining various losses allows the model to capture different levels of information, from local geometries to global patterns, ultimately leading to better performance and more meaningful representations.