├── .gitignore ├── EfficientLab-6-3_FOMAML-star_checkpoint.tar.gz ├── LICENSE ├── README.md ├── augmenters ├── __init__.py └── np_augmenters.py ├── data ├── __init__.py ├── fp-k_test_set.txt ├── fss_1000_image_to_joint_tfrecord_shards.py ├── fss_1000_image_to_tfrecord.py ├── fss_1000_utils.py ├── fss_test_set.txt ├── fss_train_set.txt └── input_fn.py ├── figures ├── EfficientLab.png └── example_5-shot_predictions.png ├── joint_train.py ├── joint_train ├── __init__.py └── data │ ├── __init__.py │ ├── constants.py │ └── input_fn.py ├── make_python_virtualenv.sh ├── meta_learners ├── __init__.py ├── args.py ├── hyperparam_search.py ├── metaseg.py ├── supervised_reptile │ ├── LICENSE │ ├── __init__.py │ └── supervised_reptile │ │ ├── __init__.py │ │ ├── eval.py │ │ ├── miniimagenet.py │ │ ├── models.py │ │ ├── omniglot.py │ │ ├── reptile.py │ │ └── train.py └── variables.py ├── models ├── __init__.py ├── constants.py ├── efficientlab.py ├── efficientnet │ ├── __init__.py │ ├── constants.py │ ├── efficientnet_builder.py │ ├── efficientnet_model.py │ └── utils.py ├── lr_schedulers.py └── regularizers.py ├── requirements.txt ├── run.sh ├── run_metasegnet.py ├── setup.py └── utils ├── __init__.py ├── debug_tf_dataset.py ├── util.py └── viz.py /.gitignore: -------------------------------------------------------------------------------- 1 | supervised-reptile/supervised_reptile/dev/ 2 | *.DS_Store 3 | 4 | ## Core latex/pdflatex auxiliary files: 5 | *.aux 6 | *.lof 7 | *.log 8 | *.lot 9 | *.fls 10 | *.out 11 | *.toc 12 | *.fmt 13 | *.fot 14 | *.cb 15 | *.cb2 16 | .*.lb 17 | 18 | ## Intermediate documents: 19 | *.dvi 20 | *.xdv 21 | *-converted-to.* 22 | # these rules might exclude image files for figures etc. 23 | # *.ps 24 | # *.eps 25 | # *.pdf 26 | 27 | ## Generated if empty string is given at "Please type another file name for output:" 28 | # *.pdf 29 | 30 | ## Bibliography auxiliary files (bibtex/biblatex/biber): 31 | *.bbl 32 | *.bcf 33 | *.blg 34 | *-blx.aux 35 | *-blx.bib 36 | *.run.xml 37 | 38 | ## Build tool auxiliary files: 39 | *.fdb_latexmk 40 | *.synctex 41 | *.synctex(busy) 42 | *.synctex.gz 43 | *.synctex.gz(busy) 44 | *.pdfsync 45 | 46 | ## Build tool directories for auxiliary files 47 | # latexrun 48 | latex.out/ 49 | 50 | ## Auxiliary and intermediate files from other packages: 51 | # algorithms 52 | *.alg 53 | *.loa 54 | 55 | # achemso 56 | acs-*.bib 57 | 58 | # amsthm 59 | *.thm 60 | 61 | # beamer 62 | *.nav 63 | *.pre 64 | *.snm 65 | *.vrb 66 | 67 | # changes 68 | *.soc 69 | 70 | # comment 71 | *.cut 72 | 73 | # cprotect 74 | *.cpt 75 | 76 | # elsarticle (documentclass of Elsevier journals) 77 | *.spl 78 | 79 | # endnotes 80 | *.ent 81 | 82 | # fixme 83 | *.lox 84 | 85 | # feynmf/feynmp 86 | *.mf 87 | *.mp 88 | *.t[1-9] 89 | *.t[1-9][0-9] 90 | *.tfm 91 | 92 | #(r)(e)ledmac/(r)(e)ledpar 93 | *.end 94 | *.?end 95 | *.[1-9] 96 | *.[1-9][0-9] 97 | *.[1-9][0-9][0-9] 98 | *.[1-9]R 99 | *.[1-9][0-9]R 100 | *.[1-9][0-9][0-9]R 101 | *.eledsec[1-9] 102 | *.eledsec[1-9]R 103 | *.eledsec[1-9][0-9] 104 | *.eledsec[1-9][0-9]R 105 | *.eledsec[1-9][0-9][0-9] 106 | *.eledsec[1-9][0-9][0-9]R 107 | 108 | # glossaries 109 | *.acn 110 | *.acr 111 | *.glg 112 | *.glo 113 | *.gls 114 | *.glsdefs 115 | *.lzo 116 | *.lzs 117 | 118 | # uncomment this for glossaries-extra (will ignore makeindex's style files!) 119 | # *.ist 120 | 121 | # gnuplottex 122 | *-gnuplottex-* 123 | 124 | # gregoriotex 125 | *.gaux 126 | *.gtex 127 | 128 | # htlatex 129 | *.4ct 130 | *.4tc 131 | *.idv 132 | *.lg 133 | *.trc 134 | *.xref 135 | 136 | # hyperref 137 | *.brf 138 | 139 | # knitr 140 | *-concordance.tex 141 | # TODO Comment the next line if you want to keep your tikz graphics files 142 | *.tikz 143 | *-tikzDictionary 144 | 145 | # listings 146 | *.lol 147 | 148 | # luatexja-ruby 149 | *.ltjruby 150 | 151 | # makeidx 152 | *.idx 153 | *.ilg 154 | *.ind 155 | 156 | # minitoc 157 | *.maf 158 | *.mlf 159 | *.mlt 160 | *.mtc[0-9]* 161 | *.slf[0-9]* 162 | *.slt[0-9]* 163 | *.stc[0-9]* 164 | 165 | # minted 166 | _minted* 167 | *.pyg 168 | 169 | # morewrites 170 | *.mw 171 | 172 | # nomencl 173 | *.nlg 174 | *.nlo 175 | *.nls 176 | 177 | # pax 178 | *.pax 179 | 180 | # pdfpcnotes 181 | *.pdfpc 182 | 183 | # sagetex 184 | *.sagetex.sage 185 | *.sagetex.py 186 | *.sagetex.scmd 187 | 188 | # scrwfile 189 | *.wrt 190 | 191 | # sympy 192 | *.sout 193 | *.sympy 194 | sympy-plots-for-*.tex/ 195 | 196 | # pdfcomment 197 | *.upa 198 | *.upb 199 | 200 | # pythontex 201 | *.pytxcode 202 | pythontex-files-*/ 203 | 204 | # tcolorbox 205 | *.listing 206 | 207 | # thmtools 208 | *.loe 209 | 210 | # TikZ & PGF 211 | *.dpth 212 | *.md5 213 | *.auxlock 214 | 215 | # todonotes 216 | *.tdo 217 | 218 | # vhistory 219 | *.hst 220 | *.ver 221 | 222 | # easy-todo 223 | *.lod 224 | 225 | # xcolor 226 | *.xcp 227 | 228 | # xmpincl 229 | *.xmpi 230 | 231 | # xindy 232 | *.xdy 233 | 234 | # xypic precompiled matrices and outlines 235 | *.xyc 236 | *.xyd 237 | 238 | # endfloat 239 | *.ttt 240 | *.fff 241 | 242 | # Latexian 243 | TSWLatexianTemp* 244 | 245 | ## Editors: 246 | # WinEdt 247 | *.bak 248 | *.sav 249 | 250 | # Texpad 251 | .texpadtmp 252 | 253 | # LyX 254 | *.lyx~ 255 | 256 | # Kile 257 | *.backup 258 | 259 | # gummi 260 | .*.swp 261 | 262 | # auto folder when using emacs and auctex 263 | ./auto/* 264 | *.el 265 | 266 | # expex forward references with \gathertags 267 | *-tags.tex 268 | 269 | # standalone packages 270 | *.sta 271 | 272 | # Makeindex log files 273 | *.lpz 274 | -------------------------------------------------------------------------------- /EfficientLab-6-3_FOMAML-star_checkpoint.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml4ai/mliis/f40352e734f77609bcd5c4ad330ea73a897a217d/EfficientLab-6-3_FOMAML-star_checkpoint.tar.gz -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Sean M. Hendryx 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta-Learning Initializations for Image Segmentation 2 | 3 | Code for meta-learning and evaluating initializations for image segmentation as described in our paper , which was [presented at the 4th Workshop on Meta-Learning at NeurIPS 2020](https://meta-learn.github.io/2020/papers/44_paper.pdf). 4 | 5 | Note that this repository is in archival status. Code is provided as-is and no updates are expected. 6 | 7 | Example 5 shot predictions on test samples from meta-test tasks: 8 | 9 | ![5-shot](figures/example_5-shot_predictions.png) 10 | 11 | 12 | ## Citing 13 | If you find this project useful in your research, please consider citing: 14 | 15 | ``` 16 | @article{hendryx2019meta, 17 | title={Meta-Learning Initializations for Image Segmentation}, 18 | author={Hendryx, Sean M and Leach, Andrew B and Hein, Paul D and Morrison, Clayton T}, 19 | journal={4th Workshop on Meta-Learning at NeurIPS 2020}, 20 | year={2020}, 21 | } 22 | ``` 23 | 24 | ## Setup 25 | 26 | 27 | We have included a `requirements.txt` file with dependencies. You can also see `make_python_virtualenv.sh` for recommended steps for setting up your environment. 28 | 29 | You can download the FSS-1000 meta-training and evaluation tfrecord shards from: 30 | https://drive.google.com/open?id=1aGHP0ev_1eAFSnYtN0ObDI-DnB0TsQUU 31 | 32 | 33 | And the joint-training shards from: 34 | https://drive.google.com/open?id=1aQpyQ0CEBCL9EW8xoCaI6xveYxtXNYKq 35 | 36 | The FP-k dataset shards are available at: 37 | https://drive.google.com/open?id=1G1NJIyQlkxAb4vlsRDPR3W3If_RJ4rPd 38 | 39 | The FP-k dataset is derived from the [FSS-1000](https://github.com/HKUSTCV/FSS-1000) and PASCAL-5i datasets. PASCAL-5i was in turn derived from the parent datasets: [PASCAL](http://host.robots.ox.ac.uk/pascal/VOC/) and [Semantic Boundaries Datasets](http://home.bharathh.info/pubs/codes/SBD/download.html) as described in [One-Shot Learning for Semantic Segmentation 40 | ](https://arxiv.org/abs/1709.03410). 41 | 42 | We created our meta-training tfrecord shards by following these steps. 43 | Download the FSS-1000 dataset from https://github.com/HKUSTCV/FSS-1000 44 | Convert the images and masks to tfrecords: 45 | ``` 46 | python fss_1000_image_to_tfrecord.py --input_dir --tfrecord_dir 47 | ``` 48 | 49 | ## Run the SOTA evaluation 50 | 51 | Extract the checkpoint: 52 | ``` 53 | tar -xzvf EfficientLab-6-3_FOMAML-star_checkpoint.tar.gz 54 | ``` 55 | 56 | Put the FSS-1000 meta-training and evaluation tfrecord shards at the root of this repo or edit the `data_dir` path in `run.sh` to point to the shards on your machine. 57 | 58 | Finally, call: 59 | ``` 60 | ./run.sh 61 | ``` 62 | 63 | ## Run an experiment 64 | 65 | The main point of entry in this codebase is: 66 | ``` 67 | python run_metasegnet.py 68 | ``` 69 | 70 | See args.py for arguments and their descriptions. 71 | 72 | Our SOTA meta-learned initialization that generated the best FSS-1000 results reported in our paper is in this repository at `EfficientLab-6-3_FOMAML-star_checkpoint` 73 | 74 | ## Visualize predictions 75 | To see predictions, set the environment variable ala: 76 | 77 | ``` 78 | export SAVE_PREDICTIONS=1 79 | ``` 80 | 81 | ## Save an adapted model 82 | To save the weights of an updated model on your task(s), run: 83 | ``` 84 | python run_metasegnet.py --save_fine_tuned_checkpoints --save_fine_tuned_checkpoints_dir /path/to/save/to <--other_args> 85 | ``` 86 | See `run.sh` for our recommended hyperparameters found via update hyperparameter optimization. 87 | 88 | ## EfficientLab 89 | Our SOTA network architecture class is defined in `models/efficientlab.py`. 90 | 91 | ![EfficientLab](figures/EfficientLab.png) 92 | 93 | 94 | ## Acknowledgements 95 | This repository builds on the [Reptile implementation by OpenAI](https://github.com/openai/supervised-reptile) and the [EfficientNet backbone implementation by Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet). 96 | -------------------------------------------------------------------------------- /augmenters/__init__.py: -------------------------------------------------------------------------------- 1 | """Image segmentation augmentation functions""" -------------------------------------------------------------------------------- /augmenters/np_augmenters.py: -------------------------------------------------------------------------------- 1 | """Image augmentations in numpy with support for dense-labels. Input features (images) should be in range [0, 255].""" 2 | import random 3 | from random import shuffle 4 | import numpy as np 5 | from scipy.ndimage import rotate 6 | from typing import Optional, Union, List 7 | 8 | 9 | def additive_gaussian_noise(image, mask, mean_sd=5.1): 10 | sd = np.abs(np.random.normal(mean_sd, 1, 1)) 11 | noise = np.random.normal(0, sd, image.shape) 12 | return np.clip(image + noise, 0., 255.).astype(np.float32), mask.astype(np.float32) 13 | 14 | 15 | def exposure(image, mask, mean_sd=12.75): 16 | sd = np.abs(np.random.normal(mean_sd, 1, 1)) 17 | noise = np.random.normal(0, sd, 1) 18 | return np.clip(image + noise, 0., 255.).astype(np.float32), mask.astype(np.float32) 19 | 20 | 21 | def random_eraser(input_img, mask, s_l=0.02, s_h=0.10, r_1=0.3, r_2=1/0.3, v_l=0, v_h=255): 22 | """ 23 | Random eraser https://arxiv.org/pdf/1708.04896.pdf 24 | Adapted for image segmentation and speed from: https://github.com/yu4u/mixup-generator/blob/master/random_eraser.py 25 | """ 26 | img_h, img_w, _ = input_img.shape 27 | s = np.random.uniform(s_l, s_h) * img_h * img_w 28 | r = np.random.uniform(r_1, r_2) 29 | w = int(np.sqrt(s / r)) 30 | h = int(np.sqrt(s * r)) 31 | top = np.random.randint(0, img_h) 32 | left = np.random.randint(0, img_w) 33 | c = np.random.uniform(v_l, v_h) 34 | input_img[top:top + h, left:left + w, :] = c 35 | mask[top:top + h, left:left + w, :] = [1, 0] # Set the background to true, foreground to false 36 | return input_img.astype(np.float32), mask.astype(np.float32) 37 | 38 | 39 | def fliplr(image, mask): 40 | image = np.fliplr(image) 41 | mask = np.fliplr(mask) 42 | return image.astype(np.float32), mask.astype(np.float32) 43 | 44 | 45 | def shift_img_lr(image, shift, roll, right, fill: Optional[Union[int, List[int]]] = None): 46 | if right: 47 | image = np.roll(image, shift, 0) 48 | if not roll: 49 | if fill is not None: 50 | left_fill = fill 51 | else: 52 | left_fill = np.random.uniform(0, 255, image.shape[2]) 53 | image[:, :shift] = left_fill 54 | else: 55 | image = np.roll(image, -shift, 0) 56 | if not roll: 57 | if fill is not None: 58 | right_fill = fill 59 | else: 60 | right_fill = np.random.uniform(0, 255, image.shape[2]) 61 | image[:, -shift:] = right_fill 62 | return image 63 | 64 | 65 | def shift_img_ud(image, shift, roll, up, fill: Optional[Union[int, List[int]]] = None): 66 | if up: 67 | image = np.roll(image, shift, 1) 68 | if not roll: 69 | if fill is not None: 70 | low_fill = fill 71 | else: 72 | low_fill = np.random.uniform(0, 255, image.shape[2]) 73 | image[-shift:, :] = low_fill 74 | else: 75 | image = np.roll(image, -shift, 1) 76 | if not roll: 77 | if fill is not None: 78 | top_fill = fill 79 | else: 80 | top_fill = np.random.uniform(0, 255, image.shape[2]) 81 | image[:shift, :] = top_fill 82 | return image 83 | 84 | 85 | def translate(image, mask, max_shift=23, mask_fill=[1, 0]): # TODO: try larger max_shift 86 | """Randomly jitter an image horizontally or vertically.""" 87 | vert = random.getrandbits(1) 88 | direction = random.getrandbits(1) 89 | shift = np.random.randint(1, max_shift + 1, 1)[0] 90 | roll = random.getrandbits(1) 91 | if vert: 92 | image = shift_img_ud(image, shift, roll, direction) 93 | mask = shift_img_ud(mask, shift, roll, direction, fill=mask_fill) 94 | else: 95 | image = shift_img_lr(image, shift, roll, direction) 96 | mask = shift_img_lr(mask, shift, roll, direction, fill=mask_fill) 97 | return image.astype(np.float32), mask.astype(np.float32) 98 | 99 | 100 | def rotate_img_mask(image, mask, max_angle: int = 45, mask_fill=[1, 0]): 101 | angle = np.random.randint(-max_angle, max_angle) 102 | mode = random.sample(['reflect', 'constant', 'mirror', 'wrap'], 1)[0] 103 | reshape = False 104 | 105 | fill_with_noise = False 106 | 107 | if mode == "constant": 108 | if random.getrandbits(1): 109 | cval = -256 110 | fill_with_noise = True 111 | else: 112 | cval = np.random.randint(0, 256) 113 | else: 114 | cval = 0 115 | 116 | image = rotate(image, angle=angle, reshape=reshape, mode=mode, cval=cval) 117 | 118 | if mode == "constant" and fill_with_noise: 119 | bg = image == -256 120 | noise = np.random.randint(0, 256, size=image.shape) 121 | image[bg] = noise[bg] 122 | 123 | cval = -256 124 | mask = rotate(mask, angle=angle, reshape=reshape, mode=mode, cval=cval, order=0) 125 | if mode == "constant": 126 | bg = mask[:, :, 0] == -256 127 | mask[bg] = mask_fill 128 | 129 | return image, mask 130 | 131 | 132 | cur_aug_funcs = [random_eraser, translate, fliplr, additive_gaussian_noise, exposure, rotate_img_mask] 133 | 134 | 135 | class Augmenter: 136 | """Image segmentation augmenter.""" 137 | def __init__(self, aug_funcs=None): 138 | if aug_funcs is None: 139 | aug_funcs = cur_aug_funcs 140 | self.aug_funcs = aug_funcs 141 | self.prob_to_return_original = 1. / (len(aug_funcs) + 1) 142 | print("Initialized image segmentation augmenter.") 143 | 144 | def apply_augmentations(self, image, mask, prob_to_return_original=0.0, return_image_mask_in_list: bool = True): # 0.5 145 | if prob_to_return_original is not None: 146 | prob = prob_to_return_original 147 | else: 148 | prob = self.prob_to_return_original 149 | if np.random.rand() <= prob: 150 | return image, mask 151 | image, mask = image.copy(), mask.copy() 152 | shuffle(self.aug_funcs) 153 | # Apply some or all of them in the shuffled order 154 | num_to_apply = np.random.randint(1, len(self.aug_funcs) + 1) 155 | for fn in self.aug_funcs[:num_to_apply]: 156 | image, mask = fn(image, mask) 157 | if return_image_mask_in_list: 158 | return [image, mask] 159 | else: 160 | return image, mask 161 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml4ai/mliis/f40352e734f77609bcd5c4ad330ea73a897a217d/data/__init__.py -------------------------------------------------------------------------------- /data/fp-k_test_set.txt: -------------------------------------------------------------------------------- 1 | airliner 2 | bus 3 | motorbike 4 | potted_plant 5 | television -------------------------------------------------------------------------------- /data/fss_1000_image_to_joint_tfrecord_shards.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script converts the image-mask pairs of the FSS-1000 dataset to tfrecords, with one task per tfrecord. 3 | """ 4 | import argparse 5 | import glob 6 | import math 7 | import os 8 | from itertools import repeat 9 | from multiprocessing.pool import Pool 10 | from pathlib import Path 11 | import sys 12 | import time 13 | import random 14 | import warnings 15 | from typing import List, Tuple, Optional 16 | 17 | import imageio 18 | import numpy as np 19 | import tensorflow as tf 20 | 21 | from data.fss_1000_utils import TEST_TASK_IDS, FP_K_TEST_TASK_IDS, split_train_test_tasks, TOTAL_NUM_FSS_CLASSES, \ 22 | IMAGE_DIMS 23 | from joint_train.data.constants import SERIALIZED_DTYPE 24 | 25 | MAX_NUM_PROCESSES = 8 26 | 27 | 28 | def parse_arguments(argv): 29 | """Parses command line arguments.""" 30 | parser = argparse.ArgumentParser(description='Writes FSS-1000 images to TFRecords.') 31 | parser.add_argument( 32 | '--input_dir', 33 | type=str, 34 | default=None, 35 | help='Absolute path to base directory of the FSS-1000 dataset.') 36 | parser.add_argument( 37 | '--tfrecord_dir', 38 | required=True, 39 | type=str, 40 | help='Directory to write tfrecords to.') 41 | parser.add_argument( 42 | '--overwrite', 43 | required=False, 44 | default=False, 45 | action="store_true", 46 | help='Overwrite existing tfrecords?') 47 | parser.add_argument("--compress", action="store_true", default=False) 48 | parser.add_argument("--fp_k_test_set", help="Hold out the test task for the fp-k classes.", action="store_true") 49 | parser.add_argument("--num_val_tasks", help="Number of validation tasks to hold out in addition to the 240 test tasks.", type=int, default=0) 50 | args, _ = parser.parse_known_args(args=argv[1:]) 51 | return args 52 | 53 | 54 | def get_fss_dir_paths(data_dir): 55 | return glob.glob(os.path.join(data_dir, "*/")) 56 | 57 | 58 | def get_image_mask_pairs(task: str, image_ext: str = ".jpg", mask_ext: str = ".png") -> List[Tuple[str, str]]: 59 | masks = glob.glob(os.path.join(task, "*" + mask_ext)) 60 | image_mask_pairs = [] 61 | for mask in masks: 62 | image = mask.replace(mask_ext, image_ext) 63 | if os.path.exists(image): 64 | image_mask_pairs.append((image, mask)) 65 | else: 66 | warnings.warn("No corresponding image found for mask: {}".format(mask)) 67 | return image_mask_pairs 68 | 69 | 70 | def main(): 71 | """Write images to TFRecords.""" 72 | print("Converting FSS-1000 image-mask pairs to tfrecord shards.") 73 | dry_run = False 74 | start = time.time() 75 | print(start) 76 | args = parse_arguments(sys.argv) 77 | if args.compress: 78 | ext = ".tfrecord.gzip" 79 | else: 80 | ext = ".tfrecord" 81 | 82 | if args.fp_k_test_set: 83 | test_task_ids = FP_K_TEST_TASK_IDS 84 | else: 85 | test_task_ids = TEST_TASK_IDS 86 | 87 | train_dirs, test_dirs, all_classes = get_fss_train_test_absolute_dir_paths(args.input_dir, test_task_ids=test_task_ids) 88 | 89 | train_dirs, val_dirs = split_train_test_tasks(train_dirs, args.num_val_tasks, reproducbile_splits=True) 90 | 91 | assert len(train_dirs) + len(val_dirs) + len(test_dirs) == TOTAL_NUM_FSS_CLASSES 92 | 93 | if not dry_run: 94 | mkdir(args.tfrecord_dir) 95 | 96 | for set_name, paths in zip(["train", "val", "test"], [train_dirs, val_dirs, test_dirs]): 97 | image_mask_pairs = [] 98 | for folder in paths: 99 | print("folder: {}".format(folder)) 100 | image_mask_pairs.extend(get_image_mask_pairs(folder)) 101 | 102 | tfrecord_filename = os.path.join(args.tfrecord_dir, set_name + ext) 103 | 104 | if not dry_run: 105 | if not os.path.exists(tfrecord_filename) or args.overwrite: 106 | write_tfrecords(tfrecord_filename, image_mask_pairs, all_classes, compress=args.compress) 107 | print("Wrote tfrecord file to {}".format(tfrecord_filename)) 108 | 109 | print("Finished.") 110 | print("Took {} minutes.".format((time.time() - start) / 60.0)) 111 | 112 | 113 | def get_fss_train_test_absolute_dir_paths(data_dir, test_task_ids: List[str] = TEST_TASK_IDS): 114 | expected_classes = TOTAL_NUM_FSS_CLASSES 115 | all_classes = get_classes_from_subdirs(data_dir, expected_classes) 116 | train_classes = list(set(all_classes) - set(test_task_ids)) 117 | test_classes = list(set(test_task_ids)) 118 | assert len(train_classes) + len(test_classes) == expected_classes 119 | 120 | return [os.path.join(data_dir, x) for x in train_classes], [os.path.join(data_dir, x) for x in test_classes], sorted(all_classes) 121 | 122 | 123 | def get_classes_from_subdirs(data_dir, expected_num_classes: int): 124 | """Get the classes from the names of the subdirectories""" 125 | all_classes = os.listdir(data_dir) 126 | all_classes = [x for x in all_classes if os.path.isdir(os.path.join(data_dir, x))] 127 | if len(all_classes) != expected_num_classes: 128 | print("length of found classes does not equal number of expected classes") 129 | import pdb; pdb.set_trace() 130 | all_classes = sorted(all_classes) 131 | return all_classes 132 | 133 | 134 | def mkdir(path): 135 | """ 136 | Recursive create dir at `path` if `path` does not exist. 137 | """ 138 | if not os.path.exists(path): 139 | os.makedirs(path) 140 | 141 | 142 | def one_hot_encode(mask, class_name: str, class_names: List[str], image_width: int = IMAGE_DIMS, truth_value: int = 255, seperate_background_channel: bool = True): 143 | if seperate_background_channel: 144 | background = truth_value - mask 145 | 146 | n_classes = len(class_names) 147 | i = class_names.index(class_name) 148 | 149 | if seperate_background_channel: 150 | n_classes += 1 151 | i += 1 152 | 153 | all_classes = np.zeros([image_width, image_width, n_classes]) 154 | all_classes[:, :, i] = mask 155 | 156 | if seperate_background_channel: 157 | all_classes[:, :, 0] = background 158 | 159 | return all_classes 160 | 161 | 162 | def image_to_feature(image_filename, take_first_channel=False, one_hot_encode_mask=False, all_classes: Optional[List[str]] = None, serialize_as=SERIALIZED_DTYPE): 163 | """ 164 | Converts target image to a bytes feature. 165 | 166 | Args: 167 | image_filename: Full path of image with image type extension. 168 | take_first_channel: Set to True for masks. 169 | 170 | Returns: 171 | TF bytes Feature for the image. 172 | """ 173 | im = imageio.imread(image_filename) 174 | class_name = os.path.basename(Path(image_filename).parent) 175 | img_shape = im.shape 176 | height, width = im.shape[0], im.shape[1] 177 | if height != IMAGE_DIMS or width != IMAGE_DIMS: 178 | print("{} is not of expected image dimensions. Skipping this sample".format(image_filename)) 179 | return None 180 | if take_first_channel: 181 | if len(img_shape) > 2: 182 | im = im[:, :, 0] 183 | if one_hot_encode_mask: 184 | im = one_hot_encode(im, class_name=class_name, class_names=all_classes) 185 | im = im.astype(serialize_as) 186 | bytes_list = tf.train.BytesList(value=[im.tobytes()]) 187 | return tf.train.Feature(bytes_list=bytes_list) 188 | 189 | 190 | def make_example(image_filename, mask_filename, all_classes): 191 | """Collect TF Features into a TF Example.""" 192 | image = image_to_feature(image_filename) 193 | mask = image_to_feature(mask_filename, take_first_channel=True, one_hot_encode_mask=True, all_classes=all_classes), 194 | if (image is None) or (mask is None): 195 | return None 196 | feature = { 197 | 'image': image, 198 | 'mask': mask, 199 | } 200 | features = tf.train.Features(feature=feature) 201 | return tf.train.Example(features=features) 202 | 203 | 204 | def image_mask_are_valid(image: str, mask: str) -> bool: 205 | def _valid(image_filename): 206 | im = imageio.imread(image_filename) 207 | height, width = im.shape[0], im.shape[1] 208 | if height != IMAGE_DIMS or width != IMAGE_DIMS: 209 | print("{} is not of expected image dimensions. Skipping this sample".format(image_filename)) 210 | return False 211 | return True 212 | res = [_valid(x) for x in [image, mask]] 213 | if all(res): 214 | return True 215 | return False 216 | 217 | 218 | def chunks(lst, n): 219 | """Yield successive n-sized chunks from lst.""" 220 | for i in range(0, len(lst), n): 221 | yield lst[i: i + n] 222 | 223 | 224 | def write_tfrecords(tfrecord_basename, filename_pairs, all_classes: List[str], max_examples:int = 200, compress: bool = False): 225 | """Write tfrecord shards in parallel""" 226 | if isinstance(filename_pairs, zip): 227 | filename_pairs = list(filename_pairs) 228 | elif not isinstance(filename_pairs, list): 229 | raise ValueError("filename_pairs must be list or zip object but is {}".format(type(filename_pairs))) 230 | random.shuffle(filename_pairs) 231 | num_examples = len(filename_pairs) 232 | 233 | num_shards = int(math.ceil(float(num_examples) / float(max_examples))) 234 | record_names = ['%s-%05i-of-%05i' % (tfrecord_basename, n + 1, num_shards) for n in range(num_shards)] 235 | 236 | shard_filename_pairs = [[] for _ in range(num_shards)] 237 | for n, filename_pair in enumerate(filename_pairs): 238 | sublist_index = n % num_shards 239 | shard_filename_pairs[sublist_index].append(filename_pair) 240 | 241 | iterable = [(fn, fn_pairs, ac, compress_bool) for fn, fn_pairs, ac, compress_bool in zip(record_names, shard_filename_pairs, repeat(all_classes), repeat(compress))] 242 | 243 | num_processes = min(num_shards, MAX_NUM_PROCESSES) 244 | with Pool(num_processes) as pool: 245 | pool.starmap(write_tfrecord, iterable) 246 | 247 | 248 | def write_tfrecord(tfrecord_filename, filename_pairs, all_classes: List[str], compress: bool = False): 249 | """Write TFExamples containing images and masks to TFRecord(s). 250 | 251 | Args: 252 | tfrecord_filename: Filename to write record to, full path and extension. 253 | filename_pairs: List of (image_filename, mask_filename) tuples. 254 | all_classes: Dict mapping string to integer index 255 | """ 256 | print("Writing examples to tfrecord at {}".format(tfrecord_filename)) 257 | if compress: 258 | options = tf.python_io.TFRecordOptions( 259 | compression_type=tf.python_io.TFRecordCompressionType.GZIP) 260 | else: 261 | options = None 262 | writer = tf.python_io.TFRecordWriter( 263 | tfrecord_filename, 264 | options) 265 | for n, filename_pair in enumerate(filename_pairs): 266 | image_filename, mask_filename = filename_pair 267 | if not image_mask_are_valid(image_filename, mask_filename): 268 | continue 269 | 270 | example = make_example(image_filename, mask_filename, all_classes) 271 | 272 | if example is None: 273 | continue 274 | serialized_example = example.SerializeToString() 275 | writer.write(serialized_example) 276 | writer.close() 277 | print("Examples written to {}".format(tfrecord_filename)) 278 | 279 | 280 | def old_write_tfrecord(tfrecord_filename, filename_pairs, all_classes: List[str], max_examples:int = 200): 281 | """Write TFExamples containing images and masks to TFRecord(s). 282 | 283 | Args: 284 | tfrecord_filename: Filename to write record to, full path and extension. 285 | filename_pairs: List of (image_filename, mask_filename) tuples. 286 | all_classes: Dict mapping string to integer index 287 | max_examples: Maximum number of examples to write to each TFRecord shard. 288 | """ 289 | options = tf.python_io.TFRecordOptions( 290 | compression_type=tf.python_io.TFRecordCompressionType.GZIP) 291 | if isinstance(filename_pairs, zip): 292 | filename_pairs = list(filename_pairs) 293 | elif not isinstance(filename_pairs, list): 294 | raise ValueError("filename_pairs must be list or zip object but is {}".format(type(filename_pairs))) 295 | random.shuffle(filename_pairs) 296 | num_examples = len(filename_pairs) 297 | # Casting for consistent behavior in python 2 and 3. 298 | num_shards = int(math.ceil(float(num_examples) / float(max_examples))) 299 | writers = [ 300 | tf.python_io.TFRecordWriter( 301 | '%s-%05i-of-%05i' % (tfrecord_filename, n + 1, num_shards), 302 | options, 303 | ) for n in range(num_shards) 304 | ] 305 | 306 | for n, filename_pair in enumerate(filename_pairs): 307 | writer = writers[n % num_shards] 308 | image_filename, mask_filename = filename_pair 309 | if not image_mask_are_valid(image_filename, mask_filename): 310 | continue 311 | 312 | example = make_example(image_filename, mask_filename, all_classes) 313 | 314 | if example is None: 315 | continue 316 | serialized_example = example.SerializeToString() 317 | writer.write(serialized_example) 318 | for writer in writers: 319 | writer.close() 320 | print("Examples written to {}".format(tfrecord_filename)) 321 | 322 | 323 | if __name__ == '__main__': 324 | main() 325 | -------------------------------------------------------------------------------- /data/fss_1000_image_to_tfrecord.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script converts the image-mask pairs of the FSS-1000 dataset to tfrecords, with one task per tfrecord. 3 | 4 | Example usage: 5 | python fss_1000_image_to_tfrecord.py --input_dir fewshot_data --tfrecord_dir fewshot_shards/ 6 | """ 7 | import argparse 8 | import glob 9 | import os 10 | import sys 11 | import time 12 | import random 13 | import warnings 14 | from typing import List, Tuple 15 | 16 | import imageio 17 | import tensorflow as tf 18 | 19 | 20 | IMAGE_DIMS = 224 # Length of one side of input images. Images assumed to be square. 21 | 22 | 23 | def parse_arguments(argv): 24 | """Parses command line arguments.""" 25 | parser = argparse.ArgumentParser(description='Writes FSS-1000 images to TFRecords.') 26 | parser.add_argument( 27 | '--input_dir', 28 | type=str, 29 | default=None, 30 | help='Absolute path to base directory of the FSS-1000 dataset.') 31 | parser.add_argument( 32 | '--tfrecord_dir', 33 | required=True, 34 | type=str, 35 | help='Directory to write tfrecords to.') 36 | parser.add_argument( 37 | '--overwrite', 38 | required=False, 39 | default=False, 40 | type=bool, 41 | help='Overwrite existing tfrecords?') 42 | args, _ = parser.parse_known_args(args=argv[1:]) 43 | return args 44 | 45 | 46 | def get_fss_dir_paths(data_dir): 47 | return glob.glob(os.path.join(data_dir, "*/")) 48 | 49 | 50 | def get_image_mask_pairs(task: str, image_ext: str = ".jpg", mask_ext: str = ".png") -> List[Tuple[str, str]]: 51 | masks = glob.glob(os.path.join(task, "*" + mask_ext)) 52 | image_mask_pairs = [] 53 | for mask in masks: 54 | image = mask.replace(mask_ext, image_ext) 55 | if os.path.exists(image): 56 | image_mask_pairs.append((image, mask)) 57 | else: 58 | warnings.warn("No corresponding image found for mask: {}".format(mask)) 59 | return image_mask_pairs 60 | 61 | 62 | def main(): 63 | """Write images to TFRecords.""" 64 | print("Converting FSS-1000 image-mask pairs to tfrecords, one per task.") 65 | dry_run = False 66 | start = time.time() 67 | print(start) 68 | args = parse_arguments(sys.argv) 69 | 70 | task_dirs = get_fss_dir_paths(args.input_dir) 71 | print("{} tasks found".format(len(task_dirs))) 72 | 73 | if not dry_run: 74 | mkdir(args.tfrecord_dir) 75 | 76 | for task in task_dirs: 77 | image_mask_pairs = get_image_mask_pairs(task) 78 | task_name = os.path.basename(task.rstrip("/")) 79 | print("Processing task: {}".format(task_name)) 80 | tfrecord_filename = os.path.join(args.tfrecord_dir, task_name + ".tfrecord.gzip") 81 | 82 | if not dry_run: 83 | if not os.path.exists(tfrecord_filename) or args.overwrite: 84 | write_tfrecord(tfrecord_filename, image_mask_pairs) 85 | print("Wrote tfrecord file to {}".format(tfrecord_filename)) 86 | 87 | print("Finished.") 88 | print("Took {} minutes.".format((time.time() - start) / 60.0)) 89 | 90 | 91 | def mkdir(path): 92 | """ 93 | Recursive create dir at `path` if `path` does not exist. 94 | """ 95 | if not os.path.exists(path): 96 | os.makedirs(path) 97 | 98 | 99 | def image_to_feature(image_filename, take_first_channel=False): 100 | """ 101 | Converts target image to a bytes feature. 102 | 103 | Args: 104 | image_filename: Full path of image with image type extension. 105 | take_first_channel: Set to True for masks. 106 | 107 | Returns: 108 | TF bytes Feature for the image. 109 | """ 110 | im = imageio.imread(image_filename) 111 | img_shape = im.shape 112 | height, width = im.shape[0], im.shape[1] 113 | if height != IMAGE_DIMS or width != IMAGE_DIMS: 114 | print("{} is not of expected image dimensions. Skipping this sample".format(image_filename)) 115 | return None 116 | if take_first_channel: 117 | if len(img_shape) > 2: 118 | im = im[:, :, 0] 119 | bytes_list = tf.train.BytesList(value=[im.tobytes()]) 120 | return tf.train.Feature(bytes_list=bytes_list) 121 | 122 | 123 | def make_example(image_filename, mask_filename): 124 | """Collect TF Features into a TF Example.""" 125 | image = image_to_feature(image_filename) 126 | mask = image_to_feature(mask_filename, take_first_channel=True), 127 | if (image is None) or (mask is None): 128 | return None 129 | feature = { 130 | 'image': image, 131 | 'mask': mask, 132 | } 133 | features = tf.train.Features(feature=feature) 134 | return tf.train.Example(features=features) 135 | 136 | 137 | def image_mask_are_valid(image: str, mask: str) -> bool: 138 | def _valid(image_filename): 139 | im = imageio.imread(image_filename) 140 | height, width = im.shape[0], im.shape[1] 141 | if height != IMAGE_DIMS or width != IMAGE_DIMS: 142 | print("{} is not of expected image dimensions. Skipping this sample".format(image_filename)) 143 | return False 144 | return True 145 | res = [_valid(x) for x in [image, mask]] 146 | if all(res): 147 | return True 148 | return False 149 | 150 | 151 | def write_tfrecord(tfrecord_filename, filename_pairs, max_examples=None): 152 | """Write TFExamples containing images and masks to TFRecord(s). 153 | 154 | Args: 155 | tfrecord_filename: Filename to write record to, full path and extension. 156 | filename_pairs: List of (image_filename, mask_filename) tuples. 157 | max_examples: Maximum number of examples to write to each TFRecord shard. 158 | """ 159 | options = tf.python_io.TFRecordOptions( 160 | compression_type=tf.python_io.TFRecordCompressionType.GZIP) 161 | if isinstance(filename_pairs, zip): 162 | filename_pairs = list(filename_pairs) 163 | elif not isinstance(filename_pairs, list): 164 | raise ValueError("filename_pairs must be list or zip object but is {}".format(type(filename_pairs))) 165 | random.shuffle(filename_pairs) # Shuffle examples within a task. 166 | 167 | writer = tf.python_io.TFRecordWriter(tfrecord_filename, options=options) 168 | i = 0 169 | for filename_pair in filename_pairs: 170 | image_filename, mask_filename = filename_pair 171 | if not image_mask_are_valid(image_filename, mask_filename): 172 | continue 173 | example = make_example(image_filename, mask_filename) 174 | serialized_example = example.SerializeToString() 175 | writer.write(serialized_example) 176 | i += 1 177 | writer.close() 178 | print("{} examples written to {}".format(i, tfrecord_filename)) 179 | 180 | 181 | if __name__ == '__main__': 182 | main() 183 | -------------------------------------------------------------------------------- /data/fss_1000_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import os 4 | import random 5 | from typing import List 6 | 7 | 8 | def split_train_test_tasks(all_tasks: List[str], n_test, reproducbile_splits: bool = False): 9 | if not isinstance(all_tasks, list): 10 | all_tasks = list(all_tasks) 11 | if reproducbile_splits: 12 | all_tasks = sorted(all_tasks) 13 | else: 14 | random.shuffle(all_tasks) 15 | test_set = [] 16 | for i in range(n_test): 17 | test_set.append(all_tasks.pop()) 18 | assert_train_test_split(all_tasks, test_set) 19 | return all_tasks, test_set 20 | 21 | 22 | def assert_train_test_split(train, test): 23 | for i in test: 24 | assert i not in train, "train-test leakage" 25 | 26 | 27 | def get_fss_tasks(data_dir): 28 | return glob.glob(os.path.join(data_dir, "*.tfrecord*")) 29 | 30 | 31 | def get_fss_test_set() -> List[str]: 32 | dirname = os.path.dirname(__file__) 33 | path = "fss_test_set.txt" # File containing the test examples from the FSS-1000 authors. 34 | filename = os.path.join(dirname, path) 35 | with open(filename, "r") as file: 36 | tasks = [line.rstrip("\n") for line in file] 37 | return tasks 38 | 39 | 40 | def get_fss_train_set() -> List[str]: 41 | dirname = os.path.dirname(__file__) 42 | path = "fss_train_set.txt" 43 | filename = os.path.join(dirname, path) 44 | with open(filename, "r") as file: 45 | tasks = [line.rstrip("\n") for line in file] 46 | return tasks 47 | 48 | 49 | def get_fp_k_test_set() -> List[str]: 50 | dirname = os.path.dirname(__file__) 51 | path = "fp-k_test_set.txt" # File containing the test examples from the FSS-1000 authors. 52 | filename = os.path.join(dirname, path) 53 | with open(filename, "r") as file: 54 | tasks = [line.rstrip("\n") for line in file] 55 | return tasks 56 | 57 | 58 | TEST_TASK_IDS = get_fss_test_set() 59 | TRAIN_TASK_IDS = get_fss_train_set() 60 | FP_K_TEST_TASK_IDS = get_fp_k_test_set() 61 | TOTAL_NUM_FSS_CLASSES = 1000 62 | IMAGE_DIMS = 224 # Length of one side of input images. Images assumed to be square. 63 | -------------------------------------------------------------------------------- /data/fss_test_set.txt: -------------------------------------------------------------------------------- 1 | bus 2 | hotel_slipper 3 | burj_al 4 | reflex_camera 5 | abe's_flyingfish 6 | oiltank_car 7 | doormat 8 | fish_eagle 9 | barber_shaver 10 | motorbike 11 | feather_clothes 12 | wandering_albatross 13 | rice_cooker 14 | delta_wing 15 | fish 16 | nintendo_switch 17 | bustard 18 | diver 19 | minicooper 20 | cathedrale_paris 21 | big_ben 22 | combination_lock 23 | villa_savoye 24 | american_alligator 25 | gym_ball 26 | andean_condor 27 | leggings 28 | pyramid_cube 29 | jet_aircraft 30 | meatloaf 31 | reel 32 | swan 33 | osprey 34 | crt_screen 35 | microscope 36 | rubber_eraser 37 | arrow 38 | monkey 39 | mitten 40 | spiderman 41 | parthenon 42 | bat 43 | chess_king 44 | sulphur_butterfly 45 | quail_egg 46 | oriole 47 | iron_man 48 | wooden_boat 49 | anise 50 | steering_wheel 51 | groenendael 52 | dwarf_beans 53 | pteropus 54 | chalk_brush 55 | bloodhound 56 | moon 57 | english_foxhound 58 | boxing_gloves 59 | peregine_falcon 60 | pyraminx 61 | cicada 62 | screw 63 | shower_curtain 64 | tredmill 65 | bulb 66 | bell_pepper 67 | lemur_catta 68 | doughnut 69 | twin_tower 70 | astronaut 71 | nintendo_3ds 72 | fennel_bulb 73 | indri 74 | captain_america_shield 75 | kunai 76 | broom 77 | iphone 78 | earphone1 79 | flying_squirrel 80 | onion 81 | vinyl 82 | sydney_opera_house 83 | oyster 84 | harmonica 85 | egg 86 | breast_pump 87 | guitar 88 | potato_chips 89 | tunnel 90 | cuckoo 91 | rubick_cube 92 | plastic_bag 93 | phonograph 94 | net_surface_shoes 95 | goldfinch 96 | ipad 97 | mite_predator 98 | coffee_mug 99 | golden_plover 100 | f1_racing 101 | lapwing 102 | nintendo_gba 103 | pizza 104 | rally_car 105 | drilling_platform 106 | cd 107 | fly 108 | magpie_bird 109 | leaf_fan 110 | little_blue_heron 111 | carriage 112 | moist_proof_pad 113 | flying_snakes 114 | dart_target 115 | warehouse_tray 116 | nintendo_wiiu 117 | chiffon_cake 118 | bath_ball 119 | manatee 120 | cloud 121 | marimba 122 | eagle 123 | ruler 124 | soymilk_machine 125 | sled 126 | seagull 127 | glider_flyingfish 128 | doublebus 129 | transport_helicopter 130 | window_screen 131 | truss_bridge 132 | wasp 133 | snowman 134 | poached_egg 135 | strawberry 136 | spinach 137 | earphone2 138 | downy_pitch 139 | taj_mahal 140 | rocking_chair 141 | cablestayed_bridge 142 | sealion 143 | banana_boat 144 | pheasant 145 | stone_lion 146 | electronic_stove 147 | fox 148 | iguana 149 | rugby_ball 150 | hang_glider 151 | water_buffalo 152 | lotus 153 | paper_plane 154 | missile 155 | flamingo 156 | american_chamelon 157 | kart 158 | chinese_knot 159 | cabbage_butterfly 160 | key 161 | church 162 | tiltrotor 163 | helicopter 164 | french_fries 165 | water_heater 166 | snow_leopard 167 | goblet 168 | fan 169 | snowplow 170 | leafhopper 171 | pspgo 172 | black_bear 173 | quail 174 | condor 175 | chandelier 176 | hair_razor 177 | white_wolf 178 | toaster 179 | pidan 180 | pyramid 181 | chicken_leg 182 | letter_opener 183 | apple_icon 184 | porcupine 185 | chicken 186 | stingray 187 | warplane 188 | windmill 189 | bamboo_slip 190 | wig 191 | flying_geckos 192 | stonechat 193 | haddock 194 | australian_terrier 195 | hover_board 196 | siamang 197 | canton_tower 198 | santa_sledge 199 | arch_bridge 200 | curlew 201 | sushi 202 | beet_root 203 | accordion 204 | leaf_egg 205 | stealth_aircraft 206 | stork 207 | bucket 208 | hawk 209 | chess_queen 210 | ocarina 211 | knife 212 | whippet 213 | cantilever_bridge 214 | may_bug 215 | wagtail 216 | leather_shoes 217 | wheelchair 218 | shumai 219 | speedboat 220 | vacuum_cup 221 | chess_knight 222 | pumpkin_pie 223 | wooden_spoon 224 | bamboo_dragonfly 225 | ganeva_chair 226 | soap 227 | clearwing_flyingfish 228 | pencil_sharpener1 229 | cricket 230 | photocopier 231 | nintendo_sp 232 | samarra_mosque 233 | clam 234 | charge_battery 235 | flying_frog 236 | ferrari911 237 | polo_shirt 238 | echidna 239 | coin 240 | tower_pisa -------------------------------------------------------------------------------- /data/fss_train_set.txt: -------------------------------------------------------------------------------- 1 | boston_bull 2 | brush_pen 3 | woodpecker 4 | brain_coral 5 | gliding_lizard 6 | zebra 7 | wallet 8 | prayer_rug 9 | kwanyin 10 | mcdonald_sign 11 | whistle 12 | hen_of_the_woods 13 | radio_telescope 14 | persimmon 15 | lhasa_apso 16 | pingpong_racket 17 | pingpong_ball 18 | bullet_train 19 | dart 20 | brambling 21 | wrench 22 | hyena 23 | light_tube 24 | coffeepot 25 | volleyball 26 | mooli 27 | leopard 28 | scorpion 29 | upright_piano 30 | warthog 31 | goose 32 | taxi 33 | timber_wolf 34 | muscle_car 35 | yoga_pad 36 | vending_machine 37 | raven 38 | tray 39 | cpu 40 | wreck 41 | cauliflower 42 | tobacco_pipe 43 | lycaenid_butterfly 44 | adidas_logo2 45 | ceiling_fan 46 | hippo 47 | meerkat 48 | jackfruit 49 | plate 50 | brasscica 51 | wok 52 | carp 53 | broccoli 54 | orang 55 | loafer 56 | crash_helmet 57 | artichoke 58 | teapot 59 | bracelet 60 | german_pointer 61 | totem_pole 62 | pay_phone 63 | shuriken 64 | spider 65 | giant_schnauzer 66 | pickelhaube 67 | car_mirror 68 | saluki 69 | dumbbell 70 | kite 71 | coucal 72 | pistachio 73 | redheart 74 | pufferfish 75 | sandal 76 | black_grouse 77 | vestment 78 | snowball 79 | papaya 80 | crab 81 | cactus_ball 82 | toilet_tissue 83 | space_heater 84 | starfish 85 | gas_pump 86 | tractor 87 | chest 88 | cherry 89 | ox 90 | litchi 91 | throne 92 | toothbrush 93 | envelope 94 | loggerhead_turtle 95 | bra 96 | wafer 97 | lawn_mower 98 | swim_ring 99 | earplug 100 | airedale 101 | waffle 102 | park_bench 103 | scissors 104 | radiator 105 | tiger_cat 106 | syringe 107 | consomme 108 | cream 109 | mushroom 110 | washer 111 | hamster 112 | school_bus 113 | garlic 114 | baseball_bat 115 | water_snake 116 | guinea_pig 117 | ibex 118 | matchstick 119 | hartebeest 120 | blossom_card 121 | ferret 122 | oscilloscope 123 | barbell 124 | african_elephant 125 | teddy 126 | saxophone 127 | snake 128 | toucan 129 | umbrella 130 | miniskirt 131 | abacus 132 | dingo 133 | lionfish 134 | pubg_lvl3backpack 135 | jellyfish 136 | tow_truck 137 | egret 138 | stinkhorn 139 | sandwich 140 | pretzel 141 | partridge 142 | lacewing 143 | beaver 144 | pumpkin 145 | hornet 146 | lion 147 | ladder 148 | egyptian_cat 149 | bradypod 150 | cello 151 | water_bike 152 | harvester 153 | lobster 154 | torii 155 | beam_bridge 156 | poker 157 | cougar 158 | basset 159 | cottontail 160 | hammer 161 | seal 162 | impala 163 | electric_fan 164 | stupa 165 | stretcher 166 | giant_panda 167 | pubg_lvl3helmet 168 | bouzouki 169 | vulture 170 | pineapple 171 | arabian_camel 172 | necklace 173 | goldfish 174 | balance_weight 175 | hair_drier 176 | motor_scooter 177 | rock_snake 178 | redshank 179 | cactus 180 | jinrikisha 181 | single_log 182 | digital_clock 183 | usb 184 | armadillo 185 | common_newt 186 | bee_eater 187 | agama 188 | neck_brace 189 | coconut 190 | bassoon 191 | seatbelt 192 | water_tower 193 | carrot 194 | petri_dish 195 | fig 196 | sloth_bear 197 | kazoo 198 | scroll_brush 199 | pickup 200 | carambola 201 | ab_wheel 202 | squirrel 203 | trimaran 204 | bee 205 | conversion_plug 206 | bolotie 207 | trolleybus 208 | egg_tart 209 | kinguin 210 | ocicat 211 | cigarette 212 | raft 213 | skua 214 | spotted_salamander 215 | white_shark 216 | band-aid 217 | zucchini 218 | capuchin 219 | dowitcher 220 | studio_couch 221 | shotgun 222 | sports_car 223 | lifeboat 224 | dragonfly 225 | cn_tower 226 | dugong 227 | ostrich 228 | lemon 229 | icecream 230 | grey_fox 231 | hook 232 | roller_coaster 233 | cradle 234 | shih-tzu 235 | ski_mask 236 | cardoon 237 | cup 238 | persian_cat 239 | razor 240 | lipstick 241 | quill_pen 242 | mailbox 243 | rocket 244 | streetcar 245 | otter 246 | shower_cap 247 | witch_hat 248 | croquet_ball 249 | beagle 250 | yorkshire_terrier 251 | ambulance 252 | balloon 253 | esport_chair 254 | toothpaste 255 | grey_whale 256 | marshmallow 257 | aubergine 258 | narcissus 259 | pepitas 260 | boa_constrictor 261 | stopwatch 262 | revolver 263 | traffic_light 264 | box_turtle 265 | air_strip 266 | sunglasses 267 | proboscis 268 | chicory 269 | mule 270 | beer_glass 271 | daisy 272 | spider_monkey 273 | chalk 274 | chihuahua 275 | har_gow 276 | wombat 277 | ladyfinger 278 | dutch_oven 279 | diamond 280 | lady_slipper 281 | garfish 282 | parallel_bars 283 | sea_urchin 284 | electronic_toothbrush 285 | plaice 286 | cocacola 287 | white_stork 288 | mooncake 289 | coyote 290 | conch 291 | sungnyemun 292 | spoon 293 | chinese_date 294 | eletrical_switch 295 | tomb 296 | wine_bottle 297 | raccoon 298 | mouse 299 | arctic_fox 300 | steak 301 | english_setter 302 | marmot 303 | prairie_chicken 304 | okra 305 | pillow 306 | paddle 307 | mango 308 | wild_boar 309 | hummingbird 310 | stop_sign 311 | leatherback_turtle 312 | snail 313 | water_polo 314 | polar_bear 315 | ptarmigan 316 | stole 317 | gecko 318 | sturgeon 319 | pill_bottle 320 | fountain 321 | bison 322 | black_swan 323 | cheese_burger 324 | wash_basin 325 | handkerchief 326 | banjo 327 | children_slide 328 | spade 329 | bushtit 330 | eggnog 331 | indian_elephant 332 | ashtray 333 | mount_fuji 334 | roller_skate 335 | flat-coated_retriever 336 | sniper_rifle 337 | pinecone 338 | potato 339 | conveyor 340 | tank 341 | bottle_cap 342 | kangaroo 343 | pen 344 | donkey 345 | parking_meter 346 | collar 347 | bluetick 348 | shift_gear 349 | calculator 350 | baseball 351 | handshower 352 | laptop 353 | loguat 354 | window_shade 355 | head_cabbage 356 | koala 357 | microphone 358 | croissant 359 | red_breasted_merganser 360 | telescope 361 | black_stork 362 | gyromitra 363 | almond 364 | pubg_airdrop 365 | one-armed_bandit 366 | siamese_cat 367 | howler_monkey 368 | lynx 369 | toilet_brush 370 | angora 371 | acorn 372 | mink 373 | taro 374 | microwave 375 | airship 376 | sweatshirt 377 | kappa_logo 378 | mashed_potato 379 | sewing_machine 380 | recreational_vehicle 381 | barometer 382 | pool_table 383 | backpack 384 | car_wheel 385 | weasel 386 | rain_barrel 387 | table_lamp 388 | convertible 389 | lettuce 390 | sandbar 391 | measuring_cup 392 | carousel 393 | spark_plug 394 | bittern 395 | baboon 396 | louvre_pyramid 397 | whale 398 | swab 399 | wallaby 400 | solar_dish 401 | tulip 402 | sombrero 403 | cigar 404 | chainsaw 405 | vacuum 406 | nagoya_castle 407 | stove 408 | swimming_glasses 409 | "thors_hammer" 410 | dishwasher 411 | camomile 412 | soccer_ball 413 | motarboard 414 | macaque 415 | birdhouse 416 | grasshopper 417 | brick_card 418 | power_drill 419 | ruddy_turnstone 420 | candle 421 | ballpoint 422 | hatchet 423 | aircraft_carrier 424 | typewriter 425 | equestrian_helmet 426 | cucumber 427 | briard 428 | puma_logo 429 | garbage_truck 430 | squirrel_monkey 431 | housefinch 432 | globe 433 | maotai_bottle 434 | fire_screen 435 | gibbon 436 | besom 437 | pokermon_ball 438 | punching_bag 439 | piano_keyboard 440 | fur_coat 441 | jacko_lantern 442 | hawthorn 443 | sock 444 | trailer_truck 445 | cableways 446 | corn 447 | colubus 448 | african_crocodile 449 | assult_rifle 450 | espresso 451 | tile_roof 452 | mcdonald_uncle 453 | gorilla 454 | hard_disk 455 | whiptail 456 | buckingham_palace 457 | turnstile 458 | ac_wall 459 | fire_engine 460 | nike_logo 461 | eel 462 | tiger 463 | blenheim_spaniel 464 | yurt 465 | statue_liberty 466 | chess_bishop 467 | soap_dispenser 468 | baby 469 | gourd 470 | balance_beam 471 | surfboard 472 | tokyo_tower 473 | tofu 474 | thimble 475 | rock_beauty 476 | vine_snake 477 | pear 478 | hammerhead_shark 479 | file_cabinet 480 | shopping_cart 481 | psp 482 | folding_chair 483 | patas 484 | bedlington_terrier 485 | windsor_tie 486 | thrush 487 | jay_bird 488 | running_shoe 489 | frog 490 | walnut 491 | cassette 492 | watermelon 493 | indian_cobra 494 | three-toed_sloth 495 | red_fox 496 | fork 497 | yonex_icon 498 | paper_towel 499 | cristo_redentor 500 | sea_cucumber 501 | relay_stick 502 | crayon 503 | trench_coat 504 | pomelo 505 | hami_melon 506 | forklift 507 | refrigerator 508 | african_grey 509 | igloo 510 | tiger_shark 511 | strainer 512 | brick_tea 513 | peanut 514 | chopsticks 515 | skateboard 516 | cornet 517 | leeks 518 | ringlet_butterfly 519 | sleeping_bag 520 | handcuff 521 | crane 522 | toilet_seat 523 | radio 524 | christmas_stocking 525 | cairn 526 | kremlin 527 | mud_turtle 528 | bighorn_sheep 529 | cricketball 530 | tebby_cat 531 | nail_scissor 532 | carbonara 533 | suitcase 534 | pinwheel 535 | victor_icon 536 | peacock 537 | celery 538 | scabbard 539 | guacamole 540 | oil_filter 541 | fire_balloon 542 | shovel 543 | medical_kit 544 | flowerpot 545 | submarine 546 | triceratops 547 | badger 548 | afghan_hound 549 | excavator 550 | skull 551 | rabbit 552 | toilet_plunger 553 | fire_hydrant 554 | cosmetic_brush 555 | battery 556 | beer_bottle 557 | trilobite 558 | speaker 559 | beacon 560 | sparrow 561 | ginger 562 | avocado 563 | red_wolf 564 | ipod 565 | turtle 566 | feeder 567 | killer_whale 568 | great_wall 569 | ac_ground 570 | stool 571 | snowmobile 572 | drumstick 573 | flute 574 | night_snake 575 | pomegranate 576 | water_ouzel 577 | brick 578 | monitor 579 | lampshade 580 | parachute 581 | bald_eagle 582 | coffin 583 | garbage_can 584 | mortar 585 | keyboard 586 | drum 587 | green_mamba 588 | cornmeal 589 | titi_monkey 590 | rosehip 591 | mountain_tent 592 | crepe 593 | polecat 594 | skunk 595 | shakuhachi 596 | banana 597 | golden_retriever 598 | carton 599 | hotdog 600 | golfcart 601 | joystick 602 | spatula 603 | printer 604 | computer_mouse 605 | espresso_maker 606 | melon_seed 607 | screwdriver 608 | dhole 609 | langur 610 | briefcase 611 | paper_crane 612 | beaker 613 | lorikeet 614 | bowtie 615 | panpipe 616 | harp 617 | tresher 618 | panther 619 | comb 620 | american_staffordshire 621 | stapler 622 | frying_pan 623 | kitchen_knife 624 | jordan_logo 625 | bell 626 | sunscreen 627 | dinosaur 628 | apron 629 | chimpanzee 630 | tomato 631 | sponge 632 | red_bayberry 633 | cleaver 634 | space_shuttle 635 | dough 636 | bagel 637 | projector 638 | triumphal_arch 639 | camel 640 | coho 641 | potted_plant 642 | owl 643 | hock 644 | bear 645 | bathtub 646 | cumquat 647 | military_vest 648 | cabbage 649 | jacamar 650 | dandie_dinmont 651 | hornbill 652 | ruffed_grouse 653 | cheese 654 | monocycle 655 | ironing_board 656 | brown_bear 657 | smoothing_iron 658 | face_powder 659 | border_terrier 660 | swimming_trunk 661 | ladybug 662 | banded_gecko 663 | gas_tank 664 | yawl 665 | iceberg 666 | mario 667 | violin 668 | remote_control 669 | obelisk 670 | cheetah 671 | vase 672 | pig 673 | apple 674 | can_opener 675 | flatworm 676 | waffle_iron 677 | wardrobe 678 | basketball 679 | lark 680 | wall_clock 681 | kit_fox 682 | modem 683 | strongbox 684 | perfume 685 | staffordshire 686 | cowboy_hat 687 | french_ball 688 | television 689 | digital_watch 690 | cocktail_shaker 691 | kobe_logo 692 | llama 693 | orange 694 | sulphur_crested 695 | platypus 696 | saltshaker 697 | quad_drone 698 | chickadee_bird 699 | adidas_logo1 700 | ladle 701 | flying_disc 702 | crocodile 703 | ice_lolly 704 | canoe 705 | tape_player 706 | bulbul_bird 707 | adhensive_tape 708 | sundial 709 | macaw 710 | gazelle 711 | gypsy_moth 712 | schooner 713 | microsd 714 | baseball_player 715 | wolf 716 | cannon 717 | pencil_sharpener2 718 | spoonbill 719 | mongoose 720 | sarong 721 | soup_bowl 722 | anemone_fish 723 | manx 724 | fox_squirrel 725 | rhinoceros 726 | butterfly 727 | thatch 728 | bee_house 729 | mouthpiece 730 | hare 731 | police_car 732 | tennis_racket 733 | eft_newt 734 | chicken_wings 735 | monarch_butterfly 736 | golf_ball 737 | armour 738 | maraca 739 | drake 740 | buckler 741 | steam_locomotive 742 | nematode 743 | paint_brush 744 | airliner 745 | bolete 746 | cushion 747 | sandwich_cookies 748 | diaper 749 | olive 750 | albatross 751 | lesser_panda 752 | scarerow 753 | panda 754 | memory_stick 755 | spring_scroll 756 | rose 757 | pencil_box 758 | bomb 759 | terrapin_turtle 760 | sidewinder -------------------------------------------------------------------------------- /data/input_fn.py: -------------------------------------------------------------------------------- 1 | """Input function for sampling few-shot image segmentation data from tfrecords.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import argparse 8 | import os 9 | import sys 10 | from typing import List, Optional 11 | import multiprocessing 12 | 13 | import tensorflow as tf 14 | 15 | # File shuffle buffer size should be larger than the number of shards. 16 | # Shuffle buffer size should be larger than the number of examples per shard. 17 | # Cycle length equal to num shards will reduce randomness as only the first 18 | # examples will be read. 1 reads each TFRecord to the end. 19 | _COMPRESSION_TYPE = "GZIP" 20 | _IMAGE_WIDTH = 512 21 | _PREFETCH_BUFFER_SIZE = 1 # Batches for training in inner loop 22 | _FILE_SHUFFLE_BUFFER_SIZE = 256 # Max number of TFRecord files to shuffle, should be larger than the number of shards 23 | _SHUFFLE_BUFFER_SIZE = 512 # Max size when all tasks are loaded into memory, should be much larger than number of examples, capture handle of the shards at a time 24 | _CYCLE_LENGTH = 1 # Number of TFRecords read simultaneously 25 | _BLOCK_LENGTH = 1 # Number of consecutive elements from each iterator 26 | 27 | 28 | def parse_example(example, image_width, scale_to_0_1: bool = False): 29 | """ 30 | Parse a TF Example into corresponding image and mask. 31 | 32 | Positive class is assumed to be encoded as the int value 255 in tfrecords. 33 | 34 | Args: 35 | example: Batch of TF Example protos. 36 | image_width: Width to resize to of image and mask, assumed to be same as height. 37 | scale_to_0_1: if True, divide by 255. 38 | 39 | Returns: 40 | A pair of 3 channel float images and 2 channel float segmentation masks. 41 | The first channel in the masks will be the background, the second channel 42 | will be the class of interest 43 | """ 44 | 45 | features = { 46 | 'image': tf.FixedLenFeature([], tf.string), 47 | 'mask': tf.FixedLenFeature([], tf.string), 48 | } 49 | 50 | parsed_example = tf.parse_single_example(example, features) 51 | 52 | image = tf.decode_raw(parsed_example['image'], tf.uint8) 53 | # if image_width is not None: 54 | image = tf.reshape(image, (image_width, image_width, 3)) 55 | image = tf.cast(image, tf.float32) 56 | if scale_to_0_1: 57 | image /= 255. 58 | 59 | mask = tf.decode_raw(parsed_example['mask'], tf.uint8) 60 | # if image_width is not None: 61 | mask = tf.reshape(mask, (image_width, image_width)) 62 | mask = tf.stack([255 - mask, mask], axis=2) # Converts pos. class to 2-class. 63 | mask = tf.cast(mask, tf.float32) / 255. 64 | 65 | return image, mask 66 | 67 | 68 | def make_dataset( 69 | file_pattern, 70 | batch_size, 71 | compression_type=_COMPRESSION_TYPE, 72 | image_width=_IMAGE_WIDTH, 73 | prefetch_buffer_size=_PREFETCH_BUFFER_SIZE, 74 | file_shuffle_buffer_size= _FILE_SHUFFLE_BUFFER_SIZE, 75 | shuffle_buffer_size=_SHUFFLE_BUFFER_SIZE, 76 | cycle_length=_CYCLE_LENGTH, 77 | ): 78 | """Makes a TF Dataset from tfrecords. 79 | 80 | Args: 81 | file_pattern: TF Record filenames ending in a wildcard. 82 | batch_size: Number of training examples per batch. 83 | compression_type: Compression type for reading TFRecords. 84 | image_width: Width of image and mask, assumed to be same as height. 85 | prefetch_buffer_size: Number of data elements to preload to device. 86 | file_shuffle_buffer_size: Size of the shuffle buffer for shard filenames. 87 | shuffle_buffer_size: Number of training examples to sample from, 88 | trade-off in favor of speed over randomness for smaller values. 89 | cycle_length: Number of TFRecords to read from simultaneously. 90 | compression_type: Set to 'GZIP' if compressed, None if no compression. 91 | 92 | Returns: 93 | TF Record dataset for (image, mask) pairs. 94 | """ 95 | verbose = True 96 | # TODO: parameterize num_cpus so that when running on hpc can set to <= 28 97 | num_cpus = 20 98 | if verbose: 99 | print("tf dataset listing files matching pattern: {}".format(file_pattern)) 100 | dataset = tf.data.Dataset.list_files(file_pattern) # Shuffles by default. 101 | dataset = dataset.repeat() 102 | dataset = dataset.shuffle(file_shuffle_buffer_size) 103 | # Behavior of interleave with cycle_length 1 is the same as flat_map. Each 104 | # TFRecord will be read until exhaustion before moving on to the next. 105 | dataset = dataset.interleave(lambda filenames: tf.data.TFRecordDataset( 106 | filenames=filenames, 107 | compression_type=compression_type), 108 | cycle_length=cycle_length, 109 | block_length=_BLOCK_LENGTH) 110 | dataset = dataset.map(lambda example: parse_example(example, image_width), 111 | num_parallel_calls=num_cpus) 112 | # META-LEARNING CODE NEEDS THE RESULTING BATCH TO BE A LIST OF UNIQUE EXAMPLES. 113 | # dataset = dataset.shuffle(shuffle_buffer_size) 114 | # dataset = dataset.shuffle(batch_size) 115 | dataset = dataset.batch(batch_size) 116 | dataset = dataset.prefetch(prefetch_buffer_size) 117 | 118 | return dataset 119 | 120 | 121 | def load_from_tfrecords(filenames, image_width=_IMAGE_WIDTH) -> List: 122 | """ 123 | Loads all examples in the tfrecord into memory. 124 | Example usage: 125 | ``` 126 | import tensorflow as tf 127 | import random 128 | data = load_from_tfrecords(["path.tfrecords"]) 129 | with tf.Session() as sess: 130 | for tuple in random.sample(data, 100): 131 | for tensor in tuple: # Loop through (image, mask) 132 | print(tensor.eval()) 133 | ``` 134 | """ 135 | compression = tf.python_io.TFRecordCompressionType.GZIP 136 | options = tf.python_io.TFRecordOptions(compression) 137 | examples = [] 138 | for file in filenames: 139 | for message in tf.python_io.tf_record_iterator(file, options=options): 140 | examples.append(parse_example(message, image_width=image_width)) 141 | return examples 142 | 143 | # def debug(dataset): 144 | # """Debugging utility for tf.data.Dataset.""" 145 | # iterator = tf.data.Iterator.from_structure( 146 | # dataset.output_types, dataset.output_shapes) 147 | # next_element = iterator.get_next() 148 | # 149 | # ds_init_op = iterator.make_initializer(dataset) 150 | # 151 | # with tf.Session() as sess: 152 | # sess.run(ds_init_op) 153 | # viz(sess, next_element) 154 | # import pdb; pdb.set_trace() 155 | # res = sess.run(next_element) 156 | # # for i in range(len(res)): 157 | # # print("IoU of label with itself:") 158 | # # print(Gecko._iou(res[i][1], res[i][1], class_of_interest_channel=None)) 159 | # print(res) 160 | # 161 | # 162 | # def plot_mask(mask_j: np.ndarray, figure_index=0, channel_index: Optional[int] = None): 163 | # import matplotlib.pyplot as plt 164 | # plt.figure(figure_index) 165 | # if channel_index is None: 166 | # for k in range(mask_j.shape[2]): 167 | # if np.sum(mask_j[:, :, k]) == 0: 168 | # continue 169 | # break 170 | # print("class at channel {}".format(k)) 171 | # else: 172 | # k = channel_index 173 | # plt.imshow(mask_j[:, :, k]) 174 | # plt.show() 175 | # print("IoU of label with itself:") 176 | # print(Gecko._iou(mask_j.copy(), mask_j.copy(), class_of_interest_channel=None, round_labels=True)) 177 | # import pdb; pdb.set_trace() 178 | # return k 179 | # 180 | # 181 | # def viz(sess, next_element, num_to_viz=2): 182 | # try: 183 | # import matplotlib.pyplot as plt 184 | # 185 | # for i in range(num_to_viz): 186 | # res = sess.run(next_element) 187 | # image = res[0].astype(int) 188 | # mask = res[1] 189 | # if len(image.shape) == 4: 190 | # for j in range(image.shape[0]): 191 | # plt.figure(i + j) 192 | # plt.imshow(image[j]) 193 | # plt.show() 194 | # mask_j = mask[j] 195 | # plot_mask(mask_j, i + j) 196 | # else: 197 | # plt.figure(i) 198 | # plt.imshow(image) 199 | # plt.show() 200 | # plot_mask(mask, i ) 201 | # except Exception as e: 202 | # print(e) 203 | # import pdb; pdb.set_trace() -------------------------------------------------------------------------------- /figures/EfficientLab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml4ai/mliis/f40352e734f77609bcd5c4ad330ea73a897a217d/figures/EfficientLab.png -------------------------------------------------------------------------------- /figures/example_5-shot_predictions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml4ai/mliis/f40352e734f77609bcd5c4ad330ea73a897a217d/figures/example_5-shot_predictions.png -------------------------------------------------------------------------------- /joint_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trains an image segmentation model with SGD. 3 | 4 | python joint_train.py --seperate_background_channel --data_dir joint_fewshot_shards_uint8_background_channel --augment --epochs 10 --steps_per_epoch 2 --batch_size 3 --val_batches 2 --sgd --l2 --final_layer_dropout_rate 0.2 --rsd 2 --restore_efficient_net_weights_from models/efficientnet/efficientnet-b0 5 | python joint_train.py --fp_k_test_set --seperate_background_channel --augment --epochs 10 --steps_per_epoch 2 --batch_size 3 --val_batches 2 --sgd --l2 --final_layer_dropout_rate 0.2 --rsd 2 --data_dir joint_fewshot_shards_uint8_background_channel_fp-k-test-set --restore_efficient_net_weights_from models/efficientnet/efficientnet-b0 6 | python joint_train.py --test_on_val_set --seperate_background_channel --data_dir joint_fewshot_shards_uint8_background_channel_val-set/ --augment --epochs 10 --steps_per_epoch 2 --batch_size 3 --val_batches 2 --sgd --l2 --final_layer_dropout_rate 0.2 --rsd 2 --restore_efficient_net_weights_from models/efficientnet/efficientnet-b0 7 | 8 | """ 9 | import argparse 10 | import os 11 | import time 12 | from functools import partial 13 | from typing import List, Tuple, Optional, Callable 14 | 15 | import numpy as np 16 | import tensorflow as tf 17 | 18 | from augmenters.np_augmenters import Augmenter, translate, fliplr, additive_gaussian_noise, exposure 19 | from data.fss_1000_utils import TEST_TASK_IDS, TRAIN_TASK_IDS, FP_K_TEST_TASK_IDS 20 | from joint_train.data.input_fn import TFRecordSegmentationDataset 21 | from models.constants import SUPPORTED_MODELS 22 | from models.efficientlab import EfficientLab 23 | from meta_learners.supervised_reptile.supervised_reptile.reptile import Gecko 24 | from utils.util import log_estimated_time_remaining 25 | 26 | TRAIN_ID = "train" 27 | VAL_ID = "val" 28 | TEST_ID = "test" 29 | 30 | 31 | def parse_args(): 32 | """ 33 | Returns an argument parser object for image segmentation training script. 34 | """ 35 | 36 | parser = argparse.ArgumentParser(description="Train segmentation model via SGD.") 37 | 38 | # Data 39 | parser.add_argument("--data_dir", help="Path to folder containing tfrecords", required=True) 40 | parser.add_argument("--fp_k_test_set", help="Hold out the test task for the fp-k classes.", action="store_true") 41 | parser.add_argument("--test_on_val_set", help="If speced, will train on train shards and test on val shards, else will train on both train and val and test on test.", action="store_true") 42 | 43 | # Model 44 | parser.add_argument('--model_name', 45 | help="Name of the model architecture to meta-train. Must be in the set: {}.".format(SUPPORTED_MODELS), required=False, 46 | default='EfficientLab') 47 | parser.add_argument("--rsd", help="List of integers specifying the 1-indexed reduction endpionts from EfficientNet to input into the lightweight skip decoding layers of EfficientLab.", type=int, nargs="+") 48 | parser.add_argument("--feature_extractor_name", help="efficientnet-b0 or efficientnet-b3", type=str, default="efficientnet-b0") 49 | parser.add_argument("--image_size", help="size of image in pixels. images assumed to square", type=int, default=224) 50 | parser.add_argument("--seperate_background_channel", help="Whether or not to make a mutually exclusive background channel.", action='store_true', default=False) 51 | 52 | # Training 53 | parser.add_argument("--restore_efficient_net_weights_from", help="path to dir to restore efficientnet weights from", type=str, default=None) 54 | parser.add_argument('--sgd', help='use vanilla SGD instead of Adam', action='store_true') 55 | parser.add_argument('--loss_name', help='Name of the loss function to use. Should be cross_entropy, cross_entropy_dice, or ce_dice', default='ce_dice') 56 | parser.add_argument("--l2", help="Applies l2 weight decay to all weights in network", action="store_true") 57 | parser.add_argument("--augment", help="Apply augmentations to training data", 58 | action="store_true") 59 | parser.add_argument("--final_layer_dropout_rate", help="Probability to dropout inputs at final layer.", type=float, default=0.0) 60 | parser.add_argument('--batch_size', help='Training batch size', default=64, type=int) 61 | parser.add_argument('--epochs', help='Number of training epochs', default=200, type=int) 62 | parser.add_argument("--steps_per_epoch", help="Number of gradient steps to take per epoch. If unspecified will be determined from batch size and number of examples.", type=int, default=None) 63 | parser.add_argument("--learning_rate", default=0.005, type=float) 64 | parser.add_argument("--final_learning_rate", default=5e-7, type=float) 65 | parser.add_argument("--label_smoothing", default=0.0, type=float) 66 | 67 | 68 | # Evaluation 69 | parser.add_argument("--val_batches", default=20, type=int) 70 | parser.add_argument('--pretrained', help='Evaluate a pre-trained model.', 71 | action='store_true', default=False) 72 | parser.add_argument('--eval_interval', help='Training steps per evaluation', default=2, type=int) 73 | 74 | # Misc. config 75 | parser.add_argument('--seed', help='random seed', default=0, type=int) 76 | parser.add_argument('--checkpoint', help='Checkpoint directory to write to (or restore from).', default='/tmp/model_checkpoint', type=str) 77 | return parser.parse_args() 78 | 79 | 80 | 81 | def get_model_kwargs(parsed_args): 82 | """ 83 | Build the kwargs for model constructors from the 84 | parsed command-line arguments. 85 | """ 86 | parsed_args.model_name = parsed_args.model_name.lower() 87 | if parsed_args.model_name not in SUPPORTED_MODELS: 88 | raise ValueError("Model name must be in the set: {}".format(SUPPORTED_MODELS)) 89 | res = {'learning_rate': parsed_args.learning_rate} 90 | restore_ckpt_dir = parsed_args.restore_efficient_net_weights_from 91 | res["restore_ckpt_dir"] = restore_ckpt_dir 92 | if parsed_args.lsd: 93 | res["rsd"] = parsed_args.lsd 94 | res["feature_extractor_name"] = parsed_args.feature_extractor_name 95 | res["l2"] = parsed_args.l2 96 | res["final_layer_dropout_rate"] = parsed_args.final_layer_dropout_rate 97 | res["label_smoothing"] = parsed_args.label_smoothing 98 | if "dice" not in parsed_args.loss_name: 99 | res["dice"] = False 100 | if parsed_args.sgd: 101 | res['optimizer'] = tf.train.GradientDescentOptimizer 102 | else: 103 | res['optimizer'] = partial(tf.train.AdamOptimizer, beta1=0) 104 | res['loss_name'] = parsed_args.loss_name 105 | res["n_rows"] = parsed_args.image_size 106 | res["n_cols"] = parsed_args.image_size 107 | return res 108 | 109 | 110 | def after_step(): 111 | """Function to be called after a step of gradient descent""" 112 | raise NotImplementedError 113 | 114 | 115 | def after_epoch(): 116 | """Function to be called after an epoch""" 117 | raise NotImplementedError 118 | 119 | 120 | def get_train_test_shards_from_dir(data_dir, ext: str = ".tfrecord.gzip", test_on_val_set: bool = False): 121 | all_shards = os.listdir(data_dir) 122 | all_shards = [x for x in all_shards if ext in x] 123 | train_shards = [x for x in all_shards if TEST_ID not in x] 124 | test_shards = [x for x in all_shards if TRAIN_ID not in x] 125 | 126 | if test_on_val_set: 127 | train_shards = [x for x in train_shards if VAL_ID not in x] 128 | test_shards = [x for x in all_shards if VAL_ID in x] 129 | assert len(set(train_shards + test_shards)) == len(all_shards) - len([x for x in all_shards if TEST_ID in x]) 130 | else: 131 | assert len(set(train_shards + test_shards)) == len(all_shards) 132 | 133 | assert len(set(test_shards).intersection(set(train_shards))) == 0 134 | return [os.path.join(data_dir, x) for x in train_shards], [os.path.join(data_dir, x) for x in test_shards] 135 | 136 | 137 | def get_training_data(data_dir: str, num_classes: int, batch_size: int, image_size: int, ext: str = ".tfrecord.gzip", augment:bool = False, seperate_background_channel: bool = True, test_on_val_set: bool = False) -> Tuple[tf.Tensor, tf.Tensor, tf.Operation]: 138 | train_shards, test_shards = get_train_test_shards_from_dir(data_dir, ext, test_on_val_set=test_on_val_set) 139 | 140 | if augment: 141 | if seperate_background_channel: 142 | mask_filled_translate = partial(translate, mask_fill=[1] + [0] * num_classes) 143 | else: 144 | mask_filled_translate = partial(translate, mask_fill=[0] * num_classes) 145 | 146 | augmenter = Augmenter(aug_funcs=[mask_filled_translate, fliplr, additive_gaussian_noise, exposure]) 147 | else: 148 | augmenter = None 149 | dataset = TFRecordSegmentationDataset(tfrecord_paths=train_shards, image_width=image_size, mask_channels=num_classes, augmenter=augmenter, seperate_background_channel=seperate_background_channel) 150 | dataset, ds_init_op = dataset.make_dataset(batch_size) 151 | return dataset, ds_init_op 152 | 153 | 154 | def train(sess: tf.Session, model: EfficientLab, dataset_init_op: tf.Operation, epochs: int, steps_per_epoch: int, images, masks, save_dir: str, lr_fn: Callable, restore_ckpt_dir: Optional[str] = None, val_batches: int = 20, save_checkpoint_every_n_epochs: int = 2, time_deadline=None, max_checkpoints_to_keep: int = 2, eval_interval: int = 2, report_allocated_tensors_on_oom: bool = True): 155 | """ 156 | 157 | Args: 158 | sess: 159 | model: 160 | dataset_init_op: 161 | epochs: 162 | steps_per_epoch: 163 | images: 164 | masks: 165 | save_dir: 166 | lr_fn: A function that takes in the epoch number and returns the learning rate. For constant, learning rate, define a lambda: lr_fn = lambda i: lr 167 | val_batches: Number of batches to evaluate at the end of each epoch 168 | save_checkpoint_every_n_epochs: 169 | time_deadline: 170 | max_checkpoints_to_keep: 171 | 172 | Returns: 173 | 174 | """ 175 | assert isinstance(epochs, int) 176 | assert isinstance(steps_per_epoch, int) 177 | 178 | if not os.path.exists(save_dir): 179 | os.mkdir(save_dir) 180 | print("Logging to {}".format(save_dir)) 181 | 182 | saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep) 183 | 184 | if restore_ckpt_dir is not None: 185 | print("Restoring from checkpoint {}".format(restore_ckpt_dir)) 186 | model.restore_model(sess, restore_ckpt_dir, filter_to_scopes=[model.feature_extractor_name]) 187 | 188 | try: 189 | if not model.variables_initialized: 190 | print("Initializing variables.") 191 | tf.global_variables_initializer().run() 192 | sess.run(tf.global_variables_initializer()) 193 | except AttributeError: 194 | print("Model does not explicitly track whether variable initialization has already been run on the graph at attribute .variables_initialized.") 195 | print("Initializing variables.") 196 | tf.global_variables_initializer().run() 197 | sess.run(tf.global_variables_initializer()) 198 | 199 | print("Training...") 200 | sess.run(dataset_init_op) 201 | 202 | print("Saving graph definition to {}.".format(save_dir)) 203 | saver.save(sess, os.path.join(save_dir, 'model.ckpt'), global_step=0) 204 | tf.summary.FileWriter(os.path.join(save_dir, 'train'), sess.graph) 205 | 206 | if report_allocated_tensors_on_oom: 207 | run_opts = tf.RunOptions(report_tensor_allocations_upon_oom=True) 208 | else: 209 | run_opts = None 210 | 211 | ious = [] 212 | for i in range(epochs): 213 | start_time = time.time() 214 | print('Epoch: ', i) 215 | lr = lr_fn(i) 216 | print("lr: ", lr) 217 | for _ in range(steps_per_epoch): 218 | try: 219 | _ = sess.run(model.minimize_op, feed_dict={model.lr_ph: lr}, options=run_opts) 220 | except tf.errors.OutOfRangeError: 221 | sess.run(dataset_init_op, options=run_opts) 222 | print("Finished epoch {} with {} steps.".format(i, steps_per_epoch)) 223 | epoch_minutes = log_estimated_time_remaining(start_time, i, epochs, unit_name="epoch") 224 | iters_per_sec = steps_per_epoch / (epoch_minutes * 60) 225 | print("Iterations per second: {}".format(iters_per_sec)) 226 | 227 | if i % eval_interval == 0: 228 | # TODO implement val set accuracy callback 229 | print("Validating") 230 | iou, loss = iou_callback(sess, model, val_batches, run_opts) 231 | print("Loss: {}".format(loss)) 232 | print("IoU on epoch {} estimated on {} batches:".format(i, val_batches)) 233 | print(iou) 234 | ious.append(iou) 235 | 236 | if i % save_checkpoint_every_n_epochs == 0 or i == epochs - 1: 237 | print("Saving checkpoint to {}.".format(save_dir)) 238 | saver.save(sess, os.path.join(save_dir, 'model.ckpt'), global_step=i) 239 | 240 | if time_deadline is not None and time.time() > time_deadline: 241 | break 242 | 243 | print("Training complete. History:") 244 | print("Train set Intersection over Union (IoU):") 245 | print(ious) 246 | 247 | 248 | def iou_callback(sess, model: EfficientLab, val_batches, run_opts): 249 | ious = [] 250 | losses = [] 251 | for _ in range(val_batches): 252 | images, preds, labels, loss = sess.run([model.input_ph, model.predictions, model.label_ph, model.loss], options=run_opts, feed_dict={model.is_training_ph: False}) 253 | # viz(images, preds, labels) 254 | ious.append(compute_iou_metric(preds, labels)) 255 | losses.append(loss) 256 | iou = np.nanmean(ious) 257 | loss = np.nanmean(losses) 258 | return iou, loss 259 | 260 | 261 | def compute_iou_metric(predictions: np.ndarray, labels: np.ndarray): 262 | assert len(predictions) == len(labels) 263 | assert len(predictions.shape) == 4 264 | # Pass prediction and label arrays to _iou: 265 | iou = [Gecko._iou(predictions[i], labels[i], class_of_interest_channel=None) for i in range(predictions.shape[0])] 266 | iou = np.nanmean(iou) 267 | return iou 268 | 269 | 270 | def viz(images, preds, labels): 271 | from utils.debug_tf_dataset import plot_mask 272 | import matplotlib.pyplot as plt 273 | 274 | images = images / 255. 275 | 276 | if len(images.shape) == 4: 277 | for j in range(images.shape[0]): 278 | print("image") 279 | plt.figure(j) 280 | plt.imshow(images[j]) 281 | plt.show() 282 | print("label mask") 283 | mask_j = labels[j] 284 | k = plot_mask(mask_j, j + 1) 285 | print("predicted mask") 286 | pred = preds[j] 287 | plot_mask(pred, j + 2, channel_index=k) 288 | else: 289 | plt.figure(0) 290 | plt.imshow(images) 291 | plt.show() 292 | plot_mask(labels, 1) 293 | 294 | 295 | def main(): 296 | # Reference: https://github.com/SMHendryx/tf-segmentation-trainer/blob/master/train.py 297 | start = time.time() 298 | # Args: 299 | args = parse_args() 300 | data_dir = args.data_dir 301 | learning_rate = args.learning_rate 302 | final_learning_rate = args.final_learning_rate 303 | epochs = args.epochs 304 | 305 | #all_classes, train_classes = get_classes_from_dir(data_dir, ext=".tfrecord.gzip") 306 | 307 | train_classes, test_classes = TRAIN_TASK_IDS, TEST_TASK_IDS 308 | 309 | all_classes = sorted(list(train_classes + test_classes)) 310 | 311 | if args.fp_k_test_set: 312 | test_classes = FP_K_TEST_TASK_IDS 313 | train_classes = [x for x in all_classes if x not in test_classes] 314 | 315 | assert len(set(test_classes).intersection(set(train_classes))) == 0, "train-test class names overlap" 316 | assert len(train_classes + test_classes) == len(set(all_classes)) 317 | 318 | num_classes = len(all_classes) 319 | next_element, dataset_init_op = get_training_data(data_dir, num_classes=num_classes, batch_size=args.batch_size, image_size=args.image_size, augment=args.augment, seperate_background_channel=args.seperate_background_channel, test_on_val_set=args.test_on_val_set) 320 | images = next_element[0] 321 | masks = next_element[1] 322 | 323 | model_kwargs = get_model_kwargs(args) 324 | restore_ckpt_dir = model_kwargs["restore_ckpt_dir"] 325 | model = EfficientLab(images=images, labels=masks, n_classes=num_classes, seperate_background_channel=args.seperate_background_channel, binary_iou_loss=False, **model_kwargs) 326 | 327 | if args.steps_per_epoch is None: 328 | steps_per_epoch = int(760 * 10 // args.batch_size) 329 | else: 330 | steps_per_epoch = args.steps_per_epoch 331 | 332 | def lr_fn(i, epochs=epochs, initial_lr=learning_rate, final_lr=final_learning_rate): 333 | frac_done = i / epochs 334 | cur_lr = frac_done * final_lr + (1 - frac_done) * initial_lr 335 | return cur_lr 336 | 337 | with tf.Session() as sess: 338 | train(sess, model, dataset_init_op, args.epochs, steps_per_epoch=steps_per_epoch, save_dir=args.checkpoint, 339 | lr_fn=lr_fn, val_batches=args.val_batches, images=images, masks=masks, eval_interval=args.eval_interval, restore_ckpt_dir=restore_ckpt_dir) 340 | 341 | print("Finished training") 342 | end = time.time() 343 | print("Experiment took {} hours".format((end - start) / 3600.)) 344 | 345 | 346 | if __name__ == '__main__': 347 | main() 348 | -------------------------------------------------------------------------------- /joint_train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml4ai/mliis/f40352e734f77609bcd5c4ad330ea73a897a217d/joint_train/__init__.py -------------------------------------------------------------------------------- /joint_train/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml4ai/mliis/f40352e734f77609bcd5c4ad330ea73a897a217d/joint_train/data/__init__.py -------------------------------------------------------------------------------- /joint_train/data/constants.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | SERIALIZED_DTYPE = np.uint8 -------------------------------------------------------------------------------- /joint_train/data/input_fn.py: -------------------------------------------------------------------------------- 1 | """Input function for sampling few-shot image segmentation data from tfrecords.""" 2 | 3 | from typing import List, Optional 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from joint_train.data.constants import SERIALIZED_DTYPE 9 | from augmenters.np_augmenters import Augmenter 10 | from utils.debug_tf_dataset import debug 11 | 12 | _COMPRESSION_TYPE = "GZIP" 13 | _PREFETCH_BUFFER_SIZE = 10 # Prefetch batches for training. Reduce if testing locally. 14 | # Shuffle buffer size should be larger than the number of examples per shard. 15 | # Cycle length equal to num shards will reduce randomness as only the first 16 | # examples will be read. 1 reads each TFRecord to the end. 17 | _SHUFFLE_BUFFER_SIZE = 400 # Reduce if testing locally. # Maintain a buffer of _SHUFFLE_BUFFER_SIZE elements, and randomly select the next element from that buffer 18 | _NUM_SUBPROCESSES = 28 19 | _NP_TO_TF_DTYPES = {np.uint8: tf.uint8} 20 | _DEFAULT_READER_BUFFER_SIZE_BYTES = int(2.56e8) 21 | 22 | 23 | class TFRecordSegmentationDataset: 24 | 25 | def __init__(self, tfrecord_paths: List[str], image_width:int, 26 | image_channels=3, mask_channels=1, seed=0, augmenter: Optional[Augmenter]=None, seperate_background_channel: bool = True): 27 | """ 28 | Image segmentation tf.data.Dataset constructor for batching from tfrecords. 29 | Args: 30 | tfrecord_paths: list of paths to tfrecords 31 | image_width: side of images and masks (all images and masks assumed to be square and of the same size) 32 | image_channels: Int number of channels in the images. 33 | mask_channels: Int number of channels in the masks. 34 | seed: Integer seed for the RNG in the data pipeline. 35 | augmenter: An object with an apply_augmentations method that will be wrapped with tf.py_func and applied to input examples 36 | """ 37 | self.tfrecord_paths = tfrecord_paths 38 | self.num_shards = len(self.tfrecord_paths) 39 | self.tfrecord_paths_tensor = tf.constant(self.tfrecord_paths) 40 | self.image_width = image_width 41 | self.image_channels = image_channels 42 | if seperate_background_channel: 43 | mask_channels += 1 44 | self.mask_channels = mask_channels 45 | print("building dataset with labels with {} mask channels".format(self.mask_channels)) 46 | self.serialized_image_raw_dtype = _NP_TO_TF_DTYPES[SERIALIZED_DTYPE] 47 | self.serialized_mask_raw_dtype = _NP_TO_TF_DTYPES[SERIALIZED_DTYPE] 48 | self.seed = seed 49 | self.augmenter = augmenter 50 | 51 | def make_dataset(self, batch_size=8, num_parallel_calls=_NUM_SUBPROCESSES, num_concurrent_reads: Optional[int] = None): 52 | """ 53 | Parse tfrecords from paths, shuffle and batch. 54 | the next element in dataset op and the dataset initializer op. 55 | Args: 56 | batch_size: Number of images/masks in each batch returned. 57 | num_parallel_calls: Number of parallel subprocesses for Dataset.map calls. 58 | num_concurrent_reads: Interleave reading from this number of tfrecords. Sets `cycle_length` param to 59 | interleave. If unspecified, will be set to the number of shards. 60 | Returns: 61 | next_element: A tensor with shape [2], where next_element[0] 62 | is image batch, next_element[1] is the corresponding 63 | mask batch. 64 | init_op: Data initializer op, needs to be executed in a session 65 | for the data queue to be filled up and the next_element op 66 | to yield batches. 67 | """ 68 | print("Making dataset from shards {}".format(self.tfrecord_paths)) 69 | dataset = tf.data.Dataset.from_tensor_slices(self.tfrecord_paths) 70 | dataset = dataset.repeat() 71 | dataset = dataset.shuffle(self.num_shards) 72 | if num_concurrent_reads is None: 73 | num_concurrent_reads = self.num_shards 74 | dataset = dataset.interleave( 75 | lambda filename: tf.data.TFRecordDataset(filename, compression_type=_COMPRESSION_TYPE, buffer_size=_DEFAULT_READER_BUFFER_SIZE_BYTES), 76 | cycle_length=num_concurrent_reads, block_length=1) 77 | dataset = dataset.map(lambda record: self._parse_example(record), num_parallel_calls=num_parallel_calls) 78 | 79 | if self.augmenter is not None: 80 | apply_augs = lambda i, m: self.augmenter.apply_augmentations(i, m, return_image_mask_in_list=False) 81 | dataset = dataset.map( 82 | lambda image, mask: tf.py_func(apply_augs, [image, mask], [tf.float32, tf.float32]), 83 | num_parallel_calls=num_parallel_calls) 84 | dataset = dataset.map(lambda image, mask: ( 85 | tf.reshape(image, (self.image_width, self.image_width, self.image_channels)), 86 | tf.reshape(mask, (self.image_width, self.image_width, self.mask_channels))), 87 | num_parallel_calls=num_parallel_calls) 88 | 89 | dataset = dataset.shuffle(_SHUFFLE_BUFFER_SIZE) 90 | dataset = dataset.batch(batch_size) 91 | dataset = dataset.prefetch(_PREFETCH_BUFFER_SIZE) 92 | # debug(dataset) # Uncomment to visualize examples 93 | iterator = tf.data.Iterator.from_structure( 94 | dataset.output_types, dataset.output_shapes) 95 | next_element = iterator.get_next() 96 | ds_init_op = iterator.make_initializer(dataset) 97 | 98 | return next_element, ds_init_op 99 | 100 | def _parse_example(self, example, scale_to_0_1: bool = False): 101 | """ 102 | Parse a TF Example into corresponding image and mask. 103 | 104 | Positive class is assumed to be encoded as the int value 255 in tfrecords. 105 | 106 | Args: 107 | record_name: name of the tfrecord indicating the class 108 | example: Batch of TF Example protos. 109 | image_width: Width to resize to of image and mask, assumed to be same as height. 110 | scale_to_0_1: if True, scale images by dividing by 255. 111 | 112 | Returns: 113 | A pair of 3 channel float images and n-channel float segmentation masks. 114 | The first channel in the masks will be the background, the rest 115 | will be the classes of interest. 116 | """ 117 | 118 | features = { 119 | 'image': tf.FixedLenFeature([], tf.string), 120 | 'mask': tf.FixedLenFeature([], tf.string), 121 | } 122 | parsed_example = tf.parse_single_example(example, features) 123 | 124 | image = tf.decode_raw(parsed_example['image'], self.serialized_image_raw_dtype) 125 | image = tf.reshape(image, (self.image_width, self.image_width, self.image_channels)) 126 | image = tf.cast(image, tf.float32) 127 | if scale_to_0_1: 128 | image /= 255. 129 | 130 | mask = tf.decode_raw(parsed_example['mask'], self.serialized_mask_raw_dtype) 131 | mask = tf.reshape(mask, (self.image_width, self.image_width, self.mask_channels)) 132 | mask = tf.cast(mask, tf.float32) / 255. 133 | return image, mask 134 | 135 | 136 | def parse_example(example, image_width:int = 224, image_channels: int = 3, mask_channels: int = 1000, scale_to_0_1: bool = False, serialized_mask_raw_dtype = tf.float64): 137 | """ 138 | Parse a TF Example into corresponding image and mask. 139 | 140 | Positive class is assumed to be encoded as the int value 255 in tfrecords. 141 | 142 | Args: 143 | record_name: name of the tfrecord indicating the class 144 | example: Batch of TF Example protos. 145 | image_width: Width to resize to of image and mask, assumed to be same as height. 146 | scale_to_0_1: if True, divide by 255. 147 | 148 | Returns: 149 | A pair of 3 channel float images and n-channel float segmentation masks. 150 | The first channel in the masks will be the background, the rest 151 | will be the classes of interest. 152 | """ 153 | 154 | features = { 155 | 'image': tf.FixedLenFeature([], tf.string), 156 | 'mask': tf.FixedLenFeature([], tf.string), 157 | } 158 | 159 | parsed_example = tf.parse_single_example(example, features) 160 | 161 | image = tf.decode_raw(parsed_example['image'], tf.uint8) 162 | image = tf.reshape(image, (image_width, image_width, image_channels)) 163 | image = tf.cast(image, tf.float32) 164 | if scale_to_0_1: 165 | image /= 255. 166 | 167 | mask = tf.decode_raw(parsed_example['mask'], serialized_mask_raw_dtype) # tf.uint8) 168 | mask = tf.reshape(mask, (image_width, image_width, mask_channels)) 169 | mask = tf.cast(mask, tf.float32) / 255. 170 | return image, mask 171 | 172 | 173 | def load_from_tfrecords(filenames) -> List: 174 | """ 175 | Loads all examples in the tfrecord into memory. 176 | Example usage: 177 | ``` 178 | import tensorflow as tf 179 | import random 180 | data = load_from_tfrecords(["path.tfrecords"]) 181 | with tf.Session() as sess: 182 | for tuple in random.sample(data, 100): 183 | for tensor in tuple: # Loop through (image, mask) 184 | print(tensor.eval()) 185 | ``` 186 | """ 187 | compression = tf.python_io.TFRecordCompressionType.GZIP 188 | options = tf.python_io.TFRecordOptions(compression) 189 | examples = [] 190 | for file in filenames: 191 | for message in tf.python_io.tf_record_iterator(file, options=options): 192 | examples.append(parse_example(message)) 193 | break 194 | return examples 195 | -------------------------------------------------------------------------------- /make_python_virtualenv.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Makes Python 3 virtual environment in a directory called .env 4 | # Tested on OSX (needs python3, pip3, and virtualenv installed) 5 | 6 | pip3 install -U pip3 7 | pip3 install -U setuptools 8 | 9 | pathToPython3=$(which python3) 10 | echo $pathToPython3 11 | 12 | python3 -m venv .env 13 | #virtualenv -p $pathToPython3 .env 14 | #venv -p $pathToPython3 .env 15 | 16 | source .env/bin/activate 17 | 18 | pip3 install -r requirements.txt 19 | 20 | 21 | # To deactivate run: 22 | # deactivate 23 | -------------------------------------------------------------------------------- /meta_learners/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml4ai/mliis/f40352e734f77609bcd5c4ad330ea73a897a217d/meta_learners/__init__.py -------------------------------------------------------------------------------- /meta_learners/hyperparam_search.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions to estimate optimal hyperparameters when adapting to unseen tasks. 3 | 4 | Contains the implementation of update hyperparameter optimization (UHO) with Bayesian optimization and Gaussian processes. 5 | """ 6 | 7 | from collections import deque 8 | import operator 9 | import os 10 | from copy import copy 11 | from typing import Callable, Dict, List, Tuple, Optional, Any 12 | 13 | import numpy as np 14 | import pandas as pd 15 | from skopt import Optimizer 16 | from skopt.space import Categorical, Real, Integer 17 | 18 | DROPOUT_RATE_NAME = "drop_rate" 19 | AUG_RATE_NAME = "aug_rate" 20 | BATCH_SIZE_NAME = "inner_batch_size" 21 | LEARNING_RATE_NAME = "lr" 22 | SUPPORTED_SEARCH_ALGS = {"GP"} 23 | 24 | class EarlyStopper: 25 | """ 26 | Computes stopping criterion given a metric and a patience. 27 | """ 28 | 29 | def __init__(self, patience: int = 10, metric_should_increase: bool = True, min_steps: int = 0): 30 | """ 31 | Args: 32 | patience: How many steps to continue training if eval does not improve. 33 | metric_should_increase: If True, metric is expected to increase (i.e. set to True for accuracy or IoU, 34 | False for a loss function such as cross entropy. 35 | """ 36 | self.patience = patience 37 | self.metric_should_increase = metric_should_increase 38 | if metric_should_increase: 39 | self.eval_operator = operator.gt 40 | else: 41 | self.eval_operator = operator.lt 42 | self._best_metric = None 43 | self._best_num_steps = None 44 | self.num_evals_without_improving = 0 45 | self.min_steps = min_steps 46 | if min_steps > 0: 47 | self._best_num_steps = min_steps 48 | print("Built EarlyStopper with patience {}".format(self.patience)) 49 | 50 | def continue_training(self, metric, total_steps_taken): 51 | if total_steps_taken <= self.min_steps: 52 | self._best_metric = metric 53 | return True 54 | elif self._best_metric is None or self.eval_operator(metric, self._best_metric): 55 | self.num_evals_without_improving = 0 56 | self._best_metric = metric 57 | self._best_num_steps = total_steps_taken 58 | else: 59 | self.num_evals_without_improving += 1 60 | if self.num_evals_without_improving > self.patience: 61 | return False 62 | return True 63 | 64 | def best_metric(self): 65 | return self._best_metric 66 | 67 | def best_num_steps(self): 68 | return self._best_num_steps 69 | 70 | 71 | def run_m(eval_fn: Callable, params: Dict, m: int = 1): 72 | """ 73 | Calls `eval_fn` with `params`, returns results. Assumes that eval_fn returns a tuple of: 74 | (list of task IDs, list of iterations, and list of metrics). 75 | 76 | Args: 77 | eval_fn: A callable that takes `params` and returns a tuple of: 78 | (list of task IDs, list of iterations, and list of metrics). 79 | params: kwargs fed into eval_fn call 80 | m: Number of times to eval_fn(**params). 81 | 82 | Returns: 83 | The metrics returned by running eval_fn with params 84 | """ 85 | all_task_ids, all_num_steps, all_metrics = [], [], [] 86 | for _ in range(m): 87 | task_ids, num_steps, metrics = eval_fn(**params) 88 | all_task_ids.extend(task_ids) 89 | all_num_steps.extend(num_steps) 90 | all_metrics.extend(metrics) 91 | return all_task_ids, all_num_steps, all_metrics 92 | 93 | 94 | def save_results(results: List[Tuple[Dict, Tuple[List, List, List]]], path: str, metric_name: str = "mIoU", append_if_exists: bool = False): 95 | """Takes the results and saves them to csv""" 96 | print("Saving results to {}".format(path)) 97 | # Format results to make a dataframe: 98 | formatted = {"task_ID": [], "best_num_steps": [], metric_name: []} 99 | for result in results: 100 | config, config_results = result 101 | task_ids, num_steps, metrics = config_results 102 | 103 | for key, val in config.items(): 104 | try: 105 | formatted[key].extend([val for _ in range(len(task_ids))]) 106 | except KeyError: 107 | formatted[key] = [val for _ in range(len(task_ids))] 108 | 109 | formatted["task_ID"].extend(task_ids) 110 | formatted["best_num_steps"].extend(num_steps) 111 | formatted[metric_name].extend(metrics) 112 | df = pd.DataFrame(formatted) 113 | mode = "w" 114 | header = True 115 | if os.path.exists(path): 116 | if not append_if_exists: 117 | i = 0 118 | while True: 119 | new_path = path + "_{}".format(i) 120 | if not os.path.exists(new_path): 121 | break 122 | i += 1 123 | path = new_path 124 | mode = "w" 125 | header = True 126 | else: 127 | mode = "a" 128 | header = False 129 | df.to_csv(path, index=False, mode=mode, header=header) 130 | print("Saved optimization raw results to {}".format(path)) 131 | 132 | 133 | def compute_best_configuration(results_list, metric_should_increase=True): 134 | if metric_should_increase: 135 | eval_operator = operator.gt 136 | best_metric = -np.inf 137 | else: 138 | eval_operator = operator.lt 139 | best_metric = np.inf 140 | 141 | for sampled_config, results in results_list: 142 | task_ids, num_steps, metrics = results 143 | miou_across_tasks = np.mean(metrics) 144 | if eval_operator(miou_across_tasks, best_metric): 145 | best_config = sampled_config 146 | best_metric = miou_across_tasks 147 | m_best_num_steps = np.median(num_steps) 148 | best_step_num = m_best_num_steps 149 | 150 | print("Best mIoU found: {}".format(best_metric)) 151 | print("with median iteration: {}".format(best_step_num)) 152 | print("and config: {}".format(best_config)) 153 | 154 | return best_config, int(best_step_num), best_metric 155 | 156 | 157 | def log_opt_progress(hyperparams, results_i, task_ids, num_steps, metrics, save_results_to): 158 | print("Results for hyperparams {}: task IDs: {}, best num steps: {}, mIoUs: {}".format(hyperparams, task_ids, num_steps, 159 | metrics)) 160 | print("mean mIoU: {}".format(np.nanmean(metrics))) 161 | 162 | if save_results_to is not None: 163 | save_results([results_i], save_results_to, append_if_exists=True) 164 | 165 | 166 | def insert_sampled_into_full_set_of_hyperparams(sampled, hyperparams) -> Dict: 167 | for key, val in sampled.items(): 168 | hyperparams[key] = val 169 | return hyperparams 170 | 171 | 172 | def get_dim_type(value: List[Any]): 173 | value = value[0] 174 | if isinstance(value, float): 175 | return Real 176 | elif isinstance(value, int): 177 | return Integer 178 | elif isinstance(value, str): 179 | return Categorical 180 | else: 181 | raise ValueError("Value must be float, int, or str, but {} is {}".format(value, type(value))) 182 | 183 | 184 | def gp_update_hyperparameter_optimization(eval_fn: Callable, hyperparams: Dict, search_key_ranges: Dict, n: int, 185 | save_results_to: Optional[str] = "gp_hyper_param_search_results.csv", m: int = 1, 186 | metric_should_increase: bool = True, metric_name: str = "mIoU", base: int = 2, 187 | n_initial_points: Optional[int] = None, prior: str = "log-uniform"): 188 | """ 189 | Multitask hyperparameter search with Gaussian process regression of values in search_key_ranges. 190 | Calls `eval_fn` with `params`, replacing values in `params` with expected improvement maximizing values sampled from 191 | the ranges in `search_key_ranges` for the keys that are in both `params` and `search_key_ranges`. 192 | 193 | Args: 194 | eval_fn: The function to call with params that returns a metric. 195 | hyperparams: Dictionary of kwargs that must be specified to call eval_fn. 196 | search_key_ranges: Dictionary mapping a key in params to a range to sample from. 197 | n: number of hyperparameter configurations to sample. 198 | m: number of train-val splits datasets to sample 199 | metric_should_increase: If true 200 | prior: Sample points from this distribution. E.g., "log-uniform" sample from a log scaled uniform distribution. 201 | 202 | Returns: 203 | Tuple of the sampled values in a dictionary with the same keys as search_key_ranges and the resulting metric. 204 | """ 205 | for key in search_key_ranges.keys(): 206 | assert key in hyperparams, "key: {} not in hyperparams: {}".format(key, hyperparams) 207 | 208 | if n_initial_points is None: 209 | n_initial_points = int(n / 2) 210 | print("Sampling {} points initially at random.".format(n_initial_points)) 211 | 212 | search_dim_types = {key: get_dim_type(val) for key, val in search_key_ranges.items()} 213 | 214 | dims = [search_dim_types[key](domain[0], domain[1], prior=prior, base=base, name=key) for key, domain in search_key_ranges.items() if domain[0] != domain[1]] 215 | dim_names = [dim.name for dim in dims] 216 | opt = Optimizer( 217 | dims, 218 | "GP", # Estimate metric as a function of lr using a Gaussian Process. 219 | acq_func='EI', # Use Expected Improvement as an acquisition function. 220 | acq_optimizer="lbfgs", # Draw random samples from GP then optimize to find best lr to suggest. 221 | n_initial_points=n_initial_points, # First points will be completely random to avoid exploiting too early. 222 | ) 223 | 224 | results = [] 225 | for i in range(n): 226 | print("Running configuration sample {} of {}.".format(i + 1, n)) 227 | print("With sampled hyperparams:") 228 | sampled_list = opt.ask() 229 | sampled = {name: x for name, x in zip(dim_names, sampled_list)} 230 | print(sampled) 231 | 232 | hyperparams = insert_sampled_into_full_set_of_hyperparams(sampled, hyperparams) 233 | 234 | task_ids, num_steps, metrics = run_m(eval_fn, hyperparams, m) 235 | 236 | # Most recent metric observed for given params 237 | objective = np.nanmean(metrics) 238 | if metric_should_increase: 239 | objective *= -1 240 | print("Objective value at sample {} of {}: {}".format(i + 1, n, objective)) 241 | opt_result = opt.tell(sampled_list, objective) 242 | 243 | results_i = (sampled, (task_ids, num_steps, metrics)) 244 | results.append(results_i) 245 | log_opt_progress(hyperparams, results_i, task_ids, num_steps, metrics, save_results_to) 246 | 247 | best_config, expected_best_step_num, best_metric = compute_best_configuration(results, metric_should_increase) 248 | 249 | return best_config, expected_best_step_num, best_metric, results 250 | 251 | 252 | def lr_droprate_aug_rate_batch_size_gp_search(eval_fn: Callable, params: Dict, lr_name: str = LEARNING_RATE_NAME, lr_search_range_low: float = 0.0005, lr_search_range_high: float = 0.05, 253 | droprate_name: str = DROPOUT_RATE_NAME, drop_rate_search_range_low: float = 0.2, drop_rate_search_range_high: float = 0.2, 254 | aug_rate_name: str = AUG_RATE_NAME, aug_rate_search_range_low: float = 0.5, aug_rate_search_range_high: float = 0.5, 255 | batch_size_name: str = BATCH_SIZE_NAME, batch_size_search_range_low: int = 8, batch_size_search_range_high: int = 8, 256 | n: int = 100, 257 | save_results_to: str = "hyper_param_search_results.csv", m: int = 1, 258 | metric_should_increase: bool = True, metric_name: str = "mIoU") -> Tuple[float, int]: 259 | """ 260 | Performs search over learning rates by randomly sampling within range and successively reducing the range based on 261 | top x percent of results. Returns the best learning rate and expected number of iterations. 262 | """ 263 | lr_range = [float(lr_search_range_low), float(lr_search_range_high)] 264 | if lr_range[0] > lr_range[1]: 265 | lr_range[0], lr_range[1] = lr_range[1], lr_range[0] 266 | drop_range = [float(drop_rate_search_range_low), float(drop_rate_search_range_high)] 267 | if drop_range[0] > drop_range[1]: 268 | drop_range[0], drop_range[1] = drop_range[1], drop_range[0] 269 | aug_range = [float(aug_rate_search_range_low), float(aug_rate_search_range_high)] 270 | if aug_range[0] > aug_range[1]: 271 | aug_range[0], aug_range[1] = aug_range[1], aug_range[0] 272 | batch_range = [int(batch_size_search_range_low), int(batch_size_search_range_high)] 273 | if batch_range[0] > batch_range[1]: 274 | batch_range[0], batch_range[1] = batch_range[1], batch_range[0] 275 | 276 | search_key_ranges = {lr_name: lr_range, droprate_name: drop_range, aug_rate_name: aug_range, batch_size_name: batch_range} 277 | 278 | best_config, expected_best_step_num, _, _ = gp_update_hyperparameter_optimization(eval_fn=eval_fn, hyperparams=params, search_key_ranges=search_key_ranges, n=n, 279 | save_results_to=save_results_to, m=m, metric_should_increase=metric_should_increase, metric_name=metric_name) 280 | 281 | return float(best_config[lr_name]), int(expected_best_step_num) 282 | -------------------------------------------------------------------------------- /meta_learners/metaseg.py: -------------------------------------------------------------------------------- 1 | """ 2 | APIs for loading the meta-learning segmentation datasets. 3 | """ 4 | 5 | import os 6 | import glob 7 | import random 8 | import warnings 9 | from typing import List, Tuple, Union, Optional 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | from augmenters.np_augmenters import Augmenter 15 | from data import input_fn 16 | from data.fss_1000_utils import split_train_test_tasks, get_fss_tasks, TEST_TASK_IDS 17 | from utils.viz import savefig_mask_on_image 18 | from utils.util import count_examples_in_tfrecords, hash_np_array 19 | 20 | DEFAULT_NUM_TEST_EXAMPLES = 5 21 | DEFAULT_K_SHOT_SET = [{"airliner", "aeroplane"}, {"bus"}, {"motorbike"}, {"potted_plant", "potted plant"}, {"television", "tvmonitor"}] 22 | 23 | 24 | def read_fss_1000_dataset(data_dir: str, 25 | num_val_tasks: int = 0, 26 | num_test_tasks: int = 240, 27 | test_task_ids: Optional[List[str]] = TEST_TASK_IDS, 28 | image_size: Optional[int] = 224 29 | ) -> Tuple[List["BinarySegmentationTask"], List["BinarySegmentationTask"], List["BinarySegmentationTask"], 30 | List[str], List[str], List[str]]: 31 | """ 32 | Reads in the FSS-1000 meta-learning image segmentation dataset. Assumes each task is in a shard. 33 | Args: 34 | data_dir: a directory containing tfrecords files for each semantic class. 35 | 36 | Returns: 37 | Tuple of (train_tasks, val_tasks, test_tasks, train_task_names, val_task_names, test_task_names). 38 | First three objects are instances of BinarySegmentationTask. 39 | """ 40 | verbose = False 41 | 42 | all_tasks = get_fss_tasks(data_dir) 43 | 44 | if test_task_ids is None: 45 | train_shards, test_shards = split_train_test_tasks(all_tasks, num_test_tasks) 46 | else: 47 | train_shards, test_shards = [], [] 48 | for task in all_tasks: 49 | comparer = os.path.basename(task).replace(".tfrecord.gzip", "") 50 | if comparer in test_task_ids: 51 | test_shards.append(task) 52 | else: 53 | train_shards.append(task) 54 | assert all([os.path.basename(x).replace(".tfrecord.gzip", "") in test_task_ids for x in test_shards]), "Test shard not in test_task_ids" 55 | assert all([not os.path.basename(x).replace(".tfrecord.gzip", "") in test_task_ids for x in train_shards]), "Test set task found in train shards" 56 | 57 | train_shards, val_shards = split_train_test_tasks(train_shards, num_val_tasks, reproducbile_splits=True) 58 | 59 | print("{} training tasks, {} val tasks, {} test tasks.".format(len(train_shards), len(val_shards), len(test_shards))) 60 | 61 | train_tasks, val_tasks, test_tasks = [], [], [] 62 | iterator = None 63 | 64 | print("Building FSS-1000 training task samplers...") 65 | train_task_names = [] 66 | for task in train_shards: 67 | task_name = os.path.basename(task) 68 | train_task_names.append(task_name) 69 | batch_size = count_examples_in_tfrecords([task]) 70 | if verbose: 71 | print("{} examples in task {}".format(batch_size, task_name)) 72 | few_shot_seg_task = BinarySegmentationTask( 73 | iterator=iterator, 74 | tfrecord_paths=task, 75 | batch_size=batch_size, 76 | name=task_name, 77 | image_size=image_size, 78 | verbose=False) 79 | # Initialize all tasks using the same iterator. 80 | if not iterator: 81 | print("making new iterator in read_dataset for task: {}".format(task_name)) 82 | iterator = few_shot_seg_task.iterator 83 | train_tasks.append(few_shot_seg_task) 84 | 85 | print("Building FSS-1000 val task samplers...") 86 | val_task_names = [] 87 | for task in val_shards: 88 | task_name = os.path.basename(task) 89 | val_task_names.append(task_name) 90 | batch_size = count_examples_in_tfrecords([task]) 91 | if verbose: 92 | print("Meta-val task: {}".format(task_name)) 93 | print("{} examples in task {}".format(batch_size, task_name)) 94 | few_shot_seg_task = BinarySegmentationTask( 95 | iterator=iterator, 96 | tfrecord_paths=task, 97 | batch_size=batch_size, 98 | name=task_name, 99 | image_size=image_size, 100 | verbose=False) 101 | val_tasks.append(few_shot_seg_task) 102 | 103 | print("Building FSS-1000 test task samplers...") 104 | test_task_names = [] 105 | for task in test_shards: 106 | task_name = os.path.basename(task) 107 | test_task_names.append(task_name) 108 | batch_size = count_examples_in_tfrecords([task]) 109 | if verbose: 110 | print("Meta-test task: {}".format(task_name)) 111 | print("{} examples in task {}".format(batch_size, task_name)) 112 | few_shot_seg_task = BinarySegmentationTask( 113 | iterator=iterator, 114 | tfrecord_paths=task, 115 | batch_size=batch_size, 116 | name=task_name, 117 | image_size=image_size, 118 | verbose=False) 119 | test_tasks.append(few_shot_seg_task) 120 | 121 | return train_tasks, val_tasks, test_tasks, train_task_names, val_task_names, test_task_names 122 | 123 | 124 | def read_fp_k_shot_dataset(data_dir: str, 125 | all_task_names = DEFAULT_K_SHOT_SET, 126 | image_size: Optional[int] = 224 127 | ) -> Tuple[List["BinarySegmentationTask"], List[str]]: 128 | """ 129 | Reads in the FP-k-shot meta-learning image segmentation dataset, which contains FSS-1000 and PASCAL-5^i classes. 130 | Args: 131 | data_dir: a directory containing tfrecords files for each semantic class. 132 | all_task_names: list of set of synonyms defining the tasks. 133 | Returns: 134 | BinarySegmentationTask objects for the tasks in `tasks`. 135 | """ 136 | verbose = True 137 | 138 | all_tasks = get_fss_tasks(data_dir) 139 | 140 | print("{} tasks found.".format(len(all_tasks))) 141 | 142 | test_tasks = [] 143 | iterator = None 144 | 145 | print("Building k-shot-FSS-1000 test task samplers...") 146 | test_task_names = [] 147 | for synonyms in all_task_names: 148 | task_shards = [] 149 | task_globs = [] 150 | for i, synonym in enumerate(synonyms): 151 | synonym = synonym.replace(" ", "") 152 | if i == 0: 153 | task_name = synonym 154 | print("Processing task: {}".format(task_name)) 155 | syn_shards = [x for x in all_tasks if synonym in os.path.basename(x)] 156 | task_shards.extend(syn_shards) 157 | task_globs.append(os.path.join(data_dir, "{}*.tfrecord*".format(synonym))) 158 | 159 | print("task shards: {}".format(task_shards)) 160 | 161 | test_task_names.append(task_name) 162 | batch_size = count_examples_in_tfrecords(task_shards) 163 | if verbose: 164 | print("{} examples in task {}".format(batch_size, task_name)) 165 | few_shot_seg_task = BinarySegmentationTask( 166 | iterator=iterator, 167 | tfrecord_paths=task_globs, 168 | batch_size=batch_size, 169 | name=task_name, 170 | image_size=image_size, 171 | verbose=False) 172 | # Initialize all tasks using the same iterator. 173 | if not iterator: 174 | print("making new iterator in read_dataset for task: {}".format(task_name)) 175 | iterator = few_shot_seg_task.iterator 176 | test_tasks.append(few_shot_seg_task) 177 | 178 | return test_tasks, test_task_names 179 | 180 | 181 | class BinarySegmentationTask: 182 | """ 183 | Segmentation maps for binary segmentations. 184 | Label dimensions are [n_row, n_col, 2] (one-hot encoding). 185 | """ 186 | def __init__(self, 187 | tfrecord_paths, 188 | iterator=None, 189 | batch_size=32, 190 | seed=None, 191 | name: str = None, 192 | image_size: Optional[int] = None, 193 | verbose: bool = False): 194 | self.tfrecord_paths = tfrecord_paths 195 | self.batch_size = batch_size 196 | 197 | 198 | if image_size is not None: 199 | dataset = input_fn.make_dataset(self.tfrecord_paths, self.batch_size, image_width=image_size) 200 | else: 201 | dataset = input_fn.make_dataset(self.tfrecord_paths, self.batch_size) 202 | if not iterator: 203 | print("making new iterator in BinarySegmentationTask.__init__ for task: {}".format(name)) 204 | iterator = tf.data.Iterator.from_structure(dataset.output_types, 205 | dataset.output_shapes) 206 | self.iterator = iterator 207 | self._initialization_op = self.iterator.make_initializer(dataset) 208 | self._next_element = self.iterator.get_next() 209 | 210 | self.name = name 211 | if verbose: 212 | print("BinarySegmentationTask for data {} will return batches of size {}".format(self.name, self.batch_size)) 213 | 214 | def sample(self, sess, num_images, verbose=False) -> List[List[np.array]]: 215 | """ 216 | Sample tuple of (image, label) from tfrecords 217 | Args: 218 | sess: tf session 219 | Returns: 220 | A sequence of (image, label) pairs 221 | """ 222 | if num_images > self.batch_size: 223 | raise ValueError("Tried to sample {} examples.Cannot sample more than {} examples that generator was initialized with.".format(num_images, self.batch_size)) 224 | 225 | # Reinitialize iterator with this task's dataset, then fetch one batch. 226 | sess.run(self._initialization_op) 227 | # Fetch a batch of size self.batch_size of images and corresponding masks: 228 | images, masks = sess.run(self._next_element) 229 | 230 | return [[image, mask] for image, mask in zip(images[:num_images], masks[:num_images])] 231 | 232 | 233 | def _sample_mini_image_segmentation_dataset(sess, dataset, num_classes, num_shots, return_task_name: bool = False) -> Union[list, Tuple[list, str]]: 234 | """ 235 | Samples a binary, image segmentation task from a dataset with num_shots examples. 236 | num_classes currently ignored 237 | 238 | Returns: 239 | An iterable of (input, label) tuples of length num_shots. 240 | """ 241 | l = list(dataset) 242 | 243 | # Sample random task: 244 | class_obj = random.sample(l, 1)[0] 245 | 246 | # print("Sampled task: {}".format(class_obj.name)) 247 | 248 | if num_shots > class_obj.batch_size: # Account for fewer examples available than num_shots 249 | warnings.warn("Requested {} examples but dataset can return max of {} examples.".format(num_shots, class_obj.batch_size)) 250 | num_shots = class_obj.batch_size 251 | 252 | if not return_task_name: 253 | return class_obj.sample(sess, num_shots) 254 | else: 255 | return class_obj.sample(sess, num_shots), class_obj.name 256 | 257 | 258 | def _mini_batches(samples, batch_size, num_batches, replacement: bool = False, augmenter: Optional[Augmenter] = None, aug_rate: Optional[float] = None,): 259 | """ 260 | Generate mini-batches from some data. 261 | Args: 262 | replacement: bool. If False, loop through all examples before sampling an example again. 263 | Returns: 264 | An iterable of sequences of (input, label) pairs, 265 | where each sequence is a mini-batch. 266 | """ 267 | if aug_rate is not None: 268 | prob_to_return_original = 1.0 - aug_rate 269 | else: 270 | prob_to_return_original = None 271 | samples = list(samples) 272 | if len(samples) == 0: 273 | raise ValueError('No samples to sample. `samples` has no length: {}'.format(samples)) 274 | if replacement: 275 | for _ in range(num_batches): 276 | cur_batch = random.sample(samples, batch_size) 277 | if augmenter is not None: 278 | _cur_batch = [] 279 | for sample in cur_batch: 280 | sample = augmenter.apply_augmentations(sample[0], sample[1], prob_to_return_original) 281 | _cur_batch.append(sample) 282 | cur_batch = _cur_batch 283 | yield cur_batch 284 | return 285 | cur_batch = [] 286 | batch_count = 0 287 | # i = 0 288 | while True: 289 | random.shuffle(samples) 290 | for sample in samples: 291 | if augmenter is not None: 292 | sample = augmenter.apply_augmentations(sample[0], sample[1], prob_to_return_original) 293 | # savefig_mask_on_image(sample[0], sample[1], save_path=os.path.join("augs", str(i) + ".png")) 294 | # i += 1 295 | cur_batch.append(sample) 296 | if len(cur_batch) < batch_size: 297 | continue 298 | yield cur_batch 299 | cur_batch = [] 300 | batch_count += 1 301 | if batch_count == num_batches: 302 | return 303 | 304 | 305 | def assert_train_test_split(train_set, test_set): 306 | train_hashes = set() 307 | for image, _ in train_set: 308 | train_hashes.add(hash_np_array(image)) 309 | for image, _ in test_set: 310 | assert hash_np_array(image) not in train_hashes 311 | 312 | 313 | def _sample_train_test_segmentation_with_replacement(samples: List, train_shots: int = 5, test_shots: int = 5): 314 | indices = np.random.randint(len(samples), size=train_shots) 315 | train_set = [samples[x] for x in indices] 316 | indices = np.random.randint(len(samples), size=test_shots) 317 | test_set = [samples[x] for x in indices] 318 | return train_set, test_set 319 | 320 | 321 | def _split_train_test_segmentation(samples, test_shots=1, test_train_test_split: bool = False, shuffle_before_split: bool = True): 322 | """ 323 | Split a few-shot task into a train and a test set. 324 | 325 | Args: 326 | samples: an iterable of (input, label) pairs. Should already be shuffled. 327 | test_shots: the number of examples per class in the 328 | test set. 329 | 330 | Returns: 331 | A tuple (train, test), where train and test are 332 | sequences of (input, label) pairs. 333 | """ 334 | samples = list(samples)[:] 335 | 336 | if shuffle_before_split: 337 | random.shuffle(samples) 338 | 339 | train_set = samples[:-test_shots] 340 | test_set = samples[-test_shots:] 341 | if test_train_test_split: 342 | assert_train_test_split(train_set, test_set) 343 | return train_set, test_set 344 | -------------------------------------------------------------------------------- /meta_learners/supervised_reptile/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 OpenAI 4 | Modified Work Copyright (c) 2020 Sean M. Hendryx 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /meta_learners/supervised_reptile/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reptile for supervised meta-learning. 3 | """ -------------------------------------------------------------------------------- /meta_learners/supervised_reptile/supervised_reptile/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reptile for supervised meta-learning. 3 | """ 4 | -------------------------------------------------------------------------------- /meta_learners/supervised_reptile/supervised_reptile/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for evaluating models. 3 | """ 4 | import itertools 5 | import os 6 | from typing import Optional, List, Dict, Tuple 7 | 8 | import numpy as np 9 | import pandas as pd 10 | 11 | from meta_learners.hyperparam_search import LEARNING_RATE_NAME, \ 12 | DROPOUT_RATE_NAME, lr_droprate_aug_rate_batch_size_gp_search, AUG_RATE_NAME 13 | from utils.util import ci95 14 | from .reptile import Gecko, DEFAULT_ITER_RANGE 15 | from meta_learners.variables import weight_decay 16 | 17 | 18 | def evaluate_gecko(sess, 19 | model, 20 | dataset, 21 | num_classes=1, 22 | num_shots=5, 23 | eval_inner_batch_size=5, 24 | eval_inner_iters=50, 25 | replacement=False, 26 | num_samples=100, 27 | transductive=False, 28 | weight_decay_rate=1, 29 | meta_fn=Gecko, 30 | visualize_predicted_segmentations=True, 31 | save_fine_tuned_checkpoints=False, 32 | save_fine_tuned_checkpoints_dir: Optional[str] = None, 33 | lr_scheduler=None, 34 | lr=None, 35 | augment=False, 36 | serially_eval_all_tasks: bool = False, 37 | aug_rate: Optional[float] = None, 38 | ) -> Tuple[float, Dict[str, List[float]]]: 39 | """ 40 | Evaluates an image segmentation model on a dataset. 41 | """ 42 | print("Evaluating with eval_inner_iters: {}".format(eval_inner_iters)) 43 | print("Evaluating with lr: {}".format(lr)) 44 | 45 | if save_fine_tuned_checkpoints: 46 | print("Saving fine-tuned checkpoints to {}".format(save_fine_tuned_checkpoints_dir)) 47 | if weight_decay_rate != 1: 48 | pre_step_op = weight_decay(weight_decay_rate) 49 | else: 50 | pre_step_op = None # no need to just multiply all vars by 1. 51 | gecko = meta_fn(sess, 52 | transductive=transductive, 53 | pre_step_op=pre_step_op, 54 | lr_scheduler=lr_scheduler, 55 | augment=augment, 56 | aug_rate=aug_rate) 57 | 58 | mean_ious = [] 59 | task_iou_map = {} 60 | for i in range(num_samples): 61 | mean_iou, task_iou_map_i = gecko.evaluate(dataset, model.input_ph, model.label_ph, 62 | model.minimize_op, model.predictions, 63 | num_classes=num_classes, num_shots=num_shots, 64 | inner_batch_size=eval_inner_batch_size, 65 | inner_iters=eval_inner_iters, replacement=replacement, 66 | eval_all_tasks=serially_eval_all_tasks, 67 | save_fine_tuned_checkpoints=save_fine_tuned_checkpoints, 68 | save_fine_tuned_checkpoints_dir=save_fine_tuned_checkpoints_dir, 69 | eval_sample_num=i, is_training_ph=model.is_training_ph, lr_ph=model.lr_ph, lr=lr,) 70 | for key, val in task_iou_map_i.items(): 71 | try: 72 | task_iou_map[key].append(val) 73 | except KeyError: 74 | task_iou_map[key] = [val] 75 | mean_ious.append(mean_iou) 76 | 77 | all_ious = list(itertools.chain(*task_iou_map.values())) 78 | ninety_five_perc_ci = ci95(all_ious) 79 | 80 | mean_of_all_task_splits = np.nanmean(all_ious) 81 | print("Mean of all {} task-splits: {} +/- 95% CI: {}".format(len(all_ious), mean_of_all_task_splits, ninety_five_perc_ci)) 82 | 83 | print("{} NaN values out of total number of samples: {}".format(np.count_nonzero(np.isnan(mean_ious)), num_samples)) 84 | mean_iou = np.nanmean(mean_ious) 85 | print("Mean of samples:") 86 | print("{} mean IoU, +/- 95% CI: {}".format(mean_iou, ninety_five_perc_ci)) 87 | print("Evaluated with eval_inner_iters: {}".format(eval_inner_iters)) 88 | print("Evaluated with lr: {}".format(lr)) 89 | 90 | return mean_iou, task_iou_map 91 | 92 | 93 | def optimize_update_hyperparams(sess, 94 | model, 95 | dataset, 96 | num_classes=1, 97 | num_shots=5, 98 | eval_inner_batch_size=5, 99 | eval_inner_iters=5, 100 | replacement=False, 101 | num_samples=100, 102 | transductive=False, 103 | weight_decay_rate=1, 104 | meta_fn=Gecko, 105 | save_fine_tuned_checkpoints=False, 106 | save_fine_tuned_checkpoints_dir: Optional[str] = None, 107 | lr_scheduler=None, 108 | lr=None, 109 | lr_search_range_low: float = 0.0005, 110 | lr_search_range_high: float = 0.05, 111 | drop_rate=None, 112 | drop_rate_search_range_low: float = 0.1, 113 | drop_rate_search_range_high: float = 0.8, 114 | aug_rate: float = 0.5, 115 | aug_rate_search_range_low: float = 0.5, 116 | aug_rate_search_range_high: float = 0.5, 117 | batch_size_search_range_low: int = 8, 118 | batch_size_search_range_high: int = 8, 119 | augment=False, 120 | serially_eval_all_tasks: bool = True, 121 | min_steps: int = 0, 122 | max_steps: int = 80, 123 | num_configs_to_sample=100, 124 | num_train_val_data_splits_to_sample_per_config=1, 125 | save_dir: Optional[str] = None, # Dir in which to save results csv. 126 | results_csv_name: str = "GP_val-set_hyper_param_search_results.csv", 127 | eval_tasks_with_median_early_stopping_iterations: bool = False, 128 | estimator: str = "GP", 129 | ): 130 | """ 131 | Evaluates an image segmentation model on a dataset. 132 | """ 133 | supported_estimators = {"GP"} 134 | assert estimator in supported_estimators 135 | 136 | if save_fine_tuned_checkpoints: 137 | print("Saving fine-tuned checkpoints to {}".format(save_fine_tuned_checkpoints_dir)) 138 | if weight_decay_rate != 1: 139 | pre_step_op = weight_decay(weight_decay_rate) 140 | else: 141 | pre_step_op = None # no need to just multiply all vars by 1. 142 | gecko = meta_fn(sess, 143 | transductive=transductive, 144 | pre_step_op=pre_step_op, 145 | lr_scheduler=lr_scheduler, 146 | augment=augment) 147 | 148 | params = {"dataset": dataset, "input_ph": model.input_ph, "label_ph": model.label_ph, 149 | "minimize_op": model.minimize_op, "predictions": model.predictions, "num_classes": num_classes, 150 | "num_shots": num_shots, "inner_batch_size": eval_inner_batch_size, 151 | "replacement": replacement, "eval_all_tasks": serially_eval_all_tasks, "is_training_ph": model.is_training_ph, # serially_eval_all_tasks 152 | "lr_ph": model.lr_ph, LEARNING_RATE_NAME: lr, "drop_rate_ph": model.final_layer_dropout_rate_ph, DROPOUT_RATE_NAME: drop_rate, AUG_RATE_NAME: aug_rate, 153 | "eval_tasks_with_median_early_stopping_iterations": eval_tasks_with_median_early_stopping_iterations, "min_steps": min_steps, "max_steps": max_steps, 154 | } 155 | 156 | eval_fn = gecko.evaluate_with_early_stopping 157 | 158 | if eval_tasks_with_median_early_stopping_iterations: 159 | print("Evaluating val-set tasks with median iterations returned by early stopping.") 160 | 161 | before_ext, ext = os.path.splitext(results_csv_name) 162 | before_ext += "_{}-shot".format(num_shots) 163 | results_csv_name = before_ext + ext 164 | if save_dir is not None: 165 | save_results_to = os.path.join(save_dir, results_csv_name) 166 | else: 167 | save_results_to = results_csv_name 168 | 169 | if estimator == "GP": 170 | best_lr, expected_best_step_num = lr_droprate_aug_rate_batch_size_gp_search(eval_fn, params, 171 | lr_search_range_low=lr_search_range_low, 172 | lr_search_range_high=lr_search_range_high, 173 | drop_rate_search_range_low=drop_rate_search_range_low, 174 | drop_rate_search_range_high=drop_rate_search_range_high, 175 | aug_rate_search_range_low=aug_rate_search_range_low, 176 | aug_rate_search_range_high=aug_rate_search_range_high, 177 | batch_size_search_range_low=batch_size_search_range_low, 178 | batch_size_search_range_high=batch_size_search_range_high, 179 | n=num_configs_to_sample, 180 | m=num_train_val_data_splits_to_sample_per_config, 181 | save_results_to=save_results_to) 182 | else: 183 | raise ValueError("Unsupported hyperparameter optimizer estimator {}. `estimator` must be in {}".format(estimator, supported_estimators)) 184 | 185 | return best_lr, expected_best_step_num 186 | 187 | 188 | DEFAULT_K_RANGE = [1, 5, 10, 50, 100, 200, 400] 189 | 190 | def run_k_shot_learning_curves_experiment(sess, 191 | model, 192 | dataset, 193 | num_classes=1, 194 | num_shots=5, 195 | eval_inner_batch_size=8, 196 | eval_inner_iters=5, 197 | replacement=False, 198 | num_samples=100, 199 | transductive=True, 200 | weight_decay_rate=1, 201 | meta_fn=Gecko, 202 | lr_scheduler=None, 203 | lr=None, 204 | augment=True, 205 | aug_rate: float = 0.5, 206 | csv_outpath="k-shot-results.csv", # None, 207 | iter_range=DEFAULT_ITER_RANGE, 208 | ): 209 | print("Running k-shot learning curves experiment over k-ranges {} and dataset {}".format(DEFAULT_K_RANGE, [x.name for x in dataset])) 210 | if iter_range is None: 211 | iter_range = DEFAULT_ITER_RANGE 212 | print("Using iter range {}".format(iter_range)) 213 | 214 | gecko = meta_fn(sess, 215 | transductive=transductive, 216 | pre_step_op=weight_decay(weight_decay_rate), 217 | lr_scheduler=lr_scheduler, 218 | augment=augment, 219 | aug_rate=aug_rate) 220 | 221 | ks, results = gecko.evaluate_m_k_shot_ranges_all_tasks(tasks=dataset, k_range=DEFAULT_K_RANGE, m=num_samples, input_ph=model.input_ph, 222 | label_ph=model.label_ph, minimize_op=model.minimize_op, predictions=model.predictions, 223 | inner_batch_size=eval_inner_batch_size, 224 | inner_iters=eval_inner_iters, replacement=replacement, is_training_ph=model.is_training_ph, 225 | lr_ph=model.lr_ph, lr=lr, test_samples=20, iter_range=iter_range, aug_rate=aug_rate) 226 | 227 | print("k-shot learning curve results:") 228 | print("ks:") 229 | print(ks) 230 | print("IoUs") 231 | print(results) 232 | 233 | if csv_outpath is not None: 234 | df = pd.DataFrame({"k": ks, "mIoU": results}) 235 | if not os.path.isfile(csv_outpath): 236 | df.to_csv(csv_outpath, index=False) 237 | else: 238 | df.to_csv(csv_outpath, mode="a", header=False) 239 | 240 | df.to_csv(csv_outpath, index=False) 241 | return ks, results 242 | -------------------------------------------------------------------------------- /meta_learners/supervised_reptile/supervised_reptile/miniimagenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loading and using the Mini-ImageNet dataset. 3 | 4 | To use these APIs, you should prepare a directory that 5 | contains three sub-directories: train, test, and val. 6 | Each of these three directories should contain one 7 | sub-directory per WordNet ID. 8 | """ 9 | 10 | import os 11 | import random 12 | 13 | from PIL import Image 14 | import numpy as np 15 | 16 | def read_dataset(data_dir): 17 | """ 18 | Read the Mini-ImageNet dataset. 19 | 20 | Args: 21 | data_dir: directory containing Mini-ImageNet. 22 | 23 | Returns: 24 | A tuple (train, val, test) of sequences of 25 | ImageNetClass instances. 26 | """ 27 | return tuple(_read_classes(os.path.join(data_dir, x)) for x in ['train', 'val', 'test']) 28 | 29 | def _read_classes(dir_path): 30 | """ 31 | Read the WNID directories in a directory. 32 | """ 33 | return [ImageNetClass(os.path.join(dir_path, f)) for f in os.listdir(dir_path) 34 | if f.startswith('n')] 35 | 36 | # pylint: disable=R0903 37 | class ImageNetClass: 38 | """ 39 | A single image class. 40 | """ 41 | def __init__(self, dir_path): 42 | self.dir_path = dir_path 43 | self._cache = {} 44 | 45 | def sample(self, num_images): 46 | """ 47 | Sample images (as numpy arrays) from the class. 48 | 49 | Returns: 50 | A sequence of 84x84x3 numpy arrays. 51 | Each pixel ranges from 0 to 1. 52 | """ 53 | names = [f for f in os.listdir(self.dir_path) if f.endswith('.JPEG')] 54 | random.shuffle(names) 55 | images = [] 56 | for name in names[:num_images]: 57 | images.append(self._read_image(name)) 58 | return images 59 | 60 | def _read_image(self, name): 61 | if name in self._cache: 62 | return self._cache[name].astype('float32') / 0xff 63 | with open(os.path.join(self.dir_path, name), 'rb') as in_file: 64 | img = Image.open(in_file).resize((84, 84)).convert('RGB') 65 | self._cache[name] = np.array(img) 66 | return self._read_image(name) 67 | -------------------------------------------------------------------------------- /meta_learners/supervised_reptile/supervised_reptile/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image classification models for supervised meta-learning. 3 | """ 4 | 5 | from functools import partial 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | from copy import deepcopy 10 | 11 | #dev packages: 12 | import pdb 13 | DEBUG = True 14 | 15 | 16 | DEFAULT_OPTIMIZER = partial(tf.train.AdamOptimizer, beta1=0) 17 | 18 | # pylint: disable=R0903 19 | 20 | 21 | class OmniglotModel: 22 | """ 23 | A model for Omniglot classification. 24 | """ 25 | def __init__(self, num_classes, optimizer=DEFAULT_OPTIMIZER, **optim_kwargs): 26 | self.input_ph = tf.placeholder(tf.float32, shape=(None, 28, 28)) 27 | out = tf.reshape(self.input_ph, (-1, 28, 28, 1)) 28 | for _ in range(4): 29 | out = tf.layers.conv2d(out, 64, 3, strides=2, padding='same') 30 | out = tf.layers.batch_normalization(out, training=True) 31 | out = tf.nn.relu(out) 32 | out = tf.reshape(out, (-1, int(np.prod(out.get_shape()[1:])))) 33 | self.logits = tf.layers.dense(out, num_classes) 34 | self.label_ph = tf.placeholder(tf.int32, shape=(None,)) 35 | self.loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.label_ph, 36 | logits=self.logits) 37 | self.predictions = tf.argmax(self.logits, axis=-1) 38 | self.minimize_op = optimizer(**optim_kwargs).minimize(self.loss) 39 | 40 | # pylint: disable=R0903 41 | class MiniImageNetModel: 42 | """ 43 | A model for Mini-ImageNet classification. 44 | """ 45 | def __init__(self, num_classes, optimizer=DEFAULT_OPTIMIZER, **optim_kwargs): 46 | self.input_ph = tf.placeholder(tf.float32, shape=(None, 84, 84, 3)) 47 | out = self.input_ph 48 | for _ in range(4): 49 | out = tf.layers.conv2d(out, 32, 3, padding='same') 50 | out = tf.layers.batch_normalization(out, training=True) 51 | out = tf.layers.max_pooling2d(out, 2, 2, padding='same') 52 | out = tf.nn.relu(out) 53 | out = tf.reshape(out, (-1, int(np.prod(out.get_shape()[1:])))) 54 | self.logits = tf.layers.dense(out, num_classes) 55 | self.label_ph = tf.placeholder(tf.int32, shape=(None,)) 56 | self.loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.label_ph, 57 | logits=self.logits) 58 | self.predictions = tf.argmax(self.logits, axis=-1) 59 | self.minimize_op = optimizer(**optim_kwargs).minimize(self.loss) 60 | 61 | -------------------------------------------------------------------------------- /meta_learners/supervised_reptile/supervised_reptile/omniglot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loading and augmenting the Omniglot dataset. 3 | 4 | To use these APIs, you should prepare a directory that 5 | contains all of the alphabets from both images_background 6 | and images_evaluation. 7 | """ 8 | 9 | import os 10 | import random 11 | 12 | from PIL import Image 13 | import numpy as np 14 | 15 | def read_dataset(data_dir): 16 | """ 17 | Iterate over the characters in a data directory. 18 | 19 | Args: 20 | data_dir: a directory of alphabet directories. 21 | 22 | Returns: 23 | An iterable over Characters. 24 | 25 | The dataset is unaugmented and not split up into 26 | training and test sets. 27 | """ 28 | for alphabet_name in sorted(os.listdir(data_dir)): 29 | alphabet_dir = os.path.join(data_dir, alphabet_name) 30 | if not os.path.isdir(alphabet_dir): 31 | continue 32 | for char_name in sorted(os.listdir(alphabet_dir)): 33 | if not char_name.startswith('character'): 34 | continue 35 | yield Character(os.path.join(alphabet_dir, char_name), 0) 36 | 37 | def split_dataset(dataset, num_train=1200): 38 | """ 39 | Split the dataset into a training and test set. 40 | 41 | Args: 42 | dataset: an iterable of Characters. 43 | 44 | Returns: 45 | A tuple (train, test) of Character sequences. 46 | """ 47 | all_data = list(dataset) 48 | random.shuffle(all_data) 49 | return all_data[:num_train], all_data[num_train:] 50 | 51 | def augment_dataset(dataset): 52 | """ 53 | Augment the dataset by adding 90 degree rotations. 54 | 55 | Args: 56 | dataset: an iterable of Characters. 57 | 58 | Returns: 59 | An iterable of augmented Characters. 60 | """ 61 | for character in dataset: 62 | for rotation in [0, 90, 180, 270]: 63 | yield Character(character.dir_path, rotation=rotation) 64 | 65 | # pylint: disable=R0903 66 | class Character: 67 | """ 68 | A single character class. 69 | """ 70 | def __init__(self, dir_path, rotation=0): 71 | self.dir_path = dir_path 72 | self.rotation = rotation 73 | self._cache = {} 74 | 75 | def sample(self, num_images): 76 | """ 77 | Sample images (as numpy arrays) from the class. 78 | 79 | Returns: 80 | A sequence of 28x28 numpy arrays. 81 | Each pixel ranges from 0 to 1. 82 | """ 83 | names = [f for f in os.listdir(self.dir_path) if f.endswith('.png')] 84 | random.shuffle(names) 85 | images = [] 86 | for name in names[:num_images]: 87 | images.append(self._read_image(os.path.join(self.dir_path, name))) 88 | return images 89 | 90 | def _read_image(self, path): 91 | if path in self._cache: 92 | return self._cache[path] 93 | with open(path, 'rb') as in_file: 94 | img = Image.open(in_file).resize((28, 28)).rotate(self.rotation) 95 | self._cache[path] = np.array(img).astype('float32') 96 | return self._cache[path] 97 | -------------------------------------------------------------------------------- /meta_learners/supervised_reptile/supervised_reptile/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training helpers for supervised meta-learning. 3 | """ 4 | 5 | import os 6 | import time 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | from typing import Optional 11 | 12 | from utils.util import log_estimated_time_remaining 13 | from .reptile import Gecko 14 | from meta_learners.variables import weight_decay 15 | 16 | 17 | # pylint: disable=R0913,R0914 18 | def train_gecko(sess, 19 | model, 20 | train_set, 21 | test_set, 22 | save_dir, 23 | num_classes=5, 24 | num_shots=5, 25 | inner_batch_size=5, 26 | inner_iters=20, 27 | replacement=False, 28 | meta_step_size=0.1, 29 | meta_step_size_final=0.1, 30 | meta_batch_size=1, 31 | meta_iters=10000, 32 | eval_inner_batch_size=5, 33 | eval_inner_iters=50, 34 | eval_interval=10, 35 | weight_decay_rate=1, 36 | time_deadline=None, 37 | train_shots=None, 38 | transductive=False, 39 | meta_fn=Gecko, 40 | log_fn=print, 41 | save_checkpoint_every_n_meta_iters=100, 42 | max_checkpoints_to_keep=2, 43 | augment=False, 44 | lr_scheduler=None, 45 | lr=None, 46 | save_best_seen=False, 47 | num_tasks_to_eval=100, 48 | aug_rate: Optional[float]=None): 49 | """ 50 | Train a model on a dataset. 51 | """ 52 | if not os.path.exists(save_dir): 53 | os.mkdir(save_dir) 54 | saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep) 55 | 56 | if save_best_seen: 57 | best_save_dir = os.path.join(save_dir, "best_eval") 58 | if not os.path.exists(best_save_dir): 59 | os.mkdir(best_save_dir) 60 | best_saver = tf.train.Saver(max_to_keep=1) 61 | best_eval_iou = -np.inf 62 | 63 | if weight_decay_rate != 1: 64 | pre_step_op = weight_decay(weight_decay_rate) 65 | else: 66 | pre_step_op = None # no need to just multiply all vars by 1. 67 | reptile = meta_fn(sess, 68 | transductive=transductive, 69 | pre_step_op=pre_step_op, lr_scheduler=lr_scheduler, augment=augment, aug_rate=aug_rate) 70 | iou_ph = tf.placeholder(tf.float32, shape=()) 71 | tf.summary.scalar('IoU', iou_ph) 72 | merged = tf.summary.merge_all() 73 | train_writer = tf.summary.FileWriter(os.path.join(save_dir, 'train'), sess.graph) 74 | test_writer = tf.summary.FileWriter(os.path.join(save_dir, 'test'), sess.graph) 75 | try: 76 | if not model.variables_initialized: 77 | print("Initializing variables.") 78 | tf.global_variables_initializer().run() 79 | sess.run(tf.global_variables_initializer()) 80 | except AttributeError: 81 | print("Model does not explicitly track whether variable initialization has already been run on the graph.") 82 | print("Initializing variables.") 83 | tf.global_variables_initializer().run() 84 | sess.run(tf.global_variables_initializer()) 85 | 86 | 87 | for i in range(meta_iters): 88 | begin_time = time.time() 89 | print('Reptile training step {} of {}'.format(i + 1, meta_iters)) 90 | frac_done = i / meta_iters 91 | print('{} done'.format(frac_done)) 92 | cur_meta_step_size = frac_done * meta_step_size_final + (1 - frac_done) * meta_step_size 93 | print("Current meta-step size: {}".format(cur_meta_step_size)) 94 | reptile.train_step(train_set, model.input_ph, model.label_ph, model.minimize_op, 95 | num_classes=num_classes, num_shots=(train_shots or num_shots), 96 | inner_batch_size=inner_batch_size, inner_iters=inner_iters, 97 | replacement=replacement, 98 | meta_step_size=cur_meta_step_size, meta_batch_size=meta_batch_size, lr_ph=model.lr_ph, lr=lr,) 99 | # call Gecko.evaluate to track progress: 100 | if i % eval_interval == 0: 101 | print('Evaluating training performance.') 102 | # track accuracy with mean intersection over union: 103 | mean_ious = [] 104 | for dataset, writer in [(train_set, train_writer), (test_set, test_writer)]: 105 | mean_iou, _ = reptile.evaluate(dataset, model.input_ph, model.label_ph, 106 | model.minimize_op, model.predictions, 107 | num_classes=num_classes, num_shots=num_shots, 108 | inner_batch_size=eval_inner_batch_size, 109 | inner_iters=eval_inner_iters, replacement=replacement, 110 | eval_all_tasks=False, 111 | num_tasks_to_sample=num_tasks_to_eval, 112 | save_fine_tuned_checkpoints=False, is_training_ph=model.is_training_ph, 113 | lr_ph=model.lr_ph) 114 | summary = sess.run(merged, feed_dict={iou_ph: mean_iou}) 115 | writer.add_summary(summary, i) 116 | # Log the learning rate: 117 | summary = tf.Summary(value=[tf.Summary.Value(tag="meta_step_size", simple_value=cur_meta_step_size)]) 118 | writer.add_summary(summary, i) 119 | writer.flush() 120 | mean_ious.append(mean_iou) 121 | log_fn('Train step %d: train=%f test=%f' % (i, mean_ious[0], mean_ious[1])) 122 | 123 | if save_best_seen and mean_ious[1] > best_eval_iou: 124 | best_eval_iou = mean_ious[1] 125 | print("Highest test-set evaluation IoU seen at step {}: {}".format(i, best_eval_iou)) 126 | print("Saving checkpoint to {}.".format(best_save_dir)) 127 | best_saver.save(sess, os.path.join(best_save_dir, 'model.ckpt'), global_step=i) 128 | 129 | if i % save_checkpoint_every_n_meta_iters == 0 or i == meta_iters - 1: # save checkpoint every n (should be 100) meta-iters and final meta-iter 130 | print("Saving checkpoint to {}.".format(save_dir)) 131 | saver.save(sess, os.path.join(save_dir, 'model.ckpt'), global_step=i) 132 | if time_deadline is not None and time.time() > time_deadline: 133 | break 134 | log_estimated_time_remaining(begin_time, i, meta_iters) 135 | return reptile 136 | -------------------------------------------------------------------------------- /meta_learners/variables.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools for manipulating sets of variables. 3 | """ 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | 9 | def interpolate_vars(old_vars, new_vars, epsilon): 10 | """ 11 | Interpolate between two sequences of variables. 12 | """ 13 | return add_vars(old_vars, scale_vars(subtract_vars(new_vars, old_vars), epsilon)) 14 | 15 | 16 | def average_vars(var_seqs): 17 | """ 18 | Average a sequence of variable sequences. 19 | """ 20 | res = [] 21 | for variables in zip(*var_seqs): 22 | res.append(np.mean(variables, axis=0)) 23 | return res 24 | 25 | 26 | def subtract_vars(var_seq_1, var_seq_2): 27 | """ 28 | Subtract `var_seq_2` from `var_seq_1`. 29 | """ 30 | # import pdb; pdb.set_trace() 31 | return [v1 - v2 for v1, v2 in zip(var_seq_1, var_seq_2)] 32 | 33 | 34 | def add_vars(var_seq_1, var_seq_2): 35 | """ 36 | Add two variable sequences. 37 | """ 38 | return [v1 + v2 for v1, v2 in zip(var_seq_1, var_seq_2)] 39 | 40 | 41 | def scale_vars(var_seq, scale): 42 | """ 43 | Scale a variable sequence. 44 | """ 45 | return [v * scale for v in var_seq] 46 | 47 | 48 | def weight_decay(rate, variables=None): 49 | """ 50 | Create an Op that performs weight decay. 51 | """ 52 | if variables is None: 53 | variables = tf.trainable_variables() 54 | ops = [tf.assign(var, var * rate) for var in variables] 55 | return tf.group(*ops) 56 | 57 | 58 | class VariableState: 59 | """ 60 | Manage the state of a set of variables. 61 | """ 62 | def __init__(self, session, variables): 63 | self._session = session 64 | self._variables = variables 65 | self._placeholders = [tf.placeholder(v.dtype.base_dtype, shape=v.get_shape()) 66 | for v in variables] 67 | assigns = [tf.assign(v, p) for v, p in zip(self._variables, self._placeholders)] 68 | self._assign_op = tf.group(*assigns) 69 | 70 | def export_variables(self): 71 | """ 72 | Save the current variables. 73 | """ 74 | return self._session.run(self._variables) 75 | 76 | def import_variables(self, values): 77 | """ 78 | Restore the variables. 79 | """ 80 | self._session.run(self._assign_op, feed_dict=dict(zip(self._placeholders, values))) 81 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image segmentation neural networks 3 | """ 4 | -------------------------------------------------------------------------------- /models/constants.py: -------------------------------------------------------------------------------- 1 | SUPPORTED_MODELS = {"efficientlab"} 2 | -------------------------------------------------------------------------------- /models/efficientnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml4ai/mliis/f40352e734f77609bcd5c4ad330ea73a897a217d/models/efficientnet/__init__.py -------------------------------------------------------------------------------- /models/efficientnet/constants.py: -------------------------------------------------------------------------------- 1 | MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255] 2 | STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255] -------------------------------------------------------------------------------- /models/efficientnet/efficientnet_builder.py: -------------------------------------------------------------------------------- 1 | # Modified Copyright 2020 Sean M. Hendryx. All Rights Reserved. 2 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Model Builder for EfficientNet.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import re 24 | import tensorflow as tf 25 | 26 | from models.efficientnet import efficientnet_model 27 | 28 | 29 | def efficientnet_params(model_name): 30 | """Get efficientnet params based on model name.""" 31 | params_dict = { 32 | # (width_coefficient, depth_coefficient, resolution, dropout_rate) 33 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 34 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 35 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 36 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 37 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 38 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 39 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 40 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5), 41 | } 42 | return params_dict[model_name] 43 | 44 | 45 | class BlockDecoder(object): 46 | """Block Decoder for readability.""" 47 | 48 | def _decode_block_string(self, block_string): 49 | """Gets a block through a string notation of arguments.""" 50 | assert isinstance(block_string, str) 51 | ops = block_string.split('_') 52 | options = {} 53 | for op in ops: 54 | splits = re.split(r'(\d.*)', op) 55 | if len(splits) >= 2: 56 | key, value = splits[:2] 57 | options[key] = value 58 | 59 | if 's' not in options or len(options['s']) != 2: 60 | raise ValueError('Strides options should be a pair of integers.') 61 | 62 | return efficientnet_model.BlockArgs( 63 | kernel_size=int(options['k']), 64 | num_repeat=int(options['r']), 65 | input_filters=int(options['i']), 66 | output_filters=int(options['o']), 67 | expand_ratio=int(options['e']), 68 | id_skip=('noskip' not in block_string), 69 | se_ratio=float(options['se']) if 'se' in options else None, 70 | strides=[int(options['s'][0]), int(options['s'][1])], 71 | conv_type=int(options['c']) if 'c' in options else 0) 72 | 73 | def _encode_block_string(self, block): 74 | """Encodes a block to a string.""" 75 | args = [ 76 | 'r%d' % block.num_repeat, 77 | 'k%d' % block.kernel_size, 78 | 's%d%d' % (block.strides[0], block.strides[1]), 79 | 'e%s' % block.expand_ratio, 80 | 'i%d' % block.input_filters, 81 | 'o%d' % block.output_filters, 82 | 'c%d' % block.conv_type, 83 | ] 84 | if block.se_ratio > 0 and block.se_ratio <= 1: 85 | args.append('se%s' % block.se_ratio) 86 | if block.id_skip is False: 87 | args.append('noskip') 88 | return '_'.join(args) 89 | 90 | def decode(self, string_list, max_block_num=None): 91 | """Decodes a list of string notations to specify blocks inside the network. 92 | 93 | Args: 94 | string_list: a list of strings, each string is a notation of block. 95 | 96 | Returns: 97 | A list of namedtuples to represent blocks arguments. 98 | """ 99 | assert isinstance(string_list, list) 100 | blocks_args = [] 101 | num_blocks = 0 102 | for block_string in string_list: 103 | block_args = self._decode_block_string(block_string) 104 | num_blocks += block_args.num_repeat 105 | if max_block_num is not None and num_blocks > max_block_num + 1: # account for zero-indexed blocks 106 | print("more blocks than max_block_num. Stopping graph construction of more blocks at {} blocks.".format(max_block_num)) 107 | break 108 | blocks_args.append(block_args) 109 | return blocks_args 110 | 111 | def encode(self, blocks_args): 112 | """Encodes a list of Blocks to a list of strings. 113 | 114 | Args: 115 | blocks_args: A list of namedtuples to represent blocks arguments. 116 | Returns: 117 | a list of strings, each string is a notation of block. 118 | """ 119 | block_strings = [] 120 | for block in blocks_args: 121 | block_strings.append(self._encode_block_string(block)) 122 | return block_strings 123 | 124 | 125 | def efficientnet(width_coefficient=None, 126 | depth_coefficient=None, 127 | dropout_rate=0.2, 128 | drop_connect_rate=0.2, max_block_num=None): 129 | """Creates an efficientnet model.""" 130 | blocks_args = [ 131 | 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', 132 | 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', 133 | 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', 134 | 'r1_k3_s11_e6_i192_o320_se0.25', 135 | ] 136 | global_params = efficientnet_model.GlobalParams( 137 | batch_norm_momentum=0.99, 138 | batch_norm_epsilon=1e-3, 139 | dropout_rate=dropout_rate, 140 | drop_connect_rate=drop_connect_rate, 141 | data_format='channels_last', 142 | num_classes=1000, 143 | width_coefficient=width_coefficient, 144 | depth_coefficient=depth_coefficient, 145 | depth_divisor=8, 146 | min_depth=None, 147 | relu_fn=tf.nn.swish) 148 | decoder = BlockDecoder() 149 | return decoder.decode(blocks_args, max_block_num), global_params 150 | 151 | 152 | def get_model_params(model_name, override_params, max_block_num=None): 153 | """Get the block args and global params for a given model.""" 154 | if model_name.startswith('efficientnet'): 155 | width_coefficient, depth_coefficient, _, dropout_rate = ( 156 | efficientnet_params(model_name)) 157 | blocks_args, global_params = efficientnet( 158 | width_coefficient, depth_coefficient, dropout_rate, max_block_num=max_block_num) 159 | else: 160 | raise NotImplementedError('model name is not pre-defined: %s' % model_name) 161 | 162 | if override_params: 163 | # ValueError will be raised here if override_params has fields not included 164 | # in global_params. 165 | global_params = global_params._replace(**override_params) 166 | 167 | tf.logging.info('global_params= %s', global_params) 168 | tf.logging.info('blocks_args= %s', blocks_args) 169 | return blocks_args, global_params 170 | 171 | 172 | def build_model(images, 173 | model_name, 174 | training, 175 | override_params=None, 176 | model_dir=None): 177 | """A helper functiion to creates a model and returns predicted logits. 178 | 179 | Args: 180 | images: input images tensor. 181 | model_name: string, the predefined model name. 182 | training: boolean, whether the model is constructed for training. 183 | override_params: A dictionary of params for overriding. Fields must exist in 184 | efficientnet_model.GlobalParams. 185 | model_dir: string, optional model dir for saving configs. 186 | 187 | Returns: 188 | logits: the logits tensor of classes. 189 | endpoints: the endpoints for each layer. 190 | 191 | Raises: 192 | When model_name specified an undefined model, raises NotImplementedError. 193 | When override_params has invalid fields, raises ValueError. 194 | """ 195 | assert isinstance(images, tf.Tensor) 196 | blocks_args, global_params = get_model_params(model_name, override_params) 197 | 198 | if model_dir: 199 | param_file = os.path.join(model_dir, 'model_params.txt') 200 | if not tf.gfile.Exists(param_file): 201 | if not tf.gfile.Exists(model_dir): 202 | tf.gfile.MakeDirs(model_dir) 203 | with tf.gfile.GFile(param_file, 'w') as f: 204 | tf.logging.info('writing to %s' % param_file) 205 | f.write('model_name= %s\n\n' % model_name) 206 | f.write('global_params= %s\n\n' % str(global_params)) 207 | f.write('blocks_args= %s\n\n' % str(blocks_args)) 208 | 209 | with tf.variable_scope(model_name): 210 | model = efficientnet_model.Model(blocks_args, global_params) 211 | logits = model(images, training=training) 212 | 213 | logits = tf.identity(logits, 'logits') 214 | return logits, model.endpoints 215 | 216 | 217 | def build_model_base(images, model_name, training, override_params=None, max_block_num=None): 218 | """A helper functiion to create a base model and return global_pool. 219 | 220 | Args: 221 | images: input images tensor. 222 | model_name: string, the model name of a pre-defined MnasNet. 223 | training: boolean, whether the model is constructed for training. 224 | override_params: A dictionary of params for overriding. Fields must exist in 225 | mnasnet_model.GlobalParams. 226 | 227 | Returns: 228 | features: global pool features. 229 | endpoints: the endpoints for each layer. 230 | 231 | Raises: 232 | When model_name specified an undefined model, raises NotImplementedError. 233 | When override_params has invalid fields, raises ValueError. 234 | """ 235 | assert isinstance(images, tf.Tensor) 236 | blocks_args, global_params = get_model_params(model_name, override_params, max_block_num) 237 | 238 | with tf.variable_scope(model_name, reuse=tf.AUTO_REUSE): 239 | model = efficientnet_model.Model(blocks_args, global_params) 240 | features = model(images, training=training, features_only=True) 241 | 242 | features = tf.identity(features, 'global_pool') 243 | return features, model.endpoints 244 | -------------------------------------------------------------------------------- /models/efficientnet/efficientnet_model.py: -------------------------------------------------------------------------------- 1 | # Modified Copyright 2020 Sean M. Hendryx. All Rights Reserved. 2 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """ 17 | Contains definitions for EfficientNet model architecture adapted for few-shot meta-learning. 18 | 19 | [1] Mingxing Tan, Quoc V. Le 20 | EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. 21 | ICML'19, https://arxiv.org/abs/1905.11946 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | import os 29 | import sys 30 | sys.path.insert(1, os.path.join(sys.path[0], '../..')) 31 | 32 | import collections 33 | import math 34 | import numpy as np 35 | import six 36 | from six.moves import xrange # pylint: disable=redefined-builtin 37 | import tensorflow as tf 38 | 39 | from models.efficientnet import utils 40 | 41 | 42 | GlobalParams = collections.namedtuple('GlobalParams', [ 43 | 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'data_format', 44 | 'num_classes', 'width_coefficient', 'depth_coefficient', 45 | 'depth_divisor', 'min_depth', 'drop_connect_rate', 'relu_fn', 46 | ]) 47 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) 48 | 49 | # batchnorm = tf.layers.BatchNormalization 50 | batchnorm = utils.TpuBatchNormalization # TPU-specific requirement. 51 | 52 | BlockArgs = collections.namedtuple('BlockArgs', [ 53 | 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', 54 | 'expand_ratio', 'id_skip', 'strides', 'se_ratio', 'conv_type', 55 | ]) 56 | # defaults will be a public argument for namedtuple in Python 3.7 57 | # https://docs.python.org/3/library/collections.html#collections.namedtuple 58 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) 59 | 60 | 61 | def conv_kernel_initializer(shape, dtype=None, partition_info=None): 62 | """Initialization for convolutional kernels. 63 | 64 | The main difference with tf.variance_scaling_initializer is that 65 | tf.variance_scaling_initializer uses a truncated normal with an uncorrected 66 | standard deviation, whereas here we use a normal distribution. Similarly, 67 | tf.contrib.layers.variance_scaling_initializer uses a truncated normal with 68 | a corrected standard deviation. 69 | 70 | Args: 71 | shape: shape of variable 72 | dtype: dtype of variable 73 | partition_info: unused 74 | 75 | Returns: 76 | an initialization for the variable 77 | """ 78 | del partition_info 79 | kernel_height, kernel_width, _, out_filters = shape 80 | fan_out = int(kernel_height * kernel_width * out_filters) 81 | return tf.random_normal( 82 | shape, mean=0.0, stddev=np.sqrt(2.0 / fan_out), dtype=dtype) 83 | 84 | 85 | def dense_kernel_initializer(shape, dtype=None, partition_info=None): 86 | """Initialization for dense kernels. 87 | 88 | This initialization is equal to 89 | tf.variance_scaling_initializer(scale=1.0/3.0, mode='fan_out', 90 | distribution='uniform'). 91 | It is written out explicitly here for clarity. 92 | 93 | Args: 94 | shape: shape of variable 95 | dtype: dtype of variable 96 | partition_info: unused 97 | 98 | Returns: 99 | an initialization for the variable 100 | """ 101 | del partition_info 102 | init_range = 1.0 / np.sqrt(shape[1]) 103 | return tf.random_uniform(shape, -init_range, init_range, dtype=dtype) 104 | 105 | 106 | def round_filters(filters, global_params): 107 | """Round number of filters based on depth multiplier.""" 108 | orig_f = filters 109 | multiplier = global_params.width_coefficient 110 | divisor = global_params.depth_divisor 111 | min_depth = global_params.min_depth 112 | if not multiplier: 113 | return filters 114 | 115 | filters *= multiplier 116 | min_depth = min_depth or divisor 117 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) 118 | # Make sure that round down does not go down by more than 10%. 119 | if new_filters < 0.9 * filters: 120 | new_filters += divisor 121 | tf.logging.info('round_filter input={} output={}'.format(orig_f, new_filters)) 122 | return int(new_filters) 123 | 124 | 125 | def round_repeats(repeats, global_params): 126 | """Round number of filters based on depth multiplier.""" 127 | multiplier = global_params.depth_coefficient 128 | if not multiplier: 129 | return repeats 130 | return int(math.ceil(multiplier * repeats)) 131 | 132 | 133 | class MBConvBlock(object): 134 | """A class of MBConv: Mobile Inverted Residual Bottleneck. 135 | 136 | Attributes: 137 | endpoints: dict. A list of internal tensors. 138 | """ 139 | 140 | def __init__(self, block_args, global_params): 141 | """Initializes a MBConv block. 142 | 143 | Args: 144 | block_args: BlockArgs, arguments to create a Block. 145 | global_params: GlobalParams, a set of global parameters. 146 | """ 147 | self._block_args = block_args 148 | self._batch_norm_momentum = global_params.batch_norm_momentum 149 | self._batch_norm_epsilon = global_params.batch_norm_epsilon 150 | self._data_format = global_params.data_format 151 | if self._data_format == 'channels_first': 152 | self._channel_axis = 1 153 | self._spatial_dims = [2, 3] 154 | else: 155 | self._channel_axis = -1 156 | self._spatial_dims = [1, 2] 157 | 158 | self._relu_fn = global_params.relu_fn or tf.nn.swish 159 | self._has_se = (self._block_args.se_ratio is not None) and ( 160 | self._block_args.se_ratio > 0) and (self._block_args.se_ratio <= 1) 161 | 162 | self.endpoints = None 163 | 164 | # Builds the block according to arguments. 165 | self._build() 166 | 167 | def block_args(self): 168 | return self._block_args 169 | 170 | def _build(self): 171 | """Builds block according to the arguments.""" 172 | filters = self._block_args.input_filters * self._block_args.expand_ratio 173 | if self._block_args.expand_ratio != 1: 174 | # Expansion phase: 175 | self._expand_conv = tf.layers.Conv2D( 176 | filters, 177 | kernel_size=[1, 1], 178 | strides=[1, 1], 179 | kernel_initializer=conv_kernel_initializer, 180 | padding='same', 181 | data_format=self._data_format, 182 | use_bias=False) 183 | self._bn0 = batchnorm( 184 | axis=self._channel_axis, 185 | momentum=self._batch_norm_momentum, 186 | epsilon=self._batch_norm_epsilon) 187 | 188 | kernel_size = self._block_args.kernel_size 189 | # Depth-wise convolution phase: 190 | self._depthwise_conv = utils.DepthwiseConv2D( 191 | [kernel_size, kernel_size], 192 | strides=self._block_args.strides, 193 | depthwise_initializer=conv_kernel_initializer, 194 | padding='same', 195 | data_format=self._data_format, 196 | use_bias=False) 197 | self._bn1 = batchnorm( 198 | axis=self._channel_axis, 199 | momentum=self._batch_norm_momentum, 200 | epsilon=self._batch_norm_epsilon) 201 | 202 | if self._has_se: 203 | num_reduced_filters = max( 204 | 1, int(self._block_args.input_filters * self._block_args.se_ratio)) 205 | # Squeeze and Excitation layer. 206 | self._se_reduce = tf.layers.Conv2D( 207 | num_reduced_filters, 208 | kernel_size=[1, 1], 209 | strides=[1, 1], 210 | kernel_initializer=conv_kernel_initializer, 211 | padding='same', 212 | data_format=self._data_format, 213 | use_bias=True) 214 | self._se_expand = tf.layers.Conv2D( 215 | filters, 216 | kernel_size=[1, 1], 217 | strides=[1, 1], 218 | kernel_initializer=conv_kernel_initializer, 219 | padding='same', 220 | data_format=self._data_format, 221 | use_bias=True) 222 | 223 | # Output phase: 224 | filters = self._block_args.output_filters 225 | self._project_conv = tf.layers.Conv2D( 226 | filters, 227 | kernel_size=[1, 1], 228 | strides=[1, 1], 229 | kernel_initializer=conv_kernel_initializer, 230 | padding='same', 231 | data_format=self._data_format, 232 | use_bias=False) 233 | self._bn2 = batchnorm( 234 | axis=self._channel_axis, 235 | momentum=self._batch_norm_momentum, 236 | epsilon=self._batch_norm_epsilon) 237 | 238 | def _call_se(self, input_tensor): 239 | """Call Squeeze and Excitation layer. 240 | 241 | Args: 242 | input_tensor: Tensor, a single input tensor for Squeeze/Excitation layer. 243 | 244 | Returns: 245 | A output tensor, which should have the same shape as input. 246 | """ 247 | se_tensor = tf.reduce_mean(input_tensor, self._spatial_dims, keepdims=True) 248 | se_tensor = self._se_expand(self._relu_fn(self._se_reduce(se_tensor))) 249 | tf.logging.info('Built Squeeze and Excitation with tensor shape: %s' % 250 | (se_tensor.shape)) 251 | return tf.sigmoid(se_tensor) * input_tensor 252 | 253 | def call(self, inputs, training=True, drop_connect_rate=None): 254 | """Implementation of call(). 255 | 256 | Args: 257 | inputs: the inputs tensor. 258 | training: boolean, whether the model is constructed for training. 259 | drop_connect_rate: float, between 0 to 1, drop connect rate. 260 | 261 | Returns: 262 | A output tensor. 263 | """ 264 | tf.logging.info('Block input: %s shape: %s' % (inputs.name, inputs.shape)) 265 | if self._block_args.expand_ratio != 1: 266 | x = self._relu_fn(self._bn0(self._expand_conv(inputs), training=training)) 267 | else: 268 | x = inputs 269 | tf.logging.info('Expand: %s shape: %s' % (x.name, x.shape)) 270 | 271 | x = self._relu_fn(self._bn1(self._depthwise_conv(x), training=training)) 272 | tf.logging.info('DWConv: %s shape: %s' % (x.name, x.shape)) 273 | 274 | if self._has_se: 275 | with tf.variable_scope('se'): 276 | x = self._call_se(x) 277 | 278 | self.endpoints = {'expansion_output': x} 279 | 280 | x = self._bn2(self._project_conv(x), training=training) 281 | if self._block_args.id_skip: 282 | if all( 283 | s == 1 for s in self._block_args.strides 284 | ) and self._block_args.input_filters == self._block_args.output_filters: 285 | # only apply drop_connect if skip presents. 286 | if drop_connect_rate: 287 | x = utils.drop_connect(x, training, drop_connect_rate) 288 | x = tf.add(x, inputs) 289 | tf.logging.info('Project: %s shape: %s' % (x.name, x.shape)) 290 | return x 291 | 292 | 293 | class Model(tf.keras.Model): 294 | """A class implements tf.keras.Model for MNAS-like model. 295 | 296 | Reference: https://arxiv.org/abs/1807.11626 297 | """ 298 | 299 | def __init__(self, blocks_args=None, global_params=None): 300 | """Initializes an `Model` instance. 301 | 302 | Args: 303 | blocks_args: A list of BlockArgs to construct block modules. 304 | global_params: GlobalParams, a set of global parameters. 305 | 306 | Raises: 307 | ValueError: when blocks_args is not specified as a list. 308 | """ 309 | super(Model, self).__init__() 310 | if not isinstance(blocks_args, list): 311 | raise ValueError('blocks_args should be a list.') 312 | self._global_params = global_params 313 | self._blocks_args = blocks_args 314 | self._relu_fn = global_params.relu_fn or tf.nn.swish 315 | 316 | self.endpoints = None 317 | 318 | self._build() 319 | 320 | def _get_conv_block(self, conv_type): 321 | conv_block_map = { 322 | 0: MBConvBlock 323 | } 324 | return conv_block_map[conv_type] 325 | 326 | def _build(self): 327 | """Builds a model.""" 328 | self._blocks = [] 329 | # Builds blocks. 330 | for block_args in self._blocks_args: 331 | assert block_args.num_repeat > 0 332 | # Update block input and output filters based on depth multiplier. 333 | block_args = block_args._replace( 334 | input_filters=round_filters(block_args.input_filters, 335 | self._global_params), 336 | output_filters=round_filters(block_args.output_filters, 337 | self._global_params), 338 | num_repeat=round_repeats(block_args.num_repeat, self._global_params)) 339 | 340 | # The first block needs to take care of stride and filter size increase. 341 | conv_block = self._get_conv_block(block_args.conv_type) 342 | self._blocks.append(conv_block(block_args, self._global_params)) 343 | if block_args.num_repeat > 1: 344 | # pylint: disable=protected-access 345 | block_args = block_args._replace( 346 | input_filters=block_args.output_filters, strides=[1, 1]) 347 | # pylint: enable=protected-access 348 | for _ in xrange(block_args.num_repeat - 1): 349 | self._blocks.append(conv_block(block_args, self._global_params)) 350 | 351 | batch_norm_momentum = self._global_params.batch_norm_momentum 352 | batch_norm_epsilon = self._global_params.batch_norm_epsilon 353 | if self._global_params.data_format == 'channels_first': 354 | channel_axis = 1 355 | else: 356 | channel_axis = -1 357 | 358 | # Stem part. 359 | self._conv_stem = tf.layers.Conv2D( 360 | filters=round_filters(32, self._global_params), 361 | kernel_size=[3, 3], 362 | strides=[2, 2], 363 | kernel_initializer=conv_kernel_initializer, 364 | padding='same', 365 | data_format=self._global_params.data_format, 366 | use_bias=False) 367 | self._bn0 = batchnorm( 368 | axis=channel_axis, 369 | momentum=batch_norm_momentum, 370 | epsilon=batch_norm_epsilon) 371 | 372 | # Head part. 373 | self._conv_head = tf.layers.Conv2D( 374 | filters=round_filters(1280, self._global_params), 375 | kernel_size=[1, 1], 376 | strides=[1, 1], 377 | kernel_initializer=conv_kernel_initializer, 378 | padding='same', 379 | use_bias=False) 380 | self._bn1 = batchnorm( 381 | axis=channel_axis, 382 | momentum=batch_norm_momentum, 383 | epsilon=batch_norm_epsilon) 384 | 385 | self._avg_pooling = tf.keras.layers.GlobalAveragePooling2D( 386 | data_format=self._global_params.data_format) 387 | # Dropout is called here at graph construction time 388 | if self._global_params.dropout_rate > 0: 389 | self._dropout = tf.keras.layers.Dropout(self._global_params.dropout_rate) 390 | else: 391 | self._dropout = None 392 | self._fc = tf.layers.Dense( 393 | self._global_params.num_classes, 394 | kernel_initializer=dense_kernel_initializer) 395 | 396 | def call(self, inputs, training=True, features_only=None): 397 | """Implementation of call(). 398 | 399 | Args: 400 | inputs: input tensors. 401 | training: boolean, whether the model is constructed for training. 402 | features_only: build the base feature network only. 403 | 404 | Returns: 405 | output tensors. 406 | """ 407 | outputs = None 408 | self.endpoints = {} 409 | # Calls Stem layers 410 | with tf.variable_scope('stem'): 411 | outputs = self._relu_fn( 412 | self._bn0(self._conv_stem(inputs), training=training)) 413 | tf.logging.info('Built stem layers with output shape: %s' % outputs.shape) 414 | self.endpoints['stem'] = outputs 415 | 416 | # Calls blocks. 417 | reduction_idx = 0 418 | for idx, block in enumerate(self._blocks): 419 | is_reduction = False 420 | if ((idx == len(self._blocks) - 1) or 421 | self._blocks[idx + 1].block_args().strides[0] > 1): 422 | is_reduction = True 423 | reduction_idx += 1 424 | 425 | with tf.variable_scope('blocks_%s' % idx): 426 | drop_rate = self._global_params.drop_connect_rate 427 | if drop_rate: 428 | drop_rate *= float(idx) / len(self._blocks) 429 | tf.logging.info('block_%s drop_connect_rate: %s' % (idx, drop_rate)) 430 | outputs = block.call( 431 | outputs, training=training, drop_connect_rate=drop_rate) 432 | self.endpoints['block_%s' % idx] = outputs 433 | if is_reduction: 434 | self.endpoints['reduction_%s' % reduction_idx] = outputs 435 | if block.endpoints: 436 | for k, v in six.iteritems(block.endpoints): 437 | self.endpoints['block_%s/%s' % (idx, k)] = v 438 | if is_reduction: 439 | self.endpoints['reduction_%s/%s' % (reduction_idx, k)] = v 440 | self.endpoints['global_pool'] = outputs 441 | 442 | if not features_only: 443 | # Calls final layers and returns logits. 444 | with tf.variable_scope('head'): 445 | outputs = self._relu_fn( 446 | self._bn1(self._conv_head(outputs), training=training)) 447 | outputs = self._avg_pooling(outputs) 448 | if self._dropout: 449 | outputs = self._dropout(outputs, training=training) 450 | self.endpoints['global_pool'] = outputs 451 | outputs = self._fc(outputs) 452 | self.endpoints['head'] = outputs 453 | return outputs -------------------------------------------------------------------------------- /models/efficientnet/utils.py: -------------------------------------------------------------------------------- 1 | # Modified Copyright 2020 Sean M. Hendryx. All Rights Reserved. 2 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Model utilities.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import numpy as np 24 | import tensorflow as tf 25 | 26 | from tensorflow.contrib.tpu.python.ops import tpu_ops 27 | from tensorflow.contrib.tpu.python.tpu import tpu_function 28 | 29 | 30 | def build_learning_rate(initial_lr, 31 | global_step, 32 | steps_per_epoch=None, 33 | lr_decay_type='exponential', 34 | decay_factor=0.97, 35 | decay_epochs=2.4, 36 | total_steps=None, 37 | warmup_epochs=5): 38 | """Build learning rate.""" 39 | if lr_decay_type == 'exponential': 40 | assert steps_per_epoch is not None 41 | decay_steps = steps_per_epoch * decay_epochs 42 | lr = tf.train.exponential_decay( 43 | initial_lr, global_step, decay_steps, decay_factor, staircase=True) 44 | elif lr_decay_type == 'cosine': 45 | assert total_steps is not None 46 | lr = 0.5 * initial_lr * ( 47 | 1 + tf.cos(np.pi * tf.cast(global_step, tf.float32) / total_steps)) 48 | elif lr_decay_type == 'constant': 49 | lr = initial_lr 50 | else: 51 | assert False, 'Unknown lr_decay_type : %s' % lr_decay_type 52 | 53 | if warmup_epochs: 54 | tf.logging.info('Learning rate warmup_epochs: %d' % warmup_epochs) 55 | warmup_steps = int(warmup_epochs * steps_per_epoch) 56 | warmup_lr = ( 57 | initial_lr * tf.cast(global_step, tf.float32) / tf.cast( 58 | warmup_steps, tf.float32)) 59 | lr = tf.cond(global_step < warmup_steps, lambda: warmup_lr, lambda: lr) 60 | 61 | return lr 62 | 63 | 64 | def build_optimizer(learning_rate, 65 | optimizer_name='rmsprop', 66 | decay=0.9, 67 | epsilon=0.001, 68 | momentum=0.9): 69 | """Build optimizer.""" 70 | if optimizer_name == 'sgd': 71 | tf.logging.info('Using SGD optimizer') 72 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) 73 | elif optimizer_name == 'momentum': 74 | tf.logging.info('Using Momentum optimizer') 75 | optimizer = tf.train.MomentumOptimizer( 76 | learning_rate=learning_rate, momentum=momentum) 77 | elif optimizer_name == 'rmsprop': 78 | tf.logging.info('Using RMSProp optimizer') 79 | optimizer = tf.train.RMSPropOptimizer(learning_rate, decay, momentum, 80 | epsilon) 81 | else: 82 | tf.logging.fatal('Unknown optimizer:', optimizer_name) 83 | 84 | return optimizer 85 | 86 | 87 | class TpuBatchNormalization(tf.layers.BatchNormalization): 88 | # class TpuBatchNormalization(tf.layers.BatchNormalization): 89 | """Cross replica batch normalization.""" 90 | 91 | def __init__(self, fused=False, **kwargs): 92 | if fused in (True, None): 93 | raise ValueError('TpuBatchNormalization does not support fused=True.') 94 | super(TpuBatchNormalization, self).__init__(fused=fused, **kwargs) 95 | 96 | def _cross_replica_average(self, t, num_shards_per_group): 97 | """Calculates the average value of input tensor across TPU replicas.""" 98 | num_shards = tpu_function.get_tpu_context().number_of_shards 99 | group_assignment = None 100 | if num_shards_per_group > 1: 101 | if num_shards % num_shards_per_group != 0: 102 | raise ValueError('num_shards: %d mod shards_per_group: %d, should be 0' 103 | % (num_shards, num_shards_per_group)) 104 | num_groups = num_shards // num_shards_per_group 105 | group_assignment = [[ 106 | x for x in range(num_shards) if x // num_shards_per_group == y 107 | ] for y in range(num_groups)] 108 | return tpu_ops.cross_replica_sum(t, group_assignment) / tf.cast( 109 | num_shards_per_group, t.dtype) 110 | 111 | def _moments(self, inputs, reduction_axes, keep_dims): 112 | """Compute the mean and variance: it overrides the original _moments.""" 113 | shard_mean, shard_variance = super(TpuBatchNormalization, self)._moments( 114 | inputs, reduction_axes, keep_dims=keep_dims) 115 | 116 | num_shards = tpu_function.get_tpu_context().number_of_shards or 1 117 | if num_shards <= 8: # Skip cross_replica for 2x2 or smaller slices. 118 | num_shards_per_group = 1 119 | else: 120 | num_shards_per_group = max(8, num_shards // 8) 121 | tf.logging.info('TpuBatchNormalization with num_shards_per_group %s', 122 | num_shards_per_group) 123 | if num_shards_per_group > 1: 124 | # Compute variance using: Var[X]= E[X^2] - E[X]^2. 125 | shard_square_of_mean = tf.math.square(shard_mean) 126 | shard_mean_of_square = shard_variance + shard_square_of_mean 127 | group_mean = self._cross_replica_average( 128 | shard_mean, num_shards_per_group) 129 | group_mean_of_square = self._cross_replica_average( 130 | shard_mean_of_square, num_shards_per_group) 131 | group_variance = group_mean_of_square - tf.math.square(group_mean) 132 | return (group_mean, group_variance) 133 | else: 134 | return (shard_mean, shard_variance) 135 | 136 | 137 | def drop_connect(inputs, is_training, drop_connect_rate): 138 | """Apply drop connect.""" 139 | if isinstance(is_training, tf.Tensor): 140 | return drop_connect_cond(inputs, is_training, drop_connect_rate) 141 | if not is_training: 142 | return inputs 143 | 144 | # Compute keep_prob 145 | # TODO(tanmingxing): add support for training progress. 146 | keep_prob = 1.0 - drop_connect_rate 147 | 148 | # Compute drop_connect tensor 149 | batch_size = tf.shape(inputs)[0] 150 | random_tensor = keep_prob 151 | random_tensor += tf.random_uniform([batch_size, 1, 1, 1], dtype=inputs.dtype) 152 | binary_tensor = tf.floor(random_tensor) 153 | output = tf.div(inputs, keep_prob) * binary_tensor 154 | return output 155 | 156 | 157 | def drop_connect_cond(inputs, is_training, drop_connect_rate): 158 | """Apply drop connect.""" 159 | def dropc(_inputs, _rate): 160 | # Compute keep_prob 161 | keep_prob = 1.0 - _rate 162 | 163 | # Compute drop_connect tensor 164 | batch_size = tf.shape(inputs)[0] 165 | random_tensor = keep_prob 166 | random_tensor += tf.random_uniform([batch_size, 1, 1, 1], dtype=_inputs.dtype) 167 | binary_tensor = tf.floor(random_tensor) 168 | return tf.div(_inputs, keep_prob) * binary_tensor 169 | 170 | return tf.cond(is_training, lambda: dropc(inputs, drop_connect_rate), lambda: inputs) 171 | 172 | 173 | def archive_ckpt(ckpt_eval, ckpt_objective, ckpt_path): 174 | """Archive a checkpoint if the metric is better.""" 175 | ckpt_dir, ckpt_name = os.path.split(ckpt_path) 176 | 177 | saved_objective_path = os.path.join(ckpt_dir, 'best_objective.txt') 178 | saved_objective = float('-inf') 179 | if tf.gfile.Exists(saved_objective_path): 180 | with tf.gfile.GFile(saved_objective_path, 'r') as f: 181 | saved_objective = float(f.read()) 182 | if saved_objective > ckpt_objective: 183 | tf.logging.info('Ckpt %s is worse than %s', ckpt_objective, saved_objective) 184 | return False 185 | 186 | filenames = tf.gfile.Glob(ckpt_path + '.*') 187 | if filenames is None: 188 | tf.logging.info('No files to copy for checkpoint %s', ckpt_path) 189 | return False 190 | 191 | # Clear the old folder. 192 | dst_dir = os.path.join(ckpt_dir, 'archive') 193 | if tf.gfile.Exists(dst_dir): 194 | tf.gfile.DeleteRecursively(dst_dir) 195 | tf.gfile.MakeDirs(dst_dir) 196 | 197 | # Write checkpoints. 198 | for f in filenames: 199 | dest = os.path.join(dst_dir, os.path.basename(f)) 200 | tf.gfile.Copy(f, dest, overwrite=True) 201 | ckpt_state = tf.train.generate_checkpoint_state_proto( 202 | dst_dir, 203 | model_checkpoint_path=ckpt_name, 204 | all_model_checkpoint_paths=[ckpt_name]) 205 | with tf.gfile.GFile(os.path.join(dst_dir, 'checkpoint'), 'w') as f: 206 | f.write(str(ckpt_state)) 207 | with tf.gfile.GFile(os.path.join(dst_dir, 'best_eval.txt'), 'w') as f: 208 | f.write('%s' % ckpt_eval) 209 | 210 | # Update the best objective. 211 | with tf.gfile.GFile(saved_objective_path, 'w') as f: 212 | f.write('%f' % ckpt_objective) 213 | 214 | tf.logging.info('Copying checkpoint %s to %s', ckpt_path, dst_dir) 215 | return True 216 | 217 | 218 | # TODO(hongkuny): Consolidate this as a common library cross models. 219 | class DepthwiseConv2D(tf.keras.layers.DepthwiseConv2D, tf.layers.Layer): 220 | """Wrap keras DepthwiseConv2D to tf.layers.""" 221 | 222 | pass -------------------------------------------------------------------------------- /models/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Optional 3 | 4 | 5 | class LRScheduler: 6 | 7 | def __init__(self, initial_lr: float, total_steps: int): 8 | self.initial_lr = initial_lr 9 | self.total_steps = total_steps 10 | 11 | def anneal_lr(self, cur_step: int): 12 | """Implemented by subclass""" 13 | pass 14 | 15 | def cur_lr(self, cur_step): 16 | lr = self.anneal_lr(cur_step) 17 | return lr 18 | 19 | 20 | class CosineLRScheduler(LRScheduler): 21 | 22 | def __init__(self, initial_lr: float, total_steps: int): 23 | super().__init__(initial_lr, total_steps) 24 | 25 | def anneal_lr(self, cur_step: int, min_to_decay_to: float = 0.0): 26 | lr = 0.5 * self.initial_lr * (1 + np.cos(np.pi * cur_step / self.total_steps)) 27 | lr = np.max([lr, min_to_decay_to]) 28 | return lr 29 | 30 | 31 | class StepDecay(LRScheduler): 32 | 33 | def __init__(self, initial_lr: float, total_steps: Optional[int] = None, decay_rate: float = 0.5, decay_after_n_steps: int = 5, min_lr: float = 1e-7): 34 | super().__init__(initial_lr, total_steps) 35 | assert decay_rate is not None and decay_after_n_steps is not None 36 | self.decay_rate = decay_rate 37 | self.decay_after_n_steps = decay_after_n_steps 38 | self.min_lr = min_lr 39 | 40 | def anneal_lr(self, cur_step: int, decay_rate: Optional[float] = None, decay_after_n_steps: Optional[int] = None): 41 | if decay_after_n_steps is None: 42 | decay_after_n_steps = self.decay_after_n_steps 43 | if decay_rate is None: 44 | decay_rate = self.decay_rate 45 | m = cur_step // decay_after_n_steps 46 | lr = self.initial_lr * (decay_rate ** m) 47 | lr = self.min_lr if lr < self.min_lr else lr 48 | return lr 49 | 50 | 51 | supported_learning_rate_schedulers = {"cosine_anneal": CosineLRScheduler, "fixed": None, "constant": None, 52 | "step": StepDecay, "step_decay": StepDecay} 53 | -------------------------------------------------------------------------------- /models/regularizers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def l2_term(weight_decay=0.0005): 5 | """ 6 | L2 loss on trainable variables. 7 | Adapted from tensorflow tpu efficientnet implementation 8 | """ 9 | return tf.identity(weight_decay * tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables() 10 | if 'batch_normalization' not in v.name]), name="l2") 11 | 12 | 13 | def l1_term(weight_decay=0.0005): 14 | """ 15 | L1 loss on trainable variables. 16 | """ 17 | return tf.identity(weight_decay * tf.add_n([tf.reduce_sum(tf.math.abs(v)) for v in tf.trainable_variables() 18 | if 'batch_normalization' not in v.name]), name="l1") 19 | 20 | def darc1_term(logits, weight=0.0005): 21 | """Assumes batch dim is first.""" 22 | return tf.identity(weight * tf.reduce_max(tf.reduce_sum(tf.abs(logits), axis=0), name="darc1")) 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.15.4 2 | click 3 | numpy 4 | Pillow 5 | matplotlib 6 | scipy 7 | scikit-optimize 8 | pandas 9 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | data_dir=fewshot_shards/ 5 | name=meta-eval_GPUHO_EffLab_rsd-stages-3-6_`date +%s` 6 | checkpoint_dir=EfficientLab-6-3_FOMAML-star_checkpoint 7 | 8 | python run_metasegnet.py --fss_1000 --image_size 224 \ 9 | --pretrained \ 10 | --rsd 2 4 --l2 \ 11 | --foml --foml-tail 5 \ 12 | --final_layer_dropout_rate 0.5 --augment --aug_rate 0.5 \ 13 | --sgd --loss_name bce_dice --inner-batch 8 --learning-rate 0.0005 --train-shots 10 --inner-iters 59 --learning_rate_scheduler fixed \ 14 | --meta-iters 50000 --meta-batch 5 \ 15 | --eval-interval 500 --serially_eval_all_test_tasks --eval-samples 2 --shots 5 --eval-batch 8 --eval-iters 59 --transductive \ 16 | --model_name efficientlab --sgd --meta-step 0.1 --meta-step-final 0.00001 \ 17 | --checkpoint ${checkpoint_dir} --data-dir ${data_dir} # 2>&1 | tee log_${name}.txt 18 | -------------------------------------------------------------------------------- /run_metasegnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Meta-trains and evaluates image segmentation models. 3 | """ 4 | import copy 5 | import datetime 6 | import json 7 | import logging 8 | import os 9 | 10 | import numpy as np 11 | import random 12 | import tensorflow as tf 13 | 14 | from data.fss_1000_utils import FP_K_TEST_TASK_IDS 15 | from models.constants import SUPPORTED_MODELS 16 | from models.efficientlab import EfficientLab 17 | from models.lr_schedulers import supported_learning_rate_schedulers 18 | from meta_learners.args import argument_parser, model_kwargs, train_kwargs, evaluate_kwargs, hyper_search_kwargs 19 | from meta_learners.metaseg import read_fss_1000_dataset, read_fp_k_shot_dataset 20 | from meta_learners.supervised_reptile.supervised_reptile.eval import evaluate_gecko, optimize_update_hyperparams, \ 21 | run_k_shot_learning_curves_experiment 22 | from meta_learners.supervised_reptile.supervised_reptile.train import train_gecko 23 | from utils.util import latest_checkpoint, validate_datasets 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def main(): 29 | """ 30 | Load data and train an image segmentation model on it. 31 | """ 32 | verbose = True 33 | eval_train_tasks = True 34 | logger.info("Running image segmentation meta-learning...") 35 | start_time = datetime.datetime.now() 36 | print("Experiment started at: {}".format(start_time)) 37 | 38 | args = argument_parser().parse_args() 39 | 40 | if args.optimize_update_hyperparms_on_val_set: 41 | assert args.num_val_tasks > 0, "Must specify number of validation tasks greater than 0 to optimize update hyperparams." 42 | 43 | random.seed(args.seed) 44 | global DATA_DIR 45 | DATA_DIR = args.data_dir 46 | 47 | print('Defining model architecture:') 48 | loss_name = model_kwargs(args)['loss_name'] 49 | print('Using loss {}'.format(loss_name)) 50 | args.model_name = args.model_name.lower() 51 | lr_scheduler = None 52 | if args.model_name == "efficientlab": 53 | restore_ckpt_dir = model_kwargs(args)["restore_ckpt_dir"] 54 | model = EfficientLab(**model_kwargs(args)) 55 | initial_lr = args.learning_rate 56 | total_inner_steps = train_kwargs(args)["eval_inner_iters"] 57 | lr_scheduler_name = args.learning_rate_scheduler 58 | if supported_learning_rate_schedulers[lr_scheduler_name] is not None: 59 | if "step" in lr_scheduler_name: 60 | lr_sched_kwargs = {"decay_rate": args.step_decay_rate, "decay_after_n_steps": args.decay_after_n_steps} 61 | else: 62 | lr_sched_kwargs = {} 63 | lr_scheduler = supported_learning_rate_schedulers[lr_scheduler_name](initial_lr, total_inner_steps, **lr_sched_kwargs) 64 | else: 65 | lr_scheduler = None 66 | else: 67 | raise ValueError("model_name must be in {}".format(SUPPORTED_MODELS)) 68 | print('{} instantiated.'.format(args.model_name)) 69 | print("Model contains {} trainable parameters.".format(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))) 70 | 71 | # Define the meta-learner: 72 | print("Meta-learning with algorithm:") 73 | if args.foml: 74 | print("FOMAML") 75 | else: 76 | print("Reptile") 77 | train_fn, evaluate_fn = train_gecko, evaluate_gecko 78 | 79 | # Get the meta-learning dataset. Each item in train_set and test_set is a task: 80 | print("Setting up meta-learning dataset") 81 | serially_eval_all_test_tasks = args.serially_eval_all_test_tasks 82 | if args.run_k_shot_learning_curves_experiment: 83 | test_set, test_task_names = read_fp_k_shot_dataset(DATA_DIR, image_size=args.image_size) 84 | val_set = None 85 | train_set = None 86 | elif args.fp_k_test_set: 87 | print("Holding out FP-k classes: {}".format(FP_K_TEST_TASK_IDS)) 88 | dataset = read_fss_1000_dataset(DATA_DIR, num_val_tasks=args.num_val_tasks, test_task_ids=FP_K_TEST_TASK_IDS) 89 | train_set, val_set, test_set, train_task_names, val_task_names, test_task_names = dataset 90 | if len(val_set) == 0: 91 | val_set = None 92 | else: 93 | dataset = read_fss_1000_dataset(DATA_DIR, num_val_tasks=args.num_val_tasks) 94 | train_set, val_set, test_set, train_task_names, val_task_names, test_task_names = dataset 95 | if len(val_set) == 0: 96 | val_set = None 97 | 98 | validate_datasets(args, train_set, val_set, test_set) 99 | 100 | if verbose: 101 | print('Found {} testing tasks:'.format(len(test_set))) 102 | for test_task in test_set: 103 | print("{}".format(test_task.name)) 104 | if train_set is not None: 105 | print('Found {} training tasks:'.format(len(train_set))) 106 | for train_task in train_set: 107 | print("{}".format(train_task.name)) 108 | 109 | with tf.Session() as sess: 110 | if args.model_name == "efficientlab": 111 | if restore_ckpt_dir is not None and not args.pretrained: 112 | print("Restoring from checkpoint {}".format(restore_ckpt_dir)) 113 | model.restore_model(sess, restore_ckpt_dir, filter_to_scopes=[model.feature_extractor_name]) 114 | if not args.pretrained: 115 | print("Meta-training...") 116 | 117 | if args.continue_training_from_checkpoint is not None: 118 | continue_training_from_checkpoint = latest_checkpoint(args.continue_training_from_checkpoint) 119 | print('Continuing meta-training from checkpoint: {}'.format(continue_training_from_checkpoint)) 120 | tf.train.Saver().restore(sess, continue_training_from_checkpoint) 121 | model.variables_initialized = True 122 | 123 | _ = train_fn(sess, model, train_set, val_set or test_set, args.checkpoint, lr_scheduler=lr_scheduler, 124 | augment=args.augment, **train_kwargs(args)) 125 | else: 126 | if args.do_not_restore_final_layer_weights: 127 | print('Restoring from checkpoint: {}'.format(args.checkpoint)) 128 | # model.restore_model(sess, args.checkpoint, filter_to_scopes=[model.feature_extractor_name, model.feature_decoder_name], filter_out_scope=model.final_layer_scope, convert_ckpt_to_rel_path=True) 129 | model.restore_model(sess, args.checkpoint, filter_out_scope=model.final_layer_scope, convert_ckpt_to_rel_path=True) 130 | else: 131 | checkpoint = latest_checkpoint(args.checkpoint) 132 | print('Restoring from checkpoint: {}'.format(checkpoint)) 133 | tf.train.Saver().restore(sess, checkpoint) 134 | 135 | eval_kwargs = evaluate_kwargs(args) 136 | 137 | if args.optimize_update_hyperparms_on_val_set: 138 | print("Optimizing the update routine hyperparams on the val set") 139 | assert len(val_set) > 0, "Dev set has no tasks" 140 | save_fine_tuned_checkpoints_test = eval_kwargs["save_fine_tuned_checkpoints"] 141 | eval_kwargs["save_fine_tuned_checkpoints"] = False 142 | num_train_val_data_splits_to_sample_per_config = 1 if args.fss_1000 else 4 143 | estimated_lr, estimated_steps = optimize_update_hyperparams(sess, model, val_set, 144 | lr_scheduler=lr_scheduler, 145 | serially_eval_all_tasks=serially_eval_all_test_tasks, 146 | num_configs_to_sample=args.num_configs_to_sample, save_dir=args.checkpoint, 147 | results_csv_name=args.uho_results_csv_name, 148 | num_train_val_data_splits_to_sample_per_config=num_train_val_data_splits_to_sample_per_config, 149 | max_steps=args.max_steps, min_steps=args.min_steps, 150 | b=args.uho_outer_iters, **eval_kwargs, **hyper_search_kwargs(args)) 151 | eval_kwargs["save_fine_tuned_checkpoints"] = save_fine_tuned_checkpoints_test 152 | eval_kwargs["eval_inner_iters"] = estimated_steps 153 | eval_kwargs["lr"] = estimated_lr 154 | 155 | # Optionally meta-fine-tune on train + val sets here with optimal params, for small number of meta-iters 156 | # (e.g. 200, which is ~1.33 epochs on FSS-1000), and meta-step-final 157 | if args.meta_fine_tune_steps_on_train_val > 0: 158 | print("Fine-tuning meta-learned init for {} meta-steps with optimized hyperparameters.".format(args.meta_fine_tune_steps_on_train_val)) 159 | training_params = train_kwargs(args) 160 | training_params["inner_iters"] = estimated_steps 161 | training_params["lr"] = estimated_lr 162 | training_params["meta_step_size"] = training_params["meta_step_size_final"] 163 | _ = train_fn(sess, model, train_set + val_set, test_set, 164 | os.path.join(args.checkpoint, "fine-tuned_on_train_val_with_optimized_update_hyperparams"), 165 | lr_scheduler=lr_scheduler, augment=args.augment, **training_params) 166 | 167 | del eval_kwargs["eval_tasks_with_median_early_stopping_iterations"] 168 | 169 | if args.run_k_shot_learning_curves_experiment: 170 | k_shot_eval_kwargs = copy.deepcopy(eval_kwargs) 171 | del k_shot_eval_kwargs["save_fine_tuned_checkpoints"] 172 | del k_shot_eval_kwargs["save_fine_tuned_checkpoints_dir"] 173 | run_k_shot_learning_curves_experiment(sess, model, test_set, lr_scheduler=lr_scheduler, iter_range=args.k_shot_iter_range, **k_shot_eval_kwargs) 174 | else: 175 | print('Evaluating {}-shot learning on training tasks.'.format(args.shots)) 176 | if eval_train_tasks: 177 | save_fine_tuned_checkpoints_test = eval_kwargs["save_fine_tuned_checkpoints"] 178 | eval_kwargs["save_fine_tuned_checkpoints"] = args.save_fine_tuned_checkpoints_train 179 | mean_train_iou, _ = evaluate_fn(sess, model, train_set, visualize_predicted_segmentations=False, 180 | lr_scheduler=lr_scheduler, serially_eval_all_tasks=False, 181 | **eval_kwargs) 182 | eval_kwargs["save_fine_tuned_checkpoints"] = save_fine_tuned_checkpoints_test 183 | 184 | if args.eval_val_tasks: 185 | test_set = val_set 186 | test_task_names = val_task_names 187 | test_set_string = "val" 188 | else: 189 | test_set_string = "test" 190 | print('Evaluating {}-shot learning on meta-{} tasks.'.format(args.shots, test_set_string)) 191 | mean_test_iou, task_name_iou_map = evaluate_fn(sess, model, test_set, visualize_predicted_segmentations=False, 192 | lr_scheduler=lr_scheduler, 193 | serially_eval_all_tasks=serially_eval_all_test_tasks, **eval_kwargs) 194 | 195 | print("Evaluated meta-{} tasks:".format(test_set_string)) 196 | print(task_name_iou_map) 197 | if eval_train_tasks: 198 | print("Mean meta-train IoU: {}".format(mean_train_iou)) 199 | # Do NOT change this print (it's used to grep logs): 200 | print("Mean IoU over all meta-test tasks: {}".format(mean_test_iou)) 201 | 202 | # Write results out: 203 | results_path = os.path.join(args.checkpoint, "meta-test_results.json") 204 | with open(results_path, "w") as f: 205 | json.dump(task_name_iou_map, f) 206 | print("Wrote results to {}".format(results_path)) 207 | 208 | 209 | end_time = datetime.datetime.now() 210 | print("Experiment finished at: {}, taking {}".format(end_time, end_time - start_time)) 211 | 212 | 213 | if __name__ == '__main__': 214 | main() 215 | 216 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module configuration. 3 | """ 4 | 5 | from setuptools import setup 6 | 7 | with open('requirements.txt') as f: 8 | requirements = f.read().splitlines() 9 | 10 | 11 | setup( 12 | name='mliis', 13 | version='0.0.1', 14 | description='Meta-learning initializations for image segmentation', 15 | long_description='Code for reproducing experiments in https://arxiv.org/abs/1912.06290', 16 | url='https://github.com/ml4ai/mliis', 17 | author='Sean M. Hendryx', 18 | author_email='seanmhendryx@email.arizona.edu', 19 | license='MIT', 20 | keywords='meta-learning image segmentation ai machine learning', 21 | packages=['mliis'], 22 | install_requires=requirements 23 | ) 24 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utilities for image segmentation meta-learning.""" -------------------------------------------------------------------------------- /utils/debug_tf_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def debug(dataset): 8 | """Debugging utility for tf.data.Dataset.""" 9 | iterator = tf.data.Iterator.from_structure( 10 | dataset.output_types, dataset.output_shapes) 11 | next_element = iterator.get_next() 12 | 13 | ds_init_op = iterator.make_initializer(dataset) 14 | 15 | with tf.Session() as sess: 16 | sess.run(ds_init_op) 17 | viz(sess, next_element) 18 | import pdb; pdb.set_trace() 19 | res = sess.run(next_element) 20 | # for i in range(len(res)): 21 | # print("IoU of label with itself:") 22 | # print(Gecko._iou(res[i][1], res[i][1], class_of_interest_channel=None)) 23 | print(res) 24 | 25 | 26 | def plot_mask(mask_j: np.ndarray, figure_index=0, channel_index: Optional[int] = None, test_iou_of_label: bool = False): 27 | if test_iou_of_label: 28 | from meta_learners.supervised_reptile.supervised_reptile.reptile import Gecko 29 | import matplotlib.pyplot as plt 30 | plt.figure(figure_index) 31 | if channel_index is None: 32 | for k in range(mask_j.shape[2]): 33 | if np.sum(mask_j[:, :, k]) == 0: 34 | continue 35 | break 36 | print("class at channel {}".format(k)) 37 | else: 38 | k = channel_index 39 | plt.imshow(mask_j[:, :, k]) 40 | plt.show() 41 | if test_iou_of_label: 42 | print("IoU of label with itself:") 43 | print(Gecko._iou(mask_j.copy(), mask_j.copy(), class_of_interest_channel=None, round_labels=True)) 44 | return k 45 | 46 | 47 | def viz(sess, next_element, num_to_viz=20): 48 | try: 49 | import matplotlib.pyplot as plt 50 | 51 | for i in range(num_to_viz): 52 | res = sess.run(next_element) 53 | image = res[0].astype(int) 54 | mask = res[1] 55 | if len(image.shape) == 4: 56 | for j in range(image.shape[0]): 57 | plt.figure(i + j) 58 | plt.imshow(image[j]) 59 | plt.show() 60 | mask_j = mask[j] 61 | plot_mask(mask_j, i + j) 62 | else: 63 | plt.figure(i) 64 | plt.imshow(image) 65 | plt.show() 66 | plot_mask(mask, i ) 67 | except Exception as e: 68 | print(e) 69 | import pdb; pdb.set_trace() -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import hashlib 3 | import os 4 | import re 5 | import time 6 | import warnings 7 | from typing import List, Optional, Tuple, Dict, Union 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | from tensorflow import Session 12 | from tensorflow.python import pywrap_tensorflow 13 | 14 | # from data.input_fn import parse_example, _IMAGE_WIDTH 15 | 16 | 17 | def hash_np_array(a: np.array) -> bytes: 18 | """Returns the sha-256 hash bytes of a stringified numpy array.""" 19 | m = hashlib.sha256() 20 | m.update(a.tostring()) 21 | return m.digest() 22 | 23 | 24 | def count_examples_in_tfrecords(paths: List[str]) -> int: 25 | if not isinstance(paths, list): 26 | paths = list(paths) 27 | options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP) 28 | c = 0 29 | with tf.Session() as sess: 30 | for fn in paths: 31 | for record in tf.python_io.tf_record_iterator(fn, options=options): 32 | c += 1 33 | return c 34 | 35 | 36 | def count_unique_task_examples(dir: str, task_name: str) -> int: 37 | shards = glob.glob(os.path.join(dir, "*.tfrecord*")) 38 | shards = [x for x in shards if task_name in x] 39 | return count_examples_in_tfrecords(shards) 40 | 41 | 42 | def latest_checkpoint(checkpoint_dir: str, ckpt_prefix: str = "model.ckpt", return_relative: bool = True) -> str: 43 | if return_relative: 44 | with open(os.path.join(checkpoint_dir, "checkpoint")) as f: 45 | text = f.readline() 46 | pattern = re.compile(re.escape(ckpt_prefix + "-") + r"[0-9]+") 47 | basename = pattern.findall(text)[0] 48 | return os.path.join(checkpoint_dir, basename) 49 | else: 50 | return tf.train.latest_checkpoint(checkpoint_dir) 51 | 52 | 53 | def get_list_of_tensor_names(sess: tf.Session) -> List[str]: 54 | sess.as_default() 55 | graph = tf.get_default_graph() 56 | return [t.name for op in graph.get_operations() for t in op.values()] 57 | 58 | 59 | def get_list_of_node_shapes(sess: tf.Session) -> List: 60 | with sess: 61 | graph_def = sess.graph.as_graph_def(add_shapes=True) 62 | return [n._output_shapes for n in graph_def.node] 63 | 64 | 65 | def mkdir(path): 66 | """ 67 | Recursive create dir at `path` if `path` does not exist. 68 | """ 69 | os.makedirs(path, exist_ok=True) 70 | 71 | 72 | def save_fine_tuned_checkpoint(save_fine_tuned_checkpoint_dir: str, sess: Session, step: Optional[int] = None, 73 | eval_sample_num: Optional[int] = None): 74 | if save_fine_tuned_checkpoint_dir is None: 75 | raise ValueError("Must specify directory in which to save fine-tuned checkpoints if saving them.") 76 | if eval_sample_num is not None: 77 | save_fine_tuned_checkpoint_dir = os.path.join(save_fine_tuned_checkpoint_dir, str(eval_sample_num)) 78 | mkdir(save_fine_tuned_checkpoint_dir) 79 | saver = tf.train.Saver() 80 | saver.save(sess, os.path.join(save_fine_tuned_checkpoint_dir, 'model.ckpt'), global_step=step) 81 | print("Saved fine-tuned checkpoint to {}.".format(save_fine_tuned_checkpoint_dir)) 82 | 83 | 84 | def get_training_set_hash_map(training_set: List[Tuple]) -> Dict[bytes, bytes]: 85 | """Returns a dict mapping sha-256 of image to sha-256 of corresponding mask.""" 86 | hash_map = {} 87 | for pair in training_set: 88 | image_hash = hash_np_array(pair[0]) 89 | mask_hash = hash_np_array(pair[1]) 90 | hash_map[image_hash] = mask_hash 91 | return hash_map 92 | 93 | 94 | def log_estimated_time_remaining(start_time, cur_step, total_steps, unit_name="meta-step"): 95 | elapsed = (time.time() - start_time) / 60. 96 | print("This {} took:".format(unit_name), elapsed, "minutes.") 97 | print('Estimated training hours remaining:%.4f' % ((total_steps - cur_step) * elapsed / 60.)) 98 | return elapsed 99 | 100 | 101 | def get_image_paths(list_of_paths): 102 | return [x for x in list_of_paths if is_image_file(x)] 103 | 104 | 105 | def is_image_file(path): 106 | _, ext = os.path.splitext(path) 107 | if ext in ['.jpg', '.jpeg', '.png', '.tiff', '.tif', '.bmp', ".mat"]: 108 | return True 109 | else: 110 | return False 111 | 112 | 113 | def initialize_uninitialized_vars(session, list_of_variables=None): 114 | if list_of_variables is None: 115 | list_of_variables = tf.global_variables() 116 | uninitialized_variables = list(tf.get_variable(name) for name in 117 | session.run(tf.report_uninitialized_variables(list_of_variables))) 118 | warnings.warn("Initializing the following variables: {}".format(uninitialized_variables)) 119 | print("Initializing the following variables: {}".format(uninitialized_variables)) 120 | session.run(tf.variables_initializer(uninitialized_variables)) 121 | return uninitialized_variables 122 | 123 | 124 | def validate_datasets(args, train_set, val_set, test_set): 125 | if not args.pretrained and not args.run_k_shot_learning_curves_experiment: 126 | assert len(train_set) > 0, "Training set must have examples." 127 | assert len(test_set) > 0, "Test set must have examples." 128 | if args.eval_val_tasks and val_set is not None: 129 | if len(val_set) == 0: 130 | raise ValueError("Val set has no tasks to evaluate") 131 | 132 | 133 | def ci95(a: Union[List[float], np.ndarray]): 134 | """Computes the 95% confidence interval of the array `a`.""" 135 | sigma = np.std(a) 136 | return 1.96 * sigma / np.sqrt(len(a)) 137 | 138 | 139 | def runtime_metrics(runtimes: List[Union[float, int]]): 140 | """runtimes is a list of time it takes to process one image""" 141 | ci = ci95(runtimes) 142 | return np.mean(runtimes), ci 143 | -------------------------------------------------------------------------------- /utils/viz.py: -------------------------------------------------------------------------------- 1 | """Visualization utils for image segmentation.""" 2 | import copy 3 | import os 4 | 5 | import numpy as np 6 | 7 | 8 | def plot_mask_on_image(image: np.ndarray, mask: np.ndarray, truth_value: int = 1.0, alpha=0.75, scale_to_0_1: bool = True) -> None: 9 | import matplotlib.pyplot as plt 10 | import matplotlib.cm as cm 11 | 12 | if scale_to_0_1: 13 | image /= 255. 14 | masked = np.ma.masked_where(mask != truth_value, mask) 15 | fig, ax1 = plt.subplots() 16 | ax1.imshow(image) 17 | ax1.imshow(masked, interpolation='none', alpha=alpha, cmap=cm.jet) 18 | plt.show() 19 | 20 | 21 | def _plot_two_images(A, B): 22 | import matplotlib.pyplot as plt 23 | 24 | plt.figure() 25 | 26 | plt.subplot(121) 27 | plt.imshow(A) 28 | plt.subplot(122) 29 | plt.imshow(B) 30 | plt.show() 31 | 32 | 33 | def _save_plot_two_images(A, B, fname): 34 | import matplotlib.pyplot as plt 35 | 36 | fig = plt.figure() 37 | 38 | plt.subplot(121) 39 | plt.imshow(A) 40 | plt.subplot(122) 41 | plt.imshow(B) 42 | 43 | plt.savefig(fname) 44 | 45 | plt.close(fig) 46 | 47 | 48 | def savefig_mask_on_image(image: np.ndarray, mask: np.ndarray, truth_value: int = 1, alpha=0.5, save_path = None) -> None: 49 | """Plot mask on image for binary segmentation.""" 50 | if not os.path.exists(os.path.dirname(save_path)): 51 | os.makedirs(os.path.dirname(save_path)) 52 | from matplotlib import pyplot as plt, cm as cm # Import locally to avoid errors when matplotlib is not available 53 | image = image.copy() 54 | mask = mask.copy() 55 | if truth_value not in {1, 255}: 56 | raise ValueError("Foreground class should be 1 or 255") 57 | print("mask.shape: {}".format(mask.shape)) 58 | if mask.shape[2] == 2: # Get the second channel 59 | mask = mask[:, :, 1] 60 | if truth_value == 1: 61 | mask *= 255 62 | truth_value = 255 63 | image, mask = image.astype(int), mask.astype(int) 64 | masked = np.ma.masked_where(mask != truth_value, mask) 65 | 66 | plt.gca().set_axis_off() 67 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0, 68 | hspace=0, wspace=0) 69 | plt.margins(0, 0) 70 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 71 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 72 | 73 | fig, ax1 = plt.subplots() 74 | ax1.imshow(image) 75 | ax1.imshow(masked, interpolation='none', alpha=alpha, cmap=cm.autumn) # cm.jet 76 | plt.axis('off') 77 | 78 | if save_path: 79 | plt.savefig(save_path, bbox_inches=0, 80 | pad_inches=0) 81 | print("saved figure to {}".format(save_path)) 82 | plt.close() 83 | else: 84 | print("No path speced to save to.") 85 | plt.show() 86 | 87 | 88 | def savefig_batch_mask_on_image(images, masks, truth_value: int = 1, alpha=0.75, save_path_bn = None, ext=".png") -> None: 89 | """Plot mask on image for binary segmentation.""" 90 | from matplotlib import pyplot as plt, cm as cm # Import locally to avoid errors when matplotlib is not available 91 | for i, image_mask in enumerate(zip(images, masks)): 92 | image, mask = image_mask 93 | image = image.copy() 94 | mask = mask.copy() 95 | if truth_value not in {1, 255}: 96 | raise ValueError("Foreground class should be 1 or 255") 97 | print("mask.shape: {}".format(mask.shape)) 98 | if mask.shape[2] == 2: # Get the second channel 99 | mask = mask[:, :, 1] 100 | if truth_value == 1: 101 | mask *= 255 102 | _truth_value = 255 103 | image, mask = image.astype(int), mask.astype(int) 104 | masked = np.ma.masked_where(mask != _truth_value, mask) 105 | fig, ax1 = plt.subplots() 106 | ax1.imshow(image) 107 | ax1.imshow(masked, interpolation='none', alpha=alpha, cmap=cm.jet) 108 | if save_path_bn: 109 | save_path = save_path_bn + "_" + str(i) + ext 110 | plt.savefig(save_path) 111 | print("saved figure to {}".format(save_path)) 112 | plt.close() 113 | else: 114 | print("No path speced to save to.") 115 | --------------------------------------------------------------------------------