Learning a deep language model for microbiomes: the power 2 of large scale unlabeled microbiome data

.

example, susceptibility to infection with Campylobacter jejuni was shown to depend on 102 the species composition of the microbiota [15].

103
Transformers, a powerful and flexible machine learning architecture originally 104 developed for NLP [16], provides a potential solution to above issues.Past work [17][18][19][20][21] 105 has applied transformers to biological data.However, such work has focused on learning 106 a sequence encoder for representing DNA [21] or, more commonly, protein amino acid 107 sequences [17][18][19][20] (e.g., each token might represent a k-mer in such a sequence).In 108 contrast, we focus on representing entire microbial communities and their interactions, 109 using each token to represent a single microbe in such a community.
We present the first use of transformers to learn representations of microbiome at 111 the taxa level by adapting "self-supervised" pre-training techniques from NLP, allowing 112 the model to learn from vast amounts of unlabeled 16S microbiome data and mitigating 113 the required amount of expensive labeled data.The pre-trained models can be viewed 114 as a form of "language model" for microbiome data, capturing the inherent composition 115 rules of microbial communities, which we can easily adapt to downstream prediction 116 tasks with a smaller amount of labeled "finetuning" data.pre-process step (Fig. 1B) to transform the microbiome sample into 'text-like' inputs

207
A critical challenge in applying complex deep learning models like transformers is the 208 lack of large amounts of labeled training data.This can be addressed, however, using a 209 technique referred to as self-supervised pre-training [24], which leverages readily 210 available unlabeled data.In this work, we follow this approach and our training process 211 is described in Fig. 2.

213
We begin with a randomly initialized transformer and first train a task-agnostic 214 transformer using unlabeled data via self-supervised pre-training.Specifically, We use

216
Accurately) [25] to pre-train the encoder layers of the transformer model.We chose 217 ELECTRA because it reaches comparable performance to other popular pre-training 218 approaches (BERT [26] and its various flavors) while being computationally efficient.Halfvarson (HV).This dataset comes from an IBD study performed in [28].We 329 used the curated dataset produced in [13], which contains 564 microbiome samples, with 330 510 of them IBD positive.

331
HMP2.This dataset comes from an IBD study performed as part of phase 2 of the 332 Human Microbiome Project [29].Again, we used the curated dataset produced in [13],

361
As mentioned previously, our method applies a standard multi-layered perceptron Table 2 shows the performance of all methods on three tasks.We see that for all three tasks, the transformer produced representation achieved substantially improved 379 performance for both AUROC and AUPR.This confirms that our approach learns a  Halfvarson and HMP2, using a held-out AGP validation set for stopping is observed to 396 lead to poor and highly unstable results (shown by "Transformer (original)" in Table 3).

397
We address this problem by introducing a simple ensemble strategy.During fine tuning, 398 we train an ensemble of k classifiers using different random initializations of the 399 classification head.Similar to the standard practice when applying transformer to 400 language [26], we found that each individual classifier only needs to be fine-tuned for a single epoch, i.e., going over all of the training once, and that training more epochs often leads to overfitting.In our experiments, we used ensemble size k = 10.

403
We compare our ensemble performance with the baselines described above, and 404 additionally strengthen the Weighted baseline of [13] by using an ensembled MLP 405 classifier and reporting the best testing performance achieved by the Weighted baseline 406 method during training.The baselines from [30] use random forest as the classifier and 407 do not have a similar free parameter regarding their stopping condition.

408
We report the performance of all methods averaged across five random runs with 409 different initialization in Table 3.The results show that our method consistently  In this section we take a closer look at the pre-trained language model to interpret 417 the learned context-sensitive representations of microbial species.We first consider the 5000 most frequent taxa shown in Figure 4 and compute for each 489 taxa its average attribution toward the model's IBD prediction using the AGP IBD 490 data, as described in Sec.2.3.Lachnoclostridium, which were also found to be consistently associated with the healthy cohort in the data repository for the human gut microbiota (DRHM).Therefore they

Fig 1 .
Fig 1. Workflow of using a transformer model for generating sample embedding/classification and context sensitive taxa embeddings.The inputs (A), which are samples represented as relative abundance vectors, first go through the preprocessing step (B) to generate text-like inputs (C) for the transformer model (D).The transformer model generates a sample embedding (h cls ) that goes through a sample classification layer (E) to produce task specific sample level predictions (F).The transformer model also generates context sensitive embedding (G) for each taxa in the sample.Same taxa appearing in different samples can have different embedding because of contextual differences.

Fig 2 .
Fig 2. Training of the transformer model.Unlabeled microbiome data (A) is fed into a randomly initialized transformer (B) as inputs to the self-supervised pre-training process (C), which produces a pre-trained transformer that generates token-level classifications (D).We replace the token-level classification head with a randomly initialized CLS classification head (E), and use labeled microbiome data (F) to fine-tune the CLS classification head (G), which produces the fine-tuned transformer (H).

219
The ELECTRA pre-training approach has two steps.The first step trains a 220 generator model by randomly masking out 15% of taxa in microbiome samples and 221 training the generator model to predict the missing taxa based on the remainder of the 222 sample.For the second step, we use the trained generator to produce perturbed

Fig 3 .
Fig 3. Electra pre-training diagram.A generator is trained to predict the masked taxa from a sample.A discriminator is trained to differentiate taxa filled in by the generator from the original taxa in the sample.Both use the same transformer architecture, and have token level classification heads.The generator token level classification head predicts the taxa ID whereas the discriminator token level classification head predicts the input taxa as "Real" or "Modified" and

337
In this section, we empirically compare transformer-produced sample representations 338 against a variety of baseline methods.Our baselines include Weighted, a simple 339 non-contextualized abundance-weighted-averaging of the GloVe embeddings from[13], 340 two classic dimension reduction based methods, and two deep learning based methods 341 introduced by [30], each of which performs dimension reduction using the sample 342 taxonomic abundance profiles as input features: 343 • PCA: Principle Component Analysis, configured to retain at least 99% of the 344 variance.

362(
MLP) classifier to the transformer-produced sample representations for classification.363Toallow Weighted to act as a more consistent comparison with our model, we replaced 364 the random forest classifier used in prior work with the same MLP classifier.We 365 evaluate our method and the baseline methods using the AGP dataset on three 366 microbiome classification tasks.367Foreach task, we perform evaluation using 5-fold Cross Validation.For each 368 cross-validation run, 80% of the available labeled data was used for training, and the 369 remaining 20% was split between a validation and test set.We use the validation set for 370 hyperparameter tuning and choosing the stopping point for training the MLP classifier, 371 and use the test set to evaluate the resulting classifiers.We consider two different 372 evaluation criteria: the Area Under the ROC Curve (AUROC) and the Area Under the 373 Precision-Recall curve (AUPR).We select these two metrics because they allow us to 374 rigorously compare the discriminative capabilities of our models and baselines on 375 unbalanced classes, without having to specify a particular threshold for what we 376 consider a "positive" or "negative" classification.

385
Therefore it is important to test how well our transformer based prediction models 386 generalize on independent datasets that come from different population/sample 387 distributions.To test this, we applied our transformer model trained for the IBD 388 prediction task using the AGP data on the Halfvarson and HMP2 datasets from 389 independent studies.390 An issue that arises when performing such cross-study tests is the need to decide a 391 stopping point during finetuning to pick the best model to use on the test data.In the 392 previous single study experiments, using a held-out validation set for this purpose 393 proved to be an effective strategy.However, due to the substantial distributional shift 394 between the AGP data used for training/validation and the independent test set of 395

410
achieves better performance on the Halfvarson dataset compared to all baselines, and 411 comparable performance on the HMP2 dataset compared to the best performing of the 412 Weighted baseline model selected using testing data.Although CAE Best and CAE Match 413 achieve slightly higher HMP2 performance, this comes at the cost of an enormous deficit 414 on Halfvarson.These results illustrate our approach's ability to consistently generalize 415 well to out of distribution settings.

418 3 . 3
Context sensitive taxa embedding captures biologically 419 meaningful information 420We hypothesize that the superior predictive performance of our model is because our 421 pre-trained language model transforms the input taxa embedding into a more

Fig 4 .
Fig 4. t-SNE visualization of (a) original taxa vocabulary embeddings and (b) contextualize taxa embeddings.Both are colored by phylum.See Figure 9 for embedding spaces colored by phylum, class, order, and family.

Fig 5 .
Fig 5. Mapping between the original vocabulary and contextualized embedding spaces.Figure a) shows how the contextualized embeddings can extract "threads" of a single phylum from the vocabulary embedding space, and map those taxa to tight clusters in the contextualized embeddings.Figure b) shows that the mapping to the contextual embedding space is able to more cleanly separate taxa by phylum.Figure c) contrasts Figure b) and shows that taxa which are very tightly clustered in the vocabulary embeddings may not map to meaningful clusters or phylum-level separation in the contextualized embedding space.Figure d) shows cluster purity versus K for K-means clustering in the vocabulary and contextualized embedding spaces, showing the tighter clustering of the embedding space isn't simply an artifact of the t-SNE dimension reduction.
Fig 5. Mapping between the original vocabulary and contextualized embedding spaces.Figure a) shows how the contextualized embeddings can extract "threads" of a single phylum from the vocabulary embedding space, and map those taxa to tight clusters in the contextualized embeddings.Figure b) shows that the mapping to the contextual embedding space is able to more cleanly separate taxa by phylum.Figure c) contrasts Figure b) and shows that taxa which are very tightly clustered in the vocabulary embeddings may not map to meaningful clusters or phylum-level separation in the contextualized embedding space.Figure d) shows cluster purity versus K for K-means clustering in the vocabulary and contextualized embedding spaces, showing the tighter clustering of the embedding space isn't simply an artifact of the t-SNE dimension reduction.
Figure 6 shows 477 heatmaps for both sets of correlations.We can see that, although both embeddings 478 show clear correlation with some metabolic pathways, the contextualized embedding 479 dimensions capture stronger correlation, signified by the darker blue and red colors in 480 the heatmap.481Wefurther compare the distribution of the statistically significant correlation 482 magnitudes from both embeddings in Figure7, which shows that the normalized 483 histograms of the contextualized embedding dimensions are shifted to the right 484 compared to that of the GloVe embedding dimensions.

485 3 . 4
Understanding taxa importance for IBD prediction 486 In this part, we focus on the fine-tuned IBD ensemble prediction model to understand 487 what taxa play critical roles in our model's IBD prediction by studying their attribution. 488

Fig 7 .
Fig 7. Distribution of the magnitude of statistically significant correlations between embedding dimensions and metabolic pathways, for both contextualized embeddings and the prior GloVe embeddings.

Fig 8 . 33 Figure 8
Fig 8. t-SNE visualization of the contextualized embeddings colored by attribution to IBD.The taxa associated with IBD are visualized in lighter color (yellow) and the taxa associated with no-disease state are in dark purple.

Fig 9 .
Fig 9. Vocabulary and contextualized embedding spaces colored by different levels of the phylogenetic hierarchy: phylum, class, order, and family.

Table 1
"infrequent" (0-2) classes.Out of 6549 AGP examples containing vegetable 326 frequency metadata, 5654 were labeled positive.provides the summary statistics for the three classification tasks.

Table 1 .
Three classification tasks derived from the AGP data and meta data.
[30] CAE: 357 to match the parameter count of our own model (7.07M) as closely as possible.358Forthebaselinesfrom[30], we adapt that work's random forest classification layer 359 (and the range of hyperparameters to consider), because random forest most 360

Table 2 .
Average 382 383One of the largest challenges in working with microbiome data is that there is large 384 variance in the distributions and characteristics of data used from study to study.

Table 3 .
Average performance (standard deviation) on independent IBD datastes.Weighted's standard deviation is close to zero, thus omitted.

Table 4 .
Top 10Taxa associated with negative (non-disease) IBD classification ordered by attribution strength.

Table 5 .
Top ten Taxa associated with positive IBD classification ordered by attribution strength.Due to the difference in technologies between all the datasets, we compare the 531 markers across the studies at the genus level.In our study, seven ASVs were not 532 resolved beyond the family level, and are therefore excluded from this analysis.Further, 533 two of our ASVs belonged to sub-clade of a genus, we considered them belonging to the 534 genus of the clade: specifically Prevotella 9 (which was considered Prevotella in this 535 analysis) and Ruminoccocus 1 (which was considered Ruminoccocus in this analysis).Out of our 13 ASVs, four ASVs belong to genera Prevotella, Paraprevotella, and 524We compared the top 10 ASV attributions to IBD and the healthy cohort (20 ASVs 525 total) found with our model to 284 markers taxa identified in the data repository for 526 the human gut microbiota [38] across three projects (NCBI PRJEB7949 (95 entries), 527 NCBI PRJNA368966 (32 entries), NCBI PRJNA3x85949 (157 entries)) comparing IBD 528 and healthy controls (query request: 529 gmrepo.humangut.info/phenotypes/comparisons/D006262/D015212).530