Discovering cognitive strategies with tiny recurrent neural networks – Nature

-


All data were analysed using Python 3.9 and PyTorch 1.13.

Tasks and datasets

No statistical methods were used to predetermine sample sizes in this study. All datasets were drawn from previously published studies, and we included all available subjects (with enough trials for modelling) in each task. Allocation to experimental groups was not randomized by us; instead, randomization was previously performed by the original authors. Our study does not include any direct behavioural experimentation. Therefore, blinding was not required.

Reversal learning task

The reversal learning task is a paradigm designed to assess subjects’ ability to adapt their behaviour in response to changing reward contingencies. In each trial, subjects are presented with two actions, A1 and A2, yielding a unit reward with probability \({p}_{1}^{{\rm{reward}}}\) and \({p}_{2}^{{\rm{reward}}}\), respectively. These reward probabilities remain constant for several trials before switching unpredictably and abruptly, without explicit cues. When this occurs, the action associated with the higher reward probability becomes linked to the lower reward probability, and vice versa. The task necessitates continuous exploration of which action currently has a higher reward probability in order to maximize total rewards. For consistency with the other animal tasks, we assume that actions (A1 and A2) are made at the choice state, and Ai deterministically leads to state Si, where the reward is delivered.

In the Bartolo dataset10,48,49, 2 male monkeys (Rhesus macaque, Macaca mulatta; age 4.5 years) completed a total of 15,500 trials of the reversal learning task with 2 state-reward types: (1) \({p}_{1}^{{\rm{reward}}}=0.7\) and \({p}_{2}^{{\rm{reward}}}=0.3\); (2) \({p}_{1}^{{\rm{reward}}}=0.3\) and \({p}_{2}^{{\rm{reward}}}=0.7\). Blocks were 80 trials long, and the switch happened at a ‘reversal trial’ between trials 30 and 50. We predicted the behaviour from trials 10 to 70, similar to the original preprocessing procedure10 because the monkeys were inferring the current block type (‘what’ block, choosing from two objects; ‘where’ block, choosing from two locations) in the first few trials.

In the Akam dataset11,50, 10 male mice (C57BL6; aged between 2–3 months) completed a total of 67,009 trials of the reversal learning task with 3 state-reward types: (1) \({p}_{1}^{{\rm{reward}}}=0.75\) and \({p}_{2}^{{\rm{reward}}}=0.25\); (2) \({p}_{1}^{{\rm{reward}}}=0.25\) and \({p}_{2}^{{\rm{reward}}}=0.75\); (3) \({p}_{1}^{\text{reward}}=0.5\) and \({p}_{2}^{{\rm{reward}}}=0.5\) (neutral trials). Block transitions from non-neutral blocks were triggered 10 trials after an exponential moving average (tau = 8 trials) crossed a 75% correct threshold. Block transitions from neutral blocks occurred with a probability of 10% on each trial after the 15th of the block to give an average neutral block length of 25 trials.

Two-stage task

The two-stage task is a paradigm commonly used to distinguish between the influences of model-free and model-based RL on animal behaviour51, and later reduced in ref. 34. In each trial, subjects are presented with two actions, A1 and A2, while at the choice state. Action A1 leads with a high probability to state S1 and a low probability to state S2, while action A2 leads with a high probability to state S2 and a low probability to state S1. From second-stage states S1 and S2, the animal can execute an action for a chance of receiving a unit reward. Second-stage states are distinguishable by visual cues and have different probabilities of yielding a unit reward: \({p}_{1}^{{\rm{reward}}}\) for S1 and \({p}_{2}^{{\rm{reward}}}\) for S2. These reward probabilities remain constant for several trials before switching unpredictably and abruptly. When this occurs, the second-stage state associated with the higher reward probability becomes linked to the lower reward probability, and vice versa.

In the Miller dataset12,52, 4 adult male Long-Evans rats (Taconic Biosciences; Hilltop Lab Animals) completed a total of 33,957 trials of the two-stage task with 2 state-reward types: (1) \({p}_{1}^{{\rm{reward}}}=0.8\) and \({p}_{2}^{{\rm{reward}}}=0.2\); (2) \({p}_{1}^{{\rm{reward}}}=0.2\) and \({p}_{2}^{{\rm{reward}}}=0.8\). Block switches occurred with a 2% probability on each trial after a minimum block length of 10 trials.

In the Akam dataset11,50, 10 male mice (C57BL6; aged between 2–3 months) completed a total of 133,974 trials of the two-stage task with 3 state-reward types: (1) \({p}_{1}^{{\rm{reward}}}=0.8\) and \({p}_{2}^{{\rm{reward}}}=0.2\); (2) \({p}_{1}^{{\rm{reward}}}=0.2\) and \({p}_{2}^{{\rm{reward}}}=0.8\); (3) \({p}_{1}^{{\rm{reward}}}=0.4\) and \({p}_{2}^{{\rm{reward}}}=0.4\) (neutral trials). Block transitions occur 20 trials after an exponential moving average (tau = 8 trials) of the subject’s choices crossed a 75% correct threshold. In neutral blocks, block transitions occurred with 10% probability on each trial after the 40th trial. Transitions from non-neutral blocks occurred with equal probability either to another non-neutral block or to the neutral block. Transitions from neutral blocks occurred with equal probability to one of the non-neutral blocks.

Transition-reversal two-stage task

The transition-reversal two-stage task is a modified version of the original two-stage task, with the introduction of occasional reversals in action-state-transition probabilities11. This modification was proposed to facilitate the dissociation of state prediction and reward prediction in neural activity and to prevent habit-like strategies that may produce model-based control-like behaviour without forward planning. In each trial, subjects are presented with two actions, A1 and A2, at the choice state. One action commonly leads to state S1 and rarely to state S2, while the other action commonly leads to state S2 and rarely to state S1. These action-state-transition probabilities remain constant for several trials before switching unpredictably and abruptly, without explicit cues. In the second-stage states S1 and S2, subjects execute an action for a chance of receiving a unit reward. The second-stage states are visually distinguishable and have different reward probabilities that also switch unpredictably and abruptly, without explicit cues, similar to the other two tasks.

In the Akam dataset11,50, 17 male mice (C57BL6; aged between 2–3 months) completed a total of 230,237 trials of the transition-reversal two-stage task with 2 action-state types: (1) Pr(S1A1) = Pr(S2A2) = 0.8 and Pr(S2A1) = Pr(S1A2) = 0.2; (2) Pr(S1A1) = Pr(S2A2) = 0.2 and Pr(S2A1) = Pr(S1A2) = 0.8. There were also 3 state-reward types: (1) \({p}_{1}^{{\rm{reward}}}=0.8\) and \({p}_{2}^{{\rm{reward}}}=0.2\); (2) \({p}_{1}^{{\rm{reward}}}=0.2\) and \({p}_{2}^{{\rm{reward}}}=0.8\); (3) \({p}_{1}^{{\rm{reward}}}=0.4\) and \({p}_{2}^{{\rm{reward}}}=0.4\) (neutral trials). Block transitions occur 20 trials after an exponential moving average (tau = 8 trials) of the subject’s choices crossed a 75% correct threshold. In neutral blocks, block transitions occurred with 10% probability on each trial after the 40th trial. Transitions from non-neutral blocks occurred with equal probability (25%) either to another non-neutral block via reversal in the reward or transition probabilities, or to one of the two neutral blocks. Transitions from neutral blocks occurred via a change in the reward probabilities only to one of the non-neutral blocks with the same transition probabilities.

Three-armed reversal learning task

In the Suthaharan dataset53, 1,010 participants (605 participants from the pandemic group and 405 participants from the replication group) completed a three-armed probabilistic reversal learning task. This task was framed as either a non-social (card deck) or social (partner) domain, each lasting 160 trials divided evenly into 4 blocks. Participants were presented with 3 actions (A1A2 and A3; 3 decks of cards in the non-social domain frame or 3 avatar partners in the social domain frame), each containing different amounts of winning (+100) and losing (−50) points. The objective was to find the best option and earn as many points as possible, knowing that the best option could change.

The task contingencies started with 90%, 50% and 10% reward probabilities, with the best deck/partner switching after 9 out of 10 consecutive rewards. Unknown to the participants, the underlying contingencies transitioned to 80%, 40%, and 20% reward probabilities at the end of the second block, making it more challenging to distinguish between probabilistic noise and genuine changes in the best option.

Four-armed drifting bandit task

The Bahrami dataset54 includes 975 participants who completed the 4-arm bandit task55. Participants were asked to choose between 4 options on 150 trials. On each trial, they chose an option and were given a reward. The rewards for each option drifted over time in a manner known as a restless bandit, forcing the participants to constantly explore the different options to obtain the maximum reward. The rewards followed one of three predefined drift schedules54.

During preprocessing, we removed 57 participants (5.9%) who missed more than 10% of trials. For model fitting, missing trials from other subjects are excluded from the loss calculation.

Original two-stage task

In the Gillan dataset56,57, the original version of the two-stage task51 was used to assess goal-directed (model-based) and habitual (model-free) learning in individuals with diverse psychiatric symptoms. In total, 1,961 participants (548 from the first experiment and 1413 from the second experiment) completed the task. In each trial, participants were presented with a choice between two options (A1 or A2). Each option commonly (70%) led to a particular second-stage state (A1S1 or A2S2). However, on 30% of ‘rare’ trials, choices led to the alternative second-stage state (A1S2 or A2S1). In the second-stage states, subjects chose between two options (B1/B2 in S1 or C1/C2 in S2), each associated with a distinct probability of being rewarded. The reward probabilities associated with each second-stage option drifted slowly and independently over time, remaining within the range of 0.25 to 0.75. To maximize rewards, subjects had to track which second-stage options were currently best as they changed over time.

For model fitting, missing stages or trials from some participants are excluded from the loss calculation.

Recurrent neural networks

Network architectures

We investigated several architectures, as described below. Our primary goal is to capture the maximum possible behavioural variance with d dynamical variables. While we generally prefer more flexible models due to their reduced bias, such models typically require more data for training, and insufficient data can result in underfitting and poorer performance in comparison to less flexible (simpler) models. Therefore, we aimed to balance data efficiency and model capacity through cross-validation.

After finding the best-performing model class, we performed an investigation of the network properties that contributed the most to the successfully explained variance. Analogous to ablation studies, our approach consisted of gradually removing components or adding constraints to the architectures, such as eliminating nonlinearity or introducing symmetric weight constraints. The unaffected predictive performance suggests that the examined components are not essential for the successfully explained variance. If affected, this indicates that these components can contribute to explaining additional behavioural patterns. Following this approach, we can establish connections between architectural components and their corresponding underlying behavioural patterns. The primary objective of this approach is to capture maximum variance with minimal components in the models, resulting in highly interpretable models.

Recurrent layer

The neural network models in this paper used the vanilla GRUs in their hidden layers31. The hidden state ht at the beginning of trial t consists of d elements (dynamical variables). The initial hidden state h1 is set to 0 and ht (t > 1) is updated as follows:

$$\begin{array}{l}{r}_{t}=\sigma ({W}_{ir}{x}_{t-1}+{b}_{ir}+{W}_{hr}{h}_{t-1}+{b}_{hr})\\ {z}_{t}=\sigma ({W}_{iz}{x}_{t-1}+{b}_{iz}+{W}_{hz}{h}_{t-1}+{b}_{hz})\\ {n}_{t}=\tanh ({W}_{in}\,{x}_{t-1}+{b}_{in}+{r}_{t}\odot ({W}_{hn}{h}_{t-1}+{b}_{hn}))\\ {h}_{t}=(1-{z}_{t})\odot {n}_{t}+{z}_{t}\odot {h}_{t-1}\end{array}$$

(1)

where σ is the sigmoid function, is the Hadamard (element-wise) product, xt − 1 and ht − 1 are the input and hidden state from the last trial t − 1, and rt, zt and nt are the reset, update and new gates (intermediate variables) at trial t, respectively. The weight matrices W and biases b are trainable parameters. The d-dimensional hidden state of the network, ht, represents a summary of past inputs and is the only information used to generate outputs.

Importantly, the use of GRUs means that the set of d-unit activations fully specifies the network’s internal state, rendering the system Markovian (that is, ht is fully determined by ht − 1 and xt − 1). This is in contrast to alternative RNN architectures such as the long short-term memory58, where the use of a cell state renders the system non-Markovian (that is, the output state ht cannot be fully determined by ht − 1 and xt − 1).

To accommodate discrete inputs, we also introduce a modified architecture called switching GRU, where recurrent weights and biases are input-dependent, similar to discrete-latent-variable-dependent switching linear dynamical systems59. In this architecture, the hidden state ht (t > 1) is updated as follows:

$$\begin{array}{l}{r}_{t}=\sigma ({b}_{ir}^{({x}_{t-1})}+{W}_{hr}^{({x}_{t-1})}{h}_{t-1}+{b}_{hr}^{({x}_{t-1})})\\ {z}_{t}=\sigma ({b}_{iz}^{({x}_{t-1})}+{W}_{hz}^{({x}_{t-1})}{h}_{t-1}+{b}_{hz}^{({x}_{t-1})})\\ {n}_{t}=\tanh ({b}_{in}^{({x}_{t-1})}+{r}_{t}\odot ({W}_{hn}^{({x}_{t-1})}{h}_{t-1}+{b}_{hn}^{({x}_{t-1})}))\\ {h}_{t}=(1-{z}_{t})\odot {n}_{t}+{z}_{t}\odot {h}_{t-1}\end{array}$$

(2)

where \({W}_{h\cdot }^{({x}_{t-1})}\) and \({b}_{\cdot \cdot }^{({x}_{t-1})}\) are the weight matrices and biases selected by the input xt − 1 (that is, each input xt − 1 induces an independent set of weights Wh and biases b).

For discrete inputs, switching GRUs are a generalization of vanilla GRUs (that is, a vanilla GRU can be viewed as a switching GRU whose recurrent weights do not vary with the input). Generalizations of switching GRUs from discrete to continuous inputs are closely related to multiplicative integration GRUs60.

For animal datasets, we found that the switching GRU models performed similarly to the vanilla GRU models for d ≥ 2, but consistently outperformed the vanilla GRU models for d = 1. Therefore, for the results of animal datasets in the main text, we reported the performance of the switching GRU models for d = 1 and the performance of the vanilla GRU models for d ≥ 2. Mathematically, these vanilla GRU models can be directly transformed into corresponding switching GRU models:

$$\begin{array}{l}{b}_{i.}^{({x}_{t-1})}\,\leftarrow \,{W}_{i.}\,{x}_{t-1}+{b}_{i.}\\ {b}_{h.}^{({x}_{t-1})}\,\leftarrow \,{b}_{h.}\\ {W}_{h.}^{({x}_{t-1})}\,\leftarrow \,{W}_{h.}\end{array}$$

(3)

We also proposed the switching linear neural networks (SLIN), where the hidden state ht (t > 1) is updated as follows:

$${h}_{t}={W}^{({x}_{t-1})}{h}_{t-1}+{b}^{({x}_{t-1})}$$

(4)

where \({W}^{({x}_{t-1})}\) and \({b}^{({x}_{t-1})}\) are the weight matrices and biases selected by the input xt − 1. In some variants, we constrained \({W}^{({x}_{t-1})}\) to be symmetric.

Input layer

The network’s input xt consists of the previous action at − 1, the previous second-stage state st − 1, and the previous reward rt − 1 (but at = st in the reversal learning task). In the vanilla GRU networks, the input xt is three-dimensional and projects with linear weights to the recurrent layer. In the switching GRU networks, the input xt is used as a selector variable where the network’s recurrent weights and biases depend on the network’s inputs. Thus, switching GRUs trained on the reversal learning task have four sets of recurrent weights and biases corresponding to all combinations of at − 1 and rt − 1, and switching GRUs trained on the two-stage and transition-reversal two-stage tasks have eight sets of recurrent weights and biases corresponding to all combinations of at − 1, st − 1 and rt − 1.

Output layer

The network’s output consists of two units whose activities are linear functions of the hidden state ht. A softmax function (a generalization of the logistic function) is used to convert these activities into a probability distribution (a policy). In the first trial, the network’s output is read out from the initial hidden state h1, which has not yet been updated on the basis of any input. For d-unit networks, the network’s output scores were computed either from a fully connected readout layer (that is, \({s}_{t}^{(i)}={\sum }_{j=1}^{d}{\beta }_{i,j}\cdot {h}_{t}^{(j)}\), i = 1, …, d) or from a diagonal readout layer (that is, \({s}_{t}^{(i)}={\beta }_{i}\cdot {h}_{t}^{(i)}\), i = 1, …, d). The output scores are sent to the softmax layer to produce action probabilities.

Network training

Networks were trained using the Adam optimizer (learning rate of 0.005) on batched training data with cross-entropy loss, recurrent weight L1-regularization loss (coefficient drawn between 10−5 and 10−1, depending on experiments), and early stop (if the validation loss does not improve for 200 iteration steps). All networks were implemented with PyTorch.

Classical cognitive models

Models for the reversal learning task

In this task, we implemented one model from the Bayesian inference family and eight models from the model-free family (adopted from34 and12, or constructed from RNN phase portraits).

Bayesian inference strategy (d = 1)

This model (also known as latent-state) assumes the existence of the latent-state h, with h = i representing a higher reward probability following action Ai (state Si). The probability \({\Pr }_{t}(h=1)\), as the dynamical variable, is first updated via Bayesian inference:

$${\overline{\Pr }}_{t}(h=1)=\frac{\Pr ({r}_{t-1}| h=1,{s}_{t-1}){\Pr }_{t-1}(h=1)}{\Pr ({r}_{t-1}| h=1,{s}_{t-1}){\Pr }_{t-1}(h=1)+\Pr ({r}_{t-1}| h=2,{s}_{t-1}){\Pr }_{t-1}(h=2)},$$

(5)

where the left-hand side is the posterior probability (we omit the conditions for simplicity). The agent also incorporates the knowledge that, in each trial, the latent-state h can switch (for example, from h = 1 to h = 2) with a small probability pr. Thus the probability \({\Pr }_{t}(h)\) reads,

$${\Pr }_{t}(h=1)=(1-{p}_{r}){\overline{\Pr }}_{t}(h=1)+{p}_{r}(1-{\overline{\Pr }}_{t}(h=1)).$$

(6)

The action probability is then derived from softmax (βPrt(h = 1), βPrt(h = 2)) with inverse temperature β (β ≥ 0).

Model-free strategy (d = 1)

This model hypothesizes that the two action values Qt(Ai) are fully anti-correlated (Qt(A1) = −Qt(A2)) as follows:

$$\begin{array}{l}{Q}_{t}({a}_{t-1})\,=\,{Q}_{t-1}({a}_{t-1})+\alpha ({r}_{t-1}-{Q}_{t-1}({a}_{t-1}))\\ {Q}_{t}({\overline{{a}}}_{t-1})\,=\,{Q}_{t-1}({\overline{{a}}}_{t-1})-\alpha ({r}_{t-1}+{Q}_{t-1}({\overline{{a}}}_{t-1})),\end{array}$$

(7)

where \({\overline{{a}}}_{t-1}\) is the unchosen action, and α is the learning rate (0 ≤ α ≤ 1). We specify the Qt(A1) as the dynamical variable.

Model-free strategy (d = 2)

This model hypothesizes that the two action values Qt(Ai), as two dynamical variables, are updated independently:

$${Q}_{t}({a}_{t-1})={Q}_{t-1}({a}_{t-1})+\alpha ({r}_{t-1}-{Q}_{t-1}({a}_{t-1})).$$

(8)

The unchosen action value \({Q}_{t}({\overline{{a}}}_{t-1})\) is unaffected.

Model-free strategy with value forgetting (d = 2)

The chosen action value is updated as in the previous model. The unchosen action value \({Q}_{t}({\overline{{a}}}_{t-1})\), instead, is gradually forgotten:

$${Q}_{t}({\overline{{a}}}_{t-1})=D{Q}_{t-1}({\overline{{a}}}_{t-1}),$$

(9)

where D is the value forgetting rate (0 ≤ D ≤ 1).

Model-free strategy with value forgetting to mean (d = 2)

This model is the ‘forgetful model-free strategy’ proposed in61. The chosen action value is updated as in the previous model. The unchosen action value \({Q}_{t}({\overline{{a}}}_{t-1})\), instead, is gradually forgotten to a initial value (\(\widetilde{V}=1/2\)):

$${Q}_{t}({\overline{{a}}}_{t-1})=D{Q}_{t-1}({\overline{{a}}}_{t-1})+(1-D)\widetilde{V},$$

(10)

where D is the value forgetting rate (0 ≤ D ≤ 1).

Model-free strategy with the drift-to-the-other rule (d = 2)

This strategy is constructed from the phase diagram of the two-unit RNN. When there is a reward, the chosen action value is updated as follows,

$${Q}_{t}({a}_{t-1})={D}_{1}{Q}_{t-1}({a}_{t-1})+1,$$

(11)

where D1 is the value drifting rate (0 ≤ D1 ≤ 1). The unchosen action value is slightly decreased:

$${Q}_{t}({\overline{{a}}}_{t-1})={Q}_{t-1}({\overline{{a}}}_{t-1})-b,$$

(12)

where b is the decaying bias (0 ≤ b ≤ 1, usually small). When there is no reward, the unchosen action value is unchanged, and the chosen action value drifts to the other:

$${Q}_{t}({a}_{t-1})={Q}_{t-1}({a}_{t-1})+{\alpha }_{0}({Q}_{t-1}({\overline{{a}}}_{t-1})-{Q}_{t-1}({a}_{t-1})),$$

(13)

where α0 is the drifting rate (0 ≤ α0 ≤ 1).

For all model-free RL models with d = 2, the action probability is determined by softmax (βQt(A1), βQt(A2)).

Model-free strategy with inertia (d = 2)

The action values are updated as the model-free strategy (d = 1). The action perseveration (inertia) is updated by:

$$\begin{array}{l}{X}_{t}({a}_{t-1})\,=\,{X}_{t-1}({a}_{t-1})+{\alpha }_{{\rm{pers}}}({k}_{{\rm{pers}}}-{X}_{t-1}({a}_{t-1}))\\ {X}_{t}({\overline{{a}}}_{t-1})\,=\,{X}_{t-1}({\overline{{a}}}_{t-1})-{\alpha }_{{\rm{pers}}}({k}_{{\rm{pers}}}+{X}_{t-1}({\overline{{a}}}_{t-1}))\end{array}$$

(14)

where αpers is the perseveration learning rate (0 ≤ αpers ≤ 1), and kpers is the single-trial perseveration term, affecting the balance between action values and action perseverations.

Model-free strategy with inertia (d = 3)

The action values are updated as the model-free strategy (d = 2). The action perseveration (inertia) is updated by the same rule in the model-free strategy with inertia (d = 2).

The action probabilities in all model-free models with inertia are generated via \({\rm{softmax}}\,({\{\beta ({Q}_{t}({A}_{i})+{X}_{t}({A}_{i}))\}}_{i})\). Both the action values and action perseverations are dynamical variables.

Model-free reward-as-cue strategy (d = 8)

This model assumes that the animal considers the combination of the second-stage state st − 1 and the reward rt − 1 from the trial t − 1 as the augmented state \({{\mathcal{S}}}_{t}\) for trial t. The eight dynamical variables are the values for the two actions at the four augmented states. The action values are updated as follows:

$${Q}_{t}({{\mathcal{S}}}_{t-1},{a}_{t-1})={Q}_{t-1}({{\mathcal{S}}}_{t-1},{a}_{t-1})+\alpha ({r}_{t-1}-{Q}_{t-1}({{\mathcal{S}}}_{t-1},{a}_{t-1})).$$

(15)

The action probability at trial t is determined by \({\rm{softmax}}\,(\beta {Q}_{t}({{\mathcal{S}}}_{t},{A}_{1}),\beta {Q}_{t}({{\mathcal{S}}}_{t},{A}_{2}))\).

Models for the two-stage task

We implemented one model from the Bayesian inference family, four models from the model-free family, and four from the model-based family (adopted from refs. 12,34).

Bayesian inference strategy (d = 1)

Same as Bayesian inference strategy (d = 1) in the reversal learning task, except that h = i represents a higher reward probability following state Si (not action Ai).

Model-free strategy (d = 1)

Same as the model-free strategy (d = 1) in the reversal learning task by ignoring the second-stage states st − 1.

Model-free Q(1) strategy (d = 2)

Same as the model-free strategy (d = 2) in the reversal learning task by ignoring the second-stage states st − 1.

Model-free Q(0) strategy (d = 4)

This model first updates the first-stage action values Qt(at − 1) with the second-stage state values Vt − 1(st − 1):

$${Q}_{t}({a}_{t-1})={Q}_{t-1}({a}_{t-1})+\alpha ({V}_{t-1}({s}_{t-1})-{Q}_{t-1}({a}_{t-1})),$$

(16)

while the unchosen action value \({Q}_{t}({\overline{{a}}}_{t-1})\) is unaffected. Then the second-stage state value Vt(st − 1) is updated by the observed reward:

$${V}_{t}({s}_{t-1})={V}_{t-1}({s}_{t-1})+\alpha ({r}_{t-1}-{V}_{t-1}({s}_{t-1})).$$

(17)

The four dynamical variables are the two action values and two state values.

Model-free reward-as-cue strategy (d = 8)

Same as model-free reward-as-cue strategy (d = 8) in the reversal learning task.

Model-based strategy (d = 1)

In this model, the two state values Vt(Si) are fully anti-correlated (Vt(S1) = −Vt(S2)):

$$\begin{array}{l}{V}_{t}({s}_{t-1})\,=\,{V}_{t-1}({s}_{t-1})+\alpha ({r}_{t-1}-{V}_{t-1}({s}_{t-1}))\\ {V}_{t}({\overline{s}}_{t-1})\,=\,{V}_{t-1}({\overline{s}}_{t-1})-\alpha ({r}_{t-1}+{V}_{t-1}({\overline{s}}_{t-1})),\end{array}$$

(18)

where \({\bar{s}}_{t-1}\) is the unvisited state. The dynamical variable is the state value Vt(S1).

Model-based strategy (d = 2)

The visited state value is updated:

$${V}_{t}({s}_{t-1})={V}_{t-1}({s}_{t-1})+\alpha ({r}_{t-1}-{V}_{t-1}({s}_{t-1})).$$

(19)

The unvisited state value is unchanged. The two dynamical variables are the two state values.

Model-based strategy with value forgetting (d = 2)

The visited state value is updated as in the previous model. The unvisited state value is gradually forgotten:

$${V}_{t}({\bar{s}}_{t-1})=D{V}_{t-1}({\bar{s}}_{t-1}),$$

(20)

where D is the value forgetting rate (0 ≤ D ≤ 1).

For all model-based RL models, the action values at the first stage are directly computed using the state-transition model:

$${Q}_{t}^{{\rm{m}}{\rm{b}}}({A}_{i})=\sum _{j}Pr({S}_{j}|{A}_{i}){V}_{t}({S}_{j}),$$

(21)

where Pr(SjAi) is known. The action probability is determined by \(\text{softmax}\,(\beta {Q}_{t}^{{\rm{m}}{\rm{b}}}({A}_{1}),\beta {Q}_{t}^{{\rm{m}}{\rm{b}}}({A}_{2}))\).

Model-based mixture strategy (d = 2)

This model is a mixture of the model-free strategy (d = 1) and the model-based strategy (d = 1). The net action values are determined by:

$${Q}_{t}^{{\rm{n}}{\rm{e}}{\rm{t}}}({A}_{i})=(1-w){Q}_{t}^{{\rm{m}}{\rm{f}}}({A}_{i})+w{Q}_{t}^{{\rm{m}}{\rm{b}}}({A}_{i}),$$

(22)

where w controls the strength of the model-based component. The action probabilities are generated via \(\text{softmax}\,(\beta {Q}_{t}^{{\rm{net}}}({A}_{1}),\beta {Q}_{t}^{{\rm{net}}}({A}_{2}))\). \({Q}_{t}^{{\rm{mf}}}({A}_{1})\) and Vt(S1) are the dynamical variables.

Models for the transition-reversal two-stage task

For this task, we further include cognitive models proposed in ref. 11. We first describe different model components (ingredients) and corresponding numbers of dynamical variables, and then specify the components employed in each model.

Second-stage state value component

The visited state value is updated:

$${V}_{t}({s}_{t-1})={V}_{t-1}({s}_{t-1})+{\alpha }_{Q}({r}_{t-1}-{V}_{t-1}({s}_{t-1})).$$

(23)

The unvisited state value \({V}_{t}({\bar{s}}_{t-1})\) is either unchanged or gradually forgotten with fQ as the value forgetting rate. This component requires two dynamical variables.

Model-free action value component

The first-stage action values \({Q}_{t}^{{\rm{mf}}}({a}_{t-1})\) are updated by the second-stage state values Vt − 1(st − 1) and the observed reward:

$${Q}_{t}^{{\rm{mf}}}({a}_{t-1})={Q}_{t-1}^{{\rm{mf}}}({a}_{t-1})+\alpha (\lambda {r}_{t-1}+(1-\lambda ){V}_{t-1}({s}_{t-1})-{Q}_{t-1}^{{\rm{mf}}}({a}_{t-1})),$$

(24)

where λ is the eligibility trace. The unchosen action value \({Q}_{t}^{{\rm{mf}}}({\overline{{a}}}_{t-1})\) is unaffected or gradually forgotten with fQ as the value forgetting rate. This component requires two dynamical variables.

Model-based component

The action-state-transition probabilities are updated as:

$$\begin{array}{r}{P}_{t}({s}_{t-1}| {a}_{t-1})={P}_{t-1}({s}_{t-1}| {a}_{t-1})+{\alpha }_{T}(1-{P}_{t-1}({s}_{t-1}| {a}_{t-1}))\\ {P}_{t}({\overline{s}}_{t-1}| {a}_{t-1})={P}_{t-1}({\overline{s}}_{t-1}| {a}_{t-1})+{\alpha }_{T}(0-{P}_{t-1}({\overline{s}}_{t-1}| {a}_{t-1})),\end{array}$$

(25)

where αT is the transition probability learning rate. For the unchosen action, the action-state-transition probabilities are either unchanged or forgotten:

$$\begin{array}{l}{P}_{t}({s}_{t-1}| {\overline{{a}}}_{t-1})={P}_{t-1}({s}_{t-1}| {\overline{{a}}}_{t-1})+{f}_{T}(0.5-{P}_{t-1}({s}_{t-1}| {\overline{{a}}}_{t-1}))\\ {P}_{t}({\overline{s}}_{t-1}| {\overline{{a}}}_{t-1})={P}_{t-1}({\overline{s}}_{t-1}| {\overline{{a}}}_{t-1})+{f}_{T}(0.5-{P}_{t-1}({\overline{s}}_{t-1}| {\overline{{a}}}_{t-1})),\end{array}$$

(26)

where fT is the transition probability forgetting rate.

The model-based action values at the first stage are directly computed using the learned state-transition model:

$${Q}_{t}^{{\rm{mb}}}({A}_{i})=\sum _{j}{P}_{t}({S}_{j}| {A}_{i}){V}_{t}({S}_{j}).$$

(27)

This component requires two dynamical variables (Pt(S1A1) and Pt(S1A2)), since other variables can be directly inferred.

Motor-level model-free action component

Due to the apparatus design in this task11, it is proposed that the mice consider the motor-level actions \({a}_{t-1}^{{\rm{m}}{\rm{o}}}\), defined as the combination of the last-trial action at − 1 and the second-stage state st−2 before it. The motor-level action values \({Q}_{t}^{{\rm{m}}{\rm{o}}}({a}_{t-1}^{{\rm{m}}{\rm{o}}})\) are updated as:

$${Q}_{t}^{{\rm{m}}{\rm{o}}}({a}_{t-1}^{{\rm{m}}{\rm{o}}})={Q}_{t-1}^{{\rm{m}}{\rm{o}}}({a}_{t-1}^{{\rm{m}}{\rm{o}}})+\alpha (\lambda {r}_{t-1}+(1-\lambda ){V}_{t-1}({s}_{t-2})-{Q}_{t-1}^{{\rm{m}}{\rm{o}}}({a}_{t-1}^{{\rm{m}}{\rm{o}}})),$$

(28)

where λ is the eligibility trace. The unchosen motor-level action value \({Q}_{t}^{{\rm{m}}{\rm{o}}}\) is unaffected or gradually forgotten with fQ as the value forgetting rate. This component requires four dynamical variables (four motor-level actions).

Choice perseveration component

The single-trial perseveration \({\widetilde{X}}_{t-1}^{{\rm{cp}}}\) is set to −0.5 for at − 1 = A1 and 0.5 for at − 1 = A2. The multi-trial perseveration \({Q}_{t-1}^{{\rm{c}}{\rm{p}}}\) (exponential moving average of choices) is updated as:

$${X}_{t}^{{\rm{cp}}}={X}_{t-1}^{{\rm{cp}}}+{\alpha }_{c}({\widetilde{X}}_{t-1}^{{\rm{cp}}}-{X}_{t-1}^{{\rm{cp}}}),$$

(29)

where αc is the choice perseveration learning rate. In some models, the αc is less than 1, so one dynamical variable is required; while in some other models, the αc is fixed to 1, suggesting that it is reduced to the single-trial perseveration and no dynamical variable is required.

Motor-level choice perseveration component

The multi-trial motor-level perseveration \({X}_{t-1}^{{\rm{mocp}}}({s}_{t-2})\) is updated as:

$${X}_{t}^{{\rm{mocp}}}({s}_{t-2})={X}_{t-1}^{{\rm{mocp}}}({s}_{t-2})+{\alpha }_{m}({\widetilde{X}}_{t-1}^{{\rm{cp}}}-{X}_{t-1}^{{\rm{mocp}}}({s}_{t-2})),$$

(30)

where αm is the motor-level choice perseveration learning rate. This component requires two dynamical variables.

Action selection component

The net action values are computed as follows:

$${Q}_{t}^{{\rm{net}}}({A}_{i})={G}^{{\rm{mf}}}{Q}_{t}^{{\rm{mf}}}({A}_{i})+{G}^{{\rm{mo}}}{Q}_{t}^{{\rm{mo}}}({A}_{i},{s}_{t-1})+{G}^{{\rm{mb}}}{Q}_{t}^{{\rm{mb}}}({A}_{i})+{X}_{t}({A}_{i}),$$

(31)

where Gmf, Gmo and Gmb are model-free, motor-level model-free and model-based inverse temperatures, respectively, and Xt(Ai) is:

$$\begin{array}{l}{X}_{t}({A}_{1})=0\\ {X}_{t}({A}_{2})={B}_{c}+{B}_{r}{\widetilde{X}}_{t-1}^{{\rm{s}}}+{P}_{c}{X}_{t}^{{\rm{cp}}}+{P}_{m}{X}_{t}^{{\rm{mocp}}}({s}_{t-1}),\end{array}$$

(32)

where Bc (bias), Br (rotation bias), Pc, Pm are weights controlling each component, and \({\widetilde{X}}_{t-1}^{s}\) is −0.5 for st − 1 = S1 and 0.5 for st − 1 = S2.

The action probabilities are generated via \(\text{softmax}\,({Q}_{t}^{{\rm{net}}}({A}_{1}),\) \({Q}_{t}^{{\rm{net}}}({A}_{2}))\).

Model-free strategies

We include five model-free RL models:

  1. (1)

    the model-free strategy (d = 1) same as the two-stage task;

  2. (2)

    the model-free Q(1) strategy (d = 2) same as the two-stage task;

  3. (3)

    state value [2] + model-free action value [2] + bias [0] + rotation bias [0] + single-trial choice perseveration [0];

  4. (4)

    state value [2] + model-free action value with forgetting [2] + bias [0] + rotation bias [0] + single-trial choice perseveration [0];

  5. (5)

    state value [2] + model-free action value with forgetting [2] + motor-level model-free action value with forgetting [4] + bias [0] + rotation bias [0] + multi-trial choice perseveration [1] + multi-trial motor-level choice perseveration [2].

Here, we use the format of ‘model component [required number of dynamical variables]’ (more details in ref. 11).

Model-based strategies

We include 12 model-based RL models:

  1. (1)

    state value [2] + model-based [2] + bias [0] + rotation bias [0] + single-trial choice perseveration [0];

  2. (2)

    state value [2] + model-free action value [2] + model-based [2] + bias [0] + rotation bias [0] + single-trial choice perseveration [0];

  3. (3)

    state value [2] + model-based with forgetting [2] + bias [0] + rotation bias [0] + single-trial choice perseveration [0];

  4. (4)

    state value [2] + model-free action value with forgetting [2] + model-based with forgetting [2] + bias [0] + rotation bias [0] + single-trial choice perseveration [0];

  5. (5)

    state value [2] + model-free action value with forgetting [2] + model-based [2] + bias [0] + rotation bias [0] + single-trial choice perseveration [0];

  6. (6)

    state value [2] + model-free action value [2] + model-based [2] + bias [0] + rotation bias [0] + multi-trial choice perseveration [1];

  7. (7)

    state value [2] + model-free action value with forgetting [2] + model-based with forgetting [2] + bias [0] + rotation bias [0] + multi-trial choice perseveration [1];

  8. (8)

    state value [2] + model-free action value with forgetting [2] + model-based [2] + bias [0] + rotation bias [0] + multi-trial choice perseveration [1];

  9. (9)

    state value [2] + model-free action value with forgetting [2] + model-based with forgetting [2] + bias [0] + rotation bias [0] + multi-trial motor-level choice perseveration [2];

  10. (10)

    state value [2] + model-based with forgetting [2] + bias [0] + rotation bias [0] + multi-trial choice perseveration [1] + multi-trial motor-level choice perseveration [2];

  11. (11)

    state value [2] + model-free action value with forgetting [2] + model-based with forgetting [2] + bias [0] + rotation bias [0] + multi-trial choice perseveration [1] + multi-trial motor-level choice perseveration [2];

  12. (12)

    state value [2] + model-free action value with forgetting [2] + model-based with forgetting [2] + motor-level model-free action value with forgetting [4] + bias [0] + rotation bias [0] + multi-trial choice perseveration [1] + multi-trial motor-level choice perseveration [2].

Here, we use the format of model component [required number of dynamical variables] (more details in ref. 11).

Models for the three-armed reversal learning task

We implemented four models (n = 3 actions) from the model-free family, one of which is constructed from the strategies discovered by the RNN.

Model-free strategy (d = n)

This model hypothesizes that each action value Qt(Ai), as a dynamical variable, is updated independently. The chosen action value is updated by:

$${Q}_{t}({a}_{t-1})={Q}_{t-1}({a}_{t-1})+\alpha ({r}_{t-1}-{Q}_{t-1}({a}_{t-1})).$$

(33)

The unchosen action values Qt(Aj) (Aj ≠ at − 1) are unaffected.

Model-free strategy with value forgetting (d = n)

The chosen action value is updated as in the previous model. The unchosen action value Qt(Aj) (Aj ≠ at − 1), instead, is gradually forgotten:

$${Q}_{t}({A}_{j})=D{Q}_{t-1}({A}_{j}),$$

(34)

where D is the value forgetting rate (0 ≤ D ≤ 1).

Model-free strategy with value forgetting and action perseveration (d = 2n)

The action values are updated as the model-free strategy with value forgetting. The chosen action perseveration is updated by:

$${X}_{t}({a}_{t-1})={D}_{{\rm{pers}}}{X}_{t-1}({a}_{t-1})+{k}_{{\rm{pers}}},$$

(35)

and the unchosen action perseverations are updated by:

$${X}_{t}({A}_{j})={D}_{{\rm{pers}}}{X}_{t-1}({A}_{j}),$$

(36)

where Dpers is the perseveration forgetting rate (0 ≤ Dpers ≤ 1), and kpers is the single-trial perseveration term, affecting the balance between action values and action perseverations.

Model-free strategy with unchosen value updating and reward utility (d = n)

This model is constructed from the strategy discovered by the RNN (see Supplementary Results 1.4). It assumes that the reward utility U(r) (equivalent to the preference setpoint) is different in four cases (corresponding to four free parameters): no reward for chosen action (Uc(0)), one reward for chosen action (Uc(1)), no reward for unchosen action (Uu(0)), and one reward for chosen action (Uu(1)).

The chosen action value is updated by:

$${Q}_{t}({a}_{t-1})={Q}_{t-1}({a}_{t-1})+{\alpha }_{c}({U}_{c}({r}_{t-1})-{Q}_{t-1}({a}_{t-1})).$$

(37)

The unchosen action value Qt(Aj) (Aj ≠ at − 1) is updated by:

$${Q}_{t}({A}_{j})={Q}_{t-1}({A}_{j})+{\alpha }_{u}({U}_{u}({r}_{t-1})-{Q}_{t-1}({A}_{j})).$$

(38)

The action probabilities for these models are generated via \({\rm{softmax}}\,({\{\beta ({Q}_{t}({A}_{i})+{X}_{t}({A}_{i}))\}}_{i})\) (Xt = 0 for models without action perseverations). Both the action values and action perseverations are dynamical variables.

Models for the four-armed drifting bandit task

We implemented five models (n = 4 actions) from the model-free family, two of which are constructed from the strategies discovered by the RNN.

Model-free strategy (d = n)

This model is the same as the model-free strategy in the three-armed reversal learning task.

Model-free strategy with value forgetting (d = n)

This model is the same as the model-free strategy with value forgetting in the three-armed reversal learning task.

Model-free strategy with value forgetting and action perseveration (d = 2n)

This model is the same as the model-free strategy with value forgetting and action perseveration in the three-armed reversal learning task.

Model-free strategy with unchosen value updating and reward reference point (d = n)

This model is constructed from the strategy discovered by the RNN (see Supplementary Results 1.5). It assumes that the reward utility U(r) is different for chosen action (Uc(r) = βc(r − Rc)) and for unchosen action (Uu(r) = βu(r − Ru)), where βc and βu are reward sensitivities, and Rc and Ru are reward reference points.

The chosen action value is updated by:

$${Q}_{t}({a}_{t-1})=(1-{\alpha }_{c}){Q}_{t-1}({a}_{t-1})+{U}_{c}({r}_{t-1}),$$

(39)

where 1 − αc is the decay rate for chosen actions. The unchosen action value Qt(Aj) (Aj ≠ at − 1) is updated by:

$${Q}_{t}({A}_{j})=(1-{\alpha }_{u}){Q}_{t-1}({A}_{j})+{U}_{u}({r}_{t-1}),$$

(40)

where 1 − αu is the decay rate for unchosen actions. We additionally fit a reduced model of this strategy where βc = αc and βu = αu (similarly inspired by the RNN’s solution).

The action probabilities for these models are generated via \({\rm{softmax}}\,({\{\beta ({Q}_{t}({A}_{i})+{X}_{t}({A}_{i}))\}}_{i})\) (Xt = 0 for models without action perseverations). Both the action values and action perseverations are dynamical variables.

Models for the original two-stage task

Model-free strategy (d = 3)

This model hypothesizes that the action values for each task state (first-stage state S0, second-stage states S1 and S2) are fully anti-correlated (\({Q}_{t}^{{S}_{0}}({A}_{1})=-{Q}_{t}^{{S}_{0}}({A}_{2})\), \({Q}_{t}^{{S}_{1}}({B}_{1})=-{Q}_{t}^{{S}_{1}}({B}_{2})\), \({Q}_{t}^{{S}_{2}}({B}_{3})=-{Q}_{t}^{{S}_{2}}({B}_{3})\)).

The action values at the chosen second-stage state (for example, assuming B1 or B2 at S1 is chosen) are updated by:

$$\begin{array}{l}{Q}_{t}^{{S}_{1}}({a}_{t-1}^{{S}_{1}})={Q}_{t-1}^{{S}_{1}}({a}_{t-1}^{{S}_{1}})+{\alpha }_{2}({r}_{t-1}-{Q}_{t-1}^{{S}_{1}}({a}_{t-1}^{{S}_{1}}))\\ {Q}_{t}^{{S}_{1}}({\overline{{a}}}_{t-1}^{{S}_{1}})={Q}_{t-1}^{{S}_{1}}({\overline{{a}}}_{t-1}^{{S}_{1}})-{\alpha }_{2}({r}_{t-1}+{Q}_{t-1}^{{S}_{1}}({\overline{{a}}}_{t-1}^{{S}_{1}})),\end{array}$$

(41)

where \({\overline{{a}}}_{t-1}^{{S}_{1}}\) is the unchosen second-stage action at the chosen second-stage state, and α2 is the learning rate for the second-stage states (0 ≤ α2 ≤ 1). The second-stage action probabilities are generated via softmax \(({\beta }_{2}{Q}_{t}^{{S}_{1}}({B}_{1}),{\beta }_{2}{Q}_{t}^{{S}_{1}}({B}_{2}))\).

The action values at the first-stage state (A1 or A2 at S0) are updated by:

$$\begin{array}{c}{Q}_{t}^{{S}_{0},{\rm{m}}{\rm{f}}}({a}_{t-1}^{{S}_{0}})={Q}_{t-1}^{{S}_{0},{\rm{m}}{\rm{f}}}({a}_{t-1}^{{S}_{0}})+{\alpha }_{1}(\lambda {r}_{t-1}+(1-\lambda ){Q}_{t}^{{S}_{1}}({a}_{t-1}^{{S}_{1}})\\ \,\,\,\,\,\,-\,{Q}_{t-1}^{{S}_{0},{\rm{m}}{\rm{f}}}({a}_{t-1}^{{S}_{0}}))\\ {Q}_{t}^{{S}_{0},{\rm{m}}{\rm{f}}}({\overline{{a}}}_{t-1}^{{S}_{0}})={Q}_{t-1}^{{S}_{0},{\rm{m}}{\rm{f}}}({\overline{{a}}}_{t-1}^{{S}_{0}})-{\alpha }_{1}(\lambda {r}_{t-1}+(1-\lambda ){Q}_{t}^{{S}_{1}}({a}_{t-1}^{{S}_{1}})\\ \,\,\,\,\,\,+\,{Q}_{t-1}^{{S}_{0},{\rm{m}}{\rm{f}}}({\overline{{a}}}_{t-1}^{{S}_{0}})),\end{array}$$

(42)

where \({\overline{{a}}}_{t-1}^{{S}_{0}}\) is the unchosen first-stage action, α1 is the learning rate for the first-stage state (0≤ α1 ≤1), and λ specifies the TD(λ) learning rule. The first-stage action probabilities are generated via softmax \(({\beta }_{1}{Q}_{t}^{{S}_{0},{\rm{m}}{\rm{f}}}({A}_{1}),{\beta }_{1}{Q}_{t}^{{S}_{0},{\rm{m}}{\rm{f}}}({A}_{2}))\).

Here \({Q}_{t}^{{S}_{0},{\rm{m}}{\rm{f}}}({A}_{1})\), \({Q}_{t}^{{S}_{1}}({B}_{1})\), and \({Q}_{t}^{{S}_{2}}({C}_{1})\) are the dynamical variables.

Model-based strategy (d = 2)

The update of action values at the chosen second-stage state is the same as the model-free strategy. The action values at the first-stage state (A1 or A2 at S0) are determined by:

$${Q}_{t}^{{S}_{0},{\rm{m}}{\rm{b}}}({A}_{i})=Pr[{S}_{1}|{A}_{i}]\mathop{\text{max}}\limits_{{B}_{j}}\,{Q}_{t}^{{S}_{1}}({B}_{j})+Pr[{S}_{2}|{A}_{i}]\mathop{\text{max}}\limits_{{C}_{j}}\,{Q}_{t}^{{S}_{2}}({C}_{j}).$$

(43)

The first-stage action probabilities are generated via \({\rm{s}}{\rm{o}}{\rm{f}}{\rm{t}}{\rm{m}}{\rm{a}}{\rm{x}}\,({\beta }_{1}{Q}_{t}^{{S}_{0},{\rm{m}}{\rm{b}}}({A}_{1}),{\beta }_{1}{Q}_{t}^{{S}_{0},{\rm{m}}{\rm{b}}}({A}_{2}))\).

Only \({Q}_{t}^{{S}_{1}}({B}_{1})\) and \({Q}_{t}^{{S}_{2}}({C}_{1})\) are the dynamical variables.

Model-based mixture strategy (d = 3)

This model considers the mixture of model-free and model-based strategies for the first-stage states. The net action values are determined by:

$${Q}_{t}^{{S}_{0},{\rm{n}}{\rm{e}}{\rm{t}}}({A}_{i})=(1-w){Q}_{t}^{{S}_{0},{\rm{m}}{\rm{f}}}({A}_{i})+w{Q}_{t}^{{S}_{0},{\rm{m}}{\rm{b}}}({A}_{i}),$$

(44)

where w controls the strength of the model-based component. The first-stage action probabilities are generated via \(\text{softmax}\,({\beta }_{1}{Q}_{t}^{{S}_{0},{\rm{n}}{\rm{e}}{\rm{t}}}({A}_{1}),{\beta }_{1}{Q}_{t}^{{S}_{0},{\rm{n}}{\rm{e}}{\rm{t}}}({A}_{2}))\). \({Q}_{t}^{{S}_{0},{\rm{m}}{\rm{f}}}({A}_{1})\), \({Q}_{t}^{{S}_{1}}({B}_{1})\) and \({Q}_{t}^{{S}_{2}}({C}_{1})\) are the dynamical variables.

Model-free strategy (d = 6)

Compared to the model-free strategy (d = 3), only the chosen action values at S0, S1, and S2 are updated. The unchosen values are unchanged. \({Q}_{t}^{{S}_{0},{\rm{m}}{\rm{f}}}({A}_{1})\), \({Q}_{t}^{{S}_{0},{\rm{m}}{\rm{f}}}({A}_{2})\), \({Q}_{t}^{{S}_{1}}({B}_{1})\), \({Q}_{t}^{{S}_{1}}({B}_{2})\), \({Q}_{t}^{{S}_{2}}({C}_{1})\) and \({Q}_{t}^{{S}_{2}}({C}_{2})\) are the dynamical variables.

Model-based strategy (d = 4)

Compared to the model-based strategy (d = 2), only the chosen action values at S1, and S2 are updated. The unchosen values are unchanged. \({Q}_{t}^{{S}_{1}}({B}_{1})\), \({Q}_{t}^{{S}_{1}}({B}_{2})\), \({Q}_{t}^{{S}_{2}}({C}_{1})\) and \({Q}_{t}^{{S}_{2}}({C}_{2})\) are the dynamical variables.

Model-based mixture strategy (d = 6)

Compared to the model-based mixture strategy (d = 3), only the chosen action values at S0, S1 and S2 are updated. The unchosen values are unchanged. \({Q}_{t}^{{S}_{0},{\rm{m}}{\rm{f}}}({A}_{1})\), \({Q}_{t}^{{S}_{0},{\rm{m}}{\rm{f}}}({A}_{2})\), \({Q}_{t}^{{S}_{1}}({B}_{1})\), \({Q}_{t}^{{S}_{1}}({B}_{2})\), \({Q}_{t}^{{S}_{2}}({C}_{1})\) and \({Q}_{t}^{{S}_{2}}({C}_{2})\) are the dynamical variables.

Model-free strategy with reward utility (d = 3)

This model is constructed from the RNN’s strategy. Similar to the model-free strategy (d = 3), it hypothesizes that the action values for each task state (first-stage state S0, second-stage states S1 and S2) are fully anti-correlated (\({Q}_{t}^{{S}_{0}}({A}_{1})=-{Q}_{t}^{{S}_{0}}({A}_{2})\), \({Q}_{t}^{{S}_{1}}({B}_{1})=-{Q}_{t}^{{S}_{1}}({B}_{2})\), \({Q}_{t}^{{S}_{2}}({B}_{3})=-{Q}_{t}^{{S}_{2}}({B}_{3})\)).

It assumes that when receiving one reward, the reward utility (that is, equivalently, the preference setpoint) for the chosen action at the first-stage state S0 is \({U}^{{S}_{0}}(1)=1\), for the chosen action at the chosen second-stage state S1 (or S2) is \({U}^{{S}_{1}}(1)=1\), and for the (motor-level) chosen action at the unchosen second-stage state S2 (or S1) is \({U}^{{S}_{2}}(1)={U}_{{\rm{other}}}\) (for example, B1 at the chosen S1 and C1 at unchosen S2 are the same motor-level action). When receiving no reward, the reward utility for the chosen action at the first-stage state S0 is \({U}^{{S}_{0}}(0)={U}_{1{\rm{st}},{\rm{zero}}}\), for the chosen action at the chosen second-stage state (assuming S1) is \({U}^{{S}_{1}}(0)={U}_{2{\rm{nd}},{\rm{zero}}}\), and for the (motor-level) chosen action at the unchosen second-stage state (assuming S2) is \({U}^{{S}_{2}}(0)=-{U}_{{\rm{other}}}\). The chosen action values at the chosen second-stage state (for example, assuming B1 or B2 at S1) are updated by:

$${Q}_{t}^{{S}_{1}}({a}_{t-1}^{{S}_{1}})={Q}_{t-1}^{{S}_{1}}({a}_{t-1}^{{S}_{1}})+{\alpha }_{2}({U}^{{S}_{1}}({r}_{t-1})-{Q}_{t-1}^{{S}_{1}}({a}_{t-1}^{{S}_{1}})),$$

(45)

where α2 is the learning rate for the second-stage states (0 ≤ α2 ≤ 1). The (motor-level) chosen action values (that is, \({\widetilde{a}}_{t-1}^{{S}_{2}}={C}_{1}\) if \({a}_{t-1}^{{S}_{1}}={B}_{1}\) and, \({\widetilde{a}}_{t-1}^{{S}_{2}}={C}_{2}\) if \({a}_{t-1}^{{S}_{1}}={B}_{2}\)) at the unchosen second-stage state (for example, assuming C1 or C2 at S2) are updated by:

$${Q}_{t}^{{S}_{2}}({\widetilde{a}}_{t-1}^{{S}_{2}})={Q}_{t-1}^{{S}_{2}}({\widetilde{a}}_{t-1}^{{S}_{2}})+{\alpha }_{2}({U}^{{S}_{2}}({r}_{t-1})-{Q}_{t-1}^{{S}_{2}}({\widetilde{a}}_{t-1}^{{S}_{2}})).$$

(46)

The second-stage action probabilities are generated via \(\text{softmax}\,({\beta }_{2}{Q}_{t}^{{S}_{1}}({B}_{1}),{\beta }_{2}{Q}_{t}^{{S}_{1}}({B}_{2}))\).

The action values at the first-stage state (A1 or A2 at S0) are updated by:

$${Q}_{t}^{{S}_{0}}({a}_{t-1}^{{S}_{0}})={Q}_{t-1}^{{S}_{0}}({a}_{t-1}^{{S}_{0}})+{\alpha }_{1}({U}^{{S}_{0}}({r}_{t-1})-{Q}_{t-1}^{{S}_{0}}({a}_{t-1}^{{S}_{0}}))$$

(47)

where α1 is the learning rate for the first-stage state (0 ≤ α1 ≤ 1). The first-stage action probabilities are generated via \(\text{softmax}\,({\beta }_{1}{Q}_{t}^{{S}_{0}}({A}_{1}),{\beta }_{1}{Q}_{t}^{{S}_{0}}({A}_{2}))\).

Here \({Q}_{t}^{{S}_{0}}({A}_{1})\), \({Q}_{t}^{{S}_{1}}({B}_{1})\), and \({Q}_{t}^{{S}_{2}}({C}_{1})\) are the dynamical variables.

Model fitting

Maximum likelihood estimation

The parameters in all models were optimized on the training dataset to maximize the log-likelihood (that is, minimize the negative log-likelihood, or cross-entropy) for the next-action prediction. The loss function is defined as follows:

$$\begin{array}{l}{\mathcal{L}}\,=\,-\log \Pr [\text{action sequences from one subject given}\\ \,\,\text{one model}]\\ \,=\,-\mathop{\sum }\limits_{n=1}^{{N}_{{\rm{session}}}}\mathop{\sum }\limits_{t=1}^{{T}_{n}}\log \Pr [\text{observing}\,{a}_{t}\,\text{given past}\\ \,\,\text{observations and the model}],\end{array}$$

(48)

where Nsession is the number of sessions and Tn is the number of trials in session n.

Nested cross-validation

To avoid overfitting and ensure a fair comparison between models with varying numbers of parameters, we implemented nested cross-validation. For each animal, we first divided sessions into non-overlapping shorter blocks (approximately 150 trials per block) and allocated these blocks into ten folds. In the outer loop, nine folds were designated for training and validation, while the remaining fold was reserved for testing. In the inner loop, eight of the nine folds were assigned for training (optimizing a model’s parameters for a given set of hyperparameters), and the remaining fold of the nine was allocated for validation (selecting the best-performing model across all hyperparameter sets). Notice that this procedure allows different hyperparameters for each test set.

RNNs’ hyperparameters encompassed the L1-regularization coefficient on recurrent weights (drawn from 10−5, 10−4, 10−3, 10−2 or 10−1, depending on the experiments), the number of training epochs (that is, early stopping), and the random seed (three seeds). For cognitive models, the only hyperparameter was the random seed (used for parameter initialization). The inner loop produced nine models, with the best-performing model, based on average performance in the training and validation datasets, being selected and evaluated on the unseen testing fold. The final testing performance was computed as the average across all ten testing folds, weighted by the number of trials per block. This approach ensures that test data is exclusively used for evaluation and is never encountered during training or selection.

During RNN training, we employed early stopping if the validation performance failed to improve after 200 training epochs. This method effectively prevents RNN overfitting on the training data. According to this criterion, a more flexible model may demonstrate worse performance than a less flexible one, as the training for the former could be halted early due to insufficient training data. However, it is expected that the more flexible model would continue to improve with additional training data (for example, see Supplementary Fig. 8).

We note that, in the rich-data situation, this training–validation–test split in (nested) cross-validation is better than the typical usage of AIC62, corrected AIC (AICc)63 or BIC64 in cognitive modelling, due to the following reasons65: the (nested) cross-validation provides a direct and unbiased estimate of the expected extra-sample test error, which reflects the generalization performance on new data points with inputs not necessarily appearing in the training dataset; by contrast, AIC, AICc and BIC can only provide asymptotically unbiased estimates of in-sample test error under some conditions (for example, models are linear in their parameters), measuring the generalization performance on new data points with inputs always appearing in the training dataset (the labels could be different from those in the training dataset due to noise). Furthermore, in contrast to regular statistical models, neural networks are singular statistical models with degenerate Fisher information matrices. Consequently, estimating the model complexity (the number of effective parameters, as used in AIC, AICc or BIC) in neural networks requires estimating the real log canonical threshold66, which falls outside the scope of this study.

Estimating the dimensionality of behaviour

For each animal, we observed that the predictive performance of RNN models initially improves and then saturates, or sometimes declines as the number d of dynamical variables increases. To operationally estimate the dimensionality d* of behaviour, we implemented a statistical procedure that satisfies two criteria: (1) the RNN model with d* dynamical variables significantly outperforms all RNN models with d < d* dynamical variables (using a significance level of 0.05 in the t-tests of predictive performance conducted over outer folds); (2) any RNN model with \({d}^{{\prime} }\) (\({d}^{{\prime} } > {d}_{* }\)) dynamical variables does not exhibit significant improvement over all RNN models with \(d < {d}^{{\prime} }\) dynamical variables.

Our primary objective is to estimate the intrinsic dimensionality (reflecting the latent variables in the data-generating process), not the embedding dimensionality67. However, it is important to consider the practical limitations associated with the estimation procedure. For instance, RNN models may fail to uncover certain latent variables due to factors such as limited training data or variables operating over very long time scales, leading to an underestimation of d*. Additionally, even if all d* latent variables are accurately captured, the RNN models may still require d ≥ d* dynamical variables to effectively and losslessly embed d*-dimensional dynamics, particularly if they exhibit high nonlinearity, potentially resulting in an overestimation of d*. A comprehensive understanding of these factors is crucial for future studies.

Knowledge distillation

We employ the knowledge distillation framework33 to fit models to individual subjects, while simultaneously leveraging group data: first fitting a teacher network to data from multiple subjects, and then fitting a student network to the outputs of the teacher network corresponding to an individual subject.

Teacher network

In the teacher network (TN), each subject is represented by a one-hot vector. This vector projects through a fully connected linear layer into a subject-embedding vector esub, which is provided as an additional input to the RNN. The teacher network uses 20 units in its hidden layer and uses the same output layer and loss (cross-entropy between the next-trial action and the predicted next-trial action probability) as in previous RNN models.

Student network

The student network (SN) has the same architecture as previous tiny RNNs. The only difference is that, during training and validation, the loss is defined as cross-entropy between the next-trial action probability provided by the teacher and the next-trial action probability predicted by the student:

$$\begin{array}{l}{\mathcal{L}}=-\mathop{\sum }\limits_{n=1}^{{N}_{{\rm{session}}}}\mathop{\sum }\limits_{t=1}^{{T}_{n}}\mathop{\sum }\limits_{a=1}^{{N}_{a}}{\text{Pr}}^{{\rm{TN}}}[{a}_{t}=a| \,\text{past observations}]\\ \,\,\times \,\log {\text{Pr}}^{{\rm{SN}}}[{a}_{t}=a| \,\text{past observations}],\end{array}$$

(49)

where Nsession is the number of sessions, Tn is the number of trials in session n, and Na is the number of actions.

Training, validation and test data in knowledge distillation for the mouse in the Akam dataset

To study the influence of the number of training trials from one representative mouse on the performance of knowledge distillation, we employed a procedure different from nested cross-validation. This procedure splits the data from animal M into two sets. The first set consisted of 25% of the trials and was used as a hold-out M-test dataset. The second set consisted of the remaining 75% trials, from which smaller datasets of different sizes were sampled. From each sampled dataset, 90% of the trials were used for training (M-training dataset) and 10% for validation (M-validation dataset). Next, we split the data from all other animals, with 90% of the data used for training (O-training dataset) and 10% for validation (O-validation dataset).

After dividing the datasets as described above, we trained the models. The solo RNNs were trained to predict choices on the M-training dataset and selected on the M-validation dataset. The teacher RNNs were trained to predict choices on the M– and O-training datasets and selected on the M– and O-validation datasets. The number of embedding units in the teacher RNNs was selected based on the M-validation dataset. The student RNNs were trained on the M-training dataset and selected on the M-validation dataset, but with the training target of action probabilities provided by the teacher RNNs. Here the student RNNs and the corresponding teacher RNNs were trained on the same M-training dataset. Finally, all models were evaluated on the unseen M-test data.

When training the student RNNs, due to symmetry in the task, we augment the M-training datasets by flipping the action and second-stage states, resulting in an augmented dataset that is four times the size of the original one, similar to29. One key difference between our augmentation procedure and that of29 is that the authors augmented the data for training the group RNNs, where the potential action bias presented in the original dataset (and other related biases) becomes invisible to the RNNs. By contrast, our teacher RNNs are trained only on the original dataset, where any potential action biases can be learned. Even if we augment the training data later for the student networks, the biases learned by the teacher network can still be transferred into the student networks. In addition to direct augmentation, simulating the teacher network can be another method to generate pseudo-data. The benefit of these pseudo-data was discussed in model compression68.

Protocols for training, validating and testing models in human datasets

Interspersed split protocol

In the three human datasets, each subject only performs one block of 100–200 trials. In the standard practice of cognitive modelling, the cognitive models are trained and tested on the same block, leading to potential overfitting and exaggerated performance. While it is possible to directly segment one block into three sequences for training, validation, and testing, this might introduce undesired distributional shifts in the sequences due to the learning effect. To ensure a fair comparison between RNNs and cognitive models, here we propose a new interspersed split protocol to define the training, validation and testing trials, similar to the usage of goldfish loss to prevent the memorization of training data in language models69. Specifically, we randomly sample without replacement  ~75% trial indexes for training,  ~12.5% trial indexes for validation and  ~12.5% trial indexes for testing (three-armed reversal learning task: 120/20/20 (training/validation/testing); four-armed drifting bandit task: 110/20/20; original two-stage task: 150/25/25). We then feed in the whole block of trials as the model’s inputs, obtain the output probabilities for each trial, and calculate the training, validation, and testing losses for each set of trial indexes, separately. This protocol guarantees the identical distribution between three sets of trials.

One possible concern is whether the test data is leaked into the training data in this protocol. For instance, the models are trained on the input sequence ((a1r1), (a2r2), (a3r3)) to predict a4 and later tested on the input sequence ((a1r1), (a2r2)) to predict a3. In this scenario, while the models see a3 in the input during training, they never see a3 in the output. Thus, models are not trained to learn the input–output mapping from ((a1r1), (a2r2)) to a3, which is evaluated during testing. We confirmed that this procedure prevents data leakage on artificially generated choices (Supplementary Fig. 40).

Cross-subject split protocol

In addition to the interspersed split protocol, it is possible to train the RNNs on a proportion of subjects and evaluate them on held-out subjects (that is, zero-shot generalization), a cross-subject split protocol. To illustrate this protocol, we first divided all subjects into six folds of cross-validation. The teacher network was trained and validated using five folds and tested on the remaining one fold. For each subject in the test fold, because each subject only completed one task block, student networks are trained on the action-augmented blocks (to predict the teacher’s choice probabilities for the subject), validated on the original block (to predict the teacher’s choice probabilities for the subject), and tested on the original block (to predict actual choices of the subject). By design, both teacher networks and student networks will not overfit the subjects’ choices in the test data. The cognitive models were trained and validated using five folds and tested on the remaining one fold. We presented the results in Supplementary Fig. 41.

Phase portraits

Models with d = 1

Logit

In each trial t, a model predicts the action probabilities Pr(at = A1) and Pr(at = A2). We define the logit L(t) (log odds) at trial t as \(L(t)=\log (\Pr ({a}_{t}={A}_{1})/\Pr ({a}_{t}={A}_{2}))\). When applied to probabilities computed via softmax, the logit yields \(L(t)=\log ({e}^{\beta {o}_{t}^{(1)}}/{e}^{\beta {o}_{t}^{(2)}})=\beta ({o}_{t}^{(1)}-{o}_{t}^{(2)})\), where \({o}_{t}^{(i)}\) is the model’s output for action at = Ai before softmax. Thus, the logit can be viewed as reflecting the preference for action A1 over A2: in RNNs, the logit corresponds to the score difference \({o}_{t}^{(1)}-{o}_{t}^{(2)}\); in model-free and model-based RL models, the logit is proportional to the difference in first-stage action values Qt(A1) − Qt(A2); in Bayesian inference models, the logit is proportional to the difference in latent-state probabilities \({\Pr }_{t}(h=1)-{\Pr }_{t}(h=2)=2{\Pr }_{t}(h=1)-1\).

Logit change

We define the logit change, ΔL(t), in trial t as the difference between L(t + 1) and L(t). In one-dimensional models, ΔL(t) is a function of the input and L(t), forming a vector field.

Stability of fixed points

Here we derive the stability of a fixed point in one-dimensional discrete dynamical systems. The system’s dynamics update according to:

$${L}_{{\rm{next}}}={f}_{I}(L),$$

(50)

where L is the current-trial logit, Lnext is the next-trial logit, and fI is a function determined by input I (omitted for simplicity). At a fixed point, denoted by L = L*, we have

$${L}^{* }=f({L}^{* }).$$

(51)

Next, we consider a small perturbation δL around the fixed point:

$$\begin{array}{l}{L}_{{\rm{next}}}\,=\,f({L}^{* }+\delta L)\\ \,\,\approx \,f({L}^{* })+{f}^{{\prime} }({L}^{* })\delta L\\ \,\,=\,{L}^{* }+{f}^{{\prime} }({L}^{* })\delta L.\end{array}$$

(52)

The fixed point is stable only when \(-1 < {f}^{{\prime} }({L}^{* }) < 1\). Because the logit change ΔL is defined as ΔL = g(L) = f(L) − L, we have the stability condition \(-2 < {g}^{{\prime} }({L}^{* }) < 0\).

Effective learning rate and slope

In the one-dimensional RL models with prediction error updates and constant learning rate α, we have

$$g(L)=\alpha ({L}^{* }-L),$$

(53)

where g(L) is the logit change at L. In general, to obtain a generalized form of g(L) = α(L)(L* − L) with a non-constant learning rate, we define the effective learning rate α(L) at L relative to a stable fixed point L* as:

$$\alpha (L)=-\frac{g(L)-g({L}^{* })}{L-{L}^{* }}=-\frac{g(L)}{L-{L}^{* }}.$$

(54)

At L*, α(L*) is the negative slope \(-{g}^{{\prime} }({L}^{* })\) of the tangent at L*. However, for general L ≠ L*, α(L) is the negative slope of the secant connecting (Lg(L)) and (L*, 0), which is different from \(-{g}^{{\prime} }(L)\).

We have

$$\begin{array}{l}{\alpha }^{{\prime} }(L)\delta L\,\approx \,\alpha (L+\delta L)-\alpha (L)\\ \,\,\,=\,\frac{g(L)}{L-{L}^{* }}-\frac{g(L+\delta L)}{L+\delta L-{L}^{* }}\\ \,\,\,\approx \,\frac{-\alpha (L)-{g}^{{\prime} }(L)}{L-{L}^{* }}\delta L.\end{array}$$

(55)

Letting δL go to zero, we have:

$$\alpha (L)=-{g}^{{\prime} }(L)-{\alpha }^{{\prime} }(L)(L-{L}^{* }),$$

(56)

which provides the relationship between the effective learning rate α(L) and the slope of the tangent \({g}^{{\prime} }(L)\).

Models with d > 1

In models with more dynamical variables, ΔL(t) is no longer solely a function of the input and L(t) due to added degrees of freedom. In these models, the state space is spanned by a set of dynamical variables, collected by the vector F(t). For example, the action value vector is the \(F(t)={({Q}_{t}({A}_{1}),{Q}_{t}({A}_{2}))}^{T}\) in the two-dimensional RL models. The vector field ΔF(t) can be defined as \(\Delta F(t)=F(t+1)-F(t)={({Q}_{t+1}({A}_{1})-{Q}_{t}({A}_{1}),{Q}_{t+1}({A}_{2})-{Q}_{t}({A}_{2}))}^{T}\), a function of F(t) and the input in trial t.

Dynamical regression

For one-dimensional models with states characterized by the policy logit L(t), we can approximate the one-step dynamics for a given input with a linear function—that is, ΔL ~ β0 + βLL. The coefficients β0 and βL can be computed via linear regression, or ‘dynamical regression’ given its use in modelling dynamical systems. Here, β0 is similar to the preference setpoint and βL is similar to learning rates in RL models.

For models with more than one dynamical variable, we can use a similar dynamical regression approach to extract a first-order approximation of the model dynamics via linearizations of vector fields. To facilitate interpretation, we consider only d-dimensional RNNs with a d-unit diagonal readout layer (denoted by Li(t) or Pi(t); a non-degenerate case).

For tasks with a single choice state (Supplementary Results 1.4 and 1.5), the diagonal readout layer means that d is equal to the number of actions. Thus Pi(t) corresponds to the action preference for Ai at trial t (before softmax). A special case of Pi(t) is equal to βVt(Ai) in cognitive models. We use ΔPi(t) = Pi(t + 1) − Pi(t) to denote preference changes between two consecutive trials. For the reversal learning task and three-armed reversal learning task, we consider ΔPi(t) as an (approximate) linear function of P1(t), …, Pd(t) for different (discrete) task inputs (that is, \(\Delta {P}_{i} \sim {\beta }_{0}^{({P}_{i})}+{\sum }_{j=1}^{d}{\beta }_{{P}_{j}}^{({P}_{i})}{P}_{j}\)). For the four-armed drifting bandit task, we further include the continuous reward r as an independent variable (that is, \(\Delta {P}_{i} \sim {\beta }_{0}^{({P}_{i})}+{\beta }_{R}^{({P}_{i})}r+{\sum }_{j=1}^{d}{\beta }_{{P}_{j}}^{({P}_{i})}{P}_{j}\)).

For the original two-stage task, where there are three choice states (Supplementary Result 1.6), we focus on the three-dimensional model with a diagonal readout layer. Here, L1, L2 and L3 represent the logits for A1/A2 at the first-stage state, logits for B1/B2 at the second-stage state S1 and logits for C1/C2 at the second-stage state S2, respectively. We similarly consider the regression \(\Delta {L}_{i} \sim {\beta }_{0}^{({L}_{i})}+{\sum }_{j=1}^{3}{\beta }_{{L}_{j}}^{({L}_{i})}{L}_{j}\).

Collecting all the \({\beta }_{{L}_{j}}^{({L}_{i})}\) (similarly for \({\beta }_{{P}_{j}}^{({P}_{i})}\)) regression coefficients for a given input condition, we have the input-dependent state-transition matrix A, akin to the Jacobian matrix of nonlinear dynamical systems:

$${\bf{A}}=\left[\begin{array}{cccc}{\beta }_{{L}_{1}}^{({L}_{1})} & {\beta }_{{L}_{2}}^{({L}_{1})} & \cdots & {\beta }_{{L}_{d}}^{({L}_{1})}\\ {\beta }_{{L}_{1}}^{({L}_{2})} & {\beta }_{{L}_{2}}^{({L}_{2})} & \cdots & {\beta }_{{L}_{d}}^{({L}_{2})}\\ \vdots & \vdots & \ddots & \vdots \\ {\beta }_{{L}_{1}}^{({L}_{d})} & {\beta }_{{L}_{2}}^{({L}_{d})} & \cdots & {\beta }_{{L}_{d}}^{({L}_{d})}\end{array}\right]$$

Note that the model-free RL models in these tasks are fully characterized by the collection of all regression coefficients in our dynamical regression.

Symbolic regression

Apart from the two-dimensional vector field analysis, symbolic regression is another method for discovering concise equations that summarize the dynamics learned by RNNs. To accomplish this, we used PySR70 to search for simple symbolic expressions of the updated dynamical variables as functions of the current dynamical variables for each possible input I (for the RNN with d = 2 and a diagonal readout matrix). Ultimately, this process revealed a model-free strategy featuring the drift-to-the-other rule.

Model validation via behaviour-feature identifier

We proposed a general and scalable approach based on a ‘behaviour-feature identifier’. In contrast to conventional model recovery, this approach provides a model-agnostic form of validation to identify and verify the hallmark of the discovered strategy in the empirical data.

For a given task, we collect the behavioural sequences generated by models that exhibit a specific feature (positive class) and by those that do not (negative class). An RNN identifier is then trained on these sequences to discern their classes. Subsequently, this identifier is applied to the actual behavioural sequences produced by subjects.

We built identifiers to distinguish between the RNN models (positive class) and model-free RL models (negative class) in the reversal learning task, and between the RNN models (positive class) and model-based RL models (negative class) in the two-stage task. We presented the results in Supplementary Fig. 29.

Meta-RL models

We trained meta-RL agents on the two-stage task (common transition: Pr(S1A1) = Pr(S2A2) = 0.8, rare transition: Pr(S2A1) = Pr(S1A2) = 0.2; see Supplementary Fig. 27) implemented in NeuroGym (v.0.0.1)71. Each second-stage state leads to a different probability of a unit reward, with the most valuable state switching stochastically (Pr(r = 1S1) = 1 − Pr(r = 1S2) = 0.8 or 0.2 with a probability of 0.025 on each trial). There are three periods (discrete time steps) on one trial: Delay 1, Go and Delay 2. During Delay 1, the agent receives the observation (choice state S0 and a fixation signal), and the reward (1 or 0) from second-stage states on the last trial. During Go, the agent receives the observation of the choice state and a go signal. During Delay 2, the agent receives the observation of state S1/S2 and a fixation signal. If the agent does not select action A1 or A2 during Go or select action F (Fixate) during Delay periods, a small negative reward (−0.1) is given. The contributions of second-stage states, rewards, and actions on networks are thus separated in time.

The agent architecture is a fully connected, gated RNN (long short-term memory58) with 48 units22. The input to the network consists of the current observation (state S0/S1/S2 and a scalar fixation/go signal), a scalar reward signal of the previous time step, and a one-hot action vector of the previous time step. The network outputs a scalar baseline (value function for the current state) serving as the critic and a real-valued action vector (passed through a softmax layer to sample one action from A1/A2/F) serving as the actor. The agents are trained using the Advantage Actor-Critic RL algorithm72 with the policy gradient loss, value estimate loss, and entropy regularization. We trained and analysed agents for five seeds. Our agents obtained 0.64 rewards on average on each trial (0.5 rewards for chance level), close to optimal performance (0.68 rewards obtained by an oracle agent knowing the correct action).

Reporting summary

Further information on research design is available in the Nature Portfolio Reporting Summary linked to this article.



Source link

Latest news

Despite Protests, Elon Musk Secures Air Permit for xAI

A local health department in Memphis has granted Elon Musk’s xAI data center an air permit to continue...

Wonder Dynamics co-founder Nikola Todorovic joins Disrupt 2025

Tech Zone Daily Disrupt 2025 is back at Moscone West in San Francisco this October 27–29, bringing together...

Robinhood’s co-founder is beaming up (and down) the future of energy

Robinhood’s Baiju Bhatt has a new mission: solar power from space. Fresh off a $50 million Series A raise,...

Learn a founder-focused approach to anxiety at TC All Stage

Startups demand constant decision-making and pressure-filled pivots, which bring big emotional swings. It’s no wonder anxiety shows up...

Lovable on track to raise $150M at $2B valuation

Lovable, one of the darlings of the vibe-coding world and one of Europe’s fastest-growing AI startups, is working...

Must read

You might also likeRELATED
Recommended to you