Abstract
Neural population activity is theorized to reflect an underlying dynamical structure. This structure can be accurately captured using state space models with explicit dynamics, such as those based on recurrent neural networks (RNNs). However, using recurrence to explicitly model dynamics necessitates sequential processing of data, slowing real-time applications such as brain-computer interfaces. Here we introduce the Neural Data Transformer (NDT), a non-recurrent alternative. We test the NDT’s ability to capture autonomous dynamical systems by applying it to synthetic datasets with known dynamics and data from monkey motor cortex during a reaching task well-modeled by RNNs. The NDT models these datasets as well as state-of-the-art recurrent models. Further, its non-recurrence enables 3.9ms inference, well within the loop time of real-time applications and more than 6 times faster than recurrent baselines on the monkey reaching dataset. These results suggest that an explicit dynamics model is not necessary to model autonomous neural population dynamics.
1. Introduction
Neural populations are theorized to have an underlying dynamical structure which drives the evolution of population activity over time [21, 27, 33]. This structure can be explicitly modeled using linear [6, 11, 19] or switching linear dynamical systems [17, 24], or nonlinear dynamical systems such as recurrent neural networks (RNNs) [22, 23, 26]. In contrast to traditional analyses that average activity across repeated trials of the same behavior, these models have helped relate neural population activity to behavior in individual trials. In particular, an RNN-based method called latent factor analysis via dynamical systems (LFADS) has been shown to model single trial variability in neural spiking activity far better than traditional baselines like spike smoothing or GPFA [14, 22]. This precise modeling enables accurate prediction of subjects’ behaviors on a moment-by-moment basis and millisecond timescale.
RNNs have also been used to model language, and have been analogously shown to capture linguistic structure in input sentences [31]. However, with the advent of massive language datasets and their costly training implications, the language modeling community has shifted away from recurrent networks and towards the Transformer architecture [32]. A Transformer receives a sequence of word tokens, or inputs, and processes each individual token in parallel. For example, a Transformer can classify the parts of speech of every word in a sentence simultaneously, whereas an RNN must process earlier words before later ones. A Transformer’s parallelism enables it to be trained and operated on sequential data faster than an RNN. Though neuroscience datasets may not yet be large enough to realize much training benefit, reduced inference times could already benefit real-time applications where cycle times are critical, such as brain-computer interfaces or closed-loop neural stimulation.
Here we introduce the Neural Data Transformer (NDT), an architecture for modeling neural population spiking activity. The NDT is based on the BERT encoder [5] with modifications for application to neuroscientific datasets, specifically multi-electrode spiking activity. Modifications are needed as spiking activity has markedly different statistics than both language data and other time series [8, 35] previously modeled by Transformers. Further, neuroscientific datasets are generally much smaller than typical dataset sizes in other machine learning domains, necessitating careful training decisions [9]1.
We test the NDT on synthetic and real datasets to validate its performance. In our synthetic datasets, we generate firing rates using autonomous dynamical systems and sample spikes from the firing rates. We show the NDT can use the sampled spikes to recover the unobserved rates as well as LFADS. Further, when applied to activity recorded from monkey motor cortex, NDT-inferred firing rates enable prediction of simultaneously measured behavioral variables as well as rates from LFADS. We then demonstrate the NDT’s inference efficiency, showing it performs inference in 3.9ms with minimal dependence on sequence length. On the monkey dataset, this enables inference 6.7 × faster than LFADS. We also include an ablative study measuring the contributions of different design choices, and consider the tradeoffs of using an NDT with fewer layers.
Our results provide a proof-of-principle that recurrence is not necessary to accurately infer neural population firing rates on a single-trial basis, and unlocks for neuroscience an alternative modeling paradigm that has greatly advanced other fields using machine learning models.
2. The NDT Model
Both the NDT and LFADS transform sequences of binned spiking activity into inferred firing rates (Fig. 1a). In real-time applications, the sequence of spiking activity would come from a rolling window of recent activity that ends with the current timestep. Both models assume a Poisson emission model, meaning inferred rates are compared against the observed spiking activity to compute a Poisson likelihood-based training objective (negative log-likelihood, NLL). When computing rates, a sequential model like the Decoder RNN in LFADS typically maintains an internal state. At each timestep, the state incorporates the next input, and the updated state produces the timestep’s output (Fig. 1b, bottom). In contrast, a Transformer uses a stack of layers that process all inputs together (Fig. 1b, top, depicts one such layer). A Transformer layer comprises several nonlinear blocks, in particular a self-attention block (Fig. 2a) in which a new representation of each input is constructed by incorporating relevant information from every other input. Specifically, a self-attention block creates three representations from each input: a query, key, and value. Each query is paired with every key to compute a dot-product similarity representing how much each input “attends” every other input. Each output is a weighted sum of all values, with weights determined by these attentions. Overall, self-attention enables the exchange of information across timesteps, and thus enables the modeling of temporal dependencies without explicit dynamics. Further details can be found in Sec. 7.2, and in the original Transformer paper [32] or the Annotated Transformer [15].
The body of the NDT architecture is a Transformer encoder with 6 layers in most of our experiments, as in Vaswani et al. [32]. We briefly discuss the option to use fewer layers in Sec. 3.3. Before entering the encoder, each channel of the input can be optionally projected to an n-dimensional embedding, i.e., for data with C channels, the dimension of each input representation is then Cn. To keep dimensionality small, we only consider n∈{1, 2 } in our experiments. We pass the Transformer encoder outputs through a linear layer and exponentiation (thus treating the linear layer outputs as log-firing rates) before calculating the NLL. Instead of the cross-entropy loss used in language modeling, these log-firing rates are passed into a Poisson likelihood loss.
To train the model in an unsupervised manner, we adapt the masked modeling methodology used in BERT (Fig. 2b). In masked modeling, the model is given an input sequence x1 … xT, with a random subset of the T input tokens masked. Subset size is typically a fixed ratio of the full sequence, e.g., 20% of the inputs. The model is then asked to reproduce the original input for that masked subset. To do so, the model must learn how to leverage the context provided by the unmasked timesteps (e.g., if firing rates in the dataset are temporally smooth, high spike counts in unmasked timesteps may imply high spike counts in masked timesteps). Readers familiar with LFADS might note that masked modeling resembles the coordinated dropout method developed to regularize LFADS models [13], only differing in that coordinated dropout masks individual dimensions (channels) of a given input timestep independently and is not constrained to mask entire input timesteps.
We adjust the training procedure as follows:
In BERT, masked inputs are typically replaced with a special “[MASK]” token. Instead of using this special token, which introduces a large distribution shift between training and inference time [5], we use a “zero mask.” That is, we simply zero out the spike inputs of a masked timestep, which was previously demonstrated to be an effective masking strategy for spiking data [13].
We use intensive regularization to stabilize training, which we find especially important when dataset sizes are smaller. Specifically, in the dropout layers (see Sec. 7.2 for locations), dropout ratios are swept ∈ [0.2, 0.6].
The importance of the design choices presented here are proven in an ablative study in Sec. 3.4.
3. Results
We compare the NDT with LFADS on both synthetic autonomous dynamics and M1 reaching activity, optimizing hyperparameters (ranges in Sec. 7.3) as follows:
NDT is optimized using grid search. Using early stopping, we select the checkpoint with least validation NLL as measured without masking.
LFADS is optimized using the AutoLFADS framework [14]. LFADS is known to benefit from Population-Based Training (PBT [10]) over simple grid search. (We find that NDT performs comparably between grid search and PBT.) AutoLFADS PBT is run with exponentially-smoothed validation NLL as the exploitation metric, and so we select the least smoothed validation NLL checkpoint [14].
Each search has 20 models. We run three searches for each experiment (a total of 3*20 = 60 models are trained) and report the mean and 95% CI of the metrics achieved. We select our models according to likelihood, since likelihood does not require knowledge of the underlying system and is measurable in both synthetic and real-world settings. We apply AutoLFADS with fixed settings that were previously shown to work in a variety of applications [14]. Note that our goal is to use AutoLFADS to provide a baseline for comparison, and we do not exhaustively explore its design choices or alternate hyperparameter ranges to achieve a performance ceiling or minimize training/inference times.
3.1. The NDT achieves high-fidelity inference on synthetic autonomous dynamical systems
We first evaluate the NDT on two synthetic datasets where observed activity reflects autonomous dynamics: the Lorenz system and the chaotic RNN (details in Sec. 7.4). The Lorenz system dataset [28, 36] is created by simulating a 3D state evolving according to the Lorenz equations, and projecting it to a specified higher dimensionality to form firing rates for a population of synthetic neurons. These rates are sampled according to a Poisson distribution to generate spikes. Similarly, the chaotic RNN dataset [30] is created by simulating dynamics using a vanilla RNN whose weights are initialized from the normal distribution. This system is motivated by the fact that many neural datasets are well modeled by RNNs (which are themselves nonlinear dynamical systems). The chaotic RNN is more complex than the Lorenz system - as measured by the number of principal components underlying the generating system - and is thus more challenging to model.
The synthetic setting allows us to evaluate inferred firing rates by comparing against the ground truth rates that produced the synthetic spikes. In both datasets, NDT and AutoLFADS inferences closely match the ground truth, though NDT rates appear less smooth (Fig. 3a). We quantify model inference quality by measuring the correspondence between inferred and ground truth firing rates using the coefficient of determination (R2; Tab. 1). In both datasets, the gap between the two models is small, indicating the NDT can accurately infer firing rates in autonomous dynamical systems. Importantly, we also find that within an HP search, NDT models with high data likelihoods (as computed on the observed spiking activity) tend to match the underlying systems well (as measured by correspondence with ground truth firing rates, Fig. 3b). This match between likelihoods and firing rate inference does not occur in LFADS models that lack coordinated dropout [13] and provides a key confirmation that the NDT’s masking strategy works as desired. Verifying that likelihood correlates with recovery of underlying structure in synthetic data provides confidence that likelihood can be used to optimize and choose between NDT models in applications to real-world data.
3.2. NDT infers motor cortical firing rates in autonomous settings with high fidelity
To test performance in real-world neural recordings, we apply NDT to the Monkey J Maze dataset [12]. Thesec data were previously used to evaluate LFADS and AutoL-FADS [13, 14, 22] and serve as a benchmark for models of autonomous dynamics. In this dataset, spiking activity from 202 neurons in the primary motor and dorsal premotor cortices was recorded as a monkey performed a delayed reaching task with a variety of straight and curved reaches. The reaching dataset consists of 2296 trials across 108 different reach conditions, where a given condition is specified by targets and obstacles present. Each trial has a random delay period that separates target presentation from a “Go” cue that prompts the monkey to begin its reach, which provides a time period for the monkey to plan before executing the reach. Previous analyses of this paradigm demonstrated that neural activity is well modeled as an autonomous dynamical system, where plan activity serves as an initial state that predicts the activity patterns observed during movement execution [4, 22, 27]. We train our models on activity during this autonomous period, spanning 250 ms before movement onset to 450 ms after. We perform most experiments by binning the spike sequences at 10ms; we find similar results for bin sizes varying from 2ms to 20ms (results not shown).
We compute peri-stimulus time histograms (PSTHs) for the models by averaging inferred rates across repeated trials of the same reach condition (Fig. 4a). Both NDT and AutoLFADS exhibit low across-trial variance (as shown by the shaded errorbars), indicating that the models produce consistent inferred rates for different trials of the same condition. We also calculate a spike smoothing baseline by first passing observed spiking activity through a Gaussian kernel with 30 ms standard deviation, and then averaging across trials to form empirical PSTHs (Fig. 4a, bottom). These exhibit larger across-trial variance than the model-inferred firing rates, as spike smoothing produces noisy estimates on single trials [22]. To quantify the quality of the models’ inferred rates, we measure the correspondence between inferred PSTHs and empirical PSTHs. NDT models with greater likelihoods tend to have better R2 (Fig. 4b). The highest-performing models perform on par with AutoLFADS.
For motor cortical datasets, another method to evaluate the quality of inferred firing rates is through behavioral decoding, i.e., testing how well simultaneously-recorded behavioral variables can be decoded from the models’ inferred rates. We use optimal linear estimation to map firing rates onto hand velocities (details in Sec. 7.5) and find that NDT enables accurate behavioral decoding that matches AutoLFADS (0.918 and 0.915 R2, respectively). These velocity predictions can be integrated to produce predicted reaching trajectories (Fig. 4c). The large number of trials (2000) also allows us to evaluate each model’s sensitivity to dataset size by subsampling from the full dataset. NDT comfortably outperforms the spike smoothing baseline, even when scaling to as few as 92 training trials (Fig. 4d). While a 6-layer NDT performs worse than AutoLFADS at 92 trials, we show that a 2-layer NDT closes the gap in Sec. 3.3.1.
3.3. Efficiency Gains from Parallelism
Inference Speeds
Since the NDT models a given input sequence in parallel, we should expect a roughly constant inference speed with respect to input sequence length. The NDT’s non-recurrence enables 3.9ms inference (Fig. 5, with details in Sec. 7.6), comfortably within the loop time of many real-time applications. In practice, we find the NDT’s inference times increase slightly with increased bin lengths; in contrast, LFADS inference times increase substantially. In the reaching dataset with sequence length 70, this amounts to a 6.7x speedup. For reference, prior work that achieved high-performing online decoding uses windows with 20 bins of 15ms [11]; even with this reduced bin count our method provides a 4x speedup.
One caveat to this inference efficiency comparison is that we use recurrent architectures in a non-iterative fashion. For real-time applications, recurrent architectures could potentially be adapted to maintain an internal state that is updated each time a new input is received (i.e., once per timestep), as is done in traditional iterative state space models such as Kalman Filters. The resulting inference speed should be comparable with a parallel architecture. However, to the authors’ best knowledge such an iterative approach has not been demonstrated on deep network models of neural data.
3.3.1 Smaller NDTs Improve Training Speed and Data Efficiency
The fixed computational complexity of the NDT’s parallel architecture should grant faster training in addition to inference [32]. The 6-layer NDT used in previous experiments, however, does train for significantly longer than our LFADS model (Tab. 2). We note that training times of different models across an HP search can vary widely, i.e., we see NDT 6-layer times between 3 and 18 hours. However, training times can be reduced substantially by simply using a smaller NDT. We find a 1-layer NDT, with around the same number of parameters as our LFADS model, trains under 30 minutes (and infers in under 1ms). Remarkably, this 1-layer NDT achieves 0.89 R2 on kinematics decoding in the Maze dataset (-0.02 R2 against the 6-layer baseline), and a 2-layer NDT matches the 6-layer performance. Note that the shallower NDTs train faster than LFADS again due to parallelism, as parallelism avoids the costly backpropagation through time used to train recurrent networks. In our case, the 6-layer NDT was much larger than the AutoL-FADS model; AutoLFADS training times are more appropriately compared with the 2 or 1-layer NDT.
Smaller models may also be more performant in limited data settings. The 2-layer model achieves 0.866 R2 when training on just 92 trials (not shown), outperforming the AutoLFADS model. Regularization is still critical for the smaller 2-layer model: performance drops to 0.4 R2 when the dropout range is confined to [0.0, 0.3] instead of [0.2, 0.6] (not shown). Though non-exhaustive, this result indicates that the gap between AutoLFADS and NDT when limited to 92 trials (Fig. 4d) may be due to 6-layer models being oversized. Extrapolating beyond this dataset, neural datasets, though smaller than in other domains, may be well-modeled by Transformers so long as the models are appropriately scaled.
3.4. Ablative Analysis
We empirically justify three key design choices of the 6-layer NDT by removing them and evaluating the degraded performance on the reaching dataset. Each is critical to achieving high performance (Tab. 3): without these subtle choices, performance is much worse and more variable. For example, models that infer rates instead of logrates train more slowly and fail to converge to a good solution over a wide set of hyperparameters (Fig. 6). Notably, inferring rates instead of logrates regresses performance in both 2-layer models and 6-layer models, i.e., constraining models to output in log space is important even with increased capacity. We also experimented with a few training details relevant in other Transformer works, such as variable length mask spans or adding embedding layers, but found their contributions on the reaching dataset to be marginal, on the order of 1-2% R2.
4. Discussion
We have introduced the NDT, a parallel neural network architecture for neural spiking activity, and shown it can be competitive with RNNs in autonomous dynamical settings while achieving substantially faster inference. Further, with careful architecture choices, the NDT could even match RNN performance on datasets with as few as 92 training trials (0.2 Mb). This indicates that Transformers are compatible with dataset sizes that are typically available in systems neuroscience.
The most critical limitation
of the NDT, and thus an important avenue for future work, is its inability to model non-autonomous dynamics, i.e. systems with unpredictable external perturbations. This occurs when unmonitored brain areas send signals to the recorded area. For example, unpredictable experiment cues that are first processed in so-matosensory or visual areas will propagate and perturb the dynamics of recorded motor areas. LFADS, which infers inputs in such non-autonomous settings, outperforms NDT by over 20% R2 in a preliminary experiment with a synthetic, non-autonomous dataset, the Chaotic RNN with Inputs studied in Sussillo et al. [30]. Incorporating a prior for exogenous inputs may solve this limitation.
Despite this limitation, we put forth the NDT as a forward-looking proposal. We believe the NDT and more generally the Transformer can benefit neuroscience due to the Transformer’s rapid rise in the broader machine learning community. This community will continue to advance Transformer tooling, analysis, and theory. Many of these advances could translate to neuroscientific applications. We provide two example directions:
Story generation requires modeling of both sensible short-term sentence structure and a coherent long-term storyline. While RNNs struggle to learn long-term dependencies, the Transformer’s parallel design makes it less biased with respect to either short or long term dependencies. This enables the Transformer to produce long passages of coherent text [25]. Analogously, a single Transformer model may yield insights around both fast and slow features of neural activity, uncovering hierarchy within the activity that maps naturally to the multi-scale nature of animal behavior [2, 3].
Transformers have been productively used to understand the interaction of data from multiple modalities. For example, vision-language transformers [18] produce language representations that are contextualized by accompanying images. Similar techniques could be applied to build models which incorporate recordings of multiple brain areas, different recording modalities, and behavioral measurements.
However, the major driver of the Transformer’s popularity is its ability to scale to large amounts of training data better than RNNs (i.e. through faster training). As increasing training data generally improves machine learning models across domains, we anticipate that larger datasets from new recording technologies and dataset aggregation will further improve the NDT’s performance and applicability, possibly past recurrent methods. Notably, these large datasets need not be excessively difficult to collect. For example, they could consist of neural activity that is continuously collected without constrained or even measured behavior. In other domains, the largest datasets tend to be similarly unstructured, naturally-occurring data, such as freeform text extracted from the internet. In a large-scale “pretraining” step, networks can learn deep representations of such data in a self-supervised manner, using methods such as those used to train LFADS and NDT. Pretrained representations make subsequent learning for downstream tasks much more data-efficient. The seeming universality of the representations learned in these tasks, for example, has prompted the GLUE language benchmark [34] to assess how well single models perform on 9 different language tasks. An analogous effort in neuroscience may help reveal all the different computational roles of a given neural population, much as prior work has sought to find preferential tuning properties for single neurons.
One promising avenue in the analysis of trained RNNs is the application of techniques from nonlinear dynamical systems theory to interrogate the RNNs’ learned dynamical structure [7, 20, 28, 29]. The Transformer is currently disconnected from these dynamical techniques, as it lacks a recurrent structure to analyze. It would be useful, even beyond the computational neuroscience community, to try to bridge this gap and understand how the Transformer represents dynamical structure.
6. Author Contributions
JY and CP jointly contributed towards conceptualization, writing, and revision. JY was responsible for investigation and software, and CP was responsible for funding acquisition and resources.
7. Methods
7.1. Data Availability
The Lorenz dataset, along with generation scripts for the Chaotic RNN dataset, are available in the code repo. The Maze dataset will be released upon publication.
7.2. Architectural Details
We summarize the Transformer encoder and selfattention mechanism, though we refer the reader to the resources [15, 32] for more details. Inputs to a Transformer layer pass through a self-attention block, a layer norm block [1], an MLP, and another layer norm.
The self-attention block is the only ones which simultaneously transforms multiple inputs into multiple outputs. It comprises three different learned weight matrices, termed the query, key, and value matrices. All inputs are multiplied by these matrices to form three sets of intermediate representations, respectively termed the queries, keys, and values. An output is formed by taking a weighted sum of the values. To be precise, we introduce notation. The entire block transforms a sequence of T inputs x[1:T], into outputs y[1:T]. The intermediate representations are q[1:T], k[1:T], v[1:T]. For example, if the query matrix is denoted Q, then qi = Qxi.
Output yi is computed as a weighted sum with weights, where each weight represents the “attention” step i pays to step j. is determined by calculating dot-product similarity between query i and key j, and then normalizing similarities over all j with the softmax function. Formally: Self-attention lets inputs query for relevant information from other inputs. However, if we directly feed population representations, inputs would be unable to query for information from a particular timestep, i.e. there is no intrinsic ordering of the inputs. This is inappropriate for our context. To allow the NDT to account for input order, we add a learned position embedding (i.e. a unique vector representing the identity of the input timestep) to each input before it is fed into the transformer layers.
Following self-attention, we have layer normalization and an MLP. Each layer normalization block receives a single population state vector as input and normalizes this input using the mean and variance of its elements. The MLP comprises 2 linear layers joined with a non-linear ReLU activation, and similarly transforms a single input to a single output. Dropout layers are added right before the inputs enter the transformer body (the consecutive transformer layers), right after they exit the body, and right after each linear layer in the MLP of each transformer layer.
7.3. Hyperparameters
NDT searches are swept over
Dropout ratio, as described in Sec. 3.
Context span, the number of timesteps forward and backward each input aggregates information from. Span is swept between 4 and 32 steps in both directions for the synthetic datasets, and 10 and 50 in the reaching datasets.
The ratio of masked tokens that are replaced with a random input instead of a zero mask, and the ratio that are not replaced at all (a methodology from BERT to reduce train-test distribution shift). Zero mask ratio is between 0.5 and 1.0 on synthetic datasets, and 0.6 and 1.0 on the reaching dataset. Of the remaining masked tokens, between 0.9 and 1.0 are replaced with random inputs on synthetic datasets, between 0.6 and 1.0 on the reaching dataset.
Length of masked span [16] is set between 1 and 5 in synthetic datasets, and 1 and 7 in reaching dataset.
AutoLFADS PBT optimizes over:
Dropout, from 0.0 to 0.6
Coordinated Dropout [13] rate, from 0.01 to 0.7
L2 penalties for the generator from 1e-4 to 1.0
KL penalties for the initial condition from 1e-5 to 1e-3
Both models optimize learning rate, from 1e-5 to 5e-3. The LFADS controller is kept off as we study autonomous settings. Note that although we find these AutoLFADS settings outperform the ranges reported in [14], we only claim they are sufficient and not necessary for achieving reported results. PBT settings such as early stopping metrics and epochs per generation are as in [14]. Other hyperparameters are available in the code.
7.4. Synthetic Dataset
The train-val split is 0.8 and 0.2 for each dataset. The Lorenz dataset has 1560 total trials, 50 timesteps, and 29 channels. These trials comprise 65 conditions (firing rate trajectories) with 24 trials sampled per condition. The RNN dataset is generated with γ = 1.5 and has 1300 total trials (with 100 conditions and 13 trials per condition), 100 timesteps, and 50 channels. R2 is calculated by flattening timesteps and trials, and averaging across input channels, as done in [14].
7.5. Kinematic Decoding
We decode 2D hand velocity from inferred rates at single timesteps using ridge regression with α = 0.01. As in [22], we find improved decoding performance by applying a 90 ms lag between neural activity and the corresponding kinematics, i.e., while rates are inferred between a (−0.25ms, 0.45ms) window around movement onset, kinematics are predicted only around (−0.16ms, 0.45ms).
7.6. Timing Tests
We report the time of a forward pass through the NDT and LFADS models, i.e. the time it takes to infer rates from spike inputs. 1 posterior sample is used for LFADS. Times are averaged over 1300 trials. Measurements were taken on a machine (on CPU) with 32GB RAM and a 4-core i7-4790K processor running at 4.2 GHz.
5. Acknowledgements
We thank Andrew Sedler, Yahia Ali, and Ruyi Marone for their insights and conversations. We also thank Krishna Shenoy, Mark Churchland, Matt Kaufman, and Stephen Ryu for sharing the Monkey J Maze dataset. This work was supported by the Emory Neuromodulation and Technology Innovation Center (ENTICe), NSF NCS 1835364, DARPA PA-18-02-04-INI-FP-021, NIH Eunice Kennedy Shriver NICHD K12HD073945, the Alfred P. Sloan Foundation, and the Simons Foundation as part of the SimonsEmory International Consortium on Motor Control (CP). The views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official policies or endorsements, either ex-pressed or implied, of the U.S. Government, or any sponsor.
Footnotes
joel.ye{at}gatech.edu, chethan{at}gatech.edu
Self-attention clarified, contributions added.
↵1 Negative results such as “difficult training” are under-reported. Our regularization was inspired by discussion in thisTwitterthread.