├── .gitignore ├── LICENSE.md ├── README.md ├── concepts_xai ├── __init__.py ├── datasets │ ├── __init__.py │ ├── cars3D.py │ ├── dSprites.py │ ├── dataset_utils.py │ ├── latentFactorData.py │ ├── load_paths.py │ ├── shapes3D.py │ ├── smallNorb.py │ └── tabular_toy.py ├── evaluation │ ├── __init__.py │ └── metrics │ │ ├── accuracy.py │ │ ├── completeness.py │ │ ├── downstream_task.py │ │ ├── mpo.py │ │ ├── niching.py │ │ └── purity.py ├── methods │ ├── CBM │ │ └── CBModel.py │ ├── CME │ │ ├── CtlModel.py │ │ └── ItCModel.py │ ├── CW │ │ └── CWLayer.py │ ├── OCACE │ │ ├── __init__.py │ │ ├── topicModel.py │ │ └── visualisation.py │ ├── SENN │ │ ├── __init__.py │ │ ├── aggregators.py │ │ └── base_senn.py │ ├── SSCC │ │ └── SSCClassifier.py │ ├── VAE │ │ ├── baseVAE.py │ │ ├── betaVAE.py │ │ ├── losses.py │ │ └── weak_vae.py │ └── __init__.py └── utils │ ├── __init__.py │ ├── architectures.py │ ├── model_loader.py │ ├── utils.py │ └── visualisation.py ├── config.yml ├── config_template.yml ├── download_datasets.sh ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/ 3 | .python-version 4 | *.ipynb 5 | __pycache__ 6 | build/ 7 | cars/ 8 | concepts_xai.egg-info/ 9 | dist/ 10 | dsprites/ 11 | shapes3d/ 12 | small_norb/ 13 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Dmitry Kazhdan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Concept-based XAI Library 2 | 3 | CXAI is an open-source library for research on concept-based Explainable AI 4 | (XAI). 5 | 6 | CXAI supports a variety of different models, datasets, and evaluation metrics, 7 | associated with concept-based approaches: 8 | 9 | 10 | ### High-level Specs: 11 | 12 | _Methods_: 13 | - [Now You See Me (CME): Concept-based Model Extraction](https://arxiv.org/abs/2010.13233). 14 | - [Concept Bottleneck Models](https://arxiv.org/abs/2007.04612) 15 | - [Weakly-Supervised Disentanglement Without Compromises](https://arxiv.org/abs/2002.02886) 16 | - [Challenging Common Assumptions in the Unsupervised Learning of Disentangled Representations](https://arxiv.org/abs/1811.12359) 17 | - [Concept Whitening for Interpretable Image Recognition](https://arxiv.org/abs/2002.01650) 18 | - [On Completeness-aware Concept-Based Explanations in Deep Neural Networks](https://arxiv.org/abs/1910.07969) 19 | - [Towards Robust Interpretability with Self-Explaining Neural Networks 20 | ](https://arxiv.org/abs/1806.07538) 21 | 22 | 23 | _Datasets_: 24 | - [dSprites](https://github.com/deepmind/dsprites-dataset) 25 | - [Shapes3D](https://github.com/deepmind/3d-shapes) 26 | - [Cars3D](https://papers.nips.cc/paper/2015/hash/e07413354875be01a996dc560274708e-Abstract.html) 27 | - [SmallNorb](https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/) 28 | 29 | to get the datasets run script datasets/download_datasets.sh 30 | 31 | ### Requirements 32 | 33 | - Python 3.7 - 3.8 34 | - See 'requirements.txt' for the rest of required packages 35 | 36 | ### Installation 37 | If installing from the source, please proceed by running the following command: 38 | ```bash 39 | python setup.py install 40 | ``` 41 | This will install the `concepts-xai` package together with all its dependencies. 42 | 43 | To test that the package has been successfully installed, you may run: 44 | ```python 45 | import concepts_xai 46 | help("concepts_xai") 47 | ``` 48 | to display all the subpackages included from this installation. 49 | 50 | ### Subpackages 51 | 52 | - `datasets`: datasets to use, including task functions. 53 | - `evaluation`: different evaluation metrics to use for evaluating our methods. 54 | - `experiments`: experimental setups (To-be-added soon) 55 | - `methods`: defines the concept-based methods. Note: SSCC defines wrappers around these methods, that turn then into semi-supervised concept labelling methods. 56 | - `utils`: contains utility functions for model creation as well as data management. 57 | 58 | 59 | ### Citing 60 | 61 | If you find this code useful in your research, please consider citing: 62 | 63 | ``` 64 | @article{kazhdan2021disentanglement, 65 | title={Is Disentanglement all you need? Comparing Concept-based \& Disentanglement Approaches}, 66 | author={Kazhdan, Dmitry and Dimanov, Botty and Terre, Helena Andres and Jamnik, Mateja and Li{\`o}, Pietro and Weller, Adrian}, 67 | journal={arXiv preprint arXiv:2104.06917}, 68 | year={2021} 69 | } 70 | ``` 71 | 72 | This work has been presented at the [RAI](https://sites.google.com/view/rai-workshop/), [WeaSuL](https://weasul.github.io/), and 73 | [RobustML](https://sites.google.com/connect.hku.hk/robustml-2021/home) workshops, at [The Ninth International Conference on Learning 74 | Representations (ICLR 2021)](https://iclr.cc/). 75 | -------------------------------------------------------------------------------- /concepts_xai/__init__.py: -------------------------------------------------------------------------------- 1 | from concepts_xai.datasets import * 2 | from concepts_xai.evaluation import * 3 | from concepts_xai.methods import * 4 | from concepts_xai.utils import * 5 | -------------------------------------------------------------------------------- /concepts_xai/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/concept-based-xai/a86c6abd609ec2ef0cf56cb556e3f834a0962bdc/concepts_xai/datasets/__init__.py -------------------------------------------------------------------------------- /concepts_xai/datasets/cars3D.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import PIL 4 | import numpy as np 5 | import scipy.io as sio 6 | import tensorflow as tf 7 | 8 | from .latentFactorData import LatentFactorData, get_task_data, built_task_fn 9 | 10 | CARS_concept_names = ['elevation', 'azimuth', 'object_type'] 11 | CARS_concept_n_vals = [4, 24, 183] 12 | 13 | 14 | class Cars3D(LatentFactorData): 15 | 16 | def __init__( 17 | self, 18 | dataset_path, 19 | task_name='elevation_full', 20 | train_size=0.85, 21 | random_state=42, 22 | ): 23 | ''' 24 | :param dataset_path: path to the cars dataset folder 25 | :param task_name: the task to use with the dataset for creating labels 26 | ''' 27 | super().__init__( 28 | dataset_path=dataset_path, 29 | task_name=task_name, 30 | num_factors=3, 31 | sample_shape=[64, 64, 3], 32 | c_names=CARS_concept_names, 33 | task_fn=CARS3D_TASKS[task_name], 34 | ) 35 | self._get_generators(train_size, random_state) 36 | 37 | def _load_x_c_data(self): 38 | x_data = [] 39 | all_files = [ 40 | x for x in tf.io.gfile.listdir(self.dataset_path) if ".mat" in x 41 | ] 42 | c_data = [] 43 | 44 | for i, filename in enumerate(all_files): 45 | data_mesh = load_mesh(os.path.join(self.dataset_path, filename)) 46 | factor1 = np.array(list(range(4))) 47 | factor2 = np.array(list(range(24))) 48 | all_factors = np.transpose([ 49 | np.tile(factor1, len(factor2)), 50 | np.repeat(factor2, len(factor1)), 51 | np.tile(i, len(factor1) * len(factor2)) 52 | ]) 53 | 54 | c_data += [ 55 | list(all_factors[j]) for j in range(all_factors.shape[0]) 56 | ] 57 | x_data.append(data_mesh) 58 | 59 | x_data = np.concatenate(x_data) 60 | c_data = np.array(c_data) 61 | 62 | return x_data, c_data 63 | 64 | 65 | def load_mesh(filename): 66 | """Parses a single source file and rescales contained images.""" 67 | with tf.io.gfile.GFile(filename, "rb") as f: 68 | mesh = np.einsum("abcde->deabc", sio.loadmat(f)["im"]) 69 | flattened_mesh = mesh.reshape((-1,) + mesh.shape[2:]) 70 | rescaled_mesh = np.zeros((flattened_mesh.shape[0], 64, 64, 3)) 71 | for i in range(flattened_mesh.shape[0]): 72 | flattened_im = flattened_mesh[i, :, :, :] 73 | pic = PIL.Image.fromarray(flattened_im) 74 | pic.thumbnail((64, 64), PIL.Image.ANTIALIAS) 75 | np_pic = np.array(pic) 76 | rescaled_mesh[i, :, :, :] = np_pic 77 | return rescaled_mesh * 1. / 255 78 | 79 | 80 | # =========================================================================== 81 | # Task DEFINITIONS 82 | # =========================================================================== 83 | 84 | def get_elevation_full(x_data, c_data): 85 | label_fn = lambda c: c[0] 86 | return get_task_data(x_data, c_data, label_fn, filter_fn=None) 87 | 88 | 89 | def get_all_concepts(x_data, c_data): 90 | label_fn = lambda c: c 91 | return get_task_data(x_data, c_data, label_fn, filter_fn=None) 92 | 93 | 94 | # =========================================================================== 95 | # Define task function lookups 96 | # =========================================================================== 97 | 98 | CARS3D_TASKS = { 99 | 'all_concepts': get_all_concepts, 100 | 'elevation_full': get_elevation_full, 101 | "bin_elevation": built_task_fn(lambda c: int(c[0] >= 2)), 102 | } 103 | -------------------------------------------------------------------------------- /concepts_xai/datasets/dSprites.py: -------------------------------------------------------------------------------- 1 | ''' 2 | See https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_reloading_example.ipynb 3 | for a nice overview. 4 | 5 | 6 latent factors: color, shape, scale, rotation and position (x and y) 6 | 'latents_sizes': [1, 3, 6, 40, 32, 32] 7 | ''' 8 | 9 | import numpy as np 10 | 11 | from .latentFactorData import LatentFactorData, get_task_data 12 | 13 | ################################################################################ 14 | ## GLOBAL VARIABLES 15 | ################################################################################ 16 | 17 | CONCEPT_NAMES = [ 18 | 'shape', 19 | 'scale', 20 | 'rotation', 21 | 'x_pos', 22 | 'y_pos', 23 | ] 24 | 25 | CONCEPT_N_VALUES = [ 26 | 3, # [square, ellipse, heart] 27 | 6, # np.linspace(0.5, 1, 6) 28 | 40, # 40 values in {0, 1, ..., 39} representing angles in [0, 2 * pi] 29 | 32, # 32 values in {0, 2, 3, ..., 31} representing coordinates in [0, 1] 30 | 32, # 32 values in {0, 2, 3, ..., 31} representing coordinates in [0, 1] 31 | ] 32 | 33 | 34 | ################################################################################ 35 | ## DATASET LOADER 36 | ################################################################################ 37 | 38 | class dSprites(LatentFactorData): 39 | 40 | def __init__( 41 | self, 42 | dataset_path, 43 | task='shape_scale_small_skip', 44 | train_size=0.85, 45 | random_state=None, 46 | ): 47 | ''' 48 | :param str dataset_path: path to the .npz dsprites file. 49 | :param Or[ 50 | str, 51 | Function[(ndarray, ndarray), (ndarray, ndarray, ndarray) 52 | ] task: the task to use with the dataset for creating 53 | labels. If this is a string, then it must be the name of a 54 | pre-defined task in the DSPRITES_TASKS lookup table. Otherwise 55 | we expect a function that takes two np.ndarrays (x_data, c_data), 56 | corresponding to the dSprites samples and their respective concepts 57 | respectively, and produces a tuple of three np.ndarrays 58 | (x_data, c_data, y_data) corresponding to the task's 59 | samples, ground truth concept values, and labels, respectively. 60 | ''' 61 | 62 | # Note: we exclude the trivial 'color' concept 63 | if isinstance(task, str): 64 | if task not in DSPRITES_TASKS: 65 | raise ValueError( 66 | f'If the given task is a string, then it is expected to be ' 67 | f'the name of a pre-defined task in ' 68 | f'{list(DSPRITES_TASKS.keys())}. However, we were given ' 69 | f'"{task}" which is not a known task.' 70 | ) 71 | task_fn = DSPRITES_TASKS[task] 72 | else: 73 | task_fn = task 74 | super().__init__( 75 | dataset_path=dataset_path, 76 | task_name="dSprites", 77 | num_factors=len(CONCEPT_NAMES), 78 | sample_shape=[64, 64, 1], 79 | c_names=CONCEPT_NAMES, 80 | task_fn=task_fn, 81 | ) 82 | self._get_generators(train_size, random_state) 83 | 84 | def _load_x_c_data(self): 85 | # Load dataset 86 | dataset_zip = np.load(self.dataset_path) 87 | x_data = dataset_zip['imgs'] 88 | x_data = np.expand_dims(x_data, axis=-1) 89 | c_data = dataset_zip['latents_classes'] 90 | c_data = c_data[:, 1:] # Remove color concept 91 | return x_data, c_data 92 | 93 | 94 | ################################################################################ 95 | # TASK DEFINITIONS 96 | ################################################################################ 97 | 98 | def cardinality_encoding(card_group_1, card_group_2): 99 | result_to_encoding = {} 100 | for i in card_group_1: 101 | for j in card_group_2: 102 | result_to_encoding[(i, j)] = len(result_to_encoding) 103 | return result_to_encoding 104 | 105 | 106 | def small_skip_ranges_filter_fn(concept): 107 | ''' 108 | Filter out certain values only 109 | ''' 110 | ranges = [ 111 | list(range(3)), 112 | list(range(6)), 113 | list(range(0, 40, 5)), 114 | list(range(0, 32, 2)), 115 | list(range(0, 32, 2)), 116 | ] 117 | return all([ 118 | (concept[i] in ranges[i]) for i in range(len(ranges)) 119 | ]) 120 | 121 | 122 | def get_shape_full(x_data, c_data): 123 | return get_task_data( 124 | x_data=x_data, 125 | c_data=c_data, 126 | label_fn=lambda c_data: c_data[0], 127 | ) 128 | 129 | 130 | def get_shape_small_skip(x_data, c_data): 131 | return get_task_data( 132 | x_data=x_data, 133 | c_data=c_data, 134 | label_fn=lambda c_data: c_data[0], 135 | filter_fn=small_skip_ranges_filter_fn, 136 | ) 137 | 138 | 139 | def get_shape_scale_full(x_data, c_data): 140 | return get_task_data( 141 | x_data=x_data, 142 | c_data=c_data, 143 | label_fn=lambda c_data: ( 144 | c_data[0] * CONCEPT_N_VALUES[1] + c_data[1] 145 | ), 146 | ) 147 | 148 | 149 | def get_shape_scale_small_skip(x_data, c_data): 150 | label_remap = cardinality_encoding( 151 | list(range(3)), 152 | list(range(6)), 153 | ) 154 | return get_task_data( 155 | x_data=x_data, 156 | c_data=c_data, 157 | label_fn=lambda c_data: label_remap[(c_data[0], c_data[1])], 158 | filter_fn=small_skip_ranges_filter_fn, 159 | ) 160 | 161 | 162 | ################################################################################ 163 | # TASK FUNCTION LOOKUP TABLE 164 | ################################################################################ 165 | 166 | DSPRITES_TASKS = { 167 | "shape_full": get_shape_full, 168 | "shape_scale_full": get_shape_scale_full, 169 | "shape_scale_small_skip": get_shape_scale_small_skip, 170 | "shape_small_skip": get_shape_small_skip, 171 | } 172 | -------------------------------------------------------------------------------- /concepts_xai/datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from . import dSprites, shapes3D, smallNorb, cars3D 4 | 5 | ''' 6 | A primer on bases: 7 | Assume you have a vector A = (x, y, z), where every dimension is in base 8 | (a, b, c), then, in order to convert each of those dimensions to decimal, we do: 9 | 10 | D = (z * 1) + (y * c) + (x * b * c) 11 | 12 | Or, can define bases vector B = [b*c, c, 1], and then define D = B.A 13 | 14 | Example: 15 | (1, 0, 1) in bases (2, 2, 2) (i.e. in binary): 16 | D = (2^0 * 1) + (2^1 * 0) + (2^2 * 1) = 1 + 4 = 5 17 | ''' 18 | 19 | 20 | def get_latent_bases(latent_sizes): 21 | return np.concatenate( 22 | [latent_sizes[::-1].cumprod()[::-1][1:], np.array([1, ])] 23 | ) 24 | 25 | 26 | # Convert a concept-based index into a single index 27 | def latent_to_index(latents, latents_bases): 28 | return np.dot(latents, latents_bases).astype(int) 29 | 30 | dataset_concept_names = { 31 | "dSprites": dSprites.DSPRITES_concept_names, 32 | "shapes3d": shapes3D.SHAPES3D_concept_names, 33 | "smallNorb": smallNorb.SMALLNORB_concept_names, 34 | "cars3D": cars3D.CARS_concept_names 35 | } 36 | -------------------------------------------------------------------------------- /concepts_xai/datasets/latentFactorData.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from sklearn.model_selection import train_test_split 4 | 5 | 6 | class LatentFactorData(object): 7 | """ 8 | Abstract class for datasets with known ground truth factors 9 | """ 10 | def __init__( 11 | self, 12 | dataset_path, 13 | task_name, 14 | num_factors, 15 | sample_shape, 16 | c_names, 17 | task_fn, 18 | ): 19 | self.dataset_path = dataset_path 20 | self.task_name = task_name 21 | self.num_factors = num_factors 22 | self.sample_shape = sample_shape 23 | self.c_names = c_names 24 | self.task_fn = task_fn 25 | self._has_generators = False 26 | 27 | def _load_x_c_data(self): 28 | raise NotImplementedError( 29 | "Need to implement code for loading input and concept data" 30 | ) 31 | 32 | def get_concept_values(self): 33 | if not self._has_generators: 34 | self._get_generators() 35 | return self.n_c_vals_list 36 | 37 | def _get_generators(self, train_size=0.85, random_state=42): 38 | 39 | x_data, c_data = self._load_x_c_data() 40 | x_data, c_data, y_data = self.task_fn(x_data, c_data) 41 | x_data = x_data.astype(np.float32) 42 | self.n_classes = len(np.unique(y_data)) 43 | self.n_c_vals_list = [ 44 | len(np.unique(c_data[:, i])) for i in range(c_data.shape[1]) 45 | ] 46 | 47 | # Create dictionary of 48 | # concept_id --> dictionary: consequtive_id : original_id 49 | self.cid_new_to_old = [] 50 | for i in range(c_data.shape[1]): 51 | unique_vals, unique_inds = np.unique( 52 | c_data[:, i], 53 | return_inverse=True, 54 | ) 55 | self.cid_new_to_old.append( 56 | {i : unique_vals[i] for i in range(len(unique_vals))} 57 | ) 58 | c_data[:, i] = unique_inds 59 | 60 | ( 61 | self.x_train, 62 | self.x_test, 63 | self.y_train, 64 | self.y_test, 65 | self.c_train, 66 | self.c_test, 67 | ) = train_test_split( 68 | x_data, 69 | y_data, 70 | c_data, 71 | train_size=train_size, 72 | random_state=random_state, 73 | ) 74 | 75 | self.n_train_samples = self.c_train.shape[0] 76 | self.n_test_samples = self.c_test.shape[0] 77 | self.data_gen_train = tf.data.Dataset.from_tensor_slices( 78 | (self.x_train, self.c_train, self.y_train) 79 | ) 80 | self.data_gen_test = tf.data.Dataset.from_tensor_slices( 81 | (self.x_test, self.c_test, self.y_test) 82 | ) 83 | self._has_generators = True 84 | 85 | def load_data(self): 86 | return self.data_gen_train, self.data_gen_test, self.c_names 87 | 88 | 89 | # =========================================================================== 90 | # Task definition utility functions 91 | # =========================================================================== 92 | 93 | def built_task_fn(label_fn, filter_fn=None): 94 | def task_fn(x_data, c_data): 95 | return get_task_data(x_data, c_data, label_fn, filter_fn=filter_fn) 96 | return task_fn 97 | 98 | 99 | def get_task_data(x_data, c_data, label_fn, filter_fn=None): 100 | 101 | if filter_fn is not None: 102 | ids = np.array([filter_fn(c) for c in c_data]) 103 | ids = np.where(ids)[0] 104 | c_data = c_data[ids] 105 | x_data = x_data[ids] 106 | 107 | y_data = np.array([label_fn(c) for c in c_data]) 108 | 109 | return x_data, c_data, y_data 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /concepts_xai/datasets/load_paths.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | 4 | def load_dataset_paths(config_path): 5 | ''' 6 | :param config_path: Path to .yml file containing the individual dataset 7 | paths 8 | :return: Dictionary of dataset_name --> dataset_path 9 | ''' 10 | 11 | dataset_paths_dict = {} 12 | 13 | with open(config_path) as config_file: 14 | data = yaml.load(config_file, Loader=yaml.FullLoader) 15 | 16 | dataset_paths_dict["dSprites"] = data['dsprites_path'] 17 | dataset_paths_dict["cars3D"] = data['cars3D_path'] 18 | dataset_paths_dict["smallNorb"] = data['smallNorb_path'] 19 | dataset_paths_dict["shapes3d"] = data['shapes3d_path'] 20 | 21 | return dataset_paths_dict 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /concepts_xai/datasets/shapes3D.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | import h5py 4 | 5 | from .latentFactorData import LatentFactorData, get_task_data 6 | 7 | ################################################################################ 8 | ## GLOBAL VARIABLES 9 | ################################################################################ 10 | 11 | 12 | CONCEPT_NAMES = [ 13 | 'floor_hue', 14 | 'wall_hue', 15 | 'object_hue', 16 | 'scale', 17 | 'shape', 18 | 'orientation', 19 | ] 20 | CONCEPT_N_VALUES = [ 21 | 10, 22 | 10, 23 | 10, 24 | 8, 25 | 4, 26 | 15, 27 | ] 28 | 29 | 30 | ################################################################################ 31 | ## DATASET LOADER 32 | ################################################################################ 33 | 34 | class shapes3D(LatentFactorData): 35 | 36 | def __init__( 37 | self, 38 | dataset_path, 39 | task='shape_full', 40 | train_size=0.85, 41 | random_state=None, 42 | ): 43 | ''' 44 | :param dataset_path: path to the .npz shapes3D file 45 | :param Or[ 46 | str, 47 | Function[(ndarray, ndarray), (ndarray, ndarray, ndarray) 48 | ] task: the task to use with the dataset for creating 49 | labels. If this is a string, then it must be the name of a 50 | pre-defined task in the SHAPES3D_TASKS lookup table. Otherwise 51 | we expect a function that takes two np.ndarrays (x_data, c_data), 52 | corresponding to the dSprites samples and their respective concepts 53 | respectively, and produces a tuple of three np.ndarrays 54 | (x_data, c_data, y_data) corresponding to the task's 55 | samples, ground truth concept values, and labels, respectively. 56 | ''' 57 | 58 | if isinstance(task, str): 59 | if task not in SHAPES3D_TASKS: 60 | raise ValueError( 61 | f'If the given task is a string, then it is expected to be ' 62 | f'the name of a pre-defined task in ' 63 | f'{list(SHAPES3D_TASKS.keys())}. However, we were given ' 64 | f'"{task}" which is not a known task.' 65 | ) 66 | task_fn = SHAPES3D_TASKS[task] 67 | else: 68 | task_fn = task 69 | 70 | super().__init__( 71 | dataset_path=dataset_path, 72 | task_name="3dshapes", 73 | num_factors=len(CONCEPT_NAMES), 74 | sample_shape=[64, 64, 3], 75 | c_names=CONCEPT_NAMES, 76 | task_fn=task_fn, 77 | ) 78 | self._get_generators(train_size, random_state) 79 | 80 | def _load_x_c_data(self): 81 | # Get concept data 82 | latent_size_lists = [list(np.arange(i)) for i in CONCEPT_N_VALUES] 83 | c_data = np.array(list(itertools.product(*latent_size_lists))) 84 | # Load image data 85 | with h5py.File(self.dataset_path, 'r') as hf: 86 | x_data = np.array(hf.get('images')) / 255. 87 | 88 | 89 | return x_data, c_data 90 | 91 | 92 | ################################################################################ 93 | # TASK DEFINITIONS 94 | ################################################################################ 95 | 96 | def small_skip_ranges_filter_fn(concept): 97 | ''' 98 | Filter out certain values only 99 | ''' 100 | ranges = [ 101 | list(range(0, 10, 2)), 102 | list(range(0, 10, 2)), 103 | list(range(0, 10, 2)), 104 | list(range(0, 8, 2)), 105 | list(range(4)), 106 | list(range(0, 15, 2)), 107 | ] 108 | 109 | return all([(concept[i] in ranges[i]) for i in range(len(ranges))]) 110 | 111 | 112 | def get_shape_full(x_data, c_data): 113 | return get_task_data( 114 | x_data=x_data, 115 | c_data=c_data, 116 | label_fn=lambda x: x[4], 117 | ) 118 | 119 | 120 | def get_shape_small_skip(x_data, c_data): 121 | return get_task_data( 122 | x_data=x_data, 123 | c_data=c_data, 124 | label_fn=lambda x: x[4], 125 | filter_fn=small_skip_ranges_filter_fn, 126 | ) 127 | 128 | 129 | def get_reduced_filter_fn(): 130 | ranges = [ 131 | list(range(0, 10, 2)), 132 | list(range(0, 10, 2)), 133 | list(range(0, 10, 2)), 134 | list(range(0, 8, 2)), 135 | list(range(4)), 136 | list(range(15)), 137 | ] 138 | 139 | def filter_fn(concept): 140 | return all([(concept[i] in ranges[i]) for i in range(len(ranges))]) 141 | 142 | return filter_fn 143 | 144 | 145 | def get_reduced_shapes3d(x_data, c_data): 146 | 147 | ranges = [ 148 | list(range(0, 10, 2)), 149 | list(range(0, 10, 2)), 150 | list(range(0, 10, 2)), 151 | list(range(0, 8, 2)), 152 | list(range(4)), 153 | list(range(15)), 154 | ] 155 | 156 | label_fn = shape_label_fn 157 | 158 | def filter_fn(concept): 159 | return all([(concept[i] in ranges[i]) for i in range(len(ranges))]) 160 | 161 | return get_task_data( 162 | x_data=x_data, 163 | c_data=c_data, 164 | label_fn=shape_label_fn, 165 | filter_fn=filter_fn, 166 | ) 167 | 168 | 169 | ################################################################################ 170 | # TASK FUNCTION LOOKUP TABLE 171 | ################################################################################ 172 | 173 | SHAPES3D_TASKS = { 174 | "reduced_shapes3d" : get_reduced_shapes3d, 175 | "shape_full" : get_shape_full, 176 | "shape_small_skip" : get_shape_small_skip, 177 | } 178 | 179 | 180 | -------------------------------------------------------------------------------- /concepts_xai/datasets/smallNorb.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import PIL 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from .latentFactorData import LatentFactorData, get_task_data 8 | 9 | SMALLNORB_concept_names = [ 10 | 'category', 11 | 'instance', 12 | 'elevation', 13 | 'azimuth', 14 | 'lighting', 15 | ] 16 | SMALLNORB_concept_n_vals = [5, 10, 9, 18, 6] 17 | 18 | 19 | class SmallNorb(LatentFactorData): 20 | 21 | def __init__( 22 | self, 23 | dataset_path, 24 | task_name='category_full', 25 | train_size=0.85, 26 | random_state=42, 27 | ): 28 | ''' 29 | :param dataset_path: path to the smallnorb files directory 30 | :param task_name: the task to use with the dataset for creating labels 31 | ''' 32 | 33 | super().__init__( 34 | dataset_path=dataset_path, 35 | task_name=task_name, 36 | num_factors=5, 37 | sample_shape=[64, 64, 1], 38 | c_names=SMALLNORB_concept_names, 39 | task_fn=SMALLNORB_TASKS[task_name], 40 | ) 41 | self._get_generators(train_size, random_state) 42 | 43 | def _load_x_c_data(self): 44 | files_dir = self.dataset_path 45 | filename_template = "smallnorb-{}-{}.mat" 46 | splits = [ 47 | "5x46789x9x18x6x2x96x96-training", 48 | "5x01235x9x18x6x2x96x96-testing", 49 | ] 50 | x_datas, c_datas = [], [] 51 | 52 | for i, split in enumerate(splits): 53 | data_fname = os.path.join( 54 | files_dir, 55 | filename_template.format(splits[i], 'dat'), 56 | ) 57 | cat_fname = os.path.join( 58 | files_dir, 59 | filename_template.format(splits[i], 'cat'), 60 | ) 61 | info_fname = os.path.join( 62 | files_dir, 63 | filename_template.format(splits[i], 'info'), 64 | ) 65 | 66 | x_data = _read_binary_matrix(data_fname) 67 | # Resize data, and only retain data from 1 camera 68 | x_data = _resize_images(x_data[:, 0]) 69 | c_cat = _read_binary_matrix(cat_fname) 70 | c_info = _read_binary_matrix(info_fname) 71 | c_info = np.copy(c_info) 72 | # Set azimuth values to be consecutive digits 73 | c_info[:, 2:3] = c_info[:, 2:3] / 2 74 | c_data = np.column_stack((c_cat, c_info)) 75 | 76 | x_datas.append(x_data) 77 | c_datas.append(c_data) 78 | 79 | x_data = np.concatenate(x_datas) 80 | x_data = np.expand_dims(x_data, axis=-1) 81 | c_data = np.concatenate(c_datas) 82 | 83 | return x_data, c_data 84 | 85 | 86 | def _resize_images(integer_images): 87 | resized_images = np.zeros((integer_images.shape[0], 64, 64)) 88 | for i in range(integer_images.shape[0]): 89 | image = PIL.Image.fromarray(integer_images[i, :, :]) 90 | image = image.resize((64, 64), PIL.Image.ANTIALIAS) 91 | resized_images[i, :, :] = image 92 | return resized_images / 255. 93 | 94 | 95 | def _read_binary_matrix(filename): 96 | """Reads and returns binary formatted matrix stored in filename.""" 97 | with tf.io.gfile.GFile(filename, "rb") as f: 98 | s = f.read() 99 | magic = int(np.frombuffer(s, "int32", 1)) 100 | ndim = int(np.frombuffer(s, "int32", 1, 4)) 101 | eff_dim = max(3, ndim) 102 | raw_dims = np.frombuffer(s, "int32", eff_dim, 8) 103 | dims = [] 104 | for i in range(0, ndim): 105 | dims.append(raw_dims[i]) 106 | 107 | dtype_map = { 108 | 507333717: "int8", 109 | 507333716: "int32", 110 | 507333713: "float", 111 | 507333715: "double" 112 | } 113 | data = np.frombuffer(s, dtype_map[magic], offset=8 + eff_dim * 4) 114 | data = data.reshape(tuple(dims)) 115 | return data 116 | 117 | 118 | # =========================================================================== 119 | # Task DEFINITIONS 120 | # =========================================================================== 121 | 122 | def get_category_full(x_data, c_data): 123 | label_fn = lambda c: c[0] 124 | return get_task_data(x_data, c_data, label_fn, filter_fn=None) 125 | 126 | 127 | # =========================================================================== 128 | # Define task function lookups 129 | # =========================================================================== 130 | 131 | SMALLNORB_TASKS = { 132 | 'category_full': get_category_full, 133 | } 134 | -------------------------------------------------------------------------------- /concepts_xai/datasets/tabular_toy.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Implements the tabular toy dataset defined by Mahinpei et al. in "Promises and 3 | Pitfalls of Black-Box Concept Learning Models" (found in 4 | https://arxiv.org/abs/2106.13314). 5 | 6 | ''' 7 | 8 | import numpy as np 9 | from .latentFactorData import LatentFactorData 10 | 11 | 12 | ################################################################################ 13 | ## GLOBAL VARIABLES 14 | ################################################################################ 15 | 16 | CONCEPT_NAMES = [ 17 | 'x_pos', 18 | 'y_pos', 19 | 'z_pos', 20 | ] 21 | 22 | CONCEPT_N_VALUES = [ 23 | 1, 24 | 1, 25 | 1, 26 | ] 27 | 28 | 29 | ################################################################################ 30 | ## DATASET LOADER 31 | ################################################################################ 32 | 33 | class TabularToy(LatentFactorData): 34 | 35 | def __init__( 36 | self, 37 | num_samples, 38 | cov=None, 39 | num_concepts=len(CONCEPT_NAMES), 40 | train_size=0.85, 41 | random_state=None, 42 | task=None, 43 | ): 44 | """ 45 | This dataset has 7 features and 3 concepts. It is generated by 3 latent 46 | variables (x, y, z), randomly sampled from a multivariate normal 47 | distribution (with the provided covariance matrix), which define all 48 | seven features of a sample as: 49 | 50 | Feature 0: np.sin(x_vars) + x_vars 51 | Feature 1: np.cos(x_vars) + x_vars 52 | Feature 2: np.sin(y_vars) + y_vars 53 | Feature 3: np.cos(y_vars) + y_vars 54 | Feature 4 np.sin(z_vars) + z_vars 55 | Feature 5: np.cos(z_vars) + z_vars 56 | Feature 6: x_vars**2 + y_vars**2 + z_vars**2 57 | 58 | As concepts, we provide by default a vector of three values 59 | [x_pos, y_pos,z_pos] indicating whether each concept is positive or not. 60 | BY default, if a task is not provided, then we assign as a label to each 61 | sample whether or not two or more of its latent factors are positive. 62 | 63 | :param int num_samples: The total number of samples we will generate 64 | for this dataset. 65 | :param Or[np.ndarray, float] cov: A covariance matrix to use for 66 | sampling the latent variables. If provided as a float, then we will 67 | assign cov as the covariance between any two distinct latent 68 | factors. If not given, then we will use the identity matrix. 69 | :param int num_concepts: A number in {1, 2, 3} indicating how many 70 | concepts we want to include in our concept representation. If less 71 | than 3, then we will trim the concept list from the right. 72 | :param float train_size: A number in [0,1] indicating which fraction of 73 | the entire dataset will be used for training and which fraction will 74 | be used for testing. 75 | :param Function[(ndarray, ndarray), (ndarray, ndarray, ndarray)] task: 76 | The task to use with the dataset for creating 77 | labels. We expect a function that takes two np.ndarrays 78 | (x_data, c_data) corresponding to the dSprites samples and their 79 | respective concepts respectively, and produces a tuple of three 80 | np.ndarrays (x_data, c_data, y_data) corresponding to the task's 81 | samples, ground truth concept values, and labels, respectively. 82 | """ 83 | 84 | if task is None: 85 | task = lambda x, c: ( 86 | x, 87 | c, 88 | (np.sum(c, axis=-1) > 1).astype(np.int32), 89 | ) 90 | super().__init__( 91 | dataset_path=None, 92 | task_name="tabular_toy", 93 | num_factors=num_concepts, 94 | sample_shape=[7], 95 | c_names=CONCEPT_NAMES, 96 | task_fn=task, 97 | ) 98 | self.num_samples = num_samples 99 | self.cov = np.eye(3) if cov is None else cov 100 | if isinstance(self.cov, (float, int)): 101 | self.cov = np.array([ 102 | [1, self.cov, self.cov], 103 | [self.cov, 1, self.cov], 104 | [self.cov, self.cov, 1], 105 | ]) 106 | 107 | self._get_generators(train_size, random_state) 108 | 109 | def _load_x_c_data(self): 110 | # Sample the x, y, and z variables 111 | latent_vars = np.random.multivariate_normal( 112 | mean=[0, 0, 0], 113 | cov=self.cov, 114 | size=(self.num_samples,), 115 | ) 116 | x_vars = latent_vars[:, 0] 117 | y_vars = latent_vars[:, 1] 118 | z_vars = latent_vars[:, 2] 119 | 120 | # The features are just non-linear functions applied to each 121 | # variable 122 | features = [ 123 | np.sin(x_vars) + x_vars, 124 | np.cos(x_vars) + x_vars, 125 | np.sin(y_vars) + y_vars, 126 | np.cos(y_vars) + y_vars, 127 | np.sin(z_vars) + z_vars, 128 | np.cos(z_vars) + z_vars, 129 | x_vars**2 + y_vars**2 + z_vars**2, 130 | ] 131 | features = np.stack(features, axis=1) 132 | 133 | # The concepts just check if the variables are positive 134 | x_pos = (x_vars > 0).astype(np.int32) 135 | y_pos = (y_vars > 0).astype(np.int32) 136 | z_pos = (z_vars > 0).astype(np.int32) 137 | concepts = np.squeeze( 138 | np.stack([x_pos, y_pos, z_pos][:self.num_factors], axis=1) 139 | ) 140 | 141 | # And that's it buds 142 | return features, concepts 143 | -------------------------------------------------------------------------------- /concepts_xai/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/concept-based-xai/a86c6abd609ec2ef0cf56cb556e3f834a0962bdc/concepts_xai/evaluation/__init__.py -------------------------------------------------------------------------------- /concepts_xai/evaluation/metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import accuracy_score 2 | 3 | 4 | def compute_accuracies(c_true, c_pred): 5 | ''' 6 | Compute the accuracy scores per concept 7 | :param c_true: Numpy array of (n_samples, n_concepts) of ground-truth 8 | concept values 9 | :param c_pred: Numpy array of (n_samples, n_concepts) of predicted concept 10 | values 11 | :return: Accuracies for all samples, per concept 12 | ''' 13 | n_concepts = c_true.shape[1] 14 | return [ 15 | accuracy_score(c_true[:, i], c_pred[:, i]) for i in range(n_concepts) 16 | ] 17 | -------------------------------------------------------------------------------- /concepts_xai/evaluation/metrics/completeness.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Implementation of metrics used for computing the completeness/explicitness 3 | of a given set of concepts. 4 | These metrics are inspired by Yeh et al. "On Completeness-aware Concept-Based 5 | Explanations in Deep Neural Networks" seen in https://arxiv.org/abs/1910.07969v5 6 | ''' 7 | 8 | import numpy as np 9 | import sklearn 10 | import tensorflow as tf 11 | 12 | from sklearn.model_selection import train_test_split 13 | 14 | 15 | ################################################################################ 16 | ## Helper Functions 17 | ################################################################################ 18 | 19 | def _get_default_model(num_concepts, num_hidden_acts, end_activation=None): 20 | """ 21 | Helper function that returns a simple 3-layer ReLU MLP model with a hidden 22 | layer with 500 activation in it. 23 | 24 | Used as the default estimator for computing the completeness score. 25 | 26 | :param int num_concepts: The number concept vectors we have in our setup. 27 | :param int num_hidden_acts: The dimensionality of the concept vectors. 28 | 29 | :returns tf.keras.Model: The default model used for reconstructing a set of 30 | intermediate hidden activations from a set of concepts scores. 31 | """ 32 | 33 | return tf.keras.models.Sequential([ 34 | tf.keras.layers.Dense( 35 | 500, 36 | input_dim=num_concepts, 37 | activation='relu' 38 | ), 39 | tf.keras.layers.Dense( 40 | num_hidden_acts, 41 | activation=end_activation, 42 | ), 43 | ]) 44 | 45 | 46 | ################################################################################ 47 | ## Concept Score Functions 48 | ################################################################################ 49 | 50 | 51 | def dot_prod_concept_score( 52 | features, 53 | concept_vectors, 54 | epsilon=1e-5, 55 | beta=1e-5, 56 | channels_axis=-1, 57 | ): 58 | """ 59 | Returns a vector of concept scores for the given features using a normalized 60 | dot product similarity as in Yeh et al. 61 | 62 | :param np.ndarray features: A 2D matrix of samples with shape 63 | (n_samples, ..., n_features, ...) with dimension at axis channels_axis 64 | havin `n_features` in it. 65 | :param np.ndarray concept_vectors: A 2D array of shape 66 | (num_concepts, n_features) where each column represents a given 67 | meaningful concept direction. 68 | :param float beta: A value used for zeroing dot products that are 69 | considered to be irrelevant. If a dot product is less than beta, then 70 | its score will be zero. 71 | :param float epsilon: A value used for numerical stability to avoid 72 | division by zero. 73 | :param int channels_axis: the channels axis in the input features. If not 74 | given, then we assume it is the last dimension always. 75 | 76 | :returns np.ndarray: A 2D matrix with shape (n_samples, n_concepts) where 77 | the (i, j)-th entry represents the score that the j-th concept assigned 78 | the i-th sample in `features`. 79 | """ 80 | # First check that all the dimensions make sense 81 | assert features.shape[channels_axis] == concept_vectors.shape[-1], ( 82 | f'Expected input to have {concept_vectors.shape[-1]} elements in its ' 83 | f'channels axis (defined as axis {channels_axis}). ' 84 | f'Instead, we found the input to have shape {features.shape}.' 85 | ) 86 | # First normalize all concepts across their channels dimension 87 | concept_vectors_norm = np.linalg.norm( 88 | concept_vectors, 89 | axis=-1, 90 | keepdims=True, 91 | ) 92 | concept_vectors = concept_vectors / (concept_vectors_norm + epsilon) 93 | 94 | # For simplicity, we will always move the channels dimension to the end 95 | if channels_axis not in [-1, len(features.shape) - 1]: 96 | # Then perform a transpose in here 97 | perm = list(range(len(features.shape))) 98 | perm[channels_axis] = len(features.shape) - 1 99 | perm[len(features.shape) - 1] = channels_axis 100 | features = np.transpose(features, perm) 101 | 102 | # And similarly, do the same for the input features 103 | x_norm = np.linalg.norm( 104 | features, 105 | axis=-1, 106 | keepdims=True, 107 | ) 108 | 109 | x_norm = features / (x_norm + epsilon) 110 | 111 | # Compute the concept probabilities accordingly 112 | concept_prob = np.dot(features, concept_vectors.transpose()) 113 | concept_prob_norm = np.dot(x_norm, concept_vectors.transpose()) 114 | 115 | # Threshold scores accordingly using the normalized scores 116 | if beta is not None: 117 | concept_prob = concept_prob * (concept_prob_norm > beta) 118 | 119 | # And end by normalizing them 120 | norm = np.sum(concept_prob, axis=-1, keepdims=True) 121 | result = concept_prob / (norm + epsilon) 122 | 123 | # Finally, restore the shape of the output tensor if a transpose was 124 | # done at the beginning 125 | if channels_axis not in [-1, len(features.shape) - 1]: 126 | perm = list(range(len(features.shape))) 127 | perm[channels_axis] = len(features.shape) - 1 128 | perm[len(features.shape) - 1] = channels_axis 129 | result = np.transpose(result, perm) 130 | 131 | # And that's all folks 132 | return result 133 | 134 | 135 | ################################################################################ 136 | ## Completeness Score Computation 137 | ################################################################################ 138 | 139 | def completeness_score( 140 | X, 141 | y, 142 | features_to_concepts_fn, 143 | concepts_to_labels_model, 144 | concept_vectors, 145 | task_loss, 146 | g_model=None, 147 | test_size=0.2, 148 | concept_score_fn=dot_prod_concept_score, 149 | predictor_train_kwags=None, 150 | g_optimizer='adam', 151 | acc_fn=sklearn.metrics.accuracy_score, 152 | channels_axis=-1, 153 | ): 154 | """ 155 | Returns the completeness score for the given set of concept vectors 156 | `concept_vectors` using testing data `X` with labels `y`. This score 157 | is computed using Yeh et al.'s definition of a concept completeness 158 | score based on a model `features_to_concepts_fn`, which maps input features 159 | in the test data to a M-dimensional space, and a model 160 | `concepts_to_labels_model` which maps M-dimensional vectors to some 161 | probability distribution over classes in `y`. 162 | 163 | :param np.ndarray X: A tensor of testing samples that are in the domain of 164 | given function `features_to_concepts_fn` where the first dimension 165 | represents the number of test samples (which we call `n_samples`). 166 | :param np.ndarray y: A tensor of testing labels corresponding to matrix 167 | X whose first dimension must also be `n_samples`. 168 | :param Function[(np.ndarray), np.ndarray] features_to_concepts_fn: A 169 | function mapping batches of samples with the same dimensionality as 170 | X into some (n_samples, ..., M, ...)-dimensional vector space 171 | corresponding to the same vector space as that used for the given 172 | concept vectors. In this case, M is the channels dimension at location 173 | channels_axis. 174 | :param tf.keras.Model concepts_to_labels_model: An arbitrary Keras model 175 | which maps M-dimensional vectors (as those produced by calling the 176 | `features_to_concepts_fn` function on a batch of samples) into a 177 | probability distribution over labels in `y`. 178 | :param np.ndarray concept_vectors: A 2D array of shape (num_concepts, M) 179 | where each column represents a given meaningful concept direction. 180 | :param tf.keras.losses.Loss task_loss: The loss function one intends to 181 | minimize when mapping instances in `X` to labels in `y`. 182 | :param tf.keras.Model g_model: The model `g` we will train for mapping 183 | concept scores to the same M-dimensional space when computing 184 | the concept completeness score. If not given, then we will use a 185 | 3-layered ReLU MLP with 500 hidden activations. 186 | :param float test_size: A value between 0 and 1 representing what percent 187 | of the (X, y) data will be used for testing the accuracy of our g_model 188 | (and the original model) when computing the completeness score. The 189 | rest of the data will be used for training our g_model. 190 | :param Function[(np.ndarray,List[np.ndarray]), np.ndarray] concept_score_fn: 191 | A function taking as an input a matrix of shape (n_samples, M), 192 | representing outputs produced by the `features_to_concepts_fn` function, 193 | and a list of`n_concepts` M-dimensional vectors, representing unit 194 | directions of meaningful concepts, and returning a vector with 195 | n_concepts concept scores. By default we use the normalized dot product 196 | scores. 197 | :param Dict[Any, Any] predictor_train_kwags: An optional set of parameters 198 | to pass to the g_model when trained for reconstructing the M-dimensional 199 | activations from their corresponding concept scores. 200 | :param tf.keras.optimizers.Optimizer g_optimizer: The optimizer used for 201 | training the g model for the reconstruction. By default we will use an 202 | ADAM optimizer. 203 | :param Function[(np.ndarray, np.ndarray), float] acc_fn: An accuracy 204 | function taking (true_labels, predicted_labels) and returning an 205 | accuracy value between 0 and 1. 206 | :param int channels_axis: The channels dimension axis of the output of the 207 | features_to_concepts function. If not given, then it is assumed to be 208 | the last dimension. 209 | 210 | :returns Tuple[float, tf.keras.Model]: A tuple (score, g_model) containing 211 | the computed completeness score together with the resulting trained 212 | g_model. 213 | """ 214 | # Let's first start by splitting our data into a training and a testing 215 | # set 216 | X_train, X_test, y_train, y_test = train_test_split( 217 | X, 218 | y, 219 | test_size=test_size, 220 | ) 221 | 222 | # Let's take a look at the intermediate activations we will be using 223 | phi_train = features_to_concepts_fn(X_train) 224 | scores_train = concept_score_fn(phi_train, concept_vectors) 225 | num_labels = len(set(y)) 226 | 227 | # Compute some useful variables while also handling the default case for 228 | # the model we will optimize over 229 | num_concepts = concept_vectors.shape[0] 230 | num_hidden_acts = phi_train.shape[channels_axis] 231 | n_samples = X_train.shape[0] 232 | g_model = g_model or _get_default_model( 233 | num_concepts=num_concepts, 234 | num_hidden_acts=num_hidden_acts, 235 | ) 236 | predictor_train_kwags = predictor_train_kwags or { 237 | 'epochs': 50, 238 | 'batch_size': min(16, n_samples), 239 | 'verbose': 0, 240 | } 241 | 242 | # Construct a model that we can use for optimizing our g function 243 | # For this, we will first need to make sure that we set our concepts 244 | # to labels model so that we do not optimize over its parameters 245 | prev_trainable = concepts_to_labels_model.trainable 246 | concepts_to_labels_model.trainable = False 247 | f_prime_input = tf.keras.layers.Input( 248 | shape=scores_train.shape[1:], 249 | dtype=scores_train.dtype, 250 | ) 251 | f_prime_output = concepts_to_labels_model(g_model(f_prime_input)) 252 | f_prime_optimized = tf.keras.Model( 253 | f_prime_input, 254 | f_prime_output, 255 | ) 256 | 257 | # Time to optimize it! 258 | f_prime_optimized.compile( 259 | optimizer=g_optimizer, 260 | loss=task_loss, 261 | ) 262 | f_prime_optimized.fit( 263 | scores_train, 264 | y_train, 265 | **predictor_train_kwags, 266 | ) 267 | 268 | # Don't forget to reconstruct the state of the concept to labels model 269 | concepts_to_labels_model.trainable = prev_trainable 270 | 271 | # Finally, compute the actual score by computing the accuracy of the 272 | # original concept-composable model 273 | phi_test = features_to_concepts_fn(X_test) 274 | random_pred_acc = 1 / num_labels 275 | f_preds = concepts_to_labels_model.predict( 276 | phi_test 277 | ) 278 | f_acc = acc_fn( 279 | y_test, 280 | f_preds, 281 | ) 282 | 283 | # And the accuracy of the model using the reconstruction from the 284 | # concept scores 285 | f_prime_preds = f_prime_optimized.predict( 286 | concept_score_fn(phi_test, concept_vectors) 287 | ) 288 | f_prime_acc = acc_fn( 289 | y_test, 290 | f_prime_preds, 291 | ) 292 | 293 | # That gives us everything we need 294 | if f_prime_acc == random_pred_acc: 295 | return 0, g_model 296 | completeness = (f_prime_acc - random_pred_acc) / (f_acc - random_pred_acc) 297 | return completeness, g_model 298 | 299 | 300 | def direct_completeness_score( 301 | X, 302 | y, 303 | features_to_concepts_fn, 304 | concept_vectors, 305 | task_loss, 306 | g_model=None, 307 | test_size=0.2, 308 | concept_score_fn=dot_prod_concept_score, 309 | predictor_train_kwags=None, 310 | g_optimizer='adam', 311 | acc_fn=sklearn.metrics.accuracy_score, 312 | channels_axis=-1, 313 | ): 314 | """ 315 | Returns the completeness score for the given set of concept vectors 316 | `concept_vectors` using testing data `X` with labels `y`. This score 317 | is computed as the predictive accuracy of a model trained to predict 318 | the labels using the concepts scores alone. It differs from the method 319 | above in that it does not require a pre-trained concept_to_labels map. 320 | 321 | :param np.ndarray X: A tensor of testing samples that are in the domain of 322 | given function `features_to_concepts_fn` where the first dimension 323 | represents the number of test samples (which we call `n_samples`). 324 | :param np.ndarray y: A tensor of testing labels corresponding to matrix 325 | X whose first dimension must also be `n_samples`. 326 | :param Function[(np.ndarray), np.ndarray] features_to_concepts_fn: A 327 | function mapping batches of samples with the same dimensionality as 328 | X into some (n_samples, ..., M, ...)-dimensional vector space 329 | corresponding to the same vector space as that used for the given 330 | concept vectors. In this case, M is the channels dimension at location 331 | channels_axis. 332 | :param np.ndarray concept_vectors: A 2D array of shape (num_concepts, M) 333 | where each column represents a given meaningful concept direction. 334 | :param tf.keras.losses.Loss task_loss: The loss function one intends to 335 | minimize when mapping instances in `X` to labels in `y`. 336 | :param tf.keras.Model g_model: The model `g` we will train for mapping 337 | concept scores to the space of labels when computing the concept 338 | completeness score. If not given, then we will use a 3-layered ReLU MLP 339 | with 500 hidden activations. 340 | :param float test_size: A value between 0 and 1 representing what percent 341 | of the (X, y) data will be used for testing the accuracy of our g_model 342 | (and the original model) when computing the completeness score. The 343 | rest of the data will be used for training our g_model. 344 | :param Function[(np.ndarray,List[np.ndarray]), np.ndarray] concept_score_fn: 345 | A function taking as an input a matrix of shape (n_samples, M), 346 | representing outputs produced by the `features_to_concepts_fn` function, 347 | and a list of`n_concepts` M-dimensional vectors, representing unit 348 | directions of meaningful concepts, and returning a vector with 349 | n_concepts concept scores. By default we use the normalized dot product 350 | scores. 351 | :param Dict[Any, Any] predictor_train_kwags: An optional set of parameters 352 | to pass to the g_model when trained for reconstructing the M-dimensional 353 | activations from their corresponding concept scores. 354 | :param tf.keras.optimizers.Optimizer g_optimizer: The optimizer used for 355 | training the g model for the reconstruction. By default we will use an 356 | ADAM optimizer. 357 | :param Function[(np.ndarray, np.ndarray), float] acc_fn: An accuracy 358 | function taking (true_labels, predicted_labels) and returning an 359 | accuracy value between 0 and 1. 360 | :param int channels_axis: The channels dimension axis of the output of the 361 | features_to_concepts function. If not given, then it is assumed to be 362 | the last dimension. 363 | 364 | :returns Tuple[float, tf.keras.Model]: A tuple (score, g_model) containing 365 | the computed completeness score together with the resulting trained 366 | g_model. 367 | """ 368 | # Let's first start by splitting our data into a training and a testing 369 | # set 370 | X_train, X_test, y_train, y_test = train_test_split( 371 | X, 372 | y, 373 | test_size=test_size, 374 | ) 375 | 376 | # Let's take a look at the intermediate activations we will be using 377 | phi_train = features_to_concepts_fn(X_train) 378 | scores_train = concept_score_fn( 379 | phi_train, 380 | concept_vectors, 381 | ) 382 | num_labels = len(set(y)) 383 | 384 | # Compute some useful variables while also handling the default case for 385 | # the model we will optimize over 386 | num_concepts = concept_vectors.shape[0] 387 | num_hidden_acts = phi_train.shape[channels_axis] 388 | n_samples = X_train.shape[0] 389 | g_model = g_model or _get_default_model( 390 | num_concepts=num_concepts, 391 | num_hidden_acts=num_labels if num_labels > 2 else 1, 392 | ) 393 | predictor_train_kwags = predictor_train_kwags or { 394 | 'epochs': 50, 395 | 'batch_size': min(16, n_samples), 396 | 'verbose': 0, 397 | } 398 | 399 | # Time to optimize it! 400 | g_model.compile( 401 | optimizer=g_optimizer, 402 | loss=task_loss, 403 | ) 404 | g_model.fit( 405 | scores_train, 406 | y_train, 407 | **predictor_train_kwags, 408 | ) 409 | 410 | # Finally, compute the actual score by computing the accuracy of predicting 411 | # the output labels using only the concept scores 412 | phi_test = features_to_concepts_fn(X_test) 413 | from_concepts_preds = g_model.predict( 414 | concept_score_fn(phi_test, concept_vectors) 415 | ) 416 | from_concepts_acc = acc_fn( 417 | y_test, 418 | from_concepts_preds, 419 | ) 420 | 421 | # That gives us everything we need 422 | return from_concepts_acc, g_model 423 | -------------------------------------------------------------------------------- /concepts_xai/evaluation/metrics/downstream_task.py: -------------------------------------------------------------------------------- 1 | from sklearn.model_selection import train_test_split 2 | from sklearn.metrics import accuracy_score 3 | 4 | 5 | def compute_downstream_task(c_pred, y_true, predictor_model): 6 | ''' 7 | Computes the accuracy score for the downstream task 8 | :param c_pred: Concept data predictions, numpy array of shape 9 | (n_samples, n_concepts) 10 | :param y_true: Ground-truth task label data, numpy array of 11 | shape (n_samples,) 12 | :param predictor_model: sklearn model to use for predicting the task labels 13 | from the concept data 14 | :return: Accuracy of predictor_model, trained and evaluated on the provided 15 | concept and label data 16 | ''' 17 | c_train, c_test, y_train, y_test = train_test_split(c_pred, y_true) 18 | predictor_model.fit(c_train, y_train) 19 | y_pred = predictor_model.predict(y_test) 20 | return accuracy_score(y_test, y_pred) 21 | -------------------------------------------------------------------------------- /concepts_xai/evaluation/metrics/mpo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def total_mispredictions_fn(c_true, c_pred): 5 | ''' 6 | Count the total number of mispredictions 7 | :param c_true: ground-truth concept values 8 | :param c_pred: predicted concept values 9 | :return: Number of elements i where c_true[i] != c_pred[i] 10 | ''' 11 | 12 | n_errs = np.sum(c_true != c_pred) 13 | return n_errs 14 | 15 | 16 | def compute_MPO(c_true, c_pred, err_fn=total_mispredictions_fn): 17 | ''' 18 | Implementation of the (M)is-(P)rediction (O)verlap (MPO) metric from the 19 | CME paper https://arxiv.org/pdf/2010.13233.pdf 20 | Given a set of predicted concept values, MPO computes the fraction of 21 | samples in the test set, that have at least m relevant concepts predicted 22 | incorrectly 23 | 24 | :param c_true: ground truth concept data, numpy array of shape 25 | (n_samples, n_concepts) 26 | :param c_pred: predicted concept data, numpy array of shape 27 | (n_samples, n_concepts) 28 | :param err_fn: function for computing the error on one single sample of 29 | c_test and c_pred (e.g. if you want to ignore certain concepts, or 30 | weight concept errors differently). Should be a function taking two 31 | arguments of shape (n_concepts), and returning a scalar value. 32 | Defaults to computing the total number of mispredictions. 33 | :return: MPO metric values computed from c_test and c_pred using err_fn, 34 | with m ranging from 0 to n_concepts 35 | ''' 36 | 37 | # Apply error function over all samples 38 | err_vals = np.array([ 39 | err_fn(c_true[i], c_pred[i]) for i in range(c_true.shape[0]) 40 | ]) 41 | 42 | # Compute MPO values for m ranging from 0 to n_concepts 43 | n_concepts = c_true.shape[1] 44 | mpo_vals = [] 45 | for i in range(n_concepts): 46 | # Compute number of samples with at least i incorrect concept 47 | # predictions 48 | n_incorrect = (err_vals >= i).astype(np.int) 49 | 50 | # Compute % of these samples from total 51 | metric = (np.sum(n_incorrect) / c_true.shape[0]) 52 | mpo_vals.append(metric) 53 | 54 | return np.array(mpo_vals) 55 | -------------------------------------------------------------------------------- /concepts_xai/evaluation/metrics/niching.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import accuracy_score, roc_auc_score 3 | from sklearn.feature_selection import mutual_info_classif 4 | from scipy.special import softmax 5 | 6 | 7 | def niche_completeness(c_pred, y_true, predictor_model, niches): 8 | ''' 9 | Computes the niche completeness score for the downstream task 10 | :param c_pred: Concept data predictions, numpy array of shape 11 | (n_samples, n_concepts) 12 | :param y_true: Ground-truth task label data, numpy array of shape 13 | (n_samples, n_tasks) 14 | :param predictor_model: trained decoder model to use for predicting the task 15 | labels from the concept data 16 | :return: Accuracy of predictor_model, evaluated on niches obtained from the 17 | provided concept and label data 18 | ''' 19 | n_tasks = y_true.shape[1] 20 | # compute niche completeness for each task 21 | niche_completeness_list, y_pred_list = [], [] 22 | for task in range(n_tasks): 23 | # find niche 24 | niche = np.zeros_like(c_pred) 25 | niche[:, niches[:, task] > 0] = c_pred[:, niches[:, task] > 0] 26 | 27 | # compute task predictions 28 | y_pred_niche = predictor_model.predict_proba(niche) 29 | if predictor_model.__class__.__name__ == 'Sequential': 30 | # get class labels from logits 31 | y_pred_niche = y_pred_niche > 0 32 | elif len(y_pred_niche.shape) == 1: 33 | y_pred_niche = y_pred_niche[:, np.newaxis] 34 | 35 | y_pred_list.append(y_pred_niche[:, task]) 36 | 37 | y_preds = np.vstack(y_pred_list).T 38 | y_preds = softmax(y_preds, axis=1) 39 | auc = roc_auc_score(y_true, y_preds, multi_class='ovo') 40 | 41 | result = { 42 | 'auc_completeness': auc, 43 | 'y_preds': y_preds, 44 | } 45 | return result 46 | 47 | 48 | def niche_completeness_ratio(c_pred, y_true, predictor_model, niches): 49 | ''' 50 | Computes the niche completeness ratio for the downstream task 51 | :param c_pred: Concept data predictions, numpy array of shape 52 | (n_samples, n_concepts) 53 | :param y_true: Ground-truth task label data, numpy array of shape 54 | (n_samples, n_tasks) 55 | :param predictor_model: sklearn model to use for predicting the task labels 56 | from the concept data 57 | :return: Accuracy ratio between the accuracy of predictor_model evaluated 58 | on niches and the accuracy of predictor_model evaluated on all concepts 59 | ''' 60 | n_tasks = y_true.shape[1] 61 | 62 | y_pred_test = predictor_model.predict_proba(c_pred) 63 | if predictor_model.__class__.__name__ == 'Sequential': 64 | # get class labels from logits 65 | y_pred_test = y_pred_test > 0 66 | elif len(y_pred_test.shape) == 1: 67 | y_pred_test = y_pred_test[:, np.newaxis] 68 | 69 | # compute niche completeness for each task 70 | niche_completeness_list = [] 71 | for task in range(n_tasks): 72 | # find niche 73 | niche = np.zeros_like(c_pred) 74 | niche[:, niches[:, task] > 0] = c_pred[:, niches[:, task] > 0] 75 | 76 | # compute task predictions 77 | y_pred_niche = predictor_model.predict_proba(niche) 78 | if predictor_model.__class__.__name__ == 'Sequential': 79 | # get class labels from logits 80 | y_pred_niche = y_pred_niche > 0 81 | elif len(y_pred_niche.shape) == 1: 82 | y_pred_niche = y_pred_niche[:, np.newaxis] 83 | 84 | # compute accuracies 85 | accuracy_base = accuracy_score(y_true[:, task], y_pred_test[:, task]) 86 | accuracy_niche = accuracy_score(y_true[:, task], y_pred_niche[:, task]) 87 | 88 | # compute the accuracy ratio of the niche w.r.t. the baseline 89 | # (full concept bottleneck) the higher the better (high predictive power 90 | # of the niche) 91 | niche_completeness = accuracy_niche / accuracy_base 92 | niche_completeness_list.append(niche_completeness) 93 | 94 | result = { 95 | 'niche_completeness_ratio_mean': np.mean(niche_completeness_list), 96 | 'niche_completeness_ratio': niche_completeness_list, 97 | } 98 | return result 99 | 100 | 101 | def niche_impurity(c_pred, y_true, predictor_model, niches): 102 | ''' 103 | Computes the niche impurity score for the downstream task 104 | :param c_pred: Concept data predictions, numpy array of shape 105 | (n_samples, n_concepts) 106 | :param y_true: Ground-truth task label data, numpy array of shape 107 | (n_samples, n_tasks) 108 | :param predictor_model: sklearn model to use for predicting the task labels 109 | from the concept data 110 | :return: Accuracy ratio between the accuracy of predictor_model evaluated on 111 | concepts outside niches and the accuracy of predictor_model evaluated on 112 | concepts inside niches 113 | ''' 114 | n_tasks = y_true.shape[1] 115 | 116 | # compute niche completeness for each task 117 | y_pred_list = [] 118 | for task in range(n_tasks): 119 | # find niche 120 | niche = np.zeros_like(c_pred) 121 | niche[:, niches[:, task] > 0] = c_pred[:, niches[:, task] > 0] 122 | 123 | # find concepts outside the niche 124 | niche_out = np.zeros_like(c_pred) 125 | niche_out[:, niches[:, task] <= 0] = c_pred[:, niches[:, task] <= 0] 126 | 127 | # compute task predictions 128 | y_pred_niche = predictor_model.predict_proba(niche) 129 | y_pred_niche_out = predictor_model.predict_proba(niche_out) 130 | if predictor_model.__class__.__name__ == 'Sequential': 131 | # get class labels from logits 132 | y_pred_niche_out = y_pred_niche_out > 0 133 | elif len(y_pred_niche.shape) == 1: 134 | y_pred_niche_out = y_pred_niche_out[:, np.newaxis] 135 | 136 | y_pred_list.append(y_pred_niche_out[:, task]) 137 | 138 | y_preds = np.vstack(y_pred_list).T 139 | y_preds = softmax(y_preds, axis=1) 140 | auc = roc_auc_score(y_true, y_preds, multi_class='ovo') 141 | 142 | return { 143 | 'auc_impurity': auc, 144 | 'y_preds': y_preds, 145 | } 146 | 147 | 148 | def niche_finding(c, y, mode='mi', threshold=0.5): 149 | n_concepts = c.shape[1] 150 | if mode == 'corr': 151 | corrm = np.corrcoef(np.hstack([c, y]).T) 152 | niching_matrix = corrm[:n_concepts, n_concepts:] 153 | niches = np.abs(niching_matrix) > threshold 154 | elif mode == 'mi': 155 | nm = [] 156 | for yj in y.T: 157 | mi = mutual_info_classif(c, yj) 158 | nm.append(mi) 159 | nm = np.vstack(nm).T 160 | niching_matrix = nm / np.max(nm) 161 | niches = niching_matrix > threshold 162 | else: 163 | return None, None, None 164 | 165 | return niches, niching_matrix 166 | -------------------------------------------------------------------------------- /concepts_xai/methods/CBM/CBModel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module containing implementations for Concept Bottleneck Models (CBMs) as 3 | described by Koh et al. in https://arxiv.org/abs/2007.04612 4 | """ 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from collections import defaultdict 10 | from tensorflow.python.keras.engine import data_adapter 11 | 12 | 13 | ################################################################################ 14 | ## Exposed Functions 15 | ################################################################################ 16 | 17 | def produce_bottleneck(model, layer_idx): 18 | """ 19 | Partitions a Keras model into two disjoint computational graphs: (1) an 20 | encoder that maps from inputs to activations from layer with index 21 | `layer_idx` and (2) a decoder model that maps activations from layer 22 | with index `layer_idx` to the output of the given model. 23 | 24 | For this operation to be successful, the layer at index `layer_idx` 25 | must be a bottleneck of the input model (i.e., there may not be any 26 | connections from the layers preceding the bottleneck layer with those 27 | topologically after the bottleneck layer). 28 | 29 | :param tf.keras.Model model: A model which we will split into two disjoint 30 | submodels. 31 | :param int layer_idx: A valid layer index in the given model. It must 32 | be a valid index in the array of layers represented by model.layers. 33 | 34 | :return Tuple[tf.keras.Model, tf.keras.Model]: a tuple of models 35 | (encoder, decoder) representing the input to bottleneck model and the 36 | bottleneck to output model, respectively. 37 | """ 38 | 39 | # Let's start by making sure we get a full understanding of the input 40 | # model's topology 41 | in_edges = defaultdict(set) 42 | out_edges = defaultdict(set) 43 | name_to_layer = {} 44 | for src in model.layers: 45 | name_to_layer[src.name] = src 46 | for dst_node in src._outbound_nodes: 47 | in_edges[dst_node.layer.name].add(src.name) 48 | out_edges[src.name].add(dst_node.layer.name) 49 | 50 | # Now let's find the layer we will use as our bottleneck layer 51 | if len(model.layers) <= layer_idx: 52 | raise ValueError( 53 | f"Requested to use layer with index {layer_idx} as the bottleneck " 54 | f"layer however given model '{model.name}' has only " 55 | f"{len(model.layers)} indexable layers." 56 | ) 57 | bottleneck_layer = model.layers[layer_idx] 58 | 59 | # Once we have the bottleneck, let's look at all the nodes that precede it 60 | # and follow it in the computational graph defined by `model`. For this to 61 | # be considered a valid bottleneck, the set of nodes after the bottleneck 62 | # layer must be disjoint from the set of nodes preceding the bottleneck 63 | # layer (i.e., there must be no edges from the subgraph preceding the 64 | # bottleneck layer into the subgraph defined by nodes that are topologically 65 | # after the bottleneck layer). 66 | preceding_nodes = set() 67 | frontier = [bottleneck_layer.name] 68 | while frontier: 69 | next_node = frontier.pop() 70 | for src_name in in_edges[next_node]: 71 | if src_name in preceding_nodes: 72 | # Then we have already dealt with it 73 | continue 74 | preceding_nodes.add(src_name) 75 | frontier.append(src_name) 76 | 77 | # And walk the graph to compute the nodes after the bottleneck layer 78 | posterior_nodes = set() 79 | frontier = [bottleneck_layer.name] 80 | while frontier: 81 | next_node = frontier.pop() 82 | for dst_name in out_edges[next_node]: 83 | if dst_name in posterior_nodes: 84 | # Then we have already dealt with it 85 | continue 86 | posterior_nodes.add(dst_name) 87 | frontier.append(dst_name) 88 | 89 | if (posterior_nodes & preceding_nodes): 90 | raise ValueError( 91 | f"Requested bottleneck layer {bottleneck_layer.name} (index " 92 | f"{layer_idx}) does not partition the computational graph of " 93 | f"provided the model '{model.name}' into two disjoint subgraphs (" 94 | f"i.e., there is a connection between layers preceeding the " 95 | f"bottleneck and layers after the bottleneck layer)." 96 | ) 97 | 98 | # We can now compute the size of the actual bottleneck 99 | if isinstance(bottleneck_layer.output, list): 100 | raise ValueError( 101 | f"Currently we do not support as a bottleneck layer a layer that " 102 | f"has more than one output. Requested bottleneck layer " 103 | f"{bottleneck_layer.name} (at index {layer_idx}) has " 104 | f"{len(bottleneck_layer.output)} outputs." 105 | ) 106 | # Else let's check the number of concepts we are expecting vs the number 107 | # of entries 108 | num_concepts = bottleneck_layer.output.shape[-1] 109 | 110 | # With this, building the encoder is now trivial 111 | encoder = tf.keras.Model( 112 | inputs=model.inputs, 113 | outputs=bottleneck_layer.output, 114 | ) 115 | 116 | decoder_input = tf.keras.layers.Input(num_concepts) 117 | decoder_outputs = [] 118 | name_to_layer[bottleneck_layer.name] = decoder_input 119 | for layer in model.layers: 120 | if (layer.name == bottleneck_layer.name) or ( 121 | layer.name not in posterior_nodes 122 | ): 123 | continue 124 | # Otherwise let's make sure we feed it with the input corresponding 125 | # to the new computational graph we are constructing from the bottleneck 126 | # layer (NOTE: this works as we are iterating over layers in topological 127 | # order) 128 | input_layers = [] 129 | for input_name in in_edges[layer.name]: 130 | input_layers.append(name_to_layer[input_name]) 131 | if len(input_layers) == 1: 132 | input_layers = input_layers[0] 133 | # Generate the new node 134 | new_node = layer(input_layers) 135 | name_to_layer[layer.name] = new_node 136 | 137 | # And add it to the output if this was an original output 138 | if layer.name in model.output_names: 139 | decoder_outputs.append(new_node) 140 | decoder = tf.keras.Model( 141 | inputs=decoder_input, 142 | outputs=decoder_outputs 143 | ) 144 | 145 | return encoder, decoder 146 | 147 | 148 | ################################################################################ 149 | ## Exposed Classes 150 | ################################################################################ 151 | 152 | class JointConceptBottleneckModel(tf.keras.Model): 153 | """ 154 | Main class for implementing a Joint Concept Bottleneck Model with the 155 | given encoder mapping input features to concepts and the given 156 | decoder which maps concept encodings to labels. 157 | This class encapsulates the joint training process of a CBM while allowing 158 | an arbitrary encoder/decoder model to be used in its construction. 159 | 160 | Note that it generalizes the original CBM by Koh et al. by allowing the 161 | encoder to produce non-binary concepts rather than assuming all concepts 162 | are binary in nature. 163 | """ 164 | 165 | def __init__( 166 | self, 167 | encoder, 168 | decoder, 169 | task_loss, 170 | alpha=0.01, 171 | metrics=None, 172 | pass_concept_logits=False, 173 | concept_sample_weights=None, 174 | single_multiclass_concept=False, 175 | **kwargs 176 | ): 177 | """ 178 | Constructs a new trainable joint CBM which can be then trained, 179 | optimized and/or used for prediction as any other Keras model can. 180 | 181 | When using this model for prediction, it will return a tuple 182 | (labels, concepts) indicating the predicted label probabilities as 183 | well as the predicted concepts probabilities. 184 | 185 | :param tf.keras.Model encoder: A valid keras model that maps input 186 | features to a set of concepts. If the output of this model is 187 | a single vector, then every entry of this vector is assumed to be 188 | one binary concept. Otherwise, if the output of the encoder is a 189 | list of vectors, then we assume that each vector represents a 190 | probability distribution over different classes for each concept ( 191 | i.e., we assume one concept per vector). 192 | :param tf.keras.Model decoder: A valid keras model mapping a concept 193 | vector to a set of task-specific labels. We assume that if the 194 | encoder outputs a list of concepts then the input to this model 195 | is the concatenation of all the output vectors of the encoder in 196 | the same order as produced by the encoder. 197 | :param tf.keras.losses.Loss task_loss: The loss to be used for the 198 | specific task labels. 199 | :param float alpha: A parameter indicating how much weight should one 200 | assign the loss coming from training the bottleneck. If 0, then 201 | there is no learning enforced in the bottleneck. 202 | :param List[tf.keras.metrics.Metric] metrics: A list of possible 203 | metrics of interest which one may want to monitor during training. 204 | :param Bool pass_concept_logits: Whether the concept bottleneck will 205 | be passed to the concept-to-task model as logits (i.e., without 206 | a softmax or sigmoid operation applied to it) or not. If this is 207 | set to false, then it is the responsability of the input encoder 208 | model to output a valid probability distribution. 209 | :param Dict[Any, Any] kwargs: Keras Layer specific kwargs to be passed 210 | to the parent constructor. 211 | """ 212 | super(JointConceptBottleneckModel, self).__init__(**kwargs) 213 | self.encoder = encoder 214 | self.decoder = decoder 215 | self.total_loss_tracker = tf.keras.metrics.Mean( 216 | name="total_loss" 217 | ) 218 | self.concept_loss_tracker = tf.keras.metrics.Mean( 219 | name="concept_loss" 220 | ) 221 | self.task_loss_tracker = tf.keras.metrics.Mean( 222 | name="task_loss" 223 | ) 224 | self.concept_accuracy_tracker = tf.keras.metrics.Mean( 225 | name="concept_accuracy" 226 | ) 227 | self._acc_metric = \ 228 | lambda y_true, y_pred: tf.keras.metrics.sparse_top_k_categorical_accuracy( 229 | y_true, 230 | y_pred, 231 | k=1, 232 | ) 233 | self._bin_acc_metric = \ 234 | lambda y_true, y_pred: tf.math.reduce_mean( 235 | tf.keras.metrics.binary_accuracy(y_true, y_pred), 236 | axis=-1, 237 | ) 238 | self.alpha = alpha 239 | self.task_loss = task_loss 240 | self.extra_metrics = metrics or [] 241 | self.pass_concept_logits = pass_concept_logits 242 | self.concept_sample_weights = concept_sample_weights 243 | self.single_multiclass_concept = single_multiclass_concept 244 | 245 | # dummy call to build the model 246 | self(tf.zeros(list(map( 247 | lambda x: 1 if x is None else x, 248 | self.encoder.input_shape 249 | )))) 250 | 251 | @property 252 | def metrics(self): 253 | return [ 254 | self.total_loss_tracker, 255 | self.concept_loss_tracker, 256 | self.task_loss_tracker, 257 | self.concept_accuracy_tracker, 258 | ] + self.extra_metrics 259 | 260 | def predict_from_concepts(self, concepts): 261 | """ 262 | Given a set of concepts (e.g., coming from an intervention), this 263 | function returns the predicted labels for those concepts. 264 | 265 | :param np.ndarray concepts: A matrix of concepts predictions from which 266 | we wish to obtain classes for. It shape must be 267 | (n_samples, n_concepts) if concepts are binary. Otherwise, it should 268 | have shape (n_samples, ). 269 | 270 | :returns np.ndarray: Label probability predictions for the given set 271 | of concepts. 272 | """ 273 | if isinstance(concepts, list): 274 | if len(concepts) > 1: 275 | concepts = tf.keras.layers.Concatenate(axis=-1)( 276 | concepts 277 | ) 278 | else: 279 | concepts = concepts[0] 280 | return self.decoder(concepts) 281 | 282 | def call(self, inputs): 283 | # We will use the log of the variance rather than the actual variance 284 | # for stability purposes 285 | outputs, concepts, _ = self._call_fn(inputs) 286 | return outputs, concepts 287 | 288 | def _call_fn(self, inputs, **kwargs): 289 | # This method is separate from the call method above as it allows one 290 | # to overwrite this class and include an extra set of losses (returned 291 | # as the third element in the tuple) which could, for example, include 292 | # some decorrelation regularization term between concept predictions. 293 | concepts = self.encoder(inputs, **kwargs) 294 | return self.predict_from_concepts(concepts), concepts, [] 295 | 296 | def _compute_losses( 297 | self, 298 | predicted_labels, 299 | predicted_concepts, 300 | true_labels, 301 | true_concepts, 302 | ): 303 | """ 304 | Helper function for computing all the losses we require for training 305 | our joint CBM. 306 | """ 307 | # Updates stateful loss metrics. 308 | task_loss = self.task_loss(true_labels, predicted_labels) 309 | concept_loss = 0.0 310 | concept_accuracy = 0.0 311 | # If generating model does not produce a list of outputs, then we will 312 | # assume all concepts are binary 313 | if isinstance(predicted_concepts, list): 314 | for i, predicted_vec in enumerate(predicted_concepts): 315 | true_vec = true_concepts[:, i] 316 | sample_weight = None 317 | if self.concept_sample_weights is not None: 318 | sample_weight = self.concept_sample_weights[:, i:i+1] 319 | if (len(predicted_vec.shape) == 1) or ( 320 | predicted_vec.shape[-1] == 1 321 | ): 322 | # Then use binary loss here 323 | concept_loss += tf.keras.losses.BinaryCrossentropy( 324 | from_logits=self.pass_concept_logits, 325 | )( 326 | true_vec, 327 | predicted_vec, 328 | sample_weight=sample_weight, 329 | ) 330 | if len(predicted_vec.shape) == 2: 331 | # Then, let's remove the degenerate dimension 332 | predicted_vec = tf.squeeze(predicted_vec, axis=-1) 333 | concept_accuracy += self._bin_acc_metric( 334 | true_vec, 335 | predicted_vec, 336 | ) 337 | else: 338 | # Otherwise use normal cross entropy 339 | concept_loss += \ 340 | tf.keras.losses.SparseCategoricalCrossentropy( 341 | from_logits=self.pass_concept_logits, 342 | )( 343 | true_vec, 344 | predicted_vec, 345 | sample_weight=sample_weight, 346 | ) 347 | concept_accuracy += self._acc_metric( 348 | true_vec, 349 | predicted_vec, 350 | ) 351 | 352 | # And time to normalize over all the different heads 353 | concept_loss = concept_loss / len(predicted_concepts) 354 | concept_accuracy = concept_accuracy / len(predicted_concepts) 355 | elif self.single_multiclass_concept: 356 | # Then all elements in the bottleneck correspond to a single 357 | # concept that is a multi-class concept 358 | concept_loss += \ 359 | tf.keras.losses.SparseCategoricalCrossentropy( 360 | from_logits=self.pass_concept_logits, 361 | )( 362 | true_concepts, 363 | predicted_concepts, 364 | sample_weight=self.concept_sample_weights, 365 | ) 366 | concept_accuracy += self._acc_metric( 367 | true_concepts, 368 | predicted_concepts, 369 | ) 370 | else: 371 | # Then use binary loss here as we are given a single vector and we 372 | # will assume in that instance they all represent independent 373 | # binary concepts 374 | concept_loss += tf.keras.losses.BinaryCrossentropy( 375 | from_logits=self.pass_concept_logits, 376 | )( 377 | true_concepts, 378 | predicted_concepts, 379 | sample_weight=self.concept_sample_weights, 380 | ) 381 | concept_accuracy = self._bin_acc_metric( 382 | true_concepts, 383 | predicted_concepts, 384 | ) 385 | return task_loss, concept_loss, concept_accuracy 386 | 387 | def test_step(self, data): 388 | """ 389 | Overwrite function for the Keras model indicating how a test step 390 | will operate. 391 | 392 | :param Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]] data: The input 393 | training data is expected to be provided in the form 394 | (input_features, (true_labels, true_concepts)). 395 | """ 396 | # Massage the data 397 | data = data_adapter.expand_1d(data) 398 | input_features, (true_labels, true_concepts), sample_weight = \ 399 | data_adapter.unpack_x_y_sample_weight(data) 400 | 401 | # Obtain a prediction of labels and concepts 402 | predicted_labels, predicted_concepts, extra_losses = self._call_fn( 403 | input_features, 404 | training=False, 405 | ) 406 | # Compute the actual losses 407 | task_loss, concept_loss, concept_accuracy = self._compute_losses( 408 | predicted_labels=predicted_labels, 409 | predicted_concepts=predicted_concepts, 410 | true_labels=true_labels, 411 | true_concepts=true_concepts, 412 | ) 413 | 414 | # Accumulate both the concept and task-specific loss into a single value 415 | total_loss = ( 416 | task_loss + 417 | self.alpha * concept_loss 418 | ) 419 | for extra_loss in extra_losses: 420 | total_loss += extra_loss 421 | result = { 422 | self.concept_accuracy_tracker.name: concept_accuracy, 423 | self.concept_loss_tracker.name: concept_loss, 424 | self.task_loss_tracker.name: task_loss, 425 | self.total_loss_tracker.name: total_loss, 426 | } 427 | for metric in self.extra_metrics: 428 | result[metric.name] = metric( 429 | true_labels, 430 | predicted_labels, 431 | sample_weight, 432 | ) 433 | return result 434 | 435 | @tf.function 436 | def train_step(self, data): 437 | """ 438 | Overwrite function for the Keras model indicating how a train step 439 | will operate. 440 | 441 | :param Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]] data: The input 442 | training data is expected to be provided in the form 443 | (input_features, (true_labels, true_concepts)). 444 | """ 445 | # Massage the data 446 | data = data_adapter.expand_1d(data) 447 | input_features, (true_labels, true_concepts), sample_weight = \ 448 | data_adapter.unpack_x_y_sample_weight(data) 449 | with tf.GradientTape() as tape: 450 | # Obtain a prediction of labels and concepts 451 | predicted_labels, predicted_concepts, extra_losses = self._call_fn( 452 | input_features 453 | ) 454 | # Compute the actual losses 455 | task_loss, concept_loss, concept_accuracy = self._compute_losses( 456 | predicted_labels=predicted_labels, 457 | predicted_concepts=predicted_concepts, 458 | true_labels=true_labels, 459 | true_concepts=true_concepts, 460 | ) 461 | # Accumulate both the concept and task-specific loss into a single 462 | # value 463 | total_loss = ( 464 | task_loss + 465 | self.alpha * concept_loss 466 | ) 467 | # And include any extra losses coming from this process 468 | for extra_loss in extra_losses: 469 | total_loss += extra_loss 470 | 471 | num_concepts = ( 472 | len(predicted_concepts) if isinstance(predicted_concepts, list) else 473 | predicted_concepts.shape[-1] 474 | ) 475 | grads = tape.gradient(total_loss, self.trainable_weights) 476 | self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) 477 | self.total_loss_tracker.update_state(total_loss, sample_weight) 478 | self.task_loss_tracker.update_state(task_loss, sample_weight) 479 | self.concept_loss_tracker.update_state(concept_loss, sample_weight) 480 | self.concept_accuracy_tracker.update_state( 481 | concept_accuracy, 482 | sample_weight, 483 | ) 484 | for metric in self.extra_metrics: 485 | metric.update_state(true_labels, predicted_labels, sample_weight) 486 | return { 487 | metric.name: metric.result() 488 | for metric in self.metrics 489 | } 490 | 491 | 492 | class BypassJointCBM(JointConceptBottleneckModel): 493 | def __init__( 494 | self, 495 | encoder, 496 | decoder, 497 | task_loss, 498 | alpha=0.01, 499 | metrics=None, 500 | pass_concept_logits=False, 501 | concept_sample_weights=None, 502 | single_multiclass_concept=False, 503 | **kwargs, 504 | ): 505 | """ 506 | Extension of CBM model above that allows extra capacity in the 507 | bottleneck for activations that have no concept supervision. 508 | Expects the encoder to output a tuple (concepts, latent_code) where 509 | concepts has the same required properties of the output of the encoder 510 | in JointConceptBottleneckModel and latent_code is a np.ndarray vector 511 | representing activations in the bottleneck that have no supervision. 512 | 513 | The concatentation of the elements in concepts and latent_code will be 514 | fed into the provided decoder model. 515 | 516 | When using this model for prediction, it will return a tuple 517 | (labels, concepts) indicating the predicted label probabilities as 518 | well as the predicted concepts probabilities. 519 | 520 | :param tf.keras.Model encoder: A valid keras model that maps input 521 | features to a set of concepts and a vector of unsupervised latent 522 | activations. If the concept output of this model is a single vector, 523 | then every entry of that vector is assumed to be one binary concept. 524 | Otherwise, if the output of the encoder's concepts is a list of 525 | vectors, then we assume that each vector represents a probability 526 | distribution over different classes for each concept (i.e., we 527 | assume one concept per vector). 528 | :param tf.keras.Model decoder: A valid keras model mapping a concept 529 | vector concatenated to the unsupervised latent dimensions to a set 530 | of task-specific labels. We assume that if the encoder outputs a 531 | list of concepts, then the input to this model is the concatenation 532 | of all the output vectors of the encoder (including the unsupervised 533 | latent dimensions) in the same order as produced by the encoder. 534 | :param tf.keras.losses.Loss task_loss: The loss to be used for the 535 | specific task labels. 536 | :param float alpha: A parameter indicating how much weight should one 537 | assign the loss coming from training the bottleneck. If 0, then 538 | there is no learning enforced in the bottleneck. 539 | :param List[tf.keras.metrics.Metric] metrics: A list of possible 540 | metrics of interest which one may want to monitor during training. 541 | :param Bool pass_concept_logits: Whether the concept bottleneck will 542 | be passed to the concept-to-task model as logits (i.e., without 543 | a softmax or sigmoid operation applied to it) or not. If this is 544 | set to false, then it is the responsability of the input encoder 545 | model to output a valid probability distribution. 546 | :param Dict[Any, Any] kwargs: Keras Layer specific kwargs to be passed 547 | to the parent constructor. 548 | 549 | """ 550 | super(BypassJointCBM, self).__init__( 551 | encoder=encoder, 552 | decoder=decoder, 553 | task_loss=task_loss, 554 | alpha=alpha, 555 | metrics=metrics, 556 | pass_concept_logits=pass_concept_logits, 557 | concept_sample_weights=concept_sample_weights, 558 | single_multiclass_concept=single_multiclass_concept, 559 | **kwargs 560 | ) 561 | 562 | def call(self, inputs): 563 | # We will use the log of the variance rather than the actual variance 564 | # for stability purposes 565 | concepts, latent_code = self.encoder(inputs) 566 | if not isinstance(concepts, list): 567 | decode_inputs = [concepts] + [latent_code] 568 | else: 569 | decode_inputs = concepts + [latent_code] 570 | return ( 571 | self.predict_from_concepts(decode_inputs), 572 | concepts, 573 | latent_code, 574 | ) 575 | 576 | def _call_fn(self, inputs, **kwargs): 577 | # Compute our concepts and latent code 578 | concepts, latent_code = self.encoder(inputs, **kwargs) 579 | if not isinstance(concepts, list): 580 | decode_inputs = [concepts] + [latent_code] 581 | else: 582 | decode_inputs = concepts + [latent_code] 583 | return self.predict_from_concepts(decode_inputs), concepts, [] 584 | -------------------------------------------------------------------------------- /concepts_xai/methods/CME/CtlModel.py: -------------------------------------------------------------------------------- 1 | from sklearn.linear_model import LogisticRegression, LinearRegression 2 | from sklearn.tree import DecisionTreeClassifier 3 | from sklearn.ensemble import GradientBoostingClassifier 4 | 5 | ''' 6 | CtL : Concept-to-Label 7 | 8 | Class for the transparent model representing a function from concepts to task 9 | labels. Represents the decision-making process of a given black-box model, in 10 | the concept representation 11 | ''' 12 | 13 | 14 | class CtLModel: 15 | 16 | def __init__(self, **params): 17 | 18 | # Create copy of passed-in parameters 19 | self.params = params 20 | 21 | # Assign parameter values 22 | self.clf_type = self.params.get("method", "DT") 23 | self.n_concepts = self.params.get("n_concepts") 24 | self.n_classes = self.params.get("n_classes") 25 | self.c_names = self.params.get( 26 | "concept_names", 27 | ["Concept_" + str(i) for i in range(self.n_concepts)], 28 | ) 29 | self.cls_names = self.params.get( 30 | "class_names", 31 | ["Class_ " + str(i) for i in range(self.n_classes)], 32 | ) 33 | 34 | def train(self, c_data, y_data): 35 | if self.clf_type == 'DT': 36 | clf = DecisionTreeClassifier(class_weight='balanced') 37 | elif self.clf_type == 'LR': 38 | clf = LogisticRegression( 39 | max_iter=200, 40 | multi_class='auto', 41 | solver='lbfgs', 42 | ) 43 | elif self.clf_type == 'LinearRegression': 44 | clf = LinearRegression() 45 | elif self.clf_type == 'GBT': 46 | clf = GradientBoostingClassifier() 47 | else: 48 | raise ValueError("Unrecognised model type...") 49 | 50 | clf.fit(c_data, y_data) 51 | self.clf = clf 52 | 53 | def predict(self, c_data): 54 | return self.clf.predict(c_data) 55 | 56 | 57 | -------------------------------------------------------------------------------- /concepts_xai/methods/CME/ItCModel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from sklearn.dummy import DummyClassifier 4 | from sklearn.ensemble import GradientBoostingClassifier 5 | from sklearn.linear_model import LogisticRegression 6 | from sklearn.metrics import accuracy_score 7 | from sklearn.model_selection import train_test_split 8 | from sklearn.semi_supervised import LabelSpreading 9 | 10 | 11 | class ItCModel(object): 12 | 13 | def __init__(self, model, **params): 14 | 15 | # Create copy of passed-in parameters 16 | self.params = params 17 | 18 | # Retrieve the layers to use for hidden representations 19 | self.layer_ids = params.get( 20 | "layer_ids", 21 | [i for i in range(len(model.layers))], 22 | ) 23 | if self.layer_ids is None: 24 | self.layer_ids = [i for i in range(len(model.layers))] 25 | # Retrieve the corresponding layer names 26 | self.layer_names = params.get( 27 | "layer_names", 28 | ["Layer_" + str(i) for i in range(len(model.layers))], 29 | ) 30 | 31 | # Set number of parameters 32 | self.n_concepts = params['n_concepts'] 33 | 34 | # Retrieve the concept names 35 | self.concept_names = params.get( 36 | "concept_names", 37 | ["Concept_" + str(i) for i in range(self.n_concepts)], 38 | ) 39 | 40 | # Batch size to use during computation 41 | self.batch_size = params.get("batch_size", 128) 42 | 43 | # Set classifier to use for concept value prediction 44 | self.clf_method = params.get("method", "LR") 45 | 46 | # Model for extracting activations from 47 | self.model = model 48 | 49 | def get_clf(self): 50 | if self.clf_method == 'LR': 51 | semi_supervised = False 52 | clf = LogisticRegression(max_iter=200, class_weight='balanced') 53 | elif self.clf_method == 'LP': 54 | semi_supervised = True 55 | clf = LabelSpreading() 56 | elif self.clf_method == 'GBT': 57 | clf = GradientBoostingClassifier() 58 | semi_supervised = False 59 | else: 60 | raise ValueError("Non-implemented method") 61 | 62 | return clf, semi_supervised 63 | 64 | def _get_layer_concept_predictor_model(self, h_data_l, c_data_l, h_data_u): 65 | ''' 66 | Train cme concept label predictor for a particular layer and concept 67 | :param h_data_l: Activation data for a given layer 68 | :param c_data_l: Corresponding concept labels for that data 69 | :param h_data_u: Activation data without corresponding labels 70 | :return: Classifier predicting concept values from the activations 71 | ''' 72 | 73 | # Create safe copies of data 74 | x_data_l = np.copy(h_data_l) 75 | y_data_l = np.copy(c_data_l) 76 | x_data_u = np.copy(h_data_u) 77 | y_data_u = np.ones((x_data_u.shape[0])) * -1 # Here, label is 'undefined' 78 | 79 | # Specify whether to use the unlabelled data at all 80 | semi_supervised = False 81 | 82 | # If there is only 1 value, then return simple classifier 83 | unique_vals = np.unique(y_data_l) 84 | if len(unique_vals) == 1: 85 | clf = DummyClassifier(strategy="constant", constant=c_data_l[0]) 86 | else: 87 | # Otherwise, train classifier model 88 | clf, semi_supervised = self.get_clf() 89 | 90 | # Split the labelled data for train/test 91 | x_train, x_test, y_train, y_test = train_test_split( 92 | x_data_l, 93 | y_data_l, 94 | test_size=0.15, 95 | ) 96 | 97 | # Combine with unlabelled data, if using a SSCC method 98 | if semi_supervised: 99 | x_train = np.concatenate([x_train, x_data_u]) 100 | y_train = np.concatenate([y_train, y_data_u]) 101 | 102 | # Train classifier 103 | clf.fit(x_train, y_train) 104 | 105 | # Retrieve predictive accuracy of classifier 106 | pred_acc = accuracy_score(y_test, clf.predict(x_test)) 107 | 108 | return clf, pred_acc 109 | 110 | def train(self, ds_l, ds_u): 111 | ''' 112 | Compute a dictionary with structure: layer --> concept --> classifier 113 | i.e., the dictionary returns which clf to use to predict a given 114 | concept, from a given layer 115 | :param ds_l: labelled concept dataset, consisting of tuples 116 | (input data, concept data) 117 | :param ds_u: unlabelled dataset, consisting of (input data) only, 118 | without concept labels 119 | ''' 120 | 121 | # Load all concept data into a numpy array 122 | c_data_l = np.array([elem[1].numpy() for elem in ds_l]) 123 | 124 | self.model_ensemble = {} 125 | self.model_accuracies = {} 126 | 127 | # Compute activations for specified layers 128 | h_data_ls = compute_activation_per_layer( 129 | ds_l, 130 | self.layer_ids, 131 | self.model, 132 | aggregation_function=flatten_activations, 133 | ) 134 | h_data_us = compute_activation_per_layer( 135 | ds_u, 136 | self.layer_ids, 137 | self.model, 138 | aggregation_function=flatten_activations, 139 | ) 140 | 141 | for i, layer_id in enumerate(self.layer_ids): 142 | print("-" * 20) 143 | print( 144 | "Processing layer ", 145 | str(i + 1), 146 | " of ", str(len(self.layer_ids)), 147 | ":", 148 | self.model.layers[layer_id].name, 149 | ) 150 | # Retrieve activations for next layer 151 | activations_l = h_data_ls[i] 152 | activations_u = h_data_us[i] 153 | 154 | output_layer = self.model.layers[layer_id] 155 | self.model_ensemble[output_layer] = [] 156 | self.model_accuracies[output_layer] = [] 157 | 158 | # Generate predictive models for every concept 159 | for c in range(self.n_concepts): 160 | clf, pred_acc = self._get_layer_concept_predictor_model( 161 | activations_l, 162 | c_data_l[:, c], 163 | activations_u, 164 | ) 165 | self.model_ensemble[output_layer].append(clf) 166 | self.model_accuracies[output_layer].append(pred_acc) 167 | 168 | print( 169 | "Processed layer ", 170 | str(i + 1), 171 | " of ", 172 | str(len(self.layer_ids)), 173 | ) 174 | 175 | # Initialise concept predictors with models in the first layer 176 | # Consists of an array of size |concepts|, in which 177 | # The element arr[i] is the Layer Id of the layer to use when predicting 178 | # that concept 179 | self.concept_predictor_layer_ids = [ 180 | self.layer_ids[0] for _ in range(self.n_concepts) 181 | ] 182 | 183 | # For every concept, identify the layer with the best clf predictive 184 | # accuracy 185 | for c in range(self.n_concepts): 186 | max_acc = -1 187 | for i, layer_id in enumerate(self.layer_ids): 188 | layer = self.model.layers[layer_id] 189 | acc = self.model_accuracies[layer][c] 190 | if acc > max_acc: 191 | max_acc = acc 192 | self.concept_predictor_layer_ids[c] = layer_id 193 | 194 | def predict_concepts(self, ds): 195 | ''' 196 | Predict concept values from given data 197 | :param ds: tensorflow dataset of the data (has to have a known 198 | cardinality) 199 | ''' 200 | 201 | n_samples = int(ds.cardinality().numpy()) 202 | concept_vals = np.zeros((n_samples, self.n_concepts), dtype=float) 203 | 204 | for c in range(self.n_concepts): 205 | # Retrieve clf corresponding to concept c 206 | layer_id = self.concept_predictor_layer_ids[c] 207 | output_layer = self.model.layers[layer_id] 208 | clf = self.model_ensemble[output_layer][c] 209 | 210 | # Compute activations for that layer 211 | output_layer = self.model.layers[layer_id] 212 | reduced_model = tf.keras.Model( 213 | inputs=self.model.inputs, 214 | outputs=[output_layer.output], 215 | ) 216 | hidden_features = reduced_model.predict(ds.batch(self.batch_size)) 217 | clf_data = flatten_activations(hidden_features) 218 | 219 | # Predict concept values from the activations 220 | concept_vals[:, c] = clf.predict(clf_data) 221 | 222 | return concept_vals 223 | 224 | 225 | def flatten_activations(x_data): 226 | ''' 227 | Flatten all axes except the first one 228 | ''' 229 | 230 | if len(x_data.shape) > 2: 231 | n_samples = x_data.shape[0] 232 | shape = x_data.shape[1:] 233 | flattened = np.reshape(x_data, (n_samples, np.prod(shape))) 234 | else: 235 | flattened = x_data 236 | 237 | return flattened 238 | 239 | 240 | def compute_activation_per_layer( 241 | ds, 242 | layer_ids, 243 | model, 244 | batch_size=128, 245 | aggregation_function=flatten_activations, 246 | ): 247 | ''' 248 | Compute activations of input data for 'layer_ids' layers 249 | For every layer, aggregate values using 'aggregation_function' 250 | 251 | Returns a list L of size |layer_ids|, in which element L[i] is the 252 | activations computed from the model layer model.layers[layer_ids[i]], 253 | processed by the aggregation function 254 | 255 | :param ds: tf.dataset returning the input data 256 | :param layer_ids: list of indices, indicating model layers to use 257 | :param model: tf.Keras model to compute the activations from 258 | :param batch_size: batch size to use during processing 259 | :param aggregation_function: aggregation function for aggregating the 260 | activation values 261 | ''' 262 | 263 | hidden_features_list = [] 264 | 265 | for layer_id in layer_ids: 266 | # Compute and aggregate hidden activations 267 | output_layer = model.layers[layer_id] 268 | reduced_model = tf.keras.Model( 269 | inputs=model.inputs, 270 | outputs=output_layer.output, 271 | ) 272 | hidden_features = reduced_model.predict(ds.batch(batch_size)) 273 | flattened = aggregation_function(hidden_features) 274 | 275 | hidden_features_list.append(flattened) 276 | 277 | return hidden_features_list 278 | -------------------------------------------------------------------------------- /concepts_xai/methods/CW/CWLayer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Re-implementation of the Concept Whitening module proposed by Chen et al.. 3 | 4 | 1) See the IterNormRotation class in for the original paper implementation: 5 | https://github.com/zhiCHEN96/ConceptWhitening/blob/final_version/MODELS/iterative_normalization.py 6 | 7 | 2) See Concept Whitening for Interpretable Image Recognition 8 | (https://arxiv.org/abs/2002.01650) for the original paper 9 | ''' 10 | 11 | import tensorflow as tf 12 | 13 | ################################################################################ 14 | ## Helper classes taken from tensorflow-addons to avoid import and extend 15 | ## functions. 16 | ## Full code can be found in 17 | ## https://github.com/tensorflow/addons/blob/v0.13.0/tensorflow_addons/layers/max_unpooling_2d.py#L88-L147 18 | ################################################################################ 19 | 20 | 21 | def normalize_tuple(value, n, name): 22 | """Transforms an integer or iterable of integers into an integer tuple. 23 | A copy of tensorflow.python.keras.util. 24 | Args: 25 | value: The value to validate and convert. Could an int, or any iterable 26 | of ints. 27 | n: The size of the tuple to be returned. 28 | name: The name of the argument being validated, e.g. "strides" or 29 | "kernel_size". This is only used to format error messages. 30 | Returns: 31 | A tuple of n integers. 32 | Raises: 33 | ValueError: If something else than an int/long or iterable thereof was 34 | passed. 35 | """ 36 | if isinstance(value, int): 37 | return (value,) * n 38 | else: 39 | try: 40 | value_tuple = tuple(value) 41 | except TypeError: 42 | raise TypeError( 43 | "The `" 44 | + name 45 | + "` argument must be a tuple of " 46 | + str(n) 47 | + " integers. Received: " 48 | + str(value) 49 | ) 50 | if len(value_tuple) != n: 51 | raise ValueError( 52 | "The `" 53 | + name 54 | + "` argument must be a tuple of " 55 | + str(n) 56 | + " integers. Received: " 57 | + str(value) 58 | ) 59 | for single_value in value_tuple: 60 | try: 61 | int(single_value) 62 | except (ValueError, TypeError): 63 | raise ValueError( 64 | "The `" 65 | + name 66 | + "` argument must be a tuple of " 67 | + str(n) 68 | + " integers. Received: " 69 | + str(value) 70 | + " " 71 | "including element " 72 | + str(single_value) 73 | + " of type" 74 | + " " 75 | + str(type(single_value)) 76 | ) 77 | return value_tuple 78 | 79 | 80 | def _calculate_output_shape(input_shape, pool_size, strides, padding): 81 | """Calculates the shape of the unpooled output.""" 82 | if padding == "VALID": 83 | output_shape = ( 84 | input_shape[0], 85 | (input_shape[1] - 1) * strides[0] + pool_size[0], 86 | (input_shape[2] - 1) * strides[1] + pool_size[1], 87 | input_shape[3], 88 | ) 89 | elif padding == "SAME": 90 | output_shape = ( 91 | input_shape[0], 92 | input_shape[1] * strides[0], 93 | input_shape[2] * strides[1], 94 | input_shape[3], 95 | ) 96 | else: 97 | raise ValueError('Padding must be a string from: "SAME", "VALID"') 98 | return output_shape 99 | 100 | 101 | def _max_unpooling_2d( 102 | updates, 103 | mask, 104 | pool_size=(2, 2), 105 | strides=(2, 2), 106 | padding="SAME", 107 | ): 108 | """Unpool the outputs of a maximum pooling operation.""" 109 | mask = tf.cast(mask, "int32") 110 | pool_size = normalize_tuple(pool_size, 2, "pool_size") 111 | strides = normalize_tuple(strides, 2, "strides") 112 | input_shape = tf.shape(updates, out_type="int32") 113 | input_shape = [updates.shape[i] or input_shape[i] for i in range(4)] 114 | output_shape = _calculate_output_shape( 115 | input_shape, 116 | pool_size, 117 | strides, 118 | padding, 119 | ) 120 | 121 | # Calculates indices for batch, height, width and feature maps. 122 | one_like_mask = tf.ones_like(mask, dtype="int32") 123 | batch_shape = tf.concat([[input_shape[0]], [1], [1], [1]], axis=0) 124 | batch_range = tf.reshape( 125 | tf.range(output_shape[0], dtype="int32"), shape=batch_shape 126 | ) 127 | b = one_like_mask * batch_range 128 | y = mask // (output_shape[2] * output_shape[3]) 129 | x = (mask // output_shape[3]) % output_shape[2] 130 | feature_range = tf.range(output_shape[3], dtype="int32") 131 | f = one_like_mask * feature_range 132 | 133 | # Transposes indices & reshape update values to one dimension. 134 | updates_size = tf.size(updates) 135 | indices = tf.transpose( 136 | tf.reshape(tf.stack([b, y, x, f]), [4, updates_size]) 137 | ) 138 | values = tf.reshape(updates, [updates_size]) 139 | return tf.scatter_nd(indices, values, output_shape) 140 | 141 | 142 | ################################################################################ 143 | ## Concept Whitening Layer Class 144 | ################################################################################ 145 | 146 | 147 | class ConceptWhiteningLayer(tf.keras.layers.Layer): 148 | def __init__( 149 | self, 150 | T=5, 151 | eps=1e-5, 152 | momentum=0.9, 153 | activation_mode='max_pool_mean', 154 | c1=1e-4, 155 | c2=0.9, 156 | max_tau_iterations=500, 157 | initial_tau=1000.0, # Original CW code: 1000 158 | data_format="channels_first", 159 | initial_beta=1e8, # Original CW code: 1e8 160 | initial_alpha=0, # Original CW code: 0 161 | **kwargs 162 | ): 163 | super(ConceptWhiteningLayer, self).__init__(**kwargs) 164 | assert data_format in ["channels_first", "channels_last"], ( 165 | f'Expected data format to be either "channels_first" or ' 166 | f'"channels_last" but got {data_format} instead.' 167 | ) 168 | 169 | self.T = T 170 | self.eps = eps 171 | self.momentum = momentum 172 | self.c1 = c1 173 | self.c2 = c2 174 | self.max_tau_iterations = max_tau_iterations 175 | self.data_format = data_format 176 | self.initial_tau = initial_tau 177 | self.initial_alpha = initial_alpha 178 | self.initial_beta = initial_beta 179 | 180 | # Methods for aggregating a feature map into an score 181 | self.activation_mode = activation_mode 182 | 183 | def compute_output_shape(self, input_shape): 184 | return input_shape 185 | 186 | def concept_scores( 187 | self, 188 | inputs, 189 | aggregator='max_pool_mean', 190 | concept_indices=None, 191 | ): 192 | outputs = self(inputs, training=False) 193 | if len(tf.shape(outputs)) == 2: 194 | # Then the scores are already computed by our forward pass 195 | scores = outputs 196 | else: 197 | if self.data_format == "channels_last": 198 | # Then we will transpose to make things simpler so that 199 | # downstream we can always assume it is channels first 200 | # NHWC -> NCHW 201 | outputs = tf.transpose( 202 | outputs, 203 | perm=[0, 3, 1, 2], 204 | ) 205 | 206 | # Else, we need to do some aggregation 207 | if aggregator == 'mean': 208 | # Compute the mean over all channels 209 | scores = tf.math.reduce_mean(outputs, axis=[2, 3]) 210 | elif aggregator == 'max_pool_mean': 211 | # First downsample using a max pool and then continue with 212 | # a mean 213 | window_size = min( 214 | 2, 215 | outputs.shape[-1], 216 | outputs.shape[-2], 217 | ) 218 | scores = tf.nn.max_pool( 219 | outputs, 220 | ksize=window_size, 221 | strides=window_size, 222 | padding="SAME", 223 | data_format="NCHW", 224 | ) 225 | scores = tf.math.reduce_mean(scores, axis=[2, 3]) 226 | elif aggregator == 'max': 227 | # Simply select the maximum value across a given channel 228 | scores = tf.math.reduce_max(outputs, axis=[2, 3]) 229 | else: 230 | raise ValueError(f'Unsupported aggregator {aggregator}.') 231 | 232 | if concept_indices is not None: 233 | return scores[:, concept_indices] 234 | return scores 235 | 236 | def build(self, input_shape): 237 | super(ConceptWhiteningLayer, self).build(input_shape) 238 | 239 | # And use the shape to construct all our running variables 240 | assert len(input_shape) in [2, 4], ( 241 | f'Expected input to CW layer to be a rank-2 or rank-4 matrix but ' 242 | f'instead got a tensor with shape {input_shape}.' 243 | ) 244 | 245 | # We assume channels-first data format 246 | if self.data_format == "channels_first": 247 | self.num_features = input_shape[1] 248 | else: 249 | self.num_features = input_shape[-1] 250 | 251 | # Running mean 252 | self.running_mean = self.add_weight( 253 | name="running_mean", 254 | shape=(self.num_features,), 255 | dtype=tf.float32, 256 | initializer=tf.constant_initializer(0), 257 | # No gradient flow is expected to come into this variable 258 | trainable=False, 259 | ) 260 | 261 | # Running whitened matrix 262 | self.running_wm = self.add_weight( 263 | name="running_wm", 264 | shape=(self.num_features, self.num_features), 265 | dtype=tf.float32, 266 | initializer=tf.keras.initializers.Identity(), 267 | # No gradient flow is expected to come into this variable 268 | trainable=False, 269 | ) 270 | 271 | # Running rotation matrix 272 | self.running_rot = self.add_weight( 273 | name="running_rot", 274 | shape=(self.num_features, self.num_features), 275 | dtype=tf.float32, 276 | initializer=tf.keras.initializers.Identity(), 277 | # No gradient flow is expected to come into this variable 278 | trainable=False, 279 | ) 280 | 281 | # Sum of Gradient matrix 282 | self.sum_G = self.add_weight( 283 | name="sum_G", 284 | shape=(self.num_features, self.num_features), 285 | dtype=tf.float32, 286 | initializer=tf.keras.initializers.Zeros(), 287 | # No gradient flow is expected to come into this variable 288 | trainable=False, 289 | ) 290 | 291 | # Counter of gradients for each feature 292 | self.counter = self.add_weight( 293 | name="counter", 294 | shape=(self.num_features,), 295 | dtype=tf.float32, 296 | initializer=tf.constant_initializer(0.001), 297 | # No gradient flow is expected to come into this variable 298 | trainable=False, 299 | ) 300 | 301 | def update_rotation_matrix(self, concept_groups, index_map=lambda x: x): 302 | """ 303 | Update the rotation matrix R using the accumulated gradient G. 304 | """ 305 | 306 | # Updating the gradient matrix, using the concept datasets and their 307 | # aligned maps 308 | for i, concept_samples in enumerate(concept_groups): 309 | 310 | samples_shape = tf.shape(concept_samples) 311 | if len(samples_shape) == 2: 312 | # Add trivial dimensions to make it 4D so that it works as 313 | # it does with image data. We will undo this at the end 314 | concept_samples = tf.reshape( 315 | concept_samples, 316 | tf.concat( 317 | [samples_shape[0:1], samples_shape[1:], [1, 1]], 318 | axis=0, 319 | ), 320 | ) 321 | if ( 322 | (self.data_format == "channels_last") and 323 | (len(samples_shape) != 2) 324 | ): 325 | # Then we will transpose to make things simpler so that 326 | # downstream we can always assume it is channels first 327 | # NHWC -> NCHW 328 | concept_samples = tf.transpose( 329 | concept_samples, 330 | perm=[0, 3, 1, 2], 331 | ) 332 | 333 | # Produce the whitened only activations of this concept group 334 | X_hat = self._compute_whitened_activations( 335 | concept_samples, 336 | rotate=False, 337 | training=False, 338 | ) 339 | 340 | # Determine which feature map to use for this concept group 341 | feature_idx = index_map(i) 342 | 343 | # And update the gradient by performing an accumulation using the 344 | # requested activation mode 345 | if self.activation_mode == 'mean': 346 | grad_col = -tf.math.reduce_mean( 347 | tf.math.reduce_mean(X_hat, axis=[2, 3]), 348 | axis=0, 349 | ) 350 | 351 | self.sum_G[feature_idx, :].assign( 352 | ( 353 | (grad_col * self.momentum) 354 | ) + (1. - self.momentum) * self.sum_G[feature_idx, :] 355 | ) 356 | 357 | elif self.activation_mode == 'max_pool_mean': 358 | X_test_nchw = tf.linalg.einsum( 359 | 'bchw,dc->bdhw', 360 | X_hat, 361 | self.running_rot, 362 | ) 363 | 364 | # Move to NHWC as tf.nn.max_pool_with_argmax only supports 365 | # channels last format 366 | X_test_nhwc = tf.transpose(X_test_nchw, perm=[0, 2, 3, 1]) 367 | 368 | window_size = min( 369 | 2, 370 | X_hat.shape[-1], 371 | X_hat.shape[-2], 372 | ) 373 | maxpool_value, max_indices = tf.nn.max_pool_with_argmax( 374 | X_test_nhwc, 375 | ksize=window_size, 376 | strides=window_size, 377 | padding='SAME', 378 | data_format='NHWC', 379 | ) 380 | X_test_unpool = _max_unpooling_2d( 381 | maxpool_value, 382 | max_indices, 383 | pool_size=window_size, 384 | strides=window_size, 385 | padding="SAME", 386 | ) 387 | 388 | # And reshape to original NCHW format from NHWC format 389 | X_test_unpool = tf.transpose(X_test_unpool, perm=[0, 3, 1, 2]) 390 | 391 | # Finally, compute the actual gradient and update or running 392 | # matrix 393 | maxpool_mask = tf.cast( 394 | tf.math.equal(X_test_nchw, X_test_unpool), 395 | tf.float32, 396 | ) 397 | # Average only over those elements selected by the max pool 398 | # operator. 399 | grad = ( 400 | tf.reduce_sum(X_hat * maxpool_mask, axis=(2, 3)) / 401 | tf.reduce_sum(maxpool_mask, axis=(2, 3)) 402 | ) 403 | 404 | # And average over all samples 405 | grad = -tf.reduce_mean(grad, axis=0) 406 | self.sum_G[feature_idx, :].assign( 407 | self.momentum * grad + 408 | (1. - self.momentum) * self.sum_G[feature_idx, :] 409 | ) 410 | 411 | else: 412 | raise NotImplementedError( 413 | "Currently supporting only the max_pool_mean and mean " 414 | "options" 415 | ) 416 | 417 | # And increase the counter 418 | self.counter[feature_idx].assign(self.counter[feature_idx] + 1) 419 | 420 | # Time to update our rotation matrix 421 | # Original CW paper does this counter division so keeping it for now 422 | # for backwards compatability 423 | G = self.sum_G / tf.expand_dims(self.counter, axis=-1) 424 | 425 | # Original CW code uses range(2) for some bizzare reason 426 | for _ in range(2): 427 | tau = self.initial_tau # learning rate in Cayley transform 428 | alpha = self.initial_alpha 429 | beta = self.initial_beta 430 | 431 | # Compute: GR^T - RG^T 432 | A = tf.einsum('in,jn->ij', G, self.running_rot) - tf.einsum( 433 | 'in,jn->ij', 434 | self.running_rot, 435 | G, 436 | ) 437 | I = tf.eye(self.num_features) 438 | dF_0 = -0.5 * tf.math.reduce_sum(A * A) 439 | 440 | # Computing tau using Algorithm 1 in 441 | # https://link.springer.com/article/10.1007/s10107-012-0584-1 442 | # Binary search for appropriate learning rate 443 | count = 0 444 | while count < self.max_tau_iterations: 445 | Q = tf.linalg.matmul( 446 | tf.linalg.inv(I + 0.5 * tau * A), 447 | I - 0.5 * tau * A, 448 | ) 449 | Y_tau = tf.linalg.matmul(Q, self.running_rot) 450 | F_X = tf.math.reduce_sum(G * self.running_rot) 451 | F_Y_tau = tf.math.reduce_sum(G * Y_tau) 452 | dF_tau = tf.linalg.matmul( 453 | tf.einsum( 454 | 'ni,nj->ij', 455 | G, 456 | tf.linalg.inv(I + 0.5 * tau * A), 457 | ), 458 | tf.linalg.matmul(A, 0.5 * (self.running_rot + Y_tau)), 459 | ) 460 | dF_tau = -tf.linalg.trace(dF_tau) 461 | 462 | if F_Y_tau > F_X + self.c1 * tau * dF_0 + 1e-18: 463 | beta = tau 464 | tau = (beta + alpha) / 2 465 | elif dF_tau + 1e-18 < self.c2 * dF_0: 466 | alpha = tau 467 | tau = (beta + alpha) / 2 468 | else: 469 | break 470 | count += 1 471 | 472 | if count > self.max_tau_iterations: 473 | print("------------------update fail----------------------") 474 | print(F_Y_tau, F_X + self.c1 * tau * dF_0) 475 | print(dF_tau, self.c2 * dF_0) 476 | print("---------------------------------------------------") 477 | break 478 | 479 | # Using the un-numbered equation in the Concept Whitening paper 480 | # Lines 12-13 of Algorithm 2 in CW paper 481 | Q = tf.linalg.matmul( 482 | tf.linalg.matmul( 483 | tf.linalg.inv(I + 0.5 * tau * A), 484 | I - 0.5 * tau * A, 485 | ), 486 | self.running_rot, 487 | ) 488 | # And update the rotation matrix as well as reset the counters 489 | self.running_rot.assign(Q) 490 | 491 | self.counter.assign(tf.ones((self.num_features,)) * 0.001) 492 | 493 | def call(self, inputs, training=False): 494 | input_shape = tf.shape(inputs) 495 | static_imputs_shape = inputs.shape 496 | if len(static_imputs_shape) == 2: 497 | # Add trivial dimensions to make it 4D so that it works as 498 | # it does with image data. We will undo this at the end 499 | inputs = tf.reshape( 500 | inputs, 501 | tf.concat( 502 | [input_shape[0:1], input_shape[1:], [1, 1]], 503 | axis=0, 504 | ), 505 | ) 506 | if (self.data_format == "channels_last") and ( 507 | len(static_imputs_shape) != 2 508 | ): 509 | # Then we will transpose to make things simpler so that downstream 510 | # we can always assume it is channels first 511 | # NHWC -> NCHW 512 | inputs = tf.transpose( 513 | inputs, 514 | perm=[0, 3, 1, 2], 515 | ) 516 | 517 | result = tf.linalg.einsum( 518 | 'bchw,dc->bdhw', 519 | self._compute_whitened_activations(inputs, training), 520 | self.running_rot, 521 | ) 522 | 523 | if len(static_imputs_shape) == 2: 524 | # Then let's get it back to its original shape 525 | result = tf.reshape(result, input_shape) 526 | if (self.data_format == "channels_last") and ( 527 | len(static_imputs_shape) != 2 528 | ): 529 | # Then let's move it back to channels last 530 | # NCHW -> NHWC 531 | result = tf.transpose( 532 | result, 533 | perm=[0, 2, 3, 1], 534 | ) 535 | return result 536 | 537 | def _compute_whitened_activations(self, X, training, rotate=False): 538 | ''' 539 | Implements Algorithm 1 from https://arxiv.org/pdf/1904.03441.pdf 540 | Also, updates the running mean and whitening matrices using a moving 541 | average 542 | 543 | Assumes the input X is in the format of (N, C, H, W) 544 | ''' 545 | 546 | input_shape = tf.shape(X) 547 | 548 | # Flip first two dimensions in order to obtain a (C, N, H, W) tensor 549 | x = tf.transpose(X, perm=[1, 0, 2, 3]) 550 | 551 | # Change (C, N, H, W) to (D, NxHxW) 552 | cnhw_shape = tf.shape(x) 553 | x = tf.reshape(x, [self.num_features, -1]) 554 | m = tf.shape(x)[-1] 555 | 556 | if training: 557 | # Calculate mini-batch mean 558 | # Line 4 of Algorithm 1 559 | mean = tf.reduce_mean(x, axis=-1, keepdims=True) 560 | 561 | # Calculate centered activation 562 | # Line 5 of Algorithm 1 563 | xc = x - mean 564 | 565 | # Calculate covariance matrix 566 | # Line 6 of Algorithm 1 567 | I = tf.eye(self.num_features) 568 | sigma = self.eps * I 569 | sigma += 1./tf.cast(m, tf.float32) * tf.linalg.matmul( 570 | xc, 571 | tf.transpose(xc) 572 | ) 573 | # Calculate trace-normalized covariance matrix using eqn. (4) in 574 | # the paper 575 | # Line 7 of Algorithm 1 576 | sigma_tr_rec = tf.expand_dims( 577 | tf.math.reciprocal(tf.linalg.trace(sigma)), 578 | axis=-1, 579 | ) 580 | sigma_N = sigma * sigma_tr_rec 581 | 582 | # Original CW code: (they do not use the actual trace and this can 583 | # leas to instability during training) 584 | # sigma_tr_rec = tf.reduce_sum(sigma, axis=(0, 1), keepdims=True) 585 | # sigma_N = sigma * sigma_tr_rec 586 | 587 | # Calculate whitening matrix 588 | # Lines 8-11 of Algorithm 1 589 | P = tf.eye(self.num_features) 590 | for _ in range(self.T): 591 | P_cubed = tf.linalg.matmul( 592 | tf.linalg.matmul(P, P), 593 | P, 594 | ) 595 | P = 1.5 * P - ( 596 | 0.5 * tf.linalg.matmul(P_cubed, sigma_N) 597 | ) 598 | 599 | # Line 12 of Algorithm 1 600 | wm = tf.math.multiply(P, tf.math.sqrt(sigma_tr_rec)) 601 | 602 | # Update the running mean and whitening matrix 603 | self.running_mean.assign( 604 | self.momentum * tf.squeeze(mean, axis=-1) + 605 | (1. - self.momentum) * self.running_mean 606 | ) 607 | self.running_wm.assign( 608 | self.momentum * wm + 609 | (1. - self.momentum) * self.running_wm 610 | ) 611 | 612 | else: 613 | xc = x - tf.expand_dims(self.running_mean, axis=-1) 614 | wm = self.running_wm 615 | 616 | # Calculate whitening output 617 | xn = tf.linalg.matmul(wm, xc) 618 | 619 | # And, if requested, apply the rotation while it is in this format 620 | if rotate: 621 | xn = tf.linalg.einsum( 622 | 'bchw,dc->bdhw', 623 | xn, 624 | self.running_rot, 625 | ) 626 | 627 | # Transform back to original shape of (N, C, H, W) 628 | return tf.transpose(tf.reshape(xn, cnhw_shape), perm=[1, 0, 2, 3]) 629 | 630 | def get_config(self): 631 | """ 632 | Serialization function. 633 | """ 634 | result = super(ConceptWhiteningLayer, self).get_config() 635 | result.update(dict( 636 | T=self.T, 637 | eps=self.eps, 638 | momentum=self.momentum, 639 | activation_mode=self.activation_mode, 640 | c1=self.c1, 641 | c2=self.c2, 642 | max_tau_iterations=self.max_tau_iterations, 643 | )) 644 | return result 645 | -------------------------------------------------------------------------------- /concepts_xai/methods/OCACE/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/concept-based-xai/a86c6abd609ec2ef0cf56cb556e3f834a0962bdc/concepts_xai/methods/OCACE/__init__.py -------------------------------------------------------------------------------- /concepts_xai/methods/OCACE/topicModel.py: -------------------------------------------------------------------------------- 1 | import concepts_xai.evaluation.metrics.completeness as completeness 2 | import numpy as np 3 | import tensorflow as tf 4 | import tensorflow.keras.backend as K 5 | import scipy 6 | 7 | ''' 8 | Re-implementation of the "On Completeness-aware Concept-Based Explanations in 9 | Deep Neural Networks": 10 | 11 | 1) See https://arxiv.org/abs/1910.07969 for the original paper 12 | 13 | 2) See https://github.com/chihkuanyeh/concept_exp for the original paper 14 | implementation 15 | 16 | ''' 17 | 18 | 19 | class TopicModel(tf.keras.Model): 20 | """Base class of a topic model.""" 21 | 22 | def __init__( 23 | self, 24 | concepts_to_labels_model, 25 | n_channels, 26 | n_concepts, 27 | g_model=None, 28 | threshold=0.5, 29 | loss_fn=tf.keras.losses.sparse_categorical_crossentropy, 30 | top_k=32, 31 | lambda1=0.1, 32 | lambda2=0.1, 33 | seed=None, 34 | eps=1e-5, 35 | data_format="channels_last", 36 | allow_gradient_flow_to_c2l=False, 37 | acc_metric=None, 38 | initial_topic_vector=None, 39 | **kwargs, 40 | ): 41 | super(TopicModel, self).__init__(**kwargs) 42 | 43 | initializer = tf.keras.initializers.RandomUniform( 44 | minval=-0.5, 45 | maxval=0.5, 46 | seed=seed, 47 | ) 48 | 49 | # Initialize our topic vector tensor which we will learn 50 | # as part of our training 51 | if initial_topic_vector is not None: 52 | self.topic_vector = self.add_weight( 53 | name="topic_vector", 54 | shape=(n_channels, n_concepts), 55 | dtype=tf.float32, 56 | initializer=lambda *args, **kwargs: initial_topic_vector, 57 | trainable=True, 58 | ) 59 | else: 60 | self.topic_vector = self.add_weight( 61 | name="topic_vector", 62 | shape=(n_channels, n_concepts), 63 | dtype=tf.float32, 64 | initializer=tf.keras.initializers.RandomUniform( 65 | minval=-0.5, 66 | maxval=0.5, 67 | seed=seed, 68 | ), 69 | trainable=True, 70 | ) 71 | 72 | # Initialize the g model which will be in charge of reconstructing 73 | # the model latent activations from the concept scores alone 74 | self.g_model = g_model 75 | if self.g_model is None: 76 | self.g_model = completeness._get_default_model( 77 | num_concepts=n_concepts, 78 | num_hidden_acts=n_channels, 79 | ) 80 | 81 | # Set the concept-to-label predictor model 82 | self.concepts_to_labels_model = concepts_to_labels_model 83 | 84 | # Set remaining model hyperparams 85 | self.eps = eps 86 | self.threshold = threshold 87 | self.n_concepts = n_concepts 88 | self.loss_fn = loss_fn 89 | self.top_k = top_k 90 | self.lambda1 = lambda1 91 | self.lambda2 = lambda2 92 | self.n_channels = n_channels 93 | self.allow_gradient_flow_to_c2l = allow_gradient_flow_to_c2l 94 | assert data_format in ["channels_last", "channels_first"], ( 95 | f'Expected data format to be either "channels_last" or ' 96 | f'"channels_first" however we obtained "{data_format}".' 97 | ) 98 | if data_format == "channels_last": 99 | self._channel_axis = -1 100 | else: 101 | raise ValueError( 102 | 'Currently we only support "channels_last" data_format' 103 | ) 104 | 105 | self.metric_names = ["loss", "mean_sim", "accuracy"] 106 | self.metrics_dict = { 107 | name: tf.keras.metrics.Mean(name=name) 108 | for name in self.metric_names 109 | } 110 | self._acc_metric = ( 111 | acc_metric or tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1) 112 | ) 113 | 114 | @property 115 | def metrics(self): 116 | return [self.metrics_dict[name] for name in self.metric_names] 117 | 118 | def update_metrics(self, losses): 119 | for (loss_name, loss) in losses: 120 | self.metrics_dict[loss_name].update_state(loss) 121 | 122 | def concept_scores(self, x, compute_reg_terms=False): 123 | # Compute the concept representation by first normalizing across the 124 | # channel dimension both the concept vectors and the input 125 | assert x.shape[self._channel_axis] == self.n_channels, ( 126 | f'Expected input to have {self.n_channels} elements in its ' 127 | f'channels axis (defined as axis {self._channel_axis}). ' 128 | f'Instead, we found the input to have shape {x.shape}.' 129 | ) 130 | 131 | x_norm = tf.math.l2_normalize(x, axis=self._channel_axis) 132 | topic_vector_norm = tf.math.l2_normalize(self.topic_vector, axis=0) 133 | # Compute the concept probability scores 134 | topic_prob = K.dot(x, topic_vector_norm) 135 | topic_prob_norm = K.dot(x_norm, topic_vector_norm) 136 | 137 | # Threshold them if they are below the given threshold value 138 | if self.threshold is not None: 139 | topic_prob = topic_prob * tf.cast( 140 | (topic_prob_norm > self.threshold), 141 | tf.float32, 142 | ) 143 | topic_prob_sum = tf.reduce_sum( 144 | topic_prob, 145 | axis=self._channel_axis, 146 | keepdims=True, 147 | ) 148 | # And normalize the actual scores 149 | topic_prob = topic_prob / (topic_prob_sum + self.eps) 150 | if not compute_reg_terms: 151 | return topic_prob 152 | 153 | # Compute the regularization loss terms 154 | reshaped_topic_probs = tf.transpose( 155 | tf.reshape(topic_prob_norm, (-1, self.n_concepts)) 156 | ) 157 | reg_loss_1 = tf.reduce_mean( 158 | tf.nn.top_k( 159 | reshaped_topic_probs, 160 | k=tf.math.minimum( 161 | self.top_k, 162 | tf.shape(reshaped_topic_probs)[-1] 163 | ), 164 | sorted=True, 165 | ).values 166 | ) 167 | 168 | reg_loss_2 = tf.reduce_mean( 169 | K.dot(tf.transpose(topic_vector_norm), topic_vector_norm) - 170 | tf.eye(self.n_concepts) 171 | ) 172 | return topic_prob, reg_loss_1, reg_loss_2 173 | 174 | def _compute_loss(self, x, y_true, training): 175 | # First, compute the concept scores for the given samples 176 | scores, reg_loss_1, reg_loss_2 = self.concept_scores( 177 | x, 178 | compute_reg_terms=True 179 | ) 180 | 181 | # Then predict the labels after reconstructing the activations 182 | # from the scores via the g model 183 | y_pred = self.concepts_to_labels_model( 184 | self.g_model(scores), 185 | training=training, 186 | ) 187 | 188 | # Compute the task loss 189 | log_prob_loss = tf.reduce_mean(self.loss_fn(y_true, y_pred)) 190 | 191 | # And include them into the total loss 192 | total_loss = ( 193 | log_prob_loss - 194 | self.lambda1 * reg_loss_1 + 195 | self.lambda2 * reg_loss_2 196 | ) 197 | 198 | # Compute the accuracy metric to track 199 | self._acc_metric.update_state(y_true, y_pred) 200 | return total_loss, reg_loss_1, self._acc_metric.result() 201 | 202 | def train_step(self, inputs): 203 | x, y = inputs 204 | 205 | # We first need to make sure that we set our concepts to labels model 206 | # so that we do not optimize over its parameters 207 | prev_trainable = self.concepts_to_labels_model.trainable 208 | if not self.allow_gradient_flow_to_c2l: 209 | self.concepts_to_labels_model.trainable = False 210 | with tf.GradientTape() as tape: 211 | loss, mean_sim, acc = self._compute_loss( 212 | x, 213 | y, 214 | # Only train the decoder if requested by the user 215 | training=self.allow_gradient_flow_to_c2l, 216 | ) 217 | 218 | gradients = tape.gradient(loss, self.trainable_variables) 219 | self.optimizer.apply_gradients( 220 | zip(gradients, self.trainable_variables) 221 | ) 222 | self.update_metrics([ 223 | ("loss", loss), 224 | ("mean_sim", mean_sim), 225 | ("accuracy", acc), 226 | ]) 227 | 228 | # And recover the previous step of the concept to labels model 229 | self.concepts_to_labels_model.trainable = prev_trainable 230 | 231 | return { 232 | name: self.metrics_dict[name].result() 233 | for name in self.metric_names 234 | } 235 | 236 | def test_step(self, inputs): 237 | x, y = inputs 238 | loss, mean_sim, acc = self._compute_loss(x, y, training=False) 239 | self.update_metrics([ 240 | ("loss", loss), 241 | ("mean_sim", mean_sim), 242 | ("accuracy", acc) 243 | ]) 244 | 245 | return { 246 | name: self.metrics_dict[name].result() 247 | for name in self.metric_names 248 | } 249 | 250 | def call(self, x, **kwargs): 251 | concept_scores = self.concept_scores(x) 252 | predicted_labels = self.concepts_to_labels_model( 253 | self.g_model(concept_scores), 254 | training=False, 255 | ) 256 | return predicted_labels, concept_scores 257 | 258 | -------------------------------------------------------------------------------- /concepts_xai/methods/OCACE/visualisation.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | import numpy as np 4 | 5 | # Update this to change font sizes 6 | plt.rcParams.update({'font.size': 18}) 7 | 8 | 9 | def visualize_nearest_neighbours( 10 | f_data, 11 | x_data, 12 | topic_model, 13 | n_prototypes=5, 14 | channels_axis=-1, 15 | ): 16 | 17 | topic_vec = topic_model.topic_vector.numpy() 18 | fig, axes = plt.subplots( 19 | topic_vec.shape[1], 20 | n_prototypes, 21 | figsize=(20, 20), 22 | ) 23 | topic_prob = topic_model.concept_scores(f_data) 24 | if len(f_data.shape): 25 | # Then let's trivially expand its dimensions so that we always 26 | # work with 3D representations 27 | f_data = np.expand_dims( 28 | np.expand_dims(f_data, axis=1), 29 | axis=1, 30 | ) 31 | topic_prob = np.expand_dims( 32 | np.expand_dims(topic_prob, axis=1), 33 | axis=1, 34 | ) 35 | dim1_size, dim2_size = f_data.shape[1], f_data.shape[2] 36 | dim1_factor = int(x_data.shape[1] / dim1_size) 37 | dim2_factor = int(x_data.shape[2] / dim2_size) 38 | 39 | for i in range(topic_prob.shape[channels_axis]): 40 | # Then topic_prob is an array of shape 41 | # # (n_samples, kernel_w, kernel_h, n_concepts) 42 | # Here, we retrieve the indices of the n_imgs_per_c largest elements 43 | ind = np.argsort(-topic_prob[:, :, :, i].flatten())[:n_prototypes] 44 | 45 | for jc, j in enumerate(ind): 46 | # Retrieve the dim0 index 47 | dim0 = int(np.floor(j/(dim1_size*dim2_size))) 48 | # Retrieve the dim1 index 49 | dim1 = int((j-dim0 * (dim1_size*dim2_size))/dim2_size) 50 | # Retrieve the dim2 index 51 | dim2 = int((j-dim0 * (dim1_size*dim2_size)) % dim2_size) 52 | 53 | # Retrieve the subpart of the image corresponding to the 54 | # activated patch 55 | dim1_start = dim1_factor * dim1 56 | dim2_start = dim2_factor * dim2 57 | img = x_data[dim0, :, :, :] 58 | img = img[ 59 | dim1_start: dim1_start + dim1_factor, 60 | dim2_start: dim2_start + dim2_factor, 61 | :, 62 | ] 63 | 64 | # Plot the image 65 | ax = axes[i, jc] 66 | 67 | if img.shape[-1] == 1: 68 | ax.imshow(img, cmap='gray') 69 | else: 70 | ax.imshow(img) 71 | ax.yaxis.grid(False) 72 | ax.get_yaxis().set_visible(False) 73 | ax.get_xaxis().set_visible(False) 74 | 75 | fig.tight_layout() 76 | plt.show() 77 | -------------------------------------------------------------------------------- /concepts_xai/methods/SENN/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/concept-based-xai/a86c6abd609ec2ef0cf56cb556e3f834a0962bdc/concepts_xai/methods/SENN/__init__.py -------------------------------------------------------------------------------- /concepts_xai/methods/SENN/aggregators.py: -------------------------------------------------------------------------------- 1 | """ 2 | Library of some default aggregator functions that one can use in SENN. 3 | """ 4 | 5 | import tensorflow as tf 6 | 7 | 8 | def multiclass_additive_aggregator(thetas, concepts): 9 | # Returns output shape (batch, n_outputs) 10 | return tf.squeeze( 11 | # (batch, n_outputs, 1) 12 | tf.linalg.matmul( 13 | # (batch, n_outputs, n_concepts) 14 | thetas, 15 | # (batch, n_concepts, 1) 16 | tf.expand_dims(concepts, axis=-1), 17 | ), 18 | axis=-1 19 | ) 20 | 21 | 22 | def scalar_additive_aggregator(thetas, concepts): 23 | # Returns output shape (batch) 24 | return tf.squeeze( 25 | multiclass_additive_aggregator(thetas=thetas, concepts=concepts), 26 | axis=-1, 27 | ) 28 | 29 | 30 | def softmax_additive_aggregator(thetas, concepts): 31 | # Returns output shape (batch, n_outputs) 32 | return tf.nn.softmax( 33 | multiclass_additive_aggregator(thetas=thetas, concepts=concepts), 34 | axis=-1, 35 | ) 36 | -------------------------------------------------------------------------------- /concepts_xai/methods/SENN/base_senn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tensorflow implementation of Self-Explaining Neural Networks (SENN) by 3 | Alvarez-Melis and Jaakkola (NeurIPS 2018) [1]. 4 | 5 | [1] https://papers.nips.cc/paper/2018/file/3e9f0fc9b2f89e043bc6233994dfcf76-Paper.pdf 6 | 7 | """ 8 | 9 | import tensorflow as tf 10 | from tensorflow.python.keras.engine import data_adapter 11 | 12 | 13 | class SelfExplainingNN(tf.keras.Model): 14 | """ 15 | Main class implementation for Self-Explaining Neural Networks (SENN)s as 16 | defined and described by Alvarez-Melis and Jaakkola (NeurIPS 2018). 17 | """ 18 | 19 | def __init__( 20 | self, 21 | encoder_model, 22 | coefficient_model, 23 | aggregator_fn, 24 | task_loss_fn, 25 | regularization_strength=1e-1, # "lambda" regulatizer parameter 26 | sparsity_strength=2e-5, # "zeta" autoencoder parameter 27 | reconstruction_loss_fn=None, 28 | task_loss_weight=1, 29 | robustness_norm_fn=lambda x: tf.norm(x, ord='fro', axis=[-2, -1]), 30 | metrics=None, 31 | **kwargs, 32 | ): 33 | """ 34 | Constructs a SENN Keras Model ready for training. 35 | 36 | When using this model for prediction, it will return a tuple 37 | (labels, (concept_vectors, thetas)) indicating the predicted label 38 | probabilities as well as the concepts vectors, together with their 39 | linear importance weights defined by thetas. 40 | """ 41 | super(SelfExplainingNN, self).__init__(**kwargs) 42 | self.task_loss_weight = task_loss_weight 43 | self.encoder_model = encoder_model 44 | self.coefficient_model = coefficient_model 45 | self.aggregator_fn = aggregator_fn 46 | self.task_loss_fn = task_loss_fn 47 | self.regularization_strength = regularization_strength 48 | self.sparsity_strength = sparsity_strength 49 | self.reconstruction_loss_fn = reconstruction_loss_fn 50 | self.robustness_norm_fn = robustness_norm_fn 51 | if self.reconstruction_loss_fn is not None: 52 | self.metrics_dict = { 53 | "loss": tf.keras.metrics.Mean( 54 | name="loss" 55 | ), 56 | "task_loss": tf.keras.metrics.Mean( 57 | name="task_loss" 58 | ), 59 | "robustness_loss": tf.keras.metrics.Mean( 60 | name="robustness_loss" 61 | ), 62 | "reconstruction_loss": tf.keras.metrics.Mean( 63 | name="reconstruction_loss" 64 | ), 65 | } 66 | else: 67 | self.metrics_dict = { 68 | "loss": tf.keras.metrics.Mean( 69 | name="loss" 70 | ), 71 | "task_loss": tf.keras.metrics.Mean( 72 | name="task_loss" 73 | ), 74 | "robustness_loss": tf.keras.metrics.Mean( 75 | name="robustness_loss" 76 | ), 77 | } 78 | self._extra_metrics = [] 79 | for metric in (metrics or []): 80 | if isinstance(metric, (tuple, list)): 81 | if len(metric) != 2: 82 | raise ValueError( 83 | f"Expected metrics to be a list of tuples " 84 | f"(name, metric) or TF metric objects. Instead we " 85 | f"received {metric}." 86 | ) 87 | name, metric = metric 88 | else: 89 | name = metric.name 90 | self.metrics_dict[name] = metric 91 | self._extra_metrics.append((name, metric)) 92 | 93 | @property 94 | def metrics(self): 95 | return [ 96 | self.metrics_dict[name] for name in self.metrics_dict.keys() 97 | ] 98 | 99 | def update_metrics(self, losses): 100 | for (name, loss) in losses: 101 | self.metrics_dict[name].update_state(loss) 102 | 103 | def call(self, inputs): 104 | # First compute our concepts 105 | concepts = self.encoder_model(inputs) # (batch, n_concepts) 106 | # Then all of the theta weights which we will use for our concepts 107 | thetas = self.coefficient_model(inputs) # (batch, n_outputs, n_concepts) 108 | if len(thetas.shape) < 3: 109 | # Then the number of classes/outputs is 1 so let's make it explicit 110 | thetas = tf.expand_dims(thetas, axis=1) 111 | # Make sure the dimensions match 112 | tf.debugging.assert_equal( 113 | tf.shape(concepts)[-1], 114 | tf.shape(thetas)[-1], 115 | message=( 116 | "The last dimension of the returned concept tensor and the " 117 | "theta tensor must be the same (they both should correspond to " 118 | "the number of concepts used in the SENN). We found " 119 | f"{tf.shape(concepts)[-1]} entries in concepts.shape[-1] while " 120 | f"{tf.shape(thetas)[-1]} entries in thetas.shape[-1]." 121 | ) 122 | ) 123 | predictions = self.aggregator_fn( 124 | thetas=thetas, 125 | concepts=concepts, 126 | ) 127 | return predictions, (concepts, thetas) 128 | 129 | def train_step(self, inputs): 130 | # This will allow us to compute the task specific loss 131 | inputs = data_adapter.expand_1d(inputs) 132 | x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(inputs) 133 | with tf.GradientTape() as outter_tape: 134 | total_loss = 0 135 | # Do a nesting of tapes as we will need to compute gradients and 136 | # jacobian in order for one 137 | with tf.GradientTape( 138 | persistent=True, 139 | watch_accessed_variables=False, 140 | ) as inner_tape: 141 | # First compute predictions and their corresponding explanation 142 | inner_tape.watch(x) 143 | preds, (concepts, thetas) = self(x) 144 | # This gives us the task specific loss 145 | task_loss = self.task_loss_fn(y, preds) 146 | total_loss += task_loss * self.task_loss_weight 147 | 148 | if self.reconstruction_loss_fn is not None: 149 | # Now compute the encoder reconstruction loss similarly 150 | reconstruction_loss = self.reconstruction_loss_fn( 151 | x, 152 | concepts, 153 | ) 154 | total_loss += self.sparsity_strength * reconstruction_loss 155 | 156 | # Finally, time to add the robustness loss. This is the trickiest 157 | # and most expensive one as it requires the computation of a 158 | # jacobian 159 | # Gradient of f(x) with respect to x should have shape 160 | # (batch, n_outputs, ). 161 | # Note that we use a jacobian computation rather than a gradient 162 | # computation (as done in the paper) as we support general 163 | # multi-dimensional outputs 164 | if len(preds.shape) < 3: 165 | # Jacobian requires a 2D input 166 | preds = tf.expand_dims(preds, axis=1) 167 | f_grad_x = inner_tape.batch_jacobian(preds, x) 168 | 169 | # Reshape to (batch, n_outputs, flatten_x_shape) 170 | f_grad_x = tf.reshape( 171 | f_grad_x, 172 | [tf.shape(x)[0], tf.shape(preds)[-1], -1] 173 | ) 174 | 175 | # Jacobian of h(x) with respect to x should have shape 176 | # (batch, n_concepts, ) 177 | h_jacobian_x = inner_tape.batch_jacobian( 178 | concepts, 179 | x, 180 | ) 181 | # Reshape to (batch, n_concepts, flatten_x_shape) 182 | h_jacobian_x = tf.reshape( 183 | h_jacobian_x, 184 | [tf.shape(x)[0], tf.shape(concepts)[-1], -1] 185 | ) 186 | robustness_loss = self.robustness_norm_fn( 187 | f_grad_x - tf.matmul( 188 | # No need to transpose thetas as they already have the 189 | # number of outputs as its first non-batch dimension (i.e., 190 | # its shape is (batch, n_outputs, n_concepts)) 191 | thetas, 192 | # h_jacobian_x shape: (batch, n_concepts, flatten_x_shape) 193 | h_jacobian_x, 194 | ) # matmul shape: (batch, n_outputs, flatten_x_shape) 195 | ) 196 | robustness_loss = tf.math.reduce_mean(robustness_loss) 197 | total_loss += self.regularization_strength * robustness_loss 198 | 199 | # Compute gradients and proceed with SGD 200 | gradients = outter_tape.gradient(total_loss, self.trainable_variables) 201 | self.optimizer.apply_gradients( 202 | zip(gradients, self.trainable_variables) 203 | ) 204 | 205 | # And update all of our metrics 206 | self.update_metrics([ 207 | ("loss", total_loss), 208 | ("task_loss", task_loss), 209 | ("robustness_loss", robustness_loss), 210 | ]) 211 | if self.reconstruction_loss_fn is not None: 212 | self.update_metrics([ 213 | ("reconstruction_loss", reconstruction_loss), 214 | ]) 215 | for (name, metric) in self._extra_metrics: 216 | self.metrics_dict[name].update_state(y, preds, sample_weight) 217 | return { 218 | name: val.result() 219 | for name, val in self.metrics_dict.items() 220 | } 221 | 222 | def test_step(self, inputs): 223 | # This will allow us to compute the task specific loss 224 | inputs = data_adapter.expand_1d(inputs) 225 | x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(inputs) 226 | with tf.GradientTape() as outter_tape: 227 | total_loss = 0 228 | # Do a nesting of tapes as we will need to compute gradients and 229 | # jacobian in order for one 230 | with tf.GradientTape( 231 | persistent=True, 232 | watch_accessed_variables=False, 233 | ) as inner_tape: 234 | # First compute predictions and their corresponding explanation 235 | inner_tape.watch(x) 236 | preds, (concepts, thetas) = self(x) 237 | # This gives us the task specific loss 238 | task_loss = self.task_loss_fn(y, preds) 239 | total_loss += task_loss * self.task_loss_weight 240 | 241 | if self.reconstruction_loss_fn is not None: 242 | # Now compute the encoder reconstruction loss similarly 243 | reconstruction_loss = self.reconstruction_loss_fn( 244 | x, 245 | concepts, 246 | ) 247 | total_loss += self.sparsity_strength * reconstruction_loss 248 | 249 | # Finally, time to add the robustness loss. This is the trickiest 250 | # and most expensive one as it requires the computation of a 251 | # jacobian 252 | # Gradient of f(x) with respect to x should have shape 253 | # (batch, n_outputs, ). 254 | # Note that we use a jacobian computation rather than a gradient 255 | # computation (as done in the paper) as we support general 256 | # multi-dimensional outputs 257 | if len(preds.shape) < 3: 258 | # Jacobian requires a 2D input 259 | preds = tf.expand_dims(preds, axis=1) 260 | f_grad_x = inner_tape.batch_jacobian(preds, x) 261 | # Reshape to (batch, n_outputs, flatten_x_shape) 262 | f_grad_x = tf.reshape( 263 | f_grad_x, 264 | [tf.shape(x)[0], tf.shape(preds)[-1], -1] 265 | ) 266 | 267 | # Jacobian of h(x) with respect to x should have shape 268 | # (batch, n_concepts, ) 269 | h_jacobian_x = inner_tape.batch_jacobian( 270 | concepts, 271 | x, 272 | ) 273 | # Reshape to (batch, n_concepts, flatten_x_shape) 274 | h_jacobian_x = tf.reshape( 275 | h_jacobian_x, 276 | [tf.shape(x)[0], tf.shape(concepts)[-1], -1] 277 | ) 278 | robustness_loss = self.robustness_norm_fn( 279 | f_grad_x - tf.matmul( 280 | # No need to transpose thetas as they already have the 281 | # number of outputs as its first non-batch dimension (i.e., 282 | # its shape is (batch, n_outputs, n_concepts)) 283 | thetas, 284 | # h_jacobian_x shape: (batch, n_concepts, flatten_x_shape) 285 | h_jacobian_x, 286 | ) # matmul shape: (batch, n_outputs, flatten_x_shape) 287 | ) 288 | robustness_loss = tf.math.reduce_mean(robustness_loss) 289 | total_loss += self.regularization_strength * robustness_loss 290 | 291 | # And update all of our metrics 292 | self.update_metrics([ 293 | ("loss", total_loss), 294 | ("task_loss", task_loss), 295 | ("robustness_loss", robustness_loss), 296 | ]) 297 | if self.reconstruction_loss_fn is not None: 298 | self.update_metrics([ 299 | ("reconstruction_loss", reconstruction_loss), 300 | ]) 301 | for (name, metric) in self._extra_metrics: 302 | self.metrics_dict[name].update_state(y, preds, sample_weight) 303 | return { 304 | name: val.result() 305 | for name, val in self.metrics_dict.items() 306 | } 307 | -------------------------------------------------------------------------------- /concepts_xai/methods/SSCC/SSCClassifier.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import numpy as np 3 | import os 4 | import tensorflow as tf 5 | 6 | from sklearn.linear_model import LogisticRegression 7 | from sklearn.ensemble import GradientBoostingClassifier 8 | from sklearn.model_selection import train_test_split 9 | from sklearn.metrics import accuracy_score 10 | 11 | from methods.CME.ItCModel import ItCModel 12 | from methods.CBM.CBModel import ConceptBottleneckModel 13 | from evaluation.metrics.accuracy import compute_accuracies 14 | from methods.VAE.weak_vae import GroupVAEArgmax 15 | from methods.VAE.betaVAE import BetaVAE 16 | 17 | 18 | ''' 19 | Implementation of the Semi-Supervised Concept Classifier (SSCClassifier) 20 | 21 | Contains implementations of the following: 22 | 1) An abstract SSCClassifier class, specifying the requirements of a 23 | generic SSCC concept classifier 24 | 2) Implementations of the SSCClassifier, serving as wrappers of various 25 | concept-based methods defined in the .methods folder 26 | ''' 27 | 28 | 29 | class SSCClassifier(ABC): 30 | 31 | @abstractmethod 32 | def __init__(self, **kwargs): 33 | ''' 34 | Initialise the SSCClassifier wrapper 35 | :param kwargs: any extra arguments needed to initialise the concept 36 | predictor model 37 | ''' 38 | pass 39 | 40 | @abstractmethod 41 | def fit(self, data_gen_l, data_gen_u): 42 | ''' 43 | Method for training the underlying SSCC model 44 | :param data_gen_l: Tensorflow Dataset, returning triples of points 45 | (x_data, c_data, y_data) for input, concept, and label data, 46 | respectively 47 | :param data_gen_u: Tensorflow Dataset, returning tuples of points 48 | (x_data, y_data). In the fully-unsupervised case (where y-labels 49 | are not available as well), y_data should be set to -1 50 | :return: Trains the underlying SSCC concept labeler 51 | ''' 52 | pass 53 | 54 | @abstractmethod 55 | def predict(self, data_gen_u): 56 | ''' 57 | Method for predicting the concept labels from the input data 58 | :param data_gen_u: Tensorflow Dataset, returning tuples of points 59 | (x_data, y_data). In the fully-unsupervised case (where y-labels 60 | are not available as well), y_data should be set to -1 61 | :return: numpy array of predicted samples, corresponding to the input 62 | points 63 | ''' 64 | pass 65 | 66 | 67 | class SSCC_CME(SSCClassifier): 68 | ''' 69 | CME Concept Predictor Wrapper 70 | ''' 71 | 72 | def __init__(self, **kwargs): 73 | ''' 74 | :param kwargs: Required arguments: 75 | - "model" : the underlying trained model to extract concepts from 76 | - "n_concepts" : the number of concepts 77 | ''' 78 | super().__init__(**kwargs) 79 | # Extract model parameter 80 | model = kwargs["base_model"] 81 | # Copy all the other parameters into the params dict. 82 | params = { 83 | key: val for (key, val) in kwargs.items() if key != "base_model" 84 | } 85 | # Create the underlying ItCModel concept labeler 86 | self.concept_predictor = ItCModel(model, **params) 87 | 88 | def fit(self, data_gen_l, data_gen_u): 89 | # Train underlying CME concept predictor 90 | # The default ItCModel accepts tf_datasets without the y-labels during 91 | # training. Thus, we filter out the y-labels from data_gen_l and 92 | # data_gen_u via a mapping 93 | ds_l, ds_u = remove_ds_el(data_gen_l), remove_ds_el(data_gen_u) 94 | self.concept_predictor.train(ds_l, ds_u) 95 | 96 | def predict(self, data_gen_u): 97 | # Run the underlying CME predictor 98 | # The default ItCModel accepts tf_datasets without the y-labels during 99 | # prediction. Thus, we filter out the y-labels from data_gen_u via a 100 | # mapping 101 | ds_u = remove_ds_el(data_gen_u) 102 | predicted = self.concept_predictor.predict_concepts(ds_u) 103 | return predicted 104 | 105 | 106 | class SSCC_CBM(SSCClassifier): 107 | ''' 108 | CBM Model wrapper 109 | ''' 110 | def __init__(self, **kwargs): 111 | ''' 112 | :param kwargs: Required arguments: 113 | - "model" : the underlying trained model to extract concepts from 114 | - "layer_id" : layer of the model to use as bottleneck 115 | - "n_classes" : number of classes 116 | - "n_concept_vals" : number of concept values 117 | - "multi_task" : whether to use the multi-task setup, or sigmoid 118 | setup (Optional) 119 | - "end_to_end" : whether to create an end-to-end keras model 120 | ''' 121 | super().__init__(**kwargs) 122 | # Extract required parameters 123 | model = kwargs["base_model"] 124 | layer_id = kwargs["layer_id"] 125 | n_classes = kwargs["n_classes"] 126 | n_concept_vals = kwargs["n_concept_vals"] 127 | multi_task = kwargs.get("multi_task", False) 128 | end_to_end = kwargs.get("end_to_end", False) 129 | c_epochs = kwargs.get("epochs", 10) 130 | c_extr_path = kwargs["save_path"] 131 | c_optimizer = kwargs.get("c_optimizer", None) 132 | overwrite = kwargs.get("overwrite", True) 133 | lambda_param = kwargs.get("lambda_param", 1.0) 134 | 135 | # Create the underlying CBM concept model 136 | self.n_epochs = c_epochs 137 | self.log_path = kwargs["log_path"] 138 | self.do_logged_fit = kwargs.get("logged_fit", False) 139 | self.n_concepts = len(n_concept_vals) 140 | self.n_train_predictor = kwargs.get("n_train_predictor", 2000) 141 | self.concept_predictor = ConceptBottleneckModel( 142 | model, 143 | layer_id, 144 | n_classes, 145 | n_concept_vals, 146 | c_extr_path, 147 | multi_task=multi_task, 148 | end_to_end=end_to_end, 149 | c_epochs=c_epochs, 150 | c_optimizer=c_optimizer, 151 | overwrite=True, 152 | lambda_param=lambda_param, 153 | ) 154 | 155 | def fit(self, data_gen_l, data_gen_u): 156 | if self.do_logged_fit: 157 | self.logged_fit(data_gen_l, n_epochs=self.n_epochs, frequency=5) 158 | else: 159 | self.concept_predictor.train(data_gen_l) 160 | 161 | def logged_fit(self, data_gen_l, n_epochs=60, frequency=10): 162 | 163 | all_c_accs = [] 164 | pred_overwrite = self.concept_predictor.overwrite 165 | self.concept_predictor.overwrite = True 166 | self.concept_predictor.n_epochs = frequency 167 | 168 | for i in range(0, n_epochs, frequency): 169 | self.concept_predictor.train(data_gen_l) 170 | 171 | # Train the concept predictor models 172 | c_true = np.array([elem[1].numpy() for elem in data_gen_l]) 173 | c_pred = self.predict(data_gen_l) 174 | c_accs = compute_accuracies(c_true, c_pred) 175 | all_c_accs.append(c_accs) 176 | print("Accuracies: ", c_accs) 177 | print("\n" * 3) 178 | print(f"Ran {i + frequency}/{n_epochs} epochs...") 179 | 180 | all_c_accs = np.array(all_c_accs) 181 | fpath = os.path.join(self.log_path, "freq_accuracies.txt") 182 | np.savetxt(fpath, all_c_accs, fmt='%2.4f') 183 | 184 | self.concept_predictor.overwrite = pred_overwrite 185 | self.concept_predictor.n_epochs = self.n_epochs 186 | 187 | def predict(self, data_gen_u): 188 | ds_u = remove_ds_el(data_gen_u) 189 | predicted = self.concept_predictor.predict(ds_u) 190 | return predicted 191 | 192 | 193 | class SSCC_Group_VAEArgmax(SSCClassifier): 194 | ''' 195 | Weakly-supervised VAE Model wrapper 196 | ''' 197 | def __init__(self, **kwargs): 198 | super().__init__(**kwargs) 199 | self.n_changed_c = kwargs["k"] 200 | self.n_concept_vals = kwargs["n_concept_vals"] 201 | self.n_concepts = len(self.n_concept_vals) 202 | self.batch_size = kwargs.get("batch_size", 256) 203 | self.n_epochs = kwargs.get("epochs", 10) 204 | loss_fn = kwargs["loss_fn"] 205 | encoder_fn = kwargs["encoder_fn"] 206 | decoder_fn = kwargs["decoder_fn"] 207 | latent_dim = kwargs["latent_dim"] 208 | input_shape = kwargs["input_shape"] 209 | self.log_path = kwargs["log_path"] 210 | self.save_path = kwargs["save_path"] 211 | self.do_logged_fit = kwargs.get("logged_fit", False) 212 | self.overwrite = kwargs.get("overwrite", False) 213 | self.n_train_predictor = kwargs.get("n_train_predictor", 2000) 214 | self.optimizer = kwargs.get( 215 | "optimizer", 216 | tf.keras.optimizers.Adam( 217 | beta_1=0.9, 218 | beta_2=0.999, 219 | epsilon=1e-8, 220 | learning_rate=0.001, 221 | ), 222 | ) 223 | self.vae = GroupVAEArgmax( 224 | latent_dim, 225 | encoder_fn, 226 | decoder_fn, 227 | loss_fn, 228 | input_shape, 229 | ) 230 | self.vae.compile(optimizer=self.optimizer) 231 | # Here, we rely on using an ensemble of models (one per concept) for 232 | # predicting the concepts from the latent factors 233 | self.concept_predictors = [ 234 | GradientBoostingClassifier() for _ in range(self.n_concepts) 235 | ] 236 | 237 | # Load concept extractor, if the model already exists 238 | if (not self.overwrite) and (os.path.exists(self.save_path)): 239 | # Need to call model before loading, in used version of Tf 240 | self.vae(np.zeros([1]+input_shape)) 241 | self.vae.load_weights(self.save_path) 242 | 243 | def _build_paired_dataset(self, ds): 244 | 245 | random_state = np.random.RandomState(0) 246 | 247 | # TODO: temporarily rely on non-tf-data implementation 248 | from datasets.dataset_utils import get_latent_bases, latent_to_index 249 | latent_sizes = np.array(self.n_concept_vals) 250 | latents_bases = get_latent_bases(latent_sizes) 251 | x_data = np.array([elem[0].numpy() for elem in ds]) 252 | c_data = np.array([elem[1].numpy() for elem in ds]) 253 | c_ids = np.array([ 254 | latent_to_index(elem[1].numpy(), latents_bases) for elem in ds 255 | ]) 256 | 257 | x_modified = [] 258 | label_ids = [] 259 | x_filtered_data = [] 260 | 261 | n_pairs, n_non_pairs = 0, 0 262 | 263 | for i in range(c_data.shape[0]): 264 | c = c_data[i] 265 | x = x_data[i] 266 | 267 | if self.n_changed_c == -1: 268 | k_observed = random_state.randint(1, self.n_concepts) 269 | else: 270 | k_observed = self.n_changed_c 271 | 272 | index_list = random_state.choice( 273 | c.shape[0], 274 | random_state.choice([1, k_observed]), 275 | replace=False, 276 | ) 277 | idx = -1 278 | c_m = np.copy(c) 279 | 280 | for index in index_list: 281 | v = np.random.choice(np.arange(self.n_concept_vals[index])) 282 | c_m[index] = v 283 | idx = index 284 | 285 | x_m = np.where(c_ids == latent_to_index(c_m, latents_bases))[0] 286 | 287 | if len(x_m) > 0: 288 | n_pairs += 1 289 | np.random.shuffle(x_m) 290 | x_m = x_m[0] 291 | x_m = x_data[x_m] 292 | else: 293 | n_non_pairs += 1 294 | continue 295 | 296 | x_filtered_data.append(x) 297 | x_modified.append(x_m) 298 | label_ids.append(idx) 299 | 300 | x_filtered_data = np.array(x_filtered_data) 301 | x_modified = np.array(x_modified) 302 | label_ids = np.array(label_ids) 303 | x_pairs = np.concatenate([x_filtered_data, x_modified], axis=1) 304 | 305 | print("Pairs: ", n_pairs) 306 | print("Non-pairs: ", n_non_pairs) 307 | print("% pairs: ", str(n_pairs/(1. * (n_pairs+n_non_pairs)) * 100.)) 308 | 309 | paired_ds = tf.data.Dataset.from_tensor_slices((x_pairs, label_ids)) 310 | return paired_ds 311 | 312 | def fit(self, data_gen_l, data_gen_u): 313 | if not ((not self.overwrite) and (os.path.exists(self.save_path))): 314 | # Train the VAE 315 | paired_ds_l = self._build_paired_dataset(data_gen_l) 316 | 317 | self.logged_fit( 318 | self.vae, 319 | paired_ds_l, 320 | data_gen_l, 321 | n_epochs=self.n_epochs, 322 | frequency=5, 323 | ) 324 | 325 | def logged_fit( 326 | self, 327 | model, 328 | training_gen, 329 | eval_gen, 330 | n_epochs=60, 331 | frequency=10, 332 | ): 333 | 334 | cp_callback = tf.keras.callbacks.ModelCheckpoint( 335 | filepath=self.save_path, 336 | verbose=True, 337 | save_best_only=True, 338 | monitor='loss', 339 | mode='auto', 340 | save_freq='epoch', 341 | ) 342 | callbacks = [cp_callback] 343 | 344 | all_c_accs = [] 345 | 346 | eval_gen_xs = remove_ds_el(remove_ds_el(eval_gen)) 347 | 348 | if not self.do_logged_fit: 349 | frequency = n_epochs 350 | 351 | for i in range(0, n_epochs, frequency): 352 | model.fit( 353 | training_gen.batch(self.batch_size), 354 | epochs=frequency, 355 | callbacks=callbacks, 356 | ) 357 | 358 | # Train the concept predictor models 359 | c_data = np.array([elem[1].numpy() for elem in eval_gen]) 360 | z_data = model.predict(eval_gen_xs.batch(self.batch_size)) 361 | 362 | c_accs = [] 363 | 364 | for j in range(self.n_concepts): 365 | # Train concept label predictors, and evaluate their predictive 366 | # accuracy 367 | z_train, z_test, c_train, c_test = train_test_split( 368 | z_data, 369 | c_data[:, j], 370 | test_size=0.15, 371 | ) 372 | 373 | if self.n_train_predictor is not None: 374 | z_train, c_train = ( 375 | z_train[:self.n_train_predictor], 376 | c_train[:self.n_train_predictor], 377 | ) 378 | # Note: here we assume sklearn .fit() re-initializes all 379 | # previous params 380 | self.concept_predictors[j].fit(z_train, c_train) 381 | 382 | accuracy = accuracy_score( 383 | c_test, 384 | self.concept_predictors[j].predict(z_test), 385 | ) 386 | print("Accuracy of concept ", str(j), " : ", str(accuracy)) 387 | c_accs.append(accuracy) 388 | 389 | all_c_accs.append(c_accs) 390 | print("\n"*3) 391 | print(f"Ran {i+frequency}/{n_epochs} epochs...") 392 | 393 | all_c_accs = np.array(all_c_accs) 394 | fpath = os.path.join(self.log_path, "freq_accuracies.txt") 395 | np.savetxt(fpath, all_c_accs, fmt='%2.4f') 396 | 397 | def predict(self, data_gen_u): 398 | data_x = remove_ds_el(data_gen_u) 399 | z_data = self.vae.predict(data_x.batch(self.batch_size)) 400 | c_data = np.stack( 401 | [ 402 | self.concept_predictors[i].predict(z_data) 403 | for i in range(self.n_concepts) 404 | ], 405 | axis=-1, 406 | ) 407 | return c_data 408 | 409 | 410 | class SSCC_BetaVAE(SSCClassifier): 411 | ''' 412 | Beta-VAE Model wrapper 413 | ''' 414 | def __init__(self, **kwargs): 415 | super().__init__(**kwargs) 416 | self.n_concept_vals = kwargs["n_concept_vals"] 417 | self.n_concepts = len(self.n_concept_vals) 418 | self.batch_size = kwargs.get("batch_size", 256) 419 | self.n_epochs = kwargs.get("epochs", 10) 420 | beta = kwargs.get("beta", 1) 421 | loss_fn = kwargs["loss_fn"] 422 | encoder_fn = kwargs["encoder_fn"] 423 | decoder_fn = kwargs["decoder_fn"] 424 | latent_dim = kwargs["latent_dim"] 425 | input_shape = kwargs["input_shape"] 426 | self.model_path = kwargs.get("save_path", None) 427 | self.retrain = kwargs.get("overwrite", False) 428 | self.n_train_predictor = kwargs.get("n_train_predictor", 2000) 429 | self.optimizer = kwargs.get( 430 | "optimizer", 431 | tf.keras.optimizers.Adam( 432 | beta_1=0.9, 433 | beta_2=0.999, 434 | epsilon=1e-8, 435 | learning_rate=0.001, 436 | ), 437 | ) 438 | self.vae = BetaVAE( 439 | latent_dim, 440 | encoder_fn, 441 | decoder_fn, 442 | loss_fn, 443 | input_shape, 444 | beta, 445 | ) 446 | self.vae.compile(optimizer=self.optimizer) 447 | # Here, we rely on an ensemble of models (one per concept) for 448 | # predicting concepts from latent factors 449 | self.concept_predictors = [ 450 | GradientBoostingClassifier() for _ in range(self.n_concepts) 451 | ] 452 | 453 | # Load concept extractor, if the model already exists 454 | if (self.model_path is not None) and (os.path.exists(self.model_path)): 455 | # Need to call model before loading, in used version of Tf 456 | self.vae(np.zeros([1]+input_shape)) 457 | self.vae.load_weights(self.model_path) 458 | 459 | def fit(self, data_gen_l, data_gen_u): 460 | if not ( 461 | (self.model_path is not None) and 462 | (os.path.exists(self.model_path)) and 463 | (not self.retrain) 464 | ): 465 | 466 | cp_callback = tf.keras.callbacks.ModelCheckpoint( 467 | filepath=self.model_path, 468 | verbose=True, 469 | save_best_only=True, 470 | monitor='loss', 471 | mode='auto', 472 | save_freq='epoch', 473 | ) 474 | callbacks = [cp_callback] 475 | self.vae.fit( 476 | data_gen_l.batch(self.batch_size), 477 | epochs=self.n_epochs, 478 | callbacks=callbacks, 479 | ) 480 | 481 | # Train the concept predictor models 482 | c_data = np.array([elem[1].numpy() for elem in data_gen_l]) 483 | z_data = self.vae.predict( 484 | remove_ds_el(remove_ds_el(data_gen_l)).batch(self.batch_size) 485 | ) 486 | 487 | for i in range(self.n_concepts): 488 | from sklearn.model_selection import train_test_split 489 | from sklearn.metrics import accuracy_score 490 | z_train, z_test, c_train, c_test = train_test_split( 491 | z_data, 492 | c_data[:, i], 493 | test_size=0.15, 494 | ) 495 | 496 | if self.n_train_predictor is not None: 497 | z_train, c_train = ( 498 | z_train[:self.n_train_predictor], 499 | c_train[:self.n_train_predictor], 500 | ) 501 | 502 | self.concept_predictors[i].fit(z_train, c_train) 503 | print( 504 | "Accuracy of concept ", 505 | i, 506 | " : ", 507 | accuracy_score( 508 | c_test, 509 | self.concept_predictors[i].predict(z_test), 510 | ), 511 | ) 512 | 513 | def predict(self, data_gen_u): 514 | data_x = remove_ds_el(data_gen_u) 515 | z_data = self.vae.predict(data_x.batch(self.batch_size)) 516 | return np.stack( 517 | [ 518 | self.concept_predictors[i].predict(z_data) 519 | for i in range(self.n_concepts) 520 | ], 521 | axis=-1, 522 | ) 523 | return c_data 524 | 525 | 526 | def remove_ds_el(data_gen): 527 | ''' 528 | Utility function for removing the last dimension data from a generator. 529 | :param data_gen: tf.data generator 530 | :return: data generator without the last tuple element, for every item in 531 | data_gen 532 | ''' 533 | return data_gen.map( 534 | lambda *args: tuple([args[i] for i in range(len(args)-1)]) 535 | ) 536 | 537 | -------------------------------------------------------------------------------- /concepts_xai/methods/VAE/baseVAE.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class BaseVAE(tf.keras.Model): 5 | """Abstract base class of a VAE model.""" 6 | 7 | def __init__(self, encoder, decoder, loss_fn, **kwargs): 8 | super(BaseVAE, self).__init__(**kwargs) 9 | self.encoder = encoder 10 | self.decoder = decoder 11 | self.loss_fn = loss_fn 12 | self.metric_names = ["loss", "reconstruction_loss", "elbo"] 13 | self.metrics_dict = { 14 | name: tf.keras.metrics.Mean(name=name) 15 | for name in self.metric_names 16 | } 17 | 18 | @property 19 | def metrics(self): 20 | return [self.metrics_dict[name] for name in self.metric_names] 21 | 22 | def encode(self, x): 23 | return self.encoder(x) 24 | 25 | def decode(self, z): 26 | return self.decoder(z) 27 | 28 | def sample_from_latent_distribution(self, z_mean, z_logvar): 29 | """ 30 | Samples from the Gaussian distribution defined by z_mean and z_logvar. 31 | """ 32 | return tf.add( 33 | z_mean, 34 | tf.exp(z_logvar / 2) * tf.random.normal(tf.shape(z_mean), 0, 1), 35 | name="sampled_latent_variable" 36 | ) 37 | 38 | def generate_random_sample(self, z=None, num_samples=1, seed=None): 39 | (_, latent_size), (_, _) = self.encoder.output_shape 40 | if z is None: 41 | z = tf.random.normal( 42 | shape=[num_samples, latent_size], 43 | mean=0.0, 44 | stddev=1.0, 45 | seed=seed, 46 | ) 47 | return self.decoder(z) 48 | 49 | def _compute_losses(self, x, is_training=False): 50 | z_mean, z_logvar = self.encoder(x, training=is_training) 51 | z_sampled = self.sample_from_latent_distribution(z_mean, z_logvar) 52 | reconstructions = self.decoder(z_sampled, training=is_training) 53 | per_sample_loss = self.loss_fn(x, reconstructions) 54 | reconstruction_loss = tf.reduce_mean(per_sample_loss) 55 | kl_loss = compute_gaussian_kl(z_mean, z_logvar) 56 | regularizer = self.regularizer(kl_loss, z_mean, z_logvar, z_sampled) 57 | loss = tf.add(reconstruction_loss, regularizer, name="loss") 58 | elbo = tf.add(reconstruction_loss, kl_loss, name="elbo") 59 | 60 | return loss, reconstruction_loss, elbo 61 | 62 | def update_metrics(self, losses): 63 | for (loss_name, loss) in losses: 64 | self.metrics_dict[loss_name].update_state(loss) 65 | 66 | def train_step(self, inputs): 67 | """Executes one training step and returns the loss. 68 | This function computes the loss and gradients, and uses the latter to 69 | update the model's parameters. 70 | """ 71 | with tf.GradientTape() as tape: 72 | loss, reconstruction_loss, elbo = self._compute_losses( 73 | inputs, 74 | is_training=True, 75 | ) 76 | 77 | gradients = tape.gradient(loss, self.trainable_variables) 78 | self.optimizer.apply_gradients( 79 | zip(gradients, self.trainable_variables) 80 | ) 81 | self.update_metrics([ 82 | ("loss", loss), 83 | ("reconstruction_loss", reconstruction_loss), 84 | ("elbo", elbo), 85 | ]) 86 | 87 | return { 88 | name: self.metrics_dict[name].result() 89 | for name in self.metric_names 90 | } 91 | 92 | def test_step(self, inputs): 93 | """Executes one test step and returns the loss. 94 | This function computes the loss, without updating the model parameters. 95 | """ 96 | loss, reconstruction_loss, elbo = self._compute_losses( 97 | inputs, 98 | is_training=False, 99 | ) 100 | self.update_metrics([ 101 | ("loss", loss), 102 | ("reconstruction_loss", reconstruction_loss), 103 | ("elbo", elbo), 104 | ]) 105 | 106 | return { 107 | name: self.metrics_dict[name].result() 108 | for name in self.metric_names 109 | } 110 | 111 | def call(self, x, **kwargs): 112 | ''' 113 | Here, we assume that calling the VAE model returns the encoder output, 114 | or the decoder output 115 | Default behaviour is to return the encoder output. 116 | To return the decoder output, pass the "decode" argument as True in the 117 | kwargs dict 118 | ''' 119 | 120 | decode = kwargs.get("decode", False) 121 | z_mean, z_logvar = self.encoder(x, training=False) 122 | z_sampled = self.sample_from_latent_distribution(z_mean, z_logvar) 123 | 124 | if decode: 125 | return self.decoder(z_sampled, training=False) 126 | return z_sampled 127 | 128 | 129 | def compute_gaussian_kl(z_mean, z_logvar): 130 | """Compute KL divergence between input Gaussian and Standard Normal.""" 131 | return tf.reduce_mean( 132 | 0.5 * tf.reduce_sum( 133 | tf.square(z_mean) + tf.exp(z_logvar) - z_logvar - 1, [1] 134 | ), 135 | name="kl_loss", 136 | ) 137 | 138 | 139 | -------------------------------------------------------------------------------- /concepts_xai/methods/VAE/betaVAE.py: -------------------------------------------------------------------------------- 1 | from concepts_xai.methods.VAE.baseVAE import BaseVAE 2 | 3 | 4 | class BetaVAE(BaseVAE): 5 | """BetaVAE model.""" 6 | 7 | def __init__(self, encoder, decoder, loss_fn, beta=1, **kwargs): 8 | """Creates a beta-VAE model. 9 | 10 | Implementing Eq. 4 of "beta-VAE: Learning Basic Visual Concepts with a 11 | Constrained Variational Framework" 12 | (https://openreview.net/forum?id=Sy2fzU9gl). 13 | 14 | :param beta: Hyperparameter for the regularizer. 15 | """ 16 | super(BetaVAE, self).__init__( 17 | encoder=encoder, 18 | decoder=decoder, 19 | loss_fn=loss_fn, 20 | **kwargs 21 | ) 22 | self.beta = beta 23 | 24 | def regularizer(self, kl_loss, z_mean, z_logvar, z_sampled): 25 | return self.beta * kl_loss 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /concepts_xai/methods/VAE/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow_probability as tfp 3 | import tensorflow as tf 4 | 5 | 6 | def bernoulli_loss( 7 | true_images, 8 | reconstructed_images, 9 | activation, 10 | subtract_true_image_entropy=False, 11 | ): 12 | 13 | """Computes the Bernoulli loss.""" 14 | flattened_dim = np.prod(true_images.get_shape().as_list()[1:]) 15 | reconstructed_images = tf.reshape( 16 | reconstructed_images, 17 | shape=[-1, flattened_dim] 18 | ) 19 | true_images = tf.reshape(true_images, shape=[-1, flattened_dim]) 20 | 21 | # Because true images are not binary, the lower bound in the xent is not 22 | # zero: the lower bound in the xent is the entropy of the true images. 23 | if subtract_true_image_entropy: 24 | dist = tfp.distributions.Bernoulli( 25 | probs=tf.clip_by_value(true_images, 1e-6, 1 - 1e-6) 26 | ) 27 | loss_lower_bound = tf.reduce_sum(dist.entropy(), axis=1) 28 | else: 29 | loss_lower_bound = 0 30 | 31 | if activation == "logits": 32 | loss = tf.reduce_sum( 33 | tf.nn.sigmoid_cross_entropy_with_logits( 34 | logits=reconstructed_images, 35 | labels=true_images 36 | ), 37 | axis=1, 38 | ) 39 | elif activation == "tanh": 40 | reconstructed_images = tf.clip_by_value( 41 | tf.nn.tanh(reconstructed_images) / 2 + 0.5, 1e-6, 1 - 1e-6 42 | ) 43 | loss = -tf.reduce_sum( 44 | ( 45 | true_images * tf.math.log(reconstructed_images) + 46 | (1 - true_images) * tf.math.log(1 - reconstructed_images) 47 | ), 48 | axis=1, 49 | ) 50 | else: 51 | raise NotImplementedError("Activation not supported.") 52 | 53 | return loss - loss_lower_bound 54 | 55 | 56 | def l2_loss(true_images, reconstructed_images, activation): 57 | """Computes the l2 loss.""" 58 | if activation == "logits": 59 | return tf.reduce_sum( 60 | tf.square(true_images - tf.nn.sigmoid(reconstructed_images)), 61 | [1, 2, 3] 62 | ) 63 | elif activation == "tanh": 64 | reconstructed_images = tf.nn.tanh(reconstructed_images) / 2 + 0.5 65 | return tf.reduce_sum( 66 | tf.square(true_images - reconstructed_images), 67 | [1, 2, 3], 68 | ) 69 | else: 70 | raise NotImplementedError("Activation not supported.") 71 | 72 | 73 | def bernoulli_fn_wrapper( 74 | activation="logits", 75 | subtract_true_image_entropy=False, 76 | ): 77 | 78 | def loss_fn(true_images, reconstructed_images): 79 | return bernoulli_loss( 80 | true_images, 81 | reconstructed_images, 82 | activation, 83 | subtract_true_image_entropy, 84 | ) 85 | return loss_fn 86 | 87 | 88 | def l2_loss_wrapper(activation="logits"): 89 | def loss_fn(true_images, reconstructed_images): 90 | return l2_loss(true_images, reconstructed_images, activation) 91 | 92 | return loss_fn 93 | -------------------------------------------------------------------------------- /concepts_xai/methods/VAE/weak_vae.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from concepts_xai.methods.VAE.baseVAE import BaseVAE, compute_gaussian_kl 3 | 4 | 5 | class GroupVAEBase(BaseVAE): 6 | """Beta-VAE with averaging from https://arxiv.org/abs/1809.02383.""" 7 | 8 | def __init__( 9 | self, 10 | encoder, 11 | decoder, 12 | loss_fn, 13 | beta=1, 14 | **kwargs, 15 | ): 16 | """ 17 | Creates a beta-VAE model with additional averaging for weak 18 | supervision. 19 | Based on https://arxiv.org/abs/1809.02383. 20 | 21 | :param beta: Hyperparameter for KL divergence. 22 | """ 23 | super(GroupVAEBase, self).__init__( 24 | encoder=encoder, 25 | decoder=decoder, 26 | loss_fn=loss_fn, 27 | **kwargs, 28 | ) 29 | self.beta = beta 30 | self.metric_names.append("regularizer") 31 | self.metrics_dict["regularizer"] = tf.keras.metrics.Mean( 32 | name="regularizer" 33 | ) 34 | 35 | def regularizer(self, kl_loss, z_mean, z_logvar, z_sampled): 36 | return self.beta * kl_loss 37 | 38 | def _split_sample_pairs(self, x): 39 | ''' 40 | Note: each point contains two frames stacked along the first dim and 41 | integer labels. 42 | ''' 43 | if isinstance(x, tuple): 44 | assert len(x) == 2, f'Expected samples to come as pairs f{x}' 45 | return (x[0], x[1]) 46 | data_shape = x.get_shape().as_list()[1:] 47 | assert data_shape[0] % 2 == 0, ( 48 | "1st dimension of concatenated pairs assumed to be even" 49 | ) 50 | data_shape[0] = data_shape[0] // 2 51 | return x[:, :data_shape[0], ...], x[:, data_shape[0]:, ...] 52 | 53 | def _compute_losses_weak(self, x_1, x_2, is_training, labels=None): 54 | 55 | z_mean, z_logvar = self.encoder(x_1, training=is_training) 56 | z_mean_2, z_logvar_2 = self.encoder(x_2, training=is_training) 57 | if labels is not None: 58 | labels = tf.squeeze( 59 | tf.one_hot(labels, z_mean.get_shape().as_list()[1]) 60 | ) 61 | kl_per_point = compute_kl(z_mean, z_mean_2, z_logvar, z_logvar_2) 62 | 63 | new_mean = 0.5 * z_mean + 0.5 * z_mean_2 64 | var_1 = tf.exp(z_logvar) 65 | var_2 = tf.exp(z_logvar_2) 66 | new_log_var = tf.math.log(0.5*var_1 + 0.5*var_2) 67 | 68 | mean_sample_1, log_var_sample_1 = self.aggregate( 69 | z_mean, 70 | z_logvar, 71 | new_mean, 72 | new_log_var, 73 | labels, 74 | kl_per_point, 75 | ) 76 | mean_sample_2, log_var_sample_2 = self.aggregate( 77 | z_mean_2, 78 | z_logvar_2, 79 | new_mean, 80 | new_log_var, 81 | labels, 82 | kl_per_point, 83 | ) 84 | 85 | z_sampled_1 = self.sample_from_latent_distribution( 86 | mean_sample_1, 87 | log_var_sample_1, 88 | ) 89 | z_sampled_2 = self.sample_from_latent_distribution( 90 | mean_sample_2, 91 | log_var_sample_2, 92 | ) 93 | 94 | reconstructions_1 = self.decoder( 95 | z_sampled_1, 96 | training=is_training 97 | ) 98 | reconstructions_2 = self.decoder( 99 | z_sampled_2, 100 | training=is_training, 101 | ) 102 | 103 | per_sample_loss_1 = self.loss_fn(x_1, reconstructions_1) 104 | per_sample_loss_2 = self.loss_fn(x_2, reconstructions_2) 105 | reconstruction_loss_1 = tf.reduce_mean(per_sample_loss_1) 106 | reconstruction_loss_2 = tf.reduce_mean(per_sample_loss_2) 107 | reconstruction_loss = ( 108 | 0.5 * reconstruction_loss_1 + 0.5 * reconstruction_loss_2 109 | ) 110 | 111 | kl_loss_1 = compute_gaussian_kl(mean_sample_1, log_var_sample_1) 112 | kl_loss_2 = compute_gaussian_kl(mean_sample_2, log_var_sample_2) 113 | kl_loss = 0.5 * kl_loss_1 + 0.5 * kl_loss_2 114 | 115 | regularizer = self.regularizer(kl_loss, None, None, None) 116 | 117 | loss = tf.add(reconstruction_loss, regularizer, name="loss") 118 | elbo = tf.add(reconstruction_loss, kl_loss, name="elbo") 119 | 120 | return reconstruction_loss, regularizer, loss, elbo 121 | 122 | def _split_labels(self, inputs): 123 | return inputs, None 124 | 125 | def train_step(self, inputs): 126 | x, labels = self._split_labels(inputs) 127 | x_1, x_2 = self._split_sample_pairs(x) 128 | 129 | with tf.GradientTape() as tape: 130 | rec_loss, regularizer, loss, elbo = self._compute_losses_weak( 131 | x_1=x_1, 132 | x_2=x_2, 133 | is_training=True, 134 | labels=labels, 135 | ) 136 | 137 | gradients = tape.gradient(loss, self.trainable_variables) 138 | self.optimizer.apply_gradients( 139 | zip(gradients, self.trainable_variables) 140 | ) 141 | self.update_metrics([ 142 | ("loss", loss), 143 | ("reconstruction_loss", rec_loss), 144 | ("elbo", elbo), 145 | ("regularizer",regularizer), 146 | ]) 147 | 148 | return { 149 | name: self.metrics_dict[name].result() 150 | for name in self.metric_names 151 | } 152 | 153 | def test_step(self, inputs): 154 | x, labels = self._split_labels(inputs) 155 | x_1, x_2 = self._split_sample_pairs(x) 156 | rec_loss, regularizer, loss, elbo = self._compute_losses_weak( 157 | x_1=x_1, 158 | x_2=x_2, 159 | is_training=False, 160 | labels=labels, 161 | ) 162 | self.update_metrics([ 163 | ("loss", loss), 164 | ("reconstruction_loss", rec_loss), 165 | ("elbo", elbo), 166 | ("regularizer", regularizer) 167 | ]) 168 | 169 | return { 170 | name: self.metrics_dict[name].result() 171 | for name in self.metric_names 172 | } 173 | 174 | 175 | class GroupVAEArgmax(GroupVAEBase): 176 | """Class implementing the group-VAE without any label.""" 177 | 178 | def aggregate( 179 | self, 180 | z_mean, 181 | z_logvar, 182 | new_mean, 183 | new_log_var, 184 | labels, 185 | kl_per_point, 186 | ): 187 | return aggregate_argmax( 188 | z_mean, 189 | z_logvar, 190 | new_mean, 191 | new_log_var, 192 | kl_per_point, 193 | ) 194 | 195 | 196 | class GroupVAELabels(GroupVAEBase): 197 | """ 198 | Class implementing the group-VAE with labels on which factor is shared. 199 | """ 200 | 201 | def _split_labels(self, inputs): 202 | return inputs 203 | 204 | def aggregate( 205 | self, 206 | z_mean, 207 | z_logvar, 208 | new_mean, 209 | new_log_var, 210 | labels, 211 | kl_per_point, 212 | ): 213 | return aggregate_labels( 214 | z_mean, 215 | z_logvar, 216 | new_mean, 217 | new_log_var, 218 | labels, 219 | kl_per_point, 220 | ) 221 | 222 | 223 | class MLVae(GroupVAEBase): 224 | """Beta-VAE with averaging from https://arxiv.org/abs/1705.08841.""" 225 | 226 | def _compute_losses_weak(self, x_1, x_2, is_training, labels=None): 227 | z_mean, z_logvar = self.encoder(x_1, training=is_training) 228 | z_mean_2, z_logvar_2 = self.encoder(x_2, training=is_training) 229 | if labels is not None: 230 | labels = tf.squeeze( 231 | tf.one_hot(labels, z_mean.get_shape().as_list()[1]) 232 | ) 233 | kl_per_point = compute_kl(z_mean, z_mean_2, z_logvar, z_logvar_2) 234 | 235 | var_1 = tf.exp(z_logvar) 236 | var_2 = tf.exp(z_logvar_2) 237 | new_var = 2 * var_1 * var_2 / (var_1 + var_2) 238 | new_mean = ((z_mean / var_1) + (z_mean_2 / var_2)) * new_var * 0.5 239 | 240 | new_log_var = tf.math.log(new_var) 241 | 242 | mean_sample_1, log_var_sample_1 = self.aggregate( 243 | z_mean, 244 | z_logvar, 245 | new_mean, 246 | new_log_var, 247 | labels, 248 | kl_per_point, 249 | ) 250 | mean_sample_2, log_var_sample_2 = self.aggregate( 251 | z_mean_2, 252 | z_logvar_2, 253 | new_mean, 254 | new_log_var, 255 | labels, 256 | kl_per_point, 257 | ) 258 | 259 | z_sampled_1 = self.sample_from_latent_distribution( 260 | mean_sample_1, 261 | log_var_sample_1, 262 | ) 263 | z_sampled_2 = self.sample_from_latent_distribution( 264 | mean_sample_2, 265 | log_var_sample_2, 266 | ) 267 | 268 | reconstructions_1 = self.decoder( 269 | z_sampled_1, 270 | training=is_training 271 | ) 272 | reconstructions_2 = self.decoder( 273 | z_sampled_2, 274 | training=is_training, 275 | ) 276 | 277 | per_sample_loss_1 = self.loss_fn(x_1, reconstructions_1) 278 | per_sample_loss_2 = self.loss_fn(x_2, reconstructions_2) 279 | reconstruction_loss_1 = tf.reduce_mean(per_sample_loss_1) 280 | reconstruction_loss_2 = tf.reduce_mean(per_sample_loss_2) 281 | reconstruction_loss = ( 282 | 0.5 * reconstruction_loss_1 + 0.5 * reconstruction_loss_2 283 | ) 284 | 285 | kl_loss_1 = compute_gaussian_kl(mean_sample_1, log_var_sample_1) 286 | kl_loss_2 = compute_gaussian_kl(mean_sample_2, log_var_sample_2) 287 | kl_loss = 0.5 * kl_loss_1 + 0.5 * kl_loss_2 288 | 289 | regularizer = self.regularizer(kl_loss, None, None, None) 290 | 291 | loss = tf.add(reconstruction_loss, regularizer, name="loss") 292 | elbo = tf.add(reconstruction_loss, kl_loss, name="elbo") 293 | 294 | return reconstruction_loss, regularizer, loss, elbo 295 | 296 | 297 | class MLVaeLabels(MLVae): 298 | """Class implementing the ML-VAE with labels on which factor is shared.""" 299 | 300 | def _split_labels(self, inputs): 301 | return inputs 302 | 303 | def aggregate( 304 | self, 305 | z_mean, 306 | z_logvar, 307 | new_mean, 308 | new_log_var, 309 | labels, 310 | kl_per_point, 311 | ): 312 | return aggregate_labels( 313 | z_mean, 314 | z_logvar, 315 | new_mean, 316 | new_log_var, 317 | labels, 318 | kl_per_point, 319 | ) 320 | 321 | 322 | class MLVaeArgmax(MLVae): 323 | """Class implementing the ML-VAE without any label.""" 324 | 325 | def aggregate( 326 | self, 327 | z_mean, 328 | z_logvar, 329 | new_mean, 330 | new_log_var, 331 | labels, 332 | kl_per_point, 333 | ): 334 | return aggregate_argmax( 335 | z_mean, 336 | z_logvar, 337 | new_mean, 338 | new_log_var, 339 | kl_per_point, 340 | ) 341 | 342 | 343 | def aggregate_labels( 344 | z_mean, 345 | z_logvar, 346 | new_mean, 347 | new_log_var, 348 | labels, 349 | kl_per_point, 350 | ): 351 | """Use labels to aggregate. 352 | 353 | Labels contains a one-hot encoding with a single 1 of a factor shared. We 354 | enforce which dimension of the latent code learn which factor (dimension 1 355 | learns factor 1) and we enforce that each factor of variation is encoded 356 | in a single dimension. 357 | 358 | Args: 359 | z_mean: Mean of the encoder distribution for the original image. 360 | z_logvar: Logvar of the encoder distribution for the original image. 361 | new_mean: Average mean of the encoder distribution of the pair of images. 362 | new_log_var: Average logvar of the encoder distribution of the pair of 363 | images. 364 | labels: One-hot-encoding with the position of the dimension that should not 365 | be shared. 366 | kl_per_point: Distance between the two encoder distributions (unused). 367 | 368 | Returns: 369 | Mean and logvariance for the new observation. 370 | """ 371 | z_mean_averaged = tf.where( 372 | tf.math.equal( 373 | labels, 374 | tf.expand_dims(tf.reduce_max(labels, axis=1), 1) 375 | ), 376 | z_mean, 377 | new_mean, 378 | ) 379 | z_logvar_averaged = tf.where( 380 | tf.math.equal( 381 | labels, 382 | tf.expand_dims(tf.reduce_max(labels, axis=1), 1) 383 | ), 384 | z_logvar, 385 | new_log_var, 386 | ) 387 | return z_mean_averaged, z_logvar_averaged 388 | 389 | 390 | def aggregate_argmax( 391 | z_mean, 392 | z_logvar, 393 | new_mean, 394 | new_log_var, 395 | kl_per_point, 396 | ): 397 | """Argmax aggregation with adaptive k. 398 | 399 | The bottom k dimensions in terms of distance are not averaged. K is 400 | estimated adaptively by binning the distance into two bins of equal width. 401 | 402 | Args: 403 | z_mean: Mean of the encoder distribution for the original image. 404 | z_logvar: Logvar of the encoder distribution for the original image. 405 | new_mean: Average mean of the encoder distribution of the pair of images. 406 | new_log_var: Average logvar of the encoder distribution of the pair of 407 | images. 408 | kl_per_point: Distance between the two encoder distributions. 409 | 410 | Returns: 411 | Mean and logvariance for the new observation. 412 | """ 413 | mask = tf.equal(tf.map_fn(discretize_in_bins, kl_per_point, tf.int32), 1) 414 | z_mean_averaged = tf.where(mask, z_mean, new_mean) 415 | z_logvar_averaged = tf.where(mask, z_logvar, new_log_var) 416 | return z_mean_averaged, z_logvar_averaged 417 | 418 | 419 | def discretize_in_bins(x): 420 | """Discretize a vector in two bins.""" 421 | return tf.histogram_fixed_width_bins( 422 | x, 423 | [tf.reduce_min(x), tf.reduce_max(x)], 424 | nbins=2, 425 | ) 426 | 427 | 428 | def compute_kl(z_1, z_2, logvar_1, logvar_2): 429 | var_1 = tf.exp(logvar_1) 430 | var_2 = tf.exp(logvar_2) 431 | return var_1/var_2 + tf.square(z_2-z_1)/var_2 - 1 + logvar_2 - logvar_1 432 | -------------------------------------------------------------------------------- /concepts_xai/methods/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/concept-based-xai/a86c6abd609ec2ef0cf56cb556e3f834a0962bdc/concepts_xai/methods/__init__.py -------------------------------------------------------------------------------- /concepts_xai/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/concept-based-xai/a86c6abd609ec2ef0cf56cb556e3f834a0962bdc/concepts_xai/utils/__init__.py -------------------------------------------------------------------------------- /concepts_xai/utils/architectures.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.keras.layers import Dense, Dropout 3 | 4 | 5 | def small_cnn(input_shape, num_classes=10): 6 | ''' 7 | CNN architecture used in CME (https://arxiv.org/abs/2010.13233) for the 8 | dSprites task 9 | :param input_shape: input sample shape 10 | :param num_classes: number of output classes 11 | :return: compiled keras model 12 | ''' 13 | 14 | inputs = tf.keras.Input(shape=input_shape) 15 | x = get_cnn_body_layers(inputs) 16 | outputs = Dense(num_classes, activation='softmax')(x) 17 | 18 | model = tf.keras.Model(inputs=inputs, outputs=outputs) 19 | 20 | optimizer = tf.keras.optimizers.Adam(lr=1e-3, amsgrad=True) 21 | 22 | model.compile(optimizer=optimizer, 23 | loss='sparse_categorical_crossentropy', 24 | metrics=['acc']) 25 | 26 | return model 27 | 28 | 29 | def get_cnn_body_layers(inputs): 30 | x = tf.keras.layers.Conv2D( 31 | filters=32, 32 | kernel_size=4, 33 | strides=2, 34 | activation='relu', 35 | padding="same", 36 | name="e1", 37 | )(inputs) 38 | x = tf.keras.layers.Conv2D( 39 | filters=32, 40 | kernel_size=4, 41 | strides=2, 42 | activation='relu', 43 | padding="same", 44 | name="e2", 45 | )(x) 46 | x = tf.keras.layers.Conv2D( 47 | filters=64, 48 | kernel_size=2, 49 | strides=2, 50 | activation='relu', 51 | padding="same", 52 | name="e3", 53 | )(x) 54 | x = tf.keras.layers.Conv2D( 55 | filters=64, 56 | kernel_size=2, 57 | strides=2, 58 | activation='relu', 59 | padding="same", 60 | name="e4", 61 | )(x) 62 | x = tf.keras.layers.Flatten()(x) 63 | x = Dense(128, activation='relu')(x) 64 | x = Dropout(0.4)(x) 65 | x = Dense(64, activation='relu')(x) 66 | x = Dropout(0.2)(x) 67 | return x 68 | 69 | 70 | def multi_task_cnn(input_shape, num_concpet_values, concept_names=[]): 71 | ''' 72 | :param input_shape: 73 | :param num_concpet_values: list of containing number of concepts values for 74 | each concepts (e.g. [3,6,40,32,3] for dSprites) 75 | :return: CNN arrchitecture model with shared convolutional layers and 76 | multiple heads for respective task to predict multiple concept values 77 | in different outputs color, shape, scale, rotation, x and y positions 78 | ''' 79 | 80 | if concept_names != []: 81 | assert len(num_concpet_values) == len(concept_names), \ 82 | "The number of concepts is different for the values and the names" 83 | inputs = tf.keras.Input(shape=input_shape) 84 | 85 | h = get_cnn_body_layers(inputs) 86 | 87 | outpuut_layers = [] 88 | for i, c in enumerate(num_concpet_values): 89 | l = tf.keras.layers.Dense( 90 | num_concpet_values[i], 91 | activation="softmax", 92 | name="l" + str(i) if concept_names == [] else concept_names[i] 93 | )(h) 94 | outpuut_layers.append(l) 95 | 96 | model = tf.keras.Model(inputs=inputs, outputs=outpuut_layers) 97 | 98 | optimizer = tf.keras.optimizers.Adam(lr=1e-3, amsgrad=True) 99 | 100 | model.compile( 101 | optimizer=optimizer, 102 | loss=[ 103 | tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False) 104 | for c in num_concpet_values 105 | ], 106 | metrics=['acc'], 107 | ) 108 | 109 | return model 110 | 111 | 112 | def sigmoid_cnn(input_shape, num_concept_values, concept_names=[]): 113 | ''' 114 | :param input_shape: 115 | :param num_concpet_values: list of containing number of concepts values for 116 | each concepts (e.g. [3,6,40,32,3] for dSprites) 117 | :return: CNN arrchitecture model with shared convolutional layers and 118 | multiple heads for respective task to predict multiple concept values 119 | in different outputs color, shape, scale, rotation, x and y positions 120 | ''' 121 | 122 | if concept_names != []: 123 | assert len(num_concept_values) == len(concept_names), \ 124 | "The number of concepts is different for the values and the names" 125 | 126 | inputs = tf.keras.Input(shape=input_shape) 127 | 128 | h = get_cnn_body_layers(inputs) 129 | 130 | output_layers = [] 131 | for i, c in enumerate(num_concept_values): 132 | l = tf.keras.layers.Dense( 133 | num_concept_values[i], 134 | activation="sigmoid", 135 | name="l" + str(i) if concept_names == [] else concept_names[i], 136 | )(h) 137 | output_layers.append(l) 138 | 139 | model = tf.keras.Model(inputs=inputs, outputs=output_layers) 140 | 141 | optimizer = tf.keras.optimizers.Adam(lr=1e-3, amsgrad=True) 142 | 143 | model.compile( 144 | optimizer=optimizer, 145 | loss=[ 146 | tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False) 147 | for c in num_concept_values 148 | ], 149 | metrics=['acc'], 150 | ) 151 | return model 152 | 153 | 154 | def conv_encoder(input_shape, num_latent): 155 | """ 156 | CNN encoder architecture used in the 'Challenging Common Assumptions in the 157 | Unsupervised Learning of Disentangled Representations' paper 158 | (https://arxiv.org/abs/1811.12359) 159 | 160 | Note: model is uncompiled 161 | """ 162 | 163 | inputs = tf.keras.Input(shape=input_shape) 164 | 165 | e1 = tf.keras.layers.Conv2D( 166 | filters=32, 167 | kernel_size=4, 168 | strides=2, 169 | activation='relu', 170 | padding="same", 171 | name="e1", 172 | )(inputs) 173 | 174 | e2 = tf.keras.layers.Conv2D( 175 | filters=32, 176 | kernel_size=4, 177 | strides=2, 178 | activation='relu', 179 | padding="same", 180 | name="e2", 181 | )(e1) 182 | 183 | e3 = tf.keras.layers.Conv2D( 184 | filters=64, 185 | kernel_size=2, 186 | strides=2, 187 | activation='relu', 188 | padding="same", 189 | name="e3", 190 | )(e2) 191 | 192 | e4 = tf.keras.layers.Conv2D( 193 | filters=64, 194 | kernel_size=2, 195 | strides=2, 196 | activation='relu', 197 | padding="same", 198 | name="e4", 199 | )(e3) 200 | 201 | flat_e4 = tf.keras.layers.Flatten()(e4) 202 | e5 = tf.keras.layers.Dense(256, activation='relu', name="e5")(flat_e4) 203 | 204 | means = tf.keras.layers.Dense( 205 | num_latent, 206 | activation=None, 207 | name="means", 208 | )(e5) 209 | log_var = tf.keras.layers.Dense( 210 | num_latent, 211 | activation=None, 212 | name="log_var", 213 | )(e5) 214 | 215 | encoder = tf.keras.Model(inputs=inputs, outputs=[means, log_var]) 216 | 217 | return encoder 218 | 219 | 220 | def deconv_decoder(output_shape, num_latent): 221 | """ 222 | CNN decoder architecture used in the 'Challenging Common Assumptions in the 223 | Unsupervised Learning of Disentangled Representations' paper 224 | (https://arxiv.org/abs/1811.12359) 225 | 226 | Note: model is uncompiled 227 | """ 228 | 229 | latent_inputs = tf.keras.Input(shape=(num_latent,)) 230 | d1 = tf.keras.layers.Dense(256, activation='relu')(latent_inputs) 231 | d2 = tf.keras.layers.Dense(1024, activation='relu')(d1) 232 | d2_reshaped = tf.keras.layers.Reshape([4, 4, 64])(d2) 233 | 234 | d3 = tf.keras.layers.Conv2DTranspose( 235 | filters=64, 236 | kernel_size=4, 237 | strides=2, 238 | activation='relu', 239 | padding="same", 240 | )(d2_reshaped) 241 | 242 | d4 = tf.keras.layers.Conv2DTranspose( 243 | filters=32, 244 | kernel_size=4, 245 | strides=2, 246 | activation='relu', 247 | padding="same", 248 | )(d3) 249 | 250 | d5 = tf.keras.layers.Conv2DTranspose( 251 | filters=32, 252 | kernel_size=4, 253 | strides=2, 254 | activation='relu', 255 | padding="same", 256 | )(d4) 257 | 258 | d6 = tf.keras.layers.Conv2DTranspose( 259 | filters=output_shape[2], 260 | kernel_size=4, 261 | strides=2, padding="same", 262 | )(d5) 263 | output = tf.keras.layers.Reshape(output_shape)(d6) 264 | 265 | return tf.keras.Model(inputs=latent_inputs, outputs=[output]) 266 | -------------------------------------------------------------------------------- /concepts_xai/utils/model_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | 5 | def get_model( 6 | model, 7 | model_save_path, 8 | overwrite=False, 9 | train_gen=None, 10 | val_gen=None, 11 | batch_size=256, 12 | epochs=250, 13 | ): 14 | ''' 15 | Code for loading/training a given model 16 | :param model: Compiled Keras model 17 | :param model_save_path: Path for saving/loading the model weights 18 | :param train_gen: tf.dataset generator for training the model 19 | :param val_gen: tf.dataset generator for validating the model 20 | :param batch_size: Batch size used during training 21 | :param epochs: Number of epochs to use for training 22 | :return: Trained/loaded model 23 | ''' 24 | 25 | if (overwrite) or (not os.path.exists(model_save_path)): 26 | 27 | if (train_gen is None) or (val_gen is None): 28 | raise ValueError( 29 | "Training and/or Validation data generators not provided." 30 | ) 31 | 32 | # Train model 33 | callbacks = [] 34 | 35 | cp_callback = tf.keras.callbacks.ModelCheckpoint( 36 | filepath=model_save_path, 37 | verbose=True, 38 | save_best_only=True, 39 | monitor='val_acc', 40 | mode='auto', 41 | save_freq='epoch', 42 | ) 43 | callbacks.append(cp_callback) 44 | 45 | model.fit( 46 | train_gen.batch(batch_size), 47 | epochs=epochs, 48 | validation_data=val_gen.batch(batch_size), 49 | callbacks=callbacks, 50 | ) 51 | 52 | #make sure to save the path 53 | dir_name = os.path.dirname(model_save_path) 54 | if not os.path.exists(dir_name): 55 | os.makedirs(dir_name) 56 | model.save_weights(model_save_path) 57 | 58 | else: 59 | print("Loading pre-trainined model") 60 | model.load_weights(model_save_path) 61 | 62 | return model 63 | -------------------------------------------------------------------------------- /concepts_xai/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | from .architectures import multi_task_cnn, small_cnn 5 | from .model_loader import get_model 6 | 7 | 8 | def tf_data_split(ds, test_size=0.15, n_samples=None): 9 | ''' 10 | Method for train/test splitting a tf.dataset, assuming it was generated 11 | from numpy arrays 12 | :param ds: tf.dataset, assumed to be generated "from_tensor_slices" (see 13 | ./datasets/dSprites.py for an example) 14 | :param test_size: Ratio of test samples 15 | :param n_samples: Total number of samples in ds 16 | :return: Two tf datasets, obtained by splitting ds 17 | ''' 18 | 19 | if n_samples is None: 20 | # Compute total number of samples, if not provided 21 | n_samples = int(ds.cardinality().numpy()) 22 | 23 | # Split the dataset 24 | train_size = int((1. - test_size) * n_samples) 25 | test_size = n_samples - train_size 26 | ds_train = ds.take(train_size) 27 | ds_test = ds.skip(train_size).take(test_size) 28 | 29 | return ds_train, ds_test 30 | 31 | 32 | def setup_experiment_dir(dir_path, overwrite=False): 33 | 34 | # Remove prior contents 35 | if overwrite and os.path.exists(dir_path): 36 | shutil.rmtree(dir_path) 37 | 38 | # Create the directory if it doesn't exist 39 | if not os.path.exists(dir_path): 40 | os.makedirs(dir_path) 41 | 42 | 43 | def convert_to_multioutput(c): 44 | multi_output = tuple([c[i:i + 1] for i in range(c.shape[0])]) 45 | return multi_output 46 | 47 | 48 | def setup_basic_model(dataset, n_epochs, save_path, model_type=""): 49 | ''' 50 | Train a standard CNN model (used for CBM and CME) 51 | :param dataset: dataset class to train with 52 | :param n_epochs: number of epochs to train for 53 | :param save_path: path to saving/loading the model 54 | :param model_type: whether single_output ("") or multi_task 55 | :return: 56 | ''' 57 | 58 | data_gen_train, data_gen_test, c_names = dataset.load_data() 59 | 60 | n_classes = dataset.n_classes 61 | input_shape = dataset.sample_shape 62 | n_samples = dataset.n_train_samples 63 | num_concpet_values = dataset.n_c_vals_list 64 | concept_names = dataset.c_names 65 | 66 | print(f"No. training samples: {n_samples}") 67 | # Train/Load CNN model using data generators 68 | if model_type == "multi_task": 69 | 70 | # remove y and convert concept data to multi-task 71 | data_gen_train_c = data_gen_train.map( 72 | lambda x, c, y: (x, convert_to_multioutput(c)) 73 | ) 74 | data_gen_t, data_gen_v = tf_data_split( 75 | data_gen_train_c, 76 | test_size=0.15, 77 | n_samples=n_samples, 78 | ) 79 | untrained_model = multi_task_cnn( 80 | input_shape, 81 | num_concpet_values, 82 | concept_names, 83 | ) 84 | basic_model = get_model( 85 | untrained_model, 86 | save_path, 87 | data_gen_t, 88 | data_gen_v, 89 | epochs=n_epochs, 90 | ) 91 | results = basic_model.evaluate(data_gen_train_c.batch(batch_size=256)) 92 | else: 93 | # Remove concept data 94 | data_gen_train_no_c = data_gen_train.map(lambda x, c, y: (x, y)) 95 | # Train/Load CNN model using data generators 96 | data_gen_t, data_gen_v = tf_data_split( 97 | data_gen_train_no_c, 98 | test_size=0.15, 99 | n_samples=n_samples, 100 | ) 101 | untrained_model = small_cnn(input_shape, n_classes) 102 | basic_model = get_model( 103 | untrained_model, 104 | save_path, 105 | train_gen=data_gen_t, 106 | val_gen=data_gen_v, 107 | epochs=n_epochs, 108 | ) 109 | 110 | # Evaluate the model 111 | data_gen_test_no_c = data_gen_test.map(lambda x, c, y: (x, y)) 112 | results = basic_model.evaluate(data_gen_test_no_c.batch(batch_size=256)) 113 | print("Model performance: ", results) 114 | 115 | return basic_model 116 | -------------------------------------------------------------------------------- /concepts_xai/utils/visualisation.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from scipy import stats 4 | 5 | 6 | def plot_np_img(img_np, cmap=None): 7 | ''' 8 | Plot image given as a numpy array 9 | ''' 10 | plt.imshow(img_np, cmap=cmap) 11 | plt.show() 12 | 13 | 14 | def visualisation_experiment(vae, imgs): 15 | ''' 16 | Plot images in 'imgs' using the 'vae' reconstructions and original images 17 | ''' 18 | kwargs = {"decode":True} 19 | 20 | for img in imgs: 21 | # Plot original image 22 | plot_np_img(img) 23 | # Retrieve reconstructed image from vae 24 | reconstruction = vae(np.expand_dims(img, axis=0), **kwargs) 25 | reconstruction = reconstruction.numpy()[0] 26 | reconstruction = stats.logistic.cdf(reconstruction) 27 | # Plot reconstructed image 28 | plot_np_img(reconstruction) 29 | 30 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | 2 | # Path to the downloaded dsprites.npz file 3 | dsprites_path : ... 4 | 5 | # Path to the downloaded 'cars' directory 6 | cars3D_path : ... 7 | 8 | # Path to the downloaded 'smallNorb' directory 9 | smallNorb_path : ... 10 | 11 | # Path to the downloaded '3dshapes.h5' file 12 | shapes3d_path : ... 13 | 14 | -------------------------------------------------------------------------------- /config_template.yml: -------------------------------------------------------------------------------- 1 | 2 | # Path to the downloaded dsprites.npz file 3 | dsprites_path : ... 4 | 5 | # Path to the downloaded 'cars' directory 6 | cars3D_path : ... 7 | 8 | # Path to the downloaded 'smallNorb' directory 9 | smallNorb_path : ... 10 | 11 | # Path to the downloaded '3dshapes.h5' file 12 | shapes3d_path : ... -------------------------------------------------------------------------------- /download_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # Note: code partially-adapted from the 5 | # 'https://github.com/google-research/disentanglement_lib' repo 6 | 7 | 8 | echo "Downloading small_norb." 9 | if [[ ! -d "small_norb" ]]; then 10 | mkdir small_norb 11 | fi 12 | if [[ ! -e small_norb/"smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat" ]]; then 13 | wget -O small_norb/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz 14 | gunzip small_norb/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz 15 | fi 16 | if [[ ! -e small_norb/"smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat" ]]; then 17 | wget -O small_norb/smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz 18 | gunzip small_norb/smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz 19 | fi 20 | if [[ ! -e small_norb/"smallnorb-5x46789x9x18x6x2x96x96-training-info.mat" ]]; then 21 | wget -O small_norb/smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz 22 | gunzip small_norb/smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz 23 | fi 24 | if [[ ! -e small_norb/"smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat" ]]; then 25 | wget -O small_norb/smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz 26 | gunzip small_norb/smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz 27 | fi 28 | if [[ ! -e small_norb/"smallnorb-5x46789x9x18x6x2x96x96-testing-cat.mat" ]]; then 29 | wget -O small_norb/smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz 30 | gunzip small_norb/smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz 31 | fi 32 | if [[ ! -e small_norb/"smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat" ]]; then 33 | wget -O small_norb/smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz 34 | gunzip small_norb/smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz 35 | fi 36 | echo "Downloading small_norb completed!" 37 | 38 | 39 | echo "Downloading cars dataset." 40 | if [[ ! -d "cars" ]]; then 41 | wget -O nips2015-analogy-data.tar.gz http://www.scottreed.info/files/nips2015-analogy-data.tar.gz 42 | tar xzf nips2015-analogy-data.tar.gz 43 | rm nips2015-analogy-data.tar.gz 44 | mv data/cars . 45 | rm -r data 46 | fi 47 | echo "Downloading cars completed!" 48 | 49 | 50 | echo "Downloading dSprites dataset." 51 | if [[ ! -d "dsprites" ]]; then 52 | mkdir dsprites 53 | wget -O dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz https://github.com/deepmind/dsprites-dataset/raw/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz 54 | fi 55 | echo "Downloading dSprites completed!" 56 | 57 | 58 | echo "Downloading shapes3d dataset." 59 | if [[ ! -d "shapes3d" ]]; then 60 | mkdir shapes3d 61 | wget -O shapes3d/3dshapes.h5 https://storage.cloud.google.com/3d-shapes/3dshapes.h5 62 | 63 | fi 64 | echo "Downloading shapes3d completed!" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.3.3 2 | numpy==1.18.5 3 | scipy==1.4.1 4 | tensorflow==2.3.2 5 | scikit-learn==0.24.0 6 | Pillow==8.1.0 7 | pandas==1.2.0 8 | h5py==2.10.0 9 | PyYAML==5.3.1 10 | sklearn==0.0 11 | setuptools==50.3.2 12 | tensorflow-probability==0.7.0 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r") as f: 4 | long_description = f.read() 5 | 6 | version = '0.1.0' 7 | setup( 8 | name='concepts_xai', 9 | version=version, 10 | packages=find_packages(), 11 | description='Concept Extraction Comparison', 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | python_requires='>=3.6', 15 | classifiers=[ 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: MIT License", 18 | "Operating System :: OS Independent", 19 | ], 20 | ) 21 | 22 | ''' 23 | VERSIONING: 24 | 25 | 1.2.0.dev1 # Development release 26 | 1.2.0a1 # Alpha Release 27 | 1.2.0b1 # Beta Release 28 | 1.2.0rc1 # Release Candidate 29 | 1.2.0 # Final Release 30 | 1.2.0.post1 # Post Release 31 | 15.10 # Date based release 32 | 23 # Serial release 33 | 34 | 35 | 36 | 37 | MAJOR version when they make incompatible API changes, 38 | 39 | MINOR version when they add functionality in a backwards-compatible manner, and 40 | 41 | MAINTENANCE version when they make backwards-compatible bug fixes. 42 | 43 | ''' 44 | --------------------------------------------------------------------------------