Abstract
The predictive nature of the hippocampus is thought to be useful for memory-guided cognitive behaviors. Inspired by the reinforcement learning literature, this notion has been formalized as a predictive map called the successor representation (SR). The SR captures a number of observations about hippocampal activity. However, the algorithm does not provide a neural mechanism for how such representations arise. Here, we show the dynamics of a recurrent neural network naturally calculate the SR when the synaptic weights match the transition probability matrix. Interestingly, the predictive horizon can be flexibly modulated simply by changing the network gain. We derive simple, biologically plausible learning rules to learn the SR in a recurrent network. We test our model with realistic inputs and match hippocampal data recorded during random foraging. Taken together, our results suggest that the SR is more accessible in neural circuits than previously thought and can support a broad range of cognitive functions.
1. Introduction
To learn from the past, plan for the future, and form an understanding of our world, we require memories of personal experiences. These types of memories depend on the hippocam-pus for formation and recall [1, 2, 3], but an algorithmic and mechanistic understanding of memory formation and retrieval in this region remains elusive. The need to support planning and inference suggests that one of the key features of memory is the ability to predict possible outcomes [4, 5, 6, 7]. Consistent with this hypothesis, experimental work has shown that, across species and tasks, hippocampal activity is predictive of the future experience of an animal [8, 9, 10, 11, 12, 13, 14, 15]. Furthermore, theoretical work has found that models endowed with predictive objectives tend to resemble hippocampal activity [16, 17, 18, 19, 20, 21, 6, 22]. Thus, it is clear that predictive representations are an important aspect of hippocampal memory.
Inspired by work in the reinforcement learning (RL) field, these observations have been formalized by describing hippocampal activity as a predictive map under the successor representation (SR) algorithm [23, 24, 18]. Under this framework, an animal’s experience in the world is represented as a trajectory through some defined state space, and hippocampal activity predicts the future experience of an animal by integrating over the likely states that an animal will visit given its current state. This algorithm further explains how, in addition to episodic memory, the hippocampus may support relational reasoning and decision making [21, 25], consistent with differences in hippocampal representations in different tasks [26, 27]. The SR framework captures many experimental observations of neural activity, leading to a proposed computational function for the hippocampus [18].
While the SR algorithm convincingly argues for a computational function of the hippocam-pus, it is unclear what biological mechanisms might compute the SR in a neural circuit. Thus, several relevant questions remain that are difficult to probe with the current algorithm. What kind of neural architecture should one expect in a region that can support this computation? Are there distinct forms of plasticity and neuromodulation needed in this system? What is the structure of hippocampal inputs to be expected? A biologically plausible model can explore these questions and provide insight into both mechanism and function [28, 29, 30].
In other systems, it has been possible to derive biological mechanisms with the goal of achieving a particular network function or property [31, 32, 33, 34, 35, 36, 37, 38]. Key to many of these models is the constraint that learning rules at any given neuron can only use information local to that neuron. A promising direction towards such a neural model of the SR is to use the dynamics of a recurrent network to perform SR computations [39, 40]. However, this idea has not been tied to neural learning rules that support its operation and allow for testing of specific hypotheses.
Here, we show that an RNN with local learning rules and an adaptive learning rate exactly calculates the SR at steady state. We test our model with realistic inputs and make comparisons to neural data. In addition, we compare our results to the standard SR algorithm with respect to the speed of learning and the learned representations in cases where multiple solutions exist. Our work provides a mechanistic account for an algorithm that has been frequently connected to the hippocampus, but could only be interpreted at an algorithmic level. This network-level perspective allows us to make specific predictions about hippocampal mechanisms and activity.
2. Results
2.1. The successor representation
The SR algorithm described in Stachenfeld et al. [18] first discretizes the environment explored by an animal (whether a physical or abstract space) into a set of n states that the animal transitions through over time (Figure 1a). The animal’s behavior can then be thought of as a Markov chain with a corresponding transition probability matrix Tn×n (Figure 1b). T gives the probability that the animal transitions to a state s′ from the state s in one time step: Tji = P(s′ = i |s = j). The SR matrix is defined as
Here, γ ∈ (0,1) is a temporal discount factor. Mji can be seen as a measure of the occcupancy of state i over time if the animal starts at state j, with γ controlling how much to discount time steps in the future (Figure 1c). The SR of state j is the jth row of M and represents the states that an animal is likely to transition to from state j. Stachenfeld et al. [18] demonstrate that, if one assumes each state drives a single neuron, the SR of j resembles the population activity of hippocampal neurons when the animal is at state j (Figure 1d). They also show that the ith column of M resembles the place field (activity as a function of state) of a hippocampal neuron representing state i (Figure 1e). In addition, the ith column of M shows which states are likely to lead to state i.
2.2. Recurrent neural network computes SR at steady state
We begin by drawing connections between the SR algorithm [18] and an analogous neural network architecture. The input to the network encodes the current state of the animal and is represented by a layer of input neurons (Figure 1fg). These neurons feed into the rest of the network that computes the SR (Figure 1fg). The SR is then read out by a layer of output neurons so that downstream systems receive a prediction of the upcoming states (Figure 1fg). We will first model the inputs ϕ as one-hot encodings of the current state of the animal (Figure 1fg). That is, each input neuron represents a unique state and are one-to-one connected to the hidden neurons.
We first consider an architecture in which a recurrent neural network (RNN) is used to compute the SR (Figure 1f). Let us assume that the T matrix is encoded in the synaptic weights of the RNN. In this case, the steady state activity of the network in response to input ϕ retrieves a row of the SR matrix, M⊤ϕ (Figure 1f, Supplementary Notes 1). Intuitively, this is because each recurrent iteration of the RNN progresses the prediction by one transition. In other words, the tth recurrent iteration raises T to the tth power as in equation 1. To formally derive this result, we first start by defining the dynamics of our RNN with classical rate network equations [41]. At time t, the firing rate x(t) of N neurons given each neurons’ input ϕ(t) follows the discrete-time dynamics (assuming a step size Δt = 1)
Here, γ scales the recurrent activity and is a constant factor for all neurons. The synaptic weight matrix is defined such that Jij is the synaptic weight from neuron j to neuron i. Notably, this notation is transposed from what is used in RL literature, where conventions have the first index as the starting state. Generally, f is some nonlinear function in equation 2. For now, we will consider f to be the identity function, rendering this equation linear. Under this assumption, we can solve for the steady state activity xSS as
Equivalence between equation 1 and equation 3 is clearly reached when J = T⊤ [40, 39]. Thus, if the network can learn T in its synaptic weight matrix, it will exactly compute the SR.
A benefit of this scheme is that γ is not encoded in the synaptic weights. Thus, γ can be a flexibly modulated gain factor (see, for example, Sompolinsky et al. [42]) allowing the system to retrieve successor representations of varying predictive strengths. We will refer to the γ used during learning of the SR as the baseline γ, or γB.
We next consider what is needed in a learning rule such that J approximates T⊤. In order to learn a transition probability matrix, a learning rule must associate states that occur sequentially and normalize the synaptic weights into a valid probability distribution. We derive a learning rule that addresses both requirements (Figure 1h, Supplementary Notes 2), where η is the learning rate. The first term in equation 4 is a temporally asymmetric potentiation term that counts states that occur in sequence. This is similar to spike-timing dependent plasticity, or STDP [43, 8, 44]. The second term in equation 4 normalizes the synapses into a valid transition probability matrix, such that each column of J = T⊤ sums to 1.
Crucially, this update rule (equation 4) uses information local to each neuron (Figure 1h). We show that, in the asymptotic limit, the update rule extracts information about the inputs ϕ and learns T exactly despite having access only to neural activity x (Supplementary Notes 3). We will refer to an RNN using equation 4 as the RNN-Successor, or RNN-S. Combined with recurrent dynamics (equation 3), RNN-S computes the SR exactly (Figure 1h).
As an alternative to the RNN-S model, we consider the conditions necessary for a feedforward neural network to compute the SR. Under this architecture, the M matrix must be encoded in the weights from the input neurons to the hidden layer neurons (Figure 1g). This can be achieved by updating the synaptic weights with a temporal difference (TD) learning rule, the standard update used to learn the SR in the usual algorithm. Although the TD update learns the SR, it requires information about multiple input layer neurons to make updates for the synapse from input neuron j to output neuron i (Figure 1i). Thus, it is useful to explore other possible mechanisms that are simpler to compute locally. We refer to the model described in Figure 1ih as the feedforward-TD (FF-TD) model.
2.3. Evaluating SR learning by biologically plausible learning rules
To evaluate the effectiveness of the RNN-S learning rule, we tested its accuracy in learning the SR matrix for random walks. Specifically, we simulated random walks with different transition biases in a 1D circular track environment (Figure 2a). The RNN-S can learn the SR for these random walks (Figure 2b).
Because equivalence is only reached in the asymptotic limit of learning (i.e., ΔJ → 0), our RNN-S model learns the SR slowly. In contrast, animals are thought to be able to learn the structure of an environment quickly [45], and neural representations in an environment can also develop quickly [46, 47, 48]. To remedy this, we introduce a dynamic learning rate that allows for faster normalization of the synaptic weight matrix, similar to the formula for calculating a moving average (Supplementary Notes 4). For each neuron, suppose that a trace n of its recent activity is maintained with some time constant λ ∈ (0,1],
If the learning rate of the outgoing synapses from each neuron j is inversely proportional to , the update equation quickly normalizes the synapses to maintain a valid transition probability matrix (Supplementary Notes 4). We refer to this as an adaptive learning rate and contrast it with the previous static learning rate. We consider the setting where λ = 1, so the learning rate monotonically decreases over time (Figure 2c). In general, however, the learning rate could increase or decrease over time if λ < 1 (Figure 2c), and n could be reset, allowing for rapid learning. Our learning rule with the adaptive learning rate is the same as in equation 4, with the exception that for synapses J*j. This learning rule still relies only on information local to the neuron as in Figure 1i.
The RNN-S with an adaptive learning rate normalizes the synapses more quickly than a network with a static learning rate (Figure 2d, Figure S2a) and learns T faster (Figure 2e, Figure S2b). The RNN-S with a static learning rate exhibits more of a tradeoff between normalizing synapses quickly (Figure 2d, Figure S2a) and learning M accurately (Figure 2e, Figure S2b). However, both versions of the RNN-S estimate M more quickly than the FF-TD model (Figure 2f, Figure S2c).
Place fields can form quickly, but over time the place fields may skew if transition statistics are consistently biased [18, 46, 47, 48]. The adaptive learning rate recapitulates both of these effects, which are thought to be caused by slow and fast learning processes, respectively. A low learning rate can capture the biasing of place fields, which develops over many repeated experiences. This is seen in the RNN-S with a static learning rate(Figure 2g). However, a high learning rate is needed for hippocampal place cells to develop sizeable place fields in one-shot. Both these effects of slow and fast learning can be seen in the neural activity of an example RNN-S neuron with an adaptive learning rate (Figure 2h). After the first lap, a sizeable field is induced in a one-shot manner, centered at the cell’s preferred location. In subsequent laps, the place field slowly distorts to reflect the bias of the transition statistics (Figure 2h). The model is able to capture these learning effects because the adaptive learning rate transitions between high and low learning rates, unlike the static version (Figure 2i).
Thus far, we have assumed that the RNN-S learning rule uses pre→post activity over two neighboring time steps (equation 4). A more realistic framing is that a convolution with a plasticity kernel determines the weight change at any synapse. We tested how this affects our model and what range of plasticity kernels best supports the estimation of the SR. We do this by replacing the pre→post potentiation term in equation 4 with a convolution:
In the above equation, the full kernel K is split into a pre→post kernel (K+) and a post→pre kernel (K_). K+ and K_ are parameterized as independent exponential functions, Ae−t/τ.
To systematically explore the space of plasticity kernels that can be used to learn the SR, we performed a grid search over the sign and the time constants of the pre→post and post→pre sides of the plasticity kernels. Plasticity kernels that are STDP-like are more effective than others, although plasticity kernels with slight post→pre potentiation work as well (Figure 2j). The network is sensitive to the time constant and tends to find solutions for time constants around a few hundred milliseconds (Figure 2jk). Our robustness analysis indicates the timescale of a plasticity rule in such a circuit may be longer than expected by standard STDP, but within the timescale of changes in behavioral states. We note that this also contrasts with behavioral timescale plasticity [48], which integrates over a window that is several seconds long. Finally, we see that even plasticity kernels with slightly different time constants may give a result that is SR-like, even if they do not estimate the SR exactly (Figure 2j).
2.4. RNN-S can compute the SR with arbitrary under a stable regime of γR
We next investigate how robust the RNN-S model is to the value of γ. Typically, for purposes of fitting neural data or for RL simulations, γ will take on values as high as 0.9 [18, 49]. However, previous work that used RNN models reported that recurrent dynamics become unstable if the gain γ exceeds a critical value [42, 45]. This could be problematic as we show analytically that the RNN-S update rule is effective only when the network dynamics are stable and do not have non-normal amplification (Supplementary Notes 2). If these conditions are not satisfied during learning, the update rule no longer optimizes for fitting the SR and the learned weight matrix will be incorrect.
We first test how the value of γB, the gain of the network during learning, affects the RNN-S dynamics. The dynamics become unstable when exceeds 0.6 (Figure S3a-e). Specifically, the eigenvalues of the synaptic weight matrix exceed the critical threshold for stability when γB > 0.6 (Figure 3a, “Linear”). As expected from our analytical results, the stability of the network is tied to the network’s ability to estimate M. RNN-S cannot estimate M well when γB > 0.6 (Figure 3b, “Linear”). We explored two strategies to enable RNN-S to learn at high γ.
One way to tame this instability is to add a saturating nonlinearity into the dynamics of the network. Instead of assuming the network dynamics are fully linear (f is the identity function in equation 2), we add a hyperbolic tangent into the dynamics equation. This extends the stable regime of the network– the eigenvalues do not exceed the critical threshold until γB > 0.8 (Figure 3a). Similar to the linear case, the network with nonlinear dynamics fits M well until the critical threshold for stability (Figure 3b). These differences are clear visually as well. While the linear network does not estimate M well for γB = 0.8 (Figure 3b), the estimate of the nonlinear network (Figure 3c) is a closer match to the true M (Figure 3d). However, there is a tradeoff between the stabilizing effect of the nonlinearity and the potential loss of accuracy in calculating M with a nonlinearity (Figure S3h).
We explore an alternative strategy for computing M with arbitrarily high γ in the range 0 ≤ γ < 1. We have thus far pushed the limits of the model in learning the SR for different γB. However, an advantage of our recurrent architecture is that γ is a global gain modulated independently of the synaptic weights. Thus, an alternative strategy for computing M with high γ is to consider two distinct modes that the network can operate under. First, there is a learning phase in which the plasticity mechanism actively learns the structure of the environment and the model is in a stable regime (i.e., γB is small). Separately, there is a retrieval phase during which the gain γR of the network can be flexibly modulated. By changing the gain, the network can compute the SR with arbitrary prediction horizons, without any changes to the synaptic weights. We show the effectiveness of separate network phases by simulating a 1D walk where the learning phase uses a small γB (Figure 3e). Halfway through the walk, the animal enters a retrieval mode and accurately computes the SR with higher γR (Figure 3e).
Under this scheme, the model can compute the SR for any γ < 1 (Figures S3f-h). The separation of learning and retrieval phases stabilizes neural dynamics and allows flexible tuning of predictive power depending on task context.
2.5. RNN-S can be generalized to more complex inputs with successor features
We wondered how RNN-S performs given more biologically realistic inputs. We have so far assumed that an external process has discretized the environment into uncorrelated states so that each possible state is represented by a unique input neuron. In other words, the inputs ϕ are one-hot vectors. However, inputs into the hippocampus are expected to be continuous and heterogeneous, with states encoded by overlapping sets of neurons [50]. When inputs are not one-hot, there is not always a canonical ground-truth T matrix to fit and the predictive representations are referred to as successor features [49, 51]. In this setting, the performance of a model estimating successor features is evaluated by the temporal difference (TD) loss function.
Using the RNN-S model and update rule (equation 4), we explore more realistic inputs ϕ and refer to ϕ as “input features” for consistency with the successor feature literature. We vary the sparsity and spatial correlation of the input features (Figure 4a). As before (Figure 3h), the network will operate in separate learning and retrieval modes, where γB is below the critical value for stability. Under these conditions, the update rule will learn at steady state, where Rϕϕ(r) is the correlation matrix of ϕ with time lag τ (Supplementary Notes 3). Thus, the RNN-S update rule has the effect of normalizing the input feature via a decorrelative factor (Rϕϕ(0)−1) and mapping the normalized input to the feature expected at the next time step in a STDP-like manner (Rϕϕ(−1)). This interpretation generalizes the result that J = T⊤ in the one-hot encoding case (Supplementary Notes 3).
We wanted to further explore the function of the normalization term. In the one-hot case, it operates over each synapse independently and makes a probability distribution. With more realistic inputs, it operates over a set of synapses and has a decorrelative effect. We first ask how the decorrelative term changes over learning of realistic inputs. We compare the mean value of the STDP term of the update (xi(t)xj(t – 1)) to the normalization term of the update (xj(t – 1) ∑k Jik xk(t – 1)) during a sample walk (Figure 4b). The RNN-S learning rule has stronger potentiating effects in the beginning of the walk. As the model learns more of the environment and converges on the correct transition structure, the strength of the normalization term balances out the potentiation term. It may be that the normalization term is particularly important in maintaining this balance as inputs become more densely encoded. We test this hypothesis by using a normalization term that operates on each synapse independently (similar to Oja’s Rule, [52], Supplementary Notes 5). We see that the equilibrium between potentiating and depressing effects is not achieved by this type of independent normalization (Figure 4b, Supplementary Notes 6).
We wondered whether the decorrelative normalization term is necessary for the RNN-S to develop accurate representations. By replacing the decorrelative term with an independent normalization, features from non-adjacent states begin to be associated together and the model activity becomes spatially non-specific over time (Figure 4c, top). In contrast, using the decorrelative term, the RNN-S population activity is more localized (Figure 4c, bottom).
Interestingly, we noticed an additional feature of place maps as we transitioned from one-hot feature encodings to more complex feature encodings. We compared the representations learned by the RNN-S in a circular track walk with one-hot features versus more densely encoded features. For both input distributions, the RNN-S displayed the same skewing in place fields seen in Figure 2 (Figure S4). However, the place field peaks of the RNN-S model additionally shifted backwards in space for the more complex feature encodings (Figure 4d). This was not seen for the one-hot encodings (Figure 4d). The shifting in the RNN-S model is consistent with the observations made in Mehta et al. [17] and demonstrates the utility of considering more complex input conditions. A similar observation was made in Stachenfeld et al. [18] with noisy state inputs. In both cases, field shifts could be caused by neurons receiving external inputs at more than one state, particularly at states leading up to its original field location.
2.6. RNN-S estimates successor features even with naturalistic trajectories
We ask whether RNN-S can accurately estimate successor features, particularly under conditions of natural behavior. Specifically, we used the dataset from Payne et al. [11, 53], gathered from foraging Tufted Titmice in a 2D arena (Figure 5a). We discretize the arena into a set of states and encode each state as a randomly drawn feature ϕ. Using position-tracking data from Payne et al. [11, 53], we simulate the behavioral trajectory of the animal as transitions through the discrete state space. The inputs into the successor feature model are the features associated with the states in the behavioral trajectory.
We first wanted to test whether the RNN-S model was robust across a range of different types of input features. We calculate the TD loss of the model as a function of the spatial correlation across inputs ϕ (Figure 5b). We find that the model performs well across a range of inputs but loss is higher when inputs are spatially uncorrelated. This is consistent with the observation that behavioral transitions are spatially local, such that correlations across spatially adjacent features aid in the predictive power of the model. We next examine the model performance as a function of the sparsity of inputs ϕ (Figure 5c). We find the model also performs well across a range of feature sparsity, with lowest loss when features are sparse.
To understand the interacting effects of spatial correlation and feature sparsity in more detail, we performed a parameter sweep over both of these parameters (Figure 5d, Figure S5a-e). We generated random patterns according to the desired sparsity and smoothness with a spatial filter to generate correlations. This means that the entire parameter space is not covered in our sweep (e.g., the top-left area with high correlation and high sparsity is not explored). Note that since we generate ϕ by randomly drawing patterns, the special case of one-hot encoding is also not included in the parameter sweep (one-hot encoding is already explored in Figure 2). The RNN-S seems to perform well across a wide range, with highest loss in regions of low spatial correlation and low sparsity.
We want to compare the TD loss of RNN-S to that of a non-biological model designed to minimized TD loss. We repeat the same parameter sweep over input features with the FF-TD model (Figure 5e, Figure S5f). The FF-TD model performs similarly to the RNN-S model, with lower TD loss in regions with low sparsity or higher correlation. We also tested how the performance of both models is affected by the strength of γR (Figure 5f). Both models show a similar increase in TD loss as γR increases, although the RNN-S has a slightly lower TD loss at high γ than the FF-TD model. Unlike in the one-hot case, there is no ground-truth T matrix for non-one-hot inputs, so representations generated by RNN-S and FF-TD may look different, even at the same TD loss. Therefore, to compare the two models, it is important to compare representations to neural data.
2.7. RNN-S fits neural data in a random foraging task
Finally, we tested whether the neural representations learned by the models with behavioral trajectories from Figure 5 match hippocampal firing patterns. We performed new analysis on neural data from Payne et al. [11, 53] to establish a dataset for comparison. The neural data from Payne et al. [11] was collected from electrophysiological recordings in titmouse hippocampus during freely foraging behavior (Figure 6a). Payne et al. discovered the presence of place cells in this area. We analyzed statistics of place cells recorded in the anterior region of the hippocampus, where homology with rodent dorsal hippocampus is hypothesized [54]. We calculated the distribution of place field size measured relative to the arena size (Figure 6b), as well as the distribution of the number of place fields per place cell (Figure 6c). Interestingly, with similar analysis methods, Henriksen et al. [55] see similar statistics in the proximal region of dorsal CA1 in rats, indicating that our analyses could be applicable across organisms.
In order to test how spatial representations in the RNN-S are impacted by input features, we performed parameter sweeps over input statistics. As in [11], we define place cells in the model as cells with at least one statistically significant place field under permutation tests. Under most of the parameter range, all RNN-S neurons would be identified as a place cell (Figure 6d). However, under conditions of high spatial correlation and low sparsity, a portion of neurons (12%) do not have any fields in the environment. These cells are excluded from further analysis. We measured how the size of place fields varies across the parameter range (Figure 6e). The size of the fields increases as a function of the spatial correlation of the inputs, but is relatively insensitive to sparsity. This effect can be explained as the spatial correlation of the inputs introducing an additional spatial spread in the neural activity. Similarly, we measured how the number of place fields per cell varies across the parameter range (Figure 6f). The number of fields is maximal for conditions in which input features are densely encoded and spatial correlation is low. These are conditions in which each neuron receives inputs from multiple, spatially distant states.
Finally, we wanted to identify regions of parameter space that were similar to the data of Payne et al. [11, 53]. We measured the KL divergence between our model’s place field statistics (Figure 6de) and the statistics measured in Payne et al. [11] (Figure 6bc). We combined the KL divergence of both these distributions to find the parameter range in which the RNN-S best fits neural data (Figure 6g). This optimal parameter range occurs when inputs have a spatial correlation of σ ≈ 8.75 cm and sparsity ≈ 0.15. We can visually confirm that the model fits the data well by plotting the place fields of RNN-S neurons (Figure 6h).
We wondered whether the predictive gain (γR) of the representations affects the ability of the RNN-S to fit data. The KL divergence changes only slightly as a function of γR. Mainly, the KL-divergence of the place field size increases as γR increases (Figure 6i), but little effect is seen in the distribution of the number of place fields per neuron (Figure 6j).
We next tested whether the neural data was better fit by representations generated by RNN-S or the FF-TD model. Across all parameters of the input features, despite having similar TD loss (Figure 5de), the FF-TD model has much higher divergence from neural data (Figure 6gi, Figure S6).
Overall, our RNN-S model seems to strike a balance between performance in estimating successor features, similarity to data, and biological plausibility. Furthermore, our analyses provide a prediction of the input structure into the hippocampus that is otherwise not evident in an algorithmic description or in a model that only considers one-hot feature encodings.
3. Discussion
Hippocampal memory is thought to support a wide range of cognitive processes, especially those that involve forming associations or making predictions. However, the neural mechanisms that underlie these computations in the hippocampus are not fully understood. A promising biological substrate is the recurrent architecture of the CA3 region of the hippocampus and the plasticity rules observed. Here, we showed how a recurrent network with local learning rules can implement the successor representation, a predictive algorithm that captures many observations of hippocampal activity. We used our neural circuit model to make specific predictions of biological processes in this region.
A key component of our plasticity rule is a decorrelative term that depresses synapses based on coincident activity. Such anti-Hebbian or inhibitory effects are hypothesized to be broadly useful for learning, especially in unsupervised learning with overlapping input features [56, 57, 58]. Consistent with this hypothesis, anti-Hebbian learning has been implicated in circuits that perform a wide range of computations, from distinguishing patterns, [37], to familiarity detection [38], to learning birdsong syllables [59]. This inhibitory learning may be useful because it decorrelates redundant information, allowing for greater specificity and capacity in a network [57, 37]. Our results provide further support of these hypotheses and predict that anti-Hebbian learning is fundamental to a predictive neural circuit.
We derive an adaptive learning rate that allows our model to quickly learn a probability distribution, and generally adds flexibility to the learning process. The adaptive learning rate changes such that neurons that are more recently active have a slower learning rate. This is consistent with experimental findings of metaplasticity at synapses [60, 61, 62], and theoretical proposals that metaplasticity tracks the uncertainty of information [36]. In RNN-S, the adaptive learning rate improves the speed of learning and better recapitulates hippocampal data. Our adaptive learning rate also has interesting implications for flexible learning. Memory systems must be able to quickly learn new associations throughout their lifetime without catastrophe. Our learning rate is parameterized by a forgetting term λ that controls the timescale in which environmental statistics are expected to be stationary. Although we fixed λ = 1 in our simulations, there are computational benefits in considering cases where λ < 1. This parameter provides a natural way for a memory system to forget gradually over time and prioritize recent experiences, in line with other theoretical studies that have also suggested that learning and forgetting on multiple timescales allow for more flexible behavior [63, 64].
We tested the sensitivity of our network to various parameters and found a broad range of valid solutions. Prior work has sought to understand how an emergent property of a network could be generated by multiple unique solutions [65, 66, 67, 68]. It has been suggested that redundancy in solution space makes systems more robust, accounting for margins of error in the natural world [69, 70]. In a similar vein, our parameter sweep over plasticity kernels revealed that a sizeable variety of kernels give solutions that resemble the SR. Although our model was initially sensitive to the value of γ, we found that adding biological components, such as nonlinear dynamics and separate network modes, broadened the solution space of the network.
Several useful features arise from the fact that RNN-S learns the transition matrix T directly, while separating out the prediction timescale, γ, as a global gain factor. It is important for animals to engage in different horizons of prediction depending on task or memory demands [71, 72]. In RNN-S, changing the prediction time horizon is as simple as increasing or decreasing the global gain of the network. Mechanistically, this could be accomplished by a neuromodulatory gain factor that boosts γ, perhaps by increasing the excitability of all neurons [73, 74]. In RNN-S, it was useful to have low network gain during learning (γB), while allowing higher gain during retrieval to make longer timescale predictions (γR). This could be accomplished by a neuromodulatory factor that switches the network into a learning regime [75, 76], for example Acetylcholine, which reduces the gain of recurrent connections and increases learning rates [77, 78]. The idea that the hippocampus might compute the SR with flexible γ could help reconcile recent results that hippocampal activity does not always match high-γ SR [79, 80]. Finally, estimating T directly provides RNN-S with a means to sample likely future trajectories, or distributions of trajectories, which is computationally useful for many memory-guided cognitive tasks beyond reinforcement learning, including reasoning and inference [81]. We also found that the recurrent network fit hippocampal data better than a feedforward network. An interesting direction for further work involves untangling which brain areas and cognitive functions can be explained by deep (feed forward) neural networks [82], and which rely on recurrent architectures, or even richer combinations of generative structures [83]. Recurrent networks, such as RNN-S, support generative sequential sampling, reminiscent of hippocampal replay, which has been proposed as a substrate for planning, imagination, and structural inference [84, 85, 86, 87, 88].
Other recent theoretical works have also sought to find biological mechanisms to learn successor representations, albeit with different approaches [89, 90, 91, 92, 93]. The model from George et al. [93] focuses on a feedforward architecture, using STDP and theta phase precession to learn the SR. It is important to note that these mechanisms are not mutually exclusive with RNN-S. Taken together with our work, these models suggest that there are multiple ways to learn the SR in a biological circuit and that these representations may be more accessible to neural circuits than previously thought.
4. Methods
4.1. Code availability
Code is posted on Github: https://github.com/chingf/sr-project
4.2. Random walk simulations
We simulated random walks in 1D (circular track) and 2D (square) arenas. In 1D simulations, we varied the probability of staying in the current state and transitioning forwards or backwards to test different types of biases on top of a purely random walk. In 2D simulations, the probabilities of each possible action were equal. In our simulations, one timestep corresponds to second and spatial bins are assumed to be 5 cm apart. This speed of movement (15 cm/sec) was chosen to be consistent with previous experiments. In theory, one can imagine different choices of timestep size to access different time horizons of prediction– that is, the choice of timestep interacts with the choice of γ in determining the prediction horizon.
4.3. RNN-S model
This section provides details and pseudocode of the RNN-S simulation. Below are explanations of the most relevant variables:
The RNN-S algorithm is as follows:
4.4. RNN-S with plasticity kernels
We introduce additional kernel-related variables to the RNN-S model above that are optimized by an evolutionary algorithm (see following methods subsection for more details):
We also define the variable tk = 20, which is the length of the temporal support for the plasticity kernel. The value of tk was chosen such that e-tk/τ was negligibly small for the range of τ we were interested in. The update algorithm is the same as in Algorithm 1, except lines 15-16 are replaced with the following:
4.5. Metalearning of RNN parameters
To learn parameters of the RNN-S model, we use covariance matrix adaptation evolution strategy (CMA-ES) to learn the parameters of the plasticity rule. The training data provided are walks simulated from a random distribution of 1D walks. Walks varied in the number of states, the transition statistics, and the number of timesteps simulated. The loss function was the mean-squared error (MSE) loss between the RNN J matrix and the ideal estimated
T matrix at the end of the walk.
4.6. RNN-S with truncated recurrent steps and nonlinarity
For the RNN-S model with tmax recurrent steps, lines 10 and 13 in algorithm 1 is replaced with .
For RNN-S with nonlinear dynamics, there is no closed form solution. So, we select a value for tmax and replace lines 10 and 13 in algorithm 1 with an iterative update for tmax steps: Δx = −x + γtanh(Jx′) + ϕ. We choose tmax such that .
4.7. RNN-S with successor features
We use γB = 0 and a tanh nonlinearity as in Methods 4.6. For simplicity, we set γB = 0.
4.8. RNN-S with independent normalization
As in algorithm 1, but with the following in place of line 16
4.9. FF-TD Model
In all simulations of the FF-TD model, we use the temporal difference update. We perform a small grid search over the learning rate η to minimize error (for SR, this is the MSE between the true M and estimated M; for successor features, this is the temporal difference error). In the one-hot SR case, the temporal difference update given an observed transition from state s to state s′ is: for all synapses j → i. Given arbitrarily strucutred inputs (as in the successor feature case), the temporal difference update is: or, equivalently,
4.10. Generation of feature encodings for successor feature models
For a walk with n states, we created n-dimensional feature vectors for each state. We choose an initial sparsity probability p and create feature vectors as random binary vectors with probability p of being “on”. The feature vectors were then blurred by a 2D Gaussian filter with variance σ with 1 standard deviation of support. The blurred features were then min-subtracted and max-normalized. The sparsity of each feature vector was calculated as the L1 norm divided by N. The sparsity s of the dataset then was the median of all the sparsity values computed from the feature vectors. To vary the spatial correlation of the dataset we need only vary σ. To vary the sparsity s of the dataset we need to vary p, then measure the final s after blurring with σ. Note that, at large σ, the lowest sparsity values in our parameter sweep were not possible to achieve.
4.11. Measuring TD loss for successor feature models
We use the standard TD loss function (equation S7). To measure TD loss, at the end of the walk we take a random sample of observed transition pairs (ϕ, ϕ′). We use these transitions as the dataset to evaluate the loss function.
4.12. Analysis of place field statistics
We use the open source dataset from Payne et al. [11, 53]. We select for excitatory cells in the anterior tip of the hippocampus. We then select for place cells using standard measures (significantly place-modulated and stable over the course of the experiment).
We determined place field boundaries with a permutation test as in Payne et al. [11]. We then calculated the number of fields per neuron and the field size as in Henriksen et al. [55]. The same analyses were conducted for simulated neural data from the RNN-S and FF-TD models.
4.13. Behavioral simulation of Payne et al
We use behavioral tracking data from Payne et al. [11]. For each simulation, we randomly select an experiment and randomly sample a 28 minute window from that experiment. If the arena coverage is less than 85% during the window, we redo the sampling until the coverage requirement is satisfied. We then downsample the behavioral data so that the frame rate is the same as our simulation (3 FPS). Then, we divide the arena into a 14 × 14 grid. We discretize the continuous X/Y location data into these states. This sequence of states makes up the behavioral transitions that the model simulates.
4.14. Place field plots
From the models, we get the activity of each model neuron over time. We make firing field plots with the same smoothing parameters as Payne et al. [11].
Citation diversity statement
Systemic discriminatory practices have been identified in neuroscience citations, and a ‘citation diversity statement’ has been proposed as an intervention [94, 95]. There is evidence that quantifying discriminatory practices can lead to systemic improvements in academic settings [96]. Many forms of discrimination could lead to a paper being under-cited, for example authors being less widely known or less respected due to discrimination related to gender, race, sexuality, disability status, or socioeconomic background. We manually estimated the number of male and female first and last authors that we cited, acknowledging that this quantification ignores many known forms of discrimination, and fails to account for nonbinary/intersex/trans folks. In our citations, first-last author pairs were 64% male-male, 21% female-male, 6% male-female, and 9% female-female, somewhat similar to base rates in our field (biaswatchneuro.com). To familiarize ourselves with the literature, we used databases intended to counteract discrimination (blackinneuro.com, anneslist.net, connectedpapers.com). The process of making this statement improved our paper, and encouraged us to adopt less biased practices in selecting what papers to read and cite in the future. We were somewhat surprised and disappointed at how low the number of female authors were, despite being a female-female team ourselves. Citation practices alone are not enough to correct the power imbalances endemic in academic practice [97] — this requires corrections to how concrete power and resources are distributed.
Supplementary Notes
The successor representation is defined as where T is the transition probability matrix such that Tji = P(s′ = i |s = j) for current state s and future state s′
Supplementary Notes 1. Finding the conditions to retrieve M from RNN steady-state activity
For an RNN with connectivity J, activity x, input ϕ, and gain γ ∈ [0,1), the (linear) discrete-time dynamics equation is [41]
Furthermore, the steady state solution can be found by setting Δx = 0:
Assume that J = TT as a result of the network using some STDP-like learning rule where pre-post connections are potentiated. The transposition is due to notational differences from the RL literature, where the ijth index typically concerns the direction from state i to state j. This is a result of differences in RL and RNN conventions in which inputs are left-multiplied and right-multiplied, respectively. Let γ be a neuromodulatory factor that is applied over the whole network (and, thus, does not need to be encoded in the synaptic weights). Then, the equivalence to equation S1 becomes clear and our steady state solution can be written as:
This is consistent with the successor representation framework shown in Stachenfeld, et al. [18], where the columns of the M matrix represent the firing fields of a neuron, and the rows of the M matrix represent the network response to some input.
Supplementary Notes 2. Deriving the RNN-S learning rule from TD Error and showing the learning rule is valid under a stability condition
Transitions between states (s,s′) are observed as features (ϕ(s),ϕ(s′)) where ϕ is some function. For notational simplicity, we will write these observed feature transitions as (ϕ, ϕ′). A dataset is comprised of these observed feature transitions over a behavioral trajectory. Successor features are typically learned by some function approximator ψ(ϕ; θ) that is parameterized by θ and takes in the inputs ϕ. The SF approximator, ψ, is learned by minimizing the temporal difference (TD) loss function [98]: for the current policy π. Here, the TD target is ϕ + γψπ(ϕ′; θ). Analogous to the model-free setting where the value function V is being learned, ϕ is in place of the reward r. Following these definitions, we can view the RNN-S as the function approximator ψ:
For a single transition (ϕ, ϕ′) we can write out the loss as follows:
For each observed transition, we would like to update ψ such that the loss L is minimized. Thus, we take the gradient of this temporal difference loss function with respect to our parameter θ = J:
We can make the TD approximation ψπ(ϕ′; θ) ψ(ϕ′; θ) = (I – γJ)-1ϕ′ [98]:
While −∇JL(θ) gives the direction of steepest descent in the loss, we will consider a linear transformation of the gradient that allows for a simpler update rule. This simpler update rule will be more amenable to a biologically plausible learning rule. We define this modified gradient as D = ∇JL(θ)M where M = (I – γJ)⊤. We must first understand the condition for D to be in a direction of descent:
This expression is satisfied if M + M⊤ is positive definite (its eigenvalues are positive). Thus, we find that our modified gradient points towards a descent direction if the eigenvalues of M + M⊤ are positive. Interestingly, this condition is equivalent to stating that the recurrent network dynamics are stable and do not exhibit non-normal amplification [99, 100, 101]. In other words, as long as the network dynamics are in a stable regime and do not have non-normal amplification, our modified gradient reduces the temporal difference loss. Otherwise, the gradient will not point towards a descent direction.
We will use the modified gradient −D = (x′ – Jx)x⊤ as our synaptic weight update rule. Our theoretical analysis explains much of the results seen in the main text. As the gain parameter γB is increased, the network is closer to the edge of stability (the eigenvalues of M are close to positive values, Figure 3a). Stability itself is not enough to guarantee that our update rule is valid. We need the additional constraint that non-normal amplification should not be present (eigenvalues of M + M⊤ are positive). In practice, however, this does not seem to be a mode that affects our network. That is, the γB value for which the error in the network increases coincides with the γB value for which the network is no longer stable (Figure 3b). Our theoretical analysis also shows that the gain γB can always be decreased such that the eigenvalues of M + M⊤ are positive and our update rule is valid (Figure 3e). At the most extreme, one can set γB = 0 during learning to maintain stability (as we do in Figure 4 and onwards).
Supplementary Notes 3. Proving the RNN-S update rule calculated on firing rates (x) depends only on feedforward inputs (ϕ) at steady state
We will show that our update rule, which uses x (neural activity), converges on a solution that depends only on ϕ (the feedforward inputs). We will also show that in the one-hot case, we learn the SR exactly.
As a reminder, our learning rule for each j → i synapse is:
We can solve for the steady state solution of equation S19 (set ΔJ = 0). Let A = (1 – γJ)-1 for notational convenience, and recall that in steady state x = Aϕ. Let 〈x〉 denote the average of x over time.
Note that, since A = (1 – γJ)-1, .
Thus,
Therefore, where Rϕϕ(τ) is the autocorrelation matrix for some time lag τ. Therefore, the RNN-S weight matrix J at steady state is only dependent on the inputs into the RNN over time.
In the case where ϕ is one-hot, we compute the SR exactly. This is because the steady state solution at each j → i synapse simplifies into the following expression:
This is the definition of the transition probability matrix and we see that J = T⊤. Note that the solution for Jij in equation S30 is undefined if state j is never visited. We assume each relevant state is visited at least once here.
Supplementary Notes 4. Deriving the adaptive learning rate update rule
This section explains how the adaptive learning rate is derived. The logic will be similar to calculating a weighted running average. Let dij (t) be a binary function that is 1 if the transition from timestep t – 1 to timestep t is state j to state i. Otherwise, it is 0. Assume ϕ is one-hot encoded. Notice that in the one-hot case, the RNN-S update rule (equation 4) simplifies to:
What η should be used so J approaches T⊤ as quickly as possible? During learning, the empirical transition matrix, T(t), changes at each timestep t, based on transitions the animal has experienced. Define the total number of times that state ϕj happened prior to time t as , and define the running count of transitions from state j to state i as . We want J(t) = T⊤(t), which necessitates
Note that nj (t) = nj (t – 1) + ϕj (t – 1), and cij (t) = cij (t – 1) + dij (t), which gives us
Therefore, comparing with equation S31, we can see that a learning rate will let J = T⊤ as quickly as possible. We have defined n in terms of the inputs ϕ for this derivation, but in practice the adaptive learning rate as a function of x works well with the RNN-S update rule (which is also a function of x). Thus, we use the adaptive learning rate defined over x in our combined learning rule for increased biological plausibility.
In its current form, the update equation assumes transitions across all history of inputs are integrated. In reality, there is likely some kind of memory decay. This can be implemented with a decay term λ ∈ (0,1]:
λ determines the recency bias over the observed transitions that make up the T estimate. The addition of λ has the added benefit that it naturally provides a mechanism for learning rates to modulate over time. If λ = 1, the learning rate can only monotonically decrease. If λ < 1, the learning rate can become strong again over time if a state has not been visited in a while. This provides a mechanism for fast learning of new associations, which is useful for a variety of effects, including remapping.
Supplementary Notes 5. Endotaxis model and the successor representation
The learning rule and architecture of our model is similar to a hypothesized “endotaxis” model [45]. In the endotaxis model, neurons fire most strongly near a reward, allowing the animal to navigate up a gradient of neural activity akin to navigating up an odor gradient. The endotaxis model discovers the structure of an environment and can solve many tasks such as spatial navigation and abstract puzzles. We were interested in similarities between RNN-S and the learning rules for endotaxis, in support of the idea that SR-like representations may be used by the brain for a broad range of intelligent behaviors. Here, we outline similarities and differences between the two model architectures.
The endotaxis paper [45] uses Oja’s rule in an RNN with place-like inputs. The SR can also be learned with an Oja-like learning rule. Oja’s rule is typically written as [52]:
If we assume that there is a temporal asymmetry to the potentiation term (e.g., potentiation is more STDP-like than Hebbian), then we have
We then solve for the steady state solution of this equation, when ΔJij = 0: where 〈·〉 indicates the time-average of some term. Assume that the plasticity rule does not use x exactly, but instead uses ϕ directly. Given that inputs are one-hot encodings of the animal’s state at some time t, the expression becomes
If we assume T is symmetric, J = T⊤. Alternatively, if we use pre-synaptic normalization as opposed to the standard post-synaptic normalization of Oja’s rule (i.e., index j instead of i in the denominator), we also have J = T⊤. Thus, the steady state activity of a RNN with this learning rule retrieves the SR, as shown in Supplementary Notes 1.
Supplementary Notes 6. Independent normalization and successor features
If we assume the same Oja-like rule as in Supplementary Notes 5, we can also arrive at a similar interpretation in the successor feature case as in equation 7. By solving for the steady state solution without any assumptions about the inputs ϕ, we get the following equation: where diag is a function that retains only the diagonal of the matrix. This expression provides a useful way to contrast the learning rule used in RNN-S with an Oja-like alternative. While RNN-S normalizes by the full autocorrelation matrix, an Oja-like rule only normalizes by the diagonal of the matrix. This is the basis of our independent normalization model in Figure 4bc.
Supplementary Figures
Acknowledgements
This work was supported through NSF NeuroNex Award DBI-1707398, the Gatsby Charitable Foundation, the New York Stem Cell Foundation (Robertson Neuroscience Investigator Award), National Institutes of Health (NIH Director’s New Innovator Award (DP2-AG071918)), and the Arnold and Mabel Beckman Foundation (Beckman Young Investigator Award). CF received support from the NSF Graduate Research Fellowship Program. ELM received support from the Simons Society of Fellows. We thank Jack Lindsey and Tom George for comments on the manuscript, as well as Stefano Fusi, William de Cothi, Kimberly Stachenfeld, and Caswell Barry for helpful discussions.
References
- [1].↵
- [2].↵
- [3].↵
- [4].↵
- [5].↵
- [6].↵
- [7].↵
- [8].↵
- [9].↵
- [10].↵
- [11].↵
- [12].↵
- [13].↵
- [14].↵
- [15].↵
- [16].↵
- [17].↵
- [18].↵
- [19].↵
- [20].↵
- [21].↵
- [22].↵
- [23].↵
- [24].↵
- [25].↵
- [26].↵
- [27].↵
- [28].↵
- [29].↵
- [30].↵
- [31].↵
- [32].↵
- [33].↵
- [34].↵
- [35].↵
- [36].↵
- [37].↵
- [38].↵
- [39].↵
- [40].↵
- [41].↵
- [42].↵
- [43].↵
- [44].↵
- [45].↵
- [46].↵
- [47].↵
- [48].↵
- [49].↵
- [50].↵
- [51].↵
- [52].↵
- [53].↵
- [54].↵
- [55].↵
- [56].↵
- [57].↵
- [58].↵
- [59].↵
- [60].↵
- [61].↵
- [62].↵
- [63].↵
- [64].↵
- [65].↵
- [66].↵
- [67].↵
- [68].↵
- [69].↵
- [70].↵
- [71].↵
- [72].↵
- [73].↵
- [74].↵
- [75].↵
- [76].↵
- [77].↵
- [78].↵
- [79].↵
- [80].↵
- [81].↵
- [82].↵
- [83].↵
- [84].↵
- [85].↵
- [86].↵
- [87].↵
- [88].↵
- [89].↵
- [90].↵
- [91].↵
- [92].↵
- [93].↵
- [94].↵
- [95].↵
- [96].↵
- [97].↵
- [98].↵
- [99].↵
- [100].↵
- [101].↵