├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── authorship ├── Authorship and contribution for IBL behavior paper.ipynb ├── IBL behavior platform paper - contribution statement.ipynb ├── Screen Shot 2020-04-23 at 10.53.17.png ├── Screen Shot 2020-04-23 at 11.02.47.png ├── Screen Shot 2020-04-23 at 11.07.48.png ├── contributions.pdf ├── contributions.png ├── contributions_matrixform.csv ├── contributions_text.txt └── csv │ ├── Behavior paper (2019) - Contribution statement - Written contributions.csv │ ├── Behavior paper contribution statement - v2 (2020) - full_descriptions.csv │ └── clustered_contributions.png ├── create_csv_data_files.py ├── dj_env.yml ├── figure1c_number_of_mice.py ├── figure1def_training.py ├── figure2af_learning_curves_all_parameters.py ├── figure2d_training_probability.py ├── figure2g_time_to_trained.py ├── figure3ab_psychfuncs.py ├── figure3cde_variability_over_labs_basic_&_suppfig3-2.py ├── figure3f_classifier_lab_membership_basic.py ├── figure3f_plot_classifier_basic.py ├── figure4a_block_probabilities.py ├── figure4de_psychfuncs_biased.py ├── figure4fghi_variability_over_labs_full.py ├── figure4i_classifier_lab_membership_full.py ├── figure4i_plot_classifier_full.py ├── figure5_GLM_modelfit.py ├── figure5_GLM_plot.py ├── figure5_GLM_simulate.py ├── paper_behavior_functions.py ├── requirements.txt ├── supp_days_between_trainingstatus.py ├── supp_figure2_performance_trials.py ├── supp_nmice_overtime.py ├── supp_queries.py ├── suppfig_3-4a-f.py ├── suppfig_choicevariability_withinacrosslabs.py ├── suppfig_classifier_lab_membership_first_biased.py ├── suppfig_classifier_lab_membership_perf.py ├── suppfig_end_session_histogram.py ├── suppfig_history_bycontrast.py ├── suppfig_history_strategy.py ├── suppfig_plot_classifier_first_biased.py ├── suppfig_plot_classifier_perf.py ├── suppfig_variability_over_labs_first_biased.py ├── suppfig_variability_over_time.py └── text_trained1a_to_1b_sessions.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | .idea 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # Exported figures 108 | /exported_figs 109 | /data 110 | 111 | # macOS stuff 112 | .DS_store 113 | dj_local_conf.json 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 The International Brain Laboratory 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # paper-behavior 2 | This repository contains code to reproduce all figures of the behavior paper by the International Brain Laboratory. If reusing any part of this code please cite the [bioRxiv paper](https://www.biorxiv.org/content/10.1101/2020.01.17.909838v2) in which these figures appear. 3 | 4 | ### Installation 5 | These instructions require anaconda (https://www.anaconda.com/distribution/#download-section) for Python 3 and git (https://git-scm.com/book/en/v2/Getting-Started-Installing-Git) 6 | 7 | In an Anaconda prompt window: 8 | 1. Clone or download the paper-bahavior repository 9 | 2. Install the other dependencies by running `pip install -r requirements.txt` 10 | 11 | To call the functions in this repo, either run python from within the `paper-bahavior` folder or 12 | add the folder to your python path: 13 | ```python 14 | import sys 15 | sys.path.extend([r'path/to/paper-behavior']) 16 | ``` 17 | 18 | ### How to run the code 19 | All the scripts start with the name of the figure they produce. The figure panels will appear in the `exported_figs` subfolder. When running the scripts for the first time the required data will be downloaded to ./data 20 | 21 | NB: Since December 2023 our DataJoint servers were retired and some scripts no longer execute, however the main figure scripts should still work. 22 | 23 | ### Questions? 24 | If you have any problems running this code, please open an issue in the [iblenv repository](https://github.com/int-brain-lab/iblenv/issues) where we support users in using the IBL software tools. 25 | 26 | You can read more about the [IBL dataset types](https://docs.google.com/spreadsheets/d/1ieLXRPLLSgUKcLvFkrqizfZl5HjdfE6bQ2KLBCRmjQo/edit#gid=1097679410) and [additional computations on the behavioral data](data.internationalbrainlab.org), such as training status and psychometric functions. 27 | 28 | ### Known issues 29 | The data used in this paper have a number of issues. The authors are confident that these issues 30 | do not affect the results of the paper but nevertheless users should be aware of these 31 | shortcomings. 32 | 33 | 1. NaN values may be found throughout the data. These resulted from failures to either produce 34 | an event (for example the go cue tone wasn't played) or to record an event (e.g. the stimulus was 35 | produced but the photodiode failed to detect it). 36 | 2. Some events violated the task structure outlined in the paper. For example during some sessions 37 | the go cue tone happened much later than the stimulus onset. Although this conceivably 38 | affected the reaction time on some trials, it did not occur frequently enough to 39 | significantly affect the median reaction times. 40 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/paper-behavior/5066ab721053cfaf70a82e997a5d601273e95594/__init__.py -------------------------------------------------------------------------------- /authorship/Screen Shot 2020-04-23 at 10.53.17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/paper-behavior/5066ab721053cfaf70a82e997a5d601273e95594/authorship/Screen Shot 2020-04-23 at 10.53.17.png -------------------------------------------------------------------------------- /authorship/Screen Shot 2020-04-23 at 11.02.47.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/paper-behavior/5066ab721053cfaf70a82e997a5d601273e95594/authorship/Screen Shot 2020-04-23 at 11.02.47.png -------------------------------------------------------------------------------- /authorship/Screen Shot 2020-04-23 at 11.07.48.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/paper-behavior/5066ab721053cfaf70a82e997a5d601273e95594/authorship/Screen Shot 2020-04-23 at 11.07.48.png -------------------------------------------------------------------------------- /authorship/contributions.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/paper-behavior/5066ab721053cfaf70a82e997a5d601273e95594/authorship/contributions.pdf -------------------------------------------------------------------------------- /authorship/contributions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/paper-behavior/5066ab721053cfaf70a82e997a5d601273e95594/authorship/contributions.png -------------------------------------------------------------------------------- /authorship/contributions_matrixform.csv: -------------------------------------------------------------------------------- 1 | category_order,credit_task,"Aguillon Rodriguez, Valeria","Angelaki, Dora E.","Bayer, Hannah M.","Bonacchi, Niccolò","Carandini, Matteo","Cazettes, Fanny","Chapuis, Gaelle A.","Churchand, Anne K.","Dan, Yang","DeWitt, Eric E.","Faulkner, Mayo","Forrest, Hamish","Haetzel, Laura M.","Hausser, Michael","Hofer, Sonja B.","Hu, Fei","Khanal, Anup","Krasniak, Christopher S.","Laranjeira, Inês C.","Mainen, Zachary F.","Meijer, Guido T.","Miska, Nathaniel J.","Mrsic-Flogel, Thomas D.","Murakami, Masayoshi","Noel, Jean Paul","Pan-Vazquez, Alejandro","Rossant, Cyrille","Sanders, Joshua I.","Socha, Karolina Z.","Terry, Rebecca","Urai, Anne E.","Vergara, Hernando M.","Wells, Miles J.","Wilson, Christian J.","Witten, Ilana B.","Wool, Lauren E.","Zador, Anthony M." 2 | 0,CONCEPTUALIZATION: defined composition and scope of the paper,0,0,0,0,3,0,2,0,0,1,0,0,0,0,0,0,0,0,0,0,2,0,0,0,0,2,0,0,0,0,2,0,2,0,0,0,0 3 | 1,"METHODOLOGY: built, designed and tested rig assembly",2,0,0,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,0,0,1,0,0,0,2,1,0,1,0,1,0,0,3,0 4 | 1,METHODOLOGY: designed and delivered rig components,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,0,0,0,0,0,0,0,0,0 5 | 1,METHODOLOGY: developed final behavioral task,0,0,0,2,0,1,0,0,0,2,0,0,0,0,0,0,0,0,0,0,2,0,0,0,0,0,0,0,0,0,2,0,2,0,0,2,0 6 | 1,"METHODOLOGY: developed protocols for surgery, husbandry and animal training",2,0,0,1,0,1,3,0,0,1,0,0,0,0,0,0,0,2,2,0,0,0,0,0,0,1,0,0,2,0,2,0,0,0,0,1,0 7 | 1,METHODOLOGY: piloted candidate behavioral tasks,0,0,0,2,0,3,0,0,0,2,0,0,0,0,0,0,0,2,2,0,0,0,0,1,0,0,0,0,0,0,2,0,3,0,0,0,0 8 | 1,METHODOLOGY: standardized licenses and experimental protocols across institutions,0,0,0,2,0,0,3,0,0,0,0,0,0,0,0,2,0,0,0,0,2,0,0,0,2,2,0,0,0,0,0,0,0,0,0,0,0 9 | 2,SOFTWARE: developed data acquisition software and infrastructure,0,0,0,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 10 | 3,VALIDATION: maintained and validated analysis code,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,2,0,0,0,0,1,0,0,0,0,2,0,2,0,0,0,0 11 | 4,FORMAL ANALYSIS: analyzed data,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,1,3,0,0,0,1,3,0,0,0,0,3,0,2,0,0,0,0 12 | 5,"INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data",2,0,0,1,0,2,0,0,0,0,2,1,2,0,0,2,1,2,2,0,2,2,0,0,2,2,0,0,2,1,2,1,0,1,0,0,0 13 | 6,RESOURCES: hosted the research,0,2,0,0,2,0,0,2,2,0,0,0,0,2,2,0,0,0,0,2,0,0,2,0,0,0,0,0,0,0,0,0,0,0,2,0,2 14 | 7,DATA CURATION: curated data and metadata,0,0,0,2,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,3,0,0,0,0 15 | 8,WRITING - ORIGINAL DRAFT: wrote and curated the appendix protocols,0,0,1,1,0,0,3,0,0,0,2,0,0,0,0,0,0,1,1,0,2,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0 16 | 8,WRITING - ORIGINAL DRAFT: wrote the first version of the paper,0,0,0,1,2,1,0,0,0,0,0,0,0,0,0,0,0,2,0,0,0,0,0,0,3,0,0,0,0,0,0,0,0,0,0,3,0 17 | 8,WRITING - ORIGINAL DRAFT: wrote the second version of the paper,0,0,0,0,3,0,2,0,0,0,0,0,0,0,0,0,0,0,0,0,2,0,0,0,0,2,0,0,0,0,2,0,2,0,0,0,0 18 | 9,WRITING - REVIEW AND EDITING: edited the paper,0,0,3,1,2,0,1,0,0,1,0,0,0,1,0,0,0,1,0,2,0,1,0,0,1,2,0,0,1,0,0,0,0,0,0,2,0 19 | 9,WRITING - REVIEW AND EDITING: revised the paper in response to peer review,0,0,0,0,2,0,2,0,0,0,0,0,0,0,0,0,0,0,0,0,2,0,0,0,0,2,0,0,0,0,2,0,2,0,0,0,0 20 | 10,VISUALIZATION: created data visualizations,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,0,1,0,0,0,0 21 | 10,VISUALIZATION: designed and created figures,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,3,0,0,0,0,3,0,0,0,0,3,0,1,0,0,0,0 22 | 11,SUPERVISION: managed and coordinated team,0,0,1,1,3,1,3,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,2 23 | 11,SUPERVISION: supervised local laboratory research,0,2,0,0,2,0,0,2,2,0,0,0,0,2,2,0,0,0,0,2,0,0,2,0,0,0,0,0,0,0,0,0,0,0,2,0,2 24 | 12,PROJECT ADMINISTRATION: managed and coordinated research outputs,0,0,0,1,3,0,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0 25 | 13,FUNDING ACQUISITION: acquired funding,0,1,0,1,2,0,1,3,1,1,0,0,0,3,1,1,0,0,0,3,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1,1,2 26 | -------------------------------------------------------------------------------- /authorship/contributions_text.txt: -------------------------------------------------------------------------------- 1 | Valeria Aguillon Rodriguez: METHODOLOGY: built, designed and tested rig assembly (equal); METHODOLOGY: developed protocols for surgery, husbandry and animal training (equal); INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (equal) 2 | Dora E. Angelaki: RESOURCES: hosted the research (equal); SUPERVISION: supervised local laboratory research (equal); FUNDING ACQUISITION: acquired funding (support) 3 | Hannah M. Bayer: WRITING - ORIGINAL DRAFT: wrote and curated the appendix protocols (support); WRITING - REVIEW AND EDITING: edited the paper (lead); SUPERVISION: managed and coordinated team (support) 4 | Niccolò Bonacchi: METHODOLOGY: built, designed and tested rig assembly (lead); METHODOLOGY: designed and delivered rig components (support); METHODOLOGY: piloted candidate behavioral tasks (equal); METHODOLOGY: developed final behavioral task (equal); METHODOLOGY: developed protocols for surgery, husbandry and animal training (support); METHODOLOGY: standardized licenses and experimental protocols across institutions (equal); SOFTWARE: developed data acquisition software and infrastructure (lead); VALIDATION: maintained and validated analysis code (support); FORMAL ANALYSIS: analyzed data (support); INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (support); DATA CURATION: curated data and metadata (equal); WRITING - ORIGINAL DRAFT: wrote the first version of the paper (support); WRITING - ORIGINAL DRAFT: wrote and curated the appendix protocols (support); WRITING - REVIEW AND EDITING: edited the paper (support); SUPERVISION: managed and coordinated team (support); PROJECT ADMINISTRATION: managed and coordinated research outputs (support); FUNDING ACQUISITION: acquired funding (support) 5 | Matteo Carandini: CONCEPTUALIZATION: defined composition and scope of the paper (lead); RESOURCES: hosted the research (equal); WRITING - ORIGINAL DRAFT: wrote the first version of the paper (equal); WRITING - ORIGINAL DRAFT: wrote the second version of the paper (lead); WRITING - REVIEW AND EDITING: edited the paper (equal); WRITING - REVIEW AND EDITING: revised the paper in response to peer review (equal); SUPERVISION: supervised local laboratory research (equal); SUPERVISION: managed and coordinated team (lead); PROJECT ADMINISTRATION: managed and coordinated research outputs (lead); FUNDING ACQUISITION: acquired funding (equal) 6 | Fanny Cazettes: METHODOLOGY: piloted candidate behavioral tasks (lead); METHODOLOGY: developed final behavioral task (support); METHODOLOGY: developed protocols for surgery, husbandry and animal training (support); INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (equal); WRITING - ORIGINAL DRAFT: wrote the first version of the paper (support); SUPERVISION: managed and coordinated team (support) 7 | Gaelle A. Chapuis: CONCEPTUALIZATION: defined composition and scope of the paper (equal); METHODOLOGY: developed protocols for surgery, husbandry and animal training (lead); METHODOLOGY: designed and delivered rig components (support); METHODOLOGY: standardized licenses and experimental protocols across institutions (lead); VALIDATION: maintained and validated analysis code (support); FORMAL ANALYSIS: analyzed data (support); DATA CURATION: curated data and metadata (support); WRITING - ORIGINAL DRAFT: wrote the second version of the paper (equal); WRITING - ORIGINAL DRAFT: wrote and curated the appendix protocols (lead); WRITING - REVIEW AND EDITING: edited the paper (support); WRITING - REVIEW AND EDITING: revised the paper in response to peer review (equal); VISUALIZATION: designed and created figures (support); SUPERVISION: managed and coordinated team (lead); PROJECT ADMINISTRATION: managed and coordinated research outputs (lead); FUNDING ACQUISITION: acquired funding (support) 8 | Anne K. Churchand: RESOURCES: hosted the research (equal); SUPERVISION: supervised local laboratory research (equal); FUNDING ACQUISITION: acquired funding (lead) 9 | Yang Dan: RESOURCES: hosted the research (equal); SUPERVISION: supervised local laboratory research (equal); FUNDING ACQUISITION: acquired funding (support) 10 | Eric E. DeWitt: CONCEPTUALIZATION: defined composition and scope of the paper (support); METHODOLOGY: developed final behavioral task (equal); METHODOLOGY: piloted candidate behavioral tasks (equal); METHODOLOGY: developed protocols for surgery, husbandry and animal training (support); WRITING - REVIEW AND EDITING: edited the paper (support); SUPERVISION: managed and coordinated team (support); FUNDING ACQUISITION: acquired funding (support) 11 | Mayo Faulkner: INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (equal); WRITING - ORIGINAL DRAFT: wrote and curated the appendix protocols (equal) 12 | Hamish Forrest: INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (support) 13 | Laura M. Haetzel: INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (equal) 14 | Michael Hausser: RESOURCES: hosted the research (equal); WRITING - REVIEW AND EDITING: edited the paper (support); SUPERVISION: supervised local laboratory research (equal); FUNDING ACQUISITION: acquired funding (lead) 15 | Sonja B. Hofer: RESOURCES: hosted the research (equal); SUPERVISION: supervised local laboratory research (equal); FUNDING ACQUISITION: acquired funding (support) 16 | Fei Hu: METHODOLOGY: standardized licenses and experimental protocols across institutions (equal); INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (equal); FUNDING ACQUISITION: acquired funding (support) 17 | Anup Khanal: INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (support) 18 | Christopher S. Krasniak: METHODOLOGY: piloted candidate behavioral tasks (equal); METHODOLOGY: developed protocols for surgery, husbandry and animal training (equal); INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (equal); WRITING - ORIGINAL DRAFT: wrote the first version of the paper (equal); WRITING - ORIGINAL DRAFT: wrote and curated the appendix protocols (support); WRITING - REVIEW AND EDITING: edited the paper (support) 19 | Inês C. Laranjeira: METHODOLOGY: piloted candidate behavioral tasks (equal); METHODOLOGY: developed protocols for surgery, husbandry and animal training (equal); INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (equal); WRITING - ORIGINAL DRAFT: wrote and curated the appendix protocols (support) 20 | Zachary F. Mainen: FORMAL ANALYSIS: analyzed data (support); RESOURCES: hosted the research (equal); WRITING - REVIEW AND EDITING: edited the paper (equal); SUPERVISION: supervised local laboratory research (equal); FUNDING ACQUISITION: acquired funding (lead) 21 | Guido T. Meijer: CONCEPTUALIZATION: defined composition and scope of the paper (equal); METHODOLOGY: developed final behavioral task (equal); METHODOLOGY: built, designed and tested rig assembly (lead); METHODOLOGY: standardized licenses and experimental protocols across institutions (equal); VALIDATION: maintained and validated analysis code (equal); FORMAL ANALYSIS: analyzed data (lead); INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (equal); WRITING - ORIGINAL DRAFT: wrote the second version of the paper (equal); WRITING - ORIGINAL DRAFT: wrote and curated the appendix protocols (equal); WRITING - REVIEW AND EDITING: revised the paper in response to peer review (equal); VISUALIZATION: designed and created figures (lead) 22 | Nathaniel J. Miska: INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (equal); WRITING - REVIEW AND EDITING: edited the paper (support) 23 | Thomas D. Mrsic-Flogel: RESOURCES: hosted the research (equal); SUPERVISION: supervised local laboratory research (equal); FUNDING ACQUISITION: acquired funding (support) 24 | Masayoshi Murakami: METHODOLOGY: built, designed and tested rig assembly (support); METHODOLOGY: piloted candidate behavioral tasks (support) 25 | Jean Paul Noel: METHODOLOGY: standardized licenses and experimental protocols across institutions (equal); FORMAL ANALYSIS: analyzed data (support); INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (equal); WRITING - ORIGINAL DRAFT: wrote the first version of the paper (lead); WRITING - REVIEW AND EDITING: edited the paper (support) 26 | Alejandro Pan-Vazquez: CONCEPTUALIZATION: defined composition and scope of the paper (equal); METHODOLOGY: standardized licenses and experimental protocols across institutions (equal); METHODOLOGY: developed protocols for surgery, husbandry and animal training (support); VALIDATION: maintained and validated analysis code (support); FORMAL ANALYSIS: analyzed data (lead); INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (equal); WRITING - ORIGINAL DRAFT: wrote the second version of the paper (equal); WRITING - REVIEW AND EDITING: edited the paper (equal); WRITING - REVIEW AND EDITING: revised the paper in response to peer review (equal); VISUALIZATION: designed and created figures (lead) 27 | Cyrille Rossant: DATA CURATION: curated data and metadata (support) 28 | Joshua I. Sanders: METHODOLOGY: designed and delivered rig components (lead); METHODOLOGY: built, designed and tested rig assembly (equal) 29 | Karolina Z. Socha: METHODOLOGY: developed protocols for surgery, husbandry and animal training (equal); METHODOLOGY: built, designed and tested rig assembly (support); INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (equal); WRITING - REVIEW AND EDITING: edited the paper (support) 30 | Rebecca Terry: INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (support) 31 | Anne E. Urai: CONCEPTUALIZATION: defined composition and scope of the paper (equal); METHODOLOGY: built, designed and tested rig assembly (support); METHODOLOGY: piloted candidate behavioral tasks (equal); METHODOLOGY: developed final behavioral task (equal); METHODOLOGY: developed protocols for surgery, husbandry and animal training (equal); VALIDATION: maintained and validated analysis code (equal); FORMAL ANALYSIS: analyzed data (lead); INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (equal); DATA CURATION: curated data and metadata (support); WRITING - ORIGINAL DRAFT: wrote the second version of the paper (equal); WRITING - ORIGINAL DRAFT: wrote and curated the appendix protocols (support); WRITING - REVIEW AND EDITING: revised the paper in response to peer review (equal); VISUALIZATION: designed and created figures (lead); VISUALIZATION: created data visualizations (lead); SUPERVISION: managed and coordinated team (support); PROJECT ADMINISTRATION: managed and coordinated research outputs (support) 32 | Hernando M. Vergara: INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (support); WRITING - ORIGINAL DRAFT: wrote and curated the appendix protocols (support) 33 | Miles J. Wells: CONCEPTUALIZATION: defined composition and scope of the paper (equal); METHODOLOGY: built, designed and tested rig assembly (support); METHODOLOGY: piloted candidate behavioral tasks (lead); METHODOLOGY: developed final behavioral task (equal); VALIDATION: maintained and validated analysis code (equal); FORMAL ANALYSIS: analyzed data (equal); DATA CURATION: curated data and metadata (lead); WRITING - ORIGINAL DRAFT: wrote the second version of the paper (equal); WRITING - REVIEW AND EDITING: revised the paper in response to peer review (equal); VISUALIZATION: designed and created figures (support); VISUALIZATION: created data visualizations (support) 34 | Christian J. Wilson: INVESTIGATION: built and maintained rigs, performed surgeries, collected behavioral data (support) 35 | Ilana B. Witten: RESOURCES: hosted the research (equal); SUPERVISION: supervised local laboratory research (equal); FUNDING ACQUISITION: acquired funding (support) 36 | Lauren E. Wool: METHODOLOGY: built, designed and tested rig assembly (lead); METHODOLOGY: developed final behavioral task (equal); METHODOLOGY: developed protocols for surgery, husbandry and animal training (support); WRITING - ORIGINAL DRAFT: wrote the first version of the paper (lead); WRITING - REVIEW AND EDITING: edited the paper (equal); FUNDING ACQUISITION: acquired funding (support) 37 | Anthony M. Zador: RESOURCES: hosted the research (equal); SUPERVISION: supervised local laboratory research (equal); SUPERVISION: managed and coordinated team (equal); FUNDING ACQUISITION: acquired funding (equal) 38 | -------------------------------------------------------------------------------- /authorship/csv/clustered_contributions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/paper-behavior/5066ab721053cfaf70a82e997a5d601273e95594/authorship/csv/clustered_contributions.png -------------------------------------------------------------------------------- /create_csv_data_files.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Queries to save static csv files with data for each figure 5 | 6 | Guido Meijer 7 | Jul 1, 2020 8 | """ 9 | 10 | from os import mkdir 11 | from os.path import join, isdir 12 | import pandas as pd 13 | from paper_behavior_functions import (query_subjects, query_sessions_around_criterion, 14 | institution_map, CUTOFF_DATE, dj2pandas, datapath, 15 | query_session_around_performance) 16 | from ibl_pipeline.analyses import behavior as behavioral_analyses 17 | from ibl_pipeline import reference, subject, behavior, acquisition 18 | import csv 19 | 20 | # Get map of lab number to institute 21 | institution_map, _ = institution_map() 22 | 23 | # create data directory if it doesn't exist yet 24 | root = datapath() 25 | if not isdir(root): 26 | mkdir(root) 27 | 28 | # Create list of subjects used 29 | subjects = query_subjects(as_dataframe=True) 30 | subjects.to_csv(join(root, 'subjects.csv')) 31 | 32 | 33 | # %%=============================== # 34 | # FIGURE 2 35 | # ================================= # 36 | print('Starting figure 2.') 37 | # Figure 2af 38 | use_subjects = query_subjects() 39 | b = (behavioral_analyses.BehavioralSummaryByDate * use_subjects * behavioral_analyses.BehavioralSummaryByDate.PsychResults) 40 | behav = b.fetch(order_by='institution_short, subject_nickname, training_day', 41 | format='frame').reset_index() 42 | behav['institution_code'] = behav.institution_short.map(institution_map) 43 | # Save to csv 44 | behav.to_pickle(join(root, 'Fig2af.pkl')) 45 | 46 | # Figure 2h 47 | all_mice = (subject.Subject * subject.SubjectLab * reference.Lab 48 | * subject.SubjectProject() & 'subject_project = "ibl_neuropixel_brainwide_01"') 49 | mice_started_training = (all_mice & (acquisition.Session() & 'task_protocol LIKE "%training%"')) 50 | still_training = all_mice.aggr(behavioral_analyses.SessionTrainingStatus, 51 | session_start_time='max(session_start_time)') \ 52 | * behavioral_analyses.SessionTrainingStatus - subject.Death \ 53 | & 'training_status = "in_training"' \ 54 | & 'session_start_time > "%s"' % CUTOFF_DATE 55 | use_subjects = mice_started_training - still_training 56 | 57 | # Get training status and training time in number of sessions and trials 58 | ses = ( 59 | (use_subjects 60 | * behavioral_analyses.SessionTrainingStatus 61 | * behavioral_analyses.PsychResults) 62 | .proj('subject_nickname', 'training_status', 'n_trials_stim', 'institution_short') 63 | .fetch(format='frame') 64 | .reset_index() 65 | ) 66 | ses['n_trials'] = [sum(i) for i in ses['n_trials_stim']] 67 | ses = ses.drop('n_trials_stim', axis=1) 68 | 69 | # Save to csv 70 | ses.to_csv(join(root, 'Fig2d.csv')) 71 | 72 | # Figure 2ab 73 | 74 | # Query list of subjects to use 75 | use_subjects = query_subjects() 76 | 77 | b = (behavioral_analyses.BehavioralSummaryByDate * use_subjects) 78 | behav = b.fetch(order_by='institution_short, subject_nickname, training_day', 79 | format='frame').reset_index() 80 | behav['institution_code'] = behav.institution_short.map(institution_map) 81 | 82 | # Save to csv 83 | behav.to_csv(join(root, 'suppFig2_1.csv')) 84 | 85 | # %%=============================== # 86 | # FIGURE 3 87 | # ================================= # 88 | print('Starting figure 3..') 89 | 90 | # query sessions 91 | use_sessions, _ = query_sessions_around_criterion(criterion='trained', 92 | days_from_criterion=[2, 0], 93 | as_dataframe=False, 94 | force_cutoff=True) 95 | use_sessions = use_sessions & 'task_protocol LIKE "%training%"' # only get training sessions 96 | 97 | # list of dicts - see https://int-brain-lab.slack.com/archives/CB13FQFK4/p1607369435116300 for explanation 98 | sess = use_sessions.proj('task_protocol').fetch(format='frame').reset_index().to_dict('records') 99 | 100 | # query all trials for these sessions, it's split in two because otherwise the query would become 101 | # too big to handle in one go 102 | b = (behavior.TrialSet.Trial & sess) \ 103 | * subject.Subject * subject.SubjectLab * reference.Lab * acquisition.Session 104 | 105 | # reduce the size of the fetch 106 | b2 = b.proj('institution_short', 'subject_nickname', 107 | 'trial_stim_contrast_left', 'trial_stim_contrast_right', 'trial_response_choice', 108 | 'trial_stim_prob_left', 'trial_feedback_type', 'trial_response_time', 109 | 'trial_stim_on_time', 'session_end_time', 'task_protocol', 'time_zone') 110 | 111 | # construct pandas dataframe 112 | bdat = b2.fetch(order_by='institution_short, subject_nickname, session_start_time, trial_id', 113 | format='frame').reset_index() 114 | behav = dj2pandas(bdat) 115 | behav['institution_code'] = behav.institution_short.map(institution_map) 116 | 117 | # save to disk 118 | behav.to_csv(join(root, 'Fig3.csv')) 119 | 120 | # %%=============================== # 121 | # FIGURE 4 122 | # ================================= # 123 | print('Starting figure 4..') 124 | 125 | # query sessions 126 | use_sessions, _ = query_sessions_around_criterion(criterion='ephys', 127 | days_from_criterion=[2, 0], 128 | force_cutoff=True) 129 | use_sessions = use_sessions & 'task_protocol LIKE "%biased%"' # only get biased sessions 130 | # list of dicts - see https://int-brain-lab.slack.com/archives/CB13FQFK4/p1607369435116300 for explanation 131 | sess = use_sessions.proj('task_protocol').fetch(format='frame').reset_index().to_dict('records') 132 | 133 | # restrict by list of dicts with uuids for these sessions 134 | b = (behavior.TrialSet.Trial & sess) \ 135 | * subject.Subject * subject.SubjectLab * reference.Lab * acquisition.Session 136 | 137 | # reduce the size of the fetch 138 | b2 = b.proj('institution_short', 'subject_nickname', 'task_protocol', 'session_uuid', 139 | 'trial_stim_contrast_left', 'trial_stim_contrast_right', 'trial_response_choice', 140 | 'trial_stim_prob_left', 'trial_feedback_type', 141 | 'trial_response_time', 'trial_stim_on_time', 'time_zone') 142 | 143 | # construct pandas dataframe 144 | bdat = b2.fetch(order_by='institution_short, subject_nickname, session_start_time, trial_id', 145 | format='frame').reset_index() 146 | behav = dj2pandas(bdat) 147 | behav['institution_code'] = behav.institution_short.map(institution_map) 148 | 149 | # save to disk 150 | behav.to_csv(join(root, 'Fig4.csv')) 151 | 152 | # %%=============================== # 153 | # FIGURE 5 154 | # ================================= # 155 | print('Starting figure 5..') 156 | 157 | # Query sessions biased data 158 | use_sessions, _ = query_sessions_around_criterion(criterion='biased', 159 | days_from_criterion=[2, 3], 160 | as_dataframe=False, 161 | force_cutoff=True) 162 | sess = use_sessions.proj('task_protocol').fetch(format='frame').reset_index().to_dict('records') 163 | 164 | # restrict by list of dicts with uuids for these sessions 165 | b = (behavior.TrialSet.Trial & use_sessions) \ 166 | * acquisition.Session * subject.Subject * subject.SubjectLab * reference.Lab 167 | 168 | # reduce the size of the fetch 169 | b2 = b.proj('institution_short', 'subject_nickname', 'task_protocol', 170 | 'trial_stim_contrast_left', 'trial_stim_contrast_right', 171 | 'trial_response_choice', 'trial_stim_prob_left', 'trial_feedback_type') 172 | bdat = b2.fetch(order_by='institution_short, subject_nickname, session_start_time, trial_id', 173 | format='frame').reset_index() 174 | behav = dj2pandas(bdat) 175 | behav['institution_code'] = behav.institution_short.map(institution_map) 176 | 177 | # save to disk 178 | behav.to_csv(join(root, 'Fig5.csv')) 179 | 180 | # %%=============================== # 181 | # FIGURE 3 - SUPPLEMENT 2 182 | # ================================= # 183 | print('Starting figure 3 - supplement 2..') 184 | 185 | # Query sessions biased data 186 | use_sessions, _ = query_sessions_around_criterion( 187 | criterion='biased', 188 | days_from_criterion=[-1, 3], 189 | force_cutoff=True) 190 | use_sessions = use_sessions & 'task_protocol LIKE "%biased%"' # only get biased sessions 191 | sess = use_sessions.proj('task_protocol').fetch(format='frame').reset_index().to_dict('records') 192 | 193 | # restrict by list of dicts with uuids for these sessions 194 | b = (behavior.TrialSet.Trial & sess) \ 195 | * acquisition.Session * subject.Subject * subject.SubjectLab * reference.Lab 196 | 197 | # reduce the size of the fetch 198 | b2 = b.proj('institution_short', 'subject_nickname', 'task_protocol', 199 | 'trial_stim_contrast_left', 'trial_stim_contrast_right', 200 | 'trial_response_choice', 'task_protocol', 'trial_stim_prob_left', 201 | 'trial_feedback_type') 202 | bdat = b2.fetch(order_by='institution_short, subject_nickname, session_start_time, trial_id', 203 | format='frame').reset_index() 204 | behav = dj2pandas(bdat) 205 | behav['institution_code'] = behav.institution_short.map(institution_map) 206 | 207 | # save to disk 208 | behav.to_csv(join(root, 'Fig3-supp2.csv')) 209 | behav = query_session_around_performance(perform_thres=0.8) 210 | behav.to_pickle(join(root, 'suppfig_3-4af.pkl')) 211 | -------------------------------------------------------------------------------- /dj_env.yml: -------------------------------------------------------------------------------- 1 | name: dj_env 2 | channels: 3 | - defaults 4 | - conda-forge 5 | - conda-canary 6 | dependencies: 7 | - python >=3.6 8 | - scipy >=1.3.0 9 | - numpy ==1.16.4 10 | - matplotlib >=3.0.3 11 | - seaborn >=0.9.0 12 | - pandas >=0.24.2 13 | - flake8 >=3.7.8 14 | - requests >=2.22.0 15 | - jupyter >=1.0 16 | - jupyterlab >=1.0 17 | - scikit-learn 18 | - tensorflow 19 | - keras 20 | - plotly 21 | - pip >=19.0 22 | - pip: 23 | - colorlog >=4.0.2 24 | - dataclasses >=0.6 25 | - globus-sdk >=1.8.0 26 | - ibl-pipeline 27 | - datajoint 28 | - scikit_posthocs 29 | - pycircstat 30 | - nose 31 | - scikit-learn 32 | 33 | -------------------------------------------------------------------------------- /figure1c_number_of_mice.py: -------------------------------------------------------------------------------- 1 | """ 2 | Query the number of mice at different timepoints of the pipeline. 3 | 4 | @author: Anne Urai, Guido Meijer, Miles Wells, 16 Jan 2020 5 | Updated 22 April 2020, Anne Urai 6 | """ 7 | 8 | from paper_behavior_functions import query_subjects, CUTOFF_DATE, QUERY 9 | from ibl_pipeline import subject, acquisition, reference 10 | from ibl_pipeline.analyses import behavior as behavior_analysis 11 | 12 | assert QUERY, 'This script requires a DataJoint instance, which was removed in Dec 2023.' 13 | 14 | # ========================= 15 | # 1. Query all mice on brainwide map project which began training before the paper's cutoff date 16 | # ========================= 17 | 18 | all_mice = (subject.Subject * subject.SubjectLab * reference.Lab * subject.SubjectProject() 19 | & 'subject_project = "ibl_neuropixel_brainwide_01"').aggr( 20 | acquisition.Session, first_session='min(date(session_start_time))') 21 | 22 | # Filter mice that started training after the paper's cutoff date 23 | all_mice = all_mice.aggr(acquisition.Session, first_session='min(date(session_start_time))') 24 | all_mice = (all_mice & 'first_session < "%s"' % CUTOFF_DATE) 25 | 26 | print('1. Total # of mice in brainwide project: %d' % len(all_mice)) 27 | 28 | # ================================================== 29 | # Exclude mice that are still in training at the date of cut-off, meaning they have not yet 30 | # reached any learned criteria 31 | # ================================================== 32 | 33 | all_mice = query_subjects(criterion=None) # Mice that started the training task protocol 34 | still_training = all_mice * subject.Subject.aggr(behavior_analysis.SessionTrainingStatus, 35 | session_start_time='max(session_start_time)') \ 36 | * behavior_analysis.SessionTrainingStatus - subject.Death \ 37 | & 'training_status = "in_training"' & 'session_start_time > "%s"' % CUTOFF_DATE 38 | # print(pd.DataFrame(still_training)) 39 | 40 | # ================================================== 41 | # Get mice that started training 42 | # ================================================== 43 | 44 | mice_started_training = (all_mice & (acquisition.Session() & 'task_protocol LIKE "%training%"')) 45 | print('2. Number of mice that went into training: %d' % len(mice_started_training)) 46 | print('3. Number of mice that are still in training (exclude from 1 and 2): %d' % len(still_training)) 47 | 48 | # ================================================== 49 | # Mice that reached trained 50 | # ================================================== 51 | 52 | trained = query_subjects(criterion='trained') 53 | print('4. Number of mice that reached trained: %d' % len(trained)) 54 | print('5. Number of mice that reached ready4ephys: %d' % len(query_subjects(criterion='ephys'))) 55 | 56 | # ================================================== 57 | # Trained mice yet to meet final criterion at the cut off date. 58 | # These mice did not quite reach ready4ephysrig by the cut-off date, but were likely to 59 | # ================================================== 60 | 61 | # mice that reached trained but not ready4ephys, didn't die before the cut-off, and had fewer 62 | # than 40 sessions (no session marked as 'untrainable') 63 | session_training_status = acquisition.Session * behavior_analysis.SessionTrainingStatus() 64 | trained_not_ready = (trained.aggr(session_training_status, 65 | unfinished='SUM(training_status="ready4ephys" OR ' 66 | 'training_status="untrainable" OR ' 67 | 'training_status="unbiasable") = 0') 68 | .aggr(subject.Death, 'unfinished', 69 | alive='death_date IS NULL OR death_date > "%s"' % CUTOFF_DATE, 70 | keep_all_rows=True)) 71 | 72 | print('6. Number of mice that remain in training at the time of writing: %d' % 73 | len(trained_not_ready & 'alive = True AND unfinished = True')) 74 | -------------------------------------------------------------------------------- /figure1def_training.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training progression for an example mouse 3 | 4 | @author: Anne Urai, Gaelle Chapuis, Miles Wells 5 | 21 April 2020 6 | """ 7 | import os 8 | import copy 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import matplotlib.pyplot as plt 13 | import matplotlib.dates as mdates 14 | from mpl_toolkits.axes_grid1.inset_locator import inset_axes 15 | import seaborn as sns 16 | import datajoint as dj 17 | 18 | from paper_behavior_functions import seaborn_style, figpath, \ 19 | FIGURE_HEIGHT, FIGURE_WIDTH, EXAMPLE_MOUSE, dj2pandas, plot_psychometric, QUERY 20 | 21 | assert QUERY, 'This script requires a DataJoint instance, which was removed in Dec 2023.' 22 | 23 | # import wrappers etc 24 | from ibl_pipeline import subject, behavior, acquisition 25 | from ibl_pipeline.analyses import behavior as behavioral_analyses 26 | endcriteria = dj.create_virtual_module( 27 | 'SessionEndCriteriaImplemented', 'group_shared_end_criteria') 28 | 29 | 30 | def plot_contrast_heatmap(mouse, lab, ax, xlims): 31 | """ 32 | This function is copied from 33 | IBL-pipeline/prelim_analyses/behavioral_snapshots/behavior_plots.py 34 | """ 35 | cmap = copy.copy(plt.get_cmap('vlag')) 36 | cmap.set_bad(color="w") # remove rectangles without data, should be white 37 | 38 | session_date, signed_contrasts, prob_choose_right, prob_left_block = ( 39 | behavioral_analyses.BehavioralSummaryByDate.PsychResults * subject.Subject * 40 | subject.SubjectLab & 'subject_nickname="%s"' % mouse & 'lab_name="%s"' % lab).proj( 41 | 'signed_contrasts', 'prob_choose_right', 'session_date', 'prob_left_block').fetch( 42 | 'session_date', 'signed_contrasts', 'prob_choose_right', 'prob_left_block') 43 | if not len(session_date): 44 | return 45 | 46 | signed_contrasts = signed_contrasts * 100 47 | 48 | # reshape this to a heatmap format 49 | prob_left_block2 = signed_contrasts.copy() 50 | for i, date in enumerate(session_date): 51 | session_date[i] = np.repeat(date, len(signed_contrasts[i])) 52 | prob_left_block2[i] = np.repeat(prob_left_block[i], len(signed_contrasts[i])) 53 | 54 | result = pd.DataFrame({'session_date': np.concatenate(session_date), 55 | 'signed_contrasts': np.concatenate(signed_contrasts), 56 | 'prob_choose_right': np.concatenate(prob_choose_right), 57 | 'prob_left_block': np.concatenate(prob_left_block2)}) 58 | 59 | # only use the unbiased block for now 60 | result = result[result.prob_left_block == 0] 61 | result = result.round({'signed_contrasts': 2}) 62 | pp2 = result.pivot("signed_contrasts", "session_date", "prob_choose_right").sort_values( 63 | by='signed_contrasts', ascending=False) 64 | pp2 = pp2.reindex(sorted(result.signed_contrasts.unique())) 65 | 66 | # evenly spaced date axis 67 | x = pd.date_range(xlims[0], xlims[1]).to_pydatetime() 68 | pp2 = pp2.reindex(columns=x) 69 | pp2 = pp2.iloc[::-1] # reverse, red on top 70 | 71 | # inset axes for colorbar, to the right of plot 72 | axins1 = inset_axes(ax, width="5%", height="90%", loc='right', 73 | bbox_to_anchor=(0.15, 0., 1, 1), 74 | bbox_transform=ax.transAxes, borderpad=0,) 75 | 76 | # now heatmap 77 | sns.heatmap(pp2, linewidths=0, ax=ax, vmin=0, vmax=1, cmap=cmap, cbar=True, 78 | cbar_ax=axins1, cbar_kws={'label': 'Choose right (%)', 'shrink': 0.8, 'ticks': []}) 79 | ax.set(ylabel="Contrast (%)", xlabel='') 80 | # deal with date axis and make nice looking 81 | ax.xaxis_date() 82 | ax.xaxis.set_major_locator(mdates.WeekdayLocator(byweekday=mdates.MONDAY)) 83 | ax.xaxis.set_major_formatter(mdates.DateFormatter('%b-%d')) 84 | for item in ax.get_xticklabels(): 85 | item.set_rotation(60) 86 | 87 | # ================================= # 88 | # INITIALIZE A FEW THINGS 89 | # ================================= # 90 | 91 | seaborn_style() # noqa 92 | figpath = figpath() # noqa 93 | plt.close('all') 94 | # FIGURE_WIDTH = 6 # make narrower 95 | 96 | # ================================= # 97 | # Get lab name of example mouse 98 | # ================================= # 99 | 100 | lab = (subject.SubjectLab * subject.Subject & 'subject_nickname = "%s"' % EXAMPLE_MOUSE) \ 101 | .fetch1('lab_name') 102 | days = [2, 7, 10, 14] 103 | # days = [2, 7, 10, 13] # request gaelle 104 | 105 | # ================================================== 106 | # CONTRAST HEATMAP 107 | # ================================= # 108 | 109 | plt.close('all') 110 | fig, ax = plt.subplots(1, 2, figsize=(FIGURE_WIDTH / 2, FIGURE_HEIGHT)) 111 | ax[1].axis('off') 112 | xlims = [pd.Timestamp('2019-08-04T00'), pd.Timestamp('2019-08-31T00')] 113 | plot_contrast_heatmap(EXAMPLE_MOUSE, lab, ax[0], xlims) 114 | ax[0].set(ylabel='Contrast (%)', xlabel='Training day', 115 | xticks=[d + 1.5 for d in [2,8,11,17]], xticklabels=days, 116 | yticklabels=['100', '50', '25', '12.5', '6.25', '0', 117 | '-6.25', '-12.5', '-25', '-50', '-100']) 118 | for item in ax[0].get_xticklabels(): 119 | item.set_rotation(-0) 120 | plt.tight_layout() 121 | fig.savefig(os.path.join(figpath, "figure1_example_contrastheatmap.pdf")) 122 | fig.savefig(os.path.join( 123 | figpath, "figure1_example_contrastheatmap.png"), dpi=600) 124 | 125 | # ================================================================== # 126 | # PSYCHOMETRIC AND CHRONOMETRIC FUNCTIONS FOR EXAMPLE 3 DAYS 127 | # ================================================================== # 128 | 129 | # make these a bit more narrow 130 | b = ((subject.Subject & 'subject_nickname = "%s"' % EXAMPLE_MOUSE) 131 | * (subject.SubjectLab & 'lab_name="%s"' % lab) 132 | * behavioral_analyses.BehavioralSummaryByDate) 133 | behav = b.fetch(format='frame').reset_index() 134 | behav['training_day'] = behav.training_day - \ 135 | behav.training_day.min() + 1 # start at session 1 136 | 137 | for didx, day in enumerate(days): 138 | 139 | # get data for today 140 | print(day) 141 | thisdate = behav[behav.training_day == 142 | day]['session_date'].dt.strftime('%Y-%m-%d').item() 143 | b = (subject.Subject & 'subject_nickname = "%s"' % EXAMPLE_MOUSE) \ 144 | * (subject.SubjectLab & 'lab_name="%s"' % lab) \ 145 | * (acquisition.Session.proj(session_date='date(session_start_time)') & 146 | 'session_date = "%s"' % thisdate) \ 147 | * behavior.TrialSet.Trial() \ 148 | * endcriteria.SessionEndCriteriaImplemented() 149 | behavtmp = dj2pandas(b.fetch(format='frame').reset_index()) 150 | behavtmp['trial_start_time'] = behavtmp.trial_start_time / 60 # in minutes 151 | 152 | # unclear how this can be empty - but if it happens, skip 153 | if behavtmp.empty: 154 | continue 155 | 156 | # PSYCHOMETRIC FUNCTIONS 157 | fig, ax = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT*0.9)) 158 | plot_psychometric(behavtmp.signed_contrast, 159 | behavtmp.choice_right, 160 | behavtmp.trial_id, 161 | ax=ax, color='k') 162 | ax.set(xlabel="\u0394 Contrast (%)") 163 | 164 | if didx == 0: 165 | ax.set(ylabel="Rightward choices (%)") 166 | else: 167 | ax.set(ylabel=" ", yticklabels=[]) 168 | 169 | # ax.set(title='Training day %d' % (day)) 170 | sns.despine(trim=True) 171 | plt.tight_layout() 172 | fig.savefig(os.path.join( 173 | figpath, "figure1_example_psychfunc_day%d.pdf" % (day))) 174 | fig.savefig(os.path.join( 175 | figpath, "figure1_example_psychfunc_day%d.png" % (day)), dpi=600) 176 | 177 | # ================================================================== # 178 | # WITHIN-TRIAL DISENGAGEMENT CRITERIA 179 | # ================================================================== # 180 | 181 | plt.close('all') 182 | fig, ax = plt.subplots(2, 1, sharex=True, 183 | figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT*1.5)) 184 | 185 | # running median overlaid 186 | sns.lineplot(x='trial_start_time', y='rt', color='black', ci=None, 187 | data=behavtmp[['trial_start_time', 'rt']].rolling(20).median(), ax=ax[0]) 188 | ax[0].set(xlabel="", ylabel="RT (s)", ylim=[0.1, 20]) 189 | ax[0].set_yscale("log") 190 | 191 | # fix xlims 192 | if didx == 0: 193 | xlim = [0, 60] 194 | elif didx == 1: 195 | xlim = [0, 80] 196 | elif didx == 2: 197 | xlim = [0, 45] 198 | elif didx == 3: 199 | xlim = [0, 60] 200 | 201 | ax[0].set(yticks=[0.1, 1, 10, 20], 202 | yticklabels=['0.1', '1', '10', ''], xlim=xlim) 203 | 204 | if didx == 0: 205 | ax[0].set(ylabel="Trial duration (s)") 206 | else: 207 | ax[0].set(ylabel=" ", yticklabels=[]) 208 | 209 | # right y-axis with sliding performance 210 | # from : 211 | # https://stackoverflow.com/questions/36988123/pandas-groupby-and-rolling-apply-ignoring-nans 212 | 213 | g1 = behavtmp[['trial_start_time', 'correct_easy']].copy() 214 | g1['correct_easy'] = g1.correct_easy * 100 215 | g2 = g1.fillna(0).copy() 216 | s = g2.rolling(50).sum() / g1.rolling(50).count() # the actual computation 217 | 218 | sns.lineplot(x='trial_start_time', y='correct_easy', color='black', ci=None, 219 | data=s, ax=ax[1]) 220 | 221 | if day == min(days): 222 | ax[1].set(ylabel="Performance (%)\non easy trials") 223 | else: 224 | ax[1].set(ylabel=" ", yticklabels=[]) 225 | 226 | ax[1].set(xlabel='Time (min)', ylim=[25, 110], yticks=[25, 50, 75, 100], 227 | xlim=ax[0].get_xlim(), xticks=[0, 20, 40, 60, 80]) 228 | 229 | # INDICATE THE REASON AND TRIAL AT WHICH SESSION SHOULD HAVE ENDED 230 | idx = behavtmp.trial_id == behavtmp.end_status_index.unique()[0] 231 | end_x = behavtmp.loc[idx, 'trial_start_time'].values.item() 232 | ax[0].axvline(x=end_x, color='darkgrey', linestyle=':') 233 | ax[1].axvline(x=end_x, color='darkgrey', linestyle=':') 234 | # ax2.annotate(behavtmp.end_status.unique()[0], xy=(end_x, 100), xytext=(end_x, 105), 235 | # arrowprops={'arrowstyle': "->", 'connectionstyle': "arc3"}) 236 | print(behavtmp.end_status.unique()[0]) 237 | 238 | ax[0].set(title='Day %d: %d trials' % (day, behavtmp.shape[0])) 239 | sns.despine(trim=True) 240 | plt.tight_layout(h_pad=-0.05) 241 | fig.savefig(os.path.join( 242 | figpath, "figure1_example_disengagement_day%d.pdf" % day)) 243 | fig.savefig(os.path.join( 244 | figpath, "figure1_example_disengagement_day%d.png" % day), dpi=600) 245 | 246 | print(didx) 247 | print(thisdate) 248 | -------------------------------------------------------------------------------- /figure2d_training_probability.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Quantify the variability of the time to trained over labs. 5 | 6 | @author: Guido Meijer, Miles Wells 7 | 16 Jan 2020 8 | """ 9 | from os.path import join 10 | 11 | import pandas as pd 12 | import matplotlib.pyplot as plt 13 | import matplotlib.ticker as ticker 14 | import numpy as np 15 | import seaborn as sns 16 | 17 | from ibl_pipeline import subject 18 | from ibl_pipeline.analyses import behavior as behavior_analysis 19 | from paper_behavior_functions import (seaborn_style, institution_map, query_subjects, 20 | group_colors, figpath, load_csv, CUTOFF_DATE, 21 | FIGURE_HEIGHT, FIGURE_WIDTH, QUERY) 22 | from lifelines import KaplanMeierFitter 23 | 24 | # Settings 25 | fig_path = figpath() 26 | seaborn_style() 27 | 28 | if QUERY is True: 29 | mice_started_training = query_subjects(criterion=None) 30 | still_training = (mice_started_training.aggr(behavior_analysis.SessionTrainingStatus, 31 | session_start_time='max(session_start_time)') 32 | * behavior_analysis.SessionTrainingStatus - subject.Death 33 | & 'training_status = "in_training"' 34 | & 'session_start_time > "%s"' % CUTOFF_DATE) 35 | use_subjects = mice_started_training - still_training 36 | 37 | # Get training status and training time in number of sessions and trials 38 | ses = ((use_subjects * behavior_analysis.SessionTrainingStatus * behavior_analysis.PsychResults) 39 | .proj('subject_nickname', 'training_status', 'n_trials_stim', 'institution_short') 40 | .fetch(format='frame').reset_index()) 41 | ses['n_trials'] = [sum(i) for i in ses['n_trials_stim']] 42 | ses = ses.drop('n_trials_stim', axis=1).dropna() 43 | ses = ses.sort_values(['subject_nickname','session_start_time']) 44 | else: 45 | # Load in sessions from csv file 46 | ses = load_csv('Fig2d.csv').dropna() 47 | 48 | # Select mice that started training before cut off date 49 | ses = ses.groupby('subject_uuid').filter( 50 | lambda s : s['session_start_time'].min() < CUTOFF_DATE) 51 | 52 | # Construct dataframe from query 53 | training_time = pd.DataFrame() 54 | for i, nickname in enumerate(ses['subject_nickname'].unique()): 55 | training_time.loc[i, 'nickname'] = nickname 56 | training_time.loc[i, 'lab'] = ses.loc[ses['subject_nickname'] == nickname, 57 | 'institution_short'].values[0] 58 | training_time.loc[i, 'sessions'] = sum((ses['subject_nickname'] == nickname) 59 | & ((ses['training_status'] == 'in_training') 60 | | (ses['training_status'] == 'untrainable'))) 61 | training_time.loc[i, 'trials'] = ses.loc[((ses['subject_nickname'] == nickname) 62 | & (ses['training_status'] == 'in_training')), 63 | 'n_trials'].sum() 64 | training_time.loc[i, 'status'] = ses.loc[ses['subject_nickname'] == nickname, 65 | 'training_status'].values[-1] 66 | training_time.loc[i, 'date'] = ses.loc[ses['subject_nickname'] == nickname, 67 | 'session_start_time'].values[-1] 68 | 69 | # Transform training status into boolean 70 | training_time['trained'] = np.nan 71 | training_time.loc[((training_time['status'] == 'untrainable') 72 | | (training_time['status'] == 'in_training')), 'trained'] = 0 73 | training_time.loc[((training_time['status'] != 'untrainable') 74 | & (training_time['status'] != 'in_training')), 'trained'] = 1 75 | 76 | # Add lab number 77 | training_time['lab_number'] = training_time.lab.map(institution_map()[0]) 78 | training_time = training_time.sort_values('lab_number') 79 | 80 | # %% PLOT 81 | 82 | # Set figure style and color palette 83 | use_palette = [[0.6, 0.6, 0.6]] * len(np.unique(training_time['lab'])) 84 | use_palette = use_palette + [[1, 1, 0.2]] 85 | lab_colors = group_colors() 86 | ylim = [-0.02, 1.02] 87 | 88 | # Plot hazard rate survival analysis 89 | f, (ax1) = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT)) 90 | 91 | kmf = KaplanMeierFitter() 92 | for i, lab in enumerate(np.unique(training_time['lab_number'])): 93 | kmf.fit(training_time.loc[training_time['lab_number'] == lab, 'sessions'].values, 94 | event_observed=training_time.loc[training_time['lab_number'] == lab, 'trained']) 95 | ax1.step(kmf.cumulative_density_.index.values, kmf.cumulative_density_.values, 96 | color=lab_colors[i]) 97 | kmf.fit(training_time['sessions'].values, event_observed=training_time['trained']) 98 | ax1.step(kmf.cumulative_density_.index.values, kmf.cumulative_density_.values, color='black') 99 | ax1.set(ylabel='Reached proficiency', xlabel='Training day', 100 | xlim=[0, 60], ylim=ylim) 101 | ax1.set_title('All labs: %d mice' % training_time['nickname'].nunique()) 102 | 103 | # kmf.fit(training_time['sessions'].values, event_observed=training_time['trained']) 104 | # kmf.plot_cumulative_density(ax=ax2) 105 | # ax2.set(ylabel='Cumulative probability of\nreaching trained criterion', xlabel='Training day', 106 | # title='All labs', xlim=[0, 60], ylim=[0, 1.02]) 107 | # ax2.get_legend().set_visible(False) 108 | 109 | sns.despine(trim=True, offset=5) 110 | plt.tight_layout() 111 | seaborn_style() 112 | plt.savefig(join(fig_path, 'figure2d_probability_trained.pdf')) 113 | plt.savefig(join(fig_path, 'figure2d_probability_trained.png'), dpi=300) 114 | 115 | # Plot the same figure as a function of trial number 116 | f, (ax1) = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/3, FIGURE_HEIGHT)) 117 | 118 | kmf = KaplanMeierFitter() 119 | for i, lab in enumerate(np.unique(training_time['lab_number'])): 120 | kmf.fit(training_time.loc[training_time['lab_number'] == lab, 'trials'].values, 121 | event_observed=training_time.loc[training_time['lab_number'] == lab, 'trained']) 122 | ax1.step(kmf.cumulative_density_.index.values, kmf.cumulative_density_.values, 123 | color=lab_colors[i]) 124 | kmf.fit(training_time['trials'].values, event_observed=training_time['trained']) 125 | ax1.step(kmf.cumulative_density_.index.values, kmf.cumulative_density_.values, color='black') 126 | ax1.set(ylabel='Reached proficiency', xlabel='Trial', 127 | xlim=[0, 40e3], ylim=ylim) 128 | format_fcn = ticker.FuncFormatter(lambda x, pos: '{:,.0f}'.format(x / 1e3) + 'K') 129 | ax1.xaxis.set_major_formatter(format_fcn) 130 | ax1.set_title('All labs: %d mice' % training_time['nickname'].nunique()) 131 | 132 | sns.despine(trim=True, offset=5) 133 | plt.tight_layout() 134 | seaborn_style() 135 | plt.savefig(join(fig_path, 'figure2d_probability_trained_trials.pdf')) 136 | plt.savefig(join(fig_path, 'figure2d_probability_trained_trials.png'), dpi=300) 137 | -------------------------------------------------------------------------------- /figure2g_time_to_trained.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Quantify the variability of the time to trained over labs. 5 | 6 | @author: Guido Meijer, Miles Wells 7 | 16 Jan 2020 8 | """ 9 | from os.path import join 10 | 11 | import pandas as pd 12 | import matplotlib.pyplot as plt 13 | import matplotlib.ticker as ticker 14 | import numpy as np 15 | import seaborn as sns 16 | from scipy import stats 17 | import scikit_posthocs as sp 18 | from paper_behavior_functions import (query_subjects, seaborn_style, institution_map, 19 | group_colors, figpath, load_csv, datapath, 20 | EXAMPLE_MOUSE, FIGURE_HEIGHT, FIGURE_WIDTH, QUERY) 21 | from ibl_pipeline.analyses import behavior as behavior_analysis 22 | 23 | # Settings 24 | fig_path = figpath() 25 | seaborn_style() 26 | institution_map, col_names = institution_map() 27 | 28 | if QUERY is True: 29 | # Query sessions 30 | use_subjects = query_subjects() 31 | ses = (behavior_analysis.BehavioralSummaryByDate * use_subjects) 32 | ses = (ses & 'session_date <= date_trained').fetch(format='frame').reset_index() 33 | 34 | # Construct dataframe 35 | training_time = pd.DataFrame(columns=['sessions'], data=ses.groupby('subject_nickname').size()) 36 | ses['n_trials_date'] = ses['n_trials_date'].astype(int) 37 | training_time['trials'] = ses.groupby('subject_nickname').sum()['n_trials_date'] 38 | training_time['lab'] = ses.groupby('subject_nickname')['institution_short'].apply(list).str[0] 39 | 40 | # Change lab name into lab number 41 | training_time['lab_number'] = training_time.lab.map(institution_map) 42 | training_time = training_time.sort_values('lab_number') 43 | training_time = training_time.reset_index() 44 | 45 | else: 46 | data = load_csv('Fig2af.pkl').dropna() 47 | use_subjects = data['subject_nickname'].unique() # For counting the number of subjects 48 | training_time = pd.DataFrame() 49 | for i, subject in enumerate(use_subjects): 50 | training_time = training_time.append(pd.DataFrame(index=[training_time.shape[0] + 1], 51 | data={ 52 | 'subject_nickname': subject, 53 | 'lab': data.loc[data['subject_nickname'] == subject, 'institution_short'].unique(), 54 | 'sessions': data.loc[((data['subject_nickname'] == subject) 55 | & (data['session_date'] < data['date_trained']))].shape[0], 56 | 'trials': data.loc[((data['subject_nickname'] == subject) 57 | & (data['session_date'] < data['date_trained'])), 58 | 'n_trials_date'].sum()})) 59 | training_time['lab_number'] = training_time.lab.map(institution_map) 60 | training_time = training_time.sort_values('lab_number').reset_index(drop=True) 61 | 62 | # Number of sessions to trained for example mouse 63 | example_training_time = \ 64 | training_time.reset_index()[training_time.reset_index()[ 65 | 'subject_nickname'].str.match(EXAMPLE_MOUSE)]['sessions'] 66 | # example_training_time = training_time.ix[EXAMPLE_MOUSE]['sessions'] 67 | 68 | # statistics 69 | # Test normality 70 | _, normal = stats.normaltest(training_time['sessions']) 71 | if normal < 0.05: 72 | kruskal = stats.kruskal(*[group['sessions'].values 73 | for name, group in training_time.groupby('lab')]) 74 | if kruskal[1] < 0.05: # Proceed to posthocs 75 | posthoc = sp.posthoc_dunn(training_time, val_col='sessions', 76 | group_col='lab_number') 77 | else: 78 | anova = stats.f_oneway(*[group['sessions'].values 79 | for name, group in training_time.groupby('lab')]) 80 | if anova[1] < 0.05: 81 | posthoc = sp.posthoc_tukey(training_time, val_col='sessions', 82 | group_col='lab_number') 83 | 84 | 85 | # %% PLOT 86 | 87 | # Set figure style and color palette 88 | use_palette = [[0.6, 0.6, 0.6]] * len(np.unique(training_time['lab'])) 89 | use_palette = use_palette + [[1, 1, 0.2]] 90 | lab_colors = group_colors() 91 | 92 | # Add all mice to dataframe seperately for plotting 93 | training_time_no_all = training_time.copy() 94 | training_time_no_all.loc[training_time_no_all.shape[0] + 1, 'lab_number'] = 'All' 95 | training_time_all = training_time.copy() 96 | training_time_all['lab_number'] = 'All' 97 | training_time_all = training_time.append(training_time_all) 98 | 99 | # print 100 | print(training_time_all.reset_index().groupby(['lab_number'])['subject_nickname'].nunique()) 101 | 102 | f, (ax1) = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/3, FIGURE_HEIGHT)) 103 | sns.set_palette(lab_colors) 104 | sns.swarmplot(y='sessions', x='lab_number', hue='lab_number', data=training_time_no_all, 105 | palette=lab_colors, ax=ax1, marker='.') 106 | axbox = sns.boxplot(y='sessions', x='lab_number', data=training_time_all, 107 | color='white', showfliers=False, ax=ax1) 108 | axbox.artists[-1].set_edgecolor('black') 109 | for j in range(5 * (len(axbox.artists) - 1), 5 * len(axbox.artists)): 110 | axbox.lines[j].set_color('black') 111 | ax1.set(ylabel='Days to trained', xlabel='', ylim=[0, 60]) 112 | ax1.get_legend().set_visible(False) 113 | # [tick.set_color(lab_colors[i]) for i, tick in enumerate(ax1.get_xticklabels())] 114 | plt.setp(ax1.xaxis.get_majorticklabels(), rotation=40) 115 | sns.despine(trim=True) 116 | plt.tight_layout() 117 | plt.savefig(join(fig_path, 'figure2g_time_to_trained.pdf')) 118 | plt.savefig(join(fig_path, 'figure2g_time_to_trained.png'), dpi=300) 119 | 120 | 121 | # SAME FOR TRIALS TO TRAINED 122 | f, (ax1) = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT)) 123 | sns.set_palette(lab_colors) 124 | sns.swarmplot(y='trials', x='lab_number', hue='lab_number', data=training_time_no_all, 125 | palette=lab_colors, ax=ax1, marker='.') 126 | axbox = sns.boxplot(y='trials', x='lab_number', data=training_time_all, 127 | color='white', showfliers=False, ax=ax1) 128 | axbox.artists[-1].set_edgecolor('black') 129 | for j in range(5 * (len(axbox.artists) - 1), 5 * len(axbox.artists)): 130 | axbox.lines[j].set_color('black') 131 | ax1.set(ylabel='Trials to trained', xlabel='') 132 | ax1.get_legend().set_visible(False) 133 | # [tick.set_color(lab_colors[i]) for i, tick in enumerate(ax1.get_xticklabels())] 134 | plt.setp(ax1.xaxis.get_majorticklabels(), rotation=40) 135 | format_fcn = ticker.FuncFormatter(lambda x, pos: '{:,.0f}'.format(x / 1e3) + 'K') 136 | ax1.yaxis.set_major_formatter(format_fcn) 137 | sns.despine(trim=True) 138 | plt.tight_layout() 139 | plt.savefig(join(fig_path, 'suppfig_trials_to_trained.pdf')) 140 | plt.savefig(join(fig_path, 'suppfig_trials_to_trained.png'), dpi=300) 141 | 142 | 143 | # sns.swarmplot(y='trials', x='lab_number', hue='lab_number', data=training_time_no_all, 144 | # palette=lab_colors, ax=ax2) 145 | # axbox = sns.boxplot(y='trials', x='lab_number', data=training_time_all, 146 | # color='white', showfliers=False, ax=ax2) 147 | # axbox.artists[-1].set_edgecolor('black') 148 | # for j in range(5 * (len(axbox.artists) - 1), 5 * len(axbox.artists)): 149 | # axbox.lines[j].set_color('black') 150 | # ax2.set(ylabel='Trials to trained', xlabel='', ylim=[0, 50000]) 151 | # ax2.get_legend().set_visible(False) 152 | # # [tick.set_color(lab_colors[i]) for i, tick in enumerate(ax1.get_xticklabels())] 153 | # plt.setp(ax2.xaxis.get_majorticklabels(), rotation=40) 154 | 155 | # Get stats in text 156 | # Interquartile range per lab 157 | iqtr = training_time.groupby(['lab'])[ 158 | 'sessions'].quantile(0.75) - training_time.groupby(['lab'])[ 159 | 'sessions'].quantile(0.25) 160 | 161 | # Training time as a whole 162 | m_train = training_time['sessions'].mean() 163 | s_train = training_time['sessions'].std() 164 | slowest = training_time['sessions'].max() 165 | fastest = training_time['sessions'].min() 166 | 167 | # Print information used in the paper 168 | print('For mice that learned the task, the average training took %.1f ± %.1f days (s.d., ' 169 | 'n = %d), similar to the %d days of the example mouse from Lab 1 (Figure 2a, black). The ' 170 | 'fastest learner met training criteria in %d days, the slowest %d days' 171 | % (m_train, s_train, len(use_subjects), example_training_time, fastest, slowest)) 172 | 173 | # Training time in trials 174 | m_train = training_time['trials'].mean() / 1000 175 | s_train = training_time['trials'].std() / 1000 176 | slowest = training_time['trials'].max() / 1000 177 | fastest = training_time['trials'].min() / 1000 178 | 179 | print('In trials, the average training took %.1fK ± %.1fK trials (s.d., ' 180 | 'n = %d), similar to the %dK trials of the example mouse from Lab 1 (Figure 2a, black). The ' 181 | 'fastest learner met training criteria in %dK trials, the slowest %dK trials.' 182 | % (m_train, s_train, len(use_subjects), example_training_time, fastest, slowest)) 183 | -------------------------------------------------------------------------------- /figure3ab_psychfuncs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Psychometric functions of training mice, within and across labs 3 | 4 | @author: Anne Urai 5 | 15 January 2020 6 | """ 7 | import seaborn as sns 8 | import os 9 | from os.path import join 10 | import pandas as pd 11 | import matplotlib.pyplot as plt 12 | from paper_behavior_functions import (figpath, seaborn_style, group_colors, load_csv, 13 | query_sessions_around_criterion, institution_map, 14 | FIGURE_HEIGHT, FIGURE_WIDTH, QUERY, EXAMPLE_MOUSE, 15 | plot_psychometric, dj2pandas, plot_chronometric) 16 | # import wrappers etc 17 | from ibl_pipeline import reference, subject, behavior 18 | 19 | # Initialize 20 | seaborn_style() 21 | figpath = figpath() 22 | pal = group_colors() 23 | institution_map, col_names = institution_map() 24 | col_names = col_names[:-1] 25 | 26 | # %%=============================== # 27 | # GET DATA FROM TRAINED ANIMALS 28 | # ================================= # 29 | 30 | if QUERY is True: 31 | # query sessions 32 | use_sessions, use_days = query_sessions_around_criterion(criterion='trained', 33 | days_from_criterion=[2, 0], 34 | as_dataframe=False, 35 | force_cutoff=True) 36 | 37 | # list of dicts - see https://int-brain-lab.slack.com/archives/CB13FQFK4/p1607369435116300 for explanation 38 | sess = use_sessions.proj('task_protocol').fetch(format='frame').reset_index().to_dict('records') 39 | 40 | # Trial data to fetch 41 | trial_fields = ('trial_stim_contrast_left', 42 | 'trial_stim_contrast_right', 43 | 'trial_response_time', 44 | 'trial_stim_prob_left', 45 | 'trial_feedback_type', 46 | 'trial_stim_on_time', 47 | 'trial_response_choice') 48 | 49 | # Query trial data for sessions and subject name and lab info 50 | trials = (behavior.TrialSet.Trial & sess).proj(*trial_fields) 51 | 52 | # also get info about each subject 53 | subject_info = subject.Subject.proj('subject_nickname') * \ 54 | (subject.SubjectLab * reference.Lab).proj('institution_short') 55 | 56 | # Fetch, join and sort data as a pandas DataFrame 57 | behav = dj2pandas(trials.fetch(format='frame') 58 | .join(subject_info.fetch(format='frame')) 59 | .sort_values(by=['institution_short', 'subject_nickname', 60 | 'session_start_time', 'trial_id']) 61 | .reset_index()) 62 | behav['institution_code'] = behav.institution_short.map(institution_map) 63 | else: 64 | behav = load_csv('Fig3.csv') 65 | 66 | # print some output 67 | print(behav.sample(n=10)) 68 | 69 | # %%=============================== # 70 | # PSYCHOMETRIC FUNCTIONS 71 | # ================================= # 72 | 73 | # how many mice are there for each lab? 74 | N = behav.groupby(['institution_code'])['subject_nickname'].nunique().to_dict() 75 | behav['n_mice'] = behav.institution_code.map(N) 76 | behav['institution_name'] = behav.institution_code + '\n ' + behav.n_mice.apply(str) + ' mice' 77 | 78 | # plot one curve for each animal, one panel per lab 79 | plt.close('all') 80 | fig = sns.FacetGrid(behav, 81 | col="institution_code", col_wrap=7, col_order=col_names, 82 | sharex=True, sharey=True, hue="subject_uuid", 83 | height=FIGURE_HEIGHT, aspect=(FIGURE_WIDTH/7)/FIGURE_HEIGHT) 84 | fig.map(plot_psychometric, "signed_contrast", "choice_right", 85 | "subject_nickname", color='gray', alpha=0.7) 86 | fig.set_titles("{col_name}") 87 | 88 | # overlay the example mouse 89 | tmpdat = behav[behav['subject_nickname'].str.contains(EXAMPLE_MOUSE)] 90 | plot_psychometric(tmpdat.signed_contrast, tmpdat.choice_right, tmpdat.subject_nickname, 91 | color='black', ax=fig.axes[0], legend=False) 92 | 93 | # add lab means on top 94 | for axidx, ax in enumerate(fig.axes.flat): 95 | tmp_behav = behav.loc[behav.institution_name == behav.institution_name.unique()[axidx], :] 96 | plot_psychometric(tmp_behav.signed_contrast, tmp_behav.choice_right, 97 | tmp_behav.institution_name, ax=ax, legend=False, color=pal[axidx], linewidth=2) 98 | ax.set_title(sorted(behav.institution_name.unique())[axidx], 99 | color=pal[axidx]) 100 | 101 | fig.despine(trim=True) 102 | fig.set_axis_labels("\u0394 Contrast (%)", 'Rightward choices (%)') 103 | plt.tight_layout(w_pad=1) 104 | fig.savefig(os.path.join(figpath, "figure3a_psychfuncs.pdf")) 105 | fig.savefig(os.path.join(figpath, "figure3a_psychfuncs.png"), dpi=300) 106 | print('done') 107 | 108 | # %% 109 | 110 | # Plot all labs 111 | fig, ax1 = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT)) 112 | for i, inst in enumerate(behav.institution_code.unique()): 113 | tmp_behav = behav[behav['institution_code'].str.contains(inst)] 114 | plot_psychometric(tmp_behav.signed_contrast, tmp_behav.choice_right, 115 | tmp_behav.subject_nickname, ax=ax1, legend=False, color=pal[i]) 116 | #ax1.set_title('All labs', color='k', fontweight='bold') 117 | ax1.set_title('All labs: %d mice' % behav['subject_nickname'].nunique()) 118 | ax1.set(xlabel='\u0394 Contrast (%)', ylabel='Rightward choices (%)') 119 | sns.despine(trim=True) 120 | plt.tight_layout() 121 | fig.savefig(os.path.join(figpath, "figure3b_psychfuncs_all_labs.pdf")) 122 | fig.savefig(os.path.join(figpath, "figure3b_psychfuncs_all_labs.png"), dpi=300) 123 | 124 | # ================================= # 125 | # single summary panel 126 | # ================================= # 127 | 128 | # Plot all labs 129 | fig, ax1 = plt.subplots(1, 2, figsize=(8, 4)) 130 | plot_psychometric(behav.signed_contrast, behav.choice_right, 131 | behav.subject_nickname, ax=ax1[0], legend=False, color='k') 132 | ax1[0].set_title('Psychometric function', color='k', fontweight='bold') 133 | ax1[0].set(xlabel='\u0394 Contrast (%)', ylabel='Rightward choice (%)') 134 | 135 | plot_chronometric(behav.signed_contrast, behav.rt, 136 | behav.subject_nickname, ax=ax1[1], legend=False, color='k') 137 | ax1[1].set_title('Chronometric function', color='k', fontweight='bold') 138 | ax1[1].set(xlabel='\u0394 Contrast (%)', ylabel='Trial duration (s)', ylim=[0, 1.4]) 139 | sns.despine(trim=True) 140 | plt.tight_layout() 141 | fig.savefig(os.path.join(figpath, "summary_psych_chron.pdf")) 142 | plt.show() 143 | -------------------------------------------------------------------------------- /figure3cde_variability_over_labs_basic_&_suppfig3-2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Quantify the variability of behavioral metrics within and between labs of mouse behavior. 5 | This script doesn't perform any analysis but plots summary statistics over labs. 6 | 7 | Guido Meijer 8 | 16 Jan 2020 9 | """ 10 | 11 | import pandas as pd 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | from scipy import stats 15 | from os.path import join 16 | import seaborn as sns 17 | from paper_behavior_functions import (query_sessions_around_criterion, seaborn_style, 18 | institution_map, group_colors, figpath, load_csv, 19 | FIGURE_WIDTH, FIGURE_HEIGHT, QUERY, 20 | dj2pandas, fit_psychfunc, num_star) 21 | from ibl_pipeline import behavior, subject, reference 22 | import scikit_posthocs as sp 23 | from statsmodels.stats.multitest import multipletests 24 | 25 | # Initialize 26 | seaborn_style() 27 | figpath = figpath() 28 | pal = group_colors() 29 | institution_map, col_names = institution_map() 30 | col_names = col_names[:-1] 31 | 32 | if QUERY is True: 33 | use_sessions, _ = query_sessions_around_criterion(criterion='trained', 34 | days_from_criterion=[2, 0]) 35 | session_keys = (use_sessions & 'task_protocol LIKE "%training%"').fetch('KEY') 36 | ses = ((use_sessions & 'task_protocol LIKE "%training%"') 37 | * subject.Subject * subject.SubjectLab * reference.Lab 38 | * (behavior.TrialSet.Trial & session_keys)) 39 | ses = ses.proj('institution_short', 'subject_nickname', 'task_protocol', 'session_uuid', 40 | 'trial_stim_contrast_left', 'trial_stim_contrast_right', 41 | 'trial_response_choice', 'task_protocol', 'trial_stim_prob_left', 42 | 'trial_feedback_type', 'trial_response_time', 'trial_stim_on_time', 43 | 'session_end_time').fetch( 44 | order_by='institution_short, subject_nickname,session_start_time, trial_id', 45 | format='frame').reset_index() 46 | behav = dj2pandas(ses) 47 | behav['institution_code'] = behav.institution_short.map(institution_map) 48 | else: 49 | behav = load_csv('Fig3.csv', parse_dates=['session_start_time', 'session_end_time']) 50 | 51 | # Create dataframe with behavioral metrics of all mice 52 | learned = pd.DataFrame(columns=['mouse', 'lab', 'perf_easy', 'n_trials', 53 | 'threshold', 'bias', 'reaction_time', 54 | 'lapse_low', 'lapse_high', 'trials_per_minute']) 55 | 56 | for i, nickname in enumerate(behav['subject_nickname'].unique()): 57 | if np.mod(i+1, 10) == 0: 58 | print('Processing data of subject %d of %d' % (i+1, 59 | len(behav['subject_nickname'].unique()))) 60 | 61 | # Get the trials of the sessions around criterion for this subject 62 | trials = behav[behav['subject_nickname'] == nickname] 63 | trials = trials.reset_index() 64 | 65 | # Fit a psychometric function to these trials and get fit results 66 | fit_result = fit_psychfunc(trials) 67 | 68 | # Get RT, performance and number of trials 69 | reaction_time = trials['rt'].median()*1000 70 | perf_easy = trials['correct_easy'].mean()*100 71 | ntrials_perday = trials.groupby('session_uuid').count()['trial_id'].mean() 72 | 73 | # average trials/minute to normalise by session length 74 | trials['session_length'] = (trials.session_end_time - trials.session_start_time).astype('timedelta64[m]') 75 | total_session_length = trials.groupby('session_uuid')['session_length'].mean().sum() 76 | total_n_trials = trials['trial_id'].count() 77 | 78 | # Add results to dataframe 79 | learned.loc[i, 'mouse'] = nickname 80 | learned.loc[i, 'lab'] = trials['institution_short'][0] 81 | learned.loc[i, 'perf_easy'] = perf_easy 82 | learned.loc[i, 'n_trials'] = ntrials_perday 83 | learned.loc[i, 'reaction_time'] = reaction_time 84 | learned.loc[i, 'trials_per_minute'] = total_n_trials / total_session_length 85 | learned.loc[i, 'threshold'] = fit_result.loc[0, 'threshold'] 86 | learned.loc[i, 'bias'] = fit_result.loc[0, 'bias'] 87 | learned.loc[i, 'lapse_low'] = fit_result.loc[0, 'lapselow'] 88 | learned.loc[i, 'lapse_high'] = fit_result.loc[0, 'lapsehigh'] 89 | 90 | # Drop mice with faulty RT 91 | learned = learned[learned['reaction_time'].notnull()] 92 | 93 | # Change lab name into lab number 94 | learned['lab_number'] = learned.lab.map(institution_map) 95 | learned = learned.sort_values('lab_number') 96 | 97 | # Convert to float 98 | float_fields = ['perf_easy', 'reaction_time', 'threshold', 99 | 'n_trials', 'bias', 'lapse_low', 'lapse_high', 'trials_per_minute'] 100 | learned[float_fields] = learned[float_fields].astype(float) 101 | 102 | # %% Stats 103 | stats_tests = pd.DataFrame(columns=['variable', 'test_type', 'p_value']) 104 | posthoc_tests = {} 105 | 106 | for i, var in enumerate(['perf_easy', 'reaction_time', 'n_trials', 'threshold', 'bias', 'trials_per_minute']): 107 | _, normal = stats.normaltest(learned[var]) 108 | 109 | if normal < 0.05: 110 | test_type = 'kruskal' 111 | test = stats.kruskal(*[group[var].values 112 | for name, group in learned.groupby('lab_number')]) 113 | if test[1] < 0.05: # Proceed to posthocs 114 | posthoc = sp.posthoc_dunn(learned, val_col=var, group_col='lab_number') 115 | else: 116 | posthoc = np.nan 117 | else: 118 | test_type = 'anova' 119 | test = stats.f_oneway(*[group[var].values 120 | for name, group in learned.groupby('lab_number')]) 121 | if test[1] < 0.05: 122 | posthoc = sp.posthoc_tukey(learned, val_col=var, group_col='lab_number') 123 | else: 124 | posthoc = np.nan 125 | 126 | # Test for difference in variance 127 | _, p_var = stats.levene(*[group[var].values for name, group in learned.groupby('lab_number')]) 128 | 129 | posthoc_tests['posthoc_'+str(var)] = posthoc 130 | stats_tests.loc[i, 'variable'] = var 131 | stats_tests.loc[i, 'test_type'] = test_type 132 | stats_tests.loc[i, 'p_value'] = test[1] 133 | stats_tests.loc[i, 'p_value_variance'] = p_var 134 | 135 | # Correct for multiple tests 136 | stats_tests['p_value'] = multipletests(stats_tests['p_value'], method='fdr_bh')[1] 137 | stats_tests['p_value_variance'] = multipletests(stats_tests['p_value_variance'], 138 | method='fdr_bh')[1] 139 | 140 | if (stats.normaltest(learned['n_trials'])[1] < 0.05 or 141 | stats.normaltest(learned['reaction_time'])[1] < 0.05): 142 | test_type = 'spearman' 143 | correlation_coef, correlation_p = stats.spearmanr(learned['reaction_time'], 144 | learned['n_trials']) 145 | if (stats.normaltest(learned['n_trials'])[1] > 0.05 and 146 | stats.normaltest(learned['reaction_time'])[1] > 0.05): 147 | test_type = 'pearson' 148 | correlation_coef, correlation_p = stats.pearsonr(learned['reaction_time'], 149 | learned['n_trials']) 150 | 151 | # Add all mice to dataframe seperately for plotting 152 | learned_no_all = learned.copy() 153 | learned_no_all.loc[learned_no_all.shape[0] + 1, 'lab_number'] = 'All' 154 | learned_2 = learned.copy() 155 | learned_2['lab_number'] = 'All' 156 | learned_2 = learned.append(learned_2) 157 | 158 | # %% 159 | seaborn_style() 160 | lab_colors = group_colors() 161 | sns.set_palette(lab_colors) 162 | 163 | # %% 164 | vars = ['n_trials', 'perf_easy', 'threshold', 'bias', 'reaction_time', 'trials_per_minute'] 165 | ylabels =['Number of trials', 'Performance (%)\non easy trials', 166 | 'Contrast threshold (%)', 'Bias (%)', 'Trial duration (ms)', 'Trials / minute'] 167 | ylims = [[0, 1500],[70, 100], [0, 25], [-25, 25], [0, 1100], [0, 25]] 168 | for v, ylab, ylim in zip(vars, ylabels, ylims): 169 | 170 | f, ax = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT)) 171 | sns.swarmplot(y=v, x='lab_number', data=learned_no_all, hue='lab_number', 172 | palette=lab_colors, ax=ax, marker='.') 173 | axbox = sns.boxplot(y=v, x='lab_number', data=learned_2, color='white', 174 | showfliers=False, ax=ax) 175 | ax.set(ylabel=ylab, ylim=ylim, xlabel='') 176 | # [tick.set_color(lab_colors[i]) for i, tick in enumerate(ax5.get_xticklabels()[:-1])] 177 | plt.setp(ax.xaxis.get_majorticklabels(), rotation=60) 178 | axbox.artists[-1].set_edgecolor('black') 179 | for j in range(5 * (len(axbox.artists) - 1), 5 * len(axbox.artists)): 180 | axbox.lines[j].set_color('black') 181 | ax.get_legend().set_visible(False) 182 | 183 | # statistical annotation 184 | pvalue = stats_tests.loc[stats_tests['variable'] == v, 'p_value'] 185 | if pvalue.to_numpy()[0] < 0.05: 186 | ax.annotate(num_star(pvalue.to_numpy()[0]), 187 | xy=[0.1, 0.8], xycoords='axes fraction', fontsize=5) 188 | 189 | sns.despine(trim=True) 190 | plt.tight_layout() 191 | plt.savefig(join(figpath, 'figure3_metrics_%s.pdf'%v)) 192 | plt.savefig(join(figpath, 'figure3_metrics_%s.pdf'%v), dpi=300) 193 | 194 | # %% 195 | # Get stats for text 196 | perf_mean = learned['perf_easy'].mean() 197 | perf_std = learned['perf_easy'].std() 198 | thres_mean = learned['threshold'].mean() 199 | thres_std = learned['threshold'].std() 200 | rt_median = learned['reaction_time'].median() 201 | rt_std = learned['reaction_time'].std() 202 | trials_mean = learned['n_trials'].mean() 203 | trials_std = learned['n_trials'].std() 204 | -------------------------------------------------------------------------------- /figure3f_classifier_lab_membership_basic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Decode in which lab a mouse was trained based on its behavioral metrics during the three sessions 5 | of the basic task variant in which the mouse was determined to be trained. 6 | 7 | As a positive control, the time zone in which the mouse was trained is included in the dataset 8 | since the timezone provides geographical information. Decoding is performed using leave-one-out 9 | cross-validation. To control for the imbalance in the dataset (some labs have more mice than 10 | others) a fixed number of mice is randomly sub-sampled from each lab. This random sampling is 11 | repeated for a large number of repetitions. A shuffled nul-distribution is obtained by shuffling 12 | the lab labels and decoding again for each iteration. 13 | 14 | -------------- 15 | Parameters 16 | DECODER: Which decoder to use: 'bayes', 'forest', or 'regression' 17 | N_MICE: How many mice per lab to randomly sub-sample 18 | (must be lower than the lab with the least mice) 19 | ITERATIONS: Number of times to randomly sub-sample 20 | METRICS: List of strings indicating which behavioral metrics to include 21 | during decoding of lab membership 22 | METRICS_CONTROL: List of strings indicating which metrics to use for the positive control 23 | 24 | Guido Meijer 25 | September 3, 2020 26 | """ 27 | 28 | import pandas as pd 29 | import numpy as np 30 | from os.path import join 31 | from paper_behavior_functions import (query_sessions_around_criterion, institution_map, QUERY, 32 | dj2pandas, fit_psychfunc, datapath, load_csv) 33 | from ibl_pipeline import subject, reference 34 | from ibl_pipeline import behavior 35 | from sklearn.ensemble import RandomForestClassifier 36 | from sklearn.naive_bayes import GaussianNB 37 | from sklearn.linear_model import LogisticRegression 38 | from sklearn.model_selection import LeaveOneOut 39 | from sklearn.metrics import f1_score, confusion_matrix 40 | 41 | # Parameters 42 | DECODER = 'forest' # bayes, forest or regression 43 | N_MICE = 8 # how many mice per lab to sub-sample 44 | ITERATIONS = 2000 # how often to decode 45 | METRICS = ['perf_easy', 'threshold', 'bias'] 46 | METRICS_CONTROL = ['perf_easy', 'threshold', 'bias', 'time_zone'] 47 | 48 | 49 | # Decoding function with n-fold cross validation 50 | def decoding(data, labels, clf): 51 | kf = LeaveOneOut() 52 | y_pred = np.empty(len(labels), dtype=' significance: 34 | print('\n%s classifier did not perform above chance' % DECODER) 35 | print('Chance level: %.2f (F1 score)' % chance_level) 36 | else: 37 | print('\n%s classifier did not perform above chance' % DECODER) 38 | print('Chance level: %.2f (F1 score)' % chance_level) 39 | print('F1 score: %.2f ± %.3f' % (decoding_result['original'].mean(), 40 | decoding_result['original'].std())) 41 | 42 | # %% 43 | 44 | # Plot main Figure 3 45 | if DECODER == 'bayes': 46 | f, ax1 = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT)) 47 | sns.violinplot(data=pd.concat([decoding_result['control'], 48 | decoding_result['original_shuffled'], 49 | decoding_result['original']], axis=1), 50 | palette=colors, ax=ax1) 51 | ax1.plot([-1, 3.5], [chance_level, chance_level], '--', color='k', zorder=-10) 52 | ax1.set(ylabel='Decoding accuracy', xlim=[-0.6, 2.6], ylim=[-0.1, 0.62]) 53 | ax1.set_xticklabels(['Positive\ncontrol', 'Shuffle', 'Mouse\nbehavior'], 54 | rotation=90, ha='center') 55 | plt.tight_layout() 56 | sns.despine(trim=True) 57 | 58 | plt.savefig(join(FIG_PATH, 'figure3f_decoding.pdf')) 59 | plt.savefig(join(FIG_PATH, 'figure3f_decoding.png'), dpi=300) 60 | plt.close(f) 61 | 62 | # Plot supplementary Figure 3 63 | f, ax1 = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT)) 64 | sns.violinplot(data=pd.concat([decoding_result['control'], 65 | decoding_result['original_shuffled'], 66 | decoding_result['original']], axis=1), 67 | palette=colors, ax=ax1) 68 | ax1.plot([-1, 3.5], [chance_level, chance_level], '--', color='k', zorder=-10) 69 | ax1.set(ylabel='Decoding accuracy', xlim=[-0.8, 2.6], ylim=[-0.1, 0.62]) 70 | ax1.set_xticklabels(['Positive\ncontrol', 'Shuffle', 'Mouse\nbehavior'], 71 | rotation=90, ha='center') 72 | plt.tight_layout() 73 | sns.despine(trim=True) 74 | 75 | plt.savefig(join(FIG_PATH, 'suppfig3_decoding_%s.pdf' % DECODER)) 76 | plt.savefig(join(FIG_PATH, 'suppfig3_decoding_%s.png' % DECODER), dpi=300) 77 | plt.close(f) 78 | 79 | # %% 80 | f, ax1 = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/4, FIGURE_HEIGHT)) 81 | n_labs = decoding_result['confusion_matrix'][0].shape[0] 82 | # sns.heatmap(data=decoding_result['confusion_matrix'].mean(), vmin=0, vmax=0.6) 83 | sns.heatmap(data=decoding_result['confusion_matrix'].mean(), 84 | vmin=0, vmax=0.4) 85 | ax1.plot([0, 7], [0, 7], '--w') 86 | ax1.set(xticklabels=np.arange(1, n_labs + 1), yticklabels=np.arange(1, n_labs + 1), 87 | ylim=[0, n_labs], xlim=[0, n_labs], 88 | title='', ylabel=' ', xlabel='Predicted lab') 89 | if DECODER == 'bayes': 90 | ax1.set(ylabel='Actual lab') 91 | plt.setp(ax1.xaxis.get_majorticklabels(), rotation=40) 92 | plt.setp(ax1.yaxis.get_majorticklabels(), rotation=40) 93 | plt.gca().invert_yaxis() 94 | plt.tight_layout() 95 | 96 | plt.savefig(join(FIG_PATH, 'suppfig3_confusion_matrix_%s.pdf' % DECODER)) 97 | plt.savefig(join(FIG_PATH, 'suppfig3_confusion_matrix_%s.png' % DECODER), dpi=300) 98 | plt.close(f) 99 | 100 | f, ax1 = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/4, FIGURE_HEIGHT)) 101 | # sns.heatmap(data=decoding_result['control_cm'].mean(), vmin=0, vmax=1) 102 | sns.heatmap(data=decoding_result['control_cm'].mean(), 103 | vmin=0, vmax=0.4) 104 | ax1.plot([0, 7], [0, 7], '--w') 105 | ax1.set(xticklabels=np.arange(1, n_labs + 1), yticklabels=np.arange(1, n_labs + 1), 106 | title='', ylabel=' ', xlabel='Predicted lab', 107 | ylim=[0, n_labs], xlim=[0, n_labs]) 108 | 109 | if DECODER == 'bayes': 110 | ax1.set(ylabel='Actual lab') 111 | 112 | plt.setp(ax1.xaxis.get_majorticklabels(), rotation=40) 113 | plt.setp(ax1.yaxis.get_majorticklabels(), rotation=40) 114 | plt.gca().invert_yaxis() 115 | plt.tight_layout() 116 | plt.savefig(join(FIG_PATH, 'suppfig3_control_confusion_matrix_%s.pdf' % DECODER)) 117 | plt.savefig(join(FIG_PATH, 'suppfig3_control_confusion_matrix_%s.png' % DECODER), dpi=300) 118 | plt.close(f) 119 | -------------------------------------------------------------------------------- /figure4a_block_probabilities.py: -------------------------------------------------------------------------------- 1 | """ 2 | Block structure in the biased task for an example session 3 | 4 | @author: Anne Urai 5 | 15 January 2020 6 | """ 7 | 8 | from ibl_pipeline import subject, acquisition, behavior 9 | import pandas as pd 10 | import os 11 | import seaborn as sns 12 | import numpy as np 13 | from paper_behavior_functions import (seaborn_style, figpath, EXAMPLE_MOUSE, 14 | FIGURE_HEIGHT, FIGURE_WIDTH, dj2pandas) 15 | import matplotlib.pyplot as plt 16 | import matplotlib.patches as patches 17 | 18 | # INITIALIZE A FEW THINGS 19 | seaborn_style() 20 | figpath = figpath() 21 | cmap = sns.diverging_palette(20, 220, n=3, center="dark") 22 | cmap_dic = {20: cmap[0], 50: cmap[1], 80: cmap[2]} 23 | 24 | # ================================= # 25 | # SCHEMATIC OF THE BLOCKS 26 | # ================================= # 27 | 28 | behav = pd.DataFrame({'probability_left': [50, 50, 20, 20, 80, 80], 29 | 'stimulus_side': [-1, 1, -1, 1, -1, 1], 30 | 'prob': [50, 50, 20, 80, 80, 20]}) 31 | 32 | fig = sns.FacetGrid(behav, 33 | col="probability_left", hue="probability_left", col_wrap=3, 34 | col_order=[50, 20, 80], palette=cmap, sharex=True, sharey=True, 35 | aspect=0.6, height=2.2) 36 | # fig.map(sns.distplot, "stimulus_side", kde=False, norm_hist=True, bins=2, hist_kws={'rwidth':1}) 37 | fig.map(sns.barplot, "stimulus_side", "prob") 38 | fig.set(xticks=[-0, 1], xlim=[-0.5, 1.5], 39 | ylim=[0, 100], yticks=[0, 50, 100], yticklabels=[]) 40 | for ax, title in zip(fig.axes.flat, ['50/50', '80/20', '20/80']): 41 | ax.set_title(title) 42 | ax.set_xticklabels(['Left', 'Right'], rotation=45) 43 | fig.set_axis_labels('', 'Probability (%)') 44 | fig.savefig(os.path.join( 45 | figpath, "figure4_panel_block_distribution.png"), dpi=600) 46 | fig.savefig(os.path.join(figpath, "figure4_panel_block_distribution.pdf")) 47 | plt.close('all') 48 | 49 | # ================================= # 50 | # EXAMPLE SESSION TIMECOURSE 51 | # ================================= # 52 | 53 | b = ((subject.Subject & 'subject_nickname="%s"' % EXAMPLE_MOUSE) 54 | * behavior.TrialSet.Trial 55 | * (acquisition.Session & 'task_protocol LIKE "%biased%"' 56 | & 'session_start_time BETWEEN "2019-08-30" and "2019-08-31"')) 57 | bdat = b.fetch(order_by='session_start_time, trial_id', 58 | format='frame').reset_index() 59 | behav = dj2pandas(bdat) 60 | assert not behav.empty 61 | 62 | # if 100 in df.signed_contrast.values and not 50 in df.signed_contrast.values: 63 | behav['signed_contrast'] = behav['signed_contrast'].replace(-100, -35) 64 | behav['signed_contrast'] = behav['signed_contrast'].replace(100, 35) 65 | 66 | # %% 67 | for dayidx, behavtmp in behav.groupby(['session_start_time']): 68 | 69 | # 1. patches to show the blocks 70 | fig, axes = plt.subplots(ncols=1, nrows=1, figsize=(FIGURE_WIDTH/3.2, FIGURE_HEIGHT*0.9)) 71 | xmax = min([behavtmp.trial_id.max() + 5, 500]) 72 | 73 | # Loop over data points; create box from errors at each point 74 | behavtmp['blocks'] = (behavtmp["probabilityLeft"].ne( 75 | behavtmp["probabilityLeft"].shift()).cumsum()) 76 | 77 | for idx, blocknum in behavtmp.groupby('blocks'): 78 | left = blocknum.trial_id.min() 79 | width = blocknum.trial_id.max() - blocknum.trial_id.min() 80 | axes.add_patch(patches.Rectangle((left, 0), width, 100, 81 | fc=cmap_dic[blocknum.probabilityLeft.unique()[ 82 | 0]], 83 | ec='none', alpha=0.2)) 84 | 85 | # 86 | # 2. actual block probabilities as grey line 87 | behavtmp['stim_sign'] = 100 * \ 88 | ((np.sign(behavtmp.signed_contrast) / 2) + 0.5) 89 | # sns.scatterplot(x='trial_id', y='stim_sign', data=behav, color='grey', 90 | # marker='o', ax=axes, legend=False, alpha=0.5, 91 | # ec='none', linewidth=0, zorder=2) 92 | sns.lineplot(x='trial_id', y='stim_sign', color='black', ci=None, 93 | data=behavtmp[['trial_id', 'stim_sign']].rolling(10).mean(), ax=axes) 94 | # %% 95 | 96 | # 3. ANIMAL CHOICES, rolling window 97 | #rightax = axes.twinx() 98 | behavtmp['choice_right'] = behavtmp.choice_right * 100 99 | sns.lineplot(x='trial_id', y='choice_right', color='firebrick', ci=None, 100 | data=behavtmp[['trial_id', 'choice_right']].rolling(10).mean(), ax=axes, 101 | linestyle=':') 102 | # rightax.set(xlim=[-5, xmax], xlabel='Trial number', 103 | # ylabel='Rightwards choices (%)', ylim=[-1, 101]) 104 | # rightax.yaxis.label.set_color("firebrick") 105 | # rightax.tick_params(axis='y', colors='firebrick') 106 | # axes.set_yticks([0, 50, 100]) 107 | # rightax.set_yticks([0, 50, 100]) 108 | # axes.set_title(' \n ') 109 | 110 | axes.set(xlim=[-5, xmax], xlabel='Trial number', 111 | ylabel=' ', ylim=[-1, 101]) 112 | axes.yaxis.label.set_color("black") 113 | axes.tick_params(axis='y', colors='black') 114 | plt.tight_layout() 115 | fig.savefig(os.path.join( 116 | figpath, "figure4_panel_session_course_%s.png" % dayidx.date()), dpi=600) 117 | fig.savefig(os.path.join( 118 | figpath, "figure4_panel_session_course_%s.pdf" % dayidx.date())) 119 | plt.close('all') 120 | -------------------------------------------------------------------------------- /figure4de_psychfuncs_biased.py: -------------------------------------------------------------------------------- 1 | """ 2 | Psychometric function and choice shifts in the biased task 3 | 4 | @author: Anne Urai 5 | 15 January 2020 6 | """ 7 | 8 | import pandas as pd 9 | import numpy as np 10 | import os 11 | from os.path import join 12 | import matplotlib.pyplot as plt 13 | import seaborn as sns 14 | import statsmodels.api as sm 15 | from statsmodels.formula.api import ols 16 | from paper_behavior_functions import (seaborn_style, figpath, group_colors, institution_map, 17 | query_sessions_around_criterion, EXAMPLE_MOUSE, 18 | FIGURE_HEIGHT, FIGURE_WIDTH, QUERY, load_csv, 19 | dj2pandas, plot_psychometric, fit_psychfunc, plot_chronometric, 20 | break_xaxis) 21 | # import wrappers etc 22 | from ibl_pipeline import reference, subject, behavior 23 | from ibl_pipeline.utils import psychofit as psy 24 | 25 | # Initialize 26 | seaborn_style() 27 | figpath = figpath() 28 | pal = group_colors() 29 | institution_map, col_names = institution_map() 30 | col_names = col_names[:-1] 31 | 32 | # colors for biased blocks 33 | cmap = sns.color_palette([[0.8984375, 0.37890625, 0.00390625], 34 | [0.3, 0.3, 0.3], [0.3671875, 0.234375, 0.59765625]]) 35 | sns.set_palette(cmap) 36 | 37 | # ================================= # 38 | # GET DATA FROM TRAINED ANIMALS 39 | # ================================= # 40 | 41 | if QUERY is True: 42 | # query sessions 43 | use_sessions, _ = query_sessions_around_criterion(criterion='ephys', 44 | days_from_criterion=[2, 0], 45 | force_cutoff=True) 46 | use_sessions = use_sessions & 'task_protocol LIKE "%biased%"' # only get biased sessions 47 | 48 | # restrict by list of dicts with uuids for these sessions 49 | b = (use_sessions * subject.Subject * subject.SubjectLab * reference.Lab 50 | * behavior.TrialSet.Trial) 51 | 52 | # reduce the size of the fetch 53 | b2 = b.proj('institution_short', 'subject_nickname', 'task_protocol', 'session_uuid', 54 | 'trial_stim_contrast_left', 'trial_stim_contrast_right', 'trial_response_choice', 55 | 'task_protocol', 'trial_stim_prob_left', 'trial_feedback_type', 56 | 'trial_response_time', 'trial_stim_on_time') 57 | 58 | # construct pandas dataframe 59 | bdat = b2.fetch(order_by='institution_short, subject_nickname, session_start_time, trial_id', 60 | format='frame').reset_index() 61 | behav = dj2pandas(bdat) 62 | behav['institution_code'] = behav.institution_short.map(institution_map) 63 | else: 64 | behav = load_csv('Fig4.csv') 65 | 66 | # how many mice are there for each lab? 67 | N = behav.groupby(['institution_code'])['subject_nickname'].nunique().to_dict() 68 | behav['n_mice'] = behav.institution_code.map(N) 69 | behav['institution_name'] = behav.institution_code + '\n' + behav.n_mice.apply(str) + ' mice' 70 | 71 | # ================================= # 72 | # PSYCHOMETRIC FUNCTIONS 73 | # FOR OUR EXAMPLE ANIMAL 74 | # ================================= # 75 | 76 | fig = sns.FacetGrid(behav[behav['subject_nickname'] == EXAMPLE_MOUSE], 77 | hue="probabilityLeft", palette=cmap, 78 | sharex=True, sharey=True, 79 | height=FIGURE_HEIGHT, aspect=(FIGURE_WIDTH/4)/FIGURE_HEIGHT) 80 | fig.map(plot_psychometric, "signed_contrast", "choice_right", "session_uuid") 81 | fig.set_axis_labels('\u0394 Contrast (%)', 'Rightward choices (%)') 82 | fig.ax.annotate('20:80', xy=(-5, 0.6), xytext=(-25, 0.8), color=cmap[0], fontsize=7) 83 | fig.ax.annotate('80:20', xy=(5, 0.4), xytext=(13, 0.18), color=cmap[2], fontsize=7) 84 | fig.despine(trim=True) 85 | fig.axes[0][0].set_title('Example mouse', fontweight='bold', color='k') 86 | fig.savefig(os.path.join(figpath, "figure4b_psychfuncs_biased_example.pdf")) 87 | fig.savefig(os.path.join( 88 | figpath, "figure4b_psychfuncs_biased_example.png"), dpi=600) 89 | plt.close('all') 90 | 91 | # ================================= # 92 | # PSYCHOMETRIC FUNCTIONS 93 | # one for all labs combined 94 | # ================================= # 95 | 96 | fig = sns.FacetGrid(behav, 97 | hue="probabilityLeft", palette=cmap, 98 | sharex=True, sharey=True, 99 | height=FIGURE_HEIGHT, aspect=(FIGURE_WIDTH/4)/FIGURE_HEIGHT) 100 | fig.map(plot_psychometric, "signed_contrast", 101 | "choice_right", "subject_nickname") 102 | fig.set_axis_labels('\u0394 Contrast (%)', '') 103 | fig.ax.annotate('20:80', xy=(-5, 0.6), xytext=(-25, 0.8), color=cmap[0], fontsize=7) 104 | fig.ax.annotate('80:20', xy=(5, 0.4), xytext=(13, 0.18), color=cmap[2], fontsize=7) 105 | fig.despine(trim=True) 106 | fig.axes[0][0].set_title('All mice: n = %d' % behav.subject_nickname.nunique(), 107 | fontweight='bold', color='k') 108 | fig.axes[0][0].set(yticklabels=[]) 109 | fig.savefig(os.path.join(figpath, "figure4b_psychfuncs_biased.pdf")) 110 | fig.savefig(os.path.join( 111 | figpath, "figure4b_psychfuncs_biased.png"), dpi=600) 112 | plt.close('all') 113 | 114 | # ================================================================== # 115 | # DIFFERENCE BETWEEN TWO PSYCHOMETRIC FUNCTIONS 116 | # FOR EACH ANIMAL + for each lab (in 'lab color') 117 | # ================================================================== # 118 | 119 | print('fitting psychometric functions...') 120 | pars = behav.groupby(['institution_code', 'subject_nickname', 121 | 'probabilityLeft']).apply(fit_psychfunc).reset_index() 122 | # now read these out at the presented levels of signed contrast 123 | behav2 = pd.DataFrame([]) 124 | xvec = behav.signed_contrast.unique() 125 | for index, group in pars.groupby(['institution_code', 'subject_nickname', 126 | 'probabilityLeft']): 127 | # expand 128 | yvec = psy.erf_psycho_2gammas([group.bias.item(), 129 | group.threshold.item(), 130 | group.lapselow.item(), 131 | group.lapsehigh.item()], xvec) 132 | group2 = group.loc[group.index.repeat( 133 | len(yvec))].reset_index(drop=True).copy() 134 | group2['signed_contrast'] = xvec 135 | group2['choice'] = 100 * yvec 136 | 137 | # add this 138 | behav2 = behav2.append(group2) 139 | 140 | # now subtract these to compute a bias shift 141 | behav3 = pd.pivot_table(behav2, values='choice', 142 | index=['institution_code', 'subject_nickname', 143 | 'signed_contrast'], 144 | columns=['probabilityLeft']).reset_index() 145 | behav3['biasshift'] = behav3[20] - behav3[80] 146 | 147 | # %% PLOT 148 | 149 | # plot one curve for each animal, one panel per lab 150 | plt.close('all') 151 | fig = sns.FacetGrid(behav3, 152 | col="institution_code", col_wrap=7, col_order=col_names, 153 | sharex=True, sharey=True, hue="subject_nickname", 154 | height=FIGURE_HEIGHT, aspect=(FIGURE_WIDTH/7)/FIGURE_HEIGHT) 155 | fig.map(plot_chronometric, "signed_contrast", "biasshift", 156 | "subject_nickname", color='gray', alpha=0.7) 157 | 158 | # overlay the example mouse 159 | tmpdat = behav3[behav3['subject_nickname'].str.contains(EXAMPLE_MOUSE)] 160 | plot_chronometric(tmpdat.signed_contrast, tmpdat.biasshift, tmpdat.subject_nickname, 161 | color='black', ax=fig.axes[0], legend=False) 162 | fig.set_titles("{col_name}") 163 | 164 | fig.despine(trim=True) 165 | ymin = fig.axes[0].get_ylim()[0]-0.2 166 | fig.map(break_xaxis, y=ymin) 167 | 168 | # add lab means on top 169 | for axidx, ax in enumerate(fig.axes.flat): 170 | tmp_behav = behav3.loc[behav3.institution_code == behav3.institution_code.unique()[axidx], :] 171 | plot_chronometric(tmp_behav.signed_contrast, tmp_behav.biasshift, 172 | tmp_behav.institution_code, ax=ax, legend=False, 173 | color=pal[axidx], linewidth=2) 174 | ax.set_title(sorted(behav.institution_name.unique())[axidx], 175 | color=pal[axidx]) 176 | 177 | fig.set_axis_labels('\u0394 Contrast (%)', '\u0394 Rightward choices (%)') 178 | plt.tight_layout(w_pad=0) 179 | fig.savefig(os.path.join(figpath, "figure4e_biasshift.pdf")) 180 | fig.savefig(os.path.join(figpath, "figure4e_biasshift.png"), dpi=300) 181 | plt.close('all') 182 | 183 | 184 | # %% PLOT 185 | 186 | fig, ax1 = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/4, FIGURE_HEIGHT)) 187 | for i, inst in enumerate(behav.institution_code.unique()): 188 | tmp_behav = behav3[behav3['institution_code'].str.contains(inst)] 189 | plot_chronometric(tmp_behav.signed_contrast, tmp_behav.biasshift, 190 | tmp_behav.subject_nickname, ax=ax1, legend=False, color=pal[i]) 191 | # ax1.set_title('All labs', color='k', fontweight='bold') 192 | ax1.set(xlabel='\u0394 Contrast (%)', ylabel='\u0394 Rightward choices (%)', 193 | yticks=[0, 10, 20, 30, 40]) 194 | sns.despine(trim=True) 195 | plt.tight_layout() 196 | ymin = ax1.get_ylim()[0]-0.15 197 | break_xaxis(y=ymin) 198 | fig.savefig(os.path.join(figpath, "figure4d_biasshift_all_labs.pdf")) 199 | fig.savefig(os.path.join(figpath, "figure4d_biasshift_all_labs.png"), dpi=300) 200 | 201 | # ================================================================== # 202 | # Plot behavioral metrics per lab 203 | # ================================================================== # 204 | 205 | bias = behav3.loc[behav3.signed_contrast == 0, :] 206 | 207 | # stats on bias shift between laboratories: 208 | sm_lm = ols('biasshift ~ C(institution_code)', data=bias).fit() 209 | table = sm.stats.anova_lm(sm_lm, typ=2) # Type 2 ANOVA DataFrame 210 | print(table) 211 | 212 | # Add all mice to dataframe seperately for plotting 213 | bias_all = bias.copy() 214 | 215 | print('average bias shift across all mice: ') 216 | print(bias_all['biasshift'].mean()) 217 | bias_all['institution_code'] = 'All' 218 | 219 | bias_all = bias.append(bias_all) 220 | 221 | # Set color palette 222 | use_palette = [[0.6, 0.6, 0.6]] * len(np.unique(bias['institution_code'])) 223 | use_palette = use_palette + [[1, 1, 0.2]] 224 | sns.set_palette(use_palette) 225 | 226 | # plot 227 | f, ax1 = plt.subplots(1, 1, figsize=(3, 3.5)) 228 | sns.set_palette(use_palette) 229 | 230 | sns.boxplot(y='biasshift', x='institution_code', data=bias_all, ax=ax1) 231 | ax1.set(ylabel='\u0394 Rightward choices (%)\n at 0% contrast', 232 | ylim=[0, 51], xlabel='') 233 | [tick.set_color(pal[i]) for i, tick in enumerate(ax1.get_xticklabels()[:-1])] 234 | plt.setp(ax1.xaxis.get_majorticklabels(), rotation=40) 235 | plt.tight_layout(pad=2) 236 | seaborn_style() 237 | 238 | # plt.savefig(os.path.join(figpath, 'figure4e_bias_per_lab.pdf')) 239 | # plt.savefig(os.path.join(figpath, 'figure4e_bias_per_lab.png'), dpi=300) 240 | -------------------------------------------------------------------------------- /figure4fghi_variability_over_labs_full.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plotting of behavioral metrics during the full task (biased blocks) per lab 3 | 4 | Guido Meijer 5 | 6 May 2020 6 | """ 7 | 8 | import seaborn as sns 9 | import numpy as np 10 | from os.path import join 11 | import matplotlib.pyplot as plt 12 | from scipy import stats 13 | import scikit_posthocs as sp 14 | from paper_behavior_functions import (figpath, seaborn_style, group_colors, institution_map, 15 | FIGURE_WIDTH, FIGURE_HEIGHT, QUERY, 16 | fit_psychfunc, dj2pandas, load_csv) 17 | import pandas as pd 18 | from statsmodels.stats.multitest import multipletests 19 | 20 | # Initialize 21 | seaborn_style() 22 | figpath = figpath() 23 | pal = group_colors() 24 | institution_map, col_names = institution_map() 25 | col_names = col_names[:-1] 26 | 27 | # %% Process data 28 | 29 | if QUERY is True: 30 | # query sessions 31 | from paper_behavior_functions import query_sessions_around_criterion 32 | from ibl_pipeline import reference, subject, behavior 33 | use_sessions, _ = query_sessions_around_criterion(criterion='ephys', 34 | days_from_criterion=[2, 0], 35 | force_cutoff=True) 36 | session_keys = (use_sessions & 'task_protocol LIKE "%biased%"').fetch('KEY') 37 | ses = ((use_sessions & 'task_protocol LIKE "%biased%"') 38 | * subject.Subject * subject.SubjectLab * reference.Lab 39 | * (behavior.TrialSet.Trial & session_keys)) 40 | ses = ses.proj('institution_short', 'subject_nickname', 'task_protocol', 'session_uuid', 41 | 'trial_stim_contrast_left', 'trial_stim_contrast_right', 42 | 'trial_response_choice', 'task_protocol', 'trial_stim_prob_left', 43 | 'trial_feedback_type', 'trial_response_time', 'trial_stim_on_time', 44 | 'session_end_time').fetch( 45 | order_by='institution_short, subject_nickname,session_start_time, trial_id', 46 | format='frame').reset_index() 47 | behav = dj2pandas(ses) 48 | behav['institution_code'] = behav.institution_short.map(institution_map) 49 | else: 50 | behav = load_csv('Fig4.csv') 51 | 52 | biased_fits = pd.DataFrame() 53 | for i, nickname in enumerate(behav['subject_nickname'].unique()): 54 | if np.mod(i+1, 10) == 0: 55 | print('Processing data of subject %d of %d' % (i+1, 56 | len(behav['subject_nickname'].unique()))) 57 | 58 | # Get lab 59 | lab = behav.loc[behav['subject_nickname'] == nickname, 'institution_code'].unique()[0] 60 | 61 | # Fit psychometric curve 62 | left_fit = fit_psychfunc(behav[(behav['subject_nickname'] == nickname) 63 | & (behav['probabilityLeft'] == 80)]) 64 | right_fit = fit_psychfunc(behav[(behav['subject_nickname'] == nickname) 65 | & (behav['probabilityLeft'] == 20)]) 66 | fits = pd.DataFrame(data={'threshold_l': left_fit['threshold'], 67 | 'threshold_r': right_fit['threshold'], 68 | 'bias_l': left_fit['bias'], 69 | 'bias_r': right_fit['bias'], 70 | 'lapselow_l': left_fit['lapselow'], 71 | 'lapselow_r': right_fit['lapselow'], 72 | 'lapsehigh_l': left_fit['lapsehigh'], 73 | 'lapsehigh_r': right_fit['lapsehigh'], 74 | 'nickname': nickname, 'lab': lab}) 75 | biased_fits = biased_fits.append(fits, sort=False) 76 | 77 | # %% Statistics 78 | stats_tests = pd.DataFrame(columns=['variable', 'test_type', 'p_value']) 79 | posthoc_tests = {} 80 | 81 | for i, var in enumerate(['threshold_l', 'threshold_r', 'lapselow_l', 'lapselow_r', 'lapsehigh_l', 82 | 'lapsehigh_r', 'bias_l', 'bias_r']): 83 | _, normal = stats.normaltest(biased_fits[var]) 84 | 85 | if normal < 0.05: 86 | test_type = 'kruskal' 87 | test = stats.kruskal(*[group[var].values 88 | for name, group in biased_fits.groupby('lab')]) 89 | if test[1] < 0.05: # Proceed to posthocs 90 | posthoc = sp.posthoc_dunn(biased_fits, val_col=var, group_col='lab') 91 | else: 92 | posthoc = np.nan 93 | else: 94 | test_type = 'anova' 95 | test = stats.f_oneway(*[group[var].values 96 | for name, group in biased_fits.groupby('lab')]) 97 | if test[1] < 0.05: 98 | posthoc = sp.posthoc_tukey(biased_fits, val_col=var, group_col='lab') 99 | else: 100 | posthoc = np.nan 101 | 102 | posthoc_tests['posthoc_'+str(var)] = posthoc 103 | stats_tests.loc[i, 'variable'] = var 104 | stats_tests.loc[i, 'test_type'] = test_type 105 | stats_tests.loc[i, 'p_value'] = test[1] 106 | 107 | # Correct for multiple tests 108 | stats_tests['p_value'] = multipletests(stats_tests['p_value'], method='fdr_bh')[1] 109 | 110 | # Test between left/right blocks 111 | for i, var in enumerate(['threshold', 'lapselow', 'lapsehigh', 'bias']): 112 | stats_tests.loc[stats_tests.shape[0] + 1, 'variable'] = '%s_blocks' % var 113 | stats_tests.loc[stats_tests.shape[0], 'test_type'] = 'wilcoxon' 114 | _, stats_tests.loc[stats_tests.shape[0], 'p_value'] = stats.wilcoxon( 115 | biased_fits['%s_l' % var], biased_fits['%s_r' % var]) 116 | print(stats_tests) # Print the results 117 | 118 | # %% Plot metrics 119 | f, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(FIGURE_WIDTH*0.8, FIGURE_HEIGHT)) 120 | lab_colors = group_colors() 121 | 122 | ax1.plot([10, 20], [10, 20], linestyle='dashed', color=[0.6, 0.6, 0.6]) 123 | for i, lab in enumerate(biased_fits['lab'].unique()): 124 | ax1.errorbar(biased_fits.loc[biased_fits['lab'] == lab, 'threshold_l'].mean(), 125 | biased_fits.loc[biased_fits['lab'] == lab, 'threshold_r'].mean(), 126 | xerr=biased_fits.loc[biased_fits['lab'] == lab, 'threshold_l'].sem(), 127 | yerr=biased_fits.loc[biased_fits['lab'] == lab, 'threshold_l'].sem(), 128 | fmt='.', color=lab_colors[i]) 129 | ax1.set(xlabel='80:20 block', ylabel='20:80 block', title='Threshold', 130 | yticks=ax1.get_xticks(), ylim=ax1.get_xlim()) 131 | 132 | ax2.plot([0, 0.1], [0, 0.1], linestyle='dashed', color=[0.6, 0.6, 0.6]) 133 | for i, lab in enumerate(biased_fits['lab'].unique()): 134 | ax2.errorbar(biased_fits.loc[biased_fits['lab'] == lab, 'lapselow_l'].mean(), 135 | biased_fits.loc[biased_fits['lab'] == lab, 'lapselow_r'].mean(), 136 | xerr=biased_fits.loc[biased_fits['lab'] == lab, 'lapselow_l'].sem(), 137 | yerr=biased_fits.loc[biased_fits['lab'] == lab, 'lapselow_r'].sem(), 138 | fmt='.', color=lab_colors[i]) 139 | ax2.set(xlabel='80:20 block', ylabel='', title='Lapse left', 140 | yticks=ax2.get_xticks(), ylim=ax2.get_xlim()) 141 | 142 | ax3.plot([0, 0.1], [0, 0.1], linestyle='dashed', color=[0.6, 0.6, 0.6]) 143 | for i, lab in enumerate(biased_fits['lab'].unique()): 144 | ax3.errorbar(biased_fits.loc[biased_fits['lab'] == lab, 'lapsehigh_l'].mean(), 145 | biased_fits.loc[biased_fits['lab'] == lab, 'lapsehigh_r'].mean(), 146 | xerr=biased_fits.loc[biased_fits['lab'] == lab, 'lapsehigh_l'].sem(), 147 | yerr=biased_fits.loc[biased_fits['lab'] == lab, 'lapsehigh_l'].sem(), 148 | fmt='.', color=lab_colors[i]) 149 | ax3.set(xlabel='80:20 block', ylabel='', title='Lapse right', 150 | yticks=ax3.get_xticks(), ylim=ax3.get_xlim()) 151 | 152 | ax4.plot([-10, 10], [-10, 10], linestyle='dashed', color=[0.6, 0.6, 0.6]) 153 | for i, lab in enumerate(biased_fits['lab'].unique()): 154 | ax4.errorbar(biased_fits.loc[biased_fits['lab'] == lab, 'bias_l'].mean(), 155 | biased_fits.loc[biased_fits['lab'] == lab, 'bias_r'].mean(), 156 | xerr=biased_fits.loc[biased_fits['lab'] == lab, 'bias_l'].sem(), 157 | yerr=biased_fits.loc[biased_fits['lab'] == lab, 'bias_l'].sem(), 158 | fmt='.', color=lab_colors[i]) 159 | ax4.set(xlabel='80:20 block', ylabel='', title='Bias', 160 | yticks=ax4.get_xticks(), ylim=ax4.get_xlim()) 161 | 162 | plt.tight_layout(w_pad=-0.1) 163 | sns.despine(trim=True) 164 | plt.savefig(join(figpath, 'figure4f-i_metrics_per_lab_full.pdf')) 165 | plt.savefig(join(figpath, 'figure4f-i_metrics_per_lab_full.png'), dpi=300) 166 | -------------------------------------------------------------------------------- /figure4i_classifier_lab_membership_full.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Decode in which lab a mouse was trained based on its behavioral metrics during the three sessions 5 | of the full task variant in which the mouse was determined to be ready for ephys. 6 | 7 | As a positive control, the time zone in which the mouse was trained is included in the dataset 8 | since the timezone provides geographical information. Decoding is performed using leave-one-out 9 | cross-validation. To control for the imbalance in the dataset (some labs have more mice than 10 | others) a fixed number of mice is randomly sub-sampled from each lab. This random sampling is 11 | repeated for a large number of repetitions. A shuffled nul-distribution is obtained by shuffling 12 | the lab labels and decoding again for each iteration. 13 | 14 | -------------- 15 | Parameters 16 | DECODER: Which decoder to use: 'bayes', 'forest', or 'regression' 17 | N_MICE: How many mice per lab to randomly sub-sample 18 | (must be lower than the lab with the least mice) 19 | ITERATIONS: Number of times to randomly sub-sample 20 | METRICS: List of strings indicating which behavioral metrics to include 21 | during decoding of lab membership 22 | METRICS_CONTROL: List of strings indicating which metrics to use for the positive control 23 | 24 | Guido Meijer 25 | September 3, 2020 26 | """ 27 | 28 | import numpy as np 29 | from os.path import join 30 | from paper_behavior_functions import \ 31 | institution_map, QUERY, fit_psychfunc, dj2pandas, load_csv, datapath 32 | import pandas as pd 33 | from sklearn.ensemble import RandomForestClassifier 34 | from sklearn.naive_bayes import GaussianNB 35 | from sklearn.linear_model import LogisticRegression 36 | from sklearn.model_selection import LeaveOneOut 37 | from sklearn.metrics import f1_score, confusion_matrix 38 | 39 | # Settings 40 | DECODER = 'bayes' # bayes, forest or regression 41 | N_MICE = 8 # how many mice per lab to randomply sub-sample 42 | ITERATIONS = 2000 # how often to decode with random sub-samples 43 | METRICS = ['threshold_l', 'threshold_r', 'bias_l', 'bias_r', 'lapselow_l', 'lapselow_r', 44 | 'lapsehigh_l', 'lapsehigh_r'] 45 | METRICS_CONTROL = ['threshold_l', 'threshold_r', 'bias_l', 'bias_r', 'lapselow_l', 'lapselow_r', 46 | 'lapsehigh_l', 'lapsehigh_r', 'time_zone'] 47 | 48 | 49 | # Decoding function with n-fold cross validation 50 | def decoding(data, labels, clf): 51 | kf = LeaveOneOut() 52 | y_pred = np.empty(len(labels), dtype=' significance: 32 | print('Classification performance not significanlty above chance') 33 | else: 34 | print('Above chance classification performance!') 35 | 36 | # %% 37 | 38 | f, ax1 = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT)) 39 | sns.violinplot(data=pd.concat([decoding_result['control'], 40 | decoding_result['original_shuffled'], 41 | decoding_result['original']], axis=1), 42 | palette=colors, ax=ax1) 43 | ax1.plot([-1, 3.5], [chance_level, chance_level], '--', color='k', zorder=-10) 44 | ax1.set(ylabel='Decoding accuracy', xlim=[-0.8, 2.4], ylim=[-0.1, 0.62]) 45 | ax1.set_xticklabels(['Positive\ncontrol', 'Shuffle', 'Mouse\nbehavior'], 46 | rotation=90, ha='center') 47 | plt.tight_layout() 48 | sns.despine(trim=True) 49 | 50 | plt.savefig(join(FIG_PATH, 'figure4i_decoding.pdf')) 51 | plt.savefig(join(FIG_PATH, 'figure4i_decoding.png'), dpi=300) 52 | -------------------------------------------------------------------------------- /figure5_GLM_modelfit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on 2020-07-20 5 | @author: Anne Urai 6 | """ 7 | from pathlib import Path 8 | 9 | import pandas as pd 10 | import numpy as np 11 | from paper_behavior_functions import (query_sessions_around_criterion, 12 | institution_map, dj2pandas, load_csv, datapath, QUERY) 13 | from ibl_pipeline import behavior, subject, reference 14 | from tqdm.auto import tqdm 15 | from sklearn.model_selection import KFold 16 | 17 | # for modelling 18 | import patsy # to build design matrix 19 | import statsmodels.api as sm 20 | 21 | # progress bar 22 | tqdm.pandas(desc="model fitting") 23 | 24 | # whether to query data from DataJoint (True), or load from disk (False) 25 | institution_map, col_names = institution_map() 26 | 27 | # ========================================== # 28 | #%% 1. LOAD DATA 29 | # ========================================== # 30 | 31 | # Query sessions: before and after full task was first introduced 32 | if QUERY is True: 33 | use_sessions, _ = query_sessions_around_criterion(criterion='biased', 34 | days_from_criterion=[2, 3], 35 | as_dataframe=False, 36 | force_cutoff=True) 37 | 38 | trial_fields = ('trial_stim_contrast_left', 'trial_stim_contrast_right', 39 | 'trial_response_time', 'trial_stim_prob_left', 40 | 'trial_feedback_type', 'trial_stim_on_time', 'trial_response_choice') 41 | 42 | # query trial data for sessions and subject name and lab info 43 | trials = use_sessions.proj('task_protocol') * behavior.TrialSet.Trial.proj(*trial_fields) 44 | subject_info = subject.Subject.proj('subject_nickname') * \ 45 | (subject.SubjectLab * reference.Lab).proj('institution_short') 46 | 47 | # Fetch, join and sort data as a pandas DataFrame 48 | behav = dj2pandas(trials.fetch(format='frame') 49 | .join(subject_info.fetch(format='frame')) 50 | .sort_values(by=['institution_short', 'subject_nickname', 51 | 'session_start_time', 'trial_id']) 52 | .reset_index()) 53 | behav['institution_code'] = behav.institution_short.map(institution_map) 54 | # split the two types of task protocols (remove the pybpod version number) 55 | behav['task'] = behav['task_protocol'].str[14:20].copy() 56 | 57 | # RECODE SOME THINGS JUST FOR PATSY 58 | behav['contrast'] = np.abs(behav.signed_contrast) 59 | behav['stimulus_side'] = np.sign(behav.signed_contrast) 60 | behav['block_id'] = behav['probabilityLeft'].map({80:-1, 50:0, 20:1}) 61 | 62 | else: # load from disk 63 | behav = load_csv('Fig5.csv') 64 | 65 | # ========================================== # 66 | #%% 2. DEFINE THE GLM 67 | # ========================================== # 68 | 69 | 70 | # DEFINE THE MODEL 71 | def fit_glm(behav, prior_blocks=False, folds=5): 72 | 73 | # drop trials with contrast-level 50, only rarely present (should not be its own regressor) 74 | behav = behav[np.abs(behav.signed_contrast) != 50] 75 | 76 | # use patsy to easily build design matrix 77 | if not prior_blocks: 78 | endog, exog = patsy.dmatrices('choice ~ 1 + stimulus_side:C(contrast, Treatment)' 79 | '+ previous_choice:C(previous_outcome)', 80 | data=behav.dropna(subset=['trial_feedback_type', 'choice', 81 | 'previous_choice', 'previous_outcome']).reset_index(), 82 | return_type='dataframe') 83 | else: 84 | endog, exog = patsy.dmatrices('choice ~ 1 + stimulus_side:C(contrast, Treatment)' 85 | '+ previous_choice:C(previous_outcome) ' 86 | '+ block_id', 87 | data=behav.dropna(subset=['trial_feedback_type', 'choice', 88 | 'previous_choice', 'previous_outcome', 'block_id']).reset_index(), 89 | return_type='dataframe') 90 | 91 | # remove the one column (with 0 contrast) that has no variance 92 | if 'stimulus_side:C(contrast, Treatment)[0.0]' in exog.columns: 93 | exog.drop(columns=['stimulus_side:C(contrast, Treatment)[0.0]'], inplace=True) 94 | 95 | # recode choices for logistic regression 96 | endog['choice'] = endog['choice'].map({-1:0, 1:1}) 97 | 98 | # rename columns 99 | exog.rename(columns={'Intercept': 'bias', 100 | 'stimulus_side:C(contrast, Treatment)[6.25]': '6.25', 101 | 'stimulus_side:C(contrast, Treatment)[12.5]': '12.5', 102 | 'stimulus_side:C(contrast, Treatment)[25.0]': '25', 103 | 'stimulus_side:C(contrast, Treatment)[50.0]': '50', 104 | 'stimulus_side:C(contrast, Treatment)[100.0]': '100', 105 | 'previous_choice:C(previous_outcome)[-1.0]': 'unrewarded', 106 | 'previous_choice:C(previous_outcome)[1.0]': 'rewarded'}, 107 | inplace=True) 108 | 109 | # NOW FIT THIS WITH STATSMODELS - ignore NaN choices 110 | logit_model = sm.Logit(endog, exog) 111 | res = logit_model.fit_regularized(disp=False) # run silently 112 | 113 | # what do we want to keep? 114 | params = pd.DataFrame(res.params).T 115 | params['pseudo_rsq'] = res.prsquared # https://www.statsmodels.org/stable/generated/statsmodels.discrete.discrete_model.LogitResults.prsquared.html?highlight=pseudo 116 | params['condition_number'] = np.linalg.cond(exog) 117 | 118 | # ===================================== # 119 | # ADD MODEL ACCURACY - cross-validate 120 | 121 | kf = KFold(n_splits=folds, shuffle=True) 122 | acc = np.array([]) 123 | for train, test in kf.split(endog): 124 | X_train, X_test, y_train, y_test = exog.loc[train], exog.loc[test], \ 125 | endog.loc[train], endog.loc[test] 126 | # fit again 127 | logit_model = sm.Logit(y_train, X_train) 128 | res = logit_model.fit_regularized(disp=False) # run silently 129 | 130 | # compute the accuracy on held-out data [from Luigi]: 131 | # suppose you are predicting Pr(Left), let's call it p, 132 | # the % match is p if the actual choice is left, or 1-p if the actual choice is right 133 | # if you were to simulate it, in the end you would get these numbers 134 | y_test['pred'] = res.predict(X_test) 135 | y_test.loc[y_test['choice'] == 0, 'pred'] = 1 - y_test.loc[y_test['choice'] == 0, 'pred'] 136 | acc = np.append(acc, y_test['pred'].mean()) 137 | 138 | # average prediction accuracy over the K folds 139 | params['accuracy'] = np.mean(acc) 140 | 141 | return params # wide df 142 | 143 | 144 | # ========================================== # 145 | #%% 3. FIT FOR EACH MOUSE 146 | # ========================================== # 147 | 148 | print('fitting GLM to BASIC task...') 149 | params_basic = behav.loc[behav.task == 'traini', :].groupby( 150 | ['institution_code', 'subject_nickname']).progress_apply(fit_glm, 151 | prior_blocks=False).reset_index() 152 | print('The mean condition number for the basic model is', params_basic['condition_number'].mean()) 153 | 154 | print('fitting GLM to FULL task...') 155 | params_full = behav.loc[behav.task == 'biased', :].groupby( 156 | ['institution_code', 'subject_nickname']).progress_apply(fit_glm, 157 | prior_blocks=True).reset_index() 158 | print('The mean condition number for the full model is', params_full['condition_number'].mean()) 159 | 160 | # ========================================== # 161 | # SAVE FOR NEXT TIME 162 | # ========================================== # 163 | 164 | data_path = Path(datapath(), 'model_results') 165 | params_basic.to_csv(data_path / 'params_basic.csv') 166 | params_full.to_csv(data_path / 'params_full.csv') 167 | -------------------------------------------------------------------------------- /figure5_GLM_plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on 2020-07-20 5 | @author: Anne Urai 6 | """ 7 | from pathlib import Path 8 | 9 | import pandas as pd 10 | import matplotlib.pyplot as plt 11 | import seaborn as sns 12 | from scipy import stats 13 | 14 | from paper_behavior_functions import (seaborn_style, institution_map, 15 | group_colors, figpath, load_csv, 16 | FIGURE_WIDTH, FIGURE_HEIGHT, num_star) 17 | 18 | 19 | # Load some things from paper_behavior_functions 20 | figpath = Path(figpath()) 21 | seaborn_style() 22 | institution_map, col_names = institution_map() 23 | pal = group_colors() 24 | cmap = sns.diverging_palette(20, 220, n=3, center="dark") 25 | 26 | # ========================================== # 27 | #%% 1. GET GLM FITS FOR ALL MICE 28 | # ========================================== # 29 | 30 | print('loading model from disk...') 31 | params_basic = load_csv('model_results', 'params_basic.csv') 32 | params_full = load_csv('model_results', 'params_full.csv') 33 | combined = params_basic.merge(params_full, on=['institution_code', 'subject_nickname']) 34 | 35 | # ========================================== # 36 | # PRINT SUMMARY AND STATS 37 | # ========================================== # 38 | 39 | vars = ['6.25', '12.5', '25', '100', 'rewarded','unrewarded', 'bias'] 40 | for v in vars: 41 | print('basic task, %s: mean %.2f, %f : %f'%(v, params_basic[v].mean(), 42 | params_basic[v].min(), 43 | params_basic[v].max())) 44 | 45 | print('full task, %s: mean %.2f, %f : %f'%(v, params_full[v].mean(), 46 | params_full[v].min(), 47 | params_full[v].max())) 48 | 49 | # DO STATS BETWEEN THE TWO TASK TYPES 50 | test = stats.ttest_rel(combined[v + '_y'], 51 | combined[v + '_x'], 52 | axis=0, nan_policy='omit') 53 | print(test) 54 | 55 | # just show the average block bias in the full task 56 | print('full task, block_id: mean %.2f, %f: %f'%(params_full['block_id'].mean(), 57 | params_full['block_id'].min(), 58 | params_full['block_id'].max())) 59 | 60 | # ========================================== # 61 | #%% 2. PLOT WEIGHTS ACROSS MICE AND LABS 62 | # ========================================== # 63 | 64 | # reshape the data and average across labs for easy plotting 65 | basic_summ_visual = pd.melt(params_basic, 66 | id_vars=['institution_code', 'subject_nickname'], 67 | value_vars=['6.25', '12.5', '25', '100']).groupby(['subject_nickname', 68 | 'institution_code', 'variable']).mean().reset_index() 69 | 70 | basic_summ_bias = pd.melt(params_basic, 71 | id_vars=['institution_code', 'subject_nickname'], 72 | value_vars=['unrewarded', 'rewarded', 'bias']).groupby(['subject_nickname', 73 | 'institution_code', 'variable']).mean().reset_index() 74 | # WEIGHTS IN THE BASIC TASK 75 | plt.close('all') 76 | fig, ax = plt.subplots(1, 2, figsize=(FIGURE_WIDTH/3, FIGURE_HEIGHT)) 77 | sns.pointplot(data = basic_summ_visual, 78 | hue = 'institution_code', x = 'variable', y= 'value', 79 | order=['6.25', '12.5', '25', '100'], 80 | palette = pal, marker='.', ax=ax[0], zorder=0, edgecolors='white', 81 | join = False, dodge = 0.6, ci = 95, errwidth=1) 82 | plt.setp(ax[0].collections, sizes=[3]) 83 | ax[0].plot(basic_summ_visual.groupby(['variable'])['value'].mean()[['6.25', '12.5', '25', '100']], 84 | color='black', linewidth=0, marker='_', markersize=13, zorder=100) 85 | ax[0].get_legend().set_visible(False) 86 | ax[0].set(xlabel=' ', ylabel='Weight', ylim=[0,5.5]) 87 | 88 | 89 | sns.pointplot(data = basic_summ_bias, 90 | hue = 'institution_code', x = 'variable', y= 'value', 91 | order=['rewarded', 'unrewarded', 'bias'], 92 | palette = pal, marker='.', ax=ax[1], zorder=0, edgecolors='white', 93 | join = False, dodge = 0.6, ci = 95, errwidth=1) 94 | plt.setp(ax[1].collections, sizes=[3]) 95 | ax[1].plot(basic_summ_bias.groupby(['variable'])['value'].mean()[['rewarded', 'unrewarded', 'bias']], 96 | color='black', linewidth=0, marker='_', markersize=13, zorder=100) 97 | ax[1].get_legend().set_visible(False) 98 | ax[1].set(xlabel='', ylabel='', ylim=[-1,1.2], yticks=[-1, -0.5, 0, 0.5, 1], 99 | xticks=[0,1,2,3], xlim=[-0.5, 3.5]) 100 | ax[1].axhline(color='darkgray', linestyle=':') 101 | ax[1].set_xticklabels([], ha='right', rotation=15) 102 | sns.despine(trim=True) 103 | plt.tight_layout(w_pad=-0.1) 104 | fig.savefig(figpath / 'figure5c_basic_weights.pdf') 105 | 106 | # ========================= # 107 | # SAME BUT FOR FULL TASK 108 | # ========================= # 109 | 110 | # reshape the data and average across labs for easy plotting 111 | full_summ_visual = pd.melt(params_full, 112 | id_vars=['institution_code', 'subject_nickname'], 113 | value_vars=['6.25', '12.5', '25', '100']).groupby(['institution_code', 114 | 'subject_nickname', 'variable']).mean().reset_index() 115 | full_summ_bias = pd.melt(params_full, 116 | id_vars=['institution_code', 'subject_nickname'], 117 | value_vars=['unrewarded', 'rewarded', 118 | 'bias', 'block_id']).groupby(['institution_code', 119 | 'subject_nickname', 'variable']).mean().reset_index() 120 | # WEIGHTS IN THE FULL TASK 121 | plt.close('all') 122 | fig, ax = plt.subplots(1, 2, figsize=(FIGURE_WIDTH/3, FIGURE_HEIGHT)) 123 | sns.pointplot(data = full_summ_visual, 124 | order=['6.25', '12.5', '25', '100'], 125 | hue = 'institution_code', x = 'variable', y= 'value', 126 | palette = pal, marker='.', ax=ax[0], zorder=0, edgecolor='white', 127 | join = False, dodge = 0.6, ci = 95, errwidth=1) 128 | plt.setp(ax[0].collections, sizes=[3]) 129 | ax[0].plot(full_summ_visual.groupby(['variable'])['value'].mean()[['6.25', '12.5', '25', '100']], 130 | color='black', linewidth=0, marker='_', markersize=13, zorder=100) 131 | ax[0].get_legend().set_visible(False) 132 | ax[0].set(xlabel=' ', ylabel='Weight', ylim=[0,5.5]) 133 | 134 | sns.pointplot(data = full_summ_bias, 135 | hue = 'institution_code', x = 'variable', y= 'value', 136 | order=['rewarded', 'unrewarded', 'bias', 'block_id'], 137 | palette = pal, marker='.', ax=ax[1], zorder=0, edgecolor='white', 138 | join = False, dodge = 0.6, ci = 95, errwidth=1) 139 | plt.setp(ax[1].collections, sizes=[3]) 140 | ax[1].plot(full_summ_bias.groupby(['variable'])['value'].mean()[['rewarded', 'unrewarded', 'bias', 'block_id']], 141 | color='black', linewidth=0, marker='_', markersize=13, zorder=100) 142 | ax[1].axhline(color='darkgray', linestyle=':') 143 | ax[1].get_legend().set_visible(False) 144 | ax[1].set(xlabel='', ylabel='', ylim=[-1,1.2], yticks=[-1,-0.5, 0, 0.5, 1]) 145 | ax[1].set_xticklabels([], ha='right', rotation=20) 146 | 147 | sns.despine(trim=True) 148 | plt.tight_layout(w_pad=-0.1) 149 | fig.savefig(figpath / 'figure5c_full_weights.pdf') 150 | 151 | # ========================================== # 152 | #%% SUPPLEMENTARY FIGURE: 153 | # EACH PARAMETER ACROSS LABS 154 | # ========================================== # 155 | 156 | # add the data for all labs combined 157 | params_basic_all = params_basic.copy() 158 | params_basic_all['institution_code'] = 'All' 159 | params_basic_all = params_basic.append(params_basic_all) 160 | 161 | # add the data for all labs combined 162 | params_full_all = params_full.copy() 163 | params_full_all['institution_code'] = 'All' 164 | params_full_all = params_full.append(params_full_all) 165 | 166 | # which variables to plot? 167 | vars = ['6.25', '12.5', '25', '100', 'unrewarded', 'rewarded', 'bias', 'block_id', 'pseudo_rsq', 'accuracy'] 168 | ylabels =['Contrast: 6.25', 'Contrast: 12.5', 'Contrast: 25', ' Contrast: 100', 169 | 'Past choice: unrewarded', 'Past choice: rewarded', 'Bias: constant', 170 | 'Bias: block prior', 'Pseudo-R$^2$', 'Model accuracy (5-fold c.v.)'] 171 | ylims = [[0, 6.5], [0, 6.5], [0, 6.5], [0, 6.5], [-1, 1.5], [-1, 1.5], 172 | [-2, 2], [-0.5, 1], [0, 1], [0.5, 1.02]] 173 | 174 | plt.close('all') 175 | for params, modelname in zip([[params_basic, params_basic_all], 176 | [params_full, params_full_all]], ['basic', 'full']): 177 | for v, ylab, ylim in zip(vars, ylabels, ylims): 178 | 179 | if v in params[0].columns: # skip bias for the basic task 180 | 181 | print(modelname) 182 | print(v) 183 | f, ax = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT)) 184 | sns.swarmplot(y=v, x='institution_code', data=params[0], hue='institution_code', 185 | palette=pal, ax=ax, marker='.') 186 | axbox = sns.boxplot(y=v, x='institution_code', data=params[1], color='white', 187 | showfliers=False, ax=ax) 188 | ax.set(ylabel=ylab, xlabel='', ylim=ylim) 189 | # [tick.set_color(lab_colors[i]) for i, tick in enumerate(ax5.get_xticklabels()[:-1])] 190 | plt.setp(ax.xaxis.get_majorticklabels(), rotation=60) 191 | axbox.artists[-1].set_edgecolor('black') 192 | for j in range(5 * (len(axbox.artists) - 1), 5 * len(axbox.artists)): 193 | axbox.lines[j].set_color('black') 194 | ax.get_legend().set_visible(False) 195 | 196 | # DO STATISTICS 197 | _, normal = stats.normaltest(params[0][v], nan_policy='omit') 198 | 199 | if normal < 0.05: 200 | test_type = 'kruskal' 201 | test = stats.kruskal(*[group[v].values 202 | for name, group in params[0].groupby('institution_code')], 203 | nan_policy='omit') 204 | else: 205 | test_type = 'anova' 206 | test = stats.f_oneway(*[group[v].values 207 | for name, group in params[0].groupby('institution_code')]) 208 | 209 | # statistical annotation 210 | pvalue = test[1] 211 | if pvalue < 0.05: 212 | ax.annotate(num_star(pvalue), 213 | xy=[0.1, 0.8], xycoords='axes fraction', fontsize=5) 214 | 215 | sns.despine(trim=True) 216 | plt.tight_layout() 217 | plt.savefig(figpath / f'suppfig_model_{modelname}_metrics_{v}.pdf') 218 | -------------------------------------------------------------------------------- /figure5_GLM_simulate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on 2020-07-20 5 | @author: Anne Urai 6 | """ 7 | 8 | import os 9 | import pandas as pd 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import seaborn as sns 13 | import patsy # to build design matrix 14 | import statsmodels.api as sm 15 | from ibl_pipeline import behavior, subject, reference 16 | from paper_behavior_functions import (query_sessions_around_criterion, 17 | seaborn_style, institution_map, 18 | group_colors, figpath, EXAMPLE_MOUSE, 19 | FIGURE_WIDTH, FIGURE_HEIGHT, load_csv, 20 | dj2pandas, plot_psychometric, QUERY) 21 | 22 | # Load some things from paper_behavior_functions 23 | figpath = figpath() 24 | seaborn_style() 25 | institution_map, col_names = institution_map() 26 | pal = group_colors() 27 | #cmap = sns.diverging_palette(20, 220, n=3, center="dark") 28 | cmap = sns.color_palette([[0.8984375,0.37890625,0.00390625], 29 | [0.3, 0.3, 0.3], [0.3671875,0.234375,0.59765625]]) 30 | 31 | # ========================================== # 32 | #%% 1. LOAD DATA - just from example mouse 33 | # ========================================== # 34 | 35 | if QUERY: 36 | # Query sessions: before and after full task was first introduced 37 | use_sessions, _ = query_sessions_around_criterion(criterion='biased', 38 | days_from_criterion=[2, 3], 39 | as_dataframe=False, 40 | force_cutoff=True) 41 | use_sessions = (subject.Subject & 'subject_nickname = "%s"' % EXAMPLE_MOUSE) * use_sessions 42 | 43 | trial_fields = ('trial_stim_contrast_left', 'trial_stim_contrast_right', 44 | 'trial_response_time', 'trial_stim_prob_left', 45 | 'trial_feedback_type', 'trial_stim_on_time', 'trial_response_choice') 46 | 47 | # query trial data for sessions and subject name and lab info 48 | trials = use_sessions.proj('task_protocol') * behavior.TrialSet.Trial.proj(*trial_fields) 49 | 50 | # only grab the example mouse 51 | subject_info = (subject.Subject) * \ 52 | (subject.SubjectLab * reference.Lab).proj('institution_short') 53 | 54 | # Fetch, join and sort data as a pandas DataFrame 55 | behav = dj2pandas(trials.fetch(format='frame') 56 | .join(subject_info.fetch(format='frame')) 57 | .sort_values(by=['institution_short', 'subject_nickname', 58 | 'session_start_time', 'trial_id']) 59 | .reset_index()) 60 | # split the two types of task protocols (remove the pybpod version number) 61 | behav['task'] = behav['task_protocol'].str[14:20].copy() 62 | 63 | # RECODE SOME THINGS JUST FOR PATSY 64 | behav['contrast'] = np.abs(behav.signed_contrast) 65 | behav['stimulus_side'] = np.sign(behav.signed_contrast) 66 | behav['block_id'] = behav['probabilityLeft'].map({80:-1, 50:0, 20:1}) 67 | else: 68 | behav = load_csv('Fig5_simulate.pkl') 69 | 70 | # ========================================== # 71 | #%% 2. DEFINE THE GLM 72 | # ========================================== # 73 | 74 | 75 | # DEFINE THE MODEL 76 | def fit_glm(behav, prior_blocks=False, n_sim=10000): 77 | 78 | # drop trials with contrast-level 50, only rarely present (should not be its own regressor) 79 | behav = behav[np.abs(behav.signed_contrast) != 50] 80 | 81 | # use patsy to easily build design matrix 82 | if not prior_blocks: 83 | behav = behav.dropna(subset=['trial_feedback_type', 'choice', 84 | 'previous_choice', 'previous_outcome']).reset_index() 85 | endog, exog = patsy.dmatrices('choice ~ 1 + stimulus_side:C(contrast, Treatment)' 86 | '+ previous_choice:C(previous_outcome)', 87 | data=behav, return_type='dataframe') 88 | else: 89 | behav = behav.dropna(subset=['trial_feedback_type', 'choice', 90 | 'previous_choice', 'previous_outcome', 'block_id']).reset_index() 91 | endog, exog = patsy.dmatrices('choice ~ 1 + stimulus_side:C(contrast, Treatment)' 92 | '+ previous_choice:C(previous_outcome) ' 93 | '+ block_id', 94 | data=behav, return_type='dataframe') 95 | 96 | # remove the one column (with 0 contrast) that has no variance 97 | if 'stimulus_side:C(contrast, Treatment)[0.0]' in exog.columns: 98 | exog.drop(columns=['stimulus_side:C(contrast, Treatment)[0.0]'], inplace=True) 99 | 100 | # recode choices for logistic regression 101 | endog['choice'] = endog['choice'].map({-1:0, 1:1}) 102 | 103 | # rename columns 104 | exog.rename(columns={'Intercept': 'bias', 105 | 'stimulus_side:C(contrast, Treatment)[6.25]': '6.25', 106 | 'stimulus_side:C(contrast, Treatment)[12.5]': '12.5', 107 | 'stimulus_side:C(contrast, Treatment)[25.0]': '25', 108 | 'stimulus_side:C(contrast, Treatment)[50.0]': '50', 109 | 'stimulus_side:C(contrast, Treatment)[100.0]': '100', 110 | 'previous_choice:C(previous_outcome)[-1.0]': 'unrewarded', 111 | 'previous_choice:C(previous_outcome)[1.0]': 'rewarded'}, 112 | inplace=True) 113 | 114 | # NOW FIT THIS WITH STATSMODELS - ignore NaN choices 115 | logit_model = sm.Logit(endog, exog) 116 | res = logit_model.fit_regularized(disp=False) # run silently 117 | 118 | # what do we want to keep? 119 | params = pd.DataFrame(res.params).T 120 | 121 | # USE INVERSE HESSIAN TO CONSTRUCT MULTIVARIATE GAUSSIAN 122 | cov = -np.linalg.inv(logit_model.hessian(res.params)) 123 | samples = np.random.multivariate_normal(res.params, cov, n_sim) 124 | 125 | # sanity check: the mean of those samples should not be too different from the params 126 | assert np.allclose(params, np.mean(samples, axis=0), atol=0.1) 127 | 128 | # NOW SIMULATE THE MODEL X TIMES 129 | simulated_choices = [] 130 | for n in range(n_sim): 131 | # plug sampled parameters into the model - predict sequence of choices 132 | z = np.dot(exog, samples[n]) 133 | 134 | # then compute the mean choice fractions at each contrast, save and append 135 | behav['simulated_choice'] = 1 / (1 + np.exp(-z)) 136 | if not prior_blocks: 137 | simulated_choices.append(behav.groupby(['signed_contrast'])['simulated_choice'].mean().values) 138 | else: # split by probabilityLeft block 139 | gr = behav.groupby(['probabilityLeft', 'signed_contrast'])['simulated_choice'].mean().reset_index() 140 | simulated_choices.append([gr.loc[gr.probabilityLeft == 20, 'simulated_choice'].values, 141 | gr.loc[gr.probabilityLeft == 50, 'simulated_choice'].values, 142 | gr.loc[gr.probabilityLeft == 80, 'simulated_choice'].values]) 143 | 144 | return params, simulated_choices # wide df 145 | 146 | # ========================================== # 147 | #%% 3. FIT 148 | # ========================================== # 149 | 150 | print('fitting GLM to BASIC task...') 151 | params_basic, simulation_basic = fit_glm(behav.loc[behav.task == 'traini', :], 152 | prior_blocks=False) 153 | 154 | print('fitting GLM to FULL task...') 155 | params_full, simulation_full = fit_glm(behav.loc[behav.task == 'biased', :], 156 | prior_blocks=True) 157 | 158 | # ========================================== # 159 | #%% 4. PLOT PSYCHOMETRIC FUNCTIONS 160 | # ========================================== # 161 | 162 | # for plotting, replace 100 with -35 163 | behav['signed_contrast'] = behav['signed_contrast'].replace(-100, -35) 164 | behav['signed_contrast'] = behav['signed_contrast'].replace(100, 35) 165 | 166 | # BASIC TASK 167 | plt.close('all') 168 | # prep the figure with psychometric layout 169 | fig = sns.FacetGrid(behav.loc[behav.task == 'traini', :], 170 | sharex=True, sharey=True, 171 | height=FIGURE_HEIGHT, aspect=(FIGURE_WIDTH/4)/FIGURE_HEIGHT) 172 | fig.map(plot_psychometric, "signed_contrast", "choice_right", "subject_nickname", 173 | color='k', linewidth=0) # this will be empty, hack 174 | # now plot the datapoints, no errorbars 175 | sns.lineplot(data=behav.loc[behav.task == 'traini', :], 176 | x='signed_contrast', y='choice2', marker='o', err_style='bars', 177 | color='k', linewidth=0, ci=95, ax=fig.ax)# overlay the simulated 178 | # confidence intervals from the model - shaded regions 179 | fig.ax.fill_between(sorted(behav.signed_contrast.unique()), 180 | np.quantile(np.array(simulation_basic), q=0.025, axis=0), 181 | np.quantile(np.array(simulation_basic), q=0.975, axis=0), 182 | alpha=0.5, facecolor='k') 183 | fig.set_axis_labels(' ', 'Rightward choices (%)') 184 | fig.despine(trim=True) 185 | fig.savefig(os.path.join(figpath, "figure5b_basic_psychfunc.pdf")) 186 | 187 | # FULL TASK 188 | plt.close('all') 189 | fig = sns.FacetGrid(behav.loc[behav.task == 'biased', :], 190 | hue="probabilityLeft", palette=cmap, 191 | sharex=True, sharey=True, 192 | height=FIGURE_HEIGHT, aspect=(FIGURE_WIDTH/4)/FIGURE_HEIGHT) 193 | fig.map(plot_psychometric, "signed_contrast", "choice_right", "subject_nickname", linewidth=0) # just for axis layout, 194 | # hack 195 | # now plot the datapoints, no errorbars 196 | sns.lineplot(data=behav.loc[behav.task == 'biased', :], 197 | x='signed_contrast', y='choice2', marker='o', err_style='bars', 198 | hue='probabilityLeft', palette=cmap, linewidth=0, ci=95, ax=fig.ax, legend=None)# overlay the simulated 199 | # confidence intervals from the model - shaded regions 200 | for cidx, c in enumerate(cmap): 201 | simulation_full_perblock = [sim[cidx] for sim in simulation_full] # grab what we need, not super elegant 202 | fig.ax.fill_between(sorted(behav.signed_contrast.unique()), 203 | np.quantile(np.array(simulation_full_perblock), q=0.025, axis=0), 204 | np.quantile(np.array(simulation_full_perblock), q=0.975, axis=0), 205 | alpha=0.5, facecolor=cmap[cidx]) 206 | 207 | fig.ax.annotate('20:80', xy=(-5, 0.6), xytext=(-25, 0.8), color=cmap[0], fontsize=7) 208 | fig.ax.annotate('80:20', xy=(5, 0.4), xytext=(13, 0.18), color=cmap[2], fontsize=7) 209 | 210 | fig.set_axis_labels('\u0394 Contrast (%)', 'Rightward choices (%)') 211 | fig.despine(trim=True) 212 | fig.savefig(os.path.join(figpath, "figure5b_full_psychfunc.pdf")) 213 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ibllib>=1.5.13 2 | ibl-pipeline-light>=0.0.0 3 | statsmodels>=0.10.1 4 | datajoint>=0.12 5 | requests>=2.22.0 6 | scikit_posthocs>=0.6.1 7 | patsy>=0.5.1 8 | seaborn<0.12 9 | tqdm>=4.32.1 10 | nose>=1.3.7 11 | numpy>=1.16.4 12 | pandas>=0.24.2 13 | matplotlib>=3.0.3 14 | lifelines>=0.25.0 15 | pycircstat>=0.0.2 16 | scikit-learn>=0.21.3 17 | scipy>=1.3.0 18 | -------------------------------------------------------------------------------- /supp_days_between_trainingstatus.py: -------------------------------------------------------------------------------- 1 | """ 2 | The International Brain Laboratory 3 | Anne Urai, CSHL, 2020-09-07 4 | 5 | Starting from reaching 1a/1b, show distributions of days to next training stages 6 | 7 | """ 8 | from ibl_pipeline import subject, reference 9 | from ibl_pipeline.analyses import behavior as behavior_analysis 10 | 11 | from paper_behavior_functions import QUERY 12 | 13 | assert QUERY, 'This script requires a DataJoint instance, which was removed in Dec 2023.' 14 | 15 | # Query all subjects with project ibl_neuropixel_brainwide_01 and get the date at which 16 | # they reached a given training status 17 | all_subjects = (subject.Subject * subject.SubjectLab * reference.Lab * subject.SubjectProject 18 | & 'subject_project = "ibl_neuropixel_brainwide_01"') 19 | summ_by_date = all_subjects * behavior_analysis.BehavioralSummaryByDate 20 | training_status_by_day = summ_by_date.aggr(behavior_analysis.SessionTrainingStatus(), 21 | daily_status='(training_status)') 22 | df = (training_status_by_day 23 | .fetch(format='frame') 24 | .reset_index() 25 | .sort_values(by=['lab_name', 'session_date'])) 26 | print(df.daily_status.unique()) 27 | -------------------------------------------------------------------------------- /supp_figure2_performance_trials.py: -------------------------------------------------------------------------------- 1 | """ 2 | Learning curves for all labs 3 | 4 | @author: Anne Urai, Miles Wells 5 | 15 January 2020 6 | """ 7 | import os 8 | 9 | import pandas as pd 10 | import numpy as np 11 | from scipy.signal import medfilt 12 | import seaborn as sns 13 | import matplotlib.pyplot as plt 14 | import matplotlib.ticker as ticker 15 | 16 | from paper_behavior_functions import (query_subjects, figpath, load_csv, group_colors, 17 | institution_map, seaborn_style, EXAMPLE_MOUSE, 18 | FIGURE_HEIGHT, FIGURE_WIDTH, QUERY) 19 | from ibl_pipeline.analyses import behavior as behavioral_analyses 20 | 21 | # INITIALIZE A FEW THINGS 22 | seaborn_style() 23 | figpath = figpath() 24 | pal = group_colors() 25 | institution_map, col_names = institution_map() 26 | col_names = col_names[:-1] 27 | 28 | # %% ============================== # 29 | # GET DATA FROM TRAINED ANIMALS 30 | # ================================= # 31 | 32 | if QUERY is True: 33 | use_subjects = query_subjects() 34 | b = (behavioral_analyses.BehavioralSummaryByDate * use_subjects) 35 | behav = b.fetch(order_by='institution_short, subject_nickname, training_day', 36 | format='frame').reset_index() 37 | behav['institution_code'] = behav.institution_short.map(institution_map) 38 | else: 39 | behav = load_csv('Fig2af.pkl') 40 | 41 | # exclude sessions with fewer than 100 trials 42 | behav = behav[behav['n_trials_date'] > 100] 43 | 44 | # convolve performance over 3 days 45 | for i, nickname in enumerate(behav['subject_nickname'].unique()): 46 | perf = behav.loc[behav['subject_nickname'] == nickname, 'performance_easy'].values 47 | perf_conv = np.convolve(perf, np.ones((3,))/3, mode='valid') 48 | # perf_conv = np.append(perf_conv, [np.nan, np.nan]) 49 | perf_conv = medfilt(perf, kernel_size=3) 50 | behav.loc[behav['subject_nickname'] == nickname, 'performance_easy'] = perf_conv 51 | 52 | # how many mice are there for each lab? 53 | N = behav.groupby(['institution_code'])['subject_nickname'].nunique().to_dict() 54 | behav['n_mice'] = behav.institution_code.map(N) 55 | behav['institution_name'] = behav.institution_code + '\n ' + behav.n_mice.apply(str) + ' mice' 56 | 57 | # make sure each mouse starts at 0 58 | for index, group in behav.groupby(['lab_name', 'subject_nickname']): 59 | behav.loc[group.index, 'training_day'] = group['training_day'] - group['training_day'].min() 60 | 61 | # create another column only after the mouse is trained 62 | behav2 = pd.DataFrame([]) 63 | for index, group in behav.groupby(['institution_code', 'subject_nickname']): 64 | group['performance_easy_trained'] = group.performance_easy 65 | group.loc[pd.to_datetime(group['session_date']) < pd.to_datetime(group['date_trained']), 66 | 'performance_easy_trained'] = np.nan 67 | # add this 68 | behav2 = behav2.append(group) 69 | 70 | behav = behav2 71 | behav['performance_easy'] = behav.performance_easy * 100 72 | behav['performance_easy_trained'] = behav.performance_easy_trained * 100 73 | 74 | # Create column for cumulative trials per mouse 75 | behav.n_trials_date = behav.n_trials_date.astype(int) 76 | behav['cum_trials'] = ( 77 | (behav 78 | .groupby(by=['subject_uuid']) 79 | .cumsum() 80 | .n_trials_date) 81 | ) 82 | 83 | 84 | # %% ============================== # 85 | # LEARNING CURVES 86 | # ================================= # 87 | ############### 88 | # plot one curve for each animal, one panel per lab 89 | fig = sns.FacetGrid(behav, 90 | col="institution_code", col_wrap=7, col_order=col_names, 91 | sharex=True, sharey=True, hue="subject_uuid", xlim=[-1, 3e4], 92 | height=FIGURE_HEIGHT, aspect=(FIGURE_WIDTH / 7) / FIGURE_HEIGHT) 93 | fig.map(sns.lineplot, "cum_trials", 94 | "performance_easy", color='grey', alpha=0.3) 95 | fig.map(sns.lineplot, "cum_trials", 96 | "performance_easy_trained", color='black', alpha=0.3) 97 | fig.set_titles("{col_name}") 98 | format_fcn = ticker.FuncFormatter(lambda x, pos: '{:,.0f}'.format(x / 1e3) + 'K') 99 | 100 | # overlay the example mouse 101 | sns.lineplot(ax=fig.axes[0], x='cum_trials', y='performance_easy', color='black', 102 | data=behav[behav['subject_nickname'].str.contains(EXAMPLE_MOUSE)], legend=False) 103 | 104 | for axidx, ax in enumerate(fig.axes.flat): 105 | # add the lab mean to each panel 106 | d = (behav.loc[behav.institution_name == behav.institution_name.unique()[axidx], :])\ 107 | .groupby('training_day').mean() # Binning by day 108 | sns.lineplot(data=d, x='cum_trials', y='performance_easy', 109 | color=pal[axidx], ci=None, ax=ax, legend=False, linewidth=2) 110 | ax.set_title(behav.institution_name.unique()[ 111 | axidx], color=pal[axidx], fontweight='bold') 112 | fig.set(xticks=[0, 10000, 20000, 30000]) 113 | ax.xaxis.set_major_formatter(format_fcn) 114 | 115 | fig.set_axis_labels('Trial', 'Performance (%)\n on easy trials') 116 | fig.despine(trim=True) 117 | plt.tight_layout(w_pad=-2.2) 118 | fig.savefig(os.path.join(figpath, "figure2a_learningcurves_trials.pdf")) 119 | fig.savefig(os.path.join(figpath, "figure2a_learningcurves_trials.png"), dpi=300) 120 | 121 | # Plot all labs 122 | d = behav.groupby(['institution_code', 'training_day']).mean() # Binned by day 123 | 124 | fig, ax1 = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/3, FIGURE_HEIGHT)) 125 | sns.lineplot(x='cum_trials', y='performance_easy', hue='institution_code', palette=pal, 126 | ax=ax1, legend=False, data=d, ci=None) 127 | ax1.set_title('All labs: %d mice' % behav['subject_nickname'].nunique()) 128 | ax1.set(xlabel='Trial', 129 | ylabel='Performance (%)\non easy trials', xlim=[-1, 30000], ylim=[15, 100]) 130 | ax1.xaxis.set_major_formatter(format_fcn) 131 | 132 | sns.despine(trim=True) 133 | plt.tight_layout() 134 | fig.savefig(os.path.join(figpath, "figure2b_learningcurves_trials_all_labs.pdf")) 135 | fig.savefig(os.path.join( 136 | figpath, "figure2b_learningcurves_trials_all_labs.png"), dpi=300) 137 | 138 | # ================================= # 139 | # print some stats 140 | # ================================= # 141 | behav_summary_std = behav.groupby(['training_day'])[ 142 | ['performance_easy', 'cum_trials']].std().reset_index() 143 | behav_summary = behav.groupby(['training_day'])[ 144 | ['performance_easy', 'cum_trials']].mean().reset_index() 145 | print('number of trials to reach 80% accuracy on easy trials: ') 146 | print(behav_summary.loc[behav_summary.performance_easy > 147 | 80, 'cum_trials'].round().min()) 148 | 149 | plt.show() 150 | -------------------------------------------------------------------------------- /supp_nmice_overtime.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Query the number of mice at different timepoints of the pipeline 4 | @author 22 April 2020, Anne Urai 5 | """ 6 | from ibl_pipeline import subject, reference, behavior 7 | 8 | from paper_behavior_functions import QUERY 9 | 10 | assert QUERY, 'This script requires a DataJoint instance, which was removed in Dec 2023.' 11 | 12 | dates = ['2019-01-01', '2019-03-01', '2020-01-01', '2020-04-01'] 13 | dates = ['2019-01-01', '2019-05-01', '2019-11-01', '2020-04-01'] 14 | 15 | print('All mice in database:') 16 | for d in dates: 17 | 18 | # which mice were in the database by then? 19 | subj_query = subject.Subject * subject.SubjectLab * reference.Lab * \ 20 | behavior.TrialSet & 'session_start_time < "%s"'%d 21 | subj_df = subj_query.fetch(format='frame').reset_index() 22 | 23 | print('%s, %d mice, %d labs, %d choices'%(d, subj_df.subject_uuid.nunique(), 24 | subj_df.lab_name.nunique(), 25 | subj_df.n_trials.sum())) 26 | 27 | print('Brainwide map project mice:') 28 | for d in dates: 29 | 30 | # which mice were in the database by then? 31 | subj_query = subject.Subject * subject.SubjectLab * reference.Lab \ 32 | * (subject.SubjectProject & 'subject_project = "ibl_neuropixel_brainwide_01"') \ 33 | * (behavior.TrialSet & 'session_start_time < "%s"'%d) 34 | subj_df = subj_query.fetch(format='frame').reset_index() 35 | 36 | print('%s, %d mice, %d labs, %d choices'%(d, subj_df.subject_uuid.nunique(), 37 | subj_df.lab_name.nunique(), 38 | subj_df.n_trials.sum())) 39 | -------------------------------------------------------------------------------- /suppfig_3-4a-f.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Quantify the variability of behavioral metrics within and between labs of mouse behavior. 5 | This script doesn't perform any analysis but plots summary statistics over labs. 6 | 7 | Alejandro Pan 8 | 06 Jan 2020 9 | """ 10 | 11 | import pandas as pd 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | from scipy import stats 15 | from os.path import join 16 | import seaborn as sns 17 | from paper_behavior_functions import (seaborn_style, QUERY, load_csv, 18 | institution_map, group_colors, figpath, 19 | FIGURE_WIDTH, FIGURE_HEIGHT, 20 | fit_psychfunc, num_star, 21 | query_session_around_performance) 22 | import scikit_posthocs as sp 23 | from statsmodels.stats.multitest import multipletests 24 | 25 | 26 | seaborn_style() 27 | figpath = figpath() 28 | pal = group_colors() 29 | institution_map, col_names = institution_map() 30 | col_names = col_names[:-1] 31 | 32 | if QUERY: 33 | behav = query_session_around_performance(perform_thres=0.8) 34 | else: 35 | behav = load_csv('suppfig_3-4af.pkl') 36 | behav['institution_code'] = behav.lab_name.map(institution_map) 37 | 38 | # Create dataframe with behavioral metrics of all mice 39 | learned = pd.DataFrame(columns=['mouse', 'institution_short', 'perf_easy', 'n_trials', 40 | 'threshold', 'bias', 'reaction_time', 41 | 'lapse_low', 'lapse_high', 'trials_per_minute']) 42 | 43 | for i, nickname in enumerate(behav['subject_nickname'].unique()): 44 | if np.mod(i+1, 10) == 0: 45 | print('Processing data of subject %d of %d' % (i+1, 46 | len(behav['subject_nickname'].unique()))) 47 | 48 | # Get the trials of the sessions around criterion for this subject (first 49 | # 90% + next session) 50 | trials = behav[behav['subject_nickname'] == nickname].reset_index() 51 | # Exclude sessions with less than 4 contrasts 52 | trials['contrast_set'] = trials.session_start_time.map( 53 | trials.groupby(['session_start_time'])['signed_contrast'].unique()) 54 | trials = trials.loc[trials['contrast_set'].str.len()>4] 55 | if len(trials['session_start_time'].unique())<3: 56 | continue 57 | # Fit a psychometric function to these trials and get fit results 58 | fit_result = fit_psychfunc(trials) 59 | 60 | # Get RT, performance and number of trials 61 | reaction_time = trials['rt'].median()*1000 62 | perf_easy = trials['correct_easy'].mean()*100 63 | ntrials_perday = trials.groupby('session_uuid').count()['trial_id'].mean() 64 | 65 | 66 | # average trials/minute to normalise by session length 67 | trials['session_length'] = (trials.session_end_time - trials.session_start_time).astype('timedelta64[m]') 68 | total_session_length = trials.groupby('session_uuid')['session_length'].mean().sum() 69 | total_n_trials = trials['trial_id'].count() 70 | 71 | # Add results to dataframe 72 | learned.loc[i, 'mouse'] = nickname 73 | learned.loc[i, 'lab'] = trials['institution_short'].iloc[0] 74 | learned.loc[i, 'perf_easy'] = perf_easy 75 | learned.loc[i, 'n_trials'] = ntrials_perday 76 | learned.loc[i, 'reaction_time'] = reaction_time 77 | learned.loc[i, 'trials_per_minute'] = total_n_trials / total_session_length 78 | learned.loc[i, 'threshold'] = fit_result.loc[0, 'threshold'] 79 | learned.loc[i, 'bias'] = fit_result.loc[0, 'bias'] 80 | learned.loc[i, 'lapse_low'] = fit_result.loc[0, 'lapselow'] 81 | learned.loc[i, 'lapse_high'] = fit_result.loc[0, 'lapsehigh'] 82 | 83 | # Drop mice with faulty RT 84 | learned = learned[learned['reaction_time'].notnull()] 85 | 86 | # Change lab name into lab number 87 | learned['lab_number'] = learned.lab.map(institution_map) 88 | learned = learned.sort_values('lab_number') 89 | 90 | # Convert to float 91 | float_fields = ['perf_easy', 'reaction_time', 'threshold', 92 | 'n_trials', 'bias', 'lapse_low', 'lapse_high', 'trials_per_minute'] 93 | learned[float_fields] = learned[float_fields].astype(float) 94 | 95 | # %% Stats 96 | stats_tests = pd.DataFrame(columns=['variable', 'test_type', 'p_value']) 97 | posthoc_tests = {} 98 | 99 | for i, var in enumerate(['perf_easy', 'reaction_time', 'n_trials', 'threshold', 'bias', 'trials_per_minute']): 100 | _, normal = stats.normaltest(learned[var]) 101 | 102 | if normal < 0.05: 103 | test_type = 'kruskal' 104 | test = stats.kruskal(*[group[var].values 105 | for name, group in learned.groupby('lab_number')]) 106 | if test[1] < 0.05: # Proceed to posthocs 107 | posthoc = sp.posthoc_dunn(learned, val_col=var, group_col='lab_number') 108 | else: 109 | posthoc = np.nan 110 | else: 111 | test_type = 'anova' 112 | test = stats.f_oneway(*[group[var].values 113 | for name, group in learned.groupby('lab_number')]) 114 | if test[1] < 0.05: 115 | posthoc = sp.posthoc_tukey(learned, val_col=var, group_col='lab_number') 116 | else: 117 | posthoc = np.nan 118 | 119 | # Test for difference in variance 120 | _, p_var = stats.levene(*[group[var].values for name, group in learned.groupby('lab_number')]) 121 | 122 | posthoc_tests['posthoc_'+str(var)] = posthoc 123 | stats_tests.loc[i, 'variable'] = var 124 | stats_tests.loc[i, 'test_type'] = test_type 125 | stats_tests.loc[i, 'p_value'] = test[1] 126 | stats_tests.loc[i, 'p_value_variance'] = p_var 127 | 128 | # Correct for multiple tests 129 | stats_tests['p_value'] = multipletests(stats_tests['p_value'], method='fdr_bh')[1] 130 | stats_tests['p_value_variance'] = multipletests(stats_tests['p_value_variance'], 131 | method='fdr_bh')[1] 132 | 133 | if (stats.normaltest(learned['n_trials'])[1] < 0.05 or 134 | stats.normaltest(learned['reaction_time'])[1] < 0.05): 135 | test_type = 'spearman' 136 | correlation_coef, correlation_p = stats.spearmanr(learned['reaction_time'], 137 | learned['n_trials']) 138 | if (stats.normaltest(learned['n_trials'])[1] > 0.05 and 139 | stats.normaltest(learned['reaction_time'])[1] > 0.05): 140 | test_type = 'pearson' 141 | correlation_coef, correlation_p = stats.pearsonr(learned['reaction_time'], 142 | learned['n_trials']) 143 | 144 | # Add all mice to dataframe seperately for plotting 145 | learned_no_all = learned.copy() 146 | #learned_no_all.loc[learned_no_all.shape[0] + 1, 'lab_number'] = 'All' 147 | learned_2 = learned.copy() 148 | learned_2['lab_number'] = 'All' 149 | learned_2 = learned.append(learned_2) 150 | 151 | # %% 152 | seaborn_style() 153 | lab_colors = group_colors() 154 | sns.set_palette(lab_colors) 155 | 156 | # %% 157 | vars = ['n_trials', 'perf_easy', 'threshold', 'bias', 'reaction_time', 'trials_per_minute'] 158 | ylabels =['Number of trials', 'Performance (%)\non easy trials', 159 | 'Contrast threshold (%)', 'Bias (%)', 'Trial duration (ms)', 'Trials / minute'] 160 | ylims = [[0, 2000],[70, 100], [0, 50], [-30, 30], [0, 2000], [0, 30]] 161 | criteria = [[0, 0],[80, 100], [0, 20], [-10, 10], [0, 0], [0, 0]] 162 | order_x = ['Lab 1','Lab 2','Lab 3','Lab 4','Lab 5','Lab 6','Lab 7','All'] 163 | for v, ylab, ylim, crit in zip(vars, ylabels, ylims, criteria): 164 | 165 | f, ax = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT)) 166 | sns.swarmplot(y=v, x='lab_number', data=learned_no_all, hue='lab_number', 167 | palette=lab_colors, ax=ax, marker='.', order=order_x) 168 | axbox = sns.boxplot(y=v, x='lab_number', data=learned_2, color='white', 169 | showfliers=False, ax=ax, order=order_x) 170 | ax.set(ylabel=ylab, ylim=ylim, xlabel='') 171 | ax.axhspan(crit[0], crit[1], facecolor='0.2', alpha=0.2) 172 | # [tick.set_color(lab_colors[i]) for i, tick in enumerate(ax5.get_xticklabels()[:-1])] 173 | plt.setp(ax.xaxis.get_majorticklabels(), rotation=60) 174 | axbox.artists[-1].set_edgecolor('black') 175 | for j in range(5 * (len(axbox.artists) - 1), 5 * len(axbox.artists)): 176 | axbox.lines[j].set_color('black') 177 | ax.get_legend().set_visible(False) 178 | 179 | # statistical annotation 180 | pvalue = stats_tests.loc[stats_tests['variable'] == v, 'p_value'] 181 | if pvalue.to_numpy()[0] < 0.05: 182 | ax.annotate(num_star(pvalue.to_numpy()[0]), 183 | xy=[0.1, 0.8], xycoords='axes fraction', fontsize=5) 184 | 185 | sns.despine(trim=True) 186 | plt.tight_layout() 187 | plt.savefig(join(figpath, 'supplementaryfigure3_metrics_%s.pdf'%v)) 188 | plt.savefig(join(figpath, 'supplementaryfigure3_metrics_%s.pdf'%v), dpi=300) 189 | 190 | # %% 191 | # Get stats for text 192 | perf_mean = learned['perf_easy'].mean() 193 | perf_std = learned['perf_easy'].std() 194 | thres_mean = learned['threshold'].mean() 195 | thres_std = learned['threshold'].std() 196 | rt_median = learned['reaction_time'].median() 197 | rt_std = learned['reaction_time'].std() 198 | trials_mean = learned['n_trials'].mean() 199 | trials_std = learned['n_trials'].std() 200 | -------------------------------------------------------------------------------- /suppfig_classifier_lab_membership_first_biased.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Decode in which lab a mouse was trained based on its behavioral metrics during the first three 5 | biased sessions, regardless of when the mice reached proficiency. 6 | 7 | As a positive control, the time zone in which the mouse was trained is included in the dataset 8 | since the timezone provides geographical information. Decoding is performed using leave-one-out 9 | cross-validation. To control for the imbalance in the dataset (some labs have more mice than 10 | others) a fixed number of mice is randomly sub-sampled from each lab. This random sampling is 11 | repeated for a large number of repetitions. A shuffled nul-distribution is obtained by shuffling 12 | the lab labels and decoding again for each iteration. 13 | 14 | -------------- 15 | Parameters 16 | DECODER: Which decoder to use: 'bayes', 'forest', or 'regression' 17 | N_MICE: How many mice per lab to randomly sub-sample 18 | (must be lower than the lab with the least mice) 19 | ITERATIONS: Number of times to randomly sub-sample 20 | METRICS: List of strings indicating which behavioral metrics to include 21 | during decoding of lab membership 22 | METRICS_CONTROL: List of strings indicating which metrics to use for the positive control 23 | 24 | Guido Meijer 25 | September 3, 2020 26 | """ 27 | 28 | import numpy as np 29 | from os.path import join 30 | from paper_behavior_functions import \ 31 | institution_map, QUERY, fit_psychfunc, dj2pandas, load_csv, datapath 32 | import pandas as pd 33 | from sklearn.ensemble import RandomForestClassifier 34 | from sklearn.naive_bayes import GaussianNB 35 | from sklearn.linear_model import LogisticRegression 36 | from sklearn.model_selection import LeaveOneOut 37 | from sklearn.metrics import f1_score, confusion_matrix 38 | 39 | # Settings 40 | DECODER = 'bayes' # forest, bayes or regression 41 | N_MICE = 8 # how many mice per lab to sub-sample 42 | ITERATIONS = 2000 # how often to decode 43 | METRICS = ['perf_easy', 'threshold_l', 'threshold_r', 'threshold_n', 'bias_l', 'bias_r', 'bias_n'] 44 | METRICS_CONTROL = ['perf_easy', 'threshold_l', 'threshold_r', 'threshold_n', 45 | 'bias_l', 'bias_r', 'bias_n', 'time_zone'] 46 | 47 | 48 | # Decoding function with n-fold cross validation 49 | def decoding(data, labels, clf): 50 | kf = LeaveOneOut() 51 | y_pred = np.empty(len(labels), dtype='= n_days) == criterion) 53 | for n_days in range(max_n_days+1)] 54 | for criterion in np.sort(df['end_status_id'].unique())]) 55 | 56 | if normalize: 57 | counts = np.stack([n / sum(n) for n in counts.T]).T 58 | # counts = np.stack([n / sum(n) for n in counts]) 59 | 60 | bar_l = range(1, counts.shape[1]+1) 61 | # bottom = np.zeros_like(bar_l).astype('float') 62 | bottom = np.vstack((np.zeros((1, counts.shape[1])), np.cumsum(counts, axis=0)[:-1, :])) 63 | 64 | fig, ax = plt.subplots(1, 1, figsize=(FIGURE_WIDTH / 2, FIGURE_HEIGHT)) 65 | for i in range(counts.shape[0]): 66 | ax.bar(bar_l, counts[i, :], bottom=bottom[i, :], width=1, label=list(ids.keys())[i], 67 | color=colors[i]) 68 | 69 | ax.set_xticks([1] + [i * 7 for i in range(1, round(max_n_days+7/7))]) 70 | ax.set_xticks([0, 10, 20, 30, 40]) 71 | 72 | ax.set_xlim([0, counts.shape[1]+.5]) 73 | ax.set_xlabel('Session #') 74 | ax.set_ylabel('Proportion') 75 | ax.legend(loc='upper right') 76 | plt.tight_layout() 77 | sns.despine(trim=False) 78 | plt.gcf().savefig(os.path.join(save_path, "suppfig_end_status_histogram_normalized.png"), dpi=300) 79 | plt.gcf().savefig(os.path.join(save_path, "suppfig_end_status_histogram_normalized.pdf")) 80 | -------------------------------------------------------------------------------- /suppfig_plot_classifier_first_biased.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Plot the results from the classification of lab by loading in the .pkl files generated by 5 | figure3f_decoding_lab_membership_basic and figure3f_decoding_lab_membership_full 6 | 7 | Guido Meijer 8 | 18 Jun 2020 9 | """ 10 | 11 | import pandas as pd 12 | import numpy as np 13 | import seaborn as sns 14 | from os.path import join 15 | import matplotlib.pyplot as plt 16 | from paper_behavior_functions import seaborn_style, figpath, load_csv, FIGURE_WIDTH, FIGURE_HEIGHT 17 | 18 | # Settings 19 | FIG_PATH = figpath() 20 | colors = [[1, 1, 1], [1, 1, 1], [0.6, 0.6, 0.6]] 21 | seaborn_style() 22 | 23 | # Load in results from csv file 24 | decoding_result = load_csv('classification_results', 'classification_results_full_bayes.pkl') 25 | 26 | # Calculate if decoder performs above chance 27 | chance_level = decoding_result['original_shuffled'].mean() 28 | significance = np.percentile(decoding_result['original'], 2.5) 29 | sig_control = np.percentile(decoding_result['control'], 0.001) 30 | if chance_level > significance: 31 | print('Classification performance not significanlty above chance') 32 | else: 33 | print('Above chance classification performance!') 34 | 35 | # %% 36 | 37 | f, ax1 = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT)) 38 | sns.violinplot(data=pd.concat([decoding_result['control'], 39 | decoding_result['original_shuffled'], 40 | decoding_result['original']], axis=1), 41 | palette=colors, ax=ax1) 42 | ax1.plot([-1, 3.5], [chance_level, chance_level], '--', color='k', zorder=-10) 43 | ax1.set(ylabel='Decoding accuracy', xlim=[-0.8, 2.4], ylim=[-0.1, 0.62]) 44 | ax1.set_xticklabels(['Positive\ncontrol', 'Shuffle', 'Mouse\nbehavior'], 45 | rotation=90, ha='center') 46 | plt.tight_layout() 47 | sns.despine(trim=True) 48 | 49 | plt.savefig(join(FIG_PATH, 'suppfig_decoding_first_biased.pdf')) 50 | plt.savefig(join(FIG_PATH, 'suppfig_decoding_first_biased.png'), dpi=300) 51 | 52 | 53 | # %% 54 | f, ax1 = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/4, FIGURE_HEIGHT)) 55 | n_labs = decoding_result['confusion_matrix'][0].shape[0] 56 | sns.heatmap(data=decoding_result['confusion_matrix'].mean(), vmin=0, vmax=0.4) 57 | ax1.plot([0, 7], [0, 7], '--w') 58 | ax1.set(xticklabels=np.arange(1, n_labs + 1), yticklabels=np.arange(1, n_labs + 1), 59 | ylim=[0, n_labs], xlim=[0, n_labs], 60 | title='', ylabel='Actual lab', xlabel='Predicted lab') 61 | plt.setp(ax1.xaxis.get_majorticklabels(), rotation=40) 62 | plt.setp(ax1.yaxis.get_majorticklabels(), rotation=40) 63 | plt.gca().invert_yaxis() 64 | plt.tight_layout() 65 | 66 | plt.savefig(join(FIG_PATH, 'suppfig_confusion_matrix_first_biased.pdf')) 67 | plt.savefig(join(FIG_PATH, 'suppfig_confusion_matrix_first_biased.png'), dpi=300) 68 | -------------------------------------------------------------------------------- /suppfig_plot_classifier_perf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Plot the results from the classification of lab by loading in the .pkl files generated by 5 | figure3f_decoding_lab_membership_basic and figure3f_decoding_lab_membership_full 6 | 7 | Guido Meijer 8 | 18 Jun 2020 9 | """ 10 | 11 | import pandas as pd 12 | import numpy as np 13 | import seaborn as sns 14 | from os.path import join 15 | import matplotlib.pyplot as plt 16 | from paper_behavior_functions import seaborn_style, figpath, datapath, FIGURE_WIDTH, FIGURE_HEIGHT 17 | 18 | # Settings 19 | DECODER = 'bayes' 20 | FIG_PATH = figpath() 21 | colors = [[1, 1, 1], [1, 1, 1], [0.6, 0.6, 0.6]] 22 | seaborn_style() 23 | 24 | # Load in results from csv file 25 | decoding_result = pd.read_pickle(join(datapath(), 26 | 'classification_results_perf_%s.pkl' % DECODER)) 27 | 28 | # Calculate if decoder performs above chance 29 | chance_level = decoding_result['original_shuffled'].mean() 30 | significance = np.percentile(decoding_result['original'], 2.5) 31 | sig_control = np.percentile(decoding_result['control'], 0.001) 32 | if chance_level > significance: 33 | print('\n%s classifier did not perform above chance' % DECODER) 34 | print('Chance level: %.2f (F1 score)' % chance_level) 35 | else: 36 | print('\n%s classifier did not perform above chance' % DECODER) 37 | print('Chance level: %.2f (F1 score)' % chance_level) 38 | print('F1 score: %.2f ± %.3f' % (decoding_result['original'].mean(), 39 | decoding_result['original'].std())) 40 | 41 | # %% 42 | 43 | # Plot main Figure 3 44 | f, ax1 = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT)) 45 | sns.violinplot(data=pd.concat([decoding_result['control'], 46 | decoding_result['original_shuffled'], 47 | decoding_result['original']], axis=1), 48 | palette=colors, ax=ax1) 49 | ax1.plot([-1, 3.5], [chance_level, chance_level], '--', color='k', zorder=-10) 50 | ax1.set(ylabel='Decoding accuracy', xlim=[-0.6, 2.6], ylim=[-0.1, 0.62]) 51 | ax1.set_xticklabels(['Positive\ncontrol', 'Shuffle', 'Mouse\nbehavior'], 52 | rotation=90, ha='center') 53 | plt.tight_layout() 54 | sns.despine(trim=True) 55 | 56 | plt.savefig(join(FIG_PATH, 'suppfig3_decoding_perf.pdf')) 57 | plt.savefig(join(FIG_PATH, 'suppfig3_decoding_perf.png'), dpi=300) 58 | plt.close(f) 59 | 60 | # %% 61 | f, ax1 = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/4, FIGURE_HEIGHT)) 62 | n_labs = decoding_result['confusion_matrix'][0].shape[0] 63 | sns.heatmap(data=decoding_result['confusion_matrix'].mean(), vmin=0, vmax=0.4) 64 | ax1.plot([0, 7], [0, 7], '--w') 65 | ax1.set(xticklabels=np.arange(1, n_labs + 1), yticklabels=np.arange(1, n_labs + 1), 66 | ylim=[0, n_labs], xlim=[0, n_labs], 67 | title='', ylabel=' ', xlabel='Predicted lab') 68 | 69 | ax1.set(ylabel='Actual lab') 70 | plt.setp(ax1.xaxis.get_majorticklabels(), rotation=40) 71 | plt.setp(ax1.yaxis.get_majorticklabels(), rotation=40) 72 | plt.gca().invert_yaxis() 73 | plt.tight_layout() 74 | 75 | plt.savefig(join(FIG_PATH, 'suppfig3_confusion_matrix_perf.pdf')) 76 | plt.savefig(join(FIG_PATH, 'suppfig3_confusion_matrix_pref.png'), dpi=300) 77 | plt.close(f) 78 | -------------------------------------------------------------------------------- /suppfig_variability_over_labs_first_biased.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plotting of behavioral metrics during the full task (biased blocks) per lab 3 | 4 | Guido Meijer 5 | 6 May 2020 6 | """ 7 | 8 | import seaborn as sns 9 | import numpy as np 10 | from os.path import join 11 | import matplotlib.pyplot as plt 12 | from scipy import stats 13 | import scikit_posthocs as sp 14 | from paper_behavior_functions import (figpath, seaborn_style, group_colors, institution_map, 15 | load_csv, FIGURE_WIDTH, FIGURE_HEIGHT, QUERY, fit_psychfunc, 16 | dj2pandas) 17 | import pandas as pd 18 | from statsmodels.stats.multitest import multipletests 19 | 20 | # Initialize 21 | seaborn_style() 22 | figpath = figpath() 23 | pal = group_colors() 24 | institution_map, col_names = institution_map() 25 | col_names = col_names[:-1] 26 | 27 | # %% Process data 28 | 29 | if QUERY is True: 30 | # query sessions 31 | from paper_behavior_functions import query_sessions_around_criterion 32 | from ibl_pipeline import reference, subject, behavior 33 | use_sessions, _ = query_sessions_around_criterion(criterion='biased', 34 | days_from_criterion=[-1, 3]) 35 | use_sessions = use_sessions & 'task_protocol LIKE "%biased%"' # only get biased sessions 36 | b = (use_sessions * subject.Subject * subject.SubjectLab * reference.Lab 37 | * behavior.TrialSet.Trial) 38 | b2 = b.proj('institution_short', 'subject_nickname', 'task_protocol', 'session_uuid', 39 | 'trial_stim_contrast_left', 'trial_stim_contrast_right', 'trial_response_choice', 40 | 'task_protocol', 'trial_stim_prob_left', 'trial_feedback_type', 41 | 'trial_response_time', 'trial_stim_on_time') 42 | bdat = b2.fetch(order_by='institution_short, subject_nickname, session_start_time, trial_id', 43 | format='frame').reset_index() 44 | behav = dj2pandas(bdat) 45 | behav['institution_code'] = behav.institution_short.map(institution_map) 46 | else: 47 | behav = load_csv('Fig4.csv') 48 | 49 | biased_fits = pd.DataFrame() 50 | for i, nickname in enumerate(behav['subject_nickname'].unique()): 51 | if np.mod(i+1, 10) == 0: 52 | print('Processing data of subject %d of %d' % (i+1, 53 | len(behav['subject_nickname'].unique()))) 54 | 55 | # Get lab and subject uuid 56 | lab = behav.loc[behav['subject_nickname'] == nickname, 'institution_code'].unique()[0] 57 | uuid = behav.loc[behav['subject_nickname'] == nickname, 'subject_uuid'].unique()[0] 58 | 59 | # Fit psychometric curve 60 | left_fit = fit_psychfunc(behav[(behav['subject_nickname'] == nickname) 61 | & (behav['probabilityLeft'] == 80)]) 62 | right_fit = fit_psychfunc(behav[(behav['subject_nickname'] == nickname) 63 | & (behav['probabilityLeft'] == 20)]) 64 | neutral_fit = fit_psychfunc(behav[(behav['subject_nickname'] == nickname) 65 | & (behav['probabilityLeft'] == 50)]) 66 | perf_easy = (behav.loc[behav['subject_nickname'] == nickname, 'correct_easy'].mean()) * 100 67 | 68 | fits = pd.DataFrame(data={'perf_easy': perf_easy, 69 | 'threshold_l': left_fit['threshold'], 70 | 'threshold_r': right_fit['threshold'], 71 | 'threshold_n': neutral_fit['threshold'], 72 | 'bias_l': left_fit['bias'], 73 | 'bias_r': right_fit['bias'], 74 | 'bias_n': neutral_fit['bias'], 75 | 'nickname': nickname, 'lab': lab, 'subject_uuid': uuid}) 76 | biased_fits = biased_fits.append(fits, sort=False) 77 | 78 | 79 | # %% Statistics 80 | 81 | stats_tests = pd.DataFrame(columns=['variable', 'test_type', 'p_value']) 82 | posthoc_tests = {} 83 | 84 | for i, var in enumerate(['perf_easy', 'threshold_l', 'threshold_r', 'threshold_n', 85 | 'bias_l', 'bias_r', 'bias_n']): 86 | 87 | # Remove any animals with NaNs 88 | test_fits = biased_fits[biased_fits[var].notnull()] 89 | 90 | # Test for normality 91 | _, normal = stats.normaltest(test_fits[var]) 92 | 93 | if normal < 0.05: 94 | test_type = 'kruskal' 95 | test = stats.kruskal(*[group[var].values 96 | for name, group in test_fits.groupby('lab')]) 97 | if test[1] < 0.05: # Proceed to posthocs 98 | posthoc = sp.posthoc_dunn(test_fits, val_col=var, group_col='lab') 99 | else: 100 | posthoc = np.nan 101 | else: 102 | test_type = 'anova' 103 | test = stats.f_oneway(*[group[var].values 104 | for name, group in test_fits.groupby('lab')]) 105 | if test[1] < 0.05: 106 | posthoc = sp.posthoc_tukey(test_fits, val_col=var, group_col='lab') 107 | else: 108 | posthoc = np.nan 109 | 110 | posthoc_tests['posthoc_'+str(var)] = posthoc 111 | stats_tests.loc[i, 'variable'] = var 112 | stats_tests.loc[i, 'test_type'] = test_type 113 | stats_tests.loc[i, 'p_value'] = test[1] 114 | 115 | # Correct for multiple tests 116 | stats_tests['p_value'] = multipletests(stats_tests['p_value'])[1] 117 | 118 | # %% Prepare for plotting 119 | 120 | # Sort by lab number 121 | biased_fits = biased_fits.sort_values('lab') 122 | 123 | # Convert to float 124 | biased_fits[['perf_easy', 'bias_l', 'bias_r', 'bias_n', 125 | 'threshold_l', 'threshold_r', 'threshold_n']] = biased_fits[ 126 | ['perf_easy', 'bias_l', 'bias_r', 'bias_n', 'threshold_l', 127 | 'threshold_r', 'threshold_n']].astype(float) 128 | 129 | # Add all mice to dataframe seperately for plotting 130 | learned_no_all = biased_fits.copy() 131 | #learned_no_all.loc[learned_no_all.shape[0] + 1, 'lab'] = 'All' 132 | learned_2 = biased_fits.copy() 133 | learned_2['lab'] = 'All' 134 | learned_2 = biased_fits.append(learned_2) 135 | 136 | # %% 137 | # Plot behavioral metrics per lab 138 | lab_colors = group_colors() 139 | sns.set_palette(lab_colors) 140 | seaborn_style() 141 | 142 | vars = ['perf_easy', 143 | 'bias_n', 'bias_l', 'bias_r', 144 | 'threshold_n', 'threshold_l', 'threshold_r'] 145 | ylabels =['Performance (%)\non easy trials', 146 | 'Bias (%)\n50:50 blocks', 147 | 'Bias (%)\n20:80 blocks', 148 | 'Bias (%)\n80:20 blocks', 149 | 'Contrast threshold (%)\n50:50 blocks', 150 | 'Contrast threshold (%)\n20:80 blocks', 151 | 'Contrast threshold (%)\n80:20 blocks'] 152 | ylims = [[70, 101], [-30, 30], [-30, 30], [-30, 30], 153 | [0, 45], [0, 45], [0, 45]] 154 | 155 | plt.close('all') 156 | for v, ylab, ylim in zip(vars, ylabels, ylims): 157 | 158 | f, ax = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/4.5, FIGURE_HEIGHT)) 159 | 160 | sns.swarmplot(y=v, x='lab', data=learned_no_all, hue='lab', 161 | palette=lab_colors, ax=ax, marker='.') 162 | axbox = sns.boxplot(y=v, x='lab', data=learned_2, color='white', 163 | showfliers=False, ax=ax) 164 | ax.set(ylabel=ylab, ylim=ylim, xlabel='') 165 | # [tick.set_color(lab_colors[i]) for i, tick in enumerate(ax1.get_xticklabels()[:-1])] 166 | plt.setp(ax.xaxis.get_majorticklabels(), rotation=60) 167 | axbox.artists[-1].set_edgecolor('black') 168 | for j in range(5 * (len(axbox.artists) - 1), 5 * len(axbox.artists)): 169 | axbox.lines[j].set_color('black') 170 | ax.get_legend().set_visible(False) 171 | 172 | plt.tight_layout() 173 | sns.despine(trim=True) 174 | plt.savefig(join(figpath, 'suppfig_metrics_per_lab_first_biased_%s.pdf'%v)) 175 | plt.savefig(join(figpath, 'suppfig_metrics_per_lab_first_biased_%s.png'%v), dpi=300) 176 | -------------------------------------------------------------------------------- /suppfig_variability_over_time.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Quantify the variability of behavioral metrics within and between labs of mouse behavior. 5 | This script doesn't perform any analysis but plots summary statistics over labs. 6 | 7 | Guido Meijer, Miles Wells 8 | 16 Jan 2020 9 | """ 10 | 11 | import pandas as pd 12 | import matplotlib.pyplot as plt 13 | import matplotlib.ticker as ticker 14 | import numpy as np 15 | import seaborn as sns 16 | from os.path import join 17 | from paper_behavior_functions import (seaborn_style, institution_map, group_colors, figpath, 18 | query_subjects, FIGURE_WIDTH, FIGURE_HEIGHT, QUERY, load_csv) 19 | from ibl_pipeline.analyses import behavior as behavior_analysis 20 | 21 | # Settings 22 | fig_path = figpath() 23 | bin_centers = np.arange(3, 40, 3) 24 | bin_size = 5 25 | seaborn_style() 26 | 27 | # Load in data 28 | if QUERY: 29 | use_subjects = query_subjects() 30 | behav = (use_subjects * behavior_analysis.BehavioralSummaryByDate).fetch(format='frame') 31 | behav['lab'] = behav['institution_short'] 32 | behav['lab_number'] = behav.lab.map(institution_map()[0]) 33 | else: 34 | behav = load_csv('suppfig_variability.pkl.bz2') 35 | 36 | # Get variability over days 37 | mean_days = pd.DataFrame(columns=bin_centers, index=np.unique(behav['lab_number'])) 38 | std_days = pd.DataFrame(columns=bin_centers, index=np.unique(behav['lab_number'])) 39 | for i, day in enumerate(bin_centers): 40 | this_behav = behav[(behav['training_day'] > day - np.floor(bin_size / 2)) 41 | & (behav['training_day'] < day + np.floor(bin_size / 2))] 42 | mean_days[day] = this_behav.groupby('lab_number').mean()['performance_easy'] 43 | std_days[day] = this_behav.groupby('lab_number').std()['performance_easy'] 44 | 45 | # Plot output 46 | 47 | colors = group_colors() 48 | f, (ax1, ax2) = plt.subplots(1, 2, figsize=(FIGURE_WIDTH*0.7, FIGURE_HEIGHT)) 49 | for i, lab in enumerate(std_days.index.values): 50 | ax1.plot(std_days.loc[lab], color=colors[i], lw=2, label='Lab %s' % (i + 1)) 51 | #ax1.legend(frameon=False, loc='lower center', ncol=3, bbox_to_anchor=(0.5, 1)) 52 | ax1.set(xlabel='Training days', ylabel='Variability (std)', title='Within labs') 53 | ax1.set(xlim=[0, 40]) 54 | ax2.plot(mean_days.std(), lw=2) 55 | ax2.set(xlabel='Training days', ylabel='Variability (std)', title='Between labs') 56 | ax2.set(xlim=[0, 40]) 57 | 58 | sns.despine(trim=True) 59 | plt.tight_layout() 60 | plt.savefig(join(fig_path, 'suppfig4_variability_over_time.pdf')) 61 | plt.savefig(join(fig_path, 'suppfig4_variability_over_time.png'), dpi=300) 62 | 63 | ### The same but for trials ### 64 | 65 | # Settings 66 | bin_size = 1000 67 | bin_centers = np.arange(1000, 30001, bin_size) 68 | 69 | # Create column for cumulative trials per mouse 70 | behav.n_trials_date = behav.n_trials_date.astype(int) 71 | behav['cum_trials'] = ( 72 | (behav 73 | .groupby(by=['subject_uuid']) 74 | .cumsum() 75 | .n_trials_date) 76 | ) 77 | 78 | # Get variability over days 79 | mean_trials = pd.DataFrame(columns=bin_centers, index=np.unique(behav['lab_number'])) 80 | std_trials = pd.DataFrame(columns=bin_centers, index=np.unique(behav['lab_number'])) 81 | for i, tt in enumerate(bin_centers): 82 | this_behav = behav[(behav['cum_trials'] > tt - np.floor(bin_size / 2)) 83 | & (behav['cum_trials'] < tt + np.floor(bin_size / 2))] 84 | mean_trials[tt] = this_behav.groupby('lab_number').mean()['performance_easy'] 85 | std_trials[tt] = this_behav.groupby('lab_number').std()['performance_easy'] 86 | 87 | # Plot output 88 | 89 | xlim = [0, 30000] 90 | f, (ax1, ax2) = plt.subplots(1, 2, figsize=(FIGURE_WIDTH * 0.7, FIGURE_HEIGHT)) 91 | for i, lab in enumerate(std_trials.index.values): 92 | ax1.plot(std_trials.loc[lab], color=colors[i], lw=2, label='Lab %s' % (i + 1)) 93 | ax1.set(xlabel='Trials', ylabel='Variability (std)', title='Within labs') 94 | ax1.set(xlim=xlim) 95 | ax2.plot(mean_trials.std(), lw=2) 96 | ax2.set(xlabel='Trials', ylabel='Variability (std)', title='Between labs') 97 | ax2.set(xlim=xlim) 98 | 99 | sns.despine(trim=True, offset=5) 100 | format_fcn = ticker.FuncFormatter(lambda x, pos: '{:,.0f}'.format(x / 1e3) + 'K') 101 | [x.xaxis.set_major_formatter(format_fcn) for x in (ax1, ax2)] 102 | plt.tight_layout() 103 | plt.savefig(join(fig_path, 'suppfig4_variability_over_trials.pdf')) 104 | plt.savefig(join(fig_path, 'suppfig4_variability_over_trials.png'), dpi=300) 105 | -------------------------------------------------------------------------------- /text_trained1a_to_1b_sessions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Time from trained 1a > 1b 5 | 6 | @author: Gaelle Chapuis 7 | Jan 2021 8 | """ 9 | from os.path import join 10 | 11 | import pandas as pd 12 | import numpy as np 13 | from datetime import datetime 14 | 15 | from paper_behavior_functions import (query_subjects, datapath, QUERY) 16 | from ibl_pipeline.analyses import behavior as behavior_analysis 17 | 18 | # Date at which trained_1b was implemented in DJ pipeline 19 | DATE_IMPL = datetime.strptime('12-09-2019', '%d-%m-%Y').date() 20 | 21 | # Query data 22 | if QUERY is True: 23 | # Query sessions 24 | use_subjects = query_subjects() 25 | ses = ((use_subjects * behavior_analysis.SessionTrainingStatus * behavior_analysis.PsychResults 26 | & 'training_status = "trained_1a" OR training_status = "trained_1b"') 27 | .proj('subject_nickname', 'n_trials_stim', 'institution_short', 'training_status') 28 | .fetch(format='frame') 29 | .reset_index()) 30 | ses['n_trials'] = [sum(i) for i in ses['n_trials_stim']] 31 | else: 32 | ses = pd.read_csv(join(datapath(), 'Fig2c.csv')) 33 | use_subjects = ses['subject_uuid'].unique() # For counting the number of subjects 34 | 35 | ses = ses.sort_values(by=['subject_uuid', 'session_start_time']) 36 | uni_sub = np.unique(ses['subject_uuid']) 37 | 38 | training_time = pd.DataFrame(columns=['sessions']) 39 | # Loop over subjects 40 | for i_sub in range(0, len(uni_sub)): 41 | subj = uni_sub[i_sub] 42 | 43 | # Construct dataframe 44 | df = ses.loc[ses['subject_uuid'] == subj] 45 | if len(np.unique(df['training_status'])) == 2: # Append 46 | 47 | # Check that the session start date is different for when reaching 1a/1b 48 | df = df.sort_values(by=['session_start_time']) # Ensure data is sorted by date 49 | 50 | # Find index of relevant session 51 | indx_a = np.where(df['training_status'] == 'trained_1a')[0] 52 | n_row_a = indx_a[-1] # last session with trained 1a 53 | indx_b = np.where(df['training_status'] == 'trained_1b')[0] 54 | n_row_b = indx_b[0] # first session with trained 1b 55 | if n_row_a+1 != n_row_b: 56 | print("ERROR") 57 | # Get and compare dates 58 | date_a = df.iloc[[n_row_a]]['session_start_time'].values 59 | date_a = date_a.astype('datetime64[D]') 60 | date_b = df.iloc[[n_row_b]]['session_start_time'].values 61 | date_b = date_b.astype('datetime64[D]') 62 | if date_a != date_b and date_b > DATE_IMPL: 63 | # Print for debugging purposes 64 | # print(f'trained_1b: {date_b}, subject uuid: {subj}') 65 | # Aggregate and append 66 | training_time_ab = pd.DataFrame(columns=['sessions'], 67 | data=df.groupby(['training_status']).size()) 68 | training_time = training_time.append(training_time_ab.loc['trained_1a']) # Take N session done under 1a 69 | 70 | # Training time as a whole (N session in trained_1a before reaching trained_1b) 71 | m_train = training_time['sessions'].mean() 72 | s_train = training_time['sessions'].std() 73 | slowest = training_time['sessions'].max() 74 | fastest = training_time['sessions'].min() 75 | 76 | n_mice = len(training_time) 77 | print(f'using impl. date: {DATE_IMPL} : {n_mice} mice, n session from 1a>1b: {round(m_train, 2)} ± {round(s_train, 2)}') 78 | --------------------------------------------------------------------------------