├── .gitignore
├── LICENSE.md
├── README.md
├── concepts_xai
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── cars3D.py
│ ├── dSprites.py
│ ├── dataset_utils.py
│ ├── latentFactorData.py
│ ├── load_paths.py
│ ├── shapes3D.py
│ ├── smallNorb.py
│ └── tabular_toy.py
├── evaluation
│ ├── __init__.py
│ └── metrics
│ │ ├── accuracy.py
│ │ ├── completeness.py
│ │ ├── downstream_task.py
│ │ ├── mpo.py
│ │ ├── niching.py
│ │ └── purity.py
├── methods
│ ├── CBM
│ │ └── CBModel.py
│ ├── CME
│ │ ├── CtlModel.py
│ │ └── ItCModel.py
│ ├── CW
│ │ └── CWLayer.py
│ ├── OCACE
│ │ ├── __init__.py
│ │ ├── topicModel.py
│ │ └── visualisation.py
│ ├── SENN
│ │ ├── __init__.py
│ │ ├── aggregators.py
│ │ └── base_senn.py
│ ├── SSCC
│ │ └── SSCClassifier.py
│ ├── VAE
│ │ ├── baseVAE.py
│ │ ├── betaVAE.py
│ │ ├── losses.py
│ │ └── weak_vae.py
│ └── __init__.py
└── utils
│ ├── __init__.py
│ ├── architectures.py
│ ├── model_loader.py
│ ├── utils.py
│ └── visualisation.py
├── config.yml
├── config_template.yml
├── download_datasets.sh
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea/
3 | .python-version
4 | *.ipynb
5 | __pycache__
6 | build/
7 | cars/
8 | concepts_xai.egg-info/
9 | dist/
10 | dsprites/
11 | shapes3d/
12 | small_norb/
13 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Dmitry Kazhdan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Concept-based XAI Library
2 |
3 | CXAI is an open-source library for research on concept-based Explainable AI
4 | (XAI).
5 |
6 | CXAI supports a variety of different models, datasets, and evaluation metrics,
7 | associated with concept-based approaches:
8 |
9 |
10 | ### High-level Specs:
11 |
12 | _Methods_:
13 | - [Now You See Me (CME): Concept-based Model Extraction](https://arxiv.org/abs/2010.13233).
14 | - [Concept Bottleneck Models](https://arxiv.org/abs/2007.04612)
15 | - [Weakly-Supervised Disentanglement Without Compromises](https://arxiv.org/abs/2002.02886)
16 | - [Challenging Common Assumptions in the Unsupervised Learning of Disentangled Representations](https://arxiv.org/abs/1811.12359)
17 | - [Concept Whitening for Interpretable Image Recognition](https://arxiv.org/abs/2002.01650)
18 | - [On Completeness-aware Concept-Based Explanations in Deep Neural Networks](https://arxiv.org/abs/1910.07969)
19 | - [Towards Robust Interpretability with Self-Explaining Neural Networks
20 | ](https://arxiv.org/abs/1806.07538)
21 |
22 |
23 | _Datasets_:
24 | - [dSprites](https://github.com/deepmind/dsprites-dataset)
25 | - [Shapes3D](https://github.com/deepmind/3d-shapes)
26 | - [Cars3D](https://papers.nips.cc/paper/2015/hash/e07413354875be01a996dc560274708e-Abstract.html)
27 | - [SmallNorb](https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/)
28 |
29 | to get the datasets run script datasets/download_datasets.sh
30 |
31 | ### Requirements
32 |
33 | - Python 3.7 - 3.8
34 | - See 'requirements.txt' for the rest of required packages
35 |
36 | ### Installation
37 | If installing from the source, please proceed by running the following command:
38 | ```bash
39 | python setup.py install
40 | ```
41 | This will install the `concepts-xai` package together with all its dependencies.
42 |
43 | To test that the package has been successfully installed, you may run:
44 | ```python
45 | import concepts_xai
46 | help("concepts_xai")
47 | ```
48 | to display all the subpackages included from this installation.
49 |
50 | ### Subpackages
51 |
52 | - `datasets`: datasets to use, including task functions.
53 | - `evaluation`: different evaluation metrics to use for evaluating our methods.
54 | - `experiments`: experimental setups (To-be-added soon)
55 | - `methods`: defines the concept-based methods. Note: SSCC defines wrappers around these methods, that turn then into semi-supervised concept labelling methods.
56 | - `utils`: contains utility functions for model creation as well as data management.
57 |
58 |
59 | ### Citing
60 |
61 | If you find this code useful in your research, please consider citing:
62 |
63 | ```
64 | @article{kazhdan2021disentanglement,
65 | title={Is Disentanglement all you need? Comparing Concept-based \& Disentanglement Approaches},
66 | author={Kazhdan, Dmitry and Dimanov, Botty and Terre, Helena Andres and Jamnik, Mateja and Li{\`o}, Pietro and Weller, Adrian},
67 | journal={arXiv preprint arXiv:2104.06917},
68 | year={2021}
69 | }
70 | ```
71 |
72 | This work has been presented at the [RAI](https://sites.google.com/view/rai-workshop/), [WeaSuL](https://weasul.github.io/), and
73 | [RobustML](https://sites.google.com/connect.hku.hk/robustml-2021/home) workshops, at [The Ninth International Conference on Learning
74 | Representations (ICLR 2021)](https://iclr.cc/).
75 |
--------------------------------------------------------------------------------
/concepts_xai/__init__.py:
--------------------------------------------------------------------------------
1 | from concepts_xai.datasets import *
2 | from concepts_xai.evaluation import *
3 | from concepts_xai.methods import *
4 | from concepts_xai.utils import *
5 |
--------------------------------------------------------------------------------
/concepts_xai/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dmitrykazhdan/concept-based-xai/a86c6abd609ec2ef0cf56cb556e3f834a0962bdc/concepts_xai/datasets/__init__.py
--------------------------------------------------------------------------------
/concepts_xai/datasets/cars3D.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import PIL
4 | import numpy as np
5 | import scipy.io as sio
6 | import tensorflow as tf
7 |
8 | from .latentFactorData import LatentFactorData, get_task_data, built_task_fn
9 |
10 | CARS_concept_names = ['elevation', 'azimuth', 'object_type']
11 | CARS_concept_n_vals = [4, 24, 183]
12 |
13 |
14 | class Cars3D(LatentFactorData):
15 |
16 | def __init__(
17 | self,
18 | dataset_path,
19 | task_name='elevation_full',
20 | train_size=0.85,
21 | random_state=42,
22 | ):
23 | '''
24 | :param dataset_path: path to the cars dataset folder
25 | :param task_name: the task to use with the dataset for creating labels
26 | '''
27 | super().__init__(
28 | dataset_path=dataset_path,
29 | task_name=task_name,
30 | num_factors=3,
31 | sample_shape=[64, 64, 3],
32 | c_names=CARS_concept_names,
33 | task_fn=CARS3D_TASKS[task_name],
34 | )
35 | self._get_generators(train_size, random_state)
36 |
37 | def _load_x_c_data(self):
38 | x_data = []
39 | all_files = [
40 | x for x in tf.io.gfile.listdir(self.dataset_path) if ".mat" in x
41 | ]
42 | c_data = []
43 |
44 | for i, filename in enumerate(all_files):
45 | data_mesh = load_mesh(os.path.join(self.dataset_path, filename))
46 | factor1 = np.array(list(range(4)))
47 | factor2 = np.array(list(range(24)))
48 | all_factors = np.transpose([
49 | np.tile(factor1, len(factor2)),
50 | np.repeat(factor2, len(factor1)),
51 | np.tile(i, len(factor1) * len(factor2))
52 | ])
53 |
54 | c_data += [
55 | list(all_factors[j]) for j in range(all_factors.shape[0])
56 | ]
57 | x_data.append(data_mesh)
58 |
59 | x_data = np.concatenate(x_data)
60 | c_data = np.array(c_data)
61 |
62 | return x_data, c_data
63 |
64 |
65 | def load_mesh(filename):
66 | """Parses a single source file and rescales contained images."""
67 | with tf.io.gfile.GFile(filename, "rb") as f:
68 | mesh = np.einsum("abcde->deabc", sio.loadmat(f)["im"])
69 | flattened_mesh = mesh.reshape((-1,) + mesh.shape[2:])
70 | rescaled_mesh = np.zeros((flattened_mesh.shape[0], 64, 64, 3))
71 | for i in range(flattened_mesh.shape[0]):
72 | flattened_im = flattened_mesh[i, :, :, :]
73 | pic = PIL.Image.fromarray(flattened_im)
74 | pic.thumbnail((64, 64), PIL.Image.ANTIALIAS)
75 | np_pic = np.array(pic)
76 | rescaled_mesh[i, :, :, :] = np_pic
77 | return rescaled_mesh * 1. / 255
78 |
79 |
80 | # ===========================================================================
81 | # Task DEFINITIONS
82 | # ===========================================================================
83 |
84 | def get_elevation_full(x_data, c_data):
85 | label_fn = lambda c: c[0]
86 | return get_task_data(x_data, c_data, label_fn, filter_fn=None)
87 |
88 |
89 | def get_all_concepts(x_data, c_data):
90 | label_fn = lambda c: c
91 | return get_task_data(x_data, c_data, label_fn, filter_fn=None)
92 |
93 |
94 | # ===========================================================================
95 | # Define task function lookups
96 | # ===========================================================================
97 |
98 | CARS3D_TASKS = {
99 | 'all_concepts': get_all_concepts,
100 | 'elevation_full': get_elevation_full,
101 | "bin_elevation": built_task_fn(lambda c: int(c[0] >= 2)),
102 | }
103 |
--------------------------------------------------------------------------------
/concepts_xai/datasets/dSprites.py:
--------------------------------------------------------------------------------
1 | '''
2 | See https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_reloading_example.ipynb
3 | for a nice overview.
4 |
5 | 6 latent factors: color, shape, scale, rotation and position (x and y)
6 | 'latents_sizes': [1, 3, 6, 40, 32, 32]
7 | '''
8 |
9 | import numpy as np
10 |
11 | from .latentFactorData import LatentFactorData, get_task_data
12 |
13 | ################################################################################
14 | ## GLOBAL VARIABLES
15 | ################################################################################
16 |
17 | CONCEPT_NAMES = [
18 | 'shape',
19 | 'scale',
20 | 'rotation',
21 | 'x_pos',
22 | 'y_pos',
23 | ]
24 |
25 | CONCEPT_N_VALUES = [
26 | 3, # [square, ellipse, heart]
27 | 6, # np.linspace(0.5, 1, 6)
28 | 40, # 40 values in {0, 1, ..., 39} representing angles in [0, 2 * pi]
29 | 32, # 32 values in {0, 2, 3, ..., 31} representing coordinates in [0, 1]
30 | 32, # 32 values in {0, 2, 3, ..., 31} representing coordinates in [0, 1]
31 | ]
32 |
33 |
34 | ################################################################################
35 | ## DATASET LOADER
36 | ################################################################################
37 |
38 | class dSprites(LatentFactorData):
39 |
40 | def __init__(
41 | self,
42 | dataset_path,
43 | task='shape_scale_small_skip',
44 | train_size=0.85,
45 | random_state=None,
46 | ):
47 | '''
48 | :param str dataset_path: path to the .npz dsprites file.
49 | :param Or[
50 | str,
51 | Function[(ndarray, ndarray), (ndarray, ndarray, ndarray)
52 | ] task: the task to use with the dataset for creating
53 | labels. If this is a string, then it must be the name of a
54 | pre-defined task in the DSPRITES_TASKS lookup table. Otherwise
55 | we expect a function that takes two np.ndarrays (x_data, c_data),
56 | corresponding to the dSprites samples and their respective concepts
57 | respectively, and produces a tuple of three np.ndarrays
58 | (x_data, c_data, y_data) corresponding to the task's
59 | samples, ground truth concept values, and labels, respectively.
60 | '''
61 |
62 | # Note: we exclude the trivial 'color' concept
63 | if isinstance(task, str):
64 | if task not in DSPRITES_TASKS:
65 | raise ValueError(
66 | f'If the given task is a string, then it is expected to be '
67 | f'the name of a pre-defined task in '
68 | f'{list(DSPRITES_TASKS.keys())}. However, we were given '
69 | f'"{task}" which is not a known task.'
70 | )
71 | task_fn = DSPRITES_TASKS[task]
72 | else:
73 | task_fn = task
74 | super().__init__(
75 | dataset_path=dataset_path,
76 | task_name="dSprites",
77 | num_factors=len(CONCEPT_NAMES),
78 | sample_shape=[64, 64, 1],
79 | c_names=CONCEPT_NAMES,
80 | task_fn=task_fn,
81 | )
82 | self._get_generators(train_size, random_state)
83 |
84 | def _load_x_c_data(self):
85 | # Load dataset
86 | dataset_zip = np.load(self.dataset_path)
87 | x_data = dataset_zip['imgs']
88 | x_data = np.expand_dims(x_data, axis=-1)
89 | c_data = dataset_zip['latents_classes']
90 | c_data = c_data[:, 1:] # Remove color concept
91 | return x_data, c_data
92 |
93 |
94 | ################################################################################
95 | # TASK DEFINITIONS
96 | ################################################################################
97 |
98 | def cardinality_encoding(card_group_1, card_group_2):
99 | result_to_encoding = {}
100 | for i in card_group_1:
101 | for j in card_group_2:
102 | result_to_encoding[(i, j)] = len(result_to_encoding)
103 | return result_to_encoding
104 |
105 |
106 | def small_skip_ranges_filter_fn(concept):
107 | '''
108 | Filter out certain values only
109 | '''
110 | ranges = [
111 | list(range(3)),
112 | list(range(6)),
113 | list(range(0, 40, 5)),
114 | list(range(0, 32, 2)),
115 | list(range(0, 32, 2)),
116 | ]
117 | return all([
118 | (concept[i] in ranges[i]) for i in range(len(ranges))
119 | ])
120 |
121 |
122 | def get_shape_full(x_data, c_data):
123 | return get_task_data(
124 | x_data=x_data,
125 | c_data=c_data,
126 | label_fn=lambda c_data: c_data[0],
127 | )
128 |
129 |
130 | def get_shape_small_skip(x_data, c_data):
131 | return get_task_data(
132 | x_data=x_data,
133 | c_data=c_data,
134 | label_fn=lambda c_data: c_data[0],
135 | filter_fn=small_skip_ranges_filter_fn,
136 | )
137 |
138 |
139 | def get_shape_scale_full(x_data, c_data):
140 | return get_task_data(
141 | x_data=x_data,
142 | c_data=c_data,
143 | label_fn=lambda c_data: (
144 | c_data[0] * CONCEPT_N_VALUES[1] + c_data[1]
145 | ),
146 | )
147 |
148 |
149 | def get_shape_scale_small_skip(x_data, c_data):
150 | label_remap = cardinality_encoding(
151 | list(range(3)),
152 | list(range(6)),
153 | )
154 | return get_task_data(
155 | x_data=x_data,
156 | c_data=c_data,
157 | label_fn=lambda c_data: label_remap[(c_data[0], c_data[1])],
158 | filter_fn=small_skip_ranges_filter_fn,
159 | )
160 |
161 |
162 | ################################################################################
163 | # TASK FUNCTION LOOKUP TABLE
164 | ################################################################################
165 |
166 | DSPRITES_TASKS = {
167 | "shape_full": get_shape_full,
168 | "shape_scale_full": get_shape_scale_full,
169 | "shape_scale_small_skip": get_shape_scale_small_skip,
170 | "shape_small_skip": get_shape_small_skip,
171 | }
172 |
--------------------------------------------------------------------------------
/concepts_xai/datasets/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from . import dSprites, shapes3D, smallNorb, cars3D
4 |
5 | '''
6 | A primer on bases:
7 | Assume you have a vector A = (x, y, z), where every dimension is in base
8 | (a, b, c), then, in order to convert each of those dimensions to decimal, we do:
9 |
10 | D = (z * 1) + (y * c) + (x * b * c)
11 |
12 | Or, can define bases vector B = [b*c, c, 1], and then define D = B.A
13 |
14 | Example:
15 | (1, 0, 1) in bases (2, 2, 2) (i.e. in binary):
16 | D = (2^0 * 1) + (2^1 * 0) + (2^2 * 1) = 1 + 4 = 5
17 | '''
18 |
19 |
20 | def get_latent_bases(latent_sizes):
21 | return np.concatenate(
22 | [latent_sizes[::-1].cumprod()[::-1][1:], np.array([1, ])]
23 | )
24 |
25 |
26 | # Convert a concept-based index into a single index
27 | def latent_to_index(latents, latents_bases):
28 | return np.dot(latents, latents_bases).astype(int)
29 |
30 | dataset_concept_names = {
31 | "dSprites": dSprites.DSPRITES_concept_names,
32 | "shapes3d": shapes3D.SHAPES3D_concept_names,
33 | "smallNorb": smallNorb.SMALLNORB_concept_names,
34 | "cars3D": cars3D.CARS_concept_names
35 | }
36 |
--------------------------------------------------------------------------------
/concepts_xai/datasets/latentFactorData.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from sklearn.model_selection import train_test_split
4 |
5 |
6 | class LatentFactorData(object):
7 | """
8 | Abstract class for datasets with known ground truth factors
9 | """
10 | def __init__(
11 | self,
12 | dataset_path,
13 | task_name,
14 | num_factors,
15 | sample_shape,
16 | c_names,
17 | task_fn,
18 | ):
19 | self.dataset_path = dataset_path
20 | self.task_name = task_name
21 | self.num_factors = num_factors
22 | self.sample_shape = sample_shape
23 | self.c_names = c_names
24 | self.task_fn = task_fn
25 | self._has_generators = False
26 |
27 | def _load_x_c_data(self):
28 | raise NotImplementedError(
29 | "Need to implement code for loading input and concept data"
30 | )
31 |
32 | def get_concept_values(self):
33 | if not self._has_generators:
34 | self._get_generators()
35 | return self.n_c_vals_list
36 |
37 | def _get_generators(self, train_size=0.85, random_state=42):
38 |
39 | x_data, c_data = self._load_x_c_data()
40 | x_data, c_data, y_data = self.task_fn(x_data, c_data)
41 | x_data = x_data.astype(np.float32)
42 | self.n_classes = len(np.unique(y_data))
43 | self.n_c_vals_list = [
44 | len(np.unique(c_data[:, i])) for i in range(c_data.shape[1])
45 | ]
46 |
47 | # Create dictionary of
48 | # concept_id --> dictionary: consequtive_id : original_id
49 | self.cid_new_to_old = []
50 | for i in range(c_data.shape[1]):
51 | unique_vals, unique_inds = np.unique(
52 | c_data[:, i],
53 | return_inverse=True,
54 | )
55 | self.cid_new_to_old.append(
56 | {i : unique_vals[i] for i in range(len(unique_vals))}
57 | )
58 | c_data[:, i] = unique_inds
59 |
60 | (
61 | self.x_train,
62 | self.x_test,
63 | self.y_train,
64 | self.y_test,
65 | self.c_train,
66 | self.c_test,
67 | ) = train_test_split(
68 | x_data,
69 | y_data,
70 | c_data,
71 | train_size=train_size,
72 | random_state=random_state,
73 | )
74 |
75 | self.n_train_samples = self.c_train.shape[0]
76 | self.n_test_samples = self.c_test.shape[0]
77 | self.data_gen_train = tf.data.Dataset.from_tensor_slices(
78 | (self.x_train, self.c_train, self.y_train)
79 | )
80 | self.data_gen_test = tf.data.Dataset.from_tensor_slices(
81 | (self.x_test, self.c_test, self.y_test)
82 | )
83 | self._has_generators = True
84 |
85 | def load_data(self):
86 | return self.data_gen_train, self.data_gen_test, self.c_names
87 |
88 |
89 | # ===========================================================================
90 | # Task definition utility functions
91 | # ===========================================================================
92 |
93 | def built_task_fn(label_fn, filter_fn=None):
94 | def task_fn(x_data, c_data):
95 | return get_task_data(x_data, c_data, label_fn, filter_fn=filter_fn)
96 | return task_fn
97 |
98 |
99 | def get_task_data(x_data, c_data, label_fn, filter_fn=None):
100 |
101 | if filter_fn is not None:
102 | ids = np.array([filter_fn(c) for c in c_data])
103 | ids = np.where(ids)[0]
104 | c_data = c_data[ids]
105 | x_data = x_data[ids]
106 |
107 | y_data = np.array([label_fn(c) for c in c_data])
108 |
109 | return x_data, c_data, y_data
110 |
111 |
112 |
113 |
--------------------------------------------------------------------------------
/concepts_xai/datasets/load_paths.py:
--------------------------------------------------------------------------------
1 | import yaml
2 |
3 |
4 | def load_dataset_paths(config_path):
5 | '''
6 | :param config_path: Path to .yml file containing the individual dataset
7 | paths
8 | :return: Dictionary of dataset_name --> dataset_path
9 | '''
10 |
11 | dataset_paths_dict = {}
12 |
13 | with open(config_path) as config_file:
14 | data = yaml.load(config_file, Loader=yaml.FullLoader)
15 |
16 | dataset_paths_dict["dSprites"] = data['dsprites_path']
17 | dataset_paths_dict["cars3D"] = data['cars3D_path']
18 | dataset_paths_dict["smallNorb"] = data['smallNorb_path']
19 | dataset_paths_dict["shapes3d"] = data['shapes3d_path']
20 |
21 | return dataset_paths_dict
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/concepts_xai/datasets/shapes3D.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import numpy as np
3 | import h5py
4 |
5 | from .latentFactorData import LatentFactorData, get_task_data
6 |
7 | ################################################################################
8 | ## GLOBAL VARIABLES
9 | ################################################################################
10 |
11 |
12 | CONCEPT_NAMES = [
13 | 'floor_hue',
14 | 'wall_hue',
15 | 'object_hue',
16 | 'scale',
17 | 'shape',
18 | 'orientation',
19 | ]
20 | CONCEPT_N_VALUES = [
21 | 10,
22 | 10,
23 | 10,
24 | 8,
25 | 4,
26 | 15,
27 | ]
28 |
29 |
30 | ################################################################################
31 | ## DATASET LOADER
32 | ################################################################################
33 |
34 | class shapes3D(LatentFactorData):
35 |
36 | def __init__(
37 | self,
38 | dataset_path,
39 | task='shape_full',
40 | train_size=0.85,
41 | random_state=None,
42 | ):
43 | '''
44 | :param dataset_path: path to the .npz shapes3D file
45 | :param Or[
46 | str,
47 | Function[(ndarray, ndarray), (ndarray, ndarray, ndarray)
48 | ] task: the task to use with the dataset for creating
49 | labels. If this is a string, then it must be the name of a
50 | pre-defined task in the SHAPES3D_TASKS lookup table. Otherwise
51 | we expect a function that takes two np.ndarrays (x_data, c_data),
52 | corresponding to the dSprites samples and their respective concepts
53 | respectively, and produces a tuple of three np.ndarrays
54 | (x_data, c_data, y_data) corresponding to the task's
55 | samples, ground truth concept values, and labels, respectively.
56 | '''
57 |
58 | if isinstance(task, str):
59 | if task not in SHAPES3D_TASKS:
60 | raise ValueError(
61 | f'If the given task is a string, then it is expected to be '
62 | f'the name of a pre-defined task in '
63 | f'{list(SHAPES3D_TASKS.keys())}. However, we were given '
64 | f'"{task}" which is not a known task.'
65 | )
66 | task_fn = SHAPES3D_TASKS[task]
67 | else:
68 | task_fn = task
69 |
70 | super().__init__(
71 | dataset_path=dataset_path,
72 | task_name="3dshapes",
73 | num_factors=len(CONCEPT_NAMES),
74 | sample_shape=[64, 64, 3],
75 | c_names=CONCEPT_NAMES,
76 | task_fn=task_fn,
77 | )
78 | self._get_generators(train_size, random_state)
79 |
80 | def _load_x_c_data(self):
81 | # Get concept data
82 | latent_size_lists = [list(np.arange(i)) for i in CONCEPT_N_VALUES]
83 | c_data = np.array(list(itertools.product(*latent_size_lists)))
84 | # Load image data
85 | with h5py.File(self.dataset_path, 'r') as hf:
86 | x_data = np.array(hf.get('images')) / 255.
87 |
88 |
89 | return x_data, c_data
90 |
91 |
92 | ################################################################################
93 | # TASK DEFINITIONS
94 | ################################################################################
95 |
96 | def small_skip_ranges_filter_fn(concept):
97 | '''
98 | Filter out certain values only
99 | '''
100 | ranges = [
101 | list(range(0, 10, 2)),
102 | list(range(0, 10, 2)),
103 | list(range(0, 10, 2)),
104 | list(range(0, 8, 2)),
105 | list(range(4)),
106 | list(range(0, 15, 2)),
107 | ]
108 |
109 | return all([(concept[i] in ranges[i]) for i in range(len(ranges))])
110 |
111 |
112 | def get_shape_full(x_data, c_data):
113 | return get_task_data(
114 | x_data=x_data,
115 | c_data=c_data,
116 | label_fn=lambda x: x[4],
117 | )
118 |
119 |
120 | def get_shape_small_skip(x_data, c_data):
121 | return get_task_data(
122 | x_data=x_data,
123 | c_data=c_data,
124 | label_fn=lambda x: x[4],
125 | filter_fn=small_skip_ranges_filter_fn,
126 | )
127 |
128 |
129 | def get_reduced_filter_fn():
130 | ranges = [
131 | list(range(0, 10, 2)),
132 | list(range(0, 10, 2)),
133 | list(range(0, 10, 2)),
134 | list(range(0, 8, 2)),
135 | list(range(4)),
136 | list(range(15)),
137 | ]
138 |
139 | def filter_fn(concept):
140 | return all([(concept[i] in ranges[i]) for i in range(len(ranges))])
141 |
142 | return filter_fn
143 |
144 |
145 | def get_reduced_shapes3d(x_data, c_data):
146 |
147 | ranges = [
148 | list(range(0, 10, 2)),
149 | list(range(0, 10, 2)),
150 | list(range(0, 10, 2)),
151 | list(range(0, 8, 2)),
152 | list(range(4)),
153 | list(range(15)),
154 | ]
155 |
156 | label_fn = shape_label_fn
157 |
158 | def filter_fn(concept):
159 | return all([(concept[i] in ranges[i]) for i in range(len(ranges))])
160 |
161 | return get_task_data(
162 | x_data=x_data,
163 | c_data=c_data,
164 | label_fn=shape_label_fn,
165 | filter_fn=filter_fn,
166 | )
167 |
168 |
169 | ################################################################################
170 | # TASK FUNCTION LOOKUP TABLE
171 | ################################################################################
172 |
173 | SHAPES3D_TASKS = {
174 | "reduced_shapes3d" : get_reduced_shapes3d,
175 | "shape_full" : get_shape_full,
176 | "shape_small_skip" : get_shape_small_skip,
177 | }
178 |
179 |
180 |
--------------------------------------------------------------------------------
/concepts_xai/datasets/smallNorb.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import PIL
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | from .latentFactorData import LatentFactorData, get_task_data
8 |
9 | SMALLNORB_concept_names = [
10 | 'category',
11 | 'instance',
12 | 'elevation',
13 | 'azimuth',
14 | 'lighting',
15 | ]
16 | SMALLNORB_concept_n_vals = [5, 10, 9, 18, 6]
17 |
18 |
19 | class SmallNorb(LatentFactorData):
20 |
21 | def __init__(
22 | self,
23 | dataset_path,
24 | task_name='category_full',
25 | train_size=0.85,
26 | random_state=42,
27 | ):
28 | '''
29 | :param dataset_path: path to the smallnorb files directory
30 | :param task_name: the task to use with the dataset for creating labels
31 | '''
32 |
33 | super().__init__(
34 | dataset_path=dataset_path,
35 | task_name=task_name,
36 | num_factors=5,
37 | sample_shape=[64, 64, 1],
38 | c_names=SMALLNORB_concept_names,
39 | task_fn=SMALLNORB_TASKS[task_name],
40 | )
41 | self._get_generators(train_size, random_state)
42 |
43 | def _load_x_c_data(self):
44 | files_dir = self.dataset_path
45 | filename_template = "smallnorb-{}-{}.mat"
46 | splits = [
47 | "5x46789x9x18x6x2x96x96-training",
48 | "5x01235x9x18x6x2x96x96-testing",
49 | ]
50 | x_datas, c_datas = [], []
51 |
52 | for i, split in enumerate(splits):
53 | data_fname = os.path.join(
54 | files_dir,
55 | filename_template.format(splits[i], 'dat'),
56 | )
57 | cat_fname = os.path.join(
58 | files_dir,
59 | filename_template.format(splits[i], 'cat'),
60 | )
61 | info_fname = os.path.join(
62 | files_dir,
63 | filename_template.format(splits[i], 'info'),
64 | )
65 |
66 | x_data = _read_binary_matrix(data_fname)
67 | # Resize data, and only retain data from 1 camera
68 | x_data = _resize_images(x_data[:, 0])
69 | c_cat = _read_binary_matrix(cat_fname)
70 | c_info = _read_binary_matrix(info_fname)
71 | c_info = np.copy(c_info)
72 | # Set azimuth values to be consecutive digits
73 | c_info[:, 2:3] = c_info[:, 2:3] / 2
74 | c_data = np.column_stack((c_cat, c_info))
75 |
76 | x_datas.append(x_data)
77 | c_datas.append(c_data)
78 |
79 | x_data = np.concatenate(x_datas)
80 | x_data = np.expand_dims(x_data, axis=-1)
81 | c_data = np.concatenate(c_datas)
82 |
83 | return x_data, c_data
84 |
85 |
86 | def _resize_images(integer_images):
87 | resized_images = np.zeros((integer_images.shape[0], 64, 64))
88 | for i in range(integer_images.shape[0]):
89 | image = PIL.Image.fromarray(integer_images[i, :, :])
90 | image = image.resize((64, 64), PIL.Image.ANTIALIAS)
91 | resized_images[i, :, :] = image
92 | return resized_images / 255.
93 |
94 |
95 | def _read_binary_matrix(filename):
96 | """Reads and returns binary formatted matrix stored in filename."""
97 | with tf.io.gfile.GFile(filename, "rb") as f:
98 | s = f.read()
99 | magic = int(np.frombuffer(s, "int32", 1))
100 | ndim = int(np.frombuffer(s, "int32", 1, 4))
101 | eff_dim = max(3, ndim)
102 | raw_dims = np.frombuffer(s, "int32", eff_dim, 8)
103 | dims = []
104 | for i in range(0, ndim):
105 | dims.append(raw_dims[i])
106 |
107 | dtype_map = {
108 | 507333717: "int8",
109 | 507333716: "int32",
110 | 507333713: "float",
111 | 507333715: "double"
112 | }
113 | data = np.frombuffer(s, dtype_map[magic], offset=8 + eff_dim * 4)
114 | data = data.reshape(tuple(dims))
115 | return data
116 |
117 |
118 | # ===========================================================================
119 | # Task DEFINITIONS
120 | # ===========================================================================
121 |
122 | def get_category_full(x_data, c_data):
123 | label_fn = lambda c: c[0]
124 | return get_task_data(x_data, c_data, label_fn, filter_fn=None)
125 |
126 |
127 | # ===========================================================================
128 | # Define task function lookups
129 | # ===========================================================================
130 |
131 | SMALLNORB_TASKS = {
132 | 'category_full': get_category_full,
133 | }
134 |
--------------------------------------------------------------------------------
/concepts_xai/datasets/tabular_toy.py:
--------------------------------------------------------------------------------
1 | '''
2 | Implements the tabular toy dataset defined by Mahinpei et al. in "Promises and
3 | Pitfalls of Black-Box Concept Learning Models" (found in
4 | https://arxiv.org/abs/2106.13314).
5 |
6 | '''
7 |
8 | import numpy as np
9 | from .latentFactorData import LatentFactorData
10 |
11 |
12 | ################################################################################
13 | ## GLOBAL VARIABLES
14 | ################################################################################
15 |
16 | CONCEPT_NAMES = [
17 | 'x_pos',
18 | 'y_pos',
19 | 'z_pos',
20 | ]
21 |
22 | CONCEPT_N_VALUES = [
23 | 1,
24 | 1,
25 | 1,
26 | ]
27 |
28 |
29 | ################################################################################
30 | ## DATASET LOADER
31 | ################################################################################
32 |
33 | class TabularToy(LatentFactorData):
34 |
35 | def __init__(
36 | self,
37 | num_samples,
38 | cov=None,
39 | num_concepts=len(CONCEPT_NAMES),
40 | train_size=0.85,
41 | random_state=None,
42 | task=None,
43 | ):
44 | """
45 | This dataset has 7 features and 3 concepts. It is generated by 3 latent
46 | variables (x, y, z), randomly sampled from a multivariate normal
47 | distribution (with the provided covariance matrix), which define all
48 | seven features of a sample as:
49 |
50 | Feature 0: np.sin(x_vars) + x_vars
51 | Feature 1: np.cos(x_vars) + x_vars
52 | Feature 2: np.sin(y_vars) + y_vars
53 | Feature 3: np.cos(y_vars) + y_vars
54 | Feature 4 np.sin(z_vars) + z_vars
55 | Feature 5: np.cos(z_vars) + z_vars
56 | Feature 6: x_vars**2 + y_vars**2 + z_vars**2
57 |
58 | As concepts, we provide by default a vector of three values
59 | [x_pos, y_pos,z_pos] indicating whether each concept is positive or not.
60 | BY default, if a task is not provided, then we assign as a label to each
61 | sample whether or not two or more of its latent factors are positive.
62 |
63 | :param int num_samples: The total number of samples we will generate
64 | for this dataset.
65 | :param Or[np.ndarray, float] cov: A covariance matrix to use for
66 | sampling the latent variables. If provided as a float, then we will
67 | assign cov as the covariance between any two distinct latent
68 | factors. If not given, then we will use the identity matrix.
69 | :param int num_concepts: A number in {1, 2, 3} indicating how many
70 | concepts we want to include in our concept representation. If less
71 | than 3, then we will trim the concept list from the right.
72 | :param float train_size: A number in [0,1] indicating which fraction of
73 | the entire dataset will be used for training and which fraction will
74 | be used for testing.
75 | :param Function[(ndarray, ndarray), (ndarray, ndarray, ndarray)] task:
76 | The task to use with the dataset for creating
77 | labels. We expect a function that takes two np.ndarrays
78 | (x_data, c_data) corresponding to the dSprites samples and their
79 | respective concepts respectively, and produces a tuple of three
80 | np.ndarrays (x_data, c_data, y_data) corresponding to the task's
81 | samples, ground truth concept values, and labels, respectively.
82 | """
83 |
84 | if task is None:
85 | task = lambda x, c: (
86 | x,
87 | c,
88 | (np.sum(c, axis=-1) > 1).astype(np.int32),
89 | )
90 | super().__init__(
91 | dataset_path=None,
92 | task_name="tabular_toy",
93 | num_factors=num_concepts,
94 | sample_shape=[7],
95 | c_names=CONCEPT_NAMES,
96 | task_fn=task,
97 | )
98 | self.num_samples = num_samples
99 | self.cov = np.eye(3) if cov is None else cov
100 | if isinstance(self.cov, (float, int)):
101 | self.cov = np.array([
102 | [1, self.cov, self.cov],
103 | [self.cov, 1, self.cov],
104 | [self.cov, self.cov, 1],
105 | ])
106 |
107 | self._get_generators(train_size, random_state)
108 |
109 | def _load_x_c_data(self):
110 | # Sample the x, y, and z variables
111 | latent_vars = np.random.multivariate_normal(
112 | mean=[0, 0, 0],
113 | cov=self.cov,
114 | size=(self.num_samples,),
115 | )
116 | x_vars = latent_vars[:, 0]
117 | y_vars = latent_vars[:, 1]
118 | z_vars = latent_vars[:, 2]
119 |
120 | # The features are just non-linear functions applied to each
121 | # variable
122 | features = [
123 | np.sin(x_vars) + x_vars,
124 | np.cos(x_vars) + x_vars,
125 | np.sin(y_vars) + y_vars,
126 | np.cos(y_vars) + y_vars,
127 | np.sin(z_vars) + z_vars,
128 | np.cos(z_vars) + z_vars,
129 | x_vars**2 + y_vars**2 + z_vars**2,
130 | ]
131 | features = np.stack(features, axis=1)
132 |
133 | # The concepts just check if the variables are positive
134 | x_pos = (x_vars > 0).astype(np.int32)
135 | y_pos = (y_vars > 0).astype(np.int32)
136 | z_pos = (z_vars > 0).astype(np.int32)
137 | concepts = np.squeeze(
138 | np.stack([x_pos, y_pos, z_pos][:self.num_factors], axis=1)
139 | )
140 |
141 | # And that's it buds
142 | return features, concepts
143 |
--------------------------------------------------------------------------------
/concepts_xai/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dmitrykazhdan/concept-based-xai/a86c6abd609ec2ef0cf56cb556e3f834a0962bdc/concepts_xai/evaluation/__init__.py
--------------------------------------------------------------------------------
/concepts_xai/evaluation/metrics/accuracy.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import accuracy_score
2 |
3 |
4 | def compute_accuracies(c_true, c_pred):
5 | '''
6 | Compute the accuracy scores per concept
7 | :param c_true: Numpy array of (n_samples, n_concepts) of ground-truth
8 | concept values
9 | :param c_pred: Numpy array of (n_samples, n_concepts) of predicted concept
10 | values
11 | :return: Accuracies for all samples, per concept
12 | '''
13 | n_concepts = c_true.shape[1]
14 | return [
15 | accuracy_score(c_true[:, i], c_pred[:, i]) for i in range(n_concepts)
16 | ]
17 |
--------------------------------------------------------------------------------
/concepts_xai/evaluation/metrics/completeness.py:
--------------------------------------------------------------------------------
1 | '''
2 | Implementation of metrics used for computing the completeness/explicitness
3 | of a given set of concepts.
4 | These metrics are inspired by Yeh et al. "On Completeness-aware Concept-Based
5 | Explanations in Deep Neural Networks" seen in https://arxiv.org/abs/1910.07969v5
6 | '''
7 |
8 | import numpy as np
9 | import sklearn
10 | import tensorflow as tf
11 |
12 | from sklearn.model_selection import train_test_split
13 |
14 |
15 | ################################################################################
16 | ## Helper Functions
17 | ################################################################################
18 |
19 | def _get_default_model(num_concepts, num_hidden_acts, end_activation=None):
20 | """
21 | Helper function that returns a simple 3-layer ReLU MLP model with a hidden
22 | layer with 500 activation in it.
23 |
24 | Used as the default estimator for computing the completeness score.
25 |
26 | :param int num_concepts: The number concept vectors we have in our setup.
27 | :param int num_hidden_acts: The dimensionality of the concept vectors.
28 |
29 | :returns tf.keras.Model: The default model used for reconstructing a set of
30 | intermediate hidden activations from a set of concepts scores.
31 | """
32 |
33 | return tf.keras.models.Sequential([
34 | tf.keras.layers.Dense(
35 | 500,
36 | input_dim=num_concepts,
37 | activation='relu'
38 | ),
39 | tf.keras.layers.Dense(
40 | num_hidden_acts,
41 | activation=end_activation,
42 | ),
43 | ])
44 |
45 |
46 | ################################################################################
47 | ## Concept Score Functions
48 | ################################################################################
49 |
50 |
51 | def dot_prod_concept_score(
52 | features,
53 | concept_vectors,
54 | epsilon=1e-5,
55 | beta=1e-5,
56 | channels_axis=-1,
57 | ):
58 | """
59 | Returns a vector of concept scores for the given features using a normalized
60 | dot product similarity as in Yeh et al.
61 |
62 | :param np.ndarray features: A 2D matrix of samples with shape
63 | (n_samples, ..., n_features, ...) with dimension at axis channels_axis
64 | havin `n_features` in it.
65 | :param np.ndarray concept_vectors: A 2D array of shape
66 | (num_concepts, n_features) where each column represents a given
67 | meaningful concept direction.
68 | :param float beta: A value used for zeroing dot products that are
69 | considered to be irrelevant. If a dot product is less than beta, then
70 | its score will be zero.
71 | :param float epsilon: A value used for numerical stability to avoid
72 | division by zero.
73 | :param int channels_axis: the channels axis in the input features. If not
74 | given, then we assume it is the last dimension always.
75 |
76 | :returns np.ndarray: A 2D matrix with shape (n_samples, n_concepts) where
77 | the (i, j)-th entry represents the score that the j-th concept assigned
78 | the i-th sample in `features`.
79 | """
80 | # First check that all the dimensions make sense
81 | assert features.shape[channels_axis] == concept_vectors.shape[-1], (
82 | f'Expected input to have {concept_vectors.shape[-1]} elements in its '
83 | f'channels axis (defined as axis {channels_axis}). '
84 | f'Instead, we found the input to have shape {features.shape}.'
85 | )
86 | # First normalize all concepts across their channels dimension
87 | concept_vectors_norm = np.linalg.norm(
88 | concept_vectors,
89 | axis=-1,
90 | keepdims=True,
91 | )
92 | concept_vectors = concept_vectors / (concept_vectors_norm + epsilon)
93 |
94 | # For simplicity, we will always move the channels dimension to the end
95 | if channels_axis not in [-1, len(features.shape) - 1]:
96 | # Then perform a transpose in here
97 | perm = list(range(len(features.shape)))
98 | perm[channels_axis] = len(features.shape) - 1
99 | perm[len(features.shape) - 1] = channels_axis
100 | features = np.transpose(features, perm)
101 |
102 | # And similarly, do the same for the input features
103 | x_norm = np.linalg.norm(
104 | features,
105 | axis=-1,
106 | keepdims=True,
107 | )
108 |
109 | x_norm = features / (x_norm + epsilon)
110 |
111 | # Compute the concept probabilities accordingly
112 | concept_prob = np.dot(features, concept_vectors.transpose())
113 | concept_prob_norm = np.dot(x_norm, concept_vectors.transpose())
114 |
115 | # Threshold scores accordingly using the normalized scores
116 | if beta is not None:
117 | concept_prob = concept_prob * (concept_prob_norm > beta)
118 |
119 | # And end by normalizing them
120 | norm = np.sum(concept_prob, axis=-1, keepdims=True)
121 | result = concept_prob / (norm + epsilon)
122 |
123 | # Finally, restore the shape of the output tensor if a transpose was
124 | # done at the beginning
125 | if channels_axis not in [-1, len(features.shape) - 1]:
126 | perm = list(range(len(features.shape)))
127 | perm[channels_axis] = len(features.shape) - 1
128 | perm[len(features.shape) - 1] = channels_axis
129 | result = np.transpose(result, perm)
130 |
131 | # And that's all folks
132 | return result
133 |
134 |
135 | ################################################################################
136 | ## Completeness Score Computation
137 | ################################################################################
138 |
139 | def completeness_score(
140 | X,
141 | y,
142 | features_to_concepts_fn,
143 | concepts_to_labels_model,
144 | concept_vectors,
145 | task_loss,
146 | g_model=None,
147 | test_size=0.2,
148 | concept_score_fn=dot_prod_concept_score,
149 | predictor_train_kwags=None,
150 | g_optimizer='adam',
151 | acc_fn=sklearn.metrics.accuracy_score,
152 | channels_axis=-1,
153 | ):
154 | """
155 | Returns the completeness score for the given set of concept vectors
156 | `concept_vectors` using testing data `X` with labels `y`. This score
157 | is computed using Yeh et al.'s definition of a concept completeness
158 | score based on a model `features_to_concepts_fn`, which maps input features
159 | in the test data to a M-dimensional space, and a model
160 | `concepts_to_labels_model` which maps M-dimensional vectors to some
161 | probability distribution over classes in `y`.
162 |
163 | :param np.ndarray X: A tensor of testing samples that are in the domain of
164 | given function `features_to_concepts_fn` where the first dimension
165 | represents the number of test samples (which we call `n_samples`).
166 | :param np.ndarray y: A tensor of testing labels corresponding to matrix
167 | X whose first dimension must also be `n_samples`.
168 | :param Function[(np.ndarray), np.ndarray] features_to_concepts_fn: A
169 | function mapping batches of samples with the same dimensionality as
170 | X into some (n_samples, ..., M, ...)-dimensional vector space
171 | corresponding to the same vector space as that used for the given
172 | concept vectors. In this case, M is the channels dimension at location
173 | channels_axis.
174 | :param tf.keras.Model concepts_to_labels_model: An arbitrary Keras model
175 | which maps M-dimensional vectors (as those produced by calling the
176 | `features_to_concepts_fn` function on a batch of samples) into a
177 | probability distribution over labels in `y`.
178 | :param np.ndarray concept_vectors: A 2D array of shape (num_concepts, M)
179 | where each column represents a given meaningful concept direction.
180 | :param tf.keras.losses.Loss task_loss: The loss function one intends to
181 | minimize when mapping instances in `X` to labels in `y`.
182 | :param tf.keras.Model g_model: The model `g` we will train for mapping
183 | concept scores to the same M-dimensional space when computing
184 | the concept completeness score. If not given, then we will use a
185 | 3-layered ReLU MLP with 500 hidden activations.
186 | :param float test_size: A value between 0 and 1 representing what percent
187 | of the (X, y) data will be used for testing the accuracy of our g_model
188 | (and the original model) when computing the completeness score. The
189 | rest of the data will be used for training our g_model.
190 | :param Function[(np.ndarray,List[np.ndarray]), np.ndarray] concept_score_fn:
191 | A function taking as an input a matrix of shape (n_samples, M),
192 | representing outputs produced by the `features_to_concepts_fn` function,
193 | and a list of`n_concepts` M-dimensional vectors, representing unit
194 | directions of meaningful concepts, and returning a vector with
195 | n_concepts concept scores. By default we use the normalized dot product
196 | scores.
197 | :param Dict[Any, Any] predictor_train_kwags: An optional set of parameters
198 | to pass to the g_model when trained for reconstructing the M-dimensional
199 | activations from their corresponding concept scores.
200 | :param tf.keras.optimizers.Optimizer g_optimizer: The optimizer used for
201 | training the g model for the reconstruction. By default we will use an
202 | ADAM optimizer.
203 | :param Function[(np.ndarray, np.ndarray), float] acc_fn: An accuracy
204 | function taking (true_labels, predicted_labels) and returning an
205 | accuracy value between 0 and 1.
206 | :param int channels_axis: The channels dimension axis of the output of the
207 | features_to_concepts function. If not given, then it is assumed to be
208 | the last dimension.
209 |
210 | :returns Tuple[float, tf.keras.Model]: A tuple (score, g_model) containing
211 | the computed completeness score together with the resulting trained
212 | g_model.
213 | """
214 | # Let's first start by splitting our data into a training and a testing
215 | # set
216 | X_train, X_test, y_train, y_test = train_test_split(
217 | X,
218 | y,
219 | test_size=test_size,
220 | )
221 |
222 | # Let's take a look at the intermediate activations we will be using
223 | phi_train = features_to_concepts_fn(X_train)
224 | scores_train = concept_score_fn(phi_train, concept_vectors)
225 | num_labels = len(set(y))
226 |
227 | # Compute some useful variables while also handling the default case for
228 | # the model we will optimize over
229 | num_concepts = concept_vectors.shape[0]
230 | num_hidden_acts = phi_train.shape[channels_axis]
231 | n_samples = X_train.shape[0]
232 | g_model = g_model or _get_default_model(
233 | num_concepts=num_concepts,
234 | num_hidden_acts=num_hidden_acts,
235 | )
236 | predictor_train_kwags = predictor_train_kwags or {
237 | 'epochs': 50,
238 | 'batch_size': min(16, n_samples),
239 | 'verbose': 0,
240 | }
241 |
242 | # Construct a model that we can use for optimizing our g function
243 | # For this, we will first need to make sure that we set our concepts
244 | # to labels model so that we do not optimize over its parameters
245 | prev_trainable = concepts_to_labels_model.trainable
246 | concepts_to_labels_model.trainable = False
247 | f_prime_input = tf.keras.layers.Input(
248 | shape=scores_train.shape[1:],
249 | dtype=scores_train.dtype,
250 | )
251 | f_prime_output = concepts_to_labels_model(g_model(f_prime_input))
252 | f_prime_optimized = tf.keras.Model(
253 | f_prime_input,
254 | f_prime_output,
255 | )
256 |
257 | # Time to optimize it!
258 | f_prime_optimized.compile(
259 | optimizer=g_optimizer,
260 | loss=task_loss,
261 | )
262 | f_prime_optimized.fit(
263 | scores_train,
264 | y_train,
265 | **predictor_train_kwags,
266 | )
267 |
268 | # Don't forget to reconstruct the state of the concept to labels model
269 | concepts_to_labels_model.trainable = prev_trainable
270 |
271 | # Finally, compute the actual score by computing the accuracy of the
272 | # original concept-composable model
273 | phi_test = features_to_concepts_fn(X_test)
274 | random_pred_acc = 1 / num_labels
275 | f_preds = concepts_to_labels_model.predict(
276 | phi_test
277 | )
278 | f_acc = acc_fn(
279 | y_test,
280 | f_preds,
281 | )
282 |
283 | # And the accuracy of the model using the reconstruction from the
284 | # concept scores
285 | f_prime_preds = f_prime_optimized.predict(
286 | concept_score_fn(phi_test, concept_vectors)
287 | )
288 | f_prime_acc = acc_fn(
289 | y_test,
290 | f_prime_preds,
291 | )
292 |
293 | # That gives us everything we need
294 | if f_prime_acc == random_pred_acc:
295 | return 0, g_model
296 | completeness = (f_prime_acc - random_pred_acc) / (f_acc - random_pred_acc)
297 | return completeness, g_model
298 |
299 |
300 | def direct_completeness_score(
301 | X,
302 | y,
303 | features_to_concepts_fn,
304 | concept_vectors,
305 | task_loss,
306 | g_model=None,
307 | test_size=0.2,
308 | concept_score_fn=dot_prod_concept_score,
309 | predictor_train_kwags=None,
310 | g_optimizer='adam',
311 | acc_fn=sklearn.metrics.accuracy_score,
312 | channels_axis=-1,
313 | ):
314 | """
315 | Returns the completeness score for the given set of concept vectors
316 | `concept_vectors` using testing data `X` with labels `y`. This score
317 | is computed as the predictive accuracy of a model trained to predict
318 | the labels using the concepts scores alone. It differs from the method
319 | above in that it does not require a pre-trained concept_to_labels map.
320 |
321 | :param np.ndarray X: A tensor of testing samples that are in the domain of
322 | given function `features_to_concepts_fn` where the first dimension
323 | represents the number of test samples (which we call `n_samples`).
324 | :param np.ndarray y: A tensor of testing labels corresponding to matrix
325 | X whose first dimension must also be `n_samples`.
326 | :param Function[(np.ndarray), np.ndarray] features_to_concepts_fn: A
327 | function mapping batches of samples with the same dimensionality as
328 | X into some (n_samples, ..., M, ...)-dimensional vector space
329 | corresponding to the same vector space as that used for the given
330 | concept vectors. In this case, M is the channels dimension at location
331 | channels_axis.
332 | :param np.ndarray concept_vectors: A 2D array of shape (num_concepts, M)
333 | where each column represents a given meaningful concept direction.
334 | :param tf.keras.losses.Loss task_loss: The loss function one intends to
335 | minimize when mapping instances in `X` to labels in `y`.
336 | :param tf.keras.Model g_model: The model `g` we will train for mapping
337 | concept scores to the space of labels when computing the concept
338 | completeness score. If not given, then we will use a 3-layered ReLU MLP
339 | with 500 hidden activations.
340 | :param float test_size: A value between 0 and 1 representing what percent
341 | of the (X, y) data will be used for testing the accuracy of our g_model
342 | (and the original model) when computing the completeness score. The
343 | rest of the data will be used for training our g_model.
344 | :param Function[(np.ndarray,List[np.ndarray]), np.ndarray] concept_score_fn:
345 | A function taking as an input a matrix of shape (n_samples, M),
346 | representing outputs produced by the `features_to_concepts_fn` function,
347 | and a list of`n_concepts` M-dimensional vectors, representing unit
348 | directions of meaningful concepts, and returning a vector with
349 | n_concepts concept scores. By default we use the normalized dot product
350 | scores.
351 | :param Dict[Any, Any] predictor_train_kwags: An optional set of parameters
352 | to pass to the g_model when trained for reconstructing the M-dimensional
353 | activations from their corresponding concept scores.
354 | :param tf.keras.optimizers.Optimizer g_optimizer: The optimizer used for
355 | training the g model for the reconstruction. By default we will use an
356 | ADAM optimizer.
357 | :param Function[(np.ndarray, np.ndarray), float] acc_fn: An accuracy
358 | function taking (true_labels, predicted_labels) and returning an
359 | accuracy value between 0 and 1.
360 | :param int channels_axis: The channels dimension axis of the output of the
361 | features_to_concepts function. If not given, then it is assumed to be
362 | the last dimension.
363 |
364 | :returns Tuple[float, tf.keras.Model]: A tuple (score, g_model) containing
365 | the computed completeness score together with the resulting trained
366 | g_model.
367 | """
368 | # Let's first start by splitting our data into a training and a testing
369 | # set
370 | X_train, X_test, y_train, y_test = train_test_split(
371 | X,
372 | y,
373 | test_size=test_size,
374 | )
375 |
376 | # Let's take a look at the intermediate activations we will be using
377 | phi_train = features_to_concepts_fn(X_train)
378 | scores_train = concept_score_fn(
379 | phi_train,
380 | concept_vectors,
381 | )
382 | num_labels = len(set(y))
383 |
384 | # Compute some useful variables while also handling the default case for
385 | # the model we will optimize over
386 | num_concepts = concept_vectors.shape[0]
387 | num_hidden_acts = phi_train.shape[channels_axis]
388 | n_samples = X_train.shape[0]
389 | g_model = g_model or _get_default_model(
390 | num_concepts=num_concepts,
391 | num_hidden_acts=num_labels if num_labels > 2 else 1,
392 | )
393 | predictor_train_kwags = predictor_train_kwags or {
394 | 'epochs': 50,
395 | 'batch_size': min(16, n_samples),
396 | 'verbose': 0,
397 | }
398 |
399 | # Time to optimize it!
400 | g_model.compile(
401 | optimizer=g_optimizer,
402 | loss=task_loss,
403 | )
404 | g_model.fit(
405 | scores_train,
406 | y_train,
407 | **predictor_train_kwags,
408 | )
409 |
410 | # Finally, compute the actual score by computing the accuracy of predicting
411 | # the output labels using only the concept scores
412 | phi_test = features_to_concepts_fn(X_test)
413 | from_concepts_preds = g_model.predict(
414 | concept_score_fn(phi_test, concept_vectors)
415 | )
416 | from_concepts_acc = acc_fn(
417 | y_test,
418 | from_concepts_preds,
419 | )
420 |
421 | # That gives us everything we need
422 | return from_concepts_acc, g_model
423 |
--------------------------------------------------------------------------------
/concepts_xai/evaluation/metrics/downstream_task.py:
--------------------------------------------------------------------------------
1 | from sklearn.model_selection import train_test_split
2 | from sklearn.metrics import accuracy_score
3 |
4 |
5 | def compute_downstream_task(c_pred, y_true, predictor_model):
6 | '''
7 | Computes the accuracy score for the downstream task
8 | :param c_pred: Concept data predictions, numpy array of shape
9 | (n_samples, n_concepts)
10 | :param y_true: Ground-truth task label data, numpy array of
11 | shape (n_samples,)
12 | :param predictor_model: sklearn model to use for predicting the task labels
13 | from the concept data
14 | :return: Accuracy of predictor_model, trained and evaluated on the provided
15 | concept and label data
16 | '''
17 | c_train, c_test, y_train, y_test = train_test_split(c_pred, y_true)
18 | predictor_model.fit(c_train, y_train)
19 | y_pred = predictor_model.predict(y_test)
20 | return accuracy_score(y_test, y_pred)
21 |
--------------------------------------------------------------------------------
/concepts_xai/evaluation/metrics/mpo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def total_mispredictions_fn(c_true, c_pred):
5 | '''
6 | Count the total number of mispredictions
7 | :param c_true: ground-truth concept values
8 | :param c_pred: predicted concept values
9 | :return: Number of elements i where c_true[i] != c_pred[i]
10 | '''
11 |
12 | n_errs = np.sum(c_true != c_pred)
13 | return n_errs
14 |
15 |
16 | def compute_MPO(c_true, c_pred, err_fn=total_mispredictions_fn):
17 | '''
18 | Implementation of the (M)is-(P)rediction (O)verlap (MPO) metric from the
19 | CME paper https://arxiv.org/pdf/2010.13233.pdf
20 | Given a set of predicted concept values, MPO computes the fraction of
21 | samples in the test set, that have at least m relevant concepts predicted
22 | incorrectly
23 |
24 | :param c_true: ground truth concept data, numpy array of shape
25 | (n_samples, n_concepts)
26 | :param c_pred: predicted concept data, numpy array of shape
27 | (n_samples, n_concepts)
28 | :param err_fn: function for computing the error on one single sample of
29 | c_test and c_pred (e.g. if you want to ignore certain concepts, or
30 | weight concept errors differently). Should be a function taking two
31 | arguments of shape (n_concepts), and returning a scalar value.
32 | Defaults to computing the total number of mispredictions.
33 | :return: MPO metric values computed from c_test and c_pred using err_fn,
34 | with m ranging from 0 to n_concepts
35 | '''
36 |
37 | # Apply error function over all samples
38 | err_vals = np.array([
39 | err_fn(c_true[i], c_pred[i]) for i in range(c_true.shape[0])
40 | ])
41 |
42 | # Compute MPO values for m ranging from 0 to n_concepts
43 | n_concepts = c_true.shape[1]
44 | mpo_vals = []
45 | for i in range(n_concepts):
46 | # Compute number of samples with at least i incorrect concept
47 | # predictions
48 | n_incorrect = (err_vals >= i).astype(np.int)
49 |
50 | # Compute % of these samples from total
51 | metric = (np.sum(n_incorrect) / c_true.shape[0])
52 | mpo_vals.append(metric)
53 |
54 | return np.array(mpo_vals)
55 |
--------------------------------------------------------------------------------
/concepts_xai/evaluation/metrics/niching.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import accuracy_score, roc_auc_score
3 | from sklearn.feature_selection import mutual_info_classif
4 | from scipy.special import softmax
5 |
6 |
7 | def niche_completeness(c_pred, y_true, predictor_model, niches):
8 | '''
9 | Computes the niche completeness score for the downstream task
10 | :param c_pred: Concept data predictions, numpy array of shape
11 | (n_samples, n_concepts)
12 | :param y_true: Ground-truth task label data, numpy array of shape
13 | (n_samples, n_tasks)
14 | :param predictor_model: trained decoder model to use for predicting the task
15 | labels from the concept data
16 | :return: Accuracy of predictor_model, evaluated on niches obtained from the
17 | provided concept and label data
18 | '''
19 | n_tasks = y_true.shape[1]
20 | # compute niche completeness for each task
21 | niche_completeness_list, y_pred_list = [], []
22 | for task in range(n_tasks):
23 | # find niche
24 | niche = np.zeros_like(c_pred)
25 | niche[:, niches[:, task] > 0] = c_pred[:, niches[:, task] > 0]
26 |
27 | # compute task predictions
28 | y_pred_niche = predictor_model.predict_proba(niche)
29 | if predictor_model.__class__.__name__ == 'Sequential':
30 | # get class labels from logits
31 | y_pred_niche = y_pred_niche > 0
32 | elif len(y_pred_niche.shape) == 1:
33 | y_pred_niche = y_pred_niche[:, np.newaxis]
34 |
35 | y_pred_list.append(y_pred_niche[:, task])
36 |
37 | y_preds = np.vstack(y_pred_list).T
38 | y_preds = softmax(y_preds, axis=1)
39 | auc = roc_auc_score(y_true, y_preds, multi_class='ovo')
40 |
41 | result = {
42 | 'auc_completeness': auc,
43 | 'y_preds': y_preds,
44 | }
45 | return result
46 |
47 |
48 | def niche_completeness_ratio(c_pred, y_true, predictor_model, niches):
49 | '''
50 | Computes the niche completeness ratio for the downstream task
51 | :param c_pred: Concept data predictions, numpy array of shape
52 | (n_samples, n_concepts)
53 | :param y_true: Ground-truth task label data, numpy array of shape
54 | (n_samples, n_tasks)
55 | :param predictor_model: sklearn model to use for predicting the task labels
56 | from the concept data
57 | :return: Accuracy ratio between the accuracy of predictor_model evaluated
58 | on niches and the accuracy of predictor_model evaluated on all concepts
59 | '''
60 | n_tasks = y_true.shape[1]
61 |
62 | y_pred_test = predictor_model.predict_proba(c_pred)
63 | if predictor_model.__class__.__name__ == 'Sequential':
64 | # get class labels from logits
65 | y_pred_test = y_pred_test > 0
66 | elif len(y_pred_test.shape) == 1:
67 | y_pred_test = y_pred_test[:, np.newaxis]
68 |
69 | # compute niche completeness for each task
70 | niche_completeness_list = []
71 | for task in range(n_tasks):
72 | # find niche
73 | niche = np.zeros_like(c_pred)
74 | niche[:, niches[:, task] > 0] = c_pred[:, niches[:, task] > 0]
75 |
76 | # compute task predictions
77 | y_pred_niche = predictor_model.predict_proba(niche)
78 | if predictor_model.__class__.__name__ == 'Sequential':
79 | # get class labels from logits
80 | y_pred_niche = y_pred_niche > 0
81 | elif len(y_pred_niche.shape) == 1:
82 | y_pred_niche = y_pred_niche[:, np.newaxis]
83 |
84 | # compute accuracies
85 | accuracy_base = accuracy_score(y_true[:, task], y_pred_test[:, task])
86 | accuracy_niche = accuracy_score(y_true[:, task], y_pred_niche[:, task])
87 |
88 | # compute the accuracy ratio of the niche w.r.t. the baseline
89 | # (full concept bottleneck) the higher the better (high predictive power
90 | # of the niche)
91 | niche_completeness = accuracy_niche / accuracy_base
92 | niche_completeness_list.append(niche_completeness)
93 |
94 | result = {
95 | 'niche_completeness_ratio_mean': np.mean(niche_completeness_list),
96 | 'niche_completeness_ratio': niche_completeness_list,
97 | }
98 | return result
99 |
100 |
101 | def niche_impurity(c_pred, y_true, predictor_model, niches):
102 | '''
103 | Computes the niche impurity score for the downstream task
104 | :param c_pred: Concept data predictions, numpy array of shape
105 | (n_samples, n_concepts)
106 | :param y_true: Ground-truth task label data, numpy array of shape
107 | (n_samples, n_tasks)
108 | :param predictor_model: sklearn model to use for predicting the task labels
109 | from the concept data
110 | :return: Accuracy ratio between the accuracy of predictor_model evaluated on
111 | concepts outside niches and the accuracy of predictor_model evaluated on
112 | concepts inside niches
113 | '''
114 | n_tasks = y_true.shape[1]
115 |
116 | # compute niche completeness for each task
117 | y_pred_list = []
118 | for task in range(n_tasks):
119 | # find niche
120 | niche = np.zeros_like(c_pred)
121 | niche[:, niches[:, task] > 0] = c_pred[:, niches[:, task] > 0]
122 |
123 | # find concepts outside the niche
124 | niche_out = np.zeros_like(c_pred)
125 | niche_out[:, niches[:, task] <= 0] = c_pred[:, niches[:, task] <= 0]
126 |
127 | # compute task predictions
128 | y_pred_niche = predictor_model.predict_proba(niche)
129 | y_pred_niche_out = predictor_model.predict_proba(niche_out)
130 | if predictor_model.__class__.__name__ == 'Sequential':
131 | # get class labels from logits
132 | y_pred_niche_out = y_pred_niche_out > 0
133 | elif len(y_pred_niche.shape) == 1:
134 | y_pred_niche_out = y_pred_niche_out[:, np.newaxis]
135 |
136 | y_pred_list.append(y_pred_niche_out[:, task])
137 |
138 | y_preds = np.vstack(y_pred_list).T
139 | y_preds = softmax(y_preds, axis=1)
140 | auc = roc_auc_score(y_true, y_preds, multi_class='ovo')
141 |
142 | return {
143 | 'auc_impurity': auc,
144 | 'y_preds': y_preds,
145 | }
146 |
147 |
148 | def niche_finding(c, y, mode='mi', threshold=0.5):
149 | n_concepts = c.shape[1]
150 | if mode == 'corr':
151 | corrm = np.corrcoef(np.hstack([c, y]).T)
152 | niching_matrix = corrm[:n_concepts, n_concepts:]
153 | niches = np.abs(niching_matrix) > threshold
154 | elif mode == 'mi':
155 | nm = []
156 | for yj in y.T:
157 | mi = mutual_info_classif(c, yj)
158 | nm.append(mi)
159 | nm = np.vstack(nm).T
160 | niching_matrix = nm / np.max(nm)
161 | niches = niching_matrix > threshold
162 | else:
163 | return None, None, None
164 |
165 | return niches, niching_matrix
166 |
--------------------------------------------------------------------------------
/concepts_xai/methods/CBM/CBModel.py:
--------------------------------------------------------------------------------
1 | """
2 | Module containing implementations for Concept Bottleneck Models (CBMs) as
3 | described by Koh et al. in https://arxiv.org/abs/2007.04612
4 | """
5 |
6 | import numpy as np
7 | import tensorflow as tf
8 |
9 | from collections import defaultdict
10 | from tensorflow.python.keras.engine import data_adapter
11 |
12 |
13 | ################################################################################
14 | ## Exposed Functions
15 | ################################################################################
16 |
17 | def produce_bottleneck(model, layer_idx):
18 | """
19 | Partitions a Keras model into two disjoint computational graphs: (1) an
20 | encoder that maps from inputs to activations from layer with index
21 | `layer_idx` and (2) a decoder model that maps activations from layer
22 | with index `layer_idx` to the output of the given model.
23 |
24 | For this operation to be successful, the layer at index `layer_idx`
25 | must be a bottleneck of the input model (i.e., there may not be any
26 | connections from the layers preceding the bottleneck layer with those
27 | topologically after the bottleneck layer).
28 |
29 | :param tf.keras.Model model: A model which we will split into two disjoint
30 | submodels.
31 | :param int layer_idx: A valid layer index in the given model. It must
32 | be a valid index in the array of layers represented by model.layers.
33 |
34 | :return Tuple[tf.keras.Model, tf.keras.Model]: a tuple of models
35 | (encoder, decoder) representing the input to bottleneck model and the
36 | bottleneck to output model, respectively.
37 | """
38 |
39 | # Let's start by making sure we get a full understanding of the input
40 | # model's topology
41 | in_edges = defaultdict(set)
42 | out_edges = defaultdict(set)
43 | name_to_layer = {}
44 | for src in model.layers:
45 | name_to_layer[src.name] = src
46 | for dst_node in src._outbound_nodes:
47 | in_edges[dst_node.layer.name].add(src.name)
48 | out_edges[src.name].add(dst_node.layer.name)
49 |
50 | # Now let's find the layer we will use as our bottleneck layer
51 | if len(model.layers) <= layer_idx:
52 | raise ValueError(
53 | f"Requested to use layer with index {layer_idx} as the bottleneck "
54 | f"layer however given model '{model.name}' has only "
55 | f"{len(model.layers)} indexable layers."
56 | )
57 | bottleneck_layer = model.layers[layer_idx]
58 |
59 | # Once we have the bottleneck, let's look at all the nodes that precede it
60 | # and follow it in the computational graph defined by `model`. For this to
61 | # be considered a valid bottleneck, the set of nodes after the bottleneck
62 | # layer must be disjoint from the set of nodes preceding the bottleneck
63 | # layer (i.e., there must be no edges from the subgraph preceding the
64 | # bottleneck layer into the subgraph defined by nodes that are topologically
65 | # after the bottleneck layer).
66 | preceding_nodes = set()
67 | frontier = [bottleneck_layer.name]
68 | while frontier:
69 | next_node = frontier.pop()
70 | for src_name in in_edges[next_node]:
71 | if src_name in preceding_nodes:
72 | # Then we have already dealt with it
73 | continue
74 | preceding_nodes.add(src_name)
75 | frontier.append(src_name)
76 |
77 | # And walk the graph to compute the nodes after the bottleneck layer
78 | posterior_nodes = set()
79 | frontier = [bottleneck_layer.name]
80 | while frontier:
81 | next_node = frontier.pop()
82 | for dst_name in out_edges[next_node]:
83 | if dst_name in posterior_nodes:
84 | # Then we have already dealt with it
85 | continue
86 | posterior_nodes.add(dst_name)
87 | frontier.append(dst_name)
88 |
89 | if (posterior_nodes & preceding_nodes):
90 | raise ValueError(
91 | f"Requested bottleneck layer {bottleneck_layer.name} (index "
92 | f"{layer_idx}) does not partition the computational graph of "
93 | f"provided the model '{model.name}' into two disjoint subgraphs ("
94 | f"i.e., there is a connection between layers preceeding the "
95 | f"bottleneck and layers after the bottleneck layer)."
96 | )
97 |
98 | # We can now compute the size of the actual bottleneck
99 | if isinstance(bottleneck_layer.output, list):
100 | raise ValueError(
101 | f"Currently we do not support as a bottleneck layer a layer that "
102 | f"has more than one output. Requested bottleneck layer "
103 | f"{bottleneck_layer.name} (at index {layer_idx}) has "
104 | f"{len(bottleneck_layer.output)} outputs."
105 | )
106 | # Else let's check the number of concepts we are expecting vs the number
107 | # of entries
108 | num_concepts = bottleneck_layer.output.shape[-1]
109 |
110 | # With this, building the encoder is now trivial
111 | encoder = tf.keras.Model(
112 | inputs=model.inputs,
113 | outputs=bottleneck_layer.output,
114 | )
115 |
116 | decoder_input = tf.keras.layers.Input(num_concepts)
117 | decoder_outputs = []
118 | name_to_layer[bottleneck_layer.name] = decoder_input
119 | for layer in model.layers:
120 | if (layer.name == bottleneck_layer.name) or (
121 | layer.name not in posterior_nodes
122 | ):
123 | continue
124 | # Otherwise let's make sure we feed it with the input corresponding
125 | # to the new computational graph we are constructing from the bottleneck
126 | # layer (NOTE: this works as we are iterating over layers in topological
127 | # order)
128 | input_layers = []
129 | for input_name in in_edges[layer.name]:
130 | input_layers.append(name_to_layer[input_name])
131 | if len(input_layers) == 1:
132 | input_layers = input_layers[0]
133 | # Generate the new node
134 | new_node = layer(input_layers)
135 | name_to_layer[layer.name] = new_node
136 |
137 | # And add it to the output if this was an original output
138 | if layer.name in model.output_names:
139 | decoder_outputs.append(new_node)
140 | decoder = tf.keras.Model(
141 | inputs=decoder_input,
142 | outputs=decoder_outputs
143 | )
144 |
145 | return encoder, decoder
146 |
147 |
148 | ################################################################################
149 | ## Exposed Classes
150 | ################################################################################
151 |
152 | class JointConceptBottleneckModel(tf.keras.Model):
153 | """
154 | Main class for implementing a Joint Concept Bottleneck Model with the
155 | given encoder mapping input features to concepts and the given
156 | decoder which maps concept encodings to labels.
157 | This class encapsulates the joint training process of a CBM while allowing
158 | an arbitrary encoder/decoder model to be used in its construction.
159 |
160 | Note that it generalizes the original CBM by Koh et al. by allowing the
161 | encoder to produce non-binary concepts rather than assuming all concepts
162 | are binary in nature.
163 | """
164 |
165 | def __init__(
166 | self,
167 | encoder,
168 | decoder,
169 | task_loss,
170 | alpha=0.01,
171 | metrics=None,
172 | pass_concept_logits=False,
173 | concept_sample_weights=None,
174 | single_multiclass_concept=False,
175 | **kwargs
176 | ):
177 | """
178 | Constructs a new trainable joint CBM which can be then trained,
179 | optimized and/or used for prediction as any other Keras model can.
180 |
181 | When using this model for prediction, it will return a tuple
182 | (labels, concepts) indicating the predicted label probabilities as
183 | well as the predicted concepts probabilities.
184 |
185 | :param tf.keras.Model encoder: A valid keras model that maps input
186 | features to a set of concepts. If the output of this model is
187 | a single vector, then every entry of this vector is assumed to be
188 | one binary concept. Otherwise, if the output of the encoder is a
189 | list of vectors, then we assume that each vector represents a
190 | probability distribution over different classes for each concept (
191 | i.e., we assume one concept per vector).
192 | :param tf.keras.Model decoder: A valid keras model mapping a concept
193 | vector to a set of task-specific labels. We assume that if the
194 | encoder outputs a list of concepts then the input to this model
195 | is the concatenation of all the output vectors of the encoder in
196 | the same order as produced by the encoder.
197 | :param tf.keras.losses.Loss task_loss: The loss to be used for the
198 | specific task labels.
199 | :param float alpha: A parameter indicating how much weight should one
200 | assign the loss coming from training the bottleneck. If 0, then
201 | there is no learning enforced in the bottleneck.
202 | :param List[tf.keras.metrics.Metric] metrics: A list of possible
203 | metrics of interest which one may want to monitor during training.
204 | :param Bool pass_concept_logits: Whether the concept bottleneck will
205 | be passed to the concept-to-task model as logits (i.e., without
206 | a softmax or sigmoid operation applied to it) or not. If this is
207 | set to false, then it is the responsability of the input encoder
208 | model to output a valid probability distribution.
209 | :param Dict[Any, Any] kwargs: Keras Layer specific kwargs to be passed
210 | to the parent constructor.
211 | """
212 | super(JointConceptBottleneckModel, self).__init__(**kwargs)
213 | self.encoder = encoder
214 | self.decoder = decoder
215 | self.total_loss_tracker = tf.keras.metrics.Mean(
216 | name="total_loss"
217 | )
218 | self.concept_loss_tracker = tf.keras.metrics.Mean(
219 | name="concept_loss"
220 | )
221 | self.task_loss_tracker = tf.keras.metrics.Mean(
222 | name="task_loss"
223 | )
224 | self.concept_accuracy_tracker = tf.keras.metrics.Mean(
225 | name="concept_accuracy"
226 | )
227 | self._acc_metric = \
228 | lambda y_true, y_pred: tf.keras.metrics.sparse_top_k_categorical_accuracy(
229 | y_true,
230 | y_pred,
231 | k=1,
232 | )
233 | self._bin_acc_metric = \
234 | lambda y_true, y_pred: tf.math.reduce_mean(
235 | tf.keras.metrics.binary_accuracy(y_true, y_pred),
236 | axis=-1,
237 | )
238 | self.alpha = alpha
239 | self.task_loss = task_loss
240 | self.extra_metrics = metrics or []
241 | self.pass_concept_logits = pass_concept_logits
242 | self.concept_sample_weights = concept_sample_weights
243 | self.single_multiclass_concept = single_multiclass_concept
244 |
245 | # dummy call to build the model
246 | self(tf.zeros(list(map(
247 | lambda x: 1 if x is None else x,
248 | self.encoder.input_shape
249 | ))))
250 |
251 | @property
252 | def metrics(self):
253 | return [
254 | self.total_loss_tracker,
255 | self.concept_loss_tracker,
256 | self.task_loss_tracker,
257 | self.concept_accuracy_tracker,
258 | ] + self.extra_metrics
259 |
260 | def predict_from_concepts(self, concepts):
261 | """
262 | Given a set of concepts (e.g., coming from an intervention), this
263 | function returns the predicted labels for those concepts.
264 |
265 | :param np.ndarray concepts: A matrix of concepts predictions from which
266 | we wish to obtain classes for. It shape must be
267 | (n_samples, n_concepts) if concepts are binary. Otherwise, it should
268 | have shape (n_samples, ).
269 |
270 | :returns np.ndarray: Label probability predictions for the given set
271 | of concepts.
272 | """
273 | if isinstance(concepts, list):
274 | if len(concepts) > 1:
275 | concepts = tf.keras.layers.Concatenate(axis=-1)(
276 | concepts
277 | )
278 | else:
279 | concepts = concepts[0]
280 | return self.decoder(concepts)
281 |
282 | def call(self, inputs):
283 | # We will use the log of the variance rather than the actual variance
284 | # for stability purposes
285 | outputs, concepts, _ = self._call_fn(inputs)
286 | return outputs, concepts
287 |
288 | def _call_fn(self, inputs, **kwargs):
289 | # This method is separate from the call method above as it allows one
290 | # to overwrite this class and include an extra set of losses (returned
291 | # as the third element in the tuple) which could, for example, include
292 | # some decorrelation regularization term between concept predictions.
293 | concepts = self.encoder(inputs, **kwargs)
294 | return self.predict_from_concepts(concepts), concepts, []
295 |
296 | def _compute_losses(
297 | self,
298 | predicted_labels,
299 | predicted_concepts,
300 | true_labels,
301 | true_concepts,
302 | ):
303 | """
304 | Helper function for computing all the losses we require for training
305 | our joint CBM.
306 | """
307 | # Updates stateful loss metrics.
308 | task_loss = self.task_loss(true_labels, predicted_labels)
309 | concept_loss = 0.0
310 | concept_accuracy = 0.0
311 | # If generating model does not produce a list of outputs, then we will
312 | # assume all concepts are binary
313 | if isinstance(predicted_concepts, list):
314 | for i, predicted_vec in enumerate(predicted_concepts):
315 | true_vec = true_concepts[:, i]
316 | sample_weight = None
317 | if self.concept_sample_weights is not None:
318 | sample_weight = self.concept_sample_weights[:, i:i+1]
319 | if (len(predicted_vec.shape) == 1) or (
320 | predicted_vec.shape[-1] == 1
321 | ):
322 | # Then use binary loss here
323 | concept_loss += tf.keras.losses.BinaryCrossentropy(
324 | from_logits=self.pass_concept_logits,
325 | )(
326 | true_vec,
327 | predicted_vec,
328 | sample_weight=sample_weight,
329 | )
330 | if len(predicted_vec.shape) == 2:
331 | # Then, let's remove the degenerate dimension
332 | predicted_vec = tf.squeeze(predicted_vec, axis=-1)
333 | concept_accuracy += self._bin_acc_metric(
334 | true_vec,
335 | predicted_vec,
336 | )
337 | else:
338 | # Otherwise use normal cross entropy
339 | concept_loss += \
340 | tf.keras.losses.SparseCategoricalCrossentropy(
341 | from_logits=self.pass_concept_logits,
342 | )(
343 | true_vec,
344 | predicted_vec,
345 | sample_weight=sample_weight,
346 | )
347 | concept_accuracy += self._acc_metric(
348 | true_vec,
349 | predicted_vec,
350 | )
351 |
352 | # And time to normalize over all the different heads
353 | concept_loss = concept_loss / len(predicted_concepts)
354 | concept_accuracy = concept_accuracy / len(predicted_concepts)
355 | elif self.single_multiclass_concept:
356 | # Then all elements in the bottleneck correspond to a single
357 | # concept that is a multi-class concept
358 | concept_loss += \
359 | tf.keras.losses.SparseCategoricalCrossentropy(
360 | from_logits=self.pass_concept_logits,
361 | )(
362 | true_concepts,
363 | predicted_concepts,
364 | sample_weight=self.concept_sample_weights,
365 | )
366 | concept_accuracy += self._acc_metric(
367 | true_concepts,
368 | predicted_concepts,
369 | )
370 | else:
371 | # Then use binary loss here as we are given a single vector and we
372 | # will assume in that instance they all represent independent
373 | # binary concepts
374 | concept_loss += tf.keras.losses.BinaryCrossentropy(
375 | from_logits=self.pass_concept_logits,
376 | )(
377 | true_concepts,
378 | predicted_concepts,
379 | sample_weight=self.concept_sample_weights,
380 | )
381 | concept_accuracy = self._bin_acc_metric(
382 | true_concepts,
383 | predicted_concepts,
384 | )
385 | return task_loss, concept_loss, concept_accuracy
386 |
387 | def test_step(self, data):
388 | """
389 | Overwrite function for the Keras model indicating how a test step
390 | will operate.
391 |
392 | :param Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]] data: The input
393 | training data is expected to be provided in the form
394 | (input_features, (true_labels, true_concepts)).
395 | """
396 | # Massage the data
397 | data = data_adapter.expand_1d(data)
398 | input_features, (true_labels, true_concepts), sample_weight = \
399 | data_adapter.unpack_x_y_sample_weight(data)
400 |
401 | # Obtain a prediction of labels and concepts
402 | predicted_labels, predicted_concepts, extra_losses = self._call_fn(
403 | input_features,
404 | training=False,
405 | )
406 | # Compute the actual losses
407 | task_loss, concept_loss, concept_accuracy = self._compute_losses(
408 | predicted_labels=predicted_labels,
409 | predicted_concepts=predicted_concepts,
410 | true_labels=true_labels,
411 | true_concepts=true_concepts,
412 | )
413 |
414 | # Accumulate both the concept and task-specific loss into a single value
415 | total_loss = (
416 | task_loss +
417 | self.alpha * concept_loss
418 | )
419 | for extra_loss in extra_losses:
420 | total_loss += extra_loss
421 | result = {
422 | self.concept_accuracy_tracker.name: concept_accuracy,
423 | self.concept_loss_tracker.name: concept_loss,
424 | self.task_loss_tracker.name: task_loss,
425 | self.total_loss_tracker.name: total_loss,
426 | }
427 | for metric in self.extra_metrics:
428 | result[metric.name] = metric(
429 | true_labels,
430 | predicted_labels,
431 | sample_weight,
432 | )
433 | return result
434 |
435 | @tf.function
436 | def train_step(self, data):
437 | """
438 | Overwrite function for the Keras model indicating how a train step
439 | will operate.
440 |
441 | :param Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]] data: The input
442 | training data is expected to be provided in the form
443 | (input_features, (true_labels, true_concepts)).
444 | """
445 | # Massage the data
446 | data = data_adapter.expand_1d(data)
447 | input_features, (true_labels, true_concepts), sample_weight = \
448 | data_adapter.unpack_x_y_sample_weight(data)
449 | with tf.GradientTape() as tape:
450 | # Obtain a prediction of labels and concepts
451 | predicted_labels, predicted_concepts, extra_losses = self._call_fn(
452 | input_features
453 | )
454 | # Compute the actual losses
455 | task_loss, concept_loss, concept_accuracy = self._compute_losses(
456 | predicted_labels=predicted_labels,
457 | predicted_concepts=predicted_concepts,
458 | true_labels=true_labels,
459 | true_concepts=true_concepts,
460 | )
461 | # Accumulate both the concept and task-specific loss into a single
462 | # value
463 | total_loss = (
464 | task_loss +
465 | self.alpha * concept_loss
466 | )
467 | # And include any extra losses coming from this process
468 | for extra_loss in extra_losses:
469 | total_loss += extra_loss
470 |
471 | num_concepts = (
472 | len(predicted_concepts) if isinstance(predicted_concepts, list) else
473 | predicted_concepts.shape[-1]
474 | )
475 | grads = tape.gradient(total_loss, self.trainable_weights)
476 | self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
477 | self.total_loss_tracker.update_state(total_loss, sample_weight)
478 | self.task_loss_tracker.update_state(task_loss, sample_weight)
479 | self.concept_loss_tracker.update_state(concept_loss, sample_weight)
480 | self.concept_accuracy_tracker.update_state(
481 | concept_accuracy,
482 | sample_weight,
483 | )
484 | for metric in self.extra_metrics:
485 | metric.update_state(true_labels, predicted_labels, sample_weight)
486 | return {
487 | metric.name: metric.result()
488 | for metric in self.metrics
489 | }
490 |
491 |
492 | class BypassJointCBM(JointConceptBottleneckModel):
493 | def __init__(
494 | self,
495 | encoder,
496 | decoder,
497 | task_loss,
498 | alpha=0.01,
499 | metrics=None,
500 | pass_concept_logits=False,
501 | concept_sample_weights=None,
502 | single_multiclass_concept=False,
503 | **kwargs,
504 | ):
505 | """
506 | Extension of CBM model above that allows extra capacity in the
507 | bottleneck for activations that have no concept supervision.
508 | Expects the encoder to output a tuple (concepts, latent_code) where
509 | concepts has the same required properties of the output of the encoder
510 | in JointConceptBottleneckModel and latent_code is a np.ndarray vector
511 | representing activations in the bottleneck that have no supervision.
512 |
513 | The concatentation of the elements in concepts and latent_code will be
514 | fed into the provided decoder model.
515 |
516 | When using this model for prediction, it will return a tuple
517 | (labels, concepts) indicating the predicted label probabilities as
518 | well as the predicted concepts probabilities.
519 |
520 | :param tf.keras.Model encoder: A valid keras model that maps input
521 | features to a set of concepts and a vector of unsupervised latent
522 | activations. If the concept output of this model is a single vector,
523 | then every entry of that vector is assumed to be one binary concept.
524 | Otherwise, if the output of the encoder's concepts is a list of
525 | vectors, then we assume that each vector represents a probability
526 | distribution over different classes for each concept (i.e., we
527 | assume one concept per vector).
528 | :param tf.keras.Model decoder: A valid keras model mapping a concept
529 | vector concatenated to the unsupervised latent dimensions to a set
530 | of task-specific labels. We assume that if the encoder outputs a
531 | list of concepts, then the input to this model is the concatenation
532 | of all the output vectors of the encoder (including the unsupervised
533 | latent dimensions) in the same order as produced by the encoder.
534 | :param tf.keras.losses.Loss task_loss: The loss to be used for the
535 | specific task labels.
536 | :param float alpha: A parameter indicating how much weight should one
537 | assign the loss coming from training the bottleneck. If 0, then
538 | there is no learning enforced in the bottleneck.
539 | :param List[tf.keras.metrics.Metric] metrics: A list of possible
540 | metrics of interest which one may want to monitor during training.
541 | :param Bool pass_concept_logits: Whether the concept bottleneck will
542 | be passed to the concept-to-task model as logits (i.e., without
543 | a softmax or sigmoid operation applied to it) or not. If this is
544 | set to false, then it is the responsability of the input encoder
545 | model to output a valid probability distribution.
546 | :param Dict[Any, Any] kwargs: Keras Layer specific kwargs to be passed
547 | to the parent constructor.
548 |
549 | """
550 | super(BypassJointCBM, self).__init__(
551 | encoder=encoder,
552 | decoder=decoder,
553 | task_loss=task_loss,
554 | alpha=alpha,
555 | metrics=metrics,
556 | pass_concept_logits=pass_concept_logits,
557 | concept_sample_weights=concept_sample_weights,
558 | single_multiclass_concept=single_multiclass_concept,
559 | **kwargs
560 | )
561 |
562 | def call(self, inputs):
563 | # We will use the log of the variance rather than the actual variance
564 | # for stability purposes
565 | concepts, latent_code = self.encoder(inputs)
566 | if not isinstance(concepts, list):
567 | decode_inputs = [concepts] + [latent_code]
568 | else:
569 | decode_inputs = concepts + [latent_code]
570 | return (
571 | self.predict_from_concepts(decode_inputs),
572 | concepts,
573 | latent_code,
574 | )
575 |
576 | def _call_fn(self, inputs, **kwargs):
577 | # Compute our concepts and latent code
578 | concepts, latent_code = self.encoder(inputs, **kwargs)
579 | if not isinstance(concepts, list):
580 | decode_inputs = [concepts] + [latent_code]
581 | else:
582 | decode_inputs = concepts + [latent_code]
583 | return self.predict_from_concepts(decode_inputs), concepts, []
584 |
--------------------------------------------------------------------------------
/concepts_xai/methods/CME/CtlModel.py:
--------------------------------------------------------------------------------
1 | from sklearn.linear_model import LogisticRegression, LinearRegression
2 | from sklearn.tree import DecisionTreeClassifier
3 | from sklearn.ensemble import GradientBoostingClassifier
4 |
5 | '''
6 | CtL : Concept-to-Label
7 |
8 | Class for the transparent model representing a function from concepts to task
9 | labels. Represents the decision-making process of a given black-box model, in
10 | the concept representation
11 | '''
12 |
13 |
14 | class CtLModel:
15 |
16 | def __init__(self, **params):
17 |
18 | # Create copy of passed-in parameters
19 | self.params = params
20 |
21 | # Assign parameter values
22 | self.clf_type = self.params.get("method", "DT")
23 | self.n_concepts = self.params.get("n_concepts")
24 | self.n_classes = self.params.get("n_classes")
25 | self.c_names = self.params.get(
26 | "concept_names",
27 | ["Concept_" + str(i) for i in range(self.n_concepts)],
28 | )
29 | self.cls_names = self.params.get(
30 | "class_names",
31 | ["Class_ " + str(i) for i in range(self.n_classes)],
32 | )
33 |
34 | def train(self, c_data, y_data):
35 | if self.clf_type == 'DT':
36 | clf = DecisionTreeClassifier(class_weight='balanced')
37 | elif self.clf_type == 'LR':
38 | clf = LogisticRegression(
39 | max_iter=200,
40 | multi_class='auto',
41 | solver='lbfgs',
42 | )
43 | elif self.clf_type == 'LinearRegression':
44 | clf = LinearRegression()
45 | elif self.clf_type == 'GBT':
46 | clf = GradientBoostingClassifier()
47 | else:
48 | raise ValueError("Unrecognised model type...")
49 |
50 | clf.fit(c_data, y_data)
51 | self.clf = clf
52 |
53 | def predict(self, c_data):
54 | return self.clf.predict(c_data)
55 |
56 |
57 |
--------------------------------------------------------------------------------
/concepts_xai/methods/CME/ItCModel.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from sklearn.dummy import DummyClassifier
4 | from sklearn.ensemble import GradientBoostingClassifier
5 | from sklearn.linear_model import LogisticRegression
6 | from sklearn.metrics import accuracy_score
7 | from sklearn.model_selection import train_test_split
8 | from sklearn.semi_supervised import LabelSpreading
9 |
10 |
11 | class ItCModel(object):
12 |
13 | def __init__(self, model, **params):
14 |
15 | # Create copy of passed-in parameters
16 | self.params = params
17 |
18 | # Retrieve the layers to use for hidden representations
19 | self.layer_ids = params.get(
20 | "layer_ids",
21 | [i for i in range(len(model.layers))],
22 | )
23 | if self.layer_ids is None:
24 | self.layer_ids = [i for i in range(len(model.layers))]
25 | # Retrieve the corresponding layer names
26 | self.layer_names = params.get(
27 | "layer_names",
28 | ["Layer_" + str(i) for i in range(len(model.layers))],
29 | )
30 |
31 | # Set number of parameters
32 | self.n_concepts = params['n_concepts']
33 |
34 | # Retrieve the concept names
35 | self.concept_names = params.get(
36 | "concept_names",
37 | ["Concept_" + str(i) for i in range(self.n_concepts)],
38 | )
39 |
40 | # Batch size to use during computation
41 | self.batch_size = params.get("batch_size", 128)
42 |
43 | # Set classifier to use for concept value prediction
44 | self.clf_method = params.get("method", "LR")
45 |
46 | # Model for extracting activations from
47 | self.model = model
48 |
49 | def get_clf(self):
50 | if self.clf_method == 'LR':
51 | semi_supervised = False
52 | clf = LogisticRegression(max_iter=200, class_weight='balanced')
53 | elif self.clf_method == 'LP':
54 | semi_supervised = True
55 | clf = LabelSpreading()
56 | elif self.clf_method == 'GBT':
57 | clf = GradientBoostingClassifier()
58 | semi_supervised = False
59 | else:
60 | raise ValueError("Non-implemented method")
61 |
62 | return clf, semi_supervised
63 |
64 | def _get_layer_concept_predictor_model(self, h_data_l, c_data_l, h_data_u):
65 | '''
66 | Train cme concept label predictor for a particular layer and concept
67 | :param h_data_l: Activation data for a given layer
68 | :param c_data_l: Corresponding concept labels for that data
69 | :param h_data_u: Activation data without corresponding labels
70 | :return: Classifier predicting concept values from the activations
71 | '''
72 |
73 | # Create safe copies of data
74 | x_data_l = np.copy(h_data_l)
75 | y_data_l = np.copy(c_data_l)
76 | x_data_u = np.copy(h_data_u)
77 | y_data_u = np.ones((x_data_u.shape[0])) * -1 # Here, label is 'undefined'
78 |
79 | # Specify whether to use the unlabelled data at all
80 | semi_supervised = False
81 |
82 | # If there is only 1 value, then return simple classifier
83 | unique_vals = np.unique(y_data_l)
84 | if len(unique_vals) == 1:
85 | clf = DummyClassifier(strategy="constant", constant=c_data_l[0])
86 | else:
87 | # Otherwise, train classifier model
88 | clf, semi_supervised = self.get_clf()
89 |
90 | # Split the labelled data for train/test
91 | x_train, x_test, y_train, y_test = train_test_split(
92 | x_data_l,
93 | y_data_l,
94 | test_size=0.15,
95 | )
96 |
97 | # Combine with unlabelled data, if using a SSCC method
98 | if semi_supervised:
99 | x_train = np.concatenate([x_train, x_data_u])
100 | y_train = np.concatenate([y_train, y_data_u])
101 |
102 | # Train classifier
103 | clf.fit(x_train, y_train)
104 |
105 | # Retrieve predictive accuracy of classifier
106 | pred_acc = accuracy_score(y_test, clf.predict(x_test))
107 |
108 | return clf, pred_acc
109 |
110 | def train(self, ds_l, ds_u):
111 | '''
112 | Compute a dictionary with structure: layer --> concept --> classifier
113 | i.e., the dictionary returns which clf to use to predict a given
114 | concept, from a given layer
115 | :param ds_l: labelled concept dataset, consisting of tuples
116 | (input data, concept data)
117 | :param ds_u: unlabelled dataset, consisting of (input data) only,
118 | without concept labels
119 | '''
120 |
121 | # Load all concept data into a numpy array
122 | c_data_l = np.array([elem[1].numpy() for elem in ds_l])
123 |
124 | self.model_ensemble = {}
125 | self.model_accuracies = {}
126 |
127 | # Compute activations for specified layers
128 | h_data_ls = compute_activation_per_layer(
129 | ds_l,
130 | self.layer_ids,
131 | self.model,
132 | aggregation_function=flatten_activations,
133 | )
134 | h_data_us = compute_activation_per_layer(
135 | ds_u,
136 | self.layer_ids,
137 | self.model,
138 | aggregation_function=flatten_activations,
139 | )
140 |
141 | for i, layer_id in enumerate(self.layer_ids):
142 | print("-" * 20)
143 | print(
144 | "Processing layer ",
145 | str(i + 1),
146 | " of ", str(len(self.layer_ids)),
147 | ":",
148 | self.model.layers[layer_id].name,
149 | )
150 | # Retrieve activations for next layer
151 | activations_l = h_data_ls[i]
152 | activations_u = h_data_us[i]
153 |
154 | output_layer = self.model.layers[layer_id]
155 | self.model_ensemble[output_layer] = []
156 | self.model_accuracies[output_layer] = []
157 |
158 | # Generate predictive models for every concept
159 | for c in range(self.n_concepts):
160 | clf, pred_acc = self._get_layer_concept_predictor_model(
161 | activations_l,
162 | c_data_l[:, c],
163 | activations_u,
164 | )
165 | self.model_ensemble[output_layer].append(clf)
166 | self.model_accuracies[output_layer].append(pred_acc)
167 |
168 | print(
169 | "Processed layer ",
170 | str(i + 1),
171 | " of ",
172 | str(len(self.layer_ids)),
173 | )
174 |
175 | # Initialise concept predictors with models in the first layer
176 | # Consists of an array of size |concepts|, in which
177 | # The element arr[i] is the Layer Id of the layer to use when predicting
178 | # that concept
179 | self.concept_predictor_layer_ids = [
180 | self.layer_ids[0] for _ in range(self.n_concepts)
181 | ]
182 |
183 | # For every concept, identify the layer with the best clf predictive
184 | # accuracy
185 | for c in range(self.n_concepts):
186 | max_acc = -1
187 | for i, layer_id in enumerate(self.layer_ids):
188 | layer = self.model.layers[layer_id]
189 | acc = self.model_accuracies[layer][c]
190 | if acc > max_acc:
191 | max_acc = acc
192 | self.concept_predictor_layer_ids[c] = layer_id
193 |
194 | def predict_concepts(self, ds):
195 | '''
196 | Predict concept values from given data
197 | :param ds: tensorflow dataset of the data (has to have a known
198 | cardinality)
199 | '''
200 |
201 | n_samples = int(ds.cardinality().numpy())
202 | concept_vals = np.zeros((n_samples, self.n_concepts), dtype=float)
203 |
204 | for c in range(self.n_concepts):
205 | # Retrieve clf corresponding to concept c
206 | layer_id = self.concept_predictor_layer_ids[c]
207 | output_layer = self.model.layers[layer_id]
208 | clf = self.model_ensemble[output_layer][c]
209 |
210 | # Compute activations for that layer
211 | output_layer = self.model.layers[layer_id]
212 | reduced_model = tf.keras.Model(
213 | inputs=self.model.inputs,
214 | outputs=[output_layer.output],
215 | )
216 | hidden_features = reduced_model.predict(ds.batch(self.batch_size))
217 | clf_data = flatten_activations(hidden_features)
218 |
219 | # Predict concept values from the activations
220 | concept_vals[:, c] = clf.predict(clf_data)
221 |
222 | return concept_vals
223 |
224 |
225 | def flatten_activations(x_data):
226 | '''
227 | Flatten all axes except the first one
228 | '''
229 |
230 | if len(x_data.shape) > 2:
231 | n_samples = x_data.shape[0]
232 | shape = x_data.shape[1:]
233 | flattened = np.reshape(x_data, (n_samples, np.prod(shape)))
234 | else:
235 | flattened = x_data
236 |
237 | return flattened
238 |
239 |
240 | def compute_activation_per_layer(
241 | ds,
242 | layer_ids,
243 | model,
244 | batch_size=128,
245 | aggregation_function=flatten_activations,
246 | ):
247 | '''
248 | Compute activations of input data for 'layer_ids' layers
249 | For every layer, aggregate values using 'aggregation_function'
250 |
251 | Returns a list L of size |layer_ids|, in which element L[i] is the
252 | activations computed from the model layer model.layers[layer_ids[i]],
253 | processed by the aggregation function
254 |
255 | :param ds: tf.dataset returning the input data
256 | :param layer_ids: list of indices, indicating model layers to use
257 | :param model: tf.Keras model to compute the activations from
258 | :param batch_size: batch size to use during processing
259 | :param aggregation_function: aggregation function for aggregating the
260 | activation values
261 | '''
262 |
263 | hidden_features_list = []
264 |
265 | for layer_id in layer_ids:
266 | # Compute and aggregate hidden activations
267 | output_layer = model.layers[layer_id]
268 | reduced_model = tf.keras.Model(
269 | inputs=model.inputs,
270 | outputs=output_layer.output,
271 | )
272 | hidden_features = reduced_model.predict(ds.batch(batch_size))
273 | flattened = aggregation_function(hidden_features)
274 |
275 | hidden_features_list.append(flattened)
276 |
277 | return hidden_features_list
278 |
--------------------------------------------------------------------------------
/concepts_xai/methods/CW/CWLayer.py:
--------------------------------------------------------------------------------
1 | '''
2 | Re-implementation of the Concept Whitening module proposed by Chen et al..
3 |
4 | 1) See the IterNormRotation class in for the original paper implementation:
5 | https://github.com/zhiCHEN96/ConceptWhitening/blob/final_version/MODELS/iterative_normalization.py
6 |
7 | 2) See Concept Whitening for Interpretable Image Recognition
8 | (https://arxiv.org/abs/2002.01650) for the original paper
9 | '''
10 |
11 | import tensorflow as tf
12 |
13 | ################################################################################
14 | ## Helper classes taken from tensorflow-addons to avoid import and extend
15 | ## functions.
16 | ## Full code can be found in
17 | ## https://github.com/tensorflow/addons/blob/v0.13.0/tensorflow_addons/layers/max_unpooling_2d.py#L88-L147
18 | ################################################################################
19 |
20 |
21 | def normalize_tuple(value, n, name):
22 | """Transforms an integer or iterable of integers into an integer tuple.
23 | A copy of tensorflow.python.keras.util.
24 | Args:
25 | value: The value to validate and convert. Could an int, or any iterable
26 | of ints.
27 | n: The size of the tuple to be returned.
28 | name: The name of the argument being validated, e.g. "strides" or
29 | "kernel_size". This is only used to format error messages.
30 | Returns:
31 | A tuple of n integers.
32 | Raises:
33 | ValueError: If something else than an int/long or iterable thereof was
34 | passed.
35 | """
36 | if isinstance(value, int):
37 | return (value,) * n
38 | else:
39 | try:
40 | value_tuple = tuple(value)
41 | except TypeError:
42 | raise TypeError(
43 | "The `"
44 | + name
45 | + "` argument must be a tuple of "
46 | + str(n)
47 | + " integers. Received: "
48 | + str(value)
49 | )
50 | if len(value_tuple) != n:
51 | raise ValueError(
52 | "The `"
53 | + name
54 | + "` argument must be a tuple of "
55 | + str(n)
56 | + " integers. Received: "
57 | + str(value)
58 | )
59 | for single_value in value_tuple:
60 | try:
61 | int(single_value)
62 | except (ValueError, TypeError):
63 | raise ValueError(
64 | "The `"
65 | + name
66 | + "` argument must be a tuple of "
67 | + str(n)
68 | + " integers. Received: "
69 | + str(value)
70 | + " "
71 | "including element "
72 | + str(single_value)
73 | + " of type"
74 | + " "
75 | + str(type(single_value))
76 | )
77 | return value_tuple
78 |
79 |
80 | def _calculate_output_shape(input_shape, pool_size, strides, padding):
81 | """Calculates the shape of the unpooled output."""
82 | if padding == "VALID":
83 | output_shape = (
84 | input_shape[0],
85 | (input_shape[1] - 1) * strides[0] + pool_size[0],
86 | (input_shape[2] - 1) * strides[1] + pool_size[1],
87 | input_shape[3],
88 | )
89 | elif padding == "SAME":
90 | output_shape = (
91 | input_shape[0],
92 | input_shape[1] * strides[0],
93 | input_shape[2] * strides[1],
94 | input_shape[3],
95 | )
96 | else:
97 | raise ValueError('Padding must be a string from: "SAME", "VALID"')
98 | return output_shape
99 |
100 |
101 | def _max_unpooling_2d(
102 | updates,
103 | mask,
104 | pool_size=(2, 2),
105 | strides=(2, 2),
106 | padding="SAME",
107 | ):
108 | """Unpool the outputs of a maximum pooling operation."""
109 | mask = tf.cast(mask, "int32")
110 | pool_size = normalize_tuple(pool_size, 2, "pool_size")
111 | strides = normalize_tuple(strides, 2, "strides")
112 | input_shape = tf.shape(updates, out_type="int32")
113 | input_shape = [updates.shape[i] or input_shape[i] for i in range(4)]
114 | output_shape = _calculate_output_shape(
115 | input_shape,
116 | pool_size,
117 | strides,
118 | padding,
119 | )
120 |
121 | # Calculates indices for batch, height, width and feature maps.
122 | one_like_mask = tf.ones_like(mask, dtype="int32")
123 | batch_shape = tf.concat([[input_shape[0]], [1], [1], [1]], axis=0)
124 | batch_range = tf.reshape(
125 | tf.range(output_shape[0], dtype="int32"), shape=batch_shape
126 | )
127 | b = one_like_mask * batch_range
128 | y = mask // (output_shape[2] * output_shape[3])
129 | x = (mask // output_shape[3]) % output_shape[2]
130 | feature_range = tf.range(output_shape[3], dtype="int32")
131 | f = one_like_mask * feature_range
132 |
133 | # Transposes indices & reshape update values to one dimension.
134 | updates_size = tf.size(updates)
135 | indices = tf.transpose(
136 | tf.reshape(tf.stack([b, y, x, f]), [4, updates_size])
137 | )
138 | values = tf.reshape(updates, [updates_size])
139 | return tf.scatter_nd(indices, values, output_shape)
140 |
141 |
142 | ################################################################################
143 | ## Concept Whitening Layer Class
144 | ################################################################################
145 |
146 |
147 | class ConceptWhiteningLayer(tf.keras.layers.Layer):
148 | def __init__(
149 | self,
150 | T=5,
151 | eps=1e-5,
152 | momentum=0.9,
153 | activation_mode='max_pool_mean',
154 | c1=1e-4,
155 | c2=0.9,
156 | max_tau_iterations=500,
157 | initial_tau=1000.0, # Original CW code: 1000
158 | data_format="channels_first",
159 | initial_beta=1e8, # Original CW code: 1e8
160 | initial_alpha=0, # Original CW code: 0
161 | **kwargs
162 | ):
163 | super(ConceptWhiteningLayer, self).__init__(**kwargs)
164 | assert data_format in ["channels_first", "channels_last"], (
165 | f'Expected data format to be either "channels_first" or '
166 | f'"channels_last" but got {data_format} instead.'
167 | )
168 |
169 | self.T = T
170 | self.eps = eps
171 | self.momentum = momentum
172 | self.c1 = c1
173 | self.c2 = c2
174 | self.max_tau_iterations = max_tau_iterations
175 | self.data_format = data_format
176 | self.initial_tau = initial_tau
177 | self.initial_alpha = initial_alpha
178 | self.initial_beta = initial_beta
179 |
180 | # Methods for aggregating a feature map into an score
181 | self.activation_mode = activation_mode
182 |
183 | def compute_output_shape(self, input_shape):
184 | return input_shape
185 |
186 | def concept_scores(
187 | self,
188 | inputs,
189 | aggregator='max_pool_mean',
190 | concept_indices=None,
191 | ):
192 | outputs = self(inputs, training=False)
193 | if len(tf.shape(outputs)) == 2:
194 | # Then the scores are already computed by our forward pass
195 | scores = outputs
196 | else:
197 | if self.data_format == "channels_last":
198 | # Then we will transpose to make things simpler so that
199 | # downstream we can always assume it is channels first
200 | # NHWC -> NCHW
201 | outputs = tf.transpose(
202 | outputs,
203 | perm=[0, 3, 1, 2],
204 | )
205 |
206 | # Else, we need to do some aggregation
207 | if aggregator == 'mean':
208 | # Compute the mean over all channels
209 | scores = tf.math.reduce_mean(outputs, axis=[2, 3])
210 | elif aggregator == 'max_pool_mean':
211 | # First downsample using a max pool and then continue with
212 | # a mean
213 | window_size = min(
214 | 2,
215 | outputs.shape[-1],
216 | outputs.shape[-2],
217 | )
218 | scores = tf.nn.max_pool(
219 | outputs,
220 | ksize=window_size,
221 | strides=window_size,
222 | padding="SAME",
223 | data_format="NCHW",
224 | )
225 | scores = tf.math.reduce_mean(scores, axis=[2, 3])
226 | elif aggregator == 'max':
227 | # Simply select the maximum value across a given channel
228 | scores = tf.math.reduce_max(outputs, axis=[2, 3])
229 | else:
230 | raise ValueError(f'Unsupported aggregator {aggregator}.')
231 |
232 | if concept_indices is not None:
233 | return scores[:, concept_indices]
234 | return scores
235 |
236 | def build(self, input_shape):
237 | super(ConceptWhiteningLayer, self).build(input_shape)
238 |
239 | # And use the shape to construct all our running variables
240 | assert len(input_shape) in [2, 4], (
241 | f'Expected input to CW layer to be a rank-2 or rank-4 matrix but '
242 | f'instead got a tensor with shape {input_shape}.'
243 | )
244 |
245 | # We assume channels-first data format
246 | if self.data_format == "channels_first":
247 | self.num_features = input_shape[1]
248 | else:
249 | self.num_features = input_shape[-1]
250 |
251 | # Running mean
252 | self.running_mean = self.add_weight(
253 | name="running_mean",
254 | shape=(self.num_features,),
255 | dtype=tf.float32,
256 | initializer=tf.constant_initializer(0),
257 | # No gradient flow is expected to come into this variable
258 | trainable=False,
259 | )
260 |
261 | # Running whitened matrix
262 | self.running_wm = self.add_weight(
263 | name="running_wm",
264 | shape=(self.num_features, self.num_features),
265 | dtype=tf.float32,
266 | initializer=tf.keras.initializers.Identity(),
267 | # No gradient flow is expected to come into this variable
268 | trainable=False,
269 | )
270 |
271 | # Running rotation matrix
272 | self.running_rot = self.add_weight(
273 | name="running_rot",
274 | shape=(self.num_features, self.num_features),
275 | dtype=tf.float32,
276 | initializer=tf.keras.initializers.Identity(),
277 | # No gradient flow is expected to come into this variable
278 | trainable=False,
279 | )
280 |
281 | # Sum of Gradient matrix
282 | self.sum_G = self.add_weight(
283 | name="sum_G",
284 | shape=(self.num_features, self.num_features),
285 | dtype=tf.float32,
286 | initializer=tf.keras.initializers.Zeros(),
287 | # No gradient flow is expected to come into this variable
288 | trainable=False,
289 | )
290 |
291 | # Counter of gradients for each feature
292 | self.counter = self.add_weight(
293 | name="counter",
294 | shape=(self.num_features,),
295 | dtype=tf.float32,
296 | initializer=tf.constant_initializer(0.001),
297 | # No gradient flow is expected to come into this variable
298 | trainable=False,
299 | )
300 |
301 | def update_rotation_matrix(self, concept_groups, index_map=lambda x: x):
302 | """
303 | Update the rotation matrix R using the accumulated gradient G.
304 | """
305 |
306 | # Updating the gradient matrix, using the concept datasets and their
307 | # aligned maps
308 | for i, concept_samples in enumerate(concept_groups):
309 |
310 | samples_shape = tf.shape(concept_samples)
311 | if len(samples_shape) == 2:
312 | # Add trivial dimensions to make it 4D so that it works as
313 | # it does with image data. We will undo this at the end
314 | concept_samples = tf.reshape(
315 | concept_samples,
316 | tf.concat(
317 | [samples_shape[0:1], samples_shape[1:], [1, 1]],
318 | axis=0,
319 | ),
320 | )
321 | if (
322 | (self.data_format == "channels_last") and
323 | (len(samples_shape) != 2)
324 | ):
325 | # Then we will transpose to make things simpler so that
326 | # downstream we can always assume it is channels first
327 | # NHWC -> NCHW
328 | concept_samples = tf.transpose(
329 | concept_samples,
330 | perm=[0, 3, 1, 2],
331 | )
332 |
333 | # Produce the whitened only activations of this concept group
334 | X_hat = self._compute_whitened_activations(
335 | concept_samples,
336 | rotate=False,
337 | training=False,
338 | )
339 |
340 | # Determine which feature map to use for this concept group
341 | feature_idx = index_map(i)
342 |
343 | # And update the gradient by performing an accumulation using the
344 | # requested activation mode
345 | if self.activation_mode == 'mean':
346 | grad_col = -tf.math.reduce_mean(
347 | tf.math.reduce_mean(X_hat, axis=[2, 3]),
348 | axis=0,
349 | )
350 |
351 | self.sum_G[feature_idx, :].assign(
352 | (
353 | (grad_col * self.momentum)
354 | ) + (1. - self.momentum) * self.sum_G[feature_idx, :]
355 | )
356 |
357 | elif self.activation_mode == 'max_pool_mean':
358 | X_test_nchw = tf.linalg.einsum(
359 | 'bchw,dc->bdhw',
360 | X_hat,
361 | self.running_rot,
362 | )
363 |
364 | # Move to NHWC as tf.nn.max_pool_with_argmax only supports
365 | # channels last format
366 | X_test_nhwc = tf.transpose(X_test_nchw, perm=[0, 2, 3, 1])
367 |
368 | window_size = min(
369 | 2,
370 | X_hat.shape[-1],
371 | X_hat.shape[-2],
372 | )
373 | maxpool_value, max_indices = tf.nn.max_pool_with_argmax(
374 | X_test_nhwc,
375 | ksize=window_size,
376 | strides=window_size,
377 | padding='SAME',
378 | data_format='NHWC',
379 | )
380 | X_test_unpool = _max_unpooling_2d(
381 | maxpool_value,
382 | max_indices,
383 | pool_size=window_size,
384 | strides=window_size,
385 | padding="SAME",
386 | )
387 |
388 | # And reshape to original NCHW format from NHWC format
389 | X_test_unpool = tf.transpose(X_test_unpool, perm=[0, 3, 1, 2])
390 |
391 | # Finally, compute the actual gradient and update or running
392 | # matrix
393 | maxpool_mask = tf.cast(
394 | tf.math.equal(X_test_nchw, X_test_unpool),
395 | tf.float32,
396 | )
397 | # Average only over those elements selected by the max pool
398 | # operator.
399 | grad = (
400 | tf.reduce_sum(X_hat * maxpool_mask, axis=(2, 3)) /
401 | tf.reduce_sum(maxpool_mask, axis=(2, 3))
402 | )
403 |
404 | # And average over all samples
405 | grad = -tf.reduce_mean(grad, axis=0)
406 | self.sum_G[feature_idx, :].assign(
407 | self.momentum * grad +
408 | (1. - self.momentum) * self.sum_G[feature_idx, :]
409 | )
410 |
411 | else:
412 | raise NotImplementedError(
413 | "Currently supporting only the max_pool_mean and mean "
414 | "options"
415 | )
416 |
417 | # And increase the counter
418 | self.counter[feature_idx].assign(self.counter[feature_idx] + 1)
419 |
420 | # Time to update our rotation matrix
421 | # Original CW paper does this counter division so keeping it for now
422 | # for backwards compatability
423 | G = self.sum_G / tf.expand_dims(self.counter, axis=-1)
424 |
425 | # Original CW code uses range(2) for some bizzare reason
426 | for _ in range(2):
427 | tau = self.initial_tau # learning rate in Cayley transform
428 | alpha = self.initial_alpha
429 | beta = self.initial_beta
430 |
431 | # Compute: GR^T - RG^T
432 | A = tf.einsum('in,jn->ij', G, self.running_rot) - tf.einsum(
433 | 'in,jn->ij',
434 | self.running_rot,
435 | G,
436 | )
437 | I = tf.eye(self.num_features)
438 | dF_0 = -0.5 * tf.math.reduce_sum(A * A)
439 |
440 | # Computing tau using Algorithm 1 in
441 | # https://link.springer.com/article/10.1007/s10107-012-0584-1
442 | # Binary search for appropriate learning rate
443 | count = 0
444 | while count < self.max_tau_iterations:
445 | Q = tf.linalg.matmul(
446 | tf.linalg.inv(I + 0.5 * tau * A),
447 | I - 0.5 * tau * A,
448 | )
449 | Y_tau = tf.linalg.matmul(Q, self.running_rot)
450 | F_X = tf.math.reduce_sum(G * self.running_rot)
451 | F_Y_tau = tf.math.reduce_sum(G * Y_tau)
452 | dF_tau = tf.linalg.matmul(
453 | tf.einsum(
454 | 'ni,nj->ij',
455 | G,
456 | tf.linalg.inv(I + 0.5 * tau * A),
457 | ),
458 | tf.linalg.matmul(A, 0.5 * (self.running_rot + Y_tau)),
459 | )
460 | dF_tau = -tf.linalg.trace(dF_tau)
461 |
462 | if F_Y_tau > F_X + self.c1 * tau * dF_0 + 1e-18:
463 | beta = tau
464 | tau = (beta + alpha) / 2
465 | elif dF_tau + 1e-18 < self.c2 * dF_0:
466 | alpha = tau
467 | tau = (beta + alpha) / 2
468 | else:
469 | break
470 | count += 1
471 |
472 | if count > self.max_tau_iterations:
473 | print("------------------update fail----------------------")
474 | print(F_Y_tau, F_X + self.c1 * tau * dF_0)
475 | print(dF_tau, self.c2 * dF_0)
476 | print("---------------------------------------------------")
477 | break
478 |
479 | # Using the un-numbered equation in the Concept Whitening paper
480 | # Lines 12-13 of Algorithm 2 in CW paper
481 | Q = tf.linalg.matmul(
482 | tf.linalg.matmul(
483 | tf.linalg.inv(I + 0.5 * tau * A),
484 | I - 0.5 * tau * A,
485 | ),
486 | self.running_rot,
487 | )
488 | # And update the rotation matrix as well as reset the counters
489 | self.running_rot.assign(Q)
490 |
491 | self.counter.assign(tf.ones((self.num_features,)) * 0.001)
492 |
493 | def call(self, inputs, training=False):
494 | input_shape = tf.shape(inputs)
495 | static_imputs_shape = inputs.shape
496 | if len(static_imputs_shape) == 2:
497 | # Add trivial dimensions to make it 4D so that it works as
498 | # it does with image data. We will undo this at the end
499 | inputs = tf.reshape(
500 | inputs,
501 | tf.concat(
502 | [input_shape[0:1], input_shape[1:], [1, 1]],
503 | axis=0,
504 | ),
505 | )
506 | if (self.data_format == "channels_last") and (
507 | len(static_imputs_shape) != 2
508 | ):
509 | # Then we will transpose to make things simpler so that downstream
510 | # we can always assume it is channels first
511 | # NHWC -> NCHW
512 | inputs = tf.transpose(
513 | inputs,
514 | perm=[0, 3, 1, 2],
515 | )
516 |
517 | result = tf.linalg.einsum(
518 | 'bchw,dc->bdhw',
519 | self._compute_whitened_activations(inputs, training),
520 | self.running_rot,
521 | )
522 |
523 | if len(static_imputs_shape) == 2:
524 | # Then let's get it back to its original shape
525 | result = tf.reshape(result, input_shape)
526 | if (self.data_format == "channels_last") and (
527 | len(static_imputs_shape) != 2
528 | ):
529 | # Then let's move it back to channels last
530 | # NCHW -> NHWC
531 | result = tf.transpose(
532 | result,
533 | perm=[0, 2, 3, 1],
534 | )
535 | return result
536 |
537 | def _compute_whitened_activations(self, X, training, rotate=False):
538 | '''
539 | Implements Algorithm 1 from https://arxiv.org/pdf/1904.03441.pdf
540 | Also, updates the running mean and whitening matrices using a moving
541 | average
542 |
543 | Assumes the input X is in the format of (N, C, H, W)
544 | '''
545 |
546 | input_shape = tf.shape(X)
547 |
548 | # Flip first two dimensions in order to obtain a (C, N, H, W) tensor
549 | x = tf.transpose(X, perm=[1, 0, 2, 3])
550 |
551 | # Change (C, N, H, W) to (D, NxHxW)
552 | cnhw_shape = tf.shape(x)
553 | x = tf.reshape(x, [self.num_features, -1])
554 | m = tf.shape(x)[-1]
555 |
556 | if training:
557 | # Calculate mini-batch mean
558 | # Line 4 of Algorithm 1
559 | mean = tf.reduce_mean(x, axis=-1, keepdims=True)
560 |
561 | # Calculate centered activation
562 | # Line 5 of Algorithm 1
563 | xc = x - mean
564 |
565 | # Calculate covariance matrix
566 | # Line 6 of Algorithm 1
567 | I = tf.eye(self.num_features)
568 | sigma = self.eps * I
569 | sigma += 1./tf.cast(m, tf.float32) * tf.linalg.matmul(
570 | xc,
571 | tf.transpose(xc)
572 | )
573 | # Calculate trace-normalized covariance matrix using eqn. (4) in
574 | # the paper
575 | # Line 7 of Algorithm 1
576 | sigma_tr_rec = tf.expand_dims(
577 | tf.math.reciprocal(tf.linalg.trace(sigma)),
578 | axis=-1,
579 | )
580 | sigma_N = sigma * sigma_tr_rec
581 |
582 | # Original CW code: (they do not use the actual trace and this can
583 | # leas to instability during training)
584 | # sigma_tr_rec = tf.reduce_sum(sigma, axis=(0, 1), keepdims=True)
585 | # sigma_N = sigma * sigma_tr_rec
586 |
587 | # Calculate whitening matrix
588 | # Lines 8-11 of Algorithm 1
589 | P = tf.eye(self.num_features)
590 | for _ in range(self.T):
591 | P_cubed = tf.linalg.matmul(
592 | tf.linalg.matmul(P, P),
593 | P,
594 | )
595 | P = 1.5 * P - (
596 | 0.5 * tf.linalg.matmul(P_cubed, sigma_N)
597 | )
598 |
599 | # Line 12 of Algorithm 1
600 | wm = tf.math.multiply(P, tf.math.sqrt(sigma_tr_rec))
601 |
602 | # Update the running mean and whitening matrix
603 | self.running_mean.assign(
604 | self.momentum * tf.squeeze(mean, axis=-1) +
605 | (1. - self.momentum) * self.running_mean
606 | )
607 | self.running_wm.assign(
608 | self.momentum * wm +
609 | (1. - self.momentum) * self.running_wm
610 | )
611 |
612 | else:
613 | xc = x - tf.expand_dims(self.running_mean, axis=-1)
614 | wm = self.running_wm
615 |
616 | # Calculate whitening output
617 | xn = tf.linalg.matmul(wm, xc)
618 |
619 | # And, if requested, apply the rotation while it is in this format
620 | if rotate:
621 | xn = tf.linalg.einsum(
622 | 'bchw,dc->bdhw',
623 | xn,
624 | self.running_rot,
625 | )
626 |
627 | # Transform back to original shape of (N, C, H, W)
628 | return tf.transpose(tf.reshape(xn, cnhw_shape), perm=[1, 0, 2, 3])
629 |
630 | def get_config(self):
631 | """
632 | Serialization function.
633 | """
634 | result = super(ConceptWhiteningLayer, self).get_config()
635 | result.update(dict(
636 | T=self.T,
637 | eps=self.eps,
638 | momentum=self.momentum,
639 | activation_mode=self.activation_mode,
640 | c1=self.c1,
641 | c2=self.c2,
642 | max_tau_iterations=self.max_tau_iterations,
643 | ))
644 | return result
645 |
--------------------------------------------------------------------------------
/concepts_xai/methods/OCACE/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dmitrykazhdan/concept-based-xai/a86c6abd609ec2ef0cf56cb556e3f834a0962bdc/concepts_xai/methods/OCACE/__init__.py
--------------------------------------------------------------------------------
/concepts_xai/methods/OCACE/topicModel.py:
--------------------------------------------------------------------------------
1 | import concepts_xai.evaluation.metrics.completeness as completeness
2 | import numpy as np
3 | import tensorflow as tf
4 | import tensorflow.keras.backend as K
5 | import scipy
6 |
7 | '''
8 | Re-implementation of the "On Completeness-aware Concept-Based Explanations in
9 | Deep Neural Networks":
10 |
11 | 1) See https://arxiv.org/abs/1910.07969 for the original paper
12 |
13 | 2) See https://github.com/chihkuanyeh/concept_exp for the original paper
14 | implementation
15 |
16 | '''
17 |
18 |
19 | class TopicModel(tf.keras.Model):
20 | """Base class of a topic model."""
21 |
22 | def __init__(
23 | self,
24 | concepts_to_labels_model,
25 | n_channels,
26 | n_concepts,
27 | g_model=None,
28 | threshold=0.5,
29 | loss_fn=tf.keras.losses.sparse_categorical_crossentropy,
30 | top_k=32,
31 | lambda1=0.1,
32 | lambda2=0.1,
33 | seed=None,
34 | eps=1e-5,
35 | data_format="channels_last",
36 | allow_gradient_flow_to_c2l=False,
37 | acc_metric=None,
38 | initial_topic_vector=None,
39 | **kwargs,
40 | ):
41 | super(TopicModel, self).__init__(**kwargs)
42 |
43 | initializer = tf.keras.initializers.RandomUniform(
44 | minval=-0.5,
45 | maxval=0.5,
46 | seed=seed,
47 | )
48 |
49 | # Initialize our topic vector tensor which we will learn
50 | # as part of our training
51 | if initial_topic_vector is not None:
52 | self.topic_vector = self.add_weight(
53 | name="topic_vector",
54 | shape=(n_channels, n_concepts),
55 | dtype=tf.float32,
56 | initializer=lambda *args, **kwargs: initial_topic_vector,
57 | trainable=True,
58 | )
59 | else:
60 | self.topic_vector = self.add_weight(
61 | name="topic_vector",
62 | shape=(n_channels, n_concepts),
63 | dtype=tf.float32,
64 | initializer=tf.keras.initializers.RandomUniform(
65 | minval=-0.5,
66 | maxval=0.5,
67 | seed=seed,
68 | ),
69 | trainable=True,
70 | )
71 |
72 | # Initialize the g model which will be in charge of reconstructing
73 | # the model latent activations from the concept scores alone
74 | self.g_model = g_model
75 | if self.g_model is None:
76 | self.g_model = completeness._get_default_model(
77 | num_concepts=n_concepts,
78 | num_hidden_acts=n_channels,
79 | )
80 |
81 | # Set the concept-to-label predictor model
82 | self.concepts_to_labels_model = concepts_to_labels_model
83 |
84 | # Set remaining model hyperparams
85 | self.eps = eps
86 | self.threshold = threshold
87 | self.n_concepts = n_concepts
88 | self.loss_fn = loss_fn
89 | self.top_k = top_k
90 | self.lambda1 = lambda1
91 | self.lambda2 = lambda2
92 | self.n_channels = n_channels
93 | self.allow_gradient_flow_to_c2l = allow_gradient_flow_to_c2l
94 | assert data_format in ["channels_last", "channels_first"], (
95 | f'Expected data format to be either "channels_last" or '
96 | f'"channels_first" however we obtained "{data_format}".'
97 | )
98 | if data_format == "channels_last":
99 | self._channel_axis = -1
100 | else:
101 | raise ValueError(
102 | 'Currently we only support "channels_last" data_format'
103 | )
104 |
105 | self.metric_names = ["loss", "mean_sim", "accuracy"]
106 | self.metrics_dict = {
107 | name: tf.keras.metrics.Mean(name=name)
108 | for name in self.metric_names
109 | }
110 | self._acc_metric = (
111 | acc_metric or tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
112 | )
113 |
114 | @property
115 | def metrics(self):
116 | return [self.metrics_dict[name] for name in self.metric_names]
117 |
118 | def update_metrics(self, losses):
119 | for (loss_name, loss) in losses:
120 | self.metrics_dict[loss_name].update_state(loss)
121 |
122 | def concept_scores(self, x, compute_reg_terms=False):
123 | # Compute the concept representation by first normalizing across the
124 | # channel dimension both the concept vectors and the input
125 | assert x.shape[self._channel_axis] == self.n_channels, (
126 | f'Expected input to have {self.n_channels} elements in its '
127 | f'channels axis (defined as axis {self._channel_axis}). '
128 | f'Instead, we found the input to have shape {x.shape}.'
129 | )
130 |
131 | x_norm = tf.math.l2_normalize(x, axis=self._channel_axis)
132 | topic_vector_norm = tf.math.l2_normalize(self.topic_vector, axis=0)
133 | # Compute the concept probability scores
134 | topic_prob = K.dot(x, topic_vector_norm)
135 | topic_prob_norm = K.dot(x_norm, topic_vector_norm)
136 |
137 | # Threshold them if they are below the given threshold value
138 | if self.threshold is not None:
139 | topic_prob = topic_prob * tf.cast(
140 | (topic_prob_norm > self.threshold),
141 | tf.float32,
142 | )
143 | topic_prob_sum = tf.reduce_sum(
144 | topic_prob,
145 | axis=self._channel_axis,
146 | keepdims=True,
147 | )
148 | # And normalize the actual scores
149 | topic_prob = topic_prob / (topic_prob_sum + self.eps)
150 | if not compute_reg_terms:
151 | return topic_prob
152 |
153 | # Compute the regularization loss terms
154 | reshaped_topic_probs = tf.transpose(
155 | tf.reshape(topic_prob_norm, (-1, self.n_concepts))
156 | )
157 | reg_loss_1 = tf.reduce_mean(
158 | tf.nn.top_k(
159 | reshaped_topic_probs,
160 | k=tf.math.minimum(
161 | self.top_k,
162 | tf.shape(reshaped_topic_probs)[-1]
163 | ),
164 | sorted=True,
165 | ).values
166 | )
167 |
168 | reg_loss_2 = tf.reduce_mean(
169 | K.dot(tf.transpose(topic_vector_norm), topic_vector_norm) -
170 | tf.eye(self.n_concepts)
171 | )
172 | return topic_prob, reg_loss_1, reg_loss_2
173 |
174 | def _compute_loss(self, x, y_true, training):
175 | # First, compute the concept scores for the given samples
176 | scores, reg_loss_1, reg_loss_2 = self.concept_scores(
177 | x,
178 | compute_reg_terms=True
179 | )
180 |
181 | # Then predict the labels after reconstructing the activations
182 | # from the scores via the g model
183 | y_pred = self.concepts_to_labels_model(
184 | self.g_model(scores),
185 | training=training,
186 | )
187 |
188 | # Compute the task loss
189 | log_prob_loss = tf.reduce_mean(self.loss_fn(y_true, y_pred))
190 |
191 | # And include them into the total loss
192 | total_loss = (
193 | log_prob_loss -
194 | self.lambda1 * reg_loss_1 +
195 | self.lambda2 * reg_loss_2
196 | )
197 |
198 | # Compute the accuracy metric to track
199 | self._acc_metric.update_state(y_true, y_pred)
200 | return total_loss, reg_loss_1, self._acc_metric.result()
201 |
202 | def train_step(self, inputs):
203 | x, y = inputs
204 |
205 | # We first need to make sure that we set our concepts to labels model
206 | # so that we do not optimize over its parameters
207 | prev_trainable = self.concepts_to_labels_model.trainable
208 | if not self.allow_gradient_flow_to_c2l:
209 | self.concepts_to_labels_model.trainable = False
210 | with tf.GradientTape() as tape:
211 | loss, mean_sim, acc = self._compute_loss(
212 | x,
213 | y,
214 | # Only train the decoder if requested by the user
215 | training=self.allow_gradient_flow_to_c2l,
216 | )
217 |
218 | gradients = tape.gradient(loss, self.trainable_variables)
219 | self.optimizer.apply_gradients(
220 | zip(gradients, self.trainable_variables)
221 | )
222 | self.update_metrics([
223 | ("loss", loss),
224 | ("mean_sim", mean_sim),
225 | ("accuracy", acc),
226 | ])
227 |
228 | # And recover the previous step of the concept to labels model
229 | self.concepts_to_labels_model.trainable = prev_trainable
230 |
231 | return {
232 | name: self.metrics_dict[name].result()
233 | for name in self.metric_names
234 | }
235 |
236 | def test_step(self, inputs):
237 | x, y = inputs
238 | loss, mean_sim, acc = self._compute_loss(x, y, training=False)
239 | self.update_metrics([
240 | ("loss", loss),
241 | ("mean_sim", mean_sim),
242 | ("accuracy", acc)
243 | ])
244 |
245 | return {
246 | name: self.metrics_dict[name].result()
247 | for name in self.metric_names
248 | }
249 |
250 | def call(self, x, **kwargs):
251 | concept_scores = self.concept_scores(x)
252 | predicted_labels = self.concepts_to_labels_model(
253 | self.g_model(concept_scores),
254 | training=False,
255 | )
256 | return predicted_labels, concept_scores
257 |
258 |
--------------------------------------------------------------------------------
/concepts_xai/methods/OCACE/visualisation.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | import numpy as np
4 |
5 | # Update this to change font sizes
6 | plt.rcParams.update({'font.size': 18})
7 |
8 |
9 | def visualize_nearest_neighbours(
10 | f_data,
11 | x_data,
12 | topic_model,
13 | n_prototypes=5,
14 | channels_axis=-1,
15 | ):
16 |
17 | topic_vec = topic_model.topic_vector.numpy()
18 | fig, axes = plt.subplots(
19 | topic_vec.shape[1],
20 | n_prototypes,
21 | figsize=(20, 20),
22 | )
23 | topic_prob = topic_model.concept_scores(f_data)
24 | if len(f_data.shape):
25 | # Then let's trivially expand its dimensions so that we always
26 | # work with 3D representations
27 | f_data = np.expand_dims(
28 | np.expand_dims(f_data, axis=1),
29 | axis=1,
30 | )
31 | topic_prob = np.expand_dims(
32 | np.expand_dims(topic_prob, axis=1),
33 | axis=1,
34 | )
35 | dim1_size, dim2_size = f_data.shape[1], f_data.shape[2]
36 | dim1_factor = int(x_data.shape[1] / dim1_size)
37 | dim2_factor = int(x_data.shape[2] / dim2_size)
38 |
39 | for i in range(topic_prob.shape[channels_axis]):
40 | # Then topic_prob is an array of shape
41 | # # (n_samples, kernel_w, kernel_h, n_concepts)
42 | # Here, we retrieve the indices of the n_imgs_per_c largest elements
43 | ind = np.argsort(-topic_prob[:, :, :, i].flatten())[:n_prototypes]
44 |
45 | for jc, j in enumerate(ind):
46 | # Retrieve the dim0 index
47 | dim0 = int(np.floor(j/(dim1_size*dim2_size)))
48 | # Retrieve the dim1 index
49 | dim1 = int((j-dim0 * (dim1_size*dim2_size))/dim2_size)
50 | # Retrieve the dim2 index
51 | dim2 = int((j-dim0 * (dim1_size*dim2_size)) % dim2_size)
52 |
53 | # Retrieve the subpart of the image corresponding to the
54 | # activated patch
55 | dim1_start = dim1_factor * dim1
56 | dim2_start = dim2_factor * dim2
57 | img = x_data[dim0, :, :, :]
58 | img = img[
59 | dim1_start: dim1_start + dim1_factor,
60 | dim2_start: dim2_start + dim2_factor,
61 | :,
62 | ]
63 |
64 | # Plot the image
65 | ax = axes[i, jc]
66 |
67 | if img.shape[-1] == 1:
68 | ax.imshow(img, cmap='gray')
69 | else:
70 | ax.imshow(img)
71 | ax.yaxis.grid(False)
72 | ax.get_yaxis().set_visible(False)
73 | ax.get_xaxis().set_visible(False)
74 |
75 | fig.tight_layout()
76 | plt.show()
77 |
--------------------------------------------------------------------------------
/concepts_xai/methods/SENN/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dmitrykazhdan/concept-based-xai/a86c6abd609ec2ef0cf56cb556e3f834a0962bdc/concepts_xai/methods/SENN/__init__.py
--------------------------------------------------------------------------------
/concepts_xai/methods/SENN/aggregators.py:
--------------------------------------------------------------------------------
1 | """
2 | Library of some default aggregator functions that one can use in SENN.
3 | """
4 |
5 | import tensorflow as tf
6 |
7 |
8 | def multiclass_additive_aggregator(thetas, concepts):
9 | # Returns output shape (batch, n_outputs)
10 | return tf.squeeze(
11 | # (batch, n_outputs, 1)
12 | tf.linalg.matmul(
13 | # (batch, n_outputs, n_concepts)
14 | thetas,
15 | # (batch, n_concepts, 1)
16 | tf.expand_dims(concepts, axis=-1),
17 | ),
18 | axis=-1
19 | )
20 |
21 |
22 | def scalar_additive_aggregator(thetas, concepts):
23 | # Returns output shape (batch)
24 | return tf.squeeze(
25 | multiclass_additive_aggregator(thetas=thetas, concepts=concepts),
26 | axis=-1,
27 | )
28 |
29 |
30 | def softmax_additive_aggregator(thetas, concepts):
31 | # Returns output shape (batch, n_outputs)
32 | return tf.nn.softmax(
33 | multiclass_additive_aggregator(thetas=thetas, concepts=concepts),
34 | axis=-1,
35 | )
36 |
--------------------------------------------------------------------------------
/concepts_xai/methods/SENN/base_senn.py:
--------------------------------------------------------------------------------
1 | """
2 | Tensorflow implementation of Self-Explaining Neural Networks (SENN) by
3 | Alvarez-Melis and Jaakkola (NeurIPS 2018) [1].
4 |
5 | [1] https://papers.nips.cc/paper/2018/file/3e9f0fc9b2f89e043bc6233994dfcf76-Paper.pdf
6 |
7 | """
8 |
9 | import tensorflow as tf
10 | from tensorflow.python.keras.engine import data_adapter
11 |
12 |
13 | class SelfExplainingNN(tf.keras.Model):
14 | """
15 | Main class implementation for Self-Explaining Neural Networks (SENN)s as
16 | defined and described by Alvarez-Melis and Jaakkola (NeurIPS 2018).
17 | """
18 |
19 | def __init__(
20 | self,
21 | encoder_model,
22 | coefficient_model,
23 | aggregator_fn,
24 | task_loss_fn,
25 | regularization_strength=1e-1, # "lambda" regulatizer parameter
26 | sparsity_strength=2e-5, # "zeta" autoencoder parameter
27 | reconstruction_loss_fn=None,
28 | task_loss_weight=1,
29 | robustness_norm_fn=lambda x: tf.norm(x, ord='fro', axis=[-2, -1]),
30 | metrics=None,
31 | **kwargs,
32 | ):
33 | """
34 | Constructs a SENN Keras Model ready for training.
35 |
36 | When using this model for prediction, it will return a tuple
37 | (labels, (concept_vectors, thetas)) indicating the predicted label
38 | probabilities as well as the concepts vectors, together with their
39 | linear importance weights defined by thetas.
40 | """
41 | super(SelfExplainingNN, self).__init__(**kwargs)
42 | self.task_loss_weight = task_loss_weight
43 | self.encoder_model = encoder_model
44 | self.coefficient_model = coefficient_model
45 | self.aggregator_fn = aggregator_fn
46 | self.task_loss_fn = task_loss_fn
47 | self.regularization_strength = regularization_strength
48 | self.sparsity_strength = sparsity_strength
49 | self.reconstruction_loss_fn = reconstruction_loss_fn
50 | self.robustness_norm_fn = robustness_norm_fn
51 | if self.reconstruction_loss_fn is not None:
52 | self.metrics_dict = {
53 | "loss": tf.keras.metrics.Mean(
54 | name="loss"
55 | ),
56 | "task_loss": tf.keras.metrics.Mean(
57 | name="task_loss"
58 | ),
59 | "robustness_loss": tf.keras.metrics.Mean(
60 | name="robustness_loss"
61 | ),
62 | "reconstruction_loss": tf.keras.metrics.Mean(
63 | name="reconstruction_loss"
64 | ),
65 | }
66 | else:
67 | self.metrics_dict = {
68 | "loss": tf.keras.metrics.Mean(
69 | name="loss"
70 | ),
71 | "task_loss": tf.keras.metrics.Mean(
72 | name="task_loss"
73 | ),
74 | "robustness_loss": tf.keras.metrics.Mean(
75 | name="robustness_loss"
76 | ),
77 | }
78 | self._extra_metrics = []
79 | for metric in (metrics or []):
80 | if isinstance(metric, (tuple, list)):
81 | if len(metric) != 2:
82 | raise ValueError(
83 | f"Expected metrics to be a list of tuples "
84 | f"(name, metric) or TF metric objects. Instead we "
85 | f"received {metric}."
86 | )
87 | name, metric = metric
88 | else:
89 | name = metric.name
90 | self.metrics_dict[name] = metric
91 | self._extra_metrics.append((name, metric))
92 |
93 | @property
94 | def metrics(self):
95 | return [
96 | self.metrics_dict[name] for name in self.metrics_dict.keys()
97 | ]
98 |
99 | def update_metrics(self, losses):
100 | for (name, loss) in losses:
101 | self.metrics_dict[name].update_state(loss)
102 |
103 | def call(self, inputs):
104 | # First compute our concepts
105 | concepts = self.encoder_model(inputs) # (batch, n_concepts)
106 | # Then all of the theta weights which we will use for our concepts
107 | thetas = self.coefficient_model(inputs) # (batch, n_outputs, n_concepts)
108 | if len(thetas.shape) < 3:
109 | # Then the number of classes/outputs is 1 so let's make it explicit
110 | thetas = tf.expand_dims(thetas, axis=1)
111 | # Make sure the dimensions match
112 | tf.debugging.assert_equal(
113 | tf.shape(concepts)[-1],
114 | tf.shape(thetas)[-1],
115 | message=(
116 | "The last dimension of the returned concept tensor and the "
117 | "theta tensor must be the same (they both should correspond to "
118 | "the number of concepts used in the SENN). We found "
119 | f"{tf.shape(concepts)[-1]} entries in concepts.shape[-1] while "
120 | f"{tf.shape(thetas)[-1]} entries in thetas.shape[-1]."
121 | )
122 | )
123 | predictions = self.aggregator_fn(
124 | thetas=thetas,
125 | concepts=concepts,
126 | )
127 | return predictions, (concepts, thetas)
128 |
129 | def train_step(self, inputs):
130 | # This will allow us to compute the task specific loss
131 | inputs = data_adapter.expand_1d(inputs)
132 | x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(inputs)
133 | with tf.GradientTape() as outter_tape:
134 | total_loss = 0
135 | # Do a nesting of tapes as we will need to compute gradients and
136 | # jacobian in order for one
137 | with tf.GradientTape(
138 | persistent=True,
139 | watch_accessed_variables=False,
140 | ) as inner_tape:
141 | # First compute predictions and their corresponding explanation
142 | inner_tape.watch(x)
143 | preds, (concepts, thetas) = self(x)
144 | # This gives us the task specific loss
145 | task_loss = self.task_loss_fn(y, preds)
146 | total_loss += task_loss * self.task_loss_weight
147 |
148 | if self.reconstruction_loss_fn is not None:
149 | # Now compute the encoder reconstruction loss similarly
150 | reconstruction_loss = self.reconstruction_loss_fn(
151 | x,
152 | concepts,
153 | )
154 | total_loss += self.sparsity_strength * reconstruction_loss
155 |
156 | # Finally, time to add the robustness loss. This is the trickiest
157 | # and most expensive one as it requires the computation of a
158 | # jacobian
159 | # Gradient of f(x) with respect to x should have shape
160 | # (batch, n_outputs, ).
161 | # Note that we use a jacobian computation rather than a gradient
162 | # computation (as done in the paper) as we support general
163 | # multi-dimensional outputs
164 | if len(preds.shape) < 3:
165 | # Jacobian requires a 2D input
166 | preds = tf.expand_dims(preds, axis=1)
167 | f_grad_x = inner_tape.batch_jacobian(preds, x)
168 |
169 | # Reshape to (batch, n_outputs, flatten_x_shape)
170 | f_grad_x = tf.reshape(
171 | f_grad_x,
172 | [tf.shape(x)[0], tf.shape(preds)[-1], -1]
173 | )
174 |
175 | # Jacobian of h(x) with respect to x should have shape
176 | # (batch, n_concepts, )
177 | h_jacobian_x = inner_tape.batch_jacobian(
178 | concepts,
179 | x,
180 | )
181 | # Reshape to (batch, n_concepts, flatten_x_shape)
182 | h_jacobian_x = tf.reshape(
183 | h_jacobian_x,
184 | [tf.shape(x)[0], tf.shape(concepts)[-1], -1]
185 | )
186 | robustness_loss = self.robustness_norm_fn(
187 | f_grad_x - tf.matmul(
188 | # No need to transpose thetas as they already have the
189 | # number of outputs as its first non-batch dimension (i.e.,
190 | # its shape is (batch, n_outputs, n_concepts))
191 | thetas,
192 | # h_jacobian_x shape: (batch, n_concepts, flatten_x_shape)
193 | h_jacobian_x,
194 | ) # matmul shape: (batch, n_outputs, flatten_x_shape)
195 | )
196 | robustness_loss = tf.math.reduce_mean(robustness_loss)
197 | total_loss += self.regularization_strength * robustness_loss
198 |
199 | # Compute gradients and proceed with SGD
200 | gradients = outter_tape.gradient(total_loss, self.trainable_variables)
201 | self.optimizer.apply_gradients(
202 | zip(gradients, self.trainable_variables)
203 | )
204 |
205 | # And update all of our metrics
206 | self.update_metrics([
207 | ("loss", total_loss),
208 | ("task_loss", task_loss),
209 | ("robustness_loss", robustness_loss),
210 | ])
211 | if self.reconstruction_loss_fn is not None:
212 | self.update_metrics([
213 | ("reconstruction_loss", reconstruction_loss),
214 | ])
215 | for (name, metric) in self._extra_metrics:
216 | self.metrics_dict[name].update_state(y, preds, sample_weight)
217 | return {
218 | name: val.result()
219 | for name, val in self.metrics_dict.items()
220 | }
221 |
222 | def test_step(self, inputs):
223 | # This will allow us to compute the task specific loss
224 | inputs = data_adapter.expand_1d(inputs)
225 | x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(inputs)
226 | with tf.GradientTape() as outter_tape:
227 | total_loss = 0
228 | # Do a nesting of tapes as we will need to compute gradients and
229 | # jacobian in order for one
230 | with tf.GradientTape(
231 | persistent=True,
232 | watch_accessed_variables=False,
233 | ) as inner_tape:
234 | # First compute predictions and their corresponding explanation
235 | inner_tape.watch(x)
236 | preds, (concepts, thetas) = self(x)
237 | # This gives us the task specific loss
238 | task_loss = self.task_loss_fn(y, preds)
239 | total_loss += task_loss * self.task_loss_weight
240 |
241 | if self.reconstruction_loss_fn is not None:
242 | # Now compute the encoder reconstruction loss similarly
243 | reconstruction_loss = self.reconstruction_loss_fn(
244 | x,
245 | concepts,
246 | )
247 | total_loss += self.sparsity_strength * reconstruction_loss
248 |
249 | # Finally, time to add the robustness loss. This is the trickiest
250 | # and most expensive one as it requires the computation of a
251 | # jacobian
252 | # Gradient of f(x) with respect to x should have shape
253 | # (batch, n_outputs, ).
254 | # Note that we use a jacobian computation rather than a gradient
255 | # computation (as done in the paper) as we support general
256 | # multi-dimensional outputs
257 | if len(preds.shape) < 3:
258 | # Jacobian requires a 2D input
259 | preds = tf.expand_dims(preds, axis=1)
260 | f_grad_x = inner_tape.batch_jacobian(preds, x)
261 | # Reshape to (batch, n_outputs, flatten_x_shape)
262 | f_grad_x = tf.reshape(
263 | f_grad_x,
264 | [tf.shape(x)[0], tf.shape(preds)[-1], -1]
265 | )
266 |
267 | # Jacobian of h(x) with respect to x should have shape
268 | # (batch, n_concepts, )
269 | h_jacobian_x = inner_tape.batch_jacobian(
270 | concepts,
271 | x,
272 | )
273 | # Reshape to (batch, n_concepts, flatten_x_shape)
274 | h_jacobian_x = tf.reshape(
275 | h_jacobian_x,
276 | [tf.shape(x)[0], tf.shape(concepts)[-1], -1]
277 | )
278 | robustness_loss = self.robustness_norm_fn(
279 | f_grad_x - tf.matmul(
280 | # No need to transpose thetas as they already have the
281 | # number of outputs as its first non-batch dimension (i.e.,
282 | # its shape is (batch, n_outputs, n_concepts))
283 | thetas,
284 | # h_jacobian_x shape: (batch, n_concepts, flatten_x_shape)
285 | h_jacobian_x,
286 | ) # matmul shape: (batch, n_outputs, flatten_x_shape)
287 | )
288 | robustness_loss = tf.math.reduce_mean(robustness_loss)
289 | total_loss += self.regularization_strength * robustness_loss
290 |
291 | # And update all of our metrics
292 | self.update_metrics([
293 | ("loss", total_loss),
294 | ("task_loss", task_loss),
295 | ("robustness_loss", robustness_loss),
296 | ])
297 | if self.reconstruction_loss_fn is not None:
298 | self.update_metrics([
299 | ("reconstruction_loss", reconstruction_loss),
300 | ])
301 | for (name, metric) in self._extra_metrics:
302 | self.metrics_dict[name].update_state(y, preds, sample_weight)
303 | return {
304 | name: val.result()
305 | for name, val in self.metrics_dict.items()
306 | }
307 |
--------------------------------------------------------------------------------
/concepts_xai/methods/SSCC/SSCClassifier.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import numpy as np
3 | import os
4 | import tensorflow as tf
5 |
6 | from sklearn.linear_model import LogisticRegression
7 | from sklearn.ensemble import GradientBoostingClassifier
8 | from sklearn.model_selection import train_test_split
9 | from sklearn.metrics import accuracy_score
10 |
11 | from methods.CME.ItCModel import ItCModel
12 | from methods.CBM.CBModel import ConceptBottleneckModel
13 | from evaluation.metrics.accuracy import compute_accuracies
14 | from methods.VAE.weak_vae import GroupVAEArgmax
15 | from methods.VAE.betaVAE import BetaVAE
16 |
17 |
18 | '''
19 | Implementation of the Semi-Supervised Concept Classifier (SSCClassifier)
20 |
21 | Contains implementations of the following:
22 | 1) An abstract SSCClassifier class, specifying the requirements of a
23 | generic SSCC concept classifier
24 | 2) Implementations of the SSCClassifier, serving as wrappers of various
25 | concept-based methods defined in the .methods folder
26 | '''
27 |
28 |
29 | class SSCClassifier(ABC):
30 |
31 | @abstractmethod
32 | def __init__(self, **kwargs):
33 | '''
34 | Initialise the SSCClassifier wrapper
35 | :param kwargs: any extra arguments needed to initialise the concept
36 | predictor model
37 | '''
38 | pass
39 |
40 | @abstractmethod
41 | def fit(self, data_gen_l, data_gen_u):
42 | '''
43 | Method for training the underlying SSCC model
44 | :param data_gen_l: Tensorflow Dataset, returning triples of points
45 | (x_data, c_data, y_data) for input, concept, and label data,
46 | respectively
47 | :param data_gen_u: Tensorflow Dataset, returning tuples of points
48 | (x_data, y_data). In the fully-unsupervised case (where y-labels
49 | are not available as well), y_data should be set to -1
50 | :return: Trains the underlying SSCC concept labeler
51 | '''
52 | pass
53 |
54 | @abstractmethod
55 | def predict(self, data_gen_u):
56 | '''
57 | Method for predicting the concept labels from the input data
58 | :param data_gen_u: Tensorflow Dataset, returning tuples of points
59 | (x_data, y_data). In the fully-unsupervised case (where y-labels
60 | are not available as well), y_data should be set to -1
61 | :return: numpy array of predicted samples, corresponding to the input
62 | points
63 | '''
64 | pass
65 |
66 |
67 | class SSCC_CME(SSCClassifier):
68 | '''
69 | CME Concept Predictor Wrapper
70 | '''
71 |
72 | def __init__(self, **kwargs):
73 | '''
74 | :param kwargs: Required arguments:
75 | - "model" : the underlying trained model to extract concepts from
76 | - "n_concepts" : the number of concepts
77 | '''
78 | super().__init__(**kwargs)
79 | # Extract model parameter
80 | model = kwargs["base_model"]
81 | # Copy all the other parameters into the params dict.
82 | params = {
83 | key: val for (key, val) in kwargs.items() if key != "base_model"
84 | }
85 | # Create the underlying ItCModel concept labeler
86 | self.concept_predictor = ItCModel(model, **params)
87 |
88 | def fit(self, data_gen_l, data_gen_u):
89 | # Train underlying CME concept predictor
90 | # The default ItCModel accepts tf_datasets without the y-labels during
91 | # training. Thus, we filter out the y-labels from data_gen_l and
92 | # data_gen_u via a mapping
93 | ds_l, ds_u = remove_ds_el(data_gen_l), remove_ds_el(data_gen_u)
94 | self.concept_predictor.train(ds_l, ds_u)
95 |
96 | def predict(self, data_gen_u):
97 | # Run the underlying CME predictor
98 | # The default ItCModel accepts tf_datasets without the y-labels during
99 | # prediction. Thus, we filter out the y-labels from data_gen_u via a
100 | # mapping
101 | ds_u = remove_ds_el(data_gen_u)
102 | predicted = self.concept_predictor.predict_concepts(ds_u)
103 | return predicted
104 |
105 |
106 | class SSCC_CBM(SSCClassifier):
107 | '''
108 | CBM Model wrapper
109 | '''
110 | def __init__(self, **kwargs):
111 | '''
112 | :param kwargs: Required arguments:
113 | - "model" : the underlying trained model to extract concepts from
114 | - "layer_id" : layer of the model to use as bottleneck
115 | - "n_classes" : number of classes
116 | - "n_concept_vals" : number of concept values
117 | - "multi_task" : whether to use the multi-task setup, or sigmoid
118 | setup (Optional)
119 | - "end_to_end" : whether to create an end-to-end keras model
120 | '''
121 | super().__init__(**kwargs)
122 | # Extract required parameters
123 | model = kwargs["base_model"]
124 | layer_id = kwargs["layer_id"]
125 | n_classes = kwargs["n_classes"]
126 | n_concept_vals = kwargs["n_concept_vals"]
127 | multi_task = kwargs.get("multi_task", False)
128 | end_to_end = kwargs.get("end_to_end", False)
129 | c_epochs = kwargs.get("epochs", 10)
130 | c_extr_path = kwargs["save_path"]
131 | c_optimizer = kwargs.get("c_optimizer", None)
132 | overwrite = kwargs.get("overwrite", True)
133 | lambda_param = kwargs.get("lambda_param", 1.0)
134 |
135 | # Create the underlying CBM concept model
136 | self.n_epochs = c_epochs
137 | self.log_path = kwargs["log_path"]
138 | self.do_logged_fit = kwargs.get("logged_fit", False)
139 | self.n_concepts = len(n_concept_vals)
140 | self.n_train_predictor = kwargs.get("n_train_predictor", 2000)
141 | self.concept_predictor = ConceptBottleneckModel(
142 | model,
143 | layer_id,
144 | n_classes,
145 | n_concept_vals,
146 | c_extr_path,
147 | multi_task=multi_task,
148 | end_to_end=end_to_end,
149 | c_epochs=c_epochs,
150 | c_optimizer=c_optimizer,
151 | overwrite=True,
152 | lambda_param=lambda_param,
153 | )
154 |
155 | def fit(self, data_gen_l, data_gen_u):
156 | if self.do_logged_fit:
157 | self.logged_fit(data_gen_l, n_epochs=self.n_epochs, frequency=5)
158 | else:
159 | self.concept_predictor.train(data_gen_l)
160 |
161 | def logged_fit(self, data_gen_l, n_epochs=60, frequency=10):
162 |
163 | all_c_accs = []
164 | pred_overwrite = self.concept_predictor.overwrite
165 | self.concept_predictor.overwrite = True
166 | self.concept_predictor.n_epochs = frequency
167 |
168 | for i in range(0, n_epochs, frequency):
169 | self.concept_predictor.train(data_gen_l)
170 |
171 | # Train the concept predictor models
172 | c_true = np.array([elem[1].numpy() for elem in data_gen_l])
173 | c_pred = self.predict(data_gen_l)
174 | c_accs = compute_accuracies(c_true, c_pred)
175 | all_c_accs.append(c_accs)
176 | print("Accuracies: ", c_accs)
177 | print("\n" * 3)
178 | print(f"Ran {i + frequency}/{n_epochs} epochs...")
179 |
180 | all_c_accs = np.array(all_c_accs)
181 | fpath = os.path.join(self.log_path, "freq_accuracies.txt")
182 | np.savetxt(fpath, all_c_accs, fmt='%2.4f')
183 |
184 | self.concept_predictor.overwrite = pred_overwrite
185 | self.concept_predictor.n_epochs = self.n_epochs
186 |
187 | def predict(self, data_gen_u):
188 | ds_u = remove_ds_el(data_gen_u)
189 | predicted = self.concept_predictor.predict(ds_u)
190 | return predicted
191 |
192 |
193 | class SSCC_Group_VAEArgmax(SSCClassifier):
194 | '''
195 | Weakly-supervised VAE Model wrapper
196 | '''
197 | def __init__(self, **kwargs):
198 | super().__init__(**kwargs)
199 | self.n_changed_c = kwargs["k"]
200 | self.n_concept_vals = kwargs["n_concept_vals"]
201 | self.n_concepts = len(self.n_concept_vals)
202 | self.batch_size = kwargs.get("batch_size", 256)
203 | self.n_epochs = kwargs.get("epochs", 10)
204 | loss_fn = kwargs["loss_fn"]
205 | encoder_fn = kwargs["encoder_fn"]
206 | decoder_fn = kwargs["decoder_fn"]
207 | latent_dim = kwargs["latent_dim"]
208 | input_shape = kwargs["input_shape"]
209 | self.log_path = kwargs["log_path"]
210 | self.save_path = kwargs["save_path"]
211 | self.do_logged_fit = kwargs.get("logged_fit", False)
212 | self.overwrite = kwargs.get("overwrite", False)
213 | self.n_train_predictor = kwargs.get("n_train_predictor", 2000)
214 | self.optimizer = kwargs.get(
215 | "optimizer",
216 | tf.keras.optimizers.Adam(
217 | beta_1=0.9,
218 | beta_2=0.999,
219 | epsilon=1e-8,
220 | learning_rate=0.001,
221 | ),
222 | )
223 | self.vae = GroupVAEArgmax(
224 | latent_dim,
225 | encoder_fn,
226 | decoder_fn,
227 | loss_fn,
228 | input_shape,
229 | )
230 | self.vae.compile(optimizer=self.optimizer)
231 | # Here, we rely on using an ensemble of models (one per concept) for
232 | # predicting the concepts from the latent factors
233 | self.concept_predictors = [
234 | GradientBoostingClassifier() for _ in range(self.n_concepts)
235 | ]
236 |
237 | # Load concept extractor, if the model already exists
238 | if (not self.overwrite) and (os.path.exists(self.save_path)):
239 | # Need to call model before loading, in used version of Tf
240 | self.vae(np.zeros([1]+input_shape))
241 | self.vae.load_weights(self.save_path)
242 |
243 | def _build_paired_dataset(self, ds):
244 |
245 | random_state = np.random.RandomState(0)
246 |
247 | # TODO: temporarily rely on non-tf-data implementation
248 | from datasets.dataset_utils import get_latent_bases, latent_to_index
249 | latent_sizes = np.array(self.n_concept_vals)
250 | latents_bases = get_latent_bases(latent_sizes)
251 | x_data = np.array([elem[0].numpy() for elem in ds])
252 | c_data = np.array([elem[1].numpy() for elem in ds])
253 | c_ids = np.array([
254 | latent_to_index(elem[1].numpy(), latents_bases) for elem in ds
255 | ])
256 |
257 | x_modified = []
258 | label_ids = []
259 | x_filtered_data = []
260 |
261 | n_pairs, n_non_pairs = 0, 0
262 |
263 | for i in range(c_data.shape[0]):
264 | c = c_data[i]
265 | x = x_data[i]
266 |
267 | if self.n_changed_c == -1:
268 | k_observed = random_state.randint(1, self.n_concepts)
269 | else:
270 | k_observed = self.n_changed_c
271 |
272 | index_list = random_state.choice(
273 | c.shape[0],
274 | random_state.choice([1, k_observed]),
275 | replace=False,
276 | )
277 | idx = -1
278 | c_m = np.copy(c)
279 |
280 | for index in index_list:
281 | v = np.random.choice(np.arange(self.n_concept_vals[index]))
282 | c_m[index] = v
283 | idx = index
284 |
285 | x_m = np.where(c_ids == latent_to_index(c_m, latents_bases))[0]
286 |
287 | if len(x_m) > 0:
288 | n_pairs += 1
289 | np.random.shuffle(x_m)
290 | x_m = x_m[0]
291 | x_m = x_data[x_m]
292 | else:
293 | n_non_pairs += 1
294 | continue
295 |
296 | x_filtered_data.append(x)
297 | x_modified.append(x_m)
298 | label_ids.append(idx)
299 |
300 | x_filtered_data = np.array(x_filtered_data)
301 | x_modified = np.array(x_modified)
302 | label_ids = np.array(label_ids)
303 | x_pairs = np.concatenate([x_filtered_data, x_modified], axis=1)
304 |
305 | print("Pairs: ", n_pairs)
306 | print("Non-pairs: ", n_non_pairs)
307 | print("% pairs: ", str(n_pairs/(1. * (n_pairs+n_non_pairs)) * 100.))
308 |
309 | paired_ds = tf.data.Dataset.from_tensor_slices((x_pairs, label_ids))
310 | return paired_ds
311 |
312 | def fit(self, data_gen_l, data_gen_u):
313 | if not ((not self.overwrite) and (os.path.exists(self.save_path))):
314 | # Train the VAE
315 | paired_ds_l = self._build_paired_dataset(data_gen_l)
316 |
317 | self.logged_fit(
318 | self.vae,
319 | paired_ds_l,
320 | data_gen_l,
321 | n_epochs=self.n_epochs,
322 | frequency=5,
323 | )
324 |
325 | def logged_fit(
326 | self,
327 | model,
328 | training_gen,
329 | eval_gen,
330 | n_epochs=60,
331 | frequency=10,
332 | ):
333 |
334 | cp_callback = tf.keras.callbacks.ModelCheckpoint(
335 | filepath=self.save_path,
336 | verbose=True,
337 | save_best_only=True,
338 | monitor='loss',
339 | mode='auto',
340 | save_freq='epoch',
341 | )
342 | callbacks = [cp_callback]
343 |
344 | all_c_accs = []
345 |
346 | eval_gen_xs = remove_ds_el(remove_ds_el(eval_gen))
347 |
348 | if not self.do_logged_fit:
349 | frequency = n_epochs
350 |
351 | for i in range(0, n_epochs, frequency):
352 | model.fit(
353 | training_gen.batch(self.batch_size),
354 | epochs=frequency,
355 | callbacks=callbacks,
356 | )
357 |
358 | # Train the concept predictor models
359 | c_data = np.array([elem[1].numpy() for elem in eval_gen])
360 | z_data = model.predict(eval_gen_xs.batch(self.batch_size))
361 |
362 | c_accs = []
363 |
364 | for j in range(self.n_concepts):
365 | # Train concept label predictors, and evaluate their predictive
366 | # accuracy
367 | z_train, z_test, c_train, c_test = train_test_split(
368 | z_data,
369 | c_data[:, j],
370 | test_size=0.15,
371 | )
372 |
373 | if self.n_train_predictor is not None:
374 | z_train, c_train = (
375 | z_train[:self.n_train_predictor],
376 | c_train[:self.n_train_predictor],
377 | )
378 | # Note: here we assume sklearn .fit() re-initializes all
379 | # previous params
380 | self.concept_predictors[j].fit(z_train, c_train)
381 |
382 | accuracy = accuracy_score(
383 | c_test,
384 | self.concept_predictors[j].predict(z_test),
385 | )
386 | print("Accuracy of concept ", str(j), " : ", str(accuracy))
387 | c_accs.append(accuracy)
388 |
389 | all_c_accs.append(c_accs)
390 | print("\n"*3)
391 | print(f"Ran {i+frequency}/{n_epochs} epochs...")
392 |
393 | all_c_accs = np.array(all_c_accs)
394 | fpath = os.path.join(self.log_path, "freq_accuracies.txt")
395 | np.savetxt(fpath, all_c_accs, fmt='%2.4f')
396 |
397 | def predict(self, data_gen_u):
398 | data_x = remove_ds_el(data_gen_u)
399 | z_data = self.vae.predict(data_x.batch(self.batch_size))
400 | c_data = np.stack(
401 | [
402 | self.concept_predictors[i].predict(z_data)
403 | for i in range(self.n_concepts)
404 | ],
405 | axis=-1,
406 | )
407 | return c_data
408 |
409 |
410 | class SSCC_BetaVAE(SSCClassifier):
411 | '''
412 | Beta-VAE Model wrapper
413 | '''
414 | def __init__(self, **kwargs):
415 | super().__init__(**kwargs)
416 | self.n_concept_vals = kwargs["n_concept_vals"]
417 | self.n_concepts = len(self.n_concept_vals)
418 | self.batch_size = kwargs.get("batch_size", 256)
419 | self.n_epochs = kwargs.get("epochs", 10)
420 | beta = kwargs.get("beta", 1)
421 | loss_fn = kwargs["loss_fn"]
422 | encoder_fn = kwargs["encoder_fn"]
423 | decoder_fn = kwargs["decoder_fn"]
424 | latent_dim = kwargs["latent_dim"]
425 | input_shape = kwargs["input_shape"]
426 | self.model_path = kwargs.get("save_path", None)
427 | self.retrain = kwargs.get("overwrite", False)
428 | self.n_train_predictor = kwargs.get("n_train_predictor", 2000)
429 | self.optimizer = kwargs.get(
430 | "optimizer",
431 | tf.keras.optimizers.Adam(
432 | beta_1=0.9,
433 | beta_2=0.999,
434 | epsilon=1e-8,
435 | learning_rate=0.001,
436 | ),
437 | )
438 | self.vae = BetaVAE(
439 | latent_dim,
440 | encoder_fn,
441 | decoder_fn,
442 | loss_fn,
443 | input_shape,
444 | beta,
445 | )
446 | self.vae.compile(optimizer=self.optimizer)
447 | # Here, we rely on an ensemble of models (one per concept) for
448 | # predicting concepts from latent factors
449 | self.concept_predictors = [
450 | GradientBoostingClassifier() for _ in range(self.n_concepts)
451 | ]
452 |
453 | # Load concept extractor, if the model already exists
454 | if (self.model_path is not None) and (os.path.exists(self.model_path)):
455 | # Need to call model before loading, in used version of Tf
456 | self.vae(np.zeros([1]+input_shape))
457 | self.vae.load_weights(self.model_path)
458 |
459 | def fit(self, data_gen_l, data_gen_u):
460 | if not (
461 | (self.model_path is not None) and
462 | (os.path.exists(self.model_path)) and
463 | (not self.retrain)
464 | ):
465 |
466 | cp_callback = tf.keras.callbacks.ModelCheckpoint(
467 | filepath=self.model_path,
468 | verbose=True,
469 | save_best_only=True,
470 | monitor='loss',
471 | mode='auto',
472 | save_freq='epoch',
473 | )
474 | callbacks = [cp_callback]
475 | self.vae.fit(
476 | data_gen_l.batch(self.batch_size),
477 | epochs=self.n_epochs,
478 | callbacks=callbacks,
479 | )
480 |
481 | # Train the concept predictor models
482 | c_data = np.array([elem[1].numpy() for elem in data_gen_l])
483 | z_data = self.vae.predict(
484 | remove_ds_el(remove_ds_el(data_gen_l)).batch(self.batch_size)
485 | )
486 |
487 | for i in range(self.n_concepts):
488 | from sklearn.model_selection import train_test_split
489 | from sklearn.metrics import accuracy_score
490 | z_train, z_test, c_train, c_test = train_test_split(
491 | z_data,
492 | c_data[:, i],
493 | test_size=0.15,
494 | )
495 |
496 | if self.n_train_predictor is not None:
497 | z_train, c_train = (
498 | z_train[:self.n_train_predictor],
499 | c_train[:self.n_train_predictor],
500 | )
501 |
502 | self.concept_predictors[i].fit(z_train, c_train)
503 | print(
504 | "Accuracy of concept ",
505 | i,
506 | " : ",
507 | accuracy_score(
508 | c_test,
509 | self.concept_predictors[i].predict(z_test),
510 | ),
511 | )
512 |
513 | def predict(self, data_gen_u):
514 | data_x = remove_ds_el(data_gen_u)
515 | z_data = self.vae.predict(data_x.batch(self.batch_size))
516 | return np.stack(
517 | [
518 | self.concept_predictors[i].predict(z_data)
519 | for i in range(self.n_concepts)
520 | ],
521 | axis=-1,
522 | )
523 | return c_data
524 |
525 |
526 | def remove_ds_el(data_gen):
527 | '''
528 | Utility function for removing the last dimension data from a generator.
529 | :param data_gen: tf.data generator
530 | :return: data generator without the last tuple element, for every item in
531 | data_gen
532 | '''
533 | return data_gen.map(
534 | lambda *args: tuple([args[i] for i in range(len(args)-1)])
535 | )
536 |
537 |
--------------------------------------------------------------------------------
/concepts_xai/methods/VAE/baseVAE.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | class BaseVAE(tf.keras.Model):
5 | """Abstract base class of a VAE model."""
6 |
7 | def __init__(self, encoder, decoder, loss_fn, **kwargs):
8 | super(BaseVAE, self).__init__(**kwargs)
9 | self.encoder = encoder
10 | self.decoder = decoder
11 | self.loss_fn = loss_fn
12 | self.metric_names = ["loss", "reconstruction_loss", "elbo"]
13 | self.metrics_dict = {
14 | name: tf.keras.metrics.Mean(name=name)
15 | for name in self.metric_names
16 | }
17 |
18 | @property
19 | def metrics(self):
20 | return [self.metrics_dict[name] for name in self.metric_names]
21 |
22 | def encode(self, x):
23 | return self.encoder(x)
24 |
25 | def decode(self, z):
26 | return self.decoder(z)
27 |
28 | def sample_from_latent_distribution(self, z_mean, z_logvar):
29 | """
30 | Samples from the Gaussian distribution defined by z_mean and z_logvar.
31 | """
32 | return tf.add(
33 | z_mean,
34 | tf.exp(z_logvar / 2) * tf.random.normal(tf.shape(z_mean), 0, 1),
35 | name="sampled_latent_variable"
36 | )
37 |
38 | def generate_random_sample(self, z=None, num_samples=1, seed=None):
39 | (_, latent_size), (_, _) = self.encoder.output_shape
40 | if z is None:
41 | z = tf.random.normal(
42 | shape=[num_samples, latent_size],
43 | mean=0.0,
44 | stddev=1.0,
45 | seed=seed,
46 | )
47 | return self.decoder(z)
48 |
49 | def _compute_losses(self, x, is_training=False):
50 | z_mean, z_logvar = self.encoder(x, training=is_training)
51 | z_sampled = self.sample_from_latent_distribution(z_mean, z_logvar)
52 | reconstructions = self.decoder(z_sampled, training=is_training)
53 | per_sample_loss = self.loss_fn(x, reconstructions)
54 | reconstruction_loss = tf.reduce_mean(per_sample_loss)
55 | kl_loss = compute_gaussian_kl(z_mean, z_logvar)
56 | regularizer = self.regularizer(kl_loss, z_mean, z_logvar, z_sampled)
57 | loss = tf.add(reconstruction_loss, regularizer, name="loss")
58 | elbo = tf.add(reconstruction_loss, kl_loss, name="elbo")
59 |
60 | return loss, reconstruction_loss, elbo
61 |
62 | def update_metrics(self, losses):
63 | for (loss_name, loss) in losses:
64 | self.metrics_dict[loss_name].update_state(loss)
65 |
66 | def train_step(self, inputs):
67 | """Executes one training step and returns the loss.
68 | This function computes the loss and gradients, and uses the latter to
69 | update the model's parameters.
70 | """
71 | with tf.GradientTape() as tape:
72 | loss, reconstruction_loss, elbo = self._compute_losses(
73 | inputs,
74 | is_training=True,
75 | )
76 |
77 | gradients = tape.gradient(loss, self.trainable_variables)
78 | self.optimizer.apply_gradients(
79 | zip(gradients, self.trainable_variables)
80 | )
81 | self.update_metrics([
82 | ("loss", loss),
83 | ("reconstruction_loss", reconstruction_loss),
84 | ("elbo", elbo),
85 | ])
86 |
87 | return {
88 | name: self.metrics_dict[name].result()
89 | for name in self.metric_names
90 | }
91 |
92 | def test_step(self, inputs):
93 | """Executes one test step and returns the loss.
94 | This function computes the loss, without updating the model parameters.
95 | """
96 | loss, reconstruction_loss, elbo = self._compute_losses(
97 | inputs,
98 | is_training=False,
99 | )
100 | self.update_metrics([
101 | ("loss", loss),
102 | ("reconstruction_loss", reconstruction_loss),
103 | ("elbo", elbo),
104 | ])
105 |
106 | return {
107 | name: self.metrics_dict[name].result()
108 | for name in self.metric_names
109 | }
110 |
111 | def call(self, x, **kwargs):
112 | '''
113 | Here, we assume that calling the VAE model returns the encoder output,
114 | or the decoder output
115 | Default behaviour is to return the encoder output.
116 | To return the decoder output, pass the "decode" argument as True in the
117 | kwargs dict
118 | '''
119 |
120 | decode = kwargs.get("decode", False)
121 | z_mean, z_logvar = self.encoder(x, training=False)
122 | z_sampled = self.sample_from_latent_distribution(z_mean, z_logvar)
123 |
124 | if decode:
125 | return self.decoder(z_sampled, training=False)
126 | return z_sampled
127 |
128 |
129 | def compute_gaussian_kl(z_mean, z_logvar):
130 | """Compute KL divergence between input Gaussian and Standard Normal."""
131 | return tf.reduce_mean(
132 | 0.5 * tf.reduce_sum(
133 | tf.square(z_mean) + tf.exp(z_logvar) - z_logvar - 1, [1]
134 | ),
135 | name="kl_loss",
136 | )
137 |
138 |
139 |
--------------------------------------------------------------------------------
/concepts_xai/methods/VAE/betaVAE.py:
--------------------------------------------------------------------------------
1 | from concepts_xai.methods.VAE.baseVAE import BaseVAE
2 |
3 |
4 | class BetaVAE(BaseVAE):
5 | """BetaVAE model."""
6 |
7 | def __init__(self, encoder, decoder, loss_fn, beta=1, **kwargs):
8 | """Creates a beta-VAE model.
9 |
10 | Implementing Eq. 4 of "beta-VAE: Learning Basic Visual Concepts with a
11 | Constrained Variational Framework"
12 | (https://openreview.net/forum?id=Sy2fzU9gl).
13 |
14 | :param beta: Hyperparameter for the regularizer.
15 | """
16 | super(BetaVAE, self).__init__(
17 | encoder=encoder,
18 | decoder=decoder,
19 | loss_fn=loss_fn,
20 | **kwargs
21 | )
22 | self.beta = beta
23 |
24 | def regularizer(self, kl_loss, z_mean, z_logvar, z_sampled):
25 | return self.beta * kl_loss
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/concepts_xai/methods/VAE/losses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow_probability as tfp
3 | import tensorflow as tf
4 |
5 |
6 | def bernoulli_loss(
7 | true_images,
8 | reconstructed_images,
9 | activation,
10 | subtract_true_image_entropy=False,
11 | ):
12 |
13 | """Computes the Bernoulli loss."""
14 | flattened_dim = np.prod(true_images.get_shape().as_list()[1:])
15 | reconstructed_images = tf.reshape(
16 | reconstructed_images,
17 | shape=[-1, flattened_dim]
18 | )
19 | true_images = tf.reshape(true_images, shape=[-1, flattened_dim])
20 |
21 | # Because true images are not binary, the lower bound in the xent is not
22 | # zero: the lower bound in the xent is the entropy of the true images.
23 | if subtract_true_image_entropy:
24 | dist = tfp.distributions.Bernoulli(
25 | probs=tf.clip_by_value(true_images, 1e-6, 1 - 1e-6)
26 | )
27 | loss_lower_bound = tf.reduce_sum(dist.entropy(), axis=1)
28 | else:
29 | loss_lower_bound = 0
30 |
31 | if activation == "logits":
32 | loss = tf.reduce_sum(
33 | tf.nn.sigmoid_cross_entropy_with_logits(
34 | logits=reconstructed_images,
35 | labels=true_images
36 | ),
37 | axis=1,
38 | )
39 | elif activation == "tanh":
40 | reconstructed_images = tf.clip_by_value(
41 | tf.nn.tanh(reconstructed_images) / 2 + 0.5, 1e-6, 1 - 1e-6
42 | )
43 | loss = -tf.reduce_sum(
44 | (
45 | true_images * tf.math.log(reconstructed_images) +
46 | (1 - true_images) * tf.math.log(1 - reconstructed_images)
47 | ),
48 | axis=1,
49 | )
50 | else:
51 | raise NotImplementedError("Activation not supported.")
52 |
53 | return loss - loss_lower_bound
54 |
55 |
56 | def l2_loss(true_images, reconstructed_images, activation):
57 | """Computes the l2 loss."""
58 | if activation == "logits":
59 | return tf.reduce_sum(
60 | tf.square(true_images - tf.nn.sigmoid(reconstructed_images)),
61 | [1, 2, 3]
62 | )
63 | elif activation == "tanh":
64 | reconstructed_images = tf.nn.tanh(reconstructed_images) / 2 + 0.5
65 | return tf.reduce_sum(
66 | tf.square(true_images - reconstructed_images),
67 | [1, 2, 3],
68 | )
69 | else:
70 | raise NotImplementedError("Activation not supported.")
71 |
72 |
73 | def bernoulli_fn_wrapper(
74 | activation="logits",
75 | subtract_true_image_entropy=False,
76 | ):
77 |
78 | def loss_fn(true_images, reconstructed_images):
79 | return bernoulli_loss(
80 | true_images,
81 | reconstructed_images,
82 | activation,
83 | subtract_true_image_entropy,
84 | )
85 | return loss_fn
86 |
87 |
88 | def l2_loss_wrapper(activation="logits"):
89 | def loss_fn(true_images, reconstructed_images):
90 | return l2_loss(true_images, reconstructed_images, activation)
91 |
92 | return loss_fn
93 |
--------------------------------------------------------------------------------
/concepts_xai/methods/VAE/weak_vae.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from concepts_xai.methods.VAE.baseVAE import BaseVAE, compute_gaussian_kl
3 |
4 |
5 | class GroupVAEBase(BaseVAE):
6 | """Beta-VAE with averaging from https://arxiv.org/abs/1809.02383."""
7 |
8 | def __init__(
9 | self,
10 | encoder,
11 | decoder,
12 | loss_fn,
13 | beta=1,
14 | **kwargs,
15 | ):
16 | """
17 | Creates a beta-VAE model with additional averaging for weak
18 | supervision.
19 | Based on https://arxiv.org/abs/1809.02383.
20 |
21 | :param beta: Hyperparameter for KL divergence.
22 | """
23 | super(GroupVAEBase, self).__init__(
24 | encoder=encoder,
25 | decoder=decoder,
26 | loss_fn=loss_fn,
27 | **kwargs,
28 | )
29 | self.beta = beta
30 | self.metric_names.append("regularizer")
31 | self.metrics_dict["regularizer"] = tf.keras.metrics.Mean(
32 | name="regularizer"
33 | )
34 |
35 | def regularizer(self, kl_loss, z_mean, z_logvar, z_sampled):
36 | return self.beta * kl_loss
37 |
38 | def _split_sample_pairs(self, x):
39 | '''
40 | Note: each point contains two frames stacked along the first dim and
41 | integer labels.
42 | '''
43 | if isinstance(x, tuple):
44 | assert len(x) == 2, f'Expected samples to come as pairs f{x}'
45 | return (x[0], x[1])
46 | data_shape = x.get_shape().as_list()[1:]
47 | assert data_shape[0] % 2 == 0, (
48 | "1st dimension of concatenated pairs assumed to be even"
49 | )
50 | data_shape[0] = data_shape[0] // 2
51 | return x[:, :data_shape[0], ...], x[:, data_shape[0]:, ...]
52 |
53 | def _compute_losses_weak(self, x_1, x_2, is_training, labels=None):
54 |
55 | z_mean, z_logvar = self.encoder(x_1, training=is_training)
56 | z_mean_2, z_logvar_2 = self.encoder(x_2, training=is_training)
57 | if labels is not None:
58 | labels = tf.squeeze(
59 | tf.one_hot(labels, z_mean.get_shape().as_list()[1])
60 | )
61 | kl_per_point = compute_kl(z_mean, z_mean_2, z_logvar, z_logvar_2)
62 |
63 | new_mean = 0.5 * z_mean + 0.5 * z_mean_2
64 | var_1 = tf.exp(z_logvar)
65 | var_2 = tf.exp(z_logvar_2)
66 | new_log_var = tf.math.log(0.5*var_1 + 0.5*var_2)
67 |
68 | mean_sample_1, log_var_sample_1 = self.aggregate(
69 | z_mean,
70 | z_logvar,
71 | new_mean,
72 | new_log_var,
73 | labels,
74 | kl_per_point,
75 | )
76 | mean_sample_2, log_var_sample_2 = self.aggregate(
77 | z_mean_2,
78 | z_logvar_2,
79 | new_mean,
80 | new_log_var,
81 | labels,
82 | kl_per_point,
83 | )
84 |
85 | z_sampled_1 = self.sample_from_latent_distribution(
86 | mean_sample_1,
87 | log_var_sample_1,
88 | )
89 | z_sampled_2 = self.sample_from_latent_distribution(
90 | mean_sample_2,
91 | log_var_sample_2,
92 | )
93 |
94 | reconstructions_1 = self.decoder(
95 | z_sampled_1,
96 | training=is_training
97 | )
98 | reconstructions_2 = self.decoder(
99 | z_sampled_2,
100 | training=is_training,
101 | )
102 |
103 | per_sample_loss_1 = self.loss_fn(x_1, reconstructions_1)
104 | per_sample_loss_2 = self.loss_fn(x_2, reconstructions_2)
105 | reconstruction_loss_1 = tf.reduce_mean(per_sample_loss_1)
106 | reconstruction_loss_2 = tf.reduce_mean(per_sample_loss_2)
107 | reconstruction_loss = (
108 | 0.5 * reconstruction_loss_1 + 0.5 * reconstruction_loss_2
109 | )
110 |
111 | kl_loss_1 = compute_gaussian_kl(mean_sample_1, log_var_sample_1)
112 | kl_loss_2 = compute_gaussian_kl(mean_sample_2, log_var_sample_2)
113 | kl_loss = 0.5 * kl_loss_1 + 0.5 * kl_loss_2
114 |
115 | regularizer = self.regularizer(kl_loss, None, None, None)
116 |
117 | loss = tf.add(reconstruction_loss, regularizer, name="loss")
118 | elbo = tf.add(reconstruction_loss, kl_loss, name="elbo")
119 |
120 | return reconstruction_loss, regularizer, loss, elbo
121 |
122 | def _split_labels(self, inputs):
123 | return inputs, None
124 |
125 | def train_step(self, inputs):
126 | x, labels = self._split_labels(inputs)
127 | x_1, x_2 = self._split_sample_pairs(x)
128 |
129 | with tf.GradientTape() as tape:
130 | rec_loss, regularizer, loss, elbo = self._compute_losses_weak(
131 | x_1=x_1,
132 | x_2=x_2,
133 | is_training=True,
134 | labels=labels,
135 | )
136 |
137 | gradients = tape.gradient(loss, self.trainable_variables)
138 | self.optimizer.apply_gradients(
139 | zip(gradients, self.trainable_variables)
140 | )
141 | self.update_metrics([
142 | ("loss", loss),
143 | ("reconstruction_loss", rec_loss),
144 | ("elbo", elbo),
145 | ("regularizer",regularizer),
146 | ])
147 |
148 | return {
149 | name: self.metrics_dict[name].result()
150 | for name in self.metric_names
151 | }
152 |
153 | def test_step(self, inputs):
154 | x, labels = self._split_labels(inputs)
155 | x_1, x_2 = self._split_sample_pairs(x)
156 | rec_loss, regularizer, loss, elbo = self._compute_losses_weak(
157 | x_1=x_1,
158 | x_2=x_2,
159 | is_training=False,
160 | labels=labels,
161 | )
162 | self.update_metrics([
163 | ("loss", loss),
164 | ("reconstruction_loss", rec_loss),
165 | ("elbo", elbo),
166 | ("regularizer", regularizer)
167 | ])
168 |
169 | return {
170 | name: self.metrics_dict[name].result()
171 | for name in self.metric_names
172 | }
173 |
174 |
175 | class GroupVAEArgmax(GroupVAEBase):
176 | """Class implementing the group-VAE without any label."""
177 |
178 | def aggregate(
179 | self,
180 | z_mean,
181 | z_logvar,
182 | new_mean,
183 | new_log_var,
184 | labels,
185 | kl_per_point,
186 | ):
187 | return aggregate_argmax(
188 | z_mean,
189 | z_logvar,
190 | new_mean,
191 | new_log_var,
192 | kl_per_point,
193 | )
194 |
195 |
196 | class GroupVAELabels(GroupVAEBase):
197 | """
198 | Class implementing the group-VAE with labels on which factor is shared.
199 | """
200 |
201 | def _split_labels(self, inputs):
202 | return inputs
203 |
204 | def aggregate(
205 | self,
206 | z_mean,
207 | z_logvar,
208 | new_mean,
209 | new_log_var,
210 | labels,
211 | kl_per_point,
212 | ):
213 | return aggregate_labels(
214 | z_mean,
215 | z_logvar,
216 | new_mean,
217 | new_log_var,
218 | labels,
219 | kl_per_point,
220 | )
221 |
222 |
223 | class MLVae(GroupVAEBase):
224 | """Beta-VAE with averaging from https://arxiv.org/abs/1705.08841."""
225 |
226 | def _compute_losses_weak(self, x_1, x_2, is_training, labels=None):
227 | z_mean, z_logvar = self.encoder(x_1, training=is_training)
228 | z_mean_2, z_logvar_2 = self.encoder(x_2, training=is_training)
229 | if labels is not None:
230 | labels = tf.squeeze(
231 | tf.one_hot(labels, z_mean.get_shape().as_list()[1])
232 | )
233 | kl_per_point = compute_kl(z_mean, z_mean_2, z_logvar, z_logvar_2)
234 |
235 | var_1 = tf.exp(z_logvar)
236 | var_2 = tf.exp(z_logvar_2)
237 | new_var = 2 * var_1 * var_2 / (var_1 + var_2)
238 | new_mean = ((z_mean / var_1) + (z_mean_2 / var_2)) * new_var * 0.5
239 |
240 | new_log_var = tf.math.log(new_var)
241 |
242 | mean_sample_1, log_var_sample_1 = self.aggregate(
243 | z_mean,
244 | z_logvar,
245 | new_mean,
246 | new_log_var,
247 | labels,
248 | kl_per_point,
249 | )
250 | mean_sample_2, log_var_sample_2 = self.aggregate(
251 | z_mean_2,
252 | z_logvar_2,
253 | new_mean,
254 | new_log_var,
255 | labels,
256 | kl_per_point,
257 | )
258 |
259 | z_sampled_1 = self.sample_from_latent_distribution(
260 | mean_sample_1,
261 | log_var_sample_1,
262 | )
263 | z_sampled_2 = self.sample_from_latent_distribution(
264 | mean_sample_2,
265 | log_var_sample_2,
266 | )
267 |
268 | reconstructions_1 = self.decoder(
269 | z_sampled_1,
270 | training=is_training
271 | )
272 | reconstructions_2 = self.decoder(
273 | z_sampled_2,
274 | training=is_training,
275 | )
276 |
277 | per_sample_loss_1 = self.loss_fn(x_1, reconstructions_1)
278 | per_sample_loss_2 = self.loss_fn(x_2, reconstructions_2)
279 | reconstruction_loss_1 = tf.reduce_mean(per_sample_loss_1)
280 | reconstruction_loss_2 = tf.reduce_mean(per_sample_loss_2)
281 | reconstruction_loss = (
282 | 0.5 * reconstruction_loss_1 + 0.5 * reconstruction_loss_2
283 | )
284 |
285 | kl_loss_1 = compute_gaussian_kl(mean_sample_1, log_var_sample_1)
286 | kl_loss_2 = compute_gaussian_kl(mean_sample_2, log_var_sample_2)
287 | kl_loss = 0.5 * kl_loss_1 + 0.5 * kl_loss_2
288 |
289 | regularizer = self.regularizer(kl_loss, None, None, None)
290 |
291 | loss = tf.add(reconstruction_loss, regularizer, name="loss")
292 | elbo = tf.add(reconstruction_loss, kl_loss, name="elbo")
293 |
294 | return reconstruction_loss, regularizer, loss, elbo
295 |
296 |
297 | class MLVaeLabels(MLVae):
298 | """Class implementing the ML-VAE with labels on which factor is shared."""
299 |
300 | def _split_labels(self, inputs):
301 | return inputs
302 |
303 | def aggregate(
304 | self,
305 | z_mean,
306 | z_logvar,
307 | new_mean,
308 | new_log_var,
309 | labels,
310 | kl_per_point,
311 | ):
312 | return aggregate_labels(
313 | z_mean,
314 | z_logvar,
315 | new_mean,
316 | new_log_var,
317 | labels,
318 | kl_per_point,
319 | )
320 |
321 |
322 | class MLVaeArgmax(MLVae):
323 | """Class implementing the ML-VAE without any label."""
324 |
325 | def aggregate(
326 | self,
327 | z_mean,
328 | z_logvar,
329 | new_mean,
330 | new_log_var,
331 | labels,
332 | kl_per_point,
333 | ):
334 | return aggregate_argmax(
335 | z_mean,
336 | z_logvar,
337 | new_mean,
338 | new_log_var,
339 | kl_per_point,
340 | )
341 |
342 |
343 | def aggregate_labels(
344 | z_mean,
345 | z_logvar,
346 | new_mean,
347 | new_log_var,
348 | labels,
349 | kl_per_point,
350 | ):
351 | """Use labels to aggregate.
352 |
353 | Labels contains a one-hot encoding with a single 1 of a factor shared. We
354 | enforce which dimension of the latent code learn which factor (dimension 1
355 | learns factor 1) and we enforce that each factor of variation is encoded
356 | in a single dimension.
357 |
358 | Args:
359 | z_mean: Mean of the encoder distribution for the original image.
360 | z_logvar: Logvar of the encoder distribution for the original image.
361 | new_mean: Average mean of the encoder distribution of the pair of images.
362 | new_log_var: Average logvar of the encoder distribution of the pair of
363 | images.
364 | labels: One-hot-encoding with the position of the dimension that should not
365 | be shared.
366 | kl_per_point: Distance between the two encoder distributions (unused).
367 |
368 | Returns:
369 | Mean and logvariance for the new observation.
370 | """
371 | z_mean_averaged = tf.where(
372 | tf.math.equal(
373 | labels,
374 | tf.expand_dims(tf.reduce_max(labels, axis=1), 1)
375 | ),
376 | z_mean,
377 | new_mean,
378 | )
379 | z_logvar_averaged = tf.where(
380 | tf.math.equal(
381 | labels,
382 | tf.expand_dims(tf.reduce_max(labels, axis=1), 1)
383 | ),
384 | z_logvar,
385 | new_log_var,
386 | )
387 | return z_mean_averaged, z_logvar_averaged
388 |
389 |
390 | def aggregate_argmax(
391 | z_mean,
392 | z_logvar,
393 | new_mean,
394 | new_log_var,
395 | kl_per_point,
396 | ):
397 | """Argmax aggregation with adaptive k.
398 |
399 | The bottom k dimensions in terms of distance are not averaged. K is
400 | estimated adaptively by binning the distance into two bins of equal width.
401 |
402 | Args:
403 | z_mean: Mean of the encoder distribution for the original image.
404 | z_logvar: Logvar of the encoder distribution for the original image.
405 | new_mean: Average mean of the encoder distribution of the pair of images.
406 | new_log_var: Average logvar of the encoder distribution of the pair of
407 | images.
408 | kl_per_point: Distance between the two encoder distributions.
409 |
410 | Returns:
411 | Mean and logvariance for the new observation.
412 | """
413 | mask = tf.equal(tf.map_fn(discretize_in_bins, kl_per_point, tf.int32), 1)
414 | z_mean_averaged = tf.where(mask, z_mean, new_mean)
415 | z_logvar_averaged = tf.where(mask, z_logvar, new_log_var)
416 | return z_mean_averaged, z_logvar_averaged
417 |
418 |
419 | def discretize_in_bins(x):
420 | """Discretize a vector in two bins."""
421 | return tf.histogram_fixed_width_bins(
422 | x,
423 | [tf.reduce_min(x), tf.reduce_max(x)],
424 | nbins=2,
425 | )
426 |
427 |
428 | def compute_kl(z_1, z_2, logvar_1, logvar_2):
429 | var_1 = tf.exp(logvar_1)
430 | var_2 = tf.exp(logvar_2)
431 | return var_1/var_2 + tf.square(z_2-z_1)/var_2 - 1 + logvar_2 - logvar_1
432 |
--------------------------------------------------------------------------------
/concepts_xai/methods/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dmitrykazhdan/concept-based-xai/a86c6abd609ec2ef0cf56cb556e3f834a0962bdc/concepts_xai/methods/__init__.py
--------------------------------------------------------------------------------
/concepts_xai/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dmitrykazhdan/concept-based-xai/a86c6abd609ec2ef0cf56cb556e3f834a0962bdc/concepts_xai/utils/__init__.py
--------------------------------------------------------------------------------
/concepts_xai/utils/architectures.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.python.keras.layers import Dense, Dropout
3 |
4 |
5 | def small_cnn(input_shape, num_classes=10):
6 | '''
7 | CNN architecture used in CME (https://arxiv.org/abs/2010.13233) for the
8 | dSprites task
9 | :param input_shape: input sample shape
10 | :param num_classes: number of output classes
11 | :return: compiled keras model
12 | '''
13 |
14 | inputs = tf.keras.Input(shape=input_shape)
15 | x = get_cnn_body_layers(inputs)
16 | outputs = Dense(num_classes, activation='softmax')(x)
17 |
18 | model = tf.keras.Model(inputs=inputs, outputs=outputs)
19 |
20 | optimizer = tf.keras.optimizers.Adam(lr=1e-3, amsgrad=True)
21 |
22 | model.compile(optimizer=optimizer,
23 | loss='sparse_categorical_crossentropy',
24 | metrics=['acc'])
25 |
26 | return model
27 |
28 |
29 | def get_cnn_body_layers(inputs):
30 | x = tf.keras.layers.Conv2D(
31 | filters=32,
32 | kernel_size=4,
33 | strides=2,
34 | activation='relu',
35 | padding="same",
36 | name="e1",
37 | )(inputs)
38 | x = tf.keras.layers.Conv2D(
39 | filters=32,
40 | kernel_size=4,
41 | strides=2,
42 | activation='relu',
43 | padding="same",
44 | name="e2",
45 | )(x)
46 | x = tf.keras.layers.Conv2D(
47 | filters=64,
48 | kernel_size=2,
49 | strides=2,
50 | activation='relu',
51 | padding="same",
52 | name="e3",
53 | )(x)
54 | x = tf.keras.layers.Conv2D(
55 | filters=64,
56 | kernel_size=2,
57 | strides=2,
58 | activation='relu',
59 | padding="same",
60 | name="e4",
61 | )(x)
62 | x = tf.keras.layers.Flatten()(x)
63 | x = Dense(128, activation='relu')(x)
64 | x = Dropout(0.4)(x)
65 | x = Dense(64, activation='relu')(x)
66 | x = Dropout(0.2)(x)
67 | return x
68 |
69 |
70 | def multi_task_cnn(input_shape, num_concpet_values, concept_names=[]):
71 | '''
72 | :param input_shape:
73 | :param num_concpet_values: list of containing number of concepts values for
74 | each concepts (e.g. [3,6,40,32,3] for dSprites)
75 | :return: CNN arrchitecture model with shared convolutional layers and
76 | multiple heads for respective task to predict multiple concept values
77 | in different outputs color, shape, scale, rotation, x and y positions
78 | '''
79 |
80 | if concept_names != []:
81 | assert len(num_concpet_values) == len(concept_names), \
82 | "The number of concepts is different for the values and the names"
83 | inputs = tf.keras.Input(shape=input_shape)
84 |
85 | h = get_cnn_body_layers(inputs)
86 |
87 | outpuut_layers = []
88 | for i, c in enumerate(num_concpet_values):
89 | l = tf.keras.layers.Dense(
90 | num_concpet_values[i],
91 | activation="softmax",
92 | name="l" + str(i) if concept_names == [] else concept_names[i]
93 | )(h)
94 | outpuut_layers.append(l)
95 |
96 | model = tf.keras.Model(inputs=inputs, outputs=outpuut_layers)
97 |
98 | optimizer = tf.keras.optimizers.Adam(lr=1e-3, amsgrad=True)
99 |
100 | model.compile(
101 | optimizer=optimizer,
102 | loss=[
103 | tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
104 | for c in num_concpet_values
105 | ],
106 | metrics=['acc'],
107 | )
108 |
109 | return model
110 |
111 |
112 | def sigmoid_cnn(input_shape, num_concept_values, concept_names=[]):
113 | '''
114 | :param input_shape:
115 | :param num_concpet_values: list of containing number of concepts values for
116 | each concepts (e.g. [3,6,40,32,3] for dSprites)
117 | :return: CNN arrchitecture model with shared convolutional layers and
118 | multiple heads for respective task to predict multiple concept values
119 | in different outputs color, shape, scale, rotation, x and y positions
120 | '''
121 |
122 | if concept_names != []:
123 | assert len(num_concept_values) == len(concept_names), \
124 | "The number of concepts is different for the values and the names"
125 |
126 | inputs = tf.keras.Input(shape=input_shape)
127 |
128 | h = get_cnn_body_layers(inputs)
129 |
130 | output_layers = []
131 | for i, c in enumerate(num_concept_values):
132 | l = tf.keras.layers.Dense(
133 | num_concept_values[i],
134 | activation="sigmoid",
135 | name="l" + str(i) if concept_names == [] else concept_names[i],
136 | )(h)
137 | output_layers.append(l)
138 |
139 | model = tf.keras.Model(inputs=inputs, outputs=output_layers)
140 |
141 | optimizer = tf.keras.optimizers.Adam(lr=1e-3, amsgrad=True)
142 |
143 | model.compile(
144 | optimizer=optimizer,
145 | loss=[
146 | tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
147 | for c in num_concept_values
148 | ],
149 | metrics=['acc'],
150 | )
151 | return model
152 |
153 |
154 | def conv_encoder(input_shape, num_latent):
155 | """
156 | CNN encoder architecture used in the 'Challenging Common Assumptions in the
157 | Unsupervised Learning of Disentangled Representations' paper
158 | (https://arxiv.org/abs/1811.12359)
159 |
160 | Note: model is uncompiled
161 | """
162 |
163 | inputs = tf.keras.Input(shape=input_shape)
164 |
165 | e1 = tf.keras.layers.Conv2D(
166 | filters=32,
167 | kernel_size=4,
168 | strides=2,
169 | activation='relu',
170 | padding="same",
171 | name="e1",
172 | )(inputs)
173 |
174 | e2 = tf.keras.layers.Conv2D(
175 | filters=32,
176 | kernel_size=4,
177 | strides=2,
178 | activation='relu',
179 | padding="same",
180 | name="e2",
181 | )(e1)
182 |
183 | e3 = tf.keras.layers.Conv2D(
184 | filters=64,
185 | kernel_size=2,
186 | strides=2,
187 | activation='relu',
188 | padding="same",
189 | name="e3",
190 | )(e2)
191 |
192 | e4 = tf.keras.layers.Conv2D(
193 | filters=64,
194 | kernel_size=2,
195 | strides=2,
196 | activation='relu',
197 | padding="same",
198 | name="e4",
199 | )(e3)
200 |
201 | flat_e4 = tf.keras.layers.Flatten()(e4)
202 | e5 = tf.keras.layers.Dense(256, activation='relu', name="e5")(flat_e4)
203 |
204 | means = tf.keras.layers.Dense(
205 | num_latent,
206 | activation=None,
207 | name="means",
208 | )(e5)
209 | log_var = tf.keras.layers.Dense(
210 | num_latent,
211 | activation=None,
212 | name="log_var",
213 | )(e5)
214 |
215 | encoder = tf.keras.Model(inputs=inputs, outputs=[means, log_var])
216 |
217 | return encoder
218 |
219 |
220 | def deconv_decoder(output_shape, num_latent):
221 | """
222 | CNN decoder architecture used in the 'Challenging Common Assumptions in the
223 | Unsupervised Learning of Disentangled Representations' paper
224 | (https://arxiv.org/abs/1811.12359)
225 |
226 | Note: model is uncompiled
227 | """
228 |
229 | latent_inputs = tf.keras.Input(shape=(num_latent,))
230 | d1 = tf.keras.layers.Dense(256, activation='relu')(latent_inputs)
231 | d2 = tf.keras.layers.Dense(1024, activation='relu')(d1)
232 | d2_reshaped = tf.keras.layers.Reshape([4, 4, 64])(d2)
233 |
234 | d3 = tf.keras.layers.Conv2DTranspose(
235 | filters=64,
236 | kernel_size=4,
237 | strides=2,
238 | activation='relu',
239 | padding="same",
240 | )(d2_reshaped)
241 |
242 | d4 = tf.keras.layers.Conv2DTranspose(
243 | filters=32,
244 | kernel_size=4,
245 | strides=2,
246 | activation='relu',
247 | padding="same",
248 | )(d3)
249 |
250 | d5 = tf.keras.layers.Conv2DTranspose(
251 | filters=32,
252 | kernel_size=4,
253 | strides=2,
254 | activation='relu',
255 | padding="same",
256 | )(d4)
257 |
258 | d6 = tf.keras.layers.Conv2DTranspose(
259 | filters=output_shape[2],
260 | kernel_size=4,
261 | strides=2, padding="same",
262 | )(d5)
263 | output = tf.keras.layers.Reshape(output_shape)(d6)
264 |
265 | return tf.keras.Model(inputs=latent_inputs, outputs=[output])
266 |
--------------------------------------------------------------------------------
/concepts_xai/utils/model_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tensorflow as tf
3 |
4 |
5 | def get_model(
6 | model,
7 | model_save_path,
8 | overwrite=False,
9 | train_gen=None,
10 | val_gen=None,
11 | batch_size=256,
12 | epochs=250,
13 | ):
14 | '''
15 | Code for loading/training a given model
16 | :param model: Compiled Keras model
17 | :param model_save_path: Path for saving/loading the model weights
18 | :param train_gen: tf.dataset generator for training the model
19 | :param val_gen: tf.dataset generator for validating the model
20 | :param batch_size: Batch size used during training
21 | :param epochs: Number of epochs to use for training
22 | :return: Trained/loaded model
23 | '''
24 |
25 | if (overwrite) or (not os.path.exists(model_save_path)):
26 |
27 | if (train_gen is None) or (val_gen is None):
28 | raise ValueError(
29 | "Training and/or Validation data generators not provided."
30 | )
31 |
32 | # Train model
33 | callbacks = []
34 |
35 | cp_callback = tf.keras.callbacks.ModelCheckpoint(
36 | filepath=model_save_path,
37 | verbose=True,
38 | save_best_only=True,
39 | monitor='val_acc',
40 | mode='auto',
41 | save_freq='epoch',
42 | )
43 | callbacks.append(cp_callback)
44 |
45 | model.fit(
46 | train_gen.batch(batch_size),
47 | epochs=epochs,
48 | validation_data=val_gen.batch(batch_size),
49 | callbacks=callbacks,
50 | )
51 |
52 | #make sure to save the path
53 | dir_name = os.path.dirname(model_save_path)
54 | if not os.path.exists(dir_name):
55 | os.makedirs(dir_name)
56 | model.save_weights(model_save_path)
57 |
58 | else:
59 | print("Loading pre-trainined model")
60 | model.load_weights(model_save_path)
61 |
62 | return model
63 |
--------------------------------------------------------------------------------
/concepts_xai/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | from .architectures import multi_task_cnn, small_cnn
5 | from .model_loader import get_model
6 |
7 |
8 | def tf_data_split(ds, test_size=0.15, n_samples=None):
9 | '''
10 | Method for train/test splitting a tf.dataset, assuming it was generated
11 | from numpy arrays
12 | :param ds: tf.dataset, assumed to be generated "from_tensor_slices" (see
13 | ./datasets/dSprites.py for an example)
14 | :param test_size: Ratio of test samples
15 | :param n_samples: Total number of samples in ds
16 | :return: Two tf datasets, obtained by splitting ds
17 | '''
18 |
19 | if n_samples is None:
20 | # Compute total number of samples, if not provided
21 | n_samples = int(ds.cardinality().numpy())
22 |
23 | # Split the dataset
24 | train_size = int((1. - test_size) * n_samples)
25 | test_size = n_samples - train_size
26 | ds_train = ds.take(train_size)
27 | ds_test = ds.skip(train_size).take(test_size)
28 |
29 | return ds_train, ds_test
30 |
31 |
32 | def setup_experiment_dir(dir_path, overwrite=False):
33 |
34 | # Remove prior contents
35 | if overwrite and os.path.exists(dir_path):
36 | shutil.rmtree(dir_path)
37 |
38 | # Create the directory if it doesn't exist
39 | if not os.path.exists(dir_path):
40 | os.makedirs(dir_path)
41 |
42 |
43 | def convert_to_multioutput(c):
44 | multi_output = tuple([c[i:i + 1] for i in range(c.shape[0])])
45 | return multi_output
46 |
47 |
48 | def setup_basic_model(dataset, n_epochs, save_path, model_type=""):
49 | '''
50 | Train a standard CNN model (used for CBM and CME)
51 | :param dataset: dataset class to train with
52 | :param n_epochs: number of epochs to train for
53 | :param save_path: path to saving/loading the model
54 | :param model_type: whether single_output ("") or multi_task
55 | :return:
56 | '''
57 |
58 | data_gen_train, data_gen_test, c_names = dataset.load_data()
59 |
60 | n_classes = dataset.n_classes
61 | input_shape = dataset.sample_shape
62 | n_samples = dataset.n_train_samples
63 | num_concpet_values = dataset.n_c_vals_list
64 | concept_names = dataset.c_names
65 |
66 | print(f"No. training samples: {n_samples}")
67 | # Train/Load CNN model using data generators
68 | if model_type == "multi_task":
69 |
70 | # remove y and convert concept data to multi-task
71 | data_gen_train_c = data_gen_train.map(
72 | lambda x, c, y: (x, convert_to_multioutput(c))
73 | )
74 | data_gen_t, data_gen_v = tf_data_split(
75 | data_gen_train_c,
76 | test_size=0.15,
77 | n_samples=n_samples,
78 | )
79 | untrained_model = multi_task_cnn(
80 | input_shape,
81 | num_concpet_values,
82 | concept_names,
83 | )
84 | basic_model = get_model(
85 | untrained_model,
86 | save_path,
87 | data_gen_t,
88 | data_gen_v,
89 | epochs=n_epochs,
90 | )
91 | results = basic_model.evaluate(data_gen_train_c.batch(batch_size=256))
92 | else:
93 | # Remove concept data
94 | data_gen_train_no_c = data_gen_train.map(lambda x, c, y: (x, y))
95 | # Train/Load CNN model using data generators
96 | data_gen_t, data_gen_v = tf_data_split(
97 | data_gen_train_no_c,
98 | test_size=0.15,
99 | n_samples=n_samples,
100 | )
101 | untrained_model = small_cnn(input_shape, n_classes)
102 | basic_model = get_model(
103 | untrained_model,
104 | save_path,
105 | train_gen=data_gen_t,
106 | val_gen=data_gen_v,
107 | epochs=n_epochs,
108 | )
109 |
110 | # Evaluate the model
111 | data_gen_test_no_c = data_gen_test.map(lambda x, c, y: (x, y))
112 | results = basic_model.evaluate(data_gen_test_no_c.batch(batch_size=256))
113 | print("Model performance: ", results)
114 |
115 | return basic_model
116 |
--------------------------------------------------------------------------------
/concepts_xai/utils/visualisation.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | from scipy import stats
4 |
5 |
6 | def plot_np_img(img_np, cmap=None):
7 | '''
8 | Plot image given as a numpy array
9 | '''
10 | plt.imshow(img_np, cmap=cmap)
11 | plt.show()
12 |
13 |
14 | def visualisation_experiment(vae, imgs):
15 | '''
16 | Plot images in 'imgs' using the 'vae' reconstructions and original images
17 | '''
18 | kwargs = {"decode":True}
19 |
20 | for img in imgs:
21 | # Plot original image
22 | plot_np_img(img)
23 | # Retrieve reconstructed image from vae
24 | reconstruction = vae(np.expand_dims(img, axis=0), **kwargs)
25 | reconstruction = reconstruction.numpy()[0]
26 | reconstruction = stats.logistic.cdf(reconstruction)
27 | # Plot reconstructed image
28 | plot_np_img(reconstruction)
29 |
30 |
--------------------------------------------------------------------------------
/config.yml:
--------------------------------------------------------------------------------
1 |
2 | # Path to the downloaded dsprites.npz file
3 | dsprites_path : ...
4 |
5 | # Path to the downloaded 'cars' directory
6 | cars3D_path : ...
7 |
8 | # Path to the downloaded 'smallNorb' directory
9 | smallNorb_path : ...
10 |
11 | # Path to the downloaded '3dshapes.h5' file
12 | shapes3d_path : ...
13 |
14 |
--------------------------------------------------------------------------------
/config_template.yml:
--------------------------------------------------------------------------------
1 |
2 | # Path to the downloaded dsprites.npz file
3 | dsprites_path : ...
4 |
5 | # Path to the downloaded 'cars' directory
6 | cars3D_path : ...
7 |
8 | # Path to the downloaded 'smallNorb' directory
9 | smallNorb_path : ...
10 |
11 | # Path to the downloaded '3dshapes.h5' file
12 | shapes3d_path : ...
--------------------------------------------------------------------------------
/download_datasets.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | # Note: code partially-adapted from the
5 | # 'https://github.com/google-research/disentanglement_lib' repo
6 |
7 |
8 | echo "Downloading small_norb."
9 | if [[ ! -d "small_norb" ]]; then
10 | mkdir small_norb
11 | fi
12 | if [[ ! -e small_norb/"smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat" ]]; then
13 | wget -O small_norb/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz
14 | gunzip small_norb/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz
15 | fi
16 | if [[ ! -e small_norb/"smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat" ]]; then
17 | wget -O small_norb/smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz
18 | gunzip small_norb/smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz
19 | fi
20 | if [[ ! -e small_norb/"smallnorb-5x46789x9x18x6x2x96x96-training-info.mat" ]]; then
21 | wget -O small_norb/smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz
22 | gunzip small_norb/smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz
23 | fi
24 | if [[ ! -e small_norb/"smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat" ]]; then
25 | wget -O small_norb/smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz
26 | gunzip small_norb/smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz
27 | fi
28 | if [[ ! -e small_norb/"smallnorb-5x46789x9x18x6x2x96x96-testing-cat.mat" ]]; then
29 | wget -O small_norb/smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz
30 | gunzip small_norb/smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz
31 | fi
32 | if [[ ! -e small_norb/"smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat" ]]; then
33 | wget -O small_norb/smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz
34 | gunzip small_norb/smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz
35 | fi
36 | echo "Downloading small_norb completed!"
37 |
38 |
39 | echo "Downloading cars dataset."
40 | if [[ ! -d "cars" ]]; then
41 | wget -O nips2015-analogy-data.tar.gz http://www.scottreed.info/files/nips2015-analogy-data.tar.gz
42 | tar xzf nips2015-analogy-data.tar.gz
43 | rm nips2015-analogy-data.tar.gz
44 | mv data/cars .
45 | rm -r data
46 | fi
47 | echo "Downloading cars completed!"
48 |
49 |
50 | echo "Downloading dSprites dataset."
51 | if [[ ! -d "dsprites" ]]; then
52 | mkdir dsprites
53 | wget -O dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz https://github.com/deepmind/dsprites-dataset/raw/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz
54 | fi
55 | echo "Downloading dSprites completed!"
56 |
57 |
58 | echo "Downloading shapes3d dataset."
59 | if [[ ! -d "shapes3d" ]]; then
60 | mkdir shapes3d
61 | wget -O shapes3d/3dshapes.h5 https://storage.cloud.google.com/3d-shapes/3dshapes.h5
62 |
63 | fi
64 | echo "Downloading shapes3d completed!"
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.3.3
2 | numpy==1.18.5
3 | scipy==1.4.1
4 | tensorflow==2.3.2
5 | scikit-learn==0.24.0
6 | Pillow==8.1.0
7 | pandas==1.2.0
8 | h5py==2.10.0
9 | PyYAML==5.3.1
10 | sklearn==0.0
11 | setuptools==50.3.2
12 | tensorflow-probability==0.7.0
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | with open("README.md", "r") as f:
4 | long_description = f.read()
5 |
6 | version = '0.1.0'
7 | setup(
8 | name='concepts_xai',
9 | version=version,
10 | packages=find_packages(),
11 | description='Concept Extraction Comparison',
12 | long_description=long_description,
13 | long_description_content_type="text/markdown",
14 | python_requires='>=3.6',
15 | classifiers=[
16 | "Programming Language :: Python :: 3",
17 | "License :: OSI Approved :: MIT License",
18 | "Operating System :: OS Independent",
19 | ],
20 | )
21 |
22 | '''
23 | VERSIONING:
24 |
25 | 1.2.0.dev1 # Development release
26 | 1.2.0a1 # Alpha Release
27 | 1.2.0b1 # Beta Release
28 | 1.2.0rc1 # Release Candidate
29 | 1.2.0 # Final Release
30 | 1.2.0.post1 # Post Release
31 | 15.10 # Date based release
32 | 23 # Serial release
33 |
34 |
35 |
36 |
37 | MAJOR version when they make incompatible API changes,
38 |
39 | MINOR version when they add functionality in a backwards-compatible manner, and
40 |
41 | MAINTENANCE version when they make backwards-compatible bug fixes.
42 |
43 | '''
44 |
--------------------------------------------------------------------------------