├── .gitignore ├── 4334173592_145856d89b.jpg ├── 5595774449_b3f85b36ec.jpg ├── README.md ├── test.ipynb ├── test_debug.ipynb └── torchvis ├── __init__.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Jupyter Notebook 7 | .ipynb_checkpoints 8 | 9 | # model files 10 | /vgg16.pkl 11 | -------------------------------------------------------------------------------- /4334173592_145856d89b.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leelabcnbc/cnnvis-pytorch/503fd835c378f70a82cd476e79ccef549cba421a/4334173592_145856d89b.jpg -------------------------------------------------------------------------------- /5595774449_b3f85b36ec.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leelabcnbc/cnnvis-pytorch/503fd835c378f70a82cd476e79ccef549cba421a/5595774449_b3f85b36ec.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cnnvis-pytorch 2 | visualization of CNN in PyTorch 3 | 4 | this project is inspired by a summary of visualization methods in 5 | [Lasagne examples](https://github.com/Lasagne/Recipes/blob/master/examples/Saliency%20Maps%20and%20Guided%20Backpropagation.ipynb 6 | ), as well as [deep visualization toolbox](https://github.com/yosinski/deep-visualization-toolbox). 7 | 8 | Visualization of CNN units in higher layers is important for my work, and currently (May 2017), I'm not 9 | aware of any library with similar capabilities as the two mentioned above written for PyTorch. For a discussion of 10 | this feature in PyTorch, see 11 | 12 | Indeed I have some experience with deep visualization toolbox, which only supports Caffe. 13 | However, it has very poor support for networks whose input size is not around 256x256x3 14 | (standard size for ImageNet dataset, before cropping), 15 | and indeed I need to visualize networks not having input of such size, such as networks trained on CIFAR-10, etc. 16 | In addition, it can't support visualization techniques other than "deconvolution". Therefore, eventually, 17 | converting PyTorch models to Caffe and then hacking the code of deep visualization toolbox to make it work 18 | is probably not worthwhile. 19 | 20 | Some people have tried doing visualization in TensorFlow. 21 | See . 22 | However, TensorFlow has too much boilerplate, and in general I'm not familiar with it. I believe with the huge amount 23 | of boilerplate around TensorFlow, figuring out the usage of existing visualization code on my particular models, 24 | adapted to my particular needs, would possibly take more of my time than working on a pure PyTorch solution. 25 | 26 | ## Implementation 27 | 28 | It's going to be implemented mainly through forward and backward hooks of `torch.nn.Module`. Since most of visualization 29 | techniques focus on fiddling with ReLU layers, 30 | this means that as long as your ReLU layers, as well as those layers which contain your interested units 31 | are implemented using `torch.nn.Module`, not `torch.nn.functional`, then the code should work. 32 | 33 | ### Alternatives 34 | 35 | While it's possible to define some modified ReLU layers, 36 | as [suggested](https://discuss.pytorch.org/t/inherit-from-autograd-function/2117/2) by PyTorch developers, 37 | this make break the code, as autograd assumes correct grad computation. 38 | -------------------------------------------------------------------------------- /torchvis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leelabcnbc/cnnvis-pytorch/503fd835c378f70a82cd476e79ccef549cba421a/torchvis/__init__.py -------------------------------------------------------------------------------- /torchvis/util.py: -------------------------------------------------------------------------------- 1 | """assigning hooks to it""" 2 | from collections import OrderedDict 3 | from enum import Enum, auto 4 | from functools import partial 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | 10 | 11 | # class LayerType(Enum): 12 | # RELU = auto() 13 | # OTHER = auto() 14 | # 15 | 16 | # In general, the code assumes that each module is only called once 17 | 18 | class GradType(Enum): 19 | # here, comments follow those in 20 | # 21 | NAIVE = auto() # Simonyan et al. (2013): Plain Gradient 22 | GUIDED = auto() # 23 | DECONV = auto() 24 | 25 | 26 | def augment_module(net: nn.Module): 27 | layer_dict, remove_forward = _augment_module_pre(net) 28 | vis_param_dict, remove_backward = _augment_module_post(net, layer_dict) 29 | 30 | def remove_handles(): 31 | remove_forward() 32 | remove_backward() 33 | 34 | def reset_state(): 35 | for x, y in layer_dict.items(): 36 | print('clearing {}'.format(x)) 37 | assert isinstance(y, dict) 38 | y.clear() 39 | 40 | return vis_param_dict, reset_state, remove_handles 41 | 42 | 43 | def _forward_hook(m, in_, out_, module_name, callback_dict): 44 | # if callback_dict[module_name]['type'] == LayerType.RELU: 45 | assert isinstance(out_, Variable) 46 | assert 'output' not in callback_dict[module_name], 'same module called twice!' 47 | # I use Tensor so that during backwards, 48 | # I don't have to think about moving numpy array to/from devices. 49 | callback_dict[module_name]['output'] = out_.data.clone() 50 | print(module_name, callback_dict[module_name]['output'].size()) 51 | 52 | 53 | def _augment_module_pre(net: nn.Module) -> (dict, list): 54 | callback_dict = OrderedDict() # not necessarily ordered, but this can help some readability. 55 | 56 | forward_hook_remove_func_list = [] 57 | 58 | for x, y in net.named_modules(): 59 | if not isinstance(y, nn.Sequential) and y is not net: 60 | if isinstance(y, nn.ReLU): 61 | callback_dict[x] = {} 62 | forward_hook_remove_func_list.append( 63 | y.register_forward_hook(partial(_forward_hook, module_name=x, callback_dict=callback_dict))) 64 | 65 | def remove_handles(): 66 | for x in forward_hook_remove_func_list: 67 | x.remove() 68 | 69 | return callback_dict, remove_handles 70 | 71 | 72 | def _backward_hook(m: nn.Module, grad_in_, grad_out_, module_name, callback_dict, vis_param_dict): 73 | # print(module_name) 74 | # assert isinstance(grad_in_, tuple) and isinstance(grad_out_, tuple) 75 | # print('in', [z.size() if z is not None else None for z in grad_in_]) 76 | # print('out', [z.size() if z is not None else None for z in grad_out_]) 77 | 78 | # set grad for the layer 79 | (layer, index, method) = (vis_param_dict['layer'], 80 | vis_param_dict['index'], 81 | vis_param_dict['method']) 82 | if module_name not in callback_dict and layer != module_name: 83 | # print(module_name, 'SKIP') 84 | return 85 | # print(module_name, type(m), 'WORKING', isinstance(m, nn.Linear), isinstance(m, nn.ReLU)) 86 | # sanity check. 87 | assert isinstance(grad_in_, tuple) and isinstance(grad_out_, tuple) 88 | # just for sanity check. I don't want to confuse Variable and Tensor. 89 | for z in grad_in_: 90 | assert isinstance(z, Variable) 91 | for z in grad_out_: 92 | assert isinstance(z, Variable) 93 | # print('in', [z.size() if z is not None else None for z in grad_in_]) 94 | # print('out', [z.size() if z is not None else None for z in grad_out_]) 95 | assert len(grad_out_) == 1 96 | 97 | # first, work on the actual grad_out. clone for safety. 98 | grad_out_actual = grad_out_[0].clone() 99 | if layer == module_name: 100 | # then hack a grad_out with 1. 101 | print('change grad!') 102 | # print('in', [z.size() if z is not None else None for z in grad_in_]) 103 | # print(grad_in_[0].std(), grad_in_[0].max(), grad_in_[0].min()) 104 | assert index is not None 105 | # then you should set them. 106 | grad_out_actual.data.zero_() 107 | for var_ in grad_out_actual: 108 | var_[index] = 1 109 | 110 | # then use the actual gradient is fine. 111 | # ok. now time to get the fake gradient. 112 | # first case, ReLU, 113 | if isinstance(m, nn.ReLU): 114 | new_grad = grad_out_actual 115 | # here, you need to work on 116 | response = Variable(callback_dict[module_name]['output']) 117 | if method == GradType.NAIVE: 118 | new_grad[response <= 0] = 0 119 | elif method == GradType.GUIDED: 120 | new_grad[response <= 0] = 0 121 | new_grad[grad_out_actual <= 0] = 0 122 | elif method == GradType.DECONV: 123 | new_grad[grad_out_actual <= 0] = 0 124 | else: 125 | raise ValueError('unsupported yet!') 126 | elif isinstance(m, nn.Linear): 127 | w = None 128 | for w in m.parameters(): 129 | break 130 | # I think for Linear, it's always the first parameter that is the weight. 131 | # should be of size output x input. 132 | assert w is not None 133 | # grad_in_[0] is the grad w.r.t previous layer. 134 | # grad_in_[1] is the grad w.r.t weight. 135 | assert tuple(w.size()) == (grad_out_actual.size()[1], grad_in_[0].size()[1]) == tuple( 136 | grad_in_[1].size()) 137 | # then let's do multiplication myself. 138 | new_grad = torch.mm(grad_out_actual, w) 139 | else: 140 | raise TypeError('must be ReLU or Linear') 141 | if layer != module_name: 142 | if isinstance(m, nn.Linear) or (isinstance(m, nn.ReLU) and method == GradType.NAIVE): 143 | # check that my gradient is computed correctly. 144 | # I will print a numerical result here, and check by eye. 145 | print('check grad debug!') 146 | # use .data as Variable doesn't give scalar properly right now. 147 | # see 148 | assert (new_grad - grad_in_[0]).abs().data.max() < 1e-4 149 | return (new_grad,) + grad_in_[1:] 150 | 151 | 152 | def _augment_module_post(net: nn.Module, callback_dict: dict) -> (dict, list): 153 | backward_hook_remove_func_list = [] 154 | 155 | vis_param_dict = dict() 156 | vis_param_dict['layer'] = None 157 | vis_param_dict['index'] = None 158 | vis_param_dict['method'] = GradType.NAIVE 159 | 160 | for x, y in net.named_modules(): 161 | if not isinstance(y, nn.Sequential) and y is not net: 162 | # I should add hook to all layers, in case they will be needed. 163 | backward_hook_remove_func_list.append( 164 | y.register_backward_hook( 165 | partial(_backward_hook, module_name=x, callback_dict=callback_dict, vis_param_dict=vis_param_dict))) 166 | 167 | def remove_handles(): 168 | for x in backward_hook_remove_func_list: 169 | x.remove() 170 | 171 | return vis_param_dict, remove_handles 172 | --------------------------------------------------------------------------------