├── .gitignore ├── LICENSE ├── README.md ├── figures └── concept_learner.png └── src ├── args.py ├── data_clevr_hans.py ├── docker ├── Dockerfile └── requirements.txt ├── logs └── slot-attention-clevr-state-3_final ├── model.py ├── pretrain-slot-attention ├── README.md ├── data.py ├── logs │ └── slot-attention-clevr-state-3 ├── model.py ├── preprocess-images.py ├── scripts │ ├── clevr-slot-attention.sh │ └── clevr_preprocess.sh ├── train.py └── utils.py ├── runs └── CLEVR-Hans3 │ └── concept-learner-0-CLEVR-Hans3_seed0 │ ├── args.txt │ └── events.out.tfevents.1615486383.ml-meteor.local.26896.0 ├── scripts └── clevr-hans-concept-learner_CLEVR_Hans3.sh ├── train_nesy_concept_learner_clevr_hans.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | *.DS_Store 132 | 133 | src/runs/* 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 ml-research@TUDarmstadt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neuro Symbolic Concpet Learner based on Slot Attention and Set Transformer 2 | 3 | This is the official repository for the Neuro-Symbolic Concept Learner introduced in 4 | [Right for the Right Concept: Revising Neuro-Symbolic Concepts by Interacting 5 | with their Explanations](https://arxiv.org/pdf/2011.12854.pdf) by Wolfgang Stammer, Patrick Schramowski, 6 | Kristian Kersting, to be published at CVPR 2021. 7 | 8 | ![Concept Learner with NeSy XIL](./figures/concept_learner.png) 9 | 10 | This repository contains the model source code for the Neuro-Symbolic Concept Learner together with a script for training the 11 | Concept Learner on the [CLEVR-Hans3](https://github.com/ml-research/CLEVR-Hans) data set as a minimal example of how to 12 | use the model. As in the original paper the concept embedding module (Set Prediction Network with Slot Attention) was 13 | pretrained on the original CLEVR data. 14 | 15 | Files for pre-training yourself can be found in ```src/pretrain-slot-attention/``` 16 | (follow the instructions in the corresponding README). 17 | 18 | Please visit the [NeSy XIL](https://github.com/ml-research/NeSyXIL) repository for the Neuro-Symbolic Explanatory 19 | Interactive Learning approach based on this Concept Learner to see further examples from the original paper. 20 | 21 | ## How to Run with docker on GPU: 22 | 23 | ### Dataset 24 | 25 | First download the CLEVR-Hans3 data set. Please visit the [CLEVR-Hans](https://github.com/ml-research/CLEVR-Hans) 26 | repository for instrucitons on this. 27 | 28 | ### Docker 29 | 30 | To run the eaxmple train script with the CLEVR-Hans3 data follow: 31 | 32 | 1. ```cd src/docker/``` 33 | 34 | 2. ```docker build -t nesy-concept-learner -f Dockerfile .``` 35 | 36 | 3. ```docker run -it -v /pathto/NeSyConceptLearner:/workspace/repositories/NeSyConceptLearner -v /pathto/CLEVR-Hans3:/workspace/datasets/CLEVR-Hans3 --name nesy-concept-learner --entrypoint='/bin/bash' --runtime nvidia nesy-concept-learner``` 37 | 38 | 4. ```cd repositories/NeSyConceptLearner/src/``` 39 | 40 | 5. ```./scripts/clevr-hans-concept-learner_CLEVR_Hans3.sh 0 0 /workspace/datasets/CLEVR-Hans3/``` for running on gpu 0 41 | with run number 0 (for saving) 42 | 43 | ## Citation 44 | If you find this code useful in your research, please consider citing: 45 | 46 | > @article{stammer2020right, 47 | title={Right for the Right Concept: Revising Neuro-Symbolic Concepts by Interacting with their Explanations}, 48 | author={Stammer, Wolfgang and Schramowski, Patrick and Kersting, Kristian}, 49 | journal={arXiv preprint arXiv:2011.12854}, 50 | year={2020} 51 | } 52 | -------------------------------------------------------------------------------- /figures/concept_learner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyConceptLearner/66dddba1e879359dcefd6685ca3e405c17369c8d/figures/concept_learner.png -------------------------------------------------------------------------------- /src/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from datetime import datetime 4 | 5 | import utils as utils 6 | 7 | def get_args(): 8 | parser = argparse.ArgumentParser() 9 | # generic params 10 | parser.add_argument( 11 | "--name", 12 | default=datetime.now().strftime("%Y-%m-%d_%H:%M:%S"), 13 | help="Name to store the log file as", 14 | ) 15 | parser.add_argument("--mode", type=str, required=True, help="train, test, or plot") 16 | parser.add_argument("--resume", help="Path to log file to resume from") 17 | 18 | parser.add_argument( 19 | "--seed", type=int, default=10, help="Random generator seed for all frameworks" 20 | ) 21 | parser.add_argument( 22 | "--epochs", type=int, default=10, help="Number of epochs to train with" 23 | ) 24 | parser.add_argument( 25 | "--lr", type=float, default=1e-2, help="Outer learning rate of model" 26 | ) 27 | parser.add_argument( 28 | "--l2_grads", type=float, default=1, help="Right for right reason weight" 29 | ) 30 | parser.add_argument( 31 | "--batch-size", type=int, default=32, help="Batch size to train with" 32 | ) 33 | parser.add_argument( 34 | "--num-workers", type=int, default=4, help="Number of threads for data loader" 35 | ) 36 | parser.add_argument( 37 | "--dataset", 38 | choices=["clevr-hans-state"], 39 | ) 40 | parser.add_argument( 41 | "--no-cuda", 42 | action="store_true", 43 | help="Run on CPU instead of GPU (not recommended)", 44 | ) 45 | parser.add_argument( 46 | "--train-only", action="store_true", help="Only run training, no evaluation" 47 | ) 48 | parser.add_argument( 49 | "--eval-only", action="store_true", help="Only run evaluation, no training" 50 | ) 51 | parser.add_argument("--multi-gpu", action="store_true", help="Use multiple GPUs") 52 | 53 | parser.add_argument("--data-dir", type=str, help="Directory to data") 54 | parser.add_argument("--fp-ckpt", type=str, default=None, help="checkpoint filepath") 55 | 56 | # Slot attention params 57 | parser.add_argument('--n-slots', default=10, type=int, 58 | help='number of slots for slot attention module') 59 | parser.add_argument('--n-iters-slot-att', default=3, type=int, 60 | help='number of iterations in slot attention module') 61 | parser.add_argument('--n-attr', default=18, type=int, 62 | help='number of attributes per object') 63 | 64 | args = parser.parse_args() 65 | 66 | # hard set !!!!!!!!!!!!!!!!!!!!!!!!! 67 | args.n_heads = 4 68 | args.set_transf_hidden = 128 69 | 70 | assert args.data_dir.endswith(os.path.sep) 71 | args.conf_version = args.data_dir.split(os.path.sep)[-2] 72 | args.name = args.name + f"-{args.conf_version}" 73 | 74 | if args.mode == 'test' or args.mode == 'plot': 75 | assert args.fp_ckpt 76 | 77 | if args.no_cuda: 78 | args.device = 'cpu' 79 | else: 80 | args.device = 'cuda' 81 | 82 | utils.seed_everything(args.seed) 83 | 84 | return args 85 | -------------------------------------------------------------------------------- /src/data_clevr_hans.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import torch.utils.data 6 | import torchvision.transforms as transforms 7 | from torchvision.datasets.folder import pil_loader 8 | import numpy as np 9 | 10 | from pycocotools import mask as coco_mask 11 | 12 | os.environ["MKL_NUM_THREADS"] = "6" 13 | os.environ["NUMEXPR_NUM_THREADS"] = "6" 14 | os.environ["OMP_NUM_THREADS"] = "6" 15 | torch.set_num_threads(6) 16 | 17 | def get_loader(dataset, batch_size, num_workers=8, shuffle=True): 18 | return torch.utils.data.DataLoader( 19 | dataset, 20 | shuffle=shuffle, 21 | batch_size=batch_size, 22 | pin_memory=True, 23 | num_workers=num_workers, 24 | drop_last=True, 25 | ) 26 | 27 | 28 | CLASSES = { 29 | "shape": ["sphere", "cube", "cylinder"], 30 | "size": ["large", "small"], 31 | "material": ["rubber", "metal"], 32 | "color": ["cyan", "blue", "yellow", "purple", "red", "green", "gray", "brown"], 33 | } 34 | 35 | 36 | class CLEVR_HANS_EXPL(torch.utils.data.Dataset): 37 | def __init__(self, base_path, split, lexi=False, conf_vers='conf_2'): 38 | assert split in { 39 | "train", 40 | "val", 41 | "test", 42 | } 43 | self.lexi = lexi 44 | self.base_path = base_path 45 | self.split = split 46 | self.max_objects = 10 47 | self.conf_vers = conf_vers 48 | 49 | with open(self.scenes_path) as fd: 50 | scenes = json.load(fd)["scenes"] 51 | self.img_ids, self.img_class_ids, self.scenes, self.fnames, self.gt_img_expls, self.gt_table_expls = \ 52 | self.prepare_scenes(scenes) 53 | 54 | self.transform = transforms.Compose( 55 | [transforms.Resize((128, 128)), 56 | transforms.ToTensor()] 57 | ) 58 | # self.transform_img_expl = transforms.Compose( 59 | # [transforms.ToPILImage(mode='L'), 60 | # transforms.Resize((14, 14), interpolation=5), 61 | # transforms.ToTensor()] 62 | # ) 63 | 64 | self.n_classes = len(np.unique(self.img_class_ids)) 65 | self.category_dict = CLASSES 66 | 67 | # get ids of category ranges, i.e. shape has three categories from ids 0 to 2 68 | self.category_ids = np.array([3, 6, 8, 10, 18]) 69 | 70 | def convert_coords(self, obj, scene_directions): 71 | # convert the 3d coords based on camera position 72 | # conversion from ns-vqa paper, normalization for slot attention 73 | position = [np.dot(obj['3d_coords'], scene_directions['right']), 74 | np.dot(obj['3d_coords'], scene_directions['front']), 75 | obj['3d_coords'][2]] 76 | coords = [(p +4.)/ 8. for p in position] 77 | return coords 78 | 79 | def object_to_fv(self, obj, scene_directions): 80 | coords = self.convert_coords(obj, scene_directions) 81 | one_hot = lambda key: [obj[key] == x for x in CLASSES[key]] 82 | material = one_hot("material") 83 | color = one_hot("color") 84 | shape = one_hot("shape") 85 | size = one_hot("size") 86 | assert sum(material) == 1 87 | assert sum(color) == 1 88 | assert sum(shape) == 1 89 | assert sum(size) == 1 90 | # concatenate all the classes 91 | return coords + shape + size + material + color 92 | 93 | def prepare_scenes(self, scenes_json): 94 | img_ids = [] 95 | scenes = [] 96 | gt_img_expls = [] 97 | img_class_ids = [] 98 | gt_table_expls = [] 99 | fnames = [] 100 | for scene in scenes_json: 101 | fnames.append(os.path.join(self.images_folder, scene['image_filename'])) 102 | img_class_ids.append(scene['class_id']) 103 | img_idx = scene["image_index"] 104 | 105 | objects = [self.object_to_fv(obj, scene['directions']) for obj in scene["objects"]] 106 | objects = torch.FloatTensor(objects).transpose(0, 1) 107 | 108 | # get gt image explanation based on the classification rule of the class label 109 | gt_img_expl_mask = self.get_img_expl_mask(scene) 110 | gt_img_expls.append(gt_img_expl_mask) 111 | 112 | num_objects = objects.size(1) 113 | # pad with 0s 114 | if num_objects < self.max_objects: 115 | objects = torch.cat( 116 | [ 117 | objects, 118 | torch.zeros(objects.size(0), self.max_objects - num_objects), 119 | ], 120 | dim=1, 121 | ) 122 | 123 | # get gt table explanation based on the classification rule of the class label 124 | gt_table_expl_mask = self.get_table_expl_mask(objects, scene['class_id']) 125 | gt_table_expls.append(gt_table_expl_mask) 126 | 127 | # fill in masks 128 | mask = torch.zeros(self.max_objects) 129 | mask[:num_objects] = 1 130 | 131 | # concatenate obj indication to end of object list 132 | objects = torch.cat((mask.unsqueeze(dim=0), objects), dim=0) 133 | 134 | img_ids.append(img_idx) 135 | scenes.append(objects.T) 136 | return img_ids, img_class_ids, scenes, fnames, gt_img_expls, gt_table_expls 137 | 138 | def get_img_expl_mask(self, scene): 139 | class_id = scene['class_id'] 140 | 141 | mask = 0 142 | if self.conf_vers == 'CLEVR-Hans3': 143 | for obj in scene['objects']: 144 | if class_id == 0: 145 | if (obj['shape'] == 'cube' and obj['size'] == 'large') or \ 146 | (obj['shape'] == 'cylinder' and obj['size'] == 'large'): 147 | rle = obj['mask'] 148 | mask += coco_mask.decode(rle) 149 | elif class_id == 1: 150 | if (obj['shape'] == 'cube' and obj['size'] == 'small' and obj['material'] == 'metal') or \ 151 | (obj['shape'] == 'sphere' and obj['size'] == 'small'): 152 | rle = obj['mask'] 153 | mask += coco_mask.decode(rle) 154 | elif class_id == 2: 155 | if (obj['shape'] == 'sphere' and obj['size'] == 'large' and obj['color'] == 'blue') or \ 156 | (obj['shape'] == 'sphere' and obj['size'] == 'small' and obj['color'] == 'yellow'): 157 | rle = obj['mask'] 158 | mask += coco_mask.decode(rle) 159 | elif self.conf_vers == 'CLEVR-Hans7': 160 | for obj in scene['objects']: 161 | if class_id == 0: 162 | if (obj['shape'] == 'cube' and obj['size'] == 'large') or \ 163 | (obj['shape'] == 'cylinder' and obj['size'] == 'large'): 164 | rle = obj['mask'] 165 | mask += coco_mask.decode(rle) 166 | elif class_id == 1: 167 | if (obj['shape'] == 'cube' and obj['size'] == 'small' and obj['material'] == 'metal') or \ 168 | (obj['shape'] == 'sphere' and obj['size'] == 'small'): 169 | rle = obj['mask'] 170 | mask += coco_mask.decode(rle) 171 | if class_id == 2: 172 | # get y coord of red and cyan objects 173 | objects = [self.object_to_fv(obj, scene['directions']) for obj in scene["objects"]] 174 | y_red = [obj[1] for obj in objects if obj[14] == 1] 175 | y_cyan = [obj[1] for obj in objects if obj[10] == 1] 176 | obj_coords = self.convert_coords(obj, scene['directions']) 177 | if (obj['color'] == 'cyan' and sum(obj_coords[1] > y_red) >= 2) or \ 178 | (obj['color'] == 'red' and sum(obj_coords[1] < y_cyan) >= 1): 179 | rle = obj['mask'] 180 | mask += coco_mask.decode(rle) 181 | elif class_id == 3: 182 | if (obj['size'] == 'small'): 183 | rle = obj['mask'] 184 | mask += coco_mask.decode(rle) 185 | elif class_id == 4: 186 | obj_coords = self.convert_coords(obj, scene['directions']) 187 | if (obj['shape'] == 'sphere' and obj_coords[0] < 0.5 or 188 | obj['shape'] == 'cylinder' and obj['material'] == 'metal' and obj_coords[0] > 0.5): 189 | rle = obj['mask'] 190 | mask += coco_mask.decode(rle) 191 | elif class_id == 4: 192 | obj_coords = self.convert_coords(obj, scene['directions']) 193 | if (obj['shape'] == 'cylinder' and obj['material'] == 'metal' and obj_coords[0] > 0.5): 194 | rle = obj['mask'] 195 | mask += coco_mask.decode(rle) 196 | elif class_id == 6: 197 | if (obj['shape'] == 'sphere' and obj['size'] == 'large' and obj['color'] == 'blue') or \ 198 | (obj['shape'] == 'sphere' and obj['size'] == 'small' and obj['color'] == 'yellow'): 199 | rle = obj['mask'] 200 | mask += coco_mask.decode(rle) 201 | 202 | return torch.tensor(mask) * 255 # for PIL 203 | 204 | def get_table_expl_mask(self, objects, class_id): 205 | objects = objects.T 206 | 207 | mask = torch.zeros(objects.shape) 208 | 209 | if self.conf_vers == 'CLEVR-Hans3': 210 | for i, obj in enumerate(objects): 211 | if class_id == 0: 212 | # if cube and large 213 | if (obj[3:8] == torch.tensor([0, 1, 0, 1, 0])).all(): 214 | mask[i, 3:8] = torch.tensor([0, 1, 0, 1, 0]) 215 | # or cylinder and large 216 | elif (obj[3:8] == torch.tensor([0, 0, 1, 1, 0])).all(): 217 | mask[i, 3:8] = torch.tensor([0, 0, 1, 1, 0]) 218 | elif class_id == 1: 219 | # if cube, small, metal 220 | if (obj[3:10] == torch.tensor([0, 1, 0, 0, 1, 0, 1])).all(): 221 | mask[i, 3:10] = torch.tensor([0, 1, 0, 0, 1, 0, 1]) 222 | # or sphere, small 223 | elif (obj[3:8] == torch.tensor([1, 0, 0, 0, 1])).all(): 224 | mask[i, 3:8] = torch.tensor([1, 0, 0, 0, 1]) 225 | elif class_id == 2: 226 | # if sphere large blue 227 | if ((obj[3:8] == torch.tensor([1, 0, 0, 1, 0])).all() 228 | and (obj[10:] == torch.tensor([0, 1, 0, 0, 0, 0, 0, 0])).all()).all(): 229 | mask[i, 3:8] = torch.tensor([1, 0, 0, 1, 0]) 230 | mask[i, 10:] = torch.tensor([0, 1, 0, 0, 0, 0, 0, 0]) 231 | # or sphere small yellow 232 | elif ((obj[3:8] == torch.tensor([1, 0, 0, 0, 1])).all() 233 | and (obj[10:] == torch.tensor([0, 0, 1, 0, 0, 0, 0, 0])).all()).all(): 234 | mask[i, 3:8] = torch.tensor([1, 0, 0, 0, 1]) 235 | mask[i, 10:] = torch.tensor([0, 0, 1, 0, 0, 0, 0, 0]) 236 | elif self.conf_vers == 'CLEVR-Hans7': 237 | for i, obj in enumerate(objects): 238 | if class_id == 0: 239 | # if cube and large 240 | if (obj[3:8] == torch.tensor([0, 1, 0, 1, 0])).all(): 241 | mask[i, 3:8] = torch.tensor([0, 1, 0, 1, 0]) 242 | # or cylinder and large 243 | elif (obj[3:8] == torch.tensor([0, 0, 1, 1, 0])).all(): 244 | mask[i, 3:8] = torch.tensor([0, 0, 1, 1, 0]) 245 | elif class_id == 1: 246 | # if cube, small, metal 247 | if (obj[3:10] == torch.tensor([0, 1, 0, 0, 1, 0, 1])).all(): 248 | mask[i, 3:10] = torch.tensor([0, 1, 0, 0, 1, 0, 1]) 249 | # or sphere, small 250 | elif (obj[3:8] == torch.tensor([1, 0, 0, 0, 1])).all(): 251 | mask[i, 3:8] = torch.tensor([1, 0, 0, 0, 1]) 252 | elif class_id == 2: 253 | # get maximal y coord of red objects 254 | y_red = objects[objects[:, 14] == 1, 1] 255 | y_cyan = objects[objects[:, 10] == 1, 1] 256 | # if cyan object and y coord greater than that of at least 2 red objs, i.e. in front of red objs 257 | if ((obj[10:] == torch.tensor([1, 0, 0, 0, 0, 0, 0, 0])).all() 258 | and (sum(obj[1] > y_red) >= 2)).all(): 259 | mask[i, 10:] = torch.tensor([1, 0, 0, 0, 0, 0, 0, 0]) 260 | mask[i, 1] = torch.tensor([1]) 261 | # or red obj 262 | elif ((obj[10:] == torch.tensor([0, 0, 0, 0, 1, 0, 0, 0])).all() 263 | and (sum(obj[1] < y_cyan) >= 1)).all(): 264 | mask[i, 10:] = torch.tensor([0, 0, 0, 0, 1, 0, 0, 0]) 265 | mask[i, 1] = torch.tensor([1]) 266 | elif class_id == 3: 267 | # if small and brown 268 | if ((obj[6:8] == torch.tensor([0, 1])).all() 269 | and (obj[10:] == torch.tensor([0, 0, 0, 0, 0, 0, 0, 1])).all()).all(): 270 | mask[i, 6:8] = torch.tensor([0, 1]) 271 | mask[i, 10:] = torch.tensor([0, 0, 0, 0, 0, 0, 0, 1]) 272 | # if small and green 273 | elif ((obj[6:8] == torch.tensor([0, 1])).all() 274 | and (obj[10:] == torch.tensor([0, 0, 0, 0, 0, 1, 0, 0])).all()).all(): 275 | mask[i, 6:8] = torch.tensor([0, 1]) 276 | mask[i, 10:] = torch.tensor([0, 0, 0, 0, 0, 1, 0, 0]) 277 | # if small and purple 278 | elif ((obj[6:8] == torch.tensor([0, 1])).all() 279 | and (obj[10:] == torch.tensor([0, 0, 0, 1, 0, 0, 0, 0])).all()).all(): 280 | mask[i, 6:8] = torch.tensor([0, 1]) 281 | mask[i, 10:] = torch.tensor([0, 0, 0, 1, 0, 0, 0, 0]) 282 | # elif small 283 | elif (obj[6:8] == torch.tensor([0, 1])).all(): 284 | mask[i, 6:8] = torch.tensor([0, 1]) 285 | elif class_id == 4: 286 | # if at least 3 metal cylinders on right side 287 | if sum((objects[:, 5] == 1) & (objects[:, 9] == 1) & (objects[:, 0] > 0.5)) >= 3: 288 | # if sphere and on left side 289 | if ((obj[3:6] == torch.tensor([1, 0, 0])).all() 290 | and obj[0] < 0.5).all(): 291 | mask[i, 3:6] = torch.tensor([1, 0, 0]) 292 | mask[i, 0] = torch.tensor([1]) 293 | # if metal cyl. and on right side 294 | elif ((obj[3:6] == torch.tensor([0, 0, 1])).all() 295 | and (obj[8:10] == torch.tensor([0, 1])).all() 296 | and obj[0] > 0.5).all(): 297 | mask[i, 3:6] = torch.tensor([0, 0, 1]) 298 | mask[i, 8:10] = torch.tensor([0, 1]) 299 | mask[i, 0] = torch.tensor([1]) 300 | # if sphere and on left side 301 | elif ((obj[3:6] == torch.tensor([1, 0, 0])).all() 302 | and obj[0] < 0.5).all(): 303 | mask[i, 3:6] = torch.tensor([1, 0, 0]) 304 | mask[i, 0] = torch.tensor([1]) 305 | elif class_id == 5: 306 | # if metal cylinder and on right side 307 | if ((obj[3:6] == torch.tensor([0, 0, 1])).all() 308 | and (obj[8:10] == torch.tensor([0, 1])).all() 309 | and obj[0] > 0.5).all(): 310 | mask[i, 3:6] = torch.tensor([0, 0, 1]) 311 | mask[i, 8:10] = torch.tensor([0, 1]) 312 | mask[i, 0] = torch.tensor([1]) 313 | elif class_id == 6: 314 | # if sphere large blue 315 | if ((obj[3:8] == torch.tensor([1, 0, 0, 1, 0])).all() 316 | and (obj[10:] == torch.tensor([0, 1, 0, 0, 0, 0, 0, 0])).all()).all(): 317 | mask[i, 3:8] = torch.tensor([1, 0, 0, 1, 0]) 318 | mask[i, 10:] = torch.tensor([0, 1, 0, 0, 0, 0, 0, 0]) 319 | # or sphere small yellow 320 | elif ((obj[3:8] == torch.tensor([1, 0, 0, 0, 1])).all() 321 | and (obj[10:] == torch.tensor([0, 0, 1, 0, 0, 0, 0, 0])).all()).all(): 322 | mask[i, 3:8] = torch.tensor([1, 0, 0, 0, 1]) 323 | mask[i, 10:] = torch.tensor([0, 0, 1, 0, 0, 0, 0, 0]) 324 | 325 | return mask 326 | 327 | @property 328 | def images_folder(self): 329 | return os.path.join(self.base_path, self.split, "images") 330 | 331 | @property 332 | def scenes_path(self): 333 | return os.path.join( 334 | self.base_path, self.split, "CLEVR_HANS_scenes_{}.json".format(self.split) 335 | ) 336 | 337 | def __getitem__(self, item): 338 | image_id = self.img_ids[item] 339 | 340 | image = pil_loader(self.fnames[item]) 341 | # TODO: sofar only dummy 342 | img_expl = torch.tensor([0]) 343 | 344 | if self.transform is not None: 345 | image = self.transform(image) # in range [0., 1.] 346 | image = (image - 0.5) * 2.0 # Rescale to [-1, 1]. 347 | # img_expl = self.transform_img_expl(img_expl) 348 | 349 | objects = self.scenes[item] 350 | table_expl = self.gt_table_expls[item] 351 | img_class_id = self.img_class_ids[item] 352 | 353 | # remove objects presence indicator from gt table 354 | objects = objects[:, 1:] 355 | 356 | return image, objects, img_class_id, image_id, img_expl, table_expl 357 | 358 | def __len__(self): 359 | return len(self.scenes) 360 | 361 | -------------------------------------------------------------------------------- /src/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:20.03-py3 2 | WORKDIR /workspace 3 | COPY requirements.txt requirements.txt 4 | RUN pip install -r requirements.txt 5 | RUN ["apt-get", "update"] 6 | RUN ["apt-get", "install", "-y", "zsh"] 7 | RUN wget https://github.com/robbyrussell/oh-my-zsh/raw/master/tools/install.sh -O - | zsh || true -------------------------------------------------------------------------------- /src/docker/requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools 2 | Cython 3 | numpy 4 | pandas 5 | scipy 6 | statsmodels 7 | scikit-learn>=0.20.0 8 | tensorboard_logger 9 | tqdm 10 | torchsummary 11 | seaborn 12 | captum 13 | rtpt 14 | h5py -------------------------------------------------------------------------------- /src/logs/slot-attention-clevr-state-3_final: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyConceptLearner/66dddba1e879359dcefd6685ca3e405c17369c8d/src/logs/slot-attention-clevr-state-3_final -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import math 5 | from torch import nn 6 | 7 | class SlotAttention(nn.Module): 8 | """ 9 | Implementation from https://github.com/lucidrains/slot-attention by lucidrains. 10 | """ 11 | def __init__(self, num_slots, dim, iters=3, eps=1e-8, hidden_dim=128): 12 | super().__init__() 13 | self.num_slots = num_slots 14 | self.iters = iters 15 | self.eps = eps 16 | self.scale = dim ** -0.5 17 | 18 | self.slots_mu = nn.Parameter(torch.randn(1, 1, dim)) 19 | # self.slots_log_sigma = nn.Parameter(torch.randn(1, 1, dim)) 20 | self.slots_log_sigma = nn.Parameter(torch.randn(1, 1, dim)).abs().to(device='cuda') 21 | 22 | self.project_q = nn.Linear(dim, dim) 23 | self.project_k = nn.Linear(dim, dim) 24 | self.project_v = nn.Linear(dim, dim) 25 | 26 | self.gru = nn.GRUCell(dim, dim) 27 | 28 | hidden_dim = max(dim, hidden_dim) 29 | 30 | self.mlp = nn.Sequential( 31 | nn.Linear(dim, hidden_dim), 32 | nn.ReLU(inplace=True), 33 | nn.Linear(hidden_dim, dim) 34 | ) 35 | 36 | self.norm_inputs = nn.LayerNorm(dim, eps=1e-05) 37 | self.norm_slots = nn.LayerNorm(dim, eps=1e-05) 38 | self.norm_mlp = nn.LayerNorm(dim, eps=1e-05) 39 | 40 | # dummy initialisation 41 | self.attn = 0 42 | 43 | def forward(self, inputs, num_slots=None): 44 | b, n, d = inputs.shape 45 | n_s = num_slots if num_slots is not None else self.num_slots 46 | 47 | mu = self.slots_mu.expand(b, n_s, -1) 48 | sigma = self.slots_log_sigma.expand(b, n_s, -1) 49 | slots = torch.normal(mu, sigma) 50 | 51 | inputs = self.norm_inputs(inputs) 52 | k, v = self.project_k(inputs), self.project_v(inputs) 53 | 54 | for _ in range(self.iters): 55 | slots_prev = slots 56 | 57 | slots = self.norm_slots(slots) 58 | q = self.project_q(slots) 59 | 60 | dots = torch.einsum('bid,bjd->bij', q, k) * self.scale 61 | attn = dots.softmax(dim=1) + self.eps 62 | attn = attn / attn.sum(dim=-1, keepdim=True) 63 | 64 | updates = torch.einsum('bjd,bij->bid', v, attn) 65 | 66 | slots = self.gru( 67 | updates.reshape(-1, d), 68 | slots_prev.reshape(-1, d) 69 | ) 70 | 71 | slots = slots.reshape(b, -1, d) 72 | slots = slots + self.mlp(self.norm_mlp(slots)) 73 | 74 | self.attn = attn 75 | 76 | return slots 77 | 78 | 79 | class SlotAttention_encoder(nn.Module): 80 | """ 81 | Slot attention encoder for CLEVR as in Locatello et al. 2020 according to the set prediction architecture. 82 | """ 83 | def __init__(self, in_channels, hidden_channels): 84 | """ 85 | Builds the Encoder for the set prediction architecture 86 | :param in_channels: Integer, input channel dimensions to encoder 87 | :param hidden_channels: Integer, hidden channel dimensions within encoder 88 | """ 89 | super(SlotAttention_encoder, self).__init__() 90 | self.network = nn.Sequential( 91 | nn.Conv2d(in_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2), 92 | nn.ReLU(inplace=True), 93 | nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(2, 2), padding=2), 94 | nn.ReLU(inplace=True), 95 | nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(2, 2), padding=2), 96 | nn.ReLU(inplace=True), 97 | nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2), 98 | nn.ReLU(inplace=True), 99 | ) 100 | 101 | def forward(self, x): 102 | return self.network(x) 103 | 104 | 105 | class MLP(nn.Module): 106 | """ 107 | MLP for CLEVR as in Locatello et al. 2020 according to the set prediction architecture. 108 | """ 109 | def __init__(self, hidden_channels): 110 | """ 111 | Builds the MLP 112 | :param hidden_channels: Integer, hidden channel dimensions within encoder, is also equivalent to the input 113 | channel dims here. 114 | """ 115 | super(MLP, self).__init__() 116 | self.network = nn.Sequential( 117 | nn.Linear(hidden_channels, hidden_channels), 118 | nn.ReLU(inplace=True), 119 | nn.Linear(hidden_channels, hidden_channels), 120 | ) 121 | 122 | def forward(self, x): 123 | return self.network(x) 124 | 125 | 126 | def build_grid(resolution): 127 | """ 128 | Builds the grid for the Posisition Embedding. 129 | :param resolution: Tuple of Ints, in the dimensions of the latent space of the encoder. 130 | :return: 2D Float meshgrid representing th x y position. 131 | """ 132 | ranges = [np.linspace(0., 1., num=res) for res in resolution] 133 | grid = np.meshgrid(*ranges, sparse=False, indexing="ij") 134 | grid = np.stack(grid, axis=-1) 135 | grid = np.reshape(grid, [resolution[0], resolution[1], -1]) 136 | grid = np.expand_dims(grid, axis=0) 137 | grid = grid.astype(np.float32) 138 | return np.concatenate([grid, 1.0 - grid], axis=-1) 139 | 140 | 141 | class SoftPositionEmbed(nn.Module): 142 | """ 143 | Adds soft positional embedding with learnable projection. 144 | """ 145 | def __init__(self, hidden_size, resolution, device="cuda"): 146 | """Builds the soft position embedding layer. 147 | Args: 148 | hidden_size: Size of input feature dimension. 149 | resolution: Tuple of integers specifying width and height of grid. 150 | device: String specifiying the device, cpu or cuda 151 | """ 152 | super().__init__() 153 | self.dense = nn.Linear(4, hidden_size) 154 | self.grid = torch.FloatTensor(build_grid(resolution)) 155 | self.grid = self.grid.to(device) 156 | self.resolution = resolution[0] 157 | self.hidden_size = hidden_size 158 | 159 | def forward(self, inputs): 160 | return inputs + self.dense(self.grid).view((-1, self.hidden_size, self.resolution, self.resolution)) 161 | 162 | 163 | class SlotAttention_classifier(nn.Module): 164 | """ 165 | The classifier of the set prediction architecture of Locatello et al. 2020 166 | """ 167 | def __init__(self, in_channels, out_channels): 168 | """ 169 | Builds the classifier for the set prediction architecture. 170 | :param in_channels: Integer, input channel dimensions 171 | :param out_channels: Integer, output channel dimensions 172 | """ 173 | super(SlotAttention_classifier, self).__init__() 174 | self.network = nn.Sequential( 175 | nn.Linear(in_channels, in_channels), 176 | nn.ReLU(inplace=True), 177 | nn.Linear(in_channels, out_channels), 178 | nn.Sigmoid() 179 | ) 180 | 181 | def forward(self, x): 182 | return self.network(x) 183 | 184 | 185 | class SlotAttention_model(nn.Module): 186 | """ 187 | The set prediction slot attention architecture for CLEVR as in Locatello et al 2020. 188 | """ 189 | def __init__(self, n_slots, n_iters, n_attr, category_ids, 190 | in_channels=3, 191 | encoder_hidden_channels=64, 192 | attention_hidden_channels=128, 193 | device="cuda"): 194 | """ 195 | Builds the set prediction slot attention architecture. 196 | :param n_slots: Integer, number of slots 197 | :param n_iters: Integer, number of attention iterations 198 | :param n_attr: Integer, number of attributes per object to predict 199 | :param category_ids: List of Integers, specifying the boundaries of each attribute group, e.g. color 200 | attributes are variables 10 to 17 201 | :param in_channels: Integer, number of input channel dimensions 202 | :param encoder_hidden_channels: Integer, number of encoder hidden channel dimensions 203 | :param attention_hidden_channels: Integer, number of hidden channel dimensions for slot attention module 204 | :param device: String, either 'cpu' or 'cuda' 205 | """ 206 | super(SlotAttention_model, self).__init__() 207 | self.n_slots = n_slots 208 | self.n_iters = n_iters 209 | self.n_attr = n_attr 210 | self.category_ids = category_ids 211 | self.n_attr = n_attr + 1 # additional slot to indicate if it is a object or empty slot 212 | self.device = device 213 | 214 | self.encoder_cnn = SlotAttention_encoder(in_channels=in_channels, hidden_channels=encoder_hidden_channels) 215 | self.encoder_pos = SoftPositionEmbed(encoder_hidden_channels, (32, 32), device=device) 216 | self.layer_norm = nn.LayerNorm(encoder_hidden_channels, eps=1e-05) 217 | self.mlp = MLP(hidden_channels=encoder_hidden_channels) 218 | self.slot_attention = SlotAttention(num_slots=n_slots, dim=encoder_hidden_channels, iters=n_iters, eps=1e-8, 219 | hidden_dim=attention_hidden_channels) 220 | self.mlp_classifier = SlotAttention_classifier(in_channels=encoder_hidden_channels, out_channels=self.n_attr) 221 | 222 | def forward(self, x): 223 | x = self.encoder_cnn(x) 224 | x = self.encoder_pos(x) 225 | x = torch.flatten(x, start_dim=2) 226 | x = x.permute(0, 2, 1) 227 | x = self.layer_norm(x) 228 | x = self.mlp(x) 229 | x = self.slot_attention(x) 230 | x = self.mlp_classifier(x) 231 | return x 232 | 233 | def _transform_attrs(self, attrs): 234 | """ 235 | receives the attribute predictions and binarizes them by computing the argmax per attribute group 236 | :param attrs: 3D Tensor, [batch, n_slots, n_attrs] attribute predictions for a batch. 237 | :return: binarized attribute predictions 238 | """ 239 | presence = attrs[:, :, 0] 240 | attrs_trans = attrs[:, :, 1:] 241 | 242 | # threshold presence prediction, i.e. where is an object predicted 243 | presence = presence < 0.5 244 | 245 | # flatten first two dims 246 | attrs_trans = attrs_trans.view(1, -1, attrs_trans.shape[2]).squeeze() 247 | # binarize attributes 248 | # set argmax per attr to 1, all other to 0, s.t. only zeros and ones are contained within graph 249 | # NOTE: this way it is not differentiable! 250 | bin_attrs = torch.zeros(attrs_trans.shape, device=self.device) 251 | for i in range(len(self.category_ids) - 1): 252 | # find the argmax within each category and set this to one 253 | bin_attrs[range(bin_attrs.shape[0]), 254 | # e.g. x[:, 0:(3+0)], x[:, 3:(5+3)], etc 255 | (attrs_trans[:, 256 | self.category_ids[i]:self.category_ids[i + 1]].argmax(dim=1) + self.category_ids[i]).type( 257 | torch.LongTensor)] = 1 258 | 259 | # reshape back to batch x n_slots x n_attrs 260 | bin_attrs = bin_attrs.view(attrs.shape[0], attrs.shape[1], attrs.shape[2] - 1) 261 | 262 | # add coordinates back 263 | bin_attrs[:, :, :3] = attrs[:, :, 1:4] 264 | 265 | # redo presence zeroing 266 | bin_attrs[presence, :] = 0 267 | 268 | return bin_attrs 269 | 270 | 271 | ############ 272 | # Transformers # 273 | ############ 274 | """ 275 | Code largely from https://github.com/juho-lee/set_transformer by yoonholee. 276 | """ 277 | class MAB(nn.Module): 278 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 279 | super(MAB, self).__init__() 280 | self.dim_V = dim_V 281 | self.num_heads = num_heads 282 | self.fc_q = nn.Linear(dim_Q, dim_V) 283 | self.fc_k = nn.Linear(dim_K, dim_V) 284 | self.fc_v = nn.Linear(dim_K, dim_V) 285 | if ln: 286 | self.ln0 = nn.LayerNorm(dim_V) 287 | self.ln1 = nn.LayerNorm(dim_V) 288 | self.fc_o = nn.Linear(dim_V, dim_V) 289 | 290 | def forward(self, Q, K): 291 | Q = self.fc_q(Q) 292 | K, V = self.fc_k(K), self.fc_v(K) 293 | 294 | dim_split = self.dim_V // self.num_heads 295 | Q_ = torch.cat(Q.split(dim_split, 2), 0) 296 | K_ = torch.cat(K.split(dim_split, 2), 0) 297 | V_ = torch.cat(V.split(dim_split, 2), 0) 298 | 299 | A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2) 300 | O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) 301 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O) 302 | O = O + F.relu(self.fc_o(O)) 303 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O) 304 | return O 305 | 306 | class SAB(nn.Module): 307 | def __init__(self, dim_in, dim_out, num_heads, ln=False): 308 | super(SAB, self).__init__() 309 | self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln) 310 | 311 | def forward(self, X): 312 | return self.mab(X, X) 313 | 314 | class ISAB(nn.Module): 315 | def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False): 316 | super(ISAB, self).__init__() 317 | self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) 318 | nn.init.xavier_uniform_(self.I) 319 | self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln) 320 | self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln) 321 | 322 | def forward(self, X): 323 | H = self.mab0(self.I.repeat(X.size(0), 1, 1), X) 324 | return self.mab1(X, H) 325 | 326 | class PMA(nn.Module): 327 | def __init__(self, dim, num_heads, num_seeds, ln=False): 328 | super(PMA, self).__init__() 329 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) 330 | nn.init.xavier_uniform_(self.S) 331 | self.mab = MAB(dim, dim, dim, num_heads, ln=ln) 332 | 333 | def forward(self, X): 334 | return self.mab(self.S.repeat(X.size(0), 1, 1), X) 335 | 336 | class SetTransformer(nn.Module): 337 | """ 338 | Set Transformer used for the Neuro-Symbolic Concept Learner. 339 | """ 340 | def __init__(self, dim_input=3, dim_output=40, dim_hidden=128, num_heads=4, ln=False): 341 | """ 342 | Builds the Set Transformer. 343 | :param dim_input: Integer, input dimensions 344 | :param dim_output: Integer, output dimensions 345 | :param dim_hidden: Integer, hidden layer dimensions 346 | :param num_heads: Integer, number of attention heads 347 | :param ln: Boolean, whether to use Layer Norm 348 | """ 349 | super(SetTransformer, self).__init__() 350 | self.enc = nn.Sequential( 351 | SAB(dim_in=dim_input, dim_out=dim_hidden, num_heads=num_heads, ln=ln), 352 | SAB(dim_in=dim_hidden, dim_out=dim_hidden, num_heads=num_heads, ln=ln), 353 | ) 354 | self.dec = nn.Sequential( 355 | nn.Dropout(), 356 | PMA(dim=dim_hidden, num_heads=num_heads, num_seeds=1, ln=ln), 357 | nn.Dropout(), 358 | nn.Linear(dim_hidden, dim_output), 359 | ) 360 | 361 | def forward(self, x): 362 | x = self.enc(x) 363 | x = self.dec(x) 364 | return x.squeeze() 365 | 366 | ############ 367 | # NeSyConceptLearner # 368 | ############ 369 | class NeSyConceptLearner(nn.Module): 370 | """ 371 | The Neuro-Symbolic Concept Learner of Stammer et al. 2021 based on Slot Attention and Set Transformer. 372 | """ 373 | def __init__(self, n_classes, n_slots=1, n_iters=3, n_attr=18, n_set_heads=4, set_transf_hidden=128, 374 | category_ids=[3, 6, 8, 10, 17], device='cuda'): 375 | """ 376 | 377 | :param n_classes: Integer, number of classes 378 | :param n_slots: Integer, number of slots for slot attention module 379 | :param n_iters: Integer, number of attention iterations for slot attentions 380 | :param n_attr: Integer, number of attributes per object 381 | :param n_set_heads: Integer, number of attention heads for set transformer 382 | :param set_transf_hidden: Integer, hidden dim of set transformer 383 | :param category_ids: List of Integers, specifying the starting ids of the attribute groups for the attribute 384 | prediction 385 | :param device: String, eihter 'cpu' or 'cuda' 386 | """ 387 | super().__init__() 388 | self.device = device 389 | # Concept Embedding Module 390 | self.img2state_net = SlotAttention_model(n_slots, n_iters, n_attr, encoder_hidden_channels=64, 391 | attention_hidden_channels=128, category_ids=category_ids, 392 | device=device) 393 | # Reasoning module 394 | self.set_cls = SetTransformer(dim_input=n_attr, dim_hidden=set_transf_hidden, num_heads=n_set_heads, 395 | dim_output=n_classes, ln=True) 396 | 397 | def forward(self, img): 398 | """ 399 | Receives an image, passes it through the concept embedding module and the reasoning module. For simplicity we 400 | here binarize the continuous output of the concept embedding module before passing it to the reasoning module. 401 | The ouputs of both modules are returned, i.e. the final class prediction and the latent symbolic representation. 402 | :param img: 4D Tensor [batch, channels, width, height] 403 | :return: Tuple of outputs of both modules, [batch, n_classes] classification/ reasoning module output, 404 | [batch, n_slots, n_attr] concept embedding module output/ symbolic representation 405 | """ 406 | attrs = self.img2state_net(img) 407 | # binarize slot attention output, apart from coordinate output 408 | attrs_trans = self.img2state_net._transform_attrs(attrs) 409 | # run through classifier via set transformer 410 | cls = self.set_cls(attrs_trans) 411 | 412 | return cls.squeeze(), attrs_trans 413 | 414 | 415 | if __name__ == "__main__": 416 | x = torch.rand(20, 3, 128, 128) 417 | net = NeSyConceptLearner(n_classes=3, n_slots=10, n_iters=3, n_attr=18, n_set_heads=4, set_transf_hidden=128, 418 | category_ids = [3, 6, 8, 10, 17], device='cpu') 419 | output = net(x) 420 | -------------------------------------------------------------------------------- /src/pretrain-slot-attention/README.md: -------------------------------------------------------------------------------- 1 | ## Set Prediction with Slot Attention on CLEVR 2 | 3 | Code for pretraining the set prediction slot attention architecture on the original CLEVR data set. To run this please 4 | download the original [CLEVR](https://cs.stanford.edu/people/jcjohns/clevr/) data set and preprocess via the 5 | ```scripts/clevr_preprocess.sh``` shell script as: 6 | 7 | ```./scripts/clevr_preprocess.sh``` 8 | 9 | This will download the CLEVR v1 data and preprocess to create an .h5 file. 10 | 11 | Then run ```scripts/clevr-slot-attention.sh``` as: 12 | 13 | ```./scripts/clevr-slot-attention.sh 0 13 /path/to/CLEVR/``` 14 | 15 | For running on GPU device 0 with run number 13 (will be stored as slot-attention-clevr-state-13) with your local path 16 | to the CLEVR directory. Note: also here please use the docker image as decribed in the main directory. -------------------------------------------------------------------------------- /src/pretrain-slot-attention/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | import json 5 | 6 | import torch 7 | import torch.utils.data 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | from torchvision.datasets.folder import pil_loader 11 | import torchvision.transforms.functional as T 12 | import h5py 13 | import numpy as np 14 | 15 | 16 | CLASSES = { 17 | "shape": ["sphere", "cube", "cylinder"], 18 | "size": ["large", "small"], 19 | "material": ["rubber", "metal"], 20 | "color": ["cyan", "blue", "yellow", "purple", "red", "green", "gray", "brown"], 21 | } 22 | 23 | 24 | def get_loader(dataset, batch_size, num_workers=8, shuffle=True): 25 | return torch.utils.data.DataLoader( 26 | dataset, 27 | shuffle=shuffle, 28 | batch_size=batch_size, 29 | pin_memory=True, 30 | num_workers=num_workers, 31 | drop_last=True, 32 | ) 33 | 34 | 35 | class CLEVR(torch.utils.data.Dataset): 36 | def __init__(self, base_path, split): 37 | assert split in { 38 | "train", 39 | "val", 40 | "test", 41 | } # note: test isn't very useful since it doesn't have ground-truth scene information 42 | self.base_path = base_path 43 | self.split = split 44 | self.max_objects = 10 45 | 46 | with self.img_db() as db: 47 | ids = db["image_ids"] 48 | self.image_id_to_index = {id: i for i, id in enumerate(ids)} 49 | self.image_db = None 50 | 51 | with open(self.scenes_path) as fd: 52 | scenes = json.load(fd)["scenes"] 53 | self.img_ids, self.scenes = self.prepare_scenes(scenes) 54 | 55 | self.category_dict = CLASSES 56 | 57 | def object_to_fv(self, obj, scene_directions): 58 | # coords = position 59 | # Originally the x, y, z positions are in [-3, 3]. 60 | # We re-normalize them to [0, 1]. 61 | # coords = (obj["3d_coords"] + 3.) / 6. 62 | # from slot attention 63 | # coords = [(p +3.)/ 6. for p in position] 64 | # convert the 3d coords based on camera position 65 | # conversion from ns-vqa paper, normalization for slot attention 66 | position = [np.dot(obj['3d_coords'], scene_directions['right']), 67 | np.dot(obj['3d_coords'], scene_directions['front']), 68 | obj['3d_coords'][2]] 69 | coords = [(p +4.)/ 8. for p in position] 70 | 71 | one_hot = lambda key: [obj[key] == x for x in CLASSES[key]] 72 | material = one_hot("material") 73 | color = one_hot("color") 74 | shape = one_hot("shape") 75 | size = one_hot("size") 76 | assert sum(material) == 1 77 | assert sum(color) == 1 78 | assert sum(shape) == 1 79 | assert sum(size) == 1 80 | # concatenate all the classes 81 | # return coords + size + material + shape + color 82 | return coords + shape + size + material + color 83 | 84 | def prepare_scenes(self, scenes_json): 85 | img_ids = [] 86 | scenes = [] 87 | for scene in scenes_json: 88 | img_idx = scene["image_index"] 89 | # different objects depending on bbox version or attribute version of CLEVR sets 90 | objects = [self.object_to_fv(obj, scene['directions']) for obj in scene["objects"]] 91 | objects = torch.FloatTensor(objects).transpose(0, 1) 92 | num_objects = objects.size(1) 93 | # pad with 0s 94 | if num_objects < self.max_objects: 95 | objects = torch.cat( 96 | [ 97 | objects, 98 | torch.zeros(objects.size(0), self.max_objects - num_objects), 99 | ], 100 | dim=1, 101 | ) 102 | # fill in masks 103 | mask = torch.zeros(self.max_objects) 104 | mask[:num_objects] = 1 105 | 106 | # concatenate obj indication to end of object list 107 | objects = torch.cat((mask.unsqueeze(dim=0), objects), dim=0) 108 | 109 | img_ids.append(img_idx) 110 | # scenes.append((objects, mask)) 111 | scenes.append(objects.T) 112 | return img_ids, scenes 113 | 114 | @property 115 | def images_folder(self): 116 | return os.path.join(self.base_path, "images", self.split) 117 | 118 | @property 119 | def scenes_path(self): 120 | if self.split == "test": 121 | raise ValueError("Scenes are not available for test") 122 | return os.path.join( 123 | self.base_path, "scenes", "CLEVR_{}_scenes.json".format(self.split) 124 | ) 125 | 126 | def img_db(self): 127 | path = os.path.join(self.base_path, "{}-images.h5".format(self.split)) 128 | return h5py.File(path, "r") 129 | 130 | def load_image(self, image_id): 131 | if self.image_db is None: 132 | self.image_db = self.img_db() 133 | index = self.image_id_to_index[image_id] 134 | image = self.image_db["images"][index] 135 | image = (image - 0.5) * 2.0 # Rescale to [-1, 1]. 136 | return image 137 | 138 | def __getitem__(self, item): 139 | image_id = self.img_ids[item] 140 | image = self.load_image(image_id) 141 | objects = self.scenes[item] 142 | return image, objects 143 | 144 | def __len__(self): 145 | return len(self.scenes) 146 | 147 | 148 | -------------------------------------------------------------------------------- /src/pretrain-slot-attention/logs/slot-attention-clevr-state-3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyConceptLearner/66dddba1e879359dcefd6685ca3e405c17369c8d/src/pretrain-slot-attention/logs/slot-attention-clevr-state-3 -------------------------------------------------------------------------------- /src/pretrain-slot-attention/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Slot attention model based on code of tkipf and the corresponding paper Locatello et al. 2020 3 | """ 4 | from torch import nn 5 | import torch 6 | import torch.nn.functional as F 7 | import torchvision.models as models 8 | import numpy as np 9 | from torchsummary import summary 10 | 11 | class SlotAttention(nn.Module): 12 | def __init__(self, num_slots, dim, iters=3, eps=1e-8, hidden_dim=128): 13 | super().__init__() 14 | self.num_slots = num_slots 15 | self.iters = iters 16 | self.eps = eps 17 | self.scale = dim ** -0.5 18 | 19 | self.slots_mu = nn.Parameter(torch.randn(1, 1, dim)) 20 | self.slots_log_sigma = nn.Parameter(torch.randn(1, 1, dim)) 21 | 22 | self.project_q = nn.Linear(dim, dim) 23 | self.project_k = nn.Linear(dim, dim) 24 | self.project_v = nn.Linear(dim, dim) 25 | 26 | self.gru = nn.GRUCell(dim, dim) 27 | 28 | hidden_dim = max(dim, hidden_dim) 29 | 30 | self.mlp = nn.Sequential( 31 | nn.Linear(dim, hidden_dim), 32 | nn.ReLU(inplace=True), 33 | nn.Linear(hidden_dim, dim) 34 | ) 35 | 36 | self.norm_inputs = nn.LayerNorm(dim, eps=1e-05) 37 | self.norm_slots = nn.LayerNorm(dim, eps=1e-05) 38 | self.norm_mlp = nn.LayerNorm(dim, eps=1e-05) 39 | 40 | def forward(self, inputs, num_slots=None): 41 | b, n, d = inputs.shape 42 | n_s = num_slots if num_slots is not None else self.num_slots 43 | 44 | mu = self.slots_mu.expand(b, n_s, -1) 45 | sigma = self.slots_log_sigma.expand(b, n_s, -1) 46 | slots = torch.normal(mu, sigma) 47 | 48 | inputs = self.norm_inputs(inputs) 49 | k, v = self.project_k(inputs), self.project_v(inputs) 50 | 51 | for _ in range(self.iters): 52 | slots_prev = slots 53 | 54 | slots = self.norm_slots(slots) 55 | q = self.project_q(slots) 56 | 57 | dots = torch.einsum('bid,bjd->bij', q, k) * self.scale 58 | attn = dots.softmax(dim=1) + self.eps 59 | attn = attn / attn.sum(dim=-1, keepdim=True) 60 | 61 | updates = torch.einsum('bjd,bij->bid', v, attn) 62 | 63 | slots = self.gru( 64 | updates.reshape(-1, d), 65 | slots_prev.reshape(-1, d) 66 | ) 67 | 68 | slots = slots.reshape(b, -1, d) 69 | slots = slots + self.mlp(self.norm_mlp(slots)) 70 | 71 | return slots 72 | 73 | 74 | class SlotAttention_encoder(nn.Module): 75 | def __init__(self, in_channels, hidden_channels): 76 | super(SlotAttention_encoder, self).__init__() 77 | self.network = nn.Sequential( 78 | nn.Conv2d(in_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2), 79 | nn.ReLU(inplace=True), 80 | nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(2, 2), padding=2), 81 | nn.ReLU(inplace=True), 82 | nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(2, 2), padding=2), 83 | nn.ReLU(inplace=True), 84 | nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2), 85 | nn.ReLU(inplace=True), 86 | ) 87 | 88 | def forward(self, x): 89 | return self.network(x) 90 | 91 | 92 | class MLP(nn.Module): 93 | def __init__(self, hidden_channels): 94 | super(MLP, self).__init__() 95 | self.network = nn.Sequential( 96 | nn.Linear(hidden_channels, hidden_channels), 97 | nn.ReLU(inplace=True), 98 | nn.Linear(hidden_channels, hidden_channels), 99 | ) 100 | 101 | def forward(self, x): 102 | return self.network(x) 103 | 104 | 105 | def build_grid(resolution): 106 | ranges = [np.linspace(0., 1., num=res) for res in resolution] 107 | grid = np.meshgrid(*ranges, sparse=False, indexing="ij") 108 | grid = np.stack(grid, axis=-1) 109 | grid = np.reshape(grid, [resolution[0], resolution[1], -1]) 110 | grid = np.expand_dims(grid, axis=0) 111 | grid = grid.astype(np.float32) 112 | return np.concatenate([grid, 1.0 - grid], axis=-1) 113 | 114 | 115 | class SoftPositionEmbed(nn.Module): 116 | """Adds soft positional embedding with learnable projection.""" 117 | 118 | def __init__(self, hidden_size, resolution, device="cuda"): 119 | """Builds the soft position embedding layer. 120 | Args: 121 | hidden_size: Size of input feature dimension. 122 | resolution: Tuple of integers specifying width and height of grid. 123 | """ 124 | super().__init__() 125 | self.dense = nn.Linear(4, hidden_size) 126 | self.grid = torch.FloatTensor(build_grid(resolution)) 127 | self.grid = self.grid.to(device) 128 | self.resolution = resolution[0] 129 | self.hidden_size = hidden_size 130 | 131 | def forward(self, inputs): 132 | return inputs + self.dense(self.grid).view((-1, self.hidden_size, self.resolution, self.resolution)) 133 | 134 | 135 | class SlotAttention_classifier(nn.Module): 136 | def __init__(self, in_channels, out_channels): 137 | super(SlotAttention_classifier, self).__init__() 138 | self.network = nn.Sequential( 139 | nn.Linear(in_channels, in_channels), 140 | nn.ReLU(inplace=True), 141 | nn.Linear(in_channels, out_channels), 142 | nn.Sigmoid() 143 | ) 144 | 145 | def forward(self, x): 146 | return self.network(x) 147 | 148 | 149 | class SlotAttention_model(nn.Module): 150 | def __init__(self, n_slots, n_iters, n_attr, 151 | in_channels=3, 152 | encoder_hidden_channels=64, 153 | attention_hidden_channels=128, 154 | device="cuda"): 155 | super(SlotAttention_model, self).__init__() 156 | self.n_slots = n_slots 157 | self.n_iters = n_iters 158 | self.n_attr = n_attr 159 | self.n_attr = n_attr + 1 # additional slot to indicate if it is a object or empty slot 160 | self.device = device 161 | 162 | self.encoder_cnn = SlotAttention_encoder(in_channels=in_channels, hidden_channels=encoder_hidden_channels) 163 | self.encoder_pos = SoftPositionEmbed(encoder_hidden_channels, (32, 32), device=device) 164 | self.layer_norm = nn.LayerNorm(encoder_hidden_channels, eps=1e-05) 165 | self.mlp = MLP(hidden_channels=encoder_hidden_channels) 166 | self.slot_attention = SlotAttention(num_slots=n_slots, dim=encoder_hidden_channels, iters=n_iters, eps=1e-8, 167 | hidden_dim=attention_hidden_channels) 168 | self.mlp_classifier = SlotAttention_classifier(in_channels=encoder_hidden_channels, out_channels=self.n_attr) 169 | 170 | def forward(self, x): 171 | x = self.encoder_cnn(x) 172 | x = self.encoder_pos(x) 173 | x = torch.flatten(x, start_dim=2) 174 | x = x.permute(0, 2, 1) 175 | x = self.layer_norm(x) 176 | x = self.mlp(x) 177 | x = self.slot_attention(x) 178 | x = self.mlp_classifier(x) 179 | return x 180 | 181 | 182 | if __name__ == "__main__": 183 | x = torch.rand(1, 3, 128, 128) 184 | net = SlotAttention_model(n_slots=10, n_iters=3, n_attr=18, 185 | encoder_hidden_channels=64, attention_hidden_channels=128) 186 | output = net(x) 187 | print(output.shape) 188 | summary(net, (3, 128, 128)) 189 | 190 | -------------------------------------------------------------------------------- /src/pretrain-slot-attention/preprocess-images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code from https://github.com/Cyanogenoid/dspn/tree/master/dspn 3 | """ 4 | import os 5 | 6 | import h5py 7 | import torch.utils.data 8 | import torchvision.models as models 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | from tqdm import tqdm 12 | 13 | 14 | class CLEVR_Images(torch.utils.data.Dataset): 15 | """ Dataset for MSCOCO images located in a folder on the filesystem """ 16 | 17 | def __init__(self, path, transform=None): 18 | super().__init__() 19 | self.path = path 20 | self.id_to_filename = self._find_images() 21 | self.sorted_ids = sorted( 22 | self.id_to_filename.keys() 23 | ) # used for deterministic iteration order 24 | print("found {} images in {}".format(len(self), self.path)) 25 | self.transform = transform 26 | 27 | def _find_images(self): 28 | id_to_filename = {} 29 | for filename in os.listdir(self.path): 30 | if not filename.endswith(".png"): 31 | continue 32 | id_and_extension = filename.split("_")[-1] 33 | id = int(id_and_extension.split(".")[0]) 34 | id_to_filename[id] = filename 35 | return id_to_filename 36 | 37 | def __getitem__(self, item): 38 | id = self.sorted_ids[item] 39 | path = os.path.join(self.path, self.id_to_filename[id]) 40 | img = Image.open(path).convert("RGB") 41 | 42 | if self.transform is not None: 43 | img = self.transform(img) 44 | return id, img 45 | 46 | def __len__(self): 47 | return len(self.sorted_ids) 48 | 49 | 50 | def create_coco_loader(path): 51 | transform = transforms.Compose( 52 | [transforms.Resize((128, 128)), transforms.ToTensor()] 53 | ) 54 | dataset = CLEVR_Images(path, transform=transform) 55 | data_loader = torch.utils.data.DataLoader( 56 | dataset, batch_size=64, num_workers=12, shuffle=False, pin_memory=True 57 | ) 58 | return data_loader 59 | 60 | 61 | def main(): 62 | for split_name in ["train", "val"]: 63 | path = os.path.join("clevr", "images", split_name) 64 | loader = create_coco_loader(path) 65 | images_shape = (len(loader.dataset), 3, 128, 128) 66 | 67 | with h5py.File("{}-images.h5".format(split_name), libver="latest") as fd: 68 | images = fd.create_dataset("images", shape=images_shape, dtype="float32") 69 | image_ids = fd.create_dataset( 70 | "image_ids", shape=(len(loader.dataset),), dtype="int32" 71 | ) 72 | 73 | i = 0 74 | for ids, imgs in tqdm(loader): 75 | j = i + imgs.size(0) 76 | images[i:j, :, :] = imgs.numpy() 77 | image_ids[i:j] = ids.numpy().astype("int32") 78 | i = j 79 | 80 | 81 | if __name__ == "__main__": 82 | main() 83 | -------------------------------------------------------------------------------- /src/pretrain-slot-attention/scripts/clevr-slot-attention.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # to be called as: python clevr-clot-attention.sh 0 0 /pathtoclevrv1/ 4 | # (for cuda device 0 and run 0) 5 | 6 | # CUDA DEVICE ID 7 | DEVICE=$1 8 | NUM=$2 9 | DATA=$3 10 | MODEL="slot-attention-clevr-state-$NUM" 11 | DATASET=clevr-state 12 | #-------------------------------------------------------------------------------# 13 | # Train on CLEVR_v1 14 | CUDA_VISIBLE_DEVICES=$DEVICE python train.py --data-dir $DATA --dataset $DATASET --epochs 2000 --name $MODEL --lr 0.0004 --batch-size 512 --n-slots 10 --n-iters-slot-att 3 --n-attr 18 15 | -------------------------------------------------------------------------------- /src/pretrain-slot-attention/scripts/clevr_preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | wget https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip 4 | unzip CLEVR_v1.0.zip 5 | mv CLEVR_v1.0 clevr 6 | python preprocess-images.py 7 | mkdir CLEVR 8 | mv train-images.h5 CLEVR/ 9 | mv val-images.h5 CLEVR/ 10 | mv clevr/scenes CLEVR/ 11 | rm -r clevr 12 | -------------------------------------------------------------------------------- /src/pretrain-slot-attention/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from datetime import datetime 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | import torchvision.transforms as transforms 9 | import torchvision.datasets as datasets 10 | import torch.multiprocessing as mp 11 | 12 | import scipy.optimize 13 | import numpy as np 14 | from tqdm import tqdm 15 | import matplotlib 16 | from torch.optim import lr_scheduler 17 | 18 | matplotlib.use("Agg") 19 | import matplotlib.pyplot as plt 20 | from torch.utils.tensorboard import SummaryWriter 21 | 22 | import data 23 | import model 24 | import utils as utils 25 | 26 | torch.set_num_threads(6) 27 | 28 | def get_args(): 29 | parser = argparse.ArgumentParser() 30 | # generic params 31 | parser.add_argument( 32 | "--name", 33 | default=datetime.now().strftime("%Y-%m-%d_%H:%M:%S"), 34 | help="Name to store the log file as", 35 | ) 36 | parser.add_argument("--resume", help="Path to log file to resume from") 37 | 38 | parser.add_argument( 39 | "--epochs", type=int, default=10, help="Number of epochs to train with" 40 | ) 41 | parser.add_argument( 42 | "--ap-log", type=int, default=10, help="Number of epochs before logging AP" 43 | ) 44 | parser.add_argument( 45 | "--lr", type=float, default=1e-2, help="Outer learning rate of model" 46 | ) 47 | parser.add_argument( 48 | "--batch-size", type=int, default=32, help="Batch size to train with" 49 | ) 50 | parser.add_argument( 51 | "--num-workers", type=int, default=4, help="Number of threads for data loader" 52 | ) 53 | parser.add_argument( 54 | "--dataset", 55 | choices=["clevr-state"], 56 | help="Use MNIST dataset", 57 | ) 58 | parser.add_argument( 59 | "--no-cuda", 60 | action="store_true", 61 | help="Run on CPU instead of GPU (not recommended)", 62 | ) 63 | parser.add_argument( 64 | "--train-only", action="store_true", help="Only run training, no evaluation" 65 | ) 66 | parser.add_argument( 67 | "--eval-only", action="store_true", help="Only run evaluation, no training" 68 | ) 69 | parser.add_argument("--multi-gpu", action="store_true", help="Use multiple GPUs") 70 | 71 | parser.add_argument("--data-dir", type=str, help="Directory to data") 72 | # Slot attention params 73 | parser.add_argument('--n-slots', default=10, type=int, 74 | help='number of slots for slot attention module') 75 | parser.add_argument('--n-iters-slot-att', default=3, type=int, 76 | help='number of iterations in slot attention module') 77 | parser.add_argument('--n-attr', default=18, type=int, 78 | help='number of attributes per object') 79 | 80 | args = parser.parse_args() 81 | return args 82 | 83 | 84 | def run(net, loader, optimizer, criterion, writer, args, train=False, epoch=0, pool=None): 85 | if train: 86 | net.train() 87 | prefix = "train" 88 | torch.set_grad_enabled(True) 89 | else: 90 | net.eval() 91 | prefix = "test" 92 | torch.set_grad_enabled(False) 93 | 94 | preds_all = torch.zeros(0, args.n_slots, args.n_attr) 95 | target_all = torch.zeros(0, args.n_slots, args.n_attr) 96 | 97 | iters_per_epoch = len(loader) 98 | 99 | for i, sample in tqdm(enumerate(loader, start=epoch * iters_per_epoch)): 100 | # input is either a set or an image 101 | imgs, target_set = map(lambda x: x.cuda(), sample) 102 | 103 | output = net.forward(imgs) 104 | 105 | loss = utils.hungarian_loss(output, target_set, thread_pool=pool) 106 | 107 | if train: 108 | optimizer.zero_grad() 109 | loss.backward() 110 | optimizer.step() 111 | 112 | writer.add_scalar("metric/train_loss", loss.item(), global_step=i) 113 | print(f"Epoch {epoch} Train Loss: {loss.item()}") 114 | 115 | else: 116 | if i % iters_per_epoch == 0: 117 | # print predictions for one image, match predictions with targets 118 | matched_output = utils.hungarian_matching(target_set[:2], output[:2].to('cuda'), verbose=0) 119 | # for k in range(2): 120 | print(f"\nGT: \n{np.round(target_set.detach().cpu().numpy()[0, 0], 2)}") 121 | print(f"\nPred: \n{np.round(matched_output.detach().cpu().numpy()[0, 0], 2)}\n") 122 | 123 | preds_all = torch.cat((preds_all, output), 0) 124 | target_all = torch.cat((target_all, target_set), 0) 125 | 126 | writer.add_scalar("metric/val_loss", loss.item(), global_step=i) 127 | 128 | 129 | def main(): 130 | args = get_args() 131 | 132 | writer = SummaryWriter(f"runs/{args.name}", purge_step=0) 133 | # writer = None 134 | utils.save_args(args, writer) 135 | 136 | dataset_train = data.CLEVR( 137 | args.data_dir, "train", 138 | ) 139 | dataset_test = data.CLEVR( 140 | args.data_dir, "val", 141 | ) 142 | 143 | if not args.eval_only: 144 | train_loader = data.get_loader( 145 | dataset_train, batch_size=args.batch_size, num_workers=args.num_workers 146 | ) 147 | if not args.train_only: 148 | test_loader = data.get_loader( 149 | dataset_test, 150 | batch_size=args.batch_size, 151 | num_workers=args.num_workers, 152 | shuffle=False, 153 | ) 154 | 155 | net = model.SlotAttention_model(n_slots=10, n_iters=3, n_attr=18, 156 | encoder_hidden_channels=64, 157 | attention_hidden_channels=128) 158 | args.n_attr = net.n_attr 159 | 160 | start_epoch = 0 161 | if args.resume: 162 | print("Loading ckpt ...") 163 | log = torch.load(args.resume) 164 | weights = log["weights"] 165 | net.load_state_dict(weights, strict=True) 166 | start_epoch = log["args"]["epochs"] 167 | 168 | 169 | if not args.no_cuda: 170 | net = net.cuda() 171 | 172 | optimizer = torch.optim.Adam(net.parameters(), lr=args.lr) 173 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0.00005) 174 | criterion = torch.nn.SmoothL1Loss() 175 | 176 | # store args as txt file 177 | utils.save_args(args, writer) 178 | 179 | for epoch in np.arange(start_epoch, args.epochs + start_epoch): 180 | with mp.Pool(10) as pool: 181 | if not args.eval_only: 182 | run(net, train_loader, optimizer, criterion, writer, args, train=True, epoch=epoch, pool=pool) 183 | cur_lr = optimizer.param_groups[0]["lr"] 184 | writer.add_scalar("lr", cur_lr, global_step=epoch * len(train_loader)) 185 | # if args.resume is not None: 186 | scheduler.step() 187 | if not args.train_only: 188 | run(net, test_loader, None, criterion, writer, args, train=False, epoch=epoch, pool=pool) 189 | if args.eval_only: 190 | exit() 191 | 192 | results = { 193 | "name": args.name, 194 | "weights": net.state_dict(), 195 | "args": vars(args), 196 | } 197 | print(os.path.join("logs", args.name)) 198 | torch.save(results, os.path.join("logs", args.name)) 199 | if args.eval_only: 200 | break 201 | 202 | 203 | if __name__ == "__main__": 204 | main() 205 | -------------------------------------------------------------------------------- /src/pretrain-slot-attention/utils.py: -------------------------------------------------------------------------------- 1 | import scipy.optimize 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import os 6 | 7 | 8 | def save_args(args, writer): 9 | # store args as txt file 10 | with open(os.path.join(writer.log_dir, 'args.txt'), 'w') as f: 11 | for arg in vars(args): 12 | f.write(f"\n{arg}: {getattr(args, arg)}") 13 | 14 | 15 | def hungarian_matching(attrs, preds_attrs, verbose=0): 16 | """ 17 | Receives unordered predicted set and orders this to match the nearest GT set. 18 | :param attrs: 19 | :param preds_attrs: 20 | :param verbose: 21 | :return: 22 | """ 23 | assert attrs.shape[1] == preds_attrs.shape[1] 24 | assert attrs.shape == preds_attrs.shape 25 | from scipy.optimize import linear_sum_assignment 26 | matched_preds_attrs = preds_attrs.clone() 27 | for sample_id in range(attrs.shape[0]): 28 | # using euclidean distance 29 | cost_matrix = torch.cdist(attrs[sample_id], preds_attrs[sample_id]).detach().cpu() 30 | 31 | idx_mapping = linear_sum_assignment(cost_matrix) 32 | # convert to tuples of [(row_id, col_id)] of the cost matrix 33 | idx_mapping = [(idx_mapping[0][i], idx_mapping[1][i]) for i in range(len(idx_mapping[0]))] 34 | 35 | for i, (row_id, col_id) in enumerate(idx_mapping): 36 | matched_preds_attrs[sample_id, row_id, :] = preds_attrs[sample_id, col_id, :] 37 | if verbose: 38 | print('GT: {}'.format(attrs[sample_id])) 39 | print('Pred: {}'.format(preds_attrs[sample_id])) 40 | print('Cost Matrix: {}'.format(cost_matrix)) 41 | print('idx mapping: {}'.format(idx_mapping)) 42 | print('Matched Pred: {}'.format(matched_preds_attrs[sample_id])) 43 | print('\n') 44 | # exit() 45 | 46 | return matched_preds_attrs 47 | 48 | 49 | def hungarian_loss(predictions, targets, thread_pool): 50 | # permute dimensions for pairwise distance computation between all slots 51 | predictions = predictions.permute(0, 2, 1) 52 | targets = targets.permute(0, 2, 1) 53 | 54 | # predictions and targets shape :: (n, c, s) 55 | predictions, targets = outer(predictions, targets) 56 | # squared_error shape :: (n, s, s) 57 | squared_error = F.smooth_l1_loss(predictions, targets.expand_as(predictions), reduction="none").mean(1) 58 | 59 | squared_error_np = squared_error.detach().cpu().numpy() 60 | indices = thread_pool.map(hungarian_loss_per_sample, squared_error_np) 61 | losses = [ 62 | sample[row_idx, col_idx].mean() 63 | for sample, (row_idx, col_idx) in zip(squared_error, indices) 64 | ] 65 | total_loss = torch.mean(torch.stack(list(losses))) 66 | return total_loss 67 | 68 | 69 | def hungarian_loss_per_sample(sample_np): 70 | return scipy.optimize.linear_sum_assignment(sample_np) 71 | 72 | 73 | def outer(a, b=None): 74 | """ Compute outer product between a and b (or a and a if b is not specified). """ 75 | if b is None: 76 | b = a 77 | size_a = tuple(a.size()) + (b.size()[-1],) 78 | size_b = tuple(b.size()) + (a.size()[-1],) 79 | a = a.unsqueeze(dim=-1).expand(*size_a) 80 | b = b.unsqueeze(dim=-2).expand(*size_b) 81 | return a, b -------------------------------------------------------------------------------- /src/runs/CLEVR-Hans3/concept-learner-0-CLEVR-Hans3_seed0/args.txt: -------------------------------------------------------------------------------- 1 | 2 | name: concept-learner-0-CLEVR-Hans3 3 | mode: train 4 | resume: None 5 | seed: 0 6 | epochs: 50 7 | lr: 0.0001 8 | l2_grads: 1 9 | batch_size: 12 10 | num_workers: 4 11 | dataset: clevr-hans-state 12 | no_cuda: False 13 | train_only: False 14 | eval_only: False 15 | multi_gpu: False 16 | data_dir: /workspace/datasets/CLEVR-Hans3/ 17 | fp_ckpt: None 18 | n_slots: 10 19 | n_iters_slot_att: 3 20 | n_attr: 18 21 | n_heads: 4 22 | set_transf_hidden: 128 23 | conf_version: CLEVR-Hans3 24 | device: cuda 25 | n_imgclasses: 3 26 | class_weights: tensor([0.3333, 0.3333, 0.3333]) 27 | classes: [0 1 2] 28 | category_ids: [ 3 6 8 10 18] -------------------------------------------------------------------------------- /src/runs/CLEVR-Hans3/concept-learner-0-CLEVR-Hans3_seed0/events.out.tfevents.1615486383.ml-meteor.local.26896.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyConceptLearner/66dddba1e879359dcefd6685ca3e405c17369c8d/src/runs/CLEVR-Hans3/concept-learner-0-CLEVR-Hans3_seed0/events.out.tfevents.1615486383.ml-meteor.local.26896.0 -------------------------------------------------------------------------------- /src/scripts/clevr-hans-concept-learner_CLEVR_Hans3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CUDA DEVICE ID 4 | DEVICE=$1 5 | NUM=$2 6 | DATA=$3 7 | MODEL="concept-learner-$NUM" 8 | DATASET=clevr-hans-state 9 | OUTPATH="out/clevr-state/$MODEL-$ITER" 10 | 11 | #-------------------------------------------------------------------------------# 12 | # CLEVR-Hans3 13 | 14 | # For gpu 15 | CUDA_VISIBLE_DEVICES=$DEVICE python train_nesy_concept_learner_clevr_hans.py --data-dir $DATA --dataset $DATASET \ 16 | --epochs 50 --name $MODEL --lr 0.0001 --batch-size 128 --n-slots 10 --n-iters-slot-att 3 --n-attr 18 --seed 0 \ 17 | --mode train 18 | 19 | # For cpu 20 | #CUDA_VISIBLE_DEVICES=$DEVICE python train_nesy_concept_learner_clevr_hans.py --data-dir $DATA --dataset $DATASET \ 21 | #--epochs 50 --name $MODEL --lr 0.0001 --batch-size 128 --n-slots 10 --n-iters-slot-att 3 --n-attr 18 --seed 0 \ 22 | #--mode train --no-cuda 23 | -------------------------------------------------------------------------------- /src/train_nesy_concept_learner_clevr_hans.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | import glob 8 | from sklearn import metrics 9 | from tqdm import tqdm 10 | 11 | import data_clevr_hans as data 12 | import model 13 | import utils as utils 14 | from rtpt import RTPT 15 | from args import get_args 16 | 17 | torch.autograd.set_detect_anomaly(True) 18 | 19 | os.environ["MKL_NUM_THREADS"] = "6" 20 | os.environ["NUMEXPR_NUM_THREADS"] = "6" 21 | os.environ["OMP_NUM_THREADS"] = "6" 22 | torch.set_num_threads(6) 23 | 24 | # ----------------------------------------- 25 | # - Define basic and data related methods - 26 | # ----------------------------------------- 27 | def get_confusion_from_ckpt(net, test_loader, criterion, args, datasplit, writer=None): 28 | 29 | true, pred, true_wrong, pred_wrong = run_test_final(net, test_loader, criterion, writer, args, datasplit) 30 | precision, recall, accuracy, f1_score = utils.performance_matrix(true, pred) 31 | 32 | # Generate Confusion Matrix 33 | if writer is not None: 34 | utils.plot_confusion_matrix(true, pred, normalize=True, classes=args.classes, 35 | sFigName=os.path.join(writer.log_dir, 'Confusion_matrix_normalize_{}.pdf'.format( 36 | datasplit)) 37 | ) 38 | utils.plot_confusion_matrix(true, pred, normalize=False, classes=args.classes, 39 | sFigName=os.path.join(writer.log_dir, 'Confusion_matrix_{}.pdf'.format(datasplit))) 40 | else: 41 | utils.plot_confusion_matrix(true, pred, normalize=True, classes=args.classes, 42 | sFigName=os.path.join(os.path.sep.join(args.fp_ckpt.split(os.path.sep)[:-1]), 43 | 'Confusion_matrix_normalize_{}.pdf'.format(datasplit))) 44 | utils.plot_confusion_matrix(true, pred, normalize=False, classes=args.classes, 45 | sFigName=os.path.join(os.path.sep.join(args.fp_ckpt.split(os.path.sep)[:-1]), 46 | 'Confusion_matrix_{}.pdf'.format(datasplit))) 47 | return accuracy 48 | 49 | # ----------------------------------------- 50 | # - Define Train/Test/Validation methods - 51 | # ----------------------------------------- 52 | def run_test_final(net, loader, criterion, writer, args, datasplit): 53 | net.eval() 54 | 55 | running_corrects = 0 56 | running_loss=0 57 | pred_wrong = [] 58 | true_wrong = [] 59 | preds_all = [] 60 | labels_all = [] 61 | with torch.no_grad(): 62 | 63 | for i, sample in enumerate(tqdm(loader)): 64 | # input is either a set or an image 65 | imgs, target_set, img_class_ids, img_ids, _, table_expl = map(lambda x: x.cuda(), sample) 66 | img_class_ids = img_class_ids.long() 67 | 68 | # forward evaluation through the network 69 | output_cls, output_attr = net(imgs) 70 | # class prediction 71 | _, preds = torch.max(output_cls, 1) 72 | 73 | labels_all.extend(img_class_ids.cpu().numpy()) 74 | preds_all.extend(preds.cpu().numpy()) 75 | 76 | running_corrects = running_corrects + torch.sum(preds == img_class_ids) 77 | loss = criterion(output_cls, img_class_ids) 78 | running_loss += loss.item() 79 | preds = preds.cpu().numpy() 80 | target = img_class_ids.cpu().numpy() 81 | preds = np.reshape(preds, (len(preds), 1)) 82 | target = np.reshape(target, (len(preds), 1)) 83 | 84 | for i in range(len(preds)): 85 | if (preds[i] != target[i]): 86 | pred_wrong.append(preds[i]) 87 | true_wrong.append(target[i]) 88 | 89 | bal_acc = metrics.balanced_accuracy_score(labels_all, preds_all) 90 | 91 | if writer is not None: 92 | writer.add_scalar(f"Loss/{datasplit}_loss", running_loss / len(loader), 0) 93 | writer.add_scalar(f"Acc/{datasplit}_bal_acc", bal_acc, 0) 94 | 95 | return labels_all, preds_all, true_wrong, pred_wrong 96 | 97 | 98 | def run(net, loader, optimizer, criterion, split, writer, args, train=False, plot=False, epoch=0): 99 | if train: 100 | net.img2state_net.eval() 101 | net.set_cls.train() 102 | torch.set_grad_enabled(True) 103 | else: 104 | net.eval() 105 | torch.set_grad_enabled(False) 106 | 107 | iters_per_epoch = len(loader) 108 | loader = tqdm( 109 | loader, 110 | ncols=0, 111 | desc="{1} E{0:02d}".format(epoch, "train" if train else "val "), 112 | ) 113 | running_loss = 0 114 | preds_all = [] 115 | labels_all = [] 116 | for i, sample in enumerate(loader, start=epoch * iters_per_epoch): 117 | # input is either a set or an image 118 | imgs, target_set, img_class_ids, img_ids, _, table_expl = map(lambda x: x.cuda(), sample) 119 | img_class_ids = img_class_ids.long() 120 | 121 | # forward evaluation through the network 122 | output_cls, output_attr = net(imgs) 123 | 124 | # class prediction 125 | _, preds = torch.max(output_cls, 1) 126 | 127 | loss = criterion(output_cls, img_class_ids) 128 | 129 | # Outer optim step 130 | if train: 131 | optimizer.zero_grad() 132 | loss.backward() 133 | optimizer.step() 134 | 135 | running_loss += loss.item() 136 | labels_all.extend(img_class_ids.cpu().numpy()) 137 | preds_all.extend(preds.cpu().numpy()) 138 | 139 | # Plot predictions in Tensorboard 140 | if plot and not(i % iters_per_epoch): 141 | utils.write_expls(net, loader, f"Expl/{split}", epoch, writer) 142 | 143 | bal_acc = metrics.balanced_accuracy_score(labels_all, preds_all) 144 | 145 | writer.add_scalar(f"Loss/{split}_loss", running_loss / len(loader), epoch) 146 | writer.add_scalar(f"Acc/{split}_bal_acc", bal_acc, epoch) 147 | 148 | print("Epoch: {}/{}.. ".format(epoch, args.epochs), 149 | "{} Loss: {:.3f}.. ".format(split, running_loss / len(loader)), 150 | "{} Accuracy: {:.3f}.. ".format(split, bal_acc), 151 | ) 152 | 153 | return running_loss / len(loader) 154 | 155 | 156 | def train(args): 157 | 158 | if args.dataset == "clevr-hans-state": 159 | dataset_train = data.CLEVR_HANS_EXPL( 160 | args.data_dir, "train", lexi=True, conf_vers=args.conf_version 161 | ) 162 | dataset_val = data.CLEVR_HANS_EXPL( 163 | args.data_dir, "val", lexi=True, conf_vers=args.conf_version 164 | ) 165 | dataset_test = data.CLEVR_HANS_EXPL( 166 | args.data_dir, "test", lexi=True, conf_vers=args.conf_version 167 | ) 168 | else: 169 | print("Wrong dataset specifier") 170 | exit() 171 | 172 | args.n_imgclasses = dataset_train.n_classes 173 | args.class_weights = torch.ones(args.n_imgclasses)/args.n_imgclasses 174 | args.classes = np.arange(args.n_imgclasses) 175 | args.category_ids = dataset_train.category_ids 176 | 177 | train_loader = data.get_loader( 178 | dataset_train, 179 | batch_size=args.batch_size, 180 | num_workers=args.num_workers, 181 | shuffle=True, 182 | ) 183 | test_loader = data.get_loader( 184 | dataset_test, 185 | batch_size=args.batch_size, 186 | num_workers=args.num_workers, 187 | shuffle=False, 188 | ) 189 | val_loader = data.get_loader( 190 | dataset_val, 191 | batch_size=args.batch_size, 192 | num_workers=args.num_workers, 193 | shuffle=False, 194 | ) 195 | 196 | net = model.NeSyConceptLearner(n_classes=args.n_imgclasses, n_slots=args.n_slots, n_iters=args.n_iters_slot_att, 197 | n_attr=args.n_attr, n_set_heads=args.n_heads, set_transf_hidden=args.set_transf_hidden, 198 | category_ids=args.category_ids, device=args.device) 199 | 200 | # load pretrained concept embedding module 201 | log = torch.load("logs/slot-attention-clevr-state-3_final", map_location=torch.device(args.device)) 202 | net.img2state_net.load_state_dict(log['weights'], strict=True) 203 | print("Pretrained slot attention model loaded!") 204 | 205 | net = net.to(args.device) 206 | 207 | # only optimize the set transformer classifier for now, i.e. freeze the state predictor 208 | optimizer = torch.optim.Adam( 209 | [p for name, p in net.named_parameters() if p.requires_grad and 'set_cls' in name], lr=args.lr 210 | ) 211 | criterion = nn.CrossEntropyLoss() 212 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=0.000001) 213 | 214 | torch.backends.cudnn.benchmark = True 215 | 216 | # Create RTPT object 217 | rtpt = RTPT(name_initials='WS', experiment_name=f"Clevr Hans Slot Att Set Transf xil", 218 | max_iterations=args.epochs) 219 | # Start the RTPT tracking 220 | rtpt.start() 221 | 222 | # tensorboard writer 223 | writer = utils.create_writer(args) 224 | 225 | cur_best_val_loss = np.inf 226 | for epoch in range(args.epochs): 227 | _ = run(net, train_loader, optimizer, criterion, split='train', args=args, writer=writer, 228 | train=True, plot=False, epoch=epoch) 229 | scheduler.step() 230 | val_loss = run(net, val_loader, optimizer, criterion, split='val', args=args, writer=writer, 231 | train=False, plot=True, epoch=epoch) 232 | _ = run(net, test_loader, optimizer, criterion, split='test', args=args, writer=writer, 233 | train=False, plot=False, epoch=epoch) 234 | 235 | results = { 236 | "name": args.name, 237 | "weights": net.state_dict(), 238 | "args": args, 239 | } 240 | if cur_best_val_loss > val_loss: 241 | if epoch > 0: 242 | # remove previous best model 243 | os.remove(glob.glob(os.path.join(writer.log_dir, "model_*_bestvalloss*.pth"))[0]) 244 | torch.save(results, os.path.join(writer.log_dir, "model_epoch{}_bestvalloss_{:.4f}.pth".format(epoch, 245 | val_loss))) 246 | cur_best_val_loss = val_loss 247 | 248 | # Update the RTPT (subtitle is optional) 249 | rtpt.step() 250 | 251 | # load best model for final evaluation 252 | net = model.NeSyConceptLearner(args, n_slots=args.n_slots, n_iters=args.n_iters_slot_att, n_attr=args.n_attr, 253 | set_transf_hidden=args.set_transf_hidden, category_ids=args.category_ids, 254 | device=args.device) 255 | net = net.to(args.device) 256 | 257 | checkpoint = torch.load(glob.glob(os.path.join(writer.log_dir, "model_*_bestvalloss*.pth"))[0]) 258 | net.load_state_dict(checkpoint['weights']) 259 | net.eval() 260 | print("\nModel loaded from checkpoint for final evaluation\n") 261 | 262 | get_confusion_from_ckpt(net, test_loader, criterion, args=args, datasplit='test_best', 263 | writer=writer) 264 | get_confusion_from_ckpt(net, val_loader, criterion, args=args, datasplit='val_best', 265 | writer=writer) 266 | 267 | # plot expls 268 | run(net, train_loader, optimizer, criterion, split='train_best', args=args, 269 | writer=writer, train=False, plot=True, epoch=0) 270 | run(net, val_loader, optimizer, criterion, split='val_best', args=args, 271 | writer=writer, train=False, plot=True, epoch=0) 272 | run(net, test_loader, optimizer, criterion, split='test_best', args=args, 273 | writer=writer, train=False, plot=True, epoch=0) 274 | 275 | writer.close() 276 | 277 | 278 | def test(args): 279 | 280 | print(f"\n\n{args.name} seed {args.seed}\n") 281 | if args.dataset == "clevr-hans-state": 282 | dataset_val = data.CLEVR_HANS_EXPL( 283 | args.data_dir, "val", lexi=True, conf_vers=args.conf_version 284 | ) 285 | dataset_test = data.CLEVR_HANS_EXPL( 286 | args.data_dir, "test", lexi=True, conf_vers=args.conf_version 287 | ) 288 | else: 289 | print("Wrong dataset specifier") 290 | exit() 291 | 292 | args.n_imgclasses = dataset_val.n_classes 293 | args.class_weights = torch.ones(args.n_imgclasses)/args.n_imgclasses 294 | args.classes = np.arange(args.n_imgclasses) 295 | args.category_ids = dataset_val.category_ids 296 | 297 | test_loader = data.get_loader( 298 | dataset_test, 299 | batch_size=args.batch_size, 300 | num_workers=args.num_workers, 301 | shuffle=False, 302 | ) 303 | val_loader = data.get_loader( 304 | dataset_val, 305 | batch_size=args.batch_size, 306 | num_workers=args.num_workers, 307 | shuffle=False, 308 | ) 309 | 310 | criterion = nn.CrossEntropyLoss() 311 | 312 | net = model.NeSyConceptLearner(n_classes=args.n_imgclasses, n_slots=args.n_slots, n_iters=args.n_iters_slot_att, 313 | n_attr=args.n_attr, n_set_heads=args.n_heads, set_transf_hidden=args.set_transf_hidden, 314 | category_ids=args.category_ids, device=args.device) 315 | net = net.to(args.device) 316 | 317 | checkpoint = torch.load(args.fp_ckpt) 318 | net.load_state_dict(checkpoint['weights']) 319 | net.eval() 320 | print("\nModel loaded from checkpoint for final evaluation\n") 321 | 322 | acc = get_confusion_from_ckpt(net, val_loader, criterion, args=args, datasplit='val_best', writer=None) 323 | print(f"\nVal. accuracy: {(100*acc):.2f}") 324 | acc = get_confusion_from_ckpt(net, test_loader, criterion, args=args, datasplit='test_best', writer=None) 325 | print(f"\nTest accuracy: {(100*acc):.2f}") 326 | 327 | 328 | def plot(args): 329 | 330 | print(f"\n\n{args.name} seed {args.seed}\n") 331 | 332 | # no positional info per object 333 | if args.dataset == "clevr-hans-state": 334 | dataset_val = data.CLEVR_HANS_EXPL( 335 | args.data_dir, "val", lexi=True, conf_vers=args.conf_version 336 | ) 337 | dataset_test = data.CLEVR_HANS_EXPL( 338 | args.data_dir, "test", lexi=True, conf_vers=args.conf_version 339 | ) 340 | else: 341 | print("Wrong dataset specifier") 342 | exit() 343 | 344 | args.n_imgclasses = dataset_val.n_classes 345 | args.class_weights = torch.ones(args.n_imgclasses)/args.n_imgclasses 346 | args.classes = np.arange(args.n_imgclasses) 347 | args.category_ids = dataset_val.category_ids 348 | 349 | test_loader = data.get_loader( 350 | dataset_test, 351 | batch_size=args.batch_size, 352 | num_workers=args.num_workers, 353 | shuffle=False, 354 | ) 355 | 356 | # load best model for final evaluation 357 | net = model.NeSyConceptLearner(n_classes=args.n_imgclasses, n_slots=args.n_slots, n_iters=args.n_iters_slot_att, 358 | n_attr=args.n_attr, n_set_heads=args.n_heads, set_transf_hidden=args.set_transf_hidden, 359 | category_ids=args.category_ids, device=args.device) 360 | net = net.to(args.device) 361 | 362 | checkpoint = torch.load(args.fp_ckpt) 363 | net.load_state_dict(checkpoint['weights']) 364 | net.eval() 365 | print("\nModel loaded from checkpoint for final evaluation\n") 366 | 367 | save_dir = args.fp_ckpt.split('model_epoch')[0]+'figures/' 368 | try: 369 | os.makedirs(save_dir) 370 | except FileExistsError: 371 | # directory already exists 372 | pass 373 | 374 | # change plotting function in utils in order to visualize explanations 375 | assert args.conf_version == 'CLEVR-Hans3' 376 | utils.save_expls(net, test_loader, "test", save_path=save_dir) 377 | 378 | 379 | def main(): 380 | args = get_args() 381 | if args.mode == 'train': 382 | train(args) 383 | elif args.mode == 'test': 384 | test(args) 385 | elif args.mode == 'plot': 386 | plot(args) 387 | 388 | 389 | if __name__ == "__main__": 390 | main() 391 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import io 4 | import os 5 | import torch 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | # from skimage import color 9 | from sklearn import metrics 10 | from matplotlib import rc 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torchvision import transforms 13 | from captum.attr import IntegratedGradients 14 | 15 | axislabel_fontsize = 8 16 | ticklabel_fontsize = 8 17 | titlelabel_fontsize = 8 18 | 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.benchmark = False 21 | random.seed(1) 22 | np.random.seed(1) 23 | torch.manual_seed(1) 24 | 25 | 26 | def seed_everything(seed=42): 27 | random.seed(seed) 28 | os.environ['PYTHONHASHSEED'] = str(seed) 29 | np.random.seed(seed) 30 | torch.manual_seed(seed) 31 | torch.backends.cudnn.deterministic = True 32 | torch.backends.cudnn.benchmark = False 33 | 34 | 35 | def resize_tensor(input_tensors, h, w): 36 | input_tensors = torch.squeeze(input_tensors, 1) 37 | 38 | for i, img in enumerate(input_tensors): 39 | img_PIL = transforms.ToPILImage()(img) 40 | img_PIL = transforms.Resize([h, w])(img_PIL) 41 | img_PIL = transforms.ToTensor()(img_PIL) 42 | if i == 0: 43 | final_output = img_PIL 44 | else: 45 | final_output = torch.cat((final_output, img_PIL), 0) 46 | final_output = torch.unsqueeze(final_output, 1) 47 | return final_output 48 | 49 | 50 | def norm_saliencies(saliencies): 51 | saliencies_norm = saliencies.clone() 52 | 53 | for i in range(saliencies.shape[0]): 54 | if len(torch.nonzero(saliencies[i], as_tuple=False)) == 0: 55 | saliencies_norm[i] = saliencies[i] 56 | else: 57 | saliencies_norm[i] = (saliencies[i] - torch.min(saliencies[i])) / \ 58 | (torch.max(saliencies[i]) - torch.min(saliencies[i])) 59 | 60 | return saliencies_norm 61 | 62 | 63 | def generate_intgrad_captum_table(net, input, labels): 64 | labels = labels.to("cuda") 65 | explainer = IntegratedGradients(net) 66 | saliencies = explainer.attribute(input, target=labels) 67 | # remove negative attributions 68 | saliencies[saliencies < 0] = 0. 69 | # normalize each saliency map by its max 70 | for k, sal in enumerate(saliencies): 71 | saliencies[k] = sal/torch.max(sal) 72 | return norm_saliencies(saliencies) 73 | 74 | 75 | def test_hungarian_matching(attrs=torch.tensor([[[0, 1, 1, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0]], 76 | [[0, 1, 1, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0]]]).type(torch.float), 77 | pred_attrs=torch.tensor([[[0.01, 0.1, 0.2, 0.1, 0.2, 0.2, 0.01], 78 | [0.1, 0.6, 0.8, 0., 0.4, 0.001, 0.9]], 79 | [[0.01, 0.1, 0.2, 0.1, 0.2, 0.2, 0.01], 80 | [0.1, 0.6, 0.8, 0., 0.4, 0.001, 0.9]]]).type(torch.float)): 81 | hungarian_matching(attrs, pred_attrs, verbose=1) 82 | 83 | 84 | def hungarian_matching(attrs, preds_attrs, verbose=0): 85 | """ 86 | Receives unordered predicted set and orders this to match the nearest GT set. 87 | :param attrs: 88 | :param preds_attrs: 89 | :param verbose: 90 | :return: 91 | """ 92 | assert attrs.shape[1] == preds_attrs.shape[1] 93 | assert attrs.shape == preds_attrs.shape 94 | from scipy.optimize import linear_sum_assignment 95 | matched_preds_attrs = preds_attrs.clone() 96 | idx_map_ids = [] 97 | for sample_id in range(attrs.shape[0]): 98 | # using euclidean distance 99 | cost_matrix = torch.cdist(attrs[sample_id], preds_attrs[sample_id]).detach().cpu() 100 | 101 | idx_mapping = linear_sum_assignment(cost_matrix) 102 | # convert to tuples of [(row_id, col_id)] of the cost matrix 103 | idx_mapping = [(idx_mapping[0][i], idx_mapping[1][i]) for i in range(len(idx_mapping[0]))] 104 | 105 | idx_map_ids.append([idx_mapping[i][1] for i in range(len(idx_mapping))]) 106 | 107 | for i, (row_id, col_id) in enumerate(idx_mapping): 108 | matched_preds_attrs[sample_id, row_id, :] = preds_attrs[sample_id, col_id, :] 109 | if verbose: 110 | print('GT: {}'.format(attrs[sample_id])) 111 | print('Pred: {}'.format(preds_attrs[sample_id])) 112 | print('Cost Matrix: {}'.format(cost_matrix)) 113 | print('idx mapping: {}'.format(idx_mapping)) 114 | print('Matched Pred: {}'.format(matched_preds_attrs[sample_id])) 115 | print('\n') 116 | # exit() 117 | 118 | idx_map_ids = np.array(idx_map_ids) 119 | return matched_preds_attrs, idx_map_ids 120 | 121 | 122 | def create_writer(args): 123 | writer = SummaryWriter(f"runs/{args.conf_version}/{args.name}_seed{args.seed}", purge_step=0) 124 | 125 | writer.add_scalar('Hyperparameters/learningrate', args.lr, 0) 126 | writer.add_scalar('Hyperparameters/num_epochs', args.epochs, 0) 127 | writer.add_scalar('Hyperparameters/batchsize', args.batch_size, 0) 128 | 129 | # store args as txt file 130 | with open(os.path.join(writer.log_dir, 'args.txt'), 'w') as f: 131 | for arg in vars(args): 132 | f.write(f"\n{arg}: {getattr(args, arg)}") 133 | return writer 134 | 135 | 136 | def create_expl_images(img, pred_attrs, table_expl_attrs, img_expl, true_class_name, pred_class_name, xticklabels): 137 | """ 138 | """ 139 | assert pred_attrs.shape[0:2] == table_expl_attrs.shape[0:2] 140 | 141 | fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(8, 3)) 142 | ax[0].imshow(img) 143 | ax[0].axis('off') 144 | ax[0].set_title("Img") 145 | 146 | ax[1].imshow(pred_attrs, cmap='gray') 147 | ax[1].set_ylabel('Slot. ID', fontsize=axislabel_fontsize) 148 | ax[1].yaxis.set_label_coords(-0.1, 0.5) 149 | ax[1].set_yticks(np.arange(0, 11)) 150 | ax[1].yaxis.set_tick_params(labelsize=axislabel_fontsize) 151 | ax[1].set_xlabel('Obj. Attr', fontsize=axislabel_fontsize) 152 | ax[1].set_xticks(range(len(xticklabels))) 153 | ax[1].set_xticklabels(xticklabels, rotation=90, fontsize=ticklabel_fontsize) 154 | ax[1].set_title("Pred Attr") 155 | 156 | ax[2].imshow(img_expl) 157 | ax[2].axis('off') 158 | ax[2].set_title("Img Expl") 159 | 160 | im = ax[3].imshow(table_expl_attrs) 161 | ax[3].set_yticks(np.arange(0, 11)) 162 | ax[3].yaxis.set_tick_params(labelsize=axislabel_fontsize) 163 | ax[3].set_xlabel('Obj. Attr', fontsize=axislabel_fontsize) 164 | ax[3].set_xticks(range(len(xticklabels))) 165 | ax[3].set_xticklabels(xticklabels, rotation=90, fontsize=ticklabel_fontsize) 166 | ax[3].set_title("Table Expl") 167 | 168 | fig.suptitle(f"True Class: {true_class_name}; Pred Class: {pred_class_name}", fontsize=titlelabel_fontsize) 169 | 170 | return fig 171 | 172 | 173 | def performance_matrix(true, pred): 174 | precision = metrics.precision_score(true, pred, average='macro') 175 | recall = metrics.recall_score(true, pred, average='macro') 176 | accuracy = metrics.accuracy_score(true, pred) 177 | f1_score = metrics.f1_score(true, pred, average='macro') 178 | # print('Confusion Matrix:\n', metrics.confusion_matrix(true, pred)) 179 | print('Precision: {:.3f} Recall: {:.3f}, Accuracy: {:.3f}: ,f1_score: {:.3f}'.format(precision*100,recall*100, 180 | accuracy*100,f1_score*100)) 181 | return precision, recall, accuracy, f1_score 182 | 183 | 184 | def plot_confusion_matrix(y_true, y_pred, classes, normalize=True, title=None, 185 | cmap=plt.cm.Blues, sFigName='confusion_matrix.pdf'): 186 | """ 187 | This function prints and plots the confusion matrix. 188 | Normalization can be applied by setting `normalize=True`. 189 | """ 190 | if not title: 191 | if normalize: 192 | title = 'Normalized confusion matrix' 193 | else: 194 | title = 'Confusion matrix, without normalization' 195 | # Compute confusion matrix 196 | cm = metrics.confusion_matrix(y_true, y_pred) 197 | if normalize: 198 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 199 | print("Normalized confusion matrix") 200 | else: 201 | print('Confusion matrix, without normalization') 202 | print(cm) 203 | fig, ax = plt.subplots() 204 | im = ax.imshow(cm, interpolation='nearest', cmap=cmap) 205 | ax.figure.colorbar(im, ax=ax) 206 | # We want to show all ticks... 207 | ax.set(xticks=np.arange(cm.shape[1]), 208 | yticks=np.arange(cm.shape[0]), 209 | # ... and label them with the respective list entries 210 | xticklabels=classes, yticklabels=classes, 211 | title=title, 212 | ylabel='True label', 213 | xlabel='Predicted label') 214 | # Rotate the tick labels and set their alignment. 215 | plt.setp(ax.get_xticklabels(), ha="right", 216 | rotation_mode="anchor") 217 | # Loop over data dimensions and create text annotations. 218 | fmt = '.2f' if normalize else 'd' 219 | thresh = cm.max() / 2. 220 | for i in range(cm.shape[0]): 221 | for j in range(cm.shape[1]): 222 | ax.text(j, i, format(cm[i, j], fmt), 223 | ha="center", va="center", 224 | color="white" if cm[i, j] > thresh else "black") 225 | fig.tight_layout() 226 | plt.savefig(sFigName) 227 | return ax 228 | 229 | 230 | def write_expls(net, data_loader, tagname, epoch, writer): 231 | """ 232 | Writes NeSy Concpet Learner explanations to tensorboard writer. 233 | """ 234 | 235 | attr_labels = ['Sphere', 'Cube', 'Cylinder', 236 | 'Large', 'Small', 237 | 'Rubber', 'Metal', 238 | 'Cyan', 'Blue', 'Yellow', 'Purple', 'Red', 'Green', 'Gray', 'Brown'] 239 | 240 | net.eval() 241 | 242 | for i, sample in enumerate(data_loader): 243 | # input is either a set or an image 244 | imgs, target_set, img_class_ids, img_ids, _, _ = map(lambda x: x.cuda(), sample) 245 | img_class_ids = img_class_ids.long() 246 | 247 | # forward evaluation through the network 248 | output_cls, output_attr = net(imgs) 249 | _, preds = torch.max(output_cls, 1) 250 | 251 | # convert sorting gt target set and gt table explanations to match the order of the predicted table 252 | target_set, match_ids = hungarian_matching(output_attr.to('cuda'), target_set) 253 | # table_expls = table_expls[:, match_ids][range(table_expls.shape[0]), range(table_expls.shape[0])] 254 | 255 | # get explanations of set classifier 256 | table_saliencies = generate_intgrad_captum_table(net.set_cls, output_attr, preds) 257 | 258 | # get the ids of the two objects that receive the maximal importance, i.e. most important for the classification 259 | max_expl_obj_ids = table_saliencies.max(dim=2)[0].topk(2)[1] 260 | 261 | # get attention masks 262 | attns = net.img2state_net.slot_attention.attn 263 | # reshape attention masks to 2D 264 | attns = attns.reshape((attns.shape[0], attns.shape[1], int(np.sqrt(attns.shape[2])), 265 | int(np.sqrt(attns.shape[2])))) 266 | 267 | # concatenate the visual explanation of the top two objects that are most important for the classification 268 | img_saliencies = torch.zeros(attns.shape[0], attns.shape[2], attns.shape[3]) 269 | for obj_id in range(max_expl_obj_ids.shape[1]): 270 | img_saliencies += attns[range(attns.shape[0]), obj_id, :, :].detach().cpu() 271 | 272 | # upscale img_saliencies to orig img shape 273 | img_saliencies = resize_tensor(img_saliencies.cpu(), imgs.shape[2], imgs.shape[2]).squeeze(dim=1).cpu() 274 | 275 | for img_id, (img, gt_table, pred_table, table_expl, img_expl, true_label, pred_label, imgid) in enumerate(zip( 276 | imgs, target_set, output_attr, table_saliencies, 277 | img_saliencies, img_class_ids, preds, 278 | img_ids 279 | )): 280 | # unnormalize images 281 | img = img / 2. + 0.5 # Rescale to [0, 1]. 282 | 283 | fig = create_expl_images(np.array(transforms.ToPILImage()(img.cpu()).convert("RGB")), 284 | pred_table.detach().cpu().numpy(), 285 | table_expl.detach().cpu().numpy(), 286 | img_expl.detach().cpu().numpy(), 287 | true_label, pred_label, attr_labels) 288 | writer.add_figure(f"{tagname}_{img_id}", fig, epoch) 289 | if img_id > 10: 290 | break 291 | 292 | break 293 | 294 | 295 | def save_expls(net, data_loader, tagname, save_path): 296 | """ 297 | Stores the explanation plots at the specified location. 298 | """ 299 | 300 | xticklabels = ['Sphere', 'Cube', 'Cylinder', 301 | 'Large', 'Small', 302 | 'Rubber', 'Metal', 303 | 'Cyan', 'Blue', 'Yellow', 'Purple', 'Red', 'Green', 'Gray', 'Brown'] 304 | 305 | net.eval() 306 | 307 | for i, sample in enumerate(data_loader): 308 | # input is either a set or an image 309 | imgs, target_set, img_class_ids, img_ids, _, _ = map(lambda x: x.cuda(), sample) 310 | img_class_ids = img_class_ids.long() 311 | 312 | # forward evaluation through the network 313 | output_cls, output_attr = net(imgs) 314 | _, preds = torch.max(output_cls, 1) 315 | 316 | # # convert sorting gt target set and gt table explanations to match the order of the predicted table 317 | # target_set, match_ids = utils.hungarian_matching(output_attr.to('cuda'), target_set) 318 | # # table_expls = table_expls[:, match_ids][range(table_expls.shape[0]), range(table_expls.shape[0])] 319 | 320 | # get explanations of set classifier 321 | table_saliencies = generate_intgrad_captum_table(net.set_cls, output_attr, preds) 322 | # remove xyz coords from tables for conf_3 323 | output_attr = output_attr[:, :, 3:] 324 | table_saliencies = table_saliencies[:, :, 3:] 325 | 326 | # get the ids of the two objects that receive the maximal importance, i.e. most important for the classification 327 | max_expl_obj_ids = table_saliencies.max(dim=2)[0].topk(2)[1] 328 | 329 | # get attention masks 330 | attns = net.img2state_net.slot_attention.attn 331 | # reshape attention masks to 2D 332 | attns = attns.reshape((attns.shape[0], attns.shape[1], int(np.sqrt(attns.shape[2])), 333 | int(np.sqrt(attns.shape[2])))) 334 | 335 | # concatenate the visual explanation of the top two objects that are most important for the classification 336 | img_saliencies = torch.zeros(attns.shape[0], attns.shape[2], attns.shape[3]) 337 | batch_size = attns.shape[0] 338 | for i in range(max_expl_obj_ids.shape[1]): 339 | img_saliencies += attns[range(batch_size), max_expl_obj_ids[range(batch_size), i], :, :].detach().cpu() 340 | 341 | num_stored_imgs = 0 342 | relevant_ids = [618, 154, 436, 244, 318, 85] 343 | 344 | for img_id, (img, gt_table, pred_table, table_expl, img_expl, true_label, pred_label, imgid) in enumerate(zip( 345 | imgs, target_set, output_attr.detach().cpu().numpy(), 346 | table_saliencies.detach().cpu().numpy(), img_saliencies.detach().cpu().numpy(), 347 | img_class_ids, preds, img_ids 348 | )): 349 | if imgid in relevant_ids: 350 | num_stored_imgs += 1 351 | # norm img expl to be between 0 and 255 352 | img_expl = (img_expl - np.min(img_expl))/(np.max(img_expl) - np.min(img_expl)) 353 | # resize to img size 354 | img_expl = np.array(Image.fromarray(img_expl).resize((img.shape[1], img.shape[2]), resample=1)) 355 | 356 | # unnormalize images 357 | img = img / 2. + 0.5 # Rescale to [0, 1]. 358 | img = np.array(transforms.ToPILImage()(img.cpu()).convert("RGB")) 359 | 360 | np.save(f"{save_path}{tagname}_{imgid}.npy", img) 361 | np.save(f"{save_path}{tagname}_{imgid}_imgexpl.npy", img_expl) 362 | np.save(f"{save_path}{tagname}_{imgid}_table.npy", pred_table) 363 | np.save(f"{save_path}{tagname}_{imgid}_tableexpl.npy", table_expl) 364 | 365 | fig = create_expl_images(img, pred_table, table_expl, img_expl, 366 | true_label, pred_label, xticklabels) 367 | plt.savefig(f"{save_path}{tagname}_{imgid}.png") 368 | plt.close(fig) 369 | 370 | if num_stored_imgs == len(relevant_ids): 371 | exit() 372 | --------------------------------------------------------------------------------