Training biologically plausible recurrent neural networks on cognitive tasks with long-term dependencies

Training recurrent neural networks (RNNs) has become a go-to approach for generating and evaluating mechanistic neural hypotheses for cognition. The ease and efficiency of training RNNs with backpropagation through time and the availability of robustly supported deep learning libraries has made RNN modeling more approachable and accessible to neuroscience. Yet, a major technical hindrance remains. Cognitive processes such as working memory and decision making involve neural population dynamics over a long period of time within a behavioral trial and across trials. It is difficult to train RNNs to accomplish tasks where neural representations and dynamics have long temporal dependencies without gating mechanisms such as LSTMs or GRUs which currently lack experimental support and prohibit direct comparison between RNNs and biological neural circuits. We tackled this problem based on the idea of specialized skip-connections through time to support the emergence of task-relevant dynamics, and subsequently reinstitute biological plausibility by reverting to the original architecture. We show that this approach enables RNNs to successfully learn cognitive tasks that prove impractical if not impossible to learn using conventional methods. Over numerous tasks considered here, we achieve less training steps and shorter wall-clock times, particularly in tasks that require learning long-term dependencies via temporal integration over long timescales or maintaining a memory of past events in hidden-states. Our methods expand the range of experimental tasks that biologically plausible RNN models can learn, thereby supporting the development of theory for the emergent neural mechanisms of computations involving long-term dependencies.

For example, value-based decision making tasks akin to bandit problems are often used to study the neural basis of economic choices.The values of the choices in these tasks can vary over time (slowly or abruptly depending on the task design) and efficient, adaptive responses to these changes requires tracking the value of the options and making decisions contingent upon the current value estimates [1,2].Sometimes, the volatility in the value of the options changes over time and must be tracked in order to adapt appropriately [3].Strategic or social decision making tasks may invoke game-theoretic strategies wherein subjects must assess their opponents' play to maximize their own payoff [4,5].
Rule-based tasks form another important class that requires adaptive behavior.In these tasks, the outcome of a subject's response on a trial depends on a hidden rule.The rule changes across blocks of consecutive trials in an uncued manner, but the set of possible rules is small, fixed and known to the subject.Well-trained subjects detect rule switches and infer the new rule rapidly by efficiently integrating outcomes over a few exploratory trials [6][7][8][9].An extension to this is the study of rulebased behavior in situations where a rule does not come from a known set and must be discovered.Poor approximations of the true rule governing the task structure may make it difficult to learn, place unnecessary demands on cognitive resources and adversely affecting performance.Yet, human subjects quickly discover complex task structures with little instruction by via effective learning strategies [10][11][12].Moreover, well-learned structures promote generalization thereby speeding up subsequent learning [13].
In all these cases, subjects must draw rapid, accurate inferences based on the feedback they receive in order to adapt to environmental changes, thereby maximizing positive outcomes (total reward, payoff, etc.).One theory holds that this may be achieved via neural dynamics, wherein neural populations track changing environmental contingencies and appropriately alter their computations and decision representations to alter behavior [14,15].But exactly how a neural population may achieve this is not known, thus revealing an exciting frontier for the study of the neural basis of adaptive behavior.Yet, the absence of effective methods to train biological RNNs to compute over long timescales (i.e.spanning several trials) poses a fundamental challenge.This challenge extends to tasks that do not require integrating information across trials, but involve long trials instead.This includes spatial navigation, evidence accumulation and decision making tasks wherein the subject must move through a physical space to solve a problem, thus increasing the trial duration and, in some cases, the computational demand of the task.

B Biological plausibility of gating mechanisms
Gating-based architectures such as long short-term memory (LSTM) and gated recurrent unit (GRU) were introduced to overcome challenges pertaining to training vanilla RNNs with gradient descent.They endow individual network units with adaptive multiplicative gating applied to their inputs, memory computation and outputs (only for LSTMs).This has two important computational benefits.First, the multiplicative gating, particularly in the context of memory computations, strongly curtails gradient stability challenges [16].Second, they inherently support adaptive memory timescales and computation at the single unit level.These enhancements make them substantially more computationally powerful than vanilla RNNs, thus inviting their adoption in cognitive models [17,18] and in phenomenological models of neural computation [19].However, key biological constraints preclude them from being considered biologically plausible.
• Gating.Gating of the inputs and outputs of neurons is thought to be mediated by distinct inhibitory interneuron types via their nonlinear interactions at the dendrites and soma, respectively, of excitatory pyramidal neurons [14].While it is certainly believed that gating may enhance adaptive properties of neural computation, available evidence and existing models do not support the hypothesis that this gating is multiplicative.• Training.A major motivation for LSTM and GRU architectures is that they are easier to train than vanilla RNNs due to their improved gradient stability.In contrast, RNNs endowed with interneurons and biologically plausible gating functions are more difficult to train than vanilla RNNs [20].• Dynamics.A constraint for biologically-plausible RNN models is that they should produce continuous-time dynamics.LSTM and GRU architectures are discrete-time systems by construction.Continuous-time variants of these architectures have been proposed in the neuroscience literature to address precisely this issue [21].Moreover, LSTM units typically respond with abrupt activity transitions between consecutive time steps [17].In contrast, neurons exhibit smooth temporal responses.• Memory.LSTM and GRU memory mechanisms seem too powerful relative to known brain mechanisms of short-term memory function.The forget (update) gate of LSTMs (GRUs) endow individual units with arbitrarily small and large memory timescales.In contrast, the neuron membrane time constants are typically quite small and of limited range.While a variety of Glutamatergic signaling mechanisms are believed to enhance neural integration timescales [22], they are still limited.Moreover, receptor expression levels vary drastically by brain region [23], obviating a general-purpose mechanism to support arbitrary timescales at the single neuron level.Continuous-time neural population dynamics offer a complementary mechanisms for arbitrary (and adaptive) timescales in the brain [24].There is little evidence for LSTM/GRU like single-unit timescales in the brain.

C Review of RNN parameter values in literature
Model parameters of RNNs in literature (Figure S1A) are sometimes restricted to undesirable values due to competing constraints from the perspective of biological-plausibility, numerical accuracy and gradient stability.We review the values of some of these parameters from previously-published models.In particular, only continuous-time RNNs trained by backpropagation through time are considered [25-29, 21, 30-47, 20, 48-52] (more details in Table S1).The value of the neuronal or leak time constant is often set to 100 ms, which reflects the slow dynamics of NMDA receptors in the brain; although values below that are also plausible depending on the proportion of NMDA receptors in the brain area being modeled (Figure S1C).For high numerical accuracy, the value of the discretization time step should be small relative to this time constant.The most common values of the ratio between the time step size and the time constant are 0.1 and 0.2 (Figure S1D).However, these values are still quite large and can produce inaccurate approximations of the underlying system.On the other hand, decreasing the step size increases the model's memory footprint and can potentially introduce vanishing or exploding gradients when training an RNN with backpropagation through time.Indeed, despite the differences in their function, the models we have surveyed seldom use a large number of time steps (Figure S1B).Taken together, we see that the technical challenges of training RNNs with backpropagation through time constrain the duration of the task a network can be trained on or hinder the model's numerical accuracy.Step sizes as a fraction of their respective time constants.

D Task structures and descriptions
The 16 tasks can be categorized according to their task structures (A, B or C).
Task structure A. 7 tasks share this task structure, which consists of 3 task epochs: The fixation cue is set to 1 during the first two epochs.During the stimulus epoch, a stimulus is presented in a randomly chosen modality at a randomly chosen direction.The response, made in task epoch 3, should match the direction of the stimulus.
• Anti-response (anti): This is the same as go, but the response should be in the opposite direction of (i.e.π radians from) the stimulus.
• Reaction time go (rtgo): The fixation cue is set to 1 throughout the trial.During epoch 3, a stimulus is presented in a randomly chosen modality at a randomly chosen direction.The response, made in the same epoch, should match the direction of the stimulus.
• Reaction time anti-response (rtanti): Same as rtgo, but the response should be in the opposite direction of (i.e.π radians from) the stimulus.
• Decision making (dm): The fixation cue is set to 1 during the first two epochs.Two stimuli are simultaneously presented in a single modality during task epoch 2. One stimulus is randomly generated with a random magnitude and direction.The other stimulus is also generated with a random magnitude, but its direction is drawn uniformly at random between π 2 radians to 3π 2 radians away from the first stimulus.The response, made in task epoch 3, should match the direction of the stimulus with the greater magnitude.
• Decision making with distractors (dmd): This is the same as dm, except 2 stimuli are presented in each modality.All 4 stimuli have different magnitudes, but the directions of the two stimuli in modality A are the same as the directions of the two stimuli in modality B. The RNN must ignore modality A and respond in the direction corresponding to the stronger stimulus in modality B.
• Multi-sensory decision making (msdm): This is the same as dmd, except the RNN must respond in the direction in which the combined stimulus strength from both modalities is stronger.That is, given stimuli in both modalities at angles p and q, if the sum of the magnitudes of the stimuli in both modalities at angle p is larger than the sum of the magnitudes of the stimuli in both modalities at angle q, the target response is angle p. Otherwise, it is angle q.
Task structure B. These tasks are similar to tasks with structure A, except they include a delay period between stimulus presentation and response.2 tasks share this task structure, which consist of 4 task epochs: Task epoch 1 Fixation • Delayed go (dgo): The fixation cue is set to 1 during the first three epochs.A stimulus is presented in a randomly chosen modality with a random direction during task epoch 2. The response made during task epoch 4, i.e. after a delay period, should match the direction of the stimulus.
• Delayed anti-response (danti): This is the same as dgo, but the response should be in the opposite direction of (i.e.π radians from) the stimulus.
Task structure C.These tasks are similar to tasks with structure B, except they include an additional stimulus which is presented after the delay period.The target response is based on the relationship between the two stimuli that are presented before and after the delay period.7 tasks share this task structure, which consists of 5 task epochs: • Delayed decision making (ddm): The fixation cue is set to 1 during the first four epochs.A stimulus is presented during task epoch 2 in a randomly chosen modality with a random magnitude and direction.After the delay period (task epoch 3), a second stimulus is generated during task epoch 4 in the same modality as task epoch 1 and with a random magnitude, but in a direction that is drawn uniformly at random between π 2 radians to 3π 2 radians away from the first stimulus.The response, made during task epoch 5, should match the direction of the stimulus with the greater magnitude.
• Delayed decision making with distractors (ddmd): This is the same as ddm, except 2 stimuli are presented in each modality.All 4 stimuli have different magnitudes, but the directions of the two stimuli in modality A are the same as the directions of the two stimuli in modality B. The RNN must ignore modality A and respond in the direction corresponding to the stronger stimulus in modality B.
• Multi-sensory delayed decision making (msddm): This is the same as ddmd, except the RNN must respond in the direction in which the combined stimulus strength from both modalities is stronger.
• Delayed match-to-sample (dms): The fixation cue is set to 1 during the first four epochs.During task epoch 2, a stimulus is presented in a randomly chosen modality with a fixed magnitude and a random direction.After a delay period (task epoch 3), a second stimulus is presented in a randomly chosen modality during task epoch 4 with the same fixed magnitude.In half the trials, the direction of this second stimulus is within ± π 18 radians of the first stimulus.This corresponds to a match.In the other half of the trials, the direction of the second stimulus is drawn uniformly between π 2 radians to 3π 2 radians away from the first stimulus.This corresponds to a non-match.On match trials, the RNN must respond during task epoch 5 in the direction of the second stimulus.On non-math trials, it must continue maintaining fixation during task epoch 5.
• Delayed non-match-to-sample (dnms): This is the same as dms, except the RNN must maintain fixation during task epoch 5 on match trials, and respond in the direction of the second stimulus on non-match trials.
• Delayed match-to-category (dmc): The fixation cue is set to 1 during the first four epochs.During task epoch 2, a stimulus is presented in a randomly chosen modality with a fixed magnitude.In half the trials, the direction of the first stimulus is drawn uniformly between 0 to π 2 radians.In the other half of the trials, the direction is drawn uniformly between π radians to 3π 2 radians.After a delay period (task epoch 3), a second stimulus is presented in a randomly chosen modality during task epoch 4 with the same fixed magnitude.The direction of the second stimulus is either in the same quadrant as the first stimulus, corresponding to a match, or in the opposite quadrant which corresponds to a non-match.During task epoch 5 of match trials, the RNN must respond in the direction of the second stimulus.Instead they must continue maintaining fixation during task epoch 5 of non-match trials.
• Delayed non-match-to-category (dnmc): This is the same as dmc, except the RNN must maintain fixation during task epoch 5 on match trials, and respond in the direction of the second stimulus on non-match trials.
Two-choice rule reversal task.The trial structure consists of 4 task epochs:

E Dynamical stability of proposed methods
We assume that the true dynamics of the network is stable.That is, the network is a stable system when simulated using Euler's method at base time discretization ∆t with neuron time constant α∆t.In our simulations, α = 20.We observe empirically that when we increase the discretization step size for CD, the network diverges further from its true dynamics and tends to go unstable in general.Particularly, we note that for our networks, the threshold for instability is approximately within a small range around the neuron time constant.To model this, we first consider a linear approximation of the network dynamics: Here, we assume (from the above empirical explanation) that the dominant eigenvalue of W * θ is dependent on θ, where θ∆t is the time discretization used to simulate the network.From this framework, this eigenvalue will exceed 1 when θ reaches some threshold θ thres .We also define W * 1 for the case when θ = 1, which is a stable system by definition.Simulating this linearized network for θ timesteps yields: Applying the same approximation to the update equation for CD, we find a similar expression: We seek to understand why SCTT and DASC are not subject to the same threshold when increasing the temporal length of their skip connections.In fact, our simulations suggest that these skip connections can extend considerably beyond the CD threshold, up to several multiples of the time constant.This is a huge advantage for SCTT and DASC when it comes to facilitating gradient stability over a large number of time steps.To gain an intuition on why this is the case, we apply the same approximation to the update equation for DASC: If we define W * eff to be the effective linearized weight matrix: we find that it is of the same form as the linear solutions described in (3) and ( 5), implying that the stability of DASC depends on the eigenvalues of W * eff .While W * θ could have eigenvalues that are greater than 1, we know that W * 1 gives rise to a stable system (eigenvalues less than 1).The suggests that the overall weighted sum of the two weight matrices may not always lead to an unstable system.From this analysis, we conclude that it is therefore possible to implement skip connections that span longer than the limit faced by CD.

F Full training results
The number of training steps taken by all methods across all hyperparameter configurations can be found in Figure S2.For CD, the three bars represent training efficiencies at step counts of 1, 5 and 10 respectively (left to right).For SCTT and DASC, the exact hyperparameter configuration of each bar (left to right) in the figure are given by: where the subscript represents the value of θ for each method (to be precise, it represents θ 0 for CD).

G Hyperparameter exploration
We find that the optimal configuration for DASC strongly depends on the task specifics (Figure S3A).For tasks that do not require maintenance of variables in neural activity-based memory (i.e. over long durations), a short skip length supports modified dynamics that are faithful to the true dynamics, while still providing the advantages of skip connections.Instead, longer skip connections alleviate gradient stability problems that naturally arise in tasks that do require memory, thus improving training efficiency on these tasks.In addition, we find that the training efficiency of models with SCTT depends on the annealing schedule for β (Figure S3B).Models train better with SCTT when their training starts at high values of β so that their modified dynamics better approximate the true dynamics.They also train better with annealing schedules that change β more gradually to avoid abrupt changes in network dynamics during training, thus increasing the number of steps to complete training.This is not observed in CD and DASC, which are not as sensitive to their annealing schedules or to the initial value of β (DASC).

H Additional training details
All models were trained on high-performance computing clusters consisting of NVIDIA A100 80GB GPUs.Computing resources are summarized in Table S3, showing that we have trained over 90,000 models over more than 1,000 GPU hours.The code to training all models can be found at: https://github.com/wmws2/temporalskipTotal GPU hours approx.1120 *Although there are 100 models per configuration for RNNs trained to perform the rule reversal task, only 10 models could be run in parallel at a time due to memory constraints.As such, GPU hours consumed were increased by a factor of 10 when converting to total GPU hours.

Figure S1 :
Figure S1: Biologically plausible RNNs in literature. A. Overview of how RNN models are used in neuroscience.B. Number of time steps used in the respective tasks that the RNNs were trained to perform.C. Neuron time constant values in the RNNs.D. Step sizes as a fraction of their respective time constants.

Figure S3 :
Figure S3: Analysis of factors affecting training efficiency of each algorithm.A. Training steps pooled over all 16 tasks against skip length used by SCTT (green) and DASC (blue).A skip length of 20 corresponds to a time interval of one time constant.Plots are differentiated between tasks that require long-term dependencies and tasks that do not.B. Effects of the initial skip ratio (left) and step count (right) on training efficiency of each respective method.Wilcoxon signed-rank test confidence levels: * p < 5 × 10 −5 , * * p < 5 × 10 −10 Training efficiencies of RNNs trained on all 16 standard tasks across all methods in all hyperparameter configurations.The control model is shown on the left (black).Red, blue and green bars indicate training efficiencies of CD, SCTT and DASC respectively.See section F for exact configurations for each bar.

Table S1 :
Model parameters of biological RNNs in literature.
*Note that if a given reference contains several models, then the parameters of the most relevant model are reported here (if equally relevant, then one is chosen at random).

Table S3 :
Computing resources.A.Total number of models