Skip to content

Latest commit

 

History

History
44 lines (33 loc) · 2.85 KB

File metadata and controls

44 lines (33 loc) · 2.85 KB

Dynamical inference of RL parameters

Code accompanying the project Harnessing the Flexibility of Neural Networks to Predict Dynamical Theoretical Parameters of Human Choice Behavior https://www.biorxiv.org/content/10.1101/2023.04.21.537666v1

In this project we developed a new framework we term t-RNN (theoretical-RNN), in which an RNN is trained to infer time-varying RL parameters of learning agent performing a two-armed bandit task.

Experiments with simulated behavior data

  • To generate a synthetic training-set run artificial/code/create_train_data.ipynb.
  • To train a t-RNN model with the synthetic training-set run artificial/trnn_training.ipynb.
  • Code for both baseline methods can be found at artificial/q_fit.py for Stationary Q-learning maximum-likelihood and artificial/code/bayesian_fit.py for Bayesian particle filtering.
  • Plots and describe results run artificial/plots.ipynb
  • artificial/code/checkpoint/checkpoint_trnn_5.pth is a state_dict of the trained model weights used for the analysis.

Experiments with human dataset (psychiatric individuals)

The human behavioral dataset can be found at:

Code:

  • To generate a synthetic training-set run dezfouli/code/create_train_data.ipynb.
  • To train a t-RNN model with the synthetic training-set run dezfouli/code/trnn_training.ipynb.
  • Code for stationary Q-learning with preservation model can be found at dezfouli/code/qp_fit.py
  • Code for data-driven RNN can be found at dezfouli/code/drnn.ipynb
  • Plots and describe results run dezfouli/code/plots.ipynb
  • dezfouli/code/checkpoint_trnn.pth is a state_dict of the trained model weights used for the analysis.

Experiments with human dataset (exploration behavior)

The human behavioral dataset can be found at:

Code:

  • To generate a synthetic training-set run gershman/code/create_train_data.ipynb.
  • To train a t-RNN model with the synthetic training-set run gershman/code/trnn_training.ipynb.
  • Code for stationary hybrid exploration model can be found at gershman/code/hybrid_fit.py
  • Code for data-driven RNN can be found at gershman/code/drnn.ipynb
  • Plots and describe results run gershman/code/plots.ipynb
  • gershman/code/checkpoint_trnn.pth is a state_dict of the trained model weights used for the analysis.