├── source ├── __init__.py ├── descn │ ├── __init__.py │ ├── util.py │ ├── models.py │ └── descn.py ├── dragonnet │ ├── __init__.py │ ├── model.py │ └── dragonnet.py ├── factory.py ├── objective.py ├── runner.py └── pb_utils.py ├── requirements.txt ├── setup_env.sh ├── config.yaml ├── README.md ├── .gitignore ├── RunCPUTasks.ipynb ├── run_experiment.py ├── RunGPUTasks.ipynb ├── visualization.py ├── datasets └── get_data.py └── LICENSE /source/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /source/descn/__init__.py: -------------------------------------------------------------------------------- 1 | from .descn import DESCNNet 2 | -------------------------------------------------------------------------------- /source/dragonnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .dragonnet import DragonNet 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | py-boost==0.5.1 2 | scikit-uplift==0.5.1 3 | openml==0.14.2 4 | causalml==0.15.1 5 | econml==0.15.0 6 | optuna==3.6.1 7 | geomloss=0.2.6 -------------------------------------------------------------------------------- /setup_env.sh: -------------------------------------------------------------------------------- 1 | conda create -p ./rapids-24.04 -c rapidsai -c conda-forge -c nvidia \ 2 | rapids=24.04 python=3.10 cuda-version=11.8 pytorch -y 3 | conda activate -p ./rapids-24.04 4 | pip install -r requirements.txt 5 | -------------------------------------------------------------------------------- /source/descn/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from geomloss import SamplesLoss 3 | 4 | 5 | def wasserstein_torch(X, t): 6 | """ Returns the Wasserstein distance between treatment groups """ 7 | 8 | it = torch.where(t == 1)[0] 9 | ic = torch.where(t == 0)[0] 10 | Xc = X[ic] 11 | Xt = X[it] 12 | samples_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.05, backend="tensorized") 13 | imbalance_loss = samples_loss(Xt, Xc) 14 | 15 | return imbalance_loss 16 | 17 | 18 | def mmd2_torch(X, t): 19 | it = torch.where(t == 1)[0] 20 | ic = torch.where(t == 0)[0] 21 | Xc = X[ic] 22 | Xt = X[it] 23 | 24 | samples_loss = SamplesLoss(loss="energy", p=2, blur=0.05, backend="tensorized") 25 | imbalance_loss = samples_loss(Xt, Xc) 26 | 27 | return imbalance_loss 28 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | # general params 2 | experiment: 'final' 3 | n_runs: 5 4 | 5 | # hyperparams search params 6 | optuna: 7 | n_startup_trials: 20 8 | multivariate: True 9 | n_trials: 50 10 | timeout: 10000 11 | 12 | 13 | # default params for all the frameworks 14 | xgb: 15 | n_estimators: 500 16 | learning_rate: 0.01 17 | max_depth: 6 18 | tree_method: 'hist' 19 | min_child_weight: 0 20 | lambda: 1 21 | max_bin: 256 22 | gamma: 0 23 | alpha: 0 24 | n_jobs: 12 25 | 26 | lgb: 27 | n_estimators: 500 28 | learning_rate: 0.01 29 | num_leaves: 255 30 | max_depth: 6 31 | min_child_samples: 1 32 | reg_lambda: 1 33 | max_bin: 256 34 | min_split_gain: 0 35 | reg_alpha: 0 36 | 37 | pb: 38 | ntrees: 500 39 | lr: 0.01 40 | verbose: 1000 41 | es: 0 42 | lambda_l2: 1 43 | gd_steps: 1 44 | subsample: 1 45 | colsample: 1 46 | min_data_in_leaf: 1 47 | use_hess: True 48 | max_bin: 256 49 | max_depth: 6 50 | 51 | crf: 52 | n_estimators: 500 53 | min_samples_split: 10 54 | criterion: 'mse' 55 | max_depth: 10 56 | min_samples_leaf: 100 57 | max_features: 0.3 58 | max_samples: 0.5 59 | min_balancedness_tol: 0.2 60 | honest: True 61 | n_jobs: 40 62 | 63 | dr: 64 | hidden_scale: 2. 65 | outcome_scale: .5 66 | alpha: 1.0 67 | beta: 1.0 68 | epochs: 30 69 | steps_per_epoch: 150 70 | learning_rate: 1e-3 71 | data_loader_num_workers: 4 72 | loss_type: 'tarreg' 73 | device: 'cuda' 74 | 75 | dcn: 76 | epochs: 10 77 | lr: 0.001 78 | steps_per_epoch: 150 79 | prpsy_w: 0.5 80 | escvr1_w: 0.5 81 | escvr0_w: 1 82 | h1_w: 0 83 | h0_w: 0 84 | mu0hat_w: 0.5 85 | mu1hat_w: 1 86 | share_scale: 2 87 | base_scale: .5 88 | data_loader_num_workers: 4 89 | device: 'cuda' 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Experimental code for "Uplift modeling via Gradient Boosting paper" 2 | 3 | This repository is created to reproduce the experiments of the paper. It contains the scripts to run the experiments and Jupyter notebooks to launch and analyze the results. 4 | 5 | 6 | ### Requirements 7 | 8 | The proposed method is based on the frameworks that requires Nvidia GPU. We used 24 Cores CPU, 512 GB RAM, and 2xTesla V100 to obtain the experimental results, but it is possible to execute with less hardware. 9 | 10 | To setup the environment you need to have `conda` installed. After that, please execute the following to install all the dependencies (it is assumed that you run everything from the repository root directory): 11 | 12 | ```bash 13 | bash ./setup_env.sh 14 | 15 | ``` 16 | 17 | After that `rapids-24.04` conda env will be created in the repository root dir. You need to run all the experiments under this env. You can activate it by executing `conda activate -p ./rapids-24.04` 18 | 19 | ### Data 20 | 21 | To download and preprocess all the data please execute 22 | 23 | ```bash 24 | python datasets/get_data.py 25 | 26 | ``` 27 | 28 | ### Synthetic experiment 29 | 30 | Please run `Synthetic.ipynb` notebook to obtain the results provided in **Section 4.1** 31 | 32 | ### Main experiment 33 | 34 | The main experiment provided in **Section 4.2** is separated on the two parts: 35 | 36 | * GPU based experiments contain the proposed method together with Neural Network baselines. You can execute it by running `RunGPUTasks.ipynb`. Please, adjust the `Params` cell according to your hardware. The proposed results were obtained with the listed above hardware 37 | 38 | * CPU based experiments Meta Learners based algorithms and CausalForest. You can execute it by running `RunCPUTasks.ipynb`. Please, adjust the `Params` cell according to your hardware. The proposed results were obtained with the listed above hardware 39 | 40 | ### Analyze the results 41 | 42 | After evaluations are finalized, you can obtain the contents by running `ResultsMain.ipynb` notebook 43 | -------------------------------------------------------------------------------- /source/factory.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from causalml.inference.meta import BaseDRRegressor 4 | from causalml.inference.meta import BaseTClassifier, BaseXClassifier, BaseRClassifier 5 | from econml.grf import CausalForest 6 | from py_boost import GradientBoosting 7 | from xgboost import XGBRegressor, XGBClassifier 8 | 9 | from .descn import DESCNNet 10 | from .dragonnet import DragonNet 11 | from .pb_utils import ComposedUpliftLoss, QINIMetric, BCEWithNaNLoss, UpliftSplitter, UpliftSplitterXN, \ 12 | RandomSamplingSketchX 13 | 14 | 15 | def get_tlearner_xgb(params, estimator=XGBClassifier): 16 | model = BaseTClassifier( 17 | control_learner=estimator(**params['control_learner']), 18 | treatment_learner=estimator(**params['treatment_learner']), 19 | control_name='control' 20 | ) 21 | 22 | return model 23 | 24 | 25 | def get_xlearner_xgb(params, estimator=XGBClassifier, regressor=XGBRegressor): 26 | model = BaseXClassifier( 27 | control_outcome_learner=estimator(**params['control_outcome_learner']), 28 | treatment_outcome_learner=estimator(**params['treatment_outcome_learner']), 29 | control_effect_learner=regressor(**params['control_effect_learner']), 30 | treatment_effect_learner=regressor(**params['treatment_effect_learner']), 31 | control_name='control' 32 | ) 33 | 34 | return model 35 | 36 | 37 | def get_rlearner_xgb(params, estimator=XGBClassifier, regressor=XGBRegressor): 38 | model = BaseRClassifier( 39 | outcome_learner=estimator(**params['outcome_learner']), 40 | effect_learner=regressor(**params['effect_learner']), 41 | control_name='control' 42 | ) 43 | 44 | return model 45 | 46 | 47 | class XGBRegClassifier(XGBClassifier): 48 | 49 | def predict(self, X): 50 | return self.predict_proba(X)[:, 1] 51 | 52 | 53 | def get_drlearner_xgb(params, estimator=XGBRegClassifier, regressor=XGBRegressor): 54 | model = BaseDRRegressor( 55 | control_outcome_learner=estimator(**params['control_outcome_learner']), 56 | treatment_outcome_learner=estimator(**params['treatment_outcome_learner']), 57 | treatment_effect_learner=regressor(**params['treatment_effect_learner']), 58 | 59 | control_name='control' 60 | ) 61 | 62 | return model 63 | 64 | 65 | def get_pb_uplift(params, xn=True, masked=True, weight=.5): 66 | params = deepcopy(params['params']) 67 | 68 | loss = ComposedUpliftLoss(BCEWithNaNLoss(), 1, weight=weight, masked=masked) 69 | metric = QINIMetric(freq=10) 70 | splitter = UpliftSplitterXN() if xn else UpliftSplitter() 71 | 72 | model = GradientBoosting( 73 | loss, metric, 74 | target_splitter=splitter, 75 | multioutput_sketch=RandomSamplingSketchX(1, smooth=1), 76 | callbacks=[ 77 | loss, 78 | metric, 79 | ], 80 | **params 81 | ) 82 | 83 | return model 84 | 85 | 86 | def get_crf(params, ): 87 | model = CausalForest( 88 | **params['params'] 89 | ) 90 | 91 | return model 92 | 93 | 94 | def get_drnet(params, cat_cols=None): 95 | model = DragonNet( 96 | cat_cols=cat_cols, **params['params'] 97 | ) 98 | 99 | return model 100 | 101 | 102 | def get_dcn(params, cat_cols=None): 103 | model = DESCNNet( 104 | cat_cols=cat_cols, **params['params'] 105 | ) 106 | 107 | return model 108 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # project specific 2 | datasets/*/ 3 | rapids-24.04/ 4 | experiment/ 5 | final*/ 6 | visualization_* 7 | Results.ipynb 8 | synthetic.ipynb 9 | w_vs_auuc.png 10 | w_vs_mse.png 11 | data_cs_scatter.png 12 | data_ctr_scatter.png 13 | data_effect_hist.png 14 | data_tr_scatter.png 15 | pred_hist_kde.png 16 | pred_scatter.png 17 | 18 | # Byte-compiled / optimized / DLL files 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | share/python-wheels/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | MANIFEST 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .nox/ 60 | .coverage 61 | .coverage.* 62 | .cache 63 | nosetests.xml 64 | coverage.xml 65 | *.cover 66 | *.py,cover 67 | .hypothesis/ 68 | .pytest_cache/ 69 | cover/ 70 | 71 | # Translations 72 | *.mo 73 | *.pot 74 | 75 | # Django stuff: 76 | *.log 77 | local_settings.py 78 | db.sqlite3 79 | db.sqlite3-journal 80 | 81 | # Flask stuff: 82 | instance/ 83 | .webassets-cache 84 | 85 | # Scrapy stuff: 86 | .scrapy 87 | 88 | # Sphinx documentation 89 | docs/_build/ 90 | 91 | # PyBuilder 92 | .pybuilder/ 93 | target/ 94 | 95 | # Jupyter Notebook 96 | .ipynb_checkpoints 97 | 98 | # IPython 99 | profile_default/ 100 | ipython_config.py 101 | 102 | # pyenv 103 | # For a library or package, you might want to ignore these files since the code is 104 | # intended to run in multiple environments; otherwise, check them in: 105 | # .python-version 106 | 107 | # pipenv 108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 111 | # install all needed dependencies. 112 | #Pipfile.lock 113 | 114 | # poetry 115 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 116 | # This is especially recommended for binary packages to ensure reproducibility, and is more 117 | # commonly ignored for libraries. 118 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 119 | #poetry.lock 120 | 121 | # pdm 122 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 123 | #pdm.lock 124 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 125 | # in version control. 126 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 127 | .pdm.toml 128 | .pdm-python 129 | .pdm-build/ 130 | 131 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 132 | __pypackages__/ 133 | 134 | # Celery stuff 135 | celerybeat-schedule 136 | celerybeat.pid 137 | 138 | # SageMath parsed files 139 | *.sage.py 140 | 141 | # Environments 142 | .env 143 | .venv 144 | env/ 145 | venv/ 146 | ENV/ 147 | env.bak/ 148 | venv.bak/ 149 | 150 | # Spyder project settings 151 | .spyderproject 152 | .spyproject 153 | 154 | # Rope project settings 155 | .ropeproject 156 | 157 | # mkdocs documentation 158 | /site 159 | 160 | # mypy 161 | .mypy_cache/ 162 | .dmypy.json 163 | dmypy.json 164 | 165 | # Pyre type checker 166 | .pyre/ 167 | 168 | # pytype static type analyzer 169 | .pytype/ 170 | 171 | # Cython debug symbols 172 | cython_debug/ 173 | 174 | # PyCharm 175 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 176 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 177 | # and can be added to the global gitignore or merged into this file. For a more nuclear 178 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 179 | .idea/ -------------------------------------------------------------------------------- /RunCPUTasks.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c833111d", 6 | "metadata": {}, 7 | "source": [ 8 | "## Run CPU baselines" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "8b86fd14", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import os\n", 19 | "import subprocess\n", 20 | "from joblib import Parallel, delayed\n", 21 | "from multiprocessing import Queue\n", 22 | "\n", 23 | "from itertools import product" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "id": "2e2f992c", 29 | "metadata": {}, 30 | "source": [ 31 | "### Params" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "387d955e", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# parameters\n", 42 | "N_JOBS = 8\n", 43 | "N_PARALLEL_TASKS = 3\n", 44 | "PYTHON = 'rapids-24.04/bin/python'" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "id": "ae295eed", 50 | "metadata": {}, 51 | "source": [ 52 | "### Utils" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "916c751e", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "# run script to execute the task\n", 63 | "def get_script(path, runner, tuner, model):\n", 64 | " \"\"\"\n", 65 | " Get run script for the task\n", 66 | " \"\"\"\n", 67 | " command = f\"\"\"\n", 68 | " {PYTHON} run_experiment.py \\\n", 69 | " --path {os.path.join('datasets', path)} \\\n", 70 | " --njobs {N_JOBS} \\\n", 71 | " --seed 42 \\\n", 72 | " --device 0 \\\n", 73 | " --runner {runner} \\\n", 74 | " --tuner {tuner} \\\n", 75 | " --model {model} \\\n", 76 | " --config config.yaml\n", 77 | " \"\"\"\n", 78 | " return command\n", 79 | "\n", 80 | "\n", 81 | "def run(path, model, runner, tuner, ):\n", 82 | " \"\"\"\n", 83 | " Run task\n", 84 | " \"\"\"\n", 85 | " # generate script\n", 86 | " script = get_script(path, runner, tuner, model)\n", 87 | " print(script)\n", 88 | " # run task\n", 89 | " subprocess.check_output(script, shell=True, stderr=subprocess.STDOUT,)\n", 90 | " \n", 91 | " return " 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "id": "2889b48b", 97 | "metadata": {}, 98 | "source": [ 99 | "### Tasks list" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "id": "5346ad5c", 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "# tasks list\n", 110 | "datasets = [\n", 111 | " \n", 112 | " 'synth1', \n", 113 | " 'hillstrom', \n", 114 | " 'criteo',\n", 115 | " 'lenta',\n", 116 | " 'megafon',\n", 117 | "]\n", 118 | "\n", 119 | "# tuple: (type of model, run function, objective with param space)\n", 120 | "models = [\n", 121 | " # t learner\n", 122 | " ('xgb_t', 'meta', 'xgb_single'), \n", 123 | " # x learner\n", 124 | " ('xgb_x', 'meta', 'xgb_single'), \n", 125 | " # r learner\n", 126 | " ('xgb_r', 'meta', 'xgb_single'), \n", 127 | " # dr learner\n", 128 | " ('xgb_dr', 'meta', 'xgb_single'), \n", 129 | " # Causal RF\n", 130 | " ('crf', 'crf', 'crf')\n", 131 | "]\n", 132 | "\n", 133 | "# combine datasets and models\n", 134 | "tasks = product(\n", 135 | " map(\n", 136 | " lambda x: x[0] + '_' + str(x[1]), product(datasets, range(5))\n", 137 | " ),\n", 138 | " models\n", 139 | ")" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "id": "ae420827", 145 | "metadata": {}, 146 | "source": [ 147 | "### Run " 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "id": "d0628c4f", 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "with Parallel(N_PARALLEL_TASKS, backend='threading') as p:\n", 158 | " p(delayed(run)(d, *m) for (d, m) in tasks)" 159 | ] 160 | } 161 | ], 162 | "metadata": { 163 | "kernelspec": { 164 | "display_name": "rapids-env", 165 | "language": "python", 166 | "name": "rapids-env" 167 | }, 168 | "language_info": { 169 | "codemirror_mode": { 170 | "name": "ipython", 171 | "version": 3 172 | }, 173 | "file_extension": ".py", 174 | "mimetype": "text/x-python", 175 | "name": "python", 176 | "nbconvert_exporter": "python", 177 | "pygments_lexer": "ipython3", 178 | "version": "3.10.14" 179 | } 180 | }, 181 | "nbformat": 4, 182 | "nbformat_minor": 5 183 | } 184 | -------------------------------------------------------------------------------- /run_experiment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | 5 | parser.add_argument('-p', '--path', type=str) # dataset path 6 | parser.add_argument('-n', '--njobs', type=int, default=8) 7 | parser.add_argument('-s', '--seed', type=int, default=42) 8 | parser.add_argument('-d', '--device', type=str, default='0') 9 | parser.add_argument('-r', '--runner', type=str) # train_fn_dict 10 | parser.add_argument('-t', '--tuner', type=str) # obj_dict 11 | parser.add_argument('-m', '--model', type=str) # model_dict 12 | parser.add_argument('-c', '--config', type=str) # general confit file 13 | 14 | if __name__ == '__main__': 15 | 16 | import os 17 | 18 | args = parser.parse_args() 19 | str_nthr = str(args.njobs) 20 | 21 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 22 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 23 | 24 | os.environ["OMP_NUM_THREADS"] = str_nthr # export OMP_NUM_THREADS=4 25 | os.environ["OPENBLAS_NUM_THREADS"] = str_nthr # export OPENBLAS_NUM_THREADS=4 26 | os.environ["MKL_NUM_THREADS"] = str_nthr # export MKL_NUM_THREADS=6 27 | os.environ["VECLIB_MAXIMUM_THREADS"] = str_nthr # export VECLIB_MAXIMUM_THREADS=4 28 | os.environ["NUMEXPR_NUM_THREADS"] = str_nthr # export NUMEXPR_NUM_THREADS=6 29 | 30 | import optuna 31 | import numpy as np 32 | import shutil 33 | import yaml 34 | from copy import deepcopy 35 | from source.objective import * 36 | from source.factory import * 37 | from source.runner import * 38 | 39 | np.random.seed(args.seed) 40 | 41 | train_fn_dict = { 42 | 'meta': train_meta_uber, 43 | 'pb': train_pb_upd_weight, 44 | 'crf': train_crf, 45 | 'dr': train_drnet, 46 | 47 | } 48 | 49 | obj_dict = { 50 | 'xgb_single': ObjectiveSingleXGB, # xgboost with single tuning 51 | 'pb': ObjectivePB, # our model 52 | 'crf': ObjectiveCRF, 53 | 'dr': ObjectiveDR, 54 | 'dcn': ObjectiveDCN 55 | } 56 | 57 | model_dict = { 58 | 'xgb_t': ( 59 | get_tlearner_xgb, ['treatment_learner', 'control_learner'] 60 | ), 61 | 62 | 'xgb_x': ( 63 | get_xlearner_xgb, 64 | [ 65 | 'control_outcome_learner', 66 | 'treatment_outcome_learner', 67 | 'control_effect_learner', 68 | 'treatment_effect_learner' 69 | ] 70 | ), 71 | 72 | 'xgb_r': ( 73 | get_rlearner_xgb, 74 | ['outcome_learner', 'effect_learner'] 75 | ), 76 | 77 | 'xgb_dr': ( 78 | get_drlearner_xgb, 79 | ['control_outcome_learner', 'treatment_outcome_learner', 'treatment_effect_learner'] 80 | ), 81 | 82 | # py-boost learner 83 | 'pb_lc_f_t': ( 84 | lambda x: get_pb_uplift(x, False, True), 85 | ['params'] 86 | ), 87 | 88 | 'crf': ( 89 | get_crf, ['params'] 90 | ), 91 | 'dr': ( 92 | lambda x: get_drnet(x, cat_cols=[5, 7] if 'hillstrom' in args.path else None), 93 | ['params'] 94 | ), 95 | 'dcn': ( 96 | lambda x: get_dcn(x, cat_cols=[5, 7] if 'hillstrom' in args.path else None), 97 | ['params'] 98 | ) 99 | } 100 | 101 | train_fn = train_fn_dict[args.runner] 102 | Obj = obj_dict[args.tuner] 103 | factory, keys = model_dict[args.model] 104 | 105 | with open(args.config, 'r') as f: 106 | config = yaml.safe_load(f) 107 | 108 | os.makedirs(config['experiment'], exist_ok=True) 109 | 110 | params_key = args.model.split('_')[0] 111 | params = deepcopy(config[params_key]) 112 | 113 | ds_name = args.path 114 | if ds_name[-1] == '/': 115 | ds_name = ds_name[:-1] 116 | study_path = os.path.join(config['experiment'], os.path.basename(ds_name), 117 | f'{args.runner}-{args.tuner}-{args.model}') 118 | # remove previous runs 119 | try: 120 | shutil.rmtree(study_path) 121 | except FileNotFoundError: 122 | pass 123 | 124 | # get data 125 | train = joblib.load(os.path.join(args.path, 'train.pkl')) 126 | test = joblib.load(os.path.join(args.path, 'test.pkl')) 127 | print('Study started') 128 | # run study 129 | study = optuna.create_study( 130 | directions=['maximize'] * (train['t'].max()), 131 | sampler=optuna.samplers.TPESampler( 132 | n_startup_trials=config['optuna']['n_startup_trials'], 133 | multivariate=config['optuna']['multivariate'] 134 | ) 135 | 136 | ) 137 | objective = Obj( 138 | # dataset and learner 139 | train, test, train_fn, 140 | # facroty, 141 | factory=factory, 142 | # keys 143 | keys=keys, 144 | # params 145 | params=params, 146 | folder=study_path 147 | ) 148 | 149 | study.optimize(objective, n_trials=config['optuna']['n_trials'], timeout=config['optuna']['timeout']) 150 | joblib.dump(study, os.path.join(study_path, 'study.pkl')) 151 | -------------------------------------------------------------------------------- /RunGPUTasks.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "d8ee1a3a", 6 | "metadata": {}, 7 | "source": [ 8 | "## Run GPU baselines" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "1d6aec6c", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import os\n", 19 | "import subprocess\n", 20 | "from joblib import Parallel, delayed\n", 21 | "from multiprocessing import Queue\n", 22 | "\n", 23 | "from itertools import product" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "id": "926f499c", 29 | "metadata": {}, 30 | "source": [ 31 | "### Params" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "af193cf9", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# parameters\n", 42 | "DEVICE_LIST = [0, 1]\n", 43 | "N_JOBS = 4\n", 44 | "N_TASKS_PER_DEVICE = 1 # since utilization is not high, some of the tasks could be ran on the same device\n", 45 | "PYTHON = '../rapids-24.04/bin/python'" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "id": "2d63ff46", 51 | "metadata": {}, 52 | "source": [ 53 | "### Utils" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "id": "8fb12f50", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "# fill GPU queue\n", 64 | "QUEUE = Queue(maxsize=len(DEVICE_LIST) * N_TASKS_PER_DEVICE)\n", 65 | "for i in range(len(DEVICE_LIST)):\n", 66 | " for _ in range(N_TASKS_PER_DEVICE):\n", 67 | " QUEUE.put(i)\n", 68 | " \n", 69 | "# run script to execute the task\n", 70 | "def get_script(path, runner, tuner, model, device):\n", 71 | " \"\"\"\n", 72 | " Get run script for the task\n", 73 | " \"\"\"\n", 74 | " command = f\"\"\"\n", 75 | " {PYTHON} run_experiment.py \\\n", 76 | " --path {os.path.join('datasets', path)} \\\n", 77 | " --njobs {N_JOBS} \\\n", 78 | " --seed 42 \\\n", 79 | " --device {device} \\\n", 80 | " --runner {runner} \\\n", 81 | " --tuner {tuner} \\\n", 82 | " --model {model} \\\n", 83 | " --config config.yaml\n", 84 | " \"\"\"\n", 85 | " return command\n", 86 | "\n", 87 | "\n", 88 | "def run(path, model, runner, tuner, ):\n", 89 | " \"\"\"\n", 90 | " Run task\n", 91 | " \"\"\"\n", 92 | " # get free GPU\n", 93 | " device = QUEUE.get()\n", 94 | " # generate script\n", 95 | " print('Task started')\n", 96 | " script = get_script(path, runner, tuner, model, device)\n", 97 | " print(script)\n", 98 | " # run task\n", 99 | " subprocess.check_output(script, shell=True, stderr=subprocess.STDOUT,)\n", 100 | " # back to queue\n", 101 | " QUEUE.put(device)\n", 102 | " \n", 103 | " return " 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "id": "4e48e344", 109 | "metadata": {}, 110 | "source": [ 111 | "### Tasks list" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "id": "5ee65812", 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "# tasks list\n", 122 | "datasets = [\n", 123 | " \n", 124 | " 'synth1', \n", 125 | "# 'hillstrom', \n", 126 | "# 'criteo',\n", 127 | "# 'lenta',\n", 128 | "# 'megafon',\n", 129 | "]\n", 130 | "\n", 131 | "# tuple: (type of model, run function, objective with param space)\n", 132 | "models = [\n", 133 | " # py-boost baseline\n", 134 | " # ('pb_lc_f_t', 'pb', 'pb'), \n", 135 | " # dragonnet\n", 136 | " ('dr', 'dr', 'dr'),\n", 137 | " # DESCN\n", 138 | " # ('dcn', 'dr', 'dcn')\n", 139 | "]\n", 140 | "\n", 141 | "# combine datasets and models\n", 142 | "tasks = product(\n", 143 | " map(\n", 144 | " lambda x: x[0] + '_' + str(x[1]), product(datasets, range(5))\n", 145 | " ),\n", 146 | " models\n", 147 | ")" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "id": "93ebf10b", 153 | "metadata": {}, 154 | "source": [ 155 | "### Run " 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "933db84d", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "with Parallel(len(DEVICE_LIST) * N_TASKS_PER_DEVICE, backend='threading') as p:\n", 166 | " p(delayed(run)(d, *m) for (d, m) in tasks)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "id": "04ee3f26", 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [] 176 | } 177 | ], 178 | "metadata": { 179 | "kernelspec": { 180 | "display_name": "Python 3 (ipykernel)", 181 | "language": "python", 182 | "name": "python3" 183 | }, 184 | "language_info": { 185 | "codemirror_mode": { 186 | "name": "ipython", 187 | "version": 3 188 | }, 189 | "file_extension": ".py", 190 | "mimetype": "text/x-python", 191 | "name": "python", 192 | "nbconvert_exporter": "python", 193 | "pygments_lexer": "ipython3", 194 | "version": "3.7.12" 195 | } 196 | }, 197 | "nbformat": 4, 198 | "nbformat_minor": 5 199 | } 200 | -------------------------------------------------------------------------------- /source/dragonnet/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class DragonNetBase(nn.Module): 11 | """ 12 | Base Dragonnet model. 13 | 14 | Parameters 15 | ---------- 16 | input_dim: int 17 | input dimension for convariates 18 | shared_hidden: int 19 | layer size for hidden shared representation layers 20 | outcome_hidden: int 21 | layer size for conditional outcome layers 22 | """ 23 | 24 | def __init__(self, input_dim, shared_hidden=200, outcome_hidden=100): 25 | super(DragonNetBase, self).__init__() 26 | self.fc1 = nn.Linear(in_features=input_dim, out_features=shared_hidden) 27 | self.fc2 = nn.Linear(in_features=shared_hidden, out_features=shared_hidden) 28 | self.fcz = nn.Linear(in_features=shared_hidden, out_features=shared_hidden) 29 | 30 | self.treat_out = nn.Linear(in_features=shared_hidden, out_features=1) 31 | 32 | self.y0_fc1 = nn.Linear(in_features=shared_hidden, out_features=outcome_hidden) 33 | self.y0_fc2 = nn.Linear(in_features=outcome_hidden, out_features=outcome_hidden) 34 | self.y0_out = nn.Linear(in_features=outcome_hidden, out_features=1) 35 | 36 | self.y1_fc1 = nn.Linear(in_features=shared_hidden, out_features=outcome_hidden) 37 | self.y1_fc2 = nn.Linear(in_features=outcome_hidden, out_features=outcome_hidden) 38 | self.y1_out = nn.Linear(in_features=outcome_hidden, out_features=1) 39 | 40 | self.epsilon = nn.Linear(in_features=1, out_features=1) 41 | torch.nn.init.xavier_normal_(self.epsilon.weight) 42 | 43 | def forward(self, inputs): 44 | """ 45 | forward method to train model. 46 | 47 | Parameters 48 | ---------- 49 | inputs: torch.Tensor 50 | covariates 51 | 52 | Returns 53 | ------- 54 | y0: torch.Tensor 55 | outcome under control 56 | y1: torch.Tensor 57 | outcome under treatment 58 | t_pred: torch.Tensor 59 | predicted treatment 60 | eps: torch.Tensor 61 | trainable epsilon parameter 62 | """ 63 | x = F.relu(self.fc1(inputs)) 64 | x = F.relu(self.fc2(x)) 65 | z = F.relu(self.fcz(x)) 66 | 67 | t_pred = torch.sigmoid(self.treat_out(z)) 68 | 69 | y0 = F.relu(self.y0_fc1(z)) 70 | y0 = F.relu(self.y0_fc2(y0)) 71 | y0 = self.y0_out(y0) 72 | 73 | y1 = F.relu(self.y1_fc1(z)) 74 | y1 = F.relu(self.y1_fc2(y1)) 75 | y1 = self.y1_out(y1) 76 | 77 | eps = self.epsilon(torch.ones_like(t_pred)[:, 0:1]) 78 | 79 | return y0, y1, t_pred, eps 80 | 81 | 82 | def dragonnet_loss(y_true, t_true, t_pred, y0_pred, y1_pred, eps, alpha=1.0): 83 | """ 84 | Generic loss function for dragonnet 85 | 86 | Parameters 87 | ---------- 88 | y_true: torch.Tensor 89 | Actual target variable 90 | t_true: torch.Tensor 91 | Actual treatment variable 92 | t_pred: torch.Tensor 93 | Predicted treatment 94 | y0_pred: torch.Tensor 95 | Predicted target variable under control 96 | y1_pred: torch.Tensor 97 | Predicted target variable under treatment 98 | eps: torch.Tensor 99 | Trainable epsilon parameter 100 | alpha: float 101 | loss component weighting hyperparameter between 0 and 1 102 | Returns 103 | ------- 104 | loss: torch.Tensor 105 | """ 106 | t_pred = (t_pred + 0.01) / 1.02 107 | loss_t = torch.sum(F.binary_cross_entropy(t_pred, t_true)) 108 | 109 | loss0 = torch.sum((1. - t_true) * torch.square(y_true - y0_pred)) 110 | loss1 = torch.sum(t_true * torch.square(y_true - y1_pred)) 111 | loss_y = loss0 + loss1 112 | 113 | loss = loss_y + alpha * loss_t 114 | 115 | return loss 116 | 117 | 118 | def tarreg_loss(y_true, t_true, t_pred, y0_pred, y1_pred, eps, alpha=1.0, beta=1.0): 119 | """ 120 | Targeted regularisation loss function for dragonnet 121 | 122 | Parameters 123 | ---------- 124 | y_true: torch.Tensor 125 | Actual target variable 126 | t_true: torch.Tensor 127 | Actual treatment variable 128 | t_pred: torch.Tensor 129 | Predicted treatment 130 | y0_pred: torch.Tensor 131 | Predicted target variable under control 132 | y1_pred: torch.Tensor 133 | Predicted target variable under treatment 134 | eps: torch.Tensor 135 | Trainable epsilon parameter 136 | alpha: float 137 | loss component weighting hyperparameter between 0 and 1 138 | beta: float 139 | targeted regularization hyperparameter between 0 and 1 140 | Returns 141 | ------- 142 | loss: torch.Tensor 143 | """ 144 | vanilla_loss = dragonnet_loss(y_true, t_true, t_pred, y0_pred, y1_pred, alpha) 145 | t_pred = (t_pred + 0.01) / 1.02 146 | 147 | y_pred = t_true * y1_pred + (1 - t_true) * y0_pred 148 | 149 | h = (t_true / t_pred) - ((1 - t_true) / (1 - t_pred)) 150 | 151 | y_pert = y_pred + eps * h 152 | targeted_regularization = torch.sum((y_true - y_pert) ** 2) 153 | 154 | # final 155 | loss = vanilla_loss + beta * targeted_regularization 156 | return loss 157 | 158 | 159 | class EarlyStopper: 160 | def __init__(self, temp_folder, patience=15, min_delta=0, ): 161 | self.patience = patience 162 | self.min_delta = min_delta 163 | self.counter = 0 164 | self.min_validation_loss = np.inf 165 | self.temp_folder = temp_folder 166 | os.makedirs(temp_folder, exist_ok=False) 167 | 168 | def early_stop(self, validation_loss, model): 169 | if validation_loss < self.min_validation_loss: 170 | self.min_validation_loss = validation_loss 171 | self.counter = 0 172 | torch.save(model.state_dict(), os.path.join(self.temp_folder, 'checkpoint.pth')) 173 | 174 | 175 | elif validation_loss > (self.min_validation_loss + self.min_delta): 176 | self.counter += 1 177 | if self.counter >= self.patience: 178 | return True 179 | return False 180 | 181 | def load(self, model): 182 | 183 | model.load_state_dict(torch.load(os.path.join(self.temp_folder, 'checkpoint.pth'))) 184 | return model 185 | 186 | def clear(self): 187 | 188 | shutil.rmtree(self.temp_folder) 189 | -------------------------------------------------------------------------------- /source/descn/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def init_weights(m): 8 | if isinstance(m, nn.Linear): 9 | stdv = 1 / math.sqrt(m.weight.size(1)) 10 | torch.nn.init.normal_(m.weight, mean=0.0, std=stdv) 11 | # torch.nn.init.xavier_uniform_(m.weight) 12 | m.bias.data.fill_(0) 13 | 14 | 15 | def sigmod2(y): 16 | # y = torch.clamp(0.995 / (1.0 + torch.exp(-y)) + 0.0025, 0, 1) 17 | # y = torch.clamp(y, -16, 16) 18 | y = torch.sigmoid(y) 19 | # y = 0.995 / (1.0 + torch.exp(-y)) + 0.0025 20 | 21 | return y 22 | 23 | 24 | def safe_sqrt(x): 25 | ''' Numerically safe version of Pytoch sqrt ''' 26 | return torch.sqrt(torch.clip(x, 1e-9, 1e+9)) 27 | 28 | 29 | class ShareNetwork(nn.Module): 30 | def __init__(self, input_dim, share_dim, base_dim, cfg, device): 31 | super(ShareNetwork, self).__init__() 32 | if cfg.BatchNorm1d == 'true': 33 | print("use BatchNorm1d") 34 | self.DNN = nn.Sequential( 35 | 36 | nn.BatchNorm1d(input_dim), 37 | nn.Linear(input_dim, share_dim), 38 | nn.ELU(), 39 | nn.Dropout(p=cfg.do_rate), 40 | nn.Linear(share_dim, share_dim), 41 | # nn.BatchNorm1d(share_dim), 42 | nn.ELU(), 43 | nn.Dropout(p=cfg.do_rate), 44 | nn.Linear(share_dim, base_dim), 45 | # nn.BatchNorm1d(base_dim), 46 | nn.ELU(), 47 | nn.Dropout(p=cfg.do_rate) 48 | ) 49 | else: 50 | print("No BatchNorm1d") 51 | self.DNN = nn.Sequential( 52 | nn.Linear(input_dim, share_dim), 53 | nn.ELU(), 54 | nn.Dropout(p=cfg.do_rate), 55 | nn.Linear(share_dim, share_dim), 56 | nn.ELU(), 57 | nn.Dropout(p=cfg.do_rate), 58 | nn.Linear(share_dim, base_dim), 59 | nn.ELU(), 60 | ) 61 | 62 | self.DNN.apply(init_weights) 63 | self.cfg = cfg 64 | self.device = device 65 | self.to(device) 66 | 67 | def forward(self, x): 68 | x = x.to(self.device) 69 | h_rep = self.DNN(x) 70 | if self.cfg.normalization == "divide": 71 | h_rep_norm = h_rep / safe_sqrt(torch.sum(torch.square(h_rep), dim=1, keepdim=True)) 72 | else: 73 | h_rep_norm = 1.0 * h_rep 74 | return h_rep_norm 75 | 76 | 77 | class BaseModel(nn.Module): 78 | def __init__(self, base_dim, cfg): 79 | super(BaseModel, self).__init__() 80 | self.DNN = nn.Sequential( 81 | nn.Linear(base_dim, base_dim), 82 | # nn.BatchNorm1d(base_dim), 83 | nn.ELU(), 84 | nn.Dropout(p=cfg.do_rate), 85 | nn.Linear(base_dim, base_dim), 86 | # nn.BatchNorm1d(base_dim), 87 | nn.ELU(), 88 | nn.Dropout(p=cfg.do_rate), 89 | nn.Linear(base_dim, base_dim), 90 | # nn.BatchNorm1d(base_dim), 91 | nn.ELU(), 92 | nn.Dropout(p=cfg.do_rate) 93 | ) 94 | self.DNN.apply(init_weights) 95 | 96 | def forward(self, x): 97 | logits = self.DNN(x) 98 | return logits 99 | 100 | 101 | class BaseModel4MetaLearner(nn.Module): 102 | def __init__(self, input_dim, base_dim, cfg, device): 103 | super(BaseModel4MetaLearner, self).__init__() 104 | self.DNN = nn.Sequential( 105 | nn.BatchNorm1d(input_dim), 106 | nn.Linear(input_dim, base_dim), 107 | nn.ELU(), 108 | nn.Dropout(p=cfg.do_rate), 109 | nn.Linear(base_dim, base_dim), 110 | # nn.BatchNorm1d(share_dim), 111 | # nn.ELU(), 112 | # nn.Dropout(p=cfg.do_rate), 113 | # nn.Linear(base_dim, base_dim), 114 | # nn.BatchNorm1d(share_dim), 115 | nn.ELU(), 116 | nn.Dropout(p=cfg.do_rate), 117 | nn.Linear(base_dim, 1), 118 | # nn.ELU() 119 | # nn.BatchNorm1d(base_dim), 120 | ) 121 | self.DNN.apply(init_weights) 122 | self.cfg = cfg 123 | self.device = device 124 | self.to(device) 125 | 126 | def forward(self, x): 127 | x = x.to(self.device) 128 | logit = self.DNN(x) 129 | return logit 130 | 131 | 132 | class PrpsyNetwork(nn.Module): 133 | """propensity network""" 134 | 135 | def __init__(self, base_dim, cfg): 136 | super(PrpsyNetwork, self).__init__() 137 | self.baseModel = BaseModel(base_dim, cfg) 138 | self.logitLayer = nn.Linear(base_dim, 1) 139 | self.sigmoid = nn.Sigmoid() 140 | self.logitLayer.apply(init_weights) 141 | 142 | def forward(self, inputs): 143 | inputs = self.baseModel(inputs) 144 | p = self.logitLayer(inputs) 145 | return p 146 | 147 | 148 | class Mu0Network(nn.Module): 149 | def __init__(self, base_dim, cfg): 150 | super(Mu0Network, self).__init__() 151 | self.baseModel = BaseModel(base_dim, cfg) 152 | self.logitLayer = nn.Linear(base_dim, 1) 153 | self.logitLayer.apply(init_weights) 154 | self.sigmoid = nn.Sigmoid() 155 | self.relu = nn.ReLU() 156 | 157 | def forward(self, inputs): 158 | inputs = self.baseModel(inputs) 159 | p = self.logitLayer(inputs) 160 | # return self.relu(p) 161 | return p 162 | 163 | 164 | class Mu1Network(nn.Module): 165 | def __init__(self, base_dim, cfg): 166 | super(Mu1Network, self).__init__() 167 | self.baseModel = BaseModel(base_dim, cfg) 168 | self.logitLayer = nn.Linear(base_dim, 1) 169 | self.logitLayer.apply(init_weights) 170 | self.sigmoid = nn.Sigmoid() 171 | self.relu = nn.ReLU() 172 | 173 | def forward(self, inputs): 174 | inputs = self.baseModel(inputs) 175 | p = self.logitLayer(inputs) 176 | # return self.relu(p) 177 | return p 178 | 179 | 180 | class TauNetwork(nn.Module): 181 | """pseudo tau network""" 182 | 183 | def __init__(self, base_dim, cfg): 184 | super(TauNetwork, self).__init__() 185 | self.baseModel = BaseModel(base_dim, cfg) 186 | self.logitLayer = nn.Linear(base_dim, 1) 187 | self.logitLayer.apply(init_weights) 188 | self.tanh = nn.Tanh() 189 | 190 | def forward(self, inputs): 191 | inputs = self.baseModel(inputs) 192 | tau_logit = self.logitLayer(inputs) 193 | # return self.tanh(p) 194 | return tau_logit 195 | 196 | 197 | class ESX(nn.Module): 198 | """ESX""" 199 | 200 | def __init__(self, prpsy_network: PrpsyNetwork, \ 201 | mu1_network: Mu1Network, mu0_network: Mu0Network, tau_network: TauNetwork, shareNetwork: ShareNetwork, 202 | cfg, device): 203 | super(ESX, self).__init__() 204 | # self.feature_extractor = feature_extractor 205 | self.shareNetwork = shareNetwork.to(device) 206 | self.prpsy_network = prpsy_network.to(device) 207 | self.mu1_network = mu1_network.to(device) 208 | self.mu0_network = mu0_network.to(device) 209 | self.tau_network = tau_network.to(device) 210 | self.cfg = cfg 211 | self.device = device 212 | self.to(device) 213 | 214 | def forward(self, inputs): 215 | shared_h = self.shareNetwork(inputs) 216 | 217 | # propensity output_logit 218 | p_prpsy_logit = self.prpsy_network(shared_h) 219 | 220 | # p_prpsy = torch.clip(torch.sigmoid(p_prpsy_logit), 0.05, 0.95) 221 | p_prpsy = torch.clip(torch.sigmoid(p_prpsy_logit), 0.001, 0.999) 222 | 223 | # logit for mu1, mu0 224 | mu1_logit = self.mu1_network(shared_h) 225 | mu0_logit = self.mu0_network(shared_h) 226 | 227 | # pseudo tau 228 | tau_logit = self.tau_network(shared_h) 229 | 230 | p_mu1 = sigmod2(mu1_logit) 231 | p_mu0 = sigmod2(mu0_logit) 232 | p_h1 = p_mu1 # Refer to the naming in TARnet/CFR 233 | p_h0 = p_mu0 # Refer to the naming in TARnet/CFR 234 | 235 | # entire space 236 | p_estr = torch.mul(p_prpsy, p_h1) 237 | p_i_prpsy = 1 - p_prpsy 238 | p_escr = torch.mul(p_i_prpsy, p_h0) 239 | 240 | return p_prpsy_logit, p_estr, p_escr, tau_logit, mu1_logit, mu0_logit, p_prpsy, p_mu1, p_mu0, p_h1, p_h0, shared_h 241 | -------------------------------------------------------------------------------- /visualization.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | import numpy as np 3 | import pandas as pd 4 | 5 | 6 | def get_mse(ds, pred): 7 | """ 8 | 9 | :param ds: 10 | :param pred: 11 | :return: 12 | """ 13 | mse = [] 14 | mse_stds = [] 15 | for t in range(1, ds['t'].max() + 1): 16 | stds = [] 17 | 18 | sl = (ds['t'] == t) 19 | y = ds['effect'][sl] 20 | p = pred[sl][:, t - 1] 21 | mse.append(((y - p) ** 2).mean()) 22 | 23 | for _ in range(100): 24 | idx = np.random.randint(0, p.shape[0], size=p.shape[0]) 25 | stds.append(((y[idx] - p[idx]) ** 2).mean()) 26 | 27 | mse_stds.append(np.std(stds, ddof=1)) 28 | 29 | return mse, mse_stds 30 | 31 | 32 | def get_results_simple( 33 | data, experiment, task, nstart=0, nstop=5, key='test', 34 | selector_fn=lambda x: np.mean(x.values), weight=None 35 | ): 36 | """ 37 | 38 | :param data: 39 | :param experiment: 40 | :param task: 41 | :param nstart: 42 | :param nstop: 43 | :param key: 44 | :param selector_fn: 45 | :param weight: 46 | :return: 47 | """ 48 | res = {'AUUC': {}, 'MSE': {}, 'ATE': {}, 'ATE_ERR, %': {}, 'QINI': {}} 49 | 50 | auuc = [] 51 | auuc_std = [] 52 | mses = [] 53 | mse_stds = [] 54 | 55 | for n in range(nstart, nstop): 56 | # dataset for metrics' 57 | data_ = data.split('/')[-1] 58 | ds_name = key if key == 'test' else 'trian' 59 | ds = joblib.load(f'datasets/{data}_{n}/{ds_name}.pkl') 60 | 61 | # get best trial 62 | study = joblib.load(f'{experiment}/{data_}_{n}/{task}/study.pkl') 63 | best = max(study.trials, key=selector_fn) 64 | # print(best.params) 65 | # get all scores 66 | scores = joblib.load( 67 | f'{experiment}/{data_}_{n}/{task}/trial_{best.number}/scores.pkl' 68 | ) 69 | pred = joblib.load(f'{experiment}/{data_}_{n}/{task}/trial_{best.number}/{key}_pred.pkl') 70 | 71 | if weight is not None: 72 | idx = np.searchsorted(np.linspace(0, 1, 11), weight) 73 | sc = scores[f'{key}_ext_w'][:, idx, :] 74 | pred = pred[idx] 75 | else: 76 | sc = scores[f'{key}_ext'] 77 | if len(pred) == 11: 78 | pred = pred[5] 79 | 80 | # calc all the additional metrics we need 81 | # loop by treatments 82 | 83 | if 'effect' in ds: 84 | mse_, mse_stds_ = get_mse(ds, pred) 85 | mses.append(np.array(mse_)) 86 | mse_stds.append(np.array(mse_stds_)) 87 | 88 | auuc.append(np.mean(sc, axis=0)) 89 | auuc_std.append(np.std(sc, axis=0, ddof=1)) 90 | 91 | # save auucs 92 | if 'effect' in ds: 93 | res['MSE']['mean'] = np.mean(mses, axis=0) 94 | res['MSE']['std'] = np.mean(mse_stds, axis=0) 95 | 96 | res['AUUC']['mean'] = np.mean(auuc, axis=0) 97 | res['AUUC']['std'] = np.std(auuc_std, axis=0) 98 | 99 | return res 100 | 101 | 102 | def get_baselines_results(experiment, datasets, models, K=5): 103 | """ 104 | 105 | :param experiment: 106 | :param datasets: 107 | :param models: 108 | :param K: 109 | :return: 110 | """ 111 | res = [] 112 | 113 | for dataset in datasets: 114 | for model in models: 115 | 116 | D = get_results_simple( 117 | dataset, 118 | experiment, 119 | model, 120 | key='test', 121 | nstop=K 122 | ) 123 | 124 | for key in ['mean', 'std']: 125 | df = pd.DataFrame({x: D[x][key] for x in D if key in D[x]}, ) 126 | df['treat'] = np.arange(df.shape[0]) 127 | df['data'] = dataset 128 | df['model'] = model 129 | df['stat'] = key 130 | 131 | res.append(df) 132 | 133 | return res 134 | 135 | 136 | def get_pb_results(experiment, datasets, weights, K=5): 137 | """ 138 | 139 | :param experiment: 140 | :param datasets: 141 | :param weights: 142 | :param K: 143 | :return: 144 | """ 145 | res = [] 146 | 147 | for dataset in datasets: 148 | for w in weights: 149 | D = get_results_simple( 150 | dataset, 151 | experiment, 152 | 'pb-pb-pb_lc_f_t', 153 | key='test', 154 | nstop=K, 155 | weight=round(w, 1) 156 | ) 157 | for key in ['mean', 'std']: 158 | df = pd.DataFrame({x: D[x][key] for x in D if key in D[x]}, ) 159 | df['treat'] = np.arange(df.shape[0]) 160 | df['data'] = dataset 161 | df['model'] = 'pb-pb-pb_lc_f_t' + str(w) 162 | df['stat'] = key 163 | 164 | res.append(df) 165 | return res 166 | 167 | 168 | def get_pb_weighted_results(experiment, dataset, K=5): 169 | """ 170 | 171 | :param experiment: 172 | :param dataset: 173 | :param K: 174 | :return: 175 | """ 176 | res = [] 177 | 178 | for w in np.linspace(0, 1, 11): 179 | w = round(w, 1) 180 | D = get_results_simple( 181 | dataset, 182 | experiment, 183 | 'pb-pb-pb_lc_f_t', 184 | key='test', 185 | nstop=K, 186 | weight=w 187 | ) 188 | for key in ['mean', 'std']: 189 | df = pd.DataFrame({x: D[x][key] for x in D if key in D[x]}, ) 190 | df['treat'] = np.arange(df.shape[0]) 191 | df['data'] = dataset 192 | df['model'] = 'pb-pb-pb_lc_f_t' + str(w) 193 | df['stat'] = key 194 | 195 | res.append(df) 196 | 197 | res = pd.concat(res) 198 | res['w'] = res['model'].map(lambda x: x[-3:]).astype(float) 199 | return res 200 | 201 | 202 | def replace_index(df, mapping): 203 | df = df.loc[list(mapping.keys())] 204 | df = df.reset_index() 205 | df['model'] = df['model'].map(mapping) 206 | df = df.set_index('model') 207 | 208 | return df 209 | 210 | 211 | def get_datasets_summary(res, stat, mapping=None, round_mean=3, round_std=4): 212 | """ 213 | 214 | :param res: 215 | :param stat: 216 | :param round_mean: 217 | :param round_std: 218 | :return: 219 | """ 220 | df_mean = res.query('stat == "mean"')[[stat, 'treat', 'data', 'model']] 221 | df_mean = pd.pivot_table( 222 | df_mean, values=stat, index='model', columns=['data', 'treat'] 223 | 224 | ) 225 | 226 | df_std = res.query('stat == "std"')[[stat, 'treat', 'data', 'model']] 227 | df_std = pd.pivot_table( 228 | df_std, values=stat, index='model', columns=['data', 'treat'] 229 | 230 | ) 231 | tot = df_mean.round(round_mean).astype(str) + '\u00b1' + df_std.round(round_std).astype(str) 232 | if mapping is not None: 233 | tot = replace_index(tot, mapping) 234 | 235 | return tot 236 | 237 | 238 | def get_rank_stats(res, mapping): 239 | avg_rank = res.query('stat == "mean"').copy() 240 | avg_rank['AUUC_rank'] = avg_rank.groupby(['data', 'treat'])['AUUC'].rank(method='dense', ascending=False) 241 | avg_rank['AUUC_from_top'] = 1 - avg_rank['AUUC'] / avg_rank.groupby(['data', 'treat'])['AUUC'].transform('max') 242 | avg_rank['AUUC_from_top'] = avg_rank['AUUC_from_top'] * 100 243 | 244 | avg_rank['MSE_rank'] = avg_rank.groupby(['data', 'treat'])['MSE'].rank(method='dense', ascending=True) 245 | avg_rank['MSE_from_top'] = avg_rank['MSE'] / avg_rank.groupby(['data', 'treat'])['MSE'].transform('min') - 1 246 | avg_rank['MSE_from_top'] = avg_rank['MSE_from_top'] * 100 247 | 248 | avg_rank = avg_rank.groupby('model')[['AUUC_rank', 'MSE_rank', 'AUUC_from_top', 'MSE_from_top']] \ 249 | .mean().round(1) 250 | 251 | if mapping is not None: 252 | avg_rank = replace_index(avg_rank, mapping) 253 | 254 | return avg_rank 255 | 256 | 257 | def to_latex(df, direction, is_str=True): 258 | """ 259 | 260 | :param df: 261 | :param direction: 262 | :return: 263 | """ 264 | df = df.copy() 265 | 266 | best = [] 267 | 268 | for col in df.columns: 269 | ser = df[col] 270 | if is_str: 271 | ser = ser.str.split('\u00b1').map(lambda x: x[0]) 272 | ser = ser.astype(float) 273 | 274 | best.append(ser.max() if direction == 'max' else ser.min()) 275 | 276 | for n, col in enumerate(df.columns): 277 | ser = df[col] 278 | if is_str: 279 | ser = ser.str.split('\u00b1').map(lambda x: x[0]) 280 | ser = ser.astype(float) 281 | sl = ser == best[n] 282 | df[col] = df[col].astype(str) 283 | df[col].loc[sl] = "\\textbf{" + df[col].loc[sl].astype(str) + "}" 284 | 285 | df = df.reset_index() 286 | 287 | df['model'] = "\\textbf{" + df['model'] + "}" 288 | 289 | df = df.apply(lambda x: ' & '.join(x), axis=1).tolist() 290 | df = ' \\\\ \n'.join(df) 291 | 292 | return df 293 | -------------------------------------------------------------------------------- /source/dragonnet/dragonnet.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | from sklearn.exceptions import NotFittedError 7 | from sklearn.preprocessing import StandardScaler, OneHotEncoder 8 | from torch.utils.data import TensorDataset, DataLoader 9 | 10 | from .model import DragonNetBase, dragonnet_loss, tarreg_loss, EarlyStopper 11 | 12 | 13 | class DragonNet: 14 | """ 15 | Main class for the Dragonnet model 16 | 17 | Parameters 18 | ---------- 19 | input_dim: int 20 | input dimension for convariates 21 | shared_hidden: int, default=200 22 | layer size for hidden shared representation layers 23 | outcome_hidden: int, default=100 24 | layer size for conditional outcome layers 25 | alpha: float, default=1.0 26 | loss component weighting hyperparameter between 0 and 1 27 | beta: float, default=1.0 28 | targeted regularization hyperparameter between 0 and 1 29 | epochs: int, default=200 30 | Number training epochs 31 | steps_per_epoch: int, default=100 32 | Number of steps per epoch to scale batch size 33 | learning_rate: float, default=1e-3 34 | Learning rate 35 | data_loader_num_workers: int, default=4 36 | Number of workers for data loader 37 | loss_type: str, {'tarreg', 'default'}, default='tarreg' 38 | Loss function to use 39 | """ 40 | 41 | def __init__( 42 | self, 43 | hidden_scale=2., 44 | outcome_scale=.5, 45 | alpha=1.0, 46 | beta=1.0, 47 | epochs=200, 48 | steps_per_epoch=100, 49 | learning_rate=1e-5, 50 | data_loader_num_workers=4, 51 | loss_type="tarreg", 52 | device='cuda', 53 | es=10, 54 | cat_cols=None, 55 | cat_params=None 56 | ): 57 | 58 | self.temp_folder = str(uuid.uuid1()) 59 | 60 | self.hidden_scale = hidden_scale 61 | self.outcome_scale = outcome_scale 62 | self.learning_rate = learning_rate 63 | self.epochs = epochs 64 | self.steps_per_epoch = steps_per_epoch 65 | self.batch_size = None 66 | self.num_workers = data_loader_num_workers 67 | self.train_dataloader = None 68 | self.valid_dataloader = None 69 | self.device = device 70 | self.scaler = StandardScaler() 71 | self.es = es 72 | self.cat_cols = [] if cat_cols is None else cat_cols 73 | if cat_params is None: 74 | cat_params = { 75 | 'min_frequency': 10, 76 | 'max_categories': 100, 77 | 'handle_unknown': 'infrequent_if_exist', 78 | 'sparse_output': False 79 | } 80 | self.enc = OneHotEncoder(**cat_params) 81 | 82 | if loss_type == "tarreg": 83 | self.loss_f = partial(tarreg_loss, alpha=alpha, beta=beta) 84 | elif loss_type == "default": 85 | self.loss_f = partial(dragonnet_loss, alpha=alpha) 86 | 87 | def create_model(self, x): 88 | 89 | x = self.preprocess(x) 90 | nrows, input_dim = x.shape 91 | 92 | self.batch_size = max(32, int(nrows / self.steps_per_epoch)) 93 | 94 | shared_hidden = max(int(self.hidden_scale * input_dim), 4) 95 | outcome_hidden = max(int(self.outcome_scale * shared_hidden), 2) 96 | self.model = DragonNetBase(input_dim, shared_hidden, outcome_hidden).to(self.device) 97 | self.optim = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) 98 | 99 | return x 100 | 101 | def preprocess(self, x): 102 | 103 | if len(self.cat_cols) > 0: 104 | not_cat = np.setdiff1d(np.arange(x.shape[1]), self.cat_cols) 105 | x_cat = x[:, self.cat_cols] 106 | x = x[:, not_cat] 107 | 108 | try: 109 | x_cat = self.enc.transform(x_cat) 110 | except NotFittedError: 111 | x_cat = self.enc.fit_transform(x_cat) 112 | 113 | try: 114 | x = self.scaler.transform(x) 115 | except NotFittedError: 116 | x = self.scaler.fit_transform(x) 117 | 118 | if len(self.cat_cols) > 0: 119 | x = np.concatenate([x, x_cat], axis=1) 120 | 121 | return x 122 | 123 | def create_dataloaders(self, x, y, t, x_v=None, y_v=None, t_v=None): 124 | """ 125 | Utility function to create train and validation data loader: 126 | 127 | Parameters 128 | ---------- 129 | x: np.array 130 | covariates 131 | y: np.array 132 | target variable 133 | t: np.array 134 | treatment 135 | """ 136 | 137 | x = torch.Tensor(x) 138 | t = torch.Tensor(t).reshape(-1, 1) 139 | y = torch.Tensor(y).reshape(-1, 1) 140 | train_dataset = TensorDataset(x, t, y) 141 | self.train_dataloader = DataLoader( 142 | train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True 143 | ) 144 | 145 | if x_v is not None: 146 | x_v = self.preprocess(x_v) 147 | x_v = torch.Tensor(x_v) 148 | y_v = torch.Tensor(y_v).reshape(-1, 1) 149 | t_v = torch.Tensor(t_v).reshape(-1, 1) 150 | valid_dataset = TensorDataset(x_v, t_v, y_v) 151 | self.valid_dataloader = DataLoader( 152 | valid_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False 153 | ) 154 | 155 | def fit(self, x, y, t, x_v=None, y_v=None, t_v=None): 156 | """ 157 | Function used to train the dragonnet model 158 | 159 | Parameters 160 | ---------- 161 | x: np.array 162 | covariates 163 | y: np.array 164 | target variable 165 | t: np.array 166 | treatment 167 | valid_perc: float 168 | Percentage of data to allocate to validation set 169 | """ 170 | x = self.create_model(x) 171 | self.create_dataloaders(x, y, t, x_v, y_v, t_v) 172 | early_stopper = EarlyStopper(self.temp_folder, patience=self.es, min_delta=0) 173 | for epoch in range(self.epochs): 174 | 175 | self.model.train() 176 | 177 | for batch, (X, tr, y1) in enumerate(self.train_dataloader): 178 | X, tr, y1 = X.to(self.device), tr.to(self.device), y1.to(self.device) 179 | y0_pred, y1_pred, t_pred, eps = self.model(X) 180 | loss = self.loss_f(y1, tr, t_pred, y0_pred, y1_pred, eps) 181 | self.optim.zero_grad() 182 | loss.backward() 183 | self.optim.step() 184 | if self.valid_dataloader: 185 | valid_loss = self.validate_step() 186 | print( 187 | f"epoch: {epoch}--------- train_loss: {loss} ----- valid_loss: {valid_loss}" 188 | ) 189 | if early_stopper.early_stop(valid_loss, self.model): 190 | break 191 | else: 192 | print(f"epoch: {epoch}--------- train_loss: {loss}") 193 | 194 | self.model = early_stopper.load(self.model) 195 | early_stopper.clear() 196 | 197 | def validate_step(self): 198 | """ 199 | Calculates validation loss 200 | 201 | Returns 202 | ------- 203 | valid_loss: torch.Tensor 204 | validation loss 205 | """ 206 | 207 | self.model.eval() 208 | 209 | valid_loss = [] 210 | with torch.no_grad(): 211 | for batch, (X, tr, y1) in enumerate(self.valid_dataloader): 212 | X, tr, y1 = X.to(self.device), tr.to(self.device), y1.to(self.device) 213 | y0_pred, y1_pred, t_pred, eps = self.model(X) 214 | loss = self.loss_f(y1, tr, t_pred, y0_pred, y1_pred, eps) 215 | valid_loss.append(loss) 216 | return torch.Tensor(valid_loss).mean() 217 | 218 | def predict(self, x): 219 | """ 220 | Function used to predict on covariates. 221 | 222 | Parameters 223 | ---------- 224 | x: torch.Tensor or numpy.array 225 | covariates 226 | 227 | Returns 228 | ------- 229 | y0_pred: torch.Tensor 230 | outcome under control 231 | y1_pred: torch.Tensor 232 | outcome under treatment 233 | t_pred: torch.Tensor 234 | predicted treatment 235 | eps: torch.Tensor 236 | trainable epsilon parameter 237 | """ 238 | self.model.eval() 239 | 240 | res = np.zeros((x.shape[0],), dtype=np.float32) 241 | x = self.preprocess(x) 242 | x = torch.Tensor(x) 243 | 244 | ds = TensorDataset(x) 245 | dl = DataLoader( 246 | ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False 247 | ) 248 | 249 | with torch.no_grad(): 250 | for n, (batch,) in enumerate(dl): 251 | batch = batch.to(self.device) 252 | y0_pred, y1_pred, t_pred, eps = self.model(batch) 253 | res[n * self.batch_size: (n + 1) * self.batch_size] = (y1_pred - y0_pred).detach().cpu().numpy()[:, 0] 254 | 255 | return res 256 | -------------------------------------------------------------------------------- /source/objective.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | 4 | import joblib 5 | import numpy as np 6 | 7 | 8 | class Objective: 9 | 10 | def __init__(self, train, test, runner, factory, params, folder, keys=('params',)): 11 | 12 | # data and method 13 | self.train = train 14 | self.test = test 15 | self.runner = runner 16 | # how to create model 17 | self.factory = factory 18 | # keys for each meta learner base model 19 | self.keys = keys 20 | # params 21 | if params is None: 22 | params = {} 23 | self.params = params 24 | # where to save 25 | self.folder = folder 26 | os.makedirs(folder, exist_ok=True) 27 | 28 | def _set_params(self, trial, params): 29 | 30 | total_params = {} 31 | params = self.set_params(trial, deepcopy(params)) 32 | 33 | for key in self.keys: 34 | total_params[key] = params 35 | 36 | return total_params 37 | 38 | def __call__(self, trial): 39 | 40 | params = deepcopy(self.params) 41 | trial_name = f'trial_{trial.number}' 42 | folder = os.path.join(self.folder, trial_name) 43 | os.makedirs(folder, exist_ok=True) 44 | params = self._set_params(trial, params) 45 | 46 | scores, oof_pred, test_pred = self.runner(self.train, self.test, params, self.factory) 47 | joblib.dump(scores, os.path.join(folder, 'scores.pkl')) 48 | joblib.dump(params, os.path.join(folder, 'params.pkl')) 49 | joblib.dump(oof_pred, os.path.join(folder, 'oof_pred.pkl')) 50 | joblib.dump(test_pred, os.path.join(folder, 'test_pred.pkl')) 51 | 52 | print(len(scores['valid'])) 53 | print(np.mean(scores['test_ext'], axis=0)) 54 | 55 | return list(np.array(scores['valid_ext']).mean(axis=0)) # scores['valid'] # np.mean(scores['valid']) 56 | 57 | 58 | class ObjectiveMulti(Objective): 59 | 60 | def _set_params(self, trial, params): 61 | total_params = {} 62 | for key in self.keys: 63 | total_params[key] = self.set_params(trial, deepcopy(params), key) 64 | 65 | return total_params 66 | 67 | 68 | class ObjectiveSingleXGB(Objective): 69 | 70 | def set_params(self, trial, params): 71 | params['min_child_weight'] = trial.suggest_float("min_child_weight", 1e-5, 10, log=True) 72 | params['subsample'] = trial.suggest_float("subsample", 0.7, 1.0) 73 | params['colsample_bytree'] = trial.suggest_float("colsample_bytree", 0.7, 1.0) 74 | params['max_depth'] = trial.suggest_int("max_depth", 2, 6) 75 | params['lambda'] = trial.suggest_float("lambda", .1, 50, log=True) 76 | # params['learning_rate'] = trial.suggest_float("learning_rate", .01, 0.3, log=True) 77 | params['learning_rate'] = trial.suggest_float("learning_rate", .005, 0.1, log=False) 78 | params['gamma'] = trial.suggest_float("gamma", 1e-5, 100., log=True) 79 | 80 | return params 81 | 82 | 83 | class ObjectiveMultiXGB(ObjectiveMulti): 84 | 85 | def set_params(self, trial, params, prefix): 86 | params['min_child_weight'] = trial.suggest_float(f"{prefix}_min_child_weight", 1e-5, 10, log=True) 87 | params['subsample'] = trial.suggest_float(f"{prefix}_subsample", 0.7, 1.0) 88 | params['colsample_bytree'] = trial.suggest_float(f"{prefix}_colsample_bytree", 0.7, 1.0) 89 | params['max_depth'] = trial.suggest_int(f"{prefix}_max_depth", 2, 6) 90 | params['lambda'] = trial.suggest_float(f"{prefix}_lambda", .1, 50, log=True) 91 | # params['learning_rate'] = trial.suggest_float(f"{prefix}_learning_rate", .01, 0.3, log=True) 92 | params['learning_rate'] = trial.suggest_float(f"{prefix}_learning_rate", .005, 0.1, log=False) 93 | params['gamma'] = trial.suggest_float(f"{prefix}_gamma", 1e-5, 100., log=True) 94 | 95 | return params 96 | 97 | 98 | class ObjectivePB(Objective): 99 | 100 | def set_params(self, trial, params): 101 | params['min_data_in_leaf'] = trial.suggest_int("min_data_in_leaf", 1, 100, log=True) 102 | params['subsample'] = trial.suggest_float("subsample", 0.7, 1.0) 103 | params['colsample'] = trial.suggest_float("colsample", 0.7, 1.0) 104 | params['max_depth'] = trial.suggest_int("max_depth", 2, 6) 105 | params['lambda_l2'] = trial.suggest_float("lambda_l2", .1, 50, log=True) 106 | # params['lr'] = trial.suggest_float("lr", .01, 0.3, log=True) 107 | params['lr'] = trial.suggest_float("lr", .005, 0.1, log=False) 108 | params['min_gain_to_split'] = trial.suggest_float("min_gain_to_split", 1e-5, 100., log=True) 109 | 110 | # if 'weight' in params: 111 | # params['weight'] = trial.suggest_float("weight", 0, 1, log=False) 112 | 113 | return params 114 | 115 | 116 | class ObjectiveMultiPB(ObjectiveMulti): 117 | 118 | def set_params(self, trial, params, prefix): 119 | params['min_data_in_leaf'] = trial.suggest_int(f"{prefix}_min_data_in_leaf", 1, 100, log=True) 120 | params['subsample'] = trial.suggest_float(f"{prefix}_subsample", 0.7, 1.0) 121 | params['colsample'] = trial.suggest_float(f"{prefix}_colsample", 0.7, 1.0) 122 | params['max_depth'] = trial.suggest_int(f"{prefix}_max_depth", 2, 6) 123 | params['lambda_l2'] = trial.suggest_float(f"{prefix}_lambda_l2", .1, 50, log=True) 124 | params['lr'] = trial.suggest_float(f"{prefix}_lr", .005, 0.1, log=False) 125 | params['min_gain_to_split'] = trial.suggest_float(f"{prefix}_min_gain_to_split", 1e-5, 100., log=True) 126 | 127 | return params 128 | 129 | 130 | class ObjectiveLGB(Objective): 131 | 132 | def set_params(self, trial, params): 133 | params['min_child_samples'] = trial.suggest_int("min_data_in_leaf", 1, 100, log=True) 134 | params['subsample'] = trial.suggest_float("subsample", 0.7, 1.0) 135 | params['colsample_bytree'] = trial.suggest_float("colsample_bytree", 0.7, 1.0) 136 | params['max_depth'] = trial.suggest_int("max_depth", 2, 6) 137 | params['reg_lambda'] = trial.suggest_float("reg_lambda", .1, 50, log=True) 138 | params['learning_rate'] = trial.suggest_float("learning_rate", .005, 0.1, log=False) 139 | params['min_split_gain'] = trial.suggest_float("min_split_gain", 1e-5, 100., log=True) 140 | 141 | return params 142 | 143 | 144 | class ObjectiveMultiLGB(ObjectiveMulti): 145 | 146 | def set_params(self, trial, params, prefix): 147 | params['min_child_samples'] = trial.suggest_int(f"{prefix}_min_child_samples", 1, 100, log=True) 148 | params['subsample'] = trial.suggest_float(f"{prefix}_subsample", 0.7, 1.0) 149 | params['colsample_bytree'] = trial.suggest_float(f"{prefix}_colsample_bytree", 0.7, 1.0) 150 | params['max_depth'] = trial.suggest_int(f"{prefix}_max_depth", 2, 6) 151 | params['reg_lambda'] = trial.suggest_float(f"{prefix}_reg_lambda", .1, 50, log=True) 152 | params['learning_rate'] = trial.suggest_float(f"{prefix}_learning_rate", .005, 0.1, log=False) 153 | params['min_split_gain'] = trial.suggest_float(f"{prefix}_min_split_gain", 1e-5, 100., log=True) 154 | 155 | return params 156 | 157 | 158 | class ObjectiveCRF(Objective): 159 | 160 | def set_params(self, trial, params): 161 | params['criterion'] = trial.suggest_categorical("criterion", ["mse", "het"]) 162 | params['honest'] = trial.suggest_categorical("honest", [False, True]) 163 | 164 | params['max_depth'] = trial.suggest_int("max_depth", 2, 12) 165 | params['min_samples_leaf'] = trial.suggest_int("min_samples_leaf", 2, 100, log=True) 166 | 167 | params['max_features'] = trial.suggest_float("max_features", 0.2, 0.8) 168 | params['max_samples'] = trial.suggest_float("max_samples", 0.1, 0.5) 169 | 170 | params['min_balancedness_tol'] = trial.suggest_float("min_balancedness_tol", 0.05, 0.45) 171 | 172 | return params 173 | 174 | 175 | class ObjectiveDR(Objective): 176 | def set_params(self, trial, params): 177 | params['hidden_scale'] = trial.suggest_float("hidden_scale", .5, 2.) 178 | params['outcome_scale'] = trial.suggest_float("outcome_scale", .5, 2.) 179 | 180 | params['alpha'] = trial.suggest_float("alpha", .5, 1.5) 181 | params['beta'] = trial.suggest_float("beta", .5, 1.5) 182 | params['steps_per_epoch'] = trial.suggest_int("steps_per_epoch", 100, 300) 183 | 184 | params['learning_rate'] = trial.suggest_float("learning_rate", 1e-4, 1e-3) 185 | 186 | return params 187 | 188 | 189 | class ObjectiveDCN(Objective): 190 | def set_params(self, trial, params): 191 | params['share_scale'] = trial.suggest_float("share_scale", .5, 2.) 192 | params['base_scale'] = trial.suggest_float("base_scale", .5, 2.) 193 | 194 | params['prpsy_w'] = trial.suggest_float("prpsy_w", 0.5, 1) 195 | params['escvr1_w'] = trial.suggest_float("escvr1_w", 0.5, 1) 196 | params['escvr0_w'] = trial.suggest_float("escvr0_w", 0.5, 1) 197 | params['mu0hat_w'] = trial.suggest_float("mu0hat_w", 0.5, 1) 198 | params['mu1hat_w'] = trial.suggest_float("mu1hat_w", 0.5, 1) 199 | 200 | params['steps_per_epoch'] = trial.suggest_int("steps_per_epoch", 100, 300) 201 | 202 | params['lr'] = trial.suggest_float("lr", 1e-4, 1e-3) 203 | 204 | return params 205 | -------------------------------------------------------------------------------- /datasets/get_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import joblib 4 | import numpy as np 5 | import pandas as pd 6 | from causalml.dataset import make_uplift_classification 7 | from sklearn.metrics import roc_auc_score 8 | from sklearn.model_selection import StratifiedKFold, KFold 9 | from sklift import datasets 10 | from xgboost import XGBClassifier 11 | 12 | DATASET_PATH = os.path.dirname(os.path.abspath(__file__)) 13 | 14 | 15 | def make_testset(X, y, t, fold, random_state=42, stratify=True, effect=False): 16 | train, test = {}, {} 17 | nt = t.max() + 1 18 | 19 | trg = np.zeros((t.shape[0], nt), dtype=np.float32) 20 | trg[:] = np.nan 21 | rows, cols = np.nonzero(t[:, np.newaxis] == np.arange(nt)[np.newaxis, :]) 22 | trg[rows, cols] = y 23 | 24 | if stratify: 25 | stratify = t + nt * y 26 | folds = StratifiedKFold(5, random_state=random_state, shuffle=True).split(stratify, stratify) 27 | else: 28 | folds = KFold(5, random_state=random_state, shuffle=True).split(y, y) 29 | 30 | for n, (f0, f1) in enumerate(folds): 31 | 32 | if n == fold: 33 | break 34 | 35 | train['X'], test['X'] = X[f0], X[f1] 36 | train['y'], test['y'] = y[f0], y[f1] 37 | train['t'], test['t'] = t[f0], t[f1] 38 | train['trg'], test['trg'] = trg[f0], trg[f1] 39 | 40 | if effect: 41 | for data in [train, test]: 42 | data['X'], data['effect'] = data['X'][:, :-1], data['X'][:, -1] 43 | 44 | return train, test 45 | 46 | 47 | def save_data(train, test, alias, i): 48 | folder = os.path.join(DATASET_PATH, f'{alias}_{i}') 49 | os.makedirs(folder, exist_ok=True) 50 | joblib.dump(train, os.path.join(folder, 'train.pkl')) 51 | joblib.dump(test, os.path.join(folder, 'test.pkl')) 52 | 53 | return 54 | 55 | 56 | def get_synth_6_trt(): 57 | names = ["control", "treatment1", "treatment2", "treatment3", "treatment4", "treatment5", "treatment6"] 58 | 59 | df, x_names = make_uplift_classification( 60 | n_samples=10000, 61 | treatment_name=names, 62 | y_name="conversion", 63 | n_classification_features=100, 64 | n_classification_informative=20, 65 | n_classification_redundant=10, 66 | n_classification_repeated=10, 67 | n_uplift_increase_dict={ 68 | "treatment1": 3, "treatment2": 5, "treatment3": 7, "treatment4": 9, "treatment5": 11, "treatment6": 13, 69 | }, 70 | n_uplift_decrease_dict={ 71 | "treatment1": 1, "treatment2": 2, "treatment3": 3, "treatment4": 3, "treatment5": 4, "treatment6": 4, 72 | }, 73 | delta_uplift_increase_dict={ 74 | "treatment1": 0.05, 75 | "treatment2": 0.1, 76 | "treatment3": 0.12, 77 | "treatment4": 0.15, 78 | "treatment5": 0.17, 79 | "treatment6": 0.2, 80 | }, 81 | delta_uplift_decrease_dict={ 82 | "treatment1": 0.01, 83 | "treatment2": 0.02, 84 | "treatment3": 0.03, 85 | "treatment4": 0.05, 86 | "treatment5": 0.06, 87 | "treatment6": 0.07, 88 | }, 89 | n_uplift_increase_mix_informative_dict={ 90 | "treatment1": 1, 91 | "treatment2": 2, 92 | "treatment3": 3, 93 | "treatment4": 4, 94 | "treatment5": 5, 95 | "treatment6": 6, 96 | }, 97 | n_uplift_decrease_mix_informative_dict={ 98 | "treatment1": 1, 99 | "treatment2": 1, 100 | "treatment3": 1, 101 | "treatment4": 1, 102 | "treatment5": 1, 103 | "treatment6": 1, 104 | }, 105 | positive_class_proportion=0.2, 106 | random_seed=42, 107 | ) 108 | 109 | tkeys = {f"treatment{x}": x for x in range(1, 7)} 110 | tkeys['control'] = 0 111 | 112 | alias = 'synth1' 113 | 114 | X = df.drop(['treatment_group_key', 'conversion', ], axis=1).values.astype(np.float32) 115 | y = df['conversion'].values.astype(np.float32) 116 | t = df['treatment_group_key'].map(tkeys).values.astype(np.int32) 117 | 118 | for i in range(5): 119 | train, test = make_testset(X, y, t, fold=i, random_state=42, stratify=True, effect=True) 120 | save_data(train, test, alias, i) 121 | 122 | return 123 | 124 | 125 | def get_hillstom(): 126 | alias = 'hillstrom' 127 | data = datasets.fetch_hillstrom() 128 | data, target, treatment = data['data'], data['target'], data['treatment'] 129 | 130 | data['history_segment'] = data['history_segment'].str.slice(0, 1).astype(int) 131 | for col in ['zip_code', 'channel']: 132 | data[col] = pd.factorize(data[col])[0] 133 | 134 | X = data.values.astype(np.float32) 135 | y = target.values.astype(np.float32) 136 | t = treatment.map({'No E-Mail': 0, 'Mens E-Mail': 1, 'Womens E-Mail': 2}).values.astype(np.int32) 137 | 138 | for i in range(5): 139 | train, test = make_testset(X, y, t, fold=i, random_state=42, stratify=False, effect=False) 140 | save_data(train, test, alias, i) 141 | 142 | return 143 | 144 | 145 | def get_criteo(): 146 | alias = 'criteo' 147 | data = datasets.fetch_criteo() 148 | data, target, treatment = data['data'], data['target'], data['treatment'] 149 | t = treatment.values.astype(np.int32) 150 | y = target.values.astype(np.float32) 151 | 152 | np.random.seed(42) 153 | 154 | idx = np.arange(t.shape[0]) 155 | idx0, idx1 = idx[y == 0], idx[y == 1] 156 | np.random.shuffle(idx0) 157 | 158 | idx0 = idx0[:1000000] 159 | 160 | idx = np.concatenate([idx0, idx1]) 161 | np.random.shuffle(idx) 162 | 163 | X = data.values.astype(np.float32)[idx] 164 | y = y[idx] 165 | t = t[idx] 166 | 167 | for i in range(5): 168 | train, test = make_testset(X, y, t, fold=i, random_state=42, stratify=False, effect=False) 169 | save_data(train, test, alias, i) 170 | 171 | return 172 | 173 | 174 | def get_lenta(): 175 | alias = 'lenta' 176 | data = datasets.fetch_lenta() 177 | 178 | data, target, treatment = data['data'], data['target'], data['treatment'] 179 | t = treatment.map({'control': 0, 'test': 1}).values.astype(np.int32) 180 | y = target.values.astype(np.float32) 181 | 182 | data['gender'] = (data['gender'] == data['gender'].iloc[0]) 183 | # for simplicity - fill NaNs with median 184 | data = data.fillna(data.median()) 185 | X = data.values.astype(np.float32) 186 | 187 | for i in range(5): 188 | train, test = make_testset(X, y, t, fold=i, random_state=42, stratify=False, effect=False) 189 | save_data(train, test, alias, i) 190 | 191 | return 192 | 193 | 194 | def get_megafon(): 195 | alias = 'megafon' 196 | data = datasets.fetch_megafon() 197 | 198 | data, target, treatment = data['data'], data['target'], data['treatment'] 199 | t = treatment.map({'control': 0, 'treatment': 1}).values.astype(np.int32) 200 | y = target.values.astype(np.float32) 201 | X = data.values.astype(np.float32) 202 | 203 | for i in range(5): 204 | train, test = make_testset(X, y, t, fold=i, random_state=42, stratify=False, effect=False) 205 | save_data(train, test, alias, i) 206 | 207 | return 208 | 209 | 210 | # params for propensity estimator 211 | params_xgb = { 212 | 'n_estimators': 1000, 213 | 'learning_rate': 0.01, 214 | 'max_depth': 3, 215 | 'tree_method': 'gpu_hist', 216 | 'gpu_id': 0, 217 | 'min_child_weight': 0, 218 | 'lambda': 1, 219 | 'max_bin': 256, 220 | 'gamma': 0, 221 | 'alpha': 0, 222 | } 223 | 224 | 225 | def get_trt_slice(x, y, t): 226 | sl = np.isin(y, [0, t]) 227 | y = (y[sl] == t).astype(np.float32) 228 | x = x[sl] 229 | 230 | return x, y 231 | 232 | 233 | def set_propensity(train, test, params, cutoff=0.55): 234 | n_trt = train['t'].max() 235 | 236 | X, y = train['X'], train['t'] 237 | X_test = test['X'] 238 | 239 | folds = KFold(5, shuffle=True, random_state=42) 240 | 241 | oof_pred = np.zeros((X.shape[0], n_trt), dtype=np.float32) 242 | test_pred = np.zeros((X_test.shape[0], n_trt), dtype=np.float32) 243 | 244 | scores = [] 245 | priors = [] 246 | 247 | for n, (f0, f1) in enumerate(folds.split(y, y)): 248 | 249 | x_tr, x_val = X[f0], X[f1] 250 | y_tr, y_val = y[f0], y[f1] 251 | score = [] 252 | prior = [] 253 | for i in range(n_trt): 254 | # get target slice for control + curr treatment 255 | ds = [ 256 | get_trt_slice(x, y, i + 1) for (x, y) in 257 | [[x_tr, y_tr], [x_val, y_val]] 258 | ] 259 | 260 | model = XGBClassifier(**params) 261 | model.fit( 262 | *ds[0], eval_set=[ds[1]], 263 | early_stopping_rounds=100, eval_metric='auc', verbose=1000 264 | ) 265 | 266 | score.append( 267 | roc_auc_score(ds[1][1], model.predict_proba(ds[1][0])[:, 1]) 268 | ) 269 | prior.append(ds[0][1].mean()) 270 | 271 | oof_pred[f1, i] = model.predict_proba(x_val)[:, 1] 272 | test_pred[:, i] += model.predict_proba(X_test)[:, 1] 273 | 274 | scores.append(score) 275 | priors.append(prior) 276 | 277 | scores = np.array(scores).mean(axis=0) 278 | priors = np.array(priors).mean(axis=0) 279 | test_pred /= 5 280 | 281 | for i in range(n_trt): 282 | if scores[i] < cutoff: 283 | oof_pred[:, i] = priors[i] 284 | test_pred[:, i] = priors[i] 285 | 286 | train['p'] = oof_pred 287 | test['p'] = test_pred 288 | 289 | return train, test 290 | 291 | 292 | if __name__ == '__main__': 293 | print('Fetching the data...') 294 | # fetch datasets 295 | get_synth_6_trt() 296 | get_hillstom() 297 | get_criteo() 298 | get_megafon() 299 | get_lenta() 300 | 301 | print('Estimating propensities...') 302 | # create propensity scores - required for some models 303 | for alias in ['synth1', 'hillstrom', 'criteo', 'megafon', 'lenta']: 304 | for i in range(5): 305 | folder = os.path.join(DATASET_PATH, f'{alias}_{i}') 306 | train = joblib.load(os.path.join(folder, 'train.pkl')) 307 | test = joblib.load(os.path.join(folder, 'test.pkl')) 308 | 309 | train, test = set_propensity(train, test, params_xgb, cutoff=0.55) 310 | save_data(train, test, alias, i) 311 | 312 | print('Done') 313 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 Vakhrusev Anton, Ibragimov Bulat 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /source/runner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.model_selection import StratifiedKFold 3 | from sklift.metrics.metrics import uplift_auc_score 4 | 5 | 6 | def qini(y_true, uplift, treatment): 7 | scores = [] 8 | 9 | for t_id in range(treatment.max()): 10 | sl = (treatment == 0) | (treatment == (t_id + 1)) 11 | y_ = y_true[sl] 12 | upl = uplift[:, t_id][sl] 13 | trt = np.clip(treatment[sl], 0, 1) 14 | 15 | scores.append( 16 | uplift_auc_score(y_, upl, trt, ) 17 | ) 18 | 19 | return scores 20 | 21 | 22 | def mse(y_true, uplift, treatment): 23 | scores = [] 24 | 25 | for t_id in range(treatment.max()): 26 | sl = treatment == (t_id + 1) 27 | y_ = y_true[sl] 28 | upl = uplift[:, t_id][sl] 29 | 30 | scores.append( 31 | ((y_ - upl) ** 2).mean() 32 | ) 33 | 34 | return scores 35 | 36 | 37 | def prepare_for_cml(data, n_trt=None): 38 | data = data.copy() 39 | if n_trt is None: 40 | n_trt = data['t'].max() + 1 41 | 42 | if 't' in data: 43 | t = data.pop('t') 44 | trt = np.array(['control'] + [f'treatment{x}' for x in range(1, n_trt)]) 45 | data['treatment'] = trt[t] 46 | 47 | if 'p' in data: 48 | p = data.pop('p') 49 | data['p'] = {f'treatment{x}': p[:, x - 1] for x in range(1, n_trt)} 50 | 51 | return data 52 | 53 | 54 | def train_meta_uber(train, test, params, factory): 55 | train, test = train.copy(), test.copy() 56 | nt = train['t'].max() + 1 57 | 58 | oof_pred = np.zeros((train['X'].shape[0], nt - 1), dtype=np.float32) 59 | test_pred = np.zeros((test['X'].shape[0], nt - 1), dtype=np.float32) 60 | 61 | folds = StratifiedKFold(5, shuffle=True, random_state=42) 62 | 63 | scores = { 64 | 'valid_ext': [], 65 | 'test_ext': [], 66 | } 67 | 68 | if 'effect' in train: 69 | scores['valid_mse_ext'] = [] 70 | scores['test_mse_ext'] = [] 71 | 72 | strf = train['t'] + nt * train['y'] 73 | for n, (f0, f1) in enumerate(folds.split(strf, strf)): 74 | X_tr, X_val = train['X'][f0], train['X'][f1] 75 | t_tr, t_val = train['t'][f0], train['t'][f1] 76 | y_tr, y_val = train['y'][f0], train['y'][f1] 77 | p_tr, p_val = train['p'][f0], train['p'][f1] 78 | model = factory(params) 79 | 80 | _ = model.fit_predict(**prepare_for_cml( 81 | {'X': X_tr, 'y': y_tr, 't': t_tr, 'p': p_tr}, n_trt=nt 82 | ), return_ci=False) 83 | 84 | # baseline 85 | oof_pred[f1] = model.predict(**prepare_for_cml( 86 | {'X': X_val, 'p': p_val}, n_trt=nt 87 | )) 88 | tt = model.predict(**prepare_for_cml( 89 | {'X': test['X'], 'p': test['p']}, n_trt=nt 90 | )) 91 | 92 | scores['test_ext'].append(qini(test['y'], tt, test['t'])) 93 | scores['valid_ext'].append(qini(train['y'][f1], oof_pred[f1], train['t'][f1])) 94 | 95 | # if effect ... 96 | if 'effect' in train: 97 | scores['test_mse_ext'].append(mse(test['effect'], tt, test['t'])) 98 | scores['valid_mse_ext'].append(mse(train['effect'][f1], oof_pred[f1], train['t'][f1])) 99 | 100 | test_pred += tt 101 | 102 | test_pred /= 5 103 | 104 | scores = { 105 | **scores, 106 | 'valid': qini(train['y'], oof_pred, train['t']), 107 | 'test': qini(test['y'], test_pred, test['t']), 108 | 109 | } 110 | 111 | if 'effect' in train: 112 | scores = { 113 | **scores, 114 | 'valid_mse': mse(train['effect'], oof_pred, train['t']), 115 | 'test_mse': mse(test['effect'], test_pred, test['t']), 116 | 117 | } 118 | 119 | return scores, oof_pred, test_pred 120 | 121 | 122 | def train_crf(train, test, params, factory): 123 | train, test = train.copy(), test.copy() 124 | nt = train['t'].max() + 1 125 | 126 | oof_pred = np.zeros((train['X'].shape[0], nt - 1), dtype=np.float32) 127 | test_pred = np.zeros((test['X'].shape[0], nt - 1), dtype=np.float32) 128 | 129 | folds = StratifiedKFold(5, shuffle=True, random_state=42) 130 | 131 | scores = { 132 | 'valid_ext': [], 133 | 'test_ext': [], 134 | } 135 | 136 | if 'effect' in train: 137 | scores['valid_mse_ext'] = [] 138 | scores['test_mse_ext'] = [] 139 | 140 | strf = train['t'] + nt * train['y'] 141 | treat = 1 - np.isnan(train['trg'][:, 1:]) 142 | 143 | for n, (f0, f1) in enumerate(folds.split(strf, strf)): 144 | X_tr, X_val = train['X'][f0], train['X'][f1] 145 | t_tr, t_val = treat[f0], treat[f1] 146 | y_tr, y_val = train['y'][f0], train['y'][f1] 147 | model = factory(params) 148 | 149 | model.fit(X_tr, t_tr, y_tr) 150 | 151 | # baseline 152 | oof_pred[f1] = model.predict(X_val) 153 | tt = model.predict(test['X']) 154 | 155 | scores['test_ext'].append(qini(test['y'], tt, test['t'])) 156 | scores['valid_ext'].append(qini(train['y'][f1], oof_pred[f1], train['t'][f1])) 157 | 158 | # if effect ... 159 | if 'effect' in train: 160 | scores['test_mse_ext'].append(mse(test['effect'], tt, test['t'])) 161 | scores['valid_mse_ext'].append(mse(train['effect'][f1], oof_pred[f1], train['t'][f1])) 162 | 163 | test_pred += tt 164 | 165 | test_pred /= 5 166 | 167 | scores = { 168 | **scores, 169 | 'valid': qini(train['y'], oof_pred, train['t']), 170 | 'test': qini(test['y'], test_pred, test['t']), 171 | 172 | } 173 | 174 | if 'effect' in train: 175 | scores = { 176 | **scores, 177 | 'valid_mse': mse(train['effect'], oof_pred, train['t']), 178 | 'test_mse': mse(test['effect'], test_pred, test['t']), 179 | 180 | } 181 | 182 | return scores, oof_pred, test_pred 183 | 184 | 185 | def train_drnet(train, test, params, factory): 186 | train, test = train.copy(), test.copy() 187 | nt = train['t'].max() + 1 188 | 189 | oof_pred = np.zeros((train['X'].shape[0], nt - 1), dtype=np.float32) 190 | test_pred = np.zeros((test['X'].shape[0], nt - 1), dtype=np.float32) 191 | 192 | folds = StratifiedKFold(5, shuffle=True, random_state=42) 193 | 194 | scores = { 195 | 'valid_ext': [], 196 | 'test_ext': [], 197 | } 198 | 199 | if 'effect' in train: 200 | scores['valid_mse_ext'] = [] 201 | scores['test_mse_ext'] = [] 202 | 203 | strf = train['t'] + nt * train['y'] 204 | treat = 1 - np.isnan(train['trg'][:, 1:]) 205 | 206 | for n, (f0, f1) in enumerate(folds.split(strf, strf)): 207 | 208 | X_tr, X_val = train['X'][f0], train['X'][f1] 209 | t_tr, t_val = treat[f0], treat[f1] 210 | y_tr, y_val = train['y'][f0], train['y'][f1] 211 | 212 | tt = np.zeros((test['X'].shape[0], t_tr.shape[1]), dtype=np.float32) 213 | 214 | for t in range(treat.shape[1]): 215 | model = factory(params) 216 | sl_tr = (t_tr.sum(axis=1) == 0) | (t_tr[:, t] == 1) 217 | sl_val = (t_val.sum(axis=1) == 0) | (t_val[:, t] == 1) 218 | 219 | model.fit(X_tr[sl_tr], y_tr[sl_tr], t_tr[sl_tr][:, t], X_val[sl_val], y_val[sl_val], t_val[sl_val][:, t]) 220 | oof_pred[f1, t] = model.predict(X_val) 221 | tt[:, t] = model.predict(test['X']) 222 | 223 | scores['test_ext'].append(qini(test['y'], tt, test['t'])) 224 | scores['valid_ext'].append(qini(train['y'][f1], oof_pred[f1], train['t'][f1])) 225 | 226 | # if effect ... 227 | if 'effect' in train: 228 | scores['test_mse_ext'].append(mse(test['effect'], tt, test['t'])) 229 | scores['valid_mse_ext'].append(mse(train['effect'][f1], oof_pred[f1], train['t'][f1])) 230 | 231 | test_pred += tt 232 | 233 | test_pred /= 5 234 | 235 | scores = { 236 | **scores, 237 | 'valid': qini(train['y'], oof_pred, train['t']), 238 | 'test': qini(test['y'], test_pred, test['t']), 239 | 240 | } 241 | 242 | if 'effect' in train: 243 | scores = { 244 | **scores, 245 | 'valid_mse': mse(train['effect'], oof_pred, train['t']), 246 | 'test_mse': mse(test['effect'], test_pred, test['t']), 247 | 248 | } 249 | 250 | return scores, oof_pred, test_pred 251 | 252 | 253 | def train_pb_upd_weight(train, test, params, factory): 254 | train, test = train.copy(), test.copy() 255 | nt = train['t'].max() + 1 256 | 257 | for data in [train, test]: 258 | t = data['t'] 259 | 260 | trg = np.zeros((t.shape[0], nt), dtype=np.float32) 261 | trg[:] = np.nan 262 | rows, cols = np.nonzero(t[:, np.newaxis] == np.arange(nt)[np.newaxis, :]) 263 | trg[rows, cols] = data['y'] 264 | 265 | data['new_y'] = trg 266 | 267 | oof_pb = [np.zeros((train['X'].shape[0], nt - 1), dtype=np.float32) for _ in range(11)] 268 | test_pb = [np.zeros((test['X'].shape[0], nt - 1), dtype=np.float32) for _ in range(11)] 269 | 270 | folds = StratifiedKFold(5, shuffle=True, random_state=42) 271 | 272 | scores = { 273 | 274 | 'valid_ext': [], 275 | 'test_ext': [], 276 | 'valid_ext_w': [], 277 | 'test_ext_w': [] 278 | 279 | } 280 | 281 | if 'effect' in train: 282 | scores['valid_mse_ext'] = [] 283 | scores['test_mse_ext'] = [] 284 | scores['valid_mse_ext_w'] = [] 285 | scores['test_mse_ext_w'] = [] 286 | 287 | models = [] 288 | 289 | strf = train['t'] + nt * train['y'] 290 | 291 | for n, (f0, f1) in enumerate(folds.split(strf, strf)): 292 | 293 | X_tr, X_val = train['X'][f0], train['X'][f1] 294 | y_tr, y_val = train['new_y'][f0], train['new_y'][f1] 295 | 296 | model = factory(params) 297 | 298 | model.fit(X_tr, y_tr, eval_sets=[{'X': X_val, 'y': y_val}]) 299 | models.append(model) 300 | 301 | idx = np.searchsorted(np.linspace(0, 1, 11), model.loss.weight) 302 | 303 | # baseline 304 | # get valid scores and test scores for each w 305 | vs, ts = [], [] 306 | vm, tm = [], [] 307 | 308 | for k, w in enumerate(np.linspace(0, 1, 11)): 309 | model.loss.weight = w 310 | 311 | oof_pb[k][f1] = model.predict(X_val, batch_size=1e10) 312 | tt = model.predict(test['X'], batch_size=1e10) 313 | test_pb[k] += tt 314 | 315 | vs.append( 316 | qini(train['y'][f1], oof_pb[k][f1], train['t'][f1]) 317 | ) 318 | 319 | ts.append( 320 | qini(test['y'], tt, test['t']) 321 | ) 322 | 323 | if 'effect' in train: 324 | vm.append( 325 | mse(train['effect'][f1], oof_pb[k][f1], train['t'][f1]) 326 | ) 327 | 328 | tm.append( 329 | mse(test['effect'], tt, test['t']) 330 | ) 331 | 332 | scores['valid_ext_w'].append(vs) 333 | scores['test_ext_w'].append(ts) 334 | 335 | if 'effect' in train: 336 | scores['valid_mse_ext_w'].append(vm) 337 | scores['test_mse_ext_w'].append(tm) 338 | 339 | scores['valid_ext_w'] = np.array(scores['valid_ext_w']) 340 | scores['test_ext_w'] = np.array(scores['test_ext_w']) 341 | 342 | # print(scores['valid_ext_w'].shape) 343 | scores['valid_w'] = [ 344 | qini(train['y'], oof_pb[x], train['t']) for x in range(11) 345 | 346 | ] 347 | 348 | scores['test_w'] = [ 349 | qini(test['y'], test_pb[x], test['t']) for x in range(11) 350 | 351 | ] 352 | 353 | scores['best_k'] = np.argmax(np.mean(scores['valid_w'], axis=1)) 354 | 355 | scores['valid_ext'] = scores['valid_ext_w'][:, idx] 356 | scores['test_ext'] = scores['test_ext_w'][:, idx] 357 | 358 | if 'effect' in train: 359 | scores['valid_mse_ext_w'] = np.array(scores['valid_mse_ext_w']) 360 | scores['test_mse_ext_w'] = np.array(scores['test_mse_ext_w']) 361 | 362 | scores['valid_mse_w'] = [ 363 | mse(train['effect'], oof_pb[x], train['t']) for x in range(11) 364 | 365 | ] 366 | 367 | scores['test_mse_w'] = [ 368 | mse(test['effect'], test_pb[x], test['t']) for x in range(11) 369 | 370 | ] 371 | 372 | scores['valid_mse_ext'] = scores['valid_mse_ext_w'][:, idx] 373 | scores['test_mse_ext'] = scores['test_mse_ext_w'][:, idx] 374 | 375 | for tt in test_pb: 376 | tt /= 5 377 | 378 | scores = { 379 | **scores, 380 | 'valid': scores['valid_w'][scores['best_k']], 381 | 'test': scores['test_w'][scores['best_k']], 382 | } 383 | 384 | if 'effect' in train: 385 | scores = { 386 | **scores, 387 | 'valid_mse': scores['valid_mse_w'][scores['best_k']], 388 | 'test_mse': scores['test_mse_w'][scores['best_k']], 389 | } 390 | 391 | return scores, oof_pb, test_pb 392 | -------------------------------------------------------------------------------- /source/descn/descn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import sys 4 | import uuid 5 | from time import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from sklearn.exceptions import NotFittedError 11 | from sklearn.preprocessing import StandardScaler, OneHotEncoder 12 | from torch.utils.data import TensorDataset, DataLoader 13 | 14 | from .models import ShareNetwork, PrpsyNetwork, Mu1Network, Mu0Network, TauNetwork, ESX 15 | from .util import wasserstein_torch, mmd2_torch 16 | 17 | 18 | class Config: 19 | lr = 0.001 20 | decay_rate = 0.95 21 | decay_step_size = 1 22 | l2 = 0.001 23 | model_name = "model_name" 24 | 25 | n_experiments = 1 26 | batch_size = 3000 27 | share_dim = 128 28 | base_dim = 64 29 | reweight_sample = 1 30 | val_rate = 0.01 31 | do_rate = 0.1 32 | normalization = "divide" 33 | epochs = 10 34 | log_step = 100 35 | pred_step = 1 36 | optim = 'Adam' 37 | 38 | BatchNorm1d = True 39 | # loss weights 40 | prpsy_w = 0.5 41 | escvr1_w = 0.5 42 | escvr0_w = 1 43 | 44 | h1_w = 0 45 | h0_w = 0 46 | # ***sub space's loss weights 47 | mu0hat_w = 0.5 48 | mu1hat_w = 1 49 | 50 | # CFR loss 51 | # wass,mmd 52 | imb_dist = 'wass' 53 | # if imb_dist_w <=0 mean no use imb_dist_loss 54 | imb_dist_w = 0.0 55 | 56 | def __init__(self, **kwargs): 57 | for key in kwargs: 58 | self.__dict__[key] = kwargs[key] 59 | 60 | 61 | class EarlyStopper: 62 | def __init__(self, temp_folder, patience=15, min_delta=0, ): 63 | self.patience = patience 64 | self.min_delta = min_delta 65 | self.counter = 0 66 | self.min_validation_loss = np.inf 67 | self.temp_folder = temp_folder 68 | os.makedirs(temp_folder, exist_ok=False) 69 | 70 | def save(self, model): 71 | torch.save(model.state_dict(), os.path.join(self.temp_folder, 'checkpoint.pth')) 72 | 73 | def early_stop(self, validation_loss, model): 74 | if validation_loss < self.min_validation_loss: 75 | self.min_validation_loss = validation_loss 76 | self.counter = 0 77 | self.save(model) 78 | # torch.save(model.state_dict(), os.path.join(self.temp_folder, 'checkpoint.pth')) 79 | 80 | 81 | elif validation_loss > (self.min_validation_loss + self.min_delta): 82 | self.counter += 1 83 | if self.counter >= self.patience: 84 | return True 85 | return False 86 | 87 | def load(self, model): 88 | 89 | model.load_state_dict(torch.load(os.path.join(self.temp_folder, 'checkpoint.pth'))) 90 | return model 91 | 92 | def clear(self): 93 | 94 | shutil.rmtree(self.temp_folder) 95 | 96 | 97 | class DESCNNet: 98 | 99 | def __init__( 100 | self, 101 | share_scale=2., 102 | base_scale=.5, 103 | steps_per_epoch=150, 104 | data_loader_num_workers=4, 105 | device='cuda', 106 | es=10, 107 | cat_cols=None, 108 | cat_params=None, 109 | **kwargs 110 | ): 111 | 112 | self.temp_folder = str(uuid.uuid4()) + f'_{time()}' 113 | 114 | self.share_scale = share_scale 115 | self.base_scale = base_scale 116 | self.steps_per_epoch = steps_per_epoch 117 | 118 | self.kwargs = kwargs 119 | 120 | self.num_workers = data_loader_num_workers 121 | 122 | self.train_dataloader = None 123 | self.valid_dataloader = None 124 | self.device = device 125 | self.scaler = StandardScaler() 126 | self.es = es 127 | self.cat_cols = [] if cat_cols is None else cat_cols 128 | if cat_params is None: 129 | cat_params = { 130 | 'min_frequency': 10, 131 | 'max_categories': 100, 132 | 'handle_unknown': 'infrequent_if_exist', 133 | 'sparse_output': False 134 | } 135 | self.enc = OneHotEncoder(**cat_params) 136 | 137 | def create_model(self, x): 138 | 139 | x = self.preprocess(x) 140 | nrows, input_dim = x.shape 141 | print(f'Train shape: {x.shape}') 142 | 143 | device = self.device 144 | 145 | share_dim = max(int(self.share_scale * input_dim), 4) 146 | base_dim = max(int(self.base_scale * share_dim), 2) 147 | 148 | batch_size = max(32, int(nrows / self.steps_per_epoch)) 149 | 150 | print(f'Dataset specific params: share_dim={share_dim}; base_dim={base_dim}; batch_size={batch_size}') 151 | 152 | cfg = Config(share_dim=share_dim, base_dim=base_dim, batch_size=batch_size, **self.kwargs) 153 | self.cfg = cfg 154 | 155 | shareNetwork = ShareNetwork(input_dim=input_dim, share_dim=share_dim, base_dim=base_dim, cfg=cfg, device=device) 156 | prpsy_network = PrpsyNetwork(base_dim, cfg=cfg) 157 | mu1_network = Mu1Network(base_dim, cfg=cfg) 158 | mu0_network = Mu0Network(base_dim, cfg=cfg) 159 | tau_network = TauNetwork(base_dim, cfg=cfg) 160 | 161 | model = ESX(prpsy_network, mu1_network, mu0_network, tau_network, shareNetwork, cfg, device) 162 | self.model = model.to(device) 163 | optim = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.l2) 164 | lr_scheduler = torch.optim.lr_scheduler.StepLR( 165 | optimizer=optim, step_size=cfg.decay_step_size, 166 | gamma=cfg.decay_rate 167 | ) 168 | 169 | return x, cfg, optim, lr_scheduler 170 | 171 | def preprocess(self, x): 172 | 173 | if len(self.cat_cols) > 0: 174 | not_cat = np.setdiff1d(np.arange(x.shape[1]), self.cat_cols) 175 | x_cat = x[:, self.cat_cols] 176 | x = x[:, not_cat] 177 | 178 | try: 179 | x_cat = self.enc.transform(x_cat) 180 | except NotFittedError: 181 | x_cat = self.enc.fit_transform(x_cat) 182 | 183 | try: 184 | x = self.scaler.transform(x) 185 | except NotFittedError: 186 | x = self.scaler.fit_transform(x) 187 | 188 | if len(self.cat_cols) > 0: 189 | x = np.concatenate([x, x_cat], axis=1) 190 | 191 | return x 192 | 193 | def create_dataloaders(self, x, y, t, x_v=None, y_v=None, t_v=None): 194 | """ 195 | Utility function to create train and validation data loader: 196 | 197 | Parameters 198 | ---------- 199 | x: np.array 200 | covariates 201 | y: np.array 202 | target variable 203 | t: np.array 204 | treatment 205 | """ 206 | 207 | x = torch.Tensor(x) 208 | t = torch.Tensor(t).reshape(-1, 1) 209 | y = torch.Tensor(y).reshape(-1, 1) 210 | train_dataset = TensorDataset(x, t, y) 211 | self.train_dataloader = DataLoader( 212 | train_dataset, batch_size=self.cfg.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True 213 | ) 214 | 215 | if x_v is not None: 216 | x_v = self.preprocess(x_v) 217 | x_v = torch.Tensor(x_v) 218 | y_v = torch.Tensor(y_v).reshape(-1, 1) 219 | t_v = torch.Tensor(t_v).reshape(-1, 1) 220 | valid_dataset = TensorDataset(x_v, t_v, y_v) 221 | self.valid_dataloader = DataLoader( 222 | valid_dataset, batch_size=self.cfg.batch_size, num_workers=self.num_workers, shuffle=False 223 | ) 224 | 225 | def fit(self, x, y, t, x_v=None, y_v=None, t_v=None): 226 | """ 227 | Function used to train the dragonnet model 228 | 229 | Parameters 230 | ---------- 231 | x: np.array 232 | covariates 233 | y: np.array 234 | target variable 235 | t: np.array 236 | treatment 237 | """ 238 | x, cfg, optim, lr_scheduler = self.create_model(x) 239 | 240 | self.create_dataloaders(x, y, t, x_v, y_v, t_v) 241 | early_stopper = EarlyStopper(self.temp_folder, patience=self.es, min_delta=0) 242 | early_stopper.save(self.model) 243 | 244 | for epoch in range(self.cfg.epochs): 245 | 246 | self.model.train() 247 | 248 | for batch, (X, tr, y1) in enumerate(self.train_dataloader): 249 | X, tr, y1 = X.to(self.device), tr.to(self.device), y1.to(self.device) 250 | 251 | p_prpsy_logit, p_estr, p_escr, p_tau_logit, p_mu1_logit, p_mu0_logit, p_prpsy, p_mu1, p_mu0, p_h1, p_h0, shared_h \ 252 | = self.model(X) 253 | 254 | # y0_pred, y1_pred, t_pred, eps = self.model(X) 255 | try: 256 | 257 | loss = self.loss_f( 258 | tr, y1, p_prpsy_logit, p_estr, p_escr, p_tau_logit, 259 | p_mu1_logit, p_mu0_logit, p_prpsy, p_mu1, p_mu0, p_h1, p_h0, shared_h 260 | ) 261 | optim.zero_grad() 262 | loss.backward() 263 | optim.step() 264 | 265 | except Exception: 266 | continue 267 | 268 | lr_scheduler.step() 269 | 270 | if self.valid_dataloader: 271 | valid_loss = self.validate_step() 272 | print( 273 | f"epoch: {epoch}--------- train_loss: {loss} ----- valid_loss: {valid_loss}" 274 | ) 275 | if early_stopper.early_stop(valid_loss, self.model): 276 | break 277 | else: 278 | print(f"epoch: {epoch}--------- train_loss: {loss}") 279 | 280 | self.model = early_stopper.load(self.model) 281 | early_stopper.clear() 282 | 283 | def loss_f(self, t_labels, y_labels, *args): 284 | 285 | p_prpsy_logit, p_estr, p_escr, p_tau_logit, p_mu1_logit, p_mu0_logit, p_prpsy, p_mu1, p_mu0, p_h1, p_h0, shared_h = args 286 | 287 | e_labels = torch.zeros_like(t_labels).to(t_labels.device) 288 | 289 | p_t = torch.mean(t_labels).item() 290 | if self.cfg.reweight_sample: 291 | w_t = t_labels / ( 292 | 2 * p_t) 293 | w_c = (1 - t_labels) / (2 * (1 - p_t)) 294 | sample_weight = w_t + w_c 295 | else: 296 | sample_weight = torch.ones_like(t_labels) 297 | p_t = 0.5 298 | 299 | # set loss functions 300 | sample_weight = sample_weight[~e_labels.bool()] 301 | loss_w_fn = nn.BCELoss(weight=sample_weight) 302 | loss_fn = nn.BCELoss() 303 | loss_mse = nn.MSELoss() 304 | loss_with_logit_fn = nn.BCEWithLogitsLoss() # for logit 305 | loss_w_with_logit_fn = nn.BCEWithLogitsLoss( 306 | pos_weight=torch.tensor(1 / (2 * p_t))) # for propensity loss 307 | 308 | # calc loss 309 | # p_prpsy_logit, p_estr, p_escr, p_tau_logit, p_mu1_logit, p_mu0_logit, p_prpsy, p_mu1, p_mu0, p_h1, p_h0, shared_h 310 | 311 | # try: 312 | # loss for propensity 313 | prpsy_loss = self.cfg.prpsy_w * loss_w_with_logit_fn(p_prpsy_logit[~e_labels.bool()], 314 | t_labels[~e_labels.bool()]) 315 | # loss for ESTR, ESCR 316 | estr_loss = self.cfg.escvr1_w * loss_w_fn(p_estr[~e_labels.bool()], 317 | (y_labels * t_labels)[~e_labels.bool()]) 318 | escr_loss = self.cfg.escvr0_w * loss_w_fn(p_escr[~e_labels.bool()], 319 | (y_labels * (1 - t_labels))[~e_labels.bool()]) 320 | 321 | # loss for TR, CR 322 | tr_loss = self.cfg.h1_w * loss_fn(p_h1[t_labels.bool()], 323 | y_labels[t_labels.bool()]) # * (1 / (2 * p_t)) 324 | cr_loss = self.cfg.h0_w * loss_fn(p_h0[~t_labels.bool()], 325 | y_labels[~t_labels.bool()]) # * (1 / (2 * (1 - p_t))) 326 | 327 | # loss for cross TR: mu1_prime, cross CR: mu0_prime 328 | cross_tr_loss = self.cfg.mu1hat_w * loss_fn(torch.sigmoid(p_mu0_logit + p_tau_logit)[t_labels.bool()], 329 | y_labels[t_labels.bool()]) 330 | cross_cr_loss = self.cfg.mu0hat_w * loss_fn(torch.sigmoid(p_mu1_logit - p_tau_logit)[~t_labels.bool()], 331 | y_labels[~t_labels.bool()]) 332 | 333 | imb_dist = 0 334 | if self.cfg.imb_dist_w > 0: 335 | if self.cfg.imb_dist == "wass": 336 | imb_dist = wasserstein_torch(X=shared_h, t=t_labels) 337 | elif self.cfg.imb_dist == "mmd": 338 | imb_dist = mmd2_torch(shared_h, t_labels) 339 | else: 340 | sys.exit(1) 341 | imb_dist_loss = self.cfg.imb_dist_w * imb_dist 342 | 343 | total_loss = prpsy_loss + estr_loss + escr_loss \ 344 | + tr_loss + cr_loss \ 345 | + cross_tr_loss + cross_cr_loss \ 346 | + imb_dist_loss 347 | 348 | return total_loss 349 | 350 | def validate_step(self): 351 | """ 352 | Calculates validation loss 353 | 354 | Returns 355 | ------- 356 | valid_loss: torch.Tensor 357 | validation loss 358 | """ 359 | 360 | self.model.eval() 361 | 362 | valid_loss = [] 363 | with torch.no_grad(): 364 | for batch, (X, tr, y1) in enumerate(self.valid_dataloader): 365 | 366 | X, tr, y1 = X.to(self.device), tr.to(self.device), y1.to(self.device) 367 | p_prpsy_logit, p_estr, p_escr, p_tau_logit, p_mu1_logit, p_mu0_logit, p_prpsy, p_mu1, p_mu0, p_h1, p_h0, shared_h \ 368 | = self.model(X) 369 | 370 | try: 371 | loss = self.loss_f( 372 | tr, y1, p_prpsy_logit, p_estr, p_escr, p_tau_logit, 373 | p_mu1_logit, p_mu0_logit, p_prpsy, p_mu1, p_mu0, p_h1, p_h0, shared_h 374 | ) 375 | except Exception: 376 | continue 377 | 378 | valid_loss.append(loss) 379 | return torch.Tensor(valid_loss).mean() 380 | 381 | def predict(self, x): 382 | """ 383 | Function used to predict on covariates. 384 | 385 | Parameters 386 | ---------- 387 | x: torch.Tensor or numpy.array 388 | covariates 389 | 390 | Returns 391 | ------- 392 | 393 | """ 394 | self.model.eval() 395 | 396 | res = np.zeros((x.shape[0],), dtype=np.float32) 397 | x = self.preprocess(x) 398 | x = torch.Tensor(x) 399 | 400 | ds = TensorDataset(x) 401 | dl = DataLoader( 402 | ds, batch_size=self.cfg.batch_size, num_workers=self.num_workers, shuffle=False 403 | ) 404 | 405 | with torch.no_grad(): 406 | for n, (batch,) in enumerate(dl): 407 | batch = batch.to(self.device) 408 | p_prpsy_logit, p_estr, p_escr, p_tau_logit, p_mu1_logit, p_mu0_logit, p_prpsy, p_mu1, p_mu0, p_h1, p_h0, shared_h \ 409 | = self.model(batch) 410 | res[n * self.cfg.batch_size: (n + 1) * self.cfg.batch_size] = (p_h1 - p_h0).detach().cpu().numpy()[:, 0] 411 | 412 | return res 413 | -------------------------------------------------------------------------------- /source/pb_utils.py: -------------------------------------------------------------------------------- 1 | import cupy as cp 2 | import numpy as np 3 | from py_boost.callbacks.callback import Callback 4 | from py_boost.gpu.boosting import GradientBoosting 5 | from py_boost.gpu.losses import Loss, BCELoss, BCEMetric, MSELoss, Metric 6 | from py_boost.multioutput.sketching import GradSketch, RandomProjectionSketch, RandomSamplingSketch 7 | from py_boost.multioutput.target_splitter import SingleSplitter 8 | 9 | 10 | def get_treshold_stats(fact, pred): 11 | order = pred.argsort()[::-1] 12 | 13 | sorted_y = fact[order] 14 | uplift = pred[order] 15 | 16 | idx = cp.r_[ 17 | cp.nonzero(cp.diff(uplift))[0], 18 | order.shape[0] - 1 19 | ] 20 | cs = cp.nancumsum(sorted_y, axis=0)[idx] 21 | cc = cp.nancumsum((~cp.isnan(sorted_y)).astype(cp.float32), axis=0)[idx] 22 | 23 | return cs, cc, idx 24 | 25 | 26 | def get_qini_curve(fact, pred): 27 | cs, cc, idx = get_treshold_stats(fact, pred) 28 | curve = cs[:, 1] - cp.where(cc[:, 0] > 0, cs[:, 0] * cc[:, 1] / cc[:, 0], 0) 29 | return idx + 1, curve 30 | 31 | 32 | def get_perfect_and_baseline_qini(fact): 33 | y_ = cp.nansum(fact, axis=1) 34 | t_ = cp.isnan(fact[:, 0]).astype(cp.float32) 35 | 36 | perfect = (y_ * t_ - y_ * (1 - t_)) 37 | 38 | x, y = get_qini_curve(fact, perfect) 39 | x = cp.r_[0, x] 40 | y = cp.r_[0, y] 41 | score_perfect = float(cp.trapz(y, x)) 42 | 43 | x, y = np.array([0, float(x[-1])]), np.array([0, float(y[-1])]) 44 | score_baseline = np.trapz(y, x) 45 | 46 | return score_perfect, score_baseline 47 | 48 | 49 | class BCEwithNaNMetric(BCEMetric): 50 | 51 | def __call__(self, y_true, y_pred, sample_weight=None): 52 | mask = ~cp.isnan(y_true) 53 | 54 | err = super().error(cp.where(mask, y_true, 0), y_pred) 55 | err = err * mask 56 | 57 | if sample_weight is not None: 58 | err = err * sample_weight 59 | mask = mask * sample_weight 60 | 61 | return float(err.sum() / mask.sum()) 62 | 63 | 64 | class BCEWithNaNLoss(BCELoss): 65 | 66 | def __init__(self, uplift=False): 67 | self.uplift = uplift 68 | self.clip_value = 1e-6 69 | 70 | def base_score(self, y_true): 71 | # Replace .mean with nanmean function to calc base score 72 | means = cp.nanmean(y_true, axis=0) 73 | means = cp.where(cp.isnan(means), 0, means) 74 | means = cp.clip(means, self.clip_value, 1 - self.clip_value) 75 | 76 | return cp.log(means / (1 - means)) 77 | # return cp.zeros(y_true.shape[1], dtype=cp.float32) 78 | 79 | def get_grad_hess(self, y_true, y_pred): 80 | # first, get nan mask for y_true 81 | mask = cp.isnan(y_true) 82 | # then, compute loss with any values at nan places just to prevent the exception 83 | grad, hess = super().get_grad_hess(cp.where(mask, 0, y_true), y_pred) 84 | # invert mask 85 | mask = (~mask).astype(cp.float32) 86 | # multiply grad and hess on inverted mask 87 | # now grad and hess eq. 0 on NaN points 88 | # that actually means that prediction on that place should not be updated 89 | grad = grad * mask 90 | hess = hess * mask 91 | 92 | return grad, hess 93 | 94 | def postprocess_output(self, y_pred): 95 | y_pred = super().postprocess_output(y_pred) 96 | 97 | if self.uplift: 98 | uplift = y_pred[:, 1:] - y_pred[:, :1] 99 | 100 | return uplift 101 | 102 | return y_pred 103 | 104 | 105 | class MSEWithNaNLoss(MSELoss): 106 | 107 | def get_grad_hess(self, y_true, y_pred): 108 | # first, get nan mask for y_true 109 | mask = cp.isnan(y_true) 110 | # then, compute loss with any values at nan places just to prevent the exception 111 | grad, hess = super().get_grad_hess(cp.where(mask, 0, y_true), y_pred) 112 | # invert mask 113 | mask = (~mask).astype(cp.float32) 114 | # multiply grad and hess on inverted mask 115 | # now grad and hess eq. 0 on NaN points 116 | # that actually means that prediction on that place should not be updated 117 | grad = grad * mask 118 | hess = hess * mask 119 | 120 | return grad, hess 121 | 122 | 123 | class QINIMetric(Metric, Callback): 124 | 125 | def __init__(self, freq=1): 126 | 127 | self.freq = freq 128 | self.value = None 129 | self.n = None 130 | self.base = None 131 | self.perf = None 132 | self.trt_sl = None 133 | 134 | self.last_score = None 135 | 136 | def before_iteration(self, build_info): 137 | 138 | self.n = build_info['num_iter'] 139 | 140 | def before_train(self, build_info): 141 | 142 | y_true = build_info['data']['valid']['target'] 143 | assert len(y_true) <= 1, 'Only single dataset is avaliable to evaluate' 144 | y_true = y_true[0] 145 | 146 | nnans = ~np.isnan(y_true) 147 | self.trt_sl = nnans[:, 1:] | nnans[:, :1] 148 | 149 | self.n = None 150 | self.base, self.perf = [], [] 151 | 152 | for i in range(y_true.shape[1] - 1): 153 | cols = [0, i + 1] 154 | sl = cp.nonzero(self.trt_sl[:, i])[0] 155 | fact = y_true[:, cols][sl] 156 | perf, base = get_perfect_and_baseline_qini(fact) 157 | self.perf.append(perf) 158 | self.base.append(base) 159 | 160 | return 161 | 162 | def after_train(self, build_info): 163 | 164 | self.__init__(self.freq) 165 | 166 | def __call__(self, y_true, y_pred, sample_weight=None): 167 | 168 | if (self.n % self.freq) == 0: 169 | 170 | qinis = [] 171 | 172 | for i in range(y_pred.shape[1]): 173 | cols = [0, i + 1] 174 | sl = cp.nonzero(self.trt_sl[:, i])[0] 175 | fact = y_true[:, cols][sl] 176 | pred = y_pred[:, i][sl] 177 | 178 | x, y = get_qini_curve(fact, pred) 179 | q = float(cp.trapz(y, x)) 180 | score = (q - self.base[i]) / (self.perf[i] - self.base[i]) 181 | qinis.append(score) 182 | 183 | self.last_score = np.mean(qinis) 184 | 185 | return self.last_score 186 | 187 | def compare(self, v0, v1): 188 | 189 | return v0 > v1 190 | 191 | 192 | class UpliftSketch(GradSketch): 193 | 194 | def __call__(self, grad, hess): 195 | grad = grad.sum(axis=1, keepdims=True) 196 | hess = hess.sum(axis=1, keepdims=True) 197 | # hess = cp.ones_like(grad) 198 | 199 | return grad, hess 200 | 201 | 202 | class MixedUpliftSketch(UpliftSketch): 203 | 204 | def __init__(self): 205 | self.base_sketch = RandomProjectionSketch(1) 206 | 207 | def __call__(self, grad, hess): 208 | bg, bh = self.base_sketch(grad, hess) 209 | ug, uh = super().__call__(grad, hess) 210 | 211 | return cp.concatenate([bg, ug], axis=1), uh 212 | 213 | 214 | class RFCallback(Callback): 215 | 216 | def process(self, ens, last_pred, base_score, n, lr): 217 | 218 | # clean ensemble from prediction 219 | ens = ens - last_pred - base_score 220 | # add as mean 221 | n = n + 1 222 | ens = base_score + ens * ((n - 1) / n) + last_pred / (n * lr) 223 | 224 | return ens 225 | 226 | def after_iteration(self, build_info): 227 | 228 | train = build_info['data']['train'] 229 | valid = build_info['data']['valid'] 230 | 231 | train['ensemble'][:] = build_info['model'].base_score 232 | 233 | for i in range(len(valid['ensemble'])): 234 | valid['ensemble'][i] = self.process( 235 | valid['ensemble'][i], 236 | valid['last_tree']['preds'][i], 237 | build_info['model'].base_score, 238 | build_info['num_iter'], 239 | build_info['model'].lr 240 | ) 241 | 242 | return False 243 | 244 | def after_train(self, build_info): 245 | 246 | model = build_info['model'] 247 | trees, lr = model.models, model.lr 248 | n = len(trees) 249 | 250 | for i in range(n): 251 | trees[i].values = trees[i].values / (lr * n) 252 | 253 | return 254 | 255 | 256 | class RandomSamplingSketchX(RandomSamplingSketch): 257 | 258 | def before_iteration(self, build_info): 259 | super().before_iteration(build_info) 260 | self.num_iter = build_info['num_iter'] 261 | 262 | def __call__(self, grad, hess): 263 | if np.random.rand() > .8: 264 | return grad, hess 265 | 266 | return super().__call__(grad, hess) 267 | 268 | 269 | class PyBoostClassifier: 270 | 271 | def __init__(self, *args, **kwargs): 272 | self.args = args 273 | self.kwargs = kwargs 274 | self.model = None 275 | self.loss = 'bce' 276 | 277 | def fit(self, X, y, sample_weight=None): 278 | print(X.shape, y.shape) 279 | 280 | self.model = GradientBoosting(self.loss, *self.args, **self.kwargs) 281 | self.model.fit(X, y, sample_weight=sample_weight) 282 | 283 | return self 284 | 285 | def predict_proba(self, X): 286 | pred = self.model.predict(X) 287 | return np.concatenate([1 - pred, pred], axis=1) 288 | 289 | def predict(self, X): 290 | return self.predict_proba(X).argmax(axis=1) 291 | 292 | def fit_predict(self, X, y, sample_weight=None): 293 | self.fit(X, y, sample_weight=sample_weight) 294 | return self.predict(X) 295 | 296 | 297 | class PyBoostRegressor(PyBoostClassifier): 298 | 299 | def __init__(self, *args, **kwargs): 300 | super().__init__(*args, **kwargs) 301 | self.loss = 'mse' 302 | 303 | def predict(self, X): 304 | return self.model.predict(X)[:, 0] 305 | 306 | 307 | class ComposedUpliftLoss(Loss, Callback): 308 | 309 | def __init__(self, base_loss, start_iter=10, weight=.5, masked=False): 310 | 311 | self.base_loss = base_loss 312 | self.uplift_loss = MSEWithNaNLoss() 313 | self.start_iter = start_iter 314 | self.n = 0 315 | self.weight = weight 316 | self.masked = masked 317 | 318 | def before_iteration(self, build_info): 319 | 320 | self.n = build_info['num_iter'] 321 | 322 | def base_score(self, y_true): 323 | 324 | score_base = self.base_loss.base_score(y_true) 325 | score_upl = cp.nanmean(y_true, axis=0) 326 | score_upl = score_upl[1:] - score_upl[:1] 327 | 328 | return cp.concatenate([score_base, score_upl]) 329 | 330 | def get_grad_hess(self, y_true, y_pred): 331 | 332 | l = y_true.shape[1] 333 | 334 | grad, hess = self.base_loss.get_grad_hess(y_true, y_pred[:, :l]) 335 | # if self.n < self.start_iter: 336 | # not_null_mask = ~np.isnan(y_true) 337 | # proxy = cp.where(not_null_mask, y_true, self.base_loss.postprocess_output(y_pred[:, :l])) 338 | # uplift = proxy[:, 1:] - proxy[:, :1] 339 | # if self.masked: 340 | # mask = not_null_mask[:, 1:] | not_null_mask[:, :1] # mask the cells with al least one value real 341 | # uplift = cp.where(mask, uplift, np.nan) 342 | 343 | proxy = self.base_loss.postprocess_output( 344 | y_pred[:, :l]) # cp.where(not_null_mask, y_true, self.base_loss.postprocess_output(y_pred[:, :l])) 345 | uplift = proxy[:, 1:] - proxy[:, :1] 346 | # if self.masked: 347 | # mask = not_null_mask[:, 1:] | not_null_mask[:, :1] # mask the cells with al least one value real 348 | # uplift = cp.where(mask, uplift, np.nan) 349 | 350 | if self.n >= self.start_iter: 351 | # print('Good branch') 352 | grad_upl, _ = self.uplift_loss(uplift, y_pred[:, l:]) 353 | # print(grad_upl.mean(axis=0), grad_upl.std(axis=0)) 354 | else: 355 | grad_upl = cp.zeros((grad.shape[0], grad.shape[1] - 1), dtype=cp.float32) 356 | 357 | hess_upl = cp.ones_like(grad_upl) 358 | # if self.masked: 359 | # hess_upl = mask.astype(cp.float32) 360 | # else: 361 | # hess_upl = cp.ones_like(grad_upl) 362 | 363 | grad = cp.concatenate([grad, grad_upl], axis=1) 364 | hess = cp.concatenate([hess, hess_upl], axis=1) 365 | 366 | return grad, hess 367 | 368 | def postprocess_output(self, y_pred): 369 | 370 | # print(y_pred.shape, y_pred.mean(axis=0), y_pred.std(axis=0)) 371 | 372 | l = y_pred.shape[1] // 2 + 1 373 | 374 | base_pred = self.base_loss.postprocess_output(y_pred[:, :l]) 375 | base_pred = base_pred[:, 1:] - base_pred[:, :1] 376 | y_pred = y_pred[:, l:] 377 | 378 | # print(y_pred.std(axis=0), base_pred.std(axis=0)) 379 | 380 | return y_pred * self.weight + base_pred * (1 - self.weight) 381 | 382 | 383 | class UpliftSplitter(SingleSplitter): 384 | 385 | def before_iteration(self, build_info): 386 | """Initialize indexers 387 | 388 | Args: 389 | build_info: dict 390 | 391 | Returns: 392 | 393 | """ 394 | if build_info['num_iter'] == 0: 395 | nout = build_info['data']['train']['grad'].shape[1] // 2 + 1 396 | 397 | self.indexer = [cp.arange(nout, dtype=cp.uint64), cp.arange(nout, nout * 2 - 1, dtype=cp.uint64)] 398 | 399 | def __call__(self): 400 | """Get list of indexers for each group 401 | 402 | Returns: 403 | list of cp.ndarrays of indexers 404 | """ 405 | return self.indexer 406 | 407 | 408 | class UpliftSplitterXN(SingleSplitter): 409 | 410 | def before_iteration(self, build_info): 411 | """Initialize indexers 412 | 413 | Args: 414 | build_info: dict 415 | 416 | Returns: 417 | 418 | """ 419 | if build_info['num_iter'] == 0: 420 | nout = build_info['data']['train']['grad'].shape[1] // 2 + 1 421 | 422 | self.indexer = [ 423 | cp.arange(nout, dtype=cp.uint64), 424 | # cp.arange(nout, nout * 2 - 1, dtype=cp.uint64) 425 | ] + [cp.asarray([x], dtype=cp.uint64) for x in range(nout, nout * 2 - 1)] 426 | 427 | def __call__(self): 428 | """Get list of indexers for each group 429 | 430 | Returns: 431 | list of cp.ndarrays of indexers 432 | """ 433 | return self.indexer 434 | 435 | 436 | class ComposedUpliftSketch(GradSketch): 437 | 438 | def __init__(self, base_sketch=None): 439 | 440 | self.base_sketch = base_sketch 441 | if base_sketch is None: 442 | self.base_sketch = GradSketch() 443 | 444 | self.flg = False 445 | 446 | def before_iteration(self, build_info): 447 | 448 | self.flg = False 449 | 450 | def __call__(self, grad, hess): 451 | 452 | if self.flg: 453 | return self.base_sketch(grad, hess) 454 | 455 | self.flg = True 456 | grad = grad.sum(axis=1, keepdims=True) 457 | hess = hess.sum(axis=1, keepdims=True) 458 | return grad, hess 459 | 460 | 461 | class ComposedUpliftSketchMix(GradSketch): 462 | 463 | def __init__(self, base_sketch=None): 464 | 465 | self.base_sketch = base_sketch 466 | if base_sketch is None: 467 | self.base_sketch = GradSketch() 468 | 469 | self.upl_sketch = MixedUpliftSketch() 470 | 471 | self.flg = False 472 | 473 | def before_iteration(self, build_info): 474 | 475 | self.flg = False 476 | 477 | def __call__(self, grad, hess): 478 | 479 | if self.flg: 480 | return self.base_sketch(grad, hess) 481 | 482 | self.flg = True 483 | grad, hess = self.upl_sketch(grad, hess) 484 | return grad, hess 485 | --------------------------------------------------------------------------------