Distillation of MSA Embeddings to Folded Protein Structures with Graph Transformers

Determining the structure of proteins has been a long-standing goal in biology. Language models have been recently deployed to capture the evolutionary semantics of protein sequences. Enriched with multiple sequence alignments (MSA), these models can encode protein tertiary structure. In this work, we introduce an attention-based graph architecture that exploits MSA Transformer embeddings to directly produce three-dimensional folded structures from protein sequences. We envision that this pipeline will provide a basis for efficient, end-to-end protein structure prediction.

counterparts, in the form of a multiple sequence alignment (MSA), further strengthens the predictive power of these 23 transformer architectures, as demonstrated by state-of-art contact prediction results [Rao et al., 2021]. In this study, we leverage MSA Transformer embeddings within a geometric deep learning architecture to 26 directly map protein sequences to folded, three-dimensional structures. In contrast to existing architectures, we directly 27 estimate point coordinates in a learned, canonical pose, which removes the dependency on classical methods for 28 resolving distance maps, and enables gradient passing for downstream tasks, such as side chain prediction and protein 29 refinement. Overall, our results provide a bridge to a complete, end-to-end folding pipeline.
We treat the protein folding problem as a graph optimization problem. First, we harvest information-dense embeddings 32 produced by the MSA Transformer [Rao et al., 2021], and use these embeddings to produce initial node and edge hidden 33 representations in a complete graph. To process and structure geometric information, we employ the attention-based 34 architecture of the Graph Transformer, as proposed by [Shi et al., 2021]. Final node representations are then projected 35 into Cartesian coordinates through a learnable transformation, and the resulting induced distance maps are compared to 36 ground truth to define the loss for training.

38
The MSA Transformer is an unsupervised protein language model that produces information-rich residue embeddings 39 [Rao et al., 2021]. In contrast to other protein language models, it operates on two dimensional inputs consisting of 40 a length-N query sequence along with its MSA sequences. It utilizes an Axial Transformer [Ho et al., 2019]  We treat a protein as an attributed complete graph. Let H V and H E be the dimensionalities of node and edge representations, respectively. These attributes are extracted from MSA-Transformer embeddings through standard deep neural networks: is a ReLU nonlinearity, and D V and D E are the depths of node and edge 62 information extractors, respectively. W denotes dense learnable parameters, and here and in the following equations we 63 omit bias terms.

65
The Graph Transformer was introduced in [Shi et al., 2021] to incorporate edge features directly into graph attention. This is possible by directly summing transformations of edge attributes to the original keys and values of the attention mechanism. We approach protein folding with a variation of this architecture. Consider layer l node hidden states, {h l i }, and similarly learned edge latent states {e ij }. If we employ C attention heads, a layer update can be written as ij ) Figure 1: An overview of a sequence-to-structure pipeline utilizing the MSA-Transformer and a Graph Transformer.
(a) We first augment a length-N protein sequences to S of its MSA. The MSA-Transformer operates over this token matrix to produce enriched individual and pairwise embeddings. We store those embeddings that are from the original query sequence. (b) Deep neural networks extract relevant features and structure latent states for a downstream graph transformer. Individual and pairwise embeddings are assigned to nodes and edges, respectively. (c) A graph transformer operates on node representations through an attention-based mechanism that considers pairwise edge attributes. The final node encodings are projected directly to R 3 , and the induced distogram is computed for the loss.
Where ⊕ denotes concatenation, W The attention scores are normalized according to graph attention: To hold computational costs roughly constant, we let We train a predictor to recover coordinates of each residue in a learned canonical pose: Where X i ∈ R 3 . To train our network, we use a distogram-based loss function on the resulting distance map. Let D ij = X i − X j 2 be the induced Euclidean distance between the Cartesian projections of nodes i and j, and D ij be the ground truth distance. Our loss is based on the L 1 -norm of the difference between those values:  Rohl et al., 2004]. For each distance, we consider trRosetta's best prediction as its expected 86 value or its maximum likelihood estimate. We utilized dRMSD (distogram RMSD) between predicted distances 87 and ground truth as our evaluation metric. To make a direct comparison, we only consider distances that lie within 88 trRosetta's binning range (2-20 Å).

89
Our results demonstrate that the Graph Transformer model, despite its size, is competitive to trRosetta's estimates 90 ( Figure 3, Table 1). It is worth noting that our architecture resolves backbone structure as its main output and uniquely 91 and deterministically produces distances, whereas trRosetta operates within a probabilistic domain that does not need 92 three-dimensional resolution. These early results thus suggest potential for improved predictive capability with larger 93 model capacity and downstream protein refinement.
94 Figure 3: Qualitative assessment of model predictions for CASP13 free modeling targets. Note that our model is able to capture long range interactions, whereas trRosetta by construction is limited to short range dependencies. We highlight T0950 an T0963D2 as examples of challenging reconstructions for our network.