├── requirements.txt ├── Dockerfile ├── train_model.py ├── LICENSE ├── remove_data.py ├── remove_labels.py ├── run_model.py ├── truncate_data.py ├── README.md ├── team_code.py └── helper_code.py /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.3 2 | scipy==1.10.1 3 | scikit-learn==1.2.2 4 | joblib==1.2.0 5 | mne==1.4.0 6 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10.1-buster 2 | 3 | ## DO NOT EDIT these 3 lines. 4 | RUN mkdir /challenge 5 | COPY ./ /challenge 6 | WORKDIR /challenge 7 | 8 | ## Install your dependencies here using apt install, etc. 9 | 10 | ## Include the following line if you have a requirements.txt file. 11 | RUN pip install -r requirements.txt 12 | -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Do *not* edit this script. Changes will be discarded so that we can train the models consistently. 4 | 5 | # This file contains functions for training models for the Challenge. You can run it as follows: 6 | # 7 | # python train_model.py data model 8 | # 9 | # where 'data' is a folder containing the Challenge data and 'model' is a folder for saving your model. 10 | 11 | import sys 12 | from helper_code import is_integer 13 | from team_code import train_challenge_model 14 | 15 | if __name__ == '__main__': 16 | # Parse the arguments. 17 | if not (len(sys.argv) == 3 or len(sys.argv) == 4): 18 | raise Exception('Include the data and model folders as arguments, e.g., python train_model.py data model.') 19 | 20 | # Define the data and model foldes. 21 | data_folder = sys.argv[1] 22 | model_folder = sys.argv[2] 23 | 24 | # Change the level of verbosity; helpful for debugging. 25 | if len(sys.argv)==4 and is_integer(sys.argv[3]): 26 | verbose = int(sys.argv[3]) 27 | else: 28 | verbose = 1 29 | 30 | train_challenge_model(data_folder, model_folder, verbose) ### Teams: Implement this function!!! 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2022, 2023 PhysioNet 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /remove_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Load libraries. 4 | import os, sys, shutil, argparse 5 | 6 | # Parse arguments. 7 | def get_parser(): 8 | description = 'Remove data from the dataset.' 9 | parser = argparse.ArgumentParser(description=description) 10 | parser.add_argument('-i', '--input_folder', type=str, required=True) 11 | parser.add_argument('-p', '--patient_ids', nargs='*', type=str, required=False, default=[]) 12 | parser.add_argument('-o', '--output_folder', type=str, required=True) 13 | return parser 14 | 15 | # Find folders with data files. 16 | def find_data_folders(root_folder): 17 | data_folders = list() 18 | for x in sorted(os.listdir(root_folder)): 19 | data_folder = os.path.join(root_folder, x) 20 | if os.path.isdir(data_folder): 21 | data_file = os.path.join(data_folder, x + '.txt') 22 | if os.path.isfile(data_file): 23 | data_folders.append(x) 24 | return sorted(data_folders) 25 | 26 | # Run script. 27 | def run(args): 28 | # Use either the given patient IDs or all of the patient IDs. 29 | if args.patient_ids: 30 | patient_ids = args.patient_ids 31 | else: 32 | patient_ids = find_data_folders(args.input_folder) 33 | 34 | # Iterate over the patient IDs. 35 | for patient_id in patient_ids: 36 | input_path = os.path.join(args.input_folder, patient_id) 37 | output_path = os.path.join(args.output_folder, patient_id) 38 | os.makedirs(output_path, exist_ok=True) 39 | 40 | # Iterate over the files in each folder. 41 | for file_name in sorted(os.listdir(input_path)): 42 | file_root, file_ext = os.path.splitext(file_name) 43 | input_file = os.path.join(input_path, file_name) 44 | output_file = os.path.join(output_path, file_name) 45 | 46 | # If the file is not the binary signal data, then copy it. 47 | if not (file_ext == '.mat'): 48 | shutil.copy2(input_file, output_file) 49 | 50 | if __name__=='__main__': 51 | run(get_parser().parse_args(sys.argv[1:])) 52 | -------------------------------------------------------------------------------- /remove_labels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Load libraries. 4 | import os, sys, shutil, argparse 5 | 6 | # Parse arguments. 7 | def get_parser(): 8 | description = 'Remove labels from the dataset.' 9 | parser = argparse.ArgumentParser(description=description) 10 | parser.add_argument('-i', '--input_folder', type=str, required=True) 11 | parser.add_argument('-p', '--patient_ids', nargs='*', type=str, required=False, default=[]) 12 | parser.add_argument('-o', '--output_folder', type=str, required=True) 13 | return parser 14 | 15 | # Find folders with data files. 16 | def find_data_folders(root_folder): 17 | data_folders = list() 18 | for x in sorted(os.listdir(root_folder)): 19 | data_folder = os.path.join(root_folder, x) 20 | if os.path.isdir(data_folder): 21 | data_file = os.path.join(data_folder, x + '.txt') 22 | if os.path.isfile(data_file): 23 | data_folders.append(x) 24 | return sorted(data_folders) 25 | 26 | # Run script. 27 | def run(args): 28 | # Use either the given patient IDs or all of the patient IDs. 29 | if args.patient_ids: 30 | patient_ids = args.patient_ids 31 | else: 32 | patient_ids = find_data_folders(args.input_folder) 33 | 34 | # Iterate over the patient IDs. 35 | for patient_id in patient_ids: 36 | input_path = os.path.join(args.input_folder, patient_id) 37 | output_path = os.path.join(args.output_folder, patient_id) 38 | os.makedirs(output_path, exist_ok=True) 39 | 40 | # Iterate over the files in each folder. 41 | for file_name in sorted(os.listdir(input_path)): 42 | file_root, file_ext = os.path.splitext(file_name) 43 | input_file = os.path.join(input_path, file_name) 44 | output_file = os.path.join(output_path, file_name) 45 | 46 | # If the file does have the labels, then remove the labels and copy the rest of the file. 47 | if file_ext == '.txt' and file_root == patient_id: 48 | with open(input_file, 'r') as f: 49 | input_lines = f.readlines() 50 | output_lines = [l for l in input_lines if not (l.startswith('Outcome') or l.startswith('CPC'))] 51 | output_string = ''.join(output_lines) 52 | with open(output_file, 'w') as f: 53 | f.write(output_string) 54 | 55 | # Otherwise, copy the file as-is. 56 | else: 57 | shutil.copy2(input_file, output_file) 58 | 59 | if __name__=='__main__': 60 | run(get_parser().parse_args(sys.argv[1:])) 61 | -------------------------------------------------------------------------------- /run_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Do *not* edit this script. Changes will be discarded so that we can run the trained models consistently. 4 | 5 | # This file contains functions for running models for the Challenge. You can run it as follows: 6 | # 7 | # python run_model.py models data outputs 8 | # 9 | # where 'models' is a folder containing the your trained models, 'data' is a folder containing the Challenge data, and 'outputs' is a 10 | # folder for saving your models' outputs. 11 | 12 | import numpy as np, scipy as sp, os, sys 13 | from helper_code import * 14 | from team_code import load_challenge_models, run_challenge_models 15 | 16 | # Run model. 17 | def run_model(model_folder, data_folder, output_folder, allow_failures, verbose): 18 | # Load model(s). 19 | if verbose >= 1: 20 | print('Loading the Challenge models...') 21 | 22 | # You can use this function to perform tasks, such as loading your models, that you only need to perform once. 23 | models = load_challenge_models(model_folder, verbose) ### Teams: Implement this function!!! 24 | 25 | # Find the Challenge data. 26 | if verbose >= 1: 27 | print('Finding the Challenge data...') 28 | 29 | patient_ids = find_data_folders(data_folder) 30 | num_patients = len(patient_ids) 31 | 32 | if num_patients==0: 33 | raise Exception('No data were provided.') 34 | 35 | # Create a folder for the Challenge outputs if it does not already exist. 36 | os.makedirs(output_folder, exist_ok=True) 37 | 38 | # Run the team's model(s) on the Challenge data. 39 | if verbose >= 1: 40 | print('Running the Challenge models on the Challenge data...') 41 | 42 | # Iterate over the patients. 43 | for i in range(num_patients): 44 | if verbose >= 2: 45 | print(' {}/{}...'.format(i+1, num_patients)) 46 | 47 | patient_id = patient_ids[i] 48 | 49 | # Allow or disallow the model(s) to fail on parts of the data; this can be helpful for debugging. 50 | try: 51 | outcome_binary, outcome_probability, cpc = run_challenge_models(models, data_folder, patient_id, verbose) ### Teams: Implement this function!!! 52 | except: 53 | if allow_failures: 54 | if verbose >= 2: 55 | print('... failed.') 56 | outcome_binary, outcome_probability, cpc = float('nan'), float('nan'), float('nan') 57 | else: 58 | raise 59 | 60 | # Save Challenge outputs. 61 | os.makedirs(os.path.join(output_folder, patient_id), exist_ok=True) 62 | output_file = os.path.join(output_folder, patient_id, patient_id + '.txt') 63 | save_challenge_outputs(output_file, patient_id, outcome_binary, outcome_probability, cpc) 64 | 65 | if verbose >= 1: 66 | print('Done.') 67 | 68 | if __name__ == '__main__': 69 | # Parse the arguments. 70 | if not (len(sys.argv) == 4 or len(sys.argv) == 5): 71 | raise Exception('Include the model, data, and output folders as arguments, e.g., python run_model.py model data outputs.') 72 | 73 | # Define the model, data, and output folders. 74 | model_folder = sys.argv[1] 75 | data_folder = sys.argv[2] 76 | output_folder = sys.argv[3] 77 | 78 | # Allow or disallow the model to fail on parts of the data; helpful for debugging. 79 | allow_failures = False 80 | 81 | # Change the level of verbosity; helpful for debugging. 82 | if len(sys.argv)==5 and is_integer(sys.argv[4]): 83 | verbose = int(sys.argv[4]) 84 | else: 85 | verbose = 1 86 | 87 | run_model(model_folder, data_folder, output_folder, allow_failures, verbose) 88 | -------------------------------------------------------------------------------- /truncate_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Load libraries. 4 | import os, os.path, sys, shutil, argparse 5 | from helper_code import * 6 | 7 | # Parse arguments. 8 | def get_parser(): 9 | description = 'Truncate recordings to the provided time limit (in hours).' 10 | parser = argparse.ArgumentParser(description=description) 11 | parser.add_argument('-i', '--input_folder', type=str, required=True) 12 | parser.add_argument('-p', '--patient_ids', nargs='*', type=str, required=False, default=[]) 13 | parser.add_argument('-t', '--time_limit', type=float, required=True) 14 | parser.add_argument('-o', '--output_folder', type=str, required=True) 15 | return parser 16 | 17 | # Run script. 18 | def run(args): 19 | # Convert hours to seconds. 20 | time_limit = 3600 * args.time_limit 21 | 22 | # Identify the data folders. 23 | if args.patient_ids: 24 | patient_ids = args.patient_ids 25 | else: 26 | patient_ids = find_data_folders(args.input_folder) 27 | 28 | # Iterate over each folder. 29 | for patient_id in patient_ids: 30 | # Set the paths. 31 | input_path = os.path.join(args.input_folder, patient_id) 32 | output_path = os.path.join(args.output_folder, patient_id) 33 | 34 | # Create the output folder. 35 | if os.path.exists(input_path): 36 | os.makedirs(output_path, exist_ok=True) 37 | 38 | # Copy the patient metadata file. 39 | input_patient_metadata_file = os.path.join(input_path, patient_id + '.txt') 40 | output_patient_metadata_file = os.path.join(output_path, patient_id + '.txt') 41 | shutil.copy(input_patient_metadata_file, output_patient_metadata_file) 42 | 43 | # Copy the WFDB header and signal files for records that end before the end time. 44 | for filename in os.listdir(input_path): 45 | if not filename.startswith('.') and filename.endswith('.hea'): 46 | header_file = filename 47 | input_header_file = os.path.join(input_path, header_file) 48 | output_header_file = os.path.join(output_path, header_file) 49 | 50 | header_text = load_text_file(input_header_file) 51 | start_time = convert_hours_minutes_seconds_to_seconds(*get_start_time(header_text)) 52 | end_time = convert_hours_minutes_seconds_to_seconds(*get_end_time(header_text)) 53 | 54 | # If end time for a recording is before the time limit, then copy the recording. 55 | if end_time < time_limit: 56 | signal_files = set() 57 | for i, l in enumerate(header_text.split('\n')): 58 | arrs = [arr.strip() for arr in l.split(' ')] 59 | if i > 0 and not l.startswith('#') and len(arrs) > 0 and len(arrs[0]) > 0: 60 | signal_file = arrs[0] 61 | signal_files.add(signal_file) 62 | 63 | signal_files = sorted(signal_files) 64 | input_signal_files = [os.path.join(input_path, signal_file) for signal_file in signal_files] 65 | output_signal_files = [os.path.join(output_path, signal_file) for signal_file in signal_files] 66 | 67 | shutil.copy2(input_header_file, output_header_file) 68 | for input_signal_file, output_signal_file in zip(input_signal_files, output_signal_files): 69 | shutil.copy(input_signal_file, output_signal_file) 70 | 71 | # If the start time is before the time limit and the end time is after the time limit, then truncate the recording. 72 | elif start_time < time_limit and end_time >= time_limit: 73 | record_name = header_text.split(' ')[0] 74 | raise NotImplementedError('Part (but not all) of record {} exceeds the end time.'.format(record_name)) # All of the files in the dataset end on the hour. 75 | 76 | # If the start time is after the time limit, then do not copy or truncate the recording. 77 | elif start_time >= time_limit: 78 | pass 79 | 80 | if __name__=='__main__': 81 | run(get_parser().parse_args(sys.argv[1:])) 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Python example code for the George B. Moody PhysioNet Challenge 2023 2 | 3 | ## What's in this repository? 4 | 5 | This repository contains a simple example that illustrates how to format a Python entry for the George B. Moody PhysioNet Challenge 2023. We recommend that you use this repository as a template for your entry. You can remove some of the code, reuse other code, and add new code to create your entry. You do not need to use the models, features, and/or libraries in this example for your approach. We encourage a diversity of approaches for the Challenge. 6 | 7 | For this example, we implemented a random forest model with several features. This simple example is designed **not** not to perform well, so you should **not** use it as a baseline for your model's performance. You can try it by running the following commands on the Challenge training set. These commands should take a few minutes or less to run from start to finish on a recent personal computer. 8 | 9 | This code uses four main scripts, described below, to train and run a model for the Challenge. 10 | 11 | ## How do I run these scripts? 12 | 13 | You can install the dependencies for these scripts by creating a Docker image (see below) and running 14 | 15 | pip install -r requirements.txt 16 | 17 | You can train your model by running 18 | 19 | python train_model.py training_data model 20 | 21 | where 22 | 23 | - `training_data` (input; required) is a folder with the training data files and 24 | - `model` (output; required) is a folder for saving your model. 25 | 26 | You can run your trained model by running 27 | 28 | python run_model.py model test_data test_outputs 29 | 30 | where 31 | 32 | - `model` (input; required) is a folder for loading your model, 33 | - `test_data` (input; required) is a folder with the validation or test data files (you can use the training data for debugging and cross-validation, but the validation and test data will not have labels and will have 12, 24, 48, or 72 hours of data), and 34 | - `test_outputs` is a folder for saving your model outputs. 35 | 36 | The [Challenge website](https://physionetchallenges.org/2023/#data) provides a training database with a description of the contents and structure of the data files. 37 | 38 | You can evaluate your model by pulling or downloading the [evaluation code](https://github.com/physionetchallenges/evaluation-2023) and running 39 | 40 | python evaluate_model.py labels outputs scores.csv 41 | 42 | where 43 | 44 | - `labels` is a folder with labels for the data, such as the training database on the PhysioNet webpage, 45 | - `outputs` is a folder containing files with your model's outputs for the data, and 46 | - `scores.csv` (optional) is a collection of scores for your model. 47 | 48 | ## Which scripts I can edit? 49 | 50 | Please edit the following script to add your code: 51 | 52 | * `team_code.py` is a script with functions for training and running your trained model. 53 | 54 | Please do **not** edit the following scripts. We will use the unedited versions of these scripts when running your code: 55 | 56 | * `train_model.py` is a script for training your model. 57 | * `run_model.py` is a script for running your trained model. 58 | * `helper_code.py` is a script with helper functions that we used for our code. You are welcome to use them in your code. 59 | 60 | These scripts must remain in the root path of your repository, but you can put other scripts and other files elsewhere in your repository. 61 | 62 | ## How do I train, save, load, and run my model? 63 | 64 | To train and save your models, please edit the `train_challenge_model` function in the `team_code.py` script. Please do not edit the input or output arguments of the `train_challenge_model` function. 65 | 66 | To load and run your trained model, please edit the `load_challenge_model` and `run_challenge_model` functions in the `team_code.py` script. Please do not edit the input or output arguments of the functions of the `load_challenge_model` and `run_challenge_model` functions. 67 | 68 | ## How do I run these scripts in Docker? 69 | 70 | Docker and similar platforms allow you to containerize and package your code with specific dependencies so that your code can be reliably run in other computational environments . 71 | 72 | To guarantee that we can run your code, please [install](https://docs.docker.com/get-docker/) Docker, build a Docker image from your code, and run it on the training data. To quickly check your code for bugs, you may want to run it on a small subset of the training data. 73 | 74 | If you have trouble running your code, then please try the follow steps to run the example code. 75 | 76 | 1. Create a folder `example` in your home directory with several subfolders. 77 | 78 | user@computer:~$ cd ~/ 79 | user@computer:~$ mkdir example 80 | user@computer:~$ cd example 81 | user@computer:~/example$ mkdir training_data test_data model test_outputs 82 | 83 | 2. Download the training data from the [Challenge website](https://physionetchallenges.org/2023/#data). Put some of the training data in `training_data` and `test_data`. You can use some of the training data to check your code (and you should perform cross-validation on the training data to evaluate your algorithm). 84 | 85 | 3. Download or clone this repository in your terminal. 86 | 87 | user@computer:~/example$ git clone https://github.com/physionetchallenges/python-example-2023.git 88 | 89 | 4. Build a Docker image and run the example code in your terminal. 90 | 91 | user@computer:~/example$ ls 92 | model python-example-2023 test_data test_outputs training_data 93 | 94 | user@computer:~/example$ cd python-example-2023/ 95 | 96 | user@computer:~/example/python-example-2023$ docker build -t image . 97 | 98 | Sending build context to Docker daemon [...]kB 99 | [...] 100 | Successfully tagged image:latest 101 | 102 | user@computer:~/example/python-example-2023$ docker run -it -v ~/example/model:/challenge/model -v ~/example/test_data:/challenge/test_data -v ~/example/test_outputs:/challenge/test_outputs -v ~/example/training_data:/challenge/training_data image bash 103 | 104 | root@[...]:/challenge# ls 105 | Dockerfile README.md test_outputs 106 | evaluate_model.py requirements.txt training_data 107 | helper_code.py team_code.py train_model.py 108 | LICENSE run_model.py 109 | 110 | root@[...]:/challenge# python train_model.py training_data model 111 | 112 | root@[...]:/challenge# python run_model.py model test_data test_outputs 113 | 114 | root@[...]:/challenge# python evaluate_model.py test_data test_outputs 115 | [...] 116 | 117 | root@[...]:/challenge# exit 118 | Exit 119 | 120 | ## What else do I need? 121 | 122 | This repository does not include code for evaluating your entry. Please see the [evaluation code repository](https://github.com/physionetchallenges/evaluation-2023) for code and instructions for evaluating your entry using the Challenge scoring metric. 123 | 124 | This repository also includes code for preparing the validation and test sets. We will run your trained model on data without labels and with 12, 24, 48, and 72 hours of recording data to evaluate its performance with limited amounts of data. You can use this code to prepare the training data in the same way that we prepare the validation and test sets. 125 | 126 | - `truncate_data.py`: Truncate the EEG recordings. Usage: run `python truncate_data.py -i input_folder -o output_folder -t 12` to truncate the EEG recordings to 12 hours. We will run your trained models on data with 12, 24, 48, and 72 hours of recording data. 127 | - `remove_labels.py`: Remove the labels. Usage: run `python remove_labels.py -i input_folder -o output_folder` to copy the data and metadata (but not the labels) from `input_folder` to `output_folder`. 128 | - `remove_data.py`: Remove the binary signal data, i.e., the EEG recordings. Usage: run `python remove_data.py -i input_folder -o output_folder` to copy the labels and metadata (but not the EEG recording data) from `input_folder` to `output_folder`. 129 | 130 | ## How do I learn more? 131 | 132 | Please see the [Challenge website](https://physionetchallenges.org/2023/) for more details. Please post questions and concerns on the [Challenge discussion forum](https://groups.google.com/forum/#!forum/physionet-challenges). 133 | 134 | ## Useful links 135 | 136 | * [Challenge website](https://physionetchallenges.org/2023/) 137 | * [MATLAB example code](https://github.com/physionetchallenges/matlab-example-2023) 138 | * [Evaluation code](https://github.com/physionetchallenges/evaluation-2023) 139 | * [Frequently asked questions (FAQ) for this year's Challenge](https://physionetchallenges.org/2023/faq/) 140 | * [Frequently asked questions (FAQ) about the Challenges in general](https://physionetchallenges.org/faq/) 141 | -------------------------------------------------------------------------------- /team_code.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Edit this script to add your team's code. Some functions are *required*, but you can edit most parts of the required functions, 4 | # change or remove non-required functions, and add your own functions. 5 | 6 | ################################################################################ 7 | # 8 | # Optional libraries, functions, and variables. You can change or remove them. 9 | # 10 | ################################################################################ 11 | 12 | from helper_code import * 13 | import numpy as np, os, sys 14 | import mne 15 | from sklearn.impute import SimpleImputer 16 | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor 17 | import joblib 18 | 19 | ################################################################################ 20 | # 21 | # Required functions. Edit these functions to add your code, but do not change the arguments of the functions. 22 | # 23 | ################################################################################ 24 | 25 | # Train your model. 26 | def train_challenge_model(data_folder, model_folder, verbose): 27 | # Find data files. 28 | if verbose >= 1: 29 | print('Finding the Challenge data...') 30 | 31 | patient_ids = find_data_folders(data_folder) 32 | num_patients = len(patient_ids) 33 | 34 | if num_patients==0: 35 | raise FileNotFoundError('No data was provided.') 36 | 37 | # Create a folder for the model if it does not already exist. 38 | os.makedirs(model_folder, exist_ok=True) 39 | 40 | # Extract the features and labels. 41 | if verbose >= 1: 42 | print('Extracting features and labels from the Challenge data...') 43 | 44 | features = list() 45 | outcomes = list() 46 | cpcs = list() 47 | 48 | for i in range(num_patients): 49 | if verbose >= 2: 50 | print(' {}/{}...'.format(i+1, num_patients)) 51 | 52 | current_features = get_features(data_folder, patient_ids[i]) 53 | features.append(current_features) 54 | 55 | # Extract labels. 56 | patient_metadata = load_challenge_data(data_folder, patient_ids[i]) 57 | current_outcome = get_outcome(patient_metadata) 58 | outcomes.append(current_outcome) 59 | current_cpc = get_cpc(patient_metadata) 60 | cpcs.append(current_cpc) 61 | 62 | features = np.vstack(features) 63 | outcomes = np.vstack(outcomes) 64 | cpcs = np.vstack(cpcs) 65 | 66 | # Train the models. 67 | if verbose >= 1: 68 | print('Training the Challenge model on the Challenge data...') 69 | 70 | # Define parameters for random forest classifier and regressor. 71 | n_estimators = 123 # Number of trees in the forest. 72 | max_leaf_nodes = 456 # Maximum number of leaf nodes in each tree. 73 | random_state = 789 # Random state; set for reproducibility. 74 | 75 | # Impute any missing features; use the mean value by default. 76 | imputer = SimpleImputer().fit(features) 77 | 78 | # Train the models. 79 | features = imputer.transform(features) 80 | outcome_model = RandomForestClassifier( 81 | n_estimators=n_estimators, max_leaf_nodes=max_leaf_nodes, random_state=random_state).fit(features, outcomes.ravel()) 82 | cpc_model = RandomForestRegressor( 83 | n_estimators=n_estimators, max_leaf_nodes=max_leaf_nodes, random_state=random_state).fit(features, cpcs.ravel()) 84 | 85 | # Save the models. 86 | save_challenge_model(model_folder, imputer, outcome_model, cpc_model) 87 | 88 | if verbose >= 1: 89 | print('Done.') 90 | 91 | # Load your trained models. This function is *required*. You should edit this function to add your code, but do *not* change the 92 | # arguments of this function. 93 | def load_challenge_models(model_folder, verbose): 94 | filename = os.path.join(model_folder, 'models.sav') 95 | return joblib.load(filename) 96 | 97 | # Run your trained models. This function is *required*. You should edit this function to add your code, but do *not* change the 98 | # arguments of this function. 99 | def run_challenge_models(models, data_folder, patient_id, verbose): 100 | imputer = models['imputer'] 101 | outcome_model = models['outcome_model'] 102 | cpc_model = models['cpc_model'] 103 | 104 | # Extract features. 105 | features = get_features(data_folder, patient_id) 106 | features = features.reshape(1, -1) 107 | 108 | # Impute missing data. 109 | features = imputer.transform(features) 110 | 111 | # Apply models to features. 112 | outcome = outcome_model.predict(features)[0] 113 | outcome_probability = outcome_model.predict_proba(features)[0, 1] 114 | cpc = cpc_model.predict(features)[0] 115 | 116 | # Ensure that the CPC score is between (or equal to) 1 and 5. 117 | cpc = np.clip(cpc, 1, 5) 118 | 119 | return outcome, outcome_probability, cpc 120 | 121 | ################################################################################ 122 | # 123 | # Optional functions. You can change or remove these functions and/or add new functions. 124 | # 125 | ################################################################################ 126 | 127 | # Save your trained model. 128 | def save_challenge_model(model_folder, imputer, outcome_model, cpc_model): 129 | d = {'imputer': imputer, 'outcome_model': outcome_model, 'cpc_model': cpc_model} 130 | filename = os.path.join(model_folder, 'models.sav') 131 | joblib.dump(d, filename, protocol=0) 132 | 133 | # Preprocess data. 134 | def preprocess_data(data, sampling_frequency, utility_frequency): 135 | # Define the bandpass frequencies. 136 | passband = [0.1, 30.0] 137 | 138 | # Promote the data to double precision because these libraries expect double precision. 139 | data = np.asarray(data, dtype=np.float64) 140 | 141 | # If the utility frequency is between bandpass frequencies, then apply a notch filter. 142 | if utility_frequency is not None and passband[0] <= utility_frequency <= passband[1]: 143 | data = mne.filter.notch_filter(data, sampling_frequency, utility_frequency, n_jobs=4, verbose='error') 144 | 145 | # Apply a bandpass filter. 146 | data = mne.filter.filter_data(data, sampling_frequency, passband[0], passband[1], n_jobs=4, verbose='error') 147 | 148 | # Resample the data. 149 | if sampling_frequency % 2 == 0: 150 | resampling_frequency = 128 151 | else: 152 | resampling_frequency = 125 153 | lcm = np.lcm(int(round(sampling_frequency)), int(round(resampling_frequency))) 154 | up = int(round(lcm / sampling_frequency)) 155 | down = int(round(lcm / resampling_frequency)) 156 | resampling_frequency = sampling_frequency * up / down 157 | data = scipy.signal.resample_poly(data, up, down, axis=1) 158 | 159 | # Scale the data to the interval [-1, 1]. 160 | min_value = np.min(data) 161 | max_value = np.max(data) 162 | if min_value != max_value: 163 | data = 2.0 / (max_value - min_value) * (data - 0.5 * (min_value + max_value)) 164 | else: 165 | data = 0 * data 166 | 167 | return data, resampling_frequency 168 | 169 | # Extract features. 170 | def get_features(data_folder, patient_id): 171 | # Load patient data. 172 | patient_metadata = load_challenge_data(data_folder, patient_id) 173 | recording_ids = find_recording_files(data_folder, patient_id) 174 | num_recordings = len(recording_ids) 175 | 176 | # Extract patient features. 177 | patient_features = get_patient_features(patient_metadata) 178 | 179 | # Extract EEG features. 180 | eeg_channels = ['F3', 'P3', 'F4', 'P4'] 181 | group = 'EEG' 182 | 183 | if num_recordings > 0: 184 | recording_id = recording_ids[-1] 185 | recording_location = os.path.join(data_folder, patient_id, '{}_{}'.format(recording_id, group)) 186 | if os.path.exists(recording_location + '.hea'): 187 | data, channels, sampling_frequency = load_recording_data(recording_location) 188 | utility_frequency = get_utility_frequency(recording_location + '.hea') 189 | 190 | if all(channel in channels for channel in eeg_channels): 191 | data, channels = reduce_channels(data, channels, eeg_channels) 192 | data, sampling_frequency = preprocess_data(data, sampling_frequency, utility_frequency) 193 | data = np.array([data[0, :] - data[1, :], data[2, :] - data[3, :]]) # Convert to bipolar montage: F3-P3 and F4-P4 194 | eeg_features = get_eeg_features(data, sampling_frequency).flatten() 195 | else: 196 | eeg_features = float('nan') * np.ones(8) # 2 bipolar channels * 4 features / channel 197 | else: 198 | eeg_features = float('nan') * np.ones(8) # 2 bipolar channels * 4 features / channel 199 | else: 200 | eeg_features = float('nan') * np.ones(8) # 2 bipolar channels * 4 features / channel 201 | 202 | # Extract ECG features. 203 | ecg_channels = ['ECG', 'ECGL', 'ECGR', 'ECG1', 'ECG2'] 204 | group = 'ECG' 205 | 206 | if num_recordings > 0: 207 | recording_id = recording_ids[0] 208 | recording_location = os.path.join(data_folder, patient_id, '{}_{}'.format(recording_id, group)) 209 | if os.path.exists(recording_location + '.hea'): 210 | data, channels, sampling_frequency = load_recording_data(recording_location) 211 | utility_frequency = get_utility_frequency(recording_location + '.hea') 212 | 213 | data, channels = reduce_channels(data, channels, ecg_channels) 214 | data, sampling_frequency = preprocess_data(data, sampling_frequency, utility_frequency) 215 | features = get_ecg_features(data) 216 | ecg_features = expand_channels(features, channels, ecg_channels).flatten() 217 | else: 218 | ecg_features = float('nan') * np.ones(10) # 5 channels * 2 features / channel 219 | else: 220 | ecg_features = float('nan') * np.ones(10) # 5 channels * 2 features / channel 221 | 222 | # Extract features. 223 | return np.hstack((patient_features, eeg_features, ecg_features)) 224 | 225 | # Extract patient features from the data. 226 | def get_patient_features(data): 227 | age = get_age(data) 228 | sex = get_sex(data) 229 | rosc = get_rosc(data) 230 | ohca = get_ohca(data) 231 | shockable_rhythm = get_shockable_rhythm(data) 232 | ttm = get_ttm(data) 233 | 234 | sex_features = np.zeros(2, dtype=int) 235 | if sex == 'Female': 236 | female = 1 237 | male = 0 238 | other = 0 239 | elif sex == 'Male': 240 | female = 0 241 | male = 1 242 | other = 0 243 | else: 244 | female = 0 245 | male = 0 246 | other = 1 247 | 248 | features = np.array((age, female, male, other, rosc, ohca, shockable_rhythm, ttm)) 249 | 250 | return features 251 | 252 | # Extract features from the EEG data. 253 | def get_eeg_features(data, sampling_frequency): 254 | num_channels, num_samples = np.shape(data) 255 | 256 | if num_samples > 0: 257 | delta_psd, _ = mne.time_frequency.psd_array_welch(data, sfreq=sampling_frequency, fmin=0.5, fmax=8.0, verbose=False) 258 | theta_psd, _ = mne.time_frequency.psd_array_welch(data, sfreq=sampling_frequency, fmin=4.0, fmax=8.0, verbose=False) 259 | alpha_psd, _ = mne.time_frequency.psd_array_welch(data, sfreq=sampling_frequency, fmin=8.0, fmax=12.0, verbose=False) 260 | beta_psd, _ = mne.time_frequency.psd_array_welch(data, sfreq=sampling_frequency, fmin=12.0, fmax=30.0, verbose=False) 261 | 262 | delta_psd_mean = np.nanmean(delta_psd, axis=1) 263 | theta_psd_mean = np.nanmean(theta_psd, axis=1) 264 | alpha_psd_mean = np.nanmean(alpha_psd, axis=1) 265 | beta_psd_mean = np.nanmean(beta_psd, axis=1) 266 | else: 267 | delta_psd_mean = theta_psd_mean = alpha_psd_mean = beta_psd_mean = float('nan') * np.ones(num_channels) 268 | 269 | features = np.array((delta_psd_mean, theta_psd_mean, alpha_psd_mean, beta_psd_mean)).T 270 | 271 | return features 272 | 273 | # Extract features from the ECG data. 274 | def get_ecg_features(data): 275 | num_channels, num_samples = np.shape(data) 276 | 277 | if num_samples > 0: 278 | mean = np.mean(data, axis=1) 279 | std = np.std(data, axis=1) 280 | elif num_samples == 1: 281 | mean = np.mean(data, axis=1) 282 | std = float('nan') * np.ones(num_channels) 283 | else: 284 | mean = float('nan') * np.ones(num_channels) 285 | std = float('nan') * np.ones(num_channels) 286 | 287 | features = np.array((mean, std)).T 288 | 289 | return features 290 | -------------------------------------------------------------------------------- /helper_code.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Do *not* edit this script. 4 | # These are helper functions that you can use with your code. 5 | # Check the example code to see how to import these functions to your code. 6 | 7 | import os, numpy as np, scipy as sp, scipy.io 8 | 9 | ### Challenge data I/O functions 10 | 11 | # Find the folders with data files. 12 | def find_data_folders(root_folder): 13 | data_folders = list() 14 | for x in sorted(os.listdir(root_folder)): 15 | data_folder = os.path.join(root_folder, x) 16 | if os.path.isdir(data_folder): 17 | data_file = os.path.join(data_folder, x + '.txt') 18 | if os.path.isfile(data_file): 19 | data_folders.append(x) 20 | return sorted(data_folders) 21 | 22 | # Load the patient metadata: age, sex, etc. 23 | def load_challenge_data(data_folder, patient_id): 24 | patient_metadata_file = os.path.join(data_folder, patient_id, patient_id + '.txt') 25 | patient_metadata = load_text_file(patient_metadata_file) 26 | return patient_metadata 27 | 28 | # Find the record names. 29 | def find_recording_files(data_folder, patient_id): 30 | record_names = set() 31 | patient_folder = os.path.join(data_folder, patient_id) 32 | for file_name in sorted(os.listdir(patient_folder)): 33 | if not file_name.startswith('.') and file_name.endswith('.hea'): 34 | root, ext = os.path.splitext(file_name) 35 | record_name = '_'.join(root.split('_')[:-1]) 36 | record_names.add(record_name) 37 | return sorted(record_names) 38 | 39 | # Load the WFDB data for the Challenge (but not all possible WFDB files). 40 | def load_recording_data(record_name, check_values=False): 41 | # Allow either the record name or the header filename. 42 | root, ext = os.path.splitext(record_name) 43 | if ext=='': 44 | header_file = record_name + '.hea' 45 | else: 46 | header_file = record_name 47 | 48 | # Load the header file. 49 | if not os.path.isfile(header_file): 50 | raise FileNotFoundError('{} recording not found.'.format(record_name)) 51 | 52 | with open(header_file, 'r') as f: 53 | header = [l.strip() for l in f.readlines() if l.strip()] 54 | 55 | # Parse the header file. 56 | record_name = None 57 | num_signals = None 58 | sampling_frequency = None 59 | num_samples = None 60 | signal_files = list() 61 | gains = list() 62 | baselines = list() 63 | adc_zeros = list() 64 | channels = list() 65 | initial_values = list() 66 | checksums = list() 67 | 68 | for i, l in enumerate(header): 69 | arrs = [arr.strip() for arr in l.split(' ')] 70 | # Parse the record line. 71 | if i==0: 72 | record_name = arrs[0] 73 | num_signals = int(arrs[1]) 74 | sampling_frequency = float(arrs[2]) 75 | num_samples = int(arrs[3]) 76 | # Parse the signal specification lines. 77 | elif not l.startswith('#') or len(l.strip()) == 0: 78 | signal_file = arrs[0] 79 | if '(' in arrs[2] and ')' in arrs[2]: 80 | gain = float(arrs[2].split('/')[0].split('(')[0]) 81 | baseline = float(arrs[2].split('/')[0].split('(')[1].split(')')[0]) 82 | else: 83 | gain = float(arrs[2].split('/')[0]) 84 | baseline = 0.0 85 | adc_zero = int(arrs[4]) 86 | initial_value = int(arrs[5]) 87 | checksum = int(arrs[6]) 88 | channel = arrs[8] 89 | signal_files.append(signal_file) 90 | gains.append(gain) 91 | baselines.append(baseline) 92 | adc_zeros.append(adc_zero) 93 | initial_values.append(initial_value) 94 | checksums.append(checksum) 95 | channels.append(channel) 96 | 97 | # Check that the header file only references one signal file. WFDB format allows for multiple signal files, but, for 98 | # simplicity, we have not done that here. 99 | num_signal_files = len(set(signal_files)) 100 | if num_signal_files!=1: 101 | raise NotImplementedError('The header file {}'.format(header_file) \ 102 | + ' references {} signal files; one signal file expected.'.format(num_signal_files)) 103 | 104 | # Load the signal file. 105 | head, tail = os.path.split(header_file) 106 | signal_file = os.path.join(head, list(signal_files)[0]) 107 | data = np.asarray(sp.io.loadmat(signal_file)['val']) 108 | 109 | # Check that the dimensions of the signal data in the signal file is consistent with the dimensions for the signal data given 110 | # in the header file. 111 | num_channels = len(channels) 112 | if np.shape(data)!=(num_channels, num_samples): 113 | raise ValueError('The header file {}'.format(header_file) \ 114 | + ' is inconsistent with the dimensions of the signal file.') 115 | 116 | # Check that the initial value and checksums in the signal file are consistent with the initial value and checksums in the 117 | # header file. 118 | if check_values: 119 | for i in range(num_channels): 120 | if data[i, 0]!=initial_values[i]: 121 | raise ValueError('The initial value in header file {}'.format(header_file) \ 122 | + ' is inconsistent with the initial value for channel {} in the signal data'.format(channels[i])) 123 | if np.sum(data[i, :], dtype=np.int16)!=checksums[i]: 124 | raise ValueError('The checksum in header file {}'.format(header_file) \ 125 | + ' is inconsistent with the checksum value for channel {} in the signal data'.format(channels[i])) 126 | 127 | # Rescale the signal data using the gains and offsets. 128 | rescaled_data = np.zeros(np.shape(data), dtype=np.float32) 129 | for i in range(num_channels): 130 | rescaled_data[i, :] = (np.asarray(data[i, :], dtype=np.float64) - baselines[i] - adc_zeros[i]) / gains[i] 131 | 132 | return rescaled_data, channels, sampling_frequency 133 | 134 | # Choose the channels. 135 | def reduce_channels(current_data, current_channels, requested_channels): 136 | if current_channels == requested_channels: 137 | reduced_data = current_data 138 | reduced_channels = current_channels 139 | else: 140 | reduced_indices = [current_channels.index(channel) for channel in requested_channels if channel in current_channels] 141 | reduced_channels = [current_channels[i] for i in reduced_indices] 142 | reduced_data = current_data[reduced_indices, :] 143 | return reduced_data, reduced_channels 144 | 145 | # Choose the channels. 146 | def expand_channels(current_data, current_channels, requested_channels): 147 | if current_channels == requested_channels: 148 | expanded_data = current_data 149 | else: 150 | num_current_channels, num_samples = np.shape(current_data) 151 | num_requested_channels = len(requested_channels) 152 | expanded_data = np.zeros((num_requested_channels, num_samples)) 153 | for i, channel in enumerate(requested_channels): 154 | if channel in current_channels: 155 | j = current_channels.index(channel) 156 | expanded_data[i, :] = current_data[j, :] 157 | else: 158 | expanded_data[i, :] = float('nan') 159 | return expanded_data 160 | 161 | ### Helper Challenge data I/O functions 162 | 163 | # Load text file as a string. 164 | def load_text_file(filename): 165 | with open(filename, 'r') as f: 166 | data = f.read() 167 | return data 168 | 169 | # Get a variable from the patient metadata. 170 | def get_variable(text, variable_name, variable_type): 171 | variable = None 172 | for l in text.split('\n'): 173 | if l.startswith(variable_name): 174 | variable = ':'.join(l.split(':')[1:]).strip() 175 | variable = cast_variable(variable, variable_type) 176 | return variable 177 | 178 | # Get the patient ID variable from the patient data. 179 | def get_patient_id(string): 180 | return get_variable(string, 'Patient', str) 181 | 182 | # Get the patient ID variable from the patient data. 183 | def get_hospital(string): 184 | return get_variable(string, 'Hospital', str) 185 | 186 | # Get the age variable (in years) from the patient data. 187 | def get_age(string): 188 | return get_variable(string, 'Age', int) 189 | 190 | # Get the sex variable from the patient data. 191 | def get_sex(string): 192 | return get_variable(string, 'Sex', str) 193 | 194 | # Get the ROSC variable (in minutes) from the patient data. 195 | def get_rosc(string): 196 | return get_variable(string, 'ROSC', int) 197 | 198 | # Get the OHCA variable from the patient data. 199 | def get_ohca(string): 200 | return get_variable(string, 'OHCA', bool) 201 | 202 | # Get the shockable rhythm variable from the patient data. 203 | def get_shockable_rhythm(string): 204 | return get_variable(string, 'Shockable Rhythm', bool) 205 | 206 | # Get the TTM variable (in Celsius) from the patient data. 207 | def get_ttm(string): 208 | return get_variable(string, 'TTM', int) 209 | 210 | # Get the Outcome variable from the patient data. 211 | def get_outcome(string): 212 | variable = get_variable(string, 'Outcome', str) 213 | if variable is None or is_nan(variable): 214 | raise ValueError('No outcome available. Is your code trying to load labels from the hidden data?') 215 | if variable == 'Good': 216 | variable = 0 217 | elif variable == 'Poor': 218 | variable = 1 219 | return variable 220 | 221 | # Get the Outcome probability variable from the patient data. 222 | def get_outcome_probability(string): 223 | variable = sanitize_scalar_value(get_variable(string, 'Outcome Probability', str)) 224 | if variable is None or is_nan(variable): 225 | raise ValueError('No outcome available. Is your code trying to load labels from the hidden data?') 226 | return variable 227 | 228 | # Get the CPC variable from the patient data. 229 | def get_cpc(string): 230 | variable = sanitize_scalar_value(get_variable(string, 'CPC', str)) 231 | if variable is None or is_nan(variable): 232 | raise ValueError('No CPC score available. Is your code trying to load labels from the hidden data?') 233 | return variable 234 | 235 | # Get the utility frequency (in Hertz) from the recording data. 236 | def get_utility_frequency(string): 237 | return get_variable(string, '#Utility frequency', int) 238 | 239 | # Get the start time (in hh:mm:ss format) from the recording data. 240 | def get_start_time(string): 241 | variable = get_variable(string, '#Start time', str) 242 | times = tuple(int(value) for value in variable.split(':')) 243 | return times 244 | 245 | # Get the end time (in hh:mm:ss format) from the recording data. 246 | def get_end_time(string): 247 | variable = get_variable(string, '#End time', str) 248 | times = tuple(int(value) for value in variable.split(':')) 249 | return times 250 | 251 | # Convert seconds to days, hours, minutes, seconds. 252 | def convert_seconds_to_hours_minutes_seconds(seconds): 253 | hours = int(seconds/3600 - 24*days) 254 | minutes = int(seconds/60 - 24*60*days - 60*hours) 255 | seconds = int(seconds - 24*3600*days - 3600*hours - 60*minutes) 256 | return hours, minutes, seconds 257 | 258 | # Convert hours, minutes, and seconds to seconds. 259 | def convert_hours_minutes_seconds_to_seconds(hours, minutes, seconds): 260 | return 3600*hours + 60*minutes + seconds 261 | 262 | ### Challenge label and output I/O functions 263 | 264 | # Save the Challenge outputs for one file. 265 | def save_challenge_outputs(filename, patient_id, outcome, outcome_probability, cpc): 266 | # Sanitize values, e.g., in case they are a singleton array. 267 | outcome = sanitize_boolean_value(outcome) 268 | outcome_probability = sanitize_scalar_value(outcome_probability) 269 | cpc = sanitize_scalar_value(cpc) 270 | 271 | # Format Challenge outputs. 272 | patient_string = 'Patient: {}'.format(patient_id) 273 | if outcome == 0: 274 | outcome = 'Good' 275 | elif outcome == 1: 276 | outcome = 'Poor' 277 | outcome_string = 'Outcome: {}'.format(outcome) 278 | outcome_probability_string = 'Outcome Probability: {:.3f}'.format(outcome_probability) 279 | cpc_string = 'CPC: {:.3f}'.format(cast_int_if_int_else_float(cpc)) 280 | output_string = patient_string + '\n' + \ 281 | outcome_string + '\n' + outcome_probability_string + '\n' + cpc_string + '\n' 282 | 283 | # Write the Challenge outputs. 284 | if filename is not None: 285 | with open(filename, 'w') as f: 286 | f.write(output_string) 287 | 288 | return output_string 289 | 290 | ### Other helper functions 291 | 292 | # Check if a variable is a number or represents a number. 293 | def is_number(x): 294 | try: 295 | float(x) 296 | return True 297 | except (ValueError, TypeError): 298 | return False 299 | 300 | # Check if a variable is an integer or represents an integer. 301 | def is_integer(x): 302 | if is_number(x): 303 | return float(x).is_integer() 304 | else: 305 | return False 306 | 307 | # Check if a variable is a boolean or represents a boolean. 308 | def is_boolean(x): 309 | if (is_number(x) and float(x)==0) or (remove_extra_characters(x) in ('False', 'false', 'FALSE', 'F', 'f')): 310 | return True 311 | elif (is_number(x) and float(x)==1) or (remove_extra_characters(x) in ('True', 'true', 'TRUE', 'T', 't')): 312 | return True 313 | else: 314 | return False 315 | 316 | # Check if a variable is a finite number or represents a finite number. 317 | def is_finite_number(x): 318 | if is_number(x): 319 | return np.isfinite(float(x)) 320 | else: 321 | return False 322 | 323 | # Check if a variable is a NaN (not a number) or represents a NaN. 324 | def is_nan(x): 325 | if is_number(x): 326 | return np.isnan(float(x)) 327 | else: 328 | return False 329 | 330 | # Remove any quotes, brackets (for singleton arrays), and/or invisible characters. 331 | def remove_extra_characters(x): 332 | return str(x).replace('"', '').replace("'", "").replace('[', '').replace(']', '').replace(' ', '').strip() 333 | 334 | # Sanitize boolean values. 335 | def sanitize_boolean_value(x): 336 | x = remove_extra_characters(x) 337 | if (is_number(x) and float(x)==0) or (remove_extra_characters(x) in ('False', 'false', 'FALSE', 'F', 'f')): 338 | return 0 339 | elif (is_number(x) and float(x)==1) or (remove_extra_characters(x) in ('True', 'true', 'TRUE', 'T', 't')): 340 | return 1 341 | else: 342 | return float('nan') 343 | 344 | # Sanitize integer values. 345 | def sanitize_integer_value(x): 346 | x = remove_extra_characters(x) 347 | if is_integer(x): 348 | return int(float(x)) 349 | else: 350 | return float('nan') 351 | 352 | # Sanitize scalar values. 353 | def sanitize_scalar_value(x): 354 | x = remove_extra_characters(x) 355 | if is_number(x): 356 | return float(x) 357 | else: 358 | return float('nan') 359 | 360 | # Cast a value to a particular type. 361 | def cast_variable(variable, variable_type, preserve_nan=True): 362 | if preserve_nan and is_nan(variable): 363 | variable = float('nan') 364 | else: 365 | if variable_type == bool: 366 | variable = sanitize_boolean_value(variable) 367 | elif variable_type == int: 368 | variable = sanitize_integer_value(variable) 369 | elif variable_type == float: 370 | variable = sanitize_scalar_value(variable) 371 | else: 372 | variable = variable_type(variable) 373 | return variable 374 | 375 | # Cast a value to an integer if the value is an integer, a float if the value is a non-integer float, and itself otherwise. 376 | def cast_int_if_int_else_float(x): 377 | if is_integer(x): 378 | return int(float(x)) 379 | elif is_number(x): 380 | return float(x) 381 | else: 382 | return x 383 | --------------------------------------------------------------------------------