├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── dissect_all.sh ├── experiment ├── __init__.py ├── dissect_experiment.py ├── generator_int_experiment.py ├── intervention_experiment.py ├── oldalexnet.py ├── oldresnet152.py ├── oldvgg16.py ├── proggan.py ├── readdissect.py ├── setting.py └── shapebias_experiment.py ├── g_intervention.sh ├── intervention.sh ├── netdissect ├── __init__.py ├── bargraph.py ├── easydict.py ├── imgsave.py ├── imgviz.py ├── labwidget.py ├── nethook.py ├── paintwidget.py ├── parallelfolder.py ├── pbar.py ├── pidfile.py ├── renormalize.py ├── report.html ├── runningstats.py ├── sampler.py ├── segmenter.py ├── segmodel │ ├── __init__.py │ ├── colors150.npy │ ├── mobilenet.py │ ├── models.py │ ├── object150_info.csv │ ├── resnet.py │ └── resnext.py ├── segviz.py ├── show.py ├── tally.py ├── upsample.py ├── upsegmodel │ ├── __init__.py │ ├── models.py │ ├── prroi_pool │ │ ├── .gitignore │ │ ├── README.md │ │ ├── __init__.py │ │ ├── build.py │ │ ├── functional.py │ │ ├── prroi_pool.py │ │ ├── src │ │ │ ├── prroi_pooling_gpu.c │ │ │ ├── prroi_pooling_gpu.h │ │ │ ├── prroi_pooling_gpu_impl.cu │ │ │ └── prroi_pooling_gpu_impl.cuh │ │ └── test_prroi_pooling2d.py │ ├── resnet.py │ └── resnext.py ├── workerpool.py └── zdataset.py ├── notebooks ├── adv_experiment_bedroom.ipynb ├── adv_experiment_plot.ipynb ├── dissect_classifier_experiment.ipynb ├── dissect_generator_experiment.ipynb ├── intervention-classifier-experiment.ipynb ├── intervention-generator-experiment.ipynb ├── ipynb_drop_output.py ├── setup_notebooks.sh ├── shapebias_experiment.ipynb ├── single-classifier-unit-plot.ipynb └── single-generator-unit-plot.ipynb ├── setup ├── denv.yml └── setup_denv.sh ├── stylization ├── LICENSE ├── README.md ├── function.py ├── models │ └── download_models.sh ├── net.py ├── parallel-make-sp.sh ├── stylize.py └── torch_to_pytorch.py └── www ├── arxiv-thumb.png ├── classifier-dissection.png ├── classifier-intervention.png ├── dissection-compare.png ├── gandissect-tutorial.png ├── generator-dissection.png ├── generator-intervention.png ├── netdissect-tutorial.png ├── netdissect_code.png ├── paper-thumb.png ├── si-thumb.png └── website-thumb.png /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb filter=clean_ipynb 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | /datasets 3 | /results 4 | /notebooks/datasets 5 | /notebooks/results 6 | /notebooks/netdissect 7 | /notebooks/experiment 8 | /notebooks/unused 9 | /stylization/data 10 | /stylization/models/*.pth 11 | __pycache__ 12 | .DS_Store 13 | .__* 14 | .ipynb* 15 | .nfs* 16 | .*swp 17 | .idea 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Most files in this directory (stylization/) are either directly copied from the 2 | pytorch-AdaIN repository (https://github.com/naoto0804/pytorch-AdaIN) 3 | or adapted slightly. The following license applies to these files: 4 | 5 | MIT License 6 | 7 | Copyright (c) 2018 Naoto Inoue 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # What is the Role of a Neuron? 2 | 3 | When a deep network is trained on a high-level task such as classifying a place or synthesizing a scene, individual neural units within the network will often emerge that match specific human-interpretable concepts, like "trees", "windows", or "human faces." 4 | 5 | What role do such individual units serve within a deep network? 6 | 7 | We examine this question in two types of networks that contain interpretable units: networks trained to classify images of scenes (supervised image classifiers), and networks trained to synthesize images of scenes (generative adversarial networks). 8 | 9 | [**Understanding the Role of Individual Units in a Deep Network**](https://dissect.csail.mit.edu/).
10 | [David Bau](https://people.csail.mit.edu/davidbau/home/), [Jun-Yan Zhu](https://www.cs.cmu.edu/~junyanz/), [Hendrik Strobelt](http://hendrik.strobelt.com/), [Agata Lapedriza](https://www.media.mit.edu/people/agata/overview/), [Bolei Zhou](http://bzhou.ie.cuhk.edu.hk/), [Antonio Torralba](http://web.mit.edu/torralba/www/).
11 | Proceedings of the National Academy of Sciences, September 2020.
12 | MIT, MIT-IBM Watson AI Lab, IBM Research, The Chinese University of Hong Kong, Adobe Research 13 | 14 | 15 | 16 | 17 | 18 | 19 |

PNAS Paper

Supplemental

Website

arXiv
20 | 21 | 22 | ## Dissecting Units in Classifiers and Generators 23 | 24 | Network dissection compares individual network units to the predictions of a semantic segmentation network that can label pixels with a broad set of object, part, material, and color classes. This technique gives us a standard and scalable way to identify any units within the networks that match those same semantic classes. 25 | 26 | It works both in classification settings where the image is the input, and in generative settings where the image is the output. 27 | 28 | ![Dissection](/www/classifier-dissection.png) 29 | 30 | We find that both state-of-the-art GANs and classifiers contain object-matching units that correspond to a variety of object and part concepts, with semantics emerging in different layers. 31 | 32 | ![Comparing a Classifier to a Generator](/www/dissection-compare.png) 33 | 34 | To investigate the role of such units within classifiers, we measure the impact on the accuracy of the network when we turn off units individually or in groups. We find that removing as few as 20 units can destroy the network's ability to detect a class, but retaining only those 20 units and removing 492 other units in the same layer can keep the network's accuracy on that same class mostly intact. Furthermore, we find that those units that are important for the largest number of output classes are also the emergent units that match human-interpretable concepts best. 35 | 36 | ![Classifier Intervention Experiments](/www/classifier-intervention.png) 37 | 38 | In a generative network, we can understand the causal effects of neurons by observing changes to output images when sets of units are turned on and off. We find causal effects are strong enough to enable users to paint images out of object classes by activating neurons; we also find that some units reveal interactions between objects and specific contexts within a model. 39 | 40 | ![Genereator Intervention Experiments](/www/generator-intervention.png) 41 | 42 | ## Citation 43 | 44 | Bau, David, Jun-Yan Zhu, Hendrik Strobelt, Agata Lapedriza, Bolei Zhou, and Antonio Torralba. *Understanding the role of individual units in a deep neural network.* Proceedings of the National Academy of Sciences (2020). 45 | 46 | ## Bibtex 47 | 48 | ``` 49 | @article{bau2020role, 50 | author = {Bau, David and Zhu, Jun-Yan and Strobelt, Hendrik and Lapedriza, Agata and Zhou, Bolei and Torralba, Antonio}, 51 | title = {Understanding the role of individual units in a deep neural network}, 52 | elocation-id = {201907375}, 53 | year = {2020}, 54 | doi = {10.1073/pnas.1907375117}, 55 | publisher = {National Academy of Sciences}, 56 | issn = {0027-8424}, 57 | URL = {https://www.pnas.org/content/early/2020/08/31/1907375117}, 58 | journal = {Proceedings of the National Academy of Sciences} 59 | } 60 | ``` 61 | -------------------------------------------------------------------------------- /dissect_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | QUANTILE=0.01 4 | MINIOU=0.04 5 | SEG=netpqc 6 | DATASET=places 7 | 8 | MODEL=vgg16 9 | for LAYER in conv1_1 conv1_2 conv2_1 conv2_2 conv3_1 conv3_2 conv3_3 \ 10 | conv4_1 conv4_2 conv4_3 conv5_1 conv5_2 conv5_3 11 | do 12 | 13 | python -m experiment.dissect_experiment \ 14 | --quantile ${QUANTILE} --miniou ${MINIOU} \ 15 | --model ${MODEL} --dataset ${DATASET} --seg ${SEG} --layer ${LAYER} 16 | 17 | done 18 | 19 | MODEL=resnet152 20 | for LAYER in 0 4 5 6 7 21 | do 22 | 23 | python -m experiment.dissect_experiment \ 24 | --quantile ${QUANTILE} --miniou ${MINIOU} \ 25 | --model ${MODEL} --dataset ${DATASET} --seg ${SEG} --layer ${LAYER} 26 | 27 | done 28 | 29 | MODEL=alexnet 30 | for LAYER in conv1 conv2 conv3 conv4 conv5 31 | do 32 | 33 | python -m experiment.dissect_experiment \ 34 | --quantile ${QUANTILE} --miniou ${MINIOU} \ 35 | --model ${MODEL} --dataset ${DATASET} --seg ${SEG} --layer ${LAYER} 36 | 37 | done 38 | 39 | MODEL=progan 40 | for DATASET in kitchen church bedroom livingroom 41 | do 42 | 43 | for LAYER in layer1 layer2 layer3 layer4 layer5 layer6 layer7 \ 44 | layer8 layer9 layer10 layer11 layer12 layer13 layer14 45 | do 46 | 47 | python -m experiment.dissect_experiment \ 48 | --quantile ${QUANTILE} --miniou ${MINIOU} \ 49 | --model ${MODEL} --dataset ${DATASET} --seg ${SEG} --layer ${LAYER} 50 | 51 | done 52 | 53 | done 54 | 55 | -------------------------------------------------------------------------------- /experiment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidbau/dissect/9421eaa8672fd051088de6c0225a385064070935/experiment/__init__.py -------------------------------------------------------------------------------- /experiment/intervention_experiment.py: -------------------------------------------------------------------------------- 1 | # Measuring the importance of a unit to a class by measuring the 2 | # impact of removing sets of units on binary classification 3 | # accuracy for individual classes. 4 | 5 | import torch, argparse, os, json, numpy, random 6 | from netdissect import pbar, nethook 7 | from netdissect.sampler import FixedSubsetSampler 8 | from . import setting 9 | import netdissect 10 | torch.backends.cudnn.benchmark = True 11 | 12 | def parseargs(): 13 | parser = argparse.ArgumentParser() 14 | def aa(*args, **kwargs): 15 | parser.add_argument(*args, **kwargs) 16 | aa('--model', choices=['vgg16'], default='vgg16') 17 | aa('--dataset', choices=['places'], default='places') 18 | aa('--layer', default='features.conv5_3') 19 | args = parser.parse_args() 20 | return args 21 | 22 | def main(): 23 | args = parseargs() 24 | 25 | model = setting.load_classifier(args.model) 26 | model = nethook.InstrumentedModel(model).cuda().eval() 27 | layername = args.layer 28 | model.retain_layer(layername) 29 | dataset = setting.load_dataset(args.dataset, crop_size=224) 30 | train_dataset = setting.load_dataset(args.dataset, crop_size=224, 31 | split='train') 32 | sample_size = len(dataset) 33 | 34 | # Probe layer to get sizes 35 | model(dataset[0][0][None].cuda()) 36 | num_units = model.retained_layer(layername).shape[1] 37 | classlabels = dataset.classes 38 | 39 | # Measure baseline classification accuracy on val set, and cache. 40 | pbar.descnext('baseline_pra') 41 | baseline_precision, baseline_recall, baseline_accuracy, baseline_ba = ( 42 | test_perclass_pra( 43 | model, dataset, 44 | cachefile=sharedfile('pra-%s-%s/pra_baseline.npz' 45 | % (args.model, args.dataset)))) 46 | pbar.print('baseline acc', baseline_ba.mean().item()) 47 | 48 | # Now erase each unit, one at a time, and retest accuracy. 49 | unit_list = random.sample(list(range(num_units)), num_units) 50 | val_single_unit_ablation_ba = torch.zeros(num_units, len(classlabels)) 51 | for unit in pbar(unit_list): 52 | pbar.descnext('test unit %d' % unit) 53 | # Get binary accuracy if the model after ablating the unit. 54 | _, _, _, ablation_ba = test_perclass_pra( 55 | model, dataset, 56 | layername=layername, 57 | ablated_units=[unit], 58 | cachefile=sharedfile('pra-%s-%s/pra_ablate_unit_%d.npz' % 59 | (args.model, args.dataset, unit))) 60 | val_single_unit_ablation_ba[unit] = ablation_ba 61 | 62 | # For the purpose of ranking units by importance to a class, we 63 | # measure using the training set (to avoid training unit ordering 64 | # on the test set). 65 | sample_size = None 66 | # Measure baseline classification accuracy, and cache. 67 | pbar.descnext('train_baseline_pra') 68 | baseline_precision, baseline_recall, baseline_accuracy, baseline_ba = ( 69 | test_perclass_pra( 70 | model, train_dataset, 71 | sample_size=sample_size, 72 | cachefile=sharedfile('ttv-pra-%s-%s/pra_train_baseline.npz' 73 | % (args.model, args.dataset)))) 74 | pbar.print('baseline acc', baseline_ba.mean().item()) 75 | 76 | # Measure accuracy on the val set. 77 | pbar.descnext('val_baseline_pra') 78 | _, _, _, val_baseline_ba = ( 79 | test_perclass_pra( 80 | model, dataset, 81 | cachefile=sharedfile('ttv-pra-%s-%s/pra_val_baseline.npz' 82 | % (args.model, args.dataset)))) 83 | pbar.print('val baseline acc', val_baseline_ba.mean().item()) 84 | 85 | # Do in shuffled order to allow multiprocessing. 86 | single_unit_ablation_ba = torch.zeros(num_units, len(classlabels)) 87 | for unit in pbar(unit_list): 88 | pbar.descnext('test unit %d' % unit) 89 | _, _, _, ablation_ba = test_perclass_pra( 90 | model, train_dataset, 91 | layername=layername, 92 | ablated_units=[unit], 93 | sample_size=sample_size, 94 | cachefile= 95 | sharedfile('ttv-pra-%s-%s/pra_train_ablate_unit_%d.npz' % 96 | (args.model, args.dataset, unit))) 97 | single_unit_ablation_ba[unit] = ablation_ba 98 | 99 | # Now for every class, remove a set of the N most-important 100 | # and N least-important units for that class, and measure accuracy. 101 | for classnum in pbar(random.sample(range(len(classlabels)), 102 | len(classlabels))): 103 | # For a few classes, let's chart the whole range of ablations. 104 | if classnum in [100, 169, 351, 304]: 105 | num_best_list = range(1, num_units) 106 | else: 107 | num_best_list = [1, 2, 3, 4, 5, 20, 64, 128, 256] 108 | pbar.descnext('numbest') 109 | for num_best in pbar(random.sample(num_best_list, len(num_best_list))): 110 | num_worst = num_units - num_best 111 | unitlist = single_unit_ablation_ba[:,classnum].sort(0)[1][:num_best] 112 | _, _, _, testba = test_perclass_pra(model, dataset, 113 | layername=layername, 114 | ablated_units=unitlist, 115 | cachefile=sharedfile( 116 | 'ttv-pra-%s-%s/pra_val_ablate_classunits_%s_ba_%d.npz' 117 | % (args.model, args.dataset, classlabels[classnum], 118 | len(unitlist)))) 119 | unitlist = ( 120 | single_unit_ablation_ba[:,classnum].sort(0)[1][-num_worst:]) 121 | _, _, _, testba2 = test_perclass_pra(model, dataset, 122 | layername=layername, 123 | ablated_units=unitlist, 124 | cachefile=sharedfile( 125 | 'ttv-pra-%s-%s/pra_val_ablate_classunits_%s_worstba_%d.npz' % 126 | (args.model, args.dataset, classlabels[classnum], 127 | len(unitlist)))) 128 | pbar.print('%s: best %d %.3f vs worst N %.3f' % 129 | (classlabels[classnum], num_best, 130 | testba[classnum] - val_baseline_ba[classnum], 131 | testba2[classnum] - val_baseline_ba[classnum])) 132 | 133 | def test_perclass_pra(model, dataset, 134 | layername=None, ablated_units=None, sample_size=None, cachefile=None): 135 | '''Classifier precision/recall/accuracy measurement. 136 | Disables a set of units in the specified layer, and then 137 | measures per-class precision, recall, accuracy and 138 | balanced (binary classification) accuracy for each class, 139 | compared to the ground truth in the given dataset.''' 140 | try: 141 | if cachefile is not None: 142 | data = numpy.load(cachefile) 143 | # verify that this is computed. 144 | data['true_negative_rate'] 145 | result = tuple(torch.tensor(data[key]) for key in 146 | ['precision', 'recall', 'accuracy', 'balanced_accuracy']) 147 | pbar.print('Loading cached %s' % cachefile) 148 | return result 149 | except: 150 | pass 151 | model.remove_edits() 152 | if ablated_units is not None: 153 | def ablate_the_units(x, *args): 154 | x[:,ablated_units] = 0 155 | return x 156 | model.edit_layer(layername, rule=ablate_the_units) 157 | with torch.no_grad(): 158 | num_classes = len(dataset.classes) 159 | true_counts = torch.zeros(num_classes, dtype=torch.int64).cuda() 160 | pred_counts = torch.zeros(num_classes, dtype=torch.int64).cuda() 161 | correct_counts = torch.zeros(num_classes, dtype=torch.int64).cuda() 162 | total_count = 0 163 | sampler = None if sample_size is None else ( 164 | FixedSubsetSampler(list(range(sample_size)))) 165 | loader = torch.utils.data.DataLoader( 166 | dataset, batch_size=100, num_workers=20, 167 | sampler=sampler, pin_memory=True) 168 | for image_batch, class_batch in pbar(loader): 169 | total_count += len(image_batch) 170 | image_batch, class_batch = [ 171 | d.cuda() for d in [image_batch, class_batch]] 172 | scores = model(image_batch) 173 | preds = scores.max(1)[1] 174 | correct = (preds == class_batch) 175 | true_counts.add_(class_batch.bincount(minlength=num_classes)) 176 | pred_counts.add_(preds.bincount(minlength=num_classes)) 177 | correct_counts.add_(class_batch.bincount( 178 | correct, minlength=num_classes).long()) 179 | model.remove_edits() 180 | true_neg_counts = ( 181 | (total_count - true_counts) - (pred_counts - correct_counts)) 182 | precision = (correct_counts.float() / pred_counts.float()).cpu() 183 | recall = (correct_counts.float() / true_counts.float()).cpu() 184 | accuracy = (correct_counts + true_neg_counts).float().cpu() / total_count 185 | true_neg_rate = (true_neg_counts.float() / 186 | (total_count - true_counts).float()).cpu() 187 | balanced_accuracy = (recall + true_neg_rate) / 2 188 | if cachefile is not None: 189 | numpy.savez(cachefile, 190 | precision=precision.numpy(), 191 | recall=recall.numpy(), 192 | accuracy=accuracy.numpy(), 193 | true_negative_rate=true_neg_rate.numpy(), 194 | balanced_accuracy=balanced_accuracy.numpy()) 195 | return precision, recall, accuracy, balanced_accuracy 196 | 197 | def sharedfile(fn): 198 | filename = os.path.join('results/shared', fn) 199 | os.makedirs(os.path.dirname(filename), exist_ok=True) 200 | return filename 201 | 202 | if __name__ == '__main__': 203 | main() 204 | 205 | -------------------------------------------------------------------------------- /experiment/oldalexnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | # based on https://github.com/jiecaoyu/pytorch_imagenet 3 | 4 | import os 5 | import torch 6 | import sys 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from collections import OrderedDict 10 | import torchvision.transforms 11 | import numpy 12 | 13 | def load_places_alexnet(weight_file): 14 | model = AlexNet() 15 | state_dict = torch.load(weight_file) 16 | model.load_state_dict(state_dict) 17 | return model 18 | 19 | 20 | class AlexNet(nn.Sequential): 21 | 22 | def __init__(self, num_classes=None, 23 | include_lrn=True, split_groups=True, 24 | include_dropout=True): 25 | w = [3, 96, 256, 384, 384, 256, 4096, 4096, 365] 26 | if num_classes is not None: 27 | w[-1] = num_classes 28 | if split_groups is True: 29 | groups = [1, 2, 1, 2, 2] 30 | else: 31 | groups = [1, 1, 1, 1, 1] 32 | sequence = OrderedDict() 33 | for name, module in [ 34 | ('conv1', nn.Conv2d(w[0], w[1], kernel_size=11, 35 | stride=4, 36 | groups=groups[0], bias=True)), 37 | ('relu1', nn.ReLU(inplace=True)), 38 | ('pool1', nn.MaxPool2d(kernel_size=3, stride=2)), 39 | ('lrn1', LRN(local_size=5, alpha=0.0001, beta=0.75)), 40 | ('conv2', nn.Conv2d(w[1], w[2], kernel_size=5, padding=2, 41 | groups=groups[1], bias=True)), 42 | ('relu2', nn.ReLU(inplace=True)), 43 | ('pool2', nn.MaxPool2d(kernel_size=3, stride=2)), 44 | ('lrn2', LRN(local_size=5, alpha=0.0001, beta=0.75)), 45 | ('conv3', nn.Conv2d(w[2], w[3], kernel_size=3, padding=1, 46 | groups=groups[2], bias=True)), 47 | ('relu3', nn.ReLU(inplace=True)), 48 | ('conv4', nn.Conv2d(w[3], w[4], kernel_size=3, padding=1, 49 | groups=groups[3], bias=True)), 50 | ('relu4', nn.ReLU(inplace=True)), 51 | ('conv5', nn.Conv2d(w[4], w[5], kernel_size=3, padding=1, 52 | groups=groups[4], bias=True)), 53 | ('relu5', nn.ReLU(inplace=True)), 54 | ('pool5', nn.MaxPool2d(kernel_size=3, stride=2)), 55 | ('flatten', Vectorize()), 56 | ('fc6', nn.Linear(w[5] * 6 * 6, w[6], bias=True)), 57 | ('relu6', nn.ReLU(inplace=True)), 58 | ('dropout6', nn.Dropout()), 59 | ('fc7', nn.Linear(w[6], w[7], bias=True)), 60 | ('relu7', nn.ReLU(inplace=True)), 61 | ('dropout7', nn.Dropout()), 62 | ('fc8', nn.Linear(w[7], w[8])) ]: 63 | if not include_lrn and name.startswith('lrn'): 64 | continue 65 | if not include_dropout and name.startswith('drop'): 66 | continue 67 | sequence[name] = module 68 | super(AlexNet, self).__init__(sequence) 69 | 70 | class LRN(nn.Module): 71 | def __init__(self, local_size=1, alpha=1.0, beta=0.75, 72 | ACROSS_CHANNELS=True): 73 | super(LRN, self).__init__() 74 | self.ACROSS_CHANNELS = ACROSS_CHANNELS 75 | if ACROSS_CHANNELS: 76 | self.average=nn.AvgPool3d(kernel_size=(local_size, 1, 1), 77 | stride=1, 78 | padding=(int((local_size-1.0)/2), 0, 0)) 79 | else: 80 | self.average=nn.AvgPool2d(kernel_size=local_size, 81 | stride=1, 82 | padding=int((local_size-1.0)/2)) 83 | self.alpha = alpha 84 | self.beta = beta 85 | 86 | def forward(self, x): 87 | if self.ACROSS_CHANNELS: 88 | div = x.pow(2).unsqueeze(1) 89 | div = self.average(div).squeeze(1) 90 | div = div.mul(self.alpha).add(1.0).pow(self.beta) 91 | else: 92 | div = x.pow(2) 93 | div = self.average(div) 94 | div = div.mul(self.alpha).add(1.0).pow(self.beta) 95 | x = x.div(div) 96 | return x 97 | 98 | class Vectorize(nn.Module): 99 | def __init__(self): 100 | super(Vectorize, self).__init__() 101 | 102 | def forward(self, x): 103 | x = x.view(x.size(0), int(numpy.prod(x.size()[1:]))) 104 | return x 105 | 106 | -------------------------------------------------------------------------------- /experiment/oldvgg16.py: -------------------------------------------------------------------------------- 1 | import collections, torch, torchvision, numpy 2 | 3 | # Return a version of vgg16 where the layers are given their research names. 4 | def vgg16(*args, **kwargs): 5 | model = torchvision.models.vgg16(*args, **kwargs) 6 | model.features = torch.nn.Sequential(collections.OrderedDict(zip([ 7 | 'conv1_1', 'relu1_1', 8 | 'conv1_2', 'relu1_2', 9 | 'pool1', 10 | 'conv2_1', 'relu2_1', 11 | 'conv2_2', 'relu2_2', 12 | 'pool2', 13 | 'conv3_1', 'relu3_1', 14 | 'conv3_2', 'relu3_2', 15 | 'conv3_3', 'relu3_3', 16 | 'pool3', 17 | 'conv4_1', 'relu4_1', 18 | 'conv4_2', 'relu4_2', 19 | 'conv4_3', 'relu4_3', 20 | 'pool4', 21 | 'conv5_1', 'relu5_1', 22 | 'conv5_2', 'relu5_2', 23 | 'conv5_3', 'relu5_3', 24 | 'pool5'], 25 | model.features))) 26 | 27 | model.classifier = torch.nn.Sequential(collections.OrderedDict(zip([ 28 | 'fc6', 'relu6', 29 | 'drop6', 30 | 'fc7', 'relu7', 31 | 'drop7', 32 | 'fc8a'], 33 | model.classifier))) 34 | 35 | return model 36 | -------------------------------------------------------------------------------- /experiment/readdissect.py: -------------------------------------------------------------------------------- 1 | import argparse, os, json, numpy, PIL.Image, torch, torchvision, collections 2 | import math, shutil 3 | from netdissect import pidfile, tally, nethook, parallelfolder 4 | from netdissect import upsample, imgviz, imgsave, renormalize, bargraph 5 | from netdissect import runningstats 6 | 7 | 8 | class DissectVis: 9 | ''' 10 | Code to read out the dissection in a set of directories. 11 | ''' 12 | def __init__(self, outdir='results', model='vgg16', 13 | dataset='places', layers=None, 14 | seg='netpqc', quantile=0.01): 15 | labels = {} 16 | iou = {} 17 | images = {} 18 | rq = {} 19 | dirs = {} 20 | for k in layers: 21 | dirname = os.path.join(outdir, 22 | f"{model}-{dataset}-{seg}-{k}-{int(1000 * quantile)}") 23 | dirs[k] = dirname 24 | with open(os.path.join(dirname, 'report.json')) as f: 25 | labels[k] = json.load(f)['units'] 26 | rq[k] = runningstats.RunningQuantile( 27 | state=numpy.load(os.path.join(dirname, 'rq.npz'), 28 | allow_pickle=True)) 29 | images[k] = [None] * rq[k].depth 30 | self.dirs = dirs 31 | self.labels = labels 32 | self.rqtable = rq 33 | self.images = images 34 | self.basedir = outdir 35 | 36 | def label(self, layer, unit): 37 | return self.labels[layer][unit]['label'] 38 | def iou(self, layer, unit): 39 | return self.labels[layer][unit]['iou'] 40 | def dir(self, layer): 41 | return self.dirs[layer] 42 | def rq(self, layer): 43 | return self.rqtable[layer] 44 | def image(self, layer, unit): 45 | result = self.images[layer][unit] 46 | # Lazy loading of images. 47 | if result is None: 48 | result = PIL.Image.open(os.path.join( 49 | self.dirs[layer], 50 | 'image/unit%d.jpg' % unit)) 51 | result.load() 52 | self.images[layer][unit] = result 53 | return result 54 | 55 | def save_bargraph(self, filename, layer, min_iou=0.04): 56 | svg = self.bargraph(layer, min_iou=min_iou, file_header=True) 57 | with open(filename, 'w') as f: 58 | f.write(svg) 59 | 60 | def img_bargraph(self, layer, min_iou=0.04): 61 | url = self.bargraph(layer, min_iou=min_iou, data_url=True) 62 | class H: 63 | def __init__(self, url): 64 | self.url = url 65 | def _repr_html_(self): 66 | return '' % self.url 67 | return H(url) 68 | 69 | def bargraph(self, layer, min_iou=0.04, **kwargs): 70 | labelcat_list = [] 71 | for rec in self.labels[layer]: 72 | if rec['iou'] and rec['iou'] >= min_iou: 73 | labelcat_list.append(tuple(rec['cat'])) 74 | return self.bargraph_from_conceptcatlist(labelcat_list, **kwargs) 75 | 76 | def bargraph_from_conceptcatlist(self, conceptcatlist, **kwargs): 77 | count = collections.defaultdict(int) 78 | catcount = collections.defaultdict(int) 79 | for c in conceptcatlist: 80 | count[c] += 1 81 | for c in count.keys(): 82 | catcount[c[1]] += 1 83 | cats = ['object', 'part', 'material', 'texture', 'color'] 84 | catorder = dict((c, i) for i, c in enumerate(cats)) 85 | sorted_labels = sorted(count.keys(), 86 | key=lambda x: (catorder[x[1]], -count[x])) 87 | sorted_labels 88 | return bargraph.make_svg_bargraph( 89 | [label for label, cat in sorted_labels], 90 | [count[k] for k in sorted_labels], 91 | [(c, catcount[c]) for c in cats], **kwargs) 92 | 93 | -------------------------------------------------------------------------------- /experiment/setting.py: -------------------------------------------------------------------------------- 1 | import torch, torchvision, os, collections 2 | from netdissect import parallelfolder, zdataset, renormalize, segmenter 3 | from . import oldalexnet, oldvgg16, oldresnet152 4 | 5 | def load_proggan(domain): 6 | # Automatically download and cache progressive GAN model 7 | # (From Karras, converted from Tensorflow to Pytorch.) 8 | from . import proggan 9 | weights_filename = dict( 10 | bedroom='proggan_bedroom-d8a89ff1.pth', 11 | church='proggan_churchoutdoor-7e701dd5.pth', 12 | conferenceroom='proggan_conferenceroom-21e85882.pth', 13 | diningroom='proggan_diningroom-3aa0ab80.pth', 14 | kitchen='proggan_kitchen-67f1e16c.pth', 15 | livingroom='proggan_livingroom-5ef336dd.pth', 16 | restaurant='proggan_restaurant-b8578299.pth', 17 | celebhq='proggan_celebhq-620d161c.pth')[domain] 18 | # Posted here. 19 | url = 'https://dissect.csail.mit.edu/models/' + weights_filename 20 | try: 21 | sd = torch.hub.load_state_dict_from_url(url) # pytorch 1.1+ 22 | except: 23 | sd = torch.hub.model_zoo.load_url(url) # pytorch 1.0 24 | model = proggan.from_state_dict(sd) 25 | return model 26 | 27 | def load_classifier(architecture): 28 | model_factory = dict( 29 | alexnet=oldalexnet.AlexNet, 30 | vgg16=oldvgg16.vgg16, 31 | resnet152=oldresnet152.OldResNet152)[architecture] 32 | weights_filename = dict( 33 | alexnet='alexnet_places365-92864cf6.pth', 34 | vgg16='vgg16_places365-0bafbc55.pth', 35 | resnet152='resnet152_places365-f928166e5c.pth')[architecture] 36 | model = model_factory(num_classes=365) 37 | baseurl = 'https://dissect.csail.mit.edu/models/' 38 | url = baseurl + weights_filename 39 | try: 40 | sd = torch.hub.load_state_dict_from_url(url) # pytorch 1.1 41 | except: 42 | sd = torch.hub.model_zoo.load_url(url) # pytorch 1.0 43 | model.load_state_dict(sd) 44 | model.eval() 45 | return model 46 | 47 | def load_dataset(domain, split=None, full=False, crop_size=None, download=True): 48 | if domain in ['places', 'imagenet']: 49 | if split is None: 50 | split = 'val' 51 | dirname = 'datasets/%s/%s' % (domain, split) 52 | if download and not os.path.exists(dirname) and domain == 'places': 53 | os.makedirs('datasets', exist_ok=True) 54 | torchvision.datasets.utils.download_and_extract_archive( 55 | 'https://dissect.csail.mit.edu/datasets/' + 56 | 'places_%s.zip' % split, 57 | 'datasets', 58 | md5=dict(val='593bbc21590cf7c396faac2e600cd30c', 59 | train='d1db6ad3fc1d69b94da325ac08886a01')[split]) 60 | places_transform = torchvision.transforms.Compose([ 61 | torchvision.transforms.Resize((256, 256)), 62 | torchvision.transforms.CenterCrop(crop_size or 224), 63 | torchvision.transforms.ToTensor(), 64 | renormalize.NORMALIZER['imagenet']]) 65 | return parallelfolder.ParallelImageFolders([dirname], 66 | classification=True, 67 | shuffle=True, 68 | transform=places_transform) 69 | 70 | def load_segmenter(segmenter_name='netpqc'): 71 | '''Loads the segementer.''' 72 | all_parts = ('p' in segmenter_name) 73 | quad_seg = ('q' in segmenter_name) 74 | textures = ('x' in segmenter_name) 75 | colors = ('c' in segmenter_name) 76 | 77 | segmodels = [] 78 | segmodels.append(segmenter.UnifiedParsingSegmenter(segsizes=[256], 79 | all_parts=all_parts, 80 | segdiv=('quad' if quad_seg else None))) 81 | if textures: 82 | segmenter.ensure_segmenter_downloaded('datasets/segmodel', 'texture') 83 | segmodels.append(segmenter.SemanticSegmenter( 84 | segvocab="texture", segarch=("resnet18dilated", "ppm_deepsup"))) 85 | if colors: 86 | segmenter.ensure_segmenter_downloaded('datasets/segmodel', 'color') 87 | segmodels.append(segmenter.SemanticSegmenter( 88 | segvocab="color", segarch=("resnet18dilated", "ppm_deepsup"))) 89 | if len(segmodels) == 1: 90 | segmodel = segmodels[0] 91 | else: 92 | segmodel = segmenter.MergedSegmenter(segmodels) 93 | seglabels = [l for l, c in segmodel.get_label_and_category_names()[0]] 94 | segcatlabels = segmodel.get_label_and_category_names()[0] 95 | return segmodel, seglabels, segcatlabels 96 | 97 | if __name__ == '__main__': 98 | main() 99 | 100 | -------------------------------------------------------------------------------- /experiment/shapebias_experiment.py: -------------------------------------------------------------------------------- 1 | from netdissect import parallelfolder, show, tally, nethook, renormalize 2 | from . import readdissect, setting 3 | import copy, PIL.Image 4 | from netdissect import upsample, imgsave, imgviz 5 | import re, torchvision, torch, os 6 | from IPython.display import SVG 7 | from matplotlib import pyplot as plt 8 | 9 | torch.set_grad_enabled(False) 10 | 11 | def normalize_filename(n): 12 | return re.match(r'^(.*Places365_\w+_\d+)', n).group(1) 13 | 14 | ds = parallelfolder.ParallelImageFolders( 15 | ['datasets/places/val', 'datasets/stylized-places/val'], 16 | transform=torchvision.transforms.Compose([ 17 | torchvision.transforms.Resize(256), 18 | # transforms.CenterCrop(224), 19 | torchvision.transforms.CenterCrop(256), 20 | torchvision.transforms.ToTensor(), 21 | renormalize.NORMALIZER['imagenet'], 22 | ]), 23 | normalize_filename=normalize_filename, 24 | shuffle=True) 25 | 26 | 27 | layers = [ 28 | 'conv5_3', 29 | 'conv5_2', 30 | 'conv5_1', 31 | 'conv4_3', 32 | 'conv4_2', 33 | 'conv4_1', 34 | 'conv3_3', 35 | 'conv3_2', 36 | 'conv3_1', 37 | 'conv2_2', 38 | 'conv2_1', 39 | 'conv1_2', 40 | 'conv1_1', 41 | ] 42 | qd = readdissect.DissectVis(layers=layers) 43 | net = setting.load_classifier('vgg16') 44 | 45 | sds = parallelfolder.ParallelImageFolders( 46 | ['datasets/stylized-places/val'], 47 | transform=torchvision.transforms.Compose([ 48 | torchvision.transforms.Resize(256), 49 | # transforms.CenterCrop(224), 50 | torchvision.transforms.CenterCrop(256), 51 | torchvision.transforms.ToTensor(), 52 | renormalize.NORMALIZER['imagenet'], 53 | ]), 54 | normalize_filename=normalize_filename, 55 | shuffle=True) 56 | 57 | uds = parallelfolder.ParallelImageFolders( 58 | ['datasets/places/val'], 59 | transform=torchvision.transforms.Compose([ 60 | torchvision.transforms.Resize(256), 61 | # transforms.CenterCrop(224), 62 | torchvision.transforms.CenterCrop(256), 63 | torchvision.transforms.ToTensor(), 64 | renormalize.NORMALIZER['imagenet'], 65 | ]), 66 | normalize_filename=normalize_filename, 67 | shuffle=True) 68 | 69 | # def s_image(layername, unit): 70 | # result = PIL.Image.open(os.path.join(qd.dir(layername), 's_imgs/unit_%d.png' % unit)) 71 | # result.load() 72 | # return result 73 | 74 | for layername in layers: 75 | #if os.path.isfile(os.path.join(qd.dir(layername), 'intersect_99.npz')): 76 | # continue 77 | busy_fn = os.path.join(qd.dir(layername), 'busy.txt') 78 | if os.path.isfile(busy_fn): 79 | print(busy_fn) 80 | continue 81 | with open(busy_fn, 'w') as f: 82 | f.write('busy') 83 | print('working on', layername) 84 | 85 | inst_net = nethook.InstrumentedModel(copy.deepcopy(net)).cuda() 86 | inst_net.retain_layer('features.' + layername) 87 | inst_net(ds[0][0][None].cuda()) 88 | sample_act = inst_net.retained_layer('features.' + layername).cpu() 89 | upfn = upsample.upsampler((64, 64), sample_act.shape[2:]) 90 | 91 | def flat_acts(batch): 92 | inst_net(batch.cuda()) 93 | acts = upfn(inst_net.retained_layer('features.' + layername)) 94 | return acts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1]) 95 | s_rq = tally.tally_quantile(flat_acts, sds, cachefile=os.path.join( 96 | qd.dir(layername), 's_rq.npz')) 97 | u_rq = qd.rq(layername) 98 | 99 | def intersect_99_fn(uimg, simg): 100 | s_99 = s_rq.quantiles(0.99)[None,:,None,None].cuda() 101 | u_99 = u_rq.quantiles(0.99)[None,:,None,None].cuda() 102 | with torch.no_grad(): 103 | ux, sx = uimg.cuda(), simg.cuda() 104 | inst_net(ux) 105 | ur = inst_net.retained_layer('features.' + layername) 106 | inst_net(sx) 107 | sr = inst_net.retained_layer('features.' + layername) 108 | return ((sr > s_99).float() * (ur > u_99).float() 109 | ).permute(0, 2, 3, 1).reshape(-1, ur.size(1)) 110 | 111 | intersect_99 = tally.tally_mean(intersect_99_fn, ds, 112 | cachefile=os.path.join(qd.dir(layername), 'intersect_99.npz')) 113 | 114 | def compute_image_max(batch): 115 | inst_net(batch.cuda()) 116 | return inst_net.retained_layer( 117 | 'features.' + layername).max(3)[0].max(2)[0] 118 | 119 | s_topk = tally.tally_topk(compute_image_max, sds, 120 | cachefile=os.path.join(qd.dir(layername), 's_topk.npz')) 121 | 122 | def compute_acts(image_batch): 123 | inst_net(image_batch.cuda()) 124 | acts_batch = inst_net.retained_layer('features.' + layername) 125 | return (acts_batch, image_batch) 126 | 127 | iv = imgviz.ImageVisualizer(128, quantiles=s_rq, source=sds) 128 | unit_images = iv.masked_images_for_topk(compute_acts, sds, s_topk, k=5) 129 | os.makedirs(os.path.join(qd.dir(layername),'s_imgs'), exist_ok=True) 130 | imgsave.save_image_set(unit_images, 131 | os.path.join(qd.dir(layername),'s_imgs/unit%d.jpg')) 132 | 133 | iv = imgviz.ImageVisualizer(128, quantiles=u_rq, source=uds) 134 | unit_images = iv.masked_images_for_topk(compute_acts, uds, s_topk, k=5) 135 | os.makedirs(os.path.join(qd.dir(layername),'su_imgs'), exist_ok=True) 136 | imgsave.save_image_set(unit_images, 137 | os.path.join(qd.dir(layername),'su_imgs/unit%d.jpg')) 138 | 139 | os.remove(busy_fn) 140 | -------------------------------------------------------------------------------- /g_intervention.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATASET=church 4 | MODEL=progan 5 | LAYER=layer4 6 | 7 | python -m experiment.generator_int_experiment \ 8 | --model ${MODEL} --dataset ${DATASET} --layer ${LAYER} 9 | 10 | -------------------------------------------------------------------------------- /intervention.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATASET=places 4 | MODEL=vgg16 5 | LAYER=features.conv5_3 6 | 7 | python -m experiment.intervention_experiment \ 8 | --model ${MODEL} --dataset ${DATASET} --layer ${LAYER} 9 | 10 | -------------------------------------------------------------------------------- /netdissect/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidbau/dissect/9421eaa8672fd051088de6c0225a385064070935/netdissect/__init__.py -------------------------------------------------------------------------------- /netdissect/bargraph.py: -------------------------------------------------------------------------------- 1 | from xml.etree import ElementTree as et 2 | 3 | 4 | def make_svg_bargraph(labels, heights, categories=None, palette=None, 5 | barheight=100, barwidth=12, show_labels=True, file_header=False, 6 | data_url=False): 7 | if palette is None: 8 | palette = default_bargraph_palette 9 | if categories is None: 10 | categories = [('', len(labels))] 11 | unitheight = float(barheight) / max(max(heights, default=1), 1) 12 | textheight = barheight if show_labels else 0 13 | labelsize = float(barwidth) 14 | gap = float(barwidth) / 4 15 | # textsize = barwidth + gap 16 | textsize = barwidth + gap / 2 17 | rollup = max(heights, default=1) 18 | textmargin = float(labelsize) * 2 / 3 19 | leftmargin = 32 20 | rightmargin = 8 21 | svgwidth = len(heights) * (barwidth + gap) + 2 * leftmargin + rightmargin 22 | svgheight = barheight + textheight 23 | 24 | # create an SVG XML element 25 | svg = et.Element('svg', width=str(svgwidth), height=str(svgheight), 26 | version='1.1', xmlns='http://www.w3.org/2000/svg') 27 | 28 | # Draw the bar graph 29 | basey = svgheight - textheight 30 | x = leftmargin 31 | # Add units scale on left 32 | if len(heights): 33 | for h in [1, (max(heights) + 1) // 2, max(heights)]: 34 | et.SubElement(svg, 'text', x='0', y='0', 35 | style=('font-family:sans-serif;font-size:%dpx;' + 36 | 'text-anchor:end;alignment-baseline:hanging;' + 37 | 'transform:translate(%dpx, %dpx);') % 38 | (textsize, x - gap, basey - h * unitheight)).text = str(h) 39 | et.SubElement(svg, 'text', x='0', y='0', 40 | style=('font-family:sans-serif;font-size:%dpx;' + 41 | 'text-anchor:middle;' + 42 | 'transform:translate(%dpx, %dpx) rotate(-90deg)') % 43 | (textsize, x - gap - textsize, basey - h * unitheight / 2) 44 | ).text = 'units' 45 | # Draw big category background rectangles 46 | for catindex, (cat, catcount) in enumerate(categories): 47 | if not catcount: 48 | continue 49 | et.SubElement(svg, 'rect', x=str(x), y=str(basey - rollup * unitheight), 50 | width=(str((barwidth + gap) * catcount - gap)), 51 | height=str(rollup * unitheight), 52 | fill=palette[catindex % len(palette)][1]) 53 | x += (barwidth + gap) * catcount 54 | # Draw small bars as well as 45degree text labels 55 | x = leftmargin 56 | catindex = -1 57 | catcount = 0 58 | for label, height in zip(labels, heights): 59 | while not catcount and catindex <= len(categories): 60 | catindex += 1 61 | catcount = categories[catindex][1] 62 | color = palette[catindex % len(palette)][0] 63 | et.SubElement(svg, 'rect', x=str(x), y=str(basey - (height * unitheight)), 64 | width=str(barwidth), height=str(height * unitheight), 65 | fill=color) 66 | x += barwidth 67 | if show_labels: 68 | et.SubElement(svg, 'text', x='0', y='0', 69 | style=('font-family:sans-serif;font-size:%dpx;text-anchor:end;' + 70 | 'transform:translate(%dpx, %dpx) rotate(-45deg);') % 71 | (labelsize, x, basey + textmargin)).text = label 72 | x += gap 73 | catcount -= 1 74 | # Text labels for each category 75 | x = leftmargin 76 | for cat, catcount in categories: 77 | if not catcount: 78 | continue 79 | et.SubElement(svg, 'text', x='0', y='0', 80 | style=('font-family:sans-serif;font-size:%dpx;text-anchor:end;' + 81 | 'transform:translate(%dpx, %dpx) rotate(-90deg);') % 82 | (textsize, x + (barwidth + gap) * catcount - gap, 83 | basey - rollup * unitheight + gap)).text = '%d %s' % ( 84 | catcount, cat + ('s' if catcount != 1 else '')) 85 | x += (barwidth + gap) * catcount 86 | # Output - this is the bare svg. 87 | result = et.tostring(svg).decode('utf-8') 88 | if file_header or data_url: 89 | result = ''.join([ 90 | '\n', 91 | '\n', 93 | result]) 94 | if data_url: 95 | import base64 96 | result = 'data:image/svg+xml;base64,' + base64.b64encode( 97 | result.encode('utf-8')).decode('utf-8') 98 | return result 99 | 100 | 101 | default_bargraph_palette = [ 102 | ('#4B4CBF', '#B6B6F2'), 103 | ('#55B05B', '#B6F2BA'), 104 | ('#50BDAC', '#A5E5DB'), 105 | ('#81C679', '#C0FF9B'), 106 | ('#F0883B', '#F2CFB6'), 107 | ('#D4CF24', '#F2F1B6'), 108 | ('#D92E2B', '#F2B6B6'), 109 | ('#AB6BC6', '#CFAAFF'), 110 | ] 111 | -------------------------------------------------------------------------------- /netdissect/easydict.py: -------------------------------------------------------------------------------- 1 | ''' 2 | From https://github.com/makinacorpus/easydict. 3 | ''' 4 | 5 | 6 | class EasyDict(dict): 7 | """ 8 | Get attributes 9 | 10 | >>> d = EasyDict({'foo':3}) 11 | >>> d['foo'] 12 | 3 13 | >>> d.foo 14 | 3 15 | >>> d.bar 16 | Traceback (most recent call last): 17 | ... 18 | AttributeError: 'EasyDict' object has no attribute 'bar' 19 | 20 | Works recursively 21 | 22 | >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) 23 | >>> isinstance(d.bar, dict) 24 | True 25 | >>> d.bar.x 26 | 1 27 | 28 | Bullet-proof 29 | 30 | >>> EasyDict({}) 31 | {} 32 | >>> EasyDict(d={}) 33 | {} 34 | >>> EasyDict(None) 35 | {} 36 | >>> d = {'a': 1} 37 | >>> EasyDict(**d) 38 | {'a': 1} 39 | 40 | Set attributes 41 | 42 | >>> d = EasyDict() 43 | >>> d.foo = 3 44 | >>> d.foo 45 | 3 46 | >>> d.bar = {'prop': 'value'} 47 | >>> d.bar.prop 48 | 'value' 49 | >>> d 50 | {'foo': 3, 'bar': {'prop': 'value'}} 51 | >>> d.bar.prop = 'newer' 52 | >>> d.bar.prop 53 | 'newer' 54 | 55 | 56 | Values extraction 57 | 58 | >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) 59 | >>> isinstance(d.bar, list) 60 | True 61 | >>> from operator import attrgetter 62 | >>> map(attrgetter('x'), d.bar) 63 | [1, 3] 64 | >>> map(attrgetter('y'), d.bar) 65 | [2, 4] 66 | >>> d = EasyDict() 67 | >>> d.keys() 68 | [] 69 | >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) 70 | >>> d.foo 71 | 3 72 | >>> d.bar.x 73 | 1 74 | 75 | Still like a dict though 76 | 77 | >>> o = EasyDict({'clean':True}) 78 | >>> o.items() 79 | [('clean', True)] 80 | 81 | And like a class 82 | 83 | >>> class Flower(EasyDict): 84 | ... power = 1 85 | ... 86 | >>> f = Flower() 87 | >>> f.power 88 | 1 89 | >>> f = Flower({'height': 12}) 90 | >>> f.height 91 | 12 92 | >>> f['power'] 93 | 1 94 | >>> sorted(f.keys()) 95 | ['height', 'power'] 96 | """ 97 | 98 | def __init__(self, d=None, **kwargs): 99 | if d is None: 100 | d = {} 101 | if kwargs: 102 | d.update(**kwargs) 103 | for k, v in d.items(): 104 | setattr(self, k, v) 105 | # Class attributes 106 | for k in self.__class__.__dict__.keys(): 107 | if not (k.startswith('__') and k.endswith('__')): 108 | setattr(self, k, getattr(self, k)) 109 | 110 | def __setattr__(self, name, value): 111 | if isinstance(value, (list, tuple)): 112 | value = [self.__class__(x) 113 | if isinstance(x, dict) else x for x in value] 114 | elif isinstance(value, dict) and not isinstance(value, self.__class__): 115 | value = self.__class__(value) 116 | super(EasyDict, self).__setattr__(name, value) 117 | super(EasyDict, self).__setitem__(name, value) 118 | 119 | __setitem__ = __setattr__ 120 | 121 | 122 | def load_json(filename): 123 | import json 124 | with open(filename) as f: 125 | return EasyDict(json.load(f)) 126 | 127 | 128 | if __name__ == "__main__": 129 | import doctest 130 | doctest.testmod() 131 | -------------------------------------------------------------------------------- /netdissect/imgsave.py: -------------------------------------------------------------------------------- 1 | # A utility for saving a large number of images quickly without 2 | # blocking a single thread to wait for each individual image to save. 3 | 4 | import os 5 | import PIL 6 | from .workerpool import WorkerBase, WorkerPool 7 | from . import pbar 8 | 9 | 10 | def all_items_and_filenames(img_array, filename_pattern, index=()): 11 | for i, data in enumerate(img_array): 12 | inner_index = index + (i,) 13 | if PIL.Image.isImageType(data): 14 | yield data, (filename_pattern % inner_index) 15 | else: 16 | for img, name in all_items_and_filenames(data, filename_pattern, 17 | inner_index): 18 | yield img, name 19 | 20 | 21 | def expand_last_filename(img_array, filename_pattern): 22 | index, data = (), img_array 23 | while not PIL.Image.isImageType(data): 24 | index += (len(data) - 1,) 25 | data = data[len(data) - 1] 26 | return filename_pattern % index 27 | 28 | 29 | def num_items(img_array): 30 | num = 1 31 | while not PIL.Image.isImageType(img_array): 32 | num *= len(img_array) 33 | img_array = img_array[-1] 34 | return num 35 | 36 | 37 | def save_image_set(img_array, filename_pattern, sourcefile=None): 38 | ''' 39 | Saves all the (PIL) images in the given array, using the 40 | given filename pattern (which should contain a `%d` to get 41 | the index number of the image). 42 | ''' 43 | if sourcefile is not None: 44 | last_filename = expand_last_filename(img_array, filename_pattern) 45 | # Do nothing if the last file exists and is newer than the sourcefile 46 | if os.path.isfile(last_filename) and (os.path.getmtime(last_filename) 47 | >= os.path.getmtime(sourcefile)): 48 | pbar.descnext(None) 49 | return 50 | # Use multiple threads to write all the image files faster. 51 | pool = WorkerPool(worker=SaveImageWorker) 52 | for img, filename in pbar( 53 | all_items_and_filenames(img_array, filename_pattern), 54 | total=num_items(img_array)): 55 | pool.add(img, filename) 56 | pool.join() 57 | 58 | 59 | class SaveImageWorker(WorkerBase): 60 | def work(self, img, filename, quality=99): 61 | os.makedirs(os.path.dirname(filename), exist_ok=True) 62 | img.save(filename, optimize=True, quality=quality) 63 | 64 | 65 | class SaveImagePool(WorkerPool): 66 | def __init__(self, *args, **kwargs): 67 | super().__init__(*args, worker=SaveImageWorker, **kwargs) 68 | -------------------------------------------------------------------------------- /netdissect/paintwidget.py: -------------------------------------------------------------------------------- 1 | from .labwidget import Widget, Property, minify 2 | 3 | 4 | class PaintWidget(Widget): 5 | def __init__(self, 6 | width=256, height=256, 7 | image='', mask='', brushsize=10.0, oneshot=False, disabled=False, 8 | vanishing=True, opacity=0.7, 9 | **kwargs): 10 | super().__init__(**kwargs) 11 | self.mask = Property(mask) 12 | self.image = Property(image) 13 | self.vanishing = Property(vanishing) 14 | self.brushsize = Property(brushsize) 15 | self.erase = Property(False) 16 | self.oneshot = Property(oneshot) 17 | self.disabled = Property(disabled) 18 | self.width = Property(width) 19 | self.height = Property(height) 20 | self.opacity = Property(opacity) 21 | self.startpos = Property(None) 22 | self.dragpos = Property(None) 23 | self.dragging = Property(False) 24 | 25 | def widget_js(self): 26 | return minify(f''' 27 | {PAINT_WIDGET_JS} 28 | var pw = new PaintWidget(element, model); 29 | ''') 30 | 31 | def widget_html(self): 32 | v = self.view_id() 33 | return minify(f''' 34 | 43 |
44 | ''') 45 | 46 | 47 | PAINT_WIDGET_JS = """ 48 | class PaintWidget { 49 | constructor(el, model) { 50 | this.el = el; 51 | this.model = model; 52 | this.size_changed(); 53 | this.model.on('mask', this.mask_changed.bind(this)); 54 | this.model.on('image', this.image_changed.bind(this)); 55 | this.model.on('vanishing', this.mask_changed.bind(this)); 56 | this.model.on('width', this.size_changed.bind(this)); 57 | this.model.on('height', this.size_changed.bind(this)); 58 | } 59 | mouse_stroke(first_event) { 60 | var self = this; 61 | if (first_event.which === 3 || first_event.button === 2) { 62 | first_event.preventDefault(); 63 | self.mask_canvas.style.pointerEvents = 'none'; 64 | setTimeout(() => { 65 | self.mask_canvas.style.pointerEvents = 'all'; 66 | }, 3000); 67 | return; 68 | } 69 | if (self.model.get('disabled')) { return; } 70 | if (self.model.get('oneshot')) { 71 | var canvas = self.mask_canvas; 72 | var ctx = canvas.getContext('2d'); 73 | ctx.clearRect(0, 0, canvas.width, canvas.height); 74 | } 75 | function track_mouse(evt) { 76 | if (evt.type == 'keydown' || self.model.get('disabled')) { 77 | if (self.model.get('disabled') || evt.key === "Escape") { 78 | window.removeEventListener('mousemove', track_mouse); 79 | window.removeEventListener('mouseup', track_mouse); 80 | window.removeEventListener('keydown', track_mouse, true); 81 | if (self.model.get('dragging')) { 82 | self.model.set('dragging', false); 83 | } 84 | self.mask_changed(); 85 | } 86 | return; 87 | } 88 | if (evt.type == 'mouseup' || 89 | (typeof evt.buttons != 'undefined' && evt.buttons == 0)) { 90 | window.removeEventListener('mousemove', track_mouse); 91 | window.removeEventListener('mouseup', track_mouse); 92 | window.removeEventListener('keydown', track_mouse, true); 93 | self.model.set('dragging', false); 94 | self.model.set('mask', self.mask_canvas.toDataURL()); 95 | return; 96 | } 97 | var p = self.cursor_position(evt); 98 | var d = self.model.get('dragging'); 99 | var e = self.model.get('erase') ^ (evt.ctrlKey); 100 | if (!d) { self.model.set('startpos', [p.x, p.y]); } 101 | self.model.set('dragpos', [p.x, p.y]); 102 | if (!d) { self.model.set('dragging', true); } 103 | self.fill_circle(p.x, p.y, 104 | self.model.get('brushsize'), 105 | e); 106 | } 107 | this.mask_canvas.focus(); 108 | window.addEventListener('mousemove', track_mouse); 109 | window.addEventListener('mouseup', track_mouse); 110 | window.addEventListener('keydown', track_mouse, true); 111 | track_mouse(first_event); 112 | } 113 | mask_changed() { 114 | this.mask_canvas.classList.toggle("vanishing", this.model.get('vanishing')); 115 | this.draw_data_url(this.mask_canvas, this.model.get('mask')); 116 | } 117 | image_changed() { 118 | this.image.src = this.model.get('image'); 119 | } 120 | size_changed() { 121 | this.mask_canvas = document.createElement('canvas'); 122 | this.image = document.createElement('img'); 123 | this.mask_canvas.className = "paintmask"; 124 | this.image.className = "paintimage"; 125 | for (var attr of ['width', 'height']) { 126 | this.mask_canvas[attr] = this.model.get(attr); 127 | this.image[attr] = this.model.get(attr); 128 | } 129 | 130 | this.el.innerHTML = ''; 131 | this.el.appendChild(this.image); 132 | this.el.appendChild(this.mask_canvas); 133 | this.mask_canvas.addEventListener('mousedown', 134 | this.mouse_stroke.bind(this)); 135 | this.mask_changed(); 136 | this.image_changed(); 137 | } 138 | 139 | cursor_position(evt) { 140 | const rect = this.mask_canvas.getBoundingClientRect(); 141 | const x = event.clientX - rect.left; 142 | const y = event.clientY - rect.top; 143 | return {x: x, y: y}; 144 | } 145 | 146 | fill_circle(x, y, r, erase, blur) { 147 | var ctx = this.mask_canvas.getContext('2d'); 148 | ctx.save(); 149 | if (blur) { 150 | ctx.filter = 'blur(' + blur + 'px)'; 151 | } 152 | ctx.globalCompositeOperation = ( 153 | erase ? "destination-out" : 'source-over'); 154 | ctx.fillStyle = '#fff'; 155 | ctx.beginPath(); 156 | ctx.arc(x, y, r, 0, 2 * Math.PI); 157 | ctx.fill(); 158 | ctx.restore() 159 | } 160 | 161 | draw_data_url(canvas, durl) { 162 | var ctx = canvas.getContext('2d'); 163 | var img = new Image; 164 | canvas.pendingImg = img; 165 | function imgdone() { 166 | if (canvas.pendingImg == img) { 167 | ctx.clearRect(0, 0, canvas.width, canvas.height); 168 | ctx.drawImage(img, 0, 0); 169 | canvas.pendingImg = null; 170 | } 171 | } 172 | img.addEventListener('load', imgdone); 173 | img.addEventListener('error', imgdone); 174 | img.src = durl; 175 | } 176 | } 177 | """ 178 | -------------------------------------------------------------------------------- /netdissect/parallelfolder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Variants of pytorch's ImageFolder for loading image datasets with more 3 | information, such as parallel feature channels in separate files, 4 | cached files with lists of filenames, etc. 5 | ''' 6 | 7 | import os 8 | import torch 9 | import re 10 | import random 11 | import numpy 12 | import itertools 13 | import copy 14 | import torch.utils.data as data 15 | from torchvision.datasets.folder import default_loader as tv_default_loader 16 | from PIL import Image 17 | from collections import OrderedDict 18 | from . import pbar 19 | 20 | 21 | def grayscale_loader(path): 22 | with open(path, 'rb') as f: 23 | return Image.open(f).convert('L') 24 | 25 | 26 | class ndarray(numpy.ndarray): 27 | ''' 28 | Wrapper to make ndarrays into heap objects so that shared_state can 29 | be attached as an attribute. 30 | ''' 31 | pass 32 | 33 | 34 | def default_loader(filename): 35 | ''' 36 | Handles both numpy files and image formats. 37 | ''' 38 | if filename.endswith('.npy') or filename.endswith('.NPY'): 39 | return numpy.load(filename).view(ndarray) 40 | elif filename.endswith('.npz') or filename.endswith('.NPZ'): 41 | return numpy.load(filename) 42 | else: 43 | return tv_default_loader(filename) 44 | 45 | 46 | class ParallelImageFolders(data.Dataset): 47 | """ 48 | A data loader that looks for parallel image filenames, for example 49 | 50 | photo1/park/004234.jpg 51 | photo1/park/004236.jpg 52 | photo1/park/004237.jpg 53 | 54 | photo2/park/004234.png 55 | photo2/park/004236.png 56 | photo2/park/004237.png 57 | """ 58 | 59 | def __init__(self, image_roots, 60 | transform=None, 61 | loader=default_loader, 62 | stacker=None, 63 | classification=False, 64 | identification=False, 65 | intersection=False, 66 | filter_tuples=None, 67 | normalize_filename=None, 68 | verbose=None, 69 | size=None, 70 | shuffle=None, 71 | lazy_init=True): 72 | self.image_roots = image_roots 73 | if transform is not None and not hasattr(transform, '__iter__'): 74 | transform = [transform for _ in image_roots] 75 | self.transforms = transform 76 | self.stacker = stacker 77 | self.loader = loader 78 | self.identification = identification 79 | 80 | def do_lazy_init(): 81 | self.images, self.classes, self.class_to_idx = ( 82 | make_parallel_dataset(image_roots, 83 | classification=classification, 84 | intersection=intersection, 85 | filter_tuples=filter_tuples, 86 | normalize_fn=normalize_filename, 87 | verbose=verbose)) 88 | if len(self.images) == 0: 89 | raise RuntimeError("Found 0 images within: %s" % image_roots) 90 | if shuffle is not None: 91 | random.Random(shuffle).shuffle(self.images) 92 | if size is not None: 93 | self.images = self.images[:size] 94 | self._do_lazy_init = None 95 | # Do slow initialization lazily. 96 | if lazy_init: 97 | self._do_lazy_init = do_lazy_init 98 | else: 99 | do_lazy_init() 100 | 101 | def subset(self, indexes): 102 | ''' 103 | Returns a subset of the current dataset, given by 104 | the set of specified indexes. 105 | ''' 106 | if self._do_lazy_init is not None: 107 | self._do_lazy_init() 108 | # Copy over transforms and other settings. 109 | ds = ParallelImageFolders( 110 | self.image_roots, 111 | transform=self.transforms, 112 | loader=self.loader, 113 | stacker=self.stacker, 114 | identification=self.identification, 115 | lazy_init=True) 116 | # Initialize the subset items directly. 117 | ds.images = [ 118 | copy.deepcopy(self.images[i]) for i in indexes] 119 | ds.classes = self.classes 120 | ds.class_to_idx = self.class_to_idx 121 | ds._do_lazy_init = None 122 | return ds 123 | 124 | def __getattr__(self, attr): 125 | if self._do_lazy_init is not None: 126 | self._do_lazy_init() 127 | return getattr(self, attr) 128 | raise AttributeError() 129 | 130 | def __getitem__(self, index): 131 | if self._do_lazy_init is not None: 132 | self._do_lazy_init() 133 | paths = self.images[index] 134 | if self.classes is not None: 135 | classidx = paths[-1] 136 | paths = paths[:-1] 137 | sources = [self.loader(path) for path in paths] 138 | # Add a common shared state dict to allow random crops/flips to be 139 | # coordinated. 140 | shared_state = {} 141 | for s in sources: 142 | try: 143 | s.shared_state = shared_state 144 | except: 145 | pass 146 | if self.transforms is not None: 147 | sources = [transform(source) if transform is not None else source 148 | for source, transform 149 | in itertools.zip_longest(sources, self.transforms)] 150 | if self.stacker is not None: 151 | sources = self.stacker(sources) 152 | if self.classes is None and not self.identification: 153 | return sources 154 | else: 155 | sources = [sources] 156 | if self.classes is not None: 157 | sources.append(classidx) 158 | if self.identification: 159 | sources.append(index) 160 | sources = tuple(sources) 161 | return sources 162 | 163 | def __len__(self): 164 | if self._do_lazy_init is not None: 165 | self._do_lazy_init() 166 | return len(self.images) 167 | 168 | 169 | def is_npy_file(path): 170 | return (path.endswith('.npy') or path.endswith('.NPY') or 171 | path.endswith('.npz') or path.endswith('.NPZ')) 172 | 173 | 174 | def is_image_file(path): 175 | return None != re.search(r'\.(jpe?g|png)$', path, re.IGNORECASE) 176 | 177 | 178 | def walk_image_files(rootdir, verbose=None): 179 | indexfile = '%s.txt' % rootdir 180 | if os.path.isfile(indexfile): 181 | basedir = os.path.dirname(rootdir) 182 | with open(indexfile) as f: 183 | result = sorted([os.path.join(basedir, line.strip()) 184 | for line in f.readlines()]) 185 | return result 186 | result = [] 187 | # for dirname, _, fnames in sorted(pbar(os.walk(rootdir), 188 | # desc='Walking %s' % os.path.basename(rootdir))): 189 | for dirname, _, fnames in sorted(os.walk(rootdir)): 190 | for fname in sorted(fnames): 191 | if is_image_file(fname) or is_npy_file(fname): 192 | result.append(os.path.join(dirname, fname)) 193 | return result 194 | 195 | 196 | def make_parallel_dataset(image_roots, classification=False, 197 | intersection=False, filter_tuples=None, normalize_fn=None, 198 | verbose=None): 199 | """ 200 | Returns ([(img1, img2, clsid, id), (img1, img2, clsid, id)..], 201 | classes, class_to_idx) 202 | """ 203 | image_roots = [os.path.expanduser(d) for d in image_roots] 204 | image_sets = OrderedDict() 205 | if normalize_fn is None: 206 | def normalize_fn(x): return os.path.splitext(x)[0] 207 | for j, root in enumerate(image_roots): 208 | for path in walk_image_files(root, verbose=verbose): 209 | key = normalize_fn(os.path.relpath(path, root)) 210 | if key not in image_sets: 211 | image_sets[key] = [] 212 | if not intersection and len(image_sets[key]) != j: 213 | raise RuntimeError('Images not parallel: ' 214 | '{} missing from {}'.format(key, root)) 215 | image_sets[key].append(path) 216 | if classification: 217 | classes = sorted(set([os.path.basename(os.path.dirname(k)) 218 | for k in image_sets.keys()])) 219 | class_to_idx = dict({k: v for v, k in enumerate(classes)}) 220 | for k, v in image_sets.items(): 221 | v.append(class_to_idx[os.path.basename(os.path.dirname(k))]) 222 | else: 223 | classes, class_to_idx = None, None 224 | tuples = [] 225 | for key, value in image_sets.items(): 226 | if len(value) != (len(image_roots) + (1 if classification else 0)): 227 | if intersection: 228 | continue 229 | else: 230 | raise RuntimeError( 231 | 'Images not parallel: %s missing from one dir' % (key)) 232 | value = tuple(value) 233 | if filter_tuples and not filter_tuples(value): 234 | continue 235 | tuples.append(value) 236 | return tuples, classes, class_to_idx 237 | 238 | 239 | class NpzToTensor: 240 | """ 241 | A data transformer for converting a loaded npz file to a pytorch 242 | tensor. Since an npz file stores tensors under keys, a key can be 243 | specified. Otherwise, the first key is dereferenced. 244 | """ 245 | 246 | def __init__(self, key=None): 247 | self.key = key 248 | 249 | def __call__(self, data): 250 | key = self.key or next(iter(data)) 251 | return torch.from_numpy(data[key]) 252 | -------------------------------------------------------------------------------- /netdissect/pbar.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Utilities for showing progress bars, controlling default verbosity, etc. 3 | ''' 4 | 5 | # If the tqdm package is not available, then do not show progress bars; 6 | # just connect print_progress to print. 7 | import sys 8 | import types 9 | import builtins 10 | try: 11 | from tqdm import tqdm 12 | try: 13 | from tqdm.notebook import tqdm as tqdm_nb 14 | except: 15 | from tqdm import tqdm_notebook as tqdm_nb 16 | except: 17 | tqdm = None 18 | 19 | default_verbosity = True 20 | next_description = None 21 | python_print = builtins.print 22 | 23 | 24 | def post(**kwargs): 25 | ''' 26 | When within a progress loop, pbar.post(k=str) will display 27 | the given k=str status on the right-hand-side of the progress 28 | status bar. If not within a visible progress bar, does nothing. 29 | ''' 30 | innermost = innermost_tqdm() 31 | if innermost is not None: 32 | innermost.set_postfix(**kwargs) 33 | 34 | 35 | def desc(desc): 36 | ''' 37 | When within a progress loop, pbar.desc(str) changes the 38 | left-hand-side description of the loop toe the given description. 39 | ''' 40 | innermost = innermost_tqdm() 41 | if innermost is not None: 42 | innermost.set_description(str(desc)) 43 | 44 | 45 | def descnext(desc): 46 | ''' 47 | Called before starting a progress loop, pbar.descnext(str) 48 | sets the description text that will be used in the following loop. 49 | ''' 50 | global next_description 51 | if not default_verbosity or tqdm is None: 52 | return 53 | next_description = desc 54 | 55 | 56 | def print(*args): 57 | ''' 58 | When within a progress loop, will print above the progress loop. 59 | ''' 60 | global next_description 61 | next_description = None 62 | if default_verbosity: 63 | msg = ' '.join(str(s) for s in args) 64 | if tqdm is None: 65 | python_print(msg) 66 | else: 67 | tqdm.write(msg) 68 | 69 | 70 | def tqdm_terminal(it, *args, **kwargs): 71 | ''' 72 | Some settings for tqdm that make it run better in resizable terminals. 73 | ''' 74 | return tqdm(it, *args, dynamic_ncols=True, ascii=True, 75 | leave=(innermost_tqdm() is not None), **kwargs) 76 | 77 | 78 | def in_notebook(): 79 | ''' 80 | True if running inside a Jupyter notebook. 81 | ''' 82 | # From https://stackoverflow.com/a/39662359/265298 83 | try: 84 | shell = get_ipython().__class__.__name__ 85 | if shell == 'ZMQInteractiveShell': 86 | return True # Jupyter notebook or qtconsole 87 | elif shell == 'TerminalInteractiveShell': 88 | return False # Terminal running IPython 89 | else: 90 | return False # Other type (?) 91 | except NameError: 92 | return False # Probably standard Python interpreter 93 | 94 | 95 | def innermost_tqdm(): 96 | ''' 97 | Returns the innermost active tqdm progress loop on the stack. 98 | ''' 99 | if hasattr(tqdm, '_instances') and len(tqdm._instances) > 0: 100 | return max(tqdm._instances, key=lambda x: x.pos) 101 | else: 102 | return None 103 | 104 | 105 | def reporthook(*args, **kwargs): 106 | ''' 107 | For use with urllib.request.urlretrieve. 108 | 109 | with pbar.reporthook() as hook: 110 | urllib.request.urlretrieve(url, filename, reporthook=hook) 111 | ''' 112 | kwargs2 = dict(unit_scale=True, miniters=1) 113 | kwargs2.update(kwargs) 114 | bar = __call__(None, *args, **kwargs2) 115 | 116 | class ReportHook(object): 117 | def __init__(self, t): 118 | self.t = t 119 | 120 | def __call__(self, b=1, bsize=1, tsize=None): 121 | if hasattr(self.t, 'total'): 122 | if tsize is not None: 123 | self.t.total = tsize 124 | if hasattr(self.t, 'update'): 125 | self.t.update(b * bsize - self.t.n) 126 | 127 | def __enter__(self): 128 | return self 129 | 130 | def __exit__(self, *exc): 131 | if hasattr(self.t, '__exit__'): 132 | self.t.__exit__(*exc) 133 | return ReportHook(bar) 134 | 135 | 136 | def __call__(x, *args, **kwargs): 137 | ''' 138 | Invokes a progress function that can wrap iterators to print 139 | progress messages, if verbose is True. 140 | 141 | If verbose is False or tqdm is unavailable, then a quiet 142 | non-printing identity function is used. 143 | 144 | verbose can also be set to a spefific progress function rather 145 | than True, and that function will be used. 146 | ''' 147 | global default_verbosity, next_description 148 | if not default_verbosity or tqdm is None: 149 | return x 150 | if default_verbosity == True: 151 | fn = tqdm_nb if in_notebook() else tqdm_terminal 152 | else: 153 | fn = default_verbosity 154 | if next_description is not None: 155 | kwargs = dict(kwargs) 156 | kwargs['desc'] = next_description 157 | next_description = None 158 | return fn(x, *args, **kwargs) 159 | 160 | 161 | class VerboseContextManager(): 162 | def __init__(self, v, entered=False): 163 | self.v, self.entered, self.saved = v, False, [] 164 | if entered: 165 | self.__enter__() 166 | self.entered = True 167 | 168 | def __enter__(self): 169 | global default_verbosity 170 | if self.entered: 171 | self.entered = False 172 | else: 173 | self.saved.append(default_verbosity) 174 | default_verbosity = self.v 175 | return self 176 | 177 | def __exit__(self, exc_type, exc_value, exc_traceback): 178 | global default_verbosity 179 | default_verbosity = self.saved.pop() 180 | 181 | def __call__(self, v=True): 182 | ''' 183 | Calling the context manager makes a new context that is 184 | pre-entered, so it works as both a plain function and as a 185 | factory for a context manager. 186 | ''' 187 | new_v = v if self.v else not v 188 | cm = VerboseContextManager(new_v, entered=True) 189 | default_verbosity = new_v 190 | return cm 191 | 192 | 193 | # Use as either "with pbar.verbose:" or "pbar.verbose(False)", or also 194 | # "with pbar.verbose(False):" 195 | verbose = VerboseContextManager(True) 196 | 197 | # Use as either "with @pbar.quiet" or "pbar.quiet(True)". or also 198 | # "with pbar.quiet(True):" 199 | quiet = VerboseContextManager(False) 200 | 201 | 202 | class CallableModule(types.ModuleType): 203 | def __init__(self): 204 | # or super().__init__(__name__) for Python 3 205 | types.ModuleType.__init__(self, __name__) 206 | self.__dict__.update(sys.modules[__name__].__dict__) 207 | 208 | def __call__(self, x, *args, **kwargs): 209 | return __call__(x, *args, **kwargs) 210 | 211 | 212 | sys.modules[__name__] = CallableModule() 213 | -------------------------------------------------------------------------------- /netdissect/pidfile.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Utility for simple distribution of work on multiple processes, by 3 | making sure only one process is working on a job at once. 4 | ''' 5 | 6 | import os 7 | import errno 8 | import socket 9 | import atexit 10 | import time 11 | import sys 12 | 13 | 14 | def exclusive_dirfn(*args): 15 | ''' 16 | Convenience function to get exclusive access to an unfinished 17 | experiment directory. Exits the program if the directory is 18 | already done or busy (using exit_of_job_done). Otherwise, 19 | returns a function creates filenames within that directory. 20 | ''' 21 | directory = os.path.join(*[str(a) for a in args]) 22 | exit_if_job_done(directory) 23 | 24 | def dirfn(*fn): 25 | return os.path.join(directory, *fn) 26 | dirfn.dir = directory 27 | 28 | def done(): 29 | mark_job_done(directory) 30 | dirfn.done = done 31 | print('Working in %s' % directory) 32 | return dirfn 33 | 34 | 35 | def exit_if_job_done(directory, redo=False, force=False, verbose=True): 36 | if pidfile_taken(os.path.join(directory, 'lockfile.pid'), 37 | force=force, verbose=verbose): 38 | sys.exit(0) 39 | donefile = os.path.join(directory, 'done.txt') 40 | if os.path.isfile(donefile): 41 | with open(donefile) as f: 42 | msg = f.read() 43 | if redo or force: 44 | if verbose: 45 | print('Removing %s %s' % (donefile, msg)) 46 | os.remove(donefile) 47 | else: 48 | if verbose: 49 | print('%s %s' % (donefile, msg)) 50 | sys.exit(0) 51 | 52 | 53 | def mark_job_done(directory): 54 | with open(os.path.join(directory, 'done.txt'), 'w') as f: 55 | f.write('done by %d@%s %s at %s' % 56 | (os.getpid(), socket.gethostname(), 57 | os.getenv('STY', ''), 58 | time.strftime('%c'))) 59 | 60 | 61 | def pidfile_taken(path, verbose=False, force=False): 62 | ''' 63 | Usage. To grab an exclusive lock for the remaining duration of the 64 | current process (and exit if another process already has the lock), 65 | do this: 66 | 67 | if pidfile_taken('job_423/lockfile.pid', verbose=True): 68 | sys.exit(0) 69 | 70 | To do a batch of jobs, just run a script that does them all on 71 | each available machine, sharing a network filesystem. When each 72 | job grabs a lock, then this will automatically distribute the 73 | jobs so that each one is done just once on one machine. 74 | ''' 75 | 76 | # Try to create the file exclusively and write my pid into it. 77 | try: 78 | os.makedirs(os.path.dirname(path), exist_ok=True) 79 | fd = os.open(path, os.O_CREAT | os.O_EXCL | os.O_RDWR) 80 | except OSError as e: 81 | if e.errno == errno.EEXIST: 82 | # If we cannot because there was a race, yield the conflicter. 83 | conflicter = 'race' 84 | try: 85 | with open(path, 'r') as lockfile: 86 | conflicter = lockfile.read().strip() or 'empty' 87 | except: 88 | pass 89 | # Force is for manual one-time use, for deleting stale lockfiles. 90 | if force: 91 | if verbose: 92 | print('Removing %s from %s' % (path, conflicter)) 93 | os.remove(path) 94 | return pidfile_taken(path, verbose=verbose, force=False) 95 | if verbose: 96 | print('%s held by %s' % (path, conflicter)) 97 | return conflicter 98 | else: 99 | # Other problems get an exception. 100 | raise 101 | # Register to delete this file on exit. 102 | lockfile = os.fdopen(fd, 'r+') 103 | atexit.register(delete_pidfile, lockfile, path) 104 | # Write my pid into the open file. 105 | lockfile.write('%d@%s %s\n' % (os.getpid(), socket.gethostname(), 106 | os.getenv('STY', ''))) 107 | lockfile.flush() 108 | os.fsync(lockfile) 109 | # Return 'None' to say there was not a conflict. 110 | return None 111 | 112 | 113 | def delete_pidfile(lockfile, path): 114 | ''' 115 | Runs at exit after pidfile_taken succeeds. 116 | ''' 117 | if lockfile is not None: 118 | try: 119 | lockfile.close() 120 | except: 121 | pass 122 | try: 123 | os.unlink(path) 124 | except: 125 | pass 126 | -------------------------------------------------------------------------------- /netdissect/renormalize.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import torch 3 | import PIL 4 | import io 5 | import base64 6 | import re 7 | from torchvision import transforms 8 | 9 | 10 | def as_tensor(data, source='zc', target='zc'): 11 | renorm = renormalizer(source=source, target=target) 12 | return renorm(data) 13 | 14 | 15 | def as_image(data, source='zc', target='byte'): 16 | assert len(data.shape) == 3 17 | renorm = renormalizer(source=source, target=target) 18 | return PIL.Image.fromarray(renorm(data). 19 | permute(1, 2, 0).cpu().numpy()) 20 | 21 | 22 | def as_url(data, source='zc', size=None): 23 | if isinstance(data, PIL.Image.Image): 24 | img = data 25 | else: 26 | img = as_image(data, source) 27 | if size is not None: 28 | img = img.resize(size, resample=PIL.Image.BILINEAR) 29 | buffered = io.BytesIO() 30 | img.save(buffered, format='png') 31 | b64 = base64.b64encode(buffered.getvalue()).decode('utf-8') 32 | return 'data:image/png;base64,%s' % (b64) 33 | 34 | 35 | def from_image(im, target='zc', size=None): 36 | if im.format != 'RGB': 37 | im = im.convert('RGB') 38 | if size is not None: 39 | im = im.resize(size, resample=PIL.Image.BILINEAR) 40 | pt = transforms.functional.to_tensor(im) 41 | renorm = renormalizer(source='pt', target=target) 42 | return renorm(pt) 43 | 44 | 45 | def from_url(url, target='zc', size=None): 46 | image_data = re.sub('^data:image/.+;base64,', '', url) 47 | im = PIL.Image.open(io.BytesIO(base64.b64decode(image_data))) 48 | if target == 'image' and size is None: 49 | return im 50 | return from_image(im, target, size=size) 51 | 52 | 53 | def renormalizer(source='zc', target='zc'): 54 | ''' 55 | Returns a function that imposes a standard normalization on 56 | the image data. The returned renormalizer operates on either 57 | 3d tensor (single image) or 4d tensor (image batch) data. 58 | The normalization target choices are: 59 | 60 | zc (default) - zero centered [-1..1] 61 | pt - pytorch [0..1] 62 | imagenet - zero mean, unit stdev imagenet stats (approx [-2.1...2.6]) 63 | byte - as from an image file, [0..255] 64 | 65 | If a source is provided (a dataset or transform), then, the renormalizer 66 | first reverses any normalization found in the data source before 67 | imposing the specified normalization. When no source is provided, 68 | the input data is assumed to be pytorch-normalized (range [0..1]). 69 | ''' 70 | if isinstance(source, str): 71 | oldoffset, oldscale = OFFSET_SCALE[source] 72 | else: 73 | normalizer = find_normalizer(source) 74 | oldoffset, oldscale = ( 75 | (normalizer.mean, normalizer.std) if normalizer is not None 76 | else OFFSET_SCALE['pt']) 77 | newoffset, newscale = (target if isinstance(target, tuple) 78 | else OFFSET_SCALE[target]) 79 | return Renormalizer(oldoffset, oldscale, newoffset, newscale, 80 | tobyte=(target == 'byte')) 81 | 82 | 83 | # The three commonly-seen image normalization schemes. 84 | OFFSET_SCALE = dict( 85 | pt=([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), 86 | zc=([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 87 | imagenet=([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 88 | imagenet_meanonly=([0.485, 0.456, 0.406], 89 | [1.0 / 255, 1.0 / 255, 1.0 / 255]), 90 | places_meanonly=([0.475, 0.441, 0.408], 91 | [1.0 / 255, 1.0 / 255, 1.0 / 255]), 92 | byte=([0.0, 0.0, 0.0], [1.0 / 255, 1.0 / 255, 1.0 / 255])) 93 | 94 | NORMALIZER = {k: transforms.Normalize(*OFFSET_SCALE[k]) for k in OFFSET_SCALE} 95 | 96 | 97 | def find_normalizer(source=None): 98 | ''' 99 | Crawl around the transforms attached to a dataset looking for a 100 | Normalize transform to return. 101 | ''' 102 | if source is None: 103 | return None 104 | if isinstance(source, (transforms.Normalize, Renormalizer)): 105 | return source 106 | t = getattr(source, 'transform', None) 107 | if t is not None: 108 | return find_normalizer(t) 109 | ts = getattr(source, 'transforms', None) 110 | if ts is not None: 111 | for t in reversed(ts): 112 | result = find_normalizer(t) 113 | if result is not None: 114 | return result 115 | return None 116 | 117 | 118 | class Renormalizer: 119 | def __init__(self, oldoffset, oldscale, newoffset, newscale, tobyte=False): 120 | self.mul = torch.from_numpy( 121 | numpy.array(oldscale) / numpy.array(newscale)) 122 | self.add = torch.from_numpy( 123 | (numpy.array(oldoffset) - numpy.array(newoffset)) 124 | / numpy.array(newscale)) 125 | self.tobyte = tobyte 126 | # Store these away to allow the data to be renormalized again 127 | self.mean = newoffset 128 | self.std = newscale 129 | 130 | def __call__(self, data): 131 | mul, add = [d.to(data.device, data.dtype) for d in [self.mul, self.add]] 132 | if data.ndimension() == 3: 133 | mul, add = [d[:, None, None] for d in [mul, add]] 134 | elif data.ndimension() == 4: 135 | mul, add = [d[None, :, None, None] for d in [mul, add]] 136 | result = data.mul(mul).add_(add) 137 | if self.tobyte: 138 | result = result.clamp(0, 255).byte() 139 | return result 140 | -------------------------------------------------------------------------------- /netdissect/report.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 10 | 14 | 17 | 20 | 24 | 31 | 32 | 33 |
34 |
35 |

{{ header.name }}

36 | 37 |
sort by: 38 | iou 40 | label 42 | unit 44 |
45 |
46 |
48 |
49 |
unit {{ r.unit }} ({{ r.label }}, iou {{ r.iou | fixed(4) }})
50 | 51 |
52 |
53 |
54 | 55 | 80 | 81 | -------------------------------------------------------------------------------- /netdissect/sampler.py: -------------------------------------------------------------------------------- 1 | ''' 2 | A sampler is just a list of integer listing the indexes of the 3 | inputs in a data set to sample. For reproducibility, the 4 | FixedRandomSubsetSampler uses a seeded prng to produce the same 5 | sequence always. FixedSubsetSampler is just a wrapper for an 6 | explicit list of integers. 7 | 8 | coordinate_sample solves another sampling problem: when testing 9 | convolutional outputs, we can reduce data explosing by sampling 10 | random points of the feature map rather than the entire feature map. 11 | coordinate_sample does this in a deterministic way that is also 12 | resolution-independent. 13 | ''' 14 | 15 | import numpy 16 | import random 17 | from torch.utils.data.sampler import Sampler 18 | 19 | 20 | class FixedSubsetSampler(Sampler): 21 | """Represents a fixed sequence of data set indices. 22 | Subsets can be created by specifying a subset of output indexes. 23 | """ 24 | 25 | def __init__(self, samples): 26 | self.samples = samples 27 | 28 | def __iter__(self): 29 | return iter(self.samples) 30 | 31 | def __len__(self): 32 | return len(self.samples) 33 | 34 | def __getitem__(self, key): 35 | return self.samples[key] 36 | 37 | def subset(self, new_subset): 38 | return FixedSubsetSampler(self.dereference(new_subset)) 39 | 40 | def dereference(self, indices): 41 | ''' 42 | Translate output sample indices (small numbers indexing the sample) 43 | to input sample indices (larger number indexing the original full set) 44 | ''' 45 | return [self.samples[i] for i in indices] 46 | 47 | 48 | class FixedRandomSubsetSampler(FixedSubsetSampler): 49 | """Samples a fixed number of samples from the dataset, deterministically. 50 | Arguments: 51 | data_source, 52 | sample_size, 53 | seed (optional) 54 | """ 55 | 56 | def __init__(self, data_source, start=None, end=None, seed=1): 57 | rng = random.Random(seed) 58 | shuffled = list(range(len(data_source))) 59 | rng.shuffle(shuffled) 60 | self.data_source = data_source 61 | super(FixedRandomSubsetSampler, self).__init__(shuffled[start:end]) 62 | 63 | def class_subset(self, class_filter): 64 | ''' 65 | Returns only the subset matching the given rule. 66 | ''' 67 | if isinstance(class_filter, int): 68 | def rule(d): return d[1] == class_filter 69 | else: 70 | rule = class_filter 71 | return self.subset([i for i, j in enumerate(self.samples) 72 | if rule(self.data_source[j])]) 73 | 74 | 75 | def coordinate_sample(shape, sample_size, seeds, grid=13, seed=1, flat=False): 76 | ''' 77 | Returns a (end-start) sets of sample_size grid points within 78 | the shape given. If the shape dimensions are a multiple of 'grid', 79 | then sampled points within the same row will never be duplicated. 80 | ''' 81 | if flat: 82 | sampind = numpy.zeros((len(seeds), sample_size), dtype=int) 83 | else: 84 | sampind = numpy.zeros((len(seeds), 2, sample_size), dtype=int) 85 | assert sample_size <= grid 86 | for j, seed in enumerate(seeds): 87 | rng = numpy.random.RandomState(seed) 88 | # Shuffle the 169 random grid squares, and pick :sample_size. 89 | square_count = grid ** len(shape) 90 | square = numpy.stack(numpy.unravel_index( 91 | rng.choice(square_count, square_count)[:sample_size], 92 | (grid,) * len(shape))) 93 | # Then add a random offset to each x, y and put in the range [0...1) 94 | # Notice this selects the same locations regardless of resolution. 95 | uniform = (square + rng.uniform(size=square.shape)) / grid 96 | # TODO: support affine scaling so that we can align receptive field 97 | # centers exactly when sampling neurons in different layers. 98 | coords = (uniform * numpy.array(shape)[:, None]).astype(int) 99 | # Now take sample_size without replacement. We do this in a way 100 | # such that if sample_size is decreased or increased up to 'grid', 101 | # the selected points become a subset, not totally different points. 102 | if flat: 103 | sampind[j] = numpy.ravel_multi_index(coords, dims=shape) 104 | else: 105 | sampind[j] = coords 106 | return sampind 107 | 108 | 109 | def main(): 110 | from . import parallelfolder 111 | import argparse 112 | import os 113 | import shutil 114 | 115 | parser = argparse.ArgumentParser(description='Net dissect utility', 116 | prog='python -m %s.sampler' % __package__) 117 | parser.add_argument('indir') 118 | parser.add_argument('outdir') 119 | parser.add_argument('--size', type=int, default=100) 120 | parser.add_argument('--test', action='store_true', default=False) 121 | args = parser.parse_args() 122 | if os.path.exists(args.outdir): 123 | print('%s already exists' % args.outdir) 124 | sys.exit(1) 125 | os.makedirs(args.outdir) 126 | dataset = parallelfolder.ParallelImageFolders([args.indir]) 127 | sampler = FixedRandomSubsetSampler(dataset, end=args.size) 128 | seen_filenames = set() 129 | 130 | def number_filename(filename, number): 131 | if '.' in filename: 132 | a, b = filename.rsplit('.', 1) 133 | return a + '_%d.' % number + b 134 | return filename + '_%d' % number 135 | for i in sampler.dereference(range(args.size)): 136 | sourcefile = dataset.images[i][0] 137 | filename = os.path.basename(sourcefile) 138 | template = filename 139 | num = 0 140 | while filename in seen_filenames: 141 | num += 1 142 | filename = number_filename(template, num) 143 | seen_filenames.add(filename) 144 | shutil.copy(os.path.join(args.indir, sourcefile), 145 | os.path.join(args.outdir, filename)) 146 | 147 | 148 | def test(): 149 | from numpy.testing import assert_almost_equal 150 | # Test that coordinate_sample is deterministic, in-range, and scalable. 151 | assert_almost_equal(coordinate_sample((26, 26), 10, range(101, 102)), 152 | [[[14, 0, 12, 11, 8, 13, 11, 20, 7, 20], 153 | [9, 22, 7, 11, 23, 18, 21, 15, 2, 5]]]) 154 | assert_almost_equal(coordinate_sample((13, 13), 10, range(101, 102)), 155 | [[[7, 0, 6, 5, 4, 6, 5, 10, 3, 20 // 2], 156 | [4, 11, 3, 5, 11, 9, 10, 7, 1, 5 // 2]]]) 157 | assert_almost_equal(coordinate_sample((13, 13), 10, range(100, 102), 158 | flat=True), 159 | [[8, 24, 67, 103, 87, 79, 138, 94, 98, 53], 160 | [95, 11, 81, 70, 63, 87, 75, 137, 40, 2 + 10 * 13]]) 161 | assert_almost_equal(coordinate_sample((13, 13), 10, range(101, 103), 162 | flat=True), 163 | [[95, 11, 81, 70, 63, 87, 75, 137, 40, 132], 164 | [0, 78, 114, 111, 66, 45, 72, 73, 79, 135]]) 165 | assert_almost_equal(coordinate_sample((26, 26), 10, range(101, 102), 166 | flat=True), 167 | [[373, 22, 319, 297, 231, 356, 307, 535, 184, 5 + 20 * 26]]) 168 | # Test FixedRandomSubsetSampler 169 | fss = FixedRandomSubsetSampler(range(10)) 170 | assert len(fss) == 10 171 | assert_almost_equal(list(fss), [6, 8, 9, 7, 5, 3, 0, 4, 1, 2]) 172 | fss = FixedRandomSubsetSampler(range(10), 3, 8) 173 | assert len(fss) == 5 174 | assert_almost_equal(list(fss), [7, 5, 3, 0, 4]) 175 | fss = FixedRandomSubsetSampler([(i, i % 3) for i in range(10)] 176 | ).class_subset(class_filter=1) 177 | assert len(fss) == 3 178 | assert_almost_equal(list(fss), [7, 4, 1]) 179 | 180 | 181 | if __name__ == '__main__': 182 | import sys 183 | if '--test' in sys.argv[1:]: 184 | test() 185 | else: 186 | main() 187 | -------------------------------------------------------------------------------- /netdissect/segmodel/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import ModelBuilder, SegmentationModule 2 | -------------------------------------------------------------------------------- /netdissect/segmodel/colors150.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidbau/dissect/9421eaa8672fd051088de6c0225a385064070935/netdissect/segmodel/colors150.npy -------------------------------------------------------------------------------- /netdissect/segmodel/mobilenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This MobileNetV2 implementation is modified from the following repository: 3 | https://github.com/tonylins/pytorch-mobilenet-v2 4 | """ 5 | 6 | import os 7 | import sys 8 | import torch 9 | import torch.nn as nn 10 | import math 11 | try: 12 | from lib.nn import SynchronizedBatchNorm2d 13 | except ImportError: 14 | from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d 15 | 16 | try: 17 | from urllib import urlretrieve 18 | except ImportError: 19 | from urllib.request import urlretrieve 20 | 21 | 22 | __all__ = ['mobilenetv2'] 23 | 24 | 25 | model_urls = { 26 | 'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar', 27 | } 28 | 29 | 30 | def conv_bn(inp, oup, stride): 31 | return nn.Sequential( 32 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 33 | SynchronizedBatchNorm2d(oup), 34 | nn.ReLU6(inplace=True) 35 | ) 36 | 37 | 38 | def conv_1x1_bn(inp, oup): 39 | return nn.Sequential( 40 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 41 | SynchronizedBatchNorm2d(oup), 42 | nn.ReLU6(inplace=True) 43 | ) 44 | 45 | 46 | class InvertedResidual(nn.Module): 47 | def __init__(self, inp, oup, stride, expand_ratio): 48 | super(InvertedResidual, self).__init__() 49 | self.stride = stride 50 | assert stride in [1, 2] 51 | 52 | hidden_dim = round(inp * expand_ratio) 53 | self.use_res_connect = self.stride == 1 and inp == oup 54 | 55 | if expand_ratio == 1: 56 | self.conv = nn.Sequential( 57 | # dw 58 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 59 | SynchronizedBatchNorm2d(hidden_dim), 60 | nn.ReLU6(inplace=True), 61 | # pw-linear 62 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 63 | SynchronizedBatchNorm2d(oup), 64 | ) 65 | else: 66 | self.conv = nn.Sequential( 67 | # pw 68 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 69 | SynchronizedBatchNorm2d(hidden_dim), 70 | nn.ReLU6(inplace=True), 71 | # dw 72 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 73 | SynchronizedBatchNorm2d(hidden_dim), 74 | nn.ReLU6(inplace=True), 75 | # pw-linear 76 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 77 | SynchronizedBatchNorm2d(oup), 78 | ) 79 | 80 | def forward(self, x): 81 | if self.use_res_connect: 82 | return x + self.conv(x) 83 | else: 84 | return self.conv(x) 85 | 86 | 87 | class MobileNetV2(nn.Module): 88 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 89 | super(MobileNetV2, self).__init__() 90 | block = InvertedResidual 91 | input_channel = 32 92 | last_channel = 1280 93 | interverted_residual_setting = [ 94 | # t, c, n, s 95 | [1, 16, 1, 1], 96 | [6, 24, 2, 2], 97 | [6, 32, 3, 2], 98 | [6, 64, 4, 2], 99 | [6, 96, 3, 1], 100 | [6, 160, 3, 2], 101 | [6, 320, 1, 1], 102 | ] 103 | 104 | # building first layer 105 | assert input_size % 32 == 0 106 | input_channel = int(input_channel * width_mult) 107 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 108 | self.features = [conv_bn(3, input_channel, 2)] 109 | # building inverted residual blocks 110 | for t, c, n, s in interverted_residual_setting: 111 | output_channel = int(c * width_mult) 112 | for i in range(n): 113 | if i == 0: 114 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 115 | else: 116 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 117 | input_channel = output_channel 118 | # building last several layers 119 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 120 | # make it nn.Sequential 121 | self.features = nn.Sequential(*self.features) 122 | 123 | # building classifier 124 | self.classifier = nn.Sequential( 125 | nn.Dropout(0.2), 126 | nn.Linear(self.last_channel, n_class), 127 | ) 128 | 129 | self._initialize_weights() 130 | 131 | def forward(self, x): 132 | x = self.features(x) 133 | x = x.mean(3).mean(2) 134 | x = self.classifier(x) 135 | return x 136 | 137 | def _initialize_weights(self): 138 | for m in self.modules(): 139 | if isinstance(m, nn.Conv2d): 140 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 141 | m.weight.data.normal_(0, math.sqrt(2. / n)) 142 | if m.bias is not None: 143 | m.bias.data.zero_() 144 | elif isinstance(m, SynchronizedBatchNorm2d): 145 | m.weight.data.fill_(1) 146 | m.bias.data.zero_() 147 | elif isinstance(m, nn.Linear): 148 | n = m.weight.size(1) 149 | m.weight.data.normal_(0, 0.01) 150 | m.bias.data.zero_() 151 | 152 | 153 | def mobilenetv2(pretrained=False, **kwargs): 154 | """Constructs a MobileNet_V2 model. 155 | 156 | Args: 157 | pretrained (bool): If True, returns a model pre-trained on ImageNet 158 | """ 159 | model = MobileNetV2(n_class=1000, **kwargs) 160 | if pretrained: 161 | model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False) 162 | return model 163 | 164 | 165 | def load_url(url, model_dir='./pretrained', map_location=None): 166 | if not os.path.exists(model_dir): 167 | os.makedirs(model_dir) 168 | filename = url.split('/')[-1] 169 | cached_file = os.path.join(model_dir, filename) 170 | if not os.path.exists(cached_file): 171 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 172 | urlretrieve(url, cached_file) 173 | return torch.load(cached_file, map_location=map_location) 174 | 175 | -------------------------------------------------------------------------------- /netdissect/segmodel/object150_info.csv: -------------------------------------------------------------------------------- 1 | Idx,Ratio,Train,Val,Stuff,Name 2 | 1,0.1576,11664,1172,1,wall 3 | 2,0.1072,6046,612,1,building;edifice 4 | 3,0.0878,8265,796,1,sky 5 | 4,0.0621,9336,917,1,floor;flooring 6 | 5,0.0480,6678,641,0,tree 7 | 6,0.0450,6604,643,1,ceiling 8 | 7,0.0398,4023,408,1,road;route 9 | 8,0.0231,1906,199,0,bed 10 | 9,0.0198,4688,460,0,windowpane;window 11 | 10,0.0183,2423,225,1,grass 12 | 11,0.0181,2874,294,0,cabinet 13 | 12,0.0166,3068,310,1,sidewalk;pavement 14 | 13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul 15 | 14,0.0151,1804,190,1,earth;ground 16 | 15,0.0118,6666,796,0,door;double;door 17 | 16,0.0110,4269,411,0,table 18 | 17,0.0109,1691,160,1,mountain;mount 19 | 18,0.0104,3999,441,0,plant;flora;plant;life 20 | 19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall 21 | 20,0.0103,3261,318,0,chair 22 | 21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar 23 | 22,0.0074,709,75,1,water 24 | 23,0.0067,3296,315,0,painting;picture 25 | 24,0.0065,1191,106,0,sofa;couch;lounge 26 | 25,0.0061,1516,162,0,shelf 27 | 26,0.0060,667,69,1,house 28 | 27,0.0053,651,57,1,sea 29 | 28,0.0052,1847,224,0,mirror 30 | 29,0.0046,1158,128,1,rug;carpet;carpeting 31 | 30,0.0044,480,44,1,field 32 | 31,0.0044,1172,98,0,armchair 33 | 32,0.0044,1292,184,0,seat 34 | 33,0.0033,1386,138,0,fence;fencing 35 | 34,0.0031,698,61,0,desk 36 | 35,0.0030,781,73,0,rock;stone 37 | 36,0.0027,380,43,0,wardrobe;closet;press 38 | 37,0.0026,3089,302,0,lamp 39 | 38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub 40 | 39,0.0024,804,99,0,railing;rail 41 | 40,0.0023,1453,153,0,cushion 42 | 41,0.0023,411,37,0,base;pedestal;stand 43 | 42,0.0022,1440,162,0,box 44 | 43,0.0022,800,77,0,column;pillar 45 | 44,0.0020,2650,298,0,signboard;sign 46 | 45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser 47 | 46,0.0019,367,36,0,counter 48 | 47,0.0018,311,30,1,sand 49 | 48,0.0018,1181,122,0,sink 50 | 49,0.0018,287,23,1,skyscraper 51 | 50,0.0018,468,38,0,fireplace;hearth;open;fireplace 52 | 51,0.0018,402,43,0,refrigerator;icebox 53 | 52,0.0018,130,12,1,grandstand;covered;stand 54 | 53,0.0018,561,64,1,path 55 | 54,0.0017,880,102,0,stairs;steps 56 | 55,0.0017,86,12,1,runway 57 | 56,0.0017,172,11,0,case;display;case;showcase;vitrine 58 | 57,0.0017,198,18,0,pool;table;billiard;table;snooker;table 59 | 58,0.0017,930,109,0,pillow 60 | 59,0.0015,139,18,0,screen;door;screen 61 | 60,0.0015,564,52,1,stairway;staircase 62 | 61,0.0015,320,26,1,river 63 | 62,0.0015,261,29,1,bridge;span 64 | 63,0.0014,275,22,0,bookcase 65 | 64,0.0014,335,60,0,blind;screen 66 | 65,0.0014,792,75,0,coffee;table;cocktail;table 67 | 66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne 68 | 67,0.0014,1309,138,0,flower 69 | 68,0.0013,1112,113,0,book 70 | 69,0.0013,266,27,1,hill 71 | 70,0.0013,659,66,0,bench 72 | 71,0.0012,331,31,0,countertop 73 | 72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove 74 | 73,0.0012,369,36,0,palm;palm;tree 75 | 74,0.0012,144,9,0,kitchen;island 76 | 75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system 77 | 76,0.0010,324,33,0,swivel;chair 78 | 77,0.0009,304,27,0,boat 79 | 78,0.0009,170,20,0,bar 80 | 79,0.0009,68,6,0,arcade;machine 81 | 80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty 82 | 81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle 83 | 82,0.0008,492,49,0,towel 84 | 83,0.0008,2510,269,0,light;light;source 85 | 84,0.0008,440,39,0,truck;motortruck 86 | 85,0.0008,147,18,1,tower 87 | 86,0.0008,583,56,0,chandelier;pendant;pendent 88 | 87,0.0007,533,61,0,awning;sunshade;sunblind 89 | 88,0.0007,1989,239,0,streetlight;street;lamp 90 | 89,0.0007,71,5,0,booth;cubicle;stall;kiosk 91 | 90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box 92 | 91,0.0007,135,12,0,airplane;aeroplane;plane 93 | 92,0.0007,83,5,1,dirt;track 94 | 93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes 95 | 94,0.0006,1003,104,0,pole 96 | 95,0.0006,182,12,1,land;ground;soil 97 | 96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail 98 | 97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway 99 | 98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock 100 | 99,0.0006,965,114,0,bottle 101 | 100,0.0006,117,13,0,buffet;counter;sideboard 102 | 101,0.0006,354,35,0,poster;posting;placard;notice;bill;card 103 | 102,0.0006,108,9,1,stage 104 | 103,0.0006,557,55,0,van 105 | 104,0.0006,52,4,0,ship 106 | 105,0.0005,99,5,0,fountain 107 | 106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter 108 | 107,0.0005,292,31,0,canopy 109 | 108,0.0005,77,9,0,washer;automatic;washer;washing;machine 110 | 109,0.0005,340,38,0,plaything;toy 111 | 110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium 112 | 111,0.0005,465,49,0,stool 113 | 112,0.0005,50,4,0,barrel;cask 114 | 113,0.0005,622,75,0,basket;handbasket 115 | 114,0.0005,80,9,1,waterfall;falls 116 | 115,0.0005,59,3,0,tent;collapsible;shelter 117 | 116,0.0005,531,72,0,bag 118 | 117,0.0005,282,30,0,minibike;motorbike 119 | 118,0.0005,73,7,0,cradle 120 | 119,0.0005,435,44,0,oven 121 | 120,0.0005,136,25,0,ball 122 | 121,0.0005,116,24,0,food;solid;food 123 | 122,0.0004,266,31,0,step;stair 124 | 123,0.0004,58,12,0,tank;storage;tank 125 | 124,0.0004,418,83,0,trade;name;brand;name;brand;marque 126 | 125,0.0004,319,43,0,microwave;microwave;oven 127 | 126,0.0004,1193,139,0,pot;flowerpot 128 | 127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna 129 | 128,0.0004,347,36,0,bicycle;bike;wheel;cycle 130 | 129,0.0004,52,5,1,lake 131 | 130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine 132 | 131,0.0004,108,13,0,screen;silver;screen;projection;screen 133 | 132,0.0004,201,30,0,blanket;cover 134 | 133,0.0004,285,21,0,sculpture 135 | 134,0.0004,268,27,0,hood;exhaust;hood 136 | 135,0.0003,1020,108,0,sconce 137 | 136,0.0003,1282,122,0,vase 138 | 137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight 139 | 138,0.0003,453,57,0,tray 140 | 139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin 141 | 140,0.0003,397,44,0,fan 142 | 141,0.0003,92,8,1,pier;wharf;wharfage;dock 143 | 142,0.0003,228,18,0,crt;screen 144 | 143,0.0003,570,59,0,plate 145 | 144,0.0003,217,22,0,monitor;monitoring;device 146 | 145,0.0003,206,19,0,bulletin;board;notice;board 147 | 146,0.0003,130,14,0,shower 148 | 147,0.0003,178,28,0,radiator 149 | 148,0.0002,504,57,0,glass;drinking;glass 150 | 149,0.0002,775,96,0,clock 151 | 150,0.0002,421,56,0,flag 152 | -------------------------------------------------------------------------------- /netdissect/segmodel/resnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | try: 7 | from lib.nn import SynchronizedBatchNorm2d 8 | except ImportError: 9 | from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d 10 | 11 | try: 12 | from urllib import urlretrieve 13 | except ImportError: 14 | from urllib.request import urlretrieve 15 | 16 | 17 | __all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon! 18 | 19 | 20 | model_urls = { 21 | 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth', 22 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', 23 | 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth' 24 | } 25 | 26 | 27 | def conv3x3(in_planes, out_planes, stride=1): 28 | "3x3 convolution with padding" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=1, bias=False) 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | expansion = 1 35 | 36 | def __init__(self, inplanes, planes, stride=1, downsample=None): 37 | super(BasicBlock, self).__init__() 38 | self.conv1 = conv3x3(inplanes, planes, stride) 39 | self.bn1 = SynchronizedBatchNorm2d(planes) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.conv2 = conv3x3(planes, planes) 42 | self.bn2 = SynchronizedBatchNorm2d(planes) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | residual = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | 56 | if self.downsample is not None: 57 | residual = self.downsample(x) 58 | 59 | out += residual 60 | out = self.relu(out) 61 | 62 | return out 63 | 64 | 65 | class Bottleneck(nn.Module): 66 | expansion = 4 67 | 68 | def __init__(self, inplanes, planes, stride=1, downsample=None): 69 | super(Bottleneck, self).__init__() 70 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 71 | self.bn1 = SynchronizedBatchNorm2d(planes) 72 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 73 | padding=1, bias=False) 74 | self.bn2 = SynchronizedBatchNorm2d(planes) 75 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 76 | self.bn3 = SynchronizedBatchNorm2d(planes * 4) 77 | self.relu = nn.ReLU(inplace=True) 78 | self.downsample = downsample 79 | self.stride = stride 80 | 81 | def forward(self, x): 82 | residual = x 83 | 84 | out = self.conv1(x) 85 | out = self.bn1(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv2(out) 89 | out = self.bn2(out) 90 | out = self.relu(out) 91 | 92 | out = self.conv3(out) 93 | out = self.bn3(out) 94 | 95 | if self.downsample is not None: 96 | residual = self.downsample(x) 97 | 98 | out += residual 99 | out = self.relu(out) 100 | 101 | return out 102 | 103 | 104 | class ResNet(nn.Module): 105 | 106 | def __init__(self, block, layers, num_classes=1000): 107 | self.inplanes = 128 108 | super(ResNet, self).__init__() 109 | self.conv1 = conv3x3(3, 64, stride=2) 110 | self.bn1 = SynchronizedBatchNorm2d(64) 111 | self.relu1 = nn.ReLU(inplace=True) 112 | self.conv2 = conv3x3(64, 64) 113 | self.bn2 = SynchronizedBatchNorm2d(64) 114 | self.relu2 = nn.ReLU(inplace=True) 115 | self.conv3 = conv3x3(64, 128) 116 | self.bn3 = SynchronizedBatchNorm2d(128) 117 | self.relu3 = nn.ReLU(inplace=True) 118 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 119 | 120 | self.layer1 = self._make_layer(block, 64, layers[0]) 121 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 122 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 123 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 124 | self.avgpool = nn.AvgPool2d(7, stride=1) 125 | self.fc = nn.Linear(512 * block.expansion, num_classes) 126 | 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 130 | m.weight.data.normal_(0, math.sqrt(2. / n)) 131 | elif isinstance(m, SynchronizedBatchNorm2d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | 135 | def _make_layer(self, block, planes, blocks, stride=1): 136 | downsample = None 137 | if stride != 1 or self.inplanes != planes * block.expansion: 138 | downsample = nn.Sequential( 139 | nn.Conv2d(self.inplanes, planes * block.expansion, 140 | kernel_size=1, stride=stride, bias=False), 141 | SynchronizedBatchNorm2d(planes * block.expansion), 142 | ) 143 | 144 | layers = [] 145 | layers.append(block(self.inplanes, planes, stride, downsample)) 146 | self.inplanes = planes * block.expansion 147 | for i in range(1, blocks): 148 | layers.append(block(self.inplanes, planes)) 149 | 150 | return nn.Sequential(*layers) 151 | 152 | def forward(self, x): 153 | x = self.relu1(self.bn1(self.conv1(x))) 154 | x = self.relu2(self.bn2(self.conv2(x))) 155 | x = self.relu3(self.bn3(self.conv3(x))) 156 | x = self.maxpool(x) 157 | 158 | x = self.layer1(x) 159 | x = self.layer2(x) 160 | x = self.layer3(x) 161 | x = self.layer4(x) 162 | 163 | x = self.avgpool(x) 164 | x = x.view(x.size(0), -1) 165 | x = self.fc(x) 166 | 167 | return x 168 | 169 | def resnet18(pretrained=False, **kwargs): 170 | """Constructs a ResNet-18 model. 171 | 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(load_url(model_urls['resnet18'])) 178 | return model 179 | 180 | ''' 181 | def resnet34(pretrained=False, **kwargs): 182 | """Constructs a ResNet-34 model. 183 | 184 | Args: 185 | pretrained (bool): If True, returns a model pre-trained on ImageNet 186 | """ 187 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 188 | if pretrained: 189 | model.load_state_dict(load_url(model_urls['resnet34'])) 190 | return model 191 | ''' 192 | 193 | def resnet50(pretrained=False, **kwargs): 194 | """Constructs a ResNet-50 model. 195 | 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | """ 199 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 200 | if pretrained: 201 | model.load_state_dict(load_url(model_urls['resnet50']), strict=False) 202 | return model 203 | 204 | 205 | def resnet101(pretrained=False, **kwargs): 206 | """Constructs a ResNet-101 model. 207 | 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 212 | if pretrained: 213 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False) 214 | return model 215 | 216 | # def resnet152(pretrained=False, **kwargs): 217 | # """Constructs a ResNet-152 model. 218 | # 219 | # Args: 220 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 221 | # """ 222 | # model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 223 | # if pretrained: 224 | # model.load_state_dict(load_url(model_urls['resnet152'])) 225 | # return model 226 | 227 | def load_url(url, model_dir='./pretrained', map_location=None): 228 | if not os.path.exists(model_dir): 229 | os.makedirs(model_dir) 230 | filename = url.split('/')[-1] 231 | cached_file = os.path.join(model_dir, filename) 232 | if not os.path.exists(cached_file): 233 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 234 | urlretrieve(url, cached_file) 235 | return torch.load(cached_file, map_location=map_location) 236 | -------------------------------------------------------------------------------- /netdissect/segmodel/resnext.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | try: 7 | from lib.nn import SynchronizedBatchNorm2d 8 | except ImportError: 9 | from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d 10 | 11 | try: 12 | from urllib import urlretrieve 13 | except ImportError: 14 | from urllib.request import urlretrieve 15 | 16 | 17 | __all__ = ['ResNeXt', 'resnext101'] # support resnext 101 18 | 19 | 20 | model_urls = { 21 | #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth', 22 | 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth' 23 | } 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1): 27 | "3x3 convolution with padding" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=1, bias=False) 30 | 31 | 32 | class GroupBottleneck(nn.Module): 33 | expansion = 2 34 | 35 | def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None): 36 | super(GroupBottleneck, self).__init__() 37 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 38 | self.bn1 = SynchronizedBatchNorm2d(planes) 39 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 40 | padding=1, groups=groups, bias=False) 41 | self.bn2 = SynchronizedBatchNorm2d(planes) 42 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False) 43 | self.bn3 = SynchronizedBatchNorm2d(planes * 2) 44 | self.relu = nn.ReLU(inplace=True) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | residual = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | out = self.relu(out) 58 | 59 | out = self.conv3(out) 60 | out = self.bn3(out) 61 | 62 | if self.downsample is not None: 63 | residual = self.downsample(x) 64 | 65 | out += residual 66 | out = self.relu(out) 67 | 68 | return out 69 | 70 | 71 | class ResNeXt(nn.Module): 72 | 73 | def __init__(self, block, layers, groups=32, num_classes=1000): 74 | self.inplanes = 128 75 | super(ResNeXt, self).__init__() 76 | self.conv1 = conv3x3(3, 64, stride=2) 77 | self.bn1 = SynchronizedBatchNorm2d(64) 78 | self.relu1 = nn.ReLU(inplace=True) 79 | self.conv2 = conv3x3(64, 64) 80 | self.bn2 = SynchronizedBatchNorm2d(64) 81 | self.relu2 = nn.ReLU(inplace=True) 82 | self.conv3 = conv3x3(64, 128) 83 | self.bn3 = SynchronizedBatchNorm2d(128) 84 | self.relu3 = nn.ReLU(inplace=True) 85 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 86 | 87 | self.layer1 = self._make_layer(block, 128, layers[0], groups=groups) 88 | self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups) 89 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups) 90 | self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups) 91 | self.avgpool = nn.AvgPool2d(7, stride=1) 92 | self.fc = nn.Linear(1024 * block.expansion, num_classes) 93 | 94 | for m in self.modules(): 95 | if isinstance(m, nn.Conv2d): 96 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups 97 | m.weight.data.normal_(0, math.sqrt(2. / n)) 98 | elif isinstance(m, SynchronizedBatchNorm2d): 99 | m.weight.data.fill_(1) 100 | m.bias.data.zero_() 101 | 102 | def _make_layer(self, block, planes, blocks, stride=1, groups=1): 103 | downsample = None 104 | if stride != 1 or self.inplanes != planes * block.expansion: 105 | downsample = nn.Sequential( 106 | nn.Conv2d(self.inplanes, planes * block.expansion, 107 | kernel_size=1, stride=stride, bias=False), 108 | SynchronizedBatchNorm2d(planes * block.expansion), 109 | ) 110 | 111 | layers = [] 112 | layers.append(block(self.inplanes, planes, stride, groups, downsample)) 113 | self.inplanes = planes * block.expansion 114 | for i in range(1, blocks): 115 | layers.append(block(self.inplanes, planes, groups=groups)) 116 | 117 | return nn.Sequential(*layers) 118 | 119 | def forward(self, x): 120 | x = self.relu1(self.bn1(self.conv1(x))) 121 | x = self.relu2(self.bn2(self.conv2(x))) 122 | x = self.relu3(self.bn3(self.conv3(x))) 123 | x = self.maxpool(x) 124 | 125 | x = self.layer1(x) 126 | x = self.layer2(x) 127 | x = self.layer3(x) 128 | x = self.layer4(x) 129 | 130 | x = self.avgpool(x) 131 | x = x.view(x.size(0), -1) 132 | x = self.fc(x) 133 | 134 | return x 135 | 136 | 137 | ''' 138 | def resnext50(pretrained=False, **kwargs): 139 | """Constructs a ResNet-50 model. 140 | 141 | Args: 142 | pretrained (bool): If True, returns a model pre-trained on Places 143 | """ 144 | model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs) 145 | if pretrained: 146 | model.load_state_dict(load_url(model_urls['resnext50']), strict=False) 147 | return model 148 | ''' 149 | 150 | 151 | def resnext101(pretrained=False, **kwargs): 152 | """Constructs a ResNet-101 model. 153 | 154 | Args: 155 | pretrained (bool): If True, returns a model pre-trained on Places 156 | """ 157 | model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs) 158 | if pretrained: 159 | model.load_state_dict(load_url(model_urls['resnext101']), strict=False) 160 | return model 161 | 162 | 163 | # def resnext152(pretrained=False, **kwargs): 164 | # """Constructs a ResNeXt-152 model. 165 | # 166 | # Args: 167 | # pretrained (bool): If True, returns a model pre-trained on Places 168 | # """ 169 | # model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs) 170 | # if pretrained: 171 | # model.load_state_dict(load_url(model_urls['resnext152'])) 172 | # return model 173 | 174 | 175 | def load_url(url, model_dir='./pretrained', map_location=None): 176 | if not os.path.exists(model_dir): 177 | os.makedirs(model_dir) 178 | filename = url.split('/')[-1] 179 | cached_file = os.path.join(model_dir, filename) 180 | if not os.path.exists(cached_file): 181 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 182 | urlretrieve(url, cached_file) 183 | return torch.load(cached_file, map_location=map_location) 184 | -------------------------------------------------------------------------------- /netdissect/show.py: -------------------------------------------------------------------------------- 1 | # show.py 2 | # 3 | # An abbreviated way to output simple HTML layout of text and images 4 | # into a python notebook. 5 | # 6 | # - show a PIL image to show an inline HTML . 7 | # - show an array of items to vertically stack them, centered in a block. 8 | # - show an array of arrays to horizontally lay them out as inline blocks. 9 | # - show an array of tuples to create a table. 10 | 11 | import PIL.Image 12 | import base64 13 | import io 14 | import IPython 15 | import types 16 | import sys 17 | import html as html_module 18 | from IPython.display import display 19 | 20 | g_buffer = None 21 | 22 | 23 | def blocks(obj, space=''): 24 | return IPython.display.HTML(space.join(blocks_tags(obj))) 25 | 26 | 27 | def rows(obj, space=''): 28 | return IPython.display.HTML(space.join(rows_tags(obj))) 29 | 30 | 31 | def rows_tags(obj): 32 | if isinstance(obj, dict): 33 | obj = obj.items() 34 | results = [] 35 | results.append('') 36 | for row in obj: 37 | results.append('') 38 | for item in row: 39 | results.append('') 43 | results.append('') 44 | results.append('
') 41 | results.extend(blocks_tags(item)) 42 | results.append('
') 45 | return results 46 | 47 | 48 | def blocks_tags(obj): 49 | results = [] 50 | if hasattr(obj, '_repr_html_'): 51 | results.append(obj._repr_html_()) 52 | elif isinstance(obj, PIL.Image.Image): 53 | results.append(pil_to_html(obj)) 54 | elif isinstance(obj, (str, int, float)): 55 | results.append('
') 56 | results.append(html_module.escape(str(obj))) 57 | results.append('
') 58 | elif isinstance(obj, dict): 59 | results.extend(blocks_tags([(k, v) for k, v in obj.items()])) 60 | elif hasattr(obj, '__iter__'): 61 | if hasattr(obj, 'tolist'): 62 | # Handle numpy/pytorch tensors as lists. 63 | try: 64 | obj = obj.tolist() 65 | except: 66 | pass 67 | blockstart, blockend, tstart, tend, rstart, rend, cstart, cend = [ 68 | '
', 70 | '
', 71 | '', 72 | '
', 73 | '', 74 | '', 75 | '', 76 | '', 77 | ] 78 | needs_end = False 79 | table_mode = False 80 | for i, line in enumerate(obj): 81 | if i == 0: 82 | needs_end = True 83 | if isinstance(line, tuple): 84 | table_mode = True 85 | results.append(tstart) 86 | else: 87 | results.append(blockstart) 88 | if table_mode: 89 | results.append(rstart) 90 | if not isinstance(line, str) and hasattr(line, '__iter__'): 91 | for cell in line: 92 | results.append(cstart) 93 | results.extend(blocks_tags(cell)) 94 | results.append(cend) 95 | else: 96 | results.append(cstart) 97 | results.extend(blocks_tags(line)) 98 | results.append(cend) 99 | results.append(rend) 100 | else: 101 | results.extend(blocks_tags(line)) 102 | if needs_end: 103 | results.append(table_mode and tend or blockend) 104 | return results 105 | 106 | 107 | def pil_to_b64(img, format='png'): 108 | buffered = io.BytesIO() 109 | img.save(buffered, format=format) 110 | return base64.b64encode(buffered.getvalue()).decode('utf-8') 111 | 112 | 113 | def pil_to_url(img, format='png'): 114 | return 'data:image/%s;base64,%s' % (format, pil_to_b64(img, format)) 115 | 116 | 117 | def pil_to_html(img, margin=1): 118 | mattr = ' style="margin:%dpx"' % margin 119 | return '' % (pil_to_url(img), mattr) 120 | 121 | 122 | def a(x, cols=None): 123 | global g_buffer 124 | if g_buffer is None: 125 | g_buffer = [] 126 | g_buffer.append(x) 127 | if cols is not None and len(g_buffer) >= cols: 128 | flush() 129 | 130 | 131 | def reset(): 132 | global g_buffer 133 | g_buffer = None 134 | 135 | 136 | def flush(*args, **kwargs): 137 | global g_buffer 138 | if g_buffer is not None: 139 | x = g_buffer 140 | g_buffer = None 141 | display(blocks(x, *args, **kwargs)) 142 | 143 | 144 | def show(x=None, *args, **kwargs): 145 | flush(*args, **kwargs) 146 | if x is not None: 147 | display(blocks(x, *args, **kwargs)) 148 | 149 | 150 | def html(obj, space=''): 151 | return blocks(obj, space)._repr_html_() 152 | 153 | 154 | class CallableModule(types.ModuleType): 155 | def __init__(self): 156 | # or super().__init__(__name__) for Python 3 157 | types.ModuleType.__init__(self, __name__) 158 | self.__dict__.update(sys.modules[__name__].__dict__) 159 | 160 | def __call__(self, x=None, *args, **kwargs): 161 | show(x, *args, **kwargs) 162 | 163 | 164 | sys.modules[__name__] = CallableModule() 165 | -------------------------------------------------------------------------------- /netdissect/upsample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | 4 | 5 | def upsampler(target_shape, data_shape=None, 6 | image_size=None, scale_offset=None, 7 | source=None, convolutions=None, dtype=torch.float, device=None): 8 | ''' 9 | Returns a function that will upsample a batch of torch data from the 10 | expected data_shape to the specified target_shape. Can use scale_offset 11 | and image_size to center the grid in a nondefault way: scale_offset 12 | maps feature pixels to image_size pixels, and it is assumed that 13 | the target_shape is a uniform downsampling of image_size. 14 | ''' 15 | if source is not None: 16 | assert image_size is None 17 | image_size = image_size_from_source(source) 18 | if convolutions is not None: 19 | assert scale_offset is None 20 | scale_offset = sequence_scale_offset(convolutions) 21 | if image_size is not None and data_shape is None: 22 | data_shape = sequence_data_size(convolutions, image_size) 23 | assert data_shape is not None 24 | assert len(data_shape) == 2 25 | grid = upsample_grid(data_shape, target_shape, image_size, scale_offset, 26 | dtype, device) 27 | batch_grid = grid 28 | # padding mode could be 'border' 29 | 30 | def upsample_func(data, mode='bilinear', padding_mode='zeros'): 31 | nonlocal grid, batch_grid 32 | # Use the same grid over the whole batch 33 | if batch_grid.shape[0] != data.shape[0]: 34 | batch_grid = grid.expand((data.shape[0],) + grid.shape[1:]) 35 | if batch_grid.device != data.device: 36 | batch_grid = batch_grid.to(data.device) 37 | try: 38 | return torch.nn.functional.grid_sample(data, batch_grid, mode=mode, 39 | padding_mode=padding_mode, align_corners=True) 40 | except: 41 | return torch.nn.functional.grid_sample(data, batch_grid, mode=mode, 42 | padding_mode=padding_mode) # older pytorch version 43 | return upsample_func 44 | 45 | 46 | def sequence_scale_offset(modulelist): 47 | '''Returns (yscale, yoffset), (xscale, xoffset) given a list of modules. 48 | To convert output coordinates back to input coordinates while preserving 49 | centers of receptive fields, the affine transformation is: 50 | inpx = outx * xscale + xoffset 51 | inpy = outy * yscale + yoffset 52 | In both coordinate systems, (0, 0) refers to the upper-left corner 53 | of the first pixel, (0.5, 0.5) refers to the center of that pixel, 54 | and (1, 1) refers to the lower-right corner of that same pixel. 55 | 56 | Modern convnets tend to add padding to keep receptive fields centered 57 | while scaling, which will result in zero offsets. For example, after resnet 58 | does five stride-2 reductions, the scale_offset is just ((32, 0), (32, 0)). 59 | However, AlexNet does not pad every layer, and after five stride-2 60 | reductions, the scale_offset is ((32, 31), (32, 31)). 61 | ''' 62 | return tuple(convconfig_scale_offset(d) for d in convconfigs(modulelist)) 63 | 64 | 65 | def sequence_data_size(modulelist, input_size): 66 | '''Returns (yscale, yoffset), (xscale, xoffset) given a list of modules. 67 | To convert output coordinates back to input coordinates while preserving 68 | centers of receptive fields, the affine transformation is: 69 | inpx = outx * xscale + xoffset 70 | inpy = outy * yscale + yoffset 71 | In both coordinate systems, (0, 0) refers to the upper-left corner 72 | of the first pixel, (0.5, 0.5) refers to the center of that pixel, 73 | and (1, 1) refers to the lower-right corner of that same pixel. 74 | 75 | Modern convnets tend to add padding to keep receptive fields centered 76 | while scaling, which will result in zero offsets. For example, after resnet 77 | does five stride-2 reductions, the scale_offset is just ((32, 0), (32, 0)). 78 | However, AlexNet does not pad every layer, and after five stride-2 79 | reductions, the scale_offset is ((32, 31), (32, 31)). 80 | ''' 81 | return tuple(convconfig_data_size(d, s) 82 | for d, s in zip(convconfigs(modulelist), input_size)) 83 | 84 | 85 | def convconfig_scale_offset(convconfigs): 86 | '''Composes a lists of [(k, d, s, p)...] into a single total scale and 87 | offset that returns to the input coordinates. 88 | ''' 89 | if len(convconfigs) == 0: 90 | return (1, 0) 91 | scale, offset = convconfig_scale_offset(convconfigs[1:]) 92 | kernel, dilation, stride, padding = convconfigs[0] 93 | scale *= stride 94 | offset *= stride 95 | offset += (kernel - 1) * dilation / 2.0 - padding 96 | return scale, offset 97 | 98 | 99 | def convconfig_data_size(convconfigs, data_size): 100 | '''Applies a list of [(k, d, s, p)...] to the given input size to obtain 101 | an output size. 102 | ''' 103 | for kernel, dilation, stride, padding in convconfigs: 104 | data_size = (1 + (data_size + 2 * padding 105 | - dilation * (kernel - 1) - 1) // stride) 106 | return data_size 107 | 108 | 109 | def convconfigs(modulelist): 110 | '''Converts a list of modules to a pair of lists of 111 | [(kernel_size, dilation, stride, padding)...]: one for x, and one for y.''' 112 | result = [] 113 | for module in modulelist: 114 | settings = tuple(getattr(module, n, d) 115 | for n, d in (('kernel_size', 1), 116 | ('dilation', 1), ('stride', 1), ('padding', 0))) 117 | settings = tuple((s if isinstance(s, tuple) else (s, s)) 118 | for s in settings) 119 | if settings != ((1, 1), (1, 1), (1, 1), (0, 0)): 120 | result.append(zip(*settings)) 121 | return list(zip(*result)) 122 | 123 | 124 | def upsample_grid(data_shape, target_shape, image_size=None, 125 | scale_offset=None, dtype=torch.float, device=None): 126 | '''Prepares a grid to use with grid_sample to upsample a batch of 127 | features in data_shape to the target_shape. Can use scale_offset 128 | and image_size to center the grid in a nondefault way: scale_offset 129 | maps feature pixels to image_size pixels, and it is assumed that 130 | the target_shape is a uniform downsampling of image_size.''' 131 | # Default is that nothing is resized. 132 | if target_shape is None: 133 | target_shape = data_shape 134 | # Make a default scale_offset to fill the image if there isn't one 135 | if scale_offset is None: 136 | scale = tuple(float(ts) / ds 137 | for ts, ds in zip(target_shape, data_shape)) 138 | offset = tuple(0.5 * s - 0.5 for s in scale) 139 | else: 140 | scale, offset = (v for v in zip(*scale_offset)) 141 | # Handle downsampling for different input vs target shape. 142 | if image_size is not None: 143 | scale = tuple(s * (ts - 1) / (ns - 1) 144 | for s, ns, ts in zip(scale, image_size, target_shape)) 145 | offset = tuple(o * (ts - 1) / (ns - 1) 146 | for o, ns, ts in zip(offset, image_size, target_shape)) 147 | # Pytorch needs target coordinates in terms of source coordinates [-1..1] 148 | ty, tx = (((torch.arange(ts, dtype=dtype, device=device) - o) 149 | * (2 / (s * max(1, (ss - 1)))) - 1) 150 | for ts, ss, s, o, in zip(target_shape, data_shape, scale, offset)) 151 | # Whoa, note that grid_sample reverses the order y, x -> x, y. 152 | grid = torch.stack( 153 | (tx[None, :].expand(target_shape), ty[:, None].expand(target_shape)), 2 154 | )[None, :, :, :].expand((1, target_shape[0], target_shape[1], 2)) 155 | return grid 156 | 157 | 158 | def image_size_from_source(source): 159 | sizer = find_sizer(source) 160 | if sizer is not None: 161 | size = sizer.size 162 | elif hasattr(source, 'resolution'): 163 | size = source.resolution 164 | if hasattr(size, '__len__'): 165 | return size 166 | return (size, size) 167 | 168 | 169 | def find_sizer(source): 170 | ''' 171 | Crawl around the transforms attached to a dataset looking for 172 | the last crop or resize transform to return. 173 | ''' 174 | if source is None: 175 | return None 176 | if isinstance(source, (transforms.Resize, transforms.RandomCrop, 177 | transforms.RandomResizedCrop, transforms.CenterCrop)): 178 | return source 179 | t = getattr(source, 'transform', None) 180 | if t is not None: 181 | return find_sizer(t) 182 | ts = getattr(source, 'transforms', None) 183 | if ts is not None: 184 | for t in reversed(ts): 185 | result = find_sizer(t) 186 | if result is not None: 187 | return result 188 | return None 189 | -------------------------------------------------------------------------------- /netdissect/upsegmodel/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import ModelBuilder, SegmentationModule 2 | -------------------------------------------------------------------------------- /netdissect/upsegmodel/prroi_pool/.gitignore: -------------------------------------------------------------------------------- 1 | *.o 2 | /_prroi_pooling 3 | -------------------------------------------------------------------------------- /netdissect/upsegmodel/prroi_pool/README.md: -------------------------------------------------------------------------------- 1 | # PreciseRoIPooling 2 | This repo implements the **Precise RoI Pooling** (PrRoI Pooling), proposed in the paper **Acquisition of Localization Confidence for Accurate Object Detection** published at ECCV 2018 (Oral Presentation). 3 | 4 | **Acquisition of Localization Confidence for Accurate Object Detection** 5 | 6 | _Borui Jiang*, Ruixuan Luo*, Jiayuan Mao*, Tete Xiao, Yuning Jiang_ (* indicates equal contribution.) 7 | 8 | https://arxiv.org/abs/1807.11590 9 | 10 | ## Brief 11 | 12 | In short, Precise RoI Pooling is an integration-based (bilinear interpolation) average pooling method for RoI Pooling. It avoids any quantization and has a continuous gradient on bounding box coordinates. It is: 13 | 14 | - different from the original RoI Pooling proposed in [Fast R-CNN](https://arxiv.org/abs/1504.08083). PrRoI Pooling uses average pooling instead of max pooling for each bin and has a continuous gradient on bounding box coordinates. That is, one can take the derivatives of some loss function w.r.t the coordinates of each RoI and optimize the RoI coordinates. 15 | - different from the RoI Align proposed in [Mask R-CNN](https://arxiv.org/abs/1703.06870). PrRoI Pooling uses a full integration-based average pooling instead of sampling a constant number of points. This makes the gradient w.r.t. the coordinates continuous. 16 | 17 | For a better illustration, we illustrate RoI Pooling, RoI Align and PrRoI Pooing in the following figure. More details including the gradient computation can be found in our paper. 18 | 19 |
20 | 21 | ## Implementation 22 | 23 | PrRoI Pooling was originally implemented by [Tete Xiao](http://tetexiao.com/) based on MegBrain, an (internal) deep learning framework built by Megvii Inc. It was later adapted into open-source deep learning frameworks. Currently, we only support PyTorch. Unfortunately, we don't have any specific plan for the adaptation into other frameworks such as TensorFlow, but any contributions (pull requests) will be more than welcome. 24 | 25 | ## Usage (PyTorch 1.0) 26 | 27 | In the directory `pytorch/`, we provide a PyTorch-based implementation of PrRoI Pooling. It requires PyTorch 1.0+ and only supports CUDA (CPU mode is not implemented). 28 | Since we use PyTorch JIT for cxx/cuda code compilation, to use the module in your code, simply do: 29 | 30 | ``` 31 | from prroi_pool import PrRoIPool2D 32 | 33 | avg_pool = PrRoIPool2D(window_height, window_width, spatial_scale) 34 | roi_features = avg_pool(features, rois) 35 | 36 | # for those who want to use the "functional" 37 | 38 | from prroi_pool.functional import prroi_pool2d 39 | roi_features = prroi_pool2d(features, rois, window_height, window_width, spatial_scale) 40 | ``` 41 | 42 | 43 | ## Usage (PyTorch 0.4) 44 | 45 | **!!! Please first checkout to the branch pytorch0.4.** 46 | 47 | In the directory `pytorch/`, we provide a PyTorch-based implementation of PrRoI Pooling. It requires PyTorch 0.4 and only supports CUDA (CPU mode is not implemented). 48 | To use the PrRoI Pooling module, first goto `pytorch/prroi_pool` and execute `./travis.sh` to compile the essential components (you may need `nvcc` for this step). To use the module in your code, simply do: 49 | 50 | ``` 51 | from prroi_pool import PrRoIPool2D 52 | 53 | avg_pool = PrRoIPool2D(window_height, window_width, spatial_scale) 54 | roi_features = avg_pool(features, rois) 55 | 56 | # for those who want to use the "functional" 57 | 58 | from prroi_pool.functional import prroi_pool2d 59 | roi_features = prroi_pool2d(features, rois, window_height, window_width, spatial_scale) 60 | ``` 61 | 62 | Here, 63 | 64 | - RoI is an `m * 5` float tensor of format `(batch_index, x0, y0, x1, y1)`, following the convention in the original Caffe implementation of RoI Pooling, although in some frameworks the batch indices are provided by an integer tensor. 65 | - `spatial_scale` is multiplied to the RoIs. For example, if your feature maps are down-sampled by a factor of 16 (w.r.t. the input image), you should use a spatial scale of `1/16`. 66 | - The coordinates for RoI follows the [L, R) convension. That is, `(0, 0, 4, 4)` denotes a box of size `4x4`. 67 | -------------------------------------------------------------------------------- /netdissect/upsegmodel/prroi_pool/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : __init__.py 4 | # Author : Jiayuan Mao, Tete Xiao 5 | # Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com 6 | # Date : 07/13/2018 7 | # 8 | # This file is part of PreciseRoIPooling. 9 | # Distributed under terms of the MIT license. 10 | # Copyright (c) 2017 Megvii Technology Limited. 11 | 12 | from .prroi_pool import * 13 | 14 | -------------------------------------------------------------------------------- /netdissect/upsegmodel/prroi_pool/build.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : build.py 4 | # Author : Jiayuan Mao, Tete Xiao 5 | # Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com 6 | # Date : 07/13/2018 7 | # 8 | # This file is part of PreciseRoIPooling. 9 | # Distributed under terms of the MIT license. 10 | # Copyright (c) 2017 Megvii Technology Limited. 11 | 12 | import os 13 | import torch 14 | 15 | from torch.utils.ffi import create_extension 16 | 17 | headers = [] 18 | sources = [] 19 | defines = [] 20 | extra_objects = [] 21 | with_cuda = False 22 | 23 | if torch.cuda.is_available(): 24 | with_cuda = True 25 | 26 | headers+= ['src/prroi_pooling_gpu.h'] 27 | sources += ['src/prroi_pooling_gpu.c'] 28 | defines += [('WITH_CUDA', None)] 29 | 30 | this_file = os.path.dirname(os.path.realpath(__file__)) 31 | extra_objects_cuda = ['src/prroi_pooling_gpu_impl.cu.o'] 32 | extra_objects_cuda = [os.path.join(this_file, fname) for fname in extra_objects_cuda] 33 | extra_objects.extend(extra_objects_cuda) 34 | else: 35 | # TODO(Jiayuan Mao @ 07/13): remove this restriction after we support the cpu implementation. 36 | raise NotImplementedError('Precise RoI Pooling only supports GPU (cuda) implememtations.') 37 | 38 | ffi = create_extension( 39 | '_prroi_pooling', 40 | headers=headers, 41 | sources=sources, 42 | define_macros=defines, 43 | relative_to=__file__, 44 | with_cuda=with_cuda, 45 | extra_objects=extra_objects 46 | ) 47 | 48 | if __name__ == '__main__': 49 | ffi.build() 50 | 51 | -------------------------------------------------------------------------------- /netdissect/upsegmodel/prroi_pool/functional.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : functional.py 4 | # Author : Jiayuan Mao, Tete Xiao 5 | # Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com 6 | # Date : 07/13/2018 7 | # 8 | # This file is part of PreciseRoIPooling. 9 | # Distributed under terms of the MIT license. 10 | # Copyright (c) 2017 Megvii Technology Limited. 11 | 12 | import torch 13 | import torch.autograd as ag 14 | 15 | try: 16 | from os.path import join as pjoin, dirname 17 | from torch.utils.cpp_extension import load as load_extension 18 | root_dir = pjoin(dirname(__file__), 'src') 19 | _prroi_pooling = load_extension( 20 | '_prroi_pooling', 21 | [pjoin(root_dir, 'prroi_pooling_gpu.c'), pjoin(root_dir, 'prroi_pooling_gpu_impl.cu')], 22 | verbose=False 23 | ) 24 | except ImportError: 25 | raise ImportError('Can not compile Precise RoI Pooling library.') 26 | 27 | __all__ = ['prroi_pool2d'] 28 | 29 | 30 | class PrRoIPool2DFunction(ag.Function): 31 | @staticmethod 32 | def forward(ctx, features, rois, pooled_height, pooled_width, spatial_scale): 33 | assert 'FloatTensor' in features.type() and 'FloatTensor' in rois.type(), \ 34 | 'Precise RoI Pooling only takes float input, got {} for features and {} for rois.'.format(features.type(), rois.type()) 35 | 36 | pooled_height = int(pooled_height) 37 | pooled_width = int(pooled_width) 38 | spatial_scale = float(spatial_scale) 39 | 40 | features = features.contiguous() 41 | rois = rois.contiguous() 42 | params = (pooled_height, pooled_width, spatial_scale) 43 | 44 | if features.is_cuda: 45 | output = _prroi_pooling.prroi_pooling_forward_cuda(features, rois, *params) 46 | ctx.params = params 47 | # everything here is contiguous. 48 | ctx.save_for_backward(features, rois, output) 49 | else: 50 | raise NotImplementedError('Precise RoI Pooling only supports GPU (cuda) implememtations.') 51 | 52 | return output 53 | 54 | @staticmethod 55 | def backward(ctx, grad_output): 56 | features, rois, output = ctx.saved_tensors 57 | grad_input = grad_coor = None 58 | 59 | if features.requires_grad: 60 | grad_output = grad_output.contiguous() 61 | grad_input = _prroi_pooling.prroi_pooling_backward_cuda(features, rois, output, grad_output, *ctx.params) 62 | if rois.requires_grad: 63 | grad_output = grad_output.contiguous() 64 | grad_coor = _prroi_pooling.prroi_pooling_coor_backward_cuda(features, rois, output, grad_output, *ctx.params) 65 | 66 | return grad_input, grad_coor, None, None, None 67 | 68 | 69 | prroi_pool2d = PrRoIPool2DFunction.apply 70 | 71 | -------------------------------------------------------------------------------- /netdissect/upsegmodel/prroi_pool/prroi_pool.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : prroi_pool.py 4 | # Author : Jiayuan Mao, Tete Xiao 5 | # Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com 6 | # Date : 07/13/2018 7 | # 8 | # This file is part of PreciseRoIPooling. 9 | # Distributed under terms of the MIT license. 10 | # Copyright (c) 2017 Megvii Technology Limited. 11 | 12 | import torch.nn as nn 13 | 14 | from .functional import prroi_pool2d 15 | 16 | __all__ = ['PrRoIPool2D'] 17 | 18 | 19 | class PrRoIPool2D(nn.Module): 20 | def __init__(self, pooled_height, pooled_width, spatial_scale): 21 | super().__init__() 22 | 23 | self.pooled_height = int(pooled_height) 24 | self.pooled_width = int(pooled_width) 25 | self.spatial_scale = float(spatial_scale) 26 | 27 | def forward(self, features, rois): 28 | return prroi_pool2d(features, rois, self.pooled_height, self.pooled_width, self.spatial_scale) 29 | -------------------------------------------------------------------------------- /netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu.c: -------------------------------------------------------------------------------- 1 | /* 2 | * File : prroi_pooling_gpu.c 3 | * Author : Jiayuan Mao, Tete Xiao 4 | * Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com 5 | * Date : 07/13/2018 6 | * 7 | * Distributed under terms of the MIT license. 8 | * Copyright (c) 2017 Megvii Technology Limited. 9 | */ 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | #include 18 | 19 | #include "prroi_pooling_gpu_impl.cuh" 20 | 21 | 22 | at::Tensor prroi_pooling_forward_cuda(const at::Tensor &features, const at::Tensor &rois, int pooled_height, int pooled_width, float spatial_scale) { 23 | int nr_rois = rois.size(0); 24 | int nr_channels = features.size(1); 25 | int height = features.size(2); 26 | int width = features.size(3); 27 | int top_count = nr_rois * nr_channels * pooled_height * pooled_width; 28 | auto output = at::zeros({nr_rois, nr_channels, pooled_height, pooled_width}, features.options()); 29 | 30 | if (output.numel() == 0) { 31 | THCudaCheck(cudaGetLastError()); 32 | return output; 33 | } 34 | 35 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 36 | PrRoIPoolingForwardGpu( 37 | stream, features.data(), rois.data(), output.data(), 38 | nr_channels, height, width, pooled_height, pooled_width, spatial_scale, 39 | top_count 40 | ); 41 | 42 | THCudaCheck(cudaGetLastError()); 43 | return output; 44 | } 45 | 46 | at::Tensor prroi_pooling_backward_cuda( 47 | const at::Tensor &features, const at::Tensor &rois, const at::Tensor &output, const at::Tensor &output_diff, 48 | int pooled_height, int pooled_width, float spatial_scale) { 49 | 50 | auto features_diff = at::zeros_like(features); 51 | 52 | int nr_rois = rois.size(0); 53 | int batch_size = features.size(0); 54 | int nr_channels = features.size(1); 55 | int height = features.size(2); 56 | int width = features.size(3); 57 | int top_count = nr_rois * nr_channels * pooled_height * pooled_width; 58 | int bottom_count = batch_size * nr_channels * height * width; 59 | 60 | if (output.numel() == 0) { 61 | THCudaCheck(cudaGetLastError()); 62 | return features_diff; 63 | } 64 | 65 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 66 | PrRoIPoolingBackwardGpu( 67 | stream, 68 | features.data(), rois.data(), output.data(), output_diff.data(), 69 | features_diff.data(), 70 | nr_channels, height, width, pooled_height, pooled_width, spatial_scale, 71 | top_count, bottom_count 72 | ); 73 | 74 | THCudaCheck(cudaGetLastError()); 75 | return features_diff; 76 | } 77 | 78 | at::Tensor prroi_pooling_coor_backward_cuda( 79 | const at::Tensor &features, const at::Tensor &rois, const at::Tensor &output, const at::Tensor &output_diff, 80 | int pooled_height, int pooled_width, float spatial_scale) { 81 | 82 | auto coor_diff = at::zeros_like(rois); 83 | 84 | int nr_rois = rois.size(0); 85 | int nr_channels = features.size(1); 86 | int height = features.size(2); 87 | int width = features.size(3); 88 | int top_count = nr_rois * nr_channels * pooled_height * pooled_width; 89 | int bottom_count = nr_rois * 5; 90 | 91 | if (output.numel() == 0) { 92 | THCudaCheck(cudaGetLastError()); 93 | return coor_diff; 94 | } 95 | 96 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 97 | PrRoIPoolingCoorBackwardGpu( 98 | stream, 99 | features.data(), rois.data(), output.data(), output_diff.data(), 100 | coor_diff.data(), 101 | nr_channels, height, width, pooled_height, pooled_width, spatial_scale, 102 | top_count, bottom_count 103 | ); 104 | 105 | THCudaCheck(cudaGetLastError()); 106 | return coor_diff; 107 | } 108 | 109 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 110 | m.def("prroi_pooling_forward_cuda", &prroi_pooling_forward_cuda, "PRRoIPooling_forward"); 111 | m.def("prroi_pooling_backward_cuda", &prroi_pooling_backward_cuda, "PRRoIPooling_backward"); 112 | m.def("prroi_pooling_coor_backward_cuda", &prroi_pooling_coor_backward_cuda, "PRRoIPooling_backward_coor"); 113 | } 114 | -------------------------------------------------------------------------------- /netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu.h: -------------------------------------------------------------------------------- 1 | /* 2 | * File : prroi_pooling_gpu.h 3 | * Author : Jiayuan Mao, Tete Xiao 4 | * Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com 5 | * Date : 07/13/2018 6 | * 7 | * Distributed under terms of the MIT license. 8 | * Copyright (c) 2017 Megvii Technology Limited. 9 | */ 10 | 11 | int prroi_pooling_forward_cuda(THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, int pooled_height, int pooled_width, float spatial_scale); 12 | 13 | int prroi_pooling_backward_cuda( 14 | THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, THCudaTensor *output_diff, THCudaTensor *features_diff, 15 | int pooled_height, int pooled_width, float spatial_scale 16 | ); 17 | 18 | int prroi_pooling_coor_backward_cuda( 19 | THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, THCudaTensor *output_diff, THCudaTensor *features_diff, 20 | int pooled_height, int pooled_width, float spatial_scal 21 | ); 22 | 23 | -------------------------------------------------------------------------------- /netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu_impl.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * File : prroi_pooling_gpu_impl.cuh 3 | * Author : Tete Xiao, Jiayuan Mao 4 | * Email : jasonhsiao97@gmail.com 5 | * 6 | * Distributed under terms of the MIT license. 7 | * Copyright (c) 2017 Megvii Technology Limited. 8 | */ 9 | 10 | #ifndef PRROI_POOLING_GPU_IMPL_CUH 11 | #define PRROI_POOLING_GPU_IMPL_CUH 12 | 13 | #ifdef __cplusplus 14 | extern "C" { 15 | #endif 16 | 17 | #define F_DEVPTR_IN const float * 18 | #define F_DEVPTR_OUT float * 19 | 20 | void PrRoIPoolingForwardGpu( 21 | cudaStream_t stream, 22 | F_DEVPTR_IN bottom_data, 23 | F_DEVPTR_IN bottom_rois, 24 | F_DEVPTR_OUT top_data, 25 | const int channels_, const int height_, const int width_, 26 | const int pooled_height_, const int pooled_width_, 27 | const float spatial_scale_, 28 | const int top_count); 29 | 30 | void PrRoIPoolingBackwardGpu( 31 | cudaStream_t stream, 32 | F_DEVPTR_IN bottom_data, 33 | F_DEVPTR_IN bottom_rois, 34 | F_DEVPTR_IN top_data, 35 | F_DEVPTR_IN top_diff, 36 | F_DEVPTR_OUT bottom_diff, 37 | const int channels_, const int height_, const int width_, 38 | const int pooled_height_, const int pooled_width_, 39 | const float spatial_scale_, 40 | const int top_count, const int bottom_count); 41 | 42 | void PrRoIPoolingCoorBackwardGpu( 43 | cudaStream_t stream, 44 | F_DEVPTR_IN bottom_data, 45 | F_DEVPTR_IN bottom_rois, 46 | F_DEVPTR_IN top_data, 47 | F_DEVPTR_IN top_diff, 48 | F_DEVPTR_OUT bottom_diff, 49 | const int channels_, const int height_, const int width_, 50 | const int pooled_height_, const int pooled_width_, 51 | const float spatial_scale_, 52 | const int top_count, const int bottom_count); 53 | 54 | #ifdef __cplusplus 55 | } /* !extern "C" */ 56 | #endif 57 | 58 | #endif /* !PRROI_POOLING_GPU_IMPL_CUH */ 59 | 60 | -------------------------------------------------------------------------------- /netdissect/upsegmodel/prroi_pool/test_prroi_pooling2d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_prroi_pooling2d.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 18/02/2018 6 | # 7 | # This file is part of Jacinle. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from jactorch.utils.unittest import TorchTestCase 16 | 17 | from prroi_pool import PrRoIPool2D 18 | 19 | 20 | class TestPrRoIPool2D(TorchTestCase): 21 | def test_forward(self): 22 | pool = PrRoIPool2D(7, 7, spatial_scale=0.5) 23 | features = torch.rand((4, 16, 24, 32)).cuda() 24 | rois = torch.tensor([ 25 | [0, 0, 0, 14, 14], 26 | [1, 14, 14, 28, 28], 27 | ]).float().cuda() 28 | 29 | out = pool(features, rois) 30 | out_gold = F.avg_pool2d(features, kernel_size=2, stride=1) 31 | 32 | self.assertTensorClose(out, torch.stack(( 33 | out_gold[0, :, :7, :7], 34 | out_gold[1, :, 7:14, 7:14], 35 | ), dim=0)) 36 | 37 | def test_backward_shapeonly(self): 38 | pool = PrRoIPool2D(2, 2, spatial_scale=0.5) 39 | 40 | features = torch.rand((4, 2, 24, 32)).cuda() 41 | rois = torch.tensor([ 42 | [0, 0, 0, 4, 4], 43 | [1, 14, 14, 18, 18], 44 | ]).float().cuda() 45 | features.requires_grad = rois.requires_grad = True 46 | out = pool(features, rois) 47 | 48 | loss = out.sum() 49 | loss.backward() 50 | 51 | self.assertTupleEqual(features.size(), features.grad.size()) 52 | self.assertTupleEqual(rois.size(), rois.grad.size()) 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /netdissect/upsegmodel/resnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | try: 7 | from lib.nn import SynchronizedBatchNorm2d 8 | except ImportError: 9 | from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d 10 | 11 | try: 12 | from urllib import urlretrieve 13 | except ImportError: 14 | from urllib.request import urlretrieve 15 | 16 | 17 | __all__ = ['ResNet', 'resnet50', 'resnet101'] # resnet101 is coming soon! 18 | 19 | 20 | model_urls = { 21 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', 22 | 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth' 23 | } 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1): 27 | "3x3 convolution with padding" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=1, bias=False) 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | expansion = 1 34 | 35 | def __init__(self, inplanes, planes, stride=1, downsample=None): 36 | super(BasicBlock, self).__init__() 37 | self.conv1 = conv3x3(inplanes, planes, stride) 38 | self.bn1 = SynchronizedBatchNorm2d(planes) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.conv2 = conv3x3(planes, planes) 41 | self.bn2 = SynchronizedBatchNorm2d(planes) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | residual = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | 55 | if self.downsample is not None: 56 | residual = self.downsample(x) 57 | 58 | out += residual 59 | out = self.relu(out) 60 | 61 | return out 62 | 63 | 64 | class Bottleneck(nn.Module): 65 | expansion = 4 66 | 67 | def __init__(self, inplanes, planes, stride=1, downsample=None): 68 | super(Bottleneck, self).__init__() 69 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 70 | self.bn1 = SynchronizedBatchNorm2d(planes) 71 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 72 | padding=1, bias=False) 73 | self.bn2 = SynchronizedBatchNorm2d(planes) 74 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 75 | self.bn3 = SynchronizedBatchNorm2d(planes * 4) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.downsample = downsample 78 | self.stride = stride 79 | 80 | def forward(self, x): 81 | residual = x 82 | 83 | out = self.conv1(x) 84 | out = self.bn1(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv2(out) 88 | out = self.bn2(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv3(out) 92 | out = self.bn3(out) 93 | 94 | if self.downsample is not None: 95 | residual = self.downsample(x) 96 | 97 | out += residual 98 | out = self.relu(out) 99 | 100 | return out 101 | 102 | 103 | class ResNet(nn.Module): 104 | 105 | def __init__(self, block, layers, num_classes=1000): 106 | self.inplanes = 128 107 | super(ResNet, self).__init__() 108 | self.conv1 = conv3x3(3, 64, stride=2) 109 | self.bn1 = SynchronizedBatchNorm2d(64) 110 | self.relu1 = nn.ReLU(inplace=True) 111 | self.conv2 = conv3x3(64, 64) 112 | self.bn2 = SynchronizedBatchNorm2d(64) 113 | self.relu2 = nn.ReLU(inplace=True) 114 | self.conv3 = conv3x3(64, 128) 115 | self.bn3 = SynchronizedBatchNorm2d(128) 116 | self.relu3 = nn.ReLU(inplace=True) 117 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 118 | 119 | self.layer1 = self._make_layer(block, 64, layers[0]) 120 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 121 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 122 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 123 | self.avgpool = nn.AvgPool2d(7, stride=1) 124 | self.fc = nn.Linear(512 * block.expansion, num_classes) 125 | 126 | for m in self.modules(): 127 | if isinstance(m, nn.Conv2d): 128 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 129 | m.weight.data.normal_(0, math.sqrt(2. / n)) 130 | elif isinstance(m, SynchronizedBatchNorm2d): 131 | m.weight.data.fill_(1) 132 | m.bias.data.zero_() 133 | 134 | def _make_layer(self, block, planes, blocks, stride=1): 135 | downsample = None 136 | if stride != 1 or self.inplanes != planes * block.expansion: 137 | downsample = nn.Sequential( 138 | nn.Conv2d(self.inplanes, planes * block.expansion, 139 | kernel_size=1, stride=stride, bias=False), 140 | SynchronizedBatchNorm2d(planes * block.expansion), 141 | ) 142 | 143 | layers = [] 144 | layers.append(block(self.inplanes, planes, stride, downsample)) 145 | self.inplanes = planes * block.expansion 146 | for i in range(1, blocks): 147 | layers.append(block(self.inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x): 152 | x = self.relu1(self.bn1(self.conv1(x))) 153 | x = self.relu2(self.bn2(self.conv2(x))) 154 | x = self.relu3(self.bn3(self.conv3(x))) 155 | x = self.maxpool(x) 156 | 157 | x = self.layer1(x) 158 | x = self.layer2(x) 159 | x = self.layer3(x) 160 | x = self.layer4(x) 161 | 162 | x = self.avgpool(x) 163 | x = x.view(x.size(0), -1) 164 | x = self.fc(x) 165 | 166 | return x 167 | 168 | ''' 169 | def resnet18(pretrained=False, **kwargs): 170 | """Constructs a ResNet-18 model. 171 | 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on Places 174 | """ 175 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(load_url(model_urls['resnet18'])) 178 | return model 179 | 180 | 181 | def resnet34(pretrained=False, **kwargs): 182 | """Constructs a ResNet-34 model. 183 | 184 | Args: 185 | pretrained (bool): If True, returns a model pre-trained on Places 186 | """ 187 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 188 | if pretrained: 189 | model.load_state_dict(load_url(model_urls['resnet34'])) 190 | return model 191 | ''' 192 | 193 | def resnet50(pretrained=False, **kwargs): 194 | """Constructs a ResNet-50 model. 195 | 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on Places 198 | """ 199 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 200 | if pretrained: 201 | model.load_state_dict(load_url(model_urls['resnet50']), strict=False) 202 | return model 203 | 204 | 205 | def resnet101(pretrained=False, **kwargs): 206 | """Constructs a ResNet-101 model. 207 | 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on Places 210 | """ 211 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 212 | if pretrained: 213 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False) 214 | return model 215 | 216 | # def resnet152(pretrained=False, **kwargs): 217 | # """Constructs a ResNet-152 model. 218 | # 219 | # Args: 220 | # pretrained (bool): If True, returns a model pre-trained on Places 221 | # """ 222 | # model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 223 | # if pretrained: 224 | # model.load_state_dict(load_url(model_urls['resnet152'])) 225 | # return model 226 | 227 | def load_url(url, model_dir='./pretrained', map_location=None): 228 | if not os.path.exists(model_dir): 229 | os.makedirs(model_dir) 230 | filename = url.split('/')[-1] 231 | cached_file = os.path.join(model_dir, filename) 232 | if not os.path.exists(cached_file): 233 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 234 | urlretrieve(url, cached_file) 235 | return torch.load(cached_file, map_location=map_location) 236 | -------------------------------------------------------------------------------- /netdissect/upsegmodel/resnext.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | try: 7 | from lib.nn import SynchronizedBatchNorm2d 8 | except ImportError: 9 | from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d 10 | 11 | try: 12 | from urllib import urlretrieve 13 | except ImportError: 14 | from urllib.request import urlretrieve 15 | 16 | 17 | __all__ = ['ResNeXt', 'resnext101'] # support resnext 101 18 | 19 | 20 | model_urls = { 21 | #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth', 22 | 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth' 23 | } 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1): 27 | "3x3 convolution with padding" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=1, bias=False) 30 | 31 | 32 | class GroupBottleneck(nn.Module): 33 | expansion = 2 34 | 35 | def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None): 36 | super(GroupBottleneck, self).__init__() 37 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 38 | self.bn1 = SynchronizedBatchNorm2d(planes) 39 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 40 | padding=1, groups=groups, bias=False) 41 | self.bn2 = SynchronizedBatchNorm2d(planes) 42 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False) 43 | self.bn3 = SynchronizedBatchNorm2d(planes * 2) 44 | self.relu = nn.ReLU(inplace=True) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | residual = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | out = self.relu(out) 58 | 59 | out = self.conv3(out) 60 | out = self.bn3(out) 61 | 62 | if self.downsample is not None: 63 | residual = self.downsample(x) 64 | 65 | out += residual 66 | out = self.relu(out) 67 | 68 | return out 69 | 70 | 71 | class ResNeXt(nn.Module): 72 | 73 | def __init__(self, block, layers, groups=32, num_classes=1000): 74 | self.inplanes = 128 75 | super(ResNeXt, self).__init__() 76 | self.conv1 = conv3x3(3, 64, stride=2) 77 | self.bn1 = SynchronizedBatchNorm2d(64) 78 | self.relu1 = nn.ReLU(inplace=True) 79 | self.conv2 = conv3x3(64, 64) 80 | self.bn2 = SynchronizedBatchNorm2d(64) 81 | self.relu2 = nn.ReLU(inplace=True) 82 | self.conv3 = conv3x3(64, 128) 83 | self.bn3 = SynchronizedBatchNorm2d(128) 84 | self.relu3 = nn.ReLU(inplace=True) 85 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 86 | 87 | self.layer1 = self._make_layer(block, 128, layers[0], groups=groups) 88 | self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups) 89 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups) 90 | self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups) 91 | self.avgpool = nn.AvgPool2d(7, stride=1) 92 | self.fc = nn.Linear(1024 * block.expansion, num_classes) 93 | 94 | for m in self.modules(): 95 | if isinstance(m, nn.Conv2d): 96 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups 97 | m.weight.data.normal_(0, math.sqrt(2. / n)) 98 | elif isinstance(m, SynchronizedBatchNorm2d): 99 | m.weight.data.fill_(1) 100 | m.bias.data.zero_() 101 | 102 | def _make_layer(self, block, planes, blocks, stride=1, groups=1): 103 | downsample = None 104 | if stride != 1 or self.inplanes != planes * block.expansion: 105 | downsample = nn.Sequential( 106 | nn.Conv2d(self.inplanes, planes * block.expansion, 107 | kernel_size=1, stride=stride, bias=False), 108 | SynchronizedBatchNorm2d(planes * block.expansion), 109 | ) 110 | 111 | layers = [] 112 | layers.append(block(self.inplanes, planes, stride, groups, downsample)) 113 | self.inplanes = planes * block.expansion 114 | for i in range(1, blocks): 115 | layers.append(block(self.inplanes, planes, groups=groups)) 116 | 117 | return nn.Sequential(*layers) 118 | 119 | def forward(self, x): 120 | x = self.relu1(self.bn1(self.conv1(x))) 121 | x = self.relu2(self.bn2(self.conv2(x))) 122 | x = self.relu3(self.bn3(self.conv3(x))) 123 | x = self.maxpool(x) 124 | 125 | x = self.layer1(x) 126 | x = self.layer2(x) 127 | x = self.layer3(x) 128 | x = self.layer4(x) 129 | 130 | x = self.avgpool(x) 131 | x = x.view(x.size(0), -1) 132 | x = self.fc(x) 133 | 134 | return x 135 | 136 | 137 | ''' 138 | def resnext50(pretrained=False, **kwargs): 139 | """Constructs a ResNet-50 model. 140 | 141 | Args: 142 | pretrained (bool): If True, returns a model pre-trained on Places 143 | """ 144 | model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs) 145 | if pretrained: 146 | model.load_state_dict(load_url(model_urls['resnext50']), strict=False) 147 | return model 148 | ''' 149 | 150 | 151 | def resnext101(pretrained=False, **kwargs): 152 | """Constructs a ResNet-101 model. 153 | 154 | Args: 155 | pretrained (bool): If True, returns a model pre-trained on Places 156 | """ 157 | model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs) 158 | if pretrained: 159 | model.load_state_dict(load_url(model_urls['resnext101']), strict=False) 160 | return model 161 | 162 | 163 | # def resnext152(pretrained=False, **kwargs): 164 | # """Constructs a ResNeXt-152 model. 165 | # 166 | # Args: 167 | # pretrained (bool): If True, returns a model pre-trained on Places 168 | # """ 169 | # model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs) 170 | # if pretrained: 171 | # model.load_state_dict(load_url(model_urls['resnext152'])) 172 | # return model 173 | 174 | 175 | def load_url(url, model_dir='./pretrained', map_location=None): 176 | if not os.path.exists(model_dir): 177 | os.makedirs(model_dir) 178 | filename = url.split('/')[-1] 179 | cached_file = os.path.join(model_dir, filename) 180 | if not os.path.exists(cached_file): 181 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 182 | urlretrieve(url, cached_file) 183 | return torch.load(cached_file, map_location=map_location) 184 | -------------------------------------------------------------------------------- /netdissect/workerpool.py: -------------------------------------------------------------------------------- 1 | ''' 2 | WorkerPool and WorkerBase for handling the common problems in managing 3 | a multiprocess pool of workers that aren't done by multiprocessing.Pool, 4 | including setup with per-process state, debugging by putting the worker 5 | on the main thread, and correct handling of unexpected errors, and ctrl-C. 6 | 7 | To use it, 8 | 1. Put the per-process setup and the per-task work in the 9 | setup() and work() methods of your own WorkerBase subclass. 10 | 2. To prepare the process pool, instantiate a WorkerPool, passing your 11 | subclass type as the first (worker) argument, as well as any setup keyword 12 | arguments. The WorkerPool will instantiate one of your workers in each 13 | worker process (passing in the setup arguments in those processes). 14 | If debugging, the pool can have process_count=0 to force all the work 15 | to be done immediately on the main thread; otherwise all the work 16 | will be passed to other processes. 17 | 3. Whenever there is a new piece of work to distribute, call pool.add(*args). 18 | The arguments will be queued and passed as worker.work(*args) to the 19 | next available worker. 20 | 4. When all the work has been distributed, call pool.join() to wait for all 21 | the work to complete and to finish and terminate all the worker processes. 22 | When pool.join() returns, all the work will have been done. 23 | 24 | No arrangement is made to collect the results of the work: for example, 25 | the return value of work() is ignored. If you need to collect the 26 | results, use your own mechanism (filesystem, shared memory object, queue) 27 | which can be distributed using setup arguments. 28 | ''' 29 | 30 | from multiprocessing import Process, Queue, cpu_count 31 | import signal 32 | import atexit 33 | import sys 34 | 35 | 36 | class WorkerBase(Process): 37 | ''' 38 | Subclass this class and override its work() method (and optionally, 39 | setup() as well) to define the units of work to be done in a process 40 | worker in a woker pool. 41 | ''' 42 | 43 | def __init__(self, i, process_count, queue, initargs): 44 | if process_count > 0: 45 | # Make sure we ignore ctrl-C if we are not on main process. 46 | signal.signal(signal.SIGINT, signal.SIG_IGN) 47 | self.process_id = i 48 | self.process_count = process_count 49 | self.queue = queue 50 | super(WorkerBase, self).__init__() 51 | self.setup(**initargs) 52 | 53 | def run(self): 54 | # Do the work until None is dequeued 55 | while True: 56 | try: 57 | work_batch = self.queue.get() 58 | except (KeyboardInterrupt, SystemExit): 59 | print('Exiting...') 60 | break 61 | if work_batch is None: 62 | self.queue.put(None) # for another worker 63 | return 64 | self.work(*work_batch) 65 | 66 | def setup(self, **initargs): 67 | ''' 68 | Override this method for any per-process initialization. 69 | Keywoard args are passed from WorkerPool constructor. 70 | ''' 71 | pass 72 | 73 | def work(self, *args): 74 | ''' 75 | Override this method for one-time initialization. 76 | Args are passed from WorkerPool.add() arguments. 77 | ''' 78 | raise NotImplementedError('worker subclass needed') 79 | 80 | 81 | class WorkerPool(object): 82 | ''' 83 | Instantiate this object (passing a WorkerBase subclass type 84 | as its first argument) to create a worker pool. Then call 85 | pool.add(*args) to queue args to distribute to worker.work(*args), 86 | and call pool.join() to wait for all the workers to complete. 87 | ''' 88 | 89 | def __init__(self, worker=WorkerBase, process_count=None, **initargs): 90 | global active_pools 91 | if process_count is None: 92 | process_count = cpu_count() 93 | if process_count == 0: 94 | # zero process_count uses only main process, for debugging. 95 | self.queue = None 96 | self.processes = None 97 | self.worker = worker(None, 0, None, initargs) 98 | return 99 | # Ctrl-C strategy: worker processes should ignore ctrl-C. Set 100 | # this up to be inherited by child processes before forking. 101 | original_sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) 102 | active_pools[id(self)] = self 103 | self.queue = Queue(maxsize=(process_count * 3)) 104 | self.processes = None # Initialize before trying to construct workers 105 | self.processes = [worker(i, process_count, self.queue, initargs) 106 | for i in range(process_count)] 107 | for p in self.processes: 108 | p.start() 109 | # The main process should handle ctrl-C. Restore this now. 110 | signal.signal(signal.SIGINT, original_sigint_handler) 111 | 112 | def add(self, *work_batch): 113 | if self.queue is None: 114 | if hasattr(self, 'worker'): 115 | self.worker.work(*work_batch) 116 | else: 117 | print('WorkerPool shutting down.', file=sys.stderr) 118 | else: 119 | try: 120 | # The queue can block if the work is so slow it gets full. 121 | self.queue.put(work_batch) 122 | except (KeyboardInterrupt, SystemExit): 123 | # Handle ctrl-C if done while waiting for the queue. 124 | self.early_terminate() 125 | 126 | def join(self): 127 | # End the queue, and wait for all worker processes to complete nicely. 128 | if self.queue is not None: 129 | self.queue.put(None) 130 | for p in self.processes: 131 | p.join() 132 | self.queue = None 133 | # Remove myself from the set of pools that need cleanup on shutdown. 134 | try: 135 | del active_pools[id(self)] 136 | except: 137 | pass 138 | 139 | def early_terminate(self): 140 | # When shutting down unexpectedly, first end the queue. 141 | if self.queue is not None: 142 | try: 143 | self.queue.put_nowait(None) # Nonblocking put throws if full. 144 | self.queue = None 145 | except: 146 | pass 147 | # But then don't wait: just forcibly terminate workers. 148 | if self.processes is not None: 149 | for p in self.processes: 150 | p.terminate() 151 | self.processes = None 152 | try: 153 | del active_pools[id(self)] 154 | except: 155 | pass 156 | 157 | def __del__(self): 158 | if self.queue is not None: 159 | print('ERROR: workerpool.join() not called!', file=sys.stderr) 160 | self.join() 161 | 162 | 163 | # Error and ctrl-C handling: kill worker processes if the main process ends. 164 | active_pools = {} 165 | 166 | 167 | def early_terminate_pools(): 168 | for _, pool in list(active_pools.items()): 169 | pool.early_terminate() 170 | 171 | 172 | atexit.register(early_terminate_pools) 173 | -------------------------------------------------------------------------------- /netdissect/zdataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy 3 | import itertools 4 | from torch.utils.data import TensorDataset 5 | 6 | 7 | def z_dataset_for_model(model, size=100, seed=1, indices=None): 8 | if indices is not None: 9 | indices = torch.as_tensor(indices, dtype=torch.int64, device='cpu') 10 | zs = z_sample_for_model(model, indices.max().item() + 1, seed) 11 | zs = zs[indices] 12 | else: 13 | zs = z_sample_for_model(model, size, seed) 14 | return TensorDataset(zs) 15 | 16 | 17 | def z_sample_for_model(model, size=100, seed=1): 18 | # If the model is marked with an input shape, use it. 19 | if hasattr(model, 'input_shape'): 20 | sample = standard_z_sample(size, model.input_shape[1], seed=seed).view( 21 | (size,) + model.input_shape[1:]) 22 | return sample 23 | # Examine first conv in model to determine input feature size. 24 | first_layer = [c for c in model.modules() 25 | if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d, 26 | torch.nn.Linear))][0] 27 | # 4d input if convolutional, 2d input if first layer is linear. 28 | if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): 29 | sample = standard_z_sample( 30 | size, first_layer.in_channels, seed=seed)[:, :, None, None] 31 | else: 32 | sample = standard_z_sample( 33 | size, first_layer.in_features, seed=seed) 34 | return sample 35 | 36 | 37 | def standard_z_sample(size, depth, seed=1, device=None): 38 | ''' 39 | Generate a standard set of random Z as a (size, z_dimension) tensor. 40 | With the same random seed, it always returns the same z (e.g., 41 | the first one is always the same regardless of the size.) 42 | ''' 43 | # Use numpy RandomState since it can be done deterministically 44 | # without affecting global state 45 | rng = numpy.random.RandomState(seed) 46 | result = torch.from_numpy( 47 | rng.standard_normal(size * depth) 48 | .reshape(size, depth)).float() 49 | if device is not None: 50 | result = result.to(device) 51 | return result 52 | 53 | 54 | def standard_y_sample(size, num_classes, seed=1, device=None): 55 | ''' 56 | Generate a standard set of random categorical as a (size,) tensor 57 | of integers up to (num_classes-1). 58 | With the same random seed, it always returns the same y (e.g., 59 | the first one is always the same regardless of the size.) 60 | ''' 61 | # Use numpy RandomState since it can be done deterministically 62 | # without affecting global state 63 | rng = numpy.random.RandomState(seed) 64 | result = torch.from_numpy( 65 | rng.randint(num_classes, size=size)).long() 66 | if device is not None: 67 | result = result.to(device) 68 | return result 69 | 70 | 71 | def training_loader(z_generator, batch_size, loader_size=10000): 72 | ''' 73 | Returns an infinite generator that runs through randomized z 74 | batches, forever. 75 | ''' 76 | g_epoch = 1 77 | while True: 78 | z_data = z_dataset_for_model( 79 | z_generator, size=loader_size, seed=g_epoch + 1) 80 | dataloader = torch.utils.data.DataLoader( 81 | z_data, 82 | shuffle=False, 83 | batch_size=batch_size, 84 | num_workers=10, 85 | pin_memory=True) 86 | for batch in dataloader: 87 | yield batch 88 | g_epoch += 1 89 | 90 | 91 | def testing_loader(z_generator, batch_size, test_size=1000): 92 | ''' 93 | Returns an a short iterator that returns a small set of test data. 94 | ''' 95 | z_data = z_dataset_for_model( 96 | z_generator, size=test_size, seed=1) 97 | dataloader = torch.utils.data.DataLoader( 98 | z_data, 99 | shuffle=False, 100 | batch_size=batch_size, 101 | num_workers=10, 102 | pin_memory=True) 103 | return dataloader 104 | 105 | 106 | def epoch_grouper(loader, epoch_size): 107 | ''' 108 | To use with the infinite training loader: groups the training data 109 | batches into epochs of the given size. 110 | ''' 111 | it = iter(loader) 112 | while True: 113 | chunk_it = itertools.islice(it, epoch_size) 114 | try: 115 | first_el = next(chunk_it) 116 | except StopIteration: 117 | return 118 | yield itertools.chain((first_el,), chunk_it) 119 | -------------------------------------------------------------------------------- /notebooks/ipynb_drop_output.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Suppress output and prompt numbers in git version control. 5 | 6 | This script will tell git to ignore prompt numbers and cell output 7 | when looking at ipynb files UNLESS their metadata contains: 8 | 9 | "git" : { "keep_output" : true } 10 | 11 | The notebooks themselves are not changed. 12 | 13 | See also this blogpost: http://pascalbugnion.net/blog/ipython-notebooks-and-git.html. 14 | 15 | Usage instructions 16 | ================== 17 | 18 | 1. Put this script in a directory that is on the system's path. 19 | For future reference, I will assume you saved it in 20 | `~/scripts/ipynb_drop_output`. 21 | 2. Make sure it is executable by typing the command 22 | `chmod +x ~/scripts/ipynb_drop_output`. 23 | 3. Register a filter for ipython notebooks by 24 | putting the following line in `~/.config/git/attributes`: 25 | `*.ipynb filter=clean_ipynb` 26 | 4. Connect this script to the filter by running the following 27 | git commands: 28 | 29 | git config --global filter.clean_ipynb.clean ipynb_drop_output 30 | git config --global filter.clean_ipynb.smudge cat 31 | 32 | To tell git NOT to ignore the output and prompts for a notebook, 33 | open the notebook's metadata (Edit > Edit Notebook Metadata). A 34 | panel should open containing the lines: 35 | 36 | { 37 | "name" : "", 38 | "signature" : "some very long hash" 39 | } 40 | 41 | Add an extra line so that the metadata now looks like: 42 | 43 | { 44 | "name" : "", 45 | "signature" : "don't change the hash, but add a comma at the end of the line", 46 | "git" : { "keep_outputs" : true } 47 | } 48 | 49 | You may need to "touch" the notebooks for git to actually register a change, if 50 | your notebooks are already under version control. 51 | 52 | Notes 53 | ===== 54 | 55 | Changed by David Bau to make stripping output the default. 56 | 57 | This script is inspired by http://stackoverflow.com/a/20844506/827862, but 58 | lets the user specify whether the ouptut of a notebook should be kept 59 | in the notebook's metadata, and works for IPython v3.0. 60 | """ 61 | 62 | import sys 63 | import json 64 | 65 | nb = sys.stdin.read() 66 | 67 | json_in = json.loads(nb) 68 | nb_metadata = json_in["metadata"] 69 | keep_output = False 70 | if "git" in nb_metadata: 71 | if "keep_outputs" in nb_metadata["git"] and nb_metadata["git"]["keep_outputs"]: 72 | keep_output = True 73 | if keep_output: 74 | sys.stdout.write(nb) 75 | exit() 76 | 77 | 78 | ipy_version = int(json_in["nbformat"])-1 # nbformat is 1 more than actual version. 79 | 80 | def strip_output_from_cell(cell): 81 | if "outputs" in cell: 82 | cell["outputs"] = [] 83 | if "prompt_number" in cell: 84 | del cell["prompt_number"] 85 | if "execution_count" in cell: 86 | cell["execution_count"] = None 87 | 88 | 89 | if ipy_version == 2: 90 | for sheet in json_in["worksheets"]: 91 | for cell in sheet["cells"]: 92 | strip_output_from_cell(cell) 93 | else: 94 | for cell in json_in["cells"]: 95 | strip_output_from_cell(cell) 96 | 97 | json.dump(json_in, sys.stdout, sort_keys=True, indent=1, separators=(",",": ")) 98 | -------------------------------------------------------------------------------- /notebooks/setup_notebooks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Start from directory of script 4 | cd "$(dirname "${BASH_SOURCE[0]}")" 5 | 6 | # Set up git config filters so huge output of notebooks is not committed. 7 | git config filter.clean_ipynb.clean "$(pwd)/ipynb_drop_output.py" 8 | git config filter.clean_ipynb.smudge cat 9 | git config filter.clean_ipynb.required true 10 | 11 | # Set up symlinks for the example notebooks 12 | for DIRNAME in datasets results netdissect experiment 13 | do 14 | ln -sfn ../${DIRNAME} . 15 | done 16 | -------------------------------------------------------------------------------- /notebooks/shapebias_experiment.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from netdissect import parallelfolder, show, tally, nethook, renormalize\n", 10 | "from experiment import readdissect, setting\n", 11 | "import copy, PIL.Image\n", 12 | "from netdissect import upsample, imgsave, imgviz\n", 13 | "import re, torchvision, torch, os\n", 14 | "from IPython.display import SVG\n", 15 | "from matplotlib import pyplot as plt\n", 16 | "\n", 17 | "def normalize_filename(n):\n", 18 | " return re.match(r'^(.*Places365_\\w+_\\d+)', n).group(1)\n", 19 | "\n", 20 | "ds = parallelfolder.ParallelImageFolders(\n", 21 | " ['datasets/places/val', 'datasets/stylized-places/val'],\n", 22 | " transform=torchvision.transforms.Compose([\n", 23 | " torchvision.transforms.Resize(256),\n", 24 | " # transforms.CenterCrop(224),\n", 25 | " torchvision.transforms.CenterCrop(256),\n", 26 | " torchvision.transforms.ToTensor(),\n", 27 | " renormalize.NORMALIZER['imagenet'],\n", 28 | " ]),\n", 29 | " normalize_filename=normalize_filename,\n", 30 | " shuffle=True)\n", 31 | "\n", 32 | "\n", 33 | "layers = [\n", 34 | " 'conv5_3',\n", 35 | " 'conv5_2',\n", 36 | " 'conv5_1',\n", 37 | " 'conv4_3',\n", 38 | " 'conv4_2',\n", 39 | " 'conv4_1',\n", 40 | " 'conv3_3',\n", 41 | " 'conv3_2',\n", 42 | " 'conv3_1',\n", 43 | " 'conv2_2',\n", 44 | " 'conv2_1',\n", 45 | " 'conv1_2',\n", 46 | " 'conv1_1',\n", 47 | "]\n", 48 | "qd = readdissect.DissectVis(layers=layers)\n", 49 | "net = setting.load_classifier('vgg16')\n", 50 | "\n", 51 | "sds = parallelfolder.ParallelImageFolders(\n", 52 | " ['datasets/stylized-places/val'],\n", 53 | " transform=torchvision.transforms.Compose([\n", 54 | " torchvision.transforms.Resize(256),\n", 55 | " # transforms.CenterCrop(224),\n", 56 | " torchvision.transforms.CenterCrop(256),\n", 57 | " torchvision.transforms.ToTensor(),\n", 58 | " renormalize.NORMALIZER['imagenet'],\n", 59 | " ]),\n", 60 | " normalize_filename=normalize_filename,\n", 61 | " shuffle=True)\n", 62 | "\n", 63 | "def s_image(layername, unit):\n", 64 | " result = PIL.Image.open(os.path.join(qd.dir(layername), 's_imgs/unit%d.jpg' % unit))\n", 65 | " result.load()\n", 66 | " return result\n", 67 | "\n", 68 | "\n", 69 | "def su_image(layername, unit):\n", 70 | " result = PIL.Image.open(os.path.join(qd.dir(layername), 'su_imgs/unit%d.jpg' % unit))\n", 71 | " result.load()\n", 72 | " return result\n" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": { 79 | "scrolled": false 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "for layername in layers:\n", 84 | " inst_net = nethook.InstrumentedModel(copy.deepcopy(net)).cuda()\n", 85 | " inst_net.retain_layer('features.' + layername)\n", 86 | " inst_net(ds[0][0][None].cuda())\n", 87 | " sample_act = inst_net.retained_layer('features.' + layername).cpu()\n", 88 | " upfn = upsample.upsampler((64, 64), sample_act.shape[2:])\n", 89 | "\n", 90 | " def flat_acts(batch):\n", 91 | " inst_net(batch.cuda())\n", 92 | " acts = upfn(inst_net.retained_layer('features.' + layername))\n", 93 | " return acts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])\n", 94 | " s_rq = tally.tally_quantile(flat_acts, sds, cachefile=os.path.join(qd.dir(layername), 's_rq.npz'))\n", 95 | " u_rq = qd.rq(layername)\n", 96 | "\n", 97 | " def intersect_99_fn(uimg, simg):\n", 98 | " s_99 = s_rq.quantiles(0.99)[None,:,None,None].cuda()\n", 99 | " u_99 = u_rq.quantiles(0.99)[None,:,None,None].cuda()\n", 100 | " with torch.no_grad():\n", 101 | " ux, sx = uimg.cuda(), simg.cuda()\n", 102 | " inst_net(ux)\n", 103 | " ur = inst_net.retained_layer('features.' + layername)\n", 104 | " inst_net(sx)\n", 105 | " sr = inst_net.retained_layer('features.' + layername)\n", 106 | " return ((sr > s_99).float() * (ur > u_99).float()).permute(0, 2, 3, 1).reshape(-1, ur.size(1))\n", 107 | " \n", 108 | " intersect_99 = tally.tally_mean(intersect_99_fn, ds,\n", 109 | " cachefile=os.path.join(qd.dir(layername), 'intersect_99.npz'))\n", 110 | " print(layername)\n", 111 | " numerator = intersect_99.mean()\n", 112 | " denominator = (0.02 - intersect_99.mean())\n", 113 | " score = (numerator / denominator).clamp(0, 1)\n", 114 | " plt.plot(score)\n", 115 | " plt.show()\n", 116 | " fig, ax = plt.subplots(1, 1, figsize=(3,1.2), dpi=300)\n", 117 | " ax.hist(score)\n", 118 | " ax.set_ylabel('%s units' % (layername.replace('features.', '')))\n", 119 | " ax.spines['right'].set_visible(False)\n", 120 | " ax.spines['top'].set_visible(False)\n", 121 | " # ax.set_xlabel('unit IoU (stylized vs original)')\n", 122 | " plt.show()\n", 123 | " labelcat_list_h = []\n", 124 | " labelcat_list_l = []\n", 125 | " for i, rec in enumerate(qd.labels[layername]):\n", 126 | " if rec['iou'] and float(rec['iou']) >= 0.04:\n", 127 | " if score[i] > 0.1:\n", 128 | " labelcat_list_h.append((rec['label'], rec['cat']))\n", 129 | " else:\n", 130 | " labelcat_list_l.append((rec['label'], rec['cat']))\n", 131 | " display(SVG(qd.bargraph_from_conceptcatlist(labelcat_list_l)))\n", 132 | " display(SVG(qd.bargraph_from_conceptcatlist(labelcat_list_h)))\n", 133 | " \n", 134 | " ordering = score.sort()[1]\n", 135 | "\n", 136 | " for i in torch.cat([ordering[:5], ordering[-10:]]):\n", 137 | " #if qd.iou(layername, i) > 0.04:\n", 138 | " print(i.item(), score[i].item(), qd.label(layername, i), qd.iou(layername, i))\n", 139 | " display(qd.image(layername, i))\n", 140 | " display(s_image(layername, i))\n", 141 | "\n", 142 | " #result = [qd.iou(layername, i) for i in ordering]\n", 143 | " #plt.plot(result)" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": { 150 | "scrolled": false 151 | }, 152 | "outputs": [], 153 | "source": [ 154 | "fig, axes = plt.subplots(5, 1, figsize=(5,6), dpi=300, sharex=True)\n", 155 | "plotlayers = [\n", 156 | " 'features.conv1_2',\n", 157 | " 'features.conv2_2',\n", 158 | " 'features.conv3_3',\n", 159 | " 'features.conv4_3',\n", 160 | " 'features.conv5_3',\n", 161 | "]\n", 162 | "for i, layername in enumerate(plotlayers):\n", 163 | " inst_net = nethook.InstrumentedModel(copy.deepcopy(net)).cuda()\n", 164 | " inst_net.retain_layer(layername)\n", 165 | " \n", 166 | " def flat_acts(batch):\n", 167 | " inst_net(batch.cuda())\n", 168 | " acts = upfn(inst_net.retained_layer(layername))\n", 169 | " return acts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])\n", 170 | " s_rq = tally.tally_quantile(flat_acts, sds, cachefile=os.path.join(qd.dir(layername), 's_rq.npz'))\n", 171 | " u_rq = qd.rq(layername)\n", 172 | "\n", 173 | " def intersect_99_fn(uimg, simg):\n", 174 | " s_99 = s_rq.quantiles(0.99)[None,:,None,None].cuda()\n", 175 | " u_99 = u_rq.quantiles(0.99)[None,:,None,None].cuda()\n", 176 | " with torch.no_grad():\n", 177 | " ux, sx = uimg.cuda(), simg.cuda()\n", 178 | " inst_net(ux)\n", 179 | " ur = inst_net.retained_layer(layername)\n", 180 | " inst_net(sx)\n", 181 | " sr = inst_net.retained_layer(layername)\n", 182 | " return ((sr > s_99).float() * (ur > u_99).float()).permute(0, 2, 3, 1).reshape(-1, ur.size(1))\n", 183 | " \n", 184 | " intersect_99 = tally.tally_mean(intersect_99_fn, ds,\n", 185 | " cachefile=os.path.join(qd.dir(layername), 'intersect_99.npz'))\n", 186 | " numerator = intersect_99.mean()\n", 187 | " denominator = (0.02 - intersect_99.mean())\n", 188 | " score = (numerator / denominator).clamp(0, 0.5)\n", 189 | " ax = axes[i]\n", 190 | " ax.hist(score)\n", 191 | " # ax.set_ylabel('%s' % (layername.replace('features.', '')))\n", 192 | " ax.spines['right'].set_visible(False)\n", 193 | " ax.spines['top'].set_visible(False)\n", 194 | " # ax.set_xlabel('unit IoU (stylized vs original)')\n", 195 | "plt.show()\n" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "for u in [166, 107, 268, 434, 436, 437, 73, 220, 299, 494, 485, 477, 462, 338]:\n", 205 | " print(u, score[u].item())" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "qd.dirs" 215 | ] 216 | } 217 | ], 218 | "metadata": { 219 | "kernelspec": { 220 | "display_name": "Python 3", 221 | "language": "python", 222 | "name": "python3" 223 | }, 224 | "language_info": { 225 | "codemirror_mode": { 226 | "name": "ipython", 227 | "version": 3 228 | }, 229 | "file_extension": ".py", 230 | "mimetype": "text/x-python", 231 | "name": "python", 232 | "nbconvert_exporter": "python", 233 | "pygments_lexer": "ipython3", 234 | "version": "3.6.10" 235 | } 236 | }, 237 | "nbformat": 4, 238 | "nbformat_minor": 4 239 | } -------------------------------------------------------------------------------- /setup/denv.yml: -------------------------------------------------------------------------------- 1 | name: denv 2 | channels: 3 | - pytorch 4 | - ostrokach-forge 5 | - conda-forge 6 | dependencies: 7 | - python=3.6 8 | - cudatoolkit=10.2 9 | - cudnn=7.6.0 10 | - pytorch=1.5.1 11 | - torchvision 12 | - mkl-include 13 | - numpy 14 | - scipy 15 | - scikit-learn 16 | - scikit-image 17 | - matplotlib 18 | - seaborn 19 | - numba 20 | - jupyter 21 | - jupyterlab 22 | - pyyaml 23 | - mkl 24 | - tqdm 25 | - pip 26 | -------------------------------------------------------------------------------- /setup/setup_denv.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Bash script to set up an anaconda python-based deep learning environment 4 | # that has support for pytorch, tensorflow, pycaffe in the same environment, 5 | # long with juypter, scipy etc. 6 | 7 | # This should not require root. However, it does copy and build a lot of 8 | # binaries into your ~/.conda directory. If you do not want to store 9 | # these in your homedir disk, then ~/.conda can be a symlink somewhere else. 10 | # (At MIT CSAIL, you should symlink ~/.conda to a directory on NFS or local 11 | # disk instead of leaving it on AFS, or else you will exhaust your quota.) 12 | 13 | # Start from parent directory of script 14 | cd "$(dirname "$(dirname "$(readlink -f "$0")")")" 15 | 16 | # Default RECIPE 'denv' can be overridden by 'RECIPE=foo setup.sh' 17 | RECIPE=${RECIPE:-denv} 18 | # Default ENV_NAME 'denv' can be overridden by 'ENV_NAME=foo setup.sh' 19 | ENV_NAME="${ENV_NAME:-${RECIPE}}" 20 | echo "Creating conda environment ${ENV_NAME}" 21 | 22 | if [[ ! $(type -P conda) ]] 23 | then 24 | echo "conda not in PATH" 25 | echo "read: https://conda.io/docs/user-guide/install/index.html" 26 | exit 1 27 | fi 28 | 29 | if df "${HOME}/.conda" --type=afs > /dev/null 2>&1 30 | then 31 | echo "Not installing: your ~/.conda directory is on AFS." 32 | echo "Use 'ln -s /some/nfs/dir ~/.conda' to avoid using up your AFS quota." 33 | exit 1 34 | fi 35 | 36 | # Uninstall existing environment 37 | source deactivate 38 | rm -rf ~/.conda/envs/${ENV_NAME} 39 | 40 | # Build new environment: torch and torch vision from source 41 | # CUDA_HOME is needed 42 | # https://github.com/rusty1s/pytorch_scatter/issues/19#issuecomment-449735614 43 | conda env create --name=${ENV_NAME} -f setup/${RECIPE}.yml 44 | 45 | # Set up CUDA_HOME to set itself up correctly on every source activate 46 | # https://stackoverflow.com/questions/31598963 47 | mkdir -p ~/.conda/envs/${ENV_NAME}/etc/conda/activate.d 48 | echo "export CUDA_HOME=/usr/local/cuda-10.2" > \ 49 | ~/.conda/envs/${ENV_NAME}/etc/conda/activate.d/CUDA_HOME.sh 50 | 51 | source activate ${ENV_NAME} 52 | 53 | -------------------------------------------------------------------------------- /stylization/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 MIT CSAIL and David Bau 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /stylization/README.md: -------------------------------------------------------------------------------- 1 | # stylize-datasets 2 | This repository contains code for stylizing arbitrary image datasets using [AdaIN](https://arxiv.org/abs/1703.06868). The code is a generalization of Robert Geirhos' [Stylized-ImageNet](https://github.com/rgeirhos/Stylized-ImageNet) code, which is tailored to stylizing ImageNet. Everything in this repository is based on naoto0804's [pytorch-AdaIN](https://github.com/naoto0804/pytorch-AdaIN) implementation. 3 | 4 | Given an image dataset, the script creates the specified number of stylized versions of every image while keeping the directory structure and naming scheme intact (usefull for existing data loaders or if directory names include class annotations). 5 | 6 | Feel free to open an issue in case there is any question. 7 | 8 | ## Usage 9 | - Dependencies: 10 | - python >= 3.6 11 | - Pillow 12 | - torch 13 | - torchvision 14 | - tqdm 15 | - Download the models: 16 | - either run run `bash models/download_models.sh` or download the models manually from [vgg](https://drive.google.com/file/d/108uza-dsmwvbW2zv-G73jtVcMU_2Nb7Y/view)/[decoder](https://drive.google.com/file/d/1w9r1NoYnn7tql1VYG3qDUzkbIks24RBQ/view) and move both files to the `models/` directory 17 | - Get style images: Download train.zip from [Kaggle's painter-by-numbers dataset](https://www.kaggle.com/c/painter-by-numbers/data) 18 | - To stylize a dataset, run `python stylize.py`. 19 | 20 | Arguments: 21 | - `--content-dir ` the top-level directory of the content image dataset (mandatory) 22 | - `--style-dir ` the top-level directory of the style images (mandatory) 23 | - `--output-dir ` the directory where the stylized dataset will be stored (optional, default: `output/`) 24 | - `--num-styles ` number of stylizations to create for each content image (optional, default: `1`) 25 | - `--alpha ` Weight that controls the strength of stylization, should be between 0 and 1 (optional, default: `1`) 26 | - `--extensions ...` list of image extensions to scan style and content directory for (optional, default: `png, jpeg, jpg`). Note: this is case sensitive, `--extensions jpg` will not scan for files ending on `.JPG`. Image types must be compatible with PIL's `Image.open()` ([Documentation](https://pillow.readthedocs.io/en/5.1.x/handbook/image-file-formats.html)) 27 | - `--content-size ` Minimum size for content images, resulting in scaling of the shorter side of the content image to `N` (optional, default: `0`). Set this to 0 to keep the original image dimensions. 28 | - `--style-size ` Minimum size for style images, resulting in scaling of the shorter side of the style image to `N` (optional, default: `512`). Set this to 0 to keep the original image dimensions (for large style images, this will result in high (GPU) memory consumption). 29 | - `--crop` If set, content and style images will be cropped at the center to create square output images 30 | 31 | Here is an example call: 32 | 33 | ``` 34 | python3 stylize.py --content-dir '/home/username/stylize-datasets/images/' --style-dir '/home/username/stylize-datasets/train/' --num-styles 10 --content_size 0 --style_size 256 35 | ``` 36 | -------------------------------------------------------------------------------- /stylization/function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def calc_mean_std(feat, eps=1e-5): 5 | # eps is a small value added to the variance to avoid divide-by-zero. 6 | size = feat.data.size() 7 | assert (len(size) == 4) 8 | N, C = size[:2] 9 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 10 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 11 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 12 | return feat_mean, feat_std 13 | 14 | 15 | def adaptive_instance_normalization(content_feat, style_feat): 16 | assert (content_feat.data.size()[:2] == style_feat.data.size()[:2]) 17 | size = content_feat.data.size() 18 | style_mean, style_std = calc_mean_std(style_feat) 19 | content_mean, content_std = calc_mean_std(content_feat) 20 | 21 | normalized_feat = (content_feat - content_mean.expand( 22 | size)) / content_std.expand(size) 23 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 24 | 25 | 26 | def _calc_feat_flatten_mean_std(feat): 27 | # takes 3D feat (C, H, W), return mean and std of array within channels 28 | assert (feat.size()[0] == 3) 29 | assert (isinstance(feat, torch.FloatTensor)) 30 | feat_flatten = feat.view(3, -1) 31 | mean = feat_flatten.mean(dim=-1, keepdim=True) 32 | std = feat_flatten.std(dim=-1, keepdim=True) 33 | return feat_flatten, mean, std 34 | 35 | 36 | def _mat_sqrt(x): 37 | U, D, V = torch.svd(x) 38 | return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t()) 39 | 40 | 41 | def coral(source, target): 42 | # assume both source and target are 3D array (C, H, W) 43 | # Note: flatten -> f 44 | 45 | source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source) 46 | source_f_norm = (source_f - source_f_mean.expand_as( 47 | source_f)) / source_f_std.expand_as(source_f) 48 | source_f_cov_eye = \ 49 | torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3) 50 | 51 | target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target) 52 | target_f_norm = (target_f - target_f_mean.expand_as( 53 | target_f)) / target_f_std.expand_as(target_f) 54 | target_f_cov_eye = \ 55 | torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3) 56 | 57 | source_f_norm_transfer = torch.mm( 58 | _mat_sqrt(target_f_cov_eye), 59 | torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)), 60 | source_f_norm) 61 | ) 62 | 63 | source_f_transfer = source_f_norm_transfer * \ 64 | target_f_std.expand_as(source_f_norm) + \ 65 | target_f_mean.expand_as(source_f_norm) 66 | 67 | return source_f_transfer.view(source.size()) 68 | -------------------------------------------------------------------------------- /stylization/models/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd models 3 | # The VGG-19 network is obtained by: 4 | # 1. converting vgg_normalised.caffemodel to .t7 using loadcaffe 5 | # 2. inserting a convolutional module at the beginning to preprocess the image 6 | # 3. replacing zero-padding with reflection-padding 7 | # The original vgg_normalised.caffemodel can be obtained with: 8 | # "wget -c --no-check-certificate https://bethgelab.org/media/uploads/deeptextures/vgg_normalised.caffemodel" 9 | wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=108uza-dsmwvbW2zv-G73jtVcMU_2Nb7Y' -O vgg_normalised.pth 10 | wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1w9r1NoYnn7tql1VYG3qDUzkbIks24RBQ' -O decoder.pth 11 | cd .. 12 | -------------------------------------------------------------------------------- /stylization/net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Variable 3 | 4 | from function import adaptive_instance_normalization as adain 5 | from function import calc_mean_std 6 | 7 | decoder = nn.Sequential( 8 | nn.ReflectionPad2d((1, 1, 1, 1)), 9 | nn.Conv2d(512, 256, (3, 3)), 10 | nn.ReLU(), 11 | nn.Upsample(scale_factor=2), 12 | nn.ReflectionPad2d((1, 1, 1, 1)), 13 | nn.Conv2d(256, 256, (3, 3)), 14 | nn.ReLU(), 15 | nn.ReflectionPad2d((1, 1, 1, 1)), 16 | nn.Conv2d(256, 256, (3, 3)), 17 | nn.ReLU(), 18 | nn.ReflectionPad2d((1, 1, 1, 1)), 19 | nn.Conv2d(256, 256, (3, 3)), 20 | nn.ReLU(), 21 | nn.ReflectionPad2d((1, 1, 1, 1)), 22 | nn.Conv2d(256, 128, (3, 3)), 23 | nn.ReLU(), 24 | nn.Upsample(scale_factor=2), 25 | nn.ReflectionPad2d((1, 1, 1, 1)), 26 | nn.Conv2d(128, 128, (3, 3)), 27 | nn.ReLU(), 28 | nn.ReflectionPad2d((1, 1, 1, 1)), 29 | nn.Conv2d(128, 64, (3, 3)), 30 | nn.ReLU(), 31 | nn.Upsample(scale_factor=2), 32 | nn.ReflectionPad2d((1, 1, 1, 1)), 33 | nn.Conv2d(64, 64, (3, 3)), 34 | nn.ReLU(), 35 | nn.ReflectionPad2d((1, 1, 1, 1)), 36 | nn.Conv2d(64, 3, (3, 3)), 37 | ) 38 | 39 | vgg = nn.Sequential( 40 | nn.Conv2d(3, 3, (1, 1)), 41 | nn.ReflectionPad2d((1, 1, 1, 1)), 42 | nn.Conv2d(3, 64, (3, 3)), 43 | nn.ReLU(), # relu1-1 44 | nn.ReflectionPad2d((1, 1, 1, 1)), 45 | nn.Conv2d(64, 64, (3, 3)), 46 | nn.ReLU(), # relu1-2 47 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 48 | nn.ReflectionPad2d((1, 1, 1, 1)), 49 | nn.Conv2d(64, 128, (3, 3)), 50 | nn.ReLU(), # relu2-1 51 | nn.ReflectionPad2d((1, 1, 1, 1)), 52 | nn.Conv2d(128, 128, (3, 3)), 53 | nn.ReLU(), # relu2-2 54 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 55 | nn.ReflectionPad2d((1, 1, 1, 1)), 56 | nn.Conv2d(128, 256, (3, 3)), 57 | nn.ReLU(), # relu3-1 58 | nn.ReflectionPad2d((1, 1, 1, 1)), 59 | nn.Conv2d(256, 256, (3, 3)), 60 | nn.ReLU(), # relu3-2 61 | nn.ReflectionPad2d((1, 1, 1, 1)), 62 | nn.Conv2d(256, 256, (3, 3)), 63 | nn.ReLU(), # relu3-3 64 | nn.ReflectionPad2d((1, 1, 1, 1)), 65 | nn.Conv2d(256, 256, (3, 3)), 66 | nn.ReLU(), # relu3-4 67 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 68 | nn.ReflectionPad2d((1, 1, 1, 1)), 69 | nn.Conv2d(256, 512, (3, 3)), 70 | nn.ReLU(), # relu4-1, this is the last layer used 71 | nn.ReflectionPad2d((1, 1, 1, 1)), 72 | nn.Conv2d(512, 512, (3, 3)), 73 | nn.ReLU(), # relu4-2 74 | nn.ReflectionPad2d((1, 1, 1, 1)), 75 | nn.Conv2d(512, 512, (3, 3)), 76 | nn.ReLU(), # relu4-3 77 | nn.ReflectionPad2d((1, 1, 1, 1)), 78 | nn.Conv2d(512, 512, (3, 3)), 79 | nn.ReLU(), # relu4-4 80 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 81 | nn.ReflectionPad2d((1, 1, 1, 1)), 82 | nn.Conv2d(512, 512, (3, 3)), 83 | nn.ReLU(), # relu5-1 84 | nn.ReflectionPad2d((1, 1, 1, 1)), 85 | nn.Conv2d(512, 512, (3, 3)), 86 | nn.ReLU(), # relu5-2 87 | nn.ReflectionPad2d((1, 1, 1, 1)), 88 | nn.Conv2d(512, 512, (3, 3)), 89 | nn.ReLU(), # relu5-3 90 | nn.ReflectionPad2d((1, 1, 1, 1)), 91 | nn.Conv2d(512, 512, (3, 3)), 92 | nn.ReLU() # relu5-4 93 | ) 94 | 95 | 96 | class Net(nn.Module): 97 | def __init__(self, encoder, decoder): 98 | super(Net, self).__init__() 99 | enc_layers = list(encoder.children()) 100 | self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1 101 | self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1 102 | self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1 103 | self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1 104 | self.decoder = decoder 105 | self.mse_loss = nn.MSELoss() 106 | 107 | # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image 108 | def encode_with_intermediate(self, input): 109 | results = [input] 110 | for i in range(4): 111 | func = getattr(self, 'enc_{:d}'.format(i + 1)) 112 | results.append(func(results[-1])) 113 | return results[1:] 114 | 115 | # extract relu4_1 from input image 116 | def encode(self, input): 117 | for i in range(4): 118 | input = getattr(self, 'enc_{:d}'.format(i + 1))(input) 119 | return input 120 | 121 | def calc_content_loss(self, input, target): 122 | assert (input.data.size() == target.data.size()) 123 | assert (target.requires_grad is False) 124 | return self.mse_loss(input, target) 125 | 126 | def calc_style_loss(self, input, target): 127 | assert (input.data.size() == target.data.size()) 128 | assert (target.requires_grad is False) 129 | input_mean, input_std = calc_mean_std(input) 130 | target_mean, target_std = calc_mean_std(target) 131 | return self.mse_loss(input_mean, target_mean) + \ 132 | self.mse_loss(input_std, target_std) 133 | 134 | def forward(self, content, style): 135 | style_feats = self.encode_with_intermediate(style) 136 | t = adain(self.encode(content), style_feats[-1]) 137 | 138 | g_t = self.decoder(Variable(t.data, requires_grad=True)) 139 | g_t_feats = self.encode_with_intermediate(g_t) 140 | 141 | loss_c = self.calc_content_loss(g_t_feats[-1], t) 142 | loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0]) 143 | for i in range(1, 4): 144 | loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i]) 145 | return loss_c, loss_s 146 | -------------------------------------------------------------------------------- /stylization/parallel-make-sp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | # Start from directory of script 6 | cd "$(dirname "${BASH_SOURCE[0]}")" 7 | 8 | # Make sure the datasets directory is symlinked 9 | ln -sfn ../datasets data 10 | 11 | # Only create a stylized validation set 12 | for SPLIT in val # train 13 | do 14 | 15 | DATASET=data/places/${SPLIT} 16 | OUTPUT=data/stylized-places/${SPLIT} 17 | DIRS=$(ls -1 $DATASET) 18 | 19 | # Loop through every class dir: this script can be run in parallel 20 | for D in $DIRS 21 | do 22 | 23 | # Only one process should work on a directory at once 24 | if [ ! -d "${OUTPUT}/${D}" ] 25 | then 26 | 27 | # If the process fails, then remove the whole directory 28 | mkdir -p "${OUTPUT}/${D}" 29 | trap "rm -rf ${OUTPUT}/${D}; exit" INT TERM EXIT 30 | 31 | python3 stylize.py \ 32 | --content-dir "${DATASET}/${D}" \ 33 | --output-dir "${OUTPUT}/${D}" \ 34 | --style-dir 'data/painter-by-numbers/train/' \ 35 | --num-styles 1 \ 36 | --content-size 0 \ 37 | --style-size 256 38 | 39 | # Mark the directory as complete 40 | date > "${OUTPUT}/${D}/done.txt" 41 | trap - INT TERM EXIT 42 | 43 | fi 44 | 45 | done 46 | 47 | done 48 | -------------------------------------------------------------------------------- /stylization/stylize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | from function import adaptive_instance_normalization 4 | import net 5 | from pathlib import Path 6 | from PIL import Image 7 | import random 8 | import torch 9 | import torch.nn as nn 10 | import torchvision.transforms 11 | from torchvision.utils import save_image 12 | from tqdm import tqdm 13 | import zlib # For hash 14 | 15 | parser = argparse.ArgumentParser(description='This script applies the AdaIN style transfer method to arbitrary datasets.') 16 | parser.add_argument('--content-dir', type=str, 17 | help='Directory path to a batch of content images') 18 | parser.add_argument('--style-dir', type=str, 19 | help='Directory path to a batch of style images') 20 | parser.add_argument('--output-dir', type=str, default='output', 21 | help='Directory to save the output images') 22 | parser.add_argument('--num-styles', type=int, default=1, help='Number of styles to \ 23 | create for each image (default: 1)') 24 | parser.add_argument('--alpha', type=float, default=1.0, 25 | help='The weight that controls the degree of \ 26 | stylization. Should be between 0 and 1') 27 | parser.add_argument('--extensions', nargs='+', type=str, default=['png', 'jpeg', 'jpg'], help='List of image extensions to scan style and content directory for (case sensitive), default: png, jpeg, jpg') 28 | 29 | # Advanced options 30 | parser.add_argument('--content-size', type=int, default=0, 31 | help='New (minimum) size for the content image, \ 32 | keeping the original size if set to 0') 33 | parser.add_argument('--style-size', type=int, default=512, 34 | help='New (minimum) size for the style image, \ 35 | keeping the original size if set to 0') 36 | parser.add_argument('--crop', action='store_true', 37 | help='do center crop to create squared image') 38 | 39 | 40 | def input_transform(size, crop): 41 | transform_list = [] 42 | if size != 0: 43 | transform_list.append(torchvision.transforms.Resize(size)) 44 | if crop: 45 | transform_list.append(torchvision.transforms.CenterCrop(size)) 46 | transform_list.append(torchvision.transforms.ToTensor()) 47 | transform = torchvision.transforms.Compose(transform_list) 48 | return transform 49 | 50 | def style_transfer(vgg, decoder, content, style, alpha=1.0): 51 | assert (0.0 <= alpha <= 1.0) 52 | content_f = vgg(content) 53 | style_f = vgg(style) 54 | feat = adaptive_instance_normalization(content_f, style_f) 55 | feat = feat * alpha + content_f * (1 - alpha) 56 | return decoder(feat) 57 | 58 | def main(): 59 | args = parser.parse_args() 60 | 61 | # deterministic hash per output directory 62 | random.seed(zlib.alder32(args.output_dir)) 63 | 64 | # set content and style directories 65 | content_dir = Path(args.content_dir) 66 | style_dir = Path(args.style_dir) 67 | style_dir = style_dir.resolve() 68 | output_dir = Path(args.output_dir) 69 | output_dir = output_dir.resolve() 70 | assert style_dir.is_dir(), 'Style directory not found' 71 | 72 | # collect content files 73 | extensions = args.extensions 74 | assert len(extensions) > 0, 'No file extensions specified' 75 | content_dir = Path(content_dir) 76 | content_dir = content_dir.resolve() 77 | assert content_dir.is_dir(), 'Content directory not found' 78 | dataset = [] 79 | for ext in extensions: 80 | dataset += list(content_dir.rglob('*.' + ext)) 81 | 82 | assert len(dataset) > 0, 'No images with specified extensions found in content directory' + content_dir 83 | content_paths = sorted(dataset) 84 | print('Found %d content images in %s' % (len(content_paths), content_dir)) 85 | 86 | # collect style files 87 | styles = [] 88 | for ext in extensions: 89 | styles += list(style_dir.rglob('*.' + ext)) 90 | 91 | assert len(styles) > 0, 'No images with specified extensions found in style directory' + style_dir 92 | styles = sorted(styles) 93 | print('Found %d style images in %s' % (len(styles), style_dir)) 94 | 95 | decoder = net.decoder 96 | vgg = net.vgg 97 | 98 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 99 | 100 | decoder.eval() 101 | vgg.eval() 102 | 103 | decoder.load_state_dict(torch.load('models/decoder.pth')) 104 | vgg.load_state_dict(torch.load('models/vgg_normalised.pth')) 105 | vgg = nn.Sequential(*list(vgg.children())[:31]) 106 | 107 | vgg.to(device) 108 | decoder.to(device) 109 | 110 | content_tf = input_transform(args.content_size, args.crop) 111 | style_tf = input_transform(args.style_size, args.crop) 112 | 113 | 114 | # disable decompression bomb errors 115 | Image.MAX_IMAGE_PIXELS = None 116 | skipped_imgs = [] 117 | 118 | # actual style transfer as in AdaIN 119 | with tqdm(total=len(content_paths)) as pbar: 120 | for content_path in content_paths: 121 | try: 122 | content_img = Image.open(content_path).convert('RGB') 123 | for style_path in random.sample(styles, args.num_styles): 124 | style_img = Image.open(style_path).convert('RGB') 125 | 126 | content = content_tf(content_img) 127 | style = style_tf(style_img) 128 | style = style.to(device).unsqueeze(0) 129 | content = content.to(device).unsqueeze(0) 130 | with torch.no_grad(): 131 | output = style_transfer(vgg, decoder, content, style, 132 | args.alpha) 133 | output = output.cpu() 134 | 135 | rel_path = content_path.relative_to(content_dir) 136 | out_dir = output_dir.joinpath(rel_path.parent) 137 | 138 | # create directory structure if it does not exist 139 | if not out_dir.is_dir(): 140 | out_dir.mkdir(parents=True) 141 | 142 | content_name = content_path.stem 143 | style_name = style_path.stem 144 | out_filename = content_name + '-stylized-' + style_name + content_path.suffix 145 | output_name = out_dir.joinpath(out_filename) 146 | 147 | save_image(output, output_name, padding=0) #default image padding is 2. 148 | style_img.close() 149 | content_img.close() 150 | except OSError as e: 151 | print('Skipping stylization of %s due to an error' %(content_path)) 152 | skipped_imgs.append(content_path) 153 | continue 154 | except RuntimeError as e: 155 | print('Skipping stylization of %s due to an error' %(content_path)) 156 | skipped_imgs.append(content_path) 157 | continue 158 | finally: 159 | pbar.update(1) 160 | 161 | if(len(skipped_imgs) > 0): 162 | with open(output_dir.joinpath('skipped_imgs.txt'), 'w') as f: 163 | for item in skipped_imgs: 164 | f.write("%s\n" % item) 165 | 166 | if __name__ == '__main__': 167 | main() 168 | -------------------------------------------------------------------------------- /www/arxiv-thumb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidbau/dissect/9421eaa8672fd051088de6c0225a385064070935/www/arxiv-thumb.png -------------------------------------------------------------------------------- /www/classifier-dissection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidbau/dissect/9421eaa8672fd051088de6c0225a385064070935/www/classifier-dissection.png -------------------------------------------------------------------------------- /www/classifier-intervention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidbau/dissect/9421eaa8672fd051088de6c0225a385064070935/www/classifier-intervention.png -------------------------------------------------------------------------------- /www/dissection-compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidbau/dissect/9421eaa8672fd051088de6c0225a385064070935/www/dissection-compare.png -------------------------------------------------------------------------------- /www/gandissect-tutorial.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidbau/dissect/9421eaa8672fd051088de6c0225a385064070935/www/gandissect-tutorial.png -------------------------------------------------------------------------------- /www/generator-dissection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidbau/dissect/9421eaa8672fd051088de6c0225a385064070935/www/generator-dissection.png -------------------------------------------------------------------------------- /www/generator-intervention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidbau/dissect/9421eaa8672fd051088de6c0225a385064070935/www/generator-intervention.png -------------------------------------------------------------------------------- /www/netdissect-tutorial.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidbau/dissect/9421eaa8672fd051088de6c0225a385064070935/www/netdissect-tutorial.png -------------------------------------------------------------------------------- /www/netdissect_code.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidbau/dissect/9421eaa8672fd051088de6c0225a385064070935/www/netdissect_code.png -------------------------------------------------------------------------------- /www/paper-thumb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidbau/dissect/9421eaa8672fd051088de6c0225a385064070935/www/paper-thumb.png -------------------------------------------------------------------------------- /www/si-thumb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidbau/dissect/9421eaa8672fd051088de6c0225a385064070935/www/si-thumb.png -------------------------------------------------------------------------------- /www/website-thumb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidbau/dissect/9421eaa8672fd051088de6c0225a385064070935/www/website-thumb.png --------------------------------------------------------------------------------