Abstract
Understanding how neural dynamics give rise to behaviour is one of the most fundamental questions in systems neuroscience. To achieve this, a common approach is to record neural populations in behaving animals, and model these data as emanating from a latent dynamical system whose state trajectories can then be related back to behavioural observations via some form of decoding. As recordings are typically performed in localized circuits that form only a part of the wider implicated network, it is important to simultaneously learn the local dynamics and infer any unobserved external input that might drive them. Here, we introduce iLQR-VAE, a control-based approach to variational inference in nonlinear dynamical systems, capable of learning both latent dynamics, initial conditions, and ongoing external inputs. As in recent deep learning approaches, our method is based on an input-driven sequential variational autoencoder (VAE). The main novelty lies in the use of the powerful iterative linear quadratic regulator algorithm (iLQR) in the recognition model. Optimization of the standard evidence lower-bound requires differentiating through iLQR solutions, which is made possible by recent advances in differentiable control. Importantly, the recognition model is naturally tied to the generative model, greatly reducing the number of free parameters and ensuring high-quality inference throughout the course of learning. Moreover, iLQR can be used to perform inference flexibly on heterogeneous trials of varying lengths. This allows for instance to evaluate the model on a single long trial after training on smaller chunks. We demonstrate the effectiveness of iLQR-VAE on a range of synthetic systems, with autonomous as well as input-driven dynamics. We further apply it to neural and behavioural recordings in non-human primates performing two different reaching tasks, and show that iLQR-VAE yields high-quality kinematic reconstructions from the neural data.
1 Introduction
The mammalian brain is a complex, high-dimensional system, containing billions of neurons whose coordinated dynamics ultimately drive behaviour. Identifying and interpreting these dynamics is the focus of a large body of neuroscience research, which is being facilitated by the advent of new experimental techniques that allow large-scale recordings of neural populations (Jun et al., 2017; Stosiek et al., 2003). A range of methods have been developed for learning dynamics from data (Buesing et al., 2012; Gao et al., 2016; Duncker et al., 2019; Archer et al., 2015; Hernandez et al., 2018; She and Wu, 2020; Kim et al., 2021; Nguyen et al., 2020). These methods all specify a generative model in the form of a flexible latent dynamical system driven by process noise, coupled with an appropriate observation model.
Importantly, neural recordings are typically only made in a small selection of brain regions, leaving many areas unobserved which might provide relevant task-related input to the recorded one(s). Yet, the aforementioned methods perform Bayesian inference of state trajectories directly, and therefore do not support inference of external input (which they effectively treat as process noise and marginalize out). Indeed, simultaneous learning of latent dynamics and inference of unobserved control inputs is a challenging, generally degenerate problem that involves teasing apart momentary variations in the data that can be attributed to the system’s internal transition function, and those that need to be explained by forcing inputs. This distinction can be achieved by introducing external control in the form of abrupt changes in the latent state transition function, and inferring these switching events (Ghahramani and Hinton, 2000; Linderman et al., 2017). More recently, Pandarinath et al. (2018) introduced LFADS, a sequential variational autoencoder (VAE) that performs inference at the level of external inputs as well as initial latent states. The inferred inputs were shown to be congruent with task-induced perturbations in various reaching tasks in primates (Pandarinath et al., 2018; Keshtkaran and Pandarinath, 2019). Further related work is discussed in Appendix A.
Here, we introduce iLQR-VAE, a new method for learning input-driven latent dynamics from data. As in LFADS, we use an input-driven sequential VAE to encode observations into a set of initial conditions and external inputs driving an RNN generator. However, while LFADS uses a separate, bidirectional RNN as the encoder, here we substitute the inference network with an optimization-based recognition model that relies on the powerful iterative linear quadratic regulator algorithm (iLQR, Li and Todorov, 2004). iLQR solves an optimization problem that finds a mode of the exact posterior over inputs for the current setting of generative parameters. This ensures that the encoder (mean) remains optimal for every update of the decoder, thus reducing the amortization gap (Cremer et al., 2018). Moreover, having the recognition model be implicitly defined by the generative model stabilizes training, prevents posterior collapse (thus circumventing the need for tricks such as KL warmup), and greatly reduces the number of (hyper-)parameters.
While iLQR-VAE could find applications in many fields as a general approach to learning stochastic nonlinear dynamical systems, here we focus on neuroscience case studies. We first demonstrate in a series of synthetic examples that iLQR-VAE can learn the dynamics of both autonomous and input-driven systems. Next, we show state-of-the art performance on monkey M1 population recordings during two types of reaching tasks (O’Doherty et al., 2018; Churchland et al., 2010). In particular, we show that hand kinematics can be accurately decoded from inferred latent state trajectories, and that the inferred inputs are consistent with recently proposed theories of motor preparation.
2 Method
iLQR-VAE models a set of temporal observations, such as behavioural and/or neural recordings, through a shared input-driven nonlinear latent dynamical system (Figure S1). The input encapsulates both process noise (as in traditional latent dynamics models), initial inputs that set the initial condition of the dynamics, and any meaningful task-related control input. In this section, we describe the architecture of the generative model, and the control-based variational inference strategy used for training the model and making predictions. A graphical summary of the model can be found in Appendix B.
2.1 Generative model
We consider the following generative model: where ut ∈ ℝm, zt ∈ ℝn and are the input, latent state and observations at time t, respectively. Here, observations may comprise either neural activity, behavioural variables, or both – the distinction will be made later where relevant. We use the notation θ to denote the set of all parameters of the generative model. We use u0 to set the initial condition z1 = fθ(0, u0, 0) of the network1. This way, the latent state trajectory of the network z(u) = {z1, …, zT} is entirely determined by the input sequence u = {u0, …, uT}and the state transition function fθ(·), according to Equation 1. For fθ(·), we use either standard linear or GRU-like RNN dynamics (see Appendix C for details). For the likelihoods, we use Gaussian or Poisson distributions with means given by linear or nonlinear readouts of the network state of the form ōt = h(Czt + b) (Appendix D).
We place a Gaussian prior over ut ≤ 0. We then consider two alternative choices for the prior over ut>0. The first is a Gaussian prior with S = diag(s1, …, sm). In many settings however, we expect inputs to enter the system in a sparse manner. To explicitely model this, we introduce a second prior over u in the form of a heavy-tailed distribution constructed hierarchically by assuming that the ith input at time t> 0 is where si > 0 is a scale factor, ϵit ∼ 𝒩 (0, 1) is independent across i and t, and is a shared scale factor drawn from a chi-squared distribution with ν degrees of freedom. Thus, inputs are spatially and temporally independent a priori, such that any spatio-temporal structure in the observations will have to be explained by the coupled dynamics of the latent states. Moreover, the heavy-tailed nature of this prior allows for strong inputs when they are needed. Finally, the fact that the scale factor is shared across input dimensions means that inputs are either all weak or potentially all strong at the same time for all input channels, expressing the prior belief that inputs come as shared events.
This hierarchical construction induces a multivariate Student prior at each time step: where S = diag(s1, …, sm). Note that both S and ν are parameters of the generative model, which we will learn.
2.2 iLQR-VAE: a novel control-based variational inference strategy
To train the model, we optimize θ to maximize the log-likelihood of observing a collection of independent observation sequences 𝒪 = {o(1), …, o(K)}, or “trials”, given by:
As the integral is in general intractable, we resort to an amortized variational inference strategy by introducing a recognition model qϕ(u | o(k)) to approximate the posterior pθ(u | o(k)). Following standard practice (Kingma and Welling, 2013; Rezende et al., 2014), we thus train the model by maximizing the evidence lower-bound (ELBO): with respect to both θ and ϕ.
Here, the main novelty is the use of an optimization-based recognition model. We reason that maximizing the exact log posterior, i.e. computing subject to the generative dynamics of Equations 1 and 2, is a standard nonlinear control problem: acts as a running cost penalizing momentary deviations between desired outputs ot and the actual outputs caused by a set of controls u, and log pθ(ut) acts as an energetic cost on those controls. Importantly, there exists a general purpose, efficient algorithm to solve such nonlinear control problems: iLQR (Li and Todorov, 2004; Appendix E). We thus propose to use a black-box iLQR solver to parameterize the mean of the recognition density qϕ(u | o) for any o, and to model uncertainty separately using a multivariate Gaussian density common to all trials. Therefore, we parametrize the recognition model as follows: where we use a separable posterior covariance (the Kronecker product of a spatial factor Σs and a temporal factor Σt).
To optimize the ELBO, we estimate the expectation in Equation 8 by drawing samples from qϕ(u | o(k)) and using the reparameterization trick (Kingma et al., 2015) to obtain gradients. A major complication that would normally preclude the use of optimization-based recognition models is the need to differentiate through the mean of the posterior. In this case, this involves differentiating through an entire optimization process. Using automatic differentiation within the iLQR solver is in general impractically expensive memory-wise. However, recent advances in differentiable model predictive control enable implicit differentiation through iLQRsolve with a memory cost that does not depend on the number of iterations (Amos et al., 2018; Blondel et al., 2021; Appendix F).
2.3 Complexity and implementation
We optimize the ELBO using Adam (Kingma and Ba, 2014) with a decaying learning rate where i is the iteration number. Averaging over data samples can be easily parallelized; we do this here using the MPI library and a local CPU cluster. In each iteration and for each data sample, obtaining the approximate posterior mean through iLQR is the main computational bottleneck, with a complexity of 𝒪 (T (n3 +n2no)). To help mitigate this cost, we find it useful to re-use the previously inferred control inputs to initialize each iLQRsolve.
3 Experiments and results
3.1 iLQR-VAE enables fast learning of dynamics
Before demonstrating the method on a number of synthetic and real datasets involving ongoing external inputs, we begin with a simpler example meant to illustrate some of iLQR-VAE’s main properties (Figure 1). We generated data from an autonomous (i.e. non-input-driven) linear dynamical system (n = 8 latents, m = 3 input channels) seeded with a random initial condition in each of 56 trials. The state zt was linearly decoded with added Gaussian noise to produce observation data, which we used to train a model in the same class.
At the beginning of learning, iLQR-VAE originally relies on large ongoing inputs that control the generator into producing outputs very similar to the observations in the data (Figure 1, red box, left), resulting in a rapidly decreasing loss. Subsequently, the amount of input required to fit the observations gradually decreases as the system learns the internal dynamics of the ground truth system. Eventually, the inferred control inputs become confined to the first time bin, i.e. they act as initial conditions for the now autonomous dynamics of the generative model. Thus, iLQR-VAE operates in a regime where the output of the generator explains the data well at all times, and learning consists in making the inputs more parsimonious. We note that this regime is facilitated here by our choice of generator dynamics, which we initialised to be very weak (i.e with a small spectral radius) initially and therefore easily controllable.
We contrast this with learning in a modified version of iLQR-VAE where we allowed u0 to vary freely (with a Gaussian prior of adjustable variance) but effectively fixed ut>0 to be 0. In other words, we constrained the dynamics of the generator to remain (near-)autonomous throughout learning (Figure 1, grey box, top). Although this incorporates important information about the ground truth generator (which is itself autonomous), counter-intuitively we found that it impairs learning. At the beginning of training, iLQR is unable to find initial conditions that would explain the data well, resulting in a much higher initial loss. The model then gets stuck in plateaus that are seemingly avoided by the free version of iLQR-VAE (see Figure S2 for independent repeats of this experiment).
On the same toy dataset, we also compared iLQR-VAE to LFADS (Pandarinath et al., 2018), keeping the generative model in the same model class (see Appendix G for details). We found that LFADS learns in a similar manner to iLQR-VAE (Figure 1), also progressively doing away with inputs.
3.2 iLQR-VAE for nonlinear system identification
Next, we illustrate the method on an autonomous nonlinear dynamical system, the chaotic Lorenz attractor (Lorenz, 1963; Appendix H). This is a standard benchmark to evaluate system identification methods on nonlinear dynamics (Nguyen et al., 2020; Hernandez et al., 2018; Champion et al., 2019), and one typically considers the dynamics to be learned if the trained model can recreate the whole attractor structure.
Here, we show that iLQR-VAE can learn these complex nonlinear dynamics. Before training, the inferred inputs are large throughout the trial, and explain the output observations by forcing the internal state of the generator into appropriate trajectories (Figure 2A, top). At the end of learning, the inputs remain confined to the first time bin, setting the initial condition of the trajectories which are now driven by the stronger, near-autonomous dynamics of the generator. In Figure 2B we show that, conditioned on an initial bout of test data, the model perfectly predicts the rest of the trajectory. Moreover, starting from a random initial condition, the model can recreate the whole attractor structure (Figure 2C).
To quantitatively assess how well the dynamics have been learned, we computed the k-step coefficient of determination, , as in Hernandez et al. (2018). This metric evaluates how well the model can predict the true state k steps into the future, starting from any state inferred along a test trajectory (see Appendix H for details). Hernandez et al. reported but did not show results for larger k. For iLQR-VAE, and the forward interpolation was still very high at 50 time steps, with .
3.3 Inferring sparse inputs
To demonstrate iLQR-VAE’s ability to infer unobserved inputs and learn the ground truth dynamics of an input-driven system, we generated synthetic data from a system with n = 3, m = 3 and no = 10, which evolves with linear dynamics for T = 1000 time steps (see Appendix I for an example with input-driven nonlinear dynamics). The system was driven by sparse inputs, and the output corrupted with Gaussian noise. Input events were drawn in each time step from a Bernoulli distribution with mean p = 0.03. Whenever an input event occurred, the magnitude of inputs in each channel was drawn from a standard mutivariate Gaussian distribution.
We fit both iLQR-VAE and LFADS models to these data, choosing the generator to be within the ground-truth model class for both. iLQR-VAE captured most of the variance in the inputs (Figure 3A; R2 = 0.94 ± 0.02; 5 random seeds), and recovered the eigenvalue spectrum of the transition matrix almost perfectly (Figure 3B). LFADS however performed poorly on this example (R2 = 0.05 ± 0.02 for input reconstruction; 3 random seeds), as well as in several other similar comparisons on datasets of different sizes and trial numbers (Appendix J). This is unsurprising, as LFADS assumes a dense (auto-regressive) Gaussian prior over the inputs, which is not overridden by the relatively small amount of data used here. Note however that when applied to a set of 56 trials of 100 time steps driven by Gaussian autoregressive inputs, iLQR-VAE still captured the structure in the inputs more accurately than LFADS did (R2 = 0.81 ± 0.01 vs. 0.29 ± 0.06). We hypothesize that this reflects the difficulty of learning a good recognition model from a small amount of data. We evaluate the effect of the choice of prior more extensively in Appendix K.
3.4 Predicting hand kinematics from primate M1 recordings
3.4.1 Trial-structured maze task
To highlight the utility of iLQR-VAE for analyzing experimental neuroscience data, we next applied it to recordings of monkey motor (M1) and dorsal premotor (PMd) cortices during a delayed reaching task (‘Maze’ dataset of Kaufman et al., 2016; DANDI 000128). This dataset contains 108 different reach configurations over nearly 3000 trials, and has recently been proposed as a neuroscience benchmark for neural data analysis methods (Pei et al., 2021). We compared the performance of iLQR-VAE to several other latent variable models, evaluated on this dataset in Pei et al. (2021).
Consistent with previous findings (Pandarinath et al., 2018), iLQR-VAE inferred inputs that were confined to initial conditions, from which smooth single-trial dynamics evolved near-autonomously (Figure 4A). As a first measure of performance, we evaluated the models on “co-smoothing”, i.e the ability to predict the activity of held-out neurons conditioned on a set of held-in neurons (see Appendix L for details). Conditioning of 137 neurons (i.e using 45 held-out neurons), we obtained a co-smoothing of 0.331 ± 0.001 (over 5 random seeds). For comparison, Pei et al. (2021) reports 0.187 for GPFA (Yu et al., 2009), 0.225 for SLDS (Linderman et al., 2017), 0.329 for Neural Data Transformers (Ye and Pandarinath, 2021) and R2 = 0.346 for AutoLFADS (LFADS with large scale hyperparameter optimization; Keshtkaran et al., 2021) on the same dataset.
Next, we assessed how well hand velocity could be decoded from neural activity – another metric of interest to neuroscientists. We applied ridge regression to predict the monkey’s hand velocity (with a 100 ms lag) from momentary neuronal firing rates (mean of the posterior predictive distribution) on test data. This reconstruction could be performed with very high accuracy R2 = 0.896 ± 0.002 (over 5 random seeds), compared to 0.640 for GPFA, 0.775 for SLDS, 0.897 for Neural Data Transformers and 0.907 for AutoLFADS (Pei et al., 2021). These experiments place iLQR-VAE on par with state-of-the-art methods, without any extensive hyperparameter optimization.
3.4.2 Continuous reaching task
While a large number of neuroscience studies perform neural and behavioural recordings during trial-structured tasks, much can be learned by analyzing the dynamics of more naturalistic, less constrained behaviours. iLQR-VAE’s flexible recognition model is well-suited to the analysis of such less structured tasks, as it can easily be trained and tested on trials of heterogeneous lengths. To illustrate this, we applied iLQR-VAE to a self-paced reaching task during which a monkey had to reach to consecutive targets randomly sampled from a 17×8 grid on a screen (O’Doherty et al., 2018; Makin et al., 2018). This dataset consists of both neural recordings from primary motor cortex (M1) together with continuous behavioural recordings in the form of x- and y-velocities of the fingertip.
In this example, we experimented with fitting the spike trains and hand velocities jointly (combining a Poisson likelihood for the 130 neurons and a Gaussian likelihood for the 2 kinematic variables, see Appendix M for further details). We found that it allowed iLQR-VAE to reach a similar kinematic decoding performance as when fitting neural activity alone, but using a smaller network. More generally, we reason that a natural approach to making behavioural predictions from neural data using a probabilistic generator is to fit it to both jointly, and then use the posterior predictive distribution over behavioural variables (conditioning on spike trains only) as a nonlinear decoder. In future work, this could provide more accurate predictions in those motor tasks where linear regression struggles (see e.g. Schroeder et al., 2022).
For our analyses, we used the first ∼22 minutes of a single recording session (indy_20160426), excluded neurons with overall firing rates below 2 Hz, and binned data at 25 ms resolution. Although it is not a formal requirement of our method, we chunked the data into 336 non-overlapping pseudotrials of 4 s each, in order to enable parallelization of the ELBO computation during training. We only trained the model on a random subset of 168 trials.
To highlight the flexibility of iLQR as a recognition model, we then evaluated the model by performing inference on the first 9 minutes of the data, as a single long chunk of observations. Note that this is not generally possible in LFADS or any sequential VAE where an encoder RNN has been trained exclusively on trials of the same fixed length. Despite the lack of trial structure, we found that neurons display a stereotyped firing pattern across multiple instances of each reach. This was revealed by binning the angular space into 8 reach directions, temporally segmenting and grouping the inferred firing rates according to the momentary reach direction, and aligning these segments to the time of target onset (Figure 4B). Moreover, hand kinematics could be linearly decoded from the inferred firing rates with high accuracy (Figure 4C; R2 = 0.75 ± 0.01 over 5 random seeds), on-par with AutoLFADS (R2 = 0.76; Keshtkaran et al., 2021), and considerably higher than GPFA and related approaches (R2 = 0.6; Jensen et al., 2021).
We next wondered whether we could use iLQR-VAE to address an open question in motor neuroscience, namely the extent to which the peri-movement dynamics of the motor cortex rely on external inputs (possibly from other brain areas). Such inputs could arise during movement preparation, execution, neither, or both. We thus examined the relationship between the inputs inferred by iLQR-VAE and the concurrent evolution of the neuronal firing rates and hand kinematics. Overall, neuronal activity tends to rise rapidly starting 150 ms before movement onset (Figure 4D, red), consistent with the literature (Shenoy et al., 2013; Churchland et al., 2012). Interestingly, we observed that inputs tend to arise much earlier (around the time of target onset), and start decaying well before the mean neural activity has finished rising (Figure 4D top), about 150 ms before the hand started to move (Figure 4D, bottom). While these results must be interpreted cautiously, as inference was performed using information from the whole duration of the trial (i.e. using iLQR as a smoother), they show that the data is best explained by large inputs prior to movement onset, rather than during movement itself. Interestingly, the timing of these inputs is globally consistent with target-induced visual inputs driving preparatory activity in M1, whose dynamics then evolve in a more autonomous manner to drive subsequent motion.
4 Discussion
Limitations and future work
While we have demonstrated that iLQR-VAE performs well on various toy and real datasets, the method has a number of limitations, some of which could be addressed in future work. Firstly, the problem of decoupling ongoing inputs from dynamics is degenerate in general, and there is no guarantee that iLQR-VAE will always successfully identify the ground-truth. This problem will be exacerbated in the low data regime, or if there is a large mismatch between our prior over inputs and the true input distribution. While further generalization tests such as extrapolations can be used to assess post-hoc how well the dynamics have been learned, the lack of identifiability will often make interpretation of the model parameters difficult. Secondly, using iLQR as a way of solving maximum a posteriori inference in state space models comes at a high computational cost, and with the risk that iLQR may converge to a local minimum. We note that both these issues could potentially be tackled at once if process noise in the generator was modelled separately from control inputs, as the MAP estimation problem could then be solved using some of the highly efficient algorithms available in the framework of linearly solvable stochastic control (Todorov, 2009; Dvijotham and Todorov, 2013; Kappen, 2005). Finally, for simplicity we modelled posterior input uncertainty using a common covariance across all data samples. This might be limiting, for example when modelling neural populations that exhibit coordinated global firing fluctuations giving rise to data samples with highly variable information content. A better solution, left to future work, would be to amortize the computation of the posterior uncertainty by reusing some of the computations performed in iLQR.
Conclusion
The rise of new tools and software now makes it possible to record from thousands of neurons while monitoring behaviour in great detail (Jun et al., 2017; Mathis et al., 2018; Musk et al., 2019). These datasets create unique opportunities for understanding the brain dynamics that underlie neural and behavioural observations. However, identifying complex dynamical systems is a hard nonlinear filtering and learning problem that calls for new computational techniques (Kutschireiter et al., 2020). Here, we exploited the duality between control and inference (Toussaint, 2009; Kappen and Ruiz, 2016; Levine, 2018; Appendix N) to bring efficient algorithms for nonlinear control to bear on learning and inference in nonlinear state space models.
The method we proposed uses iLQR, a powerful general purpose nonlinear controller, to perform amortized inference over inputs in an RNN-based generative model. Using an optimization-based recognition model such as iLQR has two advantages. First, it brings important flexibility at test time, enabling predictions on arbitrary, heterogeneous sequences of observations as well as seamless handling of missing data. Second, owing to parameter sharing between the generative and recognition models, the ELBO gap is reduced (Appendix O), making learning more robust (in particular, to initialization) and reducing the number of hyperparameters to tune. With the advent of automatically differentiable optimizers (Blondel et al., 2021), we therefore hope that optimization-based recognition models will open up new avenues for VAEs.
Appendix
A Additional related work
In this section, we first discuss (non-exhaustively) several methods used for identifying dynamical systems from data, before presenting the few approaches we are aware of that explictly tackle the problem of inferring unobserved control inputs to those systems.
The problem of identifying the dynamics giving rise to a set of observations is one that spans many fields, from climate modelling to neuroscience, and a variety of methods have therefore been developed to tackle it. Most existing approaches assume non-driven dynamics, as this greatly facilitates systems identification.
One common modelling paradigm is to assume the data arises from a latent linear dynamical system (LDS), which parameters can be learned using an Expectation-Maximization (EM) approach Ghahramani and Hinton (1996). While linear models are typically very efficient as they allow estimates to be computed in closed-form, they severely restrict the range of dynamics that can be approximated. Various extensions have been proposed, such as switching linear dynamical systems (Linderman et al., 2017; Ghahramani and Hinton, 2000), which assume that the data can be modelled using several latent dynamical systems with a Hidden Markov Model controlling the transitions between those. Alternatively, Costa et al. (2019) proposes to use adaptative locally linear dynamics and uses an iterative procedure to find the most likely switching points. In a similar vein, Hernandez et al. (2018) approximates the dynamics as locally linear; interestingly, the proposed method (VIND) incorporates the generative dynamics in the approximate posterior distribution over latent trajectories given data. This is reminiscent of the approach taken in iLQR-VAE, where the recognition parameters are kept tied to the generative parameters.
Another way to keep the problem solvable while allowing for richer dynamics is to approximate those using a linear combination of nonlinear basis functions. This then turns the optimization into the more simple problem of learning the weights of the expansion (with the caveat that one needs to choose the set of basis functions). This is the method used in Brunton et al. (2016b), with an additional constraint that the coefficients are sparse in the space of basis functions, yielding a more interpretable model. This was later extended in Champion et al. (2019) to allow for automatic discovery of a set coordinates in which the dynamics can be approximated as sparse.
In a similar manner, a popular approach involves modelling the dynamics as linear in the space of observables (which can include linear or nonlinear mappings from the state of the system), as is done in dynamic mode decomposition Schmid (2010); Kutz et al. (2016) (see Brunton et al. (2016a) for applications to neural data). This approach is closely related to Koopman operatory theory, which finds a set of dynamic modes and uses those to approximate the data as a single linear dynamical system.
Finally, the dynamics can be modelled using nonlinear neural networks, and the parameters learned using variational methods (see e.g Nguyen et al., 2020; Hernandez et al., 2018; Koppe et al., 2019).
Most of the aforementioned models can be extended to incorporate known external inputs coming into the system. This is for instance done in dynamic mode decomposition with control inputs (DMDc; Proctor et al., 2016), which can be generalized into Koopman operators with inputs and control (KIC; Proctor et al., 2018).
On the other hand, the range of methods modelling dynamics driven by unobserved inputs (which must thus be inferred) is a lot more limited. Indeed, LFADS (Pandarinath et al., 2018) is the first method we are aware of which explicitely models the set of control inputs driving the system. As described in the main text, LFADS models the dynamics as a (potentially) input-driven nonlinear dynamical system, and learns both the parameters and the inputs. More recently, Fieseler et al. (2020) proposed an extension of DMDc to handle unsupervised learning of unobserved signals as well as estimation of the dynamics. This was then used this to successfully model neural recordings made in C. elegans. Crucially however, the dynamics were modelled as linear, thus restricting the range of dynamics that the learnt system could generate. Morrison et al. (2020) modelled the same data using input-driven nonlinear dynamics, but assumed a limited subset of inputs driving transitions at given time points, and thus only learned the magnitude of those inputs and not their timing.
Finally, an approach related to the modelling of unobserved inputs (which give rise to changes that cannot be explained by the dynamics alone) is the explicit modelling of events which lead to discontinuities in the dynamics. This is done in Chen et al. (2020) within the framework of neural ordinary differential equations (Chen et al., 2018). To some extent, one can also view switching dynamical models as inferring unobserved inputs giving rise to state transitions, although those “inputs” are restricted to live in a discrete subspace.
B Graphical summary of the model
C Implementation of the dynamics
We considered different functional forms for the discrete-time dynamics of the latent state. In the following, zt and ut denote the latent state and an external input at time t, respectively.
C.1 Linear dynamics
The simplest case considered is that of linear dynamics:
One issue with linear dynamics is that they may become unstable, such that repeated application of the operator A will lead to a divergence of ‖z‖ and the associated gradients. This can become problematic, especially when modelling long sequences of observations. To circumvent this issue, we used a parametrization of the propagator A that ensured it remained stable at all times. To find a stable linear parametrization, we considered the Lyapunov stability condition (Bhatia and Szegö, 2002). The discrete time dynamics of Equation S1 are asymptotically stable if and only if A satisfies for some positive definite matrix P with eigenvalues ≥1. It is easy to verify that the following parameterization of the state matrix A satisfies this criterion: with U and Q arbitrary unitary matrices, and D an arbitrary non-negative diagonal matrix. Conversely, any stable matrix can be reached by this parameterization. Note that the matrix P that satisfies Equation S2 is then given by P = U (D + I)U T. Finally, as we are also learning the B and C matrices in Equation S1, we can without loss of generality set U = I.
C.2 GRU dynamics
To fit the monkey reaching data as well as the Lorenz attractor, we chose the dynamical system to be Minimal Gated Unit (MGU). More specifically, we used the MGU2 variant of the MGU proposed in Heck and Salem (2017): where xt = But denotes the input entering the dynamical system. Note that the latent state z is often denoted by h is the literature. We found that the MGU2 gave better and more stable performance than the MGU. We hypothesize that this is due to the input entering the system in the update gate only (as opposed to entering it through both forget and update gates), thus making the system more easily controllable. We chose σ(·) to be a sigmoid function, and g(·) to be a soft ReLu-like nonlinearity,
D Likelihood functions
The likelihood of the observations appears both in the ELBO and in the iLQR cost. Minimization of the latter via iLQR requires computing the momentary Jacobians and Hessians of the likelihood function w.r.t. the internal state zt. Although these quantities can be obtained generically via automatic differentiation, iLQR is always faster when they are provided directly (Appendix E), which we did here using the analytical expressions given below.
D.1 Gaussian likelihood
For the Gaussian likelihood, we assume observations o are linearly decoded from latents z and corrupted with Gaussian noise, such that o∼ 𝒩 (Cz + b, Σ), with C the readout matrix, b a vector of biases, and Σ a diagonal matrix of variances. This yields the following log-likelihood function:
The Jacobian of this expression is as follows :
Finally, the Hessian is given by :
D.2 Poisson likelihood
To model spike trains, we assume that they are generated by a Poisson process with an underlying positive rate function for neuron i given by: where is a nonlinear function (chosen to be an exponential when modelling the monkey recordings, and a soft ReLU-like nonlinearity elsewhere), Δ denotes the time bin size, and βi is a neuron-specific gain parameter. This yields the following log-likelihood : where the sum is performed over neurons. Using the shorthand notations h(x) = log f (x) and at = Czt + b, the Jacobian and Hessian of this expression are given by :
E iLQR algorithm
Our recognition model makes use of the iterative Linear Quadratic Regulator algorithm (iLQR; Li and Todorov, 2004; Tassa et al., 2014) to find the mean of the posterior distribution qϕ(u| o). Iterative LQR is used to solve finite-horizon optimal control problems with non-linear dynamics and non-quadratic costs by (i) linearizing the dynamics locally around some initial trajectory, (ii) performing a quadratic approximation to the control cost around that same trajectory, (iii) solving the linear-quadratic problem generated by the local approximation to obtain better control inputs, and (iv) repeat until convergence, each time linearizing around the trajectory induced by the new inputs. Below, we first introduce the linear-quadratic regulator (LQR), and detail the approximation used in iLQR to turn any non-linear non-quadratic problem into one that can be solved with LQR. Moreover, we provide pseudo-code for our implementation of iLQR (see Algorithm 1).
The Linear Quadratic Regulator is concerned with finding the set of controls u ∈ ℝm that minimize a quadratic cost function 𝒞 (u) under deterministic linear dynamics, given by:
Here, At ∈ ℝn×n is a (possibly time-dependent) transition matrix, Bt ∈ ℝn×m represents the input channels at time t, and ht is a state and input-independent term. Note that z ∈ ℝn is a deterministic function of the initial condition z0 and the sequence of inputs u. LQR finds the inputs minimizing Equation S15 using a dynamic programming approach, by recursively finding the feedback rule (Kt, kt) which gives the optimal inputs to minimize the cost-to-go at each time t as ut = Ktzt+kt. Details can be found in function Backward in Algorithm 1.
iLQR is an extension of LQR to general dynamics and cost functions. Specifically, iLQR minimizes where θ denotes a set of parameters. At iteration i, iLQR approximates both the dynamics and the cost around the current trajectory τ i = (zi, ui) as: and
Here, δz and δu refer to perturbations around the current nominal trajectory, and all ∇ operators correspond to partial differentiation evaluated at the current nominal trajectory (ui, zi) and corresponding time t.
The above equations are readily identified as a local LQR problem of the form of Equation S15, which can thus be solved using standard dynamic programming tools. Once δu⋆ minimizing Equation S19 has been computed, the inputs are updated as ui+1 = ui +δu⋆, and the new state trajectory follows from simulating the dynamics forward with these new inputs. After each LQR update, we thus obtain a new trajectory τ i+1 and the process repeats until convergence to some locally optimal trajectory τ⋆.
Implementation details can be found in Algorithm 1. Note that the backward LQR pass involves inversion of the matrix Quu (defined in Algorithm 1 function Backward). Depending on the specific form of the iLQR cost function, this might not always be positive-definite. Therefore, we include an adaptive Levenberg-Marquard-type regularizer (not described in the pseudo-code) Quu → Quu + λI to maintain positive definiteness. Thus, iLQR effectively reverts to first-order gradient descent, as opposed to second-order optimization, whenever the locally quadratic approximation is a bad one.
F Differentiating through iLQR
Here we discuss how to efficiently differentiate through the iLQR algorithm. This becomes necessary when one wishes to differentiate through a function involving an iLQRsolve, such as the posterior mean of our recognition model (Equation 11). While a naive but simple strategy to achieve this would be to unroll the algorithm and gather gradients for every step, this is expensive both computationally and memory-wise. Amos et al. (2018) derived a way to analytically obtain gradients with respect to the parameters of iLQR, at the cost of a single LQR pass. Specifically, differentiating through an iLQRsolve is achieved by running iLQR to convergence, forming a linear-quadratic approximation around the converged trajectory, following the steps described in Appendix E and differentiating through the corresponding LQR problem. Below, we provide an alternative derivation to Amos et al.‘s of the gradients of an LQR solution.
F.1 LQR optimality conditions
We now introduce use the more compact notation , which will be used in the rest of this section.
As described in Appendix E, the finite-horizon, discrete-time LQR problem involves minimizing: subject to constraints on its dynamics following the notation from Appendix E. To solve this problem, we write down the Lagrangian: where λ1, λ2,…, λT are adjoint (dual) variables that enforce the dynamic constraint. Differentiating with respect to λt and τt enables us to obtain the set of equations satified by λ and τ, also known as the KKT conditions (Kuhn and Tucker, 2014; Karush, 2014; Boyd et al., 2004):
iLQRsolve(𝒞θ(u), uinit)), with u ∈ℝm and 𝒞θ defined in Equation S17. Parameters: θ, γ
Rearranging, we can rewrite the KKT conditions in matrix form as:
These optimality conditions are satisfied for the solution to the optimization problem . Equation S28 implies that the solution of the LQR problem p⋆ will satisfy:
Computing this quickly becomes infeasible as K grows with long-time horizons, and Equation S28 is typically solved in linear time using a dynamic programming approach, as described in Appendix E.
F.2 Backpropagating through the LQR solver
Differentiating through an LQR solve boils down to differentiating through the backsolve in Equation S29. In the following, we denote the adjoint of parameter θ as . From Giles (2008), we know that the adjoint of the backsolve operation is given by:
We note that Equation S30 has the same form as Equation S29, which means we can compute by solving another LQR problem. After solving for , we can then compute as an outer-product of with y to get:
Collecting all the gradients of and , we arrive at
Note that we have symmetrized the adjoint of Ct, which ensures that Ct remains symmetric after each gradient update. The antisymmetric part of Ct does not contribute to the LQR cost.
Finally, one subtlety arises from the fact that Equation S22 and Equation S21 are written as a function of τ in the general LQR setting. In the iLQR case however, the LQR problem is local at each iteration, and δτ vanishes at convergence. If we denote by i the last iteration before declaring convergence, one can however write the problem as a function of the variable of interest τ⋆, using : subject to constraints on its dynamics
This implies that the values for C and c need to be adjusted accordingly, such as to reflect the switch of variable from δτ during the optimization to the fixed point τ⋆ to compute gradients. Note that at convergence we can use τ i ≈ τ⋆, giving access to all the necessary variables to compute gradients with respect to θ.
G Details of experiment 1
The data in Section 3.1 was generated from an autonomous linear dynamical system with n = 8, m = 3, and no = 8 where no is the dimension of the observation space. All the models were fit using the dynamics within the ground-truth model class, i.e with linear dynamics, n = 8, m = 3, and no = 8. We optimized the model parameters with Adam, using (manually optimized) learning rates of for the free iLQR-VAE model, for autonomous iLQR-VAE and for LFADS, where k is the iteration number. We used GRU networks with 32 units to parametrize the LFADS encoders (one encoder for the initial condition and one for the inputs). Note that while all methods run in similar wallclock time in this example, this will ultimately be implementation and data-dependent.
In Figure S2, we show additional learning curves for the “forced autonomous” models; these show that, even for different initializations and trajectories through the loss landscape, the model consistently gets stuck in plateaus. This can be contrasted with the free-form iLQR-VAE models.
H Further details of Lorenz attractor
The chaotic Lorenz attractor consists of a three-dimensional state (𝓁 1, 𝓁 2, 𝓁 3) evolving according to
For our example, we generated data by integrating Equation S40 over a long time period using a Runge-Kutta solver (RK4) followed by z-scoring and splitting the resulting state trajectory into 112 non-overlapping bouts (Figure 2A). We added Gaussian noise with a standard deviation of 0.1, and trained iLQR-VAE on this dataset (Figure 2B, bottom). We then fitted these data using GRU dynamics with n = 20 and m = 5.
The normalized k-step mean-squared error was defined as follows: where is the prediction at time t + k, and the mean for this trial.
I Learning input-driven nonlinear dynamics
To bridge the gap between autonomous nonlinear dynamics (see Section 3.2) and real data, we evaluated iLQR-VAE on an input-driven nonlinear system, the Duffing oscillator. We generated Duffing trajectories that included a perturbation of the Duffing state half-way through. We then embedded those into the spiking activity of 200 neurons (see below). We found that iLQR-VAE could successfully learn the dynamics and infer the timing of the perturbations.
The dynamics of the Duffing oscillator are given by
To generate each training sample, we integrated Equation S43 from two different random initial conditions for 100 time steps each using a Runge-Kutta solver (RK4) with dt = 0.03. Example such trajectories are shown in Figure S3A (top); note that each trajectory can be understood as the evolution of the system in state-space for a given energy level of the oscillator. We then concatenated those two trajectories to yield a single trajectory of 200 steps with a perturbation in the middle. We then linearly mapped the low-dimensional oscillator onto a 200-dimensional state, before passing it through the nonlinearity of Equation S7 to obtain a set of firing rates, which then gave rise to observations via a Poisson process (Figure S3B, top).
We generated 112 training and 112 testing trials in this way. We fit these data using iLQR-VAE with n = 20, m = 4, and found that it could successfully infer the latent trajectories (see Figure S3B, middle). Importantly, iLQR-VAE learned to fit most of the trajectories as an autonomously evolving dynamical system, and only used inputs to explain the sudden change in the oscillator’s energy level triggered by the perturbation (see Figure S3B, bottom). This shows that the model can successfully disentangle ongoing dynamics from external inputs, suggesting that it is well-suited for identifying input-driven dynamics in real data.
J Comparison of LFADS and iLQR-VAE on a toy input inference task
We used the LFADS implementation from https://github.com/google-research/computation-thru-dynamics/tree/master/lfads_tutorial, which we modified to include linear dynamics and Gaussian likelihoods. We then evaluated the quality of the input reconstruction by measuring how much input variance was captured by the models. We report this as the R2 from inferred to true inputs.
We used a generative model within the ground truth model class. For each dataset, we performed a hyper-parameter search to choose the best-performing encoder architecture and learning rate for LFADS.
Results of this experiment are summarized in Table S1. iLQR-VAE – which did not require any hyperparameter tuning for these examples – inferred inputs more accurately for all dataset sizes and trial lengths.
Our results suggest that LFADS’ performance improves with larger amounts of data. More surprisingly, LFADS also seems to perform better when the data is split into shorter trials. In particular, we found it difficult to fit LFADS on the single long trial, but the dynamics could be learned more accurately if this data was split into 10 trials of 100 steps. On the other hand, iLQR-VAE inferred inputs more accurately for longer trials. This is what we would expect if the model is well learnt, as longer trials contain more information to fit the inputs accurately.
One important distinction between the two methods, which partly explains LFADS’ lower R2, is the prior it over inputs (auto-regressive prior for LFADS and Student for iLQR-VAE). In Figure S4 we show an example of LFADS, on one of the test examples of the S 56×100 dataset. In this example, LFADS infers its largest input concurrently to the ground truth input, but also infers small inputs when there are none in the ground truth. This has a significant impact on the R2 metric. Note however that this is not the only effect at play here, as emphasized by the lower performance on the AR dataset. The impact of the choice of prior in iLQR-VAE is discussed further in Appendix K.
K Comparison of the Student and Gaussian priors
In this section, we compare the performance of the Gaussian and Student priors on two toy examples. The first consists of data generated by a linear dynamical system (n = 3, m = 3, no = 10, Gaussian likelihood) driven by autoregressive Gaussian inputs (close to the Gaussian prior). The second one uses the same system, but driven by sparse inputs (closer to the Student prior). We find that in the first example, both priors yield extremely similar results (see Figure S5(A-C)). Indeed, the Student prior learns a very high value of ν ∼ 20, thus becoming nearly Gaussian.
In the sparse input case however, the Student prior allows to fit the data considerably better. As we can see in Figure S5(D-F), the Gaussian prior learns a large variance to fit the sparse inputs, leading to higher baseline noise than in the true system.
As shown here, the Student prior offers a more flexible model, as the Gaussian case is recovered for large ν values. Note however that using the Gaussian prior ensures that the input term in the iLQR cost function is always convex in u, which can facilitate the optimization and allow iLQR to converge faster in some cases. Moreover, in the case of autonomous dynamics (e.g Lorenz attractor and Maze dataset) both priors will converge to the same solution.
L Further details of single trial analyses
L.1 Benchmarking against existing methods
To allow for direct comparison with benchmarks reported Pei et al. (2021), we first used data provided by the Neural Latents Benchmark (NLB) challenge, available at https://gui.dandiarchive.org/#/dandiset/000128.
We used 1720 training trials and 510 validation trials, which were drawn randomly for each instantiation of the model to avoid overfitting to test data. The risk of overfitting to the dataset was lowered by the fact that iLQR-VAE requires very little hyperparameter optimization. For this experiment, we fitted iLQR-VAE to the neural activity using a model with MGU dynamics (n = 60), a Student prior over inputs (m = 15), and a Poisson likelihood (no = 182 neurons). We trained models on trials spanning all reach conditions and restricting data to a time window starting 250 ms before and ending 450 ms after movement onset, binned at 5ms. For regression to hand velocity, we introduced a lag of 100ms between neural activity and hand velocity. As the test data used in the NLB challenge is not publicly available, the results we reported were not computed on the exact same data split. However, the model performed highly consistently across random seeds, such that we expect iLQR-VAE’s performance to be directly comparable to the results from Pei et al. (2021). To fit these data, we ran iLQR-VAE on 168 CPUs for ∼ 6h, using a mini-batch size of 168 trials.
The co-smoothing metric used to assess how well the model fit the data is defined as log-likelihood score : where the overall log-likelihood ℒ is the sum of all the log-likelihoods evaluated at all points and for all neurons, λ denotes the vector inferred time-varying firing rates, is the mean firing rate for neuron n and nsp is the total number of spikes.
L.2 Further analyses
A key feature of monkey M1 motor cortical recordings is the prevalence of rotational dynamics in the data (Churchland et al., 2012). These can be captured using jPCA, a method developed to find the subspace in which the dynamics are most rotational, which was recently generalized by Rutten et al. (2020). Here, we found that we could uncover clean rotational dynamics from the single-trial firing rates, similarly to Pandarinath et al. (2018).
M Further details of the continuous reaching task analysis
M.1 Details of the analyses
For our analyses of the primate data in Section 3.4, we considered the first 22 minutes of the recording session ‘indy_20160426’ from O’Doherty et al. (2018). We binned spikes at 25 ms resolution and considered all neurons with a firing rate of at least 2Hz. Behavioural data took the form of the velocity of the hand of the monkey in the xy-plane and were extracted as the first derivative of a cubic spline fitted to the position over time. We z-scored the hand velocity and shifted it by 120ms, following on Jensen et al. (2021).
To fit iLQR-VAE, the resulting dataset was divided into 336 non-overlapping pseudo-trials of which a random half were used to fit the generative model and the other half of the trials were used as a held-out test dataset. We fitted a model with n = 50, m = 10 using the non-linear dynamics described in Equation S4. The latent state was then mapped onto both the kinematics and neural observations. We used a linear readout from latents to 2D kinematics variables, and a linear readout following by a nonlinearity from latents to the firing rates of 130 neurons.
After fitting the iLQR-VAE to neural activity and behavior jointly, we then proceeded to infer u from neural activity alone. Next, we computed the kinematic reconstruction error on the test dataset as the fraction of variance captured in both x- and y-hand velocities.
Finally, we analyzed the inputs to the model after fitting, in relation to specific events in the task and the behaviour. We defined ‘movement onset’ after each target onset as the time at which the hand speed first exceeded 0.03 m s−1. We aligned the z-scored input u on each trial to target onset and movement onset separately for visualization purposes. We performed a similar analysis for hand speed and mean z-scored neural activity which were also z-scored and aligned to target and movement onset for comparison with the control input.
M.2 Comparison of iLQR-VAE and bGPFA
As a further way of understanding the relative benefits and disadvantages of iLQR-VAE, we compared its performance with bGPFA, a fully Bayesian extension of GPFA (Yu et al., 2009) that enables the use of non-Gaussian likelihoods, scales to very large datasets, and was recently shown to outperform standard GPFA on this same continuous reaching dataset (Jensen et al., 2021). Importantly, bGPFA makes different assumptions to iLQR-VAE, as it places a smooth prior directly on the latents with no explicit notion of dynamics. We fit both methods using 10 minutes of data (chunked into pseudo-trials for iLQR-VAE and as a continuous trial for bGPFA). For iLQR-VAE we then performed inference and retrained the posterior covariance on the first minute of data whilst fixing the generative parameters. We found that while both methods captured similar trends in the firing rates, bGPFA yielded smoother estimates, but iLQR-VAE captured larger modulations (consistent with the higher R2 when regressing from firing rates to hand velocity). Note that the firing rate estimates here are not as smooth as for the Maze dataset (c.f. Figure 4A), because iLQR-VAE was fit using a Gaussian prior over inputs with non-zero variance at all times, effectively implying an autoregressive prior on the latent trajectories and firing rates.
From Figure S7, one can notice that bGPFA struggles to capture larger variations in the firing rate. This suggests oversmoothing, and might explain why the method does not capture hand kinematics as well as iLQR-VAE (R2 = 0.6 for bGPFA and 0.76 for iLQR-VAE.) This is indeed what we see in Figure S8.
N Link to Kalman filtering
The Linear Quadratic Regulator and the Kalman filter (Kalman, 1964) are algorithms designed for systems with linear dynamics and Gaussian noise. LQR finds the optimal feedback control law to minimize a cost 𝒞 in deterministic systems, while the Kalman filter yields an estimate of the state from observations corrupted with process and observation noise. It is well-known that Kalman filtering and LQR are dual of one another, and they can both be combined into LQG to yield an optimal control law from noisy observations. Here, we explore another link between LQR and Kalman smoothing, by showing how LQR can be used as a Kalman smoother. Moreover, in order to gain insights into the learning process of iLQR-VAE, we explore different procedures for learning the parameters of a Kalman filter.
Linear quadratic control as filtering
The Kalman smoother assumes dynamics of the form with w ∼ 𝒩 (0, I), v ∼ 𝒩 (0, Σv), and the initial condition is assumed to be generated by a Gaussian distribution with known parameters z1 ∼ 𝒩 (μ, Π).
On the other hand, LQR assumes the following fully-deterministic dynamics : with z1 known exactly.
Note than in the iLQR-VAE framework we have thus far only considered cases where no observed external inputs were given. However these can be straightforwardly included as an additional û term in Equation S47 and Equation S45.
The Kalman smoother’s objective is to minimize the expected mean squared error between the inferred latent state and the true state, . As described in Aravkin et al. (2017), with linear dynamics and Gaussian noise, this becomes equivalent to minimizing the following objective w.r.t z: where the first two terms correspond to the prior over the initial condition and smoothness of the trajectory, and the last term represents the likelihood of the observations. Interestingly, this can be related to the objective we minimize to find the posterior mean in iLQR-VAE (Equation 11):
The right-hand sides of Equation S49 and Equation S50 become identical when Σu = BK and Σ0 = Π. Note that the introduction of the BI matrix in Equation S45 unties the two formulations slightly by allowing for further mixing between the input channels that isn’t accounted for by the prior. In the examples we consider next, we therefore set BI = I.
The above equations show how LQR can be used to solve the standard Kalman filtering problem, with the key difference being that the optimization is performed over inputs u = {u0, …, uT−1} rather than latent trajectories z = {z1, …, zT } directly. This is illustrated in Figure S9(A), where a Rauch-Kung-Striebel (RKS) smoother and LQR were ran on the same set of 8-dimensional observations arising from an 8-dimensional linear dynamical system, and inferred the same latent trajectory given the ground-truth parameters. As we only use LQR to parametrize the mean of the posterior distribution, we trained the recognition model for 100 steps to get the uncertainty over the latents, which was very similar to the output of the RKS smoother.
Learning a Kalman filter
We then proceeded to learn the parameters of the models using either iLQR-VAE, an Expectation-Maximization (EM) procedure, or direct minimization of the negative log likelihood of the data (Figure S9B-C).
Interestingly, the EM algorithm is closely related to iLQR-VAE, since the E-step finds the latent trajectories minimizing Equation S49, when iLQR-VAE solves Equation S50 in an inner optimization loop. While there exists an analytical solution for the M-step in the case of the Kalman filter, this does not generalize to nonlinear dynamics and non-Gaussian noise. Therefore, we used a gradient descent procedure for the maximization step.
Both of these were performed using Adam with a learning rate of 0.02, and with initial parameters drawn from the same distributions. We see in Figure S9C that iLQR-VAE reaches a smaller NLL in considerably fewer iterations than gradient descent, which we hypothesize is due to the good preconditioning given by iLQR (discussed in Figure 1). Note however that the cost of one iLQR-VAE iteration is higher than the direct computation of Equation S49.
In this section, we have shown in a simple linear-quadratic example how iLQR-VAE performs filtering by inferring the process noise as inputs. While this is undoubtedly an unconventional approach, it becomes particularly valuable in cases where dynamics are non-linear and the noise non-Gaussian. Indeed, in such cases the problem of learning an estimator for the latent state is a very difficult one, typically solved using methods such as particle filtering or unscented Kalman filters (Doucet and Johansen, 2009; Wan et al., 2001). iLQR-VAE offers another way to solve this problem, with close links to the aforementioned approaches.
O Analysis of the inference gap
In order to evaluate the benefits of defining the recognition model implicitly through the generative parameters, we compared iLQR-VAE to a more standard sequential variational auto-encoder, using a bidirectional recurrent neural network as the recognition model. We generated data from the same system as in Figure 3, in the form of 76 trials of 100 time steps. We used the same generative model in both cases (linear dynamics with n = 3, m = 3, no = 10, Student prior), such that the only difference lay in the choice of recognition model. We compared the ELBO to a more accurate estimate of the log-likelihood, the Importance Weighted Autoencoder (IWAE) bound (Burda et al., 2015), which is computed as where we used Monte-Carlo sampling with 1000 samples to evaluate the expectation. This then allowed us to compute the inference gap Cremer et al. (2018) of both models as ℒIWAE−ELBO. As shown in Figure S10, iLQR-VAE has a smaller inference gap throughout training, leading to faster and more robust convergence. This confirms the intuition that keeping the recognition and generative models in sync throughout training reduces the inference gap.
Footnotes
mmcs3{at}cam.ac.uk
c.kao{at}ucl.ac.uk
ktj21{at}cam.ac.uk
g.hennequin{at}eng.cam.ac.uk
↵1 Note that when m < n, u0 can only reach an m-dimensional subspace of initial conditions, which could be limiting. We can circumvent this problem by spreading u0 over multiple surrogate time bins before the start of the trial, i.e. introduce {u−n/m, …, u−2, u−1, u0} together with an appropriate dependence of fθ on t ≤ 0 in Equation 1, such that each of these surrogate inputs target a different latent subspace with purely integrating (“sticking”) linear dynamics before t = 1.