## Abstract

**Motivation** Modality matching in single-cell omics data analysis—i.e., matching cells across data sets collected using different types of genomic assays—has become an important problem, because unifying perspectives across different technologies holds the promise of yielding biological and clinical discoveries. However, single-cell dataset sizes can now reach hundreds of thousands to millions of cells, which remains out of reach for most multi-modal computational methods.

**Results** We propose LSMMD-MA, a large-scale Python implementation of the MMD-MA method for multimodal data integration. In LSMMD-MA we reformulate the MMD-MA optimization problem using linear algebra and solve it with KeOps, a CUDA framework for symbolic matrix computation in Python. We show that LSMMD-MA scales to a million cells in each modality, two orders of magnitude greater than existing implementations.

**Availability** LSMMD-MA is freely available at https://github.com/google-research/large_scale_mmdma

**Contact** lpapaxanthos{at}google.com

## 1 Introduction and background

Modality matching in single-cell genomics data analysis can enhance our understanding of the relationships between cellular modalities and help us resolve cell states. In this problem, single-cell measurements collected using two or more different types of assays are projected into a shared space or are otherwise matched across modalities, with the goal of achieving insights into the joint multi-modal dataset. Most existing multimodal models rely on learning cell representations in each modality in a joint low-dimensional space [1, 2, 3, 4, 5, 6, 7, 8]. MMD-MA [9, 10] is one such method that has shown promising results on datasets containing several thousand cells in each modality. However, thanks to new single-cell technologies, the size of single-cell datasets has increased significantly in the past two years, now reaching several hundreds of thousands to millions of cells [11, 12, 6]. These datasets cannot be analyzed by current implementations of MMD-MA due to memory issues.

More precisely, MMD-MA [9] is a multimodal approach that maps each cell in each modality to a shared, low-dimensional representation space. The linear mappings from the input spaces to the representation space are learned by minimizing an objective function composed of several terms: a) a *matching* term based on the squared maximum mean discrepancy (MMD) with a Gaussian radial basis function (RBF) kernel to ensure that the different modalities overlap in the representation space, b) two *non-collapsing* penalties to prevent trivial solutions, and c) two *distortion* penalties to ensure that as much information from the input data as possible is captured in the shared representation. More details about the method are provided Appendix A. However, current implementations of MMD-MA [9, 10] scale quadratically as a function of the number of cells in memory and runtime, which is prohibitive for datasets with more than a few thousand samples (see Table A1).

To increase the scalibility of MMD-MA, we introduce LSMMD-MA, a reformulation and PyTorch implementation of MMD-MA that overcomes the memory explosion issue. To achieve this, we 1) reformulate MMD-MA’s optimization problem in the primal, which is beneficial when the number of cells is larger than the number of informative features, and 2) implement the MMD matching term with the CUDA-based KeOps library for symbolic matrices [13], tailored to handle matrices that do not fit in RAM or GPU memory. The resulting algorithm scales only linearly in memory with the number of cells and can handle up to a million cells in each modality.

## 2 Methods

### Reformulating MMD-MA in the primal

Let (resp. ) be the data matrix of the first (resp. second) modality, where *n*_{x} (resp. *n*_{y}) is the number of cells in the first (resp. second) modality and *p*_{x} (resp. *p*_{y}) is the number of features in the first (resp. second) modality. The goal of MMD-MA is to learn two mappings from the input spaces and to a shared representation space ℝ^{d}. We focus specifically on linear mappings, as in the original publications [9, 10], where mappings are parameterized with dual variables and such that the embedding of the first (resp. second) modality is *XX*^{⊤}*α*_{x} (resp. *YY*^{⊤}*α*_{y}). Instead, we equivalently parameterize the mappings by primal variables and , such that the embedding of the first (resp. second) modality is *XW*_{x} (resp. *YW*_{y}). We can then rewrite the MMD-MA optimization problem in the primal:
where *λ*_{1} and *λ*_{2} are hyperparameters. Under the assumption that *n >> p >> d* for each modality, efficiently implementing the primal loss (1) scales better than implementing the dual loss, as shown in Table A1 shows. The primal loss does not require the computation and storage of the linear kernel matrices *XX*^{⊤} and *YY* ^{⊤}, which are O(*n*^{2}) in time and memory, and the penalty and distortion terms are not O(*n*^{2}) in runtime anymore. However, computing the MMD term remains O(*n*^{2}) in runtime and memory if we implement it naively.

### Using KeOps

To overcome the O(*n*^{2}) memory burden of computing the MMD term, we implement it using the CUDA-based Map-Reduce scheme of KeOps [13]. This allows us to compute the MMD term without instantiating the *n* × *n* Gaussian RBF kernel in memory, using symbolic matrix computation with O(*n*) memory complexity. KeOps therefore optimizes (1) with a linear memory complexity and also improves runtime by a significant multiplicative factor when the number of samples is larger than 1000.

## 3 Implementation

We make four algorithms available, including `LS-MMDMA` and three variants: primal formulation without KeOps, dual formulation with KeOps, and dual formulation without KeOps (an efficient implementation of the original algorithm). The code is implemented in PyTorch [14] and can run on CPU or GPU. The package is open source with an Apache license, available at github.com/google-research/large_scale_mmdma. It is referenced on PyPI and can be installed with the command: `pip install lsmmdma`. Details about I/O, command line instructions and tutorials are given in the `Readme.md` and in the `examples` folder.

## 4 Results and conclusion

We tested the implementation of LSMMD-MA (primal formulation with KeOps) against three comparison partners (primal formulation without KeOps, dual formulations with KeOps and dual formulation without KeOps) and against the two original implementations [9, 10] which focus on the dual formulation in TensorFlow [15] and PyTorch [14], respectively. Additionally, [10] proposes to use the linear time approximation of MMD for large numbers of samples (*>* 5000) (see Lemma 14 in [16]). We ran all algorithms on simulated datasets of different sizes where the latent space is shaped as a branch (see lsmmdma/data/data_pipeline.py in GitHub and the Appendix D for more details). All algorithms were run for 500 epochs and a low-dimensional representation of dimension *d* = 10. A V100 GPU (16GB) was used for the experiments.

We observe that LSMMD-MA, using the primal formulation and KeOps, scales to one million cells in each modality, whereas the original implementations runs out of memory for more than 14,000 cells (Figure 1). We also notice that our dual implementations are faster than the original implementations irrespective of the number of samples. As a proof-of-principle, we also ran LSMMD-MA on a real-world CITE-seq dataset containing 90,261 human bone marrow mononuclear cells, with 13,953 gene IDs for the gene expression modality and 134 proteins for the protein marker modality [17]. This dataset required 68.4 minutes to run 1000 epochs, which would have been infeasible with previous versions of MMD-MA.

These results suggest that an optimised implementation, exploiting the primal formulation and taking advantage of the KeOps library, are key to building a multimodal model that scales to the size of current single-cell datasets.

## Appendix

### A Description of `MMD-MA` with the dual formulation

`MMD-MA` [9] is a method for analyzing multimodal data that relies on mapping the observed cell samples to embeddings, using functions belonging to a reproducing kernel Hilbert space (RKHS). The authors build linear kernels for both domains and learn the coefficients of the embedding functions in the RKHS dual representation. In practice, the `MMD-MA` loss is composed of a) the squared MMD, a *matching* term, b) two *non-collapsing* penalties and c) two *distortion* penalties, one for each modality.

Let be the data matrix of the first modality and the one of the second modality. *n*_{x} (resp. *n*_{y}) is the number of cells in the first (resp. second) modality and *p*_{x} (resp. *p*_{y}) is the number of features in the first (resp. second) modality. We denote by and the linear kernel matrices corresponding to both datasets. and are the learned coefficients of the RKHS functions used to build the embeddings *K*_{x}*α*_{x} and *K*_{y}*α*_{y}, of dimension *d*. The loss of `MMD-MA` can be written as follows:
where *λ*_{1} and *λ*_{2} are two hyperparameters weighing the non-collapsing penalty (*pen*(*K, α*) = ||*α*^{T} *Kα* − **I**_{d} ||_{2}) and the distortion penalty, ensuring that the geometries in the low- and high-dimensional representations are comparable (*dis*(*K, α*) = ||*K* −*K*^{T} *α*^{T} *αK* ||_{2}). The RBF kernel is used to calculate MMD, introducing non-linearity in the matching term. However, the linear and gaussian kernel matrices require memory and runtime that scale quadratically as a function of the number of cells in terms of memory and runtime, which is prohibitive for large datasets.

### B Runtime and memory requirements

### C Benchmark

### D Datasets

The synthetic datasets used for the runtime experiments have a varying number of samples from 10^{3} to 10^{6} and a varying number of features from 10 to 10^{4}. To generate them, we used the `generate_data` function in lsmmdma/data/data_pipeline.py, with the `random_seed` argument fixed to 4 and the `simulation` argument set to ‘branch’. The simulation process works as follows. A manifold (of shape ‘branch’) is generated in two dimensions. The resulting set of points is standardised. The points are then mapped to `p_feature` dimensions using a (2 x `p_feature`) mapping, sampled from a standard Gaussian distribution, resulting in a (`n_sample` x `p_feature`) matrix. Gaussian noise is then added to each element of the matrix. This simulation was first described in [9] and the latent space shaped as a branch aims at mimicking a development process for example.

The real-world dataset on which the last experiment is ran comes from [17], where datasets were made publically available for the Neurips Competition Multimodal Single-Cell Data Integration.