├── .github └── ISSUE_TEMPLATE │ └── bug.md ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── neural ├── __init__.py ├── __main__.py ├── dataset.py ├── extraction.py ├── linear │ ├── __init__.py │ ├── __main__.py │ ├── arx.py │ ├── lin_model_template.py │ ├── receptive_field.py │ └── stats.py ├── model.py ├── train.py ├── utils.py ├── utils_mous.py └── visuals.py ├── requirements.txt └── setup.cfg /.github/ISSUE_TEMPLATE/bug.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐛 Bug Report 3 | about: Submit a bug report to help us improve 4 | labels: 'bug' 5 | --- 6 | 7 | ## 🐛 Bug Report 8 | 9 | (A clear and concise description of what the bug is) 10 | 11 | ## To Reproduce 12 | 13 | (Write your steps here:) 14 | 15 | 1. Step 1... 16 | 1. Step 2... 17 | 1. Step 3... 18 | 19 | ## Expected behavior 20 | 21 | (Write what you thought would happen.) 22 | 23 | ## Actual Behavior 24 | 25 | (Write what happened. Add screenshots, if applicable.) 26 | 27 | ## Your Environment 28 | 29 | 30 | 31 | - Python and PyTorch version: 32 | - Operating system and version (desktop or mobile): 33 | - Hardware (gpu or cpu, amount of RAM etc.): 34 | 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pkl 2 | *.png 3 | *.hdf5 4 | dump/ 5 | __pycache__ 6 | .DS_Store 7 | .vscode/ 8 | cache 9 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to DeepMEG-Encoding 2 | 3 | ## Pull Requests 4 | 5 | In order to accept your pull request, we need you to submit a CLA. You only need 6 | to do this once to work on any of Facebook's open source projects. 7 | 8 | Complete your CLA here: 9 | 10 | Demucs is the implementation of a research paper. 11 | Therefore, we do not plan on accepting many pull requests for new features. 12 | We certainly welcome them for bug fixes. 13 | 14 | 15 | ## Issues 16 | 17 | We use GitHub issues to track public bugs. Please ensure your description is 18 | clear and has sufficient instructions to be able to reproduce the issue. 19 | Please first check existing issues as well as the README for existing solutions. 20 | 21 | 22 | ## License 23 | By contributing to this repository, you agree that your contributions will be licensed 24 | under the LICENSE file in the root directory of this source tree. 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepMEG-Encoding project 2 | 3 | This project investigates models for forecasting MEG data, using past MEG and external stimuli. 4 | These models range from linear to nonlinear, with or without access to the initial brain state, and conclude with our Deep Recurrent Encoder (DRE) architecture. 5 | The DRE outperforms current methods and trains across subjects simulatenously. 6 | An ablation study yields insight into the modules which best explain its predictive performance. 7 | A simple feature importance analysis helps interpret what the deep architecture learns. 8 | 9 | Predictive Performance | Feature Importance 10 | :-------------------------:|:-------------------------: 11 | ![](https://user-images.githubusercontent.com/37180957/109517969-1c22be00-7aaa-11eb-9511-7301c27bf0ac.png) | ![](https://user-images.githubusercontent.com/37180957/109518451-89ceea00-7aaa-11eb-8124-cbaeec97d29c.png) 12 | 13 | 14 | ## General information 15 | 16 | You will need Python >= 3.7 to use this code. 17 | 18 | Install Python package requirements with: 19 | ``` 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | Find help with: 24 | ``` 25 | python3 -m neural --help 26 | ``` 27 | 28 | ## Data extraction 29 | 30 | First, install the MOUS dataset from https://data.donders.ru.nl/collections/di/dccn/DSC_3011020.09_236?3 31 | To do so, you may register via your orchid account. 32 | 33 | To extract MEG and stimuli from the MOUS dataset: 34 | 35 | ``` 36 | python3 -m neural.extraction --data /path/to/dataset --out /path/to/extraction 37 | ``` 38 | 39 | **Notes:** 40 | - This script will create the folder `/path/to/extraction/full`. 41 | - This step takes a few hours and requires in the extraction folder at least 90GB of disk space. 42 | - This step does not need a GPU to run. 43 | - Setting `--n-subjects 1` would perform the extraction over one subject only. This is useful 44 | if you have a limited disk space and you want to test quickly. 45 | - Setting `--use-pca` would extract 40 components from the MEG and the extraction folder becomes `/path/to/extraction/40`. This will require less disk space (about 15GB). 46 | 47 | Then, proceed to train encoding models. 48 | 49 | ## Train the Deep Recurrent Encoder (DRE) 50 | 51 | To train the DRE: 52 | 53 | ``` 54 | python3 -m neural --data /path/to/extraction/full --out /path/to/metrics 55 | ``` 56 | 57 | The ablations of the DRE (resp. "NO-CONV", "PCA", "NO-SUB", "NO-INIT") in the paper were trained using: 58 | ``` 59 | python3 -m neural --data /path/to/extraction --out /path/to/metrics --epochs 40 --conv-layers=0 60 | python3 -m neural --data /path/to/full/extraction --out /path/to/metrics --pca 40 61 | python3 -m neural --data /path/to/full/extraction --out /path/to/metrics --subject-dim=0 62 | python3 -m neural --data /path/to/full/extraction --out /path/to/metrics --meg-init 0 63 | ``` 64 | 65 | ## Train the Linear Encoders (TRF, RTRF) 66 | 67 | ``` 68 | python3 -m neural.linear --with-forcing --with-init --shuffle --out /path/to/metrics 69 | ``` 70 | 71 | ## License 72 | 73 | This repository is released under the CC-BY-NC 4.0. license as found in the [LICENSE](LICENSE) file. 74 | 75 | -------------------------------------------------------------------------------- /neural/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /neural/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import json 9 | import random 10 | import shutil 11 | from dataclasses import dataclass, field 12 | from functools import partial 13 | from pathlib import Path 14 | 15 | import numpy as np 16 | import torch as th 17 | from torch import nn 18 | from torch.utils.data import ConcatDataset 19 | import matplotlib.pyplot as plt 20 | import pandas as pd 21 | 22 | from .dataset import load_torch_megs 23 | from .model import MegPredictor 24 | from .train import train_eval_model 25 | from .utils import get_metrics, inverse 26 | from .visuals import report_correl 27 | 28 | 29 | def get_parser(): 30 | parser = argparse.ArgumentParser("neural", description="Train MEG predictor using forcings") 31 | parser.add_argument( 32 | "-o", "--out", type=Path, default=Path("dump"), 33 | help="Folder where checkpoints and metrics are saved.") 34 | parser.add_argument( 35 | "-R", "--restart", action='store_true', help='Restart training, ignoring previous run') 36 | parser.add_argument("--seed", type=int, default=42, help="Random seed") 37 | 38 | # Dataset related 39 | parser.add_argument( 40 | "-d", "--data", type=Path, 41 | required=True, 42 | help="Path to the data extracted.") 43 | parser.add_argument("-s", "--subjects", type=int, default=68, 44 | help="Maximum number of subjects.") 45 | parser.add_argument("--pca", type=int, help="Use PCA version of the data. " 46 | "Should be the dimension of the PCA used.") 47 | parser.add_argument("-x", "--exclude", action="append", default=[], help="Exclude features") 48 | parser.add_argument("-i", "--include", action="append", default=[], help="Include features") 49 | 50 | # Optimization parameters 51 | parser.add_argument("-e", "--epochs", type=int, default=60, 52 | help="Number of epochs to train for.") 53 | parser.add_argument("-b", "--batch-size", type=int, default=32) 54 | parser.add_argument("--lr", type=float, default=1e-4) 55 | parser.add_argument("--l1", action="store_true", help="Use L1 loss instead of MSE") 56 | 57 | # Parameters to the model 58 | parser.add_argument("--conv-layers", type=int, default=2, 59 | help="Number of convolution layers in the encoder/decoder.") 60 | parser.add_argument("--lstm-layers", type=int, default=2, 61 | help="Number of LSTM layers.") 62 | parser.add_argument("--conv-channels", type=int, default=512, 63 | help="Output channels for convolutions.") 64 | parser.add_argument("--lstm-hidden", type=int, default=512, 65 | help="Hidden dimension of the LSTM.") 66 | parser.add_argument("--subject-dim", type=int, default=16, 67 | help="Dimension of the subject embedding.") 68 | 69 | # Other parameters 70 | parser.add_argument("--meg-init", type=int, default=40, 71 | help="Number of MEG time steps to provide as basal state.") 72 | parser.add_argument("--no-forcings", action="store_false", dest="forcings", default=True, 73 | help="Remove all forcings from the input.") 74 | parser.add_argument("--save-meg", action="store_true", 75 | help="Save full MEG output for each subject.") 76 | 77 | return parser 78 | 79 | 80 | def make_repo_from_parser(args, parser): 81 | args.out.mkdir(exist_ok=True) 82 | 83 | parts = [] 84 | name_args = dict(args.__dict__) 85 | ignore = ["restart", "data", "out"] 86 | for key in ignore: 87 | name_args.pop(key, None) 88 | for name, value in name_args.items(): 89 | if value != parser.get_default(name): 90 | if isinstance(value, Path): 91 | value = value.name 92 | elif isinstance(value, list): 93 | value = ",".join(map(str, value)) 94 | parts.append(f"{name}={value}") 95 | if parts: 96 | name = " ".join(parts) 97 | else: 98 | name = "default" 99 | print(f"Experiment {name}") 100 | 101 | if args.pca is not None: 102 | suffix = f"{args.pca}" 103 | else: 104 | suffix = "full" 105 | 106 | # args.data = args.data.with_name(args.data.name + suffix) 107 | args.data = args.data / Path(suffix) 108 | print("Using dataset", args.data) 109 | 110 | out = args.out / name 111 | if args.restart and out.exists(): 112 | shutil.rmtree(out) 113 | out.mkdir(exist_ok=True) 114 | 115 | return out 116 | 117 | 118 | @dataclass 119 | class SavedState: 120 | metrics: list = field(default_factory=list) 121 | state: dict = None 122 | best_state: dict = None 123 | 124 | 125 | def main(): 126 | 127 | # Make repository 128 | parser = get_parser() 129 | args = parser.parse_args() 130 | out = make_repo_from_parser(args, parser) 131 | 132 | # Set seed and device 133 | device = "cuda" if th.cuda.is_available() else "cpu" 134 | print(f"Using device: {device}") 135 | th.manual_seed(args.seed) 136 | random.seed(args.seed) 137 | 138 | # Load data 139 | meg_sets = load_torch_megs(args.data, args.subjects, exclude=args.exclude, include=args.include) 140 | train_set = ConcatDataset(meg_sets.train_sets) 141 | valid_set = ConcatDataset(meg_sets.valid_sets) 142 | test_set = ConcatDataset(meg_sets.test_sets) 143 | 144 | # Instantiate model 145 | model = MegPredictor( 146 | meg_dim=meg_sets.meg_dim, 147 | forcing_dims=meg_sets.forcing_dims if args.forcings else {}, 148 | meg_init=args.meg_init, 149 | subject_dim=args.subject_dim, 150 | conv_layers=args.conv_layers, 151 | conv_channels=args.conv_channels, 152 | lstm_layers=args.lstm_layers, 153 | lstm_hidden=args.lstm_hidden).to(device) 154 | 155 | # Instantiate optimization 156 | optimizer = th.optim.Adam(model.parameters(), lr=args.lr) 157 | criterion = nn.L1Loss() if args.l1 else nn.MSELoss() 158 | train_eval = partial( 159 | train_eval_model, 160 | model=model, 161 | optimizer=optimizer, 162 | device=device, 163 | criterion=criterion, 164 | batch_size=args.batch_size) 165 | 166 | try: 167 | saved = th.load(out / "saved.th") 168 | except IOError: 169 | saved = SavedState() 170 | else: 171 | model.load_state_dict(saved.state) 172 | 173 | best_loss = float("inf") 174 | for epoch, metric in enumerate(saved.metrics): 175 | print(f"Epoch {epoch:04d}: " 176 | f"train={metric['train']:.4f} test={metric['valid']:.6f} best={metric['best']:.6f}") 177 | best_loss = metric['best'] 178 | 179 | # Train and Evaluate (valid set) the model 180 | # from where you left off 181 | # select best model over the epochs on valid set` 182 | print("Training and Validation...") 183 | for epoch in range(len(saved.metrics), args.epochs): 184 | train_loss, _ = train_eval(train_set) 185 | with th.no_grad(): 186 | valid_loss, evals = train_eval(valid_set, train=False, save=True) 187 | best_loss = min(valid_loss, best_loss) 188 | saved.metrics.append({ 189 | "train": train_loss, 190 | "valid": valid_loss, 191 | "best": best_loss, 192 | }) 193 | print(f"Epoch {epoch:04d}: " 194 | f"train={train_loss:.4f} valid={valid_loss:.6f} best={best_loss:.6f}") 195 | if valid_loss == best_loss: 196 | saved.best_state = { 197 | key: value.to("cpu").clone() 198 | for key, value in model.state_dict().items() 199 | } 200 | th.save(model, out / "model.th") 201 | json.dump(saved.metrics, open(out / "metrics.json", "w"), indent=2) 202 | saved.state = {key: value.to("cpu") for key, value in model.state_dict().items()} 203 | th.save(saved, out / "saved.th") 204 | 205 | # Save train-valid curve 206 | tmp = pd.read_json(out / "metrics.json") 207 | fig, ax = plt.subplots() 208 | ax.plot(tmp['train'], color='red', label='train') 209 | ax.plot(tmp['valid'], color='green', label='valid') 210 | ax.set_ylabel('Train Loss') 211 | ax.legend() 212 | fig.savefig(out / "train_valid_curve.png") 213 | 214 | # Load best model (on the valid set) 215 | json.dump(saved.metrics, open(out / "metrics.json", "w"), indent=2) 216 | model.load_state_dict(saved.best_state) 217 | 218 | # Evaluate (test set) the model 219 | with th.no_grad(): 220 | print("Evaluating model on test set...") 221 | 222 | # Reference evaluation 223 | ref_loss, ref_evals = train_eval(test_set, train=False, save=True) 224 | print("Ref loss", ref_loss) 225 | 226 | # Trim true and predicted meg 227 | # to a common time length (in case of an excess index) 228 | min_length = ref_evals.lengths.min().item() 229 | megs = ref_evals.megs[:, :, :min_length] 230 | ref_predictions = ref_evals.predictions[:, :, :min_length] 231 | 232 | # Reformat true and predicted meg: back to [N, T, C] 233 | megs = megs.transpose(1, 2) 234 | ref_predictions = ref_predictions.transpose(1, 2) 235 | 236 | # Loop over subjects 237 | ordered_subjects = ref_evals.subjects.unique().sort()[0] 238 | scores = list() 239 | 240 | for sub in ordered_subjects: 241 | sub_sel = (ref_evals.subjects == sub).nonzero().flatten() 242 | 243 | Y_true, Y_pred = megs[sub_sel], ref_predictions[sub_sel] 244 | 245 | # Load the necessary to reverse PCA 246 | pca_mat = meg_sets.pca_mats[sub] 247 | mean = meg_sets.means[sub] 248 | scaler = meg_sets.meg_scalers[sub] 249 | 250 | # Reverse PCA 251 | Y_true = inverse(mean, scaler, pca_mat, Y_true.double().numpy()) 252 | Y_pred = inverse(mean, scaler, pca_mat, Y_pred.double().numpy()) 253 | 254 | if args.save_meg: 255 | # Save prediction sample from all subjects [N, T, C] 256 | print(f"Saving MEG pred for sub {sub}...") 257 | th.save({"meg_pred_epoch": Y_pred[0], 258 | "meg_true_epoch": Y_true[0], 259 | "meg_pred_evoked": Y_pred.mean(0), 260 | "meg_true_evoked": Y_true.mean(0)}, 261 | out / f"meg_prediction_subject_{sub}.th") 262 | 263 | # Correlation metrics between true and predicted meg, shape [N, T, C] 264 | score = get_metrics(Y_true, Y_pred) 265 | 266 | scores.append(score) 267 | 268 | scores = np.stack(scores) # [S, T, C] 269 | print("Average prediction score (Pearson R): ", scores[:, 60:].mean()) 270 | 271 | # Save results 272 | th.save({"scores": scores}, 273 | out / "reference_metrics.th") 274 | 275 | # Shuffled-feature evaluations 276 | for name in list(meg_sets.forcing_dims) + ["meg", "subject"]: 277 | if name in ["word_onsets", "first_mask"]: 278 | continue 279 | test_loss, evals = train_eval(test_set, train=False, save=True, permut_feature=name) 280 | 281 | delta = (test_loss - ref_loss) / ref_loss 282 | print("Shuffled", name, "relative loss increase", 100 * delta) 283 | predictions = evals.predictions[:, :, :min_length].transpose(1, 2) 284 | assert (evals.megs == ref_evals.megs).all() 285 | report_correl( 286 | megs, 287 | predictions, 288 | out / f"feature_importance_{name}_correl_all.png", 289 | ref=ref_predictions, 290 | start=60) 291 | 292 | # Loop over subjects 293 | scores = list() 294 | 295 | for sub in ordered_subjects: 296 | sub_sel = (ref_evals.subjects == sub).nonzero().flatten() 297 | 298 | Y_true, Y_pred = megs[sub_sel], predictions[sub_sel] 299 | 300 | # Load the necessary to reverse PCA 301 | pca_mat = meg_sets.pca_mats[sub] 302 | mean = meg_sets.means[sub] 303 | scaler = meg_sets.meg_scalers[sub] 304 | 305 | # Reverse PCA 306 | Y_true = inverse(mean, scaler, pca_mat, Y_true.numpy()) 307 | Y_pred = inverse(mean, scaler, pca_mat, Y_pred.numpy()) 308 | 309 | # Correlation metrics 310 | score = get_metrics(Y_true, Y_pred) 311 | 312 | scores.append(score) 313 | 314 | scores = np.stack(scores) 315 | 316 | th.save({"scores": scores}, 317 | out / f"shuffled_{name}_metrics.th") 318 | 319 | 320 | if __name__ == "__main__": 321 | main() 322 | -------------------------------------------------------------------------------- /neural/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | '''Fits all extracted files from the MOUS dataset into a usable, 8 | self-contained MEGDatasets structure. 9 | It comprises: 10 | -- torch datasets (train, valid, test) 11 | -- useful information to map from PCA space (default: 40) to Sensor space (default: 273) 12 | (e.g. scaling, meg mean, pca matrix) 13 | ''' 14 | 15 | from collections import defaultdict, namedtuple 16 | from pathlib import Path 17 | 18 | import numpy as np 19 | import torch as th 20 | import tqdm 21 | from sklearn.preprocessing import RobustScaler 22 | from torch.utils.data import Dataset 23 | 24 | 25 | class MegSubject(Dataset): 26 | def __init__(self, meg, forcings, length, subject_id): 27 | """ 28 | Torch Dataset class for storing the stimuli and MEG response 29 | of a single subject. 30 | 31 | Inputs: 32 | - meg: shape [N, C, T] with N trials, C channels, T time steps 33 | - forcings: dict, stimulus features (e.g. word frequency) 34 | keys: word onset, length, freq, first_mask (?), last, first 35 | values: each is a [N, 1, T] 36 | - length: shape [], length T shared by all sequences 37 | - subject_id: int 38 | """ 39 | self.meg = meg 40 | self.forcings = forcings 41 | self.length = length 42 | self.subject_id = subject_id 43 | 44 | def __len__(self): 45 | """Gives total number of samples N""" 46 | return self.meg.size(0) 47 | 48 | def __getitem__(self, idx): 49 | """Returns one sample of data""" 50 | return (self.meg[idx], {k: v[idx] 51 | for k, v in self.forcings.items()}, self.length[idx], 52 | self.subject_id) 53 | 54 | 55 | # Higher level object to store datasets with their information (vs. torch ConcatDataset) 56 | MegDatasets = namedtuple("MegDataset", 57 | "train_sets valid_sets test_sets meg_dim forcing_dims meg_scalers " 58 | "pca_mats means") 59 | 60 | 61 | def _narrow(tensor, indexes): 62 | return tensor.gather(2, indexes.expand(-1, tensor.size(1), -1)) 63 | 64 | 65 | def _prepare_forcing(forcing): 66 | '''Reformat forcing we extracted from the MOUS dataset. 67 | Inputs: 68 | - forcing: [N, T, 1] forcing feature (e.g. word frequency) 69 | Usually a value from the forcings dict, whose keys are: 70 | word onset, length, freq, n_before, n_after, first_mask 71 | ''' 72 | forcing = th.from_numpy(forcing).float() 73 | 74 | # Reformat the forcing feature as [N, 1, T] 75 | if forcing.dim() == 2: 76 | forcing = forcing.unsqueeze(1) 77 | else: 78 | forcing = forcing.permute(0, 2, 1) 79 | 80 | # Normalize 81 | forcing_normalized = (forcing - forcing.mean()) / forcing.std() 82 | 83 | return forcing_normalized 84 | 85 | 86 | def load_torch_megs(path, n_subjects_max=None, subject=None, init=60, exclude=[], include=[]): 87 | 88 | # Create dict to the paths of all extracted files (one per subject) 89 | path = Path(path) 90 | subjects = defaultdict(dict) 91 | for child in path.iterdir(): 92 | if child.suffix == ".pth": # e.g. meg_1076_4_visual.pth 93 | kind, sub, *_ = child.stem.split("_") 94 | subjects[sub][kind] = child 95 | 96 | # Select subjects of interest 97 | to_load = list(subjects.keys()) 98 | to_load.sort() 99 | if subject is not None: 100 | to_load = [to_load[subject]] 101 | if n_subjects_max: 102 | to_load = to_load[:n_subjects_max] 103 | 104 | train_sets = [] 105 | valid_sets = [] 106 | test_sets = [] 107 | meg_scalers = [] 108 | means = [] 109 | pca_mats = [] 110 | 111 | iterator = tqdm.tqdm(to_load, leave=False, ncols=120, desc="Loading data...") 112 | 113 | # Loop over subjects of interest 114 | subjs = [] 115 | for index, subject in enumerate(iterator): 116 | # Load meg and forcing extraction files 117 | megdata = th.load(subjects[subject]["meg"]) 118 | forcings = th.load(subjects[subject]["forcing"]) 119 | subjs.append(megdata.get("subject", subject)) # what does the second arg do? 120 | 121 | after = forcings.pop("word_n_after", None) 122 | before = forcings.pop("word_n_before", None) 123 | 124 | # Define new stimulus features (last_word, first_word) 125 | # from old stimulus features (word_n_after, word_n_before) 126 | # assuming they are inclusive 127 | if after is not None: 128 | forcings["last_word"] = (after == 1).astype(np.float32) 129 | if before is not None: 130 | forcings["first_word"] = (before == 1).astype(np.float32) 131 | if "is_stop" in forcings: 132 | last_word = "is_stop" 133 | else: 134 | last_word = "last_word" 135 | # Create mask to select the first stimulus only in the 2.5s epoch 136 | if "first_mask" not in forcings: 137 | stim = forcings["stimulus"] 138 | first = 0 * stim 139 | for row in range(len(stim)): 140 | low = 60 141 | start = low + stim[row, low:].nonzero()[0][0] 142 | end = (stim[row, start:] == 0).nonzero()[0] 143 | if len(end): 144 | end = end[0] 145 | first[row, start:start + end] = 1 146 | else: 147 | # print(subject, row, stim[row]) 148 | first[row, start:] = 1 149 | forcings["first_mask"] = first 150 | 151 | # Include or exclude stimulus features based on their name (key of forcing dict) 152 | for name in exclude: 153 | if name not in forcings: 154 | raise ValueError(f"{name} is not a valid feature name.") 155 | for name in include: 156 | if name not in forcings: 157 | raise ValueError(f"{name} is not a valid feature name.") 158 | if include: 159 | feats = list(include) 160 | else: 161 | feats = list(forcings.keys()) 162 | for name in exclude: 163 | feats.remove(name) 164 | 165 | forcings = { 166 | # just normalizes forcing and permutes to [N, 1, T] 167 | name: _prepare_forcing(forcing) 168 | for name, forcing in forcings.items() if name in feats 169 | } 170 | forcing_dims = {} 171 | for key, value in forcings.items(): 172 | forcing_dims[key] = value.size(1) # expected: 1 173 | 174 | meg = megdata["meg"] 175 | pca_mats.append(megdata["pca_mat"]) 176 | if "meg_last_idx" in megdata: 177 | last_index = th.from_numpy(megdata["meg_last_idx"]) 178 | else: 179 | last_index = th.full((meg.shape[0], ), meg.shape[1], dtype=th.long) 180 | 181 | # Scale (robust) meg data 182 | # TODO: separate scaling for each set (train, valid, test)? 183 | meg_scaler = RobustScaler() 184 | meg_scalers.append(meg_scaler) 185 | meg = meg_scaler.fit_transform(meg.reshape(-1, meg.shape[-1])).reshape(*meg.shape) 186 | meg = th.from_numpy(meg) 187 | 188 | # Remove trials where an amplitude is too high (e.g. 16) after scaling 189 | max_amplitude = meg.abs().max(dim=1)[0].max(dim=1)[0] 190 | mask = max_amplitude <= 16 191 | # print(mask.float().mean(), mask.shape) 192 | meg = meg[mask] 193 | forcings = {key: value[mask] for key, value in forcings.items()} 194 | last_index = last_index[mask] 195 | 196 | # Center meg data 197 | mean = meg.mean(dim=0, keepdim=True).mean(dim=1, keepdim=True) 198 | means.append(mean) 199 | meg = meg - mean 200 | 201 | # Change meg format: [N, T, C] -> [N, C, T] 202 | meg = meg.permute(0, 2, 1) 203 | meg_dim = meg.size(1) # expected: C 204 | 205 | n_trials = meg.shape[0] 206 | 207 | # Separate trials into train / valid / test: 208 | # search for an end of sentence to do the cuts 209 | train, valid, test = 0.7, 0.1, 0.2 210 | 211 | for trial in range(int(train * n_trials), 212 | int((train + valid) * n_trials)): 213 | if forcings[last_word][trial, 0, 60] > 0: # end of sentence 214 | break 215 | idx_train = list(range(trial + 1)) 216 | 217 | for trial in range(int((train + valid) * n_trials), n_trials): 218 | if forcings[last_word][trial, 0, 60] > 0: 219 | break 220 | idx_valid = list(range(idx_train[-1] + 1, trial + 1)) 221 | idx_test = list(range(idx_valid[-1] + 1, n_trials)) 222 | 223 | # Instantiate train/valid/test epoched datasets 224 | dataset_train = MegSubject( 225 | meg=meg[idx_train], 226 | forcings={k: v[idx_train] 227 | for k, v in forcings.items()}, 228 | length=1 + last_index[idx_train], 229 | subject_id=index) 230 | 231 | dataset_valid = MegSubject( 232 | meg=meg[idx_valid], 233 | forcings={k: v[idx_valid] 234 | for k, v in forcings.items()}, 235 | length=1 + last_index[idx_valid], 236 | subject_id=index) 237 | 238 | dataset_test = MegSubject( 239 | meg=meg[idx_test], 240 | forcings={k: v[idx_test] 241 | for k, v in forcings.items()}, 242 | length=1 + last_index[idx_test], 243 | subject_id=index) 244 | 245 | train_sets.append(dataset_train) 246 | valid_sets.append(dataset_valid) 247 | test_sets.append(dataset_test) 248 | 249 | print("subjects: ", subjs) 250 | print("Overall train size: ", sum(tr.meg.shape[0] for tr in train_sets)) 251 | print("Overall valid size: ", sum(tr.meg.shape[0] for tr in valid_sets)) 252 | print("Overall test size: ", sum(tr.meg.shape[0] for tr in test_sets)) 253 | 254 | return MegDatasets( 255 | train_sets=train_sets, 256 | valid_sets=valid_sets, 257 | test_sets=test_sets, 258 | meg_scalers=meg_scalers, 259 | means=means, 260 | pca_mats=pca_mats, 261 | meg_dim=meg_dim, 262 | forcing_dims=forcing_dims) 263 | -------------------------------------------------------------------------------- /neural/extraction.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # import libraries 8 | 9 | from concurrent.futures import ProcessPoolExecutor 10 | import argparse 11 | import os 12 | import traceback 13 | 14 | from sklearn.decomposition import PCA 15 | import numpy as np 16 | import mne 17 | import torch as th 18 | 19 | from .utils_mous import ( 20 | create_directory, get_word_length, get_word_freq, 21 | read_log, get_log_times, _add_stim_id, setup_logfiles, setup_stimuli) 22 | 23 | mne.set_log_level(False) 24 | 25 | 26 | def get_parser(): 27 | '''Parser, to input arguments from the terminal.''' 28 | parser = argparse.ArgumentParser("extraction", 29 | description="Extract meg and forcing from MOUS Dataset") 30 | parser.add_argument("--data", type=str, help="Path to MOUS dataset") 31 | parser.add_argument("--out", type=str, help="Path where to save output") 32 | parser.add_argument("--use-pca", action="store_true", default=False) 33 | parser.add_argument("--pca-dim", type=int, default=40) 34 | parser.add_argument("--n-subjects", type=int, default=-1, 35 | help="Maximum number of subjects to extract from.") 36 | parser.add_argument("--workers", type=int, default=10, 37 | help="Number of parallel workers.") 38 | return parser 39 | 40 | 41 | def make_repo_from_parser(args, parser): 42 | '''Creates a repository whose name is based on the arguments of the parser.''' 43 | output_repo = args.out 44 | if args.use_pca: 45 | output_repo += f"/{args.pca_dim}" 46 | else: 47 | output_repo += "/full" 48 | create_directory(output_repo) 49 | 50 | return output_repo 51 | 52 | 53 | # collect arguments and make repository 54 | parser = get_parser() 55 | args = parser.parse_args() 56 | output_directory = make_repo_from_parser(args, parser) 57 | print("output directory: ", output_directory) 58 | data_path = args.data 59 | 60 | # create legible csv tables from MOUS data for meg and stimuli 61 | cache = "./cache" 62 | create_directory(cache) 63 | log_files = setup_logfiles(data_path, cache) 64 | stimuli = setup_stimuli(data_path, cache) 65 | 66 | # select files for given task 67 | log_files = log_files.query('task=="visual"').reset_index(drop=True) 68 | 69 | 70 | def extract_subject(subject): 71 | '''Extracts MEG and forcing for a subject in the MOUS Dataset. 72 | 73 | Input: 74 | subject (int): subject identifier 75 | ''' 76 | log_file = log_files.iloc[subject] 77 | try: 78 | # generic output filename 79 | output_fname = [ 80 | "%s", 81 | str(log_file["subject"]), 82 | str(log_file["log_id"]), log_file["task"] 83 | ] 84 | output_fname = "_".join(output_fname) + ".pth" 85 | 86 | ########################## 87 | # LOAD MEG AND LOG 88 | ########################## 89 | 90 | # get meg and log filenames 91 | raw_fname = os.path.join(data_path, log_file['meg_file']) 92 | log_fname = os.path.join(data_path, log_file['log_file']) 93 | 94 | # read meg (continuous) 95 | raw = mne.io.read_raw_ctf(raw_fname, preload=True) 96 | raw.filter(1., 30.) # Slow 97 | 98 | # preprocess annotations and add task information 99 | log = read_log(log_fname, stimuli) 100 | log = _add_stim_id(log, verbose=False, stimuli=stimuli) # get words 101 | 102 | # adding n words before and after in sentence 103 | log_words = log.query('condition=="word"') 104 | words_idx = log.query('condition=="word"').index 105 | 106 | sentence_lengths = np.bincount(log_words.sequence_pos.values.astype(int)) 107 | 108 | n_words_before = np.concatenate([np.arange(length) + 1 109 | for length in sentence_lengths]).flatten() 110 | n_words_before = n_words_before.astype(int) 111 | 112 | n_words_after = np.concatenate([np.ones(length) * length 113 | for length in sentence_lengths]).flatten() \ 114 | - n_words_before 115 | n_words_after = n_words_after.astype(int) 116 | 117 | log.loc[words_idx, "n_words_before"] = n_words_before 118 | log.loc[words_idx, "n_words_after"] = n_words_after 119 | 120 | # find events 121 | events = mne.find_events(raw, min_duration=.010) 122 | # link meg and annotations 123 | log = get_log_times(log, events, raw.info['sfreq']) 124 | 125 | ########################## 126 | # EXTRACT MEG 127 | ########################## 128 | 129 | # select desired event 130 | log_events = log.query('condition=="word"') 131 | 132 | # format events for mne 133 | log_events_formatted = np.c_[log_events.meg_sample, 134 | np.ones((len(log_events), 2), int)] 135 | _, idx = np.unique(log_events_formatted[:, 0], return_index=True) 136 | 137 | # segment meg into word-locked epochs 138 | picks = mne.pick_types(raw.info, 139 | meg=True, 140 | eeg=False, 141 | stim=False, 142 | eog=False, 143 | ecg=False) 144 | decim = 10 145 | tmin, tmax = -.500, 2 146 | 147 | epochs = mne.Epochs( 148 | raw, 149 | events=log_events_formatted, 150 | metadata=log_events, 151 | tmin=tmin, 152 | tmax=tmax, 153 | decim=decim, 154 | preload=True, 155 | picks=picks, 156 | ) 157 | 158 | # throw away compensation channels 159 | bads = [epochs.ch_names[i] for i in range(28)] # hardcoded 160 | raw = raw.pick_types(meg=True, exclude=bads) 161 | epochs = epochs.pick_types(meg=True, exclude=bads) 162 | 163 | # get evoked meg 164 | evoked = epochs.average(method='mean') 165 | 166 | # get pca on evoked 167 | evoked_temp = evoked.apply_baseline().data.T * 1e12 # scaled 168 | duration_for_pca = int((np.abs(tmin) + 1) * epochs.info["sfreq"]) 169 | evoked_temp = evoked_temp[:duration_for_pca] # cropped 170 | if args.use_pca: 171 | pca = PCA(args.pca_dim).fit(evoked_temp) 172 | pca_mat = pca.components_ 173 | else: 174 | pca_mat = np.eye(evoked_temp.shape[1], dtype=np.float32) 175 | 176 | ########################## 177 | # SAVE MEG 178 | ########################## 179 | 180 | # collect 181 | meg = epochs.get_data() 182 | meg_evoked = evoked.apply_baseline().data[None, :, :] 183 | 184 | # useful for sentences of different lengths 185 | meg_last_idx = (np.abs(tmin) + tmax) * epochs.info["sfreq"] * np.ones(len(epochs)) 186 | meg_last_idx = meg_last_idx.astype(int) 187 | 188 | # reformat 189 | meg = np.swapaxes(meg, 1, 2) 190 | meg_evoked = np.swapaxes(meg_evoked, 1, 2) 191 | 192 | meg_pca = meg @ pca_mat.T 193 | times = np.array(epochs.metadata["time"], dtype=np.float32) 194 | 195 | # save 196 | output_dict = dict( 197 | zip(["meg", "meg_last_idx", "pca_mat", "epochs_info", "times", "subject"], 198 | [meg_pca.astype(np.float32), 199 | meg_last_idx, pca_mat, epochs.info, times, log_files.subject[subject] 200 | ])) 201 | output_path = os.path.join(output_directory, output_fname % "meg") 202 | print("output path: ", output_path) 203 | th.save(output_dict, output_path) 204 | 205 | ########################## 206 | # LOAD FORCING 207 | ########################## 208 | 209 | n_epochs, n_channels, n_times = epochs.get_data().shape 210 | forcing_word = np.zeros((n_epochs, 6, n_times), dtype=np.float32) 211 | 212 | for epo_idx in range(n_epochs): 213 | 214 | # continuous time interval 215 | on = epochs.metadata.iloc[epo_idx].time 216 | start, end = on - np.abs(tmin), on + tmax 217 | 218 | # corresponding words 219 | cond = (start < epochs.metadata.time) & (epochs.metadata.time < end) 220 | 221 | words = epochs.metadata[cond].word.values.flatten().tolist() 222 | 223 | # recentering the time interval around the main onset 224 | onsets = epochs.metadata[cond].time.values - on + np.abs(tmin) 225 | 226 | # converting the time interval from s to Tsampl 227 | onsets = (onsets * epochs.info["sfreq"]).astype(int) 228 | 229 | # recovering word durations, then offsets 230 | durations = epochs.metadata[cond].Duration.values.astype(float) * 1e-4 # unit: second 231 | durations = (durations * epochs.info["sfreq"]).astype(int) # unit: time sample 232 | offsets = onsets + durations 233 | 234 | # getting features 235 | word_lengths = get_word_length(words) 236 | word_freqs = get_word_freq(words) 237 | word_n_before = epochs.metadata[cond].n_words_before.values.flatten().tolist() 238 | # add + 1 to make difference with no forcing 239 | word_n_after = (epochs.metadata[cond].n_words_after.values.flatten() + 1).tolist() 240 | 241 | # placing square on word presence 242 | for idx, (onset, offset) in enumerate(zip(onsets, offsets)): 243 | forcing_word[epo_idx, 0, onset: offset] = 1. 244 | forcing_word[epo_idx, 1, onset:offset] = word_lengths[idx] 245 | forcing_word[epo_idx, 2, onset:offset] = word_freqs[idx] 246 | forcing_word[epo_idx, 3, onset:offset] = word_n_before[idx] 247 | forcing_word[epo_idx, 4, onset:offset] = word_n_after[idx] 248 | if idx == 0: 249 | # mask for first forcing used to shuffle features 250 | forcing_word[epo_idx, 5, onset:offset] = 1. 251 | 252 | # save forcing 253 | forcing_names = ["word_onsets", "word_lengths", "word_freqs", 254 | "word_n_before", "word_n_after", "first_mask"] 255 | 256 | forcing = [forcing_word[:, 0, :][:, None, :], 257 | forcing_word[:, 1, :][:, None, :], 258 | forcing_word[:, 2, :][:, None, :], 259 | forcing_word[:, 3, :][:, None, :], 260 | forcing_word[:, 4, :][:, None, :], 261 | forcing_word[:, 5, :][:, None, :], 262 | ] 263 | 264 | # reformat 265 | forcing = [np.swapaxes(f, 1, 2) for f in forcing] 266 | 267 | # save 268 | output_dict = dict(zip(forcing_names, forcing)) 269 | output_path = os.path.join(output_directory, 270 | output_fname % "forcing") 271 | th.save(output_dict, output_path) 272 | 273 | except Exception as e: 274 | print(f"Error {e} with subject {subject} {log_file}") 275 | traceback.print_exc() 276 | return 277 | else: 278 | print("SUBJECT", subject, "done") 279 | 280 | 281 | # loop over subjects 282 | if args.n_subjects == -1: 283 | n_subjects = len(log_files) 284 | else: 285 | n_subjects = args.n_subjects 286 | if args.workers == 1: 287 | for subject in range(n_subjects): 288 | extract_subject(subject) 289 | else: 290 | with ProcessPoolExecutor(args.workers) as pool: 291 | for subject in range(n_subjects): 292 | pool.submit(extract_subject, subject) 293 | -------------------------------------------------------------------------------- /neural/linear/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /neural/linear/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # import libraries 8 | import argparse 9 | from concurrent.futures import ProcessPoolExecutor 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import torch 14 | import tqdm 15 | from sklearn.model_selection import train_test_split 16 | from sklearn.preprocessing import StandardScaler 17 | 18 | from ..dataset import load_torch_megs 19 | from ..utils import get_metrics, inverse 20 | from ..visuals import plt # noqa 21 | from .arx import ARX 22 | from .receptive_field import RField 23 | from .stats import report_correl 24 | 25 | 26 | def get_parser(): 27 | parser = argparse.ArgumentParser("lin", description="Train lin predictors using forcings") 28 | parser.add_argument( 29 | "-d", "--data", type=Path, 30 | required=True, 31 | help="Path to the data extracted.") 32 | parser.add_argument("--n-subjects", type=int, default=68, help="Max number of subjects") 33 | parser.add_argument("--out", type=Path, default=Path("dump")) 34 | parser.add_argument("--with-forcing", action="store_true", default=False) 35 | parser.add_argument("--with-init", action="store_true", default=False) 36 | parser.add_argument("--shuffle", action="store_true", default=False) 37 | parser.add_argument("--pca", type=int, help="Use PCA version of the data. " 38 | "Should be the dimension of the PCA used.") 39 | parser.add_argument("--n-workers", type=int, default=20, help="Workers for parallelization.") 40 | return parser 41 | 42 | 43 | def make_repo_from_parser(args): 44 | '''Creates and outputs a generic results repository 45 | from the parser arguments.''' 46 | 47 | parts = ["linear"] 48 | name_args = dict(args.__dict__) 49 | ignore = ["data", "out"] 50 | for key in ignore: 51 | name_args.pop(key, None) 52 | for name, value in name_args.items(): 53 | parts.append(f"{name}={value}") 54 | name = " ".join(parts) 55 | print(f"Experiment {name}") 56 | 57 | # data path 58 | if args.pca is not None: 59 | suffix = f"{args.pca}" 60 | else: 61 | suffix = "full" 62 | args.data = args.data / Path(suffix) 63 | print("Using dataset", args.data) 64 | 65 | # omar added 66 | name_reg = name + " reg" 67 | name_autoreg = name + " autoreg" 68 | 69 | out_reg = args.out / name_reg 70 | out_autoreg = args.out / name_autoreg 71 | 72 | out_reg.mkdir(exist_ok=True) 73 | out_autoreg.mkdir(exist_ok=True) 74 | return out_reg, out_autoreg 75 | 76 | 77 | def permute_forcing(first_mask, forcing, permutation, init=60): 78 | initial = forcing[:, :, init:init + 1] 79 | mask = first_mask > 0 80 | print(mask.float().mean(), first_mask.min(), first_mask.max()) 81 | return torch.where(mask, initial[permutation], forcing) 82 | 83 | 84 | def shuffle_forcings(forcings, name): 85 | forcing = forcings[name] 86 | first_mask = forcings["first_mask"] 87 | permutation = torch.randperm(forcing.size(0)) 88 | forcings[name] = permute_forcing(first_mask, forcing, permutation) 89 | 90 | 91 | def eval_lin_models(subject, 92 | data_path, 93 | results_path_reg, 94 | results_path_autoreg, 95 | n_init=40, 96 | tune_models=True, 97 | with_init=True, 98 | with_forcing=True, 99 | shuffle=False): 100 | # Load dataset 101 | data = load_torch_megs(data_path, subject=subject) 102 | 103 | # Load the necessary to reverse PCA 104 | pca_mat = data.pca_mats[0] 105 | mean = data.means[0] 106 | scaler = data.meg_scalers[0] 107 | 108 | # Get train / valid / test sets, tensor shape [N, C, T] 109 | meg_train = data.train_sets[0].meg.numpy() 110 | meg_valid = data.valid_sets[0].meg.numpy() 111 | meg_test = data.test_sets[0].meg.numpy() 112 | 113 | forcing_keys = data.train_sets[0].forcings.keys() 114 | forcing_train = np.concatenate(list([data.train_sets[0].forcings[k] 115 | for k in forcing_keys]), axis=1) 116 | forcing_valid = np.concatenate(list([data.valid_sets[0].forcings[k] 117 | for k in forcing_keys]), axis=1) 118 | forcing_test = np.concatenate(list([data.test_sets[0].forcings[k] 119 | for k in forcing_keys]), axis=1) 120 | 121 | if not with_forcing: 122 | forcing_train = np.zeros_like(forcing_train) 123 | forcing_valid = np.zeros_like(forcing_valid) 124 | forcing_test = np.zeros_like(forcing_test) 125 | 126 | # Reformat [N, T, C] 127 | [meg_train, meg_valid, meg_test, 128 | forcing_train, forcing_valid, forcing_test] = [ 129 | np.swapaxes(elem, 1, 2) for elem in [meg_train, meg_valid, meg_test, 130 | forcing_train, forcing_valid, forcing_test] 131 | ] 132 | 133 | ###################### 134 | # LIN REG 135 | ###################### 136 | 137 | # Instantiate 138 | rfield = RField(lag_u=260, penal_weight=1.8) 139 | 140 | # Tune hyperparameter on valid set 141 | alpha_scores = list() 142 | alphas = np.logspace(-3, 3, 5) 143 | 144 | if tune_models: 145 | 146 | for alpha in alphas: 147 | rfield.model.estimator = alpha 148 | rfield.fit(forcing_train, meg_train) 149 | meg_pred = rfield.predict(forcing_valid) 150 | meg_true = meg_valid 151 | # computing metrics 152 | alpha_score = get_metrics(meg_true, meg_pred) 153 | alpha_scores.append(alpha_score.mean()) 154 | # plt.plot(np.log10(alphas), alpha_scores) 155 | # plt.ylabel('r') 156 | # plt.show() 157 | # plt.close() 158 | 159 | alpha = alphas[np.argmax(alpha_scores)] 160 | rfield.model.estimator = alpha 161 | 162 | # Retrain on train + valid set, save model 163 | rfield.fit(forcing_train, meg_train) 164 | torch.save(rfield, results_path_reg / f"model_trf_subject_{subject}.th") 165 | 166 | # Predict on test set 167 | meg_pred = rfield.predict(forcing_test) 168 | meg_true = meg_test 169 | 170 | # Reverse PCA 171 | meg_pred = inverse(mean, scaler, pca_mat, meg_pred) 172 | meg_true = inverse(mean, scaler, pca_mat, meg_true) 173 | 174 | # Save plot 175 | report_correl(meg_true, meg_pred, results_path_reg / "reg.png", 0) 176 | 177 | # Save prediction sample from all subjects 178 | torch.save({"meg_pred_epoch": meg_pred[0], 179 | "meg_true_epoch": meg_true[0], 180 | "meg_pred_evoked": meg_pred.mean(0), 181 | "meg_true_evoked": meg_true.mean(0)}, 182 | results_path_reg / f"meg_prediction_subject_{subject}.th") 183 | 184 | # Compute metric (Pearson R) 185 | score_linreg = get_metrics(meg_true, meg_pred) 186 | 187 | # Permutation Feature Importance 188 | shuffled = {} 189 | to_shuffle = ["word_lengths", "word_freqs"] if shuffle else [] 190 | for name in to_shuffle: 191 | 192 | # Permute forcing (via original torch forcing) 193 | forcing_test_torch = data.test_sets[0].forcings 194 | shuffle_forcings(forcing_test_torch, name) 195 | forcing_test_shuffle = np.concatenate(list([forcing_test_torch[k] 196 | for k in forcing_keys]), axis=1) 197 | forcing_test_shuffle = np.swapaxes(forcing_test_shuffle, 1, 2) 198 | if not with_forcing: 199 | forcing_test_shuffle = np.zeros_like(forcing_test_shuffle) 200 | 201 | # Predict on test set 202 | meg_pred = rfield.predict(forcing_test_shuffle) 203 | meg_true = meg_test 204 | 205 | # Reverse pca 206 | meg_pred = inverse(mean, scaler, pca_mat, meg_pred) 207 | meg_true = inverse(mean, scaler, pca_mat, meg_true) 208 | 209 | # Compute metric (Pearson R) 210 | score_tmp = get_metrics(meg_true, meg_pred) 211 | shuffled[name] = score_tmp 212 | 213 | ###################### 214 | # LIN AUTOREG 215 | ###################### 216 | 217 | # Instantiate 218 | ridge = ARX(lag_u=n_init, lag_y=n_init, solver="ridge", penal_weight=1.8, scaling=False) 219 | 220 | # Tune hyperparameter on valid set 221 | alpha_scores = list() 222 | 223 | if tune_models: 224 | 225 | for alpha in alphas: 226 | ridge.penal_weight = alpha 227 | ridge.fit(forcing_train, meg_train) 228 | meg_init = np.zeros_like(meg_valid) 229 | if with_init: 230 | meg_init[:, :n_init, :] = meg_valid[:, :n_init, :] 231 | meg_pred = ridge.predict( 232 | forcing_valid, meg_init, start=n_init, eval="unrolled") 233 | meg_true = meg_valid 234 | # computing metrics 235 | alpha_score = get_metrics(meg_true, meg_pred) 236 | alpha_scores.append(alpha_score.mean()) 237 | 238 | # plt.plot(np.log10(alphas), alpha_scores) 239 | # plt.ylabel('r') 240 | # plt.show() 241 | # plt.close() 242 | 243 | alpha = alphas[np.argmax(alpha_scores)] 244 | ridge.penal_weight = alpha 245 | 246 | # Retrain on train + valid set, save model 247 | ridge.fit(forcing_train, meg_train) 248 | torch.save(ridge, results_path_autoreg / f"model_rtrf_subject_{subject}.th") 249 | 250 | # Predict on test set 251 | meg_init = np.zeros_like(meg_test) 252 | if with_init: 253 | meg_init[:, :n_init, :] = meg_test[:, :n_init, :] 254 | meg_pred = ridge.predict(forcing_test, meg_init, start=n_init, eval="unrolled") 255 | meg_true = meg_test 256 | 257 | # Reverse PCA 258 | meg_pred = inverse(mean, scaler, pca_mat, meg_pred) 259 | meg_true = inverse(mean, scaler, pca_mat, meg_true) 260 | 261 | # Save plot 262 | report_correl(meg_true, meg_pred, results_path_autoreg / "autoreg.png", n_init) 263 | 264 | # Save prediction sample for all subjects 265 | torch.save({"meg_pred_epoch": meg_pred[0], 266 | "meg_true_epoch": meg_true[0], 267 | "meg_pred_evoked": meg_pred.mean(0), 268 | "meg_true_evoked": meg_true.mean(0)}, 269 | results_path_autoreg / f"meg_prediction_subject_{subject}.th") 270 | 271 | # Compute metric (Pearson R) 272 | score_linautoreg = get_metrics(meg_true, meg_pred) 273 | 274 | # TODO: add Permutation Feature Importance for linear autoreg 275 | score_linautoreg = np.zeros_like(score_linreg) 276 | 277 | return score_linreg, score_linautoreg, shuffled 278 | 279 | 280 | def main(): 281 | 282 | # Make repository 283 | parser = get_parser() 284 | args = parser.parse_args() 285 | out_reg, out_autoreg = make_repo_from_parser(args) 286 | 287 | # Prepare model labels 288 | if (not args.with_init) and (args.with_forcing): 289 | label_add = "(no init)" 290 | elif (args.with_init) and (not args.with_forcing): 291 | label_add = "(no forcing)" 292 | elif (args.with_init) and (args.with_forcing): 293 | label_add = "" 294 | elif (not args.with_init) and (not args.with_forcing): 295 | label_add = "(no init, no forcing)" 296 | 297 | # Initialize result dicts 298 | reg_results = {"label": "lin reg " + label_add, 299 | "scores": []} 300 | shuffled_results = {"word_freqs": [], 301 | "word_lengths": []} 302 | autoreg_results = {"label": "lin autoreg " + label_add, 303 | "scores": []} 304 | 305 | # Loop over subjects (in parallel) 306 | with ProcessPoolExecutor(args.n_workers) as pool: 307 | 308 | pendings = [] 309 | for sub in range(args.n_subjects): 310 | pendings.append( 311 | pool.submit( 312 | eval_lin_models, 313 | sub, 314 | args.data, 315 | out_reg, 316 | out_autoreg, 317 | with_forcing=args.with_forcing, 318 | with_init=args.with_init, 319 | shuffle=args.shuffle)) 320 | 321 | for pending in tqdm.tqdm(pendings): 322 | (score_linreg, score_linautoreg, shuffled) = pending.result() 323 | 324 | # stack results in lists 325 | reg_results["scores"].append(score_linreg) 326 | autoreg_results["scores"].append(score_linautoreg) 327 | for key, score_shuffled in shuffled.items(): 328 | shuffled_results[key].append(score_shuffled) 329 | 330 | # Making numpy arrays from lists 331 | reg_results["scores"] = np.array(reg_results["scores"]) 332 | autoreg_results["scores"] = np.array(autoreg_results["scores"]) 333 | for key in shuffled_results.keys(): 334 | shuffled_results[key] = np.array(shuffled_results[key]) 335 | 336 | # # Converting to torch arrays 337 | # reg_results["scores"] = torch.from_numpy(reg_results["scores"]) 338 | # autoreg_results["scores"] = torch.from_numpy(autoreg_results["scores"]) 339 | # for key in shuffled_results.keys(): 340 | # shuffled_results[key] = torch.from_numpy(shuffled_results[key]) 341 | 342 | # Save 343 | torch.save(reg_results, out_reg / "reference_metrics.th") 344 | torch.save(autoreg_results, out_autoreg / "reference_metrics.th") 345 | 346 | if args.shuffle: 347 | for key, value in shuffled_results.items(): 348 | torch.save({'scores': value, 'label': 'lin reg ' + label_add}, 349 | out_reg / f"shuffled_{key}_metrics.th") 350 | 351 | 352 | if __name__ == "__main__": 353 | main() 354 | -------------------------------------------------------------------------------- /neural/linear/arx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # import libraries 8 | import numpy as np 9 | import statsmodels.api as sm 10 | from sklearn.linear_model import Ridge 11 | from sklearn.preprocessing import StandardScaler 12 | 13 | from .lin_model_template import lin_model 14 | from .stats import quick_svd 15 | 16 | 17 | class ARX(lin_model): 18 | def __init__(self, lag_u, lag_y, solver="ridge", penal_weight=1., scaling=False, log=False): 19 | 20 | # model choice 21 | self.lag_u = lag_u 22 | self.lag_y = lag_y 23 | self.maxlag = np.max([self.lag_u, self.lag_y]) 24 | self.scaling = scaling 25 | self.penal_weight = penal_weight 26 | 27 | # model architecture 28 | 29 | # learned feats 30 | self.weights = np.array([]) 31 | self.weights_u = np.array([]) 32 | self.weights_y = np.array([]) 33 | self.A = np.array([]) 34 | 35 | # data properties 36 | self.n_channels_y = 0 37 | self.n_channels_u = 0 38 | self.n_feats_x = 0 39 | self.n_feats_v = 0 40 | self.scaler_target = StandardScaler() 41 | self.scaler_y = StandardScaler() 42 | self.scaler_u = StandardScaler() 43 | 44 | self.solver = solver 45 | self.model = 0 46 | 47 | self.regressors_names = list() 48 | self.residuals = 0 49 | 50 | self.log = log 51 | 52 | def fit(self, U, Y): 53 | """ 54 | Input: 55 | ----- 56 | 57 | U : numpy array (n_samples x n_times x n_channels_u) 58 | Y : numpy array (n_samples x n_times x n_channels_y) 59 | """ 60 | 61 | if self.log: 62 | print("----------------------------------------- \n") 63 | print("\n TRAIN \n") 64 | print("----------------------------------------- \n") 65 | 66 | target, regression_mat = self.formulate_regression(U, Y) 67 | 68 | # plug-in model to solve the regression 69 | if self.log: 70 | print("Solving the Least-Squares...\n") 71 | if self.solver == "ridge": 72 | 73 | self.model = Ridge(alpha=self.penal_weight) 74 | self.model.fit(regression_mat, target) 75 | weights = self.model.coef_.T 76 | residuals = target - regression_mat @ weights 77 | if self.log: 78 | print("shape of weights: ", weights.T.shape, "\n") 79 | 80 | if self.solver == "feasible": # using statsmodels 81 | 82 | self.model = sm.GLSAR(target, regression_mat) 83 | result = self.model.fit() 84 | weights = result.params 85 | 86 | if len(weights.shape) == 1: # fixing dimensionality bug 87 | weights = weights.reshape(weights.shape[-1], -1) 88 | 89 | residuals = target - regression_mat @ weights 90 | if self.log: 91 | print("shape of weights: ", weights.T.shape, "\n") 92 | 93 | if self.solver == "dmd": # pseudo-inverse the tranposed system 94 | 95 | # formulation 96 | # (n_feats_x + n_feats_u, n_common_times * n_samples) 97 | snapshots = regression_mat.T 98 | snapshots_next = target.T 99 | if self.log: 100 | print("shape of snapshot matrix: ", snapshots.shape) 101 | 102 | # compute surrogate right-side svd 103 | UU, DD, VVt = quick_svd(snapshots, rank_perc=self.penal_weight) 104 | 105 | # inverse the system and obtain weights 106 | # left-side svd (optional? not done here) 107 | if self.log: 108 | print("\n Inverting the system... \n") 109 | 110 | weights = snapshots_next @ VVt.T @ np.diag(1. / DD) @ UU.T 111 | weights = weights.T 112 | residuals = target - regression_mat @ weights 113 | if self.log: 114 | print("shape of weights: ", weights.T.shape) 115 | 116 | # record learnt coefficients 117 | self.residuals = residuals 118 | self.weights = weights.T 119 | if len(self.weights.shape) == 0: # debug 120 | self.weights = self.weights[None, :] 121 | 122 | # reformat weights to [(n_output_channels, n_lags, n_input_channels) 123 | # for u, for y] 124 | self.weights_y = self.weights[:, :self.lag_y * self.n_channels_y].reshape( 125 | self.n_channels_y, self.lag_y, self.n_channels_y) 126 | self.weights_u = self.weights[:, self.lag_y * self.n_channels_y:].reshape( 127 | self.n_channels_y, self.lag_u, self.n_channels_u) 128 | 129 | # form recurrence matrix from weights 130 | if self.lag_y != 0: 131 | self.A = np.concatenate([ 132 | self.weights_y.reshape(self.n_channels_y, -1), 133 | np.eye(N=self.n_feats_x - self.n_channels_y, M=self.n_feats_x, k=self.n_channels_y) 134 | ], 135 | axis=0) 136 | else: 137 | self.A = np.zeros((self.n_feats_x, self.n_feats_x)) 138 | -------------------------------------------------------------------------------- /neural/linear/lin_model_template.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # import libraries 8 | import copy 9 | 10 | import numpy as np 11 | from sklearn.preprocessing import StandardScaler 12 | 13 | from ..visuals import plot_eigvalues, plt 14 | from .stats import statespace_transform 15 | 16 | # Mother class 17 | 18 | 19 | class lin_model: 20 | def __init__(self, lag_y, lag_u, penal_weight=1., scaling=False, log=False): 21 | 22 | # model choice 23 | self.lag_u = lag_u 24 | self.lag_y = lag_y 25 | self.maxlag = np.max([self.lag_u, self.lag_y]) 26 | self.scaling = scaling 27 | self.penal_weight = penal_weight 28 | 29 | # model architecture 30 | 31 | # learned feats 32 | self.weights = np.array([]) 33 | self.weights_u = np.array([]) 34 | self.weights_y = np.array([]) 35 | self.A = np.array([]) 36 | 37 | # data properties 38 | self.n_channels_y = 0 39 | self.n_channels_u = 0 40 | self.n_feats_x = 0 41 | self.n_feats_v = 0 42 | self.scaler_target = StandardScaler() 43 | self.scaler_y = StandardScaler() 44 | self.scaler_u = StandardScaler() 45 | 46 | self.solver = solver 47 | self.model = 0 48 | 49 | self.regressors_names = list() 50 | self.residuals = 0 51 | 52 | self.log = log 53 | 54 | def formulate_regression(self, U, Y): 55 | """ 56 | Input: 57 | ----- 58 | 59 | U : numpy array (n_samples x n_times x n_channels_u) 60 | Y : numpy array (n_samples x n_times x n_channels_y) 61 | """ 62 | 63 | if self.log: 64 | print("----------------------------------------- \n") 65 | print("\n FORMULATING REGRESSION... \n") 66 | print("----------------------------------------- \n") 67 | 68 | # initializing variables 69 | n_samples, n_times, self.n_channels_y = Y.shape 70 | self.n_channels_u = U.shape[2] 71 | 72 | self.n_feats_x = self.n_channels_y * self.lag_y 73 | self.n_feats_v = self.n_channels_u * self.lag_u 74 | 75 | # scaling input timeseries 76 | if self.scaling: 77 | Y = self.scaler_y.fit_transform(Y.reshape(-1, Y.shape[-1])) 78 | U = self.scaler_u.fit_transform(U.reshape(-1, U.shape[-1])) 79 | Y = Y.reshape(n_samples, n_times, self.n_channels_y) 80 | U = U.reshape(n_samples, n_times, self.n_channels_u) 81 | 82 | # # zero padding for initialization? 83 | # U_ini = np.zeros((n_samples, self.lag_u, self.n_channels_u)) 84 | # Y_ini = np.zeros((n_samples, self.lag_y, self.n_channels_y)) 85 | # U = np.concatenate([U_ini, U], axis=1) 86 | # Y = np.concatenate([Y_ini, Y], axis=1) 87 | 88 | # convert: canonical space timeseries (Y, U) -> state space timeseries (X, V) 89 | X = statespace_transform(Y, self.lag_y) 90 | V = statespace_transform(U, self.lag_u) 91 | 92 | # take common time-length 93 | # Y and U having different lags and nb channels, their statespace timeseries have different lengths 94 | n_common_times = np.min([X.shape[1], V.shape[1]]) 95 | X = X[:, -n_common_times:, :] 96 | V = V[:, -n_common_times:, :] 97 | _, _, n_feats_x = X.shape # (n_samples, n_common_times, n_feats_x) 98 | _, _, n_feats_v = V.shape 99 | 100 | # formulate the problem as a regression 101 | regression_mat = np.concatenate( 102 | [ 103 | X[:, :-1, :].reshape((n_common_times - 1) * n_samples, -1).T, # () 104 | V[:, 1:, :].reshape((n_common_times - 1) * n_samples, -1).T 105 | ], 106 | axis=0).T # (n_common_times * n_samples, n_feats_x + n_feats_u) 107 | 108 | target = Y[:, -n_common_times + 1:, :].reshape(-1, self.n_channels_y) 109 | 110 | if self.log: 111 | print("shape of target: ", target.shape, "\n") 112 | print("shape of regression matrix: ", regression_mat.shape, "\n") 113 | 114 | return target, regression_mat 115 | 116 | def plot_weights(self, summarize=True, names_u=[]): 117 | 118 | if len(names_u) == 0: 119 | names_u = ["U Channel " + str(channel_u) for channel_u in range(self.n_channels_u)] 120 | 121 | if not summarize: 122 | 123 | # plot forcing weights 124 | fig, axes = plt.subplots(self.n_channels_u, 1, sharex=True) 125 | 126 | for channel_u in range(self.n_channels_u): 127 | 128 | axes[channel_u].set_title(names_u[channel_u] + " Weights") 129 | axes[channel_u].plot(self.weights_u[:, :, channel_u].T) 130 | 131 | plt.xlabel("Lags") 132 | plt.tight_layout() 133 | plt.show() 134 | plt.close() 135 | 136 | # plot recurrence weights 137 | # be careful: display optimized for even nb of MEG Princ Comps 138 | fig, axes = plt.subplots(self.n_channels_y // 2, 2, sharex=True) 139 | 140 | for channel_y in range(self.n_channels_y // 2): 141 | axes[channel_y, 0].set_title("Y Channel " + str(channel_y) + " Weights") 142 | axes[channel_y, 0].plot(self.weights_y[:, :, channel_y].T) 143 | 144 | for channel_y in range(self.n_channels_y // 2, self.n_channels_y): 145 | axes[channel_y - (self.n_channels_y // 2), 1].set_title("Y Channel " + 146 | str(channel_y) + " Weights") 147 | axes[channel_y - (self.n_channels_y // 2), 1].plot( 148 | self.weights_y[:, :, channel_y].T) 149 | 150 | plt.xlabel("Lags") 151 | plt.tight_layout() 152 | plt.tight_layout() 153 | plt.show() 154 | plt.close() 155 | 156 | if summarize: 157 | 158 | fig, axes = plt.subplots(2, 1, figsize=(8.15, 3.53)) 159 | 160 | # forcing weights 161 | for channel_u in range(self.n_channels_u): 162 | axes[0].fill_between( 163 | range(self.lag_u), 164 | np.mean(self.weights_u[:, :, channel_u]**2, 165 | axis=0), # mean over output channels 166 | label=names_u[channel_u], 167 | alpha=0.25) 168 | axes[0].set_title("Forcing Weights over Lags") 169 | axes[0].legend() 170 | 171 | # recurrence weights 172 | for channel_y in range(self.n_channels_y): 173 | axes[1].fill_between( 174 | range(self.lag_y), 175 | np.mean(self.weights_y[:, :, channel_y]**2, 176 | axis=0), # mean over output channels 177 | label="channel " + str(channel_y), 178 | alpha=0.25) 179 | axes[1].set_title("Recurrence Weights over Lags") 180 | 181 | plt.tight_layout() 182 | plt.show() 183 | plt.close() 184 | 185 | def plot_recurrence_eigvalues(self): 186 | 187 | plot_eigvalues(self.A) 188 | 189 | def predict(self, U=None, Y=None, start=0, eval="unrolled"): 190 | 191 | if self.log: 192 | print("\n-----------------------------------------") 193 | print("\n PREDICT \n") 194 | print("-----------------------------------------\n") 195 | 196 | # instantiate U and Y 197 | n_epochs, n_times, self.n_channels_u = U.shape 198 | 199 | if Y is None: 200 | Y = np.zeros((n_epochs, n_times, self.n_channels_y)) 201 | if U is None: 202 | U = np.zeros((n_epochs, n_times, self.n_channels_u)) 203 | 204 | # augment U and Y_pred with generic initializations for prediction 205 | U_ini = np.zeros((n_epochs, self.lag_u, self.n_channels_u)) 206 | Y_ini = np.zeros((n_epochs, self.lag_y, self.n_channels_y)) 207 | 208 | U_augmented = np.concatenate([U_ini, U], axis=1) 209 | Y_augmented = np.concatenate([Y_ini, Y], axis=1) 210 | 211 | if self.log: 212 | print(U_augmented.shape) 213 | print(Y_augmented.shape) 214 | 215 | # convert: canonical space timeseries -> state space timeseries 216 | V = statespace_transform(U_augmented, self.lag_u) 217 | X = statespace_transform(Y_augmented, self.lag_y) 218 | _, _, n_feats_v = V.shape 219 | _, _, n_feats_x = X.shape 220 | 221 | # make sure n_times + 1 222 | V = V[:, -(n_times + 1):, :] 223 | X = X[:, -(n_times + 1):, :] 224 | 225 | # standard scale 226 | # V_augmented = self.scaler_v.transform(V_augmented.reshape(-1, n_feats_v)) 227 | # V_augmented = V_augmented.reshape(n_epochs, n_times + 1, n_feats_v) 228 | 229 | if self.log: 230 | print("\n Constructing predicted trajectories... \n") 231 | 232 | # initialize pred 233 | Y_pred = list() 234 | 235 | if start > 0: 236 | for idx in range(start): 237 | Y_pred.append(Y[:, idx, :]) 238 | 239 | if eval == "onestep": 240 | X_true = copy.deepcopy(X) 241 | # THIS IS NEW 242 | # X = np.zeros(n_epochs, n_times, self.weights.shape[0]) 243 | 244 | # print("DIM OF X_TRUE IS: ", X_true.shape) 245 | 246 | for t in range(start, n_times): 247 | 248 | # if self.scaling: 249 | # V_contrib = np.array([(self.scaler_v.transform(V_augmented[epoch, t+1, :][None, :])).flatten() 250 | # for epoch in range(n_epochs)]) 251 | # X_contrib = np.array([(self.scaler_x.transform(X_pred_augmented[epoch, t, :][None, :])).flatten() 252 | # for epoch in range(n_epochs)]) 253 | V_contrib = V[:, t + 1, :] 254 | 255 | if eval == "unrolled": 256 | X_contrib = X[:, t, :] 257 | elif eval == "onestep": 258 | X_contrib = X_true[:, t, :] 259 | 260 | # currstate = np.concatenate([X_pred_augmented[:, t, :], 261 | # V_augmented[:, t+1, :]], 262 | # axis=1) 263 | 264 | currstate = np.concatenate([X_contrib, V_contrib], axis=1) 265 | if self.log: 266 | print("dim of currstate: ", currstate.shape) 267 | # n_feats = currstate.shape[-1] 268 | 269 | A_reduced = self.weights # (n_channels_y, n_feats) 270 | 271 | pred = currstate @ A_reduced.T 272 | if self.log: 273 | print("dim of pred: ", pred.shape) 274 | # if self.scaling: 275 | # pred = np.array([(self.scaler_target.inverse_transform(pred[epoch][None, :])).flatten() 276 | # for epoch in range(n_epochs)]) 277 | obj2 = X[:, t, :-self.n_channels_y:] 278 | if self.log: 279 | print("dim of obj2: ", obj2.shape) 280 | np.concatenate( 281 | [ 282 | pred, # n_epochs, n_channels_y 283 | obj2 284 | ], 285 | axis=1) # concatenate over time 286 | if self.log: 287 | print("concatenation works!") 288 | print("dim of X[:, t+1, :]: ", X[:, t + 1, :].shape) 289 | X[:, t + 1, :] = np.concatenate( 290 | [ 291 | pred, # n_epochs, n_channels_y 292 | obj2 293 | ], 294 | axis=1) # concatenate over time 295 | 296 | # Y_pred.append(pred) 297 | 298 | # Y_pred = np.array(Y_pred) 299 | 300 | # X_pred = self.scaler_x.inverse_transform(X_pred) 301 | Y_pred = X[:, 1:, :self.n_channels_y] 302 | 303 | # return np.swapaxes(Y_pred, 0, 1) 304 | return Y_pred 305 | -------------------------------------------------------------------------------- /neural/linear/receptive_field.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # import libraries 8 | import numpy as np 9 | from mne.decoding import ReceptiveField 10 | 11 | from ..visuals import plt 12 | 13 | 14 | class RField: 15 | def __init__(self, lag_u, penal_weight=1e3): 16 | self.lag_u = lag_u 17 | self.penal_weight = penal_weight 18 | self.model = ReceptiveField( 19 | tmin=0., tmax=lag_u, sfreq=1., estimator=self.penal_weight) 20 | self.n_channels_u = 0 21 | 22 | def fit(self, U, Y): 23 | 24 | self.n_channels_u = U.shape[2] 25 | # swap 2 first axes for MNE: 26 | # (n_samples, n_times, n_channels) -> (n_times, n_samples, n_channels) 27 | self.model.fit(np.swapaxes(U, 0, 1), np.swapaxes(Y, 0, 1)) 28 | 29 | def plot_weights(self, summarize=True, names_u=[]): 30 | 31 | if len(names_u) == 0: 32 | names_u = [ 33 | "U Channel " + str(channel_u) 34 | for channel_u in range(self.n_channels_u) 35 | ] 36 | 37 | if not summarize: 38 | 39 | # plot forcing weights 40 | fig, axes = plt.subplots(self.n_channels_u, 1, sharex=True) 41 | 42 | for increment, channel_u in enumerate( 43 | list(range(self.n_channels_u))): 44 | 45 | axes[channel_u].set_title(names_u[channel_u] + " Weights") 46 | weights = self.model.coef_[:, channel_u, :].T 47 | axes[channel_u].plot(weights) 48 | 49 | plt.xlabel("Lags") 50 | plt.tight_layout() 51 | plt.show() 52 | plt.close() 53 | 54 | if summarize: 55 | 56 | fig, axes = plt.subplots(2, 1, figsize=(10, 5)) 57 | 58 | # forcing weights 59 | for increment, channel_u in enumerate( 60 | list(range(self.n_channels_u))): 61 | weights = self.model.coef_[:, channel_u, :].T 62 | axes[0].fill_between( 63 | range(weights.shape[0]), 64 | np.sum(weights**2, axis=1), 65 | label=names_u[channel_u], 66 | alpha=0.25) 67 | axes[0].set_title("Forcing Weights over Lags") 68 | axes[0].legend() 69 | 70 | plt.tight_layout() 71 | plt.show() 72 | plt.close() 73 | 74 | def predict(self, U, U_ini=np.array([]), Y_ini=np.array([])): 75 | 76 | return np.swapaxes(self.model.predict(np.swapaxes(U, 0, 1)), 0, 1) 77 | -------------------------------------------------------------------------------- /neural/linear/stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # import libraries 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import pandas as pd 12 | import scipy 13 | import scipy.linalg 14 | import seaborn as sns 15 | import torch 16 | 17 | 18 | def statespace_transform(Y, lag): 19 | """ 20 | Converts canonical-space timeseries into state-space timeseries 21 | by squeezing multichannel past, up to the chosen lag order, 22 | into a single lag vector. 23 | 24 | The advantage of using the state-space, is being able to write any 25 | finite-dimensional Dynamical System as order 1. 26 | 27 | https://en.wikipedia.org/wiki/Vector_autoregression#Writing_VAR(p)_as_VAR(1) 28 | 29 | Input: 30 | ----- 31 | 32 | Y : numpy array, (n_samples, n_times, n_channels) 33 | Observed timeseries. 34 | 35 | lag: float 36 | Chosen autoregression order. 37 | 38 | Output: 39 | ------ 40 | 41 | Y_statespace : numpy array, (n_samples, n_times_statespace, n_feats_statespace) 42 | Statespace timeseries, where: 43 | n_times_statespace = n_times - lag 44 | n_feats_statespace = n_channels * lag 45 | 46 | """ 47 | 48 | # print("\n Converting to Statespace... \n") 49 | 50 | # Initialize 51 | n_samples, n_times, _ = Y.shape 52 | Y_statespace = [] 53 | 54 | for sample in range(n_samples): 55 | 56 | # We want most recent on top = at the beginning in statespace vector 57 | Y_statespace.append( 58 | np.array([ 59 | Y[sample, np.arange(t, t + lag)[::-1], :].flatten() 60 | for t in range(n_times - lag + 1) 61 | ])) 62 | 63 | return np.array(Y_statespace) 64 | 65 | 66 | def statespace_inversetransform(Y_statespace, n_channels, lag): 67 | """ 68 | Input: 69 | ----- 70 | 71 | Y_statespace : numpy array, (n_samples, n_times_statespace, n_feats_statespace) 72 | Statespace timeseries, where: 73 | n_times_statespace = n_times - lag 74 | n_feats_statespace = n_channels * lag 75 | 76 | n_channels: int 77 | Nb of original channels. 78 | 79 | lag: float 80 | Chosen autoregression order. 81 | 82 | 83 | Output: 84 | ----- 85 | 86 | Y : numpy array, (n_samples, n_times, n_channels) 87 | Observed timeseries. 88 | """ 89 | 90 | # print("\n From Statespace to Original space... \n") 91 | 92 | n_samples = Y_statespace.shape[0] 93 | Y = [] 94 | 95 | for sample in range(n_samples): 96 | 97 | Y.append( 98 | np.concatenate([ 99 | Y_statespace[sample, 0, :].reshape(lag, n_channels)[::-1], 100 | Y_statespace[sample, 1:, :n_channels] 101 | ], 102 | axis=0)) 103 | 104 | return np.array(Y) 105 | 106 | 107 | # format for LSTM: (B, T, C) -> (T, B, C) in pytorch 108 | def LSTM_format_transform(data, device): 109 | """data is a numpy array of shape (B, T, C). 110 | We want to output a torch tensor of shape (T, B, C).""" 111 | return torch.from_numpy(data).float().transpose(0, 1) 112 | 113 | 114 | def LSTM_format_inversetransform(data): 115 | """data is a numpy array of shape (T, B, C). 116 | We want to output a torch tensor of shape (B, T, C).""" 117 | return data.transpose(0, 1).cpu().detach().numpy() 118 | 119 | 120 | def svd_fat(X, rank_perc=1.): 121 | 122 | "Quick truncated svd for horizontal matrix." 123 | 124 | sym = X @ X.T 125 | 126 | eigvectors, eigvalues, _ = scipy.linalg.svd(sym) 127 | eigvalues = np.real(eigvalues) 128 | 129 | # trim it 130 | picks = np.arange(int(rank_perc * eigvalues.size)) 131 | # picks = np.where(eigvalues > truncate)[0] 132 | U = eigvectors[:, picks] 133 | D = np.sqrt(eigvalues[picks]) 134 | temp = np.diag(1. / D) @ U.T 135 | Vt = temp @ X 136 | 137 | return U, D, Vt 138 | 139 | 140 | def quick_svd(X, rank_perc=1.): 141 | """Quick, approximate, truncated svd for matrix depending on its 142 | horizontal/vertical ratio.""" 143 | 144 | n_lines, n_cols = X.shape 145 | 146 | # fat matrix 147 | if n_cols > 2 * n_lines: 148 | # print("\n Computing SVD using fat matrix option... \n") 149 | U, D, Vt = svd_fat(X, rank_perc) 150 | return U, D, Vt 151 | 152 | # tall matrix 153 | if n_lines > 2 * n_cols: 154 | # print("\n Computing SVD using tall matrix option... \n") 155 | U, D, Vt = svd_fat(X.T, rank_perc) 156 | return Vt.T, D, U.T 157 | 158 | # squarish matrix 159 | else: 160 | # print("\n Computing SVD directly using scipy function... \n") 161 | U, D, Vt = scipy.linalg.svd(X) 162 | 163 | # truncate 164 | D = np.real(D) 165 | picks = np.arange(int(rank_perc * D.size)) 166 | # picks = np.where(D > truncate)[0] 167 | U = U[:, picks] 168 | D = D[picks] 169 | Vt = Vt[picks] 170 | 171 | return U, D, Vt 172 | 173 | 174 | def R_score(Y_true, Y_pred, avg_out="times"): 175 | """ 176 | Y_true: ndarray (n_epochs, n_times, n_channels_y) 177 | Y_pred: ndarray (n_epochs, n_times, n_channels_y) 178 | """ 179 | 180 | Y_true = torch.from_numpy(Y_true) 181 | Y_pred = torch.from_numpy(Y_pred) 182 | 183 | if avg_out == "epochs": 184 | dim = 0 185 | 186 | elif avg_out == "times": 187 | dim = 1 188 | 189 | cov = (Y_true * Y_pred).mean(dim) 190 | na, nb = [(i**2).mean(dim)**0.5 for i in [Y_true, Y_pred]] 191 | norms = na * nb 192 | R_matrix = cov / norms 193 | return R_matrix.mean(1).cpu().numpy() # avg over channels 194 | 195 | 196 | def R_score_v2(Y_true, Y_pred, score="r", avg_out="times", start=0): 197 | """ 198 | Y_true: numpy or torch (B, T, C) 199 | Y_pred: numpy or torch (B, T, C) 200 | """ 201 | 202 | if type(Y_true) is not np.ndarray: 203 | Y_true = torch.from_numpy(Y_true) 204 | if type(Y_pred) is not np.ndarray: 205 | Y_pred = torch.from_numpy(Y_pred) 206 | 207 | if avg_out == "epochs": 208 | dim = 0 209 | 210 | elif avg_out == "times": 211 | dim = 1 212 | Y_pred, Y_true = Y_pred[:, start:, :], Y_true[:, start:, :] 213 | 214 | if score == "r": 215 | cov = (Y_true * Y_pred).mean(dim) 216 | na, nb = [(i**2).mean(dim)**0.5 for i in [Y_true, Y_pred]] 217 | norms = na * nb 218 | R_matrix = cov / norms 219 | 220 | if score == "relativemse": 221 | Y_err = Y_pred - Y_true 222 | R_matrix = (Y_err**2).mean(dim) / (Y_true**2).mean(dim) # rename this score matrix!!! 223 | 224 | if type(R_matrix) is not np.ndarray: 225 | R_matrix = R_matrix.cpu().numpy() 226 | 227 | return R_matrix 228 | 229 | 230 | def report_correl(Y_true, Y_pred, path, start): 231 | """ 232 | Y_true: ndarray (n_epochs, n_times, n_channels_y) 233 | Y_pred: ndarray (n_epochs, n_times, n_channels_y) 234 | """ 235 | 236 | r_dynamic_epochs = R_score_v2(Y_true, Y_pred, avg_out="epochs") 237 | 238 | r_average_epochs = R_score_v2(Y_true, Y_true.mean(0, keepdims=True), avg_out="epochs") 239 | 240 | mse_dynamic_epochs = R_score_v2(Y_true, Y_pred, score="relativemse", avg_out="epochs") 241 | 242 | mse_average_epochs = R_score_v2( 243 | Y_true, Y_true.mean(0, keepdims=True), score="relativemse", avg_out="epochs") 244 | # r_scalar = R_score(Y_true, 245 | # Y_pred, 246 | # avg_out="times").mean() 247 | 248 | r_average_times = R_score_v2(Y_true[:, start:, :], Y_pred[:, start:, :], avg_out="times") 249 | 250 | r_average_times_evoked = R_score_v2( 251 | Y_true[:, start:, :].mean(0, keepdims=True), 252 | Y_pred[:, start:, :].mean(0, keepdims=True), 253 | avg_out="times") 254 | 255 | mse_average_times = R_score_v2( 256 | Y_true[:, start:, :], Y_pred[:, start:, :], score="relativemse", avg_out="times") 257 | 258 | mse_average_times_evoked = R_score_v2( 259 | Y_true[:, start:, :].mean(0, keepdims=True), 260 | Y_pred[:, start:, :].mean(0, keepdims=True), 261 | score="relativemse", 262 | avg_out="times") 263 | 264 | # print(r_scalar) 265 | 266 | fig, axes = plt.subplots(2, 4, figsize=(15, 5)) 267 | 268 | # Mean response 269 | axes[0, 0].plot(Y_pred.mean(0)) 270 | axes[0, 0].set_title("Predicted Response (Evoked)") 271 | axes[0, 0].axvline(x=start, ls="--") 272 | axes[0, 0].text(x=start, y=0, s="init") 273 | 274 | axes[1, 0].plot(Y_true.mean(0)) 275 | axes[1, 0].set_title("True Response (Evoked)") 276 | 277 | # Reponse to one stimulus 278 | axes[0, 1].plot(Y_pred[0]) 279 | axes[0, 1].set_title("Predicted Response (Epoch 0)") 280 | axes[0, 1].axvline(start, ls="--") 281 | axes[0, 1].text(x=start, y=0, s="init") 282 | 283 | axes[1, 1].plot(Y_true[0]) 284 | axes[1, 1].set_title("True Response (Epoch 0)") 285 | 286 | # Dynamic Correlation score 287 | axes[0, 2].plot(r_dynamic_epochs.mean(-1), label="epoch-wise correlation") 288 | # axes[0, 2].plot(r_dynamic_evoked, label="evoked-wise correlation") 289 | axes[0, 2].plot(r_average_epochs.mean(-1), label="baseline correlation") 290 | axes[0, 2].legend() 291 | axes[0, 2].set_title("Correlation along time") 292 | axes[0, 2].set_ylim(0, 1) 293 | axes[0, 2].locator_params(axis='x', nbins=20) 294 | axes[0, 2].locator_params(axis='y', nbins=10) 295 | axes[0, 2].grid() 296 | axes[0, 2].axvline(start, ls="--") 297 | axes[0, 2].text(x=start, y=0, s="init") 298 | 299 | # Scalar Correlation score 300 | # axes[1, 2].bar([0, 1, 2], [0, r_scalar, 0]) 301 | # axes[1, 2].set_title("Correlation Score") 302 | 303 | # Distributional Correlation score 304 | # epoched 305 | scores = r_average_times.T.flatten() 306 | pca_labels = np.concatenate( 307 | [[idx] * r_average_times.shape[0] for idx in range(r_average_times.shape[1])]) 308 | df = pd.DataFrame({"scores": scores, "pca_labels": pca_labels}) 309 | sns.boxplot(x="pca_labels", y="scores", data=df, ax=axes[1, 2]) 310 | axes[1, 2].set_title("Overall Correlation") 311 | # evoked 312 | scores = r_average_times_evoked.mean(0) 313 | pca_labels = np.arange(r_average_times.shape[-1]) 314 | axes[1, 2].plot(pca_labels, scores, label="corr of the trial-mean") 315 | axes[1, 2].legend(bbox_to_anchor=(0, 1), loc='upper left', ncol=1) 316 | 317 | # Dynamic MSE score 318 | axes[0, 3].plot(mse_dynamic_epochs.mean(-1), label="epoch-wise mse") 319 | # axes[0, 2].plot(r_dynamic_evoked, label="evoked-wise correlation") 320 | axes[0, 3].plot(mse_average_epochs.mean(-1), label="baseline mse") 321 | axes[0, 3].legend() 322 | axes[0, 3].set_title("Relative MSE along time") 323 | axes[0, 3].set_ylim(0, 1) 324 | axes[0, 3].locator_params(axis='x', nbins=20) 325 | axes[0, 3].locator_params(axis='y', nbins=10) 326 | axes[0, 3].grid() 327 | axes[0, 3].axvline(start, ls="--") 328 | axes[0, 3].text(x=start, y=0, s="init") 329 | 330 | # Distributional MSE score 331 | # epoched 332 | scores = mse_average_times.T.flatten() 333 | pca_labels = np.concatenate( 334 | [[idx] * mse_average_times.shape[0] for idx in range(mse_average_times.shape[1])]) 335 | df = pd.DataFrame({"scores": scores, "pca_labels": pca_labels}) 336 | sns.boxplot(x="pca_labels", y="scores", data=df, ax=axes[1, 3]) 337 | axes[1, 3].set_title("Overall Relative MSE") 338 | # evoked 339 | scores = mse_average_times_evoked.mean(0) 340 | pca_labels = np.arange(mse_average_times.shape[-1]) 341 | axes[1, 3].plot(pca_labels, scores, label="rel. MSE of the trial-mean") 342 | axes[1, 3].legend(bbox_to_anchor=(0, 1), loc='upper left', ncol=1) 343 | 344 | plt.tight_layout() 345 | plt.savefig(path) 346 | plt.close() 347 | 348 | 349 | def report_correl_all(Y_trues, Y_preds, U_trues, path): 350 | """ 351 | Y_trues: list of ndarray (n_epochs, n_times, n_channels_y) 352 | Y_preds: list of ndarray (n_epochs, n_times, n_channels_y) 353 | U_trues: list of ndarray (n_epochs, n_times, n_channels_u) 354 | """ 355 | 356 | rs_dynamic_epochs = list() 357 | rs_dynamic_evoked = list() 358 | rs_scalar = list() 359 | 360 | for subject, example in enumerate(zip(Y_trues, Y_preds, U_trues)): 361 | 362 | # unpack 363 | Y_true, Y_pred, U_true = example 364 | 365 | # calculate metrics for a given subject 366 | r_dynamic_epochs = R_score(Y_true, Y_pred, avg_out="epochs") 367 | 368 | r_dynamic_evoked = R_score( 369 | Y_true.mean(0, keepdims=True), Y_pred.mean(0, keepdims=True), avg_out="epochs") 370 | 371 | r_scalar = R_score(Y_true, Y_pred, avg_out="times").mean() 372 | 373 | # record these metrics 374 | rs_dynamic_epochs.append(r_dynamic_epochs) 375 | rs_dynamic_evoked.append(r_dynamic_evoked) 376 | rs_scalar.append(r_scalar) 377 | 378 | r_dynamic_epochs_mean = np.mean(rs_dynamic_epochs, axis=0) 379 | r_dynamic_epochs_std = np.std(rs_dynamic_epochs, axis=0) 380 | 381 | r_dynamic_evoked_mean = np.mean(rs_dynamic_evoked, axis=0) 382 | r_dynamic_evoked_std = np.std(rs_dynamic_evoked, axis=0) 383 | 384 | r_scalar_mean = np.mean(rs_scalar) 385 | r_scalar_std = np.std(rs_scalar) 386 | 387 | fig, axes = plt.subplots(3, 3, figsize=(15, 5)) 388 | 389 | # Mean response for a subject 390 | axes[0, 0].plot(Y_preds[0].mean(0)) 391 | axes[0, 0].set_title("A Predicted Response (Evoked)") 392 | 393 | axes[1, 0].plot(Y_trues[0].mean(0)) 394 | axes[1, 0].set_title("A True Response (Evoked)") 395 | 396 | axes[2, 0].plot(U_trues[0].mean(0)[:, 0], label="word presence") 397 | axes[2, 0].set_title("Stimulus (only onset shown here)") 398 | axes[2, 0].legend() 399 | 400 | # Reponse to one stimulus for a subject 401 | axes[0, 1].plot(Y_preds[0][0]) 402 | axes[0, 1].set_title("A Predicted Response (Epoch 0)") 403 | 404 | axes[1, 1].plot(Y_trues[0][0]) 405 | axes[1, 1].set_title("A True Response (Epoch 0)") 406 | 407 | axes[2, 1].plot(U_trues[0][0][:, 0], label="word presence") 408 | axes[2, 1].plot(U_trues[0][0][:, 1], label="word length") 409 | axes[2, 1].plot(U_trues[0][0][:, 2], label="word frequency") 410 | axes[2, 1].set_title("A Stimulus (all features)") 411 | axes[2, 1].legend(loc="upper right") 412 | 413 | # Dynamic Correlation score 414 | axes[0, 2].plot(r_dynamic_epochs_mean, label="epoch-wise correlation", color="#B03A2E") 415 | axes[0, 2].fill_between( 416 | range(r_dynamic_epochs_mean.size), 417 | r_dynamic_epochs_mean - r_dynamic_epochs_std, 418 | r_dynamic_epochs + r_dynamic_epochs_std, 419 | color="#F1948A", 420 | alpha=0.5) 421 | 422 | axes[0, 2].plot(r_dynamic_evoked_mean, label="evoked-wise correlation", color="#2874A6") 423 | axes[0, 2].fill_between( 424 | range(r_dynamic_evoked_mean.size), 425 | r_dynamic_evoked_mean - r_dynamic_evoked_std, 426 | r_dynamic_evoked_mean + r_dynamic_evoked_std, 427 | color="#85C1E9", 428 | alpha=0.5) 429 | 430 | axes[0, 2].legend() 431 | axes[0, 2].set_title("Correlation along time") 432 | 433 | # Scalar Correlation score 434 | axes[1, 2].bar([0, 1, 2], [0, r_scalar_mean, 0], yerr=[0, r_scalar_std, 0]) 435 | axes[1, 2].set_title("Correlation Score") 436 | 437 | plt.tight_layout() 438 | plt.savefig(path) 439 | plt.close() 440 | 441 | 442 | def report_correl_across_models(Y_trues, Y_preds, names, path): 443 | """ 444 | Y_trues: list of ndarray (n_epochs, n_times, n_channels_y) 445 | Y_preds: list of ndarray (n_epochs, n_times, n_channels_y) 446 | """ 447 | 448 | rs_dynamic_epochs = list() 449 | rs_scalar = list() 450 | 451 | for Y_true, Y_pred in zip(Y_trues, Y_preds): 452 | 453 | # calculate metrics for a given subject 454 | r_dynamic_epochs = R_score(Y_true, Y_pred, avg_out="epochs") 455 | 456 | r_scalar = R_score(Y_true, Y_pred, avg_out="times").mean() 457 | 458 | # record these metrics 459 | rs_dynamic_epochs.append(r_dynamic_epochs) 460 | rs_scalar.append(r_scalar) 461 | 462 | # correlation (trial-wise) 463 | fig, axes = plt.subplots(2, 1) 464 | 465 | for idx, name in enumerate(names): 466 | axes[0].plot(rs_dynamic_epochs[idx], label=name) 467 | 468 | axes[0].set_xlabel("Time steps") 469 | axes[0].set_title("Correlation (trial-wise)") 470 | axes[0].legend() 471 | 472 | # correlation (time-wise) 473 | axes[1].barh(range(len(names)), rs_scalar) 474 | axes[1].set_yticks(np.arange(len(names))) 475 | axes[1].set_yticklabels(names) 476 | axes[1].set_title("Correlation (time-wise)") 477 | 478 | plt.tight_layout() 479 | plt.savefig(path) 480 | plt.close() 481 | -------------------------------------------------------------------------------- /neural/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import math 7 | 8 | import torch as th 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | from .utils import center_trim 13 | 14 | 15 | class MegPredictor(nn.Module): 16 | def __init__(self, 17 | meg_dim, 18 | forcing_dims, 19 | meg_init=40, 20 | n_subjects=100, 21 | max_length=301, 22 | subject_dim=16, 23 | conv_layers=2, 24 | kernel=4, 25 | stride=2, 26 | conv_channels=256, 27 | lstm_hidden=256, 28 | lstm_layers=2): 29 | super().__init__() 30 | self.forcing_dims = dict(forcing_dims) 31 | self.meg_init = meg_init 32 | 33 | in_channels = meg_dim + 1 + subject_dim + sum(forcing_dims.values()) 34 | 35 | if subject_dim: 36 | self.subject_embedding = nn.Embedding(n_subjects, subject_dim) 37 | else: 38 | self.subject_embedding = None 39 | 40 | channels = conv_channels 41 | encoder = [] 42 | for _ in range(conv_layers): 43 | encoder += [ 44 | nn.Conv1d(in_channels, channels, kernel, stride, padding=kernel // 2), 45 | nn.ReLU(), 46 | ] 47 | in_channels = channels 48 | self.encoder = nn.Sequential(*encoder) 49 | if lstm_layers: 50 | self.lstm = nn.LSTM( 51 | input_size=in_channels, 52 | hidden_size=lstm_hidden, 53 | num_layers=lstm_layers) 54 | in_channels = lstm_hidden 55 | else: 56 | self.lstm = None 57 | self.conv_layers = conv_layers 58 | self.stride = stride 59 | self.kernel = kernel 60 | if conv_layers == 0: 61 | self.decoder = nn.Conv1d(in_channels, meg_dim, 1) 62 | else: 63 | decoder = [] 64 | for index in range(conv_layers): 65 | if index == conv_layers - 1: 66 | channels = meg_dim 67 | decoder += [ 68 | nn.ConvTranspose1d(in_channels, channels, kernel, stride, padding=kernel // 2), 69 | ] 70 | if index < conv_layers - 1: 71 | decoder += [nn.ReLU()] 72 | in_channels = channels 73 | self.decoder = nn.Sequential(*decoder) 74 | 75 | def get_meg_mask(self, meg, forcings): 76 | batch, _, time = meg.size() 77 | mask = th.zeros(batch, 1, time, device=meg.device) 78 | mask[:, :, :self.meg_init] = 1. 79 | return mask 80 | 81 | def valid_length(self, length): 82 | for _ in range(self.conv_layers): 83 | length = math.ceil(length / self.stride) + 1 84 | for _ in range(self.conv_layers): 85 | length = (length - 1) * self.stride 86 | return int(length) 87 | 88 | def pad(self, x): 89 | length = x.size(-1) 90 | valid_length = self.valid_length(length) 91 | delta = valid_length - length 92 | return F.pad(x, (delta // 2, delta - delta // 2)) 93 | 94 | def forward(self, meg, forcings, subject_id): 95 | forcings = dict(forcings) 96 | batch, _, length = meg.size() 97 | inputs = [] 98 | 99 | mask = self.get_meg_mask(meg, forcings) 100 | meg = meg * mask 101 | inputs += [meg, mask] 102 | 103 | if self.subject_embedding is not None: 104 | subject = self.subject_embedding(subject_id) 105 | inputs.append(subject.view(batch, -1, 1).expand(-1, -1, length)) 106 | 107 | if self.forcing_dims: 108 | _, forcings = zip(*sorted([(k, v) 109 | for k, v in forcings.items() if k in self.forcing_dims])) 110 | else: 111 | forcings = {} 112 | 113 | inputs.extend(forcings) 114 | 115 | x = th.cat(inputs, dim=1) 116 | x = self.pad(x) 117 | x = self.encoder(x) 118 | if self.lstm is not None: 119 | x = x.permute(2, 0, 1) 120 | x, _ = self.lstm(x) 121 | x = x.permute(1, 2, 0) 122 | out = self.decoder(x) 123 | return center_trim(out, length) 124 | -------------------------------------------------------------------------------- /neural/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | '''Trains a model with a train_eval_model function.''' 8 | 9 | from collections import namedtuple 10 | 11 | import torch as th 12 | from torch import nn 13 | from torch.utils import data 14 | from tqdm import tqdm 15 | 16 | SavedEval = namedtuple("SavedEval", "megs forcings predictions lengths subjects") 17 | 18 | 19 | def permute_forcing(first_mask, forcing, permutation, init=60): 20 | initial = forcing[:, :, init:init + 1] 21 | mask = first_mask > 0 22 | return th.where(mask, initial[permutation], forcing) 23 | 24 | 25 | def train_eval_model(dataset, 26 | model, 27 | optimizer=None, 28 | progress=True, 29 | train=True, 30 | save=False, 31 | device="cpu", 32 | batch_size=128, 33 | permut_feature=None, 34 | criterion=nn.MSELoss()): 35 | '''Train and Eval function. 36 | 37 | Inputs: 38 | ... 39 | - save: if True, the second output is a namedtuple with 40 | megs, [N, C, T] 41 | forcings, dict of values [N, 1, T] 42 | predictions, [N, C, T] # if regularization, T can change 43 | lengths, [N] 44 | subjects, [N] 45 | ''' 46 | 47 | dataloaded = data.DataLoader(dataset, batch_size=batch_size, shuffle=train) 48 | if train: 49 | desc = "train set" 50 | model.train() 51 | else: 52 | desc = "test set" 53 | model.eval() 54 | 55 | running_loss = 0 56 | 57 | dl_iter = iter(dataloaded) 58 | if progress: 59 | dl_iter = tqdm(dataloaded, leave=False, ncols=120, total=len(dataloaded), desc=desc) 60 | 61 | if save: 62 | saved = SavedEval([], [], [], [], []) 63 | 64 | batch_idx = 0 65 | 66 | for batch_idx, batch in enumerate(dl_iter): 67 | 68 | # Unpack batch and load unto device (e.g. gpu) 69 | meg, forcings, length, subject_id = batch # [B, C, T], dict of values [B, C, 1], [B], [B] 70 | meg = meg.to(device) 71 | forcings = {k: v.to(device) for k, v in forcings.items()} 72 | subject_id = subject_id.to(device) 73 | true_subject_id = subject_id 74 | length = length.to(device) 75 | 76 | n_batches, channels, n_times = meg.size() 77 | 78 | meg_true = meg 79 | 80 | # Permute an input feature (to measure its importance at test time) 81 | if permut_feature is not None: 82 | permutation = th.randperm(n_batches, device=device) 83 | if permut_feature == "meg": 84 | permutation = permutation.view(-1, 1, 1).expand(-1, meg.size(1), meg.size(-1)) 85 | meg = th.gather(meg, 0, permutation) 86 | elif permut_feature == "subject": 87 | subject_id = th.gather(subject_id, 0, permutation) 88 | else: 89 | forcing = forcings[permut_feature] 90 | forcings[permut_feature] = permute_forcing(forcings["first_mask"], forcing, 91 | permutation) 92 | saved_forcings = forcings 93 | 94 | # Predict, evaluate loss, backprop 95 | meg_pred = model(meg, forcings, subject_id) 96 | loss_train = criterion(meg_pred, meg_true) 97 | loss = criterion(meg_pred[..., model.meg_init:], meg_true[..., model.meg_init:]) 98 | running_loss += loss.item() 99 | 100 | if train: 101 | loss_train.backward() 102 | optimizer.step() 103 | optimizer.zero_grad() 104 | 105 | if save: 106 | # all quantities (meg, forcings, length, subject_id) saved in their original state, 107 | # except forcing which is saved in its permuted state 108 | saved.megs.append(meg_true.cpu()) 109 | saved.forcings.append({k: v.cpu() for k, v in saved_forcings.items()}) 110 | saved.predictions.append(meg_pred.detach().cpu()) 111 | saved.lengths.append(length.cpu()) 112 | saved.subjects.append(true_subject_id.cpu()) 113 | if progress: 114 | dl_iter.set_postfix(loss=running_loss / (batch_idx + 1)) 115 | 116 | n_batches = batch_idx + 1 # idx starts at 0 117 | running_loss /= n_batches # average over batches 118 | 119 | if save: 120 | saved = SavedEval( 121 | megs=th.cat(saved.megs), 122 | forcings={k: th.cat([v[k] for v in saved.forcings]) 123 | for k in forcings}, 124 | predictions=th.cat(saved.predictions), 125 | lengths=th.cat(saved.lengths), 126 | subjects=th.cat(saved.subjects)) 127 | else: 128 | saved = None 129 | 130 | return running_loss, saved 131 | -------------------------------------------------------------------------------- /neural/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os.path 7 | import shutil 8 | 9 | import numpy as np 10 | 11 | 12 | def create_directory(path, overwrite=False): 13 | 14 | # if it is not there, create it 15 | if not os.path.exists(path): 16 | os.makedirs(path) 17 | 18 | # if it is there and overwrite, remove then recreate 19 | if os.path.exists(path) and overwrite: 20 | shutil.rmtree(path) 21 | os.makedirs(path) 22 | 23 | return 0 24 | 25 | 26 | def get_metrics(Y_true, Y_pred): 27 | '''Computes the correlation of two [B, T, C] tensors 28 | over the first dimension. 29 | In this case, yields epoch-wise correlation per 30 | time step and channel. 31 | 32 | Inputs: 33 | - Y_true: torch tensor [N, T, C], truth 34 | - Y_pred: torch tensor [N, T, C], prediction 35 | 36 | Output: 37 | - R_matrix: torch tensor [T, C] of Pearson R scores 38 | ''' 39 | dim = 0 # avg-out epochs 40 | 41 | Y_true = Y_true - Y_true.mean(axis=dim, keepdims=True) 42 | Y_pred = Y_pred - Y_pred.mean(axis=dim, keepdims=True) 43 | cov = (Y_true * Y_pred).mean(dim) 44 | na, nb = [(i**2).mean(dim)**0.5 for i in [Y_true, Y_pred]] 45 | norms = na * nb 46 | R_matrix = cov / norms # shape (T, C) 47 | 48 | return R_matrix 49 | 50 | 51 | def center_trim(tensor, reference): 52 | if hasattr(reference, "size"): 53 | reference = reference.size(-1) 54 | delta = tensor.size(-1) - reference 55 | if delta < 0: 56 | raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") 57 | if delta: 58 | tensor = tensor[..., delta // 2:-(delta - delta // 2)] 59 | return tensor 60 | 61 | 62 | def inverse(mean, scaler, pca, Y): 63 | pca_channels, full_channels = pca.shape 64 | trials, time, channels = Y.shape 65 | 66 | Y = Y + mean.numpy() 67 | Y = Y.reshape(-1, channels) 68 | Y = scaler.inverse_transform(Y) 69 | return np.reshape(Y @ pca, (trials, time, full_channels)) 70 | -------------------------------------------------------------------------------- /neural/utils_mous.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import os.path as op 8 | import glob 9 | import shutil 10 | 11 | import numpy as np 12 | import pandas as pd 13 | from wordfreq import zipf_frequency as word_frequency 14 | from Levenshtein import editops 15 | 16 | 17 | def create_directory(path, overwrite=False): 18 | '''Create (or check for, or overwrite) directory. 19 | Input: 20 | path (str): path of the repository to be created 21 | overwrite (bool): default False 22 | ''' 23 | # if it is not there, create it 24 | if not os.path.exists(path): 25 | os.makedirs(path) 26 | 27 | # if it is there and overwrite, remove then recreate 28 | if os.path.exists(path) and overwrite: 29 | shutil.rmtree(path) 30 | os.makedirs(path) 31 | 32 | 33 | def get_word_length(words): 34 | '''Compute the length of words. 35 | Input: 36 | words (array of str): words 37 | Output: 38 | word_lengths (array of int): word lengths 39 | ''' 40 | return np.array([len(word) for word in words]) 41 | 42 | 43 | def get_word_freq(words, language="nl"): 44 | '''Compute the frequency of words in a language. 45 | Input: 46 | words (array of str): words 47 | language (int): default 'nl' for netherlands 48 | 49 | Output; 50 | word_lengths (array of int): word lengths 51 | ''' 52 | return np.array([word_frequency(word, 'nl') 53 | for word in words]) 54 | 55 | 56 | def get_word_n_remaining(words): 57 | '''Compute the number of remaining words in a sentence array. 58 | Input: 59 | words (array of str): words 60 | Output: 61 | words_remaining (array of int): number of remaining words 62 | ''' 63 | return np.array([len(words[idx:]) 64 | for idx in range(len(words))]) 65 | 66 | 67 | def setup_logfiles(data_path, cache): 68 | '''Creates a legible csv table (subject id, meg file, etc.) 69 | to access files in the MOUS database. 70 | Input: 71 | data_path (str): path to the MOUS database 72 | ''' 73 | fname_logfiles = os.path.join(cache, 'log_files.csv') 74 | 75 | if not op.isfile(fname_logfiles): 76 | tasks = dict(visual='Vis', auditory='Aud') 77 | log_files = list() 78 | 79 | for task in ('visual', 'auditory'): 80 | 81 | log_path = os.path.join(data_path, 'sourcedata', 'meg_task') 82 | files = glob.glob(os.path.join(log_path, "*-MEG-MOUS-%s*.log") % tasks[task]) 83 | 84 | for file in np.sort(files): 85 | subject = file.split('-')[0].split('/')[-1] 86 | log_files.append( 87 | dict(subject=int(subject[1:]), 88 | task=task, 89 | log_id=int(file.split('-')[1]), 90 | log_file=op.join('sourcedata', 'meg_task', 91 | file.split('/')[-1]), 92 | meg_file=op.join( 93 | 'sub-' + subject, 'meg', 94 | 'sub-%s_task-%s_meg.ds' % (subject, task)))) 95 | log_files = pd.DataFrame(log_files) 96 | # Remove corrupted log 97 | log_files = log_files.loc[(log_files.subject != 1006) 98 | & (log_files.subject != 1017)] 99 | log_files.to_csv(fname_logfiles) 100 | return pd.read_csv(fname_logfiles) 101 | 102 | 103 | def setup_stimuli(data_path, cache): 104 | '''Creates a legible csv table for the word stimuli used 105 | in the MOUS database. 106 | Input: 107 | data_path (str): path to the MOUS database 108 | ''' 109 | fname_stimuli = os.path.join(cache, 'stimuli.csv') 110 | if not op.isfile(fname_stimuli): 111 | 112 | source = op.join(data_path, 'stimuli', 'stimuli.txt') 113 | 114 | with open(source, 'r') as f: 115 | stimuli = f.read() 116 | while ' ' in stimuli: 117 | stimuli.replace(' ', ' ') 118 | 119 | # clean up 120 | stimuli = stimuli.split('\n')[:-1] 121 | stim_id = [int(s.split(' ')[0]) for s in stimuli] 122 | sequences = [' '.join(s.split(' ')[1:]) for s in stimuli] 123 | stimuli = pd.DataFrame([ 124 | dict(index=idx, sequence=seq) 125 | for idx, seq in zip(stim_id, sequences) 126 | ]) 127 | stimuli.to_csv(fname_stimuli, index=False) 128 | return pd.read_csv(fname_stimuli, index_col='index') 129 | 130 | 131 | def _parse_log(log_fname): 132 | '''Auxiliary function to preprocess the log_files table. 133 | TODO 134 | ''' 135 | with open(log_fname, 'r') as f: 136 | text = f.read() 137 | 138 | # Fixes broken inputs 139 | text = text.replace('.\n', '.') 140 | 141 | # file is made of two blocks 142 | block1, block2 = text.split('\n\n\n') 143 | 144 | # read first header 145 | header1 = block1.replace(' ', '_').split('\n')[3].split('\t') 146 | header1[6] = 'time_uncertainty' 147 | header1[8] = 'duration_uncertainty' 148 | 149 | # read first data 150 | df1 = pd.DataFrame([s.split('\t') for s in block1.split('\n')][5:], 151 | columns=header1) 152 | # the two dataframe are only synced on certains rows 153 | common_samples = ('Picture', 'Sound', 'Nothing') 154 | sel = df1['Event_Type'].apply(lambda x: x in common_samples) 155 | index = df1.loc[sel].index 156 | 157 | # read second header 158 | header2 = block2.replace(' ', '_').split('\n')[0].split('\t') 159 | header2[7] = 'time_uncertainty' 160 | header2[9] = 'duration_uncertainty' 161 | 162 | # read second data 163 | df2 = pd.DataFrame([s.split('\t') for s in block2.split('\n')[2:-1]], 164 | columns=header2, 165 | index=index) 166 | 167 | # remove duplicate 168 | duplicates = np.intersect1d(df1.keys(), df2.keys()) 169 | for key in duplicates: 170 | assert (df1.loc[index, key] == df2[key].fillna('')).all() 171 | df2.pop(key) 172 | 173 | log = pd.concat((df1, df2), axis=1) 174 | return log 175 | 176 | 177 | def _clean_log(log): 178 | '''Auxiliary function to preprocess the log_files table. 179 | Cleans and reformats the log_files table. 180 | Input: 181 | log (pandas table) 182 | Output: 183 | log_cleaned (pandas table) 184 | ''' 185 | # Relabel condition: only applies to sample where condition changes 186 | translate = dict( 187 | ZINNEN='sentence', 188 | WOORDEN='word_list', 189 | FIX='fix', 190 | QUESTION='question', 191 | Response='response', 192 | ISI='isi', 193 | blank='blank', 194 | ) 195 | for key, value in translate.items(): 196 | sel = log.Code.astype(str).str.contains(key) 197 | log.loc[sel, 'condition'] = value 198 | log.loc[log.Code == '', 'condition'] = 'blank' 199 | 200 | # Annotate sequence idx and extend context to all trials 201 | start = 0 202 | block = 0 203 | context = 'init' 204 | log['new_context'] = False 205 | query = 'condition in ("word_list", "sentence")' 206 | for idx, row in log.query(query).iterrows(): 207 | log.loc[start:idx, 'context'] = context 208 | log.loc[start:idx, 'block'] = block 209 | log.loc[idx, 'new_context'] = True 210 | context = row.condition 211 | block += 1 212 | start = idx 213 | log.loc[start:, 'context'] = context 214 | log.loc[start:, 'block'] = block 215 | 216 | # Format time 217 | log['time'] = 0 218 | idx = log.Time.str.isnumeric() == True # noqa 219 | log.loc[idx, 'time'] = log.loc[idx, 'Time'].astype(float) / 1e4 220 | 221 | # Extract individual word 222 | log.loc[log.condition.isna(), 'condition'] = 'word' 223 | idx = log.condition == 'word' 224 | words = log.Code.str.strip('0123456789 ') 225 | log.loc[idx, 'word'] = words.loc[idx] 226 | sel = log.query('word=="" and condition=="word"').index 227 | log.loc[sel, 'word'] = np.nan 228 | log.loc[log.word.isna() & (log.condition == "word"), 'condition'] = 'blank' 229 | return log 230 | 231 | 232 | def _add_stim_id(log, verbose, stimuli): 233 | '''Auxiliary function to preprocess the log_files table. 234 | Matches stimulus information between the log_files and stimuli tables. 235 | Input: 236 | log (pandas table) 237 | verbose (bool): if True, prints commentary 238 | stimuli (pandas table) 239 | Output: 240 | log_completed (pandas table) 241 | ''' 242 | # Find beginning of each sequence (word list or sentence) 243 | start = 0 244 | sequence_pos = -1 245 | for idx, row in log.query('condition == "fix"').iterrows(): 246 | if sequence_pos >= 0: 247 | log.loc[start:idx, 'sequence_pos'] = sequence_pos 248 | sequence_pos += 1 249 | start = idx 250 | log.loc[start:, 'sequence_pos'] = sequence_pos 251 | 252 | # Find corresponding stimulus id 253 | stim_id = 0 254 | lower30 = lambda s: s[:30].lower() # noqa 255 | stimuli['first_30_chars'] = stimuli.sequence.apply(lower30) 256 | sel = slice(0, 0) 257 | for pos, row in log.groupby('sequence_pos'): 258 | if pos == -1: 259 | continue 260 | 261 | # select words in this sequence 262 | sel = row.condition == "word" 263 | if not sum(sel): 264 | continue 265 | 266 | # match with stimuli 267 | first_30_chars = ' '.join(row.loc[sel, 'word'])[:30].lower() # noqa 268 | stim_id = stimuli.query('first_30_chars == @first_30_chars').index 269 | assert len(stim_id) == 1 270 | stim_id = stim_id[0] 271 | 272 | n_words = len(stimuli.loc[stim_id, 'sequence'].split(' ')) 273 | if (n_words != sum(sel)) and verbose: 274 | print('mistach of %i words in %s (stim %i)' % 275 | (n_words - sum(sel), pos, stim_id)) 276 | print('stim: %s' % stimuli.loc[stim_id, 'sequence']) 277 | print('log: %s' % ' '.join(row.loc[sel, 'word'])) 278 | 279 | # Update 280 | log.loc[row.index, 'stim_id'] = stim_id 281 | return log 282 | 283 | 284 | def read_log(log_fname, stimuli, task='auto', verbose=False): 285 | '''Reads and preprocesses (using auxiliary functions) a log_files table. 286 | Matches stimulus information between the log_files and stimuli tables. 287 | Input: 288 | log_fname (str): path to the log_files table 289 | stimuli (pandas table) 290 | task (str): if 'auto', deduce whether 'visual' or 'auditory', else specify 291 | verbose (bool): default False, does not print commentary 292 | Output: 293 | log (pandas table): preprocessed log_files table 294 | ''' 295 | log = _parse_log(log_fname) 296 | log = _clean_log(log) 297 | if task == 'auto': 298 | task = 'visual' if log_fname[-7:] == 'Vis.log' else 'auditory' 299 | if task == 'visual': 300 | log = _add_stim_id(log, verbose=verbose, stimuli=stimuli) 301 | return log 302 | 303 | 304 | def get_log_times(log, events, sfreq): 305 | ''' 306 | Adds time of occurence of events to the log_files table. 307 | Input: 308 | log (pandas table): log_files table 309 | events (array): cf https://mne.tools/dev/auto_tutorials/raw/plot_20_event_arrays.html 310 | sfreq (int): sampling frequency 311 | Output: 312 | log (pandas table): log_files table enriched with time of events information 313 | ''' 314 | sel = np.sort(np.r_[np.where(events[:, 2] == 20)[0], # fixation 315 | np.where(events[:, 2] == 10)[0] # context 316 | ]) 317 | common_megs = events[sel] 318 | common_logs = log.query('(new_context == True) or condition=="fix"') 319 | 320 | last_log = common_logs.time[0] 321 | last_meg = common_megs[0, 0] 322 | last_idx = 0 323 | assert len(common_megs) == len(common_logs) 324 | for common_meg, (idx, common_log) in zip(common_megs, 325 | common_logs.iterrows()): 326 | 327 | if common_meg[2] == 20: 328 | assert common_log.condition == 'fix' 329 | else: 330 | assert common_log.condition in ('sentence', 'word_list') 331 | 332 | log.loc[idx, 'meg_time'] = common_meg[0] / sfreq 333 | 334 | sel = slice(last_idx + 1, idx) 335 | times = log.loc[sel, 'time'] - last_log + last_meg / sfreq 336 | assert np.all(np.isfinite(times)) 337 | log.loc[sel, 'meg_time'] = times 338 | 339 | last_log = common_log.time 340 | last_meg = common_meg[0] 341 | last_idx = idx 342 | 343 | assert np.isfinite(last_log) * np.isfinite(last_meg) 344 | 345 | # last block 346 | sel = slice(last_idx + 1, None) 347 | times = log.loc[sel, 'time'] - last_log + last_meg / sfreq 348 | log.loc[sel, 'meg_time'] = times 349 | log['meg_sample'] = np.array(log.meg_time.values * sfreq, int) 350 | return log 351 | 352 | 353 | def match_list(A, B, on_replace='delete'): 354 | """Match two lists of different sizes and return corresponding indices 355 | Parameters 356 | ---------- 357 | A: list | array, shape (n,) 358 | The values of the first list 359 | B: list | array: shape (m, ) 360 | The values of the second list 361 | Returns 362 | ------- 363 | A_idx : array 364 | The indices of the A list that match those of the B 365 | B_idx : array 366 | The indices of the B list that match those of the A 367 | """ 368 | unique = np.unique(np.r_[A, B]) 369 | label_encoder = dict((k, v) for v, k in enumerate(unique)) 370 | 371 | def int_to_unicode(array): 372 | return ''.join([str(chr(label_encoder[ii])) for ii in array]) 373 | 374 | changes = editops(int_to_unicode(A), int_to_unicode(B)) 375 | B_sel = np.arange(len(B)).astype(float) 376 | A_sel = np.arange(len(A)).astype(float) 377 | for type, val_a, val_b in changes: 378 | if type == 'insert': 379 | B_sel[val_b] = np.nan 380 | elif type == 'delete': 381 | A_sel[val_a] = np.nan 382 | elif on_replace == 'delete': 383 | # print('delete replace') 384 | A_sel[val_a] = np.nan 385 | B_sel[val_b] = np.nan 386 | elif on_replace == 'keep': 387 | # print('keep replace') 388 | pass 389 | else: 390 | raise NotImplementedError 391 | B_sel = B_sel[np.where(~np.isnan(B_sel))] 392 | A_sel = A_sel[np.where(~np.isnan(A_sel))] 393 | assert len(B_sel) == len(A_sel) 394 | return A_sel.astype(int), B_sel.astype(int) 395 | -------------------------------------------------------------------------------- /neural/visuals.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | '''Largely unused. 8 | Functions for visualization. ''' 9 | 10 | import os 11 | 12 | import matplotlib 13 | import mne 14 | import numpy as np 15 | import pandas as pd 16 | import scipy 17 | import seaborn as sns 18 | 19 | 20 | # Ugly hack because my editor reorder imports automatically 21 | def init(): 22 | matplotlib.use('Agg') 23 | from matplotlib import pyplot as plt 24 | return plt 25 | 26 | 27 | def R_score_v2(Y_true, Y_pred, score="r", avg_out="times"): 28 | """ 29 | Y_true: ndarray (n_epochs, n_times, n_channels_y) 30 | Y_pred: ndarray (n_epochs, n_times, n_channels_y) 31 | """ 32 | 33 | if avg_out == "epochs": 34 | dim = 0 35 | 36 | elif avg_out == "times": 37 | dim = 1 38 | 39 | if score == "r": 40 | Y_true = Y_true - Y_true.mean(dim, keepdim=True) 41 | Y_pred = Y_pred - Y_pred.mean(dim, keepdim=True) 42 | cov = (Y_true * Y_pred).mean(dim) 43 | na, nb = [(i**2).mean(dim)**0.5 for i in [Y_true, Y_pred]] 44 | norms = na * nb 45 | R_matrix = cov / norms 46 | 47 | if score == "relativemse": 48 | Y_err = Y_pred - Y_true 49 | R_matrix = (Y_err**2).mean(dim) / (Y_true**2).mean(dim) # rename this score matrix!!! 50 | 51 | return R_matrix.cpu().numpy() 52 | 53 | 54 | def report_correl(Y_true, Y_pred, path, start, ref=None): 55 | """ 56 | Y_true: ndarray (n_epochs, n_times, n_channels_y) 57 | Y_pred: ndarray (n_epochs, n_times, n_channels_y) 58 | """ 59 | 60 | r_dynamic_epochs = R_score_v2(Y_true, Y_pred, avg_out="epochs") 61 | 62 | if ref is not None: 63 | r_average_epochs = R_score_v2(Y_true, ref, avg_out="epochs") 64 | ratio = (r_dynamic_epochs / r_average_epochs).mean(-1) 65 | else: 66 | r_average_epochs = R_score_v2(Y_true, Y_true.mean(0, keepdims=True), avg_out="epochs") 67 | 68 | mse_dynamic_epochs = R_score_v2(Y_true, Y_pred, score="relativemse", avg_out="epochs") 69 | 70 | mse_average_epochs = R_score_v2( 71 | Y_true, Y_true.mean(0, keepdims=True), score="relativemse", avg_out="epochs") 72 | # r_scalar = R_score(Y_true, 73 | # Y_pred, 74 | # avg_out="times").mean() 75 | 76 | r_average_times = R_score_v2(Y_true[:, start:, :], Y_pred[:, start:, :], avg_out="times") 77 | 78 | r_average_times_evoked = R_score_v2( 79 | Y_true[:, start:, :].mean(0, keepdims=True), 80 | Y_pred[:, start:, :].mean(0, keepdims=True), 81 | avg_out="times") 82 | 83 | mse_average_times = R_score_v2( 84 | Y_true[:, start:, :], Y_pred[:, start:, :], score="relativemse", avg_out="times") 85 | 86 | mse_average_times_evoked = R_score_v2( 87 | Y_true[:, start:, :].mean(0, keepdims=True), 88 | Y_pred[:, start:, :].mean(0, keepdims=True), 89 | score="relativemse", 90 | avg_out="times") 91 | 92 | # print(r_scalar) 93 | 94 | fig, axes = plt.subplots(2, 4, figsize=(15, 5)) 95 | 96 | # Mean response 97 | axes[0, 0].plot(Y_pred.mean(0)) 98 | axes[0, 0].set_title("Predicted Response (Evoked)") 99 | axes[0, 0].axvline(x=start, ls="--") 100 | axes[0, 0].text(x=start, y=0, s="init") 101 | 102 | axes[1, 0].plot(Y_true.mean(0)) 103 | axes[1, 0].set_title("True Response (Evoked)") 104 | 105 | # Reponse to one stimulus 106 | axes[0, 1].plot(Y_pred[0]) 107 | axes[0, 1].set_title("Predicted Response (Epoch 0)") 108 | axes[0, 1].axvline(start, ls="--") 109 | axes[0, 1].text(x=start, y=0, s="init") 110 | 111 | axes[1, 1].plot(Y_true[0]) 112 | axes[1, 1].set_title("True Response (Epoch 0)") 113 | 114 | # Dynamic Correlation score 115 | if ref is not None: 116 | # axes[0, 2].plot( 117 | # -np.log10(1e-8 + np.clip(1 - ratio, 0, 1)), 118 | # label="log10 1 - ratio of correl") 119 | axes[0, 2].plot(ratio) 120 | else: 121 | axes[0, 2].plot(r_dynamic_epochs.mean(-1), label="epoch-wise correlation") 122 | axes[0, 2].plot(r_average_epochs.mean(-1), label="baseline correlation") 123 | axes[0, 2].set_ylim(0, 1) 124 | axes[0, 2].legend() 125 | axes[0, 2].set_title("Correlation along time") 126 | axes[0, 2].locator_params(axis='x', nbins=20) 127 | axes[0, 2].locator_params(axis='y', nbins=10) 128 | axes[0, 2].grid() 129 | axes[0, 2].axvline(start, ls="--") 130 | axes[0, 2].text(x=start, y=0, s="init") 131 | 132 | # Scalar Correlation score 133 | # axes[1, 2].bar([0, 1, 2], [0, r_scalar, 0]) 134 | # axes[1, 2].set_title("Correlation Score") 135 | 136 | # Distributional Correlation score 137 | # epoched 138 | scores = r_average_times.T.flatten() 139 | pca_labels = np.concatenate( 140 | [[idx] * r_average_times.shape[0] for idx in range(r_average_times.shape[1])]) 141 | df = pd.DataFrame({"scores": scores, "pca_labels": pca_labels}) 142 | sns.boxplot(x="pca_labels", y="scores", data=df, ax=axes[1, 2]) 143 | axes[1, 2].set_title("Overall Correlation") 144 | # evoked 145 | scores = r_average_times_evoked.mean(0) 146 | pca_labels = np.arange(r_average_times.shape[-1]) 147 | axes[1, 2].plot(pca_labels, scores, label="corr of the trial-mean") 148 | axes[1, 2].legend(bbox_to_anchor=(0, 1), loc='upper left', ncol=1) 149 | 150 | # Dynamic MSE score 151 | axes[0, 3].plot(mse_dynamic_epochs.mean(-1), label="epoch-wise mse") 152 | # axes[0, 2].plot(r_dynamic_evoked, label="evoked-wise correlation") 153 | axes[0, 3].plot(mse_average_epochs.mean(-1), label="baseline mse") 154 | axes[0, 3].legend() 155 | axes[0, 3].set_title("Relative MSE along time") 156 | axes[0, 3].set_ylim(0, 1) 157 | axes[0, 3].locator_params(axis='x', nbins=20) 158 | axes[0, 3].locator_params(axis='y', nbins=10) 159 | axes[0, 3].grid() 160 | axes[0, 3].axvline(start, ls="--") 161 | axes[0, 3].text(x=start, y=0, s="init") 162 | 163 | # Distributional MSE score 164 | # epoched 165 | scores = mse_average_times.T.flatten() 166 | pca_labels = np.concatenate( 167 | [[idx] * mse_average_times.shape[0] for idx in range(mse_average_times.shape[1])]) 168 | df = pd.DataFrame({"scores": scores, "pca_labels": pca_labels}) 169 | sns.boxplot(x="pca_labels", y="scores", data=df, ax=axes[1, 3]) 170 | axes[1, 3].set_title("Overall Relative MSE") 171 | # evoked 172 | scores = mse_average_times_evoked.mean(0) 173 | pca_labels = np.arange(mse_average_times.shape[-1]) 174 | axes[1, 3].plot(pca_labels, scores, label="rel. MSE of the trial-mean") 175 | axes[1, 3].legend(bbox_to_anchor=(0, 1), loc='upper left', ncol=1) 176 | 177 | plt.tight_layout() 178 | plt.savefig(path) 179 | plt.close() 180 | 181 | 182 | def make_train_test_curve(train_losses, test_losses, path, show=False, save=True): 183 | plt.plot(train_losses, label="train") 184 | plt.plot(test_losses, label="test") 185 | plt.gca().locator_params(axis='x', nbins=20) 186 | plt.gca().locator_params(axis='y', nbins=10) 187 | plt.gca().grid() 188 | plt.xlabel("Nb epochs") 189 | plt.ylabel("MSE") 190 | plt.title("Losses") 191 | plt.legend() 192 | if show: 193 | plt.show() 194 | if save: 195 | plt.savefig(path) 196 | plt.close() 197 | 198 | 199 | def plot_eigvalues(A, add_to_title=""): 200 | 201 | # get eigenvalues 202 | eigs = scipy.linalg.eigvals(A) 203 | 204 | # plotting the eigenvalues 205 | plt.plot(eigs.real, eigs.imag, 'o') 206 | plt.plot(eigs.real, eigs.imag, 'rx') 207 | 208 | # plot unit circle 209 | thetas = np.linspace(0, 2 * np.pi) 210 | plt.plot(np.cos(thetas), np.sin(thetas), ls='--', c='gray') 211 | plt.xlabel("Re") 212 | plt.ylabel("Im") 213 | plt.xlim(-1.1, 1.1) 214 | plt.ylim(-1.1, 1.1) 215 | plt.axis('equal') 216 | 217 | # add stability thresholds 218 | plt.axvline(-1, linestyle="--") 219 | plt.axvline(1, linestyle="--") 220 | plt.text(1, 5, "stability threshold", rotation=90, verticalalignment='center') 221 | plt.text(-1, 5, "stability threshold", rotation=90, verticalalignment='center') 222 | 223 | plt.title(add_to_title + "\n" + "Recurrence Matrix Eigenvalues") 224 | plt.show() 225 | plt.close() 226 | 227 | 228 | def plot_score_per_time(scores_per_time, labels, sfreq, ref=None, path=None, title=None): 229 | """ 230 | input: 231 | -- scores_per_time : list of score_per_time arrays of shape (S, T, C) 232 | -- labels: list of labels of same len as scores 233 | -- ref: reference score_per_time, as in an upper bound 234 | """ 235 | 236 | n_models = len(scores_per_time) 237 | n_subjects, n_times, n_channels = scores_per_time[0].shape 238 | 239 | # convert times to ms 240 | times = (np.arange(n_times) / sfreq) * 1000 241 | 242 | for idx in range(n_models): 243 | 244 | # current elements 245 | score_per_time = scores_per_time[idx].mean(-1) 246 | if ref is not None: 247 | score_per_time = ref.mean(-1) - score_per_time 248 | label = labels[idx] 249 | 250 | # take mean and SEM error over subjects 251 | score_per_time_mean = score_per_time.mean(0) 252 | score_per_time_error = score_per_time.std(0) / np.sqrt(n_subjects) 253 | 254 | # plot 255 | plt.plot(times, score_per_time_mean, label=label) 256 | plt.fill_between( 257 | times, 258 | y1=score_per_time_mean + score_per_time_error, 259 | y2=score_per_time_mean - score_per_time_error, 260 | alpha=0.5) 261 | 262 | plt.legend() 263 | plt.tight_layout() 264 | if title is not None: 265 | plt.title(title) 266 | else: 267 | plt.title("Sample-wise Correlation between predicted and truth") 268 | plt.xlabel("Time (ms)") 269 | 270 | if path is not None: 271 | plt.savefig(os.path.join(path, "scores_per_time.png")) 272 | 273 | 274 | def plot_score_per_time_topo(scores_per_time, 275 | labels, 276 | info, 277 | sfreq=120, 278 | ref=None, 279 | path=None, 280 | title=None): 281 | """ 282 | input: 283 | -- scores_per_time : list of score_per_time arrays of shape (S, T, C) 284 | -- labels: list of labels of same len as scores 285 | -- ref: reference score_per_time, as in an upper bound 286 | """ 287 | 288 | n_models = len(scores_per_time) 289 | n_subjects, n_times, n_channels = scores_per_time[0].shape 290 | 291 | fig, axes = plt.subplots(n_models, 2) 292 | 293 | for idx in range(n_models): 294 | 295 | # current elements 296 | score_per_time = scores_per_time[idx] 297 | if ref is not None: 298 | score_per_time = ref - score_per_time 299 | label = labels[idx] 300 | 301 | # take mean over subjects 302 | evo_data = score_per_time.mean(0).T # (C, T) 303 | evo = mne.EvokedArray(evo_data, info=info, tmin=-.500) 304 | 305 | # plot time course 306 | mne.viz.plot_evoked( 307 | evo, 308 | spatial_colors=True, 309 | scalings=dict(mag=1.), 310 | show=False, 311 | axes=axes[idx, 0], 312 | titles='') 313 | # ax[0].set_ylim(-.01, .11) 314 | axes[idx, 0].set_xlabel('time') 315 | axes[idx, 0].set_ylabel('$\\Delta{}r$') 316 | axes[idx, 0].set_title('Feature %s' % label) 317 | 318 | # plot topo 319 | vmax = evo_data.mean(1).max() 320 | im, _ = mne.viz.plot_topomap( 321 | evo_data.mean(1), 322 | evo.info, 323 | cmap='RdBu_r', 324 | vmin=-vmax, 325 | vmax=vmax, 326 | show=False, 327 | axes=axes[idx, 1]) 328 | 329 | plt.colorbar(im, ax=axes[idx, 1]) 330 | 331 | plt.tight_layout() 332 | 333 | if path is not None: 334 | plt.savefig(path / "sensors_feature_importance.pdf") 335 | 336 | 337 | def plot_score(scores, labels, ref=None, path=None, title=None): 338 | """ 339 | input: 340 | -- scores : list of scores arrays of shape (S,) 341 | -- labels: list of labels of same len as scores 342 | """ 343 | 344 | n_models = len(scores) 345 | 346 | bps = [] # list of boxplots 347 | 348 | for idx in range(n_models): 349 | 350 | # current elements 351 | score = scores[idx] 352 | if ref is not None: 353 | score = ref - score 354 | 355 | bp = plt.boxplot(score, positions=[idx]) 356 | bps.append(bp) 357 | 358 | plt.xlim(-0.5, n_models + 1) 359 | plt.legend([bp["boxes"][0] for bp in bps], [label for label in labels], loc='upper right') 360 | if title is not None: 361 | plt.title(title) 362 | else: 363 | plt.title("Temporal Correlation between predicted and truth") 364 | plt.tight_layout() 365 | 366 | if path is not None: 367 | plt.savefig(os.path.join(path, "scores.png")) 368 | 369 | 370 | plt = init() 371 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch==1.7.1 3 | torchvision==0.8.2 4 | matplotlib==3.3.2 5 | pandas==1.1.3 6 | tqdm==4.46.0 7 | sklearn 8 | mne==0.20 9 | wordfreq==2.3.2 10 | Levenshtein==0.12.0 11 | scipy==1.5.4 12 | seaborn==0.11.0 13 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pep8] 2 | max-line-length = 100 3 | 4 | [flake8] 5 | max-line-length = 100 6 | 7 | [yapf] 8 | column_limit = 100 9 | --------------------------------------------------------------------------------