├── assets
└── intro.png
├── data
└── IPH
│ └── example.xlsx
├── tools.py
├── README.md
├── upload_data.py
├── resnet3d.py
├── LICENSE
├── check_result.py
├── models.py
└── main_VAE.py
/assets/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/med-air/TOP-GPM/HEAD/assets/intro.png
--------------------------------------------------------------------------------
/data/IPH/example.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/med-air/TOP-GPM/HEAD/data/IPH/example.xlsx
--------------------------------------------------------------------------------
/tools.py:
--------------------------------------------------------------------------------
1 | from models import *
2 | import os
3 | import glob
4 | import argparse
5 | from sklearn.preprocessing import StandardScaler
6 | from tqdm import tqdm
7 | from sklearn.model_selection import train_test_split
8 | # from idhp_data import *
9 | import SimpleITK as sitk
10 | import cv2
11 | import numpy as np
12 | import math
13 |
14 |
15 |
16 | def policy_val(t, yf, q_t0, q_t1, compute_policy_curve=False):
17 | # if np.any(np.isnan(eff_pred)):
18 | # return np.nan, np.nan
19 | q_cat = np.concatenate((q_t0, q_t1),1)
20 |
21 | policy = np.argmax(q_cat,1)
22 | policy = policy[:,np.newaxis]
23 | t0_overlap = (policy==t)*(t==0)
24 | t1_overlap = (policy==t)*(t==1)
25 |
26 |
27 | if np.sum(t0_overlap) == 0:
28 | t0_value = 0
29 | else:
30 | t0_value = np.mean(yf[t0_overlap])
31 |
32 | if np.sum(t1_overlap) == 0:
33 | t1_value = 0
34 | else:
35 | t1_value = np.mean(yf[t1_overlap])
36 |
37 |
38 |
39 | pit_0 = np.sum(policy==0)/len(t)
40 | pit_1 = np.sum(policy==1)/len(t)
41 |
42 | policy_value = pit_0*t0_value + pit_1*t1_value
43 |
44 |
45 | return policy_value
46 |
47 | def factual_acc(t, yf, q_t0, q_t1):
48 |
49 | q_t0[q_t0>=0.5] = 1
50 | q_t0[q_t0<0.5] = 0
51 |
52 | q_t1[q_t1>=0.5] = 1
53 | q_t1[q_t1<0.5] = 0
54 |
55 |
56 | accuracy_0 = np.sum(q_t0[t==0]==yf[t==0])/len(yf[t==0])
57 | accuracy_1 = np.sum(q_t1[t==1]==yf[t==1])/len(yf[t==1])
58 |
59 |
60 | print("Factual accuracy of t0:", accuracy_0)
61 | print("Factual accuracy of t1:", accuracy_1)
62 |
63 | return accuracy_0,accuracy_1
64 |
65 | def factual_auc(t, yf, q_t0, q_t1):
66 | from sklearn import metrics
67 | y_t0 = []
68 | y_t1 = []
69 | p_t0 = []
70 | p_t1 = []
71 |
72 | for index in range(len(t)):
73 | if t[index] ==0:
74 | y_t0.append(yf[index])
75 | p_t0.append(q_t0[index])
76 | else:
77 | y_t1.append(yf[index])
78 | p_t1.append(q_t1[index])
79 |
80 |
81 | y_t0,p_t0, y_t1,p_t1 = np.array(y_t0), np.array(p_t0), np.array(y_t1), np.array(p_t1)
82 | auc0 = metrics.roc_auc_score(y_t0,p_t0)
83 | auc1 = metrics.roc_auc_score(y_t1,p_t1)
84 |
85 |
86 |
87 | print("Factual auc of t0:", auc0)
88 | print("Factual auc of t1:", auc1)
89 |
90 | return auc0,auc1
91 |
92 | def policy_risk_multi(t, yf, q_t0, q_t1):
93 | policy_value = policy_val(t, yf, q_t0, q_t1)
94 | policy_risk = 1 - policy_value
95 | return policy_risk
96 |
97 | def ate_error_0_1(t, yf, eff_pred):
98 | att = np.mean(yf[t==0]) - np.mean(yf[t==1])
99 | pred_att = np.mean(eff_pred)
100 |
101 | return np.abs(att-pred_att)
102 |
103 | def ate_error_0_2(t, yf, eff_pred):
104 | att = np.mean(yf[t==0]) - np.mean(yf[t==2])
105 | pred_att = np.mean(eff_pred)
106 |
107 | return np.abs(att-pred_att)
108 |
109 | def ate_error_1_2(t, yf, eff_pred):
110 | att = np.mean(yf[t==1]) - np.mean(yf[t==2])
111 | pred_att = np.mean(eff_pred)
112 |
113 | return np.abs(att-pred_att)
114 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Introduction
2 |
3 | + Pytorch implementation for paper **[Treatment Outcome Prediction for Intracerebral Hemorrhage via Generative Prognostic Model with Imaging and Tabular Data](https://arxiv.org/pdf/2307.12858)**
4 |
5 |
6 |
7 |
8 |
9 | ## Abstract
10 | > Intracerebral hemorrhage (ICH) is the second most common and deadliest form of stroke. Despite medical advances, predicting treatment outcomes for ICH remains a challenge.
11 | This paper proposes a novel prognostic model that utilizes both imaging and tabular data to predict treatment outcome for ICH. Our model is trained on observational data collected from non-randomized controlled trials, providing reliable predictions of treatment success. Specifically, we propose to employ a variational autoencoder model to generate a low-dimensional prognostic score, which can effectively address the selection bias resulting from the non-randomized controlled trials. Importantly, we develop a variational distributions combination module that combines the information from imaging data, non-imaging clinical data, and treatment assignment to accurately generate the prognostic score. We conducted extensive experiments on a real-world clinical dataset of intracerebral hemorrhage. Our proposed method demonstrates a substantial improvement in treatment outcome prediction compared to existing state-of-the-art approaches.
12 |
13 | ## Setup
14 |
15 | ### OS Requirements
16 | This model has been tested on the following systems:
17 |
18 | + Linux: Ubuntu 18.04
19 |
20 | ```bash
21 | Package Version
22 | ---------------------- -------------------
23 | torch 1.4.0
24 | torchvision 0.5.0
25 | h5py 3.1.0
26 | opencv-python 4.5.2.52
27 | SimpleITK 2.0.2
28 | scikit-image. 0.17.2
29 | ml-collections 0.1.1
30 | tensorboardx 2.2.0
31 | medpy 0.4.0
32 | scikit-learn 0.24.2
33 | pandas 1.1.5
34 | ```
35 |
36 | ## Training & Testing
37 | + This article uses a private dataset. In order to successfully run the code, you need to prepare your own dataset.
38 | + Specifically, you need to prepare a .xls file, which saves the patients' non-imaging clinical data and the path of imaging data. We have provided an example for you to run the data, which is saved in "./data/IPH/example.xls".
39 | + We run main_VAE.py to train and evaluate the model:
40 | ```bash
41 | python main_VAE.py
42 | ```
43 | + Our proposed model is saved in models.py, named "VAE_MM".
44 |
45 | ## Citation
46 | If this repository is useful for your research, please cite:
47 |
48 | @inproceedings{ma2023treatment,
49 | title={Treatment Outcome Prediction for Intracerebral Hemorrhage via Generative Prognostic Model with Imaging and Tabular Data},
50 | author={Ma, Wenao and Chen, Cheng and Abrigo, Jill and Mak, Calvin Hoi-Kwan and Gong, Yuqi and Chan, Nga Yan and Han, Chu and Liu, Zaiyi and Dou, Qi},
51 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
52 | pages={715--725},
53 | year={2023},
54 | organization={Springer}
55 | }
56 |
57 | ## Contact
58 | For any questions, please contact 'wama@cse.cuhk.edu.hk'
59 |
60 | ## License
61 | This project is covered under the **Apache 2.0 License**.
62 |
63 |
--------------------------------------------------------------------------------
/upload_data.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | from torch.utils.data import Dataset, DataLoader
4 | from torchvision import transforms, utils
5 | import SimpleITK as sitk
6 | import numpy as np
7 | import torch
8 |
9 |
10 | class trainerData3d_preload(Dataset):
11 | def __init__(self, img_path, data, outcome, treatment, is_train = True):
12 | self.is_train = is_train
13 | self.img_path = img_path
14 | self.data = data
15 | self.outcome = outcome
16 | self.treatment = treatment
17 | self.all_image_data = []
18 | for index in range(len(self.img_path)):
19 |
20 | get_img = sitk.ReadImage('../../../../' + self.img_path[index]+'/Img_final_0.nii.gz')
21 | return_img = sitk.GetArrayFromImage(get_img).astype(np.float32)
22 |
23 | num_index = len(return_img) // 2
24 | return_img = return_img[num_index-10:num_index+10]
25 | return_img = return_img[np.newaxis,:,:,:]
26 | self.all_image_data.append(return_img)
27 |
28 | def __getitem__(self, index):
29 | return_data = torch.from_numpy(self.data[index]).float().cuda()
30 | return_yt = torch.from_numpy(np.concatenate([self.outcome[index], self.treatment[index]], 0)).float().cuda()
31 |
32 | return_img = self.all_image_data[index]
33 | return_img = torch.from_numpy(return_img).float().cuda()
34 | return return_data, return_yt, return_img
35 | def __len__(self):
36 | return len(self.img_path)
37 |
38 |
39 |
40 | class trainerData_cli(Dataset):
41 | def __init__(self, img_path, data, outcome, treatment, is_train = True):
42 | self.is_train = is_train
43 | self.img_path = img_path
44 | self.data = data
45 | self.outcome = outcome
46 | self.treatment = treatment
47 | def __getitem__(self, index):
48 | return_data = torch.from_numpy(self.data[index]).float().cuda()
49 | return_yt = torch.from_numpy(np.concatenate([self.outcome[index], self.treatment[index]], 0)).float().cuda()
50 | #return_outcome = torch.from_numpy(self.outcome[index]).float().cuda()
51 | #return_treatment = torch.from_numpy(self.return_treatment[index]).float().cuda()
52 |
53 |
54 | return return_data, return_yt
55 | def __len__(self):
56 | return len(self.img_path)
57 |
58 |
59 | class trainerData(Dataset):
60 | def __init__(self, img_path, data, outcome, treatment, is_train = True):
61 | self.is_train = is_train
62 | self.img_path = img_path
63 | self.data = data
64 | self.outcome = outcome
65 | self.treatment = treatment
66 | def __getitem__(self, index):
67 | return_data = torch.from_numpy(self.data[index]).float().cuda()
68 | return_yt = torch.from_numpy(np.concatenate([self.outcome[index], self.treatment[index]], 0)).float().cuda()
69 | #return_outcome = torch.from_numpy(self.outcome[index]).float().cuda()
70 | #return_treatment = torch.from_numpy(self.return_treatment[index]).float().cuda()
71 | try:
72 | get_img = sitk.ReadImage('../../../../' + self.img_path[index]+'/Img_final_0.nii.gz')
73 | return_img = sitk.GetArrayFromImage(get_img).astype(np.float32)
74 | if return_img.shape[0] < 14:
75 | print(self.img_path[index])
76 | return_img = np.zeros((25,224,224))
77 | except:
78 | return_img = np.zeros((25,224,224))
79 | num_index = len(return_img) // 2
80 | return_img = torch.from_numpy(return_img[num_index - 2: num_index + 1]).float().cuda()
81 | return return_data, return_yt, return_img
82 | def __len__(self):
83 | return len(self.img_path)
84 |
85 |
86 | class trainerData_single(Dataset):
87 | def __init__(self, img_path, data, outcome, treatment, is_train = True):
88 | self.is_train = is_train
89 | self.img_path = img_path
90 | self.data = data
91 | self.outcome = outcome
92 | self.treatment = treatment
93 | def __getitem__(self, index):
94 | return_data = torch.from_numpy(self.data[index]).float().cuda()
95 | return_yt = torch.from_numpy(np.concatenate([self.outcome[index], self.treatment[index]], 0)).float().cuda()
96 | #return_outcome = torch.from_numpy(self.outcome[index]).float().cuda()
97 | #return_treatment = torch.from_numpy(self.return_treatment[index]).float().cuda()
98 |
99 | return return_data, return_yt
100 | def __len__(self):
101 | return len(self.img_path)
102 |
103 | def convert_file(x):
104 | x = x.values
105 | x = x.astype(float)
106 | return x
107 |
108 |
109 | def load_and_format_covariates(file_path):
110 |
111 | data = pd.read_excel(file_path)
112 |
113 | data = data.values[1:, ]
114 |
115 | #binfeats = list(range(6,37))
116 | #contfeats = [i for i in range(37) if i not in binfeats]
117 |
118 | mu_0, mu_1, path, x = data[:, 3][:, None], data[:, 4][:, None], data[:, 5], data[:, 6:]
119 | #perm = binfeats
120 | #x = x[:, perm].astype(float)
121 |
122 | # for num in range(len(x)):
123 | # a1 = x[num].astype(float)
124 | # print(num)
125 |
126 | return x.astype(float), path
127 |
128 |
129 | def load_all_other_crap(file_path):
130 | data = pd.read_excel(file_path)
131 | data = data.values[1:, ]
132 | t, y, y_cf = data[:, 0], data[:, 1][:, None], data[:, 2][:, None]
133 | mu_0, mu_1, x = data[:, 3][:, None], data[:, 4][:, None], data[:, 6:]
134 | return t.reshape(-1, 1).astype(float), y.astype(float), y_cf.astype(float), mu_0.astype(float), mu_1.astype(float)
135 |
136 | def main():
137 | pass
138 |
139 |
140 | if __name__ == '__main__':
141 | main()
142 |
--------------------------------------------------------------------------------
/resnet3d.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 | import torch.nn.functional as F
4 |
5 | #__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
6 |
7 | def conv3x3(in_planes, out_planes, stride=1):
8 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
9 |
10 | class BasicBlock(nn.Module):
11 | expansion = 1
12 |
13 | def __init__(self, inplanes, planes, stride=1, downsample=None):
14 | super(BasicBlock, self).__init__()
15 | self.conv1 = conv3x3(inplanes, planes, stride)
16 | self.bn1 = nn.BatchNorm3d(planes)
17 | self.relu = nn.ReLU(inplace=True)
18 | self.conv2 = conv3x3(planes, planes)
19 | self.bn2 = nn.BatchNorm3d(planes)
20 | self.downsample = downsample
21 | self.stride = stride
22 |
23 | def forward(self, x):
24 | residual = x
25 |
26 | out = self.conv1(x)
27 | out = self.bn1(out)
28 | out = self.relu(out)
29 |
30 | out = self.conv2(out)
31 | out = self.bn2(out)
32 |
33 | if self.downsample is not None:
34 | residual = self.downsample(x)
35 |
36 | out += residual
37 | out = self.relu(out)
38 |
39 | return out
40 |
41 | class Bottleneck(nn.Module):
42 | expansion = 4
43 |
44 | def __init__(self, inplanes, planes, stride=1, downsample=None):
45 | super(Bottleneck, self).__init__()
46 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size = 1, bias = False)
47 | self.bn1 = nn.BatchNorm3d(planes)
48 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
49 | self.bn2 = nn.BatchNorm3d(planes)
50 | self.conv3 = nn.Conv3d(planes, planes*self.expansion, kernel_size=1, bias=False)
51 | self.bn3 = nn.BatchNorm3d(planes*self.expansion)
52 | self.relu = nn.ReLU(inplace=True)
53 | self.stride = stride
54 |
55 | def forward(self, x):
56 | residual = x
57 |
58 | out = self.conv1(x)
59 | out = self.bn1(out)
60 | out = self.relu(out)
61 |
62 | out = self.conv2(out)
63 | out = self.bn2(out)
64 | out = self.relu(out)
65 |
66 | out = self.conv3(out)
67 | out = self.bn3(out)
68 |
69 | if self.downsample is not None:
70 | residual = self.downsample(x)
71 |
72 | out += residual
73 | out = self.relu(out)
74 |
75 | return out
76 |
77 | # class ResNet(nn.Module):
78 | # def __init__(self, block, layers, num_classes=1000, channel_size=[64,64,128,256,512],dropout=False):
79 | # c = channel_size
80 | # self.inplanes = c[0]
81 | # super(ResNet, self).__init__()
82 | # net = nn.Sequential()
83 | # net.add_module('conv1', nn.Conv3d(1, c[0],kernel_size=7, stride=2, padding=0, bias=False))
84 | # net.add_module('bn1', nn.BatchNorm3d(c[0]))
85 | # net.add_module('relu', nn.ReLU(inplace=True))
86 | # net.add_module('maxpool',nn.MaxPool3d(kernel_size=3, stride=2, padding=1))
87 | # net.add_module('layer1', self._make_layer(block, c[1], layers[0]))
88 | # net.add_module('layer2', self._make_layer(block, c[2], layers[1], stride=2))
89 | # net.add_module('layer3', self._make_layer(block, c[3], layers[2], stride=2))
90 | # net.add_module('layer4', self._make_layer(block, c[4], layers[3], stride=2))
91 | # net.add_module('avgpool', nn.AvgPool3d([5,6,5], stride=1))
92 | # if dropout is True:
93 | # net.add_module('dropout', nn.Dropout(0.5))
94 | # self.feature_extractor = net
95 | # self.classifier = nn.Linear(c[4] * block.expansion, num_classes)
96 |
97 | # for m in self.modules():
98 | # if isinstance(m, nn.Conv3d):
99 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
100 | # elif isinstance(m, nn.BatchNorm3d):
101 | # nn.init.constant_(m.weight,1)
102 | # nn.init.constant_(m.bias, 0)
103 |
104 | # def _make_layer(self, block, planes, blocks, stride=1):
105 | # downsample = None
106 | # if stride != 1 or self.inplanes != planes * block.expansion:
107 | # downsample = nn.Sequential(nn.Conv3d(self.inplanes, planes*block.expansion,kernel_size=1, stride=stride, bias=False),
108 | # nn.BatchNorm3d(planes*block.expansion))
109 |
110 | # layers = []
111 | # layers.append(block(self.inplanes, planes, stride, downsample))
112 | # self.inplanes = planes * block.expansion
113 | # for i in range(1, blocks):
114 | # layers.append(block(self.inplanes, planes))
115 |
116 | # return nn.Sequential(*layers)
117 |
118 | # def forward(self, x):
119 | # x = self.feature_extractor(x)
120 | # x = x.view(x.size(0),-1)
121 | # x = self.classifier(x)
122 | # x = F.log_softmax(x)
123 | # return x
124 |
125 |
126 |
127 | # def resnet18(**kwargs):
128 | # model = ResNet(BasicBlock, [2,2,2,2], **kwargs)
129 | # return model
130 |
131 | # def resnet34(**kwargs):
132 | # model = ResNet(BasicBlock, [3,4,6,3], **kwargs)
133 | # return model
134 |
135 | # def resnet50(**kwargs):
136 | # model = ResNet(BasicBlock, [3,4,6,3], **kwargs)
137 | # return model
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/check_result.py:
--------------------------------------------------------------------------------
1 | from models import *
2 | import os
3 | import glob
4 | import argparse
5 | from sklearn.preprocessing import StandardScaler
6 | from tqdm import tqdm
7 | from sklearn.model_selection import train_test_split
8 | from sklearn import metrics
9 | # from idhp_data import *
10 | import SimpleITK as sitk
11 | import cv2
12 | import numpy as np
13 | import math
14 | import copy
15 | from tools import *
16 |
17 | def _split_output(yt_hat, t, y, y_scaler, x, is_train=False):
18 | """
19 | Split output into dictionary for easier use in estimation
20 | Args:
21 | yt_hat: Generated prediction
22 | t: Binary treatment assignments
23 | y: Treatment outcomes
24 | y_scaler: Scaled treatment outcomes
25 | x: Covariates
26 | index: Index in data
27 |
28 | Returns:
29 | Dictionary of all needed data
30 | """
31 |
32 |
33 | yt_hat = yt_hat
34 | q_t0 = yt_hat[:, 0].reshape(-1, 1).copy()
35 | q_t1 = yt_hat[:, 1].reshape(-1, 1).copy()
36 |
37 | g = yt_hat[:, 2].copy()
38 | treatment_predicted = g.copy()
39 | treatment_predicted[treatment_predicted>=0.5] = 1
40 | treatment_predicted[treatment_predicted<0.5] = 0
41 |
42 | y = y.copy()
43 | var = "average propensity for t: {}".format(g[t.squeeze() == 1.].mean())
44 |
45 | q_cat = np.concatenate((q_t0, q_t1),1)
46 |
47 | policy = np.argmax(q_cat,1)
48 |
49 | print(var)
50 | print("Policy Risk:", policy_risk_multi(t, y, q_t0, q_t1))
51 | print("Ate_Error:", ate_error_0_1(t, y, q_t0 - q_t1))
52 |
53 | print("Treatment accuracy:", np.sum(treatment_predicted==t.squeeze())/treatment_predicted.shape[0])
54 |
55 | if not is_train:
56 | print("Treatment policy :",policy)
57 | print("Treatment prediction:",treatment_predicted)
58 | print("Treatment label :",t.squeeze().astype(int))
59 |
60 | auc0,auc1 = factual_auc(t, y, q_t0, q_t1)
61 | accuracy_0, accuracy_1 = factual_acc(t, y, q_t0, q_t1)
62 |
63 |
64 | return {'ave propensity for t': g[t.squeeze() == 1.].mean(),
65 | 'Policy Risk': policy_risk_multi(t, y, q_t0, q_t1),
66 | 'Ate_Error_0_1': ate_error_0_1(t, y, q_t0 - q_t1), 'Treatment accuracy': np.sum(treatment_predicted==t.squeeze())/treatment_predicted.shape[0],
67 | 'Treatment policy': policy, 'Treatment prediction': treatment_predicted, 'Treatment label': t.squeeze().astype(int), 'accuracy_0': accuracy_0, 'accuracy_1':accuracy_1,
68 | 'auc0': auc0, 'auc1':auc1}
69 |
70 | average_propensity_for_t0 = []
71 | average_propensity_for_t1 = []
72 | average_propensity_for_t2 = []
73 | policy_risk = []
74 | test_ate_error_0_1 = []
75 | test_ate_error_0_2 = []
76 | test_ate_error_1_2 = []
77 | treatment_accuracy = []
78 | treatment_policy=np.array([])
79 | treatment_prediction=np.array([])
80 | treatment_label=np.array([])
81 | test_factual_accuracy_of_t0 = []
82 | test_factual_accuracy_of_t1 = []
83 | test_factual_accuracy_of_t2 = []
84 |
85 |
86 | train_average_propensity_for_t0 = []
87 | train_average_propensity_for_t1 = []
88 | train_average_propensity_for_t2 = []
89 | train_policy_risk = []
90 | train_ate_error_0_1 = []
91 | train_ate_error_0_2 = []
92 | train_ate_error_1_2 = []
93 | train_treatment_accuracy = []
94 | train_factual_accuracy_of_t0 = []
95 | train_factual_accuracy_of_t1 = []
96 | train_factual_accuracy_of_t2 = []
97 |
98 | train_factual_auc_of_t0 = []
99 | train_factual_auc_of_t1 = []
100 | test_factual_auc_of_t0 = []
101 | test_factual_auc_of_t1 = []
102 |
103 | key_word = 'Treatment accuracy'
104 | key_word4 = 'Policy Risk'
105 | key_word5 = 'accuracy_0'
106 | key_word6 = 'accuracy_1'
107 | key_word7 = 'accuracy_2'
108 |
109 | key_word1 = 'Ate_Error_0_1'
110 | key_word2 = 'Ate_Error_0_2'
111 | key_word3 = 'Ate_Error_1_2'
112 |
113 | key_word_auc0 = 'auc0'
114 | key_word_auc1 = 'auc1'
115 | epoch_index = 0
116 |
117 | for validation_index in range(1):
118 | best_evaluation = 0.
119 | train_outputs_best = {}
120 | test_outputs_best = {}
121 | for epoch in range(0,1500,10):
122 | test_results = np.load("../results_save/IPH_limited_ours2/{}_fold_{}_epoch_test.npz".format(validation_index, epoch), allow_pickle=True)
123 | train_results = np.load("../results_save/IPH_limited_ours2/{}_fold_{}_epoch_train.npz".format(validation_index, epoch), allow_pickle=True)
124 |
125 | yt_hat_test, t_test, y_test, y, x_test = test_results['yt_hat_test'], test_results['t_test'], test_results['y_test'], \
126 | test_results['y'], test_results['x_test']
127 | yt_hat_train, t_train, y_train, y, x_train = train_results['yt_hat_train'], train_results['t_train'], train_results['y_train'], \
128 | train_results['y'], train_results['x_train']
129 |
130 | test_outputs = _split_output(yt_hat_test, t_test, y_test, y, x_test, is_train=False)
131 | train_outputs = _split_output(yt_hat_train, t_train, y_train, y, x_train, is_train=True)
132 | #test_outputs = test_outputs['arr_0'].item()
133 | #train_outputs = train_outputs['arr_0'].item()
134 | if test_outputs[key_word_auc0]+test_outputs[key_word_auc1] >= best_evaluation and epoch>=100:
135 | #if (test_outputs[key_word1]+test_outputs[key_word2]+test_outputs[key_word3]+test_outputs[key_word4]+(1-test_outputs[key_word5])+(1-test_outputs[key_word6])+(1-test_outputs[key_word7]))/7 <= best_evaluation and epoch>=500:
136 | test_outputs_best = test_outputs
137 | best_evaluation = test_outputs[key_word_auc0]+test_outputs[key_word_auc1]
138 | epoch_index = epoch
139 | #best_evaluation = (test_outputs[key_word1]+test_outputs[key_word2]+test_outputs[key_word3]+test_outputs[key_word4]+(1-test_outputs[key_word5])+(1-test_outputs[key_word6])+(1-test_outputs[key_word7]))/7
140 |
141 | train_outputs_best = train_outputs
142 | # if (train_outputs[key_word1]+train_outputs[key_word2]+train_outputs[key_word3]+train_outputs[key_word4]+(1-train_outputs[key_word5])+(1-train_outputs[key_word6])+(1-train_outputs[key_word7]))/7 <= best_evaluation and epoch>=500:
143 | # train_outputs_best = train_outputs
144 |
145 | # #best_evaluation = test_outputs[key_word]
146 | # best_evaluation = (train_outputs[key_word1]+train_outputs[key_word2]+train_outputs[key_word3]+train_outputs[key_word4]+(1-train_outputs[key_word5])+(1-train_outputs[key_word6])+(1-train_outputs[key_word7]))/7
147 |
148 |
149 | print("==========Best test results for the {} fold==========".format(validation_index))
150 |
151 | print("average propensity for t: {}".format(test_outputs_best['ave propensity for t']))
152 | print("Policy Risk:", test_outputs_best['Policy Risk'])
153 | print("Ate_Error_0_1:", test_outputs_best['Ate_Error_0_1'])
154 |
155 | print("Treatment accuracy:", test_outputs_best['Treatment accuracy'])
156 | print("Treatment policy :",test_outputs_best['Treatment policy'])
157 | print("Treatment prediction:",test_outputs_best['Treatment prediction'])
158 | print("Treatment label :",test_outputs_best['Treatment label'])
159 | print("Factual accuracy of t0:", test_outputs_best['accuracy_0'])
160 | print("Factual accuracy of t1:", test_outputs_best['accuracy_1'])
161 | print("Factual auc of t0:", test_outputs_best['auc0'])
162 | print("Factual auc of t1:", test_outputs_best['auc1'])
163 |
164 | # print("Factual auc of t0:", test_outputs_best['auc_0'])
165 |
166 | print("==========Best train results for the {} fold==========".format(validation_index))
167 | print("average propensity for t: {}".format(train_outputs_best['ave propensity for t']))
168 | print("Policy Risk:", train_outputs_best['Policy Risk'])
169 | print("Ate_Error_0_1:", train_outputs_best['Ate_Error_0_1'])
170 |
171 | print("Treatment accuracy:", train_outputs_best['Treatment accuracy'])
172 | print("Factual accuracy of t0:", train_outputs_best['accuracy_0'])
173 | print("Factual accuracy of t1:", train_outputs_best['accuracy_1'])
174 | print("Factual auc of t0:", train_outputs_best['auc0'])
175 | print("Factual auc of t1:", train_outputs_best['auc1'])
176 |
177 | # print("Factual auc of t0:", train_outputs_best['auc_0'])
178 |
179 | print("====================================================")
180 | average_propensity_for_t0.append(test_outputs_best['ave propensity for t'])
181 |
182 | policy_risk.append(test_outputs_best['Policy Risk'])
183 | test_ate_error_0_1.append(test_outputs_best['Ate_Error_0_1'])
184 |
185 | treatment_accuracy.append(test_outputs_best['Treatment accuracy'])
186 | test_factual_accuracy_of_t0.append(test_outputs_best['accuracy_0'])
187 | test_factual_accuracy_of_t1.append(test_outputs_best['accuracy_1'])
188 | test_factual_auc_of_t0.append(test_outputs_best['auc0'])
189 | test_factual_auc_of_t1.append(test_outputs_best['auc1'])
190 |
191 | # test_factual_auc_of_t0.append(test_outputs_best['auc_0'])
192 |
193 | treatment_policy=np.concatenate((treatment_policy,test_outputs_best['Treatment policy']),0)
194 | treatment_prediction=np.concatenate((treatment_prediction,test_outputs_best['Treatment prediction']),0)
195 | treatment_label=np.concatenate((treatment_label,test_outputs_best['Treatment label']),0)
196 |
197 | train_average_propensity_for_t0.append(train_outputs_best['ave propensity for t'])
198 |
199 | train_policy_risk.append(train_outputs_best['Policy Risk'])
200 | train_ate_error_0_1.append(train_outputs_best['Ate_Error_0_1'])
201 |
202 | train_factual_accuracy_of_t0.append(train_outputs_best['accuracy_0'])
203 | train_factual_accuracy_of_t1.append(train_outputs_best['accuracy_1'])
204 | train_factual_auc_of_t0.append(train_outputs_best['auc0'])
205 | train_factual_auc_of_t1.append(train_outputs_best['auc1'])
206 |
207 | train_treatment_accuracy.append(train_outputs_best['Treatment accuracy'])
208 |
209 | print("==========Average best test results==========")
210 | print("The best epoch:",epoch_index)
211 | print("average propensity for t: {}".format(np.mean(average_propensity_for_t0)))
212 | print("Policy Risk: {} +- {}".format(np.mean(policy_risk),np.std(policy_risk)))
213 | print("Ate_Error_0_1: {} +- {}".format(np.mean(test_ate_error_0_1),np.std(test_ate_error_0_1)))
214 |
215 | print("Treatment accuracy: {} +- {}".format(np.mean(treatment_accuracy),np.std(treatment_accuracy)))
216 | print("Treatment policy :",treatment_policy)
217 | print("Treatment prediction:",treatment_prediction)
218 | print("Treatment label :",treatment_label)
219 | print("Factual accuracy of t0: {} +- {}".format(np.mean(test_factual_accuracy_of_t0),np.std(test_factual_accuracy_of_t0)))
220 | print("Factual accuracy of t1: {} +- {}".format(np.mean(test_factual_accuracy_of_t1),np.std(test_factual_accuracy_of_t1)))
221 | print("Factual auc of t0: {} +- {}".format(np.mean(test_factual_auc_of_t0),np.std(test_factual_auc_of_t0)))
222 | print("Factual auc of t1: {} +- {}".format(np.mean(test_factual_auc_of_t1),np.std(test_factual_auc_of_t1)))
223 |
224 | # print("Factual auc of t0: {} +- {}".format(np.mean(test_factual_auc_of_t0),np.std(test_factual_auc_of_t0)))
225 | # print("Factual auc of t1: {} +- {}".format(np.mean(test_factual_auc_of_t1),np.std(test_factual_auc_of_t1)))
226 | # print("Factual auc of t2: {} +- {}".format(np.mean(test_factual_auc_of_t2),np.std(test_factual_auc_of_t2)))
227 | print("==========Average best train results=========")
228 | print("average propensity for t: {}".format(np.mean(train_average_propensity_for_t0)))
229 | print("Policy Risk: {} +- {}".format(np.mean(train_policy_risk),np.std(train_policy_risk)))
230 | print("Ate_Error_0_1: {} +- {}".format(np.mean(train_ate_error_0_1),np.std(train_ate_error_0_1)))
231 |
232 | print("Treatment accuracy: {} +- {}".format(np.mean(train_treatment_accuracy), np.std(train_treatment_accuracy)))
233 | print("=============================================")
234 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.optim as optim
5 | from torch.utils.data import TensorDataset, DataLoader
6 | from torch.utils.tensorboard import SummaryWriter
7 | from torchvision import datasets, transforms
8 | from torch.distributions import Normal, Independent, kl
9 |
10 | import torch.nn.init as init
11 | from torchvision import models
12 |
13 | from resnet3d import *
14 | from torch.autograd import Variable
15 | import math
16 | import numpy as np
17 | import copy
18 | import random
19 |
20 | class BasicBlock(nn.Module):
21 | expansion = 1
22 |
23 | def __init__(self, inplanes, planes, stride=1, downsample=None):
24 | super(BasicBlock, self).__init__()
25 | self.conv1 = conv3x3(inplanes, planes, stride)
26 | self.bn1 = nn.BatchNorm3d(planes)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.conv2 = conv3x3(planes, planes)
29 | self.bn2 = nn.BatchNorm3d(planes)
30 | self.downsample = downsample
31 | self.stride = stride
32 |
33 | def forward(self, x):
34 | residual = x
35 |
36 | out = self.conv1(x)
37 | out = self.bn1(out)
38 | out = self.relu(out)
39 |
40 | out = self.conv2(out)
41 | out = self.bn2(out)
42 |
43 | if self.downsample is not None:
44 | residual = self.downsample(x)
45 |
46 | out += residual
47 | out = self.relu(out)
48 |
49 | return out
50 |
51 | class EpsilonLayer(nn.Module):
52 | def __init__(self):
53 | super(EpsilonLayer, self).__init__()
54 |
55 | # building epsilon trainable weight
56 | self.weights = nn.Parameter(torch.Tensor(1, 1))
57 |
58 | # initializing weight parameter with RandomNormal
59 | nn.init.normal_(self.weights, mean=0, std=0.05)
60 |
61 | def forward(self, inputs):
62 | return torch.mm(torch.ones_like(inputs)[:, 0:1], self.weights.T)
63 |
64 | class VAE_MM(nn.Module):
65 | def __init__(self,input_size,hidden_size,latent_size,out_features=[200, 64, 1]):
66 | super(VAE_MM,self).__init__()
67 | dropout = False
68 | self.input_size = input_size
69 | self.hidden_size = hidden_size
70 | self.latent_size = latent_size
71 |
72 | # self.encoder = nn.Sequential(
73 | # nn.Linear(input_size, hidden_size),
74 | # nn.ReLU(),
75 | # nn.Linear(hidden_size,latent_size*4)
76 | # )
77 |
78 | self.decoder0 = nn.Sequential(
79 | #nn.Linear(latent_size+1,4),
80 | #nn.ReLU(),
81 | nn.Linear(latent_size+1,1)
82 | )
83 |
84 |
85 | self.dropout = nn.Dropout(p=0.2)
86 | self.experts = ProductOfExperts()
87 | self.epsilon = EpsilonLayer()
88 | c=[64,64,128,256,512]
89 | layers = [3, 4, 6, 3]
90 | self.inplanes = c[0]
91 | self.share = torch.nn.Sequential()
92 | self.share.add_module('conv1', nn.Conv3d(1, c[0],kernel_size=7, stride=2, padding=0, bias=False))
93 | self.share.add_module('bn1', nn.BatchNorm3d(c[0]))
94 | self.share.add_module('relu', nn.ReLU(inplace=True))
95 | self.share.add_module('maxpool',nn.MaxPool3d(kernel_size=3, stride=2, padding=1))
96 | self.share.add_module('layer1', self._make_layer(BasicBlock, c[1], layers[0]))
97 | self.share.add_module('layer2', self._make_layer(BasicBlock, c[2], layers[1], stride=2))
98 | self.share.add_module('layer3', self._make_layer(BasicBlock, c[3], layers[2], stride=2))
99 | self.share.add_module('layer4', self._make_layer(BasicBlock, c[4], layers[3], stride=2))
100 | self.share.add_module('avgpool', nn.AvgPool3d([1,7,7]))
101 |
102 |
103 | if dropout is True:
104 | self.share.add_module('dropout', nn.Dropout(0.5))
105 | self.resenet_head = nn.Sequential(nn.Linear(512, out_features[1]), nn.BatchNorm1d(out_features[1]), nn.ReLU())
106 | self.resenet_head_t = nn.Sequential(nn.Linear(out_features[1]+1, out_features[1]), nn.BatchNorm1d(out_features[1]), nn.ReLU())
107 | self.pnet_net_im_0 = nn.Sequential(nn.Linear(out_features[1], latent_size*2), nn.BatchNorm1d(latent_size*2))
108 | self.pnet_net_im_1 = nn.Sequential(nn.Linear(out_features[1], latent_size*2), nn.BatchNorm1d(latent_size*2))
109 | self.qnet_net_im_0 = nn.Sequential(nn.Linear(out_features[1]+1, latent_size*2), nn.BatchNorm1d(latent_size*2))
110 | self.qnet_net_im_1 = nn.Sequential(nn.Linear(out_features[1]+1, latent_size*2), nn.BatchNorm1d(latent_size*2))
111 |
112 | self.pnet_net_tab_0 = nn.Sequential(nn.Linear(out_features[1], latent_size*2), nn.BatchNorm1d(latent_size*2))
113 | self.pnet_net_tab_1 = nn.Sequential(nn.Linear(out_features[1], latent_size*2), nn.BatchNorm1d(latent_size*2))
114 | self.qnet_net_tab_0 = nn.Sequential(nn.Linear(out_features[1]+1, latent_size*2), nn.BatchNorm1d(latent_size*2))
115 | self.qnet_net_tab_1 = nn.Sequential(nn.Linear(out_features[1]+1, latent_size*2), nn.BatchNorm1d(latent_size*2))
116 |
117 | self.representation_block = nn.Sequential(
118 | nn.Linear(in_features=input_size, out_features=out_features[0]),
119 | nn.BatchNorm1d(out_features[0]),
120 | nn.ReLU(),
121 | nn.Linear(in_features=out_features[0], out_features=out_features[1]),
122 | nn.BatchNorm1d(out_features[1]),
123 | nn.ReLU()
124 | )
125 |
126 | self.representation_block_0 = nn.Sequential(
127 | nn.Linear(in_features=out_features[1]+1, out_features=out_features[1]),
128 | nn.BatchNorm1d(out_features[1]),
129 | nn.ReLU()
130 | )
131 |
132 | self.representation_block_1 = nn.Sequential(
133 | nn.Linear(in_features=out_features[1], out_features=out_features[1]),
134 | nn.BatchNorm1d(out_features[1]),
135 | nn.ReLU()
136 | )
137 |
138 | for m in self.modules():
139 | if isinstance(m,nn.Conv3d):
140 | nn.init.kaiming_uniform_(m.weight,mode='fan_out',nonlinearity='relu')
141 | elif isinstance(m,nn.BatchNorm1d):
142 | nn.init.constant_(m.weight,1)
143 | nn.init.constant_(m.bias,0)
144 | elif isinstance(m,nn.BatchNorm3d):
145 | nn.init.constant_(m.weight,1)
146 | nn.init.constant_(m.bias,0)
147 | if isinstance(m,nn.Linear):
148 | nn.init.normal_(m.weight, std=0.01)
149 |
150 |
151 | def _make_layer(self, block, planes, blocks, stride=1):
152 | downsample = None
153 | if stride != 1 or self.inplanes != planes * block.expansion:
154 | downsample = nn.Sequential(nn.Conv3d(self.inplanes, planes*block.expansion,kernel_size=1, stride=stride, bias=False),
155 | nn.BatchNorm3d(planes*block.expansion))
156 |
157 | layers = []
158 | layers.append(block(self.inplanes, planes, stride, downsample))
159 | self.inplanes = planes * block.expansion
160 | for i in range(1, blocks):
161 | layers.append(block(self.inplanes, planes))
162 |
163 | return nn.Sequential(*layers)
164 |
165 | def encode_tab(self,x,treatment):
166 | x = self.representation_block(x)
167 | treatment_0 = torch.zeros(treatment.shape).cuda()
168 | treatment_1 = torch.ones(treatment.shape).cuda()
169 | x_0 = torch.cat((x,treatment_0),dim=1)
170 | x_1 = torch.cat((x,treatment_1),dim=1)
171 | h_tab_0 = self.representation_block_0(x_0)
172 | h_tab_1 = self.representation_block_0(x_1)
173 | return h_tab_0, h_tab_1
174 |
175 | def encode_img(self,x,treatment):
176 | x = self.resenet_head(x)
177 | treatment_0 = torch.zeros(treatment.shape).cuda()
178 | treatment_1 = torch.ones(treatment.shape).cuda()
179 | x_0 = torch.cat((x,treatment_0),dim=1)
180 | x_1 = torch.cat((x,treatment_1),dim=1)
181 | h_im_0 = self.resenet_head_t(x_0)
182 | h_im_1 = self.resenet_head_t(x_1)
183 | return h_im_0, h_im_1
184 |
185 | def p_net_im(self,phi0_im,phi1_im):
186 | h_im0 = self.pnet_net_im_0(phi0_im)
187 | mu0_im, logvar0_im = torch.chunk(h_im0,2,dim=-1)
188 | h_im1 = self.pnet_net_im_0(phi1_im)
189 | mu1_im, logvar1_im = torch.chunk(h_im1,2,dim=-1)
190 | return mu0_im, logvar0_im, mu1_im, logvar1_im
191 |
192 | def p_net_tab(self,phi0_tab,phi1_tab):
193 | h_tab0 = self.pnet_net_tab_0(phi0_tab)
194 | mu0_tab, logvar0_tab = torch.chunk(h_tab0,2,dim=-1)
195 | h_tab1 = self.pnet_net_tab_0(phi1_tab)
196 | mu1_tab, logvar1_tab = torch.chunk(h_tab1,2,dim=-1)
197 | return mu0_tab, logvar0_tab, mu1_tab, logvar1_tab
198 |
199 | def q_net_im(self,phi0_im,phi1_im):
200 | h_im0 = self.qnet_net_im_0(phi0_im)
201 | mu0_im, logvar0_im = torch.chunk(h_im0,2,dim=-1)
202 | h_im1 = self.qnet_net_im_0(phi1_im)
203 | mu1_im, logvar1_im = torch.chunk(h_im1,2,dim=-1)
204 | return mu0_im, logvar0_im, mu1_im, logvar1_im
205 |
206 | def q_net_tab(self,phi0_tab,phi1_tab):
207 | h_tab0 = self.qnet_net_tab_0(phi0_tab)
208 | mu0_tab, logvar0_tab = torch.chunk(h_tab0,2,dim=-1)
209 | h_tab1 = self.qnet_net_tab_0(phi1_tab)
210 | mu1_tab, logvar1_tab = torch.chunk(h_tab1,2,dim=-1)
211 | return mu0_tab, logvar0_tab, mu1_tab, logvar1_tab
212 |
213 | def PoE(self, mu0_im, logvar0_im, mu1_im, logvar1_im,
214 | mu0_tab, logvar0_tab, mu1_tab, logvar1_tab):
215 | mu0, logvar0, mu1, logvar1 = prior_expert((1,mu0_im.shape[0],self.latent_size))
216 | mu0 = torch.cat((mu0,mu0_tab.unsqueeze(0)),dim=0)
217 | logvar0 = torch.cat((logvar0,logvar0_tab.unsqueeze(0)),dim=0)
218 | mu1 = torch.cat((mu1,mu1_tab.unsqueeze(0)),dim=0)
219 | logvar1 = torch.cat((logvar1,logvar1_tab.unsqueeze(0)),dim=0)
220 |
221 | mu0 = torch.cat((mu0,mu0_im.unsqueeze(0)),dim=0)
222 | logvar0 = torch.cat((logvar0,logvar0_im.unsqueeze(0)),dim=0)
223 | mu1 = torch.cat((mu1,mu1_im.unsqueeze(0)),dim=0)
224 | logvar1 = torch.cat((logvar1,logvar1_im.unsqueeze(0)),dim=0)
225 | mu0, logvar0 = self.experts(mu0, logvar0)
226 | mu1, logvar1 = self.experts(mu1, logvar1)
227 |
228 | return mu0, logvar0, mu1, logvar1
229 |
230 | def VDC_p(self,phi0_im, phi1_im, phi0_tab, phi1_tab):
231 | mu0_im_p, logvar0_im_p, mu1_im_p, logvar1_im_p = self.p_net_im(phi0_im,phi1_im)
232 | mu0_tab_p, logvar0_tab_p, mu1_tab_p, logvar1_tab_p = self.p_net_tab(phi0_tab, phi1_tab)
233 |
234 | return self.PoE(mu0_im_p, logvar0_im_p, mu1_im_p, logvar1_im_p,
235 | mu0_tab_p, logvar0_tab_p, mu1_tab_p, logvar1_tab_p)
236 |
237 | def VDC_q(self, phi0_im, phi1_im, phi0_tab, phi1_tab,labels):
238 | phi0_im = torch.cat((phi0_im,labels[:,0].unsqueeze(1)),dim=1)
239 | phi1_im = torch.cat((phi1_im,labels[:,0].unsqueeze(1)),dim=1)
240 | phi0_tab = torch.cat((phi0_tab,labels[:,0].unsqueeze(1)),dim=1)
241 | phi1_tab = torch.cat((phi1_tab,labels[:,0].unsqueeze(1)),dim=1)
242 | mu0_im_p, logvar0_im_p, mu1_im_p, logvar1_im_p = self.q_net_im(phi0_im,phi1_im)
243 | mu0_tab_p, logvar0_tab_p, mu1_tab_p, logvar1_tab_p = self.q_net_tab(phi0_tab, phi1_tab)
244 |
245 | return self.PoE(mu0_im_p, logvar0_im_p, mu1_im_p, logvar1_im_p,
246 | mu0_tab_p, logvar0_tab_p, mu1_tab_p, logvar1_tab_p)
247 |
248 | def reparameterize(self,mu,logvar):
249 | std = torch.exp(0.5*logvar)
250 | eps = torch.randn_like(std)
251 | z = mu + eps*std
252 | return z
253 |
254 |
255 | def decode(self,z0,z1, treatment):
256 | treatment_0 = torch.zeros(treatment.shape).cuda()
257 | treatment_1 = torch.ones(treatment.shape).cuda()
258 | z0 = torch.cat((z0,treatment_0),dim=1)
259 | z1 = torch.cat((z1,treatment_1),dim=1)
260 | y0 = self.decoder0(z0)
261 | y1 = self.decoder0(z1)
262 |
263 | return F.sigmoid(y0)[:,0], F.sigmoid(y1)[:,0]
264 |
265 |
266 | def forward(self,x,labels,image,is_train=False):
267 |
268 | treatment = labels[:,1]
269 | treatment = torch.unsqueeze(treatment,dim=1)
270 |
271 | image = self.share.forward(image)
272 | image = self.dropout(image)
273 | image = image[:,:,0,0,0]
274 | phi0_im, phi1_im = self.encode_img(image,treatment)
275 | phi0_tab, phi1_tab = self.encode_tab(x,treatment)
276 |
277 | mu0_p, logvar0_p, mu1_p, logvar1_p = self.VDC_p(phi0_im, phi1_im,phi0_tab, phi1_tab)
278 | if is_train:
279 | mu0_q, logvar0_q, mu1_q, logvar1_q = self.VDC_q(phi0_im, phi1_im,phi0_tab, phi1_tab,labels)
280 | z0 = self.reparameterize(mu0_q, logvar0_q)
281 | z1 = self.reparameterize(mu1_q, logvar1_q)
282 | y0,y1 = self.decode(z0,z1,treatment)
283 | dist_p_0 = Independent(Normal(loc=mu0_p,scale=torch.exp(logvar0_p)),1)
284 | dist_q_0 = Independent(Normal(loc=mu0_q,scale=torch.exp(logvar0_q)),1)
285 | dist_p_1 = Independent(Normal(loc=mu1_p,scale=torch.exp(logvar1_p)),1)
286 | dist_q_1 = Independent(Normal(loc=mu1_q,scale=torch.exp(logvar1_q)),1)
287 | return y0, y1, dist_p_0, dist_q_0, dist_p_1, dist_q_1
288 | else:
289 | z0 = self.reparameterize(mu0_p, logvar0_p)
290 | z1 = self.reparameterize(mu1_p, logvar1_p)
291 | y0,y1 = self.decode(z0,z1,treatment)
292 | return y0, y1
293 |
294 | def prior_expert(size):
295 | mu0 = Variable(torch.zeros(size))
296 | logvar0 = Variable(torch.log(torch.ones(size)))
297 | mu1 = Variable(torch.zeros(size))
298 | logvar1 = Variable(torch.log(torch.ones(size)))
299 | mu0, logvar0, mu1, logvar1 = mu0.cuda(), logvar0.cuda(), mu1.cuda(), logvar1.cuda()
300 | return mu0, logvar0, mu1, logvar1
301 |
302 | class ProductOfExperts(nn.Module):
303 | def forward(self,mu,logvar,eps=1e-8):
304 | var = torch.exp(logvar)+eps
305 | T = 1./var
306 | pd_mu = torch.sum(mu*T,dim=0) / torch.sum(T, dim=0)
307 | pd_var = 1./torch.sum(T,dim=0)
308 | pd_logvar = torch.log(pd_var)
309 |
310 | return pd_mu, pd_logvar
311 |
312 | def VAE_loss_function(y0, y1, dist_p_0, dist_q_0, dist_p_1, dist_q_1, labels, class_ratio,ratio_as_t1):
313 | treatment = labels[:,1]
314 | map_0 = torch.zeros(treatment.shape)
315 | map_0[treatment==0] = 1
316 | map_1 = torch.zeros(treatment.shape)
317 | map_1[treatment==1] = 1
318 | BCE = torch.sum(map_0.cuda()*F.binary_cross_entropy(y0, labels[:,0], reduction='none'))/(torch.sum(map_0)+1e-8) + torch.sum(map_1.cuda()*F.binary_cross_entropy(y1, labels[:,0], reduction='none'))/(torch.sum(map_1)+1e-8)
319 |
320 | KLD = torch.mean(kl.kl_divergence(dist_q_0,dist_p_0)+kl.kl_divergence(dist_q_1,dist_p_1))
321 | return BCE, 1.0*KLD
322 |
323 | def normal_loss_function(y0, y1, labels, class_ratio,ratio_as_t1):
324 | treatment = labels[:,1]
325 | map_0 = torch.zeros(treatment.shape)
326 | map_0[treatment==0] = 1
327 | map_1 = torch.zeros(treatment.shape)
328 | map_1[treatment==1] = 1
329 | BCE = torch.sum(map_0.cuda()*F.binary_cross_entropy(y0, labels[:,0], reduction='none'))/(torch.sum(map_0)+1e-8) + torch.sum(map_1.cuda()*F.binary_cross_entropy(y1, labels[:,0], reduction='none'))/(torch.sum(map_1)+1e-8)
330 |
331 | return BCE
332 |
--------------------------------------------------------------------------------
/main_VAE.py:
--------------------------------------------------------------------------------
1 | from models import *
2 | import os
3 | import glob
4 | import argparse
5 | from sklearn.preprocessing import StandardScaler
6 | from tqdm import tqdm
7 | from sklearn.model_selection import train_test_split
8 | from upload_data import *
9 | import SimpleITK as sitk
10 | import cv2
11 | import numpy as np
12 | import math
13 | from tools import *
14 |
15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
16 | #torch.backends.cudnn.enabled = False
17 | #random.shuffle
18 |
19 | def _split_output(yt_hat, t, y, y_scaler, x, is_train=False):
20 | """
21 | Split output into dictionary for easier use in estimation
22 | Args:
23 | yt_hat: Generated prediction
24 | t: Binary treatment assignments
25 | y: Treatment outcomes
26 | y_scaler: Scaled treatment outcomes
27 | x: Covariates
28 | index: Index in data
29 |
30 | Returns:
31 | Dictionary of all needed data
32 | """
33 |
34 |
35 | yt_hat = yt_hat.detach().cpu().numpy()
36 | q_t0 = yt_hat[:, 0].reshape(-1, 1).copy()
37 | q_t1 = yt_hat[:, 1].reshape(-1, 1).copy()
38 |
39 | g = yt_hat[:, 2].copy()
40 | treatment_predicted = g.copy()
41 | treatment_predicted[treatment_predicted>=0.5] = 1
42 | treatment_predicted[treatment_predicted<0.5] = 0
43 |
44 | y = y.copy()
45 | var = "average propensity for t: {}".format(g[t.squeeze() == 1.].mean())
46 |
47 | q_cat = np.concatenate((q_t0, q_t1),1)
48 |
49 | policy = np.argmax(q_cat,1)
50 | #policy = np.zeros(q_cat.shape[0])
51 |
52 | print(var)
53 | print("Policy Risk:", policy_risk_multi(t, y, q_t0, q_t1))
54 | print("Ate_Error:", ate_error_0_1(t, y, q_t0 - q_t1))
55 |
56 | print("Treatment accuracy:", np.sum(treatment_predicted==t.squeeze())/treatment_predicted.shape[0])
57 |
58 | if not is_train:
59 | print("Treatment policy :",policy)
60 | print("Treatment prediction:",treatment_predicted)
61 | print("Treatment label :",t.squeeze().astype(int))
62 |
63 | factual_auc(t, y, q_t0, q_t1)
64 | factual_acc(t, y, q_t0, q_t1)
65 |
66 | return {'ave propensity for t': g[t.squeeze() == 1.].mean(),
67 | 'Policy Risk': policy_risk_multi(t, y, q_t0, q_t1),
68 | 'Ate_Error_0_1': ate_error_0_1(t, y, q_t0 - q_t1), 'Treatment accuracy': np.sum(treatment_predicted==t.squeeze())/treatment_predicted.shape[0],
69 | 'Treatment policy': policy, 'Treatment prediction': treatment_predicted, 'Treatment label': t.squeeze().astype(int)}
70 |
71 |
72 | def train(train_loader, net, optimizer, criterion, class_ratio,ratio_as_t1):
73 | """
74 | Trains network for one epoch in batches.
75 |
76 | Args:
77 | train_loader: Data loader for training set.
78 | net: Neural network model.
79 | optimizer: Optimizer (e.g. SGD).
80 | criterion: Loss function (e.g. cross-entropy loss).
81 | """
82 |
83 | avg_loss_0 = 0
84 | avg_loss_1 = 0
85 |
86 | for i, data in enumerate(train_loader):
87 |
88 | # get the inputs; data is a list of [inputs, labels]
89 | inputs, labels, images = data
90 | #traumatic = inputs[:,3]
91 | # zero the parameter gradients
92 | optimizer.zero_grad()
93 |
94 | # forward + backward + optimize
95 | y0, y1, dist_p_0, dist_q_0, dist_p_1, dist_q_1 = net(inputs,labels,images,is_train=True)
96 | BCE,KLD = criterion(y0, y1, dist_p_0, dist_q_0, dist_p_1, dist_q_1, labels, class_ratio,ratio_as_t1)
97 | loss = BCE + KLD
98 | #loss = criterion(outputs, labels, traumatic, class_ratio)
99 | loss.backward()
100 | optimizer.step()
101 |
102 | # keep track of loss and accuracy
103 | avg_loss_0 += BCE
104 | avg_loss_1 += KLD
105 |
106 | return avg_loss_0 / len(train_loader), avg_loss_1 / len(train_loader), loss / len(train_loader)
107 |
108 | def test(train_loader, net, criterion, number):
109 | """
110 | Trains network for one epoch in batches.
111 |
112 | Args:
113 | train_loader: Data loader for training set.
114 | net: Neural network model.
115 | optimizer: Optimizer (e.g. SGD).
116 | criterion: Loss function (e.g. cross-entropy loss).
117 | """
118 | net.eval()
119 | avg_loss = 0
120 |
121 | yt_hat_test = torch.from_numpy(np.zeros((number,4)))
122 | num_ = 0
123 | with torch.no_grad():
124 | for i, data in enumerate(train_loader):
125 |
126 | # get the inputs; data is a list of [inputs, labels]
127 | inputs, labels, images = data
128 | traumatic = inputs[:,3]
129 | # zero the parameter gradients
130 |
131 | y0, y1 = net(inputs,labels,images)
132 | yt_hat_test[num_:num_+y0.shape[0],0] = y0
133 | yt_hat_test[num_:num_+y0.shape[0],1] = y1
134 | num_ += y0.shape[0]
135 | net.train()
136 | return yt_hat_test
137 |
138 | def load_image(path):
139 | get_test_X = sitk.ReadImage(path)
140 | test_X = sitk.GetArrayFromImage(get_test_X).astype(np.float32)
141 | image = np.zeros((test_X.shape[0],224,224)).astype(np.float32)
142 | for num in range(len(image)):
143 | image[num] = cv2.resize(test_X[num], (224, 224))
144 | return image
145 |
146 | def train_and_predict_dragons(t, y, x, img_path, targeted_regularization=True, output_dir='',
147 | knob_loss=VAE_loss_function, ratio=1., dragon='', val_split=0.2, batch_size=64, validation_index=0):
148 | """
149 | Method for training dragonnet and tarnet and predicting new results
150 | Returns:
151 | Outputs on train and test data
152 | """
153 |
154 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
155 | print(device)
156 |
157 | verbose = 0
158 |
159 | train_outputs_best = {}
160 | test_outputs_best = {}
161 | best_evaluation = 1.
162 |
163 | if dragon == 'tarnet':
164 | print('I am here making tarnet')
165 | net = TarNet(x.shape[1]).to("cuda")
166 |
167 | elif dragon == 'dragonnet':
168 | print("I am here making dragonnet")
169 | net = DragonNet(x.shape[1]).to("cuda")
170 |
171 | elif dragon == 'ours':
172 | print("I am here making ours")
173 | net = VAE_MM(x.shape[1],hidden_size=200,latent_size=10).to("cuda")
174 |
175 | # Which loss to use for training the network
176 | #net = torch.nn.DataParallel(net)
177 | if targeted_regularization:
178 | loss = make_tarreg_loss(ratio=ratio, dragonnet_loss=knob_loss)
179 | else:
180 | loss = knob_loss
181 |
182 | # loss = knob_loss
183 | # for reporducing the IHDP experimemt
184 |
185 | i = 0
186 | torch.manual_seed(i)
187 | np.random.seed(i)
188 | # Get the data and optionally divide into train and test set
189 |
190 | all_index = np.arange(int(x.shape[0]))
191 |
192 | y_origin = copy.copy(y)
193 | y[y_origin>=4]=0
194 | y[y_origin<=3]=1
195 |
196 | ivh = x[:,1]
197 | GCS = x[:,4]
198 | ivh_index = []
199 | GCS_index = []
200 | for limited_index in all_index:
201 | if ivh[limited_index]==1 and t[limited_index]==0:
202 | ivh_index.append(limited_index)
203 | if len(ivh_index)>=50:
204 | break
205 | for limited_index in all_index:
206 | if GCS[limited_index]<9 and t[limited_index]==1:
207 | GCS_index.append(limited_index)
208 | if len(GCS_index)>=50:
209 | break
210 | test_index = np.array(ivh_index+GCS_index)
211 | x = x[:,1:]
212 |
213 | train_index = []
214 | for m in all_index:
215 | if m not in test_index:
216 | train_index.append(m)
217 |
218 | x_train, x_test = x[train_index], x[test_index]
219 | y_train, y_test = y[train_index], y[test_index]
220 | t_train, t_test = t[train_index], t[test_index]
221 | img_path_train, img_path_test = img_path[train_index], img_path[test_index]
222 |
223 | yt_train = np.concatenate([y_train, t_train], 1)
224 | yt_test = np.concatenate([y_test, t_test], 1)
225 |
226 | t0_index = np.where(t_train==0)
227 | t1_index = np.where(t_train==1)
228 |
229 | ratio_t0 = np.sum(y_train[t0_index])/len(y_train[t0_index])
230 | ratio_t1 = np.sum(y_train[t1_index])/len(y_train[t1_index])
231 |
232 | class_ratio = [ratio_t0, ratio_t1]
233 |
234 | ratio_as_t1 = np.sum(t_train)/len(t_train)
235 |
236 |
237 | train_data = trainerData3d_preload(img_path_train, x_train, y_train, t_train, is_train = True)
238 | test_data = trainerData3d_preload(img_path_test, x_test, y_test, t_test, is_train = False)
239 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last = True)
240 | train_loader_test = DataLoader(train_data, batch_size=batch_size, shuffle=False)
241 | test_loader = DataLoader(test_data, batch_size=256, shuffle=False)
242 |
243 | import time;
244 | start_time = time.time()
245 |
246 | # Configuring optimizers
247 | # Training the networks first for 100 epochs with the Adam optimizer and
248 | # then for 300 epochs with the SGD optimizer.
249 | epochs1 = 1500
250 | epochs2 = 500
251 |
252 | # Add L2 regularization to t0 and t1 heads of the network
253 |
254 | optimizer_Adam = optim.Adam(net.parameters(), lr=5e-3)
255 | scheduler_Adam = optim.lr_scheduler.StepLR(optimizer=optimizer_Adam, step_size = 300, gamma=0.5)
256 | #scheduler_SGD = optim.lr_scheduler.StepLR(optimizer=optimizer_SGD, step_size = 200, gamma=0.5)
257 |
258 | train_loss = 0
259 | epochs0 = 0
260 |
261 | if epochs0 != 0:
262 | load_model_path = '../models_save/IPH_limited_ours2/'+str(epochs0)+'.pth'
263 | net.load_state_dict(torch.load(load_model_path))
264 |
265 | # Adam training run
266 | for epoch in range(epochs0, epochs1):
267 | # Train on data
268 | train_loss0,train_loss1, train_sum = train(train_loader, net, optimizer_Adam, loss, class_ratio,ratio_as_t1)
269 | scheduler_Adam.step(train_sum)
270 |
271 | #train_loss = train(train_loader, net, optimizer_SGD, loss, class_ratio)
272 | #scheduler_SGD.step(train_loss)
273 |
274 |
275 |
276 | if epoch % 10 ==0:
277 | print("BCE:"+str(epoch)+"/"+str(epochs1)+" "+f"Adam loss: {train_loss0}")
278 | print("KLD:"+str(epoch)+"/"+str(epochs1)+" "+f"Adam loss: {train_loss1}")
279 | yt_hat_test = test(test_loader, net, loss, len(test_index))
280 | yt_hat_train = test(train_loader_test, net, loss, len(train_index))
281 | np.savez_compressed("../results_save/IPH_limited_ours2/{}_fold_{}_epoch_test.npz".format(validation_index, epoch),yt_hat_test=yt_hat_test,t_test=t_test,y_test=y_test,
282 | y=y,x_test=x_test)
283 | np.savez_compressed("../results_save/IPH_limited_ours2/{}_fold_{}_epoch_train.npz".format(validation_index, epoch),yt_hat_train=yt_hat_train,t_train=t_train,y_train=y_train,
284 | y=y,x_train=x_train)
285 | test_outputs = _split_output(yt_hat_test, t_test, y_test, y, x_test, is_train=False)
286 | train_outputs = _split_output(yt_hat_train, t_train, y_train, y, x_train, is_train=True)
287 | if test_outputs['Policy Risk'] <= best_evaluation:
288 | train_outputs_best = train_outputs
289 | test_outputs_best = test_outputs
290 | best_evaluation = test_outputs['Policy Risk']
291 | print("==================the {} fold====================".format(validation_index))
292 |
293 | if epoch % 100 ==0:
294 | save_model_path = '../models_save/IPH_limited_ours/'+str(epoch)+'.pth'
295 | torch.save(net.state_dict(),save_model_path)
296 | save_model_path = '../models_save/IPH_limited_ours/'+str(epoch)+ '_' + str(validation_index) + '_fold.pth'
297 | torch.save(net.state_dict(),save_model_path)
298 | return test_outputs_best, train_outputs_best
299 |
300 |
301 | def run_ihdp(data_base_dir, output_dir='~/result/IPH/',
302 | knob_loss=VAE_loss_function,
303 | ratio=1., dragon=''):
304 |
305 | print("the dragon is {}".format(dragon))
306 |
307 | simulation_files = sorted(glob.glob("{}/*.xls".format(data_base_dir)))
308 |
309 | for idx, simulation_file in enumerate(simulation_files):
310 |
311 | simulation_output_dir = os.path.join(output_dir, str(idx))
312 |
313 | os.makedirs(simulation_output_dir, exist_ok=True)
314 |
315 | x, img_path = load_and_format_covariates(simulation_file)
316 | t, y, y_cf, mu_0, mu_1 = load_all_other_crap(simulation_file)
317 | np.savez_compressed(os.path.join(simulation_output_dir, "simulation_outputs.npz"),
318 | t=t, y=y, y_cf=y_cf, mu_0=mu_0, mu_1=mu_1)
319 |
320 | average_propensity_for_t0 = []
321 | average_propensity_for_t1 = []
322 | average_propensity_for_t2 = []
323 | policy_risk = []
324 | ate_error_0_1 = []
325 | ate_error_0_2 = []
326 | ate_error_1_2 = []
327 | treatment_accuracy = []
328 | treatment_policy=np.array([])
329 | treatment_prediction=np.array([])
330 | treatment_label=np.array([])
331 |
332 |
333 | train_average_propensity_for_t0 = []
334 | train_average_propensity_for_t1 = []
335 | train_average_propensity_for_t2 = []
336 | train_policy_risk = []
337 | train_ate_error_0_1 = []
338 | train_ate_error_0_2 = []
339 | train_ate_error_1_2 = []
340 | train_treatment_accuracy = []
341 |
342 |
343 | for validation_index in range(0,1):
344 | # print("Is targeted regularization: {}".format(is_targeted_regularization))
345 | test_outputs_best, train_outputs_best = train_and_predict_dragons(t, y, x, img_path,
346 | targeted_regularization=False,
347 | output_dir=simulation_output_dir,
348 | knob_loss=knob_loss, ratio=ratio, dragon=dragon,
349 | val_split=0.2, batch_size=128, validation_index=validation_index)
350 |
351 | #np.savez_compressed("../results_save/cli_img/{}_fold_test.npz".format(validation_index),test_outputs_best)
352 | #np.savez_compressed("../results_save/cli_img/{}_fold_train.npz".format(validation_index),train_outputs_best)
353 | print("==========Best test results for the {} fold==========".format(validation_index))
354 | print("average propensity for t: {}".format(test_outputs_best['ave propensity for t']))
355 | print("Policy Risk:", test_outputs_best['Policy Risk'])
356 | print("Ate_Error_0_1:", test_outputs_best['Ate_Error_0_1'])
357 |
358 |
359 | print("Treatment accuracy:", test_outputs_best['Treatment accuracy'])
360 | print("Treatment policy :",test_outputs_best['Treatment policy'])
361 | print("Treatment prediction:",test_outputs_best['Treatment prediction'])
362 | print("Treatment label :",test_outputs_best['Treatment label'])
363 | print("==========Best train results for the {} fold==========".format(validation_index))
364 | print("average propensity for t: {}".format(train_outputs_best['ave propensity for t']))
365 | print("Policy Risk:", train_outputs_best['Policy Risk'])
366 | print("Ate_Error_0_1:", train_outputs_best['Ate_Error_0_1'])
367 |
368 | print("Treatment accuracy:", train_outputs_best['Treatment accuracy'])
369 | print("====================================================")
370 | average_propensity_for_t0.append(test_outputs_best['ave propensity for t'])
371 |
372 | policy_risk.append(test_outputs_best['Policy Risk'])
373 | ate_error_0_1.append(test_outputs_best['Ate_Error_0_1'])
374 |
375 | treatment_accuracy.append(test_outputs_best['Treatment accuracy'])
376 | treatment_policy=np.concatenate((treatment_policy,test_outputs_best['Treatment policy']),0)
377 | treatment_prediction=np.concatenate((treatment_prediction,test_outputs_best['Treatment prediction']),0)
378 | treatment_label=np.concatenate((treatment_label,test_outputs_best['Treatment label']),0)
379 |
380 | train_average_propensity_for_t0.append(train_outputs_best['ave propensity for t'])
381 |
382 | train_policy_risk.append(train_outputs_best['Policy Risk'])
383 | train_ate_error_0_1.append(train_outputs_best['Ate_Error_0_1'])
384 |
385 | train_treatment_accuracy.append(train_outputs_best['Treatment accuracy'])
386 |
387 | print("==========Average best test results==========")
388 | print("average propensity for t: {}".format(np.mean(average_propensity_for_t0)))
389 | print("Policy Risk:", np.mean(policy_risk))
390 | print("Ate_Error_0_1:", np.mean(ate_error_0_1))
391 |
392 | print("Treatment accuracy:", np.mean(treatment_accuracy))
393 | print("Treatment policy :",treatment_policy)
394 | print("Treatment prediction:",treatment_prediction)
395 | print("Treatment label :",treatment_label)
396 | print("==========Average best train results=========")
397 | print("average propensity for t: {}".format(np.mean(train_average_propensity_for_t0)))
398 | print("Policy Risk:", np.mean(train_policy_risk))
399 | print("Ate_Error_0_1:", np.mean(train_ate_error_0_1))
400 |
401 | print("Treatment accuracy:", np.mean(train_treatment_accuracy))
402 | print("=============================================")
403 |
404 |
405 | def turn_knob(data_base_dir, knob='dragonnet',
406 | output_base_dir=''):
407 | output_dir = os.path.join(output_base_dir, knob)
408 |
409 | if knob == 'dragonnet':
410 | run_ihdp(data_base_dir=data_base_dir, output_dir=output_dir, dragon='dragonnet')
411 |
412 | if knob == 'tarnet':
413 | run_ihdp(data_base_dir=data_base_dir, output_dir=output_dir, dragon='tarnet')
414 |
415 | if knob == 'ours':
416 | run_ihdp(data_base_dir=data_base_dir, output_dir=output_dir, dragon='ours')
417 |
418 | def main():
419 | parser = argparse.ArgumentParser()
420 | parser.add_argument('--data_base_dir', type=str, help="path to directory",default='../data/IPH')
421 | parser.add_argument('--knob', type=str, default='ours',
422 | help="dragonnet or tarnet or ours")
423 |
424 | parser.add_argument('--output_base_dir', type=str, help="directory to save the output",default='../result/ours')
425 |
426 | args = parser.parse_args()
427 | turn_knob(args.data_base_dir, args.knob, args.output_base_dir)
428 |
429 |
430 | if __name__ == '__main__':
431 | main()
432 |
--------------------------------------------------------------------------------