├── src ├── data_preprocessing │ ├── features_preprocessing │ │ ├── __init__.py │ │ ├── .DS_Store │ │ ├── stepIV_GP_prep_part_II.py │ │ ├── archive │ │ │ ├── GP_prep_part_II.py │ │ │ └── GP_prep.py │ │ ├── stepII_split_sets_n_normalise.py │ │ └── stepIII_GP_prep.py │ ├── extract_MIMIC_data │ │ ├── extract_labels │ │ │ ├── __init__.py │ │ │ ├── .DS_Store │ │ │ ├── SQL-SI │ │ │ │ ├── .DS_Store │ │ │ │ ├── SI.sql │ │ │ │ └── abx_micro_poe.sql │ │ │ ├── SQL-SOFA │ │ │ │ ├── .DS_Store │ │ │ │ ├── coagulation │ │ │ │ │ ├── .DS_Store │ │ │ │ │ └── labsperhour.sql │ │ │ │ ├── renal │ │ │ │ │ ├── runninguo24h.sql │ │ │ │ │ ├── labsperhour.sql │ │ │ │ │ └── uoperhour.sql │ │ │ │ ├── SOFA_flag.sql │ │ │ │ ├── cardiovascular │ │ │ │ │ ├── echo.sql │ │ │ │ │ ├── vitalsperhour.sql │ │ │ │ │ └── cardio_SOFA.sql │ │ │ │ ├── hourly_table.sql │ │ │ │ ├── respiration │ │ │ │ │ ├── resp_SOFA.sql │ │ │ │ │ ├── bloodgasfirstday.sql │ │ │ │ │ ├── ventdurations.sql │ │ │ │ │ ├── bloodgasfirstdayarterial.sql │ │ │ │ │ └── ventsettings.sql │ │ │ │ ├── liver │ │ │ │ │ └── labsperhour.sql │ │ │ │ ├── central_nervous_system │ │ │ │ │ └── gcsperhour.sql │ │ │ │ └── SOFA.sql │ │ │ └── sofa_delta.sql │ │ ├── .DS_Store │ │ └── extract_features │ │ │ ├── hourly-cohort.sql │ │ │ ├── static-query.sql │ │ │ ├── sepsis3_cohort_mr.sql │ │ │ ├── match_controls.py │ │ │ ├── make_data.py │ │ │ ├── extract-55h-of-hourly-case-vital-series_ex1c.sql │ │ │ └── extract-55h-of-hourly-control-vital-series_ex1c.sql │ ├── .DS_Store │ └── main.py ├── trainers │ ├── .DS_Store │ ├── losses.py │ └── trainer.py ├── utils │ └── debug.py ├── loss_n_eval │ ├── losses.py │ └── aucs.py ├── models │ ├── GP_utils.py │ ├── attTCN_alpha.py │ ├── TCN.py │ ├── GP_logreg.py │ ├── attTCN_beta.py │ ├── GP_attTCN.py │ ├── attTCN.py │ └── GP_attTCN_ablations.py ├── mains │ ├── ablation_0.py │ ├── ablation_alpha.py │ ├── main.py │ └── ablation_beta.py └── data_loader │ └── utils.py ├── requirements.txt ├── LICENSE ├── .gitignore └── README.md /src/data_preprocessing/features_preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/trainers/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmr12/MGP-AttTCN/HEAD/src/trainers/.DS_Store -------------------------------------------------------------------------------- /src/data_preprocessing/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmr12/MGP-AttTCN/HEAD/src/data_preprocessing/.DS_Store -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmr12/MGP-AttTCN/HEAD/src/data_preprocessing/extract_MIMIC_data/.DS_Store -------------------------------------------------------------------------------- /src/data_preprocessing/features_preprocessing/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmr12/MGP-AttTCN/HEAD/src/data_preprocessing/features_preprocessing/.DS_Store -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==1.0.1 2 | numpy==1.18.1 3 | SQLAlchemy==1.3.13 4 | scipy==1.4.1 5 | psycopg2==2.8.6 6 | sacred==0.8.2 7 | scikit_learn==0.24.1 8 | tensorflow==2.4.1 9 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmr12/MGP-AttTCN/HEAD/src/data_preprocessing/extract_MIMIC_data/extract_labels/.DS_Store -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SI/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmr12/MGP-AttTCN/HEAD/src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SI/.DS_Store -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmr12/MGP-AttTCN/HEAD/src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/.DS_Store -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/coagulation/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmr12/MGP-AttTCN/HEAD/src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/coagulation/.DS_Store -------------------------------------------------------------------------------- /src/utils/debug.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import sys 3 | import time 4 | 5 | 6 | def print_time(): 7 | print(datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S')) 8 | sys.stdout.flush() 9 | 10 | 11 | def flush_print(string): 12 | print(string) 13 | sys.stdout.flush() 14 | 15 | 16 | def t_print(string): 17 | T = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S') 18 | print(T, " -- ", string) 19 | sys.stdout.flush() 20 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/renal/runninguo24h.sql: -------------------------------------------------------------------------------- 1 | DROP MATERIALIZED VIEW IF EXISTS SOFA_runninguo24h CASCADE; 2 | create materialized view SOFA_runninguo24h as 3 | SELECT 4 | uo_1.hadm_id, uo_1.HLOS 5 | 6 | , SUM(uo_4.UrineOutput) as running_uo_24h 7 | FROM SOFA_uoperhour uo_1 8 | JOIN SOFA_uoperhour uo_4 ON 9 | uo_1.hadm_id = uo_4.hadm_id and 10 | uo_4.HLOS between uo_1.HLOS -24 and uo_1.HLOS 11 | 12 | where uo_4.ICULOS >= 0 and uo_1.ICULOS >= 24 13 | 14 | group by uo_1.hadm_id, uo_1.HLOS 15 | order by uo_1.hadm_id, uo_1.HLOS 16 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SI/SI.sql: -------------------------------------------------------------------------------- 1 | DROP TABLE IF EXISTS SI_flag CASCADE; 2 | CREATE TABLE SI_flag as 3 | with abx as 4 | ( 5 | select hadm_id 6 | , suspected_infection_time 7 | , ROW_NUMBER() OVER 8 | ( 9 | PARTITION BY hadm_id 10 | ORDER BY suspected_infection_time 11 | ) as rn 12 | from abx_micro_poe 13 | ) 14 | select 15 | hadm_id 16 | , suspected_infection_time 17 | , suspected_infection_time - interval '48' hour as si_start 18 | , suspected_infection_time + interval '24' hour as si_end 19 | from abx 20 | where abx.rn = 1 21 | order by hadm_id -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/SOFA_flag.sql: -------------------------------------------------------------------------------- 1 | DROP MATERIALIZED VIEW IF EXISTS SOFAflag CASCADE; 2 | CREATE materialized VIEW SOFAflag AS 3 | -- create flag 4 | with SOFA_flag as ( 5 | select t0.hadm_id 6 | , t0.hlos 7 | , case 8 | when tn2.SOFA - t0.SOFA >=2 then 1 9 | else 0 end as SOFAflag 10 | from SOFAperhour tn2 11 | join SOFAperhour t0 12 | on tn2.hadm_id = t0.hadm_id 13 | and tn2.hlos = t0.hlos - 1 14 | where t0.hlos >= 0 15 | ) 16 | -- adding admissions info, calculating hour of onset 17 | select S.* 18 | , ha.subject_id 19 | , ha.admittime 20 | , ha.admittime + S.hlos * interval '1 hour' as SOFAtime 21 | from SOFA_flag S 22 | left join admissions ha 23 | on S.hadm_id = ha.hadm_id -------------------------------------------------------------------------------- /src/trainers/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | def GP_loss(model, inputs, labels, weighted_loss=None): 5 | out = model(inputs) 6 | temp_labels = tf.reshape(labels, (-1, 1)) 7 | y_star = tf.concat([temp_labels, tf.ones_like(temp_labels) - temp_labels], axis=1, name="y_star") 8 | if weighted_loss is not None: 9 | float_y_star = tf.cast(y_star, dtype=tf.float32) 10 | weights = float_y_star[:, 0] / weighted_loss + float_y_star[:, 1] 11 | else: 12 | weights = tf.convert_to_tensor([1], dtype=tf.float32) 13 | return tf.losses.softmax_cross_entropy(y_star, out, weights=weights) 14 | 15 | 16 | def grad(model, inputs, targets, ratio_weights=None, multi_class=True, GP=False, weighted_loss=None): 17 | if GP == True: 18 | with tf.GradientTape() as tape: 19 | loss_value = GP_loss(model, inputs, targets, weighted_loss) 20 | else: 21 | with tf.GradientTape() as tape: 22 | loss_value = loss(model, inputs, targets, ratio_weights, multi_class=multi_class) 23 | return loss_value, tape.gradient(loss_value, model.trainable_variables) 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Margherita 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/loss_n_eval/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | def GP_loss(model, inputs, labels, weighted_loss=None): 5 | out = model(inputs) 6 | temp_labels = tf.reshape(labels, (-1, 1)) 7 | y_star = tf.concat([temp_labels, tf.ones_like(temp_labels) - temp_labels], axis=1, name="y_star") 8 | if weighted_loss is not None: 9 | float_y_star = tf.cast(y_star, dtype=tf.float32) 10 | weights = float_y_star[:, 0] / weighted_loss + float_y_star[:, 1] 11 | else: 12 | weights = tf.convert_to_tensor([1], dtype=tf.float32) 13 | return tf.compat.v1.losses.softmax_cross_entropy(y_star, out, weights=weights) 14 | 15 | 16 | def grad(model, inputs, targets, ratio_weights=None, multi_class=True, GP=False, weighted_loss=None): 17 | if GP == True: 18 | with tf.GradientTape() as tape: 19 | loss_value = GP_loss(model, inputs, targets, weighted_loss) 20 | else: 21 | with tf.GradientTape() as tape: 22 | loss_value = loss(model, inputs, targets, ratio_weights, multi_class=multi_class) 23 | return loss_value, tape.gradient(loss_value, model.trainable_variables) 24 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/cardiovascular/echo.sql: -------------------------------------------------------------------------------- 1 | -- This code extracts structured data from echocardiographies 2 | -- You can join it to the text notes using ROW_ID 3 | -- Just note that ROW_ID will differ across versions of MIMIC-III. 4 | 5 | DROP MATERIALIZED VIEW IF EXISTS ECHODATA CASCADE; 6 | CREATE MATERIALIZED VIEW ECHODATA AS 7 | select ROW_ID 8 | , subject_id, hadm_id 9 | , chartdate 10 | 11 | -- charttime is always null for echoes.. 12 | -- however, the time is available in the echo text, e.g.: 13 | -- , substring(ne.text, 'Date/Time: [\[\]0-9*-]+ at ([0-9:]+)') as TIMESTAMP 14 | -- we can therefore impute it and re-create charttime 15 | , cast(to_timestamp( (to_char( chartdate, 'DD-MM-YYYY' ) || substring(ne.text, 'Date/Time: [\[\]0-9*-]+ at ([0-9:]+)')), 16 | 'DD-MM-YYYYHH24:MI') as timestamp without time zone) 17 | as charttime 18 | 19 | , case 20 | when substring(ne.text, 'Weight \(lb\): (.*?)\n') like '%*%' 21 | then null 22 | else cast(substring(ne.text, 'Weight \(lb\): (.*?)\n') as numeric) 23 | end as Weight 24 | 25 | from noteevents ne 26 | where category = 'Echo' 27 | -------------------------------------------------------------------------------- /src/models/GP_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import sys 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | cwd = os.path.dirname(os.path.abspath(__file__)) 9 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir)) 10 | sys.path.append(head) 11 | 12 | 13 | # reformatted from M Moor and J Futoma 14 | 15 | def kroneker_matrix(M, idx1, idx2=None): #y 16 | if idx2 is None: 17 | idx2 = idx1 18 | grid = tf.meshgrid(idx1, idx2) 19 | idx = tf.stack((grid[0], grid[1]), -1) 20 | return tf.gather_nd(M, idx) 21 | 22 | 23 | def OU_kernel(length, x1, x2): 24 | x1 = tf.reshape(x1, [-1, 1]) # colvec 25 | x2 = tf.reshape(x2, [1, -1]) # rowvec 26 | K = tf.exp(-tf.abs(x1 - x2) / length) 27 | return K 28 | 29 | 30 | def K_vitals_initialiser(shape, partition_info=None, dtype=None): 31 | # initialise lengths to be 0.01 for vitals and 5 for blood tests 32 | output = np.ones(shape[0]) 33 | output[7:] = 0.01 34 | return tf.linalg.diag(tf.convert_to_tensor(output, dtype=tf.float32)) 35 | 36 | 37 | def K_labs_initialiser(shape, partition_info=None, dtype=None): 38 | # initialise lengths to be 0.01 for vitals and 5 for blood tests 39 | output = np.ones(shape[0]) 40 | output[:7] = 0.01 41 | return tf.linalg.diag(tf.convert_to_tensor(output, dtype=tf.float32)) 42 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/sofa_delta.sql: -------------------------------------------------------------------------------- 1 | DROP TABLE IF EXISTS sofa_delta CASCADE; 2 | CREATE TABLE sofa_delta AS 3 | with case_cohort as ( 4 | select 5 | so.hadm_id 6 | , so.sofa 7 | , so.sofaresp 8 | , so.sofacoag 9 | , so.sofaliv 10 | , so.sofacardio 11 | , so.sofagcs 12 | , so.sofaren 13 | , so.sepsis_time as sepsis_onset 14 | , so.delta_score 15 | , so.sofaresp_delta 16 | , so.sofacoag_delta 17 | , so.sofaliv_delta 18 | , so.sofacardio_delta 19 | , so.sofagcs_delta 20 | , so.sofaren_delta 21 | , ie.icustay_id 22 | , ie.intime 23 | , ie.outtime 24 | , (date_part('year', age(sepsis_time, intime))*365 * 24 25 | + date_part('month', age(sepsis_time, intime))*365/12 * 24 26 | + date_part('day', age(sepsis_time, intime))* 24 27 | + date_part('hour', age(sepsis_time, intime)) 28 | + round(date_part('minute', age(sepsis_time, intime))/60)) as h_from_intime 29 | from sepsis_onset so 30 | left join icustays ie 31 | on so.hadm_id = ie.hadm_id 32 | where sepsis_time between intime and outtime 33 | 34 | ) 35 | select C.* 36 | , case when sepsis_onset is null then 0 else 1 end as septic 37 | , case when h_from_intime < 7 then 1 else 0 end as excluded 38 | , case when dg.icd9_code in ('78552','99591','99592') then 1 else 0 end as ICD_positive 39 | from case_cohort C 40 | left join icustays ie 41 | on C.icustay_id = ie.icustay_id 42 | left join diagnoses_icd dg 43 | on C.hadm_id = dg.hadm_id -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/hourly_table.sql: -------------------------------------------------------------------------------- 1 | -- This query generates a row for every hour the patient is in the ICU. 2 | -- The hours are based on clock-hours (i.e. 02:00, 03:00). 3 | -- The hour clock starts 24 hours before the first heart rate measurement. 4 | -- Note that the time of the first heart rate measurement is ceilinged to the hour. 5 | 6 | -- this query extracts the cohort and every possible hour they were in the ICU 7 | -- this table can be to other tables on ICUSTAY_ID and (ENDTIME - 1 hour,ENDTIME] 8 | DROP MATERIALIZED VIEW IF EXISTS hadms_hours CASCADE; 9 | CREATE MATERIALIZED VIEW hadms_hours as 10 | -- get first/last measurement time 11 | with all_hours as 12 | ( 13 | select 14 | ha.hadm_id 15 | 16 | -- ceiling the intime to the nearest hour by adding 59 minutes then truncating 17 | , date_trunc('hour', ha.admittime + interval '59' minute) as endtime 18 | 19 | -- create integers for each charttime in hours from admission 20 | -- so 0 is admission time, 1 is one hour after admission, etc, up to ICU disch 21 | , generate_series 22 | ( 23 | -- allow up to 24 hours before ICU admission (to grab labs before admit) 24 | -24, 25 | ceil(extract(EPOCH from ha.dischtime - ha.admittime)/60.0/60.0)::INTEGER 26 | ) as hr 27 | 28 | from admissions ha 29 | ) 30 | SELECT 31 | ah.hadm_id 32 | , ah.hr 33 | -- add the hr series 34 | -- endtime now indexes the end time of every hour for each patient 35 | , ah.endtime + ah.hr*interval '1' hour as endtime 36 | from all_hours ah 37 | order by ah.hadm_id, ah.hr; -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/cardiovascular/vitalsperhour.sql: -------------------------------------------------------------------------------- 1 | -- This query pivots the vital signs for the first 24 hours of a patient's stay 2 | -- Vital signs include heart rate, blood pressure, respiration rate, and temperature 3 | 4 | DROP MATERIALIZED VIEW IF EXISTS cardio_vitalsperhour CASCADE; 5 | create materialized view cardio_vitalsperhour as 6 | SELECT pvt.subject_id, pvt.hadm_id, pvt.HLOS 7 | 8 | -- Easier names 9 | , min(case when VitalID = 4 then valuenum else null end) as MinBP 10 | 11 | FROM ( 12 | select ha.subject_id, ha.hadm_id 13 | , valuenum 14 | , case 15 | when itemid in (456,52,6702,443,220052,220181,225312) and valuenum > 0 and valuenum < 300 then 4 -- MeanBP 16 | else null end as VitalID 17 | -- convert F to C 18 | , (date_part('year', age(ce.charttime, ha.admittime))*365 * 24 19 | + date_part('month', age(ce.charttime, ha.admittime))*365/12 * 24 20 | + date_part('day', age(ce.charttime, ha.admittime))* 24 21 | + date_part('hour', age(ce.charttime, ha.admittime)) 22 | + round(date_part('minute', age(ce.charttime, ha.admittime))/60)) as HLOS 23 | from admissions ha 24 | left join chartevents ce 25 | on ha.subject_id = ce.subject_id and ha.hadm_id = ce.hadm_id 26 | AND ce.charttime BETWEEN (ha.admittime - interval '1' day) AND ha.dischtime 27 | -- exclude rows marked as error 28 | and ce.error IS DISTINCT FROM 1 29 | where ce.itemid in 30 | ( 31 | -- MEAN ARTERIAL PRESSURE 32 | 456, --"NBP Mean" 33 | 52, --"Arterial BP Mean" 34 | 6702, -- Arterial BP Mean #2 35 | 443, -- Manual BP Mean(calc) 36 | 220052, --"Arterial Blood Pressure mean" 37 | 220181, --"Non Invasive Blood Pressure mean" 38 | 225312 --"ART BP mean" 39 | ) 40 | ) pvt 41 | group by pvt.subject_id, pvt.hadm_id, pvt.HLOS 42 | order by pvt.subject_id, pvt.hadm_id, pvt.HLOS; -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/respiration/resp_SOFA.sql: -------------------------------------------------------------------------------- 1 | DROP MATERIALIZED VIEW IF EXISTS SOFA_PaO2FiO2 CASCADE; 2 | CREATE materialized VIEW SOFA_PaO2FiO2 AS 3 | 4 | -- adding hadm_id to ventduations 5 | with vd as ( 6 | select vd.* , ie.hadm_id 7 | from ventdurations vd 8 | left join icustays ie 9 | on vd.icustay_id = ie.icustay_id 10 | ) 11 | -- combining ventilation and respiration data 12 | , pafi1 as ( 13 | select bg.hadm_id 14 | , bg.charttime 15 | , PaO2FiO2 16 | , case when sum(vd.icustay_id) is not null 17 | then 1 18 | when sum(vd.icustay_id) =0 then -1 19 | else 0 end as IsVent 20 | 21 | from resp_bloodgasfirstdayarterial bg 22 | left join vd 23 | on bg.hadm_id = vd.hadm_id 24 | and bg.charttime >= vd.starttime 25 | and bg.charttime <= vd.endtime 26 | group by bg.hadm_id, bg.charttime, PaO2FiO2 27 | order by bg.hadm_id, bg.charttime 28 | ) 29 | 30 | -- because pafi has an interaction between vent/PaO2:FiO2, we need two columns for the score 31 | -- it can happen that the lowest unventilated PaO2/FiO2 is 68, but the lowest ventilated PaO2/FiO2 is 120 32 | -- in this case, the SOFA score is 3, *not* 4. 33 | select pf.hadm_id 34 | -- , charttime 35 | , min( case when IsVent = 0 36 | then PaO2FiO2 37 | else null end) as PaO2FiO2_novent_min 38 | , min( case when IsVent = 1 39 | then PaO2FiO2 40 | else null end) as PaO2FiO2_vent_min 41 | , (date_part('year', age(pf.charttime, ha.admittime))*365 * 24 42 | + date_part('month', age(pf.charttime, ha.admittime))*365/12 * 24 43 | + date_part('day', age(pf.charttime, ha.admittime))* 24 44 | + date_part('hour', age(pf.charttime, ha.admittime)) 45 | + round(date_part('minute', age(pf.charttime, ha.admittime))/60)) as HLOS 46 | 47 | from pafi1 pf 48 | left join admissions ha 49 | on ha.hadm_id = pf.hadm_id 50 | group by pf.hadm_id, HLOS 51 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/renal/labsperhour.sql: -------------------------------------------------------------------------------- 1 | -- This query pivots lab values taken in the first 24 hours of a patient's stay 2 | 3 | -- Have already confirmed that the unit of measurement is always the same: null or the correct unit 4 | 5 | DROP MATERIALIZED VIEW IF EXISTS ren_labsperhour CASCADE; 6 | CREATE materialized VIEW ren_labsperhour AS 7 | SELECT 8 | pvt.subject_id, pvt.hadm_id, pvt.HLOS 9 | , max(CASE WHEN label = 'CREATININE' THEN valuenum ELSE null END) as CREATININE 10 | 11 | FROM 12 | ( -- begin query that extracts the data 13 | SELECT ha.subject_id, ha.hadm_id 14 | -- here we assign labels to ITEMIDs 15 | -- this also fuses together multiple ITEMIDs containing the same data 16 | , CASE 17 | WHEN itemid = 50912 THEN 'CREATININE' 18 | ELSE null 19 | END AS label 20 | , -- add in some sanity checks on the values 21 | -- the where clause below requires all valuenum to be > 0, so these are only upper limit checks 22 | CASE 23 | WHEN itemid = 50912 and valuenum > 150 THEN null -- mg/dL 'CREATININE' 24 | ELSE le.valuenum 25 | END AS valuenum 26 | 27 | , (date_part('year', age(le.charttime, ha.admittime))*365 * 24 28 | + date_part('month', age(le.charttime, ha.admittime))*365/12 * 24 29 | + date_part('day', age(le.charttime, ha.admittime))* 24 30 | + date_part('hour', age(le.charttime, ha.admittime)) 31 | + round(date_part('minute', age(le.charttime, ha.admittime))/60)) as HLOS 32 | 33 | FROM admissions ha 34 | 35 | LEFT JOIN labevents le 36 | ON le.hadm_id = ha.hadm_id 37 | AND le.charttime BETWEEN (ha.admittime - interval '1' day) AND ha.dischtime 38 | AND le.ITEMID in 39 | ( 40 | -- comment is: LABEL | CATEGORY | FLUID | NUMBER OF ROWS IN LABEVENTS 41 | 50912 -- CREATININE | CHEMISTRY | BLOOD | 797476 42 | ) 43 | AND valuenum IS NOT null AND valuenum > 0 -- lab values cannot be 0 and cannot be negative 44 | ) pvt 45 | GROUP BY pvt.subject_id, pvt.hadm_id, pvt.HLOS 46 | ORDER BY pvt.subject_id, pvt.hadm_id, pvt.HLOS; -------------------------------------------------------------------------------- /src/data_preprocessing/features_preprocessing/stepIV_GP_prep_part_II.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import sys 5 | 6 | import pandas as pd 7 | 8 | # appending head path 9 | cwd = os.path.dirname(os.path.abspath(__file__)) 10 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir)) 11 | sys.path.append(head) 12 | 13 | 14 | class GPPreprocessingSecondRound: 15 | def __init__(self, split, n_features=44): 16 | self.n_features = n_features 17 | self.cwd = os.path.dirname(os.path.abspath(__file__)) 18 | self.head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir, os.pardir)) 19 | print('working out of the assumption that head is ', self.head) 20 | self.path = os.path.join(self.head, "data", split) 21 | 22 | def load_files(self): 23 | 24 | file_path = os.path.join(self.path, "GP_prep.pkl") 25 | stat_file_path = os.path.join(self.path, "full_static.csv") 26 | with open(file_path, "rb") as f: 27 | self.data = pickle.load(f) 28 | self.static_data = pd.read_csv(stat_file_path) 29 | 30 | def discard_useless_files(self): 31 | # keep: Y, T, ind_T, ind_K, num_obs, X, num_X 32 | # atm values, times, ind_lvs, ind_times, 33 | # labels, num_rnn_grid_times, rnn_grid_times, 34 | # num_obs_times, num_obs_values, onset_hour, ids 35 | Y, T, ind_Y, _, labels, len_X, X, len_T, _, onset_hour, ids = self.data 36 | self.data = [Y, T, ind_Y, len_T, X, len_X, labels, ids, onset_hour] 37 | 38 | def join_files(self): 39 | self.static_data.set_index("icustay_id", inplace=True) 40 | self.static_data = self.static_data.loc[self.data[-2]] 41 | self.static_data.drop(columns="Unnamed: 0", inplace=True) 42 | self.static_data = self.static_data.to_numpy() 43 | self.data.append(self.static_data) 44 | 45 | def save(self): 46 | file_path = os.path.join(self.path, "GP_prep_v2.pkl") 47 | with open(file_path, "wb") as f: 48 | pickle.dump(self.data, f) 49 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/coagulation/labsperhour.sql: -------------------------------------------------------------------------------- 1 | -- This query pivots lab values taken in the first 24 hours of a patient's stay 2 | 3 | -- Have already confirmed that the unit of measurement is always the same: null or the correct unit 4 | 5 | DROP MATERIALIZED VIEW IF EXISTS coag_labsperhour CASCADE; 6 | CREATE materialized VIEW coag_labsperhour AS 7 | SELECT 8 | pvt.subject_id, pvt.hadm_id, pvt.HLOS 9 | , min(CASE WHEN label = 'PLATELET' THEN valuenum ELSE null END) as PLATELET 10 | 11 | FROM 12 | ( -- begin query that extracts the data 13 | SELECT ha.subject_id, ha.hadm_id 14 | -- here we assign labels to ITEMIDs 15 | -- this also fuses together multiple ITEMIDs containing the same data 16 | , CASE 17 | WHEN itemid = 51265 THEN 'PLATELET' 18 | ELSE null 19 | END AS label 20 | , -- add in some sanity checks on the values 21 | -- the where clause below requires all valuenum to be > 0, so these are only upper limit checks 22 | CASE 23 | WHEN itemid = 51265 and valuenum > 10000 THEN null -- K/uL 'PLATELET' 24 | ELSE le.valuenum 25 | END AS valuenum 26 | 27 | , (date_part('year', age(le.charttime, ha.admittime))*365 * 24 28 | + date_part('month', age(le.charttime, ha.admittime))*365/12 * 24 29 | + date_part('day', age(le.charttime, ha.admittime))* 24 30 | + date_part('hour', age(le.charttime, ha.admittime)) 31 | + round(date_part('minute', age(le.charttime, ha.admittime))/60)) as HLOS 32 | 33 | FROM admissions ha 34 | 35 | LEFT JOIN labevents le 36 | ON le.hadm_id = ha.hadm_id 37 | AND le.charttime BETWEEN (ha.admittime - interval '1' day) AND ha.dischtime 38 | AND le.ITEMID in 39 | ( 40 | -- comment is: LABEL | CATEGORY | FLUID | NUMBER OF ROWS IN LABEVENTS 41 | 51265 -- PLATELET COUNT | HEMATOLOGY | BLOOD | 778444 42 | ) 43 | AND valuenum IS NOT null AND valuenum > 0 -- lab values cannot be 0 and cannot be negative 44 | ) pvt 45 | GROUP BY pvt.subject_id, pvt.hadm_id, pvt.HLOS 46 | ORDER BY pvt.subject_id, pvt.hadm_id, pvt.HLOS; -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/liver/labsperhour.sql: -------------------------------------------------------------------------------- 1 | -- This query pivots lab values taken in the first 24 hours of a patient's stay 2 | 3 | -- Have already confirmed that the unit of measurement is always the same: null or the correct unit 4 | 5 | DROP MATERIALIZED VIEW IF EXISTS liv_labsperhour CASCADE; 6 | CREATE materialized VIEW liv_labsperhour AS 7 | SELECT 8 | pvt.subject_id, pvt.hadm_id, pvt.HLOS 9 | , max(CASE WHEN label = 'BILIRUBIN' THEN valuenum ELSE null END) as BILIRUBIN 10 | 11 | FROM 12 | ( -- begin query that extracts the data 13 | SELECT ha.subject_id, ha.hadm_id 14 | -- here we assign labels to ITEMIDs 15 | -- this also fuses together multiple ITEMIDs containing the same data 16 | , CASE 17 | WHEN itemid = 50885 THEN 'BILIRUBIN' 18 | ELSE null 19 | END AS label 20 | , -- add in some sanity checks on the values 21 | -- the where clause below requires all valuenum to be > 0, so these are only upper limit checks 22 | CASE 23 | WHEN itemid = 50885 and valuenum > 150 THEN null -- mg/dL 'BILIRUBIN' 24 | ELSE le.valuenum 25 | END AS valuenum 26 | 27 | , (date_part('year', age(le.charttime, ha.admittime))*365 * 24 28 | + date_part('month', age(le.charttime, ha.admittime))*365/12 * 24 29 | + date_part('day', age(le.charttime, ha.admittime))* 24 30 | + date_part('hour', age(le.charttime, ha.admittime)) 31 | + round(date_part('minute', age(le.charttime, ha.admittime))/60)) as HLOS 32 | 33 | FROM admissions ha 34 | 35 | LEFT JOIN labevents le 36 | ON le.hadm_id = ha.hadm_id 37 | AND le.charttime BETWEEN (ha.admittime - interval '1' day) AND ha.dischtime 38 | AND le.ITEMID in 39 | ( 40 | -- comment is: LABEL | CATEGORY | FLUID | NUMBER OF ROWS IN LABEVENTS 41 | 50885 -- BILIRUBIN, TOTAL | CHEMISTRY | BLOOD | 238277 42 | ) 43 | AND valuenum IS NOT null AND valuenum > 0 -- lab values cannot be 0 and cannot be negative 44 | ) pvt 45 | GROUP BY pvt.subject_id, pvt.hadm_id, pvt.HLOS 46 | ORDER BY pvt.subject_id, pvt.hadm_id, pvt.HLOS; -------------------------------------------------------------------------------- /src/data_preprocessing/features_preprocessing/archive/GP_prep_part_II.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import sys 5 | 6 | import pandas as pd 7 | 8 | # appending head path 9 | cwd = os.path.dirname(os.path.abspath(__file__)) 10 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir, os.pardir)) 11 | sys.path.append(head) 12 | 13 | 14 | class GPPreprocessingSecondRound: 15 | def __init__(self, split, n_features=44): 16 | self.path = os.path.join(head, 'data', split) 17 | self.n_features = n_features 18 | 19 | def load_files(self, features=None): 20 | if features is None: 21 | file_path = os.path.join(self.path, 'GP_prep.pkl') 22 | stat_file_path = os.path.join(self.path, 'full_static.csv') 23 | else: 24 | file_path = os.path.join(self.path, "/GP_prep_{}.pkl".format(features)) 25 | stat_file_path = os.path.join(self.path, "/full_static_{}.csv".format(features)) 26 | with open(file_path, "rb") as f: 27 | self.data = pickle.load(f) 28 | self.static_data = pd.read_csv(stat_file_path) 29 | 30 | def discard_useless_files(self): 31 | # keep: Y, T, ind_T, ind_K, num_obs, X, num_X 32 | # atm values, times, ind_lvs, ind_times, 33 | # labels, num_rnn_grid_times, rnn_grid_times, 34 | # num_obs_times, num_obs_values, onset_hour, ids 35 | Y, T, ind_Y, _, labels, len_X, X, len_T, _, onset_hour, ids = self.data 36 | self.data = [Y, T, ind_Y, len_T, X, len_X, labels, ids, onset_hour] 37 | 38 | def join_files(self): 39 | self.static_data.set_index("icustay_id", inplace=True) 40 | self.static_data = self.static_data.loc[self.data[-2]] 41 | self.static_data.drop(columns="Unnamed: 0", inplace=True) 42 | self.static_data = self.static_data.to_numpy() 43 | self.data.append(self.static_data) 44 | 45 | def save(self, features=None): 46 | if features is None: 47 | file_path = os.path.join(self.path, 'GP_prep_v2.pkl') 48 | else: 49 | file_path = os.path.join(self.path, "/GP_prep_{}_v2.pkl".format(features)) 50 | 51 | with open(file_path, "wb") as f: 52 | pickle.dump(self.data, f) 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Ignore Disk Image 132 | *.DS_Store 133 | -------------------------------------------------------------------------------- /src/loss_n_eval/aucs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from scipy import interp 5 | from sklearn.metrics import precision_recall_curve, roc_curve, auc 6 | 7 | # appending head path 8 | cwd = os.path.dirname(os.path.abspath(__file__)) 9 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir)) 10 | sys.path.append(head) 11 | from src.utils.debug import t_print 12 | 13 | 14 | def evals(y_true, y_proba, classes, cv=False, singles=True, overall=False): 15 | roc_auc = [] 16 | pr_auc = [] 17 | roc_comps = [] 18 | pr_comps = [] 19 | if singles: 20 | # calculate ROC and PR for each horizon 21 | cl_print = [] 22 | for i in range(7): 23 | idx = classes == i 24 | if np.sum(idx) != 0: 25 | cl_print.append(np.sum(idx)) 26 | y_star = y_true[idx] 27 | y = y_proba[idx, 0] 28 | roc_horizon, pr_horizon, roc_comp, pr_comp = one_eval(y_star, y, cv=cv) 29 | if cv: 30 | roc_comps.append(np.asarray(roc_comp)) 31 | pr_comps.append(np.asarray(pr_comp)) 32 | roc_auc.append(roc_horizon) 33 | pr_auc.append(pr_horizon) 34 | else: 35 | cl_print.append(0) 36 | #t_print("warning: no class {}".format(i)) 37 | roc_auc.append(0) 38 | pr_auc.append(0) 39 | #print('classes', cl_print, np.sum(cl_print), flush=True) 40 | if overall: 41 | # calculate ROC and PR over all horizons 42 | roc_horizon, pr_horizon, roc_comp, pr_comp = one_eval(y_star=y_true, y=y_proba[:, 0], cv=cv) 43 | if cv: 44 | roc_comps.append(np.asarray(roc_comp)) 45 | pr_comps.append(np.asarray(pr_comp)) 46 | roc_auc.append(roc_horizon) 47 | pr_auc.append(pr_horizon) 48 | return roc_auc, pr_auc, roc_comps, pr_comps 49 | 50 | 51 | def one_eval(y_star, y, cv=False): 52 | linear_space = np.linspace(0, 1, 100) 53 | 54 | fpr, tpr, _ = roc_curve(y_true=y_star, y_score=y) 55 | roc_auc = auc(fpr, tpr) 56 | 57 | pre, rec, _ = precision_recall_curve(y_true=y_star, probas_pred=y) 58 | recall = rec[np.argsort(rec)] 59 | precision = pre[np.argsort(rec)] 60 | pr_auc = auc(recall, precision) 61 | 62 | adj_fpr, adj_tpr = linear_space, interp(linear_space, fpr, tpr) 63 | adj_tpr[0] = 0.0 64 | adj_rec, adj_pre = linear_space, interp(linear_space, recall, precision) 65 | if cv: 66 | return roc_auc, pr_auc, [adj_fpr, adj_tpr], [adj_rec, adj_pre] 67 | else: 68 | return roc_auc, pr_auc, None, None 69 | -------------------------------------------------------------------------------- /src/models/attTCN_alpha.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tensorflow as tf 4 | from tensorflow import keras 5 | import numpy as np 6 | 7 | # appending head path 8 | cwd = os.path.dirname(os.path.abspath(__file__)) 9 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir, os.pardir)) 10 | sys.path.append(head) 11 | 12 | from src.models.TCN import make_model 13 | 14 | 15 | class AttTCN_alpha: 16 | def __init__(self, 17 | time_window, 18 | n_channels, 19 | num_layers, 20 | DO, 21 | L2reg, 22 | kernel_size=2, 23 | stride=1): 24 | self.alphaTCN = make_model(time_window=time_window, 25 | no_channels=n_channels, 26 | L2reg=L2reg, 27 | DO=DO, 28 | num_layers=num_layers, 29 | kernel_size=kernel_size, 30 | stride=stride, 31 | add_classification_layer=False) 32 | 33 | self.alpha_layer = keras.layers.Dense(2, input_shape=[n_channels], name="alpha_weights") 34 | 35 | self.trainable_variables = self.alphaTCN.trainable_variables + \ 36 | self.alpha_layer.trainable_variables 37 | self.num_layers = num_layers 38 | 39 | def __call__(self, inputs): 40 | # Note that the activation on alpha and the output are only valid if for a model trained on the last timestep 41 | self.alpha = keras.activations.softmax(self.alpha_layer(self.alphaTCN(inputs)), -2) 42 | _ = self.get_weights() 43 | # end shape = batch x time x features x outcome_classes 44 | end_shape = inputs.shape + [2] 45 | expanded_alpha = tf.broadcast_to(tf.expand_dims(self.alpha, -2), end_shape) 46 | expanded_inputs = tf.broadcast_to(tf.expand_dims(inputs, -1), end_shape) 47 | 48 | return tf.reduce_sum(expanded_alpha * expanded_inputs, [1, 2]) 49 | 50 | def get_weights(self): 51 | self.trainable_variables = self.alphaTCN.trainable_variables + \ 52 | self.alpha_layer.trainable_variables 53 | return self.trainable_variables 54 | 55 | def set_weights(self, weights): 56 | if not isinstance(weights[0], np.ndarray): 57 | weights = [weights[i].numpy() for i in range(len(weights))] 58 | start = 0 59 | end = self.num_layers * 4 60 | self.alphaTCN.set_weights(weights[start: end]) 61 | start = end 62 | end += 2 63 | self.alpha_layer.set_weights(weights[start: end]) 64 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_features/hourly-cohort.sql: -------------------------------------------------------------------------------- 1 | -- create new hourly-case-cohort: 2 | drop table if exists cases_hourly_ex1c CASCADE; 3 | create table cases_hourly_ex1c( 4 | icustay_id int, 5 | intime timestamp, 6 | outtime timestamp, 7 | length_of_stay double precision, 8 | delta_score int, 9 | sepsis_onset timestamp, 10 | sepsis_onset_day double precision, 11 | sepsis_onset_hour double precision 12 | ); 13 | 14 | 15 | -- define cases and controls: join sofa-delta table with exclusion criteria! 16 | insert into cases_hourly_ex1c 17 | select 18 | sd.icustay_id 19 | , s3c.intime 20 | , s3c.outtime 21 | , extract(EPOCH from s3c.outtime - s3c.intime) 22 | / 60.0 / 60.0 as length_of_stay 23 | , sd.delta_score 24 | , sd.sepsis_onset 25 | , extract(EPOCH from sd.sepsis_onset - s3c.intime) 26 | / 60.0 / 60.0 / 24.0 as sepsis_onset_day 27 | , extract(EPOCH from sd.sepsis_onset - s3c.intime) 28 | / 60.0 / 60.0 as sepsis_onset_hour 29 | from sofa_delta sd 30 | inner join sepsis3_cohort_mr s3c 31 | on sd.icustay_id = s3c.icustay_id 32 | inner join admissions adm 33 | on s3c.hadm_id = adm.hadm_id 34 | where s3c.excluded = 0 35 | and extract(EPOCH from sd.sepsis_onset - s3c.intime) 36 | / 60.0 / 60.0 > 0.5 37 | 38 | group by sd.icustay_id, s3c.intime, s3c.outtime, length_of_stay, sd.delta_score, sd.sepsis_onset, 39 | sepsis_onset_day, sepsis_onset_hour 40 | order by sd.icustay_id 41 | ; 42 | 43 | 44 | --new control cohort (without corrected icd criteria!) 45 | drop table if exists controls_hourly CASCADE; 46 | create table controls_hourly( 47 | icustay_id int, 48 | hadm_id int, 49 | intime timestamp, 50 | outtime timestamp, 51 | length_of_stay double precision, 52 | delta_score int, 53 | sepsis_onset timestamp 54 | ); 55 | 56 | insert into controls_hourly 57 | select 58 | s3c.icustay_id 59 | , s3c.hadm_id 60 | , s3c.intime 61 | , s3c.outtime 62 | , extract(EPOCH from s3c.outtime - s3c.intime) 63 | / 60.0 / 60.0 as length_of_stay 64 | , sd.delta_score 65 | , sd.sepsis_onset 66 | from sepsis3_cohort_mr s3c 67 | left join sofa_delta sd 68 | on s3c.icustay_id = sd.icustay_id 69 | inner join admissions adm 70 | on s3c.hadm_id = adm.hadm_id 71 | -- NEW: to remove icd9 sepsis from controls! 72 | 73 | where 74 | s3c.hadm_id not in ( 75 | select distinct(dg.hadm_id) 76 | from diagnoses_icd dg 77 | where dg.icd9_code in ('78552','99591','99592')) 78 | and s3c.excluded = 0 79 | and sd.sepsis_onset is null 80 | and extract(EPOCH from s3c.outtime - s3c.intime) 81 | / 60.0 / 60.0 > 0.5 82 | 83 | group by s3c.icustay_id, s3c.hadm_id, s3c.intime, s3c.outtime, length_of_stay, sd.delta_score, sd.sepsis_onset 84 | order by s3c.hadm_id, s3c.icustay_id 85 | ; 86 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_features/static-query.sql: -------------------------------------------------------------------------------- 1 | 2 | -- ------------------------------------------------------------------ 3 | -- Title: Query for static variables 4 | -- Description: This query extracts static variables for all patients (age, gender, ethnicity, ..) around admission (such that it can be used by our model!) 5 | -- Comments: static information which was not known at icu entry we considered as SPOILER --> it still could be useful for demographic description in paper 6 | -- >> MODIFIED VERSION 7 | -- SOURCE: https://github.com/MIT-LCP/mimic-code/blob/master/concepts/demographics/icustay-detail.sql 8 | -- AUTHOR (of this version): Michael Moor, October 2018 9 | 10 | -- ------------------------------------------------------------------ 11 | 12 | -- This query extracts useful demographic/administrative information for patient ICU stays 13 | DROP MATERIALIZED VIEW IF EXISTS icustay_static CASCADE; 14 | CREATE MATERIALIZED VIEW icustay_static as 15 | 16 | SELECT ie.subject_id, ie.hadm_id, ie.icustay_id 17 | 18 | -- patient level factors 19 | , pat.gender 20 | -- SPOILER --, pat.dod --(don't use date of death, as we only want to use info that is known at admission!)s 21 | 22 | --, round(cast(adm_w.Weight_Admit as numeric), 2) as Weight_Admit 23 | --, adm_w.Weight_Admit 24 | --, adm_h2.Height 25 | 26 | -- hospital level factors 27 | , adm.admittime 28 | -- SPOILER --, adm.dischtime 29 | -- SPOILER --, ROUND( (CAST(EXTRACT(epoch FROM adm.dischtime - adm.admittime)/(60*60*24) AS numeric)), 4) AS los_hospital 30 | , ROUND( (CAST(EXTRACT(epoch FROM adm.admittime - pat.dob)/(60*60*24*365.242) AS numeric)), 4) AS admission_age 31 | , adm.ethnicity, adm.admission_type 32 | , adm.admission_location 33 | , ie.first_careunit 34 | 35 | -- SPOILER --, adm.hospital_expire_flag 36 | , DENSE_RANK() OVER (PARTITION BY adm.subject_id ORDER BY adm.admittime) AS hospstay_seq -- >>>>>> IS THIS SPOILED? 37 | , CASE 38 | WHEN DENSE_RANK() OVER (PARTITION BY adm.subject_id ORDER BY adm.admittime) = 1 THEN 1 39 | ELSE 0 END AS first_hosp_stay 40 | 41 | -- icu level factors 42 | , ie.intime 43 | , ie.outtime 44 | -- SPOILER --, ROUND( (CAST(EXTRACT(epoch FROM ie.outtime - ie.intime)/(60*60*24) AS numeric)), 4) AS los_icu 45 | , DENSE_RANK() OVER (PARTITION BY ie.hadm_id ORDER BY ie.intime) AS icustay_seq 46 | 47 | -- first ICU stay *for the current hospitalization* 48 | , CASE 49 | WHEN DENSE_RANK() OVER (PARTITION BY ie.hadm_id ORDER BY ie.intime) = 1 THEN 1 50 | ELSE 0 END AS first_icu_stay 51 | 52 | FROM icustays ie 53 | INNER JOIN admissions adm 54 | ON ie.hadm_id = adm.hadm_id 55 | -- for admission weight: 56 | --left join adm_w 57 | -- on ie.icustay_id = adm_w.icustay_id 58 | --left join adm_h2 59 | -- on ie.icustay_id = adm_h2.icustay_id 60 | 61 | INNER JOIN patients pat 62 | ON ie.subject_id = pat.subject_id 63 | WHERE adm.has_chartevents_data = 1 64 | ORDER BY ie.subject_id, adm.admittime, ie.intime; 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | MGP-AttTCN: An Interpretable Machine Learning Model for the Prediction of Sepsis 2 | ============================== 3 | 4 | Data 5 | ------------ 6 | The dataset used is the MIMIC III dataset, fount at https://mimic.physionet.org. 7 | 8 | Use 9 | ------------ 10 | 11 | STEP I: install dependencies 12 | `pip install -r requirements.txt` 13 | 14 | STEP II: data extraction & preprocessing 15 | `python scr/data_processing/main.py [-h] -u SQLUSER -pw SQLPASS -ht HOST -db DBNAME -r SCHEMA_READ_NAME [-w SCHEMA_WRITE_NAME]` 16 | 17 | STEP III: run the model 18 | 19 | Project Organization 20 | ------------ 21 | 22 | ├── LICENSE 23 | ├── README.md <- The top-level README for developers using this project. 24 | ├── data 25 | │   ├── external <- Data from third party sources. 26 | │   ├── interim <- Intermediate data that has been transformed. 27 | │   ├── processed <- The final, canonical data sets for modeling. 28 | │   ├── raw <- The original, immutable data dump. 29 | │   ├── train <- The training data used for ... training. 30 | │   ├── val <- The validation data used for ... validating (and hyperparameter selection). 31 | │   └── test <- The test data used for reporting. 32 | │ 33 | ├── src <- Source code for use in this project. 34 | │   ├── __init__.py <- Makes src a Python module 35 | │ │ 36 | │   ├── mains <- Runs the full pipeline 37 | │   │   └── GP_TCN_stat_main.py <- in use for MGP-TCN; MGP-AttTCN 38 | │   │ 39 | │   ├── data_loader <- Loads the data into main 40 | │   │   └── raw_irreg_loader.py <- in use for MGP-TCN; MGP-AttTCN 41 | │ │ 42 | │   ├── models <- Models to load into main 43 | │ │ ├── GP_TCN_Moor.py <- re-implementation of Moor et. al. (MGP-TCN) 44 | │ │ └── GP_attTCN.py <- thesis model: MGP + attention based TCN (MGP-AttTCN) 45 | │ │ 46 | │   ├── trainer <- Trains the data 47 | │   │   └── GP_trainer_with_stat.py <- in use for MGP-TCN; MGP-AttTCN 48 | │ │ 49 | │   ├── loss_n_eval <- Files to calculate loss, gradients and AUROC, AUPR 50 | │   │   └── ... 51 | │ │ 52 | │   ├── visualization <- Scripts to create exploratory and results oriented visualizations 53 | │ │ 54 | │   ├── data_preprocessing <- Scripts to download or generate data 55 | │   │ 56 | │   └── features_preprocessing <- Scripts to turn raw data into features for modeling 57 | │ 58 | └── requirements.txt <- The requirements file for reproducing the analysis environment, e.g. 59 | generated with `pip freeze > requirements.txt` 60 | 61 | 62 | Credits 63 | ------------ 64 | Credits to M. Moor for sharing his code from https://arxiv.org/abs/1902.01659 65 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/renal/uoperhour.sql: -------------------------------------------------------------------------------- 1 | -- ------------------------------------------------------------------ 2 | -- Purpose: Create a view of the urine output for each ICUSTAY_ID over the first 24 hours. 3 | -- ------------------------------------------------------------------ 4 | 5 | DROP MATERIALIZED VIEW IF EXISTS SOFA_uoperhour CASCADE; 6 | create materialized view SOFA_uoperhour as 7 | 8 | select 9 | -- patient identifiers 10 | ha.subject_id, ha.hadm_id, ie.icustay_id 11 | 12 | -- volumes associated with urine output ITEMIDs 13 | , sum( 14 | -- we consider input of GU irrigant as a negative volume 15 | case 16 | when oe.itemid = 227488 and oe.value > 0 then -1*oe.value 17 | else oe.value 18 | end) as UrineOutput 19 | , (date_part('year', age(oe.charttime, ha.admittime))*365 * 24 20 | + date_part('month', age(oe.charttime, ha.admittime))*365/12 * 24 21 | + date_part('day', age(oe.charttime, ha.admittime))* 24 22 | + date_part('hour', age(oe.charttime, ha.admittime)) 23 | + round(date_part('minute', age(oe.charttime, ha.admittime))/60)) as HLOS 24 | 25 | , (date_part('year', age(oe.charttime, ie.intime))*365 * 24 26 | + date_part('month', age(oe.charttime, ie.intime))*365/12 * 24 27 | + date_part('day', age(oe.charttime, ie.intime))* 24 28 | + date_part('hour', age(oe.charttime, ie.intime)) 29 | + round(date_part('minute', age(oe.charttime, ie.intime))/60)) as ICULOS 30 | 31 | from admissions ha 32 | -- Join to the outputevents table to get urine output 33 | left join outputevents oe 34 | -- join on all patient identifiers 35 | on ha.subject_id = oe.subject_id and ha.hadm_id = oe.hadm_id 36 | left join icustays ie 37 | on ie.icustay_id = oe.icustay_id 38 | -- and ensure the data occurs during the first day 39 | and oe.charttime between (ha.admittime - interval '1' day) AND ha.dischtime 40 | where itemid in 41 | ( 42 | -- these are the most frequently occurring urine output observations in CareVue 43 | 40055, -- "Urine Out Foley" 44 | 43175, -- "Urine ." 45 | 40069, -- "Urine Out Void" 46 | 40094, -- "Urine Out Condom Cath" 47 | 40715, -- "Urine Out Suprapubic" 48 | 40473, -- "Urine Out IleoConduit" 49 | 40085, -- "Urine Out Incontinent" 50 | 40057, -- "Urine Out Rt Nephrostomy" 51 | 40056, -- "Urine Out Lt Nephrostomy" 52 | 40405, -- "Urine Out Other" 53 | 40428, -- "Urine Out Straight Cath" 54 | 40086,-- Urine Out Incontinent 55 | 40096, -- "Urine Out Ureteral Stent #1" 56 | 40651, -- "Urine Out Ureteral Stent #2" 57 | 58 | -- these are the most frequently occurring urine output observations in MetaVision 59 | 226559, -- "Foley" 60 | 226560, -- "Void" 61 | 226561, -- "Condom Cath" 62 | 226584, -- "Ileoconduit" 63 | 226563, -- "Suprapubic" 64 | 226564, -- "R Nephrostomy" 65 | 226565, -- "L Nephrostomy" 66 | 226567, -- Straight Cath 67 | 226557, -- R Ureteral Stent 68 | 226558, -- L Ureteral Stent 69 | 227488, -- GU Irrigant Volume In 70 | 227489 -- GU Irrigant/Urine Volume Out 71 | ) 72 | group by ha.subject_id, ha.hadm_id, ie.icustay_id, HLOS, ICULOS 73 | order by ha.subject_id, ha.hadm_id, ie.icustay_id, HLOS, ICULOS -------------------------------------------------------------------------------- /src/models/TCN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import tensorflow as tf 5 | from tensorflow import keras 6 | 7 | cwd = os.path.dirname(os.path.abspath(__file__)) 8 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir)) 9 | sys.path.append(head) 10 | 11 | 12 | def make_model(time_window, no_channels, L2reg, DO, num_layers, kernel_size=2, stride=1, add_classification_layer=True, 13 | filters_per_layer=None): 14 | no_initial_channels = no_channels 15 | if filters_per_layer is None: 16 | no_channels = no_initial_channels 17 | else: 18 | no_channels = filters_per_layer 19 | # residual block 20 | layers = [keras.layers.Conv1D(filters=no_channels, kernel_size=kernel_size, strides=1, padding='causal', 21 | dilation_rate=1, activation=tf.nn.relu, 22 | input_shape=(time_window, no_initial_channels), 23 | kernel_regularizer=keras.regularizers.l2(L2reg[0]), 24 | name="conv00"), 25 | keras.layers.Dropout(DO[0], 26 | name="DropOut00"), 27 | keras.layers.Conv1D(filters=no_channels, kernel_size=kernel_size, strides=1, padding='causal', 28 | dilation_rate=1, activation=tf.nn.relu, 29 | kernel_regularizer=keras.regularizers.l2(L2reg[0]), 30 | name="conv01"), 31 | keras.layers.Dropout(DO[0], 32 | name="DropOut01")] 33 | for i in range(1, num_layers): 34 | layers += [keras.layers.Conv1D(filters=no_channels, kernel_size=kernel_size, strides=1, padding='causal', 35 | dilation_rate=2 ** i, activation=tf.nn.relu, 36 | kernel_regularizer=keras.regularizers.l2(L2reg[i]), 37 | name="conv{}0".format(i)), 38 | keras.layers.Dropout(DO[i], 39 | name="DropOut{}0".format(i)), 40 | keras.layers.Conv1D(filters=no_channels, kernel_size=kernel_size, strides=1, padding='causal', 41 | dilation_rate=2 ** i, activation=tf.nn.relu, 42 | kernel_regularizer=keras.regularizers.l2(L2reg[i]), 43 | name="conv{}1".format(i)), 44 | keras.layers.Dropout(DO[i], 45 | name="DropOut{}1".format(i))] 46 | if add_classification_layer: 47 | layers.append(LastDimDenseLayer(no_channels, 2)) 48 | model = keras.Sequential(layers) 49 | return model 50 | 51 | 52 | class LastDimDenseLayer(tf.keras.layers.Layer): 53 | def __init__(self, no_channels, num_outputs): 54 | super(LastDimDenseLayer, self).__init__() 55 | self.kernel = self.add_variable("LastTimestepDense", 56 | shape=[no_channels, num_outputs]) 57 | 58 | def build(self, input_shape): 59 | pass 60 | 61 | def call(self, input): 62 | return tf.matmul(input[:, -1, :], self.kernel) 63 | -------------------------------------------------------------------------------- /src/models/GP_logreg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tensorflow as tf 4 | import numpy as np 5 | # appending head path 6 | cwd = os.path.dirname(os.path.abspath(__file__)) 7 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir, os.pardir)) 8 | sys.path.append(head) 9 | from src.models.GP import MultiKernelMGPLayer 10 | from src.models.attTCN import AttTCN 11 | 12 | 13 | class GPLogReg: 14 | def __init__(self, 15 | time_window, 16 | n_mc_samples, 17 | n_features, 18 | n_stat_features, 19 | log_noise_mean=-2, 20 | log_noise_std=0.1, 21 | method_name='chol', 22 | add_diag=0.001, 23 | L2reg=None, 24 | save_path=head, 25 | ): 26 | # a few variables to be used later 27 | self.tw = time_window 28 | self.non_s_feat = n_features 29 | self.s_feat = n_stat_features 30 | self.samp = n_mc_samples 31 | 32 | # the model 33 | self.GP = MultiKernelMGPLayer(time_window=time_window, 34 | n_mc_samples=n_mc_samples, 35 | n_features=n_features, 36 | log_noise_mean=log_noise_mean, 37 | log_noise_std=log_noise_std, 38 | method_name=method_name, 39 | add_diag=add_diag, 40 | save_path=save_path) 41 | 42 | self.LogReg = tf.keras.Sequential( 43 | [tf.keras.layers.Dense(2, 44 | input_shape=(time_window * n_features + n_stat_features,), 45 | kernel_regularizer=tf.keras.regularizers.L2(L2reg[0]), 46 | bias_regularizer=tf.keras.regularizers.L2(L2reg[0]),)]) 47 | 48 | self.trainable_variables = self.GP.trainable_variables + \ 49 | self.LogReg.trainable_variables 50 | self.n_GP_var = len(self.GP.trainable_variables) 51 | 52 | def __call__(self, inputs): 53 | self.GP_out = self.GP(inputs[:-1]) 54 | # GP out: batch x MC samples x tw x features 55 | self.GP_out = tf.reshape(self.GP_out, (-1, self.samp, self.tw * self.non_s_feat)) 56 | stat_input = tf.expand_dims(inputs[-1], axis=1) 57 | stat_input = tf.broadcast_to(stat_input, [stat_input.shape[0], self.samp, stat_input.shape[-1]]) 58 | self.LR_input = tf.concat([self.GP_out, stat_input], axis=-1) 59 | self.LR_input = tf.reshape(self.LR_input, (-1, self.tw * self.non_s_feat + self.s_feat)) 60 | 61 | return self.LogReg(self.LR_input) 62 | 63 | def get_weights(self): 64 | self.trainable_variables = self.GP.trainable_variables + \ 65 | self.LogReg.trainable_variables 66 | return self.trainable_variables 67 | 68 | def set_weights(self, weights): 69 | if not isinstance(weights[0], np.ndarray): 70 | weights = [weights[i].numpy() for i in range(len(weights))] 71 | self.GP.set_weights(weights[:self.n_GP_var]) 72 | self.LogReg.set_weights(weights[self.n_GP_var:]) 73 | -------------------------------------------------------------------------------- /src/models/attTCN_beta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tensorflow as tf 4 | from tensorflow import keras 5 | import numpy as np 6 | 7 | # appending head path 8 | cwd = os.path.dirname(os.path.abspath(__file__)) 9 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir, os.pardir)) 10 | sys.path.append(head) 11 | 12 | from src.models.TCN import make_model 13 | 14 | 15 | class AttTCN_beta: 16 | def __init__(self, 17 | time_window, 18 | n_channels, 19 | num_layers, 20 | DO, 21 | L2reg, 22 | kernel_size=2, 23 | stride=1, 24 | sigmoid_beta=False): 25 | self.betaTCN = make_model(time_window=time_window, 26 | no_channels=n_channels, 27 | L2reg=L2reg, 28 | DO=DO, 29 | num_layers=num_layers, 30 | kernel_size=kernel_size, 31 | stride=stride, 32 | add_classification_layer=False) 33 | 34 | self.beta_layer_pos = keras.layers.Dense(n_channels, input_shape=[n_channels], name="beta_pos_weights") 35 | self.beta_layer_neg = keras.layers.Dense(n_channels, input_shape=[n_channels], name="beta_neg_weights") 36 | 37 | self.trainable_variables =self.betaTCN.trainable_variables + \ 38 | self.beta_layer_pos.trainable_variables + \ 39 | self.beta_layer_neg.trainable_variables 40 | self.num_layers = num_layers 41 | self.sigmoid_beta = sigmoid_beta 42 | 43 | def __call__(self, inputs): 44 | # Note that the activation on alpha and the output are only valid if for a model trained on the last timestep 45 | if self.sigmoid_beta: 46 | beta_pos = tf.expand_dims(keras.activations.sigmoid(self.beta_layer_pos(self.betaTCN(inputs))), -1) 47 | beta_neg = tf.expand_dims(keras.activations.sigmoid(self.beta_layer_neg(self.betaTCN(inputs))), -1) 48 | else: 49 | beta_pos = tf.expand_dims(self.beta_layer_pos(self.betaTCN(inputs)), -1) 50 | beta_neg = tf.expand_dims(self.beta_layer_neg(self.betaTCN(inputs)), -1) 51 | _ = self.get_weights() 52 | self.beta = tf.concat([beta_pos, beta_neg], -1) 53 | expanded_inputs = tf.broadcast_to(tf.expand_dims(inputs, -1), list(self.beta.shape)) 54 | 55 | return tf.reduce_sum(self.beta * expanded_inputs, [1, 2]) 56 | 57 | def get_weights(self): 58 | self.trainable_variables = self.betaTCN.trainable_variables + \ 59 | self.beta_layer_pos.trainable_variables + \ 60 | self.beta_layer_neg.trainable_variables 61 | return self.trainable_variables 62 | 63 | def set_weights(self, weights): 64 | if not isinstance(weights[0], np.ndarray): 65 | weights = [weights[i].numpy() for i in range(len(weights))] 66 | start = 0 67 | end = self.num_layers * 4 68 | self.betaTCN.set_weights(weights[start: end]) 69 | start = end 70 | end += 2 71 | self.beta_layer_pos.set_weights(weights[start: end]) 72 | start = end 73 | end += 2 74 | self.beta_layer_neg.set_weights(weights[start: end]) 75 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_features/sepsis3_cohort_mr.sql: -------------------------------------------------------------------------------- 1 | drop table if exists sepsis3_cohort_mr cascade; 2 | create table sepsis3_cohort_mr as 3 | with serv as 4 | 5 | ( 6 | select hadm_id, curr_service 7 | , ROW_NUMBER() over (partition by hadm_id order by transfertime) as rn 8 | from services 9 | ) 10 | 11 | , t1 as 12 | ( 13 | select ie.icustay_id, ie.hadm_id 14 | , ie.intime, ie.outtime 15 | , round((cast(adm.admittime as date) - cast(pat.dob as date)) / 365.242, 4) as age 16 | , pat.gender 17 | , adm.ethnicity 18 | , ie.dbsource 19 | -- used to get first ICUSTAY_ID 20 | , ROW_NUMBER() over (partition by ie.subject_id order by intime) as rn 21 | 22 | -- exclusions 23 | , s.curr_service as first_service 24 | , adm.HAS_CHARTEVENTS_DATA 25 | 26 | -- suspicion of infection using POE 27 | , case when spoe.suspected_infection_time is not null then 1 else 0 end 28 | as suspected_of_infection_poe 29 | , spoe.suspected_infection_time as suspected_infection_time_poe 30 | , extract(EPOCH from ie.intime - spoe.suspected_infection_time) 31 | / 60.0 / 60.0 / 24.0 as suspected_infection_time_poe_days 32 | -- , spoe.specimen as specimen_poe 33 | -- , spoe.positiveculture as positiveculture_poe 34 | -- , spoe.antibiotic_time as antibiotic_time_poe 35 | 36 | from icustays ie 37 | inner join admissions adm 38 | on ie.hadm_id = adm.hadm_id 39 | inner join patients pat 40 | on ie.subject_id = pat.subject_id 41 | left join serv s 42 | on ie.hadm_id = s.hadm_id 43 | and s.rn = 1 44 | left join SI_flag spoe -- 'hadm_id', 'suspected_infection_time', 'si_start', 'si_end' 45 | on ie.hadm_id = spoe.hadm_id 46 | ) 47 | select 48 | t1.hadm_id, t1.icustay_id 49 | , t1.intime, t1.outtime 50 | 51 | -- set de-identified ages to median of 91.4 52 | , case when age > 89 then 91.4 else age end as age 53 | , gender 54 | , ethnicity 55 | , first_service 56 | , dbsource 57 | 58 | -- suspicion using POE 59 | , suspected_of_infection_poe 60 | , suspected_infection_time_poe 61 | , suspected_infection_time_poe_days 62 | -- , specimen_poe 63 | -- , positiveculture_poe 64 | -- , antibiotic_time_poe 65 | 66 | -- exclusions 67 | , case when t1.rn = 1 then 0 else 1 end as exclusion_secondarystay_INACTIVE 68 | , case when t1.age <= 14 then 1 else 0 end as exclusion_nonadult -- CHANGED FROM ORIGINAL! <=16 69 | , case when t1.dbsource != 'metavision' then 1 else 0 end as exclusion_carevue 70 | , case when t1.suspected_infection_time_poe is not null -- CHANGED FROM ORIGINAL! 71 | and t1.suspected_infection_time_poe < t1.intime then 1 72 | else 0 end as exclusion_suspicion_before_intime_INACTIVE 73 | , case when t1.suspected_infection_time_poe is not null -- CHANGED FROM ORIGINAL! 74 | and t1.suspected_infection_time_poe > t1.intime + interval '4' hour then 1 75 | else null end as exclusion_suspicion_after_intime_plus_4_INACTIVE 76 | , case when t1.HAS_CHARTEVENTS_DATA = 0 then 1 77 | when t1.intime is null then 1 78 | when t1.outtime is null then 1 79 | else 0 end as exclusion_bad_data 80 | -- the above flags are used to summarize patients excluded 81 | -- below flag is used to actually exclude patients in future queries 82 | , case when 83 | t1.age <= 14 -- CHANGED FROM ORIGINAL! <=16 84 | or t1.HAS_CHARTEVENTS_DATA = 0 85 | or t1.intime is null 86 | or t1.outtime is null 87 | or t1.dbsource != 'metavision' 88 | then 1 89 | else 0 end as excluded 90 | 91 | from t1 92 | order by t1.icustay_id; -------------------------------------------------------------------------------- /src/models/GP_attTCN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tensorflow as tf 4 | import numpy as np 5 | # appending head path 6 | cwd = os.path.dirname(os.path.abspath(__file__)) 7 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir, os.pardir)) 8 | sys.path.append(head) 9 | from src.models.GP import MultiKernelMGPLayer 10 | from src.models.attTCN import AttTCN 11 | 12 | 13 | class GPattTCN: 14 | def __init__(self, 15 | time_window, 16 | n_mc_samples, 17 | n_features, 18 | n_stat_features, 19 | kernel='OU', 20 | len_mode='avg', 21 | len_trainable=True, 22 | log_noise_mean=-2, 23 | log_noise_std=0.1, 24 | log_length_mean=1, 25 | log_length_std=0.1, 26 | method_name='chol', 27 | add_diag=0.001, 28 | L2reg=None, 29 | DO=None, 30 | save_path=head, 31 | num_layers=4, 32 | kernel_size=2, 33 | stride=1, 34 | sigmoid_beta=False, 35 | moor_data=False 36 | ): 37 | # a few variables to be used later 38 | self.tw = time_window 39 | self.s_feat = n_stat_features 40 | self.samp = n_mc_samples 41 | 42 | # the model 43 | self.GP = MultiKernelMGPLayer(time_window=time_window, 44 | n_mc_samples=n_mc_samples, 45 | n_features=n_features, 46 | log_noise_mean=log_noise_mean, 47 | log_noise_std=log_noise_std, 48 | method_name=method_name, 49 | add_diag=add_diag, 50 | save_path=save_path) 51 | 52 | self.attTCN = AttTCN(time_window, 53 | n_features + n_stat_features, 54 | num_layers, 55 | DO, 56 | L2reg, 57 | kernel_size=kernel_size, 58 | stride=stride, 59 | sigmoid_beta=sigmoid_beta 60 | ) 61 | 62 | self.trainable_variables = self.GP.trainable_variables + \ 63 | self.attTCN.trainable_variables 64 | self.n_GP_var = len(self.GP.trainable_variables) 65 | 66 | def __call__(self, inputs): 67 | self.GP_out = self.GP(inputs[:-1]) 68 | stat_matching_shape = \ 69 | tf.concat([ # step III: concatenate all patients i 70 | # step II: tile patient info for each MC sample 71 | tf.tile( 72 | 73 | # step I: for pat i, repeat feat data for each time step 74 | tf.concat([tf.reshape(inputs[-1][i], [1, 1, self.s_feat]) for _ in range(self.tw)], axis=1) 75 | , [self.samp, 1, 1]) 76 | for i in range(inputs[-1].shape[0])], axis=0) 77 | 78 | self.TCN_input = tf.concat([self.GP_out, stat_matching_shape], -1) 79 | 80 | return self.attTCN(self.TCN_input) 81 | 82 | def get_weights(self): 83 | self.trainable_variables = self.GP.trainable_variables + \ 84 | self.attTCN.trainable_variables 85 | return self.trainable_variables 86 | 87 | def set_weights(self, weights): 88 | if not isinstance(weights[0], np.ndarray): 89 | weights = [weights[i].numpy() for i in range(len(weights))] 90 | self.GP.set_weights(weights[:self.n_GP_var]) 91 | self.attTCN.set_weights(weights[self.n_GP_var:]) 92 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/respiration/bloodgasfirstday.sql: -------------------------------------------------------------------------------- 1 | 2 | -- The aim of this query is to pivot entries related to blood gases and 3 | -- chemistry values which were found in LABEVENTS 4 | 5 | -- things to check: 6 | -- when a mixed venous/arterial blood sample are taken at the same time, is the store time different? 7 | 8 | DROP MATERIALIZED VIEW IF EXISTS resp_bloodgasfirstday CASCADE; 9 | create materialized view resp_bloodgasfirstday as 10 | with pvt as 11 | ( -- begin query that extracts the data 12 | select ha.subject_id, ha.hadm_id 13 | -- here we assign labels to ITEMIDs 14 | -- this also fuses together multiple ITEMIDs containing the same data 15 | , case 16 | when itemid = 50800 then 'SPECIMEN' -- KEEP 17 | when itemid = 50801 then 'AADO2' -- KEEP 18 | when itemid = 50803 then 'BICARBONATE' -- KEEP 19 | when itemid = 50804 then 'TOTALCO2' -- KEEP 20 | when itemid = 50811 then 'HEMOGLOBIN' -- KEEP 21 | when itemid = 50813 then 'LACTATE' -- KEEP 22 | when itemid = 50815 then 'O2FLOW' -- KEEP 23 | when itemid = 50816 then 'FIO2' -- KEEP 24 | when itemid = 50817 then 'SO2' -- OXYGENSATURATION -- KEEP 25 | when itemid = 50818 then 'PCO2' -- KEEP 26 | when itemid = 50820 then 'PH' -- KEEP 27 | when itemid = 50821 then 'PO2' -- KEEP 28 | else null 29 | end as label 30 | , charttime 31 | , value 32 | -- add in some sanity checks on the values 33 | , case 34 | when valuenum <= 0 then null 35 | when itemid = 50810 and valuenum > 100 then null -- hematocrit 36 | -- ensure FiO2 is a valid number between 21-100 37 | -- mistakes are rare (<100 obs out of ~100,000) 38 | -- there are 862 obs of valuenum == 20 - some people round down! 39 | -- rather than risk imputing garbage data for FiO2, we simply NULL invalid values 40 | when itemid = 50816 and valuenum < 20 then null 41 | when itemid = 50816 and valuenum > 100 then null 42 | when itemid = 50817 and valuenum > 100 then null -- O2 sat 43 | when itemid = 50815 and valuenum > 70 then null -- O2 flow 44 | when itemid = 50821 and valuenum > 800 then null -- PO2 45 | -- conservative upper limit 46 | else valuenum 47 | end as valuenum 48 | 49 | from admissions ha 50 | left join labevents le 51 | on le.subject_id = ha.subject_id and le.hadm_id = ha.hadm_id 52 | and le.charttime between ha.admittime - interval '1' day and ha.dischtime -- MR add 53 | and le.ITEMID in 54 | -- blood gases 55 | ( 56 | 50800, 50801, 50803, 50804, 50811, 50813, 50815, 50816, 50817, 50818 57 | , 50820, 50821 58 | ) 59 | ) 60 | select pvt.SUBJECT_ID, pvt.HADM_ID, pvt.CHARTTIME 61 | -- SPECIMEN 62 | , max(case when label = 'SPECIMEN' then value else null end) as SPECIMEN 63 | -- SPECIMEN PROB 64 | , max(case when label = 'AADO2' then valuenum else null end) as AADO2 -- KEEP 65 | , max(case when label = 'BICARBONATE' then valuenum else null end) as BICARBONATE -- KEEP 66 | , max(case when label = 'TOTALCO2' then valuenum else null end) as TOTALCO2 -- KEEP 67 | , max(case when label = 'HEMOGLOBIN' then valuenum else null end) as HEMOGLOBIN -- KEEP 68 | , max(case when label = 'LACTATE' then valuenum else null end) as LACTATE -- KEEP 69 | , max(case when label = 'O2FLOW' then valuenum else null end) as O2FLOW -- KEEP 70 | , max(case when label = 'SO2' then valuenum else null end) as SO2 -- OXYGENSATURATION -- KEEP 71 | , max(case when label = 'PCO2' then valuenum else null end) as PCO2 -- KEEP 72 | , max(case when label = 'PH' then valuenum else null end) as PH -- KEEP 73 | -- RESPIRATION VALS 74 | , max(case when label = 'PO2' then valuenum else null end) as PO2 -- KEEP 75 | , max(case when label = 'FIO2' then valuenum else null end) as FIO2 -- KEEP 76 | 77 | from pvt 78 | group by pvt.subject_id, pvt.hadm_id, pvt.CHARTTIME 79 | order by pvt.subject_id, pvt.hadm_id, pvt.CHARTTIME; -------------------------------------------------------------------------------- /src/models/attTCN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tensorflow as tf 4 | from tensorflow import keras 5 | import numpy as np 6 | 7 | # appending head path 8 | cwd = os.path.dirname(os.path.abspath(__file__)) 9 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir, os.pardir)) 10 | sys.path.append(head) 11 | 12 | from src.models.TCN import make_model 13 | 14 | 15 | class AttTCN: 16 | def __init__(self, 17 | time_window, 18 | n_channels, 19 | num_layers, 20 | DO, 21 | L2reg, 22 | kernel_size=2, 23 | stride=1, 24 | sigmoid_beta=False): 25 | self.alphaTCN = make_model(time_window=time_window, 26 | no_channels=n_channels, 27 | L2reg=L2reg, 28 | DO=DO, 29 | num_layers=num_layers, 30 | kernel_size=kernel_size, 31 | stride=stride, 32 | add_classification_layer=False) 33 | 34 | self.alpha_layer = keras.layers.Dense(2, input_shape=[n_channels], name="alpha_weights") 35 | 36 | self.betaTCN = make_model(time_window=time_window, 37 | no_channels=n_channels, 38 | L2reg=L2reg, 39 | DO=DO, 40 | num_layers=num_layers, 41 | kernel_size=kernel_size, 42 | stride=stride, 43 | add_classification_layer=False) 44 | 45 | self.beta_layer_pos = keras.layers.Dense(n_channels, input_shape=[n_channels], name="beta_pos_weights") 46 | self.beta_layer_neg = keras.layers.Dense(n_channels, input_shape=[n_channels], name="beta_neg_weights") 47 | 48 | self.trainable_variables = self.alphaTCN.trainable_variables + \ 49 | self.alpha_layer.trainable_variables + \ 50 | self.betaTCN.trainable_variables + \ 51 | self.beta_layer_pos.trainable_variables + \ 52 | self.beta_layer_neg.trainable_variables 53 | self.num_layers = num_layers 54 | self.sigmoid_beta = sigmoid_beta 55 | 56 | def __call__(self, inputs): 57 | # Note that the activation on alpha and the output are only valid if for a model trained on the last timestep 58 | self.alpha = keras.activations.softmax(self.alpha_layer(self.alphaTCN(inputs)), -2) 59 | if self.sigmoid_beta: 60 | beta_pos = tf.expand_dims(keras.activations.sigmoid(self.beta_layer_pos(self.betaTCN(inputs))), -1) 61 | beta_neg = tf.expand_dims(keras.activations.sigmoid(self.beta_layer_neg(self.betaTCN(inputs))), -1) 62 | else: 63 | beta_pos = tf.expand_dims(self.beta_layer_pos(self.betaTCN(inputs)), -1) 64 | beta_neg = tf.expand_dims(self.beta_layer_neg(self.betaTCN(inputs)), -1) 65 | _ = self.get_weights() 66 | self.beta = tf.concat([beta_pos, beta_neg], -1) 67 | 68 | expanded_alpha = tf.broadcast_to(tf.expand_dims(self.alpha, -2), list(self.beta.shape)) 69 | expanded_inputs = tf.broadcast_to(tf.expand_dims(inputs, -1), list(self.beta.shape)) 70 | 71 | return tf.reduce_sum(tf.reduce_sum(expanded_alpha * self.beta * expanded_inputs, -2), -2) 72 | 73 | def get_weights(self): 74 | self.trainable_variables = self.alphaTCN.trainable_variables + \ 75 | self.alpha_layer.trainable_variables + \ 76 | self.betaTCN.trainable_variables + \ 77 | self.beta_layer_pos.trainable_variables + \ 78 | self.beta_layer_neg.trainable_variables 79 | return self.trainable_variables 80 | 81 | def set_weights(self, weights): 82 | if not isinstance(weights[0], np.ndarray): 83 | weights = [weights[i].numpy() for i in range(len(weights))] 84 | start = 0 85 | end = self.num_layers * 4 86 | self.alphaTCN.set_weights(weights[start: end]) 87 | start = end 88 | end += 2 89 | self.alpha_layer.set_weights(weights[start: end]) 90 | start = end 91 | end += self.num_layers * 4 92 | self.betaTCN.set_weights(weights[start: end]) 93 | start = end 94 | end += 2 95 | self.beta_layer_pos.set_weights(weights[start: end]) 96 | start = end 97 | end += 2 98 | self.beta_layer_neg.set_weights(weights[start: end]) 99 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/respiration/ventdurations.sql: -------------------------------------------------------------------------------- 1 | -- This query extracts the duration of mechanical ventilation 2 | -- The main goal of the query is to aggregate sequential ventilator settings 3 | -- into single mechanical ventilation "events". The start and end time of these 4 | -- events can then be used for various purposes: calculating the total duration 5 | -- of mechanical ventilation, cross-checking values (e.g. PaO2:FiO2 on vent), etc 6 | 7 | -- The query's logic is roughly: 8 | -- 1) The presence of a mechanical ventilation setting starts a new ventilation event 9 | -- 2) Any instance of a setting in the next 8 hours continues the event 10 | -- 3) Certain elements end the current ventilation event 11 | -- a) documented extubation ends the current ventilation 12 | -- b) initiation of non-invasive vent and/or oxygen ends the current vent 13 | -- The ventilation events are numbered consecutively by the `num` column. 14 | 15 | -- PART II 16 | 17 | 18 | --DROP MATERIALIZED VIEW IF EXISTS VENTDURATIONS CASCADE; 19 | DROP MATERIALIZED VIEW IF EXISTS VENTDURATIONS CASCADE; 20 | create MATERIALIZED VIEW ventdurations as 21 | with vd0 as 22 | ( 23 | select 24 | icustay_id 25 | -- this carries over the previous charttime which had a mechanical ventilation event 26 | , case 27 | when MechVent=1 then 28 | LAG(CHARTTIME, 1) OVER (partition by icustay_id, MechVent order by charttime) 29 | else null 30 | end as charttime_lag 31 | , charttime 32 | , MechVent 33 | , OxygenTherapy 34 | , Extubated 35 | , SelfExtubated 36 | from ventsettings 37 | ) 38 | , vd1 as 39 | ( 40 | select 41 | icustay_id 42 | , charttime_lag 43 | , charttime 44 | , MechVent 45 | , OxygenTherapy 46 | , Extubated 47 | , SelfExtubated 48 | 49 | -- if this is a mechanical ventilation event, we calculate the time since the last event 50 | , case 51 | -- if the current observation indicates mechanical ventilation is present 52 | -- calculate the time since the last vent event 53 | when MechVent=1 then 54 | CHARTTIME - charttime_lag 55 | else null 56 | end as ventduration 57 | 58 | , LAG(Extubated,1) 59 | OVER 60 | ( 61 | partition by icustay_id, case when MechVent=1 or Extubated=1 then 1 else 0 end 62 | order by charttime 63 | ) as ExtubatedLag 64 | 65 | -- now we determine if the current mech vent event is a "new", i.e. they've just been intubated 66 | , case 67 | -- if there is an extubation flag, we mark any subsequent ventilation as a new ventilation event 68 | --when Extubated = 1 then 0 -- extubation is *not* a new ventilation event, the *subsequent* row is 69 | when 70 | LAG(Extubated,1) 71 | OVER 72 | ( 73 | partition by icustay_id, case when MechVent=1 or Extubated=1 then 1 else 0 end 74 | order by charttime 75 | ) 76 | = 1 then 1 77 | -- if patient has initiated oxygen therapy, and is not currently vented, start a newvent 78 | when MechVent = 0 and OxygenTherapy = 1 then 1 79 | -- if there is less than 8 hours between vent settings, we do not treat this as a new ventilation event 80 | when (CHARTTIME - charttime_lag) > interval '8' hour 81 | then 1 82 | else 0 83 | end as newvent 84 | -- use the staging table with only vent settings from chart events 85 | FROM vd0 ventsettings 86 | ) 87 | , vd2 as 88 | ( 89 | select vd1.* 90 | -- create a cumulative sum of the instances of new ventilation 91 | -- this results in a monotonic integer assigned to each instance of ventilation 92 | , case when MechVent=1 or Extubated = 1 then 93 | SUM( newvent ) 94 | OVER ( partition by icustay_id order by charttime ) 95 | else null end 96 | as ventnum 97 | --- now we convert CHARTTIME of ventilator settings into durations 98 | from vd1 99 | ) 100 | -- create the durations for each mechanical ventilation instance 101 | select icustay_id 102 | -- regenerate ventnum so it's sequential 103 | , ROW_NUMBER() over (partition by icustay_id order by ventnum) as ventnum 104 | , min(charttime) as starttime 105 | , max(charttime) as endtime 106 | , extract(epoch from max(charttime)-min(charttime))/60/60 AS duration_hours 107 | from vd2 108 | group by icustay_id, ventnum 109 | having min(charttime) != max(charttime) 110 | -- patient had to be mechanically ventilated at least once 111 | -- i.e. max(mechvent) should be 1 112 | -- this excludes a frequent situation of NIV/oxygen before intub 113 | -- in these cases, ventnum=0 and max(mechvent)=0, so they are ignored 114 | and max(mechvent) = 1 115 | order by icustay_id, ventnum; 116 | -------------------------------------------------------------------------------- /src/data_preprocessing/features_preprocessing/stepII_split_sets_n_normalise.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | import sys 5 | import pandas as pd 6 | 7 | cwd = os.path.dirname(os.path.abspath(__file__)) 8 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir, os.pardir)) 9 | sys.path.append(head) 10 | 11 | """ 12 | Note: this file assumes that you have already generated all files in 13 | make_hourly_data/static_data and inclusioncrit.sql 14 | TODO: change file source for the files above 15 | 16 | """ 17 | 18 | 19 | class MakeSetsAndNormalise: 20 | 21 | def __init__(self, final_data_var_path, final_data_stat_path, 22 | split_file='sets_split.pkl', new_split_file='sets_split_split_2.pkl'): 23 | self.stat_data_path = final_data_stat_path 24 | self.var_data_path = final_data_var_path 25 | self.split_file = split_file 26 | self.new_split_file = new_split_file 27 | self.cwd = os.path.dirname(os.path.abspath(__file__)) 28 | self.head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir, os.pardir)) 29 | print('working out of the assumption that head is ', self.head) 30 | 31 | def load_data(self): 32 | self.stat_data = pd.read_csv(self.stat_data_path) 33 | if "Unnamed: 0" in self.stat_data.columns: 34 | self.stat_data.drop(columns="Unnamed: 0") 35 | self.var_data = pd.read_csv(self.var_data_path) 36 | if "Unnamed: 0" in self.var_data.columns: 37 | self.var_data.drop(columns="Unnamed: 0") 38 | 39 | def split(self): 40 | file = os.path.join(self.head, 'data', self.split_file) 41 | all_icus = self.stat_data.icustay_id.unique().tolist() 42 | random.shuffle(all_icus) 43 | no_icus = len(all_icus) 44 | self.sets = { 45 | "train": all_icus[:int(no_icus * 0.8)], 46 | "val": all_icus[int(no_icus * 0.8): int(no_icus * 0.9)], 47 | "test": all_icus[int(no_icus * 0.9):] 48 | } 49 | f = open(file, "wb") 50 | pickle.dump(self.sets, f) 51 | f.close() 52 | 53 | def load_split(self): 54 | file = os.path.join(self.head, 'data', self.split_file) 55 | with open(file, "rb") as f: 56 | self.sets = pickle.load(f) 57 | 58 | def new_splits(self): 59 | all_non_test_icus = self.sets['train'] + self.sets['val'] 60 | random.shuffle(all_non_test_icus) 61 | self.new_sets = { 62 | "train": all_non_test_icus[:len(self.sets['train'])], 63 | "val": all_non_test_icus[len(self.sets['train']): len(self.sets['train']) + len(self.sets['val'])], 64 | "test": self.sets['test'] 65 | } 66 | file = os.path.join(self.head, 'data', self.new_split_file) 67 | f = open(file, "wb") 68 | pickle.dump(self.new_sets, f) 69 | f.close() 70 | self.sets = self.new_sets 71 | 72 | 73 | def normalise(self, file_name=None): 74 | # first re-order columns if needed 75 | cols = list(self.var_data.columns) 76 | start_col = ['label', 'icustay_id', 'chart_time', 'subject_id', 'sepsis_target'] 77 | if cols[:5] != start_col: 78 | for col in start_col: 79 | cols.remove(col) 80 | cols = start_col + cols 81 | self.var_data = self.var_data[cols] 82 | mean = self.var_data.loc[(self.var_data.icustay_id.isin(self.sets["train"])), cols[5:]].mean(axis=0) 83 | std = (self.var_data.loc[(self.var_data.icustay_id.isin(self.sets["train"])), cols[5:]] - mean).std(axis=0) 84 | self.var_data[cols[5:]] = (self.var_data[cols[5:]] - mean) / std 85 | if file_name is not None: 86 | mean = self.stat_data.loc[(self.stat_data.icustay_id.isin(self.sets["train"])), "admission_age"].mean() 87 | std = (self.stat_data.loc[(self.stat_data.icustay_id.isin(self.sets["train"])), "admission_age"] 88 | - mean).std() 89 | self.stat_data.admission_age = (self.stat_data.admission_age - mean) / std 90 | 91 | 92 | def save(self, file_name=None): 93 | path = os.path.join(self.head, 'data') 94 | sets_names = ["train", "val", "test"] 95 | if file_name is None: 96 | full_static = "full_static.csv" 97 | full_labvitals = "full_labvitals.csv" 98 | else: 99 | full_static = "full_static_{}.csv".format(file_name) 100 | full_labvitals = "full_labvitals_{}.csv".format(file_name) 101 | for set in sets_names: 102 | if not os.path.exists(path + set): 103 | os.makedirs(path + set) 104 | self.stat_data[self.stat_data.icustay_id.isin(self.sets[set])].to_csv(os.path.join(path, set, full_static)) 105 | self.var_data[self.var_data.icustay_id.isin(self.sets[set])].to_csv(os.path.join(path, set, full_labvitals)) 106 | 107 | 108 | -------------------------------------------------------------------------------- /src/models/GP_attTCN_ablations.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tensorflow as tf 4 | import numpy as np 5 | # appending head path 6 | cwd = os.path.dirname(os.path.abspath(__file__)) 7 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir, os.pardir)) 8 | sys.path.append(head) 9 | from src.models.GP_attTCN import GPattTCN 10 | from src.models.attTCN_alpha import AttTCN_alpha 11 | from src.models.attTCN_beta import AttTCN_beta 12 | 13 | class GPattTCN_alpha(GPattTCN): 14 | def __init__(self, 15 | time_window, 16 | n_mc_samples, 17 | n_features, 18 | n_stat_features, 19 | kernel='OU', 20 | len_mode='avg', 21 | len_trainable=True, 22 | log_noise_mean=-2, 23 | log_noise_std=0.1, 24 | log_length_mean=1, 25 | log_length_std=0.1, 26 | method_name='chol', 27 | add_diag=0.001, 28 | L2reg=None, 29 | DO=None, 30 | save_path=head, 31 | num_layers=4, 32 | kernel_size=2, 33 | stride=1, 34 | sigmoid_beta=False, 35 | moor_data=False 36 | ): 37 | super().__init__(time_window, 38 | n_mc_samples, 39 | n_features, 40 | n_stat_features, 41 | kernel, 42 | len_mode, 43 | len_trainable, 44 | log_noise_mean, 45 | log_noise_std, 46 | log_length_mean, 47 | log_length_std, 48 | method_name, 49 | add_diag, 50 | L2reg, 51 | DO, 52 | save_path, 53 | num_layers, 54 | kernel_size, 55 | stride, 56 | sigmoid_beta, 57 | moor_data) 58 | self.attTCN = AttTCN_alpha(time_window, 59 | n_features + n_stat_features, 60 | num_layers, 61 | DO, 62 | L2reg, 63 | kernel_size=kernel_size, 64 | stride=stride, 65 | ) 66 | 67 | self.trainable_variables = self.GP.trainable_variables + \ 68 | self.attTCN.trainable_variables 69 | 70 | 71 | class GPattTCN_beta(GPattTCN): 72 | def __init__(self, 73 | time_window, 74 | n_mc_samples, 75 | n_features, 76 | n_stat_features, 77 | kernel='OU', 78 | len_mode='avg', 79 | len_trainable=True, 80 | log_noise_mean=-2, 81 | log_noise_std=0.1, 82 | log_length_mean=1, 83 | log_length_std=0.1, 84 | method_name='chol', 85 | add_diag=0.001, 86 | L2reg=None, 87 | DO=None, 88 | save_path=head, 89 | num_layers=4, 90 | kernel_size=2, 91 | stride=1, 92 | sigmoid_beta=False, 93 | moor_data=False 94 | ): 95 | super().__init__(time_window, 96 | n_mc_samples, 97 | n_features, 98 | n_stat_features, 99 | kernel, 100 | len_mode, 101 | len_trainable, 102 | log_noise_mean, 103 | log_noise_std, 104 | log_length_mean, 105 | log_length_std, 106 | method_name, 107 | add_diag, 108 | L2reg, 109 | DO, 110 | save_path, 111 | num_layers, 112 | kernel_size, 113 | stride, 114 | sigmoid_beta, 115 | moor_data) 116 | self.attTCN = AttTCN_beta(time_window, 117 | n_features + n_stat_features, 118 | num_layers, 119 | DO, 120 | L2reg, 121 | kernel_size=kernel_size, 122 | stride=stride, 123 | sigmoid_beta=sigmoid_beta 124 | ) 125 | 126 | self.trainable_variables = self.GP.trainable_variables + \ 127 | self.attTCN.trainable_variables -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/cardiovascular/cardio_SOFA.sql: -------------------------------------------------------------------------------- 1 | DROP MATERIALIZED VIEW IF EXISTS SOFA_cardio CASCADE; 2 | CREATE materialized VIEW SOFA_cardio AS 3 | 4 | -- calculate weight 5 | with wt as ( 6 | select ha.hadm_id 7 | -- average weight 8 | , avg( 9 | case 10 | -- kg 11 | when itemid in (762, 763, 3723, 3580, 226512) 12 | then valuenum 13 | -- lbs 14 | when itemid in (3581) 15 | then valuenum * 0.45359237 16 | -- oz 17 | when itemid IN (3582) 18 | then valuenum * 0.0283495231 19 | else null 20 | end) as weight 21 | 22 | from admissions ha 23 | left join chartevents c 24 | on ha.hadm_id = c.hadm_id 25 | where valuenum is not null 26 | and valuenum != 0 27 | and itemid in 28 | ( 29 | -- cv 30 | 762, 763, 3723, 3580, -- Weight Kg 31 | 3581, -- Weight lb 32 | 3582, -- Weight oz 33 | -- mv 34 | 226512 -- Weight Kg 35 | ) 36 | -- note that the timeframe below assumes weight does not change as a funrion of time 37 | -- verified from the fact that mv only has an admission weight and no dynamic weighting 38 | and charttime between ha.admittime - interval '2' month and ha.dischtime 39 | -- some rows are marked as error, let's inglore them 40 | and c.error is distinct from 1 41 | group by ha.hadm_id 42 | ) 43 | -- calculate weight indirectly through echography weight 44 | , echo2 as ( 45 | select ha.hadm_id, avg(weight * 0.45359237) as weight 46 | from admissions ha 47 | left join echodata echo 48 | on ha.hadm_id = echo.hadm_id 49 | and echo.charttime > ha.admittime - interval '7' day 50 | and echo.charttime < ha.dischtime 51 | group by ha.hadm_id 52 | ) 53 | -- calculate rates for carevue 54 | , vaso_cv as ( 55 | select ha.hadm_id 56 | , max( case 57 | when itemid = 30047 then rate / coalesce(wt.weight, ec.weight) 58 | when itemid = 30120 then rate 59 | else null end ) as rate_norepinephrine 60 | , max( case 61 | when itemid = 30044 then rate / coalesce(wt.weight, ec.weight) 62 | when itemid in (30119, 30309) then rate 63 | else null end ) as rate_epinephrine 64 | , max( case when itemid in (30043, 30307) then rate end) as rate_dopamine 65 | , max( case when itemid in (30042, 30306) then rate end) as rate_dobutamine 66 | , (date_part('year', age(cv.charttime, ha.admittime ))*365 * 24 67 | + date_part('month', age(cv.charttime, ha.admittime ))*365/12 * 24 68 | + date_part('day', age(cv.charttime, ha.admittime ))* 24 69 | + date_part('hour', age(cv.charttime, ha.admittime )) 70 | + round(date_part('minute', age(cv.charttime, ha.admittime ))/60)) as HLOS 71 | 72 | from admissions ha 73 | inner join inputevents_cv cv 74 | on ha.hadm_id = cv.hadm_id 75 | and cv.charttime between ha.admittime - interval '1' day and ha.dischtime 76 | left join wt 77 | on ha.hadm_id = wt.hadm_id 78 | left join echo2 ec 79 | on ha.hadm_id = ec.hadm_id 80 | where itemid in (30047,30120,30044,30119,30309,30043,30307,30042,30306) 81 | and rate is not null 82 | group by ha.hadm_id, HLOS 83 | ) 84 | -- calculate rates for metavision 85 | , vaso_mv as ( 86 | select ha.hadm_id 87 | , max(case when itemid = 221906 then rate end) as rate_norepinephrine 88 | , max(case when itemid = 221289 then rate end) as rate_epinephrine 89 | , max(case when itemid = 221662 then rate end) as rate_dopamine 90 | , max(case when itemid = 221653 then rate end) as rate_dobutamine 91 | , (date_part('year', age(mv.starttime, ha.admittime))*365 * 24 92 | + date_part('month', age(mv.starttime, ha.admittime))*365/12 * 24 93 | + date_part('day', age(mv.starttime, ha.admittime))* 24 94 | + date_part('hour', age(mv.starttime, ha.admittime)) 95 | + round(date_part('minute', age(mv.starttime, ha.admittime))/60)) as HLOS 96 | from admissions ha 97 | inner join inputevents_mv mv 98 | on ha.hadm_id = mv.hadm_id 99 | and mv.starttime between ha.admittime - interval '1' day and ha.dischtime 100 | where itemid in (221906,221289,221662,221653) 101 | -- 'Rewritten' orders are not delivered to the patient 102 | and statusdescription != 'Rewritten' 103 | group by ha.hadm_id , HLOS 104 | ) 105 | -- join everything 106 | select 107 | ha.hadm_id, vaso.hlos 108 | , coalesce(cv.rate_norepinephrine, mv.rate_norepinephrine) as rate_norepinephrine 109 | , coalesce(cv.rate_epinephrine, mv.rate_epinephrine) as rate_epinephrine 110 | , coalesce(cv.rate_dopamine, mv.rate_dopamine) as rate_dopamine 111 | , coalesce(cv.rate_dobutamine, mv.rate_dobutamine) as rate_dobutamine 112 | from admissions ha 113 | left join ( 114 | select hadm_id, hlos from vaso_cv union 115 | select hadm_id, hlos from vaso_mv 116 | ) as vaso 117 | on ha.hadm_id = vaso.hadm_id 118 | left join vaso_cv cv 119 | on vaso.hadm_id = cv.hadm_id 120 | and vaso.hlos = cv.hlos 121 | left join vaso_mv mv 122 | on vaso.hadm_id = mv.hadm_id 123 | and vaso.hlos = mv.hlos 124 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/central_nervous_system/gcsperhour.sql: -------------------------------------------------------------------------------- 1 | -- ITEMIDs used: 2 | 3 | -- CAREVUE 4 | -- 723 as GCSVerbal 5 | -- 454 as GCSMotor 6 | -- 184 as GCSEyes 7 | 8 | -- METAVISION 9 | -- 223900 GCS - Verbal Response 10 | -- 223901 GCS - Motor Response 11 | -- 220739 GCS - Eye Opening 12 | 13 | -- The code combines the ITEMIDs into the carevue itemids, then pivots those 14 | -- So 223900 is changed to 723, then the ITEMID 723 is pivoted to form GCSVerbal 15 | 16 | -- Note: 17 | -- The GCS for sedated patients is defaulted to 15 in this code. 18 | -- This is in line with how the data is meant to be collected. 19 | -- e.g., from the SAPS II publication: 20 | -- For sedated patients, the Glasgow Coma Score before sedation was used. 21 | -- This was ascertained either from interviewing the physician who ordered the sedation, 22 | -- or by reviewing the patient's medical record. 23 | 24 | DROP MATERIALIZED VIEW IF EXISTS SOFA_gcsperhour CASCADE; 25 | create materialized view SOFA_gcsperhour as 26 | with base as 27 | ( 28 | SELECT pvt.hadm_id 29 | , pvt.charttime 30 | 31 | -- Easier names - note we coalesced Metavision and CareVue IDs below 32 | , max(case when pvt.itemid = 454 then pvt.valuenum else null end) as GCSMotor 33 | , max(case when pvt.itemid = 723 then pvt.valuenum else null end) as GCSVerbal 34 | , max(case when pvt.itemid = 184 then pvt.valuenum else null end) as GCSEyes 35 | 36 | -- If verbal was set to 0 in the below select, then this is an intubated patient 37 | , case 38 | when max(case when pvt.itemid = 723 then pvt.valuenum else null end) = 0 39 | then 1 40 | else 0 41 | end as EndoTrachFlag 42 | 43 | , ROW_NUMBER () 44 | OVER (PARTITION BY pvt.hadm_id ORDER BY pvt.charttime ASC) as rn 45 | 46 | FROM ( 47 | select l.hadm_id 48 | -- merge the ITEMIDs so that the pivot applies to both metavision/carevue data 49 | , case 50 | when l.ITEMID in (723,223900) then 723 51 | when l.ITEMID in (454,223901) then 454 52 | when l.ITEMID in (184,220739) then 184 53 | else l.ITEMID end 54 | as ITEMID 55 | 56 | -- convert the data into a number, reserving a value of 0 for ET/Trach 57 | , case 58 | -- endotrach/vent is assigned a value of 0, later parsed specially 59 | when l.ITEMID = 723 and l.VALUE = '1.0 ET/Trach' then 0 -- carevue 60 | when l.ITEMID = 223900 and l.VALUE = 'No Response-ETT' then 0 -- metavision 61 | 62 | else VALUENUM 63 | end 64 | as VALUENUM 65 | , l.CHARTTIME 66 | 67 | from CHARTEVENTS l 68 | 69 | -- get intime for charttime subselection 70 | inner join admissions ha 71 | on l.hadm_id = ha.hadm_id 72 | 73 | -- Isolate the desired GCS variables 74 | where l.ITEMID in 75 | ( 76 | -- 198 -- GCS 77 | -- GCS components, CareVue 78 | 184, 454, 723 79 | -- GCS components, Metavision 80 | , 223900, 223901, 220739 81 | ) 82 | -- Only get data for the first 24 hours 83 | and l.charttime between (ha.admittime - interval '1' day) AND ha.dischtime 84 | -- exclude rows marked as error 85 | and l.error IS DISTINCT FROM 1 86 | ) pvt 87 | group by pvt.hadm_id, pvt.charttime 88 | ) 89 | , gcs as ( 90 | select b.* 91 | , b2.GCSVerbal as GCSVerbalPrev 92 | , b2.GCSMotor as GCSMotorPrev 93 | , b2.GCSEyes as GCSEyesPrev 94 | -- Calculate GCS, factoring in special case when they are intubated and prev vals 95 | -- note that the coalesce are used to implement the following if: 96 | -- if current value exists, use it 97 | -- if previous value exists, use it 98 | -- otherwise, default to normal 99 | , case 100 | -- replace GCS during sedation with 15 101 | when b.GCSVerbal = 0 102 | then 15 103 | when b.GCSVerbal is null and b2.GCSVerbal = 0 104 | then 15 105 | -- if previously they were intub, but they aren't now, do not use previous GCS values 106 | when b2.GCSVerbal = 0 107 | then 108 | coalesce(b.GCSMotor,6) 109 | + coalesce(b.GCSVerbal,5) 110 | + coalesce(b.GCSEyes,4) 111 | -- otherwise, add up score normally, imputing previous value if none available at current time 112 | else 113 | coalesce(b.GCSMotor,coalesce(b2.GCSMotor,6)) 114 | + coalesce(b.GCSVerbal,coalesce(b2.GCSVerbal,5)) 115 | + coalesce(b.GCSEyes,coalesce(b2.GCSEyes,4)) 116 | end as GCS 117 | 118 | from base b 119 | -- join to itself within 6 hours to get previous value 120 | left join base b2 121 | on b.hadm_id = b2.hadm_id and b.rn = b2.rn+1 and b2.charttime > b.charttime - interval '6' hour 122 | 123 | ) 124 | select ha.hadm_id 125 | -- The minimum GCS is determined by the above row partition, we only join if IsMinGCS=1 126 | , min(GCS) as GCS 127 | , (date_part('year', age(gs.charttime, ha.admittime))*365 * 24 128 | + date_part('month', age(gs.charttime, ha.admittime))*365/12 * 24 129 | + date_part('day', age(gs.charttime, ha.admittime))* 24 130 | + date_part('hour', age(gs.charttime, ha.admittime)) 131 | + round(date_part('minute', age(gs.charttime, ha.admittime))/60)) as HLOS 132 | from admissions ha 133 | left join gcs gs 134 | on ha.hadm_id = gs.hadm_id 135 | group by ha.hadm_id, HLOS 136 | ORDER BY ha.hadm_id, HLOS; -------------------------------------------------------------------------------- /src/mains/ablation_0.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from datetime import datetime 4 | import numpy as np 5 | import tensorflow as tf 6 | import pickle 7 | cwd = os.path.dirname(os.path.abspath(__file__)) 8 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir)) 9 | sys.path.append(head) 10 | from src.utils.debug import t_print 11 | from src.data_loader.loader import DataGenerator 12 | from src.models.GP_logreg import GPLogReg 13 | from src.trainers.trainer import Trainer 14 | 15 | 16 | def main( 17 | # data 18 | max_no_dtpts, 19 | min_no_dtpts, 20 | time_window, 21 | n_features, 22 | n_stat_features, 23 | features, 24 | late_patients_only, 25 | horizon0, 26 | # model 27 | model_choice, 28 | # MGP 29 | no_mc_samples, 30 | kernel_choice, 31 | L2reg, 32 | # training 33 | learning_rate, 34 | batch_size, 35 | num_epochs,): 36 | # generate save path 37 | logdir = os.path.join("logs/ablation_0", datetime.now().strftime("%Y%m%d-%H%M%S")) 38 | if not os.path.isdir(logdir): 39 | os.mkdir(logdir) 40 | Dict = { 41 | # data 42 | "max_no_dtpts":max_no_dtpts, 43 | "min_no_dtpts":min_no_dtpts, 44 | "time_window": time_window, 45 | "n_features": n_features, 46 | "n_stat_features": n_stat_features, 47 | "features": features, 48 | "late_patients_only": late_patients_only, 49 | "horizon0": horizon0, 50 | # model 51 | "model_choice": model_choice, 52 | # MGP 53 | "no_mc_samples": no_mc_samples, 54 | "kernel_choice": kernel_choice, 55 | "L2reg": L2reg, 56 | # training 57 | "learning_rate": learning_rate, 58 | "batch_size": batch_size, 59 | "num_epochs": num_epochs, 60 | } 61 | with open(os.path.join(logdir, 'hyperparam.pkl'), "wb") as f: 62 | pickle.dump(Dict, f) 63 | 64 | summary_writers = {'train': tf.summary.create_file_writer(os.path.join(logdir, 'train')), 65 | 'val': tf.summary.create_file_writer(os.path.join(logdir, 'val'))} 66 | # Load data 67 | data = DataGenerator(no_mc_samples=no_mc_samples, 68 | max_no_dtpts=max_no_dtpts, 69 | min_no_dtpts=min_no_dtpts, 70 | batch_size=batch_size, 71 | fast_load=False, 72 | to_save=True, 73 | debug=True, 74 | fixed_idx_per_class=False, 75 | features=features) 76 | 77 | t_print("main - generate model and optimiser") 78 | optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) 79 | global_step = tf.Variable(0) 80 | 81 | # Load model 82 | model = GPLogReg(time_window, 83 | no_mc_samples, 84 | n_features, 85 | n_stat_features, 86 | L2reg=L2reg) 87 | 88 | # Initialise trainer 89 | trainer = Trainer(model=model, 90 | data=data, 91 | num_epochs=num_epochs, 92 | batch_size=batch_size, 93 | optimizer=optimizer, 94 | global_step=global_step, 95 | summary_writers=summary_writers, 96 | log_path=logdir, 97 | train_only=False, 98 | notebook_friendly=False, 99 | eval_every=20, 100 | late_patients_only=late_patients_only, 101 | horizon0=horizon0,) 102 | 103 | # train model 104 | trainer.run() 105 | 106 | 107 | if __name__=="__main__": 108 | tf.random.set_seed(1237) 109 | np.random.seed(1237) 110 | # data 111 | max_no_dtpts = 250 # chopping 4.6% of data at 250 112 | min_no_dtpts = 40 # helping with covariance singularity 113 | time_window = 25 # fixed 114 | n_features = 24 # old data: 44 115 | n_stat_features = 8 # old data: 35 116 | features = 'mr_features_mm_labels' 117 | n_features= 17 118 | n_stat_features= 8 119 | features = None 120 | late_patients_only = False 121 | horizon0 = False 122 | 123 | # model 124 | model_choice = 'Att' # ['Att', 'Moor'] 125 | 126 | # MGP 127 | no_mc_samples = 10 128 | kernel_choice = 'OU' 129 | 130 | # training 131 | learning_rate = 0.0005 132 | batch_size = 128 133 | num_epochs = 100 134 | 135 | learning_rate = np.random.uniform(10e-6, high=10e-4, size=None) 136 | no_mc_samples = np.random.randint(8, high=20, size=None, dtype='l') 137 | L2reg = [10**float(np.random.randint(-5, high=8, size=None, dtype='l'))] * 5 138 | 139 | 140 | main( 141 | # data 142 | max_no_dtpts, 143 | min_no_dtpts, 144 | time_window, 145 | n_features, 146 | n_stat_features, 147 | features, 148 | late_patients_only, 149 | horizon0, 150 | # model 151 | model_choice, 152 | # MGP 153 | no_mc_samples, 154 | kernel_choice, 155 | L2reg, 156 | # training 157 | learning_rate, 158 | batch_size, 159 | num_epochs,) 160 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/respiration/bloodgasfirstdayarterial.sql: -------------------------------------------------------------------------------- 1 | DROP MATERIALIZED VIEW IF EXISTS resp_bloodgasfirstdayarterial CASCADE; 2 | CREATE MATERIALIZED VIEW resp_bloodgasfirstdayarterial AS 3 | -- export from chartevents SoP2 data 4 | with stg_spo2 as 5 | ( 6 | select SUBJECT_ID, HADM_ID, CHARTTIME 7 | -- max here is just used to group SpO2 by charttime 8 | , max(case when valuenum <= 0 or valuenum > 100 then null else valuenum end) as SpO2 9 | from CHARTEVENTS 10 | -- o2 sat 11 | where ITEMID in 12 | ( 13 | 646 -- SpO2 14 | , 220277 -- O2 saturation pulseoxymetry 15 | ) 16 | group by SUBJECT_ID, HADM_ID, CHARTTIME 17 | ) 18 | -- export from chartevents FiO2 data 19 | , stg_fio2 as 20 | ( 21 | select SUBJECT_ID, HADM_ID, CHARTTIME 22 | -- pre-process the FiO2s to ensure they are between 21-100% 23 | , max( 24 | case 25 | when itemid = 223835 26 | then case 27 | when valuenum > 0 and valuenum <= 1 28 | then valuenum * 100 29 | -- improperly input data - looks like O2 flow in litres 30 | when valuenum > 1 and valuenum < 21 31 | then null 32 | when valuenum >= 21 and valuenum <= 100 33 | then valuenum 34 | else null end -- unphysiological 35 | when itemid in (3420, 3422) 36 | -- all these values are well formatted 37 | then valuenum 38 | when itemid = 190 and valuenum > 0.20 and valuenum < 1 39 | -- well formatted but not in % 40 | then valuenum * 100 41 | else null end 42 | ) as fio2_chartevents -- keep 43 | from CHARTEVENTS 44 | where ITEMID in 45 | ( 46 | 3420 -- FiO2 47 | , 190 -- FiO2 set 48 | , 223835 -- Inspired O2 Fraction (FiO2) 49 | , 3422 -- FiO2 [measured] 50 | ) 51 | -- exclude rows marked as error 52 | and error IS DISTINCT FROM 1 53 | group by SUBJECT_ID, HADM_ID, CHARTTIME 54 | ) 55 | -- extract first time SpO2 is recorded / sampled 56 | , stg2 as 57 | ( 58 | select bg.* 59 | , ROW_NUMBER() OVER (partition by bg.hadm_id, bg.charttime order by s1.charttime DESC) as lastRowSpO2 -- keep 60 | , s1.spo2 -- keep 61 | from resp_bloodgasfirstday bg 62 | left join stg_spo2 s1 63 | -- same patient 64 | on bg.hadm_id = s1.hadm_id 65 | -- spo2 occurred at most 2 hours before this blood gas 66 | and s1.charttime between bg.charttime - interval '2' hour and bg.charttime 67 | where bg.po2 is not null 68 | ) 69 | -- extract first time FiO2 is recorded / sampled + specimen prediction (?) 70 | , stg3 as 71 | ( 72 | select bg.subject_id, bg.hadm_id, bg.charttime 73 | , bg.SPECIMEN 74 | , bg.PO2 75 | , ROW_NUMBER() OVER (partition by bg.hadm_id, bg.charttime order by greatest(bg2.charttime, s2.charttime) DESC) as lastRowFiO2 -- KEEP 76 | , case 77 | when coalesce(bg2.charttime, s2.charttime) is null then null 78 | when bg2.charttime is null then s2.fio2_chartevents 79 | when s2.charttime is null then bg2.FIO2 80 | when bg2.charttime >= s2.charttime then coalesce(bg2.FIO2, s2.fio2_chartevents) 81 | else coalesce(s2.fio2_chartevents, bg2.FIO2) end 82 | as FIO2_val 83 | 84 | -- create our specimen prediction 85 | -- data conditioned on this for some reason 86 | , 1/(1+exp(-(-0.02544 87 | + 0.04598 * bg.po2 88 | + coalesce(-0.15356 * bg.spo2 , -0.15356 * 97.49420 + 0.13429) 89 | + coalesce( 0.00621 * fio2_chartevents , 0.00621 * 51.49550 + -0.24958) 90 | + coalesce( 0.10559 * bg.hemoglobin , 0.10559 * 10.32307 + 0.05954) 91 | + coalesce( 0.13251 * bg.so2 , 0.13251 * 93.66539 + -0.23172) 92 | + coalesce(-0.01511 * bg.pco2 , -0.01511 * 42.08866 + -0.01630) 93 | + coalesce( 0.01480 * bg.fio2 , 0.01480 * 63.97836 + -0.31142) 94 | + coalesce(-0.00200 * bg.aado2 , -0.00200 * 442.21186 + -0.01328) 95 | + coalesce(-0.03220 * bg.bicarbonate , -0.03220 * 22.96894 + -0.06535) 96 | + coalesce( 0.05384 * bg.totalco2 , 0.05384 * 24.72632 + -0.01405) 97 | + coalesce( 0.08202 * bg.lactate , 0.08202 * 3.06436 + 0.06038) 98 | + coalesce( 0.10956 * bg.ph , 0.10956 * 7.36233 + -0.00617) 99 | + coalesce( 0.00848 * bg.o2flow , 0.00848 * 7.59362 + -0.35803) 100 | ))) as SPECIMEN_PROB -- keep 101 | from stg2 bg 102 | left join stg_fio2 s2 103 | -- same patient 104 | on bg.hadm_id = s2.hadm_id 105 | -- fio2 occurred at most 4 hours before this blood gas 106 | and s2.charttime between bg.charttime - interval '4' hour and bg.charttime 107 | left join stg2 bg2 108 | -- same patient 109 | on bg.hadm_id = bg2.hadm_id 110 | -- fio2 occurred at most 4 hours before this blood gas 111 | and bg2.charttime between bg.charttime - interval '4' hour and bg.charttime 112 | where bg.lastRowSpO2 = 1 -- only the row with the most recent SpO2 (if no SpO2 found lastRowSpO2 = 1) 113 | ) 114 | -- calculate PaO2FiO2 115 | select subject_id, hadm_id, charttime 116 | , case 117 | when PO2 is not null and FIO2_val is not null 118 | -- multiply by 100 because FiO2 is in a % but should be a fraction 119 | then 100*PO2/FIO2_val 120 | else null 121 | end as PaO2FiO2 122 | from stg3 123 | where lastRowFiO2 = 1 -- only the most recent FiO2 124 | -- restrict it to *only* arterial samples 125 | and (SPECIMEN = 'ART' or SPECIMEN_PROB > 0.75) 126 | order by hadm_id, charttime; -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_features/match_controls.py: -------------------------------------------------------------------------------- 1 | ''' 2 | --------------------------------- 3 | Author: Michael Moor, 09.10.2018 4 | --------------------------------- 5 | 6 | Summary: Script to match controls to cases based on icustay_ids and case/control ratio. 7 | Input: it takes a 8 | - cases csv, that links case icustay_id to sepsis_onset_hour (relative time after icu-intime in hours) 9 | - control csv listing control icustay_ids and corresponding icu-intime 10 | Output: 11 | - matched_controls.csv that list controls the following way: 12 | icustay_id, control_onset_time, control_onset_hour, matched_case_icustay_id 13 | Detailed Description: 14 | 1. Load Input files 15 | 2. Determine Control vs Case ratio p (e.g. 10/1) 16 | 3. Loop: 17 | For each case: 18 | randomly select (without repe) p controls as matched_controls 19 | For each selected control: 20 | append to result df: icustay_id, control_onset_time, control_onset_hour, matched_case_icustay_id 21 | (icustay_id of this control, the cases sepsis_onset_hour as control_onset_hour and the absolute time as control_onset_time, and the matched_case_icustay_id) 22 | 4. return result df as output 23 | ''' 24 | 25 | import os 26 | import time 27 | from datetime import datetime 28 | import numpy as np 29 | import pandas as pd 30 | cwd = os.path.dirname(os.path.abspath(__file__)) 31 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir, os.pardir, os.pardir)) 32 | 33 | 34 | def match_controls(): 35 | 36 | np.random.seed(42) 37 | result = pd.DataFrame() 38 | 39 | # -------------------- 40 | # 1. Load Input files 41 | # -------------------- 42 | 43 | casepath = os.path.join(head, 'data', 'interim', 'q13_cases_hourly_ex1c.csv') 44 | controlpath = os.path.join(head, 'data', 'interim', 'q13_controls_hourly.csv') 45 | outpath = os.path.join(head, 'data', 'interim', 'q13_matched_controls.csv') 46 | 47 | start = time.time() # Get current time 48 | 49 | cases = pd.read_csv(casepath) # read file to pd.dataframe 50 | unnamed = "Unnamed: 0" 51 | if unnamed in cases.columns: 52 | cases = cases.drop(columns="Unnamed: 0").drop_duplicates() 53 | else: 54 | cases = cases.drop_duplicates() 55 | 56 | controls = pd.read_csv(controlpath) # read file to pd.dataframe, it has many duplicates (why?), remove them 57 | if unnamed in controls.columns: 58 | controls = controls.drop(columns="Unnamed: 0").drop_duplicates() # drop duplicate rows 59 | else: 60 | controls = controls.drop_duplicates() # drop duplicate rows 61 | controls = controls.reset_index(drop=True) # resetting row index for aesthetic reasons 62 | 63 | controls['intime'] = controls['intime'].apply( 64 | lambda x: datetime.strptime(x, "%Y-%m-%d %H:%M:%S")) # convert time string to datetime object 65 | 66 | case_ids = cases.sort_values(by="sepsis_onset_hour")['icustay_id'].unique() # get unique ids 67 | control_ids = controls['icustay_id'].unique() 68 | 69 | # -------------------------------- 70 | # 2. Determine Control/Case Ratio 71 | # -------------------------------- 72 | 73 | ratio = len(control_ids) / float(len(case_ids)) 74 | rf = int(np.floor(ratio)) # rf is the ratio floored, to receive the largest viable integer ratio 75 | 76 | # --------------------------------------------- 77 | # 3. For each case match 'ratio-many' controls 78 | # --------------------------------------------- 79 | 80 | controls_s = controls.sort_values( 81 | by="length_of_stay") # Shuffle controls dataframe rows, for random control selection 82 | 83 | for i, case_id in enumerate(case_ids): 84 | matched_controls = controls_s[ 85 | int(i * ratio):int( 86 | ratio * (i + 1))] # select the next batch of controls to match to current case 87 | matched_controls = matched_controls.drop(columns=['delta_score', 'sepsis_onset']) # drop unnecessary cols 88 | 89 | onset_hour = float( 90 | cases[cases['icustay_id'] == case_id]['sepsis_onset_hour']) # get float of current case onset hour 91 | 92 | matched_controls[ 93 | 'control_onset_hour'] = onset_hour # use sepsis_onset_hour of current case as control_onset_hour 94 | matched_controls['control_onset_time'] = matched_controls['intime'] + pd.Timedelta( 95 | hours=onset_hour) # compute control_onset time w.r.t. control icu-intime 96 | matched_controls[ 97 | 'matched_case_icustay_id'] = case_id # so that each matched control can be mapped back to its matched case 98 | 99 | result = result.append(matched_controls, ignore_index=True) 100 | 101 | # --------------------------------------------------------------------- 102 | # 4. Return matched controls: here, write to csv (as next step in sql) 103 | # --------------------------------------------------------------------- 104 | 105 | result.to_csv(outpath, sep=',', index=False) # write to csv, but without row indices 106 | print('Matching Controls RUNTIME: {} seconds'.format(time.time() - start)) 107 | 108 | print('Number of Cases: {}'.format(len(case_ids))) 109 | print('Number of Controls: {}'.format(len(control_ids))) 110 | print('Number of Matched Controls: {}'.format(len(result))) 111 | print('Matching Ratio: {}, floored: {}'.format(ratio, rf)) 112 | 113 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SI/abx_micro_poe.sql: -------------------------------------------------------------------------------- 1 | /* 2 | - MODIFIED VERSION 3 | - SOURCE: https://github.com/alistairewj/sepsis3-mimic/blob/master/query/tbls/abx-micro-prescription.sql 4 | - DOWNLOADED on 8th February 2018 5 | */ 6 | 7 | -- only works for metavision as carevue does not accurately document antibiotics 8 | DROP TABLE IF EXISTS abx_micro_poe CASCADE; 9 | CREATE TABLE abx_micro_poe as 10 | -- mv tells us how many antibiotics were prescribed and when 11 | with mv as 12 | ( 13 | select hadm_id 14 | , count(mv.drug) as no_antibiotic 15 | , startdate as antibiotic_time 16 | from prescriptions mv 17 | inner join abx_poe_list ab 18 | on mv.drug = ab.drug 19 | group by hadm_id, antibiotic_time 20 | ) 21 | -- me tells us when cultures were taken 22 | , me as 23 | ( 24 | select hadm_id 25 | , chartdate, charttime 26 | , spec_type_desc 27 | -- , max(case when org_name is not null and org_name != '' then 1 else 0 end) as PositiveCulture 28 | from microbiologyevents 29 | group by hadm_id, chartdate, charttime, spec_type_desc 30 | ) 31 | -- ab_fnl checks whether a culture was taken either 72h prior or 24h after administration of antibiotics 32 | -- conditions on there being more than 1 antibiotic administered 33 | -- (see: Sepsis 3 Seymour paper attachment) 34 | , ab_fnl as 35 | ( 36 | select 37 | mv.hadm_id 38 | -- , mv.no_antibiotic 39 | , mv.antibiotic_time 40 | , coalesce(me72.charttime,me72.chartdate) as last72_charttime 41 | , coalesce(me24.charttime,me24.chartdate) as next24_charttime 42 | , case when me72.charttime is null then 'date' else 'time' end as last72 43 | , case when me24.charttime is null then 'date' else 'time' end as next24 44 | 45 | --, me72.positiveculture as last72_positiveculture 46 | --, me72.spec_type_desc as last72_specimen 47 | --, me24.positiveculture as next24_positiveculture 48 | --, me24.spec_type_desc as next24_specimen 49 | from mv 50 | -- blood culture in last 72 hours 51 | left join me me72 52 | on mv.hadm_id = me72.hadm_id 53 | and mv.antibiotic_time is not null 54 | and 55 | ( 56 | -- if charttime is available, use it 57 | ( 58 | mv.antibiotic_time >= me72.charttime 59 | and mv.antibiotic_time <= me72.charttime + interval '72' hour 60 | ) 61 | OR 62 | ( 63 | -- if charttime is not available, use chartdate 64 | me72.charttime is null 65 | and mv.antibiotic_time >= me72.chartdate 66 | and mv.antibiotic_time <= me72.chartdate + interval '3' day 67 | ) 68 | ) 69 | -- blood culture in subsequent 24 hours 70 | left join me me24 71 | on mv.hadm_id = me24.hadm_id 72 | and mv.antibiotic_time is not null 73 | -- and me24.charttime is not null -- this probably takes away quite a few options 74 | and 75 | ( 76 | -- if charttime is available, use it 77 | ( 78 | mv.antibiotic_time >= me24.charttime - interval '24' hour 79 | and mv.antibiotic_time <= me24.charttime 80 | ) 81 | OR 82 | ( 83 | -- if charttime is not available, use chartdate 84 | me24.charttime is null 85 | and mv.antibiotic_time >= me24.chartdate - interval '1' day 86 | and mv.antibiotic_time <= me24.chartdate 87 | ) 88 | ) 89 | -- added the 19.09.05 - apparently this happens sometimes 90 | where coalesce(me72.charttime,me72.chartdate,me24.charttime,me24.chartdate) is not null 91 | 92 | -- where no_antibiotic > 1 -- see: https://github.com/alistairewj/sepsis3-mimic/issues/12 93 | ) 94 | , abx_micro_poe_temp as ( 95 | select 96 | hadm_id 97 | -- , antibiotic_name 98 | , antibiotic_time 99 | , last72_charttime 100 | , next24_charttime 101 | 102 | -- suspected_infection flag: redundant with suspected_infection_time 103 | /* 104 | , case 105 | when coalesce(last72_charttime,next24_charttime) is null 106 | then 0 107 | else 1 end as suspected_infection 108 | */ 109 | -- time of suspected infection: either the culture time (if before antibiotic), or the antibiotic time 110 | , case 111 | when coalesce(last72_charttime, next24_charttime) is null 112 | then null 113 | else least(coalesce(last72_charttime, next24_charttime), antibiotic_time) 114 | end as suspected_infection_time 115 | -- to calculate time of SI, we don't care about which specimen was cultured or whether it was a positive culture 116 | /* 117 | -- the specimen that was cultured 118 | , case 119 | when last72_charttime is not null 120 | then last72_specimen 121 | when next24_charttime is not null 122 | then next24_specimen 123 | else null 124 | end as specimen 125 | 126 | -- whether the cultured specimen ended up being positive or not 127 | , case 128 | when last72_charttime is not null 129 | then last72_positiveculture 130 | when next24_charttime is not null 131 | then next24_positiveculture 132 | else null 133 | end as positiveculture 134 | */ 135 | from ab_fnl 136 | ) 137 | select 138 | a.hadm_id, 139 | ad.admittime, 140 | a.antibiotic_time, 141 | extract(EPOCH from a.antibiotic_time - ad.admittime) 142 | / 60.0 / 60.0 as abx_h, 143 | a.last72_charttime, 144 | extract(EPOCH from a.last72_charttime - ad.admittime) 145 | / 60.0 / 60.0 as l72_h, 146 | a.next24_charttime, 147 | extract(EPOCH from a.next24_charttime - ad.admittime) 148 | / 60.0 / 60.0 as n24_h, 149 | a.suspected_infection_time, 150 | extract(EPOCH from a.suspected_infection_time - ad.admittime) 151 | / 60.0 / 60.0 as si_h 152 | from abx_micro_poe_temp a 153 | left join admissions ad 154 | on a.hadm_id=ad.hadm_id 155 | 156 | -------------------------------------------------------------------------------- /src/data_preprocessing/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pandas as pd 4 | import argparse 5 | 6 | cwd = os.path.dirname(os.path.abspath(__file__)) 7 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir)) 8 | sys.path.append(head) 9 | 10 | from src.data_preprocessing.extract_MIMIC_data.extract_labels.make_labels import make_labels 11 | from src.data_preprocessing.extract_MIMIC_data.extract_features.make_data import MakeData 12 | from src.data_preprocessing.extract_MIMIC_data.extract_features.match_controls import match_controls 13 | from src.data_preprocessing.features_preprocessing.stepI_data_prep import DataPreprocessing 14 | from src.data_preprocessing.features_preprocessing.stepII_split_sets_n_normalise import MakeSetsAndNormalise 15 | from src.data_preprocessing.features_preprocessing.stepIII_GP_prep import CompactTransform 16 | from src.data_preprocessing.features_preprocessing.stepIV_GP_prep_part_II import GPPreprocessingSecondRound 17 | 18 | def make_dirs(): 19 | data_path = os.path.join(head, 'data') 20 | if not os.path.isdir(data_path): 21 | os.mkdir(data_path) 22 | interim_path = os.path.join(data_path, 'interim') 23 | if not os.path.isdir(interim_path): 24 | os.mkdir(interim_path) 25 | processed_path = os.path.join(data_path, 'processed') 26 | if not os.path.isdir(processed_path): 27 | os.mkdir(processed_path) 28 | 29 | 30 | def main(args): 31 | # create directories 32 | make_dirs() 33 | 34 | # generate sepsis labels 35 | labels = make_labels(args.connect_key, args.path) 36 | labels.generate_SI_data() 37 | labels.generate_SOFA_data() 38 | labels.generate_all_sepsis_onset() 39 | labels.filter_first_sepsis_onset() 40 | labels.save_to_postgres() 41 | labels.generate_sofa_delta_table() 42 | 43 | # generate data to feed in model 44 | data = MakeData(args.connect_key) 45 | data.step1_cohort() 46 | match_controls() 47 | data.step3_match_controls_to_sql() 48 | data.step4_extract_data() 49 | 50 | # merge extracted data, normalise TS length, run basic tests 51 | interim_path = os.path.join(head, 'data', 'interim') 52 | files = ["static_variables.csv", 53 | "static_variables_cases.csv", 54 | "static_variables_controls.csv", 55 | "vital_variables_cases.csv", 56 | "vital_variables_controls.csv", 57 | "lab_variables_cases.csv", 58 | "lab_variables_controls.csv",] 59 | 60 | cas_f = os.path.join(interim_path, files[1]) 61 | cos_f = os.path.join(interim_path, files[2]) 62 | cav_f = os.path.join(interim_path, files[3]) 63 | cov_f = os.path.join(interim_path, files[4]) 64 | cal_f = os.path.join(interim_path, files[5]) 65 | col_f = os.path.join(interim_path, files[6]) 66 | 67 | print("Initialising", flush=True) 68 | first_processing = DataPreprocessing(cas_f, cos_f, cav_f, cov_f, cal_f, col_f,) 69 | print("load_static", flush=True) 70 | first_processing.load_static() 71 | print("load_labs", flush=True) 72 | first_processing.load_labs() 73 | print("load_vitals", flush=True) 74 | first_processing.load_vitals() 75 | print("dropping unnamed columns", flush=True) 76 | first_processing.drop_all_unnamed() 77 | print("get onset 4 all", flush=True) 78 | # TODO this breaks 79 | first_processing.get_onset_hour() 80 | print("merge l & v", flush=True) 81 | first_processing.merge_labs_vitals() 82 | print("filter", flush=True) 83 | first_processing.filter_time_window() 84 | print("merge ca & co", flush=True) 85 | first_processing.merge_case_control() 86 | print("check ts lengths", flush=True) 87 | first_processing.ts_length_checks() 88 | print("static_prep", flush=True) 89 | first_processing.static_prep() 90 | 91 | # normalise, separate sets 92 | final_data_var_path = os.path.join(head, 'data', 'processed', 'full_labvitals_horizon_0_last.csv') 93 | final_data_stat_path = os.path.join(head, 'data', 'processed', 'full_static.csv') 94 | sets_n_norm = MakeSetsAndNormalise(final_data_var_path, final_data_stat_path) 95 | sets_n_norm.load_data() 96 | try: 97 | sets_n_norm.load_split() 98 | except FileNotFoundError: 99 | sets_n_norm.split() 100 | sets_n_norm.normalise() 101 | sets_n_norm.save() 102 | 103 | # GP models features 104 | for outpath in ['train', 'val', 'test']: 105 | file = os.path.join(head, 'data', outpath, 'full_labvitals.csv') 106 | data = pd.read_csv(file) 107 | onsetfile = os.path.join(head, 'data', 'processed', 'onset_hours.csv') 108 | onset = pd.read_csv(onsetfile) 109 | GP_prepI = CompactTransform(data, onset, outpath) 110 | GP_prepI.calculation() 111 | GP_prepI.save() 112 | GP_prepII = GPPreprocessingSecondRound(outpath) 113 | GP_prepII.load_files() 114 | GP_prepII.discard_useless_files() 115 | GP_prepII.join_files() 116 | GP_prepII.save() 117 | 118 | 119 | def parse_arg(): 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument("-c", "--connect_key", 122 | default="dbname=mimic user=postgres password=postgres host=localhost options=--search_path=mimiciii", 123 | help="key to enter the DB, eg: 'dbname=mimic user=postgres password=postgres options=--search_path=mimiciii'") 124 | parser.add_argument("-p", "--path", default="/cluster/home/mrosnat/MGP-AttTCN", 125 | help="path to data folder - where you would like to save your data") 126 | 127 | 128 | return parser.parse_args() 129 | 130 | 131 | if __name__ == '__main__': 132 | args = parse_arg() 133 | main(args) 134 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/SOFA.sql: -------------------------------------------------------------------------------- 1 | DROP MATERIALIZED VIEW IF EXISTS SOFAperhour CASCADE; 2 | CREATE materialized VIEW SOFAperhour AS 3 | -- get all the data in one place 4 | with scorecomp as ( 5 | select 6 | -- General info 7 | ha.hadm_id 8 | , u.HLOS as HLOS 9 | 10 | -- Respiration 11 | , pf.PaO2FiO2_novent_min 12 | , pf.PaO2FiO2_vent_min 13 | 14 | -- Coagulation 15 | , cl.Platelet 16 | 17 | -- Liver 18 | , ll.Bilirubin 19 | 20 | -- Cardiovascular 21 | , c.rate_norepinephrine 22 | , c.rate_epinephrine 23 | , c.rate_dopamine 24 | , c.rate_dobutamine 25 | , cv.MinBP 26 | 27 | -- Central nervous system 28 | , gcs.GCS 29 | 30 | -- Renal 31 | , uo.running_uo_24h as UrineOutput24h 32 | , rl.Creatinine 33 | 34 | from admissions ha 35 | left join ( 36 | select hadm_id, hlos from SOFA_PaO2FiO2 union 37 | select hadm_id, hlos from coag_labsperhour union 38 | select hadm_id, hlos from liv_labsperhour union 39 | select hadm_id, hlos from SOFA_cardio union 40 | select hadm_id, hlos from cardio_vitalsperhour union 41 | select hadm_id, hlos from SOFA_gcsperhour union 42 | select hadm_id, hlos from SOFA_runninguo24h union 43 | select hadm_id, hlos from ren_labsperhour 44 | ) as u 45 | on ha.hadm_id = u.hadm_id 46 | left join SOFA_PaO2FiO2 pf 47 | on u.hadm_id = pf.hadm_id 48 | and u.hlos = pf.hlos 49 | left join coag_labsperhour cl 50 | on u.hadm_id = cl.hadm_id 51 | and u.hlos = cl.hlos 52 | left join liv_labsperhour ll 53 | on u.hadm_id = ll.hadm_id 54 | and u.hlos = ll.hlos 55 | left join SOFA_cardio c 56 | on u.hadm_id = c.hadm_id 57 | and u.hlos = c.hlos 58 | left join cardio_vitalsperhour cv 59 | on u.hadm_id = cv.hadm_id 60 | and u.hlos = cv.hlos 61 | left join SOFA_gcsperhour gcs 62 | on u.hadm_id = gcs.hadm_id 63 | and u.hlos = gcs.hlos 64 | left join SOFA_runninguo24h uo 65 | on u.hadm_id = uo.hadm_id 66 | and u.hlos = uo.hlos 67 | left join ren_labsperhour rl 68 | on u.hadm_id = rl.hadm_id 69 | and u.hlos = rl.hlos 70 | ) 71 | -- calculating all the variables 72 | , SOFA as ( 73 | select hadm_id, HLOS 74 | -- Respiration 75 | , case 76 | when PaO2FiO2_vent_min < 100 then 4 77 | when PaO2FiO2_vent_min < 200 then 3 78 | when coalesce(PaO2FiO2_novent_min, PaO2FiO2_vent_min) < 300 then 2 79 | when coalesce(PaO2FiO2_novent_min, PaO2FiO2_vent_min) < 400 then 1 80 | when coalesce(PaO2FiO2_vent_min, PaO2FiO2_novent_min) is null then null 81 | else 0 end as respiration 82 | -- Coagulation 83 | , case 84 | when Platelet <20 then 4 85 | when Platelet <50 then 3 86 | when Platelet < 100 then 2 87 | when Platelet < 150 then 1 88 | when Platelet is null then null 89 | else 0 end as coagulation 90 | -- Liver 91 | , case 92 | when Bilirubin >= 12.0 then 4 93 | when Bilirubin >= 6.0 then 3 94 | when Bilirubin >= 2.0 then 2 95 | when Bilirubin >= 1.2 then 1 96 | when Bilirubin is null then null 97 | else 0 end as liver 98 | -- Cardiovascular 99 | , case 100 | when rate_dopamine > 15 101 | or rate_epinephrine > 0.1 102 | or rate_norepinephrine > 0.1 103 | then 4 104 | when rate_dopamine > 5 105 | or rate_epinephrine <= 0.1 106 | or rate_norepinephrine <= 0.1 107 | then 3 108 | when rate_dopamine > 0 109 | or rate_dobutamine >0 110 | then 2 111 | when MinBP < 70 then 1 112 | when coalesce(MinBP, 113 | rate_dopamine, 114 | rate_dobutamine, 115 | rate_epinephrine, 116 | rate_norepinephrine) 117 | is null then null 118 | else 0 end as Cardiovascular 119 | -- Neurological failure (GCS) 120 | , case 121 | when (GCS >= 13 and GCS <= 14) then 1 122 | when (GCS >= 10 and GCS <= 12) then 2 123 | when (GCS >= 6 and GCS <= 9) then 3 124 | when GCS < 6 then 4 125 | else 0 end as CNS 126 | -- Renal failure -- TODO: this becomes wrong once you look over the past 24h, given urine is already over 24h :( 127 | , case 128 | when (Creatinine >= 5) then 4 129 | -- when (UrineOutput24h < 200 and HLOS > 24) then 4 130 | when (Creatinine >= 3.5 and Creatinine < 5.0) then 3 131 | -- when (UrineOutput24h < 500 and HLOS > 24) then 3 132 | when (Creatinine >= 2.0 and Creatinine < 3.5) then 2 133 | when (Creatinine >= 1.2 and Creatinine < 2.0) then 1 134 | when coalesce(UrineOutput24h, Creatinine) is null then null 135 | else 0 end as renal_labs 136 | , case 137 | when (UrineOutput24h < 200 and HLOS > 24) then 4 138 | when (UrineOutput24h < 500 and HLOS > 24) then 3 139 | else 0 end as renal_uo 140 | 141 | from scorecomp 142 | ) 143 | -- making an hourly table 144 | , SOFA_per_hour as ( 145 | select 146 | hah.hadm_id 147 | , hah.hr as hlos 148 | , hah.endtime 149 | , SOFA.respiration as SOFAresp 150 | , SOFA.coagulation as SOFAcoag 151 | , SOFA.liver as SOFAliv 152 | , SOFA.Cardiovascular as SOFAcardio 153 | , SOFA.CNS as SOFAgcs 154 | , SOFA.renal_labs as SOFAren 155 | from hadms_hours hah 156 | left join SOFA 157 | on SOFA.hadm_id = hah.hadm_id 158 | and SOFA.HLOS = hah.hr 159 | order by hadm_id, hr 160 | ) 161 | -- maximum value for each component of SOFA over the past 24 hours 162 | , SOFA_per_hour_looking_back as ( 163 | SELECT 164 | sofa1.hadm_id 165 | , sofa1.HLOS 166 | , max(sofa2.SOFAresp) as SOFAresp 167 | , max(sofa2.SOFAcoag) as SOFAcoag 168 | , max(sofa2.SOFAliv) as SOFAliv 169 | , max(sofa2.SOFAcardio) as SOFAcardio 170 | , max(sofa2.SOFAgcs) as SOFAgcs 171 | , max(sofa2.SOFAren) as SOFAren 172 | 173 | FROM SOFA_per_hour sofa1 174 | JOIN SOFA_per_hour sofa2 ON 175 | sofa1.hadm_id = sofa2.hadm_id and 176 | sofa2.HLOS between sofa1.HLOS -24 and sofa1.HLOS 177 | group by sofa1.hadm_id, sofa1.HLOS 178 | order by sofa1.hadm_id, sofa1.HLOS 179 | ) 180 | SELECT s1.hadm_id 181 | , s1.hlos 182 | , SOFAresp 183 | , SOFAcoag 184 | , SOFAliv 185 | , SOFAcardio 186 | , SOFAgcs 187 | , GREATEST(SOFAren, s2.renal_uo) as SOFAren 188 | , coalesce(SOFAresp, 0) 189 | + coalesce(SOFAcoag, 0) 190 | + coalesce(SOFAliv, 0) 191 | + coalesce(SOFAcardio, 0) 192 | + coalesce(SOFAgcs, 0) 193 | + coalesce(GREATEST(SOFAren, s2.renal_uo), 0) as SOFA 194 | from SOFA_per_hour_looking_back s1 195 | left join SOFA s2 196 | on s1.hadm_id = s2.hadm_id 197 | and s1.HLOS = s2.hlos 198 | -------------------------------------------------------------------------------- /src/mains/ablation_alpha.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from datetime import datetime 4 | import numpy as np 5 | import tensorflow as tf 6 | import pickle 7 | cwd = os.path.dirname(os.path.abspath(__file__)) 8 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir)) 9 | sys.path.append(head) 10 | from src.utils.debug import t_print 11 | from src.data_loader.loader import DataGenerator 12 | from src.models.GP_attTCN_ablations import GPattTCN_alpha 13 | from src.trainers.trainer import Trainer 14 | 15 | 16 | def main( 17 | # data 18 | max_no_dtpts, 19 | min_no_dtpts, 20 | time_window, 21 | n_features, 22 | n_stat_features, 23 | features, 24 | late_patients_only, 25 | horizon0, 26 | # model 27 | model_choice, 28 | # MGP 29 | no_mc_samples, 30 | kernel_choice, 31 | # TCN 32 | num_layers, 33 | kernel_size, 34 | stride, 35 | DO, 36 | L2reg, 37 | sigmoid_beta, 38 | # training 39 | learning_rate, 40 | batch_size, 41 | num_epochs,): 42 | # generate save path 43 | logdir = os.path.join("logs/abl_alpha", datetime.now().strftime("%Y%m%d-%H%M%S")) 44 | if not os.path.isdir(logdir): 45 | os.mkdir(logdir) 46 | Dict = { 47 | # data 48 | "max_no_dtpts": max_no_dtpts, 49 | "min_no_dtpts": min_no_dtpts, 50 | "time_window": time_window, 51 | "n_features": n_features, 52 | "n_stat_features": n_stat_features, 53 | "features": features, 54 | "late_patients_only": late_patients_only, 55 | "horizon0": horizon0, 56 | # model 57 | "model_choice": model_choice, 58 | # MGP 59 | "no_mc_samples": no_mc_samples, 60 | "kernel_choice": kernel_choice, 61 | # TCN 62 | "num_layers": num_layers, 63 | "kernel_size": kernel_size, 64 | "stride": stride, 65 | "DO": DO, 66 | "L2reg": L2reg, 67 | "sigmoid_beta": sigmoid_beta, 68 | # training 69 | "learning_rate": learning_rate, 70 | "batch_size": batch_size, 71 | "num_epochs": num_epochs, 72 | } 73 | with open(os.path.join(logdir, 'hyperparam.pkl'), "wb") as f: 74 | pickle.dump(Dict, f) 75 | 76 | summary_writers = {'train': tf.summary.create_file_writer(os.path.join(logdir, 'train')), 77 | 'val': tf.summary.create_file_writer(os.path.join(logdir, 'val')), 78 | } 79 | t_print("nu_layers: {}\tlr: {}\tMC samples :{}\tDO :{}\tL2 :{}\t kernel:{}".format(num_layers, learning_rate, no_mc_samples, DO[0], L2reg[0], kernel_size)) 80 | # Load data 81 | data = DataGenerator(no_mc_samples=no_mc_samples, 82 | max_no_dtpts=max_no_dtpts, 83 | min_no_dtpts=min_no_dtpts, 84 | batch_size=batch_size, 85 | fast_load=False, 86 | to_save=True, 87 | debug=True, 88 | fixed_idx_per_class=False, 89 | features=features) 90 | 91 | t_print("main - generate model and optimiser") 92 | optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) 93 | global_step = tf.Variable(0) 94 | 95 | # Load model 96 | model = GPattTCN_alpha(time_window, 97 | no_mc_samples, 98 | n_features, 99 | n_stat_features, 100 | kernel=kernel_choice, 101 | L2reg=L2reg, 102 | DO=DO, 103 | num_layers=num_layers, 104 | kernel_size=kernel_size, 105 | stride=stride, 106 | sigmoid_beta=sigmoid_beta) 107 | 108 | # Initialise trainer 109 | trainer = Trainer(model=model, 110 | data=data, 111 | num_epochs=num_epochs, 112 | batch_size=batch_size, 113 | optimizer=optimizer, 114 | global_step=global_step, 115 | summary_writers=summary_writers, 116 | log_path=logdir, 117 | train_only=False, 118 | notebook_friendly=False, 119 | eval_every=20, 120 | late_patients_only=late_patients_only, 121 | horizon0=horizon0,) 122 | 123 | # train model 124 | trainer.run() 125 | 126 | if __name__=="__main__": 127 | tf.random.set_seed(1237) 128 | np.random.seed(1237) 129 | # data 130 | max_no_dtpts = 250 # chopping 4.6% of data at 250 131 | min_no_dtpts = 40 # helping with covariance singularity 132 | time_window = 25 # fixed 133 | n_features = 24 # old data: 44 134 | n_stat_features = 8 # old data: 35 135 | features = 'mr_features_mm_labels' 136 | n_features= 17 137 | n_stat_features= 8 138 | features = None 139 | late_patients_only = False 140 | horizon0 = False 141 | 142 | # model 143 | model_choice = 'Att' # ['Att', 'Moor'] 144 | 145 | # MGP 146 | kernel_choice = 'OU' 147 | 148 | # TCN 149 | stride = 1 150 | DO = [0.01] * 10 151 | sigmoid_beta = True 152 | 153 | # training 154 | batch_size = 128 155 | num_epochs = 100 156 | 157 | 158 | num_layers = np.random.randint(2, high=8, size=None, dtype='l') 159 | learning_rate = np.random.uniform(10e-6, high=10e-4, size=None) 160 | no_mc_samples = np.random.randint(8, high=20, size=None, dtype='l') 161 | #DO = [np.random.uniform(0, high=0.99, size=None) for _ in range(num_layers)] 162 | L2reg = [10**float(np.random.randint(-5, high=8, size=None, dtype='l'))] * num_layers 163 | kernel_size = (np.random.randint(2, high=6, size=None, dtype='l'),) 164 | 165 | main( 166 | # data 167 | max_no_dtpts, 168 | min_no_dtpts, 169 | time_window, 170 | n_features, 171 | n_stat_features, 172 | features, 173 | late_patients_only, 174 | horizon0, 175 | # model 176 | model_choice, 177 | # MGP 178 | no_mc_samples, 179 | kernel_choice, 180 | # TCN 181 | num_layers, 182 | kernel_size, 183 | stride, 184 | DO, 185 | L2reg, 186 | sigmoid_beta, 187 | # training 188 | learning_rate, 189 | batch_size, 190 | num_epochs,) 191 | -------------------------------------------------------------------------------- /src/mains/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from datetime import datetime 4 | import numpy as np 5 | import tensorflow as tf 6 | import pickle 7 | cwd = os.path.dirname(os.path.abspath(__file__)) 8 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir)) 9 | sys.path.append(head) 10 | from src.utils.debug import t_print 11 | from src.data_loader.loader import DataGenerator 12 | from src.models.GP_attTCN import GPattTCN 13 | from src.trainers.trainer import Trainer 14 | 15 | 16 | def main( 17 | # data 18 | max_no_dtpts, 19 | min_no_dtpts, 20 | time_window, 21 | n_features, 22 | n_stat_features, 23 | features, 24 | late_patients_only, 25 | horizon0, 26 | # model 27 | model_choice, 28 | # MGP 29 | no_mc_samples, 30 | kernel_choice, 31 | # TCN 32 | num_layers, 33 | kernel_size, 34 | stride, 35 | DO, 36 | L2reg, 37 | sigmoid_beta, 38 | # training 39 | learning_rate, 40 | batch_size, 41 | num_epochs,): 42 | # generate save path 43 | logdir = os.path.join( "logs/", datetime.now().strftime("%Y%m%d-%H%M%S")) 44 | if not os.path.isdir(logdir): 45 | os.mkdir(logdir) 46 | Dict = { 47 | # data 48 | "max_no_dtpts":max_no_dtpts, 49 | "min_no_dtpts":min_no_dtpts, 50 | "time_window": time_window, 51 | "n_features": n_features, 52 | "n_stat_features": n_stat_features, 53 | "features": features, 54 | "late_patients_only": late_patients_only, 55 | "horizon0": horizon0, 56 | # model 57 | "model_choice": model_choice, 58 | # MGP 59 | "no_mc_samples": no_mc_samples, 60 | "kernel_choice": kernel_choice, 61 | # TCN 62 | "num_layers": num_layers, 63 | "kernel_size": kernel_size, 64 | "stride": stride, 65 | "DO": DO, 66 | "L2reg": L2reg, 67 | "sigmoid_beta": sigmoid_beta, 68 | # training 69 | "learning_rate": learning_rate, 70 | "batch_size": batch_size, 71 | "num_epochs": num_epochs, 72 | } 73 | with open(os.path.join(logdir, 'hyperparam.pkl'), "wb") as f: 74 | pickle.dump(Dict, f) 75 | 76 | summary_writers = {'train': tf.summary.create_file_writer(os.path.join(logdir, 'train')), 77 | 'val': tf.summary.create_file_writer(os.path.join(logdir, 'val')), 78 | } 79 | t_print("nu_layers: {}\tlr: {}\tMC samples :{}\tDO :{}\tL2 :{}\t kernel:{}".format(num_layers, learning_rate, no_mc_samples, DO[0], L2reg[0], kernel_size)) 80 | # Load data 81 | data = DataGenerator(no_mc_samples=no_mc_samples, 82 | max_no_dtpts=max_no_dtpts, 83 | min_no_dtpts=min_no_dtpts, 84 | batch_size=batch_size, 85 | fast_load=False, 86 | to_save=True, 87 | debug=True, 88 | fixed_idx_per_class=False, 89 | features=features) 90 | 91 | t_print("main - generate model and optimiser") 92 | optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) 93 | global_step = tf.Variable(0) 94 | 95 | # Load model 96 | model = GPattTCN(time_window, 97 | no_mc_samples, 98 | n_features, 99 | n_stat_features, 100 | kernel=kernel_choice, 101 | L2reg=L2reg, 102 | DO=DO, 103 | num_layers=num_layers, 104 | kernel_size=kernel_size, 105 | stride=stride, 106 | sigmoid_beta=sigmoid_beta) 107 | 108 | # Initialise trainer 109 | trainer = Trainer(model=model, 110 | data=data, 111 | num_epochs=num_epochs, 112 | batch_size=batch_size, 113 | optimizer=optimizer, 114 | global_step=global_step, 115 | summary_writers=summary_writers, 116 | log_path=logdir, 117 | train_only=False, 118 | notebook_friendly=False, 119 | eval_every=20, 120 | late_patients_only=late_patients_only, 121 | horizon0=horizon0,) 122 | 123 | # train model 124 | trainer.run() 125 | 126 | 127 | if __name__=="__main__": 128 | tf.random.set_seed(1237) 129 | np.random.seed(1237) 130 | # data 131 | max_no_dtpts = 250 # chopping 4.6% of data at 250 132 | min_no_dtpts = 40 # helping with covariance singularity 133 | time_window = 25 # fixed 134 | n_features = 24 # old data: 44 135 | n_stat_features = 8 # old data: 35 136 | features = 'mr_features_mm_labels' 137 | n_features= 17 138 | n_stat_features= 8 139 | features = None 140 | late_patients_only = False 141 | horizon0 = False 142 | 143 | # model 144 | model_choice = 'Att' # ['Att', 'Moor'] 145 | 146 | # MGP 147 | no_mc_samples = 10 148 | kernel_choice = 'OU' 149 | 150 | # TCN 151 | num_layers = 4 152 | kernel_size = 3 153 | stride = 1 154 | DO = [0.01] * 10 155 | L2reg = [0.000001] * 10 156 | sigmoid_beta = True 157 | 158 | # training 159 | learning_rate = 0.0005 160 | batch_size = 128 161 | num_epochs = 100 162 | 163 | 164 | num_layers = np.random.randint(2, high=8, size=None, dtype='l') 165 | learning_rate = np.random.uniform(10e-6, high=10e-4, size=None) 166 | no_mc_samples = np.random.randint(8, high=20, size=None, dtype='l') 167 | #DO = [np.random.uniform(0, high=0.99, size=None) for _ in range(num_layers)] 168 | L2reg = [10**float(np.random.randint(-5, high=8, size=None, dtype='l'))] * num_layers 169 | load_path = head + "/not_a_path" 170 | kernel_size = (np.random.randint(2, high=6, size=None, dtype='l'),) 171 | 172 | main( 173 | # data 174 | max_no_dtpts, 175 | min_no_dtpts, 176 | time_window, 177 | n_features, 178 | n_stat_features, 179 | features, 180 | late_patients_only, 181 | horizon0, 182 | # model 183 | model_choice, 184 | # MGP 185 | no_mc_samples, 186 | kernel_choice, 187 | # TCN 188 | num_layers, 189 | kernel_size, 190 | stride, 191 | DO, 192 | L2reg, 193 | sigmoid_beta, 194 | # training 195 | learning_rate, 196 | batch_size, 197 | num_epochs) 198 | -------------------------------------------------------------------------------- /src/mains/ablation_beta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from datetime import datetime 4 | import numpy as np 5 | import tensorflow as tf 6 | import pickle 7 | cwd = os.path.dirname(os.path.abspath(__file__)) 8 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir)) 9 | sys.path.append(head) 10 | from src.utils.debug import t_print 11 | from src.data_loader.loader import DataGenerator 12 | from src.models.GP_attTCN_ablations import GPattTCN_beta 13 | from src.trainers.trainer import Trainer 14 | 15 | 16 | def main( 17 | # data 18 | max_no_dtpts, 19 | min_no_dtpts, 20 | time_window, 21 | n_features, 22 | n_stat_features, 23 | features, 24 | late_patients_only, 25 | horizon0, 26 | # model 27 | model_choice, 28 | # MGP 29 | no_mc_samples, 30 | kernel_choice, 31 | # TCN 32 | num_layers, 33 | kernel_size, 34 | stride, 35 | DO, 36 | L2reg, 37 | sigmoid_beta, 38 | # training 39 | learning_rate, 40 | batch_size, 41 | num_epochs,): 42 | # generate save path 43 | logdir = os.path.join("logs/abl_beta", datetime.now().strftime("%Y%m%d-%H%M%S")) 44 | if not os.path.isdir(logdir): 45 | os.mkdir(logdir) 46 | Dict = { 47 | # data 48 | "max_no_dtpts": max_no_dtpts, 49 | "min_no_dtpts": min_no_dtpts, 50 | "time_window": time_window, 51 | "n_features": n_features, 52 | "n_stat_features": n_stat_features, 53 | "features": features, 54 | "late_patients_only": late_patients_only, 55 | "horizon0": horizon0, 56 | # model 57 | "model_choice": model_choice, 58 | # MGP 59 | "no_mc_samples": no_mc_samples, 60 | "kernel_choice": kernel_choice, 61 | # TCN 62 | "num_layers": num_layers, 63 | "kernel_size": kernel_size, 64 | "stride": stride, 65 | "DO": DO, 66 | "L2reg": L2reg, 67 | "sigmoid_beta": sigmoid_beta, 68 | # training 69 | "learning_rate": learning_rate, 70 | "batch_size": batch_size, 71 | "num_epochs": num_epochs, 72 | } 73 | with open(os.path.join(logdir, 'hyperparam.pkl'), "wb") as f: 74 | pickle.dump(Dict, f) 75 | 76 | summary_writers = {'train': tf.summary.create_file_writer(os.path.join(logdir, 'train')), 77 | 'val': tf.summary.create_file_writer(os.path.join(logdir, 'val')), 78 | } 79 | t_print("nu_layers: {}\tlr: {}\tMC samples :{}\tDO :{}\tL2 :{}\t kernel:{}".format(num_layers, learning_rate, no_mc_samples, DO[0], L2reg[0], kernel_size)) 80 | # Load data 81 | data = DataGenerator(no_mc_samples=no_mc_samples, 82 | max_no_dtpts=max_no_dtpts, 83 | min_no_dtpts=min_no_dtpts, 84 | batch_size=batch_size, 85 | fast_load=False, 86 | to_save=True, 87 | debug=True, 88 | fixed_idx_per_class=False, 89 | features=features) 90 | 91 | t_print("main - generate model and optimiser") 92 | optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) 93 | global_step = tf.Variable(0) 94 | 95 | # Load model 96 | model = GPattTCN_beta(time_window, 97 | no_mc_samples, 98 | n_features, 99 | n_stat_features, 100 | kernel=kernel_choice, 101 | L2reg=L2reg, 102 | DO=DO, 103 | num_layers=num_layers, 104 | kernel_size=kernel_size, 105 | stride=stride, 106 | sigmoid_beta=sigmoid_beta) 107 | 108 | # Initialise trainer 109 | trainer = Trainer(model=model, 110 | data=data, 111 | num_epochs=num_epochs, 112 | batch_size=batch_size, 113 | optimizer=optimizer, 114 | global_step=global_step, 115 | summary_writers=summary_writers, 116 | log_path=logdir, 117 | train_only=False, 118 | notebook_friendly=False, 119 | eval_every=20, 120 | late_patients_only=late_patients_only, 121 | horizon0=horizon0, ) 122 | 123 | # train model 124 | trainer.run() 125 | 126 | 127 | if __name__=="__main__": 128 | tf.random.set_seed(1237) 129 | np.random.seed(1237) 130 | # data 131 | max_no_dtpts = 250 # chopping 4.6% of data at 250 132 | min_no_dtpts = 40 # helping with covariance singularity 133 | time_window = 25 # fixed 134 | n_features = 24 # old data: 44 135 | n_stat_features = 8 # old data: 35 136 | features = 'mr_features_mm_labels' 137 | n_features= 17 138 | n_stat_features= 8 139 | features = None 140 | late_patients_only = False 141 | horizon0 = False 142 | 143 | # model 144 | model_choice = 'Att' # ['Att', 'Moor'] 145 | 146 | # MGP 147 | no_mc_samples = 10 148 | kernel_choice = 'OU' 149 | 150 | # TCN 151 | num_layers = 4 152 | kernel_size = 3 153 | stride = 1 154 | DO = [0.01] * 10 155 | L2reg = [0.000001] * 10 156 | sigmoid_beta = True 157 | 158 | # training 159 | learning_rate = 0.0005 160 | batch_size = 128 161 | num_epochs = 100 162 | 163 | 164 | num_layers = np.random.randint(2, high=8, size=None, dtype='l') 165 | learning_rate = np.random.uniform(10e-6, high=10e-4, size=None) 166 | no_mc_samples = np.random.randint(8, high=20, size=None, dtype='l') 167 | #DO = [np.random.uniform(0, high=0.99, size=None) for _ in range(num_layers)] 168 | L2reg = [10**float(np.random.randint(-5, high=8, size=None, dtype='l'))] * num_layers 169 | load_path = head + "/not_a_path" 170 | kernel_size = (np.random.randint(2, high=6, size=None, dtype='l'),) 171 | 172 | main( 173 | # data 174 | max_no_dtpts, 175 | min_no_dtpts, 176 | time_window, 177 | n_features, 178 | n_stat_features, 179 | features, 180 | late_patients_only, 181 | horizon0, 182 | # model 183 | model_choice, 184 | # MGP 185 | no_mc_samples, 186 | kernel_choice, 187 | # TCN 188 | num_layers, 189 | kernel_size, 190 | stride, 191 | DO, 192 | L2reg, 193 | sigmoid_beta, 194 | # training 195 | learning_rate, 196 | batch_size, 197 | num_epochs,) 198 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_labels/SQL-SOFA/respiration/ventsettings.sql: -------------------------------------------------------------------------------- 1 | -- This query extracts the duration of mechanical ventilation 2 | -- The main goal of the query is to aggregate sequential ventilator settings 3 | -- into single mechanical ventilation "events". The start and end time of these 4 | -- events can then be used for various purposes: calculating the total duration 5 | -- of mechanical ventilation, cross-checking values (e.g. PaO2:FiO2 on vent), etc 6 | 7 | -- The query's logic is roughly: 8 | -- 1) The presence of a mechanical ventilation setting starts a new ventilation event 9 | -- 2) Any instance of a setting in the next 8 hours continues the event 10 | -- 3) Certain elements end the current ventilation event 11 | -- a) documented extubation ends the current ventilation 12 | -- b) initiation of non-invasive vent and/or oxygen ends the current vent 13 | -- The ventilation events are numbered consecutively by the `num` column. 14 | 15 | 16 | -- PART I 17 | 18 | -- First, create a temporary table to store relevant data from CHARTEVENTS. 19 | DROP MATERIALIZED VIEW IF EXISTS ventsettings CASCADE; 20 | CREATE MATERIALIZED VIEW ventsettings AS 21 | select 22 | icustay_id, charttime 23 | -- case statement determining whether it is an instance of mech vent 24 | , max( 25 | case 26 | when itemid is null or value is null then 0 -- can't have null values 27 | when itemid = 720 and value != 'Other/Remarks' THEN 1 -- VentTypeRecorded 28 | when itemid = 223848 and value != 'Other' THEN 1 29 | when itemid = 223849 then 1 -- ventilator mode 30 | when itemid = 467 and value = 'Ventilator' THEN 1 -- O2 delivery device == ventilator 31 | when itemid in 32 | ( 33 | 445, 448, 449, 450, 1340, 1486, 1600, 224687 -- minute volume 34 | , 639, 654, 681, 682, 683, 684,224685,224684,224686 -- tidal volume 35 | , 218,436,535,444,459,224697,224695,224696,224746,224747 -- High/Low/Peak/Mean/Neg insp force ("RespPressure") 36 | , 221,1,1211,1655,2000,226873,224738,224419,224750,227187 -- Insp pressure 37 | , 543 -- PlateauPressure 38 | , 5865,5866,224707,224709,224705,224706 -- APRV pressure 39 | , 60,437,505,506,686,220339,224700 -- PEEP 40 | , 3459 -- high pressure relief 41 | , 501,502,503,224702 -- PCV 42 | , 223,667,668,669,670,671,672 -- TCPCV 43 | , 224701 -- PSVlevel 44 | ) 45 | THEN 1 46 | else 0 47 | end 48 | ) as MechVent 49 | , max( 50 | case 51 | -- initiation of oxygen therapy indicates the ventilation has ended 52 | when itemid = 226732 and value in 53 | ( 54 | 'Nasal cannula', -- 153714 observations 55 | 'Face tent', -- 24601 observations 56 | 'Aerosol-cool', -- 24560 observations 57 | 'Trach mask ', -- 16435 observations 58 | 'High flow neb', -- 10785 observations 59 | 'Non-rebreather', -- 5182 observations 60 | 'Venti mask ', -- 1947 observations 61 | 'Medium conc mask ', -- 1888 observations 62 | 'T-piece', -- 1135 observations 63 | 'High flow nasal cannula', -- 925 observations 64 | 'Ultrasonic neb', -- 9 observations 65 | 'Vapomist' -- 3 observations 66 | ) then 1 67 | when itemid = 467 and value in 68 | ( 69 | 'Cannula', -- 278252 observations 70 | 'Nasal Cannula', -- 248299 observations 71 | -- 'None', -- 95498 observations 72 | 'Face Tent', -- 35766 observations 73 | 'Aerosol-Cool', -- 33919 observations 74 | 'Trach Mask', -- 32655 observations 75 | 'Hi Flow Neb', -- 14070 observations 76 | 'Non-Rebreather', -- 10856 observations 77 | 'Venti Mask', -- 4279 observations 78 | 'Medium Conc Mask', -- 2114 observations 79 | 'Vapotherm', -- 1655 observations 80 | 'T-Piece', -- 779 observations 81 | 'Hood', -- 670 observations 82 | 'Hut', -- 150 observations 83 | 'TranstrachealCat', -- 78 observations 84 | 'Heated Neb', -- 37 observations 85 | 'Ultrasonic Neb' -- 2 observations 86 | ) then 1 87 | else 0 88 | end 89 | ) as OxygenTherapy 90 | , max( 91 | case when itemid is null or value is null then 0 92 | -- extubated indicates ventilation event has ended 93 | when itemid = 640 and value = 'Extubated' then 1 94 | when itemid = 640 and value = 'Self Extubation' then 1 95 | else 0 96 | end 97 | ) 98 | as Extubated 99 | , max( 100 | case when itemid is null or value is null then 0 101 | when itemid = 640 and value = 'Self Extubation' then 1 102 | else 0 103 | end 104 | ) 105 | as SelfExtubated 106 | from chartevents ce 107 | where ce.value is not null 108 | -- exclude rows marked as error 109 | and ce.error IS DISTINCT FROM 1 110 | and itemid in 111 | ( 112 | -- the below are settings used to indicate ventilation 113 | 720, 223849 -- vent mode 114 | , 223848 -- vent type 115 | , 445, 448, 449, 450, 1340, 1486, 1600, 224687 -- minute volume 116 | , 639, 654, 681, 682, 683, 684,224685,224684,224686 -- tidal volume 117 | , 218,436,535,444,224697,224695,224696,224746,224747 -- High/Low/Peak/Mean ("RespPressure") 118 | , 221,1,1211,1655,2000,226873,224738,224419,224750,227187 -- Insp pressure 119 | , 543 -- PlateauPressure 120 | , 5865,5866,224707,224709,224705,224706 -- APRV pressure 121 | , 60,437,505,506,686,220339,224700 -- PEEP 122 | , 3459 -- high pressure relief 123 | , 501,502,503,224702 -- PCV 124 | , 223,667,668,669,670,671,672 -- TCPCV 125 | , 224701 -- PSVlevel 126 | 127 | -- the below are settings used to indicate extubation 128 | , 640 -- extubated 129 | 130 | -- the below indicate oxygen/NIV, i.e. the end of a mechanical vent event 131 | , 468 -- O2 Delivery Device#2 132 | , 469 -- O2 Delivery Mode 133 | , 470 -- O2 Flow (lpm) 134 | , 471 -- O2 Flow (lpm) #2 135 | , 227287 -- O2 Flow (additional cannula) 136 | , 226732 -- O2 Delivery Device(s) 137 | , 223834 -- O2 Flow 138 | 139 | -- used in both oxygen + vent calculation 140 | , 467 -- O2 Delivery Device 141 | ) 142 | group by icustay_id, charttime 143 | -- 144 | -- 145 | -- 146 | UNION 147 | -- add in the extubation flags from procedureevents_mv 148 | -- note that we only need the start time for the extubation 149 | -- (extubation is always charted as ending 1 minute after it started) 150 | select 151 | icustay_id, starttime as charttime 152 | , 0 as MechVent 153 | , 0 as OxygenTherapy 154 | , 1 as Extubated 155 | , case when itemid = 225468 then 1 else 0 end as SelfExtubated 156 | from procedureevents_mv 157 | where itemid in 158 | ( 159 | 227194 -- "Extubation" 160 | , 225468 -- "Unplanned Extubation (patient-initiated)" 161 | , 225477 -- "Unplanned Extubation (non-patient initiated)" 162 | ); 163 | 164 | -------------------------------------------------------------------------------- /src/data_loader/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def reduce_data(X, n_max=None, n_min=None): 5 | # data = [Y, T, ind_Y, len_T, X, len_X, labels, ids, static, onset_h] 6 | # data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 7 | n_removed = 0 8 | X = [np.asarray(X[i]) for i in range(len(X))] 9 | if n_max is not None: 10 | idx_to_reduce = np.where(X[3] >= n_max) # 7: num_obs_times 11 | for pat_no in idx_to_reduce[0]: 12 | # Y, T, ind_K_D, ind_T 13 | for data_no in range(3): 14 | X[data_no][pat_no] = X[data_no][pat_no][-n_max:] 15 | # num_obs 16 | X[3][pat_no] = n_max 17 | # X 18 | X[4][pat_no] = np.arange(start=int(np.min(X[1][pat_no])), 19 | stop=int(np.max(X[1][pat_no])) + 1) 20 | # num distinct X 21 | X[5][pat_no] = len(X[4][pat_no]) 22 | 23 | if n_min is not None: 24 | idx_to_keep = X[3] > n_min # 7: num_obs_times 25 | for i in range(len(X)): 26 | X[i] = np.array(X[i])[idx_to_keep] 27 | n_removed += np.sum(1 - idx_to_keep) 28 | return X, n_removed 29 | 30 | 31 | def new_indices(data): 32 | Y, T, ind_Y, len_T, X, len_X, labels, static, classes, ids = data 33 | ind_T = np.zeros_like(ind_Y) 34 | ind_K_D = np.zeros_like(ind_Y) 35 | for id in range(len(ids)): 36 | ind_Ti = np.zeros_like(ind_Y[id]) 37 | ind_Yi = np.asarray(ind_Y[id])[:len_T[id]] 38 | counter = 0 39 | for feat in range(45): 40 | new_items = np.where(ind_Yi == feat)[0] 41 | ind_Ti[counter: counter + len(new_items)] = new_items 42 | counter += len(new_items) 43 | ind_T[id, :len(ind_Ti)] = ind_Ti 44 | ind_K_Di = np.sort(ind_Yi) 45 | ind_K_D[id, :len(ind_K_Di)] = ind_K_Di 46 | 47 | # data = [Y, T, ind_K_D, ind_T, len_T, X, len_X, labels, static, classes, ids, ind_Y] 48 | # data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] 49 | data = [Y, T, ind_K_D, ind_T, len_T, X, len_X, labels, static, classes, ids, ind_Y] 50 | 51 | # shuffle 52 | idx = np.arange(len(classes)) 53 | np.random.shuffle(idx) 54 | return [np.asarray(data[i])[idx] for i in range(len(data))] 55 | 56 | 57 | def pad_raw_data(data): 58 | # data = [Y, T, ind_Y, len_T, X, len_X, labels, ids, static] 59 | # data = [0, 1, 2, 3, 4, 5, 6, 7, 8] 60 | dataset_size = len(data[-2]) 61 | max_num_obs = np.max(list(data[3])) 62 | max_num_X = np.max(list(data[5])) 63 | 64 | Y_padded = np.zeros((dataset_size, max_num_obs)) 65 | T_padded = Y_padded.copy() 66 | ind_Y_padded = Y_padded.copy() 67 | X_padded = np.zeros((dataset_size, max_num_X)) 68 | 69 | for i in range(dataset_size): 70 | Y_padded[i, :data[3][i]] = data[0][i] 71 | T_padded[i, :data[3][i]] = data[1][i] 72 | ind_Y_padded[i, :data[3][i]] = data[2][i] 73 | X_padded[i, :data[5][i]] = data[4][i] 74 | 75 | data[0] = Y_padded 76 | data[1] = T_padded 77 | data[2] = ind_Y_padded 78 | data[4] = X_padded 79 | return data 80 | 81 | 82 | def remove_column(data, name): 83 | try: 84 | columns = data.columns 85 | except AttributeError: 86 | return data 87 | if name in columns: 88 | data.drop(columns=name, inplace=True) 89 | return data 90 | 91 | 92 | def all_horizons(data): 93 | # data = [Y, T, ind_Y, len_T, X, len_X, labels, ids, onsets, static] 94 | # data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 95 | Y_all = data[0].copy() 96 | T_all = data[1].copy() 97 | ind_Y_all = data[2].copy() 98 | num_distinct_Y_all = data[3].copy() 99 | X_all = data[4].copy() 100 | num_distinct_X_all = data[5].copy() 101 | labels_all = data[6].copy() 102 | icustay_id_all = data[7].copy() 103 | static_all = data[9].copy() 104 | 105 | 106 | classes_all = np.zeros_like(labels_all) 107 | abs_max_T = np.max(T_all, axis=1) 108 | 109 | # debug 110 | keep_all = [np.arange(len(labels_all))] 111 | 112 | for horizon in range(1, 7): 113 | # get new max times 114 | max_T = data[8] - horizon 115 | Filter = data[1] > np.broadcast_to(max_T[:, np.newaxis], data[1].shape) 116 | data[0][Filter] = 0 117 | data[1][Filter] = 0 118 | data[2][Filter] = 0 119 | data[3] = data[3] - np.sum(Filter, axis=1) 120 | 121 | # reduce num outputs 122 | data[5] = np.ceil(np.max(data[1], axis=1) - data[1][:, 0]).astype(np.int32) 123 | for i in range(len(data[5])): 124 | data[4][i, data[5][i]:] = 0 125 | 126 | # drop empty TS 127 | kept = np.arange(data[4].shape[0]) 128 | to_keep = (data[3] > 0) 129 | data = [np.asarray(data[i])[to_keep] for i in range(len(data))] 130 | 131 | kept = kept[to_keep] 132 | abs_max_T = abs_max_T[to_keep] 133 | classes = np.ones_like(data[7]) * horizon 134 | 135 | # append 136 | Y_all = np.concatenate((Y_all, data[0]), axis=0) 137 | T_all = np.concatenate((T_all, data[1]), axis=0) 138 | ind_Y_all = np.concatenate((ind_Y_all, data[2]), axis=0) 139 | num_distinct_Y_all = np.concatenate((num_distinct_Y_all, data[3]), axis=0) 140 | X_all = np.concatenate((X_all, data[4]), axis=0) 141 | num_distinct_X_all = np.concatenate((num_distinct_X_all, data[5]), axis=0) 142 | labels_all = np.concatenate((labels_all, data[6]), axis=0) 143 | icustay_id_all = np.concatenate((icustay_id_all, data[7]), axis=0) 144 | static_all = np.concatenate((static_all, data[9]), axis=0) 145 | classes_all = np.concatenate((classes_all, classes), axis=0) 146 | keep_all.append(kept) 147 | # data = [Y, T, ind_Y, len_T, X, len_X, labels, static, classes, ids] 148 | # data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 149 | data = [Y_all, T_all, ind_Y_all, num_distinct_Y_all, X_all, num_distinct_X_all, labels_all, 150 | static_all, classes_all, icustay_id_all] 151 | 152 | return data 153 | 154 | 155 | def separating_and_resampling(data): 156 | """ 157 | resamples case patients (or controls) to have a balanced dataset 158 | then separate the two to sample exactly 50 - 50 in a batch 159 | :param data: 160 | :return: 161 | """ 162 | # separating 163 | labels = data[7] 164 | 165 | idx_control = labels == 0 166 | idx_case = labels == 1 167 | 168 | # resampling 169 | if np.sum(idx_control) > np.sum(idx_case): 170 | idx_case = np.random.choice(np.where(idx_case)[0], np.sum(idx_control), replace=True, p=None) 171 | elif np.sum(idx_case) > np.sum(idx_control): 172 | idx_control = np.random.choice(np.where(idx_control)[0], np.sum(idx_case), replace=True, p=None) 173 | 174 | # separating 175 | control_data = [data[i][idx_control] for i in range(len(data))] 176 | case_data = [data[i][idx_case] for i in range(len(data))] 177 | return case_data, control_data 178 | -------------------------------------------------------------------------------- /src/data_preprocessing/features_preprocessing/stepIII_GP_prep.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import sys 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import tqdm 8 | 9 | # appending head path 10 | cwd = os.path.dirname(os.path.abspath(__file__)) 11 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir)) 12 | sys.path.append(head) 13 | from src.utils.debug import * 14 | 15 | 16 | class CompactTransform: 17 | def __init__(self, data, onset_h, outpath): 18 | self.data = data 19 | self.onset_h = onset_h 20 | self.outpath = outpath 21 | self.cwd = os.path.dirname(os.path.abspath(__file__)) 22 | self.head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir, os.pardir)) 23 | print('working out of the assumption that head is ', self.head) 24 | 25 | def calculation(self): 26 | """ 27 | ids - ids // 28 | values - array of list of observations 29 | ind_lvs - array of list of variable observed 30 | times - array of list of observed times 31 | num_rnn_grid_times - num of grid times 32 | rnn_grid_times - hourly range to celing of end_time (grid times) 33 | labels - label 34 | ind_times - copy of times (legacy) 35 | num_obs_times - # of times with observations (diff from # obs vals as multiple vals can be obs at same t) 36 | num_obs_values - duh 37 | onset_hour - duh 38 | """ 39 | # reformatting self.data so that all numeric values are in one column and a new column indicates the var_id 40 | variables = ['sysbp', 'diabp', 'meanbp', 'resprate', 'heartrate', 'spo2_pulsoxy', 41 | 'tempc', 'bicarbonate', 'creatinine', 'chloride', 'glucose', 42 | 'hematocrit', 'hemoglobin', 'lactate', 'platelet', 'potassium', 'ptt', 43 | 'inr', 'pt', 'sodium', 'bun', 'wbc', 'magnesium', 'ph_bloodgas'] 44 | 45 | # initialise 46 | var_id = len(variables) 47 | var = variables[-1] 48 | self.data["var_id"] = var_id - 1 49 | moddata = self.data.loc[~self.data[var].isna(), ["icustay_id", var, "chart_time", "var_id", "label"]] 50 | moddata.rename(columns={var: "value"}, inplace=True) 51 | # loop 52 | for var_id, var in enumerate(variables[:-1]): 53 | self.data["var_id"] = var_id 54 | temp = self.data.loc[~self.data[var].isna(), ["icustay_id", var, "chart_time", "var_id", "label"]] 55 | temp.rename(columns={var: "value"}, inplace=True) 56 | moddata = moddata.append(temp, sort=False) 57 | self.data = moddata.sort_values(["icustay_id", "chart_time"], inplace=False) 58 | 59 | temp = self.data.groupby("icustay_id", as_index=False).max() 60 | # ids - ids 61 | ids = temp.icustay_id.to_numpy() 62 | end_time = temp.chart_time.to_numpy() 63 | # labels - label 64 | labels = temp.label.to_numpy() 65 | # num_rnn_grid_times - num of grid times 66 | num_rnn_grid_times = np.round(end_time + 1).astype(int) 67 | 68 | # values - array of list of observations 69 | values = [] 70 | # times - array of list of observed times 71 | times = [] 72 | # ind_lvs - array of list of variable observed 73 | ind_lvs = [] 74 | # rnn_grid_times - hourly range to ceiling of end_time (grid times) 75 | rnn_grid_times = [] 76 | for i, x in tqdm(enumerate(ids)): 77 | if i % 300 == 0: t_print("id iteration {}".format(i)) 78 | values.append(self.data.loc[self.data.icustay_id == x, "value"].tolist()) 79 | times.append(self.data.loc[self.data.icustay_id == x, "chart_time"].tolist()) 80 | ind_lvs.append(self.data.loc[self.data.icustay_id == x, "var_id"].tolist()) 81 | rnn_grid_times.append(np.arange(num_rnn_grid_times[i])) 82 | values = np.array(values) 83 | times = np.array(times) 84 | ind_lvs = np.array(ind_lvs) 85 | rnn_grid_times = np.array(rnn_grid_times) 86 | 87 | # ind_times - copy of times (legacy) 88 | ind_times = times 89 | # num_obs_values - duh 90 | num_obs_values = self.data.groupby("icustay_id").value.count().to_numpy() 91 | # num_obs_times - # of times with observations 92 | # (diff from # obs vals as multiple vals can be obs at same t) 93 | num_obs_times = self.data[["icustay_id", "chart_time"]].groupby("icustay_id").chart_time.count().to_numpy() 94 | onset_hour = self.onset_h.loc[self.onset_h.icustay_id.isin(ids)]. \ 95 | sort_values(by="icustay_id").onset_hour.to_numpy() 96 | self.result = [values, times, ind_lvs, ind_times, 97 | labels, num_rnn_grid_times, rnn_grid_times, 98 | num_obs_times, num_obs_values, onset_hour, ids] 99 | 100 | def save(self, features=None): 101 | if features is None: 102 | path = os.path.join(self.head, 'data', self.outpath, "GP_prep.pkl") 103 | else: 104 | path = os.path.join(self.head, 'data', self.outpath, "GP_prep_{}.pkl".format(features)) 105 | with open(path, "wb") as f: 106 | pickle.dump(self.result, f) 107 | 108 | 109 | def main(args): 110 | outpath = args.out_path 111 | cwd = os.path.dirname(os.path.abspath(__file__)) 112 | path = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir)) + "/data/" 113 | file = path + outpath + "/full_labvitals.csv" 114 | data = pd.read_csv(file) 115 | onsetfile = path + "processed/onset_hours.csv" 116 | onset = pd.read_csv(onsetfile) 117 | ct = CompactTransform(data, onset, outpath) 118 | ct.calculation() 119 | ct.save() 120 | 121 | 122 | def modular_main(outpath, onset_file_name, mr_features, extension_name): 123 | cwd = os.path.dirname(os.path.abspath(__file__)) 124 | path = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir)) + "/data/" 125 | if extension_name is not None: 126 | file = path + outpath + "/full_labvitals_{}.csv".format(extension_name) 127 | else: 128 | file = path + outpath + "/full_labvitals.csv" 129 | data = pd.read_csv(file) 130 | onsetfile = path + "processed/" + onset_file_name 131 | onset = pd.read_csv(onsetfile) 132 | ct = CompactTransform(data, onset, outpath) 133 | if mr_features: 134 | ct.calculation(features='mr_features') 135 | else: 136 | ct.calculation() 137 | ct.save(features=extension_name) 138 | 139 | 140 | def parse_arg(): 141 | parser = argparse.ArgumentParser() 142 | parser.add_argument("-o", "--out_path", 143 | choices=['test', 'val', 'train'], 144 | help="where to save the output files. Choose ['test','val','train']") 145 | return parser.parse_args() 146 | 147 | 148 | if __name__ == '__main__': 149 | args = parse_arg() 150 | main(args) 151 | -------------------------------------------------------------------------------- /src/data_preprocessing/features_preprocessing/archive/GP_prep.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import sys 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | # appending head path 10 | cwd = os.path.dirname(os.path.abspath(__file__)) 11 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir, os.pardir)) 12 | sys.path.append(head) 13 | from src.utils.debug import * 14 | 15 | 16 | class CompactTransform: 17 | def __init__(self, data, onset_h, outpath): 18 | self.data = data 19 | self.onset_h = onset_h 20 | self.outpath = outpath 21 | 22 | def calculation(self, features=None): 23 | """ 24 | ids - ids // 25 | values - array of list of observations 26 | ind_lvs - array of list of variable observed 27 | times - array of list of observed times 28 | num_rnn_grid_times - num of grid times 29 | rnn_grid_times - hourly range to celing of end_time (grid times) 30 | labels - label 31 | ind_times - copy of times (legacy) 32 | num_obs_times - # of times with observations (diff from # obs vals as multiple vals can be obs at same t) 33 | num_obs_values - duh 34 | onset_hour - duh 35 | """ 36 | # reformatting self.data so that all numeric values are in one column and a new column indicates the var_id 37 | if features is None: 38 | variables = ['sysbp', 'diabp', 'meanbp', 'resprate', 'heartrate', 39 | 'spo2_pulsoxy', 'tempc', 'cardiacoutput', 'tvset', 'tvobserved', 40 | 'tvspontaneous', 'peakinsppressure', 'totalpeeplevel', 'o2flow', 'fio2', 41 | 'albumin', 'bands', 'bicarbonate', 'bilirubin', 'creatinine', 42 | 'chloride', 'glucose', 'hematocrit', 'hemoglobin', 'lactate', 43 | 'platelet', 'potassium', 'ptt', 'inr', 'pt', 'sodium', 'bun', 'wbc', 44 | 'creatinekinase', 'ck_mb', 'fibrinogen', 'ldh', 'magnesium', 45 | 'calcium_free', 'po2_bloodgas', 'ph_bloodgas', 'pco2_bloodgas', 46 | 'so2_bloodgas', 'troponin_t'] 47 | 48 | elif features == 'mr_features': 49 | variables = ['sysbp', 'diabp', 'meanbp', 'resprate', 'heartrate', 'spo2_pulsoxy', 50 | 'tempc', 'bicarbonate', 'creatinine', 'chloride', 'glucose', 51 | 'hematocrit', 'hemoglobin', 'lactate', 'platelet', 'potassium', 'ptt', 52 | 'inr', 'pt', 'sodium', 'bun', 'wbc', 'magnesium', 'ph_bloodgas'] 53 | else: 54 | return 1 55 | 56 | # initialise 57 | var_id = len(variables) 58 | var = variables[-1] 59 | self.data["var_id"] = var_id - 1 60 | moddata = self.data.loc[~self.data[var].isna(), ["icustay_id", var, "chart_time", "var_id", "label"]] 61 | moddata.rename(columns={var: "value"}, inplace=True) 62 | # loop 63 | for var_id, var in enumerate(variables[:-1]): 64 | self.data["var_id"] = var_id 65 | temp = self.data.loc[~self.data[var].isna(), ["icustay_id", var, "chart_time", "var_id", "label"]] 66 | temp.rename(columns={var: "value"}, inplace=True) 67 | moddata = moddata.append(temp, sort=False) 68 | self.data = moddata.sort_values(["icustay_id", "chart_time"], inplace=False) 69 | 70 | temp = self.data.groupby("icustay_id", as_index=False).max() 71 | # ids - ids 72 | ids = temp.icustay_id.to_numpy() 73 | end_time = temp.chart_time.to_numpy() 74 | # labels - label 75 | labels = temp.label.to_numpy() 76 | # num_rnn_grid_times - num of grid times 77 | num_rnn_grid_times = np.round(end_time + 1).astype(int) 78 | 79 | # values - array of list of observations 80 | values = [] 81 | # times - array of list of observed times 82 | times = [] 83 | # ind_lvs - array of list of variable observed 84 | ind_lvs = [] 85 | # rnn_grid_times - hourly range to ceiling of end_time (grid times) 86 | rnn_grid_times = [] 87 | for i, x in enumerate(ids): 88 | if i % 300 == 0: t_print("id iteration {}".format(i)) 89 | values.append(self.data.loc[self.data.icustay_id == x, "value"].tolist()) 90 | times.append(self.data.loc[self.data.icustay_id == x, "chart_time"].tolist()) 91 | ind_lvs.append(self.data.loc[self.data.icustay_id == x, "var_id"].tolist()) 92 | rnn_grid_times.append(np.arange(num_rnn_grid_times[i])) 93 | values = np.array(values) 94 | times = np.array(times) 95 | ind_lvs = np.array(ind_lvs) 96 | rnn_grid_times = np.array(rnn_grid_times) 97 | 98 | # ind_times - copy of times (legacy) 99 | ind_times = times 100 | # num_obs_values - duh 101 | num_obs_values = self.data.groupby("icustay_id").value.count().to_numpy() 102 | # num_obs_times - # of times with observations 103 | # (diff from # obs vals as multiple vals can be obs at same t) 104 | num_obs_times = self.data[["icustay_id", "chart_time"]].groupby("icustay_id").chart_time.count().to_numpy() 105 | onset_hour = self.onset_h.loc[self.onset_h.icustay_id.isin(ids)]. \ 106 | sort_values(by="icustay_id").onset_hour.to_numpy() 107 | self.result = [values, times, ind_lvs, ind_times, 108 | labels, num_rnn_grid_times, rnn_grid_times, 109 | num_obs_times, num_obs_values, onset_hour, ids] 110 | 111 | def save(self, features=None): 112 | if features is None: 113 | path = os.path.join(head, 'data', self.outpath, 'GP_prep.pkl') 114 | else: 115 | path = os.path.join(head, 'data', self.outpath, "/GP_prep_{}.pkl".format(features)) 116 | with open(path, "wb") as f: 117 | pickle.dump(self.result, f) 118 | 119 | 120 | def main(args): 121 | outpath = args.out_path 122 | cwd = os.path.dirname(os.path.abspath(__file__)) 123 | path = os.path.join(head, 'data') 124 | file = path + outpath + "/full_labvitals.csv" 125 | data = pd.read_csv(file) 126 | onsetfile = path + "processed/onset_hours.csv" 127 | onset = pd.read_csv(onsetfile) 128 | ct = CompactTransform(data, onset, outpath) 129 | ct.calculation() 130 | ct.save() 131 | 132 | 133 | def modular_main(outpath, onset_file_name, mr_features, extension_name): 134 | cwd = os.path.dirname(os.path.abspath(__file__)) 135 | path = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir)) + "/data/" 136 | if extension_name is not None: 137 | file = path + outpath + "/full_labvitals_{}.csv".format(extension_name) 138 | else: 139 | file = path + outpath + "/full_labvitals.csv" 140 | data = pd.read_csv(file) 141 | onsetfile = path + "processed/" + onset_file_name 142 | onset = pd.read_csv(onsetfile) 143 | ct = CompactTransform(data, onset, outpath) 144 | if mr_features: 145 | ct.calculation(features='mr_features') 146 | else: 147 | ct.calculation() 148 | ct.save(features=extension_name) 149 | 150 | 151 | def parse_arg(): 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument("-o", "--out_path", 154 | choices=['test', 'val', 'train'], 155 | help="where to save the output files. Choose ['test','val','train']") 156 | return parser.parse_args() 157 | 158 | 159 | if __name__ == '__main__': 160 | args = parse_arg() 161 | main(args) 162 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_features/make_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import pandas as pd 5 | import psycopg2 6 | from sqlalchemy import create_engine 7 | from sqlalchemy.types import Integer, DateTime, Numeric 8 | 9 | cwd = os.path.dirname(os.path.abspath(__file__)) 10 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir, os.pardir, os.pardir)) 11 | sys.path.append(head) 12 | from src.utils.debug import * 13 | 14 | 15 | 16 | class MakeData: 17 | 18 | def __init__(self, connect_key="dbname=mimic user=postgres host=localhost password=postgres options=--search_path=mimiciii",): 19 | """ 20 | Initialise function 21 | :param sqluser: user name 22 | :param schema_write_name: schema with write access 23 | :param schema_read_name: schema where mimic is saved 24 | """ 25 | # specify user/password/where the database is 26 | self.connect_key = connect_key 27 | self.cwd = cwd 28 | self.dbname = connect_key.rsplit('dbname=')[1].rsplit(' ')[0] 29 | self.user = connect_key.rsplit('user=')[1].rsplit(' ')[0] 30 | self.password = connect_key.rsplit('password=')[1].rsplit(' ')[0] 31 | try: 32 | self.host = connect_key.rsplit('host=')[1].rsplit(' ')[0] 33 | except: 34 | self.host = 'localhost' 35 | try: 36 | self.port = connect_key.rsplit('port=')[1].rsplit(' ')[0] 37 | except: 38 | self.port = str(5432) 39 | print('working out of the assumption that cwd is ', self.cwd) 40 | 41 | self.engine = create_engine('postgresql+psycopg2://{0}:{1}@{2}:{3}/{4}'.format(self.user, 42 | self.password, 43 | self.host, 44 | self.port, 45 | self.dbname)) 46 | 47 | def create_table(self, sqlfile): 48 | conn = psycopg2.connect(self.connect_key) 49 | cur = conn.cursor() 50 | file = self.cwd + sqlfile 51 | with open(file, 'r') as openfile: 52 | query = openfile.read() 53 | openfile.close() 54 | cur.execute(query) 55 | conn.commit() 56 | conn.close() 57 | 58 | def build_df(self, q_text): 59 | conn = psycopg2.connect(self.connect_key) 60 | query = q_text 61 | return pd.read_sql_query(query, conn) 62 | 63 | def step1_cohort(self): 64 | t_print("welcome to step1") 65 | pre_file = "/sepsis3_cohort_mr.sql" 66 | file = "/hourly-cohort.sql" 67 | t_print("creating hourly cohort ...") 68 | start = time.time() 69 | self.create_table(pre_file) 70 | t_print(".. done 1/2 time: {}".format(time.time() - start)) 71 | self.create_table(file) 72 | t_print(".. done 2/2 time: {}".format(time.time() - start)) 73 | path = os.path.join(head, 'data', 'interim') 74 | file1 = "q13_cases_hourly_ex1c.csv" 75 | file2 = "q13_controls_hourly.csv" 76 | t_print("saving results ...") 77 | start = time.time() 78 | self.build_df("SELECT * FROM cases_hourly_ex1c").to_csv(os.path.join(path, file1)) 79 | t_print("time : {}".format(time.time() - start)) 80 | self.build_df("SELECT * FROM controls_hourly").to_csv(os.path.join(path, file2)) 81 | t_print(".. done! time: {}".format(time.time() - start)) 82 | 83 | def step3_match_controls_to_sql(self): 84 | path = os.path.join(head, 'data', 'interim') 85 | file = "q13_matched_controls.csv" 86 | t_print("reading csv..") 87 | mc = pd.read_csv(os.path.join(path, file)) 88 | t_print("read") 89 | print_time() 90 | types = {"icustay_id": Integer(), 91 | "hadm_id": Integer(), 92 | "intime": DateTime(), 93 | "outtime": DateTime(), 94 | "length_of_stay": Numeric(), 95 | "control_onset_hour": Numeric(), 96 | "control_onset_time": DateTime(), 97 | "matched_case_icustay_id": Integer() 98 | } 99 | t_print("saving to SQL...") 100 | # somehow we cannot overwrite tables directly with "to_sql" so let's do that before 101 | conn = psycopg2.connect(self.connect_key) 102 | cur = conn.cursor() 103 | cur.execute("drop table IF EXISTS matched_controls_hourly cascade") 104 | conn.commit() 105 | mc[mc.columns].to_sql("matched_controls_hourly", 106 | self.engine, 107 | if_exists='append', 108 | schema="mimiciii", 109 | dtype=types) 110 | t_print("saved") 111 | 112 | def step4_extract_data(self): 113 | # read all SQL files 114 | files = ["/extract-55h-of-hourly-case-vital-series_ex1c.sql", 115 | "/extract-55h-of-hourly-control-vital-series_ex1c.sql", 116 | "/extract-55h-of-hourly-case-lab-series_ex1c.sql", 117 | "/extract-55h-of-hourly-control-lab-series_ex1c.sql", 118 | "/static-query.sql"] 119 | for file in files: 120 | print_time() 121 | t_print(file) 122 | self.create_table(file) 123 | path = os.path.join(head, 'data', 'interim') 124 | 125 | # save static files 126 | queries = ["select * from icustay_static", 127 | "select * from icustay_static st inner join cases_hourly_ex1c ch on st.icustay_id=ch.icustay_id", 128 | "select * from icustay_static st inner join matched_controls_hourly ch on st.icustay_id=ch.icustay_id",] 129 | files = ["static_variables.csv", 130 | "static_variables_cases.csv", 131 | "static_variables_controls.csv", ] 132 | for q, f in zip(queries, files): 133 | print_time() 134 | t_print(f) 135 | self.build_df(q).to_csv(os.path.join(path, f)) 136 | 137 | # save time series files 138 | queries = ["""select 139 | icustay_id 140 | , subject_id 141 | , chart_time 142 | , sepsis_target 143 | , sysbp 144 | , diabp 145 | , meanbp 146 | , resprate 147 | , heartrate 148 | , spo2_pulsoxy 149 | , tempc 150 | from case_55h_hourly_vitals_ex1c cv order by cv.icustay_id, cv.chart_time""", 151 | """select 152 | icustay_id 153 | , subject_id 154 | , chart_time 155 | , pseudo_target 156 | , sysbp 157 | , diabp 158 | , meanbp 159 | , resprate 160 | , heartrate 161 | , spo2_pulsoxy 162 | , tempc 163 | from control_55h_hourly_vitals_ex1c cv order by cv.icustay_id, cv.chart_time""", 164 | """select 165 | icustay_id 166 | , subject_id 167 | , chart_time 168 | , sepsis_target 169 | , bicarbonate 170 | , creatinine 171 | , chloride 172 | , glucose 173 | , hematocrit 174 | , hemoglobin 175 | , lactate 176 | , platelet 177 | , potassium 178 | , ptt 179 | , inr 180 | , pt 181 | , sodium 182 | , bun 183 | , wbc 184 | , magnesium 185 | , ph_bloodgas 186 | from case_55h_hourly_labs_ex1c cl order by cl.icustay_id, cl.chart_time""", 187 | """select 188 | icustay_id 189 | , subject_id 190 | , chart_time 191 | , pseudo_target 192 | , bicarbonate 193 | , creatinine 194 | , chloride 195 | , glucose 196 | , hematocrit 197 | , hemoglobin 198 | , lactate 199 | , platelet 200 | , potassium 201 | , ptt 202 | , inr 203 | , pt 204 | , sodium 205 | , bun 206 | , wbc 207 | , magnesium 208 | , ph_bloodgas 209 | from control_55h_hourly_labs_ex1c cl order by cl.icustay_id, cl.chart_time""" 210 | ] 211 | 212 | files = ["vital_variables_cases.csv", 213 | "vital_variables_controls.csv", 214 | "lab_variables_cases.csv", 215 | "lab_variables_controls.csv",] 216 | 217 | # then do data extraction: group together all readings per timestamp 218 | for q, f in zip(queries, files): 219 | print_time() 220 | t_print(f) 221 | temp = self.build_df(q) 222 | temp.groupby(["icustay_id", "chart_time"], as_index=False).mean().to_csv(os.path.join(path, f)) 223 | 224 | -------------------------------------------------------------------------------- /src/trainers/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import sys 4 | import numpy as np 5 | import tensorflow as tf 6 | from tqdm import tqdm 7 | # appending head path 8 | cwd = os.path.dirname(os.path.abspath(__file__)) 9 | head = os.path.abspath(os.path.join(cwd, os.pardir, os.pardir)) 10 | sys.path.append(head) 11 | from src.loss_n_eval.aucs import evals as uni_evals 12 | from src.loss_n_eval.losses import grad, GP_loss 13 | from src.utils.debug import t_print 14 | 15 | class Trainer: 16 | def __init__(self, 17 | model, 18 | data, 19 | num_epochs, 20 | batch_size, 21 | optimizer, 22 | global_step, 23 | summary_writers, 24 | log_path, 25 | train_only=False, 26 | notebook_friendly=False, 27 | eval_every=20, 28 | late_patients_only=False, 29 | horizon0=False, 30 | lab_vitals_only=False, 31 | weighted_loss=None, 32 | ): 33 | 34 | self.model = model 35 | self.data = data 36 | self.num_epochs = num_epochs 37 | self.batch_size = batch_size 38 | self.optimizer = optimizer 39 | self.global_step = global_step 40 | self.notebook_friendly = notebook_friendly 41 | self.summary_writers = summary_writers 42 | self.log_path = log_path 43 | self.train_only = train_only 44 | self.eval_every = eval_every 45 | self.late_patients_only = late_patients_only 46 | self.horizon0 = horizon0 47 | self.lab_vitals_only = lab_vitals_only 48 | self.weighted_loss = weighted_loss 49 | 50 | # Initialise progress trackers - epoch 51 | self.train_loss_results = [] 52 | self._roc = [] 53 | self._pr = [] 54 | 55 | # Initialise progress trackers - batch 56 | self.train_loss_results_batch = [] 57 | self._roc_batch = [] 58 | self._pr_batch = [] 59 | 60 | if self.late_patients_only: 61 | # 'int' truncates, hence int + 1 finds the ceiling 62 | self.no_batches = int(len(self.data.late_case_patients) * 6 / self.batch_size) + 1 63 | else: 64 | self.no_batches = int(len(self.data.train_case_idx) / self.batch_size) + 1 65 | self.no_dev_batches = int(len(self.data.val_data[-1]) / self.batch_size) + 1 66 | 67 | 68 | def run(self): 69 | for epoch in range(self.num_epochs): 70 | t_print("Start of epoch {}".format(epoch)) 71 | # shuffle data 72 | np.random.shuffle(self.data.train_case_idx) 73 | np.random.shuffle(self.data.train_control_idx) 74 | self.data.apply_reshuffle() 75 | 76 | for batch in tqdm(range(self.no_batches)): 77 | # Load data 78 | # batch_data = Y, T, ind_features, num_distinct_Y, X, num_distinct_X, static, labels, classes 79 | batch_data = next(self.data.next_batch(self.batch_size, batch, late=self.late_patients_only, 80 | horizon0=self.horizon0)) 81 | # batch_data[8] is static 82 | if self.lab_vitals_only: 83 | inputs = batch_data[:7] 84 | else: 85 | inputs = batch_data[:8] 86 | y = batch_data[8] 87 | classes = batch_data[9] 88 | if len(y) > 0: 89 | 90 | # Evaluate loss and gradient 91 | loss_value, grads = grad(self.model, inputs, y, GP=True, weighted_loss=self.weighted_loss) 92 | # Apply gradient 93 | self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables), self.global_step) 94 | self.global_step.assign_add(1) 95 | 96 | # Track progress - loss 97 | self.train_loss_results_batch.append(loss_value.numpy()) 98 | 99 | # Track progress - metrics 100 | y_hat = tf.nn.softmax(self.model(inputs)) 101 | roc_auc, pr_auc, _, _ = uni_evals(y.numpy(), y_hat.numpy(), classes, overall=True) 102 | self._roc_batch.append(roc_auc) 103 | self._pr_batch.append(pr_auc) 104 | 105 | # write into tensorboard 106 | step = (epoch * self.no_batches + batch) * self.no_dev_batches 107 | with self.summary_writers['train'].as_default(): 108 | tf.summary.scalar('loss', loss_value.numpy(), step=step) 109 | for i in range(8): 110 | if roc_auc[i] != 0: tf.summary.scalar("roc_{}".format(i), roc_auc[i], step=step) 111 | if pr_auc[i] != 0: tf.summary.scalar("pr_{}".format(i), pr_auc[i], step=step) 112 | 113 | if batch % self.eval_every == 0: 114 | t_print("Epoch {:03d} -- Batch {:03d}: Loss: {:.3f}\tROC o/a:{:.3f}\tPR o/a:{:.3f}".format( 115 | epoch, batch, loss_value.numpy(), roc_auc[7], pr_auc[7])) 116 | if not self.train_only: 117 | # iterate over all horizons 118 | for horizon in range(7): 119 | self.dev_eval_per_horizon(horizon, step) 120 | 121 | # end of batch loop 122 | self.train_loss_results.append(np.mean(self.train_loss_results_batch)) 123 | self._roc.append(np.mean(np.asarray(self._roc_batch), axis=0)) 124 | self._pr.append(np.mean(np.asarray(self._pr_batch), axis=0)) 125 | t_print("End of epoch {:03d}: Loss: {:.3f}\tROC o/a:{:.3f}\tPR o/a:{:.3f}".format( 126 | epoch, self.train_loss_results[-1], self._roc[-1][7], self._pr[-1][7])) 127 | 128 | if not self.train_only: 129 | # save all outputs 130 | all_dev_y = [] 131 | all_dev_y_hat = [] 132 | classes = [] 133 | for dev_batch in range(self.no_dev_batches): 134 | step = (self.num_epochs * self.no_batches) * self.no_dev_batches 135 | y_true, y_hat, _class = self.dev_eval(dev_batch, step) 136 | all_dev_y.append(y_true) 137 | all_dev_y_hat.append(y_hat) 138 | classes.append(_class) 139 | if not self.notebook_friendly: 140 | _to_save = {"epoch": epoch, 141 | "y_true": all_dev_y, 142 | "y_hat": all_dev_y_hat, 143 | "classes": classes, 144 | "weights": self.model.get_weights()} 145 | with open(os.path.join(self.log_path, 'epoch_{}_out.pkl'.format(epoch)), "wb") as f: 146 | pickle.dump(_to_save, f) 147 | 148 | def dev_eval_per_horizon(self, horizon, step): 149 | batch_data = next(self.data.next_batch_dev_small(horizon)) 150 | _, loss_dev, roc_auc, pr_auc = self.step(batch_data) 151 | with self.summary_writers['val'].as_default(): 152 | tf.summary.scalar("loss_dev", loss_dev.numpy(), step=step + horizon) 153 | tf.summary.scalar("roc_{}_dev".format(horizon), roc_auc[horizon], step=step + horizon) 154 | tf.summary.scalar("pr_{}_dev".format(horizon), pr_auc[horizon], step=step + horizon) 155 | # print 156 | t_print("DEV hz {} Loss: {:.3f}\tROC o/a:{:.3f}\tPR o/a:{:.3f}".format(horizon, 157 | loss_dev, roc_auc[7], pr_auc[7])) 158 | 159 | def dev_eval(self, dev_batch, step): 160 | batch_data = next(self.data.next_batch_dev_all(self.batch_size, dev_batch)) 161 | dev_y_hat, loss_dev, roc_auc, pr_auc = self.step(batch_data) 162 | # write into sacred observer 163 | with self.summary_writers['val'].as_default(): 164 | tf.summary.scalar("loss_dev", loss_dev.numpy(), step=step + dev_batch) 165 | for i in range(7): 166 | if roc_auc[i] != 0: tf.summary.scalar("roc_{}_dev".format(i), roc_auc[i], step=step + dev_batch) 167 | if pr_auc[i] != 0: tf.summary.scalar("pr_{}_dev".format(i), pr_auc[i], step=step + dev_batch) 168 | # print 169 | t_print("DEV Loss: {:.3f}\tROC o/a:{:.3f}\tPR o/a:{:.3f}".format(loss_dev, roc_auc[7], pr_auc[7])) 170 | # return y_true, y_hat, class 171 | return np.array(batch_data[9]), dev_y_hat, np.array(batch_data[9]) 172 | 173 | def step(self, batch_data): 174 | # batch_data[8] is static 175 | if self.lab_vitals_only: 176 | inputs = batch_data[:7] 177 | else: 178 | inputs = batch_data[:8] 179 | y = batch_data[8] 180 | classes = batch_data[9] 181 | if len(y) > 0: 182 | # Track progress - dev loss 183 | loss_dev = GP_loss(self.model, inputs, y) 184 | # Track progress - dev metrics 185 | dev_y_hat = tf.nn.softmax(self.model(inputs)) 186 | roc_auc, pr_auc, _, _ = uni_evals(y.numpy(), dev_y_hat.numpy(), classes, overall=True) 187 | return dev_y_hat, loss_dev, roc_auc, pr_auc 188 | else: return None, None, None, None 189 | 190 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_features/extract-55h-of-hourly-case-vital-series_ex1c.sql: -------------------------------------------------------------------------------- 1 | /* 2 | Extract 55 hours of vital time series of sepsis case icustays (48 hours before onset and 7 hours after onset). 3 | --------------------------------------------------------------------------------------------------------------------- 4 | - MODIFIED VERSION 5 | - SOURCE: https://github.com/MIT-LCP/mimic-code/blob/7ff270c7079a42621f6e011de6ce4ddc0f7fd45c/concepts/firstday/vitals-first-day.sql 6 | - AUTHOR (of this version): Michael Moor, October 2018 7 | - HINT: to add/remove vitals uncomment/comment both the 'case' statements (e.g. l.35) the corresponding itemids below (e.g. l.112,113) 8 | - CAVE: For our purposes we did not use CareVue data. Therefore their IDs were removed from the 'case' statement in l.66! 9 | However, you find them commented out in the 'where' clause at the end of the script. 10 | --------------------------------------------------------------------------------------------------------------------- 11 | */ 12 | 13 | --extract-time-series.sql 14 | -- extract time series, TEMPLATE/inspiration: vitals-first-day.sql 15 | -- info: to choose only 1 vital: comment out both the case statement (l.14) of the other variables and the corresponding itemids below (l.71-127) 16 | 17 | 18 | DROP MATERIALIZED VIEW IF EXISTS case_55h_hourly_vitals_ex1c CASCADE; 19 | create materialized view case_55h_hourly_vitals_ex1c as 20 | SELECT pvt.icustay_id, pvt.subject_id -- removed , pvt.hadm_id, 21 | , pvt.chart_time 22 | , case 23 | when pvt.chart_time < pvt.sepsis_onset then 0 24 | when pvt.chart_time between pvt.sepsis_onset and (pvt.sepsis_onset+interval '5' hour ) then 1 25 | else 2 end as sepsis_target 26 | --, case 27 | -- when pvt.sepsis_onset > (pvt.intime + interval '150' hour) then 1 28 | -- else 0 end as late_onset_after_150h 29 | 30 | 31 | -- Easier names 32 | 33 | , case when VitalID = 2 then valuenum else null end as SysBP 34 | , case when VitalID = 3 then valuenum else null end as DiaBP 35 | , case when VitalID = 4 then valuenum else null end as MeanBP 36 | , case when VitalID = 5 then valuenum else null end as RespRate 37 | , case when VitalID = 1 then valuenum else null end as HeartRate 38 | , case when VitalID = 7 then valuenum else null end as SpO2_pulsoxy 39 | --, case when VitalID = 8 then valuenum else null end as Glucose 40 | , case when VitalID = 6 then valuenum else null end as TempC 41 | , case when VitalID = 10 then valuenum else null end as CardiacOutput 42 | , case when VitalID = 11 then valuenum else null end as SV 43 | , case when VitalID = 12 then valuenum else null end as SVI 44 | , case when VitalID = 13 then valuenum else null end as SVV 45 | , case when VitalID = 14 then valuenum else null end as TFC 46 | , case when VitalID = 15 then valuenum else null end as TPR 47 | , case when VitalID = 16 then valuenum else null end as TVset 48 | , case when VitalID = 17 then valuenum else null end as TVobserved 49 | , case when VitalID = 18 then valuenum else null end as TVspontaneous 50 | , case when VitalID = 19 then valuenum else null end as Flowrate 51 | , case when VitalID = 20 then valuenum else null end as PeakInspPressure 52 | , case when VitalID = 21 then valuenum else null end as TotalPEEPLevel 53 | , case when VitalID = 22 then valuenum else null end as VitalCapacity 54 | , case when VitalID = 23 then valuenum else null end as O2Flow 55 | , case when VitalID = 24 then valuenum else null end as FiO2 56 | , case when VitalID = 25 then valuenum else null end as CRP 57 | 58 | FROM ( 59 | select ch.icustay_id, ie.subject_id -- removed: ie.subject_id, ie.hadm_id, 60 | , case 61 | when itemid in (220045) and valuenum > 0 and valuenum < 300 then 1 -- HeartRate 62 | when itemid in (220179,220050,225309) and valuenum > 0 and valuenum < 400 then 2 -- SysBP 63 | when itemid in (8368,8440,8441,8555,220180,220051,225310) and valuenum > 0 and valuenum < 300 then 3 -- DiasBP 64 | when itemid in (220052,220181,225312) and valuenum > 0 and valuenum < 300 then 4 -- MeanBP 65 | when itemid in (220210,224690) and valuenum > 0 and valuenum < 70 then 5 -- RespRate 66 | when itemid in (223761,678) and valuenum > 70 and valuenum < 120 then 6 -- TempF, converted to degC in valuenum call 67 | when itemid in (223762,676) and valuenum > 10 and valuenum < 50 then 6 -- TempC 68 | when itemid in (646,220277) and valuenum > 0 and valuenum <= 100 then 7 -- SpO2 69 | when itemid in (807,811,1529,3745,3744,225664,220621,226537) and valuenum > 0 then 8 -- Glucose 70 | --when itemid in (227428) and valuenum >= 0 then 9 -- SOFA score (computed in-ICU!) 71 | when itemid in (228369, 224842, 220088, 227543) and valuenum > 0 then 10 -- Cardiac Output (l/min), ids: NICOM, hemodynamics, hemodynamics Thermodilution, CO arterial (hemodynamics) 72 | when itemid in (228374, 227547) and valuenum > 0 then 11 -- SV: Stroke Volume (ml/beat) (id: NICOM, SV Arterial (Hemodynamics)) 73 | when itemid in (228375) and valuenum >= 0 then 12 -- SVI-NICOM: Strove Volume Index (%) (id: NICOM) 74 | when itemid in (228376, 227546) then 13 -- SVV : Strove Volume Variation (no unit) (id: NICOM, SVV arterial (hemodynamics)) 75 | when itemid in (228380) then 14 -- Thoracic Fluid Content (TFC) (no unit) (NICOM) 76 | when itemid in (228381) then 15 -- Total Peripheral Resistance (TPR) (dynes*sec/cm5) (NICOM) 77 | when itemid in (224684) and valuenum > 0 then 16 -- Tidal Volume (set) (ml) 78 | when itemid in (224685) and valuenum > 0 then 17 -- Tidal Volume (observed) (ml) 79 | when itemid in (224686) and valuenum > 0 then 18 -- Tidal Volume (spontaneous) (ml) 80 | when itemid in (224691) and valuenum >= 0 then 19 -- Flow rate (L/min) (respiratory) 81 | when itemid in (224695) and valuenum >= 0 then 20 -- Peak Insp. Pressure metavision chartevents Respiratory (cmH2O) 82 | when itemid in (224700) and valuenum >= 0 then 21 -- Total PEEP Level metavision chartevents Respiratory (cmH2O) 83 | when itemid in (220218) and valuenum > 0 then 22 -- Vital Capacity VC metavision chartevents Respiratory (Liters) 84 | when itemid in (223834, 227582) and valuenum >= 0 then 23 -- O2 Flow O2 Flow metavision chartevents Respiratory L/min (ids: respiratory, BIBAP) 85 | when itemid in (223835) and valuenum >= 0 then 24 -- Inspired O2 Fraction FiO2 metavision chartevents Respiratory (No unit) 86 | when itemid in (227444) and valuenum >= 0 then 25 -- CRP (no values in labevents! thus we use chartevents CRP) 87 | 88 | else null end as VitalID 89 | -- convert F to C 90 | , case when itemid in (223761,678) then (valuenum-32)/1.8 else valuenum end as valuenum 91 | , ce.charttime as chart_time 92 | , ch.sepsis_onset 93 | , s3c.intime 94 | 95 | from cases_hourly_ex1c ch -- was icustays ie (changed it below as well) 96 | left join icustays ie 97 | on ch.icustay_id = ie.icustay_id 98 | left join sepsis3_cohort_mr s3c 99 | on ch.icustay_id = s3c.icustay_id 100 | left join chartevents ce 101 | on ch.icustay_id = ce.icustay_id -- removed: ie.subject_id = ce.subject_id and ie.hadm_id = ce.hadm_id and 102 | and ce.charttime between (ch.sepsis_onset-interval '48' hour ) and (ch.sepsis_onset+interval '7' hour ) 103 | 104 | -- exclude rows marked as error 105 | where ce.error=0 and 106 | ce.itemid in -- and sepsis_case = 1 107 | ( 108 | ---- HEART RATE 109 | --211, --"Heart Rate" 110 | 220045, --"Heart Rate" 111 | 112 | -- Systolic/diastolic 113 | 114 | -- 51, -- Arterial BP [Systolic] 115 | -- 442, -- Manual BP [Systolic] 116 | -- 455, -- NBP [Systolic] 117 | -- 6701, -- Arterial BP #2 [Systolic] 118 | 220179, -- Non Invasive Blood Pressure systolic 119 | 220050, -- Arterial Blood Pressure systolic 120 | 225309, -- ART BP systolic 121 | 122 | -- 8368, -- Arterial BP [Diastolic] 123 | -- 8440, -- Manual BP [Diastolic] 124 | -- 8441, -- NBP [Diastolic] 125 | -- 8555, -- Arterial BP #2 [Diastolic] 126 | 220180, -- Non Invasive Blood Pressure diastolic 127 | 220051, -- Arterial Blood Pressure diastolic 128 | 225310, -- ART BP diastolic 129 | 130 | 131 | -- -- MEAN ARTERIAL PRESSURE 132 | -- 456, --"NBP Mean" 133 | -- 52, --"Arterial BP Mean" 134 | -- 6702, -- Arterial BP Mean #2 135 | -- 443, -- Manual BP Mean(calc) 136 | 220052, --"Arterial Blood Pressure mean" 137 | 220181, --"Non Invasive Blood Pressure mean" 138 | 225312, --"ART BP mean" 139 | -- 224322, -- I-ABP mean 140 | 141 | -- RESPIRATORY RATE 142 | -- 618,-- Respiratory Rate 143 | -- 615,-- Resp Rate (Total) 144 | 220210,-- Respiratory Rate 145 | 224690, --, -- Respiratory Rate (Total) 146 | 147 | 148 | -- SPO2, peripheral 149 | 220277, 150 | 151 | -- GLUCOSE, both lab and fingerstick 152 | -- 807,-- Fingerstick Glucose 153 | -- 811,-- Glucose (70-105) 154 | -- 1529,-- Glucose 155 | -- 3745,-- BloodGlucose 156 | -- 3744,-- Blood Glucose 157 | 225664,-- Glucose finger stick 158 | 220621,-- Glucose (serum) 159 | 226537,-- Glucose (whole blood) 160 | 161 | -- -- TEMPERATURE 162 | 223762, -- "Temperature Celsius" 163 | -- 676, -- "Temperature C" 164 | 223761, -- "Temperature Fahrenheit" 165 | -- 678 -- "Temperature F" 166 | 167 | -- --SOFA SCORE (in icu) 168 | -- 227428 169 | 170 | -- Cardiac Output 171 | 228369, -- NICOM, 172 | 224842, -- hemodynamics, 173 | 220088, -- hemodynamics Thermodilution, 174 | 227543, -- CO arterial (hemodynamics) 175 | 176 | -- Stroke Volume 177 | 228374, -- SV NICOM 178 | 227547, -- SV Arterial (Hemodynamics) 179 | 180 | -- Stroke Volume Index 181 | 228375, -- SVI-NICOM: Strove Volume Index (%) NICOM 182 | 183 | -- Stroke Volume Variation 184 | 228376, -- SVV: NICOM 185 | 227546, -- SVV: arterial (hemodynamics)) 186 | 187 | -- Thoracic Fluid Content 188 | 228380, -- TFC (no unit) (NICOM) 189 | 190 | -- Total Peripheral Resistance 191 | 228381, -- TPR (dynes*sec/cm5) (NICOM) 192 | 193 | -- Tidal Volume set 194 | 224684, -- (ml) 195 | 196 | -- Tidal Volume (observed) 197 | 224685, -- (ml) 198 | 199 | -- Tidal Volume (spontaneous) 200 | 224686, -- (ml) 201 | 202 | -- Flow rate (respiratory) 203 | 224691, -- (L/min) (respiratory) 204 | 205 | -- Peak Insp. Pressure 206 | 224695, -- metavision chartevents Respiratory (cmH2O) 207 | 208 | -- Total PEEP Level 209 | 224700, -- Total PEEP Level metavision chartevents Respiratory (cmH2O) 210 | 211 | -- Vital Capacity VC 212 | 220218, -- Vital Capacity VC metavision chartevents Respiratory (Liters) 213 | 214 | -- O2 Flow 215 | 223834, -- O2 Flow respiratory 216 | 227582, -- O2 Flow BIBAP 217 | -- Inspired O2 Fraction (FiO2) 218 | 223835, -- Inspired O2 Fraction FiO2 metavision chartevents Respiratory (No unit) 219 | 220 | 227444 -- C Reactive Protein (CRP) mg/L metavision chartevents Labs (NO labevents! therefore use chartevents..) 221 | 222 | ) 223 | 224 | ) pvt 225 | WHERE VitalID IS NOT NULL 226 | --group by pvt.subject_id, pvt.hadm_id, pvt.icustay_id 227 | order by pvt.icustay_id, pvt.subject_id, pvt.chart_time; -- removed pvt.hadm_id, 228 | 229 | 230 | 231 | 232 | -------------------------------------------------------------------------------- /src/data_preprocessing/extract_MIMIC_data/extract_features/extract-55h-of-hourly-control-vital-series_ex1c.sql: -------------------------------------------------------------------------------- 1 | /* 2 | Extract 55 hours of vital time series of Control icustays (48 hours before 'control onset' and 7 hours after (buffer)). 3 | --------------------------------------------------------------------------------------------------------------------- 4 | - MODIFIED VERSION 5 | - SOURCE: https://github.com/MIT-LCP/mimic-code/blob/7ff270c7079a42621f6e011de6ce4ddc0f7fd45c/concepts/firstday/vitals-first-day.sql 6 | - AUTHOR (of this version): Michael Moor, October 2018 7 | - HINT: to add/remove vitals uncomment/comment both the 'case' statement (e.g. l.35) the corresponding itemids below (e.g. l.112,113) 8 | - CAVE: For our purposes we did not use CareVue data. Therefore their IDs were removed from the 'case' statement in l.66! 9 | However, you find them commented out in the 'where' clause at the end of the script. 10 | --------------------------------------------------------------------------------------------------------------------- 11 | */ 12 | 13 | --extract-time-series.sql 14 | -- extract time series, TEMPLATE/inspiration: vitals-first-day.sql 15 | -- info: to choose only 1 vital: comment out both the case statement (l.14) of the other variables and the corresponding itemids below (l.71-127) 16 | 17 | 18 | DROP MATERIALIZED VIEW IF EXISTS control_55h_hourly_vitals_ex1c CASCADE; 19 | create materialized view control_55h_hourly_vitals_ex1c as 20 | SELECT pvt.icustay_id, pvt.subject_id -- removed , pvt.hadm_id, 21 | , pvt.chart_time 22 | , case 23 | when pvt.chart_time < pvt.control_onset_time then 0 24 | when pvt.chart_time between pvt.control_onset_time and (pvt.control_onset_time+interval '5' hour ) then 1 25 | else 2 end as pseudo_target 26 | --, case 27 | -- when pvt.sepsis_onset > (pvt.intime + interval '150' hour) then 1 28 | -- else 0 end as late_onset_after_150h 29 | 30 | 31 | -- Easier names 32 | 33 | , case when VitalID = 2 then valuenum else null end as SysBP 34 | , case when VitalID = 3 then valuenum else null end as DiaBP 35 | , case when VitalID = 4 then valuenum else null end as MeanBP 36 | , case when VitalID = 5 then valuenum else null end as RespRate 37 | , case when VitalID = 1 then valuenum else null end as HeartRate 38 | , case when VitalID = 7 then valuenum else null end as SpO2_pulsoxy 39 | --, case when VitalID = 8 then valuenum else null end as Glucose 40 | , case when VitalID = 6 then valuenum else null end as TempC 41 | , case when VitalID = 10 then valuenum else null end as CardiacOutput 42 | , case when VitalID = 11 then valuenum else null end as SV 43 | , case when VitalID = 12 then valuenum else null end as SVI 44 | , case when VitalID = 13 then valuenum else null end as SVV 45 | , case when VitalID = 14 then valuenum else null end as TFC 46 | , case when VitalID = 15 then valuenum else null end as TPR 47 | , case when VitalID = 16 then valuenum else null end as TVset 48 | , case when VitalID = 17 then valuenum else null end as TVobserved 49 | , case when VitalID = 18 then valuenum else null end as TVspontaneous 50 | , case when VitalID = 19 then valuenum else null end as Flowrate 51 | , case when VitalID = 20 then valuenum else null end as PeakInspPressure 52 | , case when VitalID = 21 then valuenum else null end as TotalPEEPLevel 53 | , case when VitalID = 22 then valuenum else null end as VitalCapacity 54 | , case when VitalID = 23 then valuenum else null end as O2Flow 55 | , case when VitalID = 24 then valuenum else null end as FiO2 56 | , case when VitalID = 25 then valuenum else null end as CRP 57 | 58 | FROM ( 59 | select ch.icustay_id, ie.subject_id -- removed: ie.subject_id, ie.hadm_id, 60 | , case 61 | when itemid in (220045) and valuenum > 0 and valuenum < 300 then 1 -- HeartRate 62 | when itemid in (220179,220050,225309) and valuenum > 0 and valuenum < 400 then 2 -- SysBP 63 | when itemid in (8368,8440,8441,8555,220180,220051,225310) and valuenum > 0 and valuenum < 300 then 3 -- DiasBP 64 | when itemid in (220052,220181,225312) and valuenum > 0 and valuenum < 300 then 4 -- MeanBP 65 | when itemid in (220210,224690) and valuenum > 0 and valuenum < 70 then 5 -- RespRate 66 | when itemid in (223761,678) and valuenum > 70 and valuenum < 120 then 6 -- TempF, converted to degC in valuenum call 67 | when itemid in (223762,676) and valuenum > 10 and valuenum < 50 then 6 -- TempC 68 | when itemid in (646,220277) and valuenum > 0 and valuenum <= 100 then 7 -- SpO2 69 | when itemid in (807,811,1529,3745,3744,225664,220621,226537) and valuenum > 0 then 8 -- Glucose 70 | --when itemid in (227428) and valuenum >= 0 then 9 -- SOFA score (computed in-ICU!) 71 | when itemid in (228369, 224842, 220088, 227543) and valuenum > 0 then 10 -- Cardiac Output (l/min), ids: NICOM, hemodynamics, hemodynamics Thermodilution, CO arterial (hemodynamics) 72 | when itemid in (228374, 227547) and valuenum > 0 then 11 -- SV: Stroke Volume (ml/beat) (id: NICOM, SV Arterial (Hemodynamics)) 73 | when itemid in (228375) and valuenum >= 0 then 12 -- SVI-NICOM: Strove Volume Index (%) (id: NICOM) 74 | when itemid in (228376, 227546) then 13 -- SVV : Strove Volume Variation (no unit) (id: NICOM, SVV arterial (hemodynamics)) 75 | when itemid in (228380) then 14 -- Thoracic Fluid Content (TFC) (no unit) (NICOM) 76 | when itemid in (228381) then 15 -- Total Peripheral Resistance (TPR) (dynes*sec/cm5) (NICOM) 77 | when itemid in (224684) and valuenum > 0 then 16 -- Tidal Volume (set) (ml) 78 | when itemid in (224685) and valuenum > 0 then 17 -- Tidal Volume (observed) (ml) 79 | when itemid in (224686) and valuenum > 0 then 18 -- Tidal Volume (spontaneous) (ml) 80 | when itemid in (224691) and valuenum >= 0 then 19 -- Flow rate (L/min) (respiratory) 81 | when itemid in (224695) and valuenum >= 0 then 20 -- Peak Insp. Pressure metavision chartevents Respiratory (cmH2O) 82 | when itemid in (224700) and valuenum >= 0 then 21 -- Total PEEP Level metavision chartevents Respiratory (cmH2O) 83 | when itemid in (220218) and valuenum > 0 then 22 -- Vital Capacity VC metavision chartevents Respiratory (Liters) 84 | when itemid in (223834, 227582) and valuenum >= 0 then 23 -- O2 Flow O2 Flow metavision chartevents Respiratory L/min (ids: respiratory, BIBAP) 85 | when itemid in (223835) and valuenum >= 0 then 24 -- Inspired O2 Fraction FiO2 metavision chartevents Respiratory (No unit) 86 | when itemid in (227444) and valuenum >= 0 then 25 -- CRP (no values in labevents! thus we use chartevents CRP) 87 | 88 | else null end as VitalID 89 | -- convert F to C 90 | , case when itemid in (223761,678) then (valuenum-32)/1.8 else valuenum end as valuenum 91 | , ce.charttime as chart_time 92 | , ch.control_onset_time 93 | , s3c.intime 94 | 95 | from matched_controls_hourly ch 96 | left join icustays ie 97 | on ch.icustay_id = ie.icustay_id 98 | left join sepsis3_cohort_mr s3c 99 | on ch.icustay_id = s3c.icustay_id 100 | left join chartevents ce 101 | on ch.icustay_id = ce.icustay_id -- removed: ie.subject_id = ce.subject_id and ie.hadm_id = ce.hadm_id and 102 | and ce.charttime between (ch.control_onset_time-interval '48' hour ) and (ch.control_onset_time+interval '7' hour ) 103 | 104 | 105 | -- exclude rows marked as error 106 | where ce.error=0 and 107 | ce.itemid in -- and sepsis_case = 1 108 | ( 109 | ---- HEART RATE 110 | --211, --"Heart Rate" 111 | 220045, --"Heart Rate" 112 | 113 | -- Systolic/diastolic 114 | 115 | -- 51, -- Arterial BP [Systolic] 116 | -- 442, -- Manual BP [Systolic] 117 | -- 455, -- NBP [Systolic] 118 | -- 6701, -- Arterial BP #2 [Systolic] 119 | 220179, -- Non Invasive Blood Pressure systolic 120 | 220050, -- Arterial Blood Pressure systolic 121 | 225309, -- ART BP systolic 122 | 123 | -- 8368, -- Arterial BP [Diastolic] 124 | -- 8440, -- Manual BP [Diastolic] 125 | -- 8441, -- NBP [Diastolic] 126 | -- 8555, -- Arterial BP #2 [Diastolic] 127 | 220180, -- Non Invasive Blood Pressure diastolic 128 | 220051, -- Arterial Blood Pressure diastolic 129 | 225310, -- ART BP diastolic 130 | 131 | 132 | -- -- MEAN ARTERIAL PRESSURE 133 | -- 456, --"NBP Mean" 134 | -- 52, --"Arterial BP Mean" 135 | -- 6702, -- Arterial BP Mean #2 136 | -- 443, -- Manual BP Mean(calc) 137 | 220052, --"Arterial Blood Pressure mean" 138 | 220181, --"Non Invasive Blood Pressure mean" 139 | 225312, --"ART BP mean" 140 | -- 224322, -- I-ABP mean 141 | 142 | -- RESPIRATORY RATE 143 | -- 618,-- Respiratory Rate 144 | -- 615,-- Resp Rate (Total) 145 | 220210,-- Respiratory Rate 146 | 224690, --, -- Respiratory Rate (Total) 147 | 148 | 149 | -- SPO2, peripheral 150 | 220277, 151 | 152 | -- GLUCOSE, both lab and fingerstick 153 | -- 807,-- Fingerstick Glucose 154 | -- 811,-- Glucose (70-105) 155 | -- 1529,-- Glucose 156 | -- 3745,-- BloodGlucose 157 | -- 3744,-- Blood Glucose 158 | 225664,-- Glucose finger stick 159 | 220621,-- Glucose (serum) 160 | 226537,-- Glucose (whole blood) 161 | 162 | -- -- TEMPERATURE 163 | 223762, -- "Temperature Celsius" 164 | -- 676, -- "Temperature C" 165 | 223761, -- "Temperature Fahrenheit" 166 | -- 678 -- "Temperature F" 167 | 168 | -- --SOFA SCORE (in icu) 169 | -- 227428 170 | 171 | -- Cardiac Output 172 | 228369, -- NICOM, 173 | 224842, -- hemodynamics, 174 | 220088, -- hemodynamics Thermodilution, 175 | 227543, -- CO arterial (hemodynamics) 176 | 177 | -- Stroke Volume 178 | 228374, -- SV NICOM 179 | 227547, -- SV Arterial (Hemodynamics) 180 | 181 | -- Stroke Volume Index 182 | 228375, -- SVI-NICOM: Strove Volume Index (%) NICOM 183 | 184 | -- Stroke Volume Variation 185 | 228376, -- SVV: NICOM 186 | 227546, -- SVV: arterial (hemodynamics)) 187 | 188 | -- Thoracic Fluid Content 189 | 228380, -- TFC (no unit) (NICOM) 190 | 191 | -- Total Peripheral Resistance 192 | 228381, -- TPR (dynes*sec/cm5) (NICOM) 193 | 194 | -- Tidal Volume set 195 | 224684, -- (ml) 196 | 197 | -- Tidal Volume (observed) 198 | 224685, -- (ml) 199 | 200 | -- Tidal Volume (spontaneous) 201 | 224686, -- (ml) 202 | 203 | -- Flow rate (respiratory) 204 | 224691, -- (L/min) (respiratory) 205 | 206 | -- Peak Insp. Pressure 207 | 224695, -- metavision chartevents Respiratory (cmH2O) 208 | 209 | -- Total PEEP Level 210 | 224700, -- Total PEEP Level metavision chartevents Respiratory (cmH2O) 211 | 212 | -- Vital Capacity VC 213 | 220218, -- Vital Capacity VC metavision chartevents Respiratory (Liters) 214 | 215 | -- O2 Flow 216 | 223834, -- O2 Flow respiratory 217 | 227582, -- O2 Flow BIBAP 218 | -- Inspired O2 Fraction (FiO2) 219 | 223835, -- Inspired O2 Fraction FiO2 metavision chartevents Respiratory (No unit) 220 | 221 | 227444 -- C Reactive Protein (CRP) mg/L metavision chartevents Labs (NO labevents! therefore use chartevents..) 222 | 223 | ) 224 | 225 | ) pvt 226 | WHERE VitalID IS NOT NULL 227 | --group by pvt.subject_id, pvt.hadm_id, pvt.icustay_id 228 | order by pvt.icustay_id, pvt.subject_id, pvt.chart_time; -- removed pvt.hadm_id, 229 | 230 | 231 | 232 | 233 | --------------------------------------------------------------------------------