├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── datasets ├── apc_config.py ├── gqn_config.py ├── multi_object_config.py ├── multid_config.py ├── shapestacks_config.py └── sketchy_config.py ├── environment.yml ├── models ├── __init__.py ├── genesis_config.py ├── genesisv2_config.py ├── monet_config.py └── vae_config.py ├── modules ├── __init__.py ├── attention.py ├── blocks.py ├── component_vae.py ├── decoders.py ├── encoders.py └── unet.py ├── scripts ├── __init__.py ├── compute_fid.py ├── compute_seg_metrics.py ├── generate_multid.py ├── sketchy_preparation.py ├── visualise_data.py ├── visualise_generation.py └── visualise_reconstruction.py ├── third_party ├── __init__ .py ├── multi_object_datasets │ ├── CONTRIBUTING.md │ ├── LICENSE │ ├── README.md │ ├── clevr_with_masks.py │ ├── multi_dsprites.py │ ├── objects_room.py │ ├── preview.png │ ├── segmentation_metrics.py │ └── tetrominoes.py ├── pytorch_fid │ ├── LICENSE │ ├── README.md │ ├── fid_score.py │ └── inception.py ├── shapestacks │ ├── LICENSE │ ├── __init__.py │ ├── segmentation_utils.py │ └── shapestacks_provider.py ├── sylvester │ ├── LICENSE │ ├── VAE.py │ ├── __init__.py │ └── layers.py └── tf_gqn │ ├── LICENSE │ ├── __init__.py │ └── gqn_tfr_provider.py ├── train.py └── utils ├── __init__.py ├── colour_palette15.json ├── geco.py ├── misc.py ├── plotting.py └── shapestacks_urls.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # ----- Costum additions ----- # 107 | 108 | # Editors 109 | .vscode/ 110 | *.swp 111 | 112 | # Data directories 113 | data/ 114 | tmp/ 115 | checkpoints/ 116 | 117 | # Local binaries 118 | bin/ 119 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "forge"] 2 | path = forge 3 | url = https://github.com/akosiorek/forge.git 4 | -------------------------------------------------------------------------------- /datasets/apc_config.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import os 15 | from glob import glob 16 | import random 17 | from tqdm import tqdm 18 | 19 | import torch 20 | from torch.utils.data import Dataset, DataLoader 21 | from torchvision import transforms 22 | 23 | from PIL import Image 24 | 25 | from forge import flags 26 | from forge.experiment_tools import fprint 27 | 28 | from utils.misc import loader_throughput 29 | 30 | 31 | flags.DEFINE_string('data_folder', 'data/apc', 32 | 'Path to data folder.') 33 | flags.DEFINE_integer('img_size', 128, 34 | 'Dimension of images. Images are square.') 35 | flags.DEFINE_integer('num_workers', 4, 'Number of threads for loading data.') 36 | 37 | flags.DEFINE_integer('K_steps', 10, 'Number of component steps.') 38 | 39 | 40 | def load(cfg, **unused_kwargs): 41 | 42 | del unused_kwargs 43 | if not os.path.exists(cfg.data_folder): 44 | raise Exception("Data folder does not exist.") 45 | 46 | assert cfg.img_size == 128 47 | 48 | # Create splits if needed 49 | modes = ['train', 'val', 'test'] 50 | create_splits = False 51 | for m in modes: 52 | if not os.path.exists(f'{cfg.data_folder}/{m}_images.txt'): 53 | create_splits = True 54 | break 55 | if create_splits: 56 | fprint("Creating new train/val/test splits...") 57 | # Randomly split into train/val/test with fixed seed 58 | all_scenes = sorted(glob(f'{cfg.data_folder}/processed/*/*/scene-*')) 59 | random.seed(0) 60 | random.shuffle(all_scenes) 61 | num_eval_scenes = len(all_scenes) // 10 62 | train_scenes = all_scenes[2*num_eval_scenes:] 63 | val_scenes = all_scenes[:num_eval_scenes] 64 | test_scenes = all_scenes[num_eval_scenes:2*num_eval_scenes] 65 | modes = ['train', 'val', 'test'] 66 | mode_scenes = [train_scenes, val_scenes, test_scenes] 67 | for mode, mscs in zip(modes, mode_scenes): 68 | img_paths = [] 69 | for sc in mscs: 70 | img_paths += glob(f'{sc}/frame-*.color.png') 71 | with open(f'{cfg.data_folder}/{mode}_images.txt', 'w') as f: 72 | for item in sorted(img_paths): 73 | f.write("%s\n" % item) 74 | # Sanity checks 75 | assert len(train_scenes + val_scenes + test_scenes) == len(all_scenes) 76 | assert not list(set(train_scenes).intersection(val_scenes)) 77 | assert not list(set(train_scenes).intersection(test_scenes)) 78 | assert not list(set(val_scenes).intersection(test_scenes)) 79 | fprint("Created new train/val/test splits!") 80 | 81 | # Read splits 82 | with open(f'{cfg.data_folder}/train_images.txt') as f: 83 | train_images = f.readlines() 84 | train_images = [x.strip() for x in train_images] 85 | with open(f'{cfg.data_folder}/val_images.txt') as f: 86 | val_images = f.readlines() 87 | val_images = [x.strip() for x in val_images] 88 | with open(f'{cfg.data_folder}/test_images.txt') as f: 89 | test_images = f.readlines() 90 | test_images = [x.strip() for x in test_images] 91 | fprint(f"{len(train_images)} train images") 92 | fprint(f"{len(val_images)} val images") 93 | fprint(f"{len(test_images)} test images") 94 | 95 | # Datasets 96 | trainset = APCDataset(train_images) 97 | valset = APCDataset(val_images) 98 | testset = APCDataset(test_images) 99 | # Loaders 100 | train_loader = DataLoader( 101 | trainset, batch_size=cfg.batch_size, shuffle=True, 102 | num_workers=cfg.num_workers) 103 | val_loader = DataLoader( 104 | valset, batch_size=cfg.batch_size, shuffle=True, 105 | num_workers=cfg.num_workers) 106 | test_loader = DataLoader( 107 | testset, batch_size=cfg.batch_size, shuffle=True, 108 | num_workers=cfg.num_workers) 109 | 110 | # Throughput stats 111 | if not cfg.debug: 112 | loader_throughput(train_loader) 113 | 114 | return (train_loader, val_loader, test_loader) 115 | 116 | 117 | class APCDataset(Dataset): 118 | 119 | def __init__(self, image_paths): 120 | self.image_paths = image_paths 121 | # Transforms 122 | self.transform = transforms.ToTensor() 123 | 124 | def __len__(self): 125 | return len(self.image_paths) 126 | 127 | def __getitem__(self, idx): 128 | fp = self.image_paths[idx] 129 | img = self.transform(Image.open(fp)) 130 | mfp= fp.replace('frame', 'masks/frame').replace('color', 'mask') 131 | try: 132 | mfp= fp.replace('frame', 'masks/frame').replace('color', 'mask') 133 | mask = self.transform(Image.open(mfp)).long() 134 | except FileNotFoundError: 135 | mask = torch.zeros_like(img[:1, :, :]).long() 136 | return {'input': img, 'instances': mask} 137 | 138 | 139 | def preprocess(data_folder='data/apc', img_size=128): 140 | print("Getting image paths...") 141 | image_paths = glob( 142 | f'{data_folder}/training/*/*/scene-*/frame-*.color.png') 143 | print(f"Done. Found {len(image_paths)}.") 144 | img_T = transforms.Compose([ 145 | transforms.Resize(img_size, interpolation=Image.BILINEAR), 146 | transforms.CenterCrop(img_size) 147 | ]) 148 | mask_T = transforms.Compose([ 149 | transforms.Resize(img_size, interpolation=Image.NEAREST), 150 | transforms.CenterCrop(img_size) 151 | ]) 152 | # Created folders 153 | print("Creating folders...") 154 | for path in tqdm(glob(f'{data_folder}/training/*/*/scene-*/')): 155 | os.makedirs(path.replace('training', 'processed')) 156 | os.makedirs(path.replace('training', 'processed')+'/masks') 157 | print("Done.") 158 | print("Preprocessing images...") 159 | # Preprocess images 160 | for path in tqdm(image_paths): 161 | # Image 162 | img = img_T(Image.open(path)) 163 | img.save(path.replace('training', 'processed')) 164 | # Mask 165 | if 'scene-empty' not in path: 166 | m_path = path.replace('frame', 'masks/frame').replace('color', 'mask') 167 | mask = mask_T(Image.open(m_path)) 168 | mask.save(m_path.replace('training', 'processed')) 169 | print("ALL DONE!") 170 | 171 | 172 | if __name__ == "__main__": 173 | preprocess() 174 | -------------------------------------------------------------------------------- /datasets/gqn_config.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import os 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | 19 | import tensorflow as tf 20 | 21 | import numpy as np 22 | 23 | import third_party.tf_gqn.gqn_tfr_provider as gqn 24 | 25 | from forge import flags 26 | from forge.experiment_tools import fprint 27 | 28 | from utils.misc import loader_throughput 29 | 30 | 31 | flags.DEFINE_string('data_folder', 'data/gqn_datasets', 32 | 'Path to data folder.') 33 | flags.DEFINE_integer('img_size', 64, 34 | 'Dimension of images. Images are square.') 35 | flags.DEFINE_integer('val_frac', 60, 36 | 'Fraction of training images to use for validation.') 37 | 38 | flags.DEFINE_integer('num_workers', 4, 'TF records dataset.') 39 | flags.DEFINE_integer('buffer_size', 128, 'TF records dataset.') 40 | 41 | flags.DEFINE_integer('K_steps', 7, 'Number of recurrent steps.') 42 | 43 | 44 | SEED = 0 45 | 46 | 47 | def load(cfg, **unused_kwargs): 48 | # Fix TensorFlow seed 49 | global SEED 50 | SEED = cfg.seed 51 | tf.set_random_seed(SEED) 52 | 53 | if cfg.num_workers == 0: 54 | fprint("Need to use at least one worker for loading tfrecords.") 55 | cfg.num_workers = 1 56 | 57 | del unused_kwargs 58 | if not os.path.exists(cfg.data_folder): 59 | raise Exception("Data folder does not exist.") 60 | print(f"Using {cfg.num_workers} data workers.") 61 | # Create data iterators 62 | train_loader = GQNLoader( 63 | data_folder=cfg.data_folder, mode='devel_train', img_size=cfg.img_size, 64 | val_frac=cfg.val_frac, batch_size=cfg.batch_size, 65 | num_workers=cfg.num_workers, buffer_size=cfg.buffer_size) 66 | val_loader = GQNLoader( 67 | data_folder=cfg.data_folder, mode='devel_val', img_size=cfg.img_size, 68 | val_frac=cfg.val_frac, batch_size=cfg.batch_size, 69 | num_workers=cfg.num_workers, buffer_size=cfg.buffer_size) 70 | test_loader = GQNLoader( 71 | data_folder=cfg.data_folder, mode='test', img_size=cfg.img_size, 72 | val_frac=cfg.val_frac, batch_size=1, 73 | num_workers=1, buffer_size=cfg.buffer_size) 74 | # Create session to be used by loaders 75 | sess = tf.InteractiveSession() 76 | train_loader.sess = sess 77 | val_loader.sess = sess 78 | test_loader.sess = sess 79 | 80 | # Throughput stats 81 | if not cfg.debug: 82 | loader_throughput(train_loader) 83 | 84 | return (train_loader, val_loader, test_loader) 85 | 86 | 87 | class GQNLoader(): 88 | """GQN dataset.""" 89 | 90 | def __init__(self, data_folder, mode, img_size, val_frac, batch_size, 91 | num_workers, buffer_size): 92 | self.img_size = img_size 93 | self.batch_size = batch_size 94 | self.sess = None 95 | # Create GQN reader 96 | reader = gqn.GQNTFRecordDataset( 97 | dataset='rooms_ring_camera', 98 | context_size=0, 99 | root=data_folder, 100 | mode=mode, 101 | val_frac=val_frac, 102 | custom_frame_size=None, 103 | num_threads=num_workers, 104 | buffer_size=buffer_size) 105 | # Operato on TFRecordsDataset 106 | dataset = reader._dataset 107 | # Set properties 108 | dataset = dataset.repeat(1) 109 | if 'train' in mode: 110 | dataset = dataset.shuffle( 111 | buffer_size=buffer_size * self.batch_size, seed=SEED) 112 | dataset = dataset.batch(self.batch_size) 113 | self.dataset = dataset.prefetch(buffer_size * self.batch_size) 114 | # Create iterator 115 | it = self.dataset.make_one_shot_iterator() 116 | self.frames, _ = it.get_next() 117 | # TODO(martin): avoid hard coding these 118 | train_sz = 10800000 119 | test_sz = 1200000 120 | if mode == 'train': 121 | num_frames = train_sz 122 | elif mode == 'test': 123 | num_frames = test_sz 124 | elif mode == 'devel_train': 125 | num_frames = (train_sz // val_frac) * (val_frac-1) 126 | elif mode == 'devel_val': 127 | num_frames = (train_sz // val_frac) 128 | else: 129 | raise ValueError("Mode not known.") 130 | self.length = num_frames // batch_size 131 | 132 | def __len__(self): 133 | return self.length 134 | 135 | def __iter__(self): 136 | return self 137 | 138 | def __next__(self): 139 | try: 140 | img = self.sess.run(self.frames) 141 | img = img[:, 0, :, :, :] 142 | img = np.moveaxis(img, 3, 1) 143 | img = torch.FloatTensor(img) 144 | if self.img_size != 64: 145 | img = F.interpolate(img, size=self.img_size) 146 | return {'input': img} 147 | except tf.errors.OutOfRangeError: 148 | print("Reached end of epoch. Creating new iterator.") 149 | # Create new iterator for next epoch 150 | it = self.dataset.make_one_shot_iterator() 151 | self.frames, _ = it.get_next() 152 | raise StopIteration 153 | -------------------------------------------------------------------------------- /datasets/multi_object_config.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | import tensorflow as tf 18 | 19 | import numpy as np 20 | 21 | from forge import flags 22 | from forge.experiment_tools import fprint 23 | 24 | from utils.misc import loader_throughput, len_tfrecords, np_img_centre_crop 25 | 26 | import third_party.multi_object_datasets.multi_dsprites as multi_dsprites 27 | import third_party.multi_object_datasets.objects_room as objects_room 28 | import third_party.multi_object_datasets.clevr_with_masks as clevr_with_masks 29 | import third_party.multi_object_datasets.tetrominoes as tetrominoes 30 | 31 | 32 | flags.DEFINE_string('data_folder', 'data/multi-object-datasets', 33 | 'Path to data folder.') 34 | flags.DEFINE_string('dataset', 'objects_room', 35 | '{multi_dsprites, objects_room, clevr, tetrominoes}') 36 | flags.DEFINE_integer('img_size', -1, 37 | 'Dimension of images. Images are square.') 38 | flags.DEFINE_integer('dataset_size', -1, 'Number of images to use.') 39 | 40 | flags.DEFINE_integer('num_workers', 4, 41 | 'Number of threads for loading data.') 42 | flags.DEFINE_integer('buffer_size', 128, 'TF records dataset.') 43 | 44 | flags.DEFINE_integer('K_steps', -1, 'Number of recurrent steps.') 45 | 46 | 47 | MULTI_DSPRITES = '/multi_dsprites/multi_dsprites_colored_on_colored.tfrecords' 48 | OBJECTS_ROOM = '/objects_room/objects_room_train.tfrecords' 49 | CLEVR = '/clevr_with_masks/clevr_with_masks_train.tfrecords' 50 | TETROMINOS = '/tetrominoes/tetrominoes_train.tfrecords' 51 | CLEVR_CROP = 192 # Following pre-processing in the IODINE paper 52 | 53 | SEED = 0 54 | 55 | 56 | def load(cfg, **unused_kwargs): 57 | # Fix TensorFlow seed 58 | global SEED 59 | SEED = cfg.seed 60 | tf.set_random_seed(SEED) 61 | 62 | del unused_kwargs 63 | fprint(f"Using {cfg.num_workers} data workers.") 64 | 65 | sess = tf.InteractiveSession() 66 | 67 | if cfg.dataset == 'multi_dsprites': 68 | cfg.img_size = 64 if cfg.img_size < 0 else cfg.img_size 69 | cfg.K_steps = 5 if cfg.K_steps < 0 else cfg.K_steps 70 | background_entities = 1 71 | max_frames = 60000 72 | raw_dataset = multi_dsprites.dataset( 73 | cfg.data_folder + MULTI_DSPRITES, 74 | 'colored_on_colored', 75 | map_parallel_calls=cfg.num_workers if cfg.num_workers > 0 else None) 76 | elif cfg.dataset == 'objects_room': 77 | cfg.img_size = 64 if cfg.img_size < 0 else cfg.img_size 78 | cfg.K_steps = 7 if cfg.K_steps < 0 else cfg.K_steps 79 | background_entities = 4 80 | max_frames = 1000000 81 | raw_dataset = objects_room.dataset( 82 | cfg.data_folder + OBJECTS_ROOM, 83 | 'train', 84 | map_parallel_calls=cfg.num_workers if cfg.num_workers > 0 else None) 85 | elif cfg.dataset == 'clevr': 86 | cfg.img_size = 128 if cfg.img_size < 0 else cfg.img_size 87 | cfg.K_steps = 11 if cfg.K_steps < 0 else cfg.K_steps 88 | background_entities = 1 89 | max_frames = 70000 90 | raw_dataset = clevr_with_masks.dataset( 91 | cfg.data_folder + CLEVR, 92 | map_parallel_calls=cfg.num_workers if cfg.num_workers > 0 else None) 93 | elif cfg.dataset == 'tetrominoes': 94 | cfg.img_size = 32 if cfg.img_size < 0 else cfg.img_size 95 | cfg.K_steps = 4 if cfg.K_steps < 0 else cfg.K_steps 96 | background_entities = 1 97 | max_frames = 60000 98 | raw_dataset = tetrominoes.dataset( 99 | cfg.data_folder + TETROMINOS, 100 | map_parallel_calls=cfg.num_workers if cfg.num_workers > 0 else None) 101 | else: 102 | raise NotImplementedError(f"{cfg.dataset} not a valid dataset.") 103 | 104 | # Split into train / val / test 105 | if cfg.dataset_size > max_frames: 106 | fprint(f"WARNING: {cfg.dataset_size} frames requested, "\ 107 | "but only {max_frames} available.") 108 | cfg.dataset_size = max_frames 109 | if cfg.dataset_size > 0: 110 | total_sz = cfg.dataset_size 111 | raw_dataset = raw_dataset.take(total_sz) 112 | else: 113 | total_sz = max_frames 114 | if total_sz < 0: 115 | fprint("Determining size of dataset...") 116 | total_sz = len_tfrecords(raw_dataset, sess) 117 | fprint(f"Dataset has {total_sz} frames") 118 | 119 | val_sz = 10000 120 | tst_sz = 10000 121 | tng_sz = total_sz - val_sz - tst_sz 122 | assert tng_sz > 0 123 | fprint(f"Splitting into {tng_sz}/{val_sz}/{tst_sz} for tng/val/tst") 124 | tst_dataset = raw_dataset.take(tst_sz) 125 | val_dataset = raw_dataset.skip(tst_sz).take(val_sz) 126 | tng_dataset = raw_dataset.skip(tst_sz + val_sz) 127 | 128 | tng_loader = MultiOjectLoader(sess, tng_dataset, background_entities, 129 | tng_sz, cfg.batch_size, 130 | cfg.img_size, cfg.buffer_size) 131 | val_loader = MultiOjectLoader(sess, val_dataset, background_entities, 132 | val_sz, cfg.batch_size, 133 | cfg.img_size, cfg.buffer_size) 134 | tst_loader = MultiOjectLoader(sess, tst_dataset, background_entities, 135 | tst_sz, cfg.batch_size, 136 | cfg.img_size, cfg.buffer_size) 137 | 138 | # Throughput stats 139 | if not cfg.debug: 140 | loader_throughput(tng_loader) 141 | 142 | return (tng_loader, val_loader, tst_loader) 143 | 144 | 145 | class MultiOjectLoader(): 146 | 147 | def __init__(self, sess, dataset, background_entities, 148 | num_frames, batch_size, img_size=64, buffer_size=128): 149 | # Batch and shuffle 150 | dataset = dataset.shuffle(buffer_size*batch_size, seed=SEED) 151 | dataset = dataset.batch(batch_size) 152 | self.dataset = dataset.prefetch(buffer_size) 153 | # State 154 | self.sess = sess 155 | self.background_entities = background_entities 156 | self.num_frames = num_frames 157 | self.batch_size = batch_size 158 | self.length = self.num_frames // batch_size 159 | self.img_size = img_size 160 | self.count = 0 161 | self.frames = None 162 | 163 | def __len__(self): 164 | return self.length 165 | 166 | def __iter__(self): 167 | fprint("Creating new one_shot_iterator.") 168 | it = self.dataset.make_one_shot_iterator() 169 | self.frames = it.get_next() 170 | return self 171 | 172 | def __next__(self): 173 | try: 174 | frame = self.sess.run(self.frames) 175 | self.count += 1 176 | 177 | # Parse image 178 | img = frame['image'] 179 | img = np.moveaxis(img, 3, 1) 180 | shape = img.shape 181 | # TODO(martin): use more explicit CLEVR flag? 182 | if shape[2] != shape[3]: 183 | img = np_img_centre_crop(img, CLEVR_CROP, batch=True) 184 | img = torch.FloatTensor(img) / 255. 185 | if self.img_size != shape[2]: 186 | img = F.interpolate(img, size=self.img_size) 187 | 188 | # Parse masks 189 | raw_masks = frame['mask'] 190 | masks = np.zeros((shape[0], 1, shape[2], shape[3]), dtype='int') 191 | # Convert to boolean masks 192 | cond = np.where(raw_masks[:, :, :, :, 0] == 255, True, False) 193 | # Ignore background entities 194 | num_entities = cond.shape[1] 195 | for o_idx in range(self.background_entities, num_entities): 196 | masks[cond[:, o_idx:o_idx+1, :, :]] = o_idx + 1 197 | masks = torch.FloatTensor(masks) 198 | if shape[2] != shape[3]: 199 | masks = np_img_centre_crop(masks, CLEVR_CROP, batch=True) 200 | masks = torch.FloatTensor(masks) 201 | if self.img_size != shape[2]: 202 | masks = F.interpolate(masks, size=self.img_size) 203 | masks = masks.type(torch.LongTensor) 204 | 205 | return {'input': img, 'instances': masks} 206 | 207 | except tf.errors.OutOfRangeError: 208 | fprint("Reached end of epoch. Creating new iterator.") 209 | fprint(f"Counted {self.count} batches, expected {self.length}.") 210 | fprint("Creating new iterator.") 211 | self.count = 0 212 | raise StopIteration 213 | -------------------------------------------------------------------------------- /datasets/multid_config.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import os 15 | 16 | import torch 17 | from torch.utils.data import Dataset, DataLoader 18 | from torchvision import transforms 19 | import torch.nn.functional as F 20 | 21 | import numpy as np 22 | 23 | from forge import flags 24 | 25 | from utils.misc import loader_throughput 26 | 27 | 28 | flags.DEFINE_string('data_folder', 'data/multi_dsprites/processed', 29 | 'Path to data folder.') 30 | flags.DEFINE_boolean('unique_colours', False, 'Dataset with unique colours.') 31 | flags.DEFINE_boolean('load_instances', True, 'Load instances.') 32 | flags.DEFINE_integer('img_size', 64, 33 | 'Dimension of images. Images are square.') 34 | 35 | flags.DEFINE_integer('num_workers', 4, 36 | 'Number of threads for loading data.') 37 | flags.DEFINE_boolean('mem_map', False, 'Use memory mapping.') 38 | 39 | flags.DEFINE_integer('K_steps', 5, 'Number of recurrent steps.') 40 | 41 | 42 | def load(cfg, **unused_kwargs): 43 | """ 44 | Args: 45 | cfg (obj): Forge config 46 | Returns: 47 | (DataLoader, DataLoader, DataLoader): 48 | Tuple of data loaders for train, val, test 49 | """ 50 | del unused_kwargs 51 | if not os.path.exists(cfg.data_folder): 52 | raise Exception("Data folder does not exist.") 53 | print(f"Using {cfg.num_workers} data workers.") 54 | 55 | if not hasattr(cfg, 'unique_colours'): 56 | cfg.unique_colours = False 57 | 58 | # Paths 59 | if cfg.unique_colours: 60 | train_path = 'training_images_rand4_unique.npy' 61 | val_path = 'validation_images_rand4_unique.npy' 62 | test_path = 'test_images_rand4_unique.npy' 63 | else: 64 | train_path = 'training_images_rand4.npy' 65 | val_path = 'validation_images_rand4.npy' 66 | test_path = 'test_images_rand4.npy' 67 | 68 | # Training 69 | train_dataset = dSpritesDataset(os.path.join(cfg.data_folder, train_path), 70 | cfg.load_instances, 71 | cfg.img_size, 72 | cfg.mem_map) 73 | train_loader = DataLoader(train_dataset, 74 | batch_size=cfg.batch_size, 75 | shuffle=True, 76 | num_workers=cfg.num_workers) 77 | # Validation 78 | val_dataset = dSpritesDataset(os.path.join(cfg.data_folder, val_path), 79 | cfg.load_instances, 80 | cfg.img_size, 81 | cfg.mem_map) 82 | val_loader = DataLoader(val_dataset, 83 | batch_size=cfg.batch_size, 84 | shuffle=True, 85 | num_workers=cfg.num_workers) 86 | # Test 87 | test_dataset = dSpritesDataset(os.path.join(cfg.data_folder, test_path), 88 | cfg.load_instances, 89 | cfg.img_size, 90 | cfg.mem_map) 91 | test_loader = DataLoader(test_dataset, 92 | batch_size=cfg.batch_size, 93 | shuffle=True, 94 | num_workers=1) 95 | 96 | # Throughput stats 97 | if not cfg.debug: 98 | loader_throughput(train_loader) 99 | 100 | return (train_loader, val_loader, test_loader) 101 | 102 | 103 | class dSpritesDataset(Dataset): 104 | """dSprites dataset.""" 105 | 106 | def __init__(self, file_path, load_instances=True, 107 | img_size=64, mem_map=False): 108 | """ 109 | Args: 110 | file_path (string): Path to the npy file of dSprites dataset. 111 | transform (callable, optional): Optional transform to be applied 112 | """ 113 | if mem_map: 114 | self.all_images = np.load(file_path, mmap_mode='r') 115 | else: 116 | self.all_images = np.load(file_path) 117 | self.to_tensor = transforms.ToTensor() 118 | if load_instances and mem_map: 119 | self.all_instance_masks = np.load( 120 | file_path.replace('images', 'masks'), mmap_mode='r') 121 | elif load_instances: 122 | self.all_instance_masks = np.load( 123 | file_path.replace('images', 'masks')) 124 | else: 125 | self.all_instance_masks = None 126 | self.img_size = img_size 127 | 128 | def __len__(self): 129 | return len(self.all_images) 130 | 131 | def __getitem__(self, idx): 132 | img = self.all_images[idx] 133 | img = self.to_tensor(img) 134 | if self.img_size != 64: 135 | img = F.interpolate(img.unsqueeze(0), size=self.img_size).squeeze(0) 136 | output = {'input': img} 137 | if self.all_instance_masks is not None: 138 | ins = self.all_instance_masks[idx] 139 | ins = self.to_tensor(ins) 140 | if self.img_size != 64: 141 | ins = F.interpolate( 142 | ins.unsqueeze(0), size=self.img_size).squeeze(0) 143 | output['instances'] = ins.type(torch.LongTensor) 144 | return output 145 | -------------------------------------------------------------------------------- /datasets/shapestacks_config.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import os 15 | from shutil import copytree 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | from torch.utils.data import Dataset, DataLoader 20 | from torchvision import transforms 21 | 22 | import numpy as np 23 | from PIL import Image 24 | 25 | from forge import flags 26 | from forge.experiment_tools import fprint 27 | 28 | from utils.misc import loader_throughput, np_img_centre_crop 29 | 30 | from third_party.shapestacks.shapestacks_provider import _get_filenames_with_labels 31 | from third_party.shapestacks.segmentation_utils import load_segmap_as_matrix 32 | 33 | 34 | flags.DEFINE_string('data_folder', 'data/shapestacks', 'Path to data folder.') 35 | flags.DEFINE_string('split_name', 'default', '{default, blocks_all, css_all}') 36 | flags.DEFINE_integer('img_size', 64, 'Dimension of images. Images are square.') 37 | flags.DEFINE_boolean('shuffle_test', False, 'Shuffle test set.') 38 | 39 | flags.DEFINE_integer('num_workers', 4, 'Number of threads for loading data.') 40 | flags.DEFINE_boolean('load_instances', True, 'Load instances.') 41 | flags.DEFINE_boolean('copy_to_tmp', False, 'Copy files to /tmp.') 42 | 43 | flags.DEFINE_integer('K_steps', 9, 'Number of recurrent steps.') 44 | 45 | 46 | MAX_SHAPES = 6 47 | CENTRE_CROP = 196 48 | 49 | 50 | def load(cfg, **unused_kwargs): 51 | del unused_kwargs 52 | if not os.path.exists(cfg.data_folder): 53 | raise Exception("Data folder does not exist.") 54 | print(f"Using {cfg.num_workers} data workers.") 55 | 56 | # Copy all images and splits to /tmp 57 | if cfg.copy_to_tmp: 58 | for directory in ['/recordings', '/splits', '/iseg']: 59 | src = cfg.data_folder + directory 60 | dst = '/tmp' + directory 61 | fprint(f"Copying dataset from {src} to {dst}.") 62 | copytree(src, dst) 63 | cfg.data_folder = '/tmp' 64 | 65 | # Training 66 | tng_set = ShapeStacksDataset(cfg.data_folder, 67 | cfg.split_name, 68 | 'train', 69 | cfg.img_size, 70 | cfg.load_instances) 71 | tng_loader = DataLoader(tng_set, 72 | batch_size=cfg.batch_size, 73 | shuffle=True, 74 | num_workers=cfg.num_workers) 75 | # Validation 76 | val_set = ShapeStacksDataset(cfg.data_folder, 77 | cfg.split_name, 78 | 'eval', 79 | cfg.img_size, 80 | cfg.load_instances) 81 | val_loader = DataLoader(val_set, 82 | batch_size=cfg.batch_size, 83 | shuffle=True, 84 | num_workers=cfg.num_workers) 85 | # Test 86 | tst_set = ShapeStacksDataset(cfg.data_folder, 87 | cfg.split_name, 88 | 'test', 89 | cfg.img_size, 90 | cfg.load_instances, 91 | shuffle_files=cfg.shuffle_test) 92 | tst_loader = DataLoader(tst_set, 93 | batch_size=cfg.batch_size, 94 | shuffle=True, 95 | num_workers=1) 96 | 97 | # Throughput stats 98 | if not cfg.debug: 99 | loader_throughput(tng_loader) 100 | 101 | return (tng_loader, val_loader, tst_loader) 102 | 103 | 104 | class ShapeStacksDataset(Dataset): 105 | 106 | def __init__(self, data_dir, split_name, mode, img_size=224, 107 | load_instances=True, shuffle_files=False): 108 | self.data_dir = data_dir 109 | self.img_size = img_size 110 | self.load_instances = load_instances 111 | 112 | # Files 113 | split_dir = os.path.join(data_dir, 'splits', split_name) 114 | self.filenames, self.stability_labels = _get_filenames_with_labels( 115 | mode, data_dir, split_dir) 116 | 117 | # Shuffle files? 118 | if shuffle_files: 119 | print(f"Shuffling {len(self.filenames)} files") 120 | idx = np.arange(len(self.filenames), dtype='int32') 121 | np.random.shuffle(idx) 122 | self.filenames = [self.filenames[i] for i in list(idx)] 123 | self.stability_labels = [self.stability_labels[i] for i in list(idx)] 124 | 125 | # Transforms 126 | T = [transforms.CenterCrop(CENTRE_CROP)] 127 | if img_size != CENTRE_CROP: 128 | T.append(transforms.Resize(img_size)) 129 | T.append(transforms.ToTensor()) 130 | self.transform = transforms.Compose(T) 131 | 132 | def __len__(self): 133 | return len(self.filenames) 134 | 135 | def __getitem__(self, idx): 136 | # --- Load image --- 137 | # File name example: 138 | # data_dir + /recordings/env_ccs-hard-h=2-vcom=0-vpsf=0-v=60/ 139 | # rgb-w=5-f=2-l=1-c=unique-cam_7-mono-0.png 140 | file = self.filenames[idx] 141 | img = Image.open(file) 142 | output = {'input': self.transform(img)} 143 | 144 | # --- Load instances --- 145 | if self.load_instances: 146 | file_split = file.split('/') 147 | # cam = file_split[4].split('-')[5][4:] 148 | # map_path = os.path.join( 149 | # self.data_dir, 'iseg', file_split[3], 150 | # 'iseg-w=0-f=0-l=0-c=original-cam_' + cam + '-mono-0.map') 151 | cam = file_split[-1].split('-')[5][4:] 152 | map_path = os.path.join( 153 | self.data_dir, 'iseg', file_split[-2], 154 | 'iseg-w=0-f=0-l=0-c=original-cam_' + cam + '-mono-0.map') 155 | masks = load_segmap_as_matrix(map_path) 156 | masks = np.expand_dims(masks, 0) 157 | masks = np_img_centre_crop(masks, CENTRE_CROP) 158 | masks = torch.FloatTensor(masks) 159 | if self.img_size != masks.shape[2]: 160 | masks = masks.unsqueeze(0) 161 | masks = F.interpolate(masks, size=self.img_size) 162 | masks = masks.squeeze(0) 163 | output['instances'] = masks.type(torch.LongTensor) 164 | 165 | return output 166 | -------------------------------------------------------------------------------- /datasets/sketchy_config.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import os 15 | from glob import glob 16 | 17 | from torch.utils.data import Dataset, DataLoader 18 | from torchvision import transforms 19 | 20 | from PIL import Image 21 | 22 | from forge import flags 23 | from forge.experiment_tools import fprint 24 | 25 | from utils.misc import loader_throughput 26 | 27 | 28 | flags.DEFINE_string('data_folder', 'data/sketchy', 'Path to data folder.') 29 | flags.DEFINE_integer('num_workers', 4, 'Number of threads for loading data.') 30 | flags.DEFINE_integer('img_size', 128, 'Dimension of images. Images are square.') 31 | # Object slots: 3 objects, robot base, gripper, wrist, arm, ground, cables, wall 32 | flags.DEFINE_integer('K_steps', 10, 'Number of object slots.') 33 | 34 | 35 | def load(cfg, **unused_kwargs): 36 | del unused_kwargs 37 | if not os.path.exists(cfg.data_folder): 38 | raise Exception("Data folder does not exist.") 39 | fprint(f"Using {cfg.num_workers} data workers.") 40 | 41 | assert cfg.img_size == 128 42 | 43 | tng_set = SketchyDataset(cfg.data_folder, 'train') 44 | val_set = SketchyDataset(cfg.data_folder, 'valid') 45 | tst_set = SketchyDataset(cfg.data_folder, 'test') 46 | 47 | tng_loader = DataLoader( 48 | tng_set, 49 | batch_size=cfg.batch_size, 50 | shuffle=True, 51 | num_workers=cfg.num_workers) 52 | val_loader = DataLoader( 53 | val_set, 54 | batch_size=cfg.batch_size, 55 | shuffle=True, 56 | num_workers=cfg.num_workers) 57 | tst_loader = DataLoader( 58 | tst_set, 59 | batch_size=1, 60 | shuffle=True, 61 | num_workers=1) 62 | 63 | if not cfg.debug: 64 | loader_throughput(tng_loader) 65 | 66 | return tng_loader, val_loader, tst_loader 67 | 68 | 69 | class SketchyDataset(Dataset): 70 | 71 | def __init__(self, data_dir, mode): 72 | split_file = f'{data_dir}/processed/{mode}_images.txt' 73 | if os.path.exists(split_file): 74 | fprint(f"Reading paths for {mode} files...") 75 | with open(split_file, "r") as f: 76 | self.filenames = f.readlines() 77 | self.filenames = [item.strip() for item in self.filenames] 78 | else: 79 | fprint(f"Searching for {mode} files...") 80 | self.filenames = glob(f'{data_dir}/processed/{mode}/ep*/ep*.png') 81 | with open(split_file, "w") as f: 82 | for item in self.filenames: 83 | f.write(f'{item}\n') 84 | fprint(f"Found {len(self.filenames)}.") 85 | 86 | def __len__(self): 87 | return len(self.filenames) 88 | 89 | def __getitem__(self, idx): 90 | file = self.filenames[idx] 91 | img = Image.open(file) 92 | return {'input': transforms.functional.to_tensor(img)} 93 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: genesis_env 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _tflow_select=2.3.0=mkl 8 | - absl-py=0.8.1=py37_0 9 | - astor=0.8.0=py37_0 10 | - blas=1.0=mkl 11 | - c-ares=1.15.0=h7b6447c_1001 12 | - ca-certificates=2019.11.27=0 13 | - certifi=2019.11.28=py37_0 14 | - cffi=1.13.2=py37h2e261b9_0 15 | - cudatoolkit=10.1.243=h6bb024c_0 16 | - cycler=0.10.0=py37_0 17 | - dbus=1.13.12=h746ee38_0 18 | - expat=2.2.6=he6710b0_0 19 | - fontconfig=2.13.0=h9420a91_0 20 | - freetype=2.9.1=h8a8886c_1 21 | - gast=0.3.2=py_0 22 | - glib=2.63.1=h5a9c865_0 23 | - google-pasta=0.1.8=py_0 24 | - grpcio=1.16.1=py37hf8bcb03_1 25 | - gst-plugins-base=1.14.0=hbbd80ab_1 26 | - gstreamer=1.14.0=hb453b48_1 27 | - h5py=2.9.0=py37h7918eee_0 28 | - hdf5=1.10.4=hb1b8bf9_0 29 | - icu=58.2=h9c2bf20_1 30 | - imageio=2.6.1=py37_0 31 | - intel-openmp=2019.4=243 32 | - jpeg=9b=h024ee3a_2 33 | - keras-applications=1.0.8=py_0 34 | - keras-preprocessing=1.1.0=py_1 35 | - kiwisolver=1.1.0=py37he6710b0_0 36 | - libedit=3.1.20181209=hc058e9b_0 37 | - libffi=3.2.1=hd88cf55_4 38 | - libgcc-ng=9.1.0=hdf63c60_0 39 | - libgfortran-ng=7.3.0=hdf63c60_0 40 | - libpng=1.6.37=hbc83047_0 41 | - libprotobuf=3.11.2=hd408876_0 42 | - libstdcxx-ng=9.1.0=hdf63c60_0 43 | - libtiff=4.1.0=h2733197_0 44 | - libuuid=1.0.3=h1bed415_2 45 | - libxcb=1.13=h1bed415_1 46 | - libxml2=2.9.9=hea5a465_1 47 | - markdown=3.1.1=py37_0 48 | - matplotlib=3.1.1=py37h5429711_0 49 | - mkl=2019.4=243 50 | - mkl-service=2.3.0=py37he904b0f_0 51 | - mkl_fft=1.0.15=py37ha843d7b_0 52 | - mkl_random=1.1.0=py37hd6b4f25_0 53 | - ncurses=6.1=he6710b0_1 54 | - ninja=1.9.0=py37hfd86e86_0 55 | - numpy=1.17.4=py37hc1035e2_0 56 | - numpy-base=1.17.4=py37hde5b4d6_0 57 | - olefile=0.46=py37_0 58 | - openssl=1.1.1d=h7b6447c_3 59 | - pcre=8.43=he6710b0_0 60 | - pillow=6.2.1=py37h34e0f95_0 61 | - pip=19.3.1=py37_0 62 | - pycparser=2.19=py37_0 63 | - pyparsing=2.4.5=py_0 64 | - pyqt=5.9.2=py37h05f1152_2 65 | - python=3.7.5=h0371630_0 66 | - python-dateutil=2.8.1=py_0 67 | - pytorch=1.3.1=py3.7_cuda10.1.243_cudnn7.6.3_0 68 | - pytz=2019.3=py_0 69 | - qt=5.9.7=h5867ecd_1 70 | - readline=7.0=h7b6447c_5 71 | - scipy=1.3.2=py37h7c811a0_0 72 | - setuptools=42.0.2=py37_0 73 | - simplejson=3.17.0=py37h7b6447c_0 74 | - sip=4.19.8=py37hf484d3e_0 75 | - six=1.13.0=py37_0 76 | - sqlite=3.30.1=h7b6447c_0 77 | - tensorboard=1.14.0=py37hf484d3e_0 78 | - tensorflow=1.14.0=mkl_py37h45c423b_0 79 | - tensorflow-base=1.14.0=mkl_py37h7ce6ba3_0 80 | - tensorflow-estimator=1.14.0=py_0 81 | - termcolor=1.1.0=py37_1 82 | - tk=8.6.8=hbc83047_0 83 | - torchvision=0.4.2=py37_cu101 84 | - tornado=6.0.3=py37h7b6447c_0 85 | - tqdm=4.40.2=py_0 86 | - werkzeug=0.16.0=py_0 87 | - wheel=0.33.6=py37_0 88 | - wrapt=1.11.2=py37h7b6447c_0 89 | - xz=5.2.4=h14c3975_4 90 | - zlib=1.2.11=h7b6447c_3 91 | - zstd=1.3.7=h0b5b093_0 92 | - pip: 93 | - argcomplete==1.11.0 94 | - attrdict==2.0.1 95 | - boto==2.49.0 96 | - crcmod==1.7 97 | - cryptography==2.8 98 | - fasteners==0.15 99 | - gcs-oauth2-boto-plugin==2.5 100 | - google-apitools==0.5.30 101 | - google-reauth==0.1.0 102 | - gsutil==4.46 103 | - httplib2==0.15.0 104 | - importlib-metadata==1.3.0 105 | - joblib==0.14.1 106 | - mock==2.0.0 107 | - monotonic==1.5 108 | - more-itertools==8.0.2 109 | - oauth2client==4.1.3 110 | - pbr==5.4.4 111 | - protobuf==3.11.2 112 | - pyasn1==0.4.8 113 | - pyasn1-modules==0.2.7 114 | - pyopenssl==19.1.0 115 | - pyu2f==0.1.4 116 | - retry-decorator==1.1.0 117 | - rsa==4.0 118 | - scikit-learn==0.22 119 | - sklearn==0.0 120 | - socksipy-branch==1.1 121 | - tensorboardx==1.9 122 | - zipp==0.6.0 123 | prefix: /opt/conda/envs/genesis_env -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/applied-ai-lab/genesis/9abf202bbad6fa4a675117fdea0be163e4f16695/models/__init__.py -------------------------------------------------------------------------------- /models/genesisv2_config.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | from attrdict import AttrDict 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from torch.distributions.normal import Normal 20 | import numpy as np 21 | 22 | from forge import flags 23 | 24 | from modules.unet import UNet 25 | import modules.attention as attention 26 | import modules.blocks as B 27 | 28 | from models.genesis_config import Genesis 29 | from models.monet_config import MONet 30 | 31 | import utils.misc as misc 32 | 33 | 34 | # Architecture 35 | flags.DEFINE_integer('feat_dim', 64, 'Number of features and latents.') 36 | # Segmentation 37 | flags.DEFINE_string('kernel', 'gaussian', '{laplacian, gaussian, epanechnikov') 38 | flags.DEFINE_boolean('semiconv', True, 'Use semi-convolutional embeddings.') 39 | flags.DEFINE_boolean('dynamic_K', False, 'Dynamic K.') 40 | # Auxiliary mask consistency loss 41 | flags.DEFINE_boolean('klm_loss', False, 'KL mask regulariser.') 42 | flags.DEFINE_boolean('detach_mr_in_klm', True, 'Detach reconstructed masks.') 43 | 44 | 45 | def load(cfg): 46 | return GenesisV2(cfg) 47 | 48 | 49 | class GenesisV2(nn.Module): 50 | 51 | def __init__(self, cfg): 52 | super(GenesisV2, self).__init__() 53 | # Configuration 54 | self.K_steps = cfg.K_steps 55 | self.pixel_bound = cfg.pixel_bound 56 | self.feat_dim = cfg.feat_dim 57 | self.klm_loss = cfg.klm_loss 58 | self.detach_mr_in_klm = cfg.detach_mr_in_klm 59 | self.dynamic_K = cfg.dynamic_K 60 | self.debug = cfg.debug 61 | self.multi_gpu = cfg.multi_gpu 62 | # Encoder 63 | self.encoder = UNet( 64 | num_blocks=int(np.log2(cfg.img_size)-1), 65 | img_size=cfg.img_size, 66 | filter_start=min(cfg.feat_dim, 64), 67 | in_chnls=3, 68 | out_chnls=cfg.feat_dim, 69 | norm='gn') 70 | self.encoder.final_conv = nn.Identity() 71 | self.att_process = attention.InstanceColouringSBP( 72 | img_size=cfg.img_size, 73 | kernel=cfg.kernel, 74 | colour_dim=8, 75 | K_steps=self.K_steps, 76 | feat_dim=cfg.feat_dim, 77 | semiconv=cfg.semiconv) 78 | self.seg_head = B.ConvGNReLU(cfg.feat_dim, cfg.feat_dim, 3, 1, 1) 79 | self.feat_head = nn.Sequential( 80 | B.ConvGNReLU(cfg.feat_dim, cfg.feat_dim, 3, 1, 1), 81 | nn.Conv2d(cfg.feat_dim, 2*cfg.feat_dim, 1)) 82 | self.z_head = nn.Sequential( 83 | nn.LayerNorm(2*cfg.feat_dim), 84 | nn.Linear(2*cfg.feat_dim, 2*cfg.feat_dim), 85 | nn.ReLU(inplace=True), 86 | nn.Linear(2*cfg.feat_dim, 2*cfg.feat_dim)) 87 | # Decoder 88 | c = cfg.feat_dim 89 | self.decoder_module = nn.Sequential( 90 | B.BroadcastLayer(cfg.img_size // 16), 91 | nn.ConvTranspose2d(cfg.feat_dim+2, c, 5, 2, 2, 1), 92 | nn.GroupNorm(8, c), nn.ReLU(inplace=True), 93 | nn.ConvTranspose2d(c, c, 5, 2, 2, 1), 94 | nn.GroupNorm(8, c), nn.ReLU(inplace=True), 95 | nn.ConvTranspose2d(c, min(c, 64), 5, 2, 2, 1), 96 | nn.GroupNorm(8, min(c, 64)), nn.ReLU(inplace=True), 97 | nn.ConvTranspose2d(min(c, 64), min(c, 64), 5, 2, 2, 1), 98 | nn.GroupNorm(8, min(c, 64)), nn.ReLU(inplace=True), 99 | nn.Conv2d(min(c, 64), 4, 1)) 100 | # --- Prior --- 101 | self.autoreg_prior = cfg.autoreg_prior 102 | self.prior_lstm, self.prior_linear = None, None 103 | if self.autoreg_prior and self.K_steps > 1: 104 | self.prior_lstm = nn.LSTM(cfg.feat_dim, 4*cfg.feat_dim) 105 | self.prior_linear = nn.Linear(4*cfg.feat_dim, 2*cfg.feat_dim) 106 | # --- Output pixel distribution --- 107 | assert cfg.pixel_std1 == cfg.pixel_std2 108 | self.std = cfg.pixel_std1 109 | 110 | def forward(self, x): 111 | batch_size, _, H, W = x.shape 112 | 113 | # --- Extract features --- 114 | enc_feat, _ = self.encoder(x) 115 | enc_feat = F.relu(enc_feat) 116 | 117 | # --- Predict attention masks --- 118 | if self.dynamic_K: 119 | if batch_size > 1: 120 | # Iterate over individual elements in batch 121 | log_m_k = [[] for _ in range(self.K_steps)] 122 | att_stats, log_s_k = None, None 123 | for f in torch.split(enc_feat, 1, dim=0): 124 | log_m_k_b, _, _ = self.att_process( 125 | self.seg_head(f), self.K_steps-1, dynamic_K=True) 126 | for step in range(self.K_steps): 127 | if step < len(log_m_k_b): 128 | log_m_k[step].append(log_m_k_b[step]) 129 | else: 130 | log_m_k[step].append(-1e10*torch.ones([1, 1, H, W])) 131 | for step in range(self.K_steps): 132 | log_m_k[step] = torch.cat(log_m_k[step], dim=0) 133 | if self.debug: 134 | assert len(log_m_k) == self.K_steps 135 | else: 136 | log_m_k, log_s_k, att_stats = self.att_process( 137 | self.seg_head(enc_feat), self.K_steps-1, dynamic_K=True) 138 | else: 139 | log_m_k, log_s_k, att_stats = self.att_process( 140 | self.seg_head(enc_feat), self.K_steps-1, dynamic_K=False) 141 | if self.debug: 142 | assert len(log_m_k) == self.K_steps 143 | 144 | # -- Object features, latents, and KL 145 | comp_stats = AttrDict(mu_k=[], sigma_k=[], z_k=[], kl_l_k=[], q_z_k=[]) 146 | for log_m in log_m_k: 147 | mask = log_m.exp() 148 | # Masked sum 149 | obj_feat = mask * self.feat_head(enc_feat) 150 | obj_feat = obj_feat.sum((2, 3)) 151 | # Normalise 152 | obj_feat = obj_feat / (mask.sum((2, 3)) + 1e-5) 153 | # Posterior 154 | mu, sigma_ps = self.z_head(obj_feat).chunk(2, dim=1) 155 | sigma = B.to_sigma(sigma_ps) 156 | q_z = Normal(mu, sigma) 157 | z = q_z.rsample() 158 | comp_stats['mu_k'].append(mu) 159 | comp_stats['sigma_k'].append(sigma) 160 | comp_stats['z_k'].append(z) 161 | comp_stats['q_z_k'].append(q_z) 162 | 163 | # --- Decode latents --- 164 | recon, x_r_k, log_m_r_k = self.decode_latents(comp_stats.z_k) 165 | 166 | # --- Loss terms --- 167 | losses = AttrDict() 168 | # -- Reconstruction loss 169 | losses['err'] = Genesis.x_loss(x, log_m_r_k, x_r_k, self.std) 170 | mx_r_k = [x*logm.exp() for x, logm in zip(x_r_k, log_m_r_k)] 171 | # -- Optional: Attention mask loss 172 | if self.klm_loss: 173 | if self.detach_mr_in_klm: 174 | log_m_r_k = [m.detach() for m in log_m_r_k] 175 | losses['kl_m'] = MONet.kl_m_loss( 176 | log_m_k=log_m_k, log_m_r_k=log_m_r_k, debug=self.debug) 177 | # -- Component KL 178 | losses['kl_l_k'], p_z_k = Genesis.mask_latent_loss( 179 | comp_stats.q_z_k, comp_stats.z_k, 180 | prior_lstm=self.prior_lstm, prior_linear=self.prior_linear, 181 | debug=self.debug) 182 | 183 | # Track quantities of interest 184 | stats = AttrDict( 185 | recon=recon, log_m_k=log_m_k, log_s_k=log_s_k, x_r_k=x_r_k, 186 | log_m_r_k=log_m_r_k, mx_r_k=mx_r_k, 187 | instance_seg=torch.argmax(torch.cat(log_m_k, dim=1), dim=1), 188 | instance_seg_r=torch.argmax(torch.cat(log_m_r_k, dim=1), dim=1)) 189 | 190 | # Sanity checks 191 | if self.debug: 192 | if not self.dynamic_K: 193 | assert len(log_m_k) == self.K_steps 194 | assert len(log_m_r_k) == self.K_steps 195 | misc.check_log_masks(log_m_k) 196 | misc.check_log_masks(log_m_r_k) 197 | 198 | if self.multi_gpu: 199 | # q_z_k is a torch.distribution which doesn't work with the 200 | # gathering used by DataParallel. 201 | del comp_stats['q_z_k'] 202 | 203 | return recon, losses, stats, att_stats, comp_stats 204 | 205 | def decode_latents(self, z_k): 206 | # --- Reconstruct components and image --- 207 | x_r_k, m_r_logits_k = [], [] 208 | for z in z_k: 209 | dec = self.decoder_module(z) 210 | x_r_k.append(dec[:, :3, :, :]) 211 | m_r_logits_k.append(dec[:, 3: , :, :]) 212 | # Optional: Apply pixelbound 213 | if self.pixel_bound: 214 | x_r_k = [torch.sigmoid(item) for item in x_r_k] 215 | # --- Reconstruct masks --- 216 | log_m_r_stack = MONet.get_mask_recon_stack( 217 | m_r_logits_k, 'softmax', log=True) 218 | log_m_r_k = torch.split(log_m_r_stack, 1, dim=4) 219 | log_m_r_k = [m[:, :, :, :, 0] for m in log_m_r_k] 220 | # --- Reconstruct input image by marginalising (aka summing) --- 221 | x_r_stack = torch.stack(x_r_k, dim=4) 222 | m_r_stack = torch.stack(log_m_r_k, dim=4).exp() 223 | recon = (m_r_stack * x_r_stack).sum(dim=4) 224 | 225 | return recon, x_r_k, log_m_r_k 226 | 227 | def sample(self, batch_size, K_steps=None): 228 | K_steps = self.K_steps if K_steps is None else K_steps 229 | 230 | # Sample latents 231 | if self.autoreg_prior: 232 | z_k = [Normal(0, 1).sample([batch_size, self.feat_dim])] 233 | state = None 234 | for k in range(1, K_steps): 235 | # TODO(martin): reuse code from forward method? 236 | lstm_out, state = self.prior_lstm( 237 | z_k[-1].view(1, batch_size, -1), state) 238 | linear_out = self.prior_linear(lstm_out) 239 | linear_out = torch.chunk(linear_out, 2, dim=2) 240 | linear_out = [item.squeeze(0) for item in linear_out] 241 | mu = torch.tanh(linear_out[0]) 242 | sigma = B.to_prior_sigma(linear_out[1]) 243 | p_z = Normal(mu.view([batch_size, self.feat_dim]), 244 | sigma.view([batch_size, self.feat_dim])) 245 | z_k.append(p_z.sample()) 246 | else: 247 | p_z = Normal(0, 1) 248 | z_k = [p_z.sample([batch_size, self.feat_dim]) 249 | for _ in range(K_steps)] 250 | 251 | # Decode latents 252 | recon, x_r_k, log_m_r_k = self.decode_latents(z_k) 253 | 254 | stats = AttrDict(x_k=x_r_k, log_m_k=log_m_r_k, 255 | mx_k=[x*m.exp() for x, m in zip(x_r_k, log_m_r_k)]) 256 | return recon, stats 257 | -------------------------------------------------------------------------------- /models/monet_config.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | from attrdict import AttrDict 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | from torch.distributions.normal import Normal 21 | from torch.distributions.categorical import Categorical 22 | from torch.distributions.kl import kl_divergence 23 | 24 | import numpy as np 25 | 26 | from forge import flags 27 | 28 | from modules.unet import UNet 29 | import modules.attention as attention 30 | from modules.component_vae import ComponentVAE 31 | from models.genesis_config import Genesis 32 | from utils import misc 33 | 34 | 35 | # Attention network 36 | flags.DEFINE_integer('filter_start', 32, 'Starting number of channels in UNet.') 37 | flags.DEFINE_string('prior_mode', 'softmax', '{scope, softmax}') 38 | 39 | 40 | def load(cfg): 41 | return MONet(cfg) 42 | 43 | 44 | class MONet(nn.Module): 45 | 46 | def __init__(self, cfg): 47 | super(MONet, self).__init__() 48 | # Configuration 49 | self.K_steps = cfg.K_steps 50 | self.prior_mode = cfg.prior_mode 51 | self.mckl = cfg.montecarlo_kl 52 | self.debug = cfg.debug 53 | self.pixel_bound = cfg.pixel_bound 54 | # Sub-Modules 55 | # - Attention Network 56 | if not hasattr(cfg, 'filter_start'): 57 | cfg['filter_start'] = 32 58 | core = UNet( 59 | num_blocks=int(np.log2(cfg.img_size)-1), 60 | img_size=cfg.img_size, 61 | filter_start=cfg.filter_start, 62 | in_chnls=4, 63 | out_chnls=1, 64 | norm='in') 65 | self.att_process = attention.SimpleSBP(core) 66 | # - Component VAE 67 | self.comp_vae = ComponentVAE(nout=4, cfg=cfg, act=nn.ReLU()) 68 | self.comp_vae.pixel_bound = False 69 | # Initialise pixel output standard deviations 70 | std = cfg.pixel_std2 * torch.ones(1, 1, 1, 1, self.K_steps) 71 | std[0, 0, 0, 0, 0] = cfg.pixel_std1 # first step 72 | self.register_buffer('std', std) 73 | 74 | def forward(self, x): 75 | """ 76 | Args: 77 | x (torch.Tensor): Input images [batch size, 3, dim, dim] 78 | """ 79 | # --- Predict segmentation masks --- 80 | log_m_k, log_s_k, att_stats = self.att_process(x, self.K_steps-1) 81 | 82 | # --- Reconstruct components --- 83 | x_m_r_k, comp_stats = self.comp_vae(x, log_m_k) 84 | # Split into appearances and mask prior 85 | x_r_k = [item[:, :3, :, :] for item in x_m_r_k] 86 | m_r_logits_k = [item[:, 3:, :, :] for item in x_m_r_k] 87 | # Apply pixelbound 88 | if self.pixel_bound: 89 | x_r_k = [torch.sigmoid(item) for item in x_r_k] 90 | 91 | # --- Reconstruct input image by marginalising (aka summing) --- 92 | x_r_stack = torch.stack(x_r_k, dim=4) 93 | m_stack = torch.stack(log_m_k, dim=4).exp() 94 | recon = (m_stack * x_r_stack).sum(dim=4) 95 | 96 | # --- Reconstruct masks --- 97 | log_m_r_stack = self.get_mask_recon_stack( 98 | m_r_logits_k, self.prior_mode, log=True) 99 | log_m_r_k = torch.split(log_m_r_stack, 1, dim=4) 100 | log_m_r_k = [m[:, :, :, :, 0] for m in log_m_r_k] 101 | 102 | # --- Loss terms --- 103 | losses = AttrDict() 104 | # -- Reconstruction loss 105 | losses['err'] = Genesis.x_loss(x, log_m_k, x_r_k, self.std) 106 | # -- Attention mask KL 107 | losses['kl_m'] = self.kl_m_loss(log_m_k=log_m_k, log_m_r_k=log_m_r_k) 108 | # -- Component KL 109 | q_z_k = [Normal(m, s) for m, s in 110 | zip(comp_stats.mu_k, comp_stats.sigma_k)] 111 | kl_l_k = misc.get_kl( 112 | comp_stats.z_k, q_z_k, len(q_z_k)*[Normal(0, 1)], self.mckl) 113 | losses['kl_l_k'] = [kld.sum(1) for kld in kl_l_k] 114 | 115 | # Track quantities of interest 116 | stats = AttrDict( 117 | recon=recon, log_m_k=log_m_k, log_s_k=log_s_k, x_r_k=x_r_k, 118 | log_m_r_k=log_m_r_k, 119 | mx_r_k=[x*logm.exp() for x, logm in zip(x_r_k, log_m_k)]) 120 | 121 | # Sanity check that masks sum to one if in debug mode 122 | if self.debug: 123 | assert len(log_m_k) == self.K_steps 124 | assert len(log_m_r_k) == self.K_steps 125 | misc.check_log_masks(log_m_k) 126 | misc.check_log_masks(log_m_r_k) 127 | 128 | return recon, losses, stats, att_stats, comp_stats 129 | 130 | def get_features(self, image_batch): 131 | with torch.no_grad(): 132 | _, _, _, _, comp_stats = self.forward(image_batch) 133 | return torch.cat(comp_stats.z_k, dim=1) 134 | 135 | @staticmethod 136 | def get_mask_recon_stack(m_r_logits_k, prior_mode, log): 137 | if prior_mode == 'softmax': 138 | if log: 139 | return F.log_softmax(torch.stack(m_r_logits_k, dim=4), dim=4) 140 | return F.softmax(torch.stack(m_r_logits_k, dim=4), dim=4) 141 | elif prior_mode == 'scope': 142 | log_m_r_k = [] 143 | log_s = torch.zeros_like(m_r_logits_k[0]) 144 | for step, logits in enumerate(m_r_logits_k): 145 | if step == len(m_r_logits_k) - 1: 146 | log_m_r_k.append(log_s) 147 | else: 148 | log_a = F.logsigmoid(logits) 149 | log_neg_a = F.logsigmoid(-logits) 150 | log_m_r_k.append(log_s + log_a) 151 | log_s = log_s + log_neg_a 152 | log_m_r_stack = torch.stack(log_m_r_k, dim=4) 153 | return log_m_r_stack if log else log_m_r_stack.exp() 154 | else: 155 | raise ValueError("No valid prior mode.") 156 | 157 | @staticmethod 158 | def kl_m_loss(log_m_k, log_m_r_k, debug=False): 159 | if debug: 160 | assert len(log_m_k) == len(log_m_r_k) 161 | batch_size = log_m_k[0].size(0) 162 | m_stack = torch.stack(log_m_k, dim=4).exp() 163 | m_r_stack = torch.stack(log_m_r_k, dim=4).exp() 164 | # Lower bound to 1e-5 to avoid infinities 165 | m_stack = torch.max(m_stack, torch.tensor(1e-5)) 166 | m_r_stack = torch.max(m_r_stack, torch.tensor(1e-5)) 167 | q_m = Categorical(m_stack.view(-1, len(log_m_k))) 168 | p_m = Categorical(m_r_stack.view(-1, len(log_m_k))) 169 | kl_m_ppc = kl_divergence(q_m, p_m).view(batch_size, -1) 170 | return kl_m_ppc.sum(dim=1) 171 | 172 | def sample(self, batch_size, K_steps=None): 173 | K_steps = self.K_steps if K_steps is None else K_steps 174 | # Sample latents 175 | z_batched = Normal(0, 1).sample((batch_size*K_steps, self.comp_vae.ldim)) 176 | # Pass latent through decoder 177 | x_hat_batched = self.comp_vae.decode(z_batched) 178 | # Split into appearances and masks 179 | x_r_batched = x_hat_batched[:, :3, :, :] 180 | m_r_logids_batched = x_hat_batched[:, 3:, :, :] 181 | # Apply pixel bound to appearances 182 | if self.pixel_bound: 183 | x_r_batched = torch.sigmoid(x_r_batched) 184 | # Chunk into K steps 185 | x_r_k = torch.chunk(x_r_batched, K_steps, dim=0) 186 | m_r_logits_k = torch.chunk(m_r_logids_batched, K_steps, dim=0) 187 | # Normalise masks 188 | m_r_stack = self.get_mask_recon_stack( 189 | m_r_logits_k, self.prior_mode, log=False) 190 | # Apply masking and sum to get generated image 191 | x_r_stack = torch.stack(x_r_k, dim=4) 192 | gen_image = (m_r_stack * x_r_stack).sum(dim=4) 193 | # Tracking 194 | log_m_r_k = [item.squeeze(dim=4) for item in 195 | torch.split(m_r_stack.log(), 1, dim=4)] 196 | stats = AttrDict(gen_image=gen_image, x_k=x_r_k, log_m_k=log_m_r_k, 197 | mx_k=[x*m.exp() for x, m in zip(x_r_k, log_m_r_k)]) 198 | return gen_image, stats 199 | -------------------------------------------------------------------------------- /models/vae_config.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | from attrdict import AttrDict 15 | 16 | import torch 17 | import torch.nn as nn 18 | from torch.distributions.normal import Normal 19 | 20 | from forge import flags 21 | 22 | from modules.blocks import Flatten 23 | from modules.decoders import BroadcastDecoder 24 | from third_party.sylvester.VAE import VAE 25 | 26 | 27 | # GatedConvVAE 28 | flags.DEFINE_integer('latent_dimension', 64, 'Latent channels.') 29 | flags.DEFINE_boolean('broadcast_decoder', False, 30 | 'Use broadcast decoder instead of deconv.') 31 | # Losses 32 | flags.DEFINE_boolean('pixel_bound', True, 'Bound pixel values to [0, 1].') 33 | flags.DEFINE_float('pixel_std', 0.7, 'StdDev of reconstructed pixels.') 34 | 35 | 36 | def load(cfg): 37 | return BaselineVAE(cfg) 38 | 39 | 40 | class BaselineVAE(nn.Module): 41 | 42 | def __init__(self, cfg): 43 | super(BaselineVAE, self).__init__() 44 | cfg.K_steps = None 45 | # Configuration 46 | self.ldim = cfg.latent_dimension 47 | self.pixel_std = cfg.pixel_std 48 | self.pixel_bound = cfg.pixel_bound 49 | self.debug = cfg.debug 50 | # Module 51 | nin = cfg.input_channels if hasattr(cfg, 'input_channels') else 3 52 | self.vae = VAE(self.ldim, [nin, cfg.img_size, cfg.img_size], nin) 53 | if cfg.broadcast_decoder: 54 | self.vae.p_x_nn = nn.Sequential( 55 | Flatten(), 56 | BroadcastDecoder(in_chnls=self.ldim, out_chnls=64, h_chnls=64, 57 | num_layers=4, img_dim=cfg.img_size, 58 | act=nn.ELU()), 59 | nn.ELU() 60 | ) 61 | self.vae.p_x_mean = nn.Conv2d(64, nin, 1, 1, 0) 62 | 63 | def forward(self, x): 64 | """ x (torch.Tensor): Input images [batch size, 3, dim, dim] """ 65 | # Forward propagation 66 | recon, stats = self.vae(x) 67 | if self.pixel_bound: 68 | recon = torch.sigmoid(recon) 69 | # Reconstruction loss 70 | p_xr = Normal(recon, self.pixel_std) 71 | err = -p_xr.log_prob(x).sum(dim=(1, 2, 3)) 72 | # KL divergence loss 73 | p_z = Normal(0, 1) 74 | # TODO(martin): the parsing below is not very intuitive 75 | # -- No flow 76 | if 'z' in stats: 77 | q_z = Normal(stats.mu, stats.sigma) 78 | kl = q_z.log_prob(stats.z) - p_z.log_prob(stats.z) 79 | kl = kl.sum(dim=1) 80 | # -- Using normalising flow 81 | else: 82 | q_z_0 = Normal(stats.mu_0, stats.sigma_0) 83 | kl = q_z_0.log_prob(stats.z_0) - p_z.log_prob(stats.z_k) 84 | kl = kl.sum(dim=1) - stats.ldj 85 | # Tracking 86 | losses = AttrDict(err=err, kl_l=kl) 87 | return recon, losses, stats, None, None 88 | 89 | def sample(self, batch_size, *args, **kwargs): 90 | # Sample z 91 | z = Normal(0, 1).sample([batch_size, self.ldim]) 92 | # Decode z 93 | x = self.vae.decode(z) 94 | if self.pixel_bound: 95 | x = torch.sigmoid(x) 96 | return x, AttrDict(z=z) 97 | 98 | def get_features(self, image_batch): 99 | with torch.no_grad(): 100 | _, _, stats, _, _ = self.forward(image_batch) 101 | return stats.z 102 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/applied-ai-lab/genesis/9abf202bbad6fa4a675117fdea0be163e4f16695/modules/__init__.py -------------------------------------------------------------------------------- /modules/attention.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | from attrdict import AttrDict 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | import numpy as np 21 | 22 | from modules import blocks as B 23 | 24 | 25 | class SimpleSBP(nn.Module): 26 | 27 | def __init__(self, core): 28 | super(SimpleSBP, self).__init__() 29 | self.core = core 30 | 31 | def forward(self, x, steps_to_run): 32 | # Initialise lists to store tensors over K steps 33 | log_m_k = [] 34 | # Set initial scope to all ones, so log scope is all zeros 35 | log_s_k = [torch.zeros_like(x)[:, :1, :, :]] 36 | # Loop over steps 37 | for step in range(steps_to_run): 38 | # Compute mask and update scope. Last step is different 39 | # Compute a_logits given input and current scope 40 | core_out, _ = self.core(torch.cat((x, log_s_k[step]), dim=1)) 41 | # Take first channel as logits for masks 42 | a_logits = core_out[:, :1, :, :] 43 | log_a = F.logsigmoid(a_logits) 44 | log_neg_a = F.logsigmoid(-a_logits) 45 | # Compute mask. Note that old scope needs to be used!! 46 | log_m_k.append(log_s_k[step] + log_a) 47 | # Update scope given attentikon 48 | log_s_k.append(log_s_k[step] + log_neg_a) 49 | # Set mask equal to scope for last step 50 | log_m_k.append(log_s_k[-1]) 51 | return log_m_k, log_s_k, {} 52 | 53 | def masks_from_zm_k(self, zm_k, img_size): 54 | # zm_k: K*(batch_size, ldim) 55 | b_sz = zm_k[0].size(0) 56 | log_m_k = [] 57 | log_s_k = [torch.zeros(b_sz, 1, img_size, img_size)] 58 | other_k = [] 59 | # TODO(martin): parallelise decoding 60 | for zm in zm_k: 61 | core_out = self.core.decode(zm) 62 | # Take first channel as logits for masks 63 | a_logits = core_out[:, :1, :, :] 64 | log_a = F.logsigmoid(a_logits) 65 | log_neg_a = F.logsigmoid(-a_logits) 66 | # Take rest of channels for other 67 | other_k.append(core_out[:, 1:, :, :]) 68 | # Compute mask. Note that old scope needs to be used!! 69 | log_m_k.append(log_s_k[-1] + log_a) 70 | # Update scope given attention 71 | log_s_k.append(log_s_k[-1] + log_neg_a) 72 | # Set mask equal to scope for last step 73 | log_m_k.append(log_s_k[-1]) 74 | return log_m_k, log_s_k, other_k 75 | 76 | 77 | class LatentSBP(SimpleSBP): 78 | 79 | def __init__(self, core): 80 | super(LatentSBP, self).__init__(core) 81 | self.lstm = nn.LSTM(core.z_size+256, 2*core.z_size) 82 | self.linear = nn.Linear(2*core.z_size, 2*core.z_size) 83 | 84 | def forward(self, x, steps_to_run): 85 | h = self.core.q_z_nn(x) 86 | bs = h.size(0) 87 | h = h.view(bs, -1) 88 | mean_0 = self.core.q_z_mean(h) 89 | var_0 = self.core.q_z_var(h) 90 | z, q_z = self.core.reparameterize(mean_0, var_0) 91 | z_k = [z] 92 | q_z_k = [q_z] 93 | state = None 94 | for step in range(1, steps_to_run): 95 | h_and_z = torch.cat([h, z_k[-1]], dim=1) 96 | lstm_out, state = self.lstm(h_and_z.view(1, bs, -1), state) 97 | linear_out = self.linear(lstm_out)[0, :, :] 98 | linear_out = torch.chunk(linear_out, 2, dim=1) 99 | mean_k = linear_out[0] 100 | var_k = B.to_var(linear_out[1]) 101 | z, q_z = self.core.reparameterize(mean_k, var_k) 102 | z_k.append(z) 103 | q_z_k.append(q_z) 104 | # Initialise lists to store tensors over K steps 105 | log_m_k = [] 106 | stats_k = [] 107 | # Set initial scope to all ones, so log scope is all zeros 108 | log_s_k = [torch.zeros_like(x)[:, :1, :, :]] 109 | # Run decoder in parallel 110 | z_batch = torch.cat(z_k, dim=0) 111 | core_out_batch = self.core.decode(z_batch) 112 | core_out = torch.chunk(core_out_batch, steps_to_run, dim=0) 113 | # Compute masks 114 | for step, (z, q_z, out) in enumerate(zip(z_k, q_z_k, core_out)): 115 | # Compute a_logits given input and current scope 116 | stats = AttrDict(x=out, mu=q_z.mean, sigma=q_z.scale, z=z) 117 | # Take first channel for masks 118 | a_logits = out[:, :1, :, :] 119 | log_a = F.logsigmoid(a_logits) 120 | log_neg_a = F.logsigmoid(-a_logits) 121 | # Compute mask. Note that old scope needs to be used!! 122 | log_m_k.append(log_s_k[step] + log_a) 123 | # Update scope given attention 124 | log_s_k.append(log_s_k[step] + log_neg_a) 125 | # Track stats 126 | stats_k.append(stats) 127 | # Set mask equal to scope for last step 128 | log_m_k.append(log_s_k[-1]) 129 | # Convert list of dicts into dict of lists 130 | stats = AttrDict() 131 | for key in stats_k[0]: 132 | stats[key+'_k'] = [s[key] for s in stats_k] 133 | return log_m_k, log_s_k, stats 134 | 135 | 136 | class InstanceColouringSBP(nn.Module): 137 | 138 | def __init__(self, img_size, kernel='gaussian', 139 | colour_dim=8, K_steps=None, feat_dim=None, 140 | semiconv=True): 141 | super(InstanceColouringSBP, self).__init__() 142 | # Config 143 | self.img_size = img_size 144 | self.kernel = kernel 145 | self.colour_dim = colour_dim 146 | # Initialise kernel sigma 147 | if self.kernel == 'laplacian': 148 | sigma_init = 1.0 / (np.sqrt(K_steps)*np.log(2)) 149 | elif self.kernel == 'gaussian': 150 | sigma_init = 1.0 / (K_steps*np.log(2)) 151 | elif self.kernel == 'epanechnikov': 152 | sigma_init = 2.0 / K_steps 153 | else: 154 | return ValueError("No valid kernel.") 155 | self.log_sigma = nn.Parameter(torch.tensor(sigma_init).log()) 156 | # Colour head 157 | if semiconv: 158 | self.colour_head = B.SemiConv(feat_dim, self.colour_dim, img_size) 159 | else: 160 | self.colour_head = nn.Conv2d(feat_dim, self.colour_dim, 1) 161 | 162 | def forward(self, features, steps_to_run, debug=False, 163 | dynamic_K=False, *args, **kwargs): 164 | batch_size = features.size(0) 165 | stats = AttrDict() 166 | if isinstance(features, tuple): 167 | features = features[0] 168 | if dynamic_K: 169 | assert batch_size == 1 170 | # Get colours 171 | colour_out = self.colour_head(features) 172 | if isinstance(colour_out, tuple): 173 | colour, delta = colour_out 174 | else: 175 | colour, delta = colour_out, None 176 | # Sample from uniform to select random pixels as seeds 177 | rand_pixel = torch.empty(batch_size, 1, *colour.shape[2:]) 178 | rand_pixel = rand_pixel.uniform_() 179 | # Run SBP 180 | seed_list = [] 181 | log_m_k = [] 182 | log_s_k = [torch.zeros(batch_size, 1, self.img_size, self.img_size)] 183 | for step in range(steps_to_run): 184 | # Determine seed 185 | scope = F.interpolate(log_s_k[step].exp(), size=colour.shape[2:], 186 | mode='bilinear', align_corners=False) 187 | pixel_probs = rand_pixel * scope 188 | rand_max = pixel_probs.flatten(2).argmax(2).flatten() 189 | # TODO(martin): parallelise this 190 | seed = torch.empty((batch_size, self.colour_dim)) 191 | for bidx in range(batch_size): 192 | seed[bidx, :] = colour.flatten(2)[bidx, :, rand_max[bidx]] 193 | seed_list.append(seed) 194 | # Compute masks 195 | if self.kernel == 'laplacian': 196 | distance = B.euclidian_distance(colour, seed) 197 | alpha = torch.exp(- distance / self.log_sigma.exp()) 198 | elif self.kernel == 'gaussian': 199 | distance = B.squared_distance(colour, seed) 200 | alpha = torch.exp(- distance / self.log_sigma.exp()) 201 | elif self.kernel == 'epanechnikov': 202 | distance = B.squared_distance(colour, seed) 203 | alpha = (1 - distance / self.log_sigma.exp()).relu() 204 | else: 205 | raise ValueError("No valid kernel.") 206 | alpha = alpha.unsqueeze(1) 207 | # Sanity checks 208 | if debug: 209 | assert alpha.max() <= 1, alpha.max() 210 | assert alpha.min() >= 0, alpha.min() 211 | # Clamp mask values to [0.01, 0.99] for numerical stability 212 | # TODO(martin): clamp less aggressively? 213 | alpha = B.clamp_preserve_gradients(alpha, 0.01, 0.99) 214 | # SBP update 215 | log_a = torch.log(alpha) 216 | log_neg_a = torch.log(1 - alpha) 217 | log_m = log_s_k[step] + log_a 218 | if dynamic_K and log_m.exp().sum() < 20: 219 | break 220 | log_m_k.append(log_m) 221 | log_s_k.append(log_s_k[step] + log_neg_a) 222 | # Set mask equal to scope for last step 223 | log_m_k.append(log_s_k[-1]) 224 | # Accumulate stats 225 | stats.update({'colour': colour, 'delta': delta, 'seeds': seed_list}) 226 | return log_m_k, log_s_k, stats 227 | -------------------------------------------------------------------------------- /modules/blocks.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | def clamp_preserve_gradients(x, lower, upper): 19 | # From: http://docs.pyro.ai/en/0.3.3/_modules/pyro/distributions/iaf.html 20 | return x + (x.clamp(lower, upper) - x).detach() 21 | 22 | def to_sigma(x): 23 | return F.softplus(x + 0.5) + 1e-8 24 | 25 | def to_var(x): 26 | return to_sigma(x)**2 27 | 28 | def to_prior_sigma(x, simgoid_bias=4.0, eps=1e-4): 29 | """ 30 | This parameterisation bounds sigma of a learned prior to [eps, 1+eps]. 31 | The default sigmoid_bias of 4.0 initialises sigma to be close to 1.0. 32 | The default eps prevents instability as sigma -> 0. 33 | """ 34 | return torch.sigmoid(x + simgoid_bias) + eps 35 | 36 | def flatten(x): 37 | return x.view(x.size(0), -1) 38 | 39 | def unflatten(x): 40 | return x.view(x.size(0), -1, 1, 1) 41 | 42 | def pixel_coords(img_size): 43 | g_1, g_2 = torch.meshgrid(torch.linspace(-1, 1, img_size), 44 | torch.linspace(-1, 1, img_size)) 45 | g_1 = g_1.view(1, 1, img_size, img_size) 46 | g_2 = g_2.view(1, 1, img_size, img_size) 47 | return torch.cat((g_1, g_2), dim=1) 48 | 49 | def euclidian_norm(x): 50 | # Clamp before taking sqrt for numerical stability 51 | return clamp_preserve_gradients((x**2).sum(1), 1e-10, 1e10).sqrt() 52 | 53 | def euclidian_distance(embedA, embedB): 54 | # Unflatten if needed if one is an image and the other a vector 55 | # Assumes inputs are batches 56 | if embedA.dim() == 4 or embedB.dim() == 4: 57 | if embedA.dim() == 2: 58 | embedA = unflatten(embedA) 59 | if embedB.dim() == 2: 60 | embedB = unflatten(embedB) 61 | return euclidian_norm(embedA - embedB) 62 | 63 | def squared_distance(embedA, embedB): 64 | # Unflatten if needed if one is an image and the other a vector 65 | # Assumes inputs are batches 66 | if embedA.dim() == 4 or embedB.dim() == 4: 67 | if embedA.dim() == 2: 68 | embedA = unflatten(embedA) 69 | if embedB.dim() == 2: 70 | embedB = unflatten(embedB) 71 | return ((embedA - embedB)**2).sum(1) 72 | 73 | class ToSigma(nn.Module): 74 | def __init__(self): 75 | super(ToSigma, self).__init__() 76 | def forward(self, x): 77 | return to_sigma(x) 78 | 79 | class ToVar(nn.Module): 80 | def __init__(self): 81 | super(ToVar, self).__init__() 82 | def forward(self, x): 83 | return to_var(x) 84 | 85 | class ScalarGate(nn.Module): 86 | def __init__(self, init=0.0): 87 | super(ScalarGate, self).__init__() 88 | self.gate = nn.Parameter(torch.tensor(init)) 89 | def forward(self, x): 90 | return self.gate * x 91 | 92 | class Flatten(nn.Module): 93 | def __init__(self): 94 | super(Flatten, self).__init__() 95 | def forward(self, x): 96 | return x.view(x.size(0), -1) 97 | 98 | class UnFlatten(nn.Module): 99 | def __init__(self): 100 | super(UnFlatten, self).__init__() 101 | def forward(self, x): 102 | return x.view(x.size(0), -1, 1, 1) 103 | 104 | class BroadcastLayer(nn.Module): 105 | def __init__(self, dim): 106 | super(BroadcastLayer, self).__init__() 107 | self.dim = dim 108 | self.pixel_coords = PixelCoords(dim) 109 | def forward(self, x): 110 | b_sz = x.size(0) 111 | # Broadcast 112 | if x.dim() == 2: 113 | x = x.view(b_sz, -1, 1, 1) 114 | x = x.expand(-1, -1, self.dim, self.dim) 115 | else: 116 | x = F.interpolate(x, self.dim) 117 | return self.pixel_coords(x) 118 | 119 | class PixelCoords(nn.Module): 120 | def __init__(self, im_dim): 121 | super(PixelCoords, self).__init__() 122 | # TODO(martin): avoid duplication 123 | g_1, g_2 = torch.meshgrid(torch.linspace(-1, 1, im_dim), 124 | torch.linspace(-1, 1, im_dim)) 125 | self.g_1 = g_1.view((1, 1) + g_1.shape) 126 | self.g_2 = g_2.view((1, 1) + g_2.shape) 127 | def forward(self, x): 128 | g_1 = self.g_1.expand(x.size(0), -1, -1, -1) 129 | g_2 = self.g_2.expand(x.size(0), -1, -1, -1) 130 | return torch.cat((x, g_1, g_2), dim=1) 131 | 132 | class Interpolate(nn.Module): 133 | def __init__(self, size=None, scale_factor=None, mode='nearest', 134 | align_corners=None): 135 | super(Interpolate, self).__init__() 136 | self.size = size 137 | self.scale_factor = scale_factor 138 | self.mode = mode 139 | self.align_corners = align_corners 140 | def forward(self, x): 141 | return F.interpolate(x, size=self.size, scale_factor=self.scale_factor, 142 | mode=self.mode, align_corners=self.align_corners) 143 | 144 | class ConvReLU(nn.Sequential): 145 | def __init__(self, nin, nout, kernel, stride=1, padding=0): 146 | super(ConvReLU, self).__init__( 147 | nn.Conv2d(nin, nout, kernel, stride, padding), 148 | nn.ReLU(inplace=True) 149 | ) 150 | 151 | class ConvINReLU(nn.Sequential): 152 | def __init__(self, nin, nout, kernel, stride=1, padding=0): 153 | super(ConvINReLU, self).__init__( 154 | nn.Conv2d(nin, nout, kernel, stride, padding, bias=False), 155 | nn.InstanceNorm2d(nout, affine=True), 156 | nn.ReLU(inplace=True) 157 | ) 158 | 159 | class ConvGNReLU(nn.Sequential): 160 | def __init__(self, nin, nout, kernel, stride=1, padding=0, groups=8): 161 | super(ConvGNReLU, self).__init__( 162 | nn.Conv2d(nin, nout, kernel, stride, padding, bias=False), 163 | nn.GroupNorm(groups, nout), 164 | nn.ReLU(inplace=True) 165 | ) 166 | 167 | class SemiConv(nn.Module): 168 | def __init__(self, nin, nout, img_size): 169 | super(SemiConv, self).__init__() 170 | self.conv = nn.Conv2d(nin, nout, 1) 171 | self.gate = ScalarGate() 172 | coords = pixel_coords(img_size) 173 | zeros = torch.zeros(1, nout-2, img_size, img_size) 174 | self.uv = torch.cat((zeros, coords), dim=1) 175 | def forward(self, x): 176 | out = self.gate(self.conv(x)) 177 | delta = out[:, -2:, :, :] 178 | return out + self.uv.to(out.device), delta 179 | -------------------------------------------------------------------------------- /modules/component_vae.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | from attrdict import AttrDict 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | from torch.distributions.normal import Normal 21 | 22 | import modules.blocks as B 23 | from modules.encoders import MONetCompEncoder 24 | from modules.decoders import BroadcastDecoder 25 | 26 | 27 | class ComponentVAE(nn.Module): 28 | 29 | def __init__(self, nout, cfg, act): 30 | super(ComponentVAE, self).__init__() 31 | self.ldim = cfg.comp_ldim # paper uses 16 32 | self.montecarlo = cfg.montecarlo_kl 33 | self.pixel_bound = cfg.pixel_bound 34 | # Sub-Modules 35 | self.encoder_module = MONetCompEncoder(cfg=cfg, act=act) 36 | self.decoder_module = BroadcastDecoder( 37 | in_chnls=self.ldim, 38 | out_chnls=nout, 39 | h_chnls=cfg.comp_dec_channels, 40 | num_layers=cfg.comp_dec_layers, 41 | img_dim=cfg.img_size, 42 | act=act 43 | ) 44 | 45 | def forward(self, x, log_mask): 46 | """ 47 | Args: 48 | x (torch.Tensor): Input to reconstruct [batch size, 3, dim, dim] 49 | log_mask (torch.Tensor or list of torch.Tensors): 50 | Mask to reconstruct [batch size, 1, dim, dim] 51 | """ 52 | # -- Check if inputs are lists 53 | K = 1 54 | b_sz = x.size(0) 55 | if isinstance(log_mask, list) or isinstance(log_mask, tuple): 56 | K = len(log_mask) 57 | # Repeat x along batch dimension 58 | x = x.repeat(K, 1, 1, 1) 59 | # Concat log_m_k along batch dimension 60 | log_mask = torch.cat(log_mask, dim=0) 61 | 62 | # -- Encode 63 | x = torch.cat((log_mask, x), dim=1) # Concat along feature dimension 64 | mu, sigma = self.encode(x) 65 | 66 | # -- Sample latents 67 | q_z = Normal(mu, sigma) 68 | # z - [batch_size * K, l_dim] with first axis: b0,k0 -> b0,k1 -> ... 69 | z = q_z.rsample() 70 | 71 | # -- Decode 72 | # x_r, m_r_logits = self.decode(z) 73 | x_r = self.decode(z) 74 | 75 | # -- Track quantities of interest and return 76 | x_r_k = torch.chunk(x_r, K, dim=0) 77 | z_k = torch.chunk(z, K, dim=0) 78 | mu_k = torch.chunk(mu, K, dim=0) 79 | sigma_k = torch.chunk(sigma, K, dim=0) 80 | stats = AttrDict(mu_k=mu_k, sigma_k=sigma_k, z_k=z_k) 81 | return x_r_k, stats 82 | 83 | def encode(self, x): 84 | x = self.encoder_module(x) 85 | mu, sigma_ps = torch.chunk(x, 2, dim=1) 86 | sigma = B.to_sigma(sigma_ps) 87 | return mu, sigma 88 | 89 | def decode(self, z): 90 | x_hat = self.decoder_module(z) 91 | if self.pixel_bound: 92 | x_hat = torch.sigmoid(x_hat) 93 | return x_hat 94 | 95 | def sample(self, batch_size=1, steps=1): 96 | raise NotImplementedError 97 | -------------------------------------------------------------------------------- /modules/decoders.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | import modules.blocks as B 19 | 20 | 21 | class BroadcastDecoder(nn.Module): 22 | 23 | def __init__(self, in_chnls, out_chnls, h_chnls, num_layers, img_dim, act): 24 | super(BroadcastDecoder, self).__init__() 25 | broad_dim = img_dim + 2*num_layers 26 | mods = [B.BroadcastLayer(broad_dim), 27 | nn.Conv2d(in_chnls+2, h_chnls, 3), 28 | act] 29 | for _ in range(num_layers - 1): 30 | mods.extend([nn.Conv2d(h_chnls, h_chnls, 3), act]) 31 | mods.append(nn.Conv2d(h_chnls, out_chnls, 1)) 32 | self.seq = nn.Sequential(*mods) 33 | 34 | def forward(self, x): 35 | return self.seq(x) 36 | -------------------------------------------------------------------------------- /modules/encoders.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import Sequential as Seq 17 | import torch.nn.functional as F 18 | 19 | import modules.blocks as B 20 | 21 | 22 | class MONetCompEncoder(nn.Module): 23 | 24 | def __init__(self, cfg, act): 25 | super(MONetCompEncoder, self).__init__() 26 | nin = cfg.input_channels if hasattr(cfg, 'input_channels') else 3 27 | c = cfg.comp_enc_channels 28 | self.ldim = cfg.comp_ldim 29 | nin_mlp = 2*c * (cfg.img_size//16)**2 30 | nhid_mlp = max(256, 2*self.ldim) 31 | self.module = Seq(nn.Conv2d(nin+1, c, 3, 2, 1), act, 32 | nn.Conv2d(c, c, 3, 2, 1), act, 33 | nn.Conv2d(c, 2*c, 3, 2, 1), act, 34 | nn.Conv2d(2*c, 2*c, 3, 2, 1), act, 35 | B.Flatten(), 36 | nn.Linear(nin_mlp, nhid_mlp), act, 37 | nn.Linear(nhid_mlp, 2*self.ldim)) 38 | 39 | def forward(self, x): 40 | return self.module(x) 41 | -------------------------------------------------------------------------------- /modules/unet.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | import modules.blocks as B 19 | 20 | 21 | class UNet(nn.Module): 22 | 23 | def __init__(self, num_blocks, img_size=64, 24 | filter_start=32, in_chnls=4, out_chnls=1, 25 | norm='in'): 26 | super(UNet, self).__init__() 27 | # TODO(martin): make more general 28 | c = filter_start 29 | if norm == 'in': 30 | conv_block = B.ConvINReLU 31 | elif norm == 'gn': 32 | conv_block = B.ConvGNReLU 33 | else: 34 | conv_block = B.ConvReLU 35 | if num_blocks == 4: 36 | enc_in = [in_chnls, c, 2*c, 2*c] 37 | enc_out = [c, 2*c, 2*c, 2*c] 38 | dec_in = [4*c, 4*c, 4*c, 2*c] 39 | dec_out = [2*c, 2*c, c, c] 40 | elif num_blocks == 5: 41 | enc_in = [in_chnls, c, c, 2*c, 2*c] 42 | enc_out = [c, c, 2*c, 2*c, 2*c] 43 | dec_in = [4*c, 4*c, 4*c, 2*c, 2*c] 44 | dec_out = [2*c, 2*c, c, c, c] 45 | elif num_blocks == 6: 46 | enc_in = [in_chnls, c, c, c, 2*c, 2*c] 47 | enc_out = [c, c, c, 2*c, 2*c, 2*c] 48 | dec_in = [4*c, 4*c, 4*c, 2*c, 2*c, 2*c] 49 | dec_out = [2*c, 2*c, c, c, c, c] 50 | self.down = [] 51 | self.up = [] 52 | # 3x3 kernels, stride 1, padding 1 53 | for i, o in zip(enc_in, enc_out): 54 | self.down.append(conv_block(i, o, 3, 1, 1)) 55 | for i, o in zip(dec_in, dec_out): 56 | self.up.append(conv_block(i, o, 3, 1, 1)) 57 | self.down = nn.ModuleList(self.down) 58 | self.up = nn.ModuleList(self.up) 59 | self.featuremap_size = img_size // 2**(num_blocks-1) 60 | self.mlp = nn.Sequential( 61 | B.Flatten(), 62 | nn.Linear(2*c*self.featuremap_size**2, 128), nn.ReLU(), 63 | nn.Linear(128, 128), nn.ReLU(), 64 | nn.Linear(128, 2*c*self.featuremap_size**2), nn.ReLU() 65 | ) 66 | self.final_conv = nn.Conv2d(c, out_chnls, 1) 67 | self.out_chnls = out_chnls 68 | 69 | def forward(self, x): 70 | batch_size = x.size(0) 71 | x_down = [x] 72 | skip = [] 73 | # Down 74 | for i, block in enumerate(self.down): 75 | act = block(x_down[-1]) 76 | skip.append(act) 77 | if i < len(self.down)-1: 78 | act = F.interpolate(act, scale_factor=0.5, mode='nearest') 79 | x_down.append(act) 80 | # FC 81 | x_up = self.mlp(x_down[-1]) 82 | x_up = x_up.view(batch_size, -1, 83 | self.featuremap_size, self.featuremap_size) 84 | # Up 85 | for i, block in enumerate(self.up): 86 | features = torch.cat([x_up, skip[-1 - i]], dim=1) 87 | x_up = block(features) 88 | if i < len(self.up)-1: 89 | x_up = F.interpolate(x_up, scale_factor=2.0, mode='nearest') 90 | return self.final_conv(x_up), None 91 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/applied-ai-lab/genesis/9abf202bbad6fa4a675117fdea0be163e4f16695/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/compute_fid.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import os 15 | from os import path as osp 16 | import random 17 | from attrdict import AttrDict 18 | from tqdm import tqdm 19 | 20 | import torch 21 | import numpy as np 22 | from PIL import Image 23 | 24 | import forge 25 | from forge import flags 26 | import forge.experiment_tools as fet 27 | from forge.experiment_tools import fprint 28 | 29 | from third_party.pytorch_fid import fid_score as FID 30 | 31 | 32 | def main_flags(): 33 | # Data & model config 34 | flags.DEFINE_string('data_config', 'datasets/gqn_config.py', 35 | 'Path to a data config file.') 36 | flags.DEFINE_string('model_config', 'models/genesis_config.py', 37 | 'Path to a model config file.') 38 | # Trained model 39 | flags.DEFINE_string('model_dir', 'checkpoints/test/1', 40 | 'Path to model directory.') 41 | flags.DEFINE_string('model_file', 'model.ckpt-FINAL', 'Name of model file.') 42 | # FID 43 | flags.DEFINE_integer('feat_dim', 2048, 'Number of Incpetion features.') 44 | flags.DEFINE_integer('num_fid_images', 10000, 45 | 'Number of images to compute the FID on.') 46 | # Other 47 | flags.DEFINE_string('img_dir', '/tmp', 'Directory for saving pngs.') 48 | flags.DEFINE_integer('batch_size', 10, 'Mini-batch size.') 49 | flags.DEFINE_boolean('gpu', True, 'Use GPU if available.') 50 | flags.DEFINE_integer('seed', 0, 'Seed for random number generators.') 51 | 52 | 53 | def main(): 54 | # Parse flags 55 | config = forge.config() 56 | fet.EXPERIMENT_FOLDER = config.model_dir 57 | fet.FPRINT_FILE = 'fid_evaluation.txt' 58 | config.shuffle_test = True 59 | 60 | # Fix seeds. Always first thing to be done after parsing the config! 61 | torch.manual_seed(config.seed) 62 | np.random.seed(config.seed) 63 | random.seed(config.seed) 64 | # Make CUDA operations deterministic 65 | torch.backends.cudnn.deterministic = True 66 | torch.backends.cudnn.benchmark = False 67 | 68 | # Using GPU? 69 | if torch.cuda.is_available() and config.gpu: 70 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 71 | else: 72 | config.gpu = False 73 | torch.set_default_tensor_type('torch.FloatTensor') 74 | fet.print_flags() 75 | 76 | # Load data 77 | _, _, test_loader = fet.load(config.data_config, config) 78 | 79 | # Load model 80 | flag_path = osp.join(config.model_dir, 'flags.json') 81 | fprint(f"Restoring flags from {flag_path}") 82 | pretrained_flags = AttrDict(fet.json_load(flag_path)) 83 | model = fet.load(config.model_config, pretrained_flags) 84 | model_path = osp.join(config.model_dir, config.model_file) 85 | fprint(f"Restoring model from {model_path}") 86 | checkpoint = torch.load(model_path, map_location='cpu') 87 | model_state_dict = checkpoint['model_state_dict'] 88 | model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1', None) 89 | model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2', None) 90 | model.load_state_dict(model_state_dict) 91 | fprint(model) 92 | # Put model on GPU 93 | if config.gpu: 94 | model = model.cuda() 95 | 96 | # Compute FID 97 | fid_from_model(model, test_loader, config.batch_size, 98 | config.num_fid_images, config.feat_dim, config.img_dir) 99 | 100 | 101 | def fid_from_model(model, test_loader, batch_size=10, num_images=10000, 102 | feat_dim=2048, img_dir='/tmp'): 103 | 104 | model.eval() 105 | 106 | # Save images from test set as pngs 107 | fprint("Saving images from test set as pngs.", True) 108 | test_dir = osp.join(img_dir, 'test_images') 109 | os.makedirs(test_dir) 110 | count = 0 111 | for bidx, batch in enumerate(test_loader): 112 | count = tensor_to_png(batch['input'], test_dir, count, num_images) 113 | if count >= num_images: 114 | break 115 | 116 | # Generate images and save as pngs 117 | fprint("Generate images and save as pngs.", True) 118 | gen_dir = osp.join(img_dir, 'generated_images') 119 | os.makedirs(gen_dir) 120 | count = 0 121 | for _ in tqdm(range(num_images // batch_size + 1)): 122 | if count >= num_images: 123 | break 124 | with torch.no_grad(): 125 | gen_img, _ = model.sample(batch_size) 126 | count = tensor_to_png(gen_img, gen_dir, count, num_images) 127 | 128 | # Compute FID 129 | fprint("Computing FID.", True) 130 | gpu = next(model.parameters()).is_cuda 131 | fid_value = FID.calculate_fid_given_paths( 132 | [test_dir, gen_dir], batch_size, gpu, feat_dim) 133 | fprint(f"FID: {fid_value}", True) 134 | 135 | model.train() 136 | 137 | return fid_value 138 | 139 | 140 | def tensor_to_png(tensor, save_dir, count, stop): 141 | np_images = tensor.cpu().numpy() 142 | np_images = np.moveaxis(np_images, 1, 3) 143 | for i in range(len(np_images)): 144 | im = Image.fromarray(np.uint8(255*np_images[i])) 145 | fn = osp.join(save_dir, str(count).zfill(6) + '.png') 146 | im.save(fn) 147 | count += 1 148 | if count >= stop: 149 | return count 150 | return count 151 | 152 | 153 | if __name__ == "__main__": 154 | main_flags() 155 | main() 156 | -------------------------------------------------------------------------------- /scripts/compute_seg_metrics.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import os 15 | from attrdict import AttrDict 16 | from tqdm import tqdm 17 | import random 18 | 19 | import torch 20 | 21 | import numpy as np 22 | 23 | import forge 24 | from forge import flags 25 | import forge.experiment_tools as fet 26 | from forge.experiment_tools import fprint 27 | 28 | from utils.misc import average_ari, average_segcover 29 | 30 | 31 | # Config 32 | flags.DEFINE_string('data_config', 'datasets/shapestacks_config.py', 33 | 'Path to a data config file.') 34 | flags.DEFINE_string('model_config', 'models/genesis_config.py', 35 | 'Path to a model config file.') 36 | # Trained model 37 | flags.DEFINE_string('model_dir', 'checkpoints/test/1', 38 | 'Path to model directory.') 39 | flags.DEFINE_string('model_file', 'model.ckpt-FINAL', 'Name of model file.') 40 | # Other 41 | flags.DEFINE_integer('seed', 0, 'Seed for random number generators.') 42 | flags.DEFINE_integer('num_images', 320, 'Number of images to run on.') 43 | flags.DEFINE_string('split', 'test', '{train, val, test}') 44 | 45 | # Set manual seed 46 | torch.manual_seed(0) 47 | np.random.seed(0) 48 | random.seed(0) 49 | # Make CUDA operations deterministic 50 | torch.backends.cudnn.deterministic = True 51 | torch.backends.cudnn.benchmark = False 52 | 53 | 54 | def main(): 55 | # Parse flags 56 | config = forge.config() 57 | config.batch_size = 1 58 | config.load_instances = True 59 | config.debug = False 60 | fet.print_flags() 61 | 62 | # Restore original model flags 63 | pretrained_flags = AttrDict( 64 | fet.json_load(os.path.join(config.model_dir, 'flags.json'))) 65 | 66 | # Get validation loader 67 | train_loader, val_loader, test_loader = fet.load(config.data_config, config) 68 | fprint(f"Split: {config.split}") 69 | if config.split == 'train': 70 | batch_loader = train_loader 71 | elif config.split == 'val': 72 | batch_loader = val_loader 73 | elif config.split == 'test': 74 | batch_loader = test_loader 75 | # Shuffle and prefetch to get same data for different models 76 | if 'gqn' not in config.data_config: 77 | batch_loader = torch.utils.data.DataLoader( 78 | batch_loader.dataset, batch_size=1, num_workers=0, shuffle=True) 79 | # Prefetch batches 80 | prefetched_batches = [] 81 | for i, x in enumerate(batch_loader): 82 | if i == config.num_images: 83 | break 84 | prefetched_batches.append(x) 85 | 86 | # Load model 87 | model = fet.load(config.model_config, pretrained_flags) 88 | fprint(model) 89 | model_path = os.path.join(config.model_dir, config.model_file) 90 | fprint(f"Restoring model from {model_path}") 91 | checkpoint = torch.load(model_path, map_location='cpu') 92 | model_state_dict = checkpoint['model_state_dict'] 93 | model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1', None) 94 | model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2', None) 95 | model.load_state_dict(model_state_dict) 96 | 97 | # Set experiment folder and fprint file for logging 98 | fet.EXPERIMENT_FOLDER = config.model_dir 99 | fet.FPRINT_FILE = 'segmentation_metrics.txt' 100 | 101 | # Compute metrics 102 | model.eval() 103 | ari_fg_list, msc_fg_list, ari_fg_r_list, msc_fg_r_list = [], [], [], [] 104 | with torch.no_grad(): 105 | for i, x in enumerate(tqdm(prefetched_batches)): 106 | _, _, stats, _, _ = model(x['input']) 107 | for mode in ['log_m_k', 'log_m_r_k']: 108 | if mode in stats: 109 | log_masks = stats[mode] 110 | else: 111 | continue 112 | # ARI 113 | ari_fg, _ = average_ari(log_masks, x['instances'], 114 | foreground_only=True) 115 | # Segmentation covering - foreground only 116 | ins_seg = torch.argmax(torch.cat(log_masks, 1), 1, True) 117 | msc_fg, _ = average_segcover(x['instances'], ins_seg, True) 118 | # Recording 119 | if mode == 'log_m_k': 120 | ari_fg_list.append(ari_fg) 121 | msc_fg_list.append(msc_fg) 122 | elif mode == 'log_m_r_k': 123 | ari_fg_r_list.append(ari_fg) 124 | msc_fg_r_list.append(msc_fg) 125 | 126 | # Print average metrics 127 | fprint(f"Average FG ARI: {sum(ari_fg_list)/len(ari_fg_list)}") 128 | fprint(f"Average FG MSC: {sum(msc_fg_list)/len(msc_fg_list)}") 129 | if len(ari_fg_r_list) > 0 and len(msc_fg_r_list) > 0: 130 | fprint(f"Average FG-R ARI: {sum(ari_fg_r_list)/len(ari_fg_r_list)}") 131 | fprint(f"Average FG-R MSC: {sum(msc_fg_r_list)/len(msc_fg_r_list)}") 132 | 133 | 134 | if __name__ == "__main__": 135 | main() 136 | -------------------------------------------------------------------------------- /scripts/generate_multid.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import random 15 | from random import randint, choice 16 | 17 | import torch 18 | 19 | import numpy as np 20 | from PIL import Image 21 | 22 | 23 | # Set manual seed 24 | torch.manual_seed(0) 25 | np.random.seed(0) 26 | random.seed(0) 27 | # Make CUDA operations deterministic 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = False 30 | 31 | 32 | def rand_rgb_tuple(): 33 | val = [0, 63, 127, 191, 255] 34 | return choice(val), choice(val), choice(val) 35 | 36 | 37 | def generate(sprites, dataset_size, num_objects=None, unique=False): 38 | # Initialise 39 | all_images = np.zeros((dataset_size, 64, 64, 3)) 40 | all_instance_masks = np.zeros((dataset_size, 64, 64, 1)) 41 | 42 | # Create images 43 | for i in range(dataset_size): 44 | if (i+1)%10000 == 0: 45 | print(f"Processing [{i+1} | {dataset_size}]") 46 | 47 | # Create background 48 | background_colour = rand_rgb_tuple() 49 | image = np.array(Image.new('RGB', (64, 64), background_colour)) 50 | # Initialise instance masks 51 | instance_masks = np.zeros((64, 64, 1)).astype('int') 52 | 53 | img_colours = [background_colour] 54 | 55 | # Add objects 56 | if num_objects is None: 57 | num_sprites = randint(1, 4) 58 | else: 59 | num_sprites = num_objects 60 | for obj_idx in range(num_sprites): 61 | object_index = randint(0, 737279) 62 | sprite_mask = np.array(sprites[object_index], dtype=bool) 63 | crop_index = np.where(sprite_mask == True) 64 | object_colour = rand_rgb_tuple() 65 | # Optional: get new random colour if colour has already been used 66 | while unique and object_colour in img_colours: 67 | object_colour = rand_rgb_tuple() 68 | image[crop_index] = object_colour 69 | instance_masks[crop_index] = obj_idx + 1 70 | img_colours.append(object_colour) 71 | # Collate 72 | all_images[i] = image 73 | all_instance_masks[i] = instance_masks 74 | 75 | all_images = all_images.astype('float32') / 255.0 76 | return all_images, all_instance_masks 77 | 78 | 79 | def main(): 80 | # Load dataset 81 | dataset_zip = np.load( 82 | 'data/multi_dsprites/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', 83 | encoding="latin1") 84 | sprites = dataset_zip['imgs'] 85 | 86 | # --- Random colours --- 87 | # Generate training data 88 | print("Generate training images...") 89 | train_images, train_masks = generate(sprites, 50000) 90 | print("Saving...") 91 | np.save("data/multi_dsprites/processed/training_images_rand4.npy", 92 | train_images) 93 | np.save("data/multi_dsprites/processed/training_masks_rand4.npy", 94 | train_masks) 95 | # Generate validation data 96 | print("Generate validation images...") 97 | val_images, val_masks = generate(sprites, 10000) 98 | print("Saving...") 99 | np.save("data/multi_dsprites/processed/validation_images_rand4.npy", 100 | val_images) 101 | np.save("data/multi_dsprites/processed/validation_masks_rand4.npy", 102 | val_masks) 103 | # Generate test data 104 | print("Generate test images...") 105 | test_images, test_masks = generate(sprites, 10000) 106 | print("Saving...") 107 | np.save("data/multi_dsprites/processed/test_images_rand4.npy", 108 | test_images) 109 | np.save("data/multi_dsprites/processed/test_masks_rand4.npy", 110 | test_masks) 111 | print("Done!") 112 | 113 | # --- Unique random colours --- 114 | # Generate training data 115 | print("Generate training images...") 116 | train_images, train_masks = generate(sprites, 50000, unique=True) 117 | print("Saving...") 118 | np.save("data/multi_dsprites/processed/training_images_rand4_unique.npy", 119 | train_images) 120 | np.save("data/multi_dsprites/processed/training_masks_rand4_unique.npy", 121 | train_masks) 122 | # Generate validation data 123 | print("Generate validation images...") 124 | val_images, val_masks = generate(sprites, 10000, unique=True) 125 | print("Saving...") 126 | np.save("data/multi_dsprites/processed/validation_images_rand4_unique.npy", 127 | val_images) 128 | np.save("data/multi_dsprites/processed/validation_masks_rand4_unique.npy", 129 | val_masks) 130 | # Generate test data 131 | print("Generate test images...") 132 | test_images, test_masks = generate(sprites, 10000, unique=True) 133 | print("Saving...") 134 | np.save("data/multi_dsprites/processed/test_images_rand4_unique.npy", 135 | test_images) 136 | np.save("data/multi_dsprites/processed/test_masks_rand4_unique.npy", 137 | test_masks) 138 | print("Done!") 139 | 140 | 141 | if __name__ == "__main__": 142 | main() 143 | -------------------------------------------------------------------------------- /scripts/sketchy_preparation.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import os 15 | from glob import glob 16 | from tqdm import tqdm 17 | from PIL import Image 18 | 19 | import torch 20 | from sketchy import sketchy 21 | 22 | data_folder = 'data/sketchy' 23 | filenames = sorted(glob(f'{data_folder}/records/*')) 24 | 25 | # Split into train/valid/test files 26 | num_files = len(filenames) 27 | num_eval = num_files//10 28 | valid_files = filenames[:num_eval] 29 | test_files = filenames[num_eval:2*num_eval] 30 | train_files = filenames[2*num_eval:] 31 | 32 | thumbnail_size = (128, 128) 33 | 34 | # Check for dublicates 35 | all_files = train_files + valid_files + test_files 36 | assert not len(all_files) != len(set(all_files)) 37 | 38 | episode_idx = 0 39 | for mode, files in zip(['train', 'valid', 'test'], [train_files, valid_files, test_files]): 40 | save_folder = f'{data_folder}/processed/{mode}' 41 | print(f'Processing {mode} data. Destination: {save_folder}') 42 | os.makedirs(save_folder) 43 | for episode_file in tqdm(files): 44 | episode = sketchy.load_frames(episode_file, 4) 45 | episode_folder = f'{save_folder}/ep{str(episode_idx).zfill(6)}' 46 | os.makedirs(episode_folder) 47 | prefix = f'{episode_folder}/ep{str(episode_idx).zfill(6)}' 48 | for ex_idx, frame in enumerate(episode): 49 | im_fl = frame['pixels/basket_front_left'].numpy() 50 | im_fr = frame['pixels/basket_front_right'].numpy() 51 | # Crop to 448x672 52 | im_fl = im_fl[71:-81, 144:-144] 53 | im_fr = im_fr[91:-61, 144:-144] 54 | assert im_fl.shape == im_fr.shape 55 | ss = im_fl.shape[0] # short side 56 | ls = im_fl.shape[1] # long side 57 | cs = ss-64-32 # crop size 58 | mc = int(ls//2 - cs//2) # middle crop location 59 | for im, view in zip([im_fl, im_fr], ['fl', 'fr']): 60 | # Save full image (448x448 crop) 61 | full = Image.fromarray(im[:, int(ls//2-ss//2):int(ls//2-ss//2)+ss]) 62 | full = full.resize(thumbnail_size, resample=Image.BILINEAR) 63 | full.save(f'{prefix}_t{str(ex_idx).zfill(3)}_{view}_full.png') 64 | # Save crops 65 | c = 0 66 | for x1, x2 in zip([0, -cs], [cs, ss+1]): 67 | for y1, y2 in zip([0, mc, -cs], [cs, mc+cs, ls+1]): 68 | crop = im[x1:x2, y1:y2, :] 69 | crop = Image.fromarray(crop) 70 | crop = crop.resize(thumbnail_size, resample=Image.BILINEAR) 71 | crop.save(f'{prefix}_t{str(ex_idx).zfill(3)}_{view}_c{c}.png') 72 | c += 1 73 | state = {} 74 | for key, val in frame.items(): 75 | if 'pixels' in key: 76 | continue 77 | state[key] = torch.tensor(val.numpy()) 78 | torch.save(state, f'{prefix}_t{str(ex_idx).zfill(3)}_state.pt') 79 | episode_idx += 1 80 | -------------------------------------------------------------------------------- /scripts/visualise_data.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import json 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | from matplotlib.colors import NoNorm 18 | 19 | import torch 20 | 21 | import forge 22 | from forge import flags 23 | import forge.experiment_tools as fet 24 | 25 | # Config 26 | flags.DEFINE_string('data_config', 'datasets/multid_config.py', 27 | 'Path to a data config file.') 28 | flags.DEFINE_integer('batch_size', 8, 'Mini-batch size.') 29 | flags.DEFINE_integer('seed', 0, 'Seed for random number generators.') 30 | 31 | 32 | def main(): 33 | # Parse flags 34 | cfg = forge.config() 35 | cfg.num_workers = 0 36 | 37 | # Set manual seed 38 | torch.manual_seed(cfg.seed) 39 | np.random.seed(cfg.seed) 40 | # Make CUDA operations deterministic 41 | torch.backends.cudnn.deterministic = True 42 | torch.backends.cudnn.benchmark = False 43 | 44 | # Get data loaders 45 | train_loader, _, _ = fet.load(cfg.data_config, cfg) 46 | 47 | # Optimally distinct RGB colour palette (15 colours) 48 | colours = json.load(open('utils/colour_palette15.json')) 49 | 50 | # Visualise 51 | for x in train_loader: 52 | fig, axes = plt.subplots(2, cfg.batch_size, figsize=(20,10)) 53 | 54 | for f_idx, field in enumerate(['input', 'instances']): 55 | for b_idx in range(cfg.batch_size): 56 | axes[f_idx, b_idx].axis('off') 57 | 58 | if field not in x: 59 | continue 60 | img = x[field] 61 | 62 | # Colour instance masks 63 | if field == 'instances': 64 | img_list = [] 65 | for b_idx in range(img.shape[0]): 66 | instances = img[b_idx, :, :, :] 67 | img_r = torch.zeros_like(instances) 68 | img_g = torch.zeros_like(instances) 69 | img_b = torch.zeros_like(instances) 70 | ins_idx = 0 71 | for ins in range(instances.max().numpy()): 72 | ins_map = instances == ins + 1 73 | if ins_map.any(): 74 | img_r[ins_map] = colours['palette'][ins_idx][0] 75 | img_g[ins_map] = colours['palette'][ins_idx][1] 76 | img_b[ins_map] = colours['palette'][ins_idx][2] 77 | ins_idx += 1 78 | img_list.append(torch.cat([img_r, img_g, img_b], dim=0)) 79 | img = torch.stack(img_list, dim=0) 80 | 81 | for b_idx in range(cfg.batch_size): 82 | np_img = np.moveaxis(img.data.numpy()[b_idx], 0, -1) 83 | if img.shape[1] == 1: 84 | axes[f_idx, b_idx].imshow( 85 | np_img[:, :, 0], norm=NoNorm(), cmap='gray') 86 | elif img.shape[1] == 3: 87 | axes[f_idx, b_idx].imshow(np_img) 88 | 89 | manager = plt.get_current_fig_manager() 90 | manager.resize(*manager.window.maxsize()) 91 | plt.show() 92 | 93 | if __name__ == "__main__": 94 | main() 95 | -------------------------------------------------------------------------------- /scripts/visualise_generation.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import os 15 | from os import path as osp 16 | import random 17 | from attrdict import AttrDict 18 | 19 | import torch 20 | import numpy as np 21 | import matplotlib.pyplot as plt 22 | 23 | import forge 24 | from forge import flags 25 | import forge.experiment_tools as fet 26 | from forge.experiment_tools import fprint 27 | 28 | from utils.plotting import plot 29 | 30 | 31 | # Data & model config 32 | flags.DEFINE_string('data_config', 'datasets/gqn_config.py', 33 | 'Path to a data config file.') 34 | flags.DEFINE_string('model_config', 'models/genesis_config.py', 35 | 'Path to a model config file.') 36 | # Trained model 37 | flags.DEFINE_string('model_dir', 'checkpoints/test/1', 38 | 'Path to model directory.') 39 | flags.DEFINE_string('model_file', 'model.ckpt-FINAL', 'Name of model file.') 40 | 41 | 42 | def main(): 43 | # Parse flags 44 | config = forge.config() 45 | # Restore flags of pretrained model 46 | flag_path = osp.join(config.model_dir, 'flags.json') 47 | fprint(f"Restoring flags from {flag_path}") 48 | pretrained_flags = AttrDict(fet.json_load(flag_path)) 49 | pretrained_flags.batch_size = 1 50 | pretrained_flags.gpu = False 51 | pretrained_flags.debug = True 52 | fet.print_flags() 53 | 54 | # Fix seeds. Always first thing to be done after parsing the config! 55 | torch.manual_seed(0) 56 | np.random.seed(0) 57 | random.seed(0) 58 | # Make CUDA operations deterministic 59 | torch.backends.cudnn.deterministic = True 60 | torch.backends.cudnn.benchmark = False 61 | 62 | # Load model 63 | if pretrained_flags.K_steps < 0 and 'multi_object' in config.data_config: 64 | if pretrained_flags.dataset == 'multi_dprites': 65 | pretrained_flags.K_steps = 5 66 | elif pretrained_flags.dataset == 'objects_room': 67 | pretrained_flags.K_steps = 7 68 | elif pretrained_flags.dataset == 'clevr': 69 | pretrained_flags.K_steps = 11 70 | elif pretrained_flags.dataset == 'tetrominoes': 71 | pretrained_flags.K_steps = 4 72 | model = fet.load(config.model_config, pretrained_flags) 73 | model_path = osp.join(config.model_dir, config.model_file) 74 | fprint(f"Restoring model from {model_path}") 75 | checkpoint = torch.load(model_path, map_location='cpu') 76 | model_state_dict = checkpoint['model_state_dict'] 77 | model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1', None) 78 | model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2', None) 79 | model.load_state_dict(model_state_dict) 80 | fprint(model) 81 | 82 | # Visualise 83 | model.eval() 84 | for _ in range(100): 85 | y, stats = model.sample(1, pretrained_flags.K_steps) 86 | fig, axes = plt.subplots(nrows=4, ncols=1+pretrained_flags.K_steps) 87 | 88 | # Generated 89 | plot(axes, 0, 0, y, title='Generated scene', fontsize=12) 90 | # Empty plots 91 | plot(axes, 1, 0, fontsize=12) 92 | plot(axes, 2, 0, fontsize=12) 93 | plot(axes, 3, 0, fontsize=12) 94 | 95 | # Put K generation steps in separate subfigures 96 | for step in range(pretrained_flags.K_steps): 97 | x_step = stats['x_k'][step] 98 | m_step = stats['log_m_k'][step].exp() 99 | mx_step = stats['mx_k'][step] 100 | if 'log_s_k' in stats: 101 | s_step = stats['log_s_k'][step].exp() 102 | pre = 'Mask x RGB ' if step == 0 else '' 103 | plot(axes, 0, 1+step, mx_step, pre+f'k={step+1}', fontsize=12) 104 | pre = 'RGB ' if step == 0 else '' 105 | plot(axes, 1, 1+step, x_step, pre+f'k={step+1}', fontsize=12) 106 | pre = 'Mask ' if step == 0 else '' 107 | plot(axes, 2, 1+step, m_step, pre+f'k={step+1}', True, fontsize=12) 108 | if 'log_s_k' in stats: 109 | pre = 'Scope ' if step == 0 else '' 110 | plot(axes, 3, 1+step, s_step, pre+f'k={step+1}', True, 111 | axis=step == 0, fontsize=12) 112 | 113 | # Beautify and show figure 114 | plt.subplots_adjust(wspace=0.05, hspace=0.05) 115 | manager = plt.get_current_fig_manager() 116 | manager.resize(*manager.window.maxsize()) 117 | plt.show() 118 | 119 | 120 | if __name__ == "__main__": 121 | main() 122 | -------------------------------------------------------------------------------- /scripts/visualise_reconstruction.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import os 15 | from os import path as osp 16 | import random 17 | from attrdict import AttrDict 18 | 19 | import torch 20 | import numpy as np 21 | import matplotlib.pyplot as plt 22 | 23 | import forge 24 | from forge import flags 25 | import forge.experiment_tools as fet 26 | from forge.experiment_tools import fprint 27 | 28 | from utils.plotting import plot 29 | 30 | 31 | # Data & model config 32 | flags.DEFINE_string('data_config', 'datasets/gqn_config.py', 33 | 'Path to a data config file.') 34 | flags.DEFINE_string('model_config', 'models/genesis_config.py', 35 | 'Path to a model config file.') 36 | # Trained model 37 | flags.DEFINE_string('model_dir', 'checkpoints/test/1', 38 | 'Path to model directory.') 39 | flags.DEFINE_string('model_file', 'model.ckpt-FINAL', 'Name of model file.') 40 | # Other 41 | flags.DEFINE_integer('num_images', 10, 'Number of images to visualize.') 42 | 43 | 44 | def main(): 45 | # Parse flags 46 | config = forge.config() 47 | fet.print_flags() 48 | # Restore flags of pretrained model 49 | flag_path = osp.join(config.model_dir, 'flags.json') 50 | fprint(f"Restoring flags from {flag_path}") 51 | pretrained_flags = AttrDict(fet.json_load(flag_path)) 52 | pretrained_flags.debug = True 53 | 54 | # Fix seeds. Always first thing to be done after parsing the config! 55 | torch.manual_seed(0) 56 | np.random.seed(0) 57 | random.seed(0) 58 | # Make CUDA operations deterministic 59 | torch.backends.cudnn.deterministic = True 60 | torch.backends.cudnn.benchmark = False 61 | 62 | # Load data 63 | config.batch_size = 1 64 | _, _, test_loader = fet.load(config.data_config, config) 65 | if pretrained_flags.K_steps < 0 and 'multi_object' in config.data_config: 66 | pretrained_flags.K_steps = config.K_steps 67 | 68 | # Load model 69 | model = fet.load(config.model_config, pretrained_flags) 70 | model_path = osp.join(config.model_dir, config.model_file) 71 | fprint(f"Restoring model from {model_path}") 72 | checkpoint = torch.load(model_path, map_location='cpu') 73 | model_state_dict = checkpoint['model_state_dict'] 74 | model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1', None) 75 | model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2', None) 76 | model.load_state_dict(model_state_dict) 77 | fprint(model) 78 | 79 | # Visualise 80 | model.eval() 81 | for count, batch in enumerate(test_loader): 82 | if count >= config.num_images: 83 | break 84 | 85 | # Forward pass 86 | output, _, stats, _, _ = model(batch['input']) 87 | # Set up figure 88 | fig, axes = plt.subplots(nrows=4, ncols=1+pretrained_flags.K_steps) 89 | 90 | # Input and reconstruction 91 | plot(axes, 0, 0, batch['input'], title='Input image', fontsize=12) 92 | plot(axes, 1, 0, output, title='Reconstruction', fontsize=12) 93 | # Empty plots 94 | plot(axes, 2, 0, fontsize=12) 95 | plot(axes, 3, 0, fontsize=12) 96 | 97 | # Put K reconstruction steps into separate subfigures 98 | x_k = stats['x_r_k'] 99 | if 'genesisv2' not in config.model_config: 100 | log_masks = stats['log_m_k'] 101 | else: 102 | log_masks = stats['log_m_r_k'] 103 | mx_k = [x*m.exp() for x, m in zip(x_k, log_masks)] 104 | log_s_k = stats['log_s_k'] if 'log_s_k' in stats else None 105 | for step in range(pretrained_flags.K_steps): 106 | mx_step = mx_k[step] 107 | x_step = x_k[step] 108 | m_step = log_masks[step].exp() 109 | if log_s_k: 110 | s_step = log_s_k[step].exp() 111 | 112 | pre = 'Mask x RGB ' if step == 0 else '' 113 | plot(axes, 0, 1+step, mx_step, pre+f'k={step+1}', fontsize=12) 114 | pre = 'RGB ' if step == 0 else '' 115 | plot(axes, 1, 1+step, x_step, pre+f'k={step+1}', fontsize=12) 116 | pre = 'Mask ' if step == 0 else '' 117 | plot(axes, 2, 1+step, m_step, pre+f'k={step+1}', True, fontsize=12) 118 | if log_s_k: 119 | pre = 'Scope ' if step == 0 else '' 120 | plot(axes, 3, 1+step, s_step, pre+f'k={step+1}', True, 121 | axis=step == 0, fontsize=12) 122 | 123 | # Beautify and show figure 124 | plt.subplots_adjust(wspace=0.05, hspace=0.15) 125 | manager = plt.get_current_fig_manager() 126 | manager.resize(*manager.window.maxsize()) 127 | plt.show() 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /third_party/__init__ .py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/applied-ai-lab/genesis/9abf202bbad6fa4a675117fdea0be163e4f16695/third_party/__init__ .py -------------------------------------------------------------------------------- /third_party/multi_object_datasets/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /third_party/multi_object_datasets/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /third_party/multi_object_datasets/README.md: -------------------------------------------------------------------------------- 1 | # Multi-Object Datasets 2 | 3 | This repository contains datasets for multi-object representation learning, used 4 | in developing scene decomposition methods like 5 | [MONet](https://arxiv.org/abs/1901.11390) [1] and 6 | [IODINE](http://proceedings.mlr.press/v97/greff19a.html) [2]. The datasets we 7 | provide are: 8 | 9 | 1. [Multi-dSprites](#multi-dsprites) 10 | 2. [Objects Room](#objects-room) 11 | 3. [CLEVR (with masks)](#clevr-with-masks) 12 | 4. [Tetrominoes](#tetrominoes) 13 | 14 | ![preview](preview.png) 15 | 16 | The datasets consist of multi-object scenes. Each image is accompanied by 17 | ground-truth segmentation masks for all objects in the scene. We also provide 18 | per-object generative factors (except in Objects Room) to facilitate 19 | representation learning. The generative factors include all necessary and 20 | sufficient features (size, color, position, etc.) to describe and render the 21 | objects present in a scene. 22 | 23 | Lastly, the `segmentation_metrics` module contains a TensorFlow implementation 24 | of the 25 | [adjusted Rand index](https://en.wikipedia.org/wiki/Rand_index#Adjusted_Rand_index) 26 | [3], which can be used to compare inferred object segmentations with 27 | ground-truth segmentation masks. All code has been tested to work with 28 | TensorFlow r1.14. 29 | 30 | ## Bibtex 31 | 32 | If you use one of these datasets in your work, please cite it as follows: 33 | 34 | ``` 35 | @misc{multiobjectdatasets19, 36 | title={Multi-Object Datasets}, 37 | author={Kabra, Rishabh and Burgess, Chris and Matthey, Loic and 38 | Kaufman, Raphael Lopez and Greff, Klaus and Reynolds, Malcolm and 39 | Lerchner, Alexander}, 40 | howpublished={https://github.com/deepmind/multi-object-datasets/}, 41 | year={2019} 42 | } 43 | ``` 44 | 45 | ## Descriptions 46 | 47 | ### Multi-dSprites 48 | 49 | This is a dataset based on 50 | [dSprites](https://github.com/deepmind/dsprites-dataset). Each image consists of 51 | multiple oval, heart, or square-shaped sprites (with some occlusions) set 52 | against a uniformly colored background. 53 | 54 | We're releasing three versions of this dataset containing 1M datapoints each: 55 | 56 | 1.1 Binarized: each image has 2-3 white sprites on a black background. 57 | 58 | 1.2 Colored sprites on grayscale: each scene has 2-5 randomly colored HSV 59 | sprites on a randomly sampled grayscale background. 60 | 61 | 1.3 Colored sprites and background: each scene has 1-4 sprites. All colors are 62 | randomly sampled RGB values. 63 | 64 | Each datapoint contains an image, a number of background and object masks, and 65 | the following ground-truth features per object: `x` and `y` positions, `shape`, 66 | `color` (rgb values), `orientation`, and `scale`. Lastly, `visibility` is a 67 | binary feature indicating which objects are not null. 68 | 69 | ### Objects Room 70 | 71 | This dataset is based on the [MuJoCo](http://www.mujoco.org/) environment used 72 | by the Generative Query Network [4] and is a multi-object extension of the 73 | [3d-shapes dataset](https://github.com/deepmind/3d-shapes). The training set 74 | contains 1M scenes with up to three objects. We also provide ~1K test examples 75 | for the following variants: 76 | 77 | 2.1 Empty room: scenes consist of the sky, walls, and floor only. 78 | 79 | 2.2 Six objects: exactly 6 objects are visible in each image. 80 | 81 | 2.3 Identical color: 4-6 objects are placed in the room and have an identical, 82 | randomly sampled color. 83 | 84 | Datapoints consist of an image and fixed number of masks. The first four masks 85 | correspond to the sky, floor, and two halves of the wall respectively. The 86 | remaining masks correspond to the foreground objects. 87 | 88 | ### CLEVR (with masks) 89 | 90 | We adapted the 91 | [open-source script](https://github.com/facebookresearch/clevr-dataset-gen) 92 | provided by Johnson et al. to produce ground-truth segmentation masks for CLEVR 93 | [5] scenes. These were generated afresh, so images in this dataset are not 94 | identical to those in the original CLEVR dataset. We ignore the original 95 | question-answering task. 96 | 97 | The images and masks in the dataset are of size 320x240. We also provide all 98 | ground-truth factors included in the original dataset (namely `x`, `y`, and `z` 99 | position, `pixel_coords`, and `rotation`, which are real-valued; plus `size`, 100 | `material`, `shape`, and `color`, which are encoded as integers) along with a 101 | `visibility` vector to indicate which objects are not null. 102 | 103 | ### Tetrominoes 104 | 105 | This is a dataset of Tetris-like shapes (aka tetrominoes). Each 35x35 image 106 | contains three tetrominoes, sampled from 17 unique shapes/orientations. Each 107 | tetromino has one of six possible colors (red, green, blue, yellow, magenta, 108 | cyan). We provide `x` and `y` position, `shape`, and `color` (integer-coded) as 109 | ground-truth features. Datapoints also include a `visibility` vector. 110 | 111 | ## Download 112 | 113 | The datasets can be downloaded from 114 | [Google Cloud Storage](https://console.cloud.google.com/storage/browser/multi-object-datasets). 115 | Each dataset is a single 116 | [TFRecords](https://www.tensorflow.org/tutorials/load_data/tf_records) file. To 117 | download a particular dataset, use the web interface, or run `wget` with the 118 | appropriate filename as follows: 119 | 120 | ```shell 121 | wget https://storage.googleapis.com/multi-object-datasets/multi_dsprites/multi_dsprites_colored_on_colored.tfrecords 122 | ``` 123 | 124 | To download all datasets, you'll need the `gsutil` tool, which comes with the 125 | [Google Cloud SDK](https://cloud.google.com/sdk/docs/). Simply run: 126 | 127 | ```shell 128 | gsutil cp -r gs://multi-object-datasets . 129 | ``` 130 | 131 | The approximate download sizes are: 132 | 133 | 1. Multi-dSprites: between 500 MB and 1 GB. 134 | 2. Objects Room: the training set is 7 GB. The test sets are 6-8 MB. 135 | 3. CLEVR (with masks): 10.5 GB. 136 | 4. Tetrominoes: 300 MB. 137 | 138 | ## Usage 139 | 140 | After downloading the dataset files, you can read them as 141 | [`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) 142 | instances with the readers provided. The example below shows how to read the 143 | colored-sprites-and-background version of Multi-dSprites: 144 | 145 | ```python 146 | import multi_dsprites 147 | import tensorflow as tf 148 | 149 | tf_records_path = 'path/to/multi_dsprites_colored_on_colored.tfrecords' 150 | batch_size = 32 151 | 152 | dataset = multi_dsprites.dataset(tf_records_path, 'colored_on_colored') 153 | batched_dataset = dataset.batch(batch_size) # optional batching 154 | iterator = batched_dataset.make_one_shot_iterator() 155 | data = iterator.get_next() 156 | 157 | with tf.train.SingularMonitoredSession() as sess: 158 | d = sess.run(data) 159 | ``` 160 | 161 | All dataset readers return images and segmentation masks in the following 162 | canonical format (assuming the dataset is batched as above): 163 | 164 | - 'image': `Tensor` of shape [batch_size, height, width, channels] and type 165 | uint8. 166 | 167 | - 'mask': `Tensor` of shape [batch_size, max_num_entities, height, width, 168 | channels] and type uint8. The tensor takes on values of 255 or 0, denoting 169 | whether a pixel belongs to a particular entity or not. 170 | 171 | You can compare predicted object segmentation masks with the ground-truth masks 172 | using `segmentation_metrics.adjusted_rand_index` as below: 173 | 174 | ```python 175 | max_num_entities = multi_dsprites.MAX_NUM_ENTITIES['colored_on_colored'] 176 | # Ground-truth segmentation masks are always returned in the canonical 177 | # [batch_size, max_num_entities, height, width, channels] format. To use these 178 | # as an input for `segmentation_metrics.adjusted_rand_index`, we need them in 179 | # the [batch_size, n_points, n_true_groups] format, 180 | # where n_true_groups == max_num_entities. We implement this reshape below. 181 | # Note that 'oh' denotes 'one-hot'. 182 | desired_shape = [batch_size, 183 | multi_dsprites.IMAGE_SIZE[0] * multi_dsprites.IMAGE_SIZE[1], 184 | max_num_entities] 185 | true_groups_oh = tf.transpose(data['mask'], [0, 2, 3, 4, 1]) 186 | true_groups_oh = tf.reshape(true_groups_oh, desired_shape) 187 | 188 | random_prediction = tf.random_uniform(desired_shape[:-1], 189 | minval=0, maxval=max_num_entities, 190 | dtype=tf.int32) 191 | random_prediction_oh = tf.one_hot(random_prediction, depth=max_num_entities) 192 | 193 | ari = segmentation_metrics.adjusted_rand_index(true_groups_oh, 194 | random_prediction_oh) 195 | ``` 196 | 197 | To exclude all background pixels from the ARI score (as in [2]), you can compute 198 | it as follows instead. This assumes the first true group contains all background 199 | pixels: 200 | 201 | ```python 202 | ari_nobg = segmentation_metrics.adjusted_rand_index(true_groups_oh[..., 1:], 203 | random_prediction_oh) 204 | ``` 205 | 206 | ## References 207 | 208 | [1] Burgess, C. P., Matthey, L., Watters, N., Kabra, R., Higgins, I., Botvinick, 209 | M., & Lerchner, A. (2019). Monet: Unsupervised scene decomposition and 210 | representation. arXiv preprint arXiv:1901.11390. 211 | 212 | [2] Greff, K., Kaufman, R. L., Kabra, R., Watters, N., Burgess, C., Zoran, D., 213 | Matthey, L., Botvinick, M., & Lerchner, A. (2019). Multi-Object Representation 214 | Learning with Iterative Variational Inference. Proceedings of the 36th 215 | International Conference on Machine Learning, in PMLR 97:2424-2433. 216 | 217 | [3] Rand, W. M. (1971). Objective criteria for the evaluation of clustering 218 | methods. Journal of the American Statistical association, 66(336), 846-850. 219 | 220 | [4] Eslami, S., Rezende, D. J., Besse, F., Viola, F., Morcos, A., Garnelo, M., 221 | Ruderman, A., Rusu, A., Danihelka, I., Gregor, K., Reichert, D., Buesing, L., 222 | Weber, T., Vinyals, O., Rosenbaum, D., Rabinowitz, N., King, H., Hillier, C., 223 | Botvinick, M., Wierstra, D., Kavukcuoglu, K., & Hassabis, D. (2018). Neural 224 | scene representation and rendering. Science, 360(6394), 1204-1210. 225 | 226 | [5] Johnson, J., Hariharan, B., van der Maaten, L., Fei-Fei, L., Lawrence 227 | Zitnick, C., & Girshick, R. (2017). Clevr: A diagnostic dataset for 228 | compositional language and elementary visual reasoning. In Proceedings of the 229 | IEEE Conference on Computer Vision and Pattern Recognition (pp. 2901-2910). 230 | 231 | ## Disclaimers 232 | 233 | This is not an official Google product. 234 | -------------------------------------------------------------------------------- /third_party/multi_object_datasets/clevr_with_masks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """CLEVR (with masks) dataset reader.""" 16 | 17 | import tensorflow as tf 18 | 19 | 20 | COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string('GZIP') 21 | IMAGE_SIZE = [240, 320] 22 | # The maximum number of foreground and background entities in the provided 23 | # dataset. This corresponds to the number of segmentation masks returned per 24 | # scene. 25 | MAX_NUM_ENTITIES = 11 26 | BYTE_FEATURES = ['mask', 'image', 'color', 'material', 'shape', 'size'] 27 | 28 | # Create a dictionary mapping feature names to `tf.Example`-compatible 29 | # shape and data type descriptors. 30 | features = { 31 | 'image': tf.FixedLenFeature(IMAGE_SIZE+[3], tf.string), 32 | 'mask': tf.FixedLenFeature([MAX_NUM_ENTITIES]+IMAGE_SIZE+[1], tf.string), 33 | 'x': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 34 | 'y': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 35 | 'z': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 36 | 'pixel_coords': tf.FixedLenFeature([MAX_NUM_ENTITIES, 3], tf.float32), 37 | 'rotation': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 38 | 'size': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string), 39 | 'material': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string), 40 | 'shape': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string), 41 | 'color': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string), 42 | 'visibility': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 43 | } 44 | 45 | 46 | def _decode(example_proto): 47 | # Parse the input `tf.Example` proto using the feature description dict above. 48 | single_example = tf.parse_single_example(example_proto, features) 49 | for k in BYTE_FEATURES: 50 | single_example[k] = tf.squeeze(tf.decode_raw(single_example[k], tf.uint8), 51 | axis=-1) 52 | return single_example 53 | 54 | 55 | def dataset(tfrecords_path, read_buffer_size=None, map_parallel_calls=None): 56 | """Read, decompress, and parse the TFRecords file. 57 | 58 | Args: 59 | tfrecords_path: str. Path to the dataset file. 60 | read_buffer_size: int. Number of bytes in the read buffer. See documentation 61 | for `tf.data.TFRecordDataset.__init__`. 62 | map_parallel_calls: int. Number of elements decoded asynchronously in 63 | parallel. See documentation for `tf.data.Dataset.map`. 64 | 65 | Returns: 66 | An unbatched `tf.data.TFRecordDataset`. 67 | """ 68 | raw_dataset = tf.data.TFRecordDataset( 69 | tfrecords_path, compression_type=COMPRESSION_TYPE, 70 | buffer_size=read_buffer_size) 71 | return raw_dataset.map(_decode, num_parallel_calls=map_parallel_calls) 72 | -------------------------------------------------------------------------------- /third_party/multi_object_datasets/multi_dsprites.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Multi-dSprites dataset reader.""" 16 | 17 | import functools 18 | import tensorflow as tf 19 | 20 | 21 | COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string('GZIP') 22 | IMAGE_SIZE = [64, 64] 23 | # The maximum number of foreground and background entities in each variant 24 | # of the provided datasets. The values correspond to the number of 25 | # segmentation masks returned per scene. 26 | MAX_NUM_ENTITIES = { 27 | 'binarized': 4, 28 | 'colored_on_grayscale': 6, 29 | 'colored_on_colored': 5 30 | } 31 | BYTE_FEATURES = ['mask', 'image'] 32 | 33 | 34 | def feature_descriptions(max_num_entities, is_grayscale=False): 35 | """Create a dictionary describing the dataset features. 36 | 37 | Args: 38 | max_num_entities: int. The maximum number of foreground and background 39 | entities in each image. This corresponds to the number of segmentation 40 | masks and generative factors returned per scene. 41 | is_grayscale: bool. Whether images are grayscale. Otherwise they're assumed 42 | to be RGB. 43 | 44 | Returns: 45 | A dictionary which maps feature names to `tf.Example`-compatible shape and 46 | data type descriptors. 47 | """ 48 | 49 | num_channels = 1 if is_grayscale else 3 50 | return { 51 | 'image': tf.FixedLenFeature(IMAGE_SIZE+[num_channels], tf.string), 52 | 'mask': tf.FixedLenFeature(IMAGE_SIZE+[max_num_entities, 1], tf.string), 53 | 'x': tf.FixedLenFeature([max_num_entities], tf.float32), 54 | 'y': tf.FixedLenFeature([max_num_entities], tf.float32), 55 | 'shape': tf.FixedLenFeature([max_num_entities], tf.float32), 56 | 'color': tf.FixedLenFeature([max_num_entities, num_channels], tf.float32), 57 | 'visibility': tf.FixedLenFeature([max_num_entities], tf.float32), 58 | 'orientation': tf.FixedLenFeature([max_num_entities], tf.float32), 59 | 'scale': tf.FixedLenFeature([max_num_entities], tf.float32), 60 | } 61 | 62 | 63 | def _decode(example_proto, features): 64 | # Parse the input `tf.Example` proto using a feature description dictionary. 65 | single_example = tf.parse_single_example(example_proto, features) 66 | for k in BYTE_FEATURES: 67 | single_example[k] = tf.squeeze(tf.decode_raw(single_example[k], tf.uint8), 68 | axis=-1) 69 | # To return masks in the canonical [entities, height, width, channels] format, 70 | # we need to transpose the tensor axes. 71 | single_example['mask'] = tf.transpose(single_example['mask'], [2, 0, 1, 3]) 72 | return single_example 73 | 74 | 75 | def dataset(tfrecords_path, dataset_variant, read_buffer_size=None, 76 | map_parallel_calls=None): 77 | """Read, decompress, and parse the TFRecords file. 78 | 79 | Args: 80 | tfrecords_path: str. Path to the dataset file. 81 | dataset_variant: str. One of ['binarized', 'colored_on_grayscale', 82 | 'colored_on_colored']. This is used to identify the maximum number of 83 | entities in each scene. If an incorrect identifier is passed in, the 84 | TFRecords file will not be read correctly. 85 | read_buffer_size: int. Number of bytes in the read buffer. See documentation 86 | for `tf.data.TFRecordDataset.__init__`. 87 | map_parallel_calls: int. Number of elements decoded asynchronously in 88 | parallel. See documentation for `tf.data.Dataset.map`. 89 | 90 | Returns: 91 | An unbatched `tf.data.TFRecordDataset`. 92 | """ 93 | if dataset_variant not in MAX_NUM_ENTITIES: 94 | raise ValueError('Invalid `dataset_variant` provided. The supported values' 95 | ' are: {}'.format(list(MAX_NUM_ENTITIES.keys()))) 96 | max_num_entities = MAX_NUM_ENTITIES[dataset_variant] 97 | is_grayscale = dataset_variant == 'binarized' 98 | raw_dataset = tf.data.TFRecordDataset( 99 | tfrecords_path, compression_type=COMPRESSION_TYPE, 100 | buffer_size=read_buffer_size) 101 | features = feature_descriptions(max_num_entities, is_grayscale) 102 | partial_decode_fn = functools.partial(_decode, features=features) 103 | return raw_dataset.map(partial_decode_fn, 104 | num_parallel_calls=map_parallel_calls) 105 | -------------------------------------------------------------------------------- /third_party/multi_object_datasets/objects_room.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Objects Room dataset reader.""" 16 | 17 | import functools 18 | import tensorflow as tf 19 | 20 | 21 | COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string('GZIP') 22 | IMAGE_SIZE = [64, 64] 23 | # The maximum number of foreground and background entities in each variant 24 | # of the provided datasets. The values correspond to the number of 25 | # segmentation masks returned per scene. 26 | MAX_NUM_ENTITIES = { 27 | 'train': 7, 28 | 'six_objects': 10, 29 | 'empty_room': 4, 30 | 'identical_color': 10 31 | } 32 | BYTE_FEATURES = ['mask', 'image'] 33 | 34 | 35 | def feature_descriptions(max_num_entities): 36 | """Create a dictionary describing the dataset features. 37 | 38 | Args: 39 | max_num_entities: int. The maximum number of foreground and background 40 | entities in each image. This corresponds to the number of segmentation 41 | masks returned per scene. 42 | 43 | Returns: 44 | A dictionary which maps feature names to `tf.Example`-compatible shape and 45 | data type descriptors. 46 | """ 47 | return { 48 | 'image': tf.FixedLenFeature(IMAGE_SIZE+[3], tf.string), 49 | 'mask': tf.FixedLenFeature([max_num_entities]+IMAGE_SIZE+[1], tf.string), 50 | } 51 | 52 | 53 | def _decode(example_proto, features): 54 | # Parse the input `tf.Example` proto using a feature description dictionary. 55 | single_example = tf.parse_single_example(example_proto, features) 56 | for k in BYTE_FEATURES: 57 | single_example[k] = tf.squeeze(tf.decode_raw(single_example[k], tf.uint8), 58 | axis=-1) 59 | return single_example 60 | 61 | 62 | def dataset(tfrecords_path, dataset_variant, read_buffer_size=None, 63 | map_parallel_calls=None): 64 | """Read, decompress, and parse the TFRecords file. 65 | 66 | Args: 67 | tfrecords_path: str. Path to the dataset file. 68 | dataset_variant: str. One of ['train', 'six_objects', 'empty_room', 69 | 'identical_color']. This is used to identify the maximum number of 70 | entities in each scene. If an incorrect identifier is passed in, the 71 | TFRecords file will not be read correctly. 72 | read_buffer_size: int. Number of bytes in the read buffer. See documentation 73 | for `tf.data.TFRecordDataset.__init__`. 74 | map_parallel_calls: int. Number of elements decoded asynchronously in 75 | parallel. See documentation for `tf.data.Dataset.map`. 76 | 77 | Returns: 78 | An unbatched `tf.data.TFRecordDataset`. 79 | """ 80 | if dataset_variant not in MAX_NUM_ENTITIES: 81 | raise ValueError('Invalid `dataset_variant` provided. The supported values' 82 | ' are: {}'.format(list(MAX_NUM_ENTITIES.keys()))) 83 | max_num_entities = MAX_NUM_ENTITIES[dataset_variant] 84 | raw_dataset = tf.data.TFRecordDataset( 85 | tfrecords_path, compression_type=COMPRESSION_TYPE, 86 | buffer_size=read_buffer_size) 87 | features = feature_descriptions(max_num_entities) 88 | partial_decode_fn = functools.partial(_decode, features=features) 89 | return raw_dataset.map(partial_decode_fn, 90 | num_parallel_calls=map_parallel_calls) 91 | -------------------------------------------------------------------------------- /third_party/multi_object_datasets/preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/applied-ai-lab/genesis/9abf202bbad6fa4a675117fdea0be163e4f16695/third_party/multi_object_datasets/preview.png -------------------------------------------------------------------------------- /third_party/multi_object_datasets/segmentation_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Implementation of the adjusted Rand index.""" 16 | 17 | import tensorflow as tf 18 | 19 | 20 | def adjusted_rand_index(true_mask, pred_mask, name='ari_score'): 21 | r"""Computes the adjusted Rand index (ARI), a clustering similarity score. 22 | 23 | This implementation ignores points with no cluster label in `true_mask` (i.e. 24 | those points for which `true_mask` is a zero vector). In the context of 25 | segmentation, that means this function can ignore points in an image 26 | corresponding to the background (i.e. not to an object). 27 | 28 | Args: 29 | true_mask: `Tensor` of shape [batch_size, n_points, n_true_groups]. 30 | The true cluster assignment encoded as one-hot. 31 | pred_mask: `Tensor` of shape [batch_size, n_points, n_pred_groups]. 32 | The predicted cluster assignment encoded as categorical probabilities. 33 | This function works on the argmax over axis 2. 34 | name: str. Name of this operation (defaults to "ari_score"). 35 | 36 | Returns: 37 | ARI scores as a tf.float32 `Tensor` of shape [batch_size]. 38 | 39 | Raises: 40 | ValueError: if n_points <= n_true_groups and n_points <= n_pred_groups. 41 | We've chosen not to handle the special cases that can occur when you have 42 | one cluster per datapoint (which would be unusual). 43 | 44 | References: 45 | Lawrence Hubert, Phipps Arabie. 1985. "Comparing partitions" 46 | https://link.springer.com/article/10.1007/BF01908075 47 | Wikipedia 48 | https://en.wikipedia.org/wiki/Rand_index 49 | Scikit Learn 50 | http://scikit-learn.org/stable/modules/generated/\ 51 | sklearn.metrics.adjusted_rand_score.html 52 | """ 53 | with tf.name_scope(name): 54 | _, n_points, n_true_groups = true_mask.shape.as_list() 55 | n_pred_groups = pred_mask.shape.as_list()[-1] 56 | if n_points <= n_true_groups and n_points <= n_pred_groups: 57 | # This rules out the n_true_groups == n_pred_groups == n_points 58 | # corner case, and also n_true_groups == n_pred_groups == 0, since 59 | # that would imply n_points == 0 too. 60 | # The sklearn implementation has a corner-case branch which does 61 | # handle this. We chose not to support these cases to avoid counting 62 | # distinct clusters just to check if we have one cluster per datapoint. 63 | raise ValueError( 64 | "adjusted_rand_index requires n_groups < n_points. We don't handle " 65 | "the special cases that can occur when you have one cluster " 66 | "per datapoint.") 67 | 68 | true_group_ids = tf.argmax(true_mask, -1) 69 | pred_group_ids = tf.argmax(pred_mask, -1) 70 | # We convert true and predicted clusters to one-hot ('oh') representations. 71 | true_mask_oh = tf.cast(true_mask, tf.float32) # already one-hot 72 | pred_mask_oh = tf.one_hot(pred_group_ids, n_pred_groups) # returns float32 73 | 74 | n_points = tf.cast(tf.reduce_sum(true_mask_oh, axis=[1, 2]), tf.float32) 75 | 76 | nij = tf.einsum('bji,bjk->bki', pred_mask_oh, true_mask_oh) 77 | a = tf.reduce_sum(nij, axis=1) 78 | b = tf.reduce_sum(nij, axis=2) 79 | 80 | rindex = tf.reduce_sum(nij * (nij - 1), axis=[1, 2]) 81 | aindex = tf.reduce_sum(a * (a - 1), axis=1) 82 | bindex = tf.reduce_sum(b * (b - 1), axis=1) 83 | expected_rindex = aindex * bindex / (n_points*(n_points-1)) 84 | max_rindex = (aindex + bindex) / 2 85 | ari = (rindex - expected_rindex) / (max_rindex - expected_rindex) 86 | 87 | # The case where n_true_groups == n_pred_groups == 1 needs to be 88 | # special-cased (to return 1) as the above formula gives a divide-by-zero. 89 | # This might not work when true_mask has values that do not sum to one: 90 | both_single_cluster = tf.logical_and( 91 | _all_equal(true_group_ids), _all_equal(pred_group_ids)) 92 | return tf.where(both_single_cluster, tf.ones_like(ari), ari) 93 | 94 | 95 | def _all_equal(values): 96 | """Whether values are all equal along the final axis.""" 97 | return tf.reduce_all(tf.equal(values, values[..., :1]), axis=-1) 98 | -------------------------------------------------------------------------------- /third_party/multi_object_datasets/tetrominoes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tetrominoes dataset reader.""" 16 | 17 | import tensorflow as tf 18 | 19 | 20 | COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string('GZIP') 21 | IMAGE_SIZE = [35, 35] 22 | # The maximum number of foreground and background entities in the provided 23 | # dataset. This corresponds to the number of segmentation masks returned per 24 | # scene. 25 | MAX_NUM_ENTITIES = 4 26 | BYTE_FEATURES = ['mask', 'image'] 27 | 28 | # Create a dictionary mapping feature names to `tf.Example`-compatible 29 | # shape and data type descriptors. 30 | features = { 31 | 'image': tf.FixedLenFeature(IMAGE_SIZE+[3], tf.string), 32 | 'mask': tf.FixedLenFeature([MAX_NUM_ENTITIES]+IMAGE_SIZE+[1], tf.string), 33 | 'x': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 34 | 'y': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 35 | 'shape': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 36 | 'color': tf.FixedLenFeature([MAX_NUM_ENTITIES, 3], tf.float32), 37 | 'visibility': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 38 | } 39 | 40 | 41 | def _decode(example_proto): 42 | # Parse the input `tf.Example` proto using the feature description dict above. 43 | single_example = tf.parse_single_example(example_proto, features) 44 | for k in BYTE_FEATURES: 45 | single_example[k] = tf.squeeze(tf.decode_raw(single_example[k], tf.uint8), 46 | axis=-1) 47 | return single_example 48 | 49 | 50 | def dataset(tfrecords_path, read_buffer_size=None, map_parallel_calls=None): 51 | """Read, decompress, and parse the TFRecords file. 52 | 53 | Args: 54 | tfrecords_path: str. Path to the dataset file. 55 | read_buffer_size: int. Number of bytes in the read buffer. See documentation 56 | for `tf.data.TFRecordDataset.__init__`. 57 | map_parallel_calls: int. Number of elements decoded asynchronously in 58 | parallel. See documentation for `tf.data.Dataset.map`. 59 | 60 | Returns: 61 | An unbatched `tf.data.TFRecordDataset`. 62 | """ 63 | raw_dataset = tf.data.TFRecordDataset( 64 | tfrecords_path, compression_type=COMPRESSION_TYPE, 65 | buffer_size=read_buffer_size) 66 | return raw_dataset.map(_decode, num_parallel_calls=map_parallel_calls) 67 | -------------------------------------------------------------------------------- /third_party/pytorch_fid/README.md: -------------------------------------------------------------------------------- 1 | # Fréchet Inception Distance (FID score) in PyTorch 2 | 3 | This is a port of the official implementation of [Fréchet Inception Distance](https://arxiv.org/abs/1706.08500) to PyTorch. 4 | See [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR) for the original implementation using Tensorflow. 5 | 6 | FID is a measure of similarity between two datasets of images. 7 | It was shown to correlate well with human judgement of visual quality and is most often used to evaluate the quality of samples of Generative Adversarial Networks. 8 | FID is calculated by computing the [Fréchet distance](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distance) between two Gaussians fitted to feature representations of the Inception network. 9 | 10 | Further insights and an independent evaluation of the FID score can be found in [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337). 11 | 12 | **Note that the official implementation gives slightly different scores.** If you report FID scores in your paper, and you want them to be exactly comparable to FID scores reported in other papers, you should use [the official Tensorflow implementation](https://github.com/bioinf-jku/TTUR). 13 | You can still use this version if you want a quick FID estimate without installing Tensorflow. 14 | 15 | **Update:** The weights and the model are now exactly the same as in the official Tensorflow implementation, and I verified them to give the same results (around `1e-8` mean absolute error) on single inputs on my platform. However, due to differences in the image interpolation implementation and library backends, FID results might still differ slightly from the original implementation. A test I ran (details are to come) resulted in `.08` absolute error and `0.0009` relative error. 16 | 17 | ## Usage 18 | 19 | Requirements: 20 | - python3 21 | - pytorch 22 | - torchvision 23 | - numpy 24 | - scipy 25 | 26 | To compute the FID score between two datasets, where images of each dataset are contained in an individual folder: 27 | ``` 28 | ./fid_score.py path/to/dataset1 path/to/dataset2 29 | ``` 30 | 31 | To run the evaluation on GPU, use the flag `--gpu N`, where `N` is the index of the GPU to use. 32 | 33 | ### Using different layers for feature maps 34 | 35 | In difference to the official implementation, you can choose to use a different feature layer of the Inception network instead of the default `pool3` layer. 36 | As the lower layer features still have spatial extent, the features are first global average pooled to a vector before estimating mean and covariance. 37 | 38 | This might be useful if the datasets you want to compare have less than the otherwise required 2048 images. 39 | Note that this changes the magnitude of the FID score and you can not compare them against scores calculated on another dimensionality. 40 | The resulting scores might also no longer correlate with visual quality. 41 | 42 | You can select the dimensionality of features to use with the flag `--dims N`, where N is the dimensionality of features. 43 | The choices are: 44 | - 64: first max pooling features 45 | - 192: second max pooling featurs 46 | - 768: pre-aux classifier features 47 | - 2048: final average pooling features (this is the default) 48 | 49 | ## License 50 | 51 | This implementation is licensed under the Apache License 2.0. 52 | 53 | FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see [https://arxiv.org/abs/1706.08500](https://arxiv.org/abs/1706.08500) 54 | 55 | The original implementation is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0. 56 | See [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR). 57 | -------------------------------------------------------------------------------- /third_party/pytorch_fid/fid_score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 3 | 4 | The FID metric calculates the distance between two distributions of images. 5 | Typically, we have summary statistics (mean & covariance matrix) of one 6 | of these distributions, while the 2nd distribution is given by a GAN. 7 | 8 | When run as a stand-alone program, it compares the distribution of 9 | images that are stored as PNG/JPEG at a specified location with a 10 | distribution given by summary statistics (in pickle format). 11 | 12 | The FID is calculated by assuming that X_1 and X_2 are the activations of 13 | the pool_3 layer of the inception net for generated samples and real world 14 | samples respectively. 15 | 16 | See --help to see further details. 17 | 18 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 19 | of Tensorflow 20 | 21 | Copyright 2018 Institute of Bioinformatics, JKU Linz 22 | 23 | Licensed under the Apache License, Version 2.0 (the "License"); 24 | you may not use this file except in compliance with the License. 25 | You may obtain a copy of the License at 26 | 27 | http://www.apache.org/licenses/LICENSE-2.0 28 | 29 | Unless required by applicable law or agreed to in writing, software 30 | distributed under the License is distributed on an "AS IS" BASIS, 31 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 32 | See the License for the specific language governing permissions and 33 | limitations under the License. 34 | 35 | 36 | Modified by Martin Engelcke. 37 | """ 38 | 39 | import os 40 | import pathlib 41 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 42 | 43 | import numpy as np 44 | import torch 45 | from scipy import linalg 46 | from imageio import imread 47 | from torch.nn.functional import adaptive_avg_pool2d 48 | 49 | try: 50 | from tqdm import tqdm 51 | except ImportError: 52 | # If not tqdm is not available, provide a mock version of it 53 | def tqdm(x): return x 54 | 55 | from third_party.pytorch_fid.inception import InceptionV3 56 | 57 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 58 | parser.add_argument('path', type=str, nargs=2, 59 | help=('Path to the generated images or ' 60 | 'to .npz statistic files')) 61 | parser.add_argument('--batch-size', type=int, default=50, 62 | help='Batch size to use') 63 | parser.add_argument('--dims', type=int, default=2048, 64 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 65 | help=('Dimensionality of Inception features to use. ' 66 | 'By default, uses pool3 features')) 67 | parser.add_argument('-c', '--gpu', default='', type=str, 68 | help='GPU to use (leave blank for CPU only)') 69 | 70 | 71 | def get_activations(files, model, batch_size=50, dims=2048, 72 | cuda=False, verbose=False): 73 | """Calculates the activations of the pool_3 layer for all images. 74 | 75 | Params: 76 | -- files : List of image files paths 77 | -- model : Instance of inception model 78 | -- batch_size : Batch size of images for the model to process at once. 79 | Make sure that the number of samples is a multiple of 80 | the batch size, otherwise some samples are ignored. This 81 | behavior is retained to match the original FID score 82 | implementation. 83 | -- dims : Dimensionality of features returned by Inception 84 | -- cuda : If set to True, use GPU 85 | -- verbose : If set to True and parameter out_step is given, the number 86 | of calculated batches is reported. 87 | Returns: 88 | -- A numpy array of dimension (num images, dims) that contains the 89 | activations of the given tensor when feeding inception with the 90 | query tensor. 91 | """ 92 | model.eval() 93 | 94 | if len(files) % batch_size != 0: 95 | print(('Warning: number of images is not a multiple of the ' 96 | 'batch size. Some samples are going to be ignored.')) 97 | if batch_size > len(files): 98 | print(('Warning: batch size is bigger than the data size. ' 99 | 'Setting batch size to data size')) 100 | batch_size = len(files) 101 | 102 | n_batches = len(files) // batch_size 103 | n_used_imgs = n_batches * batch_size 104 | 105 | pred_arr = np.empty((n_used_imgs, dims)) 106 | 107 | for i in tqdm(range(n_batches)): 108 | if verbose: 109 | print('\rPropagating batch %d/%d' % (i + 1, n_batches), 110 | end='', flush=True) 111 | start = i * batch_size 112 | end = start + batch_size 113 | 114 | images = np.array([imread(str(f)).astype(np.float32) 115 | for f in files[start:end]]) 116 | 117 | # Reshape to (n_images, 3, height, width) 118 | images = images.transpose((0, 3, 1, 2)) 119 | images /= 255 120 | 121 | batch = torch.from_numpy(images).type(torch.FloatTensor) 122 | if cuda: 123 | batch = batch.cuda() 124 | 125 | pred = model(batch)[0] 126 | 127 | # If model output is not scalar, apply global spatial average pooling. 128 | # This happens if you choose a dimensionality not equal 2048. 129 | if pred.shape[2] != 1 or pred.shape[3] != 1: 130 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 131 | 132 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1) 133 | 134 | if verbose: 135 | print(' done') 136 | 137 | return pred_arr 138 | 139 | 140 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 141 | """Numpy implementation of the Frechet Distance. 142 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 143 | and X_2 ~ N(mu_2, C_2) is 144 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 145 | 146 | Stable version by Dougal J. Sutherland. 147 | 148 | Params: 149 | -- mu1 : Numpy array containing the activations of a layer of the 150 | inception net (like returned by the function 'get_predictions') 151 | for generated samples. 152 | -- mu2 : The sample mean over activations, precalculated on an 153 | representative data set. 154 | -- sigma1: The covariance matrix over activations for generated samples. 155 | -- sigma2: The covariance matrix over activations, precalculated on an 156 | representative data set. 157 | 158 | Returns: 159 | -- : The Frechet Distance. 160 | """ 161 | 162 | mu1 = np.atleast_1d(mu1) 163 | mu2 = np.atleast_1d(mu2) 164 | 165 | sigma1 = np.atleast_2d(sigma1) 166 | sigma2 = np.atleast_2d(sigma2) 167 | 168 | assert mu1.shape == mu2.shape, \ 169 | 'Training and test mean vectors have different lengths' 170 | assert sigma1.shape == sigma2.shape, \ 171 | 'Training and test covariances have different dimensions' 172 | 173 | diff = mu1 - mu2 174 | 175 | # Product might be almost singular 176 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 177 | if not np.isfinite(covmean).all(): 178 | msg = ('fid calculation produces singular product; ' 179 | 'adding %s to diagonal of cov estimates') % eps 180 | print(msg) 181 | offset = np.eye(sigma1.shape[0]) * eps 182 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 183 | 184 | # Numerical error might give slight imaginary component 185 | if np.iscomplexobj(covmean): 186 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 187 | m = np.max(np.abs(covmean.imag)) 188 | raise ValueError('Imaginary component {}'.format(m)) 189 | covmean = covmean.real 190 | 191 | tr_covmean = np.trace(covmean) 192 | 193 | return (diff.dot(diff) + np.trace(sigma1) + 194 | np.trace(sigma2) - 2 * tr_covmean) 195 | 196 | 197 | def calculate_activation_statistics(files, model, batch_size=50, 198 | dims=2048, cuda=False, verbose=False): 199 | """Calculation of the statistics used by the FID. 200 | Params: 201 | -- files : List of image files paths 202 | -- model : Instance of inception model 203 | -- batch_size : The images numpy array is split into batches with 204 | batch size batch_size. A reasonable batch size 205 | depends on the hardware. 206 | -- dims : Dimensionality of features returned by Inception 207 | -- cuda : If set to True, use GPU 208 | -- verbose : If set to True and parameter out_step is given, the 209 | number of calculated batches is reported. 210 | Returns: 211 | -- mu : The mean over samples of the activations of the pool_3 layer of 212 | the inception model. 213 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 214 | the inception model. 215 | """ 216 | act = get_activations(files, model, batch_size, dims, cuda, verbose) 217 | mu = np.mean(act, axis=0) 218 | sigma = np.cov(act, rowvar=False) 219 | return mu, sigma 220 | 221 | 222 | def _compute_statistics_of_path(path, model, batch_size, dims, cuda): 223 | if path.endswith('.npz'): 224 | f = np.load(path) 225 | m, s = f['mu'][:], f['sigma'][:] 226 | f.close() 227 | else: 228 | path = pathlib.Path(path) 229 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 230 | m, s = calculate_activation_statistics(files, model, batch_size, 231 | dims, cuda) 232 | 233 | return m, s 234 | 235 | 236 | def calculate_fid_given_paths(paths, batch_size, cuda, dims): 237 | """Calculates the FID of two paths""" 238 | for p in paths: 239 | if not os.path.exists(p): 240 | raise RuntimeError('Invalid path: %s' % p) 241 | 242 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 243 | 244 | model = InceptionV3([block_idx]) 245 | if cuda: 246 | model.cuda() 247 | 248 | m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, 249 | dims, cuda) 250 | m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, 251 | dims, cuda) 252 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 253 | 254 | return fid_value 255 | 256 | 257 | if __name__ == '__main__': 258 | args = parser.parse_args() 259 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 260 | 261 | fid_value = calculate_fid_given_paths(args.path, 262 | args.batch_size, 263 | args.gpu != '', 264 | args.dims) 265 | print('FID: ', fid_value) 266 | -------------------------------------------------------------------------------- /third_party/shapestacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/applied-ai-lab/genesis/9abf202bbad6fa4a675117fdea0be163e4f16695/third_party/shapestacks/__init__.py -------------------------------------------------------------------------------- /third_party/shapestacks/segmentation_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions to work with segmentation maps (.map files). 3 | """ 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | 9 | # max. number of labels allowed in a uint8 map 10 | MAX_LABELS = 256 11 | 12 | # max. number of labels in VSEG maps, only labels 0-4 used! 13 | VSEG_LABEL_RESOLUTION = 8 14 | 15 | # label semantics in VSEG maps 16 | # 0 : background 17 | # 1 : stable base of stack 18 | # 2 : object violating global stack stability 19 | # 3 : object above stability violation / first to fall 20 | # 4 : top of stack 21 | 22 | 23 | def load_segmap_as_matrix( 24 | map_path: str, 25 | label_resolution: int = VSEG_LABEL_RESOLUTION): 26 | """ 27 | Loads a .map file and returns a matrix of the label values (uint8 between 0 28 | and 255). 29 | 30 | Args: 31 | map_path: path to the .map file to load 32 | label_resolution: max. number of labels used in the map's encoding, 33 | must be a power of 2 34 | 35 | Returns: 36 | A np.ndarray of the semantic segmentation labels. 37 | """ 38 | png_map = plt.imread(map_path) 39 | label_bin_size = MAX_LABELS // label_resolution 40 | lbl_map = np.copy(png_map[:, :, 0]) # slice of first image layer 41 | lbl_map = lbl_map / label_bin_size 42 | return lbl_map -------------------------------------------------------------------------------- /third_party/shapestacks/shapestacks_provider.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provides an input_fn for tf.estimator.Estimator to load the images of the real 3 | synthetic simulation recordings of a ShapeStacks dataset. 4 | 5 | Adapted from https://github.com/ogroth/shapestacks 6 | 7 | Modified by Martin Engelcke 8 | """ 9 | 10 | 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import print_function 14 | 15 | import os 16 | # import tensorflow as tf 17 | # import numpy as np 18 | 19 | 20 | # dataset constants 21 | _CHANNELS = 3 # RGB images 22 | _HEIGHT = 224 23 | _WIDTH = 224 24 | _NUM_CLASSES = 2 # stable | unstable 25 | # label semantics: 0 = stable | 1 = unstable 26 | 27 | # data augmentation constants 28 | _CROP_HEIGHT = 196 29 | _CROP_WIDTH = 196 30 | 31 | 32 | # internal dataset creation, file parsing and pre-processing 33 | 34 | def _get_filenames_with_labels(mode, data_dir, split_dir): 35 | """ 36 | Returns all training or test files in the data directory with their 37 | respective labels. 38 | """ 39 | if mode == 'train': 40 | scenario_list_file = os.path.join(split_dir, 'train.txt') 41 | elif mode == 'eval': 42 | scenario_list_file = os.path.join(split_dir, 'eval.txt') 43 | elif mode == 'test': 44 | scenario_list_file = os.path.join(split_dir, 'test.txt') 45 | else: 46 | raise ValueError("Mode %s is not supported!" % mode) 47 | with open(scenario_list_file) as f: 48 | scenario_list = f.read().split('\n') 49 | scenario_list.pop() 50 | 51 | filenames = [] 52 | labels = [] 53 | for i, scenario in enumerate(scenario_list): 54 | if (i+1) % 100 == 0: 55 | print("%s / %s : %s" % (i+1, len(scenario_list), scenario)) 56 | scenario_dir = os.path.join(data_dir, 'recordings', scenario) 57 | if "vcom=0" in scenario and "vpsf=0" in scenario: # stable scenario 58 | label = 0.0 59 | else: # unstable scenario 60 | label = 1.0 61 | for img_file in filter( 62 | lambda f: f.startswith('rgb-') and f.endswith('-mono-0.png'), 63 | os.listdir(scenario_dir)): 64 | filenames.append(os.path.join(scenario_dir, img_file)) 65 | labels.append(label) 66 | 67 | return filenames, labels 68 | 69 | # def _create_dataset(filenames, labels): 70 | # """ 71 | # Creates a dataset from the given filename and label tensors. 72 | # """ 73 | # tf_filenames = tf.constant(filenames) 74 | # tf_labels = tf.constant(labels) 75 | # dataset = tf.data.Dataset.from_tensor_slices((tf_filenames, tf_labels)) 76 | # return dataset 77 | 78 | # def _parse_record(filename, label): 79 | # """ 80 | # Reads the file and returns a (feature, label) pair. 81 | # Image feature values are returned to scale in [0.0, 1.0]. 82 | # """ 83 | # image_string = tf.read_file(filename) 84 | # image_decoded = tf.image.decode_image(image_string, channels=_CHANNELS) 85 | # image_resized = tf.image.resize_image_with_crop_or_pad(image_decoded, _HEIGHT, _WIDTH) 86 | # image_float = tf.cast(image_resized, tf.float32) 87 | # image_float = tf.reshape(image_float, [_HEIGHT, _WIDTH, _CHANNELS]) 88 | # return image_float, label 89 | 90 | # def _augment(feature, label, augment): 91 | # """ 92 | # Applies data augmentation to the features. 93 | # Augmentaion contains: 94 | # - random cropping and resizing back to _HEIGHT & _WIDTH 95 | # - random LR flip 96 | # - random recoloring 97 | # - clip within [-1, 1] 98 | # """ 99 | 100 | # feature = tf.image.convert_image_dtype(feature, tf.float32, saturate=True) 101 | # convert_factor = 1.0 102 | 103 | # if 'rotate' in augment: 104 | # random_rotation = tf.reshape( 105 | # tf.random_uniform([1], minval=-0.01, maxval=0.01, dtype=tf.float32), 106 | # []) 107 | # feature = tf.contrib.image.rotate( 108 | # feature, random_rotation * 3.1415, interpolation='BILINEAR') 109 | 110 | # if 'convert' in augment: 111 | # feature = tf.multiply(feature, 1.0 / 255.0) 112 | # convert_factor = 255.0 113 | 114 | # if 'crop' in augment: 115 | 116 | # if 'stretch' in augment: 117 | # rand_crop_height = tf.reshape( 118 | # tf.random_uniform( 119 | # [1], minval=_CROP_HEIGHT, maxval=_HEIGHT, dtype=tf.int32), 120 | # []) 121 | # rand_crop_width = tf.reshape( 122 | # tf.random_uniform( 123 | # [1], minval=_CROP_WIDTH, maxval=_WIDTH, dtype=tf.int32), 124 | # []) 125 | # else: 126 | # rand_crop_height = _CROP_HEIGHT 127 | # rand_crop_width = _CROP_WIDTH 128 | 129 | # feature = tf.random_crop( 130 | # value=feature, size=[rand_crop_height, rand_crop_width, _CHANNELS]) 131 | # feature = tf.image.resize_bilinear( 132 | # images=tf.reshape( 133 | # feature, [1, rand_crop_height, rand_crop_width, _CHANNELS]), 134 | # size=[_HEIGHT, _WIDTH]) 135 | 136 | # if 'flip' in augment: 137 | # feature = tf.image.random_flip_left_right( 138 | # tf.reshape(feature, [_HEIGHT, _WIDTH, _CHANNELS])) 139 | 140 | # if 'recolour' in augment: 141 | # feature = tf.image.random_brightness(feature, max_delta=32. / convert_factor) 142 | # feature = tf.image.random_saturation(feature, lower=0.5, upper=1.5) 143 | # feature = tf.image.random_hue(feature, max_delta=0.2) 144 | # feature = tf.image.random_contrast(feature, lower=0.5, upper=1.5) 145 | 146 | # if 'noise' in augment: 147 | # # add gaussian noise 148 | # gaussian_noise = tf.random_normal( 149 | # [_HEIGHT, _WIDTH, _CHANNELS], stddev=4. / convert_factor) 150 | # feature = tf.add(feature, gaussian_noise) 151 | 152 | # if 'clip' in augment: 153 | # if 'convert' in augment: 154 | # # clip to [0,1] 155 | # feature = tf.clip_by_value(feature, 0.0, 1.0) 156 | # else: 157 | # feature = tf.clip_by_value(feature, 0.0, 255.0) 158 | 159 | # if 'center' in augment: 160 | # # center around 0 161 | # feature = tf.subtract(feature, 0.5) 162 | # feature = tf.multiply(feature, 2.0) 163 | 164 | # feature = tf.reshape(feature, [_HEIGHT, _WIDTH, _CHANNELS]) 165 | # return feature, label 166 | 167 | # def _center_data(feature, label, rgb_mean): 168 | # """ 169 | # Subtracts the mean of the respective data split part to center the data. 170 | # rgb_mean is expected to scale in [0.0, 1.0]. 171 | # """ 172 | # feature_centered = feature - tf.reshape(tf.constant(rgb_mean), [1, 1, 3]) 173 | # return feature_centered, label 174 | 175 | 176 | # # public input_fn for dataset iteration 177 | 178 | # def shapestacks_input_fn( 179 | # mode, data_dir, split_name, 180 | # batch_size, num_epochs=1, 181 | # n_prefetch=2, augment=[]): 182 | # """ 183 | # Input_fn to feed a tf.estimator.Estimator with ShapeStacks images. 184 | 185 | # Args: 186 | # mode: train | eval | test 187 | # data_dir: 188 | # split_name: directory name under data_dir/splits containing train.txt, eval.txt and test.txt 189 | # batch_size: 190 | # num_epochs: 191 | # n_prefetch: number of images to prefetch into RAM 192 | # augment: data augmentations to apply 193 | # 'rotate': randomly rotates the image in plane by +/- 2 degrees 194 | # 'convert': converts input values into [0.0, 1.0] 195 | # 'crop': performs a random quadratic center crop 196 | # 'stretch': performs a random center crop not preserving aspect ratio 197 | # 'flip': applies a random left-right flip 198 | # 'recolour': recolours the image by randomly tuning brightness, saturation, 199 | # hue and contrast 200 | # 'noise': adds Gaussian noise to the image 201 | # 'clip': clips input values to [0.0, 1.0] 202 | # 'center': 203 | # 'subtract_mean': subtracts the RGB mean of the data chunk loaded 204 | # """ 205 | # split_dir = os.path.join(data_dir, 'splits', split_name) 206 | # filenames, labels = _get_filenames_with_labels(mode, data_dir, split_dir) 207 | # rgb_mean_npy = np.load(os.path.join(split_dir, mode + '_bgr_mean.npy'))[[2, 1, 0]] 208 | # dataset = _create_dataset(filenames, labels) 209 | 210 | # # shuffle before providing data 211 | # if mode == 'train': 212 | # dataset = dataset.shuffle(buffer_size=len(filenames)) 213 | 214 | # # parse data from files and apply pre-processing 215 | # dataset = dataset.map(_parse_record) 216 | # if augment != [] and mode == 'train': 217 | # dataset = dataset.map(lambda feature, label: _augment(feature, label, augment)) 218 | # if 'subtract_mean' in augment: 219 | # dataset = dataset.map(lambda feature, label: _center_data(feature, label, rgb_mean_npy)) 220 | 221 | # # prepare batch and epoch cycle 222 | # dataset = dataset.prefetch(n_prefetch * batch_size) 223 | # dataset = dataset.repeat(num_epochs) 224 | # dataset = dataset.batch(batch_size) 225 | 226 | # # set up iterator 227 | # iterator = dataset.make_one_shot_iterator() 228 | # images, labels = iterator.get_next() 229 | # return images, labels -------------------------------------------------------------------------------- /third_party/sylvester/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Rianne van den Berg 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /third_party/sylvester/VAE.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # Adapted from https://github.com/riannevdberg/sylvester-flows 3 | # 4 | # Modified by Martin Engelcke 5 | ################################################################################ 6 | 7 | from attrdict import AttrDict 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.distributions.normal import Normal 12 | 13 | from third_party.sylvester.layers import GatedConv2d, GatedConvTranspose2d 14 | 15 | from modules.blocks import ToVar 16 | 17 | 18 | def build_gc_encoder(cin, cout, stride, cfc, kfc, hn=None, gn=None): 19 | assert len(cin) == len(cout) and len(cin) == len(stride) 20 | layers = [] 21 | for l, (i, o, s) in enumerate(zip(cin, cout, stride)): 22 | layers.append(GatedConv2d(i, o, 5, s, 2, h_norm=hn, g_norm=gn)) 23 | layers.append(GatedConv2d(cout[-1], cfc, kfc, 1, 0)) 24 | return nn.Sequential(*layers) 25 | 26 | 27 | def build_gc_decoder(cin, cout, stride, zdim, kz, hn=None, gn=None): 28 | assert len(cin) == len(cout) and len(cin) == len(stride) 29 | layers = [GatedConvTranspose2d(zdim, cin[0], kz, 1, 0)] 30 | for l, (i, o, s) in enumerate(zip(cin, cout, stride)): 31 | layers.append(GatedConvTranspose2d(i, o, 5, s, 2, s-1, 32 | h_norm=hn, g_norm=gn)) 33 | return nn.Sequential(*layers) 34 | 35 | 36 | class VAE(nn.Module): 37 | """ 38 | The base VAE class containing gated convolutional encoder and decoder. 39 | Can be used as a base class for VAE's with normalizing flows. 40 | """ 41 | 42 | def __init__(self, z_size, input_size, nout, 43 | enc_norm=None, dec_norm=None): 44 | super(VAE, self).__init__() 45 | 46 | # extract model settings from args 47 | self.z_size = z_size 48 | self.input_size = input_size 49 | if nout is not None: 50 | self.nout = nout 51 | else: 52 | self.nout = input_size[0] 53 | self.enc_norm = enc_norm 54 | self.dec_norm = dec_norm 55 | 56 | if self.input_size[1] == 32 and self.input_size[2] == 32: 57 | self.last_kernel_size = 8 58 | strides = [1, 2, 1, 2, 1] 59 | elif self.input_size[1] == 64 and self.input_size[2] == 64: 60 | self.last_kernel_size = 16 61 | strides = [1, 2, 1, 2, 1] 62 | elif self.input_size[1] == 128 and self.input_size[2] == 128: 63 | self.last_kernel_size = 16 64 | strides = [2, 2, 2, 1, 1] 65 | elif self.input_size[1] == 256 and self.input_size[2] == 256: 66 | self.last_kernel_size = 16 67 | strides = [2, 2, 2, 2, 1] 68 | else: 69 | raise ValueError('Invalid input size.') 70 | 71 | self.q_z_nn_output_dim = 256 72 | 73 | # Build encoder 74 | cin = [self.input_size[0], 32, 32, 64, 64] 75 | cout = [32, 32, 64, 64, 64] 76 | self.q_z_nn, self.q_z_mean, self.q_z_var = self.create_encoder( 77 | cin, cout, strides) 78 | 79 | # Build decoder 80 | cin = [64, 64, 32, 32, 32] 81 | cout = [64, 32, 32, 32, 32] 82 | self.p_x_nn, self.p_x_mean = self.create_decoder( 83 | cin, cout, list(reversed(strides))) 84 | 85 | # log-det-jacobian = 0 without flows 86 | self.log_det_j = torch.tensor(0) 87 | 88 | def create_encoder(self, cin, cout, strides): 89 | """ 90 | Helper function to create the elemental blocks for the encoder. 91 | Creates a gated convnet encoder. 92 | the encoder expects data as input of shape: 93 | (batch_size, num_channels, width, height). 94 | """ 95 | 96 | q_z_nn = build_gc_encoder( 97 | cin, cout, strides, self.q_z_nn_output_dim, self.last_kernel_size, 98 | hn=self.enc_norm, gn=self.enc_norm 99 | ) 100 | q_z_mean = nn.Linear(256, self.z_size) 101 | q_z_var = nn.Sequential( 102 | nn.Linear(256, self.z_size), 103 | ToVar(), 104 | ) 105 | return q_z_nn, q_z_mean, q_z_var 106 | 107 | def create_decoder(self, cin, cout, strides): 108 | """ 109 | Helper function to create the elemental blocks for the decoder. 110 | Creates a gated convnet decoder. 111 | """ 112 | 113 | p_x_nn = build_gc_decoder( 114 | cin, cout, strides, self.z_size, self.last_kernel_size, 115 | hn=self.dec_norm, gn=self.dec_norm 116 | ) 117 | p_x_mean = nn.Conv2d(cout[-1], self.nout, 1, 1, 0) 118 | return p_x_nn, p_x_mean 119 | 120 | def reparameterize(self, mu, var): 121 | """ 122 | Samples z from a multivariate Gaussian with diagonal covariance matrix using the 123 | reparameterization trick. 124 | """ 125 | 126 | q_z = Normal(mu, var.sqrt()) 127 | z = q_z.rsample() 128 | return z, q_z 129 | 130 | def encode(self, x): 131 | """ 132 | Encoder expects following data shapes as input: 133 | shape = (batch_size, num_channels, width, height) 134 | """ 135 | 136 | h = self.q_z_nn(x) 137 | h = h.view(h.size(0), -1) 138 | mean = self.q_z_mean(h) 139 | var = self.q_z_var(h) 140 | 141 | return mean, var 142 | 143 | def decode(self, z): 144 | """ 145 | Decoder outputs reconstructed image in the following shapes: 146 | x_mean.shape = (batch_size, num_channels, width, height) 147 | """ 148 | 149 | z = z.view(z.size(0), self.z_size, 1, 1) 150 | h = self.p_x_nn(z) 151 | x_mean = self.p_x_mean(h) 152 | 153 | return x_mean 154 | 155 | def forward(self, x): 156 | """ 157 | Evaluates the model as a whole, encodes and decodes. Note that the log det jacobian is zero 158 | for a plain VAE (without flows), and z_0 = z_k. 159 | """ 160 | 161 | # mean and variance of z 162 | z_mu, z_var = self.encode(x) 163 | # sample z 164 | z, q_z = self.reparameterize(z_mu, z_var) 165 | x_mean = self.decode(z) 166 | 167 | stats = AttrDict(x=x_mean, mu=z_mu, sigma=z_var.sqrt(), z=z) 168 | return x_mean, stats 169 | -------------------------------------------------------------------------------- /third_party/sylvester/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/applied-ai-lab/genesis/9abf202bbad6fa4a675117fdea0be163e4f16695/third_party/sylvester/__init__.py -------------------------------------------------------------------------------- /third_party/sylvester/layers.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # Adapted from https://github.com/riannevdberg/sylvester-flows 3 | # 4 | # Modified by Martin Engelcke 5 | ################################################################################ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class GatedConv2d(nn.Module): 12 | def __init__(self, input_channels, output_channels, kernel_size, stride, 13 | padding, dilation=1, activation=None, 14 | h_norm=None, g_norm=None): 15 | super(GatedConv2d, self).__init__() 16 | # Main 17 | self.activation = activation 18 | self.sigmoid = nn.Sigmoid() 19 | self.conv = nn.Conv2d(input_channels, 2*output_channels, kernel_size, 20 | stride, padding, dilation) 21 | # Normalisation 22 | self.h_norm, self.g_norm = None, None 23 | # - Hiddens 24 | if h_norm == 'in': 25 | self.h_norm = nn.InstanceNorm2d(output_channels, affine=True) 26 | elif h_norm == 'bn': 27 | self.h_norm = nn.BatchNorm2d(output_channels) 28 | elif h_norm is None or h_norm == 'none': 29 | pass 30 | else: 31 | raise ValueError("Normalisation option not recognised.") 32 | # - Gates 33 | if g_norm == 'in': 34 | self.g_norm = nn.InstanceNorm2d(output_channels, affine=True) 35 | elif g_norm == 'bn': 36 | self.g_norm = nn.BatchNorm2d(output_channels) 37 | elif g_norm is None or g_norm == 'none': 38 | pass 39 | else: 40 | raise ValueError("Normalisation option not recognised.") 41 | 42 | def forward(self, x): 43 | h, g = torch.chunk(self.conv(x), 2, dim=1) 44 | # Features 45 | if self.h_norm is not None: 46 | h = self.h_norm(h) 47 | if self.activation is not None: 48 | h = self.activation(h) 49 | # Gates 50 | if self.g_norm is not None: 51 | g = self.g_norm(g) 52 | g = self.sigmoid(g) 53 | # Output 54 | return h * g 55 | 56 | 57 | class GatedConvTranspose2d(nn.Module): 58 | def __init__(self, input_channels, output_channels, kernel_size, stride, 59 | padding, output_padding=0, dilation=1, activation=None, 60 | h_norm=None, g_norm=None): 61 | super(GatedConvTranspose2d, self).__init__() 62 | # Main 63 | self.activation = activation 64 | self.sigmoid = nn.Sigmoid() 65 | self.conv = nn.ConvTranspose2d( 66 | input_channels, 2*output_channels, kernel_size, stride, padding, 67 | output_padding, dilation=dilation) 68 | # Normalisation 69 | # - Hiddens 70 | self.h_norm, self.g_norm = None, None 71 | if h_norm == 'in': 72 | self.h_norm = nn.InstanceNorm2d(output_channels, affine=True) 73 | elif h_norm == 'bn': 74 | self.h_norm = nn.BatchNorm2d(output_channels) 75 | elif h_norm is None or h_norm == 'none': 76 | pass 77 | else: 78 | raise ValueError("Normalisation option not recognised.") 79 | # - Gates 80 | if g_norm == 'in': 81 | self.g_norm = nn.InstanceNorm2d(output_channels, affine=True) 82 | elif g_norm == 'bn': 83 | self.g_norm = nn.BatchNorm2d(output_channels) 84 | elif g_norm is None or g_norm == 'none': 85 | pass 86 | else: 87 | raise ValueError("Normalisation option not recognised.") 88 | 89 | def forward(self, x): 90 | h, g = torch.chunk(self.conv(x), 2, dim=1) 91 | # Features 92 | if self.h_norm is not None: 93 | h = self.h_norm(h) 94 | if self.activation is not None: 95 | h = self.activation(h) 96 | # Gates 97 | if self.g_norm is not None: 98 | g = self.g_norm(g) 99 | g = self.sigmoid(g) 100 | # Output 101 | return h * g 102 | -------------------------------------------------------------------------------- /third_party/tf_gqn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/applied-ai-lab/genesis/9abf202bbad6fa4a675117fdea0be163e4f16695/third_party/tf_gqn/__init__.py -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/applied-ai-lab/genesis/9abf202bbad6fa4a675117fdea0be163e4f16695/utils/__init__.py -------------------------------------------------------------------------------- /utils/colour_palette15.json: -------------------------------------------------------------------------------- 1 | { 2 | "palette": 3 | [[204,78,51], 4 | [82,180,78], 5 | [125,102,215], 6 | [151,177,49], 7 | [193,91,184], 8 | [217,149,40], 9 | [121,125,196], 10 | [182,136,63], 11 | [76,170,212], 12 | [208,67,130], 13 | [74,170,134], 14 | [202,86,102], 15 | [117,145,74], 16 | [187,109,153], 17 | [188,120,88]] 18 | } 19 | -------------------------------------------------------------------------------- /utils/geco.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import torch 15 | 16 | 17 | class GECO(): 18 | 19 | def __init__(self, goal, step_size, alpha=0.99, beta_init=1.0, 20 | beta_min=1e-10, speedup=None): 21 | self.err_ema = None 22 | self.goal = goal 23 | self.step_size = step_size 24 | self.alpha = alpha 25 | self.beta = torch.tensor(beta_init) 26 | self.beta_min = torch.tensor(beta_min) 27 | self.beta_max = torch.tensor(1e10) 28 | self.speedup = speedup 29 | 30 | def to_cuda(self): 31 | self.beta = self.beta.cuda() 32 | if self.err_ema is not None: 33 | self.err_ema = self.err_ema.cuda() 34 | 35 | def loss(self, err, kld): 36 | # Compute loss with current beta 37 | loss = err + self.beta * kld 38 | # Update beta without computing / backpropping gradients 39 | with torch.no_grad(): 40 | if self.err_ema is None: 41 | self.err_ema = err 42 | else: 43 | self.err_ema = (1.0-self.alpha)*err + self.alpha*self.err_ema 44 | constraint = (self.goal - self.err_ema) 45 | if self.speedup is not None and constraint.item() > 0: 46 | factor = torch.exp(self.speedup * self.step_size * constraint) 47 | else: 48 | factor = torch.exp(self.step_size * constraint) 49 | self.beta = (factor * self.beta).clamp(self.beta_min, self.beta_max) 50 | # Return loss 51 | return loss 52 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import sys 15 | import time 16 | import simplejson as json 17 | import datetime 18 | 19 | import torch 20 | from torch.distributions.kl import kl_divergence 21 | 22 | import tensorflow as tf 23 | 24 | import numpy as np 25 | 26 | from sklearn.metrics import adjusted_rand_score 27 | 28 | from forge.experiment_tools import fprint 29 | 30 | 31 | def len_tfrecords(dataset, sess): 32 | iterator = dataset.make_one_shot_iterator() 33 | frame = iterator.get_next() 34 | total_sz = 0 35 | while True: 36 | try: 37 | _ = sess.run(frame) 38 | total_sz += 1 39 | if total_sz % 1000 == 0: 40 | print(total_sz) 41 | except tf.errors.OutOfRangeError: 42 | return total_sz 43 | 44 | 45 | def np_img_centre_crop(np_img, crop_dim, batch=False): 46 | # np_img: [c, dim1, dim2] if batch == False else [batch_sz, c, dim1, dim2] 47 | shape = np_img.shape 48 | if batch: 49 | s2 = (shape[2]-crop_dim)//2 50 | s3 = (shape[3]-crop_dim)//2 51 | return np_img[:, :, s2:s2+crop_dim, s3:s3+crop_dim] 52 | else: 53 | s1 = (shape[1]-crop_dim)//2 54 | s2 = (shape[2]-crop_dim)//2 55 | return np_img[:, s1:s1+crop_dim, s2:s2+crop_dim] 56 | 57 | 58 | def loader_throughput(loader, num_batches=100, burn_in=5): 59 | assert num_batches > 0 60 | if burn_in is None: 61 | burn_in = num_batches // 10 62 | num_samples = 0 63 | fprint(f"Train loader throughput stats on {num_batches} batches...") 64 | for i, batch in enumerate(loader): 65 | if i == burn_in: 66 | timer = time.time() 67 | if i >= burn_in: 68 | num_samples += batch['input'].size(0) 69 | if i == num_batches + burn_in: 70 | break 71 | dt = time.time() - timer 72 | spb = dt / num_batches 73 | ips = num_samples / dt 74 | fprint(f"{spb:.3f} s/b, {ips:.1f} im/s") 75 | 76 | 77 | def log_scalars(sdict, tag, step, writer): 78 | for key, val in sdict.items(): 79 | writer.add_scalar(f'{tag}/{key}', val, step) 80 | 81 | 82 | def colour_seg_masks(masks, palette='15'): 83 | # NOTE: Maps negative (ignore) labels to black 84 | if masks.dim() == 3: 85 | masks = masks.unsqueeze(1) 86 | assert masks.dim() == 4 87 | assert masks.shape[1] == 1 88 | colours = json.load(open(f'utils/colour_palette{palette}.json')) 89 | img_r = torch.zeros_like(masks) 90 | img_g = torch.zeros_like(masks) 91 | img_b = torch.zeros_like(masks) 92 | for c_idx in range(masks.max().item() + 1): 93 | c_map = masks == c_idx 94 | if c_map.any(): 95 | img_r[c_map] = colours['palette'][c_idx][0] 96 | img_g[c_map] = colours['palette'][c_idx][1] 97 | img_b[c_map] = colours['palette'][c_idx][2] 98 | return torch.cat([img_r, img_g, img_b], dim=1) 99 | 100 | 101 | def average_ari(log_m_k, instances, foreground_only=False): 102 | ari = [] 103 | masks_stacked = torch.stack(log_m_k, dim=4).exp().detach() 104 | masks_split = torch.split(masks_stacked, 1, dim=0) 105 | # Loop over elements in batch 106 | for i, m in enumerate(masks_split): 107 | masks_pred = np.argmax(m.cpu().numpy(), axis=-1).flatten() 108 | masks_gt = instances[i].detach().cpu().numpy().flatten() 109 | if foreground_only: 110 | masks_pred = masks_pred[np.where(masks_gt > 0)] 111 | masks_gt = masks_gt[np.where(masks_gt > 0)] 112 | score = adjusted_rand_score(masks_pred, masks_gt) 113 | ari.append(score) 114 | return sum(ari)/len(ari), ari 115 | 116 | 117 | def dataset_ari(model, data_loader, num_images=300): 118 | 119 | model.eval() 120 | 121 | fprint("Computing ARI on dataset") 122 | ari = [] 123 | ari_fg = [] 124 | model.eval() 125 | for bidx, batch in enumerate(data_loader): 126 | if next(model.parameters()).is_cuda: 127 | batch['input'] = batch['input'].cuda() 128 | with torch.no_grad(): 129 | _, _, stats, _, _ = model(batch['input']) 130 | 131 | # Return zero if labels or segmentations are not available 132 | if 'instances' not in batch or not hasattr(stats, 'log_m_k'): 133 | return 0., 0., [0], [0] 134 | 135 | _, ari_list = average_ari(stats.log_m_k, batch['instances']) 136 | _, ari_fg_list = average_ari(stats.log_m_k, batch['instances'], True) 137 | ari += ari_list 138 | ari_fg += ari_fg_list 139 | if bidx % 1 == 0: 140 | log_ari = sum(ari)/len(ari) 141 | log_ari_fg = sum(ari_fg)/len(ari_fg) 142 | t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 143 | fprint(f"{t} | After [{len(ari)} / {num_images}] images: " + 144 | f"ARI {log_ari:.4f}, FG ARI {log_ari_fg:.4f}") 145 | if len(ari) >= num_images: 146 | break 147 | 148 | assert len(ari) == len(ari_fg) 149 | ari = ari[:num_images] 150 | ari_fg = ari_fg[:num_images] 151 | 152 | avg_ari = sum(ari)/len(ari) 153 | avg_ari_fg = sum(ari_fg)/len(ari_fg) 154 | fprint(f"FINAL ARI for {len(ari)} images: {avg_ari:.4f}") 155 | fprint(f"FINAL FG ARI for {len(ari_fg)} images: {avg_ari_fg:.4f}") 156 | 157 | model.train() 158 | 159 | return avg_ari, avg_ari_fg, ari_list, ari_fg_list 160 | 161 | 162 | def iou_binary(mask_A, mask_B): 163 | assert mask_A.shape == mask_B.shape 164 | assert mask_A.dtype == torch.bool 165 | assert mask_B.dtype == torch.bool 166 | intersection = (mask_A * mask_B).sum((1, 2, 3)) 167 | union = (mask_A + mask_B).sum((1, 2, 3)) 168 | # Return -100 if union is zero, else return IOU 169 | return torch.where(union == 0, torch.tensor(-100.0), 170 | intersection.float() / union.float()) 171 | 172 | 173 | def average_segcover(segA, segB, ignore_background=False): 174 | """ 175 | Covering of segA by segB 176 | segA.shape = [batch size, 1, img_dim1, img_dim2] 177 | segB.shape = [batch size, 1, img_dim1, img_dim2] 178 | 179 | scale: If true, take weighted mean over IOU values proportional to the 180 | the number of pixels of the mask being covered. 181 | 182 | Assumes labels in segA and segB are non-negative integers. 183 | Negative labels will be ignored. 184 | """ 185 | 186 | assert segA.shape == segB.shape, f"{segA.shape} - {segB.shape}" 187 | assert segA.shape[1] == 1 and segB.shape[1] == 1 188 | bsz = segA.shape[0] 189 | nonignore = (segA >= 0) 190 | 191 | mean_scores = torch.tensor(bsz*[0.0]) 192 | N = torch.tensor(bsz*[0]) 193 | scaled_scores = torch.tensor(bsz*[0.0]) 194 | scaling_sum = torch.tensor(bsz*[0]) 195 | 196 | # Find unique label indices to iterate over 197 | if ignore_background: 198 | iter_segA = torch.unique(segA[segA > 0]).tolist() 199 | else: 200 | iter_segA = torch.unique(segA[segA >= 0]).tolist() 201 | iter_segB = torch.unique(segB[segB >= 0]).tolist() 202 | # Loop over segA 203 | for i in iter_segA: 204 | binaryA = segA == i 205 | if not binaryA.any(): 206 | continue 207 | max_iou = torch.tensor(bsz*[0.0]) 208 | # Loop over segB to find max IOU 209 | for j in iter_segB: 210 | # Do not penalise pixels that are in ignore regions 211 | binaryB = (segB == j) * nonignore 212 | if not binaryB.any(): 213 | continue 214 | iou = iou_binary(binaryA, binaryB) 215 | max_iou = torch.where(iou > max_iou, iou, max_iou) 216 | # Accumulate scores 217 | mean_scores += max_iou 218 | N = torch.where(binaryA.sum((1, 2, 3)) > 0, N+1, N) 219 | scaled_scores += binaryA.sum((1, 2, 3)).float() * max_iou 220 | scaling_sum += binaryA.sum((1, 2, 3)) 221 | 222 | # Compute coverage 223 | mean_sc = mean_scores / torch.max(N, torch.tensor(1)).float() 224 | scaled_sc = scaled_scores / torch.max(scaling_sum, torch.tensor(1)).float() 225 | 226 | # Sanity check 227 | assert (mean_sc >= 0).all() and (mean_sc <= 1).all(), mean_sc 228 | assert (scaled_sc >= 0).all() and (scaled_sc <= 1).all(), scaled_sc 229 | assert (mean_scores[N == 0] == 0).all() 230 | assert (mean_scores[nonignore.sum((1, 2, 3)) == 0] == 0).all() 231 | assert (scaled_scores[N == 0] == 0).all() 232 | assert (scaled_scores[nonignore.sum((1, 2, 3)) == 0] == 0).all() 233 | 234 | # Return mean over batch dimension 235 | return mean_sc.mean(0), scaled_sc.mean(0) 236 | 237 | 238 | def get_kl(z, q_z, p_z, montecarlo): 239 | if isinstance(q_z, list) or isinstance(q_z, tuple): 240 | assert len(q_z) == len(p_z) 241 | kl = [] 242 | for i in range(len(q_z)): 243 | if montecarlo: 244 | assert len(q_z) == len(z) 245 | kl.append(get_mc_kl(z[i], q_z[i], p_z[i])) 246 | else: 247 | kl.append(kl_divergence(q_z[i], p_z[i])) 248 | return kl 249 | elif montecarlo: 250 | return get_mc_kl(z, q_z, p_z) 251 | return kl_divergence(q_z, p_z) 252 | 253 | 254 | def get_mc_kl(z, q_z, p_z): 255 | return q_z.log_prob(z) - p_z.log_prob(z) 256 | 257 | 258 | def check_log_masks(log_m_k): 259 | summed_masks = torch.stack(log_m_k, dim=4).exp().sum(dim=4) 260 | summed_masks = summed_masks.clone().data.cpu().numpy() 261 | flat = summed_masks.flatten() 262 | diff = flat - np.ones_like(flat) 263 | idx = np.argmax(diff) 264 | max_diff = diff[idx] 265 | if max_diff > 1e-3 or np.any(np.isnan(flat)): 266 | print("Max difference: {}".format(max_diff)) 267 | for i, log_m in enumerate(log_m_k): 268 | mask_k = log_m.exp().data.cpu().numpy() 269 | print("Mask value at k={}: {}".format(i, mask_k.flatten()[idx])) 270 | raise ValueError("Masks do not sum to 1.0. Not close enough.") 271 | -------------------------------------------------------------------------------- /utils/plotting.py: -------------------------------------------------------------------------------- 1 | # =========================== A2I Copyright Header =========================== 2 | # 3 | # Copyright (c) 2003-2021 University of Oxford. All rights reserved. 4 | # Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford 5 | # https://ori.ox.ac.uk/labs/a2i/ 6 | # 7 | # This file is the property of the University of Oxford. 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, is not permitted without an explicit licensing agreement 10 | # (research or commercial). No warranty, explicit or implicit, provided. 11 | # 12 | # =========================== A2I Copyright Header =========================== 13 | 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | from matplotlib.colors import NoNorm 17 | 18 | def convert_to_np_im(torch_tensor, batch_idx=0): 19 | return np.moveaxis(torch_tensor.data.numpy()[batch_idx], 0, -1) 20 | 21 | def plot(axes, ax1, ax2, tensor=None, title=None, grey=False, axis=False, 22 | fontsize=4): 23 | if tensor is not None: 24 | im = convert_to_np_im(tensor) 25 | if grey: 26 | im = im[:, :, 0] 27 | axes[ax1, ax2].imshow(im, norm=NoNorm(), cmap='gray') 28 | else: 29 | axes[ax1, ax2].imshow(im) 30 | if not axis: 31 | axes[ax1, ax2].axis('off') 32 | else: 33 | axes[ax1, ax2].set_xticks([]) 34 | axes[ax1, ax2].set_yticks([]) 35 | if title is not None: 36 | axes[ax1, ax2].set_title(title, fontsize=fontsize) 37 | # axes[ax1, ax2].set_aspect('equal') 38 | -------------------------------------------------------------------------------- /utils/shapestacks_urls.txt: -------------------------------------------------------------------------------- 1 | http://shapestacks-file.robots.ox.ac.uk/static/download/v1/shapestacks-mjcf.tar.gz 2 | http://shapestacks-file.robots.ox.ac.uk/static/download/v1/shapestacks-meta.tar.gz 3 | http://shapestacks-file.robots.ox.ac.uk/static/download/v1/shapestacks-rgb.tar.gz 4 | --------------------------------------------------------------------------------