Abstract
The dominant view in neuroscience is that changes in synaptic weights underlie learning. It is unclear, however, how the brain is able to determine which synapses should change, and by how much. This uncertainty stands in sharp contrast to deep learning, where changes in weights are explicitly engineered to optimize performance. However, the main tool for doing that, backpropagation, is not biologically plausible, and networks trained with this rule tend to forget old tasks when learning new ones. Here we introduce the Dendritic Gated Network (DGN), a variant of the Gated Linear Network [1, 2], which offers a biologically plausible alternative to backpropagation. DGNs combine dendritic “gating” (whereby interneurons target dendrites to shape neuronal response) with local learning rules to yield provably efficient performance. They are significantly more data efficient than conventional artificial networks and are highly resistant to forgetting, and we show that they perform well on a variety of tasks, in some cases better than backpropagation. The DGN bears similarities to the cerebellum, where there is evidence for shaping of Purkinje cell responses by interneurons. It also makes several experimental predictions, one of which we validate with in vivo cerebellar imaging of mice performing a motor task.
1 Introduction
A hallmark of intelligent systems is their ability to learn. Humans, for instance, are capable of amazing feats – language acquisition and abstract reasoning being the most notable – and even fruit flies can learn simple reward associations [3, 4]. It is widely believed that learning is implemented via synaptic plasticity. But which synapses should change in response to, say the appearance of a reward, and by how much? This is especially hard to answer in humans, who have about 1014 synapses, but it is hard even in fruit flies, which have about 108 – corresponding to 100 million adjustable parameters.
One answer to this question is known: introduce a loss function (a function that measures some aspect of performance, with higher performance corresponding to lower loss), compute the gradient of the loss with respect to the weights (find the direction in weight space that yields the largest improvement in performance), and change the weights in that direction. If the weight changes are not too large, this will, on average, reduce the loss, and so improve overall performance.
This approach has been amazingly successful in artificial neural networks, and has in fact driven the deep learning revolution [5]. However, the algorithm for computing the gradient in deep networks is not directly applicable to biological systems as first pointed out by [6, 7] (see also recent reviews, [8–10]). First, to implement backpropagation [11–13], referred to simply as backprop, neurons would need to know their outgoing weights. Second, backprop requires two stages: a forward pass (for computation) and a backward pass (for learning). Moreover, in the backward pass an error signal must propagate from higher to lower areas, layer by layer (Fig. 1A), and information from the forward pass must remain in the neurons. However, biological neurons do not know their outgoing weights, and there is no evidence for a complicated, time-separated backward pass.
Comparison of multi-layer perceptrons (MLPs) and Dendrtic Gated Networks (DGNs). In all panels the blue filled circles at the bottom correspond to the input. A. MLP. Blue arrows show feedforward computations; red arrows show the error propagating back down. B. DGN. As with MLPs, information propagates up, as again shown by the blue arrows. However, rather than the error propagating down, each layer receives the target output, which it uses for learning. C. A single postsynaptic neuron in layer k of a DGN, along with several presynaptic neurons in layer k−1. Each branch gets input from all the presynaptic neurons (although this is not necessary), and those branches are gated on and off by inhibitory interneurons which receive external input. The white interneuron is active, so its corresponding branch is gated off, as indicated by the light gray branches; the gray neurons are not active, so their branches are gated on.
Backprop also leads to another problem, at least in standard deep learning setups: it adapts to the data it has seen most recently, so when learning a new task it forgets old ones [14]. This is known as catastrophic forgetting, and prevents networks trained with backprop to display the lifelong learning that comes so easily to essentially all organisms [15, 16].
Driven in part by the biological implausibility of backprop, there have been several proposals for architectures and learning rules that might be relevant to the brain. These include feedback alignment [17,18], creative use of dendrites [19,20], multiplexing [21], and methods in which the error signal is fed directly to each layer rather than propagating backwards from the output layer [22–29]. A particularly promising method that falls into the latter category is embodied in Gated Linear Networks [1, 2]. These networks, which were motivated from a machine learning rather than a neuroscience perspective, have obtained state-of-the-art results in regression and denoising [30], contextual bandit optimization [31], and transfer learning [32].
In Gated Linear Networks, the goal of every neuron, irrespective of its layer, is to predict the target output based on the input from the layer directly below it. This is very different from backprop, in which neurons in intermediate layers extract features that make it easier for subsequent layers to predict the target (compare Figs. 1A and B). Gated Linear Networks are thus particularly suitable for biologically plausible learning: every neuron is essentially part of a shallow network, with no hidden layers, for which the delta rule [33] – a rule that depends only on presynaptic and postsynaptic activity, the latter relative to the target activity – is sufficient to learn.
To implement these local learning rules, the target activity is sent to every neuron, in every layer of the network (Fig. 1B, red arrows). This is typical of a large class of learning rules [22, 23, 25–29]. Completely atypical, though, is the role of the external input. It’s used for gating the weights: each neuron has a bank of weights at its disposal, and the external input determines which one is used. For example, a neuron might use one set of weights when the visual input contains motion cues predominantly to the right; another set of weights when it contains motion cues predominantly to the left; and yet another when there are no motion cues at all.
Endowing each neuron with a library of weights is, of course, highly inconsistent with what we see in the brain. So instead we gate dendritic branches on and off, using inhibitory neurons, in an input-dependent manner (Fig. 1C). We thus refer to these networks as Dendritic Gated Networks (DGNs). Dendritic gating allows DGNs to represent essentially arbitrary nonlinear functions. Moreover, gating makes DGNs especially resistant to forgetting. In particular, when data comes in sufficiently separate “tasks”, they can learn new ones without forgetting the old. Finally, the loss is a convex function of the weights for each unit (see Supplementary Methods), as it is in Gated Linear Networks [1]. Convexity is an extremely useful feature, as it enables DGNs, like the Gated Linear Networks on which they are based, to learn (provably) efficiently.
Below we describe multi-layer Dendritic Gated Networks in detail – both the architecture and the learning rule. We then train them on three tasks: one on which networks trained with backprop typically exhibit catastrophic forgetting, and two relevant to the cerebellum. Finally, we show experimentally that in the cerebellum gating remains relatively stable over time – a key prediction of our model. We map the proposed learning rule and the associated architecture to cerebellum because 1) the climbing fibers provide a feedback; 2) its input-output function is relatively linear [34–36]; and 3) molecular layer interneurons could act as a gate [37–46].
2 Results
2.1 Dendritic Gated Networks
Dendritic Gated Networks, like conventional deep networks, are made up of multiple layers, with the input to each layer consisting of a linear combination of the activity in the previous layer. Unlike conventional deep networks, though, the weights are controlled by external input, via gating functions, denoted g(x) (Fig. 1B); those functions are implemented via dendritic branches (Fig. 1C).
This results in the following network equations. The activity (i.e., the instantaneous firing rate) of the ith neuron in layer k, denoted rk,i, is
with the synaptic drive, hk −1,j, given in terms of rk−1,j as
Here Ø (·) is the activation function (either identity or sigmoid), x is the input of the network (x is an n-dimensional vector, x = (x1, x2, …, xn)), rk,i is the activity of ith neuron on layer k (with rk,0 set to 1 to allow for a bias term), Bk,i is the number of branches of neuron i in layer k,
is the weight from neuron j in layer k − 1 to the bth branch of neuron i in layer k, nk is the number of neurons in layer k, and
(x) is the binary gating variable; it is either 1 (in which case the bth branch of the ith neuron is gated on) or 0 (in which case it is gated off). There are K layers, so k runs from 1 to K. The input to the bottom layer is x. The mapping from the input, x, to the gating variable,
, is not learned; instead, it is pre-specified, and does not change with time. In all of our simulations we use random half-space gating [1]; that is,
where
and
are sampled randomly and kept fixed throughout learning (see Methods), and “·” is the standard dot product.
In Dendritic Gated Networks, the goal of each neuron is to predict the target, denoted r*. To do that, the weights, , are modified to reduce the loss, lk(r*, rk,i). For weight updates we use gradient descent,
where η > 0 is the learning rate, and the updates are performed after each sample. The precise form of the loss can influence both the speed of learning and the asymptotic performance, but conceptually we should just think of it as some distance between r* and rk,i. In most of the simulations, we assume Ø is the identity (rk,i = hk,i) and we use quadratic loss
so the update rule becomes
This has the form of a gated version of the delta rule [33]. See Methods, Sec. 4.1 for the alternative formulation, which we use for classification problems.
2.2 Simulations
Equations (1) and (2) for the network dynamics and Eq. (3) for learning constitute a complete description of our model. For a given problem, we just need to choose a target input-output relationship (a mapping from x to r*) and specify the loss functions, l(r*, rk,i). Here we consider two tasks. The first two (catastrophic forgetting) tasks are classification, for which we use a sigmoid activation for Ø and log loss (Methods, Sec. 4.1); the remaining (cerebellar) tasks are regression, where we use an identity activation and quadratic loss (Sec. 2.1).
DGNs can mitigate catastrophic forgetting
Animals are able to acquire new skills throughout life, seemingly without compromising their ability to solve previously learned tasks [15, 16]. Standard networks do not share this ability: when trained on two tasks in a row, they tend to forget the first one (see Fig. 2, second row). This phenomenon, known as “catastrophic forgetting”, is an old problem [47–49], and many algorithms have been developed to address it. These algorithms typically fall into two categories. The first involves replaying previously seen tasks during training [49–51]. The second involves explicitly maintaining additional sets of model parameters related to previously learned tasks. Examples include freezing a subset of weights [52, 53], dynamically adjusting learning rates [54], and augmenting the loss with regularization terms with respect to past parameters [55–57]. A limitation of these approaches (aside from additional algorithmic and computational complexity) is that they require task boundaries to be provided or accurately inferred.
Comparison of the DGN to a standard multi-layer perceptron (MLP) trained with backprop. Each point on the square has to be classified as “red” (“Class 1”) or “blue” (“Class 0”). We consider a scenario common in the real world (but difficult for standard networks): the data comes in two separate tasks, as shown in the first row. We trained a 2-layer MLP (second row) and a 2-layer DGN (third row) on the two tasks. The output of the network is the probability of each class, as indicated by color; the percentages report the accuracy for each of the tasks. The MLP uses ReLU activation functions, so each neuron has an effective gating; the boundaries of those gates are shown in gray. The boundaries move with learning, and are plotted at the end of training of each of the tasks (white lines). The boundaries of the DGN do not move, so we plot them only in the first column. After training on Task A, most of the boundaries in the MLP are aligned at -45 degrees, parallel to the decision boundaries, which allows the network to perfectly separate the two classes. In the DGN, the boundaries do not change, but the network also perfectly separates the two classes. However, after training on Task B, the DGN retains high performance on Task A (91%), while the MLP’s performance drops to 66%. That’s because many of the boundaries changed to the orthogonal direction (45 degrees). For the DGN, on the other hand, changes to the network were much more local, allowing it to retain the memory of the old task (see samples from Task A overlaid on all panels) while accommodating the new one. The MLP has 50 neurons in the hidden layer; the DGN has 5 neurons on 10 dendritic branches in the hidden layer.
Unlike contemporary neural networks, the DGN architecture and learning rule is naturally robust to catastrophic forgetting without any modifications or knowledge of task boundaries. In Fig. 2 we illustrate the mechanism behind this robustness, and show how it differs from a standard multi-layer perceptron on a single example. To demonstrate this in a more challenging task, we train a DGN on the pixel-permuted MNIST continual learning benchmark [55, 58]. In this benchmark, the network has to learn random permutations of the input pixels, with the random permutation changing every 60,000 trials (see Methods for additional details). We compare the DGN to a multi layer perceptron (MLP) with and without elastic weight consolidation (EWC) [55] as per the original papers [2, 55]. EWC is a highly-effective method explicitly designed to prevent catastrophic forgetting by storing parameters of previously seen tasks. However, it has a much more complicated architecture, and it must be supplied with task boundaries, so it receives more information than the DGN.
Because MNIST has 10 digits, we train 10 different DGNs. The alternative would be a single DGN where each unit has a 10 dimensional output corresponding to the class probabilities. However, this setting is not as biologically plausible, so we did not use it. Each of the 10 networks contains 3 layers, with 100, 20, and 1 neurons per layer, and has 10 dendritic branches per neuron. The targets are categorical (1 if the digit is present, 0 if it is not), so we use Bernoulli log loss rather than quadratic loss (see Methods, Sec. 4.1). We use 1000, 200, and 10 neurons per layer for the MLP (so that the number of neurons match the number used for the DGN), with cross entropy loss, both with and without elastic weight consolidation, and optimize the learning rates separately for each network.
Figure 3 shows the learning and retention performance of the DGN, with the MLP and EWC networks included primarily as benchmarks (recall that neither is biologically plausible). In Fig. 3A we plot performance on each task for the three networks; as can be seen, performance is virtually identical. In Fig. 3B we investigate resistance to forgetting, by plotting the performance on the first task as the nine subsequent tasks are learned. The EWC network retains its original performance almost perfectly, the MLP forgets rapidly, and the DGN is in-between. It is not surprising that the EWC does well, as it was tailored to this task, and in particular it was explicitly given task boundaries. Somewhat more surprising is the performance of the DGN, which had none of these advantages but still forgets much more slowly than the MLP. The DGN also learns new tasks more rapidly than either the EWC or MLP networks (Supplementary Figure S1), possibly because of its convex loss function.
Learning and retention on the permuted MNIST task. The tasks are learned sequentially in a continual learning setup. A. Performance (on test data) for each of the 10 tasks, where a “task” corresponds to a random permutation of the pixels. B. Performance on the first task after each of nine new tasks is learned. As discussed, the MLP is especially bad at this task. The EWC is much better, to a large extent because it was provided with extra information: the task boundaries. Even though the DGN was not given that information, it forgets a factor of two more slowly than the MLP. Error bars in both plots denote 95% confidence over 20 random seeds.
Mapping DGNs to the Cerebellum
For the next two simulations we consider computations that can be mapped onto cerebellar circuitry. We focus on the cerebellum for several reasons: it is highly experimentally accessible; its architecture is well characterized; there is a clear feedback signal (the climbing fiber) to the Purkinje cells (the cerebellar neurons principally involved in learning); its input-output function is relatively linear [34–36]; and molecular layer interneurons play a major role in shaping Purkinje cell responses [37–43, 45], and can influence climbing fiber-mediated dendritic calcium signals in Purkinje cells [44, 46].
Both classic and more modern theoretical studies in the cerebellum have focused on the the cerebellar cortex, modelling it as a one-layer feedforward network [59–63]. In this view, the parallel fibers project to Purkinje cells, and their synaptic weights are adjusted under the feedback signal from the climbing fibers. This picture, however, is an over-simplification, as Purkinje cells do not directly influence downstream structures. Instead, they project to the cerebellar nucleus neurons, which constitute the ultimate output of the cerebellum (see Fig. 4). The fact that Purkinje cells form a hidden layer, combined with the observed plasticity in the Purkinje cell to cerebellar nucleus synapses [64–68], means most learning rules tailored to one-layer networks, including the delta rule, cannot be used to train the network.
The cerebellum as a two layer DGN. Contextual information from the mossy fiber (MF)/granule cell (GC) pathway is conveyed as input to the network via parallel fibers (PFs) that form synapses onto both the dendritic branches of Purkinje cells and molecular layer interneurons (MLIs). The inhibitory MLIs act as input-dependent gates of Purkinje cell dendritic branches. Purkinje cells converge onto the cerebellar nuclear neurons (CbNs) and constitute the output of the cerebellar network. The climbing fibers (CFs, red) originating in the inferior olive (IO) convey the feedback signal that is used to tune both the Purkinje cells, based on which inputs are gated on or off, and also the CbNs. Excitatory and inhibitory connections are depicted as round- and T-ends, respectively. Dashed lines represent connections not included in the model.
We propose instead that the cerebellum acts as a two layer DGN comprised of Purkinje cells as the first, hidden layer and the cerebellar nucleus as the second, output layer (Fig. 4). Parallel fibers provide the input to both the input layer (Purkinje cells) as well as the gates, represented by molecular layer interneurons, that control learning in individual Purkinje cell dendrites. For the second layer of the DGN, we use a non-gated linear neuron rather than a gated neuron. This is because the unique biophysical features of cerebellar nuclear neurons allow them to integrate input linearly [69]. Note that we can keep the DGN formulation given in Eq. (1); in the second layer we just use one branch (B2,i = 1) which is always gated on . Finally, the climbing fibers provide the feedback signal to Purkinje cells and cerebellar nuclear neurons. In our formulation, climbing fiber feedback signals the target, allowing each neuron to compute its own local error by comparing the target to its output (rk,i). This formulation is a departure from the strict error-coding role that is traditionally attributed to climbing fibers, but is consistent with a growing body of evidence that climbing fibers signal a variety of sensorimotor and cognitive predictions [70].
DGNs can learn inverse Kinematics
The cerebellum is thought to implement inverse motor control models [71, 72]. We therefore applied our proposed DGN network to the SARCOS benchmark [73], which is an inverse kinematics dataset collected using a 7 degree-of-freedom robot arm (Fig. 5). The goal is to learn an inverse model, and predict 7 joint torques given the joint positions, velocities, and accelerations for each of the 7 joints (corresponding to a 21 dimensional input).
Sarcos experiment. DGNs can solve a challenging motor control task: predicting torques from the proprioceptive inputs. The data comes from a SARCOS dexterous robotic arm [73], pictured on the left. The inputs are position, velocity and acceleration of the joints (21 dimensional variables); the targets are the desired torques (7 dimensional). Example targets (normalized to keep the training data between 0 and 1) are shown with dots, the lines are the output of our network. Performance is very good; only rarely is there a visible difference between the dots and the lines.
The target output, r*, is the desired torque, given the 21-dimensional input. There are seven joints, so we train seven different networks, each with its own target output. We use DGN networks with 20 Purkinje cells, each having 5000 branches, and minimize the quadratic loss (4).
In Fig. 5 we plot the target torques for each joint (dots) along with the predictions of the DGN (lines; chosen for ease of comparison as there is no data between the points). The lines follow the points very closely, even when there are large fluctuations, indicating that the DGN is faithfully predicting torques. The performance of our network (mean squared error on test data in the original torque units) exceeds that of most machine learning algorithms (Supplementary Table S1) while using fewer (or an equal number of) samples to learn. This illustrates the power of DGNs; we now turn to a cerebellar task much more typical of computational and experimental neuroscience.
Vestibulo-ocular reflex, and adaptation to gain changes
When an animal moves its head, to maintain a stable image on the retina it moves its eyes in the opposite direction. This is known as the vestibulo-ocular reflex (VOR), and a key feature of it is that it is plastic: animals can adapt quickly when the relationship between the head movement and visual feedback is changed, as occurs as animals grow or are given corrective lenses. VOR gain adaptation relies critically on the cerebellum, and has been used to study cerebellar motor learning for decades [74–78].
We thus applied our DGN network to model learning of VOR gain changes. The gain, denoted G, is the ratio of the desired eye velocity to the head velocity (multiplied by −1 because the eyes and head move in opposite direction, to keep with the convention that the gain is reported as a positive number). When the gain is (artificially) changed, at first animals move their eyes at the wrong speed, but after about 15 minutes they learn to compensate [76, 77].
We trained our network on a head velocity signal of the form
with ω1 = 13.333 and ω2 = 20.733 (corresponding to 2.12 and 3.30 Hz, respectively).
This was chosen to mimic, approximately, the irregular head velocities encountered in natural viewing conditions. Following Clopath et. al. [79], we assumed that the Purkinje cells receive delayed versions of this signal. The ith input signal, xi(t), which arrives via the parallel fibers, is modelled as
with delays, τi, spanning the range 50-300 ms. The cerebellum needs to compute the scaled version of the eye velocity: r*(t) = Gs(t) (as mentioned above, the actual eye movement is −r*(t), but we follow the standard convention). Learning was online, and we updated the weights every 500 ms, to approximately match the climbing fiber firing rate [80].
The DGN contained 20 Purkinje cells, with 10 branches each. As a baseline, we trained an MLP with the same number of weights (resulting in 200 hidden neurons). We used quadratic loss for both the DGN and the MLP and, as in [79], we assumed n = 100 parallel fibers and a single output. Each branch received input from all 100 parallel fibers. Gating (Eq. (2)) was controlled by xi(t) (given in Eq. (7)), refelecting the parallel fiber influence on molecular layer interneuronss (Fig. 4); see Methods for details. Given the timescale of the signal (2-3 Hz), any individual branch was gated on for about 500 ms at a time. The networks were pre-trained on a gain, G, of 1. We implemented four jump changes: first to 0.7, then back to 1.0, then to 1.3, and, finally, back to 1.0; in all cases, for 30 minutes (Fig. 6A).
VOR adaptation task. We trained the networks on gain G = 1, then changed the gain every 30 minutes. Results are shown for the Dendritic Gated Network (DGN) and a multi-layer perceptron (MLP). A. Dashed lines are true gain versus time; blue and purple lines are gains computed by the DGN and MLP, respectively. For both networks, gains were inferred almost perfectly after 15-20 minutes. B. Performance, measured as mean squared error between the the true angular velocity, Gs(t) (Eq. (6)), and the angular velocity inferred by the networks. Same color code as panel A. C. Comparison of target angular velocity versus time (black) to that predicted by the DGN (blue). (A plot for the MLP is similar.) Before the gain change, the two are almost identical; immediately after the gain change, the network uses the previous gain. D. Top panel: Parallel fiber weights for the DGN network versus delay, τi (Eq. (7)). Each panel shows 10 branches; 5 Purkinje cells are shown (chosen randomly out of 20). The weights vary smoothly with delay. Bottom panel: MLP weight profile, except that dendritic brances are replaced by the whole neuron (all 100 parallel fibers). For the MLP, the weights with similar delays are effectively uncorrelated.
Performance for both the DGN and the MLP were comparable and, after suitably adjusting the learning rates, the networks were able to learn in 15-20 minutes (Fig. 6A, B). Figure 6C shows the target and predicted head velocities immediately before and after each gain change. Not surprisingly, immediately after a gain change, the network produces output with the old gain.
Although both the DGN and the MLP solve this task, their internal mechanisms are remarkably different. Figure 6D shows the connection strengths between parallel fibers (xi(t), Eq. (7)) and Purkinje cells, after learning, as a function of the delay, τi. Every branch of the DGN (top panel; blue) develops a smooth connectivity pattern: parallel fibers that have similar delays have similar strengths. The smoothness of weights versus delay constitutes a strong prediction of our model.
2.3 Testing predictions of the DGN in behaving animals
A critical feature of our model is that the gates (in the case of the cerebellum, the molecular layer interneurons) should remain stable over learning, or at least be more stable than other parts of the circuit, such as climbing fiber inputs to Purkinje cells. To test this, we performed simultaneous two-photon calcium imaging of molecular layer interneurons (MLIs) and Purkinje cell dendrites (Fig. 7A,B) in awake behaving head-fixed mice. Imaging occurred while the mouse was learning to make an association between a tone cue and a reward (Fig. 7C; note the absence of licks between the cue and the reward in the first 18 trials, before learning). To assess the stability of responses across learning, we computed a trial-wise population vector response to reward delivery (vector of mean response in 1 second following reward delivery for each neuron; Fig. 7D). We compared the stability of these population response vectors in MLIs and Purkinje cell dendrites (reflecting climbing fiber input) over the course of the first 125 cue-reward pairing trials (Fig. 7E). The MLI population vector response was significantly more stable across these learning trials than the corresponding population response vector in Purkinje cell dendrites (Fig. 7F), consistent with the tendency for DGN gates to remain stable while other elements evolve with learning.
Testing experimental predictions of DGN during learning of cue-reward association. A. Simultaneous multi-plane 2-photon imaging of molecular layer interneurons (red hues) and Purkinje cell dendrites (blue hues) expressing GCaMP7f. Images were acquired across 5 planes at an effective rate of 9.7 Hz. B. Example traces of simultaneously recorded MLIs (red, top) and Purkinje cell dendrites (blue, bottom). C. Licking responses of mice during initial 125 trials of cue-reward pairing showing licking on individual trials (top) and mean lick probability (bottom). D. Trial-wise reward delivery responses in MLIs (left, n = 15) and Purkinje cell dendrites (right, n = 67) calculated as mean response in 1 s window after reward delivery. E. Similarity matrix of population vector response for MLIs (left) and Purkinje cell dendrites (right). F. Mean pairwise correlations of population vector responses in MLIs and Purkinje cell dendrites. MLIs responses exhibit greater trial-by-trial consistency. Data are shown as mean± S.E.M.
3 Discussion
A critical open question in neuroscience is: what learning rules ensure that synaptic strengths are updated in a way that improves performance? Answering this is difficult in large part because of the way we think about computation, which is that networks map input to output in stages, with the input gradually transformed, until eventually, in the output layer, the relevant features are easy to extract. There is certainly some evidence for this. It is, for example, much harder to extract which face a person is looking at from activity in visual area V1 than in fusiform face area [81, 82]. While this strategy for computing is reasonable, it has a downside: the relationship between activity in intermediate layers and activity in the output layer is highly nontrivial, which makes it especially hard for the brain to determine how weights in intermediate layers should change.
Here we propose that the brain might take a different approach, one based on Dendritic Gated Networks, or DGNs, which is a variant of the Gated Linear Network [1, 2]. With this architecture, each neuron is active for a relatively small region of the input space; for the rest, it is gated “off”. Each neuron receives its input from the layer below, as in conventional networks, but its goal is not to transform that input; instead, its goal is to predict the output of the whole network. That makes the role of every neuron transparent (all neurons in all layers are doing the same thing), which makes learning simple – all that is required is a delta rule.
The ease of learning makes DGNs strong candidates for biological networks. In addition, we showed they are compatible with the architecture and function of the cerebellum, and that they perform well on three nontrivial tasks. Finally, we supplied preliminary experimental support for gating, which in the cerebellum we hypothesize is done by the molecular layer interneurons.
DGNs make three strong predictions for the cerebellum. First, the activity of the molecular layer interneurons should depend solely on parallel fiber input and should not change with learning – or change very slowly relative to the timescale over which Purkinje cells learn, the latter measured in single trials [83]. This prediction is consistent with our in vivo imaging experiments. Second, dendritic branches should be in one of two states, determined by molecular layer interneuron activity: either a branch receives very little MLI input, so that it can transmit information from parallel fibers to Purkinje cells, or it receives very large MLI input, so that it cannot transmit information. Testing the second prediction is challenging, but could be addressed using a combination of cellular resolution all-optical stimulation and voltage imaging, a technical feat that may soon be within reach [84, 85]. Third, for parallel fibers carrying delayed information about head position, the parallel fiber to Purkinje cell weights should be a smooth function of the delay (Fig. 6d, top panel).
In summary, Dendritic Gated Networks are strong candidates for biological networks – and not just in the cerebellum; they could be used anywhere there is approximately feedforward structure. They come with two desirable features: biologically plausible learning, and rapid, data-efficient learning. And they imply a novel role for inhibitory neurons, which is that they are used for gating dendritic branches on and off. Importantly, they make strong, experimentally testable, predictions, so we will soon know whether they are actually used in the brain.
4 Methods
4.1 Model
The network we use in our model is described in Eqs. (1) and (2), and the learning rules are given in Eq. (3). In particular, Eq. (5) is used in all our simulations except for MNIST, where the output is categorical. In that case, we bound neural activities so they can represent probabilities. We use a standard sigmoid function, σ(z) = ez/ (1 + ez), albeit modified slightly,
where
clips values between a and b (so the right hand side is zero if σ(z) is smaller than ϵ or larger than 1 − ϵ). Clipping is used for bounding the loss as well as the gradients; this helps with numerical stability, and also enables a worst-case regret analysis [1, 2]. We set ϵ to 0.01, so neural activity lies between values 0.01 and .99.
The loss of neuron i in layer k in this case is given by
Consequently, the update rule for the weights, Eq. (3), is (after a small amount of algebra)
where 1 (·) is 1 when its argument is true and 0 otherwise. The fact that the learning is zero when rk,i is outside the range [ϵ, 1 − ϵ] follows because dØ (z)/dz = 0 when z is outside this range (see Eq. (8)). This ensures that learning saturates when weights become too large (either positive or negative). However, this can cause problems if the output is very wrong: when r* = 1 and rk,i < ϵ or r* = 0 and rk,i > 1 − ϵ. To address this, we allow learning in this regime. We can do this compactly by changing the learning rule to
Essentially, this rule says: stop learning when rk,i is within E of r*. See [86] for a complementary view of how categorical problems might be solved by gated neurons in the brain.
For a compact summary of the equations (given as pseudocode), see Supplementary Algorithms S1 and S2.
4.2 Simulations
Simulations were written using JAX [87], the DeepMind JAX Ecosystem [88], and Colab [89].
Catastrophic Forgetting
We adopt the pixel-permuted MNIST benchmark [55, 58], which is a sequence of MNIST digit classification tasks with different pixel permutations. Each task consists of 60,000 training images and 10,000 test images, all images are deskewed. Models are trained sequentially across 10 tasks, performing a single pass over each. We provide the implementation details of the baselines below. We display the parameters swept during grid search in Supplementary Table S2.
DGN. We use networks composed of 100 and 20 units in the hidden layers and a single linear neuron layer for the output. Each neuron in the hidden layer has 10 dendritic branches. The output of the network is determined by the last neuron. MNIST has 10 classes, each corresponding to a digit. Therefore, we utilize 10 DGN networks, each encoding the probability of a distinct class. Each of these networks are updated during training using a learning rate of 10−2. During testing, the class with the maximum probability is chosen. Images are scaled and shifted so that the input range is [−1, 1]. The gating vectors, are chosen randomly on the unit sphere, which can be achieved by sampling from an isotropic Normal distribution and then dividing by the L2 norm. The biases,
are drawn independently from a centred normal distribution with standard deviation 0.05.
MLP and EWC. We use a ReLu network with 1000 and 200 neurons in the hidden layers and 10 linear output units with cross entropy loss. In this setting, the MLP and EWC have the same number of neurons as DGN but fewer plastic weights in total. We use the ADAM optimization method [90] with a learning rate of 10−4 (see Supplementary Table S2 for details of the hyperparameter optimization), in conjunction with dropout. We use mini-batches of 20 data points. For EWC, we draw 100 samples for computing the Fisher matrix diagonals and set the regularization constant to 103.
Inverse Kinematics
Each DGN network has 20 Purkinje cells with 5000 branches each. We use a quadratic loss (4) with a learning rate η = 10−5 for 2000 epochs (2000 passes over the dataset). The inputs are centered at 0 and scaled to unit variance per dimension, the targets are scaled so that they lie between 0 and 1. The reported MSEs are computed on the test set based on inverse transformed predictions (thus undoing the target scaling). The gating parameters are chosen in the same way as for the Mnist simulations (see above).
We discovered that the the training set of the SARCOS dataset (downloaded from http://www.gaussianprocess.org/gpml/data/ on 15 December 2020) includes test instances. To the best of our knowledge, other recent studies using the SARCOS dataset [91, 92] reported results with this train/test setting. This means that the reported errors are measures of capacity rather than generalization. We compare the performance of DGN against the best known SARCOS results in Supplementary Table S1 using the existing train/test split. If we exclude the test instances from the train set, we get an MSE for the DGN of 0.84 using the same network setting and parameters.
VOR
The gating parameters and
(Eq. (2)), were drawn independently from the standard normal distribution. Learning rate was η = 10−5 for DGN and η = 0.02 for MLP.
4.3 Animal experiments
Animal housing and surgery
All animal procedures were approved by the local Animal Welfare and Ethical Review Board and performed under license from the UK Home Office in accordance with the Animals (Scientific Procedures) Act 1986 and generally followed procedures described previously [93]. Briefly, we used PV-Cre mice (B6;129P2-Pvalbtm1(cre)Arbr/J) [94] crossed with C57/BL6 wild type mice. Mice were group housed before and after surgery and maintained on a 12:12 day-night cycle. Surgical procedures were identical to those described in [93], except that we injected Cre-dependent GCaMP7f (pGP-AAV-CAG-FLEX-jGCaMP7f-WPRE [serotype 1]; [95]) diluted from its stock titer at 1:25. After mice had recovered from surgery, they were placed under water restriction for at least 5 days during which time they were acclimated to the recording setup and expression-checked. All mice were maintained at 80-85 percent of their initial weight over the course of imaging. Trained mice typically received all their water for the day from reward during the behavioral task, while na ї ve mice were supplemented to 1 g water per day with Hydrogel.
Cue-reward association training
Mice were trained on a conditioning protocol in which an auditory cue (4 kHz, 100 ms duration) was paired with a reward delivered 500 ms after cue onset, similar to the conditioning paradigm described in [93]. Responses of MLIs and PC dendrites to reward delivery were recorded and analyzed during the first 125 trials after initial cue-reward pairing to assess response consistency across the initial learning phase of this association.
Two-photon calcium imaging, data acquisition, and processing
Imaging experiments were performed using a 16x/0.8 NA objective (Nikon) mounted on a Sutter MOM microscope equipped with the Resonant Scan Box module. A Ti:Sapphire laser tuned to 930 nm (Mai Tai, Spectra Physics) was raster scanned using a resonant scanning galvanometer (8 kHz, Cambridge Technologies) and images were collected at 512×256 pixel resolution over fields of view of 450×225 µm per plane. Volumetric imaging across 5 planes spaced by 10 µm (depth ranging 25-65 µm below pial surface) were performed using a P-726 PIFOC High-Load Objective Scanner (Physik Instruments) at an effective volume rate of 9.7 Hz. The microscope was controlled using ScanImage (Version 2015, Vidrio Technologies) and tilted to 10 degrees such that the objective was orthogonal to the surface of the brain and coverglass. ROIs corresponding to single MLIs and PC dendrites were extracted using a combination of Suite2p software [96] for initial source extraction and custom-written software to merge PC dendritic ROIs across recording planes, which exhibited highly correlated calcium signals. Calcium signals corresponding to individual MLI somata and PC dendrites, which were easily distinguishable based on their shape, were computed as (F-F0)/F0 where F was the signal measured at each point in time and F0 is the 8th percentile of a 200 second rolling average surrounding each data time point). A neuropil correction coefficient of 0.5 (50 percent of neuropil signal output from Suite2p) was applied to all ROIs. A range of baseline durations and neuropil correction coefficients were tested and varying these parameters did not alter the main findings. Fluorescence changes for each neuron were then z-scored over time to facilitate comparisons between individual neurons with different baseline expression levels. Behavioural events and imaging synchronization signals were acquired using PackIO (see [93] for detailed description) and aligned offline using custom written scripts.
Code availability
We provide pseudo code in Supplementary Algorithms S1 and S2. A simple python implementation can be accessed via https://github.com/deepmind/deepmind-research/blob/master/gated_linear_networks/colabs/dendritic_gated_network.ipynb.
Data availability
The data that support the findings of this study are available from the corresponding authors upon reasonable request. Additional analysis made use of standard publicly available benchmarks including MNIST [97] and SARCOS (http://www.gaussianprocess.org/gpml/data/).
Author contributions
ES and AGB developed the computational model with advice from JV, CC, PEL, DB, and MHut. AGB, ES, and SK performed simulation experiments and analysis with advice from CC and PEL. DK and MBea acquired and analyzed neuronal data with advice from MH& x00E4; u. PEL, AGB, ES, and DK wrote the paper with help from all other authors. AGB and ES managed the project with support from MBot, CC, JV, and DB.
Competing Interests
The authors declare no competing interests.
Supplementary Methods
Convexity
If we ignore clipping, which has no effect on the convexity proof, the structure of the loss C as a function of the weight vector w is as follows: 𝓁 (r*, r) with r = Ø (h) and h = c · w. Concretely, for neuron i in layer k, we have r = rk,i and h = hk,i ∈ ℝ and and
and · denotes sum over j and b. If 𝓁 (r*, Ø (h))) is convex in h, then 𝓁 is also convex in w, since h is a linear function of w (e.g. [98] Sec.3.2.2). For quadratic loss (4) and Ø being the identity,
is obviously convex in h hence w. For log-loss (9) and Ø (h) = σ(h) = 1/(1 + e−h), it is easy to show that ∂2C(r*, Ø (h)))/∂h2 = σ(h) (1 − σ(h)) > 0, hence, again, C is convex in h and therefore also in w.
Inverse Kinematics
In Table S1 we compare the mean square error (MSE) obtained by DGN against baselines obtained from [30, 91, 92]. Note that, as mentioned in Methods, we (like others) used a test set that contained training examples.
Catastrophic Forgetting (permuted MNIST)
Hyerparameter selection
We select the hyperparameters for the three methods utilizing a grid search. The swept and the chosen parameters are displayed in Table S2.
Learning curves
In Fig. S1 we display the test performance of previously learned tasks (columns) as a function of the training across multiple tasks. To reduce clutter, a subset of the tasks (1, 2, 4, and 8, out of 10) are shown. The top left plot (train and test on task 1) shows that DGNs learns the first task much faster than all other methods. The plots to the right of that show retention on task 1 while the network is sequentially trained on subsequent tasks. MLP performances drop drastically after learning a few new tasks, while DGN and EWC show little forgetting. This is a remarkable feat for DGNs, which have no access to task boundaries and no explicit memory of previously learned tasks. EWC, on the other hand, has both. If we look at the four diagonal plots, we see that DGN learns new tasks faster than all other methods, although the difference gets smaller as more tasks are learned.
Parameters swept during grid search. The best parameters (shown in bold) are the ones that maximize the average test accuracy over 20 random seeds.
The final accuracies across the diagonal correspond to the left panel of Figure 3 whereas the final accuracies across the first row correspond to the right panel.
Retention results for permuted MNIST. Models are trained sequentially on 8 tasks (rows) and evaluated on all previously encountered tasks (columns). For example, the top row indicates performance on task 1 after being trained sequentially on tasks 1, 2, 4 and 8. Each model trains for one epoch per task. Error bars, indicated by the thickness of the lines, denote 95% confidence levels over 20 random seeds.
Pseudocode
Where Θ (·) is the Heaviside step function (Θ (z) = 1 for Θ (z) = 0 otherwise).
where clips values between a and b,
σ(·) is the sigmoid function, σ(z) = ez/(1 + ez), and σ−1(·),its inverse, is given by σ−1(y) = log(y/(1 − y)).
Acknowledgements
We thank Timothy Lillicrap, Gregory Wayne, and Eszter Vértes for valuable feedback. Michael H ä usser is supported by the Wellcome Trust and the European Research Council. Peter Latham is supported by the Gatsby Charitable Foundation and the Wellcome Trust.
References
- [1].↵
- [2].↵
- [3].↵
- [4].↵
- [5].↵
- [6].↵
- [7].↵
- [8].↵
- [9].
- [10].↵
- [11].↵
- [12].
- [13].↵
- [14].↵
- [15].↵
- [16].↵
- [17].↵
- [18].↵
- [19].↵
- [20].↵
- [21].↵
- [22].↵
- [23].↵
- [24].
- [25].↵
- [26].
- [27].
- [28].
- [29].↵
- [30].↵
- [31].↵
- [32].↵
- [33].↵
- [34].↵
- [35].
- [36].↵
- [37].↵
- [38].
- [39].
- [40].
- [41].
- [42].
- [43].↵
- [44].↵
- [45].↵
- [46].↵
- [47].↵
- [48].
- [49].↵
- [50].
- [51].↵
- [52].↵
- [53].↵
- [54].↵
- [55].↵
- [56].
- [57].↵
- [58].↵
- [59].↵
- [60].
- [61].
- [62].
- [63].↵
- [64].↵
- [65].
- [66].
- [67].
- [68].↵
- [69].↵
- [70].↵
- [71].↵
- [72].↵
- [73].↵
- [74].↵
- [75].
- [76].↵
- [77].↵
- [78].↵
- [79].↵
- [80].↵
- [81].↵
- [82].↵
- [83].↵
- [84].↵
- [85].↵
- [86].↵
- [87].↵
- [88].↵
- [89].↵
- [90].↵
- [91].↵
- [92].↵
- [93].↵
- [94].↵
- [95].↵
- [96].↵
- [97].↵
- [98].↵