├── .gitignore ├── LICENSE ├── Notice.txt ├── README.md ├── docs ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── INDIVIDUAL_CONTRIBUTOR_LICENSE.md ├── Keras-LICENSE.txt └── RETAIN-LICENSE.txt ├── process_mimic_modified.py ├── requirements.txt ├── retain_evaluation.py ├── retain_interpretations.py └── retain_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2021 Optum 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Notice.txt: -------------------------------------------------------------------------------- 1 | RETAIN Keras 2 | Copyright (c) Optum 2018 3 | All rights reserved. 4 | 5 | Portions Copyright (c) 2016, mp2893 6 | All rights reserved. See RETAIN-LICENSE.txt 7 | 8 | Portions Copyright (c) 2015 - 2018, François Chollet. 9 | All rights reserved. See Keras-LICENSE.txt 10 | 11 | 12 | Project Description: 13 | ==================== 14 | Retain-Keras is a Keras reimplementation of RETAIN Neural Network introduced by Edward Choi(1). 15 | 16 | This implementation adds several features and provides additional scripts to evaluate and visualize predictions for each patient 17 | 18 | Authors: 19 | Tim Rosenflanz (@tRosenflanz),Copyright Optum 2018 20 | Ryan Caldwell (@rcaldwe4),Copyright Optum 2018 21 | 22 | 23 | References: 24 | 1. Edward Choi, Mohammad Taha Bahadori, Joshua A. Kulas, Andy Schuetz, Walter F. Stewart, Jimeng Sun, 2016, RETAIN: An interpretable predictive model for healthcare using reverse time attention mechanism, In Proc. of Neural Information Processing Systems (NIPS) 2016, pp.3504-3512. GitHub Repo 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RETAIN-Keras: Keras reimplementation of RETAIN 2 | 3 | [RETAIN is a neural network architecture originally introduced by Edward Choi](https://arxiv.org/abs/1608.05745) that enables the creations of highly interpretable Recurrent Neural Network models for patient diagnosis without any loss in model performance. This repository holds the [keras](https://www.tensorflow.org/api_docs/python/tf/keras) reimplementation of RETAIN (originally in Theano) that allows for flexible modifications to the original code, introduces multiple new features, and increases the speed of training. RETAIN has shown to be highly effective for creating predictive models for a multitude of conditions and we are excited to share this implementation to the broader healthcare data science community. 4 | 5 | ### Improvements and Extra Features 6 | 7 | - Simple Keras code with Tensorflow backend 8 | - Ability to use extra numeric inputs of fixed size that can hold numeric information about a patient's visit such as age, quantity of drug prescribed, or blood pressure 9 | - Improved embedding logic that avoids using large dense inputs 10 | - Ability to evaluate models during training 11 | - Ability to train models with only positive embedding contributions which improves performance 12 | - Extra script to evaluate the model and output several helper graphics 13 | 14 | ### Installing RETAIN-Keras and Building the Environment 15 | 16 | To run the scripts in this repository, create a Python 3.7.9 virtual environment and install the dependencies in `requirements.txt`. We recommend using [Anaconda](https://www.anaconda.com/products/individual) to create your environment with the following commands: 17 | 18 | ``` 19 | git clone https://github.com/Optum/retain-keras.git 20 | conda create --name=retain python=3.7.9 21 | conda activate retain 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | ### Running the code 26 | 27 | - **Training**: - `python3 retain_train.py --num_codes=x --additional arguments` 28 | - **Evaluating**: - `python3 retain_evaluation.py --additional arguments` 29 | - **Interpretation**: - `python3 retain_interpretations.py` 30 | 31 | ### Training Arguments 32 | 33 | The `retain_train.py` script will train the RETAIN model and evaluate/save it after each epoch. The script has multiple arguments to customize the training and model: 34 | 35 | - `--num_codes`: Integer Number of medical codes in the data set (required) 36 | - `--numeric_size`: Integer Size of numeric inputs, 0 if none. Default: 0 37 | - `--use_time`: Enables the extra time input for each visit. Default: off 38 | - `--emb_size`: Integer Size of the embedding layer. Default: 200 39 | - `--epochs`: Integer Number of epochs for training. Default: 1 40 | - `--n_steps`: Integer Maximum number of visits after which the data is truncated. This features helps to conserve GPU Ram (only the most recent n_steps will be used). Default: 300 41 | - `--recurrent_size`': Integer Size of the recurrent layers. Default: 200 42 | - `--path_data_train`: String Path to train data. Default: 'data/data_train.pkl' 43 | - `--path_data_test`: String Path to test/validation data. Default: 'data/data_test.pkl' 44 | - `--path_target_train`: String Path to train target. Default: 'data/target_train.pkl' 45 | - `--path_target_test`: String Path to test/validation target. Default: 'data/target_test.pkl' 46 | - `--batch_size`: Integer Batch Size for training. Default: 32 47 | - `--dropout_input`: Float Dropout rate for embedding of codes and numeric features (0 to 1). Default: 0.0 48 | - `--dropout_context`: Float Dropout rate for context vector (0 to 1). Default: 0.0 49 | - `--l2`: Float L2 regularization value for layers. Default: 0.0 50 | - `--directory`: String Directory to save the model and the log file to. Default: 'Model' (directory needs to exist otherwise error will be thrown) 51 | - `--allow_negative`: Allows negative weights for embeddings/attentions (original RETAIN implementation allows it but forcing non-negative weights have shown to perform better on a range of tasks). Default: off 52 | 53 | ### Evaluation Arguments 54 | 55 | The `retain_evaluation.py` script will evaluate the specific RETAIN model and create some sample graphs. Arguments include: 56 | 57 | - `--path_model`: Path to the model to evaluate. Default: 'Model/weights.01.hdf5' 58 | - `--path_data`: Path to evaluation data. Default: 'data/data_test.pkl' 59 | - `--path_target`: Path to evaluation target. Default: 'data/target_test.pkl' 60 | - `--omit_graphs`: Does not output graphs if argument is present. Default: (Graphs are output) 61 | - `--n_steps`: Integer Maximum number of visits after which the data is truncated. This features helps to conserve GPU Ram (only the most recent n_steps will be used). Default: 300 62 | - `--batch_size`: Batch size for prediction (higher values are generally faster). Default: 32 63 | 64 | ### Interpretation Arguments 65 | 66 | The `retain_interpretations.py` script will compute probabilities for all patients and then will allow the user to select patients by ID to see specific risk scores and interpret visits (displayed as pandas dataframes). It is highly recommended to extract this script to a notebook to enable more dynamic interaction. Arguments include: 67 | 68 | - `--path_model`: Path to the model to evaluate. Default: 'Model/weights.01.hdf5' 69 | - `--path_data`: Path to evaluation data. Default: 'data/data_test.pkl' 70 | - `--path_dictionary`: Path to dictionary that maps code index to the specific alphanumeric value. If numerics inputs are used they should have indexes num_codes+1 through num_codes+numeric_size, num_codes index is reserved for padding. 71 | - `--batch_size`: Batch size for prediction (higher values are generally faster). Default: 32 72 | 73 | ### Data and Target Format 74 | 75 | By default the data has to be saved as a pickled pandas dataframe with the following format: 76 | 77 | - Each row is 1 patient. 78 | - Rows are sorted by the number of visits a person has. People with the least visits should be in the beginning of the dataframe and people with the most visits at the end. 79 | - Column 'codes' is a list of lists where each sublist are codes for the individual visit. Lists have to be ordered by their order of events (from old to new). 80 | - Column 'numerics' is a list of lists where each sublist contains numeric values for an individual visit. Lists have to be ordered by their order of events (from old to new). Lists have to have a static size of `numeric_size` indicating number of different numeric features for each visit. Numeric information can include things like patients age, blood pressure, BMI, length of the visit, or cost charged (or all at the same time!). This column is not used if `numeric_size` is 0. 81 | - Column 'to_event' is a list of values indicating when the respective visit happened. Values have to be ordered from oldest to newest. This column is not used if `use_time` is not specified. 82 | 83 | By default the target has to be saved as a pickled pandas dataframe with the following format: 84 | 85 | - Each row is 1 patient corresponding to the patient from data file 86 | - Column 'target' is patient's class (either 0 or 1) 87 | 88 | ### Sample Data Generation Using MIMIC-III 89 | 90 | You can quickly test this reimplementation by creating a sample dataset from [MIMIC-III](https://physionet.org/content/mimiciii/1.4/) data using the `process_mimic_modified.py` script. You will need to request access to [MIMIC-III](https://mimic.physionet.org/gettingstarted/access/), a de-identified database containing information about clinical care of patients for 11 years of data, to be able to run this script. If you do not wish to request access to the full data, you can freely download the [MIMIC-III](https://physionet.org/content/mimiciii-demo/1.4/) sample demo data and use it for exploratory benchmarks. The `process_mimic_modified.py` script heavily borrows from the original [process_mimic.py](https://github.com/mp2893/retain/blob/master/process_mimic.py) script created by Edward Choi but is modified to output data in a format specified above. It outputs the necessary files to a user-specified directory and splits them into train and test by a user-specified ratio. 91 | 92 | Example: 93 | 94 | Run from the MIMIC-III directory. This will split data with 70% going to training and 30% to test: 95 | 96 | `python process_mimic_modified.py ADMISSIONS.csv DIAGNOSES_ICD.csv PATIENTS.csv data .7` 97 | 98 | ### License 99 | 100 | Please review the [license](LICENSE), [notice](Notice.txt) and other [documents](docs/) before using the code in this repository or making a contribution to the repository 101 | 102 | ### Contributing 103 | 104 | To contribute features, bug fixes, tests, examples, or documentation, please submit a pull request with a description of your proposed changes or additions. 105 | 106 | Please include a brief description of your pull request when submitting code and ensure that your code follows the [Pep 8](https://www.python.org/dev/peps/pep-0008/) style guide. To do this run `pip install black` and `black retain-keras` to reformat files within your copy of the code using the [black code formatter](https://github.com/psf/black). The black code formatter is a PEP 8 compliant, opinionated formatter that reformats entire files in place. You can also use the [autopep8 code formatter](https://packagecontrol.io/packages/AutoPEP8) within your IDE to ensure Pep 8 compliance. 107 | 108 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 109 | 110 | 111 | ### References 112 | 113 | 1. Edward Choi, Mohammad Taha Bahadori, Joshua A. Kulas, Andy Schuetz, Walter F. Stewart, Jimeng Sun, 2016, RETAIN: An interpretable predictive model for healthcare using reverse time attention mechanism, In Proc. of Neural Information Processing Systems (NIPS) 2016, pp.3504-3512. https://github.com/mp2893/retain 114 | 115 | 2. Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220 [Circulation Electronic Pages; http://circ.ahajournals.org/content/101/23/e215.full]; 2000 (June 13). 116 | -------------------------------------------------------------------------------- /docs/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, gender identity and expression, level of experience, 9 | nationality, personal appearance, race, religion, or sexual identity and 10 | orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project email 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at [opensource@optum.com][email]. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at [http://contributor-covenant.org/version/1/4][version] 72 | 73 | [homepage]: http://contributor-covenant.org 74 | [version]: http://contributor-covenant.org/version/1/4/ 75 | [email]: mailto:opensource@optum.com 76 | -------------------------------------------------------------------------------- /docs/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guidelines 2 | 3 | Please note that this project is released with a 4 | [Contributor Code of Conduct](CODE_OF_CONDUCT.md). By participating in this 5 | project you agree to abide by its terms. Please also review our [Contributor License Agreement ("CLA")](INDIVIDUAL_CONTRIBUTOR_LICENSE.md) prior to submitting changes to the project. You will need to attest to this agreement following the instructions in the [Paperwork for Pull Requests](#paperwork-for-pull-requests) section below. 6 | 7 | --- 8 | 9 | # How to Contribute 10 | 11 | Now that we have the disclaimer out of the way, let's get into how you can be a 12 | part of our project. There are many different ways to contribute. 13 | 14 | ## Issues 15 | 16 | We track our work using Issues in GitHub. Feel free to open up your own issue 17 | to point out areas for improvement or to suggest your own new experiment. If you 18 | are comfortable with signing the waiver linked above and contributing code or 19 | documentation, grab your own issue and start working. 20 | 21 | ## Coding Standards 22 | 23 | We have some general guidelines towards contributing to this project. 24 | 25 | ### Languages 26 | 27 | *Python* 28 | 29 | The source code for this project is written in Python. You are welcome to add versions of files for other languages, however the core code will remain in Python. 30 | 31 | ### Keras Backends 32 | 33 | *Tensorflow* 34 | 35 | By default we assume that this reimplementation will be run using Tensorflow backend. As Keras grows its support for other backends, we will welcome changes that will make these scripts backend independent. 36 | 37 | ## Pull Requests 38 | 39 | If you've gotten as far as reading this section, then thank you for your 40 | suggestions. 41 | 42 | ### Paperwork for Pull Requests 43 | 44 | * Please read this guide and make sure you agree with our [Contributor License Agreement ("CLA")](INDIVIDUAL_CONTRIBUTOR_LICENSE.md). 45 | * Make sure git knows your name and email address: 46 | ``` 47 | $ git config user.name "J. Random User" 48 | $ git config user.email "j.random.user@example.com" 49 | ``` 50 | >The name and email address must be valid as we cannot accept anonymous contributions. 51 | * Write good commit messages. 52 | > Concise commit messages that describe your changes help us better understand your contributions. 53 | * The first time you open a pull request in this repository, you will see a comment on your PR with a link that will allow you to sign our Contributor License Agreement (CLA) if necessary. 54 | > The link will take you to a page that allows you to view our CLA. You will need to click the `Sign in with GitHub to agree button` and authorize the cla-assistant application to access the email addresses associated with your GitHub account. Agreeing to the CLA is also considered to be an attestation that you either wrote or have the rights to contribute the code. All committers to the PR branch will be required to sign the CLA, but you will only need to sign once. This CLA applies to all repositories in the Optum org. 55 | 56 | ### General Guidelines 57 | 58 | Ensure your pull request (PR) adheres to the following guidelines: 59 | 60 | * Try to make the name concise and descriptive. 61 | * Give a good description of the change being made. Since this is very 62 | subjective, see the [Updating Your Pull Request (PR)](#updating-your-pull-request-pr) 63 | section below for further details. 64 | * Every pull request should be associated with one or more issues. If no issue 65 | exists yet, please create your own. 66 | * Make sure that all applicable issues are mentioned somewhere in the PR description. This 67 | can be done by typing # to bring up a list of issues. 68 | 69 | ### Updating Your Pull Request (PR) 70 | 71 | A lot of times, making a PR adhere to the standards above can be difficult. 72 | If the maintainers notice anything that we'd like changed, we'll ask you to 73 | edit your PR before we merge it. This applies to both the content documented 74 | in the PR and the changed contained within the branch being merged. There's no 75 | need to open a new PR. Just edit the existing one. 76 | 77 | [email]: mailto:opensource@optum.com 78 | -------------------------------------------------------------------------------- /docs/INDIVIDUAL_CONTRIBUTOR_LICENSE.md: -------------------------------------------------------------------------------- 1 | # Individual Contributor License Agreement ("Agreement") V2.0 2 | 3 | Thank you for your interest in this Optum project (the "PROJECT"). In order to clarify the intellectual property license granted with Contributions from any person or entity, the PROJECT must have a Contributor License Agreement ("CLA") on file that has been signed by each Contributor, indicating agreement to the license terms below. This license is for your protection as a Contributor as well as the protection of the PROJECT and its users; it does not change your rights to use your own Contributions for any other purpose. 4 | 5 | You accept and agree to the following terms and conditions for Your present and future Contributions submitted to the PROJECT. In return, the PROJECT shall not use Your Contributions in a way that is inconsistent with stated project goals in effect at the time of the Contribution. Except for the license granted herein to the PROJECT and recipients of software distributed by the PROJECT, You reserve all right, title, and interest in and to Your Contributions. 6 | 1. Definitions. 7 | 8 | "You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner that is making this Agreement with the PROJECT. For legal entities, the entity making a Contribution and all other entities that control, are controlled by, or are under common control with that entity are considered to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 9 | 10 | "Contribution" shall mean any original work of authorship, including any modifications or additions to an existing work, that is intentionally submitted by You to the PROJECT for inclusion in, or documentation of, any of the products owned or managed by the PROJECT (the "Work"). For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the PROJECT or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the PROJECT for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by You as "Not a Contribution." 11 | 12 | 2. Grant of Copyright License. 13 | 14 | Subject to the terms and conditions of this Agreement, You hereby grant to the PROJECT and to recipients of software distributed by the PROJECT a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense, and distribute Your Contributions and such derivative works. 15 | 16 | 3. Grant of Patent License. 17 | 18 | Subject to the terms and conditions of this Agreement, You hereby grant to the PROJECT and to recipients of software distributed by the PROJECT a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which such Contribution(s) was submitted. If any entity institutes patent litigation against You or any other entity (including a cross-claim or counterclaim in a lawsuit) alleging that your Contribution, or the Work to which you have contributed, constitutes direct or contributory patent infringement, then any patent licenses granted to that entity under this Agreement for that Contribution or Work shall terminate as of the date such litigation is filed. 19 | 20 | 4. Representations. 21 | 22 | (a) You represent that you are legally entitled to grant the above license. If your employer(s) has rights to intellectual property that you create that includes your Contributions, you represent that you have received permission to make Contributions on behalf of that employer, that your employer has waived such rights for your Contributions to the PROJECT, or that your employer has executed a separate Corporate CLA with the PROJECT. 23 | 24 | (b) You represent that each of Your Contributions is Your original creation (see section 6 for submissions on behalf of others). You represent that Your Contribution submissions include complete details of any third-party license or other restriction (including, but not limited to, related patents and trademarks) of which you are personally aware and which are associated with any part of Your Contributions. 25 | 26 | 5. You are not expected to provide support for Your Contributions, except to the extent You desire to provide support. You may provide support for free, for a fee, or not at all. Unless required by applicable law or agreed to in writing, You provide Your Contributions on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. 27 | 28 | 6. Should You wish to submit work that is not Your original creation, You may submit it to the PROJECT separately from any Contribution, identifying the complete details of its source and of any license or other restriction (including, but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and conspicuously marking the work as "Submitted on behalf of a third-party: [named here]". 29 | 30 | 7. You agree to notify the PROJECT of any facts or circumstances of which you become aware that would make these representations inaccurate in any respect. 31 | -------------------------------------------------------------------------------- /docs/Keras-LICENSE.txt: -------------------------------------------------------------------------------- 1 | COPYRIGHT 2 | 3 | All contributions by François Chollet: 4 | Copyright (c) 2015 - 2018, François Chollet. 5 | All rights reserved. 6 | 7 | All contributions by Google: 8 | Copyright (c) 2015 - 2018, Google, Inc. 9 | All rights reserved. 10 | 11 | All contributions by Microsoft: 12 | Copyright (c) 2017 - 2018, Microsoft, Inc. 13 | All rights reserved. 14 | 15 | All other contributions: 16 | Copyright (c) 2015 - 2018, the respective contributors. 17 | All rights reserved. 18 | 19 | Each contributor holds copyright over their respective contributions. 20 | The project versioning (Git) records all such contribution source information. 21 | 22 | LICENSE 23 | 24 | The MIT License (MIT) 25 | 26 | Permission is hereby granted, free of charge, to any person obtaining a copy 27 | of this software and associated documentation files (the "Software"), to deal 28 | in the Software without restriction, including without limitation the rights 29 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 30 | copies of the Software, and to permit persons to whom the Software is 31 | furnished to do so, subject to the following conditions: 32 | 33 | The above copyright notice and this permission notice shall be included in all 34 | copies or substantial portions of the Software. 35 | 36 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 37 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 38 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 39 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 40 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 41 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 42 | SOFTWARE. 43 | 44 | -------------------------------------------------------------------------------- /docs/RETAIN-LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016, mp2893 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of RETAIN nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /process_mimic_modified.py: -------------------------------------------------------------------------------- 1 | # This script processes MIMIC-III dataset and builds longitudinal records for patients with 2+ encounter. 2 | # The output data are 4 pickled pandas dataframes suitable for training RETAIN-Keras 3 | # Originally Written by Edward Choi (mp2893@gatech.edu) https://github.com/mp2893/retain 4 | # Modified by Timothy Rosenflanz (timothy.rosenflanz@optum.com) to work with RETAIN-Keras 5 | # Usage: Put this script to the folder where MIMIC-III CSV files are located. Then execute the below command. 6 | # python process_mimic_modified.py ADMISSIONS.csv DIAGNOSES_ICD.csv PATIENTS.csv 7 | 8 | # Output files 9 | # data_train.pkl: Pickled dataframe used for training containing the codes and to_event sequences as specified in the README 10 | # data_test.pkl: Pickled dataframe used for testing containing the codes and to_event sequences as specified in the README 11 | # data_train_3digit.pkl: Pickled dataframe used for training containing the 3 digit codes and to_event sequences as specified in the README 12 | # data_test_3digit.pkl: Pickled dataframe used for testing containing the 3 digit codes and to_event sequences as specified in the README 13 | # target_train.pkl: Pickled dataframe containing target labels for training as specified in the README 14 | # target_test.pkl: Pickled dataframe containing target labels for testing as specified in the README 15 | # dictionary.pkl: Python dictionary that maps string diagnosis codes to integer diagnosis codes. 16 | # dictionary_3digit.pkl: Python dictionary that maps string diagnosis codes to integer 3 digit diagnosis codes. 17 | 18 | # Imports 19 | import os 20 | import sys 21 | import pickle 22 | import regex as re 23 | import numpy as np 24 | import pandas as pd 25 | from datetime import datetime 26 | from sklearn.model_selection import train_test_split 27 | 28 | 29 | def convert_to_icd9(dx_str): 30 | """ 31 | Maps an ICD diagnosis code to ICD9 32 | """ 33 | 34 | if dx_str.startswith("E"): 35 | if len(dx_str) > 4: 36 | return dx_str[:4] + "." + dx_str[4:] 37 | else: 38 | return dx_str 39 | else: 40 | if len(dx_str) > 3: 41 | return dx_str[:3] + "." + dx_str[3:] 42 | else: 43 | return dx_str 44 | 45 | 46 | def convert_to_3digit_icd9(dx_str): 47 | """ 48 | Roll up a diagnosis code to 3 digits 49 | """ 50 | 51 | if dx_str.startswith("E"): 52 | if len(dx_str) > 4: 53 | return dx_str[:4] 54 | else: 55 | return dx_str 56 | else: 57 | if len(dx_str) > 3: 58 | return dx_str[:3] 59 | else: 60 | return dx_str 61 | 62 | 63 | if __name__ == "__main__": 64 | 65 | admission_file = sys.argv[1] 66 | diagnosis_file = sys.argv[2] 67 | patients_file = sys.argv[3] 68 | out_directory = sys.argv[4] 69 | train_proportion = float(sys.argv[5]) 70 | 71 | # Read mortality data 72 | print("Collecting mortality information...") 73 | pid_dod_map = {} 74 | infd = open(patients_file, "r") 75 | infd.readline() 76 | for line in infd: 77 | tokens = line.strip().split(",") 78 | pid = int(tokens[1]) 79 | dod_hosp = tokens[5] 80 | if len(dod_hosp) > 0: 81 | pid_dod_map[pid] = 1 82 | else: 83 | pid_dod_map[pid] = 0 84 | infd.close() 85 | 86 | # Read and create admission records 87 | print("Building pid-admission mapping, admission-date mapping...") 88 | pid_adm_map = {} 89 | adm_date_map = {} 90 | infd = open(admission_file, "r") 91 | infd.readline() 92 | for line in infd: 93 | tokens = line.strip().split(",") 94 | pid = int(tokens[1]) 95 | adm_id = int(tokens[2]) 96 | adm_time = datetime.strptime(tokens[3], "%Y-%m-%d %H:%M:%S") 97 | adm_date_map[adm_id] = adm_time 98 | if pid in pid_adm_map: 99 | pid_adm_map[pid].append(adm_id) 100 | else: 101 | pid_adm_map[pid] = [adm_id] 102 | infd.close() 103 | 104 | # Create admission dx code mapping 105 | print("Building admission-dxList mapping...") 106 | adm_dx_map = {} 107 | adm_dx_map_3digit = {} 108 | infd = open(diagnosis_file, "r") 109 | infd.readline() 110 | for line in infd: 111 | tokens = re.sub('"|\s|\n','',line).split(',') 112 | adm_id = int(tokens[2]) 113 | dx_str = "D_" + convert_to_icd9(tokens[4][1:-1]) 114 | dx_str_3digit = "D_" + convert_to_3digit_icd9(tokens[4][1:-1]) 115 | if adm_id in adm_dx_map: 116 | adm_dx_map[adm_id].append(dx_str) 117 | else: 118 | adm_dx_map[adm_id] = [dx_str] 119 | if adm_id in adm_dx_map_3digit: 120 | adm_dx_map_3digit[adm_id].append(dx_str_3digit) 121 | else: 122 | adm_dx_map_3digit[adm_id] = [dx_str_3digit] 123 | infd.close() 124 | 125 | # Create ordered visit mapping 126 | print("Building pid-sortedVisits mapping...") 127 | pid_seq_map = {} 128 | pid_seq_map_3digit = {} 129 | for pid, adm_id_list in pid_adm_map.items(): 130 | if len(adm_id_list) < 2: 131 | continue 132 | sorted_list = sorted( 133 | [(adm_date_map[adm_id], adm_dx_map[adm_id]) for adm_id in adm_id_list] 134 | ) 135 | pid_seq_map[pid] = sorted_list 136 | sorted_list_3digit = sorted( 137 | [ 138 | (adm_date_map[adm_id], adm_dx_map_3digit[adm_id]) 139 | for adm_id in adm_id_list 140 | ] 141 | ) 142 | pid_seq_map_3digit[pid] = sorted_list_3digit 143 | 144 | # Create sequences of IDs, dates, labels, and code sequences 145 | print("Building pids, dates, mortality_labels, strSeqs...") 146 | pids = [] 147 | dates = [] 148 | seqs = [] 149 | morts = [] 150 | for pid, visits in pid_seq_map.items(): 151 | pids.append(pid) 152 | morts.append(pid_dod_map[pid]) 153 | seq = [] 154 | date = [] 155 | for visit in visits: 156 | date.append(visit[0]) 157 | seq.append(visit[1]) 158 | dates.append(date) 159 | seqs.append(seq) 160 | 161 | # Create 3 digit ICD sequences 162 | print("Building pids, dates, strSeqs for 3digit ICD9 code...") 163 | seqs_3digit = [] 164 | for pid, visits in pid_seq_map_3digit.items(): 165 | seq = [] 166 | for visit in visits: 167 | seq.append(visit[1]) 168 | seqs_3digit.append(seq) 169 | 170 | # Collect code types 171 | print("Converting strSeqs to intSeqs, and making types...") 172 | types = {} 173 | new_seqs = [] 174 | for patient in seqs: 175 | new_patient = [] 176 | for visit in patient: 177 | new_visit = [] 178 | for code in visit: 179 | if code in types: 180 | new_visit.append(types[code]) 181 | else: 182 | types[code] = len(types) 183 | new_visit.append(types[code]) 184 | new_patient.append(new_visit) 185 | new_seqs.append(new_patient) 186 | 187 | # Map code strings to integers 188 | print("Converting strSeqs to intSeqs, and making types for 3digit ICD9 code...") 189 | types_3digit = {} 190 | new_seqs_3digit = [] 191 | for patient in seqs_3digit: 192 | new_patient = [] 193 | for visit in patient: 194 | new_visit = [] 195 | for code in set(visit): 196 | if code in types_3digit: 197 | new_visit.append(types_3digit[code]) 198 | else: 199 | types_3digit[code] = len(types_3digit) 200 | new_visit.append(types_3digit[code]) 201 | new_patient.append(new_visit) 202 | new_seqs_3digit.append(new_patient) 203 | 204 | # Compute time to today as to_event column 205 | print("Making additional modifications to the data...") 206 | today = datetime.strptime("2025-01-01", "%Y-%m-%d") 207 | to_event = [[(today - date).days for date in patient] for patient in dates] 208 | 209 | # Compute time of the day when the person was admitted as the numeric column of size 1 210 | numerics = [ 211 | [[date.hour * 60 + date.minute - 720] for date in patient] for patient in dates 212 | ] 213 | 214 | # Add this feature to dictionary but leave 1 index empty for PADDING 215 | types["Time of visit"] = len(types) + 1 216 | types_3digit["Time of visit"] = len(types_3digit) + 1 217 | 218 | # Compute sorting indicies 219 | sort_indicies = np.argsort(list(map(len, to_event))) 220 | 221 | # Create the dataframes of data and sort them according to number of visits per patient 222 | print("Building sorted dataframes...") 223 | all_data = ( 224 | pd.DataFrame( 225 | data={"codes": new_seqs, "to_event": to_event, "numerics": numerics}, 226 | columns=["codes", "to_event", "numerics"], 227 | ) 228 | .iloc[sort_indicies] 229 | .reset_index() 230 | ) 231 | all_data_3digit = ( 232 | pd.DataFrame( 233 | data={"codes": new_seqs_3digit, "to_event": to_event, "numerics": numerics}, 234 | columns=["codes", "to_event", "numerics"], 235 | ) 236 | .iloc[sort_indicies] 237 | .reset_index() 238 | ) 239 | all_targets = ( 240 | pd.DataFrame(data={"target": morts}, columns=["target"]) 241 | .iloc[sort_indicies] 242 | .reset_index() 243 | ) 244 | 245 | # Create train test split 246 | print("Creating train/test splits...") 247 | data_train, data_test = train_test_split( 248 | all_data, train_size=train_proportion, random_state=12345 249 | ) 250 | data_train_3digit, data_test_3digit = train_test_split( 251 | all_data_3digit, train_size=train_proportion, random_state=12345 252 | ) 253 | target_train, target_test = train_test_split( 254 | all_targets, train_size=train_proportion, random_state=12345 255 | ) 256 | 257 | # Create reverse dictionary in index:code format 258 | types = dict((v, k) for k, v in types.items()) 259 | types_3digit = dict((v, k) for k, v in types_3digit.items()) 260 | 261 | # Write out the data 262 | print("Saving data...") 263 | if not os.path.exists(out_directory): 264 | os.makedirs(out_directory) 265 | data_train.sort_index().to_pickle(out_directory + "/data_train.pkl") 266 | data_test.sort_index().to_pickle(out_directory + "/data_test.pkl") 267 | data_train_3digit.sort_index().to_pickle(out_directory + "/data_train_3digit.pkl") 268 | data_test_3digit.sort_index().to_pickle(out_directory + "/data_test_3digit.pkl") 269 | target_train.sort_index().to_pickle(out_directory + "/target_train.pkl") 270 | target_test.sort_index().to_pickle(out_directory + "/target_test.pkl") 271 | pickle.dump(types, open(out_directory + "/dictionary.pkl", "wb"), -1) 272 | pickle.dump(types_3digit, open(out_directory + "/dictionary_3digit.pkl", "wb"), -1) 273 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn==1.0.2 2 | tensorflow==2.8.0 3 | pandas==1.3.5 4 | regex==2022.1.18 5 | -------------------------------------------------------------------------------- /retain_evaluation.py: -------------------------------------------------------------------------------- 1 | """RETAIN Model Evaluation""" 2 | import argparse 3 | import numpy as np 4 | import pandas as pd 5 | from sklearn.metrics import ( 6 | roc_auc_score, 7 | average_precision_score, 8 | precision_recall_curve, 9 | roc_curve, 10 | ) 11 | from sklearn.calibration import calibration_curve 12 | import matplotlib.pyplot as plt 13 | import tensorflow as tf 14 | import tensorflow.keras.backend as K 15 | from tensorflow.keras.models import load_model 16 | from tensorflow.keras.preprocessing import sequence 17 | from tensorflow.keras.constraints import Constraint 18 | from tensorflow.keras.utils import Sequence 19 | 20 | 21 | def import_model(path): 22 | """Import model from training phase 23 | 24 | :param str path: path to HDF5 file 25 | :return: Keras model 26 | :rtype: :class:`tensorflow.keras.Model` 27 | """ 28 | 29 | K.clear_session() 30 | config = tf.compat.v1.ConfigProto( 31 | allow_soft_placement=True, log_device_placement=False 32 | ) 33 | config.gpu_options.allow_growth = True 34 | tfsess = tf.compat.v1.Session(config=config) 35 | tf.compat.v1.keras.backend.set_session(tfsess) 36 | model = load_model( 37 | path, 38 | custom_objects={ 39 | "FreezePadding": FreezePadding, 40 | "FreezePadding_Non_Negative": FreezePadding_Non_Negative, 41 | }, 42 | ) 43 | 44 | return model 45 | 46 | 47 | def get_model_parameters(model): 48 | """Get model parameters of interest 49 | 50 | :param model: Keras model 51 | :type model: :class:`tensorflow.keras.Model` 52 | :return: parameters of model 53 | :rtype: :class:`ModelParameters` 54 | """ 55 | 56 | class ModelParameters: 57 | """Helper class to store model parametesrs in the same format as ARGS""" 58 | 59 | def __init__(self): 60 | self.num_codes = None 61 | self.numeric_size = None 62 | self.use_time = None 63 | 64 | params = ModelParameters() 65 | names = [layer.name for layer in model.layers] 66 | params.num_codes = model.get_layer(name="embedding").input_dim - 1 67 | if "numeric_input" in names: 68 | params.numeric_size = model.get_layer(name="numeric_input").input_shape[2] 69 | else: 70 | params.numeric_size = 0 71 | if "time_input" in names: 72 | params.use_time = True 73 | else: 74 | params.use_time = False 75 | return params 76 | 77 | 78 | class FreezePadding_Non_Negative(Constraint): 79 | """Freezes the last weight to be near 0 and prevents non-negative embeddings 80 | 81 | :param Constraint: Keras sequence constraint 82 | :type Constraint: :class:`tensorflow.keras.constraints.Constraint` 83 | :return: padded tensorflow tensor 84 | :rtype: :class:`tensorflow.Tensor` 85 | """ 86 | 87 | def __call__(self, w): 88 | other_weights = K.cast(K.greater_equal(w, 0)[:-1], K.floatx()) 89 | last_weight = K.cast( 90 | K.equal(K.reshape(w[-1, :], (1, K.shape(w)[1])), 0.0), K.floatx() 91 | ) 92 | appended = K.concatenate([other_weights, last_weight], axis=0) 93 | w *= appended 94 | return w 95 | 96 | 97 | class FreezePadding(Constraint): 98 | """Freezes the last weight to be near 0. 99 | 100 | :param Constraint: Keras sequence constraint 101 | :type Constraint: :class:`tensorflow.keras.constraints.Constraint` 102 | :return: padded tensorflow tensor 103 | :rtype: :class:`tensorflow.Tensor` 104 | """ 105 | 106 | def __call__(self, w): 107 | other_weights = K.cast(K.ones(K.shape(w))[:-1], K.floatx()) 108 | last_weight = K.cast( 109 | K.equal(K.reshape(w[-1, :], (1, K.shape(w)[1])), 0.0), K.floatx() 110 | ) 111 | appended = K.concatenate([other_weights, last_weight], axis=0) 112 | w *= appended 113 | return w 114 | 115 | 116 | def precision_recall(y_true, y_prob, graph): 117 | """ 118 | Get precision recall statistics 119 | 120 | :param y_true: NumPy array of true target values 121 | :type y_true: :class:`numpy.array` 122 | :param y_prob: NumPy array of predicted target values 123 | :type y_prob: :class:`numpy.array` 124 | :param graph: Option to plot + save precision-recall curve 125 | :type graph: bool 126 | """ 127 | 128 | average_precision = average_precision_score(y_true, y_prob) 129 | if graph: 130 | precision, recall, _ = precision_recall_curve(y_true, y_prob) 131 | plt.style.use("ggplot") 132 | plt.clf() 133 | plt.plot( 134 | recall, 135 | precision, 136 | label="Precision-Recall Curve (Area = %0.3f)" % average_precision, 137 | ) 138 | plt.xlabel("Recall: P(predicted+|true+)") 139 | plt.ylabel("Precision: P(true+|predicted+)") 140 | plt.ylim([0.0, 1.05]) 141 | plt.xlim([0.0, 1.0]) 142 | plt.legend(loc="lower left") 143 | print("Precision-Recall Curve saved to pr.png") 144 | plt.savefig("pr.png") 145 | else: 146 | print("Average Precision %0.3f" % average_precision) 147 | 148 | 149 | def probability_calibration(y_true, y_prob, graph): 150 | """ 151 | Get probability calibration 152 | 153 | :param y_true: NumPy array of true target values 154 | :type y_true: :class:`numpy.array` 155 | :param y_prob: NumPy array of predicted target values 156 | :type y_prob: :class:`numpy.array` 157 | :param graph: Option to plot + save probability calibration curves 158 | :type graph: bool 159 | """ 160 | 161 | if graph: 162 | fig_index = 1 163 | name = "My pred" 164 | n_bins = 20 165 | fig = plt.figure(fig_index, figsize=(10, 10)) 166 | ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2) 167 | ax2 = plt.subplot2grid((3, 1), (2, 0)) 168 | 169 | ax1.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated") 170 | 171 | fraction_of_positives, mean_predicted_value = calibration_curve( 172 | y_true, y_prob, n_bins=n_bins, normalize=True 173 | ) 174 | 175 | ax1.plot(mean_predicted_value, fraction_of_positives, label=name) 176 | 177 | ax2.hist(y_prob, range=(0, 1), bins=n_bins, label=name, histtype="step", lw=2) 178 | 179 | ax1.set_ylabel("Fraction of Positives") 180 | ax1.set_ylim([-0.05, 1.05]) 181 | ax1.legend(loc="lower right") 182 | ax1.set_title("Calibration Plots (Reliability Curve)") 183 | 184 | ax2.set_xlabel("Mean predicted value") 185 | ax2.set_ylabel("Count") 186 | ax2.legend(loc="upper center", ncol=2) 187 | print("Probability Calibration Curves saved to calibration.png") 188 | plt.tight_layout() 189 | plt.savefig("calibration.png") 190 | 191 | 192 | def lift(y_true, y_prob, graph): 193 | """ 194 | Get lift chart 195 | 196 | :param y_true: NumPy array of true target values 197 | :type y_true: :class:`numpy.array` 198 | :param y_prob: NumPy array of predicted target values 199 | :type y_prob: :class:`numpy.array` 200 | :param graph: Option to plot + save lift chart 201 | :type graph: bool 202 | """ 203 | 204 | prevalence = sum(y_true) / len(y_true) 205 | average_lift = average_precision_score(y_true, y_prob) / prevalence 206 | if graph: 207 | precision, recall, _ = precision_recall_curve(y_true, y_prob) 208 | lift_values = precision / prevalence 209 | plt.style.use("ggplot") 210 | plt.clf() 211 | plt.plot( 212 | recall, 213 | lift_values, 214 | label="Lift-Recall Curve (Area = %0.3f)" % average_lift, 215 | ) 216 | plt.xlabel("Recall: P(predicted+|true+)") 217 | plt.ylabel("Lift") 218 | plt.xlim([0.0, 1.0]) 219 | plt.legend(loc="lower left") 220 | print("Lift-Recall Curve saved to lift.png") 221 | plt.savefig("lift") 222 | else: 223 | print("Average Lift %0.3f" % average_lift) 224 | 225 | 226 | def roc(y_true, y_prob, graph): 227 | """ 228 | Get ROC statistics 229 | 230 | :param y_true: NumPy array of true target values 231 | :type y_true: :class:`numpy.array` 232 | :param y_prob: NumPy array of predicted target values 233 | :type y_prob: :class:`numpy.array` 234 | :param graph: Option to plot + save ROC curves 235 | :type graph: bool 236 | """ 237 | 238 | roc_auc = roc_auc_score(y_true, y_prob) 239 | if graph: 240 | fpr, tpr, _ = roc_curve(y_true, y_prob) 241 | plt.plot( 242 | fpr, 243 | tpr, 244 | color="darkorange", 245 | lw=2, 246 | label="ROC curve (Area = %0.3f)" % roc_auc, 247 | ) 248 | plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--") 249 | plt.xlim([0.0, 1.0]) 250 | plt.ylim([0.0, 1.05]) 251 | plt.xlabel("False Positive Rate (1 - Specifity)") 252 | plt.ylabel("True Positive Rate (Sensitivity)") 253 | plt.title("Receiver Operating Characteristic") 254 | plt.legend(loc="lower right") 255 | print("ROC Curve saved to roc.png") 256 | plt.savefig("roc.png") 257 | else: 258 | print("ROC-AUC %0.3f" % roc_auc) 259 | 260 | 261 | class SequenceBuilder(Sequence): 262 | """Class to properly construct data to sequences 263 | 264 | :param Sequence: Customized Sequence class for generating batches of data 265 | :type Sequence: :class:`tensorflow.keras.utils.data_utils.Sequence` 266 | """ 267 | 268 | def __init__(self, data, model_parameters, ARGS): 269 | # Receive all appropriate data 270 | self.codes = data[0] 271 | index = 1 272 | if model_parameters.numeric_size: 273 | self.numeric = data[index] 274 | index += 1 275 | 276 | if model_parameters.use_time: 277 | self.time = data[index] 278 | 279 | self.num_codes = model_parameters.num_codes 280 | self.batch_size = ARGS.batch_size 281 | self.numeric_size = model_parameters.numeric_size 282 | self.use_time = model_parameters.use_time 283 | self.n_steps = ARGS.n_steps 284 | 285 | def __len__(self): 286 | """Compute number of batches. 287 | Add extra batch if the data doesn't exactly divide into batches 288 | """ 289 | if len(self.codes) % self.batch_size == 0: 290 | return len(self.codes) // self.batch_size 291 | return len(self.codes) // self.batch_size + 1 292 | 293 | def __getitem__(self, idx): 294 | """Get batch of specific index""" 295 | 296 | def pad_data(data, length_visits, length_codes, pad_value=0): 297 | """Pad data to desired number of visits and codes inside each visit""" 298 | zeros = np.full((len(data), length_visits, length_codes), pad_value) 299 | for steps, mat in zip(data, zeros): 300 | if steps != [[-1]]: 301 | for step, mhot in zip(steps, mat[-len(steps) :]): 302 | # Populate the data into the appropriate visit 303 | mhot[: len(step)] = step 304 | 305 | return zeros 306 | 307 | # Compute reusable batch slice 308 | batch_slice = slice(idx * self.batch_size, (idx + 1) * self.batch_size) 309 | x_codes = self.codes[batch_slice] 310 | # Max number of visits and codes inside the visit for this batch 311 | pad_length_visits = min(max(map(len, x_codes)), self.n_steps) 312 | pad_length_codes = max(map(lambda x: max(map(len, x)), x_codes)) 313 | # Number of elements in a batch (useful in case of partial batches) 314 | length_batch = len(x_codes) 315 | # Pad data 316 | x_codes = pad_data(x_codes, pad_length_visits, pad_length_codes, self.num_codes) 317 | outputs = [x_codes] 318 | # Add numeric data if necessary 319 | if self.numeric_size: 320 | x_numeric = self.numeric[batch_slice] 321 | x_numeric = pad_data(x_numeric, pad_length_visits, self.numeric_size, -99.0) 322 | outputs.append(x_numeric) 323 | # Add time data if necessary 324 | if self.use_time: 325 | x_time = sequence.pad_sequences( 326 | self.time[batch_slice], 327 | dtype=np.float32, 328 | maxlen=pad_length_visits, 329 | value=+99, 330 | ).reshape(length_batch, pad_length_visits, 1) 331 | outputs.append(x_time) 332 | 333 | return outputs 334 | 335 | 336 | def read_data(model_parameters, ARGS): 337 | """Read test data used for scoring 338 | 339 | :param model_parameters: parameters of model 340 | :type model_parameters: str 341 | :param ARGS: Arguments object containing user-specified parameters 342 | :type ARGS: :class:`argparse.Namespace` 343 | :return: tuple for data and classifier arrays 344 | :rtype: tuple( list[class:`numpy.ndarray`] , :class:`numpy.ndarray`) 345 | """ 346 | 347 | data = pd.read_pickle(ARGS.path_data) 348 | y = pd.read_pickle(ARGS.path_target)["target"].values 349 | data_output = [data["codes"].values] 350 | 351 | if model_parameters.numeric_size: 352 | data_output.append(data["numerics"].values) 353 | if model_parameters.use_time: 354 | data_output.append(data["to_event"].values) 355 | return (data_output, y) 356 | 357 | 358 | def get_predictions(model, data, model_parameters, ARGS): 359 | """Get Model Predictions 360 | 361 | :param model: trained Keras model 362 | :type model: :class:`tensorflow.keras.Model` 363 | :param data: array(s) for features (e.g. ['to_event_ordered','code_ordered','numeric_ordered']) 364 | :type data: list[class:`numpy.ndarray`] 365 | :param str model_parameters: parameters of model 366 | :param ARGS: Arguments object containing user-specified parameters 367 | :type ARGS: :class:`argparse.Namespace` 368 | :return: 1-d array of scores for being in positive class 369 | :rtype: :class:`numpy.ndarray` 370 | """ 371 | 372 | test_generator = SequenceBuilder(data, model_parameters, ARGS) 373 | preds = model.predict_generator( 374 | generator=test_generator, 375 | max_queue_size=15, 376 | use_multiprocessing=True, 377 | verbose=1, 378 | workers=3, 379 | ) 380 | return preds 381 | 382 | 383 | def main(ARGS): 384 | """Main Body of the code""" 385 | print("Loading Model and Extracting Parameters") 386 | model = import_model(ARGS.path_model) 387 | model_parameters = get_model_parameters(model) 388 | print("Reading Data") 389 | data, y = read_data(model_parameters, ARGS) 390 | print("Predicting the probabilities") 391 | probabilities = get_predictions(model, data, model_parameters, ARGS) 392 | print("Evaluating") 393 | roc(y, probabilities[:, 0, -1], ARGS.omit_graphs) 394 | precision_recall(y, probabilities[:, 0, -1], ARGS.omit_graphs) 395 | lift(y, probabilities[:, 0, -1], ARGS.omit_graphs) 396 | probability_calibration(y, probabilities[:, 0, -1], ARGS.omit_graphs) 397 | 398 | 399 | def parse_arguments(parser): 400 | """Read user arguments""" 401 | parser.add_argument( 402 | "--path_model", 403 | type=str, 404 | default="Model/weights.01.hdf5", 405 | help="Path to the model to evaluate", 406 | ) 407 | parser.add_argument( 408 | "--path_data", 409 | type=str, 410 | default="data/data_test.pkl", 411 | help="Path to evaluation data", 412 | ) 413 | parser.add_argument( 414 | "--path_target", 415 | type=str, 416 | default="data/target_test.pkl", 417 | help="Path to evaluation target", 418 | ) 419 | parser.add_argument( 420 | "--omit_graphs", 421 | action="store_false", 422 | help="Does not output graphs if argument is present", 423 | ) 424 | parser.add_argument( 425 | "--n_steps", 426 | type=int, 427 | default=300, 428 | help="Maximum number of visits after which the data is truncated", 429 | ) 430 | parser.add_argument( 431 | "--batch_size", 432 | type=int, 433 | default=32, 434 | help="Batch size for prediction (higher values are generally faster)", 435 | ) 436 | args = parser.parse_args() 437 | 438 | return args 439 | 440 | 441 | if __name__ == "__main__": 442 | 443 | PARSER = argparse.ArgumentParser( 444 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 445 | ) 446 | ARGS = parse_arguments(PARSER) 447 | main(ARGS) 448 | -------------------------------------------------------------------------------- /retain_interpretations.py: -------------------------------------------------------------------------------- 1 | """This function will load the given data and continuosly interpet selected patients""" 2 | import argparse 3 | import pickle as pickle 4 | import numpy as np 5 | import pandas as pd 6 | import tensorflow as tf 7 | import tensorflow.keras.backend as K 8 | from tensorflow.keras.models import load_model, Model 9 | from tensorflow.keras.preprocessing import sequence 10 | from tensorflow.keras.constraints import Constraint 11 | from tensorflow.keras.utils import Sequence 12 | 13 | 14 | def import_model(path): 15 | """Import model from given path and assign it to appropriate devices""" 16 | K.clear_session() 17 | config = tf.compat.v1.ConfigProto( 18 | allow_soft_placement=True, log_device_placement=False 19 | ) 20 | config.gpu_options.allow_growth = True 21 | tfsess = tf.compat.v1.Session(config=config) 22 | tf.compat.v1.keras.backend.set_session(tfsess) 23 | model = load_model( 24 | path, 25 | custom_objects={ 26 | "FreezePadding": FreezePadding, 27 | "FreezePadding_Non_Negative": FreezePadding_Non_Negative, 28 | }, 29 | ) 30 | model_with_attention = Model( 31 | model.inputs, 32 | model.outputs 33 | + [ 34 | model.get_layer(name="softmax_1").output, 35 | model.get_layer(name="beta_dense_0").output, 36 | ], 37 | ) 38 | return model, model_with_attention 39 | 40 | 41 | def get_model_parameters(model): 42 | """Extract model arguments that were used during training""" 43 | 44 | class ModelParameters: 45 | """Helper class to store model parametesrs in the same format as ARGS""" 46 | 47 | def __init__(self): 48 | self.num_codes = None 49 | self.numeric_size = None 50 | self.use_time = None 51 | self.emb_weights = None 52 | self.output_weights = None 53 | self.bias = None 54 | 55 | params = ModelParameters() 56 | names = [layer.name for layer in model.layers] 57 | params.num_codes = model.get_layer(name="embedding").input_dim - 1 58 | params.emb_weights = model.get_layer(name="embedding").get_weights()[0] 59 | params.output_weights, params.bias = model.get_layer( 60 | name="time_distributed_out" 61 | ).get_weights() 62 | print("Model bias: {}".format(params.bias)) 63 | if "numeric_input" in names: 64 | params.numeric_size = model.get_layer(name="numeric_input").input_shape[2] 65 | # Add artificial embeddings for each numeric feature and extend the embedding weights 66 | # Numeric embeddings is just 1 for 1 dimension of the embedding which corresponds to taking value as is 67 | numeric_embeddings = np.zeros( 68 | (params.numeric_size, params.emb_weights.shape[1] + params.numeric_size) 69 | ) 70 | for i in range(params.numeric_size): 71 | numeric_embeddings[i, params.emb_weights.shape[1] + i] = 1 72 | # Extended embedding is original embedding extended to larger output size and numerics embeddings added 73 | params.emb_weights = np.append( 74 | params.emb_weights, 75 | np.zeros((params.num_codes + 1, params.numeric_size)), 76 | axis=1, 77 | ) 78 | params.emb_weights = np.append(params.emb_weights, numeric_embeddings, axis=0) 79 | else: 80 | params.numeric_size = 0 81 | if "time_input" in names: 82 | params.use_time = True 83 | else: 84 | params.use_time = False 85 | return params 86 | 87 | 88 | class FreezePadding_Non_Negative(Constraint): 89 | """Freezes the last weight to be near 0 and prevents non-negative embeddings 90 | 91 | :param Constraint: Keras sequence constraint 92 | :type Constraint: :class:`tensorflow.keras.constraints.Constraint` 93 | :return: padded tensorflow tensor 94 | :rtype: :class:`tensorflow.Tensor` 95 | """ 96 | 97 | def __call__(self, w): 98 | other_weights = K.cast(K.greater_equal(w, 0)[:-1], K.floatx()) 99 | last_weight = K.cast( 100 | K.equal(K.reshape(w[-1, :], (1, K.shape(w)[1])), 0.0), K.floatx() 101 | ) 102 | appended = K.concatenate([other_weights, last_weight], axis=0) 103 | w *= appended 104 | return w 105 | 106 | 107 | class FreezePadding(Constraint): 108 | """Freezes the last weight to be near 0. 109 | 110 | :param Constraint: Keras sequence constraint 111 | :type Constraint: :class:`tensorflow.keras.constraints.Constraint` 112 | :return: padded tensorflow tensor 113 | :rtype: :class:`tensorflow.Tensor` 114 | """ 115 | 116 | def __call__(self, w): 117 | other_weights = K.cast(K.ones(K.shape(w))[:-1], K.floatx()) 118 | last_weight = K.cast( 119 | K.equal(K.reshape(w[-1, :], (1, K.shape(w)[1])), 0.0), K.floatx() 120 | ) 121 | appended = K.concatenate([other_weights, last_weight], axis=0) 122 | w *= appended 123 | return w 124 | 125 | 126 | class SequenceBuilder(Sequence): 127 | """Class to properly construct data to sequences 128 | 129 | :param Sequence: Customized Sequence class for generating batches of data 130 | :type Sequence: :class:`tensorflow.keras.utils.Sequence` 131 | """ 132 | 133 | def __init__(self, data, model_parameters, ARGS): 134 | # Receive all appropriate data 135 | self.codes = data[0] 136 | index = 1 137 | if model_parameters.numeric_size: 138 | self.numeric = data[index] 139 | index += 1 140 | 141 | if model_parameters.use_time: 142 | self.time = data[index] 143 | 144 | self.num_codes = model_parameters.num_codes 145 | self.batch_size = ARGS.batch_size 146 | self.numeric_size = model_parameters.numeric_size 147 | self.use_time = model_parameters.use_time 148 | 149 | def __len__(self): 150 | """Compute number of batches. 151 | Add extra batch if the data doesn't exactly divide into batches 152 | """ 153 | if len(self.codes) % self.batch_size == 0: 154 | return len(self.codes) // self.batch_size 155 | return len(self.codes) // self.batch_size + 1 156 | 157 | def __getitem__(self, idx): 158 | """Get batch of specific index""" 159 | 160 | def pad_data(data, length_visits, length_codes, pad_value=0): 161 | """Pad data to desired number of visits and codes inside each visit""" 162 | zeros = np.full((len(data), length_visits, length_codes), pad_value) 163 | for steps, mat in zip(data, zeros): 164 | if steps != [[-1]]: 165 | for step, mhot in zip(steps, mat[-len(steps) :]): 166 | # Populate the data into the appropriate visit 167 | mhot[: len(step)] = step 168 | 169 | return zeros 170 | 171 | # Compute reusable batch slice 172 | batch_slice = slice(idx * self.batch_size, (idx + 1) * self.batch_size) 173 | x_codes = self.codes[batch_slice] 174 | # Max number of visits and codes inside the visit for this batch 175 | pad_length_visits = max(map(len, x_codes)) 176 | pad_length_codes = max(map(lambda x: max(map(len, x)), x_codes)) 177 | # Number of elements in a batch (useful in case of partial batches) 178 | length_batch = len(x_codes) 179 | # Pad data 180 | x_codes = pad_data(x_codes, pad_length_visits, pad_length_codes, self.num_codes) 181 | outputs = [x_codes] 182 | # Add numeric data if necessary 183 | if self.numeric_size: 184 | x_numeric = self.numeric[batch_slice] 185 | x_numeric = pad_data(x_numeric, pad_length_visits, self.numeric_size, -99.0) 186 | outputs.append(x_numeric) 187 | # Add time data if necessary 188 | if self.use_time: 189 | x_time = sequence.pad_sequences( 190 | self.time[batch_slice], 191 | dtype=np.float32, 192 | maxlen=pad_length_visits, 193 | value=+99, 194 | ).reshape(length_batch, pad_length_visits, 1) 195 | outputs.append(x_time) 196 | 197 | return outputs 198 | 199 | 200 | def read_data(model_parameters, path_data, path_dictionary): 201 | """Read test data used for scoring 202 | 203 | :param model_parameters: parameters of model 204 | :type model_parameters: str 205 | :param str path_data: path to test data 206 | :param str path_dictionary: path to code idx dictionary 207 | :return: tuple for data and classifier arrays 208 | :rtype: tuple( list[class:`numpy.ndarray`] , :class:`numpy.ndarray`) 209 | """ 210 | 211 | data = pd.read_pickle(path_data) 212 | data_output = [data["codes"].values] 213 | 214 | if model_parameters.numeric_size: 215 | data_output.append(data["numerics"].values) 216 | if model_parameters.use_time: 217 | data_output.append(data["to_event"].values) 218 | 219 | with open(path_dictionary, "rb") as f: 220 | dictionary = pickle.load(f) 221 | 222 | dictionary[model_parameters.num_codes] = "PADDING" 223 | return data_output, dictionary 224 | 225 | 226 | def get_importances(alphas, betas, patient_data, model_parameters, dictionary): 227 | """Construct dataframes that interprets each visit of the given patient""" 228 | 229 | importances = [] 230 | codes = patient_data[0][0] 231 | index = 1 232 | if model_parameters.numeric_size: 233 | numerics = patient_data[index][0] 234 | index += 1 235 | 236 | if model_parameters.use_time: 237 | time = patient_data[index][0].reshape((len(codes),)) 238 | else: 239 | time = np.arange(len(codes)) 240 | for i in range(len(patient_data[0][0])): 241 | visit_codes = codes[i] 242 | visit_beta = betas[i] 243 | visit_alpha = alphas[i][0] 244 | relevant_indices = np.append( 245 | visit_codes, 246 | range( 247 | model_parameters.num_codes + 1, 248 | model_parameters.num_codes + 1 + model_parameters.numeric_size, 249 | ), 250 | ).astype(np.int32) 251 | values = np.full(fill_value="Diagnosed", shape=(len(visit_codes),)) 252 | if model_parameters.numeric_size: 253 | visit_numerics = numerics[i] 254 | values = np.append(values, visit_numerics) 255 | values_mask = np.array( 256 | [1.0 if value == "Diagnosed" else value for value in values], 257 | dtype=np.float32, 258 | ) 259 | beta_scaled = visit_beta * model_parameters.emb_weights[relevant_indices] 260 | output_scaled = np.dot(beta_scaled, model_parameters.output_weights) 261 | alpha_scaled = values_mask * visit_alpha * output_scaled 262 | df_visit = pd.DataFrame( 263 | { 264 | "status": values, 265 | "feature": [dictionary[index] for index in relevant_indices], 266 | "importance_feature": alpha_scaled[:, 0], 267 | "importance_visit": visit_alpha, 268 | "to_event": time[i], 269 | }, 270 | columns=[ 271 | "status", 272 | "feature", 273 | "importance_feature", 274 | "importance_visit", 275 | "to_event", 276 | ], 277 | ) 278 | df_visit = df_visit[df_visit["feature"] != "PADDING"] 279 | df_visit.sort_values(["importance_feature"], ascending=False, inplace=True) 280 | importances.append(df_visit) 281 | 282 | return importances 283 | 284 | 285 | def get_predictions(model, data, model_parameters, ARGS): 286 | """Construct dataframes that interpret each visit of the given patient""" 287 | 288 | test_generator = SequenceBuilder(data, model_parameters, ARGS) 289 | preds = model.predict_generator( 290 | generator=test_generator, 291 | max_queue_size=15, 292 | use_multiprocessing=True, 293 | verbose=1, 294 | workers=3, 295 | ) 296 | return preds 297 | 298 | 299 | def main(ARGS): 300 | """Main Body of the code""" 301 | print("Loading Model and Extracting Parameters") 302 | model, model_with_attention = import_model(ARGS.path_model) 303 | model_parameters = get_model_parameters(model) 304 | print("Reading Data") 305 | data, dictionary = read_data(model_parameters, ARGS.path_data, ARGS.path_dictionary) 306 | probabilities = get_predictions(model, data, model_parameters, ARGS) 307 | ARGS.batch_size = 1 308 | data_generator = SequenceBuilder(data, model_parameters, ARGS) 309 | while 1: 310 | patient_id = int(input("Input Patient Order Number: ")) 311 | if patient_id > len(data[0]) - 1: 312 | print("Invalid ID, there are only {} patients".format(len(data[0]))) 313 | elif patient_id < 0: 314 | print("Only Positive IDs are accepted") 315 | else: 316 | print("Patients probability: {}".format(probabilities[patient_id, 0, 0])) 317 | proceed = str(input("Output predictions? (y/n): ")) 318 | if proceed == "y": 319 | patient_data = data_generator.__getitem__(patient_id) 320 | proba, alphas, betas = model_with_attention.predict_on_batch( 321 | patient_data 322 | ) 323 | visits = get_importances( 324 | alphas[0], betas[0], patient_data, model_parameters, dictionary 325 | ) 326 | for visit in visits: 327 | print(visit) 328 | 329 | 330 | def parse_arguments(parser): 331 | """Read user arguments""" 332 | parser.add_argument( 333 | "--path_model", 334 | type=str, 335 | default="Model/weights.01.hdf5", 336 | help="Path to the model to evaluate", 337 | ) 338 | parser.add_argument( 339 | "--path_data", 340 | type=str, 341 | default="data/data_test.pkl", 342 | help="Path to evaluation data", 343 | ) 344 | parser.add_argument( 345 | "--path_dictionary", 346 | type=str, 347 | default="data/dictionary.pkl", 348 | help="Path to codes dictionary", 349 | ) 350 | parser.add_argument( 351 | "--batch_size", 352 | type=int, 353 | default=32, 354 | help="Batch size for initial probability predictions", 355 | ) 356 | # parser.add_argument('--id', type=int, default=0, 357 | # help='Id of the patient being interpreted') 358 | args = parser.parse_args() 359 | 360 | return args 361 | 362 | 363 | if __name__ == "__main__": 364 | 365 | PARSER = argparse.ArgumentParser( 366 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 367 | ) 368 | ARGS = parse_arguments(PARSER) 369 | main(ARGS) 370 | -------------------------------------------------------------------------------- /retain_train.py: -------------------------------------------------------------------------------- 1 | """Implementation of RETAIN Keras from Edward Choi""" 2 | import os 3 | import argparse 4 | import numpy as np 5 | import pandas as pd 6 | import tensorflow as tf 7 | import tensorflow.keras.layers as L 8 | from tensorflow.keras import backend as K 9 | from tensorflow.keras.models import Model 10 | from tensorflow.keras.callbacks import ModelCheckpoint, Callback 11 | from tensorflow.keras.preprocessing import sequence 12 | from tensorflow.keras.utils import Sequence 13 | from tensorflow.keras.regularizers import l2 14 | from tensorflow.keras.constraints import non_neg, Constraint 15 | from sklearn.metrics import ( 16 | roc_auc_score, 17 | average_precision_score, 18 | precision_recall_curve, 19 | ) 20 | 21 | 22 | class SequenceBuilder(Sequence): 23 | """ 24 | Class to properly construct data into sequences prior to training. 25 | 26 | :param Sequence: Customized Sequence class for generating batches of data 27 | :type Sequence: :class:`tensorflow.keras.utils.Sequence` 28 | :returns: Padded, dense data used for Sequence construction (codes,visits,numerics) 29 | :rtype: :class:`ndarray` 30 | """ 31 | 32 | def __init__(self, data, target, batch_size, ARGS, target_out=True): 33 | """ 34 | Instantiates the code. 35 | 36 | :param data: Training data sequences (codes, visits, numerics) 37 | :type data: list[:class:`ndarray`] 38 | :param target: List of target values 39 | :type target: :class:`numpy.ndarray` 40 | :param batch_size: Number of samples in each batch 41 | :type batch_size: int 42 | :param ARGS: Arguments object containing user-specified parameters 43 | :type ARGS: :class:`argparse.Namespace` 44 | :param target_out: If `True` (default), then return the target values 45 | :type target_out: bool 46 | :returns: data sequences (codes, visits, numerics) 47 | :rtype: list[:class:`ndarray`] 48 | """ 49 | 50 | # Receive all appropriate data 51 | self.codes = data[0] 52 | index = 1 53 | if ARGS.numeric_size: 54 | self.numeric = data[index] 55 | index += 1 56 | 57 | if ARGS.use_time: 58 | self.time = data[index] 59 | 60 | self.num_codes = ARGS.num_codes 61 | self.target = target 62 | self.batch_size = batch_size 63 | self.target_out = target_out 64 | self.numeric_size = ARGS.numeric_size 65 | self.use_time = ARGS.use_time 66 | self.n_steps = ARGS.n_steps 67 | # self.balance = (1-(float(sum(target))/len(target)))/(float(sum(target))/len(target)) 68 | 69 | def __len__(self): 70 | """ 71 | Compute number of batches. 72 | Add extra batch if the data doesn't exactly divide into batches 73 | 74 | :return: Number of batches per epoch 75 | :rtype: int 76 | """ 77 | 78 | if len(self.codes) % self.batch_size == 0: 79 | return len(self.codes) // self.batch_size 80 | return len(self.codes) // self.batch_size + 1 81 | 82 | def __getitem__(self, idx): 83 | """ 84 | Get batch of specific index. 85 | 86 | :param idx: The index number for the batch to return 87 | :type idx: int 88 | :return: Padded data sequences (codes, visits, numerics) 89 | :rtype: list[:class:`ndarray`] 90 | """ 91 | 92 | def pad_data(data, length_visits, length_codes, pad_value=0): 93 | """ 94 | Pad numpy array to shift sparse matrix to dense matrix 95 | 96 | :param data: Training data sequences (codes, visits, numerics) 97 | :type data: list[:class:`ndarray`] 98 | :param int length_visits: max visit count in batch 99 | :param int length_codes: max codes length in batch 100 | :param pad_value: numeric value to represent padding, defaults to 0 101 | :type pad_value: int, optional 102 | :return: 'dense' array with padding for codes and visits 103 | :rtype: :class:`numpy.ndarray` 104 | """ 105 | 106 | zeros = np.full((len(data), length_visits, length_codes), pad_value) 107 | for steps, mat in zip(data, zeros): 108 | if steps != [[-1]]: 109 | for step, mhot in zip(steps, mat[-len(steps) :]): 110 | # Populate the data into the appropriate visit 111 | mhot[: len(step)] = step 112 | 113 | return zeros 114 | 115 | # Compute reusable batch slice 116 | batch_slice = slice(idx * self.batch_size, (idx + 1) * self.batch_size) 117 | x_codes = self.codes[batch_slice] 118 | # Max number of visits and codes inside the visit for this batch 119 | pad_length_visits = min(max(map(len, x_codes)), self.n_steps) 120 | pad_length_codes = max(map(lambda x: max(map(len, x)), x_codes)) 121 | # Number of elements in a batch (useful in case of partial batches) 122 | length_batch = len(x_codes) 123 | # Pad data 124 | x_codes = pad_data(x_codes, pad_length_visits, pad_length_codes, self.num_codes) 125 | outputs = [x_codes] 126 | # Add numeric data if necessary 127 | if self.numeric_size: 128 | x_numeric = self.numeric[batch_slice] 129 | x_numeric = pad_data(x_numeric, pad_length_visits, self.numeric_size, -99.0) 130 | outputs.append(x_numeric) 131 | # Add time data if necessary 132 | if self.use_time: 133 | x_time = sequence.pad_sequences( 134 | self.time[batch_slice], 135 | dtype=np.float32, 136 | maxlen=pad_length_visits, 137 | value=+99, 138 | ).reshape(length_batch, pad_length_visits, 1) 139 | outputs.append(x_time) 140 | 141 | # Add target if necessary (training vs validation) 142 | if self.target_out: 143 | target = self.target[batch_slice].reshape(length_batch, 1, 1) 144 | # sample_weights = (target*(self.balance-1)+1).reshape(length_batch, 1) 145 | # In our experiments sample weights provided worse results 146 | return (outputs, target) 147 | 148 | return outputs 149 | 150 | 151 | class FreezePadding_Non_Negative(Constraint): 152 | """ 153 | Freezes the last weight to be near 0 - permit negative weights. 154 | 155 | :param Constraint: Keras sequence constraint 156 | :type Constraint: :class:`tensorflow.keras.constraints.Constraint` 157 | :return: padded tensor or variable 158 | :rtype: :class:`tensorflow.Tensor` 159 | """ 160 | 161 | def __call__(self, w): 162 | other_weights = K.cast(K.greater_equal(w, 0)[:-1], K.floatx()) 163 | last_weight = K.cast( 164 | K.equal(K.reshape(w[-1, :], (1, K.shape(w)[1])), 0.0), K.floatx() 165 | ) 166 | appended = K.concatenate([other_weights, last_weight], axis=0) 167 | w *= appended 168 | return w 169 | 170 | 171 | class FreezePadding(Constraint): 172 | """ 173 | Freezes the last weight to be near 0 - don't permit negative weights. 174 | 175 | :param Constraint: Keras sequence constraint 176 | :type Constraint: :class:`tensorflow.keras.constraints.Constraint` 177 | :return: padded tensor or variable 178 | :rtype: :class:`tensorflow.Tensor` 179 | """ 180 | 181 | def __call__(self, w): 182 | other_weights = K.cast(K.ones(K.shape(w))[:-1], K.floatx()) 183 | last_weight = K.cast( 184 | K.equal(K.reshape(w[-1, :], (1, K.shape(w)[1])), 0.0), K.floatx() 185 | ) 186 | appended = K.concatenate([other_weights, last_weight], axis=0) 187 | w *= appended 188 | return w 189 | 190 | 191 | def read_data(ARGS): 192 | """Read the data from provided paths and assign it into lists""" 193 | 194 | data_train_df = pd.read_pickle(ARGS.path_data_train) 195 | data_test_df = pd.read_pickle(ARGS.path_data_test) 196 | y_train = pd.read_pickle(ARGS.path_target_train)["target"].values 197 | y_test = pd.read_pickle(ARGS.path_target_test)["target"].values 198 | data_output_train = [data_train_df["codes"].values] 199 | data_output_test = [data_test_df["codes"].values] 200 | 201 | if ARGS.numeric_size: 202 | data_output_train.append(data_train_df["numerics"].values) 203 | data_output_test.append(data_test_df["numerics"].values) 204 | if ARGS.use_time: 205 | data_output_train.append(data_train_df["to_event"].values) 206 | data_output_test.append(data_test_df["to_event"].values) 207 | return (data_output_train, y_train, data_output_test, y_test) 208 | 209 | 210 | def model_create(ARGS): 211 | """ 212 | Create tensorflow DAG for training a model, and then compile/train 213 | the model at the end. 214 | 215 | :param ARGS: Arguments object containing user-specified parameters 216 | :type ARGS: :class:`argparse.Namespace` 217 | :return: trained/compiled Keras model 218 | :rtype: :class:`tensorflow.keras..Model` 219 | """ 220 | 221 | def retain(ARGS): 222 | """ 223 | Helper function to create DAG of Keras Layers via functional API approach. 224 | The Keras Layer design is mimicking RETAIN architecture. 225 | :param ARGS: Arguments object containing user-specified parameters 226 | :type ARGS: :class:`argparse.Namespace` 227 | :return: Keras model 228 | :rtype: :class:`tensorflow.keras.Model` 229 | """ 230 | 231 | # Define the constant for model saving 232 | reshape_size = ARGS.emb_size + ARGS.numeric_size 233 | if ARGS.allow_negative: 234 | embeddings_constraint = FreezePadding() 235 | beta_activation = "tanh" 236 | output_constraint = None 237 | else: 238 | embeddings_constraint = FreezePadding_Non_Negative() 239 | beta_activation = "sigmoid" 240 | output_constraint = non_neg() 241 | 242 | def reshape(data): 243 | """Reshape the context vectors to 3D vector""" 244 | return K.reshape(x=data, shape=(K.shape(data)[0], 1, reshape_size)) 245 | 246 | # Code Input 247 | codes = L.Input((None, None), name="codes_input") 248 | inputs_list = [codes] 249 | # Calculate embedding for each code and sum them to a visit level 250 | codes_embs_total = L.Embedding( 251 | ARGS.num_codes + 1, ARGS.emb_size, name="embedding" 252 | )(codes) 253 | codes_embs = L.Lambda(lambda x: K.sum(x, axis=2))(codes_embs_total) 254 | # Numeric input if needed 255 | if ARGS.numeric_size: 256 | numerics = L.Input((None, ARGS.numeric_size), name="numeric_input") 257 | inputs_list.append(numerics) 258 | full_embs = L.concatenate([codes_embs, numerics], name="catInp") 259 | else: 260 | full_embs = codes_embs 261 | 262 | # Apply dropout on inputs 263 | full_embs = L.Dropout(ARGS.dropout_input)(full_embs) 264 | 265 | # Time input if needed 266 | if ARGS.use_time: 267 | time = L.Input((None, 1), name="time_input") 268 | inputs_list.append(time) 269 | time_embs = L.concatenate([full_embs, time], name="catInp2") 270 | else: 271 | time_embs = full_embs 272 | 273 | # Setup Layers 274 | # This implementation uses Bidirectional LSTM instead of reverse order 275 | # (see https://github.com/mp2893/retain/issues/3 for more details) 276 | 277 | alpha = L.Bidirectional( 278 | L.LSTM(ARGS.recurrent_size, return_sequences=True, implementation=2), 279 | name="alpha", 280 | ) 281 | beta = L.Bidirectional( 282 | L.LSTM(ARGS.recurrent_size, return_sequences=True, implementation=2), 283 | name="beta", 284 | ) 285 | 286 | alpha_dense = L.Dense(1, kernel_regularizer=l2(ARGS.l2)) 287 | beta_dense = L.Dense( 288 | ARGS.emb_size + ARGS.numeric_size, 289 | activation=beta_activation, 290 | kernel_regularizer=l2(ARGS.l2), 291 | ) 292 | 293 | # Compute alpha, visit attention 294 | alpha_out = alpha(time_embs) 295 | alpha_out = L.TimeDistributed(alpha_dense, name="alpha_dense_0")(alpha_out) 296 | alpha_out = L.Softmax(name="softmax_1", axis=1)(alpha_out) 297 | # Compute beta, codes attention 298 | beta_out = beta(time_embs) 299 | beta_out = L.TimeDistributed(beta_dense, name="beta_dense_0")(beta_out) 300 | # Compute context vector based on attentions and embeddings 301 | c_t = L.Multiply()([alpha_out, beta_out, full_embs]) 302 | c_t = L.Lambda(lambda x: K.sum(x, axis=1))(c_t) 303 | # Reshape to 3d vector for consistency between Many to Many and Many to One implementations 304 | contexts = L.Lambda(reshape)(c_t) 305 | 306 | # Make a prediction 307 | contexts = L.Dropout(ARGS.dropout_context)(contexts) 308 | output_layer = L.Dense( 309 | 1, 310 | activation="sigmoid", 311 | name="dOut", 312 | kernel_regularizer=l2(ARGS.l2), 313 | kernel_constraint=output_constraint, 314 | ) 315 | 316 | # TimeDistributed is used for consistency 317 | # between Many to Many and Many to One implementations 318 | output = L.TimeDistributed(output_layer, name="time_distributed_out")(contexts) 319 | # Define the model with appropriate inputs 320 | model = Model(inputs=inputs_list, outputs=[output]) 321 | 322 | return model 323 | 324 | # Set Tensorflow to grow GPU memory consumption instead of grabbing all of it at once 325 | K.clear_session() 326 | config = tf.compat.v1.ConfigProto( 327 | allow_soft_placement=True, log_device_placement=False 328 | ) 329 | config.gpu_options.allow_growth = True 330 | tfsess = tf.compat.v1.Session(config=config) 331 | tf.compat.v1.keras.backend.set_session(tfsess) 332 | model_final = retain(ARGS) 333 | 334 | # Compile the model - adamax has produced best results in our experiments 335 | model_final.compile( 336 | optimizer="adamax", 337 | loss="binary_crossentropy", 338 | metrics=["accuracy"], 339 | sample_weight_mode="temporal", 340 | ) 341 | 342 | return model_final 343 | 344 | 345 | def create_callbacks(model, data, ARGS): 346 | """At the end of each epoch, determine various callback statistics (e.g. ROC-AUC) 347 | 348 | :param model: Keras model 349 | :type model: :class:`tensorflow.keras.Model` 350 | :param data: Validation data - data sequences (codes, visits, numeric values) and classifier. 351 | :type data: tuple( list( :class:`ndarray`), :class:`ndarray`) 352 | :param ARGS: Arguments object containing user-specified parameters 353 | :type ARGS: :class:`argparse.Namespace` 354 | :return: various callback objects - naming convention for saved HDF5 files, custom logging class, \ 355 | reduced learning rate 356 | :rtype: tuple(:class:`tensorflow.keras.callbacks.ModelCheckpoint`, :class:`LogEval`, \ 357 | :class:`tensorflow.keras.callbacks.ReduceLROnPlateau`) 358 | """ 359 | 360 | class LogEval(Callback): 361 | """Logging Callback""" 362 | 363 | def __init__(self, filepath, model, data, ARGS, interval=1): 364 | """Constructor for logging class 365 | 366 | :param str filepath: path for log file & Keras HDF5 files 367 | :param model: model from training used for end-of-epoch analytics 368 | :type model: :class:`keras.engine.training.Model` 369 | :param data: Validation data used for end-of-epoch analytics \ 370 | (e.g. data sequences (codes, visits, numerics) and classifier) 371 | :type data: tuple(list[:class:`ndarray`],:class:`ndarray`) 372 | :param ARGS: Arguments object containing user-specified parameters 373 | :type ARGS: :class:`argparse.Namespace` 374 | :param interval: Interval for logging (e.g. every epoch), defaults to 1 375 | :type interval: int, optional 376 | """ 377 | 378 | super(Callback, self).__init__() 379 | self.filepath = filepath 380 | self.interval = interval 381 | self.data_test, self.y_test = data 382 | self.generator = SequenceBuilder( 383 | data=self.data_test, 384 | target=self.y_test, 385 | batch_size=ARGS.batch_size, 386 | ARGS=ARGS, 387 | target_out=False, 388 | ) 389 | self.model = model 390 | 391 | def on_epoch_end(self, epoch, logs={}): 392 | 393 | # Compute ROC-AUC and average precision the validation data every interval epochs 394 | if epoch % self.interval == 0: 395 | 396 | # Generate predictions 397 | preds = [] 398 | for x in self.generator: 399 | batch_pred = self.model.predict_on_batch( 400 | x=x, 401 | ) 402 | preds.append(batch_pred.flatten()) 403 | y_pred = np.concatenate(preds, axis=0) 404 | 405 | # Compute performance 406 | score_roc = roc_auc_score(self.y_test, y_pred) 407 | score_pr = average_precision_score(self.y_test, y_pred) 408 | 409 | # Create log file if it doesn't exist, otherwise write to it 410 | if os.path.exists(self.filepath): 411 | append_write = "a" 412 | else: 413 | append_write = "w" 414 | with open(self.filepath, append_write) as file_output: 415 | file_output.write( 416 | "\nEpoch: {:d}- ROC-AUC: {:.6f} ; PR-AUC: {:.6f}".format( 417 | epoch, score_roc, score_pr 418 | ) 419 | ) 420 | 421 | # Print performance 422 | print( 423 | "\nEpoch: {:d} - ROC-AUC: {:.6f} PR-AUC: {:.6f}".format( 424 | epoch, score_roc, score_pr 425 | ) 426 | ) 427 | 428 | # Create callbacks 429 | if not os.path.exists(ARGS.directory): 430 | os.makedirs(ARGS.directory) 431 | checkpoint = ModelCheckpoint(filepath=ARGS.directory + "/weights.{epoch:02d}.hdf5") 432 | log = LogEval(ARGS.directory + "/log.txt", model, data, ARGS) 433 | return (checkpoint, log) 434 | 435 | 436 | def train_model(model, data_train, y_train, data_test, y_test, ARGS): 437 | """ 438 | Class to hold callback artifacts, Sequence builder of training data, model training 439 | generator 440 | 441 | :param model: Keras model 442 | :type model: :class:`tensorflow.keras.Model` 443 | :param data_train: List with sub-arrays for medical codes, visits, and demographics 444 | :type data_train: list(:class:`numpy.ndarray`) 445 | :param y_train: Array with classifiers for training set 446 | :type y_train: :class:`numpy.ndarray` 447 | :param data_test: List with sub-arrays for medical codes, visits, and demographics 448 | :type data_test: list(:class:`numpy.ndarray`) 449 | :param y_test: Array with classifiers for test set 450 | :type y_test: :class:`numpy.ndarray` 451 | :param ARGS: Arguments object containing user-specified parameters 452 | :type ARGS: :class:`argparse.Namespace` 453 | """ 454 | 455 | checkpoint, log = create_callbacks(model, (data_test, y_test), ARGS) 456 | train_generator = SequenceBuilder( 457 | data=data_train, target=y_train, batch_size=ARGS.batch_size, ARGS=ARGS 458 | ) 459 | model.fit( 460 | x=train_generator, 461 | epochs=ARGS.epochs, 462 | max_queue_size=15, 463 | use_multiprocessing=True, 464 | callbacks=[checkpoint, log], 465 | verbose=1, 466 | workers=3, 467 | initial_epoch=0, 468 | ) 469 | 470 | 471 | def main(ARGS): 472 | """Main function""" 473 | print("Reading Data...") 474 | data_train, y_train, data_test, y_test = read_data(ARGS) 475 | 476 | print("Creating Model...") 477 | model = model_create(ARGS) 478 | 479 | print("Training Model...") 480 | train_model( 481 | model=model, 482 | data_train=data_train, 483 | y_train=y_train, 484 | data_test=data_test, 485 | y_test=y_test, 486 | ARGS=ARGS, 487 | ) 488 | 489 | 490 | def parse_arguments(parser): 491 | """Read user arguments""" 492 | parser.add_argument( 493 | "--num_codes", type=int, required=True, help="Number of medical codes" 494 | ) 495 | parser.add_argument( 496 | "--numeric_size", type=int, default=0, help="Size of numeric inputs, 0 if none" 497 | ) 498 | parser.add_argument( 499 | "--use_time", 500 | action="store_true", 501 | help="If argument is present the time input will be used", 502 | ) 503 | parser.add_argument( 504 | "--emb_size", type=int, default=200, help="Size of the embedding layer" 505 | ) 506 | parser.add_argument("--epochs", type=int, default=1, help="Number of epochs") 507 | parser.add_argument( 508 | "--n_steps", 509 | type=int, 510 | default=300, 511 | help="Maximum number of visits after which the data is truncated", 512 | ) 513 | parser.add_argument( 514 | "--recurrent_size", type=int, default=200, help="Size of the recurrent layers" 515 | ) 516 | parser.add_argument( 517 | "--path_data_train", 518 | type=str, 519 | default="data/data_train.pkl", 520 | help="Path to train data", 521 | ) 522 | parser.add_argument( 523 | "--path_data_test", 524 | type=str, 525 | default="data/data_test.pkl", 526 | help="Path to test data", 527 | ) 528 | parser.add_argument( 529 | "--path_target_train", 530 | type=str, 531 | default="data/target_train.pkl", 532 | help="Path to train target", 533 | ) 534 | parser.add_argument( 535 | "--path_target_test", 536 | type=str, 537 | default="data/target_test.pkl", 538 | help="Path to test target", 539 | ) 540 | parser.add_argument("--batch_size", type=int, default=32, help="Batch Size") 541 | parser.add_argument( 542 | "--dropout_input", type=float, default=0.0, help="Dropout rate for embedding" 543 | ) 544 | parser.add_argument( 545 | "--dropout_context", 546 | type=float, 547 | default=0.0, 548 | help="Dropout rate for context vector", 549 | ) 550 | parser.add_argument( 551 | "--l2", type=float, default=0.0, help="L2 regularitzation value" 552 | ) 553 | parser.add_argument( 554 | "--directory", 555 | type=str, 556 | default="Model", 557 | help="Directory to save the model and the log file to", 558 | ) 559 | parser.add_argument( 560 | "--allow_negative", 561 | action="store_true", 562 | help="If argument is present the negative weights for embeddings/attentions\ 563 | will be allowed (original RETAIN implementaiton)", 564 | ) 565 | args = parser.parse_args() 566 | 567 | return args 568 | 569 | 570 | if __name__ == "__main__": 571 | 572 | PARSER = argparse.ArgumentParser( 573 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 574 | ) 575 | ARGS = parse_arguments(PARSER) 576 | main(ARGS) 577 | --------------------------------------------------------------------------------