├── data ├── __init__.py └── read_medkg.py ├── models ├── __init__.py ├── prtransx.py ├── base.py ├── darling.py └── transx.py ├── paper └── Appendix.pdf ├── .gitignore ├── requirements.txt ├── scripts ├── medicines_with_codes.py ├── partition_triples.py ├── patients.py ├── patient_to_drug.py ├── patient_to_procedure.py ├── patients_to_diagnosis.py ├── diagnosis_to_procedure.py ├── demographic_sensitivity_chart.py ├── diagnosis_to_drug.py ├── medicine_to_demographics.py ├── procedure_to_demographics.py ├── disease_to_demographics.py ├── disease_to_procedure.py ├── disease_to_medicine.py ├── patient_demographics.py ├── use_case.py ├── patient_demographics_stats.py ├── medical_kg_with_demo_triples_all.py └── prob_medical_kg_with_demographics.py ├── constants.py ├── LICENCE ├── test.py ├── README.md ├── args.py ├── train.py └── utils.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /paper/Appendix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AynurGuluzade/DARLING/HEAD/paper/Appendix.pdf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/medical_kg 2 | data/mimic 3 | __pycache__/ 4 | experiments/ 5 | .ipynb_checkpoints -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.17.4 2 | matplotlib==3.3.3 3 | torch==1.2.0 4 | pandas==1.1.0 5 | scikit_learn==0.24.1 6 | -------------------------------------------------------------------------------- /models/prtransx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.transx import TransE, TransH 3 | from models.base import PrKGEBase 4 | from constants import * 5 | 6 | class PrTransE(TransE, PrKGEBase): 7 | def __init__(self, vocabs, dim=args.emb_dim, d_norm=args.d_norm, gamma=args.gamma, target=args.target): 8 | super(PrTransE, self).__init__(vocabs) 9 | 10 | def _distance(self, data): 11 | return torch.abs(self._probability_score(data[PROBABILITY]) - super(PrTransE, self)._distance(data)) 12 | 13 | class PrTransH(TransH, PrKGEBase): 14 | def __init__(self, vocabs, dim=args.emb_dim, d_norm=args.d_norm, gamma=args.gamma, target=args.target): 15 | super(PrTransH, self).__init__(vocabs) 16 | 17 | def _distance(self, data): 18 | return torch.abs(self._probability_score(data[PROBABILITY]) - super(PrTransH, self)._distance(data)) 19 | -------------------------------------------------------------------------------- /scripts/medicines_with_codes.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import libraries 3 | import os 4 | import csv 5 | from pathlib import Path 6 | 7 | ROOT_PATH = Path(os.path.dirname(__file__)).parent 8 | # %% 9 | # read mimic data 10 | prescriptions_data_path = f'{ROOT_PATH}/data/mimic/PRESCRIPTIONS.csv' 11 | 12 | prescriptions_columns = [] 13 | prescriptions = [] 14 | with open(prescriptions_data_path, newline='') as f: 15 | prescriptions = list(csv.reader(f)) 16 | prescriptions_columns = prescriptions.pop(0) #remove first row (column names) 17 | # %% 18 | # Create dictionary with MIMIC-III medicine name as key and DrugBank name as value 19 | 20 | medicine = {} 21 | for prescription in prescriptions: 22 | adm_id = prescription[2] 23 | medicine = '_'.join(prescription[7].lower().split()) # make string lowercase and concatenate words 24 | if adm_id in medicine: 25 | medicine[adm_id].append(medicine) # if adm_id is already in dict append new medicine 26 | else: 27 | medicine[adm_id] = [medicine] # add new adm_id in dictionary with list of medicines -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from pathlib import Path 4 | from args import get_args 5 | 6 | # set root path 7 | ROOT_PATH = Path(os.path.dirname(__file__)) 8 | 9 | # read arguments 10 | args = get_args() 11 | 12 | # define device 13 | CUDA = 'cuda' 14 | CPU = 'cpu' 15 | DEVICE = torch.device(CUDA if torch.cuda.is_available() else CPU) 16 | 17 | EPOCH = 'epoch' 18 | STATE_DICT = 'state_dict' 19 | BEST_VAL = 'best_val' 20 | OPTIMIZER = 'optimizer' 21 | POS = 'pos' 22 | NEG = 'neg' 23 | HEAD = 'head' 24 | RELATION = 'relation' 25 | TAIL = 'tail' 26 | TRIPLE = 'triple' 27 | DEMOGRAPHIC = 'demographic' 28 | PROBABILITY = 'probability' 29 | ENTITY = 'entity' 30 | LOSS = 'loss' 31 | TRANSE = 'TransE' 32 | TRANSH = 'TransH' 33 | TRANSR = 'TransR' 34 | TRANSD = 'TransD' 35 | PRTRANSE = 'PrTransE' 36 | PRTRANSH = 'PrTransH' 37 | DARLIN = 'DARLING' 38 | TREATMENT_RECOMMENDATION = 'treatment_recommendation' 39 | MEDICINE_RECOMMENDATION = 'medicine_recommendation' 40 | HITS_AT_1 = 'hits@1' 41 | HITS_AT_3 = 'hits@3' 42 | HITS_AT_10 = 'hits@10' 43 | MR = 'mr' 44 | MRR = 'mrr' 45 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Aynur Guluzade 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from constants import * 5 | 6 | class PrKGEBase: 7 | def _probability_score(self, probability): 8 | return args.scaling_prob * torch.log(1/probability) 9 | 10 | class KGEBase(nn.Module): 11 | def __init__(self, vocabs, dim=args.emb_dim, d_norm=args.d_norm, gamma=args.gamma, target=args.target): 12 | super(KGEBase, self).__init__() 13 | 14 | self.dim = dim 15 | self.d_norm = d_norm 16 | self.num_entities = len(vocabs[ENTITY]) 17 | self.num_relations = len(vocabs[RELATION]) 18 | 19 | # define loss function (criterion) 20 | self.criterion = nn.MarginRankingLoss(margin=gamma, reduction=args.reduction) 21 | 22 | # define target 23 | self.target = torch.FloatTensor([target]).to(DEVICE) 24 | 25 | def _init_embedding(self, num): 26 | return NotImplementedError 27 | 28 | def _distance(self, triples): 29 | return NotImplementedError 30 | 31 | def forward(self, positive, negative): 32 | return { 33 | LOSS: self.loss(positive, negative) 34 | } 35 | 36 | def loss(self, positive, negative): 37 | return self.criterion(self._distance(positive), self._distance(negative), target=self.target) 38 | 39 | def predict(self, data): 40 | return self._distance(data) -------------------------------------------------------------------------------- /scripts/partition_triples.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import libraries 3 | import os 4 | import csv 5 | import random 6 | from glob import glob 7 | from pathlib import Path 8 | from collections import Counter 9 | from sklearn.model_selection import train_test_split 10 | 11 | ROOT_PATH = Path(os.path.dirname(__file__)).parent.parent 12 | # %% 13 | # read data 14 | kg_path = f'{ROOT_PATH}/data/kg/simple/*.tsv' 15 | kg_file_paths = glob(kg_path) 16 | 17 | train = [] 18 | val = [] 19 | test = [] 20 | for file_path in kg_file_paths: 21 | with open(file_path) as tsv_file: 22 | read_tsv = csv.reader(tsv_file, delimiter="\t") 23 | if 'demographics' in file_path: 24 | train.extend(list(read_tsv)) 25 | else: 26 | train_part, val_part = train_test_split(list(read_tsv), test_size=0.2, shuffle=True) 27 | val_part, test_part = train_test_split(val_part, test_size=0.6, shuffle=True) 28 | train.extend(train_part) 29 | val.extend(val_part) 30 | test.extend(test_part) 31 | 32 | random.shuffle(train) 33 | random.shuffle(val) 34 | random.shuffle(test) 35 | # %% 36 | write_path = f'{ROOT_PATH}/data/kg/final' 37 | with open(f'{write_path}/train.txt', 'w') as outfile: 38 | csv.writer(outfile, delimiter='\t').writerows(train) 39 | 40 | with open(f'{write_path}/val.txt', 'w') as outfile: 41 | csv.writer(outfile, delimiter='\t').writerows(val) 42 | 43 | with open(f'{write_path}/test.txt', 'w') as outfile: 44 | csv.writer(outfile, delimiter='\t').writerows(test) 45 | # %% 46 | -------------------------------------------------------------------------------- /models/darling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from models.base import KGEBase, PrKGEBase 6 | from constants import * 7 | 8 | 9 | class DARLING(KGEBase, PrKGEBase): 10 | ''' 11 | DARLING: Demographic Aware pRobabiListic medIcal knowLedge embeddinG 12 | ''' 13 | def __init__(self, vocabs, dim=args.emb_dim, d_norm=args.d_norm, gamma=args.gamma, target=args.target): 14 | super(DARLING, self).__init__(vocabs) 15 | 16 | self.num_demographics = len(vocabs[DEMOGRAPHIC]) 17 | 18 | # create embedding layers 19 | self.entity_embedding = self._init_embedding(self.num_entities) 20 | self.relation_embedding = self._init_embedding(self.num_relations) 21 | self.demographic_embedding = self._init_embedding(self.num_demographics) 22 | 23 | def _init_embedding(self, num): 24 | weight = torch.FloatTensor(num, self.dim) 25 | nn.init.xavier_uniform_(weight) 26 | embeddings = nn.Embedding(num, self.dim) 27 | embeddings.weight = nn.Parameter(weight) 28 | embeddings.weight.data = F.normalize(embeddings.weight.data, p=2, dim=1) 29 | 30 | return embeddings 31 | 32 | def _distance(self, data): 33 | assert data[TRIPLE].size()[1] == 3 34 | 35 | heads, relations, tails, demographics = data[TRIPLE][:, 0], data[TRIPLE][:, 1], data[TRIPLE][:, 2], data[DEMOGRAPHIC] 36 | 37 | h = self.entity_embedding(heads) 38 | r = self.relation_embedding(relations) 39 | t = self.entity_embedding(tails) 40 | 41 | w_de = self.demographic_embedding(demographics) 42 | 43 | h_de = h - torch.sum(h * w_de, dim=1, keepdim=True) * w_de 44 | r_de = r - torch.sum(r * w_de, dim=1, keepdim=True) * w_de 45 | t_de = t - torch.sum(t * w_de, dim=1, keepdim=True) * w_de 46 | 47 | distance = h_de + r_de - t_de 48 | 49 | return torch.abs(self._probability_score(data[PROBABILITY]) - torch.norm(distance, p=self.d_norm, dim=1)) 50 | -------------------------------------------------------------------------------- /scripts/patients.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import libraries 3 | import os 4 | import csv 5 | from pathlib import Path 6 | 7 | ROOT_PATH = Path(os.path.dirname(__file__)).parent 8 | # %% 9 | # read mimic data 10 | # Include demographics (Gender, Date of Birth, Marital status, Ethnicity) 11 | patients_data_path = f'{ROOT_PATH}/data/mimic/PATIENTS.csv' 12 | admission_data_path = f'{ROOT_PATH}/data/mimic/ADMISSIONS.csv' 13 | 14 | patient_columns = [] 15 | patients = [] 16 | with open(patients_data_path, newline='') as f: 17 | patients = list(csv.reader(f)) 18 | patient_columns = patients.pop(0) 19 | 20 | admission_columns = [] 21 | admissions = [] 22 | with open(admission_data_path, newline='') as f: 23 | admissions = list(csv.reader(f)) 24 | admission_columns = admissions.pop(0) 25 | # %% 26 | # write data 27 | # demography: age, sex, ethnic group, country of birth, religion, marital status, population mobility. 28 | kg_patient_data = [['PATIENT_ID', 'GENDER', 'YOB', 'ETHNICITY', 'RELIGION', 'MARITAL_STATUS']] 29 | 30 | # create patient dictionary with id as key and gender, age as values 31 | patients_dictionary = { 32 | patient[1]: { 33 | 'gender': patient[2], 34 | 'age': patient[3].split()[0].split('-')[0] 35 | } for patient in patients 36 | } 37 | 38 | seen_ids = set() 39 | # create final patient data 40 | for admission in admissions: 41 | # get patient data 42 | pid = admission[1] 43 | gender = patients_dictionary[pid]['gender'] 44 | age = patients_dictionary[pid]['age'] 45 | ethnicity = admission[13] 46 | religion = admission[11] 47 | marital_status = admission[12] 48 | # check if we already have patient demographics 49 | if pid in seen_ids: 50 | continue 51 | 52 | # add patient data 53 | kg_patient_data.append([pid, gender, age, ethnicity, religion, marital_status]) 54 | seen_ids.add(pid) 55 | 56 | # write data 57 | write_path = f'{ROOT_PATH}/data/kg/patients.tsv' 58 | with open(write_path, 'w') as outfile: 59 | writer = csv.writer(outfile, delimiter='\t') 60 | writer.writerows(kg_patient_data) 61 | # %% 62 | -------------------------------------------------------------------------------- /scripts/patient_to_drug.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import libraries 3 | import os 4 | import csv 5 | from pathlib import Path 6 | 7 | ROOT_PATH = Path(os.path.dirname(__file__)).parent 8 | # %% 9 | # read mimic data 10 | # Include demographics (Gender, Date of Birth, Marital status, Ethnicity) 11 | drugs_data_path = f'{ROOT_PATH}/data/mimic/PRESCRIPTIONS.csv' 12 | admission_data_path = f'{ROOT_PATH}/data/mimic/ADMISSIONS.csv' 13 | 14 | drugs_columns = [] 15 | drugs = [] 16 | with open(drugs_data_path, newline='') as f: 17 | drugs = list(csv.reader(f)) 18 | drugs_columns = drugs.pop(0) #remove first row (column names) 19 | 20 | admission_columns = [] 21 | admissions = [] 22 | with open(admission_data_path, newline='') as f: 23 | admissions = list(csv.reader(f)) 24 | admission_columns = admissions.pop(0) 25 | 26 | # %% 27 | # extract and write data 28 | # create drugs dictionary with adm_id as key, list of drugs as values 29 | drugs_dict = {} 30 | 31 | for d in drugs: 32 | adm_id = d[2] # get adm_id from prescription table 33 | drug = d[7] # get drug from prescription table 34 | if adm_id in drugs_dict: 35 | drugs_dict[adm_id].append(drug) # if adm_id is already in dict append new drug 36 | else: 37 | drugs_dict[adm_id] = [drug] # add new adm_id in dictionary with list of drugs 38 | 39 | #assert len(drugs_dict.keys()) == len(admissions) 40 | # %% 41 | patient_to_drug = [] 42 | 43 | for adm in admissions: 44 | patient_id = adm[1] 45 | adm_id = adm[2] 46 | adm_time = adm[3].split()[0] # get only date without time 47 | if adm_id not in drugs_dict: 48 | continue 49 | drug_names = drugs_dict[adm_id] # get drugs for this adm_id (we already get drugs for adm_id in drugs_dict) 50 | for drug in drug_names: 51 | quadruple = [patient_id, 'patient_to_drug', drug, adm_time] 52 | patient_to_drug.append(quadruple) 53 | 54 | # remove duplicates 55 | unique_patient_to_drug = list(set(['|'.join(pd) for pd in patient_to_drug])) 56 | patient_to_drug = [upd.split('|') for upd in unique_patient_to_drug] 57 | # %% 58 | # write data 59 | write_path = f'{ROOT_PATH}/data/kg/patient_to_drug.tsv' 60 | with open(write_path, 'w') as outfile: 61 | writer = csv.writer(outfile, delimiter='\t') 62 | writer.writerows(patient_to_drug) 63 | 64 | # %% 65 | -------------------------------------------------------------------------------- /scripts/patient_to_procedure.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import libraries 3 | import os 4 | import csv 5 | from pathlib import Path 6 | 7 | ROOT_PATH = Path(os.path.dirname(__file__)).parent 8 | # %% 9 | # read mimic data 10 | procedures_data_path = f'{ROOT_PATH}/data/mimic/PROCEDURES_ICD.csv' 11 | admission_data_path = f'{ROOT_PATH}/data/mimic/ADMISSIONS.csv' 12 | 13 | procedures_columns = [] 14 | procedures = [] 15 | with open(procedures_data_path, newline='') as f: 16 | procedures = list(csv.reader(f)) 17 | procedures_columns = procedures.pop(0) #remove first row (column names) 18 | 19 | admission_columns = [] 20 | admissions = [] 21 | with open(admission_data_path, newline='') as f: 22 | admissions = list(csv.reader(f)) 23 | admission_columns = admissions.pop(0) 24 | 25 | # %% 26 | # extract and write data 27 | # create procedure dictionary with adm_id as key, list of procedure_icd codes as values 28 | procedures_dict = {} 29 | 30 | for p in procedures: 31 | adm_id = p[2] # get adm_id from prescription table 32 | proc_icd = p[4] # get drug from prescription table 33 | if adm_id in procedures_dict: 34 | procedures_dict[adm_id].append(proc_icd) # if adm_id is already in dict append new drug 35 | else: 36 | procedures_dict[adm_id] = [proc_icd] # add new adm_id in dictionary with list of drugs 37 | 38 | #assert len(procedures_dict.keys()) == len(admissions) 39 | # %% 40 | patient_to_procedure = [] 41 | 42 | for adm in admissions: 43 | patient_id = adm[1] 44 | adm_id = adm[2] 45 | adm_time = adm[3].split()[0] # get only date without time 46 | if adm_id not in procedures_dict: 47 | continue 48 | procedures = procedures_dict[adm_id] # get drugs for this adm_id (we already get drugs for adm_id in drugs_dict) 49 | for procedure in procedures: 50 | quadruple = [patient_id, 'patient_to_procedure', procedure, adm_time] 51 | patient_to_procedure.append(quadruple) 52 | 53 | # remove duplicates 54 | unique_patient_to_procedure = list(set(['|'.join(pd) for pd in patient_to_procedure])) 55 | patient_to_procedure = [upd.split('|') for upd in unique_patient_to_procedure] 56 | # %% 57 | # write data 58 | write_path = f'{ROOT_PATH}/data/kg/patient_to_procedure.tsv' 59 | with open(write_path, 'w') as outfile: 60 | writer = csv.writer(outfile, delimiter='\t') 61 | writer.writerows(patient_to_procedure) 62 | -------------------------------------------------------------------------------- /scripts/patients_to_diagnosis.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import libraries 3 | import os 4 | import csv 5 | from pathlib import Path 6 | 7 | ROOT_PATH = Path(os.path.dirname(__file__)).parent 8 | # %% 9 | # read mimic data 10 | # Include demographics (Gender, Date of Birth, Marital status, Ethnicity) 11 | diagnosis_data_path = f'{ROOT_PATH}/data/mimic/DIAGNOSES_ICD.csv' 12 | # disease_dict_data_path = f'{ROOT_PATH}/data/mimic/diseases/D_ICD_DIAGNOSES.csv' # this is for disease title 13 | admission_data_path = f'{ROOT_PATH}/data/mimic/ADMISSIONS.csv' 14 | 15 | diagnosis_columns = [] 16 | diagnosis_icd = [] 17 | with open(diagnosis_data_path, newline='') as f: 18 | diagnosis_icd = list(csv.reader(f)) 19 | diagnosis_columns = diagnosis_icd.pop(0) #remove first row (column names) 20 | 21 | admission_columns = [] 22 | admissions = [] 23 | with open(admission_data_path, newline='') as f: 24 | admissions = list(csv.reader(f)) 25 | admission_columns = admissions.pop(0) 26 | # %% 27 | # extract and write data 28 | 29 | # create diagnosis dictionary with adm_id as key, list of icd_codes as values 30 | diagnosis_icd_dict = {} 31 | 32 | for d in diagnosis_icd: 33 | adm_id = d[2] # get adm_id from diagnosis_icd table 34 | icd_code = d[4] # get icd_code from diagnosis_icd table 35 | if adm_id in diagnosis_icd_dict: 36 | diagnosis_icd_dict[adm_id].append(icd_code) # if adm_id is already in dict append new icd_code 37 | else: 38 | diagnosis_icd_dict[adm_id] = [icd_code] # add new adm_id in dictionary with list of icd_code 39 | 40 | assert len(diagnosis_icd_dict.keys()) == len(admissions) 41 | 42 | patient_to_diagnosis = [] 43 | 44 | for adm in admissions: 45 | patient_id = adm[1] 46 | adm_id = adm[2] 47 | adm_time = adm[3].split()[0] # get only date without time 48 | icd_codes = diagnosis_icd_dict[adm_id] # get icd codes for this adm_id (we already get icd_codes for adm_id in diagnosis_icd_dict) 49 | for icd in icd_codes: 50 | quadruple = [patient_id, 'patient_to_diagnosis', icd, adm_time] 51 | patient_to_diagnosis.append(quadruple) 52 | 53 | # remove duplicates 54 | unique_patient_to_diagnosis = list(set(['|'.join(pd) for pd in patient_to_diagnosis])) 55 | patient_to_diagnosis = [upd.split('|') for upd in unique_patient_to_diagnosis] 56 | # %% 57 | # write data 58 | write_path = f'{ROOT_PATH}/data/kg/patient_to_diagnosis.tsv' 59 | with open(write_path, 'w') as outfile: 60 | writer = csv.writer(outfile, delimiter='\t') 61 | writer.writerows(patient_to_diagnosis) 62 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import logging 6 | import torch 7 | import numpy as np 8 | from pathlib import Path 9 | from data.read_medkg import MedicalKG 10 | from torch.utils.data import DataLoader 11 | from utils import models, AverageMeter, RankEvaluator 12 | 13 | # import constants 14 | from constants import * 15 | 16 | # set logger 17 | logging.basicConfig(format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', 18 | datefmt='%d/%m/%Y %I:%M:%S %p', 19 | level=logging.INFO, 20 | handlers=[ 21 | logging.FileHandler(f'{args.results_path}/test_{args.model}_{args.task}.log', 'w'), 22 | logging.StreamHandler() 23 | ]) 24 | logger = logging.getLogger(__name__) 25 | 26 | # set a seed value 27 | random.seed(args.seed) 28 | np.random.seed(args.seed) 29 | if torch.cuda.is_available(): 30 | torch.manual_seed(args.seed) 31 | torch.cuda.manual_seed(args.seed) 32 | torch.cuda.manual_seed_all(args.seed) 33 | 34 | # set device 35 | torch.cuda.set_device(args.cuda_device) 36 | 37 | def main(): 38 | # load test data and prepare loader 39 | data = MedicalKG() 40 | vocabs = data.get_vocabs() 41 | _, _, test_data = data.get_data() 42 | test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=True) 43 | 44 | # load model 45 | model = models[args.model](vocabs).to(DEVICE) 46 | 47 | logger.info(f"=> loading checkpoint '{args.checkpoint_path}'") 48 | if DEVICE.type=='cpu': 49 | checkpoint = torch.load(f'{ROOT_PATH}/{args.checkpoint_path}', encoding='latin1', map_location='cpu') 50 | else: 51 | checkpoint = torch.load(f'{ROOT_PATH}/{args.checkpoint_path}', encoding='latin1') 52 | model.load_state_dict(checkpoint['state_dict']) 53 | logger.info(f"=> loaded checkpoint '{args.checkpoint_path}' (epoch {checkpoint['epoch']})") 54 | 55 | # define evaluator 56 | evaluator = RankEvaluator(vocabs) 57 | 58 | # get results 59 | results = evaluator.evaluate(test_loader, model) 60 | 61 | # log results 62 | logger.info(f'''Test results: 63 | \t\t\t\t\t Hits@1: {results[HITS_AT_1]:.4f} 64 | \t\t\t\t\t Hits@3: {results[HITS_AT_3]:.4f} 65 | \t\t\t\t\t Hits@10: {results[HITS_AT_10]:.4f} 66 | \t\t\t\t\t Mean Rank: {results[MR]:.4f} 67 | \t\t\t\t\t Mean Reciprocal Rank: {results[MRR]:.4f}''') 68 | 69 | if __name__ == '__main__': 70 | main() -------------------------------------------------------------------------------- /scripts/diagnosis_to_procedure.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import libraries 3 | import os 4 | import csv 5 | from pathlib import Path 6 | 7 | ROOT_PATH = Path(os.path.dirname(__file__)).parent 8 | # %% 9 | # read mimic data 10 | # Include demographics (Gender, Date of Birth, Marital status, Ethnicity) 11 | diagnosis_data_path = f'{ROOT_PATH}/data/mimic/DIAGNOSES_ICD.csv' # this is for diseace icd9 code 12 | procedures_data_path = f'{ROOT_PATH}/data/mimic/PROCEDURES_ICD.csv' 13 | 14 | diagnosis_columns = [] 15 | diagnosis_icd = [] 16 | with open(diagnosis_data_path, newline='') as f: 17 | diagnosis_icd = list(csv.reader(f)) 18 | diagnosis_columns = diagnosis_icd.pop(0) #remove first row (column names) 19 | 20 | rocedures_columns = [] 21 | procedures = [] 22 | with open(procedures_data_path, newline='') as f: 23 | procedures = list(csv.reader(f)) 24 | procedures_columns = procedures.pop(0) #remove first row (column names) 25 | # %% 26 | # extract and write data 27 | # create diagnosis dictionary with adm_id as key, list of icd_codes as values 28 | diagnosis_icd_dict = {} 29 | 30 | for d in diagnosis_icd: 31 | adm_id = d[2] # get adm_id from diagnosis_icd table 32 | icd_code = d[4] # get icd_code from diagnosis_icd table 33 | if adm_id in diagnosis_icd_dict: 34 | diagnosis_icd_dict[adm_id].append(icd_code) # if adm_id is already in dict append new icd_code 35 | else: 36 | diagnosis_icd_dict[adm_id] = [icd_code] # add new adm_id in dictionary with list of icd_code 37 | 38 | #%% 39 | # create procedure dictionary with adm_id as key, list of procedure_icd codes as values 40 | procedures_dict = {} 41 | 42 | for p in procedures: 43 | adm_id = p[2] # get adm_id from prescription table 44 | proc_icd = p[4] # get drug from prescription table 45 | if adm_id in procedures_dict: 46 | procedures_dict[adm_id].append(proc_icd) # if adm_id is already in dict append new drug 47 | else: 48 | procedures_dict[adm_id] = [proc_icd] # add new adm_id in dictionary with list of drugs 49 | 50 | #%% 51 | diagnosis_to_procedure = [] 52 | 53 | # iterate through dict diagnosis_icd_dict and 54 | 55 | for adm_id, icd_codes in diagnosis_icd_dict.items(): 56 | if adm_id not in procedures_dict: #if there is no drug for this adm_id skip it 57 | continue 58 | procedures = procedures_dict[adm_id] 59 | for diagnosis in icd_codes: 60 | for procedure in procedures: 61 | triple = [diagnosis, 'diagnosis_to_procedure', procedure] 62 | diagnosis_to_procedure.append(triple) 63 | 64 | # remove duplicates 65 | unique_diagnosis_to_procedure = list(set(['|'.join(dp) for dp in diagnosis_to_procedure])) 66 | diagnosis_to_procedure = [udp.split('|') for udp in unique_diagnosis_to_procedure] 67 | # %% 68 | # write data 69 | write_path = f'{ROOT_PATH}/data/kg/diagnosis_to_procedure.tsv' 70 | with open(write_path, 'w') as outfile: 71 | writer = csv.writer(outfile, delimiter='\t') 72 | writer.writerows(diagnosis_to_procedure) 73 | 74 | # %% 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Demographic Aware Probabilistic Medical Knowledge Graph Embeddings of Electronic Medical Records 2 | 3 | ## Requirements and Setup 4 | Python version >= 3.7 5 | 6 | PyTorch version >= 1.2.0 7 | 8 | ``` bash 9 | # clone the repository 10 | git clone https://github.com/AynurGuluzade/DARLING.git 11 | cd DARLING 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ## Download MIMIC-III dataset 16 | We construct a Medical Knowledge Graph using MIMIC-III dataset, which contains clinical information of patients. For accessing the data, researchers should complete an online training course and then apply for permission to download the complete MIMIC-III dataset. You can find more information [here](https://mimic.physionet.org/). 17 | 18 | After dowloading you will need to the move files under the [data](data) directory. 19 | 20 | ## Construct Probabilistic Medical Knowledge Graph with Demographics 21 | For constructing the knowledge graph please run: 22 | ``` bash 23 | # construct medical kg 24 | python scripts/prob_medical_kg_with_demographics.py 25 | ``` 26 | 27 | ## Train Framework 28 | For training you will need to adjust the paths in [args](args.py) file. At the same file you can also modify and experiment with different model settings. 29 | ``` bash 30 | # train Framework 31 | python train.py 32 | ``` 33 | 34 | ## Test Framework 35 | ``` bash 36 | # test Framework 37 | python test.py 38 | ``` 39 | 40 | ## License 41 | The repository is under [MIT License](LICENCE). 42 | 43 | ## Cite 44 | ```bash 45 | @InProceedings{10.1007/978-3-030-77211-6_48, 46 | author="Guluzade, Aynur 47 | and Kacupaj, Endri 48 | and Maleshkova, Maria", 49 | editor="Tucker, Allan 50 | and Henriques Abreu, Pedro 51 | and Cardoso, Jaime 52 | and Pereira Rodrigues, Pedro 53 | and Ria{\~{n}}o, David", 54 | title="Demographic Aware Probabilistic Medical Knowledge Graph Embeddings of Electronic Medical Records", 55 | booktitle="Artificial Intelligence in Medicine", 56 | year="2021", 57 | publisher="Springer International Publishing", 58 | address="Cham", 59 | pages="408--417", 60 | abstract="Medical knowledge graphs (KGs) constructed from Electronic Medical Records (EMR) contain abundant information about patients and medical entities. The utilization of KG embedding models on these data has proven to be efficient for different medical tasks. However, existing models do not properly incorporate patient demographics and most of them ignore the probabilistic features of the medical KG. In this paper, we propose DARLING (Demographic Aware pRobabiListic medIcal kNowledge embeddinG), a demographic-aware medical KG embedding framework that explicitly incorporates demographics in the medical entities space by associating patient demographics with a corresponding hyperplane. Our framework leverages the probabilistic features within the medical entities for learning their representations through demographic guidance. We evaluate DARLING through link prediction for treatments and medicines, on a medical KG constructed from EMR data, and illustrate its superior performance compared to existing KG embedding models.", 61 | isbn="978-3-030-77211-6" 62 | } 63 | ``` 64 | -------------------------------------------------------------------------------- /scripts/demographic_sensitivity_chart.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | def autolabel(ax, rects): 6 | """Attach a text label above each bar in *rects*, displaying its height.""" 7 | for rect in rects: 8 | height = rect.get_height() 9 | ax.annotate('{}'.format(height), 10 | xy=(rect.get_x() + rect.get_width() / 2, 0.85 * height), 11 | xytext=(0, 3), # 3 points vertical offset 12 | textcoords="offset points", 13 | ha='center', va='bottom', 14 | fontsize=6, 15 | rotation=90) 16 | 17 | labels = ['Gender', 'Age', 'Ethnic', 'G+A', 'G+E', 'A+E', 'All'] 18 | 19 | x = np.arange(len(labels)) # the label locations 20 | width = 0.35 # the width of the bars 21 | fig, (ax, ax1) = plt.subplots(2, 2) # create subplots 22 | 23 | # mean rank chart with probabilities 24 | ax[0].set_ylabel('Mean Rank', labelpad=1) 25 | mr_treatment_prob = ax[0].bar(x - width/2, [68.89, 66.46, 67.82, 66.01, 67.35, 65.18, 64.65], width, color='#CCE5FF', label='Treatment') 26 | mr_medicine_prob = ax[0].bar(x + width/2, [25.67, 23.84, 25.57, 23.92, 24.97, 23.29, 22.86], width, color='#FFCCE6', label='Medicine') 27 | ax[0].set_xticks(x) 28 | ax[0].set_xticklabels(labels, fontsize=8, rotation=-15) 29 | autolabel(ax[0], mr_treatment_prob) 30 | autolabel(ax[0], mr_medicine_prob) 31 | 32 | # mean rank chart without probabilities 33 | mr_treatment_without = ax[1].bar(x - width/2, [71.11, 69.32, 69.28, 68.83, 70.14, 67.83, 67.18], width, color='#CCE5FF', label='_nolegend_') 34 | mr_medicine_without = ax[1].bar(x + width/2, [27.13, 25.16, 26.86, 24.97, 26.46, 25.01, 24.89], width, color='#FFCCE6', label='_nolegend_') 35 | ax[1].set_xticks(x) 36 | ax[1].set_xticklabels(labels, fontsize=8, rotation=-15) 37 | autolabel(ax[1], mr_treatment_without) 38 | autolabel(ax[1], mr_medicine_without) 39 | 40 | # hits@10 bar chart 41 | ax1[0].set_ylabel('Hits@10(%)', labelpad=0) 42 | ax1[0].set_xlabel('With Probability Score', labelpad=0) 43 | hits10_treatment_prob = ax1[0].bar(x - width/2, [47.62, 50.48, 48.52, 50.97, 48.92, 51.32, 52.19], width, color='#CCE5FF', label='_nolegend_') 44 | hits10_medicine_prob = ax1[0].bar(x + width/2, [56.94, 59.71, 57.64, 60.25, 58.27, 60.97, 61.73], width, color='#FFCCE6', label='_nolegend_') 45 | ax1[0].set_xticks(x) 46 | ax1[0].set_xticklabels(labels, fontsize=8, rotation=-15) 47 | autolabel(ax1[0], hits10_treatment_prob) 48 | autolabel(ax1[0], hits10_medicine_prob) 49 | 50 | # hits@10 bar chart 51 | ax1[1].set_xlabel('Without Probability Score', labelpad=0) 52 | hits10_treatment_without = ax1[1].bar(x - width/2, [45.83, 48.17, 46.85, 48.17, 45.96, 48.25, 50.41], width, color='#CCE5FF', label='_nolegend_') 53 | hits10_medicine_without = ax1[1].bar(x + width/2, [54.58, 57.94, 55.42, 58.09, 56.12, 59.31, 59.96], width, color='#FFCCE6', label='_nolegend_') 54 | ax1[1].set_xticks(x) 55 | ax1[1].set_xticklabels(labels, fontsize=8, rotation=-15) 56 | autolabel(ax1[1], hits10_treatment_without) 57 | autolabel(ax1[1], hits10_medicine_without) 58 | 59 | 60 | fig.tight_layout() 61 | plt.figlegend(loc='upper center', fancybox=True, shadow=True, ncol=2) 62 | plt.show() 63 | -------------------------------------------------------------------------------- /scripts/diagnosis_to_drug.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import libraries 3 | import os 4 | import csv 5 | from pathlib import Path 6 | 7 | ROOT_PATH = Path(os.path.dirname(__file__)).parent 8 | # %% 9 | # read mimic data 10 | # Include demographics (Gender, Date of Birth, Marital status, Ethnicity) 11 | diagnosis_data_path = f'{ROOT_PATH}/data/mimic/DIAGNOSES_ICD.csv' # this is for diseace icd9 code 12 | drugs_data_path = f'{ROOT_PATH}/data/mimic/PRESCRIPTIONS.csv' 13 | 14 | diagnosis_columns = [] 15 | diagnosis_icd = [] 16 | with open(diagnosis_data_path, newline='') as f: 17 | diagnosis_icd = list(csv.reader(f)) 18 | diagnosis_columns = diagnosis_icd.pop(0) #remove first row (column names) 19 | 20 | drugs_columns = [] 21 | drugs = [] 22 | with open(drugs_data_path, newline='') as f: 23 | drugs = list(csv.reader(f)) 24 | drugs_columns = drugs.pop(0) #remove first row (column names) 25 | # %% 26 | # extract and write data 27 | # create 2 helper dictionaries 28 | # 1) diagnosis: admision_id -> [diagnose_1, diagnose_2...] 29 | # 2) perscriptions: admision_id -> [drug_1, drug_2...] 30 | # Iterate through diagnosis (for adm_id, diagnosis_codes in diagnosis.items():) 31 | # use admision key to get drugs (drugs <- perscriptions[k]) 32 | # iterate through diagnosis_codes and inside the loop iterate through drugs 33 | # create triple and append it on the triple list 34 | # format: ICD9_CODE(diagnosis) - diagnosis_to_drug - Drug 35 | 36 | # create diagnosis dictionary with adm_id as key, list of icd_codes as values 37 | diagnosis_icd_dict = {} 38 | 39 | for d in diagnosis_icd: 40 | adm_id = d[2] # get adm_id from diagnosis_icd table 41 | icd_code = d[4] # get icd_code from diagnosis_icd table 42 | if adm_id in diagnosis_icd_dict: 43 | diagnosis_icd_dict[adm_id].append(icd_code) # if adm_id is already in dict append new icd_code 44 | else: 45 | diagnosis_icd_dict[adm_id] = [icd_code] # add new adm_id in dictionary with list of icd_code 46 | 47 | #assert len(diagnosis_icd_dict.keys()) == len(admissions) 48 | #%% 49 | # create drugs dictionary with adm_id as key, list of drugs as values 50 | drugs_dict = {} 51 | 52 | for d in drugs: 53 | adm_id = d[2] # get adm_id from prescription table 54 | drug = d[7] # get drug from prescription table 55 | if adm_id in drugs_dict: 56 | drugs_dict[adm_id].append(drug) # if adm_id is already in dict append new drug 57 | else: 58 | drugs_dict[adm_id] = [drug] # add new adm_id in dictionary with list of drugs 59 | 60 | #%% 61 | diagnosis_to_drug = [] 62 | # iterate through dict diagnosis_icd_dict and 63 | for adm_id, icd_codes in diagnosis_icd_dict.items(): 64 | if adm_id not in drugs_dict: #if there is no drug for this adm_id skip it 65 | continue 66 | drugs = drugs_dict[adm_id] 67 | for diagnosis in icd_codes: 68 | for drug in drugs: 69 | triple = [diagnosis, 'diagnosis_to_drug', drug] 70 | diagnosis_to_drug.append(triple) 71 | 72 | # remove duplicates 73 | # unique_diagnosis_to_drug = list(set(['|'.join(dd) for dd in diagnosis_to_drug])) 74 | # diagnosis_to_drug = [udd.split('|') for udd in unique_diagnosis_to_drug] 75 | # %% 76 | # write data 77 | write_path = f'{ROOT_PATH}/data/kg/diagnosis_to_drug.tsv' 78 | with open(write_path, 'w') as outfile: 79 | writer = csv.writer(outfile, delimiter='\t') 80 | writer.writerows(diagnosis_to_drug) 81 | 82 | # %% 83 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(): 4 | parser = argparse.ArgumentParser(description='Demographic Aware Probabilistic Medical Knowledge Graph Embeddings of Electronic Medical Records') 5 | 6 | # general 7 | parser.add_argument('--seed', default=1234, type=int) 8 | parser.add_argument('--no-cuda', action='store_true') 9 | parser.add_argument('--cuda_device', default=0, type=int) 10 | 11 | # data 12 | parser.add_argument('--data_path', default='/data/medical_kg') 13 | 14 | # experiments 15 | parser.add_argument('--snapshots', default='experiments/snapshots', type=str) 16 | parser.add_argument('--results_path', default='experiments/results', type=str) 17 | parser.add_argument('--resume', default='experiments/snapshots', type=str) 18 | parser.add_argument('--checkpoint_path', default='experiments/snapshots', type=str) 19 | 20 | # model 21 | parser.add_argument('--model', default='DARLING', choices=['TransE', 22 | 'TransH', 23 | 'TransR', 24 | 'TransD', 25 | 'PrTransE', 26 | 'PrTransH', 27 | 'DARLING'], type=str) 28 | 29 | # task 30 | parser.add_argument('--task', default='both', choices=['both', 31 | 'treatment_recommendation', 32 | 'medicine_recommendation'], type=str) 33 | 34 | # model parameters 35 | parser.add_argument('--emb_dim', default=100, type=int) 36 | parser.add_argument('--dropout', default=1e-1, type=int) 37 | parser.add_argument('--d_norm', default=2, type=int) 38 | parser.add_argument('--gamma', default=1, type=int) 39 | parser.add_argument('--target', default=-1, type=int) 40 | parser.add_argument('--reduction', default='sum', choices=['none', 'mean', 'sum'], type=str) 41 | 42 | # training 43 | parser.add_argument('--lr', default=1e-3, type=float) 44 | parser.add_argument('--epochs', default=100, type=int) 45 | parser.add_argument('--start_epoch', default=0, type=int) 46 | parser.add_argument('--valfreq', default=1, type=int) 47 | parser.add_argument('--clip', default=5, type=int) 48 | parser.add_argument('--batch_size', default=128, type=int) 49 | 50 | # other 51 | parser.add_argument('--negative_prob', default=1e-15, type=float) 52 | parser.add_argument('--scaling_prob', default=1e-2, type=float) 53 | 54 | args, argv = parser.parse_known_args() 55 | 56 | if args.model in ['PrTransE', 'PrTransH']: 57 | parser.add_argument('--demographic_aware', default=False, action='store_true') 58 | parser.add_argument('--prob_embedding', default=True, action='store_true') 59 | elif args.model in ['DARLING']: 60 | parser.add_argument('--demographic_aware', default=True, action='store_true') 61 | parser.add_argument('--prob_embedding', default=True, action='store_true') 62 | else: 63 | parser.add_argument('--demographic_aware', default=False, action='store_true') 64 | parser.add_argument('--prob_embedding', default=False, action='store_true') 65 | 66 | parser.parse_args(argv, namespace=args) 67 | 68 | return args -------------------------------------------------------------------------------- /scripts/medicine_to_demographics.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import libraries 3 | import os 4 | import csv 5 | import random 6 | from pathlib import Path 7 | 8 | ROOT_PATH = Path(os.path.dirname(__file__)).parent.parent 9 | # %% 10 | # read mimic data 11 | perscriptions_data_path = f'{ROOT_PATH}/data/mimic/PRESCRIPTIONS.csv' 12 | patient_demographics_data_path = f'{ROOT_PATH}/data/kg/patient_demographics.tsv' 13 | 14 | perscriptions_columns = [] 15 | perscriptions = [] 16 | with open(perscriptions_data_path, newline='') as f: 17 | perscriptions = list(csv.reader(f)) 18 | perscriptions_columns = perscriptions.pop(0) #remove first row (column names) 19 | 20 | patient_demographics_columns = [] 21 | patient_demographics = [] 22 | with open(patient_demographics_data_path, newline='') as f: 23 | patient_demographics = list(csv.reader(f, delimiter="\t")) 24 | patient_demographics_columns = patient_demographics.pop(0) # remove first row (column names) 25 | # %% 26 | # create helper patient dictionary 27 | d_patient_demographics = {} 28 | for patient in patient_demographics: 29 | patient_id = patient[0] 30 | gender = patient[1].lower() 31 | age_group = patient[2] 32 | ethnic_group = patient[3] 33 | 34 | d_patient_demographics[patient_id] = { 35 | 'gender': gender, 36 | 'age_group': age_group, 37 | 'ethnic_group': ethnic_group 38 | } 39 | 40 | # iterate through perscriptions and use subject_id (patient id) to create relations with demographics 41 | medicine_to_gender = [] 42 | medicine_to_agegroup = [] 43 | medicine_to_ethnicgroup = [] 44 | 45 | for perscription in perscriptions: 46 | # extarct data 47 | patient_id = perscription[1] 48 | medicine = '_'.join(perscription[7].lower().split()) # make string lowercase and concatenate words 49 | demographics = d_patient_demographics[patient_id] 50 | gender = demographics['gender'] 51 | age_group = demographics['age_group'] 52 | ethnic_group = demographics['ethnic_group'] 53 | 54 | # add triples to list 55 | medicine_to_gender.append([medicine, 'medicine_to_gender', gender]) # Total: 4156450, Unique: 6034 56 | medicine_to_agegroup.append([medicine, 'medicine_to_agegroup', age_group]) # Total: 4156450, Unique: 10966 57 | medicine_to_ethnicgroup.append([medicine, 'medicine_to_ethnicgroup', ethnic_group]) # Total: 4156450, Unique: 11156 58 | 59 | # %% 60 | # gender 61 | string_medicine_to_gender = list(set(['|'.join(d) for d in medicine_to_gender])) 62 | unique_medicine_to_gender = [d.split('|') for d in string_medicine_to_gender] 63 | 64 | # age group 65 | string_medicine_to_agegroup = list(set(['|'.join(d) for d in medicine_to_agegroup])) 66 | unique_medicine_to_agegroup = [d.split('|') for d in string_medicine_to_agegroup] 67 | 68 | # ethnic group 69 | string_medicine_to_ethnicgroup = list(set(['|'.join(d) for d in medicine_to_ethnicgroup])) 70 | unique_medicine_to_ethnicgroup = [d.split('|') for d in string_medicine_to_ethnicgroup] 71 | 72 | # create final list and shuffle 73 | medicine_to_demographics = [] 74 | medicine_to_demographics.extend(unique_medicine_to_gender) 75 | medicine_to_demographics.extend(unique_medicine_to_agegroup) 76 | medicine_to_demographics.extend(unique_medicine_to_ethnicgroup) 77 | random.shuffle(medicine_to_demographics) 78 | 79 | write_medicine_to_demographics_path = f'{ROOT_PATH}/data/kg/simple/medicine_to_demographics.tsv' 80 | with open(write_medicine_to_demographics_path, 'w') as outfile: 81 | writer = csv.writer(outfile, delimiter='\t') 82 | writer.writerows(medicine_to_demographics) 83 | 84 | # %% 85 | -------------------------------------------------------------------------------- /scripts/procedure_to_demographics.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import libraries 3 | import os 4 | import csv 5 | import random 6 | from pathlib import Path 7 | 8 | ROOT_PATH = Path(os.path.dirname(__file__)).parent.parent 9 | # %% 10 | # read mimic data 11 | procedure_data_path = f'{ROOT_PATH}/data/mimic/PROCEDURES_ICD.csv' 12 | patient_demographics_data_path = f'{ROOT_PATH}/data/kg/patient_demographics.tsv' 13 | 14 | procedure_columns = [] 15 | procedure_icd = [] 16 | with open(procedure_data_path, newline='') as f: 17 | procedure_icd = list(csv.reader(f)) 18 | procedure_columns = procedure_icd.pop(0) # remove first row (column names) 19 | 20 | patient_demographics_columns = [] 21 | patient_demographics = [] 22 | with open(patient_demographics_data_path, newline='') as f: 23 | patient_demographics = list(csv.reader(f, delimiter="\t")) 24 | patient_demographics_columns = patient_demographics.pop(0) # remove first row (column names) 25 | 26 | # %% 27 | # create helper patient dictionary 28 | d_patient_demographics = {} 29 | for patient in patient_demographics: 30 | patient_id = patient[0] 31 | gender = patient[1].lower() 32 | age_group = patient[2] 33 | ethnic_group = patient[3] 34 | 35 | d_patient_demographics[patient_id] = { 36 | 'gender': gender, 37 | 'age_group': age_group, 38 | 'ethnic_group': ethnic_group 39 | } 40 | 41 | # iterate through procedure and use subject_id (patient id) to create relations with demographics 42 | procedure_to_gender = [] 43 | procedure_to_agegroup = [] 44 | procedure_to_ethnicgroup = [] 45 | 46 | for procedure in procedure_icd: 47 | # extarct data 48 | patient_id = procedure[1] 49 | procedure_icd = procedure[-1].lower() 50 | demographics = d_patient_demographics[patient_id] 51 | gender = demographics['gender'] 52 | age_group = demographics['age_group'] 53 | ethnic_group = demographics['ethnic_group'] 54 | 55 | # add triples to list 56 | procedure_to_gender.append([f'icd9_{procedure_icd}', 'procedure_to_gender', gender]) # Total: 240095, Unique: 3350 57 | procedure_to_agegroup.append([f'icd9_{procedure_icd}', 'procedure_to_agegroup', age_group]) # Total: 240095, Unique: 6517 58 | procedure_to_ethnicgroup.append([f'icd9_{procedure_icd}', 'procedure_to_ethnicgroup', ethnic_group]) # Total: 240095, Unique: 6056 59 | # %% 60 | # write unique triples for simple approach 61 | # gender 62 | string_procedure_to_gender = list(set(['|'.join(d) for d in procedure_to_gender])) 63 | unique_procedure_to_gender = [d.split('|') for d in string_procedure_to_gender] 64 | 65 | # age group 66 | string_procedure_to_agegroup = list(set(['|'.join(d) for d in procedure_to_agegroup])) 67 | unique_procedure_to_agegroup = [d.split('|') for d in string_procedure_to_agegroup] 68 | 69 | # ethnic group 70 | string_procedure_to_ethnicgroup = list(set(['|'.join(d) for d in procedure_to_ethnicgroup])) 71 | unique_procedure_to_ethnicgroup = [d.split('|') for d in string_procedure_to_ethnicgroup] 72 | 73 | # create final list and shuffle 74 | procedure_to_demographics = [] 75 | procedure_to_demographics.extend(unique_procedure_to_gender) 76 | procedure_to_demographics.extend(unique_procedure_to_agegroup) 77 | procedure_to_demographics.extend(unique_procedure_to_ethnicgroup) 78 | random.shuffle(procedure_to_demographics) 79 | 80 | write_procedure_to_demographics_path = f'{ROOT_PATH}/data/kg/simple/procedure_to_demographics.tsv' 81 | with open(write_procedure_to_demographics_path, 'w') as outfile: 82 | writer = csv.writer(outfile, delimiter='\t') 83 | writer.writerows(procedure_to_demographics) 84 | # %% 85 | -------------------------------------------------------------------------------- /scripts/disease_to_demographics.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import libraries 3 | import os 4 | import csv 5 | import random 6 | from pathlib import Path 7 | 8 | ROOT_PATH = Path(os.path.dirname(__file__)).parent.parent 9 | # %% 10 | # read mimic data 11 | diagnosis_data_path = f'{ROOT_PATH}/data/mimic/DIAGNOSES_ICD.csv' # this is for diseace icd9 code 12 | patient_demographics_data_path = f'{ROOT_PATH}/data/kg/patient_demographics.tsv' 13 | 14 | diagnosis_columns = [] 15 | diagnosis_icd = [] 16 | with open(diagnosis_data_path, newline='') as f: 17 | diagnosis_icd = list(csv.reader(f)) 18 | diagnosis_columns = diagnosis_icd.pop(0) # remove first row (column names) 19 | 20 | patient_demographics_columns = [] 21 | patient_demographics = [] 22 | with open(patient_demographics_data_path, newline='') as f: 23 | patient_demographics = list(csv.reader(f, delimiter="\t")) 24 | patient_demographics_columns = patient_demographics.pop(0) # remove first row (column names) 25 | 26 | # %% 27 | # create helper patient dictionary 28 | d_patient_demographics = {} 29 | for patient in patient_demographics: 30 | patient_id = patient[0] 31 | gender = patient[1].lower() 32 | age_group = patient[2] 33 | ethnic_group = patient[3] 34 | 35 | d_patient_demographics[patient_id] = { 36 | 'gender': gender, 37 | 'age_group': age_group, 38 | 'ethnic_group': ethnic_group 39 | } 40 | 41 | # iterate through diagnosis and use subject_id (patient id) to create relations with demographics 42 | disease_to_gender = [] 43 | disease_to_agegroup = [] 44 | disease_to_ethnicgroup = [] 45 | 46 | for disease in diagnosis_icd: 47 | # extarct data 48 | patient_id = disease[1] 49 | disease_icd = disease[-1].lower() 50 | demographics = d_patient_demographics[patient_id] 51 | gender = demographics['gender'] 52 | age_group = demographics['age_group'] 53 | ethnic_group = demographics['ethnic_group'] 54 | 55 | # add triples to list 56 | disease_to_gender.append([f'icd9_{disease_icd}', 'disease_to_gender', gender]) # Total: 651047, Unique: 11503 57 | disease_to_agegroup.append([f'icd9_{disease_icd}', 'disease_to_agegroup', age_group]) # Total: 651047, Unique: 21253 58 | disease_to_ethnicgroup.append([f'icd9_{disease_icd}', 'disease_to_ethnicgroup', ethnic_group]) # Total: 651047, Unique: 20784 59 | 60 | # %% 61 | # write unique triples for simple approach, merge all demographics into one file 62 | 63 | # gender 64 | string_disease_to_gender = list(set(['|'.join(d) for d in disease_to_gender])) 65 | unique_disease_to_gender = [d.split('|') for d in string_disease_to_gender] 66 | 67 | # age group 68 | string_disease_to_agegroup = list(set(['|'.join(d) for d in disease_to_agegroup])) 69 | unique_disease_to_agegroup = [d.split('|') for d in string_disease_to_agegroup] 70 | 71 | # ethnic group 72 | string_disease_to_ethnicgroup = list(set(['|'.join(d) for d in disease_to_ethnicgroup])) 73 | unique_disease_to_ethnicgroup = [d.split('|') for d in string_disease_to_ethnicgroup] 74 | 75 | # create final list and shuffle 76 | disease_to_demographics = [] 77 | disease_to_demographics.extend(unique_disease_to_gender) 78 | disease_to_demographics.extend(unique_disease_to_agegroup) 79 | disease_to_demographics.extend(unique_disease_to_ethnicgroup) 80 | random.shuffle(disease_to_demographics) 81 | 82 | # write data 83 | write_disease_to_demographics_path = f'{ROOT_PATH}/data/kg/simple/disease_to_demographics.tsv' 84 | with open(write_disease_to_demographics_path, 'w') as outfile: 85 | writer = csv.writer(outfile, delimiter='\t') 86 | writer.writerows(disease_to_demographics) 87 | 88 | # %% 89 | -------------------------------------------------------------------------------- /scripts/disease_to_procedure.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import libraries 3 | import os 4 | import csv 5 | import random 6 | from pathlib import Path 7 | from collections import Counter 8 | 9 | ROOT_PATH = Path(os.path.dirname(__file__)).parent.parent 10 | # %% 11 | # read mimic data 12 | disease_data_path = f'{ROOT_PATH}/data/mimic/DIAGNOSES_ICD.csv' # this is for diseace icd9 code 13 | procedure_data_path = f'{ROOT_PATH}/data/mimic/PROCEDURES_ICD.csv' 14 | 15 | diseases_columns = [] 16 | diseases_icd = [] 17 | with open(disease_data_path, newline='') as f: 18 | diseases_icd = list(csv.reader(f)) 19 | diseases_columns = diseases_icd.pop(0) # remove first row (column names) 20 | 21 | procedures_columns = [] 22 | procedures = [] 23 | with open(procedure_data_path, newline='') as f: 24 | procedures = list(csv.reader(f)) 25 | procedures_columns = procedures.pop(0) # remove first row (column names) 26 | # %% 27 | # create diseases dictionary with adm_id as key, list of icd_codes as values 28 | d_disease_icd = {} 29 | for d_icd in diseases_icd: 30 | adm_id = d_icd[2] # get adm_id from diseases_icd table 31 | icd_code = d_icd[4].lower() # get icd_code from diseases_icd table 32 | if adm_id in d_disease_icd: 33 | d_disease_icd[adm_id].append(icd_code) # if adm_id is already in dict append new icd_code 34 | else: 35 | d_disease_icd[adm_id] = [icd_code] # add new adm_id in dictionary with list of icd_code 36 | # %% 37 | # create procedures dictionary with adm_id as key, list of procedures as values 38 | d_procedures = {} 39 | for procedure in procedures: 40 | adm_id = procedure[2] # get adm_id from prescription table 41 | procedure_icd = procedure[-1].lower() 42 | if adm_id in d_procedures: 43 | d_procedures[adm_id].append(procedure_icd) # if adm_id is already in dict append new procedure 44 | else: 45 | d_procedures[adm_id] = [procedure_icd] # add new adm_id in dictionary with list of procedures 46 | #%% 47 | # Total triples: 61472226 48 | d_disease_to_procedure = {} # we calculate triple co-occurrence and extarct the top 5 49 | for adm_id, icd_codes in d_disease_icd.items(): 50 | if adm_id not in d_procedures: # if there is no procedure for this adm_id skip it 51 | continue 52 | procedures = d_procedures[adm_id] 53 | for disease in icd_codes: 54 | if disease not in d_disease_to_procedure: 55 | d_disease_to_procedure[disease] = [] 56 | for procedure in procedures: 57 | triple = [f'icd9_{disease}', 'disease_to_procedure', f'icd9_{procedure}'] 58 | d_disease_to_procedure[disease].append('|'.join(triple)) # join list for applying counter 59 | # %% 60 | # Since we map all possible diseases with all admission procedures, 61 | # we filter them by considering how many times they have co-occurred 62 | # we select the top k most co-occurred triples 63 | k = 10 64 | disease_to_procedure = [] 65 | for disease, triples in d_disease_to_procedure.items(): 66 | count_triples = Counter(triples) # counter triple co-occurrence 67 | sorted_triples = sorted(count_triples.items(), key=lambda kv: kv[1], reverse=True) # sort dictionary by value co-occurrence 68 | topk_triples = [tup[0] for tup in sorted_triples[:k]] # select top k triples 69 | final_triples = [triple.split('|') for triple in topk_triples] # create triple list from string 70 | disease_to_procedure.extend(final_triples) # add triples to list 71 | 72 | # %% 73 | # write data 74 | write_path = f'{ROOT_PATH}/data/kg/simple/disease_to_procedure.tsv' 75 | random.shuffle(disease_to_procedure) 76 | with open(write_path, 'w') as outfile: 77 | writer = csv.writer(outfile, delimiter='\t') 78 | writer.writerows(disease_to_procedure) 79 | 80 | # %% 81 | -------------------------------------------------------------------------------- /scripts/disease_to_medicine.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import libraries 3 | import os 4 | import csv 5 | import random 6 | from pathlib import Path 7 | from collections import Counter 8 | 9 | ROOT_PATH = Path(os.path.dirname(__file__)).parent.parent 10 | # %% 11 | # read mimic data 12 | disease_data_path = f'{ROOT_PATH}/data/mimic/DIAGNOSES_ICD.csv' # this is for diseace icd9 code 13 | medicine_data_path = f'{ROOT_PATH}/data/mimic/PRESCRIPTIONS.csv' 14 | 15 | diseases_columns = [] 16 | diseases_icd = [] 17 | with open(disease_data_path, newline='') as f: 18 | diseases_icd = list(csv.reader(f)) 19 | diseases_columns = diseases_icd.pop(0) #remove first row (column names) 20 | 21 | perscriptions_columns = [] 22 | perscriptions = [] 23 | with open(medicine_data_path, newline='') as f: 24 | perscriptions = list(csv.reader(f)) 25 | perscriptions_columns = perscriptions.pop(0) #remove first row (column names) 26 | # %% 27 | # create diseases dictionary with adm_id as key, list of icd_codes as values 28 | d_disease_icd = {} 29 | for d_icd in diseases_icd: 30 | adm_id = d_icd[2] # get adm_id from diseases_icd table 31 | icd_code = d_icd[4].lower() # get icd_code from diseases_icd table 32 | if adm_id in d_disease_icd: 33 | d_disease_icd[adm_id].append(icd_code) # if adm_id is already in dict append new icd_code 34 | else: 35 | d_disease_icd[adm_id] = [icd_code] # add new adm_id in dictionary with list of icd_code 36 | # %% 37 | # create medicines dictionary with adm_id as key, list of medicines as values 38 | d_medicines = {} 39 | for perscription in perscriptions: 40 | adm_id = perscription[2] # get adm_id from prescription table 41 | medicine = '_'.join(perscription[7].lower().split()) # make string lowercase and concatenate words 42 | if adm_id in d_medicines: 43 | d_medicines[adm_id].append(medicine) # if adm_id is already in dict append new medicine 44 | else: 45 | d_medicines[adm_id] = [medicine] # add new adm_id in dictionary with list of medicines 46 | #%% 47 | # Total triples: 61472226 48 | d_disease_to_medicine = {} # we calculate triple co-occurrence and extarct the top 5 49 | for adm_id, icd_codes in d_disease_icd.items(): 50 | if adm_id not in d_medicines: # if there is no medicine for this adm_id skip it 51 | continue 52 | medicines = d_medicines[adm_id] 53 | for disease in icd_codes: 54 | if disease not in d_disease_to_medicine: 55 | d_disease_to_medicine[disease] = [] 56 | for medicine in medicines: 57 | triple = [f'icd9_{disease}', 'disease_to_medicine', medicine] 58 | d_disease_to_medicine[disease].append('|'.join(triple)) # join list for applying counter 59 | # %% 60 | # Since we map all possible diseases with all perscripted medicines, 61 | # we filter them by considering how many times they have co-occurred 62 | # we select the top k most co-occurred triples 63 | k = 10 64 | disease_to_medicine = [] 65 | for disease, triples in d_disease_to_medicine.items(): 66 | count_triples = Counter(triples) # counter triple co-occurrence 67 | sorted_triples = sorted(count_triples.items(), key=lambda kv: kv[1], reverse=True) # sort dictionary by value co-occurrence 68 | topk_triples = [tup[0] for tup in sorted_triples[:k]] # select top k triples 69 | final_triples = [triple.split('|') for triple in topk_triples] # create triple list from string 70 | disease_to_medicine.extend(final_triples) # add triples to list 71 | 72 | # %% 73 | # write data 74 | write_path = f'{ROOT_PATH}/data/kg/simple/disease_to_medicine.tsv' 75 | random.shuffle(disease_to_medicine) 76 | with open(write_path, 'w') as outfile: 77 | writer = csv.writer(outfile, delimiter='\t') 78 | writer.writerows(disease_to_medicine) 79 | 80 | # %% 81 | -------------------------------------------------------------------------------- /data/read_medkg.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import random 3 | import torch 4 | import pandas as pd 5 | from glob import glob 6 | from collections import Counter 7 | from torch.utils.data import Dataset 8 | from sklearn.model_selection import train_test_split 9 | 10 | # import constants 11 | from constants import * 12 | 13 | class ModelData(Dataset): 14 | def __init__(self, raw_data, entity_vocab, relation_vocab, demographic_vocab): 15 | super(ModelData, self).__init__() 16 | 17 | self.triples = torch.LongTensor([[entity_vocab[h], relation_vocab[r], entity_vocab[t]] for h, r, t, _, _ in raw_data]).to(DEVICE) 18 | self.demographics = torch.LongTensor([demographic_vocab[d] for _, _, _, d, _ in raw_data]).to(DEVICE) 19 | self.probabilities = torch.FloatTensor([p for _, _, _, _, p in raw_data]).to(DEVICE) 20 | 21 | self.triples_num = len(self.triples) 22 | 23 | def __len__(self): 24 | return self.triples_num 25 | 26 | def __getitem__(self, item): 27 | return { 28 | TRIPLE: self.triples[item], 29 | DEMOGRAPHIC: self.demographics[item], 30 | PROBABILITY: self.probabilities[item] 31 | } 32 | 33 | class MedicalKG: 34 | def __init__(self): 35 | self.data_path = str(ROOT_PATH) + args.data_path 36 | self.read_data() 37 | self.create_vocabs() 38 | self.create_model_data() 39 | 40 | def read_file_with_pandas(self, path, col_sep='\t', col_names=[HEAD, RELATION, TAIL, DEMOGRAPHIC, PROBABILITY]): 41 | return pd.read_csv(path, 42 | sep=col_sep, 43 | header=None, 44 | names=col_names, 45 | keep_default_na=False, 46 | encoding='utf-8') 47 | 48 | def read_data(self): 49 | # read train data 50 | self.train_raw_data = self.read_file_with_pandas(f'{self.data_path}/train.txt') 51 | 52 | # read validation data 53 | self.val_raw_data = self.read_file_with_pandas(f'{self.data_path}/val.txt') 54 | 55 | # read test data 56 | self.test_raw_data = self.read_file_with_pandas(f'{self.data_path}/test.txt') 57 | 58 | def create_vocabs(self): 59 | # extarct train parts 60 | train_head = Counter(self.train_raw_data[HEAD]) 61 | train_relation = Counter(self.train_raw_data[RELATION]) 62 | train_tail = Counter(self.train_raw_data[TAIL]) 63 | train_demographic = Counter(self.train_raw_data[DEMOGRAPHIC]) 64 | 65 | # extarct val parts 66 | val_head = Counter(self.val_raw_data[HEAD]) 67 | val_relation = Counter(self.val_raw_data[RELATION]) 68 | val_tail = Counter(self.val_raw_data[TAIL]) 69 | val_demographic = Counter(self.val_raw_data[DEMOGRAPHIC]) 70 | 71 | # extarct test parts 72 | test_head = Counter(self.test_raw_data[HEAD]) 73 | test_relation = Counter(self.test_raw_data[RELATION]) 74 | test_tail = Counter(self.test_raw_data[TAIL]) 75 | test_demographic = Counter(self.test_raw_data[DEMOGRAPHIC]) 76 | 77 | # create list with entities and relations 78 | entity_list = list((train_head + val_head + test_head + train_tail + val_tail + test_tail).keys()) 79 | relation_list = list(train_relation.keys()) 80 | demographic_list = list((train_demographic + val_demographic + test_demographic).keys()) 81 | 82 | # create entity and relation vocabularies 83 | self.entity_vocab = {word: i for i, word in enumerate(entity_list)} 84 | self.relation_vocab = {word: i for i, word in enumerate(relation_list)} 85 | self.demographic_vocab = {word: i for i, word in enumerate(demographic_list)} 86 | 87 | def create_model_data(self): 88 | self.train_data = ModelData(self.train_raw_data.values, self.entity_vocab, self.relation_vocab, self.demographic_vocab) 89 | self.val_data = ModelData(self.val_raw_data.values, self.entity_vocab, self.relation_vocab, self.demographic_vocab) 90 | self.test_data = ModelData(self.test_raw_data.values, self.entity_vocab, self.relation_vocab, self.demographic_vocab) 91 | 92 | def get_vocabs(self): 93 | return { 94 | ENTITY: self.entity_vocab, 95 | RELATION: self.relation_vocab, 96 | DEMOGRAPHIC: self.demographic_vocab 97 | } 98 | 99 | def get_data(self): 100 | return self.train_data, self.val_data, self.test_data 101 | -------------------------------------------------------------------------------- /scripts/patient_demographics.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import libraries 3 | import os 4 | import csv 5 | from pathlib import Path 6 | 7 | ROOT_PATH = Path(os.path.dirname(__file__)).parent.parent 8 | # %% 9 | # read mimic data 10 | # Include demographics (Gender, Date of Birth, Marital status, Ethnicity) 11 | patients_data_path = f'{ROOT_PATH}/data/mimic/PATIENTS.csv' 12 | admission_data_path = f'{ROOT_PATH}/data/mimic/ADMISSIONS.csv' 13 | 14 | patient_columns = [] 15 | patients = [] 16 | with open(patients_data_path, newline='') as f: 17 | patients = list(csv.reader(f)) 18 | patient_columns = patients.pop(0) 19 | 20 | admission_columns = [] 21 | admissions = [] 22 | with open(admission_data_path, newline='') as f: 23 | admissions = list(csv.reader(f)) 24 | admission_columns = admissions.pop(0) 25 | 26 | # %% 27 | # group age 28 | def find_age_group(age): 29 | if age < 18: 30 | return '[0-18)' 31 | elif age >= 18 and age < 48: 32 | return '[18-48)' 33 | elif age >= 48 and age < 60: 34 | return '[48-60)' 35 | elif age >= 60 and age < 70: 36 | return '[60-70)' 37 | elif age >= 70 and age < 80: 38 | return '[70-80)' 39 | elif age >= 80: 40 | return '>=80' 41 | else: 42 | raise ValueError(f'Unkown value of age: {age}.') 43 | 44 | # all ethnicities and races 45 | white = [ 46 | 'WHITE', # 40996 47 | 'WHITE - RUSSIAN', # 164 48 | 'WHITE - OTHER EUROPEAN', # 81 49 | 'WHITE - BRAZILIAN', # 59 50 | 'WHITE - EASTERN EUROPEAN' # 25 51 | ] 52 | 53 | black = [ 54 | 'BLACK/AFRICAN AMERICAN', # 5440 55 | 'BLACK/CAPE VERDEAN', # 200 56 | 'BLACK/HAITIAN', # 101 57 | 'BLACK/AFRICAN', # 44 58 | 'CARIBBEAN ISLAND' # 9 59 | ] 60 | 61 | hispanic = [ 62 | 'HISPANIC OR LATINO', # 1696 63 | 'HISPANIC/LATINO - PUERTO RICAN', # 232 64 | 'HISPANIC/LATINO - DOMINICAN', # 78 65 | 'HISPANIC/LATINO - GUATEMALAN', # 40 66 | 'HISPANIC/LATINO - CUBAN', # 24 67 | 'HISPANIC/LATINO - SALVADORAN', # 19 68 | 'HISPANIC/LATINO - CENTRAL AMERICAN (OTHER)', # 13 69 | 'HISPANIC/LATINO - MEXICAN', # 13 70 | 'HISPANIC/LATINO - COLOMBIAN', # 9 71 | 'HISPANIC/LATINO - HONDURAN' # 4 72 | ] 73 | 74 | asian = [ 75 | 'ASIAN', # 1509 76 | 'ASIAN - CHINESE', # 277 77 | 'ASIAN - ASIAN INDIAN', # 85 78 | 'ASIAN - VIETNAMESE', # 53 79 | 'ASIAN - FILIPINO', # 25 80 | 'ASIAN - CAMBODIAN', # 17 81 | 'ASIAN - OTHER', # 17 82 | 'ASIAN - KOREAN', # 13 83 | 'ASIAN - JAPANESE', # 7 84 | 'ASIAN - THAI', # 4 85 | ] 86 | 87 | native = [ 88 | 'AMERICAN INDIAN/ALASKA NATIVE', # 51 89 | 'AMERICAN INDIAN/ALASKA NATIVE FEDERALLY RECOGNIZED TRIBE' # 3 90 | ] 91 | 92 | unknown = [ 93 | 'UNKNOWN/NOT SPECIFIED', # 4523 94 | 'UNABLE TO OBTAIN', # 814 95 | 'PATIENT DECLINED TO ANSWER' # 559 96 | ] 97 | 98 | other = [ 99 | 'OTHER', # 1512 100 | 'MULTI RACE ETHNICITY', # 130 101 | 'PORTUGUESE', # 61 102 | 'MIDDLE EASTERN', # 43 103 | 'NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER', # 18 104 | 'SOUTH AMERICAN' # 8 105 | ] 106 | 107 | def find_ethnic_group(ethnicity): 108 | if ethnicity in white: 109 | return 'white' 110 | elif ethnicity in black: 111 | return 'black' 112 | elif ethnicity in hispanic: 113 | return 'hispanic' 114 | elif ethnicity in asian: 115 | return 'asian' 116 | elif ethnicity in native: 117 | return 'native' 118 | elif ethnicity in unknown: 119 | return 'unknown' 120 | elif ethnicity in other: 121 | return 'other' 122 | else: 123 | raise ValueError(f'Unknown value for ethnicity: {ethnicity}') 124 | # %% 125 | # demographics: age, gender, ethnic group 126 | kg_patient_data = [['PATIENT_ID', 'GENDER', 'AGE_GROUP', 'ETHNIC_GROUP']] 127 | 128 | # create patient dictionary with id as key and gender, age as values 129 | patients_dictionary = { 130 | patient[1]: { 131 | 'gender': patient[2], 132 | 'dob_year': patient[3].split()[0].split('-')[0] 133 | } for patient in patients 134 | } 135 | 136 | seen_ids = set() 137 | # create final patient data 138 | for admission in admissions: 139 | # get patient data 140 | pid = admission[1] 141 | gender = patients_dictionary[pid]['gender'] 142 | adm_year = admission[3].split()[0].split('-')[0] 143 | dob_year = patients_dictionary[pid]['dob_year'] 144 | age = int(adm_year) - int(dob_year) 145 | age_group = find_age_group(age) 146 | ethnicity = admission[13] 147 | ethnic_group = find_ethnic_group(ethnicity) 148 | # check if we already have patient demographics 149 | if pid in seen_ids: 150 | continue 151 | 152 | # add patient data 153 | kg_patient_data.append([pid, gender, age_group, ethnic_group]) 154 | seen_ids.add(pid) 155 | 156 | # %% 157 | # write data 158 | write_path = f'{ROOT_PATH}/data/medical_kg/patient_demographics.tsv' 159 | with open(write_path, 'w') as outfile: 160 | writer = csv.writer(outfile, delimiter='\t') 161 | writer.writerows(kg_patient_data) 162 | # %% 163 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import logging 6 | import torch 7 | import numpy as np 8 | import torch.optim 9 | import torch.nn as nn 10 | from pathlib import Path 11 | from data.read_medkg import MedicalKG 12 | from torch.utils.data import DataLoader 13 | from utils import models, AverageMeter, NegativeSampling, RankEvaluator 14 | 15 | # import constants 16 | from constants import * 17 | 18 | # set logger 19 | logging.basicConfig(format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', 20 | datefmt='%d/%m/%Y %I:%M:%S %p', 21 | level=logging.INFO, 22 | handlers=[ 23 | logging.FileHandler(f'{args.results_path}/train_{args.model}.log', 'w'), 24 | logging.StreamHandler() 25 | ]) 26 | logger = logging.getLogger(__name__) 27 | 28 | # set a seed value 29 | random.seed(args.seed) 30 | np.random.seed(args.seed) 31 | if torch.cuda.is_available(): 32 | torch.manual_seed(args.seed) 33 | torch.cuda.manual_seed(args.seed) 34 | torch.cuda.manual_seed_all(args.seed) 35 | 36 | # set device 37 | torch.cuda.set_device(args.cuda_device) 38 | 39 | def main(): 40 | # load data 41 | data = MedicalKG() 42 | vocabs = data.get_vocabs() 43 | train_data, val_data, _ = data.get_data() 44 | 45 | # load model 46 | model = models[args.model](vocabs).to(DEVICE) 47 | 48 | # define optimizer 49 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 50 | 51 | # define negative sampling method 52 | negative_sampling = NegativeSampling(model.num_entities) 53 | 54 | logger.info(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters') 55 | 56 | if args.resume: 57 | if os.path.isfile(args.resume): 58 | logger.info(f"=> loading checkpoint '{args.resume}''") 59 | checkpoint = torch.load(args.resume) 60 | args.start_epoch = checkpoint[EPOCH] 61 | best_val = checkpoint[BEST_VAL] 62 | model.load_state_dict(checkpoint[STATE_DICT]) 63 | optimizer.load_state_dict(checkpoint[OPTIMIZER]) 64 | logger.info(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint[EPOCH]})") 65 | else: 66 | logger.info(f"=> no checkpoint found at '{args.resume}'") 67 | best_val = float('+inf') 68 | else: 69 | best_val = float('+inf') 70 | 71 | # prepare training and validation loader 72 | train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True) 73 | val_loader = DataLoader(val_data, batch_size=args.batch_size, shuffle=True) 74 | 75 | # define evaluator 76 | evaluator = RankEvaluator(vocabs) 77 | 78 | logger.info('Loaders prepared.') 79 | logger.info(f"Training data: {len(train_data)}") 80 | logger.info(f"Validation data: {len(val_data)}") 81 | logger.info(f"Unique tokens in entity vocabulary: {len(vocabs[ENTITY])}") 82 | logger.info(f"Unique tokens in relation vocabulary: {len(vocabs[RELATION])}") 83 | if args.demographic_aware: 84 | logger.info(f"Unique tokens in demographic vocabulary: {len(vocabs[DEMOGRAPHIC])}") 85 | logger.info(f'Batch: {args.batch_size}') 86 | logger.info(f'Epochs: {args.epochs}') 87 | 88 | # run epochs 89 | for epoch in range(args.start_epoch, args.epochs): 90 | # train for one epoch 91 | train(train_loader, model, optimizer, negative_sampling, epoch) 92 | 93 | # evaluate on validation set 94 | if (epoch+1) % args.valfreq == 0: 95 | results = evaluator.evaluate(val_loader, model) 96 | if results[MR] < best_val: 97 | best_val = min(results[MR], best_val) 98 | state = { 99 | EPOCH: epoch + 1, 100 | STATE_DICT: model.state_dict(), 101 | BEST_VAL: best_val, 102 | OPTIMIZER: optimizer.state_dict() 103 | } 104 | torch.save(state, f'{ROOT_PATH}/{args.snapshots}/{args.model}/e{state[EPOCH]}_v{state[BEST_VAL]:.4f}.pth.tar') 105 | # log results 106 | logger.info(f'''Val results - Epoch: {epoch+1} 107 | \t\t\t\t Hit@1: {results[HITS_AT_1]:.4f} 108 | \t\t\t\t Hit@3: {results[HITS_AT_3]:.4f} 109 | \t\t\t\t Hit@10: {results[HITS_AT_10]:.4f} 110 | \t\t\t\t Mean Rank: {results[MR]:.4f} 111 | \t\t\t\t Mean Reciprocal Rank: {results[MRR]:.4f}''') 112 | 113 | def train(train_loader, model, optimizer, negative_sampling, epoch): 114 | batch_time = AverageMeter() 115 | losses = AverageMeter() 116 | 117 | # switch to train mode 118 | model.train() 119 | 120 | end = time.time() 121 | for i, pos in enumerate(train_loader): 122 | # sample negative 123 | neg = negative_sampling(pos) 124 | 125 | # compute output 126 | output = model(pos, neg) 127 | 128 | # record loss 129 | losses.update(output[LOSS].data, pos[TRIPLE].size(0)) 130 | 131 | # compute gradient and do Adam step 132 | optimizer.zero_grad() 133 | output[LOSS].backward() 134 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 135 | optimizer.step() 136 | 137 | # measure elapsed time 138 | batch_time.update(time.time() - end) 139 | end = time.time() 140 | 141 | logger.info(f'Epoch: {epoch+1} - Train loss: {losses.val:.4f} ({losses.avg:.4f}) - Batch: {((i+1)/len(train_loader))*100:.2f}% - Time: {batch_time.sum:0.2f}s') 142 | 143 | if __name__ == '__main__': 144 | main() -------------------------------------------------------------------------------- /scripts/use_case.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import csv 4 | import time 5 | import random 6 | import logging 7 | import torch 8 | import numpy as np 9 | from pathlib import Path 10 | from ..data.read_medkg import MedicalKG 11 | from torch.utils.data import DataLoader 12 | from ..utils import models, AverageMeter, RankEvaluator 13 | 14 | # import constants 15 | from ..constants import * 16 | 17 | # set logger 18 | logging.basicConfig(format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', 19 | datefmt='%d/%m/%Y %I:%M:%S %p', 20 | level=logging.INFO, 21 | handlers=[ 22 | logging.FileHandler(f'{args.results_path}/usecase_{args.model}_{args.task}.log', 'w'), 23 | logging.StreamHandler() 24 | ]) 25 | logger = logging.getLogger(__name__) 26 | 27 | # set a seed value 28 | random.seed(args.seed) 29 | np.random.seed(args.seed) 30 | if torch.cuda.is_available(): 31 | torch.manual_seed(args.seed) 32 | torch.cuda.manual_seed(args.seed) 33 | torch.cuda.manual_seed_all(args.seed) 34 | 35 | # set device 36 | torch.cuda.set_device(args.cuda_device) 37 | 38 | # load test data and prepare loader 39 | data = MedicalKG() 40 | vocabs = data.get_vocabs() 41 | _, _, test_data = data.get_data() 42 | test_loader = DataLoader(test_data, batch_size=1, shuffle=True) 43 | 44 | # load model 45 | model = models[args.model](vocabs).to(DEVICE) 46 | 47 | logger.info(f"=> loading checkpoint '{args.checkpoint_path}'") 48 | if DEVICE.type=='cpu': 49 | checkpoint = torch.load(f'{ROOT_PATH}/{args.checkpoint_path}', encoding='latin1', map_location='cpu') 50 | else: 51 | checkpoint = torch.load(f'{ROOT_PATH}/{args.checkpoint_path}', encoding='latin1') 52 | model.load_state_dict(checkpoint['state_dict']) 53 | logger.info(f"=> loaded checkpoint '{args.checkpoint_path}' (epoch {checkpoint['epoch']})") 54 | 55 | # mimic read data 56 | diagnosis_data_path = f'{ROOT_PATH}/data/mimic/D_ICD_DIAGNOSES.csv' 57 | diagnosis_dict = {} 58 | with open(diagnosis_data_path, newline='') as f: 59 | for d in list(csv.reader(f))[1:]: 60 | diagnosis_dict[f'd_{d[1].lower()}'] = d[2] 61 | 62 | procedures_data_path = f'{ROOT_PATH}/data/mimic/D_ICD_PROCEDURES.csv' 63 | procedures_dict = {} 64 | with open(procedures_data_path, newline='') as f: 65 | for d in list(csv.reader(f))[1:]: 66 | procedures_dict[f'p_{d[1].lower()}'] = d[2] 67 | 68 | if args.task not in [TREATMENT_RECOMMENDATION, MEDICINE_RECOMMENDATION]: 69 | raise ValueError(f'Argument task must have value {TREATMENT_RECOMMENDATION} or {MEDICINE_RECOMMENDATION}.') 70 | 71 | evaluator = RankEvaluator(vocabs) 72 | 73 | k = 3 74 | 75 | # switch to evaluate mode 76 | model.eval() 77 | 78 | entity_ids = torch.arange(end=len(vocabs[ENTITY])).to(DEVICE) 79 | entity_names = list(vocabs[ENTITY].keys()) 80 | task_names = [entity_names[ti] for ti in evaluator.task_ids[args.task]] 81 | with torch.no_grad(): 82 | for _, data in enumerate(test_loader): 83 | val_triples = data[TRIPLE] 84 | 85 | # get batch size 86 | batch_size = val_triples.shape[0] 87 | 88 | all_entities = entity_ids.repeat(batch_size, 1) 89 | 90 | head, relation, tail = val_triples[:, 0], val_triples[:, 1], val_triples[:, 2] 91 | 92 | if args.task == TREATMENT_RECOMMENDATION and not entity_names[tail].startswith('p_'): 93 | continue 94 | if args.task == MEDICINE_RECOMMENDATION and (entity_names[tail].startswith('p_') or entity_names[tail].startswith('d_')): 95 | continue 96 | 97 | # exapnd for all entities 98 | expanded_heads = head.reshape(-1, 1).repeat(1, all_entities.size()[1]) 99 | expanded_relations = relation.reshape(-1, 1).repeat(1, all_entities.size()[1]) 100 | 101 | expanded_triples = torch.stack((expanded_heads, expanded_relations, all_entities), dim=2).reshape(-1, val_triples.shape[1]) 102 | model_data = {TRIPLE: expanded_triples} 103 | 104 | if args.demographic_aware: 105 | expanded_demographics = data[DEMOGRAPHIC].reshape(-1, 1).repeat(1, all_entities.size()[1]).reshape(-1, 1).squeeze() 106 | model_data.update({DEMOGRAPHIC: expanded_demographics}) 107 | 108 | if args.prob_embedding: 109 | expanded_probabilities = data[PROBABILITY].reshape(-1, 1).repeat(1, all_entities.size()[1]).reshape(-1, 1).squeeze() 110 | model_data.update({PROBABILITY: expanded_probabilities}) 111 | 112 | predicted, actual = evaluator._filter_entities(model.predict(model_data).reshape(batch_size, -1), tail, args.task) 113 | 114 | if evaluator._hits_at_k(predicted, actual, k=k) == 1: 115 | top_predictions = predicted.topk(k=k, largest=False)[1][0] 116 | if entity_names[head] in diagnosis_dict: 117 | head_name = diagnosis_dict[entity_names[head]] 118 | tail_name = entity_names[tail] if args.task == MEDICINE_RECOMMENDATION else procedures_dict[entity_names[tail]] 119 | prediction_name_1 = task_names[top_predictions[0]] if args.task == MEDICINE_RECOMMENDATION else procedures_dict[task_names[top_predictions[0]]] 120 | prediction_name_2 = task_names[top_predictions[1]] if args.task == MEDICINE_RECOMMENDATION else procedures_dict[task_names[top_predictions[1]]] 121 | prediction_name_3 = task_names[top_predictions[1]] if args.task == MEDICINE_RECOMMENDATION else procedures_dict[task_names[top_predictions[2]]] 122 | assert tail_name in [prediction_name_1, prediction_name_2, prediction_name_3] 123 | rank = [prediction_name_1, prediction_name_2, prediction_name_3].index(tail_name) + 1 124 | logger.info(f'{head_name} -> {tail_name} | [{prediction_name_1}, {prediction_name_2}, {prediction_name_3}], Rank: {rank}') 125 | -------------------------------------------------------------------------------- /models/transx.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from models.base import KGEBase 6 | from constants import * 7 | 8 | class TransE(KGEBase): 9 | def __init__(self, vocabs, dim=args.emb_dim, d_norm=args.d_norm, gamma=args.gamma, target=args.target): 10 | super(TransE, self).__init__(vocabs) 11 | 12 | # create embedding layers 13 | self.entity_embedding = self._init_embedding(self.num_entities) 14 | self.relation_embedding = self._init_embedding(self.num_relations) 15 | 16 | def _init_embedding(self, num): 17 | emb = nn.Embedding(num, self.dim) 18 | uniform_range = 6 / math.sqrt(self.dim) 19 | emb.weight.data.uniform_(-uniform_range, uniform_range) 20 | 21 | # e <= e / ||e|| 22 | emb.weight.data = emb.weight.data / torch.norm(emb.weight.data, dim=1, keepdim=True) 23 | 24 | return emb 25 | 26 | def _distance(self, data): 27 | assert data[TRIPLE].size()[1] == 3 28 | 29 | heads, relations, tails = data[TRIPLE][:, 0], data[TRIPLE][:, 1], data[TRIPLE][:, 2] # head, relation, tail: [batch_size] 30 | distance = self.entity_embedding(heads) + self.relation_embedding(relations) - self.entity_embedding(tails) 31 | 32 | return torch.norm(distance, p=self.d_norm, dim=1) 33 | 34 | 35 | class TransH(KGEBase): 36 | def __init__(self, vocabs, dim=args.emb_dim, d_norm=args.d_norm, gamma=args.gamma, target=args.target): 37 | super(TransH, self).__init__(vocabs) 38 | 39 | # create embedding layers 40 | self.entity_embedding = self._init_embedding(self.num_entities) 41 | self.relation_embedding = self._init_embedding(self.num_relations) 42 | self.normal_embedding = self._init_embedding(self.num_relations) 43 | 44 | def _init_embedding(self, num): 45 | weight = torch.FloatTensor(num, self.dim) 46 | nn.init.xavier_uniform_(weight) 47 | embeddings = nn.Embedding(num, self.dim) 48 | embeddings.weight = nn.Parameter(weight) 49 | embeddings.weight.data = F.normalize(embeddings.weight.data, p=2, dim=1) 50 | 51 | return embeddings 52 | 53 | def _distance(self, data): 54 | assert data[TRIPLE].size()[1] == 3 55 | 56 | heads, relations, tails = data[TRIPLE][:, 0], data[TRIPLE][:, 1], data[TRIPLE][:, 2] # head, relation, tail: [batch_size] 57 | 58 | w_r = self.normal_embedding(relations) 59 | 60 | h = self.entity_embedding(heads) 61 | d_r = self.relation_embedding(relations) 62 | t = self.entity_embedding(tails) 63 | 64 | # project to hyperplane 65 | h_r = h - torch.sum(h * w_r, dim=1, keepdim=True) * w_r 66 | t_r = t - torch.sum(t * w_r, dim=1, keepdim=True) * w_r 67 | 68 | distance = h_r + d_r - t_r 69 | 70 | return torch.norm(distance, p=self.d_norm, dim=1) 71 | 72 | 73 | class TransR(KGEBase): 74 | def __init__(self, vocabs, dim=args.emb_dim, d_norm=args.d_norm, gamma=args.gamma, target=args.target): 75 | super(TransR, self).__init__(vocabs) 76 | 77 | # create embedding layers 78 | self.entity_embedding = self._init_embedding(self.num_entities) 79 | self.relation_embedding = self._init_embedding(self.num_relations) 80 | self.projection_embedding = self._init_embedding(self.num_relations, self.dim * self.dim) 81 | 82 | def _init_embedding(self, num, dim=args.emb_dim): 83 | weight = torch.FloatTensor(num, dim) 84 | nn.init.xavier_uniform_(weight) 85 | embeddings = nn.Embedding(num, dim) 86 | embeddings.weight = nn.Parameter(weight) 87 | embeddings.weight.data = F.normalize(embeddings.weight.data, p=2, dim=1) 88 | 89 | return embeddings 90 | 91 | def _distance(self, data): 92 | assert data[TRIPLE].size()[1] == 3 93 | 94 | heads, relations, tails = data[TRIPLE][:, 0], data[TRIPLE][:, 1], data[TRIPLE][:, 2] # head, relation, tail: [batch_size] 95 | 96 | projection_matrix = self.projection_embedding(relations).view(-1, self.dim, self.dim) 97 | 98 | h = self.entity_embedding(heads) 99 | r = self.relation_embedding(relations) 100 | t = self.entity_embedding(tails) 101 | 102 | h_r = torch.matmul(projection_matrix, h.unsqueeze(-1)).squeeze(-1) 103 | t_r = torch.matmul(projection_matrix, t.unsqueeze(-1)).squeeze(-1) 104 | 105 | distance = h_r + r - t_r 106 | 107 | return torch.norm(distance, p=self.d_norm, dim=1) 108 | 109 | 110 | class TransD(KGEBase): 111 | def __init__(self, vocabs, dim=args.emb_dim, d_norm=args.d_norm, gamma=args.gamma, target=args.target): 112 | super(TransD, self).__init__(vocabs) 113 | 114 | # create embedding layers 115 | self.entity_embedding = self._init_embedding(self.num_entities) 116 | self.entity_projection = self._init_embedding(self.num_entities) 117 | self.relation_embedding = self._init_embedding(self.num_relations) 118 | self.relation_projection = self._init_embedding(self.num_relations) 119 | 120 | def _init_embedding(self, num, dim=args.emb_dim): 121 | weight = torch.FloatTensor(num, dim) 122 | nn.init.xavier_uniform_(weight) 123 | embeddings = nn.Embedding(num, dim) 124 | embeddings.weight = nn.Parameter(weight) 125 | embeddings.weight.data = F.normalize(embeddings.weight.data, p=2, dim=1) 126 | 127 | return embeddings 128 | 129 | def _distance(self, data): 130 | assert data[TRIPLE].size()[1] == 3 131 | 132 | heads, relations, tails = data[TRIPLE][:, 0], data[TRIPLE][:, 1], data[TRIPLE][:, 2] # head, relation, tail: [batch_size] 133 | 134 | h = self.entity_embedding(heads) 135 | r = self.relation_embedding(relations) 136 | t = self.entity_embedding(tails) 137 | 138 | h_p = self.entity_projection(heads) 139 | r_p = self.relation_projection(relations) 140 | t_p = self.entity_projection(tails) 141 | 142 | h_r = h + torch.sum(h_p * h, dim=-1, keepdim=True) * r_p 143 | t_r = t + torch.sum(t_p * t, dim=-1, keepdim=True) * r_p 144 | 145 | distance = h_r + r - t_r 146 | 147 | return torch.norm(distance, p=self.d_norm, dim=1) -------------------------------------------------------------------------------- /scripts/patient_demographics_stats.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import libraries 3 | import os 4 | import csv 5 | from pathlib import Path 6 | from datetime import datetime 7 | from collections import Counter 8 | 9 | ROOT_PATH = Path(os.path.dirname(__file__)).parent 10 | # %% 11 | # read mimic data 12 | # Include demographics (Gender, Date of Birth, Ethnicity) 13 | patients_data_path = f'{ROOT_PATH}/data/mimic/PATIENTS.csv' 14 | patient_columns = [] 15 | patients = [] 16 | with open(patients_data_path, newline='') as f: 17 | patients = list(csv.reader(f)) 18 | patient_columns = patients.pop(0) 19 | 20 | admission_data_path = f'{ROOT_PATH}/data/mimic/ADMISSIONS.csv' 21 | admission_columns = [] 22 | admissions = [] 23 | with open(admission_data_path, newline='') as f: 24 | admissions = list(csv.reader(f)) 25 | admission_columns = admissions.pop(0) 26 | # %% 27 | # age stats - based on number of admissions 28 | age_year = [] 29 | 30 | # create patient dictionary 31 | patients_year = {patient[1]: patient[3].split()[0].split('-')[0] for patient in patients} 32 | 33 | # How to quantize age values??? 34 | quantize_counter_admission = { 35 | '0': 0, # [0-18) 36 | '1': 0, # [18-48) 37 | '2': 0, # [48-60) 38 | '3': 0, # [60-70) 39 | '4': 0, # [70-80) 40 | '5': 0 # >=0 41 | } 42 | 43 | quantize_counter_patients = { 44 | '0': 0, # [0-18) 45 | '1': 0, # [18-48) 46 | '2': 0, # [48-60) 47 | '3': 0, # [60-70) 48 | '4': 0, # [70-80) 49 | '5': 0 # >=0 50 | } 51 | 52 | seen_patients = set() 53 | 54 | for adm in admissions: 55 | patient_id = adm[1] 56 | adm_year = adm[3].split()[0].split('-')[0] 57 | dob_year = patients_year[patient_id] 58 | age_year.append(int(adm_year) - int(dob_year)) 59 | if age_year[-1] < 18: 60 | quantize_counter_admission['0'] += 1 61 | if patient_id not in seen_patients: quantize_counter_patients['0'] += 1 62 | elif age_year[-1] >= 18 and age_year[-1] < 48: 63 | quantize_counter_admission['1'] += 1 64 | if patient_id not in seen_patients: quantize_counter_patients['1'] += 1 65 | elif age_year[-1] >= 48 and age_year[-1] < 60: 66 | quantize_counter_admission['2'] += 1 67 | if patient_id not in seen_patients: quantize_counter_patients['2'] += 1 68 | elif age_year[-1] >= 60 and age_year[-1] < 70: 69 | quantize_counter_admission['3'] += 1 70 | if patient_id not in seen_patients: quantize_counter_patients['3'] += 1 71 | elif age_year[-1] >= 70 and age_year[-1] < 80: 72 | quantize_counter_admission['4'] += 1 73 | if patient_id not in seen_patients: quantize_counter_patients['4'] += 1 74 | elif age_year[-1] >= 80: 75 | quantize_counter_admission['5'] += 1 76 | if patient_id not in seen_patients: quantize_counter_patients['5'] += 1 77 | seen_patients.add(patient_id) 78 | 79 | min_year = min(age_year) 80 | max_year = max(age_year) 81 | unique_years = set(age_year) 82 | # counter_age = Counter(age_year) 83 | 84 | print(f'Min year: {min_year}') 85 | print(f'Max year: {max_year}') 86 | print(f'Total unique years: {len(unique_years)}') 87 | # print(f'Counter year: {counter_age}') 88 | # print(f'Unique years: {sorted(list(unique_years))}') 89 | print(quantize_counter_admission) # {'0': 8180, '1': 8894, '2': 10050, '3': 10618, '4': 10474, '5': 10760} 90 | 91 | print(quantize_counter_patients) # {'0': 7942, '1': 7005, '2': 7515, '3': 7860, '4': 7939, '5': 8259} 92 | # %% 93 | # gender stats - based on number of admissions 94 | genders = [] 95 | 96 | # create patient dictionary 97 | patients_gender = {patient[1]: patient[2] for patient in patients} 98 | 99 | for adm in admissions: 100 | patient_id = adm[1] 101 | gender = patients_gender[patient_id] 102 | genders.append(gender) 103 | 104 | counter_gender_admission = Counter(genders) 105 | print(counter_gender_admission) # {'M': 32950, 'F': 26026} 106 | 107 | counter_gender_patients = Counter(patients_gender.values()) 108 | print(counter_gender_patients) # {'M': 26121, 'F': 20399} 109 | 110 | # %% 111 | # all ethnicities and races 112 | white = [ 113 | 'WHITE', # 40996 114 | 'WHITE - RUSSIAN', # 164 115 | 'WHITE - OTHER EUROPEAN', # 81 116 | 'WHITE - BRAZILIAN', # 59 117 | 'WHITE - EASTERN EUROPEAN' # 25 118 | ] 119 | 120 | black = [ 121 | 'BLACK/AFRICAN AMERICAN', # 5440 122 | 'BLACK/CAPE VERDEAN', # 200 123 | 'BLACK/HAITIAN', # 101 124 | 'BLACK/AFRICAN', # 44 125 | 'CARIBBEAN ISLAND' # 9 126 | ] 127 | 128 | hispanic = [ 129 | 'HISPANIC OR LATINO', # 1696 130 | 'HISPANIC/LATINO - PUERTO RICAN', # 232 131 | 'HISPANIC/LATINO - DOMINICAN', # 78 132 | 'HISPANIC/LATINO - GUATEMALAN', # 40 133 | 'HISPANIC/LATINO - CUBAN', # 24 134 | 'HISPANIC/LATINO - SALVADORAN', # 19 135 | 'HISPANIC/LATINO - CENTRAL AMERICAN (OTHER)', # 13 136 | 'HISPANIC/LATINO - MEXICAN', # 13 137 | 'HISPANIC/LATINO - COLOMBIAN', # 9 138 | 'HISPANIC/LATINO - HONDURAN' # 4 139 | ] 140 | 141 | asian = [ 142 | 'ASIAN', # 1509 143 | 'ASIAN - CHINESE', # 277 144 | 'ASIAN - ASIAN INDIAN', # 85 145 | 'ASIAN - VIETNAMESE', # 53 146 | 'ASIAN - FILIPINO', # 25 147 | 'ASIAN - CAMBODIAN', # 17 148 | 'ASIAN - OTHER', # 17 149 | 'ASIAN - KOREAN', # 13 150 | 'ASIAN - JAPANESE', # 7 151 | 'ASIAN - THAI', # 4 152 | ] 153 | 154 | native = [ 155 | 'AMERICAN INDIAN/ALASKA NATIVE', # 51 156 | 'AMERICAN INDIAN/ALASKA NATIVE FEDERALLY RECOGNIZED TRIBE' # 3 157 | ] 158 | 159 | unknown = [ 160 | 'UNKNOWN/NOT SPECIFIED', # 4523 161 | 'UNABLE TO OBTAIN', # 814 162 | 'PATIENT DECLINED TO ANSWER' # 559 163 | ] 164 | 165 | other = [ 166 | 'OTHER', # 1512 167 | 'MULTI RACE ETHNICITY', # 130 168 | 'PORTUGUESE', # 61 169 | 'MIDDLE EASTERN', # 43 170 | 'NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER', # 18 171 | 'SOUTH AMERICAN' # 8 172 | ] 173 | 174 | def find_ethnic_group(ethnicity): 175 | if ethnicity in white: 176 | return 'white' 177 | elif ethnicity in black: 178 | return 'black' 179 | elif ethnicity in hispanic: 180 | return 'hispanic' 181 | elif ethnicity in asian: 182 | return 'asian' 183 | elif ethnicity in native: 184 | return 'native' 185 | elif ethnicity in unknown: 186 | return 'unknown' 187 | elif ethnicity in other: 188 | return 'other' 189 | else: 190 | raise ValueError(f'Unknown value for ethnicity: {ethnicity}') 191 | # %% 192 | # ethnic stats - based on number of admissions 193 | ethnicities = [] 194 | 195 | # create patient dictionary 196 | patients_ethnic = {patient[1]: '' for patient in patients} 197 | 198 | for adm in admissions: 199 | patient_id = adm[1] 200 | ethnicity = adm[13] 201 | ethnic_group = find_ethnic_group(ethnicity) 202 | patients_ethnic[patient_id] = ethnic_group 203 | ethnicities.append(ethnic_group) 204 | 205 | counter_ethnicity_admission = Counter(ethnicities) 206 | print(counter_ethnicity_admission) # {'white': 41325, 'unknown': 5896, 'black': 5794, 'hispanic': 2128, 'asian': 2007, 'other': 1772, 'native': 54} 207 | 208 | counter_ethnicity_patients = Counter(patients_ethnic.values()) 209 | print(counter_ethnicity_patients) # {'white': 32372, 'unknown': 5410, 'black': 3871, 'asian': 1690, 'hispanic': 1642, 'other': 1489, 'native': 46} 210 | # %% 211 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from pathlib import Path 5 | from models.transx import TransE, TransH, TransR, TransD 6 | from models.prtransx import PrTransE, PrTransH 7 | from models.darling import DARLING 8 | 9 | # import constants 10 | from constants import * 11 | 12 | # define models 13 | models = { 14 | TRANSE: TransE, 15 | TRANSH: TransH, 16 | TRANSR: TransR, 17 | TRANSD: TransD, 18 | PRTRANSE: PrTransE, 19 | PRTRANSH: PrTransH, 20 | DARLIN: DARLING 21 | } 22 | 23 | # meter class for storing results 24 | class AverageMeter(object): 25 | """Computes and stores the average and current value""" 26 | def __init__(self): 27 | self.reset() 28 | 29 | def reset(self): 30 | self.val = 0 31 | self.avg = 0 32 | self.sum = 0 33 | self.count = 0 34 | 35 | def update(self, val, n=1): 36 | self.val = val / n 37 | self.sum += val 38 | self.count += n 39 | self.avg = self.sum / self.count 40 | 41 | class NegativeSampling: 42 | def __init__(self, num_entities): 43 | self.num_entities = num_entities 44 | self.demographic_aware = args.demographic_aware 45 | self.prob_embedding = args.prob_embedding 46 | 47 | def __call__(self, pos): 48 | if args.demographic_aware or args.prob_embedding: 49 | return { 50 | **self.uniform(pos), 51 | **{ 52 | DEMOGRAPHIC: pos[DEMOGRAPHIC], 53 | PROBABILITY: torch.FloatTensor([args.negative_prob] * pos[TRIPLE].shape[0]).to(DEVICE) 54 | } 55 | } 56 | else: 57 | return self.uniform(pos) 58 | 59 | def uniform(self, pos): 60 | pos_heads, pos_relations, pos_tails = pos[TRIPLE][:, 0], pos[TRIPLE][:, 1], pos[TRIPLE][:, 2] 61 | 62 | replace = torch.randint(high=2, size=pos_heads.size()).to(DEVICE) # 1 for head, 0 for tail 63 | 64 | random_entities = torch.randint(high=self.num_entities, size=pos_heads.size()).to(DEVICE) 65 | 66 | neg_heads = torch.where(replace == 1, random_entities, pos_heads) 67 | neg_tails = torch.where(replace == 0, random_entities, pos_tails) 68 | 69 | return { 70 | TRIPLE: torch.stack((neg_heads, pos_relations, neg_tails), dim=1).to(DEVICE) 71 | } 72 | 73 | class RankEvaluator: 74 | def __init__(self, vocabs): 75 | self.vocabs = vocabs 76 | self.task_ids = { 77 | TREATMENT_RECOMMENDATION: [v for k, v in self.vocabs[ENTITY].items() if k.startswith('p_')], 78 | MEDICINE_RECOMMENDATION: [v for k, v in self.vocabs[ENTITY].items() if not k.startswith('p_') and not k.startswith('d_')] 79 | } 80 | 81 | def _hits_at_k(self, predicted, actual, k=10): 82 | return torch.sum(torch.eq(predicted.topk(k=k, largest=False)[1], actual.unsqueeze(1))).item() / actual.size(0) 83 | 84 | def _mean_rank(self, predicted, actual): 85 | return torch.sum(torch.eq(predicted.argsort(), actual.unsqueeze(1)).nonzero()[:, 1].float().add(1.0)) / actual.size(0) 86 | 87 | def _mean_reciprocal_rank(self, predicted, actual): 88 | return torch.sum((1.0 / torch.eq(predicted.argsort(), actual.unsqueeze(1)).nonzero()[:, 1].float().add(1.0))).item() / actual.size(0) 89 | 90 | def _filter_entities(self, predicted, actual, task): 91 | filter_predicted, filter_actual = [], [] 92 | entity_vocab_list = list(self.vocabs[ENTITY].keys()) 93 | for p, a in zip(predicted.tolist(), actual.tolist()): 94 | if task == TREATMENT_RECOMMENDATION and not entity_vocab_list[a].startswith('p_'): 95 | continue 96 | if task == MEDICINE_RECOMMENDATION and (entity_vocab_list[a].startswith('p_') or entity_vocab_list[a].startswith('d_')): 97 | continue 98 | fp = [pv for i, pv in enumerate(p) if i in self.task_ids[task]] 99 | 100 | assert len(fp) == len(self.task_ids[task]) 101 | 102 | filter_predicted.append(torch.FloatTensor(fp)) 103 | filter_actual.append(fp.index(p[a])) 104 | 105 | return torch.stack(filter_predicted).to(DEVICE), torch.LongTensor(filter_actual).to(DEVICE) 106 | 107 | def _filter_medicines(self, predicted, actual): 108 | return None 109 | 110 | def _rank(self, predicted, actual, task): 111 | if task in [TREATMENT_RECOMMENDATION, MEDICINE_RECOMMENDATION]: 112 | predicted, actual = self._filter_entities(predicted, actual, task) 113 | assert predicted.size(0) == actual.size(0) 114 | 115 | self.metrics[HITS_AT_1].update(self._hits_at_k(predicted, actual, 1)) 116 | self.metrics[HITS_AT_3].update(self._hits_at_k(predicted, actual, 3)) 117 | self.metrics[HITS_AT_10].update(self._hits_at_k(predicted, actual, 10)) 118 | self.metrics[MR].update(self._mean_rank(predicted, actual)) 119 | self.metrics[MRR].update(self._mean_reciprocal_rank(predicted, actual)) 120 | 121 | def _results(self): 122 | return { 123 | HITS_AT_1: self.metrics[HITS_AT_1].avg, 124 | HITS_AT_3: self.metrics[HITS_AT_3].avg, 125 | HITS_AT_10: self.metrics[HITS_AT_10].avg, 126 | MR: self.metrics[MR].avg, 127 | MRR: self.metrics[MRR].avg 128 | } 129 | 130 | def _reset(self): 131 | self.metrics = { 132 | HITS_AT_1: AverageMeter(), 133 | HITS_AT_3: AverageMeter(), 134 | HITS_AT_10: AverageMeter(), 135 | MR: AverageMeter(), 136 | MRR: AverageMeter() 137 | } 138 | 139 | def evaluate(self, data_loader, model, task=args.task): 140 | # reset metrics 141 | self._reset() 142 | 143 | # switch to evaluate mode 144 | model.eval() 145 | 146 | entity_ids = torch.arange(end=len(self.vocabs[ENTITY])).to(DEVICE) 147 | with torch.no_grad(): 148 | for _, data in enumerate(data_loader): 149 | val_triples = data[TRIPLE] 150 | 151 | # get batch size 152 | batch_size = val_triples.shape[0] 153 | 154 | all_entities = entity_ids.repeat(batch_size, 1) 155 | 156 | heads, relations, tails = val_triples[:, 0], val_triples[:, 1], val_triples[:, 2] 157 | 158 | # exapnd for all entities 159 | expanded_heads = heads.reshape(-1, 1).repeat(1, all_entities.size()[1]) 160 | expanded_relations = relations.reshape(-1, 1).repeat(1, all_entities.size()[1]) 161 | 162 | expanded_triples = torch.stack((expanded_heads, expanded_relations, all_entities), dim=2).reshape(-1, val_triples.shape[1]) 163 | 164 | if args.demographic_aware: 165 | expanded_demographics = data[DEMOGRAPHIC].reshape(-1, 1).repeat(1, all_entities.size()[1]).reshape(-1, 1).squeeze() 166 | 167 | if args.prob_embedding: 168 | expanded_probabilities = data[PROBABILITY].reshape(-1, 1).repeat(1, all_entities.size()[1]).reshape(-1, 1).squeeze() 169 | 170 | # chunk data and predict results 171 | predicted_tails = [] 172 | for i in range(0, len(expanded_triples), batch_size**2): 173 | model_data = {TRIPLE: expanded_triples[i:i + batch_size**2]} 174 | 175 | if args.demographic_aware: 176 | model_data.update({DEMOGRAPHIC: expanded_demographics[i:i + batch_size**2]}) 177 | 178 | if args.prob_embedding: 179 | model_data.update({PROBABILITY: expanded_probabilities[i:i + batch_size**2]}) 180 | 181 | predicted_tails.append(model.predict(model_data)) 182 | 183 | predicted_tails = torch.cat(predicted_tails, dim=0).reshape(batch_size, -1) 184 | 185 | # rank results 186 | self._rank(predicted_tails, tails, task) 187 | 188 | return self._results() 189 | -------------------------------------------------------------------------------- /scripts/medical_kg_with_demo_triples_all.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import libraries 3 | import os 4 | import csv 5 | import random 6 | from pathlib import Path 7 | from collections import Counter 8 | 9 | ROOT_PATH = Path(os.path.dirname(__file__)).parent 10 | # %% 11 | # read patient demographics 12 | patient_demographics_data_path = f'{ROOT_PATH}/data/medical_kg/patient_demographics.tsv' 13 | patient_demographics = [] 14 | with open(patient_demographics_data_path, newline='') as f: 15 | patient_demographics = list(csv.reader(f, delimiter="\t")) 16 | patient_demographics_columns = patient_demographics.pop(0) # remove first row (column names) 17 | 18 | # helper patient dictionary 19 | d_patient_demographics = {} 20 | for patient in patient_demographics: 21 | patient_id = patient[0] 22 | gender = patient[1].lower() 23 | age_group = patient[2] 24 | ethnic_group = patient[3] 25 | 26 | d_patient_demographics[patient_id] = { 27 | 'gender': gender, 28 | 'age_group': age_group, 29 | 'ethnic_group': ethnic_group 30 | } 31 | # %% 32 | # read mimic data 33 | diagnosis_data_path = f'{ROOT_PATH}/data/mimic/DIAGNOSES_ICD.csv' # this is for diseace icd9 code 34 | diagnoses = [] 35 | with open(diagnosis_data_path, newline='') as f: 36 | diagnoses = list(csv.reader(f)) 37 | diagnosis_columns = diagnoses.pop(0) # remove first row (column names) 38 | 39 | perscriptions_data_path = f'{ROOT_PATH}/data/mimic/PRESCRIPTIONS.csv' 40 | perscriptions = [] 41 | with open(perscriptions_data_path, newline='') as f: 42 | perscriptions = list(csv.reader(f)) 43 | perscriptions_columns = perscriptions.pop(0) #remove first row (column names) 44 | 45 | procedure_data_path = f'{ROOT_PATH}/data/mimic/PROCEDURES_ICD.csv' 46 | procedures = [] 47 | with open(procedure_data_path, newline='') as f: 48 | procedures = list(csv.reader(f)) 49 | procedure_columns = procedures.pop(0) # remove first row (column names) 50 | # %% 51 | # create Diseace to Demographics 52 | 53 | # iterate through diagnosis and use subject_id (patient id) to create relations with demographics 54 | disease_to_gender = [] 55 | disease_to_agegroup = [] 56 | disease_to_ethnicgroup = [] 57 | 58 | for disease in diagnoses: 59 | # extarct data 60 | patient_id = disease[1] 61 | disease_icd = disease[-1].lower() 62 | demographics = d_patient_demographics[patient_id] 63 | gender = demographics['gender'] 64 | age_group = demographics['age_group'] 65 | ethnic_group = demographics['ethnic_group'] 66 | 67 | # add triples to list 68 | disease_to_gender.append([f'd_{disease_icd}', 'gender', gender]) # Total: 651047, Unique: 11503 69 | disease_to_agegroup.append([f'd_{disease_icd}', 'age_group', age_group]) # Total: 651047, Unique: 21253 70 | disease_to_ethnicgroup.append([f'd_{disease_icd}', 'ethnic_group', ethnic_group]) # Total: 651047, Unique: 20784 71 | 72 | # create final list and shuffle 73 | disease_to_demographics = [] 74 | disease_to_demographics.extend(disease_to_gender) 75 | disease_to_demographics.extend(disease_to_agegroup) 76 | disease_to_demographics.extend(disease_to_ethnicgroup) 77 | random.shuffle(disease_to_demographics) 78 | # %% 79 | # create Medicine to Demographics 80 | 81 | # iterate through perscriptions and use subject_id (patient id) to create relations with demographics 82 | medicine_to_gender = [] 83 | medicine_to_agegroup = [] 84 | medicine_to_ethnicgroup = [] 85 | 86 | for perscription in perscriptions: 87 | # extarct data 88 | patient_id = perscription[1] 89 | medicine = '_'.join(perscription[7].lower().split()) # make string lowercase and concatenate words 90 | demographics = d_patient_demographics[patient_id] 91 | gender = demographics['gender'] 92 | age_group = demographics['age_group'] 93 | ethnic_group = demographics['ethnic_group'] 94 | 95 | # add triples to list 96 | medicine_to_gender.append([medicine, 'gender', gender]) # Total: 4156450, Unique: 6034 97 | medicine_to_agegroup.append([medicine, 'age_group', age_group]) # Total: 4156450, Unique: 10966 98 | medicine_to_ethnicgroup.append([medicine, 'ethnic_group', ethnic_group]) # Total: 4156450, Unique: 11156 99 | 100 | # create final list and shuffle 101 | medicine_to_demographics = [] 102 | medicine_to_demographics.extend(medicine_to_gender) 103 | medicine_to_demographics.extend(medicine_to_agegroup) 104 | medicine_to_demographics.extend(medicine_to_ethnicgroup) 105 | random.shuffle(medicine_to_demographics) 106 | # %% 107 | # create Procedure to Demographics 108 | 109 | # iterate through procedure and use subject_id (patient id) to create relations with demographics 110 | procedure_to_gender = [] 111 | procedure_to_agegroup = [] 112 | procedure_to_ethnicgroup = [] 113 | 114 | for procedure in procedures: 115 | # extarct data 116 | patient_id = procedure[1] 117 | procedure_icd = procedure[-1].lower() 118 | demographics = d_patient_demographics[patient_id] 119 | gender = demographics['gender'] 120 | age_group = demographics['age_group'] 121 | ethnic_group = demographics['ethnic_group'] 122 | 123 | # add triples to list 124 | procedure_to_gender.append([f'p_{procedure_icd}', 'gender', gender]) # Total: 240095, Unique: 3350 125 | procedure_to_agegroup.append([f'p_{procedure_icd}', 'age_group', age_group]) # Total: 240095, Unique: 6517 126 | procedure_to_ethnicgroup.append([f'p_{procedure_icd}', 'ethnic_group', ethnic_group]) # Total: 240095, Unique: 6056 127 | 128 | # create final list and shuffle 129 | procedure_to_demographics = [] 130 | procedure_to_demographics.extend(procedure_to_gender) 131 | procedure_to_demographics.extend(procedure_to_agegroup) 132 | procedure_to_demographics.extend(procedure_to_ethnicgroup) 133 | random.shuffle(procedure_to_demographics) 134 | # %% 135 | # create Diseace to Medicine 136 | 137 | # create diseases dictionary with adm_id as key, list of icd_codes as values 138 | d_diagnoses = {} 139 | for d_icd in diagnoses: 140 | adm_id = d_icd[2] # get adm_id from diagnoses table 141 | icd_code = d_icd[4].lower() # get icd_code from diagnoses table 142 | if adm_id in d_diagnoses: 143 | d_diagnoses[adm_id].append(icd_code) # if adm_id is already in dict append new icd_code 144 | else: 145 | d_diagnoses[adm_id] = [icd_code] # add new adm_id in dictionary with list of icd_code 146 | 147 | # create medicines dictionary with adm_id as key, list of medicines as values 148 | d_medicines = {} 149 | for perscription in perscriptions: 150 | adm_id = perscription[2] # get adm_id from prescription table 151 | medicine = '_'.join(perscription[7].lower().split()) # make string lowercase and concatenate words 152 | if adm_id in d_medicines: 153 | d_medicines[adm_id].append(medicine) # if adm_id is already in dict append new medicine 154 | else: 155 | d_medicines[adm_id] = [medicine] # add new adm_id in dictionary with list of medicines 156 | 157 | # Total triples: 61472226 158 | d_disease_to_medicine = {} # we calculate triple co-occurrence and extarct the top 5 159 | for adm_id, icd_codes in d_diagnoses.items(): 160 | if adm_id not in d_medicines: # if there is no medicine for this adm_id skip it 161 | continue 162 | medicines = d_medicines[adm_id] 163 | for disease in icd_codes: 164 | if disease not in d_disease_to_medicine: 165 | d_disease_to_medicine[disease] = [] 166 | for medicine in medicines: 167 | triple = [f'd_{disease}', 'disease_to_medicine', medicine] 168 | d_disease_to_medicine[disease].append('|'.join(triple)) # join list for applying counter 169 | 170 | # Since we map all possible diseases with all perscripted medicines, 171 | # we filter them by considering how many times they have co-occurred 172 | # we select the top k most co-occurred triples 173 | k = 10 174 | disease_to_medicine = [] 175 | for disease, triples in d_disease_to_medicine.items(): 176 | count_triples = Counter(triples) # counter triple co-occurrence 177 | sorted_triples = sorted(count_triples.items(), key=lambda kv: kv[1], reverse=True) # sort dictionary by value co-occurrence 178 | final_triples = [triple_counter[0].split('|') for triple_counter in sorted_triples[:k] for _ in range(triple_counter[1])] # create triple list from string 179 | disease_to_medicine.extend(final_triples) # add triples to list 180 | random.shuffle(disease_to_medicine) 181 | # %% 182 | # create Diseace to Procedure 183 | # create procedures dictionary with adm_id as key, list of procedures as values 184 | d_procedures = {} 185 | for procedure in procedures: 186 | adm_id = procedure[2] # get adm_id from prescription table 187 | procedures = procedure[-1].lower() 188 | if adm_id in d_procedures: 189 | d_procedures[adm_id].append(procedures) # if adm_id is already in dict append new procedure 190 | else: 191 | d_procedures[adm_id] = [procedures] # add new adm_id in dictionary with list of procedures 192 | 193 | # Total triples: 61472226 194 | d_disease_to_procedure = {} # we calculate triple co-occurrence and extarct the top 5 195 | for adm_id, icd_codes in d_diagnoses.items(): 196 | if adm_id not in d_procedures: # if there is no procedure for this adm_id skip it 197 | continue 198 | procedures = d_procedures[adm_id] 199 | for disease in icd_codes: 200 | if disease not in d_disease_to_procedure: 201 | d_disease_to_procedure[disease] = [] 202 | for procedure in procedures: 203 | triple = [f'd_{disease}', 'disease_to_procedure', f'p_{procedure}'] 204 | d_disease_to_procedure[disease].append('|'.join(triple)) # join list for applying counter 205 | 206 | # Since we map all possible diseases with all admission procedures, 207 | # we filter them by considering how many times they have co-occurred 208 | # we select the top k most co-occurred triples 209 | k = 10 210 | disease_to_procedure = [] 211 | for disease, triples in d_disease_to_procedure.items(): 212 | count_triples = Counter(triples) # counter triple co-occurrence 213 | sorted_triples = sorted(count_triples.items(), key=lambda kv: kv[1], reverse=True) # sort dictionary by value co-occurrence 214 | final_triples = [triple_counter[0].split('|') for triple_counter in sorted_triples[:k] for _ in range(triple_counter[1])] # create triple list from string 215 | disease_to_procedure.extend(final_triples) # add triples to list 216 | random.shuffle(disease_to_procedure) 217 | # %% 218 | # merge all triples into one list 219 | all_triples = [] 220 | all_triples.extend(disease_to_demographics) 221 | all_triples.extend(medicine_to_demographics) 222 | all_triples.extend(procedure_to_demographics) 223 | all_triples.extend(disease_to_medicine) 224 | all_triples.extend(disease_to_procedure) 225 | random.shuffle(all_triples) 226 | # %% 227 | # write all triples 228 | write_path = f'{ROOT_PATH}/data/medical_kg/all_triples.txt' 229 | with open(write_path, 'w') as outfile: 230 | writer = csv.writer(outfile, delimiter='\t') 231 | writer.writerows(all_triples) 232 | # %% 233 | -------------------------------------------------------------------------------- /scripts/prob_medical_kg_with_demographics.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import libraries 3 | import os 4 | import csv 5 | import random 6 | from pathlib import Path 7 | from collections import Counter 8 | from sklearn.model_selection import train_test_split 9 | 10 | ROOT_PATH = Path(os.path.dirname(__file__)).parent 11 | # %% 12 | # read mimic data 13 | patients_data_path = f'{ROOT_PATH}/data/mimic/PATIENTS.csv' 14 | patients = [] 15 | with open(patients_data_path, newline='') as f: 16 | patients = list(csv.reader(f)) 17 | 18 | admission_data_path = f'{ROOT_PATH}/data/mimic/ADMISSIONS.csv' 19 | admissions = [] 20 | with open(admission_data_path, newline='') as f: 21 | admissions = list(csv.reader(f)) 22 | 23 | diagnosis_data_path = f'{ROOT_PATH}/data/mimic/DIAGNOSES_ICD.csv' # this is for diseace icd9 code 24 | diagnoses = [] 25 | with open(diagnosis_data_path, newline='') as f: 26 | diagnoses = list(csv.reader(f)) 27 | diagnosis_columns = diagnoses.pop(0) # remove first row (column names) 28 | 29 | perscriptions_data_path = f'{ROOT_PATH}/data/mimic/PRESCRIPTIONS.csv' 30 | perscriptions = [] 31 | with open(perscriptions_data_path, newline='') as f: 32 | perscriptions = list(csv.reader(f)) 33 | perscriptions_columns = perscriptions.pop(0) #remove first row (column names) 34 | 35 | procedure_data_path = f'{ROOT_PATH}/data/mimic/PROCEDURES_ICD.csv' 36 | procedures = [] 37 | with open(procedure_data_path, newline='') as f: 38 | procedures = list(csv.reader(f)) 39 | procedure_columns = procedures.pop(0) # remove first row (column names) 40 | # %% 41 | # prepare demographics 42 | # group age 43 | def find_age_group(age): 44 | if age < 18: 45 | return '[0-18)' 46 | elif age >= 18 and age < 48: 47 | return '[18-48)' 48 | elif age >= 48 and age < 60: 49 | return '[48-60)' 50 | elif age >= 60 and age < 70: 51 | return '[60-70)' 52 | elif age >= 70 and age < 80: 53 | return '[70-80)' 54 | elif age >= 80: 55 | return '>=80' 56 | else: 57 | raise ValueError(f'Unkown value of age: {age}.') 58 | 59 | # all ethnicities and races 60 | white = [ 61 | 'WHITE', # 40996 62 | 'WHITE - RUSSIAN', # 164 63 | 'WHITE - OTHER EUROPEAN', # 81 64 | 'WHITE - BRAZILIAN', # 59 65 | 'WHITE - EASTERN EUROPEAN' # 25 66 | ] 67 | 68 | black = [ 69 | 'BLACK/AFRICAN AMERICAN', # 5440 70 | 'BLACK/CAPE VERDEAN', # 200 71 | 'BLACK/HAITIAN', # 101 72 | 'BLACK/AFRICAN', # 44 73 | 'CARIBBEAN ISLAND' # 9 74 | ] 75 | 76 | hispanic = [ 77 | 'HISPANIC OR LATINO', # 1696 78 | 'HISPANIC/LATINO - PUERTO RICAN', # 232 79 | 'HISPANIC/LATINO - DOMINICAN', # 78 80 | 'HISPANIC/LATINO - GUATEMALAN', # 40 81 | 'HISPANIC/LATINO - CUBAN', # 24 82 | 'HISPANIC/LATINO - SALVADORAN', # 19 83 | 'HISPANIC/LATINO - CENTRAL AMERICAN (OTHER)', # 13 84 | 'HISPANIC/LATINO - MEXICAN', # 13 85 | 'HISPANIC/LATINO - COLOMBIAN', # 9 86 | 'HISPANIC/LATINO - HONDURAN' # 4 87 | ] 88 | 89 | asian = [ 90 | 'ASIAN', # 1509 91 | 'ASIAN - CHINESE', # 277 92 | 'ASIAN - ASIAN INDIAN', # 85 93 | 'ASIAN - VIETNAMESE', # 53 94 | 'ASIAN - FILIPINO', # 25 95 | 'ASIAN - CAMBODIAN', # 17 96 | 'ASIAN - OTHER', # 17 97 | 'ASIAN - KOREAN', # 13 98 | 'ASIAN - JAPANESE', # 7 99 | 'ASIAN - THAI', # 4 100 | ] 101 | 102 | native = [ 103 | 'AMERICAN INDIAN/ALASKA NATIVE', # 51 104 | 'AMERICAN INDIAN/ALASKA NATIVE FEDERALLY RECOGNIZED TRIBE' # 3 105 | ] 106 | 107 | unknown = [ 108 | 'UNKNOWN/NOT SPECIFIED', # 4523 109 | 'UNABLE TO OBTAIN', # 814 110 | 'PATIENT DECLINED TO ANSWER' # 559 111 | ] 112 | 113 | other = [ 114 | 'OTHER', # 1512 115 | 'MULTI RACE ETHNICITY', # 130 116 | 'PORTUGUESE', # 61 117 | 'MIDDLE EASTERN', # 43 118 | 'NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER', # 18 119 | 'SOUTH AMERICAN' # 8 120 | ] 121 | 122 | def find_ethnic_group(ethnicity): 123 | if ethnicity in white: 124 | return 'white' 125 | elif ethnicity in black: 126 | return 'black' 127 | elif ethnicity in hispanic: 128 | return 'hispanic' 129 | elif ethnicity in asian: 130 | return 'asian' 131 | elif ethnicity in native: 132 | return 'native' 133 | elif ethnicity in unknown: 134 | return 'unknown' 135 | elif ethnicity in other: 136 | return 'other' 137 | else: 138 | raise ValueError(f'Unknown value for ethnicity: {ethnicity}') 139 | 140 | # create patient dictionary with id as key and gender, age as values 141 | patients_dictionary = { 142 | patient[1]: { 143 | 'gender': patient[2], 144 | 'dob_year': patient[3].split()[0].split('-')[0] 145 | } for patient in patients 146 | } 147 | 148 | # create helper demographics dictionary 149 | # demographics: gender, age group, ethnic group 150 | d_demographics = {} 151 | for admission in admissions: 152 | # get patient data 153 | admid = admission[2] 154 | pid = admission[1] 155 | gender = patients_dictionary[pid]['gender'].lower() 156 | adm_year = admission[3].split()[0].split('-')[0] 157 | dob_year = patients_dictionary[pid]['dob_year'] 158 | age = int(adm_year) - int(dob_year) 159 | age_group = find_age_group(age) 160 | ethnicity = admission[13] 161 | ethnic_group = find_ethnic_group(ethnicity) 162 | 163 | # add patient data 164 | d_demographics[admid] = { 165 | 'patient_id': pid, 166 | 'gender': gender, 167 | 'age_group': age_group, 168 | 'ethnic_group': ethnic_group 169 | } 170 | # %% 171 | # create Diseace to Medicine 172 | # create diseases dictionary with adm_id as key, list of icd_codes as values 173 | d_diagnoses = {} 174 | for d_icd in diagnoses: 175 | adm_id = d_icd[2] # get adm_id from diagnoses table 176 | icd_code = d_icd[4].lower() # get icd_code from diagnoses table 177 | if adm_id in d_diagnoses: 178 | d_diagnoses[adm_id].append(icd_code) # if adm_id is already in dict append new icd_code 179 | else: 180 | d_diagnoses[adm_id] = [icd_code] # add new adm_id in dictionary with list of icd_code 181 | 182 | # create medicines dictionary with adm_id as key, list of medicines as values 183 | d_medicines = {} 184 | for perscription in perscriptions: 185 | adm_id = perscription[2] # get adm_id from prescription table 186 | medicine = '_'.join(perscription[7].lower().split()) # make string lowercase and concatenate words 187 | if adm_id in d_medicines: 188 | d_medicines[adm_id].append(medicine) # if adm_id is already in dict append new medicine 189 | else: 190 | d_medicines[adm_id] = [medicine] # add new adm_id in dictionary with list of medicines 191 | 192 | d_disease_to_medicine = {} # we calculate triple co-occurrence and extarct the top 5 193 | for adm_id, icd_codes in d_diagnoses.items(): 194 | if adm_id not in d_medicines: # if there is no medicine for this adm_id skip it 195 | continue 196 | medicines = d_medicines[adm_id] 197 | for disease in icd_codes: 198 | if disease not in d_disease_to_medicine: 199 | d_disease_to_medicine[disease] = {} 200 | d_disease_to_medicine[disease]['triples'] = [] 201 | d_disease_to_medicine[disease]['demographics'] = {} 202 | for medicine in medicines: 203 | triple = [f'd_{disease}', 'disease_to_medicine', medicine] 204 | d_disease_to_medicine[disease]['triples'].append('|'.join(triple)) # join list for applying counter 205 | d_disease_to_medicine[disease]['demographics']['|'.join(triple)] = f'{d_demographics[adm_id]["gender"]}|{d_demographics[adm_id]["age_group"]}|{d_demographics[adm_id]["ethnic_group"]}' 206 | 207 | # Since we map all possible diseases with all perscripted medicines, 208 | # we filter them by considering how many times they have co-occurred 209 | # we select the top k most co-occurred triples 210 | k = 10 211 | disease_to_medicine = [] 212 | for disease, medicine_dict in d_disease_to_medicine.items(): 213 | sorted_triples = sorted(Counter(medicine_dict['triples']).items(), key=lambda kv: kv[1], reverse=True) # sort triples by value co-occurrence 214 | # generate quadruples with demographics 215 | quadruples = [quadruple for quadruplelist in [[triple.split('|') + [medicine_dict['demographics'][triple]]] * count for triple, count in sorted_triples[:k]] for quadruple in quadruplelist] 216 | disease_to_medicine.extend(quadruples) # add quadruples to list 217 | random.shuffle(disease_to_medicine) 218 | # %% 219 | # create Diseace to Procedure 220 | # create procedures dictionary with adm_id as key, list of procedures as values 221 | d_procedures = {} 222 | for procedure in procedures: 223 | adm_id = procedure[2] # get adm_id from prescription table 224 | procedures = procedure[-1].lower() 225 | if adm_id in d_procedures: 226 | d_procedures[adm_id].append(procedures) # if adm_id is already in dict append new procedure 227 | else: 228 | d_procedures[adm_id] = [procedures] # add new adm_id in dictionary with list of procedures 229 | 230 | d_disease_to_procedure = {} # we calculate triple co-occurrence and extarct the top 5 231 | for adm_id, icd_codes in d_diagnoses.items(): 232 | if adm_id not in d_procedures: # if there is no procedure for this adm_id skip it 233 | continue 234 | procedures = d_procedures[adm_id] 235 | for disease in icd_codes: 236 | if disease not in d_disease_to_procedure: 237 | d_disease_to_procedure[disease] = {} 238 | d_disease_to_procedure[disease]['triples'] = [] 239 | d_disease_to_procedure[disease]['demographics'] = {} 240 | for procedure in procedures: 241 | triple = [f'd_{disease}', 'disease_to_procedure', f'p_{procedure}'] 242 | d_disease_to_procedure[disease]['triples'].append('|'.join(triple)) # join list for applying counter 243 | d_disease_to_procedure[disease]['demographics']['|'.join(triple)] = f'{d_demographics[adm_id]["gender"]}|{d_demographics[adm_id]["age_group"]}|{d_demographics[adm_id]["ethnic_group"]}' 244 | 245 | # Since we map all possible diseases with all admission procedures, 246 | # we filter them by considering how many times they have co-occurred 247 | # we select the top k most co-occurred triples 248 | k = 10 249 | disease_to_procedure = [] 250 | for disease, procedure_dict in d_disease_to_procedure.items(): 251 | sorted_triples = sorted(Counter(procedure_dict['triples']).items(), key=lambda kv: kv[1], reverse=True) # sort dictionary by value co-occurrence 252 | # generate quadruples with demographics 253 | quadruples = [quadruple for quadruplelist in [[triple.split('|') + [procedure_dict['demographics'][triple]]] * count for triple, count in sorted_triples[:k]] for quadruple in quadruplelist] 254 | disease_to_procedure.extend(quadruples) # add quadruples to list 255 | random.shuffle(disease_to_procedure) 256 | # %% 257 | # merge all triples into one list 258 | all_quadruples = [] 259 | all_quadruples.extend(disease_to_medicine) 260 | all_quadruples.extend(disease_to_procedure) 261 | random.shuffle(all_quadruples) 262 | # %% 263 | # create final quadruples with probabilities => quintuplets 264 | disease_counter = Counter([triple[0] for triple in all_quadruples]) 265 | quadruples_counter = Counter(['~'.join(triple) for triple in all_quadruples]) 266 | final_quintuplets = [q.split('~') + [round(c/disease_counter[q.split('~')[0]], 4)] for q, c in quadruples_counter.items()] 267 | # %% 268 | # partion final quintuplets 269 | # TODO: Filter demographic categories with less number of triples (e.g. triple_num < 500) 270 | train = [] 271 | val = [] 272 | test = [] 273 | train, val = train_test_split(final_quintuplets, test_size=0.2, shuffle=True) 274 | val, test = train_test_split(val, test_size=0.6, shuffle=True) 275 | 276 | random.shuffle(train) 277 | random.shuffle(val) 278 | random.shuffle(test) 279 | # %% 280 | write_path = f'{ROOT_PATH}/data/medical_kg' 281 | with open(f'{write_path}/train.txt', 'w') as outfile: 282 | csv.writer(outfile, delimiter='\t').writerows(train) 283 | 284 | with open(f'{write_path}/val.txt', 'w') as outfile: 285 | csv.writer(outfile, delimiter='\t').writerows(val) 286 | 287 | with open(f'{write_path}/test.txt', 'w') as outfile: 288 | csv.writer(outfile, delimiter='\t').writerows(test) 289 | # %% 290 | --------------------------------------------------------------------------------