Scalable Spike Source Localization in Extracellular Recordings using Amortized Variational Inference

Determining the positions of neurons in an extracellular recording is useful for investigating functional properties of the underlying neural circuitry. In this work, we present a Bayesian modelling approach for localizing the source of individual spikes on high-density, microelectrode arrays. To allow for scalable inference, we implement our model as a variational autoencoder and perform amortized variational inference. We evaluate our method on both biophysically realistic simulated and real extracellular datasets, demonstrating that it is more accurate than and can improve spike sorting performance over heuristic localization methods such as center of mass.


Introduction
Extracellular recordings, which measure local potential changes due to ionic currents flowing through cell membranes, are an essential source of data in experimental and clinical neuroscience. The most prominent signals in these recordings originate from action potentials (spikes), the all or none events neurons produce in response to inputs and transmit as outputs to other neurons. Traditionally, a small number of electrodes (channels) are used to monitor spiking activity from a few neurons simultaneously. Recent progress in microfabrication now allows for extracellular recordings from thousands of neurons using microelectrode arrays (MEAs), which have thousands of closely spaced electrodes [13,2,14,1,36,56,32,25,12]. These recordings provide insights that cannot be obtained by pooling multiple single-electrode recordings [27]. This is a significant development as it enables systematic investigations of large circuits of neurons to better understand their function and structure, as well as how they are affected by injury, disease, and pharmacological interventions [20].
On dense MEAs, each recording channel may record spikes from multiple, nearby neurons, while each neuron may leave an extracellular footprint on multiple channels. Inferring the spiking activity of individual neurons, a task called spike sorting, is therefore a challenging blind source separation problem, complicated by the large volume of recorded data [46]. Despite the challenges presented by spike sorting large-scale recordings, its importance cannot be overstated as it has been shown that isolating the activity of individual neurons is essential to understanding brain function [35]. Recent efforts have concentrated on providing scalable spike sorting algorithms for large scale MEAs and already several methods can be used for recordings taken from hundreds to thousands of channels [42,31,10,55,22,26]. However, scalability, and in particular automation, of spike sorting pipelines remains challenging [8].
One strategy for spike sorting on dense MEAs is to spatially localize detected spikes before clustering. In theory, spikes from the same neuron should be localized to the same region of the recording area (near the cell body of the firing neuron), providing discriminatory, low-dimensional features for each spike that can be utilized with efficient density-based clustering algorithms to sort large data sets with tens of millions of detected spikes [22,26]. These location estimates, while useful for spike sorting, can also be exploited in downstream analyses, for instance to register recorded neurons with anatomical information or to identify the same units from trial to trial [9,22,41].
Despite the potential benefits of localization, preexisting methods have a number of limitations. First, most methods are designed for low-channel count recording devices, making them difficult to use with dense MEAs [9,52,3,30,29,34,33,51]. Second, current methods for dense MEAs utilize cleaned extracellular action potentials (through spike-triggered averaging), disallowing their use before spike sorting [48,6]. Third, all current model-based methods, to our knowledge, are non-Bayesian, relying primarily on numerical optimization methods to infer the underlying parameters. Given these current limitations, the only localization methods used consistently before spike sorting are simple heuristics such as a center of mass calculation [38,44,22,26].
In this paper, we present a scalable Bayesian modelling approach for spike localization on dense MEAs (less than ∼ 50µm between channels) that can be performed before spike sorting. Our method consists of a generative model, a data augmentation scheme, and an amortized variational inference method implemented with a variational autoencoder (VAE) [11,28,47]. Amortized variational inference has been used in neuroscience for applications such as predicting action potentials from calcium imaging data [53] and recovering latent dynamics from single-trial neural spiking data [43], however, to our knowledge, it has not been used in applications to extracellular recordings.
After training, our method allows for localization of one million spikes (from high-density MEAs) in approximately 37 seconds on a TITAN X GPU, enabling real-time analysis of massive extracellular datasets. To evaluate our method, we use biophysically realistic simulated data, demonstrating that our localization performance is significantly better than the center of mass baseline and can lead to higher-accuracy spike sorting results across multiple probe geometries and noise levels. We also show that our trained VAE can generalize to recordings on which it was not trained. To demonstrate the applicability of our method to real data, we assess our method qualitatively on real extracellular datasets from a Neuropixels [25] probe and from a BioCam4096 recording platform.
To clarify, our contribution is not full spike sorting solution. Although we envision that our method can be used to improve spike sorting algorithms that currently rely center of mass location estimates, interfacing with and evaluating these algorithms was beyond the scope of our paper.

Spike localization
We start by introducing the relevant notation. Let s := {s n } N n=1 be the set of N spikes that are detected in an extracellular recording. For each spike, we define the source location of the spike to be p sn := (x sn y sn , z sn ) ∈ R 3 . We further denote p c := {p cm } M m=1 to be the position of all M channels on the MEA. All positions are relative to the origin, which we set to be the center of the MEA. When a spike occurs during an extracellular recording, we assume there is a stereotypical spatiotemporal pattern that is recorded on the channels in the array, i.e., the recorded extracellular waveforms. The recorded extracellular waveforms of a spike s n on a channel c m can then be defined as w n,m := {r n,m ∈ R and t = 0, . . . , T . Localizing a spike can now be defined as follows: Localizing a spike s n is equivalent to solving for the corresponding point source location p sn given the observed waveforms w n and the channel positions p c .  Figure 1: (a) For each spike s n , there are four latent variables that give rise to the observed waveforms w n . These are the initial amplitude a and the x, y, z values of the source location p sn . Each spike is localized independently as indicated by the plate diagram. (b) Our model is implemented as a variational autoencoder with a dense encoder and an exponential function with a Gaussian observation model. The latent space consists of the three location variables x n , y n , z n . The initial amplitude a n is inferred using a maximum likelihood estimate to improve stability.
We perform source localization independently for each spike. Crucially, we make the assumption that the point source location p sn of a spike is approximately the location of the firing neuron's soma. This is a useful assumption as it allows us to localize the underlying neurons' positions without spike sorting. We discuss limitations of this modelling assumption in the Discussion section.

Center of mass
Many modern spike sorting algorithms localize spikes on MEAs using the center of mass or barycenter method [44,22,26]. We summarize the traditional steps for localizing a spike, s n using this method. First, let us define α n := min t w n,m as the negative amplitude peak of the waveform, w n,m , generated by s n and recorded on channel, c m . We consider the negative peak amplitude as a matter of convention since spikes are defined as inward currents. Then, let α n := (α n,m ) M m=1 be the vector of all amplitudes generated by s n and recorded by all M channels on the MEA.
To find the center of mass of a spike, s n , the first step is to determine the central channel for the calculation. This central channel is set to be the channel which records the minimum amplitude for the spike, c mmin := c argmin m αn,m The second and final step is to take the L closest channels to c mmin and compute,x sn = L+1 l=1 (x c l )|α n,l | The center of mass method is inexpensive to compute and has been shown to give informative location estimates for spikes in both real and synthetic data [44,37,22,26]. Center of mass, however, suffers from two main drawbacks: First, since the chosen channels will form a convex hull, the center of mass location estimates must lie inside the channels' locations, negatively impacting location estimates for neurons outside of the MEA. Second, center of mass is biased towards the chosen central channel, potentially leading to artificial separation of location estimates for spikes from the same neuron [44].

Method
In this section, we introduce our scalable, model-based approach to spike localization. We describe the generative model, the data augmentation procedure, and the inference methods.

Model
Our model uses the recorded amplitudes on each channel to determine the most likely source location of s n . We assume that the peak signal from a spike decays exponentially with the distance from the source, r: a exp(br) where a, b ∈ R, r ∈ R + . This assumption is well-motivated by experimentally recorded extracellular potential decay in both a salamander and mouse retina [50,22], as well as a cat cortex [16]. It has also been further corroborated using realistic biophysical simulations [18].
We utilize this exponential assumption to infer the source location of a spike, s n , since localization is then equivalent to solving for s n 's unknown parameters, θ sn := {a n , b n , x sn , y sn , z sn } given the observed amplitudes, α n . To allow for localization without knowing the identity of the firing neuron, we assume that each spike has individual exponential decay parameters, a n , b n , and individual source locations, p sn . We find, however, that fixing b n for all spikes to a constant that is equal to an empirical estimate from literature (decay length of ∼ 28µm) works best across multiple probe geometries and noise levels, so we did not infer the value for b n in our final method. We will refer to the fixed decay rate as b and exclude it from the unknown parameters moving forward.
The generative process of our exponential model is as follows, In our observation model, the amplitudes are drawn from an isotropic Gaussian distribution with a variance of one. We chose this Gaussian observation model for computational simplicity and since it is convenient to work with when using VAEs. This is because learning the variance of the observation model in a VAE can be numerically unstable [49]. We discuss the limitations of our modeling assumptions in Section 5 and propose several extensions for future works.
For our prior distributions, we were careful to set sensible parameter values. We found that inference, especially for a spike detected near the edge of the MEA, is sensitive to the mean of the prior distribution of a n , therefore, we set µ an = λα n,mmin where α n,mmin is the smallest negative amplitude peak of s n . We choose this heuristic because the absolute value of α n,mmin will always be smaller than the absolute value of the amplitude of the spike at the source location, due to potential decay. Therefore, scaling α n,mmin by λ gives a sensible value for µ an . We empirically choose λ = 2 for the final method after performing a grid search over λ = {1, 2, 3}. The parameter, σ a , does not have a large affect on the inferred location so we set it to be approximately the standard deviation of the α n,mmin (50). The location prior means, µ xs n , µ ys n , µ zs n , are set to the location of the minimum amplitude channel, p cm min , for the given spike. The location prior standard deviations, σ x , σ y , σ z , are set to large constant values to flatten out the distributions since we do not want the location estimate to be overly biased towards p cm min .

Data Augmentation
For localization to work well, the input channels should be centered around the peak spike, which is hard for spikes near the edges (edge spikes). To address this issue, we employ a two-step data augmentation. First, inputs for edge spikes are padded such that the channel with the largest amplitude is at the center of the inputs. Second, all channels are augmented with an indicating variable which provides signal to distinguish them for the inference network. To be more specific, we introduce virtual channels outside of the MEA which have the same layout as the real, recording channels (see appendix C). We refer to a virtual channel as an "unobserved" channel, c mu , and to a real channel on the MEA as an "observed" channel, c mo . We define the amplitude on an unobserved channel, α n,mu , to be zero since unobserved channels do not actually record any signals. We let the amplitude for an observed channel, α n,mo , be equal to min t w n,mo , as before.
Before defining the augmented dataset, we must first introduce an indicator function, 1 o : α → {0, 1}: where α is an amplitude from any channel, observed or unobserved.
To construct the augmented dataset for a spike, s n , we take the set of L channels that lie within a bounding box of width W centered on the observed channel with the minimum recorded amplitude, c mo min . We define our newly augmented observed data for s n as, So, for a single spike, we construct a L × 2 dimensional vector that contains amplitudes from L channels and indices indicating whether the amplitudes came from observed or unobserved channels.
Since the prior location for each spike is at the center of the subset of channels used for the observed data, for edge spikes, the data augmentation puts the prior closer to the edge and is, therefore, more informative for localizing spikes near/off the edge of the array. Also, since edge spikes are typically seen on less channels, the data augmentation serves to ignore channels which are away from the spike, which would otherwise be used if the augmentation is not employed.

Inference
Now that we have defined the generative process and data augmentation procedure, we would like to compute the posterior distribution for the unknown parameters of a spike, s n , p(a n , x sn , y sn , z sn |β n ) (3) given the augmented dataset, β n . To infer the posterior distribution for each spike, we utilize two methods of Bayesian inference: MCMC sampling and amortized variational inference.

MCMC sampling
We use MCMC to assess the validity and applicability of our model to extracellular data. We implement our model in Turing [15], a probabilistic modeling language in Julia. We run Hamiltonian Monte Carlo (HMC) [39] for 10,000 iterations with a step size of 0.01 and a step number of 10. We use the posterior means of the location distributions as the estimated location. 1 Despite the ease use of probabilistic programming and asymptotically guaranteed inference quality of MCMC methods, the scalability of MCMC methods to large-scale datasets is limited. This leads us to implement our model as a VAE and to perform amortized variational inference for our final method.

Amortized variational inference
To speed up inference of the spike parameters, we construct a VAE and use amortized variational inference to estimate posterior distributions for each spike. In variational inference, instead of sampling from the target intractable posterior distribution of interest, we construct a variational distribution that is tractable and minimize the Kullback-Leibler (KL) divergence between the variational posterior and the true posterior. Minimizing the KL divergence is equivalent to maximizing the evidence lower bound (ELBO) for the log marginal likelihood of the data. In VAEs, the parameters of the variational posterior are not optimized directly, but are, instead, computed by an inference network. When training the VAE, we found that inference of the initial amplitude a n , especially for a spike detected near the edge of the MEA, is quite sensitive. To improve stability, we decided to do a maximum likelihood estimate for the mean of initial amplitude µ an with a fixed variance σ a . Therefore, we only define our variational posterior for the source location, x n , y n , z n .
We define our variational posterior for x n , y n , z n as a multivariate Normal with diagonal covariance where the mean and diagonal of the covariance matrix are computed by an inference network q Φ (x n , y n , z n ) = N (µ µ µ φ1 (f φ0 (υ n )), σ σ σ 2 φ2 (f φ0 (υ n ))) The inference network is implemented as a feed-forward, deep neural network parameterized by Φ = {φ 0 , φ 1 , φ 2 }. As one can see, the variational parameters are a function of the input υ υ υ.
When using an inference network, the input can be any part of the dataset so for our method, we use, υ n , as the input for each spike, s n , which is defined as follows: where w n,l is the waveform detected on the lth channel (defined in Section 2.1). Similar to our previous augmentation, the waveform for an unobserved channel is set to be all zeros. We choose to input the waveforms rather than the amplitudes because, empirically, it encourages the inferred location estimates for spikes from the same neuron to be better localized to the same region of the MEA. For both the real and simulated datasets, we used ∼2 ms of readings for each waveform.
The decoder for our method reconstructs the amplitudes from the observed data rather than the waveforms. Since we assume an exponential decay for the amplitudes, the decoder is a simple Gaussian likelihood function, where given the Euclidean distance vectorr n , computed by samples from the variational posterior, the decoder reconstructs the mean value of the observed amplitudes with a fixed variance. The decoder is parameterized by the exponential parameters of the given spike, s n , so it reconstructs the amplitudes of the augmented data, β (0) n , with the following expression: β (0) n := a n exp(br n ) × β 1 n whereβ (0) n is the reconstructed observed amplitudes. By multiplying the reconstructed amplitude vector by β 1 n which consists of either zeros or ones (see Eq. 5), the unobserved channels will be reconstructed with amplitudes of zero and the observed channels will be reconstructed with the exponential function. For our VAE, instead of estimating the distribution of a n , we directly optimize a n when maximizing the lower bound. We set the initial value of a n to the mean of the prior. Thus, a n can be read as a parameter of the decoder.
Given our inference network and decoder, the ELBO we maximize for each spike, s n , is given by, where KL is the KL-divergence. The location priors, p xn , p yn , p zn , are normally distributed as described in 3.1, with means of zero (the position of the maximum amplitude channel in the observed data) and variances of 80 (an arbitrarily high value). For more information about the architecture and training, see Appendix F.

Stabilized Location Estimation
In this model, the channel on which the input is centered can bias the estimate of the spike location, in particular when amplitudes are small. To reduce this bias, we can create multiple inputs for the same spike where each input is centered on a different channel. During inference, we can average the inferred locations for each of these inputs, thus lowering the central channel bias. To this end, we introduce a hyperparameter, amplitude jitter, where for each spike, s n , we create multiple inputs centered on channels with peak amplitudes within a small voltage of the maximum amplitude, α n,m . We use two values for the amplitude jitter in our experiments: 0µV and 10µV . When amplitude jitter is set to 0µV , no averaging is performed; when amplitude jitter is set to 10µV , all channels that have peak amplitudes within 10µV of α n,m are used as inputs to the VAE and averaged during inference.

Datasets
We simulate biophysically realistic ground-truth extracellular recordings to test our model against a variety of real-life complexities. The simulations are generated using the MEArec [4] package which includes 13 layer 5 juvenile rat somatosensory cortex neuron models from the neocortical microcircuit collaboration portal [45]. We simulate three recordings with increasing noise levels (ranging from 10µV to 30µV ) for two probe geometries, a 10x10 channel square MEA with a 15 µm inter-channel distance and 64 channels from a Neuropixels probe (∼25-40 µm inter-channel distance). Our simulations contain 40 excitatory cells and 10 inhibitory cells with random morphological subtypes, randomly distributed and rotated in 3D space around the probe (with a 20 µm minimum  Table 1: Results for the 2D location estimates. These results are for three simulated, square MEA datasets with noise levels ranging from 10µV-30µV. For the VAE methods in the first column, the amount of amplitude jitter used is displayed to the right (amplitude jitter is described in 3.3.3).
distance between somas). Each dataset has about 20,000 spikes in total (60 second duration). For more details on the simulation and noise model, see Appendix G.
For the real datasets, we use public data from a Neuropixels probe [32] and from a mouse retina recorded with the BioCam4096 platform [24]. The two datasets have 6 million and 2.2 million spikes, respectively. Spike detection and sorting (with our location estimates) are done using the HerdingSpikes2 software [22].

Evaluation
Before evaluating the localization methods, we must detect the spikes from each neuron in the simulated recordings. To avoid biasing our results by our choice of detection algorithm, we assume perfect detection, extracting waveforms from channels near each spiking neuron. Once the waveforms are extracted from the recordings, we perform the data augmentation. For the square MEA we use W = 20, 40, which gives L = 4-9, 9-25 real channels in the observed data, respectively. For the simulated Neuropixels, we use W = 35, 45, which gives L = 3-6, 8-14 real channels in the observed data, respectively. Once we have the augmented dataset, we generate location estimates for all the datasets using each localization method. For straightforward comparison with center of mass, we only evaluate the 2D location estimates (in the plane of the recording device).
In the first evaluation, we assess the accuracy of each method by computing the Euclidean distance between the estimated spike locations and the associated firing neurons. We report the mean and standard deviation of the localization error for all spikes in each recording.
In the second evaluation, we cluster the location estimates of each method using Gaussian mixture models (GMMs). The GMMs are fit with spherical covariances ranging from 45 to 75 mixture components (with a step size of 5). We report the true positive rate and accuracy for each number of mixture components when matched back to ground truth. To be clear, our use of GMMs is not a proposed spike sorting method for real data (the number of clusters is never known apriori), but rather a systematic way to evaluate whether our location estimates are more discriminable features than those of center of mass.
In the third evaluation, we again use GMMs to cluster the location estimates, however, this time combined with two principal components from each spike. We report the true positive rate and accuracy for each number of mixture components as before. Combining location estimates and principal components explicitly, to create a new, low-dimensional feature set, is introduced in Hilgen (2017). In this work, the principal components are whitened and then scaled with a hyperparameter, α. To remove any bias from choosing an α value in our evaluation, we conduct a grid search over α = {4, 6, 8, 10} and report the best metric scores for each method.
In the fourth evaluation, we assess the generalization performance of the method by training a VAE on an extracellular dataset and then trying to infer the spike locations in another dataset where the neuron locations are different, but all other aspects are kept the same (10µV noise level, square MEA). The localization and sorting performance is then compared to that of a VAE trained directly on the second dataset and to center of mass. Taken together, the first evaluation demonstrates how useful each method is purely as a localization tool, the second evaluation demonstrates how useful the location estimates are for spike sorting immediately after localizing, the third evaluation demonstrates how much the performance can improve given extra waveform information, and the fourth evaluation demonstrates how our method can be used across similar datasets without retraining. For all of our sorting analysis, we use SpikeInterface version 0.9.1 [5]. Table 1 reports the localization accuracy of the different localization methods for the square MEA with three different noise levels. Our model-based methods far outperform center of mass with any number of observed channels. As expected, introducing amplitude jitter helps lower the mean and standard deviation of the location spike distance. Using a small width of 20µm when constructing the augmented data (4-9 observed channels) has the highest performance for the square MEA.

Results
The location estimates for the square MEA are visualized in Figure 2. Recording channels are plotted as grey squares and the true soma locations are plotted as black stars. The estimated individual spike locations are colored according to their associated firing neuron identity. As can be seen in the plot, center of mass suffers both from artificial splitting of location estimates and poor performance on neurons outside the array, two areas in which the model-based approaches excel. The MCMC and VAE methods have very similar location estimates, highlighting the success of our variational inference in approximating the true posterior. See Appendix A for a location estimate plot when the VAE is trained and tested on simulated Neuropixels recordings.
In Figure 3, spike sorting performance on the square MEA is visualized for all localization methods (with and without waveform information). Here, we only show the sorting results for center of mass on 25 observed channels, where it performs at its best. Overall, the model-based approaches have significantly higher precision, recall, and accuracy than center of mass across all noise levels and all different numbers of mixtures. This illustrates how model-based location estimates provide a much more discriminatory feature set than the location estimates from the center of mass approaches. We also find that the addition of waveform information (in the form of principal components) improves spike sorting performance for all localization methods. See Appendix A for a spike sorting performance plot when the VAE is trained and tested on simulated Neuropixels recordings.
As shown in Appendix D, when our method is trained on one simulated recording, it can generalize well to another simulated recording with different neuron locations. The localization accuracy and sorting performance are only slightly lower than the VAE that is trained directly on the new recording. Our method also still outperforms center of mass on the new dataset even without training on it. Figure 4: Estimated spike locations for two real recordings. A, Analysis of a one hour recording from an awake, head-fixed mouse with a Neuropixels probe. Spikes were detected using the HS2 package [22], their locations estimated using the VAE model, and clustered with mean shift, together with the first two principal components obtained from the waveforms. Shown are a large section of the probe, a magnification and corresponding spike waveforms from the clustered units. B, The same analysis performed on a recording from a mouse retina with a BioCam array from ref [24]. Figure 4 shows our localization method as applied to two real, large-scale extracellular datasets.
In these plots, we color the location estimates based on their unit identity after spike sorting with HerdingSpikes2. These extracellular recordings do not have ground truth information as current, ground-truth recordings are limited to a few labeled neurons [57,19,21,40,55]. Therefore, to demonstrate that the units we find likely correspond to individual neurons, we visualize waveforms from a local grouping of sorted units on the Neuropixels probe. This analysis illustrates that are method can already be applied to large-scale, real extracellular recordings.
In Appendix E, we demonstrate that the inference time for the VAE is much faster than that of MCMC, highlighting the excellent scalability of our method. The inference speed of the VAE allows for localization of one million spikes in approximately 37 seconds on a TITAN X GPU, enabling real-time analysis of large-scale extracellular datasets.

Discussion
Here, we introduce a Bayesian approach to spike localization using amortized variational inference. Our method significantly improves localization accuracy and spike sorting performance over the preexisting baseline while remaining scalable to the large volumes of data generated by MEAs. Scalability is particularly relevant for recordings from thousands of channels, where a single experiment may yield in the order of 100 million spikes.
We validate the accuracy of our model assumptions and inference scheme using biophysically realistic ground truth simulated recordings that capture much of the variability seen in real recordings. Despite the realism of our simulated recordings, there are some factors that we did not account for, including: bursting cells with event amplitude fluctuations, electrode drift, and realistic intrinsic variability of recorded spike waveforms. As these factors are difficult to model, future analysis of real recordings or advances in modeling software will help to understand possible limitations of the method.
Along with limitations of the simulated data, there are also limitations of our model. Although we assume a monopole current-source, every part of the neuronal membrane can produce action potentials [7]. This means that a more complicated model, such as a dipole current [51], line currentsource [51], or modified ball-and-stick [48], might be a better fit to the data. Since these models have only ever been used after spike sorting, however, the extent at which they can improve localization performance before spike sorting is unclear and is something we would like to explore in future work. Also, our model utilizes a Gaussian observation model for the spike amplitudes. In real recordings, the true noise distribution is often non-Gaussian and is better approximated by pink noise models ( 1 f noise) [54]. We plan to explore more realistic observation models in future works.
Since our method is Bayesian, we hope to better utilize the uncertainty of the location estimates in future works. Also, as our inference network is fully differentiable, we imagine that our method can be used as a submodule in a more complex, end-to-end method. Other work indicates there is scope for constructing more complicated models to perform event detection and classification [31], and to distinguish between different morphological neuron types based on their activity footprint on the array [6]. Our work is thus a first step towards using amortized variational inference methods for the unsupervised analysis of complex electrophysiological recordings.  Table 2: Results for the 2D location estimates. These results are for three simulated, Neuropixels datasets with noise levels ranging from 10µV-30µV. For the VAE methods in the first column, the amount of amplitude jitter used is displayed to the right (amplitude jitter is described in 3.3.2). B Effect of Noise on VAE Figure 7: Effect of noise on location inference for the VAE on the Neuropixels probe. We vary the noise levels for the recording from 10µV, 20µV, and 30µV. Increasing the noise also increases the number of outliers in and the spread of the location estimates.  Figure 9: The simulated recording set-up and example data. A, Example electrical traces from the MEA with recorded action potentials (spikes, negative deflections). B, The 2D layout of the simulated recording. Recording channels are indicated in grey, and the true locations of the simulated neurons in red. The traces in part A are taken from the first column of the array. Note each spike is visible in multiple channels, with a characteristic spatial decay. C, Illustration of the data augmentation procedure in cases where the spikes are detected on channels near the array boundary. A set of virtual channels is introduced, which are incapable of recording any signal, but would report non-zero amplitudes if they were present on the MEA. Table 3: Location results for the generalization performance of a VAE trained on one 10µV, square MEA dataset and tested on another 10µV, square MEA dataset. We compare the results of this VAE to another VAE that is trained directly on the second dataset to quantify the drop in performance when generalizing between datasets. We also compare to the center of mass baselines.

Method
Observed  We compare the sorting performance of the VAE localization method and the COM localization method with and without principal components across all noise levels. For the VAE, we include the results with 0µV and 10µV amplitude jitter and with different amounts of observed channels (4-9 and 9-25). For COM, we plot the highest sorting performance (25 observed channels). The test data set has 50 neurons.

F Architecture and Training Details
We set the inference network to be 2 layers deep with ReLU nonlinearities. The hidden unit sizes in the inference network are set to be [500, 250]. We include batchnorm layers throughout the encoder to improve training and generalization.
We train the VAE with three different learning rates, {.0003, .001, .003}, and choose the learning rate that has the highest performance, although this parameter did not have a large effect on the results.
To ensure convergence for the simulated data, we train the network for 400 epochs on the entire dataset. For the real datasets, we train the network on a subset of the detected spikes (∼100,000 spikes) and then we infer the rest of the locations.

G Simulated Data
To generate the extracellular recordings, we simulate the multi-compartment neuron models using NEURON [23] and use the transmembrane currents to compute extracellular action potentials (EAP) with LFPy [17]. EAPs are then combined with randomly generated spike trains to generate recordings. Finally, noise is added and the entire recording is filtered using a 3rd order Butterworth filter (0.3, 6 kHz).
For the noise model, we simulate templates for 300 neurons that are far away from the recording area. These small action potentials make up the background noise of the recording and have noise levels ranging from 10µV to 30µV standard deviation for the simulated datasets. We choose this noise model because it best captures the frequency and challenges of background noise in real extracellular recordings.
For each of the three recordings on one probe geometry, we fix the neuron locations to assess the effect of noise on the location estimates for each neuron.

H MCMC Turing Code
Below is the probabilistic program and inference code for the MCMC version of our method in Turing [15].