├── .gitignore ├── README.md ├── data.py ├── experiment_builder.py ├── experiment_config ├── alfa+maml.json ├── alfa+maml_resnet12.json ├── alfa+random_init.json └── alfa+random_init_resnet12.json ├── experiment_scripts ├── alfa+maml.sh ├── alfa+maml_resnet12.sh ├── alfa+random_init.sh └── alfa+random_init_resnet12.sh ├── few_shot_learning_system.py ├── inner_loop_optimizers.py ├── install.sh ├── meta_neural_network_architectures.py ├── train_maml_system.py └── utils ├── __init__.py ├── dataset_tools.py ├── parser_utils.py └── storage.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | *.swp 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ALFA - Meta-Learning with Adaptive Hyperparameters 2 | #### Sungyong Baik, Myungsub Choi, Janghoon Choi, Heewon Kim, Kyoung Mu Lee 3 | 4 | Source code for NeurIPS 2020 paper "Meta-Learning with Adaptive Hyperparameters" (previously titled "Adaptive Learning for Fast Adaptation") 5 | 6 | This repository is the implementation of ALFA. 7 | The code is based off the public code of MAML++, where their reimplementation of MAML is used as the baseline. 8 | 9 | [Paper-arXiv](https://arxiv.org/abs/2011.00209) 10 | 11 | ## Requirements 12 | 13 | - Ubuntu 18.04 14 | - Anaconda3 15 | - Python==3.6 16 | - PyTorch==1.5 17 | - numpy==1.19.1 18 | 19 | To install requirements, first download Anaconda3 and then run the following: 20 | 21 | ```setup 22 | bash install.sh 23 | ``` 24 | 25 | ## Hardware Requirements 26 | - GPU with memory more than 27GB for a single-GPU ResNet12 backbone second-order training. 27 | - The current version does not support a multi-GPU setting. While running on a multi-GPU will not give errors, it will give incorrect results (due to uneven distribution of labels across GPUs). 28 | 29 | ## Datasets 30 | For miniIamgenet, the dataset can be downloaded from the link provided from MAML++ public code. 31 | make a directory named 'datasets' and place the downloaded miniImagnet under the 'datasets' directory. 32 | 33 | 34 | ## Training 35 | 36 | To train the model(s) in the paper, run this command in experiment_scripts folder: 37 | 38 | For single GPU 39 | ```train 40 | bash alfa+maml.sh 0 41 | ``` 42 | where 0 represent GPU_ID. 43 | 44 | 45 | ## Evaluation 46 | 47 | After training is finished, the same command is run to evaluate: 48 | 49 | For single GPU: 50 | ```eval 51 | bash alfa+maml.sh 0 52 | ``` 53 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | from torch.utils.data import Dataset, DataLoader 6 | import tqdm 7 | import concurrent.futures 8 | import pickle 9 | import torch 10 | from torchvision import transforms 11 | from PIL import ImageFile 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | 14 | from utils.parser_utils import get_args 15 | 16 | 17 | class rotate_image(object): 18 | 19 | def __init__(self, k, channels): 20 | self.k = k 21 | self.channels = channels 22 | 23 | def __call__(self, image): 24 | if self.channels == 1 and len(image.shape) == 3: 25 | image = image[:, :, 0] 26 | image = np.expand_dims(image, axis=2) 27 | 28 | elif self.channels == 1 and len(image.shape) == 4: 29 | image = image[:, :, :, 0] 30 | image = np.expand_dims(image, axis=3) 31 | 32 | image = np.rot90(image, k=self.k).copy() 33 | return image 34 | 35 | 36 | class torch_rotate_image(object): 37 | 38 | def __init__(self, k, channels): 39 | self.k = k 40 | self.channels = channels 41 | 42 | def __call__(self, image): 43 | rotate = transforms.RandomRotation(degrees=self.k * 90) 44 | if image.shape[-1] == 1: 45 | image = image[:, :, 0] 46 | image = Image.fromarray(image) 47 | image = rotate(image) 48 | image = np.array(image) 49 | if len(image.shape) == 2: 50 | image = np.expand_dims(image, axis=2) 51 | return image 52 | 53 | 54 | def augment_image(image, k, channels, augment_bool, args, dataset_name): 55 | transform_train, transform_evaluation = get_transforms_for_dataset(dataset_name=dataset_name, 56 | args=args, k=k) 57 | if len(image.shape) > 3: 58 | images = [item for item in image] 59 | output_images = [] 60 | for image in images: 61 | if augment_bool is True: 62 | for transform_current in transform_train: 63 | image = transform_current(image) 64 | else: 65 | for transform_current in transform_evaluation: 66 | image = transform_current(image) 67 | output_images.append(image) 68 | image = torch.stack(output_images) 69 | else: 70 | if augment_bool is True: 71 | # meanstd transformation 72 | for transform_current in transform_train: 73 | image = transform_current(image) 74 | else: 75 | for transform_current in transform_evaluation: 76 | image = transform_current(image) 77 | return image 78 | 79 | 80 | def get_transforms_for_dataset(dataset_name, args, k): 81 | if "cifar10" in dataset_name or "cifar100" in dataset_name or "FC100" in dataset_name: 82 | transform_train = [ 83 | transforms.ToTensor(), 84 | transforms.Normalize((0.5071, 0.4847, 0.4408), (0.2675, 0.2565, 0.2761))] 85 | 86 | #transforms.RandomCrop(32, padding=4), 87 | #transforms.RandomHorizontalFlip(), 88 | #transforms.ToTensor(), 89 | #transforms.Normalize(args.classification_mean, args.classification_std)] 90 | 91 | transform_evaluate = [ 92 | transforms.ToTensor(), 93 | transforms.Normalize((0.5071, 0.4847, 0.4408), (0.2675, 0.2565, 0.2761))] 94 | #transforms.ToTensor(), 95 | #transforms.Normalize(args.classification_mean, args.classification_std)] 96 | 97 | elif 'omniglot' in dataset_name: 98 | 99 | transform_train = [rotate_image(k=k, channels=args.image_channels), transforms.ToTensor()] 100 | transform_evaluate = [transforms.ToTensor()] 101 | 102 | 103 | elif 'imagenet' in dataset_name: 104 | 105 | transform_train = [transforms.Compose([ 106 | 107 | transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])] 108 | 109 | transform_evaluate = [transforms.Compose([ 110 | 111 | transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])] 112 | 113 | return transform_train, transform_evaluate 114 | 115 | 116 | class FewShotLearningDatasetParallel(Dataset): 117 | def __init__(self, args): 118 | """ 119 | A data provider class inheriting from Pytorch's Dataset class. It takes care of creating task sets for 120 | our few-shot learning model training and evaluation 121 | :param args: Arguments in the form of a Bunch object. Includes all hyperparameters necessary for the 122 | data-provider. For transparency and readability reasons to explicitly set as self.object_name all arguments 123 | required for the data provider, such that the reader knows exactly what is necessary for the data provider/ 124 | """ 125 | self.data_path = args.dataset_path 126 | self.dataset_name = args.dataset_name 127 | self.data_loaded_in_memory = False 128 | self.image_height, self.image_width, self.image_channel = args.image_height, args.image_width, args.image_channels 129 | self.args = args 130 | self.indexes_of_folders_indicating_class = args.indexes_of_folders_indicating_class 131 | self.reverse_channels = args.reverse_channels 132 | self.labels_as_int = args.labels_as_int 133 | self.train_val_test_split = args.train_val_test_split 134 | self.current_set_name = "train" 135 | self.num_target_samples = args.num_target_samples 136 | self.reset_stored_filepaths = args.reset_stored_filepaths 137 | val_rng = np.random.RandomState(seed=args.val_seed) 138 | val_seed = val_rng.randint(1, 999999) 139 | train_rng = np.random.RandomState(seed=args.train_seed) 140 | train_seed = train_rng.randint(1, 999999) 141 | test_rng = np.random.RandomState(seed=args.val_seed) 142 | test_seed = test_rng.randint(1, 999999) 143 | args.val_seed = val_seed 144 | args.train_seed = train_seed 145 | args.test_seed = test_seed 146 | self.init_seed = {"train": args.train_seed, "val": args.val_seed, 'test': args.val_seed} 147 | self.seed = {"train": args.train_seed, "val": args.val_seed, 'test': args.val_seed} 148 | self.num_of_gpus = args.num_of_gpus 149 | self.batch_size = args.batch_size 150 | 151 | self.train_index = 0 152 | self.val_index = 0 153 | self.test_index = 0 154 | 155 | self.augment_images = False 156 | self.num_samples_per_class = args.num_samples_per_class 157 | self.num_classes_per_set = args.num_classes_per_set 158 | 159 | self.rng = np.random.RandomState(seed=self.seed['val']) 160 | self.datasets = self.load_dataset() 161 | 162 | self.indexes = {"train": 0, "val": 0, 'test': 0} 163 | self.dataset_size_dict = { 164 | "train": {key: len(self.datasets['train'][key]) for key in list(self.datasets['train'].keys())}, 165 | "val": {key: len(self.datasets['val'][key]) for key in list(self.datasets['val'].keys())}, 166 | 'test': {key: len(self.datasets['test'][key]) for key in list(self.datasets['test'].keys())}} 167 | self.label_set = self.get_label_set() 168 | self.data_length = {name: np.sum([len(self.datasets[name][key]) 169 | for key in self.datasets[name]]) for name in self.datasets.keys()} 170 | 171 | print("data", self.data_length) 172 | self.observed_seed_set = None 173 | 174 | def load_dataset(self): 175 | """ 176 | Loads a dataset's dictionary files and splits the data according to the train_val_test_split variable stored 177 | in the args object. 178 | :return: Three sets, the training set, validation set and test sets (referred to as the meta-train, 179 | meta-val and meta-test in the paper) 180 | """ 181 | rng = np.random.RandomState(seed=self.seed['val']) 182 | 183 | if self.args.sets_are_pre_split == True: 184 | data_image_paths, index_to_label_name_dict_file, label_to_index = self.load_datapaths() 185 | dataset_splits = dict() 186 | for key, value in data_image_paths.items(): 187 | key = self.get_label_from_index(index=key) 188 | bits = key.split("/") 189 | set_name = bits[0] 190 | class_label = bits[1] 191 | if set_name not in dataset_splits: 192 | dataset_splits[set_name] = {class_label: value} 193 | else: 194 | dataset_splits[set_name][class_label] = value 195 | else: 196 | data_image_paths, index_to_label_name_dict_file, label_to_index = self.load_datapaths() 197 | total_label_types = len(data_image_paths) 198 | num_classes_idx = np.arange(len(data_image_paths.keys()), dtype=np.int32) 199 | rng.shuffle(num_classes_idx) 200 | keys = list(data_image_paths.keys()) 201 | values = list(data_image_paths.values()) 202 | new_keys = [keys[idx] for idx in num_classes_idx] 203 | new_values = [values[idx] for idx in num_classes_idx] 204 | data_image_paths = dict(zip(new_keys, new_values)) 205 | # data_image_paths = self.shuffle(data_image_paths) 206 | x_train_id, x_val_id, x_test_id = int(self.train_val_test_split[0] * total_label_types), \ 207 | int(np.sum(self.train_val_test_split[:2]) * total_label_types), \ 208 | int(total_label_types) 209 | print(x_train_id, x_val_id, x_test_id) 210 | x_train_classes = (class_key for class_key in list(data_image_paths.keys())[:x_train_id]) 211 | x_val_classes = (class_key for class_key in list(data_image_paths.keys())[x_train_id:x_val_id]) 212 | x_test_classes = (class_key for class_key in list(data_image_paths.keys())[x_val_id:x_test_id]) 213 | x_train, x_val, x_test = {class_key: data_image_paths[class_key] for class_key in x_train_classes}, \ 214 | {class_key: data_image_paths[class_key] for class_key in x_val_classes}, \ 215 | {class_key: data_image_paths[class_key] for class_key in x_test_classes}, 216 | dataset_splits = {"train": x_train, "val":x_val , "test": x_test} 217 | 218 | if self.args.load_into_memory is True: 219 | 220 | print("Loading data into RAM") 221 | x_loaded = {"train": [], "val": [], "test": []} 222 | 223 | for set_key, set_value in dataset_splits.items(): 224 | print("Currently loading into memory the {} set".format(set_key)) 225 | x_loaded[set_key] = {key: np.zeros(len(value), ) for key, value in set_value.items()} 226 | # for class_key, class_value in set_value.items(): 227 | with tqdm.tqdm(total=len(set_value)) as pbar_memory_load: 228 | with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor: 229 | # Process the list of files, but split the work across the process pool to use all CPUs! 230 | for (class_label, class_images_loaded) in executor.map(self.load_parallel_batch, (set_value.items())): 231 | x_loaded[set_key][class_label] = class_images_loaded 232 | pbar_memory_load.update(1) 233 | 234 | dataset_splits = x_loaded 235 | self.data_loaded_in_memory = True 236 | 237 | return dataset_splits 238 | 239 | def load_datapaths(self): 240 | """ 241 | If saved json dictionaries of the data are available, then this method loads the dictionaries such that the 242 | data is ready to be read. If the json dictionaries do not exist, then this method calls get_data_paths() 243 | which will build the json dictionary containing the class to filepath samples, and then store them. 244 | :return: data_image_paths: dict containing class to filepath list pairs. 245 | index_to_label_name_dict_file: dict containing numerical indexes mapped to the human understandable 246 | string-names of the class 247 | label_to_index: dictionary containing human understandable string mapped to numerical indexes 248 | """ 249 | dataset_dir = os.environ['DATASET_DIR'] 250 | data_path_file = "{}/{}.json".format(dataset_dir, self.dataset_name) 251 | self.index_to_label_name_dict_file = "{}/map_to_label_name_{}.json".format(dataset_dir, self.dataset_name) 252 | self.label_name_to_map_dict_file = "{}/label_name_to_map_{}.json".format(dataset_dir, self.dataset_name) 253 | 254 | if not os.path.exists(data_path_file): 255 | self.reset_stored_filepaths = True 256 | 257 | if self.reset_stored_filepaths == True: 258 | if os.path.exists(data_path_file): 259 | os.remove(data_path_file) 260 | self.reset_stored_filepaths = False 261 | 262 | try: 263 | data_image_paths = self.load_from_json(filename=data_path_file) 264 | label_to_index = self.load_from_json(filename=self.label_name_to_map_dict_file) 265 | index_to_label_name_dict_file = self.load_from_json(filename=self.index_to_label_name_dict_file) 266 | return data_image_paths, index_to_label_name_dict_file, label_to_index 267 | except: 268 | print("Mapped data paths can't be found, remapping paths..") 269 | data_image_paths, code_to_label_name, label_name_to_code = self.get_data_paths() 270 | self.save_to_json(dict_to_store=data_image_paths, filename=data_path_file) 271 | self.save_to_json(dict_to_store=code_to_label_name, filename=self.index_to_label_name_dict_file) 272 | self.save_to_json(dict_to_store=label_name_to_code, filename=self.label_name_to_map_dict_file) 273 | return self.load_datapaths() 274 | 275 | def save_to_json(self, filename, dict_to_store): 276 | with open(os.path.abspath(filename), 'w') as f: 277 | json.dump(dict_to_store, fp=f) 278 | 279 | def load_from_json(self, filename): 280 | with open(filename, mode="r") as f: 281 | load_dict = json.load(fp=f) 282 | 283 | return load_dict 284 | 285 | def load_test_image(self, filepath): 286 | """ 287 | Tests whether a target filepath contains an uncorrupted image. If image is corrupted, attempt to fix. 288 | :param filepath: Filepath of image to be tested 289 | :return: Return filepath of image if image exists and is uncorrupted (or attempt to fix has succeeded), 290 | else return None 291 | """ 292 | image = None 293 | try: 294 | image = Image.open(filepath) 295 | except RuntimeWarning: 296 | os.system("convert {} -strip {}".format(filepath, filepath)) 297 | print("converting") 298 | image = Image.open(filepath) 299 | except: 300 | print("Broken image") 301 | 302 | if image is not None: 303 | return filepath 304 | else: 305 | return None 306 | 307 | def get_data_paths(self): 308 | """ 309 | Method that scans the dataset directory and generates class to image-filepath list dictionaries. 310 | :return: data_image_paths: dict containing class to filepath list pairs. 311 | index_to_label_name_dict_file: dict containing numerical indexes mapped to the human understandable 312 | string-names of the class 313 | label_to_index: dictionary containing human understandable string mapped to numerical indexes 314 | """ 315 | print("Get images from", self.data_path) 316 | data_image_path_list_raw = [] 317 | labels = set() 318 | for subdir, dir, files in os.walk(self.data_path): 319 | for file in files: 320 | if (".jpeg") in file.lower() or (".png") in file.lower() or (".jpg") in file.lower(): 321 | filepath = os.path.abspath(os.path.join(subdir, file)) 322 | label = self.get_label_from_path(filepath) 323 | data_image_path_list_raw.append(filepath) 324 | labels.add(label) 325 | 326 | labels = sorted(labels) 327 | idx_to_label_name = {idx: label for idx, label in enumerate(labels)} 328 | label_name_to_idx = {label: idx for idx, label in enumerate(labels)} 329 | data_image_path_dict = {idx: [] for idx in list(idx_to_label_name.keys())} 330 | with tqdm.tqdm(total=len(data_image_path_list_raw)) as pbar_error: 331 | with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor: 332 | # Process the list of files, but split the work across the process pool to use all CPUs! 333 | for image_file in executor.map(self.load_test_image, (data_image_path_list_raw)): 334 | pbar_error.update(1) 335 | if image_file is not None: 336 | label = self.get_label_from_path(image_file) 337 | data_image_path_dict[label_name_to_idx[label]].append(image_file) 338 | 339 | return data_image_path_dict, idx_to_label_name, label_name_to_idx 340 | 341 | def get_label_set(self): 342 | """ 343 | Generates a set containing all class numerical indexes 344 | :return: A set containing all class numerical indexes 345 | """ 346 | index_to_label_name_dict_file = self.load_from_json(filename=self.index_to_label_name_dict_file) 347 | return set(list(index_to_label_name_dict_file.keys())) 348 | 349 | def get_index_from_label(self, label): 350 | """ 351 | Given a class's (human understandable) string, returns the numerical index of that class 352 | :param label: A string of a human understandable class contained in the dataset 353 | :return: An int containing the numerical index of the given class-string 354 | """ 355 | label_to_index = self.load_from_json(filename=self.label_name_to_map_dict_file) 356 | return label_to_index[label] 357 | 358 | def get_label_from_index(self, index): 359 | """ 360 | Given an index return the human understandable label mapping to it. 361 | :param index: A numerical index (int) 362 | :return: A human understandable label (str) 363 | """ 364 | index_to_label_name = self.load_from_json(filename=self.index_to_label_name_dict_file) 365 | return index_to_label_name[index] 366 | 367 | def get_label_from_path(self, filepath): 368 | """ 369 | Given a path of an image generate the human understandable label for that image. 370 | :param filepath: The image's filepath 371 | :return: A human understandable label. 372 | """ 373 | label_bits = filepath.split("/") 374 | label = "/".join([label_bits[idx] for idx in self.indexes_of_folders_indicating_class]) 375 | if self.labels_as_int: 376 | label = int(label) 377 | return label 378 | 379 | def load_image(self, image_path, channels): 380 | """ 381 | Given an image filepath and the number of channels to keep, load an image and keep the specified channels 382 | :param image_path: The image's filepath 383 | :param channels: The number of channels to keep 384 | :return: An image array of shape (h, w, channels), whose values range between 0.0 and 1.0. 385 | """ 386 | if not self.data_loaded_in_memory: 387 | image = Image.open(image_path) 388 | if 'omniglot' in self.dataset_name: 389 | image = image.resize((self.image_height, self.image_width), resample=Image.LANCZOS) 390 | image = np.array(image, np.float32) 391 | if channels == 1: 392 | image = np.expand_dims(image, axis=2) 393 | else: 394 | image = image.resize((self.image_height, self.image_width)).convert('RGB') 395 | image = np.array(image, np.float32) 396 | image = image / 255.0 397 | else: 398 | image = image_path 399 | 400 | return image 401 | 402 | def load_batch(self, batch_image_paths): 403 | """ 404 | Load a batch of images, given a list of filepaths 405 | :param batch_image_paths: A list of filepaths 406 | :return: A numpy array of images of shape batch, height, width, channels 407 | """ 408 | image_batch = [] 409 | 410 | if self.data_loaded_in_memory: 411 | for image_path in batch_image_paths: 412 | image_batch.append(image_path) 413 | image_batch = np.array(image_batch, dtype=np.float32) 414 | #print(image_batch.shape) 415 | else: 416 | image_batch = [self.load_image(image_path=image_path, channels=self.image_channel) 417 | for image_path in batch_image_paths] 418 | image_batch = np.array(image_batch, dtype=np.float32) 419 | image_batch = self.preprocess_data(image_batch) 420 | 421 | return image_batch 422 | 423 | def load_parallel_batch(self, inputs): 424 | """ 425 | Load a batch of images, given a list of filepaths 426 | :param batch_image_paths: A list of filepaths 427 | :return: A numpy array of images of shape batch, height, width, channels 428 | """ 429 | class_label, batch_image_paths = inputs 430 | image_batch = [] 431 | 432 | if self.data_loaded_in_memory: 433 | for image_path in batch_image_paths: 434 | image_batch.append(np.copy(image_path)) 435 | image_batch = np.array(image_batch, dtype=np.float32) 436 | else: 437 | #with tqdm.tqdm(total=1) as load_pbar: 438 | image_batch = [self.load_image(image_path=image_path, channels=self.image_channel) 439 | for image_path in batch_image_paths] 440 | #load_pbar.update(1) 441 | 442 | image_batch = np.array(image_batch, dtype=np.float32) 443 | image_batch = self.preprocess_data(image_batch) 444 | 445 | return class_label, image_batch 446 | 447 | def preprocess_data(self, x): 448 | """ 449 | Preprocesses data such that their shapes match the specified structures 450 | :param x: A data batch to preprocess 451 | :return: A preprocessed data batch 452 | """ 453 | x_shape = x.shape 454 | x = np.reshape(x, (-1, x_shape[-3], x_shape[-2], x_shape[-1])) 455 | if self.reverse_channels is True: 456 | reverse_photos = np.ones(shape=x.shape) 457 | for channel in range(x.shape[-1]): 458 | reverse_photos[:, :, :, x.shape[-1] - 1 - channel] = x[:, :, :, channel] 459 | x = reverse_photos 460 | x = x.reshape(x_shape) 461 | return x 462 | 463 | def reconstruct_original(self, x): 464 | """ 465 | Applies the reverse operations that preprocess_data() applies such that the data returns to their original form 466 | :param x: A batch of data to reconstruct 467 | :return: A reconstructed batch of data 468 | """ 469 | x = x * 255.0 470 | return x 471 | 472 | def shuffle(self, x, rng): 473 | """ 474 | Shuffles the data batch along it's first axis 475 | :param x: A data batch 476 | :return: A shuffled data batch 477 | """ 478 | indices = np.arange(len(x)) 479 | rng.shuffle(indices) 480 | x = x[indices] 481 | return x 482 | 483 | def get_set(self, dataset_name, seed, augment_images=False): 484 | """ 485 | Generates a task-set to be used for training or evaluation 486 | :param set_name: The name of the set to use, e.g. "train", "val" etc. 487 | :return: A task-set containing an image and label support set, and an image and label target set. 488 | """ 489 | #seed = seed % self.args.total_unique_tasks 490 | rng = np.random.RandomState(seed) 491 | selected_classes = rng.choice(list(self.dataset_size_dict[dataset_name].keys()), 492 | size=self.num_classes_per_set, replace=False) 493 | rng.shuffle(selected_classes) 494 | k_list = rng.randint(0, 4, size=self.num_classes_per_set) 495 | k_dict = {selected_class: k_item for (selected_class, k_item) in zip(selected_classes, k_list)} 496 | episode_labels = [i for i in range(self.num_classes_per_set)] 497 | class_to_episode_label = {selected_class: episode_label for (selected_class, episode_label) in 498 | zip(selected_classes, episode_labels)} 499 | 500 | x_images = [] 501 | y_labels = [] 502 | 503 | for class_entry in selected_classes: 504 | choose_samples_list = rng.choice(self.dataset_size_dict[dataset_name][class_entry], 505 | size=self.num_samples_per_class + self.num_target_samples, replace=False) 506 | class_image_samples = [] 507 | class_labels = [] 508 | for sample in choose_samples_list: 509 | choose_samples = self.datasets[dataset_name][class_entry][sample] 510 | x_class_data = self.load_batch([choose_samples])[0] 511 | k = k_dict[class_entry] 512 | x_class_data = augment_image(image=x_class_data, k=k, 513 | channels=self.image_channel, augment_bool=augment_images, 514 | dataset_name=self.dataset_name, args=self.args) 515 | class_image_samples.append(x_class_data) 516 | class_labels.append(int(class_to_episode_label[class_entry])) 517 | class_image_samples = torch.stack(class_image_samples) 518 | x_images.append(class_image_samples) 519 | y_labels.append(class_labels) 520 | 521 | x_images = torch.stack(x_images) 522 | y_labels = np.array(y_labels, dtype=np.float32) 523 | 524 | support_set_images = x_images[:, :self.num_samples_per_class] 525 | support_set_labels = y_labels[:, :self.num_samples_per_class] 526 | target_set_images = x_images[:, self.num_samples_per_class:] 527 | target_set_labels = y_labels[:, self.num_samples_per_class:] 528 | 529 | return support_set_images, target_set_images, support_set_labels, target_set_labels, seed 530 | 531 | def __len__(self): 532 | total_samples = self.data_length[self.current_set_name] 533 | return total_samples 534 | 535 | def length(self, set_name): 536 | self.switch_set(set_name=set_name) 537 | return len(self) 538 | 539 | def set_augmentation(self, augment_images): 540 | self.augment_images = augment_images 541 | 542 | def switch_set(self, set_name, current_iter=None): 543 | self.current_set_name = set_name 544 | if set_name == "train": 545 | self.update_seed(dataset_name=set_name, seed=self.init_seed[set_name] + current_iter) 546 | 547 | def update_seed(self, dataset_name, seed=100): 548 | self.seed[dataset_name] = seed 549 | 550 | def __getitem__(self, idx): 551 | support_set_images, target_set_image, support_set_labels, target_set_label, seed = \ 552 | self.get_set(self.current_set_name, seed=self.seed[self.current_set_name] + idx, 553 | augment_images=self.augment_images) 554 | 555 | return support_set_images, target_set_image, support_set_labels, target_set_label, seed 556 | 557 | def reset_seed(self): 558 | self.seed = self.init_seed 559 | 560 | 561 | class MetaLearningSystemDataLoader(object): 562 | def __init__(self, args, current_iter=0): 563 | """ 564 | Initializes a meta learning system dataloader. The data loader uses the Pytorch DataLoader class to parallelize 565 | batch sampling and preprocessing. 566 | :param args: An arguments NamedTuple containing all the required arguments. 567 | :param current_iter: Current iter of experiment. Is used to make sure the data loader continues where it left 568 | of previously. 569 | """ 570 | self.num_of_gpus = args.num_of_gpus 571 | self.batch_size = args.batch_size 572 | self.samples_per_iter = args.samples_per_iter 573 | self.num_workers = args.num_dataprovider_workers 574 | self.total_train_iters_produced = 0 575 | self.dataset = FewShotLearningDatasetParallel(args=args) 576 | self.batches_per_iter = args.samples_per_iter 577 | self.full_data_length = self.dataset.data_length 578 | self.continue_from_iter(current_iter=current_iter) 579 | self.args = args 580 | 581 | def get_dataloader(self): 582 | """ 583 | Returns a data loader with the correct set (train, val or test), continuing from the current iter. 584 | :return: 585 | """ 586 | return DataLoader(self.dataset, batch_size=(self.num_of_gpus * self.batch_size * self.samples_per_iter), 587 | shuffle=False, num_workers=self.num_workers, drop_last=True) 588 | 589 | def continue_from_iter(self, current_iter): 590 | """ 591 | Makes sure the data provider is aware of where we are in terms of training iterations in the experiment. 592 | :param current_iter: 593 | """ 594 | self.total_train_iters_produced += (current_iter * (self.num_of_gpus * self.batch_size * self.samples_per_iter)) 595 | 596 | def get_train_batches(self, total_batches=-1, augment_images=False): 597 | """ 598 | Returns a training batches data_loader 599 | :param total_batches: The number of batches we want the data loader to sample 600 | :param augment_images: Whether we want the images to be augmented. 601 | """ 602 | if total_batches == -1: 603 | self.dataset.data_length = self.full_data_length 604 | else: 605 | self.dataset.data_length["train"] = total_batches * self.dataset.batch_size 606 | self.dataset.switch_set(set_name="train", current_iter=self.total_train_iters_produced) 607 | self.dataset.set_augmentation(augment_images=augment_images) 608 | self.total_train_iters_produced += (self.num_of_gpus * self.batch_size * self.samples_per_iter) 609 | for sample_id, sample_batched in enumerate(self.get_dataloader()): 610 | yield sample_batched 611 | 612 | 613 | def get_val_batches(self, total_batches=-1, augment_images=False): 614 | """ 615 | Returns a validation batches data_loader 616 | :param total_batches: The number of batches we want the data loader to sample 617 | :param augment_images: Whether we want the images to be augmented. 618 | """ 619 | if total_batches == -1: 620 | self.dataset.data_length = self.full_data_length 621 | else: 622 | self.dataset.data_length['val'] = total_batches * self.dataset.batch_size 623 | self.dataset.switch_set(set_name="val") 624 | self.dataset.set_augmentation(augment_images=augment_images) 625 | for sample_id, sample_batched in enumerate(self.get_dataloader()): 626 | yield sample_batched 627 | 628 | 629 | def get_test_batches(self, total_batches=-1, augment_images=False): 630 | """ 631 | Returns a testing batches data_loader 632 | :param total_batches: The number of batches we want the data loader to sample 633 | :param augment_images: Whether we want the images to be augmented. 634 | """ 635 | if total_batches == -1: 636 | self.dataset.data_length = self.full_data_length 637 | else: 638 | self.dataset.data_length['test'] = total_batches * self.dataset.batch_size 639 | self.dataset.switch_set(set_name='test') 640 | self.dataset.set_augmentation(augment_images=augment_images) 641 | for sample_id, sample_batched in enumerate(self.get_dataloader()): 642 | yield sample_batched 643 | 644 | -------------------------------------------------------------------------------- /experiment_builder.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import os 3 | import numpy as np 4 | import sys 5 | from utils.storage import build_experiment_folder, save_statistics, save_to_json 6 | import time 7 | import torch 8 | 9 | 10 | class ExperimentBuilder(object): 11 | def __init__(self, args, data, model, device): 12 | """ 13 | Initializes an experiment builder using a named tuple (args), a data provider (data), a meta learning system 14 | (model) and a device (e.g. gpu/cpu/n) 15 | :param args: A namedtuple containing all experiment hyperparameters 16 | :param data: A data provider of instance MetaLearningSystemDataLoader 17 | :param model: A meta learning system instance 18 | :param device: Device/s to use for the experiment 19 | """ 20 | self.args, self.device = args, device 21 | 22 | self.model = model 23 | self.saved_models_filepath, self.logs_filepath, self.samples_filepath = build_experiment_folder( 24 | experiment_name=self.args.experiment_name) 25 | 26 | self.total_losses = dict() 27 | self.state = dict() 28 | self.state['best_val_acc'] = 0. 29 | self.state['best_val_iter'] = 0 30 | self.state['current_iter'] = 0 31 | self.state['current_iter'] = 0 32 | self.start_epoch = 0 33 | self.max_models_to_save = self.args.max_models_to_save 34 | self.create_summary_csv = False 35 | 36 | experiment_path = os.path.abspath(self.args.experiment_name) 37 | exp_name = experiment_path.split('/')[-1] 38 | log_base_dir = 'logs' 39 | os.makedirs(log_base_dir, exist_ok=True) 40 | 41 | log_dir = os.path.join(log_base_dir, exp_name) 42 | print(log_dir) 43 | 44 | if self.args.continue_from_epoch == 'from_scratch': 45 | self.create_summary_csv = True 46 | 47 | elif self.args.continue_from_epoch == 'latest': 48 | checkpoint = os.path.join(self.saved_models_filepath, "train_model_latest") 49 | print("attempting to find existing checkpoint", ) 50 | if os.path.exists(checkpoint): 51 | self.state = \ 52 | self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model", 53 | model_idx='latest') 54 | self.start_epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch) 55 | 56 | else: 57 | self.args.continue_from_epoch = 'from_scratch' 58 | self.create_summary_csv = True 59 | elif int(self.args.continue_from_epoch) >= 0: 60 | self.state = \ 61 | self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model", 62 | model_idx=self.args.continue_from_epoch) 63 | self.start_epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch) 64 | 65 | self.data = data(args=args, current_iter=self.state['current_iter']) 66 | 67 | print("train_seed {}, val_seed: {}, at start time".format(self.data.dataset.seed["train"], 68 | self.data.dataset.seed["val"])) 69 | self.total_epochs_before_pause = self.args.total_epochs_before_pause 70 | self.state['best_epoch'] = int(self.state['best_val_iter'] / self.args.total_iter_per_epoch) 71 | self.epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch) 72 | self.augment_flag = True if 'omniglot' in self.args.dataset_name.lower() else False 73 | self.start_time = time.time() 74 | self.epochs_done_in_this_run = 0 75 | print(self.state['current_iter'], int(self.args.total_iter_per_epoch * self.args.total_epochs)) 76 | 77 | def build_summary_dict(self, total_losses, phase, summary_losses=None): 78 | """ 79 | Builds/Updates a summary dict directly from the metric dict of the current iteration. 80 | :param total_losses: Current dict with total losses (not aggregations) from experiment 81 | :param phase: Current training phase 82 | :param summary_losses: Current summarised (aggregated/summarised) losses stats means, stdv etc. 83 | :return: A new summary dict with the updated summary statistics information. 84 | """ 85 | if summary_losses is None: 86 | summary_losses = dict() 87 | 88 | for key in total_losses: 89 | summary_losses["{}_{}_mean".format(phase, key)] = np.mean(total_losses[key]) 90 | summary_losses["{}_{}_std".format(phase, key)] = np.std(total_losses[key]) 91 | 92 | return summary_losses 93 | 94 | def build_loss_summary_string(self, summary_losses): 95 | """ 96 | Builds a progress bar summary string given current summary losses dictionary 97 | :param summary_losses: Current summary statistics 98 | :return: A summary string ready to be shown to humans. 99 | """ 100 | output_update = "" 101 | for key, value in zip(list(summary_losses.keys()), list(summary_losses.values())): 102 | if "loss" in key or "accuracy" in key: 103 | value = float(value) 104 | output_update += "{}: {:.4f}, ".format(key, value) 105 | 106 | return output_update 107 | 108 | def merge_two_dicts(self, first_dict, second_dict): 109 | """Given two dicts, merge them into a new dict as a shallow copy.""" 110 | z = first_dict.copy() 111 | z.update(second_dict) 112 | return z 113 | 114 | def train_iteration(self, train_sample, sample_idx, epoch_idx, total_losses, current_iter, pbar_train): 115 | """ 116 | Runs a training iteration, updates the progress bar and returns the total and current epoch train losses. 117 | :param train_sample: A sample from the data provider 118 | :param sample_idx: The index of the incoming sample, in relation to the current training run. 119 | :param epoch_idx: The epoch index. 120 | :param total_losses: The current total losses dictionary to be updated. 121 | :param current_iter: The current training iteration in relation to the whole experiment. 122 | :param pbar_train: The progress bar of the training. 123 | :return: Updates total_losses, train_losses, current_iter 124 | """ 125 | x_support_set, x_target_set, y_support_set, y_target_set, seed = train_sample 126 | data_batch = (x_support_set, x_target_set, y_support_set, y_target_set) 127 | 128 | if sample_idx == 0: 129 | print("shape of data", x_support_set.shape, x_target_set.shape, y_support_set.shape, 130 | y_target_set.shape) 131 | 132 | losses, _ = self.model.run_train_iter(data_batch=data_batch, epoch=epoch_idx) 133 | 134 | for key, value in zip(list(losses.keys()), list(losses.values())): 135 | if key not in total_losses: 136 | total_losses[key] = [float(value)] 137 | else: 138 | total_losses[key].append(float(value)) 139 | 140 | train_losses = self.build_summary_dict(total_losses=total_losses, phase="train") 141 | train_output_update = self.build_loss_summary_string(losses) 142 | 143 | pbar_train.update(1) 144 | pbar_train.set_description("training phase {} -> {}".format(self.epoch, train_output_update)) 145 | 146 | current_iter += 1 147 | 148 | return train_losses, total_losses, current_iter 149 | 150 | def evaluation_iteration(self, val_sample, total_losses, pbar_val, phase): 151 | """ 152 | Runs a validation iteration, updates the progress bar and returns the total and current epoch val losses. 153 | :param val_sample: A sample from the data provider 154 | :param total_losses: The current total losses dictionary to be updated. 155 | :param pbar_val: The progress bar of the val stage. 156 | :return: The updated val_losses, total_losses 157 | """ 158 | x_support_set, x_target_set, y_support_set, y_target_set, seed = val_sample 159 | data_batch = ( 160 | x_support_set, x_target_set, y_support_set, y_target_set) 161 | 162 | losses, _ = self.model.run_validation_iter(data_batch=data_batch) 163 | for key, value in zip(list(losses.keys()), list(losses.values())): 164 | if key not in total_losses: 165 | total_losses[key] = [float(value)] 166 | else: 167 | total_losses[key].append(float(value)) 168 | 169 | val_losses = self.build_summary_dict(total_losses=total_losses, phase=phase) 170 | val_output_update = self.build_loss_summary_string(losses) 171 | 172 | pbar_val.update(1) 173 | pbar_val.set_description( 174 | "val_phase {} -> {}".format(self.epoch, val_output_update)) 175 | 176 | return val_losses, total_losses 177 | 178 | def test_evaluation_iteration(self, val_sample, model_idx, sample_idx, per_model_per_batch_preds, pbar_test): 179 | """ 180 | Runs a validation iteration, updates the progress bar and returns the total and current epoch val losses. 181 | :param val_sample: A sample from the data provider 182 | :param total_losses: The current total losses dictionary to be updated. 183 | :param pbar_test: The progress bar of the val stage. 184 | :return: The updated val_losses, total_losses 185 | """ 186 | x_support_set, x_target_set, y_support_set, y_target_set, seed = val_sample 187 | data_batch = ( 188 | x_support_set, x_target_set, y_support_set, y_target_set) 189 | 190 | losses, per_task_preds = self.model.run_validation_iter(data_batch=data_batch) 191 | 192 | per_model_per_batch_preds[model_idx].extend(list(per_task_preds)) 193 | 194 | test_output_update = self.build_loss_summary_string(losses) 195 | 196 | pbar_test.update(1) 197 | pbar_test.set_description( 198 | "test_phase {} -> {}".format(self.epoch, test_output_update)) 199 | 200 | return per_model_per_batch_preds 201 | 202 | def save_models(self, model, epoch, state): 203 | """ 204 | Saves two separate instances of the current model. One to be kept for history and reloading later and another 205 | one marked as "latest" to be used by the system for the next epoch training. Useful when the training/val 206 | process is interrupted or stopped. Leads to fault tolerant training and validation systems that can continue 207 | from where they left off before. 208 | :param model: Current meta learning model of any instance within the few_shot_learning_system.py 209 | :param epoch: Current epoch 210 | :param state: Current model and experiment state dict. 211 | """ 212 | model.save_model(model_save_dir=os.path.join(self.saved_models_filepath, "train_model_{}".format(int(epoch))), 213 | state=state) 214 | 215 | model.save_model(model_save_dir=os.path.join(self.saved_models_filepath, "train_model_latest"), 216 | state=state) 217 | 218 | print("saved models to", self.saved_models_filepath) 219 | 220 | def pack_and_save_metrics(self, start_time, create_summary_csv, train_losses, val_losses, state): 221 | """ 222 | Given current epochs start_time, train losses, val losses and whether to create a new stats csv file, pack stats 223 | and save into a statistics csv file. Return a new start time for the new epoch. 224 | :param start_time: The start time of the current epoch 225 | :param create_summary_csv: A boolean variable indicating whether to create a new statistics file or 226 | append results to existing one 227 | :param train_losses: A dictionary with the current train losses 228 | :param val_losses: A dictionary with the currrent val loss 229 | :return: The current time, to be used for the next epoch. 230 | """ 231 | epoch_summary_losses = self.merge_two_dicts(first_dict=train_losses, second_dict=val_losses) 232 | 233 | if 'per_epoch_statistics' not in state: 234 | state['per_epoch_statistics'] = dict() 235 | 236 | for key, value in epoch_summary_losses.items(): 237 | 238 | if key not in state['per_epoch_statistics']: 239 | state['per_epoch_statistics'][key] = [value] 240 | else: 241 | state['per_epoch_statistics'][key].append(value) 242 | 243 | epoch_summary_string = self.build_loss_summary_string(epoch_summary_losses) 244 | epoch_summary_losses["epoch"] = self.epoch 245 | epoch_summary_losses['epoch_run_time'] = time.time() - start_time 246 | 247 | if create_summary_csv: 248 | self.summary_statistics_filepath = save_statistics(self.logs_filepath, list(epoch_summary_losses.keys()), 249 | create=True) 250 | self.create_summary_csv = False 251 | 252 | start_time = time.time() 253 | print("epoch {} -> {}".format(epoch_summary_losses["epoch"], epoch_summary_string)) 254 | 255 | self.summary_statistics_filepath = save_statistics(self.logs_filepath, 256 | list(epoch_summary_losses.values())) 257 | return start_time, state 258 | 259 | def evaluated_test_set_using_the_best_models(self, top_n_models): 260 | per_epoch_statistics = self.state['per_epoch_statistics'] 261 | val_acc = np.copy(per_epoch_statistics['val_accuracy_mean']) 262 | val_idx = np.array([i for i in range(len(val_acc))]) 263 | sorted_idx = np.argsort(val_acc, axis=0).astype(dtype=np.int32)[::-1][:top_n_models] 264 | 265 | sorted_val_acc = val_acc[sorted_idx] 266 | val_idx = val_idx[sorted_idx] 267 | print(sorted_idx) 268 | print(sorted_val_acc) 269 | 270 | top_n_idx = val_idx[:top_n_models] 271 | per_model_per_batch_preds = [[] for i in range(top_n_models)] 272 | per_model_per_batch_targets = [[] for i in range(top_n_models)] 273 | test_losses = [dict() for i in range(top_n_models)] 274 | for idx, model_idx in enumerate(top_n_idx): 275 | self.state = \ 276 | self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model", 277 | model_idx=model_idx + 1) 278 | with tqdm.tqdm(total=int(self.args.num_evaluation_tasks / self.args.batch_size)) as pbar_test: 279 | for sample_idx, test_sample in enumerate( 280 | self.data.get_test_batches(total_batches=int(self.args.num_evaluation_tasks / self.args.batch_size), 281 | augment_images=False)): 282 | #print(test_sample[4]) 283 | per_model_per_batch_targets[idx].extend(np.array(test_sample[3])) 284 | per_model_per_batch_preds = self.test_evaluation_iteration(val_sample=test_sample, 285 | sample_idx=sample_idx, 286 | model_idx=idx, 287 | per_model_per_batch_preds=per_model_per_batch_preds, 288 | pbar_test=pbar_test) 289 | # for i in range(top_n_models): 290 | # print("test assertion", 0) 291 | # print(per_model_per_batch_targets[0], per_model_per_batch_targets[i]) 292 | # assert np.equal(np.array(per_model_per_batch_targets[0]), np.array(per_model_per_batch_targets[i])) 293 | 294 | per_batch_preds = np.mean(per_model_per_batch_preds, axis=0) 295 | #print(per_batch_preds.shape) 296 | per_batch_max = np.argmax(per_batch_preds, axis=2) 297 | per_batch_targets = np.array(per_model_per_batch_targets[0]).reshape(per_batch_max.shape) 298 | #print(per_batch_max) 299 | accuracy = np.mean(np.equal(per_batch_targets, per_batch_max)) 300 | accuracy_std = np.std(np.equal(per_batch_targets, per_batch_max)) 301 | 302 | test_losses = {"test_accuracy_mean": accuracy, "test_accuracy_std": accuracy_std} 303 | 304 | _ = save_statistics(self.logs_filepath, 305 | list(test_losses.keys()), 306 | create=True, filename="test_summary.csv") 307 | 308 | summary_statistics_filepath = save_statistics(self.logs_filepath, 309 | list(test_losses.values()), 310 | create=False, filename="test_summary.csv") 311 | print(test_losses) 312 | print("saved test performance at", summary_statistics_filepath) 313 | 314 | def run_experiment(self): 315 | """ 316 | Runs a full training experiment with evaluations of the model on the val set at every epoch. Furthermore, 317 | will return the test set evaluation results on the best performing validation model. 318 | """ 319 | with tqdm.tqdm(initial=self.state['current_iter'], 320 | total=int(self.args.total_iter_per_epoch * self.args.total_epochs)) as pbar_train: 321 | 322 | while (self.state['current_iter'] < (self.args.total_epochs * self.args.total_iter_per_epoch)) and (self.args.evaluate_on_test_set_only == False): 323 | 324 | for train_sample_idx, train_sample in enumerate( 325 | self.data.get_train_batches(total_batches=int(self.args.total_iter_per_epoch * 326 | self.args.total_epochs) - self.state[ 327 | 'current_iter'], 328 | augment_images=self.augment_flag)): 329 | # print(self.state['current_iter'], (self.args.total_epochs * self.args.total_iter_per_epoch)) 330 | train_losses, total_losses, self.state['current_iter'] = self.train_iteration( 331 | train_sample=train_sample, 332 | total_losses=self.total_losses, 333 | epoch_idx=(self.state['current_iter'] / 334 | self.args.total_iter_per_epoch), 335 | pbar_train=pbar_train, 336 | current_iter=self.state['current_iter'], 337 | sample_idx=self.state['current_iter']) 338 | 339 | if self.state['current_iter'] % self.args.total_iter_per_epoch == 0: 340 | 341 | total_losses = dict() 342 | val_losses = dict() 343 | with tqdm.tqdm(total=int(self.args.num_evaluation_tasks / self.args.batch_size)) as pbar_val: 344 | for _, val_sample in enumerate( 345 | self.data.get_val_batches(total_batches=int(self.args.num_evaluation_tasks / self.args.batch_size), 346 | augment_images=False)): 347 | val_losses, total_losses = self.evaluation_iteration(val_sample=val_sample, 348 | total_losses=total_losses, 349 | pbar_val=pbar_val, phase='val') 350 | 351 | if val_losses["val_accuracy_mean"] > self.state['best_val_acc']: 352 | print("Best validation accuracy", val_losses["val_accuracy_mean"]) 353 | self.state['best_val_acc'] = val_losses["val_accuracy_mean"] 354 | self.state['best_val_iter'] = self.state['current_iter'] 355 | self.state['best_epoch'] = int( 356 | self.state['best_val_iter'] / self.args.total_iter_per_epoch) 357 | 358 | 359 | self.epoch += 1 360 | self.state = self.merge_two_dicts(first_dict=self.merge_two_dicts(first_dict=self.state, 361 | second_dict=train_losses), 362 | second_dict=val_losses) 363 | 364 | self.save_models(model=self.model, epoch=self.epoch, state=self.state) 365 | 366 | self.start_time, self.state = self.pack_and_save_metrics(start_time=self.start_time, 367 | create_summary_csv=self.create_summary_csv, 368 | train_losses=train_losses, 369 | val_losses=val_losses, 370 | state=self.state) 371 | 372 | self.total_losses = dict() 373 | 374 | self.epochs_done_in_this_run += 1 375 | 376 | save_to_json(filename=os.path.join(self.logs_filepath, "summary_statistics.json"), 377 | dict_to_store=self.state['per_epoch_statistics']) 378 | 379 | if self.epochs_done_in_this_run >= self.total_epochs_before_pause: 380 | print("train_seed {}, val_seed: {}, at pause time".format(self.data.dataset.seed["train"], 381 | self.data.dataset.seed["val"])) 382 | sys.exit() 383 | self.evaluated_test_set_using_the_best_models(top_n_models=5) 384 | -------------------------------------------------------------------------------- /experiment_config/alfa+maml.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size":2, 3 | "image_height":84, 4 | "image_width":84, 5 | "image_channels":3, 6 | "gpu_to_use":0, 7 | "num_dataprovider_workers":4, 8 | "max_models_to_save":5, 9 | "dataset_name":"mini_imagenet_full_size", 10 | "dataset_path":"mini_imagenet_full_size", 11 | "reset_stored_paths":false, 12 | "experiment_name":"alfa+maml", 13 | "train_seed": 0, "val_seed": 0, 14 | "indexes_of_folders_indicating_class": [-3, -2], 15 | "sets_are_pre_split": true, 16 | "train_val_test_split": [0.64, 0.16, 0.20], 17 | "evaluate_on_test_set_only": false, 18 | 19 | "total_epochs": 100, 20 | "total_iter_per_epoch":500, "continue_from_epoch": -2, 21 | "num_evaluation_tasks":600, 22 | "multi_step_loss_num_epochs": 15, 23 | "minimum_per_task_contribution": 0.01, 24 | "learnable_per_layer_per_step_inner_loop_learning_rate": false, 25 | "enable_inner_loop_optimizable_bn_params": false, 26 | "evalute_on_test_set_only": false, 27 | 28 | "max_pooling": true, 29 | "per_step_bn_statistics": false, 30 | "learnable_batch_norm_momentum": false, 31 | "load_into_memory": false, 32 | "init_inner_loop_learning_rate": 0.01, 33 | "init_inner_loop_weight_decay": 0.0005, 34 | "learnable_bn_gamma": true, 35 | "learnable_bn_beta": true, 36 | 37 | "dropout_rate_value":0.0, 38 | "min_learning_rate":0.001, 39 | "meta_learning_rate":0.001, "total_epochs_before_pause": 100, 40 | "first_order_to_second_order_epoch":-1, 41 | "weight_decay": 0.0, 42 | 43 | "norm_layer":"batch_norm", 44 | "cnn_num_filters":48, 45 | "num_stages":4, 46 | "conv_padding": true, 47 | "number_of_training_steps_per_iter":5, 48 | "number_of_evaluation_steps_per_iter":5, 49 | "cnn_blocks_per_stage":1, 50 | "num_classes_per_set":5, 51 | "num_samples_per_class":5, 52 | "num_target_samples": 15, 53 | 54 | "second_order": true, 55 | "use_multi_step_loss_optimization":false, 56 | "attenuate": false, 57 | "alfa": true, 58 | "random_init": false, 59 | "backbone": "4-CONV" 60 | } 61 | -------------------------------------------------------------------------------- /experiment_config/alfa+maml_resnet12.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size":2, 3 | "image_height":84, 4 | "image_width":84, 5 | "image_channels":3, 6 | "gpu_to_use":0, 7 | "num_dataprovider_workers":4, 8 | "max_models_to_save":5, 9 | "dataset_name":"mini_imagenet_full_size", 10 | "dataset_path":"mini_imagenet_full_size", 11 | "reset_stored_paths":false, 12 | "experiment_name":"alfa+maml_resnet12", 13 | "train_seed": 0, "val_seed": 0, 14 | "indexes_of_folders_indicating_class": [-3, -2], 15 | "sets_are_pre_split": true, 16 | "train_val_test_split": [0.64, 0.16, 0.20], 17 | "evaluate_on_test_set_only": false, 18 | 19 | "total_epochs": 100, 20 | "total_iter_per_epoch":500, "continue_from_epoch": -2, 21 | "num_evaluation_tasks":600, 22 | "multi_step_loss_num_epochs": 15, 23 | "minimum_per_task_contribution": 0.01, 24 | "learnable_per_layer_per_step_inner_loop_learning_rate": false, 25 | "enable_inner_loop_optimizable_bn_params": false, 26 | "evalute_on_test_set_only": false, 27 | 28 | "max_pooling": true, 29 | "per_step_bn_statistics": false, 30 | "learnable_batch_norm_momentum": false, 31 | "load_into_memory": false, 32 | "init_inner_loop_learning_rate": 0.01, 33 | "init_inner_loop_weight_decay": 0.0005, 34 | "learnable_bn_gamma": true, 35 | "learnable_bn_beta": true, 36 | 37 | "dropout_rate_value":0.0, 38 | "min_learning_rate":0.001, 39 | "meta_learning_rate":0.001, "total_epochs_before_pause": 100, 40 | "first_order_to_second_order_epoch":-1, 41 | "weight_decay": 0.0, 42 | 43 | "norm_layer":"batch_norm", 44 | "cnn_num_filters":48, 45 | "num_stages":4, 46 | "conv_padding": true, 47 | "number_of_training_steps_per_iter":5, 48 | "number_of_evaluation_steps_per_iter":5, 49 | "cnn_blocks_per_stage":1, 50 | "num_classes_per_set":5, 51 | "num_samples_per_class":5, 52 | "num_target_samples": 15, 53 | 54 | "second_order": true, 55 | "use_multi_step_loss_optimization":false, 56 | "attenuate": false, 57 | "alfa": true, 58 | "backbone": "ResNet12" 59 | } 60 | -------------------------------------------------------------------------------- /experiment_config/alfa+random_init.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size":2, 3 | "image_height":84, 4 | "image_width":84, 5 | "image_channels":3, 6 | "gpu_to_use":0, 7 | "num_dataprovider_workers":4, 8 | "max_models_to_save":5, 9 | "dataset_name":"mini_imagenet_full_size", 10 | "dataset_path":"mini_imagenet_full_size", 11 | "reset_stored_paths":false, 12 | "experiment_name":"alfa+random_init", 13 | "train_seed": 0, "val_seed": 0, 14 | "indexes_of_folders_indicating_class": [-3, -2], 15 | "sets_are_pre_split": true, 16 | "train_val_test_split": [0.64, 0.16, 0.20], 17 | "evaluate_on_test_set_only": false, 18 | 19 | "total_epochs": 100, 20 | "total_iter_per_epoch":500, "continue_from_epoch": -2, 21 | "num_evaluation_tasks":600, 22 | "multi_step_loss_num_epochs": 15, 23 | "minimum_per_task_contribution": 0.01, 24 | "learnable_per_layer_per_step_inner_loop_learning_rate": false, 25 | "enable_inner_loop_optimizable_bn_params": false, 26 | "evalute_on_test_set_only": false, 27 | 28 | "max_pooling": true, 29 | "per_step_bn_statistics": false, 30 | "learnable_batch_norm_momentum": false, 31 | "load_into_memory": false, 32 | "init_inner_loop_learning_rate": 0.01, 33 | "init_inner_loop_weight_decay": 0.0005, 34 | "learnable_bn_gamma": true, 35 | "learnable_bn_beta": true, 36 | 37 | "dropout_rate_value":0.0, 38 | "min_learning_rate":0.001, 39 | "meta_learning_rate":0.001, "total_epochs_before_pause": 100, 40 | "first_order_to_second_order_epoch":-1, 41 | "weight_decay": 0.0, 42 | 43 | "norm_layer":"batch_norm", 44 | "cnn_num_filters":48, 45 | "num_stages":4, 46 | "conv_padding": true, 47 | "number_of_training_steps_per_iter":5, 48 | "number_of_evaluation_steps_per_iter":5, 49 | "cnn_blocks_per_stage":1, 50 | "num_classes_per_set":5, 51 | "num_samples_per_class":5, 52 | "num_target_samples": 15, 53 | 54 | "second_order": true, 55 | "use_multi_step_loss_optimization":false, 56 | "attenuate": false, 57 | "alfa": true, 58 | "random_init": true, 59 | "backbone": "4-CONV" 60 | } 61 | -------------------------------------------------------------------------------- /experiment_config/alfa+random_init_resnet12.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size":2, 3 | "image_height":84, 4 | "image_width":84, 5 | "image_channels":3, 6 | "gpu_to_use":0, 7 | "num_dataprovider_workers":4, 8 | "max_models_to_save":5, 9 | "dataset_name":"mini_imagenet_full_size", 10 | "dataset_path":"mini_imagenet_full_size", 11 | "reset_stored_paths":false, 12 | "experiment_name":"alfa+random_init_resnet12", 13 | "train_seed": 0, "val_seed": 0, 14 | "indexes_of_folders_indicating_class": [-3, -2], 15 | "sets_are_pre_split": true, 16 | "train_val_test_split": [0.64, 0.16, 0.20], 17 | "evaluate_on_test_set_only": false, 18 | 19 | "total_epochs": 100, 20 | "total_iter_per_epoch":500, "continue_from_epoch": -2, 21 | "num_evaluation_tasks":600, 22 | "multi_step_loss_num_epochs": 15, 23 | "minimum_per_task_contribution": 0.01, 24 | "learnable_per_layer_per_step_inner_loop_learning_rate": false, 25 | "enable_inner_loop_optimizable_bn_params": false, 26 | "evalute_on_test_set_only": false, 27 | 28 | "max_pooling": true, 29 | "per_step_bn_statistics": false, 30 | "learnable_batch_norm_momentum": false, 31 | "load_into_memory": false, 32 | "init_inner_loop_learning_rate": 0.01, 33 | "init_inner_loop_weight_decay": 0.0005, 34 | "learnable_bn_gamma": true, 35 | "learnable_bn_beta": true, 36 | 37 | "dropout_rate_value":0.0, 38 | "min_learning_rate":0.001, 39 | "meta_learning_rate":0.001, "total_epochs_before_pause": 100, 40 | "first_order_to_second_order_epoch":-1, 41 | "weight_decay": 0.0, 42 | 43 | "norm_layer":"batch_norm", 44 | "cnn_num_filters":48, 45 | "num_stages":4, 46 | "conv_padding": true, 47 | "number_of_training_steps_per_iter":5, 48 | "number_of_evaluation_steps_per_iter":5, 49 | "cnn_blocks_per_stage":1, 50 | "num_classes_per_set":5, 51 | "num_samples_per_class":5, 52 | "num_target_samples": 15, 53 | 54 | "second_order": true, 55 | "use_multi_step_loss_optimization":false, 56 | "attenuate": false, 57 | "alfa": true, 58 | "random_init": true, 59 | "backbone": "ResNet12" 60 | } 61 | -------------------------------------------------------------------------------- /experiment_scripts/alfa+maml.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | export GPU_ID=$1 4 | 5 | echo $GPU_ID 6 | 7 | cd .. 8 | export DATASET_DIR="datasets/" 9 | export CUDA_VISIBLE_DEVICES=$1,$2,$3,$4 10 | # Activate the relevant virtual environment: 11 | python train_maml_system.py --name_of_args_json_file experiment_config/alfa+maml.json --gpu_to_use $GPU_ID 12 | -------------------------------------------------------------------------------- /experiment_scripts/alfa+maml_resnet12.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | export GPU_ID=$1 4 | 5 | echo $GPU_ID 6 | 7 | cd .. 8 | export DATASET_DIR="datasets/" 9 | export CUDA_VISIBLE_DEVICES=$1,$2,$3,$4 10 | # Activate the relevant virtual environment: 11 | python train_maml_system.py --name_of_args_json_file experiment_config/alfa+maml_resnet12.json --gpu_to_use $GPU_ID 12 | -------------------------------------------------------------------------------- /experiment_scripts/alfa+random_init.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | export GPU_ID=$1 4 | 5 | echo $GPU_ID 6 | 7 | cd .. 8 | export DATASET_DIR="datasets/" 9 | export CUDA_VISIBLE_DEVICES=$1,$2,$3,$4 10 | # Activate the relevant virtual environment: 11 | python train_maml_system.py --name_of_args_json_file experiment_config/alfa+random_init.json --gpu_to_use $GPU_ID 12 | -------------------------------------------------------------------------------- /experiment_scripts/alfa+random_init_resnet12.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | export GPU_ID=$1 4 | 5 | echo $GPU_ID 6 | 7 | cd .. 8 | export DATASET_DIR="datasets/" 9 | export CUDA_VISIBLE_DEVICES=$1,$2,$3,$4 10 | # Activate the relevant virtual environment: 11 | python train_maml_system.py --name_of_args_json_file experiment_config/alfa+random_init_resnet12.json --gpu_to_use $GPU_ID 12 | -------------------------------------------------------------------------------- /few_shot_learning_system.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | 9 | from meta_neural_network_architectures import VGGReLUNormNetwork, ResNet12 10 | from inner_loop_optimizers import LSLRGradientDescentLearningRule 11 | 12 | 13 | def set_torch_seed(seed): 14 | """ 15 | Sets the pytorch seeds for current experiment run 16 | :param seed: The seed (int) 17 | :return: A random number generator to use 18 | """ 19 | rng = np.random.RandomState(seed=seed) 20 | torch_seed = rng.randint(0, 999999) 21 | torch.manual_seed(seed=torch_seed) 22 | 23 | return rng 24 | 25 | 26 | class MAMLFewShotClassifier(nn.Module): 27 | def __init__(self, im_shape, device, args): 28 | """ 29 | Initializes a MAML few shot learning system 30 | :param im_shape: The images input size, in batch, c, h, w shape 31 | :param device: The device to use to use the model on. 32 | :param args: A namedtuple of arguments specifying various hyperparameters. 33 | """ 34 | super(MAMLFewShotClassifier, self).__init__() 35 | self.args = args 36 | self.device = device 37 | self.batch_size = args.batch_size 38 | self.use_cuda = args.use_cuda 39 | self.im_shape = im_shape 40 | self.current_epoch = 0 41 | 42 | self.rng = set_torch_seed(seed=args.seed) 43 | 44 | if self.args.backbone == 'ResNet12': 45 | self.classifier = ResNet12(im_shape=self.im_shape, num_output_classes=self.args. 46 | num_classes_per_set, 47 | args=args, device=device, meta_classifier=True).to(device=self.device) 48 | else: 49 | self.classifier = VGGReLUNormNetwork(im_shape=self.im_shape, num_output_classes=self.args. 50 | num_classes_per_set, 51 | args=args, device=device, meta_classifier=True).to(device=self.device) 52 | 53 | self.task_learning_rate = args.init_inner_loop_learning_rate 54 | 55 | self.inner_loop_optimizer = LSLRGradientDescentLearningRule(device=device, 56 | init_learning_rate=self.task_learning_rate, 57 | init_weight_decay=args.init_inner_loop_weight_decay, 58 | total_num_inner_loop_steps=self.args.number_of_training_steps_per_iter, 59 | use_learnable_weight_decay=self.args.alfa, 60 | use_learnable_learning_rates=self.args.alfa, 61 | alfa=self.args.alfa, random_init=self.args.random_init) 62 | 63 | names_weights_copy = self.get_inner_loop_parameter_dict(self.classifier.named_parameters()) 64 | 65 | if self.args.attenuate: 66 | num_layers = len(names_weights_copy) 67 | self.attenuator = nn.Sequential( 68 | nn.Linear(num_layers, num_layers), 69 | nn.ReLU(inplace=True), 70 | nn.Linear(num_layers, num_layers), 71 | nn.Sigmoid() 72 | ).to(device=self.device) 73 | 74 | self.inner_loop_optimizer.initialise( 75 | names_weights_dict=names_weights_copy) 76 | 77 | print("Inner Loop parameters") 78 | for key, value in self.inner_loop_optimizer.named_parameters(): 79 | print(key, value.shape) 80 | 81 | self.use_cuda = args.use_cuda 82 | self.device = device 83 | self.args = args 84 | self.to(device) 85 | print("Outer Loop parameters") 86 | for name, param in self.named_parameters(): 87 | if param.requires_grad: 88 | print(name, param.shape, param.device, param.requires_grad) 89 | 90 | # ALFA 91 | if self.args.alfa: 92 | num_layers = len(names_weights_copy) 93 | input_dim = num_layers*2 94 | self.regularizer = nn.Sequential( 95 | nn.Linear(input_dim, input_dim), 96 | nn.ReLU(inplace=True), 97 | nn.Linear(input_dim, input_dim) 98 | ).to(device=self.device) 99 | 100 | if self.args.attenuate: 101 | if self.args.alfa: 102 | self.optimizer = optim.Adam([ 103 | {'params':self.classifier.parameters()}, 104 | {'params': self.inner_loop_optimizer.parameters()}, 105 | {'params': self.regularizer.parameters()}, 106 | {'params':self.attenuator.parameters()}, 107 | ],lr=args.meta_learning_rate, amsgrad=False) 108 | else: 109 | self.optimizer = optim.Adam([ 110 | {'params':self.classifier.parameters()}, 111 | {'params':self.attenuator.parameters()}, 112 | ],lr=args.meta_learning_rate, amsgrad=False) 113 | else: 114 | if self.args.alfa: 115 | if self.args.random_init: 116 | self.optimizer = optim.Adam([ 117 | {'params': self.inner_loop_optimizer.parameters()}, 118 | {'params': self.regularizer.parameters()}, 119 | ], lr=args.meta_learning_rate, amsgrad=False) 120 | else: 121 | self.optimizer = optim.Adam([ 122 | {'params': self.classifier.parameters()}, 123 | {'params': self.inner_loop_optimizer.parameters()}, 124 | {'params': self.regularizer.parameters()}, 125 | ], lr=args.meta_learning_rate, amsgrad=False) 126 | else: 127 | self.optimizer = optim.Adam([ 128 | {'params': self.classifier.parameters()}, 129 | ], lr=args.meta_learning_rate, amsgrad=False) 130 | 131 | self.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=self.optimizer, T_max=self.args.total_epochs, 132 | eta_min=self.args.min_learning_rate) 133 | 134 | self.device = torch.device('cpu') 135 | if torch.cuda.is_available(): 136 | print(torch.cuda.device_count()) 137 | if torch.cuda.device_count() > 1: 138 | self.to(torch.cuda.current_device()) 139 | self.classifier = nn.DataParallel(module=self.classifier) 140 | else: 141 | self.to(torch.cuda.current_device()) 142 | 143 | self.device = torch.cuda.current_device() 144 | 145 | def get_task_embeddings(self, x_support_set_task, y_support_set_task, names_weights_copy): 146 | # Use gradients as task embeddings 147 | support_loss, support_preds = self.net_forward(x=x_support_set_task, 148 | y=y_support_set_task, 149 | weights=names_weights_copy, 150 | backup_running_statistics=True, 151 | training=True, num_step=0) 152 | 153 | if torch.cuda.device_count() > 1: 154 | self.classifier.module.zero_grad(names_weights_copy) 155 | else: 156 | self.classifier.zero_grad(names_weights_copy) 157 | grads = torch.autograd.grad(support_loss, names_weights_copy.values(), create_graph=True) 158 | 159 | 160 | layerwise_mean_grads = [] 161 | 162 | for i in range(len(grads)): 163 | layerwise_mean_grads.append(grads[i].mean()) 164 | 165 | layerwise_mean_grads = torch.stack(layerwise_mean_grads) 166 | 167 | return layerwise_mean_grads 168 | 169 | def attenuate_init(self, task_embeddings, names_weights_copy): 170 | # Generate attenuation parameters 171 | gamma = self.attenuator(task_embeddings) 172 | 173 | ## Attenuate 174 | 175 | updated_names_weights_copy = dict() 176 | i = 0 177 | for key in names_weights_copy.keys(): 178 | updated_names_weights_copy[key] = gamma[i] * names_weights_copy[key] 179 | i+=1 180 | 181 | return updated_names_weights_copy 182 | 183 | 184 | def get_per_step_loss_importance_vector(self): 185 | """ 186 | Generates a tensor of dimensionality (num_inner_loop_steps) indicating the importance of each step's target 187 | loss towards the optimization loss. 188 | :return: A tensor to be used to compute the weighted average of the loss, useful for 189 | the MSL (Multi Step Loss) mechanism. 190 | """ 191 | loss_weights = np.ones(shape=(self.args.number_of_training_steps_per_iter)) * ( 192 | 1.0 / self.args.number_of_training_steps_per_iter) 193 | decay_rate = 1.0 / self.args.number_of_training_steps_per_iter / self.args.multi_step_loss_num_epochs 194 | min_value_for_non_final_losses = 0.03 / self.args.number_of_training_steps_per_iter 195 | for i in range(len(loss_weights) - 1): 196 | curr_value = np.maximum(loss_weights[i] - (self.current_epoch * decay_rate), min_value_for_non_final_losses) 197 | loss_weights[i] = curr_value 198 | 199 | curr_value = np.minimum( 200 | loss_weights[-1] + (self.current_epoch * (self.args.number_of_training_steps_per_iter - 1) * decay_rate), 201 | 1.0 - ((self.args.number_of_training_steps_per_iter - 1) * min_value_for_non_final_losses)) 202 | loss_weights[-1] = curr_value 203 | loss_weights = torch.Tensor(loss_weights).to(device=self.device) 204 | return loss_weights 205 | 206 | def get_inner_loop_parameter_dict(self, params): 207 | """ 208 | Returns a dictionary with the parameters to use for inner loop updates. 209 | :param params: A dictionary of the network's parameters. 210 | :return: A dictionary of the parameters to use for the inner loop optimization process. 211 | """ 212 | param_dict = dict() 213 | for name, param in params: 214 | if param.requires_grad: 215 | if self.args.enable_inner_loop_optimizable_bn_params: 216 | param_dict[name] = param.to(device=self.device) 217 | else: 218 | if "norm_layer" not in name: 219 | param_dict[name] = param.to(device=self.device) 220 | 221 | return param_dict 222 | 223 | def apply_inner_loop_update(self, loss, names_weights_copy, generated_alpha_params, generated_beta_params, use_second_order, current_step_idx): 224 | """ 225 | Applies an inner loop update given current step's loss, the weights to update, a flag indicating whether to use 226 | second order derivatives and the current step's index. 227 | :param loss: Current step's loss with respect to the support set. 228 | :param names_weights_copy: A dictionary with names to parameters to update. 229 | :param use_second_order: A boolean flag of whether to use second order derivatives. 230 | :param current_step_idx: Current step's index. 231 | :return: A dictionary with the updated weights (name, param) 232 | """ 233 | num_gpus = torch.cuda.device_count() 234 | if num_gpus > 1: 235 | self.classifier.module.zero_grad(params=names_weights_copy) 236 | else: 237 | self.classifier.zero_grad(params=names_weights_copy) 238 | 239 | grads = torch.autograd.grad(loss, names_weights_copy.values(), 240 | create_graph=use_second_order, allow_unused=True) 241 | names_grads_copy = dict(zip(names_weights_copy.keys(), grads)) 242 | 243 | names_weights_copy = {key: value[0] for key, value in names_weights_copy.items()} 244 | 245 | for key, grad in names_grads_copy.items(): 246 | if grad is None: 247 | print('Grads not found for inner loop parameter', key) 248 | names_grads_copy[key] = names_grads_copy[key].sum(dim=0) 249 | 250 | 251 | names_weights_copy = self.inner_loop_optimizer.update_params(names_weights_dict=names_weights_copy, 252 | names_grads_wrt_params_dict=names_grads_copy, 253 | generated_alpha_params=generated_alpha_params, 254 | generated_beta_params=generated_beta_params, 255 | num_step=current_step_idx) 256 | 257 | num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 258 | names_weights_copy = { 259 | name.replace('module.', ''): value.unsqueeze(0).repeat( 260 | [num_devices] + [1 for i in range(len(value.shape))]) for 261 | name, value in names_weights_copy.items()} 262 | 263 | 264 | return names_weights_copy 265 | 266 | def get_across_task_loss_metrics(self, total_losses, total_accuracies): 267 | losses = dict() 268 | 269 | losses['loss'] = torch.mean(torch.stack(total_losses)) 270 | losses['accuracy'] = np.mean(total_accuracies) 271 | 272 | return losses 273 | 274 | def forward(self, data_batch, epoch, use_second_order, use_multi_step_loss_optimization, num_steps, training_phase): 275 | """ 276 | Runs a forward outer loop pass on the batch of tasks using the MAML/++ framework. 277 | :param data_batch: A data batch containing the support and target sets. 278 | :param epoch: Current epoch's index 279 | :param use_second_order: A boolean saying whether to use second order derivatives. 280 | :param use_multi_step_loss_optimization: Whether to optimize on the outer loop using just the last step's 281 | target loss (True) or whether to use multi step loss which improves the stability of the system (False) 282 | :param num_steps: Number of inner loop steps. 283 | :param training_phase: Whether this is a training phase (True) or an evaluation phase (False) 284 | :return: A dictionary with the collected losses of the current outer forward propagation. 285 | """ 286 | x_support_set, x_target_set, y_support_set, y_target_set = data_batch 287 | 288 | [b, ncs, spc] = y_support_set.shape 289 | 290 | self.num_classes_per_set = ncs 291 | 292 | total_losses = [] 293 | total_accuracies = [] 294 | total_support_accuracies = [[] for i in range(num_steps)] 295 | total_target_accuracies = [[] for i in range(num_steps)] 296 | per_task_target_preds = [[] for i in range(len(x_target_set))] 297 | 298 | if torch.cuda.device_count() > 1: 299 | self.classifier.module.zero_grad() 300 | else: 301 | self.classifier.zero_grad() 302 | for task_id, (x_support_set_task, y_support_set_task, x_target_set_task, y_target_set_task) in \ 303 | enumerate(zip(x_support_set, 304 | y_support_set, 305 | x_target_set, 306 | y_target_set)): 307 | task_losses = [] 308 | task_accuracies = [] 309 | per_step_support_accuracy = [] 310 | per_step_target_accuracy = [] 311 | per_step_loss_importance_vectors = self.get_per_step_loss_importance_vector() 312 | names_weights_copy = self.get_inner_loop_parameter_dict(self.classifier.named_parameters()) 313 | 314 | num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 315 | 316 | names_weights_copy = { 317 | name.replace('module.', ''): value.unsqueeze(0).repeat( 318 | [num_devices] + [1 for i in range(len(value.shape))]) for 319 | name, value in names_weights_copy.items()} 320 | 321 | 322 | n, s, c, h, w = x_target_set_task.shape 323 | 324 | x_support_set_task = x_support_set_task.view(-1, c, h, w) 325 | y_support_set_task = y_support_set_task.view(-1) 326 | x_target_set_task = x_target_set_task.view(-1, c, h, w) 327 | y_target_set_task = y_target_set_task.view(-1) 328 | 329 | # Attenuate the initialization for L2F 330 | if self.args.attenuate: 331 | # Obtain gradients from support set for task embedding 332 | task_embeddings = self.get_task_embeddings(x_support_set_task=x_support_set_task, 333 | y_support_set_task=y_support_set_task, 334 | names_weights_copy=names_weights_copy) 335 | 336 | names_weights_copy = self.attenuate_init(task_embeddings=task_embeddings, 337 | names_weights_copy=names_weights_copy) 338 | 339 | 340 | for num_step in range(num_steps): 341 | 342 | support_loss, support_preds = self.net_forward(x=x_support_set_task, 343 | y=y_support_set_task, 344 | weights=names_weights_copy, 345 | backup_running_statistics= 346 | True if (num_step == 0) else False, 347 | training=True, num_step=num_step) 348 | generated_alpha_params = {} 349 | generated_beta_params = {} 350 | 351 | if self.args.alfa: 352 | 353 | support_loss_grad = torch.autograd.grad(support_loss, names_weights_copy.values(), retain_graph=True) 354 | per_step_task_embedding = [] 355 | for k, v in names_weights_copy.items(): 356 | per_step_task_embedding.append(v.mean()) 357 | 358 | for i in range(len(support_loss_grad)): 359 | per_step_task_embedding.append(support_loss_grad[i].mean()) 360 | 361 | per_step_task_embedding = torch.stack(per_step_task_embedding) 362 | 363 | generated_params = self.regularizer(per_step_task_embedding) 364 | num_layers = len(names_weights_copy) 365 | 366 | generated_alpha, generated_beta = torch.split(generated_params, split_size_or_sections=num_layers) 367 | g = 0 368 | for key in names_weights_copy.keys(): 369 | generated_alpha_params[key] = generated_alpha[g] 370 | generated_beta_params[key] = generated_beta[g] 371 | g+=1 372 | 373 | names_weights_copy = self.apply_inner_loop_update(loss=support_loss, 374 | names_weights_copy=names_weights_copy, 375 | generated_beta_params=generated_beta_params, 376 | generated_alpha_params=generated_alpha_params, 377 | use_second_order=use_second_order, 378 | current_step_idx=num_step) 379 | 380 | if use_multi_step_loss_optimization and training_phase and epoch < self.args.multi_step_loss_num_epochs: 381 | target_loss, target_preds = self.net_forward(x=x_target_set_task, 382 | y=y_target_set_task, weights=names_weights_copy, 383 | backup_running_statistics=False, training=True, 384 | num_step=num_step) 385 | 386 | task_losses.append(per_step_loss_importance_vectors[num_step] * target_loss) 387 | 388 | else: 389 | if num_step == (self.args.number_of_training_steps_per_iter - 1): 390 | target_loss, target_preds = self.net_forward(x=x_target_set_task, 391 | y=y_target_set_task, weights=names_weights_copy, 392 | backup_running_statistics=False, training=True, 393 | num_step=num_step) 394 | task_losses.append(target_loss) 395 | 396 | per_task_target_preds[task_id] = target_preds.detach().cpu().numpy() 397 | _, predicted = torch.max(target_preds.data, 1) 398 | 399 | accuracy = predicted.float().eq(y_target_set_task.data.float()).cpu().float() 400 | task_losses = torch.sum(torch.stack(task_losses)) 401 | total_losses.append(task_losses) 402 | total_accuracies.extend(accuracy) 403 | 404 | if not training_phase: 405 | if torch.cuda.device_count() > 1: 406 | self.classifier.module.restore_backup_stats() 407 | else: 408 | self.classifier.restore_backup_stats() 409 | 410 | losses = self.get_across_task_loss_metrics(total_losses=total_losses, 411 | total_accuracies=total_accuracies) 412 | 413 | for idx, item in enumerate(per_step_loss_importance_vectors): 414 | losses['loss_importance_vector_{}'.format(idx)] = item.detach().cpu().numpy() 415 | 416 | return losses, per_task_target_preds 417 | 418 | def net_forward(self, x, y, weights, backup_running_statistics, training, num_step): 419 | """ 420 | A base model forward pass on some data points x. Using the parameters in the weights dictionary. Also requires 421 | boolean flags indicating whether to reset the running statistics at the end of the run (if at evaluation phase). 422 | A flag indicating whether this is the training session and an int indicating the current step's number in the 423 | inner loop. 424 | :param x: A data batch of shape b, c, h, w 425 | :param y: A data targets batch of shape b, n_classes 426 | :param weights: A dictionary containing the weights to pass to the network. 427 | :param backup_running_statistics: A flag indicating whether to reset the batch norm running statistics to their 428 | previous values after the run (only for evaluation) 429 | :param training: A flag indicating whether the current process phase is a training or evaluation. 430 | :param num_step: An integer indicating the number of the step in the inner loop. 431 | :return: the crossentropy losses with respect to the given y, the predictions of the base model. 432 | """ 433 | preds = self.classifier.forward(x=x, params=weights, 434 | training=training, 435 | backup_running_statistics=backup_running_statistics, num_step=num_step) 436 | 437 | loss = F.cross_entropy(input=preds, target=y) 438 | 439 | return loss, preds 440 | 441 | def trainable_parameters(self): 442 | """ 443 | Returns an iterator over the trainable parameters of the model. 444 | """ 445 | for param in self.parameters(): 446 | if param.requires_grad: 447 | yield param 448 | 449 | def train_forward_prop(self, data_batch, epoch): 450 | """ 451 | Runs an outer loop forward prop using the meta-model and base-model. 452 | :param data_batch: A data batch containing the support set and the target set input, output pairs. 453 | :param epoch: The index of the currrent epoch. 454 | :return: A dictionary of losses for the current step. 455 | """ 456 | losses, per_task_target_preds = self.forward(data_batch=data_batch, epoch=epoch, 457 | use_second_order=self.args.second_order and 458 | epoch > self.args.first_order_to_second_order_epoch, 459 | use_multi_step_loss_optimization=self.args.use_multi_step_loss_optimization, 460 | num_steps=self.args.number_of_training_steps_per_iter, 461 | training_phase=True) 462 | return losses, per_task_target_preds 463 | 464 | def evaluation_forward_prop(self, data_batch, epoch): 465 | """ 466 | Runs an outer loop evaluation forward prop using the meta-model and base-model. 467 | :param data_batch: A data batch containing the support set and the target set input, output pairs. 468 | :param epoch: The index of the currrent epoch. 469 | :return: A dictionary of losses for the current step. 470 | """ 471 | losses, per_task_target_preds = self.forward(data_batch=data_batch, epoch=epoch, use_second_order=False, 472 | use_multi_step_loss_optimization=True, 473 | num_steps=self.args.number_of_evaluation_steps_per_iter, 474 | training_phase=False) 475 | 476 | return losses, per_task_target_preds 477 | 478 | def meta_update(self, loss): 479 | """ 480 | Applies an outer loop update on the meta-parameters of the model. 481 | :param loss: The current crossentropy loss. 482 | """ 483 | self.optimizer.zero_grad() 484 | loss.backward() 485 | #if 'imagenet' in self.args.dataset_name: 486 | # for name, param in self.classifier.named_parameters(): 487 | # if param.requires_grad: 488 | # param.grad.data.clamp_(-10, 10) # not sure if this is necessary, more experiments are needed 489 | #for name, param in self.classifier.named_parameters(): 490 | # print(param.mean()) 491 | 492 | self.optimizer.step() 493 | 494 | def run_train_iter(self, data_batch, epoch): 495 | """ 496 | Runs an outer loop update step on the meta-model's parameters. 497 | :param data_batch: input data batch containing the support set and target set input, output pairs 498 | :param epoch: the index of the current epoch 499 | :return: The losses of the ran iteration. 500 | """ 501 | epoch = int(epoch) 502 | self.scheduler.step(epoch=epoch) 503 | if self.current_epoch != epoch: 504 | self.current_epoch = epoch 505 | 506 | if not self.training: 507 | self.train() 508 | 509 | x_support_set, x_target_set, y_support_set, y_target_set = data_batch 510 | 511 | x_support_set = torch.Tensor(x_support_set).float().to(device=self.device) 512 | x_target_set = torch.Tensor(x_target_set).float().to(device=self.device) 513 | y_support_set = torch.Tensor(y_support_set).long().to(device=self.device) 514 | y_target_set = torch.Tensor(y_target_set).long().to(device=self.device) 515 | 516 | data_batch = (x_support_set, x_target_set, y_support_set, y_target_set) 517 | 518 | losses, per_task_target_preds = self.train_forward_prop(data_batch=data_batch, epoch=epoch) 519 | 520 | self.meta_update(loss=losses['loss']) 521 | losses['learning_rate'] = self.scheduler.get_lr()[0] 522 | self.optimizer.zero_grad() 523 | self.zero_grad() 524 | 525 | return losses, per_task_target_preds 526 | 527 | def run_validation_iter(self, data_batch): 528 | """ 529 | Runs an outer loop evaluation step on the meta-model's parameters. 530 | :param data_batch: input data batch containing the support set and target set input, output pairs 531 | :param epoch: the index of the current epoch 532 | :return: The losses of the ran iteration. 533 | """ 534 | 535 | if self.training: 536 | self.eval() 537 | 538 | x_support_set, x_target_set, y_support_set, y_target_set = data_batch 539 | 540 | x_support_set = torch.Tensor(x_support_set).float().to(device=self.device) 541 | x_target_set = torch.Tensor(x_target_set).float().to(device=self.device) 542 | y_support_set = torch.Tensor(y_support_set).long().to(device=self.device) 543 | y_target_set = torch.Tensor(y_target_set).long().to(device=self.device) 544 | 545 | data_batch = (x_support_set, x_target_set, y_support_set, y_target_set) 546 | 547 | losses, per_task_target_preds = self.evaluation_forward_prop(data_batch=data_batch, epoch=self.current_epoch) 548 | 549 | losses['loss'].backward() # uncomment if you get the weird memory error 550 | self.zero_grad() 551 | self.optimizer.zero_grad() 552 | 553 | return losses, per_task_target_preds 554 | 555 | def save_model(self, model_save_dir, state): 556 | """ 557 | Save the network parameter state and experiment state dictionary. 558 | :param model_save_dir: The directory to store the state at. 559 | :param state: The state containing the experiment state and the network. It's in the form of a dictionary 560 | object. 561 | """ 562 | state['network'] = self.state_dict() 563 | torch.save(state, f=model_save_dir) 564 | 565 | def load_model(self, model_save_dir, model_name, model_idx): 566 | """ 567 | Load checkpoint and return the state dictionary containing the network state params and experiment state. 568 | :param model_save_dir: The directory from which to load the files. 569 | :param model_name: The model_name to be loaded from the direcotry. 570 | :param model_idx: The index of the model (i.e. epoch number or 'latest' for the latest saved model of the current 571 | experiment) 572 | :return: A dictionary containing the experiment state and the saved model parameters. 573 | """ 574 | filepath = os.path.join(model_save_dir, "{}_{}".format(model_name, model_idx)) 575 | state = torch.load(filepath) 576 | state_dict_loaded = state['network'] 577 | self.load_state_dict(state_dict=state_dict_loaded) 578 | return state 579 | -------------------------------------------------------------------------------- /inner_loop_optimizers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | 11 | 12 | class GradientDescentLearningRule(nn.Module): 13 | """Simple (stochastic) gradient descent learning rule. 14 | For a scalar error function `E(p[0], p_[1] ... )` of some set of 15 | potentially multidimensional parameters this attempts to find a local 16 | minimum of the loss function by applying updates to each parameter of the 17 | form 18 | p[i] := p[i] - learning_rate * dE/dp[i] 19 | With `learning_rate` a positive scaling parameter. 20 | The error function used in successive applications of these updates may be 21 | a stochastic estimator of the true error function (e.g. when the error with 22 | respect to only a subset of data-points is calculated) in which case this 23 | will correspond to a stochastic gradient descent learning rule. 24 | """ 25 | 26 | def __init__(self, device, learning_rate=1e-3): 27 | """Creates a new learning rule object. 28 | Args: 29 | learning_rate: A postive scalar to scale gradient updates to the 30 | parameters by. This needs to be carefully set - if too large 31 | the learning dynamic will be unstable and may diverge, while 32 | if set too small learning will proceed very slowly. 33 | """ 34 | super(GradientDescentLearningRule, self).__init__() 35 | assert learning_rate > 0., 'learning_rate should be positive.' 36 | self.learning_rate = torch.ones(1) * learning_rate 37 | self.learning_rate.to(device) 38 | 39 | def update_params(self, names_weights_dict, names_grads_wrt_params_dict, num_step, tau=0.9): 40 | """Applies a single gradient descent update to all parameters. 41 | All parameter updates are performed using in-place operations and so 42 | nothing is returned. 43 | Args: 44 | grads_wrt_params: A list of gradients of the scalar loss function 45 | with respect to each of the parameters passed to `initialise` 46 | previously, with this list expected to be in the same order. 47 | """ 48 | updated_names_weights_dict = dict() 49 | for key in names_weights_dict.keys(): 50 | updated_names_weights_dict[key] = names_weights_dict[key] - self.learning_rate * \ 51 | names_grads_wrt_params_dict[ 52 | key] 53 | 54 | return updated_names_weights_dict 55 | 56 | 57 | class LSLRGradientDescentLearningRule(nn.Module): 58 | """Simple (stochastic) gradient descent learning rule. 59 | For a scalar error function `E(p[0], p_[1] ... )` of some set of 60 | potentially multidimensional parameters this attempts to find a local 61 | minimum of the loss function by applying updates to each parameter of the 62 | form 63 | p[i] := p[i] - learning_rate * dE/dp[i] 64 | With `learning_rate` a positive scaling parameter. 65 | The error function used in successive applications of these updates may be 66 | a stochastic estimator of the true error function (e.g. when the error with 67 | respect to only a subset of data-points is calculated) in which case this 68 | will correspond to a stochastic gradient descent learning rule. 69 | """ 70 | 71 | def __init__(self, device, total_num_inner_loop_steps, use_learnable_learning_rates, use_learnable_weight_decay, alfa, random_init, init_learning_rate=1e-3, init_weight_decay=5e-4): 72 | """Creates a new learning rule object. 73 | Args: 74 | init_learning_rate: A postive scalar to scale gradient updates to the 75 | parameters by. This needs to be carefully set - if too large 76 | the learning dynamic will be unstable and may diverge, while 77 | if set too small learning will proceed very slowly. 78 | """ 79 | super(LSLRGradientDescentLearningRule, self).__init__() 80 | print(init_learning_rate) 81 | assert init_learning_rate > 0., 'learning_rate should be positive.' 82 | 83 | self.alfa = alfa 84 | self.random_init = random_init 85 | 86 | self.init_lr_val = init_learning_rate 87 | self.init_wd_val = init_weight_decay 88 | 89 | self.init_pspl_weight_decay = torch.ones(1) 90 | self.init_pspl_weight_decay.to(device) 91 | 92 | self.init_learning_rate = torch.ones(1) * init_learning_rate 93 | self.init_learning_rate.to(device) 94 | self.total_num_inner_loop_steps = total_num_inner_loop_steps 95 | self.use_learnable_learning_rates = use_learnable_learning_rates 96 | self.use_learnable_weight_decay = use_learnable_weight_decay 97 | self.init_weight_decay = torch.ones(1) * init_weight_decay 98 | self.init_bias_decay = torch.ones(1) 99 | 100 | def initialise(self, names_weights_dict): 101 | if self.alfa: 102 | if self.random_init: 103 | self.names_beta_dict_per_param = nn.ParameterDict() 104 | 105 | self.names_alpha_dict = nn.ParameterDict() 106 | self.names_beta_dict = nn.ParameterDict() 107 | 108 | for idx, (key, param) in enumerate(names_weights_dict.items()): 109 | 110 | if self.random_init: 111 | # per-param weight decay for random init 112 | self.names_beta_dict_per_param[key.replace(".", "-")] = nn.Parameter( 113 | data=torch.ones(param.shape) * self.init_weight_decay * self.init_learning_rate, 114 | requires_grad=self.use_learnable_learning_rates) 115 | 116 | self.names_beta_dict[key.replace(".", "-")] = nn.Parameter( 117 | data=torch.ones(self.total_num_inner_loop_steps + 1), 118 | requires_grad=self.use_learnable_learning_rates) 119 | else: 120 | # per-step per-layer meta-learnable weight decay bias term (for more stable training and better performance by 2~3%) 121 | self.names_beta_dict[key.replace(".", "-")] = nn.Parameter( 122 | data=torch.ones(self.total_num_inner_loop_steps + 1) * self.init_weight_decay * self.init_learning_rate, 123 | requires_grad=self.use_learnable_learning_rates) 124 | 125 | # per-step per-layer meta-learnable learning rate bias term (for more stable training and better performance by 2~3%) 126 | self.names_alpha_dict[key.replace(".", "-")] = nn.Parameter( 127 | data=torch.ones(self.total_num_inner_loop_steps + 1) * self.init_learning_rate, 128 | requires_grad=self.use_learnable_learning_rates) 129 | 130 | def update_params(self, names_weights_dict, names_grads_wrt_params_dict, generated_alpha_params, generated_beta_params, num_step, tau=0.1): 131 | """Applies a single gradient descent update to all parameters. 132 | All parameter updates are performed using in-place operations and so 133 | nothing is returned. 134 | Args: 135 | grads_wrt_params: A list of gradients of the scalar loss function 136 | with respect to each of the parameters passed to `initialise` 137 | previously, with this list expected to be in the same order. 138 | """ 139 | updated_names_weights_dict = dict() 140 | 141 | for key in names_grads_wrt_params_dict.keys(): 142 | # beta = (1 - generated_beta * meta-learned per-step-per-layer bias term) 143 | # alpha = generated_alpha * meta-learned per-step-per-layer bias term) 144 | if self.alfa: 145 | if self.random_init: 146 | updated_names_weights_dict[key] = (1 - self.names_beta_dict_per_param[key.replace(".", "-")] * generated_beta_params[key] * self.names_beta_dict[key.replace(".", "-")][num_step]) * names_weights_dict[key] - generated_alpha_params[key] * self.names_alpha_dict[key.replace(".", "-")][num_step] * names_grads_wrt_params_dict[key] 147 | else: 148 | updated_names_weights_dict[key] = (1 - generated_beta_params[key] * self.names_beta_dict[key.replace(".", "-")][num_step]) * names_weights_dict[key] - generated_alpha_params[key] * self.names_alpha_dict[key.replace(".", "-")][num_step] * names_grads_wrt_params_dict[key] 149 | else: 150 | updated_names_weights_dict[key] = names_weights_dict[key] - self.init_lr_val * names_grads_wrt_params_dict[key] 151 | 152 | return updated_names_weights_dict 153 | 154 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | # Change the cuda version if necessary 2 | conda install pytorch torchvision cudatoolkit=10.2 -c pytorch 3 | conda install -c conda-forge tensorboard 4 | conda install numpy scipy matplotlib 5 | conda install -c conda-forge pbzip2 pydrive 6 | conda install pillow tqdm 7 | -------------------------------------------------------------------------------- /meta_neural_network_architectures.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | from copy import copy 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch 7 | import numpy as np 8 | 9 | 10 | 11 | def extract_top_level_dict(current_dict): 12 | """ 13 | Builds a graph dictionary from the passed depth_keys, value pair. Useful for dynamically passing external params 14 | :param depth_keys: A list of strings making up the name of a variable. Used to make a graph for that params tree. 15 | :param value: Param value 16 | :param key_exists: If none then assume new dict, else load existing dict and add new key->value pairs to it. 17 | :return: A dictionary graph of the params already added to the graph. 18 | """ 19 | output_dict = dict() 20 | for key in current_dict.keys(): 21 | name = key.replace("layer_dict.", "") 22 | name = name.replace("layer_dict.", "") 23 | name = name.replace("block_dict.", "") 24 | name = name.replace("module-", "") 25 | top_level = name.split(".")[0] 26 | sub_level = ".".join(name.split(".")[1:]) 27 | 28 | if top_level not in output_dict: 29 | if sub_level == "": 30 | output_dict[top_level] = current_dict[key] 31 | else: 32 | output_dict[top_level] = {sub_level: current_dict[key]} 33 | else: 34 | new_item = {key: value for key, value in output_dict[top_level].items()} 35 | new_item[sub_level] = current_dict[key] 36 | output_dict[top_level] = new_item 37 | 38 | #print(current_dict.keys(), output_dict.keys()) 39 | return output_dict 40 | 41 | class MetaMaxResLayerReLU(nn.Module): 42 | def __init__(self, input_shape, num_filters, kernel_size, stride, padding, use_bias, args, normalization=True, 43 | meta_layer=True, no_bn_learnable_params=False, device=None, downsample=None, max_padding=0, maxpool=True): 44 | """ 45 | Initializes a BatchNorm->Conv->ReLU layer which applies those operation in that order. 46 | :param args: A named tuple containing the system's hyperparameters. 47 | :param device: The device to run the layer on. 48 | :param normalization: The type of normalization to use 'batch_norm' or 'layer_norm' 49 | :param meta_layer: Whether this layer will require meta-layer capabilities such as meta-batch norm, 50 | meta-conv etc. 51 | :param input_shape: The image input shape in the form (b, c, h, w) 52 | :param num_filters: number of filters for convolutional layer 53 | :param kernel_size: the kernel size of the convolutional layer 54 | :param stride: the stride of the convolutional layer 55 | :param padding: the bias of the convolutional layer 56 | :param use_bias: whether the convolutional layer utilizes a bias 57 | """ 58 | super(MetaMaxResLayerReLU, self).__init__() 59 | self.normalization = normalization 60 | self.use_per_step_bn_statistics = args.per_step_bn_statistics 61 | self.input_shape = input_shape 62 | self.args = args 63 | self.num_filters = num_filters 64 | self.kernel_size = kernel_size 65 | self.stride = stride 66 | self.padding = padding 67 | self.use_bias = use_bias 68 | self.meta_layer = meta_layer 69 | self.no_bn_learnable_params = no_bn_learnable_params 70 | self.device = device 71 | self.layer_dict = nn.ModuleDict() 72 | self.downsample = downsample 73 | self.max_padding = max_padding 74 | self.maxpool = maxpool 75 | self.build_block() 76 | 77 | def build_block(self): 78 | 79 | x = torch.zeros(self.input_shape) 80 | 81 | identity = x 82 | out = x 83 | 84 | self.conv1 = MetaConvNormLayerSwish(input_shape=out.shape, 85 | num_filters=self.num_filters, 86 | kernel_size=3, stride=self.stride, 87 | padding=1, 88 | use_bias=self.use_bias, args=self.args, 89 | normalization=True, 90 | meta_layer=self.meta_layer, 91 | no_bn_learnable_params=False, 92 | device=self.device) 93 | out = self.conv1(out, training=True, num_step=0) 94 | 95 | self.conv2 = MetaConvNormLayerSwish(input_shape=out.shape, 96 | num_filters=self.num_filters, 97 | kernel_size=3, stride=self.stride, 98 | padding=1, 99 | use_bias=self.use_bias, args=self.args, 100 | normalization=True, 101 | meta_layer=self.meta_layer, 102 | no_bn_learnable_params=False, 103 | device=self.device) 104 | out = self.conv2(out, training=True, num_step=0) 105 | 106 | self.conv3 = MetaConv2dLayer(in_channels=out.shape[1], out_channels=out.shape[1], 107 | kernel_size=3, 108 | stride=1, padding=self.padding, use_bias=self.use_bias) 109 | 110 | out = self.conv3(out) 111 | 112 | self.norm_layer = MetaBatchNormLayer(out.shape[1], track_running_stats=True, 113 | meta_batch_norm=self.meta_layer, 114 | no_learnable_params=self.no_bn_learnable_params, 115 | device=self.device, 116 | use_per_step_bn_statistics=self.use_per_step_bn_statistics, 117 | args=self.args) 118 | 119 | out = self.norm_layer(out, num_step=0) 120 | 121 | self.shortcut_conv = MetaConv2dLayer(in_channels=identity.shape[1], out_channels=out.shape[1], 122 | kernel_size=1, 123 | stride=1, padding=0, use_bias=self.use_bias) 124 | 125 | self.shortcut_norm_layer = MetaBatchNormLayer(out.shape[1], track_running_stats=True, 126 | meta_batch_norm=self.meta_layer, 127 | no_learnable_params=self.no_bn_learnable_params, 128 | device=self.device, 129 | use_per_step_bn_statistics=self.use_per_step_bn_statistics, 130 | args=self.args) 131 | 132 | identity = self.shortcut_conv(identity) 133 | identity = self.shortcut_norm_layer(identity, num_step=0) 134 | 135 | out += identity 136 | 137 | out = F.relu(out) 138 | 139 | if self.maxpool: 140 | out = F.max_pool2d(input=out, kernel_size=(2, 2), stride=2, padding=self.max_padding) 141 | 142 | 143 | print(out.shape) 144 | 145 | def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False): 146 | """ 147 | Forward propagates by applying the function. If params are none then internal params are used. 148 | Otherwise passed params will be used to execute the function. 149 | :param input: input data batch, size either can be any. 150 | :param num_step: The current inner loop step being taken. This is used when we are learning per step params and 151 | collecting per step batch statistics. It indexes the correct object to use for the current time-step 152 | :param params: A dictionary containing 'weight' and 'bias'. 153 | :param training: Whether this is currently the training or evaluation phase. 154 | :param backup_running_statistics: Whether to backup the running statistics. This is used 155 | at evaluation time, when after the pass is complete we want to throw away the collected validation stats. 156 | :return: The result of the batch norm operation. 157 | """ 158 | conv_params_1 = None 159 | conv_params_2 = None 160 | conv_params_3 = None 161 | conv_params_shortcut = None 162 | norm_params = None 163 | norm_params_shortcut = None 164 | activation_function_pre_params = None 165 | 166 | if params is not None: 167 | params = extract_top_level_dict(current_dict=params) 168 | 169 | if self.normalization: 170 | if 'activation_function_pre' in params: 171 | activation_function_pre_params = params['activation_function_pre'] 172 | 173 | conv_params_1 = params['conv1'] 174 | conv_params_2 = params['conv2'] 175 | conv_params_3 = params['conv3'] 176 | conv_params_shortcut = params['shortcut_conv'] 177 | 178 | if 'norm_layer' in params: 179 | norm_params = params['norm_layer'] 180 | norm_params_shortcut = params['shortcut_norm_layer'] 181 | 182 | out = x 183 | identity = x 184 | 185 | out = self.conv1(out, params=conv_params_1, training=training, 186 | backup_running_statistics=backup_running_statistics, 187 | num_step=num_step) 188 | 189 | out = self.conv2(out, params=conv_params_2, training=training, 190 | backup_running_statistics=backup_running_statistics, 191 | num_step=num_step) 192 | 193 | out = self.conv3(out, params=conv_params_3) 194 | 195 | out = self.norm_layer.forward(out, num_step=num_step, 196 | params=norm_params, training=training, 197 | backup_running_statistics=backup_running_statistics) 198 | 199 | 200 | 201 | identity = self.shortcut_conv(identity, params=conv_params_shortcut) 202 | identity = self.shortcut_norm_layer.forward(identity, num_step=num_step, 203 | params=norm_params_shortcut, training=training, 204 | backup_running_statistics=backup_running_statistics) 205 | out += identity 206 | 207 | out = F.relu(out) 208 | 209 | if self.maxpool: 210 | out = F.max_pool2d(input=out, kernel_size=(2, 2), stride=2, padding=self.max_padding) 211 | 212 | 213 | return out 214 | 215 | def restore_backup_stats(self): 216 | """ 217 | Restore stored statistics from the backup, replacing the current ones. 218 | """ 219 | self.conv1.restore_backup_stats() 220 | self.conv2.restore_backup_stats() 221 | self.norm_layer.restore_backup_stats() 222 | self.shortcut_norm_layer.restore_backup_stats() 223 | 224 | 225 | class MetaConvNormLayerSwish(nn.Module): 226 | def __init__(self, input_shape, num_filters, kernel_size, stride, padding, use_bias, args, normalization=True, 227 | meta_layer=True, no_bn_learnable_params=False, device=None): 228 | """ 229 | Initializes a BatchNorm->Conv->ReLU layer which applies those operation in that order. 230 | :param args: A named tuple containing the system's hyperparameters. 231 | :param device: The device to run the layer on. 232 | :param normalization: The type of normalization to use 'batch_norm' or 'layer_norm' 233 | :param meta_layer: Whether this layer will require meta-layer capabilities such as meta-batch norm, 234 | meta-conv etc. 235 | :param input_shape: The image input shape in the form (b, c, h, w) 236 | :param num_filters: number of filters for convolutional layer 237 | :param kernel_size: the kernel size of the convolutional layer 238 | :param stride: the stride of the convolutional layer 239 | :param padding: the bias of the convolutional layer 240 | :param use_bias: whether the convolutional layer utilizes a bias 241 | """ 242 | super(MetaConvNormLayerSwish, self).__init__() 243 | self.normalization = normalization 244 | self.use_per_step_bn_statistics = args.per_step_bn_statistics 245 | self.input_shape = input_shape 246 | self.args = args 247 | self.num_filters = num_filters 248 | self.kernel_size = kernel_size 249 | self.stride = stride 250 | self.padding = padding 251 | self.use_bias = use_bias 252 | self.meta_layer = meta_layer 253 | self.no_bn_learnable_params = no_bn_learnable_params 254 | self.device = device 255 | self.layer_dict = nn.ModuleDict() 256 | self.build_block() 257 | 258 | def build_block(self): 259 | 260 | x = torch.zeros(self.input_shape) 261 | 262 | out = x 263 | 264 | self.conv = MetaConv2dLayer(in_channels=out.shape[1], out_channels=self.num_filters, 265 | kernel_size=self.kernel_size, 266 | stride=self.stride, padding=self.padding, use_bias=self.use_bias) 267 | 268 | 269 | 270 | out = self.conv(out) 271 | 272 | if self.normalization: 273 | if self.args.norm_layer == "batch_norm": 274 | self.norm_layer = MetaBatchNormLayer(out.shape[1], track_running_stats=True, 275 | meta_batch_norm=self.meta_layer, 276 | no_learnable_params=self.no_bn_learnable_params, 277 | device=self.device, 278 | use_per_step_bn_statistics=self.use_per_step_bn_statistics, 279 | args=self.args) 280 | 281 | elif self.args.norm_layer == "layer_norm": 282 | self.norm_layer = MetaLayerNormLayer(input_feature_shape=out.shape[1:]) 283 | 284 | out = self.norm_layer(out, num_step=0) 285 | 286 | out = F.relu(out) 287 | 288 | print(out.shape) 289 | 290 | def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False): 291 | """ 292 | Forward propagates by applying the function. If params are none then internal params are used. 293 | Otherwise passed params will be used to execute the function. 294 | :param input: input data batch, size either can be any. 295 | :param num_step: The current inner loop step being taken. This is used when we are learning per step params and 296 | collecting per step batch statistics. It indexes the correct object to use for the current time-step 297 | :param params: A dictionary containing 'weight' and 'bias'. 298 | :param training: Whether this is currently the training or evaluation phase. 299 | :param backup_running_statistics: Whether to backup the running statistics. This is used 300 | at evaluation time, when after the pass is complete we want to throw away the collected validation stats. 301 | :return: The result of the batch norm operation. 302 | """ 303 | batch_norm_params = None 304 | conv_params = None 305 | activation_function_pre_params = None 306 | 307 | if params is not None: 308 | params = extract_top_level_dict(current_dict=params) 309 | 310 | if self.normalization: 311 | if 'norm_layer' in params: 312 | batch_norm_params = params['norm_layer'] 313 | 314 | if 'activation_function_pre' in params: 315 | activation_function_pre_params = params['activation_function_pre'] 316 | 317 | conv_params = params['conv'] 318 | 319 | out = x 320 | 321 | 322 | out = self.conv(out, params=conv_params) 323 | 324 | if self.normalization: 325 | out = self.norm_layer.forward(out, num_step=num_step, 326 | params=batch_norm_params, training=training, 327 | backup_running_statistics=backup_running_statistics) 328 | 329 | out = F.relu(out) 330 | 331 | return out 332 | 333 | def restore_backup_stats(self): 334 | """ 335 | Restore stored statistics from the backup, replacing the current ones. 336 | """ 337 | if self.normalization: 338 | self.norm_layer.restore_backup_stats() 339 | 340 | 341 | class MetaConv2dLayer(nn.Module): 342 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, use_bias, groups=1, dilation_rate=1): 343 | """ 344 | A MetaConv2D layer. Applies the same functionality of a standard Conv2D layer with the added functionality of 345 | being able to receive a parameter dictionary at the forward pass which allows the convolution to use external 346 | weights instead of the internal ones stored in the conv layer. Useful for inner loop optimization in the meta 347 | learning setting. 348 | :param in_channels: Number of input channels 349 | :param out_channels: Number of output channels 350 | :param kernel_size: Convolutional kernel size 351 | :param stride: Convolutional stride 352 | :param padding: Convolution padding 353 | :param use_bias: Boolean indicating whether to use a bias or not. 354 | """ 355 | super(MetaConv2dLayer, self).__init__() 356 | num_filters = out_channels 357 | self.stride = int(stride) 358 | self.padding = int(padding) 359 | self.dilation_rate = int(dilation_rate) 360 | self.use_bias = use_bias 361 | self.groups = int(groups) 362 | self.weight = nn.Parameter(torch.empty(num_filters, in_channels, kernel_size, kernel_size)) 363 | nn.init.xavier_uniform_(self.weight) 364 | 365 | if self.use_bias: 366 | self.bias = nn.Parameter(torch.zeros(num_filters)) 367 | 368 | def forward(self, x, params=None): 369 | """ 370 | Applies a conv2D forward pass. If params are not None will use the passed params as the conv weights and biases 371 | :param x: Input image batch. 372 | :param params: If none, then conv layer will use the stored self.weights and self.bias, if they are not none 373 | then the conv layer will use the passed params as its parameters. 374 | :return: The output of a convolutional function. 375 | """ 376 | if params is not None: 377 | params = extract_top_level_dict(current_dict=params) 378 | if self.use_bias: 379 | (weight, bias) = params["weight"], params["bias"] 380 | else: 381 | (weight) = params["weight"] 382 | bias = None 383 | else: 384 | #print("No inner loop params") 385 | if self.use_bias: 386 | weight, bias = self.weight, self.bias 387 | else: 388 | weight = self.weight 389 | bias = None 390 | 391 | out = F.conv2d(input=x, weight=weight, bias=bias, stride=self.stride, 392 | padding=self.padding, dilation=self.dilation_rate, groups=self.groups) 393 | return out 394 | 395 | 396 | class MetaLinearLayer(nn.Module): 397 | def __init__(self, input_shape, num_filters, use_bias): 398 | """ 399 | A MetaLinear layer. Applies the same functionality of a standard linearlayer with the added functionality of 400 | being able to receive a parameter dictionary at the forward pass which allows the convolution to use external 401 | weights instead of the internal ones stored in the linear layer. Useful for inner loop optimization in the meta 402 | learning setting. 403 | :param input_shape: The shape of the input data, in the form (b, f) 404 | :param num_filters: Number of output filters 405 | :param use_bias: Whether to use biases or not. 406 | """ 407 | super(MetaLinearLayer, self).__init__() 408 | b, c = input_shape 409 | 410 | self.use_bias = use_bias 411 | self.weights = nn.Parameter(torch.ones(num_filters, c)) 412 | nn.init.xavier_uniform_(self.weights) 413 | if self.use_bias: 414 | self.bias = nn.Parameter(torch.zeros(num_filters)) 415 | 416 | def forward(self, x, params=None): 417 | """ 418 | Forward propagates by applying a linear function (Wx + b). If params are none then internal params are used. 419 | Otherwise passed params will be used to execute the function. 420 | :param x: Input data batch, in the form (b, f) 421 | :param params: A dictionary containing 'weights' and 'bias'. If params are none then internal params are used. 422 | Otherwise the external are used. 423 | :return: The result of the linear function. 424 | """ 425 | if params is not None: 426 | params = extract_top_level_dict(current_dict=params) 427 | if self.use_bias: 428 | (weight, bias) = params["weights"], params["bias"] 429 | else: 430 | (weight) = params["weights"] 431 | bias = None 432 | else: 433 | pass 434 | #print('no inner loop params', self) 435 | 436 | if self.use_bias: 437 | weight, bias = self.weights, self.bias 438 | else: 439 | weight = self.weights 440 | bias = None 441 | # print(x.shape) 442 | out = F.linear(input=x, weight=weight, bias=bias) 443 | return out 444 | 445 | 446 | class MetaBatchNormLayer(nn.Module): 447 | def __init__(self, num_features, device, args, eps=1e-5, momentum=0.1, affine=True, 448 | track_running_stats=True, meta_batch_norm=True, no_learnable_params=False, 449 | use_per_step_bn_statistics=False): 450 | """ 451 | A MetaBatchNorm layer. Applies the same functionality of a standard BatchNorm layer with the added functionality of 452 | being able to receive a parameter dictionary at the forward pass which allows the convolution to use external 453 | weights instead of the internal ones stored in the conv layer. Useful for inner loop optimization in the meta 454 | learning setting. Also has the additional functionality of being able to store per step running stats and per step beta and gamma. 455 | :param num_features: 456 | :param device: 457 | :param args: 458 | :param eps: 459 | :param momentum: 460 | :param affine: 461 | :param track_running_stats: 462 | :param meta_batch_norm: 463 | :param no_learnable_params: 464 | :param use_per_step_bn_statistics: 465 | """ 466 | super(MetaBatchNormLayer, self).__init__() 467 | self.num_features = num_features 468 | self.eps = eps 469 | 470 | self.affine = affine 471 | self.track_running_stats = track_running_stats 472 | self.meta_batch_norm = meta_batch_norm 473 | self.num_features = num_features 474 | self.device = device 475 | self.use_per_step_bn_statistics = use_per_step_bn_statistics 476 | self.args = args 477 | self.learnable_gamma = self.args.learnable_bn_gamma 478 | self.learnable_beta = self.args.learnable_bn_beta 479 | 480 | if use_per_step_bn_statistics: 481 | self.running_mean = nn.Parameter(torch.zeros(args.number_of_training_steps_per_iter, num_features), 482 | requires_grad=False) 483 | self.running_var = nn.Parameter(torch.ones(args.number_of_training_steps_per_iter, num_features), 484 | requires_grad=False) 485 | self.bias = nn.Parameter(torch.zeros(args.number_of_training_steps_per_iter, num_features), 486 | requires_grad=self.learnable_beta) 487 | self.weight = nn.Parameter(torch.ones(args.number_of_training_steps_per_iter, num_features), 488 | requires_grad=self.learnable_gamma) 489 | else: 490 | self.running_mean = nn.Parameter(torch.zeros(num_features), requires_grad=False) 491 | self.running_var = nn.Parameter(torch.zeros(num_features), requires_grad=False) 492 | self.bias = nn.Parameter(torch.zeros(num_features), 493 | requires_grad=self.learnable_beta) 494 | self.weight = nn.Parameter(torch.ones(num_features), 495 | requires_grad=self.learnable_gamma) 496 | 497 | if self.args.enable_inner_loop_optimizable_bn_params: 498 | self.bias = nn.Parameter(torch.zeros(num_features), 499 | requires_grad=self.learnable_beta) 500 | self.weight = nn.Parameter(torch.ones(num_features), 501 | requires_grad=self.learnable_gamma) 502 | 503 | self.backup_running_mean = torch.zeros(self.running_mean.shape) 504 | self.backup_running_var = torch.ones(self.running_var.shape) 505 | 506 | self.momentum = momentum 507 | 508 | def forward(self, input, num_step, params=None, training=False, backup_running_statistics=False): 509 | """ 510 | Forward propagates by applying a bach norm function. If params are none then internal params are used. 511 | Otherwise passed params will be used to execute the function. 512 | :param input: input data batch, size either can be any. 513 | :param num_step: The current inner loop step being taken. This is used when we are learning per step params and 514 | collecting per step batch statistics. It indexes the correct object to use for the current time-step 515 | :param params: A dictionary containing 'weight' and 'bias'. 516 | :param training: Whether this is currently the training or evaluation phase. 517 | :param backup_running_statistics: Whether to backup the running statistics. This is used 518 | at evaluation time, when after the pass is complete we want to throw away the collected validation stats. 519 | :return: The result of the batch norm operation. 520 | """ 521 | if params is not None: 522 | params = extract_top_level_dict(current_dict=params) 523 | (weight, bias) = params["weight"], params["bias"] 524 | #print(num_step, params['weight']) 525 | else: 526 | #print(num_step, "no params") 527 | weight, bias = self.weight, self.bias 528 | 529 | if self.use_per_step_bn_statistics: 530 | running_mean = self.running_mean[num_step] 531 | running_var = self.running_var[num_step] 532 | if params is None: 533 | if not self.args.enable_inner_loop_optimizable_bn_params: 534 | bias = self.bias[num_step] 535 | weight = self.weight[num_step] 536 | else: 537 | running_mean = None 538 | running_var = None 539 | 540 | 541 | if backup_running_statistics and self.use_per_step_bn_statistics: 542 | self.backup_running_mean.data = copy(self.running_mean.data) 543 | self.backup_running_var.data = copy(self.running_var.data) 544 | 545 | momentum = self.momentum 546 | 547 | output = F.batch_norm(input, running_mean, running_var, weight, bias, 548 | training=True, momentum=momentum, eps=self.eps) 549 | 550 | return output 551 | 552 | def restore_backup_stats(self): 553 | """ 554 | Resets batch statistics to their backup values which are collected after each forward pass. 555 | """ 556 | if self.use_per_step_bn_statistics: 557 | self.running_mean = nn.Parameter(self.backup_running_mean.to(device=self.device), requires_grad=False) 558 | self.running_var = nn.Parameter(self.backup_running_var.to(device=self.device), requires_grad=False) 559 | 560 | def extra_repr(self): 561 | return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 562 | 'track_running_stats={track_running_stats}'.format(**self.__dict__) 563 | 564 | 565 | class MetaLayerNormLayer(nn.Module): 566 | def __init__(self, input_feature_shape, eps=1e-5, elementwise_affine=True): 567 | """ 568 | A MetaLayerNorm layer. A layer that applies the same functionality as a layer norm layer with the added 569 | capability of being able to receive params at inference time to use instead of the internal ones. As well as 570 | being able to use its own internal weights. 571 | :param input_feature_shape: The input shape without the batch dimension, e.g. c, h, w 572 | :param eps: Epsilon to use for protection against overflows 573 | :param elementwise_affine: Whether to learn a multiplicative interaction parameter 'w' in addition to 574 | the biases. 575 | """ 576 | super(MetaLayerNormLayer, self).__init__() 577 | if isinstance(input_feature_shape, numbers.Integral): 578 | input_feature_shape = (input_feature_shape,) 579 | self.normalized_shape = torch.Size(input_feature_shape) 580 | self.eps = eps 581 | self.elementwise_affine = elementwise_affine 582 | if self.elementwise_affine: 583 | self.weight = nn.Parameter(torch.Tensor(*input_feature_shape), requires_grad=False) 584 | self.bias = nn.Parameter(torch.Tensor(*input_feature_shape)) 585 | else: 586 | self.register_parameter('weight', None) 587 | self.register_parameter('bias', None) 588 | self.reset_parameters() 589 | 590 | def reset_parameters(self): 591 | """ 592 | Reset parameters to their initialization values. 593 | """ 594 | if self.elementwise_affine: 595 | self.weight.data.fill_(1) 596 | self.bias.data.zero_() 597 | 598 | def forward(self, input, num_step, params=None, training=False, backup_running_statistics=False): 599 | """ 600 | Forward propagates by applying a layer norm function. If params are none then internal params are used. 601 | Otherwise passed params will be used to execute the function. 602 | :param input: input data batch, size either can be any. 603 | :param num_step: The current inner loop step being taken. This is used when we are learning per step params and 604 | collecting per step batch statistics. It indexes the correct object to use for the current time-step 605 | :param params: A dictionary containing 'weight' and 'bias'. 606 | :param training: Whether this is currently the training or evaluation phase. 607 | :param backup_running_statistics: Whether to backup the running statistics. This is used 608 | at evaluation time, when after the pass is complete we want to throw away the collected validation stats. 609 | :return: The result of the batch norm operation. 610 | """ 611 | if params is not None: 612 | params = extract_top_level_dict(current_dict=params) 613 | bias = params["bias"] 614 | else: 615 | bias = self.bias 616 | #print('no inner loop params', self) 617 | 618 | return F.layer_norm( 619 | input, self.normalized_shape, self.weight, bias, self.eps) 620 | 621 | def restore_backup_stats(self): 622 | pass 623 | 624 | def extra_repr(self): 625 | return '{normalized_shape}, eps={eps}, ' \ 626 | 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) 627 | 628 | 629 | class MetaConvNormLayerReLU(nn.Module): 630 | def __init__(self, input_shape, num_filters, kernel_size, stride, padding, use_bias, args, normalization=True, 631 | meta_layer=True, no_bn_learnable_params=False, device=None): 632 | """ 633 | Initializes a BatchNorm->Conv->ReLU layer which applies those operation in that order. 634 | :param args: A named tuple containing the system's hyperparameters. 635 | :param device: The device to run the layer on. 636 | :param normalization: The type of normalization to use 'batch_norm' or 'layer_norm' 637 | :param meta_layer: Whether this layer will require meta-layer capabilities such as meta-batch norm, 638 | meta-conv etc. 639 | :param input_shape: The image input shape in the form (b, c, h, w) 640 | :param num_filters: number of filters for convolutional layer 641 | :param kernel_size: the kernel size of the convolutional layer 642 | :param stride: the stride of the convolutional layer 643 | :param padding: the bias of the convolutional layer 644 | :param use_bias: whether the convolutional layer utilizes a bias 645 | """ 646 | super(MetaConvNormLayerReLU, self).__init__() 647 | self.normalization = normalization 648 | self.use_per_step_bn_statistics = args.per_step_bn_statistics 649 | self.input_shape = input_shape 650 | self.args = args 651 | self.num_filters = num_filters 652 | self.kernel_size = kernel_size 653 | self.stride = stride 654 | self.padding = padding 655 | self.use_bias = use_bias 656 | self.meta_layer = meta_layer 657 | self.no_bn_learnable_params = no_bn_learnable_params 658 | self.device = device 659 | self.layer_dict = nn.ModuleDict() 660 | self.build_block() 661 | 662 | def build_block(self): 663 | 664 | x = torch.zeros(self.input_shape) 665 | 666 | out = x 667 | 668 | self.conv = MetaConv2dLayer(in_channels=out.shape[1], out_channels=self.num_filters, 669 | kernel_size=self.kernel_size, 670 | stride=self.stride, padding=self.padding, use_bias=self.use_bias) 671 | 672 | out = self.conv(out) 673 | 674 | if self.normalization: 675 | if self.args.norm_layer == "batch_norm": 676 | self.norm_layer = MetaBatchNormLayer(out.shape[1], track_running_stats=True, 677 | meta_batch_norm=self.meta_layer, 678 | no_learnable_params=self.no_bn_learnable_params, 679 | device=self.device, 680 | use_per_step_bn_statistics=self.use_per_step_bn_statistics, 681 | args=self.args) 682 | elif self.args.norm_layer == "layer_norm": 683 | self.norm_layer = MetaLayerNormLayer(input_feature_shape=out.shape[1:]) 684 | 685 | out = self.norm_layer(out, num_step=0) 686 | 687 | out = F.leaky_relu(out) 688 | 689 | print(out.shape) 690 | 691 | def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False): 692 | """ 693 | Forward propagates by applying the function. If params are none then internal params are used. 694 | Otherwise passed params will be used to execute the function. 695 | :param input: input data batch, size either can be any. 696 | :param num_step: The current inner loop step being taken. This is used when we are learning per step params and 697 | collecting per step batch statistics. It indexes the correct object to use for the current time-step 698 | :param params: A dictionary containing 'weight' and 'bias'. 699 | :param training: Whether this is currently the training or evaluation phase. 700 | :param backup_running_statistics: Whether to backup the running statistics. This is used 701 | at evaluation time, when after the pass is complete we want to throw away the collected validation stats. 702 | :return: The result of the batch norm operation. 703 | """ 704 | batch_norm_params = None 705 | conv_params = None 706 | activation_function_pre_params = None 707 | 708 | if params is not None: 709 | params = extract_top_level_dict(current_dict=params) 710 | 711 | if self.normalization: 712 | if 'norm_layer' in params: 713 | batch_norm_params = params['norm_layer'] 714 | 715 | if 'activation_function_pre' in params: 716 | activation_function_pre_params = params['activation_function_pre'] 717 | 718 | conv_params = params['conv'] 719 | 720 | out = x 721 | 722 | 723 | out = self.conv(out, params=conv_params) 724 | 725 | if self.normalization: 726 | out = self.norm_layer.forward(out, num_step=num_step, 727 | params=batch_norm_params, training=training, 728 | backup_running_statistics=backup_running_statistics) 729 | 730 | out = F.leaky_relu(out) 731 | 732 | return out 733 | 734 | def restore_backup_stats(self): 735 | """ 736 | Restore stored statistics from the backup, replacing the current ones. 737 | """ 738 | if self.normalization: 739 | self.norm_layer.restore_backup_stats() 740 | 741 | 742 | class MetaNormLayerConvReLU(nn.Module): 743 | def __init__(self, input_shape, num_filters, kernel_size, stride, padding, use_bias, args, normalization=True, 744 | meta_layer=True, no_bn_learnable_params=False, device=None): 745 | """ 746 | Initializes a BatchNorm->Conv->ReLU layer which applies those operation in that order. 747 | :param args: A named tuple containing the system's hyperparameters. 748 | :param device: The device to run the layer on. 749 | :param normalization: The type of normalization to use 'batch_norm' or 'layer_norm' 750 | :param meta_layer: Whether this layer will require meta-layer capabilities such as meta-batch norm, 751 | meta-conv etc. 752 | :param input_shape: The image input shape in the form (b, c, h, w) 753 | :param num_filters: number of filters for convolutional layer 754 | :param kernel_size: the kernel size of the convolutional layer 755 | :param stride: the stride of the convolutional layer 756 | :param padding: the bias of the convolutional layer 757 | :param use_bias: whether the convolutional layer utilizes a bias 758 | """ 759 | super(MetaNormLayerConvReLU, self).__init__() 760 | self.normalization = normalization 761 | self.use_per_step_bn_statistics = args.per_step_bn_statistics 762 | self.input_shape = input_shape 763 | self.args = args 764 | self.num_filters = num_filters 765 | self.kernel_size = kernel_size 766 | self.stride = stride 767 | self.padding = padding 768 | self.use_bias = use_bias 769 | self.meta_layer = meta_layer 770 | self.no_bn_learnable_params = no_bn_learnable_params 771 | self.device = device 772 | self.layer_dict = nn.ModuleDict() 773 | self.build_block() 774 | 775 | def build_block(self): 776 | 777 | x = torch.zeros(self.input_shape) 778 | 779 | out = x 780 | if self.normalization: 781 | if self.args.norm_layer == "batch_norm": 782 | self.norm_layer = MetaBatchNormLayer(self.input_shape[1], track_running_stats=True, 783 | meta_batch_norm=self.meta_layer, 784 | no_learnable_params=self.no_bn_learnable_params, 785 | device=self.device, 786 | use_per_step_bn_statistics=self.use_per_step_bn_statistics, 787 | args=self.args) 788 | elif self.args.norm_layer == "layer_norm": 789 | self.norm_layer = MetaLayerNormLayer(input_feature_shape=out.shape[1:]) 790 | 791 | out = self.norm_layer.forward(out, num_step=0) 792 | self.conv = MetaConv2dLayer(in_channels=out.shape[1], out_channels=self.num_filters, 793 | kernel_size=self.kernel_size, 794 | stride=self.stride, padding=self.padding, use_bias=self.use_bias) 795 | 796 | 797 | self.layer_dict['activation_function_pre'] = nn.LeakyReLU() 798 | 799 | 800 | out = self.layer_dict['activation_function_pre'].forward(self.conv.forward(out)) 801 | print(out.shape) 802 | 803 | def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False): 804 | """ 805 | Forward propagates by applying the function. If params are none then internal params are used. 806 | Otherwise passed params will be used to execute the function. 807 | :param input: input data batch, size either can be any. 808 | :param num_step: The current inner loop step being taken. This is used when we are learning per step params and 809 | collecting per step batch statistics. It indexes the correct object to use for the current time-step 810 | :param params: A dictionary containing 'weight' and 'bias'. 811 | :param training: Whether this is currently the training or evaluation phase. 812 | :param backup_running_statistics: Whether to backup the running statistics. This is used 813 | at evaluation time, when after the pass is complete we want to throw away the collected validation stats. 814 | :return: The result of the batch norm operation. 815 | """ 816 | batch_norm_params = None 817 | 818 | if params is not None: 819 | params = extract_top_level_dict(current_dict=params) 820 | 821 | if self.normalization: 822 | if 'norm_layer' in params: 823 | batch_norm_params = params['norm_layer'] 824 | 825 | conv_params = params['conv'] 826 | else: 827 | conv_params = None 828 | #print('no inner loop params', self) 829 | 830 | out = x 831 | 832 | if self.normalization: 833 | out = self.norm_layer.forward(out, num_step=num_step, 834 | params=batch_norm_params, training=training, 835 | backup_running_statistics=backup_running_statistics) 836 | 837 | out = self.conv.forward(out, params=conv_params) 838 | out = self.layer_dict['activation_function_pre'].forward(out) 839 | 840 | return out 841 | 842 | def restore_backup_stats(self): 843 | """ 844 | Restore stored statistics from the backup, replacing the current ones. 845 | """ 846 | if self.normalization: 847 | self.norm_layer.restore_backup_stats() 848 | 849 | 850 | class VGGReLUNormNetwork(nn.Module): 851 | def __init__(self, im_shape, num_output_classes, args, device, meta_classifier=True): 852 | """ 853 | Builds a multilayer convolutional network. It also provides functionality for passing external parameters to be 854 | used at inference time. Enables inner loop optimization readily. 855 | :param im_shape: The input image batch shape. 856 | :param num_output_classes: The number of output classes of the network. 857 | :param args: A named tuple containing the system's hyperparameters. 858 | :param device: The device to run this on. 859 | :param meta_classifier: A flag indicating whether the system's meta-learning (inner-loop) functionalities should 860 | be enabled. 861 | """ 862 | super(VGGReLUNormNetwork, self).__init__() 863 | b, c, self.h, self.w = im_shape 864 | self.device = device 865 | self.total_layers = 0 866 | self.args = args 867 | self.upscale_shapes = [] 868 | self.cnn_filters = args.cnn_num_filters 869 | self.input_shape = list(im_shape) 870 | self.num_stages = args.num_stages 871 | self.num_output_classes = num_output_classes 872 | 873 | if args.max_pooling: 874 | print("Using max pooling") 875 | self.conv_stride = 1 876 | else: 877 | print("Using strided convolutions") 878 | self.conv_stride = 2 879 | self.meta_classifier = meta_classifier 880 | 881 | self.build_network() 882 | print("meta network params") 883 | for name, param in self.named_parameters(): 884 | print(name, param.shape) 885 | 886 | def build_network(self): 887 | """ 888 | Builds the network before inference is required by creating some dummy inputs with the same input as the 889 | self.im_shape tuple. Then passes that through the network and dynamically computes input shapes and 890 | sets output shapes for each layer. 891 | """ 892 | x = torch.zeros(self.input_shape) 893 | out = x 894 | self.layer_dict = nn.ModuleDict() 895 | self.upscale_shapes.append(x.shape) 896 | 897 | for i in range(self.num_stages): 898 | self.layer_dict['conv{}'.format(i)] = MetaConvNormLayerReLU(input_shape=out.shape, 899 | num_filters=self.cnn_filters, 900 | kernel_size=3, stride=self.conv_stride, 901 | padding=self.args.conv_padding, 902 | use_bias=True, args=self.args, 903 | normalization=True, 904 | meta_layer=self.meta_classifier, 905 | no_bn_learnable_params=False, 906 | device=self.device) 907 | out = self.layer_dict['conv{}'.format(i)](out, training=True, num_step=0) 908 | 909 | if self.args.max_pooling: 910 | out = F.max_pool2d(input=out, kernel_size=(2, 2), stride=2, padding=0) 911 | 912 | 913 | if not self.args.max_pooling: 914 | out = F.avg_pool2d(out, out.shape[2]) 915 | 916 | self.encoder_features_shape = list(out.shape) 917 | out = out.view(out.shape[0], -1) 918 | 919 | self.layer_dict['linear'] = MetaLinearLayer(input_shape=(out.shape[0], np.prod(out.shape[1:])), 920 | num_filters=self.num_output_classes, use_bias=True) 921 | 922 | out = self.layer_dict['linear'](out) 923 | print("VGGNetwork build", out.shape) 924 | 925 | def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False): 926 | """ 927 | Forward propages through the network. If any params are passed then they are used instead of stored params. 928 | :param x: Input image batch. 929 | :param num_step: The current inner loop step number 930 | :param params: If params are None then internal parameters are used. If params are a dictionary with keys the 931 | same as the layer names then they will be used instead. 932 | :param training: Whether this is training (True) or eval time. 933 | :param backup_running_statistics: Whether to backup the running statistics in their backup store. Which is 934 | then used to reset the stats back to a previous state (usually after an eval loop, when we want to throw away stored statistics) 935 | :return: Logits of shape b, num_output_classes. 936 | """ 937 | param_dict = dict() 938 | 939 | if params is not None: 940 | params = {key: value[0] for key, value in params.items()} 941 | param_dict = extract_top_level_dict(current_dict=params) 942 | 943 | # print('top network', param_dict.keys()) 944 | for name, param in self.layer_dict.named_parameters(): 945 | path_bits = name.split(".") 946 | layer_name = path_bits[0] 947 | if layer_name not in param_dict: 948 | param_dict[layer_name] = None 949 | 950 | out = x 951 | 952 | for i in range(self.num_stages): 953 | out = self.layer_dict['conv{}'.format(i)](out, params=param_dict['conv{}'.format(i)], training=training, 954 | backup_running_statistics=backup_running_statistics, 955 | num_step=num_step) 956 | if self.args.max_pooling: 957 | out = F.max_pool2d(input=out, kernel_size=(2, 2), stride=2, padding=0) 958 | 959 | if not self.args.max_pooling: 960 | out = F.avg_pool2d(out, out.shape[2]) 961 | 962 | out = out.view(out.size(0), -1) 963 | out = self.layer_dict['linear'](out, param_dict['linear']) 964 | 965 | return out 966 | 967 | def re_init(self): 968 | #for param in self.parameters(): 969 | for name, param in self.named_parameters(): 970 | if param.requires_grad and 'weight' in name and 'norm' not in name: 971 | nn.init.xavier_uniform_(param) 972 | 973 | def zero_grad(self, params=None): 974 | if params is None: 975 | for param in self.parameters(): 976 | if param.requires_grad == True: 977 | if param.grad is not None: 978 | if torch.sum(param.grad) > 0: 979 | print(param.grad) 980 | param.grad.zero_() 981 | else: 982 | for name, param in params.items(): 983 | if param.requires_grad == True: 984 | if param.grad is not None: 985 | if torch.sum(param.grad) > 0: 986 | print(param.grad) 987 | param.grad.zero_() 988 | params[name].grad = None 989 | 990 | def restore_backup_stats(self): 991 | """ 992 | Reset stored batch statistics from the stored backup. 993 | """ 994 | for i in range(self.num_stages): 995 | self.layer_dict['conv{}'.format(i)].restore_backup_stats() 996 | 997 | 998 | class ResNet12(nn.Module): 999 | def __init__(self, im_shape, num_output_classes, args, device, meta_classifier=True): 1000 | """ 1001 | Builds a multilayer convolutional network. It also provides functionality for passing external parameters to be 1002 | used at inference time. Enables inner loop optimization readily. 1003 | :param im_shape: The input image batch shape. 1004 | :param num_output_classes: The number of output classes of the network. 1005 | :param args: A named tuple containing the system's hyperparameters. 1006 | :param device: The device to run this on. 1007 | :param meta_classifier: A flag indicating whether the system's meta-learning (inner-loop) functionalities should 1008 | be enabled. 1009 | """ 1010 | super(ResNet12, self).__init__() 1011 | b, c, self.h, self.w = im_shape 1012 | self.device = device 1013 | self.total_layers = 0 1014 | self.args = args 1015 | self.upscale_shapes = [] 1016 | self.cnn_filters = args.cnn_num_filters 1017 | self.input_shape = list(im_shape) 1018 | self.num_stages = args.num_stages 1019 | self.num_output_classes = num_output_classes 1020 | 1021 | if args.max_pooling: 1022 | print("Using max pooling") 1023 | self.conv_stride = 1 1024 | else: 1025 | print("Using strided convolutions") 1026 | self.conv_stride = 2 1027 | self.meta_classifier = meta_classifier 1028 | 1029 | self.build_network() 1030 | print("meta network params") 1031 | for name, param in self.named_parameters(): 1032 | print(name, param.shape) 1033 | 1034 | def build_network(self): 1035 | """ 1036 | Builds the network before inference is required by creating some dummy inputs with the same input as the 1037 | self.im_shape tuple. Then passes that through the network and dynamically computes input shapes and 1038 | sets output shapes for each layer. 1039 | """ 1040 | x = torch.zeros(self.input_shape) 1041 | out = x 1042 | self.layer_dict = nn.ModuleDict() 1043 | self.upscale_shapes.append(x.shape) 1044 | 1045 | num_chn = [64, 128, 256, 512] 1046 | max_padding = [0, 0, 1, 1] 1047 | maxpool = [True,True,True,False] 1048 | for i in range(len(num_chn)): 1049 | self.layer_dict['layer{}'.format(i)] = MetaMaxResLayerReLU(input_shape=out.shape, 1050 | num_filters=num_chn[i], 1051 | kernel_size=3, stride=1, 1052 | padding=1, 1053 | use_bias=False, args=self.args, 1054 | #use_bias=True, args=self.args, 1055 | normalization=True, 1056 | meta_layer=self.meta_classifier, 1057 | no_bn_learnable_params=False, 1058 | device=self.device, 1059 | downsample=False, 1060 | max_padding=max_padding[i], 1061 | maxpool=maxpool[i]) 1062 | out = self.layer_dict['layer{}'.format(i)](out, training=True, num_step=0) 1063 | 1064 | out = F.adaptive_avg_pool2d(out, (1,1)) 1065 | 1066 | out = out.view(out.shape[0], -1) 1067 | 1068 | self.layer_dict['linear'] = MetaLinearLayer(input_shape=(out.shape[0], np.prod(out.shape[1:])), 1069 | num_filters=self.num_output_classes, use_bias=True) 1070 | 1071 | out = self.layer_dict['linear'](out) 1072 | print("ResNet12 build", out.shape) 1073 | 1074 | def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False): 1075 | """ 1076 | Forward propages through the network. If any params are passed then they are used instead of stored params. 1077 | :param x: Input image batch. 1078 | :param num_step: The current inner loop step number 1079 | :param params: If params are None then internal parameters are used. If params are a dictionary with keys the 1080 | same as the layer names then they will be used instead. 1081 | :param training: Whether this is training (True) or eval time. 1082 | :param backup_running_statistics: Whether to backup the running statistics in their backup store. Which is 1083 | then used to reset the stats back to a previous state (usually after an eval loop, when we want to throw away stored statistics) 1084 | :return: Logits of shape b, num_output_classes. 1085 | """ 1086 | param_dict = dict() 1087 | 1088 | if params is not None: 1089 | #param_dict = parallel_extract_top_level_dict(current_dict=params) 1090 | 1091 | params = {key: value[0] for key, value in params.items()} 1092 | param_dict = extract_top_level_dict(current_dict=params) 1093 | 1094 | # print('top network', param_dict.keys()) 1095 | for name, param in self.layer_dict.named_parameters(): 1096 | path_bits = name.split(".") 1097 | layer_name = path_bits[0] 1098 | if layer_name not in param_dict: 1099 | param_dict[layer_name] = None 1100 | 1101 | out = x 1102 | 1103 | for i in range(self.num_stages): 1104 | out = self.layer_dict['layer{}'.format(i)](out, params=param_dict['layer{}'.format(i)], training=training, 1105 | backup_running_statistics=backup_running_statistics, 1106 | num_step=num_step) 1107 | 1108 | out = F.adaptive_avg_pool2d(out, (1,1)) 1109 | out = out.view(out.size(0), -1) 1110 | out = self.layer_dict['linear'](out, param_dict['linear']) 1111 | 1112 | return out 1113 | 1114 | def zero_grad(self, params=None): 1115 | if params is None: 1116 | for param in self.parameters(): 1117 | if param.requires_grad == True: 1118 | if param.grad is not None: 1119 | if torch.sum(param.grad) > 0: 1120 | print(param.grad) 1121 | param.grad.zero_() 1122 | else: 1123 | for name, param in params.items(): 1124 | if param.requires_grad == True: 1125 | if param.grad is not None: 1126 | if torch.sum(param.grad) > 0: 1127 | print(param.grad) 1128 | param.grad.zero_() 1129 | params[name].grad = None 1130 | 1131 | def restore_backup_stats(self): 1132 | """ 1133 | Reset stored batch statistics from the stored backup. 1134 | """ 1135 | #self.layer_dict['conv0'].restore_backup_stats() 1136 | for i in range(self.num_stages): 1137 | self.layer_dict['layer{}'.format(i)].restore_backup_stats() 1138 | -------------------------------------------------------------------------------- /train_maml_system.py: -------------------------------------------------------------------------------- 1 | from data import MetaLearningSystemDataLoader 2 | from experiment_builder import ExperimentBuilder 3 | from few_shot_learning_system import MAMLFewShotClassifier 4 | from utils.parser_utils import get_args 5 | from utils.dataset_tools import maybe_unzip_dataset 6 | 7 | # Combines the arguments, model, data and experiment builders to run an experiment 8 | args, device = get_args() 9 | model = MAMLFewShotClassifier(args=args, device=device, 10 | im_shape=(2, 3, 11 | args.image_height, args.image_width)) 12 | maybe_unzip_dataset(args=args) 13 | data = MetaLearningSystemDataLoader 14 | maml_system = ExperimentBuilder(model=model, data=data, args=args, device=device) 15 | maml_system.run_experiment() 16 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baiksung/ALFA/25cd381932812c99b7542cd11e9b45ae8ca7f125/utils/__init__.py -------------------------------------------------------------------------------- /utils/dataset_tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | def maybe_unzip_dataset(args): 5 | 6 | datasets = [args.dataset_name] 7 | dataset_paths = [args.dataset_path] 8 | done = False 9 | 10 | for dataset_idx, dataset_path in enumerate(dataset_paths): 11 | if dataset_path.endswith('/'): 12 | dataset_path = dataset_path[:-1] 13 | print(dataset_path) 14 | if not os.path.exists(dataset_path): 15 | print("Not found dataset folder structure.. searching for .tar.bz2 file") 16 | zip_directory = "{}.tar.bz2".format(os.path.join(os.environ['DATASET_DIR'], datasets[dataset_idx])) 17 | 18 | assert os.path.exists(os.path.abspath(zip_directory)), "{} dataset zip file not found" \ 19 | "place dataset in datasets folder as explained in README".format(os.path.abspath(zip_directory)) 20 | print("Found zip file, unpacking") 21 | 22 | unzip_file(filepath_pack=os.path.join(os.environ['DATASET_DIR'], "{}.tar.bz2".format(datasets[dataset_idx])), 23 | filepath_to_store=os.environ['DATASET_DIR']) 24 | 25 | 26 | 27 | args.reset_stored_filepaths = True 28 | 29 | total_files = 0 30 | for subdir, dir, files in os.walk(dataset_path): 31 | for file in files: 32 | if file.lower().endswith(".jpeg") or file.lower().endswith(".jpg") or file.lower().endswith( 33 | ".png") or file.lower().endswith(".pkl"): 34 | total_files += 1 35 | print("count stuff________________________________________", total_files) 36 | if (total_files == 1623 * 20 and datasets[dataset_idx] == 'omniglot_dataset') or ( 37 | total_files == 100 * 600 and 'mini_imagenet' in datasets[dataset_idx]) or ( 38 | total_files == 3 and 'mini_imagenet_pkl' in datasets[dataset_idx]): 39 | print("file count is correct") 40 | done = True 41 | elif datasets[dataset_idx] != 'omniglot_dataset' and datasets[dataset_idx] != 'mini_imagenet' and datasets[dataset_idx] != 'mini_imagenet_pkl': 42 | done = True 43 | print("using new dataset") 44 | 45 | if not done: 46 | shutil.rmtree(dataset_path, ignore_errors=True) 47 | maybe_unzip_dataset(args) 48 | 49 | 50 | def unzip_file(filepath_pack, filepath_to_store): 51 | command_to_run = "tar -I pbzip2 -xf {} -C {}".format(filepath_pack, filepath_to_store) 52 | os.system(command_to_run) 53 | -------------------------------------------------------------------------------- /utils/parser_utils.py: -------------------------------------------------------------------------------- 1 | from torch import cuda 2 | 3 | 4 | def get_args(): 5 | import argparse 6 | import os 7 | import torch 8 | import json 9 | parser = argparse.ArgumentParser(description='Welcome to the L2F training and inference system') 10 | 11 | parser.add_argument('--batch_size', nargs="?", type=int, default=32, help='Batch_size for experiment') 12 | parser.add_argument('--image_height', nargs="?", type=int, default=28) 13 | parser.add_argument('--image_width', nargs="?", type=int, default=28) 14 | parser.add_argument('--image_channels', nargs="?", type=int, default=1) 15 | parser.add_argument('--reset_stored_filepaths', type=str, default="False") 16 | parser.add_argument('--reverse_channels', type=str, default="False") 17 | parser.add_argument('--num_of_gpus', type=int, default=1) 18 | parser.add_argument('--indexes_of_folders_indicating_class', nargs='+', default=[-2, -3]) 19 | parser.add_argument('--train_val_test_split', nargs='+', default=[0.73982737361, 0.26, 0.13008631319]) 20 | parser.add_argument('--samples_per_iter', nargs="?", type=int, default=1) 21 | parser.add_argument('--labels_as_int', type=str, default="False") 22 | parser.add_argument('--seed', type=int, default=104) 23 | 24 | parser.add_argument('--gpu_to_use', type=int) 25 | parser.add_argument('--num_dataprovider_workers', nargs="?", type=int, default=4) 26 | parser.add_argument('--max_models_to_save', nargs="?", type=int, default=5) 27 | parser.add_argument('--dataset_name', type=str, default="omniglot_dataset") 28 | parser.add_argument('--dataset_path', type=str, default="datasets/omniglot_dataset") 29 | parser.add_argument('--reset_stored_paths', type=str, default="False") 30 | parser.add_argument('--experiment_name', nargs="?", type=str, ) 31 | parser.add_argument('--architecture_name', nargs="?", type=str) 32 | parser.add_argument('--continue_from_epoch', nargs="?", type=str, default='latest', help='Continue from checkpoint of epoch') 33 | parser.add_argument('--dropout_rate_value', type=float, default=0.3, help='Dropout_rate_value') 34 | parser.add_argument('--num_target_samples', type=int, default=15, help='Dropout_rate_value') 35 | parser.add_argument('--second_order', type=str, default="False", help='Dropout_rate_value') 36 | parser.add_argument('--total_epochs', type=int, default=200, help='Number of epochs per experiment') 37 | parser.add_argument('--total_iter_per_epoch', type=int, default=500, help='Number of iters per epoch') 38 | parser.add_argument('--min_learning_rate', type=float, default=0.00001, help='Min learning rate') 39 | parser.add_argument('--meta_learning_rate', type=float, default=0.001, help='Learning rate of overall MAML system') 40 | parser.add_argument('--meta_opt_bn', type=str, default="False") 41 | parser.add_argument('--task_learning_rate', type=float, default=0.1, help='Learning rate per task gradient step') 42 | 43 | parser.add_argument('--norm_layer', type=str, default="batch_norm") 44 | parser.add_argument('--max_pooling', type=str, default="False") 45 | parser.add_argument('--per_step_bn_statistics', type=str, default="False") 46 | parser.add_argument('--num_classes_per_set', type=int, default=20, help='Number of classes to sample per set') 47 | parser.add_argument('--cnn_num_blocks', type=int, default=4, help='Number of classes to sample per set') 48 | parser.add_argument('--number_of_training_steps_per_iter', type=int, default=1, help='Number of classes to sample per set') 49 | parser.add_argument('--number_of_evaluation_steps_per_iter', type=int, default=1, help='Number of classes to sample per set') 50 | parser.add_argument('--cnn_num_filters', type=int, default=64, help='Number of classes to sample per set') 51 | parser.add_argument('--cnn_blocks_per_stage', type=int, default=1, 52 | help='Number of classes to sample per set') 53 | parser.add_argument('--num_samples_per_class', type=int, default=1, help='Number of samples per set to sample') 54 | parser.add_argument('--name_of_args_json_file', type=str, default="None") 55 | 56 | # Architecture Backbone 57 | parser.add_argument('--backbone', type=str, default="4-CONV", help='Base learner architecture backbone') 58 | 59 | # L2F 60 | parser.add_argument('--attenuate', type=str, default="False", help='Whether to attenuate the initialization (for L2F)') 61 | 62 | # ALFA 63 | parser.add_argument('--alfa', type=str, default="False", help='Whether to perform adaptive inner-loop optimization') 64 | parser.add_argument('--random_init', type=str, default="False", help='Whether to use random initialization') 65 | 66 | args = parser.parse_args() 67 | args_dict = vars(args) 68 | if args.name_of_args_json_file is not "None": 69 | args_dict = extract_args_from_json(args.name_of_args_json_file, args_dict) 70 | 71 | for key in list(args_dict.keys()): 72 | 73 | if str(args_dict[key]).lower() == "true": 74 | args_dict[key] = True 75 | elif str(args_dict[key]).lower() == "false": 76 | args_dict[key] = False 77 | if key == "dataset_path": 78 | args_dict[key] = os.path.join(os.environ['DATASET_DIR'], args_dict[key]) 79 | print(key, os.path.join(os.environ['DATASET_DIR'], args_dict[key])) 80 | 81 | print(key, args_dict[key], type(args_dict[key])) 82 | 83 | args = Bunch(args_dict) 84 | 85 | 86 | args.use_cuda = torch.cuda.is_available() 87 | if torch.cuda.is_available(): # checks whether a cuda gpu is available and whether the gpu flag is True 88 | device = torch.cuda.current_device() 89 | 90 | print("use GPU", device) 91 | print("GPU ID {}".format(torch.cuda.current_device())) 92 | 93 | else: 94 | print("use CPU") 95 | device = torch.device('cpu') # sets the device to be CPU 96 | 97 | 98 | return args, device 99 | 100 | 101 | 102 | class Bunch(object): 103 | def __init__(self, adict): 104 | self.__dict__.update(adict) 105 | 106 | def extract_args_from_json(json_file_path, args_dict): 107 | import json 108 | summary_filename = json_file_path 109 | with open(summary_filename) as f: 110 | summary_dict = json.load(fp=f) 111 | 112 | for key in summary_dict.keys(): 113 | if "continue_from" in key: 114 | pass 115 | elif "gpu_to_use" in key: 116 | pass 117 | else: 118 | args_dict[key] = summary_dict[key] 119 | 120 | return args_dict 121 | 122 | 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /utils/storage.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import datetime 3 | import os 4 | import numpy as np 5 | from utils.parser_utils import get_args 6 | import json 7 | 8 | def save_to_json(filename, dict_to_store): 9 | with open(os.path.abspath(filename), 'w') as f: 10 | json.dump(dict_to_store, fp=f) 11 | 12 | def load_from_json(filename): 13 | with open(filename, mode="r") as f: 14 | load_dict = json.load(fp=f) 15 | 16 | return load_dict 17 | 18 | def save_statistics(experiment_name, line_to_add, filename="summary_statistics.csv", create=False): 19 | summary_filename = "{}/{}".format(experiment_name, filename) 20 | if create: 21 | with open(summary_filename, 'w') as f: 22 | writer = csv.writer(f) 23 | writer.writerow(line_to_add) 24 | else: 25 | with open(summary_filename, 'a') as f: 26 | writer = csv.writer(f) 27 | writer.writerow(line_to_add) 28 | 29 | return summary_filename 30 | 31 | def load_statistics(experiment_name, filename="summary_statistics.csv"): 32 | data_dict = dict() 33 | summary_filename = "{}/{}".format(experiment_name, filename) 34 | with open(summary_filename, 'r') as f: 35 | lines = f.readlines() 36 | data_labels = lines[0].replace("\n", "").split(",") 37 | del lines[0] 38 | 39 | for label in data_labels: 40 | data_dict[label] = [] 41 | 42 | for line in lines: 43 | data = line.replace("\n", "").split(",") 44 | for key, item in zip(data_labels, data): 45 | data_dict[key].append(item) 46 | return data_dict 47 | 48 | 49 | def build_experiment_folder(experiment_name): 50 | experiment_path = os.path.abspath(experiment_name) 51 | saved_models_filepath = "{}/{}".format(experiment_path, "saved_models") 52 | logs_filepath = "{}/{}".format(experiment_path, "logs") 53 | samples_filepath = "{}/{}".format(experiment_path, "visual_outputs") 54 | 55 | if not os.path.exists(experiment_path): 56 | os.makedirs(experiment_path) 57 | if not os.path.exists(logs_filepath): 58 | os.makedirs(logs_filepath) 59 | if not os.path.exists(samples_filepath): 60 | os.makedirs(samples_filepath) 61 | if not os.path.exists(saved_models_filepath): 62 | os.makedirs(saved_models_filepath) 63 | 64 | outputs = (saved_models_filepath, logs_filepath, samples_filepath) 65 | outputs = (os.path.abspath(item) for item in outputs) 66 | return outputs 67 | 68 | def get_best_validation_model_statistics(experiment_name, filename="summary_statistics.csv"): 69 | """ 70 | Returns the best val epoch and val accuracy from a log csv file 71 | :param log_dir: The log directory the file is saved in 72 | :param statistics_file_name: The log file name 73 | :return: The best validation accuracy and the epoch at which it is produced 74 | """ 75 | log_file_dict = load_statistics(filename=filename, experiment_name=experiment_name) 76 | d_val_loss = np.array(log_file_dict['total_d_val_loss_mean'], dtype=np.float32) 77 | best_d_val_loss = np.min(d_val_loss) 78 | best_d_val_epoch = np.argmin(d_val_loss) 79 | 80 | return best_d_val_loss, best_d_val_epoch 81 | 82 | def create_json_experiment_log(experiment_log_dir, args, log_name="experiment_log.json"): 83 | summary_filename = "{}/{}".format(experiment_log_dir, log_name) 84 | 85 | experiment_summary_dict = dict() 86 | 87 | for key, value in vars(args).items(): 88 | experiment_summary_dict[key] = value 89 | 90 | experiment_summary_dict["epoch_stats"] = dict() 91 | timestamp = datetime.datetime.now().timestamp() 92 | experiment_summary_dict["experiment_status"] = [(timestamp, "initialization")] 93 | experiment_summary_dict["experiment_initialization_time"] = timestamp 94 | with open(os.path.abspath(summary_filename), 'w') as f: 95 | json.dump(experiment_summary_dict, fp=f) 96 | 97 | def update_json_experiment_log_dict(key, value, experiment_log_dir, log_name="experiment_log.json"): 98 | summary_filename = "{}/{}".format(experiment_log_dir, log_name) 99 | with open(summary_filename) as f: 100 | summary_dict = json.load(fp=f) 101 | 102 | summary_dict[key].append(value) 103 | 104 | with open(summary_filename, 'w') as f: 105 | json.dump(summary_dict, fp=f) 106 | 107 | def change_json_log_experiment_status(experiment_status, experiment_log_dir, log_name="experiment_log.json"): 108 | timestamp = datetime.datetime.now().timestamp() 109 | experiment_status = (timestamp, experiment_status) 110 | update_json_experiment_log_dict(key="experiment_status", value=experiment_status, 111 | experiment_log_dir=experiment_log_dir, log_name=log_name) 112 | 113 | def update_json_experiment_log_epoch_stats(epoch_stats, experiment_log_dir, log_name="experiment_log.json"): 114 | summary_filename = "{}/{}".format(experiment_log_dir, log_name) 115 | with open(summary_filename) as f: 116 | summary_dict = json.load(fp=f) 117 | 118 | epoch_stats_dict = summary_dict["epoch_stats"] 119 | 120 | for key in epoch_stats.keys(): 121 | entry = float(epoch_stats[key]) 122 | if key in epoch_stats_dict: 123 | epoch_stats_dict[key].append(entry) 124 | else: 125 | epoch_stats_dict[key] = [entry] 126 | 127 | summary_dict['epoch_stats'] = epoch_stats_dict 128 | 129 | with open(summary_filename, 'w') as f: 130 | json.dump(summary_dict, fp=f) 131 | return summary_filename 132 | --------------------------------------------------------------------------------