├── .gitignore ├── README.md ├── iterative_saliency.py ├── minsaliency ├── border_collie.jpg ├── goose.jpg ├── model-1.ckpt └── training.jpg ├── real_time_saliency.py ├── requirements.txt ├── sal ├── __init__.py ├── datasets │ ├── __init__.py │ ├── cifar_dataset.py │ ├── imagenet_dataset.py │ └── imagenet_synset.py ├── saliency_model.py ├── small.py └── utils │ ├── __init__.py │ ├── gaussian_blur.py │ ├── mask.py │ ├── pt_store.py │ ├── pytorch_fixes.py │ ├── pytorch_trainer.py │ ├── resnet_encoder.py │ ├── test.jpg │ └── test2.jpg ├── saliency_eval.py ├── saliency_train.py ├── small_trainer.py └── test_black_box.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | ### JetBrains template 3 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 4 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 5 | 6 | # User-specific stuff: 7 | .idea/**/workspace.xml 8 | .idea/**/tasks.xml 9 | .idea/dictionaries 10 | .idea/ 11 | .DS_Store 12 | 13 | # Sensitive or high-churn files: 14 | .idea/**/dataSources/ 15 | .idea/**/dataSources.ids 16 | .idea/**/dataSources.xml 17 | .idea/**/dataSources.local.xml 18 | .idea/**/sqlDataSources.xml 19 | .idea/**/dynamic.xml 20 | .idea/**/uiDesigner.xml 21 | 22 | # Gradle: 23 | .idea/**/gradle.xml 24 | .idea/**/libraries 25 | 26 | # CMake 27 | cmake-build-debug/ 28 | # Mongo Explorer plugin: 29 | .idea/**/mongoSettings.xml 30 | 31 | ## File-based project format: 32 | *.iws 33 | 34 | ## Plugin-specific files: 35 | 36 | # IntelliJ 37 | out/ 38 | 39 | # mpeltonen/sbt-idea plugin 40 | .idea_modules/ 41 | 42 | # JIRA plugin 43 | atlassian-ide-plugin.xml 44 | 45 | # Cursive Clojure plugin 46 | .idea/replstate.xml 47 | 48 | # Crashlytics plugin (for Android Studio and IntelliJ) 49 | com_crashlytics_export_strings.xml 50 | crashlytics.properties 51 | crashlytics-build.properties 52 | fabric.properties 53 | ### Python template 54 | # Byte-compiled / optimized / DLL files 55 | __pycache__/ 56 | *.py[cod] 57 | *$py.class 58 | 59 | # C extensions 60 | *.so 61 | 62 | # Distribution / packaging 63 | .Python 64 | build/ 65 | develop-eggs/ 66 | dist/ 67 | downloads/ 68 | eggs/ 69 | .eggs/ 70 | lib/ 71 | lib64/ 72 | parts/ 73 | sdist/ 74 | var/ 75 | wheels/ 76 | *.egg-info/ 77 | .installed.cfg 78 | *.egg 79 | MANIFEST 80 | 81 | # PyInstaller 82 | # Usually these files are written by a python script from a template 83 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 84 | *.manifest 85 | *.spec 86 | 87 | # Installer logs 88 | pip-log.txt 89 | pip-delete-this-directory.txt 90 | 91 | # Unit test / coverage reports 92 | htmlcov/ 93 | .tox/ 94 | .coverage 95 | .coverage.* 96 | .cache 97 | nosetests.xml 98 | coverage.xml 99 | *.cover 100 | .hypothesis/ 101 | 102 | # Translations 103 | *.mo 104 | *.pot 105 | 106 | # Django stuff: 107 | *.log 108 | .static_storage/ 109 | .media/ 110 | local_settings.py 111 | 112 | # Flask stuff: 113 | instance/ 114 | .webassets-cache 115 | 116 | # Scrapy stuff: 117 | .scrapy 118 | 119 | # Sphinx documentation 120 | docs/_build/ 121 | 122 | # PyBuilder 123 | target/ 124 | 125 | # Jupyter Notebook 126 | .ipynb_checkpoints 127 | 128 | # pyenv 129 | .python-version 130 | 131 | # celery beat schedule file 132 | celerybeat-schedule 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .venv 140 | env/ 141 | venv/ 142 | ENV/ 143 | env.bak/ 144 | venv.bak/ 145 | 146 | # Spyder project settings 147 | .spyderproject 148 | .spyproject 149 | 150 | # Rope project settings 151 | .ropeproject 152 | 153 | # mkdocs documentation 154 | /site 155 | 156 | # mypy 157 | .mypy_cache/ 158 | 159 | ### Python template 160 | # Byte-compiled / optimized / DLL files 161 | __pycache__/ 162 | *.py[cod] 163 | *$py.class 164 | 165 | # C extensions 166 | *.so 167 | 168 | # Distribution / packaging 169 | .Python 170 | build/ 171 | develop-eggs/ 172 | dist/ 173 | downloads/ 174 | eggs/ 175 | .eggs/ 176 | lib/ 177 | lib64/ 178 | parts/ 179 | sdist/ 180 | var/ 181 | wheels/ 182 | *.egg-info/ 183 | .installed.cfg 184 | *.egg 185 | MANIFEST 186 | 187 | # PyInstaller 188 | # Usually these files are written by a python script from a template 189 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 190 | *.manifest 191 | *.spec 192 | 193 | # Installer logs 194 | pip-log.txt 195 | pip-delete-this-directory.txt 196 | 197 | # Unit test / coverage reports 198 | htmlcov/ 199 | .tox/ 200 | .coverage 201 | .coverage.* 202 | .cache 203 | nosetests.xml 204 | coverage.xml 205 | *.cover 206 | .hypothesis/ 207 | 208 | # Translations 209 | *.mo 210 | *.pot 211 | 212 | # Django stuff: 213 | *.log 214 | .static_storage/ 215 | .media/ 216 | local_settings.py 217 | 218 | # Flask stuff: 219 | instance/ 220 | .webassets-cache 221 | 222 | # Scrapy stuff: 223 | .scrapy 224 | 225 | # Sphinx documentation 226 | docs/_build/ 227 | 228 | # PyBuilder 229 | target/ 230 | 231 | # Jupyter Notebook 232 | .ipynb_checkpoints 233 | 234 | # pyenv 235 | .python-version 236 | 237 | # celery beat schedule file 238 | celerybeat-schedule 239 | 240 | # SageMath parsed files 241 | *.sage.py 242 | 243 | # Environments 244 | .env 245 | .venv 246 | env/ 247 | venv/ 248 | ENV/ 249 | env.bak/ 250 | venv.bak/ 251 | 252 | # Spyder project settings 253 | .spyderproject 254 | .spyproject 255 | # Rope project settings 256 | .ropeproject 257 | 258 | # mkdocs documentation 259 | /site 260 | 261 | # mypy 262 | .mypy_cache/ 263 | lastidea.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Real-time image saliency 2 | 3 | 4 | _See what your classifier is looking at!_ [[PAPER]](https://arxiv.org/abs/1705.07857) 5 | 6 | ![UI](minsaliency/border_collie.jpg) 7 | 8 | 9 | #### Real-time saliency view 10 | Run `python real_time_saliency.py` to perform the saliency detection on the video feed from your webcam. 11 | You can choose the class to visualise (1000 ImageNet classes) as well as the confidence level - low 12 | confidence will highlight anything that resembles or is related to the target class, while higher confidence 13 | will only show the most salient parts. 14 | 15 | The model runs on a CPU by default and achieves about 5 frames per 16 | second on my MacBook Pro (and over 150 frames per second on a GPU). 17 | 18 | #### Training 19 | 20 | Run `python saliency_train.py` to start the training. By default it will train the model to perform the saliency detection on the ImageNet dataset for the resnet50 classifier, but you can choose your own dataset/classifier combination. 21 | You will need PyTorch wich cuda support, the training will be performed on all your GPUs in parallel. I also advide to run the script from iTerm 2 terminal so that you can see the images during traning. 22 | 23 | ![UI](minsaliency/training.jpg) 24 | 25 | #### Using pretrained model 26 | 27 | ```python 28 | from saliency_eval import get_pretrained_saliency_fn 29 | 30 | sal_fn = get_pretrained_saliency_fn() 31 | 32 | # get the saliency map (see get_pretrained_saliency_fn doc for details) 33 | sal_map = sal_fn(images, selectors) 34 | 35 | ``` 36 | 37 | #### Requirements 38 | 39 | `pip install -r requirements.txt` 40 | 41 | Also, in case you don't have OpenCV3 installed run: 42 | 43 | `pip install opencv-contrib-python` -------------------------------------------------------------------------------- /iterative_saliency.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as torch_optim 4 | import numpy as np 5 | from torch.autograd import Variable 6 | from sal.saliency_model import SaliencyLoss, get_black_box_fn 7 | from sal.utils import pt_store 8 | from saliency_eval import to_batch_variable, load_image_as_variable 9 | PT = pt_store.PTStore() 10 | import pycat 11 | import os 12 | 13 | 14 | class IterativeSaliency: 15 | def __init__(self, cuda=True, black_box_fn=None, mask_resolution=32, num_classes=1000, default_iterations=200): 16 | if black_box_fn is None: 17 | self.black_box_fn = get_black_box_fn(cuda=cuda) # defaults to ResNet-50 on ImageNet 18 | self.default_iterations = default_iterations 19 | self.mask_resolution = mask_resolution 20 | self.num_classes = num_classes 21 | self.saliency_loss_calc = SaliencyLoss(self.black_box_fn, area_loss_coef=11, smoothness_loss_coef=0.5, preserver_loss_coef=0.2) 22 | self.cuda = cuda 23 | 24 | def get_saliency_maps(self, _images, _targets, iterations=None, show=False): 25 | ''' returns saliency maps. 26 | Params 27 | _images - input images of shape (C, H, W) or (N, C, H, W) if in batch. Can be either a numpy array, a Tensor or a Variable 28 | _targets - class ids to be masked. Can be either an int or an array with N integers. Again can be either a numpy array, a Tensor or a Variable 29 | 30 | returns a Variable of shape (N, 1, H, W) with one saliency maps for each input image. 31 | ''' 32 | _images, _targets = to_batch_variable(_images, 4, self.cuda).float(), to_batch_variable(_targets, 1, self.cuda).long() 33 | 34 | 35 | if iterations is None: 36 | iterations = self.default_iterations 37 | 38 | if self.cuda: 39 | _mask = nn.Parameter(torch.Tensor(_images.size(0), 2, self.mask_resolution, self.mask_resolution).fill_(0.5).cuda()) 40 | else: 41 | _mask = nn.Parameter(torch.Tensor(_images.size(0), 2, self.mask_resolution, self.mask_resolution).fill_(0.5)) 42 | optim = torch_optim.SGD([_mask], 0.1, 0.9, nesterov=True) 43 | #optim = torch_optim.Adam([_mask], 0.2) 44 | 45 | for iteration in xrange(iterations): 46 | #_mask.data.clamp_(0., 1.) 47 | optim.zero_grad() 48 | 49 | a = torch.abs(_mask[:, 0, :, :]) 50 | b = torch.abs(_mask[:, 1, :, :]) 51 | _mask_ = torch.unsqueeze(a / (a + b+0.001), dim=1) 52 | 53 | total_loss = self.saliency_loss_calc.get_loss(_images, _targets, _mask_, pt_store=PT) 54 | 55 | total_loss.backward() 56 | 57 | optim.step() 58 | if show: 59 | pycat.show(PT['masks'][0]*255, auto_normalize=False) 60 | pycat.show(PT['preserved'][0]) 61 | return PT.masks 62 | 63 | 64 | 65 | def test(): 66 | from PIL import Image 67 | import time 68 | print 'We will optimize for the white terrier class...' 69 | time.sleep(3) 70 | 71 | i = IterativeSaliency() 72 | 73 | ims = load_image_as_variable(os.path.join(os.path.dirname(__file__), 'sal/utils/test.jpg')) 74 | i.get_saliency_maps(ims, [203], show=True) 75 | 76 | 77 | if __name__ == '__main__': 78 | test() 79 | -------------------------------------------------------------------------------- /minsaliency/border_collie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiotrDabkowski/pytorch-saliency/bfd501ec7888dbb3727494d06c71449df1530196/minsaliency/border_collie.jpg -------------------------------------------------------------------------------- /minsaliency/goose.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiotrDabkowski/pytorch-saliency/bfd501ec7888dbb3727494d06c71449df1530196/minsaliency/goose.jpg -------------------------------------------------------------------------------- /minsaliency/model-1.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiotrDabkowski/pytorch-saliency/bfd501ec7888dbb3727494d06c71449df1530196/minsaliency/model-1.ckpt -------------------------------------------------------------------------------- /minsaliency/training.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiotrDabkowski/pytorch-saliency/bfd501ec7888dbb3727494d06c71449df1530196/minsaliency/training.jpg -------------------------------------------------------------------------------- /real_time_saliency.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from Queue import Queue 3 | import time 4 | import numpy as np 5 | import threading 6 | from sal.datasets.imagenet_dataset import CLASS_ID_TO_NAME, CLASS_NAME_TO_ID 7 | import wx 8 | from torch.nn.functional import softmax 9 | import io 10 | from PIL import Image 11 | import textwrap 12 | 13 | import matplotlib 14 | matplotlib.use('Agg') 15 | import matplotlib.pyplot as plt 16 | 17 | TO_SHOW = 737 18 | CONFIDENCE = 5 19 | FAST_MODE = True 20 | POLL_DELAY = 0.01 21 | LOGITS = 1000*[0] 22 | 23 | SAVE_SIGNAL = False 24 | 25 | def numpy_to_wx(image): 26 | height, width, c = image.shape 27 | buffer = Image.fromarray(image).convert('RGB').tobytes() 28 | bitmap = wx.BitmapFromBuffer(width, height, buffer) 29 | return bitmap 30 | 31 | class RealTimeSaliency(wx.Frame): 32 | # ---------------------------------------------------------------------- 33 | def __init__(self): 34 | wx.Frame.__init__(self, None, wx.ID_ANY, "Real-time saliency", size=(1100, 800)) 35 | 36 | self.SetMinClientSize((600, 400)) 37 | self.on_update = None 38 | 39 | panel = wx.Panel(self, wx.ID_ANY) 40 | self.img_viewer = ImgViewPanel(self) 41 | self.cls_viewer = ImgViewPanel(self) 42 | 43 | 44 | 45 | self.index = 0 46 | self.list_ctrl = wx.ListCtrl(panel, 47 | style=wx.LC_REPORT) 48 | self.search_ctrl = wx.TextCtrl(panel, value='Search', size=(200, 25)) 49 | self.search_ctrl.Bind(wx.EVT_TEXT, self.on_search) 50 | 51 | self.list_ctrl.InsertColumn(0, 'Class name', width=200) 52 | 53 | self.static_img_picker = wx.FilePickerCtrl(panel) 54 | self.static_img_picker.SetPath('Static img (optional)') 55 | 56 | self.slider_ctrl = wx.Slider(panel, value=4, minValue=-2, maxValue=11, style=wx.SL_MIN_MAX_LABELS|wx.SL_VALUE_LABEL) 57 | self.slider_ctrl.Bind(wx.EVT_SCROLL, self.on_slide) 58 | self.info = wx.StaticText(panel) 59 | self.info_ = wx.StaticText(panel, label='Confidence:') 60 | 61 | self.show_items_that_contain() 62 | 63 | btn = wx.Button(panel, label='Choose') 64 | btn.Bind(wx.EVT_BUTTON, self.choose_class) 65 | 66 | save_btn = wx.Button(panel, label='Save') 67 | save_btn.Bind(wx.EVT_BUTTON, self.save_imgs) 68 | 69 | hsizer = wx.BoxSizer(wx.HORIZONTAL) 70 | hsizer.Add(panel, 1, wx.ALL | wx.EXPAND, 5) 71 | hsizer.Add(self.img_viewer, 2, wx.ALL | wx.EXPAND, 5) 72 | 73 | self.SetSizer(hsizer) 74 | 75 | 76 | sizer = wx.BoxSizer(wx.VERTICAL) 77 | sizer.Add(self.static_img_picker, 0, wx.EXPAND, 5) 78 | sizer.Add(self.info, 0, wx.ALL | wx.EXPAND, 5) 79 | sizer.Add(self.info_, 0, wx.TOP | wx.LEFT | wx.EXPAND, 5) 80 | sizer.Add(self.slider_ctrl, 0, wx.EXPAND, 0) 81 | sizer.Add(self.list_ctrl, 3, wx.ALL | wx.EXPAND, 5) 82 | sizer.Add(self.search_ctrl, 0, wx.ALL | wx.EXPAND, 5) 83 | 84 | btn_h_sizer = wx.BoxSizer(wx.HORIZONTAL) 85 | btn_h_sizer.Add(btn, 0, wx.ALL | wx.CENTER, 5) 86 | btn_h_sizer.Add(save_btn, 0, wx.ALL | wx.CENTER, 5) 87 | sizer.Add(btn_h_sizer, 0, wx.ALL | wx.CENTER, 0) 88 | 89 | sizer.Add(self.cls_viewer, 2, wx.ALL | wx.EXPAND, 5) 90 | panel.SetSizer(sizer) 91 | wx.CallLater(100, self.update) 92 | 93 | def save_imgs(self, event): 94 | global SAVE_SIGNAL 95 | SAVE_SIGNAL = True 96 | 97 | def on_slide(self, event): 98 | global CONFIDENCE 99 | CONFIDENCE = self.slider_ctrl.GetValue() 100 | 101 | def get_img(self): 102 | static_img_path = self.static_img_picker.GetPath() 103 | if static_img_path: 104 | try: 105 | img = cv2.imread(static_img_path) 106 | img = np.flip(img, 2) 107 | except: 108 | return None 109 | return img 110 | return None 111 | 112 | def update(self): 113 | self.info.SetLabel('Showing: %s (logits: %f)' % (CLASS_ID_TO_NAME[TO_SHOW], LOGITS[TO_SHOW])) 114 | if self.on_update is not None: 115 | self.on_update() 116 | wx.CallLater(100, self.update) 117 | 118 | def on_search(self, event): 119 | self.show_items_that_contain(self.search_ctrl.GetValue()) 120 | 121 | def show_items_that_contain(self, text=''): 122 | self.list_ctrl.DeleteAllItems() 123 | i = 0 124 | for name in CLASS_ID_TO_NAME.values(): 125 | if text.lower().strip() in name.lower(): 126 | self.list_ctrl.InsertItem(i, name) 127 | i += 1 128 | 129 | def choose_class(self, event): 130 | global TO_SHOW 131 | TO_SHOW = CLASS_NAME_TO_ID[self.list_ctrl.GetItem(self.list_ctrl.GetFocusedItem()).GetText()] 132 | 133 | 134 | class ImgViewPanel(wx.Panel): 135 | def __init__(self, parent): 136 | self.parent = parent 137 | self.dialog_init_function = False 138 | self.dialog_out = False 139 | super(ImgViewPanel, self).__init__(parent, -1) 140 | self.SetBackgroundStyle(wx.BG_STYLE_CUSTOM) 141 | self.Bind(wx.EVT_PAINT, self.on_paint) 142 | self.img_now = None 143 | self.changed = False 144 | self.change_frame(np.zeros((10, 10, 3), dtype=np.uint8)) 145 | self.update() 146 | 147 | def bind_mouse(self, click, move, release, scroll): 148 | self.Bind(wx.EVT_LEFT_DOWN, click) 149 | self.Bind(wx.EVT_MOTION, move) 150 | self.Bind(wx.EVT_LEFT_UP, release) 151 | self.Bind(wx.EVT_MOUSEWHEEL, scroll) 152 | 153 | def update(self): 154 | if not self.changed: 155 | self.changed = True 156 | self.Refresh() 157 | self.Update() 158 | if self.dialog_init_function: 159 | try: 160 | # dialog_init_function must be a funtion that takes the parent as arg 161 | # and returns wx dialog object 162 | dialog = self.dialog_init_function(self) 163 | dialog.ShowModal() 164 | self.dialog_out = dialog.GetValue() 165 | self.dialog_out = self.dialog_out if self.dialog_out else True 166 | dialog.Destroy() 167 | 168 | except Exception: 169 | print 'Could not open the dialog!' 170 | self.dialog_init_function = False 171 | wx.CallLater(15, self.update) 172 | 173 | def on_paint(self, event): 174 | dc = wx.AutoBufferedPaintDC(self) 175 | dc.DrawBitmap(self.img_now, 0, 0) 176 | 177 | def change_frame(self, image): 178 | '''image must be PIL or wx image''' 179 | s = self.GetSize() 180 | x = s.x 181 | y = s.y 182 | image = cv2.resize(image, (x, y), interpolation=cv2.INTER_LINEAR) 183 | self.img_now = numpy_to_wx(image) 184 | 185 | self.changed = False 186 | 187 | 188 | class RT: 189 | DELAY_SMOOTH = 0.85 190 | def __init__(self, processor, batch_size=1, view_size=(324, 324)): 191 | ''' 192 | How it works? 193 | The images are continuously captured and added to the queue together with their frame timestamps. 194 | Another thread processes the images in the queue by passing them to the processor function. 195 | The processor function operates on batches of images: it takes a numpy array of shape (batch_size, H, W, 3) where H, W is the native resolution of the camera 196 | and must return a numpy array of shape (batch_size, ?, ?, 3) - note the size of the image does not matter as it will be resized to the view_size anyway. 197 | Images are normalized between -1 and 1! 198 | If the queue grows faster than the we can process the images then we will skip images and processor function will be given only 199 | every Nth image in the outstanding queue. 200 | 201 | Finally the processed images are placed on the display queue together with their timestamps and the display thread takes care of displaying images 202 | at the correct times by estimating the current overall delay and fps. 203 | For example if the processor takes 1 second to process one image then the dealy will be 1 second and the resulting fps will be 1. 204 | ''' 205 | self.batch_size = batch_size 206 | self.processor = processor 207 | self.cam = None 208 | self.req_queue = Queue() 209 | self.display_queue = Queue() 210 | self.delay = 0. 211 | self.time_per_frame = 0. 212 | self.show_image = None 213 | self.get_custom_rgb_img = None 214 | 215 | 216 | def start(self): 217 | if self.cam is None: 218 | self.cam = cv2.VideoCapture(0) 219 | self._stop = False 220 | # start transformer and display services 221 | tr = threading.Thread(target=self.transform) 222 | dis = threading.Thread(target=self.display) 223 | tr.daemon = True 224 | dis.daemon = True 225 | tr.start() 226 | dis.start() 227 | self._get_next_frame() 228 | 229 | def _get_next_frame(self): 230 | if self.get_custom_rgb_img: 231 | img = self.get_custom_rgb_img() 232 | else: 233 | img = None 234 | if img is None: 235 | ret_val, img = self.cam.read() 236 | img = np.flip(img, 1) 237 | # remember, remember to switch to RGB! 238 | img = np.flip(img, 2) 239 | 240 | self.req_queue.put((time.time(), img)) 241 | 242 | def stop(self): 243 | self._stop = True 244 | time.sleep(1.) 245 | 246 | def transform(self): 247 | while not self._stop: 248 | if self.req_queue.qsize()< self.batch_size: 249 | time.sleep(POLL_DELAY) 250 | continue 251 | to_proc = [] 252 | while not self.req_queue.empty(): 253 | to_proc.append(self.req_queue.get(timeout=0.1)) 254 | 255 | if len(to_proc) > self.batch_size: 256 | # usual case, take self.batch_size equally separated items 257 | sep = int(len(to_proc) / self.batch_size) 258 | old = len(to_proc) 259 | to_proc = to_proc[:sep*self.batch_size:sep] 260 | assert len(to_proc) == self.batch_size 261 | 262 | imgs = np.concatenate(tuple(np.expand_dims(e[1], 0) for e in to_proc), 0) 263 | done_imgs = ((self.processor(imgs/(255./2) - 1.) + 1) * (255./2.)).astype(np.uint8) 264 | 265 | for e in xrange(len(done_imgs)): 266 | im = done_imgs[e] 267 | t = to_proc[e][0] 268 | self.display_queue.put((t, im)) 269 | 270 | def display(self): 271 | last_frame = time.time() 272 | while not self._stop: 273 | if self.display_queue.empty(): 274 | time.sleep(POLL_DELAY) 275 | continue 276 | creation_time, im = self.display_queue.get(timeout=11) 277 | self.delay = self.DELAY_SMOOTH*self.delay + (1.-self.DELAY_SMOOTH)*(time.time() - creation_time) 278 | while time.time() < creation_time + self.delay: 279 | time.sleep(POLL_DELAY) 280 | self.time_per_frame = 0.9*self.time_per_frame + 0.1*(time.time() - last_frame) 281 | if self.show_image is not None: 282 | self.show_image(im) 283 | last_frame = time.time() 284 | 285 | @property 286 | def fps(self): 287 | return 1./self.time_per_frame 288 | 289 | 290 | 291 | from saliency_eval import get_pretrained_saliency_fn 292 | def get_proc_fn(cuda=False): 293 | fn = get_pretrained_saliency_fn(cuda=cuda, return_classification_logits=True) 294 | 295 | def proc(ims): 296 | global LOGITS 297 | print ims.shape 298 | if FAST_MODE: 299 | sq = square_centrer_crop_resize_op(np.squeeze(ims, 0), (224, 224)) 300 | else: 301 | sq = cv2.resize(np.squeeze(ims, 0), (288, 512), interpolation=cv2.INTER_AREA) 302 | sq = np.transpose(sq, (2, 0, 1)) 303 | mask, cls = fn(sq, TO_SHOW, CONFIDENCE) 304 | 305 | mask = mask[0].cpu().data.numpy() 306 | LOGITS = cls[0].cpu().data.numpy() 307 | probs = softmax(cls)[0].cpu().data.numpy() 308 | 309 | 310 | cls_im = get_probs_np_img(probs) 311 | frame.cls_viewer.change_frame(cls_im) 312 | 313 | 314 | if SAVE_SIGNAL: 315 | global SAVE_SIGNAL 316 | save_img(sq, 'original') 317 | save_img(sq*(1-mask), 'destroyed') 318 | save_img(sq*mask, 'preserved') 319 | save_img(2*np.concatenate((0.*mask, 0.*mask, mask), 0)-1, 'mask') 320 | SAVE_SIGNAL = False 321 | 322 | mask = mask/2. 323 | sq = sq*(1-mask) + np.concatenate((mask, -mask, -mask), 0) 324 | return np.expand_dims(np.transpose(sq, (1, 2, 0)), 0) 325 | return proc 326 | 327 | def square_centrer_crop_resize_op(im, size): 328 | short_edge = min(im.shape[:2]) 329 | yy = int((im.shape[0] - short_edge) / 2) 330 | xx = int((im.shape[1] - short_edge) / 2) 331 | max_square = im[yy: yy + short_edge, xx: xx + short_edge] 332 | return cv2.resize(max_square, size, interpolation=cv2.INTER_AREA) 333 | 334 | 335 | from sal.datasets import imagenet_dataset 336 | def get_probs_np_img(probs, num=6): 337 | top_k, top_k_probs = np.argsort(probs)[::-1][:num], np.sort(probs)[::-1][:num] 338 | print top_k, top_k_probs 339 | # top_k = [162, 463, 281, 178, 181, 596] 340 | # top_k_probs = [ 0.20559976, 0.10194654, 0.0338834, 0.03120151, 0.02777977, 0.02264762] 341 | objects = ['\n'.join(textwrap.wrap(imagenet_dataset.CLASS_ID_TO_NAME[e], 15)[:2]) for e in top_k] 342 | y_pos = np.arange(len(objects)) 343 | performance = list(top_k_probs) 344 | 345 | plt.figure() 346 | plt.bar(y_pos, performance, align='center', alpha=0.8, color='blue') 347 | plt.ylim([0.,1.]) 348 | plt.xticks(y_pos, objects, rotation='vertical') 349 | 350 | plt.ylabel('Probability') 351 | plt.tight_layout() 352 | 353 | buf = io.BytesIO() 354 | plt.savefig(buf, format='png') 355 | buf.seek(0) 356 | 357 | im = Image.open(buf) 358 | plt.close() 359 | 360 | return np.array(im.convert('RGB')) 361 | 362 | 363 | def save_img(img, name): 364 | img = np.transpose(img, [1,2,0]) 365 | img = np.flip(img, 2) 366 | cv2.imwrite((name+'.png') if '.' not in name else name, ((img + 1) * (255. / 2.)).astype(np.uint8)) 367 | 368 | 369 | if __name__ == "__main__": 370 | app = wx.App(False) 371 | frame = RealTimeSaliency() 372 | a = RT(get_proc_fn(cuda=False)) 373 | a.start() 374 | a.show_image = frame.img_viewer.change_frame 375 | a.get_custom_rgb_img = frame.get_img 376 | frame.on_update = a._get_next_frame 377 | frame.Show() 378 | app.MainLoop() 379 | 380 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pycat-real 2 | wxPython 3 | torchvision 4 | matplotlib -------------------------------------------------------------------------------- /sal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiotrDabkowski/pytorch-saliency/bfd501ec7888dbb3727494d06c71449df1530196/sal/__init__.py -------------------------------------------------------------------------------- /sal/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import cifar_dataset 2 | 3 | SUPPORTED_DATASETS = { 4 | 'cifar10': cifar_dataset 5 | } -------------------------------------------------------------------------------- /sal/datasets/cifar_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import * 3 | from torchvision.datasets import CIFAR10 4 | from torch.utils.data import dataloader 5 | import time, random 6 | from ..utils.pytorch_fixes import * 7 | 8 | 9 | SUGGESTED_BS = 512 10 | SUGGESTED_EPOCHS_PER_STEP = 22 11 | SUGGESTED_BASE = 32 12 | NUM_CLASSES = 10 13 | CLASSES = '''airplane airplane 14 | automobile automobile 15 | bird bird 16 | cat cat 17 | deer deer 18 | dog dog 19 | frog frog 20 | horse horse 21 | ship ship 22 | truck truck'''.splitlines() 23 | 24 | def get_train_dataset(size=32): 25 | return CIFAR10('/home/piter/CIFAR10Dataset', train=True, transform=Compose( 26 | [Scale(size), RandomSizedCrop2(size, min_area=0.5), RandomHorizontalFlip(), ToTensor(), STD_NORMALIZE]), download=True) 27 | 28 | 29 | 30 | def get_val_dataset(size=32): 31 | return CIFAR10('/home/piter/CIFAR10Dataset', train=False, transform=Compose( 32 | [Scale(size), CenterCrop(size), ToTensor(), STD_NORMALIZE]), download=True) 33 | 34 | 35 | def get_loader(dataset, batch_size=64, pin_memory=True): 36 | return dataloader.DataLoader(dataset=dataset, batch_size=batch_size, 37 | shuffle=True, drop_last=True, num_workers=8, pin_memory=True) 38 | 39 | 40 | def test(): 41 | BS = 64 42 | SAMP = 20 43 | dts = get_train_dataset() 44 | loader = get_loader(dts, batch_size=BS) 45 | i = 0 46 | t = time.time() 47 | for ims, labs in loader: 48 | i+=1 49 | if not i%20: 50 | print "Images per second:", SAMP*BS/(time.time()-t) 51 | pycat.show(ims[0].numpy()) 52 | t = time.time() 53 | if i==100: 54 | break -------------------------------------------------------------------------------- /sal/datasets/imagenet_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import * 2 | from torchvision.datasets import ImageFolder 3 | from torch.utils.data import dataloader 4 | import pycat, time, random 5 | from ..utils.pytorch_fixes import * 6 | import os 7 | 8 | # Images must be segregated in folders by class! So both train and val folders should contain 1000 folders, one for each class. 9 | # PLEASE EDIT THESE 2 LINES: 10 | IMAGE_NET_TRAIN_PATH = '/home/piter/ImageNetFull/train/' 11 | IMAGE_NET_VAL_PATH = '/home/piter/ImageNetFull/val/' 12 | 13 | 14 | #----------------------------------------------------- 15 | 16 | SUGGESTED_BS = 128 17 | NUM_CLASSES = 1000 18 | SUGGESTED_EPOCHS_PER_STEP = 11 19 | SUGGESTED_BASE = 64 20 | 21 | 22 | def get_train_dataset(size=224): 23 | if not (os.path.exists(IMAGE_NET_TRAIN_PATH) and os.path.exists(IMAGE_NET_VAL_PATH)): 24 | raise ValueError( 25 | 'Please make sure that you specify a path to the ImageNet dataset folder in sal/datasets/imagenet_dataset.py file!') 26 | return ImageFolder(IMAGE_NET_TRAIN_PATH, transform=Compose([ 27 | RandomSizedCrop2(size, min_area=0.3), 28 | RandomHorizontalFlip(), 29 | ToTensor(), 30 | STD_NORMALIZE, # Images will be in range -1 to 1 31 | ])) 32 | 33 | 34 | def get_val_dataset(size=224): 35 | if not (os.path.exists(IMAGE_NET_TRAIN_PATH) and os.path.exists(IMAGE_NET_VAL_PATH)): 36 | raise ValueError( 37 | 'Please make sure that you specify a path to the ImageNet dataset folder in sal/datasets/imagenet_dataset.py file!') 38 | return ImageFolder(IMAGE_NET_VAL_PATH, transform=Compose([ 39 | Scale(224), 40 | CenterCrop(size), 41 | ToTensor(), 42 | STD_NORMALIZE, 43 | ])) 44 | 45 | def get_loader(dataset, batch_size=64, pin_memory=True): 46 | return dataloader.DataLoader(dataset=dataset, batch_size=batch_size, 47 | shuffle=True, drop_last=True, num_workers=8, pin_memory=pin_memory) 48 | 49 | 50 | def test(): 51 | BS = 64 52 | SAMP = 20 53 | dts = get_val_dataset() 54 | loader = get_loader(dts, batch_size=BS) 55 | i = 0 56 | t = time.time() 57 | for ims, labs in loader: 58 | i+=1 59 | if not i%20: 60 | print 'min', torch.min(ims),'max', torch.max(ims), 'var', torch.var(ims), 'mean', torch.mean(ims) 61 | print "Images per second:", SAMP*BS/(time.time()-t) 62 | pycat.show(ims[0].numpy()) 63 | t = time.time() 64 | if i==100: 65 | break 66 | 67 | 68 | 69 | from imagenet_synset import synset 70 | SYNSET_TO_NAME= dict((e[:9], e[10:]) for e in synset.splitlines()) 71 | SYNSET_TO_CLASS_ID = dict((e[:9], i) for i, e in enumerate(synset.splitlines())) 72 | 73 | CLASS_ID_TO_SYNSET = {v:k for k,v in SYNSET_TO_CLASS_ID.items()} 74 | CLASS_ID_TO_NAME = {i:SYNSET_TO_NAME[CLASS_ID_TO_SYNSET[i]] for i in CLASS_ID_TO_SYNSET} 75 | CLASS_NAME_TO_ID = {v:k for k, v in CLASS_ID_TO_NAME.items()} 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /sal/datasets/imagenet_synset.py: -------------------------------------------------------------------------------- 1 | synset = '''n01440764 tench, Tinca tinca 2 | n01443537 goldfish, Carassius auratus 3 | n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias 4 | n01491361 tiger shark, Galeocerdo cuvieri 5 | n01494475 hammerhead, hammerhead shark 6 | n01496331 electric ray, crampfish, numbfish, torpedo 7 | n01498041 stingray 8 | n01514668 cock 9 | n01514859 hen 10 | n01518878 ostrich, Struthio camelus 11 | n01530575 brambling, Fringilla montifringilla 12 | n01531178 goldfinch, Carduelis carduelis 13 | n01532829 house finch, linnet, Carpodacus mexicanus 14 | n01534433 junco, snowbird 15 | n01537544 indigo bunting, indigo finch, indigo bird, Passerina cyanea 16 | n01558993 robin, American robin, Turdus migratorius 17 | n01560419 bulbul 18 | n01580077 jay 19 | n01582220 magpie 20 | n01592084 chickadee 21 | n01601694 water ouzel, dipper 22 | n01608432 kite 23 | n01614925 bald eagle, American eagle, Haliaeetus leucocephalus 24 | n01616318 vulture 25 | n01622779 great grey owl, great gray owl, Strix nebulosa 26 | n01629819 European fire salamander, Salamandra salamandra 27 | n01630670 common newt, Triturus vulgaris 28 | n01631663 eft 29 | n01632458 spotted salamander, Ambystoma maculatum 30 | n01632777 axolotl, mud puppy, Ambystoma mexicanum 31 | n01641577 bullfrog, Rana catesbeiana 32 | n01644373 tree frog, tree-frog 33 | n01644900 tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui 34 | n01664065 loggerhead, loggerhead turtle, Caretta caretta 35 | n01665541 leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea 36 | n01667114 mud turtle 37 | n01667778 terrapin 38 | n01669191 box turtle, box tortoise 39 | n01675722 banded gecko 40 | n01677366 common iguana, iguana, Iguana iguana 41 | n01682714 American chameleon, anole, Anolis carolinensis 42 | n01685808 whiptail, whiptail lizard 43 | n01687978 agama 44 | n01688243 frilled lizard, Chlamydosaurus kingi 45 | n01689811 alligator lizard 46 | n01692333 Gila monster, Heloderma suspectum 47 | n01693334 green lizard, Lacerta viridis 48 | n01694178 African chameleon, Chamaeleo chamaeleon 49 | n01695060 Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis 50 | n01697457 African crocodile, Nile crocodile, Crocodylus niloticus 51 | n01698640 American alligator, Alligator mississipiensis 52 | n01704323 triceratops 53 | n01728572 thunder snake, worm snake, Carphophis amoenus 54 | n01728920 ringneck snake, ring-necked snake, ring snake 55 | n01729322 hognose snake, puff adder, sand viper 56 | n01729977 green snake, grass snake 57 | n01734418 king snake, kingsnake 58 | n01735189 garter snake, grass snake 59 | n01737021 water snake 60 | n01739381 vine snake 61 | n01740131 night snake, Hypsiglena torquata 62 | n01742172 boa constrictor, Constrictor constrictor 63 | n01744401 rock python, rock snake, Python sebae 64 | n01748264 Indian cobra, Naja naja 65 | n01749939 green mamba 66 | n01751748 sea snake 67 | n01753488 horned viper, cerastes, sand viper, horned asp, Cerastes cornutus 68 | n01755581 diamondback, diamondback rattlesnake, Crotalus adamanteus 69 | n01756291 sidewinder, horned rattlesnake, Crotalus cerastes 70 | n01768244 trilobite 71 | n01770081 harvestman, daddy longlegs, Phalangium opilio 72 | n01770393 scorpion 73 | n01773157 black and gold garden spider, Argiope aurantia 74 | n01773549 barn spider, Araneus cavaticus 75 | n01773797 garden spider, Aranea diademata 76 | n01774384 black widow, Latrodectus mactans 77 | n01774750 tarantula 78 | n01775062 wolf spider, hunting spider 79 | n01776313 tick 80 | n01784675 centipede 81 | n01795545 black grouse 82 | n01796340 ptarmigan 83 | n01797886 ruffed grouse, partridge, Bonasa umbellus 84 | n01798484 prairie chicken, prairie grouse, prairie fowl 85 | n01806143 peacock 86 | n01806567 quail 87 | n01807496 partridge 88 | n01817953 African grey, African gray, Psittacus erithacus 89 | n01818515 macaw 90 | n01819313 sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita 91 | n01820546 lorikeet 92 | n01824575 coucal 93 | n01828970 bee eater 94 | n01829413 hornbill 95 | n01833805 hummingbird 96 | n01843065 jacamar 97 | n01843383 toucan 98 | n01847000 drake 99 | n01855032 red-breasted merganser, Mergus serrator 100 | n01855672 goose 101 | n01860187 black swan, Cygnus atratus 102 | n01871265 tusker 103 | n01872401 echidna, spiny anteater, anteater 104 | n01873310 platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus 105 | n01877812 wallaby, brush kangaroo 106 | n01882714 koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus 107 | n01883070 wombat 108 | n01910747 jellyfish 109 | n01914609 sea anemone, anemone 110 | n01917289 brain coral 111 | n01924916 flatworm, platyhelminth 112 | n01930112 nematode, nematode worm, roundworm 113 | n01943899 conch 114 | n01944390 snail 115 | n01945685 slug 116 | n01950731 sea slug, nudibranch 117 | n01955084 chiton, coat-of-mail shell, sea cradle, polyplacophore 118 | n01968897 chambered nautilus, pearly nautilus, nautilus 119 | n01978287 Dungeness crab, Cancer magister 120 | n01978455 rock crab, Cancer irroratus 121 | n01980166 fiddler crab 122 | n01981276 king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica 123 | n01983481 American lobster, Northern lobster, Maine lobster, Homarus americanus 124 | n01984695 spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish 125 | n01985128 crayfish, crawfish, crawdad, crawdaddy 126 | n01986214 hermit crab 127 | n01990800 isopod 128 | n02002556 white stork, Ciconia ciconia 129 | n02002724 black stork, Ciconia nigra 130 | n02006656 spoonbill 131 | n02007558 flamingo 132 | n02009229 little blue heron, Egretta caerulea 133 | n02009912 American egret, great white heron, Egretta albus 134 | n02011460 bittern 135 | n02012849 crane 136 | n02013706 limpkin, Aramus pictus 137 | n02017213 European gallinule, Porphyrio porphyrio 138 | n02018207 American coot, marsh hen, mud hen, water hen, Fulica americana 139 | n02018795 bustard 140 | n02025239 ruddy turnstone, Arenaria interpres 141 | n02027492 red-backed sandpiper, dunlin, Erolia alpina 142 | n02028035 redshank, Tringa totanus 143 | n02033041 dowitcher 144 | n02037110 oystercatcher, oyster catcher 145 | n02051845 pelican 146 | n02056570 king penguin, Aptenodytes patagonica 147 | n02058221 albatross, mollymawk 148 | n02066245 grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus 149 | n02071294 killer whale, killer, orca, grampus, sea wolf, Orcinus orca 150 | n02074367 dugong, Dugong dugon 151 | n02077923 sea lion 152 | n02085620 Chihuahua 153 | n02085782 Japanese spaniel 154 | n02085936 Maltese dog, Maltese terrier, Maltese 155 | n02086079 Pekinese, Pekingese, Peke 156 | n02086240 Shih-Tzu 157 | n02086646 Blenheim spaniel 158 | n02086910 papillon 159 | n02087046 toy terrier 160 | n02087394 Rhodesian ridgeback 161 | n02088094 Afghan hound, Afghan 162 | n02088238 basset, basset hound 163 | n02088364 beagle 164 | n02088466 bloodhound, sleuthhound 165 | n02088632 bluetick 166 | n02089078 black-and-tan coonhound 167 | n02089867 Walker hound, Walker foxhound 168 | n02089973 English foxhound 169 | n02090379 redbone 170 | n02090622 borzoi, Russian wolfhound 171 | n02090721 Irish wolfhound 172 | n02091032 Italian greyhound 173 | n02091134 whippet 174 | n02091244 Ibizan hound, Ibizan Podenco 175 | n02091467 Norwegian elkhound, elkhound 176 | n02091635 otterhound, otter hound 177 | n02091831 Saluki, gazelle hound 178 | n02092002 Scottish deerhound, deerhound 179 | n02092339 Weimaraner 180 | n02093256 Staffordshire bullterrier, Staffordshire bull terrier 181 | n02093428 American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier 182 | n02093647 Bedlington terrier 183 | n02093754 Border terrier 184 | n02093859 Kerry blue terrier 185 | n02093991 Irish terrier 186 | n02094114 Norfolk terrier 187 | n02094258 Norwich terrier 188 | n02094433 Yorkshire terrier 189 | n02095314 wire-haired fox terrier 190 | n02095570 Lakeland terrier 191 | n02095889 Sealyham terrier, Sealyham 192 | n02096051 Airedale, Airedale terrier 193 | n02096177 cairn, cairn terrier 194 | n02096294 Australian terrier 195 | n02096437 Dandie Dinmont, Dandie Dinmont terrier 196 | n02096585 Boston bull, Boston terrier 197 | n02097047 miniature schnauzer 198 | n02097130 giant schnauzer 199 | n02097209 standard schnauzer 200 | n02097298 Scotch terrier, Scottish terrier, Scottie 201 | n02097474 Tibetan terrier, chrysanthemum dog 202 | n02097658 silky terrier, Sydney silky 203 | n02098105 soft-coated wheaten terrier 204 | n02098286 West Highland white terrier 205 | n02098413 Lhasa, Lhasa apso 206 | n02099267 flat-coated retriever 207 | n02099429 curly-coated retriever 208 | n02099601 golden retriever 209 | n02099712 Labrador retriever 210 | n02099849 Chesapeake Bay retriever 211 | n02100236 German short-haired pointer 212 | n02100583 vizsla, Hungarian pointer 213 | n02100735 English setter 214 | n02100877 Irish setter, red setter 215 | n02101006 Gordon setter 216 | n02101388 Brittany spaniel 217 | n02101556 clumber, clumber spaniel 218 | n02102040 English springer, English springer spaniel 219 | n02102177 Welsh springer spaniel 220 | n02102318 cocker spaniel, English cocker spaniel, cocker 221 | n02102480 Sussex spaniel 222 | n02102973 Irish water spaniel 223 | n02104029 kuvasz 224 | n02104365 schipperke 225 | n02105056 groenendael 226 | n02105162 malinois 227 | n02105251 briard 228 | n02105412 kelpie 229 | n02105505 komondor 230 | n02105641 Old English sheepdog, bobtail 231 | n02105855 Shetland sheepdog, Shetland sheep dog, Shetland 232 | n02106030 collie 233 | n02106166 Border collie 234 | n02106382 Bouvier des Flandres, Bouviers des Flandres 235 | n02106550 Rottweiler 236 | n02106662 German shepherd, German shepherd dog, German police dog, alsatian 237 | n02107142 Doberman, Doberman pinscher 238 | n02107312 miniature pinscher 239 | n02107574 Greater Swiss Mountain dog 240 | n02107683 Bernese mountain dog 241 | n02107908 Appenzeller 242 | n02108000 EntleBucher 243 | n02108089 boxer 244 | n02108422 bull mastiff 245 | n02108551 Tibetan mastiff 246 | n02108915 French bulldog 247 | n02109047 Great Dane 248 | n02109525 Saint Bernard, St Bernard 249 | n02109961 Eskimo dog, husky 250 | n02110063 malamute, malemute, Alaskan malamute 251 | n02110185 Siberian husky 252 | n02110341 dalmatian, coach dog, carriage dog 253 | n02110627 affenpinscher, monkey pinscher, monkey dog 254 | n02110806 basenji 255 | n02110958 pug, pug-dog 256 | n02111129 Leonberg 257 | n02111277 Newfoundland, Newfoundland dog 258 | n02111500 Great Pyrenees 259 | n02111889 Samoyed, Samoyede 260 | n02112018 Pomeranian 261 | n02112137 chow, chow chow 262 | n02112350 keeshond 263 | n02112706 Brabancon griffon 264 | n02113023 Pembroke, Pembroke Welsh corgi 265 | n02113186 Cardigan, Cardigan Welsh corgi 266 | n02113624 toy poodle 267 | n02113712 miniature poodle 268 | n02113799 standard poodle 269 | n02113978 Mexican hairless 270 | n02114367 timber wolf, grey wolf, gray wolf, Canis lupus 271 | n02114548 white wolf, Arctic wolf, Canis lupus tundrarum 272 | n02114712 red wolf, maned wolf, Canis rufus, Canis niger 273 | n02114855 coyote, prairie wolf, brush wolf, Canis latrans 274 | n02115641 dingo, warrigal, warragal, Canis dingo 275 | n02115913 dhole, Cuon alpinus 276 | n02116738 African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus 277 | n02117135 hyena, hyaena 278 | n02119022 red fox, Vulpes vulpes 279 | n02119789 kit fox, Vulpes macrotis 280 | n02120079 Arctic fox, white fox, Alopex lagopus 281 | n02120505 grey fox, gray fox, Urocyon cinereoargenteus 282 | n02123045 tabby, tabby cat 283 | n02123159 tiger cat 284 | n02123394 Persian cat 285 | n02123597 Siamese cat, Siamese 286 | n02124075 Egyptian cat 287 | n02125311 cougar, puma, catamount, mountain lion, painter, panther, Felis concolor 288 | n02127052 lynx, catamount 289 | n02128385 leopard, Panthera pardus 290 | n02128757 snow leopard, ounce, Panthera uncia 291 | n02128925 jaguar, panther, Panthera onca, Felis onca 292 | n02129165 lion, king of beasts, Panthera leo 293 | n02129604 tiger, Panthera tigris 294 | n02130308 cheetah, chetah, Acinonyx jubatus 295 | n02132136 brown bear, bruin, Ursus arctos 296 | n02133161 American black bear, black bear, Ursus americanus, Euarctos americanus 297 | n02134084 ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus 298 | n02134418 sloth bear, Melursus ursinus, Ursus ursinus 299 | n02137549 mongoose 300 | n02138441 meerkat, mierkat 301 | n02165105 tiger beetle 302 | n02165456 ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle 303 | n02167151 ground beetle, carabid beetle 304 | n02168699 long-horned beetle, longicorn, longicorn beetle 305 | n02169497 leaf beetle, chrysomelid 306 | n02172182 dung beetle 307 | n02174001 rhinoceros beetle 308 | n02177972 weevil 309 | n02190166 fly 310 | n02206856 bee 311 | n02219486 ant, emmet, pismire 312 | n02226429 grasshopper, hopper 313 | n02229544 cricket 314 | n02231487 walking stick, walkingstick, stick insect 315 | n02233338 cockroach, roach 316 | n02236044 mantis, mantid 317 | n02256656 cicada, cicala 318 | n02259212 leafhopper 319 | n02264363 lacewing, lacewing fly 320 | n02268443 dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk 321 | n02268853 damselfly 322 | n02276258 admiral 323 | n02277742 ringlet, ringlet butterfly 324 | n02279972 monarch, monarch butterfly, milkweed butterfly, Danaus plexippus 325 | n02280649 cabbage butterfly 326 | n02281406 sulphur butterfly, sulfur butterfly 327 | n02281787 lycaenid, lycaenid butterfly 328 | n02317335 starfish, sea star 329 | n02319095 sea urchin 330 | n02321529 sea cucumber, holothurian 331 | n02325366 wood rabbit, cottontail, cottontail rabbit 332 | n02326432 hare 333 | n02328150 Angora, Angora rabbit 334 | n02342885 hamster 335 | n02346627 porcupine, hedgehog 336 | n02356798 fox squirrel, eastern fox squirrel, Sciurus niger 337 | n02361337 marmot 338 | n02363005 beaver 339 | n02364673 guinea pig, Cavia cobaya 340 | n02389026 sorrel 341 | n02391049 zebra 342 | n02395406 hog, pig, grunter, squealer, Sus scrofa 343 | n02396427 wild boar, boar, Sus scrofa 344 | n02397096 warthog 345 | n02398521 hippopotamus, hippo, river horse, Hippopotamus amphibius 346 | n02403003 ox 347 | n02408429 water buffalo, water ox, Asiatic buffalo, Bubalus bubalis 348 | n02410509 bison 349 | n02412080 ram, tup 350 | n02415577 bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis 351 | n02417914 ibex, Capra ibex 352 | n02422106 hartebeest 353 | n02422699 impala, Aepyceros melampus 354 | n02423022 gazelle 355 | n02437312 Arabian camel, dromedary, Camelus dromedarius 356 | n02437616 llama 357 | n02441942 weasel 358 | n02442845 mink 359 | n02443114 polecat, fitch, foulmart, foumart, Mustela putorius 360 | n02443484 black-footed ferret, ferret, Mustela nigripes 361 | n02444819 otter 362 | n02445715 skunk, polecat, wood pussy 363 | n02447366 badger 364 | n02454379 armadillo 365 | n02457408 three-toed sloth, ai, Bradypus tridactylus 366 | n02480495 orangutan, orang, orangutang, Pongo pygmaeus 367 | n02480855 gorilla, Gorilla gorilla 368 | n02481823 chimpanzee, chimp, Pan troglodytes 369 | n02483362 gibbon, Hylobates lar 370 | n02483708 siamang, Hylobates syndactylus, Symphalangus syndactylus 371 | n02484975 guenon, guenon monkey 372 | n02486261 patas, hussar monkey, Erythrocebus patas 373 | n02486410 baboon 374 | n02487347 macaque 375 | n02488291 langur 376 | n02488702 colobus, colobus monkey 377 | n02489166 proboscis monkey, Nasalis larvatus 378 | n02490219 marmoset 379 | n02492035 capuchin, ringtail, Cebus capucinus 380 | n02492660 howler monkey, howler 381 | n02493509 titi, titi monkey 382 | n02493793 spider monkey, Ateles geoffroyi 383 | n02494079 squirrel monkey, Saimiri sciureus 384 | n02497673 Madagascar cat, ring-tailed lemur, Lemur catta 385 | n02500267 indri, indris, Indri indri, Indri brevicaudatus 386 | n02504013 Indian elephant, Elephas maximus 387 | n02504458 African elephant, Loxodonta africana 388 | n02509815 lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens 389 | n02510455 giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca 390 | n02514041 barracouta, snoek 391 | n02526121 eel 392 | n02536864 coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch 393 | n02606052 rock beauty, Holocanthus tricolor 394 | n02607072 anemone fish 395 | n02640242 sturgeon 396 | n02641379 gar, garfish, garpike, billfish, Lepisosteus osseus 397 | n02643566 lionfish 398 | n02655020 puffer, pufferfish, blowfish, globefish 399 | n02666196 abacus 400 | n02667093 abaya 401 | n02669723 academic gown, academic robe, judge's robe 402 | n02672831 accordion, piano accordion, squeeze box 403 | n02676566 acoustic guitar 404 | n02687172 aircraft carrier, carrier, flattop, attack aircraft carrier 405 | n02690373 airliner 406 | n02692877 airship, dirigible 407 | n02699494 altar 408 | n02701002 ambulance 409 | n02704792 amphibian, amphibious vehicle 410 | n02708093 analog clock 411 | n02727426 apiary, bee house 412 | n02730930 apron 413 | n02747177 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin 414 | n02749479 assault rifle, assault gun 415 | n02769748 backpack, back pack, knapsack, packsack, rucksack, haversack 416 | n02776631 bakery, bakeshop, bakehouse 417 | n02777292 balance beam, beam 418 | n02782093 balloon 419 | n02783161 ballpoint, ballpoint pen, ballpen, Biro 420 | n02786058 Band Aid 421 | n02787622 banjo 422 | n02788148 bannister, banister, balustrade, balusters, handrail 423 | n02790996 barbell 424 | n02791124 barber chair 425 | n02791270 barbershop 426 | n02793495 barn 427 | n02794156 barometer 428 | n02795169 barrel, cask 429 | n02797295 barrow, garden cart, lawn cart, wheelbarrow 430 | n02799071 baseball 431 | n02802426 basketball 432 | n02804414 bassinet 433 | n02804610 bassoon 434 | n02807133 bathing cap, swimming cap 435 | n02808304 bath towel 436 | n02808440 bathtub, bathing tub, bath, tub 437 | n02814533 beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon 438 | n02814860 beacon, lighthouse, beacon light, pharos 439 | n02815834 beaker 440 | n02817516 bearskin, busby, shako 441 | n02823428 beer bottle 442 | n02823750 beer glass 443 | n02825657 bell cote, bell cot 444 | n02834397 bib 445 | n02835271 bicycle-built-for-two, tandem bicycle, tandem 446 | n02837789 bikini, two-piece 447 | n02840245 binder, ring-binder 448 | n02841315 binoculars, field glasses, opera glasses 449 | n02843684 birdhouse 450 | n02859443 boathouse 451 | n02860847 bobsled, bobsleigh, bob 452 | n02865351 bolo tie, bolo, bola tie, bola 453 | n02869837 bonnet, poke bonnet 454 | n02870880 bookcase 455 | n02871525 bookshop, bookstore, bookstall 456 | n02877765 bottlecap 457 | n02879718 bow 458 | n02883205 bow tie, bow-tie, bowtie 459 | n02892201 brass, memorial tablet, plaque 460 | n02892767 brassiere, bra, bandeau 461 | n02894605 breakwater, groin, groyne, mole, bulwark, seawall, jetty 462 | n02895154 breastplate, aegis, egis 463 | n02906734 broom 464 | n02909870 bucket, pail 465 | n02910353 buckle 466 | n02916936 bulletproof vest 467 | n02917067 bullet train, bullet 468 | n02927161 butcher shop, meat market 469 | n02930766 cab, hack, taxi, taxicab 470 | n02939185 caldron, cauldron 471 | n02948072 candle, taper, wax light 472 | n02950826 cannon 473 | n02951358 canoe 474 | n02951585 can opener, tin opener 475 | n02963159 cardigan 476 | n02965783 car mirror 477 | n02966193 carousel, carrousel, merry-go-round, roundabout, whirligig 478 | n02966687 carpenter's kit, tool kit 479 | n02971356 carton 480 | n02974003 car wheel 481 | n02977058 cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM 482 | n02978881 cassette 483 | n02979186 cassette player 484 | n02980441 castle 485 | n02981792 catamaran 486 | n02988304 CD player 487 | n02992211 cello, violoncello 488 | n02992529 cellular telephone, cellular phone, cellphone, cell, mobile phone 489 | n02999410 chain 490 | n03000134 chainlink fence 491 | n03000247 chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour 492 | n03000684 chain saw, chainsaw 493 | n03014705 chest 494 | n03016953 chiffonier, commode 495 | n03017168 chime, bell, gong 496 | n03018349 china cabinet, china closet 497 | n03026506 Christmas stocking 498 | n03028079 church, church building 499 | n03032252 cinema, movie theater, movie theatre, movie house, picture palace 500 | n03041632 cleaver, meat cleaver, chopper 501 | n03042490 cliff dwelling 502 | n03045698 cloak 503 | n03047690 clog, geta, patten, sabot 504 | n03062245 cocktail shaker 505 | n03063599 coffee mug 506 | n03063689 coffeepot 507 | n03065424 coil, spiral, volute, whorl, helix 508 | n03075370 combination lock 509 | n03085013 computer keyboard, keypad 510 | n03089624 confectionery, confectionary, candy store 511 | n03095699 container ship, containership, container vessel 512 | n03100240 convertible 513 | n03109150 corkscrew, bottle screw 514 | n03110669 cornet, horn, trumpet, trump 515 | n03124043 cowboy boot 516 | n03124170 cowboy hat, ten-gallon hat 517 | n03125729 cradle 518 | n03126707 crane 519 | n03127747 crash helmet 520 | n03127925 crate 521 | n03131574 crib, cot 522 | n03133878 Crock Pot 523 | n03134739 croquet ball 524 | n03141823 crutch 525 | n03146219 cuirass 526 | n03160309 dam, dike, dyke 527 | n03179701 desk 528 | n03180011 desktop computer 529 | n03187595 dial telephone, dial phone 530 | n03188531 diaper, nappy, napkin 531 | n03196217 digital clock 532 | n03197337 digital watch 533 | n03201208 dining table, board 534 | n03207743 dishrag, dishcloth 535 | n03207941 dishwasher, dish washer, dishwashing machine 536 | n03208938 disk brake, disc brake 537 | n03216828 dock, dockage, docking facility 538 | n03218198 dogsled, dog sled, dog sleigh 539 | n03220513 dome 540 | n03223299 doormat, welcome mat 541 | n03240683 drilling platform, offshore rig 542 | n03249569 drum, membranophone, tympan 543 | n03250847 drumstick 544 | n03255030 dumbbell 545 | n03259280 Dutch oven 546 | n03271574 electric fan, blower 547 | n03272010 electric guitar 548 | n03272562 electric locomotive 549 | n03290653 entertainment center 550 | n03291819 envelope 551 | n03297495 espresso maker 552 | n03314780 face powder 553 | n03325584 feather boa, boa 554 | n03337140 file, file cabinet, filing cabinet 555 | n03344393 fireboat 556 | n03345487 fire engine, fire truck 557 | n03347037 fire screen, fireguard 558 | n03355925 flagpole, flagstaff 559 | n03372029 flute, transverse flute 560 | n03376595 folding chair 561 | n03379051 football helmet 562 | n03384352 forklift 563 | n03388043 fountain 564 | n03388183 fountain pen 565 | n03388549 four-poster 566 | n03393912 freight car 567 | n03394916 French horn, horn 568 | n03400231 frying pan, frypan, skillet 569 | n03404251 fur coat 570 | n03417042 garbage truck, dustcart 571 | n03424325 gasmask, respirator, gas helmet 572 | n03425413 gas pump, gasoline pump, petrol pump, island dispenser 573 | n03443371 goblet 574 | n03444034 go-kart 575 | n03445777 golf ball 576 | n03445924 golfcart, golf cart 577 | n03447447 gondola 578 | n03447721 gong, tam-tam 579 | n03450230 gown 580 | n03452741 grand piano, grand 581 | n03457902 greenhouse, nursery, glasshouse 582 | n03459775 grille, radiator grille 583 | n03461385 grocery store, grocery, food market, market 584 | n03467068 guillotine 585 | n03476684 hair slide 586 | n03476991 hair spray 587 | n03478589 half track 588 | n03481172 hammer 589 | n03482405 hamper 590 | n03483316 hand blower, blow dryer, blow drier, hair dryer, hair drier 591 | n03485407 hand-held computer, hand-held microcomputer 592 | n03485794 handkerchief, hankie, hanky, hankey 593 | n03492542 hard disc, hard disk, fixed disk 594 | n03494278 harmonica, mouth organ, harp, mouth harp 595 | n03495258 harp 596 | n03496892 harvester, reaper 597 | n03498962 hatchet 598 | n03527444 holster 599 | n03529860 home theater, home theatre 600 | n03530642 honeycomb 601 | n03532672 hook, claw 602 | n03534580 hoopskirt, crinoline 603 | n03535780 horizontal bar, high bar 604 | n03538406 horse cart, horse-cart 605 | n03544143 hourglass 606 | n03584254 iPod 607 | n03584829 iron, smoothing iron 608 | n03590841 jack-o'-lantern 609 | n03594734 jean, blue jean, denim 610 | n03594945 jeep, landrover 611 | n03595614 jersey, T-shirt, tee shirt 612 | n03598930 jigsaw puzzle 613 | n03599486 jinrikisha, ricksha, rickshaw 614 | n03602883 joystick 615 | n03617480 kimono 616 | n03623198 knee pad 617 | n03627232 knot 618 | n03630383 lab coat, laboratory coat 619 | n03633091 ladle 620 | n03637318 lampshade, lamp shade 621 | n03642806 laptop, laptop computer 622 | n03649909 lawn mower, mower 623 | n03657121 lens cap, lens cover 624 | n03658185 letter opener, paper knife, paperknife 625 | n03661043 library 626 | n03662601 lifeboat 627 | n03666591 lighter, light, igniter, ignitor 628 | n03670208 limousine, limo 629 | n03673027 liner, ocean liner 630 | n03676483 lipstick, lip rouge 631 | n03680355 Loafer 632 | n03690938 lotion 633 | n03691459 loudspeaker, speaker, speaker unit, loudspeaker system, speaker system 634 | n03692522 loupe, jeweler's loupe 635 | n03697007 lumbermill, sawmill 636 | n03706229 magnetic compass 637 | n03709823 mailbag, postbag 638 | n03710193 mailbox, letter box 639 | n03710637 maillot 640 | n03710721 maillot, tank suit 641 | n03717622 manhole cover 642 | n03720891 maraca 643 | n03721384 marimba, xylophone 644 | n03724870 mask 645 | n03729826 matchstick 646 | n03733131 maypole 647 | n03733281 maze, labyrinth 648 | n03733805 measuring cup 649 | n03742115 medicine chest, medicine cabinet 650 | n03743016 megalith, megalithic structure 651 | n03759954 microphone, mike 652 | n03761084 microwave, microwave oven 653 | n03763968 military uniform 654 | n03764736 milk can 655 | n03769881 minibus 656 | n03770439 miniskirt, mini 657 | n03770679 minivan 658 | n03773504 missile 659 | n03775071 mitten 660 | n03775546 mixing bowl 661 | n03776460 mobile home, manufactured home 662 | n03777568 Model T 663 | n03777754 modem 664 | n03781244 monastery 665 | n03782006 monitor 666 | n03785016 moped 667 | n03786901 mortar 668 | n03787032 mortarboard 669 | n03788195 mosque 670 | n03788365 mosquito net 671 | n03791053 motor scooter, scooter 672 | n03792782 mountain bike, all-terrain bike, off-roader 673 | n03792972 mountain tent 674 | n03793489 mouse, computer mouse 675 | n03794056 mousetrap 676 | n03796401 moving van 677 | n03803284 muzzle 678 | n03804744 nail 679 | n03814639 neck brace 680 | n03814906 necklace 681 | n03825788 nipple 682 | n03832673 notebook, notebook computer 683 | n03837869 obelisk 684 | n03838899 oboe, hautboy, hautbois 685 | n03840681 ocarina, sweet potato 686 | n03841143 odometer, hodometer, mileometer, milometer 687 | n03843555 oil filter 688 | n03854065 organ, pipe organ 689 | n03857828 oscilloscope, scope, cathode-ray oscilloscope, CRO 690 | n03866082 overskirt 691 | n03868242 oxcart 692 | n03868863 oxygen mask 693 | n03871628 packet 694 | n03873416 paddle, boat paddle 695 | n03874293 paddlewheel, paddle wheel 696 | n03874599 padlock 697 | n03876231 paintbrush 698 | n03877472 pajama, pyjama, pj's, jammies 699 | n03877845 palace 700 | n03884397 panpipe, pandean pipe, syrinx 701 | n03887697 paper towel 702 | n03888257 parachute, chute 703 | n03888605 parallel bars, bars 704 | n03891251 park bench 705 | n03891332 parking meter 706 | n03895866 passenger car, coach, carriage 707 | n03899768 patio, terrace 708 | n03902125 pay-phone, pay-station 709 | n03903868 pedestal, plinth, footstall 710 | n03908618 pencil box, pencil case 711 | n03908714 pencil sharpener 712 | n03916031 perfume, essence 713 | n03920288 Petri dish 714 | n03924679 photocopier 715 | n03929660 pick, plectrum, plectron 716 | n03929855 pickelhaube 717 | n03930313 picket fence, paling 718 | n03930630 pickup, pickup truck 719 | n03933933 pier 720 | n03935335 piggy bank, penny bank 721 | n03937543 pill bottle 722 | n03938244 pillow 723 | n03942813 ping-pong ball 724 | n03944341 pinwheel 725 | n03947888 pirate, pirate ship 726 | n03950228 pitcher, ewer 727 | n03954731 plane, carpenter's plane, woodworking plane 728 | n03956157 planetarium 729 | n03958227 plastic bag 730 | n03961711 plate rack 731 | n03967562 plow, plough 732 | n03970156 plunger, plumber's helper 733 | n03976467 Polaroid camera, Polaroid Land camera 734 | n03976657 pole 735 | n03977966 police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria 736 | n03980874 poncho 737 | n03982430 pool table, billiard table, snooker table 738 | n03983396 pop bottle, soda bottle 739 | n03991062 pot, flowerpot 740 | n03992509 potter's wheel 741 | n03995372 power drill 742 | n03998194 prayer rug, prayer mat 743 | n04004767 printer 744 | n04005630 prison, prison house 745 | n04008634 projectile, missile 746 | n04009552 projector 747 | n04019541 puck, hockey puck 748 | n04023962 punching bag, punch bag, punching ball, punchball 749 | n04026417 purse 750 | n04033901 quill, quill pen 751 | n04033995 quilt, comforter, comfort, puff 752 | n04037443 racer, race car, racing car 753 | n04039381 racket, racquet 754 | n04040759 radiator 755 | n04041544 radio, wireless 756 | n04044716 radio telescope, radio reflector 757 | n04049303 rain barrel 758 | n04065272 recreational vehicle, RV, R.V. 759 | n04067472 reel 760 | n04069434 reflex camera 761 | n04070727 refrigerator, icebox 762 | n04074963 remote control, remote 763 | n04081281 restaurant, eating house, eating place, eatery 764 | n04086273 revolver, six-gun, six-shooter 765 | n04090263 rifle 766 | n04099969 rocking chair, rocker 767 | n04111531 rotisserie 768 | n04116512 rubber eraser, rubber, pencil eraser 769 | n04118538 rugby ball 770 | n04118776 rule, ruler 771 | n04120489 running shoe 772 | n04125021 safe 773 | n04127249 safety pin 774 | n04131690 saltshaker, salt shaker 775 | n04133789 sandal 776 | n04136333 sarong 777 | n04141076 sax, saxophone 778 | n04141327 scabbard 779 | n04141975 scale, weighing machine 780 | n04146614 school bus 781 | n04147183 schooner 782 | n04149813 scoreboard 783 | n04152593 screen, CRT screen 784 | n04153751 screw 785 | n04154565 screwdriver 786 | n04162706 seat belt, seatbelt 787 | n04179913 sewing machine 788 | n04192698 shield, buckler 789 | n04200800 shoe shop, shoe-shop, shoe store 790 | n04201297 shoji 791 | n04204238 shopping basket 792 | n04204347 shopping cart 793 | n04208210 shovel 794 | n04209133 shower cap 795 | n04209239 shower curtain 796 | n04228054 ski 797 | n04229816 ski mask 798 | n04235860 sleeping bag 799 | n04238763 slide rule, slipstick 800 | n04239074 sliding door 801 | n04243546 slot, one-armed bandit 802 | n04251144 snorkel 803 | n04252077 snowmobile 804 | n04252225 snowplow, snowplough 805 | n04254120 soap dispenser 806 | n04254680 soccer ball 807 | n04254777 sock 808 | n04258138 solar dish, solar collector, solar furnace 809 | n04259630 sombrero 810 | n04263257 soup bowl 811 | n04264628 space bar 812 | n04265275 space heater 813 | n04266014 space shuttle 814 | n04270147 spatula 815 | n04273569 speedboat 816 | n04275548 spider web, spider's web 817 | n04277352 spindle 818 | n04285008 sports car, sport car 819 | n04286575 spotlight, spot 820 | n04296562 stage 821 | n04310018 steam locomotive 822 | n04311004 steel arch bridge 823 | n04311174 steel drum 824 | n04317175 stethoscope 825 | n04325704 stole 826 | n04326547 stone wall 827 | n04328186 stopwatch, stop watch 828 | n04330267 stove 829 | n04332243 strainer 830 | n04335435 streetcar, tram, tramcar, trolley, trolley car 831 | n04336792 stretcher 832 | n04344873 studio couch, day bed 833 | n04346328 stupa, tope 834 | n04347754 submarine, pigboat, sub, U-boat 835 | n04350905 suit, suit of clothes 836 | n04355338 sundial 837 | n04355933 sunglass 838 | n04356056 sunglasses, dark glasses, shades 839 | n04357314 sunscreen, sunblock, sun blocker 840 | n04366367 suspension bridge 841 | n04367480 swab, swob, mop 842 | n04370456 sweatshirt 843 | n04371430 swimming trunks, bathing trunks 844 | n04371774 swing 845 | n04372370 switch, electric switch, electrical switch 846 | n04376876 syringe 847 | n04380533 table lamp 848 | n04389033 tank, army tank, armored combat vehicle, armoured combat vehicle 849 | n04392985 tape player 850 | n04398044 teapot 851 | n04399382 teddy, teddy bear 852 | n04404412 television, television system 853 | n04409515 tennis ball 854 | n04417672 thatch, thatched roof 855 | n04418357 theater curtain, theatre curtain 856 | n04423845 thimble 857 | n04428191 thresher, thrasher, threshing machine 858 | n04429376 throne 859 | n04435653 tile roof 860 | n04442312 toaster 861 | n04443257 tobacco shop, tobacconist shop, tobacconist 862 | n04447861 toilet seat 863 | n04456115 torch 864 | n04458633 totem pole 865 | n04461696 tow truck, tow car, wrecker 866 | n04462240 toyshop 867 | n04465501 tractor 868 | n04467665 trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi 869 | n04476259 tray 870 | n04479046 trench coat 871 | n04482393 tricycle, trike, velocipede 872 | n04483307 trimaran 873 | n04485082 tripod 874 | n04486054 triumphal arch 875 | n04487081 trolleybus, trolley coach, trackless trolley 876 | n04487394 trombone 877 | n04493381 tub, vat 878 | n04501370 turnstile 879 | n04505470 typewriter keyboard 880 | n04507155 umbrella 881 | n04509417 unicycle, monocycle 882 | n04515003 upright, upright piano 883 | n04517823 vacuum, vacuum cleaner 884 | n04522168 vase 885 | n04523525 vault 886 | n04525038 velvet 887 | n04525305 vending machine 888 | n04532106 vestment 889 | n04532670 viaduct 890 | n04536866 violin, fiddle 891 | n04540053 volleyball 892 | n04542943 waffle iron 893 | n04548280 wall clock 894 | n04548362 wallet, billfold, notecase, pocketbook 895 | n04550184 wardrobe, closet, press 896 | n04552348 warplane, military plane 897 | n04553703 washbasin, handbasin, washbowl, lavabo, wash-hand basin 898 | n04554684 washer, automatic washer, washing machine 899 | n04557648 water bottle 900 | n04560804 water jug 901 | n04562935 water tower 902 | n04579145 whiskey jug 903 | n04579432 whistle 904 | n04584207 wig 905 | n04589890 window screen 906 | n04590129 window shade 907 | n04591157 Windsor tie 908 | n04591713 wine bottle 909 | n04592741 wing 910 | n04596742 wok 911 | n04597913 wooden spoon 912 | n04599235 wool, woolen, woollen 913 | n04604644 worm fence, snake fence, snake-rail fence, Virginia fence 914 | n04606251 wreck 915 | n04612504 yawl 916 | n04613696 yurt 917 | n06359193 web site, website, internet site, site 918 | n06596364 comic book 919 | n06785654 crossword puzzle, crossword 920 | n06794110 street sign 921 | n06874185 traffic light, traffic signal, stoplight 922 | n07248320 book jacket, dust cover, dust jacket, dust wrapper 923 | n07565083 menu 924 | n07579787 plate 925 | n07583066 guacamole 926 | n07584110 consomme 927 | n07590611 hot pot, hotpot 928 | n07613480 trifle 929 | n07614500 ice cream, icecream 930 | n07615774 ice lolly, lolly, lollipop, popsicle 931 | n07684084 French loaf 932 | n07693725 bagel, beigel 933 | n07695742 pretzel 934 | n07697313 cheeseburger 935 | n07697537 hotdog, hot dog, red hot 936 | n07711569 mashed potato 937 | n07714571 head cabbage 938 | n07714990 broccoli 939 | n07715103 cauliflower 940 | n07716358 zucchini, courgette 941 | n07716906 spaghetti squash 942 | n07717410 acorn squash 943 | n07717556 butternut squash 944 | n07718472 cucumber, cuke 945 | n07718747 artichoke, globe artichoke 946 | n07720875 bell pepper 947 | n07730033 cardoon 948 | n07734744 mushroom 949 | n07742313 Granny Smith 950 | n07745940 strawberry 951 | n07747607 orange 952 | n07749582 lemon 953 | n07753113 fig 954 | n07753275 pineapple, ananas 955 | n07753592 banana 956 | n07754684 jackfruit, jak, jack 957 | n07760859 custard apple 958 | n07768694 pomegranate 959 | n07802026 hay 960 | n07831146 carbonara 961 | n07836838 chocolate sauce, chocolate syrup 962 | n07860988 dough 963 | n07871810 meat loaf, meatloaf 964 | n07873807 pizza, pizza pie 965 | n07875152 potpie 966 | n07880968 burrito 967 | n07892512 red wine 968 | n07920052 espresso 969 | n07930864 cup 970 | n07932039 eggnog 971 | n09193705 alp 972 | n09229709 bubble 973 | n09246464 cliff, drop, drop-off 974 | n09256479 coral reef 975 | n09288635 geyser 976 | n09332890 lakeside, lakeshore 977 | n09399592 promontory, headland, head, foreland 978 | n09421951 sandbar, sand bar 979 | n09428293 seashore, coast, seacoast, sea-coast 980 | n09468604 valley, vale 981 | n09472597 volcano 982 | n09835506 ballplayer, baseball player 983 | n10148035 groom, bridegroom 984 | n10565667 scuba diver 985 | n11879895 rapeseed 986 | n11939491 daisy 987 | n12057211 yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum 988 | n12144580 corn 989 | n12267677 acorn 990 | n12620546 hip, rose hip, rosehip 991 | n12768682 buckeye, horse chestnut, conker 992 | n12985857 coral fungus 993 | n12998815 agaric 994 | n13037406 gyromitra 995 | n13040303 stinkhorn, carrion fungus 996 | n13044778 earthstar 997 | n13052670 hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa 998 | n13054560 bolete 999 | n13133613 ear, spike, capitulum 1000 | n15075141 toilet tissue, toilet paper, bathroom tissue''' -------------------------------------------------------------------------------- /sal/saliency_model.py: -------------------------------------------------------------------------------- 1 | from utils.pytorch_fixes import * 2 | from torch.nn import functional as F 3 | import torch.nn as nn 4 | from torch.nn import Module 5 | from sal.utils.mask import * 6 | from torchvision.models.resnet import resnet50 7 | import os 8 | 9 | def get_black_box_fn(model_zoo_model=resnet50, cuda=True, image_domain=(-2., 2.)): 10 | ''' You can try any model from the pytorch model zoo (torchvision.models) 11 | eg. VGG, inception, mobilenet, alexnet... 12 | ''' 13 | black_box_model = model_zoo_model(pretrained=True) 14 | 15 | black_box_model.train(False) 16 | if cuda: 17 | black_box_model = torch.nn.DataParallel(black_box_model).cuda() 18 | 19 | def black_box_fn(_images): 20 | return black_box_model(adapt_to_image_domain(_images, image_domain)) 21 | return black_box_fn 22 | 23 | 24 | 25 | class SaliencyModel(Module): 26 | def __init__(self, encoder, encoder_scales, encoder_base, upsampler_scales, upsampler_base, fix_encoder=True, 27 | use_simple_activation=False, allow_selector=False, num_classes=1000): 28 | super(SaliencyModel, self).__init__() 29 | assert upsampler_scales <= encoder_scales 30 | 31 | self.encoder = encoder # decoder must return at least scale0 to scaleN where N is num_scales 32 | self.upsampler_scales = upsampler_scales 33 | self.encoder_scales = encoder_scales 34 | self.fix_encoder = fix_encoder 35 | self.use_simple_activation = use_simple_activation 36 | 37 | # now build the decoder for the specified number of scales 38 | # start with the top scale 39 | down = self.encoder_scales 40 | modulator_sizes = [] 41 | for up in reversed(xrange(self.upsampler_scales)): 42 | upsampler_chans = upsampler_base * 2**(up+1) 43 | encoder_chans = encoder_base * 2**down 44 | inc = upsampler_chans if down!=encoder_scales else encoder_chans 45 | modulator_sizes.append(inc) 46 | self.add_module('up%d'%up, 47 | UNetUpsampler( 48 | in_channels=inc, 49 | passthrough_channels=encoder_chans/2, 50 | out_channels=upsampler_chans/2, 51 | follow_up_residual_blocks=1, 52 | activation_fn=lambda: nn.ReLU(), 53 | )) 54 | down -= 1 55 | 56 | self.to_saliency_chans = nn.Conv2d(upsampler_base, 2, 1) 57 | 58 | self.allow_selector = allow_selector 59 | 60 | if self.allow_selector: 61 | s = encoder_base*2**encoder_scales 62 | self.selector_module = nn.Embedding(num_classes, s) 63 | self.selector_module.weight.data.normal_(0, 1./s**0.5) 64 | 65 | 66 | def minimialistic_restore(self, save_dir): 67 | assert self.fix_encoder, 'You should not use this function if you are not using a pre-trained encoder like resnet' 68 | 69 | p = os.path.join(save_dir, 'model-%d.ckpt' % 1) 70 | if not os.path.exists(p): 71 | print 'Could not find any checkpoint at %s, skipping restore' % p 72 | return 73 | for name, data in torch.load(p, map_location=lambda storage, loc: storage).items(): 74 | self._modules[name].load_state_dict(data) 75 | 76 | def minimalistic_save(self, save_dir): 77 | assert self.fix_encoder, 'You should not use this function if you are not using a pre-trained encoder like resnet' 78 | data = {} 79 | for name, module in self._modules.items(): 80 | if module is self.encoder: # we do not want to restore the encoder as it should have its own restore function 81 | continue 82 | data[name] = module.state_dict() 83 | if not os.path.exists(save_dir): 84 | os.mkdir(save_dir) 85 | torch.save(data, os.path.join(save_dir, 'model-%d.ckpt' % 1)) 86 | 87 | 88 | def get_trainable_parameters(self): 89 | all_params = self.parameters() 90 | if not self.fix_encoder: return set(all_params) 91 | unwanted = self.encoder.parameters() 92 | return set(all_params) - set(unwanted) - (set(self.selector_module.parameters()) if self.allow_selector else set([])) 93 | 94 | def forward(self, _images, _selectors=None, pt_store=None, model_confidence=0.): 95 | # forward pass through the encoder 96 | out = self.encoder(_images) 97 | if self.fix_encoder: 98 | out = [e.detach() for e in out] 99 | 100 | down = self.encoder_scales 101 | main_flow = out[down] 102 | 103 | if self.allow_selector: 104 | assert _selectors is not None 105 | em = torch.squeeze(self.selector_module(_selectors.view(-1, 1)), 1) 106 | act = torch.sum(main_flow*em.view(-1, 2048, 1, 1), 1, keepdim=True) 107 | th = torch.sigmoid(act-model_confidence) 108 | main_flow = main_flow*th 109 | 110 | ex = torch.mean(torch.mean(act, 3), 2) 111 | exists_logits = torch.cat((-ex / 2., ex / 2.), 1) 112 | else: 113 | exists_logits = None 114 | 115 | for up in reversed(xrange(self.upsampler_scales)): 116 | assert down > 0 117 | main_flow = self._modules['up%d'%up](main_flow, out[down-1]) 118 | down -= 1 119 | # now get the final saliency map (the reslution of the map = resolution_of_the_image / (2**(encoder_scales-upsampler_scales))) 120 | saliency_chans = self.to_saliency_chans(main_flow) 121 | 122 | if self.use_simple_activation: 123 | return torch.unsqueeze(torch.sigmoid(saliency_chans[:,0,:,:]/2), dim=1), exists_logits, out[-1] 124 | 125 | 126 | a = torch.abs(saliency_chans[:,0,:,:]) 127 | b = torch.abs(saliency_chans[:,1,:,:]) 128 | return torch.unsqueeze(a/(a+b), dim=1), exists_logits, out[-1] 129 | 130 | 131 | 132 | class SaliencyLoss: 133 | def __init__(self, black_box_fn, area_loss_coef=8, smoothness_loss_coef=0.5, preserver_loss_coef=0.3, 134 | num_classes=1000, area_loss_power=0.3, preserver_confidence=1, destroyer_confidence=5, **apply_mask_kwargs): 135 | self.black_box_fn = black_box_fn 136 | self.area_loss_coef = area_loss_coef 137 | self.smoothness_loss_coef = smoothness_loss_coef 138 | self.preserver_loss_coef = preserver_loss_coef 139 | self.num_classes = num_classes 140 | self.area_loss_power =area_loss_power 141 | self.preserver_confidence = preserver_confidence 142 | self.destroyer_confidence = destroyer_confidence 143 | self.apply_mask_kwargs = apply_mask_kwargs 144 | 145 | def get_loss(self, _images, _targets, _masks, _is_real_target=None, pt_store=None): 146 | ''' masks must be already in the range 0,1 and of shape: (B, 1, ?, ?)''' 147 | if _masks.size()[-2:] != _images.size()[-2:]: 148 | _masks = F.upsample(_masks, (_images.size(2), _images.size(3)), mode='bilinear') 149 | 150 | if _is_real_target is None: 151 | _is_real_target = Variable(tensor_like(_targets).fill_(1.)) 152 | destroyed_images = apply_mask(_images, 1.-_masks, **self.apply_mask_kwargs) 153 | destroyed_logits = self.black_box_fn(destroyed_images) 154 | 155 | preserved_images = apply_mask(_images, _masks, **self.apply_mask_kwargs) 156 | preserved_logits = self.black_box_fn(preserved_images) 157 | 158 | _one_hot_targets = one_hot(_targets, self.num_classes) 159 | preserver_loss = cw_loss(preserved_logits, _one_hot_targets, targeted=_is_real_target == 1, t_conf=self.preserver_confidence, nt_conf=1.) 160 | destroyer_loss = cw_loss(destroyed_logits, _one_hot_targets, targeted=_is_real_target == 0, t_conf=1., nt_conf=self.destroyer_confidence) 161 | area_loss = calc_area_loss(_masks, self.area_loss_power) 162 | smoothness_loss = calc_smoothness_loss(_masks) 163 | 164 | total_loss = destroyer_loss + self.area_loss_coef*area_loss + self.smoothness_loss_coef*smoothness_loss + self.preserver_loss_coef*preserver_loss 165 | 166 | if pt_store is not None: 167 | # add variables to the pt_store 168 | pt_store(masks=_masks) 169 | pt_store(destroyed=destroyed_images) 170 | pt_store(preserved=preserved_images) 171 | pt_store(area_loss=area_loss) 172 | pt_store(smoothness_loss=smoothness_loss) 173 | pt_store(destroyer_loss=destroyer_loss) 174 | pt_store(preserver_loss=preserver_loss) 175 | pt_store(preserved_logits=preserved_logits) 176 | pt_store(destroyed_logits=destroyed_logits) 177 | return total_loss 178 | 179 | 180 | 181 | -------------------------------------------------------------------------------- /sal/small.py: -------------------------------------------------------------------------------- 1 | from .datasets import cifar_dataset 2 | from utils.pytorch_fixes import * 3 | from utils.pytorch_trainer import * 4 | import torch 5 | from torch.autograd import Variable 6 | from torch.nn import functional as F 7 | import torch.nn as nn 8 | from torch.nn import Module 9 | import numpy as np 10 | import pycat 11 | 12 | 13 | 14 | 15 | class SimpleClassifier(EasyModule): 16 | '''Relatively fast and works well for 32x32 and 64x64 images. Achieves ~89% on CIFAR10 with base=24. 17 | Returns all the intermediate features so it can serve as a feature extractor or a base for the U-Net decoder.''' 18 | def __init__(self, base_channels=24, num_classes=10, image_channels=3): 19 | super(SimpleClassifier, self).__init__() 20 | 21 | BASE1 = base_channels 22 | BASE2 = BASE1 * 2 23 | BASE3 = BASE1 * 4 24 | BASE4 = BASE1 * 8 25 | self.scale0 = SimpleCNNBlock(image_channels, BASE1, layers=3) 26 | self.scale1 = SimpleCNNBlock(BASE1, BASE2, stride=2, layers=3) 27 | self.scale2 = SimpleCNNBlock(BASE2, BASE3, stride=2, layers=3) 28 | self.scale3 = SimpleCNNBlock(BASE3, BASE4, stride=2, layers=3) 29 | self.scaleX = GlobalAvgPool() 30 | self.scaleC = nn.Linear(BASE4, num_classes) 31 | 32 | def forward(self, _images): 33 | s0 = self.scale0(_images) 34 | s1 = self.scale1(s0) 35 | s2 = self.scale2(s1) 36 | s3 = self.scale3(s2) 37 | sX = self.scaleX(s3) 38 | sC = self.scaleC(sX) 39 | return s0, s1, s2, s3, sX, sC 40 | 41 | @staticmethod 42 | def out_to_logits(out): 43 | s0, s1, s2, s3, sX, sC = out 44 | return sC 45 | 46 | 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /sal/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiotrDabkowski/pytorch-saliency/bfd501ec7888dbb3727494d06c71449df1530196/sal/utils/__init__.py -------------------------------------------------------------------------------- /sal/utils/gaussian_blur.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.nn import Module 3 | from torch.nn.functional import conv2d 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | 8 | def _gaussian_kernels(kernel_size, sigma, chans): 9 | assert kernel_size % 2, 'Kernel size of the gaussian blur must be odd!' 10 | x = np.expand_dims(np.array(range(-kernel_size/2, -kernel_size/2+kernel_size, 1)), 0) 11 | vals = np.exp(-np.square(x)/(2.*sigma**2)) 12 | _kernel = np.reshape(vals / np.sum(vals), (1, 1, kernel_size, 1)) 13 | kernel = np.zeros((chans, 1, kernel_size, 1), dtype=np.float32) + _kernel 14 | return kernel, np.transpose(kernel, [0, 1, 3, 2]) 15 | 16 | def gaussian_blur(_images, kernel_size=55, sigma=11): 17 | ''' Very fast, linear time gaussian blur, using separable convolution. Operates on batch of images [N, C, H, W]. 18 | Returns blurred images of the same size. Kernel size must be odd. 19 | Increasing kernel size over 4*simga yields little improvement in quality. So kernel size = 4*sigma is a good choice.''' 20 | kernel_a, kernel_b = _gaussian_kernels(kernel_size=kernel_size, sigma=sigma, chans=_images.size(1)) 21 | kernel_a = torch.Tensor(kernel_a) 22 | kernel_b = torch.Tensor(kernel_b) 23 | if _images.is_cuda: 24 | kernel_a = kernel_a.cuda() 25 | kernel_b = kernel_b.cuda() 26 | _rows = conv2d(_images, Variable(kernel_a, requires_grad=False), groups=_images.size(1), padding=(kernel_size / 2, 0)) 27 | return conv2d(_rows, Variable(kernel_b, requires_grad=False), groups=_images.size(1), padding=(0, kernel_size / 2)) 28 | 29 | 30 | def test(): 31 | from PIL import Image 32 | import os, pycat 33 | im = Variable(torch.Tensor(np.expand_dims(np.transpose(np.array(Image.open(os.path.join(os.path.dirname(__file__), 'test.jpg'))), (2, 0, 1)), 0)/255.), requires_grad=True) 34 | g = gaussian_blur(im) 35 | print 'Original' 36 | pycat.show(im[0].data.numpy()) 37 | print 'Blurred version' 38 | pycat.show(g[0].data.numpy()) 39 | print 'Image gradient over blurred sum (should be white in the middle + turning darker at the edges)' 40 | l = torch.sum(g) 41 | l.backward() 42 | gr = im.grad[0].data.numpy() 43 | assert np.mean(gr) > 0.9 and np.mean(np.flip(gr, 1)-gr) < 1e-6 and np.mean(np.flip(gr, 2)-gr) < 1e-6 44 | pycat.show(gr) 45 | 46 | 47 | -------------------------------------------------------------------------------- /sal/utils/mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import gaussian_blur 4 | 5 | def calc_smoothness_loss(mask, power=2, border_penalty=0.3): 6 | ''' For a given image this loss should be more or less invariant to image resize when using power=2... 7 | let L be the length of a side 8 | EdgesLength ~ L 9 | EdgesSharpness ~ 1/L, easy to see if you imagine just a single vertical edge in the whole image''' 10 | x_loss = torch.sum((torch.abs(mask[:,:,1:,:] - mask[:,:,:-1,:]))**power) 11 | y_loss = torch.sum((torch.abs(mask[:,:,:,1:] - mask[:,:,:,:-1]))**power) 12 | if border_penalty>0: 13 | border = float(border_penalty)*torch.sum(mask[:,:,-1,:]**power + mask[:,:,0,:]**power + mask[:,:,:,-1]**power + mask[:,:,:,0]**power) 14 | else: 15 | border = 0. 16 | return (x_loss + y_loss + border) / float(power * mask.size(0)) # watch out, normalised by the batch size! 17 | 18 | 19 | 20 | def calc_area_loss(mask, power=1.): 21 | if power != 1: 22 | mask = (mask+0.0005)**power # prevent nan (derivative of sqrt at 0 is inf) 23 | return torch.mean(mask) 24 | 25 | 26 | def tensor_like(x): 27 | if x.is_cuda: 28 | return torch.Tensor(*x.size()).cuda() 29 | else: 30 | return torch.Tensor(*x.size()) 31 | 32 | def apply_mask(images, mask, noise=True, random_colors=True, blurred_version_prob=0.5, noise_std=0.11, 33 | color_range=0.66, blur_kernel_size=55, blur_sigma=11, 34 | bypass=0., boolean=False, preserved_imgs_noise_std=0.03): 35 | images = images.clone() 36 | cuda = images.is_cuda 37 | 38 | if boolean: 39 | # remember its just for validation! 40 | return (mask > 0.5).float() *images 41 | 42 | assert 0. <= bypass < 0.9 43 | n, c, _, _ = images.size() 44 | if preserved_imgs_noise_std > 0: 45 | images = images + Variable(tensor_like(images).normal_(std=preserved_imgs_noise_std), requires_grad=False) 46 | if bypass > 0: 47 | mask = (1.-bypass)*mask + bypass 48 | if noise and noise_std: 49 | alt = tensor_like(images).normal_(std=noise_std) 50 | else: 51 | alt = tensor_like(images).zero_() 52 | if random_colors: 53 | if cuda: 54 | alt += torch.Tensor(n, c, 1, 1).cuda().uniform_(-color_range/2., color_range/2.) 55 | else: 56 | alt += torch.Tensor(n, c, 1, 1).uniform_(-color_range/2., color_range/2.) 57 | 58 | alt = Variable(alt, requires_grad=False) 59 | 60 | if blurred_version_prob > 0.: # <- it can be a scalar between 0 and 1 61 | cand = gaussian_blur.gaussian_blur(images, kernel_size=blur_kernel_size, sigma=blur_sigma) 62 | if cuda: 63 | when = Variable((torch.Tensor(n, 1, 1, 1).cuda().uniform_(0., 1.) < blurred_version_prob).float(), requires_grad=False) 64 | else: 65 | when = Variable((torch.Tensor(n, 1, 1, 1).uniform_(0., 1.) < blurred_version_prob).float(), requires_grad=False) 66 | alt = alt*(1.-when) + cand*when 67 | 68 | return (mask*images.detach()) + (1. - mask)*alt.detach() 69 | 70 | 71 | 72 | 73 | def test(): 74 | from PIL import Image 75 | import numpy as np 76 | import os, pycat 77 | im = Variable(torch.Tensor(np.expand_dims(np.transpose(np.array(Image.open(os.path.join(os.path.dirname(__file__), 'test.jpg'))), (2, 0, 1)), 0)/255.*2-1.), requires_grad=False) 78 | print 'Original' 79 | pycat.show(im[0].data.numpy()) 80 | print 81 | for pres in [1., 0.5, 0.1]: 82 | print 'Mask strength =', pres 83 | for e in xrange(5): 84 | m = Variable(torch.Tensor(1, 3, im.size(2), im.size(3)).fill_(pres), requires_grad=True) 85 | res = apply_mask(im, m) 86 | pycat.show(res[0].data.numpy()) 87 | s = torch.sum(res) 88 | s.backward() 89 | print torch.sum(m.grad) 90 | 91 | 92 | -------------------------------------------------------------------------------- /sal/utils/pt_store.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.autograd import Variable 3 | import torch 4 | 5 | class PTStore: 6 | def __init__(self): 7 | self.__dict__['vars'] = {} 8 | 9 | def __call__(self, **kwargs): 10 | assert len(kwargs)==1, "You must specify just 1 variable to add" 11 | key, value = kwargs.items()[0] 12 | setattr(self, key, value) 13 | return value 14 | 15 | def __setattr__(self, key, value): 16 | self.__dict__['vars'][key] = value 17 | 18 | def __getattr__(self, key): 19 | if key=='_vars': 20 | return self.__dict__['vars'] 21 | if key not in self.__dict__['vars']: 22 | raise KeyError('Key %s was not found in the pt_store! Forgot to add it?' % key) 23 | return self.__dict__['vars'][key] 24 | 25 | def __getitem__(self, key): 26 | if key not in self.__dict__['vars']: 27 | raise KeyError('Key %s was not found in the pt_store! Forgot to add it?' % key) 28 | cand = self.__dict__['vars'][key] 29 | return to_numpy(cand) 30 | 31 | def clear(self): 32 | self.__dict__['vars'].clear() 33 | 34 | def to_numpy(cand): 35 | if isinstance(cand, Variable): 36 | return cand.data.cpu().numpy() 37 | elif isinstance(cand, torch._TensorBase): 38 | return cand.cpu().numpy() 39 | elif isinstance(cand, (list, tuple)): 40 | return map(to_numpy, cand) 41 | elif isinstance(cand, np.ndarray): 42 | return cand 43 | else: 44 | return np.array([cand]) 45 | 46 | def to_number(x): 47 | if isinstance(x, (int, long, float)): 48 | return float(x) 49 | if isinstance(x, np.ndarray): 50 | return x[0] 51 | return x.data[0] 52 | 53 | -------------------------------------------------------------------------------- /sal/utils/pytorch_fixes.py: -------------------------------------------------------------------------------- 1 | __all__ = ['RandomSizedCrop2', 'STD_NORMALIZE', 'ShapeLog', 'AssertSize', 'GlobalAvgPool', 'SimpleCNNBlock', 'SimpleLinearBlock', 2 | 'SimpleExtractor', 'SimpleGenerator', 'DiscreteNeuron', 'chain', 'EasyModule', 'UNetUpsampler', 3 | 'SimpleUpsamplerSubpixel', 'CustomModule', 'BottleneckBlock', 'losses', 'F', 'torch_optim', 'Variable', 4 | 'one_hot', 'cw_loss', 'adapt_to_image_domain', 'nn', 'MultiModulator', 5 | 'B4D_4D', 'B2D_4D', 'B2D_2D', 'Veri4D', 'Veri2D', 'GanGroup', 6 | 7 | ] 8 | 9 | 10 | import torch 11 | from PIL import Image 12 | from torchvision.transforms import * 13 | from torch.nn import * 14 | import torch.nn as nn 15 | import torch.nn.modules.loss as losses 16 | from torch.nn import functional as F 17 | from itertools import chain 18 | import torch.optim as torch_optim 19 | from torch.autograd import Variable 20 | import signal 21 | import sys 22 | 23 | INFO_TEMPLATE = '\033[38;5;2mINFO: %s\033[0m\n' 24 | WARN_TEMPLATE = '\033[38;5;1mWARNING: %s\033[0m\n' 25 | 26 | 27 | def signal_handler(signal, frame): 28 | print('Finishing...') 29 | sys.exit(0) 30 | signal.signal(signal.SIGINT, signal_handler) 31 | 32 | 33 | 34 | STD_NORMALIZE = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 35 | 36 | 37 | class DiscreteNeuron(Module): 38 | def forward(self, x): 39 | return discrete_neuron_func()(x) 40 | 41 | class discrete_neuron_func(torch.autograd.Function): 42 | def forward(self, x): 43 | self.save_for_backward(x) 44 | x = x.clone() 45 | x[x>0] = 1. 46 | x[x<0] = -1. 47 | return x 48 | 49 | def backward(self, grad_out): 50 | x, = self.saved_tensors 51 | grad_input = grad_out.clone() 52 | this_grad = torch.exp(-(x**2)/2.) / (3.14/2.)**0.5 53 | return this_grad*grad_input 54 | 55 | class RandomSizedCrop2(object): 56 | """Crop the given PIL.Image to random size and aspect ratio. 57 | 58 | A crop of random size of (0.08 to 1.0) of the original size and a random 59 | aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop 60 | is finally resized to given size. 61 | This is popularly used to train the Inception networks. 62 | 63 | Args: 64 | size: size of the smaller edge 65 | interpolation: Default: PIL.Image.BILINEAR 66 | """ 67 | 68 | def __init__(self, size, min_area=0.3, interpolation=Image.BILINEAR): 69 | self.size = size 70 | self.interpolation = interpolation 71 | self.min_area = min_area 72 | 73 | def __call__(self, img): 74 | for attempt in range(10): 75 | area = img.size[0] * img.size[1] 76 | target_area = random.uniform(self.min_area, 1.0) * area 77 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 78 | 79 | w = int(round(math.sqrt(target_area * aspect_ratio))) 80 | h = int(round(math.sqrt(target_area / aspect_ratio))) 81 | 82 | if random.random() < 0.5: 83 | w, h = h, w 84 | 85 | if w <= img.size[0] and h <= img.size[1]: 86 | x1 = random.randint(0, img.size[0] - w) 87 | y1 = random.randint(0, img.size[1] - h) 88 | 89 | img = img.crop((x1, y1, x1 + w, y1 + h)) 90 | assert(img.size == (w, h)) 91 | 92 | return img.resize((self.size, self.size), self.interpolation) 93 | 94 | # Fallback 95 | scale = Scale(self.size, interpolation=self.interpolation) 96 | crop = CenterCrop(self.size) 97 | return crop(scale(img)) 98 | 99 | 100 | class AssertSize(Module): 101 | def __init__(self, expected_dim=None): 102 | super(AssertSize, self).__init__() 103 | self.expected_dim = expected_dim 104 | 105 | def forward(self, x): 106 | if self.expected_dim is not None: 107 | assert self.expected_dim==x.size(2), 'expected %d got %d' % (self.expected_dim, x.size(2)) 108 | return x 109 | 110 | 111 | class ShapeLog(Module): 112 | def forward(self, x): 113 | print x.size() 114 | return x 115 | 116 | class PixelShuffleBlock(Module): 117 | def forward(self, x): 118 | return F.pixel_shuffle(x, 2) 119 | 120 | 121 | class GlobalAvgPool(Module): 122 | def forward(self, x): 123 | x = F.avg_pool2d(x, x.size(2), stride=None, padding=0, ceil_mode=False, 124 | count_include_pad=True) 125 | return x.view(x.size(0), -1) 126 | 127 | 128 | 129 | def SimpleCNNBlock(in_channels, out_channels, 130 | kernel_size=3, layers=1, stride=1, 131 | follow_with_bn=True, activation_fn=lambda: ReLU(True), affine=True): 132 | assert layers > 0 and kernel_size%2 and stride>0 133 | current_channels = in_channels 134 | _modules = [] 135 | for layer in range(layers): 136 | _modules.append(Conv2d(current_channels, out_channels, kernel_size, stride=stride if layer==0 else 1, padding=kernel_size/2, bias=not follow_with_bn)) 137 | current_channels = out_channels 138 | if follow_with_bn: 139 | _modules.append(BatchNorm2d(current_channels, affine=affine)) 140 | if activation_fn is not None: 141 | _modules.append(activation_fn()) 142 | return Sequential(*_modules) 143 | 144 | 145 | def ReducedCNNBlock(in_channels, out_channels, 146 | kernel_size=3, layers=1, stride=1, activation_fn=lambda: ReLU(True), reducer_family='all'): 147 | assert layers > 0 and kernel_size % 2 and stride > 0 148 | current_channels = in_channels 149 | follow_with_bn = True 150 | _modules = [] 151 | for layer in range(layers): 152 | _modules.append(Conv2d(current_channels, out_channels, kernel_size, stride=stride if layer == layers - 1 else 1, 153 | padding=kernel_size / 2, bias=not follow_with_bn)) 154 | current_channels = out_channels 155 | if follow_with_bn: 156 | _modules.append(Reducer4D(current_channels, family=reducer_family)) 157 | if activation_fn is not None: 158 | _modules.append(activation_fn()) 159 | return Sequential(*_modules) 160 | 161 | 162 | def SimpleLinearBlock(in_channels, out_channels, layers=1, follow_with_bn=True, activation_fn=lambda: ReLU(inplace=False), affine=True): 163 | assert layers > 0 164 | current_channels = in_channels 165 | _modules = [] 166 | for layer in range(layers): 167 | _modules.append(Linear(current_channels, out_channels, bias=not follow_with_bn)) 168 | current_channels = out_channels 169 | if follow_with_bn: 170 | _modules.append(BatchNorm1d(current_channels, affine=affine)) 171 | if activation_fn is not None: 172 | _modules.append(activation_fn()) 173 | return Sequential(*_modules) 174 | 175 | 176 | 177 | def ReducedExtractor(base_channels, downsampling_blocks, extra_modules=(), activation_fn=lambda: torch.nn.ReLU(inplace=False)): 178 | # final_dimension is an extra layer of protection so that we have the dimensions right 179 | current_channels = 3 180 | _modules = [BatchNorm2d(current_channels)] 181 | for layers in downsampling_blocks: 182 | if layers-1>0: 183 | _modules.append(ReducedCNNBlock(current_channels, base_channels, layers=layers-1, activation_fn=activation_fn)) 184 | current_channels = base_channels 185 | base_channels *= 2 186 | _modules.append(ReducedCNNBlock(current_channels, base_channels, stride=2, activation_fn=activation_fn)) 187 | current_channels = base_channels 188 | _modules.extend(extra_modules) 189 | return Sequential(*_modules) 190 | 191 | 192 | 193 | def SimpleExtractor(base_channels, downsampling_blocks, extra_modules=(), affine=True, activation_fn=lambda: torch.nn.ReLU(inplace=False)): 194 | # final_dimension is an extra layer of protection so that we have the dimensions right 195 | current_channels = 3 196 | _modules = [BatchNorm2d(current_channels)] 197 | for layers in downsampling_blocks: 198 | if layers-1>0: 199 | _modules.append(SimpleCNNBlock(current_channels, base_channels, layers=layers-1, activation_fn=activation_fn)) 200 | current_channels = base_channels 201 | base_channels *= 2 202 | _modules.append(SimpleCNNBlock(current_channels, base_channels, stride=2, affine=affine, activation_fn=activation_fn)) 203 | current_channels = base_channels 204 | _modules.extend(extra_modules) 205 | return Sequential(*_modules) 206 | 207 | 208 | def SimpleUpsamplerSubpixel(in_channels, out_channels, kernel_size=3, activation_fn=lambda: torch.nn.ReLU(inplace=False), follow_with_bn=True): 209 | _modules = [ 210 | SimpleCNNBlock(in_channels, out_channels * 4, kernel_size=kernel_size, follow_with_bn=follow_with_bn), 211 | PixelShuffleBlock(), 212 | activation_fn(), 213 | ] 214 | return Sequential(*_modules) 215 | 216 | 217 | class UNetUpsampler(Module): 218 | def __init__(self, in_channels, out_channels, passthrough_channels, follow_up_residual_blocks=1, upsampler_block=SimpleUpsamplerSubpixel, 219 | upsampler_kernel_size=3, activation_fn=lambda: torch.nn.ReLU(inplace=False)): 220 | super(UNetUpsampler, self).__init__() 221 | assert follow_up_residual_blocks >= 1, 'You must follow up with residuals when using unet!' 222 | assert passthrough_channels >= 1, 'You must use passthrough with unet' 223 | self.upsampler = upsampler_block(in_channels=in_channels, 224 | out_channels=out_channels, kernel_size=upsampler_kernel_size, activation_fn=activation_fn) 225 | 226 | self.follow_up = BottleneckBlock(out_channels+passthrough_channels, out_channels, layers=follow_up_residual_blocks, activation_fn=activation_fn) 227 | 228 | def forward(self, inp, passthrough): 229 | upsampled = self.upsampler(inp) 230 | upsampled = torch.cat((upsampled, passthrough), 1) 231 | return self.follow_up(upsampled) 232 | 233 | 234 | class CustomModule(Module): 235 | def __init__(self, py_func): 236 | super(CustomModule, self).__init__() 237 | self.py_func = py_func 238 | 239 | def forward(self, inp): 240 | return self.py_func(inp) 241 | 242 | 243 | class MultiModulator(Module): 244 | def __init__(self, embedding_size, num_classes, modulator_sizes): 245 | super(MultiModulator, self).__init__() 246 | self.emb = Embedding(num_classes, embedding_size) 247 | self.num_modulators = len(modulator_sizes) 248 | for i, m in enumerate(modulator_sizes): 249 | self.add_module('m%d'%i, Linear(embedding_size, m)) 250 | 251 | def forward(self, selectors): 252 | ''' class selector must be of shape (BS,) Returns (BS, MODULATOR_SIZE) for each modulator.''' 253 | em = torch.squeeze(self.emb(selectors.view(-1, 1)), 1) 254 | res = [] 255 | for i in xrange(self.num_modulators): 256 | res.append( self._modules['m%d'%i](em) ) 257 | return tuple(res) 258 | 259 | 260 | import os 261 | class EasyModule(Module): 262 | def save(self, save_dir, step=1): 263 | if not os.path.exists(save_dir): 264 | os.mkdir(save_dir) 265 | torch.save(self.state_dict(), os.path.join(save_dir, 'model-%d.ckpt'%step)) 266 | 267 | 268 | def restore(self, save_dir, step=1): 269 | p = os.path.join(save_dir, 'model-%d.ckpt' % step) 270 | if not os.path.exists(p): 271 | print WARN_TEMPLATE % ('Could not find any checkpoint at %s, skipping restore' % p) 272 | return 273 | self.load_state_dict(torch.load(p)) 274 | Module.save = EasyModule.save.__func__ 275 | Module.restore = EasyModule.restore.__func__ 276 | 277 | def one_hot(labels, depth): 278 | if labels.is_cuda: 279 | return Variable(torch.zeros(labels.size(0), depth).cuda().scatter_(1, labels.long().view(-1, 1).data, 1)) 280 | else: 281 | return Variable(torch.zeros(labels.size(0), depth).scatter_(1, labels.long().view(-1, 1).data, 1)) 282 | 283 | 284 | def cw_loss(logits, one_hot_labels, targeted=True, t_conf=2, nt_conf=5): 285 | ''' computes the advantage of the selected label over other highest prob guess. 286 | In case of the targeted it tries to maximise this advantage to reach desired confidence. 287 | For example confidence of 3 would mean that the desired label is e^3 (about 20) times more probable than the second top guess. 288 | In case of non targeted optimisation the case is opposite and we try to minimise this advantage - the probability of the label is 289 | 20 times smaller than the probability of the top guess. 290 | 291 | So for targeted optim a small confidence should be enough (about 2) and for non targeted about 5-6 would work better (assuming 1000 classes so log(no_idea)=6.9) 292 | ''' 293 | this = torch.sum(logits*one_hot_labels, 1) 294 | other_best, _ = torch.max(logits*(1.-one_hot_labels) - 12111*one_hot_labels, 1) # subtracting 12111 from selected labels to make sure that they dont end up a maximum 295 | t = F.relu(other_best - this + t_conf) 296 | nt = F.relu(this - other_best + nt_conf) 297 | if isinstance(targeted, (bool, int)): 298 | return torch.mean(t) if targeted else torch.mean(nt) 299 | else: # must be a byte tensor of zeros and ones 300 | 301 | return torch.mean(t*(targeted>0).float() + nt*(targeted==0).float()) 302 | 303 | def adapt_to_image_domain(images_plus_minus_one, desired_domain): 304 | if desired_domain == (-1., 1.): 305 | return images_plus_minus_one 306 | return images_plus_minus_one * (desired_domain[1] - desired_domain[0]) / 2. + (desired_domain[0] + desired_domain[1]) / 2. 307 | 308 | class Bottleneck(Module): 309 | def __init__(self, in_channels, out_channels, stride=1, bottleneck_ratio=4, activation_fn=lambda: torch.nn.ReLU(inplace=False)): 310 | super(Bottleneck, self).__init__() 311 | bottleneck_channels = out_channels/bottleneck_ratio 312 | self.conv1 = Conv2d(in_channels, bottleneck_channels, kernel_size=1, bias=False) 313 | self.bn1 = BatchNorm2d(bottleneck_channels) 314 | self.conv2 = Conv2d(bottleneck_channels, bottleneck_channels, kernel_size=3, stride=stride, 315 | padding=1, bias=False) 316 | self.bn2 = BatchNorm2d(bottleneck_channels) 317 | self.conv3 = Conv2d(bottleneck_channels, out_channels, kernel_size=1, bias=False) 318 | self.bn3 = BatchNorm2d(out_channels) 319 | self.activation_fn = activation_fn() 320 | 321 | if stride != 1 or in_channels != out_channels : 322 | self.residual_transformer = Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=True) 323 | else: 324 | self.residual_transformer = None 325 | 326 | def forward(self, x): 327 | residual = x 328 | 329 | out = self.conv1(x) 330 | out = self.bn1(out) 331 | out = self.activation_fn(out) 332 | 333 | out = self.conv2(out) 334 | out = self.bn2(out) 335 | out = self.activation_fn(out) 336 | 337 | out = self.conv3(out) 338 | out = self.bn3(out) 339 | 340 | if self.residual_transformer is not None: 341 | residual = self.residual_transformer(residual) 342 | out += residual 343 | 344 | out = self.activation_fn(out) 345 | return out 346 | 347 | def BottleneckBlock(in_channels, out_channels, stride=1, layers=1, activation_fn=lambda: torch.nn.ReLU(inplace=False)): 348 | assert layers > 0 and stride > 0 349 | current_channels = in_channels 350 | _modules = [] 351 | for layer in range(layers): 352 | _modules.append(Bottleneck(current_channels, out_channels, stride=stride if layer==0 else 1, activation_fn=activation_fn)) 353 | current_channels = out_channels 354 | return Sequential(*_modules) if len(_modules)>1 else _modules[0] 355 | 356 | def SimpleGenerator(in_channels, base_channels, upsampling_blocks=lambda: torch.nn.ReLU(inplace=False)): 357 | _modules = [] 358 | current_channels = in_channels 359 | base_channels = base_channels * 2**len(upsampling_blocks) 360 | for layers in upsampling_blocks: 361 | if layers-1>0: 362 | _modules.append(SimpleCNNBlock(current_channels, base_channels, layers=layers-1)) 363 | current_channels = base_channels 364 | _modules.append(SimpleCNNBlock(current_channels, base_channels*2)) 365 | _modules.append(PixelShuffleBlock()) 366 | base_channels /= 2 367 | current_channels = base_channels 368 | _modules.append(Conv2d(base_channels, 3, 1)) 369 | _modules.append(Tanh()) 370 | return Sequential(*_modules) 371 | 372 | 373 | # Backward modules 374 | class B4D_4D(Module): 375 | """Converts 4d to 4d representation, note increases the resolution by 2!""" 376 | def __init__(self, in_channels, out_channels, preprocess_layers=2, keep_resolution=False): 377 | super(B4D_4D, self).__init__() 378 | assert preprocess_layers > 0 379 | self.keep_resolution = keep_resolution 380 | self.preprocess = SimpleCNNBlock(in_channels=in_channels, out_channels=in_channels, 381 | kernel_size=3, layers=preprocess_layers, stride=1, follow_with_bn=True, 382 | activation_fn=lambda: nn.ReLU(inplace=False)) 383 | self.trns = SimpleCNNBlock(in_channels=in_channels, out_channels=4*out_channels if not self.keep_resolution else out_channels, 384 | kernel_size=3, layers=1, stride=1, follow_with_bn=False, 385 | activation_fn=None) 386 | 387 | def forward(self, x): 388 | x = self.preprocess(x) 389 | x = self.trns(x) 390 | if not self.keep_resolution: 391 | x = F.pixel_shuffle(x, 2) 392 | return x 393 | 394 | 395 | 396 | class B2D_4D(Module): 397 | """Note: will always return 2x2""" 398 | def __init__(self, in_channels, out_channels, preprocess_layers=2): 399 | super(B2D_4D, self).__init__() 400 | if preprocess_layers>0: 401 | self.preprocess = SimpleLinearBlock(in_channels=in_channels, out_channels=in_channels, 402 | layers=preprocess_layers, follow_with_bn=True, 403 | activation_fn=lambda: nn.ReLU(inplace=False)) 404 | else: 405 | self.preprocess = None 406 | 407 | self.trns = SimpleLinearBlock(in_channels=in_channels, out_channels=out_channels*4, layers=1, 408 | follow_with_bn=False, activation_fn=None) 409 | 410 | def forward(self, x): 411 | x = self.preprocess(x) if self.preprocess else x 412 | x = self.trns(x) 413 | x = torch.unsqueeze(x, 2) 414 | x = torch.unsqueeze(x, 2) 415 | x = F.pixel_shuffle(x, 2) 416 | return x 417 | 418 | 419 | class B2D_2D(Module): 420 | def __init__(self, in_channels, out_channels, preprocess_layers=2): 421 | super(B2D_2D, self).__init__() 422 | if preprocess_layers>0: 423 | self.preprocess = SimpleLinearBlock(in_channels=in_channels, out_channels=in_channels, 424 | layers=preprocess_layers, follow_with_bn=True, 425 | activation_fn=lambda: nn.ReLU(inplace=False)) 426 | else: 427 | self.preprocess = None 428 | 429 | self.trns = SimpleLinearBlock(in_channels=in_channels, out_channels=out_channels, layers=1, 430 | follow_with_bn=False, activation_fn=None) 431 | 432 | def forward(self, x): 433 | x = self.preprocess(x) if self.preprocess else x 434 | x = self.trns(x) 435 | return x 436 | 437 | 438 | # Veri modules 439 | class Veri4D(Module): 440 | def __init__(self, current_channels, base_channels, downsampling_blocks, activation_fn=lambda: torch.nn.LeakyReLU(negative_slope=0.2, inplace=False)): 441 | super(Veri4D, self).__init__() 442 | _modules = [] 443 | for i, layers in enumerate(downsampling_blocks): 444 | is_last = i==len(downsampling_blocks)-1 445 | if is_last: 446 | _modules.append(SimpleCNNBlock(current_channels, base_channels, layers=layers, activation_fn=activation_fn)) 447 | else: 448 | if layers - 1 > 0: 449 | _modules.append( 450 | SimpleCNNBlock(current_channels, base_channels, layers=layers - 1, activation_fn=activation_fn)) 451 | current_channels = base_channels 452 | base_channels *= 2 453 | _modules.append( 454 | SimpleCNNBlock(current_channels, base_channels, stride=2, activation_fn=activation_fn)) 455 | current_channels = base_channels 456 | _modules.append(nn.Conv2d(current_channels, 1, 1)) 457 | _modules.append(nn.Sigmoid()) 458 | self.trns = nn.Sequential(*_modules) 459 | 460 | def forward(self, x): 461 | return self.trns(x) 462 | 463 | 464 | class Veri2D(Module): 465 | def __init__(self, current_channels, fully_connected_blocks, activation_fn=lambda: torch.nn.LeakyReLU(negative_slope=0.2, inplace=False)): 466 | super(Veri2D, self).__init__() 467 | _modules = [] 468 | for new_chans in fully_connected_blocks: 469 | _modules.append(nn.Linear(current_channels, new_chans, bias=False)) 470 | current_channels = new_chans 471 | _modules.append(nn.BatchNorm1d(current_channels)) 472 | _modules.append(activation_fn()) 473 | _modules.append(nn.Linear(current_channels, 1)) 474 | _modules.append(nn.Sigmoid()) 475 | self.trns = nn.Sequential(*_modules) 476 | 477 | def forward(self, x): 478 | return self.trns(x) 479 | 480 | 481 | 482 | class GanGroup(Module): 483 | def __init__(self, discriminators): 484 | super(GanGroup, self).__init__() 485 | self.keys = [] 486 | for idx, module in enumerate(discriminators): 487 | key = str(idx) 488 | self.add_module(key, module) 489 | self.keys.append(key) 490 | self.reals = None 491 | self.fakes = None 492 | 493 | def set_data(self, reals, fakes): 494 | self.reals = reals 495 | self.fakes = fakes 496 | assert len(self.reals)==len(self.keys)==len(self.fakes) 497 | 498 | def generator_loss(self): 499 | return sum(self.calc_adv_loss(self._modules[k](i), 1.) for k, i in zip(self.keys, self.fakes))/len(self.keys) 500 | 501 | def discriminator_loss(self): 502 | losses = [] 503 | for i, k in enumerate(self.keys): 504 | discriminator = self._modules[k] 505 | assert self.reals[i].size(0)==self.fakes[i].size(0) 506 | inp = torch.cat((self.reals[i], self.fakes[i]), dim=0).detach() 507 | r, f = torch.split(discriminator(inp), self.reals[i].size(0), dim=0) 508 | losses.append(self.calc_adv_loss(r, 1.)+self.calc_adv_loss(f, 0.)) 509 | return sum(losses)/len(self.keys) 510 | 511 | def get_train_event(self, disc_steps=3, optimizer=torch_optim.SGD, **optimizer_kwargs): 512 | from .pytorch_trainer import TrainStepEvent, PT 513 | opt = optimizer(self.parameters(), **optimizer_kwargs) 514 | @TrainStepEvent() 515 | def gan_group_train_evnt(s): 516 | for _ in xrange(disc_steps): 517 | opt.zero_grad() 518 | loss = self.discriminator_loss() 519 | PT(disc_loss_=loss) 520 | loss.backward() 521 | opt.step() 522 | return gan_group_train_evnt 523 | 524 | @staticmethod 525 | def calc_adv_loss(x, target, tolerance=0.01): 526 | # target can be either 0 or 1 527 | if target == 1: 528 | return -torch.mean(torch.log(x + tolerance)) 529 | elif target == 0: 530 | return -torch.mean(torch.log((1. + tolerance) - x)) 531 | else: 532 | raise ValueError('target can only be 0 or 1') 533 | 534 | -------------------------------------------------------------------------------- /sal/utils/pytorch_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import sys 3 | import numpy as np 4 | import os 5 | import torch.nn 6 | import torch.utils.data as torch_data 7 | import torch.optim as torch_optim 8 | from torch.optim.lr_scheduler import StepLR 9 | import torch.nn.modules.loss as losses 10 | 11 | from torch.autograd import Variable 12 | import pycat 13 | 14 | GREEN_STR = '\033[38;5;2m%s\033[0m' 15 | RED_STR = '\033[38;5;1m%s\033[0m' 16 | INFO_TEMPLATE = '\033[38;5;2mINFO: %s\033[0m\n' 17 | WARN_TEMPLATE = '\033[38;5;1mWARNING: %s\033[0m\n' 18 | 19 | assert torch.cuda.is_available(), 'CUDA must be available' 20 | 21 | 22 | 23 | from .pt_store import PTStore, to_number, to_numpy 24 | PT = PTStore() 25 | BATCHES_DONE_INFO = '{batches_done}/{batches_per_epoch}' 26 | TIME_INFO = 'time: {comp_time:.3f} - data: {data_time:.3f} - ETA: {eta:.0f}' 27 | SPEED_INFO = 'e/s: {examples_per_sec:.1f}' 28 | 29 | 30 | 31 | 32 | def smoothing_dict_update(main, update, smooth_coef): 33 | for k, v in update.items(): 34 | if main.get(k) is None: 35 | main[k] = v 36 | else: 37 | main[k] = smooth_coef*main[k] + (1.-smooth_coef)*v 38 | return main 39 | 40 | 41 | class NiceTrainer: 42 | def __init__(self, 43 | forward_step, # forward step takes the output of the transform_inputs 44 | train_dts, 45 | optimizer, 46 | pt_store=PT, 47 | transform_inputs=lambda batch, trainer: batch, 48 | 49 | printable_vars=(), 50 | events=(), 51 | computed_variables=None, 52 | 53 | loss_name='loss', 54 | 55 | val_dts=None, 56 | set_trainable=None, 57 | 58 | modules=None, 59 | save_every=None, 60 | save_dir='ntsave', 61 | 62 | info_string=(BATCHES_DONE_INFO, TIME_INFO, SPEED_INFO), 63 | smooth_coef=0.95, 64 | goodbye_after = 5, 65 | 66 | lr_step_period=None, 67 | lr_step_gamma=0.1, 68 | ): 69 | ''' 70 | 71 | ''' 72 | self.forward_step = forward_step 73 | assert isinstance(train_dts, torch_data.DataLoader), 'train_dts must be an instance of torch.utils.data.DataLoader' 74 | self.train_dts = train_dts 75 | 76 | assert isinstance(optimizer, torch_optim.Optimizer), 'optimizer must be an instance of torch.optim.Optimizer' 77 | self.optimizer = optimizer 78 | assert isinstance(pt_store, PTStore), 'pt_store must be an instance of PTStore' 79 | self.pt_store = pt_store 80 | self.transform_inputs = transform_inputs 81 | 82 | 83 | self.printable_vars = list(printable_vars) 84 | self.events = list(events) if events else [] 85 | assert all(map(lambda x: isinstance(x, BaseEvent), self.events)), 'All events must be instances of the BaseEvent!' 86 | self.computed_variables = computed_variables if computed_variables is not None else {} 87 | 88 | self.loss_name = loss_name 89 | 90 | 91 | assert val_dts is None or isinstance(val_dts, torch_data.DataLoader), 'val_dts must be an instance of torch.utils.data.DataLoader or None' 92 | self.val_dts = val_dts 93 | if modules is not None: 94 | if not hasattr(modules, '__iter__'): 95 | modules = [modules] 96 | assert modules is None or all(map(lambda x: isinstance(x, torch.nn.Module), modules)), 'The list of modules can only contain instances of torch.nn.Module' 97 | self.modules = modules 98 | if set_trainable is None and self.modules is not None: 99 | def set_trainable(is_training): 100 | for m in self.modules: 101 | m.train(is_training) 102 | 103 | self.set_trainable = set_trainable 104 | 105 | 106 | # todo implement save/restore functionality! 107 | self.save_every = save_every 108 | self.save_dir = save_dir 109 | if self.save_every is not None: 110 | self._add_timed_save_event(save_every) 111 | 112 | self.smooth_coef = smooth_coef 113 | 114 | 115 | self._is_in_train_mode = None 116 | self.goodbye_after = goodbye_after 117 | 118 | self._extra_var_info_string = (info_string if isinstance(info_string, basestring) else ' - '.join(info_string)) + ( 119 | (' - ' if self.printable_vars else '') + ' - '.join('%s: {%s:.4f}' % (e, e) for e in self.printable_vars)) 120 | 121 | 122 | 123 | self.info_vars = dict( 124 | epochs_done=0, 125 | batches_done=0, 126 | total_batches_done=0, 127 | batch_size=0, 128 | total_examples_done=0, 129 | batches_per_sec=float('nan'), 130 | examples_per_sec=float('nan'), 131 | batches_per_epoch=float('nan'), 132 | eta=float('nan'), 133 | data_time=float('nan'), 134 | comp_time=float('nan'), 135 | is_training=None, 136 | ) 137 | self._core_info_vars = set(self.info_vars.keys()) 138 | 139 | if lr_step_period is not None: 140 | self._add_lr_scheduling_event(lr_step_period, lr_step_gamma) 141 | 142 | def _add_timed_save_event(self, period): 143 | @TimeEvent(period=period, first_at=period) 144 | def periodic_save(s): 145 | print 146 | print INFO_TEMPLATE % 'Performing a periodic save...' 147 | s.save() 148 | self.events.append(periodic_save) 149 | 150 | def _add_lr_scheduling_event(self, period, gamma): 151 | @TrainStepEvent() 152 | @EpochChangeEvent() 153 | def lr_step_event(s): 154 | if not hasattr(s, '_lr_sheduler') or s._lr_sheduler is None: 155 | s._lr_sheduler = StepLR(s.optimizer, period, gamma) 156 | s._lr_sheduler.step(epoch=s.info_vars['epochs_done']) 157 | for param_group in s.optimizer.param_groups: 158 | print INFO_TEMPLATE % ('LR ' + str(param_group['lr'])) 159 | self.events.append(lr_step_event) 160 | 161 | 162 | 163 | 164 | def _main_loop(self, is_training, steps=None, allow_switch_mode=True): 165 | """Trains for 1 epoch if steps is None. Otherwise performs specified number of steps.""" 166 | if steps is not None: print WARN_TEMPLATE % 'Num steps is not fully supported yet! (fix it!)' # todo allow continue and partial execution 167 | if not is_training: 168 | assert self.val_dts is not None, 'Validation dataset was not provided' 169 | if allow_switch_mode and self._is_in_train_mode != is_training: 170 | if self.set_trainable is not None: 171 | self.set_trainable(is_training=is_training) 172 | self._is_in_train_mode = is_training 173 | else: 174 | if is_training: 175 | print WARN_TEMPLATE % "could not set the modules to the training mode because neither set_trainable nor modules were provided, assuming already in the training mode" 176 | self._is_in_train_mode = True 177 | else: 178 | raise ValueError("cannot set the modules to the eval mode because neither set_trainable nor modules were provided") 179 | 180 | dts = self.train_dts if is_training else self.val_dts 181 | dts_iter = iter(dts) 182 | smooth_burn_in = 0.8 * self.smooth_coef 183 | smooth_normal = self.smooth_coef 184 | 185 | smoothing_dict = dict( 186 | comp_time=None, 187 | data_time=None, 188 | ) 189 | smoothing_dict.update(dict(zip(self.printable_vars, len(self.printable_vars)*[None]))) 190 | 191 | 192 | batches_per_epoch = len(dts) 193 | batch_size = dts.batch_size 194 | steps_done_here = 0 195 | batches_done = 0 # todo allow continue! 196 | 197 | self.info_vars.update(dict( 198 | batch_size=batch_size, 199 | batches_per_epoch=batches_per_epoch, 200 | batches_done=batches_done, 201 | is_training=is_training, 202 | )) 203 | 204 | 205 | 206 | last_hearbeat = [time.time()+5.] 207 | # def guard(): 208 | # while 1: 209 | # time.sleep(1) 210 | # last = last_hearbeat[0] 211 | # if last is None: 212 | # break 213 | # if time.time()-last>self.goodbye_after: 214 | # print 'We have not received any heartbeat update since last %d seconds. Time to say goodbye!' % self.goodbye_after 215 | # os.system('kill %d' % os.getpid()) 216 | # break 217 | # t = threading.Thread(target=guard) 218 | # t.daemon = True 219 | # t.start() 220 | 221 | 222 | 223 | smoothed_printable_vars = {} 224 | 225 | t_fetch = time.time() 226 | for batch in dts_iter: 227 | last_hearbeat[0] = time.time() 228 | 229 | self.pt_store.clear() 230 | batch = self.transform_inputs(batch, self) 231 | self.pt_store.batch = batch 232 | 233 | torch.cuda.synchronize() 234 | t_start = time.time() 235 | # --------------------------- OPTIMIZATION STEP ------------------------------------ 236 | if is_training: 237 | self.optimizer.zero_grad() 238 | 239 | self.forward_step(batch) 240 | loss = getattr(self.pt_store, self.loss_name) 241 | 242 | if is_training: 243 | loss.backward() 244 | self.optimizer.step() 245 | # ---------------------------------------------------------------------------------- 246 | torch.cuda.synchronize() 247 | t_end = time.time() 248 | 249 | # call events 250 | for event in self.events: 251 | event(self) 252 | 253 | 254 | # smoothing coef should be relatively small at the start because the starting values are very unstable 255 | # todo use bias correction by dividing by (1-smooth^t) 256 | smooth_coef = smooth_normal if steps_done_here > 22 else smooth_burn_in 257 | 258 | # important to smooth computation times to have a nice estimates 259 | smoothing_update = dict( 260 | comp_time= t_end - t_start, 261 | data_time= t_start - t_fetch, 262 | ) 263 | 264 | # calculate computed variables and add them to the pt_store 265 | for var_name, func in self.computed_variables.items(): 266 | setattr(self.pt_store, var_name, func(self)) 267 | 268 | # add ALL the printable variables to the smoother 269 | smoothing_update.update({e:to_number(self.pt_store[e]) for e in self.printable_vars if e not in self._core_info_vars}) 270 | 271 | # perform smoother update 272 | smoothing_dict_update(smoothed_printable_vars, smoothing_update, smooth_coef) 273 | 274 | # calculate 275 | batches_per_sec = 1./ (smoothed_printable_vars['comp_time']+smoothed_printable_vars['data_time']) 276 | batches_done += 1 277 | steps_done_here += 1 278 | eta = (batches_per_epoch - batches_done) / batches_per_sec 279 | 280 | # update info vars after iteration step 281 | self.info_vars.update(dict( 282 | batches_done=batches_done, 283 | total_batches_done=self.info_vars['total_batches_done']+1, 284 | total_examples_done=self.info_vars['total_examples_done']+batch_size, 285 | batches_per_sec=batches_per_sec, 286 | examples_per_sec=batches_per_sec*batch_size, 287 | eta=eta, 288 | )) 289 | # info vars should also contain all the printable vars 290 | self.info_vars.update(smoothed_printable_vars) 291 | 292 | 293 | 294 | 295 | formatted_info_string = self._extra_var_info_string.format(**self.info_vars) 296 | sys.stdout.write('\r' + formatted_info_string) 297 | sys.stdout.flush() 298 | 299 | t_fetch = t_end 300 | if steps is not None and steps <= steps_done_here: 301 | break 302 | else: 303 | if is_training: 304 | self.info_vars['epochs_done'] += 1 305 | 306 | if steps is not None: # really annoying that pytorch does not handle this on its own... todo improve this! 307 | dts_iter._shutdown_workers() 308 | time.sleep(11) 309 | for e in dts_iter.workers: 310 | e.terminate() 311 | time.sleep(2) 312 | del dts_iter 313 | 314 | 315 | # we have left the dangerous loop, quit the guarding process... 316 | last_hearbeat[0] = None 317 | sys.stdout.write('\n') 318 | sys.stdout.flush() 319 | sys.stdout.write('\n') 320 | sys.stdout.flush() 321 | 322 | def train(self, steps=None): 323 | if steps is None: 324 | print '_'*55 325 | print 'Epoch', self.info_vars['epochs_done']+1 326 | self._main_loop(is_training=True, steps=steps, allow_switch_mode=True) 327 | 328 | def validate(self, allow_switch_mode=False): 329 | old_info = self.info_vars.copy() 330 | print "Validation:" 331 | self._main_loop(is_training=False, steps=None, allow_switch_mode=allow_switch_mode) 332 | self.info_vars = old_info 333 | 334 | 335 | def _get_state(self): 336 | return dict( 337 | info_vars={k:v for k, v in self.info_vars.items() if k in self._core_info_vars}, 338 | state_dicts=[m.state_dict() for m in self.modules], 339 | optimizer_state=self.optimizer.state_dict(), 340 | ) 341 | 342 | def _set_state(self, state): 343 | self.info_vars = state['info_vars'] 344 | self.optimizer.load_state_dict(state['optimizer_state']) 345 | if len(self.modules)!=len(state['state_dicts']): 346 | raise ValueError('The number of save dicts is different from number of models') 347 | for m, s in zip(self.modules, state['state_dicts']): 348 | m.load_state_dict(s) 349 | 350 | def save(self, step=1): 351 | if not self.modules: 352 | raise ValueError("nothing to save - the list of modules was not provided") 353 | if not os.path.exists(self.save_dir): 354 | os.mkdir(self.save_dir) 355 | torch.save(self._get_state(), os.path.join(self.save_dir, 'model-%d.ckpt'%step)) 356 | 357 | 358 | def restore(self, step=1): 359 | if not self.modules: 360 | raise ValueError("nothing to load - the list of modules was not provided") 361 | p = os.path.join(self.save_dir, 'model-%d.ckpt' % step) 362 | if not os.path.exists(p): 363 | return 364 | self._set_state(torch.load(p)) 365 | 366 | # 367 | # 368 | # def restore(self, allow_restore_crash=True, relaxed=False): 369 | # """ If you set allow_restore_crash to True we will 370 | # check whether automatic periodic save was made after standard save and if this is the case 371 | # we will continue from periodic save.""" 372 | # assert self.saver is not None, 'You must specify saver if you want to use restore' 373 | # 374 | # std_checkpoint = tf.train.get_checkpoint_state(self.save_dir) 375 | # if std_checkpoint and std_checkpoint.model_checkpoint_path: 376 | # step = int(std_checkpoint.model_checkpoint_path.split('-')[-1]) 377 | # std_nt = cPickle.load(open(os.path.join(self.save_dir, 'nice_trainer%d'%step), 'rb')) 378 | # else: 379 | # std_nt = None 380 | # 381 | # if allow_restore_crash: 382 | # # check periodic save folder 383 | # periodic_check_dir = os.path.join(self.save_dir, 'periodic_check') 384 | # periodic_checkpoint = tf.train.get_checkpoint_state(periodic_check_dir) 385 | # if periodic_checkpoint and periodic_checkpoint.model_checkpoint_path: 386 | # periodic_nt = cPickle.load(open(os.path.join(periodic_check_dir, 'nice_trainer%d' % 0), 'rb')) 387 | # else: 388 | # periodic_nt = None 389 | # if periodic_nt is not None and (std_nt is None or std_nt['last_save_time'] < periodic_nt['last_save_time']): 390 | # # restore from crash 391 | # print 'Restoring from periodic save (maybe in the middle of the epoch). Training will be continued.' 392 | # self._restore(periodic_checkpoint.model_checkpoint_path, relaxed) 393 | # self._restore_nt(periodic_nt, continue_epoch=True) 394 | # return 395 | # 396 | # if std_checkpoint and std_checkpoint.model_checkpoint_path: 397 | # print 'Loading model from', std_checkpoint.model_checkpoint_path 398 | # self._restore(std_checkpoint.model_checkpoint_path, relaxed) 399 | # self._restore_nt(std_nt) 400 | # return 401 | # else: 402 | # print 'No saved models to restore from' 403 | # return 404 | # 405 | # def _restore(self, path, relaxed): 406 | # if not relaxed: 407 | # self.saver.restore(self.sess, path) 408 | # else: 409 | # optimistic_restore(self.sess, path, var_list=self.saver._var_list) 410 | # 411 | # def _save_nt(self, save_path): 412 | # nt = { 413 | # 'last_save_time': self._last_save_time, 414 | # 'save_num': self._save_num, 415 | # 'epoch': self._epoch, 416 | # 'measured_batches_per_sec': self.measured_batches_per_sec, 417 | # 'train_bm_state': self.bm_train.get_state(), 418 | # 'logs': self.logs, 419 | # } 420 | # cPickle.dump(nt, open(save_path, 'wb')) 421 | # 422 | # 423 | # 424 | # def _restore_nt(self, old_nt, continue_epoch=False): 425 | # self._epoch = old_nt['epoch'] 426 | # self._save_num = old_nt['save_num'] 427 | # self.measured_batches_per_sec = old_nt['measured_batches_per_sec'] 428 | # self.logs = old_nt.get('logs', self.logs) 429 | # 430 | # if continue_epoch: 431 | # self._epoch -= 1 432 | # # now make changes to the train batch manager... 433 | # self.bm_train.continue_from_state(old_nt['train_bm_state']) 434 | 435 | 436 | 437 | class BaseEvent: 438 | def __init__(self): 439 | self.func = None 440 | 441 | def __call__(self, func_or_trainer): 442 | if self.func is None: 443 | self.func = func_or_trainer 444 | return self 445 | if self.should_run(func_or_trainer): 446 | self.func(func_or_trainer) 447 | 448 | def should_run(self, trainer): 449 | raise NotImplementedError() 450 | 451 | 452 | class StepEvent(BaseEvent): 453 | def should_run(self, trainer): 454 | return True 455 | 456 | class ValStepEvent(BaseEvent): 457 | def should_run(self, trainer): 458 | return not trainer.info_vars['is_training'] 459 | 460 | class TrainStepEvent(BaseEvent): 461 | def should_run(self, trainer): 462 | return trainer.info_vars['is_training'] 463 | 464 | 465 | class EpochChangeEvent(BaseEvent): 466 | def __init__(self, call_on_first=True): 467 | BaseEvent.__init__(self) 468 | self.last_epoch = None 469 | self.call_on_first = call_on_first 470 | 471 | def should_run(self, trainer): 472 | current = trainer.info_vars['epochs_done'] 473 | if self.last_epoch is None: 474 | self.last_epoch = current 475 | if self.call_on_first: 476 | return True 477 | else: 478 | return False 479 | if self.last_epoch!=current: 480 | self.last_epoch = current 481 | return True 482 | return False 483 | 484 | 485 | class EveryNthEvent(BaseEvent): 486 | def __init__(self, every_n, required_remainder=0): 487 | # required reminder=1 in order to trigger at the first call. 488 | BaseEvent.__init__(self) 489 | self.every_n = every_n 490 | self.required_remainder = required_remainder % every_n 491 | self.count = 0 492 | 493 | def should_run(self, trainer): 494 | self.count += 1 495 | return self.count%self.every_n == self.required_remainder 496 | 497 | 498 | 499 | class TimeEvent(BaseEvent): 500 | def __init__(self, period, first_at=0): 501 | assert first_at <= period 502 | BaseEvent.__init__(self) 503 | self.period = period 504 | self.last = time.time()-(period-first_at) 505 | 506 | def should_run(self, trainer): 507 | t = time.time() 508 | if t>=self.last+self.period: 509 | self.last = t 510 | return True 511 | return False 512 | 513 | 514 | def auto_norm(im, auto_normalize=True, auto_fix=True): 515 | if im.shape[0]==1 and auto_fix: 516 | im = np.concatenate((np.zeros_like(im), np.zeros_like(im), im), axis=0) 517 | if im.shape[2]==1 and auto_fix: 518 | im = np.concatenate((np.zeros_like(im), np.zeros_like(im), im), axis=2) 519 | if im.shape[-1]!=3 and auto_fix: 520 | if len(im.shape) == 3: 521 | im = np.transpose(im, (1, 2, 0)) 522 | else: 523 | im = np.transpose(im, (0, 2, 3, 1)) 524 | min_v = np.min(im) 525 | max_v = np.max(im) 526 | if 0 num_validation_examples: 32 | break 33 | print 'Top 1 accuracy:', np.mean(scores) --------------------------------------------------------------------------------