├── models ├── __init__.py ├── ards_lstm_model.py └── vent_lstm_model.py ├── analyzer ├── __init__.py ├── container.py ├── analyzer_utils.py ├── analyzer.py ├── ards_nearest_cls.py ├── vent_nearest_cls.py ├── vent_logistic_regression.py ├── ards_logistic_regression.py ├── ards_catboost_dynamic.py ├── dataset_explore │ ├── dataset_report.py │ ├── vent_explore.py │ ├── ards_explore.py │ └── raw_explore.py ├── vent_catboost_dynamic.py ├── cross_validation.py ├── vent_lstm.py └── ards_lstm.py ├── datasets ├── __init__.py ├── helper.py ├── cv_dataset.py └── derived_raw_dataset.py ├── start.sh ├── .vscode ├── settings.json └── launch.json ├── tools ├── data │ ├── __init__.py │ ├── utils.py │ ├── label_generator.py │ └── data_generator.py ├── __init__.py ├── config_loader.py ├── logging.py ├── module_test.py ├── generic.py ├── feature_importance.py └── plot.py ├── documents ├── general_pipeline.png ├── processing.md └── CHANGELOG.md ├── requirements.txt ├── .gitignore ├── main.py ├── configs ├── cv_dataset.yml ├── global_config.yaml ├── mimiciv_dataset_raw.yaml ├── mimiciv_dataset_ards.yaml ├── mimiciv_dataset_vent.yaml └── analyzers.yml ├── launch_list.yml ├── test.py ├── README_CN.md └── README.md /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /analyzer/__init__.py: -------------------------------------------------------------------------------- 1 | from .analyzer import Analyzer -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .derived_ards_dataset import * -------------------------------------------------------------------------------- /start.sh: -------------------------------------------------------------------------------- 1 | rm main.log 2 | nohup python main.py >> main.log 2>&1 & -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.typeCheckingMode": "off", 3 | } -------------------------------------------------------------------------------- /tools/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_generator import * 2 | from .label_generator import * 3 | from .utils import * -------------------------------------------------------------------------------- /documents/general_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/on1262/sepsisdataprocessing/HEAD/documents/general_pipeline.png -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .generic import * 2 | from .plot import * 3 | from .config_loader import GLOBAL_CONF_LOADER 4 | from .logging import logger 5 | from .metrics import * 6 | from .feature_importance import * -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | captum==0.6.0 2 | compress_pickle==2.1.0 3 | matplotlib==3.8.0 4 | missingno==0.5.2 5 | numpy==1.23.5 6 | pandas==2.0.2 7 | PyYAML==6.0.1 8 | scikit_learn==1.2.2 9 | shap==0.42.1 10 | torch==1.13.1+cu116 11 | tqdm==4.65.0 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | mimic-iv-2.2.zip 2 | __pycache__ 3 | /data/ 4 | /outputs/ 5 | *.pyc 6 | *.pt 7 | *.log 8 | catboost_info/ 9 | nohup.out 10 | simhei.ttf 11 | generate_requirements.sh 12 | test_plot/ 13 | tools/module_test.py 14 | launch_list.yml 15 | 16 | 17 | test.py 18 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from analyzer import Analyzer 2 | import yaml 3 | 4 | with open('./launch_list.yml', 'r', encoding='utf-8') as fp: 5 | analyzer_params = yaml.load(fp, Loader=yaml.SafeLoader)['analyzer_params'] 6 | 7 | if __name__ == '__main__': 8 | if len(analyzer_params) > 0: 9 | analyzer = Analyzer(analyzer_params) 10 | print('Done') -------------------------------------------------------------------------------- /configs/cv_dataset.yml: -------------------------------------------------------------------------------- 1 | --- 2 | extract_cols: 3 | - 性别 4 | - 出生年月 5 | - 体重(kg) 6 | - 身高(cm) 7 | - 入住ICU日期 8 | - '#dX_tX_最高呼吸频率(次/min)' 9 | - '#dX_tX_最低呼吸频率(次/min)' 10 | - '#dX_tX_最高心率(次/min)' 11 | - '#dX_tX_最低心率(次/min)' 12 | - '#dX_tX_最高SPO2(%)' 13 | - '#dX_tX_最低SPO2(%)' 14 | - '#dX_tX_最高体温(℃)' 15 | - '#dX_PaO2(mmHg)' 16 | - '#dX_PH' 17 | - '#dX_PaO2(mmHg) / FiO2(%)' -------------------------------------------------------------------------------- /launch_list.yml: -------------------------------------------------------------------------------- 1 | --- 2 | 3 | # analyzers will run in sequence according to the following order 4 | analyzer_params: 5 | # - ards_nearest_4cls 6 | - ards_catboost_dynamic 7 | # - ards_feature_explore 8 | # - ards_logistic_reg 9 | # - ards_lstm 10 | 11 | 12 | #- dataset_report 13 | # - vent_feature_explore 14 | # - vent_catboost_dynamic 15 | # - vent_logistic_reg 16 | # - vent_nearest_3cls 17 | # - vent_lstm 18 | 19 | 20 | # - cross_validation 21 | # - raw_feature_explore -------------------------------------------------------------------------------- /analyzer/container.py: -------------------------------------------------------------------------------- 1 | from datasets import MIMICIV_ARDS_Dataset 2 | import tools 3 | import os 4 | 5 | 6 | class DataContainer(): 7 | '''存放数据和一些与模型无关的内容''' 8 | def __init__(self): 9 | self._conf = tools.GLOBAL_CONF_LOADER['analyzer']['data_container'] # 这部分是global, 对外界不可见 10 | self.n_fold = self._conf['n_fold'] 11 | self.seed = self._conf['seed'] 12 | # for feature importance 13 | self.register_values = {} 14 | 15 | def get_analyzer_params(self, model_name) -> dict: 16 | '''根据数据集和模型名不同, 获取所需的模型参数''' 17 | paths = tools.GLOBAL_CONF_LOADER['paths'] 18 | analyzer_params = tools.Config('configs/analyzers.yml')[model_name] 19 | analyzer_params['paths'] = paths # 添加global config的paths到params中 20 | return analyzer_params 21 | 22 | def clear_register(self): 23 | self.register_values.clear() -------------------------------------------------------------------------------- /analyzer/analyzer_utils.py: -------------------------------------------------------------------------------- 1 | from tools import logger as logger 2 | import pickle 3 | import numpy as np 4 | import tools 5 | import os 6 | 7 | def create_final_result(out_dir): 8 | '''Collect result.log from each folder, merge to final result.log''' 9 | logger.info('Creating final result') 10 | with open(os.path.join(out_dir, 'final_result.log'), 'w') as final_f: 11 | for dir in os.listdir(out_dir): 12 | p = os.path.join(out_dir, dir) 13 | if os.path.isdir(p): 14 | if 'result.log' in os.listdir(p): 15 | rp = os.path.join(p, 'result.log') 16 | logger.info(f'Find: {rp}') 17 | with open(rp, 'r') as f: 18 | final_f.write(f.read()) 19 | final_f.write('\n') 20 | logger.info(f'Final result saved at ' + os.path.join(out_dir, 'final_result.log')) -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // 使用 IntelliSense 了解相关属性。 3 | // 悬停以查看现有属性的描述。 4 | // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Current File", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "justMyCode": true 14 | }, 15 | { 16 | "name": "Main", 17 | "type": "python", 18 | "request": "launch", 19 | "program": "${cwd}/main.py", 20 | "console": "integratedTerminal", 21 | "justMyCode": true 22 | }, 23 | { 24 | "name": "Dataset", 25 | "type": "python", 26 | "request": "launch", 27 | "program": "${cwd}/datasets/mimic_dataset.py", 28 | "console": "integratedTerminal", 29 | "justMyCode": true 30 | }, 31 | ] 32 | } -------------------------------------------------------------------------------- /documents/processing.md: -------------------------------------------------------------------------------- 1 | ## Generating sepsis3.csv 2 | 3 | The extracted sepsis3 table has **32971** rows, **25596** patients and contains the columns: 4 | - subject_id 5 | - stay_id 6 | - antibiotic_time 7 | - culture_time 8 | - suspected_infection_time 9 | - sofa_time 10 | - sofa_score 11 | - respiration 12 | - coagulation 13 | - liver 14 | - cardiovascular 15 | - cns 16 | - renal 17 | - sepsis3 18 | 19 | **step1: build postgresql** 20 | 21 | ``` 22 | createdb mimiciv 23 | 24 | cd ~/mimic-code/mimic-iv/buildmimic/postgres 25 | 26 | psql -d mimiciv -f create.sql 27 | 28 | psql -d mimiciv -v ON_ERROR_STOP=1 -v mimic_data_dir=/path/to/mimic-iv -f load.sql 29 | 30 | ``` 31 | **step2: build concepts** 32 | 33 | ``` 34 | cd ~/mimic-code/mimic-iv/concepts_postgres 35 | 36 | psql -d mimiciv 37 | 38 | \i postgres-functions.sql -- only needs to be run once 39 | 40 | \i postgres-make-concepts.sql 41 | 42 | ``` 43 | **step3: extract csv** 44 | 45 | ``` 46 | 47 | \copy (SELECT * FROM mimiciv_derived.sepsis3) TO '~/sepsis3.csv' WITH CSV HEADER; 48 | 49 | ``` -------------------------------------------------------------------------------- /models/ards_lstm_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch 4 | import tools 5 | from tools import logger as logger 6 | 7 | class ArdsLSTMModel(nn.Module): 8 | '''带预测窗口的多分类判别模型''' 9 | def __init__(self, in_channels:int, n_cls, hidden_size=128) -> None: 10 | super().__init__() 11 | self.in_channels = in_channels 12 | self.n_cls = n_cls 13 | self.hidden_size = hidden_size 14 | 15 | self.norm = nn.BatchNorm1d(num_features=in_channels) 16 | self.den = nn.Linear(in_features=hidden_size, out_features=n_cls) 17 | self.lstm = nn.LSTM(input_size=in_channels, hidden_size=hidden_size, batch_first=True) 18 | 19 | def forward(self, x:torch.Tensor): 20 | # x: (batch, feature, time) 21 | x = self.norm(x) 22 | # x: (batch, feature, time) 23 | x = x.transpose(1,2) 24 | # x: (batch, time, feature) out带有tanh 25 | x, _ = self.lstm(x) 26 | # x: (batch, time, hidden_size) 27 | x = self.den(x) 28 | # x: (batch, time, n_cls) 29 | return x -------------------------------------------------------------------------------- /models/vent_lstm_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch 4 | import tools 5 | from tools import logger as logger 6 | 7 | class VentLSTMModel(nn.Module): 8 | '''带预测窗口的多分类判别模型''' 9 | def __init__(self, in_channels:int, n_cls, hidden_size=128) -> None: 10 | super().__init__() 11 | self.in_channels = in_channels 12 | self.n_cls = n_cls 13 | self.hidden_size = hidden_size 14 | 15 | self.norm = nn.BatchNorm1d(num_features=in_channels) 16 | self.den = nn.Linear(in_features=hidden_size, out_features=n_cls) 17 | self.lstm = nn.LSTM(input_size=in_channels, hidden_size=hidden_size, batch_first=True) 18 | 19 | def forward(self, x:torch.Tensor): 20 | # x: (batch, feature, time) 21 | x = self.norm(x) 22 | # x: (batch, feature, time) 23 | x = x.transpose(1,2) 24 | # x: (batch, time, feature) out带有tanh 25 | x, _ = self.lstm(x) 26 | # x: (batch, time, hidden_size) 27 | x = self.den(x) 28 | # x: (batch, time, n_cls) 29 | return x -------------------------------------------------------------------------------- /configs/global_config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | work_dir: '.' 3 | paths: 4 | out_dir: outputs 5 | mimic-iv-ards: 6 | mimic_dir: data/mimic-iv 7 | cache_dir: data/mimic-iv-ards/cache 8 | conf_manual_path: configs/mimiciv_dataset_ards.yaml 9 | sepsis_patient_path: data/mimic-iv/sepsis_result/sepsis3.csv 10 | out_dir: outputs/mimic-iv-ards 11 | mimic-iv-raw: 12 | mimic_dir: data/mimic-iv 13 | cache_dir: data/mimic-iv-raw/cache 14 | conf_manual_path: configs/mimiciv_dataset_raw.yaml 15 | out_dir: outputs/mimic-iv-raw 16 | mimic-iv-vent: 17 | mimic_dir: data/mimic-iv 18 | cache_dir: data/mimic-iv-vent/cache 19 | conf_manual_path: configs/mimiciv_dataset_vent.yaml 20 | ventilation_path: data/mimic-iv/ventilation_result/ventilation.csv 21 | out_dir: outputs/mimic-iv-vent 22 | cv: 23 | conf_manual_path: configs/cv_dataset.yml 24 | data_dir: data/cv/单病种数据导出.csv 25 | cache_dir: data/cv/cache 26 | out_dir: outputs/cv 27 | 28 | 29 | analyzer: 30 | data_container: 31 | n_fold: 5 32 | seed: 100 33 | -------------------------------------------------------------------------------- /tools/config_loader.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | 4 | 5 | ''' 6 | ConfigLoader: load and process global configs from a single yaml file 7 | work directory will be appended if 'work_dir' in the keys of dict 8 | ''' 9 | class ConfigLoader(): 10 | def __init__(self, glob_conf_path:str) -> None: 11 | self.glob_conf_path = glob_conf_path 12 | with open(self.glob_conf_path, 'r',encoding='utf-8') as fp: 13 | self.glob_conf = yaml.load(fp, Loader=yaml.SafeLoader) 14 | # append work_dir to every values under 'paths' 15 | self.glob_conf['paths'] = self.process_path(self.glob_conf['work_dir'], self.glob_conf['paths']) 16 | 17 | def process_path(self, root, conf_dict): 18 | for key in conf_dict: 19 | if isinstance(conf_dict[key], str): 20 | conf_dict[key] = os.path.join(root, conf_dict[key]) 21 | elif isinstance(conf_dict[key], dict): 22 | conf_dict[key] = self.process_path(root, conf_dict[key]) 23 | return conf_dict 24 | 25 | def __getitem__(self, __key: str): 26 | return self.glob_conf[__key] 27 | 28 | 29 | 30 | GLOBAL_CONF_LOADER = ConfigLoader('./configs/global_config.yaml') 31 | 32 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | from matplotlib import pyplot as plt 3 | import tools 4 | import numpy as np 5 | import pandas as pd 6 | 7 | def plot_missrate_comp(): 8 | processed_row = tools.load_pkl('outputs/feature_explore[ards@origin]/row_missrate.pkl').flatten() 9 | processed_col = tools.load_pkl('outputs/feature_explore[ards@origin]/col_missrate.pkl').flatten() 10 | raw_row = tools.load_pkl('outputs/feature_explore[raw@version]/row_missrate.pkl').flatten() 11 | raw_col = tools.load_pkl('outputs/feature_explore[raw@version]/col_missrate.pkl').flatten() 12 | row_data = np.concatenate([processed_row, raw_row], axis=0) 13 | col_data = np.concatenate([processed_col, raw_col], axis=0) 14 | for data, label in zip([row_data, col_data], ['row', 'col']): 15 | df = pd.DataFrame(data, columns=['data']) 16 | df['source'] = 'raw' 17 | lens = len(processed_row) if label == 'row' else len(processed_col) 18 | df.loc[:lens, 'source'] = 'processed' 19 | sns.histplot(df, x='data', hue='source', bins=20, stat='proportion', common_norm=False, shrink=0.95, element='bars', edgecolor=None) 20 | plt.xlabel(f'{label} missrate') 21 | plt.savefig(f'test_plot/{label}_missrate.png') 22 | plt.close() 23 | 24 | 25 | 26 | if __name__ == '__main__': 27 | plot_missrate_comp() -------------------------------------------------------------------------------- /tools/logging.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | logger.info('Import logger from loguru') 6 | 7 | 8 | class SummaryWriter: 9 | '''A simple version of summary writer that mimics tensorboard''' 10 | def __init__(self) -> None: 11 | self.clear() 12 | 13 | def add_scalar(self, tag, value, global_step): 14 | self.data[tag] = self.data[tag] + [(value, global_step)] if tag in self.data else [(value, global_step)] 15 | 16 | def clear(self): 17 | self.data = {} 18 | 19 | def plot(self, tags:[list, str], k_fold=False, log_y=False, title='Title', out_path:str=None): 20 | plt.figure(figsize=(12, 8)) 21 | plt.title(title) 22 | tags:list = [tags] if isinstance(tags, str) else tags 23 | for tag in tags: 24 | tag_data = np.asarray(self.data[tag]) 25 | color = f'C{tags.index(tag)}' if not k_fold else 'grey' 26 | alpha = 1.0 if not k_fold else 0.5 27 | plt.plot(tag_data[:, 1], tag_data[:, 0], '-o', color=color, alpha=alpha) 28 | if k_fold: # tags are different folds. Plot mean and std bar 29 | ax = plt.gca() 30 | data = np.asarray([self.data[tag] for tag in tags]) 31 | x_ticks = data.mean(axis=0)[:, 1] 32 | mean_y = data.mean(axis=0)[:, 0] # (n_steps,) 33 | std_data = data[:, :, 0].std(axis=0) # (n_steps,) 34 | ax.fill_between(x_ticks, mean_y + std_data * 0.5, mean_y - std_data * 0.5, alpha=0.5, linewidth=0) 35 | plt.plot(x_ticks, mean_y, '-o', color='C0') 36 | if log_y: 37 | plt.yscale('log') 38 | if out_path is not None: 39 | plt.savefig(out_path) 40 | plt.close() 41 | -------------------------------------------------------------------------------- /documents/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # CHANGELOG 2 | 3 | ## 2023.8.18 4 | 5 | - Fixed the missing cache folder when mimic_dataset is first preprocessed, or you can fix it manually by creating a new `cache` empty folder under `data/mimic-iv`. 6 | - Added instructions for configuring the pytorch CUDA version in the readme 7 | 8 | ## 2023.10.26 9 | 10 | Upgrade to generic framework 11 | 12 | ### Important changes 13 | - Improved code readability 14 | - Decoupled generic code and Sepsis/ARDS-specific processing code, customized processing by modifying abstract functions through derived classes. 15 | - Improved data processing speed, it takes about 40 minutes to process the whole MIMIC-IV dataset. 16 | - Include new data such as hosp.labevents/ed.vitalsign, linking MIMIC-IV and ED datasets. 17 | - Fixed potential data leakage by changing linear interpolation to historical nearest neighbor padding 18 | - 19 | 20 | ### Other changes 21 | - Compressed storage cache files and read only necessary files. 22 | - Iterative culling of high missing samples and high missing features, this algorithm gives better results 23 | - Configuration files are now in yaml format for easier annotation. 24 | 25 | ## 2023.11.24 26 | 27 | ### Important changes 28 | - remove everything about sepsis in mimic-iv-raw dataset 29 | - add example pipeline: ventilation 30 | - mimic_dataset->mimic_ards_dataset. All these 3 datasets are examples 31 | - New ways of interpolation 32 | - nearest_static will be changed to latest_static. It will only obtain latest history. return -1 if no history is founded 33 | - empty align target: push forward the start time when at least one dynamic feature is valid 34 | 35 | ### Other changes 36 | - fix bug: admittime < dischtime 37 | - fix bug: carrunit now is in static data 38 | - fix bug: generate label 39 | - empty align id is available 40 | 41 | ## 2023.11.30 42 | 43 | - compressing format default: xz -> gz (faster) 44 | - value->valuenum in hosp extraction 45 | - add category filter (>2%) 46 | - -------------------------------------------------------------------------------- /analyzer/analyzer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tools 3 | import os 4 | from tools import logger as logger 5 | from .container import DataContainer 6 | 7 | from analyzer.ards_catboost_dynamic import ARDSCatboostAnalyzer 8 | from analyzer.ards_lstm import ArdsLSTMAnalyzer 9 | from analyzer.ards_nearest_cls import ArdsNearest4ClsAnalyzer 10 | from analyzer.ards_logistic_regression import ArdsLogisticRegAnalyzer 11 | 12 | from analyzer.cross_validation import CV_Analyzer 13 | from analyzer.dataset_explore.dataset_report import DatasetReport 14 | 15 | from analyzer.dataset_explore.ards_explore import ArdsFeatureExplorer 16 | from analyzer.dataset_explore.raw_explore import RawFeatureExplorer 17 | from analyzer.dataset_explore.vent_explore import VentFeatureExplorer 18 | 19 | from analyzer.vent_catboost_dynamic import VentCatboostDynamicAnalyzer 20 | from analyzer.vent_nearest_cls import VentNearest3ClsAnalyzer 21 | from analyzer.vent_lstm import VentLSTMAnalyzer 22 | from analyzer.vent_logistic_regression import VentLogisticRegAnalyzer 23 | 24 | 25 | 26 | class Analyzer: 27 | def __init__(self, params:list) -> None: 28 | ''' 29 | params: startup script, otherwise you need to run_sub_analyzer manually, can be None 30 | ''' 31 | self.container = DataContainer() 32 | self.analyzer_dict = { 33 | 'ards_nearest_4cls': ArdsNearest4ClsAnalyzer, 34 | 'ards_catboost_dynamic': ARDSCatboostAnalyzer, 35 | 'ards_feature_explore': ArdsFeatureExplorer, 36 | 'ards_lstm': ArdsLSTMAnalyzer, 37 | 'ards_logistic_reg': ArdsLogisticRegAnalyzer, 38 | 39 | 'vent_feature_explore': VentFeatureExplorer, 40 | 'vent_catboost_dynamic': VentCatboostDynamicAnalyzer, 41 | 'vent_nearest_2cls': VentNearest3ClsAnalyzer, 42 | 'vent_lstm': VentLSTMAnalyzer, 43 | 'vent_logistic_reg': VentLogisticRegAnalyzer, 44 | 45 | 'cross_validation': CV_Analyzer, 46 | 'dataset_report': DatasetReport, 47 | 'raw_feature_explore': RawFeatureExplorer, 48 | 49 | } 50 | if params is not None: 51 | for name in params: 52 | for label in self.analyzer_dict.keys(): 53 | if label == name: 54 | self.run_sub_analyzer(name, label) 55 | break 56 | 57 | def run_sub_analyzer(self, analyzer_name, label): 58 | logger.info(f'Run Analyzer: {analyzer_name}') 59 | params = self.container.get_analyzer_params(analyzer_name) 60 | params['analyzer_name'] = analyzer_name 61 | sub_analyzer = self.analyzer_dict[label](params, self.container) 62 | sub_analyzer.run() 63 | # utils.create_final_result() -------------------------------------------------------------------------------- /tools/data/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import torch 3 | import numpy as np 4 | import tools 5 | import compress_pickle as pickle 6 | from os.path import exists, join as osjoin 7 | from abc import abstractmethod 8 | 9 | class Normalization(): 10 | def __init__(self, norm_dict:dict, total_keys:list) -> None: 11 | self.means = np.asarray([norm_dict[key]['mean'] for key in total_keys]) 12 | self.stds = np.asarray([norm_dict[key]['std'] for key in total_keys]) 13 | 14 | def restore(self, in_data, selected_idx): 15 | # restore de-norm data 16 | # in_data: (..., n_selected_fea, seqs_len) 17 | means, stds = self.means[selected_idx], self.stds[selected_idx] + 1e-4 18 | out = in_data * stds + means 19 | return out 20 | 21 | def __call__(self, in_data, selected_idx) -> Any: 22 | # in_data: (..., n_selected_fea, seqs_len) 23 | means, stds = self.means[selected_idx], self.stds[selected_idx] + 1e-4 24 | out = (in_data - means[:, None]) / (stds[:, None]) 25 | return out 26 | 27 | def Collect_Fn(data_list:list): 28 | result = {} 29 | result['data'] = np.stack([d['data'] for d in data_list], axis=0) 30 | result['length'] = np.asarray([d['length'] for d in data_list], dtype=np.int32) 31 | return result 32 | 33 | def unroll(x:np.ndarray, mask:np.ndarray): 34 | # x: (batch, n_fea, seqs_len) or (batch, seqs_len) or (batch, seqs_len, n_cls) 35 | # mask: (batch, seqs_len) 36 | assert(len(x.shape) <= 3 and len(mask.shape) == 2) 37 | if len(x.shape) == 2: 38 | return x.flatten()[mask.flatten()] 39 | elif x.shape[2] == mask.shape[1]: 40 | batch, n_fea, seqs_len = x.shape 41 | x = np.transpose(x, (0, 2, 1)).reshape((batch*seqs_len, n_fea)) 42 | return x[mask.flatten(), :] 43 | elif x.shape[1] == mask.shape[1]: 44 | batch, seqs_len, n_cls = x.shape 45 | x = x.reshape((batch*seqs_len, n_cls)) 46 | return x[mask.flatten(), :] 47 | else: 48 | assert(0) 49 | 50 | def map_func(a:np.ndarray, mapping:dict): 51 | ''' 52 | mapping: key=origin_idx, value=target_idx 53 | ''' 54 | a_shape = list(a.shape) 55 | n_targets = len(np.unique(list(mapping.values()))) 56 | a_shape[-1] = n_targets 57 | result = np.zeros(tuple(a_shape)) 58 | for k, v in mapping.items(): 59 | result[..., v] += a[..., k] 60 | return result 61 | 62 | def cal_label_weight(n_cls, label:np.ndarray): 63 | ''' 64 | Get the weight of n_cls inversely proportional to the number. 65 | label: (..., n_cls) 66 | return: (n_cls,) 67 | ''' 68 | hard_label = np.argmax(label, axis=-1).flatten() 69 | weight = np.asarray([np.mean(hard_label == c) for c in range(n_cls)]) 70 | weight = 1 / weight 71 | weight = weight / np.sum(weight) 72 | return weight 73 | 74 | -------------------------------------------------------------------------------- /tools/data/label_generator.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import torch 3 | import numpy as np 4 | import tools 5 | import compress_pickle as pickle 6 | from os.path import exists, join as osjoin 7 | from abc import abstractmethod 8 | 9 | class LabelGenerator(): 10 | def __init__(self) -> None: 11 | pass 12 | 13 | @abstractmethod 14 | def __call__(self, slice) -> Any: 15 | pass 16 | 17 | class LabelGenerator_cls(LabelGenerator): 18 | def __init__(self, centers:list, soft_label=False, smooth_band=0) -> None: 19 | super().__init__() 20 | self.centers = centers 21 | self.soft_label = soft_label 22 | self.smooth_band = smooth_band 23 | 24 | def __call__(self, target:np.ndarray) -> Any: 25 | if self.soft_label: 26 | return self.label_cls(self.centers, target, smooth_band=self.smooth_band) 27 | else: 28 | return self.label_cls(self.centers, target, smooth_band=0) 29 | 30 | def label_cls(self, centers:list, nums:np.ndarray, smooth_band:int): 31 | ''' 32 | centers: centers of each class, needs to be increasing, n_cls = len(centers) 33 | nums: input(in_shape,) can be arbitrary 34 | band: linear smoothing between two classes, band is the total width to be smoothed. 35 | When the inputs are outside the band (near the center or over the sides), they are hard labels, only inside the band they are soft labels. 36 | return: (... , len(centers)) where (...) = nums.shape 37 | ''' 38 | num_classes = len(centers) 39 | smoothed_labels = np.zeros((nums.shape + (num_classes,))) 40 | for i in range(num_classes-1): 41 | center_i = centers[i] 42 | center_j = centers[i+1] 43 | lower = 0.5*(center_i + center_j) - smooth_band/2 44 | upper = 0.5*(center_i + center_j) + smooth_band/2 45 | hard_i = np.logical_and(nums >= center_i, nums <= lower) 46 | hard_j = np.logical_and(nums < center_j, nums > upper) 47 | mask = np.logical_and(nums > lower, nums <= upper) 48 | if smooth_band > 0 and mask.any(): 49 | diff = (nums - center_i) / (center_j - center_i) 50 | smooth_i = 1 - diff 51 | smooth_j = diff 52 | smoothed_labels[..., i][mask] = smooth_i[mask] 53 | smoothed_labels[..., i+1][mask] = smooth_j[mask] 54 | smoothed_labels[..., i][hard_i] = 1 55 | smoothed_labels[..., i+1][hard_j] = 1 56 | smoothed_labels[..., 0][nums <= centers[0]] = 1 57 | smoothed_labels[..., -1][nums >= centers[-1]] = 1 58 | return smoothed_labels 59 | 60 | class LabelGenerator_origin(LabelGenerator): 61 | def __init__(self) -> None: 62 | super().__init__() 63 | 64 | def __call__(self, target) -> Any: 65 | return target[..., None] 66 | -------------------------------------------------------------------------------- /analyzer/ards_nearest_cls.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tools 3 | import os 4 | from tqdm import tqdm 5 | from tools import logger as logger 6 | from .container import DataContainer 7 | from tools.data import DynamicDataGenerator, LabelGenerator_cls, map_func, label_func_min 8 | from datasets.derived_ards_dataset import MIMICIV_ARDS_Dataset 9 | 10 | 11 | class ArdsNearest4ClsAnalyzer: 12 | def __init__(self, params:dict, container:DataContainer) -> None: 13 | self.params = params 14 | self.paths = params['paths'] 15 | self.dataset = MIMICIV_ARDS_Dataset() 16 | self.dataset.load_version(params['dataset_version']) 17 | self.model_name = self.params['analyzer_name'] 18 | self.target_idx = self.dataset.idx_dict['PF_ratio'] 19 | 20 | def predict(self, X_test:np.ndarray): 21 | ''' 22 | input: batch, n_fea, seq_len 23 | output: (test_batch, seq_len, n_cls) 24 | ''' 25 | prediction = np.zeros((X_test.shape[0], X_test.shape[2], 4)) 26 | target = X_test[:, self.target_idx, :] 27 | prediction[:,:,0] = np.logical_and(target > 0, target <= 100) 28 | prediction[:,:,1] = np.logical_and(target > 100, target <= 200) 29 | prediction[:,:,2] = np.logical_and(target > 200, target <= 300) 30 | prediction[:,:,3] = target > 300 31 | return prediction 32 | 33 | def train(self): 34 | pass 35 | 36 | def run(self): 37 | # step 1: init variables 38 | out_dir = os.path.join(self.paths['out_dir'], self.model_name) 39 | tools.reinit_dir(out_dir, build=True) 40 | # metric_2cls = tools.DichotomyMetric() 41 | metric_4cls = tools.MultiClassMetric(class_names=self.params['class_names'], out_dir=out_dir) 42 | generator = DynamicDataGenerator( 43 | window_points=self.params['window'], 44 | n_fea=len(self.dataset.total_keys), 45 | label_generator=LabelGenerator_cls( 46 | centers=self.params['centers'] 47 | ), 48 | label_func=label_func_min, 49 | target_idx=self.target_idx, 50 | limit_idx=[], 51 | forbidden_idx=[] 52 | ) 53 | # step 2: train and predict 54 | for idx, (train_index, valid_index, test_index) in enumerate(self.dataset.enumerate_kf()): 55 | result = generator(self.dataset.data[test_index, :, :], self.dataset.seqs_len[test_index]) 56 | X_test, Y_mask, Y_gt = result['data'], result['mask'], result['label'] 57 | Y_pred = self.predict(X_test) 58 | Y_pred = np.asarray(Y_pred) 59 | metric_4cls.add_prediction(Y_pred, Y_gt, Y_mask) # 去掉mask外的数据 60 | # metric_2cls.add_prediction(map_func(Y_pred)[..., 1].flatten()[Y_mask], map_func(Y_gt)[..., 1].flatten()[Y_mask]) 61 | 62 | metric_4cls.confusion_matrix(comment=self.model_name) 63 | # metric_2cls.plot_curve(curve_type='roc', title=f'{self.model_name} model ROC (4->2 cls)', save_path=os.path.join(out_dir, f'{self.model_name}_ROC.png')) 64 | 65 | with open(os.path.join(out_dir, 'result.txt'), 'w') as fp: 66 | print('Overall performance:', file=fp) 67 | metric_4cls.write_result(fp) -------------------------------------------------------------------------------- /analyzer/vent_nearest_cls.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tools 3 | import os 4 | from tqdm import tqdm 5 | from tools import logger as logger 6 | from .container import DataContainer 7 | from tools.data import DynamicDataGenerator, LabelGenerator_cls, label_func_max, map_func 8 | from datasets.derived_vent_dataset import MIMICIV_Vent_Dataset 9 | 10 | 11 | class VentNearest3ClsAnalyzer: 12 | def __init__(self, params:dict, container:DataContainer) -> None: 13 | self.params = params 14 | self.paths = params['paths'] 15 | self.dataset = MIMICIV_Vent_Dataset() 16 | self.dataset.load_version(params['dataset_version']) 17 | self.model_name = self.params['analyzer_name'] 18 | self.target_idx = self.dataset.idx_dict['vent_status'] 19 | 20 | def predict(self, X_test:np.ndarray): 21 | ''' 22 | input: batch, n_fea, seq_len 23 | output: (test_batch, seq_len, n_cls) 24 | ''' 25 | prediction = np.zeros((X_test.shape[0], X_test.shape[2], 3)) 26 | target = X_test[:, self.target_idx, :] 27 | prediction[:,:,0] = target <= 0.5 28 | prediction[:,:,1] = np.logical_and(target > 0.5, target <= 1.5) 29 | prediction[:,:,2] = np.logical_and(target > 1.5, target <= 2.5) 30 | return prediction 31 | 32 | def train(self): 33 | pass 34 | 35 | def run(self): 36 | # step 1: init variables 37 | out_dir = os.path.join(self.paths['out_dir'], self.model_name) 38 | tools.reinit_dir(out_dir, build=True) 39 | metric_2cls = tools.DichotomyMetric() 40 | metric_3cls = tools.MultiClassMetric(class_names=self.params['class_names'], out_dir=out_dir) 41 | generator = DynamicDataGenerator( 42 | window_points=self.params['window'], 43 | n_fea=len(self.dataset.total_keys), 44 | label_generator=LabelGenerator_cls( 45 | centers=self.params['centers'], 46 | ), 47 | label_func=label_func_max, 48 | target_idx=self.target_idx, 49 | limit_idx=[], 50 | forbidden_idx=[] 51 | ) 52 | # step 2: train and predict 53 | for idx, (train_index, valid_index, test_index) in enumerate(self.dataset.enumerate_kf()): 54 | result = generator(self.dataset.data[test_index, :, :], self.dataset.seqs_len[test_index]) 55 | X_test, Y_mask, Y_gt = result['data'], result['mask'], result['label'] 56 | Y_pred = self.predict(X_test) 57 | Y_pred = np.asarray(Y_pred) 58 | metric_3cls.add_prediction(Y_pred, Y_gt, Y_mask) # 去掉mask外的数据 59 | Y_mask = Y_mask.flatten() 60 | metric_2cls.add_prediction( 61 | map_func(Y_pred, mapping={0:0, 1:1, 2:1})[..., 1].flatten()[Y_mask], 62 | map_func(Y_gt, mapping={0:0, 1:1, 2:1})[..., 1].flatten()[Y_mask] 63 | ) 64 | 65 | metric_3cls.confusion_matrix(comment=self.model_name) 66 | metric_2cls.plot_curve(title=f'{self.model_name} model ROC (4->2 cls)', save_path=os.path.join(out_dir, f'{self.model_name}_ROC.png')) 67 | 68 | with open(os.path.join(out_dir, 'result.txt'), 'w') as fp: 69 | print('Overall performance:', file=fp) 70 | metric_3cls.write_result(fp) -------------------------------------------------------------------------------- /configs/mimiciv_dataset_raw.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | ## configs for mimic-iv-core 3 | 4 | data_linkage: 5 | ed: True # link mimic-iv-ED dataset 6 | hosp: True 7 | icu: True 8 | 9 | category_to_numeric: 10 | insurance: { Medicare: 1, Medicaid: 2, Other: 0, Default: 0} 11 | language: {ENGLISH: 1, Default: 0} 12 | marital_status: {WIDOWED: 1, SINGLE: 2, MARRIED: 3, DIVORCED: 4, Default: 0} 13 | race: 14 | ASIAN: 8 15 | ASIAN - ASIAN INDIAN: 20 16 | ASIAN - CHINESE: 9 17 | ASIAN - SOUTH EAST ASIAN: 18 18 | BLACK/AFRICAN: 14 19 | BLACK/AFRICAN AMERICAN: 2 20 | BLACK/CAPE VERDEAN: 11 21 | BLACK/CARIBBEAN ISLAND: 13 22 | Default: 0 23 | HISPANIC OR LATINO: 7 24 | HISPANIC/LATINO - DOMINICAN: 12 25 | HISPANIC/LATINO - GUATEMALAN: 19 26 | HISPANIC/LATINO - PUERTO RICAN: 5 27 | OTHER: 3 28 | PATIENT DECLINED TO ANSWER: 15 29 | PORTUGUESE: 17 30 | UNABLE TO OBTAIN: 16 31 | UNKNOWN: 4 32 | WHITE: 1 33 | WHITE - EASTERN EUROPEAN: 21 34 | WHITE - OTHER EUROPEAN: 6 35 | WHITE - RUSSIAN: 10 36 | careunit: 37 | Cardiac Surgery: 17 38 | Cardiac Vascular Intensive Care Unit (CVICU): 20 39 | Coronary Care Unit (CCU): 26 40 | Default: 0 41 | Discharge Lounge: 5 42 | Emergency Department: 1 43 | Emergency Department Observation: 4 44 | Hematology/Oncology: 9 45 | Hematology/Oncology Intermediate: 25 46 | Labor & Delivery: 15 47 | Med/Surg: 6 48 | Med/Surg/GYN: 12 49 | Med/Surg/Trauma: 13 50 | Medical Intensive Care Unit (MICU): 16 51 | Medical/Surgical (Gynecology): 27 52 | Medical/Surgical Intensive Care Unit (MICU/SICU): 19 53 | Medicine: 3 54 | Medicine/Cardiology: 7 55 | Neurology: 8 56 | Obstetrics (Postpartum & Antepartum): 21 57 | PACU: 14 58 | Psychiatry: 23 59 | Surgery: 28 60 | Surgery/Pancreatic/Biliary/Bariatric: 29 61 | Surgery/Trauma: 18 62 | Surgical Intensive Care Unit (SICU): 22 63 | Transplant: 11 64 | Trauma SICU (TSICU): 24 65 | Vascular: 10 66 | empty: 2 67 | 68 | remove_rule: 69 | pass1: 70 | duration_minmax: [2, 192] 71 | pass2: 72 | # gradually eliminate invalid cols/rows in several iterations 73 | # the last value in each list is an approximation of final missrate threshold 74 | max_col_missrate: [0.99] # there are >5000 features. They will cause OOM 75 | max_subject_missrate: [1.0] 76 | adm_select_strategy: default # default always select zero. Another option: random 77 | 78 | generate_table: 79 | align_target: [] 80 | default_missing_value: -1 81 | calculate_bin: avg # how to calculate multiple data points in a discrete bin: avg=use average value, latest: use last value 82 | delta_t_hour: 4.0 83 | 84 | k-fold: 5 85 | validation_proportion: 0.15 # use 0.15*80% samples for validation 86 | compress_cache: True # compress cache file to speed up loading 87 | compress_suffix: .gz # zip: fast compression but slow loading. xz: slow compression but fast loading (smaller file size) 88 | 89 | version: 90 | raw@version: 91 | fill_missvalue: none # do not fill miss value in feature explore status 92 | feature_limit: [] 93 | forbidden_feas: [] 94 | 95 | value_clip: {} 96 | 97 | -------------------------------------------------------------------------------- /analyzer/vent_logistic_regression.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tools 3 | import os 4 | from tqdm import tqdm 5 | from tools import logger as logger 6 | from .container import DataContainer 7 | from tools.data import SliceDataGenerator, LabelGenerator_cls, cal_label_weight, vent_label_func, Normalization 8 | from sklearn.linear_model import LogisticRegression 9 | from os.path import join as osjoin 10 | from datasets.derived_vent_dataset import MIMICIV_Vent_Dataset 11 | 12 | class VentLogisticRegAnalyzer: 13 | def __init__(self, params:dict, container:DataContainer) -> None: 14 | self.params = params 15 | self.paths = params['paths'] 16 | self.dataset = MIMICIV_Vent_Dataset() 17 | self.dataset.load_version(params['dataset_version']) 18 | self.model_name = self.params['analyzer_name'] 19 | self.target_idx = self.dataset.idx_dict['vent_status'] 20 | 21 | def run(self): 22 | # step 1: init variables 23 | out_dir = os.path.join(self.paths['out_dir'], self.model_name) 24 | tools.reinit_dir(out_dir, build=True) 25 | # metric_2cls = tools.DichotomyMetric() 26 | metric_2cls = tools.DichotomyMetric() 27 | generator = SliceDataGenerator( 28 | window_points=self.params['window'], 29 | n_fea=len(self.dataset.total_keys), 30 | label_generator=LabelGenerator_cls( 31 | centers=self.params['centers'] 32 | ), 33 | norm=Normalization(self.dataset.norm_dict, self.dataset.total_keys), 34 | label_func=vent_label_func, 35 | target_idx=self.target_idx, 36 | limit_idx=[], 37 | forbidden_idx=[self.dataset.idx_dict[id] for id in ['vent_status']] 38 | ) 39 | print(f'Available features: {[self.dataset.total_keys[idx] for idx in generator.avail_idx]}') 40 | # step 2: train and predict 41 | for fold_idx, (train_index, valid_index, test_index) in enumerate(self.dataset.enumerate_kf()): 42 | reg_train_index = np.concatenate([train_index, valid_index], axis=0) # lineat regression do not need validation 43 | train_result = generator(self.dataset.data[reg_train_index, :, :], self.dataset.seqs_len[reg_train_index]) 44 | X_train, Y_train = train_result['data'], train_result['label'] 45 | 46 | test_result = generator(self.dataset.data[test_index, :, :], self.dataset.seqs_len[test_index]) 47 | X_test, Y_test = test_result['data'], test_result['label'] 48 | 49 | class_weight = cal_label_weight(len(self.params['centers']), Y_train) 50 | class_weight = {idx:class_weight[idx] for idx in range(len(class_weight))} 51 | logger.info(f'class weight: {class_weight}') 52 | 53 | model = LogisticRegression(max_iter=self.params['max_iter'], multi_class='multinomial', class_weight=class_weight) 54 | model.fit(X_train, Y_train[:, 1]) 55 | 56 | Y_test_pred = model.predict_proba(X_test) 57 | metric_2cls.add_prediction(Y_test_pred[:, 1], Y_test[:, 1]) 58 | 59 | metric_2cls.plot_curve(curve_type='roc', title=f'ROC for ventilation', save_path=osjoin(out_dir, f'vent_roc.png')) 60 | metric_2cls.plot_curve(curve_type='prc', title=f'PRC for ventilation', save_path=osjoin(out_dir, f'vent_prc.png')) 61 | 62 | with open(os.path.join(out_dir, 'result.txt'), 'w') as fp: 63 | print('Overall performance:', file=fp) 64 | metric_2cls.write_result(fp) -------------------------------------------------------------------------------- /configs/mimiciv_dataset_ards.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | data_linkage: 3 | ed: False # link mimic-iv-ED dataset 4 | hosp: False 5 | icu: True 6 | 7 | category_to_numeric: 8 | insurance: { Medicare: 1, Medicaid: 2, Other: 0, Default: 0} 9 | language: {ENGLISH: 1, Default: 0} 10 | marital_status: {WIDOWED: 1, SINGLE: 2, MARRIED: 3, DIVORCED: 4, Default: 0} 11 | race: 12 | ASIAN: 8 13 | ASIAN - ASIAN INDIAN: 20 14 | ASIAN - CHINESE: 9 15 | ASIAN - SOUTH EAST ASIAN: 18 16 | BLACK/AFRICAN: 14 17 | BLACK/AFRICAN AMERICAN: 2 18 | BLACK/CAPE VERDEAN: 11 19 | BLACK/CARIBBEAN ISLAND: 13 20 | Default: 0 21 | HISPANIC OR LATINO: 7 22 | HISPANIC/LATINO - DOMINICAN: 12 23 | HISPANIC/LATINO - GUATEMALAN: 19 24 | HISPANIC/LATINO - PUERTO RICAN: 5 25 | OTHER: 3 26 | PATIENT DECLINED TO ANSWER: 15 27 | PORTUGUESE: 17 28 | UNABLE TO OBTAIN: 16 29 | UNKNOWN: 4 30 | WHITE: 1 31 | WHITE - EASTERN EUROPEAN: 21 32 | WHITE - OTHER EUROPEAN: 6 33 | WHITE - RUSSIAN: 10 34 | careunit: 35 | Cardiac Surgery: 17 36 | Cardiac Vascular Intensive Care Unit (CVICU): 20 37 | Coronary Care Unit (CCU): 26 38 | Default: 0 39 | Discharge Lounge: 5 40 | Emergency Department: 1 41 | Emergency Department Observation: 4 42 | Hematology/Oncology: 9 43 | Hematology/Oncology Intermediate: 25 44 | Labor & Delivery: 15 45 | Med/Surg: 6 46 | Med/Surg/GYN: 12 47 | Med/Surg/Trauma: 13 48 | Medical Intensive Care Unit (MICU): 16 49 | Medical/Surgical (Gynecology): 27 50 | Medical/Surgical Intensive Care Unit (MICU/SICU): 19 51 | Medicine: 3 52 | Medicine/Cardiology: 7 53 | Neurology: 8 54 | Obstetrics (Postpartum & Antepartum): 21 55 | PACU: 14 56 | Psychiatry: 23 57 | Surgery: 28 58 | Surgery/Pancreatic/Biliary/Bariatric: 29 59 | Surgery/Trauma: 18 60 | Surgical Intensive Care Unit (SICU): 22 61 | Transplant: 11 62 | Trauma SICU (TSICU): 24 63 | Vascular: 10 64 | empty: 2 65 | 66 | remove_rule: 67 | pass1: 68 | target_id: ['220224', '223835'] 69 | duration_minmax: [2, 96] 70 | check_sepsis_time: [-30, 10] 71 | pass2: 72 | # gradually eliminate invalid cols/rows in several iterations 73 | # the last value in each list is an approximation of final missrate threshold 74 | max_col_missrate: [0.9, 0.8, 0.5] 75 | max_subject_missrate: [0.9, 0.8, 0.5] 76 | adm_select_strategy: default # default always select zero. Another option: random 77 | 78 | generate_table: 79 | align_target: ['220224', '223835'] 80 | default_missing_value: -1 81 | calculate_bin: avg # how to calculate multiple data points in a discrete bin: avg=use average value, latest: use last value 82 | delta_t_hour: 0.5 83 | 84 | k-fold: 5 85 | validation_proportion: 0.15 # use 0.15*80% samples for validation 86 | compress_cache: True # compress cache file to speed up loading 87 | compress_suffix: .gz # zip: fast compression but slow loading. xz: slow compression but fast loading (smaller file size) 88 | 89 | version: 90 | ards@origin: 91 | fill_missvalue: none # do not fill miss value in feature explore status 92 | feature_limit: [] 93 | forbidden_feas: [] 94 | ards@filled: 95 | fill_missvalue: avg 96 | feature_limit: [] 97 | forbidden_feas: [Weight, BMI, Height (Inches), '225677', '220545', '226730', '226531', Weight (Lbs), '227467', '220339', '227443', '220587', 'dod'] 98 | value_clip: 99 | Inspired O2 Fraction: 100 | min: 21 101 | max: 100 102 | Arterial O2 pressure: 103 | min: 0 104 | max: 500 105 | 106 | -------------------------------------------------------------------------------- /analyzer/ards_logistic_regression.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tools 3 | import os 4 | from tqdm import tqdm 5 | from tools import logger as logger 6 | from .container import DataContainer 7 | from tools.data import SliceDataGenerator, LabelGenerator_cls, cal_label_weight, label_func_min, Normalization, map_func 8 | from sklearn.linear_model import LogisticRegression 9 | from os.path import join as osjoin 10 | from datasets.derived_ards_dataset import MIMICIV_ARDS_Dataset 11 | 12 | class ArdsLogisticRegAnalyzer: 13 | def __init__(self, params:dict, container:DataContainer) -> None: 14 | self.params = params 15 | self.paths = params['paths'] 16 | self.dataset = MIMICIV_ARDS_Dataset() 17 | self.dataset.load_version(params['dataset_version']) 18 | self.model_name = self.params['analyzer_name'] 19 | self.target_idx = self.dataset.idx_dict['PF_ratio'] 20 | 21 | def run(self): 22 | # step 1: init variables 23 | out_dir = os.path.join(self.paths['out_dir'], self.model_name) 24 | tools.reinit_dir(out_dir, build=True) 25 | # metric_2cls = tools.DichotomyMetric() 26 | metric_4cls = tools.MultiClassMetric(class_names=self.params['class_names'], out_dir=out_dir) 27 | metric_2cls = [tools.DichotomyMetric() for _ in range(4)] 28 | generator = SliceDataGenerator( 29 | window_points=self.params['window'], 30 | n_fea=len(self.dataset.total_keys), 31 | label_generator=LabelGenerator_cls( 32 | centers=self.params['centers'] 33 | ), 34 | norm=Normalization(self.dataset.norm_dict, self.dataset.total_keys), 35 | label_func=label_func_min, 36 | target_idx=self.target_idx, 37 | limit_idx=[self.dataset.fea_idx(id) for id in self.params['limit_feas']], 38 | forbidden_idx=[self.dataset.fea_idx(id) for id in self.params['forbidden_feas']] 39 | ) 40 | feature_names = [self.dataset.fea_label(idx) for idx in generator.avail_idx] 41 | print(f'Available features: {feature_names}') 42 | # step 2: train and predict 43 | for fold_idx, (train_index, valid_index, test_index) in enumerate(self.dataset.enumerate_kf()): 44 | reg_train_index = np.concatenate([train_index, valid_index], axis=0) # lineat regression do not need validation 45 | train_result = generator(self.dataset.data[reg_train_index, :, :], self.dataset.seqs_len[reg_train_index]) 46 | X_train, Y_train = train_result['data'], train_result['label'] 47 | 48 | test_result = generator(self.dataset.data[test_index, :, :], self.dataset.seqs_len[test_index]) 49 | X_test, Y_test = test_result['data'], test_result['label'] 50 | 51 | class_weight = cal_label_weight(len(self.params['centers']), Y_train) 52 | class_weight = {idx:class_weight[idx] for idx in range(len(class_weight))} 53 | logger.info(f'class weight: {class_weight}') 54 | 55 | model = LogisticRegression(max_iter=self.params['max_iter'], multi_class='multinomial', class_weight=class_weight) 56 | model.fit(X_train, np.argmax(Y_train, axis=-1)) 57 | 58 | Y_test_pred = model.predict_proba(X_test) 59 | metric_4cls.add_prediction(Y_test_pred, Y_test) # 去掉mask外的数据 60 | for idx, map_dict in zip([0,1,2,3], [{0:0,1:1,2:1,3:1}, {0:0,1:1,2:0,3:0}, {0:0,1:0,2:1,3:0}, {0:0,1:0,2:0,3:1}]): # TODO 这里写错了 61 | metric_2cls[idx].add_prediction(map_func(Y_test_pred, map_dict)[:, 1], map_func(Y_test, map_dict)[:, 1]) 62 | 63 | metric_4cls.confusion_matrix(comment=self.model_name) 64 | for idx in range(4): 65 | metric_2cls[idx].plot_curve(curve_type='roc', title=f'ROC for {self.params["class_names"][idx]}', save_path=osjoin(out_dir, f'roc_cls_{idx}.png')) 66 | 67 | with open(os.path.join(out_dir, 'result.txt'), 'w') as fp: 68 | print('Overall performance:', file=fp) 69 | metric_4cls.write_result(fp) -------------------------------------------------------------------------------- /tools/module_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from matplotlib.colors import to_rgb 4 | 5 | 6 | class SummaryWriter: 7 | '''A simple version of summary writer that mimics tensorboard''' 8 | def __init__(self) -> None: 9 | self.clear() 10 | 11 | def add_scalar(self, tag, value, global_step): 12 | self.data[tag] = self.data[tag] + [(value, global_step)] if tag in self.data else [(value, global_step)] 13 | 14 | def clear(self): 15 | self.data = {} 16 | 17 | def plot(self, tags:[list, str], k_fold=False, log_y=False, title='Title', out_path:str=None): 18 | plt.figure(figsize=(12, 8)) 19 | plt.title(title) 20 | tags:list = [tags] if isinstance(tags, str) else tags 21 | for tag in tags: 22 | tag_data = np.asarray(self.data[tag]) 23 | color = f'C{tags.index(tag)}' if not k_fold else 'grey' 24 | alpha = 1.0 if not k_fold else 0.5 25 | plt.plot(tag_data[:, 1], tag_data[:, 0], '-o', color=color, alpha=alpha) 26 | if k_fold: # tags are different folds. Plot mean and std bar 27 | ax = plt.gca() 28 | data = np.asarray([self.data[tag] for tag in tags]) 29 | x_ticks = data.mean(axis=0)[:, 1] 30 | mean_y = data.mean(axis=0)[:, 0] # (n_steps,) 31 | std_data = data[:, :, 0].std(axis=0) # (n_steps,) 32 | ax.fill_between(x_ticks, mean_y + std_data * 0.5, mean_y - std_data * 0.5, alpha=0.5, linewidth=0) 33 | plt.plot(x_ticks, mean_y, '-o', color='C0') 34 | if log_y: 35 | plt.yscale('log') 36 | if out_path is not None: 37 | plt.savefig(out_path) 38 | plt.close() 39 | 40 | 41 | def plot_stack_proportion(data:dict[str, tuple], out_path=None): 42 | plt.figure(figsize=(15, 8)) 43 | 44 | names = list(data.keys()) 45 | style = [to_rgb(f'C{idx}') for idx in range(10)] 46 | plt.barh(names, [0 for _ in names]) 47 | idx = 0 48 | for k_idx, (key, (x, label)) in enumerate(data.items()): 49 | x_sum = 0 50 | for idx in range(len(x)): 51 | color = np.asarray(style[k_idx % 10]) 52 | color = np.clip(color * (1 - idx/len(x)) + 1.0*(idx/len(x)), 0, 0.95) 53 | plt.barh([key], x[idx], left=x_sum, color=tuple(color)) 54 | label_wid = len(label[idx])*0.005 55 | if x[idx] > label_wid: 56 | plt.annotate(label[idx], (x_sum + x[idx]*0.5 - label_wid*0.5, k_idx), fontsize=10) 57 | x_sum += x[idx] 58 | 59 | plt.xlim(left=0, right=1) 60 | plt.savefig(out_path) 61 | 62 | if __name__ == '__main__': 63 | wrt = SummaryWriter() 64 | for fold in range(5): 65 | for idx in range(10): 66 | wrt.add_scalar(f'tr_{fold}', np.random.rand(), idx) 67 | wrt.plot([f'tr_{fold}' for fold in range(5)], k_fold=True, log_y=False, title='test', out_path='test_plot/test_writer.png') 68 | 69 | 70 | def interp(fx:np.ndarray, fy:np.ndarray, x_start:float, interval:float, n_bins:int, missing=-1, fill_bin=['avg', 'latest']): 71 | # fx, (N,), N >= 1, sample time for each data point (irrgular) 72 | # fy: same size as fx, sample value for each data point 73 | # x: dim=1 74 | assert(fx.shape[0] == fy.shape[0] and len(fx.shape) == len(fy.shape) and len(fx.shape) == 1 and fx.shape[0] >= 1) 75 | assert(interval > 0 and n_bins > 0) 76 | assert(fill_bin in ['avg', 'latest']) 77 | result = np.ones((n_bins)) * missing 78 | 79 | for idx in range(n_bins): 80 | t_bin_start = x_start + (idx - 1) * interval 81 | t_bin_end = x_start + idx * interval 82 | valid_mask = np.logical_and(fx > t_bin_start, fx <= t_bin_end) # (start, end] 83 | if np.any(valid_mask): # have at least one point 84 | if fill_bin == 'avg': 85 | result[idx] = np.mean(fy[valid_mask]) 86 | elif fill_bin == 'latest': 87 | result[idx] = fy[valid_mask][-1] 88 | else: 89 | assert(0) 90 | else: # no point in current bin 91 | if idx == 0: 92 | result[idx] = missing 93 | else: 94 | result[idx] = result[idx-1] # history is always available 95 | return result -------------------------------------------------------------------------------- /configs/mimiciv_dataset_vent.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | ## configs for mimic-iv-core 3 | 4 | data_linkage: 5 | ed: True # link mimic-iv-ED dataset 6 | hosp: True 7 | icu: True 8 | 9 | category_to_numeric: 10 | insurance: { Medicare: 1, Medicaid: 2, Other: 0, Default: 0} 11 | language: {ENGLISH: 1, Default: 0} 12 | marital_status: {WIDOWED: 1, SINGLE: 2, MARRIED: 3, DIVORCED: 4, Default: 0} 13 | race: 14 | ASIAN: 8 15 | ASIAN - ASIAN INDIAN: 20 16 | ASIAN - CHINESE: 9 17 | ASIAN - SOUTH EAST ASIAN: 18 18 | BLACK/AFRICAN: 14 19 | BLACK/AFRICAN AMERICAN: 2 20 | BLACK/CAPE VERDEAN: 11 21 | BLACK/CARIBBEAN ISLAND: 13 22 | Default: 0 23 | HISPANIC OR LATINO: 7 24 | HISPANIC/LATINO - DOMINICAN: 12 25 | HISPANIC/LATINO - GUATEMALAN: 19 26 | HISPANIC/LATINO - PUERTO RICAN: 5 27 | OTHER: 3 28 | PATIENT DECLINED TO ANSWER: 15 29 | PORTUGUESE: 17 30 | UNABLE TO OBTAIN: 16 31 | UNKNOWN: 4 32 | WHITE: 1 33 | WHITE - EASTERN EUROPEAN: 21 34 | WHITE - OTHER EUROPEAN: 6 35 | WHITE - RUSSIAN: 10 36 | careunit: 37 | Cardiac Surgery: 17 38 | Cardiac Vascular Intensive Care Unit (CVICU): 20 39 | Coronary Care Unit (CCU): 26 40 | Default: 0 41 | Discharge Lounge: 5 42 | Emergency Department: 1 43 | Emergency Department Observation: 4 44 | Hematology/Oncology: 9 45 | Hematology/Oncology Intermediate: 25 46 | Labor & Delivery: 15 47 | Med/Surg: 6 48 | Med/Surg/GYN: 12 49 | Med/Surg/Trauma: 13 50 | Medical Intensive Care Unit (MICU): 16 51 | Medical/Surgical (Gynecology): 27 52 | Medical/Surgical Intensive Care Unit (MICU/SICU): 19 53 | Medicine: 3 54 | Medicine/Cardiology: 7 55 | Neurology: 8 56 | Obstetrics (Postpartum & Antepartum): 21 57 | PACU: 14 58 | Psychiatry: 23 59 | Surgery: 28 60 | Surgery/Pancreatic/Biliary/Bariatric: 29 61 | Surgery/Trauma: 18 62 | Surgical Intensive Care Unit (SICU): 22 63 | Transplant: 11 64 | Trauma SICU (TSICU): 24 65 | Vascular: 10 66 | empty: 2 67 | 68 | 69 | remove_rule: 70 | pass1: 71 | duration_minmax: [4, 96] 72 | pass2: 73 | # gradually eliminate invalid cols/rows in several iterations 74 | # the last value in each list is an approximation of final missrate threshold 75 | max_col_missrate: [0.95, 0.95, 0.90] 76 | max_subject_missrate: [0.95, 0.9, 0.5] 77 | adm_select_strategy: default # default always select the first available admission. Another option: random 78 | 79 | generate_table: 80 | align_target: [] # no target: pipeline will use the start time of nearest valid feature 81 | default_missing_value: -1 82 | calculate_bin: avg # how to calculate multiple data points in a discrete bin: avg=use average value, latest: use last value 83 | delta_t_hour: 1.0 84 | 85 | k-fold: 5 86 | validation_proportion: 0.15 # use X*80% samples for validation 87 | compress_cache: True # compress cache file to speed up loading, it will slow down cache dumping 88 | compress_suffix: .gz # zip: fast compression but slow loading. xz: slow compression but fast loading (smaller file size) 89 | 90 | 91 | version: 92 | vent@origin: 93 | fill_missvalue: none 94 | feature_limit: [] 95 | forbidden_feas: [ 96 | Weight, BMI, Height (Inches), Weight (Lbs), dod, Specific Gravity, 97 | ventilation_start, ventilation_end, ventilation_num, hosp_expire 98 | ] 99 | vent@filled: 100 | fill_missvalue: avg 101 | feature_limit: [] 102 | forbidden_feas: [ 103 | Weight, BMI, Height (Inches), Weight (Lbs), dod, Specific Gravity, 104 | ventilation_start, ventilation_end, ventilation_num, hosp_expire 105 | ] 106 | 107 | value_clip: {} 108 | 109 | ## configs for mimic-iv-vent 110 | 111 | # 0: no ventilation, 1: non-invasive, 2:invasive 112 | ventilation_to_numeric: { 113 | SupplementalOxygen: 1, 114 | HFNC: 1, 115 | NonInvasiveVent: 1, 116 | InvasiveVent: 2, 117 | Tracheostomy: 2 118 | } 119 | 120 | -------------------------------------------------------------------------------- /analyzer/ards_catboost_dynamic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tools 3 | import os 4 | from os.path import join as osjoin 5 | from tqdm import tqdm 6 | from tools import logger as logger 7 | from .container import DataContainer 8 | from tools.data import SliceDataGenerator, LabelGenerator_cls, cal_label_weight, label_func_min, map_func 9 | from catboost import Pool, CatBoostClassifier 10 | from tools.feature_importance import TreeFeatureImportance 11 | from datasets.derived_ards_dataset import MIMICIV_ARDS_Dataset 12 | 13 | class ARDSCatboostAnalyzer: 14 | def __init__(self, params:dict, container:DataContainer) -> None: 15 | self.params = params 16 | self.paths = params['paths'] 17 | self.dataset = MIMICIV_ARDS_Dataset() 18 | self.dataset.load_version(params['dataset_version']) 19 | self.model_name = self.params['analyzer_name'] 20 | self.target_idx = self.dataset.idx_dict['PF_ratio'] 21 | 22 | def run(self): 23 | # step 1: init variables 24 | out_dir = os.path.join(self.paths['out_dir'], self.model_name) 25 | tools.reinit_dir(out_dir, build=True) 26 | # metric_2cls = tools.DichotomyMetric() 27 | metric_4cls = tools.MultiClassMetric(class_names=self.params['class_names'], out_dir=out_dir) 28 | metric_2cls = [tools.DichotomyMetric() for _ in range(4)] 29 | generator = SliceDataGenerator( 30 | window_points=self.params['window'], 31 | n_fea=len(self.dataset.total_keys), 32 | label_generator=LabelGenerator_cls( 33 | centers=self.params['centers'] 34 | ), 35 | label_func=label_func_min, 36 | target_idx=self.target_idx, 37 | limit_idx=[self.dataset.fea_idx(id) for id in self.params['limit_feas']], 38 | forbidden_idx=[self.dataset.fea_idx(id) for id in self.params['forbidden_feas']] 39 | ) 40 | feature_names = [self.dataset.fea_label(idx) for idx in generator.avail_idx] 41 | print(f'Available features: {feature_names}') 42 | # step 2: train and predict 43 | for fold_idx, (train_index, valid_index, test_index) in enumerate(self.dataset.enumerate_kf()): 44 | train_result = generator(self.dataset.data[train_index, :, :], self.dataset.seqs_len[train_index]) 45 | X_train, Y_train = train_result['data'], train_result['label'] 46 | 47 | valid_result = generator(self.dataset.data[valid_index, :, :], self.dataset.seqs_len[valid_index]) 48 | X_valid, Y_valid = valid_result['data'], valid_result['label'] 49 | 50 | test_result = generator(self.dataset.data[test_index, :, :], self.dataset.seqs_len[test_index]) 51 | X_test, Y_test = test_result['data'], test_result['label'] 52 | 53 | model = CatBoostClassifier( 54 | iterations=self.params['iterations'], 55 | learning_rate=self.params['learning_rate'], 56 | loss_function=self.params['loss_function'], 57 | class_weights=cal_label_weight(len(self.params['centers']), Y_train), 58 | use_best_model=True 59 | ) 60 | pool_train = Pool(X_train, Y_train.argmax(axis=-1)) 61 | pool_valid = Pool(X_valid, Y_valid.argmax(axis=-1)) 62 | 63 | model.fit(pool_train, eval_set=pool_valid) 64 | 65 | Y_test_pred = model.predict_proba(X_test) 66 | metric_4cls.add_prediction(Y_test_pred, Y_test) # 去掉mask外的数据 67 | for idx, map_dict in zip([0,1,2,3], [{0:0,1:1,2:1,3:1}, {0:0,1:1,2:0,3:0}, {0:0,1:0,2:1,3:0}, {0:0,1:0,2:0,3:1}]): # TODO 这里写错了 68 | metric_2cls[idx].add_prediction(map_func(Y_test_pred, map_dict)[:, 1], map_func(Y_test, map_dict)[:, 1]) 69 | if fold_idx == 0: 70 | explorer = TreeFeatureImportance(map_func=lambda x:x[:, :, 1], fea_names=feature_names, missvalue=-1, n_approx=2000) 71 | explorer.add_record(model, valid_X=X_valid) 72 | explorer.plot_beeswarm(max_disp=10, plot_path=osjoin(out_dir, f'importance.png')) 73 | 74 | metric_4cls.confusion_matrix(comment=self.model_name) 75 | for idx in range(4): 76 | metric_2cls[idx].plot_curve(curve_type='roc', title=f'ROC for {self.params["class_names"][idx]}', save_path=osjoin(out_dir, f'roc_cls_{idx}.png')) 77 | 78 | with open(os.path.join(out_dir, 'result.txt'), 'w') as fp: 79 | print('Overall performance:', file=fp) 80 | metric_4cls.write_result(fp) -------------------------------------------------------------------------------- /tools/generic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import yaml 3 | import datetime 4 | from sklearn import random as sk_random 5 | import pandas as pd 6 | import pickle 7 | import os, sys 8 | import missingno as msno 9 | import hashlib 10 | from .logging import logger 11 | 12 | def reinit_dir(write_dir_path=None, build=True): 13 | '''清除并且重建一个文件夹和其中所有的内容''' 14 | if write_dir_path is not None: 15 | if os.path.exists(write_dir_path): 16 | for name in os.listdir(write_dir_path): 17 | p = os.path.join(write_dir_path, name) 18 | if os.path.isdir(p): 19 | reinit_dir(p, build=False) 20 | os.rmdir(p) 21 | elif os.path.isfile(p): 22 | os.remove(p) 23 | if build: 24 | os.makedirs(write_dir_path, exist_ok=True) 25 | 26 | ''' 27 | 设置matplotlib显示中文, 对于pandas_profile不可用 28 | ''' 29 | def set_chinese_font(): 30 | logger.info("Set Chinese Font in Matplotlib") 31 | from matplotlib import pyplot as plt 32 | plt.rcParams['font.family'] = ['Arial Unicode MS'] 33 | 34 | def save_pkl(obj, path): 35 | with open(path, 'wb') as f: 36 | pickle.dump(obj, f) 37 | 38 | def load_pkl(path): 39 | with open(path, 'rb') as f: 40 | result = pickle.load(f) 41 | return result 42 | 43 | # 清空文件 44 | def clear_file(name): 45 | with open(name, 'w+'): 46 | pass 47 | 48 | def set_sk_random_seed(seed:int=100): 49 | sk_random.seed(seed) 50 | 51 | def remove_slash(name:str): 52 | return name.replace('/','%') 53 | 54 | 55 | def cal_file_md5(filename:str) -> str: 56 | '''输入文件名, 返回文件的MD5字符串''' 57 | with open(filename, 'rb') as fp: 58 | data = fp.read() 59 | file_md5= hashlib.md5(data).hexdigest() 60 | return file_md5 61 | 62 | 63 | def assert_no_na(dataset:pd.DataFrame): 64 | try: 65 | assert(not np.any(dataset.isna().to_numpy())) 66 | except Exception as e: 67 | na_mat = dataset.isna() 68 | for col in dataset.columns: 69 | if np.any(na_mat[col].to_numpy()): 70 | logger.error(f'assert_na: NA in feature:{col}') 71 | assert(0) 72 | 73 | class Config: 74 | ''' 75 | 加载配置表 76 | cache_path: 自动配置 77 | manual_path: 手动配置 78 | ''' 79 | def __init__(self, conf_path) -> None: 80 | self.conf_path = conf_path 81 | with open(conf_path, 'r', encoding='utf-8') as fp: 82 | self.configs = yaml.load(fp, Loader=yaml.SafeLoader) 83 | 84 | def __getitem__(self, idx): 85 | return self.configs.copy()[idx] 86 | 87 | def dump(self): 88 | with open(self.conf_path, 'w', encoding='utf-8') as fp: 89 | yaml.dump(self.configs, fp) 90 | 91 | class TimeConverter: 92 | ''' 93 | 将一段时间字符串转化为时间戳 94 | ''' 95 | def __init__(self, format:str=None, out_unit=['day','hour','minute']) -> None: 96 | ''' 97 | format: 年%Y 月%m 日%d 小时%H 分钟%M 秒%S" 98 | ''' 99 | self.format = format 100 | coeff = 1 101 | if out_unit == 'day': 102 | coeff *= 60*60*24 103 | elif out_unit == 'hour': 104 | coeff *= 60*60 105 | elif out_unit == 'minute': 106 | coeff *= 60 107 | self.coeff = coeff 108 | 109 | def __call__(self, in_str:str) -> float: 110 | dt = datetime.datetime.strptime(in_str, self.format) 111 | return dt.timestamp() / self.coeff 112 | 113 | def make_mask(m_shape, seq_lens) -> np.ndarray: 114 | ''' 115 | m_shape: (batch, seq_lens) 或者 (batch, n_fea, seq_lens) 116 | mask: (batch, seq_lens) or (batch, n_fea, seq_lens) 取决于m_shape 117 | ''' 118 | mask = np.zeros(m_shape, dtype=bool) 119 | if len(m_shape) == 2: 120 | for idx in range(m_shape[0]): 121 | mask[idx, :seq_lens[idx]] = True 122 | elif len(m_shape) == 3: 123 | for idx in range(m_shape[0]): 124 | mask[idx, :, :seq_lens[idx]] = True 125 | else: 126 | assert(0) 127 | return mask 128 | 129 | 130 | def find_best(path_dir, prefix='best'): 131 | '''寻找不含子文件的文件夹中最新文件的full path''' 132 | # get a list of all files in the directory 133 | all_files = os.listdir(path_dir) 134 | for file in all_files: 135 | if str(file).startswith(prefix): 136 | # get the full path of the latest file 137 | best_file_path = os.path.join(path_dir, file) 138 | return best_file_path 139 | assert(0) 140 | return None 141 | 142 | # set_chinese_font() -------------------------------------------------------------------------------- /configs/analyzers.yml: -------------------------------------------------------------------------------- 1 | --- 2 | cross_validation: 3 | alignment_dict: 4 | gender: 性别 5 | PH (Arterial): dX_PH 6 | age: 年龄 7 | Respiratory Rate (Total): dX_tX_呼吸频率 8 | O2 saturation pulseoxymetry: dX_tX_SPO2 9 | Arterial O2 pressure: dX_PaO2(mmHg) 10 | Inspired O2 Fraction: dX_FiO2(%) 11 | PF_ratio: dX_PaO2(mmHg) / FiO2(%) 12 | window: 48 13 | smoothing_band: 50 14 | loss_function: MultiClass 15 | iterations: 300 16 | depth: 5 17 | learning_rate: 0.03 18 | centers: [50, 150, 250, 350] 19 | class_names: [Severe, Moderate, Mild, No_ARDS] 20 | limit_feas: [gender, '223830', '224690', age, '220277', '220224', '223835', PF_ratio] 21 | forbidden_feas: [] 22 | slice_len: 192 23 | 24 | dataset_report: 25 | dataset_name: ards # ['ards', 'raw', 'vent'] 26 | dataset_version: ards@origin 27 | basic: true 28 | dynamic_dist: true 29 | static_dist: true 30 | 31 | ards_feature_explore: 32 | dataset_version: ards@origin 33 | coverrate: 34 | enabled: true 35 | class_names: [Severe, Moderate, Mild, No_ARDS] 36 | window: 16 37 | centers: [50, 150, 250, 350] 38 | plot_admission_dist: true 39 | plot_chart_vis: 40 | enabled: true 41 | collect_list: [transfer, admission] 42 | plot_samples: 43 | enabled: true 44 | n_sample: 10 45 | features: 46 | - Arterial O2 pressure 47 | - Inspired O2 Fraction 48 | - '220045' 49 | plot_time_series: 50 | enabled: false 51 | names: 52 | - PF_ratio 53 | n_sample: 400 54 | n_per_plots: 40 55 | correlation: 56 | enabled: false 57 | target: PF_ratio 58 | miss_mat: true 59 | first_ards_time: false 60 | feature_count: true 61 | 62 | raw_feature_explore: 63 | dataset_version: raw@version 64 | plot_admission_dist: true 65 | plot_chart_vis: 66 | enabled: true 67 | collect_list: [transfer, admission] 68 | plot_samples: 69 | enabled: false 70 | n_sample: 50 71 | features: 72 | - Arterial O2 pressure 73 | - Inspired O2 Fraction 74 | abnormal_dist: 75 | enabled: true 76 | value_limitation: 77 | Inspired O2 Fraction: {min: 21, max: 100} 78 | Arterial O2 pressure: {min: 0, max: 500} 79 | Respiratory Rate: {min: 1, max: 120} 80 | Arterial O2 Saturation: {min: 50, max: 100} 81 | O2 saturation pulseoxymetry: {min: 50, max: 100} 82 | Daily Weight: {min: 30, max: 200} 83 | miss_mat: true 84 | feature_count: true 85 | correlation: false 86 | 87 | vent_feature_explore: 88 | dataset_version: vent@origin 89 | correlation: 90 | enabled: true 91 | target: vent_status 92 | miss_mat: true 93 | vent_statistics: true 94 | vent_sample: 95 | enabled: true 96 | n_plot: 10 97 | 98 | ards_catboost_dynamic: 99 | dataset_version: ards@origin 100 | window: 16 101 | soft_label: false 102 | smoothing_band: 50 103 | loss_function: MultiClass 104 | iterations: 600 105 | depth: 5 106 | learning_rate: 0.05 107 | centers: [50, 150, 250, 350] 108 | class_names: [Severe, Moderate, Mild, No_ARDS] 109 | limit_feas: [] 110 | forbidden_feas: ['220224', '223835', 'PF_ratio'] 111 | 112 | ards_lstm: 113 | dataset_version: ards@filled 114 | window: 16 115 | device: 'cuda:1' 116 | centers: [50, 150, 250, 350] 117 | class_names: [Severe, Moderate, Mild, No_ARDS] 118 | hidden_size: 128 119 | batch_size: 2048 120 | epoch: 100 121 | lr: 0.001 122 | limit_feas: [] 123 | forbidden_feas: ['220224', '223835', 'PF_ratio'] 124 | 125 | ards_nearest_4cls: 126 | dataset_version: ards@origin 127 | window: 16 128 | soft_label: false 129 | centers: [50, 150, 250, 350] 130 | class_names: [Severe, Moderate, Mild, No_ARDS] 131 | smoothing_band: 50 132 | 133 | ards_logistic_reg: 134 | dataset_version: ards@filled 135 | window: 16 136 | centers: [50, 150, 250, 350] 137 | class_names: [Severe, Moderate, Mild, No_ARDS] 138 | max_iter: 1000 139 | limit_feas: [] 140 | forbidden_feas: ['220224', '223835', 'PF_ratio'] 141 | 142 | vent_nearest_2cls: 143 | dataset_version: vent@origin 144 | window: 8 145 | centers: [0, 1] 146 | class_names: ['no_vent', 'use_vent'] 147 | 148 | vent_catboost_dynamic: 149 | dataset_version: vent@origin 150 | window: 8 151 | loss_function: MultiClass 152 | iterations: 400 153 | depth: 5 154 | learning_rate: 0.03 155 | centers: [0, 1] 156 | class_names: ['no_vent', 'use_vent'] 157 | limit_feas: [] 158 | forbidden_feas: ['vent_status'] 159 | 160 | vent_logistic_reg: 161 | dataset_version: vent@filled 162 | window: 8 163 | centers: [0, 1] 164 | class_names: ['no_vent', 'use_vent'] 165 | max_iter: 1000 166 | 167 | vent_lstm: 168 | dataset_version: vent@filled 169 | window: 8 170 | device: 'cuda:1' 171 | centers: [0, 1] 172 | class_names: ['no_vent', 'use_vent'] 173 | hidden_size: 128 174 | batch_size: 2048 175 | epoch: 100 176 | lr: 0.001 177 | -------------------------------------------------------------------------------- /analyzer/dataset_explore/dataset_report.py: -------------------------------------------------------------------------------- 1 | import tools 2 | from tools.logging import logger 3 | import matplotlib.pyplot as plt 4 | from ..container import DataContainer 5 | import numpy as np 6 | from tqdm import tqdm 7 | import os 8 | from os.path import join as osjoin 9 | import pandas as pd 10 | import yaml 11 | from datasets.derived_vent_dataset import MIMICIV_Vent_Dataset 12 | from datasets.derived_ards_dataset import MIMICIV_ARDS_Dataset 13 | from datasets.derived_raw_dataset import MIMICIV_Raw_Dataset 14 | 15 | class DatasetReport(): 16 | def __init__(self, params:dict, container:DataContainer) -> None: 17 | self.params = params 18 | self.paths = params['paths'] 19 | if params['dataset_name'] == 'ards': 20 | self.dataset = MIMICIV_ARDS_Dataset() 21 | elif params['dataset_name'] == 'raw': 22 | self.dataset = MIMICIV_Raw_Dataset() 23 | elif params['dataset_name'] == 'vent': 24 | self.dataset = MIMICIV_Vent_Dataset() 25 | else: 26 | logger.error('Incorrect dataset_name') 27 | assert(0) 28 | self.dataset.load_version(params['dataset_version']) 29 | self.dataset.mode('all') 30 | self.gbl_conf = container._conf 31 | self.data = self.dataset.data 32 | 33 | def run(self): 34 | out_dir = os.path.join(self.paths['out_dir'], f'report_{self.params["dataset_name"]}') 35 | tools.reinit_dir(out_dir, build=True) 36 | report_path = osjoin(out_dir, f'dataset_report_{self.params["dataset_version"]}.txt') 37 | dist_dir = os.path.join(out_dir, 'dist') 38 | dir_names = ['points', 'duration', 'frequency', 'dynamic_value', 'static_value'] 39 | for name in dir_names: 40 | os.makedirs(os.path.join(dist_dir, name), exist_ok=True) 41 | logger.info('generating dataset report') 42 | write_lines = [] 43 | if self.params['basic']: 44 | # basic statistics 45 | write_lines.append('='*10 + 'basic' + '='*10) 46 | write_lines.append(f'Version: {self.params["dataset_version"]}') 47 | write_lines.append(f'Static keys: {len(self.dataset.static_keys)}') 48 | write_lines.append(f'Dynamic keys: {len(self.dataset.dynamic_keys)}') 49 | write_lines.append(f'Subjects:{len(self.dataset)}') 50 | write_lines.append(f'Static feature: {[self.dataset.fea_label(id) for id in self.dataset.static_keys]}') 51 | write_lines.append(f'Dynamic feature: {[self.dataset.fea_label(id) for id in self.dataset.dynamic_keys]}') 52 | if self.params['dynamic_dist']: 53 | # dynamic feature explore 54 | for id in tqdm(self.dataset.dynamic_keys, 'plot dynamic dist'): 55 | fea_name = self.dataset.fea_label(id) 56 | save_name = tools.remove_slash(str(fea_name)) 57 | write_lines.append('='*10 + f'{fea_name}({id})' + '='*10) 58 | arr_points = [] 59 | arr_duration = [] 60 | arr_frequency = [] 61 | arr_avg_value = [] 62 | for s in self.dataset._subjects.values(): 63 | for adm in s.admissions: 64 | if id in adm.keys(): 65 | dur = adm[id][-1,1] - adm[id][0,1] 66 | arr_points.append(adm[id].shape[0]) 67 | arr_duration.append(dur) 68 | if dur > 1e-3: # TODO 只有一个点无法计算 69 | arr_frequency.append(arr_points[-1] / arr_duration[-1]) 70 | else: 71 | arr_frequency.append(0) 72 | arr_avg_value.append(adm[id][:,0].mean()) 73 | arr_points, arr_duration, arr_frequency, arr_avg_value = \ 74 | np.asarray(arr_points), np.asarray(arr_duration), np.asarray(arr_frequency), np.asarray(arr_avg_value) 75 | if np.size(arr_points) != 0: 76 | write_lines.append(f'average points per admission: {arr_points.mean():.3f}') 77 | if np.size(arr_duration) != 0: 78 | write_lines.append(f'average duration(hour) per admission: {arr_duration.mean():.3f}') 79 | if np.size(arr_frequency) != 0: 80 | write_lines.append(f'average frequency(point/hour) per admission: {arr_frequency.mean():.3f}') 81 | if np.size(arr_avg_value) != 0: 82 | write_lines.append(f'average avg value per admission: {arr_avg_value.mean():.3f}') 83 | # plot distribution 84 | titles = ['points', 'duration', 'frequency', 'dynamic_value'] 85 | arrs = [arr_points, arr_duration, arr_frequency, arr_avg_value] 86 | for title, arr in zip(titles, arrs): 87 | if np.size(arr) != 0: 88 | tools.plot_single_dist( 89 | data=arr, data_name=f'{title}: {fea_name}', 90 | save_path=os.path.join(dist_dir, title, save_name + '.png'), discrete=False, adapt=True, bins=50) 91 | if self.params['static_dist']: 92 | # static feature explore 93 | for id in tqdm(self.dataset.static_keys, 'generate static feature report'): 94 | fea_name = self.dataset.fea_label(id) 95 | save_name = tools.remove_slash(str(fea_name)) 96 | write_lines.append('='*10 + f'{fea_name}({id})' + '='*10) 97 | idx = self.dataset.idx_dict[str(id)] 98 | static_data = self.dataset.data[:, idx, 0] 99 | write_lines.append(f'mean: {static_data.mean():.3f}') 100 | write_lines.append(f'std: {static_data.std():.3f}') 101 | write_lines.append(f'max: {np.max(static_data):.3f}') 102 | write_lines.append(f'min: {np.min(static_data):.3f}') 103 | tools.plot_single_dist( 104 | data=static_data, data_name=f'{fea_name}', 105 | save_path=os.path.join(dist_dir, 'static_value', save_name + '.png'), discrete=False, adapt=True, bins=50) 106 | # write report 107 | with open(report_path, 'w', encoding='utf-8') as fp: 108 | for line in write_lines: 109 | fp.write(line + '\n') 110 | logger.info(f'Report generated at {report_path}') -------------------------------------------------------------------------------- /analyzer/vent_catboost_dynamic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tools 3 | import os 4 | from tqdm import tqdm 5 | from tools import logger as logger 6 | from .container import DataContainer 7 | from tools.data import SliceDataGenerator, DynamicDataGenerator, LabelGenerator_cls, cal_label_weight, vent_label_func 8 | from tools.feature_importance import TreeFeatureImportance 9 | from catboost import Pool, CatBoostClassifier 10 | from os.path import join as osjoin 11 | from datasets.derived_vent_dataset import MIMICIV_Vent_Dataset 12 | 13 | import matplotlib.pyplot as plt 14 | 15 | class VentCatboostDynamicAnalyzer: 16 | def __init__(self, params:dict, container:DataContainer) -> None: 17 | self.params = params 18 | self.paths = params['paths'] 19 | self.dataset = MIMICIV_Vent_Dataset() 20 | self.dataset.load_version(params['dataset_version']) 21 | self.model_name = self.params['analyzer_name'] 22 | self.target_idx = self.dataset.idx_dict['vent_status'] 23 | 24 | def run(self): 25 | # step 1: init variables 26 | out_dir = os.path.join(self.paths['out_dir'], self.model_name) 27 | tools.reinit_dir(out_dir, build=True) 28 | # metric_2cls = tools.DichotomyMetric() 29 | metric_2cls = tools.DichotomyMetric() 30 | generator = SliceDataGenerator( 31 | window_points=self.params['window'], 32 | n_fea=len(self.dataset.total_keys), 33 | label_generator=LabelGenerator_cls( 34 | centers=self.params['centers'] 35 | ), 36 | label_func=vent_label_func, 37 | target_idx=self.target_idx, 38 | limit_idx=[self.dataset.fea_idx(id) for id in self.params['limit_feas']], 39 | forbidden_idx=[self.dataset.fea_idx(id) for id in self.params['forbidden_feas']] 40 | ) 41 | feature_names = [self.dataset.fea_label(idx) for idx in generator.avail_idx] 42 | print(f'Available features: {feature_names}') 43 | # step 2: train and predict 44 | for fold_idx, (train_index, valid_index, test_index) in enumerate(self.dataset.enumerate_kf()): 45 | train_result = generator(self.dataset.data[train_index, :, :], self.dataset.seqs_len[train_index]) 46 | X_train, Y_train = train_result['data'], train_result['label'] 47 | 48 | valid_result = generator(self.dataset.data[valid_index, :, :], self.dataset.seqs_len[valid_index]) 49 | X_valid, Y_valid = valid_result['data'], valid_result['label'] 50 | 51 | test_result = generator(self.dataset.data[test_index, :, :], self.dataset.seqs_len[test_index]) 52 | X_test, Y_test = test_result['data'], test_result['label'] 53 | 54 | label_weight = cal_label_weight(len(self.params['centers']), Y_train) 55 | logger.info(f'label weight: {label_weight}') 56 | 57 | model = CatBoostClassifier( 58 | iterations=self.params['iterations'], 59 | learning_rate=self.params['learning_rate'], 60 | loss_function=self.params['loss_function'], 61 | class_weights=label_weight, 62 | use_best_model=True 63 | ) 64 | pool_train = Pool(X_train, Y_train.argmax(axis=-1)) 65 | pool_valid = Pool(X_valid, Y_valid.argmax(axis=-1)) 66 | 67 | model.fit(pool_train, eval_set=pool_valid) 68 | 69 | Y_test_pred = model.predict_proba(X_test) 70 | metric_2cls.add_prediction(Y_test_pred[:, 1], Y_test[:, 1]) 71 | if fold_idx == 0: 72 | # plot sample 73 | self.plot_examples(test_index, model, 20, osjoin(out_dir, 'samples')) 74 | explorer = TreeFeatureImportance(map_func=lambda x:x[:, :, 1], fea_names=feature_names, missvalue=-1, n_approx=-1) 75 | explorer.add_record(model, valid_X=X_valid) 76 | explorer.plot_beeswarm(max_disp=10, plot_path=osjoin(out_dir, f'importance.png')) 77 | 78 | metric_2cls.plot_curve(curve_type='roc', title=f'ROC for ventilation', save_path=osjoin(out_dir, f'vent_roc.png')) 79 | metric_2cls.plot_curve(curve_type='prc', title=f'PRC for ventilation', save_path=osjoin(out_dir, f'vent_prc.png')) 80 | 81 | with open(os.path.join(out_dir, 'result.txt'), 'w') as fp: 82 | print('Overall performance:', file=fp) 83 | metric_2cls.write_result(fp) 84 | 85 | def plot_examples(self, test_index, model:CatBoostClassifier, n_sample:int, out_dir:str): 86 | tools.reinit_dir(out_dir) 87 | generator = DynamicDataGenerator( 88 | window_points=self.params['window'], 89 | n_fea=len(self.dataset.total_keys), 90 | label_generator=LabelGenerator_cls( 91 | centers=self.params['centers'] 92 | ), 93 | label_func=vent_label_func, 94 | target_idx=self.target_idx, 95 | limit_idx=[], 96 | forbidden_idx=[self.dataset.idx_dict[id] for id in ['vent_status']] 97 | ) 98 | test_result = generator(self.dataset.data[test_index, :, :], self.dataset.seqs_len[test_index]) 99 | X_test, mask, Y_test = test_result['data'], test_result['mask'], test_result['label'] 100 | # random sample 10 sequences 101 | seq_index = np.arange(len(Y_test))[np.logical_and(np.max(Y_test[:, :, 1], axis=-1) > 0.5, np.min(Y_test[:, :, 1], axis=-1) < 0.5)] 102 | seq_index = seq_index[np.random.choice(len(seq_index), n_sample)] 103 | origin_label = self.dataset.data[test_index[seq_index], self.target_idx, :] # origin label 104 | X_test = X_test[seq_index, :, :] 105 | for idx in range(n_sample): 106 | seq_mask = mask[seq_index[idx], :] 107 | Y_origin = np.clip(origin_label[idx, seq_mask], 0, 1) # 2->1 108 | Y_pred = model.predict_proba(X_test[idx, :, :].T)[seq_mask, 1] 109 | Y_target = Y_test[seq_index[idx], seq_mask, 1] 110 | plt.figure(figsize=(8, 8)) 111 | ax = plt.gca() 112 | ax.plot(Y_origin, '-o', color='C0', label='origin ventilation status') 113 | ax.plot(Y_pred, '-o', color='C1', label='prediction probability') 114 | ax.plot(Y_target, '-o', color='C2', label='prediction target') 115 | ax.legend() 116 | plt.axhline(0.5, xmin=0, xmax=np.sum(seq_mask)) 117 | plt.savefig(osjoin(out_dir, f'{idx}.png')) 118 | plt.close() 119 | 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /analyzer/cross_validation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.model_selection import KFold 3 | import tools 4 | import os 5 | from tools import logger as logger 6 | from .container import DataContainer 7 | from tools.data import SliceDataGenerator, LabelGenerator_cls, cal_label_weight 8 | from catboost import CatBoostClassifier, Pool 9 | from os.path import join as osjoin 10 | from datasets.cv_dataset import CrossValidationDataset 11 | from datasets.derived_ards_dataset import MIMICIV_ARDS_Dataset 12 | from datasets.derived_raw_dataset import MIMICIV_Raw_Dataset 13 | 14 | 15 | 16 | class CV_Analyzer: 17 | def __init__(self, params:dict, container:DataContainer) -> None: 18 | self.params = params 19 | self.paths = params['paths'] 20 | self.container = container 21 | self.model_name = self.params['analyzer_name'] 22 | # copy attribute from container 23 | self.mimic_raw_dataset = MIMICIV_Raw_Dataset() 24 | self.mimic_raw_dataset.load_version('raw_version') 25 | self.mimic_dataset = MIMICIV_ARDS_Dataset() 26 | self.mimic_dataset.load_version('no_fill_version') 27 | # prepare mimic-iv data 28 | self.mimic_data = self.mimic_dataset.data 29 | # initialize 30 | self.out_dir = os.path.join(self.paths['out_dir'], self.model_name) 31 | tools.reinit_dir(self.out_dir, build=True) 32 | # prepare cross validation dataset 33 | self.cv_dataset = CrossValidationDataset() 34 | 35 | def prepare_cross_validation_data(self): 36 | # alignment 37 | # init validation out dir 38 | val_out_dir = osjoin(self.out_dir, 'seperate_validation') 39 | tools.reinit_dir(val_out_dir) 40 | # generate label and mask 41 | self.val_generator = SliceDataGenerator( 42 | window_points=self.params['window'], 43 | n_fea=len(self.cv_dataset.total_keys), 44 | label_generator=LabelGenerator_cls( 45 | centers=self.params['centers'], 46 | soft_label=False, 47 | smooth_band=self.params['smoothing_band'] 48 | ), 49 | target_idx=self.cv_dataset.total_keys.index('dX_PaO2(mmHg) / FiO2(%)') 50 | ) 51 | 52 | # make validation data into slice 53 | mimic_limit_idx = sorted([self.mimic_dataset.idx_dict[n] for n in self.params['feature_limit']]) # generator会对idx排序 54 | self.fea_names = [self.mimic_dataset.fea_label(idx) for idx in mimic_limit_idx] 55 | val_fea_names = [ 56 | self.params['alignment_dict'][self.mimic_dataset.fea_label(idx)] for idx in mimic_limit_idx] 57 | 58 | val_idx = [self.cv_dataset.total_keys.index(name) for name in val_fea_names] 59 | cv_result = self.val_generator('val_cv', self.cv_dataset.data, seq_lens=self.cv_dataset.seqs_len) 60 | X_cv, Y_cv = cv_result['data'][:, val_idx], cv_result['label'] 61 | return X_cv, Y_cv 62 | 63 | def train(self, dataset, label): 64 | out_dir = os.path.join(self.paths['out_dir'], self.model_name, label) 65 | tools.reinit_dir(out_dir, build=True) 66 | metric_4cls = tools.MultiClassMetric(class_names=self.params['class_names'], out_dir=os.path.join(out_dir, '4cls_mimic')) 67 | metric_4cls_sep_val = tools.MultiClassMetric(class_names=self.params['class_names'], out_dir=os.path.join(out_dir, '4cls_sep_val')) 68 | tools.reinit_dir(os.path.join(out_dir, '4cls_mimic'), build=True) 69 | tools.reinit_dir(os.path.join(out_dir, '4cls_sep_val'), build=True) 70 | # step 3: generate mimic-iv labels 71 | mimic_limit_idx = [dataset.idx_dict[n] for n in self.params['feature_limit']] 72 | 73 | # fea_names = [self.dataset.get_fea_label(idx) for idx in generator.available_idx()] 74 | imp_logger = tools.TreeFeatureImportance(fea_names=self.fea_names) 75 | generator = SliceDataGenerator( 76 | window_points=self.params['window'], 77 | n_fea=len(dataset.total_keys), 78 | label_generator=LabelGenerator_cls( 79 | centers=self.params['centers'], 80 | soft_label=False, 81 | smooth_band=self.params['smoothing_band'] 82 | ), 83 | target_idx=dataset.idx_dict['PF_ratio'], 84 | limit_idx=mimic_limit_idx 85 | ) 86 | # step 4: train and predict 87 | for idx, (train_index, valid_index, test_index) in enumerate(dataset.enumerate_kf()): 88 | # train model on mimic-iv 89 | train_result = generator(f'{idx}_train', dataset.data[train_index, :, :], dataset.seqs_len[train_index]) 90 | X_train, Y_train = train_result['data'], train_result['label'] 91 | valid_result = generator(f'{idx}_train', dataset.data[valid_index, :, :], dataset.seqs_len[valid_index]) 92 | X_valid, Y_valid = valid_result['data'], valid_result['label'] 93 | test_result = generator(f'{idx}_test', dataset.data[test_index, :, :], dataset.seqs_len[test_index]) 94 | X_test, Y_test = test_result['data'], test_result['label'] 95 | model = CatBoostClassifier( 96 | iterations=self.params['iterations'], 97 | learning_rate=self.params['learning_rate'], 98 | loss_function=self.params['loss_function'], 99 | class_weights=cal_label_weight(4, Y_train) 100 | ) 101 | pool_train = Pool(X_train, Y_train.argmax(axis=-1)) 102 | pool_valid = Pool(X_valid, Y_valid.argmax(axis=-1)) 103 | model.fit(pool_train, eval_set=pool_valid) 104 | Y_pred = model.predict_proba(X_test) 105 | metric_4cls.add_prediction(Y_pred, Y_test) # 去掉mask外的数据 106 | # test model on cross validation dataset 107 | cv_pred = model.predict_proba(self.X_cv) 108 | metric_4cls_sep_val.add_prediction(cv_pred, self.Y_cv) 109 | 110 | imp_logger.add_record(model, self.X_cv) 111 | dataset.mode('all') # 恢复原本状态 112 | 113 | # step 5: result explore 114 | imp_logger.plot_beeswarm(os.path.join(out_dir, f'cv_shap_{label}.png')) 115 | single_imp_out = os.path.join(out_dir, 'single_shap') 116 | tools.reinit_dir(single_imp_out, build=True) 117 | imp_logger.plot_single_importance(out_dir=single_imp_out, select=10) 118 | metric_4cls.confusion_matrix(comment=self.model_name + '_all') 119 | metric_4cls_sep_val.confusion_matrix(comment='cross validation k-fold=5') 120 | 121 | return { 122 | 'metric_4cls': metric_4cls, 123 | 'metric_4cls_seq_val': metric_4cls_sep_val 124 | } 125 | 126 | def run(self): 127 | # init cross validation data 128 | self.X_cv, self.Y_cv = self.prepare_cross_validation_data() 129 | # step 2: init variables 130 | 131 | pipeline_result = self.train(dataset=self.mimic_dataset, label='with_pipeline') 132 | raw_result = self.train(dataset=self.mimic_raw_dataset, label='raw') 133 | 134 | with open(os.path.join(self.out_dir, 'result.txt'), 'a') as f: 135 | print('mimic-iv with pipeline:', file=f) 136 | pipeline_result['metric_4cls'].write_result(f) 137 | print('mimic-iv without pipeline:', file=f) 138 | raw_result['metric_4cls'].write_result(f) 139 | 140 | print('cross validation with pipeline:', file=f) 141 | pipeline_result['metric_4cls_seq_val'].write_result(f) 142 | print('cross validation without pipeline:', file=f) 143 | raw_result['metric_4cls_seq_val'].write_result(f) 144 | 145 | -------------------------------------------------------------------------------- /datasets/helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tools 3 | import pickle 4 | import numpy as np 5 | import pandas as pd 6 | from tools import GLOBAL_CONF_LOADER 7 | from tools import logger 8 | from sklearn.model_selection import KFold 9 | import math 10 | from tqdm import tqdm 11 | 12 | def interp(fx:np.ndarray, fy:np.ndarray, x_start:float, interval:float, n_bins:int, missing=-1, fill_bin=['avg', 'latest']): 13 | # fx, (N,), N >= 1, sample time for each data point (irrgular) 14 | # fy: same size as fx, sample value for each data point 15 | # x: dim=1 16 | assert(fx.shape[0] == fy.shape[0] and len(fx.shape) == len(fy.shape) and len(fx.shape) == 1 and fx.shape[0] >= 1) 17 | assert(interval > 0 and n_bins > 0) 18 | assert(fill_bin in ['avg', 'latest']) 19 | result = np.ones((n_bins)) * missing 20 | 21 | for idx in range(n_bins): 22 | t_bin_start = x_start + (idx - 1) * interval 23 | t_bin_end = x_start + idx * interval 24 | valid_mask = np.logical_and(fx > t_bin_start, fx <= t_bin_end) # (start, end] 25 | if np.any(valid_mask): # have at least one point 26 | if fill_bin == 'avg': 27 | result[idx] = np.mean(fy[valid_mask]) 28 | elif fill_bin == 'latest': 29 | result[idx] = fy[valid_mask][-1] 30 | else: 31 | assert(0) 32 | else: # no point in current bin 33 | if idx == 0: 34 | result[idx] = missing 35 | else: 36 | result[idx] = result[idx-1] # history is always available 37 | return result 38 | 39 | class Admission: 40 | ''' 41 | 代表一段连续的、环境较稳定的住院经历,原subject/admission/stay/transfer的四级结构被精简到subject/admission的二级结构 42 | label: 代表急诊室数据或ICU数据 43 | admittime: 起始时间 44 | dischtime: 结束时间 45 | ''' 46 | def __init__(self, unique_id:int, admittime:float, dischtime:float) -> None: 47 | self.dynamic_data = {} # dict(fea_name:ndarray(value, time)) 48 | assert(admittime < dischtime) 49 | self.unique_id = unique_id # 16 digits 50 | self.admittime = admittime 51 | self.dischtime = dischtime 52 | self._data_updated = False 53 | 54 | def append_dynamic(self, itemid, time:float, value): 55 | assert(not self._data_updated) 56 | if self.dynamic_data.get(itemid) is None: 57 | self.dynamic_data[itemid] = [(value, time)] 58 | else: 59 | self.dynamic_data[itemid].append((value, time)) 60 | 61 | def pop_dynamic(self, itemid): 62 | if self.dynamic_data.get(itemid) is not None: 63 | self.dynamic_data.pop(itemid) 64 | 65 | def update_data(self): 66 | '''绝对时间变为相对时间,更改动态特征的格式''' 67 | if not self._data_updated: 68 | self._data_updated = True 69 | for key in self.dynamic_data: 70 | arr = np.asarray(sorted(self.dynamic_data[key], key=lambda x:x[1])) 71 | arr[:, 1] -= self.admittime 72 | self.dynamic_data[key] = arr 73 | 74 | def duration(self): 75 | return max(0, self.dischtime - self.admittime) 76 | 77 | def empty(self): 78 | return True if len(self.dynamic_data) == 0 else False 79 | 80 | def __getitem__(self, idx): 81 | return self.dynamic_data[idx] 82 | 83 | def __len__(self): 84 | return len(self.dynamic_data) 85 | 86 | def keys(self): 87 | return self.dynamic_data.keys() 88 | 89 | 90 | class Subject: 91 | ''' 92 | 每个患者有一张表, 每列是一个指标, 每行是一次检测结果, 每个结果包含一个(值, 时间戳)的结构 93 | static data: dict(feature name: (value, charttime)) 94 | dyanmic data: admissions->(id, chart time, value) 95 | ''' 96 | def __init__(self, subject_id, birth_year:int) -> None: 97 | self.subject_id = subject_id 98 | self.birth_year = birth_year 99 | self.static_data:dict[str, np.ndarray] = {} # dict(fea_name:value) 100 | self.admissions:list[Admission] = [] 101 | 102 | def append_admission(self, admission:Admission): 103 | self.admissions.append(admission) 104 | # 维护时间序列 105 | if len(self.admissions) >= 1: 106 | self.admissions = sorted(self.admissions, key=lambda adm:adm.admittime) 107 | 108 | def append_static(self, charttime:float, name, value): 109 | if charttime is None: 110 | self.static_data[name] = value 111 | else: 112 | if name not in self.static_data: 113 | self.static_data[name] = [(value, charttime)] 114 | else: 115 | self.static_data[name].append((value, charttime)) 116 | 117 | def latest_static(self, key, time=None): 118 | if key not in self.static_data.keys(): 119 | return None 120 | 121 | if not isinstance(self.static_data[key], np.ndarray): # single value 122 | return self.static_data[key] 123 | else: 124 | assert(time is not None) 125 | idx = np.argmin(time - self.static_data[key][:, 1]) 126 | if time - self.static_data[key][idx, 1] >= 0: 127 | return self.static_data[key][idx, 0] 128 | else: 129 | return None # input time is too early 130 | 131 | def nearest_static(self, key, time=None): 132 | if key not in self.static_data.keys(): 133 | return None 134 | 135 | if not isinstance(self.static_data[key], np.ndarray): # single value 136 | return self.static_data[key] 137 | else: 138 | assert(time is not None) 139 | idx = np.argmin(np.abs(self.static_data[key][:, 1] - time)) 140 | return self.static_data[key][idx, 0] 141 | 142 | def append_dynamic(self, charttime:float, itemid, value): 143 | '''添加一个动态特征到合适的admission中''' 144 | for adm in self.admissions: # search admission by charttime 145 | if adm.admittime < charttime and charttime < adm.dischtime: 146 | adm.append_dynamic(itemid, charttime, value) 147 | 148 | def update_data(self): 149 | '''将数据整理成连续形式''' 150 | for adm in self.admissions: 151 | adm.update_data() 152 | 153 | def find_admission(self, unique_id:int): 154 | for adm in self.admissions: 155 | if adm.unique_id == unique_id: 156 | return adm 157 | return None 158 | def del_empty_admission(self): 159 | # 删除空的admission 160 | new_adm = [] 161 | for idx in range(len(self.admissions)): 162 | if not self.admissions[idx].empty(): 163 | new_adm.append(self.admissions[idx]) 164 | self.admissions = new_adm 165 | 166 | def empty(self): 167 | for adm in self.admissions: 168 | if not adm.empty(): 169 | return False 170 | return True 171 | 172 | class KFoldIterator: 173 | def __init__(self, dataset, k): 174 | self._current = -1 175 | self._k = k 176 | self._dataset = dataset 177 | 178 | def __iter__(self): 179 | return self 180 | 181 | def __next__(self): 182 | self._current += 1 183 | if self._current < self._k: 184 | return self._dataset.set_kf_index(self._current) 185 | else: 186 | raise StopIteration 187 | 188 | def load_all_subjects(patient_table_path:str) -> set: 189 | # return a dict with key=subject_id, value=None 190 | patient_set = set() 191 | patients = pd.read_csv(os.path.join(patient_table_path), encoding='utf-8') 192 | for row in tqdm(patients.itertuples(), 'Find all subjects', total=len(patients)): 193 | patient_set.add(row.subject_id) 194 | return patient_set 195 | -------------------------------------------------------------------------------- /analyzer/dataset_explore/vent_explore.py: -------------------------------------------------------------------------------- 1 | import tools 2 | from tools.logging import logger 3 | from ..container import DataContainer 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from tqdm import tqdm 7 | import seaborn as sns 8 | import os 9 | from os.path import join as osjoin 10 | import pandas as pd 11 | import yaml 12 | from datasets.derived_vent_dataset import MIMICIV_Vent_Dataset 13 | from scipy.signal import convolve2d 14 | 15 | class VentFeatureExplorer: 16 | def __init__(self, params:dict, container:DataContainer) -> None: 17 | self.params = params 18 | self.container = container 19 | self.dataset = MIMICIV_Vent_Dataset() 20 | self.dataset.load_version(params['dataset_version']) 21 | self.gbl_conf = container._conf 22 | self.data = self.dataset.data 23 | 24 | def run(self): 25 | '''输出mimic-iv数据集的统计特征, 独立于模型和研究方法''' 26 | logger.info('Vent dataset Explore') 27 | dataset_version = self.params['dataset_version'] 28 | out_dir = osjoin(self.params['paths']['out_dir'], f'feature_explore[{dataset_version}]') 29 | tools.reinit_dir(out_dir, build=True) 30 | # random plot sample time series 31 | if self.params['correlation']['enabled']: 32 | self.correlation(out_dir, self.params['correlation']['target']) 33 | if self.params['miss_mat']: 34 | self.miss_mat(out_dir) 35 | if self.params['vent_statistics']: 36 | self.vent_statistics(out_dir) 37 | if self.params['vent_sample']['enabled']: 38 | self.vent_sample(out_dir) 39 | 40 | def correlation(self, out_dir, target_id_or_label): 41 | # plot correlation matrix 42 | target_id, target_label = self.dataset.fea_id(target_id_or_label), self.dataset.fea_label(target_id_or_label) 43 | target_index = self.dataset.idx_dict[target_id] 44 | labels = [self.dataset.fea_label(id) for id in self.dataset._total_keys] 45 | corr_mat = tools.plot_correlation_matrix(self.data[:, :, 0], labels, save_path=os.path.join(out_dir, 'correlation_matrix')) 46 | correlations = [] 47 | for idx in range(corr_mat.shape[1]): 48 | correlations.append([corr_mat[target_index, idx], labels[idx]]) # list[(correlation coeff, label)] 49 | correlations = sorted(correlations, key=lambda x:np.abs(x[0]), reverse=True) 50 | with open(os.path.join(out_dir, 'correlation.txt'), 'w') as fp: 51 | fp.write(f"Target feature: {target_label}\n") 52 | for idx in range(corr_mat.shape[1]): 53 | fp.write(f'Correlation with target: {correlations[idx][0]} \t{correlations[idx][1]}\n') 54 | 55 | def miss_mat(self, out_dir): 56 | '''计算行列缺失分布并输出''' 57 | na_table = np.ones((len(self.dataset.subjects), len(self.dataset._dynamic_keys)), dtype=bool) # True=miss 58 | for r_id, s_id in enumerate(self.dataset.subjects): 59 | for adm in self.dataset.subjects[s_id].admissions: 60 | # TODO 替换dynamic keys到total keys 61 | adm_key = set(adm.keys()) 62 | for c_id, key in enumerate(self.dataset._dynamic_keys): 63 | if key in adm_key: 64 | na_table[r_id, c_id] = False 65 | 66 | row_nas = na_table.mean(axis=1) 67 | col_nas = na_table.mean(axis=0) 68 | tools.plot_single_dist(row_nas, f"Row miss rate", os.path.join(out_dir, "row_miss_rate.png"), discrete=False, adapt=True) 69 | tools.plot_single_dist(col_nas, f"Column miss rate", os.path.join(out_dir, "col_miss_rate.png"), discrete=False, adapt=True) 70 | # save raw/col miss rate to file 71 | tools.save_pkl(row_nas, os.path.join(out_dir, "row_missrate.pkl")) 72 | tools.save_pkl(col_nas, os.path.join(out_dir, "col_missrate.pkl")) 73 | 74 | # plot matrix 75 | row_idx = sorted(list(range(row_nas.shape[0])), key=lambda x:row_nas[x]) 76 | col_idx = sorted(list(range(col_nas.shape[0])), key=lambda x:col_nas[x]) 77 | na_table = na_table[row_idx, :][:, col_idx] # (n_subjects, n_feature) 78 | # apply conv to get density 79 | conv_kernel = np.ones((3,3)) / 9 80 | na_table = np.clip(convolve2d(na_table, conv_kernel, boundary='symm'), 0, 1.0) 81 | tools.plot_density_matrix(1.0-na_table, 'Missing distribution for subjects and features [miss=white]', xlabel='features', ylabel='subjects', 82 | aspect='auto', save_path=os.path.join(out_dir, "miss_mat.png")) 83 | 84 | def vent_statistics(self, out_dir): 85 | result = {'no_vent':0, 'non-invasive_vent':0, 'invasive_vent':0} 86 | non_invasive_vent_times = [] 87 | invasive_vent_times = [] 88 | for s_id in self.dataset.subjects: 89 | subject = self.dataset.subjects[s_id] 90 | if 'ventilation_num' not in subject.static_data: 91 | result['no_vent'] += 1 92 | continue 93 | vent_num = subject.static_data['ventilation_num'] 94 | max_vent = int(np.max(vent_num[:, 0])) 95 | if max_vent == 0: 96 | result['no_vent'] += 1 97 | elif max_vent == 1: 98 | result['non-invasive_vent'] += 1 99 | for idx in range(vent_num.shape[0]): 100 | if int(vent_num[idx, 0]) == 1 and (vent_num[idx, 1] >= subject.admissions[0].admittime) and (vent_num[idx, 1] < subject.admissions[0].dischtime): 101 | non_invasive_vent_times.append(vent_num[idx, 1] - subject.admissions[0].admittime) 102 | break 103 | elif max_vent == 2: 104 | result['invasive_vent'] += 1 105 | for idx in range(vent_num.shape[0]): 106 | if int(vent_num[idx, 0]) == 2 and (vent_num[idx, 1] >= subject.admissions[0].admittime) and (vent_num[idx, 1] < subject.admissions[0].dischtime): 107 | invasive_vent_times.append(vent_num[idx, 1] - subject.admissions[0].admittime) 108 | break 109 | else: 110 | assert(0) 111 | logger.info(f'All subjects: {len(self.dataset.subjects)}') 112 | logger.info(f'Vent status (max vent for each sequence): {result}') 113 | # plot distribution of first ventilation time 114 | tools.plot_single_dist(np.asanyarray(non_invasive_vent_times), 'non-invasive first ventilation time', osjoin(out_dir, 'non_invasive_dist.png'), bins=72) 115 | tools.plot_single_dist(np.asanyarray(invasive_vent_times), 'invasive first ventilation time', osjoin(out_dir, 'invasive_dist.png'), bins=72) 116 | 117 | def vent_sample(self, out_dir): 118 | n_plot = self.params['vent_sample']['n_plot'] 119 | n_idx = 1 120 | plt.figure(figsize = (6, n_plot*2)) 121 | for s_id in self.dataset.subjects: 122 | subject = self.dataset.subjects[s_id] 123 | adm = subject.admissions[0] 124 | if 'ventilation_num' not in subject.static_data: 125 | continue 126 | else: 127 | vent_status = subject.static_data['ventilation_num'][:, 0] 128 | vent_start = subject.static_data['ventilation_start'][:, 0] 129 | vent_end = subject.static_data['ventilation_end'][:, 0] 130 | mask = np.logical_and(vent_start >= adm.admittime, vent_start < adm.dischtime) 131 | vent_status = vent_status[mask] 132 | if not np.any(mask): 133 | continue 134 | status_list = np.unique(vent_status).astype(int).tolist() 135 | if 2 not in status_list or 1 not in status_list: 136 | continue 137 | vent_start = vent_start[mask] 138 | vent_end = vent_end[mask] 139 | plt.subplot(n_plot, 1, n_idx) 140 | for idx in range(vent_status.shape[0]): 141 | plt.plot([vent_start[idx] - adm.admittime, vent_end[idx] - adm.admittime], [vent_status[idx], vent_status[idx]], '-o') 142 | n_idx += 1 143 | if n_idx > n_plot: 144 | break 145 | plt.savefig(osjoin(out_dir, 'vent_sample.png')) 146 | plt.close() 147 | -------------------------------------------------------------------------------- /tools/data/data_generator.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import torch 3 | import numpy as np 4 | import tools 5 | import compress_pickle as pickle 6 | from os.path import exists, join as osjoin 7 | from abc import abstractmethod 8 | from .label_generator import LabelGenerator 9 | from .utils import Normalization, unroll 10 | 11 | 12 | def label_func_min(pred_window:np.ndarray, pred_window_mask:np.ndarray): 13 | assert(pred_window.shape == pred_window_mask.shape) 14 | invalid_flag = pred_window.max() + 1 15 | pred_window = pred_window * pred_window_mask + invalid_flag * np.logical_not(pred_window_mask) 16 | label = np.min(pred_window, axis=1) 17 | sequence_mask = label != invalid_flag 18 | return sequence_mask, label 19 | 20 | def label_func_max(pred_window:np.ndarray, pred_window_mask:np.ndarray): 21 | assert(pred_window.shape == pred_window_mask.shape) 22 | invalid_flag = pred_window.min() - 1 23 | pred_window = pred_window * pred_window_mask + invalid_flag * np.logical_not(pred_window_mask) 24 | label = np.max(pred_window, axis=1) 25 | sequence_mask = label != invalid_flag 26 | return sequence_mask, label 27 | 28 | def vent_label_func(pred_window:np.ndarray, pred_window_mask:np.ndarray): 29 | sequence_mask, label = label_func_max(pred_window, pred_window_mask) 30 | label = np.clip(label, -1, 1) # 1,2 -> 1 31 | return sequence_mask, label 32 | 33 | class DataGenerator(): 34 | def __init__(self, n_fea, limit_idx=[], forbidden_idx=[]) -> None: 35 | if len(limit_idx) == 0: 36 | self._avail_idx = [idx for idx in range(n_fea) if idx not in forbidden_idx] 37 | else: 38 | self._avail_idx = [idx for idx in range(n_fea) if (idx in limit_idx) and (idx not in forbidden_idx)] 39 | 40 | @property 41 | def avail_idx(self): 42 | return self._avail_idx 43 | 44 | class DynamicDataGenerator(DataGenerator): 45 | def __init__(self, window_points, 46 | n_fea, 47 | label_generator: LabelGenerator, 48 | label_func, 49 | target_idx, 50 | limit_idx=[], 51 | forbidden_idx=[], 52 | norm:Normalization=None 53 | ) -> None: 54 | super().__init__(n_fea, limit_idx, forbidden_idx) 55 | self.norm = norm 56 | self.label_func = label_func 57 | self.target_idx = target_idx 58 | self.window = window_points # how many points we should look forward 59 | self.label_gen = label_generator 60 | 61 | def __call__(self, _data:np.ndarray, seq_lens:np.ndarray) -> dict: 62 | ''' 63 | data: (batch, n_fea, seq_lens) 64 | mask: (batch, seq_lens) 65 | return: 66 | mask(batch, seq_lens), label(batch, seq_lens, n_cls) 67 | ''' 68 | 69 | mask = tools.make_mask((_data.shape[0], _data.shape[2]), seq_lens) # (batch, seq_lens) 70 | data = _data.copy() 71 | target = data[:, self.target_idx, :] 72 | data = data[:, self.avail_idx, :] 73 | # 将target按照时间顺序平移 74 | for idx in range(target.shape[1]-1): # 最后一个格子预测一格 75 | stop = min(data.shape[2], idx+self.window) 76 | pred_window = target[:, idx+1:stop] # seq_len的最后一个格子是无效的 77 | pred_window_mask = mask[:, idx+1:stop] 78 | sequence_mask, label = self.label_func(pred_window, pred_window_mask) # (batch, window) -> (batch, ) 79 | target[:, idx] = label 80 | mask[:, idx] = sequence_mask 81 | # 将target转化为标签 82 | label = self.label_gen(target) * mask[..., None] 83 | if self.norm is not None: 84 | data = self.norm(data, self.avail_idx) 85 | 86 | result = {'data': data, 'mask': mask, 'label': label} 87 | return result 88 | 89 | class SliceDataGenerator(DataGenerator): 90 | ''' 91 | 生成每个时间点和预测窗口的标签, 并进行展开 92 | ''' 93 | def __init__(self, 94 | window_points, 95 | n_fea, 96 | label_generator: LabelGenerator, 97 | label_func, 98 | target_idx, 99 | limit_idx=[], 100 | forbidden_idx=[], 101 | norm:Normalization=None 102 | ) -> None: 103 | super().__init__(n_fea, limit_idx, forbidden_idx) 104 | self.norm = norm 105 | self.label_func = label_func 106 | self.target_idx = target_idx 107 | self.window = window_points # 向前预测多少个点内的ARDS 108 | self.label_gen = label_generator 109 | self.slice_len = None 110 | 111 | def __call__(self, _data:np.ndarray, seq_lens:np.ndarray) -> dict: 112 | ''' 113 | data: (batch, n_fea, seq_lens) 114 | mask: (batch, seq_lens) 115 | return: 116 | label(n_slice, dim_target) 117 | ''' 118 | 119 | mask = tools.make_mask((_data.shape[0], _data.shape[2]), seq_lens) # (batch, seq_lens) 120 | data = _data.copy() 121 | target = data[:, self.target_idx, :] 122 | data = data[:, self.avail_idx, :] 123 | self.slice_len = data.shape[0] 124 | # 将target按照时间顺序平移 125 | for idx in range(target.shape[1]-1): # 最后一个格子预测一格 126 | stop = min(data.shape[2], idx+self.window) 127 | pred_window = target[:, idx+1:stop] # seq_len的最后一个格子是无效的 128 | pred_window_mask = mask[:, idx+1:stop] 129 | sequence_mask, label = self.label_func(pred_window, pred_window_mask) # (batch, window) -> (batch, ) 130 | target[:, idx] = label 131 | mask[:, idx] = np.logical_and(mask[:, idx], sequence_mask) 132 | # 将target转化为标签 133 | label = self.label_gen(target) * mask[..., None] 134 | if self.norm is not None: 135 | data = self.norm(data, self.avail_idx) 136 | 137 | # 转化为slice 138 | data, label = unroll(data, mask), unroll(label, mask) 139 | result = {'data': data, 'mask': mask, 'label': label, 'slice_len': self.slice_len} 140 | return result 141 | 142 | def restore_from_slice(self, x:np.ndarray): 143 | '''make_slice的反向操作, 保证顺序不会更改''' 144 | pass 145 | 146 | class StaticDataGenerator(DataGenerator): 147 | ''' 148 | 生成每个时间点和预测窗口的标签,但是不展开时间轴 149 | ''' 150 | def __init__(self, 151 | start_point, 152 | window_points, 153 | n_fea, 154 | label_generator: LabelGenerator, 155 | label_func, 156 | target_idx, 157 | limit_idx=[], 158 | forbidden_idx=[], 159 | norm:Normalization=None 160 | ) -> None: 161 | super().__init__(n_fea, limit_idx, forbidden_idx) 162 | self.norm = norm 163 | self.label_func = label_func 164 | self.target_idx = target_idx 165 | self.window = window_points 166 | self.start_point = start_point 167 | self.label_gen = label_generator 168 | 169 | def __call__(self, _data:np.ndarray, seq_lens:np.ndarray) -> dict: 170 | ''' 171 | data: (batch, n_fea, seq_lens) 172 | mask: (batch, seq_lens) 173 | return: 174 | mask(new_batch, window), label(new_batch, n_cls), data(new_batch, avail_idx, window) 175 | ''' 176 | 177 | mask = tools.make_mask(_data.shape[[0,2]], seq_lens) # (batch, seq_lens) 178 | mask = mask[:, self.start_point:min(target.size(1), self.start_point+self.window)] # (batch, window) 179 | seq_mask = np.max(mask, axis=1) # (batch, ) throw away all zero sequences 180 | mask = mask[seq_mask, :] 181 | 182 | data = _data.copy() 183 | target = data[seq_mask, self.target_idx, :] 184 | data = data[seq_mask, self.avail_idx, :] 185 | # calculate target in prediction window 186 | target = target * mask 187 | pred_window = target[:, min(target.size(1), self.window)] 188 | pred_window_mask = mask[:, min(target.size(1), self.window)] 189 | sequence_mask, target = self.label_func(pred_window, pred_window_mask) # (batch, window) -> (batch, ) 190 | mask[np.logical_not(sequence_mask), :] = False 191 | # target -> label 192 | label = self.label_gen(target) * mask[: None] 193 | if self.norm is not None: 194 | data = self.norm(data, self.avail_idx) 195 | 196 | result = {'data': data, 'mask': mask, 'label': label} 197 | return result -------------------------------------------------------------------------------- /tools/feature_importance.py: -------------------------------------------------------------------------------- 1 | import shap 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import os 5 | import torch 6 | import numpy as np 7 | import random 8 | from captum.attr import DeepLift 9 | import seaborn as sns 10 | import sys 11 | 12 | 13 | class TreeFeatureImportance(): 14 | def __init__(self, map_func, fea_names, missvalue=None, n_approx=-1) -> None: 15 | self.map_func = map_func # map shape value of each class into a total shap value 16 | self.fea_names = fea_names 17 | self.n_approx = n_approx 18 | self.missvalue = missvalue 19 | # register 20 | self.records = [] 21 | 22 | def add_record(self, model, valid_X:np.ndarray): 23 | explainer = shap.Explainer(model) 24 | n_in = min(valid_X.shape[0], self.n_approx) if self.n_approx > 0 else valid_X.shape[0] 25 | permutation = np.random.permutation(valid_X.shape[0])[:n_in] 26 | shap_values = explainer(valid_X[permutation]) 27 | if self.missvalue is not None: 28 | data = shap_values.data # (N, n_fea) 29 | values = shap_values.values # (N, n_fea, n_cls) 30 | values[data == self.missvalue, :] = 0 31 | self.records.append((shap_values.base_values, data, values)) # (sample, n_fea) 32 | else: 33 | self.records.append((shap_values.base_values, shap_values.data, shap_values.values)) # (sample, n_fea) 34 | 35 | def _update_record(self): 36 | if isinstance(self.records, list): 37 | base_values = np.concatenate([record[0] for record in self.records], axis=0) 38 | if len(base_values.shape) == 2: # prediction bias for each category 39 | base_values = np.mean(base_values, axis=-1) 40 | data = np.concatenate([record[1] for record in self.records], axis=0) 41 | shap_values = np.concatenate([record[2] for record in self.records], axis=0) 42 | shap_values = self.map_func(shap_values) 43 | self.records = shap.Explanation(base_values=base_values, data=data, values=shap_values, feature_names=self.fea_names) 44 | 45 | def plot_beeswarm(self, max_disp=20, plot_path=None): 46 | self._update_record() 47 | plt.subplots_adjust(left=0.3) 48 | shap.plots.beeswarm(self.records, order=self.records.abs.mean(0), max_display=max_disp, show=False, plot_size=(14,14)) 49 | plt.savefig(plot_path) 50 | plt.close() 51 | 52 | def plot_single_importance(self, out_dir, select=None): 53 | ''' 54 | 输出每个特征的取值和重要性关系 55 | select: 可以是list/int/None 56 | int: 选择前k个特征输出 57 | None: 输出所有特征 58 | ''' 59 | self._update_record() 60 | imp = self.records.abs.mean(0).values 61 | order = sorted(list(range(len(self.fea_names))), key=lambda x:imp[x], reverse=True) 62 | if isinstance(select, int): 63 | order = order[:min(select, len(order))] 64 | names = [self.fea_names[idx] for idx in order] 65 | for idx, name in zip(order, names): 66 | data = self.records[:,name] 67 | if len(data) >= 20: 68 | data_sorted = np.sort(data.data)[5:-5] 69 | xmin, xmax = data_sorted[0], data_sorted[-1] 70 | plt.subplots_adjust(left=0.3) 71 | shap.plots.scatter(self.records[:,name]) 72 | plt.xlabel(name) 73 | plt.xlim((xmin, xmax)) 74 | plt.savefig(os.path.join(out_dir, f'{name}.png')) 75 | plt.close() 76 | 77 | def get_importance_array(self, feature_ids, fp=sys.stdout): 78 | '''返回特征重要性的排序''' 79 | self._update_record() 80 | imp = self.records.abs.mean(0).values 81 | order = sorted(list(range(len(self.fea_names))), key=lambda x:imp[x], reverse=True) 82 | ids = [feature_ids[idx] for idx in order] 83 | names = [self.fea_names[idx] for idx in order] 84 | result = ['\nFeature Importance\n'] 85 | for id, name in zip(ids, names): 86 | result.append(f'\"{id}\", {name}\n') 87 | fp.writelines(result) 88 | 89 | 90 | class DeepFeatureImportance(): 91 | '''基于intergrated-gradients对深度学习网络计算重要性''' 92 | def __init__(self, device, fea_names) -> None: 93 | self.fea_names = fea_names 94 | self.device = torch.device(device) 95 | # register 96 | self.records = [] 97 | 98 | def add_record(self, model:torch.nn.Module, valid_X:np.ndarray, threshold:int): 99 | ''' 100 | 要求forward_func输入为(batch, seq_len, n_fea) 101 | threshold: 时序上截取的最终时刻 102 | ''' 103 | max_k = min(500, valid_X.shape[0]//2) 104 | valid_X = torch.as_tensor(valid_X, dtype=torch.float32).to(self.device) 105 | model = model.eval().to(self.device) 106 | background = torch.mean(valid_X[:max_k,...], dim=0)[None, ...] 107 | valid = valid_X[max_k:,...] 108 | explainer = DeepLift(model=model) 109 | shap_values = explainer.attribute(valid, background) 110 | shap_values = shap_values # (batch, n_fea, seq_len) 111 | valid = valid.detach().clone().cpu().numpy()[:, :, :threshold] 112 | shap_values = shap_values.detach().clone().cpu().numpy()[:, :, :threshold] 113 | self.records.append((valid, shap_values)) # (batch, n_fea, threshold) 114 | 115 | 116 | def update_record(self): 117 | if isinstance(self.records, list): 118 | data = np.concatenate([record[0] for record in self.records], axis=0) 119 | shap_values = np.concatenate([record[1] for record in self.records], axis=0) 120 | self.records = {} 121 | self.records['exp'] = shap.Explanation(base_values=None, data=np.mean(data, axis=-1), values=np.mean(shap_values, axis=-1), feature_names=self.fea_names) 122 | self.records['data'] = data 123 | self.records['shap_values'] = shap_values 124 | 125 | 126 | def plot_beeswarm(self, plot_path): 127 | self.update_record() 128 | plt.subplots_adjust(left=0.3) 129 | shap.plots.beeswarm(self.records['exp'], order=self.records['exp'].abs.mean(0), max_display=20, show=False, plot_size=(14,14)) 130 | plt.savefig(plot_path) 131 | plt.close() 132 | 133 | def plot_hotspot(self, plot_path): 134 | self.update_record() 135 | shap_values = np.mean(np.abs(self.records['shap_values']), axis=0) # (n_fea, threshold) 136 | shap_values = np.log10(shap_values-np.min(shap_values)+1e-6) 137 | # reference_index = list(self.fea_names).index('age') 138 | # reference_seq = shap_values[reference_index, :].copy() 139 | # for idx in range(shap_values.shape[0]): 140 | # shap_values[idx, :] -= reference_seq 141 | imps = np.mean(shap_values, axis=-1) 142 | sorted_idx = sorted(list(range(len(imps))), key=lambda x:imps[x]) 143 | sorted_names = [self.fea_names[idx] for idx in sorted_idx] 144 | sorted_values = np.asarray([shap_values[idx, :] for idx in sorted_idx]) 145 | time_ticks = [n for n in range(shap_values.shape[-1])] 146 | f, ax = plt.subplots(figsize=(15, 15)) 147 | cmap = sns.diverging_palette(230, 20, as_cmap=True) 148 | sns.heatmap(sorted_values, cmap=cmap, annot=False, yticklabels=sorted_names, xticklabels=time_ticks, 149 | square=True, linewidths=.5, cbar_kws={"shrink": .5}) 150 | plt.xticks(rotation=90) 151 | plt.yticks(rotation=0) 152 | plt.title('Feature importance DeepSHAP', fontsize = 13) 153 | plt.savefig(plot_path) 154 | plt.close() 155 | 156 | def plot_single_importance(self, out_dir, select=None): 157 | ''' 158 | 输出每个特征的取值和重要性关系 159 | select: 可以是list/int/None 160 | int: 选择前k个特征输出 161 | None: 输出所有特征 162 | ''' 163 | self.update_record() 164 | imp = self.records.abs.mean(0).values 165 | order = sorted(list(range(len(self.fea_names))), key=lambda x:imp[x], reverse=True) 166 | if isinstance(select, int): 167 | order = order[:min(select, len(order))] 168 | names = [self.fea_names[idx] for idx in order] 169 | for idx, name in zip(order, names): 170 | plt.subplots_adjust(left=0.3) 171 | shap.plots.scatter(self.records[:,name], ) 172 | plt.savefig(os.path.join(out_dir, f'{name}.png')) 173 | plt.close() 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | 2 | ![Pipeline Overview](documents/general_pipeline.png) 3 | 4 | # MIMIC-IV Data Processing Pipeline 5 | 6 | **[中文版本](README_CN.md) | [English Version](README.md)** 7 | 8 | MIMIC-IV数据集广泛地用于各种医学研究,然而原始数据集并没有经过数据清洗。本框架提供了针对MIMIC-IV的一个高度可配置的Pipeline,以最小化封装、高灵活度、易拓展性为目标。数据处理的代码耦合度低,可以方便地与其他数据集的处理代码合并。框架本身提供了一个默认的处理流程,同时,该框架提供在配置文件和调用接口两个层级的用户自定义配置,可以满足复杂的用户自定义需求。 9 | 10 | ## 架构 11 | 12 | 该框架主要包括三个部分:数据集dataset、模型model、分析器analyzer。dataset将原数据抽象为torch.dataset接口;model对批次输入计算输出;analyzer类似trainer,提供K-fold、指标计算、绘图等工作。将model和analyzer拆分,使得一个analyzer调用多个model进行集成学习、一个model被多个analzyer调用等情况更加方便。其余的tools部分包括共用的工具方法,configs部分为需要配置的字段,例如路径、数据清洗的参数等。 13 | 14 | analyzer: 分析模块 15 | 1. analyzer: 按照序列运行anlayzer,添加新analyzer时需要注册 16 | 2. analyzer_utils: 工具类 17 | 3. container: 存放与模型无关的参数 18 | 4. 其他文件: 每个文件代表一个独立的analyzer, 执行特定的下游任务 19 | 20 | configs: 每个数据集对应的配置文件 21 | 1. global_config: 配置路径 22 | 2. mimiciv_dataset_XXX: 对应`dataset/derived_XXX_dataset.py`,是我们对不同pipeline的示例实现。默认提供三个实现,ards/vent/raw 23 | 1. ards:进行ARDS时序预测的四分类任务 24 | 2. vent:进行ventilation时序预测的二分类任务 25 | 3. raw:用于数据集可视化 26 | 27 | 其他模块: 28 | - data: 数据集文件 29 | - datasets: 数据集抽象, 包括数据提取/清洗/重新组织 30 | - models: 模型 31 | - outputs: 输出文件夹 32 | - tools: 工具类 33 | - main.py: **主入口** 34 | - launch_list.yml 配置程序启动后运行哪些analyzer 35 | 36 | ## 部署方法 37 | 38 | 按照以下步骤部署: 39 | 1. 在`python=3.10.11`环境下配置conda环境,并安装所需的packages:`pip install -r requirements.txt` 40 | 2. 第一步中关于pytorch的cuda版本问题参考下一小节 41 | 3. 将MIMIC-IV数据集解压至`data/mimic-iv`文件夹下, 子文件夹有`hosp`,`icu`等 42 | 4. (可选)如果有MIMIC-IV-ED文件,解压至`ed`子文件夹 43 | 5. 将生成的`sepsis3.csv`存放在`data/mimic-iv/sepsis_result`下,ventilation类似(需要MIMIC-Code的结果才能生成标签) 44 | 6. 运行`python -u main.py`,一次生成 45 | 46 | 安装Pytorch对应的CUDA版本: 47 | 1. 新建并进入一个conda虚拟环境 48 | 2. 输入`nvidia-smi` 查看服务器安装的CUDA版本 49 | 3. 按照 https://pytorch.org/get-started/previous-versions/ 选择linux下的对应版本的安装命令,pytorch对应的CUDA版本可以落后于服务器的CUDA版本 50 | 4. 检查是否安装成功: https://blog.csdn.net/qq_45032341/article/details/105196680 51 | 5. 如果安装了不同于`requirements.txt`中的pytorch版本,将对应的行删掉,避免重复安装 52 | 6. 这个框架本身对第三方库的版本没有严格限制 53 | 54 | ## MIMIC-IV数据集 55 | 56 | 我们将对数据集的处理分为两个部分:`dataset specific processing`和`model & task specific processing`,前者主要处理数据集的内部结构产生的问题,后者对不同的下游任务会进行不同的处理。`datasets/mimiciv/MIMICIV`中的代码用于`dataset specific processing`,在一般情况下用户不需要修改其中的内容。 57 | 58 | `model & task specific processing`由派生类和`data_generator`共同完成。`datasets/mimic_dataset/MIMICIVDataset`是`MIMICIV`的派生类,通过重写以`on_`开头的抽象方法来实现对不同下游任务的灵活处理,用户需要对这部代码进行修改,用于满足不同的研究需求。`MIMICIV`的每个派生类都在`config`中拥有独立的配置文件,通过修改配置文件,可以在不修改代码的情况下调整`MIMICIV`的行为。 59 | 60 | ### 一些重要概念 61 | 62 | `static/dynamic features`: 实际上我们不区分静态和动态特征,静态特征也可以有采样时刻。它们的区别主要是,前者可以不属于任何一个admission,按照时间被分配到某个admission中,而且时间轴不会减去起始时间。后者一定属于某个admission,在采集的时候就确定好了。 63 | 64 | `derived_XXX_dataset/dataset_version`:一个下游任务从`MIMIC_Core`中派生得到derived_XXX_dataset,derived_XXX_dataset可以有多个datset_version 65 | 66 | ### MIMICIV处理流程介绍 67 | 68 | 这部分处理对应`dataset specific processing`,在大多数情况下用户不需要修改其中的内容,只需理解大致的处理流程和接口的定义。按照最小封装原则,仅用`@property`区分内部和外部可见属性,用`@abstractmethod`指示哪些方法需要被派生类重写。 69 | 70 | **数据存取**:`MIMICIV`的数据处理分为7个阶段(phase),每个阶段处理后的数据会存入cache,在`data/mimic-iv-XXX/cache`下(例如, 如果是`mimic_raw`,则存储在`data/mimic-iv-raw/cache`下),一旦所有数据处理完毕,每次实例化时仅载入最终结果对应的cache,去除不必要的IO操作,此时会显示`Bare Mode Enabled`。另外,为了节约空间,cache采用lzma算法压缩,压缩耗时会大于pickle直接存储的耗时,但是对读取几乎没有影响,如果需要省去压缩的时间,修改`configs/mimic_dataset/compress_cache`为`False` 71 | 72 | **preprocess_phase1**:载入`hosp`, `icu`的item映射表,为`ed`创造对应的映射表,不需要实际读取ED的数据。之后调用`on_extract_subjects`引入需要筛选subject列表和外部数据。 73 | 74 | **preprocess_phase2**: 按照subject_id列表读取患者的基本信息,这一阶段先后调用`on_build_subject`, `on_extract_admission` 75 | 76 | **preprocess_phase3**: 从MIMIC-IV中遍历所有的可用数据,并抽取需要的数据。有三个数据源(`icu`, `ed`, `hosp`)通过`configs/mimic_dataset/data_linkage`可以配置哪些数据源不需要读取。在数据抽取之前,通过`on_select_feature`简单地排除哪些特征是不需要的 77 | 78 | **preprocess_phase4**: 将数据转化为数值型,筛选可用的样本,最后将异常值按照`configs/mimic_dataset/value_clip`约束在正常范围内。在前两个过程中,将依次调用`on_convert_numeric`, `on_select_admissions` 79 | 80 | **preprocess_phase5**: 进一步去除高缺失率的特征和样本,生成每个特征的统计数据,这些统计数据在时间轴对齐之前生成,因此不受插值和缺失填充的影响。这一步将调用`on_remove_missing_data` 81 | 82 | **preprocess_phase6**: 对不均匀采样的数据进行插值和对齐,按照一定的时间间隔生成一个三维数组`(n_subject, n_feature, n_sequence)`,不同长度的序列用`-1`填充末尾,最后进行特征工程。该阶段依次调用`on_build_table`, `on_feature_engineering` 83 | 84 | 插值考虑三种情况: 85 | 1. 一个admission的头和尾通过`configs/mimic_dataset/align_target`确定,寻找target features都存在的最早和最晚时间点。插值的时间间隔由`configs/mimic_dataset/delta_t_hour`确定。 86 | 2. 当某个时间点存在历史数据时,选用最近的历史可用数据,不受因果性影响 87 | 3. 当插值起始时间早于该特征已知的历史数据时,用第一个数据点填充空缺的部分。 88 | 89 | **preprocess_phase7**: 按照`configs/mimic_dataset/version`生成同一个数据集的不同衍生版本,用户可以在配置中自定义特征的限制范围和排除范围,指定缺失值的处理方法,目前支持`avg`均值填充和`none`缺省(-1)填充。不同版本可以分别用于数据探查和模型训练。这一阶段也会设置固定的K-fold,保持训练的稳定性 90 | 91 | ### MIMICIV derived dataset介绍 92 | 93 | 这部分对应`datasets/mimic_dataset`中的函数,建议用户只修改必要的部分,并且明确接口的输入输出格式。这个框架对于用户自定义行为几乎没有检查,如果用户修改的部分引入新的问题(例如产生NaN),可能会使得Pipeline的其他部分报错。 94 | 95 | **on_extract_subjects** 96 | 97 | 输出: dict[key=int(subject_id), value=object(extra_data)] 98 | 1. subject_id需要覆盖所有采集的患者,可以和其他数据pipeline链接,比如通过读取`mimic-code`的衍生表得到 99 | 2. `extra_data`包含subject对应的附加数据,格式不受限制,仅在`on_build_subject`中被用户处理 100 | 101 | **on_build_subject** 102 | 103 | 创建Subject实例 104 | 105 | 输入: 106 | 1. subject_id: int 107 | 2. subject: Subject类型,通过`append_static`方法添加新数据 108 | 3. row: `hosp/patients.csv`中对应subject的行数据,包括`gender`, `dod`等 109 | 4. _extract_result: 在`on_extract_subjects`中提取的数据 110 | 111 | 输出: subject,这是一个reference,实质上添加内容是通过`append_static`方法 112 | 113 | **on_extract_admission** 114 | 115 | 向subjects中添加Admission实例 116 | 117 | 输入: 118 | 1. subject: Subject 119 | 2. source: str, 取值可以是(`admission`, `icu`, `transfer`) 分别提取自 `hosp/admission.csv`, `icu/icu_stays.csv`, `hosp/transfers.csv` 120 | 3. row: namedtuple, 从`source`对应的表中提取的行 121 | 122 | 输出:bool,是否添加了admission,用于统计 123 | 124 | 通过`Subject.append_static()`和`Admission.append_dyanmic()`添加数据 125 | 126 | **on_select_feature** 127 | 128 | 输入: 129 | 1. id: int, subject_id 130 | 2. source: str, 取值可以是(`icu`, `hosp`, `ed`), 源文件分别来自于`icu/d_items.csv`, `hosp/d_labitems.csv`, `ed`具有内置的格式 131 | 3. row: dict, key的含义参考`source`对应源文件中各列的含义 132 | - source=`icu`, key=[`id`, `label`, `category`, `type`, `low`, `high`] 133 | - source=`hosp`, key=[`id`, `label`, `fluid`, `category`] 134 | - source=`ed`, key=[`id`, `link_id`, `label`] 其中`link_id`将ED的基本生命体征链接到ICU到生命体征中,不作为单独一列 135 | 136 | 输出:bool, False表示不选择该特征,True表示选择。注意这一阶段只需去除不需要的特征,选择的特征还会被进一步筛选 137 | 138 | **on_convert_numeric** 139 | 140 | 不建议重写该方法,应当仿造示例仅重写判断语句中的内容 141 | 142 | `Subject.static_data`: dict[key=str, value=list[tuple(value, time)]],修改非数值型特征的值使得所有特征都变为数值型 143 | 144 | **on_select_admissions(remove_invalid_pass1)** 145 | 146 | 按照条件筛选可用的样本,不建议重写整个方法,应该按照任务需要修改筛选条件 147 | 148 | 输入: 149 | 1. rule: dict, 来源是`configs/mimiciv_dataset/remove_rule/pass1` 150 | 2. subjects: dict[subject_id: int, Subject] 151 | 152 | 默认的筛选条件用于筛选Sepsis患者,需要全部满足: 153 | 1. 是否包含`target_id`中的全部特征,如果不包含,在某些任务中无法生成时序标签 154 | 2. 时长是否满足条件,太短的序列无法用于预测,太长的序列会影响padding后表的大小 155 | 3. 关于`Sepsis time`的筛选条件,在非Sepsis任务中不需要 156 | 157 | **on_remove_missing_data(remove_invalid_pass2)** 158 | 159 | 进一步筛选可用的样本和特征,除非必要,建议只修改配置文件中的missrate,不修改代码 160 | 161 | 默认的算法进行多步迭代,同时删除高缺失的特征和样本,迭代步骤由配置文件中的`max_col_missrate`和`max_subject_missrate`的列表长度决定,列表末尾的missrate是最终的missrate 162 | 163 | `adm_select_strategy`有两种选择,`random`表示在每个subject的可用admission中随机选择一个,`default`表示默认选择第一个有效的admission 164 | 165 | **on_build_table** 166 | 167 | 在确定序列开始时间时,修改某些特征的取值,这些特征往往和时间有关 168 | 169 | 输入: 170 | 1. subject: Subject 171 | 2. key: str, 当前修改的特征id 172 | 3. value: float:原值 173 | 4. t_start: float,表示当前序列的起始时间距离admit time的时间,是正数 174 | 175 | 输出:new_value: float, 当前特征的新值 176 | 177 | **on_feature_engineering** 178 | 179 | 在插值后,根据已有特征计算新的特征。不建议重写整个方法,可以修改判断语句中的一部分。注意原始特征需要考虑缺失(-1)的情况,如果原始特征缺失,新特征也应当设置为-1 180 | 181 | 输入: 182 | 1. tables: list[np.ndarray[n_features, n_sequences]], 制表后的数据,没有padding 183 | 2. norm_dict: dict[key='mean' or 'std', value=float], 存储特征的统计数据,新特征需要更新它 184 | 3. static_keys, dynamic_keys: list[str],新特征需要更新它们 185 | 186 | 输出:更新后的tables, norm_dict, static_keys, dynamic_keys 187 | 188 | ## Data generator介绍 189 | 190 | `model & task specific processing`不总是能被pipeline解决,例如,不同模型需要不同的归一化方法、时序和非时序任务需要不同的数据组织格式。我们提供了多种generator生成不同的数据格式和标签 191 | 192 | 总的来说,有3种DataGenerator和两种LabelGenerator,它们在`models/utils`下,用户可以很方便地开发新的Generator。我们提供了几种示例算法,说明不同的generator是如何用于不同任务的。 193 | 194 | ### label generator 195 | 196 | **LabelGenerator_4cls**:将任意大小的target数组拓展一个新的维度,按照事先设定的大小划分为四个类别,用于分类预测 197 | 198 | **LabelGenerator_regression**: 不进行任何处理,可以修改用于其他任务 199 | 200 | ### data generator 201 | 202 | 不同的data generator和label generator可以任意组合,参考示例算法: `analyzer/method_nearest_cls` 203 | 204 | **DynamicDataGenerator**: 对每个数据点,向前搜索一个预测窗口,并计算窗口内的最低值作为预测目标。它也会采集可用的特征,和`dataset version`类似,并且可以进行归一化。最终结果依然是一个三维结构,时间轴不会被展开。可以用于LSTM等RNN算法。 205 | 206 | **SliceDataGenerator**:按照`DynamicDataGenerator`的方式计算,但是最终数据的时间轴会被展开,成为(n_sequence, n_feature)的形式,无效数据会被剔除。用于GBDT等非时序方法 207 | 208 | **StaticDataGenerator**:对每个序列,生成一个大预测窗口,用第一时刻的特征预测终局情况。每个序列生成一个样本。 209 | 210 | ## 更多帮助 211 | 212 | [如何生成Sepsis3.csv](documents/processing.md) 213 | 214 | -------------------------------------------------------------------------------- /analyzer/vent_lstm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tools 3 | import os 4 | from tools import logger as logger 5 | from .container import DataContainer 6 | from tools.data import DynamicDataGenerator, LabelGenerator_cls, vent_label_func, unroll 7 | from datasets.derived_vent_dataset import MIMICIV_Vent_Dataset 8 | from models.vent_lstm_model import VentLSTMModel 9 | from torch.utils.data.dataloader import DataLoader 10 | import os 11 | from os.path import join as osjoin 12 | import torch 13 | from tools.data import Collect_Fn 14 | from tools.logging import SummaryWriter 15 | 16 | class VentLSTMTrainer(): 17 | def __init__(self, params:dict, dataset:MIMICIV_Vent_Dataset, generator:DynamicDataGenerator) -> None: 18 | self.params = params 19 | self.dataset = dataset 20 | self.generator = generator 21 | self.train_dataloader = DataLoader(dataset=self.dataset, batch_size=params['batch_size'], shuffle=True, collate_fn=Collect_Fn) 22 | self.valid_dataloader = DataLoader(dataset=self.dataset, batch_size=1, shuffle=False, collate_fn=Collect_Fn) 23 | self.test_dataloader = DataLoader(dataset=self.dataset, batch_size=1, shuffle=False, collate_fn=Collect_Fn) 24 | 25 | def load_checkpoint(self, p_checkpoint:str): 26 | result = torch.load(p_checkpoint) 27 | self.model = result['model'].to(self.params['device']) 28 | 29 | def save(self, save_path): 30 | torch.save({ 31 | 'model': self.model 32 | }, save_path) 33 | 34 | def cal_label_weight(self, phase:str, dataset:MIMICIV_Vent_Dataset, generator:DynamicDataGenerator): 35 | dataset.mode(phase) 36 | result = None 37 | dl = DataLoader(dataset=dataset, batch_size=2000, shuffle=False, collate_fn=Collect_Fn) 38 | for batch in dl: 39 | data_dict = generator(batch['data'], batch['length']) 40 | mask = tools.make_mask((batch['data'].shape[0], batch['data'].shape[2]), batch['length']) 41 | Y_gt:np.ndarray = data_dict['label'] 42 | weight = np.sum(unroll(Y_gt, mask=mask), axis=0) 43 | result = weight if result is None else result + weight 44 | result = 1 / result 45 | result = result / result.min() 46 | logger.info(f'Label weight: {result}') 47 | return result 48 | 49 | def train(self, fold_idx:int, summary_writer:SummaryWriter, cache_dir:str): 50 | self.record = {} 51 | # create model 52 | self.model = VentLSTMModel( 53 | in_channels=len(self.generator.avail_idx), 54 | n_cls=len(self.params['centers']), 55 | hidden_size=self.params['hidden_size'] 56 | ).to(self.params['device']) # TODO add sample weight 57 | self.opt = torch.optim.Adam(self.model.parameters(), lr=self.params['lr']) 58 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, self.params['epoch'], eta_min=self.params['lr']*0.1) 59 | train_cls_weight = self.cal_label_weight('train', self.dataset, self.generator) 60 | self.criterion = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.as_tensor(train_cls_weight, device=self.params['device'])) 61 | for epoch in range(self.params['epoch']): 62 | self.record[f'{fold_idx}_epoch_train_loss'], self.record[f'{fold_idx}_epoch_valid_loss'] = 0, 0 63 | for phase in ['train', 'valid']: 64 | self.dataset.mode(phase) 65 | for batch in self.train_dataloader: 66 | data_dict = self.generator(batch['data'], batch['length']) 67 | mask = torch.as_tensor( 68 | tools.make_mask((batch['data'].shape[0], batch['data'].shape[2]), batch['length']), 69 | dtype=bool, device=self.params['device'] 70 | ) 71 | X = torch.as_tensor(data_dict['data'], dtype=torch.float32, device=self.params['device']) 72 | Y_gt:torch.Tensor = torch.as_tensor(data_dict['label'], dtype=torch.float32, device=self.params['device']) 73 | Y_pred:torch.Tensor = self.model(X) 74 | loss = torch.sum(self.criterion(Y_pred.permute((0, 2, 1)), Y_gt.permute((0, 2, 1))) * mask) / torch.sum(mask) 75 | if phase == 'train': 76 | self.opt.zero_grad() 77 | loss.backward() 78 | self.opt.step() 79 | self.record[f'{fold_idx}_epoch_{phase}_loss'] += loss.detach().cpu().item() * batch['data'].shape[0] / len(self.dataset) 80 | summary_writer.add_scalar(f'{fold_idx}_epoch_{phase}_loss', self.record[f'{fold_idx}_epoch_{phase}_loss'], global_step=epoch) 81 | self.scheduler.step() 82 | if (not 'best_valid_loss' in self.record) or (self.record[f'{fold_idx}_epoch_valid_loss'] < self.record['best_valid_loss']): 83 | self.record['best_valid_loss'] = self.record[f'{fold_idx}_epoch_valid_loss'] 84 | logger.info(f'Save model at epoch {epoch}, valid loss={self.record[f"{fold_idx}_epoch_valid_loss"]}') 85 | self.save(osjoin(cache_dir, 'model_best.pth')) 86 | logger.info('Done') 87 | 88 | def predict(self): 89 | self.dataset.mode('test') 90 | result = {'Y_gt':[], 'Y_pred':[], 'mask':[]} 91 | with torch.no_grad(): 92 | for batch in self.train_dataloader: 93 | data_dict = self.generator(batch['data'], batch['length']) 94 | mask = tools.make_mask((batch['data'].shape[0], batch['data'].shape[2]), batch['length']) 95 | X = torch.as_tensor(data_dict['data'], dtype=torch.float32, device=self.params['device']) 96 | Y_gt:torch.Tensor = torch.as_tensor(data_dict['label'], dtype=torch.float32, device=self.params['device']) 97 | Y_pred:torch.Tensor = torch.softmax(self.model(X), dim=-1) 98 | result['Y_pred'].append(Y_pred.cpu().numpy()) 99 | result['Y_gt'].append(Y_gt.cpu().numpy()) 100 | result['mask'].append(mask) 101 | return { 102 | 'Y_pred': np.concatenate(result['Y_pred'], axis=0), 103 | 'Y_gt': np.concatenate(result['Y_gt'], axis=0), 104 | 'mask': np.concatenate(result['mask'], axis=0) 105 | } 106 | 107 | 108 | class VentLSTMAnalyzer: 109 | def __init__(self, params:dict, container:DataContainer) -> None: 110 | self.params = params 111 | self.paths = params['paths'] 112 | self.dataset = MIMICIV_Vent_Dataset() 113 | self.dataset.load_version(params['dataset_version']) 114 | self.model_name = self.params['analyzer_name'] 115 | self.target_idx = self.dataset.idx_dict['vent_status'] 116 | 117 | def run(self): 118 | # step 1: init variables 119 | out_dir = os.path.join(self.paths['out_dir'], self.model_name) 120 | tools.reinit_dir(out_dir, build=True) 121 | 122 | metric_2cls = tools.DichotomyMetric() 123 | 124 | generator = DynamicDataGenerator( 125 | window_points=self.params['window'], 126 | n_fea=len(self.dataset.total_keys), 127 | label_generator=LabelGenerator_cls( 128 | centers=self.params['centers'] 129 | ), 130 | label_func=vent_label_func, # predict most severe ventilation in each bin 131 | target_idx=self.target_idx, 132 | limit_idx=[], 133 | forbidden_idx=[self.dataset.idx_dict[id] for id in ['vent_status']] 134 | ) 135 | feature_names = [self.dataset.fea_label(idx) for idx in generator.avail_idx] 136 | print(f'Available features: {feature_names}') 137 | 138 | summary_writer = SummaryWriter() 139 | # step 2: train and predict 140 | for fold_idx, _ in enumerate(self.dataset.enumerate_kf()): 141 | trainer = VentLSTMTrainer(params=self.params, dataset=self.dataset, generator=generator) 142 | cache_dir = osjoin(out_dir, f'fold_{fold_idx}') 143 | tools.reinit_dir(cache_dir, build=True) 144 | trainer.train(fold_idx=fold_idx, summary_writer=summary_writer, cache_dir=cache_dir) 145 | out_dict = trainer.predict() 146 | out_dict['Y_pred'] = unroll(out_dict['Y_pred'], out_dict['mask']) 147 | out_dict['Y_gt'] = unroll(out_dict['Y_gt'], out_dict['mask']) 148 | 149 | metric_2cls.add_prediction(out_dict['Y_pred'][:, 1], out_dict['Y_gt'][:, 1]) 150 | 151 | metric_2cls.plot_curve(curve_type='roc', title=f'ROC for ventilation', save_path=osjoin(out_dir, f'vent_roc.png')) 152 | metric_2cls.plot_curve(curve_type='prc', title=f'PRC for ventilation', save_path=osjoin(out_dir, f'vent_prc.png')) 153 | 154 | for phase in ['train', 'valid']: 155 | summary_writer.plot(tags=[f'{fold}_epoch_{phase}_loss' for fold in range(5)], 156 | k_fold=True, log_y=True, title=f'{phase} loss for vent LSTM', out_path=osjoin(out_dir, f'{phase}_loss.png')) 157 | with open(os.path.join(out_dir, 'result.txt'), 'w') as fp: 158 | print('Overall performance:', file=fp) 159 | metric_2cls.write_result(fp) -------------------------------------------------------------------------------- /datasets/cv_dataset.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import tools 3 | from os.path import join as osjoin, exists 4 | import pickle 5 | import numpy as np 6 | import pandas as pd 7 | from tools import GLOBAL_CONF_LOADER 8 | from tools import logger 9 | from tqdm import tqdm 10 | import time 11 | from sklearn.model_selection import KFold 12 | 13 | class Converter: 14 | def __init__(self) -> None: 15 | # self.date_converter = datetime.datetime() 16 | self.col = None 17 | 18 | def switch_col(self, new_col): 19 | self.col = new_col 20 | 21 | def __call__(self, x): 22 | if self.col == '性别': 23 | assert(x in ['男', '女']) 24 | return 1 if x == '男' else 0 25 | elif self.col == '出生年月': # 只保留年份 26 | assert(x != '') 27 | return int(x.split('T ')[-1]) 28 | elif self.col == '入住ICU日期': # 只保留年份 29 | return int(x.split('-')[0]) 30 | else: 31 | try: 32 | x = float(x) 33 | return x 34 | except Exception as e: 35 | return -1 36 | 37 | 38 | class CrossValidationDataset(): 39 | __name = 'cv' 40 | 41 | @classmethod 42 | def name(cls): 43 | return cls.__name 44 | 45 | def __init__(self) -> None: 46 | self.paths = GLOBAL_CONF_LOADER['paths']['cv'] 47 | self.loc_conf = tools.Config(self.paths['conf_manual_path']) 48 | self.data = None # (samples, fea, ticks) 49 | self.total_keys = None 50 | self.preprocess() 51 | self.seqs_len = np.ones(self.data.shape[0], dtype=np.int64) * (7*24*2) 52 | 53 | def preprocess(self, load_pkl=False): 54 | pkl_path = osjoin(self.paths['cache_dir'], 'data.pkl') 55 | if exists(pkl_path) and load_pkl: 56 | with open(pkl_path, 'rb') as fp: 57 | result = pickle.load(fp) 58 | self.data = result['data'] 59 | self.total_keys = result['total_keys'] 60 | return 61 | # preprocess 62 | if not exists(self.paths['cache_dir']): 63 | tools.reinit_dir(self.paths['cache_dir']) 64 | data_path = self.paths['data_dir'] 65 | data = pd.read_csv(data_path) 66 | dt_head = [f'd{d}_t{t}_' for t in range(1,5) for d in range(1,8)] 67 | t_head = [f'd{d}_' for d in range(1,8)] 68 | # step1: 展开需要提取的列 69 | extract_cols = self.loc_conf['extract_cols'] 70 | col_maps = {} # 特征和其对应名字的从属关系 71 | expanded_cols = [] 72 | for col in extract_cols: 73 | col = str(col) 74 | if col.startswith('#'): 75 | col = col.replace('#', '') 76 | sublist = [] 77 | for day in range(1, 8): 78 | sublist.append(col.replace('dX', f'd{day}')) 79 | if 'tX' in col: 80 | subsublist = [] 81 | for s in sublist: 82 | for time in range(1, 5): 83 | subsublist.append(s.replace('tX', f't{time}')) 84 | expanded_cols += subsublist 85 | col_maps[col] = subsublist 86 | else: 87 | expanded_cols += sublist 88 | col_maps[col] = sublist 89 | else: 90 | expanded_cols.append(col) 91 | col_maps[col] = [col] 92 | for col in expanded_cols: 93 | assert(col in data.columns) 94 | data = data[expanded_cols] 95 | # step2: 均值填充 96 | converter = Converter() 97 | for col in data.columns: 98 | converter.switch_col(col) 99 | data[col] = data[col].apply(converter) 100 | col_data = data[col].to_numpy() 101 | fill_flag = (col_data < -0.99) * (col_data > -1.01) 102 | col_mean = None 103 | if np.any(fill_flag): 104 | col_mean = np.sum(col_data * (1-fill_flag)) / np.sum(1-fill_flag) 105 | data.loc[fill_flag, col] = col_mean 106 | assert(np.sum(np.abs(data.to_numpy(np.float32) + 1) < 0.01) == 0) # 确保所有均值都被填充了 107 | # step3: 静态特征加工 108 | deprecated_cols = [ 109 | '入住ICU日期', '出生年月', 'dX_tX_最高心率(次/min)', 'dX_tX_最低心率(次/min)', 110 | 'dX_tX_最高SPO2(%)', 'dX_tX_最低SPO2(%)', 'dX_tX_最高体温(℃)', 111 | 'dX_tX_最高呼吸频率(次/min)', 'dX_tX_最低呼吸频率(次/min)' 112 | ] 113 | data['年龄'] = data['入住ICU日期'] - data['出生年月'] 114 | col_maps['年龄'] = ['年龄'] 115 | for idx, head in enumerate(dt_head): 116 | if idx == 0: 117 | col_maps['dX_tX_心率'] = [] 118 | col_maps['dX_tX_SPO2'] = [] 119 | col_maps['dX_tX_体温'] = [] 120 | col_maps['dX_tX_呼吸频率'] = [] # 次/min 121 | 122 | data[head + '心率'] = 0.5 * data[head + '最高心率(次/min)'] + 0.5 * data[head + '最低心率(次/min)'] 123 | col_maps['dX_tX_心率'].append(head + '心率') 124 | data[head + 'SPO2'] = 0.5 * data[head + '最高SPO2(%)'] + 0.5 * data[head + '最低SPO2(%)'] 125 | col_maps['dX_tX_SPO2'].append(head + 'SPO2') 126 | data[head + '体温'] = data[head + '最高体温(℃)'] 127 | col_maps['dX_tX_体温'].append(head + '体温') 128 | data[head + '呼吸频率(次/min)'] = 0.5 * data[head + '最高呼吸频率(次/min)'] + 0.5 * data[head + '最低呼吸频率(次/min)'] 129 | col_maps['dX_tX_呼吸频率'].append(head + '呼吸频率(次/min)') 130 | 131 | for idx, head in enumerate(t_head): 132 | if idx == 0: 133 | col_maps['dX_FiO2(%)'] = [] 134 | data[head + 'FiO2(%)'] = data[head + 'PaO2(mmHg)'] / data[head + 'PaO2(mmHg) / FiO2(%)'] 135 | col_maps['dX_FiO2(%)'].append(head + 'FiO2(%)') 136 | for feas in deprecated_cols: 137 | col_maps.pop(feas) 138 | data = data.copy() 139 | total_keys = list(col_maps.keys()) 140 | total_keys = sorted(total_keys, key=lambda x: len(x) if x != 'dX_PaO2(mmHg) / FiO2(%)' else 999) # move target to the end 141 | # step4: 沿时间轴展开 142 | table = -np.ones((len(data.index), len(col_maps), 28)) 143 | for idx, col in enumerate(total_keys): 144 | if len(col_maps[col]) > 1: 145 | if len(col_maps[col]) == 7: # 只有天数,没有间隔 146 | for sub_idx, subcol in enumerate(col_maps[col]): 147 | np_arr = data[subcol].to_numpy()[..., None] 148 | table[:, idx, 4*sub_idx:4*(sub_idx+1)] = np_arr 149 | elif len(col_maps[col]) == 28: 150 | for sub_idx, subcol in enumerate(col_maps[col]): 151 | np_arr = data[subcol].to_numpy() 152 | table[:, idx, sub_idx] = np_arr 153 | else: 154 | assert(0) 155 | else: 156 | assert(col in data.columns) 157 | np_arr = data[col].to_numpy()[..., None] # (subjects, 1) 158 | table[:, idx, :] = np_arr# 没有时间信息,直接展开 159 | # step5: 插值得到最小粒度为半小时的表 160 | assert(np.sum((table < -0.99) * (table > -1.01)) == 0) 161 | final_table = np.zeros((table.shape[0], table.shape[1], 7*24*2)) 162 | for t in range(28): 163 | final_table[:, :, 12*t:12*(t+1)] = table[:,:,[t]] 164 | # step7: 清除异常值 165 | for idx, col in enumerate(total_keys): 166 | if col == '体重(kg)': 167 | final_table[:, idx, :] = np.clip(final_table[:, idx, :], 40, 200) 168 | elif col == '身高(cm)': 169 | final_table[:, idx, :] = np.clip(final_table[:, idx, :], 150, 200) 170 | elif col == 'dX_tX_心率': 171 | final_table[:, idx, :] = np.clip(final_table[:, idx, :], 50, 200) 172 | elif col == 'dX_tX_体温': 173 | final_table[:, idx, :] = np.clip(final_table[:, idx, :], 30, 45) 174 | elif col == 'dX_tX_SPO2': # 0-100 175 | final_table[:, idx, :] = np.clip(final_table[:, idx, :], 60, 100) 176 | elif col == 'dX_FiO2(%)': # 0-1 177 | final_table[:, idx, :] = np.clip(final_table[:, idx, :], 0.21, 1) 178 | elif col == 'dX_PaO2(mmHg)': 179 | final_table[:, idx, :] = np.clip(final_table[:, idx, :], 100, 500) 180 | elif col == 'dX_PaO2(mmHg) / FiO2(%)': 181 | final_table[:, idx, :] = np.clip(final_table[:, idx, :], 100, 500) 182 | # step6: 打印结果 183 | print('total keys: ', total_keys) 184 | print('table size: ', final_table.shape) 185 | for idx, col in enumerate(total_keys): 186 | print(f'col{idx}: ', col, 'avg: ', final_table[:, idx, :].mean(), 187 | 'min: ', final_table[:, idx, :].min(), 'max: ', final_table[:, idx, :].max()) 188 | # last step: 加载并保存信息 189 | self.data = final_table 190 | self.total_keys = total_keys 191 | with open(pkl_path, 'wb') as fp: 192 | pickle.dump({ 193 | 'data': final_table, 194 | 'total_keys': total_keys 195 | }, fp) 196 | 197 | def mode(self, mode=['train', 'valid', 'test', 'all']): 198 | '''切换dataset的模式, train/valid/test需要在register_split方法调用后才能使用''' 199 | pass 200 | 201 | def __getitem__(self, idx): 202 | return {'data': self.data[idx, :, :], 'length': self.data.shape[-1]} 203 | 204 | def __len__(self): 205 | return self.data.shape[0] -------------------------------------------------------------------------------- /analyzer/ards_lstm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tools 3 | import os 4 | from tools import logger as logger 5 | from .container import DataContainer 6 | from tools.data import DynamicDataGenerator, LabelGenerator_cls, label_func_min, unroll, map_func 7 | from datasets.derived_ards_dataset import MIMICIV_ARDS_Dataset 8 | from models.ards_lstm_model import ArdsLSTMModel 9 | from torch.utils.data.dataloader import DataLoader 10 | import os 11 | from os.path import join as osjoin 12 | import torch 13 | from tools.data import Collect_Fn 14 | from tools.logging import SummaryWriter 15 | 16 | class ArdsLSTMTrainer(): 17 | def __init__(self, params:dict, dataset:MIMICIV_ARDS_Dataset, generator:DynamicDataGenerator) -> None: 18 | self.params = params 19 | self.dataset = dataset 20 | self.generator = generator 21 | self.train_dataloader = DataLoader(dataset=self.dataset, batch_size=params['batch_size'], shuffle=True, collate_fn=Collect_Fn) 22 | self.valid_dataloader = DataLoader(dataset=self.dataset, batch_size=1, shuffle=False, collate_fn=Collect_Fn) 23 | self.test_dataloader = DataLoader(dataset=self.dataset, batch_size=1, shuffle=False, collate_fn=Collect_Fn) 24 | 25 | def load_checkpoint(self, p_checkpoint:str): 26 | result = torch.load(p_checkpoint) 27 | self.model = result['model'].to(self.params['device']) 28 | 29 | def save(self, save_path): 30 | torch.save({ 31 | 'model': self.model 32 | }, save_path) 33 | 34 | def cal_label_weight(self, phase:str, dataset:MIMICIV_ARDS_Dataset, generator:DynamicDataGenerator): 35 | dataset.mode(phase) 36 | result = None 37 | dl = DataLoader(dataset=dataset, batch_size=2000, shuffle=False, collate_fn=Collect_Fn) 38 | for batch in dl: 39 | data_dict = generator(batch['data'], batch['length']) 40 | mask = tools.make_mask((batch['data'].shape[0], batch['data'].shape[2]), batch['length']) 41 | Y_gt:np.ndarray = data_dict['label'] 42 | weight = np.sum(unroll(Y_gt, mask=mask), axis=0) 43 | result = weight if result is None else result + weight 44 | result = 1 / result 45 | result = result / result.min() 46 | logger.info(f'Label weight: {result}') 47 | return result 48 | 49 | def train(self, fold_idx:int, summary_writer:SummaryWriter, cache_dir:str): 50 | self.record = {} 51 | # create model 52 | self.model = ArdsLSTMModel( 53 | in_channels=len(self.generator.avail_idx), 54 | n_cls=len(self.params['centers']), 55 | hidden_size=self.params['hidden_size'] 56 | ).to(self.params['device']) # TODO add sample weight 57 | self.opt = torch.optim.Adam(self.model.parameters(), lr=self.params['lr']) 58 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, self.params['epoch'], eta_min=self.params['lr']*0.1) 59 | train_cls_weight = self.cal_label_weight('train', self.dataset, self.generator) 60 | self.criterion = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.as_tensor(train_cls_weight, device=self.params['device'])) 61 | for epoch in range(self.params['epoch']): 62 | self.record[f'{fold_idx}_epoch_train_loss'], self.record[f'{fold_idx}_epoch_valid_loss'] = 0, 0 63 | for phase in ['train', 'valid']: 64 | self.dataset.mode(phase) 65 | for batch in self.train_dataloader: 66 | data_dict = self.generator(batch['data'], batch['length']) 67 | mask = torch.as_tensor( 68 | tools.make_mask((batch['data'].shape[0], batch['data'].shape[2]), batch['length']), 69 | dtype=bool, device=self.params['device'] 70 | ) 71 | X = torch.as_tensor(data_dict['data'], dtype=torch.float32, device=self.params['device']) 72 | Y_gt:torch.Tensor = torch.as_tensor(data_dict['label'], dtype=torch.float32, device=self.params['device']) 73 | Y_pred:torch.Tensor = self.model(X) 74 | loss = torch.sum(self.criterion(Y_pred.permute((0, 2, 1)), Y_gt.permute((0, 2, 1))) * mask) / torch.sum(mask) 75 | if phase == 'train': 76 | self.opt.zero_grad() 77 | loss.backward() 78 | self.opt.step() 79 | self.record[f'{fold_idx}_epoch_{phase}_loss'] += loss.detach().cpu().item() * batch['data'].shape[0] / len(self.dataset) 80 | summary_writer.add_scalar(f'{fold_idx}_epoch_{phase}_loss', self.record[f'{fold_idx}_epoch_{phase}_loss'], global_step=epoch) 81 | self.scheduler.step() 82 | if (not 'best_valid_loss' in self.record) or (self.record[f'{fold_idx}_epoch_valid_loss'] < self.record['best_valid_loss']): 83 | self.record['best_valid_loss'] = self.record[f'{fold_idx}_epoch_valid_loss'] 84 | logger.info(f'Save model at epoch {epoch}, valid loss={self.record[f"{fold_idx}_epoch_valid_loss"]}') 85 | self.save(osjoin(cache_dir, 'model_best.pth')) 86 | logger.info('Done') 87 | 88 | def predict(self): 89 | self.dataset.mode('test') 90 | result = {'Y_gt':[], 'Y_pred':[], 'mask':[]} 91 | with torch.no_grad(): 92 | for batch in self.train_dataloader: 93 | data_dict = self.generator(batch['data'], batch['length']) 94 | mask = tools.make_mask((batch['data'].shape[0], batch['data'].shape[2]), batch['length']) 95 | X = torch.as_tensor(data_dict['data'], dtype=torch.float32, device=self.params['device']) 96 | Y_gt:torch.Tensor = torch.as_tensor(data_dict['label'], dtype=torch.float32, device=self.params['device']) 97 | Y_pred:torch.Tensor = torch.softmax(self.model(X), dim=-1) 98 | result['Y_pred'].append(Y_pred.cpu().numpy()) 99 | result['Y_gt'].append(Y_gt.cpu().numpy()) 100 | result['mask'].append(mask) 101 | return { 102 | 'Y_pred': np.concatenate(result['Y_pred'], axis=0), 103 | 'Y_gt': np.concatenate(result['Y_gt'], axis=0), 104 | 'mask': np.concatenate(result['mask'], axis=0) 105 | } 106 | 107 | 108 | class ArdsLSTMAnalyzer: 109 | def __init__(self, params:dict, container:DataContainer) -> None: 110 | self.params = params 111 | self.paths = params['paths'] 112 | self.dataset = MIMICIV_ARDS_Dataset() 113 | self.dataset.load_version(params['dataset_version']) 114 | self.model_name = self.params['analyzer_name'] 115 | self.target_idx = self.dataset.idx_dict['PF_ratio'] 116 | 117 | def run(self): 118 | # step 1: init variables 119 | out_dir = os.path.join(self.paths['out_dir'], self.model_name) 120 | tools.reinit_dir(out_dir, build=True) 121 | 122 | metric_2cls = [tools.DichotomyMetric() for _ in range(len(self.params['class_names']))] 123 | metric_4cls = tools.MultiClassMetric(class_names=self.params['class_names'], out_dir=out_dir) 124 | 125 | generator = DynamicDataGenerator( 126 | window_points=self.params['window'], 127 | n_fea=len(self.dataset.total_keys), 128 | label_generator=LabelGenerator_cls( 129 | centers=self.params['centers'] 130 | ), 131 | label_func=label_func_min, 132 | target_idx=self.target_idx, 133 | limit_idx=[self.dataset.fea_idx(id) for id in self.params['limit_feas']], 134 | forbidden_idx=[self.dataset.fea_idx(id) for id in self.params['forbidden_feas']] 135 | ) 136 | feature_names = [self.dataset.fea_label(idx) for idx in generator.avail_idx] 137 | print(f'Available features: {feature_names}') 138 | summary_writer = SummaryWriter() 139 | # step 2: train and predict 140 | for fold_idx, _ in enumerate(self.dataset.enumerate_kf()): 141 | trainer = ArdsLSTMTrainer(params=self.params, dataset=self.dataset, generator=generator) 142 | cache_dir = osjoin(out_dir, f'fold_{fold_idx}') 143 | tools.reinit_dir(cache_dir, build=True) 144 | trainer.train(fold_idx=fold_idx, summary_writer=summary_writer, cache_dir=cache_dir) 145 | out_dict = trainer.predict() 146 | out_dict['Y_pred'] = unroll(out_dict['Y_pred'], out_dict['mask']) 147 | out_dict['Y_gt'] = unroll(out_dict['Y_gt'], out_dict['mask']) 148 | 149 | metric_4cls.add_prediction(out_dict['Y_pred'], out_dict['Y_gt']) # 去掉mask外的数据 150 | for idx, map_dict in zip([0,1,2,3], [{0:0,1:1,2:1,3:1}, {0:0,1:1,2:0,3:0}, {0:0,1:0,2:1,3:0}, {0:0,1:0,2:0,3:1}]): # TODO 这里写错了 151 | metric_2cls[idx].add_prediction(map_func(out_dict['Y_pred'], map_dict)[:, 1], map_func(out_dict['Y_gt'], map_dict)[:, 1]) 152 | 153 | metric_4cls.confusion_matrix(comment=self.model_name) 154 | for idx in range(4): 155 | metric_2cls[idx].plot_curve(curve_type='roc', title=f'ROC for {self.params["class_names"][idx]}', save_path=osjoin(out_dir, f'roc_cls_{idx}.png')) 156 | 157 | for phase in ['train', 'valid']: 158 | summary_writer.plot(tags=[f'{fold}_epoch_{phase}_loss' for fold in range(5)], 159 | k_fold=True, log_y=True, title=f'{phase} loss for ards LSTM', out_path=osjoin(out_dir, f'{phase}_loss.png')) 160 | with open(os.path.join(out_dir, 'result.txt'), 'w') as fp: 161 | print('Overall performance:', file=fp) 162 | metric_4cls.write_result(fp) -------------------------------------------------------------------------------- /analyzer/dataset_explore/ards_explore.py: -------------------------------------------------------------------------------- 1 | import tools 2 | from tools.logging import logger 3 | from ..container import DataContainer 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from tqdm import tqdm 7 | import seaborn as sns 8 | import os 9 | from os.path import join as osjoin 10 | import pandas as pd 11 | import yaml 12 | from datasets.derived_ards_dataset import MIMICIV_ARDS_Dataset 13 | from scipy.signal import convolve2d 14 | from tools.data import DynamicDataGenerator, label_func_min, LabelGenerator_cls, cal_label_weight 15 | 16 | 17 | 18 | class ArdsFeatureExplorer: 19 | def __init__(self, params:dict, container:DataContainer) -> None: 20 | self.params = params 21 | self.container = container 22 | self.dataset = MIMICIV_ARDS_Dataset() 23 | self.dataset.load_version(params['dataset_version']) 24 | self.gbl_conf = container._conf 25 | self.data = self.dataset.data 26 | self.dataset.mode('all') 27 | self.target_idx = self.dataset.fea_idx('PF_ratio') 28 | 29 | def run(self): 30 | '''输出mimic-iv数据集的统计特征, 独立于模型和研究方法''' 31 | logger.info('Analyzer: Feature explore') 32 | dataset_version = self.params['dataset_version'] 33 | out_dir = osjoin(self.params['paths']['out_dir'], f'feature_explore[{dataset_version}]') 34 | tools.reinit_dir(out_dir, build=True) 35 | # random plot sample time series 36 | if self.params['coverrate']['enabled']: 37 | self.plot_cover_rate(self.params['coverrate']['class_names']) 38 | if self.params['plot_samples']['enabled']: 39 | n_sample = self.params['plot_samples']['n_sample'] 40 | id_list = [self.dataset.fea_id(x) for x in self.params['plot_samples']['features']] 41 | id_names = [self.dataset.fea_label(x) for x in self.params['plot_samples']['features']] 42 | self.plot_samples(num=n_sample, id_list=id_list, id_names=id_names, out_dir=os.path.join(out_dir, 'samples')) 43 | if self.params['plot_time_series']['enabled']: 44 | n_sample = self.params['plot_time_series']['n_sample'] 45 | n_per_plots = self.params['plot_time_series']['n_per_plots'] 46 | for name in self.params['plot_time_series']["names"]: 47 | self.plot_time_series_samples(name, n_sample=n_sample, n_per_plots=n_per_plots, write_dir=os.path.join(out_dir, f"time_series_{name}")) 48 | if self.params['correlation']['enabled']: 49 | self.correlation(out_dir, self.params['correlation']['target']) 50 | if self.params['miss_mat']: 51 | self.miss_mat(out_dir) 52 | if self.params['first_ards_time']: 53 | self.first_ards_time(out_dir) 54 | 55 | def first_ards_time(self, out_dir): 56 | '''打印首次呼衰出现的时间分布''' 57 | times = [] 58 | counts = [] # 产生呼衰的次数 59 | ards_count = 0 60 | adms = [adm for s in self.dataset.subjects.values() for adm in s.admissions] 61 | pao2_id, fio2_id = "220224", "223835" 62 | for adm in adms: 63 | count = 0 64 | ticks = adm[pao2_id][:, 1] 65 | fio2 = np.interp(x=ticks, xp=adm[fio2_id][:, 1], fp=adm[fio2_id][:, 0]) 66 | pao2 = adm[pao2_id][:, 0] 67 | pf = pao2 / fio2 68 | for idx in range(pf.shape[0]): 69 | if pf[idx] < self.container.ards_threshold: 70 | times.append(adm[pao2_id][idx, 1]) 71 | count += 1 72 | if count != 0: 73 | ards_count += 1 74 | counts.append(count) 75 | tools.plot_single_dist(np.asarray(times), f"First ARDS time(hour)", os.path.join(out_dir, "first_ards_time.png"), adapt=True) 76 | tools.plot_single_dist(np.asarray(counts), f"ARDS Count", os.path.join(out_dir, "ards_count.png"), adapt=True) 77 | logger.info(f"ARDS patients count={ards_count}") 78 | 79 | def correlation(self, out_dir, target_id_or_label): 80 | # plot correlation matrix 81 | target_id, target_label = self.dataset.fea_id(target_id_or_label), self.dataset.fea_label(target_id_or_label) 82 | target_index = self.dataset.idx_dict[target_id] 83 | labels = [self.dataset.fea_label(id) for id in self.dataset._total_keys] 84 | corr_mat = tools.plot_correlation_matrix(self.data[:, :, 0], labels, save_path=os.path.join(out_dir, 'correlation_matrix')) 85 | correlations = [] 86 | for idx in range(corr_mat.shape[1]): 87 | correlations.append([corr_mat[target_index, idx], labels[idx]]) # list[(correlation coeff, label)] 88 | correlations = sorted(correlations, key=lambda x:np.abs(x[0]), reverse=True) 89 | with open(os.path.join(out_dir, 'correlation.txt'), 'w') as fp: 90 | fp.write(f"Target feature: {target_label}\n") 91 | for idx in range(corr_mat.shape[1]): 92 | fp.write(f'Correlation with target: {correlations[idx][0]} \t{correlations[idx][1]}\n') 93 | 94 | def miss_mat(self, out_dir): 95 | '''计算行列缺失分布并输出''' 96 | na_table = np.ones((len(self.dataset.subjects), len(self.dataset._dynamic_keys)), dtype=bool) # True=miss 97 | for r_id, s_id in enumerate(self.dataset.subjects): 98 | for adm in self.dataset.subjects[s_id].admissions: 99 | # TODO 替换dynamic keys到total keys 100 | adm_key = set(adm.keys()) 101 | for c_id, key in enumerate(self.dataset._dynamic_keys): 102 | if key in adm_key: 103 | na_table[r_id, c_id] = False 104 | 105 | row_nas = na_table.mean(axis=1) 106 | col_nas = na_table.mean(axis=0) 107 | tools.plot_single_dist(row_nas, f"Row miss rate", os.path.join(out_dir, "row_miss_rate.png"), discrete=False, adapt=True) 108 | tools.plot_single_dist(col_nas, f"Column miss rate", os.path.join(out_dir, "col_miss_rate.png"), discrete=False, adapt=True) 109 | # save raw/col miss rate to file 110 | tools.save_pkl(row_nas, os.path.join(out_dir, "row_missrate.pkl")) 111 | tools.save_pkl(col_nas, os.path.join(out_dir, "col_missrate.pkl")) 112 | 113 | # plot matrix 114 | row_idx = sorted(list(range(row_nas.shape[0])), key=lambda x:row_nas[x]) 115 | col_idx = sorted(list(range(col_nas.shape[0])), key=lambda x:col_nas[x]) 116 | na_table = na_table[row_idx, :][:, col_idx] # (n_subjects, n_feature) 117 | # apply conv to get density 118 | conv_kernel = np.ones((5,5)) / 25 119 | na_table = np.clip(convolve2d(na_table, conv_kernel, boundary='symm'), 0, 1.0) 120 | tools.plot_density_matrix(1.0-na_table, 'Missing distribution for subjects and features [miss=white]', xlabel='features', ylabel='subjects', 121 | aspect='auto', save_path=os.path.join(out_dir, "miss_mat.png")) 122 | 123 | def plot_samples(self, num, id_list:list, id_names:list, out_dir): 124 | '''随机抽取num个样本生成id_list中特征的时间序列, 在非对齐的时间刻度下表示''' 125 | tools.reinit_dir(out_dir, build=True) 126 | count = 0 127 | nrow = len(id_list) 128 | assert(nrow <= 5) # 太多会导致subplot拥挤 129 | bar = tqdm(desc='plot samples', total=num) 130 | for s_id, s in self.dataset.subjects.items(): 131 | for adm in s.admissions: 132 | if count >= num: 133 | return 134 | plt.figure(figsize = (6, nrow*2)) 135 | # register xlim 136 | xmin, xmax = np.inf,-np.inf 137 | for idx, id in enumerate(id_list): 138 | if id in adm.keys(): 139 | xmin = min(xmin, np.min(adm[id][:,1])) 140 | xmax = max(xmax, np.max(adm[id][:,1])) 141 | for idx, id in enumerate(id_list): 142 | if id in adm.keys(): 143 | plt.subplot(nrow, 1, idx+1) 144 | plt.plot(adm[id][:,1], adm[id][:,0], '-o', label=id_names[idx]) 145 | plt.gca().set_xlim([xmin-1, xmax+1]) 146 | plt.legend() 147 | plt.suptitle(f'subject={s_id}') 148 | plt.savefig(os.path.join(out_dir, f'{count}.png')) 149 | plt.close() 150 | bar.update(1) 151 | count += 1 152 | 153 | def plot_time_series_samples(self, fea_name:str, n_sample:int=100, n_per_plots:int=10, write_dir=None): 154 | ''' 155 | fea_name: total_keys中的项, 例如"220224" 156 | ''' 157 | if write_dir is not None: 158 | tools.reinit_dir(write_dir) 159 | n_sample = min(n_sample, self.data.shape[0]) 160 | n_plot = int(np.ceil(n_sample / n_per_plots)) 161 | fea_idx = self.dataset._idx_dict[fea_name] 162 | start_idx = 0 163 | label = self.dataset.fea_label(fea_name) 164 | for p_idx in range(n_plot): 165 | stop_idx = min(start_idx + n_per_plots, n_sample) 166 | mat = self.data[start_idx:stop_idx, fea_idx, :] 167 | for idx in range(stop_idx-start_idx): 168 | series = mat[idx, :] 169 | plt.plot(series[series > 0], alpha=0.3) 170 | plt.title(f"Time series sample of {label}") 171 | plt.xlabel("time tick=0.5 hour") 172 | plt.xlim(left=0, right=72) 173 | plt.ylabel(label) 174 | start_idx = stop_idx 175 | if write_dir is None: 176 | plt.show() 177 | else: 178 | plt.savefig(os.path.join(write_dir, f"plot_{p_idx}.png")) 179 | plt.close() 180 | 181 | 182 | def plot_cover_rate(self, class_names): 183 | generator = DynamicDataGenerator( 184 | window_points=self.params['coverrate']['window'], 185 | n_fea=len(self.dataset.total_keys), 186 | label_generator=LabelGenerator_cls( 187 | centers=self.params['coverrate']['centers'] 188 | ), 189 | label_func=label_func_min, 190 | target_idx=self.target_idx, 191 | limit_idx=[self.target_idx], 192 | forbidden_idx=[] 193 | ) 194 | result = generator(self.dataset.data, self.dataset.seqs_len) 195 | label, mask = result['label'], result['mask'] 196 | weight = cal_label_weight(4, label[mask, :]) 197 | rw = 1 / weight 198 | rw = rw / np.sum(rw) 199 | logger.info(f'Slice label proportion: {rw}') 200 | label = np.argmax(label, axis=-1) 201 | label[np.logical_not(mask)] = 4 202 | cls_label = [np.sum(np.min(label, axis=-1) == idx) for idx in range(4)] 203 | cls_label = np.asarray(cls_label) 204 | n_sum = np.sum(cls_label) 205 | for idx, name in enumerate(class_names): 206 | logger.info(f'{name}: n={cls_label[idx]}, proportion={cls_label[idx] / n_sum:.3f}') -------------------------------------------------------------------------------- /analyzer/dataset_explore/raw_explore.py: -------------------------------------------------------------------------------- 1 | import tools 2 | from tools.logging import logger 3 | from ..container import DataContainer 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from tqdm import tqdm 7 | import seaborn as sns 8 | import os 9 | from os.path import join as osjoin 10 | import pandas as pd 11 | from datasets.derived_raw_dataset import MIMICIV_Raw_Dataset 12 | import yaml 13 | from scipy.signal import convolve2d 14 | 15 | 16 | class RawFeatureExplorer: 17 | def __init__(self, params:dict, container:DataContainer) -> None: 18 | self.params = params 19 | self.container = container 20 | self.dataset = MIMICIV_Raw_Dataset() 21 | self.dataset.load_version(params['dataset_version']) 22 | self.gbl_conf = container._conf 23 | self.data = self.dataset.data 24 | 25 | def run(self): 26 | '''输出mimic-iv数据集的统计特征, 独立于模型和研究方法''' 27 | logger.info('Analyzer: Feature explore') 28 | dataset_version = self.params['dataset_version'] 29 | out_dir = osjoin(self.params['paths']['out_dir'], f'feature_explore[{dataset_version}]') 30 | tools.reinit_dir(out_dir, build=True) 31 | # random plot sample time series 32 | if self.params['plot_admission_dist']: 33 | self.plot_admission_dist(out_dir=out_dir) 34 | if self.params['plot_chart_vis']['enabled']: 35 | self.plot_chart_vis(out_dir=osjoin(out_dir, 'chart_vis')) 36 | if self.params['plot_samples']['enabled']: 37 | n_sample = self.params['plot_samples']['n_sample'] 38 | id_list = [self.dataset.fea_id(x) for x in self.params['plot_samples']['features']] 39 | id_names = [self.dataset.fea_label(x) for x in self.params['plot_samples']['features']] 40 | self.plot_samples(num=n_sample, id_list=id_list, id_names=id_names, out_dir=os.path.join(out_dir, 'samples')) 41 | if self.params['correlation']: 42 | self.correlation(out_dir) 43 | if self.params['abnormal_dist']['enabled']: 44 | self.abnormal_dist(out_dir) 45 | if self.params['miss_mat']: 46 | self.miss_mat(out_dir) 47 | if self.params['feature_count']: 48 | self.feature_count(out_dir) 49 | 50 | def plot_admission_dist(self, out_dir): 51 | out_path = osjoin(out_dir, 'admission_dist.png') 52 | admission_path = osjoin(self.params['paths']['mimic-iv-ards']['mimic_dir'], 'hosp', 'admissions.csv') 53 | admissions = pd.read_csv(admission_path, encoding='utf-8') 54 | subject_dict = {} 55 | for row in tqdm(admissions.itertuples(), 'Collect Admission info', total=len(admissions)): 56 | subject_dict[row.subject_id] = 1 if row.subject_id not in subject_dict else subject_dict[row.subject_id] + 1 57 | n_adm = np.asarray(list(subject_dict.values())) 58 | logger.info(f'Admission: mean={n_adm.mean():.3f}') 59 | 60 | logger.info(f'Retain {100*len(n_adm)/np.sum(n_adm):.3f}% admissions if we only choose the first admission.') 61 | tools.plot_single_dist(n_adm, 'Number of Admission', save_path=out_path, discrete=True, adapt=True, label=True, shrink=0.9, edgecolor=None) 62 | 63 | def plot_chart_vis(self, out_dir): 64 | tools.reinit_dir(out_dir) 65 | plot_keys = self.params['plot_chart_vis']['collect_list'] 66 | record = {} 67 | for plot_key in plot_keys: 68 | if plot_key == 'transfer': 69 | key_record = {} 70 | transfer_path = osjoin(self.params['paths']['mimic-iv-ards']['mimic_dir'], 'hosp', 'transfers.csv') 71 | table = pd.read_csv(transfer_path, engine='c', encoding='utf-8') 72 | for row in tqdm(table.itertuples(), 'plot category distribution: transfers'): 73 | r = 'empty' if not isinstance(row.careunit, str) else row.careunit 74 | key_record[r] = 1 if r not in key_record else key_record[r] + 1 75 | record[plot_key] = key_record 76 | elif plot_key == 'admission': 77 | names = ['insurance', 'language', 'marital_status', 'race'] 78 | for name in names: 79 | record[name] = {} 80 | admission_path = osjoin(self.params['paths']['mimic-iv-ards']['mimic_dir'], 'hosp', 'admissions.csv') 81 | table = pd.read_csv(admission_path, engine='c', encoding='utf-8') 82 | for row in tqdm(table.itertuples(), 'plot chcategory distribution: admissions'): 83 | for name in names: 84 | r = 'empty' if not isinstance(getattr(row, name), str) else getattr(row, name) 85 | record[name][r] = 1 if r not in record[name] else record[name][r] + 1 86 | 87 | # sort and plot 88 | plot_in = {} 89 | for key, key_record in record.items(): 90 | x = sorted(key_record.values(), reverse=True) 91 | y = sorted(list(key_record.keys()), key=lambda x:key_record[x], reverse=True) 92 | total = sum(key_record.values()) 93 | x = [xi/total for xi in x] 94 | plot_in[key] = [x, y] 95 | tools.plot_bar_with_label(data=np.asarray(x), labels=y, title=f'Category distribution for {key}', sort=False, out_path=osjoin(out_dir, f'category_{key}.png')) 96 | with open(osjoin(out_dir, 'categories.yml'), 'w', encoding='utf-8') as fp: 97 | out_dict = {} 98 | for key in plot_in: 99 | out_dict[key] = {name:idx+1 for idx, name in enumerate(plot_in[key][1]) if sum(plot_in[key][0][idx:]) > 0.02} 100 | out_dict[key].update({'Default':0}) 101 | yaml.dump(out_dict, fp) 102 | 103 | tools.plot_stack_proportion(plot_in, out_path=os.path.join(out_dir, f"stack_percentage.png")) 104 | 105 | 106 | def abnormal_dist(self, out_dir): 107 | limit:dict = self.params['abnormal_dist']['value_limitation'] 108 | limit_names = list(limit.keys()) 109 | abnormal_table = np.zeros((len(limit_names), 2)) # True=miss 110 | idx_dict = {name:idx for idx, name in enumerate(limit_names)} 111 | for s_id in self.dataset.subjects: 112 | adm = self.dataset.subjects[s_id].admissions[0] 113 | for key in adm.dynamic_data: 114 | name = self.dataset.fea_label(key) 115 | if name in limit.keys(): 116 | is_abnormal = np.any( 117 | np.logical_or(adm.dynamic_data[key][:, 0] < limit[name]['min'], adm.dynamic_data[key][:, 0] > limit[name]['max']) 118 | ) 119 | abnormal_table[idx_dict[name], 0] += (1 if is_abnormal else 0) 120 | abnormal_table[idx_dict[name], 1] += 1 121 | abnormal_table = abnormal_table[:, 0] / np.maximum(abnormal_table[:, 1], 1e-3) # avoid nan 122 | # sort table 123 | limit_names = sorted(limit_names, key=lambda x:abnormal_table[limit_names.index(x)], reverse=True) 124 | abnormal_table = np.sort(abnormal_table)[::-1] 125 | # bar plot 126 | plt.figure(figsize=(10,10)) 127 | ax = sns.barplot(x=np.asarray(limit_names), y=abnormal_table) 128 | ax.set_xticklabels(limit_names, rotation=45, ha='right') 129 | ax.bar_label(ax.containers[0], fontsize=10, fmt=lambda x:f'{x:.4f}') 130 | plt.subplots_adjust(bottom=0.3) 131 | plt.savefig(osjoin(out_dir, 'abnormal.png')) 132 | plt.close() 133 | 134 | def correlation(self, out_dir): 135 | # plot correlation matrix 136 | labels = [self.dataset.fea_label(id) for id in self.dataset._total_keys] 137 | logger.info('plot correlation') 138 | tools.plot_correlation_matrix(self.data[:, :, 0], labels, save_path=os.path.join(out_dir, 'correlation_matrix'), corr_thres=0.8) 139 | 140 | def miss_mat(self, out_dir): 141 | '''计算行列缺失分布并输出''' 142 | na_table = np.ones((len(self.dataset.subjects), len(self.dataset._dynamic_keys)), dtype=bool) # True=miss 143 | for r_id, s_id in enumerate(self.dataset.subjects): 144 | for adm in self.dataset.subjects[s_id].admissions: 145 | # TODO 替换dynamic keys到total keys 146 | adm_key = set(adm.keys()) 147 | for c_id, key in enumerate(self.dataset._dynamic_keys): 148 | if key in adm_key: 149 | na_table[r_id, c_id] = False 150 | 151 | row_nas = na_table.mean(axis=1) 152 | col_nas = na_table.mean(axis=0) 153 | tools.plot_single_dist(row_nas, f"Row miss rate", os.path.join(out_dir, "row_miss_rate.png"), discrete=False, adapt=True) 154 | tools.plot_single_dist(col_nas, f"Column miss rate", os.path.join(out_dir, "col_miss_rate.png"), discrete=False, adapt=True) 155 | # save raw/col miss rate to file 156 | tools.save_pkl(row_nas, os.path.join(out_dir, "row_missrate.pkl")) 157 | tools.save_pkl(col_nas, os.path.join(out_dir, "col_missrate.pkl")) 158 | 159 | # plot matrix 160 | row_idx = sorted(list(range(row_nas.shape[0])), key=lambda x:row_nas[x]) 161 | col_idx = sorted(list(range(col_nas.shape[0])), key=lambda x:col_nas[x]) 162 | na_table = na_table[row_idx, :][:, col_idx] # (n_subjects, n_feature) 163 | # apply conv to get density 164 | conv_kernel = np.ones((5,5)) / 25 165 | na_table = np.clip(convolve2d(na_table, conv_kernel, boundary='symm'), 0, 1.0) 166 | tools.plot_density_matrix(1.0-na_table, 'Missing distribution for subjects and features [miss=white]', xlabel='features', ylabel='subjects', 167 | aspect='auto', save_path=os.path.join(out_dir, "miss_mat.png")) 168 | 169 | def feature_count(self, out_dir): 170 | '''打印vital_sig中特征出现的次数和最短间隔排序''' 171 | adms = [adm for s in self.dataset.subjects.values() for adm in s.admissions] 172 | count_hist = {} 173 | for adm in adms: 174 | for key in adm.keys(): 175 | if key not in count_hist.keys(): 176 | count_hist[key] = {'num':0, 'count':0, 'interval':0} 177 | count_hist[key]['num'] += 1 178 | count_hist[key]['count'] += adm[key].shape[0] 179 | count_hist[key]['interval'] += ((adm[key][-1, 1] - adm[key][0, 1]) / adm[key].shape[0]) 180 | for key in count_hist.keys(): 181 | count_hist[key]['count'] /= count_hist[key]['num'] 182 | count_hist[key]['interval'] /= count_hist[key]['num'] 183 | key_list = list(count_hist.keys()) 184 | key_list = sorted(key_list, key=lambda x:count_hist[x]['count']) 185 | key_list = key_list[-40:] # 最多80, 否则vital_sig可能不准 186 | with open(os.path.join(out_dir, 'interval.txt'), 'w') as fp: 187 | for key in key_list: 188 | interval = count_hist[key]['interval'] 189 | fp.write(f'\"{key}\", {self.dataset.fea_label(key)} mean interval={interval:.1f}\n') 190 | vital_sig = {"220045", "220210", "220277", "220181", "220179", "220180", "223761", "223762", "224685", "224684", "224686", "228640", "224417"} 191 | med_ind = {key for key in key_list} - vital_sig 192 | for name in ['vital_sig', 'med_ind']: 193 | subset = vital_sig if name == 'vital_sig' else med_ind 194 | new_list = [] 195 | for key in key_list: 196 | if key in subset: 197 | new_list.append(key) 198 | counts = np.asarray([count_hist[key]['count'] for key in new_list]) 199 | intervals = np.asarray([count_hist[key]['interval'] for key in new_list]) 200 | labels = [self.dataset.fea_label(key) for key in new_list] 201 | tools.plot_bar_with_label(counts, labels, f'{name} Count', out_path=os.path.join(out_dir, f"{name}_feature_count.png")) 202 | tools.plot_bar_with_label(intervals, labels, f'{name} Interval', out_path=os.path.join(out_dir, f"{name}_feature_interval.png")) 203 | 204 | def plot_samples(self, num, id_list:list, id_names:list, out_dir): 205 | '''随机抽取num个样本生成id_list中特征的时间序列, 在非对齐的时间刻度下表示''' 206 | tools.reinit_dir(out_dir, build=True) 207 | count = 0 208 | nrow = len(id_list) 209 | assert(nrow <= 5) # 太多会导致subplot拥挤 210 | bar = tqdm(desc='plot samples', total=num) 211 | for s_id, s in self.dataset.subjects.items(): 212 | for adm in s.admissions: 213 | if count >= num: 214 | return 215 | plt.figure(figsize = (6, nrow*3)) 216 | # register xlim 217 | xmin, xmax = np.inf,-np.inf 218 | for idx, id in enumerate(id_list): 219 | if id in adm.keys(): 220 | xmin = min(xmin, np.min(adm[id][:,1])) 221 | xmax = max(xmax, np.max(adm[id][:,1])) 222 | for idx, id in enumerate(id_list): 223 | if id in adm.keys(): 224 | plt.subplot(nrow, 1, idx+1) 225 | plt.plot(adm[id][:,1], adm[id][:,0], '-o', label=id_names[idx]) 226 | plt.gca().set_xlim([xmin, xmax]) 227 | plt.legend() 228 | plt.suptitle(f'subject={s_id}') 229 | plt.savefig(os.path.join(out_dir, f'{count}.png')) 230 | plt.close() 231 | bar.update(1) 232 | count += 1 -------------------------------------------------------------------------------- /datasets/derived_raw_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tools 3 | import numpy as np 4 | from tools import logger 5 | from tqdm import tqdm 6 | from .helper import Subject, Admission, load_all_subjects 7 | from collections import namedtuple, Counter 8 | from random import choice 9 | from os.path import join as osjoin 10 | from .mimiciv_core import MIMICIV_Core 11 | 12 | class MIMICIV_Raw_Dataset(MIMICIV_Core): 13 | _name = 'mimic-iv-raw' 14 | 15 | def __init__(self): 16 | super().__init__(self.name()) 17 | 18 | def on_extract_subjects(self) -> dict: 19 | # extract all subjects 20 | patient_set = load_all_subjects(osjoin(self._mimic_dir, 'hosp', 'patients.csv')) 21 | return patient_set, None 22 | 23 | def on_build_subject(self, subject_id:int, subject:Subject, row:namedtuple, patient_set:set, extra_data:object) -> Subject: 24 | ''' 25 | subject: Subject() 26 | row: dict, {column_name:value} 27 | extract_value: value of _extract_reuslt[id] 28 | ''' 29 | ymd_convertor = tools.TimeConverter(format="%Y-%m-%d", out_unit='hour') 30 | subject.append_static(0, 'age', -1) 31 | subject.append_static(0, 'gender', row.gender) 32 | if row.dod is not None and isinstance(row.dod, str): 33 | subject.append_static(ymd_convertor(row.dod), 'dod', ymd_convertor(row.dod)) 34 | return subject 35 | 36 | def on_extract_admission(self, subject:Subject, source:str, row:namedtuple) -> bool: 37 | ymdhms_converter = tools.TimeConverter(format="%Y-%m-%d %H:%M:%S", out_unit='hour') 38 | if source == 'admission': 39 | admittime = ymdhms_converter(row.admittime) 40 | dischtime = ymdhms_converter(row.dischtime) 41 | if dischtime <= admittime: 42 | return False 43 | adm = Admission( 44 | unique_id=int(row.hadm_id*1e8), 45 | admittime=admittime, 46 | dischtime=dischtime 47 | ) 48 | discretizer = self._loc_conf['category_to_numeric'] 49 | subject.append_admission(adm) 50 | for name, val in zip( 51 | ['insurance', 'language', 'race', 'marital_status'], 52 | [row.insurance, row.language, row.race, row.marital_status] 53 | ): 54 | subject.append_static(admittime, name, discretizer[name][val] if val in discretizer[name] else discretizer[name]['Default']) 55 | 56 | subject.append_static(admittime, 'hosp_expire', row.hospital_expire_flag) 57 | return True 58 | elif source == 'icu': 59 | return False 60 | elif source == 'transfer': 61 | if not np.isnan(row.hadm_id): 62 | adm = subject.find_admission(int(row.hadm_id*1e8)) 63 | if adm is not None: 64 | discretizer = self._loc_conf['category_to_numeric'] 65 | careunit = discretizer['careunit'][row.careunit] if row.careunit in discretizer['careunit'] else discretizer['careunit']['Default'] 66 | adm.append_dynamic('careunit', ymdhms_converter(row.intime), careunit) 67 | return False 68 | else: 69 | assert(0) 70 | return False 71 | 72 | 73 | def on_select_feature(self, subject_id:int, row:dict, source:str=['icu', 'hosp', 'ed']): 74 | if source == 'icu': 75 | if row['type'] in ['Numeric', 'Numeric with tag'] and row['category'] != 'Alarms': 76 | return True # select 77 | else: 78 | return False # not select 79 | elif source == 'ed': 80 | return True 81 | elif source == 'hosp': 82 | return True 83 | 84 | def on_convert_numeric(self, s:Subject) -> Subject: 85 | ''' 86 | 1. 对特定格式的特征进行转换(血压) 87 | 2. 检测不能转换为float的静态特征 88 | ''' 89 | # step1: convert static data 90 | invalid_record = {} 91 | def add_invalid(key, value): 92 | if key not in invalid_record: 93 | invalid_record[key] = {'count':1, 'examples':set()} 94 | invalid_record[key]['examples'].add(value) 95 | else: 96 | invalid_record[key]['count'] += 1 97 | if len(invalid_record[key]['examples']) < 5: 98 | invalid_record[key]['examples'].add(value) 99 | 100 | static_data = s.static_data 101 | pop_keys = [] 102 | for key in list(static_data.keys()): 103 | if key == 'gender': 104 | # female: 0, male: 1 105 | static_data[key] = 0 if static_data[key][0][0] == 'F' else 1 106 | elif 'Blood Pressure' in key: 107 | static_data['systolic pressure'] = [] 108 | static_data['diastolic pressure'] = [] 109 | for idx in range(len(static_data[key])): 110 | p_result = static_data[key][idx][0].split('/') 111 | time = static_data[key][idx][1] 112 | vs, vd = float(p_result[0]), float(p_result[1]) 113 | static_data['systolic pressure'].append((vs, time)) 114 | static_data['diastolic pressure'].append((vd, time)) 115 | static_data.pop(key) 116 | static_data['systolic pressure'] = np.asarray(static_data['systolic pressure']) 117 | static_data['diastolic pressure'] = np.asarray(static_data['diastolic pressure']) 118 | elif isinstance(static_data[key], list): 119 | valid_idx = [] 120 | for idx in range(len(static_data[key])): 121 | value, t = static_data[key][idx] 122 | try: 123 | v = float(value) 124 | assert(not np.isnan(v)) 125 | valid_idx.append(idx) 126 | except Exception as e: 127 | add_invalid(key, value) 128 | if len(valid_idx) == 0: # no valid 129 | pop_keys.append(key) 130 | else: 131 | static_data[key] = np.asarray(static_data[key])[valid_idx, :].astype(np.float64) 132 | for key in pop_keys: 133 | static_data.pop(key) 134 | 135 | s.static_data = static_data 136 | # step2: convert dynamic data in admissions 137 | # NOTE: just throw away invalid row 138 | for adm in s.admissions: 139 | pop_keys = [] 140 | for key in adm.dynamic_data: 141 | valid_idx = [] 142 | for idx, row in enumerate(adm.dynamic_data[key]): 143 | value = row[0] 144 | try: 145 | v = float(value) 146 | assert(not np.isnan(v)) 147 | valid_idx.append(idx) 148 | except Exception as e: 149 | add_invalid(key, value) 150 | if len(valid_idx) == 0: 151 | pop_keys.append(key) 152 | elif len(valid_idx) < len(adm.dynamic_data[key]): 153 | adm.dynamic_data[key] = np.asarray(adm.dynamic_data[key])[valid_idx, :].astype(np.float64) # TODO 这部分代码需要复制给另外两个dataset 154 | else: 155 | adm.dynamic_data[key] = np.asarray(adm.dynamic_data[key]).astype(np.float64) 156 | for key in pop_keys: 157 | adm.dynamic_data.pop(key) 158 | 159 | return invalid_record 160 | 161 | def on_select_admissions(self, rule:dict, subjects:dict[int, Subject]): 162 | invalid_record = {'age': 0, 'duration_positive':0, 'duration_limit':0, 'empty':0} 163 | for s_id in subjects: 164 | if not subjects[s_id].empty(): 165 | retain_adms = [] 166 | for idx, adm in enumerate(subjects[s_id].admissions): 167 | age = int(adm.admittime / (24*365) + 1970 - subjects[s_id].birth_year) 168 | if age <= 18: 169 | invalid_record['age'] += 1 170 | continue 171 | dur = adm.dischtime - adm.admittime 172 | if dur <= 0: 173 | invalid_record['duration_positive'] += 1 174 | continue 175 | if 'duration_minmax' in rule: 176 | dur_min, dur_max = rule['duration_minmax'] 177 | if not (dur > dur_min and dur < dur_max): 178 | invalid_record['duration_limit'] += 1 179 | continue 180 | retain_adms.append(idx) 181 | subjects[s_id].admissions = [subjects[s_id].admissions[idx] for idx in retain_adms] 182 | else: 183 | invalid_record['empty'] += 1 184 | pop_list = [] 185 | for s_id in subjects: 186 | subjects[s_id].del_empty_admission() 187 | if subjects[s_id].empty(): 188 | pop_list.append(s_id) 189 | for s_id in pop_list: 190 | subjects.pop(s_id) 191 | 192 | logger.info(f'invalid admissions with age <= 19: {invalid_record["age"]}') 193 | logger.info(f'invalid admissions without positive duration: {invalid_record["duration_positive"]}') 194 | logger.info(f'invalid admissions exceed duration limitation: {invalid_record["duration_limit"]}') 195 | logger.info(f'invalid subjects with no admission (empty): {invalid_record["empty"]}') 196 | logger.info(f'remove_pass1: Deleted {len(pop_list)}/{len(pop_list)+len(subjects)} subjects') 197 | 198 | return subjects 199 | 200 | def on_remove_missing_data(self, rule:dict, subjects: dict[int, Subject]) -> dict[int, Subject]: 201 | n_iter = 0 202 | while n_iter < len(rule['max_col_missrate']) or (len(post_dynamic_keys)+len(post_static_keys) != len(col_missrate)) or (len(pop_subject_ids) > 0): 203 | # step1: create column missrate dict 204 | prior_static_keys = [k for s in subjects.values() for k in s.static_data.keys()] 205 | prior_dynamic_keys = [k for s in subjects.values() if not s.empty() for k in choice(s.admissions).keys()] # random smapling strategy 206 | col_missrate = {k:1-v/len(subjects) for key_list in [prior_static_keys, prior_dynamic_keys] for k,v in Counter(key_list).items()} 207 | # step2: remove invalid columns and admissions 208 | for s_id, s in subjects.items(): 209 | for adm in s.admissions: 210 | pop_keys = [k for k in adm.keys() if k not in col_missrate or col_missrate[k] > rule['max_col_missrate'][min(n_iter, len(rule['max_col_missrate'])-1)]] 211 | for key in pop_keys: 212 | adm.pop_dynamic(key) 213 | s.del_empty_admission() 214 | if len(s.admissions) > 1: 215 | retain_adm = np.random.randint(0, len(s.admissions)) if rule['adm_select_strategy'] == 'random' else 0 216 | s.admissions = [s.admissions[retain_adm]] 217 | # step3: create subject missrate dict 218 | post_static_keys = set([k for s in subjects.values() for k in s.static_data.keys()]) 219 | post_dynamic_keys = set([k for s in subjects.values() if not s.empty() for k in s.admissions[0].keys()]) 220 | subject_missrate = { 221 | s_id:1 - (len(s.static_data) + len(s.admissions[0]))/(len(post_static_keys)+len(post_dynamic_keys)) \ 222 | if not s.empty() else 1 for s_id, s in subjects.items() 223 | } 224 | # step4: remove invalid subject 225 | pop_subject_ids = set([s_id for s_id in subjects if subject_missrate[s_id] > rule['max_subject_missrate'][min(n_iter, len(rule['max_subject_missrate'])-1)] or subjects[s_id].empty()]) 226 | for s_id in pop_subject_ids: 227 | subjects.pop(s_id) 228 | # step5: calculate removed subjects/columns 229 | logger.info(f'remove_pass2: iter[{n_iter}] Retain {len(self._subjects)}/{len(pop_subject_ids)+len(self._subjects)} subjects') 230 | logger.info(f'remove_pass2: iter[{n_iter}] Retain {len(post_dynamic_keys)+len(post_static_keys)}/{len(col_missrate)} keys in selected admission') 231 | n_iter += 1 232 | return subjects 233 | 234 | def on_build_table(self, subject:Subject, key, value, t_start): 235 | admittime = subject.admissions[0].admittime 236 | if key == 'dod': 237 | if abs(value+1.0) < 1e-3: 238 | return -1 239 | else: 240 | delta_year = np.floor(value / (24*365) - ((t_start+admittime) / (24*365))) # 经过多少年死亡, 两个都是timestamp,不需要加上1970 241 | assert(-100 < delta_year < 100) 242 | return delta_year 243 | elif key == 'age': 244 | age = (admittime + t_start) // (24*365) + 1970 - subject.birth_year # admittime从1970年开始计时 245 | return age 246 | else: 247 | return value 248 | 249 | def on_feature_engineering(self, tables:list[np.ndarray], norm_dict:dict, static_keys:list, dynamic_keys): 250 | return { 251 | 'tables': tables, 252 | 'norm_dict': norm_dict, 253 | 'static_keys': static_keys, 254 | 'dynamic_keys': dynamic_keys 255 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Pipeline Overview](documents/general_pipeline.png) 2 | 3 | # MIMIC-IV Data Processing Pipeline 4 | 5 | **[中文版本](README_CN.md) | [English Version](README.md)** 6 | 7 | The MIMIC-IV dataset is widely used in various medical researches, however, the original dataset has not undergone data cleaning. This framework provides a highly configurable Pipeline for MIMIC-IV, aiming at minimal encapsulation, high flexibility, and easy extensibility. The code for data processing is low coupled and can be easily merged with the processing code for other datasets. The framework itself provides a default processing flow, and at the same time, the framework provides user-defined configurations at both the configuration file and calling interface levels, which can meet complex user-defined requirements. 8 | 9 | ## Architecture 10 | 11 | The framework consists of three main parts: dataset dataset, model model, and analyzer analyzer. dataset abstracts the original data into torch.dataset interface; model computes outputs on batch inputs; and analyzer is similar to trainer, providing K-fold, metrics computation, plotting, and other tasks. Splitting model and analyzer makes it more convenient for a single analyzer to call multiple models for integrated learning, and for a single model to be called by multiple analzyers. The rest of the tools section includes common tools and methods, and the configs section is for fields that need to be configured, such as paths, parameters for data cleansing, and so on. 12 | 13 | analyzer: the analyzer module 14 | 1. analyser: runs the analyser in a sequence and needs to be registered when adding a new analyser. 15 | 2. container: store parameters that are not related to the model 16 | 3. feature_explore: generates an exploratory report of the dataset, with configurable generation parameters 17 | 3. utils: utility methods 18 | 19 | configs: Configuration files for each dataset. 20 | 1. global_config: configuration path 21 | 2. mimiciv_dataset: corresponds to `dataset/mimic_dataset.py`, by default it is configured for extracting patients with oxygenation indices and corresponding features from Sepsis patients, it is recommended to modify the downstream tasks on this dataset. 22 | 3. mimiciv_dataset_raw: corresponds to `dataset/mimic_raw_dataset.py`, provides a minimally processed dataset for data probing 23 | 24 | other modules: 25 | - data: dataset file 26 | - datasets: dataset abstractions, including data extraction/cleaning/reorganisation 27 | - libs: third-party libraries and related code 28 | - models: models 29 | - outputs: output folder 30 | - tools: tool classes 31 | - main.py: main entry, interfaces with the launcher via parameters 32 | 33 | ## Deployment method 34 | 35 | Follow the steps below to deploy: 36 | 1. Configure the conda environment under `python=3.10.11` and install the required packages: `pip install -r requirements.txt`. 37 | 2. In the first step, refer to the next subsection for the cuda version of pytorch. 38 | 3. Extract the MIMIC-IV dataset to the `data/mimic-iv` folder, with sub-folders `hosp`, `icu`, etc. 39 | 4. (Optional) If there is a MIMIC-IV-ED file, extract it to the `ed` subfolder. 40 | 5. store the generated `sepsis3.csv` under `data/mimic-iv/sepsis_result`. 41 | 6. Run `python -u main.py` to generate it all at once. 42 | 43 | Install the corresponding CUDA version of Pytorch: 44 | 1. Create and enter a new conda virtual environment. 45 | 2. Type `nvidia-smi` to see which CUDA version is installed on the server. 46 | 3. Follow the installation command at https://pytorch.org/get-started/previous-versions/ to select the corresponding version under linux, the CUDA version of pytorch can be behind the CUDA version of the server. 47 | 4. Check if the installation was successful: https://blog.csdn.net/qq_45032341/article/details/105196680 48 | 5. If you installed a different version of pytorch than the one in `requirements.txt`, delete the corresponding line to avoid duplicate installations. 49 | 6. The framework itself has no strict restrictions on the versions of third-party libraries. 50 | 51 | ## MIMIC-IV datasets 52 | 53 | We divide the processing of the dataset into two parts: `dataset specific processing`, which deals with issues arising from the internal structure of the dataset, and `model & task specific processing`, which does different processing for different downstream tasks. The code in `datasets/mimiciv/MIMICIV` is used for `dataset specific processing`, and in general the user does not need to modify its contents. 54 | 55 | `model & task specific processing` is done by derived classes and `data_generator`. `datasets/mimic_dataset/MIMICIVDataset` is a derived class of `MIMICIV`, which enables flexible processing of different downstream tasks by overriding abstract methods starting with `on_`, and the user needs to modify this code for different research needs. Each derived class of `MIMICIV` has a separate configuration file in `config`, which can be modified to adjust the behaviour of `MIMICIV` without modifying the code. 56 | 57 | ### Introduction to the MIMICIV processing flow 58 | 59 | This part of the processing corresponds to `dataset specific processing`, in most cases users do not need to modify the contents of it, but only need to understand the general processing flow and the definition of the interface. According to the principle of minimal encapsulation, only `@property` is used to distinguish between internal and external visible properties, and `@abstractmethod` is used to indicate which methods need to be overridden by derived classes. 60 | 61 | **Data access**: `MIMICIV`'s data processing is divided into 7 stages (phases), and the processed data from each stage is stored in cache under `data/mimic-iv-XXX/cache` (e.g. `data/mimic-iv-raw/cache` in the case of `mimic_raw`), once Once all the data is processed, only the cache corresponding to the final result is loaded each time it is instantiated, removing unnecessary IO operations, at which point `Bare Mode Enabled` is displayed. In addition, in order to save space, the cache is compressed using the lzma algorithm, the compression time will be greater than that of the pickle direct storage, but it has almost no effect on reading, if you need to save the compression time, modify `configs/mimic_dataset/compress_cache` to `False 62 | 63 | **preprocess_phase1**: load `hosp`, `icu` item mapping tables, create corresponding mapping tables for `ed`, no need to actually read ED data. Afterwards call `on_extract_subjects` to introduce the list of subjects and external data to be filtered. 64 | 65 | **preprocess_phase2**: read the basic information of the patient according to the subject_id list, this phase calls `on_build_subject`, `on_extract_admission` successively. 66 | 67 | **preprocess_phase3**: Traverses all available data from MIMIC-IV and extracts the required data. There are three data sources (`icu`, `ed`, `hosp`) via `configs/mimic_dataset/data_linkage` it is possible to configure which data sources do not need to be read. Simply exclude which features are not needed before data extraction via `on_select_feature` 68 | 69 | **preprocess_phase4**: Converts the data to numeric, filters the available samples, and finally constrains the outliers to be within the normal range according to `configs/mimic_dataset/value_clip`. In the first two processes, `on_convert_numeric`, `on_select_admissions` will be called in turn. 70 | 71 | **preprocess_phase5**: further removes features and samples with high missing rates, generating per-feature statistics that are generated before the timeline alignment and are therefore unaffected by interpolation and missing padding. This step will call `on_remove_missing_data` 72 | 73 | **preprocess_phase6**: Interpolation and alignment of unevenly sampled data, generation of a 3D array `(n_subject, n_feature, n_sequence)` at regular intervals, with different lengths of sequences filling in the end with `-1`, and finally feature engineering. This stage calls `on_build_table`, `on_feature_engineering` in turn 74 | 75 | Interpolation considers three cases: 76 | 1. the head and tail of an admittance are determined by `configs/mimic_dataset/align_target`, looking for the earliest and latest points in time at which the target features are both present. The interpolation interval is determined by `configs/mimic_dataset/delta_t_hour`. 77 | 2. when historical data exists at a given point in time, the most recent historical available data is used, regardless of causality 78 | 3. when the interpolation start time is earlier than the known historical data for that feature, the first data point is used to fill in the gaps. 79 | 80 | **preprocess_phase7**: Generate different derived versions of the same dataset according to `configs/mimic_dataset/version`, user can customise the restriction and exclusion ranges of features in the configuration, specify the handling of missing values, currently supports `avg` mean fill and `none` default (- 1) padding. Different versions can be used for data exploration and model training respectively. 81 | 82 | ### Introduction to MIMICIV derived dataset 83 | 84 | This section corresponds to the functions in `datasets/mimic_dataset`, and it is recommended that the user modifies only the necessary parts and specifies the input and output formats of the interface. This framework has few checks for user-defined behaviour, which may cause the rest of the Pipeline to report errors if user-modified parts introduce new problems (e.g., generating NaNs). 85 | 86 | **on_extract_subjects** 87 | 88 | Output: dict[key=int(subject_id), value=object(extra_data)] 89 | 1. subject_id needs to cover all patients captured, and can be linked to other data pipelines, e.g. by reading the derivation table of `mimic-code`. 90 | 2. `extra_data` contains additional data corresponding to the subject, in an unrestricted format, and is only processed by the user in `on_build_subject 91 | 92 | **on_build_subject** 93 | 94 | Creating an instance of Subject 95 | 96 | Inputs: 97 | 1. subject_id: int 98 | 2. subject: Subject type, add new data via `append_static` method 99 | 3. row: rows of `hosp/patients.csv` corresponding to subject, including `gender`, `dod`, etc. 100 | 4. _extract_result: the data extracted in `on_extract_subjects` 101 | 102 | Output: subject, which is a reference, essentially adding content via the `append_static` method 103 | 104 | **on_extract_admission** 105 | 106 | Adding Admission instances to subjects 107 | 108 | Inputs: 109 | 1. subject: Subject 110 | 2. source: str, values can be (`admission`, `icu`, `transfer`) extracted from `hosp/admission.csv`, `icu/icu_stays.csv`, `hosp/transfers.csv` respectively 111 | 3. row: namedtuple, rows from table corresponding to `source` 112 | 113 | No output, add data via `Subject.append_static()` and `Admission.append_dyanmic()` 114 | 115 | **on_select_feature** 116 | 117 | Inputs: 118 | 1. id: int, subject_id 119 | 2. source: str, can be (`icu`, `hosp`, `ed`) from `icu/d_items.csv`, `hosp/d_labitems.csv`, `ed` has built-in formatting 120 | 3. row: dict, the meaning of key refer to `source` corresponding to the meaning of each column in the source file 121 | - source=`icu`, key=[`id`, `label`, `category`, `type`, `low`, `high`] 122 | - source=`hosp`, key=[`id`, `label`, `fluid`, `category`] 123 | - source=`ed`, key=[`id`, `link_id`, `label`] where `link_id` links the ED's basic vitals to the ICU to vitals, not as a separate column 124 | 125 | Output: bool, False means the feature is not selected, True means it is. Note that this stage only removes unwanted features, selected features are further filtered 126 | 127 | **on_convert_numeric** 128 | 129 | It is not recommended to rewrite this method, the example should be modelled to rewrite only what is in the judgement statement 130 | 131 | `Subject.static_data`: dict[key=str, value=list[tuple(value, time)]], modifies the values of non-numeric features to make all features numeric 132 | 133 | **on_select_admissions(remove_invalid_pass1)** 134 | 135 | Filter the available samples according to the conditions, it is not recommended to rewrite the whole method, the filtering conditions should be modified according to the task needs 136 | 137 | Inputs. 138 | 1. rule: dict, from `configs/mimiciv_dataset/remove_rule/pass1 139 | 2. subjects: dict[subject_id: int, Subject] 140 | 141 | Default filter conditions are used to filter Sepsis patients and need to be satisfied in all:. 142 | 1. whether all the features in `target_id` are included, if not, the temporal labels cannot be generated in some tasks 143 | 2. whether the time length satisfies the condition, too short sequences can not be used for prediction, too long sequences will affect the size of the table after padding. 144 | 3. filter condition on `Sepsis time`, not needed in non-Sepsis tasks 145 | 146 | **on_remove_missing_data(remove_invalid_pass2)** 147 | 148 | Further filtering of available samples and features, unless necessary, it is recommended to modify only the missrate in the configuration file and not the code 149 | 150 | The default algorithm performs multi-step iterations to remove highly missing features and samples at the same time, the iteration steps are determined by the length of the `max_col_missrate` and `max_subject_missrate` lists in the config file, the missrate at the end of the list is the final missrate 151 | 152 | The `adm_select_strategy` has two choices, `random` which means randomly selecting one of the available admissions for each subject, and `default` which means selecting the first valid admission by default. 153 | 154 | **on_build_table** 155 | 156 | Modify the values of certain features, which are often time dependent, when determining the start time of the sequence 157 | 158 | Inputs. 159 | 1. subject: Subject 160 | 2. key: str, the id of the currently modified feature 161 | 3. value: float: original value 162 | 4. t_start: float, the start time of the current sequence from the admit time, is a positive number 163 | 164 | Output: new_value: float, the new value of the current feature 165 | 166 | **on_feature_engineering** 167 | 168 | After interpolation, new features are computed based on existing features. It is not recommended to rewrite the whole method, you can modify a part of the judgement statement. Note that the original feature needs to be considered missing (-1), if the original feature is missing, the new feature should also be set to -1 169 | 170 | Input: 171 | 1. tables: list[np.ndarray[n_features, n_sequences]], tabulated data without padding 172 | 2. norm_dict: dict[key='mean' or 'std', value=float], store the statistics of the features, new features need to update it 173 | 3. static_keys, dynamic_keys: list[str], new features require updating them 174 | 175 | Output: updated tables, norm_dict, static_keys, dynamic_keys 176 | 177 | ## Introduction to Data generator 178 | 179 | `model & task specific processing` is not always solved by pipelines, e.g. different models require different normalisation methods, temporal and non-temporal tasks require different data organisation formats. We provide multiple generators to generate different data formats and labels. 180 | 181 | Overall, there are three DataGenerators and two LabelGenerators, which are under `models/utils`, and users can easily develop new generators.We provide several sample algorithms to illustrate how different generators can be used for different tasks. 182 | 183 | ### label generator 184 | 185 | **LabelGenerator_4cls**: Expand an arbitrarily sized array of targets into a new dimension, divided into four categories of a predefined size, for classification prediction. 186 | 187 | **LabelGenerator_regression**: no processing, can be modified for other tasks. 188 | 189 | ### data generator 190 | 191 | Different data generators and label generators can be combined in any way, see the example algorithm: `analyzer/method_nearest_cls`. 192 | 193 | **DynamicDataGenerator**: For each data point, searches forward through a prediction window and calculates the lowest value within the window as the prediction target. It also takes available features, similar to `dataset version`, and can be normalised. The end result is still a 3D structure and the timeline is not expanded. It can be used for RNN algorithms such as LSTM. 194 | 195 | **SliceDataGenerator**: computed in the same way as `DynamicDataGenerator`, but the timeline of the final data will be unfolded into the form of (n_sequence, n_feature), and invalid data will be eliminated. For non-temporal methods such as GBDT 196 | 197 | **StaticDataGenerator**: for each sequence, generates a large prediction window to predict the endgame situation using features from the first moment. One sample is generated for each sequence. 198 | 199 | 200 | ## More help 201 | 202 | [How to generate Sepsis3.csv](documents/processing.md) -------------------------------------------------------------------------------- /tools/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from matplotlib.colors import Normalize as ColorNorm 3 | import numpy as np 4 | import seaborn as sns 5 | import scipy 6 | import pandas as pd 7 | from collections.abc import Iterable 8 | import os, sys 9 | import random 10 | from tqdm import tqdm 11 | import subprocess 12 | import missingno as msno 13 | from .generic import reinit_dir, remove_slash 14 | from matplotlib.colors import to_rgb 15 | from .logging import logger 16 | 17 | ''' 18 | 用于分位数回归的作图, 通过线性插值得到待测点所对应的分位数, 使得数据的分布不会改变出图的色彩多样性 19 | ''' 20 | class HueColorNormlize(ColorNorm): 21 | def __init__(self, data:np.ndarray) -> None: 22 | super().__init__(vmin=data.min(), vmax=data.max(), clip=False) 23 | n_points = 21 24 | points = [data.min()] 25 | data_sorted = np.sort(data, axis=None) 26 | for idx in range(n_points-1): 27 | points.append(data_sorted[round(((idx+1)/(n_points-1))*(data.shape[0]-1))]) 28 | self.x = np.asarray(points) 29 | self.y = np.asarray(list(range(len(points)))) / (len(points)-1) 30 | 31 | def get_ticks(self): 32 | return self.x 33 | 34 | def __call__(self, value, clip: bool = None): 35 | return np.interp(value, self.x, self.y, left=0, right=1) 36 | 37 | def inverse(self, value): 38 | return np.interp(value, self.y, self.x, left=self.vmin, right=self.vmax) 39 | 40 | def simple_plot(data, title='Title', out_path=None): 41 | plt.figure(figsize = (6,12)) 42 | for idx in range(data.shape[0]): 43 | plt.plot(data[idx, :]) 44 | plt.title(title) 45 | if out_path is None: 46 | plt.show() 47 | else: 48 | plt.savefig(out_path) 49 | plt.close() 50 | 51 | 52 | def plot_bar_with_label(data:np.ndarray, labels:list, title:str, sort=True, out_path=None): 53 | # Validate the input 54 | if not isinstance(data, np.ndarray): 55 | raise ValueError("Input data must be a numpy array.") 56 | if not isinstance(labels, list) or not all(isinstance(l, str) for l in labels): 57 | raise ValueError("Input labels must be a list of strings.") 58 | 59 | # sort 60 | if sort: 61 | ind = np.argsort(data) 62 | data = data[ind] 63 | labels = [labels[i] for i in ind] 64 | 65 | # Set up the histogram 66 | fig, ax = plt.subplots(figsize=(12,12)) # Set figure size 67 | plt.subplots_adjust(bottom=0.4, left=0.2, right=0.8) 68 | ind = np.arange(len(data)) 69 | width = 0.8 70 | if len(labels) < 20: 71 | fontsize = 20 72 | else: 73 | fontsize = 10 74 | ax.bar(ind, data, width, color='SkyBlue') 75 | 76 | # Set up the X axis 77 | ax.set_xticks(range(0, len(labels))) 78 | ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=fontsize) 79 | 80 | ax.set_title(title) 81 | 82 | if out_path is None: 83 | # Show the histogram 84 | plt.show() 85 | else: 86 | plt.savefig(out_path) 87 | plt.close() 88 | 89 | def plot_stack_proportion(data:dict[str, tuple], out_path=None): 90 | plt.figure(figsize=(25, 10)) 91 | height = 0.5 92 | names = list(data.keys()) 93 | style = [to_rgb(f'C{idx}') for idx in range(10)] 94 | plt.barh(names, [0 for _ in names], height=height) 95 | idx = 0 96 | for k_idx, (key, (x, label)) in enumerate(data.items()): 97 | x_sum = 0 98 | for idx in range(len(x)): 99 | color = np.asarray(style[k_idx % 10]) 100 | color = np.clip(color + 0.2 * (idx % 2), 0, 1.0) 101 | plt.barh([key], x[idx], left=x_sum, color=tuple(color), height=height) 102 | label_wid = len(label[idx])*0.006 103 | if x[idx] > label_wid: 104 | plt.annotate(label[idx], (x_sum + x[idx]*0.5 - label_wid*0.5, k_idx), fontsize=10) 105 | x_sum += x[idx] 106 | 107 | plt.xlim(left=0, right=1) 108 | plt.subplots_adjust(left=0.1, right=0.9) 109 | plt.savefig(out_path) 110 | 111 | 112 | def plot_density_matrix(data:np.ndarray, title:str, xlabel:str, ylabel:str, aspect='equal', save_path=None): 113 | plt.figure(figsize=(10, 10)) 114 | plt.title(title) 115 | plt.imshow(data, cmap='jet', aspect=aspect) # auto for square picture, equal for original aspect ratio 116 | plt.colorbar() 117 | plt.xlabel(xlabel) 118 | plt.ylabel(ylabel) 119 | plt.savefig(save_path) 120 | plt.close() 121 | 122 | 123 | def plot_single_dist(data:np.ndarray, data_name:str, save_path=None, discrete=True, adapt=False, label=False, **kwargs): 124 | ''' 125 | 从源数据直接打印直方图 126 | data: shape任意, 每个元素代表一个样本 127 | data_name: str 特征名 128 | discrete: 取值是否离散, True=离散取值 129 | adapt: 自动调整输出取值范围, 可能会忽略某些极端值 130 | ''' 131 | data = data[:] 132 | if data.shape[0] > 2000: 133 | data = np.asarray(random.choices(data, k=2000)) 134 | assert(not np.any(np.isnan(data))) 135 | mu, sigma = scipy.stats.norm.fit(data) 136 | if adapt and sigma > 0.01: 137 | data = data[np.logical_and(data >= mu-3*sigma, data <= mu+3*sigma)] 138 | 139 | plt.figure(figsize=(8,8)) 140 | ax = sns.histplot(data=data, stat='proportion', discrete=discrete, **kwargs) 141 | if adapt: 142 | if discrete: 143 | ax.set_xlim(left=max(mu-3*sigma, np.min(data))-0.5, right=min(mu+3*sigma, np.max(data))+0.5) 144 | else: 145 | ax.set_xlim(left=max(mu-3*sigma, np.min(data)), right=min(mu+3*sigma, np.max(data))) 146 | if label: 147 | ax.bar_label(ax.containers[1], fontsize=10, fmt=lambda x:f'{x:.3f}') 148 | plt.title('Distribution of ' + data_name, fontsize = 13) 149 | plt.legend(['data', 'Normal dist. ($\mu=$ {:.2f} and $\sigma=$ {:.2f} )'.format(mu, sigma)], loc='best') 150 | if save_path is None: 151 | plt.show() 152 | else: 153 | plt.savefig(save_path) 154 | plt.close() 155 | 156 | 157 | def plot_correlation_matrix(data:np.ndarray, fea_names:list, invalid_flag=-1, corr_thres=-1, save_path=None): 158 | ''' 159 | 生成相关矩阵 160 | data: (sample, n_fea) 161 | fea_names: (n_fea) 162 | percent_corr: 只打印相关性大于阈值的特征,我们对多重共线性特征更感兴趣 163 | ''' 164 | assert(len(fea_names) == data.shape[1]) 165 | fig_size = int(8+len(fea_names)*0.2) 166 | mat = np.zeros((data.shape[1], data.shape[1])) 167 | for i in tqdm(range(data.shape[1]), 'plot correlation'): 168 | for j in range(data.shape[1]): 169 | if i < j: 170 | idx_i = np.logical_and(data[:, i] != invalid_flag, np.logical_not(np.isnan(data[:, i]))) 171 | idx_j = np.logical_and(data[:, j] != invalid_flag, np.logical_not(np.isnan(data[:, j]))) 172 | avail_idx = np.logical_and(idx_i, idx_j) # availble subject index 173 | if np.sum(avail_idx) < 100: 174 | continue 175 | vi = data[avail_idx, i] 176 | vj = data[avail_idx, j] 177 | coef = np.corrcoef(vi, vj) # (2,2) 178 | if not np.isnan(coef[0,1]): 179 | mat[j, i] = coef[0,1] 180 | 181 | if corr_thres > 0: 182 | valid_idx = np.logical_or((np.abs(mat).max(axis=0) >= corr_thres), (np.abs(mat).max(axis=1) >= corr_thres)) 183 | mat = mat[valid_idx, :][:, valid_idx] 184 | fea_names = [name for idx, name in enumerate(fea_names) if valid_idx[idx]] 185 | f, ax = plt.subplots(figsize=(fig_size, fig_size)) 186 | mask = np.triu(np.ones_like(mat, dtype=bool)) 187 | cmap = sns.diverging_palette(230, 20, as_cmap=True) 188 | sns.heatmap(mat, mask=mask, cmap=cmap, vmax=1, center=0, annot=False, xticklabels=fea_names, yticklabels=fea_names, 189 | square=True, linewidths=.5, cbar_kws={"shrink": .5}) 190 | plt.xticks(rotation=90) 191 | plt.yticks(rotation=0) 192 | plt.title('Pearson Correlation Matrix', fontsize = 13) 193 | plt.subplots_adjust(left=0.3) 194 | plt.savefig(save_path) 195 | plt.close() 196 | return mat 197 | 198 | def plot_confusion_matrix(cm:np.ndarray, labels:list, title='Confusion matrix', comment='', save_path='./out.png'): 199 | ''' 200 | 生成混淆矩阵 201 | cm: cm[x][y]代表pred=x, gt=y 202 | labels: list(str) 各个class的名字 203 | save_path: 完整路径名 204 | ''' 205 | sns.set_style("whitegrid") 206 | plt.figure(figsize=(12, 10)) 207 | plt.gca().grid(False) 208 | plt.imshow(cm.T, interpolation='nearest', cmap=plt.cm.OrRd) 209 | if comment != '': 210 | title = title + f'[{comment}]' 211 | plt.title(title, size=18) 212 | plt.colorbar() 213 | tick_marks = np.arange(len(labels)) 214 | plt.xticks(tick_marks, labels, rotation=45, size=15) 215 | plt.yticks(tick_marks, labels, size=15) 216 | plt.ylabel('True label', size=18) 217 | plt.xlabel('Predicted label', size=18) 218 | width, height = cm.shape 219 | out_type = 'int' if np.max(cm) > 1+1e-3 else 'float' 220 | for x in range(width): 221 | for y in range(height): 222 | num_color = 'black' if cm[y][x] < 1.5*cm.mean() else 'white' 223 | cm_str = str(cm[y][x]) if out_type == 'int' else f'{cm[y][x]:.2f}' 224 | plt.annotate(cm_str, xy=(y, x), fontsize=24, color=num_color, 225 | horizontalalignment='center', 226 | verticalalignment='center') 227 | plt.savefig(save_path) 228 | plt.close() 229 | 230 | def plot_reg_correlation(X:np.ndarray, fea_names:Iterable, Y:np.ndarray, 231 | target_name: str, adapt=False, write_dir_path=None, plot_dash=True, comment:str=''): 232 | ''' 233 | 生成X的每一列关于Y的线性回归, 用来探究单变量对目标的影响 234 | write_dir_path: 将每个变量保存为一张图, 放在给定文件夹中 235 | X: (sample, n_fea) 236 | fea_names: list(str) 237 | Y: (sample,) 238 | target_name:str 239 | plot_dash: 是否画出Y=X的虚线 240 | 241 | ''' 242 | if write_dir_path is not None: 243 | os.makedirs(write_dir_path, exist_ok=True) 244 | Y = Y.reshape(Y.shape[0], 1) 245 | ymin, ymax = np.inf, -np.inf 246 | X,Y = X.astype(np.float32), Y.astype(np.float32) 247 | x_valid = ((1 - np.isnan(X)) * (1 - np.isnan(Y))).astype(bool) # 两者都是true才行 248 | corr_list = [] 249 | for idx, _ in enumerate(fea_names): 250 | x,y = X[x_valid[:, idx],idx], Y[x_valid[:, idx]] 251 | ymin, ymax = min(ymin, y.min()), max(ymax, y.max()) 252 | corr = np.corrcoef(x=x, y=y, rowvar=False)[1][0] # 相关矩阵, 2*2 253 | corr_list.append(corr) 254 | idx_list = list(range(len(fea_names))) 255 | idx_list = sorted(idx_list, key = lambda idx:abs(corr_list[idx]), reverse=True) # 按相关系数绝对值排序 256 | for rank, idx in enumerate(idx_list): 257 | name = fea_names[idx] 258 | logger.debug(f'Plot correlation: {name} cmt=[{comment}]') 259 | plt.figure(figsize = (12,12)) 260 | sns.regplot(x=X[x_valid[:, idx], idx], y=Y[x_valid[:, idx]], scatter_kws={'alpha':0.2}) 261 | # plot line y=x 262 | d_min, d_max = ymin, ymax 263 | if plot_dash: 264 | plt.plot(np.asarray([d_min, d_max]),np.asarray([d_min, d_max]), 265 | linestyle='dashed', color='C7', label='Y=X') 266 | plt.title(f'{name} vs {target_name} cmt=[{comment}]', fontsize = 12) 267 | if adapt and Y.shape[0] > 20: 268 | # 去除20个极值, 使得显示效果更好 269 | Y_sorted = np.sort(Y[x_valid[:, idx], 0], axis=0) 270 | X_sorted = np.sort(X[x_valid[:, idx], idx], axis=0) 271 | Y_span = Y_sorted[-10] - Y_sorted[10] 272 | X_span = X_sorted[-10] - X_sorted[10] 273 | plt.ylim(bottom=Y_sorted[10]-Y_span*0.05, top=Y_sorted[-10]+Y_span*0.05) 274 | plt.xlim(left=X_sorted[10]-X_span*0.05, right=X_sorted[-10]+X_span*0.05) 275 | else: 276 | plt.ylim(bottom=ymin, top=ymax) 277 | plt.xlabel(name) 278 | plt.ylabel(target_name) 279 | plt.legend(['$Pearson=$ {:.2f}'.format(corr_list[idx])], loc = 'best') 280 | if write_dir_path is None: 281 | plt.show() 282 | else: 283 | plt.savefig( 284 | os.path.join(write_dir_path, remove_slash(rf'{rank}@{fea_names[idx]}_vs_{target_name}{comment}.png')) 285 | ) 286 | plt.close() 287 | 288 | 289 | def plot_shap_scatter(fea_name:str, shap:np.ndarray, values:np.ndarray, x_lim=(0, -1), write_dir_path=None): 290 | ''' 291 | 生成某个特征所有样本的shap value和特征值的对应关系 292 | ''' 293 | plt.figure(figsize = (6,6)) 294 | sns.scatterplot(x=values, y=shap) 295 | plt.title(f'SHAP Value scatter plot for {fea_name}', fontsize = 12) 296 | plt.xlabel(fea_name) 297 | plt.ylabel('SHAP Value') 298 | if x_lim[1] > x_lim[0]: 299 | plt.xlim(left=x_lim[0], right=min(x_lim[1], values.max())) 300 | if write_dir_path is None: 301 | plt.show() 302 | else: 303 | plt.savefig( 304 | os.path.join(write_dir_path, f'shap_scatter_{remove_slash(fea_name)}.png') 305 | ) 306 | plt.close() 307 | 308 | 309 | def plot_dis_correlation(X:np.ndarray, fea_names, Y:np.ndarray, target_name, write_dir_path=None): 310 | ''' 311 | 生成X的每一列关于Y的条件分布, 用来探究单变量对目标的影响, 要求Y能转换为Bool型 312 | write_dir_path: 将每个变量保存为一张图, 放在给定文件夹中 313 | ''' 314 | Y = np.nan_to_num(Y, copy=False, nan=0) 315 | reinit_dir(write_dir_path) 316 | Y = Y.astype(bool) 317 | convert_list = [] 318 | for idx, name in enumerate(fea_names): 319 | try: 320 | x = X[:,idx].astype(float) 321 | x = np.nan_to_num(x, copy=False, nan=-1) 322 | valid = (x > -0.5) 323 | x_valid = x[valid] 324 | Y_valid = Y[valid] 325 | corr = np.corrcoef(x=x_valid, y=Y_valid, rowvar=False)[0,1] 326 | convert_list.append(corr) 327 | except Exception as e: 328 | logger.info(f'plot_dis_correlation: No correlation for {name}.') 329 | convert_list.append(-2) 330 | idx_list = list(range(len(fea_names))) 331 | idx_list = sorted(idx_list, key= lambda idx:abs(convert_list[idx]), reverse=True) 332 | for rank, idx in enumerate(idx_list): 333 | logger.debug(f'Plot correlation: {fea_names[idx]}') 334 | name = fea_names[idx] 335 | if convert_list[idx] > -1: 336 | df = pd.DataFrame(data=np.stack([X[:, idx].astype(float),Y],axis=1), columns=[name,'y']) 337 | df = df[df[name] > -0.5] # remove missing value 338 | else: 339 | df = pd.DataFrame(data=np.stack([X[:, idx].astype(str),Y],axis=1), columns=[name,'y']) 340 | sns.displot( 341 | data=df, x=name, hue='y', kind='hist', stat='proportion', common_norm=False, bins=20 342 | ) 343 | if convert_list[idx] > -1: 344 | plt.annotate(f'corr={convert_list[idx]:.3f}', xy=(0.05, 0.95), xycoords='axes fraction') 345 | if write_dir_path is None: 346 | plt.show() 347 | else: 348 | plt.savefig( 349 | os.path.join(write_dir_path, rf'{rank}.png') 350 | ) 351 | plt.close() 352 | if write_dir_path is not None: 353 | logger.info(f'dis correlation is saved in {write_dir_path}') 354 | 355 | 356 | def plot_na(data:pd.DataFrame, mode='matrix', disp=False, save_path=None): 357 | if mode == 'matrix': 358 | msno.matrix(data) 359 | plt.title('feature valid matrix (miss=white)') 360 | elif mode == 'bar': 361 | p=0.5 362 | df = msno.nullity_filter(data, filter='bottom', p=p) 363 | if not df.empty: 364 | fea_count = len(df.columns) 365 | msno.bar(df=df,fontsize=10, 366 | figsize=(15, (25 + max(fea_count,50) - 50) * 0.5), 367 | sort='descending') 368 | plt.title(f'feature valid rate, thres={p}') 369 | elif mode == 'sample': # calculate row missing rate 370 | na_mat = data.isna().to_numpy(dtype=np.int32) 371 | valid_mat = 1 - np.mean(na_mat, axis=1) 372 | sns.histplot(x=valid_mat, bins=20, stat='proportion') 373 | plt.title('sample valid rate') 374 | else: 375 | assert(False) 376 | if save_path: 377 | plt.savefig(save_path) 378 | if disp: 379 | plt.show() 380 | plt.close() 381 | 382 | def plot_category_dist(data:pd.DataFrame, type_dict:dict, output_dir=None): 383 | reinit_dir(output_dir) 384 | for name in type_dict.keys(): 385 | if isinstance(type_dict[name], dict): 386 | sns.histplot(data=data[name], stat='proportion') 387 | plt.title(f'Category items distribution of {name}') 388 | plt.savefig(os.path.join(output_dir, f'{name}_dist.png')) 389 | plt.close() 390 | 391 | def plot_model_comparison(csv_path, title, out_path): 392 | '''输出模型对比的散点图''' 393 | df = pd.read_csv(csv_path, encoding='utf-8') 394 | plt.figure(figsize=(10,10)) 395 | # columns=[model_name, hyper_params, metricA, metricB] 396 | sns.scatterplot(data=df, x="4cls_accuracy", y="robust_AUC", 397 | hue="model", 398 | palette="ch:r=-.2,d=.3_r", linewidth=0) 399 | plt.title(title) 400 | plt.savefig(out_path) --------------------------------------------------------------------------------