├── assets └── intro.png ├── data └── IPH │ └── example.xlsx ├── tools.py ├── README.md ├── upload_data.py ├── resnet3d.py ├── LICENSE ├── check_result.py ├── models.py └── main_VAE.py /assets/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/med-air/TOP-GPM/HEAD/assets/intro.png -------------------------------------------------------------------------------- /data/IPH/example.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/med-air/TOP-GPM/HEAD/data/IPH/example.xlsx -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | import os 3 | import glob 4 | import argparse 5 | from sklearn.preprocessing import StandardScaler 6 | from tqdm import tqdm 7 | from sklearn.model_selection import train_test_split 8 | # from idhp_data import * 9 | import SimpleITK as sitk 10 | import cv2 11 | import numpy as np 12 | import math 13 | 14 | 15 | 16 | def policy_val(t, yf, q_t0, q_t1, compute_policy_curve=False): 17 | # if np.any(np.isnan(eff_pred)): 18 | # return np.nan, np.nan 19 | q_cat = np.concatenate((q_t0, q_t1),1) 20 | 21 | policy = np.argmax(q_cat,1) 22 | policy = policy[:,np.newaxis] 23 | t0_overlap = (policy==t)*(t==0) 24 | t1_overlap = (policy==t)*(t==1) 25 | 26 | 27 | if np.sum(t0_overlap) == 0: 28 | t0_value = 0 29 | else: 30 | t0_value = np.mean(yf[t0_overlap]) 31 | 32 | if np.sum(t1_overlap) == 0: 33 | t1_value = 0 34 | else: 35 | t1_value = np.mean(yf[t1_overlap]) 36 | 37 | 38 | 39 | pit_0 = np.sum(policy==0)/len(t) 40 | pit_1 = np.sum(policy==1)/len(t) 41 | 42 | policy_value = pit_0*t0_value + pit_1*t1_value 43 | 44 | 45 | return policy_value 46 | 47 | def factual_acc(t, yf, q_t0, q_t1): 48 | 49 | q_t0[q_t0>=0.5] = 1 50 | q_t0[q_t0<0.5] = 0 51 | 52 | q_t1[q_t1>=0.5] = 1 53 | q_t1[q_t1<0.5] = 0 54 | 55 | 56 | accuracy_0 = np.sum(q_t0[t==0]==yf[t==0])/len(yf[t==0]) 57 | accuracy_1 = np.sum(q_t1[t==1]==yf[t==1])/len(yf[t==1]) 58 | 59 | 60 | print("Factual accuracy of t0:", accuracy_0) 61 | print("Factual accuracy of t1:", accuracy_1) 62 | 63 | return accuracy_0,accuracy_1 64 | 65 | def factual_auc(t, yf, q_t0, q_t1): 66 | from sklearn import metrics 67 | y_t0 = [] 68 | y_t1 = [] 69 | p_t0 = [] 70 | p_t1 = [] 71 | 72 | for index in range(len(t)): 73 | if t[index] ==0: 74 | y_t0.append(yf[index]) 75 | p_t0.append(q_t0[index]) 76 | else: 77 | y_t1.append(yf[index]) 78 | p_t1.append(q_t1[index]) 79 | 80 | 81 | y_t0,p_t0, y_t1,p_t1 = np.array(y_t0), np.array(p_t0), np.array(y_t1), np.array(p_t1) 82 | auc0 = metrics.roc_auc_score(y_t0,p_t0) 83 | auc1 = metrics.roc_auc_score(y_t1,p_t1) 84 | 85 | 86 | 87 | print("Factual auc of t0:", auc0) 88 | print("Factual auc of t1:", auc1) 89 | 90 | return auc0,auc1 91 | 92 | def policy_risk_multi(t, yf, q_t0, q_t1): 93 | policy_value = policy_val(t, yf, q_t0, q_t1) 94 | policy_risk = 1 - policy_value 95 | return policy_risk 96 | 97 | def ate_error_0_1(t, yf, eff_pred): 98 | att = np.mean(yf[t==0]) - np.mean(yf[t==1]) 99 | pred_att = np.mean(eff_pred) 100 | 101 | return np.abs(att-pred_att) 102 | 103 | def ate_error_0_2(t, yf, eff_pred): 104 | att = np.mean(yf[t==0]) - np.mean(yf[t==2]) 105 | pred_att = np.mean(eff_pred) 106 | 107 | return np.abs(att-pred_att) 108 | 109 | def ate_error_1_2(t, yf, eff_pred): 110 | att = np.mean(yf[t==1]) - np.mean(yf[t==2]) 111 | pred_att = np.mean(eff_pred) 112 | 113 | return np.abs(att-pred_att) 114 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | + Pytorch implementation for paper **[Treatment Outcome Prediction for Intracerebral Hemorrhage via Generative Prognostic Model with Imaging and Tabular Data](https://arxiv.org/pdf/2307.12858)** 4 | 5 |

6 | intro 7 |

8 | 9 | ## Abstract 10 | > Intracerebral hemorrhage (ICH) is the second most common and deadliest form of stroke. Despite medical advances, predicting treatment outcomes for ICH remains a challenge. 11 | This paper proposes a novel prognostic model that utilizes both imaging and tabular data to predict treatment outcome for ICH. Our model is trained on observational data collected from non-randomized controlled trials, providing reliable predictions of treatment success. Specifically, we propose to employ a variational autoencoder model to generate a low-dimensional prognostic score, which can effectively address the selection bias resulting from the non-randomized controlled trials. Importantly, we develop a variational distributions combination module that combines the information from imaging data, non-imaging clinical data, and treatment assignment to accurately generate the prognostic score. We conducted extensive experiments on a real-world clinical dataset of intracerebral hemorrhage. Our proposed method demonstrates a substantial improvement in treatment outcome prediction compared to existing state-of-the-art approaches. 12 | 13 | ## Setup 14 | 15 | ### OS Requirements 16 | This model has been tested on the following systems: 17 | 18 | + Linux: Ubuntu 18.04 19 | 20 | ```bash 21 | Package Version 22 | ---------------------- ------------------- 23 | torch 1.4.0 24 | torchvision 0.5.0 25 | h5py 3.1.0 26 | opencv-python 4.5.2.52 27 | SimpleITK 2.0.2 28 | scikit-image. 0.17.2 29 | ml-collections 0.1.1 30 | tensorboardx 2.2.0 31 | medpy 0.4.0 32 | scikit-learn 0.24.2 33 | pandas 1.1.5 34 | ``` 35 | 36 | ## Training & Testing 37 | + This article uses a private dataset. In order to successfully run the code, you need to prepare your own dataset. 38 | + Specifically, you need to prepare a .xls file, which saves the patients' non-imaging clinical data and the path of imaging data. We have provided an example for you to run the data, which is saved in "./data/IPH/example.xls". 39 | + We run main_VAE.py to train and evaluate the model: 40 | ```bash 41 | python main_VAE.py 42 | ``` 43 | + Our proposed model is saved in models.py, named "VAE_MM". 44 | 45 | ## Citation 46 | If this repository is useful for your research, please cite: 47 | 48 | @inproceedings{ma2023treatment, 49 | title={Treatment Outcome Prediction for Intracerebral Hemorrhage via Generative Prognostic Model with Imaging and Tabular Data}, 50 | author={Ma, Wenao and Chen, Cheng and Abrigo, Jill and Mak, Calvin Hoi-Kwan and Gong, Yuqi and Chan, Nga Yan and Han, Chu and Liu, Zaiyi and Dou, Qi}, 51 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 52 | pages={715--725}, 53 | year={2023}, 54 | organization={Springer} 55 | } 56 | 57 | ## Contact 58 | For any questions, please contact 'wama@cse.cuhk.edu.hk' 59 | 60 | ## License 61 | This project is covered under the **Apache 2.0 License**. 62 | 63 | -------------------------------------------------------------------------------- /upload_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from torch.utils.data import Dataset, DataLoader 4 | from torchvision import transforms, utils 5 | import SimpleITK as sitk 6 | import numpy as np 7 | import torch 8 | 9 | 10 | class trainerData3d_preload(Dataset): 11 | def __init__(self, img_path, data, outcome, treatment, is_train = True): 12 | self.is_train = is_train 13 | self.img_path = img_path 14 | self.data = data 15 | self.outcome = outcome 16 | self.treatment = treatment 17 | self.all_image_data = [] 18 | for index in range(len(self.img_path)): 19 | 20 | get_img = sitk.ReadImage('../../../../' + self.img_path[index]+'/Img_final_0.nii.gz') 21 | return_img = sitk.GetArrayFromImage(get_img).astype(np.float32) 22 | 23 | num_index = len(return_img) // 2 24 | return_img = return_img[num_index-10:num_index+10] 25 | return_img = return_img[np.newaxis,:,:,:] 26 | self.all_image_data.append(return_img) 27 | 28 | def __getitem__(self, index): 29 | return_data = torch.from_numpy(self.data[index]).float().cuda() 30 | return_yt = torch.from_numpy(np.concatenate([self.outcome[index], self.treatment[index]], 0)).float().cuda() 31 | 32 | return_img = self.all_image_data[index] 33 | return_img = torch.from_numpy(return_img).float().cuda() 34 | return return_data, return_yt, return_img 35 | def __len__(self): 36 | return len(self.img_path) 37 | 38 | 39 | 40 | class trainerData_cli(Dataset): 41 | def __init__(self, img_path, data, outcome, treatment, is_train = True): 42 | self.is_train = is_train 43 | self.img_path = img_path 44 | self.data = data 45 | self.outcome = outcome 46 | self.treatment = treatment 47 | def __getitem__(self, index): 48 | return_data = torch.from_numpy(self.data[index]).float().cuda() 49 | return_yt = torch.from_numpy(np.concatenate([self.outcome[index], self.treatment[index]], 0)).float().cuda() 50 | #return_outcome = torch.from_numpy(self.outcome[index]).float().cuda() 51 | #return_treatment = torch.from_numpy(self.return_treatment[index]).float().cuda() 52 | 53 | 54 | return return_data, return_yt 55 | def __len__(self): 56 | return len(self.img_path) 57 | 58 | 59 | class trainerData(Dataset): 60 | def __init__(self, img_path, data, outcome, treatment, is_train = True): 61 | self.is_train = is_train 62 | self.img_path = img_path 63 | self.data = data 64 | self.outcome = outcome 65 | self.treatment = treatment 66 | def __getitem__(self, index): 67 | return_data = torch.from_numpy(self.data[index]).float().cuda() 68 | return_yt = torch.from_numpy(np.concatenate([self.outcome[index], self.treatment[index]], 0)).float().cuda() 69 | #return_outcome = torch.from_numpy(self.outcome[index]).float().cuda() 70 | #return_treatment = torch.from_numpy(self.return_treatment[index]).float().cuda() 71 | try: 72 | get_img = sitk.ReadImage('../../../../' + self.img_path[index]+'/Img_final_0.nii.gz') 73 | return_img = sitk.GetArrayFromImage(get_img).astype(np.float32) 74 | if return_img.shape[0] < 14: 75 | print(self.img_path[index]) 76 | return_img = np.zeros((25,224,224)) 77 | except: 78 | return_img = np.zeros((25,224,224)) 79 | num_index = len(return_img) // 2 80 | return_img = torch.from_numpy(return_img[num_index - 2: num_index + 1]).float().cuda() 81 | return return_data, return_yt, return_img 82 | def __len__(self): 83 | return len(self.img_path) 84 | 85 | 86 | class trainerData_single(Dataset): 87 | def __init__(self, img_path, data, outcome, treatment, is_train = True): 88 | self.is_train = is_train 89 | self.img_path = img_path 90 | self.data = data 91 | self.outcome = outcome 92 | self.treatment = treatment 93 | def __getitem__(self, index): 94 | return_data = torch.from_numpy(self.data[index]).float().cuda() 95 | return_yt = torch.from_numpy(np.concatenate([self.outcome[index], self.treatment[index]], 0)).float().cuda() 96 | #return_outcome = torch.from_numpy(self.outcome[index]).float().cuda() 97 | #return_treatment = torch.from_numpy(self.return_treatment[index]).float().cuda() 98 | 99 | return return_data, return_yt 100 | def __len__(self): 101 | return len(self.img_path) 102 | 103 | def convert_file(x): 104 | x = x.values 105 | x = x.astype(float) 106 | return x 107 | 108 | 109 | def load_and_format_covariates(file_path): 110 | 111 | data = pd.read_excel(file_path) 112 | 113 | data = data.values[1:, ] 114 | 115 | #binfeats = list(range(6,37)) 116 | #contfeats = [i for i in range(37) if i not in binfeats] 117 | 118 | mu_0, mu_1, path, x = data[:, 3][:, None], data[:, 4][:, None], data[:, 5], data[:, 6:] 119 | #perm = binfeats 120 | #x = x[:, perm].astype(float) 121 | 122 | # for num in range(len(x)): 123 | # a1 = x[num].astype(float) 124 | # print(num) 125 | 126 | return x.astype(float), path 127 | 128 | 129 | def load_all_other_crap(file_path): 130 | data = pd.read_excel(file_path) 131 | data = data.values[1:, ] 132 | t, y, y_cf = data[:, 0], data[:, 1][:, None], data[:, 2][:, None] 133 | mu_0, mu_1, x = data[:, 3][:, None], data[:, 4][:, None], data[:, 6:] 134 | return t.reshape(-1, 1).astype(float), y.astype(float), y_cf.astype(float), mu_0.astype(float), mu_1.astype(float) 135 | 136 | def main(): 137 | pass 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | -------------------------------------------------------------------------------- /resnet3d.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import torch.nn.functional as F 4 | 5 | #__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 9 | 10 | class BasicBlock(nn.Module): 11 | expansion = 1 12 | 13 | def __init__(self, inplanes, planes, stride=1, downsample=None): 14 | super(BasicBlock, self).__init__() 15 | self.conv1 = conv3x3(inplanes, planes, stride) 16 | self.bn1 = nn.BatchNorm3d(planes) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.conv2 = conv3x3(planes, planes) 19 | self.bn2 = nn.BatchNorm3d(planes) 20 | self.downsample = downsample 21 | self.stride = stride 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | 33 | if self.downsample is not None: 34 | residual = self.downsample(x) 35 | 36 | out += residual 37 | out = self.relu(out) 38 | 39 | return out 40 | 41 | class Bottleneck(nn.Module): 42 | expansion = 4 43 | 44 | def __init__(self, inplanes, planes, stride=1, downsample=None): 45 | super(Bottleneck, self).__init__() 46 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size = 1, bias = False) 47 | self.bn1 = nn.BatchNorm3d(planes) 48 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm3d(planes) 50 | self.conv3 = nn.Conv3d(planes, planes*self.expansion, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm3d(planes*self.expansion) 52 | self.relu = nn.ReLU(inplace=True) 53 | self.stride = stride 54 | 55 | def forward(self, x): 56 | residual = x 57 | 58 | out = self.conv1(x) 59 | out = self.bn1(out) 60 | out = self.relu(out) 61 | 62 | out = self.conv2(out) 63 | out = self.bn2(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv3(out) 67 | out = self.bn3(out) 68 | 69 | if self.downsample is not None: 70 | residual = self.downsample(x) 71 | 72 | out += residual 73 | out = self.relu(out) 74 | 75 | return out 76 | 77 | # class ResNet(nn.Module): 78 | # def __init__(self, block, layers, num_classes=1000, channel_size=[64,64,128,256,512],dropout=False): 79 | # c = channel_size 80 | # self.inplanes = c[0] 81 | # super(ResNet, self).__init__() 82 | # net = nn.Sequential() 83 | # net.add_module('conv1', nn.Conv3d(1, c[0],kernel_size=7, stride=2, padding=0, bias=False)) 84 | # net.add_module('bn1', nn.BatchNorm3d(c[0])) 85 | # net.add_module('relu', nn.ReLU(inplace=True)) 86 | # net.add_module('maxpool',nn.MaxPool3d(kernel_size=3, stride=2, padding=1)) 87 | # net.add_module('layer1', self._make_layer(block, c[1], layers[0])) 88 | # net.add_module('layer2', self._make_layer(block, c[2], layers[1], stride=2)) 89 | # net.add_module('layer3', self._make_layer(block, c[3], layers[2], stride=2)) 90 | # net.add_module('layer4', self._make_layer(block, c[4], layers[3], stride=2)) 91 | # net.add_module('avgpool', nn.AvgPool3d([5,6,5], stride=1)) 92 | # if dropout is True: 93 | # net.add_module('dropout', nn.Dropout(0.5)) 94 | # self.feature_extractor = net 95 | # self.classifier = nn.Linear(c[4] * block.expansion, num_classes) 96 | 97 | # for m in self.modules(): 98 | # if isinstance(m, nn.Conv3d): 99 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 100 | # elif isinstance(m, nn.BatchNorm3d): 101 | # nn.init.constant_(m.weight,1) 102 | # nn.init.constant_(m.bias, 0) 103 | 104 | # def _make_layer(self, block, planes, blocks, stride=1): 105 | # downsample = None 106 | # if stride != 1 or self.inplanes != planes * block.expansion: 107 | # downsample = nn.Sequential(nn.Conv3d(self.inplanes, planes*block.expansion,kernel_size=1, stride=stride, bias=False), 108 | # nn.BatchNorm3d(planes*block.expansion)) 109 | 110 | # layers = [] 111 | # layers.append(block(self.inplanes, planes, stride, downsample)) 112 | # self.inplanes = planes * block.expansion 113 | # for i in range(1, blocks): 114 | # layers.append(block(self.inplanes, planes)) 115 | 116 | # return nn.Sequential(*layers) 117 | 118 | # def forward(self, x): 119 | # x = self.feature_extractor(x) 120 | # x = x.view(x.size(0),-1) 121 | # x = self.classifier(x) 122 | # x = F.log_softmax(x) 123 | # return x 124 | 125 | 126 | 127 | # def resnet18(**kwargs): 128 | # model = ResNet(BasicBlock, [2,2,2,2], **kwargs) 129 | # return model 130 | 131 | # def resnet34(**kwargs): 132 | # model = ResNet(BasicBlock, [3,4,6,3], **kwargs) 133 | # return model 134 | 135 | # def resnet50(**kwargs): 136 | # model = ResNet(BasicBlock, [3,4,6,3], **kwargs) 137 | # return model 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /check_result.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | import os 3 | import glob 4 | import argparse 5 | from sklearn.preprocessing import StandardScaler 6 | from tqdm import tqdm 7 | from sklearn.model_selection import train_test_split 8 | from sklearn import metrics 9 | # from idhp_data import * 10 | import SimpleITK as sitk 11 | import cv2 12 | import numpy as np 13 | import math 14 | import copy 15 | from tools import * 16 | 17 | def _split_output(yt_hat, t, y, y_scaler, x, is_train=False): 18 | """ 19 | Split output into dictionary for easier use in estimation 20 | Args: 21 | yt_hat: Generated prediction 22 | t: Binary treatment assignments 23 | y: Treatment outcomes 24 | y_scaler: Scaled treatment outcomes 25 | x: Covariates 26 | index: Index in data 27 | 28 | Returns: 29 | Dictionary of all needed data 30 | """ 31 | 32 | 33 | yt_hat = yt_hat 34 | q_t0 = yt_hat[:, 0].reshape(-1, 1).copy() 35 | q_t1 = yt_hat[:, 1].reshape(-1, 1).copy() 36 | 37 | g = yt_hat[:, 2].copy() 38 | treatment_predicted = g.copy() 39 | treatment_predicted[treatment_predicted>=0.5] = 1 40 | treatment_predicted[treatment_predicted<0.5] = 0 41 | 42 | y = y.copy() 43 | var = "average propensity for t: {}".format(g[t.squeeze() == 1.].mean()) 44 | 45 | q_cat = np.concatenate((q_t0, q_t1),1) 46 | 47 | policy = np.argmax(q_cat,1) 48 | 49 | print(var) 50 | print("Policy Risk:", policy_risk_multi(t, y, q_t0, q_t1)) 51 | print("Ate_Error:", ate_error_0_1(t, y, q_t0 - q_t1)) 52 | 53 | print("Treatment accuracy:", np.sum(treatment_predicted==t.squeeze())/treatment_predicted.shape[0]) 54 | 55 | if not is_train: 56 | print("Treatment policy :",policy) 57 | print("Treatment prediction:",treatment_predicted) 58 | print("Treatment label :",t.squeeze().astype(int)) 59 | 60 | auc0,auc1 = factual_auc(t, y, q_t0, q_t1) 61 | accuracy_0, accuracy_1 = factual_acc(t, y, q_t0, q_t1) 62 | 63 | 64 | return {'ave propensity for t': g[t.squeeze() == 1.].mean(), 65 | 'Policy Risk': policy_risk_multi(t, y, q_t0, q_t1), 66 | 'Ate_Error_0_1': ate_error_0_1(t, y, q_t0 - q_t1), 'Treatment accuracy': np.sum(treatment_predicted==t.squeeze())/treatment_predicted.shape[0], 67 | 'Treatment policy': policy, 'Treatment prediction': treatment_predicted, 'Treatment label': t.squeeze().astype(int), 'accuracy_0': accuracy_0, 'accuracy_1':accuracy_1, 68 | 'auc0': auc0, 'auc1':auc1} 69 | 70 | average_propensity_for_t0 = [] 71 | average_propensity_for_t1 = [] 72 | average_propensity_for_t2 = [] 73 | policy_risk = [] 74 | test_ate_error_0_1 = [] 75 | test_ate_error_0_2 = [] 76 | test_ate_error_1_2 = [] 77 | treatment_accuracy = [] 78 | treatment_policy=np.array([]) 79 | treatment_prediction=np.array([]) 80 | treatment_label=np.array([]) 81 | test_factual_accuracy_of_t0 = [] 82 | test_factual_accuracy_of_t1 = [] 83 | test_factual_accuracy_of_t2 = [] 84 | 85 | 86 | train_average_propensity_for_t0 = [] 87 | train_average_propensity_for_t1 = [] 88 | train_average_propensity_for_t2 = [] 89 | train_policy_risk = [] 90 | train_ate_error_0_1 = [] 91 | train_ate_error_0_2 = [] 92 | train_ate_error_1_2 = [] 93 | train_treatment_accuracy = [] 94 | train_factual_accuracy_of_t0 = [] 95 | train_factual_accuracy_of_t1 = [] 96 | train_factual_accuracy_of_t2 = [] 97 | 98 | train_factual_auc_of_t0 = [] 99 | train_factual_auc_of_t1 = [] 100 | test_factual_auc_of_t0 = [] 101 | test_factual_auc_of_t1 = [] 102 | 103 | key_word = 'Treatment accuracy' 104 | key_word4 = 'Policy Risk' 105 | key_word5 = 'accuracy_0' 106 | key_word6 = 'accuracy_1' 107 | key_word7 = 'accuracy_2' 108 | 109 | key_word1 = 'Ate_Error_0_1' 110 | key_word2 = 'Ate_Error_0_2' 111 | key_word3 = 'Ate_Error_1_2' 112 | 113 | key_word_auc0 = 'auc0' 114 | key_word_auc1 = 'auc1' 115 | epoch_index = 0 116 | 117 | for validation_index in range(1): 118 | best_evaluation = 0. 119 | train_outputs_best = {} 120 | test_outputs_best = {} 121 | for epoch in range(0,1500,10): 122 | test_results = np.load("../results_save/IPH_limited_ours2/{}_fold_{}_epoch_test.npz".format(validation_index, epoch), allow_pickle=True) 123 | train_results = np.load("../results_save/IPH_limited_ours2/{}_fold_{}_epoch_train.npz".format(validation_index, epoch), allow_pickle=True) 124 | 125 | yt_hat_test, t_test, y_test, y, x_test = test_results['yt_hat_test'], test_results['t_test'], test_results['y_test'], \ 126 | test_results['y'], test_results['x_test'] 127 | yt_hat_train, t_train, y_train, y, x_train = train_results['yt_hat_train'], train_results['t_train'], train_results['y_train'], \ 128 | train_results['y'], train_results['x_train'] 129 | 130 | test_outputs = _split_output(yt_hat_test, t_test, y_test, y, x_test, is_train=False) 131 | train_outputs = _split_output(yt_hat_train, t_train, y_train, y, x_train, is_train=True) 132 | #test_outputs = test_outputs['arr_0'].item() 133 | #train_outputs = train_outputs['arr_0'].item() 134 | if test_outputs[key_word_auc0]+test_outputs[key_word_auc1] >= best_evaluation and epoch>=100: 135 | #if (test_outputs[key_word1]+test_outputs[key_word2]+test_outputs[key_word3]+test_outputs[key_word4]+(1-test_outputs[key_word5])+(1-test_outputs[key_word6])+(1-test_outputs[key_word7]))/7 <= best_evaluation and epoch>=500: 136 | test_outputs_best = test_outputs 137 | best_evaluation = test_outputs[key_word_auc0]+test_outputs[key_word_auc1] 138 | epoch_index = epoch 139 | #best_evaluation = (test_outputs[key_word1]+test_outputs[key_word2]+test_outputs[key_word3]+test_outputs[key_word4]+(1-test_outputs[key_word5])+(1-test_outputs[key_word6])+(1-test_outputs[key_word7]))/7 140 | 141 | train_outputs_best = train_outputs 142 | # if (train_outputs[key_word1]+train_outputs[key_word2]+train_outputs[key_word3]+train_outputs[key_word4]+(1-train_outputs[key_word5])+(1-train_outputs[key_word6])+(1-train_outputs[key_word7]))/7 <= best_evaluation and epoch>=500: 143 | # train_outputs_best = train_outputs 144 | 145 | # #best_evaluation = test_outputs[key_word] 146 | # best_evaluation = (train_outputs[key_word1]+train_outputs[key_word2]+train_outputs[key_word3]+train_outputs[key_word4]+(1-train_outputs[key_word5])+(1-train_outputs[key_word6])+(1-train_outputs[key_word7]))/7 147 | 148 | 149 | print("==========Best test results for the {} fold==========".format(validation_index)) 150 | 151 | print("average propensity for t: {}".format(test_outputs_best['ave propensity for t'])) 152 | print("Policy Risk:", test_outputs_best['Policy Risk']) 153 | print("Ate_Error_0_1:", test_outputs_best['Ate_Error_0_1']) 154 | 155 | print("Treatment accuracy:", test_outputs_best['Treatment accuracy']) 156 | print("Treatment policy :",test_outputs_best['Treatment policy']) 157 | print("Treatment prediction:",test_outputs_best['Treatment prediction']) 158 | print("Treatment label :",test_outputs_best['Treatment label']) 159 | print("Factual accuracy of t0:", test_outputs_best['accuracy_0']) 160 | print("Factual accuracy of t1:", test_outputs_best['accuracy_1']) 161 | print("Factual auc of t0:", test_outputs_best['auc0']) 162 | print("Factual auc of t1:", test_outputs_best['auc1']) 163 | 164 | # print("Factual auc of t0:", test_outputs_best['auc_0']) 165 | 166 | print("==========Best train results for the {} fold==========".format(validation_index)) 167 | print("average propensity for t: {}".format(train_outputs_best['ave propensity for t'])) 168 | print("Policy Risk:", train_outputs_best['Policy Risk']) 169 | print("Ate_Error_0_1:", train_outputs_best['Ate_Error_0_1']) 170 | 171 | print("Treatment accuracy:", train_outputs_best['Treatment accuracy']) 172 | print("Factual accuracy of t0:", train_outputs_best['accuracy_0']) 173 | print("Factual accuracy of t1:", train_outputs_best['accuracy_1']) 174 | print("Factual auc of t0:", train_outputs_best['auc0']) 175 | print("Factual auc of t1:", train_outputs_best['auc1']) 176 | 177 | # print("Factual auc of t0:", train_outputs_best['auc_0']) 178 | 179 | print("====================================================") 180 | average_propensity_for_t0.append(test_outputs_best['ave propensity for t']) 181 | 182 | policy_risk.append(test_outputs_best['Policy Risk']) 183 | test_ate_error_0_1.append(test_outputs_best['Ate_Error_0_1']) 184 | 185 | treatment_accuracy.append(test_outputs_best['Treatment accuracy']) 186 | test_factual_accuracy_of_t0.append(test_outputs_best['accuracy_0']) 187 | test_factual_accuracy_of_t1.append(test_outputs_best['accuracy_1']) 188 | test_factual_auc_of_t0.append(test_outputs_best['auc0']) 189 | test_factual_auc_of_t1.append(test_outputs_best['auc1']) 190 | 191 | # test_factual_auc_of_t0.append(test_outputs_best['auc_0']) 192 | 193 | treatment_policy=np.concatenate((treatment_policy,test_outputs_best['Treatment policy']),0) 194 | treatment_prediction=np.concatenate((treatment_prediction,test_outputs_best['Treatment prediction']),0) 195 | treatment_label=np.concatenate((treatment_label,test_outputs_best['Treatment label']),0) 196 | 197 | train_average_propensity_for_t0.append(train_outputs_best['ave propensity for t']) 198 | 199 | train_policy_risk.append(train_outputs_best['Policy Risk']) 200 | train_ate_error_0_1.append(train_outputs_best['Ate_Error_0_1']) 201 | 202 | train_factual_accuracy_of_t0.append(train_outputs_best['accuracy_0']) 203 | train_factual_accuracy_of_t1.append(train_outputs_best['accuracy_1']) 204 | train_factual_auc_of_t0.append(train_outputs_best['auc0']) 205 | train_factual_auc_of_t1.append(train_outputs_best['auc1']) 206 | 207 | train_treatment_accuracy.append(train_outputs_best['Treatment accuracy']) 208 | 209 | print("==========Average best test results==========") 210 | print("The best epoch:",epoch_index) 211 | print("average propensity for t: {}".format(np.mean(average_propensity_for_t0))) 212 | print("Policy Risk: {} +- {}".format(np.mean(policy_risk),np.std(policy_risk))) 213 | print("Ate_Error_0_1: {} +- {}".format(np.mean(test_ate_error_0_1),np.std(test_ate_error_0_1))) 214 | 215 | print("Treatment accuracy: {} +- {}".format(np.mean(treatment_accuracy),np.std(treatment_accuracy))) 216 | print("Treatment policy :",treatment_policy) 217 | print("Treatment prediction:",treatment_prediction) 218 | print("Treatment label :",treatment_label) 219 | print("Factual accuracy of t0: {} +- {}".format(np.mean(test_factual_accuracy_of_t0),np.std(test_factual_accuracy_of_t0))) 220 | print("Factual accuracy of t1: {} +- {}".format(np.mean(test_factual_accuracy_of_t1),np.std(test_factual_accuracy_of_t1))) 221 | print("Factual auc of t0: {} +- {}".format(np.mean(test_factual_auc_of_t0),np.std(test_factual_auc_of_t0))) 222 | print("Factual auc of t1: {} +- {}".format(np.mean(test_factual_auc_of_t1),np.std(test_factual_auc_of_t1))) 223 | 224 | # print("Factual auc of t0: {} +- {}".format(np.mean(test_factual_auc_of_t0),np.std(test_factual_auc_of_t0))) 225 | # print("Factual auc of t1: {} +- {}".format(np.mean(test_factual_auc_of_t1),np.std(test_factual_auc_of_t1))) 226 | # print("Factual auc of t2: {} +- {}".format(np.mean(test_factual_auc_of_t2),np.std(test_factual_auc_of_t2))) 227 | print("==========Average best train results=========") 228 | print("average propensity for t: {}".format(np.mean(train_average_propensity_for_t0))) 229 | print("Policy Risk: {} +- {}".format(np.mean(train_policy_risk),np.std(train_policy_risk))) 230 | print("Ate_Error_0_1: {} +- {}".format(np.mean(train_ate_error_0_1),np.std(train_ate_error_0_1))) 231 | 232 | print("Treatment accuracy: {} +- {}".format(np.mean(train_treatment_accuracy), np.std(train_treatment_accuracy))) 233 | print("=============================================") 234 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torch.utils.data import TensorDataset, DataLoader 6 | from torch.utils.tensorboard import SummaryWriter 7 | from torchvision import datasets, transforms 8 | from torch.distributions import Normal, Independent, kl 9 | 10 | import torch.nn.init as init 11 | from torchvision import models 12 | 13 | from resnet3d import * 14 | from torch.autograd import Variable 15 | import math 16 | import numpy as np 17 | import copy 18 | import random 19 | 20 | class BasicBlock(nn.Module): 21 | expansion = 1 22 | 23 | def __init__(self, inplanes, planes, stride=1, downsample=None): 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = conv3x3(inplanes, planes, stride) 26 | self.bn1 = nn.BatchNorm3d(planes) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.conv2 = conv3x3(planes, planes) 29 | self.bn2 = nn.BatchNorm3d(planes) 30 | self.downsample = downsample 31 | self.stride = stride 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out = self.conv1(x) 37 | out = self.bn1(out) 38 | out = self.relu(out) 39 | 40 | out = self.conv2(out) 41 | out = self.bn2(out) 42 | 43 | if self.downsample is not None: 44 | residual = self.downsample(x) 45 | 46 | out += residual 47 | out = self.relu(out) 48 | 49 | return out 50 | 51 | class EpsilonLayer(nn.Module): 52 | def __init__(self): 53 | super(EpsilonLayer, self).__init__() 54 | 55 | # building epsilon trainable weight 56 | self.weights = nn.Parameter(torch.Tensor(1, 1)) 57 | 58 | # initializing weight parameter with RandomNormal 59 | nn.init.normal_(self.weights, mean=0, std=0.05) 60 | 61 | def forward(self, inputs): 62 | return torch.mm(torch.ones_like(inputs)[:, 0:1], self.weights.T) 63 | 64 | class VAE_MM(nn.Module): 65 | def __init__(self,input_size,hidden_size,latent_size,out_features=[200, 64, 1]): 66 | super(VAE_MM,self).__init__() 67 | dropout = False 68 | self.input_size = input_size 69 | self.hidden_size = hidden_size 70 | self.latent_size = latent_size 71 | 72 | # self.encoder = nn.Sequential( 73 | # nn.Linear(input_size, hidden_size), 74 | # nn.ReLU(), 75 | # nn.Linear(hidden_size,latent_size*4) 76 | # ) 77 | 78 | self.decoder0 = nn.Sequential( 79 | #nn.Linear(latent_size+1,4), 80 | #nn.ReLU(), 81 | nn.Linear(latent_size+1,1) 82 | ) 83 | 84 | 85 | self.dropout = nn.Dropout(p=0.2) 86 | self.experts = ProductOfExperts() 87 | self.epsilon = EpsilonLayer() 88 | c=[64,64,128,256,512] 89 | layers = [3, 4, 6, 3] 90 | self.inplanes = c[0] 91 | self.share = torch.nn.Sequential() 92 | self.share.add_module('conv1', nn.Conv3d(1, c[0],kernel_size=7, stride=2, padding=0, bias=False)) 93 | self.share.add_module('bn1', nn.BatchNorm3d(c[0])) 94 | self.share.add_module('relu', nn.ReLU(inplace=True)) 95 | self.share.add_module('maxpool',nn.MaxPool3d(kernel_size=3, stride=2, padding=1)) 96 | self.share.add_module('layer1', self._make_layer(BasicBlock, c[1], layers[0])) 97 | self.share.add_module('layer2', self._make_layer(BasicBlock, c[2], layers[1], stride=2)) 98 | self.share.add_module('layer3', self._make_layer(BasicBlock, c[3], layers[2], stride=2)) 99 | self.share.add_module('layer4', self._make_layer(BasicBlock, c[4], layers[3], stride=2)) 100 | self.share.add_module('avgpool', nn.AvgPool3d([1,7,7])) 101 | 102 | 103 | if dropout is True: 104 | self.share.add_module('dropout', nn.Dropout(0.5)) 105 | self.resenet_head = nn.Sequential(nn.Linear(512, out_features[1]), nn.BatchNorm1d(out_features[1]), nn.ReLU()) 106 | self.resenet_head_t = nn.Sequential(nn.Linear(out_features[1]+1, out_features[1]), nn.BatchNorm1d(out_features[1]), nn.ReLU()) 107 | self.pnet_net_im_0 = nn.Sequential(nn.Linear(out_features[1], latent_size*2), nn.BatchNorm1d(latent_size*2)) 108 | self.pnet_net_im_1 = nn.Sequential(nn.Linear(out_features[1], latent_size*2), nn.BatchNorm1d(latent_size*2)) 109 | self.qnet_net_im_0 = nn.Sequential(nn.Linear(out_features[1]+1, latent_size*2), nn.BatchNorm1d(latent_size*2)) 110 | self.qnet_net_im_1 = nn.Sequential(nn.Linear(out_features[1]+1, latent_size*2), nn.BatchNorm1d(latent_size*2)) 111 | 112 | self.pnet_net_tab_0 = nn.Sequential(nn.Linear(out_features[1], latent_size*2), nn.BatchNorm1d(latent_size*2)) 113 | self.pnet_net_tab_1 = nn.Sequential(nn.Linear(out_features[1], latent_size*2), nn.BatchNorm1d(latent_size*2)) 114 | self.qnet_net_tab_0 = nn.Sequential(nn.Linear(out_features[1]+1, latent_size*2), nn.BatchNorm1d(latent_size*2)) 115 | self.qnet_net_tab_1 = nn.Sequential(nn.Linear(out_features[1]+1, latent_size*2), nn.BatchNorm1d(latent_size*2)) 116 | 117 | self.representation_block = nn.Sequential( 118 | nn.Linear(in_features=input_size, out_features=out_features[0]), 119 | nn.BatchNorm1d(out_features[0]), 120 | nn.ReLU(), 121 | nn.Linear(in_features=out_features[0], out_features=out_features[1]), 122 | nn.BatchNorm1d(out_features[1]), 123 | nn.ReLU() 124 | ) 125 | 126 | self.representation_block_0 = nn.Sequential( 127 | nn.Linear(in_features=out_features[1]+1, out_features=out_features[1]), 128 | nn.BatchNorm1d(out_features[1]), 129 | nn.ReLU() 130 | ) 131 | 132 | self.representation_block_1 = nn.Sequential( 133 | nn.Linear(in_features=out_features[1], out_features=out_features[1]), 134 | nn.BatchNorm1d(out_features[1]), 135 | nn.ReLU() 136 | ) 137 | 138 | for m in self.modules(): 139 | if isinstance(m,nn.Conv3d): 140 | nn.init.kaiming_uniform_(m.weight,mode='fan_out',nonlinearity='relu') 141 | elif isinstance(m,nn.BatchNorm1d): 142 | nn.init.constant_(m.weight,1) 143 | nn.init.constant_(m.bias,0) 144 | elif isinstance(m,nn.BatchNorm3d): 145 | nn.init.constant_(m.weight,1) 146 | nn.init.constant_(m.bias,0) 147 | if isinstance(m,nn.Linear): 148 | nn.init.normal_(m.weight, std=0.01) 149 | 150 | 151 | def _make_layer(self, block, planes, blocks, stride=1): 152 | downsample = None 153 | if stride != 1 or self.inplanes != planes * block.expansion: 154 | downsample = nn.Sequential(nn.Conv3d(self.inplanes, planes*block.expansion,kernel_size=1, stride=stride, bias=False), 155 | nn.BatchNorm3d(planes*block.expansion)) 156 | 157 | layers = [] 158 | layers.append(block(self.inplanes, planes, stride, downsample)) 159 | self.inplanes = planes * block.expansion 160 | for i in range(1, blocks): 161 | layers.append(block(self.inplanes, planes)) 162 | 163 | return nn.Sequential(*layers) 164 | 165 | def encode_tab(self,x,treatment): 166 | x = self.representation_block(x) 167 | treatment_0 = torch.zeros(treatment.shape).cuda() 168 | treatment_1 = torch.ones(treatment.shape).cuda() 169 | x_0 = torch.cat((x,treatment_0),dim=1) 170 | x_1 = torch.cat((x,treatment_1),dim=1) 171 | h_tab_0 = self.representation_block_0(x_0) 172 | h_tab_1 = self.representation_block_0(x_1) 173 | return h_tab_0, h_tab_1 174 | 175 | def encode_img(self,x,treatment): 176 | x = self.resenet_head(x) 177 | treatment_0 = torch.zeros(treatment.shape).cuda() 178 | treatment_1 = torch.ones(treatment.shape).cuda() 179 | x_0 = torch.cat((x,treatment_0),dim=1) 180 | x_1 = torch.cat((x,treatment_1),dim=1) 181 | h_im_0 = self.resenet_head_t(x_0) 182 | h_im_1 = self.resenet_head_t(x_1) 183 | return h_im_0, h_im_1 184 | 185 | def p_net_im(self,phi0_im,phi1_im): 186 | h_im0 = self.pnet_net_im_0(phi0_im) 187 | mu0_im, logvar0_im = torch.chunk(h_im0,2,dim=-1) 188 | h_im1 = self.pnet_net_im_0(phi1_im) 189 | mu1_im, logvar1_im = torch.chunk(h_im1,2,dim=-1) 190 | return mu0_im, logvar0_im, mu1_im, logvar1_im 191 | 192 | def p_net_tab(self,phi0_tab,phi1_tab): 193 | h_tab0 = self.pnet_net_tab_0(phi0_tab) 194 | mu0_tab, logvar0_tab = torch.chunk(h_tab0,2,dim=-1) 195 | h_tab1 = self.pnet_net_tab_0(phi1_tab) 196 | mu1_tab, logvar1_tab = torch.chunk(h_tab1,2,dim=-1) 197 | return mu0_tab, logvar0_tab, mu1_tab, logvar1_tab 198 | 199 | def q_net_im(self,phi0_im,phi1_im): 200 | h_im0 = self.qnet_net_im_0(phi0_im) 201 | mu0_im, logvar0_im = torch.chunk(h_im0,2,dim=-1) 202 | h_im1 = self.qnet_net_im_0(phi1_im) 203 | mu1_im, logvar1_im = torch.chunk(h_im1,2,dim=-1) 204 | return mu0_im, logvar0_im, mu1_im, logvar1_im 205 | 206 | def q_net_tab(self,phi0_tab,phi1_tab): 207 | h_tab0 = self.qnet_net_tab_0(phi0_tab) 208 | mu0_tab, logvar0_tab = torch.chunk(h_tab0,2,dim=-1) 209 | h_tab1 = self.qnet_net_tab_0(phi1_tab) 210 | mu1_tab, logvar1_tab = torch.chunk(h_tab1,2,dim=-1) 211 | return mu0_tab, logvar0_tab, mu1_tab, logvar1_tab 212 | 213 | def PoE(self, mu0_im, logvar0_im, mu1_im, logvar1_im, 214 | mu0_tab, logvar0_tab, mu1_tab, logvar1_tab): 215 | mu0, logvar0, mu1, logvar1 = prior_expert((1,mu0_im.shape[0],self.latent_size)) 216 | mu0 = torch.cat((mu0,mu0_tab.unsqueeze(0)),dim=0) 217 | logvar0 = torch.cat((logvar0,logvar0_tab.unsqueeze(0)),dim=0) 218 | mu1 = torch.cat((mu1,mu1_tab.unsqueeze(0)),dim=0) 219 | logvar1 = torch.cat((logvar1,logvar1_tab.unsqueeze(0)),dim=0) 220 | 221 | mu0 = torch.cat((mu0,mu0_im.unsqueeze(0)),dim=0) 222 | logvar0 = torch.cat((logvar0,logvar0_im.unsqueeze(0)),dim=0) 223 | mu1 = torch.cat((mu1,mu1_im.unsqueeze(0)),dim=0) 224 | logvar1 = torch.cat((logvar1,logvar1_im.unsqueeze(0)),dim=0) 225 | mu0, logvar0 = self.experts(mu0, logvar0) 226 | mu1, logvar1 = self.experts(mu1, logvar1) 227 | 228 | return mu0, logvar0, mu1, logvar1 229 | 230 | def VDC_p(self,phi0_im, phi1_im, phi0_tab, phi1_tab): 231 | mu0_im_p, logvar0_im_p, mu1_im_p, logvar1_im_p = self.p_net_im(phi0_im,phi1_im) 232 | mu0_tab_p, logvar0_tab_p, mu1_tab_p, logvar1_tab_p = self.p_net_tab(phi0_tab, phi1_tab) 233 | 234 | return self.PoE(mu0_im_p, logvar0_im_p, mu1_im_p, logvar1_im_p, 235 | mu0_tab_p, logvar0_tab_p, mu1_tab_p, logvar1_tab_p) 236 | 237 | def VDC_q(self, phi0_im, phi1_im, phi0_tab, phi1_tab,labels): 238 | phi0_im = torch.cat((phi0_im,labels[:,0].unsqueeze(1)),dim=1) 239 | phi1_im = torch.cat((phi1_im,labels[:,0].unsqueeze(1)),dim=1) 240 | phi0_tab = torch.cat((phi0_tab,labels[:,0].unsqueeze(1)),dim=1) 241 | phi1_tab = torch.cat((phi1_tab,labels[:,0].unsqueeze(1)),dim=1) 242 | mu0_im_p, logvar0_im_p, mu1_im_p, logvar1_im_p = self.q_net_im(phi0_im,phi1_im) 243 | mu0_tab_p, logvar0_tab_p, mu1_tab_p, logvar1_tab_p = self.q_net_tab(phi0_tab, phi1_tab) 244 | 245 | return self.PoE(mu0_im_p, logvar0_im_p, mu1_im_p, logvar1_im_p, 246 | mu0_tab_p, logvar0_tab_p, mu1_tab_p, logvar1_tab_p) 247 | 248 | def reparameterize(self,mu,logvar): 249 | std = torch.exp(0.5*logvar) 250 | eps = torch.randn_like(std) 251 | z = mu + eps*std 252 | return z 253 | 254 | 255 | def decode(self,z0,z1, treatment): 256 | treatment_0 = torch.zeros(treatment.shape).cuda() 257 | treatment_1 = torch.ones(treatment.shape).cuda() 258 | z0 = torch.cat((z0,treatment_0),dim=1) 259 | z1 = torch.cat((z1,treatment_1),dim=1) 260 | y0 = self.decoder0(z0) 261 | y1 = self.decoder0(z1) 262 | 263 | return F.sigmoid(y0)[:,0], F.sigmoid(y1)[:,0] 264 | 265 | 266 | def forward(self,x,labels,image,is_train=False): 267 | 268 | treatment = labels[:,1] 269 | treatment = torch.unsqueeze(treatment,dim=1) 270 | 271 | image = self.share.forward(image) 272 | image = self.dropout(image) 273 | image = image[:,:,0,0,0] 274 | phi0_im, phi1_im = self.encode_img(image,treatment) 275 | phi0_tab, phi1_tab = self.encode_tab(x,treatment) 276 | 277 | mu0_p, logvar0_p, mu1_p, logvar1_p = self.VDC_p(phi0_im, phi1_im,phi0_tab, phi1_tab) 278 | if is_train: 279 | mu0_q, logvar0_q, mu1_q, logvar1_q = self.VDC_q(phi0_im, phi1_im,phi0_tab, phi1_tab,labels) 280 | z0 = self.reparameterize(mu0_q, logvar0_q) 281 | z1 = self.reparameterize(mu1_q, logvar1_q) 282 | y0,y1 = self.decode(z0,z1,treatment) 283 | dist_p_0 = Independent(Normal(loc=mu0_p,scale=torch.exp(logvar0_p)),1) 284 | dist_q_0 = Independent(Normal(loc=mu0_q,scale=torch.exp(logvar0_q)),1) 285 | dist_p_1 = Independent(Normal(loc=mu1_p,scale=torch.exp(logvar1_p)),1) 286 | dist_q_1 = Independent(Normal(loc=mu1_q,scale=torch.exp(logvar1_q)),1) 287 | return y0, y1, dist_p_0, dist_q_0, dist_p_1, dist_q_1 288 | else: 289 | z0 = self.reparameterize(mu0_p, logvar0_p) 290 | z1 = self.reparameterize(mu1_p, logvar1_p) 291 | y0,y1 = self.decode(z0,z1,treatment) 292 | return y0, y1 293 | 294 | def prior_expert(size): 295 | mu0 = Variable(torch.zeros(size)) 296 | logvar0 = Variable(torch.log(torch.ones(size))) 297 | mu1 = Variable(torch.zeros(size)) 298 | logvar1 = Variable(torch.log(torch.ones(size))) 299 | mu0, logvar0, mu1, logvar1 = mu0.cuda(), logvar0.cuda(), mu1.cuda(), logvar1.cuda() 300 | return mu0, logvar0, mu1, logvar1 301 | 302 | class ProductOfExperts(nn.Module): 303 | def forward(self,mu,logvar,eps=1e-8): 304 | var = torch.exp(logvar)+eps 305 | T = 1./var 306 | pd_mu = torch.sum(mu*T,dim=0) / torch.sum(T, dim=0) 307 | pd_var = 1./torch.sum(T,dim=0) 308 | pd_logvar = torch.log(pd_var) 309 | 310 | return pd_mu, pd_logvar 311 | 312 | def VAE_loss_function(y0, y1, dist_p_0, dist_q_0, dist_p_1, dist_q_1, labels, class_ratio,ratio_as_t1): 313 | treatment = labels[:,1] 314 | map_0 = torch.zeros(treatment.shape) 315 | map_0[treatment==0] = 1 316 | map_1 = torch.zeros(treatment.shape) 317 | map_1[treatment==1] = 1 318 | BCE = torch.sum(map_0.cuda()*F.binary_cross_entropy(y0, labels[:,0], reduction='none'))/(torch.sum(map_0)+1e-8) + torch.sum(map_1.cuda()*F.binary_cross_entropy(y1, labels[:,0], reduction='none'))/(torch.sum(map_1)+1e-8) 319 | 320 | KLD = torch.mean(kl.kl_divergence(dist_q_0,dist_p_0)+kl.kl_divergence(dist_q_1,dist_p_1)) 321 | return BCE, 1.0*KLD 322 | 323 | def normal_loss_function(y0, y1, labels, class_ratio,ratio_as_t1): 324 | treatment = labels[:,1] 325 | map_0 = torch.zeros(treatment.shape) 326 | map_0[treatment==0] = 1 327 | map_1 = torch.zeros(treatment.shape) 328 | map_1[treatment==1] = 1 329 | BCE = torch.sum(map_0.cuda()*F.binary_cross_entropy(y0, labels[:,0], reduction='none'))/(torch.sum(map_0)+1e-8) + torch.sum(map_1.cuda()*F.binary_cross_entropy(y1, labels[:,0], reduction='none'))/(torch.sum(map_1)+1e-8) 330 | 331 | return BCE 332 | -------------------------------------------------------------------------------- /main_VAE.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | import os 3 | import glob 4 | import argparse 5 | from sklearn.preprocessing import StandardScaler 6 | from tqdm import tqdm 7 | from sklearn.model_selection import train_test_split 8 | from upload_data import * 9 | import SimpleITK as sitk 10 | import cv2 11 | import numpy as np 12 | import math 13 | from tools import * 14 | 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 16 | #torch.backends.cudnn.enabled = False 17 | #random.shuffle 18 | 19 | def _split_output(yt_hat, t, y, y_scaler, x, is_train=False): 20 | """ 21 | Split output into dictionary for easier use in estimation 22 | Args: 23 | yt_hat: Generated prediction 24 | t: Binary treatment assignments 25 | y: Treatment outcomes 26 | y_scaler: Scaled treatment outcomes 27 | x: Covariates 28 | index: Index in data 29 | 30 | Returns: 31 | Dictionary of all needed data 32 | """ 33 | 34 | 35 | yt_hat = yt_hat.detach().cpu().numpy() 36 | q_t0 = yt_hat[:, 0].reshape(-1, 1).copy() 37 | q_t1 = yt_hat[:, 1].reshape(-1, 1).copy() 38 | 39 | g = yt_hat[:, 2].copy() 40 | treatment_predicted = g.copy() 41 | treatment_predicted[treatment_predicted>=0.5] = 1 42 | treatment_predicted[treatment_predicted<0.5] = 0 43 | 44 | y = y.copy() 45 | var = "average propensity for t: {}".format(g[t.squeeze() == 1.].mean()) 46 | 47 | q_cat = np.concatenate((q_t0, q_t1),1) 48 | 49 | policy = np.argmax(q_cat,1) 50 | #policy = np.zeros(q_cat.shape[0]) 51 | 52 | print(var) 53 | print("Policy Risk:", policy_risk_multi(t, y, q_t0, q_t1)) 54 | print("Ate_Error:", ate_error_0_1(t, y, q_t0 - q_t1)) 55 | 56 | print("Treatment accuracy:", np.sum(treatment_predicted==t.squeeze())/treatment_predicted.shape[0]) 57 | 58 | if not is_train: 59 | print("Treatment policy :",policy) 60 | print("Treatment prediction:",treatment_predicted) 61 | print("Treatment label :",t.squeeze().astype(int)) 62 | 63 | factual_auc(t, y, q_t0, q_t1) 64 | factual_acc(t, y, q_t0, q_t1) 65 | 66 | return {'ave propensity for t': g[t.squeeze() == 1.].mean(), 67 | 'Policy Risk': policy_risk_multi(t, y, q_t0, q_t1), 68 | 'Ate_Error_0_1': ate_error_0_1(t, y, q_t0 - q_t1), 'Treatment accuracy': np.sum(treatment_predicted==t.squeeze())/treatment_predicted.shape[0], 69 | 'Treatment policy': policy, 'Treatment prediction': treatment_predicted, 'Treatment label': t.squeeze().astype(int)} 70 | 71 | 72 | def train(train_loader, net, optimizer, criterion, class_ratio,ratio_as_t1): 73 | """ 74 | Trains network for one epoch in batches. 75 | 76 | Args: 77 | train_loader: Data loader for training set. 78 | net: Neural network model. 79 | optimizer: Optimizer (e.g. SGD). 80 | criterion: Loss function (e.g. cross-entropy loss). 81 | """ 82 | 83 | avg_loss_0 = 0 84 | avg_loss_1 = 0 85 | 86 | for i, data in enumerate(train_loader): 87 | 88 | # get the inputs; data is a list of [inputs, labels] 89 | inputs, labels, images = data 90 | #traumatic = inputs[:,3] 91 | # zero the parameter gradients 92 | optimizer.zero_grad() 93 | 94 | # forward + backward + optimize 95 | y0, y1, dist_p_0, dist_q_0, dist_p_1, dist_q_1 = net(inputs,labels,images,is_train=True) 96 | BCE,KLD = criterion(y0, y1, dist_p_0, dist_q_0, dist_p_1, dist_q_1, labels, class_ratio,ratio_as_t1) 97 | loss = BCE + KLD 98 | #loss = criterion(outputs, labels, traumatic, class_ratio) 99 | loss.backward() 100 | optimizer.step() 101 | 102 | # keep track of loss and accuracy 103 | avg_loss_0 += BCE 104 | avg_loss_1 += KLD 105 | 106 | return avg_loss_0 / len(train_loader), avg_loss_1 / len(train_loader), loss / len(train_loader) 107 | 108 | def test(train_loader, net, criterion, number): 109 | """ 110 | Trains network for one epoch in batches. 111 | 112 | Args: 113 | train_loader: Data loader for training set. 114 | net: Neural network model. 115 | optimizer: Optimizer (e.g. SGD). 116 | criterion: Loss function (e.g. cross-entropy loss). 117 | """ 118 | net.eval() 119 | avg_loss = 0 120 | 121 | yt_hat_test = torch.from_numpy(np.zeros((number,4))) 122 | num_ = 0 123 | with torch.no_grad(): 124 | for i, data in enumerate(train_loader): 125 | 126 | # get the inputs; data is a list of [inputs, labels] 127 | inputs, labels, images = data 128 | traumatic = inputs[:,3] 129 | # zero the parameter gradients 130 | 131 | y0, y1 = net(inputs,labels,images) 132 | yt_hat_test[num_:num_+y0.shape[0],0] = y0 133 | yt_hat_test[num_:num_+y0.shape[0],1] = y1 134 | num_ += y0.shape[0] 135 | net.train() 136 | return yt_hat_test 137 | 138 | def load_image(path): 139 | get_test_X = sitk.ReadImage(path) 140 | test_X = sitk.GetArrayFromImage(get_test_X).astype(np.float32) 141 | image = np.zeros((test_X.shape[0],224,224)).astype(np.float32) 142 | for num in range(len(image)): 143 | image[num] = cv2.resize(test_X[num], (224, 224)) 144 | return image 145 | 146 | def train_and_predict_dragons(t, y, x, img_path, targeted_regularization=True, output_dir='', 147 | knob_loss=VAE_loss_function, ratio=1., dragon='', val_split=0.2, batch_size=64, validation_index=0): 148 | """ 149 | Method for training dragonnet and tarnet and predicting new results 150 | Returns: 151 | Outputs on train and test data 152 | """ 153 | 154 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 155 | print(device) 156 | 157 | verbose = 0 158 | 159 | train_outputs_best = {} 160 | test_outputs_best = {} 161 | best_evaluation = 1. 162 | 163 | if dragon == 'tarnet': 164 | print('I am here making tarnet') 165 | net = TarNet(x.shape[1]).to("cuda") 166 | 167 | elif dragon == 'dragonnet': 168 | print("I am here making dragonnet") 169 | net = DragonNet(x.shape[1]).to("cuda") 170 | 171 | elif dragon == 'ours': 172 | print("I am here making ours") 173 | net = VAE_MM(x.shape[1],hidden_size=200,latent_size=10).to("cuda") 174 | 175 | # Which loss to use for training the network 176 | #net = torch.nn.DataParallel(net) 177 | if targeted_regularization: 178 | loss = make_tarreg_loss(ratio=ratio, dragonnet_loss=knob_loss) 179 | else: 180 | loss = knob_loss 181 | 182 | # loss = knob_loss 183 | # for reporducing the IHDP experimemt 184 | 185 | i = 0 186 | torch.manual_seed(i) 187 | np.random.seed(i) 188 | # Get the data and optionally divide into train and test set 189 | 190 | all_index = np.arange(int(x.shape[0])) 191 | 192 | y_origin = copy.copy(y) 193 | y[y_origin>=4]=0 194 | y[y_origin<=3]=1 195 | 196 | ivh = x[:,1] 197 | GCS = x[:,4] 198 | ivh_index = [] 199 | GCS_index = [] 200 | for limited_index in all_index: 201 | if ivh[limited_index]==1 and t[limited_index]==0: 202 | ivh_index.append(limited_index) 203 | if len(ivh_index)>=50: 204 | break 205 | for limited_index in all_index: 206 | if GCS[limited_index]<9 and t[limited_index]==1: 207 | GCS_index.append(limited_index) 208 | if len(GCS_index)>=50: 209 | break 210 | test_index = np.array(ivh_index+GCS_index) 211 | x = x[:,1:] 212 | 213 | train_index = [] 214 | for m in all_index: 215 | if m not in test_index: 216 | train_index.append(m) 217 | 218 | x_train, x_test = x[train_index], x[test_index] 219 | y_train, y_test = y[train_index], y[test_index] 220 | t_train, t_test = t[train_index], t[test_index] 221 | img_path_train, img_path_test = img_path[train_index], img_path[test_index] 222 | 223 | yt_train = np.concatenate([y_train, t_train], 1) 224 | yt_test = np.concatenate([y_test, t_test], 1) 225 | 226 | t0_index = np.where(t_train==0) 227 | t1_index = np.where(t_train==1) 228 | 229 | ratio_t0 = np.sum(y_train[t0_index])/len(y_train[t0_index]) 230 | ratio_t1 = np.sum(y_train[t1_index])/len(y_train[t1_index]) 231 | 232 | class_ratio = [ratio_t0, ratio_t1] 233 | 234 | ratio_as_t1 = np.sum(t_train)/len(t_train) 235 | 236 | 237 | train_data = trainerData3d_preload(img_path_train, x_train, y_train, t_train, is_train = True) 238 | test_data = trainerData3d_preload(img_path_test, x_test, y_test, t_test, is_train = False) 239 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last = True) 240 | train_loader_test = DataLoader(train_data, batch_size=batch_size, shuffle=False) 241 | test_loader = DataLoader(test_data, batch_size=256, shuffle=False) 242 | 243 | import time; 244 | start_time = time.time() 245 | 246 | # Configuring optimizers 247 | # Training the networks first for 100 epochs with the Adam optimizer and 248 | # then for 300 epochs with the SGD optimizer. 249 | epochs1 = 1500 250 | epochs2 = 500 251 | 252 | # Add L2 regularization to t0 and t1 heads of the network 253 | 254 | optimizer_Adam = optim.Adam(net.parameters(), lr=5e-3) 255 | scheduler_Adam = optim.lr_scheduler.StepLR(optimizer=optimizer_Adam, step_size = 300, gamma=0.5) 256 | #scheduler_SGD = optim.lr_scheduler.StepLR(optimizer=optimizer_SGD, step_size = 200, gamma=0.5) 257 | 258 | train_loss = 0 259 | epochs0 = 0 260 | 261 | if epochs0 != 0: 262 | load_model_path = '../models_save/IPH_limited_ours2/'+str(epochs0)+'.pth' 263 | net.load_state_dict(torch.load(load_model_path)) 264 | 265 | # Adam training run 266 | for epoch in range(epochs0, epochs1): 267 | # Train on data 268 | train_loss0,train_loss1, train_sum = train(train_loader, net, optimizer_Adam, loss, class_ratio,ratio_as_t1) 269 | scheduler_Adam.step(train_sum) 270 | 271 | #train_loss = train(train_loader, net, optimizer_SGD, loss, class_ratio) 272 | #scheduler_SGD.step(train_loss) 273 | 274 | 275 | 276 | if epoch % 10 ==0: 277 | print("BCE:"+str(epoch)+"/"+str(epochs1)+" "+f"Adam loss: {train_loss0}") 278 | print("KLD:"+str(epoch)+"/"+str(epochs1)+" "+f"Adam loss: {train_loss1}") 279 | yt_hat_test = test(test_loader, net, loss, len(test_index)) 280 | yt_hat_train = test(train_loader_test, net, loss, len(train_index)) 281 | np.savez_compressed("../results_save/IPH_limited_ours2/{}_fold_{}_epoch_test.npz".format(validation_index, epoch),yt_hat_test=yt_hat_test,t_test=t_test,y_test=y_test, 282 | y=y,x_test=x_test) 283 | np.savez_compressed("../results_save/IPH_limited_ours2/{}_fold_{}_epoch_train.npz".format(validation_index, epoch),yt_hat_train=yt_hat_train,t_train=t_train,y_train=y_train, 284 | y=y,x_train=x_train) 285 | test_outputs = _split_output(yt_hat_test, t_test, y_test, y, x_test, is_train=False) 286 | train_outputs = _split_output(yt_hat_train, t_train, y_train, y, x_train, is_train=True) 287 | if test_outputs['Policy Risk'] <= best_evaluation: 288 | train_outputs_best = train_outputs 289 | test_outputs_best = test_outputs 290 | best_evaluation = test_outputs['Policy Risk'] 291 | print("==================the {} fold====================".format(validation_index)) 292 | 293 | if epoch % 100 ==0: 294 | save_model_path = '../models_save/IPH_limited_ours/'+str(epoch)+'.pth' 295 | torch.save(net.state_dict(),save_model_path) 296 | save_model_path = '../models_save/IPH_limited_ours/'+str(epoch)+ '_' + str(validation_index) + '_fold.pth' 297 | torch.save(net.state_dict(),save_model_path) 298 | return test_outputs_best, train_outputs_best 299 | 300 | 301 | def run_ihdp(data_base_dir, output_dir='~/result/IPH/', 302 | knob_loss=VAE_loss_function, 303 | ratio=1., dragon=''): 304 | 305 | print("the dragon is {}".format(dragon)) 306 | 307 | simulation_files = sorted(glob.glob("{}/*.xls".format(data_base_dir))) 308 | 309 | for idx, simulation_file in enumerate(simulation_files): 310 | 311 | simulation_output_dir = os.path.join(output_dir, str(idx)) 312 | 313 | os.makedirs(simulation_output_dir, exist_ok=True) 314 | 315 | x, img_path = load_and_format_covariates(simulation_file) 316 | t, y, y_cf, mu_0, mu_1 = load_all_other_crap(simulation_file) 317 | np.savez_compressed(os.path.join(simulation_output_dir, "simulation_outputs.npz"), 318 | t=t, y=y, y_cf=y_cf, mu_0=mu_0, mu_1=mu_1) 319 | 320 | average_propensity_for_t0 = [] 321 | average_propensity_for_t1 = [] 322 | average_propensity_for_t2 = [] 323 | policy_risk = [] 324 | ate_error_0_1 = [] 325 | ate_error_0_2 = [] 326 | ate_error_1_2 = [] 327 | treatment_accuracy = [] 328 | treatment_policy=np.array([]) 329 | treatment_prediction=np.array([]) 330 | treatment_label=np.array([]) 331 | 332 | 333 | train_average_propensity_for_t0 = [] 334 | train_average_propensity_for_t1 = [] 335 | train_average_propensity_for_t2 = [] 336 | train_policy_risk = [] 337 | train_ate_error_0_1 = [] 338 | train_ate_error_0_2 = [] 339 | train_ate_error_1_2 = [] 340 | train_treatment_accuracy = [] 341 | 342 | 343 | for validation_index in range(0,1): 344 | # print("Is targeted regularization: {}".format(is_targeted_regularization)) 345 | test_outputs_best, train_outputs_best = train_and_predict_dragons(t, y, x, img_path, 346 | targeted_regularization=False, 347 | output_dir=simulation_output_dir, 348 | knob_loss=knob_loss, ratio=ratio, dragon=dragon, 349 | val_split=0.2, batch_size=128, validation_index=validation_index) 350 | 351 | #np.savez_compressed("../results_save/cli_img/{}_fold_test.npz".format(validation_index),test_outputs_best) 352 | #np.savez_compressed("../results_save/cli_img/{}_fold_train.npz".format(validation_index),train_outputs_best) 353 | print("==========Best test results for the {} fold==========".format(validation_index)) 354 | print("average propensity for t: {}".format(test_outputs_best['ave propensity for t'])) 355 | print("Policy Risk:", test_outputs_best['Policy Risk']) 356 | print("Ate_Error_0_1:", test_outputs_best['Ate_Error_0_1']) 357 | 358 | 359 | print("Treatment accuracy:", test_outputs_best['Treatment accuracy']) 360 | print("Treatment policy :",test_outputs_best['Treatment policy']) 361 | print("Treatment prediction:",test_outputs_best['Treatment prediction']) 362 | print("Treatment label :",test_outputs_best['Treatment label']) 363 | print("==========Best train results for the {} fold==========".format(validation_index)) 364 | print("average propensity for t: {}".format(train_outputs_best['ave propensity for t'])) 365 | print("Policy Risk:", train_outputs_best['Policy Risk']) 366 | print("Ate_Error_0_1:", train_outputs_best['Ate_Error_0_1']) 367 | 368 | print("Treatment accuracy:", train_outputs_best['Treatment accuracy']) 369 | print("====================================================") 370 | average_propensity_for_t0.append(test_outputs_best['ave propensity for t']) 371 | 372 | policy_risk.append(test_outputs_best['Policy Risk']) 373 | ate_error_0_1.append(test_outputs_best['Ate_Error_0_1']) 374 | 375 | treatment_accuracy.append(test_outputs_best['Treatment accuracy']) 376 | treatment_policy=np.concatenate((treatment_policy,test_outputs_best['Treatment policy']),0) 377 | treatment_prediction=np.concatenate((treatment_prediction,test_outputs_best['Treatment prediction']),0) 378 | treatment_label=np.concatenate((treatment_label,test_outputs_best['Treatment label']),0) 379 | 380 | train_average_propensity_for_t0.append(train_outputs_best['ave propensity for t']) 381 | 382 | train_policy_risk.append(train_outputs_best['Policy Risk']) 383 | train_ate_error_0_1.append(train_outputs_best['Ate_Error_0_1']) 384 | 385 | train_treatment_accuracy.append(train_outputs_best['Treatment accuracy']) 386 | 387 | print("==========Average best test results==========") 388 | print("average propensity for t: {}".format(np.mean(average_propensity_for_t0))) 389 | print("Policy Risk:", np.mean(policy_risk)) 390 | print("Ate_Error_0_1:", np.mean(ate_error_0_1)) 391 | 392 | print("Treatment accuracy:", np.mean(treatment_accuracy)) 393 | print("Treatment policy :",treatment_policy) 394 | print("Treatment prediction:",treatment_prediction) 395 | print("Treatment label :",treatment_label) 396 | print("==========Average best train results=========") 397 | print("average propensity for t: {}".format(np.mean(train_average_propensity_for_t0))) 398 | print("Policy Risk:", np.mean(train_policy_risk)) 399 | print("Ate_Error_0_1:", np.mean(train_ate_error_0_1)) 400 | 401 | print("Treatment accuracy:", np.mean(train_treatment_accuracy)) 402 | print("=============================================") 403 | 404 | 405 | def turn_knob(data_base_dir, knob='dragonnet', 406 | output_base_dir=''): 407 | output_dir = os.path.join(output_base_dir, knob) 408 | 409 | if knob == 'dragonnet': 410 | run_ihdp(data_base_dir=data_base_dir, output_dir=output_dir, dragon='dragonnet') 411 | 412 | if knob == 'tarnet': 413 | run_ihdp(data_base_dir=data_base_dir, output_dir=output_dir, dragon='tarnet') 414 | 415 | if knob == 'ours': 416 | run_ihdp(data_base_dir=data_base_dir, output_dir=output_dir, dragon='ours') 417 | 418 | def main(): 419 | parser = argparse.ArgumentParser() 420 | parser.add_argument('--data_base_dir', type=str, help="path to directory",default='../data/IPH') 421 | parser.add_argument('--knob', type=str, default='ours', 422 | help="dragonnet or tarnet or ours") 423 | 424 | parser.add_argument('--output_base_dir', type=str, help="directory to save the output",default='../result/ours') 425 | 426 | args = parser.parse_args() 427 | turn_knob(args.data_base_dir, args.knob, args.output_base_dir) 428 | 429 | 430 | if __name__ == '__main__': 431 | main() 432 | --------------------------------------------------------------------------------