├── dragonnet ├── __init__.py ├── model.py └── dragonnet.py ├── requirements.txt ├── .gitignore ├── setup.py ├── LICENSE └── README.md /dragonnet/__init__.py: -------------------------------------------------------------------------------- 1 | from dragonnet.dragonnet import DragonNet 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch~=1.12 2 | pandas~=1.4 3 | numpy~=1.22 4 | scikit-learn~=1.1.1 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | data/ 3 | __pycache__/ 4 | dist 5 | build 6 | dragonnet.egg-info 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="dragonnet", 5 | packages=find_packages( 6 | include=["dragonnet"] 7 | ), 8 | version="0.1", 9 | description="pytorch implementation of dragonnet", 10 | author="Faraz", 11 | classifiers=[ 12 | "Programming Language :: Python :: 3", 13 | "Operating System :: OS Independent", 14 | ], 15 | python_requires=">=3.6", 16 | install_requires=[ 17 | "pandas>=1.4", 18 | "numpy>=1.22", 19 | "torch>=1.12" 20 | ], 21 | ) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Faraz Mahmoodian 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | Pytorch implementation of DragonNet from the paper: 3 | 4 | Shi, C., Blei, D. and Veitch, V., 2019. Adapting neural networks for the estimation of treatment effects. Advances in neural information processing systems, 32. 5 | [arxiv link](https://arxiv.org/abs/1906.02120) 6 | 7 | Author's original Tensorflow [implementation](https://github.com/claudiashi57/dragonnet) 8 | 9 | ### Installation 10 | 11 | ```shell 12 | python setup.py bdist_wheel 13 | pip install dist/dragonnet-0.1-py3-none-any.whl 14 | ``` 15 | 16 | ### Usage 17 | 18 | ```python 19 | # import the module 20 | from dragonnet.dragonnet import DragonNet 21 | 22 | # initialize model and train 23 | model = DragonNet(X.shape[1]) 24 | model.fit(X_train, y_train, t_train) 25 | 26 | # predict 27 | y0_pred, y1_pred, t_pred, _ = model.predict(X_test) 28 | ``` 29 | 30 | ### Parameters 31 | ```text 32 | class dragon.DragonNet(input_dim, shared_hidden=200, outcome_hidden=100, alpha=1.0, beta=1.0, epochs=200, batch_size=64, learning_rate=1e-5, data_loader_num_workers=4, loss='tarreg') 33 | 34 | input_dim: int 35 | input dimension for covariates 36 | shared_hidden: int, default=200 37 | layer size for hidden shared representation layers 38 | outcome_hidden: int, default=100 39 | layer size for conditional outcome layers 40 | alpha: float, default=1.0 41 | loss component weighting hyperparameter between 0 and 1 42 | beta: float, default=1.0 43 | targeted regularization hyperparameter between 0 and 1 44 | epochs: int, default=200 45 | Number training epochs 46 | batch_size: int, default=64 47 | Training batch size 48 | learning_rate: float, default=1e-3 49 | Learning rate 50 | data_loader_num_workers: int, default=4 51 | Number of workers for data loader 52 | loss: str, {'tarreg', 'default'}, default='tarreg' 53 | Loss function to use 54 | ``` 55 | 56 | ### To do: 57 | 1) Replicate experiments on IHDP and ACIC data -------------------------------------------------------------------------------- /dragonnet/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class DragonNetBase(nn.Module): 8 | """ 9 | Base Dragonnet model. 10 | 11 | Parameters 12 | ---------- 13 | input_dim: int 14 | input dimension for convariates 15 | shared_hidden: int 16 | layer size for hidden shared representation layers 17 | outcome_hidden: int 18 | layer size for conditional outcome layers 19 | """ 20 | def __init__(self, input_dim, shared_hidden=200, outcome_hidden=100): 21 | super(DragonNetBase, self).__init__() 22 | self.fc1 = nn.Linear(in_features=input_dim, out_features=shared_hidden) 23 | self.fc2 = nn.Linear(in_features=shared_hidden, out_features=shared_hidden) 24 | self.fcz = nn.Linear(in_features=shared_hidden, out_features=shared_hidden) 25 | 26 | self.treat_out = nn.Linear(in_features=shared_hidden, out_features=1) 27 | 28 | self.y0_fc1 = nn.Linear(in_features=shared_hidden, out_features=outcome_hidden) 29 | self.y0_fc2 = nn.Linear(in_features=outcome_hidden, out_features=outcome_hidden) 30 | self.y0_out = nn.Linear(in_features=outcome_hidden, out_features=1) 31 | 32 | self.y1_fc1 = nn.Linear(in_features=shared_hidden, out_features=outcome_hidden) 33 | self.y1_fc2 = nn.Linear(in_features=outcome_hidden, out_features=outcome_hidden) 34 | self.y1_out = nn.Linear(in_features=outcome_hidden, out_features=1) 35 | 36 | self.epsilon = nn.Linear(in_features=1, out_features=1) 37 | torch.nn.init.xavier_normal_(self.epsilon.weight) 38 | 39 | def forward(self, inputs): 40 | """ 41 | forward method to train model. 42 | 43 | Parameters 44 | ---------- 45 | inputs: torch.Tensor 46 | covariates 47 | 48 | Returns 49 | ------- 50 | y0: torch.Tensor 51 | outcome under control 52 | y1: torch.Tensor 53 | outcome under treatment 54 | t_pred: torch.Tensor 55 | predicted treatment 56 | eps: torch.Tensor 57 | trainable epsilon parameter 58 | """ 59 | x = F.relu(self.fc1(inputs)) 60 | x = F.relu(self.fc2(x)) 61 | z = F.relu(self.fcz(x)) 62 | 63 | t_pred = torch.sigmoid(self.treat_out(z)) 64 | 65 | y0 = F.relu(self.y0_fc1(z)) 66 | y0 = F.relu(self.y0_fc2(y0)) 67 | y0 = self.y0_out(y0) 68 | 69 | y1 = F.relu(self.y1_fc1(z)) 70 | y1 = F.relu(self.y1_fc2(y1)) 71 | y1 = self.y1_out(y1) 72 | 73 | eps = self.epsilon(torch.ones_like(t_pred)[:, 0:1]) 74 | 75 | return y0, y1, t_pred, eps 76 | 77 | 78 | def dragonnet_loss(y_true, t_true, t_pred, y0_pred, y1_pred, eps, alpha=1.0): 79 | """ 80 | Generic loss function for dragonnet 81 | 82 | Parameters 83 | ---------- 84 | y_true: torch.Tensor 85 | Actual target variable 86 | t_true: torch.Tensor 87 | Actual treatment variable 88 | t_pred: torch.Tensor 89 | Predicted treatment 90 | y0_pred: torch.Tensor 91 | Predicted target variable under control 92 | y1_pred: torch.Tensor 93 | Predicted target variable under treatment 94 | eps: torch.Tensor 95 | Trainable epsilon parameter 96 | alpha: float 97 | loss component weighting hyperparameter between 0 and 1 98 | Returns 99 | ------- 100 | loss: torch.Tensor 101 | """ 102 | t_pred = (t_pred + 0.01) / 1.02 103 | loss_t = torch.sum(F.binary_cross_entropy(t_pred, t_true)) 104 | 105 | loss0 = torch.sum((1. - t_true) * torch.square(y_true - y0_pred)) 106 | loss1 = torch.sum(t_true * torch.square(y_true - y1_pred)) 107 | loss_y = loss0 + loss1 108 | 109 | loss = loss_y + alpha * loss_t 110 | 111 | return loss 112 | 113 | 114 | def tarreg_loss(y_true, t_true, t_pred, y0_pred, y1_pred, eps, alpha=1.0, beta=1.0): 115 | """ 116 | Targeted regularisation loss function for dragonnet 117 | 118 | Parameters 119 | ---------- 120 | y_true: torch.Tensor 121 | Actual target variable 122 | t_true: torch.Tensor 123 | Actual treatment variable 124 | t_pred: torch.Tensor 125 | Predicted treatment 126 | y0_pred: torch.Tensor 127 | Predicted target variable under control 128 | y1_pred: torch.Tensor 129 | Predicted target variable under treatment 130 | eps: torch.Tensor 131 | Trainable epsilon parameter 132 | alpha: float 133 | loss component weighting hyperparameter between 0 and 1 134 | beta: float 135 | targeted regularization hyperparameter between 0 and 1 136 | Returns 137 | ------- 138 | loss: torch.Tensor 139 | """ 140 | vanilla_loss = dragonnet_loss(y_true, t_true, t_pred, y0_pred, y1_pred, alpha) 141 | t_pred = (t_pred + 0.01) / 1.02 142 | 143 | y_pred = t_true * y1_pred + (1 - t_true) * y0_pred 144 | 145 | h = (t_true / t_pred) - ((1 - t_true) / (1 - t_pred)) 146 | 147 | y_pert = y_pred + eps * h 148 | targeted_regularization = torch.sum((y_true - y_pert)**2) 149 | 150 | # final 151 | loss = vanilla_loss + beta * targeted_regularization 152 | return loss 153 | 154 | 155 | class EarlyStopper: 156 | def __init__(self, patience=15, min_delta=0): 157 | self.patience = patience 158 | self.min_delta = min_delta 159 | self.counter = 0 160 | self.min_validation_loss = np.inf 161 | 162 | def early_stop(self, validation_loss): 163 | if validation_loss < self.min_validation_loss: 164 | self.min_validation_loss = validation_loss 165 | self.counter = 0 166 | elif validation_loss > (self.min_validation_loss + self.min_delta): 167 | self.counter += 1 168 | if self.counter >= self.patience: 169 | return True 170 | return False 171 | 172 | -------------------------------------------------------------------------------- /dragonnet/dragonnet.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import numpy as np 5 | from sklearn.model_selection import train_test_split 6 | from torch.utils.data import TensorDataset, DataLoader 7 | 8 | from dragonnet.model import DragonNetBase, dragonnet_loss, tarreg_loss, EarlyStopper 9 | 10 | 11 | class DragonNet: 12 | """ 13 | Main class for the Dragonnet model 14 | 15 | Parameters 16 | ---------- 17 | input_dim: int 18 | input dimension for convariates 19 | shared_hidden: int, default=200 20 | layer size for hidden shared representation layers 21 | outcome_hidden: int, default=100 22 | layer size for conditional outcome layers 23 | alpha: float, default=1.0 24 | loss component weighting hyperparameter between 0 and 1 25 | beta: float, default=1.0 26 | targeted regularization hyperparameter between 0 and 1 27 | epochs: int, default=200 28 | Number training epochs 29 | batch_size: int, default=64 30 | Training batch size 31 | learning_rate: float, default=1e-3 32 | Learning rate 33 | data_loader_num_workers: int, default=4 34 | Number of workers for data loader 35 | loss_type: str, {'tarreg', 'default'}, default='tarreg' 36 | Loss function to use 37 | """ 38 | 39 | def __init__( 40 | self, 41 | input_dim, 42 | shared_hidden=200, 43 | outcome_hidden=100, 44 | alpha=1.0, 45 | beta=1.0, 46 | epochs=200, 47 | batch_size=64, 48 | learning_rate=1e-5, 49 | data_loader_num_workers=4, 50 | loss_type="tarreg", 51 | ): 52 | 53 | self.model = DragonNetBase(input_dim, shared_hidden, outcome_hidden) 54 | self.epochs = epochs 55 | self.batch_size = batch_size 56 | self.num_workers = data_loader_num_workers 57 | self.optim = torch.optim.Adam(self.model.parameters(), lr=learning_rate) 58 | self.train_dataloader = None 59 | self.valid_dataloader = None 60 | if loss_type == "tarreg": 61 | self.loss_f = partial(tarreg_loss, alpha=alpha, beta=beta) 62 | elif loss_type == "default": 63 | self.loss_f = partial(dragonnet_loss, alpha=alpha) 64 | 65 | def create_dataloaders(self, x, y, t, valid_perc=None): 66 | """ 67 | Utility function to create train and validation data loader: 68 | 69 | Parameters 70 | ---------- 71 | x: np.array 72 | covariates 73 | y: np.array 74 | target variable 75 | t: np.array 76 | treatment 77 | """ 78 | if valid_perc: 79 | x_train, x_test, y_train, y_test, t_train, t_test = train_test_split( 80 | x, y, t, test_size=valid_perc, random_state=42 81 | ) 82 | x_train = torch.Tensor(x_train) 83 | x_test = torch.Tensor(x_test) 84 | y_train = torch.Tensor(y_train).reshape(-1, 1) 85 | y_test = torch.Tensor(y_test).reshape(-1, 1) 86 | t_train = torch.Tensor(t_train).reshape(-1, 1) 87 | t_test = torch.Tensor(t_test).reshape(-1, 1) 88 | train_dataset = TensorDataset(x_train, t_train, y_train) 89 | valid_dataset = TensorDataset(x_test, t_test, y_test) 90 | self.train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.num_workers) 91 | self.valid_dataloader = DataLoader(valid_dataset, batch_size=self.batch_size, num_workers=self.num_workers) 92 | else: 93 | x = torch.Tensor(x) 94 | t = torch.Tensor(t).reshape(-1, 1) 95 | y = torch.Tensor(y).reshape(-1, 1) 96 | train_dataset = TensorDataset(x, t, y) 97 | self.train_dataloader = DataLoader( 98 | train_dataset, batch_size=self.batch_size, num_workers=self.num_workers 99 | ) 100 | 101 | def fit(self, x, y, t, valid_perc=None): 102 | """ 103 | Function used to train the dragonnet model 104 | 105 | Parameters 106 | ---------- 107 | x: np.array 108 | covariates 109 | y: np.array 110 | target variable 111 | t: np.array 112 | treatment 113 | valid_perc: float 114 | Percentage of data to allocate to validation set 115 | """ 116 | self.create_dataloaders(x, y, t, valid_perc) 117 | early_stopper = EarlyStopper(patience=10, min_delta=0) 118 | for epoch in range(self.epochs): 119 | for batch, (X, tr, y1) in enumerate(self.train_dataloader): 120 | y0_pred, y1_pred, t_pred, eps = self.model(X) 121 | loss = self.loss_f(y1, tr, t_pred, y0_pred, y1_pred, eps) 122 | self.optim.zero_grad() 123 | loss.backward() 124 | self.optim.step() 125 | if self.valid_dataloader: 126 | self.model.eval() 127 | valid_loss = self.validate_step() 128 | print( 129 | f"epoch: {epoch}--------- train_loss: {loss} ----- valid_loss: {valid_loss}" 130 | ) 131 | self.model.train() 132 | if early_stopper.early_stop(valid_loss): 133 | break 134 | else: 135 | print(f"epoch: {epoch}--------- train_loss: {loss}") 136 | 137 | def validate_step(self): 138 | """ 139 | Calculates validation loss 140 | 141 | Returns 142 | ------- 143 | valid_loss: torch.Tensor 144 | validation loss 145 | """ 146 | valid_loss = [] 147 | with torch.no_grad(): 148 | for batch, (X, tr, y1) in enumerate(self.valid_dataloader): 149 | y0_pred, y1_pred, t_pred, eps = self.predict(X) 150 | loss = self.loss_f(y1, tr, t_pred, y0_pred, y1_pred, eps) 151 | valid_loss.append(loss) 152 | return torch.Tensor(valid_loss).mean() 153 | 154 | def predict(self, x): 155 | """ 156 | Function used to predict on covariates. 157 | 158 | Parameters 159 | ---------- 160 | x: torch.Tensor or numpy.array 161 | covariates 162 | 163 | Returns 164 | ------- 165 | y0_pred: torch.Tensor 166 | outcome under control 167 | y1_pred: torch.Tensor 168 | outcome under treatment 169 | t_pred: torch.Tensor 170 | predicted treatment 171 | eps: torch.Tensor 172 | trainable epsilon parameter 173 | """ 174 | x = torch.Tensor(x) 175 | with torch.no_grad(): 176 | y0_pred, y1_pred, t_pred, eps = self.model(x) 177 | return y0_pred, y1_pred, t_pred, eps 178 | --------------------------------------------------------------------------------