├── LICENSE ├── requirements.txt ├── TARNet ├── layers │ ├── extract.py │ ├── gather.py │ ├── split.py │ └── MLP.py ├── __init__.py └── models │ └── TARNet.py ├── MANIFEST.in ├── pyproject.toml ├── .gitignore └── readme.md /LICENSE: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /TARNet/layers/extract.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | global-include *.txt *.py *.md -------------------------------------------------------------------------------- /TARNet/__init__.py: -------------------------------------------------------------------------------- 1 | from .models.TARNet import TARNet -------------------------------------------------------------------------------- /TARNet/layers/gather.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class Gather_Streams(tf.keras.layers.Layer): 5 | def __init__(self): 6 | super(Gather_Streams, self).__init__() 7 | 8 | def call(self, inputs): 9 | 10 | x, y = inputs 11 | return tf.dynamic_stitch(y, x) 12 | -------------------------------------------------------------------------------- /TARNet/layers/split.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class Split_Streams(tf.keras.layers.Layer): 5 | def __init__(self): 6 | super(Split_Streams, self).__init__() 7 | 8 | def call(self, inputs): 9 | x, y, z = inputs 10 | 11 | indice_position = tf.reshape( 12 | tf.cast(tf.where(tf.equal(tf.reshape(y, (-1,)), z)), tf.int32), 13 | (-1,), 14 | ) 15 | 16 | return tf.gather(x, indice_position), indice_position 17 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["tensorflow", "pickle"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "TARNet" 7 | version = "1.0" 8 | authors = [ 9 | { name="", email=""}, 10 | ] 11 | description = "TARNet Model with tensorflow 2 API." 12 | readme = "README.md" 13 | license = { file="LICENSE" } 14 | requires-python = ">=3.7" 15 | classifiers = [ 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: MIT License", 18 | "Operating System :: OS Independent", 19 | ] 20 | 21 | [project.urls] 22 | "Homepage" = "https://github.com/arnaud39/TARNet" 23 | "Bug Tracker" = "https://github.com/arnaud39/TARNet/issues" 24 | -------------------------------------------------------------------------------- /TARNet/layers/MLP.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow.keras import regularizers 4 | 5 | 6 | class MLP(tf.keras.layers.Layer): 7 | """Multi Layer Perceptron as a Keras layer.""" 8 | 9 | def __init__( 10 | self, 11 | units: int = 200, 12 | num_layers: int = 1, 13 | kernel_initializer=tf.keras.initializers.HeNormal(), 14 | activation: str = "relu", 15 | name: str = "phi", 16 | reg_l2: float = 0, 17 | ): 18 | """Initiate the layers used by the multi layer perceptron. 19 | 20 | Args: 21 | units (int, optional): _description_. Defaults to 200. 22 | num_layers (int, optional): _description_. Defaults to 1. 23 | kernel_initializer (_type_, optional): _description_. Defaults to tf.keras.initializers.HeNormal(). 24 | activation (str, optional): _description_. Defaults to "relu". 25 | name (str, optional): _description_. Defaults to "phi". 26 | reg_l2 (float, optional): _description_. Defaults to 0. 27 | """ 28 | 29 | super(MLP, self).__init__() 30 | self.layers = [ 31 | tf.keras.layers.Dense( 32 | units=units, 33 | activation=activation, 34 | kernel_initializer=kernel_initializer, 35 | kernel_regularizer=regularizers.l2(reg_l2), 36 | name=f"{name}_{k}", 37 | trainable=True, 38 | ) 39 | for k in range(num_layers) 40 | ] 41 | 42 | def call(self, x: tf.Tensor) -> tf.Tensor: 43 | for layer in self.layers: 44 | x = layer(x) 45 | return x 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .DS_Store 3 | 4 | *.csv 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | .pdm.toml 91 | 92 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 93 | __pypackages__/ 94 | 95 | # Celery stuff 96 | celerybeat-schedule 97 | celerybeat.pid 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | .dmypy.json 124 | dmypy.json 125 | 126 | # Pyre type checker 127 | .pyre/ 128 | 129 | # pytype static type analyzer 130 | .pytype/ 131 | 132 | # Cython debug symbols 133 | cython_debug/ -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # TARNet 2 | 3 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 4 | [![pypi version](https://img.shields.io/pypi/v/tarnet.svg)](https://pypi.python.org/pypi/tarnet) 5 | 6 | [![Downloads](https://static.pepy.tech/badge/tarnet)](https://pepy.tech/project/tarnet) 7 | 8 | **TARNet: TARNet Model with tensorflow 2 API.** 9 | 10 | Treatment-Agnostic Representation Network 🩺 is a machine learning architecture that has a common MLP feeding specific sub-networks. It can help to identify bias in the data, estimate average treatment effect or act as transfer-learning like model. 11 | 12 | ![TARNet model architecture](https://i.ibb.co/Lt7B7vV/TARNet.png) 13 | 14 | This package implement this model as a keras-like TensorFlow API model. 15 | 16 | Parameters are: 17 | ```python 18 | normalizer_layer: tf.keras.layers.Layer = None, 19 | n_treatments: int = 2, 20 | output_dim: int = 1, 21 | phi_layers: int = 2, 22 | units:int = 20, 23 | y_layers: int = 3, 24 | activation: str = "relu", 25 | reg_l2: float = 0.0, 26 | treatment_as_input: bool = False, 27 | scaler: Any = None, 28 | output_bias: float = None, 29 | ``` 30 | 31 | 32 | The input wil be (X,t) with t a (X,1) shape tensor representing the hidden treatment/category. 33 | 34 | 35 | 36 | Author & Maintainer: Arnaud Petit 37 | 38 | ## Installation 39 | 40 | Using pip: 41 | 42 | ```python 43 | pip install tarnet 44 | ``` 45 | 46 | ## Dependencies 47 | 48 | TARNet requires: 49 | 50 | - Python (>= 3.7) 51 | - TensorFlow 52 | 53 | ## Documentation 54 | 55 | Link to the documentation: coming soon 56 | 57 | ## Examples 58 | 59 | General case, import one of the classes Classifiers, Regressions, Clustering from vulpes.automl, add some parameters to the object (optional), fit your dataset: 60 | 61 | ```python 62 | from tarnet import TARNet 63 | 64 | df = pd.read_csv("...") 65 | 66 | X, y, t = ( 67 | df.drop(output + ["icu_type"], axis=1).to_numpy(dtype="float32"), 68 | df[output].to_numpy(dtype="int").reshape(-1, 1), 69 | df.icu_type.to_numpy(dtype="int32").reshape(-1, 1), 70 | ) 71 | 72 | import tensorflow as tf 73 | 74 | normalizer_layer = tf.keras.layers.Normalization(axis=None) 75 | normalizer_layer.adapt(X) 76 | scaler = normalizer_layer 77 | 78 | DATASET_SIZE = len(df) 79 | 80 | batch_size = 64 81 | 82 | train_size = int(0.7 * DATASET_SIZE) 83 | val_size = int(0.2 * DATASET_SIZE) 84 | test_size = int(0.1 * DATASET_SIZE) 85 | 86 | dataset = tf.data.Dataset.zip( 87 | (tf.data.Dataset.from_tensor_slices((X, t)), tf.data.Dataset.from_tensor_slices(y)) 88 | ).shuffle(buffer_size=DATASET_SIZE, reshuffle_each_iteration=False)#batch(64) 89 | 90 | train_dataset = dataset.take(train_size).batch(batch_size) 91 | test_dataset = dataset.skip(train_size) 92 | val_dataset = test_dataset.take(val_size).batch(batch_size) 93 | test_dataset = test_dataset.skip(val_size) 94 | 95 | neg, pos = np.bincount(np.concatenate([y for _, y in train_dataset]).reshape(-1).astype("int")) 96 | 97 | initial_bias = tf.keras.initializers.Constant(np.log([pos/neg])) 98 | 99 | model = tarNET( 100 | output_dim=1, 101 | n_treatments=10, 102 | normalizer_layer=normalizer_layer, 103 | scaler=scaler, 104 | output_bias=initial_bias, 105 | phi_layers=10, 106 | ) 107 | 108 | ``` 109 | 110 | 111 | ## Why TARNet? 112 | 113 | TARNet stands for: **T**reatment-**A**gnostic **R**epresentation **N**etwork. 114 | 115 | 116 | 117 | ## License 118 | 119 | [MIT](https://choosealicense.com/licenses/mit/) 120 | -------------------------------------------------------------------------------- /TARNet/models/TARNet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from ..layers.gather import Gather_Streams 4 | from ..layers.split import Split_Streams 5 | from ..layers.MLP import MLP 6 | from pickle import load 7 | 8 | from typing import Any 9 | 10 | 11 | class TARNet(tf.keras.Model): 12 | """Return a tarnet sub KERAS API model.""" 13 | 14 | def __init__( 15 | self, 16 | normalizer_layer: tf.keras.layers.Layer = None, 17 | n_treatments: int = 2, 18 | output_dim: int = 1, 19 | phi_layers: int = 2, 20 | units:int = 20, 21 | y_layers: int = 3, 22 | activation: str = "relu", 23 | reg_l2: float = 0.0, 24 | treatment_as_input: bool = False, 25 | scaler: Any = None, 26 | output_bias: float = None, 27 | ): 28 | """Initialize the layers used by the model. 29 | 30 | Args: 31 | normalizer_layer (tf.keras.layer, optional): _description_. Defaults to None. 32 | n_treatments (int, optional): _description_. Defaults to 2. 33 | output_dim (int, optional): _description_. Defaults to 1. 34 | phi_layers (int, optional): _description_. Defaults to 2. 35 | y_layers (int, optional): _description_. Defaults to 3. 36 | activation (str, optional): _description_. Defaults to "relu". 37 | reg_l2 (float, optional): _description_. Defaults to 0.0. 38 | """ 39 | super(TARNet, self).__init__() 40 | # uniform quantile transform for treatment 41 | self.scaler = scaler if scaler else load(open("scaler.pkl", "rb")) 42 | 43 | # input normalization layer 44 | self.normalizer_layer = normalizer_layer 45 | self.phi = MLP( 46 | units=units, 47 | activation=activation, 48 | name="phi", 49 | num_layers=phi_layers, 50 | ) 51 | 52 | self.splitter = Split_Streams() 53 | 54 | self.y_hiddens = [ 55 | MLP( 56 | units=units, 57 | activation=activation, 58 | name=f"y_{k}", 59 | num_layers=y_layers, 60 | ) 61 | for k in range(n_treatments) 62 | ] 63 | 64 | # add linear function to cover the normalized output 65 | self.y_outputs = [ 66 | tf.keras.layers.Dense( 67 | output_dim, 68 | activation="sigmoid", 69 | bias_initializer=output_bias, 70 | name=f"top_{k}", 71 | ) 72 | for k in range(n_treatments) 73 | ] 74 | 75 | self.n_treatments = n_treatments 76 | 77 | self.output_ = Gather_Streams() 78 | 79 | def call(self, x): 80 | 81 | cofeatures_input, treatment_input = x 82 | treatment_cat = tf.cast(treatment_input, tf.int32) 83 | 84 | if self.normalizer_layer: 85 | cofeatures_input = self.normalizer_layer(cofeatures_input) 86 | x_flux = self.phi(cofeatures_input) 87 | 88 | streams = [ 89 | self.splitter([x_flux, treatment_cat, tf.cast(indice_treatment, tf.int32)]) 90 | for indice_treatment in range(len(self.y_hiddens)) 91 | ] 92 | # xstream is a list of tuple, containing the gathered and indice position, let's unpack them 93 | x_streams, indice_streams = zip(*streams) 94 | # tf.print(indice_streams, output_stream=sys.stderr) 95 | x_streams = [ 96 | y_hidden(x_stream) for y_hidden, x_stream in zip(self.y_hiddens, x_streams) 97 | ] 98 | x_streams = [ 99 | y_output(x_stream) for y_output, x_stream in zip(self.y_outputs, x_streams) 100 | ] 101 | 102 | return self.output_([x_streams, indice_streams]) 103 | --------------------------------------------------------------------------------